Use 16bit unsigned for channel id for TURN

Bug: webrtc:345518625
Change-Id: I0ee879e9a35cd9831e035a661d54201dc6defac9
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/353901
Reviewed-by: Mirko Bonadei <mbonadei@webrtc.org>
Commit-Queue: Tomas Gunnarsson <tommi@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#43447}
This commit is contained in:
Tommi 2024-11-22 15:08:09 +01:00 committed by WebRTC LUCI CQ
parent 89432bc225
commit 924dc088dc
5 changed files with 41 additions and 45 deletions

View File

@ -130,7 +130,7 @@ class TurnChannelBindRequest : public StunRequest {
public:
TurnChannelBindRequest(TurnPort* port,
TurnEntry* entry,
int channel_id,
uint16_t channel_id,
const rtc::SocketAddress& ext_addr);
~TurnChannelBindRequest() override;
void OnSent() override;
@ -139,10 +139,10 @@ class TurnChannelBindRequest : public StunRequest {
void OnTimeout() override;
private:
TurnPort* port_;
TurnEntry* entry_;
int channel_id_;
rtc::SocketAddress ext_addr_;
TurnPort* const port_;
TurnEntry* entry_; // Could be WeakPtr.
const uint16_t channel_id_;
const rtc::SocketAddress ext_addr_;
};
// Manages a "connection" to a remote destination. We will attempt to bring up
@ -155,9 +155,7 @@ class TurnEntry : public sigslot::has_slots<> {
TurnPort* port() { return port_; }
int channel_id() const { return channel_id_; }
// For testing only.
void set_channel_id(int channel_id) { channel_id_ = channel_id; }
uint16_t channel_id() const { return channel_id_; }
const rtc::SocketAddress& address() const { return ext_addr_; }
BindState state() const { return state_; }
@ -197,9 +195,9 @@ class TurnEntry : public sigslot::has_slots<> {
webrtc::CallbackList<TurnEntry*> destroyed_callback_list_;
private:
TurnPort* port_;
int channel_id_;
rtc::SocketAddress ext_addr_;
TurnPort* const port_;
const uint16_t channel_id_;
const rtc::SocketAddress ext_addr_;
BindState state_;
// List of associated connection instances to keep track of how many and
// which connections are associated with this entry. Once this is empty,
@ -1095,7 +1093,7 @@ void TurnPort::HandleDataIndication(const char* data,
data_attr->length(), ext_addr, PROTO_UDP, packet_time_us);
}
void TurnPort::HandleChannelData(int channel_id,
void TurnPort::HandleChannelData(uint16_t channel_id,
const char* data,
size_t size,
int64_t packet_time_us) {
@ -1257,7 +1255,7 @@ TurnEntry* TurnPort::FindEntry(const rtc::SocketAddress& addr) const {
return (it != entries_.end()) ? it->get() : nullptr;
}
TurnEntry* TurnPort::FindEntry(int channel_id) const {
TurnEntry* TurnPort::FindEntry(uint16_t channel_id) const {
auto it = absl::c_find_if(entries_, [&channel_id](const auto& e) {
return e->channel_id() == channel_id;
});
@ -1309,16 +1307,6 @@ void TurnPort::SetCallbacksForTest(CallbacksForTest* callbacks) {
callbacks_for_test_ = callbacks;
}
bool TurnPort::SetEntryChannelIdForTesting(const rtc::SocketAddress& address,
int channel_id) {
TurnEntry* entry = FindEntry(address);
if (!entry) {
return false;
}
entry->set_channel_id(channel_id);
return true;
}
std::string TurnPort::ReconstructServerUrl() {
// https://www.rfc-editor.org/rfc/rfc7065#section-3.1
// turnURI = scheme ":" host [ ":" port ]
@ -1720,7 +1708,7 @@ void TurnCreatePermissionRequest::OnTimeout() {
TurnChannelBindRequest::TurnChannelBindRequest(
TurnPort* port,
TurnEntry* entry,
int channel_id,
uint16_t channel_id,
const rtc::SocketAddress& ext_addr)
: StunRequest(port->request_manager(),
std::make_unique<TurnMessage>(TURN_CHANNEL_BIND_REQUEST)),

View File

@ -190,13 +190,6 @@ class TurnPort : public Port {
void set_credentials(const RelayCredentials& credentials) {
credentials_ = credentials;
}
// Finds the turn entry with `address` and sets its channel id.
// Returns true if the entry is found.
// This method must not be used in production, it is a test only
// utility that doesn't check the channel id is valid according to
// RFC5766.
bool SetEntryChannelIdForTesting(const rtc::SocketAddress& address,
int channel_id);
void HandleConnectionDestroyed(Connection* conn) override;
@ -308,7 +301,7 @@ class TurnPort : public Port {
void HandleDataIndication(const char* data,
size_t size,
int64_t packet_time_us);
void HandleChannelData(int channel_id,
void HandleChannelData(uint16_t channel_id,
const char* data,
size_t size,
int64_t packet_time_us);
@ -327,7 +320,7 @@ class TurnPort : public Port {
bool HasPermission(const rtc::IPAddress& ipaddr) const;
TurnEntry* FindEntry(const rtc::SocketAddress& address) const;
TurnEntry* FindEntry(int channel_id) const;
TurnEntry* FindEntry(uint16_t channel_id) const;
// Marks the connection with remote address `address` failed and
// pruned (a.k.a. write-timed-out). Returns true if a connection is found.

View File

@ -1541,12 +1541,9 @@ TEST_F(TurnPortTest, TestChannelBindGetErrorResponse) {
ASSERT_TRUE(conn2 != nullptr);
conn1->Ping(0);
EXPECT_TRUE_SIMULATED_WAIT(conn1->writable(), kSimulatedRtt * 2, fake_clock_);
// TODO(bugs.webrtc.org/345518625): SetEntryChannelIdForTesting should not be
// a public method. Instead we should set an option on the fake TURN server to
// force it to send a channel bind errors.
int illegal_channel_id = kMaxTurnChannelNumber + 1u;
ASSERT_TRUE(turn_port_->SetEntryChannelIdForTesting(
udp_port_->Candidates()[0].address(), illegal_channel_id));
// Tell the TURN server to reject all bind requests from now on.
turn_server_.server()->set_reject_bind_requests(true);
std::string data = "ABC";
conn1->Send(data.data(), data.length(), options);

View File

@ -135,6 +135,8 @@ void TurnServer::OnNewInternalConnection(rtc::Socket* socket) {
}
void TurnServer::AcceptConnection(rtc::Socket* server_socket) {
RTC_DCHECK_RUN_ON(thread_);
// Check if someone is trying to connect to us.
rtc::SocketAddress accept_addr;
rtc::Socket* accepted_socket = server_socket->Accept(&accept_addr);
@ -192,6 +194,7 @@ void TurnServer::OnInternalPacket(rtc::AsyncPacketSocket* socket,
void TurnServer::HandleStunMessage(TurnServerConnection* conn,
rtc::ArrayView<const uint8_t> payload) {
RTC_DCHECK_RUN_ON(thread_);
TurnMessage msg;
rtc::ByteBufferReader buf(payload);
if (!msg.Read(&buf) || (buf.Length() > 0)) {
@ -577,6 +580,7 @@ std::string TurnServerAllocation::ToString() const {
}
void TurnServerAllocation::HandleTurnMessage(const TurnMessage* msg) {
RTC_DCHECK_RUN_ON(thread_);
RTC_DCHECK(msg != NULL);
switch (msg->type()) {
case STUN_ALLOCATE_REQUEST:
@ -680,6 +684,7 @@ void TurnServerAllocation::HandleSendIndication(const TurnMessage* msg) {
void TurnServerAllocation::HandleCreatePermissionRequest(
const TurnMessage* msg) {
RTC_DCHECK_RUN_ON(server_->thread_);
// Check mandatory attributes.
const StunAddressAttribute* peer_attr =
msg->GetAddress(STUN_ATTR_XOR_PEER_ADDRESS);
@ -707,6 +712,13 @@ void TurnServerAllocation::HandleCreatePermissionRequest(
}
void TurnServerAllocation::HandleChannelBindRequest(const TurnMessage* msg) {
RTC_DCHECK_RUN_ON(server_->thread_);
if (server_->reject_bind_requests_) {
RTC_LOG(LS_ERROR) << "HandleChannelBindRequest: Rejecting bind requests";
SendBadRequestResponse(msg);
return;
}
// Check mandatory attributes.
const StunUInt32Attribute* channel_attr =
msg->GetUInt32(STUN_ATTR_CHANNEL_NUMBER);
@ -718,7 +730,7 @@ void TurnServerAllocation::HandleChannelBindRequest(const TurnMessage* msg) {
}
// Check that channel id is valid.
int channel_id = channel_attr->value() >> 16;
uint16_t channel_id = static_cast<uint16_t>(channel_attr->value() >> 16);
if (channel_id < kMinTurnChannelNumber ||
channel_id > kMaxTurnChannelNumber) {
SendBadRequestResponse(msg);

View File

@ -73,14 +73,14 @@ class TurnServerConnection {
// handles TURN messages (via HandleTurnMessage) and channel data messages
// (via HandleChannelData) for this allocation when received by the server.
// The object informs the server when its lifetime timer expires.
class TurnServerAllocation {
class TurnServerAllocation final {
public:
TurnServerAllocation(TurnServer* server_,
webrtc::TaskQueueBase* thread,
const TurnServerConnection& conn,
rtc::AsyncPacketSocket* server_socket,
absl::string_view key);
virtual ~TurnServerAllocation();
~TurnServerAllocation();
TurnServerConnection* conn() { return &conn_; }
const std::string& key() const { return key_; }
@ -99,8 +99,8 @@ class TurnServerAllocation {
private:
struct Channel {
webrtc::ScopedTaskSafety pending_delete;
int id;
rtc::SocketAddress peer;
const uint16_t id;
const rtc::SocketAddress peer;
};
struct Permission {
webrtc::ScopedTaskSafety pending_delete;
@ -235,6 +235,11 @@ class TurnServer : public sigslot::has_slots<> {
reject_private_addresses_ = filter;
}
void set_reject_bind_requests(bool filter) {
RTC_DCHECK_RUN_ON(thread_);
reject_bind_requests_ = filter;
}
void set_enable_permission_checks(bool enable) {
RTC_DCHECK_RUN_ON(thread_);
enable_permission_checks_ = enable;
@ -341,7 +346,8 @@ class TurnServer : public sigslot::has_slots<> {
// otu - one-time-use. Server will respond with 438 if it's
// sees the same nonce in next transaction.
bool enable_otu_nonce_ RTC_GUARDED_BY(thread_);
bool reject_private_addresses_ = false;
bool reject_private_addresses_ RTC_GUARDED_BY(thread_) = false;
bool reject_bind_requests_ RTC_GUARDED_BY(thread_) = false;
// Check for permission when receiving an external packet.
bool enable_permission_checks_ = true;