Keep listening if "accept" returns an invalid socket.
There is an issue in PhysicalSocket::Accept where the flag to continue listening is not set in "enabled_events_" if "accept" returns an error. This CL fixes this (initial idea by silviu.cpp@gmail.com). BUG=webrtc:2030 Review URL: https://codereview.webrtc.org/1452903006 Cr-Commit-Position: refs/heads/master@{#11080}
This commit is contained in:
parent
88518a22c6
commit
095ae15d6b
File diff suppressed because it is too large
Load Diff
@ -14,6 +14,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "webrtc/base/asyncfile.h"
|
||||
#include "webrtc/base/nethelpers.h"
|
||||
#include "webrtc/base/scoped_ptr.h"
|
||||
#include "webrtc/base/socketserver.h"
|
||||
#include "webrtc/base/criticalsection.h"
|
||||
@ -115,6 +116,107 @@ class PhysicalSocketServer : public SocketServer {
|
||||
#endif
|
||||
};
|
||||
|
||||
class PhysicalSocket : public AsyncSocket, public sigslot::has_slots<> {
|
||||
public:
|
||||
PhysicalSocket(PhysicalSocketServer* ss, SOCKET s = INVALID_SOCKET);
|
||||
~PhysicalSocket() override;
|
||||
|
||||
// Creates the underlying OS socket (same as the "socket" function).
|
||||
virtual bool Create(int family, int type);
|
||||
|
||||
SocketAddress GetLocalAddress() const override;
|
||||
SocketAddress GetRemoteAddress() const override;
|
||||
|
||||
int Bind(const SocketAddress& bind_addr) override;
|
||||
int Connect(const SocketAddress& addr) override;
|
||||
|
||||
int GetError() const override;
|
||||
void SetError(int error) override;
|
||||
|
||||
ConnState GetState() const override;
|
||||
|
||||
int GetOption(Option opt, int* value) override;
|
||||
int SetOption(Option opt, int value) override;
|
||||
|
||||
int Send(const void* pv, size_t cb) override;
|
||||
int SendTo(const void* buffer,
|
||||
size_t length,
|
||||
const SocketAddress& addr) override;
|
||||
|
||||
int Recv(void* buffer, size_t length) override;
|
||||
int RecvFrom(void* buffer, size_t length, SocketAddress* out_addr) override;
|
||||
|
||||
int Listen(int backlog) override;
|
||||
AsyncSocket* Accept(SocketAddress* out_addr) override;
|
||||
|
||||
int Close() override;
|
||||
|
||||
int EstimateMTU(uint16_t* mtu) override;
|
||||
|
||||
SocketServer* socketserver() { return ss_; }
|
||||
|
||||
protected:
|
||||
int DoConnect(const SocketAddress& connect_addr);
|
||||
|
||||
// Make virtual so ::accept can be overwritten in tests.
|
||||
virtual SOCKET DoAccept(SOCKET socket, sockaddr* addr, socklen_t* addrlen);
|
||||
|
||||
void OnResolveResult(AsyncResolverInterface* resolver);
|
||||
|
||||
void UpdateLastError();
|
||||
void MaybeRemapSendError();
|
||||
|
||||
static int TranslateOption(Option opt, int* slevel, int* sopt);
|
||||
|
||||
PhysicalSocketServer* ss_;
|
||||
SOCKET s_;
|
||||
uint8_t enabled_events_;
|
||||
bool udp_;
|
||||
mutable CriticalSection crit_;
|
||||
int error_ GUARDED_BY(crit_);
|
||||
ConnState state_;
|
||||
AsyncResolver* resolver_;
|
||||
|
||||
#if !defined(NDEBUG)
|
||||
std::string dbg_addr_;
|
||||
#endif
|
||||
};
|
||||
|
||||
class SocketDispatcher : public Dispatcher, public PhysicalSocket {
|
||||
public:
|
||||
explicit SocketDispatcher(PhysicalSocketServer *ss);
|
||||
SocketDispatcher(SOCKET s, PhysicalSocketServer *ss);
|
||||
~SocketDispatcher() override;
|
||||
|
||||
bool Initialize();
|
||||
|
||||
virtual bool Create(int type);
|
||||
bool Create(int family, int type) override;
|
||||
|
||||
#if defined(WEBRTC_WIN)
|
||||
WSAEVENT GetWSAEvent() override;
|
||||
SOCKET GetSocket() override;
|
||||
bool CheckSignalClose() override;
|
||||
#elif defined(WEBRTC_POSIX)
|
||||
int GetDescriptor() override;
|
||||
bool IsDescriptorClosed() override;
|
||||
#endif
|
||||
|
||||
uint32_t GetRequestedEvents() override;
|
||||
void OnPreEvent(uint32_t ff) override;
|
||||
void OnEvent(uint32_t ff, int err) override;
|
||||
|
||||
int Close() override;
|
||||
|
||||
#if defined(WEBRTC_WIN)
|
||||
private:
|
||||
static int next_id_;
|
||||
int id_;
|
||||
bool signal_close_;
|
||||
int signal_err_;
|
||||
#endif // WEBRTC_WIN
|
||||
};
|
||||
|
||||
} // namespace rtc
|
||||
|
||||
#endif // WEBRTC_BASE_PHYSICALSOCKETSERVER_H__
|
||||
|
||||
@ -22,9 +22,82 @@
|
||||
|
||||
namespace rtc {
|
||||
|
||||
class PhysicalSocketTest : public SocketTest {
|
||||
class PhysicalSocketTest;
|
||||
|
||||
class FakeSocketDispatcher : public SocketDispatcher {
|
||||
public:
|
||||
explicit FakeSocketDispatcher(PhysicalSocketServer* ss)
|
||||
: SocketDispatcher(ss) {
|
||||
}
|
||||
|
||||
protected:
|
||||
SOCKET DoAccept(SOCKET socket, sockaddr* addr, socklen_t* addrlen) override;
|
||||
};
|
||||
|
||||
class FakePhysicalSocketServer : public PhysicalSocketServer {
|
||||
public:
|
||||
explicit FakePhysicalSocketServer(PhysicalSocketTest* test)
|
||||
: test_(test) {
|
||||
}
|
||||
|
||||
AsyncSocket* CreateAsyncSocket(int type) override {
|
||||
SocketDispatcher* dispatcher = new FakeSocketDispatcher(this);
|
||||
if (dispatcher->Create(type)) {
|
||||
return dispatcher;
|
||||
} else {
|
||||
delete dispatcher;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
AsyncSocket* CreateAsyncSocket(int family, int type) override {
|
||||
SocketDispatcher* dispatcher = new FakeSocketDispatcher(this);
|
||||
if (dispatcher->Create(family, type)) {
|
||||
return dispatcher;
|
||||
} else {
|
||||
delete dispatcher;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
PhysicalSocketTest* GetTest() const { return test_; }
|
||||
|
||||
private:
|
||||
PhysicalSocketTest* test_;
|
||||
};
|
||||
|
||||
class PhysicalSocketTest : public SocketTest {
|
||||
public:
|
||||
// Set flag to simluate failures when calling "::accept" on a AsyncSocket.
|
||||
void SetFailAccept(bool fail) { fail_accept_ = fail; }
|
||||
bool FailAccept() const { return fail_accept_; }
|
||||
|
||||
protected:
|
||||
PhysicalSocketTest()
|
||||
: server_(new FakePhysicalSocketServer(this)),
|
||||
scope_(server_.get()),
|
||||
fail_accept_(false) {
|
||||
}
|
||||
|
||||
void ConnectInternalAcceptError(const IPAddress& loopback);
|
||||
|
||||
rtc::scoped_ptr<FakePhysicalSocketServer> server_;
|
||||
SocketServerScope scope_;
|
||||
bool fail_accept_;
|
||||
};
|
||||
|
||||
SOCKET FakeSocketDispatcher::DoAccept(SOCKET socket,
|
||||
sockaddr* addr,
|
||||
socklen_t* addrlen) {
|
||||
FakePhysicalSocketServer* ss =
|
||||
static_cast<FakePhysicalSocketServer*>(socketserver());
|
||||
if (ss->GetTest()->FailAccept()) {
|
||||
return INVALID_SOCKET;
|
||||
}
|
||||
|
||||
return SocketDispatcher::DoAccept(socket, addr, addrlen);
|
||||
}
|
||||
|
||||
TEST_F(PhysicalSocketTest, TestConnectIPv4) {
|
||||
SocketTest::TestConnectIPv4();
|
||||
}
|
||||
@ -51,6 +124,92 @@ TEST_F(PhysicalSocketTest, TestConnectFailIPv4) {
|
||||
SocketTest::TestConnectFailIPv4();
|
||||
}
|
||||
|
||||
void PhysicalSocketTest::ConnectInternalAcceptError(const IPAddress& loopback) {
|
||||
testing::StreamSink sink;
|
||||
SocketAddress accept_addr;
|
||||
|
||||
// Create two clients.
|
||||
scoped_ptr<AsyncSocket> client1(server_->CreateAsyncSocket(loopback.family(),
|
||||
SOCK_STREAM));
|
||||
sink.Monitor(client1.get());
|
||||
EXPECT_EQ(AsyncSocket::CS_CLOSED, client1->GetState());
|
||||
EXPECT_PRED1(IsUnspecOrEmptyIP, client1->GetLocalAddress().ipaddr());
|
||||
|
||||
scoped_ptr<AsyncSocket> client2(server_->CreateAsyncSocket(loopback.family(),
|
||||
SOCK_STREAM));
|
||||
sink.Monitor(client2.get());
|
||||
EXPECT_EQ(AsyncSocket::CS_CLOSED, client2->GetState());
|
||||
EXPECT_PRED1(IsUnspecOrEmptyIP, client2->GetLocalAddress().ipaddr());
|
||||
|
||||
// Create server and listen.
|
||||
scoped_ptr<AsyncSocket> server(
|
||||
server_->CreateAsyncSocket(loopback.family(), SOCK_STREAM));
|
||||
sink.Monitor(server.get());
|
||||
EXPECT_EQ(0, server->Bind(SocketAddress(loopback, 0)));
|
||||
EXPECT_EQ(0, server->Listen(5));
|
||||
EXPECT_EQ(AsyncSocket::CS_CONNECTING, server->GetState());
|
||||
|
||||
// Ensure no pending server connections, since we haven't done anything yet.
|
||||
EXPECT_FALSE(sink.Check(server.get(), testing::SSE_READ));
|
||||
EXPECT_TRUE(nullptr == server->Accept(&accept_addr));
|
||||
EXPECT_TRUE(accept_addr.IsNil());
|
||||
|
||||
// Attempt first connect to listening socket.
|
||||
EXPECT_EQ(0, client1->Connect(server->GetLocalAddress()));
|
||||
EXPECT_FALSE(client1->GetLocalAddress().IsNil());
|
||||
EXPECT_NE(server->GetLocalAddress(), client1->GetLocalAddress());
|
||||
|
||||
// Client is connecting, outcome not yet determined.
|
||||
EXPECT_EQ(AsyncSocket::CS_CONNECTING, client1->GetState());
|
||||
EXPECT_FALSE(sink.Check(client1.get(), testing::SSE_OPEN));
|
||||
EXPECT_FALSE(sink.Check(client1.get(), testing::SSE_CLOSE));
|
||||
|
||||
// Server has pending connection, try to accept it (will fail).
|
||||
EXPECT_TRUE_WAIT((sink.Check(server.get(), testing::SSE_READ)), kTimeout);
|
||||
// Simulate "::accept" returning an error.
|
||||
SetFailAccept(true);
|
||||
scoped_ptr<AsyncSocket> accepted(server->Accept(&accept_addr));
|
||||
EXPECT_FALSE(accepted);
|
||||
ASSERT_TRUE(accept_addr.IsNil());
|
||||
|
||||
// Ensure no more pending server connections.
|
||||
EXPECT_FALSE(sink.Check(server.get(), testing::SSE_READ));
|
||||
EXPECT_TRUE(nullptr == server->Accept(&accept_addr));
|
||||
EXPECT_TRUE(accept_addr.IsNil());
|
||||
|
||||
// Attempt second connect to listening socket.
|
||||
EXPECT_EQ(0, client2->Connect(server->GetLocalAddress()));
|
||||
EXPECT_FALSE(client2->GetLocalAddress().IsNil());
|
||||
EXPECT_NE(server->GetLocalAddress(), client2->GetLocalAddress());
|
||||
|
||||
// Client is connecting, outcome not yet determined.
|
||||
EXPECT_EQ(AsyncSocket::CS_CONNECTING, client2->GetState());
|
||||
EXPECT_FALSE(sink.Check(client2.get(), testing::SSE_OPEN));
|
||||
EXPECT_FALSE(sink.Check(client2.get(), testing::SSE_CLOSE));
|
||||
|
||||
// Server has pending connection, try to accept it (will succeed).
|
||||
EXPECT_TRUE_WAIT((sink.Check(server.get(), testing::SSE_READ)), kTimeout);
|
||||
SetFailAccept(false);
|
||||
scoped_ptr<AsyncSocket> accepted2(server->Accept(&accept_addr));
|
||||
ASSERT_TRUE(accepted2);
|
||||
EXPECT_FALSE(accept_addr.IsNil());
|
||||
EXPECT_EQ(accepted2->GetRemoteAddress(), accept_addr);
|
||||
}
|
||||
|
||||
TEST_F(PhysicalSocketTest, TestConnectAcceptErrorIPv4) {
|
||||
ConnectInternalAcceptError(kIPv4Loopback);
|
||||
}
|
||||
|
||||
// Crashes on Linux. See webrtc:4923.
|
||||
#if defined(WEBRTC_LINUX)
|
||||
#define MAYBE_TestConnectAcceptErrorIPv6 DISABLED_TestConnectAcceptErrorIPv6
|
||||
#else
|
||||
#define MAYBE_TestConnectAcceptErrorIPv6 TestConnectAcceptErrorIPv6
|
||||
#endif
|
||||
TEST_F(PhysicalSocketTest, MAYBE_TestConnectAcceptErrorIPv6) {
|
||||
ConnectInternalAcceptError(kIPv6Loopback);
|
||||
}
|
||||
|
||||
// Crashes on Linux. See webrtc:4923.
|
||||
#if defined(WEBRTC_LINUX)
|
||||
#define MAYBE_TestConnectFailIPv6 DISABLED_TestConnectFailIPv6
|
||||
|
||||
@ -21,8 +21,9 @@ namespace rtc {
|
||||
// socketserver, and call the SocketTest test methods.
|
||||
class SocketTest : public testing::Test {
|
||||
protected:
|
||||
SocketTest() : ss_(NULL), kIPv4Loopback(INADDR_LOOPBACK),
|
||||
kIPv6Loopback(in6addr_loopback) {}
|
||||
SocketTest() : kIPv4Loopback(INADDR_LOOPBACK),
|
||||
kIPv6Loopback(in6addr_loopback),
|
||||
ss_(nullptr) {}
|
||||
virtual void SetUp() { ss_ = Thread::Current()->socketserver(); }
|
||||
void TestConnectIPv4();
|
||||
void TestConnectIPv6();
|
||||
@ -57,6 +58,10 @@ class SocketTest : public testing::Test {
|
||||
void TestGetSetOptionsIPv4();
|
||||
void TestGetSetOptionsIPv6();
|
||||
|
||||
static const int kTimeout = 5000; // ms
|
||||
const IPAddress kIPv4Loopback;
|
||||
const IPAddress kIPv6Loopback;
|
||||
|
||||
private:
|
||||
void ConnectInternal(const IPAddress& loopback);
|
||||
void ConnectWithDnsLookupInternal(const IPAddress& loopback,
|
||||
@ -77,12 +82,13 @@ class SocketTest : public testing::Test {
|
||||
void UdpReadyToSend(const IPAddress& loopback);
|
||||
void GetSetOptionsInternal(const IPAddress& loopback);
|
||||
|
||||
static const int kTimeout = 5000; // ms
|
||||
SocketServer* ss_;
|
||||
const IPAddress kIPv4Loopback;
|
||||
const IPAddress kIPv6Loopback;
|
||||
};
|
||||
|
||||
// For unbound sockets, GetLocalAddress / GetRemoteAddress return AF_UNSPEC
|
||||
// values on Windows, but an empty address of the same family on Linux/MacOS X.
|
||||
bool IsUnspecOrEmptyIP(const IPAddress& address);
|
||||
|
||||
} // namespace rtc
|
||||
|
||||
#endif // WEBRTC_BASE_SOCKET_UNITTEST_H_
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user