diff --git a/net/dcsctp/socket/callback_deferrer.cc b/net/dcsctp/socket/callback_deferrer.cc index 1b7fbacccb..b4af10e88a 100644 --- a/net/dcsctp/socket/callback_deferrer.cc +++ b/net/dcsctp/socket/callback_deferrer.cc @@ -36,12 +36,19 @@ class MessageDeliverer { }; } // namespace +void CallbackDeferrer::Prepare() { + RTC_DCHECK(!prepared_); + prepared_ = true; +} + void CallbackDeferrer::TriggerDeferred() { // Need to swap here. The client may call into the library from within a // callback, and that might result in adding new callbacks to this instance, // and the vector can't be modified while iterated on. + RTC_DCHECK(prepared_); std::vector> deferred; deferred.swap(deferred_); + prepared_ = false; for (auto& cb : deferred) { cb(underlying_); @@ -70,12 +77,14 @@ uint32_t CallbackDeferrer::GetRandomInt(uint32_t low, uint32_t high) { } void CallbackDeferrer::OnMessageReceived(DcSctpMessage message) { + RTC_DCHECK(prepared_); deferred_.emplace_back( [deliverer = MessageDeliverer(std::move(message))]( DcSctpSocketCallbacks& cb) mutable { deliverer.Deliver(cb); }); } void CallbackDeferrer::OnError(ErrorKind error, absl::string_view message) { + RTC_DCHECK(prepared_); deferred_.emplace_back( [error, message = std::string(message)](DcSctpSocketCallbacks& cb) { cb.OnError(error, message); @@ -83,6 +92,7 @@ void CallbackDeferrer::OnError(ErrorKind error, absl::string_view message) { } void CallbackDeferrer::OnAborted(ErrorKind error, absl::string_view message) { + RTC_DCHECK(prepared_); deferred_.emplace_back( [error, message = std::string(message)](DcSctpSocketCallbacks& cb) { cb.OnAborted(error, message); @@ -90,14 +100,17 @@ void CallbackDeferrer::OnAborted(ErrorKind error, absl::string_view message) { } void CallbackDeferrer::OnConnected() { + RTC_DCHECK(prepared_); deferred_.emplace_back([](DcSctpSocketCallbacks& cb) { cb.OnConnected(); }); } void CallbackDeferrer::OnClosed() { + RTC_DCHECK(prepared_); deferred_.emplace_back([](DcSctpSocketCallbacks& cb) { cb.OnClosed(); }); } void CallbackDeferrer::OnConnectionRestarted() { + RTC_DCHECK(prepared_); deferred_.emplace_back( [](DcSctpSocketCallbacks& cb) { cb.OnConnectionRestarted(); }); } @@ -105,6 +118,7 @@ void CallbackDeferrer::OnConnectionRestarted() { void CallbackDeferrer::OnStreamsResetFailed( rtc::ArrayView outgoing_streams, absl::string_view reason) { + RTC_DCHECK(prepared_); deferred_.emplace_back( [streams = std::vector(outgoing_streams.begin(), outgoing_streams.end()), @@ -115,6 +129,7 @@ void CallbackDeferrer::OnStreamsResetFailed( void CallbackDeferrer::OnStreamsResetPerformed( rtc::ArrayView outgoing_streams) { + RTC_DCHECK(prepared_); deferred_.emplace_back( [streams = std::vector(outgoing_streams.begin(), outgoing_streams.end())]( @@ -123,6 +138,7 @@ void CallbackDeferrer::OnStreamsResetPerformed( void CallbackDeferrer::OnIncomingStreamsReset( rtc::ArrayView incoming_streams) { + RTC_DCHECK(prepared_); deferred_.emplace_back( [streams = std::vector(incoming_streams.begin(), incoming_streams.end())]( @@ -130,12 +146,14 @@ void CallbackDeferrer::OnIncomingStreamsReset( } void CallbackDeferrer::OnBufferedAmountLow(StreamID stream_id) { + RTC_DCHECK(prepared_); deferred_.emplace_back([stream_id](DcSctpSocketCallbacks& cb) { cb.OnBufferedAmountLow(stream_id); }); } void CallbackDeferrer::OnTotalBufferedAmountLow() { + RTC_DCHECK(prepared_); deferred_.emplace_back( [](DcSctpSocketCallbacks& cb) { cb.OnTotalBufferedAmountLow(); }); } diff --git a/net/dcsctp/socket/callback_deferrer.h b/net/dcsctp/socket/callback_deferrer.h index ab2739feb1..918b1df32d 100644 --- a/net/dcsctp/socket/callback_deferrer.h +++ b/net/dcsctp/socket/callback_deferrer.h @@ -26,7 +26,6 @@ #include "rtc_base/ref_counted_object.h" namespace dcsctp { - // Defers callbacks until they can be safely triggered. // // There are a lot of callbacks from the dcSCTP library to the client, @@ -44,11 +43,22 @@ namespace dcsctp { // There are a number of exceptions, which is clearly annotated in the API. class CallbackDeferrer : public DcSctpSocketCallbacks { public: + class ScopedDeferrer { + public: + explicit ScopedDeferrer(CallbackDeferrer& callback_deferrer) + : callback_deferrer_(callback_deferrer) { + callback_deferrer_.Prepare(); + } + + ~ScopedDeferrer() { callback_deferrer_.TriggerDeferred(); } + + private: + CallbackDeferrer& callback_deferrer_; + }; + explicit CallbackDeferrer(DcSctpSocketCallbacks& underlying) : underlying_(underlying) {} - void TriggerDeferred(); - // Implementation of DcSctpSocketCallbacks SendPacketStatus SendPacketWithStatus( rtc::ArrayView data) override; @@ -71,7 +81,11 @@ class CallbackDeferrer : public DcSctpSocketCallbacks { void OnTotalBufferedAmountLow() override; private: + void Prepare(); + void TriggerDeferred(); + DcSctpSocketCallbacks& underlying_; + bool prepared_ = false; std::vector> deferred_; }; } // namespace dcsctp diff --git a/net/dcsctp/socket/dcsctp_socket.cc b/net/dcsctp/socket/dcsctp_socket.cc index a1cc12d1e1..10018135e0 100644 --- a/net/dcsctp/socket/dcsctp_socket.cc +++ b/net/dcsctp/socket/dcsctp_socket.cc @@ -281,6 +281,8 @@ void DcSctpSocket::MakeConnectionParameters() { } void DcSctpSocket::Connect() { + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + if (state_ == State::kClosed) { MakeConnectionParameters(); RTC_DLOG(LS_INFO) @@ -296,10 +298,11 @@ void DcSctpSocket::Connect() { << "Called Connect on a socket that is not closed"; } RTC_DCHECK(IsConsistent()); - callbacks_.TriggerDeferred(); } void DcSctpSocket::RestoreFromState(const DcSctpSocketHandoverState& state) { + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + if (state_ != State::kClosed) { callbacks_.OnError(ErrorKind::kUnsupportedOperation, "Only closed socket can be restored from state"); @@ -334,10 +337,11 @@ void DcSctpSocket::RestoreFromState(const DcSctpSocketHandoverState& state) { } RTC_DCHECK(IsConsistent()); - callbacks_.TriggerDeferred(); } void DcSctpSocket::Shutdown() { + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + if (tcb_ != nullptr) { // https://tools.ietf.org/html/rfc4960#section-9.2 // "Upon receipt of the SHUTDOWN primitive from its upper layer, the @@ -361,10 +365,11 @@ void DcSctpSocket::Shutdown() { InternalClose(ErrorKind::kNoError, ""); } RTC_DCHECK(IsConsistent()); - callbacks_.TriggerDeferred(); } void DcSctpSocket::Close() { + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + if (state_ != State::kClosed) { if (tcb_ != nullptr) { SctpPacket::Builder b = tcb_->PacketBuilder(); @@ -379,7 +384,6 @@ void DcSctpSocket::Close() { RTC_DLOG(LS_INFO) << log_prefix() << "Called Close on a closed socket"; } RTC_DCHECK(IsConsistent()); - callbacks_.TriggerDeferred(); } void DcSctpSocket::CloseConnectionBecauseOfTooManyTransmissionErrors() { @@ -411,6 +415,8 @@ void DcSctpSocket::InternalClose(ErrorKind error, absl::string_view message) { SendStatus DcSctpSocket::Send(DcSctpMessage message, const SendOptions& send_options) { + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + if (message.payload().empty()) { callbacks_.OnError(ErrorKind::kProtocolViolation, "Unable to send empty message"); @@ -445,12 +451,13 @@ SendStatus DcSctpSocket::Send(DcSctpMessage message, } RTC_DCHECK(IsConsistent()); - callbacks_.TriggerDeferred(); return SendStatus::kSuccess; } ResetStreamsStatus DcSctpSocket::ResetStreams( rtc::ArrayView outgoing_streams) { + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + if (tcb_ == nullptr) { callbacks_.OnError(ErrorKind::kWrongSequence, "Can't reset streams as the socket is not connected"); @@ -472,7 +479,6 @@ ResetStreamsStatus DcSctpSocket::ResetStreams( } RTC_DCHECK(IsConsistent()); - callbacks_.TriggerDeferred(); return ResetStreamsStatus::kPerformed; } @@ -654,6 +660,8 @@ bool DcSctpSocket::ValidatePacket(const SctpPacket& packet) { } void DcSctpSocket::HandleTimeout(TimeoutID timeout_id) { + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + timer_manager_.HandleTimeout(timeout_id); if (tcb_ != nullptr && tcb_->HasTooManyTxErrors()) { @@ -662,10 +670,11 @@ void DcSctpSocket::HandleTimeout(TimeoutID timeout_id) { } RTC_DCHECK(IsConsistent()); - callbacks_.TriggerDeferred(); } void DcSctpSocket::ReceivePacket(rtc::ArrayView data) { + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + ++metrics_.rx_packets_count; if (packet_observer_ != nullptr) { @@ -681,7 +690,6 @@ void DcSctpSocket::ReceivePacket(rtc::ArrayView data) { callbacks_.OnError(ErrorKind::kParseFailed, "Failed to parse received SCTP packet"); RTC_DCHECK(IsConsistent()); - callbacks_.TriggerDeferred(); return; } @@ -696,7 +704,6 @@ void DcSctpSocket::ReceivePacket(rtc::ArrayView data) { RTC_DLOG(LS_VERBOSE) << log_prefix() << "Packet failed verification tag check - dropping"; RTC_DCHECK(IsConsistent()); - callbacks_.TriggerDeferred(); return; } @@ -714,7 +721,6 @@ void DcSctpSocket::ReceivePacket(rtc::ArrayView data) { } RTC_DCHECK(IsConsistent()); - callbacks_.TriggerDeferred(); } void DcSctpSocket::DebugPrintOutgoing(rtc::ArrayView payload) { @@ -1646,6 +1652,8 @@ HandoverReadinessStatus DcSctpSocket::GetHandoverReadiness() const { absl::optional DcSctpSocket::GetHandoverStateAndClose() { + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + if (!GetHandoverReadiness().IsReady()) { return absl::nullopt; } @@ -1659,7 +1667,6 @@ DcSctpSocket::GetHandoverStateAndClose() { tcb_->AddHandoverState(state); send_queue_.AddHandoverState(state); InternalClose(ErrorKind::kNoError, "handover"); - callbacks_.TriggerDeferred(); } return std::move(state);