sockets.cpp (30754B)
1 // SPDX-FileCopyrightText: 2015-2024 Connor McLaughlin <stenzek@gmail.com> 2 // SPDX-License-Identifier: (GPL-3.0 OR CC-BY-NC-ND-4.0) 3 4 #include "sockets.h" 5 #include "platform_misc.h" 6 7 #include "common/assert.h" 8 #include "common/log.h" 9 10 #include <algorithm> 11 #include <cstring> 12 #include <limits> 13 14 #ifndef __APPLE__ 15 #include <malloc.h> // alloca 16 #else 17 #include <alloca.h> 18 #endif 19 20 #ifdef _WIN32 21 22 #include "common/windows_headers.h" 23 24 #include <WS2tcpip.h> 25 #include <WinSock2.h> 26 27 #define SIZE_CAST(x) static_cast<int>(x) 28 using ssize_t = int; 29 using nfds_t = ULONG; 30 31 #else 32 33 #include <arpa/inet.h> 34 #include <errno.h> 35 #include <netinet/in.h> 36 #include <netinet/tcp.h> 37 #include <poll.h> 38 #include <sys/ioctl.h> 39 #include <sys/socket.h> 40 #include <sys/types.h> 41 #include <sys/uio.h> 42 #include <sys/un.h> 43 #include <unistd.h> 44 45 #ifdef __linux__ 46 #include <sys/epoll.h> 47 #endif 48 49 #define ioctlsocket ioctl 50 #define closesocket close 51 #define WSAEWOULDBLOCK EAGAIN 52 #define WSAGetLastError() errno 53 #define WSAPoll poll 54 #define SIZE_CAST(x) x 55 56 #define SOCKET_ERROR -1 57 #define INVALID_SOCKET -1 58 #define SD_BOTH SHUT_RDWR 59 #endif 60 61 Log_SetChannel(Sockets); 62 63 static bool SetNonBlocking(SocketDescriptor sd, Error* error) 64 { 65 // switch to nonblocking mode 66 unsigned long value = 1; 67 if (ioctlsocket(sd, FIONBIO, &value) < 0) 68 { 69 Error::SetSocket(error, "ioctlsocket() failed: ", WSAGetLastError()); 70 return false; 71 } 72 73 return true; 74 } 75 76 void SocketAddress::SetFromSockaddr(const void* sa, size_t length) 77 { 78 m_length = std::min(static_cast<u32>(length), static_cast<u32>(sizeof(m_data))); 79 std::memcpy(m_data, sa, m_length); 80 if (m_length < sizeof(m_data)) 81 std::memset(m_data + m_length, 0, sizeof(m_data) - m_length); 82 } 83 84 bool SocketAddress::IsIPAddress() const 85 { 86 const sockaddr* addr = reinterpret_cast<const sockaddr*>(m_data); 87 return (addr->sa_family == AF_INET || addr->sa_family == AF_INET6); 88 } 89 90 std::optional<SocketAddress> SocketAddress::Parse(Type type, const char* address, u32 port, Error* error) 91 { 92 std::optional<SocketAddress> ret = SocketAddress(); 93 94 switch (type) 95 { 96 case Type::IPv4: 97 { 98 sockaddr_in* sain = reinterpret_cast<sockaddr_in*>(ret->m_data); 99 std::memset(sain, 0, sizeof(sockaddr_in)); 100 sain->sin_family = AF_INET; 101 sain->sin_port = htons(static_cast<u16>(port)); 102 int res = inet_pton(AF_INET, address, &sain->sin_addr); 103 if (res == 1) 104 { 105 ret->m_length = sizeof(sockaddr_in); 106 } 107 else 108 { 109 Error::SetSocket(error, "inet_pton() failed: ", WSAGetLastError()); 110 ret.reset(); 111 } 112 } 113 break; 114 115 case Type::IPv6: 116 { 117 sockaddr_in6* sain6 = reinterpret_cast<sockaddr_in6*>(ret->m_data); 118 std::memset(sain6, 0, sizeof(sockaddr_in6)); 119 sain6->sin6_family = AF_INET; 120 sain6->sin6_port = htons(static_cast<u16>(port)); 121 int res = inet_pton(AF_INET6, address, &sain6->sin6_addr); 122 if (res == 1) 123 { 124 ret->m_length = sizeof(sockaddr_in6); 125 } 126 else 127 { 128 Error::SetSocket(error, "inet_pton() failed: ", WSAGetLastError()); 129 ret.reset(); 130 } 131 } 132 break; 133 134 #ifndef _WIN32 135 case Type::Unix: 136 { 137 sockaddr_un* sun = reinterpret_cast<sockaddr_un*>(ret->m_data); 138 std::memset(sun, 0, sizeof(sockaddr_un)); 139 sun->sun_family = AF_UNIX; 140 141 const size_t len = std::strlen(address); 142 if ((len + 1) <= std::size(sun->sun_path)) 143 { 144 std::memcpy(sun->sun_path, address, len); 145 ret->m_length = sizeof(sockaddr_un); 146 } 147 else 148 { 149 Error::SetStringFmt(error, "Path length {} exceeds {} bytes.", len, std::size(sun->sun_path)); 150 ret.reset(); 151 } 152 } 153 break; 154 #endif 155 156 default: 157 Error::SetStringView(error, "Unknown address type."); 158 ret.reset(); 159 break; 160 } 161 162 return ret; 163 } 164 165 SmallString SocketAddress::ToString() const 166 { 167 SmallString ret; 168 169 const sockaddr* sa = reinterpret_cast<const sockaddr*>(m_data); 170 switch (sa->sa_family) 171 { 172 case AF_INET: 173 { 174 ret.clear(); 175 ret.reserve(128); 176 const char* res = 177 inet_ntop(AF_INET, &reinterpret_cast<const sockaddr_in*>(m_data)->sin_addr, ret.data(), ret.buffer_size()); 178 if (res == nullptr) 179 ret.assign("<unknown>"); 180 else 181 ret.update_size(); 182 183 ret.append_format(":{}", static_cast<u32>(ntohs(reinterpret_cast<const sockaddr_in*>(m_data)->sin_port))); 184 } 185 break; 186 187 case AF_INET6: 188 { 189 ret.clear(); 190 ret.reserve(128); 191 ret.append('['); 192 const char* res = inet_ntop(AF_INET6, &reinterpret_cast<const sockaddr_in6*>(m_data)->sin6_addr, ret.data() + 1, 193 ret.buffer_size() - 1); 194 if (res == nullptr) 195 ret.assign("<unknown>"); 196 else 197 ret.update_size(); 198 199 ret.append_format("]:{}", static_cast<u32>(ntohs(reinterpret_cast<const sockaddr_in6*>(m_data)->sin6_port))); 200 } 201 break; 202 203 #ifndef _WIN32 204 case AF_UNIX: 205 { 206 ret.assign(reinterpret_cast<const sockaddr_un*>(m_data)->sun_path); 207 } 208 break; 209 #endif 210 211 default: 212 { 213 ret.assign("<unknown>"); 214 break; 215 } 216 } 217 218 return ret; 219 } 220 221 BaseSocket::BaseSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor) 222 : m_multiplexer(multiplexer), m_descriptor(descriptor) 223 { 224 } 225 226 BaseSocket::~BaseSocket() = default; 227 228 SocketMultiplexer::SocketMultiplexer() = default; 229 230 SocketMultiplexer::~SocketMultiplexer() 231 { 232 CloseAll(); 233 234 #ifdef __linux__ 235 if (m_epoll_fd >= 0) 236 close(m_epoll_fd); 237 #else 238 if (m_poll_array) 239 std::free(m_poll_array); 240 #endif 241 } 242 243 std::unique_ptr<SocketMultiplexer> SocketMultiplexer::Create(Error* error) 244 { 245 std::unique_ptr<SocketMultiplexer> ret; 246 if (PlatformMisc::InitializeSocketSupport(error)) 247 { 248 ret = std::unique_ptr<SocketMultiplexer>(new SocketMultiplexer()); 249 if (!ret->Initialize(error)) 250 ret.reset(); 251 } 252 253 return ret; 254 } 255 256 bool SocketMultiplexer::Initialize(Error* error) 257 { 258 #ifdef __linux__ 259 m_epoll_fd = epoll_create1(0); 260 if (m_epoll_fd < 0) 261 { 262 Error::SetErrno(error, "epoll_create1() failed: ", errno); 263 return false; 264 } 265 266 return true; 267 #else 268 return true; 269 #endif 270 } 271 272 std::shared_ptr<ListenSocket> SocketMultiplexer::InternalCreateListenSocket(const SocketAddress& address, 273 CreateStreamSocketCallback callback, 274 Error* error) 275 { 276 // create and bind socket 277 const sockaddr* sa = reinterpret_cast<const sockaddr*>(address.GetData()); 278 SocketDescriptor descriptor = socket(sa->sa_family, SOCK_STREAM, StreamSocket::GetSocketProtocolForAddress(address)); 279 if (descriptor == INVALID_SOCKET) 280 { 281 Error::SetSocket(error, "socket() failed: ", WSAGetLastError()); 282 return {}; 283 } 284 285 const int reuseaddr_enable = 1; 286 if (setsockopt(descriptor, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<const char*>(&reuseaddr_enable), 287 sizeof(reuseaddr_enable)) < 0) 288 { 289 WARNING_LOG("Failed to set SO_REUSEADDR: {}", Error::CreateSocket(WSAGetLastError()).GetDescription()); 290 } 291 292 if (bind(descriptor, sa, address.GetLength()) < 0) 293 { 294 Error::SetSocket(error, "bind() failed: ", WSAGetLastError()); 295 closesocket(descriptor); 296 return {}; 297 } 298 299 if (listen(descriptor, 5) < 0) 300 { 301 Error::SetSocket(error, "listen() failed: ", WSAGetLastError()); 302 closesocket(descriptor); 303 return {}; 304 } 305 306 if (!SetNonBlocking(descriptor, error)) 307 { 308 closesocket(descriptor); 309 return {}; 310 } 311 312 // create listensocket 313 std::shared_ptr<ListenSocket> ret = std::make_shared<ListenSocket>(*this, descriptor, callback); 314 315 // add to list, register for reads 316 AddOpenSocket(std::static_pointer_cast<BaseSocket>(ret)); 317 SetNotificationMask(ret.get(), descriptor, POLLIN); 318 return ret; 319 } 320 321 std::shared_ptr<StreamSocket> SocketMultiplexer::InternalConnectStreamSocket(const SocketAddress& address, 322 CreateStreamSocketCallback callback, 323 Error* error) 324 { 325 // create and bind socket 326 const sockaddr* sa = reinterpret_cast<const sockaddr*>(address.GetData()); 327 SocketDescriptor descriptor = socket(sa->sa_family, SOCK_STREAM, StreamSocket::GetSocketProtocolForAddress(address)); 328 if (descriptor == INVALID_SOCKET) 329 { 330 Error::SetSocket(error, "socket() failed: ", WSAGetLastError()); 331 return {}; 332 } 333 334 if (connect(descriptor, sa, address.GetLength()) < 0) 335 { 336 Error::SetSocket(error, "connect() failed: ", WSAGetLastError()); 337 closesocket(descriptor); 338 return {}; 339 } 340 341 if (!SetNonBlocking(descriptor, error)) 342 { 343 closesocket(descriptor); 344 return {}; 345 } 346 347 // create stream socket 348 std::shared_ptr<StreamSocket> csocket = callback(*this, descriptor); 349 csocket->InitialSetup(); 350 if (!csocket->IsConnected()) 351 csocket.reset(); 352 353 return csocket; 354 } 355 356 void SocketMultiplexer::AddOpenSocket(std::shared_ptr<BaseSocket> socket) 357 { 358 #ifdef __linux__ 359 struct epoll_event ev = {.events = 0u, .data = {.fd = socket->GetDescriptor()}}; 360 if (epoll_ctl(m_epoll_fd, EPOLL_CTL_ADD, socket->GetDescriptor(), &ev) != 0) [[unlikely]] 361 ERROR_LOG("epoll_ctl() to add socket failed: {}", Error::CreateErrno(errno).GetDescription()); 362 #endif 363 364 std::unique_lock lock(m_open_sockets_lock); 365 DebugAssert(m_open_sockets.find(socket->GetDescriptor()) == m_open_sockets.end()); 366 m_open_sockets.emplace(socket->GetDescriptor(), std::move(socket)); 367 } 368 369 void SocketMultiplexer::AddClientSocket(std::shared_ptr<BaseSocket> socket) 370 { 371 AddOpenSocket(std::move(socket)); 372 m_client_socket_count.fetch_add(1, std::memory_order_acq_rel); 373 } 374 375 void SocketMultiplexer::RemoveOpenSocket(BaseSocket* socket) 376 { 377 std::unique_lock lock(m_open_sockets_lock); 378 const auto iter = m_open_sockets.find(socket->GetDescriptor()); 379 Assert(iter != m_open_sockets.end()); 380 m_open_sockets.erase(iter); 381 382 #ifdef __linux__ 383 if (epoll_ctl(m_epoll_fd, EPOLL_CTL_DEL, socket->GetDescriptor(), nullptr) != 0) [[unlikely]] 384 ERROR_LOG("epoll_ctl() to remove socket failed: {}", Error::CreateErrno(errno).GetDescription()); 385 #else 386 #ifdef _DEBUG 387 for (size_t i = 0; i < m_poll_array_active_size; i++) 388 { 389 pollfd& pfd = m_poll_array[i]; 390 DebugAssert(pfd.fd != socket->GetDescriptor()); 391 } 392 #endif 393 394 // Update size. 395 size_t new_active_size = 0; 396 for (size_t i = 0; i < m_poll_array_active_size; i++) 397 new_active_size = (m_poll_array[i].fd != INVALID_SOCKET) ? (i + 1) : new_active_size; 398 m_poll_array_active_size = new_active_size; 399 #endif 400 } 401 402 void SocketMultiplexer::RemoveClientSocket(BaseSocket* socket) 403 { 404 DebugAssert(m_client_socket_count.load(std::memory_order_acquire) > 0); 405 m_client_socket_count.fetch_sub(1, std::memory_order_acq_rel); 406 RemoveOpenSocket(socket); 407 } 408 409 bool SocketMultiplexer::HasAnyOpenSockets() 410 { 411 std::unique_lock lock(m_open_sockets_lock); 412 return !m_open_sockets.empty(); 413 } 414 415 bool SocketMultiplexer::HasAnyClientSockets() 416 { 417 return (GetClientSocketCount() > 0); 418 } 419 420 size_t SocketMultiplexer::GetClientSocketCount() 421 { 422 return m_client_socket_count.load(std::memory_order_acquire); 423 } 424 425 void SocketMultiplexer::CloseAll() 426 { 427 std::unique_lock lock(m_open_sockets_lock); 428 429 while (!m_open_sockets.empty()) 430 { 431 std::shared_ptr<BaseSocket> socket = m_open_sockets.begin()->second; 432 lock.unlock(); 433 socket->Close(); 434 lock.lock(); 435 } 436 } 437 438 void SocketMultiplexer::SetNotificationMask(BaseSocket* socket, SocketDescriptor descriptor, u32 events) 439 { 440 #ifdef __linux__ 441 struct epoll_event ev = {.events = events, .data = {.fd = descriptor}}; 442 if (epoll_ctl(m_epoll_fd, EPOLL_CTL_MOD, descriptor, &ev) != 0) [[unlikely]] 443 ERROR_LOG("epoll_ctl() for events 0x{:x} failed: {}", events, Error::CreateErrno(errno).GetDescription()); 444 #else 445 std::unique_lock lock(m_poll_array_lock); 446 size_t free_slot = m_poll_array_active_size; 447 for (size_t i = 0; i < m_poll_array_active_size; i++) 448 { 449 pollfd& pfd = m_poll_array[i]; 450 if (pfd.fd != descriptor) 451 { 452 free_slot = (pfd.fd < 0 && free_slot != m_poll_array_active_size) ? i : free_slot; 453 continue; 454 } 455 456 // unbinding? 457 if (events != 0) 458 pfd.events = static_cast<short>(events); 459 else 460 pfd.fd = INVALID_SOCKET; 461 462 return; 463 } 464 465 // don't create entries for null masks 466 if (events == 0) 467 return; 468 469 // need to grow the array? 470 if (free_slot == m_poll_array_max_size) 471 { 472 const size_t new_size = std::max(free_slot + 1, free_slot * 2); 473 pollfd* new_array = static_cast<pollfd*>(std::realloc(m_poll_array, sizeof(pollfd) * new_size)); 474 if (!new_array) 475 Panic("Memory allocation failed."); 476 477 for (size_t i = m_poll_array_max_size; i < new_size; i++) 478 new_array[i] = {.fd = INVALID_SOCKET, .events = 0, .revents = 0}; 479 m_poll_array = new_array; 480 m_poll_array_max_size = new_size; 481 } 482 483 m_poll_array[free_slot] = {.fd = descriptor, .events = static_cast<short>(events), .revents = 0}; 484 m_poll_array_active_size = free_slot + 1; 485 #endif 486 } 487 488 bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds) 489 { 490 #ifdef __linux__ 491 constexpr int MAX_EVENTS = 128; 492 struct epoll_event events[MAX_EVENTS]; 493 494 const int nevents = epoll_wait(m_epoll_fd, events, MAX_EVENTS, static_cast<int>(milliseconds)); 495 if (nevents <= 0) 496 return false; 497 498 // find sockets that triggered, we use an array here so we can avoid holding the lock, and if a socket disconnects 499 using PendingSocketPair = std::pair<std::shared_ptr<BaseSocket>, u32>; 500 PendingSocketPair* triggered_sockets = 501 reinterpret_cast<PendingSocketPair*>(alloca(sizeof(PendingSocketPair) * static_cast<size_t>(nevents))); 502 size_t num_triggered_sockets = 0; 503 { 504 std::unique_lock open_lock(m_open_sockets_lock); 505 for (int i = 0; i < nevents; i++) 506 { 507 const epoll_event& ev = events[i]; 508 const auto iter = m_open_sockets.find(ev.data.fd); 509 if (iter == m_open_sockets.end()) [[unlikely]] 510 { 511 ERROR_LOG("Attempting to look up unknown socket {}, this should never happen.", ev.data.fd); 512 continue; 513 } 514 515 // we add a reference here in case the read kills it with a write pending, or something like that 516 new (&triggered_sockets[num_triggered_sockets++]) PendingSocketPair(iter->second->shared_from_this(), ev.events); 517 } 518 } 519 520 // fire events 521 for (size_t i = 0; i < num_triggered_sockets; i++) 522 { 523 PendingSocketPair& psp = triggered_sockets[i]; 524 525 // fire events 526 if (psp.second & (EPOLLRDHUP | EPOLLHUP | EPOLLERR)) 527 { 528 psp.first->OnHangupEvent(); 529 } 530 else 531 { 532 if (psp.second & EPOLLIN) 533 psp.first->OnReadEvent(); 534 if (psp.second & EPOLLOUT) 535 psp.first->OnWriteEvent(); 536 } 537 538 psp.first.~shared_ptr(); 539 } 540 541 return true; 542 #else 543 std::unique_lock lock(m_poll_array_lock); 544 if (m_poll_array_active_size == 0) 545 return false; 546 547 const int res = WSAPoll(m_poll_array, static_cast<nfds_t>(m_poll_array_active_size), milliseconds); 548 if (res <= 0) 549 return false; 550 551 // find sockets that triggered, we use an array here so we can avoid holding the lock, and if a socket disconnects 552 using PendingSocketPair = std::pair<std::shared_ptr<BaseSocket>, u32>; 553 PendingSocketPair* triggered_sockets = 554 reinterpret_cast<PendingSocketPair*>(alloca(sizeof(PendingSocketPair) * static_cast<size_t>(res))); 555 size_t num_triggered_sockets = 0; 556 { 557 std::unique_lock open_lock(m_open_sockets_lock); 558 for (size_t i = 0; i < m_poll_array_active_size; i++) 559 { 560 const pollfd& pfd = m_poll_array[i]; 561 if (pfd.revents == 0) 562 continue; 563 564 const auto iter = m_open_sockets.find(pfd.fd); 565 if (iter == m_open_sockets.end()) [[unlikely]] 566 { 567 ERROR_LOG("Attempting to look up unknown socket {}, this should never happen.", pfd.fd); 568 continue; 569 } 570 571 // we add a reference here in case the read kills it with a write pending, or something like that 572 new (&triggered_sockets[num_triggered_sockets++]) 573 PendingSocketPair(iter->second->shared_from_this(), pfd.revents); 574 } 575 } 576 577 // release lock so connections etc can acquire it 578 lock.unlock(); 579 580 // fire events 581 for (size_t i = 0; i < num_triggered_sockets; i++) 582 { 583 PendingSocketPair& psp = triggered_sockets[i]; 584 585 // fire events 586 if (psp.second & (POLLHUP | POLLERR)) 587 { 588 psp.first->OnHangupEvent(); 589 } 590 else 591 { 592 if (psp.second & POLLIN) 593 psp.first->OnReadEvent(); 594 if (psp.second & POLLOUT) 595 psp.first->OnWriteEvent(); 596 } 597 598 psp.first.~shared_ptr(); 599 } 600 601 return true; 602 #endif 603 } 604 605 ListenSocket::ListenSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor, 606 SocketMultiplexer::CreateStreamSocketCallback accept_callback) 607 : BaseSocket(multiplexer, descriptor), m_accept_callback(accept_callback) 608 { 609 // get local address 610 sockaddr_storage sa; 611 socklen_t salen = sizeof(sa); 612 if (getsockname(m_descriptor, reinterpret_cast<sockaddr*>(&sa), &salen) == 0) 613 m_local_address.SetFromSockaddr(&sa, salen); 614 } 615 616 ListenSocket::~ListenSocket() 617 { 618 DebugAssert(m_descriptor == INVALID_SOCKET); 619 } 620 621 void ListenSocket::Close() 622 { 623 if (m_descriptor < 0) 624 return; 625 626 m_multiplexer.SetNotificationMask(this, m_descriptor, 0); 627 m_multiplexer.RemoveOpenSocket(this); 628 closesocket(m_descriptor); 629 m_descriptor = INVALID_SOCKET; 630 } 631 632 void ListenSocket::OnReadEvent() 633 { 634 // connection incoming 635 sockaddr_storage sa; 636 socklen_t salen = sizeof(sa); 637 SocketDescriptor new_descriptor = accept(m_descriptor, reinterpret_cast<sockaddr*>(&sa), &salen); 638 if (new_descriptor == INVALID_SOCKET) 639 { 640 ERROR_LOG("accept() returned {}", Error::CreateSocket(WSAGetLastError()).GetDescription()); 641 return; 642 } 643 644 Error error; 645 if (!SetNonBlocking(new_descriptor, &error)) 646 { 647 ERROR_LOG("Failed to set just-connected socket to nonblocking: {}", error.GetDescription()); 648 closesocket(new_descriptor); 649 return; 650 } 651 652 // create socket, we release our own reference. 653 std::shared_ptr<StreamSocket> client = m_accept_callback(m_multiplexer, new_descriptor); 654 if (!client) 655 { 656 closesocket(new_descriptor); 657 return; 658 } 659 660 m_num_connections_accepted++; 661 client->InitialSetup(); 662 } 663 664 void ListenSocket::OnWriteEvent() 665 { 666 ERROR_LOG("Unexpected OnWriteEvent() in ListenSocket {}", m_local_address.ToString()); 667 } 668 669 void ListenSocket::OnHangupEvent() 670 { 671 ERROR_LOG("Unexpected OnHangupEvent() in ListenSocket {}", m_local_address.ToString()); 672 } 673 674 StreamSocket::StreamSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor) 675 : BaseSocket(multiplexer, descriptor) 676 { 677 // get local address 678 sockaddr_storage sa; 679 socklen_t salen = sizeof(sa); 680 if (getsockname(m_descriptor, reinterpret_cast<sockaddr*>(&sa), &salen) == 0) 681 m_local_address.SetFromSockaddr(&sa, salen); 682 683 // get remote address 684 salen = sizeof(sockaddr_storage); 685 if (getpeername(m_descriptor, reinterpret_cast<sockaddr*>(&sa), &salen) == 0) 686 m_remote_address.SetFromSockaddr(&sa, salen); 687 } 688 689 StreamSocket::~StreamSocket() 690 { 691 DebugAssert(m_descriptor == INVALID_SOCKET); 692 } 693 694 u32 StreamSocket::GetSocketProtocolForAddress(const SocketAddress& sa) 695 { 696 const sockaddr* ssa = reinterpret_cast<const sockaddr*>(sa.GetData()); 697 return (ssa->sa_family == AF_INET || ssa->sa_family == AF_INET6) ? IPPROTO_TCP : 0; 698 } 699 700 void StreamSocket::InitialSetup() 701 { 702 // register for notifications 703 m_multiplexer.AddClientSocket(shared_from_this()); 704 m_multiplexer.SetNotificationMask(this, m_descriptor, POLLIN); 705 706 // trigger connected notification 707 std::unique_lock lock(m_lock); 708 OnConnected(); 709 } 710 711 size_t StreamSocket::Read(void* buffer, size_t buffer_size) 712 { 713 std::unique_lock lock(m_lock); 714 if (!m_connected) 715 return 0; 716 717 // try a read 718 const ssize_t len = recv(m_descriptor, static_cast<char*>(buffer), SIZE_CAST(buffer_size), 0); 719 if (len <= 0) 720 { 721 // Check for EAGAIN 722 if (len < 0 && WSAGetLastError() == WSAEWOULDBLOCK) 723 { 724 // Not an error. Just means no data is available. 725 return 0; 726 } 727 728 // error 729 CloseWithError(); 730 return 0; 731 } 732 733 return len; 734 } 735 736 size_t StreamSocket::Write(const void* buffer, size_t buffer_size) 737 { 738 std::unique_lock lock(m_lock); 739 if (!m_connected) 740 return 0; 741 742 // try a write 743 const ssize_t len = send(m_descriptor, static_cast<const char*>(buffer), SIZE_CAST(buffer_size), 0); 744 if (len <= 0) 745 { 746 // Check for EAGAIN 747 if (len < 0 && WSAGetLastError() == WSAEWOULDBLOCK) 748 { 749 // Not an error. Just means no data is available. 750 return 0; 751 } 752 753 // error 754 CloseWithError(); 755 return 0; 756 } 757 758 return len; 759 } 760 761 size_t StreamSocket::WriteVector(const void** buffers, const size_t* buffer_lengths, size_t num_buffers) 762 { 763 std::unique_lock lock(m_lock); 764 if (!m_connected || num_buffers == 0) 765 return 0; 766 767 #ifdef _WIN32 768 769 WSABUF* bufs = static_cast<WSABUF*>(alloca(sizeof(WSABUF) * num_buffers)); 770 for (size_t i = 0; i < num_buffers; i++) 771 { 772 bufs[i].buf = (CHAR*)buffers[i]; 773 bufs[i].len = (ULONG)buffer_lengths[i]; 774 } 775 776 DWORD bytesSent = 0; 777 if (WSASend(m_descriptor, bufs, (DWORD)num_buffers, &bytesSent, 0, nullptr, nullptr) == SOCKET_ERROR) 778 { 779 if (WSAGetLastError() != WSAEWOULDBLOCK) 780 { 781 // Socket error. 782 CloseWithError(); 783 return 0; 784 } 785 } 786 787 return static_cast<size_t>(bytesSent); 788 789 #else // _WIN32 790 791 iovec* bufs = static_cast<iovec*>(alloca(sizeof(iovec) * num_buffers)); 792 for (size_t i = 0; i < num_buffers; i++) 793 { 794 bufs[i].iov_base = (void*)buffers[i]; 795 bufs[i].iov_len = buffer_lengths[i]; 796 } 797 798 ssize_t res = writev(m_descriptor, bufs, num_buffers); 799 if (res < 0) 800 { 801 if (errno != EAGAIN) 802 { 803 // Socket error. 804 CloseWithError(); 805 return 0; 806 } 807 808 res = 0; 809 } 810 811 return static_cast<size_t>(res); 812 813 #endif 814 } 815 816 bool StreamSocket::SetNagleBuffering(bool enabled, Error* error /* = nullptr */) 817 { 818 if (!m_local_address.IsIPAddress()) 819 { 820 Error::SetStringView(error, "Attempting to disable nagle on a non-IP socket."); 821 return false; 822 } 823 824 int disable = enabled ? 0 : 1; 825 if (setsockopt(m_descriptor, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char*>(&disable), sizeof(disable)) != 0) 826 { 827 Error::SetSocket(error, "setsockopt(TCP_NODELAY) failed: ", WSAGetLastError()); 828 return false; 829 } 830 831 return true; 832 } 833 834 void StreamSocket::Close() 835 { 836 std::unique_lock lock(m_lock); 837 if (!m_connected) 838 return; 839 840 m_multiplexer.SetNotificationMask(this, m_descriptor, 0); 841 m_multiplexer.RemoveClientSocket(this); 842 shutdown(m_descriptor, SD_BOTH); 843 closesocket(m_descriptor); 844 m_descriptor = INVALID_SOCKET; 845 m_connected = false; 846 847 OnDisconnected(Error::CreateString("Connection explicitly closed.")); 848 } 849 850 void StreamSocket::CloseWithError() 851 { 852 std::unique_lock lock(m_lock); 853 DebugAssert(m_connected); 854 855 Error error; 856 const int error_code = WSAGetLastError(); 857 if (error_code == 0) 858 error.SetStringView("Connection closed by peer."); 859 else 860 error.SetSocket(error_code); 861 862 m_multiplexer.SetNotificationMask(this, m_descriptor, 0); 863 m_multiplexer.RemoveClientSocket(this); 864 closesocket(m_descriptor); 865 m_descriptor = INVALID_SOCKET; 866 m_connected = false; 867 868 OnDisconnected(error); 869 } 870 871 void StreamSocket::OnReadEvent() 872 { 873 // forward through 874 std::unique_lock lock(m_lock); 875 if (m_connected) 876 OnRead(); 877 } 878 879 void StreamSocket::OnWriteEvent() 880 { 881 // shouldn't be called 882 } 883 884 void StreamSocket::OnHangupEvent() 885 { 886 std::unique_lock lock(m_lock); 887 if (!m_connected) 888 return; 889 890 m_multiplexer.SetNotificationMask(this, m_descriptor, 0); 891 m_multiplexer.RemoveClientSocket(this); 892 closesocket(m_descriptor); 893 m_descriptor = INVALID_SOCKET; 894 m_connected = false; 895 896 OnDisconnected(Error::CreateString("Connection closed by peer.")); 897 } 898 899 BufferedStreamSocket::BufferedStreamSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor, 900 size_t receive_buffer_size /* = 16384 */, 901 size_t send_buffer_size /* = 16384 */) 902 : StreamSocket(multiplexer, descriptor), m_receive_buffer(receive_buffer_size), m_send_buffer(send_buffer_size) 903 { 904 } 905 906 BufferedStreamSocket::~BufferedStreamSocket() 907 { 908 } 909 910 std::unique_lock<std::recursive_mutex> BufferedStreamSocket::GetLock() 911 { 912 return std::unique_lock(m_lock); 913 } 914 915 std::span<const u8> BufferedStreamSocket::AcquireReadBuffer() const 916 { 917 return std::span<const u8>(m_receive_buffer.data() + m_receive_buffer_offset, m_receive_buffer_size); 918 } 919 920 void BufferedStreamSocket::ReleaseReadBuffer(size_t bytes_consumed) 921 { 922 DebugAssert(bytes_consumed <= m_receive_buffer_size); 923 m_receive_buffer_offset += static_cast<u32>(bytes_consumed); 924 m_receive_buffer_size -= static_cast<u32>(bytes_consumed); 925 926 // Anything left? If not, reset offset. 927 m_receive_buffer_offset = (m_receive_buffer_size == 0) ? 0 : m_receive_buffer_offset; 928 } 929 930 std::span<u8> BufferedStreamSocket::AcquireWriteBuffer(size_t wanted_bytes, bool allow_smaller /* = false */) 931 { 932 if (!m_connected) 933 return {}; 934 935 // If to get the desired space, we need to move backwards, do so. 936 if ((m_send_buffer_offset + m_send_buffer_size + wanted_bytes) > m_send_buffer.size()) 937 { 938 if ((m_send_buffer_size + wanted_bytes) > m_send_buffer.size() && !allow_smaller) 939 { 940 // Not enough space. 941 return {}; 942 } 943 944 // Shuffle buffer backwards. 945 std::memmove(m_send_buffer.data(), m_send_buffer.data() + m_send_buffer_offset, m_send_buffer_size); 946 m_send_buffer_offset = 0; 947 } 948 949 DebugAssert((m_send_buffer_offset + m_send_buffer_size + wanted_bytes) <= m_send_buffer.size()); 950 return std::span<u8>(m_send_buffer.data() + m_send_buffer_offset + m_send_buffer_size, 951 m_send_buffer.size() - m_send_buffer_offset - m_send_buffer_size); 952 } 953 954 void BufferedStreamSocket::ReleaseWriteBuffer(size_t bytes_written, bool commit /* = true */) 955 { 956 if (!m_connected) 957 return; 958 959 DebugAssert((m_send_buffer_offset + m_send_buffer_size + bytes_written) <= m_send_buffer.size()); 960 m_send_buffer_size += static_cast<u32>(bytes_written); 961 962 // Send as much as we can. 963 if (commit && m_send_buffer_size > 0) 964 { 965 const ssize_t res = send(m_descriptor, reinterpret_cast<const char*>(m_send_buffer.data() + m_send_buffer_offset), 966 SIZE_CAST(m_send_buffer_size), 0); 967 if (res < 0 && WSAGetLastError() != WSAEWOULDBLOCK) 968 { 969 CloseWithError(); 970 return; 971 } 972 973 m_send_buffer_offset += static_cast<size_t>(res); 974 m_send_buffer_size -= static_cast<size_t>(res); 975 if (m_send_buffer_size == 0) 976 { 977 m_send_buffer_offset = 0; 978 } 979 else 980 { 981 // Register for writes to finish it off. 982 m_multiplexer.SetNotificationMask(this, m_descriptor, POLLIN | POLLOUT); 983 } 984 } 985 } 986 987 size_t BufferedStreamSocket::Read(void* buffer, size_t buffer_size) 988 { 989 // Read from receive buffer. 990 const std::span<const u8> rdbuf = AcquireReadBuffer(); 991 if (rdbuf.empty()) 992 return 0; 993 994 const size_t bytes_to_read = std::min(rdbuf.size(), buffer_size); 995 std::memcpy(buffer, rdbuf.data(), bytes_to_read); 996 ReleaseReadBuffer(bytes_to_read); 997 return bytes_to_read; 998 } 999 1000 size_t BufferedStreamSocket::Write(const void* buffer, size_t buffer_size) 1001 { 1002 if (!m_connected) 1003 return 0; 1004 1005 // Read from receive buffer. 1006 const std::span<u8> wrbuf = AcquireWriteBuffer(buffer_size, true); 1007 if (wrbuf.empty()) 1008 return 0; 1009 1010 const size_t bytes_to_write = std::min(wrbuf.size(), buffer_size); 1011 std::memcpy(wrbuf.data(), buffer, bytes_to_write); 1012 ReleaseWriteBuffer(bytes_to_write); 1013 return bytes_to_write; 1014 } 1015 1016 size_t BufferedStreamSocket::WriteVector(const void** buffers, const size_t* buffer_lengths, size_t num_buffers) 1017 { 1018 if (!m_connected || num_buffers == 0) 1019 return 0; 1020 1021 size_t total_size = 0; 1022 for (size_t i = 0; i < num_buffers; i++) 1023 total_size += buffer_lengths[i]; 1024 1025 const std::span<u8> wrbuf = AcquireWriteBuffer(total_size, true); 1026 if (wrbuf.empty()) 1027 return 0; 1028 1029 size_t written_bytes = 0; 1030 for (size_t i = 0; i < num_buffers; i++) 1031 { 1032 const size_t bytes_to_write = std::min(wrbuf.size() - written_bytes, buffer_lengths[i]); 1033 if (bytes_to_write == 0) 1034 break; 1035 1036 std::memcpy(&wrbuf[written_bytes], buffers[i], bytes_to_write); 1037 written_bytes += buffer_lengths[i]; 1038 } 1039 1040 return written_bytes; 1041 } 1042 1043 void BufferedStreamSocket::Close() 1044 { 1045 StreamSocket::Close(); 1046 1047 m_receive_buffer_offset = 0; 1048 m_receive_buffer_size = 0; 1049 m_send_buffer_offset = 0; 1050 m_send_buffer_size = 0; 1051 } 1052 1053 void BufferedStreamSocket::OnReadEvent() 1054 { 1055 std::unique_lock lock(m_lock); 1056 if (!m_connected) 1057 return; 1058 1059 // Pull as many bytes as possible into the read buffer. 1060 for (;;) 1061 { 1062 const size_t buffer_space = m_receive_buffer.size() - m_receive_buffer_offset - m_receive_buffer_size; 1063 if (buffer_space == 0) [[unlikely]] 1064 { 1065 // If we're here again, it means OnRead() didn't consume the data, and we overflowed. 1066 ERROR_LOG("Receive buffer overflow, dropping client {}.", GetRemoteAddress().ToString()); 1067 CloseWithError(); 1068 return; 1069 } 1070 1071 const ssize_t res = recv( 1072 m_descriptor, reinterpret_cast<char*>(m_receive_buffer.data() + m_receive_buffer_offset + m_receive_buffer_size), 1073 SIZE_CAST(buffer_space), 0); 1074 if (res <= 0 && WSAGetLastError() != WSAEWOULDBLOCK) 1075 { 1076 CloseWithError(); 1077 return; 1078 } 1079 1080 m_receive_buffer_size += static_cast<size_t>(res); 1081 OnRead(); 1082 1083 // Are we at the end? 1084 if ((m_receive_buffer_offset + m_receive_buffer_size) == m_receive_buffer.size()) 1085 { 1086 // Try to claw back some of the buffer, and try reading again. 1087 if (m_receive_buffer_offset > 0) 1088 { 1089 std::memmove(m_receive_buffer.data(), m_receive_buffer.data() + m_receive_buffer_offset, m_receive_buffer_size); 1090 m_receive_buffer_offset = 0; 1091 continue; 1092 } 1093 } 1094 1095 break; 1096 } 1097 } 1098 1099 void BufferedStreamSocket::OnWriteEvent() 1100 { 1101 std::unique_lock lock(m_lock); 1102 if (!m_connected) 1103 return; 1104 1105 // Send as much as we can. 1106 if (m_send_buffer_size > 0) 1107 { 1108 const ssize_t res = send(m_descriptor, reinterpret_cast<const char*>(m_send_buffer.data() + m_send_buffer_offset), 1109 SIZE_CAST(m_send_buffer_size), 0); 1110 if (res < 0 && WSAGetLastError() != WSAEWOULDBLOCK) 1111 { 1112 CloseWithError(); 1113 return; 1114 } 1115 1116 m_send_buffer_offset += static_cast<size_t>(res); 1117 m_send_buffer_size -= static_cast<size_t>(res); 1118 if (m_send_buffer_size == 0) 1119 m_send_buffer_offset = 0; 1120 } 1121 1122 OnWrite(); 1123 1124 if (m_send_buffer_size == 0) 1125 { 1126 // Are we done? Switch back to reads only. 1127 m_multiplexer.SetNotificationMask(this, m_descriptor, POLLIN); 1128 } 1129 } 1130 1131 void BufferedStreamSocket::OnWrite() 1132 { 1133 }