HTTPDownloader: Add support for progress updates/cancelling

This commit is contained in:
Stenzek
2023-11-24 15:54:43 +10:00
parent cca901c4c6
commit cc6f22163c
10 changed files with 92 additions and 41 deletions

View File

@ -33,6 +33,9 @@ void ReportFormattedErrorAsync(const std::string_view& title, const char* format
bool ConfirmMessage(const std::string_view& title, const std::string_view& message);
bool ConfirmFormattedMessage(const std::string_view& title, const char* format, ...);
/// Returns the user agent to use for HTTP requests.
std::string GetHTTPUserAgent();
/// Opens a URL, using the default application.
void OpenURL(const std::string_view& url);

View File

@ -5,6 +5,7 @@
#include "common/assert.h"
#include "common/log.h"
#include "common/progress_callback.h"
#include "common/string_util.h"
#include "common/timer.h"
@ -34,13 +35,14 @@ void HTTPDownloader::SetMaxActiveRequests(u32 max_active_requests)
m_max_active_requests = max_active_requests;
}
void HTTPDownloader::CreateRequest(std::string url, Request::Callback callback)
void HTTPDownloader::CreateRequest(std::string url, Request::Callback callback, ProgressCallback* progress)
{
Request* req = InternalCreateRequest();
req->parent = this;
req->type = Request::Type::Get;
req->url = std::move(url);
req->callback = std::move(callback);
req->progress = progress;
req->start_time = Common::Timer::GetCurrentValue();
std::unique_lock<std::mutex> lock(m_pending_http_request_lock);
@ -53,7 +55,8 @@ void HTTPDownloader::CreateRequest(std::string url, Request::Callback callback)
LockedAddRequest(req);
}
void HTTPDownloader::CreatePostRequest(std::string url, std::string post_data, Request::Callback callback)
void HTTPDownloader::CreatePostRequest(std::string url, std::string post_data, Request::Callback callback,
ProgressCallback* progress)
{
Request* req = InternalCreateRequest();
req->parent = this;
@ -61,6 +64,7 @@ void HTTPDownloader::CreatePostRequest(std::string url, std::string post_data, R
req->url = std::move(url);
req->post_data = std::move(post_data);
req->callback = std::move(callback);
req->progress = progress;
req->start_time = Common::Timer::GetCurrentValue();
std::unique_lock<std::mutex> lock(m_pending_http_request_lock);
@ -73,12 +77,6 @@ void HTTPDownloader::CreatePostRequest(std::string url, std::string post_data, R
LockedAddRequest(req);
}
bool HTTPDownloader::HasAnyRequests()
{
std::unique_lock<std::mutex> lock(m_pending_http_request_lock);
return !m_pending_http_requests.empty();
}
void HTTPDownloader::LockedPollRequests(std::unique_lock<std::mutex>& lock)
{
if (m_pending_http_requests.empty())
@ -100,11 +98,12 @@ void HTTPDownloader::LockedPollRequests(std::unique_lock<std::mutex>& lock)
continue;
}
if (req->state == Request::State::Started && current_time >= req->start_time &&
if ((req->state == Request::State::Started || req->state == Request::State::Receiving) &&
current_time >= req->start_time &&
Common::Timer::ConvertValueToSeconds(current_time - req->start_time) >= m_timeout)
{
// request timed out
Log_ErrorPrintf("Request for '%s' timed out", req->url.c_str());
Log_ErrorFmt("Request for '{}' timed out", req->url);
req->state.store(Request::State::Cancelled);
m_pending_http_requests.erase(m_pending_http_requests.begin() + index);
@ -117,22 +116,50 @@ void HTTPDownloader::LockedPollRequests(std::unique_lock<std::mutex>& lock)
lock.lock();
continue;
}
else if ((req->state == Request::State::Started || req->state == Request::State::Receiving) && req->progress &&
req->progress->IsCancelled())
{
// request timed out
Log_ErrorFmt("Request for '{}' cancelled", req->url);
req->state.store(Request::State::Cancelled);
m_pending_http_requests.erase(m_pending_http_requests.begin() + index);
lock.unlock();
req->callback(HTTP_STATUS_CANCELLED, std::string(), Request::Data());
CloseRequest(req);
lock.lock();
continue;
}
if (req->state != Request::State::Complete)
{
if (req->progress)
{
const u32 size = static_cast<u32>(req->data.size());
if (size != req->last_progress_update)
{
req->last_progress_update = size;
req->progress->SetProgressRange(req->content_length);
req->progress->SetProgressValue(req->last_progress_update);
}
}
active_requests++;
index++;
continue;
}
// request complete
Log_VerbosePrintf("Request for '%s' complete, returned status code %u and %zu bytes", req->url.c_str(),
req->status_code, req->data.size());
Log_VerboseFmt("Request for '{}' complete, returned status code {} and {} bytes", req->url, req->status_code,
req->data.size());
m_pending_http_requests.erase(m_pending_http_requests.begin() + index);
// run callback with lock unheld
lock.unlock();
req->callback(req->status_code, std::move(req->content_type), std::move(req->data));
req->callback(req->status_code, req->content_type, std::move(req->data));
CloseRequest(req);
lock.lock();
}
@ -197,6 +224,12 @@ u32 HTTPDownloader::LockedGetActiveRequestCount()
return count;
}
bool HTTPDownloader::HasAnyRequests()
{
std::unique_lock<std::mutex> lock(m_pending_http_request_lock);
return !m_pending_http_requests.empty();
}
std::string HTTPDownloader::URLEncode(const std::string_view& str)
{
std::string ret;

View File

@ -13,6 +13,8 @@
#include <string_view>
#include <vector>
class ProgressCallback;
class HTTPDownloader
{
public:
@ -27,7 +29,7 @@ public:
struct Request
{
using Data = std::vector<u8>;
using Callback = std::function<void(s32 status_code, std::string content_type, Data data)>;
using Callback = std::function<void(s32 status_code, const std::string& content_type, Data data)>;
enum class Type
{
@ -46,6 +48,7 @@ public:
HTTPDownloader* parent;
Callback callback;
ProgressCallback* progress;
std::string url;
std::string post_data;
std::string content_type;
@ -53,6 +56,7 @@ public:
u64 start_time;
s32 status_code = 0;
u32 content_length = 0;
u32 last_progress_update = 0;
Type type = Type::Get;
std::atomic<State> state{State::Pending};
};
@ -60,7 +64,7 @@ public:
HTTPDownloader();
virtual ~HTTPDownloader();
static std::unique_ptr<HTTPDownloader> Create(const char* user_agent = DEFAULT_USER_AGENT);
static std::unique_ptr<HTTPDownloader> Create(std::string user_agent = DEFAULT_USER_AGENT);
static std::string URLEncode(const std::string_view& str);
static std::string URLDecode(const std::string_view& str);
static std::string GetExtensionForContentType(const std::string& content_type);
@ -68,12 +72,12 @@ public:
void SetTimeout(float timeout);
void SetMaxActiveRequests(u32 max_active_requests);
void CreateRequest(std::string url, Request::Callback callback);
void CreatePostRequest(std::string url, std::string post_data, Request::Callback callback);
bool HasAnyRequests();
void CreateRequest(std::string url, Request::Callback callback, ProgressCallback* progress = nullptr);
void CreatePostRequest(std::string url, std::string post_data, Request::Callback callback,
ProgressCallback* progress = nullptr);
void PollRequests();
void WaitForAllRequests();
bool HasAnyRequests();
static const char DEFAULT_USER_AGENT[];

View File

@ -25,10 +25,10 @@ HTTPDownloaderCurl::~HTTPDownloaderCurl()
curl_multi_cleanup(m_multi_handle);
}
std::unique_ptr<HTTPDownloader> HTTPDownloader::Create(const char* user_agent)
std::unique_ptr<HTTPDownloader> HTTPDownloader::Create(std::string user_agent)
{
std::unique_ptr<HTTPDownloaderCurl> instance(std::make_unique<HTTPDownloaderCurl>());
if (!instance->Initialize(user_agent))
if (!instance->Initialize(std::move(user_agent)))
return {};
return instance;
@ -37,7 +37,7 @@ std::unique_ptr<HTTPDownloader> HTTPDownloader::Create(const char* user_agent)
static bool s_curl_initialized = false;
static std::once_flag s_curl_initialized_once_flag;
bool HTTPDownloaderCurl::Initialize(const char* user_agent)
bool HTTPDownloaderCurl::Initialize(std::string user_agent)
{
if (!s_curl_initialized)
{
@ -65,7 +65,7 @@ bool HTTPDownloaderCurl::Initialize(const char* user_agent)
return false;
}
m_user_agent = user_agent;
m_user_agent = std::move(user_agent);
return true;
}
@ -76,7 +76,16 @@ size_t HTTPDownloaderCurl::WriteCallback(char* ptr, size_t size, size_t nmemb, v
const size_t transfer_size = size * nmemb;
const size_t new_size = current_size + transfer_size;
req->data.resize(new_size);
req->start_time = Common::Timer::GetCurrentValue();
std::memcpy(&req->data[current_size], ptr, transfer_size);
if (req->content_length == 0)
{
curl_off_t length;
if (curl_easy_getinfo(req->handle, CURLINFO_CONTENT_LENGTH_DOWNLOAD_T, &length) == CURLE_OK)
req->content_length = static_cast<u32>(length);
}
return nmemb;
}
@ -160,8 +169,9 @@ bool HTTPDownloaderCurl::StartRequest(HTTPDownloader::Request* request)
curl_easy_setopt(req->handle, CURLOPT_USERAGENT, m_user_agent.c_str());
curl_easy_setopt(req->handle, CURLOPT_WRITEFUNCTION, &HTTPDownloaderCurl::WriteCallback);
curl_easy_setopt(req->handle, CURLOPT_WRITEDATA, req);
curl_easy_setopt(req->handle, CURLOPT_NOSIGNAL, 1);
curl_easy_setopt(req->handle, CURLOPT_NOSIGNAL, 1L);
curl_easy_setopt(req->handle, CURLOPT_PRIVATE, req);
curl_easy_setopt(req->handle, CURLOPT_FOLLOWLOCATION, 1L);
if (request->type == Request::Type::Post)
{

View File

@ -15,7 +15,7 @@ public:
HTTPDownloaderCurl();
~HTTPDownloaderCurl() override;
bool Initialize(const char* user_agent);
bool Initialize(std::string user_agent);
protected:
Request* InternalCreateRequest() override;

View File

@ -25,16 +25,16 @@ HTTPDownloaderWinHttp::~HTTPDownloaderWinHttp()
}
}
std::unique_ptr<HTTPDownloader> HTTPDownloader::Create(const char* user_agent)
std::unique_ptr<HTTPDownloader> HTTPDownloader::Create(std::string user_agent)
{
std::unique_ptr<HTTPDownloaderWinHttp> instance(std::make_unique<HTTPDownloaderWinHttp>());
if (!instance->Initialize(user_agent))
if (!instance->Initialize(std::move(user_agent)))
return {};
return instance;
}
bool HTTPDownloaderWinHttp::Initialize(const char* user_agent)
bool HTTPDownloaderWinHttp::Initialize(std::string user_agent)
{
static constexpr DWORD dwAccessType = WINHTTP_ACCESS_TYPE_AUTOMATIC_PROXY;

View File

@ -14,7 +14,7 @@ public:
HTTPDownloaderWinHttp();
~HTTPDownloaderWinHttp() override;
bool Initialize(const char* user_agent);
bool Initialize(std::string user_agent);
protected:
Request* InternalCreateRequest() override;