duckstation

duckstation, but archived from the revision just before upstream changed it to a proprietary software project, this version is the libre one
git clone https://git.neptards.moe/u3shit/duckstation.git
Log | Files | Refs | README | LICENSE

sockets.cpp (30754B)


      1 // SPDX-FileCopyrightText: 2015-2024 Connor McLaughlin <stenzek@gmail.com>
      2 // SPDX-License-Identifier: (GPL-3.0 OR CC-BY-NC-ND-4.0)
      3 
      4 #include "sockets.h"
      5 #include "platform_misc.h"
      6 
      7 #include "common/assert.h"
      8 #include "common/log.h"
      9 
     10 #include <algorithm>
     11 #include <cstring>
     12 #include <limits>
     13 
     14 #ifndef __APPLE__
     15 #include <malloc.h> // alloca
     16 #else
     17 #include <alloca.h>
     18 #endif
     19 
     20 #ifdef _WIN32
     21 
     22 #include "common/windows_headers.h"
     23 
     24 #include <WS2tcpip.h>
     25 #include <WinSock2.h>
     26 
     27 #define SIZE_CAST(x) static_cast<int>(x)
     28 using ssize_t = int;
     29 using nfds_t = ULONG;
     30 
     31 #else
     32 
     33 #include <arpa/inet.h>
     34 #include <errno.h>
     35 #include <netinet/in.h>
     36 #include <netinet/tcp.h>
     37 #include <poll.h>
     38 #include <sys/ioctl.h>
     39 #include <sys/socket.h>
     40 #include <sys/types.h>
     41 #include <sys/uio.h>
     42 #include <sys/un.h>
     43 #include <unistd.h>
     44 
     45 #ifdef __linux__
     46 #include <sys/epoll.h>
     47 #endif
     48 
     49 #define ioctlsocket ioctl
     50 #define closesocket close
     51 #define WSAEWOULDBLOCK EAGAIN
     52 #define WSAGetLastError() errno
     53 #define WSAPoll poll
     54 #define SIZE_CAST(x) x
     55 
     56 #define SOCKET_ERROR -1
     57 #define INVALID_SOCKET -1
     58 #define SD_BOTH SHUT_RDWR
     59 #endif
     60 
     61 Log_SetChannel(Sockets);
     62 
     63 static bool SetNonBlocking(SocketDescriptor sd, Error* error)
     64 {
     65   // switch to nonblocking mode
     66   unsigned long value = 1;
     67   if (ioctlsocket(sd, FIONBIO, &value) < 0)
     68   {
     69     Error::SetSocket(error, "ioctlsocket() failed: ", WSAGetLastError());
     70     return false;
     71   }
     72 
     73   return true;
     74 }
     75 
     76 void SocketAddress::SetFromSockaddr(const void* sa, size_t length)
     77 {
     78   m_length = std::min(static_cast<u32>(length), static_cast<u32>(sizeof(m_data)));
     79   std::memcpy(m_data, sa, m_length);
     80   if (m_length < sizeof(m_data))
     81     std::memset(m_data + m_length, 0, sizeof(m_data) - m_length);
     82 }
     83 
     84 bool SocketAddress::IsIPAddress() const
     85 {
     86   const sockaddr* addr = reinterpret_cast<const sockaddr*>(m_data);
     87   return (addr->sa_family == AF_INET || addr->sa_family == AF_INET6);
     88 }
     89 
     90 std::optional<SocketAddress> SocketAddress::Parse(Type type, const char* address, u32 port, Error* error)
     91 {
     92   std::optional<SocketAddress> ret = SocketAddress();
     93 
     94   switch (type)
     95   {
     96     case Type::IPv4:
     97     {
     98       sockaddr_in* sain = reinterpret_cast<sockaddr_in*>(ret->m_data);
     99       std::memset(sain, 0, sizeof(sockaddr_in));
    100       sain->sin_family = AF_INET;
    101       sain->sin_port = htons(static_cast<u16>(port));
    102       int res = inet_pton(AF_INET, address, &sain->sin_addr);
    103       if (res == 1)
    104       {
    105         ret->m_length = sizeof(sockaddr_in);
    106       }
    107       else
    108       {
    109         Error::SetSocket(error, "inet_pton() failed: ", WSAGetLastError());
    110         ret.reset();
    111       }
    112     }
    113     break;
    114 
    115     case Type::IPv6:
    116     {
    117       sockaddr_in6* sain6 = reinterpret_cast<sockaddr_in6*>(ret->m_data);
    118       std::memset(sain6, 0, sizeof(sockaddr_in6));
    119       sain6->sin6_family = AF_INET;
    120       sain6->sin6_port = htons(static_cast<u16>(port));
    121       int res = inet_pton(AF_INET6, address, &sain6->sin6_addr);
    122       if (res == 1)
    123       {
    124         ret->m_length = sizeof(sockaddr_in6);
    125       }
    126       else
    127       {
    128         Error::SetSocket(error, "inet_pton() failed: ", WSAGetLastError());
    129         ret.reset();
    130       }
    131     }
    132     break;
    133 
    134 #ifndef _WIN32
    135     case Type::Unix:
    136     {
    137       sockaddr_un* sun = reinterpret_cast<sockaddr_un*>(ret->m_data);
    138       std::memset(sun, 0, sizeof(sockaddr_un));
    139       sun->sun_family = AF_UNIX;
    140 
    141       const size_t len = std::strlen(address);
    142       if ((len + 1) <= std::size(sun->sun_path))
    143       {
    144         std::memcpy(sun->sun_path, address, len);
    145         ret->m_length = sizeof(sockaddr_un);
    146       }
    147       else
    148       {
    149         Error::SetStringFmt(error, "Path length {} exceeds {} bytes.", len, std::size(sun->sun_path));
    150         ret.reset();
    151       }
    152     }
    153     break;
    154 #endif
    155 
    156     default:
    157       Error::SetStringView(error, "Unknown address type.");
    158       ret.reset();
    159       break;
    160   }
    161 
    162   return ret;
    163 }
    164 
    165 SmallString SocketAddress::ToString() const
    166 {
    167   SmallString ret;
    168 
    169   const sockaddr* sa = reinterpret_cast<const sockaddr*>(m_data);
    170   switch (sa->sa_family)
    171   {
    172     case AF_INET:
    173     {
    174       ret.clear();
    175       ret.reserve(128);
    176       const char* res =
    177         inet_ntop(AF_INET, &reinterpret_cast<const sockaddr_in*>(m_data)->sin_addr, ret.data(), ret.buffer_size());
    178       if (res == nullptr)
    179         ret.assign("<unknown>");
    180       else
    181         ret.update_size();
    182 
    183       ret.append_format(":{}", static_cast<u32>(ntohs(reinterpret_cast<const sockaddr_in*>(m_data)->sin_port)));
    184     }
    185     break;
    186 
    187     case AF_INET6:
    188     {
    189       ret.clear();
    190       ret.reserve(128);
    191       ret.append('[');
    192       const char* res = inet_ntop(AF_INET6, &reinterpret_cast<const sockaddr_in6*>(m_data)->sin6_addr, ret.data() + 1,
    193                                   ret.buffer_size() - 1);
    194       if (res == nullptr)
    195         ret.assign("<unknown>");
    196       else
    197         ret.update_size();
    198 
    199       ret.append_format("]:{}", static_cast<u32>(ntohs(reinterpret_cast<const sockaddr_in6*>(m_data)->sin6_port)));
    200     }
    201     break;
    202 
    203 #ifndef _WIN32
    204     case AF_UNIX:
    205     {
    206       ret.assign(reinterpret_cast<const sockaddr_un*>(m_data)->sun_path);
    207     }
    208     break;
    209 #endif
    210 
    211     default:
    212     {
    213       ret.assign("<unknown>");
    214       break;
    215     }
    216   }
    217 
    218   return ret;
    219 }
    220 
    221 BaseSocket::BaseSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor)
    222   : m_multiplexer(multiplexer), m_descriptor(descriptor)
    223 {
    224 }
    225 
    226 BaseSocket::~BaseSocket() = default;
    227 
    228 SocketMultiplexer::SocketMultiplexer() = default;
    229 
    230 SocketMultiplexer::~SocketMultiplexer()
    231 {
    232   CloseAll();
    233 
    234 #ifdef __linux__
    235   if (m_epoll_fd >= 0)
    236     close(m_epoll_fd);
    237 #else
    238   if (m_poll_array)
    239     std::free(m_poll_array);
    240 #endif
    241 }
    242 
    243 std::unique_ptr<SocketMultiplexer> SocketMultiplexer::Create(Error* error)
    244 {
    245   std::unique_ptr<SocketMultiplexer> ret;
    246   if (PlatformMisc::InitializeSocketSupport(error))
    247   {
    248     ret = std::unique_ptr<SocketMultiplexer>(new SocketMultiplexer());
    249     if (!ret->Initialize(error))
    250       ret.reset();
    251   }
    252 
    253   return ret;
    254 }
    255 
    256 bool SocketMultiplexer::Initialize(Error* error)
    257 {
    258 #ifdef __linux__
    259   m_epoll_fd = epoll_create1(0);
    260   if (m_epoll_fd < 0)
    261   {
    262     Error::SetErrno(error, "epoll_create1() failed: ", errno);
    263     return false;
    264   }
    265 
    266   return true;
    267 #else
    268   return true;
    269 #endif
    270 }
    271 
    272 std::shared_ptr<ListenSocket> SocketMultiplexer::InternalCreateListenSocket(const SocketAddress& address,
    273                                                                             CreateStreamSocketCallback callback,
    274                                                                             Error* error)
    275 {
    276   // create and bind socket
    277   const sockaddr* sa = reinterpret_cast<const sockaddr*>(address.GetData());
    278   SocketDescriptor descriptor = socket(sa->sa_family, SOCK_STREAM, StreamSocket::GetSocketProtocolForAddress(address));
    279   if (descriptor == INVALID_SOCKET)
    280   {
    281     Error::SetSocket(error, "socket() failed: ", WSAGetLastError());
    282     return {};
    283   }
    284 
    285   const int reuseaddr_enable = 1;
    286   if (setsockopt(descriptor, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<const char*>(&reuseaddr_enable),
    287                  sizeof(reuseaddr_enable)) < 0)
    288   {
    289     WARNING_LOG("Failed to set SO_REUSEADDR: {}", Error::CreateSocket(WSAGetLastError()).GetDescription());
    290   }
    291 
    292   if (bind(descriptor, sa, address.GetLength()) < 0)
    293   {
    294     Error::SetSocket(error, "bind() failed: ", WSAGetLastError());
    295     closesocket(descriptor);
    296     return {};
    297   }
    298 
    299   if (listen(descriptor, 5) < 0)
    300   {
    301     Error::SetSocket(error, "listen() failed: ", WSAGetLastError());
    302     closesocket(descriptor);
    303     return {};
    304   }
    305 
    306   if (!SetNonBlocking(descriptor, error))
    307   {
    308     closesocket(descriptor);
    309     return {};
    310   }
    311 
    312   // create listensocket
    313   std::shared_ptr<ListenSocket> ret = std::make_shared<ListenSocket>(*this, descriptor, callback);
    314 
    315   // add to list, register for reads
    316   AddOpenSocket(std::static_pointer_cast<BaseSocket>(ret));
    317   SetNotificationMask(ret.get(), descriptor, POLLIN);
    318   return ret;
    319 }
    320 
    321 std::shared_ptr<StreamSocket> SocketMultiplexer::InternalConnectStreamSocket(const SocketAddress& address,
    322                                                                              CreateStreamSocketCallback callback,
    323                                                                              Error* error)
    324 {
    325   // create and bind socket
    326   const sockaddr* sa = reinterpret_cast<const sockaddr*>(address.GetData());
    327   SocketDescriptor descriptor = socket(sa->sa_family, SOCK_STREAM, StreamSocket::GetSocketProtocolForAddress(address));
    328   if (descriptor == INVALID_SOCKET)
    329   {
    330     Error::SetSocket(error, "socket() failed: ", WSAGetLastError());
    331     return {};
    332   }
    333 
    334   if (connect(descriptor, sa, address.GetLength()) < 0)
    335   {
    336     Error::SetSocket(error, "connect() failed: ", WSAGetLastError());
    337     closesocket(descriptor);
    338     return {};
    339   }
    340 
    341   if (!SetNonBlocking(descriptor, error))
    342   {
    343     closesocket(descriptor);
    344     return {};
    345   }
    346 
    347   // create stream socket
    348   std::shared_ptr<StreamSocket> csocket = callback(*this, descriptor);
    349   csocket->InitialSetup();
    350   if (!csocket->IsConnected())
    351     csocket.reset();
    352 
    353   return csocket;
    354 }
    355 
    356 void SocketMultiplexer::AddOpenSocket(std::shared_ptr<BaseSocket> socket)
    357 {
    358 #ifdef __linux__
    359   struct epoll_event ev = {.events = 0u, .data = {.fd = socket->GetDescriptor()}};
    360   if (epoll_ctl(m_epoll_fd, EPOLL_CTL_ADD, socket->GetDescriptor(), &ev) != 0) [[unlikely]]
    361     ERROR_LOG("epoll_ctl() to add socket failed: {}", Error::CreateErrno(errno).GetDescription());
    362 #endif
    363 
    364   std::unique_lock lock(m_open_sockets_lock);
    365   DebugAssert(m_open_sockets.find(socket->GetDescriptor()) == m_open_sockets.end());
    366   m_open_sockets.emplace(socket->GetDescriptor(), std::move(socket));
    367 }
    368 
    369 void SocketMultiplexer::AddClientSocket(std::shared_ptr<BaseSocket> socket)
    370 {
    371   AddOpenSocket(std::move(socket));
    372   m_client_socket_count.fetch_add(1, std::memory_order_acq_rel);
    373 }
    374 
    375 void SocketMultiplexer::RemoveOpenSocket(BaseSocket* socket)
    376 {
    377   std::unique_lock lock(m_open_sockets_lock);
    378   const auto iter = m_open_sockets.find(socket->GetDescriptor());
    379   Assert(iter != m_open_sockets.end());
    380   m_open_sockets.erase(iter);
    381 
    382 #ifdef __linux__
    383   if (epoll_ctl(m_epoll_fd, EPOLL_CTL_DEL, socket->GetDescriptor(), nullptr) != 0) [[unlikely]]
    384     ERROR_LOG("epoll_ctl() to remove socket failed: {}", Error::CreateErrno(errno).GetDescription());
    385 #else
    386 #ifdef _DEBUG
    387   for (size_t i = 0; i < m_poll_array_active_size; i++)
    388   {
    389     pollfd& pfd = m_poll_array[i];
    390     DebugAssert(pfd.fd != socket->GetDescriptor());
    391   }
    392 #endif
    393 
    394   // Update size.
    395   size_t new_active_size = 0;
    396   for (size_t i = 0; i < m_poll_array_active_size; i++)
    397     new_active_size = (m_poll_array[i].fd != INVALID_SOCKET) ? (i + 1) : new_active_size;
    398   m_poll_array_active_size = new_active_size;
    399 #endif
    400 }
    401 
    402 void SocketMultiplexer::RemoveClientSocket(BaseSocket* socket)
    403 {
    404   DebugAssert(m_client_socket_count.load(std::memory_order_acquire) > 0);
    405   m_client_socket_count.fetch_sub(1, std::memory_order_acq_rel);
    406   RemoveOpenSocket(socket);
    407 }
    408 
    409 bool SocketMultiplexer::HasAnyOpenSockets()
    410 {
    411   std::unique_lock lock(m_open_sockets_lock);
    412   return !m_open_sockets.empty();
    413 }
    414 
    415 bool SocketMultiplexer::HasAnyClientSockets()
    416 {
    417   return (GetClientSocketCount() > 0);
    418 }
    419 
    420 size_t SocketMultiplexer::GetClientSocketCount()
    421 {
    422   return m_client_socket_count.load(std::memory_order_acquire);
    423 }
    424 
    425 void SocketMultiplexer::CloseAll()
    426 {
    427   std::unique_lock lock(m_open_sockets_lock);
    428 
    429   while (!m_open_sockets.empty())
    430   {
    431     std::shared_ptr<BaseSocket> socket = m_open_sockets.begin()->second;
    432     lock.unlock();
    433     socket->Close();
    434     lock.lock();
    435   }
    436 }
    437 
    438 void SocketMultiplexer::SetNotificationMask(BaseSocket* socket, SocketDescriptor descriptor, u32 events)
    439 {
    440 #ifdef __linux__
    441   struct epoll_event ev = {.events = events, .data = {.fd = descriptor}};
    442   if (epoll_ctl(m_epoll_fd, EPOLL_CTL_MOD, descriptor, &ev) != 0) [[unlikely]]
    443     ERROR_LOG("epoll_ctl() for events 0x{:x} failed: {}", events, Error::CreateErrno(errno).GetDescription());
    444 #else
    445   std::unique_lock lock(m_poll_array_lock);
    446   size_t free_slot = m_poll_array_active_size;
    447   for (size_t i = 0; i < m_poll_array_active_size; i++)
    448   {
    449     pollfd& pfd = m_poll_array[i];
    450     if (pfd.fd != descriptor)
    451     {
    452       free_slot = (pfd.fd < 0 && free_slot != m_poll_array_active_size) ? i : free_slot;
    453       continue;
    454     }
    455 
    456     // unbinding?
    457     if (events != 0)
    458       pfd.events = static_cast<short>(events);
    459     else
    460       pfd.fd = INVALID_SOCKET;
    461 
    462     return;
    463   }
    464 
    465   // don't create entries for null masks
    466   if (events == 0)
    467     return;
    468 
    469   // need to grow the array?
    470   if (free_slot == m_poll_array_max_size)
    471   {
    472     const size_t new_size = std::max(free_slot + 1, free_slot * 2);
    473     pollfd* new_array = static_cast<pollfd*>(std::realloc(m_poll_array, sizeof(pollfd) * new_size));
    474     if (!new_array)
    475       Panic("Memory allocation failed.");
    476 
    477     for (size_t i = m_poll_array_max_size; i < new_size; i++)
    478       new_array[i] = {.fd = INVALID_SOCKET, .events = 0, .revents = 0};
    479     m_poll_array = new_array;
    480     m_poll_array_max_size = new_size;
    481   }
    482 
    483   m_poll_array[free_slot] = {.fd = descriptor, .events = static_cast<short>(events), .revents = 0};
    484   m_poll_array_active_size = free_slot + 1;
    485 #endif
    486 }
    487 
    488 bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds)
    489 {
    490 #ifdef __linux__
    491   constexpr int MAX_EVENTS = 128;
    492   struct epoll_event events[MAX_EVENTS];
    493 
    494   const int nevents = epoll_wait(m_epoll_fd, events, MAX_EVENTS, static_cast<int>(milliseconds));
    495   if (nevents <= 0)
    496     return false;
    497 
    498   // find sockets that triggered, we use an array here so we can avoid holding the lock, and if a socket disconnects
    499   using PendingSocketPair = std::pair<std::shared_ptr<BaseSocket>, u32>;
    500   PendingSocketPair* triggered_sockets =
    501     reinterpret_cast<PendingSocketPair*>(alloca(sizeof(PendingSocketPair) * static_cast<size_t>(nevents)));
    502   size_t num_triggered_sockets = 0;
    503   {
    504     std::unique_lock open_lock(m_open_sockets_lock);
    505     for (int i = 0; i < nevents; i++)
    506     {
    507       const epoll_event& ev = events[i];
    508       const auto iter = m_open_sockets.find(ev.data.fd);
    509       if (iter == m_open_sockets.end()) [[unlikely]]
    510       {
    511         ERROR_LOG("Attempting to look up unknown socket {}, this should never happen.", ev.data.fd);
    512         continue;
    513       }
    514 
    515       // we add a reference here in case the read kills it with a write pending, or something like that
    516       new (&triggered_sockets[num_triggered_sockets++]) PendingSocketPair(iter->second->shared_from_this(), ev.events);
    517     }
    518   }
    519 
    520   // fire events
    521   for (size_t i = 0; i < num_triggered_sockets; i++)
    522   {
    523     PendingSocketPair& psp = triggered_sockets[i];
    524 
    525     // fire events
    526     if (psp.second & (EPOLLRDHUP | EPOLLHUP | EPOLLERR))
    527     {
    528       psp.first->OnHangupEvent();
    529     }
    530     else
    531     {
    532       if (psp.second & EPOLLIN)
    533         psp.first->OnReadEvent();
    534       if (psp.second & EPOLLOUT)
    535         psp.first->OnWriteEvent();
    536     }
    537 
    538     psp.first.~shared_ptr();
    539   }
    540 
    541   return true;
    542 #else
    543   std::unique_lock lock(m_poll_array_lock);
    544   if (m_poll_array_active_size == 0)
    545     return false;
    546 
    547   const int res = WSAPoll(m_poll_array, static_cast<nfds_t>(m_poll_array_active_size), milliseconds);
    548   if (res <= 0)
    549     return false;
    550 
    551   // find sockets that triggered, we use an array here so we can avoid holding the lock, and if a socket disconnects
    552   using PendingSocketPair = std::pair<std::shared_ptr<BaseSocket>, u32>;
    553   PendingSocketPair* triggered_sockets =
    554     reinterpret_cast<PendingSocketPair*>(alloca(sizeof(PendingSocketPair) * static_cast<size_t>(res)));
    555   size_t num_triggered_sockets = 0;
    556   {
    557     std::unique_lock open_lock(m_open_sockets_lock);
    558     for (size_t i = 0; i < m_poll_array_active_size; i++)
    559     {
    560       const pollfd& pfd = m_poll_array[i];
    561       if (pfd.revents == 0)
    562         continue;
    563 
    564       const auto iter = m_open_sockets.find(pfd.fd);
    565       if (iter == m_open_sockets.end()) [[unlikely]]
    566       {
    567         ERROR_LOG("Attempting to look up unknown socket {}, this should never happen.", pfd.fd);
    568         continue;
    569       }
    570 
    571       // we add a reference here in case the read kills it with a write pending, or something like that
    572       new (&triggered_sockets[num_triggered_sockets++])
    573         PendingSocketPair(iter->second->shared_from_this(), pfd.revents);
    574     }
    575   }
    576 
    577   // release lock so connections etc can acquire it
    578   lock.unlock();
    579 
    580   // fire events
    581   for (size_t i = 0; i < num_triggered_sockets; i++)
    582   {
    583     PendingSocketPair& psp = triggered_sockets[i];
    584 
    585     // fire events
    586     if (psp.second & (POLLHUP | POLLERR))
    587     {
    588       psp.first->OnHangupEvent();
    589     }
    590     else
    591     {
    592       if (psp.second & POLLIN)
    593         psp.first->OnReadEvent();
    594       if (psp.second & POLLOUT)
    595         psp.first->OnWriteEvent();
    596     }
    597 
    598     psp.first.~shared_ptr();
    599   }
    600 
    601   return true;
    602 #endif
    603 }
    604 
    605 ListenSocket::ListenSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor,
    606                            SocketMultiplexer::CreateStreamSocketCallback accept_callback)
    607   : BaseSocket(multiplexer, descriptor), m_accept_callback(accept_callback)
    608 {
    609   // get local address
    610   sockaddr_storage sa;
    611   socklen_t salen = sizeof(sa);
    612   if (getsockname(m_descriptor, reinterpret_cast<sockaddr*>(&sa), &salen) == 0)
    613     m_local_address.SetFromSockaddr(&sa, salen);
    614 }
    615 
    616 ListenSocket::~ListenSocket()
    617 {
    618   DebugAssert(m_descriptor == INVALID_SOCKET);
    619 }
    620 
    621 void ListenSocket::Close()
    622 {
    623   if (m_descriptor < 0)
    624     return;
    625 
    626   m_multiplexer.SetNotificationMask(this, m_descriptor, 0);
    627   m_multiplexer.RemoveOpenSocket(this);
    628   closesocket(m_descriptor);
    629   m_descriptor = INVALID_SOCKET;
    630 }
    631 
    632 void ListenSocket::OnReadEvent()
    633 {
    634   // connection incoming
    635   sockaddr_storage sa;
    636   socklen_t salen = sizeof(sa);
    637   SocketDescriptor new_descriptor = accept(m_descriptor, reinterpret_cast<sockaddr*>(&sa), &salen);
    638   if (new_descriptor == INVALID_SOCKET)
    639   {
    640     ERROR_LOG("accept() returned {}", Error::CreateSocket(WSAGetLastError()).GetDescription());
    641     return;
    642   }
    643 
    644   Error error;
    645   if (!SetNonBlocking(new_descriptor, &error))
    646   {
    647     ERROR_LOG("Failed to set just-connected socket to nonblocking: {}", error.GetDescription());
    648     closesocket(new_descriptor);
    649     return;
    650   }
    651 
    652   // create socket, we release our own reference.
    653   std::shared_ptr<StreamSocket> client = m_accept_callback(m_multiplexer, new_descriptor);
    654   if (!client)
    655   {
    656     closesocket(new_descriptor);
    657     return;
    658   }
    659 
    660   m_num_connections_accepted++;
    661   client->InitialSetup();
    662 }
    663 
    664 void ListenSocket::OnWriteEvent()
    665 {
    666   ERROR_LOG("Unexpected OnWriteEvent() in ListenSocket {}", m_local_address.ToString());
    667 }
    668 
    669 void ListenSocket::OnHangupEvent()
    670 {
    671   ERROR_LOG("Unexpected OnHangupEvent() in ListenSocket {}", m_local_address.ToString());
    672 }
    673 
    674 StreamSocket::StreamSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor)
    675   : BaseSocket(multiplexer, descriptor)
    676 {
    677   // get local address
    678   sockaddr_storage sa;
    679   socklen_t salen = sizeof(sa);
    680   if (getsockname(m_descriptor, reinterpret_cast<sockaddr*>(&sa), &salen) == 0)
    681     m_local_address.SetFromSockaddr(&sa, salen);
    682 
    683   // get remote address
    684   salen = sizeof(sockaddr_storage);
    685   if (getpeername(m_descriptor, reinterpret_cast<sockaddr*>(&sa), &salen) == 0)
    686     m_remote_address.SetFromSockaddr(&sa, salen);
    687 }
    688 
    689 StreamSocket::~StreamSocket()
    690 {
    691   DebugAssert(m_descriptor == INVALID_SOCKET);
    692 }
    693 
    694 u32 StreamSocket::GetSocketProtocolForAddress(const SocketAddress& sa)
    695 {
    696   const sockaddr* ssa = reinterpret_cast<const sockaddr*>(sa.GetData());
    697   return (ssa->sa_family == AF_INET || ssa->sa_family == AF_INET6) ? IPPROTO_TCP : 0;
    698 }
    699 
    700 void StreamSocket::InitialSetup()
    701 {
    702   // register for notifications
    703   m_multiplexer.AddClientSocket(shared_from_this());
    704   m_multiplexer.SetNotificationMask(this, m_descriptor, POLLIN);
    705 
    706   // trigger connected notification
    707   std::unique_lock lock(m_lock);
    708   OnConnected();
    709 }
    710 
    711 size_t StreamSocket::Read(void* buffer, size_t buffer_size)
    712 {
    713   std::unique_lock lock(m_lock);
    714   if (!m_connected)
    715     return 0;
    716 
    717   // try a read
    718   const ssize_t len = recv(m_descriptor, static_cast<char*>(buffer), SIZE_CAST(buffer_size), 0);
    719   if (len <= 0)
    720   {
    721     // Check for EAGAIN
    722     if (len < 0 && WSAGetLastError() == WSAEWOULDBLOCK)
    723     {
    724       // Not an error. Just means no data is available.
    725       return 0;
    726     }
    727 
    728     // error
    729     CloseWithError();
    730     return 0;
    731   }
    732 
    733   return len;
    734 }
    735 
    736 size_t StreamSocket::Write(const void* buffer, size_t buffer_size)
    737 {
    738   std::unique_lock lock(m_lock);
    739   if (!m_connected)
    740     return 0;
    741 
    742   // try a write
    743   const ssize_t len = send(m_descriptor, static_cast<const char*>(buffer), SIZE_CAST(buffer_size), 0);
    744   if (len <= 0)
    745   {
    746     // Check for EAGAIN
    747     if (len < 0 && WSAGetLastError() == WSAEWOULDBLOCK)
    748     {
    749       // Not an error. Just means no data is available.
    750       return 0;
    751     }
    752 
    753     // error
    754     CloseWithError();
    755     return 0;
    756   }
    757 
    758   return len;
    759 }
    760 
    761 size_t StreamSocket::WriteVector(const void** buffers, const size_t* buffer_lengths, size_t num_buffers)
    762 {
    763   std::unique_lock lock(m_lock);
    764   if (!m_connected || num_buffers == 0)
    765     return 0;
    766 
    767 #ifdef _WIN32
    768 
    769   WSABUF* bufs = static_cast<WSABUF*>(alloca(sizeof(WSABUF) * num_buffers));
    770   for (size_t i = 0; i < num_buffers; i++)
    771   {
    772     bufs[i].buf = (CHAR*)buffers[i];
    773     bufs[i].len = (ULONG)buffer_lengths[i];
    774   }
    775 
    776   DWORD bytesSent = 0;
    777   if (WSASend(m_descriptor, bufs, (DWORD)num_buffers, &bytesSent, 0, nullptr, nullptr) == SOCKET_ERROR)
    778   {
    779     if (WSAGetLastError() != WSAEWOULDBLOCK)
    780     {
    781       // Socket error.
    782       CloseWithError();
    783       return 0;
    784     }
    785   }
    786 
    787   return static_cast<size_t>(bytesSent);
    788 
    789 #else // _WIN32
    790 
    791   iovec* bufs = static_cast<iovec*>(alloca(sizeof(iovec) * num_buffers));
    792   for (size_t i = 0; i < num_buffers; i++)
    793   {
    794     bufs[i].iov_base = (void*)buffers[i];
    795     bufs[i].iov_len = buffer_lengths[i];
    796   }
    797 
    798   ssize_t res = writev(m_descriptor, bufs, num_buffers);
    799   if (res < 0)
    800   {
    801     if (errno != EAGAIN)
    802     {
    803       // Socket error.
    804       CloseWithError();
    805       return 0;
    806     }
    807 
    808     res = 0;
    809   }
    810 
    811   return static_cast<size_t>(res);
    812 
    813 #endif
    814 }
    815 
    816 bool StreamSocket::SetNagleBuffering(bool enabled, Error* error /* = nullptr */)
    817 {
    818   if (!m_local_address.IsIPAddress())
    819   {
    820     Error::SetStringView(error, "Attempting to disable nagle on a non-IP socket.");
    821     return false;
    822   }
    823 
    824   int disable = enabled ? 0 : 1;
    825   if (setsockopt(m_descriptor, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char*>(&disable), sizeof(disable)) != 0)
    826   {
    827     Error::SetSocket(error, "setsockopt(TCP_NODELAY) failed: ", WSAGetLastError());
    828     return false;
    829   }
    830 
    831   return true;
    832 }
    833 
    834 void StreamSocket::Close()
    835 {
    836   std::unique_lock lock(m_lock);
    837   if (!m_connected)
    838     return;
    839 
    840   m_multiplexer.SetNotificationMask(this, m_descriptor, 0);
    841   m_multiplexer.RemoveClientSocket(this);
    842   shutdown(m_descriptor, SD_BOTH);
    843   closesocket(m_descriptor);
    844   m_descriptor = INVALID_SOCKET;
    845   m_connected = false;
    846 
    847   OnDisconnected(Error::CreateString("Connection explicitly closed."));
    848 }
    849 
    850 void StreamSocket::CloseWithError()
    851 {
    852   std::unique_lock lock(m_lock);
    853   DebugAssert(m_connected);
    854 
    855   Error error;
    856   const int error_code = WSAGetLastError();
    857   if (error_code == 0)
    858     error.SetStringView("Connection closed by peer.");
    859   else
    860     error.SetSocket(error_code);
    861 
    862   m_multiplexer.SetNotificationMask(this, m_descriptor, 0);
    863   m_multiplexer.RemoveClientSocket(this);
    864   closesocket(m_descriptor);
    865   m_descriptor = INVALID_SOCKET;
    866   m_connected = false;
    867 
    868   OnDisconnected(error);
    869 }
    870 
    871 void StreamSocket::OnReadEvent()
    872 {
    873   // forward through
    874   std::unique_lock lock(m_lock);
    875   if (m_connected)
    876     OnRead();
    877 }
    878 
    879 void StreamSocket::OnWriteEvent()
    880 {
    881   // shouldn't be called
    882 }
    883 
    884 void StreamSocket::OnHangupEvent()
    885 {
    886   std::unique_lock lock(m_lock);
    887   if (!m_connected)
    888     return;
    889 
    890   m_multiplexer.SetNotificationMask(this, m_descriptor, 0);
    891   m_multiplexer.RemoveClientSocket(this);
    892   closesocket(m_descriptor);
    893   m_descriptor = INVALID_SOCKET;
    894   m_connected = false;
    895 
    896   OnDisconnected(Error::CreateString("Connection closed by peer."));
    897 }
    898 
    899 BufferedStreamSocket::BufferedStreamSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor,
    900                                            size_t receive_buffer_size /* = 16384 */,
    901                                            size_t send_buffer_size /* = 16384 */)
    902   : StreamSocket(multiplexer, descriptor), m_receive_buffer(receive_buffer_size), m_send_buffer(send_buffer_size)
    903 {
    904 }
    905 
    906 BufferedStreamSocket::~BufferedStreamSocket()
    907 {
    908 }
    909 
    910 std::unique_lock<std::recursive_mutex> BufferedStreamSocket::GetLock()
    911 {
    912   return std::unique_lock(m_lock);
    913 }
    914 
    915 std::span<const u8> BufferedStreamSocket::AcquireReadBuffer() const
    916 {
    917   return std::span<const u8>(m_receive_buffer.data() + m_receive_buffer_offset, m_receive_buffer_size);
    918 }
    919 
    920 void BufferedStreamSocket::ReleaseReadBuffer(size_t bytes_consumed)
    921 {
    922   DebugAssert(bytes_consumed <= m_receive_buffer_size);
    923   m_receive_buffer_offset += static_cast<u32>(bytes_consumed);
    924   m_receive_buffer_size -= static_cast<u32>(bytes_consumed);
    925 
    926   // Anything left? If not, reset offset.
    927   m_receive_buffer_offset = (m_receive_buffer_size == 0) ? 0 : m_receive_buffer_offset;
    928 }
    929 
    930 std::span<u8> BufferedStreamSocket::AcquireWriteBuffer(size_t wanted_bytes, bool allow_smaller /* = false */)
    931 {
    932   if (!m_connected)
    933     return {};
    934 
    935   // If to get the desired space, we need to move backwards, do so.
    936   if ((m_send_buffer_offset + m_send_buffer_size + wanted_bytes) > m_send_buffer.size())
    937   {
    938     if ((m_send_buffer_size + wanted_bytes) > m_send_buffer.size() && !allow_smaller)
    939     {
    940       // Not enough space.
    941       return {};
    942     }
    943 
    944     // Shuffle buffer backwards.
    945     std::memmove(m_send_buffer.data(), m_send_buffer.data() + m_send_buffer_offset, m_send_buffer_size);
    946     m_send_buffer_offset = 0;
    947   }
    948 
    949   DebugAssert((m_send_buffer_offset + m_send_buffer_size + wanted_bytes) <= m_send_buffer.size());
    950   return std::span<u8>(m_send_buffer.data() + m_send_buffer_offset + m_send_buffer_size,
    951                        m_send_buffer.size() - m_send_buffer_offset - m_send_buffer_size);
    952 }
    953 
    954 void BufferedStreamSocket::ReleaseWriteBuffer(size_t bytes_written, bool commit /* = true */)
    955 {
    956   if (!m_connected)
    957     return;
    958 
    959   DebugAssert((m_send_buffer_offset + m_send_buffer_size + bytes_written) <= m_send_buffer.size());
    960   m_send_buffer_size += static_cast<u32>(bytes_written);
    961 
    962   // Send as much as we can.
    963   if (commit && m_send_buffer_size > 0)
    964   {
    965     const ssize_t res = send(m_descriptor, reinterpret_cast<const char*>(m_send_buffer.data() + m_send_buffer_offset),
    966                              SIZE_CAST(m_send_buffer_size), 0);
    967     if (res < 0 && WSAGetLastError() != WSAEWOULDBLOCK)
    968     {
    969       CloseWithError();
    970       return;
    971     }
    972 
    973     m_send_buffer_offset += static_cast<size_t>(res);
    974     m_send_buffer_size -= static_cast<size_t>(res);
    975     if (m_send_buffer_size == 0)
    976     {
    977       m_send_buffer_offset = 0;
    978     }
    979     else
    980     {
    981       // Register for writes to finish it off.
    982       m_multiplexer.SetNotificationMask(this, m_descriptor, POLLIN | POLLOUT);
    983     }
    984   }
    985 }
    986 
    987 size_t BufferedStreamSocket::Read(void* buffer, size_t buffer_size)
    988 {
    989   // Read from receive buffer.
    990   const std::span<const u8> rdbuf = AcquireReadBuffer();
    991   if (rdbuf.empty())
    992     return 0;
    993 
    994   const size_t bytes_to_read = std::min(rdbuf.size(), buffer_size);
    995   std::memcpy(buffer, rdbuf.data(), bytes_to_read);
    996   ReleaseReadBuffer(bytes_to_read);
    997   return bytes_to_read;
    998 }
    999 
   1000 size_t BufferedStreamSocket::Write(const void* buffer, size_t buffer_size)
   1001 {
   1002   if (!m_connected)
   1003     return 0;
   1004 
   1005   // Read from receive buffer.
   1006   const std::span<u8> wrbuf = AcquireWriteBuffer(buffer_size, true);
   1007   if (wrbuf.empty())
   1008     return 0;
   1009 
   1010   const size_t bytes_to_write = std::min(wrbuf.size(), buffer_size);
   1011   std::memcpy(wrbuf.data(), buffer, bytes_to_write);
   1012   ReleaseWriteBuffer(bytes_to_write);
   1013   return bytes_to_write;
   1014 }
   1015 
   1016 size_t BufferedStreamSocket::WriteVector(const void** buffers, const size_t* buffer_lengths, size_t num_buffers)
   1017 {
   1018   if (!m_connected || num_buffers == 0)
   1019     return 0;
   1020 
   1021   size_t total_size = 0;
   1022   for (size_t i = 0; i < num_buffers; i++)
   1023     total_size += buffer_lengths[i];
   1024 
   1025   const std::span<u8> wrbuf = AcquireWriteBuffer(total_size, true);
   1026   if (wrbuf.empty())
   1027     return 0;
   1028 
   1029   size_t written_bytes = 0;
   1030   for (size_t i = 0; i < num_buffers; i++)
   1031   {
   1032     const size_t bytes_to_write = std::min(wrbuf.size() - written_bytes, buffer_lengths[i]);
   1033     if (bytes_to_write == 0)
   1034       break;
   1035 
   1036     std::memcpy(&wrbuf[written_bytes], buffers[i], bytes_to_write);
   1037     written_bytes += buffer_lengths[i];
   1038   }
   1039 
   1040   return written_bytes;
   1041 }
   1042 
   1043 void BufferedStreamSocket::Close()
   1044 {
   1045   StreamSocket::Close();
   1046 
   1047   m_receive_buffer_offset = 0;
   1048   m_receive_buffer_size = 0;
   1049   m_send_buffer_offset = 0;
   1050   m_send_buffer_size = 0;
   1051 }
   1052 
   1053 void BufferedStreamSocket::OnReadEvent()
   1054 {
   1055   std::unique_lock lock(m_lock);
   1056   if (!m_connected)
   1057     return;
   1058 
   1059   // Pull as many bytes as possible into the read buffer.
   1060   for (;;)
   1061   {
   1062     const size_t buffer_space = m_receive_buffer.size() - m_receive_buffer_offset - m_receive_buffer_size;
   1063     if (buffer_space == 0) [[unlikely]]
   1064     {
   1065       // If we're here again, it means OnRead() didn't consume the data, and we overflowed.
   1066       ERROR_LOG("Receive buffer overflow, dropping client {}.", GetRemoteAddress().ToString());
   1067       CloseWithError();
   1068       return;
   1069     }
   1070 
   1071     const ssize_t res = recv(
   1072       m_descriptor, reinterpret_cast<char*>(m_receive_buffer.data() + m_receive_buffer_offset + m_receive_buffer_size),
   1073       SIZE_CAST(buffer_space), 0);
   1074     if (res <= 0 && WSAGetLastError() != WSAEWOULDBLOCK)
   1075     {
   1076       CloseWithError();
   1077       return;
   1078     }
   1079 
   1080     m_receive_buffer_size += static_cast<size_t>(res);
   1081     OnRead();
   1082 
   1083     // Are we at the end?
   1084     if ((m_receive_buffer_offset + m_receive_buffer_size) == m_receive_buffer.size())
   1085     {
   1086       // Try to claw back some of the buffer, and try reading again.
   1087       if (m_receive_buffer_offset > 0)
   1088       {
   1089         std::memmove(m_receive_buffer.data(), m_receive_buffer.data() + m_receive_buffer_offset, m_receive_buffer_size);
   1090         m_receive_buffer_offset = 0;
   1091         continue;
   1092       }
   1093     }
   1094 
   1095     break;
   1096   }
   1097 }
   1098 
   1099 void BufferedStreamSocket::OnWriteEvent()
   1100 {
   1101   std::unique_lock lock(m_lock);
   1102   if (!m_connected)
   1103     return;
   1104 
   1105   // Send as much as we can.
   1106   if (m_send_buffer_size > 0)
   1107   {
   1108     const ssize_t res = send(m_descriptor, reinterpret_cast<const char*>(m_send_buffer.data() + m_send_buffer_offset),
   1109                              SIZE_CAST(m_send_buffer_size), 0);
   1110     if (res < 0 && WSAGetLastError() != WSAEWOULDBLOCK)
   1111     {
   1112       CloseWithError();
   1113       return;
   1114     }
   1115 
   1116     m_send_buffer_offset += static_cast<size_t>(res);
   1117     m_send_buffer_size -= static_cast<size_t>(res);
   1118     if (m_send_buffer_size == 0)
   1119       m_send_buffer_offset = 0;
   1120   }
   1121 
   1122   OnWrite();
   1123 
   1124   if (m_send_buffer_size == 0)
   1125   {
   1126     // Are we done? Switch back to reads only.
   1127     m_multiplexer.SetNotificationMask(this, m_descriptor, POLLIN);
   1128   }
   1129 }
   1130 
   1131 void BufferedStreamSocket::OnWrite()
   1132 {
   1133 }