sockets.h (8856B)
1 // SPDX-FileCopyrightText: 2015-2024 Connor McLaughlin <stenzek@gmail.com> 2 // SPDX-License-Identifier: (GPL-3.0 OR CC-BY-NC-ND-4.0) 3 4 #pragma once 5 6 #include "common/error.h" 7 #include "common/heap_array.h" 8 #include "common/small_string.h" 9 #include "common/threading.h" 10 #include "common/types.h" 11 12 #include <map> 13 #include <memory> 14 #include <mutex> 15 #include <optional> 16 #include <span> 17 #include <unordered_map> 18 19 #ifdef _WIN32 20 using SocketDescriptor = uintptr_t; 21 #else 22 using SocketDescriptor = int; 23 #endif 24 25 struct pollfd; 26 27 class BaseSocket; 28 class ListenSocket; 29 class StreamSocket; 30 class BufferedStreamSocket; 31 class SocketMultiplexer; 32 33 struct SocketAddress final 34 { 35 enum class Type 36 { 37 Unknown, 38 IPv4, 39 IPv6, 40 Unix, 41 }; 42 43 // accessors 44 const void* GetData() const { return m_data; } 45 u32 GetLength() const { return m_length; } 46 47 // parse interface 48 static std::optional<SocketAddress> Parse(Type type, const char* address, u32 port, Error* error); 49 50 // resolve interface 51 static std::optional<SocketAddress> Resolve(const char* address, u32 port, Error* error); 52 53 // to string interface 54 SmallString ToString() const; 55 56 // initializers 57 void SetFromSockaddr(const void* sa, size_t length); 58 59 /// Returns true if the address is IP. 60 bool IsIPAddress() const; 61 62 private: 63 u8 m_data[128] = {}; 64 u32 m_length = 0; 65 }; 66 67 class BaseSocket : public std::enable_shared_from_this<BaseSocket> 68 { 69 friend SocketMultiplexer; 70 71 public: 72 BaseSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor); 73 virtual ~BaseSocket(); 74 75 ALWAYS_INLINE SocketDescriptor GetDescriptor() const { return m_descriptor; } 76 77 virtual void Close() = 0; 78 79 protected: 80 virtual void OnReadEvent() = 0; 81 virtual void OnWriteEvent() = 0; 82 virtual void OnHangupEvent() = 0; 83 84 SocketMultiplexer& m_multiplexer; 85 SocketDescriptor m_descriptor; 86 }; 87 88 class SocketMultiplexer final 89 { 90 // TODO: Re-introduce worker threads. 91 92 public: 93 typedef std::shared_ptr<StreamSocket> (*CreateStreamSocketCallback)(SocketMultiplexer& multiplexer, 94 SocketDescriptor descriptor); 95 friend BaseSocket; 96 friend ListenSocket; 97 friend StreamSocket; 98 friend BufferedStreamSocket; 99 100 public: 101 ~SocketMultiplexer(); 102 103 // Factory method. 104 static std::unique_ptr<SocketMultiplexer> Create(Error* error); 105 106 // Public interface 107 template<class T> 108 std::shared_ptr<ListenSocket> CreateListenSocket(const SocketAddress& address, Error* error); 109 template<class T> 110 std::shared_ptr<T> ConnectStreamSocket(const SocketAddress& address, Error* error); 111 112 // Returns true if any sockets are currently registered. 113 bool HasAnyOpenSockets(); 114 115 // Returns true if any client sockets are currently connected. 116 bool HasAnyClientSockets(); 117 118 // Returns the number of current client sockets. 119 size_t GetClientSocketCount(); 120 121 // Close all sockets on this multiplexer. 122 void CloseAll(); 123 124 // Poll for events. Returns false if there are no sockets registered. 125 bool PollEventsWithTimeout(u32 milliseconds); 126 127 protected: 128 // Internal interface 129 std::shared_ptr<ListenSocket> InternalCreateListenSocket(const SocketAddress& address, 130 CreateStreamSocketCallback callback, Error* error); 131 std::shared_ptr<StreamSocket> InternalConnectStreamSocket(const SocketAddress& address, 132 CreateStreamSocketCallback callback, Error* error); 133 134 private: 135 // Hide the constructor. 136 SocketMultiplexer(); 137 138 // Initialization. 139 bool Initialize(Error* error); 140 141 // Tracking of open sockets. 142 void AddOpenSocket(std::shared_ptr<BaseSocket> socket); 143 void AddClientSocket(std::shared_ptr<BaseSocket> socket); 144 void RemoveOpenSocket(BaseSocket* socket); 145 void RemoveClientSocket(BaseSocket* socket); 146 147 // Register for notifications 148 void SetNotificationMask(BaseSocket* socket, SocketDescriptor descriptor, u32 events); 149 150 private: 151 // We store the fd in the struct to avoid the cache miss reading the object. 152 using SocketMap = std::unordered_map<SocketDescriptor, std::shared_ptr<BaseSocket>>; 153 154 #ifdef __linux__ 155 int m_epoll_fd = -1; 156 #else 157 std::mutex m_poll_array_lock; 158 pollfd* m_poll_array = nullptr; 159 size_t m_poll_array_active_size = 0; 160 size_t m_poll_array_max_size = 0; 161 #endif 162 163 std::mutex m_open_sockets_lock; 164 SocketMap m_open_sockets; 165 std::atomic_size_t m_client_socket_count{0}; 166 }; 167 168 template<class T> 169 std::shared_ptr<ListenSocket> SocketMultiplexer::CreateListenSocket(const SocketAddress& address, Error* error) 170 { 171 const CreateStreamSocketCallback callback = [](SocketMultiplexer& multiplexer, 172 SocketDescriptor descriptor) -> std::shared_ptr<StreamSocket> { 173 return std::static_pointer_cast<StreamSocket>(std::make_shared<T>(multiplexer, descriptor)); 174 }; 175 return InternalCreateListenSocket(address, callback, error); 176 } 177 178 template<class T> 179 std::shared_ptr<T> SocketMultiplexer::ConnectStreamSocket(const SocketAddress& address, Error* error) 180 { 181 const CreateStreamSocketCallback callback = [](SocketMultiplexer& multiplexer, 182 SocketDescriptor descriptor) -> std::shared_ptr<StreamSocket> { 183 return std::static_pointer_cast<StreamSocket>(std::make_shared<T>(multiplexer, descriptor)); 184 }; 185 return std::static_pointer_cast<T>(InternalConnectStreamSocket(address, callback, error)); 186 } 187 188 class ListenSocket final : public BaseSocket 189 { 190 friend SocketMultiplexer; 191 192 public: 193 ListenSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor, 194 SocketMultiplexer::CreateStreamSocketCallback accept_callback); 195 virtual ~ListenSocket() override; 196 197 const SocketAddress* GetLocalAddress() const { return &m_local_address; } 198 u32 GetConnectionsAccepted() const { return m_num_connections_accepted; } 199 200 void Close() override final; 201 202 protected: 203 void OnReadEvent() override final; 204 void OnWriteEvent() override final; 205 void OnHangupEvent() override final; 206 207 private: 208 SocketMultiplexer::CreateStreamSocketCallback m_accept_callback; 209 SocketAddress m_local_address = {}; 210 u32 m_num_connections_accepted = 0; 211 }; 212 213 class StreamSocket : public BaseSocket 214 { 215 public: 216 StreamSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor); 217 virtual ~StreamSocket() override; 218 219 static u32 GetSocketProtocolForAddress(const SocketAddress& sa); 220 221 virtual void Close() override; 222 223 // Accessors 224 const SocketAddress& GetLocalAddress() const { return m_local_address; } 225 const SocketAddress& GetRemoteAddress() const { return m_remote_address; } 226 bool IsConnected() const { return m_connected; } 227 228 // Read/write 229 size_t Read(void* buffer, size_t buffer_size); 230 size_t Write(const void* buffer, size_t buffer_size); 231 size_t WriteVector(const void** buffers, const size_t* buffer_lengths, size_t num_buffers); 232 233 /// Disables Nagle's buffering algorithm, i.e. TCP_NODELAY. 234 bool SetNagleBuffering(bool enabled, Error* error = nullptr); 235 236 protected: 237 virtual void OnConnected() = 0; 238 virtual void OnDisconnected(const Error& error) = 0; 239 virtual void OnRead() = 0; 240 241 virtual void OnReadEvent() override; 242 virtual void OnWriteEvent() override; 243 virtual void OnHangupEvent() override; 244 245 void CloseWithError(); 246 247 private: 248 void InitialSetup(); 249 250 SocketAddress m_local_address = {}; 251 SocketAddress m_remote_address = {}; 252 std::recursive_mutex m_lock; 253 bool m_connected = true; 254 255 // Ugly, but needed in order to call the events. 256 friend SocketMultiplexer; 257 friend ListenSocket; 258 friend BufferedStreamSocket; 259 }; 260 261 class BufferedStreamSocket : public StreamSocket 262 { 263 public: 264 BufferedStreamSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor, size_t receive_buffer_size = 16384, 265 size_t send_buffer_size = 16384); 266 virtual ~BufferedStreamSocket() override; 267 268 // Must hold the lock when not part of OnRead(). 269 std::unique_lock<std::recursive_mutex> GetLock(); 270 std::span<const u8> AcquireReadBuffer() const; 271 void ReleaseReadBuffer(size_t bytes_consumed); 272 std::span<u8> AcquireWriteBuffer(size_t wanted_bytes, bool allow_smaller = false); 273 void ReleaseWriteBuffer(size_t bytes_written, bool commit = true); 274 275 // Hide StreamSocket read/write methods. 276 size_t Read(void* buffer, size_t buffer_size); 277 size_t Write(const void* buffer, size_t buffer_size); 278 size_t WriteVector(const void** buffers, const size_t* buffer_lengths, size_t num_buffers); 279 virtual void Close() override; 280 281 protected: 282 void OnReadEvent() override final; 283 void OnWriteEvent() override final; 284 virtual void OnWrite(); 285 286 private: 287 std::vector<u8> m_receive_buffer; 288 size_t m_receive_buffer_offset = 0; 289 size_t m_receive_buffer_size = 0; 290 291 std::vector<u8> m_send_buffer; 292 size_t m_send_buffer_offset = 0; 293 size_t m_send_buffer_size = 0; 294 };