diff --git a/net/dcsctp/socket/dcsctp_socket.cc b/net/dcsctp/socket/dcsctp_socket.cc index bcff5bef8d..6ddd55933d 100644 --- a/net/dcsctp/socket/dcsctp_socket.cc +++ b/net/dcsctp/socket/dcsctp_socket.cc @@ -368,9 +368,10 @@ SendStatus DcSctpSocket::Send(DcSctpMessage message, return SendStatus::kErrorResourceExhaustion; } - send_queue_.Add(callbacks_.TimeMillis(), std::move(message), send_options); + TimeMs now = callbacks_.TimeMillis(); + send_queue_.Add(now, std::move(message), send_options); if (tcb_ != nullptr) { - tcb_->SendBufferedPackets(); + tcb_->SendBufferedPackets(now); } RTC_DCHECK(IsConsistent()); @@ -1023,6 +1024,7 @@ void DcSctpSocket::HandleInit(const CommonHeader& header, void DcSctpSocket::SendCookieEcho() { RTC_DCHECK(tcb_ != nullptr); + TimeMs now = callbacks_.TimeMillis(); SctpPacket::Builder b = tcb_->PacketBuilder(); b.Add(*cookie_echo_chunk_); @@ -1030,7 +1032,7 @@ void DcSctpSocket::SendCookieEcho() { // "The COOKIE ECHO chunk can be bundled with any pending outbound DATA // chunks, but it MUST be the first chunk in the packet and until the COOKIE // ACK is returned the sender MUST NOT send any other packets to the peer." - tcb_->SendBufferedPackets(b, /*only_one_packet=*/true); + tcb_->SendBufferedPackets(b, now, /*only_one_packet=*/true); } void DcSctpSocket::HandleInitAck( @@ -1143,7 +1145,7 @@ void DcSctpSocket::HandleCookieEcho( // "A COOKIE ACK chunk may be bundled with any pending DATA chunks (and/or // SACK chunks), but the COOKIE ACK chunk MUST be the first chunk in the // packet." - tcb_->SendBufferedPackets(b); + tcb_->SendBufferedPackets(b, callbacks_.TimeMillis()); } bool DcSctpSocket::HandleCookieEchoWithTCB(const CommonHeader& header, @@ -1244,7 +1246,7 @@ void DcSctpSocket::HandleCookieAck( t1_cookie_->Stop(); cookie_echo_chunk_ = absl::nullopt; SetState(State::kEstablished, "COOKIE_ACK received"); - tcb_->SendBufferedPackets(); + tcb_->SendBufferedPackets(callbacks_.TimeMillis()); callbacks_.OnConnected(); } @@ -1261,14 +1263,14 @@ void DcSctpSocket::HandleSack(const CommonHeader& header, absl::optional chunk = SackChunk::Parse(descriptor.data); if (ValidateParseSuccess(chunk) && ValidateHasTCB()) { + TimeMs now = callbacks_.TimeMillis(); SackChunk sack = ChunkValidators::Clean(*std::move(chunk)); - if (tcb_->retransmission_queue().HandleSack(callbacks_.TimeMillis(), - sack)) { + if (tcb_->retransmission_queue().HandleSack(now, sack)) { MaybeSendShutdownOrAck(); // Receiving an ACK will decrease outstanding bytes (maybe now below // cwnd?) or indicate packet loss that may result in sending FORWARD-TSN. - tcb_->SendBufferedPackets(); + tcb_->SendBufferedPackets(now); } else { RTC_DLOG(LS_VERBOSE) << log_prefix() << "Dropping out-of-order SACK with TSN " diff --git a/net/dcsctp/socket/dcsctp_socket_test.cc b/net/dcsctp/socket/dcsctp_socket_test.cc index b662b3e196..a3ddc7f85e 100644 --- a/net/dcsctp/socket/dcsctp_socket_test.cc +++ b/net/dcsctp/socket/dcsctp_socket_test.cc @@ -40,6 +40,7 @@ #include "net/dcsctp/public/dcsctp_message.h" #include "net/dcsctp/public/dcsctp_options.h" #include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/types.h" #include "net/dcsctp/rx/reassembly_queue.h" #include "net/dcsctp/socket/mock_dcsctp_socket_callbacks.h" #include "net/dcsctp/testing/testing_macros.h" @@ -56,6 +57,7 @@ using ::testing::IsEmpty; using ::testing::SizeIs; constexpr SendOptions kSendOptions; +constexpr size_t kLargeMessageSize = DcSctpOptions::kMaxSafeMTUSize * 20; MATCHER_P(HasDataChunkWithSsn, ssn, "") { absl::optional packet = SctpPacket::Parse(arg); @@ -592,7 +594,7 @@ TEST_F(DcSctpSocketTest, TimeoutResendsPacket) { TEST_F(DcSctpSocketTest, SendALotOfBytesMissedSecondPacket) { ConnectSockets(); - std::vector payload(options_.mtu * 10); + std::vector payload(kLargeMessageSize); sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions); // First DATA @@ -832,7 +834,7 @@ TEST_F(DcSctpSocketTest, OnePeerReconnects) { EXPECT_CALL(cb_a_, OnConnectionRestarted).Times(1); // Let's be evil here - reconnect while a fragmented packet was about to be // sent. The receiving side should get it in full. - std::vector payload(options_.mtu * 10); + std::vector payload(kLargeMessageSize); sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions); // First DATA @@ -1068,5 +1070,97 @@ TEST_F(DcSctpSocketTest, SetMaxMessageSize) { EXPECT_EQ(sock_a_.options().max_message_size, 42u); } +TEST_F(DcSctpSocketTest, SendsMessagesWithLowLifetime) { + ConnectSockets(); + + // Mock that the time always goes forward. + TimeMs now(0); + EXPECT_CALL(cb_a_, TimeMillis).WillRepeatedly([&]() { + now += DurationMs(3); + return now; + }); + EXPECT_CALL(cb_z_, TimeMillis).WillRepeatedly([&]() { + now += DurationMs(3); + return now; + }); + + // Queue a few small messages with low lifetime, both ordered and unordered, + // and validate that all are delivered. + static constexpr int kIterations = 100; + for (int i = 0; i < kIterations; ++i) { + SendOptions send_options; + send_options.unordered = IsUnordered((i % 2) == 0); + send_options.lifetime = DurationMs(i % 3); // 0, 1, 2 ms + + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), send_options); + } + + ExchangeMessages(sock_a_, cb_a_, sock_z_, cb_z_); + + for (int i = 0; i < kIterations; ++i) { + EXPECT_TRUE(cb_z_.ConsumeReceivedMessage().has_value()); + } + + EXPECT_FALSE(cb_z_.ConsumeReceivedMessage().has_value()); + + // Validate that the sockets really make the time move forward. + EXPECT_GE(*now, kIterations * 2); +} + +TEST_F(DcSctpSocketTest, DiscardsMessagesWithLowLifetimeIfMustBuffer) { + ConnectSockets(); + + SendOptions lifetime_0; + lifetime_0.unordered = IsUnordered(true); + lifetime_0.lifetime = DurationMs(0); + + SendOptions lifetime_1; + lifetime_1.unordered = IsUnordered(true); + lifetime_1.lifetime = DurationMs(1); + + // Mock that the time always goes forward. + TimeMs now(0); + EXPECT_CALL(cb_a_, TimeMillis).WillRepeatedly([&]() { + now += DurationMs(3); + return now; + }); + EXPECT_CALL(cb_z_, TimeMillis).WillRepeatedly([&]() { + now += DurationMs(3); + return now; + }); + + // Fill up the send buffer with a large message. + std::vector payload(kLargeMessageSize); + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions); + + // And queue a few small messages with lifetime=0 or 1 ms - can't be sent. + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2, 3}), lifetime_0); + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), {4, 5, 6}), lifetime_1); + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), {7, 8, 9}), lifetime_0); + + // Handle all that was sent until congestion window got full. + for (;;) { + std::vector packet_from_a = cb_a_.ConsumeSentPacket(); + if (packet_from_a.empty()) { + break; + } + sock_z_.ReceivePacket(std::move(packet_from_a)); + } + + // Shouldn't be enough to send that large message. + EXPECT_FALSE(cb_z_.ConsumeReceivedMessage().has_value()); + + // Exchange the rest of the messages, with the time ever increasing. + ExchangeMessages(sock_a_, cb_a_, sock_z_, cb_z_); + + // The large message should be delivered. It was sent reliably. + ASSERT_HAS_VALUE_AND_ASSIGN(DcSctpMessage m1, cb_z_.ConsumeReceivedMessage()); + EXPECT_EQ(m1.stream_id(), StreamID(1)); + EXPECT_THAT(m1.payload(), SizeIs(kLargeMessageSize)); + + // But none of the smaller messages. + EXPECT_FALSE(cb_z_.ConsumeReceivedMessage().has_value()); +} + } // namespace } // namespace dcsctp diff --git a/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h b/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h index 9d0bd53372..799f85c274 100644 --- a/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h +++ b/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h @@ -78,6 +78,7 @@ class MockDcSctpSocketCallbacks : public DcSctpSocketCallbacks { << log_prefix_ << "Socket abort: " << ToString(error) << "; " << message; }); + ON_CALL(*this, TimeMillis).WillByDefault([this]() { return now_; }); } MOCK_METHOD(void, SendPacket, @@ -88,8 +89,7 @@ class MockDcSctpSocketCallbacks : public DcSctpSocketCallbacks { return timeout_manager_.CreateTimeout(); } - TimeMs TimeMillis() override { return now_; } - + MOCK_METHOD(TimeMs, TimeMillis, (), (override)); uint32_t GetRandomInt(uint32_t low, uint32_t high) override { return random_.Rand(low, high); } diff --git a/net/dcsctp/socket/transmission_control_block.cc b/net/dcsctp/socket/transmission_control_block.cc index 09d021d820..6e0be6a316 100644 --- a/net/dcsctp/socket/transmission_control_block.cc +++ b/net/dcsctp/socket/transmission_control_block.cc @@ -51,11 +51,12 @@ void TransmissionControlBlock::ObserveRTT(DurationMs rtt) { } absl::optional TransmissionControlBlock::OnRtxTimerExpiry() { + TimeMs now = callbacks_.TimeMillis(); RTC_DLOG(LS_INFO) << log_prefix_ << "Timer " << t3_rtx_->name() << " has expired"; if (IncrementTxErrorCounter("t3-rtx expired")) { retransmission_queue_.HandleT3RtxTimerExpiry(); - SendBufferedPackets(); + SendBufferedPackets(now); } return absl::nullopt; } @@ -76,8 +77,8 @@ void TransmissionControlBlock::MaybeSendSack() { } void TransmissionControlBlock::SendBufferedPackets(SctpPacket::Builder& builder, + TimeMs now, bool only_one_packet) { - TimeMs now = callbacks_.TimeMillis(); for (int packet_idx = 0;; ++packet_idx) { // Only add control chunks to the first packet that is sent, if sending // multiple packets in one go (as allowed by the congestion window). diff --git a/net/dcsctp/socket/transmission_control_block.h b/net/dcsctp/socket/transmission_control_block.h index 0ab1a1ac3a..2f1c9ada6c 100644 --- a/net/dcsctp/socket/transmission_control_block.h +++ b/net/dcsctp/socket/transmission_control_block.h @@ -151,13 +151,14 @@ class TransmissionControlBlock : public Context { // only a single packet will be sent. Otherwise, zero, one or multiple may be // sent. void SendBufferedPackets(SctpPacket::Builder& builder, + TimeMs now, bool only_one_packet = false); // As above, but without passing in a builder and allowing sending many // packets. - void SendBufferedPackets() { + void SendBufferedPackets(TimeMs now) { SctpPacket::Builder builder(peer_verification_tag_, options_); - SendBufferedPackets(builder, /*only_one_packet=*/false); + SendBufferedPackets(builder, now, /*only_one_packet=*/false); } // Returns a textual representation of this object, for logging. diff --git a/net/dcsctp/tx/fcfs_send_queue.cc b/net/dcsctp/tx/fcfs_send_queue.cc index eae90e09f9..f2dc5e40f8 100644 --- a/net/dcsctp/tx/fcfs_send_queue.cc +++ b/net/dcsctp/tx/fcfs_send_queue.cc @@ -36,7 +36,10 @@ void FCFSSendQueue::Add(TimeMs now, // has been added to the queue. absl::optional expires_at = absl::nullopt; if (send_options.lifetime.has_value()) { - expires_at = now + *send_options.lifetime; + // `expires_at` is the time when it expires. Which is slightly larger than + // the message's lifetime, as the message is alive during its entire + // lifetime (which may be zero). + expires_at = now + *send_options.lifetime + DurationMs(1); } queue.emplace_back(std::move(message), expires_at, send_options); } diff --git a/net/dcsctp/tx/fcfs_send_queue_test.cc b/net/dcsctp/tx/fcfs_send_queue_test.cc index ec28b41b25..a67a0a1a9c 100644 --- a/net/dcsctp/tx/fcfs_send_queue_test.cc +++ b/net/dcsctp/tx/fcfs_send_queue_test.cc @@ -191,7 +191,7 @@ TEST_F(FCFSSendQueueTest, ProduceWithLifetimeExpiry) { // Default is no expiry TimeMs now = kNow; buf_.Add(now, DcSctpMessage(kStreamID, kPPID, payload)); - now = now + DurationMs(1000000); + now += DurationMs(1000000); ASSERT_TRUE(buf_.Produce(now, 100)); SendOptions expires_2_seconds; @@ -199,17 +199,17 @@ TEST_F(FCFSSendQueueTest, ProduceWithLifetimeExpiry) { // Add and consume within lifetime buf_.Add(now, DcSctpMessage(kStreamID, kPPID, payload), expires_2_seconds); - now = now + DurationMs(1999); + now += DurationMs(2000); ASSERT_TRUE(buf_.Produce(now, 100)); // Add and consume just outside lifetime buf_.Add(now, DcSctpMessage(kStreamID, kPPID, payload), expires_2_seconds); - now = now + DurationMs(2000); + now += DurationMs(2001); ASSERT_FALSE(buf_.Produce(now, 100)); // A long time after expiry buf_.Add(now, DcSctpMessage(kStreamID, kPPID, payload), expires_2_seconds); - now = now + DurationMs(1000000); + now += DurationMs(1000000); ASSERT_FALSE(buf_.Produce(now, 100)); // Expire one message, but produce the second that is not expired. @@ -219,7 +219,7 @@ TEST_F(FCFSSendQueueTest, ProduceWithLifetimeExpiry) { expires_4_seconds.lifetime = DurationMs(4000); buf_.Add(now, DcSctpMessage(kStreamID, kPPID, payload), expires_4_seconds); - now = now + DurationMs(2000); + now += DurationMs(2001); ASSERT_TRUE(buf_.Produce(now, 100)); ASSERT_FALSE(buf_.Produce(now, 100));