Handle corner case in SctpDataChannel::ObserverAdapter

This handles a corner case whereby an OnStateChange implementation
synchronously calls UnregisterObserver, which would (before this CL)
delete the observer adapter.

(Using No-Try since an import bot won't pass until this CL lands)

No-Try: True
Bug: webrtc:11547
Change-Id: I33a13495aad6151fdd76becfa9a2c8672d80d825
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/300280
Reviewed-by: Mirko Bonadei <mbonadei@webrtc.org>
Commit-Queue: Tomas Gunnarsson <tommi@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#39761}
This commit is contained in:
Tommi 2023-04-04 20:10:56 +02:00 committed by WebRTC LUCI CQ
parent fe53fec24e
commit 1b3c89878e
2 changed files with 58 additions and 10 deletions

View File

@ -179,6 +179,33 @@ TEST_F(SctpDataChannelTest, ConnectedToTransportOnCreated) {
EXPECT_TRUE(controller_->IsStreamAdded(sid));
}
// Tests that calling `UnregisterObserver()` from within the `OnStateChange`
// is safe.
TEST_F(SctpDataChannelTest, UnregisterObserverFromOnStateChange) {
class TrickyObserver : public DataChannelObserver {
public:
explicit TrickyObserver(DataChannelInterface* channel)
: channel_(channel) {}
void OnStateChange() override { channel_->UnregisterObserver(); }
void OnBufferedAmountChange(uint64_t previous_amount) override {}
void OnMessage(const DataBuffer& buffer) override {}
// This test is specifically for the observer adapter inside SctpDataChannel
// that kicks in when the return value from `IsOkToCallOnTheNetworkThread()`
// is false.
bool IsOkToCallOnTheNetworkThread() override { return false; }
private:
DataChannelInterface* channel_;
};
EXPECT_EQ(DataChannelInterface::kConnecting, channel_->state());
TrickyObserver observer(channel_.get());
channel_->RegisterObserver(&observer);
SetChannelReady();
EXPECT_EQ(DataChannelInterface::kOpen, channel_->state());
}
// Tests the state of the data channel.
TEST_F(SctpDataChannelTest, StateTransition) {
AddObserver();

View File

@ -173,6 +173,12 @@ class SctpDataChannel::ObserverAdapter : public DataChannelObserver {
return cached_state_;
}
void SetDelegate(DataChannelObserver* delegate) {
RTC_DCHECK_RUN_ON(signaling_thread());
delegate_ = delegate;
safety_.reset(webrtc::PendingTaskSafetyFlag::CreateDetached());
}
private:
void OnStateChange() override {
RTC_DCHECK_RUN_ON(network_thread());
@ -181,7 +187,8 @@ class SctpDataChannel::ObserverAdapter : public DataChannelObserver {
RTC_DCHECK_RUN_ON(signaling_thread());
cached_state_ = new_state;
inside_state_change_ = true;
delegate_->OnStateChange();
if (delegate_)
delegate_->OnStateChange();
inside_state_change_ = false;
}));
}
@ -189,22 +196,27 @@ class SctpDataChannel::ObserverAdapter : public DataChannelObserver {
void OnMessage(const DataBuffer& buffer) override {
RTC_DCHECK_RUN_ON(network_thread());
signaling_thread()->PostTask(
SafeTask(safety_.flag(),
[this, buffer = buffer] { delegate_->OnMessage(buffer); }));
SafeTask(safety_.flag(), [this, buffer = buffer] {
RTC_DCHECK_RUN_ON(signaling_thread());
if (delegate_)
delegate_->OnMessage(buffer);
}));
}
void OnBufferedAmountChange(uint64_t sent_data_size) override {
RTC_DCHECK_RUN_ON(network_thread());
signaling_thread()->PostTask(
SafeTask(safety_.flag(), [this, sent_data_size] {
delegate_->OnBufferedAmountChange(sent_data_size);
RTC_DCHECK_RUN_ON(signaling_thread());
if (delegate_)
delegate_->OnBufferedAmountChange(sent_data_size);
}));
}
rtc::Thread* signaling_thread() const { return channel_->signaling_thread_; }
rtc::Thread* network_thread() const { return channel_->network_thread_; }
DataChannelObserver* const delegate_;
DataChannelObserver* delegate_ RTC_GUARDED_BY(signaling_thread());
SctpDataChannel* const channel_;
ScopedTaskSafety safety_;
bool inside_state_change_ RTC_GUARDED_BY(signaling_thread()) = false;
@ -291,8 +303,11 @@ void SctpDataChannel::RegisterObserver(DataChannelObserver* observer) {
// should be called on the network thread and IsOkToCallOnTheNetworkThread().
if (!observer->IsOkToCallOnTheNetworkThread()) {
auto prepare_observer = [&]() {
RTC_DCHECK(!observer_adapter_);
observer_adapter_ = std::make_unique<ObserverAdapter>(observer, this);
if (observer_adapter_) {
observer_adapter_->SetDelegate(observer);
} else {
observer_adapter_ = std::make_unique<ObserverAdapter>(observer, this);
}
return observer_adapter_.get();
};
// Instantiate the adapter in the right context and then substitute the
@ -344,12 +359,18 @@ void SctpDataChannel::UnregisterObserver() {
network_thread_->BlockingCall(std::move(unregister_observer));
}
auto clear_observer = [&]() { observer_adapter_.reset(); };
auto clear_delegate = [&]() {
// In case an implementation decides to unregister an observer while
// in a callback from the observer adapter, we can't delete the adapter.
// Instead we'll just clear the delegate pointer.
if (observer_adapter_)
observer_adapter_->SetDelegate(nullptr);
};
if (current_thread != signaling_thread_) {
signaling_thread_->BlockingCall(std::move(clear_observer));
signaling_thread_->BlockingCall(std::move(clear_delegate));
} else {
clear_observer();
clear_delegate();
}
}