diff --git a/net/dcsctp/socket/dcsctp_socket.cc b/net/dcsctp/socket/dcsctp_socket.cc index 822040ef5b..56abb492c0 100644 --- a/net/dcsctp/socket/dcsctp_socket.cc +++ b/net/dcsctp/socket/dcsctp_socket.cc @@ -189,6 +189,7 @@ DcSctpSocket::DcSctpSocket(absl::string_view log_prefix, send_queue_( log_prefix_, options_.max_send_buffer_size, + options_.mtu, options_.default_stream_priority, [this](StreamID stream_id) { callbacks_.OnBufferedAmountLow(stream_id); diff --git a/net/dcsctp/socket/dcsctp_socket_test.cc b/net/dcsctp/socket/dcsctp_socket_test.cc index 82fbb1b90b..e70378ffd3 100644 --- a/net/dcsctp/socket/dcsctp_socket_test.cc +++ b/net/dcsctp/socket/dcsctp_socket_test.cc @@ -371,6 +371,18 @@ std::unique_ptr HandoverSocket( return handover_socket; } +std::vector GetReceivedMessagePpids(SocketUnderTest& z) { + std::vector ppids; + for (;;) { + absl::optional msg = z.cb.ConsumeReceivedMessage(); + if (!msg.has_value()) { + break; + } + ppids.push_back(*msg->ppid()); + } + return ppids; +} + // Test parameter that controls whether to perform handovers during the test. A // test can have multiple points where it conditionally hands over socket Z. // Either socket Z will be handed over at all those points or handed over never. @@ -2403,5 +2415,110 @@ TEST(DcSctpSocketTest, ReconnectSocketWithPendingStreamReset) { ExchangeMessages(a, z); a.socket.ResetStreams(std::vector({StreamID(2)})); } + +TEST(DcSctpSocketTest, SmallSentMessagesWithPrioWillArriveInSpecificOrder) { + DcSctpOptions options = {.enable_message_interleaving = true}; + SocketUnderTest a("A", options); + SocketUnderTest z("A", options); + + a.socket.SetStreamPriority(StreamID(1), StreamPriority(700)); + a.socket.SetStreamPriority(StreamID(2), StreamPriority(200)); + a.socket.SetStreamPriority(StreamID(3), StreamPriority(100)); + + // Enqueue messages before connecting the socket, to ensure they aren't send + // as soon as Send() is called. + a.socket.Send(DcSctpMessage(StreamID(3), PPID(301), + std::vector(kSmallMessageSize)), + kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(101), + std::vector(kSmallMessageSize)), + kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(2), PPID(201), + std::vector(kSmallMessageSize)), + kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(102), + std::vector(kSmallMessageSize)), + kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(103), + std::vector(kSmallMessageSize)), + kSendOptions); + + ConnectSockets(a, z); + ExchangeMessages(a, z); + + std::vector received_ppids; + for (;;) { + absl::optional msg = z.cb.ConsumeReceivedMessage(); + if (!msg.has_value()) { + break; + } + received_ppids.push_back(*msg->ppid()); + } + + EXPECT_THAT(received_ppids, ElementsAre(101, 102, 103, 201, 301)); +} + +TEST(DcSctpSocketTest, LargeSentMessagesWithPrioWillArriveInSpecificOrder) { + DcSctpOptions options = {.enable_message_interleaving = true}; + SocketUnderTest a("A", options); + SocketUnderTest z("A", options); + + a.socket.SetStreamPriority(StreamID(1), StreamPriority(700)); + a.socket.SetStreamPriority(StreamID(2), StreamPriority(200)); + a.socket.SetStreamPriority(StreamID(3), StreamPriority(100)); + + // Enqueue messages before connecting the socket, to ensure they aren't send + // as soon as Send() is called. + a.socket.Send(DcSctpMessage(StreamID(3), PPID(301), + std::vector(kLargeMessageSize)), + kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(101), + std::vector(kLargeMessageSize)), + kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(2), PPID(201), + std::vector(kLargeMessageSize)), + kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(102), + std::vector(kLargeMessageSize)), + kSendOptions); + + ConnectSockets(a, z); + ExchangeMessages(a, z); + + EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(101, 102, 201, 301)); +} + +TEST(DcSctpSocketTest, MessageWithHigherPrioWillInterruptLowerPrioMessage) { + DcSctpOptions options = {.enable_message_interleaving = true}; + SocketUnderTest a("A", options); + SocketUnderTest z("Z", options); + + ConnectSockets(a, z); + + a.socket.SetStreamPriority(StreamID(2), StreamPriority(128)); + a.socket.Send(DcSctpMessage(StreamID(2), PPID(201), + std::vector(kLargeMessageSize)), + kSendOptions); + + // Due to a non-zero initial congestion window, the message will already start + // to send, but will not succeed to be sent completely before filling the + // congestion window or stopping due to reaching how many packets that can be + // sent at once (max burst). The important thing is that the entire message + // doesn't get sent in full. + + // Now enqueue two messages; one small and one large higher priority message. + a.socket.SetStreamPriority(StreamID(1), StreamPriority(512)); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(101), + std::vector(kSmallMessageSize)), + kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(102), + std::vector(kLargeMessageSize)), + kSendOptions); + + ExchangeMessages(a, z); + + EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(101, 102, 201)); +} + } // namespace } // namespace dcsctp diff --git a/net/dcsctp/socket/transmission_control_block.h b/net/dcsctp/socket/transmission_control_block.h index 038ad3683f..f21278845b 100644 --- a/net/dcsctp/socket/transmission_control_block.h +++ b/net/dcsctp/socket/transmission_control_block.h @@ -129,6 +129,7 @@ class TransmissionControlBlock : public Context { if (handover_state == nullptr) { send_queue.Reset(); } + send_queue.EnableMessageInterleaving(capabilities.message_interleaving); } // Implementation of `Context`. diff --git a/net/dcsctp/tx/mock_send_queue.h b/net/dcsctp/tx/mock_send_queue.h index 82e96b7084..0c8f5d141d 100644 --- a/net/dcsctp/tx/mock_send_queue.h +++ b/net/dcsctp/tx/mock_send_queue.h @@ -52,6 +52,7 @@ class MockSendQueue : public SendQueue { SetBufferedAmountLowThreshold, (StreamID stream_id, size_t bytes), (override)); + MOCK_METHOD(void, EnableMessageInterleaving, (bool enabled), (override)); }; } // namespace dcsctp diff --git a/net/dcsctp/tx/rr_send_queue.cc b/net/dcsctp/tx/rr_send_queue.cc index bec6f08def..174d19b77c 100644 --- a/net/dcsctp/tx/rr_send_queue.cc +++ b/net/dcsctp/tx/rr_send_queue.cc @@ -32,6 +32,7 @@ namespace dcsctp { RRSendQueue::RRSendQueue(absl::string_view log_prefix, size_t buffer_size, + size_t mtu, StreamPriority default_priority, std::function on_buffered_amount_low, size_t total_buffered_amount_low_threshold, @@ -39,8 +40,7 @@ RRSendQueue::RRSendQueue(absl::string_view log_prefix, : log_prefix_(std::string(log_prefix) + "fcfs: "), buffer_size_(buffer_size), default_priority_(default_priority), - // TODO(webrtc:5696): Provide correct MTU. - scheduler_(DcSctpOptions::kMaxSafeMTUSize), + scheduler_(mtu), on_buffered_amount_low_(std::move(on_buffered_amount_low)), total_buffered_amount_(std::move(on_total_buffered_amount_low)) { total_buffered_amount_.SetLowThreshold(total_buffered_amount_low_threshold); diff --git a/net/dcsctp/tx/rr_send_queue.h b/net/dcsctp/tx/rr_send_queue.h index c2f1ee8e73..49c36feab5 100644 --- a/net/dcsctp/tx/rr_send_queue.h +++ b/net/dcsctp/tx/rr_send_queue.h @@ -45,6 +45,7 @@ class RRSendQueue : public SendQueue { public: RRSendQueue(absl::string_view log_prefix, size_t buffer_size, + size_t mtu, StreamPriority default_priority, std::function on_buffered_amount_low, size_t total_buffered_amount_low_threshold, @@ -81,6 +82,9 @@ class RRSendQueue : public SendQueue { } size_t buffered_amount_low_threshold(StreamID stream_id) const override; void SetBufferedAmountLowThreshold(StreamID stream_id, size_t bytes) override; + void EnableMessageInterleaving(bool enabled) override { + scheduler_.EnableMessageInterleaving(enabled); + } void SetStreamPriority(StreamID stream_id, StreamPriority priority); StreamPriority GetStreamPriority(StreamID stream_id) const; diff --git a/net/dcsctp/tx/rr_send_queue_test.cc b/net/dcsctp/tx/rr_send_queue_test.cc index 3966c17e58..7471cccad5 100644 --- a/net/dcsctp/tx/rr_send_queue_test.cc +++ b/net/dcsctp/tx/rr_send_queue_test.cc @@ -36,12 +36,14 @@ constexpr StreamPriority kDefaultPriority(10); constexpr size_t kBufferedAmountLowThreshold = 500; constexpr size_t kOneFragmentPacketSize = 100; constexpr size_t kTwoFragmentPacketSize = 101; +constexpr size_t kMtu = 1100; class RRSendQueueTest : public testing::Test { protected: RRSendQueueTest() : buf_("log: ", kMaxQueueSize, + kMtu, kDefaultPriority, on_buffered_amount_low_.AsStdFunction(), kBufferedAmountLowThreshold, @@ -787,7 +789,7 @@ TEST_F(RRSendQueueTest, WillHandoverPriority) { DcSctpSocketHandoverState state; buf_.AddHandoverState(state); - RRSendQueue q2("log: ", kMaxQueueSize, kDefaultPriority, + RRSendQueue q2("log: ", kMaxQueueSize, kMtu, kDefaultPriority, on_buffered_amount_low_.AsStdFunction(), kBufferedAmountLowThreshold, on_total_buffered_amount_low_.AsStdFunction()); @@ -795,5 +797,25 @@ TEST_F(RRSendQueueTest, WillHandoverPriority) { EXPECT_EQ(q2.GetStreamPriority(StreamID(1)), StreamPriority(42)); EXPECT_EQ(q2.GetStreamPriority(StreamID(2)), StreamPriority(42)); } + +TEST_F(RRSendQueueTest, WillSendMessagesByPrio) { + buf_.EnableMessageInterleaving(true); + buf_.SetStreamPriority(StreamID(1), StreamPriority(10)); + buf_.SetStreamPriority(StreamID(2), StreamPriority(20)); + buf_.SetStreamPriority(StreamID(3), StreamPriority(30)); + + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector(40))); + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, std::vector(20))); + buf_.Add(kNow, DcSctpMessage(StreamID(3), kPPID, std::vector(10))); + std::vector expected_streams = {3, 2, 2, 1, 1, 1, 1}; + + for (uint16_t stream_num : expected_streams) { + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk, + buf_.Produce(kNow, 10)); + EXPECT_EQ(chunk.data.stream_id, StreamID(stream_num)); + } + EXPECT_FALSE(buf_.Produce(kNow, 1).has_value()); +} + } // namespace } // namespace dcsctp diff --git a/net/dcsctp/tx/send_queue.h b/net/dcsctp/tx/send_queue.h index b2e5a9d436..a7e663530a 100644 --- a/net/dcsctp/tx/send_queue.h +++ b/net/dcsctp/tx/send_queue.h @@ -126,6 +126,12 @@ class SendQueue { // Sets a limit for the `OnBufferedAmountLow` event. virtual void SetBufferedAmountLowThreshold(StreamID stream_id, size_t bytes) = 0; + + // Configures the send queue to support interleaved message sending as + // described in RFC8260. Every send queue starts with this value set as + // disabled, but can later change it when the capabilities of the connection + // have been negotiated. This affects the behavior of the `Produce` method. + virtual void EnableMessageInterleaving(bool enabled) = 0; }; } // namespace dcsctp