diff --git a/media/sctp/sctp_transport.cc b/media/sctp/sctp_transport.cc index 539eebd50e..6578d4cc04 100644 --- a/media/sctp/sctp_transport.cc +++ b/media/sctp/sctp_transport.cc @@ -720,6 +720,21 @@ bool SctpTransport::SendData(const SendDataParams& params, ready_to_send_data_ = false; return false; } + + // Do not queue data to send on a closing stream. + auto it = stream_status_by_sid_.find(params.sid); + if (it == stream_status_by_sid_.end() || !it->second.is_open()) { + RTC_LOG(LS_WARNING) + << debug_name_ + << "->SendData(...): " + "Not sending data because sid is unknown or closing: " + << params.sid; + if (result) { + *result = SDR_ERROR; + } + return false; + } + size_t payload_size = payload.size(); OutgoingMessage message(payload, params); SendDataResult send_message_result = SendMessageInternal(&message); @@ -756,12 +771,11 @@ SendDataResult SctpTransport::SendMessageInternal(OutgoingMessage* message) { } if (message->send_params().type != DMT_CONTROL) { auto it = stream_status_by_sid_.find(message->send_params().sid); - if (it == stream_status_by_sid_.end() || !it->second.is_open()) { - RTC_LOG(LS_WARNING) - << debug_name_ - << "->SendMessageInternal(...): " - "Not sending data because sid is unknown or closing: " - << message->send_params().sid; + if (it == stream_status_by_sid_.end()) { + RTC_LOG(LS_WARNING) << debug_name_ + << "->SendMessageInternal(...): " + "Not sending data because sid is unknown: " + << message->send_params().sid; return SDR_ERROR; } } @@ -1032,13 +1046,19 @@ void SctpTransport::CloseSctpSocket() { bool SctpTransport::SendQueuedStreamResets() { RTC_DCHECK_RUN_ON(network_thread_); + auto needs_reset = + [this](const std::map::value_type& stream) { + // Ignore streams with partial outgoing messages as they are required to + // be fully sent by the WebRTC spec + // https://w3c.github.io/webrtc-pc/#closing-procedure + return stream.second.need_outgoing_reset() && + (!partial_outgoing_message_.has_value() || + partial_outgoing_message_.value().send_params().sid != + static_cast(stream.first)); + }; // Figure out how many streams need to be reset. We need to do this so we can // allocate the right amount of memory for the sctp_reset_streams structure. - size_t num_streams = absl::c_count_if( - stream_status_by_sid_, - [](const std::map::value_type& stream) { - return stream.second.need_outgoing_reset(); - }); + size_t num_streams = absl::c_count_if(stream_status_by_sid_, needs_reset); if (num_streams == 0) { // Nothing to reset. return true; @@ -1057,12 +1077,10 @@ bool SctpTransport::SendQueuedStreamResets() { resetp->srs_number_streams = rtc::checked_cast(num_streams); int result_idx = 0; - for (const std::map::value_type& stream : - stream_status_by_sid_) { - if (!stream.second.need_outgoing_reset()) { - continue; + for (const auto& stream : stream_status_by_sid_) { + if (needs_reset(stream)) { + resetp->srs_stream_list[result_idx++] = stream.first; } - resetp->srs_stream_list[result_idx++] = stream.first; } int ret = @@ -1111,7 +1129,16 @@ bool SctpTransport::SendBufferedMessage() { return false; } RTC_DCHECK_EQ(0u, partial_outgoing_message_->size()); + + int sid = partial_outgoing_message_->send_params().sid; partial_outgoing_message_.reset(); + + // Send the queued stream reset if it was pending for this stream. + auto it = stream_status_by_sid_.find(sid); + if (it->second.need_outgoing_reset()) { + SendQueuedStreamResets(); + } + return true; } diff --git a/media/sctp/sctp_transport_unittest.cc b/media/sctp/sctp_transport_unittest.cc index 120f4e5a27..be3eb8e386 100644 --- a/media/sctp/sctp_transport_unittest.cc +++ b/media/sctp/sctp_transport_unittest.cc @@ -518,6 +518,47 @@ TEST_P(SctpTransportTestWithOrdered, SendLargeBufferedOutgoingMessage) { EXPECT_EQ(2u, receiver2()->num_messages_received()); } +// Tests that a large message gets buffered and later sent by the SctpTransport +// when the sctp library only accepts the message partially during a stream +// reset. +TEST_P(SctpTransportTestWithOrdered, + SendLargeBufferedOutgoingMessageDuringReset) { + bool ordered = GetParam(); + SetupConnectedTransportsWithTwoStreams(); + SctpTransportObserver transport2_observer(transport2()); + + // Wait for initial SCTP association to be formed. + EXPECT_EQ_WAIT(1, transport1_ready_to_send_count(), kDefaultTimeout); + // Make the fake transport unwritable so that messages pile up for the SCTP + // socket. + fake_dtls1()->SetWritable(false); + SendDataResult result; + + // Fill almost all of sctp library's send buffer. + ASSERT_TRUE(SendData(transport1(), /*sid=*/1, + std::string(kSctpSendBufferSize / 2, 'a'), &result, + ordered)); + + std::string buffered_message(kSctpSendBufferSize, 'b'); + // SctpTransport accepts this message by buffering the second half. + ASSERT_TRUE( + SendData(transport1(), /*sid=*/1, buffered_message, &result, ordered)); + // Queue a stream reset + transport1()->ResetStream(/*sid=*/1); + + // Make the transport writable again and expect a "SignalReadyToSendData" at + // some point after sending the buffered message. + fake_dtls1()->SetWritable(true); + EXPECT_EQ_WAIT(2, transport1_ready_to_send_count(), kDefaultTimeout); + + // Queued message should be received by the receiver before receiving the + // reset + EXPECT_TRUE_WAIT(ReceivedData(receiver2(), 1, buffered_message), + kDefaultTimeout); + EXPECT_EQ(2u, receiver2()->num_messages_received()); + EXPECT_TRUE_WAIT(transport2_observer.WasStreamClosed(1), kDefaultTimeout); +} + TEST_P(SctpTransportTestWithOrdered, SendData) { bool ordered = GetParam(); SetupConnectedTransportsWithTwoStreams();