[SctpDataChannel] Add a copy of the sid for the network thread.

* Rename id_ -> id_s_, add id_n_ and thread guards.
* Same for getters, sid() -> sid_s(), add sid_n()

As more things migrate over to the network thread, we'll only need the
_n variant.

Bug: webrtc:11547
Change-Id: Ic998330f4c81b0f6833967631ac70edc2ca2301c
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/299141
Commit-Queue: Tomas Gunnarsson <tommi@webrtc.org>
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#39724}
This commit is contained in:
Tommi 2023-03-30 12:01:56 +02:00 committed by WebRTC LUCI CQ
parent c888db24a4
commit 1158bde7c2
5 changed files with 68 additions and 67 deletions

View File

@ -117,7 +117,7 @@ void DataChannelController::OnChannelClosed(int channel_id) {
StreamId sid(channel_id); StreamId sid(channel_id);
sid_allocator_.ReleaseSid(sid); sid_allocator_.ReleaseSid(sid);
auto it = absl::c_find_if(sctp_data_channels_n_, auto it = absl::c_find_if(sctp_data_channels_n_,
[&](const auto& c) { return c->sid() == sid; }); [&](const auto& c) { return c->sid_n() == sid; });
if (it != sctp_data_channels_n_.end()) if (it != sctp_data_channels_n_.end())
sctp_data_channels_n_.erase(it); sctp_data_channels_n_.erase(it);
@ -343,9 +343,10 @@ void DataChannelController::AllocateSctpSids(rtc::SSLRole role) {
RTC_DCHECK_RUN_ON(network_thread()); RTC_DCHECK_RUN_ON(network_thread());
for (auto it = sctp_data_channels_n_.begin(); for (auto it = sctp_data_channels_n_.begin();
it != sctp_data_channels_n_.end();) { it != sctp_data_channels_n_.end();) {
if (!(*it)->sid().HasValue()) { if (!(*it)->sid_n().HasValue()) {
StreamId sid = sid_allocator_.AllocateSid(role); StreamId sid = sid_allocator_.AllocateSid(role);
if (sid.HasValue()) { if (sid.HasValue()) {
(*it)->SetSctpSid_n(sid);
AddSctpDataStream(sid); AddSctpDataStream(sid);
channels_to_update.push_back(std::make_pair((*it).get(), sid)); channels_to_update.push_back(std::make_pair((*it).get(), sid));
} else { } else {
@ -373,22 +374,20 @@ void DataChannelController::AllocateSctpSids(rtc::SSLRole role) {
return c.get() == pair.first; return c.get() == pair.first;
}); });
RTC_DCHECK(it != sctp_data_channels_.end()); RTC_DCHECK(it != sctp_data_channels_.end());
(*it)->SetSctpSid(pair.second); (*it)->SetSctpSid_s(pair.second);
} }
} }
void DataChannelController::OnSctpDataChannelClosed(SctpDataChannel* channel) { void DataChannelController::OnSctpDataChannelClosed(SctpDataChannel* channel) {
RTC_DCHECK_RUN_ON(signaling_thread()); RTC_DCHECK_RUN_ON(signaling_thread());
// TODO(tommi): `sid()` should be called on the network thread. network_thread()->BlockingCall([&] {
// `sid()` and `SctpDataChannel::id_`should have thread guards to enforce
// correct usage.
network_thread()->BlockingCall([&, sid = channel->sid()] {
RTC_DCHECK_RUN_ON(network_thread()); RTC_DCHECK_RUN_ON(network_thread());
// After the closing procedure is done, it's safe to use this ID for // After the closing procedure is done, it's safe to use this ID for
// another data channel. // another data channel.
if (sid.HasValue()) if (channel->sid_n().HasValue()) {
sid_allocator_.ReleaseSid(sid); sid_allocator_.ReleaseSid(channel->sid_n());
}
auto it = absl::c_find_if(sctp_data_channels_n_, [&](const auto& c) { auto it = absl::c_find_if(sctp_data_channels_n_, [&](const auto& c) {
return c.get() == channel; return c.get() == channel;
@ -463,15 +462,14 @@ void DataChannelController::NotifyDataChannelsOfTransportCreated() {
RTC_DCHECK_RUN_ON(network_thread()); RTC_DCHECK_RUN_ON(network_thread());
RTC_DCHECK(data_channel_transport()); RTC_DCHECK(data_channel_transport());
// TODO(tommi): Move the blocking call to `AddSctpDataStream` from for (const auto& channel : sctp_data_channels_n_) {
// `SctpDataChannel::OnTransportChannelCreated` to here and be consistent if (channel->sid_n().HasValue())
// with other call sites to `AddSctpDataStream`. We're already AddSctpDataStream(channel->sid_n());
// on the right (network) thread here. }
signaling_thread()->PostTask(SafeTask(signaling_safety_.flag(), [this] { signaling_thread()->PostTask(SafeTask(signaling_safety_.flag(), [this] {
RTC_DCHECK_RUN_ON(signaling_thread()); RTC_DCHECK_RUN_ON(signaling_thread());
auto copy = sctp_data_channels_; for (const auto& channel : sctp_data_channels_) {
for (const auto& channel : copy) {
channel->OnTransportChannelCreated(); channel->OnTransportChannelCreated();
} }
})); }));
@ -480,8 +478,9 @@ void DataChannelController::NotifyDataChannelsOfTransportCreated() {
std::vector<rtc::scoped_refptr<SctpDataChannel>>::iterator std::vector<rtc::scoped_refptr<SctpDataChannel>>::iterator
DataChannelController::FindChannel(StreamId stream_id) { DataChannelController::FindChannel(StreamId stream_id) {
RTC_DCHECK_RUN_ON(signaling_thread()); RTC_DCHECK_RUN_ON(signaling_thread());
return absl::c_find_if(sctp_data_channels_, return absl::c_find_if(sctp_data_channels_, [&](const auto& c) {
[&](const auto& c) { return c->sid() == stream_id; }); return c->sid_s() == stream_id;
});
} }
rtc::Thread* DataChannelController::network_thread() const { rtc::Thread* DataChannelController::network_thread() const {

View File

@ -89,7 +89,7 @@ class SctpDataChannelTest : public ::testing::Test {
void SetChannelReady() { void SetChannelReady() {
controller_->set_transport_available(true); controller_->set_transport_available(true);
webrtc_data_channel_->OnTransportChannelCreated(); webrtc_data_channel_->OnTransportChannelCreated();
if (!webrtc_data_channel_->sid().HasValue()) { if (!webrtc_data_channel_->sid_s().HasValue()) {
SetChannelSid(webrtc_data_channel_, StreamId(0)); SetChannelSid(webrtc_data_channel_, StreamId(0));
} }
controller_->set_ready_to_send(true); controller_->set_ready_to_send(true);
@ -105,7 +105,7 @@ class SctpDataChannelTest : public ::testing::Test {
RTC_DCHECK(sid.HasValue()); RTC_DCHECK(sid.HasValue());
network_thread_.BlockingCall( network_thread_.BlockingCall(
[&]() { controller_->AddSctpDataStream(sid); }); [&]() { controller_->AddSctpDataStream(sid); });
channel->SetSctpSid(sid); channel->SetSctpSid_s(sid);
} }
void AddObserver() { void AddObserver() {
@ -141,11 +141,11 @@ TEST_F(SctpDataChannelTest, VerifyConfigurationGetters) {
// Check the non-const part of the configuration. // Check the non-const part of the configuration.
EXPECT_EQ(webrtc_data_channel_->id(), init_.id); EXPECT_EQ(webrtc_data_channel_->id(), init_.id);
EXPECT_EQ(webrtc_data_channel_->sid(), StreamId()); EXPECT_EQ(webrtc_data_channel_->sid_s(), StreamId());
SetChannelReady(); SetChannelReady();
EXPECT_EQ(webrtc_data_channel_->id(), 0); EXPECT_EQ(webrtc_data_channel_->id(), 0);
EXPECT_EQ(webrtc_data_channel_->sid(), StreamId(0)); EXPECT_EQ(webrtc_data_channel_->sid_s(), StreamId(0));
} }
// Verifies that the data channel is connected to the transport after creation. // Verifies that the data channel is connected to the transport after creation.
@ -156,10 +156,10 @@ TEST_F(SctpDataChannelTest, ConnectedToTransportOnCreated) {
EXPECT_TRUE(controller_->IsConnected(dc.get())); EXPECT_TRUE(controller_->IsConnected(dc.get()));
// The sid is not set yet, so it should not have added the streams. // The sid is not set yet, so it should not have added the streams.
EXPECT_FALSE(controller_->IsStreamAdded(dc->sid())); EXPECT_FALSE(controller_->IsStreamAdded(dc->sid_s()));
SetChannelSid(dc, StreamId(0)); SetChannelSid(dc, StreamId(0));
EXPECT_TRUE(controller_->IsStreamAdded(dc->sid())); EXPECT_TRUE(controller_->IsStreamAdded(dc->sid_s()));
} }
// Tests the state of the data channel. // Tests the state of the data channel.

View File

@ -175,7 +175,8 @@ SctpDataChannel::SctpDataChannel(
rtc::Thread* network_thread) rtc::Thread* network_thread)
: signaling_thread_(signaling_thread), : signaling_thread_(signaling_thread),
network_thread_(network_thread), network_thread_(network_thread),
id_(config.id), id_s_(config.id),
id_n_(config.id),
internal_id_(GenerateUniqueId()), internal_id_(GenerateUniqueId()),
label_(label), label_(label),
protocol_(config.protocol), protocol_(config.protocol),
@ -260,7 +261,8 @@ bool SctpDataChannel::negotiated() const {
} }
int SctpDataChannel::id() const { int SctpDataChannel::id() const {
return id_.stream_id_int(); RTC_DCHECK_RUN_ON(signaling_thread_);
return id_s_.stream_id_int();
} }
Priority SctpDataChannel::priority() const { Priority SctpDataChannel::priority() const {
@ -333,14 +335,21 @@ bool SctpDataChannel::Send(const DataBuffer& buffer) {
return true; return true;
} }
void SctpDataChannel::SetSctpSid(const StreamId& sid) { void SctpDataChannel::SetSctpSid_s(StreamId sid) {
RTC_DCHECK_RUN_ON(signaling_thread_); RTC_DCHECK_RUN_ON(signaling_thread_);
RTC_DCHECK(!id_.HasValue()); RTC_DCHECK(!id_s_.HasValue());
RTC_DCHECK(sid.HasValue()); RTC_DCHECK(sid.HasValue());
RTC_DCHECK_NE(handshake_state_, kHandshakeWaitingForAck); RTC_DCHECK_NE(handshake_state_, kHandshakeWaitingForAck);
RTC_DCHECK_EQ(state_, kConnecting); RTC_DCHECK_EQ(state_, kConnecting);
id_ = sid; id_s_ = sid;
}
void SctpDataChannel::SetSctpSid_n(StreamId sid) {
RTC_DCHECK_RUN_ON(network_thread_);
RTC_DCHECK(!id_n_.HasValue());
RTC_DCHECK(sid.HasValue());
id_n_ = sid;
} }
void SctpDataChannel::OnClosingProcedureStartedRemotely() { void SctpDataChannel::OnClosingProcedureStartedRemotely() {
@ -370,16 +379,8 @@ void SctpDataChannel::OnClosingProcedureComplete() {
void SctpDataChannel::OnTransportChannelCreated() { void SctpDataChannel::OnTransportChannelCreated() {
RTC_DCHECK_RUN_ON(signaling_thread_); RTC_DCHECK_RUN_ON(signaling_thread_);
RTC_DCHECK(controller_);
connected_to_transport_ = true; connected_to_transport_ = true;
if (id_.HasValue()) {
// TODO(bugs.webrtc.org/11547): Move this call over to DCC and do it when we
// get the initial notification from the transport, on the network thread.
network_thread_->BlockingCall(
[c = controller_.get(), sid = id_] { c->AddSctpDataStream(sid); });
}
} }
void SctpDataChannel::OnTransportChannelClosed(RTCError error) { void SctpDataChannel::OnTransportChannelClosed(RTCError error) {
@ -407,18 +408,18 @@ void SctpDataChannel::OnDataReceived(DataMessageType type,
// Ignore it if we are not expecting an ACK message. // Ignore it if we are not expecting an ACK message.
RTC_LOG(LS_WARNING) RTC_LOG(LS_WARNING)
<< "DataChannel received unexpected CONTROL message, sid = " << "DataChannel received unexpected CONTROL message, sid = "
<< id_.stream_id_int(); << id_s_.stream_id_int();
return; return;
} }
if (ParseDataChannelOpenAckMessage(payload)) { if (ParseDataChannelOpenAckMessage(payload)) {
// We can send unordered as soon as we receive the ACK message. // We can send unordered as soon as we receive the ACK message.
handshake_state_ = kHandshakeReady; handshake_state_ = kHandshakeReady;
RTC_LOG(LS_INFO) << "DataChannel received OPEN_ACK message, sid = " RTC_LOG(LS_INFO) << "DataChannel received OPEN_ACK message, sid = "
<< id_.stream_id_int(); << id_s_.stream_id_int();
} else { } else {
RTC_LOG(LS_WARNING) RTC_LOG(LS_WARNING)
<< "DataChannel failed to parse OPEN_ACK message, sid = " << "DataChannel failed to parse OPEN_ACK message, sid = "
<< id_.stream_id_int(); << id_s_.stream_id_int();
} }
return; return;
} }
@ -427,7 +428,7 @@ void SctpDataChannel::OnDataReceived(DataMessageType type,
type == DataMessageType::kText); type == DataMessageType::kText);
RTC_DLOG(LS_VERBOSE) << "DataChannel received DATA message, sid = " RTC_DLOG(LS_VERBOSE) << "DataChannel received DATA message, sid = "
<< id_.stream_id_int(); << id_s_.stream_id_int();
// We can send unordered as soon as we receive any DATA message since the // We can send unordered as soon as we receive any DATA message since the
// remote side must have received the OPEN (and old clients do not send // remote side must have received the OPEN (and old clients do not send
// OPEN_ACK). // OPEN_ACK).
@ -514,7 +515,7 @@ void SctpDataChannel::UpdateState() {
switch (state_) { switch (state_) {
case kConnecting: { case kConnecting: {
if (connected_to_transport_) { if (connected_to_transport_ && controller_) {
if (handshake_state_ == kHandshakeShouldSendOpen) { if (handshake_state_ == kHandshakeShouldSendOpen) {
rtc::CopyOnWriteBuffer payload; rtc::CopyOnWriteBuffer payload;
WriteDataChannelOpenMessage(label_, protocol_, priority_, ordered_, WriteDataChannelOpenMessage(label_, protocol_, priority_, ordered_,
@ -534,7 +535,7 @@ void SctpDataChannel::UpdateState() {
DeliverQueuedReceivedData(); DeliverQueuedReceivedData();
} }
} else { } else {
RTC_DCHECK(!id_.HasValue()); RTC_DCHECK(!id_s_.HasValue());
} }
break; break;
} }
@ -542,7 +543,7 @@ void SctpDataChannel::UpdateState() {
break; break;
} }
case kClosing: { case kClosing: {
if (connected_to_transport_) { if (connected_to_transport_ && controller_) {
// Wait for all queued data to be sent before beginning the closing // Wait for all queued data to be sent before beginning the closing
// procedure. // procedure.
if (queued_send_data_.Empty() && queued_control_data_.Empty()) { if (queued_send_data_.Empty() && queued_control_data_.Empty()) {
@ -550,9 +551,9 @@ void SctpDataChannel::UpdateState() {
// to complete; after calling RemoveSctpDataStream, // to complete; after calling RemoveSctpDataStream,
// OnClosingProcedureComplete will end up called asynchronously // OnClosingProcedureComplete will end up called asynchronously
// afterwards. // afterwards.
if (!started_closing_procedure_ && controller_ && id_.HasValue()) { if (!started_closing_procedure_ && id_s_.HasValue()) {
started_closing_procedure_ = true; started_closing_procedure_ = true;
network_thread_->BlockingCall([c = controller_.get(), sid = id_] { network_thread_->BlockingCall([c = controller_.get(), sid = id_s_] {
c->RemoveSctpDataStream(sid); c->RemoveSctpDataStream(sid);
}); });
} }
@ -640,7 +641,7 @@ bool SctpDataChannel::SendDataMessage(const DataBuffer& buffer,
send_params.type = send_params.type =
buffer.binary ? DataMessageType::kBinary : DataMessageType::kText; buffer.binary ? DataMessageType::kBinary : DataMessageType::kText;
RTCError error = controller_->SendData(id_, send_params, buffer.data); RTCError error = controller_->SendData(id_s_, send_params, buffer.data);
if (error.ok()) { if (error.ok()) {
++messages_sent_; ++messages_sent_;
@ -691,20 +692,12 @@ void SctpDataChannel::SendQueuedControlMessages() {
} }
} }
void SctpDataChannel::QueueControlMessage(
const rtc::CopyOnWriteBuffer& buffer) {
RTC_DCHECK_RUN_ON(signaling_thread_);
queued_control_data_.PushBack(std::make_unique<DataBuffer>(buffer, true));
}
bool SctpDataChannel::SendControlMessage(const rtc::CopyOnWriteBuffer& buffer) { bool SctpDataChannel::SendControlMessage(const rtc::CopyOnWriteBuffer& buffer) {
RTC_DCHECK_RUN_ON(signaling_thread_); RTC_DCHECK_RUN_ON(signaling_thread_);
RTC_DCHECK(connected_to_transport_); RTC_DCHECK(connected_to_transport_);
RTC_DCHECK(id_.HasValue()); RTC_DCHECK(id_s_.HasValue());
RTC_DCHECK(controller_);
if (!controller_) {
return false;
}
bool is_open_message = handshake_state_ == kHandshakeShouldSendOpen; bool is_open_message = handshake_state_ == kHandshakeShouldSendOpen;
RTC_DCHECK(!is_open_message || !negotiated_); RTC_DCHECK(!is_open_message || !negotiated_);
@ -715,10 +708,10 @@ bool SctpDataChannel::SendControlMessage(const rtc::CopyOnWriteBuffer& buffer) {
send_params.ordered = ordered_ || is_open_message; send_params.ordered = ordered_ || is_open_message;
send_params.type = DataMessageType::kControl; send_params.type = DataMessageType::kControl;
RTCError err = controller_->SendData(id_, send_params, buffer); RTCError err = controller_->SendData(id_s_, send_params, buffer);
if (err.ok()) { if (err.ok()) {
RTC_DLOG(LS_VERBOSE) << "Sent CONTROL message on channel " RTC_DLOG(LS_VERBOSE) << "Sent CONTROL message on channel "
<< id_.stream_id_int(); << id_s_.stream_id_int();
if (handshake_state_ == kHandshakeShouldSendAck) { if (handshake_state_ == kHandshakeShouldSendAck) {
handshake_state_ = kHandshakeReady; handshake_state_ = kHandshakeReady;
@ -726,7 +719,7 @@ bool SctpDataChannel::SendControlMessage(const rtc::CopyOnWriteBuffer& buffer) {
handshake_state_ = kHandshakeWaitingForAck; handshake_state_ = kHandshakeWaitingForAck;
} }
} else if (err.type() == RTCErrorType::RESOURCE_EXHAUSTED) { } else if (err.type() == RTCErrorType::RESOURCE_EXHAUSTED) {
QueueControlMessage(buffer); queued_control_data_.PushBack(std::make_unique<DataBuffer>(buffer, true));
} else { } else {
RTC_LOG(LS_ERROR) << "Closing the DataChannel due to a failure to send" RTC_LOG(LS_ERROR) << "Closing the DataChannel due to a failure to send"
" the CONTROL message, send_result = " " the CONTROL message, send_result = "

View File

@ -192,7 +192,8 @@ class SctpDataChannel : public DataChannelInterface {
// Sets the SCTP sid and adds to transport layer if not set yet. Should only // Sets the SCTP sid and adds to transport layer if not set yet. Should only
// be called once. // be called once.
void SetSctpSid(const StreamId& sid); void SetSctpSid_s(StreamId sid);
void SetSctpSid_n(StreamId sid);
// The remote side started the closing procedure by resetting its outgoing // The remote side started the closing procedure by resetting its outgoing
// stream (our incoming stream). Sets state to kClosing. // stream (our incoming stream). Sets state to kClosing.
@ -215,7 +216,14 @@ class SctpDataChannel : public DataChannelInterface {
// stats purposes (see also `GetStats()`). // stats purposes (see also `GetStats()`).
int internal_id() const { return internal_id_; } int internal_id() const { return internal_id_; }
const StreamId& sid() const { return id_; } StreamId sid_s() const {
RTC_DCHECK_RUN_ON(signaling_thread_);
return id_s_;
}
StreamId sid_n() const {
RTC_DCHECK_RUN_ON(network_thread_);
return id_n_;
}
// Reset the allocator for internal ID values for testing, so that // Reset the allocator for internal ID values for testing, so that
// the internal IDs generated are predictable. Test only. // the internal IDs generated are predictable. Test only.
@ -250,12 +258,12 @@ class SctpDataChannel : public DataChannelInterface {
bool QueueSendDataMessage(const DataBuffer& buffer); bool QueueSendDataMessage(const DataBuffer& buffer);
void SendQueuedControlMessages(); void SendQueuedControlMessages();
void QueueControlMessage(const rtc::CopyOnWriteBuffer& buffer);
bool SendControlMessage(const rtc::CopyOnWriteBuffer& buffer); bool SendControlMessage(const rtc::CopyOnWriteBuffer& buffer);
rtc::Thread* const signaling_thread_; rtc::Thread* const signaling_thread_;
rtc::Thread* const network_thread_; rtc::Thread* const network_thread_;
StreamId id_; StreamId id_s_ RTC_GUARDED_BY(signaling_thread_);
StreamId id_n_ RTC_GUARDED_BY(network_thread_);
const int internal_id_; const int internal_id_;
const std::string label_; const std::string label_;
const std::string protocol_; const std::string protocol_;

View File

@ -51,8 +51,8 @@ class FakeDataChannelController
std::move(my_weak_ptr), std::string(label), std::move(my_weak_ptr), std::string(label),
transport_available_, init, signaling_thread_, transport_available_, init, signaling_thread_,
network_thread_); network_thread_);
if (transport_available_ && channel->sid().HasValue()) { if (transport_available_ && channel->sid_n().HasValue()) {
AddSctpDataStream(channel->sid()); AddSctpDataStream(channel->sid_n());
} }
return channel; return channel;
}); });
@ -103,8 +103,9 @@ class FakeDataChannelController
signaling_thread_->PostTask(SafeTask(signaling_safety_.flag(), [this, sid] { signaling_thread_->PostTask(SafeTask(signaling_safety_.flag(), [this, sid] {
// Unlike the real SCTP transport, act like the closing procedure finished // Unlike the real SCTP transport, act like the closing procedure finished
// instantly. // instantly.
auto it = absl::c_find_if(connected_channels_, auto it = absl::c_find_if(connected_channels_, [&](const auto* c) {
[&](const auto* c) { return c->sid() == sid; }); return c->sid_s() == sid;
});
// This path mimics the DCC's OnChannelClosed handler since the FDCC // This path mimics the DCC's OnChannelClosed handler since the FDCC
// (this class) doesn't have a transport that would do that. // (this class) doesn't have a transport that would do that.
if (it != connected_channels_.end()) if (it != connected_channels_.end())