636 lines
22 KiB
C++
636 lines
22 KiB
C++
#include "NetRequest.hpp"
|
||
|
||
#include <mutex>
|
||
#include <condition_variable>
|
||
#include <unordered_map>
|
||
#include <fstream>
|
||
#include <sstream>
|
||
#include <atomic>
|
||
|
||
namespace ntq
|
||
{
|
||
namespace
|
||
{
|
||
static std::string joinPath(const std::string &base, const std::string &path)
|
||
{
|
||
if (base.empty()) return path.empty() || path[0] == '/' ? path : std::string("/") + path;
|
||
if (path.empty()) return base[0] == '/' ? base : std::string("/") + base;
|
||
bool base_has = base.front() == '/';
|
||
bool base_end = base.back() == '/';
|
||
bool path_has = path.front() == '/';
|
||
std::string b = base_has ? base : std::string("/") + base;
|
||
if (base_end && path_has) return b + path.substr(1);
|
||
if (!base_end && !path_has) return b + "/" + path;
|
||
return b + path;
|
||
}
|
||
|
||
static std::string paramsToQuery(const httplib::Params ¶ms)
|
||
{
|
||
if (params.empty()) return {};
|
||
std::string s;
|
||
bool first = true;
|
||
for (auto &kv : params)
|
||
{
|
||
if (!first) s += '&';
|
||
first = false;
|
||
s += kv.first;
|
||
s += '=';
|
||
s += kv.second;
|
||
}
|
||
return s;
|
||
}
|
||
|
||
static httplib::Headers mergeHeaders(const httplib::Headers &a, const httplib::Headers &b)
|
||
{
|
||
httplib::Headers h = a;
|
||
for (auto &kv : b)
|
||
{
|
||
// 覆盖同名 header:先删再插
|
||
h.erase(kv.first);
|
||
h.emplace(kv.first, kv.second);
|
||
}
|
||
return h;
|
||
}
|
||
}
|
||
|
||
class ConcurrencyGate
|
||
{
|
||
public:
|
||
explicit ConcurrencyGate(size_t limit) : limit_(limit), active_(0) {}
|
||
|
||
void set_limit(size_t limit)
|
||
{
|
||
std::lock_guard<std::mutex> lk(mtx_);
|
||
limit_ = limit > 0 ? limit : 1;
|
||
cv_.notify_all();
|
||
}
|
||
|
||
struct Guard
|
||
{
|
||
ConcurrencyGate &g;
|
||
explicit Guard(ConcurrencyGate &gate) : g(gate) { g.enter(); }
|
||
~Guard() { g.leave(); }
|
||
};
|
||
|
||
private:
|
||
friend struct Guard;
|
||
void enter()
|
||
{
|
||
std::unique_lock<std::mutex> lk(mtx_);
|
||
cv_.wait(lk, [&]{ return active_ < limit_; });
|
||
++active_;
|
||
}
|
||
void leave()
|
||
{
|
||
std::lock_guard<std::mutex> lk(mtx_);
|
||
if (active_ > 0) --active_;
|
||
cv_.notify_one();
|
||
}
|
||
|
||
size_t limit_;
|
||
size_t active_;
|
||
std::mutex mtx_;
|
||
std::condition_variable cv_;
|
||
};
|
||
|
||
struct NetRequest::Impl
|
||
{
|
||
RequestOptions opts;
|
||
LogCallback logger;
|
||
Stats stats;
|
||
|
||
// 并发控制
|
||
ConcurrencyGate gate{4};
|
||
|
||
// 缓存
|
||
struct CacheEntry
|
||
{
|
||
HttpResponse resp;
|
||
std::chrono::steady_clock::time_point expiry;
|
||
};
|
||
bool cache_enabled = false;
|
||
std::chrono::milliseconds cache_ttl{0};
|
||
std::unordered_map<std::string, CacheEntry> cache;
|
||
std::mutex cache_mtx;
|
||
|
||
void log(const std::string &msg)
|
||
{
|
||
if (logger) logger(msg);
|
||
}
|
||
|
||
template <typename ClientT>
|
||
void apply_client_options(ClientT &cli)
|
||
{
|
||
const time_t c_sec = static_cast<time_t>(opts.connect_timeout_ms / 1000);
|
||
const time_t c_usec = static_cast<time_t>((opts.connect_timeout_ms % 1000) * 1000);
|
||
const time_t r_sec = static_cast<time_t>(opts.read_timeout_ms / 1000);
|
||
const time_t r_usec = static_cast<time_t>((opts.read_timeout_ms % 1000) * 1000);
|
||
const time_t w_sec = static_cast<time_t>(opts.write_timeout_ms / 1000);
|
||
const time_t w_usec = static_cast<time_t>((opts.write_timeout_ms % 1000) * 1000);
|
||
cli.set_connection_timeout(c_sec, c_usec);
|
||
cli.set_read_timeout(r_sec, r_usec);
|
||
cli.set_write_timeout(w_sec, w_usec);
|
||
cli.set_keep_alive(opts.keep_alive);
|
||
}
|
||
|
||
std::string build_full_path(const std::string &path) const
|
||
{
|
||
return joinPath(opts.base_path, path);
|
||
}
|
||
|
||
std::string cache_key(const std::string &path, const httplib::Params ¶ms, const httplib::Headers &headers)
|
||
{
|
||
std::ostringstream oss;
|
||
oss << opts.scheme << "://" << opts.host << ':' << opts.port << build_full_path(path);
|
||
if (!params.empty()) oss << '?' << paramsToQuery(params);
|
||
for (auto &kv : headers) oss << '|' << kv.first << '=' << kv.second;
|
||
return oss.str();
|
||
}
|
||
|
||
void record_latency(double ms)
|
||
{
|
||
stats.last_latency_ms = ms;
|
||
const double alpha = 0.2;
|
||
if (stats.avg_latency_ms <= 0.0) stats.avg_latency_ms = ms;
|
||
else stats.avg_latency_ms = alpha * ms + (1.0 - alpha) * stats.avg_latency_ms;
|
||
}
|
||
|
||
static ErrorCode map_error()
|
||
{
|
||
// 简化:无法区分具体错误码,统一归为 Network
|
||
return ErrorCode::Network;
|
||
}
|
||
};
|
||
|
||
NetRequest::NetRequest(const RequestOptions &options)
|
||
: impl_(new Impl)
|
||
{
|
||
impl_->opts = options;
|
||
if (impl_->opts.scheme == "https" && impl_->opts.port == 80) impl_->opts.port = 443;
|
||
if (impl_->opts.scheme == "http" && impl_->opts.port == 0) impl_->opts.port = 80;
|
||
}
|
||
|
||
NetRequest::~NetRequest()
|
||
{
|
||
delete impl_;
|
||
}
|
||
|
||
void NetRequest::setLogger(LogCallback logger)
|
||
{
|
||
impl_->logger = std::move(logger);
|
||
}
|
||
|
||
void NetRequest::setMaxConcurrentRequests(size_t n)
|
||
{
|
||
impl_->gate.set_limit(n > 0 ? n : 1);
|
||
}
|
||
|
||
void NetRequest::enableCache(std::chrono::milliseconds ttl)
|
||
{
|
||
impl_->cache_enabled = true;
|
||
impl_->cache_ttl = ttl.count() > 0 ? ttl : std::chrono::milliseconds(1000);
|
||
}
|
||
|
||
void NetRequest::disableCache()
|
||
{
|
||
impl_->cache_enabled = false;
|
||
std::lock_guard<std::mutex> lk(impl_->cache_mtx);
|
||
impl_->cache.clear();
|
||
}
|
||
|
||
ntq::optional<HttpResponse> NetRequest::Get(const std::string &path,
|
||
const httplib::Params &query,
|
||
const httplib::Headers &headers,
|
||
ErrorCode *err)
|
||
{
|
||
ConcurrencyGate::Guard guard(impl_->gate);
|
||
impl_->stats.total_requests++;
|
||
auto start = std::chrono::steady_clock::now();
|
||
|
||
if (impl_->cache_enabled)
|
||
{
|
||
std::string key = impl_->cache_key(path, query, mergeHeaders(impl_->opts.default_headers, headers));
|
||
std::lock_guard<std::mutex> lk(impl_->cache_mtx);
|
||
auto it = impl_->cache.find(key);
|
||
if (it != impl_->cache.end() && std::chrono::steady_clock::now() < it->second.expiry)
|
||
{
|
||
if (err) *err = ErrorCode::None;
|
||
auto resp = it->second.resp;
|
||
resp.from_cache = true;
|
||
return resp;
|
||
}
|
||
}
|
||
|
||
ntq::optional<HttpResponse> result;
|
||
ErrorCode local_err = ErrorCode::None;
|
||
|
||
const auto full_path = impl_->build_full_path(path);
|
||
auto merged_headers = mergeHeaders(impl_->opts.default_headers, headers);
|
||
|
||
if (impl_->opts.scheme == "https")
|
||
{
|
||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||
httplib::SSLClient cli(impl_->opts.host.c_str(), impl_->opts.port);
|
||
impl_->apply_client_options(cli);
|
||
auto res = query.empty() ? cli.Get(full_path.c_str(), merged_headers)
|
||
: cli.Get(full_path.c_str(), query, merged_headers);
|
||
if (res)
|
||
{
|
||
HttpResponse r;
|
||
r.status = res->status;
|
||
r.body = res->body;
|
||
r.headers = res->headers;
|
||
r.from_cache = false;
|
||
result = r;
|
||
}
|
||
else
|
||
{
|
||
local_err = Impl::map_error();
|
||
}
|
||
#else
|
||
impl_->log("HTTPS requested but OpenSSL is not enabled; falling back to error.");
|
||
local_err = ErrorCode::SSL;
|
||
#endif
|
||
}
|
||
else
|
||
{
|
||
httplib::Client cli(impl_->opts.host.c_str(), impl_->opts.port);
|
||
impl_->apply_client_options(cli);
|
||
auto res = query.empty() ? cli.Get(full_path.c_str(), merged_headers)
|
||
: cli.Get(full_path.c_str(), query, merged_headers);
|
||
if (res)
|
||
{
|
||
HttpResponse r;
|
||
r.status = res->status;
|
||
r.body = res->body;
|
||
r.headers = res->headers;
|
||
r.from_cache = false;
|
||
result = r;
|
||
}
|
||
else
|
||
{
|
||
local_err = Impl::map_error();
|
||
}
|
||
}
|
||
|
||
auto end = std::chrono::steady_clock::now();
|
||
impl_->record_latency(std::chrono::duration<double, std::milli>(end - start).count());
|
||
|
||
if (!result.has_value())
|
||
{
|
||
impl_->stats.total_errors++;
|
||
if (err) *err = local_err;
|
||
return ntq::nullopt;
|
||
}
|
||
|
||
if (impl_->cache_enabled)
|
||
{
|
||
std::string key = impl_->cache_key(path, query, merged_headers);
|
||
std::lock_guard<std::mutex> lk(impl_->cache_mtx);
|
||
impl_->cache[key] = Impl::CacheEntry{*result, std::chrono::steady_clock::now() + impl_->cache_ttl};
|
||
}
|
||
if (err) *err = ErrorCode::None;
|
||
return result;
|
||
}
|
||
|
||
ntq::optional<HttpResponse> NetRequest::PostJson(const std::string &path,
|
||
const std::string &json,
|
||
const httplib::Headers &headers,
|
||
ErrorCode *err)
|
||
{
|
||
ConcurrencyGate::Guard guard(impl_->gate);
|
||
impl_->stats.total_requests++;
|
||
auto start = std::chrono::steady_clock::now();
|
||
|
||
ntq::optional<HttpResponse> result;
|
||
ErrorCode local_err = ErrorCode::None;
|
||
|
||
const auto full_path = impl_->build_full_path(path);
|
||
auto merged_headers = mergeHeaders(impl_->opts.default_headers, headers);
|
||
|
||
if (impl_->opts.scheme == "https")
|
||
{
|
||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||
httplib::SSLClient cli(impl_->opts.host.c_str(), impl_->opts.port);
|
||
impl_->apply_client_options(cli);
|
||
auto res = cli.Post(full_path.c_str(), merged_headers, json, "application/json");
|
||
if (res)
|
||
{
|
||
HttpResponse r{res->status, res->body, res->headers, false};
|
||
result = r;
|
||
}
|
||
else
|
||
{
|
||
local_err = Impl::map_error();
|
||
}
|
||
#else
|
||
local_err = ErrorCode::SSL;
|
||
#endif
|
||
}
|
||
else
|
||
{
|
||
httplib::Client cli(impl_->opts.host.c_str(), impl_->opts.port);
|
||
impl_->apply_client_options(cli);
|
||
auto res = cli.Post(full_path.c_str(), merged_headers, json, "application/json");
|
||
if (res)
|
||
{
|
||
HttpResponse r{res->status, res->body, res->headers, false};
|
||
result = r;
|
||
}
|
||
else
|
||
{
|
||
local_err = Impl::map_error();
|
||
}
|
||
}
|
||
|
||
auto end = std::chrono::steady_clock::now();
|
||
impl_->record_latency(std::chrono::duration<double, std::milli>(end - start).count());
|
||
if (!result)
|
||
{
|
||
impl_->stats.total_errors++;
|
||
if (err) *err = local_err;
|
||
return ntq::nullopt;
|
||
}
|
||
if (err) *err = ErrorCode::None;
|
||
return result;
|
||
}
|
||
|
||
ntq::optional<HttpResponse> NetRequest::PostForm(const std::string &path,
|
||
const httplib::Params &form,
|
||
const httplib::Headers &headers,
|
||
ErrorCode *err)
|
||
{
|
||
ConcurrencyGate::Guard guard(impl_->gate);
|
||
impl_->stats.total_requests++;
|
||
auto start = std::chrono::steady_clock::now();
|
||
|
||
ntq::optional<HttpResponse> result;
|
||
ErrorCode local_err = ErrorCode::None;
|
||
|
||
const auto full_path = impl_->build_full_path(path);
|
||
auto merged_headers = mergeHeaders(impl_->opts.default_headers, headers);
|
||
|
||
if (impl_->opts.scheme == "https")
|
||
{
|
||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||
httplib::SSLClient cli(impl_->opts.host.c_str(), impl_->opts.port);
|
||
impl_->apply_client_options(cli);
|
||
auto res = cli.Post(full_path.c_str(), merged_headers, form);
|
||
if (res)
|
||
{
|
||
HttpResponse r{res->status, res->body, res->headers, false};
|
||
result = r;
|
||
}
|
||
else
|
||
{
|
||
local_err = Impl::map_error();
|
||
}
|
||
#else
|
||
local_err = ErrorCode::SSL;
|
||
#endif
|
||
}
|
||
else
|
||
{
|
||
httplib::Client cli(impl_->opts.host.c_str(), impl_->opts.port);
|
||
impl_->apply_client_options(cli);
|
||
auto res = cli.Post(full_path.c_str(), merged_headers, form);
|
||
if (res)
|
||
{
|
||
HttpResponse r{res->status, res->body, res->headers, false};
|
||
result = r;
|
||
}
|
||
else
|
||
{
|
||
local_err = Impl::map_error();
|
||
}
|
||
}
|
||
|
||
auto end = std::chrono::steady_clock::now();
|
||
impl_->record_latency(std::chrono::duration<double, std::milli>(end - start).count());
|
||
if (!result)
|
||
{
|
||
impl_->stats.total_errors++;
|
||
if (err) *err = local_err;
|
||
return ntq::nullopt;
|
||
}
|
||
if (err) *err = ErrorCode::None;
|
||
return result;
|
||
}
|
||
|
||
std::future<ntq::optional<HttpResponse>> NetRequest::GetAsync(const std::string &path,
|
||
const httplib::Params &query,
|
||
const httplib::Headers &headers,
|
||
ErrorCode *err)
|
||
{
|
||
return std::async(std::launch::async, [this, path, query, headers, err]() mutable {
|
||
ErrorCode local;
|
||
auto r = Get(path, query, headers, &local);
|
||
if (err) *err = local;
|
||
return r;
|
||
});
|
||
}
|
||
|
||
std::future<ntq::optional<HttpResponse>> NetRequest::PostJsonAsync(const std::string &path,
|
||
const std::string &json,
|
||
const httplib::Headers &headers,
|
||
ErrorCode *err)
|
||
{
|
||
return std::async(std::launch::async, [this, path, json, headers, err]() mutable {
|
||
ErrorCode local;
|
||
auto r = PostJson(path, json, headers, &local);
|
||
if (err) *err = local;
|
||
return r;
|
||
});
|
||
}
|
||
|
||
std::future<ntq::optional<HttpResponse>> NetRequest::PostFormAsync(const std::string &path,
|
||
const httplib::Params &form,
|
||
const httplib::Headers &headers,
|
||
ErrorCode *err)
|
||
{
|
||
return std::async(std::launch::async, [this, path, form, headers, err]() mutable {
|
||
ErrorCode local;
|
||
auto r = PostForm(path, form, headers, &local);
|
||
if (err) *err = local;
|
||
return r;
|
||
});
|
||
}
|
||
|
||
bool NetRequest::DownloadToFile(const std::string &path,
|
||
const std::string &local_file,
|
||
const httplib::Headers &headers,
|
||
bool resume,
|
||
size_t /*chunk_size*/,
|
||
ErrorCode *err)
|
||
{
|
||
ConcurrencyGate::Guard guard(impl_->gate);
|
||
impl_->stats.total_requests++;
|
||
auto start = std::chrono::steady_clock::now();
|
||
|
||
std::ios_base::openmode mode = std::ios::binary | std::ios::out;
|
||
size_t offset = 0;
|
||
if (resume)
|
||
{
|
||
std::ifstream in(local_file, std::ios::binary | std::ios::ate);
|
||
if (in)
|
||
{
|
||
offset = static_cast<size_t>(in.tellg());
|
||
}
|
||
mode |= std::ios::app;
|
||
}
|
||
else
|
||
{
|
||
mode |= std::ios::trunc;
|
||
}
|
||
|
||
std::ofstream out(local_file, mode);
|
||
if (!out)
|
||
{
|
||
if (err) *err = ErrorCode::IOError;
|
||
impl_->stats.total_errors++;
|
||
return false;
|
||
}
|
||
|
||
auto merged_headers = mergeHeaders(impl_->opts.default_headers, headers);
|
||
if (resume && offset > 0)
|
||
{
|
||
merged_headers.emplace("Range", "bytes=" + std::to_string(offset) + "-");
|
||
}
|
||
|
||
const auto full_path = impl_->build_full_path(path);
|
||
int status_code = 0;
|
||
ErrorCode local_err = ErrorCode::None;
|
||
|
||
auto content_receiver = [&](const char *data, size_t data_length) {
|
||
out.write(data, static_cast<std::streamsize>(data_length));
|
||
return static_cast<bool>(out);
|
||
};
|
||
|
||
bool ok = false;
|
||
if (impl_->opts.scheme == "https")
|
||
{
|
||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||
httplib::SSLClient cli(impl_->opts.host.c_str(), impl_->opts.port);
|
||
impl_->apply_client_options(cli);
|
||
auto res = cli.Get(full_path.c_str(), merged_headers, content_receiver);
|
||
if (res)
|
||
{
|
||
status_code = res->status;
|
||
ok = (status_code == 200 || status_code == 206);
|
||
}
|
||
else
|
||
{
|
||
local_err = Impl::map_error();
|
||
}
|
||
#else
|
||
local_err = ErrorCode::SSL;
|
||
#endif
|
||
}
|
||
else
|
||
{
|
||
httplib::Client cli(impl_->opts.host.c_str(), impl_->opts.port);
|
||
impl_->apply_client_options(cli);
|
||
auto res = cli.Get(full_path.c_str(), merged_headers, content_receiver);
|
||
if (res)
|
||
{
|
||
status_code = res->status;
|
||
ok = (status_code == 200 || status_code == 206);
|
||
}
|
||
else
|
||
{
|
||
local_err = Impl::map_error();
|
||
}
|
||
}
|
||
|
||
out.close();
|
||
auto end = std::chrono::steady_clock::now();
|
||
impl_->record_latency(std::chrono::duration<double, std::milli>(end - start).count());
|
||
|
||
if (!ok)
|
||
{
|
||
impl_->stats.total_errors++;
|
||
if (err) *err = local_err;
|
||
return false;
|
||
}
|
||
if (err) *err = ErrorCode::None;
|
||
return true;
|
||
}
|
||
|
||
NetRequest::Stats NetRequest::getStats() const
|
||
{
|
||
return impl_->stats;
|
||
}
|
||
|
||
// ------------------------- Quick helpers -------------------------
|
||
namespace {
|
||
struct ParsedURL {
|
||
std::string scheme;
|
||
std::string host;
|
||
int port = 0;
|
||
std::string path_and_query;
|
||
bool ok = false;
|
||
};
|
||
|
||
static ParsedURL parse_url(const std::string &url)
|
||
{
|
||
ParsedURL p; p.ok = false;
|
||
// very small parser: scheme://host[:port]/path[?query]
|
||
auto pos_scheme = url.find("://");
|
||
if (pos_scheme == std::string::npos) return p;
|
||
p.scheme = url.substr(0, pos_scheme);
|
||
size_t pos_host = pos_scheme + 3;
|
||
|
||
size_t pos_path = url.find('/', pos_host);
|
||
std::string hostport = pos_path == std::string::npos ? url.substr(pos_host)
|
||
: url.substr(pos_host, pos_path - pos_host);
|
||
auto pos_colon = hostport.find(':');
|
||
if (pos_colon == std::string::npos) {
|
||
p.host = hostport;
|
||
p.port = (p.scheme == "https") ? 443 : 80;
|
||
} else {
|
||
p.host = hostport.substr(0, pos_colon);
|
||
std::string port_str = hostport.substr(pos_colon + 1);
|
||
p.port = port_str.empty() ? ((p.scheme == "https") ? 443 : 80) : std::atoi(port_str.c_str());
|
||
}
|
||
p.path_and_query = (pos_path == std::string::npos) ? "/" : url.substr(pos_path);
|
||
p.ok = !p.host.empty();
|
||
return p;
|
||
}
|
||
}
|
||
|
||
ntq::optional<HttpResponse> NetRequest::QuickGet(const std::string &url,
|
||
const httplib::Headers &headers,
|
||
ErrorCode *err)
|
||
{
|
||
auto p = parse_url(url);
|
||
if (!p.ok) { if (err) *err = ErrorCode::InvalidURL; return std::nullopt; }
|
||
RequestOptions opt; opt.scheme = p.scheme; opt.host = p.host; opt.port = p.port;
|
||
NetRequest req(opt);
|
||
return req.Get(p.path_and_query, {}, headers, err);
|
||
}
|
||
|
||
ntq::optional<HttpResponse> NetRequest::QuickPostJson(const std::string &url,
|
||
const std::string &json,
|
||
const httplib::Headers &headers,
|
||
ErrorCode *err)
|
||
{
|
||
auto p = parse_url(url);
|
||
if (!p.ok) { if (err) *err = ErrorCode::InvalidURL; return std::nullopt; }
|
||
RequestOptions opt; opt.scheme = p.scheme; opt.host = p.host; opt.port = p.port;
|
||
NetRequest req(opt);
|
||
return req.PostJson(p.path_and_query, json, headers, err);
|
||
}
|
||
|
||
ntq::optional<HttpResponse> NetRequest::QuickPostForm(const std::string &url,
|
||
const httplib::Params &form,
|
||
const httplib::Headers &headers,
|
||
ErrorCode *err)
|
||
{
|
||
auto p = parse_url(url);
|
||
if (!p.ok) { if (err) *err = ErrorCode::InvalidURL; return std::nullopt; }
|
||
RequestOptions opt; opt.scheme = p.scheme; opt.host = p.host; opt.port = p.port;
|
||
NetRequest req(opt);
|
||
return req.PostForm(p.path_and_query, form, headers, err);
|
||
}
|
||
}
|