From 492296cc3c41020ed967e3c9e26468c8463ab420 Mon Sep 17 00:00:00 2001 From: Tommi Date: Sun, 12 Mar 2023 16:59:25 +0100 Subject: [PATCH] Remove the `SctpDataChannel::config_` member variable. Instead there are direct member variables for the various relevant states, some weren't needed, some can be const but the `id` member in particular needs special handling and can't be const. For dealing with the stream id, we now have SctpSid. A class that does range validation, checks thread safety, handles the special `-1` case (for what's essentially an unsigned 16 bit int). Using a special type for this also has the effect that range checking happens more consistently (although I'm not modifying the structs in api/). With upcoming steps of avoiding thread hops, the ID may need to migrate to the network thread, which the thread checks will help with. Along the way, update SctpSidAllocator to use flat_set instead of std::set and moving some of the sctp data channel code to the cc file to help with more accurately tracking code coverage. Bug: webrtc:11547 Change-Id: Iea6e7647ab8f93052044c5afbcc449115206b4e9 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/296444 Commit-Queue: Tomas Gunnarsson Reviewed-by: Harald Alvestrand Cr-Commit-Position: refs/heads/main@{#39539} --- media/sctp/sctp_transport_internal.h | 2 + pc/BUILD.gn | 6 + pc/DEPS | 1 + pc/data_channel_controller.cc | 28 ++- pc/data_channel_controller_unittest.cc | 51 +++++ pc/data_channel_unittest.cc | 293 ++++++++++++++----------- pc/sctp_data_channel.cc | 152 ++++++++----- pc/sctp_data_channel.h | 64 +++--- pc/sctp_utils.cc | 89 ++++++-- pc/sctp_utils.h | 47 +++- pc/sctp_utils_unittest.cc | 46 ++++ 11 files changed, 529 insertions(+), 250 deletions(-) diff --git a/media/sctp/sctp_transport_internal.h b/media/sctp/sctp_transport_internal.h index 38da554911..fd31176894 100644 --- a/media/sctp/sctp_transport_internal.h +++ b/media/sctp/sctp_transport_internal.h @@ -46,6 +46,8 @@ constexpr int kSctpSendBufferSize = 256 * 1024; constexpr uint16_t kMaxSctpStreams = 1024; constexpr uint16_t kMaxSctpSid = kMaxSctpStreams - 1; constexpr uint16_t kMinSctpSid = 0; +// The maximum number of streams that can be negotiated according to spec. +constexpr uint16_t kSpecMaxSctpSid = 65535; // This is the default SCTP port to use. It is passed along the wire and the // connectee and connector must be using the same port. It is not related to the diff --git a/pc/BUILD.gn b/pc/BUILD.gn index 2776f9f245..43dd114934 100644 --- a/pc/BUILD.gn +++ b/pc/BUILD.gn @@ -517,12 +517,17 @@ rtc_source_set("sctp_utils") { deps = [ "../api:libjingle_peerconnection_api", "../api:priority", + "../api:sequence_checker", "../api/transport:datagram_transport_interface", "../media:media_channel", + "../media:rtc_data_sctp_transport_internal", "../media:rtc_media_base", + "../net/dcsctp/public:types", "../rtc_base:byte_buffer", "../rtc_base:copy_on_write_buffer", "../rtc_base:logging", + "../rtc_base:ssl", + "../rtc_base/system:no_unique_address", ] absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] } @@ -884,6 +889,7 @@ rtc_library("sctp_data_channel") { "../rtc_base:threading", "../rtc_base:threading", "../rtc_base:weak_ptr", + "../rtc_base/containers:flat_set", "../rtc_base/system:unused", "../rtc_base/third_party/sigslot:sigslot", ] diff --git a/pc/DEPS b/pc/DEPS index edb904c712..80a702d716 100644 --- a/pc/DEPS +++ b/pc/DEPS @@ -11,6 +11,7 @@ include_rules = [ "+modules/rtp_rtcp", "+modules/video_coding", "+modules/video_render", + "+net/dcsctp", "+p2p", "+system_wrappers", ] diff --git a/pc/data_channel_controller.cc b/pc/data_channel_controller.cc index 3011c0f5f6..0ebfdfa4cc 100644 --- a/pc/data_channel_controller.cc +++ b/pc/data_channel_controller.cc @@ -277,23 +277,30 @@ DataChannelController::InternalCreateSctpDataChannel( RTC_DCHECK_RUN_ON(signaling_thread()); InternalDataChannelInit new_config = config ? (*config) : InternalDataChannelInit(); - if (new_config.id < 0) { + StreamId sid(new_config.id); + if (!sid.HasValue()) { rtc::SSLRole role; - if ((pc_->GetSctpSslRole(&role)) && - !sid_allocator_.AllocateSid(role, &new_config.id)) { + // TODO(bugs.webrtc.org/11547): `GetSctpSslRole` likely involves a hop to + // the network thread. (unless there's no transport). Change this so that + // the role is checked on the network thread and any network thread related + // initialization is done at the same time (to avoid additional hops). + if (pc_->GetSctpSslRole(&role) && !sid_allocator_.AllocateSid(role, &sid)) { RTC_LOG(LS_ERROR) << "No id can be allocated for the SCTP data channel."; return nullptr; } - } else if (!sid_allocator_.ReserveSid(new_config.id)) { + // Note that when we get here, the ID may still be invalid. + } else if (!sid_allocator_.ReserveSid(sid)) { RTC_LOG(LS_ERROR) << "Failed to create a SCTP data channel " "because the id is already in use or out of range."; return nullptr; } + // In case `sid` has changed. Update `new_config` accordingly. + new_config.id = sid.stream_id_int(); rtc::scoped_refptr channel( SctpDataChannel::Create(weak_factory_.GetWeakPtr(), label, new_config, signaling_thread(), network_thread())); if (!channel) { - sid_allocator_.ReleaseSid(new_config.id); + sid_allocator_.ReleaseSid(sid); return nullptr; } sctp_data_channels_.push_back(channel); @@ -304,13 +311,16 @@ void DataChannelController::AllocateSctpSids(rtc::SSLRole role) { RTC_DCHECK_RUN_ON(signaling_thread()); std::vector> channels_to_close; for (const auto& channel : sctp_data_channels_) { - if (channel->id() < 0) { - int sid; + if (!channel->sid().HasValue()) { + StreamId sid; if (!sid_allocator_.AllocateSid(role, &sid)) { RTC_LOG(LS_ERROR) << "Failed to allocate SCTP sid, closing channel."; channels_to_close.push_back(channel); continue; } + // TODO(bugs.webrtc.org/11547): This hides a blocking call to the network + // thread via AddSctpDataStream. Maybe it's better to move the whole loop + // to the network thread? Maybe even `sctp_data_channels_`? channel->SetSctpSid(sid); } } @@ -326,10 +336,10 @@ void DataChannelController::OnSctpDataChannelClosed(SctpDataChannel* channel) { for (auto it = sctp_data_channels_.begin(); it != sctp_data_channels_.end(); ++it) { if (it->get() == channel) { - if (channel->id() >= 0) { + if (channel->sid().HasValue()) { // After the closing procedure is done, it's safe to use this ID for // another data channel. - sid_allocator_.ReleaseSid(channel->id()); + sid_allocator_.ReleaseSid(channel->sid()); } // Since this method is triggered by a signal from the DataChannel, diff --git a/pc/data_channel_controller_unittest.cc b/pc/data_channel_controller_unittest.cc index 097eed61ed..0d9dd88efd 100644 --- a/pc/data_channel_controller_unittest.cc +++ b/pc/data_channel_controller_unittest.cc @@ -26,12 +26,30 @@ namespace { using ::testing::NiceMock; using ::testing::Return; +class MockDataChannelTransport : public webrtc::DataChannelTransportInterface { + public: + ~MockDataChannelTransport() override {} + + MOCK_METHOD(RTCError, OpenChannel, (int channel_id), (override)); + MOCK_METHOD(RTCError, + SendData, + (int channel_id, + const SendDataParams& params, + const rtc::CopyOnWriteBuffer& buffer), + (override)); + MOCK_METHOD(RTCError, CloseChannel, (int channel_id), (override)); + MOCK_METHOD(void, SetDataSink, (DataChannelSink * sink), (override)); + MOCK_METHOD(bool, IsReadyToSend, (), (const, override)); +}; + class DataChannelControllerTest : public ::testing::Test { protected: DataChannelControllerTest() { pc_ = rtc::make_ref_counted>(); ON_CALL(*pc_, signaling_thread) .WillByDefault(Return(rtc::Thread::Current())); + // TODO(tommi): Return a dedicated thread. + ON_CALL(*pc_, network_thread).WillByDefault(Return(rtc::Thread::Current())); } ~DataChannelControllerTest() override { run_loop_.Flush(); } @@ -116,5 +134,38 @@ TEST_F(DataChannelControllerTest, AsyncChannelCloseTeardown) { rtc::RefCountReleaseStatus::kDroppedLastRef); } +// Allocate the maximum number of data channels and then one more. +// The last allocation should fail. +TEST_F(DataChannelControllerTest, MaxChannels) { + NiceMock transport; + int channel_id = 0; + + ON_CALL(*pc_, GetSctpSslRole).WillByDefault([&](rtc::SSLRole* role) { + *role = (channel_id & 1) ? rtc::SSL_SERVER : rtc::SSL_CLIENT; + return true; + }); + + DataChannelController dcc(pc_.get()); + pc_->network_thread()->BlockingCall( + [&] { dcc.set_data_channel_transport(&transport); }); + + // Allocate the maximum number of channels + 1. Inside the loop, the creation + // process will allocate a stream id for each channel. + for (channel_id = 0; channel_id <= cricket::kMaxSctpStreams; ++channel_id) { + rtc::scoped_refptr channel = + dcc.InternalCreateDataChannelWithProxy( + "label", + std::make_unique(DataChannelInit()).get()); + + if (channel_id == cricket::kMaxSctpStreams) { + // We've reached the maximum and the previous call should have failed. + EXPECT_FALSE(channel.get()); + } else { + // We're still working on saturating the pool. Things should be working. + EXPECT_TRUE(channel.get()); + } + } +} + } // namespace } // namespace webrtc diff --git a/pc/data_channel_unittest.cc b/pc/data_channel_unittest.cc index d9575f0312..973c6943df 100644 --- a/pc/data_channel_unittest.cc +++ b/pc/data_channel_unittest.cc @@ -30,13 +30,13 @@ #include "rtc_base/thread.h" #include "test/gtest.h" -using webrtc::DataChannelInterface; -using webrtc::SctpDataChannel; -using webrtc::SctpSidAllocator; +namespace webrtc { + +namespace { static constexpr int kDefaultTimeout = 10000; -class FakeDataChannelObserver : public webrtc::DataChannelObserver { +class FakeDataChannelObserver : public DataChannelObserver { public: FakeDataChannelObserver() : messages_received_(0), @@ -49,7 +49,7 @@ class FakeDataChannelObserver : public webrtc::DataChannelObserver { ++on_buffered_amount_change_count_; } - void OnMessage(const webrtc::DataBuffer& buffer) { ++messages_received_; } + void OnMessage(const DataBuffer& buffer) { ++messages_received_; } size_t messages_received() const { return messages_received_; } @@ -88,8 +88,8 @@ class SctpDataChannelTest : public ::testing::Test { void SetChannelReady() { controller_->set_transport_available(true); webrtc_data_channel_->OnTransportChannelCreated(); - if (webrtc_data_channel_->id() < 0) { - webrtc_data_channel_->SetSctpSid(0); + if (!webrtc_data_channel_->sid().HasValue()) { + webrtc_data_channel_->SetSctpSid(StreamId(0)); } controller_->set_ready_to_send(true); } @@ -100,12 +100,39 @@ class SctpDataChannelTest : public ::testing::Test { } rtc::AutoThread main_thread_; - webrtc::InternalDataChannelInit init_; + InternalDataChannelInit init_; std::unique_ptr controller_; std::unique_ptr observer_; rtc::scoped_refptr webrtc_data_channel_; }; +TEST_F(SctpDataChannelTest, VerifyConfigurationGetters) { + EXPECT_EQ(webrtc_data_channel_->label(), "test"); + EXPECT_EQ(webrtc_data_channel_->protocol(), init_.protocol); + + // Note that the `init_.reliable` field is deprecated, so we directly set + // it here to match spec behavior for purposes of checking the `reliable()` + // getter. + init_.reliable = (!init_.maxRetransmits && !init_.maxRetransmitTime); + EXPECT_EQ(webrtc_data_channel_->reliable(), init_.reliable); + EXPECT_EQ(webrtc_data_channel_->ordered(), init_.ordered); + EXPECT_EQ(webrtc_data_channel_->negotiated(), init_.negotiated); + EXPECT_EQ(webrtc_data_channel_->priority(), Priority::kLow); + EXPECT_EQ(webrtc_data_channel_->maxRetransmitTime(), + static_cast(-1)); + EXPECT_EQ(webrtc_data_channel_->maxPacketLifeTime(), init_.maxRetransmitTime); + EXPECT_EQ(webrtc_data_channel_->maxRetransmits(), static_cast(-1)); + EXPECT_EQ(webrtc_data_channel_->maxRetransmitsOpt(), init_.maxRetransmits); + + // Check the non-const part of the configuration. + EXPECT_EQ(webrtc_data_channel_->id(), init_.id); + EXPECT_EQ(webrtc_data_channel_->sid(), StreamId()); + + SetChannelReady(); + EXPECT_EQ(webrtc_data_channel_->id(), 0); + EXPECT_EQ(webrtc_data_channel_->sid(), StreamId(0)); +} + // Verifies that the data channel is connected to the transport after creation. TEST_F(SctpDataChannelTest, ConnectedToTransportOnCreated) { controller_->set_transport_available(true); @@ -118,7 +145,7 @@ TEST_F(SctpDataChannelTest, ConnectedToTransportOnCreated) { EXPECT_FALSE(controller_->IsSendStreamAdded(dc->id())); EXPECT_FALSE(controller_->IsRecvStreamAdded(dc->id())); - dc->SetSctpSid(0); + dc->SetSctpSid(StreamId(0)); EXPECT_TRUE(controller_->IsSendStreamAdded(dc->id())); EXPECT_TRUE(controller_->IsRecvStreamAdded(dc->id())); } @@ -135,19 +162,17 @@ TEST_F(SctpDataChannelTest, ConnectedAfterTransportBecomesAvailable) { // Tests the state of the data channel. TEST_F(SctpDataChannelTest, StateTransition) { - EXPECT_EQ(webrtc::DataChannelInterface::kConnecting, - webrtc_data_channel_->state()); + EXPECT_EQ(DataChannelInterface::kConnecting, webrtc_data_channel_->state()); EXPECT_EQ(controller_->channels_opened(), 0); EXPECT_EQ(controller_->channels_closed(), 0); SetChannelReady(); - EXPECT_EQ(webrtc::DataChannelInterface::kOpen, webrtc_data_channel_->state()); + EXPECT_EQ(DataChannelInterface::kOpen, webrtc_data_channel_->state()); EXPECT_EQ(controller_->channels_opened(), 1); EXPECT_EQ(controller_->channels_closed(), 0); webrtc_data_channel_->Close(); - EXPECT_EQ(webrtc::DataChannelInterface::kClosed, - webrtc_data_channel_->state()); + EXPECT_EQ(DataChannelInterface::kClosed, webrtc_data_channel_->state()); EXPECT_TRUE(webrtc_data_channel_->error().ok()); EXPECT_EQ(controller_->channels_opened(), 1); EXPECT_EQ(controller_->channels_closed(), 1); @@ -160,7 +185,7 @@ TEST_F(SctpDataChannelTest, StateTransition) { TEST_F(SctpDataChannelTest, BufferedAmountWhenBlocked) { AddObserver(); SetChannelReady(); - webrtc::DataBuffer buffer("abcd"); + DataBuffer buffer("abcd"); EXPECT_TRUE(webrtc_data_channel_->Send(buffer)); size_t successful_send_count = 1; @@ -191,7 +216,7 @@ TEST_F(SctpDataChannelTest, BufferedAmountWhenBlocked) { TEST_F(SctpDataChannelTest, QueuedDataSentWhenUnblocked) { AddObserver(); SetChannelReady(); - webrtc::DataBuffer buffer("abcd"); + DataBuffer buffer("abcd"); controller_->set_send_blocked(true); EXPECT_TRUE(webrtc_data_channel_->Send(buffer)); @@ -208,7 +233,7 @@ TEST_F(SctpDataChannelTest, QueuedDataSentWhenUnblocked) { TEST_F(SctpDataChannelTest, BlockedWhenSendQueuedDataNoCrash) { AddObserver(); SetChannelReady(); - webrtc::DataBuffer buffer("abcd"); + DataBuffer buffer("abcd"); controller_->set_send_blocked(true); EXPECT_TRUE(webrtc_data_channel_->Send(buffer)); EXPECT_EQ(0U, observer_->on_buffered_amount_change_count()); @@ -230,13 +255,13 @@ TEST_F(SctpDataChannelTest, BlockedWhenSendQueuedDataNoCrash) { TEST_F(SctpDataChannelTest, VerifyMessagesAndBytesSent) { AddObserver(); SetChannelReady(); - std::vector buffers({ - webrtc::DataBuffer("message 1"), - webrtc::DataBuffer("msg 2"), - webrtc::DataBuffer("message three"), - webrtc::DataBuffer("quadra message"), - webrtc::DataBuffer("fifthmsg"), - webrtc::DataBuffer("message of the beast"), + std::vector buffers({ + DataBuffer("message 1"), + DataBuffer("msg 2"), + DataBuffer("message three"), + DataBuffer("quadra message"), + DataBuffer("fifthmsg"), + DataBuffer("message of the beast"), }); // Default values. @@ -279,7 +304,7 @@ TEST_F(SctpDataChannelTest, OpenMessageSent) { SetChannelReady(); EXPECT_GE(webrtc_data_channel_->id(), 0); - EXPECT_EQ(webrtc::DataMessageType::kControl, + EXPECT_EQ(DataMessageType::kControl, controller_->last_send_data_params().type); EXPECT_EQ(controller_->last_sid(), webrtc_data_channel_->id()); } @@ -289,7 +314,7 @@ TEST_F(SctpDataChannelTest, QueuedOpenMessageSent) { SetChannelReady(); controller_->set_send_blocked(false); - EXPECT_EQ(webrtc::DataMessageType::kControl, + EXPECT_EQ(DataMessageType::kControl, controller_->last_send_data_params().type); EXPECT_EQ(controller_->last_sid(), webrtc_data_channel_->id()); } @@ -298,39 +323,39 @@ TEST_F(SctpDataChannelTest, QueuedOpenMessageSent) { // state. TEST_F(SctpDataChannelTest, LateCreatedChannelTransitionToOpen) { SetChannelReady(); - webrtc::InternalDataChannelInit init; + InternalDataChannelInit init; init.id = 1; rtc::scoped_refptr dc = SctpDataChannel::Create(controller_->weak_ptr(), "test1", init, rtc::Thread::Current(), rtc::Thread::Current()); - EXPECT_EQ(webrtc::DataChannelInterface::kConnecting, dc->state()); - EXPECT_TRUE_WAIT(webrtc::DataChannelInterface::kOpen == dc->state(), 1000); + EXPECT_EQ(DataChannelInterface::kConnecting, dc->state()); + EXPECT_TRUE_WAIT(DataChannelInterface::kOpen == dc->state(), 1000); } // Tests that an unordered DataChannel sends data as ordered until the OPEN_ACK // message is received. TEST_F(SctpDataChannelTest, SendUnorderedAfterReceivesOpenAck) { SetChannelReady(); - webrtc::InternalDataChannelInit init; + InternalDataChannelInit init; init.id = 1; init.ordered = false; rtc::scoped_refptr dc = SctpDataChannel::Create(controller_->weak_ptr(), "test1", init, rtc::Thread::Current(), rtc::Thread::Current()); - EXPECT_EQ_WAIT(webrtc::DataChannelInterface::kOpen, dc->state(), 1000); + EXPECT_EQ_WAIT(DataChannelInterface::kOpen, dc->state(), 1000); // Sends a message and verifies it's ordered. - webrtc::DataBuffer buffer("some data"); + DataBuffer buffer("some data"); ASSERT_TRUE(dc->Send(buffer)); EXPECT_TRUE(controller_->last_send_data_params().ordered); // Emulates receiving an OPEN_ACK message. cricket::ReceiveDataParams params; params.sid = init.id; - params.type = webrtc::DataMessageType::kControl; + params.type = DataMessageType::kControl; rtc::CopyOnWriteBuffer payload; - webrtc::WriteDataChannelOpenAckMessage(&payload); + WriteDataChannelOpenAckMessage(&payload); dc->OnDataReceived(params, payload); // Sends another message and verifies it's unordered. @@ -342,20 +367,20 @@ TEST_F(SctpDataChannelTest, SendUnorderedAfterReceivesOpenAck) { // message is received. TEST_F(SctpDataChannelTest, SendUnorderedAfterReceiveData) { SetChannelReady(); - webrtc::InternalDataChannelInit init; + InternalDataChannelInit init; init.id = 1; init.ordered = false; rtc::scoped_refptr dc = SctpDataChannel::Create(controller_->weak_ptr(), "test1", init, rtc::Thread::Current(), rtc::Thread::Current()); - EXPECT_EQ_WAIT(webrtc::DataChannelInterface::kOpen, dc->state(), 1000); + EXPECT_EQ_WAIT(DataChannelInterface::kOpen, dc->state(), 1000); // Emulates receiving a DATA message. cricket::ReceiveDataParams params; params.sid = init.id; - params.type = webrtc::DataMessageType::kText; - webrtc::DataBuffer buffer("data"); + params.type = DataMessageType::kText; + DataBuffer buffer("data"); dc->OnDataReceived(params, buffer.data); // Sends a message and verifies it's unordered. @@ -366,60 +391,57 @@ TEST_F(SctpDataChannelTest, SendUnorderedAfterReceiveData) { // Tests that the channel can't open until it's successfully sent the OPEN // message. TEST_F(SctpDataChannelTest, OpenWaitsForOpenMesssage) { - webrtc::DataBuffer buffer("foo"); + DataBuffer buffer("foo"); controller_->set_send_blocked(true); SetChannelReady(); - EXPECT_EQ(webrtc::DataChannelInterface::kConnecting, - webrtc_data_channel_->state()); + EXPECT_EQ(DataChannelInterface::kConnecting, webrtc_data_channel_->state()); controller_->set_send_blocked(false); - EXPECT_EQ_WAIT(webrtc::DataChannelInterface::kOpen, - webrtc_data_channel_->state(), 1000); - EXPECT_EQ(webrtc::DataMessageType::kControl, + EXPECT_EQ_WAIT(DataChannelInterface::kOpen, webrtc_data_channel_->state(), + 1000); + EXPECT_EQ(DataMessageType::kControl, controller_->last_send_data_params().type); } // Tests that close first makes sure all queued data gets sent. TEST_F(SctpDataChannelTest, QueuedCloseFlushes) { - webrtc::DataBuffer buffer("foo"); + DataBuffer buffer("foo"); controller_->set_send_blocked(true); SetChannelReady(); - EXPECT_EQ(webrtc::DataChannelInterface::kConnecting, - webrtc_data_channel_->state()); + EXPECT_EQ(DataChannelInterface::kConnecting, webrtc_data_channel_->state()); controller_->set_send_blocked(false); - EXPECT_EQ_WAIT(webrtc::DataChannelInterface::kOpen, - webrtc_data_channel_->state(), 1000); + EXPECT_EQ_WAIT(DataChannelInterface::kOpen, webrtc_data_channel_->state(), + 1000); controller_->set_send_blocked(true); webrtc_data_channel_->Send(buffer); webrtc_data_channel_->Close(); controller_->set_send_blocked(false); - EXPECT_EQ_WAIT(webrtc::DataChannelInterface::kClosed, - webrtc_data_channel_->state(), 1000); + EXPECT_EQ_WAIT(DataChannelInterface::kClosed, webrtc_data_channel_->state(), + 1000); EXPECT_TRUE(webrtc_data_channel_->error().ok()); - EXPECT_EQ(webrtc::DataMessageType::kText, - controller_->last_send_data_params().type); + EXPECT_EQ(DataMessageType::kText, controller_->last_send_data_params().type); } // Tests that messages are sent with the right id. TEST_F(SctpDataChannelTest, SendDataId) { - webrtc_data_channel_->SetSctpSid(1); + webrtc_data_channel_->SetSctpSid(StreamId(1)); SetChannelReady(); - webrtc::DataBuffer buffer("data"); + DataBuffer buffer("data"); EXPECT_TRUE(webrtc_data_channel_->Send(buffer)); EXPECT_EQ(1, controller_->last_sid()); } // Tests that the incoming messages with wrong ids are rejected. TEST_F(SctpDataChannelTest, ReceiveDataWithInvalidId) { - webrtc_data_channel_->SetSctpSid(1); + webrtc_data_channel_->SetSctpSid(StreamId(1)); SetChannelReady(); AddObserver(); cricket::ReceiveDataParams params; params.sid = 0; - webrtc::DataBuffer buffer("abcd"); + DataBuffer buffer("abcd"); webrtc_data_channel_->OnDataReceived(params, buffer.data); EXPECT_EQ(0U, observer_->messages_received()); @@ -427,14 +449,14 @@ TEST_F(SctpDataChannelTest, ReceiveDataWithInvalidId) { // Tests that the incoming messages with right ids are accepted. TEST_F(SctpDataChannelTest, ReceiveDataWithValidId) { - webrtc_data_channel_->SetSctpSid(1); + webrtc_data_channel_->SetSctpSid(StreamId(1)); SetChannelReady(); AddObserver(); cricket::ReceiveDataParams params; params.sid = 1; - webrtc::DataBuffer buffer("abcd"); + DataBuffer buffer("abcd"); webrtc_data_channel_->OnDataReceived(params, buffer.data); EXPECT_EQ(1U, observer_->messages_received()); @@ -443,17 +465,17 @@ TEST_F(SctpDataChannelTest, ReceiveDataWithValidId) { // Tests that no CONTROL message is sent if the datachannel is negotiated and // not created from an OPEN message. TEST_F(SctpDataChannelTest, NoMsgSentIfNegotiatedAndNotFromOpenMsg) { - webrtc::InternalDataChannelInit config; + InternalDataChannelInit config; config.id = 1; config.negotiated = true; - config.open_handshake_role = webrtc::InternalDataChannelInit::kNone; + config.open_handshake_role = InternalDataChannelInit::kNone; SetChannelReady(); rtc::scoped_refptr dc = SctpDataChannel::Create(controller_->weak_ptr(), "test1", config, rtc::Thread::Current(), rtc::Thread::Current()); - EXPECT_EQ_WAIT(webrtc::DataChannelInterface::kOpen, dc->state(), 1000); + EXPECT_EQ_WAIT(DataChannelInterface::kOpen, dc->state(), 1000); EXPECT_EQ(0, controller_->last_sid()); } @@ -461,16 +483,16 @@ TEST_F(SctpDataChannelTest, NoMsgSentIfNegotiatedAndNotFromOpenMsg) { // are correct, receiving data both while not open and while open. TEST_F(SctpDataChannelTest, VerifyMessagesAndBytesReceived) { AddObserver(); - std::vector buffers({ - webrtc::DataBuffer("message 1"), - webrtc::DataBuffer("msg 2"), - webrtc::DataBuffer("message three"), - webrtc::DataBuffer("quadra message"), - webrtc::DataBuffer("fifthmsg"), - webrtc::DataBuffer("message of the beast"), + std::vector buffers({ + DataBuffer("message 1"), + DataBuffer("msg 2"), + DataBuffer("message three"), + DataBuffer("quadra message"), + DataBuffer("fifthmsg"), + DataBuffer("message of the beast"), }); - webrtc_data_channel_->SetSctpSid(1); + webrtc_data_channel_->SetSctpSid(StreamId(1)); cricket::ReceiveDataParams params; params.sid = 1; @@ -507,33 +529,33 @@ TEST_F(SctpDataChannelTest, VerifyMessagesAndBytesReceived) { // Tests that OPEN_ACK message is sent if the datachannel is created from an // OPEN message. TEST_F(SctpDataChannelTest, OpenAckSentIfCreatedFromOpenMessage) { - webrtc::InternalDataChannelInit config; + InternalDataChannelInit config; config.id = 1; config.negotiated = true; - config.open_handshake_role = webrtc::InternalDataChannelInit::kAcker; + config.open_handshake_role = InternalDataChannelInit::kAcker; SetChannelReady(); rtc::scoped_refptr dc = SctpDataChannel::Create(controller_->weak_ptr(), "test1", config, rtc::Thread::Current(), rtc::Thread::Current()); - EXPECT_EQ_WAIT(webrtc::DataChannelInterface::kOpen, dc->state(), 1000); + EXPECT_EQ_WAIT(DataChannelInterface::kOpen, dc->state(), 1000); EXPECT_EQ(config.id, controller_->last_sid()); - EXPECT_EQ(webrtc::DataMessageType::kControl, + EXPECT_EQ(DataMessageType::kControl, controller_->last_send_data_params().type); } // Tests the OPEN_ACK role assigned by InternalDataChannelInit. TEST_F(SctpDataChannelTest, OpenAckRoleInitialization) { - webrtc::InternalDataChannelInit init; - EXPECT_EQ(webrtc::InternalDataChannelInit::kOpener, init.open_handshake_role); + InternalDataChannelInit init; + EXPECT_EQ(InternalDataChannelInit::kOpener, init.open_handshake_role); EXPECT_FALSE(init.negotiated); - webrtc::DataChannelInit base; + DataChannelInit base; base.negotiated = true; - webrtc::InternalDataChannelInit init2(base); - EXPECT_EQ(webrtc::InternalDataChannelInit::kNone, init2.open_handshake_role); + InternalDataChannelInit init2(base); + EXPECT_EQ(InternalDataChannelInit::kNone, init2.open_handshake_role); } // Tests that that Send() returns false if the sending buffer is full @@ -546,35 +568,32 @@ TEST_F(SctpDataChannelTest, OpenWhenSendBufferFull) { rtc::CopyOnWriteBuffer buffer(packetSize); memset(buffer.MutableData(), 0, buffer.size()); - webrtc::DataBuffer packet(buffer, true); + DataBuffer packet(buffer, true); controller_->set_send_blocked(true); - for (size_t i = 0; - i < webrtc::DataChannelInterface::MaxSendQueueSize() / packetSize; ++i) { + for (size_t i = 0; i < DataChannelInterface::MaxSendQueueSize() / packetSize; + ++i) { EXPECT_TRUE(webrtc_data_channel_->Send(packet)); } // The sending buffer shoul be full, send returns false. EXPECT_FALSE(webrtc_data_channel_->Send(packet)); - EXPECT_TRUE(webrtc::DataChannelInterface::kOpen == - webrtc_data_channel_->state()); + EXPECT_TRUE(DataChannelInterface::kOpen == webrtc_data_channel_->state()); } // Tests that the DataChannel is closed on transport errors. TEST_F(SctpDataChannelTest, ClosedOnTransportError) { SetChannelReady(); - webrtc::DataBuffer buffer("abcd"); + DataBuffer buffer("abcd"); controller_->set_transport_error(); EXPECT_TRUE(webrtc_data_channel_->Send(buffer)); - EXPECT_EQ(webrtc::DataChannelInterface::kClosed, - webrtc_data_channel_->state()); + EXPECT_EQ(DataChannelInterface::kClosed, webrtc_data_channel_->state()); EXPECT_FALSE(webrtc_data_channel_->error().ok()); - EXPECT_EQ(webrtc::RTCErrorType::NETWORK_ERROR, - webrtc_data_channel_->error().type()); - EXPECT_EQ(webrtc::RTCErrorDetailType::NONE, + EXPECT_EQ(RTCErrorType::NETWORK_ERROR, webrtc_data_channel_->error().type()); + EXPECT_EQ(RTCErrorDetailType::NONE, webrtc_data_channel_->error().error_detail()); } @@ -591,24 +610,23 @@ TEST_F(SctpDataChannelTest, ClosedWhenReceivedBufferFull) { for (size_t i = 0; i < 16 * 1024 + 1; ++i) { webrtc_data_channel_->OnDataReceived(params, buffer); } - EXPECT_EQ(webrtc::DataChannelInterface::kClosed, - webrtc_data_channel_->state()); + EXPECT_EQ(DataChannelInterface::kClosed, webrtc_data_channel_->state()); EXPECT_FALSE(webrtc_data_channel_->error().ok()); - EXPECT_EQ(webrtc::RTCErrorType::RESOURCE_EXHAUSTED, + EXPECT_EQ(RTCErrorType::RESOURCE_EXHAUSTED, webrtc_data_channel_->error().type()); - EXPECT_EQ(webrtc::RTCErrorDetailType::NONE, + EXPECT_EQ(RTCErrorDetailType::NONE, webrtc_data_channel_->error().error_detail()); } // Tests that sending empty data returns no error and keeps the channel open. TEST_F(SctpDataChannelTest, SendEmptyData) { - webrtc_data_channel_->SetSctpSid(1); + webrtc_data_channel_->SetSctpSid(StreamId(1)); SetChannelReady(); - EXPECT_EQ(webrtc::DataChannelInterface::kOpen, webrtc_data_channel_->state()); + EXPECT_EQ(DataChannelInterface::kOpen, webrtc_data_channel_->state()); - webrtc::DataBuffer buffer(""); + DataBuffer buffer(""); EXPECT_TRUE(webrtc_data_channel_->Send(buffer)); - EXPECT_EQ(webrtc::DataChannelInterface::kOpen, webrtc_data_channel_->state()); + EXPECT_EQ(DataChannelInterface::kOpen, webrtc_data_channel_->state()); } // Tests that a channel can be closed without being opened or assigned an sid. @@ -623,8 +641,7 @@ TEST_F(SctpDataChannelTest, NeverOpened) { // See also chromium:1421534. TEST_F(SctpDataChannelTest, UnusedTransitionsDirectlyToClosed) { webrtc_data_channel_->Close(); - EXPECT_EQ(webrtc::DataChannelInterface::kClosed, - webrtc_data_channel_->state()); + EXPECT_EQ(DataChannelInterface::kClosed, webrtc_data_channel_->state()); } // Test that the data channel goes to the "closed" state (and doesn't crash) @@ -634,7 +651,7 @@ TEST_F(SctpDataChannelTest, TransportDestroyedWhileDataBuffered) { rtc::CopyOnWriteBuffer buffer(1024); memset(buffer.MutableData(), 0, buffer.size()); - webrtc::DataBuffer packet(buffer, true); + DataBuffer packet(buffer, true); // Send a packet while sending is blocked so it ends up buffered. controller_->set_send_blocked(true); @@ -643,16 +660,16 @@ TEST_F(SctpDataChannelTest, TransportDestroyedWhileDataBuffered) { // Tell the data channel that its transport is being destroyed. // It should then stop using the transport (allowing us to delete it) and // transition to the "closed" state. - webrtc::RTCError error(webrtc::RTCErrorType::OPERATION_ERROR_WITH_DATA, ""); - error.set_error_detail(webrtc::RTCErrorDetailType::SCTP_FAILURE); + RTCError error(RTCErrorType::OPERATION_ERROR_WITH_DATA, ""); + error.set_error_detail(RTCErrorDetailType::SCTP_FAILURE); webrtc_data_channel_->OnTransportChannelClosed(error); controller_.reset(nullptr); - EXPECT_EQ_WAIT(webrtc::DataChannelInterface::kClosed, - webrtc_data_channel_->state(), kDefaultTimeout); + EXPECT_EQ_WAIT(DataChannelInterface::kClosed, webrtc_data_channel_->state(), + kDefaultTimeout); EXPECT_FALSE(webrtc_data_channel_->error().ok()); - EXPECT_EQ(webrtc::RTCErrorType::OPERATION_ERROR_WITH_DATA, + EXPECT_EQ(RTCErrorType::OPERATION_ERROR_WITH_DATA, webrtc_data_channel_->error().type()); - EXPECT_EQ(webrtc::RTCErrorDetailType::SCTP_FAILURE, + EXPECT_EQ(RTCErrorDetailType::SCTP_FAILURE, webrtc_data_channel_->error().error_detail()); } @@ -662,19 +679,19 @@ TEST_F(SctpDataChannelTest, TransportGotErrorCode) { // Tell the data channel that its transport is being destroyed with an // error code. // It should then report that error code. - webrtc::RTCError error(webrtc::RTCErrorType::OPERATION_ERROR_WITH_DATA, - "Transport channel closed"); - error.set_error_detail(webrtc::RTCErrorDetailType::SCTP_FAILURE); + RTCError error(RTCErrorType::OPERATION_ERROR_WITH_DATA, + "Transport channel closed"); + error.set_error_detail(RTCErrorDetailType::SCTP_FAILURE); error.set_sctp_cause_code( static_cast(cricket::SctpErrorCauseCode::kProtocolViolation)); webrtc_data_channel_->OnTransportChannelClosed(error); controller_.reset(nullptr); - EXPECT_EQ_WAIT(webrtc::DataChannelInterface::kClosed, - webrtc_data_channel_->state(), kDefaultTimeout); + EXPECT_EQ_WAIT(DataChannelInterface::kClosed, webrtc_data_channel_->state(), + kDefaultTimeout); EXPECT_FALSE(webrtc_data_channel_->error().ok()); - EXPECT_EQ(webrtc::RTCErrorType::OPERATION_ERROR_WITH_DATA, + EXPECT_EQ(RTCErrorType::OPERATION_ERROR_WITH_DATA, webrtc_data_channel_->error().type()); - EXPECT_EQ(webrtc::RTCErrorDetailType::SCTP_FAILURE, + EXPECT_EQ(RTCErrorDetailType::SCTP_FAILURE, webrtc_data_channel_->error().error_detail()); EXPECT_EQ( static_cast(cricket::SctpErrorCauseCode::kProtocolViolation), @@ -689,66 +706,80 @@ class SctpSidAllocatorTest : public ::testing::Test { // Verifies that an even SCTP id is allocated for SSL_CLIENT and an odd id for // SSL_SERVER. TEST_F(SctpSidAllocatorTest, SctpIdAllocationBasedOnRole) { - int id; + StreamId id; EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_SERVER, &id)); - EXPECT_EQ(1, id); + EXPECT_EQ(1, id.stream_id_int()); + id.reset(); EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_CLIENT, &id)); - EXPECT_EQ(0, id); + EXPECT_EQ(0, id.stream_id_int()); + id.reset(); EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_SERVER, &id)); - EXPECT_EQ(3, id); + EXPECT_EQ(3, id.stream_id_int()); + id.reset(); EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_CLIENT, &id)); - EXPECT_EQ(2, id); + EXPECT_EQ(2, id.stream_id_int()); } // Verifies that SCTP ids of existing DataChannels are not reused. TEST_F(SctpSidAllocatorTest, SctpIdAllocationNoReuse) { - int old_id = 1; + StreamId old_id(1); EXPECT_TRUE(allocator_.ReserveSid(old_id)); - int new_id; + StreamId new_id; EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_SERVER, &new_id)); EXPECT_NE(old_id, new_id); - old_id = 0; + old_id = StreamId(0); EXPECT_TRUE(allocator_.ReserveSid(old_id)); + new_id.reset(); EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_CLIENT, &new_id)); EXPECT_NE(old_id, new_id); } // Verifies that SCTP ids of removed DataChannels can be reused. TEST_F(SctpSidAllocatorTest, SctpIdReusedForRemovedDataChannel) { - int odd_id = 1; - int even_id = 0; + StreamId odd_id(1); + StreamId even_id(0); EXPECT_TRUE(allocator_.ReserveSid(odd_id)); EXPECT_TRUE(allocator_.ReserveSid(even_id)); - int allocated_id = -1; + StreamId allocated_id; EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_SERVER, &allocated_id)); - EXPECT_EQ(odd_id + 2, allocated_id); + EXPECT_EQ(odd_id.stream_id_int() + 2, allocated_id.stream_id_int()); + allocated_id.reset(); EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_CLIENT, &allocated_id)); - EXPECT_EQ(even_id + 2, allocated_id); + EXPECT_EQ(even_id.stream_id_int() + 2, allocated_id.stream_id_int()); + allocated_id.reset(); EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_SERVER, &allocated_id)); - EXPECT_EQ(odd_id + 4, allocated_id); + EXPECT_EQ(odd_id.stream_id_int() + 4, allocated_id.stream_id_int()); + allocated_id.reset(); EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_CLIENT, &allocated_id)); - EXPECT_EQ(even_id + 4, allocated_id); + EXPECT_EQ(even_id.stream_id_int() + 4, allocated_id.stream_id_int()); allocator_.ReleaseSid(odd_id); allocator_.ReleaseSid(even_id); // Verifies that removed ids are reused. + allocated_id.reset(); EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_SERVER, &allocated_id)); EXPECT_EQ(odd_id, allocated_id); + allocated_id.reset(); EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_CLIENT, &allocated_id)); EXPECT_EQ(even_id, allocated_id); // Verifies that used higher ids are not reused. + allocated_id.reset(); EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_SERVER, &allocated_id)); - EXPECT_EQ(odd_id + 6, allocated_id); + EXPECT_EQ(odd_id.stream_id_int() + 6, allocated_id.stream_id_int()); + allocated_id.reset(); EXPECT_TRUE(allocator_.AllocateSid(rtc::SSL_CLIENT, &allocated_id)); - EXPECT_EQ(even_id + 6, allocated_id); + EXPECT_EQ(even_id.stream_id_int() + 6, allocated_id.stream_id_int()); } + +} // namespace +} // namespace webrtc diff --git a/pc/sctp_data_channel.cc b/pc/sctp_data_channel.cc index d41baafc9d..81668dd4bc 100644 --- a/pc/sctp_data_channel.cc +++ b/pc/sctp_data_channel.cc @@ -18,7 +18,6 @@ #include "absl/cleanup/cleanup.h" #include "media/sctp/sctp_transport_internal.h" #include "pc/proxy.h" -#include "pc/sctp_utils.h" #include "rtc_base/checks.h" #include "rtc_base/logging.h" #include "rtc_base/system/unused.h" @@ -118,41 +117,26 @@ bool InternalDataChannelInit::IsValid() const { return true; } -bool SctpSidAllocator::AllocateSid(rtc::SSLRole role, int* sid) { +bool SctpSidAllocator::AllocateSid(rtc::SSLRole role, StreamId* sid) { int potential_sid = (role == rtc::SSL_CLIENT) ? 0 : 1; - while (!IsSidAvailable(potential_sid)) { + while (potential_sid <= static_cast(cricket::kMaxSctpSid)) { + *sid = StreamId(potential_sid); + if (used_sids_.insert(*sid).second) + return true; potential_sid += 2; - if (potential_sid > static_cast(cricket::kMaxSctpSid)) { - return false; - } } - - *sid = potential_sid; - used_sids_.insert(potential_sid); - return true; + sid->reset(); + return false; } -bool SctpSidAllocator::ReserveSid(int sid) { - if (!IsSidAvailable(sid)) { +bool SctpSidAllocator::ReserveSid(const StreamId& sid) { + if (!sid.HasValue() || sid.stream_id_int() > cricket::kMaxSctpSid) return false; - } - used_sids_.insert(sid); - return true; + return used_sids_.insert(sid).second; } -void SctpSidAllocator::ReleaseSid(int sid) { - auto it = used_sids_.find(sid); - if (it != used_sids_.end()) { - used_sids_.erase(it); - } -} - -bool SctpSidAllocator::IsSidAvailable(int sid) const { - if (sid < static_cast(cricket::kMinSctpSid) || - sid > static_cast(cricket::kMaxSctpSid)) { - return false; - } - return used_sids_.find(sid) == used_sids_.end(); +void SctpSidAllocator::ReleaseSid(const StreamId& sid) { + used_sids_.erase(sid); } // static @@ -192,16 +176,22 @@ SctpDataChannel::SctpDataChannel( rtc::Thread* network_thread) : signaling_thread_(signaling_thread), network_thread_(network_thread), + id_(config.id), internal_id_(GenerateUniqueId()), label_(label), - config_(config), + protocol_(config.protocol), + max_retransmit_time_(config.maxRetransmitTime), + max_retransmits_(config.maxRetransmits), + priority_(config.priority), + negotiated_(config.negotiated), + ordered_(config.ordered), observer_(nullptr), controller_(std::move(controller)) { RTC_DCHECK_RUN_ON(signaling_thread_); RTC_UNUSED(network_thread_); - RTC_DCHECK(config_.IsValid()); + RTC_DCHECK(config.IsValid()); - switch (config_.open_handshake_role) { + switch (config.open_handshake_role) { case webrtc::InternalDataChannelInit::kNone: // pre-negotiated handshake_state_ = kHandshakeReady; break; @@ -252,9 +242,50 @@ void SctpDataChannel::UnregisterObserver() { observer_ = nullptr; } +std::string SctpDataChannel::label() const { + return label_; +} + bool SctpDataChannel::reliable() const { // May be called on any thread. - return !config_.maxRetransmits && !config_.maxRetransmitTime; + return !max_retransmits_ && !max_retransmit_time_; +} + +bool SctpDataChannel::ordered() const { + return ordered_; +} + +uint16_t SctpDataChannel::maxRetransmitTime() const { + return max_retransmit_time_ ? *max_retransmit_time_ + : static_cast(-1); +} + +uint16_t SctpDataChannel::maxRetransmits() const { + return max_retransmits_ ? *max_retransmits_ : static_cast(-1); +} + +absl::optional SctpDataChannel::maxPacketLifeTime() const { + return max_retransmit_time_; +} + +absl::optional SctpDataChannel::maxRetransmitsOpt() const { + return max_retransmits_; +} + +std::string SctpDataChannel::protocol() const { + return protocol_; +} + +bool SctpDataChannel::negotiated() const { + return negotiated_; +} + +int SctpDataChannel::id() const { + return id_.stream_id_int(); +} + +Priority SctpDataChannel::priority() const { + return priority_ ? *priority_ : Priority::kLow; } uint64_t SctpDataChannel::buffered_amount() const { @@ -327,24 +358,24 @@ bool SctpDataChannel::Send(const DataBuffer& buffer) { return true; } -void SctpDataChannel::SetSctpSid(int sid) { +void SctpDataChannel::SetSctpSid(const StreamId& sid) { RTC_DCHECK_RUN_ON(signaling_thread_); - RTC_DCHECK_LT(config_.id, 0); - RTC_DCHECK_GE(sid, 0); + RTC_DCHECK(!id_.HasValue()); + RTC_DCHECK(sid.HasValue()); RTC_DCHECK_NE(handshake_state_, kHandshakeWaitingForAck); RTC_DCHECK_EQ(state_, kConnecting); - if (config_.id == sid) { + if (id_ == sid) { return; } - const_cast(config_).id = sid; - controller_->AddSctpDataStream(sid); + id_ = sid; + controller_->AddSctpDataStream(sid.stream_id_int()); } void SctpDataChannel::OnClosingProcedureStartedRemotely(int sid) { RTC_DCHECK_RUN_ON(signaling_thread_); - if (sid == config_.id && state_ != kClosing && state_ != kClosed) { + if (id_.stream_id_int() == sid && state_ != kClosing && state_ != kClosed) { // Don't bother sending queued data since the side that initiated the // closure wouldn't receive it anyway. See crbug.com/559394 for a lengthy // discussion about this. @@ -360,7 +391,7 @@ void SctpDataChannel::OnClosingProcedureStartedRemotely(int sid) { void SctpDataChannel::OnClosingProcedureComplete(int sid) { RTC_DCHECK_RUN_ON(signaling_thread_); - if (sid == config_.id) { + if (id_.stream_id_int() == sid) { // If the closing procedure is complete, we should have finished sending // all pending data and transitioned to kClosing already. RTC_DCHECK_EQ(state_, kClosing); @@ -380,13 +411,13 @@ void SctpDataChannel::OnTransportChannelCreated() { } // The sid may have been unassigned when controller_->ConnectDataChannel was // done. So always add the streams even if connected_to_transport_ is true. - if (config_.id >= 0) { - controller_->AddSctpDataStream(config_.id); + if (id_.HasValue()) { + controller_->AddSctpDataStream(id_.stream_id_int()); } } void SctpDataChannel::OnTransportChannelClosed(RTCError error) { - // The SctpTransport is unusable, which could come from multiplie reasons: + // The SctpTransport is unusable, which could come from multiple reasons: // - the SCTP m= section was rejected // - the DTLS transport is closed // - the SCTP transport is closed @@ -404,7 +435,7 @@ DataChannelStats SctpDataChannel::GetStats() const { void SctpDataChannel::OnDataReceived(const cricket::ReceiveDataParams& params, const rtc::CopyOnWriteBuffer& payload) { RTC_DCHECK_RUN_ON(signaling_thread_); - if (params.sid != config_.id) { + if (id_.stream_id_int() != params.sid) { return; } @@ -518,7 +549,9 @@ void SctpDataChannel::UpdateState() { if (connected_to_transport_) { if (handshake_state_ == kHandshakeShouldSendOpen) { rtc::CopyOnWriteBuffer payload; - WriteDataChannelOpenMessage(label_, config_, &payload); + WriteDataChannelOpenMessage(label_, protocol_, priority_, ordered_, + max_retransmits_, max_retransmit_time_, + &payload); SendControlMessage(payload); } else if (handshake_state_ == kHandshakeShouldSendAck) { rtc::CopyOnWriteBuffer payload; @@ -547,9 +580,9 @@ void SctpDataChannel::UpdateState() { // to complete; after calling RemoveSctpDataStream, // OnClosingProcedureComplete will end up called asynchronously // afterwards. - if (!started_closing_procedure_ && controller_ && config_.id >= 0) { + if (!started_closing_procedure_ && controller_ && id_.HasValue()) { started_closing_procedure_ = true; - controller_->RemoveSctpDataStream(config_.id); + controller_->RemoveSctpDataStream(id_.stream_id_int()); } } } else { @@ -630,23 +663,23 @@ bool SctpDataChannel::SendDataMessage(const DataBuffer& buffer, return false; } - send_params.ordered = config_.ordered; + send_params.ordered = ordered_; // Send as ordered if it is still going through OPEN/ACK signaling. - if (handshake_state_ != kHandshakeReady && !config_.ordered) { + if (handshake_state_ != kHandshakeReady && !ordered_) { send_params.ordered = true; RTC_LOG(LS_VERBOSE) << "Sending data as ordered for unordered DataChannel " "because the OPEN_ACK message has not been received."; } - send_params.max_rtx_count = config_.maxRetransmits; - send_params.max_rtx_ms = config_.maxRetransmitTime; + send_params.max_rtx_count = max_retransmits_; + send_params.max_rtx_ms = max_retransmit_time_; send_params.type = buffer.binary ? DataMessageType::kBinary : DataMessageType::kText; cricket::SendDataResult send_result = cricket::SDR_SUCCESS; - bool success = - controller_->SendData(config_.id, send_params, buffer.data, &send_result); + bool success = controller_->SendData(id_.stream_id_int(), send_params, + buffer.data, &send_result); if (success) { ++messages_sent_; @@ -706,26 +739,27 @@ void SctpDataChannel::QueueControlMessage( bool SctpDataChannel::SendControlMessage(const rtc::CopyOnWriteBuffer& buffer) { RTC_DCHECK_RUN_ON(signaling_thread_); RTC_DCHECK(writable_); - RTC_DCHECK_GE(config_.id, 0); + RTC_DCHECK(id_.HasValue()); if (!controller_) { return false; } bool is_open_message = handshake_state_ == kHandshakeShouldSendOpen; - RTC_DCHECK(!is_open_message || !config_.negotiated); + RTC_DCHECK(!is_open_message || !negotiated_); SendDataParams send_params; // Send data as ordered before we receive any message from the remote peer to // make sure the remote peer will not receive any data before it receives the // OPEN message. - send_params.ordered = config_.ordered || is_open_message; + send_params.ordered = ordered_ || is_open_message; send_params.type = DataMessageType::kControl; cricket::SendDataResult send_result = cricket::SDR_SUCCESS; - bool retval = - controller_->SendData(config_.id, send_params, buffer, &send_result); + bool retval = controller_->SendData(id_.stream_id_int(), send_params, buffer, + &send_result); if (retval) { - RTC_LOG(LS_VERBOSE) << "Sent CONTROL message on channel " << config_.id; + RTC_LOG(LS_VERBOSE) << "Sent CONTROL message on channel " + << id_.stream_id_int(); if (handshake_state_ == kHandshakeShouldSendAck) { handshake_state_ = kHandshakeReady; diff --git a/pc/sctp_data_channel.h b/pc/sctp_data_channel.h index 91daaf7e4e..7fa7173b59 100644 --- a/pc/sctp_data_channel.h +++ b/pc/sctp_data_channel.h @@ -25,6 +25,8 @@ #include "api/transport/data_channel_transport_interface.h" #include "media/base/media_channel.h" #include "pc/data_channel_utils.h" +#include "pc/sctp_utils.h" +#include "rtc_base/containers/flat_set.h" #include "rtc_base/copy_on_write_buffer.h" #include "rtc_base/ssl_stream_adapter.h" // For SSLRole #include "rtc_base/third_party/sigslot/sigslot.h" @@ -64,8 +66,6 @@ class SctpDataChannelControllerInterface { virtual ~SctpDataChannelControllerInterface() {} }; -// TODO(tommi): Change to not inherit from DataChannelInit but to have it as -// a const member. Block access to the 'id' member since it cannot be const. struct InternalDataChannelInit : public DataChannelInit { enum OpenHandshakeRole { kOpener, kAcker, kNone }; // The default role is kOpener because the default `negotiated` is false. @@ -86,19 +86,16 @@ class SctpSidAllocator { // SSL_CLIENT, the allocated id starts from 0 and takes even numbers; // otherwise, the id starts from 1 and takes odd numbers. // Returns false if no ID can be allocated. - bool AllocateSid(rtc::SSLRole role, int* sid); + bool AllocateSid(rtc::SSLRole role, StreamId* sid); // Attempts to reserve a specific sid. Returns false if it's unavailable. - bool ReserveSid(int sid); + bool ReserveSid(const StreamId& sid); // Indicates that `sid` isn't in use any more, and is thus available again. - void ReleaseSid(int sid); + void ReleaseSid(const StreamId& sid); private: - // Checks if `sid` is available to be assigned to a new SCTP data channel. - bool IsSidAvailable(int sid) const; - - std::set used_sids_; + webrtc::flat_set used_sids_; }; // SctpDataChannel is an implementation of the DataChannelInterface based on @@ -143,32 +140,20 @@ class SctpDataChannel : public DataChannelInterface, void RegisterObserver(DataChannelObserver* observer) override; void UnregisterObserver() override; - std::string label() const override { return label_; } + std::string label() const override; bool reliable() const override; - bool ordered() const override { return config_.ordered; } - // Backwards compatible accessors - uint16_t maxRetransmitTime() const override { - return config_.maxRetransmitTime ? *config_.maxRetransmitTime - : static_cast(-1); - } - uint16_t maxRetransmits() const override { - return config_.maxRetransmits ? *config_.maxRetransmits - : static_cast(-1); - } - absl::optional maxPacketLifeTime() const override { - return config_.maxRetransmitTime; - } - absl::optional maxRetransmitsOpt() const override { - return config_.maxRetransmits; - } - std::string protocol() const override { return config_.protocol; } - bool negotiated() const override { return config_.negotiated; } - int id() const override { return config_.id; } - Priority priority() const override { - return config_.priority ? *config_.priority : Priority::kLow; - } + bool ordered() const override; - virtual int internal_id() const { return internal_id_; } + // Backwards compatible accessors + uint16_t maxRetransmitTime() const override; + uint16_t maxRetransmits() const override; + + absl::optional maxPacketLifeTime() const override; + absl::optional maxRetransmitsOpt() const override; + std::string protocol() const override; + bool negotiated() const override; + int id() const override; + Priority priority() const override; uint64_t buffered_amount() const override; void Close() override; @@ -202,7 +187,7 @@ class SctpDataChannel : public DataChannelInterface, // Sets the SCTP sid and adds to transport layer if not set yet. Should only // be called once. - void SetSctpSid(int sid); + void SetSctpSid(const StreamId& sid); // The remote side started the closing procedure by resetting its outgoing // stream (our incoming stream). Sets state to kClosing. void OnClosingProcedureStartedRemotely(int sid); @@ -220,6 +205,8 @@ class SctpDataChannel : public DataChannelInterface, DataChannelStats GetStats() const; + const StreamId& sid() const { return id_; } + // Reset the allocator for internal ID values for testing, so that // the internal IDs generated are predictable. Test only. static void ResetInternalIdAllocatorForTesting(int new_value); @@ -259,9 +246,16 @@ class SctpDataChannel : public DataChannelInterface, rtc::Thread* const signaling_thread_; rtc::Thread* const network_thread_; + StreamId id_; const int internal_id_; const std::string label_; - const InternalDataChannelInit config_; + const std::string protocol_; + const absl::optional max_retransmit_time_; + const absl::optional max_retransmits_; + const absl::optional priority_; + const bool negotiated_; + const bool ordered_; + DataChannelObserver* observer_ RTC_GUARDED_BY(signaling_thread_) = nullptr; DataState state_ RTC_GUARDED_BY(signaling_thread_) = kConnecting; RTCError error_ RTC_GUARDED_BY(signaling_thread_); diff --git a/pc/sctp_utils.cc b/pc/sctp_utils.cc index dc83da4f62..3677a9a0bb 100644 --- a/pc/sctp_utils.cc +++ b/pc/sctp_utils.cc @@ -16,6 +16,7 @@ #include "absl/types/optional.h" #include "api/priority.h" +#include "media/sctp/sctp_transport_internal.h" #include "rtc_base/byte_buffer.h" #include "rtc_base/copy_on_write_buffer.h" #include "rtc_base/logging.h" @@ -46,6 +47,53 @@ enum DataChannelPriority { DCO_PRIORITY_HIGH = 1024, }; +StreamId::StreamId() : id_(absl::nullopt) { + thread_checker_.Detach(); +} + +StreamId::StreamId(int id) + : id_(id >= cricket::kMinSctpSid && id <= cricket::kSpecMaxSctpSid + ? absl::optional(static_cast(id)) + : absl::nullopt) { + thread_checker_.Detach(); +} + +StreamId::StreamId(const StreamId& sid) : id_(sid.id_) {} + +bool StreamId::HasValue() const { + RTC_DCHECK_RUN_ON(&thread_checker_); + return id_.has_value(); +} + +int StreamId::stream_id_int() const { + RTC_DCHECK_RUN_ON(&thread_checker_); + return id_.has_value() ? static_cast(id_.value().value()) : -1; +} + +void StreamId::reset() { + RTC_DCHECK_RUN_ON(&thread_checker_); + id_ = absl::nullopt; +} + +StreamId& StreamId::operator=(const StreamId& sid) { + RTC_DCHECK_RUN_ON(&thread_checker_); + RTC_DCHECK_RUN_ON(&sid.thread_checker_); + id_ = sid.id_; + return *this; +} + +bool StreamId::operator==(const StreamId& sid) const { + RTC_DCHECK_RUN_ON(&thread_checker_); + RTC_DCHECK_RUN_ON(&sid.thread_checker_); + return id_ == sid.id_; +} + +bool StreamId::operator<(const StreamId& sid) const { + RTC_DCHECK_RUN_ON(&thread_checker_); + RTC_DCHECK_RUN_ON(&sid.thread_checker_); + return id_ < sid.id_; +} + bool IsOpenMessage(const rtc::CopyOnWriteBuffer& payload) { // Format defined at // https://www.rfc-editor.org/rfc/rfc8832#section-5.1 @@ -165,6 +213,18 @@ bool ParseDataChannelOpenAckMessage(const rtc::CopyOnWriteBuffer& payload) { bool WriteDataChannelOpenMessage(const std::string& label, const DataChannelInit& config, rtc::CopyOnWriteBuffer* payload) { + return WriteDataChannelOpenMessage(label, config.protocol, config.priority, + config.ordered, config.maxRetransmits, + config.maxRetransmitTime, payload); +} + +bool WriteDataChannelOpenMessage(const std::string& label, + const std::string& protocol, + absl::optional opt_priority, + bool ordered, + absl::optional max_retransmits, + absl::optional max_retransmit_time, + rtc::CopyOnWriteBuffer* payload) { // Format defined at // http://tools.ietf.org/html/draft-ietf-rtcweb-data-protocol-09#section-5.1 uint8_t channel_type = 0; @@ -172,8 +232,8 @@ bool WriteDataChannelOpenMessage(const std::string& label, uint16_t priority = 0; // Set priority according to // https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-12#section-6.4 - if (config.priority) { - switch (*config.priority) { + if (opt_priority) { + switch (*opt_priority) { case Priority::kVeryLow: priority = DCO_PRIORITY_VERY_LOW; break; @@ -188,39 +248,38 @@ bool WriteDataChannelOpenMessage(const std::string& label, break; } } - if (config.ordered) { - if (config.maxRetransmits) { + if (ordered) { + if (max_retransmits) { channel_type = DCOMCT_ORDERED_PARTIAL_RTXS; - reliability_param = *config.maxRetransmits; - } else if (config.maxRetransmitTime) { + reliability_param = *max_retransmits; + } else if (max_retransmit_time) { channel_type = DCOMCT_ORDERED_PARTIAL_TIME; - reliability_param = *config.maxRetransmitTime; + reliability_param = *max_retransmit_time; } else { channel_type = DCOMCT_ORDERED_RELIABLE; } } else { - if (config.maxRetransmits) { + if (max_retransmits) { channel_type = DCOMCT_UNORDERED_PARTIAL_RTXS; - reliability_param = *config.maxRetransmits; - } else if (config.maxRetransmitTime) { + reliability_param = *max_retransmits; + } else if (max_retransmit_time) { channel_type = DCOMCT_UNORDERED_PARTIAL_TIME; - reliability_param = *config.maxRetransmitTime; + reliability_param = *max_retransmit_time; } else { channel_type = DCOMCT_UNORDERED_RELIABLE; } } - rtc::ByteBufferWriter buffer(NULL, - 20 + label.length() + config.protocol.length()); + rtc::ByteBufferWriter buffer(NULL, 20 + label.length() + protocol.length()); // TODO(tommi): Add error handling and check resulting length. buffer.WriteUInt8(DATA_CHANNEL_OPEN_MESSAGE_TYPE); buffer.WriteUInt8(channel_type); buffer.WriteUInt16(priority); buffer.WriteUInt32(reliability_param); buffer.WriteUInt16(static_cast(label.length())); - buffer.WriteUInt16(static_cast(config.protocol.length())); + buffer.WriteUInt16(static_cast(protocol.length())); buffer.WriteString(label); - buffer.WriteString(config.protocol); + buffer.WriteString(protocol); payload->SetData(buffer.Data(), buffer.Length()); return true; } diff --git a/pc/sctp_utils.h b/pc/sctp_utils.h index da854458f4..d0c66defe7 100644 --- a/pc/sctp_utils.h +++ b/pc/sctp_utils.h @@ -14,9 +14,13 @@ #include #include "api/data_channel_interface.h" +#include "api/sequence_checker.h" #include "api/transport/data_channel_transport_interface.h" #include "media/base/media_channel.h" +#include "net/dcsctp/public/types.h" #include "rtc_base/copy_on_write_buffer.h" +#include "rtc_base/ssl_stream_adapter.h" // For SSLRole +#include "rtc_base/system/no_unique_address.h" namespace rtc { class CopyOnWriteBuffer; @@ -25,6 +29,41 @@ class CopyOnWriteBuffer; namespace webrtc { struct DataChannelInit; +// Wraps the `uint16_t` sctp data channel stream id value and does range +// checking. The class interface is `int` based to ease with DataChannelInit +// compatibility and types used in `DataChannelController`'s interface. Going +// forward, `int` compatibility won't be needed and we can either just use +// this class or the internal dcsctp::StreamID type. +class StreamId { + public: + StreamId(); + explicit StreamId(int id); + explicit StreamId(const StreamId& sid); + + // Returns `true` if a valid stream id is contained, in the range of + // kMinSctpSid - kSpecMaxSctpSid ([0..0xffff]). Note that this + // is different than having `kMaxSctpSid` as the upper bound, which is + // the limit that is internally used by `SctpSidAllocator`. Sid values may + // be assigned to `StreamId` outside of `SctpSidAllocator` and have a higher + // id value than supplied by `SctpSidAllocator`, yet is still valid. + bool HasValue() const; + + // Provided for compatibility with existing code that hasn't been updated + // to use `StreamId` directly. New code should not use 'int' for the stream + // id but rather `StreamId` directly. + int stream_id_int() const; + void reset(); + + StreamId& operator=(const StreamId& sid); + bool operator==(const StreamId& sid) const; + bool operator<(const StreamId& sid) const; + bool operator!=(const StreamId& sid) const { return !(operator==(sid)); } + + private: + RTC_NO_UNIQUE_ADDRESS webrtc::SequenceChecker thread_checker_; + absl::optional id_ RTC_GUARDED_BY(thread_checker_); +}; + // Read the message type and return true if it's an OPEN message. bool IsOpenMessage(const rtc::CopyOnWriteBuffer& payload); @@ -34,10 +73,16 @@ bool ParseDataChannelOpenMessage(const rtc::CopyOnWriteBuffer& payload, bool ParseDataChannelOpenAckMessage(const rtc::CopyOnWriteBuffer& payload); +bool WriteDataChannelOpenMessage(const std::string& label, + const std::string& protocol, + absl::optional priority, + bool ordered, + absl::optional max_retransmits, + absl::optional max_retransmit_time, + rtc::CopyOnWriteBuffer* payload); bool WriteDataChannelOpenMessage(const std::string& label, const DataChannelInit& config, rtc::CopyOnWriteBuffer* payload); - void WriteDataChannelOpenAckMessage(rtc::CopyOnWriteBuffer* payload); } // namespace webrtc diff --git a/pc/sctp_utils_unittest.cc b/pc/sctp_utils_unittest.cc index 146886b8cb..3e49824b45 100644 --- a/pc/sctp_utils_unittest.cc +++ b/pc/sctp_utils_unittest.cc @@ -12,12 +12,17 @@ #include +#include + #include "absl/types/optional.h" #include "api/priority.h" +#include "media/sctp/sctp_transport_internal.h" #include "rtc_base/byte_buffer.h" #include "rtc_base/copy_on_write_buffer.h" #include "test/gtest.h" +using webrtc::StreamId; + class SctpUtilsTest : public ::testing::Test { public: void VerifyOpenMessageFormat(const rtc::CopyOnWriteBuffer& packet, @@ -194,3 +199,44 @@ TEST_F(SctpUtilsTest, TestIsOpenMessage) { rtc::CopyOnWriteBuffer empty; EXPECT_FALSE(webrtc::IsOpenMessage(empty)); } + +TEST(SctpSidTest, Basics) { + // These static asserts are mostly here to aid with readability (i.e. knowing + // what these constants represent). + static_assert(cricket::kMinSctpSid == 0, "Min stream id should be 0"); + static_assert(cricket::kMaxSctpSid <= cricket::kSpecMaxSctpSid, ""); + static_assert( + cricket::kSpecMaxSctpSid == std::numeric_limits::max(), + "Max legal sctp stream value should be 0xffff"); + + // cricket::kMaxSctpSid is a chosen value in the webrtc implementation, + // the highest generated `sid` value chosen for resource reservation reasons. + // It's one less than kMaxSctpStreams (1024) or 1023 since sid values are + // zero based. + + EXPECT_TRUE(!StreamId(-1).HasValue()); + EXPECT_TRUE(!StreamId(-2).HasValue()); + EXPECT_TRUE(StreamId(cricket::kMinSctpSid).HasValue()); + EXPECT_TRUE(StreamId(cricket::kMinSctpSid + 1).HasValue()); + EXPECT_TRUE(StreamId(cricket::kSpecMaxSctpSid).HasValue()); + EXPECT_TRUE(StreamId(cricket::kMaxSctpSid).HasValue()); + + // Two illegal values are equal (both not valid). + EXPECT_EQ(StreamId(-1), StreamId(-2)); + // Two different, but legal, values, are not equal. + EXPECT_NE(StreamId(1), StreamId(2)); + // Test operator<() for container compatibility. + EXPECT_LT(StreamId(1), StreamId(2)); + + // Test assignment, value() and reset(). + StreamId sid1; + StreamId sid2(cricket::kMaxSctpSid); + EXPECT_NE(sid1, sid2); + sid1 = sid2; + EXPECT_EQ(sid1, sid2); + + EXPECT_EQ(sid1.stream_id_int(), cricket::kMaxSctpSid); + EXPECT_TRUE(sid1.HasValue()); + sid1.reset(); + EXPECT_FALSE(sid1.HasValue()); +}