diff --git a/pc/data_channel_controller.cc b/pc/data_channel_controller.cc index e42c96916e..39a31e625d 100644 --- a/pc/data_channel_controller.cc +++ b/pc/data_channel_controller.cc @@ -117,7 +117,7 @@ void DataChannelController::OnChannelClosed(int channel_id) { StreamId sid(channel_id); sid_allocator_.ReleaseSid(sid); auto it = absl::c_find_if(sctp_data_channels_n_, - [&](const auto& c) { return c->sid() == sid; }); + [&](const auto& c) { return c->sid_n() == sid; }); if (it != sctp_data_channels_n_.end()) sctp_data_channels_n_.erase(it); @@ -343,9 +343,10 @@ void DataChannelController::AllocateSctpSids(rtc::SSLRole role) { RTC_DCHECK_RUN_ON(network_thread()); for (auto it = sctp_data_channels_n_.begin(); it != sctp_data_channels_n_.end();) { - if (!(*it)->sid().HasValue()) { + if (!(*it)->sid_n().HasValue()) { StreamId sid = sid_allocator_.AllocateSid(role); if (sid.HasValue()) { + (*it)->SetSctpSid_n(sid); AddSctpDataStream(sid); channels_to_update.push_back(std::make_pair((*it).get(), sid)); } else { @@ -373,22 +374,20 @@ void DataChannelController::AllocateSctpSids(rtc::SSLRole role) { return c.get() == pair.first; }); RTC_DCHECK(it != sctp_data_channels_.end()); - (*it)->SetSctpSid(pair.second); + (*it)->SetSctpSid_s(pair.second); } } void DataChannelController::OnSctpDataChannelClosed(SctpDataChannel* channel) { RTC_DCHECK_RUN_ON(signaling_thread()); - // TODO(tommi): `sid()` should be called on the network thread. - // `sid()` and `SctpDataChannel::id_`should have thread guards to enforce - // correct usage. - network_thread()->BlockingCall([&, sid = channel->sid()] { + network_thread()->BlockingCall([&] { RTC_DCHECK_RUN_ON(network_thread()); // After the closing procedure is done, it's safe to use this ID for // another data channel. - if (sid.HasValue()) - sid_allocator_.ReleaseSid(sid); + if (channel->sid_n().HasValue()) { + sid_allocator_.ReleaseSid(channel->sid_n()); + } auto it = absl::c_find_if(sctp_data_channels_n_, [&](const auto& c) { return c.get() == channel; @@ -463,15 +462,14 @@ void DataChannelController::NotifyDataChannelsOfTransportCreated() { RTC_DCHECK_RUN_ON(network_thread()); RTC_DCHECK(data_channel_transport()); - // TODO(tommi): Move the blocking call to `AddSctpDataStream` from - // `SctpDataChannel::OnTransportChannelCreated` to here and be consistent - // with other call sites to `AddSctpDataStream`. We're already - // on the right (network) thread here. + for (const auto& channel : sctp_data_channels_n_) { + if (channel->sid_n().HasValue()) + AddSctpDataStream(channel->sid_n()); + } signaling_thread()->PostTask(SafeTask(signaling_safety_.flag(), [this] { RTC_DCHECK_RUN_ON(signaling_thread()); - auto copy = sctp_data_channels_; - for (const auto& channel : copy) { + for (const auto& channel : sctp_data_channels_) { channel->OnTransportChannelCreated(); } })); @@ -480,8 +478,9 @@ void DataChannelController::NotifyDataChannelsOfTransportCreated() { std::vector>::iterator DataChannelController::FindChannel(StreamId stream_id) { RTC_DCHECK_RUN_ON(signaling_thread()); - return absl::c_find_if(sctp_data_channels_, - [&](const auto& c) { return c->sid() == stream_id; }); + return absl::c_find_if(sctp_data_channels_, [&](const auto& c) { + return c->sid_s() == stream_id; + }); } rtc::Thread* DataChannelController::network_thread() const { diff --git a/pc/data_channel_unittest.cc b/pc/data_channel_unittest.cc index a970aaa345..80a349dfe5 100644 --- a/pc/data_channel_unittest.cc +++ b/pc/data_channel_unittest.cc @@ -89,7 +89,7 @@ class SctpDataChannelTest : public ::testing::Test { void SetChannelReady() { controller_->set_transport_available(true); webrtc_data_channel_->OnTransportChannelCreated(); - if (!webrtc_data_channel_->sid().HasValue()) { + if (!webrtc_data_channel_->sid_s().HasValue()) { SetChannelSid(webrtc_data_channel_, StreamId(0)); } controller_->set_ready_to_send(true); @@ -105,7 +105,7 @@ class SctpDataChannelTest : public ::testing::Test { RTC_DCHECK(sid.HasValue()); network_thread_.BlockingCall( [&]() { controller_->AddSctpDataStream(sid); }); - channel->SetSctpSid(sid); + channel->SetSctpSid_s(sid); } void AddObserver() { @@ -141,11 +141,11 @@ TEST_F(SctpDataChannelTest, VerifyConfigurationGetters) { // Check the non-const part of the configuration. EXPECT_EQ(webrtc_data_channel_->id(), init_.id); - EXPECT_EQ(webrtc_data_channel_->sid(), StreamId()); + EXPECT_EQ(webrtc_data_channel_->sid_s(), StreamId()); SetChannelReady(); EXPECT_EQ(webrtc_data_channel_->id(), 0); - EXPECT_EQ(webrtc_data_channel_->sid(), StreamId(0)); + EXPECT_EQ(webrtc_data_channel_->sid_s(), StreamId(0)); } // Verifies that the data channel is connected to the transport after creation. @@ -156,10 +156,10 @@ TEST_F(SctpDataChannelTest, ConnectedToTransportOnCreated) { EXPECT_TRUE(controller_->IsConnected(dc.get())); // The sid is not set yet, so it should not have added the streams. - EXPECT_FALSE(controller_->IsStreamAdded(dc->sid())); + EXPECT_FALSE(controller_->IsStreamAdded(dc->sid_s())); SetChannelSid(dc, StreamId(0)); - EXPECT_TRUE(controller_->IsStreamAdded(dc->sid())); + EXPECT_TRUE(controller_->IsStreamAdded(dc->sid_s())); } // Tests the state of the data channel. diff --git a/pc/sctp_data_channel.cc b/pc/sctp_data_channel.cc index 825f671b6b..623a153067 100644 --- a/pc/sctp_data_channel.cc +++ b/pc/sctp_data_channel.cc @@ -175,7 +175,8 @@ SctpDataChannel::SctpDataChannel( rtc::Thread* network_thread) : signaling_thread_(signaling_thread), network_thread_(network_thread), - id_(config.id), + id_s_(config.id), + id_n_(config.id), internal_id_(GenerateUniqueId()), label_(label), protocol_(config.protocol), @@ -260,7 +261,8 @@ bool SctpDataChannel::negotiated() const { } int SctpDataChannel::id() const { - return id_.stream_id_int(); + RTC_DCHECK_RUN_ON(signaling_thread_); + return id_s_.stream_id_int(); } Priority SctpDataChannel::priority() const { @@ -333,14 +335,21 @@ bool SctpDataChannel::Send(const DataBuffer& buffer) { return true; } -void SctpDataChannel::SetSctpSid(const StreamId& sid) { +void SctpDataChannel::SetSctpSid_s(StreamId sid) { RTC_DCHECK_RUN_ON(signaling_thread_); - RTC_DCHECK(!id_.HasValue()); + RTC_DCHECK(!id_s_.HasValue()); RTC_DCHECK(sid.HasValue()); RTC_DCHECK_NE(handshake_state_, kHandshakeWaitingForAck); RTC_DCHECK_EQ(state_, kConnecting); - id_ = sid; + id_s_ = sid; +} + +void SctpDataChannel::SetSctpSid_n(StreamId sid) { + RTC_DCHECK_RUN_ON(network_thread_); + RTC_DCHECK(!id_n_.HasValue()); + RTC_DCHECK(sid.HasValue()); + id_n_ = sid; } void SctpDataChannel::OnClosingProcedureStartedRemotely() { @@ -370,16 +379,8 @@ void SctpDataChannel::OnClosingProcedureComplete() { void SctpDataChannel::OnTransportChannelCreated() { RTC_DCHECK_RUN_ON(signaling_thread_); - RTC_DCHECK(controller_); connected_to_transport_ = true; - - if (id_.HasValue()) { - // TODO(bugs.webrtc.org/11547): Move this call over to DCC and do it when we - // get the initial notification from the transport, on the network thread. - network_thread_->BlockingCall( - [c = controller_.get(), sid = id_] { c->AddSctpDataStream(sid); }); - } } void SctpDataChannel::OnTransportChannelClosed(RTCError error) { @@ -407,18 +408,18 @@ void SctpDataChannel::OnDataReceived(DataMessageType type, // Ignore it if we are not expecting an ACK message. RTC_LOG(LS_WARNING) << "DataChannel received unexpected CONTROL message, sid = " - << id_.stream_id_int(); + << id_s_.stream_id_int(); return; } if (ParseDataChannelOpenAckMessage(payload)) { // We can send unordered as soon as we receive the ACK message. handshake_state_ = kHandshakeReady; RTC_LOG(LS_INFO) << "DataChannel received OPEN_ACK message, sid = " - << id_.stream_id_int(); + << id_s_.stream_id_int(); } else { RTC_LOG(LS_WARNING) << "DataChannel failed to parse OPEN_ACK message, sid = " - << id_.stream_id_int(); + << id_s_.stream_id_int(); } return; } @@ -427,7 +428,7 @@ void SctpDataChannel::OnDataReceived(DataMessageType type, type == DataMessageType::kText); RTC_DLOG(LS_VERBOSE) << "DataChannel received DATA message, sid = " - << id_.stream_id_int(); + << id_s_.stream_id_int(); // We can send unordered as soon as we receive any DATA message since the // remote side must have received the OPEN (and old clients do not send // OPEN_ACK). @@ -514,7 +515,7 @@ void SctpDataChannel::UpdateState() { switch (state_) { case kConnecting: { - if (connected_to_transport_) { + if (connected_to_transport_ && controller_) { if (handshake_state_ == kHandshakeShouldSendOpen) { rtc::CopyOnWriteBuffer payload; WriteDataChannelOpenMessage(label_, protocol_, priority_, ordered_, @@ -534,7 +535,7 @@ void SctpDataChannel::UpdateState() { DeliverQueuedReceivedData(); } } else { - RTC_DCHECK(!id_.HasValue()); + RTC_DCHECK(!id_s_.HasValue()); } break; } @@ -542,7 +543,7 @@ void SctpDataChannel::UpdateState() { break; } case kClosing: { - if (connected_to_transport_) { + if (connected_to_transport_ && controller_) { // Wait for all queued data to be sent before beginning the closing // procedure. if (queued_send_data_.Empty() && queued_control_data_.Empty()) { @@ -550,9 +551,9 @@ void SctpDataChannel::UpdateState() { // to complete; after calling RemoveSctpDataStream, // OnClosingProcedureComplete will end up called asynchronously // afterwards. - if (!started_closing_procedure_ && controller_ && id_.HasValue()) { + if (!started_closing_procedure_ && id_s_.HasValue()) { started_closing_procedure_ = true; - network_thread_->BlockingCall([c = controller_.get(), sid = id_] { + network_thread_->BlockingCall([c = controller_.get(), sid = id_s_] { c->RemoveSctpDataStream(sid); }); } @@ -640,7 +641,7 @@ bool SctpDataChannel::SendDataMessage(const DataBuffer& buffer, send_params.type = buffer.binary ? DataMessageType::kBinary : DataMessageType::kText; - RTCError error = controller_->SendData(id_, send_params, buffer.data); + RTCError error = controller_->SendData(id_s_, send_params, buffer.data); if (error.ok()) { ++messages_sent_; @@ -691,20 +692,12 @@ void SctpDataChannel::SendQueuedControlMessages() { } } -void SctpDataChannel::QueueControlMessage( - const rtc::CopyOnWriteBuffer& buffer) { - RTC_DCHECK_RUN_ON(signaling_thread_); - queued_control_data_.PushBack(std::make_unique(buffer, true)); -} - bool SctpDataChannel::SendControlMessage(const rtc::CopyOnWriteBuffer& buffer) { RTC_DCHECK_RUN_ON(signaling_thread_); RTC_DCHECK(connected_to_transport_); - RTC_DCHECK(id_.HasValue()); + RTC_DCHECK(id_s_.HasValue()); + RTC_DCHECK(controller_); - if (!controller_) { - return false; - } bool is_open_message = handshake_state_ == kHandshakeShouldSendOpen; RTC_DCHECK(!is_open_message || !negotiated_); @@ -715,10 +708,10 @@ bool SctpDataChannel::SendControlMessage(const rtc::CopyOnWriteBuffer& buffer) { send_params.ordered = ordered_ || is_open_message; send_params.type = DataMessageType::kControl; - RTCError err = controller_->SendData(id_, send_params, buffer); + RTCError err = controller_->SendData(id_s_, send_params, buffer); if (err.ok()) { RTC_DLOG(LS_VERBOSE) << "Sent CONTROL message on channel " - << id_.stream_id_int(); + << id_s_.stream_id_int(); if (handshake_state_ == kHandshakeShouldSendAck) { handshake_state_ = kHandshakeReady; @@ -726,7 +719,7 @@ bool SctpDataChannel::SendControlMessage(const rtc::CopyOnWriteBuffer& buffer) { handshake_state_ = kHandshakeWaitingForAck; } } else if (err.type() == RTCErrorType::RESOURCE_EXHAUSTED) { - QueueControlMessage(buffer); + queued_control_data_.PushBack(std::make_unique(buffer, true)); } else { RTC_LOG(LS_ERROR) << "Closing the DataChannel due to a failure to send" " the CONTROL message, send_result = " diff --git a/pc/sctp_data_channel.h b/pc/sctp_data_channel.h index f87a3c0f40..588b0cbf63 100644 --- a/pc/sctp_data_channel.h +++ b/pc/sctp_data_channel.h @@ -192,7 +192,8 @@ class SctpDataChannel : public DataChannelInterface { // Sets the SCTP sid and adds to transport layer if not set yet. Should only // be called once. - void SetSctpSid(const StreamId& sid); + void SetSctpSid_s(StreamId sid); + void SetSctpSid_n(StreamId sid); // The remote side started the closing procedure by resetting its outgoing // stream (our incoming stream). Sets state to kClosing. @@ -215,7 +216,14 @@ class SctpDataChannel : public DataChannelInterface { // stats purposes (see also `GetStats()`). int internal_id() const { return internal_id_; } - const StreamId& sid() const { return id_; } + StreamId sid_s() const { + RTC_DCHECK_RUN_ON(signaling_thread_); + return id_s_; + } + StreamId sid_n() const { + RTC_DCHECK_RUN_ON(network_thread_); + return id_n_; + } // Reset the allocator for internal ID values for testing, so that // the internal IDs generated are predictable. Test only. @@ -250,12 +258,12 @@ class SctpDataChannel : public DataChannelInterface { bool QueueSendDataMessage(const DataBuffer& buffer); void SendQueuedControlMessages(); - void QueueControlMessage(const rtc::CopyOnWriteBuffer& buffer); bool SendControlMessage(const rtc::CopyOnWriteBuffer& buffer); rtc::Thread* const signaling_thread_; rtc::Thread* const network_thread_; - StreamId id_; + StreamId id_s_ RTC_GUARDED_BY(signaling_thread_); + StreamId id_n_ RTC_GUARDED_BY(network_thread_); const int internal_id_; const std::string label_; const std::string protocol_; diff --git a/pc/test/fake_data_channel_controller.h b/pc/test/fake_data_channel_controller.h index 5a1ce2baae..26ecc31378 100644 --- a/pc/test/fake_data_channel_controller.h +++ b/pc/test/fake_data_channel_controller.h @@ -51,8 +51,8 @@ class FakeDataChannelController std::move(my_weak_ptr), std::string(label), transport_available_, init, signaling_thread_, network_thread_); - if (transport_available_ && channel->sid().HasValue()) { - AddSctpDataStream(channel->sid()); + if (transport_available_ && channel->sid_n().HasValue()) { + AddSctpDataStream(channel->sid_n()); } return channel; }); @@ -103,8 +103,9 @@ class FakeDataChannelController signaling_thread_->PostTask(SafeTask(signaling_safety_.flag(), [this, sid] { // Unlike the real SCTP transport, act like the closing procedure finished // instantly. - auto it = absl::c_find_if(connected_channels_, - [&](const auto* c) { return c->sid() == sid; }); + auto it = absl::c_find_if(connected_channels_, [&](const auto* c) { + return c->sid_s() == sid; + }); // This path mimics the DCC's OnChannelClosed handler since the FDCC // (this class) doesn't have a transport that would do that. if (it != connected_channels_.end())