From 095ae15d6b9ff60357b44ed6f4997754079eff2e Mon Sep 17 00:00:00 2001 From: jbauch Date: Fri, 18 Dec 2015 01:39:55 -0800 Subject: [PATCH] Keep listening if "accept" returns an invalid socket. There is an issue in PhysicalSocket::Accept where the flag to continue listening is not set in "enabled_events_" if "accept" returns an error. This CL fixes this (initial idea by silviu.cpp@gmail.com). BUG=webrtc:2030 Review URL: https://codereview.webrtc.org/1452903006 Cr-Commit-Position: refs/heads/master@{#11080} --- webrtc/base/physicalsocketserver.cc | 1175 +++++++++--------- webrtc/base/physicalsocketserver.h | 102 ++ webrtc/base/physicalsocketserver_unittest.cc | 161 ++- webrtc/base/socket_unittest.h | 16 +- 4 files changed, 855 insertions(+), 599 deletions(-) diff --git a/webrtc/base/physicalsocketserver.cc b/webrtc/base/physicalsocketserver.cc index 4a4c0a36ba..524617221c 100644 --- a/webrtc/base/physicalsocketserver.cc +++ b/webrtc/base/physicalsocketserver.cc @@ -44,7 +44,6 @@ #include "webrtc/base/byteorder.h" #include "webrtc/base/common.h" #include "webrtc/base/logging.h" -#include "webrtc/base/nethelpers.h" #include "webrtc/base/physicalsocketserver.h" #include "webrtc/base/timeutils.h" #include "webrtc/base/winping.h" @@ -97,463 +96,455 @@ static const int ICMP_HEADER_SIZE = 8u; static const int ICMP_PING_TIMEOUT_MILLIS = 10000u; #endif -class PhysicalSocket : public AsyncSocket, public sigslot::has_slots<> { - public: - PhysicalSocket(PhysicalSocketServer* ss, SOCKET s = INVALID_SOCKET) - : ss_(ss), s_(s), enabled_events_(0), error_(0), - state_((s == INVALID_SOCKET) ? CS_CLOSED : CS_CONNECTED), - resolver_(NULL) { +PhysicalSocket::PhysicalSocket(PhysicalSocketServer* ss, SOCKET s) + : ss_(ss), s_(s), enabled_events_(0), error_(0), + state_((s == INVALID_SOCKET) ? CS_CLOSED : CS_CONNECTED), + resolver_(nullptr) { #if defined(WEBRTC_WIN) - // EnsureWinsockInit() ensures that winsock is initialized. The default - // version of this function doesn't do anything because winsock is - // initialized by constructor of a static object. If neccessary libjingle - // users can link it with a different version of this function by replacing - // win32socketinit.cc. See win32socketinit.cc for more details. - EnsureWinsockInit(); + // EnsureWinsockInit() ensures that winsock is initialized. The default + // version of this function doesn't do anything because winsock is + // initialized by constructor of a static object. If neccessary libjingle + // users can link it with a different version of this function by replacing + // win32socketinit.cc. See win32socketinit.cc for more details. + EnsureWinsockInit(); #endif - if (s_ != INVALID_SOCKET) { - enabled_events_ = DE_READ | DE_WRITE; + if (s_ != INVALID_SOCKET) { + enabled_events_ = DE_READ | DE_WRITE; - int type = SOCK_STREAM; - socklen_t len = sizeof(type); - VERIFY(0 == getsockopt(s_, SOL_SOCKET, SO_TYPE, (SockOptArg)&type, &len)); - udp_ = (SOCK_DGRAM == type); - } - } - - ~PhysicalSocket() override { - Close(); - } - - // Creates the underlying OS socket (same as the "socket" function). - virtual bool Create(int family, int type) { - Close(); - s_ = ::socket(family, type, 0); + int type = SOCK_STREAM; + socklen_t len = sizeof(type); + VERIFY(0 == getsockopt(s_, SOL_SOCKET, SO_TYPE, (SockOptArg)&type, &len)); udp_ = (SOCK_DGRAM == type); - UpdateLastError(); - if (udp_) - enabled_events_ = DE_READ | DE_WRITE; - return s_ != INVALID_SOCKET; } +} - SocketAddress GetLocalAddress() const override { - sockaddr_storage addr_storage = {0}; - socklen_t addrlen = sizeof(addr_storage); - sockaddr* addr = reinterpret_cast(&addr_storage); - int result = ::getsockname(s_, addr, &addrlen); - SocketAddress address; - if (result >= 0) { - SocketAddressFromSockAddrStorage(addr_storage, &address); - } else { - LOG(LS_WARNING) << "GetLocalAddress: unable to get local addr, socket=" - << s_; - } - return address; +PhysicalSocket::~PhysicalSocket() { + Close(); +} + +bool PhysicalSocket::Create(int family, int type) { + Close(); + s_ = ::socket(family, type, 0); + udp_ = (SOCK_DGRAM == type); + UpdateLastError(); + if (udp_) + enabled_events_ = DE_READ | DE_WRITE; + return s_ != INVALID_SOCKET; +} + +SocketAddress PhysicalSocket::GetLocalAddress() const { + sockaddr_storage addr_storage = {0}; + socklen_t addrlen = sizeof(addr_storage); + sockaddr* addr = reinterpret_cast(&addr_storage); + int result = ::getsockname(s_, addr, &addrlen); + SocketAddress address; + if (result >= 0) { + SocketAddressFromSockAddrStorage(addr_storage, &address); + } else { + LOG(LS_WARNING) << "GetLocalAddress: unable to get local addr, socket=" + << s_; } + return address; +} - SocketAddress GetRemoteAddress() const override { - sockaddr_storage addr_storage = {0}; - socklen_t addrlen = sizeof(addr_storage); - sockaddr* addr = reinterpret_cast(&addr_storage); - int result = ::getpeername(s_, addr, &addrlen); - SocketAddress address; - if (result >= 0) { - SocketAddressFromSockAddrStorage(addr_storage, &address); - } else { - LOG(LS_WARNING) << "GetRemoteAddress: unable to get remote addr, socket=" - << s_; - } - return address; +SocketAddress PhysicalSocket::GetRemoteAddress() const { + sockaddr_storage addr_storage = {0}; + socklen_t addrlen = sizeof(addr_storage); + sockaddr* addr = reinterpret_cast(&addr_storage); + int result = ::getpeername(s_, addr, &addrlen); + SocketAddress address; + if (result >= 0) { + SocketAddressFromSockAddrStorage(addr_storage, &address); + } else { + LOG(LS_WARNING) << "GetRemoteAddress: unable to get remote addr, socket=" + << s_; } + return address; +} - int Bind(const SocketAddress& bind_addr) override { - sockaddr_storage addr_storage; - size_t len = bind_addr.ToSockAddrStorage(&addr_storage); - sockaddr* addr = reinterpret_cast(&addr_storage); - int err = ::bind(s_, addr, static_cast(len)); - UpdateLastError(); +int PhysicalSocket::Bind(const SocketAddress& bind_addr) { + sockaddr_storage addr_storage; + size_t len = bind_addr.ToSockAddrStorage(&addr_storage); + sockaddr* addr = reinterpret_cast(&addr_storage); + int err = ::bind(s_, addr, static_cast(len)); + UpdateLastError(); #if !defined(NDEBUG) - if (0 == err) { - dbg_addr_ = "Bound @ "; - dbg_addr_.append(GetLocalAddress().ToString()); - } + if (0 == err) { + dbg_addr_ = "Bound @ "; + dbg_addr_.append(GetLocalAddress().ToString()); + } #endif - return err; + return err; +} + +int PhysicalSocket::Connect(const SocketAddress& addr) { + // TODO(pthatcher): Implicit creation is required to reconnect... + // ...but should we make it more explicit? + if (state_ != CS_CLOSED) { + SetError(EALREADY); + return SOCKET_ERROR; } - - int Connect(const SocketAddress& addr) override { - // TODO: Implicit creation is required to reconnect... - // ...but should we make it more explicit? - if (state_ != CS_CLOSED) { - SetError(EALREADY); - return SOCKET_ERROR; - } - if (addr.IsUnresolvedIP()) { - LOG(LS_VERBOSE) << "Resolving addr in PhysicalSocket::Connect"; - resolver_ = new AsyncResolver(); - resolver_->SignalDone.connect(this, &PhysicalSocket::OnResolveResult); - resolver_->Start(addr); - state_ = CS_CONNECTING; - return 0; - } - - return DoConnect(addr); - } - - int DoConnect(const SocketAddress& connect_addr) { - if ((s_ == INVALID_SOCKET) && - !Create(connect_addr.family(), SOCK_STREAM)) { - return SOCKET_ERROR; - } - sockaddr_storage addr_storage; - size_t len = connect_addr.ToSockAddrStorage(&addr_storage); - sockaddr* addr = reinterpret_cast(&addr_storage); - int err = ::connect(s_, addr, static_cast(len)); - UpdateLastError(); - if (err == 0) { - state_ = CS_CONNECTED; - } else if (IsBlockingError(GetError())) { - state_ = CS_CONNECTING; - enabled_events_ |= DE_CONNECT; - } else { - return SOCKET_ERROR; - } - - enabled_events_ |= DE_READ | DE_WRITE; + if (addr.IsUnresolvedIP()) { + LOG(LS_VERBOSE) << "Resolving addr in PhysicalSocket::Connect"; + resolver_ = new AsyncResolver(); + resolver_->SignalDone.connect(this, &PhysicalSocket::OnResolveResult); + resolver_->Start(addr); + state_ = CS_CONNECTING; return 0; } - int GetError() const override { - CritScope cs(&crit_); - return error_; + return DoConnect(addr); +} + +int PhysicalSocket::DoConnect(const SocketAddress& connect_addr) { + if ((s_ == INVALID_SOCKET) && + !Create(connect_addr.family(), SOCK_STREAM)) { + return SOCKET_ERROR; + } + sockaddr_storage addr_storage; + size_t len = connect_addr.ToSockAddrStorage(&addr_storage); + sockaddr* addr = reinterpret_cast(&addr_storage); + int err = ::connect(s_, addr, static_cast(len)); + UpdateLastError(); + if (err == 0) { + state_ = CS_CONNECTED; + } else if (IsBlockingError(GetError())) { + state_ = CS_CONNECTING; + enabled_events_ |= DE_CONNECT; + } else { + return SOCKET_ERROR; } - void SetError(int error) override { - CritScope cs(&crit_); - error_ = error; - } + enabled_events_ |= DE_READ | DE_WRITE; + return 0; +} - ConnState GetState() const override { return state_; } +int PhysicalSocket::GetError() const { + CritScope cs(&crit_); + return error_; +} - int GetOption(Option opt, int* value) override { - int slevel; - int sopt; - if (TranslateOption(opt, &slevel, &sopt) == -1) - return -1; - socklen_t optlen = sizeof(*value); - int ret = ::getsockopt(s_, slevel, sopt, (SockOptArg)value, &optlen); - if (ret != -1 && opt == OPT_DONTFRAGMENT) { +void PhysicalSocket::SetError(int error) { + CritScope cs(&crit_); + error_ = error; +} + +AsyncSocket::ConnState PhysicalSocket::GetState() const { + return state_; +} + +int PhysicalSocket::GetOption(Option opt, int* value) { + int slevel; + int sopt; + if (TranslateOption(opt, &slevel, &sopt) == -1) + return -1; + socklen_t optlen = sizeof(*value); + int ret = ::getsockopt(s_, slevel, sopt, (SockOptArg)value, &optlen); + if (ret != -1 && opt == OPT_DONTFRAGMENT) { #if defined(WEBRTC_LINUX) && !defined(WEBRTC_ANDROID) - *value = (*value != IP_PMTUDISC_DONT) ? 1 : 0; + *value = (*value != IP_PMTUDISC_DONT) ? 1 : 0; #endif - } - return ret; } + return ret; +} - int SetOption(Option opt, int value) override { - int slevel; - int sopt; - if (TranslateOption(opt, &slevel, &sopt) == -1) - return -1; - if (opt == OPT_DONTFRAGMENT) { +int PhysicalSocket::SetOption(Option opt, int value) { + int slevel; + int sopt; + if (TranslateOption(opt, &slevel, &sopt) == -1) + return -1; + if (opt == OPT_DONTFRAGMENT) { #if defined(WEBRTC_LINUX) && !defined(WEBRTC_ANDROID) - value = (value) ? IP_PMTUDISC_DO : IP_PMTUDISC_DONT; + value = (value) ? IP_PMTUDISC_DO : IP_PMTUDISC_DONT; #endif - } - return ::setsockopt(s_, slevel, sopt, (SockOptArg)&value, sizeof(value)); } + return ::setsockopt(s_, slevel, sopt, (SockOptArg)&value, sizeof(value)); +} - int Send(const void* pv, size_t cb) override { - int sent = ::send(s_, reinterpret_cast(pv), (int)cb, +int PhysicalSocket::Send(const void* pv, size_t cb) { + int sent = ::send(s_, reinterpret_cast(pv), (int)cb, #if defined(WEBRTC_LINUX) && !defined(WEBRTC_ANDROID) - // Suppress SIGPIPE. Without this, attempting to send on a socket whose - // other end is closed will result in a SIGPIPE signal being raised to - // our process, which by default will terminate the process, which we - // don't want. By specifying this flag, we'll just get the error EPIPE - // instead and can handle the error gracefully. - MSG_NOSIGNAL + // Suppress SIGPIPE. Without this, attempting to send on a socket whose + // other end is closed will result in a SIGPIPE signal being raised to + // our process, which by default will terminate the process, which we + // don't want. By specifying this flag, we'll just get the error EPIPE + // instead and can handle the error gracefully. + MSG_NOSIGNAL #else - 0 + 0 #endif - ); - UpdateLastError(); - MaybeRemapSendError(); - // We have seen minidumps where this may be false. - ASSERT(sent <= static_cast(cb)); - if ((sent < 0) && IsBlockingError(GetError())) { - enabled_events_ |= DE_WRITE; - } - return sent; + ); + UpdateLastError(); + MaybeRemapSendError(); + // We have seen minidumps where this may be false. + ASSERT(sent <= static_cast(cb)); + if ((sent < 0) && IsBlockingError(GetError())) { + enabled_events_ |= DE_WRITE; } + return sent; +} - int SendTo(const void* buffer, - size_t length, - const SocketAddress& addr) override { - sockaddr_storage saddr; - size_t len = addr.ToSockAddrStorage(&saddr); - int sent = ::sendto( - s_, static_cast(buffer), static_cast(length), +int PhysicalSocket::SendTo(const void* buffer, + size_t length, + const SocketAddress& addr) { + sockaddr_storage saddr; + size_t len = addr.ToSockAddrStorage(&saddr); + int sent = ::sendto( + s_, static_cast(buffer), static_cast(length), #if defined(WEBRTC_LINUX) && !defined(WEBRTC_ANDROID) - // Suppress SIGPIPE. See above for explanation. - MSG_NOSIGNAL, + // Suppress SIGPIPE. See above for explanation. + MSG_NOSIGNAL, #else - 0, + 0, #endif - reinterpret_cast(&saddr), static_cast(len)); - UpdateLastError(); - MaybeRemapSendError(); - // We have seen minidumps where this may be false. - ASSERT(sent <= static_cast(length)); - if ((sent < 0) && IsBlockingError(GetError())) { - enabled_events_ |= DE_WRITE; - } - return sent; + reinterpret_cast(&saddr), static_cast(len)); + UpdateLastError(); + MaybeRemapSendError(); + // We have seen minidumps where this may be false. + ASSERT(sent <= static_cast(length)); + if ((sent < 0) && IsBlockingError(GetError())) { + enabled_events_ |= DE_WRITE; } + return sent; +} - int Recv(void* buffer, size_t length) override { - int received = ::recv(s_, static_cast(buffer), - static_cast(length), 0); - if ((received == 0) && (length != 0)) { - // Note: on graceful shutdown, recv can return 0. In this case, we - // pretend it is blocking, and then signal close, so that simplifying - // assumptions can be made about Recv. - LOG(LS_WARNING) << "EOF from socket; deferring close event"; - // Must turn this back on so that the select() loop will notice the close - // event. - enabled_events_ |= DE_READ; - SetError(EWOULDBLOCK); - return SOCKET_ERROR; - } - UpdateLastError(); - int error = GetError(); - bool success = (received >= 0) || IsBlockingError(error); - if (udp_ || success) { - enabled_events_ |= DE_READ; - } - if (!success) { - LOG_F(LS_VERBOSE) << "Error = " << error; - } - return received; +int PhysicalSocket::Recv(void* buffer, size_t length) { + int received = ::recv(s_, static_cast(buffer), + static_cast(length), 0); + if ((received == 0) && (length != 0)) { + // Note: on graceful shutdown, recv can return 0. In this case, we + // pretend it is blocking, and then signal close, so that simplifying + // assumptions can be made about Recv. + LOG(LS_WARNING) << "EOF from socket; deferring close event"; + // Must turn this back on so that the select() loop will notice the close + // event. + enabled_events_ |= DE_READ; + SetError(EWOULDBLOCK); + return SOCKET_ERROR; } - - int RecvFrom(void* buffer, size_t length, SocketAddress* out_addr) override { - sockaddr_storage addr_storage; - socklen_t addr_len = sizeof(addr_storage); - sockaddr* addr = reinterpret_cast(&addr_storage); - int received = ::recvfrom(s_, static_cast(buffer), - static_cast(length), 0, addr, &addr_len); - UpdateLastError(); - if ((received >= 0) && (out_addr != NULL)) - SocketAddressFromSockAddrStorage(addr_storage, out_addr); - int error = GetError(); - bool success = (received >= 0) || IsBlockingError(error); - if (udp_ || success) { - enabled_events_ |= DE_READ; - } - if (!success) { - LOG_F(LS_VERBOSE) << "Error = " << error; - } - return received; + UpdateLastError(); + int error = GetError(); + bool success = (received >= 0) || IsBlockingError(error); + if (udp_ || success) { + enabled_events_ |= DE_READ; } - - int Listen(int backlog) override { - int err = ::listen(s_, backlog); - UpdateLastError(); - if (err == 0) { - state_ = CS_CONNECTING; - enabled_events_ |= DE_ACCEPT; -#if !defined(NDEBUG) - dbg_addr_ = "Listening @ "; - dbg_addr_.append(GetLocalAddress().ToString()); -#endif - } - return err; + if (!success) { + LOG_F(LS_VERBOSE) << "Error = " << error; } + return received; +} - AsyncSocket* Accept(SocketAddress* out_addr) override { - sockaddr_storage addr_storage; - socklen_t addr_len = sizeof(addr_storage); - sockaddr* addr = reinterpret_cast(&addr_storage); - SOCKET s = ::accept(s_, addr, &addr_len); - UpdateLastError(); - if (s == INVALID_SOCKET) - return NULL; +int PhysicalSocket::RecvFrom(void* buffer, + size_t length, + SocketAddress* out_addr) { + sockaddr_storage addr_storage; + socklen_t addr_len = sizeof(addr_storage); + sockaddr* addr = reinterpret_cast(&addr_storage); + int received = ::recvfrom(s_, static_cast(buffer), + static_cast(length), 0, addr, &addr_len); + UpdateLastError(); + if ((received >= 0) && (out_addr != nullptr)) + SocketAddressFromSockAddrStorage(addr_storage, out_addr); + int error = GetError(); + bool success = (received >= 0) || IsBlockingError(error); + if (udp_ || success) { + enabled_events_ |= DE_READ; + } + if (!success) { + LOG_F(LS_VERBOSE) << "Error = " << error; + } + return received; +} + +int PhysicalSocket::Listen(int backlog) { + int err = ::listen(s_, backlog); + UpdateLastError(); + if (err == 0) { + state_ = CS_CONNECTING; enabled_events_ |= DE_ACCEPT; - if (out_addr != NULL) - SocketAddressFromSockAddrStorage(addr_storage, out_addr); - return ss_->WrapSocket(s); +#if !defined(NDEBUG) + dbg_addr_ = "Listening @ "; + dbg_addr_.append(GetLocalAddress().ToString()); +#endif } + return err; +} - int Close() override { - if (s_ == INVALID_SOCKET) - return 0; - int err = ::closesocket(s_); - UpdateLastError(); - s_ = INVALID_SOCKET; - state_ = CS_CLOSED; - enabled_events_ = 0; - if (resolver_) { - resolver_->Destroy(false); - resolver_ = NULL; - } - return err; +AsyncSocket* PhysicalSocket::Accept(SocketAddress* out_addr) { + // Always re-subscribe DE_ACCEPT to make sure new incoming connections will + // trigger an event even if DoAccept returns an error here. + enabled_events_ |= DE_ACCEPT; + sockaddr_storage addr_storage; + socklen_t addr_len = sizeof(addr_storage); + sockaddr* addr = reinterpret_cast(&addr_storage); + SOCKET s = DoAccept(s_, addr, &addr_len); + UpdateLastError(); + if (s == INVALID_SOCKET) + return nullptr; + if (out_addr != nullptr) + SocketAddressFromSockAddrStorage(addr_storage, out_addr); + return ss_->WrapSocket(s); +} + +int PhysicalSocket::Close() { + if (s_ == INVALID_SOCKET) + return 0; + int err = ::closesocket(s_); + UpdateLastError(); + s_ = INVALID_SOCKET; + state_ = CS_CLOSED; + enabled_events_ = 0; + if (resolver_) { + resolver_->Destroy(false); + resolver_ = nullptr; } + return err; +} - int EstimateMTU(uint16_t* mtu) override { - SocketAddress addr = GetRemoteAddress(); - if (addr.IsAnyIP()) { - SetError(ENOTCONN); - return -1; - } +int PhysicalSocket::EstimateMTU(uint16_t* mtu) { + SocketAddress addr = GetRemoteAddress(); + if (addr.IsAnyIP()) { + SetError(ENOTCONN); + return -1; + } #if defined(WEBRTC_WIN) - // Gets the interface MTU (TTL=1) for the interface used to reach |addr|. - WinPing ping; - if (!ping.IsValid()) { + // Gets the interface MTU (TTL=1) for the interface used to reach |addr|. + WinPing ping; + if (!ping.IsValid()) { + SetError(EINVAL); // can't think of a better error ID + return -1; + } + int header_size = ICMP_HEADER_SIZE; + if (addr.family() == AF_INET6) { + header_size += IPV6_HEADER_SIZE; + } else if (addr.family() == AF_INET) { + header_size += IP_HEADER_SIZE; + } + + for (int level = 0; PACKET_MAXIMUMS[level + 1] > 0; ++level) { + int32_t size = PACKET_MAXIMUMS[level] - header_size; + WinPing::PingResult result = ping.Ping(addr.ipaddr(), size, + ICMP_PING_TIMEOUT_MILLIS, + 1, false); + if (result == WinPing::PING_FAIL) { SetError(EINVAL); // can't think of a better error ID return -1; + } else if (result != WinPing::PING_TOO_LARGE) { + *mtu = PACKET_MAXIMUMS[level]; + return 0; } - int header_size = ICMP_HEADER_SIZE; - if (addr.family() == AF_INET6) { - header_size += IPV6_HEADER_SIZE; - } else if (addr.family() == AF_INET) { - header_size += IP_HEADER_SIZE; - } + } - for (int level = 0; PACKET_MAXIMUMS[level + 1] > 0; ++level) { - int32_t size = PACKET_MAXIMUMS[level] - header_size; - WinPing::PingResult result = ping.Ping(addr.ipaddr(), size, - ICMP_PING_TIMEOUT_MILLIS, - 1, false); - if (result == WinPing::PING_FAIL) { - SetError(EINVAL); // can't think of a better error ID - return -1; - } else if (result != WinPing::PING_TOO_LARGE) { - *mtu = PACKET_MAXIMUMS[level]; - return 0; - } - } - - ASSERT(false); - return -1; + ASSERT(false); + return -1; #elif defined(WEBRTC_MAC) - // No simple way to do this on Mac OS X. - // SIOCGIFMTU would work if we knew which interface would be used, but - // figuring that out is pretty complicated. For now we'll return an error - // and let the caller pick a default MTU. - SetError(EINVAL); - return -1; + // No simple way to do this on Mac OS X. + // SIOCGIFMTU would work if we knew which interface would be used, but + // figuring that out is pretty complicated. For now we'll return an error + // and let the caller pick a default MTU. + SetError(EINVAL); + return -1; #elif defined(WEBRTC_LINUX) - // Gets the path MTU. - int value; - socklen_t vlen = sizeof(value); - int err = getsockopt(s_, IPPROTO_IP, IP_MTU, &value, &vlen); - if (err < 0) { - UpdateLastError(); - return err; - } + // Gets the path MTU. + int value; + socklen_t vlen = sizeof(value); + int err = getsockopt(s_, IPPROTO_IP, IP_MTU, &value, &vlen); + if (err < 0) { + UpdateLastError(); + return err; + } - ASSERT((0 <= value) && (value <= 65536)); - *mtu = value; - return 0; + ASSERT((0 <= value) && (value <= 65536)); + *mtu = value; + return 0; #elif defined(__native_client__) - // Most socket operations, including this, will fail in NaCl's sandbox. - error_ = EACCES; - return -1; + // Most socket operations, including this, will fail in NaCl's sandbox. + error_ = EACCES; + return -1; #endif +} + + +SOCKET PhysicalSocket::DoAccept(SOCKET socket, + sockaddr* addr, + socklen_t* addrlen) { + return ::accept(socket, addr, addrlen); +} + +void PhysicalSocket::OnResolveResult(AsyncResolverInterface* resolver) { + if (resolver != resolver_) { + return; } - SocketServer* socketserver() { return ss_; } - - protected: - void OnResolveResult(AsyncResolverInterface* resolver) { - if (resolver != resolver_) { - return; - } - - int error = resolver_->GetError(); - if (error == 0) { - error = DoConnect(resolver_->address()); - } else { - Close(); - } - - if (error) { - SetError(error); - SignalCloseEvent(this, error); - } + int error = resolver_->GetError(); + if (error == 0) { + error = DoConnect(resolver_->address()); + } else { + Close(); } - void UpdateLastError() { - SetError(LAST_SYSTEM_ERROR); + if (error) { + SetError(error); + SignalCloseEvent(this, error); } +} - void MaybeRemapSendError() { +void PhysicalSocket::UpdateLastError() { + SetError(LAST_SYSTEM_ERROR); +} + +void PhysicalSocket::MaybeRemapSendError() { #if defined(WEBRTC_MAC) - // https://developer.apple.com/library/mac/documentation/Darwin/ - // Reference/ManPages/man2/sendto.2.html - // ENOBUFS - The output queue for a network interface is full. - // This generally indicates that the interface has stopped sending, - // but may be caused by transient congestion. - if (GetError() == ENOBUFS) { - SetError(EWOULDBLOCK); - } -#endif + // https://developer.apple.com/library/mac/documentation/Darwin/ + // Reference/ManPages/man2/sendto.2.html + // ENOBUFS - The output queue for a network interface is full. + // This generally indicates that the interface has stopped sending, + // but may be caused by transient congestion. + if (GetError() == ENOBUFS) { + SetError(EWOULDBLOCK); } +#endif +} - static int TranslateOption(Option opt, int* slevel, int* sopt) { - switch (opt) { - case OPT_DONTFRAGMENT: +int PhysicalSocket::TranslateOption(Option opt, int* slevel, int* sopt) { + switch (opt) { + case OPT_DONTFRAGMENT: #if defined(WEBRTC_WIN) - *slevel = IPPROTO_IP; - *sopt = IP_DONTFRAGMENT; - break; + *slevel = IPPROTO_IP; + *sopt = IP_DONTFRAGMENT; + break; #elif defined(WEBRTC_MAC) || defined(BSD) || defined(__native_client__) - LOG(LS_WARNING) << "Socket::OPT_DONTFRAGMENT not supported."; - return -1; + LOG(LS_WARNING) << "Socket::OPT_DONTFRAGMENT not supported."; + return -1; #elif defined(WEBRTC_POSIX) - *slevel = IPPROTO_IP; - *sopt = IP_MTU_DISCOVER; - break; + *slevel = IPPROTO_IP; + *sopt = IP_MTU_DISCOVER; + break; #endif - case OPT_RCVBUF: - *slevel = SOL_SOCKET; - *sopt = SO_RCVBUF; - break; - case OPT_SNDBUF: - *slevel = SOL_SOCKET; - *sopt = SO_SNDBUF; - break; - case OPT_NODELAY: - *slevel = IPPROTO_TCP; - *sopt = TCP_NODELAY; - break; - case OPT_DSCP: - LOG(LS_WARNING) << "Socket::OPT_DSCP not supported."; - return -1; - case OPT_RTP_SENDTIME_EXTN_ID: - return -1; // No logging is necessary as this not a OS socket option. - default: - ASSERT(false); - return -1; - } - return 0; + case OPT_RCVBUF: + *slevel = SOL_SOCKET; + *sopt = SO_RCVBUF; + break; + case OPT_SNDBUF: + *slevel = SOL_SOCKET; + *sopt = SO_SNDBUF; + break; + case OPT_NODELAY: + *slevel = IPPROTO_TCP; + *sopt = TCP_NODELAY; + break; + case OPT_DSCP: + LOG(LS_WARNING) << "Socket::OPT_DSCP not supported."; + return -1; + case OPT_RTP_SENDTIME_EXTN_ID: + return -1; // No logging is necessary as this not a OS socket option. + default: + ASSERT(false); + return -1; } - - PhysicalSocketServer* ss_; - SOCKET s_; - uint8_t enabled_events_; - bool udp_; - int error_; - // Protects |error_| that is accessed from different threads. - mutable CriticalSection crit_; - ConnState state_; - AsyncResolver* resolver_; - -#if !defined(NDEBUG) - std::string dbg_addr_; -#endif -}; + return 0; +} #if defined(WEBRTC_POSIX) class EventDispatcher : public Dispatcher { @@ -791,115 +782,119 @@ class PosixSignalDispatcher : public Dispatcher { PhysicalSocketServer *owner_; }; -class SocketDispatcher : public Dispatcher, public PhysicalSocket { - public: - explicit SocketDispatcher(PhysicalSocketServer *ss) : PhysicalSocket(ss) { - } - SocketDispatcher(SOCKET s, PhysicalSocketServer *ss) : PhysicalSocket(ss, s) { - } +SocketDispatcher::SocketDispatcher(PhysicalSocketServer *ss) + : PhysicalSocket(ss) { +} - ~SocketDispatcher() override { - Close(); - } +SocketDispatcher::SocketDispatcher(SOCKET s, PhysicalSocketServer *ss) + : PhysicalSocket(ss, s) { +} - bool Initialize() { - ss_->Add(this); - fcntl(s_, F_SETFL, fcntl(s_, F_GETFL, 0) | O_NONBLOCK); +SocketDispatcher::~SocketDispatcher() { + Close(); +} + +bool SocketDispatcher::Initialize() { + ss_->Add(this); + fcntl(s_, F_SETFL, fcntl(s_, F_GETFL, 0) | O_NONBLOCK); + return true; +} + +bool SocketDispatcher::Create(int type) { + return Create(AF_INET, type); +} + +bool SocketDispatcher::Create(int family, int type) { + // Change the socket to be non-blocking. + if (!PhysicalSocket::Create(family, type)) + return false; + + return Initialize(); +} + +int SocketDispatcher::GetDescriptor() { + return s_; +} + +bool SocketDispatcher::IsDescriptorClosed() { + // We don't have a reliable way of distinguishing end-of-stream + // from readability. So test on each readable call. Is this + // inefficient? Probably. + char ch; + ssize_t res = ::recv(s_, &ch, 1, MSG_PEEK); + if (res > 0) { + // Data available, so not closed. + return false; + } else if (res == 0) { + // EOF, so closed. return true; - } - - virtual bool Create(int type) { - return Create(AF_INET, type); - } - - bool Create(int family, int type) override { - // Change the socket to be non-blocking. - if (!PhysicalSocket::Create(family, type)) - return false; - - return Initialize(); - } - - int GetDescriptor() override { return s_; } - - bool IsDescriptorClosed() override { - // We don't have a reliable way of distinguishing end-of-stream - // from readability. So test on each readable call. Is this - // inefficient? Probably. - char ch; - ssize_t res = ::recv(s_, &ch, 1, MSG_PEEK); - if (res > 0) { - // Data available, so not closed. - return false; - } else if (res == 0) { - // EOF, so closed. - return true; - } else { // error - switch (errno) { - // Returned if we've already closed s_. - case EBADF: - // Returned during ungraceful peer shutdown. - case ECONNRESET: - return true; - default: - // Assume that all other errors are just blocking errors, meaning the - // connection is still good but we just can't read from it right now. - // This should only happen when connecting (and at most once), because - // in all other cases this function is only called if the file - // descriptor is already known to be in the readable state. However, - // it's not necessary a problem if we spuriously interpret a - // "connection lost"-type error as a blocking error, because typically - // the next recv() will get EOF, so we'll still eventually notice that - // the socket is closed. - LOG_ERR(LS_WARNING) << "Assuming benign blocking error"; - return false; - } + } else { // error + switch (errno) { + // Returned if we've already closed s_. + case EBADF: + // Returned during ungraceful peer shutdown. + case ECONNRESET: + return true; + default: + // Assume that all other errors are just blocking errors, meaning the + // connection is still good but we just can't read from it right now. + // This should only happen when connecting (and at most once), because + // in all other cases this function is only called if the file + // descriptor is already known to be in the readable state. However, + // it's not necessary a problem if we spuriously interpret a + // "connection lost"-type error as a blocking error, because typically + // the next recv() will get EOF, so we'll still eventually notice that + // the socket is closed. + LOG_ERR(LS_WARNING) << "Assuming benign blocking error"; + return false; } } +} - uint32_t GetRequestedEvents() override { return enabled_events_; } +uint32_t SocketDispatcher::GetRequestedEvents() { + return enabled_events_; +} - void OnPreEvent(uint32_t ff) override { - if ((ff & DE_CONNECT) != 0) - state_ = CS_CONNECTED; - if ((ff & DE_CLOSE) != 0) - state_ = CS_CLOSED; +void SocketDispatcher::OnPreEvent(uint32_t ff) { + if ((ff & DE_CONNECT) != 0) + state_ = CS_CONNECTED; + if ((ff & DE_CLOSE) != 0) + state_ = CS_CLOSED; +} + +void SocketDispatcher::OnEvent(uint32_t ff, int err) { + // Make sure we deliver connect/accept first. Otherwise, consumers may see + // something like a READ followed by a CONNECT, which would be odd. + if ((ff & DE_CONNECT) != 0) { + enabled_events_ &= ~DE_CONNECT; + SignalConnectEvent(this); } - - void OnEvent(uint32_t ff, int err) override { - // Make sure we deliver connect/accept first. Otherwise, consumers may see - // something like a READ followed by a CONNECT, which would be odd. - if ((ff & DE_CONNECT) != 0) { - enabled_events_ &= ~DE_CONNECT; - SignalConnectEvent(this); - } - if ((ff & DE_ACCEPT) != 0) { - enabled_events_ &= ~DE_ACCEPT; - SignalReadEvent(this); - } - if ((ff & DE_READ) != 0) { - enabled_events_ &= ~DE_READ; - SignalReadEvent(this); - } - if ((ff & DE_WRITE) != 0) { - enabled_events_ &= ~DE_WRITE; - SignalWriteEvent(this); - } - if ((ff & DE_CLOSE) != 0) { - // The socket is now dead to us, so stop checking it. - enabled_events_ = 0; - SignalCloseEvent(this, err); - } + if ((ff & DE_ACCEPT) != 0) { + enabled_events_ &= ~DE_ACCEPT; + SignalReadEvent(this); } - - int Close() override { - if (s_ == INVALID_SOCKET) - return 0; - - ss_->Remove(this); - return PhysicalSocket::Close(); + if ((ff & DE_READ) != 0) { + enabled_events_ &= ~DE_READ; + SignalReadEvent(this); } -}; + if ((ff & DE_WRITE) != 0) { + enabled_events_ &= ~DE_WRITE; + SignalWriteEvent(this); + } + if ((ff & DE_CLOSE) != 0) { + // The socket is now dead to us, so stop checking it. + enabled_events_ = 0; + SignalCloseEvent(this, err); + } +} + +int SocketDispatcher::Close() { + if (s_ == INVALID_SOCKET) + return 0; + + ss_->Remove(this); + return PhysicalSocket::Close(); +} class FileDispatcher: public Dispatcher, public AsyncFile { public: @@ -1015,126 +1010,120 @@ private: WSAEVENT hev_; }; -class SocketDispatcher : public Dispatcher, public PhysicalSocket { - public: - static int next_id_; - int id_; - bool signal_close_; - int signal_err_; +SocketDispatcher::SocketDispatcher(PhysicalSocketServer* ss) + : PhysicalSocket(ss), + id_(0), + signal_close_(false) { +} - SocketDispatcher(PhysicalSocketServer* ss) - : PhysicalSocket(ss), - id_(0), - signal_close_(false) { - } +SocketDispatcher::SocketDispatcher(SOCKET s, PhysicalSocketServer* ss) + : PhysicalSocket(ss, s), + id_(0), + signal_close_(false) { +} - SocketDispatcher(SOCKET s, PhysicalSocketServer* ss) - : PhysicalSocket(ss, s), - id_(0), - signal_close_(false) { - } +SocketDispatcher::~SocketDispatcher() { + Close(); +} - virtual ~SocketDispatcher() { - Close(); - } +bool SocketDispatcher::Initialize() { + ASSERT(s_ != INVALID_SOCKET); + // Must be a non-blocking + u_long argp = 1; + ioctlsocket(s_, FIONBIO, &argp); + ss_->Add(this); + return true; +} - bool Initialize() { - ASSERT(s_ != INVALID_SOCKET); - // Must be a non-blocking - u_long argp = 1; - ioctlsocket(s_, FIONBIO, &argp); - ss_->Add(this); - return true; - } +bool SocketDispatcher::Create(int type) { + return Create(AF_INET, type); +} - virtual bool Create(int type) { - return Create(AF_INET, type); - } +bool SocketDispatcher::Create(int family, int type) { + // Create socket + if (!PhysicalSocket::Create(family, type)) + return false; - virtual bool Create(int family, int type) { - // Create socket - if (!PhysicalSocket::Create(family, type)) - return false; + if (!Initialize()) + return false; - if (!Initialize()) - return false; + do { id_ = ++next_id_; } while (id_ == 0); + return true; +} - do { id_ = ++next_id_; } while (id_ == 0); - return true; - } +int SocketDispatcher::Close() { + if (s_ == INVALID_SOCKET) + return 0; - virtual int Close() { - if (s_ == INVALID_SOCKET) - return 0; + id_ = 0; + signal_close_ = false; + ss_->Remove(this); + return PhysicalSocket::Close(); +} - id_ = 0; - signal_close_ = false; - ss_->Remove(this); - return PhysicalSocket::Close(); - } +uint32_t SocketDispatcher::GetRequestedEvents() { + return enabled_events_; +} - virtual uint32_t GetRequestedEvents() { return enabled_events_; } +void SocketDispatcher::OnPreEvent(uint32_t ff) { + if ((ff & DE_CONNECT) != 0) + state_ = CS_CONNECTED; + // We set CS_CLOSED from CheckSignalClose. +} - virtual void OnPreEvent(uint32_t ff) { - if ((ff & DE_CONNECT) != 0) - state_ = CS_CONNECTED; - // We set CS_CLOSED from CheckSignalClose. - } - - virtual void OnEvent(uint32_t ff, int err) { - int cache_id = id_; - // Make sure we deliver connect/accept first. Otherwise, consumers may see - // something like a READ followed by a CONNECT, which would be odd. - if (((ff & DE_CONNECT) != 0) && (id_ == cache_id)) { - if (ff != DE_CONNECT) - LOG(LS_VERBOSE) << "Signalled with DE_CONNECT: " << ff; - enabled_events_ &= ~DE_CONNECT; +void SocketDispatcher::OnEvent(uint32_t ff, int err) { + int cache_id = id_; + // Make sure we deliver connect/accept first. Otherwise, consumers may see + // something like a READ followed by a CONNECT, which would be odd. + if (((ff & DE_CONNECT) != 0) && (id_ == cache_id)) { + if (ff != DE_CONNECT) + LOG(LS_VERBOSE) << "Signalled with DE_CONNECT: " << ff; + enabled_events_ &= ~DE_CONNECT; #if !defined(NDEBUG) - dbg_addr_ = "Connected @ "; - dbg_addr_.append(GetRemoteAddress().ToString()); + dbg_addr_ = "Connected @ "; + dbg_addr_.append(GetRemoteAddress().ToString()); #endif - SignalConnectEvent(this); - } - if (((ff & DE_ACCEPT) != 0) && (id_ == cache_id)) { - enabled_events_ &= ~DE_ACCEPT; - SignalReadEvent(this); - } - if ((ff & DE_READ) != 0) { - enabled_events_ &= ~DE_READ; - SignalReadEvent(this); - } - if (((ff & DE_WRITE) != 0) && (id_ == cache_id)) { - enabled_events_ &= ~DE_WRITE; - SignalWriteEvent(this); - } - if (((ff & DE_CLOSE) != 0) && (id_ == cache_id)) { - signal_close_ = true; - signal_err_ = err; - } + SignalConnectEvent(this); } - - virtual WSAEVENT GetWSAEvent() { - return WSA_INVALID_EVENT; + if (((ff & DE_ACCEPT) != 0) && (id_ == cache_id)) { + enabled_events_ &= ~DE_ACCEPT; + SignalReadEvent(this); } - - virtual SOCKET GetSocket() { - return s_; + if ((ff & DE_READ) != 0) { + enabled_events_ &= ~DE_READ; + SignalReadEvent(this); } - - virtual bool CheckSignalClose() { - if (!signal_close_) - return false; - - char ch; - if (recv(s_, &ch, 1, MSG_PEEK) > 0) - return false; - - state_ = CS_CLOSED; - signal_close_ = false; - SignalCloseEvent(this, signal_err_); - return true; + if (((ff & DE_WRITE) != 0) && (id_ == cache_id)) { + enabled_events_ &= ~DE_WRITE; + SignalWriteEvent(this); } -}; + if (((ff & DE_CLOSE) != 0) && (id_ == cache_id)) { + signal_close_ = true; + signal_err_ = err; + } +} + +WSAEVENT SocketDispatcher::GetWSAEvent() { + return WSA_INVALID_EVENT; +} + +SOCKET SocketDispatcher::GetSocket() { + return s_; +} + +bool SocketDispatcher::CheckSignalClose() { + if (!signal_close_) + return false; + + char ch; + if (recv(s_, &ch, 1, MSG_PEEK) > 0) + return false; + + state_ = CS_CLOSED; + signal_close_ = false; + SignalCloseEvent(this, signal_err_); + return true; +} int SocketDispatcher::next_id_ = 0; @@ -1190,7 +1179,7 @@ Socket* PhysicalSocketServer::CreateSocket(int family, int type) { return socket; } else { delete socket; - return 0; + return nullptr; } } @@ -1204,7 +1193,7 @@ AsyncSocket* PhysicalSocketServer::CreateAsyncSocket(int family, int type) { return dispatcher; } else { delete dispatcher; - return 0; + return nullptr; } } @@ -1214,7 +1203,7 @@ AsyncSocket* PhysicalSocketServer::WrapSocket(SOCKET s) { return dispatcher; } else { delete dispatcher; - return 0; + return nullptr; } } @@ -1343,7 +1332,7 @@ bool PhysicalSocketServer::Wait(int cmsWait, bool process_io) { int errcode = 0; // Reap any error code, which can be signaled through reads or writes. - // TODO: Should we set errcode if getsockopt fails? + // TODO(pthatcher): Should we set errcode if getsockopt fails? if (FD_ISSET(fd, &fdsRead) || FD_ISSET(fd, &fdsWrite)) { socklen_t len = sizeof(errcode); ::getsockopt(fd, SOL_SOCKET, SO_ERROR, &errcode, &len); @@ -1352,7 +1341,7 @@ bool PhysicalSocketServer::Wait(int cmsWait, bool process_io) { // Check readable descriptors. If we're waiting on an accept, signal // that. Otherwise we're waiting for data, check to see if we're // readable or really closed. - // TODO: Only peek at TCP descriptors. + // TODO(pthatcher): Only peek at TCP descriptors. if (FD_ISSET(fd, &fdsRead)) { FD_CLR(fd, &fdsRead); if (pdispatcher->GetRequestedEvents() & DE_ACCEPT) { @@ -1526,7 +1515,7 @@ bool PhysicalSocketServer::Wait(int cmsWait, bool process_io) { if (dw == WSA_WAIT_FAILED) { // Failed? - // TODO: need a better strategy than this! + // TODO(pthatcher): need a better strategy than this! WSAGetLastError(); ASSERT(false); return false; diff --git a/webrtc/base/physicalsocketserver.h b/webrtc/base/physicalsocketserver.h index af09e0b988..ae1f10f596 100644 --- a/webrtc/base/physicalsocketserver.h +++ b/webrtc/base/physicalsocketserver.h @@ -14,6 +14,7 @@ #include #include "webrtc/base/asyncfile.h" +#include "webrtc/base/nethelpers.h" #include "webrtc/base/scoped_ptr.h" #include "webrtc/base/socketserver.h" #include "webrtc/base/criticalsection.h" @@ -115,6 +116,107 @@ class PhysicalSocketServer : public SocketServer { #endif }; +class PhysicalSocket : public AsyncSocket, public sigslot::has_slots<> { + public: + PhysicalSocket(PhysicalSocketServer* ss, SOCKET s = INVALID_SOCKET); + ~PhysicalSocket() override; + + // Creates the underlying OS socket (same as the "socket" function). + virtual bool Create(int family, int type); + + SocketAddress GetLocalAddress() const override; + SocketAddress GetRemoteAddress() const override; + + int Bind(const SocketAddress& bind_addr) override; + int Connect(const SocketAddress& addr) override; + + int GetError() const override; + void SetError(int error) override; + + ConnState GetState() const override; + + int GetOption(Option opt, int* value) override; + int SetOption(Option opt, int value) override; + + int Send(const void* pv, size_t cb) override; + int SendTo(const void* buffer, + size_t length, + const SocketAddress& addr) override; + + int Recv(void* buffer, size_t length) override; + int RecvFrom(void* buffer, size_t length, SocketAddress* out_addr) override; + + int Listen(int backlog) override; + AsyncSocket* Accept(SocketAddress* out_addr) override; + + int Close() override; + + int EstimateMTU(uint16_t* mtu) override; + + SocketServer* socketserver() { return ss_; } + + protected: + int DoConnect(const SocketAddress& connect_addr); + + // Make virtual so ::accept can be overwritten in tests. + virtual SOCKET DoAccept(SOCKET socket, sockaddr* addr, socklen_t* addrlen); + + void OnResolveResult(AsyncResolverInterface* resolver); + + void UpdateLastError(); + void MaybeRemapSendError(); + + static int TranslateOption(Option opt, int* slevel, int* sopt); + + PhysicalSocketServer* ss_; + SOCKET s_; + uint8_t enabled_events_; + bool udp_; + mutable CriticalSection crit_; + int error_ GUARDED_BY(crit_); + ConnState state_; + AsyncResolver* resolver_; + +#if !defined(NDEBUG) + std::string dbg_addr_; +#endif +}; + +class SocketDispatcher : public Dispatcher, public PhysicalSocket { + public: + explicit SocketDispatcher(PhysicalSocketServer *ss); + SocketDispatcher(SOCKET s, PhysicalSocketServer *ss); + ~SocketDispatcher() override; + + bool Initialize(); + + virtual bool Create(int type); + bool Create(int family, int type) override; + +#if defined(WEBRTC_WIN) + WSAEVENT GetWSAEvent() override; + SOCKET GetSocket() override; + bool CheckSignalClose() override; +#elif defined(WEBRTC_POSIX) + int GetDescriptor() override; + bool IsDescriptorClosed() override; +#endif + + uint32_t GetRequestedEvents() override; + void OnPreEvent(uint32_t ff) override; + void OnEvent(uint32_t ff, int err) override; + + int Close() override; + +#if defined(WEBRTC_WIN) + private: + static int next_id_; + int id_; + bool signal_close_; + int signal_err_; +#endif // WEBRTC_WIN +}; + } // namespace rtc #endif // WEBRTC_BASE_PHYSICALSOCKETSERVER_H__ diff --git a/webrtc/base/physicalsocketserver_unittest.cc b/webrtc/base/physicalsocketserver_unittest.cc index ad0d4657e8..5ff4859e13 100644 --- a/webrtc/base/physicalsocketserver_unittest.cc +++ b/webrtc/base/physicalsocketserver_unittest.cc @@ -22,9 +22,82 @@ namespace rtc { -class PhysicalSocketTest : public SocketTest { +class PhysicalSocketTest; + +class FakeSocketDispatcher : public SocketDispatcher { + public: + explicit FakeSocketDispatcher(PhysicalSocketServer* ss) + : SocketDispatcher(ss) { + } + + protected: + SOCKET DoAccept(SOCKET socket, sockaddr* addr, socklen_t* addrlen) override; }; +class FakePhysicalSocketServer : public PhysicalSocketServer { + public: + explicit FakePhysicalSocketServer(PhysicalSocketTest* test) + : test_(test) { + } + + AsyncSocket* CreateAsyncSocket(int type) override { + SocketDispatcher* dispatcher = new FakeSocketDispatcher(this); + if (dispatcher->Create(type)) { + return dispatcher; + } else { + delete dispatcher; + return nullptr; + } + } + + AsyncSocket* CreateAsyncSocket(int family, int type) override { + SocketDispatcher* dispatcher = new FakeSocketDispatcher(this); + if (dispatcher->Create(family, type)) { + return dispatcher; + } else { + delete dispatcher; + return nullptr; + } + } + + PhysicalSocketTest* GetTest() const { return test_; } + + private: + PhysicalSocketTest* test_; +}; + +class PhysicalSocketTest : public SocketTest { + public: + // Set flag to simluate failures when calling "::accept" on a AsyncSocket. + void SetFailAccept(bool fail) { fail_accept_ = fail; } + bool FailAccept() const { return fail_accept_; } + + protected: + PhysicalSocketTest() + : server_(new FakePhysicalSocketServer(this)), + scope_(server_.get()), + fail_accept_(false) { + } + + void ConnectInternalAcceptError(const IPAddress& loopback); + + rtc::scoped_ptr server_; + SocketServerScope scope_; + bool fail_accept_; +}; + +SOCKET FakeSocketDispatcher::DoAccept(SOCKET socket, + sockaddr* addr, + socklen_t* addrlen) { + FakePhysicalSocketServer* ss = + static_cast(socketserver()); + if (ss->GetTest()->FailAccept()) { + return INVALID_SOCKET; + } + + return SocketDispatcher::DoAccept(socket, addr, addrlen); +} + TEST_F(PhysicalSocketTest, TestConnectIPv4) { SocketTest::TestConnectIPv4(); } @@ -51,6 +124,92 @@ TEST_F(PhysicalSocketTest, TestConnectFailIPv4) { SocketTest::TestConnectFailIPv4(); } +void PhysicalSocketTest::ConnectInternalAcceptError(const IPAddress& loopback) { + testing::StreamSink sink; + SocketAddress accept_addr; + + // Create two clients. + scoped_ptr client1(server_->CreateAsyncSocket(loopback.family(), + SOCK_STREAM)); + sink.Monitor(client1.get()); + EXPECT_EQ(AsyncSocket::CS_CLOSED, client1->GetState()); + EXPECT_PRED1(IsUnspecOrEmptyIP, client1->GetLocalAddress().ipaddr()); + + scoped_ptr client2(server_->CreateAsyncSocket(loopback.family(), + SOCK_STREAM)); + sink.Monitor(client2.get()); + EXPECT_EQ(AsyncSocket::CS_CLOSED, client2->GetState()); + EXPECT_PRED1(IsUnspecOrEmptyIP, client2->GetLocalAddress().ipaddr()); + + // Create server and listen. + scoped_ptr server( + server_->CreateAsyncSocket(loopback.family(), SOCK_STREAM)); + sink.Monitor(server.get()); + EXPECT_EQ(0, server->Bind(SocketAddress(loopback, 0))); + EXPECT_EQ(0, server->Listen(5)); + EXPECT_EQ(AsyncSocket::CS_CONNECTING, server->GetState()); + + // Ensure no pending server connections, since we haven't done anything yet. + EXPECT_FALSE(sink.Check(server.get(), testing::SSE_READ)); + EXPECT_TRUE(nullptr == server->Accept(&accept_addr)); + EXPECT_TRUE(accept_addr.IsNil()); + + // Attempt first connect to listening socket. + EXPECT_EQ(0, client1->Connect(server->GetLocalAddress())); + EXPECT_FALSE(client1->GetLocalAddress().IsNil()); + EXPECT_NE(server->GetLocalAddress(), client1->GetLocalAddress()); + + // Client is connecting, outcome not yet determined. + EXPECT_EQ(AsyncSocket::CS_CONNECTING, client1->GetState()); + EXPECT_FALSE(sink.Check(client1.get(), testing::SSE_OPEN)); + EXPECT_FALSE(sink.Check(client1.get(), testing::SSE_CLOSE)); + + // Server has pending connection, try to accept it (will fail). + EXPECT_TRUE_WAIT((sink.Check(server.get(), testing::SSE_READ)), kTimeout); + // Simulate "::accept" returning an error. + SetFailAccept(true); + scoped_ptr accepted(server->Accept(&accept_addr)); + EXPECT_FALSE(accepted); + ASSERT_TRUE(accept_addr.IsNil()); + + // Ensure no more pending server connections. + EXPECT_FALSE(sink.Check(server.get(), testing::SSE_READ)); + EXPECT_TRUE(nullptr == server->Accept(&accept_addr)); + EXPECT_TRUE(accept_addr.IsNil()); + + // Attempt second connect to listening socket. + EXPECT_EQ(0, client2->Connect(server->GetLocalAddress())); + EXPECT_FALSE(client2->GetLocalAddress().IsNil()); + EXPECT_NE(server->GetLocalAddress(), client2->GetLocalAddress()); + + // Client is connecting, outcome not yet determined. + EXPECT_EQ(AsyncSocket::CS_CONNECTING, client2->GetState()); + EXPECT_FALSE(sink.Check(client2.get(), testing::SSE_OPEN)); + EXPECT_FALSE(sink.Check(client2.get(), testing::SSE_CLOSE)); + + // Server has pending connection, try to accept it (will succeed). + EXPECT_TRUE_WAIT((sink.Check(server.get(), testing::SSE_READ)), kTimeout); + SetFailAccept(false); + scoped_ptr accepted2(server->Accept(&accept_addr)); + ASSERT_TRUE(accepted2); + EXPECT_FALSE(accept_addr.IsNil()); + EXPECT_EQ(accepted2->GetRemoteAddress(), accept_addr); +} + +TEST_F(PhysicalSocketTest, TestConnectAcceptErrorIPv4) { + ConnectInternalAcceptError(kIPv4Loopback); +} + +// Crashes on Linux. See webrtc:4923. +#if defined(WEBRTC_LINUX) +#define MAYBE_TestConnectAcceptErrorIPv6 DISABLED_TestConnectAcceptErrorIPv6 +#else +#define MAYBE_TestConnectAcceptErrorIPv6 TestConnectAcceptErrorIPv6 +#endif +TEST_F(PhysicalSocketTest, MAYBE_TestConnectAcceptErrorIPv6) { + ConnectInternalAcceptError(kIPv6Loopback); +} + // Crashes on Linux. See webrtc:4923. #if defined(WEBRTC_LINUX) #define MAYBE_TestConnectFailIPv6 DISABLED_TestConnectFailIPv6 diff --git a/webrtc/base/socket_unittest.h b/webrtc/base/socket_unittest.h index d368afb3f5..e4a6b32705 100644 --- a/webrtc/base/socket_unittest.h +++ b/webrtc/base/socket_unittest.h @@ -21,8 +21,9 @@ namespace rtc { // socketserver, and call the SocketTest test methods. class SocketTest : public testing::Test { protected: - SocketTest() : ss_(NULL), kIPv4Loopback(INADDR_LOOPBACK), - kIPv6Loopback(in6addr_loopback) {} + SocketTest() : kIPv4Loopback(INADDR_LOOPBACK), + kIPv6Loopback(in6addr_loopback), + ss_(nullptr) {} virtual void SetUp() { ss_ = Thread::Current()->socketserver(); } void TestConnectIPv4(); void TestConnectIPv6(); @@ -57,6 +58,10 @@ class SocketTest : public testing::Test { void TestGetSetOptionsIPv4(); void TestGetSetOptionsIPv6(); + static const int kTimeout = 5000; // ms + const IPAddress kIPv4Loopback; + const IPAddress kIPv6Loopback; + private: void ConnectInternal(const IPAddress& loopback); void ConnectWithDnsLookupInternal(const IPAddress& loopback, @@ -77,12 +82,13 @@ class SocketTest : public testing::Test { void UdpReadyToSend(const IPAddress& loopback); void GetSetOptionsInternal(const IPAddress& loopback); - static const int kTimeout = 5000; // ms SocketServer* ss_; - const IPAddress kIPv4Loopback; - const IPAddress kIPv6Loopback; }; +// For unbound sockets, GetLocalAddress / GetRemoteAddress return AF_UNSPEC +// values on Windows, but an empty address of the same family on Linux/MacOS X. +bool IsUnspecOrEmptyIP(const IPAddress& address); + } // namespace rtc #endif // WEBRTC_BASE_SOCKET_UNITTEST_H_