diff --git a/p2p/base/async_stun_tcp_socket.cc b/p2p/base/async_stun_tcp_socket.cc index 64c19c496c..5f8f07227f 100644 --- a/p2p/base/async_stun_tcp_socket.cc +++ b/p2p/base/async_stun_tcp_socket.cc @@ -49,7 +49,7 @@ AsyncStunTCPSocket* AsyncStunTCPSocket::Create( } AsyncStunTCPSocket::AsyncStunTCPSocket(rtc::Socket* socket) - : rtc::AsyncTCPSocketBase(socket, /*listen=*/false, kBufSize) {} + : rtc::AsyncTCPSocketBase(socket, kBufSize) {} int AsyncStunTCPSocket::Send(const void* pv, size_t cb, @@ -125,10 +125,6 @@ void AsyncStunTCPSocket::ProcessInput(char* data, size_t* len) { } } -void AsyncStunTCPSocket::HandleIncomingConnection(rtc::Socket* socket) { - SignalNewConnection(this, new AsyncStunTCPSocket(socket)); -} - size_t AsyncStunTCPSocket::GetExpectedLength(const void* data, size_t len, int* pad_bytes) { diff --git a/p2p/base/async_stun_tcp_socket.h b/p2p/base/async_stun_tcp_socket.h index 2dc9613eec..eb4eef7cdc 100644 --- a/p2p/base/async_stun_tcp_socket.h +++ b/p2p/base/async_stun_tcp_socket.h @@ -36,7 +36,6 @@ class AsyncStunTCPSocket : public rtc::AsyncTCPSocketBase { size_t cb, const rtc::PacketOptions& options) override; void ProcessInput(char* data, size_t* len) override; - void HandleIncomingConnection(rtc::Socket* socket) override; private: // This method returns the message hdr + length written in the header. diff --git a/p2p/base/async_stun_tcp_socket_unittest.cc b/p2p/base/async_stun_tcp_socket_unittest.cc index c774b52fa4..72d6a7fde0 100644 --- a/p2p/base/async_stun_tcp_socket_unittest.cc +++ b/p2p/base/async_stun_tcp_socket_unittest.cc @@ -16,6 +16,7 @@ #include #include #include +#include #include "absl/memory/memory.h" #include "rtc_base/network/sent_packet.h" @@ -59,10 +60,10 @@ static unsigned char kTurnChannelDataMessageWithOddLength[] = { static const rtc::SocketAddress kClientAddr("11.11.11.11", 0); static const rtc::SocketAddress kServerAddr("22.22.22.22", 0); -class AsyncStunServerTCPSocket : public rtc::AsyncTCPSocket { +class AsyncStunServerTCPSocket : public rtc::AsyncTcpListenSocket { public: - explicit AsyncStunServerTCPSocket(rtc::Socket* socket) - : AsyncTCPSocket(socket, true) {} + explicit AsyncStunServerTCPSocket(std::unique_ptr socket) + : AsyncTcpListenSocket(std::move(socket)) {} void HandleIncomingConnection(rtc::Socket* socket) override { SignalNewConnection(this, new AsyncStunTCPSocket(socket)); } @@ -77,9 +78,11 @@ class AsyncStunTCPSocketTest : public ::testing::Test, virtual void SetUp() { CreateSockets(); } void CreateSockets() { - rtc::Socket* server = vss_->CreateSocket(kServerAddr.family(), SOCK_STREAM); + std::unique_ptr server = + absl::WrapUnique(vss_->CreateSocket(kServerAddr.family(), SOCK_STREAM)); server->Bind(kServerAddr); - listen_socket_ = std::make_unique(server); + listen_socket_ = + std::make_unique(std::move(server)); listen_socket_->SignalNewConnection.connect( this, &AsyncStunTCPSocketTest::OnNewConnection); diff --git a/p2p/base/basic_packet_socket_factory.cc b/p2p/base/basic_packet_socket_factory.cc index 901e3b9eeb..e0f21fefdc 100644 --- a/p2p/base/basic_packet_socket_factory.cc +++ b/p2p/base/basic_packet_socket_factory.cc @@ -14,6 +14,7 @@ #include +#include "absl/memory/memory.h" #include "api/async_dns_resolver.h" #include "api/wrapping_async_dns_resolver.h" #include "p2p/base/async_stun_tcp_socket.h" @@ -89,7 +90,7 @@ AsyncListenSocket* BasicPacketSocketFactory::CreateServerTcpSocket( RTC_CHECK(!(opts & PacketSocketFactory::OPT_STUN)); - return new AsyncTCPSocket(socket, true); + return new AsyncTcpListenSocket(absl::WrapUnique(socket)); } AsyncPacketSocket* BasicPacketSocketFactory::CreateClientTcpSocket( @@ -183,7 +184,7 @@ AsyncPacketSocket* BasicPacketSocketFactory::CreateClientTcpSocket( if (tcp_options.opts & PacketSocketFactory::OPT_STUN) { tcp_socket = new cricket::AsyncStunTCPSocket(socket); } else { - tcp_socket = new AsyncTCPSocket(socket, false); + tcp_socket = new AsyncTCPSocket(socket); } return tcp_socket; diff --git a/p2p/base/port_unittest.cc b/p2p/base/port_unittest.cc index 52153d8c90..5bcc186bb2 100644 --- a/p2p/base/port_unittest.cc +++ b/p2p/base/port_unittest.cc @@ -1060,6 +1060,24 @@ class FakeAsyncPacketSocket : public AsyncPacketSocket { State state_; }; +class FakeAsyncListenSocket : public AsyncListenSocket { + public: + // Returns current local address. Address may be set to NULL if the + // socket is not bound yet (GetState() returns STATE_BINDING). + virtual SocketAddress GetLocalAddress() const { return local_address_; } + void Bind(const SocketAddress& address) { + local_address_ = address; + state_ = State::kBound; + } + virtual int GetOption(Socket::Option opt, int* value) { return 0; } + virtual int SetOption(Socket::Option opt, int value) { return 0; } + virtual State GetState() const { return state_; } + + private: + SocketAddress local_address_; + State state_ = State::kClosed; +}; + // Local -> XXXX TEST_F(PortTest, TestLocalToLocal) { TestLocalToLocal(); @@ -1508,8 +1526,8 @@ TEST_F(PortTest, TestDelayedBindingUdp) { } TEST_F(PortTest, TestDisableInterfaceOfTcpPort) { - FakeAsyncPacketSocket* lsocket = new FakeAsyncPacketSocket(); - FakeAsyncPacketSocket* rsocket = new FakeAsyncPacketSocket(); + FakeAsyncListenSocket* lsocket = new FakeAsyncListenSocket(); + FakeAsyncListenSocket* rsocket = new FakeAsyncListenSocket(); FakePacketSocketFactory socket_factory; socket_factory.set_next_server_tcp_socket(lsocket); @@ -1518,10 +1536,8 @@ TEST_F(PortTest, TestDisableInterfaceOfTcpPort) { socket_factory.set_next_server_tcp_socket(rsocket); auto rport = CreateTcpPort(kLocalAddr2, &socket_factory); - lsocket->set_state(AsyncPacketSocket::STATE_BOUND); - lsocket->local_address_ = kLocalAddr1; - rsocket->set_state(AsyncPacketSocket::STATE_BOUND); - rsocket->local_address_ = kLocalAddr2; + lsocket->Bind(kLocalAddr1); + rsocket->Bind(kLocalAddr2); lport->SetIceRole(cricket::ICEROLE_CONTROLLING); lport->SetIceTiebreaker(kTiebreaker1); @@ -1560,17 +1576,17 @@ void PortTest::TestCrossFamilyPorts(int type) { SocketAddress("192.168.1.3", 0), SocketAddress("192.168.1.4", 0), SocketAddress("2001:db8::1", 0), SocketAddress("2001:db8::2", 0)}; for (int i = 0; i < 4; i++) { - FakeAsyncPacketSocket* socket = new FakeAsyncPacketSocket(); if (type == SOCK_DGRAM) { + FakeAsyncPacketSocket* socket = new FakeAsyncPacketSocket(); factory.set_next_udp_socket(socket); ports[i] = CreateUdpPort(addresses[i], &factory); socket->set_state(AsyncPacketSocket::STATE_BINDING); socket->SignalAddressReady(socket, addresses[i]); } else if (type == SOCK_STREAM) { + FakeAsyncListenSocket* socket = new FakeAsyncListenSocket(); factory.set_next_server_tcp_socket(socket); ports[i] = CreateTcpPort(addresses[i], &factory); - socket->set_state(AsyncPacketSocket::STATE_BOUND); - socket->local_address_ = addresses[i]; + socket->Bind(addresses[i]); } ports[i]->PrepareAddress(); } diff --git a/p2p/base/tcp_port.cc b/p2p/base/tcp_port.cc index f9dd1853aa..ea805e5894 100644 --- a/p2p/base/tcp_port.cc +++ b/p2p/base/tcp_port.cc @@ -169,13 +169,11 @@ void TCPPort::PrepareAddress() { // Socket may be in the CLOSED state if Listen() // failed, we still want to add the socket address. RTC_LOG(LS_VERBOSE) << "Preparing TCP address, current state: " - << listen_socket_->GetState(); - if (listen_socket_->GetState() == rtc::AsyncPacketSocket::STATE_BOUND || - listen_socket_->GetState() == rtc::AsyncPacketSocket::STATE_CLOSED) - AddAddress(listen_socket_->GetLocalAddress(), - listen_socket_->GetLocalAddress(), rtc::SocketAddress(), - TCP_PROTOCOL_NAME, "", TCPTYPE_PASSIVE_STR, LOCAL_PORT_TYPE, - ICE_TYPE_PREFERENCE_HOST_TCP, 0, "", true); + << static_cast(listen_socket_->GetState()); + AddAddress(listen_socket_->GetLocalAddress(), + listen_socket_->GetLocalAddress(), rtc::SocketAddress(), + TCP_PROTOCOL_NAME, "", TCPTYPE_PASSIVE_STR, LOCAL_PORT_TYPE, + ICE_TYPE_PREFERENCE_HOST_TCP, 0, "", true); } else { RTC_LOG(LS_INFO) << ToString() << ": Not listening due to firewall restrictions."; diff --git a/rtc_base/async_packet_socket.h b/rtc_base/async_packet_socket.h index 8362604f9e..ce36dd6373 100644 --- a/rtc_base/async_packet_socket.h +++ b/rtc_base/async_packet_socket.h @@ -128,17 +128,31 @@ class RTC_EXPORT AsyncPacketSocket : public sigslot::has_slots<> { // CONNECTED to CLOSED. sigslot::signal2 SignalClose; - // Used only for listening TCP sockets. - sigslot::signal2 SignalNewConnection; - private: RTC_DISALLOW_COPY_AND_ASSIGN(AsyncPacketSocket); }; -// TODO(bugs.webrtc.org/13065): Intended to be broken out into a separate class, -// after downstream has adapted the new name. The main feature to move from -// AsyncPacketSocket to AsyncListenSocket is the SignalNewConnection. -using AsyncListenSocket = AsyncPacketSocket; +// Listen socket, producing an AsyncPacketSocket when a peer connects. +class RTC_EXPORT AsyncListenSocket : public sigslot::has_slots<> { + public: + enum class State { + kClosed, + kBound, + }; + + // Returns current state of the socket. + virtual State GetState() const = 0; + + // Returns current local address. Address may be set to null if the + // socket is not bound yet (GetState() returns kBinding). + virtual SocketAddress GetLocalAddress() const = 0; + + // Get/set options. + virtual int GetOption(Socket::Option opt, int* value) = 0; + virtual int SetOption(Socket::Option opt, int value) = 0; + + sigslot::signal2 SignalNewConnection; +}; void CopySocketInformationToPacketInfo(size_t packet_size_bytes, const AsyncPacketSocket& socket_from, diff --git a/rtc_base/async_tcp_socket.cc b/rtc_base/async_tcp_socket.cc index 76efb6dec1..37a1052d52 100644 --- a/rtc_base/async_tcp_socket.cc +++ b/rtc_base/async_tcp_socket.cc @@ -62,16 +62,11 @@ Socket* AsyncTCPSocketBase::ConnectSocket( } AsyncTCPSocketBase::AsyncTCPSocketBase(Socket* socket, - bool listen, size_t max_packet_size) : socket_(socket), - listen_(listen), max_insize_(max_packet_size), max_outsize_(max_packet_size) { - if (!listen_) { - // Listening sockets don't send/receive data, so they don't need buffers. - inbuf_.EnsureCapacity(kMinimumRecvSize); - } + inbuf_.EnsureCapacity(kMinimumRecvSize); RTC_DCHECK(socket_.get() != nullptr); socket_->SignalConnectEvent.connect(this, @@ -79,12 +74,6 @@ AsyncTCPSocketBase::AsyncTCPSocketBase(Socket* socket, socket_->SignalReadEvent.connect(this, &AsyncTCPSocketBase::OnReadEvent); socket_->SignalWriteEvent.connect(this, &AsyncTCPSocketBase::OnWriteEvent); socket_->SignalCloseEvent.connect(this, &AsyncTCPSocketBase::OnCloseEvent); - - if (listen_) { - if (socket_->Listen(kListenBacklog) < 0) { - RTC_LOG(LS_ERROR) << "Listen() failed with error " << socket_->GetError(); - } - } } AsyncTCPSocketBase::~AsyncTCPSocketBase() {} @@ -106,11 +95,7 @@ AsyncTCPSocket::State AsyncTCPSocketBase::GetState() const { case Socket::CS_CLOSED: return STATE_CLOSED; case Socket::CS_CONNECTING: - if (listen_) { - return STATE_BOUND; - } else { - return STATE_CONNECTING; - } + return STATE_CONNECTING; case Socket::CS_CONNECTED: return STATE_CONNECTED; default: @@ -149,7 +134,6 @@ int AsyncTCPSocketBase::SendTo(const void* pv, } int AsyncTCPSocketBase::FlushOutBuffer() { - RTC_DCHECK(!listen_); RTC_DCHECK_GT(outbuf_.size(), 0); rtc::ArrayView view = outbuf_; int res; @@ -189,7 +173,6 @@ int AsyncTCPSocketBase::FlushOutBuffer() { void AsyncTCPSocketBase::AppendToOutBuffer(const void* pv, size_t cb) { RTC_DCHECK(outbuf_.size() + cb <= max_outsize_); - RTC_DCHECK(!listen_); outbuf_.AppendData(static_cast(pv), cb); } @@ -200,62 +183,44 @@ void AsyncTCPSocketBase::OnConnectEvent(Socket* socket) { void AsyncTCPSocketBase::OnReadEvent(Socket* socket) { RTC_DCHECK(socket_.get() == socket); - if (listen_) { - rtc::SocketAddress address; - rtc::Socket* new_socket = socket->Accept(&address); - if (!new_socket) { - // TODO(stefan): Do something better like forwarding the error - // to the user. - RTC_LOG(LS_ERROR) << "TCP accept failed with error " - << socket_->GetError(); - return; + size_t total_recv = 0; + while (true) { + size_t free_size = inbuf_.capacity() - inbuf_.size(); + if (free_size < kMinimumRecvSize && inbuf_.capacity() < max_insize_) { + inbuf_.EnsureCapacity(std::min(max_insize_, inbuf_.capacity() * 2)); + free_size = inbuf_.capacity() - inbuf_.size(); } - HandleIncomingConnection(new_socket); + int len = socket_->Recv(inbuf_.data() + inbuf_.size(), free_size, nullptr); + if (len < 0) { + // TODO(stefan): Do something better like forwarding the error to the + // user. + if (!socket_->IsBlocking()) { + RTC_LOG(LS_ERROR) << "Recv() returned error: " << socket_->GetError(); + } + break; + } - // Prime a read event in case data is waiting. - new_socket->SignalReadEvent(new_socket); + total_recv += len; + inbuf_.SetSize(inbuf_.size() + len); + if (!len || static_cast(len) < free_size) { + break; + } + } + + if (!total_recv) { + return; + } + + size_t size = inbuf_.size(); + ProcessInput(inbuf_.data(), &size); + + if (size > inbuf_.size()) { + RTC_LOG(LS_ERROR) << "input buffer overflow"; + RTC_NOTREACHED(); + inbuf_.Clear(); } else { - size_t total_recv = 0; - while (true) { - size_t free_size = inbuf_.capacity() - inbuf_.size(); - if (free_size < kMinimumRecvSize && inbuf_.capacity() < max_insize_) { - inbuf_.EnsureCapacity(std::min(max_insize_, inbuf_.capacity() * 2)); - free_size = inbuf_.capacity() - inbuf_.size(); - } - - int len = - socket_->Recv(inbuf_.data() + inbuf_.size(), free_size, nullptr); - if (len < 0) { - // TODO(stefan): Do something better like forwarding the error to the - // user. - if (!socket_->IsBlocking()) { - RTC_LOG(LS_ERROR) << "Recv() returned error: " << socket_->GetError(); - } - break; - } - - total_recv += len; - inbuf_.SetSize(inbuf_.size() + len); - if (!len || static_cast(len) < free_size) { - break; - } - } - - if (!total_recv) { - return; - } - - size_t size = inbuf_.size(); - ProcessInput(inbuf_.data(), &size); - - if (size > inbuf_.size()) { - RTC_LOG(LS_ERROR) << "input buffer overflow"; - RTC_NOTREACHED(); - inbuf_.Clear(); - } else { - inbuf_.SetSize(size); - } + inbuf_.SetSize(size); } } @@ -283,12 +248,11 @@ AsyncTCPSocket* AsyncTCPSocket::Create(Socket* socket, const SocketAddress& bind_address, const SocketAddress& remote_address) { return new AsyncTCPSocket( - AsyncTCPSocketBase::ConnectSocket(socket, bind_address, remote_address), - false); + AsyncTCPSocketBase::ConnectSocket(socket, bind_address, remote_address)); } -AsyncTCPSocket::AsyncTCPSocket(Socket* socket, bool listen) - : AsyncTCPSocketBase(socket, listen, kBufSize) {} +AsyncTCPSocket::AsyncTCPSocket(Socket* socket) + : AsyncTCPSocketBase(socket, kBufSize) {} int AsyncTCPSocket::Send(const void* pv, size_t cb, @@ -343,8 +307,59 @@ void AsyncTCPSocket::ProcessInput(char* data, size_t* len) { } } -void AsyncTCPSocket::HandleIncomingConnection(Socket* socket) { - SignalNewConnection(this, new AsyncTCPSocket(socket, false)); +AsyncTcpListenSocket::AsyncTcpListenSocket(std::unique_ptr socket) + : socket_(std::move(socket)) { + RTC_DCHECK(socket_.get() != nullptr); + socket_->SignalReadEvent.connect(this, &AsyncTcpListenSocket::OnReadEvent); + if (socket_->Listen(kListenBacklog) < 0) { + RTC_LOG(LS_ERROR) << "Listen() failed with error " << socket_->GetError(); + } +} + +AsyncTcpListenSocket::State AsyncTcpListenSocket::GetState() const { + switch (socket_->GetState()) { + case Socket::CS_CLOSED: + return State::kClosed; + case Socket::CS_CONNECTING: + return State::kBound; + default: + RTC_NOTREACHED(); + return State::kClosed; + } +} + +SocketAddress AsyncTcpListenSocket::GetLocalAddress() const { + return socket_->GetLocalAddress(); +} + +int AsyncTcpListenSocket::GetOption(Socket::Option opt, int* value) { + return socket_->GetOption(opt, value); +} + +int AsyncTcpListenSocket::SetOption(Socket::Option opt, int value) { + return socket_->SetOption(opt, value); +} + +void AsyncTcpListenSocket::OnReadEvent(Socket* socket) { + RTC_DCHECK(socket_.get() == socket); + + rtc::SocketAddress address; + rtc::Socket* new_socket = socket->Accept(&address); + if (!new_socket) { + // TODO(stefan): Do something better like forwarding the error + // to the user. + RTC_LOG(LS_ERROR) << "TCP accept failed with error " << socket_->GetError(); + return; + } + + HandleIncomingConnection(new_socket); + + // Prime a read event in case data is waiting. + new_socket->SignalReadEvent(new_socket); +} + +void AsyncTcpListenSocket::HandleIncomingConnection(Socket* socket) { + SignalNewConnection(this, new AsyncTCPSocket(socket)); } } // namespace rtc diff --git a/rtc_base/async_tcp_socket.h b/rtc_base/async_tcp_socket.h index ddf9a436f6..901e5cfe33 100644 --- a/rtc_base/async_tcp_socket.h +++ b/rtc_base/async_tcp_socket.h @@ -28,7 +28,7 @@ namespace rtc { // buffer them in user space. class AsyncTCPSocketBase : public AsyncPacketSocket { public: - AsyncTCPSocketBase(Socket* socket, bool listen, size_t max_packet_size); + AsyncTCPSocketBase(Socket* socket, size_t max_packet_size); ~AsyncTCPSocketBase() override; // Pure virtual methods to send and recv data. @@ -36,8 +36,6 @@ class AsyncTCPSocketBase : public AsyncPacketSocket { size_t cb, const rtc::PacketOptions& options) override = 0; virtual void ProcessInput(char* data, size_t* len) = 0; - // Signals incoming connection. - virtual void HandleIncomingConnection(Socket* socket) = 0; SocketAddress GetLocalAddress() const override; SocketAddress GetRemoteAddress() const override; @@ -76,7 +74,6 @@ class AsyncTCPSocketBase : public AsyncPacketSocket { void OnCloseEvent(Socket* socket, int error); std::unique_ptr socket_; - bool listen_; Buffer inbuf_; Buffer outbuf_; size_t max_insize_; @@ -93,19 +90,37 @@ class AsyncTCPSocket : public AsyncTCPSocketBase { static AsyncTCPSocket* Create(Socket* socket, const SocketAddress& bind_address, const SocketAddress& remote_address); - AsyncTCPSocket(Socket* socket, bool listen); + explicit AsyncTCPSocket(Socket* socket); ~AsyncTCPSocket() override {} int Send(const void* pv, size_t cb, const rtc::PacketOptions& options) override; void ProcessInput(char* data, size_t* len) override; - void HandleIncomingConnection(Socket* socket) override; private: RTC_DISALLOW_COPY_AND_ASSIGN(AsyncTCPSocket); }; +class AsyncTcpListenSocket : public AsyncListenSocket { + public: + explicit AsyncTcpListenSocket(std::unique_ptr socket); + + State GetState() const override; + SocketAddress GetLocalAddress() const override; + + int GetOption(Socket::Option opt, int* value) override; + int SetOption(Socket::Option opt, int value) override; + + virtual void HandleIncomingConnection(rtc::Socket* socket); + + private: + // Called by the underlying socket + void OnReadEvent(Socket* socket); + + std::unique_ptr socket_; +}; + } // namespace rtc #endif // RTC_BASE_ASYNC_TCP_SOCKET_H_ diff --git a/rtc_base/nat_unittest.cc b/rtc_base/nat_unittest.cc index 4b7a117611..2e41684b78 100644 --- a/rtc_base/nat_unittest.cc +++ b/rtc_base/nat_unittest.cc @@ -56,7 +56,7 @@ TestClient* CreateTestClient(SocketFactory* factory, } TestClient* CreateTCPTestClient(Socket* socket) { - return new TestClient(std::make_unique(socket, false)); + return new TestClient(std::make_unique(socket)); } // Tests that when sending from internal_addr to external_addrs through the diff --git a/rtc_base/test_echo_server.h b/rtc_base/test_echo_server.h index 6fdfc249e4..a061ed0ce7 100644 --- a/rtc_base/test_echo_server.h +++ b/rtc_base/test_echo_server.h @@ -41,7 +41,7 @@ class TestEchoServer : public sigslot::has_slots<> { void OnAccept(Socket* socket) { Socket* raw_socket = socket->Accept(nullptr); if (raw_socket) { - AsyncTCPSocket* packet_socket = new AsyncTCPSocket(raw_socket, false); + AsyncTCPSocket* packet_socket = new AsyncTCPSocket(raw_socket); packet_socket->SignalReadPacket.connect(this, &TestEchoServer::OnPacket); packet_socket->SignalClose.connect(this, &TestEchoServer::OnClose); client_sockets_.push_back(packet_socket);