diff --git a/net/dcsctp/public/types.h b/net/dcsctp/public/types.h index d0725620d8..7d69875d1a 100644 --- a/net/dcsctp/public/types.h +++ b/net/dcsctp/public/types.h @@ -41,6 +41,9 @@ class DurationMs : public webrtc::StrongAlias { constexpr explicit DurationMs(const UnderlyingType& v) : webrtc::StrongAlias(v) {} + static constexpr DurationMs InfiniteDuration() { + return DurationMs(std::numeric_limits::max()); + } // Convenience methods for working with time. constexpr DurationMs& operator+=(DurationMs d) { value_ += d.value_; diff --git a/net/dcsctp/socket/dcsctp_socket.cc b/net/dcsctp/socket/dcsctp_socket.cc index a6845e3a90..6101007896 100644 --- a/net/dcsctp/socket/dcsctp_socket.cc +++ b/net/dcsctp/socket/dcsctp_socket.cc @@ -1102,9 +1102,7 @@ void DcSctpSocket::HandleDataCommon(AnyDataChunk& chunk) { if (tcb_->data_tracker().Observe(tsn, immediate_ack)) { tcb_->reassembly_queue().Add(tsn, std::move(data)); - tcb_->reassembly_queue().MaybeResetStreamsDeferred( - tcb_->data_tracker().last_cumulative_acked_tsn()); - DeliverReassembledMessages(); + MaybeResetStreamsDeferredAndDeliverMessages(); } } @@ -1455,12 +1453,15 @@ void DcSctpSocket::HandleCookieAck( callbacks_.OnConnected(); } -void DcSctpSocket::DeliverReassembledMessages() { - if (tcb_->reassembly_queue().HasMessages()) { - for (auto& message : tcb_->reassembly_queue().FlushMessages()) { - ++metrics_.rx_messages_count; - callbacks_.OnMessageReceived(std::move(message)); - } +void DcSctpSocket::MaybeResetStreamsDeferredAndDeliverMessages() { + // As new data has been received, see if paused streams can be resumed, which + // results in even more data added to the reassembly queue. + tcb_->reassembly_queue().MaybeResetStreamsDeferred( + tcb_->data_tracker().last_cumulative_acked_tsn()); + + for (auto& message : tcb_->reassembly_queue().FlushMessages()) { + ++metrics_.rx_messages_count; + callbacks_.OnMessageReceived(std::move(message)); } } @@ -1710,12 +1711,10 @@ void DcSctpSocket::HandleForwardTsnCommon(const AnyForwardTsnChunk& chunk) { } tcb_->data_tracker().HandleForwardTsn(chunk.new_cumulative_tsn()); tcb_->reassembly_queue().Handle(chunk); + // A forward TSN - for ordered streams - may allow messages to be // delivered. - DeliverReassembledMessages(); - - // Processing a FORWARD_TSN might result in sending a SACK. - tcb_->MaybeSendSack(); + MaybeResetStreamsDeferredAndDeliverMessages(); } void DcSctpSocket::MaybeSendShutdownOrAck() { diff --git a/net/dcsctp/socket/dcsctp_socket.h b/net/dcsctp/socket/dcsctp_socket.h index 157c515d65..4f7d1787a5 100644 --- a/net/dcsctp/socket/dcsctp_socket.h +++ b/net/dcsctp/socket/dcsctp_socket.h @@ -179,8 +179,10 @@ class DcSctpSocket : public DcSctpSocketInterface { // Parses `payload`, which is a serialized packet that is just going to be // sent and prints all chunks. void DebugPrintOutgoing(rtc::ArrayView payload); - // Called whenever there may be reassembled messages, and delivers those. - void DeliverReassembledMessages(); + // Called whenever data has been received, or the cumulative acknowledgment + // TSN has moved, that may result in performing deferred stream resetting and + // delivering messages. + void MaybeResetStreamsDeferredAndDeliverMessages(); // Returns true if there is a TCB, and false otherwise (and reports an error). bool ValidateHasTCB(); diff --git a/net/dcsctp/socket/dcsctp_socket_test.cc b/net/dcsctp/socket/dcsctp_socket_test.cc index 4d8fc8ae86..13202846ac 100644 --- a/net/dcsctp/socket/dcsctp_socket_test.cc +++ b/net/dcsctp/socket/dcsctp_socket_test.cc @@ -9,6 +9,7 @@ */ #include "net/dcsctp/socket/dcsctp_socket.h" +#include #include #include #include @@ -30,6 +31,7 @@ #include "net/dcsctp/packet/chunk/data_chunk.h" #include "net/dcsctp/packet/chunk/data_common.h" #include "net/dcsctp/packet/chunk/error_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" #include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h" #include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h" #include "net/dcsctp/packet/chunk/idata_chunk.h" @@ -275,6 +277,26 @@ void AdvanceTime(SocketUnderTest& a, SocketUnderTest& z, DurationMs duration) { RunTimers(z); } +// Exchanges messages between `a` and `z`, advancing time until there are no +// more pending timers, or until `max_timeout` is reached. +void ExchangeMessagesAndAdvanceTime( + SocketUnderTest& a, + SocketUnderTest& z, + DurationMs max_timeout = DurationMs(10000)) { + TimeMs time_started = a.cb.TimeMillis(); + while (a.cb.TimeMillis() - time_started < max_timeout) { + ExchangeMessages(a, z); + + DurationMs time_to_next_timeout = + std::min(a.cb.GetTimeToNextTimeout(), z.cb.GetTimeToNextTimeout()); + if (time_to_next_timeout == DurationMs::InfiniteDuration()) { + // No more pending timer. + return; + } + AdvanceTime(a, z, time_to_next_timeout); + } +} + // Calls Connect() on `sock_a_` and make the connection established. void ConnectSockets(SocketUnderTest& a, SocketUnderTest& z) { EXPECT_CALL(a.cb, OnConnected).Times(1); @@ -2977,5 +2999,60 @@ TEST_P(DcSctpSocketParametrizedTest, AllPacketsAfterConnectHaveZeroChecksum) { MaybeHandoverSocketAndSendMessage(a, std::move(z)); } + +TEST(DcSctpSocketTest, HandlesForwardTsnOutOfOrderWithStreamResetting) { + // This test ensures that receiving FORWARD-TSN and RECONFIG out of order is + // handled correctly. + SocketUnderTest a("A", {.heartbeat_interval = DurationMs(0)}); + SocketUnderTest z("Z", {.heartbeat_interval = DurationMs(0)}); + + ConnectSockets(a, z); + std::vector payload(kSmallMessageSize); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), + { + .max_retransmissions = 0, + }); + + // Packet is lost. + EXPECT_THAT(a.cb.ConsumeSentPacket(), + HasChunks(ElementsAre( + IsDataChunk(AllOf(Property(&DataChunk::ssn, SSN(0)), + Property(&DataChunk::ppid, PPID(51))))))); + AdvanceTime(a, z, a.options.rto_initial); + + auto fwd_tsn_packet = a.cb.ConsumeSentPacket(); + EXPECT_THAT(fwd_tsn_packet, + HasChunks(ElementsAre(IsChunkType(ForwardTsnChunk::kType)))); + // Reset stream 1 + a.socket.ResetStreams(std::vector({StreamID(1)})); + auto reconfig_packet = a.cb.ConsumeSentPacket(); + EXPECT_THAT(reconfig_packet, + HasChunks(ElementsAre(IsChunkType(ReConfigChunk::kType)))); + + // These two packets are received in the wrong order. + z.socket.ReceivePacket(reconfig_packet); + z.socket.ReceivePacket(fwd_tsn_packet); + ExchangeMessagesAndAdvanceTime(a, z); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload), {}); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + + auto data_packet_2 = a.cb.ConsumeSentPacket(); + auto data_packet_3 = a.cb.ConsumeSentPacket(); + EXPECT_THAT(data_packet_2, HasChunks(ElementsAre(IsDataChunk(AllOf( + Property(&DataChunk::ssn, SSN(0)), + Property(&DataChunk::ppid, PPID(52))))))); + EXPECT_THAT(data_packet_3, HasChunks(ElementsAre(IsDataChunk(AllOf( + Property(&DataChunk::ssn, SSN(1)), + Property(&DataChunk::ppid, PPID(53))))))); + + z.socket.ReceivePacket(data_packet_2); + z.socket.ReceivePacket(data_packet_3); + ASSERT_THAT(z.cb.ConsumeReceivedMessage(), + testing::Optional(Property(&DcSctpMessage::ppid, PPID(52)))); + ASSERT_THAT(z.cb.ConsumeReceivedMessage(), + testing::Optional(Property(&DcSctpMessage::ppid, PPID(53)))); +} + } // namespace } // namespace dcsctp diff --git a/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h b/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h index 8b2a772fa3..150c1b9fa5 100644 --- a/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h +++ b/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h @@ -166,6 +166,10 @@ class MockDcSctpSocketCallbacks : public DcSctpSocketCallbacks { return timeout_manager_.GetNextExpiredTimeout(); } + DurationMs GetTimeToNextTimeout() const { + return timeout_manager_.GetTimeToNextTimeout(); + } + private: const std::string log_prefix_; TimeMs now_ = TimeMs(0); diff --git a/net/dcsctp/timer/fake_timeout.h b/net/dcsctp/timer/fake_timeout.h index 74ffe5af29..4621b2ce83 100644 --- a/net/dcsctp/timer/fake_timeout.h +++ b/net/dcsctp/timer/fake_timeout.h @@ -20,6 +20,7 @@ #include "absl/types/optional.h" #include "api/task_queue/task_queue_base.h" #include "net/dcsctp/public/timeout.h" +#include "net/dcsctp/public/types.h" #include "rtc_base/checks.h" #include "rtc_base/containers/flat_set.h" @@ -53,6 +54,7 @@ class FakeTimeout : public Timeout { } TimeoutID timeout_id() const { return timeout_id_; } + TimeMs expiry() const { return expiry_; } private: const std::function get_time_; @@ -97,6 +99,19 @@ class FakeTimeoutManager { return absl::nullopt; } + DurationMs GetTimeToNextTimeout() const { + TimeMs next_expiry = TimeMs::InfiniteFuture(); + for (const FakeTimeout* timer : timers_) { + if (timer->expiry() < next_expiry) { + next_expiry = timer->expiry(); + } + } + TimeMs now = get_time_(); + return next_expiry != TimeMs::InfiniteFuture() && next_expiry >= now + ? next_expiry - now + : DurationMs::InfiniteDuration(); + } + private: const std::function get_time_; webrtc::flat_set timers_;