Util: Add socket helper classes
This commit is contained in:
268
src/util/sockets.h
Normal file
268
src/util/sockets.h
Normal file
@ -0,0 +1,268 @@
|
||||
// SPDX-FileCopyrightText: 2015-2024 Connor McLaughlin <stenzek@gmail.com>
|
||||
// SPDX-License-Identifier: (GPL-3.0 OR CC-BY-NC-ND-4.0)
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common/error.h"
|
||||
#include "common/heap_array.h"
|
||||
#include "common/small_string.h"
|
||||
#include "common/threading.h"
|
||||
#include "common/types.h"
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <optional>
|
||||
#include <unordered_map>
|
||||
#include <span>
|
||||
|
||||
#ifdef _WIN32
|
||||
using SocketDescriptor = uintptr_t;
|
||||
#else
|
||||
using SocketDescriptor = int;
|
||||
#endif
|
||||
|
||||
struct pollfd;
|
||||
|
||||
class BaseSocket;
|
||||
class ListenSocket;
|
||||
class StreamSocket;
|
||||
class BufferedStreamSocket;
|
||||
class SocketMultiplexer;
|
||||
|
||||
struct SocketAddress final
|
||||
{
|
||||
enum class Type
|
||||
{
|
||||
Unknown,
|
||||
IPv4,
|
||||
IPv6,
|
||||
Unix,
|
||||
};
|
||||
|
||||
// accessors
|
||||
const void* GetData() const { return m_data; }
|
||||
u32 GetLength() const { return m_length; }
|
||||
|
||||
// parse interface
|
||||
static std::optional<SocketAddress> Parse(Type type, const char* address, u32 port, Error* error);
|
||||
|
||||
// resolve interface
|
||||
static std::optional<SocketAddress> Resolve(const char* address, u32 port, Error* error);
|
||||
|
||||
// to string interface
|
||||
SmallString ToString() const;
|
||||
|
||||
// initializers
|
||||
void SetFromSockaddr(const void* sa, size_t length);
|
||||
|
||||
private:
|
||||
u8 m_data[128] = {};
|
||||
u32 m_length = 0;
|
||||
};
|
||||
|
||||
class BaseSocket : public std::enable_shared_from_this<BaseSocket>
|
||||
{
|
||||
friend SocketMultiplexer;
|
||||
|
||||
public:
|
||||
BaseSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor);
|
||||
virtual ~BaseSocket();
|
||||
|
||||
ALWAYS_INLINE SocketDescriptor GetDescriptor() const { return m_descriptor; }
|
||||
|
||||
virtual void Close() = 0;
|
||||
|
||||
protected:
|
||||
virtual void OnReadEvent() = 0;
|
||||
virtual void OnWriteEvent() = 0;
|
||||
|
||||
SocketMultiplexer& m_multiplexer;
|
||||
SocketDescriptor m_descriptor;
|
||||
};
|
||||
|
||||
class SocketMultiplexer final
|
||||
{
|
||||
// TODO: Re-introduce worker threads.
|
||||
|
||||
public:
|
||||
typedef std::shared_ptr<StreamSocket> (*CreateStreamSocketCallback)(SocketMultiplexer& multiplexer,
|
||||
SocketDescriptor descriptor);
|
||||
friend BaseSocket;
|
||||
friend ListenSocket;
|
||||
friend StreamSocket;
|
||||
friend BufferedStreamSocket;
|
||||
|
||||
public:
|
||||
~SocketMultiplexer();
|
||||
|
||||
// Factory method.
|
||||
static std::unique_ptr<SocketMultiplexer> Create(Error* error);
|
||||
|
||||
// Public interface
|
||||
template<class T>
|
||||
std::shared_ptr<ListenSocket> CreateListenSocket(const SocketAddress& address, Error* error);
|
||||
template<class T>
|
||||
std::shared_ptr<T> ConnectStreamSocket(const SocketAddress& address, Error* error);
|
||||
|
||||
// Returns true if any sockets are currently registered.
|
||||
bool HasAnyOpenSockets();
|
||||
|
||||
// Close all sockets on this multiplexer.
|
||||
void CloseAll();
|
||||
|
||||
// Poll for events. Returns false if there are no sockets registered.
|
||||
bool PollEventsWithTimeout(u32 milliseconds);
|
||||
|
||||
protected:
|
||||
// Internal interface
|
||||
std::shared_ptr<ListenSocket> InternalCreateListenSocket(const SocketAddress& address,
|
||||
CreateStreamSocketCallback callback, Error* error);
|
||||
std::shared_ptr<StreamSocket> InternalConnectStreamSocket(const SocketAddress& address,
|
||||
CreateStreamSocketCallback callback, Error* error);
|
||||
|
||||
private:
|
||||
// Hide the constructor.
|
||||
SocketMultiplexer();
|
||||
|
||||
// Tracking of open sockets.
|
||||
void AddOpenSocket(std::shared_ptr<BaseSocket> socket);
|
||||
void RemoveOpenSocket(BaseSocket* socket);
|
||||
|
||||
// Register for notifications
|
||||
void SetNotificationMask(BaseSocket* socket, SocketDescriptor descriptor, u32 events);
|
||||
|
||||
private:
|
||||
// We store the fd in the struct to avoid the cache miss reading the object.
|
||||
using SocketMap = std::unordered_map<SocketDescriptor, std::shared_ptr<BaseSocket>>;
|
||||
|
||||
std::mutex m_poll_array_lock;
|
||||
pollfd* m_poll_array = nullptr;
|
||||
size_t m_poll_array_active_size = 0;
|
||||
size_t m_poll_array_max_size = 0;
|
||||
|
||||
std::mutex m_open_sockets_lock;
|
||||
SocketMap m_open_sockets;
|
||||
};
|
||||
|
||||
template<class T>
|
||||
std::shared_ptr<ListenSocket> SocketMultiplexer::CreateListenSocket(const SocketAddress& address, Error* error)
|
||||
{
|
||||
const CreateStreamSocketCallback callback = [](SocketMultiplexer& multiplexer,
|
||||
SocketDescriptor descriptor) -> std::shared_ptr<StreamSocket> {
|
||||
return std::static_pointer_cast<StreamSocket>(std::make_shared<T>(multiplexer, descriptor));
|
||||
};
|
||||
return InternalCreateListenSocket(address, callback, error);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
std::shared_ptr<T> SocketMultiplexer::ConnectStreamSocket(const SocketAddress& address, Error* error)
|
||||
{
|
||||
const CreateStreamSocketCallback callback = [](SocketMultiplexer& multiplexer,
|
||||
SocketDescriptor descriptor) -> std::shared_ptr<StreamSocket> {
|
||||
return std::static_pointer_cast<StreamSocket>(std::make_shared<T>(multiplexer, descriptor));
|
||||
};
|
||||
return std::static_pointer_cast<T>(InternalConnectStreamSocket(address, callback, error));
|
||||
}
|
||||
|
||||
class ListenSocket final : public BaseSocket
|
||||
{
|
||||
friend SocketMultiplexer;
|
||||
|
||||
public:
|
||||
ListenSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor,
|
||||
SocketMultiplexer::CreateStreamSocketCallback accept_callback);
|
||||
virtual ~ListenSocket() override;
|
||||
|
||||
const SocketAddress* GetLocalAddress() const { return &m_local_address; }
|
||||
u32 GetConnectionsAccepted() const { return m_num_connections_accepted; }
|
||||
|
||||
void Close() override final;
|
||||
|
||||
protected:
|
||||
void OnReadEvent() override final;
|
||||
void OnWriteEvent() override final;
|
||||
|
||||
private:
|
||||
SocketMultiplexer::CreateStreamSocketCallback m_accept_callback;
|
||||
SocketAddress m_local_address = {};
|
||||
u32 m_num_connections_accepted = 0;
|
||||
};
|
||||
|
||||
class StreamSocket : public BaseSocket
|
||||
{
|
||||
public:
|
||||
StreamSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor);
|
||||
virtual ~StreamSocket() override;
|
||||
|
||||
static u32 GetSocketProtocolForAddress(const SocketAddress& sa);
|
||||
|
||||
virtual void Close() override final;
|
||||
|
||||
// Accessors
|
||||
const SocketAddress& GetLocalAddress() const { return m_local_address; }
|
||||
const SocketAddress& GetRemoteAddress() const { return m_remote_address; }
|
||||
bool IsConnected() const { return m_connected; }
|
||||
|
||||
// Read/write
|
||||
size_t Read(void* buffer, size_t buffer_size);
|
||||
size_t Write(const void* buffer, size_t buffer_size);
|
||||
size_t WriteVector(const void** buffers, const size_t* buffer_lengths, size_t num_buffers);
|
||||
|
||||
protected:
|
||||
virtual void OnConnected() = 0;
|
||||
virtual void OnDisconnected(const Error& error) = 0;
|
||||
virtual void OnRead() = 0;
|
||||
|
||||
virtual void OnReadEvent() override;
|
||||
virtual void OnWriteEvent() override;
|
||||
|
||||
void CloseWithError();
|
||||
|
||||
private:
|
||||
void InitialSetup();
|
||||
|
||||
SocketAddress m_local_address = {};
|
||||
SocketAddress m_remote_address = {};
|
||||
std::recursive_mutex m_lock;
|
||||
bool m_connected = true;
|
||||
|
||||
// Ugly, but needed in order to call the events.
|
||||
friend SocketMultiplexer;
|
||||
friend ListenSocket;
|
||||
friend BufferedStreamSocket;
|
||||
};
|
||||
|
||||
class BufferedStreamSocket : public StreamSocket
|
||||
{
|
||||
public:
|
||||
BufferedStreamSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor, size_t receive_buffer_size = 16384,
|
||||
size_t send_buffer_size = 16384);
|
||||
virtual ~BufferedStreamSocket() override;
|
||||
|
||||
// Must hold the lock when not part of OnRead().
|
||||
std::unique_lock<std::recursive_mutex> GetLock();
|
||||
std::span<const u8> AcquireReadBuffer() const;
|
||||
void ReleaseReadBuffer(size_t bytes_consumed);
|
||||
std::span<u8> AcquireWriteBuffer(size_t wanted_bytes, bool allow_smaller = false);
|
||||
void ReleaseWriteBuffer(size_t bytes_written, bool commit = true);
|
||||
|
||||
// Hide StreamSocket read/write methods.
|
||||
size_t Read(void* buffer, size_t buffer_size);
|
||||
size_t Write(const void* buffer, size_t buffer_size);
|
||||
size_t WriteVector(const void** buffers, const size_t* buffer_lengths, size_t num_buffers);
|
||||
|
||||
protected:
|
||||
void OnReadEvent() override final;
|
||||
void OnWriteEvent() override final;
|
||||
virtual void OnWrite();
|
||||
|
||||
private:
|
||||
std::vector<u8> m_receive_buffer;
|
||||
size_t m_receive_buffer_offset = 0;
|
||||
size_t m_receive_buffer_size = 0;
|
||||
|
||||
std::vector<u8> m_send_buffer;
|
||||
size_t m_send_buffer_offset = 0;
|
||||
size_t m_send_buffer_size = 0;
|
||||
};
|
||||
Reference in New Issue
Block a user