diff --git a/net/dcsctp/public/dcsctp_handover_state.h b/net/dcsctp/public/dcsctp_handover_state.h index 8907669e42..3ad4ab7f3e 100644 --- a/net/dcsctp/public/dcsctp_handover_state.h +++ b/net/dcsctp/public/dcsctp_handover_state.h @@ -24,6 +24,25 @@ namespace dcsctp { // for serialization. Serialization is not provided by dcSCTP. If needed it has // to be implemented in the calling client. struct DcSctpSocketHandoverState { + enum class SocketState { + kClosed, + kConnected, + }; + SocketState socket_state = SocketState::kClosed; + + uint32_t my_verification_tag = 0; + uint32_t my_initial_tsn = 0; + uint32_t peer_verification_tag = 0; + uint32_t peer_initial_tsn = 0; + uint64_t tie_tag = 0; + + struct Capabilities { + bool partial_reliability = false; + bool message_interleaving = false; + bool reconfig = false; + }; + Capabilities capabilities; + struct Transmission { uint32_t next_tsn = 0; uint32_t next_reset_req_sn = 0; @@ -98,6 +117,7 @@ class HandoverReadinessStatus value() |= status.value(); return *this; } + std::string ToString() const; }; } // namespace dcsctp diff --git a/net/dcsctp/public/dcsctp_socket.h b/net/dcsctp/public/dcsctp_socket.h index 248646e85f..583d037019 100644 --- a/net/dcsctp/public/dcsctp_socket.h +++ b/net/dcsctp/public/dcsctp_socket.h @@ -17,6 +17,7 @@ #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "api/array_view.h" +#include "net/dcsctp/public/dcsctp_handover_state.h" #include "net/dcsctp/public/dcsctp_message.h" #include "net/dcsctp/public/dcsctp_options.h" #include "net/dcsctp/public/packet_observer.h" @@ -355,6 +356,14 @@ class DcSctpSocketInterface { // `DcSctpSocketCallbacks::OnConnected` will be called on success. virtual void Connect() = 0; + // Puts this socket to the state in which the original socket was when its + // `DcSctpSocketHandoverState` was captured by `GetHandoverStateAndClose`. + // `RestoreFromState` is allowed only on the closed socket. + // `DcSctpSocketCallbacks::OnConnected` will be called if a connected socket + // state is restored. + // `DcSctpSocketCallbacks::OnError` will be called on error. + virtual void RestoreFromState(const DcSctpSocketHandoverState& state) = 0; + // Gracefully shutdowns the socket and sends all outstanding data. This is an // asynchronous operation and `DcSctpSocketCallbacks::OnClosed` will be called // on success. @@ -417,6 +426,20 @@ class DcSctpSocketInterface { // Retrieves the latest metrics. virtual Metrics GetMetrics() const = 0; + + // Returns empty bitmask if the socket is in the state in which a snapshot of + // the state can be made by `GetHandoverStateAndClose()`. Return value is + // invalidated by a call to any non-const method. + virtual HandoverReadinessStatus GetHandoverReadiness() const = 0; + + // Collects a snapshot of the socket state that can be used to reconstruct + // this socket in another process. On success this socket object is closed + // synchronously and no callbacks will be made after the method has returned. + // The method fails if the socket is not in a state ready for handover. + // nullopt indicates the failure. `DcSctpSocketCallbacks::OnClosed` will be + // called on success. + virtual absl::optional + GetHandoverStateAndClose() = 0; }; } // namespace dcsctp diff --git a/net/dcsctp/public/mock_dcsctp_socket.h b/net/dcsctp/public/mock_dcsctp_socket.h index b382773fdf..eb1e8ccec9 100644 --- a/net/dcsctp/public/mock_dcsctp_socket.h +++ b/net/dcsctp/public/mock_dcsctp_socket.h @@ -26,6 +26,11 @@ class MockDcSctpSocket : public DcSctpSocketInterface { MOCK_METHOD(void, Connect, (), (override)); + MOCK_METHOD(void, + RestoreFromState, + (const DcSctpSocketHandoverState&), + (override)); + MOCK_METHOD(void, Shutdown, (), (override)); MOCK_METHOD(void, Close, (), (override)); @@ -59,6 +64,15 @@ class MockDcSctpSocket : public DcSctpSocketInterface { (override)); MOCK_METHOD(Metrics, GetMetrics, (), (const, override)); + + MOCK_METHOD(HandoverReadinessStatus, + GetHandoverReadiness, + (), + (const, override)); + MOCK_METHOD(absl::optional, + GetHandoverStateAndClose, + (), + (override)); }; } // namespace dcsctp diff --git a/net/dcsctp/socket/dcsctp_socket.cc b/net/dcsctp/socket/dcsctp_socket.cc index 5211cca350..afc30f61fd 100644 --- a/net/dcsctp/socket/dcsctp_socket.cc +++ b/net/dcsctp/socket/dcsctp_socket.cc @@ -139,8 +139,57 @@ TieTag MakeTieTag(DcSctpSocketCallbacks& cb) { static_cast(tie_tag_lower)); } +constexpr absl::string_view HandoverUnreadinessReasonToString( + HandoverUnreadinessReason reason) { + switch (reason) { + case HandoverUnreadinessReason::kWrongConnectionState: + return "WRONG_CONNECTION_STATE"; + case HandoverUnreadinessReason::kSendQueueNotEmpty: + return "SEND_QUEUE_NOT_EMPTY"; + case HandoverUnreadinessReason::kDataTrackerTsnBlocksPending: + return "DATA_TRACKER_TSN_BLOCKS_PENDING"; + case HandoverUnreadinessReason::kReassemblyQueueDeliveredTSNsGap: + return "REASSEMBLY_QUEUE_DELIVERED_TSN_GAP"; + case HandoverUnreadinessReason::kStreamResetDeferred: + return "STREAM_RESET_DEFERRED"; + case HandoverUnreadinessReason::kOrderedStreamHasUnassembledChunks: + return "ORDERED_STREAM_HAS_UNASSEMBLED_CHUNKS"; + case HandoverUnreadinessReason::kUnorderedStreamHasUnassembledChunks: + return "UNORDERED_STREAM_HAS_UNASSEMBLED_CHUNKS"; + case HandoverUnreadinessReason::kRetransmissionQueueOutstandingData: + return "RETRANSMISSION_QUEUE_OUTSTANDING_DATA"; + case HandoverUnreadinessReason::kRetransmissionQueueFastRecovery: + return "RETRANSMISSION_QUEUE_FAST_RECOVERY"; + case HandoverUnreadinessReason::kRetransmissionQueueNotEmpty: + return "RETRANSMISSION_QUEUE_NOT_EMPTY"; + case HandoverUnreadinessReason::kPendingStreamReset: + return "PENDING_STREAM_RESET"; + case HandoverUnreadinessReason::kPendingStreamResetRequest: + return "PENDING_STREAM_RESET_REQUEST"; + } +} } // namespace +std::string HandoverReadinessStatus::ToString() const { + std::string result; + for (uint32_t bit = 1; + bit <= static_cast(HandoverUnreadinessReason::kMax); + bit *= 2) { + auto flag = static_cast(bit); + if (Contains(flag)) { + if (!result.empty()) { + result.append(","); + } + absl::string_view s = HandoverUnreadinessReasonToString(flag); + result.append(s.data(), s.size()); + } + } + if (result.empty()) { + result = "READY"; + } + return result; +} + DcSctpSocket::DcSctpSocket(absl::string_view log_prefix, DcSctpSocketCallbacks& callbacks, std::unique_ptr packet_observer, @@ -286,6 +335,42 @@ void DcSctpSocket::Connect() { callbacks_.TriggerDeferred(); } +void DcSctpSocket::RestoreFromState(const DcSctpSocketHandoverState& state) { + if (state_ != State::kClosed) { + callbacks_.OnError(ErrorKind::kUnsupportedOperation, + "Only closed socket can be restored from state"); + } else { + if (state.socket_state == + DcSctpSocketHandoverState::SocketState::kConnected) { + VerificationTag my_verification_tag = + VerificationTag(state.my_verification_tag); + connect_params_.verification_tag = my_verification_tag; + + Capabilities capabilities; + capabilities.partial_reliability = state.capabilities.partial_reliability; + capabilities.message_interleaving = + state.capabilities.message_interleaving; + capabilities.reconfig = state.capabilities.reconfig; + + tcb_ = std::make_unique( + timer_manager_, log_prefix_, options_, capabilities, callbacks_, + send_queue_, my_verification_tag, TSN(state.my_initial_tsn), + VerificationTag(state.peer_verification_tag), + TSN(state.peer_initial_tsn), static_cast(0), + TieTag(state.tie_tag), packet_sender_, + [this]() { return state_ == State::kEstablished; }, &state); + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Created peer TCB from state: " + << tcb_->ToString(); + + SetState(State::kEstablished, "restored from handover state"); + callbacks_.OnConnected(); + } + } + + RTC_DCHECK(IsConsistent()); + callbacks_.TriggerDeferred(); +} + void DcSctpSocket::Shutdown() { if (tcb_ != nullptr) { // https://tools.ietf.org/html/rfc4960#section-9.2 @@ -1579,4 +1664,38 @@ void DcSctpSocket::SendShutdownAck() { t2_shutdown_->Start(); } +HandoverReadinessStatus DcSctpSocket::GetHandoverReadiness() const { + HandoverReadinessStatus status; + if (state_ != State::kClosed && state_ != State::kEstablished) { + status.Add(HandoverUnreadinessReason::kWrongConnectionState); + } + if (!send_queue_.IsEmpty()) { + status.Add(HandoverUnreadinessReason::kSendQueueNotEmpty); + } + if (tcb_) { + status.Add(tcb_->GetHandoverReadiness()); + } + return status; +} + +absl::optional +DcSctpSocket::GetHandoverStateAndClose() { + if (!GetHandoverReadiness().IsReady()) { + return absl::nullopt; + } + + DcSctpSocketHandoverState state; + + if (state_ == State::kClosed) { + state.socket_state = DcSctpSocketHandoverState::SocketState::kClosed; + } else if (state_ == State::kEstablished) { + state.socket_state = DcSctpSocketHandoverState::SocketState::kConnected; + tcb_->AddHandoverState(state); + InternalClose(ErrorKind::kNoError, "handover"); + callbacks_.TriggerDeferred(); + } + + return std::move(state); +} + } // namespace dcsctp diff --git a/net/dcsctp/socket/dcsctp_socket.h b/net/dcsctp/socket/dcsctp_socket.h index 60359bd173..508a8a6aad 100644 --- a/net/dcsctp/socket/dcsctp_socket.h +++ b/net/dcsctp/socket/dcsctp_socket.h @@ -85,6 +85,7 @@ class DcSctpSocket : public DcSctpSocketInterface { void ReceivePacket(rtc::ArrayView data) override; void HandleTimeout(TimeoutID timeout_id) override; void Connect() override; + void RestoreFromState(const DcSctpSocketHandoverState& state) override; void Shutdown() override; void Close() override; SendStatus Send(DcSctpMessage message, @@ -98,6 +99,8 @@ class DcSctpSocket : public DcSctpSocketInterface { size_t buffered_amount_low_threshold(StreamID stream_id) const override; void SetBufferedAmountLowThreshold(StreamID stream_id, size_t bytes) override; Metrics GetMetrics() const override; + HandoverReadinessStatus GetHandoverReadiness() const override; + absl::optional GetHandoverStateAndClose() override; // Returns this socket's verification tag, or zero if not yet connected. VerificationTag verification_tag() const { diff --git a/net/dcsctp/socket/dcsctp_socket_test.cc b/net/dcsctp/socket/dcsctp_socket_test.cc index 5f99cc91be..2fadde8a21 100644 --- a/net/dcsctp/socket/dcsctp_socket_test.cc +++ b/net/dcsctp/socket/dcsctp_socket_test.cc @@ -315,6 +315,24 @@ class DcSctpSocketTest : public testing::Test { EXPECT_EQ(sock_z_->state(), SocketState::kConnected); } + void HandoverSocketZ() { + ASSERT_EQ(sock_z_->GetHandoverReadiness(), HandoverReadinessStatus()); + bool is_closed = sock_z_->state() == SocketState::kClosed; + if (!is_closed) { + EXPECT_CALL(cb_z_, OnClosed).Times(1); + } + absl::optional handover_state = + sock_z_->GetHandoverStateAndClose(); + EXPECT_TRUE(handover_state.has_value()); + cb_z_.Reset(); + sock_z_ = std::make_unique("Z", cb_z_, GetPacketObserver("Z"), + options_); + if (!is_closed) { + EXPECT_CALL(cb_z_, OnConnected).Times(1); + } + sock_z_->RestoreFromState(*handover_state); + } + const DcSctpOptions options_; testing::NiceMock cb_a_; testing::NiceMock cb_z_; @@ -322,6 +340,52 @@ class DcSctpSocketTest : public testing::Test { std::unique_ptr sock_z_; }; +// Test parameter that controls whether to perform handovers during the test. A +// test can have multiple points where it conditionally hands over socket Z. +// Either socket Z will be handed over at all those points or handed over never. +enum class HandoverMode { + kNoHandover, + kPerformHandovers, +}; + +class DcSctpSocketParametrizedTest + : public DcSctpSocketTest, + public ::testing::WithParamInterface { + protected: + // Trigger handover for socket Z depending on the current test param. + void MaybeHandoverSocketZ() { + if (GetParam() == HandoverMode::kPerformHandovers) { + HandoverSocketZ(); + } + } + // Trigger handover for socket Z depending on the current test param. + // Then checks message passing to verify the handed over socket is functional. + void MaybeHandoverSocketZAndSendMessage() { + if (GetParam() == HandoverMode::kPerformHandovers) { + HandoverSocketZ(); + } + + ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + + absl::optional msg = cb_z_.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); + } +}; + +INSTANTIATE_TEST_SUITE_P(Handovers, + DcSctpSocketParametrizedTest, + testing::Values(HandoverMode::kNoHandover, + HandoverMode::kPerformHandovers), + [](const auto& test_info) { + return test_info.param == + HandoverMode::kPerformHandovers + ? "WithHandovers" + : "NoHandover"; + }); + TEST_F(DcSctpSocketTest, EstablishConnection) { EXPECT_CALL(cb_a_, OnConnected).Times(1); EXPECT_CALL(cb_z_, OnConnected).Times(1); @@ -566,8 +630,8 @@ TEST_F(DcSctpSocketTest, ResendingCookieEchoTooManyTimesAborts) { TEST_F(DcSctpSocketTest, DoesntSendMorePacketsUntilCookieAckHasBeenReceived) { sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), - std::vector(kLargeMessageSize)), - kSendOptions); + std::vector(kLargeMessageSize)), + kSendOptions); sock_a_->Connect(); // Z reads INIT, produces INIT_ACK @@ -623,11 +687,13 @@ TEST_F(DcSctpSocketTest, DoesntSendMorePacketsUntilCookieAckHasBeenReceived) { SizeIs(kLargeMessageSize)); } -TEST_F(DcSctpSocketTest, ShutdownConnection) { +TEST_P(DcSctpSocketParametrizedTest, ShutdownConnection) { ConnectSockets(); + MaybeHandoverSocketZ(); RTC_LOG(LS_INFO) << "Shutting down"; + EXPECT_CALL(cb_z_, OnClosed).Times(1); sock_a_->Shutdown(); // Z reads SHUTDOWN, produces SHUTDOWN_ACK sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); @@ -638,6 +704,9 @@ TEST_F(DcSctpSocketTest, ShutdownConnection) { EXPECT_EQ(sock_a_->state(), SocketState::kClosed); EXPECT_EQ(sock_z_->state(), SocketState::kClosed); + + MaybeHandoverSocketZ(); + EXPECT_EQ(sock_z_->state(), SocketState::kClosed); } TEST_F(DcSctpSocketTest, ShutdownTimerExpiresTooManyTimeClosesConnection) { @@ -704,8 +773,9 @@ TEST_F(DcSctpSocketTest, SendMessageAfterEstablished) { EXPECT_EQ(msg->stream_id(), StreamID(1)); } -TEST_F(DcSctpSocketTest, TimeoutResendsPacket) { +TEST_P(DcSctpSocketParametrizedTest, TimeoutResendsPacket) { ConnectSockets(); + MaybeHandoverSocketZ(); sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); cb_a_.ConsumeSentPacket(); @@ -719,10 +789,13 @@ TEST_F(DcSctpSocketTest, TimeoutResendsPacket) { absl::optional msg = cb_z_.ConsumeReceivedMessage(); ASSERT_TRUE(msg.has_value()); EXPECT_EQ(msg->stream_id(), StreamID(1)); + + MaybeHandoverSocketZAndSendMessage(); } -TEST_F(DcSctpSocketTest, SendALotOfBytesMissedSecondPacket) { +TEST_P(DcSctpSocketParametrizedTest, SendALotOfBytesMissedSecondPacket) { ConnectSockets(); + MaybeHandoverSocketZ(); std::vector payload(kLargeMessageSize); sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions); @@ -739,10 +812,13 @@ TEST_F(DcSctpSocketTest, SendALotOfBytesMissedSecondPacket) { ASSERT_TRUE(msg.has_value()); EXPECT_EQ(msg->stream_id(), StreamID(1)); EXPECT_THAT(msg->payload(), testing::ElementsAreArray(payload)); + + MaybeHandoverSocketZAndSendMessage(); } -TEST_F(DcSctpSocketTest, SendingHeartbeatAnswersWithAck) { +TEST_P(DcSctpSocketParametrizedTest, SendingHeartbeatAnswersWithAck) { ConnectSockets(); + MaybeHandoverSocketZ(); // Inject a HEARTBEAT chunk SctpPacket::Builder b(sock_a_->verification_tag(), DcSctpOptions()); @@ -761,10 +837,13 @@ TEST_F(DcSctpSocketTest, SendingHeartbeatAnswersWithAck) { HeartbeatAckChunk::Parse(ack_packet.descriptors()[0].data)); ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatInfoParameter info_param, ack.info()); EXPECT_THAT(info_param.info(), ElementsAre(1, 2, 3, 4)); + + MaybeHandoverSocketZAndSendMessage(); } -TEST_F(DcSctpSocketTest, ExpectHeartbeatToBeSent) { +TEST_P(DcSctpSocketParametrizedTest, ExpectHeartbeatToBeSent) { ConnectSockets(); + MaybeHandoverSocketZ(); EXPECT_THAT(cb_a_.ConsumeSentPacket(), IsEmpty()); @@ -786,11 +865,16 @@ TEST_F(DcSctpSocketTest, ExpectHeartbeatToBeSent) { // Feed it to Sock-z and expect a HEARTBEAT_ACK that will be propagated back. sock_z_->ReceivePacket(hb_packet_raw); sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + + MaybeHandoverSocketZAndSendMessage(); } -TEST_F(DcSctpSocketTest, CloseConnectionAfterTooManyLostHeartbeats) { +TEST_P(DcSctpSocketParametrizedTest, + CloseConnectionAfterTooManyLostHeartbeats) { ConnectSockets(); + MaybeHandoverSocketZ(); + EXPECT_CALL(cb_z_, OnClosed).Times(1); EXPECT_THAT(cb_a_.ConsumeSentPacket(), testing::IsEmpty()); // Force-close socket Z so that it doesn't interfere from now on. sock_z_->Close(); @@ -825,12 +909,16 @@ TEST_F(DcSctpSocketTest, CloseConnectionAfterTooManyLostHeartbeats) { // Should suffice as exceeding RTO AdvanceTime(DurationMs(1000)); RunTimers(); + + MaybeHandoverSocketZ(); } -TEST_F(DcSctpSocketTest, RecoversAfterASuccessfulAck) { +TEST_P(DcSctpSocketParametrizedTest, RecoversAfterASuccessfulAck) { ConnectSockets(); + MaybeHandoverSocketZ(); EXPECT_THAT(cb_a_.ConsumeSentPacket(), testing::IsEmpty()); + EXPECT_CALL(cb_z_, OnClosed).Times(1); // Force-close socket Z so that it doesn't interfere from now on. sock_z_->Close(); @@ -882,8 +970,9 @@ TEST_F(DcSctpSocketTest, RecoversAfterASuccessfulAck) { EXPECT_EQ(another_packet.descriptors()[0].type, HeartbeatRequestChunk::kType); } -TEST_F(DcSctpSocketTest, ResetStream) { +TEST_P(DcSctpSocketParametrizedTest, ResetStream) { ConnectSockets(); + MaybeHandoverSocketZ(); sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), {}); sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); @@ -906,10 +995,13 @@ TEST_F(DcSctpSocketTest, ResetStream) { // Receiving a response will trigger a callback. Streams are now reset. EXPECT_CALL(cb_a_, OnStreamsResetPerformed).Times(1); sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + + MaybeHandoverSocketZAndSendMessage(); } -TEST_F(DcSctpSocketTest, ResetStreamWillMakeChunksStartAtZeroSsn) { +TEST_P(DcSctpSocketParametrizedTest, ResetStreamWillMakeChunksStartAtZeroSsn) { ConnectSockets(); + MaybeHandoverSocketZ(); std::vector payload(options_.mtu - 100); @@ -956,10 +1048,14 @@ TEST_F(DcSctpSocketTest, ResetStreamWillMakeChunksStartAtZeroSsn) { // Handle SACK sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + + MaybeHandoverSocketZAndSendMessage(); } -TEST_F(DcSctpSocketTest, ResetStreamWillOnlyResetTheRequestedStreams) { +TEST_P(DcSctpSocketParametrizedTest, + ResetStreamWillOnlyResetTheRequestedStreams) { ConnectSockets(); + MaybeHandoverSocketZ(); std::vector payload(options_.mtu - 100); @@ -1034,10 +1130,13 @@ TEST_F(DcSctpSocketTest, ResetStreamWillOnlyResetTheRequestedStreams) { // Handle SACK sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + + MaybeHandoverSocketZAndSendMessage(); } -TEST_F(DcSctpSocketTest, OnePeerReconnects) { +TEST_P(DcSctpSocketParametrizedTest, OnePeerReconnects) { ConnectSockets(); + MaybeHandoverSocketZ(); EXPECT_CALL(cb_a_, OnConnectionRestarted).Times(1); // Let's be evil here - reconnect while a fragmented packet was about to be @@ -1064,8 +1163,9 @@ TEST_F(DcSctpSocketTest, OnePeerReconnects) { EXPECT_THAT(msg->payload(), testing::ElementsAreArray(payload)); } -TEST_F(DcSctpSocketTest, SendMessageWithLimitedRtx) { +TEST_P(DcSctpSocketParametrizedTest, SendMessageWithLimitedRtx) { ConnectSockets(); + MaybeHandoverSocketZ(); SendOptions send_options; send_options.max_retransmissions = 0; @@ -1117,10 +1217,13 @@ TEST_F(DcSctpSocketTest, SendMessageWithLimitedRtx) { absl::optional msg3 = cb_z_.ConsumeReceivedMessage(); EXPECT_FALSE(msg3.has_value()); + + MaybeHandoverSocketZAndSendMessage(); } -TEST_F(DcSctpSocketTest, SendManyFragmentedMessagesWithLimitedRtx) { +TEST_P(DcSctpSocketParametrizedTest, SendManyFragmentedMessagesWithLimitedRtx) { ConnectSockets(); + MaybeHandoverSocketZ(); SendOptions send_options; send_options.unordered = IsUnordered(true); @@ -1210,8 +1313,9 @@ class FakeChunk : public Chunk, public TLVTrait { std::string ToString() const override { return "FAKE"; } }; -TEST_F(DcSctpSocketTest, ReceivingUnknownChunkRespondsWithError) { +TEST_P(DcSctpSocketParametrizedTest, ReceivingUnknownChunkRespondsWithError) { ConnectSockets(); + MaybeHandoverSocketZ(); // Inject a FAKE chunk SctpPacket::Builder b(sock_a_->verification_tag(), DcSctpOptions()); @@ -1228,10 +1332,13 @@ TEST_F(DcSctpSocketTest, ReceivingUnknownChunkRespondsWithError) { UnrecognizedChunkTypeCause cause, error.error_causes().get()); EXPECT_THAT(cause.unrecognized_chunk(), ElementsAre(0x49, 0x00, 0x00, 0x04)); + + MaybeHandoverSocketZAndSendMessage(); } -TEST_F(DcSctpSocketTest, ReceivingErrorChunkReportsAsCallback) { +TEST_P(DcSctpSocketParametrizedTest, ReceivingErrorChunkReportsAsCallback) { ConnectSockets(); + MaybeHandoverSocketZ(); // Inject a ERROR chunk SctpPacket::Builder b(sock_a_->verification_tag(), DcSctpOptions()); @@ -1243,6 +1350,8 @@ TEST_F(DcSctpSocketTest, ReceivingErrorChunkReportsAsCallback) { EXPECT_CALL(cb_a_, OnError(ErrorKind::kPeerReported, HasSubstr("Unrecognized Chunk Type"))); sock_a_->ReceivePacket(b.Build()); + + MaybeHandoverSocketZAndSendMessage(); } TEST_F(DcSctpSocketTest, PassingHighWatermarkWillOnlyAcceptCumAckTsn) { @@ -1359,8 +1468,9 @@ TEST_F(DcSctpSocketTest, SetMaxMessageSize) { EXPECT_EQ(sock_a_->options().max_message_size, 42u); } -TEST_F(DcSctpSocketTest, SendsMessagesWithLowLifetime) { +TEST_P(DcSctpSocketParametrizedTest, SendsMessagesWithLowLifetime) { ConnectSockets(); + MaybeHandoverSocketZ(); // Mock that the time always goes forward. TimeMs now(0); @@ -1394,10 +1504,14 @@ TEST_F(DcSctpSocketTest, SendsMessagesWithLowLifetime) { // Validate that the sockets really make the time move forward. EXPECT_GE(*now, kIterations * 2); + + MaybeHandoverSocketZAndSendMessage(); } -TEST_F(DcSctpSocketTest, DiscardsMessagesWithLowLifetimeIfMustBuffer) { +TEST_P(DcSctpSocketParametrizedTest, + DiscardsMessagesWithLowLifetimeIfMustBuffer) { ConnectSockets(); + MaybeHandoverSocketZ(); SendOptions lifetime_0; lifetime_0.unordered = IsUnordered(true); @@ -1449,53 +1563,65 @@ TEST_F(DcSctpSocketTest, DiscardsMessagesWithLowLifetimeIfMustBuffer) { // But none of the smaller messages. EXPECT_FALSE(cb_z_.ConsumeReceivedMessage().has_value()); + + MaybeHandoverSocketZAndSendMessage(); } -TEST_F(DcSctpSocketTest, HasReasonableBufferedAmountValues) { +TEST_P(DcSctpSocketParametrizedTest, HasReasonableBufferedAmountValues) { ConnectSockets(); + MaybeHandoverSocketZ(); EXPECT_EQ(sock_a_->buffered_amount(StreamID(1)), 0u); sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), - std::vector(kSmallMessageSize)), - kSendOptions); + std::vector(kSmallMessageSize)), + kSendOptions); // Sending a small message will directly send it as a single packet, so // nothing is left in the queue. EXPECT_EQ(sock_a_->buffered_amount(StreamID(1)), 0u); sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), - std::vector(kLargeMessageSize)), - kSendOptions); + std::vector(kLargeMessageSize)), + kSendOptions); // Sending a message will directly start sending a few packets, so the // buffered amount is not the full message size. EXPECT_GT(sock_a_->buffered_amount(StreamID(1)), 0u); EXPECT_LT(sock_a_->buffered_amount(StreamID(1)), kLargeMessageSize); + + MaybeHandoverSocketZAndSendMessage(); } TEST_F(DcSctpSocketTest, HasDefaultOnBufferedAmountLowValueZero) { EXPECT_EQ(sock_a_->buffered_amount_low_threshold(StreamID(1)), 0u); } -TEST_F(DcSctpSocketTest, TriggersOnBufferedAmountLowWithDefaultValueZero) { +TEST_P(DcSctpSocketParametrizedTest, + TriggersOnBufferedAmountLowWithDefaultValueZero) { EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0); ConnectSockets(); + MaybeHandoverSocketZ(); EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(1))); sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), - std::vector(kSmallMessageSize)), - kSendOptions); + std::vector(kSmallMessageSize)), + kSendOptions); ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + + EXPECT_CALL(cb_a_, OnBufferedAmountLow).WillRepeatedly(testing::Return()); + MaybeHandoverSocketZAndSendMessage(); } -TEST_F(DcSctpSocketTest, DoesntTriggerOnBufferedAmountLowIfBelowThreshold) { +TEST_P(DcSctpSocketParametrizedTest, + DoesntTriggerOnBufferedAmountLowIfBelowThreshold) { static constexpr size_t kMessageSize = 1000; static constexpr size_t kBufferedAmountLowThreshold = kMessageSize * 10; sock_a_->SetBufferedAmountLowThreshold(StreamID(1), - kBufferedAmountLowThreshold); + kBufferedAmountLowThreshold); EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0); ConnectSockets(); + MaybeHandoverSocketZ(); EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(1))).Times(0); sock_a_->Send( @@ -1507,16 +1633,19 @@ TEST_F(DcSctpSocketTest, DoesntTriggerOnBufferedAmountLowIfBelowThreshold) { DcSctpMessage(StreamID(1), PPID(53), std::vector(kMessageSize)), kSendOptions); ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + + MaybeHandoverSocketZAndSendMessage(); } -TEST_F(DcSctpSocketTest, TriggersOnBufferedAmountMultipleTimes) { +TEST_P(DcSctpSocketParametrizedTest, TriggersOnBufferedAmountMultipleTimes) { static constexpr size_t kMessageSize = 1000; static constexpr size_t kBufferedAmountLowThreshold = kMessageSize / 2; sock_a_->SetBufferedAmountLowThreshold(StreamID(1), - kBufferedAmountLowThreshold); + kBufferedAmountLowThreshold); EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0); ConnectSockets(); + MaybeHandoverSocketZ(); EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(1))).Times(3); EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(2))).Times(2); @@ -1544,16 +1673,20 @@ TEST_F(DcSctpSocketTest, TriggersOnBufferedAmountMultipleTimes) { DcSctpMessage(StreamID(1), PPID(53), std::vector(kMessageSize)), kSendOptions); ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + + MaybeHandoverSocketZAndSendMessage(); } -TEST_F(DcSctpSocketTest, TriggersOnBufferedAmountLowOnlyWhenCrossingThreshold) { +TEST_P(DcSctpSocketParametrizedTest, + TriggersOnBufferedAmountLowOnlyWhenCrossingThreshold) { static constexpr size_t kMessageSize = 1000; static constexpr size_t kBufferedAmountLowThreshold = kMessageSize * 1.5; sock_a_->SetBufferedAmountLowThreshold(StreamID(1), - kBufferedAmountLowThreshold); + kBufferedAmountLowThreshold); EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0); ConnectSockets(); + MaybeHandoverSocketZ(); EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0); @@ -1561,8 +1694,8 @@ TEST_F(DcSctpSocketTest, TriggersOnBufferedAmountLowOnlyWhenCrossingThreshold) { // messages will start to be fully buffered. while (sock_a_->buffered_amount(StreamID(1)) <= kBufferedAmountLowThreshold) { sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), - std::vector(kMessageSize)), - kSendOptions); + std::vector(kMessageSize)), + kSendOptions); } size_t initial_buffered = sock_a_->buffered_amount(StreamID(1)); ASSERT_GT(initial_buffered, kBufferedAmountLowThreshold); @@ -1571,36 +1704,46 @@ TEST_F(DcSctpSocketTest, TriggersOnBufferedAmountLowOnlyWhenCrossingThreshold) { // callback. EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(1))).Times(1); ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + + MaybeHandoverSocketZAndSendMessage(); } -TEST_F(DcSctpSocketTest, DoesntTriggerOnTotalBufferAmountLowWhenBelow) { +TEST_P(DcSctpSocketParametrizedTest, + DoesntTriggerOnTotalBufferAmountLowWhenBelow) { ConnectSockets(); + MaybeHandoverSocketZ(); EXPECT_CALL(cb_a_, OnTotalBufferedAmountLow).Times(0); sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), - std::vector(kLargeMessageSize)), - kSendOptions); + std::vector(kLargeMessageSize)), + kSendOptions); ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + + MaybeHandoverSocketZAndSendMessage(); } -TEST_F(DcSctpSocketTest, TriggersOnTotalBufferAmountLowWhenCrossingThreshold) { +TEST_P(DcSctpSocketParametrizedTest, + TriggersOnTotalBufferAmountLowWhenCrossingThreshold) { ConnectSockets(); + MaybeHandoverSocketZ(); EXPECT_CALL(cb_a_, OnTotalBufferedAmountLow).Times(0); // Fill up the send queue completely. for (;;) { if (sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), - std::vector(kLargeMessageSize)), - kSendOptions) == SendStatus::kErrorResourceExhaustion) { + std::vector(kLargeMessageSize)), + kSendOptions) == SendStatus::kErrorResourceExhaustion) { break; } } EXPECT_CALL(cb_a_, OnTotalBufferedAmountLow).Times(1); ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + + MaybeHandoverSocketZAndSendMessage(); } TEST_F(DcSctpSocketTest, InitialMetricsAreZeroed) { @@ -1650,8 +1793,8 @@ TEST_F(DcSctpSocketTest, RxAndTxPacketMetricsIncrease) { // Send one more (large - fragmented), and receive the delayed SACK. sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), - std::vector(options_.mtu * 2 + 1)), - kSendOptions); + std::vector(options_.mtu * 2 + 1)), + kSendOptions); EXPECT_EQ(sock_a_->GetMetrics().unack_data_count, 3u); sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); // DATA @@ -1683,12 +1826,13 @@ TEST_F(DcSctpSocketTest, RxAndTxPacketMetricsIncrease) { EXPECT_EQ(*sock_a_->GetMetrics().peer_rwnd_bytes, initial_a_rwnd); } -TEST_F(DcSctpSocketTest, UnackDataAlsoIncludesSendQueue) { +TEST_P(DcSctpSocketParametrizedTest, UnackDataAlsoIncludesSendQueue) { ConnectSockets(); + MaybeHandoverSocketZ(); sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), - std::vector(kLargeMessageSize)), - kSendOptions); + std::vector(kLargeMessageSize)), + kSendOptions); size_t payload_bytes = options_.mtu - SctpPacket::kHeaderSize - DataChunk::kHeaderSize; @@ -1706,14 +1850,17 @@ TEST_F(DcSctpSocketTest, UnackDataAlsoIncludesSendQueue) { EXPECT_LE(sock_a_->GetMetrics().unack_data_count, expected_sent_packets + expected_queued_packets + 2); + + MaybeHandoverSocketZAndSendMessage(); } -TEST_F(DcSctpSocketTest, DoesntSendMoreThanMaxBurstPackets) { +TEST_P(DcSctpSocketParametrizedTest, DoesntSendMoreThanMaxBurstPackets) { ConnectSockets(); + MaybeHandoverSocketZ(); sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), - std::vector(kLargeMessageSize)), - kSendOptions); + std::vector(kLargeMessageSize)), + kSendOptions); for (int i = 0; i < kMaxBurstPackets; ++i) { std::vector packet = cb_a_.ConsumeSentPacket(); @@ -1722,10 +1869,14 @@ TEST_F(DcSctpSocketTest, DoesntSendMoreThanMaxBurstPackets) { } EXPECT_THAT(cb_a_.ConsumeSentPacket(), IsEmpty()); + + ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + MaybeHandoverSocketZAndSendMessage(); } -TEST_F(DcSctpSocketTest, SendsOnlyLargePackets) { +TEST_P(DcSctpSocketParametrizedTest, SendsOnlyLargePackets) { ConnectSockets(); + MaybeHandoverSocketZ(); // A really large message, to ensure that the congestion window is often full. constexpr size_t kMessageSize = 100000; @@ -1765,10 +1916,13 @@ TEST_F(DcSctpSocketTest, SendsOnlyLargePackets) { // The 4 is for padding/alignment. EXPECT_GE(size, options_.mtu - 4); } + + MaybeHandoverSocketZAndSendMessage(); } -TEST_F(DcSctpSocketTest, DoesntBundleForwardTsnWithData) { +TEST_P(DcSctpSocketParametrizedTest, DoesntBundleForwardTsnWithData) { ConnectSockets(); + MaybeHandoverSocketZ(); // Force an RTT measurement using heartbeats. AdvanceTime(options_.heartbeat_interval); @@ -1848,5 +2002,49 @@ TEST_F(DcSctpSocketTest, DoesntBundleForwardTsnWithData) { EXPECT_EQ(packet4.descriptors()[0].type, ForwardTsnChunk::kType); } +TEST_F(DcSctpSocketTest, SendMessagesAfterHandover) { + ConnectSockets(); + + // Send message before handover to move socket to a not initial state + sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + cb_z_.ConsumeReceivedMessage(); + + HandoverSocketZ(); + + absl::optional msg; + + RTC_LOG(LS_INFO) << "Sending A #1"; + + sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {3, 4}), kSendOptions); + sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + + msg = cb_z_.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); + EXPECT_THAT(msg->payload(), testing::ElementsAre(3, 4)); + + RTC_LOG(LS_INFO) << "Sending A #2"; + + sock_a_->Send(DcSctpMessage(StreamID(2), PPID(53), {5, 6}), kSendOptions); + sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + + msg = cb_z_.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(2)); + EXPECT_THAT(msg->payload(), testing::ElementsAre(5, 6)); + + RTC_LOG(LS_INFO) << "Sending Z #1"; + + sock_z_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2, 3}), kSendOptions); + sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); // ack + sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); // data + + msg = cb_a_.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); + EXPECT_THAT(msg->payload(), testing::ElementsAre(1, 2, 3)); +} + } // namespace } // namespace dcsctp diff --git a/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h b/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h index 894dd9ac5a..a49a0b3325 100644 --- a/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h +++ b/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h @@ -150,6 +150,12 @@ class MockDcSctpSocketCallbacks : public DcSctpSocketCallbacks { return timeout_manager_.GetNextExpiredTimeout(); } + void Reset() { + sent_packets_.clear(); + received_messages_.clear(); + timeout_manager_.Reset(); + } + private: const std::string log_prefix_; TimeMs now_ = TimeMs(0); diff --git a/net/dcsctp/socket/transmission_control_block.cc b/net/dcsctp/socket/transmission_control_block.cc index f0f1ab9782..2e4e968737 100644 --- a/net/dcsctp/socket/transmission_control_block.cc +++ b/net/dcsctp/socket/transmission_control_block.cc @@ -183,4 +183,30 @@ std::string TransmissionControlBlock::ToString() const { return sb.Release(); } +HandoverReadinessStatus TransmissionControlBlock::GetHandoverReadiness() const { + HandoverReadinessStatus status; + status.Add(data_tracker_.GetHandoverReadiness()); + status.Add(stream_reset_handler_.GetHandoverReadiness()); + status.Add(reassembly_queue_.GetHandoverReadiness()); + status.Add(retransmission_queue_.GetHandoverReadiness()); + return status; +} + +void TransmissionControlBlock::AddHandoverState( + DcSctpSocketHandoverState& state) { + state.capabilities.partial_reliability = capabilities_.partial_reliability; + state.capabilities.message_interleaving = capabilities_.message_interleaving; + state.capabilities.reconfig = capabilities_.reconfig; + + state.my_verification_tag = my_verification_tag().value(); + state.peer_verification_tag = peer_verification_tag().value(); + state.my_initial_tsn = my_initial_tsn().value(); + state.peer_initial_tsn = peer_initial_tsn().value(); + state.tie_tag = tie_tag().value(); + + data_tracker_.AddHandoverState(state); + stream_reset_handler_.AddHandoverState(state); + reassembly_queue_.AddHandoverState(state); + retransmission_queue_.AddHandoverState(state); +} } // namespace dcsctp diff --git a/net/dcsctp/socket/transmission_control_block.h b/net/dcsctp/socket/transmission_control_block.h index c3766d1546..6d9dfc5e70 100644 --- a/net/dcsctp/socket/transmission_control_block.h +++ b/net/dcsctp/socket/transmission_control_block.h @@ -44,20 +44,22 @@ namespace dcsctp { // closed or restarted, this object will be deleted and/or replaced. class TransmissionControlBlock : public Context { public: - TransmissionControlBlock(TimerManager& timer_manager, - absl::string_view log_prefix, - const DcSctpOptions& options, - const Capabilities& capabilities, - DcSctpSocketCallbacks& callbacks, - SendQueue& send_queue, - VerificationTag my_verification_tag, - TSN my_initial_tsn, - VerificationTag peer_verification_tag, - TSN peer_initial_tsn, - size_t a_rwnd, - TieTag tie_tag, - PacketSender& packet_sender, - std::function is_connection_established) + TransmissionControlBlock( + TimerManager& timer_manager, + absl::string_view log_prefix, + const DcSctpOptions& options, + const Capabilities& capabilities, + DcSctpSocketCallbacks& callbacks, + SendQueue& send_queue, + VerificationTag my_verification_tag, + TSN my_initial_tsn, + VerificationTag peer_verification_tag, + TSN peer_initial_tsn, + size_t a_rwnd, + TieTag tie_tag, + PacketSender& packet_sender, + std::function is_connection_established, + const DcSctpSocketHandoverState* handover_state = nullptr) : log_prefix_(log_prefix), options_(options), timer_manager_(timer_manager), @@ -86,10 +88,14 @@ class TransmissionControlBlock : public Context { packet_sender_(packet_sender), rto_(options), tx_error_counter_(log_prefix, options), - data_tracker_(log_prefix, delayed_ack_timer_.get(), peer_initial_tsn), + data_tracker_(log_prefix, + delayed_ack_timer_.get(), + peer_initial_tsn, + handover_state), reassembly_queue_(log_prefix, peer_initial_tsn, - options.max_receiver_window_buffer_size), + options.max_receiver_window_buffer_size, + handover_state), retransmission_queue_( log_prefix, my_initial_tsn, @@ -100,13 +106,15 @@ class TransmissionControlBlock : public Context { *t3_rtx_, options, capabilities.partial_reliability, - capabilities.message_interleaving), + capabilities.message_interleaving, + handover_state), stream_reset_handler_(log_prefix, this, &timer_manager, &data_tracker_, &reassembly_queue_, - &retransmission_queue_), + &retransmission_queue_, + handover_state), heartbeat_handler_(log_prefix, options, this, &timer_manager_) {} // Implementation of `Context`. @@ -188,6 +196,10 @@ class TransmissionControlBlock : public Context { // Returns a textual representation of this object, for logging. std::string ToString() const; + HandoverReadinessStatus GetHandoverReadiness() const; + + void AddHandoverState(DcSctpSocketHandoverState& state); + private: // Will be called when the retransmission timer (t3-rtx) expires. absl::optional OnRtxTimerExpiry(); diff --git a/net/dcsctp/timer/fake_timeout.h b/net/dcsctp/timer/fake_timeout.h index f2bf10325e..e8f50d93cb 100644 --- a/net/dcsctp/timer/fake_timeout.h +++ b/net/dcsctp/timer/fake_timeout.h @@ -91,6 +91,8 @@ class FakeTimeoutManager { return absl::nullopt; } + void Reset() { timers_.clear(); } + private: const std::function get_time_; webrtc::flat_set timers_;