From 924dc088dc886cacdbd29fe87e0ec895f57938f2 Mon Sep 17 00:00:00 2001 From: Tommi Date: Fri, 22 Nov 2024 15:08:09 +0100 Subject: [PATCH] Use 16bit unsigned for channel id for TURN Bug: webrtc:345518625 Change-Id: I0ee879e9a35cd9831e035a661d54201dc6defac9 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/353901 Reviewed-by: Mirko Bonadei Commit-Queue: Tomas Gunnarsson Cr-Commit-Position: refs/heads/main@{#43447} --- p2p/base/turn_port.cc | 36 ++++++++++++---------------------- p2p/base/turn_port.h | 11 ++--------- p2p/base/turn_port_unittest.cc | 9 +++------ p2p/base/turn_server.cc | 14 ++++++++++++- p2p/base/turn_server.h | 16 ++++++++++----- 5 files changed, 41 insertions(+), 45 deletions(-) diff --git a/p2p/base/turn_port.cc b/p2p/base/turn_port.cc index f52e1f2916..b4da910ab3 100644 --- a/p2p/base/turn_port.cc +++ b/p2p/base/turn_port.cc @@ -130,7 +130,7 @@ class TurnChannelBindRequest : public StunRequest { public: TurnChannelBindRequest(TurnPort* port, TurnEntry* entry, - int channel_id, + uint16_t channel_id, const rtc::SocketAddress& ext_addr); ~TurnChannelBindRequest() override; void OnSent() override; @@ -139,10 +139,10 @@ class TurnChannelBindRequest : public StunRequest { void OnTimeout() override; private: - TurnPort* port_; - TurnEntry* entry_; - int channel_id_; - rtc::SocketAddress ext_addr_; + TurnPort* const port_; + TurnEntry* entry_; // Could be WeakPtr. + const uint16_t channel_id_; + const rtc::SocketAddress ext_addr_; }; // Manages a "connection" to a remote destination. We will attempt to bring up @@ -155,9 +155,7 @@ class TurnEntry : public sigslot::has_slots<> { TurnPort* port() { return port_; } - int channel_id() const { return channel_id_; } - // For testing only. - void set_channel_id(int channel_id) { channel_id_ = channel_id; } + uint16_t channel_id() const { return channel_id_; } const rtc::SocketAddress& address() const { return ext_addr_; } BindState state() const { return state_; } @@ -197,9 +195,9 @@ class TurnEntry : public sigslot::has_slots<> { webrtc::CallbackList destroyed_callback_list_; private: - TurnPort* port_; - int channel_id_; - rtc::SocketAddress ext_addr_; + TurnPort* const port_; + const uint16_t channel_id_; + const rtc::SocketAddress ext_addr_; BindState state_; // List of associated connection instances to keep track of how many and // which connections are associated with this entry. Once this is empty, @@ -1095,7 +1093,7 @@ void TurnPort::HandleDataIndication(const char* data, data_attr->length(), ext_addr, PROTO_UDP, packet_time_us); } -void TurnPort::HandleChannelData(int channel_id, +void TurnPort::HandleChannelData(uint16_t channel_id, const char* data, size_t size, int64_t packet_time_us) { @@ -1257,7 +1255,7 @@ TurnEntry* TurnPort::FindEntry(const rtc::SocketAddress& addr) const { return (it != entries_.end()) ? it->get() : nullptr; } -TurnEntry* TurnPort::FindEntry(int channel_id) const { +TurnEntry* TurnPort::FindEntry(uint16_t channel_id) const { auto it = absl::c_find_if(entries_, [&channel_id](const auto& e) { return e->channel_id() == channel_id; }); @@ -1309,16 +1307,6 @@ void TurnPort::SetCallbacksForTest(CallbacksForTest* callbacks) { callbacks_for_test_ = callbacks; } -bool TurnPort::SetEntryChannelIdForTesting(const rtc::SocketAddress& address, - int channel_id) { - TurnEntry* entry = FindEntry(address); - if (!entry) { - return false; - } - entry->set_channel_id(channel_id); - return true; -} - std::string TurnPort::ReconstructServerUrl() { // https://www.rfc-editor.org/rfc/rfc7065#section-3.1 // turnURI = scheme ":" host [ ":" port ] @@ -1720,7 +1708,7 @@ void TurnCreatePermissionRequest::OnTimeout() { TurnChannelBindRequest::TurnChannelBindRequest( TurnPort* port, TurnEntry* entry, - int channel_id, + uint16_t channel_id, const rtc::SocketAddress& ext_addr) : StunRequest(port->request_manager(), std::make_unique(TURN_CHANNEL_BIND_REQUEST)), diff --git a/p2p/base/turn_port.h b/p2p/base/turn_port.h index e8b0f8671d..8e69bca528 100644 --- a/p2p/base/turn_port.h +++ b/p2p/base/turn_port.h @@ -190,13 +190,6 @@ class TurnPort : public Port { void set_credentials(const RelayCredentials& credentials) { credentials_ = credentials; } - // Finds the turn entry with `address` and sets its channel id. - // Returns true if the entry is found. - // This method must not be used in production, it is a test only - // utility that doesn't check the channel id is valid according to - // RFC5766. - bool SetEntryChannelIdForTesting(const rtc::SocketAddress& address, - int channel_id); void HandleConnectionDestroyed(Connection* conn) override; @@ -308,7 +301,7 @@ class TurnPort : public Port { void HandleDataIndication(const char* data, size_t size, int64_t packet_time_us); - void HandleChannelData(int channel_id, + void HandleChannelData(uint16_t channel_id, const char* data, size_t size, int64_t packet_time_us); @@ -327,7 +320,7 @@ class TurnPort : public Port { bool HasPermission(const rtc::IPAddress& ipaddr) const; TurnEntry* FindEntry(const rtc::SocketAddress& address) const; - TurnEntry* FindEntry(int channel_id) const; + TurnEntry* FindEntry(uint16_t channel_id) const; // Marks the connection with remote address `address` failed and // pruned (a.k.a. write-timed-out). Returns true if a connection is found. diff --git a/p2p/base/turn_port_unittest.cc b/p2p/base/turn_port_unittest.cc index 6c16a9d983..d1f97099a8 100644 --- a/p2p/base/turn_port_unittest.cc +++ b/p2p/base/turn_port_unittest.cc @@ -1541,12 +1541,9 @@ TEST_F(TurnPortTest, TestChannelBindGetErrorResponse) { ASSERT_TRUE(conn2 != nullptr); conn1->Ping(0); EXPECT_TRUE_SIMULATED_WAIT(conn1->writable(), kSimulatedRtt * 2, fake_clock_); - // TODO(bugs.webrtc.org/345518625): SetEntryChannelIdForTesting should not be - // a public method. Instead we should set an option on the fake TURN server to - // force it to send a channel bind errors. - int illegal_channel_id = kMaxTurnChannelNumber + 1u; - ASSERT_TRUE(turn_port_->SetEntryChannelIdForTesting( - udp_port_->Candidates()[0].address(), illegal_channel_id)); + + // Tell the TURN server to reject all bind requests from now on. + turn_server_.server()->set_reject_bind_requests(true); std::string data = "ABC"; conn1->Send(data.data(), data.length(), options); diff --git a/p2p/base/turn_server.cc b/p2p/base/turn_server.cc index 494939a306..e8e238f3cf 100644 --- a/p2p/base/turn_server.cc +++ b/p2p/base/turn_server.cc @@ -135,6 +135,8 @@ void TurnServer::OnNewInternalConnection(rtc::Socket* socket) { } void TurnServer::AcceptConnection(rtc::Socket* server_socket) { + RTC_DCHECK_RUN_ON(thread_); + // Check if someone is trying to connect to us. rtc::SocketAddress accept_addr; rtc::Socket* accepted_socket = server_socket->Accept(&accept_addr); @@ -192,6 +194,7 @@ void TurnServer::OnInternalPacket(rtc::AsyncPacketSocket* socket, void TurnServer::HandleStunMessage(TurnServerConnection* conn, rtc::ArrayView payload) { + RTC_DCHECK_RUN_ON(thread_); TurnMessage msg; rtc::ByteBufferReader buf(payload); if (!msg.Read(&buf) || (buf.Length() > 0)) { @@ -577,6 +580,7 @@ std::string TurnServerAllocation::ToString() const { } void TurnServerAllocation::HandleTurnMessage(const TurnMessage* msg) { + RTC_DCHECK_RUN_ON(thread_); RTC_DCHECK(msg != NULL); switch (msg->type()) { case STUN_ALLOCATE_REQUEST: @@ -680,6 +684,7 @@ void TurnServerAllocation::HandleSendIndication(const TurnMessage* msg) { void TurnServerAllocation::HandleCreatePermissionRequest( const TurnMessage* msg) { + RTC_DCHECK_RUN_ON(server_->thread_); // Check mandatory attributes. const StunAddressAttribute* peer_attr = msg->GetAddress(STUN_ATTR_XOR_PEER_ADDRESS); @@ -707,6 +712,13 @@ void TurnServerAllocation::HandleCreatePermissionRequest( } void TurnServerAllocation::HandleChannelBindRequest(const TurnMessage* msg) { + RTC_DCHECK_RUN_ON(server_->thread_); + if (server_->reject_bind_requests_) { + RTC_LOG(LS_ERROR) << "HandleChannelBindRequest: Rejecting bind requests"; + SendBadRequestResponse(msg); + return; + } + // Check mandatory attributes. const StunUInt32Attribute* channel_attr = msg->GetUInt32(STUN_ATTR_CHANNEL_NUMBER); @@ -718,7 +730,7 @@ void TurnServerAllocation::HandleChannelBindRequest(const TurnMessage* msg) { } // Check that channel id is valid. - int channel_id = channel_attr->value() >> 16; + uint16_t channel_id = static_cast(channel_attr->value() >> 16); if (channel_id < kMinTurnChannelNumber || channel_id > kMaxTurnChannelNumber) { SendBadRequestResponse(msg); diff --git a/p2p/base/turn_server.h b/p2p/base/turn_server.h index 94188d6df5..73e53270a1 100644 --- a/p2p/base/turn_server.h +++ b/p2p/base/turn_server.h @@ -73,14 +73,14 @@ class TurnServerConnection { // handles TURN messages (via HandleTurnMessage) and channel data messages // (via HandleChannelData) for this allocation when received by the server. // The object informs the server when its lifetime timer expires. -class TurnServerAllocation { +class TurnServerAllocation final { public: TurnServerAllocation(TurnServer* server_, webrtc::TaskQueueBase* thread, const TurnServerConnection& conn, rtc::AsyncPacketSocket* server_socket, absl::string_view key); - virtual ~TurnServerAllocation(); + ~TurnServerAllocation(); TurnServerConnection* conn() { return &conn_; } const std::string& key() const { return key_; } @@ -99,8 +99,8 @@ class TurnServerAllocation { private: struct Channel { webrtc::ScopedTaskSafety pending_delete; - int id; - rtc::SocketAddress peer; + const uint16_t id; + const rtc::SocketAddress peer; }; struct Permission { webrtc::ScopedTaskSafety pending_delete; @@ -235,6 +235,11 @@ class TurnServer : public sigslot::has_slots<> { reject_private_addresses_ = filter; } + void set_reject_bind_requests(bool filter) { + RTC_DCHECK_RUN_ON(thread_); + reject_bind_requests_ = filter; + } + void set_enable_permission_checks(bool enable) { RTC_DCHECK_RUN_ON(thread_); enable_permission_checks_ = enable; @@ -341,7 +346,8 @@ class TurnServer : public sigslot::has_slots<> { // otu - one-time-use. Server will respond with 438 if it's // sees the same nonce in next transaction. bool enable_otu_nonce_ RTC_GUARDED_BY(thread_); - bool reject_private_addresses_ = false; + bool reject_private_addresses_ RTC_GUARDED_BY(thread_) = false; + bool reject_bind_requests_ RTC_GUARDED_BY(thread_) = false; // Check for permission when receiving an external packet. bool enable_permission_checks_ = true;