Keep listening if "accept" returns an invalid socket.

There is an issue in PhysicalSocket::Accept where the flag to continue
listening is not set in "enabled_events_" if "accept" returns an error.
This CL fixes this (initial idea by silviu.cpp@gmail.com).

BUG=webrtc:2030

Review URL: https://codereview.webrtc.org/1452903006

Cr-Commit-Position: refs/heads/master@{#11080}
This commit is contained in:
jbauch 2015-12-18 01:39:55 -08:00 committed by Commit bot
parent 88518a22c6
commit 095ae15d6b
4 changed files with 855 additions and 599 deletions

File diff suppressed because it is too large Load Diff

View File

@ -14,6 +14,7 @@
#include <vector>
#include "webrtc/base/asyncfile.h"
#include "webrtc/base/nethelpers.h"
#include "webrtc/base/scoped_ptr.h"
#include "webrtc/base/socketserver.h"
#include "webrtc/base/criticalsection.h"
@ -115,6 +116,107 @@ class PhysicalSocketServer : public SocketServer {
#endif
};
class PhysicalSocket : public AsyncSocket, public sigslot::has_slots<> {
public:
PhysicalSocket(PhysicalSocketServer* ss, SOCKET s = INVALID_SOCKET);
~PhysicalSocket() override;
// Creates the underlying OS socket (same as the "socket" function).
virtual bool Create(int family, int type);
SocketAddress GetLocalAddress() const override;
SocketAddress GetRemoteAddress() const override;
int Bind(const SocketAddress& bind_addr) override;
int Connect(const SocketAddress& addr) override;
int GetError() const override;
void SetError(int error) override;
ConnState GetState() const override;
int GetOption(Option opt, int* value) override;
int SetOption(Option opt, int value) override;
int Send(const void* pv, size_t cb) override;
int SendTo(const void* buffer,
size_t length,
const SocketAddress& addr) override;
int Recv(void* buffer, size_t length) override;
int RecvFrom(void* buffer, size_t length, SocketAddress* out_addr) override;
int Listen(int backlog) override;
AsyncSocket* Accept(SocketAddress* out_addr) override;
int Close() override;
int EstimateMTU(uint16_t* mtu) override;
SocketServer* socketserver() { return ss_; }
protected:
int DoConnect(const SocketAddress& connect_addr);
// Make virtual so ::accept can be overwritten in tests.
virtual SOCKET DoAccept(SOCKET socket, sockaddr* addr, socklen_t* addrlen);
void OnResolveResult(AsyncResolverInterface* resolver);
void UpdateLastError();
void MaybeRemapSendError();
static int TranslateOption(Option opt, int* slevel, int* sopt);
PhysicalSocketServer* ss_;
SOCKET s_;
uint8_t enabled_events_;
bool udp_;
mutable CriticalSection crit_;
int error_ GUARDED_BY(crit_);
ConnState state_;
AsyncResolver* resolver_;
#if !defined(NDEBUG)
std::string dbg_addr_;
#endif
};
class SocketDispatcher : public Dispatcher, public PhysicalSocket {
public:
explicit SocketDispatcher(PhysicalSocketServer *ss);
SocketDispatcher(SOCKET s, PhysicalSocketServer *ss);
~SocketDispatcher() override;
bool Initialize();
virtual bool Create(int type);
bool Create(int family, int type) override;
#if defined(WEBRTC_WIN)
WSAEVENT GetWSAEvent() override;
SOCKET GetSocket() override;
bool CheckSignalClose() override;
#elif defined(WEBRTC_POSIX)
int GetDescriptor() override;
bool IsDescriptorClosed() override;
#endif
uint32_t GetRequestedEvents() override;
void OnPreEvent(uint32_t ff) override;
void OnEvent(uint32_t ff, int err) override;
int Close() override;
#if defined(WEBRTC_WIN)
private:
static int next_id_;
int id_;
bool signal_close_;
int signal_err_;
#endif // WEBRTC_WIN
};
} // namespace rtc
#endif // WEBRTC_BASE_PHYSICALSOCKETSERVER_H__

View File

@ -22,9 +22,82 @@
namespace rtc {
class PhysicalSocketTest : public SocketTest {
class PhysicalSocketTest;
class FakeSocketDispatcher : public SocketDispatcher {
public:
explicit FakeSocketDispatcher(PhysicalSocketServer* ss)
: SocketDispatcher(ss) {
}
protected:
SOCKET DoAccept(SOCKET socket, sockaddr* addr, socklen_t* addrlen) override;
};
class FakePhysicalSocketServer : public PhysicalSocketServer {
public:
explicit FakePhysicalSocketServer(PhysicalSocketTest* test)
: test_(test) {
}
AsyncSocket* CreateAsyncSocket(int type) override {
SocketDispatcher* dispatcher = new FakeSocketDispatcher(this);
if (dispatcher->Create(type)) {
return dispatcher;
} else {
delete dispatcher;
return nullptr;
}
}
AsyncSocket* CreateAsyncSocket(int family, int type) override {
SocketDispatcher* dispatcher = new FakeSocketDispatcher(this);
if (dispatcher->Create(family, type)) {
return dispatcher;
} else {
delete dispatcher;
return nullptr;
}
}
PhysicalSocketTest* GetTest() const { return test_; }
private:
PhysicalSocketTest* test_;
};
class PhysicalSocketTest : public SocketTest {
public:
// Set flag to simluate failures when calling "::accept" on a AsyncSocket.
void SetFailAccept(bool fail) { fail_accept_ = fail; }
bool FailAccept() const { return fail_accept_; }
protected:
PhysicalSocketTest()
: server_(new FakePhysicalSocketServer(this)),
scope_(server_.get()),
fail_accept_(false) {
}
void ConnectInternalAcceptError(const IPAddress& loopback);
rtc::scoped_ptr<FakePhysicalSocketServer> server_;
SocketServerScope scope_;
bool fail_accept_;
};
SOCKET FakeSocketDispatcher::DoAccept(SOCKET socket,
sockaddr* addr,
socklen_t* addrlen) {
FakePhysicalSocketServer* ss =
static_cast<FakePhysicalSocketServer*>(socketserver());
if (ss->GetTest()->FailAccept()) {
return INVALID_SOCKET;
}
return SocketDispatcher::DoAccept(socket, addr, addrlen);
}
TEST_F(PhysicalSocketTest, TestConnectIPv4) {
SocketTest::TestConnectIPv4();
}
@ -51,6 +124,92 @@ TEST_F(PhysicalSocketTest, TestConnectFailIPv4) {
SocketTest::TestConnectFailIPv4();
}
void PhysicalSocketTest::ConnectInternalAcceptError(const IPAddress& loopback) {
testing::StreamSink sink;
SocketAddress accept_addr;
// Create two clients.
scoped_ptr<AsyncSocket> client1(server_->CreateAsyncSocket(loopback.family(),
SOCK_STREAM));
sink.Monitor(client1.get());
EXPECT_EQ(AsyncSocket::CS_CLOSED, client1->GetState());
EXPECT_PRED1(IsUnspecOrEmptyIP, client1->GetLocalAddress().ipaddr());
scoped_ptr<AsyncSocket> client2(server_->CreateAsyncSocket(loopback.family(),
SOCK_STREAM));
sink.Monitor(client2.get());
EXPECT_EQ(AsyncSocket::CS_CLOSED, client2->GetState());
EXPECT_PRED1(IsUnspecOrEmptyIP, client2->GetLocalAddress().ipaddr());
// Create server and listen.
scoped_ptr<AsyncSocket> server(
server_->CreateAsyncSocket(loopback.family(), SOCK_STREAM));
sink.Monitor(server.get());
EXPECT_EQ(0, server->Bind(SocketAddress(loopback, 0)));
EXPECT_EQ(0, server->Listen(5));
EXPECT_EQ(AsyncSocket::CS_CONNECTING, server->GetState());
// Ensure no pending server connections, since we haven't done anything yet.
EXPECT_FALSE(sink.Check(server.get(), testing::SSE_READ));
EXPECT_TRUE(nullptr == server->Accept(&accept_addr));
EXPECT_TRUE(accept_addr.IsNil());
// Attempt first connect to listening socket.
EXPECT_EQ(0, client1->Connect(server->GetLocalAddress()));
EXPECT_FALSE(client1->GetLocalAddress().IsNil());
EXPECT_NE(server->GetLocalAddress(), client1->GetLocalAddress());
// Client is connecting, outcome not yet determined.
EXPECT_EQ(AsyncSocket::CS_CONNECTING, client1->GetState());
EXPECT_FALSE(sink.Check(client1.get(), testing::SSE_OPEN));
EXPECT_FALSE(sink.Check(client1.get(), testing::SSE_CLOSE));
// Server has pending connection, try to accept it (will fail).
EXPECT_TRUE_WAIT((sink.Check(server.get(), testing::SSE_READ)), kTimeout);
// Simulate "::accept" returning an error.
SetFailAccept(true);
scoped_ptr<AsyncSocket> accepted(server->Accept(&accept_addr));
EXPECT_FALSE(accepted);
ASSERT_TRUE(accept_addr.IsNil());
// Ensure no more pending server connections.
EXPECT_FALSE(sink.Check(server.get(), testing::SSE_READ));
EXPECT_TRUE(nullptr == server->Accept(&accept_addr));
EXPECT_TRUE(accept_addr.IsNil());
// Attempt second connect to listening socket.
EXPECT_EQ(0, client2->Connect(server->GetLocalAddress()));
EXPECT_FALSE(client2->GetLocalAddress().IsNil());
EXPECT_NE(server->GetLocalAddress(), client2->GetLocalAddress());
// Client is connecting, outcome not yet determined.
EXPECT_EQ(AsyncSocket::CS_CONNECTING, client2->GetState());
EXPECT_FALSE(sink.Check(client2.get(), testing::SSE_OPEN));
EXPECT_FALSE(sink.Check(client2.get(), testing::SSE_CLOSE));
// Server has pending connection, try to accept it (will succeed).
EXPECT_TRUE_WAIT((sink.Check(server.get(), testing::SSE_READ)), kTimeout);
SetFailAccept(false);
scoped_ptr<AsyncSocket> accepted2(server->Accept(&accept_addr));
ASSERT_TRUE(accepted2);
EXPECT_FALSE(accept_addr.IsNil());
EXPECT_EQ(accepted2->GetRemoteAddress(), accept_addr);
}
TEST_F(PhysicalSocketTest, TestConnectAcceptErrorIPv4) {
ConnectInternalAcceptError(kIPv4Loopback);
}
// Crashes on Linux. See webrtc:4923.
#if defined(WEBRTC_LINUX)
#define MAYBE_TestConnectAcceptErrorIPv6 DISABLED_TestConnectAcceptErrorIPv6
#else
#define MAYBE_TestConnectAcceptErrorIPv6 TestConnectAcceptErrorIPv6
#endif
TEST_F(PhysicalSocketTest, MAYBE_TestConnectAcceptErrorIPv6) {
ConnectInternalAcceptError(kIPv6Loopback);
}
// Crashes on Linux. See webrtc:4923.
#if defined(WEBRTC_LINUX)
#define MAYBE_TestConnectFailIPv6 DISABLED_TestConnectFailIPv6

View File

@ -21,8 +21,9 @@ namespace rtc {
// socketserver, and call the SocketTest test methods.
class SocketTest : public testing::Test {
protected:
SocketTest() : ss_(NULL), kIPv4Loopback(INADDR_LOOPBACK),
kIPv6Loopback(in6addr_loopback) {}
SocketTest() : kIPv4Loopback(INADDR_LOOPBACK),
kIPv6Loopback(in6addr_loopback),
ss_(nullptr) {}
virtual void SetUp() { ss_ = Thread::Current()->socketserver(); }
void TestConnectIPv4();
void TestConnectIPv6();
@ -57,6 +58,10 @@ class SocketTest : public testing::Test {
void TestGetSetOptionsIPv4();
void TestGetSetOptionsIPv6();
static const int kTimeout = 5000; // ms
const IPAddress kIPv4Loopback;
const IPAddress kIPv6Loopback;
private:
void ConnectInternal(const IPAddress& loopback);
void ConnectWithDnsLookupInternal(const IPAddress& loopback,
@ -77,12 +82,13 @@ class SocketTest : public testing::Test {
void UdpReadyToSend(const IPAddress& loopback);
void GetSetOptionsInternal(const IPAddress& loopback);
static const int kTimeout = 5000; // ms
SocketServer* ss_;
const IPAddress kIPv4Loopback;
const IPAddress kIPv6Loopback;
};
// For unbound sockets, GetLocalAddress / GetRemoteAddress return AF_UNSPEC
// values on Windows, but an empty address of the same family on Linux/MacOS X.
bool IsUnspecOrEmptyIP(const IPAddress& address);
} // namespace rtc
#endif // WEBRTC_BASE_SOCKET_UNITTEST_H_