diff --git a/pc/data_channel_unittest.cc b/pc/data_channel_unittest.cc index 2582561282..99ff5af1ff 100644 --- a/pc/data_channel_unittest.cc +++ b/pc/data_channel_unittest.cc @@ -179,6 +179,33 @@ TEST_F(SctpDataChannelTest, ConnectedToTransportOnCreated) { EXPECT_TRUE(controller_->IsStreamAdded(sid)); } +// Tests that calling `UnregisterObserver()` from within the `OnStateChange` +// is safe. +TEST_F(SctpDataChannelTest, UnregisterObserverFromOnStateChange) { + class TrickyObserver : public DataChannelObserver { + public: + explicit TrickyObserver(DataChannelInterface* channel) + : channel_(channel) {} + void OnStateChange() override { channel_->UnregisterObserver(); } + void OnBufferedAmountChange(uint64_t previous_amount) override {} + void OnMessage(const DataBuffer& buffer) override {} + + // This test is specifically for the observer adapter inside SctpDataChannel + // that kicks in when the return value from `IsOkToCallOnTheNetworkThread()` + // is false. + bool IsOkToCallOnTheNetworkThread() override { return false; } + + private: + DataChannelInterface* channel_; + }; + + EXPECT_EQ(DataChannelInterface::kConnecting, channel_->state()); + TrickyObserver observer(channel_.get()); + channel_->RegisterObserver(&observer); + SetChannelReady(); + EXPECT_EQ(DataChannelInterface::kOpen, channel_->state()); +} + // Tests the state of the data channel. TEST_F(SctpDataChannelTest, StateTransition) { AddObserver(); diff --git a/pc/sctp_data_channel.cc b/pc/sctp_data_channel.cc index 892eca9aa7..bb93743f68 100644 --- a/pc/sctp_data_channel.cc +++ b/pc/sctp_data_channel.cc @@ -173,6 +173,12 @@ class SctpDataChannel::ObserverAdapter : public DataChannelObserver { return cached_state_; } + void SetDelegate(DataChannelObserver* delegate) { + RTC_DCHECK_RUN_ON(signaling_thread()); + delegate_ = delegate; + safety_.reset(webrtc::PendingTaskSafetyFlag::CreateDetached()); + } + private: void OnStateChange() override { RTC_DCHECK_RUN_ON(network_thread()); @@ -181,7 +187,8 @@ class SctpDataChannel::ObserverAdapter : public DataChannelObserver { RTC_DCHECK_RUN_ON(signaling_thread()); cached_state_ = new_state; inside_state_change_ = true; - delegate_->OnStateChange(); + if (delegate_) + delegate_->OnStateChange(); inside_state_change_ = false; })); } @@ -189,22 +196,27 @@ class SctpDataChannel::ObserverAdapter : public DataChannelObserver { void OnMessage(const DataBuffer& buffer) override { RTC_DCHECK_RUN_ON(network_thread()); signaling_thread()->PostTask( - SafeTask(safety_.flag(), - [this, buffer = buffer] { delegate_->OnMessage(buffer); })); + SafeTask(safety_.flag(), [this, buffer = buffer] { + RTC_DCHECK_RUN_ON(signaling_thread()); + if (delegate_) + delegate_->OnMessage(buffer); + })); } void OnBufferedAmountChange(uint64_t sent_data_size) override { RTC_DCHECK_RUN_ON(network_thread()); signaling_thread()->PostTask( SafeTask(safety_.flag(), [this, sent_data_size] { - delegate_->OnBufferedAmountChange(sent_data_size); + RTC_DCHECK_RUN_ON(signaling_thread()); + if (delegate_) + delegate_->OnBufferedAmountChange(sent_data_size); })); } rtc::Thread* signaling_thread() const { return channel_->signaling_thread_; } rtc::Thread* network_thread() const { return channel_->network_thread_; } - DataChannelObserver* const delegate_; + DataChannelObserver* delegate_ RTC_GUARDED_BY(signaling_thread()); SctpDataChannel* const channel_; ScopedTaskSafety safety_; bool inside_state_change_ RTC_GUARDED_BY(signaling_thread()) = false; @@ -291,8 +303,11 @@ void SctpDataChannel::RegisterObserver(DataChannelObserver* observer) { // should be called on the network thread and IsOkToCallOnTheNetworkThread(). if (!observer->IsOkToCallOnTheNetworkThread()) { auto prepare_observer = [&]() { - RTC_DCHECK(!observer_adapter_); - observer_adapter_ = std::make_unique(observer, this); + if (observer_adapter_) { + observer_adapter_->SetDelegate(observer); + } else { + observer_adapter_ = std::make_unique(observer, this); + } return observer_adapter_.get(); }; // Instantiate the adapter in the right context and then substitute the @@ -344,12 +359,18 @@ void SctpDataChannel::UnregisterObserver() { network_thread_->BlockingCall(std::move(unregister_observer)); } - auto clear_observer = [&]() { observer_adapter_.reset(); }; + auto clear_delegate = [&]() { + // In case an implementation decides to unregister an observer while + // in a callback from the observer adapter, we can't delete the adapter. + // Instead we'll just clear the delegate pointer. + if (observer_adapter_) + observer_adapter_->SetDelegate(nullptr); + }; if (current_thread != signaling_thread_) { - signaling_thread_->BlockingCall(std::move(clear_observer)); + signaling_thread_->BlockingCall(std::move(clear_delegate)); } else { - clear_observer(); + clear_delegate(); } }