diff --git a/net/dcsctp/socket/dcsctp_socket.cc b/net/dcsctp/socket/dcsctp_socket.cc index 21b40d854f..7ccca5cbda 100644 --- a/net/dcsctp/socket/dcsctp_socket.cc +++ b/net/dcsctp/socket/dcsctp_socket.cc @@ -479,13 +479,7 @@ ResetStreamsStatus DcSctpSocket::ResetStreams( } tcb_->stream_reset_handler().ResetStreams(outgoing_streams); - absl::optional reconfig = - tcb_->stream_reset_handler().MakeStreamResetRequest(); - if (reconfig.has_value()) { - SctpPacket::Builder builder = tcb_->PacketBuilder(); - builder.Add(*reconfig); - packet_sender_.Send(builder); - } + MaybeSendResetStreamsRequest(); RTC_DCHECK(IsConsistent()); return ResetStreamsStatus::kPerformed; @@ -570,6 +564,16 @@ void DcSctpSocket::MaybeSendShutdownOnPacketReceived(const SctpPacket& packet) { } } +void DcSctpSocket::MaybeSendResetStreamsRequest() { + absl::optional reconfig = + tcb_->stream_reset_handler().MakeStreamResetRequest(); + if (reconfig.has_value()) { + SctpPacket::Builder builder = tcb_->PacketBuilder(); + builder.Add(*reconfig); + packet_sender_.Send(builder); + } +} + bool DcSctpSocket::ValidatePacket(const SctpPacket& packet) { const CommonHeader& header = packet.common_header(); VerificationTag my_verification_tag = @@ -1463,6 +1467,10 @@ void DcSctpSocket::HandleReconfig( absl::optional chunk = ReConfigChunk::Parse(descriptor.data); if (ValidateParseSuccess(chunk) && ValidateHasTCB()) { tcb_->stream_reset_handler().HandleReConfig(*std::move(chunk)); + // Handling this response may result in outgoing stream resets finishing + // (either successfully or with failure). If there still are pending streams + // that were waiting for this request to finish, continue resetting them. + MaybeSendResetStreamsRequest(); } } diff --git a/net/dcsctp/socket/dcsctp_socket.h b/net/dcsctp/socket/dcsctp_socket.h index b1b3ea9d9b..0ab54e801a 100644 --- a/net/dcsctp/socket/dcsctp_socket.h +++ b/net/dcsctp/socket/dcsctp_socket.h @@ -155,6 +155,8 @@ class DcSctpSocket : public DcSctpSocketInterface { void MaybeSendShutdownOrAck(); // If the socket is shutting down, responds SHUTDOWN to any incoming DATA. void MaybeSendShutdownOnPacketReceived(const SctpPacket& packet); + // If there are streams pending to be reset, send a request to reset them. + void MaybeSendResetStreamsRequest(); // Sends a INIT chunk. void SendInit(); // Sends a SHUTDOWN chunk. diff --git a/net/dcsctp/socket/dcsctp_socket_test.cc b/net/dcsctp/socket/dcsctp_socket_test.cc index 848b7d6274..a6b8aa62bf 100644 --- a/net/dcsctp/socket/dcsctp_socket_test.cc +++ b/net/dcsctp/socket/dcsctp_socket_test.cc @@ -61,6 +61,7 @@ using ::testing::ElementsAre; using ::testing::HasSubstr; using ::testing::IsEmpty; using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; constexpr SendOptions kSendOptions; constexpr size_t kLargeMessageSize = DcSctpOptions::kMaxSafeMTUSize * 20; @@ -2262,5 +2263,58 @@ TEST(DcSctpSocketTest, ReceiveBothUnorderedAndOrderedWithSameTSN) { std::vector(10), opts)) .Build()); } + +TEST(DcSctpSocketTest, CloseTwoStreamsAtTheSameTime) { + // Reported as https://crbug.com/1312009. + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(1)))).Times(1); + EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(2)))).Times(1); + EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(1)))).Times(1); + EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(2)))).Times(1); + + ConnectSockets(a, z); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions); + + ExchangeMessages(a, z); + + a.socket.ResetStreams(std::vector({StreamID(1)})); + a.socket.ResetStreams(std::vector({StreamID(2)})); + + ExchangeMessages(a, z); +} + +TEST(DcSctpSocketTest, CloseThreeStreamsAtTheSameTime) { + // Similar to CloseTwoStreamsAtTheSameTime, but ensuring that the two + // remaining streams are reset at the same time in the second request. + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(1)))).Times(1); + EXPECT_CALL(z.cb, OnIncomingStreamsReset( + UnorderedElementsAre(StreamID(2), StreamID(3)))) + .Times(1); + EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(1)))).Times(1); + EXPECT_CALL(a.cb, OnStreamsResetPerformed( + UnorderedElementsAre(StreamID(2), StreamID(3)))) + .Times(1); + + ConnectSockets(a, z); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), {1, 2}), kSendOptions); + + ExchangeMessages(a, z); + + a.socket.ResetStreams(std::vector({StreamID(1)})); + a.socket.ResetStreams(std::vector({StreamID(2)})); + a.socket.ResetStreams(std::vector({StreamID(3)})); + + ExchangeMessages(a, z); +} } // namespace } // namespace dcsctp