Files
RkApp/Identification/NetraLib/src/NetRequest.cpp
2025-11-21 14:27:36 +08:00

636 lines
22 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "NetRequest.hpp"
#include <mutex>
#include <condition_variable>
#include <unordered_map>
#include <fstream>
#include <sstream>
#include <atomic>
namespace ntq
{
namespace
{
static std::string joinPath(const std::string &base, const std::string &path)
{
if (base.empty()) return path.empty() || path[0] == '/' ? path : std::string("/") + path;
if (path.empty()) return base[0] == '/' ? base : std::string("/") + base;
bool base_has = base.front() == '/';
bool base_end = base.back() == '/';
bool path_has = path.front() == '/';
std::string b = base_has ? base : std::string("/") + base;
if (base_end && path_has) return b + path.substr(1);
if (!base_end && !path_has) return b + "/" + path;
return b + path;
}
static std::string paramsToQuery(const httplib::Params &params)
{
if (params.empty()) return {};
std::string s;
bool first = true;
for (auto &kv : params)
{
if (!first) s += '&';
first = false;
s += kv.first;
s += '=';
s += kv.second;
}
return s;
}
static httplib::Headers mergeHeaders(const httplib::Headers &a, const httplib::Headers &b)
{
httplib::Headers h = a;
for (auto &kv : b)
{
// 覆盖同名 header先删再插
h.erase(kv.first);
h.emplace(kv.first, kv.second);
}
return h;
}
}
class ConcurrencyGate
{
public:
explicit ConcurrencyGate(size_t limit) : limit_(limit), active_(0) {}
void set_limit(size_t limit)
{
std::lock_guard<std::mutex> lk(mtx_);
limit_ = limit > 0 ? limit : 1;
cv_.notify_all();
}
struct Guard
{
ConcurrencyGate &g;
explicit Guard(ConcurrencyGate &gate) : g(gate) { g.enter(); }
~Guard() { g.leave(); }
};
private:
friend struct Guard;
void enter()
{
std::unique_lock<std::mutex> lk(mtx_);
cv_.wait(lk, [&]{ return active_ < limit_; });
++active_;
}
void leave()
{
std::lock_guard<std::mutex> lk(mtx_);
if (active_ > 0) --active_;
cv_.notify_one();
}
size_t limit_;
size_t active_;
std::mutex mtx_;
std::condition_variable cv_;
};
struct NetRequest::Impl
{
RequestOptions opts;
LogCallback logger;
Stats stats;
// 并发控制
ConcurrencyGate gate{4};
// 缓存
struct CacheEntry
{
HttpResponse resp;
std::chrono::steady_clock::time_point expiry;
};
bool cache_enabled = false;
std::chrono::milliseconds cache_ttl{0};
std::unordered_map<std::string, CacheEntry> cache;
std::mutex cache_mtx;
void log(const std::string &msg)
{
if (logger) logger(msg);
}
template <typename ClientT>
void apply_client_options(ClientT &cli)
{
const time_t c_sec = static_cast<time_t>(opts.connect_timeout_ms / 1000);
const time_t c_usec = static_cast<time_t>((opts.connect_timeout_ms % 1000) * 1000);
const time_t r_sec = static_cast<time_t>(opts.read_timeout_ms / 1000);
const time_t r_usec = static_cast<time_t>((opts.read_timeout_ms % 1000) * 1000);
const time_t w_sec = static_cast<time_t>(opts.write_timeout_ms / 1000);
const time_t w_usec = static_cast<time_t>((opts.write_timeout_ms % 1000) * 1000);
cli.set_connection_timeout(c_sec, c_usec);
cli.set_read_timeout(r_sec, r_usec);
cli.set_write_timeout(w_sec, w_usec);
cli.set_keep_alive(opts.keep_alive);
}
std::string build_full_path(const std::string &path) const
{
return joinPath(opts.base_path, path);
}
std::string cache_key(const std::string &path, const httplib::Params &params, const httplib::Headers &headers)
{
std::ostringstream oss;
oss << opts.scheme << "://" << opts.host << ':' << opts.port << build_full_path(path);
if (!params.empty()) oss << '?' << paramsToQuery(params);
for (auto &kv : headers) oss << '|' << kv.first << '=' << kv.second;
return oss.str();
}
void record_latency(double ms)
{
stats.last_latency_ms = ms;
const double alpha = 0.2;
if (stats.avg_latency_ms <= 0.0) stats.avg_latency_ms = ms;
else stats.avg_latency_ms = alpha * ms + (1.0 - alpha) * stats.avg_latency_ms;
}
static ErrorCode map_error()
{
// 简化:无法区分具体错误码,统一归为 Network
return ErrorCode::Network;
}
};
NetRequest::NetRequest(const RequestOptions &options)
: impl_(new Impl)
{
impl_->opts = options;
if (impl_->opts.scheme == "https" && impl_->opts.port == 80) impl_->opts.port = 443;
if (impl_->opts.scheme == "http" && impl_->opts.port == 0) impl_->opts.port = 80;
}
NetRequest::~NetRequest()
{
delete impl_;
}
void NetRequest::setLogger(LogCallback logger)
{
impl_->logger = std::move(logger);
}
void NetRequest::setMaxConcurrentRequests(size_t n)
{
impl_->gate.set_limit(n > 0 ? n : 1);
}
void NetRequest::enableCache(std::chrono::milliseconds ttl)
{
impl_->cache_enabled = true;
impl_->cache_ttl = ttl.count() > 0 ? ttl : std::chrono::milliseconds(1000);
}
void NetRequest::disableCache()
{
impl_->cache_enabled = false;
std::lock_guard<std::mutex> lk(impl_->cache_mtx);
impl_->cache.clear();
}
ntq::optional<HttpResponse> NetRequest::Get(const std::string &path,
const httplib::Params &query,
const httplib::Headers &headers,
ErrorCode *err)
{
ConcurrencyGate::Guard guard(impl_->gate);
impl_->stats.total_requests++;
auto start = std::chrono::steady_clock::now();
if (impl_->cache_enabled)
{
std::string key = impl_->cache_key(path, query, mergeHeaders(impl_->opts.default_headers, headers));
std::lock_guard<std::mutex> lk(impl_->cache_mtx);
auto it = impl_->cache.find(key);
if (it != impl_->cache.end() && std::chrono::steady_clock::now() < it->second.expiry)
{
if (err) *err = ErrorCode::None;
auto resp = it->second.resp;
resp.from_cache = true;
return resp;
}
}
ntq::optional<HttpResponse> result;
ErrorCode local_err = ErrorCode::None;
const auto full_path = impl_->build_full_path(path);
auto merged_headers = mergeHeaders(impl_->opts.default_headers, headers);
if (impl_->opts.scheme == "https")
{
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
httplib::SSLClient cli(impl_->opts.host.c_str(), impl_->opts.port);
impl_->apply_client_options(cli);
auto res = query.empty() ? cli.Get(full_path.c_str(), merged_headers)
: cli.Get(full_path.c_str(), query, merged_headers);
if (res)
{
HttpResponse r;
r.status = res->status;
r.body = res->body;
r.headers = res->headers;
r.from_cache = false;
result = r;
}
else
{
local_err = Impl::map_error();
}
#else
impl_->log("HTTPS requested but OpenSSL is not enabled; falling back to error.");
local_err = ErrorCode::SSL;
#endif
}
else
{
httplib::Client cli(impl_->opts.host.c_str(), impl_->opts.port);
impl_->apply_client_options(cli);
auto res = query.empty() ? cli.Get(full_path.c_str(), merged_headers)
: cli.Get(full_path.c_str(), query, merged_headers);
if (res)
{
HttpResponse r;
r.status = res->status;
r.body = res->body;
r.headers = res->headers;
r.from_cache = false;
result = r;
}
else
{
local_err = Impl::map_error();
}
}
auto end = std::chrono::steady_clock::now();
impl_->record_latency(std::chrono::duration<double, std::milli>(end - start).count());
if (!result.has_value())
{
impl_->stats.total_errors++;
if (err) *err = local_err;
return ntq::nullopt;
}
if (impl_->cache_enabled)
{
std::string key = impl_->cache_key(path, query, merged_headers);
std::lock_guard<std::mutex> lk(impl_->cache_mtx);
impl_->cache[key] = Impl::CacheEntry{*result, std::chrono::steady_clock::now() + impl_->cache_ttl};
}
if (err) *err = ErrorCode::None;
return result;
}
ntq::optional<HttpResponse> NetRequest::PostJson(const std::string &path,
const std::string &json,
const httplib::Headers &headers,
ErrorCode *err)
{
ConcurrencyGate::Guard guard(impl_->gate);
impl_->stats.total_requests++;
auto start = std::chrono::steady_clock::now();
ntq::optional<HttpResponse> result;
ErrorCode local_err = ErrorCode::None;
const auto full_path = impl_->build_full_path(path);
auto merged_headers = mergeHeaders(impl_->opts.default_headers, headers);
if (impl_->opts.scheme == "https")
{
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
httplib::SSLClient cli(impl_->opts.host.c_str(), impl_->opts.port);
impl_->apply_client_options(cli);
auto res = cli.Post(full_path.c_str(), merged_headers, json, "application/json");
if (res)
{
HttpResponse r{res->status, res->body, res->headers, false};
result = r;
}
else
{
local_err = Impl::map_error();
}
#else
local_err = ErrorCode::SSL;
#endif
}
else
{
httplib::Client cli(impl_->opts.host.c_str(), impl_->opts.port);
impl_->apply_client_options(cli);
auto res = cli.Post(full_path.c_str(), merged_headers, json, "application/json");
if (res)
{
HttpResponse r{res->status, res->body, res->headers, false};
result = r;
}
else
{
local_err = Impl::map_error();
}
}
auto end = std::chrono::steady_clock::now();
impl_->record_latency(std::chrono::duration<double, std::milli>(end - start).count());
if (!result)
{
impl_->stats.total_errors++;
if (err) *err = local_err;
return ntq::nullopt;
}
if (err) *err = ErrorCode::None;
return result;
}
ntq::optional<HttpResponse> NetRequest::PostForm(const std::string &path,
const httplib::Params &form,
const httplib::Headers &headers,
ErrorCode *err)
{
ConcurrencyGate::Guard guard(impl_->gate);
impl_->stats.total_requests++;
auto start = std::chrono::steady_clock::now();
ntq::optional<HttpResponse> result;
ErrorCode local_err = ErrorCode::None;
const auto full_path = impl_->build_full_path(path);
auto merged_headers = mergeHeaders(impl_->opts.default_headers, headers);
if (impl_->opts.scheme == "https")
{
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
httplib::SSLClient cli(impl_->opts.host.c_str(), impl_->opts.port);
impl_->apply_client_options(cli);
auto res = cli.Post(full_path.c_str(), merged_headers, form);
if (res)
{
HttpResponse r{res->status, res->body, res->headers, false};
result = r;
}
else
{
local_err = Impl::map_error();
}
#else
local_err = ErrorCode::SSL;
#endif
}
else
{
httplib::Client cli(impl_->opts.host.c_str(), impl_->opts.port);
impl_->apply_client_options(cli);
auto res = cli.Post(full_path.c_str(), merged_headers, form);
if (res)
{
HttpResponse r{res->status, res->body, res->headers, false};
result = r;
}
else
{
local_err = Impl::map_error();
}
}
auto end = std::chrono::steady_clock::now();
impl_->record_latency(std::chrono::duration<double, std::milli>(end - start).count());
if (!result)
{
impl_->stats.total_errors++;
if (err) *err = local_err;
return ntq::nullopt;
}
if (err) *err = ErrorCode::None;
return result;
}
std::future<ntq::optional<HttpResponse>> NetRequest::GetAsync(const std::string &path,
const httplib::Params &query,
const httplib::Headers &headers,
ErrorCode *err)
{
return std::async(std::launch::async, [this, path, query, headers, err]() mutable {
ErrorCode local;
auto r = Get(path, query, headers, &local);
if (err) *err = local;
return r;
});
}
std::future<ntq::optional<HttpResponse>> NetRequest::PostJsonAsync(const std::string &path,
const std::string &json,
const httplib::Headers &headers,
ErrorCode *err)
{
return std::async(std::launch::async, [this, path, json, headers, err]() mutable {
ErrorCode local;
auto r = PostJson(path, json, headers, &local);
if (err) *err = local;
return r;
});
}
std::future<ntq::optional<HttpResponse>> NetRequest::PostFormAsync(const std::string &path,
const httplib::Params &form,
const httplib::Headers &headers,
ErrorCode *err)
{
return std::async(std::launch::async, [this, path, form, headers, err]() mutable {
ErrorCode local;
auto r = PostForm(path, form, headers, &local);
if (err) *err = local;
return r;
});
}
bool NetRequest::DownloadToFile(const std::string &path,
const std::string &local_file,
const httplib::Headers &headers,
bool resume,
size_t /*chunk_size*/,
ErrorCode *err)
{
ConcurrencyGate::Guard guard(impl_->gate);
impl_->stats.total_requests++;
auto start = std::chrono::steady_clock::now();
std::ios_base::openmode mode = std::ios::binary | std::ios::out;
size_t offset = 0;
if (resume)
{
std::ifstream in(local_file, std::ios::binary | std::ios::ate);
if (in)
{
offset = static_cast<size_t>(in.tellg());
}
mode |= std::ios::app;
}
else
{
mode |= std::ios::trunc;
}
std::ofstream out(local_file, mode);
if (!out)
{
if (err) *err = ErrorCode::IOError;
impl_->stats.total_errors++;
return false;
}
auto merged_headers = mergeHeaders(impl_->opts.default_headers, headers);
if (resume && offset > 0)
{
merged_headers.emplace("Range", "bytes=" + std::to_string(offset) + "-");
}
const auto full_path = impl_->build_full_path(path);
int status_code = 0;
ErrorCode local_err = ErrorCode::None;
auto content_receiver = [&](const char *data, size_t data_length) {
out.write(data, static_cast<std::streamsize>(data_length));
return static_cast<bool>(out);
};
bool ok = false;
if (impl_->opts.scheme == "https")
{
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
httplib::SSLClient cli(impl_->opts.host.c_str(), impl_->opts.port);
impl_->apply_client_options(cli);
auto res = cli.Get(full_path.c_str(), merged_headers, content_receiver);
if (res)
{
status_code = res->status;
ok = (status_code == 200 || status_code == 206);
}
else
{
local_err = Impl::map_error();
}
#else
local_err = ErrorCode::SSL;
#endif
}
else
{
httplib::Client cli(impl_->opts.host.c_str(), impl_->opts.port);
impl_->apply_client_options(cli);
auto res = cli.Get(full_path.c_str(), merged_headers, content_receiver);
if (res)
{
status_code = res->status;
ok = (status_code == 200 || status_code == 206);
}
else
{
local_err = Impl::map_error();
}
}
out.close();
auto end = std::chrono::steady_clock::now();
impl_->record_latency(std::chrono::duration<double, std::milli>(end - start).count());
if (!ok)
{
impl_->stats.total_errors++;
if (err) *err = local_err;
return false;
}
if (err) *err = ErrorCode::None;
return true;
}
NetRequest::Stats NetRequest::getStats() const
{
return impl_->stats;
}
// ------------------------- Quick helpers -------------------------
namespace {
struct ParsedURL {
std::string scheme;
std::string host;
int port = 0;
std::string path_and_query;
bool ok = false;
};
static ParsedURL parse_url(const std::string &url)
{
ParsedURL p; p.ok = false;
// very small parser: scheme://host[:port]/path[?query]
auto pos_scheme = url.find("://");
if (pos_scheme == std::string::npos) return p;
p.scheme = url.substr(0, pos_scheme);
size_t pos_host = pos_scheme + 3;
size_t pos_path = url.find('/', pos_host);
std::string hostport = pos_path == std::string::npos ? url.substr(pos_host)
: url.substr(pos_host, pos_path - pos_host);
auto pos_colon = hostport.find(':');
if (pos_colon == std::string::npos) {
p.host = hostport;
p.port = (p.scheme == "https") ? 443 : 80;
} else {
p.host = hostport.substr(0, pos_colon);
std::string port_str = hostport.substr(pos_colon + 1);
p.port = port_str.empty() ? ((p.scheme == "https") ? 443 : 80) : std::atoi(port_str.c_str());
}
p.path_and_query = (pos_path == std::string::npos) ? "/" : url.substr(pos_path);
p.ok = !p.host.empty();
return p;
}
}
ntq::optional<HttpResponse> NetRequest::QuickGet(const std::string &url,
const httplib::Headers &headers,
ErrorCode *err)
{
auto p = parse_url(url);
if (!p.ok) { if (err) *err = ErrorCode::InvalidURL; return std::nullopt; }
RequestOptions opt; opt.scheme = p.scheme; opt.host = p.host; opt.port = p.port;
NetRequest req(opt);
return req.Get(p.path_and_query, {}, headers, err);
}
ntq::optional<HttpResponse> NetRequest::QuickPostJson(const std::string &url,
const std::string &json,
const httplib::Headers &headers,
ErrorCode *err)
{
auto p = parse_url(url);
if (!p.ok) { if (err) *err = ErrorCode::InvalidURL; return std::nullopt; }
RequestOptions opt; opt.scheme = p.scheme; opt.host = p.host; opt.port = p.port;
NetRequest req(opt);
return req.PostJson(p.path_and_query, json, headers, err);
}
ntq::optional<HttpResponse> NetRequest::QuickPostForm(const std::string &url,
const httplib::Params &form,
const httplib::Headers &headers,
ErrorCode *err)
{
auto p = parse_url(url);
if (!p.ok) { if (err) *err = ErrorCode::InvalidURL; return std::nullopt; }
RequestOptions opt; opt.scheme = p.scheme; opt.host = p.host; opt.port = p.port;
NetRequest req(opt);
return req.PostForm(p.path_and_query, form, headers, err);
}
}