diff --git a/pc/BUILD.gn b/pc/BUILD.gn index 7e07843bd8..383e746688 100644 --- a/pc/BUILD.gn +++ b/pc/BUILD.gn @@ -875,6 +875,7 @@ rtc_library("sctp_data_channel") { "../api:priority", "../api:rtc_error", "../api:scoped_refptr", + "../api:sequence_checker", "../api/transport:datagram_transport_interface", "../media:media_channel", "../media:rtc_data_sctp_transport_internal", @@ -887,6 +888,7 @@ rtc_library("sctp_data_channel") { "../rtc_base:threading", "../rtc_base:weak_ptr", "../rtc_base/containers:flat_set", + "../rtc_base/system:no_unique_address", "../rtc_base/system:unused", ] absl_deps = [ diff --git a/pc/data_channel_controller.cc b/pc/data_channel_controller.cc index e98752d43d..7eb1501a59 100644 --- a/pc/data_channel_controller.cc +++ b/pc/data_channel_controller.cc @@ -275,9 +275,10 @@ DataChannelController::InternalCreateSctpDataChannel( // the network thread. (unless there's no transport). Change this so that // the role is checked on the network thread and any network thread related // initialization is done at the same time (to avoid additional hops). - if (pc_->GetSctpSslRole(&role) && !sid_allocator_.AllocateSid(role, &sid)) { - RTC_LOG(LS_ERROR) << "No id can be allocated for the SCTP data channel."; - return nullptr; + if (pc_->GetSctpSslRole(&role)) { + sid = sid_allocator_.AllocateSid(role); + if (!sid.HasValue()) + return nullptr; } // Note that when we get here, the ID may still be invalid. } else if (!sid_allocator_.ReserveSid(sid)) { @@ -325,9 +326,8 @@ void DataChannelController::AllocateSctpSids(rtc::SSLRole role) { std::vector> channels_to_close; for (const auto& channel : sctp_data_channels_) { if (!channel->sid().HasValue()) { - StreamId sid; - if (!sid_allocator_.AllocateSid(role, &sid)) { - RTC_LOG(LS_ERROR) << "Failed to allocate SCTP sid, closing channel."; + StreamId sid = sid_allocator_.AllocateSid(role); + if (!sid.HasValue()) { channels_to_close.push_back(channel); continue; } diff --git a/pc/data_channel_controller.h b/pc/data_channel_controller.h index 2aa8ab1e34..28a7e16eb5 100644 --- a/pc/data_channel_controller.h +++ b/pc/data_channel_controller.h @@ -148,7 +148,7 @@ class DataChannelController : public SctpDataChannelControllerInterface, bool data_channel_transport_ready_to_send_ RTC_GUARDED_BY(signaling_thread()) = false; - SctpSidAllocator sid_allocator_ /* RTC_GUARDED_BY(signaling_thread()) */; + SctpSidAllocator sid_allocator_; std::vector> sctp_data_channels_ RTC_GUARDED_BY(signaling_thread()); bool has_used_data_channels_ RTC_GUARDED_BY(signaling_thread()) = false; diff --git a/pc/data_channel_unittest.cc b/pc/data_channel_unittest.cc index 4eeeac1e24..f92c05cca6 100644 --- a/pc/data_channel_unittest.cc +++ b/pc/data_channel_unittest.cc @@ -653,18 +653,10 @@ class SctpSidAllocatorTest : public ::testing::Test { // Verifies that an even SCTP id is allocated for SSL_CLIENT and an odd id for // SSL_SERVER. TEST_F(SctpSidAllocatorTest, SctpIdAllocationBasedOnRole) { - StreamId id; - EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_SERVER, &id)); - EXPECT_EQ(1, id.stream_id_int()); - id.reset(); - EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_CLIENT, &id)); - EXPECT_EQ(0, id.stream_id_int()); - id.reset(); - EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_SERVER, &id)); - EXPECT_EQ(3, id.stream_id_int()); - id.reset(); - EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_CLIENT, &id)); - EXPECT_EQ(2, id.stream_id_int()); + EXPECT_EQ(allocator_.AllocateSid(rtc::SSL_SERVER), StreamId(1)); + EXPECT_EQ(allocator_.AllocateSid(rtc::SSL_CLIENT), StreamId(0)); + EXPECT_EQ(allocator_.AllocateSid(rtc::SSL_SERVER), StreamId(3)); + EXPECT_EQ(allocator_.AllocateSid(rtc::SSL_CLIENT), StreamId(2)); } // Verifies that SCTP ids of existing DataChannels are not reused. @@ -672,14 +664,14 @@ TEST_F(SctpSidAllocatorTest, SctpIdAllocationNoReuse) { StreamId old_id(1); EXPECT_TRUE(allocator_.ReserveSid(old_id)); - StreamId new_id; - EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_SERVER, &new_id)); + StreamId new_id = allocator_.AllocateSid(rtc::SSL_SERVER); + EXPECT_TRUE(new_id.HasValue()); EXPECT_NE(old_id, new_id); old_id = StreamId(0); EXPECT_TRUE(allocator_.ReserveSid(old_id)); - new_id.reset(); - EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_CLIENT, &new_id)); + new_id = allocator_.AllocateSid(rtc::SSL_CLIENT); + EXPECT_TRUE(new_id.HasValue()); EXPECT_NE(old_id, new_id); } @@ -690,41 +682,33 @@ TEST_F(SctpSidAllocatorTest, SctpIdReusedForRemovedDataChannel) { EXPECT_TRUE(allocator_.ReserveSid(odd_id)); EXPECT_TRUE(allocator_.ReserveSid(even_id)); - StreamId allocated_id; - EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_SERVER, &allocated_id)); + StreamId allocated_id = allocator_.AllocateSid(rtc::SSL_SERVER); EXPECT_EQ(odd_id.stream_id_int() + 2, allocated_id.stream_id_int()); - allocated_id.reset(); - EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_CLIENT, &allocated_id)); + allocated_id = allocator_.AllocateSid(rtc::SSL_CLIENT); EXPECT_EQ(even_id.stream_id_int() + 2, allocated_id.stream_id_int()); - allocated_id.reset(); - EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_SERVER, &allocated_id)); + allocated_id = allocator_.AllocateSid(rtc::SSL_SERVER); EXPECT_EQ(odd_id.stream_id_int() + 4, allocated_id.stream_id_int()); - allocated_id.reset(); - EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_CLIENT, &allocated_id)); + allocated_id = allocator_.AllocateSid(rtc::SSL_CLIENT); EXPECT_EQ(even_id.stream_id_int() + 4, allocated_id.stream_id_int()); allocator_.ReleaseSid(odd_id); allocator_.ReleaseSid(even_id); // Verifies that removed ids are reused. - allocated_id.reset(); - EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_SERVER, &allocated_id)); + allocated_id = allocator_.AllocateSid(rtc::SSL_SERVER); EXPECT_EQ(odd_id, allocated_id); - allocated_id.reset(); - EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_CLIENT, &allocated_id)); + allocated_id = allocator_.AllocateSid(rtc::SSL_CLIENT); EXPECT_EQ(even_id, allocated_id); // Verifies that used higher ids are not reused. - allocated_id.reset(); - EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_SERVER, &allocated_id)); + allocated_id = allocator_.AllocateSid(rtc::SSL_SERVER); EXPECT_EQ(odd_id.stream_id_int() + 6, allocated_id.stream_id_int()); - allocated_id.reset(); - EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_CLIENT, &allocated_id)); + allocated_id = allocator_.AllocateSid(rtc::SSL_CLIENT); EXPECT_EQ(even_id.stream_id_int() + 6, allocated_id.stream_id_int()); } diff --git a/pc/sctp_data_channel.cc b/pc/sctp_data_channel.cc index f02895424b..24efae674c 100644 --- a/pc/sctp_data_channel.cc +++ b/pc/sctp_data_channel.cc @@ -117,25 +117,32 @@ bool InternalDataChannelInit::IsValid() const { return true; } -bool SctpSidAllocator::AllocateSid(rtc::SSLRole role, StreamId* sid) { - int potential_sid = (role == rtc::SSL_CLIENT) ? 0 : 1; - while (potential_sid <= static_cast(cricket::kMaxSctpSid)) { - *sid = StreamId(potential_sid); - if (used_sids_.insert(*sid).second) - return true; - potential_sid += 2; - } - sid->reset(); - return false; +SctpSidAllocator::SctpSidAllocator() { + sequence_checker_.Detach(); } -bool SctpSidAllocator::ReserveSid(const StreamId& sid) { +StreamId SctpSidAllocator::AllocateSid(rtc::SSLRole role) { + RTC_DCHECK_RUN_ON(&sequence_checker_); + int potential_sid = (role == rtc::SSL_CLIENT) ? 0 : 1; + while (potential_sid <= static_cast(cricket::kMaxSctpSid)) { + StreamId sid(potential_sid); + if (used_sids_.insert(sid).second) + return sid; + potential_sid += 2; + } + RTC_LOG(LS_ERROR) << "SCTP sid allocation pool exhausted."; + return StreamId(); +} + +bool SctpSidAllocator::ReserveSid(StreamId sid) { + RTC_DCHECK_RUN_ON(&sequence_checker_); if (!sid.HasValue() || sid.stream_id_int() > cricket::kMaxSctpSid) return false; return used_sids_.insert(sid).second; } -void SctpSidAllocator::ReleaseSid(const StreamId& sid) { +void SctpSidAllocator::ReleaseSid(StreamId sid) { + RTC_DCHECK_RUN_ON(&sequence_checker_); used_sids_.erase(sid); } diff --git a/pc/sctp_data_channel.h b/pc/sctp_data_channel.h index e57c5fdee4..e0f8bb24b6 100644 --- a/pc/sctp_data_channel.h +++ b/pc/sctp_data_channel.h @@ -22,12 +22,14 @@ #include "api/priority.h" #include "api/rtc_error.h" #include "api/scoped_refptr.h" +#include "api/sequence_checker.h" #include "api/transport/data_channel_transport_interface.h" #include "pc/data_channel_utils.h" #include "pc/sctp_utils.h" #include "rtc_base/containers/flat_set.h" #include "rtc_base/copy_on_write_buffer.h" #include "rtc_base/ssl_stream_adapter.h" // For SSLRole +#include "rtc_base/system/no_unique_address.h" #include "rtc_base/thread.h" #include "rtc_base/thread_annotations.h" #include "rtc_base/weak_ptr.h" @@ -73,20 +75,22 @@ struct InternalDataChannelInit : public DataChannelInit { // Helper class to allocate unique IDs for SCTP DataChannels. class SctpSidAllocator { public: + SctpSidAllocator(); // Gets the first unused odd/even id based on the DTLS role. If `role` is // SSL_CLIENT, the allocated id starts from 0 and takes even numbers; // otherwise, the id starts from 1 and takes odd numbers. - // Returns false if no ID can be allocated. - bool AllocateSid(rtc::SSLRole role, StreamId* sid); + // If a `StreamId` cannot be allocated, `StreamId::HasValue()` will be false. + StreamId AllocateSid(rtc::SSLRole role); // Attempts to reserve a specific sid. Returns false if it's unavailable. - bool ReserveSid(const StreamId& sid); + bool ReserveSid(StreamId sid); // Indicates that `sid` isn't in use any more, and is thus available again. - void ReleaseSid(const StreamId& sid); + void ReleaseSid(StreamId sid); private: - flat_set used_sids_; + flat_set used_sids_ RTC_GUARDED_BY(&sequence_checker_); + RTC_NO_UNIQUE_ADDRESS SequenceChecker sequence_checker_; }; // SctpDataChannel is an implementation of the DataChannelInterface based on