diff --git a/net/dcsctp/rx/data_tracker.cc b/net/dcsctp/rx/data_tracker.cc index 8faee9e7d2..1f2e43f7f5 100644 --- a/net/dcsctp/rx/data_tracker.cc +++ b/net/dcsctp/rx/data_tracker.cc @@ -373,4 +373,14 @@ void DataTracker::AddHandoverState(DcSctpSocketHandoverState& state) { state.rx.seen_packet = seen_packet_; } +void DataTracker::RestoreFromState(const DcSctpSocketHandoverState& state) { + // Validate that the component is in pristine state. + RTC_DCHECK(additional_tsn_blocks_.empty()); + RTC_DCHECK(duplicate_tsns_.empty()); + RTC_DCHECK(!seen_packet_); + + seen_packet_ = state.rx.seen_packet; + last_cumulative_acked_tsn_ = + tsn_unwrapper_.Unwrap(TSN(state.rx.last_cumulative_acked_tsn)); +} } // namespace dcsctp diff --git a/net/dcsctp/rx/data_tracker.h b/net/dcsctp/rx/data_tracker.h index 603a237245..ea077a9b57 100644 --- a/net/dcsctp/rx/data_tracker.h +++ b/net/dcsctp/rx/data_tracker.h @@ -54,15 +54,12 @@ class DataTracker { DataTracker(absl::string_view log_prefix, Timer* delayed_ack_timer, - TSN peer_initial_tsn, - const DcSctpSocketHandoverState* handover_state = nullptr) + TSN peer_initial_tsn) : log_prefix_(std::string(log_prefix) + "dtrack: "), - seen_packet_(handover_state != nullptr ? handover_state->rx.seen_packet - : false), + seen_packet_(false), delayed_ack_timer_(*delayed_ack_timer), - last_cumulative_acked_tsn_(tsn_unwrapper_.Unwrap( - handover_state ? TSN(handover_state->rx.last_cumulative_acked_tsn) - : TSN(*peer_initial_tsn - 1))) {} + last_cumulative_acked_tsn_( + tsn_unwrapper_.Unwrap(TSN(*peer_initial_tsn - 1))) {} // Indicates if the provided TSN is valid. If this return false, the data // should be dropped and not added to any other buffers, which essentially @@ -110,6 +107,7 @@ class DataTracker { HandoverReadinessStatus GetHandoverReadiness() const; void AddHandoverState(DcSctpSocketHandoverState& state); + void RestoreFromState(const DcSctpSocketHandoverState& state); private: enum class AckState { diff --git a/net/dcsctp/rx/data_tracker_test.cc b/net/dcsctp/rx/data_tracker_test.cc index 43494734b6..f74dd6eb0b 100644 --- a/net/dcsctp/rx/data_tracker_test.cc +++ b/net/dcsctp/rx/data_tracker_test.cc @@ -66,8 +66,9 @@ class DataTrackerTest : public testing::Test { DcSctpSocketHandoverState state; tracker_->AddHandoverState(state); g_handover_state_transformer_for_test(&state); - tracker_ = std::make_unique("log: ", timer_.get(), kInitialTSN, - &state); + tracker_ = + std::make_unique("log: ", timer_.get(), kInitialTSN); + tracker_->RestoreFromState(state); } TimeMs now_ = TimeMs(0); diff --git a/net/dcsctp/rx/interleaved_reassembly_streams.cc b/net/dcsctp/rx/interleaved_reassembly_streams.cc index 847058b7f8..8b316de676 100644 --- a/net/dcsctp/rx/interleaved_reassembly_streams.cc +++ b/net/dcsctp/rx/interleaved_reassembly_streams.cc @@ -32,26 +32,8 @@ namespace dcsctp { InterleavedReassemblyStreams::InterleavedReassemblyStreams( absl::string_view log_prefix, - OnAssembledMessage on_assembled_message, - const DcSctpSocketHandoverState* handover_state) - : log_prefix_(log_prefix), on_assembled_message_(on_assembled_message) { - if (handover_state) { - for (const DcSctpSocketHandoverState::OrderedStream& state : - handover_state->rx.ordered_streams) { - FullStreamId stream_id(IsUnordered(false), StreamID(state.id)); - streams_.emplace( - std::piecewise_construct, std::forward_as_tuple(stream_id), - std::forward_as_tuple(stream_id, this, MID(state.next_ssn))); - } - for (const DcSctpSocketHandoverState::UnorderedStream& state : - handover_state->rx.unordered_streams) { - FullStreamId stream_id(IsUnordered(true), StreamID(state.id)); - streams_.emplace(std::piecewise_construct, - std::forward_as_tuple(stream_id), - std::forward_as_tuple(stream_id, this)); - } - } -} + OnAssembledMessage on_assembled_message) + : log_prefix_(log_prefix), on_assembled_message_(on_assembled_message) {} size_t InterleavedReassemblyStreams::Stream::TryToAssembleMessage( UnwrappedMID mid) { @@ -267,4 +249,24 @@ void InterleavedReassemblyStreams::AddHandoverState( } } +void InterleavedReassemblyStreams::RestoreFromState( + const DcSctpSocketHandoverState& state) { + // Validate that the component is in pristine state. + RTC_DCHECK(streams_.empty()); + + for (const DcSctpSocketHandoverState::OrderedStream& state : + state.rx.ordered_streams) { + FullStreamId stream_id(IsUnordered(false), StreamID(state.id)); + streams_.emplace( + std::piecewise_construct, std::forward_as_tuple(stream_id), + std::forward_as_tuple(stream_id, this, MID(state.next_ssn))); + } + for (const DcSctpSocketHandoverState::UnorderedStream& state : + state.rx.unordered_streams) { + FullStreamId stream_id(IsUnordered(true), StreamID(state.id)); + streams_.emplace(std::piecewise_construct, std::forward_as_tuple(stream_id), + std::forward_as_tuple(stream_id, this)); + } +} + } // namespace dcsctp diff --git a/net/dcsctp/rx/interleaved_reassembly_streams.h b/net/dcsctp/rx/interleaved_reassembly_streams.h index 9d4bbc799d..a7b67707e9 100644 --- a/net/dcsctp/rx/interleaved_reassembly_streams.h +++ b/net/dcsctp/rx/interleaved_reassembly_streams.h @@ -28,10 +28,8 @@ namespace dcsctp { // enabled on the association, i.e. when RFC8260 is in use. class InterleavedReassemblyStreams : public ReassemblyStreams { public: - InterleavedReassemblyStreams( - absl::string_view log_prefix, - OnAssembledMessage on_assembled_message, - const DcSctpSocketHandoverState* handover_state = nullptr); + InterleavedReassemblyStreams(absl::string_view log_prefix, + OnAssembledMessage on_assembled_message); int Add(UnwrappedTSN tsn, Data data) override; @@ -44,6 +42,7 @@ class InterleavedReassemblyStreams : public ReassemblyStreams { HandoverReadinessStatus GetHandoverReadiness() const override; void AddHandoverState(DcSctpSocketHandoverState& state) override; + void RestoreFromState(const DcSctpSocketHandoverState& state) override; private: struct FullStreamId { diff --git a/net/dcsctp/rx/reassembly_queue.cc b/net/dcsctp/rx/reassembly_queue.cc index e0c47f731b..f72c5cb8c1 100644 --- a/net/dcsctp/rx/reassembly_queue.cc +++ b/net/dcsctp/rx/reassembly_queue.cc @@ -39,53 +39,43 @@ namespace { std::unique_ptr CreateStreams( absl::string_view log_prefix, ReassemblyStreams::OnAssembledMessage on_assembled_message, - bool use_message_interleaving, - const DcSctpSocketHandoverState* handover_state) { + bool use_message_interleaving) { if (use_message_interleaving) { return std::make_unique( - log_prefix, std::move(on_assembled_message), handover_state); + log_prefix, std::move(on_assembled_message)); } return std::make_unique( - log_prefix, std::move(on_assembled_message), handover_state); + log_prefix, std::move(on_assembled_message)); } } // namespace -ReassemblyQueue::ReassemblyQueue( - absl::string_view log_prefix, - TSN peer_initial_tsn, - size_t max_size_bytes, - bool use_message_interleaving, - const DcSctpSocketHandoverState* handover_state) +ReassemblyQueue::ReassemblyQueue(absl::string_view log_prefix, + TSN peer_initial_tsn, + size_t max_size_bytes, + bool use_message_interleaving) : log_prefix_(std::string(log_prefix) + "reasm: "), max_size_bytes_(max_size_bytes), watermark_bytes_(max_size_bytes * kHighWatermarkLimit), - last_assembled_tsn_watermark_(tsn_unwrapper_.Unwrap( - handover_state ? TSN(handover_state->rx.last_assembled_tsn) - : TSN(*peer_initial_tsn - 1))), - last_completed_reset_req_seq_nbr_( - handover_state - ? ReconfigRequestSN( - handover_state->rx.last_completed_deferred_reset_req_sn) - : ReconfigRequestSN(0)), + last_assembled_tsn_watermark_( + tsn_unwrapper_.Unwrap(TSN(*peer_initial_tsn - 1))), + last_completed_reset_req_seq_nbr_(ReconfigRequestSN(0)), streams_(CreateStreams( log_prefix_, [this](rtc::ArrayView tsns, DcSctpMessage message) { AddReassembledMessage(tsns, std::move(message)); }, - use_message_interleaving, - handover_state)) {} + use_message_interleaving)) {} void ReassemblyQueue::Add(TSN tsn, Data data) { RTC_DCHECK(IsConsistent()); RTC_DLOG(LS_VERBOSE) << log_prefix_ << "added tsn=" << *tsn << ", stream=" << *data.stream_id << ":" << *data.message_id << ":" << *data.fsn << ", type=" - << (data.is_beginning && data.is_end - ? "complete" - : data.is_beginning - ? "first" - : data.is_end ? "last" : "middle"); + << (data.is_beginning && data.is_end ? "complete" + : data.is_beginning ? "first" + : data.is_end ? "last" + : "middle"); UnwrappedTSN unwrapped_tsn = tsn_unwrapper_.Unwrap(tsn); @@ -309,4 +299,14 @@ void ReassemblyQueue::AddHandoverState(DcSctpSocketHandoverState& state) { streams_->AddHandoverState(state); } +void ReassemblyQueue::RestoreFromState(const DcSctpSocketHandoverState& state) { + // Validate that the component is in pristine state. + RTC_DCHECK(last_completed_reset_req_seq_nbr_ == ReconfigRequestSN(0)); + + last_assembled_tsn_watermark_ = + tsn_unwrapper_.Unwrap(TSN(state.rx.last_assembled_tsn)); + last_completed_reset_req_seq_nbr_ = + ReconfigRequestSN(state.rx.last_completed_deferred_reset_req_sn); + streams_->RestoreFromState(state); +} } // namespace dcsctp diff --git a/net/dcsctp/rx/reassembly_queue.h b/net/dcsctp/rx/reassembly_queue.h index ab5dd5e1b4..91f30a3f69 100644 --- a/net/dcsctp/rx/reassembly_queue.h +++ b/net/dcsctp/rx/reassembly_queue.h @@ -72,8 +72,7 @@ class ReassemblyQueue { ReassemblyQueue(absl::string_view log_prefix, TSN peer_initial_tsn, size_t max_size_bytes, - bool use_message_interleaving = false, - const DcSctpSocketHandoverState* handover_state = nullptr); + bool use_message_interleaving = false); // Adds a data chunk to the queue, with a `tsn` and other parameters in // `data`. @@ -124,6 +123,7 @@ class ReassemblyQueue { HandoverReadinessStatus GetHandoverReadiness() const; void AddHandoverState(DcSctpSocketHandoverState& state); + void RestoreFromState(const DcSctpSocketHandoverState& state); private: bool IsConsistent() const; diff --git a/net/dcsctp/rx/reassembly_queue_test.cc b/net/dcsctp/rx/reassembly_queue_test.cc index cac469f89f..549bc6fce1 100644 --- a/net/dcsctp/rx/reassembly_queue_test.cc +++ b/net/dcsctp/rx/reassembly_queue_test.cc @@ -376,7 +376,8 @@ TEST_F(ReassemblyQueueTest, HandoverInInitialState) { reasm1.AddHandoverState(state); g_handover_state_transformer_for_test(&state); ReassemblyQueue reasm2("log: ", TSN(100), kBufferSize, - /*use_message_interleaving=*/false, &state); + /*use_message_interleaving=*/false); + reasm2.RestoreFromState(state); reasm2.Add(TSN(10), gen_.Ordered({1, 2, 3, 4}, "BE")); EXPECT_THAT(reasm2.FlushMessages(), SizeIs(1)); @@ -392,7 +393,8 @@ TEST_F(ReassemblyQueueTest, HandoverAfterHavingAssembedOneMessage) { reasm1.AddHandoverState(state); g_handover_state_transformer_for_test(&state); ReassemblyQueue reasm2("log: ", TSN(100), kBufferSize, - /*use_message_interleaving=*/false, &state); + /*use_message_interleaving=*/false); + reasm2.RestoreFromState(state); reasm2.Add(TSN(11), gen_.Ordered({1, 2, 3, 4}, "BE")); EXPECT_THAT(reasm2.FlushMessages(), SizeIs(1)); diff --git a/net/dcsctp/rx/reassembly_streams.h b/net/dcsctp/rx/reassembly_streams.h index 06f1a781ce..0ecfac0c0a 100644 --- a/net/dcsctp/rx/reassembly_streams.h +++ b/net/dcsctp/rx/reassembly_streams.h @@ -81,6 +81,7 @@ class ReassemblyStreams { virtual HandoverReadinessStatus GetHandoverReadiness() const = 0; virtual void AddHandoverState(DcSctpSocketHandoverState& state) = 0; + virtual void RestoreFromState(const DcSctpSocketHandoverState& state) = 0; }; } // namespace dcsctp diff --git a/net/dcsctp/rx/traditional_reassembly_streams.cc b/net/dcsctp/rx/traditional_reassembly_streams.cc index f5dc8cacc8..dce6c90131 100644 --- a/net/dcsctp/rx/traditional_reassembly_streams.cc +++ b/net/dcsctp/rx/traditional_reassembly_streams.cc @@ -80,27 +80,9 @@ absl::optional::iterator> FindEnd( TraditionalReassemblyStreams::TraditionalReassemblyStreams( absl::string_view log_prefix, - OnAssembledMessage on_assembled_message, - const DcSctpSocketHandoverState* handover_state) + OnAssembledMessage on_assembled_message) : log_prefix_(log_prefix), - on_assembled_message_(std::move(on_assembled_message)) { - if (handover_state) { - for (const DcSctpSocketHandoverState::OrderedStream& state_stream : - handover_state->rx.ordered_streams) { - ordered_streams_.emplace( - std::piecewise_construct, - std::forward_as_tuple(StreamID(state_stream.id)), - std::forward_as_tuple(this, SSN(state_stream.next_ssn))); - } - for (const DcSctpSocketHandoverState::UnorderedStream& state_stream : - handover_state->rx.unordered_streams) { - unordered_streams_.emplace( - std::piecewise_construct, - std::forward_as_tuple(StreamID(state_stream.id)), - std::forward_as_tuple(this)); - } - } -} + on_assembled_message_(std::move(on_assembled_message)) {} int TraditionalReassemblyStreams::UnorderedStream::Add(UnwrappedTSN tsn, Data data) { @@ -342,4 +324,25 @@ void TraditionalReassemblyStreams::AddHandoverState( } } +void TraditionalReassemblyStreams::RestoreFromState( + const DcSctpSocketHandoverState& state) { + // Validate that the component is in pristine state. + RTC_DCHECK(ordered_streams_.empty()); + RTC_DCHECK(unordered_streams_.empty()); + + for (const DcSctpSocketHandoverState::OrderedStream& state_stream : + state.rx.ordered_streams) { + ordered_streams_.emplace( + std::piecewise_construct, + std::forward_as_tuple(StreamID(state_stream.id)), + std::forward_as_tuple(this, SSN(state_stream.next_ssn))); + } + for (const DcSctpSocketHandoverState::UnorderedStream& state_stream : + state.rx.unordered_streams) { + unordered_streams_.emplace(std::piecewise_construct, + std::forward_as_tuple(StreamID(state_stream.id)), + std::forward_as_tuple(this)); + } +} + } // namespace dcsctp diff --git a/net/dcsctp/rx/traditional_reassembly_streams.h b/net/dcsctp/rx/traditional_reassembly_streams.h index 2fac9ff683..4825afd1ba 100644 --- a/net/dcsctp/rx/traditional_reassembly_streams.h +++ b/net/dcsctp/rx/traditional_reassembly_streams.h @@ -29,10 +29,8 @@ namespace dcsctp { // RFC4960 is to be followed. class TraditionalReassemblyStreams : public ReassemblyStreams { public: - TraditionalReassemblyStreams( - absl::string_view log_prefix, - OnAssembledMessage on_assembled_message, - const DcSctpSocketHandoverState* handover_state = nullptr); + TraditionalReassemblyStreams(absl::string_view log_prefix, + OnAssembledMessage on_assembled_message); int Add(UnwrappedTSN tsn, Data data) override; @@ -45,6 +43,7 @@ class TraditionalReassemblyStreams : public ReassemblyStreams { HandoverReadinessStatus GetHandoverReadiness() const override; void AddHandoverState(DcSctpSocketHandoverState& state) override; + void RestoreFromState(const DcSctpSocketHandoverState& state) override; private: using ChunkMap = std::map; diff --git a/net/dcsctp/rx/traditional_reassembly_streams_test.cc b/net/dcsctp/rx/traditional_reassembly_streams_test.cc index 759962473d..341870442d 100644 --- a/net/dcsctp/rx/traditional_reassembly_streams_test.cc +++ b/net/dcsctp/rx/traditional_reassembly_streams_test.cc @@ -160,8 +160,8 @@ TEST_F(TraditionalReassemblyStreamsTest, NoStreamsCanBeHandedOver) { DcSctpSocketHandoverState state; streams1.AddHandoverState(state); g_handover_state_transformer_for_test(&state); - TraditionalReassemblyStreams streams2("", on_assembled.AsStdFunction(), - &state); + TraditionalReassemblyStreams streams2("", on_assembled.AsStdFunction()); + streams2.RestoreFromState(state); EXPECT_EQ(streams2.Add(tsn(1), gen_.Ordered({1}, "B")), 1); EXPECT_EQ(streams2.Add(tsn(2), gen_.Ordered({2, 3, 4})), 3); @@ -196,8 +196,8 @@ TEST_F(TraditionalReassemblyStreamsTest, DcSctpSocketHandoverState state; streams1.AddHandoverState(state); g_handover_state_transformer_for_test(&state); - TraditionalReassemblyStreams streams2("", on_assembled.AsStdFunction(), - &state); + TraditionalReassemblyStreams streams2("", on_assembled.AsStdFunction()); + streams2.RestoreFromState(state); EXPECT_EQ(streams2.Add(tsn(4), gen_.Ordered({7})), 1); } @@ -229,8 +229,8 @@ TEST_F(TraditionalReassemblyStreamsTest, DcSctpSocketHandoverState state; streams1.AddHandoverState(state); g_handover_state_transformer_for_test(&state); - TraditionalReassemblyStreams streams2("", on_assembled.AsStdFunction(), - &state); + TraditionalReassemblyStreams streams2("", on_assembled.AsStdFunction()); + streams2.RestoreFromState(state); EXPECT_EQ(streams2.Add(tsn(4), gen_.Unordered({7})), 1); } diff --git a/net/dcsctp/socket/dcsctp_socket.cc b/net/dcsctp/socket/dcsctp_socket.cc index 56abb492c0..421b3bfea3 100644 --- a/net/dcsctp/socket/dcsctp_socket.cc +++ b/net/dcsctp/socket/dcsctp_socket.cc @@ -306,6 +306,22 @@ void DcSctpSocket::Connect() { RTC_DCHECK(IsConsistent()); } +void DcSctpSocket::CreateTransmissionControlBlock( + const Capabilities& capabilities, + VerificationTag my_verification_tag, + TSN my_initial_tsn, + VerificationTag peer_verification_tag, + TSN peer_initial_tsn, + size_t a_rwnd, + TieTag tie_tag) { + tcb_ = std::make_unique( + timer_manager_, log_prefix_, options_, capabilities, callbacks_, + send_queue_, my_verification_tag, my_initial_tsn, peer_verification_tag, + peer_initial_tsn, a_rwnd, tie_tag, packet_sender_, + [this]() { return state_ == State::kEstablished; }); + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Created TCB: " << tcb_->ToString(); +} + void DcSctpSocket::RestoreFromState(const DcSctpSocketHandoverState& state) { RTC_DCHECK_RUN_ON(&thread_checker_); CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); @@ -328,15 +344,13 @@ void DcSctpSocket::RestoreFromState(const DcSctpSocketHandoverState& state) { send_queue_.RestoreFromState(state); - tcb_ = std::make_unique( - timer_manager_, log_prefix_, options_, capabilities, callbacks_, - send_queue_, my_verification_tag, TSN(state.my_initial_tsn), + CreateTransmissionControlBlock( + capabilities, my_verification_tag, TSN(state.my_initial_tsn), VerificationTag(state.peer_verification_tag), TSN(state.peer_initial_tsn), static_cast(0), - TieTag(state.tie_tag), packet_sender_, - [this]() { return state_ == State::kEstablished; }, &state); - RTC_DLOG(LS_VERBOSE) << log_prefix() << "Created peer TCB from state: " - << tcb_->ToString(); + TieTag(state.tie_tag)); + + tcb_->RestoreFromState(state); SetState(State::kEstablished, "restored from handover state"); callbacks_.OnConnected(); @@ -1201,14 +1215,18 @@ void DcSctpSocket::HandleInitAck( metrics_.peer_implementation = DeterminePeerImplementation(cookie->data()); - tcb_ = std::make_unique( - timer_manager_, log_prefix_, options_, capabilities, callbacks_, - send_queue_, connect_params_.verification_tag, - connect_params_.initial_tsn, chunk->initiate_tag(), chunk->initial_tsn(), - chunk->a_rwnd(), MakeTieTag(callbacks_), packet_sender_, - [this]() { return state_ == State::kEstablished; }); - RTC_DLOG(LS_VERBOSE) << log_prefix() - << "Created peer TCB: " << tcb_->ToString(); + // If the connection is re-established (peer restarted, but re-used old + // connection), make sure that all message identifiers are reset and any + // partly sent message is re-sent in full. The same is true when the socket + // is closed and later re-opened, which never happens in WebRTC, but is a + // valid operation on the SCTP level. Note that in case of handover, the + // send queue is already re-configured, and shouldn't be reset. + send_queue_.Reset(); + + CreateTransmissionControlBlock(capabilities, connect_params_.verification_tag, + connect_params_.initial_tsn, + chunk->initiate_tag(), chunk->initial_tsn(), + chunk->a_rwnd(), MakeTieTag(callbacks_)); SetState(State::kCookieEchoed, "INIT_ACK received"); @@ -1262,14 +1280,18 @@ void DcSctpSocket::HandleCookieEcho( } if (tcb_ == nullptr) { - tcb_ = std::make_unique( - timer_manager_, log_prefix_, options_, cookie->capabilities(), - callbacks_, send_queue_, connect_params_.verification_tag, + // If the connection is re-established (peer restarted, but re-used old + // connection), make sure that all message identifiers are reset and any + // partly sent message is re-sent in full. The same is true when the socket + // is closed and later re-opened, which never happens in WebRTC, but is a + // valid operation on the SCTP level. Note that in case of handover, the + // send queue is already re-configured, and shouldn't be reset. + send_queue_.Reset(); + + CreateTransmissionControlBlock( + cookie->capabilities(), connect_params_.verification_tag, connect_params_.initial_tsn, cookie->initiate_tag(), - cookie->initial_tsn(), cookie->a_rwnd(), MakeTieTag(callbacks_), - packet_sender_, [this]() { return state_ == State::kEstablished; }); - RTC_DLOG(LS_VERBOSE) << log_prefix() - << "Created peer TCB: " << tcb_->ToString(); + cookie->initial_tsn(), cookie->a_rwnd(), MakeTieTag(callbacks_)); } SctpPacket::Builder b = tcb_->PacketBuilder(); diff --git a/net/dcsctp/socket/dcsctp_socket.h b/net/dcsctp/socket/dcsctp_socket.h index d70d0fca54..157c515d65 100644 --- a/net/dcsctp/socket/dcsctp_socket.h +++ b/net/dcsctp/socket/dcsctp_socket.h @@ -138,6 +138,14 @@ class DcSctpSocket : public DcSctpSocketInterface { bool IsConsistent() const; static constexpr absl::string_view ToString(DcSctpSocket::State state); + void CreateTransmissionControlBlock(const Capabilities& capabilities, + VerificationTag my_verification_tag, + TSN my_initial_tsn, + VerificationTag peer_verification_tag, + TSN peer_initial_tsn, + size_t a_rwnd, + TieTag tie_tag); + // Changes the socket state, given a `reason` (for debugging/logging). void SetState(State state, absl::string_view reason); // Fills in `connect_params` with random verification tag and initial TSN. diff --git a/net/dcsctp/socket/stream_reset_handler_test.cc b/net/dcsctp/socket/stream_reset_handler_test.cc index a9a8b36bf7..e1e54d0422 100644 --- a/net/dcsctp/socket/stream_reset_handler_test.cc +++ b/net/dcsctp/socket/stream_reset_handler_test.cc @@ -193,14 +193,17 @@ class StreamResetHandlerTest : public testing::Test { g_handover_state_transformer_for_test(&state); data_tracker_ = std::make_unique( - "log: ", delayed_ack_timer_.get(), kPeerInitialTsn, &state); - reasm_ = std::make_unique("log: ", kPeerInitialTsn, kArwnd, - &state); + "log: ", delayed_ack_timer_.get(), kPeerInitialTsn); + data_tracker_->RestoreFromState(state); + reasm_ = + std::make_unique("log: ", kPeerInitialTsn, kArwnd); + reasm_->RestoreFromState(state); retransmission_queue_ = std::make_unique( "", kMyInitialTsn, kArwnd, producer_, [](DurationMs rtt_ms) {}, []() {}, *t3_rtx_timer_, DcSctpOptions(), /*supports_partial_reliability=*/true, - /*use_message_interleaving=*/false, &state); + /*use_message_interleaving=*/false); + retransmission_queue_->RestoreFromState(state); handler_ = std::make_unique( "log: ", &ctx_, &timer_manager_, data_tracker_.get(), reasm_.get(), retransmission_queue_.get(), &state); diff --git a/net/dcsctp/socket/transmission_control_block.cc b/net/dcsctp/socket/transmission_control_block.cc index 78331d5e96..44a1b7392c 100644 --- a/net/dcsctp/socket/transmission_control_block.cc +++ b/net/dcsctp/socket/transmission_control_block.cc @@ -37,6 +37,77 @@ namespace dcsctp { +TransmissionControlBlock::TransmissionControlBlock( + TimerManager& timer_manager, + absl::string_view log_prefix, + const DcSctpOptions& options, + const Capabilities& capabilities, + DcSctpSocketCallbacks& callbacks, + SendQueue& send_queue, + VerificationTag my_verification_tag, + TSN my_initial_tsn, + VerificationTag peer_verification_tag, + TSN peer_initial_tsn, + size_t a_rwnd, + TieTag tie_tag, + PacketSender& packet_sender, + std::function is_connection_established) + : log_prefix_(log_prefix), + options_(options), + timer_manager_(timer_manager), + capabilities_(capabilities), + callbacks_(callbacks), + t3_rtx_(timer_manager_.CreateTimer( + "t3-rtx", + absl::bind_front(&TransmissionControlBlock::OnRtxTimerExpiry, this), + TimerOptions(options.rto_initial, + TimerBackoffAlgorithm::kExponential, + /*max_restarts=*/absl::nullopt, + options.max_timer_backoff_duration))), + delayed_ack_timer_(timer_manager_.CreateTimer( + "delayed-ack", + absl::bind_front(&TransmissionControlBlock::OnDelayedAckTimerExpiry, + this), + TimerOptions(options.delayed_ack_max_timeout, + TimerBackoffAlgorithm::kExponential, + /*max_restarts=*/0, + /*max_backoff_duration=*/absl::nullopt, + webrtc::TaskQueueBase::DelayPrecision::kHigh))), + my_verification_tag_(my_verification_tag), + my_initial_tsn_(my_initial_tsn), + peer_verification_tag_(peer_verification_tag), + peer_initial_tsn_(peer_initial_tsn), + tie_tag_(tie_tag), + is_connection_established_(std::move(is_connection_established)), + packet_sender_(packet_sender), + rto_(options), + tx_error_counter_(log_prefix, options), + data_tracker_(log_prefix, delayed_ack_timer_.get(), peer_initial_tsn), + reassembly_queue_(log_prefix, + peer_initial_tsn, + options.max_receiver_window_buffer_size, + capabilities.message_interleaving), + retransmission_queue_( + log_prefix, + my_initial_tsn, + a_rwnd, + send_queue, + absl::bind_front(&TransmissionControlBlock::ObserveRTT, this), + [this]() { tx_error_counter_.Clear(); }, + *t3_rtx_, + options, + capabilities.partial_reliability, + capabilities.message_interleaving), + stream_reset_handler_(log_prefix, + this, + &timer_manager, + &data_tracker_, + &reassembly_queue_, + &retransmission_queue_), + heartbeat_handler_(log_prefix, options, this, &timer_manager_) { + send_queue.EnableMessageInterleaving(capabilities.message_interleaving); +} + void TransmissionControlBlock::ObserveRTT(DurationMs rtt) { DurationMs prev_rto = rto_.rto(); rto_.ObserveRTT(rtt); @@ -232,4 +303,11 @@ void TransmissionControlBlock::AddHandoverState( reassembly_queue_.AddHandoverState(state); retransmission_queue_.AddHandoverState(state); } + +void TransmissionControlBlock::RestoreFromState( + const DcSctpSocketHandoverState& state) { + data_tracker_.RestoreFromState(state); + retransmission_queue_.RestoreFromState(state); + reassembly_queue_.RestoreFromState(state); +} } // namespace dcsctp diff --git a/net/dcsctp/socket/transmission_control_block.h b/net/dcsctp/socket/transmission_control_block.h index f21278845b..8e0e9a3ec5 100644 --- a/net/dcsctp/socket/transmission_control_block.h +++ b/net/dcsctp/socket/transmission_control_block.h @@ -45,92 +45,20 @@ namespace dcsctp { // closed or restarted, this object will be deleted and/or replaced. class TransmissionControlBlock : public Context { public: - TransmissionControlBlock( - TimerManager& timer_manager, - absl::string_view log_prefix, - const DcSctpOptions& options, - const Capabilities& capabilities, - DcSctpSocketCallbacks& callbacks, - SendQueue& send_queue, - VerificationTag my_verification_tag, - TSN my_initial_tsn, - VerificationTag peer_verification_tag, - TSN peer_initial_tsn, - size_t a_rwnd, - TieTag tie_tag, - PacketSender& packet_sender, - std::function is_connection_established, - const DcSctpSocketHandoverState* handover_state = nullptr) - : log_prefix_(log_prefix), - options_(options), - timer_manager_(timer_manager), - capabilities_(capabilities), - callbacks_(callbacks), - t3_rtx_(timer_manager_.CreateTimer( - "t3-rtx", - absl::bind_front(&TransmissionControlBlock::OnRtxTimerExpiry, this), - TimerOptions(options.rto_initial, - TimerBackoffAlgorithm::kExponential, - /*max_restarts=*/absl::nullopt, - options.max_timer_backoff_duration))), - delayed_ack_timer_(timer_manager_.CreateTimer( - "delayed-ack", - absl::bind_front(&TransmissionControlBlock::OnDelayedAckTimerExpiry, - this), - TimerOptions(options.delayed_ack_max_timeout, - TimerBackoffAlgorithm::kExponential, - /*max_restarts=*/0, - /*max_backoff_duration=*/absl::nullopt, - webrtc::TaskQueueBase::DelayPrecision::kHigh))), - my_verification_tag_(my_verification_tag), - my_initial_tsn_(my_initial_tsn), - peer_verification_tag_(peer_verification_tag), - peer_initial_tsn_(peer_initial_tsn), - tie_tag_(tie_tag), - is_connection_established_(std::move(is_connection_established)), - packet_sender_(packet_sender), - rto_(options), - tx_error_counter_(log_prefix, options), - data_tracker_(log_prefix, - delayed_ack_timer_.get(), - peer_initial_tsn, - handover_state), - reassembly_queue_(log_prefix, - peer_initial_tsn, - options.max_receiver_window_buffer_size, - capabilities.message_interleaving, - handover_state), - retransmission_queue_( - log_prefix, - my_initial_tsn, - a_rwnd, - send_queue, - absl::bind_front(&TransmissionControlBlock::ObserveRTT, this), - [this]() { tx_error_counter_.Clear(); }, - *t3_rtx_, - options, - capabilities.partial_reliability, - capabilities.message_interleaving, - handover_state), - stream_reset_handler_(log_prefix, - this, - &timer_manager, - &data_tracker_, - &reassembly_queue_, - &retransmission_queue_, - handover_state), - heartbeat_handler_(log_prefix, options, this, &timer_manager_) { - // If the connection is re-established (peer restarted, but re-used old - // connection), make sure that all message identifiers are reset and any - // partly sent message is re-sent in full. The same is true when the socket - // is closed and later re-opened, which never happens in WebRTC, but is a - // valid operation on the SCTP level. Note that in case of handover, the - // send queue is already re-configured, and shouldn't be reset. - if (handover_state == nullptr) { - send_queue.Reset(); - } - send_queue.EnableMessageInterleaving(capabilities.message_interleaving); - } + TransmissionControlBlock(TimerManager& timer_manager, + absl::string_view log_prefix, + const DcSctpOptions& options, + const Capabilities& capabilities, + DcSctpSocketCallbacks& callbacks, + SendQueue& send_queue, + VerificationTag my_verification_tag, + TSN my_initial_tsn, + VerificationTag peer_verification_tag, + TSN peer_initial_tsn, + size_t a_rwnd, + TieTag tie_tag, + PacketSender& packet_sender, + std::function is_connection_established); // Implementation of `Context`. bool is_connection_established() const override { @@ -216,6 +144,7 @@ class TransmissionControlBlock : public Context { HandoverReadinessStatus GetHandoverReadiness() const; void AddHandoverState(DcSctpSocketHandoverState& state); + void RestoreFromState(const DcSctpSocketHandoverState& handover_state); private: // Will be called when the retransmission timer (t3-rtx) expires. diff --git a/net/dcsctp/tx/outstanding_data.cc b/net/dcsctp/tx/outstanding_data.cc index c013ac5bdd..91651e9a73 100644 --- a/net/dcsctp/tx/outstanding_data.cc +++ b/net/dcsctp/tx/outstanding_data.cc @@ -517,4 +517,12 @@ IForwardTsnChunk OutstandingData::CreateIForwardTsn() const { std::move(skipped_streams)); } +void OutstandingData::ResetSequenceNumbers(UnwrappedTSN next_tsn, + UnwrappedTSN last_cumulative_tsn) { + RTC_DCHECK(outstanding_data_.empty()); + RTC_DCHECK(next_tsn_ == last_cumulative_tsn_ack_.next_value()); + RTC_DCHECK(next_tsn == last_cumulative_tsn.next_value()); + next_tsn_ = next_tsn; + last_cumulative_tsn_ack_ = last_cumulative_tsn; +} } // namespace dcsctp diff --git a/net/dcsctp/tx/outstanding_data.h b/net/dcsctp/tx/outstanding_data.h index 382490b52f..5c638680b7 100644 --- a/net/dcsctp/tx/outstanding_data.h +++ b/net/dcsctp/tx/outstanding_data.h @@ -147,6 +147,10 @@ class OutstandingData { // abandoned, which means that a FORWARD-TSN should be sent. bool ShouldSendForwardTsn() const; + // Sets the next TSN to be used. This is used in handover. + void ResetSequenceNumbers(UnwrappedTSN next_tsn, + UnwrappedTSN last_cumulative_tsn); + private: // A fragmented message's DATA chunk while in the retransmission queue, and // its associated metadata. diff --git a/net/dcsctp/tx/retransmission_queue.cc b/net/dcsctp/tx/retransmission_queue.cc index f26e8baa44..0ca02b0b88 100644 --- a/net/dcsctp/tx/retransmission_queue.cc +++ b/net/dcsctp/tx/retransmission_queue.cc @@ -59,8 +59,7 @@ RetransmissionQueue::RetransmissionQueue( Timer& t3_rtx, const DcSctpOptions& options, bool supports_partial_reliability, - bool use_message_interleaving, - const DcSctpSocketHandoverState* handover_state) + bool use_message_interleaving) : options_(options), min_bytes_required_to_send_(options.mtu * kMinBytesRequiredToSendFactor), partial_reliability_(supports_partial_reliability), @@ -72,25 +71,19 @@ RetransmissionQueue::RetransmissionQueue( on_clear_retransmission_counter_( std::move(on_clear_retransmission_counter)), t3_rtx_(t3_rtx), - cwnd_(handover_state ? handover_state->tx.cwnd - : options_.cwnd_mtus_initial * options_.mtu), - rwnd_(handover_state ? handover_state->tx.rwnd : a_rwnd), + cwnd_(options_.cwnd_mtus_initial * options_.mtu), + rwnd_(a_rwnd), // https://tools.ietf.org/html/rfc4960#section-7.2.1 // "The initial value of ssthresh MAY be arbitrarily high (for // example, implementations MAY use the size of the receiver advertised // window)."" - ssthresh_(handover_state ? handover_state->tx.ssthresh : rwnd_), - partial_bytes_acked_( - handover_state ? handover_state->tx.partial_bytes_acked : 0), + ssthresh_(rwnd_), + partial_bytes_acked_(0), send_queue_(send_queue), outstanding_data_( data_chunk_header_size_, - tsn_unwrapper_.Unwrap(handover_state - ? TSN(handover_state->tx.next_tsn) - : my_initial_tsn), - tsn_unwrapper_.Unwrap(handover_state - ? TSN(handover_state->tx.next_tsn - 1) - : TSN(*my_initial_tsn - 1)), + tsn_unwrapper_.Unwrap(my_initial_tsn), + tsn_unwrapper_.Unwrap(TSN(*my_initial_tsn - 1)), [this](IsUnordered unordered, StreamID stream_id, MID message_id) { return send_queue_.Discard(unordered, stream_id, message_id); }) {} @@ -578,4 +571,21 @@ void RetransmissionQueue::AddHandoverState(DcSctpSocketHandoverState& state) { state.tx.ssthresh = ssthresh_; state.tx.partial_bytes_acked = partial_bytes_acked_; } + +void RetransmissionQueue::RestoreFromState( + const DcSctpSocketHandoverState& state) { + // Validate that the component is in pristine state. + RTC_DCHECK(outstanding_data_.empty()); + RTC_DCHECK(!t3_rtx_.is_running()); + RTC_DCHECK(partial_bytes_acked_ == 0); + + cwnd_ = state.tx.cwnd; + rwnd_ = state.tx.rwnd; + ssthresh_ = state.tx.ssthresh; + partial_bytes_acked_ = state.tx.partial_bytes_acked; + + outstanding_data_.ResetSequenceNumbers( + tsn_unwrapper_.Unwrap(TSN(state.tx.next_tsn)), + tsn_unwrapper_.Unwrap(TSN(state.tx.next_tsn - 1))); +} } // namespace dcsctp diff --git a/net/dcsctp/tx/retransmission_queue.h b/net/dcsctp/tx/retransmission_queue.h index 1958dfd643..51eeb5a319 100644 --- a/net/dcsctp/tx/retransmission_queue.h +++ b/net/dcsctp/tx/retransmission_queue.h @@ -54,18 +54,16 @@ class RetransmissionQueue { // outstanding chunk has been ACKed, it will call // `on_clear_retransmission_counter` and will also use `t3_rtx`, which is the // SCTP retransmission timer to manage retransmissions. - RetransmissionQueue( - absl::string_view log_prefix, - TSN my_initial_tsn, - size_t a_rwnd, - SendQueue& send_queue, - std::function on_new_rtt, - std::function on_clear_retransmission_counter, - Timer& t3_rtx, - const DcSctpOptions& options, - bool supports_partial_reliability = true, - bool use_message_interleaving = false, - const DcSctpSocketHandoverState* handover_state = nullptr); + RetransmissionQueue(absl::string_view log_prefix, + TSN my_initial_tsn, + size_t a_rwnd, + SendQueue& send_queue, + std::function on_new_rtt, + std::function on_clear_retransmission_counter, + Timer& t3_rtx, + const DcSctpOptions& options, + bool supports_partial_reliability = true, + bool use_message_interleaving = false); // Handles a received SACK. Returns true if the `sack` was processed and // false if it was discarded due to received out-of-order and not relevant. @@ -154,6 +152,7 @@ class RetransmissionQueue { HandoverReadinessStatus GetHandoverReadiness() const; void AddHandoverState(DcSctpSocketHandoverState& state); + void RestoreFromState(const DcSctpSocketHandoverState& state); private: enum class CongestionAlgorithmPhase { diff --git a/net/dcsctp/tx/retransmission_queue_test.cc b/net/dcsctp/tx/retransmission_queue_test.cc index 1d28cb23a1..f11ebad19a 100644 --- a/net/dcsctp/tx/retransmission_queue_test.cc +++ b/net/dcsctp/tx/retransmission_queue_test.cc @@ -103,16 +103,19 @@ class RetransmissionQueueTest : public testing::Test { supports_partial_reliability, use_message_interleaving); } - RetransmissionQueue CreateQueueByHandover(RetransmissionQueue& queue) { + std::unique_ptr CreateQueueByHandover( + RetransmissionQueue& queue) { EXPECT_EQ(queue.GetHandoverReadiness(), HandoverReadinessStatus()); DcSctpSocketHandoverState state; queue.AddHandoverState(state); g_handover_state_transformer_for_test(&state); - return RetransmissionQueue( + auto queue2 = std::make_unique( "", TSN(10), kArwnd, producer_, on_rtt_.AsStdFunction(), on_clear_retransmission_counter_.AsStdFunction(), *timer_, options_, /*supports_partial_reliability=*/true, - /*use_message_interleaving=*/false, &state); + /*use_message_interleaving=*/false); + queue2->RestoreFromState(state); + return queue2; } DcSctpOptions options_; @@ -1488,18 +1491,19 @@ TEST_F(RetransmissionQueueTest, HandoverTest) { EXPECT_THAT(GetSentPacketTSNs(queue), SizeIs(2)); queue.HandleSack(now_, SackChunk(TSN(11), kArwnd, {}, {})); - RetransmissionQueue handedover_queue = CreateQueueByHandover(queue); + std::unique_ptr handedover_queue = + CreateQueueByHandover(queue); EXPECT_CALL(producer_, Produce) .WillOnce(CreateChunk()) .WillOnce(CreateChunk()) .WillOnce(CreateChunk()) .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); - EXPECT_THAT(GetSentPacketTSNs(handedover_queue), + EXPECT_THAT(GetSentPacketTSNs(*handedover_queue), testing::ElementsAre(TSN(12), TSN(13), TSN(14))); - handedover_queue.HandleSack(now_, SackChunk(TSN(13), kArwnd, {}, {})); - EXPECT_THAT(handedover_queue.GetChunkStatesForTesting(), + handedover_queue->HandleSack(now_, SackChunk(TSN(13), kArwnd, {}, {})); + EXPECT_THAT(handedover_queue->GetChunkStatesForTesting(), ElementsAre(Pair(TSN(13), State::kAcked), // Pair(TSN(14), State::kInFlight))); }