duckstation

duckstation, but archived from the revision just before upstream changed it to a proprietary software project, this version is the libre one
git clone https://git.neptards.moe/u3shit/duckstation.git
Log | Files | Refs | README | LICENSE

sockets.h (8856B)


      1 // SPDX-FileCopyrightText: 2015-2024 Connor McLaughlin <stenzek@gmail.com>
      2 // SPDX-License-Identifier: (GPL-3.0 OR CC-BY-NC-ND-4.0)
      3 
      4 #pragma once
      5 
      6 #include "common/error.h"
      7 #include "common/heap_array.h"
      8 #include "common/small_string.h"
      9 #include "common/threading.h"
     10 #include "common/types.h"
     11 
     12 #include <map>
     13 #include <memory>
     14 #include <mutex>
     15 #include <optional>
     16 #include <span>
     17 #include <unordered_map>
     18 
     19 #ifdef _WIN32
     20 using SocketDescriptor = uintptr_t;
     21 #else
     22 using SocketDescriptor = int;
     23 #endif
     24 
     25 struct pollfd;
     26 
     27 class BaseSocket;
     28 class ListenSocket;
     29 class StreamSocket;
     30 class BufferedStreamSocket;
     31 class SocketMultiplexer;
     32 
     33 struct SocketAddress final
     34 {
     35   enum class Type
     36   {
     37     Unknown,
     38     IPv4,
     39     IPv6,
     40     Unix,
     41   };
     42 
     43   // accessors
     44   const void* GetData() const { return m_data; }
     45   u32 GetLength() const { return m_length; }
     46 
     47   // parse interface
     48   static std::optional<SocketAddress> Parse(Type type, const char* address, u32 port, Error* error);
     49 
     50   // resolve interface
     51   static std::optional<SocketAddress> Resolve(const char* address, u32 port, Error* error);
     52 
     53   // to string interface
     54   SmallString ToString() const;
     55 
     56   // initializers
     57   void SetFromSockaddr(const void* sa, size_t length);
     58 
     59   /// Returns true if the address is IP.
     60   bool IsIPAddress() const;
     61 
     62 private:
     63   u8 m_data[128] = {};
     64   u32 m_length = 0;
     65 };
     66 
     67 class BaseSocket : public std::enable_shared_from_this<BaseSocket>
     68 {
     69   friend SocketMultiplexer;
     70 
     71 public:
     72   BaseSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor);
     73   virtual ~BaseSocket();
     74 
     75   ALWAYS_INLINE SocketDescriptor GetDescriptor() const { return m_descriptor; }
     76 
     77   virtual void Close() = 0;
     78 
     79 protected:
     80   virtual void OnReadEvent() = 0;
     81   virtual void OnWriteEvent() = 0;
     82   virtual void OnHangupEvent() = 0;
     83 
     84   SocketMultiplexer& m_multiplexer;
     85   SocketDescriptor m_descriptor;
     86 };
     87 
     88 class SocketMultiplexer final
     89 {
     90   // TODO: Re-introduce worker threads.
     91 
     92 public:
     93   typedef std::shared_ptr<StreamSocket> (*CreateStreamSocketCallback)(SocketMultiplexer& multiplexer,
     94                                                                       SocketDescriptor descriptor);
     95   friend BaseSocket;
     96   friend ListenSocket;
     97   friend StreamSocket;
     98   friend BufferedStreamSocket;
     99 
    100 public:
    101   ~SocketMultiplexer();
    102 
    103   // Factory method.
    104   static std::unique_ptr<SocketMultiplexer> Create(Error* error);
    105 
    106   // Public interface
    107   template<class T>
    108   std::shared_ptr<ListenSocket> CreateListenSocket(const SocketAddress& address, Error* error);
    109   template<class T>
    110   std::shared_ptr<T> ConnectStreamSocket(const SocketAddress& address, Error* error);
    111 
    112   // Returns true if any sockets are currently registered.
    113   bool HasAnyOpenSockets();
    114 
    115   // Returns true if any client sockets are currently connected.
    116   bool HasAnyClientSockets();
    117 
    118   // Returns the number of current client sockets.
    119   size_t GetClientSocketCount();
    120 
    121   // Close all sockets on this multiplexer.
    122   void CloseAll();
    123 
    124   // Poll for events. Returns false if there are no sockets registered.
    125   bool PollEventsWithTimeout(u32 milliseconds);
    126 
    127 protected:
    128   // Internal interface
    129   std::shared_ptr<ListenSocket> InternalCreateListenSocket(const SocketAddress& address,
    130                                                            CreateStreamSocketCallback callback, Error* error);
    131   std::shared_ptr<StreamSocket> InternalConnectStreamSocket(const SocketAddress& address,
    132                                                             CreateStreamSocketCallback callback, Error* error);
    133 
    134 private:
    135   // Hide the constructor.
    136   SocketMultiplexer();
    137 
    138   // Initialization.
    139   bool Initialize(Error* error);
    140 
    141   // Tracking of open sockets.
    142   void AddOpenSocket(std::shared_ptr<BaseSocket> socket);
    143   void AddClientSocket(std::shared_ptr<BaseSocket> socket);
    144   void RemoveOpenSocket(BaseSocket* socket);
    145   void RemoveClientSocket(BaseSocket* socket);
    146 
    147   // Register for notifications
    148   void SetNotificationMask(BaseSocket* socket, SocketDescriptor descriptor, u32 events);
    149 
    150 private:
    151   // We store the fd in the struct to avoid the cache miss reading the object.
    152   using SocketMap = std::unordered_map<SocketDescriptor, std::shared_ptr<BaseSocket>>;
    153 
    154 #ifdef __linux__
    155   int m_epoll_fd = -1;
    156 #else
    157   std::mutex m_poll_array_lock;
    158   pollfd* m_poll_array = nullptr;
    159   size_t m_poll_array_active_size = 0;
    160   size_t m_poll_array_max_size = 0;
    161 #endif
    162 
    163   std::mutex m_open_sockets_lock;
    164   SocketMap m_open_sockets;
    165   std::atomic_size_t m_client_socket_count{0};
    166 };
    167 
    168 template<class T>
    169 std::shared_ptr<ListenSocket> SocketMultiplexer::CreateListenSocket(const SocketAddress& address, Error* error)
    170 {
    171   const CreateStreamSocketCallback callback = [](SocketMultiplexer& multiplexer,
    172                                                  SocketDescriptor descriptor) -> std::shared_ptr<StreamSocket> {
    173     return std::static_pointer_cast<StreamSocket>(std::make_shared<T>(multiplexer, descriptor));
    174   };
    175   return InternalCreateListenSocket(address, callback, error);
    176 }
    177 
    178 template<class T>
    179 std::shared_ptr<T> SocketMultiplexer::ConnectStreamSocket(const SocketAddress& address, Error* error)
    180 {
    181   const CreateStreamSocketCallback callback = [](SocketMultiplexer& multiplexer,
    182                                                  SocketDescriptor descriptor) -> std::shared_ptr<StreamSocket> {
    183     return std::static_pointer_cast<StreamSocket>(std::make_shared<T>(multiplexer, descriptor));
    184   };
    185   return std::static_pointer_cast<T>(InternalConnectStreamSocket(address, callback, error));
    186 }
    187 
    188 class ListenSocket final : public BaseSocket
    189 {
    190   friend SocketMultiplexer;
    191 
    192 public:
    193   ListenSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor,
    194                SocketMultiplexer::CreateStreamSocketCallback accept_callback);
    195   virtual ~ListenSocket() override;
    196 
    197   const SocketAddress* GetLocalAddress() const { return &m_local_address; }
    198   u32 GetConnectionsAccepted() const { return m_num_connections_accepted; }
    199 
    200   void Close() override final;
    201 
    202 protected:
    203   void OnReadEvent() override final;
    204   void OnWriteEvent() override final;
    205   void OnHangupEvent() override final;
    206 
    207 private:
    208   SocketMultiplexer::CreateStreamSocketCallback m_accept_callback;
    209   SocketAddress m_local_address = {};
    210   u32 m_num_connections_accepted = 0;
    211 };
    212 
    213 class StreamSocket : public BaseSocket
    214 {
    215 public:
    216   StreamSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor);
    217   virtual ~StreamSocket() override;
    218 
    219   static u32 GetSocketProtocolForAddress(const SocketAddress& sa);
    220 
    221   virtual void Close() override;
    222 
    223   // Accessors
    224   const SocketAddress& GetLocalAddress() const { return m_local_address; }
    225   const SocketAddress& GetRemoteAddress() const { return m_remote_address; }
    226   bool IsConnected() const { return m_connected; }
    227 
    228   // Read/write
    229   size_t Read(void* buffer, size_t buffer_size);
    230   size_t Write(const void* buffer, size_t buffer_size);
    231   size_t WriteVector(const void** buffers, const size_t* buffer_lengths, size_t num_buffers);
    232 
    233   /// Disables Nagle's buffering algorithm, i.e. TCP_NODELAY.
    234   bool SetNagleBuffering(bool enabled, Error* error = nullptr);
    235 
    236 protected:
    237   virtual void OnConnected() = 0;
    238   virtual void OnDisconnected(const Error& error) = 0;
    239   virtual void OnRead() = 0;
    240 
    241   virtual void OnReadEvent() override;
    242   virtual void OnWriteEvent() override;
    243   virtual void OnHangupEvent() override;
    244 
    245   void CloseWithError();
    246 
    247 private:
    248   void InitialSetup();
    249 
    250   SocketAddress m_local_address = {};
    251   SocketAddress m_remote_address = {};
    252   std::recursive_mutex m_lock;
    253   bool m_connected = true;
    254 
    255   // Ugly, but needed in order to call the events.
    256   friend SocketMultiplexer;
    257   friend ListenSocket;
    258   friend BufferedStreamSocket;
    259 };
    260 
    261 class BufferedStreamSocket : public StreamSocket
    262 {
    263 public:
    264   BufferedStreamSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor, size_t receive_buffer_size = 16384,
    265                        size_t send_buffer_size = 16384);
    266   virtual ~BufferedStreamSocket() override;
    267 
    268   // Must hold the lock when not part of OnRead().
    269   std::unique_lock<std::recursive_mutex> GetLock();
    270   std::span<const u8> AcquireReadBuffer() const;
    271   void ReleaseReadBuffer(size_t bytes_consumed);
    272   std::span<u8> AcquireWriteBuffer(size_t wanted_bytes, bool allow_smaller = false);
    273   void ReleaseWriteBuffer(size_t bytes_written, bool commit = true);
    274 
    275   // Hide StreamSocket read/write methods.
    276   size_t Read(void* buffer, size_t buffer_size);
    277   size_t Write(const void* buffer, size_t buffer_size);
    278   size_t WriteVector(const void** buffers, const size_t* buffer_lengths, size_t num_buffers);
    279   virtual void Close() override;
    280 
    281 protected:
    282   void OnReadEvent() override final;
    283   void OnWriteEvent() override final;
    284   virtual void OnWrite();
    285 
    286 private:
    287   std::vector<u8> m_receive_buffer;
    288   size_t m_receive_buffer_offset = 0;
    289   size_t m_receive_buffer_size = 0;
    290 
    291   std::vector<u8> m_send_buffer;
    292   size_t m_send_buffer_offset = 0;
    293   size_t m_send_buffer_size = 0;
    294 };