From 15a0c880cf5c4a0107d634cec2c6ad3d0ee23663 Mon Sep 17 00:00:00 2001 From: Victor Boivie Date: Tue, 28 Sep 2021 21:38:34 +0200 Subject: [PATCH] dcsctp: Ensure callbacks are always triggered The previous manual way of triggering the deferred callbacks was very error-prone, and this was also forgotten at a few places. We can do better. Using the RAII programming idiom, the callbacks are now ensured to be called before returning from public methods. Also added additional debug checks to ensure that there is a ScopedDeferrer active whenever callbacks are deferred. Bug: webrtc:13217 Change-Id: I16a8343b52c00fb30acb018d3846acd0a64318e0 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/233242 Commit-Queue: Victor Boivie Reviewed-by: Florent Castelli Cr-Commit-Position: refs/heads/main@{#35117} --- net/dcsctp/socket/callback_deferrer.cc | 18 ++++++++++++++++ net/dcsctp/socket/callback_deferrer.h | 20 +++++++++++++++--- net/dcsctp/socket/dcsctp_socket.cc | 29 ++++++++++++++++---------- 3 files changed, 53 insertions(+), 14 deletions(-) diff --git a/net/dcsctp/socket/callback_deferrer.cc b/net/dcsctp/socket/callback_deferrer.cc index 1b7fbacccb..b4af10e88a 100644 --- a/net/dcsctp/socket/callback_deferrer.cc +++ b/net/dcsctp/socket/callback_deferrer.cc @@ -36,12 +36,19 @@ class MessageDeliverer { }; } // namespace +void CallbackDeferrer::Prepare() { + RTC_DCHECK(!prepared_); + prepared_ = true; +} + void CallbackDeferrer::TriggerDeferred() { // Need to swap here. The client may call into the library from within a // callback, and that might result in adding new callbacks to this instance, // and the vector can't be modified while iterated on. + RTC_DCHECK(prepared_); std::vector> deferred; deferred.swap(deferred_); + prepared_ = false; for (auto& cb : deferred) { cb(underlying_); @@ -70,12 +77,14 @@ uint32_t CallbackDeferrer::GetRandomInt(uint32_t low, uint32_t high) { } void CallbackDeferrer::OnMessageReceived(DcSctpMessage message) { + RTC_DCHECK(prepared_); deferred_.emplace_back( [deliverer = MessageDeliverer(std::move(message))]( DcSctpSocketCallbacks& cb) mutable { deliverer.Deliver(cb); }); } void CallbackDeferrer::OnError(ErrorKind error, absl::string_view message) { + RTC_DCHECK(prepared_); deferred_.emplace_back( [error, message = std::string(message)](DcSctpSocketCallbacks& cb) { cb.OnError(error, message); @@ -83,6 +92,7 @@ void CallbackDeferrer::OnError(ErrorKind error, absl::string_view message) { } void CallbackDeferrer::OnAborted(ErrorKind error, absl::string_view message) { + RTC_DCHECK(prepared_); deferred_.emplace_back( [error, message = std::string(message)](DcSctpSocketCallbacks& cb) { cb.OnAborted(error, message); @@ -90,14 +100,17 @@ void CallbackDeferrer::OnAborted(ErrorKind error, absl::string_view message) { } void CallbackDeferrer::OnConnected() { + RTC_DCHECK(prepared_); deferred_.emplace_back([](DcSctpSocketCallbacks& cb) { cb.OnConnected(); }); } void CallbackDeferrer::OnClosed() { + RTC_DCHECK(prepared_); deferred_.emplace_back([](DcSctpSocketCallbacks& cb) { cb.OnClosed(); }); } void CallbackDeferrer::OnConnectionRestarted() { + RTC_DCHECK(prepared_); deferred_.emplace_back( [](DcSctpSocketCallbacks& cb) { cb.OnConnectionRestarted(); }); } @@ -105,6 +118,7 @@ void CallbackDeferrer::OnConnectionRestarted() { void CallbackDeferrer::OnStreamsResetFailed( rtc::ArrayView outgoing_streams, absl::string_view reason) { + RTC_DCHECK(prepared_); deferred_.emplace_back( [streams = std::vector(outgoing_streams.begin(), outgoing_streams.end()), @@ -115,6 +129,7 @@ void CallbackDeferrer::OnStreamsResetFailed( void CallbackDeferrer::OnStreamsResetPerformed( rtc::ArrayView outgoing_streams) { + RTC_DCHECK(prepared_); deferred_.emplace_back( [streams = std::vector(outgoing_streams.begin(), outgoing_streams.end())]( @@ -123,6 +138,7 @@ void CallbackDeferrer::OnStreamsResetPerformed( void CallbackDeferrer::OnIncomingStreamsReset( rtc::ArrayView incoming_streams) { + RTC_DCHECK(prepared_); deferred_.emplace_back( [streams = std::vector(incoming_streams.begin(), incoming_streams.end())]( @@ -130,12 +146,14 @@ void CallbackDeferrer::OnIncomingStreamsReset( } void CallbackDeferrer::OnBufferedAmountLow(StreamID stream_id) { + RTC_DCHECK(prepared_); deferred_.emplace_back([stream_id](DcSctpSocketCallbacks& cb) { cb.OnBufferedAmountLow(stream_id); }); } void CallbackDeferrer::OnTotalBufferedAmountLow() { + RTC_DCHECK(prepared_); deferred_.emplace_back( [](DcSctpSocketCallbacks& cb) { cb.OnTotalBufferedAmountLow(); }); } diff --git a/net/dcsctp/socket/callback_deferrer.h b/net/dcsctp/socket/callback_deferrer.h index ab2739feb1..918b1df32d 100644 --- a/net/dcsctp/socket/callback_deferrer.h +++ b/net/dcsctp/socket/callback_deferrer.h @@ -26,7 +26,6 @@ #include "rtc_base/ref_counted_object.h" namespace dcsctp { - // Defers callbacks until they can be safely triggered. // // There are a lot of callbacks from the dcSCTP library to the client, @@ -44,11 +43,22 @@ namespace dcsctp { // There are a number of exceptions, which is clearly annotated in the API. class CallbackDeferrer : public DcSctpSocketCallbacks { public: + class ScopedDeferrer { + public: + explicit ScopedDeferrer(CallbackDeferrer& callback_deferrer) + : callback_deferrer_(callback_deferrer) { + callback_deferrer_.Prepare(); + } + + ~ScopedDeferrer() { callback_deferrer_.TriggerDeferred(); } + + private: + CallbackDeferrer& callback_deferrer_; + }; + explicit CallbackDeferrer(DcSctpSocketCallbacks& underlying) : underlying_(underlying) {} - void TriggerDeferred(); - // Implementation of DcSctpSocketCallbacks SendPacketStatus SendPacketWithStatus( rtc::ArrayView data) override; @@ -71,7 +81,11 @@ class CallbackDeferrer : public DcSctpSocketCallbacks { void OnTotalBufferedAmountLow() override; private: + void Prepare(); + void TriggerDeferred(); + DcSctpSocketCallbacks& underlying_; + bool prepared_ = false; std::vector> deferred_; }; } // namespace dcsctp diff --git a/net/dcsctp/socket/dcsctp_socket.cc b/net/dcsctp/socket/dcsctp_socket.cc index a1cc12d1e1..10018135e0 100644 --- a/net/dcsctp/socket/dcsctp_socket.cc +++ b/net/dcsctp/socket/dcsctp_socket.cc @@ -281,6 +281,8 @@ void DcSctpSocket::MakeConnectionParameters() { } void DcSctpSocket::Connect() { + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + if (state_ == State::kClosed) { MakeConnectionParameters(); RTC_DLOG(LS_INFO) @@ -296,10 +298,11 @@ void DcSctpSocket::Connect() { << "Called Connect on a socket that is not closed"; } RTC_DCHECK(IsConsistent()); - callbacks_.TriggerDeferred(); } void DcSctpSocket::RestoreFromState(const DcSctpSocketHandoverState& state) { + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + if (state_ != State::kClosed) { callbacks_.OnError(ErrorKind::kUnsupportedOperation, "Only closed socket can be restored from state"); @@ -334,10 +337,11 @@ void DcSctpSocket::RestoreFromState(const DcSctpSocketHandoverState& state) { } RTC_DCHECK(IsConsistent()); - callbacks_.TriggerDeferred(); } void DcSctpSocket::Shutdown() { + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + if (tcb_ != nullptr) { // https://tools.ietf.org/html/rfc4960#section-9.2 // "Upon receipt of the SHUTDOWN primitive from its upper layer, the @@ -361,10 +365,11 @@ void DcSctpSocket::Shutdown() { InternalClose(ErrorKind::kNoError, ""); } RTC_DCHECK(IsConsistent()); - callbacks_.TriggerDeferred(); } void DcSctpSocket::Close() { + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + if (state_ != State::kClosed) { if (tcb_ != nullptr) { SctpPacket::Builder b = tcb_->PacketBuilder(); @@ -379,7 +384,6 @@ void DcSctpSocket::Close() { RTC_DLOG(LS_INFO) << log_prefix() << "Called Close on a closed socket"; } RTC_DCHECK(IsConsistent()); - callbacks_.TriggerDeferred(); } void DcSctpSocket::CloseConnectionBecauseOfTooManyTransmissionErrors() { @@ -411,6 +415,8 @@ void DcSctpSocket::InternalClose(ErrorKind error, absl::string_view message) { SendStatus DcSctpSocket::Send(DcSctpMessage message, const SendOptions& send_options) { + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + if (message.payload().empty()) { callbacks_.OnError(ErrorKind::kProtocolViolation, "Unable to send empty message"); @@ -445,12 +451,13 @@ SendStatus DcSctpSocket::Send(DcSctpMessage message, } RTC_DCHECK(IsConsistent()); - callbacks_.TriggerDeferred(); return SendStatus::kSuccess; } ResetStreamsStatus DcSctpSocket::ResetStreams( rtc::ArrayView outgoing_streams) { + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + if (tcb_ == nullptr) { callbacks_.OnError(ErrorKind::kWrongSequence, "Can't reset streams as the socket is not connected"); @@ -472,7 +479,6 @@ ResetStreamsStatus DcSctpSocket::ResetStreams( } RTC_DCHECK(IsConsistent()); - callbacks_.TriggerDeferred(); return ResetStreamsStatus::kPerformed; } @@ -654,6 +660,8 @@ bool DcSctpSocket::ValidatePacket(const SctpPacket& packet) { } void DcSctpSocket::HandleTimeout(TimeoutID timeout_id) { + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + timer_manager_.HandleTimeout(timeout_id); if (tcb_ != nullptr && tcb_->HasTooManyTxErrors()) { @@ -662,10 +670,11 @@ void DcSctpSocket::HandleTimeout(TimeoutID timeout_id) { } RTC_DCHECK(IsConsistent()); - callbacks_.TriggerDeferred(); } void DcSctpSocket::ReceivePacket(rtc::ArrayView data) { + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + ++metrics_.rx_packets_count; if (packet_observer_ != nullptr) { @@ -681,7 +690,6 @@ void DcSctpSocket::ReceivePacket(rtc::ArrayView data) { callbacks_.OnError(ErrorKind::kParseFailed, "Failed to parse received SCTP packet"); RTC_DCHECK(IsConsistent()); - callbacks_.TriggerDeferred(); return; } @@ -696,7 +704,6 @@ void DcSctpSocket::ReceivePacket(rtc::ArrayView data) { RTC_DLOG(LS_VERBOSE) << log_prefix() << "Packet failed verification tag check - dropping"; RTC_DCHECK(IsConsistent()); - callbacks_.TriggerDeferred(); return; } @@ -714,7 +721,6 @@ void DcSctpSocket::ReceivePacket(rtc::ArrayView data) { } RTC_DCHECK(IsConsistent()); - callbacks_.TriggerDeferred(); } void DcSctpSocket::DebugPrintOutgoing(rtc::ArrayView payload) { @@ -1646,6 +1652,8 @@ HandoverReadinessStatus DcSctpSocket::GetHandoverReadiness() const { absl::optional DcSctpSocket::GetHandoverStateAndClose() { + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + if (!GetHandoverReadiness().IsReady()) { return absl::nullopt; } @@ -1659,7 +1667,6 @@ DcSctpSocket::GetHandoverStateAndClose() { tcb_->AddHandoverState(state); send_queue_.AddHandoverState(state); InternalClose(ErrorKind::kNoError, "handover"); - callbacks_.TriggerDeferred(); } return std::move(state);