diff --git a/net/dcsctp/fuzzers/dcsctp_fuzzers.cc b/net/dcsctp/fuzzers/dcsctp_fuzzers.cc index b4b6224ec4..e8fcacffa0 100644 --- a/net/dcsctp/fuzzers/dcsctp_fuzzers.cc +++ b/net/dcsctp/fuzzers/dcsctp_fuzzers.cc @@ -435,6 +435,7 @@ void FuzzSocket(DcSctpSocketInterface& socket, options.unordered = IsUnordered(flags & 0x01); options.max_retransmissions = (flags & 0x02) != 0 ? absl::make_optional(0) : absl::nullopt; + options.lifecycle_id = LifecycleId(42); size_t payload_exponent = (flags >> 2) % 16; size_t payload_size = static_cast(1) << payload_exponent; socket.Send(DcSctpMessage(StreamID(state.GetByte()), PPID(53), diff --git a/net/dcsctp/socket/dcsctp_socket.cc b/net/dcsctp/socket/dcsctp_socket.cc index aa48649fb6..53838193ec 100644 --- a/net/dcsctp/socket/dcsctp_socket.cc +++ b/net/dcsctp/socket/dcsctp_socket.cc @@ -62,6 +62,7 @@ #include "net/dcsctp/public/dcsctp_options.h" #include "net/dcsctp/public/dcsctp_socket.h" #include "net/dcsctp/public/packet_observer.h" +#include "net/dcsctp/public/types.h" #include "net/dcsctp/rx/data_tracker.h" #include "net/dcsctp/rx/reassembly_queue.h" #include "net/dcsctp/socket/callback_deferrer.h" @@ -447,13 +448,20 @@ SendStatus DcSctpSocket::Send(DcSctpMessage message, const SendOptions& send_options) { RTC_DCHECK_RUN_ON(&thread_checker_); CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + LifecycleId lifecycle_id = send_options.lifecycle_id; if (message.payload().empty()) { + if (lifecycle_id.IsSet()) { + callbacks_.OnLifecycleEnd(lifecycle_id); + } callbacks_.OnError(ErrorKind::kProtocolViolation, "Unable to send empty message"); return SendStatus::kErrorMessageEmpty; } if (message.payload().size() > options_.max_message_size) { + if (lifecycle_id.IsSet()) { + callbacks_.OnLifecycleEnd(lifecycle_id); + } callbacks_.OnError(ErrorKind::kProtocolViolation, "Unable to send too large message"); return SendStatus::kErrorMessageTooLarge; @@ -464,11 +472,17 @@ SendStatus DcSctpSocket::Send(DcSctpMessage message, // "An endpoint should reject any new data request from its upper layer // if it is in the SHUTDOWN-PENDING, SHUTDOWN-SENT, SHUTDOWN-RECEIVED, or // SHUTDOWN-ACK-SENT state." + if (lifecycle_id.IsSet()) { + callbacks_.OnLifecycleEnd(lifecycle_id); + } callbacks_.OnError(ErrorKind::kWrongSequence, "Unable to send message as the socket is shutting down"); return SendStatus::kErrorShuttingDown; } if (send_queue_.IsFull()) { + if (lifecycle_id.IsSet()) { + callbacks_.OnLifecycleEnd(lifecycle_id); + } callbacks_.OnError(ErrorKind::kResourceExhaustion, "Unable to send message as the send queue is full"); return SendStatus::kErrorResourceExhaustion; diff --git a/net/dcsctp/socket/dcsctp_socket_test.cc b/net/dcsctp/socket/dcsctp_socket_test.cc index f4b0b2c89f..d1e2d904e0 100644 --- a/net/dcsctp/socket/dcsctp_socket_test.cc +++ b/net/dcsctp/socket/dcsctp_socket_test.cc @@ -2536,5 +2536,137 @@ TEST(DcSctpSocketTest, MessageWithHigherPrioWillInterruptLowerPrioMessage) { EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(101, 102, 201)); } +TEST(DcSctpSocketTest, LifecycleEventsAreGeneratedForAckedMessages) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + ConnectSockets(a, z); + + a.socket.Send(DcSctpMessage(StreamID(2), PPID(101), + std::vector(kLargeMessageSize)), + {.lifecycle_id = LifecycleId(41)}); + + a.socket.Send(DcSctpMessage(StreamID(2), PPID(102), + std::vector(kLargeMessageSize)), + kSendOptions); + + a.socket.Send(DcSctpMessage(StreamID(2), PPID(103), + std::vector(kLargeMessageSize)), + {.lifecycle_id = LifecycleId(42)}); + + EXPECT_CALL(a.cb, OnLifecycleMessageDelivered(LifecycleId(41))); + EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(41))); + EXPECT_CALL(a.cb, OnLifecycleMessageDelivered(LifecycleId(42))); + EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(42))); + ExchangeMessages(a, z); + // In case of delayed ack. + AdvanceTime(a, z, a.options.delayed_ack_max_timeout); + ExchangeMessages(a, z); + + EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(101, 102, 103)); +} + +TEST(DcSctpSocketTest, LifecycleEventsForFailMaxRetransmissions) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + ConnectSockets(a, z); + + std::vector payload(a.options.mtu - 100); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), + { + .max_retransmissions = 0, + .lifecycle_id = LifecycleId(1), + }); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload), + { + .max_retransmissions = 0, + .lifecycle_id = LifecycleId(2), + }); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), + { + .max_retransmissions = 0, + .lifecycle_id = LifecycleId(3), + }); + + // First DATA + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // Second DATA (lost) + a.cb.ConsumeSentPacket(); + + EXPECT_CALL(a.cb, OnLifecycleMessageDelivered(LifecycleId(1))); + EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(1))); + EXPECT_CALL(a.cb, OnLifecycleMessageExpired(LifecycleId(2), + /*maybe_delivered=*/true)); + EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(2))); + EXPECT_CALL(a.cb, OnLifecycleMessageDelivered(LifecycleId(3))); + EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(3))); + ExchangeMessages(a, z); + + // Handle delayed SACK. + AdvanceTime(a, z, a.options.delayed_ack_max_timeout); + ExchangeMessages(a, z); + + // The chunk is now NACKed. Let the RTO expire, to discard the message. + AdvanceTime(a, z, a.options.rto_initial); + ExchangeMessages(a, z); + + // Handle delayed SACK. + AdvanceTime(a, z, a.options.delayed_ack_max_timeout); + ExchangeMessages(a, z); + + EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(51, 53)); +} + +TEST(DcSctpSocketTest, LifecycleEventsForExpiredMessageWithRetransmitLimit) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + ConnectSockets(a, z); + + // Will not be able to send it in full within the congestion window, but will + // need to wait for SACKs to be received for more fragments to be sent. + std::vector payload(kLargeMessageSize); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), + { + .max_retransmissions = 0, + .lifecycle_id = LifecycleId(1), + }); + + // First DATA + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // Second DATA (lost) + a.cb.ConsumeSentPacket(); + + EXPECT_CALL(a.cb, OnLifecycleMessageExpired(LifecycleId(1), + /*maybe_delivered=*/false)); + EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(1))); + ExchangeMessages(a, z); + + EXPECT_THAT(GetReceivedMessagePpids(z), IsEmpty()); +} + +TEST(DcSctpSocketTest, LifecycleEventsForExpiredMessageWithLifetimeLimit) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + // Send it before the socket is connected, to prevent it from being sent too + // quickly. The idea is that it should be expired before even attempting to + // send it in full. + std::vector payload(kSmallMessageSize); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), + { + .lifetime = DurationMs(100), + .lifecycle_id = LifecycleId(1), + }); + + AdvanceTime(a, z, DurationMs(200)); + + EXPECT_CALL(a.cb, OnLifecycleMessageExpired(LifecycleId(1), + /*maybe_delivered=*/false)); + EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(1))); + ConnectSockets(a, z); + ExchangeMessages(a, z); + + EXPECT_THAT(GetReceivedMessagePpids(z), IsEmpty()); +} + } // namespace } // namespace dcsctp diff --git a/net/dcsctp/socket/stream_reset_handler_test.cc b/net/dcsctp/socket/stream_reset_handler_test.cc index e1e54d0422..493b4c4bf7 100644 --- a/net/dcsctp/socket/stream_reset_handler_test.cc +++ b/net/dcsctp/socket/stream_reset_handler_test.cc @@ -107,6 +107,7 @@ class StreamResetHandlerTest : public testing::Test { kArwnd)), retransmission_queue_(std::make_unique( "", + &callbacks_, kMyInitialTsn, kArwnd, producer_, @@ -199,8 +200,8 @@ class StreamResetHandlerTest : public testing::Test { std::make_unique("log: ", kPeerInitialTsn, kArwnd); reasm_->RestoreFromState(state); retransmission_queue_ = std::make_unique( - "", kMyInitialTsn, kArwnd, producer_, [](DurationMs rtt_ms) {}, []() {}, - *t3_rtx_timer_, DcSctpOptions(), + "", &callbacks_, kMyInitialTsn, kArwnd, producer_, + [](DurationMs rtt_ms) {}, []() {}, *t3_rtx_timer_, DcSctpOptions(), /*supports_partial_reliability=*/true, /*use_message_interleaving=*/false); retransmission_queue_->RestoreFromState(state); diff --git a/net/dcsctp/socket/transmission_control_block.cc b/net/dcsctp/socket/transmission_control_block.cc index 44a1b7392c..d769e26069 100644 --- a/net/dcsctp/socket/transmission_control_block.cc +++ b/net/dcsctp/socket/transmission_control_block.cc @@ -89,6 +89,7 @@ TransmissionControlBlock::TransmissionControlBlock( capabilities.message_interleaving), retransmission_queue_( log_prefix, + &callbacks_, my_initial_tsn, a_rwnd, send_queue, diff --git a/net/dcsctp/tx/BUILD.gn b/net/dcsctp/tx/BUILD.gn index e8fbce905f..3cb7df4cc2 100644 --- a/net/dcsctp/tx/BUILD.gn +++ b/net/dcsctp/tx/BUILD.gn @@ -190,6 +190,7 @@ if (rtc_include_tests) { "../public:socket", "../public:types", "../socket:mock_callbacks", + "../socket:mock_callbacks", "../testing:data_generator", "../testing:testing_macros", "../timer", diff --git a/net/dcsctp/tx/retransmission_queue.cc b/net/dcsctp/tx/retransmission_queue.cc index 958879387e..36e2a859ba 100644 --- a/net/dcsctp/tx/retransmission_queue.cc +++ b/net/dcsctp/tx/retransmission_queue.cc @@ -51,6 +51,7 @@ constexpr float kMinBytesRequiredToSendFactor = 0.9; RetransmissionQueue::RetransmissionQueue( absl::string_view log_prefix, + DcSctpSocketCallbacks* callbacks, TSN my_initial_tsn, size_t a_rwnd, SendQueue& send_queue, @@ -60,7 +61,8 @@ RetransmissionQueue::RetransmissionQueue( const DcSctpOptions& options, bool supports_partial_reliability, bool use_message_interleaving) - : options_(options), + : callbacks_(*callbacks), + options_(options), min_bytes_required_to_send_(options.mtu * kMinBytesRequiredToSendFactor), partial_reliability_(supports_partial_reliability), log_prefix_(std::string(log_prefix) + "tx: "), @@ -278,6 +280,21 @@ bool RetransmissionQueue::HandleSack(TimeMs now, const SackChunk& sack) { OutstandingData::AckInfo ack_info = outstanding_data_.HandleSack( cumulative_tsn_ack, sack.gap_ack_blocks(), is_in_fast_recovery()); + // Add lifecycle events for delivered messages. + for (LifecycleId lifecycle_id : ack_info.acked_lifecycle_ids) { + RTC_DLOG(LS_VERBOSE) << "Triggering OnLifecycleMessageDelivered(" + << lifecycle_id.value() << ")"; + callbacks_.OnLifecycleMessageDelivered(lifecycle_id); + callbacks_.OnLifecycleEnd(lifecycle_id); + } + for (LifecycleId lifecycle_id : ack_info.abandoned_lifecycle_ids) { + RTC_DLOG(LS_VERBOSE) << "Triggering OnLifecycleMessageExpired(" + << lifecycle_id.value() << ", true)"; + callbacks_.OnLifecycleMessageExpired(lifecycle_id, + /*maybe_delivered=*/true); + callbacks_.OnLifecycleEnd(lifecycle_id); + } + // Update of outstanding_data_ is now done. Congestion control remains. UpdateReceiverWindow(sack.a_rwnd()); @@ -467,10 +484,14 @@ std::vector> RetransmissionQueue::GetChunksToSend( chunk_opt->data, now, partial_reliability_ ? chunk_opt->max_retransmissions : MaxRetransmits::NoLimit(), - partial_reliability_ ? chunk_opt->expires_at - : TimeMs::InfiniteFuture()); + partial_reliability_ ? chunk_opt->expires_at : TimeMs::InfiniteFuture(), + chunk_opt->lifecycle_id); if (tsn.has_value()) { + if (chunk_opt->lifecycle_id.IsSet()) { + RTC_DCHECK(chunk_opt->data.is_end); + callbacks_.OnLifecycleMessageFullySent(chunk_opt->lifecycle_id); + } to_be_sent.emplace_back(tsn->Wrap(), std::move(chunk_opt->data)); } } diff --git a/net/dcsctp/tx/retransmission_queue.h b/net/dcsctp/tx/retransmission_queue.h index 51eeb5a319..830c0b346d 100644 --- a/net/dcsctp/tx/retransmission_queue.h +++ b/net/dcsctp/tx/retransmission_queue.h @@ -28,6 +28,7 @@ #include "net/dcsctp/packet/data.h" #include "net/dcsctp/public/dcsctp_handover_state.h" #include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" #include "net/dcsctp/timer/timer.h" #include "net/dcsctp/tx/outstanding_data.h" #include "net/dcsctp/tx/retransmission_timeout.h" @@ -55,6 +56,7 @@ class RetransmissionQueue { // `on_clear_retransmission_counter` and will also use `t3_rtx`, which is the // SCTP retransmission timer to manage retransmissions. RetransmissionQueue(absl::string_view log_prefix, + DcSctpSocketCallbacks* callbacks, TSN my_initial_tsn, size_t a_rwnd, SendQueue& send_queue, @@ -212,6 +214,7 @@ class RetransmissionQueue { // to the congestion control algorithm. size_t max_bytes_to_send() const; + DcSctpSocketCallbacks& callbacks_; const DcSctpOptions options_; // The minimum bytes required to be available in the congestion window to // allow packets to be sent - to avoid sending too small packets. diff --git a/net/dcsctp/tx/retransmission_queue_test.cc b/net/dcsctp/tx/retransmission_queue_test.cc index f11ebad19a..e62c030bfa 100644 --- a/net/dcsctp/tx/retransmission_queue_test.cc +++ b/net/dcsctp/tx/retransmission_queue_test.cc @@ -28,6 +28,7 @@ #include "net/dcsctp/packet/chunk/sack_chunk.h" #include "net/dcsctp/packet/data.h" #include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/socket/mock_dcsctp_socket_callbacks.h" #include "net/dcsctp/testing/data_generator.h" #include "net/dcsctp/testing/testing_macros.h" #include "net/dcsctp/timer/fake_timeout.h" @@ -98,7 +99,7 @@ class RetransmissionQueueTest : public testing::Test { RetransmissionQueue CreateQueue(bool supports_partial_reliability = true, bool use_message_interleaving = false) { return RetransmissionQueue( - "", TSN(10), kArwnd, producer_, on_rtt_.AsStdFunction(), + "", &callbacks_, TSN(10), kArwnd, producer_, on_rtt_.AsStdFunction(), on_clear_retransmission_counter_.AsStdFunction(), *timer_, options_, supports_partial_reliability, use_message_interleaving); } @@ -110,7 +111,7 @@ class RetransmissionQueueTest : public testing::Test { queue.AddHandoverState(state); g_handover_state_transformer_for_test(&state); auto queue2 = std::make_unique( - "", TSN(10), kArwnd, producer_, on_rtt_.AsStdFunction(), + "", &callbacks_, TSN(10), kArwnd, producer_, on_rtt_.AsStdFunction(), on_clear_retransmission_counter_.AsStdFunction(), *timer_, options_, /*supports_partial_reliability=*/true, /*use_message_interleaving=*/false); @@ -118,6 +119,7 @@ class RetransmissionQueueTest : public testing::Test { return queue2; } + MockDcSctpSocketCallbacks callbacks_; DcSctpOptions options_; DataGenerator gen_; TimeMs now_ = TimeMs(0);