diff --git a/p2p/base/connection.cc b/p2p/base/connection.cc index 0187163d08..38a186da98 100644 --- a/p2p/base/connection.cc +++ b/p2p/base/connection.cc @@ -163,8 +163,10 @@ constexpr int kSupportGoogPingVersionResponseIndex = namespace cricket { // A ConnectionRequest is a STUN binding used to determine writability. -ConnectionRequest::ConnectionRequest(Connection* connection) - : StunRequest(new IceMessage()), connection_(connection) {} +ConnectionRequest::ConnectionRequest(StunRequestManager& manager, + Connection* connection) + : StunRequest(manager, std::make_unique()), + connection_(connection) {} void ConnectionRequest::Prepare(StunMessage* request) { RTC_DCHECK_RUN_ON(connection_->network_thread_); @@ -276,7 +278,7 @@ void ConnectionRequest::OnSent() { connection_->OnConnectionRequestSent(this); // Each request is sent only once. After a single delay , the request will // time out. - timeout_ = true; + set_timed_out(); } int ConnectionRequest::resend_delay() { @@ -986,7 +988,7 @@ int64_t Connection::last_ping_sent() const { void Connection::Ping(int64_t now) { RTC_DCHECK_RUN_ON(network_thread_); last_ping_sent_ = now; - ConnectionRequest* req = new ConnectionRequest(this); + ConnectionRequest* req = new ConnectionRequest(requests_, this); // If not using renomination, we use "1" to mean "nominated" and "0" to mean // "not nominated". If using renomination, values greater than 1 are used for // re-nominated pairs. diff --git a/p2p/base/connection.h b/p2p/base/connection.h index e07482ac0d..d871bc4ceb 100644 --- a/p2p/base/connection.h +++ b/p2p/base/connection.h @@ -57,7 +57,7 @@ struct CandidatePair final : public CandidatePairInterface { // A ConnectionRequest is a simple STUN ping used to determine writability. class ConnectionRequest : public StunRequest { public: - explicit ConnectionRequest(Connection* connection); + ConnectionRequest(StunRequestManager& manager, Connection* connection); void Prepare(StunMessage* request) override; void OnResponse(StunMessage* response) override; void OnErrorResponse(StunMessage* response) override; diff --git a/p2p/base/stun_port.cc b/p2p/base/stun_port.cc index d27ca2f025..5d7c426d45 100644 --- a/p2p/base/stun_port.cc +++ b/p2p/base/stun_port.cc @@ -40,7 +40,10 @@ class StunBindingRequest : public StunRequest { StunBindingRequest(UDPPort* port, const rtc::SocketAddress& addr, int64_t start_time) - : port_(port), server_addr_(addr), start_time_(start_time) {} + : StunRequest(port->request_manager()), + port_(port), + server_addr_(addr), + start_time_(start_time) {} const rtc::SocketAddress& server_addr() const { return server_addr_; } @@ -63,7 +66,7 @@ class StunBindingRequest : public StunRequest { // The keep-alive requests will be stopped after its lifetime has passed. if (WithinLifetime(rtc::TimeMillis())) { - port_->requests_.SendDelayed( + port_->request_manager_.SendDelayed( new StunBindingRequest(port_, server_addr_, start_time_), port_->stun_keepalive_delay()); } @@ -88,7 +91,7 @@ class StunBindingRequest : public StunRequest { int64_t now = rtc::TimeMillis(); if (WithinLifetime(now) && rtc::TimeDiff(now, start_time_) < RETRY_TIMEOUT) { - port_->requests_.SendDelayed( + port_->request_manager_.SendDelayed( new StunBindingRequest(port_, server_addr_, start_time_), port_->stun_keepalive_delay()); } @@ -166,7 +169,7 @@ UDPPort::UDPPort(rtc::Thread* thread, username, password, field_trials), - requests_(thread), + request_manager_(thread), socket_(socket), error_(0), ready_(false), @@ -192,7 +195,7 @@ UDPPort::UDPPort(rtc::Thread* thread, username, password, field_trials), - requests_(thread), + request_manager_(thread), socket_(nullptr), error_(0), ready_(false), @@ -215,7 +218,7 @@ bool UDPPort::Init() { socket_->SignalSentPacket.connect(this, &UDPPort::OnSentPacket); socket_->SignalReadyToSend.connect(this, &UDPPort::OnReadyToSend); socket_->SignalAddressReady.connect(this, &UDPPort::OnLocalAddressReady); - requests_.SignalSendPacket.connect(this, &UDPPort::OnSendPacket); + request_manager_.SignalSendPacket.connect(this, &UDPPort::OnSendPacket); return true; } @@ -225,7 +228,7 @@ UDPPort::~UDPPort() { } void UDPPort::PrepareAddress() { - RTC_DCHECK(requests_.empty()); + RTC_DCHECK(request_manager_.empty()); if (socket_->GetState() == rtc::AsyncPacketSocket::STATE_BOUND) { OnLocalAddressReady(socket_, socket_->GetLocalAddress()); } @@ -390,7 +393,7 @@ void UDPPort::OnReadPacket(rtc::AsyncPacketSocket* socket, // will eat it because it might be a response to a retransmitted packet, and // we already cleared the request when we got the first response. if (server_addresses_.find(remote_addr) != server_addresses_.end()) { - requests_.CheckResponse(data, size); + request_manager_.CheckResponse(data, size); return; } @@ -413,7 +416,7 @@ void UDPPort::OnReadyToSend(rtc::AsyncPacketSocket* socket) { void UDPPort::SendStunBindingRequests() { // We will keep pinging the stun server to make sure our NAT pin-hole stays // open until the deadline (specified in SendStunBindingRequest). - RTC_DCHECK(requests_.empty()); + RTC_DCHECK(request_manager_.empty()); for (ServerAddresses::const_iterator it = server_addresses_.begin(); it != server_addresses_.end(); ++it) { @@ -463,7 +466,7 @@ void UDPPort::SendStunBindingRequest(const rtc::SocketAddress& stun_addr) { } else if (socket_->GetState() == rtc::AsyncPacketSocket::STATE_BOUND) { // Check if `server_addr_` is compatible with the port's ip. if (IsCompatibleAddress(stun_addr)) { - requests_.Send( + request_manager_.Send( new StunBindingRequest(this, stun_addr, rtc::TimeMillis())); } else { // Since we can't send stun messages to the server, we should mark this diff --git a/p2p/base/stun_port.h b/p2p/base/stun_port.h index de40745c3f..3968c17a26 100644 --- a/p2p/base/stun_port.h +++ b/p2p/base/stun_port.h @@ -114,10 +114,12 @@ class UDPPort : public Port { stun_keepalive_lifetime_ = lifetime; } // Returns true if there is a pending request with type `msg_type`. - bool HasPendingRequest(int msg_type) { - return requests_.HasRequest(msg_type); + bool HasPendingRequestForTest(int msg_type) { + return request_manager_.HasRequestForTest(msg_type); } + StunRequestManager& request_manager() { return request_manager_; } + protected: UDPPort(rtc::Thread* thread, rtc::PacketSocketFactory* factory, @@ -244,7 +246,7 @@ class UDPPort : public Port { ServerAddresses server_addresses_; ServerAddresses bind_request_succeeded_servers_; ServerAddresses bind_request_failed_servers_; - StunRequestManager requests_; + StunRequestManager request_manager_; rtc::AsyncPacketSocket* socket_; int error_; int send_error_count_ = 0; diff --git a/p2p/base/stun_port_unittest.cc b/p2p/base/stun_port_unittest.cc index 609de9b152..fa51ed6666 100644 --- a/p2p/base/stun_port_unittest.cc +++ b/p2p/base/stun_port_unittest.cc @@ -393,7 +393,7 @@ TEST_F(StunPortTest, TestStunBindingRequestShortLifetime) { PrepareAddress(); EXPECT_TRUE_SIMULATED_WAIT(done(), kTimeoutMs, fake_clock); EXPECT_TRUE_SIMULATED_WAIT( - !port()->HasPendingRequest(cricket::STUN_BINDING_REQUEST), 2000, + !port()->HasPendingRequestForTest(cricket::STUN_BINDING_REQUEST), 2000, fake_clock); } @@ -404,7 +404,7 @@ TEST_F(StunPortTest, TestStunBindingRequestLongLifetime) { PrepareAddress(); EXPECT_TRUE_SIMULATED_WAIT(done(), kTimeoutMs, fake_clock); EXPECT_TRUE_SIMULATED_WAIT( - port()->HasPendingRequest(cricket::STUN_BINDING_REQUEST), 1000, + port()->HasPendingRequestForTest(cricket::STUN_BINDING_REQUEST), 1000, fake_clock); } diff --git a/p2p/base/stun_request.cc b/p2p/base/stun_request.cc index ed94ccb5fc..532e5821a5 100644 --- a/p2p/base/stun_request.cc +++ b/p2p/base/stun_request.cc @@ -12,6 +12,7 @@ #include #include +#include #include #include "rtc_base/checks.h" @@ -56,7 +57,8 @@ void StunRequestManager::Send(StunRequest* request) { } void StunRequestManager::SendDelayed(StunRequest* request, int delay) { - request->set_manager(this); + RTC_DCHECK_RUN_ON(thread_); + RTC_DCHECK_EQ(this, request->manager()); RTC_DCHECK(requests_.find(request->id()) == requests_.end()); request->Construct(); requests_[request->id()] = request; @@ -67,7 +69,8 @@ void StunRequestManager::SendDelayed(StunRequest* request, int delay) { } } -void StunRequestManager::Flush(int msg_type) { +void StunRequestManager::FlushForTest(int msg_type) { + RTC_DCHECK_RUN_ON(thread_); for (const auto& kv : requests_) { StunRequest* request = kv.second; if (msg_type == kAllRequests || msg_type == request->type()) { @@ -77,7 +80,8 @@ void StunRequestManager::Flush(int msg_type) { } } -bool StunRequestManager::HasRequest(int msg_type) { +bool StunRequestManager::HasRequestForTest(int msg_type) { + RTC_DCHECK_RUN_ON(thread_); for (const auto& kv : requests_) { StunRequest* request = kv.second; if (msg_type == kAllRequests || msg_type == request->type()) { @@ -88,6 +92,7 @@ bool StunRequestManager::HasRequest(int msg_type) { } void StunRequestManager::Remove(StunRequest* request) { + RTC_DCHECK_RUN_ON(thread_); RTC_DCHECK(request->manager() == this); RequestMap::iterator iter = requests_.find(request->id()); if (iter != requests_.end()) { @@ -98,6 +103,7 @@ void StunRequestManager::Remove(StunRequest* request) { } void StunRequestManager::Clear() { + RTC_DCHECK_RUN_ON(thread_); std::vector requests; for (RequestMap::iterator i = requests_.begin(); i != requests_.end(); ++i) requests.push_back(i->second); @@ -110,6 +116,7 @@ void StunRequestManager::Clear() { } bool StunRequestManager::CheckResponse(StunMessage* msg) { + RTC_DCHECK_RUN_ON(thread_); RequestMap::iterator iter = requests_.find(msg->transaction_id()); if (iter == requests_.end()) { // TODO(pthatcher): Log unknown responses without being too spammy @@ -156,7 +163,13 @@ bool StunRequestManager::CheckResponse(StunMessage* msg) { return true; } +bool StunRequestManager::empty() const { + RTC_DCHECK_RUN_ON(thread_); + return requests_.empty(); +} + bool StunRequestManager::CheckResponse(const char* data, size_t size) { + RTC_DCHECK_RUN_ON(thread_); // Check the appropriate bytes of the stream to see if they match the // transaction ID of a response we are expecting. @@ -186,32 +199,33 @@ bool StunRequestManager::CheckResponse(const char* data, size_t size) { return CheckResponse(response.get()); } -StunRequest::StunRequest() - : count_(0), - timeout_(false), - manager_(0), +StunRequest::StunRequest(StunRequestManager& manager) + : manager_(manager), msg_(new StunMessage()), - tstamp_(0) { + tstamp_(0), + count_(0), + timeout_(false) { msg_->SetTransactionID(rtc::CreateRandomString(kStunTransactionIdLength)); } -StunRequest::StunRequest(StunMessage* request) - : count_(0), timeout_(false), manager_(0), msg_(request), tstamp_(0) { +StunRequest::StunRequest(StunRequestManager& manager, + std::unique_ptr request) + : manager_(manager), + msg_(std::move(request)), + tstamp_(0), + count_(0), + timeout_(false) { msg_->SetTransactionID(rtc::CreateRandomString(kStunTransactionIdLength)); } StunRequest::~StunRequest() { - RTC_DCHECK(manager_ != NULL); - if (manager_) { - manager_->Remove(this); - manager_->thread_->Clear(this); - } - delete msg_; + manager_.Remove(this); + manager_.network_thread()->Clear(this); } void StunRequest::Construct() { if (msg_->type() == 0) { - Prepare(msg_); + Prepare(msg_.get()); RTC_DCHECK(msg_->type() != 0); } } @@ -222,24 +236,16 @@ int StunRequest::type() { } const StunMessage* StunRequest::msg() const { - return msg_; -} - -StunMessage* StunRequest::mutable_msg() { - return msg_; + return msg_.get(); } int StunRequest::Elapsed() const { + RTC_DCHECK_RUN_ON(network_thread()); return static_cast(rtc::TimeMillis() - tstamp_); } -void StunRequest::set_manager(StunRequestManager* manager) { - RTC_DCHECK(!manager_); - manager_ = manager; -} - void StunRequest::OnMessage(rtc::Message* pmsg) { - RTC_DCHECK(manager_ != NULL); + RTC_DCHECK_RUN_ON(network_thread()); RTC_DCHECK(pmsg->message_id == MSG_STUN_SEND); if (timeout_) { @@ -252,24 +258,26 @@ void StunRequest::OnMessage(rtc::Message* pmsg) { rtc::ByteBufferWriter buf; msg_->Write(&buf); - manager_->SignalSendPacket(buf.Data(), buf.Length(), this); + manager_.SignalSendPacket(buf.Data(), buf.Length(), this); OnSent(); - manager_->thread_->PostDelayed(RTC_FROM_HERE, resend_delay(), this, - MSG_STUN_SEND, NULL); + manager_.network_thread()->PostDelayed(RTC_FROM_HERE, resend_delay(), this, + MSG_STUN_SEND, NULL); } void StunRequest::OnSent() { + RTC_DCHECK_RUN_ON(network_thread()); count_ += 1; int retransmissions = (count_ - 1); if (retransmissions >= STUN_MAX_RETRANSMISSIONS) { timeout_ = true; } - RTC_LOG(LS_VERBOSE) << "Sent STUN request " << count_ - << "; resend delay = " << resend_delay(); + RTC_DLOG(LS_VERBOSE) << "Sent STUN request " << count_ + << "; resend delay = " << resend_delay(); } int StunRequest::resend_delay() { + RTC_DCHECK_RUN_ON(network_thread()); if (count_ == 0) { return 0; } @@ -278,4 +286,9 @@ int StunRequest::resend_delay() { return std::min(rto, STUN_MAX_RTO); } +void StunRequest::set_timed_out() { + RTC_DCHECK_RUN_ON(network_thread()); + timeout_ = true; +} + } // namespace cricket diff --git a/p2p/base/stun_request.h b/p2p/base/stun_request.h index b417c705cd..51276023a7 100644 --- a/p2p/base/stun_request.h +++ b/p2p/base/stun_request.h @@ -15,6 +15,7 @@ #include #include +#include #include #include "api/transport/stun.h" @@ -47,11 +48,11 @@ class StunRequestManager { // If `msg_type` is kAllRequests, sends all pending requests right away. // Otherwise, sends those that have a matching type right away. // Only for testing. - void Flush(int msg_type); + void FlushForTest(int msg_type); // Returns true if at least one request with `msg_type` is scheduled for // transmission. For testing only. - bool HasRequest(int msg_type); + bool HasRequestForTest(int msg_type); // Removes a stun request that was added previously. This will happen // automatically when a request succeeds, fails, or times out. @@ -65,7 +66,10 @@ class StunRequestManager { bool CheckResponse(StunMessage* msg); bool CheckResponse(const char* data, size_t size); - bool empty() { return requests_.empty(); } + bool empty() const; + + // TODO(tommi): Use TaskQueueBase* instead of rtc::Thread. + rtc::Thread* network_thread() const { return thread_; } // Raised when there are bytes to be sent. sigslot::signal3 SignalSendPacket; @@ -74,27 +78,26 @@ class StunRequestManager { typedef std::map RequestMap; rtc::Thread* const thread_; - RequestMap requests_; - - friend class StunRequest; + RequestMap requests_ RTC_GUARDED_BY(thread_); }; // Represents an individual request to be sent. The STUN message can either be // constructed beforehand or built on demand. class StunRequest : public rtc::MessageHandler { public: - StunRequest(); - explicit StunRequest(StunMessage* request); + explicit StunRequest(StunRequestManager& manager); + StunRequest(StunRequestManager& manager, + std::unique_ptr request); ~StunRequest() override; // Causes our wrapped StunMessage to be Prepared void Construct(); // The manager handling this request (if it has been scheduled for sending). - StunRequestManager* manager() { return manager_; } + StunRequestManager* manager() { return &manager_; } // Returns the transaction ID of this request. - const std::string& id() { return msg_->transaction_id(); } + const std::string& id() const { return msg_->transaction_id(); } // Returns the reduced transaction ID of this request. uint32_t reduced_transaction_id() const { @@ -107,15 +110,11 @@ class StunRequest : public rtc::MessageHandler { // Returns a const pointer to `msg_`. const StunMessage* msg() const; - // Returns a mutable pointer to `msg_`. - StunMessage* mutable_msg(); - // Time elapsed since last send (in ms) int Elapsed() const; protected: - int count_; - bool timeout_; + friend class StunRequestManager; // Fills in a request object to be sent. Note that request's transaction ID // will already be set and cannot be changed. @@ -130,17 +129,21 @@ class StunRequest : public rtc::MessageHandler { // Returns the next delay for resends. virtual int resend_delay(); - private: - void set_manager(StunRequestManager* manager); + webrtc::TaskQueueBase* network_thread() const { + return manager_.network_thread(); + } + void set_timed_out(); + + private: // Handles messages for sending and timeout. void OnMessage(rtc::Message* pmsg) override; - StunRequestManager* manager_; - StunMessage* msg_; - int64_t tstamp_; - - friend class StunRequestManager; + StunRequestManager& manager_; + const std::unique_ptr msg_; + int64_t tstamp_ RTC_GUARDED_BY(network_thread()); + int count_ RTC_GUARDED_BY(network_thread()); + bool timeout_ RTC_GUARDED_BY(network_thread()); }; } // namespace cricket diff --git a/p2p/base/stun_request_unittest.cc b/p2p/base/stun_request_unittest.cc index ce573f087d..b551c342c1 100644 --- a/p2p/base/stun_request_unittest.cc +++ b/p2p/base/stun_request_unittest.cc @@ -10,6 +10,7 @@ #include "p2p/base/stun_request.h" +#include #include #include "rtc_base/fake_clock.h" @@ -19,6 +20,24 @@ #include "test/gtest.h" namespace cricket { +namespace { +std::unique_ptr CreateStunMessage( + StunMessageType type, + const StunMessage* req = nullptr) { + std::unique_ptr msg = std::make_unique(); + msg->SetType(type); + if (req) { + msg->SetTransactionID(req->transaction_id()); + } + return msg; +} + +int TotalDelay(int sends) { + std::vector delays = {0, 250, 750, 1750, 3750, + 7750, 15750, 23750, 31750, 39750}; + return delays[sends]; +} +} // namespace class StunRequestTest : public ::testing::Test, public sigslot::has_slots<> { public: @@ -47,21 +66,6 @@ class StunRequestTest : public ::testing::Test, public sigslot::has_slots<> { void OnTimeout() { timeout_ = true; } protected: - static StunMessage* CreateStunMessage(StunMessageType type, - StunMessage* req) { - StunMessage* msg = new StunMessage(); - msg->SetType(type); - if (req) { - msg->SetTransactionID(req->transaction_id()); - } - return msg; - } - static int TotalDelay(int sends) { - std::vector delays = {0, 250, 750, 1750, 3750, - 7750, 15750, 23750, 31750, 39750}; - return delays[sends]; - } - StunRequestManager manager_; int request_count_; StunMessage* response_; @@ -73,9 +77,20 @@ class StunRequestTest : public ::testing::Test, public sigslot::has_slots<> { // Forwards results to the test class. class StunRequestThunker : public StunRequest { public: - StunRequestThunker(StunMessage* msg, StunRequestTest* test) - : StunRequest(msg), test_(test) {} - explicit StunRequestThunker(StunRequestTest* test) : test_(test) {} + StunRequestThunker(StunRequestManager& manager, + StunMessageType message_type, + StunRequestTest* test) + : StunRequest(manager, CreateStunMessage(message_type)), test_(test) { + Construct(); // Triggers a call to `Prepare()` which sets the type. + } + StunRequestThunker(StunRequestManager& manager, StunRequestTest* test) + : StunRequest(manager), test_(test) { + Construct(); // Triggers a call to `Prepare()` which sets the type. + } + + std::unique_ptr CreateResponseMessage(StunMessageType type) { + return CreateStunMessage(type, msg()); + } private: virtual void OnResponse(StunMessage* res) { test_->OnResponse(res); } @@ -93,127 +108,124 @@ class StunRequestThunker : public StunRequest { // Test handling of a normal binding response. TEST_F(StunRequestTest, TestSuccess) { - StunMessage* req = CreateStunMessage(STUN_BINDING_REQUEST, NULL); + auto* request = new StunRequestThunker(manager_, STUN_BINDING_REQUEST, this); + std::unique_ptr res = + request->CreateResponseMessage(STUN_BINDING_RESPONSE); + manager_.Send(request); + EXPECT_TRUE(manager_.CheckResponse(res.get())); - manager_.Send(new StunRequestThunker(req, this)); - StunMessage* res = CreateStunMessage(STUN_BINDING_RESPONSE, req); - EXPECT_TRUE(manager_.CheckResponse(res)); - - EXPECT_TRUE(response_ == res); + EXPECT_TRUE(response_ == res.get()); EXPECT_TRUE(success_); EXPECT_FALSE(failure_); EXPECT_FALSE(timeout_); - delete res; } // Test handling of an error binding response. TEST_F(StunRequestTest, TestError) { - StunMessage* req = CreateStunMessage(STUN_BINDING_REQUEST, NULL); + auto* request = new StunRequestThunker(manager_, STUN_BINDING_REQUEST, this); + std::unique_ptr res = + request->CreateResponseMessage(STUN_BINDING_ERROR_RESPONSE); + manager_.Send(request); + EXPECT_TRUE(manager_.CheckResponse(res.get())); - manager_.Send(new StunRequestThunker(req, this)); - StunMessage* res = CreateStunMessage(STUN_BINDING_ERROR_RESPONSE, req); - EXPECT_TRUE(manager_.CheckResponse(res)); - - EXPECT_TRUE(response_ == res); + EXPECT_TRUE(response_ == res.get()); EXPECT_FALSE(success_); EXPECT_TRUE(failure_); EXPECT_FALSE(timeout_); - delete res; } // Test handling of a binding response with the wrong transaction id. TEST_F(StunRequestTest, TestUnexpected) { - StunMessage* req = CreateStunMessage(STUN_BINDING_REQUEST, NULL); + auto* request = new StunRequestThunker(manager_, STUN_BINDING_REQUEST, this); + std::unique_ptr res = CreateStunMessage(STUN_BINDING_RESPONSE); - manager_.Send(new StunRequestThunker(req, this)); - StunMessage* res = CreateStunMessage(STUN_BINDING_RESPONSE, NULL); - EXPECT_FALSE(manager_.CheckResponse(res)); + manager_.Send(request); + EXPECT_FALSE(manager_.CheckResponse(res.get())); EXPECT_TRUE(response_ == NULL); EXPECT_FALSE(success_); EXPECT_FALSE(failure_); EXPECT_FALSE(timeout_); - delete res; } // Test that requests are sent at the right times. TEST_F(StunRequestTest, TestBackoff) { rtc::ScopedFakeClock fake_clock; - StunMessage* req = CreateStunMessage(STUN_BINDING_REQUEST, NULL); + auto* request = new StunRequestThunker(manager_, STUN_BINDING_REQUEST, this); + std::unique_ptr res = + request->CreateResponseMessage(STUN_BINDING_RESPONSE); int64_t start = rtc::TimeMillis(); - manager_.Send(new StunRequestThunker(req, this)); - StunMessage* res = CreateStunMessage(STUN_BINDING_RESPONSE, req); + manager_.Send(request); for (int i = 0; i < 9; ++i) { EXPECT_TRUE_SIMULATED_WAIT(request_count_ != i, STUN_TOTAL_TIMEOUT, fake_clock); int64_t elapsed = rtc::TimeMillis() - start; - RTC_LOG(LS_INFO) << "STUN request #" << (i + 1) << " sent at " << elapsed - << " ms"; + RTC_DLOG(LS_INFO) << "STUN request #" << (i + 1) << " sent at " << elapsed + << " ms"; EXPECT_EQ(TotalDelay(i), elapsed); } - EXPECT_TRUE(manager_.CheckResponse(res)); + EXPECT_TRUE(manager_.CheckResponse(res.get())); - EXPECT_TRUE(response_ == res); + EXPECT_TRUE(response_ == res.get()); EXPECT_TRUE(success_); EXPECT_FALSE(failure_); EXPECT_FALSE(timeout_); - delete res; } // Test that we timeout properly if no response is received. TEST_F(StunRequestTest, TestTimeout) { rtc::ScopedFakeClock fake_clock; - StunMessage* req = CreateStunMessage(STUN_BINDING_REQUEST, NULL); - StunMessage* res = CreateStunMessage(STUN_BINDING_RESPONSE, req); + auto* request = new StunRequestThunker(manager_, STUN_BINDING_REQUEST, this); + std::unique_ptr res = + request->CreateResponseMessage(STUN_BINDING_RESPONSE); - manager_.Send(new StunRequestThunker(req, this)); + manager_.Send(request); SIMULATED_WAIT(false, cricket::STUN_TOTAL_TIMEOUT, fake_clock); - EXPECT_FALSE(manager_.CheckResponse(res)); + EXPECT_FALSE(manager_.CheckResponse(res.get())); EXPECT_TRUE(response_ == NULL); EXPECT_FALSE(success_); EXPECT_FALSE(failure_); EXPECT_TRUE(timeout_); - delete res; } // Regression test for specific crash where we receive a response with the // same id as a request that doesn't have an underlying StunMessage yet. TEST_F(StunRequestTest, TestNoEmptyRequest) { - StunRequestThunker* request = new StunRequestThunker(this); + StunRequestThunker* request = new StunRequestThunker(manager_, this); manager_.SendDelayed(request, 100); StunMessage dummy_req; dummy_req.SetTransactionID(request->id()); - StunMessage* res = CreateStunMessage(STUN_BINDING_RESPONSE, &dummy_req); + std::unique_ptr res = + CreateStunMessage(STUN_BINDING_RESPONSE, &dummy_req); - EXPECT_TRUE(manager_.CheckResponse(res)); + EXPECT_TRUE(manager_.CheckResponse(res.get())); - EXPECT_TRUE(response_ == res); + EXPECT_TRUE(response_ == res.get()); EXPECT_TRUE(success_); EXPECT_FALSE(failure_); EXPECT_FALSE(timeout_); - delete res; } // If the response contains an attribute in the "comprehension required" range // which is not recognized, the transaction should be considered a failure and // the response should be ignored. TEST_F(StunRequestTest, TestUnrecognizedComprehensionRequiredAttribute) { - StunMessage* req = CreateStunMessage(STUN_BINDING_REQUEST, NULL); + auto* request = new StunRequestThunker(manager_, STUN_BINDING_REQUEST, this); + std::unique_ptr res = + request->CreateResponseMessage(STUN_BINDING_ERROR_RESPONSE); - manager_.Send(new StunRequestThunker(req, this)); - StunMessage* res = CreateStunMessage(STUN_BINDING_ERROR_RESPONSE, req); + manager_.Send(request); res->AddAttribute(StunAttribute::CreateUInt32(0x7777)); - EXPECT_FALSE(manager_.CheckResponse(res)); + EXPECT_FALSE(manager_.CheckResponse(res.get())); EXPECT_EQ(nullptr, response_); EXPECT_FALSE(success_); EXPECT_FALSE(failure_); EXPECT_FALSE(timeout_); - delete res; } } // namespace cricket diff --git a/p2p/base/turn_port.cc b/p2p/base/turn_port.cc index d71bc5f265..8729a8277e 100644 --- a/p2p/base/turn_port.cc +++ b/p2p/base/turn_port.cc @@ -1361,7 +1361,8 @@ void TurnPort::MaybeAddTurnLoggingId(StunMessage* msg) { } TurnAllocateRequest::TurnAllocateRequest(TurnPort* port) - : StunRequest(new TurnMessage()), port_(port) {} + : StunRequest(port->request_manager(), std::make_unique()), + port_(port) {} void TurnAllocateRequest::Prepare(StunMessage* request) { // Create the request as indicated in RFC 5766, Section 6.1. @@ -1549,7 +1550,9 @@ void TurnAllocateRequest::OnTryAlternate(StunMessage* response, int code) { } TurnRefreshRequest::TurnRefreshRequest(TurnPort* port) - : StunRequest(new TurnMessage()), port_(port), lifetime_(-1) {} + : StunRequest(port->request_manager(), std::make_unique()), + port_(port), + lifetime_(-1) {} void TurnRefreshRequest::Prepare(StunMessage* request) { // Create the request as indicated in RFC 5766, Section 7.1. @@ -1630,7 +1633,7 @@ TurnCreatePermissionRequest::TurnCreatePermissionRequest( TurnEntry* entry, const rtc::SocketAddress& ext_addr, const std::string& remote_ufrag) - : StunRequest(new TurnMessage()), + : StunRequest(port->request_manager(), std::make_unique()), port_(port), entry_(entry), ext_addr_(ext_addr), @@ -1703,7 +1706,7 @@ TurnChannelBindRequest::TurnChannelBindRequest( TurnEntry* entry, int channel_id, const rtc::SocketAddress& ext_addr) - : StunRequest(new TurnMessage()), + : StunRequest(port->request_manager(), std::make_unique()), port_(port), entry_(entry), channel_id_(channel_id), diff --git a/p2p/base/turn_port.h b/p2p/base/turn_port.h index fa76695087..74d249317f 100644 --- a/p2p/base/turn_port.h +++ b/p2p/base/turn_port.h @@ -171,6 +171,7 @@ class TurnPort : public Port { void OnAllocateMismatch(); rtc::AsyncPacketSocket* socket() const { return socket_; } + StunRequestManager& request_manager() { return request_manager_; } // Signal with resolved server address. // Parameters are port, server address and resolved server address. @@ -188,7 +189,11 @@ class TurnPort : public Port { sigslot::signal2 SignalTurnRefreshResult; sigslot::signal3 SignalCreatePermissionResult; - void FlushRequests(int msg_type) { request_manager_.Flush(msg_type); } + + void FlushRequestsForTest(int msg_type) { + request_manager_.FlushForTest(msg_type); + } + bool HasRequests() { return !request_manager_.empty(); } void set_credentials(const RelayCredentials& credentials) { credentials_ = credentials; diff --git a/p2p/base/turn_port_unittest.cc b/p2p/base/turn_port_unittest.cc index 8ca84b6dad..d1b911f837 100644 --- a/p2p/base/turn_port_unittest.cc +++ b/p2p/base/turn_port_unittest.cc @@ -1241,10 +1241,10 @@ TEST_F(TurnPortTest, TestRefreshRequestGetsErrorResponse) { // This sends out the first RefreshRequest with correct credentials. // When this succeeds, it will schedule a new RefreshRequest with the bad // credential. - turn_port_->FlushRequests(TURN_REFRESH_REQUEST); + turn_port_->FlushRequestsForTest(TURN_REFRESH_REQUEST); EXPECT_TRUE_SIMULATED_WAIT(turn_refresh_success_, kSimulatedRtt, fake_clock_); // Flush it again, it will receive a bad response. - turn_port_->FlushRequests(TURN_REFRESH_REQUEST); + turn_port_->FlushRequestsForTest(TURN_REFRESH_REQUEST); EXPECT_TRUE_SIMULATED_WAIT(!turn_refresh_success_, kSimulatedRtt, fake_clock_); EXPECT_FALSE(turn_port_->connected()); @@ -1458,11 +1458,11 @@ TEST_F(TurnPortTest, TestRefreshCreatePermissionRequest) { // another request with bad_ufrag and bad_pwd. RelayCredentials bad_credentials("bad_user", "bad_pwd"); turn_port_->set_credentials(bad_credentials); - turn_port_->FlushRequests(kAllRequests); + turn_port_->FlushRequestsForTest(kAllRequests); EXPECT_TRUE_SIMULATED_WAIT(turn_create_permission_success_, kSimulatedRtt, fake_clock_); // Flush the requests again; the create-permission-request will fail. - turn_port_->FlushRequests(kAllRequests); + turn_port_->FlushRequestsForTest(kAllRequests); EXPECT_TRUE_SIMULATED_WAIT(!turn_create_permission_success_, kSimulatedRtt, fake_clock_); EXPECT_TRUE(CheckConnectionFailedAndPruned(conn));