From 06fbe63cbf81c7731e1fcd68f26aa064953f3347 Mon Sep 17 00:00:00 2001 From: Victor Boivie Date: Thu, 21 Sep 2023 12:30:36 +0200 Subject: [PATCH] dcsctp: Exit deferred stream reset on FORWARD-TSN https://datatracker.ietf.org/doc/html/rfc6525#section-5.2.2: E2: If the Sender's Last Assigned TSN is greater than the cumulative acknowledgment point, then the endpoint MUST enter "deferred reset processing". ... until the cumulative acknowledgment point reaches the Sender's Last Assigned TSN. The cumulative acknowledgement point can not only be reached by receiving DATA chunks, but also by receiving a FORWARD-TSN that instructs the receiver to skip them. This was only done for DATA and not for FORWARD-TSN, which is now corrected. Additionally, an unnecessary implicit sending of SACK after having received FORWARD-TSN was removed as this is done anyway every time a packet has been received. This unifies the processing of DATA and FORWARD-TSN more. Bug: webrtc:14600 Change-Id: If797d3c46e741074fe05e322d0aebec765a87968 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/321400 Reviewed-by: Harald Alvestrand Commit-Queue: Victor Boivie Cr-Commit-Position: refs/heads/main@{#40811} --- net/dcsctp/public/types.h | 3 + net/dcsctp/socket/dcsctp_socket.cc | 25 +++--- net/dcsctp/socket/dcsctp_socket.h | 6 +- net/dcsctp/socket/dcsctp_socket_test.cc | 77 +++++++++++++++++++ .../socket/mock_dcsctp_socket_callbacks.h | 4 + net/dcsctp/timer/fake_timeout.h | 15 ++++ 6 files changed, 115 insertions(+), 15 deletions(-) diff --git a/net/dcsctp/public/types.h b/net/dcsctp/public/types.h index d0725620d8..7d69875d1a 100644 --- a/net/dcsctp/public/types.h +++ b/net/dcsctp/public/types.h @@ -41,6 +41,9 @@ class DurationMs : public webrtc::StrongAlias { constexpr explicit DurationMs(const UnderlyingType& v) : webrtc::StrongAlias(v) {} + static constexpr DurationMs InfiniteDuration() { + return DurationMs(std::numeric_limits::max()); + } // Convenience methods for working with time. constexpr DurationMs& operator+=(DurationMs d) { value_ += d.value_; diff --git a/net/dcsctp/socket/dcsctp_socket.cc b/net/dcsctp/socket/dcsctp_socket.cc index a6845e3a90..6101007896 100644 --- a/net/dcsctp/socket/dcsctp_socket.cc +++ b/net/dcsctp/socket/dcsctp_socket.cc @@ -1102,9 +1102,7 @@ void DcSctpSocket::HandleDataCommon(AnyDataChunk& chunk) { if (tcb_->data_tracker().Observe(tsn, immediate_ack)) { tcb_->reassembly_queue().Add(tsn, std::move(data)); - tcb_->reassembly_queue().MaybeResetStreamsDeferred( - tcb_->data_tracker().last_cumulative_acked_tsn()); - DeliverReassembledMessages(); + MaybeResetStreamsDeferredAndDeliverMessages(); } } @@ -1455,12 +1453,15 @@ void DcSctpSocket::HandleCookieAck( callbacks_.OnConnected(); } -void DcSctpSocket::DeliverReassembledMessages() { - if (tcb_->reassembly_queue().HasMessages()) { - for (auto& message : tcb_->reassembly_queue().FlushMessages()) { - ++metrics_.rx_messages_count; - callbacks_.OnMessageReceived(std::move(message)); - } +void DcSctpSocket::MaybeResetStreamsDeferredAndDeliverMessages() { + // As new data has been received, see if paused streams can be resumed, which + // results in even more data added to the reassembly queue. + tcb_->reassembly_queue().MaybeResetStreamsDeferred( + tcb_->data_tracker().last_cumulative_acked_tsn()); + + for (auto& message : tcb_->reassembly_queue().FlushMessages()) { + ++metrics_.rx_messages_count; + callbacks_.OnMessageReceived(std::move(message)); } } @@ -1710,12 +1711,10 @@ void DcSctpSocket::HandleForwardTsnCommon(const AnyForwardTsnChunk& chunk) { } tcb_->data_tracker().HandleForwardTsn(chunk.new_cumulative_tsn()); tcb_->reassembly_queue().Handle(chunk); + // A forward TSN - for ordered streams - may allow messages to be // delivered. - DeliverReassembledMessages(); - - // Processing a FORWARD_TSN might result in sending a SACK. - tcb_->MaybeSendSack(); + MaybeResetStreamsDeferredAndDeliverMessages(); } void DcSctpSocket::MaybeSendShutdownOrAck() { diff --git a/net/dcsctp/socket/dcsctp_socket.h b/net/dcsctp/socket/dcsctp_socket.h index 157c515d65..4f7d1787a5 100644 --- a/net/dcsctp/socket/dcsctp_socket.h +++ b/net/dcsctp/socket/dcsctp_socket.h @@ -179,8 +179,10 @@ class DcSctpSocket : public DcSctpSocketInterface { // Parses `payload`, which is a serialized packet that is just going to be // sent and prints all chunks. void DebugPrintOutgoing(rtc::ArrayView payload); - // Called whenever there may be reassembled messages, and delivers those. - void DeliverReassembledMessages(); + // Called whenever data has been received, or the cumulative acknowledgment + // TSN has moved, that may result in performing deferred stream resetting and + // delivering messages. + void MaybeResetStreamsDeferredAndDeliverMessages(); // Returns true if there is a TCB, and false otherwise (and reports an error). bool ValidateHasTCB(); diff --git a/net/dcsctp/socket/dcsctp_socket_test.cc b/net/dcsctp/socket/dcsctp_socket_test.cc index 4d8fc8ae86..13202846ac 100644 --- a/net/dcsctp/socket/dcsctp_socket_test.cc +++ b/net/dcsctp/socket/dcsctp_socket_test.cc @@ -9,6 +9,7 @@ */ #include "net/dcsctp/socket/dcsctp_socket.h" +#include #include #include #include @@ -30,6 +31,7 @@ #include "net/dcsctp/packet/chunk/data_chunk.h" #include "net/dcsctp/packet/chunk/data_common.h" #include "net/dcsctp/packet/chunk/error_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" #include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h" #include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h" #include "net/dcsctp/packet/chunk/idata_chunk.h" @@ -275,6 +277,26 @@ void AdvanceTime(SocketUnderTest& a, SocketUnderTest& z, DurationMs duration) { RunTimers(z); } +// Exchanges messages between `a` and `z`, advancing time until there are no +// more pending timers, or until `max_timeout` is reached. +void ExchangeMessagesAndAdvanceTime( + SocketUnderTest& a, + SocketUnderTest& z, + DurationMs max_timeout = DurationMs(10000)) { + TimeMs time_started = a.cb.TimeMillis(); + while (a.cb.TimeMillis() - time_started < max_timeout) { + ExchangeMessages(a, z); + + DurationMs time_to_next_timeout = + std::min(a.cb.GetTimeToNextTimeout(), z.cb.GetTimeToNextTimeout()); + if (time_to_next_timeout == DurationMs::InfiniteDuration()) { + // No more pending timer. + return; + } + AdvanceTime(a, z, time_to_next_timeout); + } +} + // Calls Connect() on `sock_a_` and make the connection established. void ConnectSockets(SocketUnderTest& a, SocketUnderTest& z) { EXPECT_CALL(a.cb, OnConnected).Times(1); @@ -2977,5 +2999,60 @@ TEST_P(DcSctpSocketParametrizedTest, AllPacketsAfterConnectHaveZeroChecksum) { MaybeHandoverSocketAndSendMessage(a, std::move(z)); } + +TEST(DcSctpSocketTest, HandlesForwardTsnOutOfOrderWithStreamResetting) { + // This test ensures that receiving FORWARD-TSN and RECONFIG out of order is + // handled correctly. + SocketUnderTest a("A", {.heartbeat_interval = DurationMs(0)}); + SocketUnderTest z("Z", {.heartbeat_interval = DurationMs(0)}); + + ConnectSockets(a, z); + std::vector payload(kSmallMessageSize); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), + { + .max_retransmissions = 0, + }); + + // Packet is lost. + EXPECT_THAT(a.cb.ConsumeSentPacket(), + HasChunks(ElementsAre( + IsDataChunk(AllOf(Property(&DataChunk::ssn, SSN(0)), + Property(&DataChunk::ppid, PPID(51))))))); + AdvanceTime(a, z, a.options.rto_initial); + + auto fwd_tsn_packet = a.cb.ConsumeSentPacket(); + EXPECT_THAT(fwd_tsn_packet, + HasChunks(ElementsAre(IsChunkType(ForwardTsnChunk::kType)))); + // Reset stream 1 + a.socket.ResetStreams(std::vector({StreamID(1)})); + auto reconfig_packet = a.cb.ConsumeSentPacket(); + EXPECT_THAT(reconfig_packet, + HasChunks(ElementsAre(IsChunkType(ReConfigChunk::kType)))); + + // These two packets are received in the wrong order. + z.socket.ReceivePacket(reconfig_packet); + z.socket.ReceivePacket(fwd_tsn_packet); + ExchangeMessagesAndAdvanceTime(a, z); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload), {}); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + + auto data_packet_2 = a.cb.ConsumeSentPacket(); + auto data_packet_3 = a.cb.ConsumeSentPacket(); + EXPECT_THAT(data_packet_2, HasChunks(ElementsAre(IsDataChunk(AllOf( + Property(&DataChunk::ssn, SSN(0)), + Property(&DataChunk::ppid, PPID(52))))))); + EXPECT_THAT(data_packet_3, HasChunks(ElementsAre(IsDataChunk(AllOf( + Property(&DataChunk::ssn, SSN(1)), + Property(&DataChunk::ppid, PPID(53))))))); + + z.socket.ReceivePacket(data_packet_2); + z.socket.ReceivePacket(data_packet_3); + ASSERT_THAT(z.cb.ConsumeReceivedMessage(), + testing::Optional(Property(&DcSctpMessage::ppid, PPID(52)))); + ASSERT_THAT(z.cb.ConsumeReceivedMessage(), + testing::Optional(Property(&DcSctpMessage::ppid, PPID(53)))); +} + } // namespace } // namespace dcsctp diff --git a/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h b/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h index 8b2a772fa3..150c1b9fa5 100644 --- a/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h +++ b/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h @@ -166,6 +166,10 @@ class MockDcSctpSocketCallbacks : public DcSctpSocketCallbacks { return timeout_manager_.GetNextExpiredTimeout(); } + DurationMs GetTimeToNextTimeout() const { + return timeout_manager_.GetTimeToNextTimeout(); + } + private: const std::string log_prefix_; TimeMs now_ = TimeMs(0); diff --git a/net/dcsctp/timer/fake_timeout.h b/net/dcsctp/timer/fake_timeout.h index 74ffe5af29..4621b2ce83 100644 --- a/net/dcsctp/timer/fake_timeout.h +++ b/net/dcsctp/timer/fake_timeout.h @@ -20,6 +20,7 @@ #include "absl/types/optional.h" #include "api/task_queue/task_queue_base.h" #include "net/dcsctp/public/timeout.h" +#include "net/dcsctp/public/types.h" #include "rtc_base/checks.h" #include "rtc_base/containers/flat_set.h" @@ -53,6 +54,7 @@ class FakeTimeout : public Timeout { } TimeoutID timeout_id() const { return timeout_id_; } + TimeMs expiry() const { return expiry_; } private: const std::function get_time_; @@ -97,6 +99,19 @@ class FakeTimeoutManager { return absl::nullopt; } + DurationMs GetTimeToNextTimeout() const { + TimeMs next_expiry = TimeMs::InfiniteFuture(); + for (const FakeTimeout* timer : timers_) { + if (timer->expiry() < next_expiry) { + next_expiry = timer->expiry(); + } + } + TimeMs now = get_time_(); + return next_expiry != TimeMs::InfiniteFuture() && next_expiry >= now + ? next_expiry - now + : DurationMs::InfiniteDuration(); + } + private: const std::function get_time_; webrtc::flat_set timers_;