Add thread checker to SctpSidAllocator

Also changing AllocateSid to return StreamId instead of bool.

Bug: webrtc:11547
Change-Id: I776e917300ddfdbb79e78c01ef880209ec2c5917
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/298301
Reviewed-by: Florent Castelli <orphis@webrtc.org>
Commit-Queue: Tomas Gunnarsson <tommi@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#39630}
This commit is contained in:
Tommi 2023-03-21 18:45:24 +01:00 committed by WebRTC LUCI CQ
parent 0a025345fe
commit 8efaec62b1
6 changed files with 53 additions and 56 deletions

View File

@ -875,6 +875,7 @@ rtc_library("sctp_data_channel") {
"../api:priority", "../api:priority",
"../api:rtc_error", "../api:rtc_error",
"../api:scoped_refptr", "../api:scoped_refptr",
"../api:sequence_checker",
"../api/transport:datagram_transport_interface", "../api/transport:datagram_transport_interface",
"../media:media_channel", "../media:media_channel",
"../media:rtc_data_sctp_transport_internal", "../media:rtc_data_sctp_transport_internal",
@ -887,6 +888,7 @@ rtc_library("sctp_data_channel") {
"../rtc_base:threading", "../rtc_base:threading",
"../rtc_base:weak_ptr", "../rtc_base:weak_ptr",
"../rtc_base/containers:flat_set", "../rtc_base/containers:flat_set",
"../rtc_base/system:no_unique_address",
"../rtc_base/system:unused", "../rtc_base/system:unused",
] ]
absl_deps = [ absl_deps = [

View File

@ -275,9 +275,10 @@ DataChannelController::InternalCreateSctpDataChannel(
// the network thread. (unless there's no transport). Change this so that // the network thread. (unless there's no transport). Change this so that
// the role is checked on the network thread and any network thread related // the role is checked on the network thread and any network thread related
// initialization is done at the same time (to avoid additional hops). // initialization is done at the same time (to avoid additional hops).
if (pc_->GetSctpSslRole(&role) && !sid_allocator_.AllocateSid(role, &sid)) { if (pc_->GetSctpSslRole(&role)) {
RTC_LOG(LS_ERROR) << "No id can be allocated for the SCTP data channel."; sid = sid_allocator_.AllocateSid(role);
return nullptr; if (!sid.HasValue())
return nullptr;
} }
// Note that when we get here, the ID may still be invalid. // Note that when we get here, the ID may still be invalid.
} else if (!sid_allocator_.ReserveSid(sid)) { } else if (!sid_allocator_.ReserveSid(sid)) {
@ -325,9 +326,8 @@ void DataChannelController::AllocateSctpSids(rtc::SSLRole role) {
std::vector<rtc::scoped_refptr<SctpDataChannel>> channels_to_close; std::vector<rtc::scoped_refptr<SctpDataChannel>> channels_to_close;
for (const auto& channel : sctp_data_channels_) { for (const auto& channel : sctp_data_channels_) {
if (!channel->sid().HasValue()) { if (!channel->sid().HasValue()) {
StreamId sid; StreamId sid = sid_allocator_.AllocateSid(role);
if (!sid_allocator_.AllocateSid(role, &sid)) { if (!sid.HasValue()) {
RTC_LOG(LS_ERROR) << "Failed to allocate SCTP sid, closing channel.";
channels_to_close.push_back(channel); channels_to_close.push_back(channel);
continue; continue;
} }

View File

@ -148,7 +148,7 @@ class DataChannelController : public SctpDataChannelControllerInterface,
bool data_channel_transport_ready_to_send_ bool data_channel_transport_ready_to_send_
RTC_GUARDED_BY(signaling_thread()) = false; RTC_GUARDED_BY(signaling_thread()) = false;
SctpSidAllocator sid_allocator_ /* RTC_GUARDED_BY(signaling_thread()) */; SctpSidAllocator sid_allocator_;
std::vector<rtc::scoped_refptr<SctpDataChannel>> sctp_data_channels_ std::vector<rtc::scoped_refptr<SctpDataChannel>> sctp_data_channels_
RTC_GUARDED_BY(signaling_thread()); RTC_GUARDED_BY(signaling_thread());
bool has_used_data_channels_ RTC_GUARDED_BY(signaling_thread()) = false; bool has_used_data_channels_ RTC_GUARDED_BY(signaling_thread()) = false;

View File

@ -653,18 +653,10 @@ class SctpSidAllocatorTest : public ::testing::Test {
// Verifies that an even SCTP id is allocated for SSL_CLIENT and an odd id for // Verifies that an even SCTP id is allocated for SSL_CLIENT and an odd id for
// SSL_SERVER. // SSL_SERVER.
TEST_F(SctpSidAllocatorTest, SctpIdAllocationBasedOnRole) { TEST_F(SctpSidAllocatorTest, SctpIdAllocationBasedOnRole) {
StreamId id; EXPECT_EQ(allocator_.AllocateSid(rtc::SSL_SERVER), StreamId(1));
EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_SERVER, &id)); EXPECT_EQ(allocator_.AllocateSid(rtc::SSL_CLIENT), StreamId(0));
EXPECT_EQ(1, id.stream_id_int()); EXPECT_EQ(allocator_.AllocateSid(rtc::SSL_SERVER), StreamId(3));
id.reset(); EXPECT_EQ(allocator_.AllocateSid(rtc::SSL_CLIENT), StreamId(2));
EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_CLIENT, &id));
EXPECT_EQ(0, id.stream_id_int());
id.reset();
EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_SERVER, &id));
EXPECT_EQ(3, id.stream_id_int());
id.reset();
EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_CLIENT, &id));
EXPECT_EQ(2, id.stream_id_int());
} }
// Verifies that SCTP ids of existing DataChannels are not reused. // Verifies that SCTP ids of existing DataChannels are not reused.
@ -672,14 +664,14 @@ TEST_F(SctpSidAllocatorTest, SctpIdAllocationNoReuse) {
StreamId old_id(1); StreamId old_id(1);
EXPECT_TRUE(allocator_.ReserveSid(old_id)); EXPECT_TRUE(allocator_.ReserveSid(old_id));
StreamId new_id; StreamId new_id = allocator_.AllocateSid(rtc::SSL_SERVER);
EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_SERVER, &new_id)); EXPECT_TRUE(new_id.HasValue());
EXPECT_NE(old_id, new_id); EXPECT_NE(old_id, new_id);
old_id = StreamId(0); old_id = StreamId(0);
EXPECT_TRUE(allocator_.ReserveSid(old_id)); EXPECT_TRUE(allocator_.ReserveSid(old_id));
new_id.reset(); new_id = allocator_.AllocateSid(rtc::SSL_CLIENT);
EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_CLIENT, &new_id)); EXPECT_TRUE(new_id.HasValue());
EXPECT_NE(old_id, new_id); EXPECT_NE(old_id, new_id);
} }
@ -690,41 +682,33 @@ TEST_F(SctpSidAllocatorTest, SctpIdReusedForRemovedDataChannel) {
EXPECT_TRUE(allocator_.ReserveSid(odd_id)); EXPECT_TRUE(allocator_.ReserveSid(odd_id));
EXPECT_TRUE(allocator_.ReserveSid(even_id)); EXPECT_TRUE(allocator_.ReserveSid(even_id));
StreamId allocated_id; StreamId allocated_id = allocator_.AllocateSid(rtc::SSL_SERVER);
EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_SERVER, &allocated_id));
EXPECT_EQ(odd_id.stream_id_int() + 2, allocated_id.stream_id_int()); EXPECT_EQ(odd_id.stream_id_int() + 2, allocated_id.stream_id_int());
allocated_id.reset(); allocated_id = allocator_.AllocateSid(rtc::SSL_CLIENT);
EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_CLIENT, &allocated_id));
EXPECT_EQ(even_id.stream_id_int() + 2, allocated_id.stream_id_int()); EXPECT_EQ(even_id.stream_id_int() + 2, allocated_id.stream_id_int());
allocated_id.reset(); allocated_id = allocator_.AllocateSid(rtc::SSL_SERVER);
EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_SERVER, &allocated_id));
EXPECT_EQ(odd_id.stream_id_int() + 4, allocated_id.stream_id_int()); EXPECT_EQ(odd_id.stream_id_int() + 4, allocated_id.stream_id_int());
allocated_id.reset(); allocated_id = allocator_.AllocateSid(rtc::SSL_CLIENT);
EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_CLIENT, &allocated_id));
EXPECT_EQ(even_id.stream_id_int() + 4, allocated_id.stream_id_int()); EXPECT_EQ(even_id.stream_id_int() + 4, allocated_id.stream_id_int());
allocator_.ReleaseSid(odd_id); allocator_.ReleaseSid(odd_id);
allocator_.ReleaseSid(even_id); allocator_.ReleaseSid(even_id);
// Verifies that removed ids are reused. // Verifies that removed ids are reused.
allocated_id.reset(); allocated_id = allocator_.AllocateSid(rtc::SSL_SERVER);
EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_SERVER, &allocated_id));
EXPECT_EQ(odd_id, allocated_id); EXPECT_EQ(odd_id, allocated_id);
allocated_id.reset(); allocated_id = allocator_.AllocateSid(rtc::SSL_CLIENT);
EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_CLIENT, &allocated_id));
EXPECT_EQ(even_id, allocated_id); EXPECT_EQ(even_id, allocated_id);
// Verifies that used higher ids are not reused. // Verifies that used higher ids are not reused.
allocated_id.reset(); allocated_id = allocator_.AllocateSid(rtc::SSL_SERVER);
EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_SERVER, &allocated_id));
EXPECT_EQ(odd_id.stream_id_int() + 6, allocated_id.stream_id_int()); EXPECT_EQ(odd_id.stream_id_int() + 6, allocated_id.stream_id_int());
allocated_id.reset(); allocated_id = allocator_.AllocateSid(rtc::SSL_CLIENT);
EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_CLIENT, &allocated_id));
EXPECT_EQ(even_id.stream_id_int() + 6, allocated_id.stream_id_int()); EXPECT_EQ(even_id.stream_id_int() + 6, allocated_id.stream_id_int());
} }

View File

@ -117,25 +117,32 @@ bool InternalDataChannelInit::IsValid() const {
return true; return true;
} }
bool SctpSidAllocator::AllocateSid(rtc::SSLRole role, StreamId* sid) { SctpSidAllocator::SctpSidAllocator() {
int potential_sid = (role == rtc::SSL_CLIENT) ? 0 : 1; sequence_checker_.Detach();
while (potential_sid <= static_cast<int>(cricket::kMaxSctpSid)) {
*sid = StreamId(potential_sid);
if (used_sids_.insert(*sid).second)
return true;
potential_sid += 2;
}
sid->reset();
return false;
} }
bool SctpSidAllocator::ReserveSid(const StreamId& sid) { StreamId SctpSidAllocator::AllocateSid(rtc::SSLRole role) {
RTC_DCHECK_RUN_ON(&sequence_checker_);
int potential_sid = (role == rtc::SSL_CLIENT) ? 0 : 1;
while (potential_sid <= static_cast<int>(cricket::kMaxSctpSid)) {
StreamId sid(potential_sid);
if (used_sids_.insert(sid).second)
return sid;
potential_sid += 2;
}
RTC_LOG(LS_ERROR) << "SCTP sid allocation pool exhausted.";
return StreamId();
}
bool SctpSidAllocator::ReserveSid(StreamId sid) {
RTC_DCHECK_RUN_ON(&sequence_checker_);
if (!sid.HasValue() || sid.stream_id_int() > cricket::kMaxSctpSid) if (!sid.HasValue() || sid.stream_id_int() > cricket::kMaxSctpSid)
return false; return false;
return used_sids_.insert(sid).second; return used_sids_.insert(sid).second;
} }
void SctpSidAllocator::ReleaseSid(const StreamId& sid) { void SctpSidAllocator::ReleaseSid(StreamId sid) {
RTC_DCHECK_RUN_ON(&sequence_checker_);
used_sids_.erase(sid); used_sids_.erase(sid);
} }

View File

@ -22,12 +22,14 @@
#include "api/priority.h" #include "api/priority.h"
#include "api/rtc_error.h" #include "api/rtc_error.h"
#include "api/scoped_refptr.h" #include "api/scoped_refptr.h"
#include "api/sequence_checker.h"
#include "api/transport/data_channel_transport_interface.h" #include "api/transport/data_channel_transport_interface.h"
#include "pc/data_channel_utils.h" #include "pc/data_channel_utils.h"
#include "pc/sctp_utils.h" #include "pc/sctp_utils.h"
#include "rtc_base/containers/flat_set.h" #include "rtc_base/containers/flat_set.h"
#include "rtc_base/copy_on_write_buffer.h" #include "rtc_base/copy_on_write_buffer.h"
#include "rtc_base/ssl_stream_adapter.h" // For SSLRole #include "rtc_base/ssl_stream_adapter.h" // For SSLRole
#include "rtc_base/system/no_unique_address.h"
#include "rtc_base/thread.h" #include "rtc_base/thread.h"
#include "rtc_base/thread_annotations.h" #include "rtc_base/thread_annotations.h"
#include "rtc_base/weak_ptr.h" #include "rtc_base/weak_ptr.h"
@ -73,20 +75,22 @@ struct InternalDataChannelInit : public DataChannelInit {
// Helper class to allocate unique IDs for SCTP DataChannels. // Helper class to allocate unique IDs for SCTP DataChannels.
class SctpSidAllocator { class SctpSidAllocator {
public: public:
SctpSidAllocator();
// Gets the first unused odd/even id based on the DTLS role. If `role` is // Gets the first unused odd/even id based on the DTLS role. If `role` is
// SSL_CLIENT, the allocated id starts from 0 and takes even numbers; // SSL_CLIENT, the allocated id starts from 0 and takes even numbers;
// otherwise, the id starts from 1 and takes odd numbers. // otherwise, the id starts from 1 and takes odd numbers.
// Returns false if no ID can be allocated. // If a `StreamId` cannot be allocated, `StreamId::HasValue()` will be false.
bool AllocateSid(rtc::SSLRole role, StreamId* sid); StreamId AllocateSid(rtc::SSLRole role);
// Attempts to reserve a specific sid. Returns false if it's unavailable. // Attempts to reserve a specific sid. Returns false if it's unavailable.
bool ReserveSid(const StreamId& sid); bool ReserveSid(StreamId sid);
// Indicates that `sid` isn't in use any more, and is thus available again. // Indicates that `sid` isn't in use any more, and is thus available again.
void ReleaseSid(const StreamId& sid); void ReleaseSid(StreamId sid);
private: private:
flat_set<StreamId> used_sids_; flat_set<StreamId> used_sids_ RTC_GUARDED_BY(&sequence_checker_);
RTC_NO_UNIQUE_ADDRESS SequenceChecker sequence_checker_;
}; };
// SctpDataChannel is an implementation of the DataChannelInterface based on // SctpDataChannel is an implementation of the DataChannelInterface based on