From eee0e336a2f3de58c286e308e0705265f9a8ba0c Mon Sep 17 00:00:00 2001 From: Victor Boivie Date: Fri, 4 Mar 2022 20:11:44 +0100 Subject: [PATCH] dcsctp: Convert socket tests not to use fixtures Following https://abseil.io/tips/122 to make tests easier to understand and adds a bit of flexibility to create sockets with custom parameters. This also simplifies handover tests. Additionally, AdvanceTime will now also run timers, as that was easily forgotten previously. Bug: None Change-Id: Ieb5eece7aca51c98a7634ed1c61646383ad1712d Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/253782 Reviewed-by: Sergey Sukhanov Commit-Queue: Victor Boivie Cr-Commit-Position: refs/heads/main@{#36141} --- net/dcsctp/socket/dcsctp_socket_test.cc | 1678 +++++++++-------- .../socket/mock_dcsctp_socket_callbacks.h | 6 - net/dcsctp/timer/fake_timeout.h | 2 - 3 files changed, 891 insertions(+), 795 deletions(-) diff --git a/net/dcsctp/socket/dcsctp_socket_test.cc b/net/dcsctp/socket/dcsctp_socket_test.cc index 66876e4e25..f45773baba 100644 --- a/net/dcsctp/socket/dcsctp_socket_test.cc +++ b/net/dcsctp/socket/dcsctp_socket_test.cc @@ -233,13 +233,12 @@ TSN AddTo(TSN tsn, int delta) { return TSN(*tsn + delta); } -DcSctpOptions MakeOptionsForTest(bool enable_message_interleaving) { - DcSctpOptions options; +DcSctpOptions FixupOptions(DcSctpOptions options = {}) { + DcSctpOptions fixup = options; // To make the interval more predictable in tests. - options.heartbeat_interval_include_rtt = false; - options.enable_message_interleaving = enable_message_interleaving; - options.max_burst = kMaxBurstPackets; - return options; + fixup.heartbeat_interval_include_rtt = false; + fixup.max_burst = kMaxBurstPackets; + return fixup; } std::unique_ptr GetPacketObserver(absl::string_view name) { @@ -249,103 +248,90 @@ std::unique_ptr GetPacketObserver(absl::string_view name) { return nullptr; } -class DcSctpSocketTest : public testing::Test { - protected: - explicit DcSctpSocketTest(bool enable_message_interleaving = false) - : options_(MakeOptionsForTest(enable_message_interleaving)), - cb_a_("A"), - cb_z_("Z"), - sock_a_(std::make_unique("A", - cb_a_, - GetPacketObserver("A"), - options_)), - sock_z_(std::make_unique("Z", - cb_z_, - GetPacketObserver("Z"), - options_)) {} +struct SocketUnderTest { + explicit SocketUnderTest(absl::string_view name, + const DcSctpOptions& opts = {}) + : options(FixupOptions(opts)), + cb(name), + socket(name, cb, GetPacketObserver(name), options) {} - void AdvanceTime(DurationMs duration) { - cb_a_.AdvanceTime(duration); - cb_z_.AdvanceTime(duration); - } - - static void ExchangeMessages(DcSctpSocket& sock_a, - MockDcSctpSocketCallbacks& cb_a, - DcSctpSocket& sock_z, - MockDcSctpSocketCallbacks& cb_z) { - bool delivered_packet = false; - do { - delivered_packet = false; - std::vector packet_from_a = cb_a.ConsumeSentPacket(); - if (!packet_from_a.empty()) { - delivered_packet = true; - sock_z.ReceivePacket(std::move(packet_from_a)); - } - std::vector packet_from_z = cb_z.ConsumeSentPacket(); - if (!packet_from_z.empty()) { - delivered_packet = true; - sock_a.ReceivePacket(std::move(packet_from_z)); - } - } while (delivered_packet); - } - - void RunTimers(MockDcSctpSocketCallbacks& cb, DcSctpSocket& socket) { - for (;;) { - absl::optional timeout_id = cb.GetNextExpiredTimeout(); - if (!timeout_id.has_value()) { - break; - } - socket.HandleTimeout(*timeout_id); - } - } - - void RunTimers() { - RunTimers(cb_a_, *sock_a_); - RunTimers(cb_z_, *sock_z_); - } - - // Calls Connect() on `sock_a_` and make the connection established. - void ConnectSockets() { - EXPECT_CALL(cb_a_, OnConnected).Times(1); - EXPECT_CALL(cb_z_, OnConnected).Times(1); - - sock_a_->Connect(); - // Z reads INIT, INIT_ACK, COOKIE_ECHO, COOKIE_ACK - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); - - EXPECT_EQ(sock_a_->state(), SocketState::kConnected); - EXPECT_EQ(sock_z_->state(), SocketState::kConnected); - } - - void HandoverSocketZ() { - ASSERT_EQ(sock_z_->GetHandoverReadiness(), HandoverReadinessStatus()); - bool is_closed = sock_z_->state() == SocketState::kClosed; - if (!is_closed) { - EXPECT_CALL(cb_z_, OnClosed).Times(1); - } - absl::optional handover_state = - sock_z_->GetHandoverStateAndClose(); - EXPECT_TRUE(handover_state.has_value()); - g_handover_state_transformer_for_test(&*handover_state); - cb_z_.Reset(); - sock_z_ = std::make_unique("Z", cb_z_, GetPacketObserver("Z"), - options_); - if (!is_closed) { - EXPECT_CALL(cb_z_, OnConnected).Times(1); - } - sock_z_->RestoreFromState(*handover_state); - } - - const DcSctpOptions options_; - testing::NiceMock cb_a_; - testing::NiceMock cb_z_; - std::unique_ptr sock_a_; - std::unique_ptr sock_z_; + const DcSctpOptions options; + testing::NiceMock cb; + DcSctpSocket socket; }; +void ExchangeMessages(SocketUnderTest& a, SocketUnderTest& z) { + bool delivered_packet = false; + do { + delivered_packet = false; + std::vector packet_from_a = a.cb.ConsumeSentPacket(); + if (!packet_from_a.empty()) { + delivered_packet = true; + z.socket.ReceivePacket(std::move(packet_from_a)); + } + std::vector packet_from_z = z.cb.ConsumeSentPacket(); + if (!packet_from_z.empty()) { + delivered_packet = true; + a.socket.ReceivePacket(std::move(packet_from_z)); + } + } while (delivered_packet); +} + +void RunTimers(SocketUnderTest& s) { + for (;;) { + absl::optional timeout_id = s.cb.GetNextExpiredTimeout(); + if (!timeout_id.has_value()) { + break; + } + s.socket.HandleTimeout(*timeout_id); + } +} + +void AdvanceTime(SocketUnderTest& a, SocketUnderTest& z, DurationMs duration) { + a.cb.AdvanceTime(duration); + z.cb.AdvanceTime(duration); + + RunTimers(a); + RunTimers(z); +} + +// Calls Connect() on `sock_a_` and make the connection established. +void ConnectSockets(SocketUnderTest& a, SocketUnderTest& z) { + EXPECT_CALL(a.cb, OnConnected).Times(1); + EXPECT_CALL(z.cb, OnConnected).Times(1); + + a.socket.Connect(); + // Z reads INIT, INIT_ACK, COOKIE_ECHO, COOKIE_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); +} + +std::unique_ptr HandoverSocket( + std::unique_ptr sut) { + EXPECT_EQ(sut->socket.GetHandoverReadiness(), HandoverReadinessStatus()); + + bool is_closed = sut->socket.state() == SocketState::kClosed; + if (!is_closed) { + EXPECT_CALL(sut->cb, OnClosed).Times(1); + } + absl::optional handover_state = + sut->socket.GetHandoverStateAndClose(); + EXPECT_TRUE(handover_state.has_value()); + g_handover_state_transformer_for_test(&*handover_state); + + auto handover_socket = std::make_unique("H", sut->options); + if (!is_closed) { + EXPECT_CALL(handover_socket->cb, OnConnected).Times(1); + } + handover_socket->socket.RestoreFromState(*handover_state); + return handover_socket; +} + // Test parameter that controls whether to perform handovers during the test. A // test can have multiple points where it conditionally hands over socket Z. // Either socket Z will be handed over at all those points or handed over never. @@ -355,27 +341,31 @@ enum class HandoverMode { }; class DcSctpSocketParametrizedTest - : public DcSctpSocketTest, + : public ::testing::Test, public ::testing::WithParamInterface { protected: - // Trigger handover for socket Z depending on the current test param. - void MaybeHandoverSocketZ() { + // Trigger handover for `sut` depending on the current test param. + std::unique_ptr MaybeHandoverSocket( + std::unique_ptr sut) { if (GetParam() == HandoverMode::kPerformHandovers) { - HandoverSocketZ(); + return HandoverSocket(std::move(sut)); } + return sut; } + // Trigger handover for socket Z depending on the current test param. // Then checks message passing to verify the handed over socket is functional. - void MaybeHandoverSocketZAndSendMessage() { + void MaybeHandoverSocketAndSendMessage(SocketUnderTest& a, + std::unique_ptr z) { if (GetParam() == HandoverMode::kPerformHandovers) { - HandoverSocketZ(); + z = HandoverSocket(std::move(z)); } - ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); - ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + ExchangeMessages(a, *z); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + ExchangeMessages(a, *z); - absl::optional msg = cb_z_.ConsumeReceivedMessage(); + absl::optional msg = z->cb.ConsumeReceivedMessage(); ASSERT_TRUE(msg.has_value()); EXPECT_EQ(msg->stream_id(), StreamID(1)); } @@ -392,451 +382,492 @@ INSTANTIATE_TEST_SUITE_P(Handovers, : "NoHandover"; }); -TEST_F(DcSctpSocketTest, EstablishConnection) { - EXPECT_CALL(cb_a_, OnConnected).Times(1); - EXPECT_CALL(cb_z_, OnConnected).Times(1); - EXPECT_CALL(cb_a_, OnConnectionRestarted).Times(0); - EXPECT_CALL(cb_z_, OnConnectionRestarted).Times(0); +TEST(DcSctpSocketTest, EstablishConnection) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); - sock_a_->Connect(); + EXPECT_CALL(a.cb, OnConnected).Times(1); + EXPECT_CALL(z.cb, OnConnected).Times(1); + EXPECT_CALL(a.cb, OnConnectionRestarted).Times(0); + EXPECT_CALL(z.cb, OnConnectionRestarted).Times(0); + + a.socket.Connect(); // Z reads INIT, produces INIT_ACK - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // A reads INIT_ACK, produces COOKIE_ECHO - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // Z reads COOKIE_ECHO, produces COOKIE_ACK - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // A reads COOKIE_ACK. - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); - EXPECT_EQ(sock_a_->state(), SocketState::kConnected); - EXPECT_EQ(sock_z_->state(), SocketState::kConnected); + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); } -TEST_F(DcSctpSocketTest, EstablishConnectionWithSetupCollision) { - EXPECT_CALL(cb_a_, OnConnected).Times(1); - EXPECT_CALL(cb_z_, OnConnected).Times(1); - EXPECT_CALL(cb_a_, OnConnectionRestarted).Times(0); - EXPECT_CALL(cb_z_, OnConnectionRestarted).Times(0); - sock_a_->Connect(); - sock_z_->Connect(); +TEST(DcSctpSocketTest, EstablishConnectionWithSetupCollision) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); - ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + EXPECT_CALL(a.cb, OnConnected).Times(1); + EXPECT_CALL(z.cb, OnConnected).Times(1); + EXPECT_CALL(a.cb, OnConnectionRestarted).Times(0); + EXPECT_CALL(z.cb, OnConnectionRestarted).Times(0); + a.socket.Connect(); + z.socket.Connect(); - EXPECT_EQ(sock_a_->state(), SocketState::kConnected); - EXPECT_EQ(sock_z_->state(), SocketState::kConnected); + ExchangeMessages(a, z); + + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); } -TEST_F(DcSctpSocketTest, ShuttingDownWhileEstablishingConnection) { - EXPECT_CALL(cb_a_, OnConnected).Times(0); - EXPECT_CALL(cb_z_, OnConnected).Times(1); - sock_a_->Connect(); +TEST(DcSctpSocketTest, ShuttingDownWhileEstablishingConnection) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + EXPECT_CALL(a.cb, OnConnected).Times(0); + EXPECT_CALL(z.cb, OnConnected).Times(1); + a.socket.Connect(); // Z reads INIT, produces INIT_ACK - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // A reads INIT_ACK, produces COOKIE_ECHO - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // Z reads COOKIE_ECHO, produces COOKIE_ACK - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // Drop COOKIE_ACK, just to more easily verify shutdown protocol. - cb_z_.ConsumeSentPacket(); + z.cb.ConsumeSentPacket(); // As Socket A has received INIT_ACK, it has a TCB and is connected, while // Socket Z needs to receive COOKIE_ECHO to get there. Socket A still has // timers running at this point. - EXPECT_EQ(sock_a_->state(), SocketState::kConnecting); - EXPECT_EQ(sock_z_->state(), SocketState::kConnected); + EXPECT_EQ(a.socket.state(), SocketState::kConnecting); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); // Socket A is now shut down, which should make it stop those timers. - sock_a_->Shutdown(); + a.socket.Shutdown(); - EXPECT_CALL(cb_a_, OnClosed).Times(1); - EXPECT_CALL(cb_z_, OnClosed).Times(1); + EXPECT_CALL(a.cb, OnClosed).Times(1); + EXPECT_CALL(z.cb, OnClosed).Times(1); // Z reads SHUTDOWN, produces SHUTDOWN_ACK - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // A reads SHUTDOWN_ACK, produces SHUTDOWN_COMPLETE - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // Z reads SHUTDOWN_COMPLETE. - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); - EXPECT_TRUE(cb_a_.ConsumeSentPacket().empty()); - EXPECT_TRUE(cb_z_.ConsumeSentPacket().empty()); + EXPECT_TRUE(a.cb.ConsumeSentPacket().empty()); + EXPECT_TRUE(z.cb.ConsumeSentPacket().empty()); - EXPECT_EQ(sock_a_->state(), SocketState::kClosed); - EXPECT_EQ(sock_z_->state(), SocketState::kClosed); + EXPECT_EQ(a.socket.state(), SocketState::kClosed); + EXPECT_EQ(z.socket.state(), SocketState::kClosed); } -TEST_F(DcSctpSocketTest, EstablishSimultaneousConnection) { - EXPECT_CALL(cb_a_, OnConnected).Times(1); - EXPECT_CALL(cb_z_, OnConnected).Times(1); - EXPECT_CALL(cb_a_, OnConnectionRestarted).Times(0); - EXPECT_CALL(cb_z_, OnConnectionRestarted).Times(0); - sock_a_->Connect(); +TEST(DcSctpSocketTest, EstablishSimultaneousConnection) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + EXPECT_CALL(a.cb, OnConnected).Times(1); + EXPECT_CALL(z.cb, OnConnected).Times(1); + EXPECT_CALL(a.cb, OnConnectionRestarted).Times(0); + EXPECT_CALL(z.cb, OnConnectionRestarted).Times(0); + a.socket.Connect(); // INIT isn't received by Z, as it wasn't ready yet. - cb_a_.ConsumeSentPacket(); + a.cb.ConsumeSentPacket(); - sock_z_->Connect(); + z.socket.Connect(); // A reads INIT, produces INIT_ACK - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // Z reads INIT_ACK, sends COOKIE_ECHO - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // A reads COOKIE_ECHO - establishes connection. - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); - EXPECT_EQ(sock_a_->state(), SocketState::kConnected); + EXPECT_EQ(a.socket.state(), SocketState::kConnected); // Proceed with the remaining packets. - ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + ExchangeMessages(a, z); - EXPECT_EQ(sock_a_->state(), SocketState::kConnected); - EXPECT_EQ(sock_z_->state(), SocketState::kConnected); + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); } -TEST_F(DcSctpSocketTest, EstablishConnectionLostCookieAck) { - EXPECT_CALL(cb_a_, OnConnected).Times(1); - EXPECT_CALL(cb_z_, OnConnected).Times(1); - EXPECT_CALL(cb_a_, OnConnectionRestarted).Times(0); - EXPECT_CALL(cb_z_, OnConnectionRestarted).Times(0); +TEST(DcSctpSocketTest, EstablishConnectionLostCookieAck) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); - sock_a_->Connect(); + EXPECT_CALL(a.cb, OnConnected).Times(1); + EXPECT_CALL(z.cb, OnConnected).Times(1); + EXPECT_CALL(a.cb, OnConnectionRestarted).Times(0); + EXPECT_CALL(z.cb, OnConnectionRestarted).Times(0); + + a.socket.Connect(); // Z reads INIT, produces INIT_ACK - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // A reads INIT_ACK, produces COOKIE_ECHO - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // Z reads COOKIE_ECHO, produces COOKIE_ACK - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // COOKIE_ACK is lost. - cb_z_.ConsumeSentPacket(); + z.cb.ConsumeSentPacket(); - EXPECT_EQ(sock_a_->state(), SocketState::kConnecting); - EXPECT_EQ(sock_z_->state(), SocketState::kConnected); + EXPECT_EQ(a.socket.state(), SocketState::kConnecting); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); // This will make A re-send the COOKIE_ECHO - AdvanceTime(DurationMs(options_.t1_cookie_timeout)); - RunTimers(); + AdvanceTime(a, z, DurationMs(a.options.t1_cookie_timeout)); // Z reads COOKIE_ECHO, produces COOKIE_ACK - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // A reads COOKIE_ACK. - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); - EXPECT_EQ(sock_a_->state(), SocketState::kConnected); - EXPECT_EQ(sock_z_->state(), SocketState::kConnected); + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); } -TEST_F(DcSctpSocketTest, ResendInitAndEstablishConnection) { - sock_a_->Connect(); +TEST(DcSctpSocketTest, ResendInitAndEstablishConnection) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + a.socket.Connect(); // INIT is never received by Z. ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet, - SctpPacket::Parse(cb_a_.ConsumeSentPacket())); + SctpPacket::Parse(a.cb.ConsumeSentPacket())); EXPECT_EQ(init_packet.descriptors()[0].type, InitChunk::kType); - AdvanceTime(options_.t1_init_timeout); - RunTimers(); + AdvanceTime(a, z, a.options.t1_init_timeout); // Z reads INIT, produces INIT_ACK - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // A reads INIT_ACK, produces COOKIE_ECHO - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // Z reads COOKIE_ECHO, produces COOKIE_ACK - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // A reads COOKIE_ACK. - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); - EXPECT_EQ(sock_a_->state(), SocketState::kConnected); - EXPECT_EQ(sock_z_->state(), SocketState::kConnected); + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); } -TEST_F(DcSctpSocketTest, ResendingInitTooManyTimesAborts) { - sock_a_->Connect(); +TEST(DcSctpSocketTest, ResendingInitTooManyTimesAborts) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + a.socket.Connect(); // INIT is never received by Z. ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet, - SctpPacket::Parse(cb_a_.ConsumeSentPacket())); + SctpPacket::Parse(a.cb.ConsumeSentPacket())); EXPECT_EQ(init_packet.descriptors()[0].type, InitChunk::kType); - for (int i = 0; i < *options_.max_init_retransmits; ++i) { - AdvanceTime(options_.t1_init_timeout * (1 << i)); - RunTimers(); + for (int i = 0; i < *a.options.max_init_retransmits; ++i) { + AdvanceTime(a, z, a.options.t1_init_timeout * (1 << i)); // INIT is resent ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket resent_init_packet, - SctpPacket::Parse(cb_a_.ConsumeSentPacket())); + SctpPacket::Parse(a.cb.ConsumeSentPacket())); EXPECT_EQ(resent_init_packet.descriptors()[0].type, InitChunk::kType); } // Another timeout, after the max init retransmits. - AdvanceTime(options_.t1_init_timeout * (1 << *options_.max_init_retransmits)); - EXPECT_CALL(cb_a_, OnAborted).Times(1); - RunTimers(); + EXPECT_CALL(a.cb, OnAborted).Times(1); + AdvanceTime( + a, z, a.options.t1_init_timeout * (1 << *a.options.max_init_retransmits)); - EXPECT_EQ(sock_a_->state(), SocketState::kClosed); + EXPECT_EQ(a.socket.state(), SocketState::kClosed); } -TEST_F(DcSctpSocketTest, ResendCookieEchoAndEstablishConnection) { - sock_a_->Connect(); +TEST(DcSctpSocketTest, ResendCookieEchoAndEstablishConnection) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + a.socket.Connect(); // Z reads INIT, produces INIT_ACK - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // A reads INIT_ACK, produces COOKIE_ECHO - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // COOKIE_ECHO is never received by Z. ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet, - SctpPacket::Parse(cb_a_.ConsumeSentPacket())); + SctpPacket::Parse(a.cb.ConsumeSentPacket())); EXPECT_EQ(init_packet.descriptors()[0].type, CookieEchoChunk::kType); - AdvanceTime(options_.t1_init_timeout); - RunTimers(); + AdvanceTime(a, z, a.options.t1_init_timeout); // Z reads COOKIE_ECHO, produces COOKIE_ACK - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // A reads COOKIE_ACK. - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); - EXPECT_EQ(sock_a_->state(), SocketState::kConnected); - EXPECT_EQ(sock_z_->state(), SocketState::kConnected); + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); } -TEST_F(DcSctpSocketTest, ResendingCookieEchoTooManyTimesAborts) { - sock_a_->Connect(); +TEST(DcSctpSocketTest, ResendingCookieEchoTooManyTimesAborts) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + a.socket.Connect(); // Z reads INIT, produces INIT_ACK - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // A reads INIT_ACK, produces COOKIE_ECHO - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // COOKIE_ECHO is never received by Z. ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet, - SctpPacket::Parse(cb_a_.ConsumeSentPacket())); + SctpPacket::Parse(a.cb.ConsumeSentPacket())); EXPECT_EQ(init_packet.descriptors()[0].type, CookieEchoChunk::kType); - for (int i = 0; i < *options_.max_init_retransmits; ++i) { - AdvanceTime(options_.t1_cookie_timeout * (1 << i)); - RunTimers(); + for (int i = 0; i < *a.options.max_init_retransmits; ++i) { + AdvanceTime(a, z, a.options.t1_cookie_timeout * (1 << i)); // COOKIE_ECHO is resent ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket resent_init_packet, - SctpPacket::Parse(cb_a_.ConsumeSentPacket())); + SctpPacket::Parse(a.cb.ConsumeSentPacket())); EXPECT_EQ(resent_init_packet.descriptors()[0].type, CookieEchoChunk::kType); } // Another timeout, after the max init retransmits. - AdvanceTime(options_.t1_cookie_timeout * - (1 << *options_.max_init_retransmits)); - EXPECT_CALL(cb_a_, OnAborted).Times(1); - RunTimers(); + EXPECT_CALL(a.cb, OnAborted).Times(1); + AdvanceTime( + a, z, + a.options.t1_cookie_timeout * (1 << *a.options.max_init_retransmits)); - EXPECT_EQ(sock_a_->state(), SocketState::kClosed); + EXPECT_EQ(a.socket.state(), SocketState::kClosed); } -TEST_F(DcSctpSocketTest, DoesntSendMorePacketsUntilCookieAckHasBeenReceived) { - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), +TEST(DcSctpSocketTest, DoesntSendMorePacketsUntilCookieAckHasBeenReceived) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), std::vector(kLargeMessageSize)), kSendOptions); - sock_a_->Connect(); + a.socket.Connect(); // Z reads INIT, produces INIT_ACK - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // A reads INIT_ACK, produces COOKIE_ECHO - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // COOKIE_ECHO is never received by Z. ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket cookie_echo_packet1, - SctpPacket::Parse(cb_a_.ConsumeSentPacket())); + SctpPacket::Parse(a.cb.ConsumeSentPacket())); EXPECT_THAT(cookie_echo_packet1.descriptors(), SizeIs(2)); EXPECT_EQ(cookie_echo_packet1.descriptors()[0].type, CookieEchoChunk::kType); EXPECT_EQ(cookie_echo_packet1.descriptors()[1].type, DataChunk::kType); - EXPECT_THAT(cb_a_.ConsumeSentPacket(), IsEmpty()); + EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty()); // There are DATA chunks in the sent packet (that was lost), which means that // the T3-RTX timer is running, but as the socket is in kCookieEcho state, it // will be T1-COOKIE that drives retransmissions, so when the T3-RTX expires, // nothing should be retransmitted. - ASSERT_TRUE(options_.rto_initial < options_.t1_cookie_timeout); - AdvanceTime(options_.rto_initial); - RunTimers(); - EXPECT_THAT(cb_a_.ConsumeSentPacket(), IsEmpty()); + ASSERT_TRUE(a.options.rto_initial < a.options.t1_cookie_timeout); + AdvanceTime(a, z, a.options.rto_initial); + EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty()); // When T1-COOKIE expires, both the COOKIE-ECHO and DATA should be present. - AdvanceTime(options_.t1_cookie_timeout - options_.rto_initial); - RunTimers(); + AdvanceTime(a, z, a.options.t1_cookie_timeout - a.options.rto_initial); // And this COOKIE-ECHO and DATA is also lost - never received by Z. ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket cookie_echo_packet2, - SctpPacket::Parse(cb_a_.ConsumeSentPacket())); + SctpPacket::Parse(a.cb.ConsumeSentPacket())); EXPECT_THAT(cookie_echo_packet2.descriptors(), SizeIs(2)); EXPECT_EQ(cookie_echo_packet2.descriptors()[0].type, CookieEchoChunk::kType); EXPECT_EQ(cookie_echo_packet2.descriptors()[1].type, DataChunk::kType); - EXPECT_THAT(cb_a_.ConsumeSentPacket(), IsEmpty()); + EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty()); // COOKIE_ECHO has exponential backoff. - AdvanceTime(options_.t1_cookie_timeout * 2); - RunTimers(); + AdvanceTime(a, z, a.options.t1_cookie_timeout * 2); // Z reads COOKIE_ECHO, produces COOKIE_ACK - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // A reads COOKIE_ACK. - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); - EXPECT_EQ(sock_a_->state(), SocketState::kConnected); - EXPECT_EQ(sock_z_->state(), SocketState::kConnected); + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); - ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); - EXPECT_THAT(cb_z_.ConsumeReceivedMessage()->payload(), + ExchangeMessages(a, z); + EXPECT_THAT(z.cb.ConsumeReceivedMessage()->payload(), SizeIs(kLargeMessageSize)); } TEST_P(DcSctpSocketParametrizedTest, ShutdownConnection) { - ConnectSockets(); - MaybeHandoverSocketZ(); + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); RTC_LOG(LS_INFO) << "Shutting down"; - EXPECT_CALL(cb_z_, OnClosed).Times(1); - sock_a_->Shutdown(); + EXPECT_CALL(z->cb, OnClosed).Times(1); + a.socket.Shutdown(); // Z reads SHUTDOWN, produces SHUTDOWN_ACK - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); // A reads SHUTDOWN_ACK, produces SHUTDOWN_COMPLETE - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); // Z reads SHUTDOWN_COMPLETE. - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); - EXPECT_EQ(sock_a_->state(), SocketState::kClosed); - EXPECT_EQ(sock_z_->state(), SocketState::kClosed); + EXPECT_EQ(a.socket.state(), SocketState::kClosed); + EXPECT_EQ(z->socket.state(), SocketState::kClosed); - MaybeHandoverSocketZ(); - EXPECT_EQ(sock_z_->state(), SocketState::kClosed); + z = MaybeHandoverSocket(std::move(z)); + EXPECT_EQ(z->socket.state(), SocketState::kClosed); } -TEST_F(DcSctpSocketTest, ShutdownTimerExpiresTooManyTimeClosesConnection) { - ConnectSockets(); +TEST(DcSctpSocketTest, ShutdownTimerExpiresTooManyTimeClosesConnection) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); - sock_a_->Shutdown(); + ConnectSockets(a, z); + + a.socket.Shutdown(); // Drop first SHUTDOWN packet. - cb_a_.ConsumeSentPacket(); + a.cb.ConsumeSentPacket(); - EXPECT_EQ(sock_a_->state(), SocketState::kShuttingDown); + EXPECT_EQ(a.socket.state(), SocketState::kShuttingDown); - for (int i = 0; i < *options_.max_retransmissions; ++i) { - AdvanceTime(DurationMs(options_.rto_initial * (1 << i))); - RunTimers(); + for (int i = 0; i < *a.options.max_retransmissions; ++i) { + AdvanceTime(a, z, DurationMs(a.options.rto_initial * (1 << i))); // Dropping every shutdown chunk. ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, - SctpPacket::Parse(cb_a_.ConsumeSentPacket())); + SctpPacket::Parse(a.cb.ConsumeSentPacket())); EXPECT_EQ(packet.descriptors()[0].type, ShutdownChunk::kType); - EXPECT_TRUE(cb_a_.ConsumeSentPacket().empty()); + EXPECT_TRUE(a.cb.ConsumeSentPacket().empty()); } // The last expiry, makes it abort the connection. - AdvanceTime(options_.rto_initial * (1 << *options_.max_retransmissions)); - EXPECT_CALL(cb_a_, OnAborted).Times(1); - RunTimers(); + EXPECT_CALL(a.cb, OnAborted).Times(1); + AdvanceTime(a, z, + a.options.rto_initial * (1 << *a.options.max_retransmissions)); - EXPECT_EQ(sock_a_->state(), SocketState::kClosed); + EXPECT_EQ(a.socket.state(), SocketState::kClosed); ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, - SctpPacket::Parse(cb_a_.ConsumeSentPacket())); + SctpPacket::Parse(a.cb.ConsumeSentPacket())); EXPECT_EQ(packet.descriptors()[0].type, AbortChunk::kType); - EXPECT_TRUE(cb_a_.ConsumeSentPacket().empty()); + EXPECT_TRUE(a.cb.ConsumeSentPacket().empty()); } -TEST_F(DcSctpSocketTest, EstablishConnectionWhileSendingData) { - sock_a_->Connect(); +TEST(DcSctpSocketTest, EstablishConnectionWhileSendingData) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + a.socket.Connect(); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); // Z reads INIT, produces INIT_ACK - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // // A reads INIT_ACK, produces COOKIE_ECHO - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // // Z reads COOKIE_ECHO, produces COOKIE_ACK - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // // A reads COOKIE_ACK. - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); - EXPECT_EQ(sock_a_->state(), SocketState::kConnected); - EXPECT_EQ(sock_z_->state(), SocketState::kConnected); + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); - absl::optional msg = cb_z_.ConsumeReceivedMessage(); + absl::optional msg = z.cb.ConsumeReceivedMessage(); ASSERT_TRUE(msg.has_value()); EXPECT_EQ(msg->stream_id(), StreamID(1)); } -TEST_F(DcSctpSocketTest, SendMessageAfterEstablished) { - ConnectSockets(); +TEST(DcSctpSocketTest, SendMessageAfterEstablished) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + ConnectSockets(a, z); - absl::optional msg = cb_z_.ConsumeReceivedMessage(); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + + absl::optional msg = z.cb.ConsumeReceivedMessage(); ASSERT_TRUE(msg.has_value()); EXPECT_EQ(msg->stream_id(), StreamID(1)); } TEST_P(DcSctpSocketParametrizedTest, TimeoutResendsPacket) { - ConnectSockets(); - MaybeHandoverSocketZ(); + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); - cb_a_.ConsumeSentPacket(); + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + a.cb.ConsumeSentPacket(); RTC_LOG(LS_INFO) << "Advancing time"; - AdvanceTime(options_.rto_initial); - RunTimers(); + AdvanceTime(a, *z, a.options.rto_initial); - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); - absl::optional msg = cb_z_.ConsumeReceivedMessage(); + absl::optional msg = z->cb.ConsumeReceivedMessage(); ASSERT_TRUE(msg.has_value()); EXPECT_EQ(msg->stream_id(), StreamID(1)); - MaybeHandoverSocketZAndSendMessage(); + MaybeHandoverSocketAndSendMessage(a, std::move(z)); } TEST_P(DcSctpSocketParametrizedTest, SendALotOfBytesMissedSecondPacket) { - ConnectSockets(); - MaybeHandoverSocketZ(); + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); std::vector payload(kLargeMessageSize); - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions); // First DATA - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); // Second DATA (lost) - cb_a_.ConsumeSentPacket(); + a.cb.ConsumeSentPacket(); // Retransmit and handle the rest - ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + ExchangeMessages(a, *z); - absl::optional msg = cb_z_.ConsumeReceivedMessage(); + absl::optional msg = z->cb.ConsumeReceivedMessage(); ASSERT_TRUE(msg.has_value()); EXPECT_EQ(msg->stream_id(), StreamID(1)); EXPECT_THAT(msg->payload(), testing::ElementsAreArray(payload)); - MaybeHandoverSocketZAndSendMessage(); + MaybeHandoverSocketAndSendMessage(a, std::move(z)); } TEST_P(DcSctpSocketParametrizedTest, SendingHeartbeatAnswersWithAck) { - ConnectSockets(); - MaybeHandoverSocketZ(); + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); // Inject a HEARTBEAT chunk - SctpPacket::Builder b(sock_a_->verification_tag(), DcSctpOptions()); + SctpPacket::Builder b(a.socket.verification_tag(), DcSctpOptions()); uint8_t info[] = {1, 2, 3, 4}; Parameters::Builder params_builder; params_builder.Add(HeartbeatInfoParameter(info)); b.Add(HeartbeatRequestChunk(params_builder.Build())); - sock_a_->ReceivePacket(b.Build()); + a.socket.ReceivePacket(b.Build()); // HEARTBEAT_ACK is sent as a reply. Capture it. ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket ack_packet, - SctpPacket::Parse(cb_a_.ConsumeSentPacket())); + SctpPacket::Parse(a.cb.ConsumeSentPacket())); ASSERT_THAT(ack_packet.descriptors(), SizeIs(1)); ASSERT_HAS_VALUE_AND_ASSIGN( HeartbeatAckChunk ack, @@ -844,19 +875,21 @@ TEST_P(DcSctpSocketParametrizedTest, SendingHeartbeatAnswersWithAck) { ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatInfoParameter info_param, ack.info()); EXPECT_THAT(info_param.info(), ElementsAre(1, 2, 3, 4)); - MaybeHandoverSocketZAndSendMessage(); + MaybeHandoverSocketAndSendMessage(a, std::move(z)); } TEST_P(DcSctpSocketParametrizedTest, ExpectHeartbeatToBeSent) { - ConnectSockets(); - MaybeHandoverSocketZ(); + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); - EXPECT_THAT(cb_a_.ConsumeSentPacket(), IsEmpty()); + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); - AdvanceTime(options_.heartbeat_interval); - RunTimers(); + EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty()); - std::vector hb_packet_raw = cb_a_.ConsumeSentPacket(); + AdvanceTime(a, *z, a.options.heartbeat_interval); + + std::vector hb_packet_raw = a.cb.ConsumeSentPacket(); ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket hb_packet, SctpPacket::Parse(hb_packet_raw)); ASSERT_THAT(hb_packet.descriptors(), SizeIs(1)); @@ -869,86 +902,85 @@ TEST_P(DcSctpSocketParametrizedTest, ExpectHeartbeatToBeSent) { EXPECT_THAT(hb.info()->info(), SizeIs(8)); // Feed it to Sock-z and expect a HEARTBEAT_ACK that will be propagated back. - sock_z_->ReceivePacket(hb_packet_raw); - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + z->socket.ReceivePacket(hb_packet_raw); + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); - MaybeHandoverSocketZAndSendMessage(); + MaybeHandoverSocketAndSendMessage(a, std::move(z)); } TEST_P(DcSctpSocketParametrizedTest, CloseConnectionAfterTooManyLostHeartbeats) { - ConnectSockets(); - MaybeHandoverSocketZ(); + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); - EXPECT_CALL(cb_z_, OnClosed).Times(1); - EXPECT_THAT(cb_a_.ConsumeSentPacket(), testing::IsEmpty()); + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_CALL(z->cb, OnClosed).Times(1); + EXPECT_THAT(a.cb.ConsumeSentPacket(), testing::IsEmpty()); // Force-close socket Z so that it doesn't interfere from now on. - sock_z_->Close(); + z->socket.Close(); - DurationMs time_to_next_hearbeat = options_.heartbeat_interval; + DurationMs time_to_next_hearbeat = a.options.heartbeat_interval; - for (int i = 0; i < *options_.max_retransmissions; ++i) { + for (int i = 0; i < *a.options.max_retransmissions; ++i) { RTC_LOG(LS_INFO) << "Letting HEARTBEAT interval timer expire - sending..."; - AdvanceTime(time_to_next_hearbeat); - RunTimers(); + AdvanceTime(a, *z, time_to_next_hearbeat); // Dropping every heartbeat. ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket hb_packet, - SctpPacket::Parse(cb_a_.ConsumeSentPacket())); + SctpPacket::Parse(a.cb.ConsumeSentPacket())); EXPECT_EQ(hb_packet.descriptors()[0].type, HeartbeatRequestChunk::kType); RTC_LOG(LS_INFO) << "Letting the heartbeat expire."; - AdvanceTime(DurationMs(1000)); - RunTimers(); + AdvanceTime(a, *z, DurationMs(1000)); - time_to_next_hearbeat = options_.heartbeat_interval - DurationMs(1000); + time_to_next_hearbeat = a.options.heartbeat_interval - DurationMs(1000); } RTC_LOG(LS_INFO) << "Letting HEARTBEAT interval timer expire - sending..."; - AdvanceTime(time_to_next_hearbeat); - RunTimers(); + AdvanceTime(a, *z, time_to_next_hearbeat); // Last heartbeat - EXPECT_THAT(cb_a_.ConsumeSentPacket(), Not(IsEmpty())); + EXPECT_THAT(a.cb.ConsumeSentPacket(), Not(IsEmpty())); - EXPECT_CALL(cb_a_, OnAborted).Times(1); + EXPECT_CALL(a.cb, OnAborted).Times(1); // Should suffice as exceeding RTO - AdvanceTime(DurationMs(1000)); - RunTimers(); + AdvanceTime(a, *z, DurationMs(1000)); - MaybeHandoverSocketZ(); + z = MaybeHandoverSocket(std::move(z)); } TEST_P(DcSctpSocketParametrizedTest, RecoversAfterASuccessfulAck) { - ConnectSockets(); - MaybeHandoverSocketZ(); + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); - EXPECT_THAT(cb_a_.ConsumeSentPacket(), testing::IsEmpty()); - EXPECT_CALL(cb_z_, OnClosed).Times(1); + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_THAT(a.cb.ConsumeSentPacket(), testing::IsEmpty()); + EXPECT_CALL(z->cb, OnClosed).Times(1); // Force-close socket Z so that it doesn't interfere from now on. - sock_z_->Close(); + z->socket.Close(); - DurationMs time_to_next_hearbeat = options_.heartbeat_interval; + DurationMs time_to_next_hearbeat = a.options.heartbeat_interval; - for (int i = 0; i < *options_.max_retransmissions; ++i) { - AdvanceTime(time_to_next_hearbeat); - RunTimers(); + for (int i = 0; i < *a.options.max_retransmissions; ++i) { + AdvanceTime(a, *z, time_to_next_hearbeat); // Dropping every heartbeat. - cb_a_.ConsumeSentPacket(); + a.cb.ConsumeSentPacket(); RTC_LOG(LS_INFO) << "Letting the heartbeat expire."; - AdvanceTime(DurationMs(1000)); - RunTimers(); + AdvanceTime(a, *z, DurationMs(1000)); - time_to_next_hearbeat = options_.heartbeat_interval - DurationMs(1000); + time_to_next_hearbeat = a.options.heartbeat_interval - DurationMs(1000); } RTC_LOG(LS_INFO) << "Getting the last heartbeat - and acking it"; - AdvanceTime(time_to_next_hearbeat); - RunTimers(); + AdvanceTime(a, *z, time_to_next_hearbeat); - std::vector hb_packet_raw = cb_a_.ConsumeSentPacket(); + std::vector hb_packet_raw = a.cb.ConsumeSentPacket(); ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket hb_packet, SctpPacket::Parse(hb_packet_raw)); ASSERT_THAT(hb_packet.descriptors(), SizeIs(1)); @@ -956,350 +988,363 @@ TEST_P(DcSctpSocketParametrizedTest, RecoversAfterASuccessfulAck) { HeartbeatRequestChunk hb, HeartbeatRequestChunk::Parse(hb_packet.descriptors()[0].data)); - SctpPacket::Builder b(sock_a_->verification_tag(), options_); + SctpPacket::Builder b(a.socket.verification_tag(), a.options); b.Add(HeartbeatAckChunk(std::move(hb).extract_parameters())); - sock_a_->ReceivePacket(b.Build()); + a.socket.ReceivePacket(b.Build()); // Should suffice as exceeding RTO - which will not fire. - EXPECT_CALL(cb_a_, OnAborted).Times(0); - AdvanceTime(DurationMs(1000)); - RunTimers(); - EXPECT_THAT(cb_a_.ConsumeSentPacket(), IsEmpty()); + EXPECT_CALL(a.cb, OnAborted).Times(0); + AdvanceTime(a, *z, DurationMs(1000)); + + EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty()); // Verify that we get new heartbeats again. RTC_LOG(LS_INFO) << "Expecting a new heartbeat"; - AdvanceTime(time_to_next_hearbeat); - RunTimers(); + AdvanceTime(a, *z, time_to_next_hearbeat); ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket another_packet, - SctpPacket::Parse(cb_a_.ConsumeSentPacket())); + SctpPacket::Parse(a.cb.ConsumeSentPacket())); EXPECT_EQ(another_packet.descriptors()[0].type, HeartbeatRequestChunk::kType); } TEST_P(DcSctpSocketParametrizedTest, ResetStream) { - ConnectSockets(); - MaybeHandoverSocketZ(); + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), {}); - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); - absl::optional msg = cb_z_.ConsumeReceivedMessage(); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), {}); + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + + absl::optional msg = z->cb.ConsumeReceivedMessage(); ASSERT_TRUE(msg.has_value()); EXPECT_EQ(msg->stream_id(), StreamID(1)); // Handle SACK - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); // Reset the outgoing stream. This will directly send a RE-CONFIG. - sock_a_->ResetStreams(std::vector({StreamID(1)})); + a.socket.ResetStreams(std::vector({StreamID(1)})); // Receiving the packet will trigger a callback, indicating that A has // reset its stream. It will also send a RE-CONFIG with a response. - EXPECT_CALL(cb_z_, OnIncomingStreamsReset).Times(1); - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + EXPECT_CALL(z->cb, OnIncomingStreamsReset).Times(1); + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); // Receiving a response will trigger a callback. Streams are now reset. - EXPECT_CALL(cb_a_, OnStreamsResetPerformed).Times(1); - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + EXPECT_CALL(a.cb, OnStreamsResetPerformed).Times(1); + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); - MaybeHandoverSocketZAndSendMessage(); + MaybeHandoverSocketAndSendMessage(a, std::move(z)); } TEST_P(DcSctpSocketParametrizedTest, ResetStreamWillMakeChunksStartAtZeroSsn) { - ConnectSockets(); - MaybeHandoverSocketZ(); + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); - std::vector payload(options_.mtu - 100); + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + std::vector payload(a.options.mtu - 100); - auto packet1 = cb_a_.ConsumeSentPacket(); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + + auto packet1 = a.cb.ConsumeSentPacket(); EXPECT_THAT(packet1, HasDataChunkWithSsn(SSN(0))); - sock_z_->ReceivePacket(packet1); + z->socket.ReceivePacket(packet1); - auto packet2 = cb_a_.ConsumeSentPacket(); + auto packet2 = a.cb.ConsumeSentPacket(); EXPECT_THAT(packet2, HasDataChunkWithSsn(SSN(1))); - sock_z_->ReceivePacket(packet2); + z->socket.ReceivePacket(packet2); // Handle SACK - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); - absl::optional msg1 = cb_z_.ConsumeReceivedMessage(); + absl::optional msg1 = z->cb.ConsumeReceivedMessage(); ASSERT_TRUE(msg1.has_value()); EXPECT_EQ(msg1->stream_id(), StreamID(1)); - absl::optional msg2 = cb_z_.ConsumeReceivedMessage(); + absl::optional msg2 = z->cb.ConsumeReceivedMessage(); ASSERT_TRUE(msg2.has_value()); EXPECT_EQ(msg2->stream_id(), StreamID(1)); // Reset the outgoing stream. This will directly send a RE-CONFIG. - sock_a_->ResetStreams(std::vector({StreamID(1)})); + a.socket.ResetStreams(std::vector({StreamID(1)})); // RE-CONFIG, req - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); // RE-CONFIG, resp - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); - auto packet3 = cb_a_.ConsumeSentPacket(); + auto packet3 = a.cb.ConsumeSentPacket(); EXPECT_THAT(packet3, HasDataChunkWithSsn(SSN(0))); - sock_z_->ReceivePacket(packet3); + z->socket.ReceivePacket(packet3); - auto packet4 = cb_a_.ConsumeSentPacket(); + auto packet4 = a.cb.ConsumeSentPacket(); EXPECT_THAT(packet4, HasDataChunkWithSsn(SSN(1))); - sock_z_->ReceivePacket(packet4); + z->socket.ReceivePacket(packet4); // Handle SACK - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); - MaybeHandoverSocketZAndSendMessage(); + MaybeHandoverSocketAndSendMessage(a, std::move(z)); } TEST_P(DcSctpSocketParametrizedTest, ResetStreamWillOnlyResetTheRequestedStreams) { - ConnectSockets(); - MaybeHandoverSocketZ(); + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); - std::vector payload(options_.mtu - 100); + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + std::vector payload(a.options.mtu - 100); // Send two ordered messages on SID 1 - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); - auto packet1 = cb_a_.ConsumeSentPacket(); + auto packet1 = a.cb.ConsumeSentPacket(); EXPECT_THAT(packet1, HasDataChunkWithStreamId(StreamID(1))); EXPECT_THAT(packet1, HasDataChunkWithSsn(SSN(0))); - sock_z_->ReceivePacket(packet1); + z->socket.ReceivePacket(packet1); - auto packet2 = cb_a_.ConsumeSentPacket(); + auto packet2 = a.cb.ConsumeSentPacket(); EXPECT_THAT(packet1, HasDataChunkWithStreamId(StreamID(1))); EXPECT_THAT(packet2, HasDataChunkWithSsn(SSN(1))); - sock_z_->ReceivePacket(packet2); + z->socket.ReceivePacket(packet2); // Handle SACK - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); // Do the same, for SID 3 - sock_a_->Send(DcSctpMessage(StreamID(3), PPID(53), payload), {}); - sock_a_->Send(DcSctpMessage(StreamID(3), PPID(53), payload), {}); - auto packet3 = cb_a_.ConsumeSentPacket(); + a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), payload), {}); + a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), payload), {}); + auto packet3 = a.cb.ConsumeSentPacket(); EXPECT_THAT(packet3, HasDataChunkWithStreamId(StreamID(3))); EXPECT_THAT(packet3, HasDataChunkWithSsn(SSN(0))); - sock_z_->ReceivePacket(packet3); - auto packet4 = cb_a_.ConsumeSentPacket(); + z->socket.ReceivePacket(packet3); + auto packet4 = a.cb.ConsumeSentPacket(); EXPECT_THAT(packet4, HasDataChunkWithStreamId(StreamID(3))); EXPECT_THAT(packet4, HasDataChunkWithSsn(SSN(1))); - sock_z_->ReceivePacket(packet4); - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + z->socket.ReceivePacket(packet4); + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); // Receive all messages. - absl::optional msg1 = cb_z_.ConsumeReceivedMessage(); + absl::optional msg1 = z->cb.ConsumeReceivedMessage(); ASSERT_TRUE(msg1.has_value()); EXPECT_EQ(msg1->stream_id(), StreamID(1)); - absl::optional msg2 = cb_z_.ConsumeReceivedMessage(); + absl::optional msg2 = z->cb.ConsumeReceivedMessage(); ASSERT_TRUE(msg2.has_value()); EXPECT_EQ(msg2->stream_id(), StreamID(1)); - absl::optional msg3 = cb_z_.ConsumeReceivedMessage(); + absl::optional msg3 = z->cb.ConsumeReceivedMessage(); ASSERT_TRUE(msg3.has_value()); EXPECT_EQ(msg3->stream_id(), StreamID(3)); - absl::optional msg4 = cb_z_.ConsumeReceivedMessage(); + absl::optional msg4 = z->cb.ConsumeReceivedMessage(); ASSERT_TRUE(msg4.has_value()); EXPECT_EQ(msg4->stream_id(), StreamID(3)); // Reset SID 1. This will directly send a RE-CONFIG. - sock_a_->ResetStreams(std::vector({StreamID(3)})); + a.socket.ResetStreams(std::vector({StreamID(3)})); // RE-CONFIG, req - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); // RE-CONFIG, resp - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); // Send a message on SID 1 and 3 - SID 1 should not be reset, but 3 should. - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); - sock_a_->Send(DcSctpMessage(StreamID(3), PPID(53), payload), {}); + a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), payload), {}); - auto packet5 = cb_a_.ConsumeSentPacket(); + auto packet5 = a.cb.ConsumeSentPacket(); EXPECT_THAT(packet5, HasDataChunkWithStreamId(StreamID(1))); EXPECT_THAT(packet5, HasDataChunkWithSsn(SSN(2))); // Unchanged. - sock_z_->ReceivePacket(packet5); + z->socket.ReceivePacket(packet5); - auto packet6 = cb_a_.ConsumeSentPacket(); + auto packet6 = a.cb.ConsumeSentPacket(); EXPECT_THAT(packet6, HasDataChunkWithStreamId(StreamID(3))); EXPECT_THAT(packet6, HasDataChunkWithSsn(SSN(0))); // Reset. - sock_z_->ReceivePacket(packet6); + z->socket.ReceivePacket(packet6); // Handle SACK - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); - MaybeHandoverSocketZAndSendMessage(); + MaybeHandoverSocketAndSendMessage(a, std::move(z)); } TEST_P(DcSctpSocketParametrizedTest, OnePeerReconnects) { - ConnectSockets(); - MaybeHandoverSocketZ(); + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); - EXPECT_CALL(cb_a_, OnConnectionRestarted).Times(1); + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_CALL(a.cb, OnConnectionRestarted).Times(1); // Let's be evil here - reconnect while a fragmented packet was about to be // sent. The receiving side should get it in full. std::vector payload(kLargeMessageSize); - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions); // First DATA - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); // Create a new association, z2 - and don't use z anymore. - testing::NiceMock cb_z2("Z2"); - DcSctpSocket sock_z2("Z2", cb_z2, nullptr, options_); - - sock_z2.Connect(); + SocketUnderTest z2("Z2"); + z2.socket.Connect(); // Retransmit and handle the rest. As there will be some chunks in-flight that // have the wrong verification tag, those will yield errors. - ExchangeMessages(*sock_a_, cb_a_, sock_z2, cb_z2); + ExchangeMessages(a, z2); - absl::optional msg = cb_z2.ConsumeReceivedMessage(); + absl::optional msg = z2.cb.ConsumeReceivedMessage(); ASSERT_TRUE(msg.has_value()); EXPECT_EQ(msg->stream_id(), StreamID(1)); EXPECT_THAT(msg->payload(), testing::ElementsAreArray(payload)); } TEST_P(DcSctpSocketParametrizedTest, SendMessageWithLimitedRtx) { - ConnectSockets(); - MaybeHandoverSocketZ(); + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); SendOptions send_options; send_options.max_retransmissions = 0; - std::vector payload(options_.mtu - 100); - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options); - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(52), payload), send_options); - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), payload), send_options); + std::vector payload(a.options.mtu - 100); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload), send_options); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), send_options); // First DATA - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); // Second DATA (lost) - cb_a_.ConsumeSentPacket(); + a.cb.ConsumeSentPacket(); // Third DATA - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); // Handle SACK for first DATA - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); // Handle delayed SACK for third DATA - AdvanceTime(options_.delayed_ack_max_timeout); - RunTimers(); + AdvanceTime(a, *z, a.options.delayed_ack_max_timeout); // Handle SACK for second DATA - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); // Now the missing data chunk will be marked as nacked, but it might still be // in-flight and the reported gap could be due to out-of-order delivery. So // the RetransmissionQueue will not mark it as "to be retransmitted" until // after the t3-rtx timer has expired. - AdvanceTime(options_.rto_initial); - RunTimers(); + AdvanceTime(a, *z, a.options.rto_initial); // The chunk will be marked as retransmitted, and then as abandoned, which // will trigger a FORWARD-TSN to be sent. // FORWARD-TSN (third) - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); // Which will trigger a SACK - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); - absl::optional msg1 = cb_z_.ConsumeReceivedMessage(); + absl::optional msg1 = z->cb.ConsumeReceivedMessage(); ASSERT_TRUE(msg1.has_value()); EXPECT_EQ(msg1->ppid(), PPID(51)); - absl::optional msg2 = cb_z_.ConsumeReceivedMessage(); + absl::optional msg2 = z->cb.ConsumeReceivedMessage(); ASSERT_TRUE(msg2.has_value()); EXPECT_EQ(msg2->ppid(), PPID(53)); - absl::optional msg3 = cb_z_.ConsumeReceivedMessage(); + absl::optional msg3 = z->cb.ConsumeReceivedMessage(); EXPECT_FALSE(msg3.has_value()); - MaybeHandoverSocketZAndSendMessage(); + MaybeHandoverSocketAndSendMessage(a, std::move(z)); } TEST_P(DcSctpSocketParametrizedTest, SendManyFragmentedMessagesWithLimitedRtx) { - ConnectSockets(); - MaybeHandoverSocketZ(); + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); SendOptions send_options; send_options.unordered = IsUnordered(true); send_options.max_retransmissions = 0; - std::vector payload(options_.mtu * 2 - 100 /* margin */); + std::vector payload(a.options.mtu * 2 - 100 /* margin */); // Sending first message - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options); // Sending second message - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(52), payload), send_options); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload), send_options); // Sending third message - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), payload), send_options); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), send_options); // Sending fourth message - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(54), payload), send_options); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(54), payload), send_options); // First DATA, first fragment - std::vector packet = cb_a_.ConsumeSentPacket(); + std::vector packet = a.cb.ConsumeSentPacket(); EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(51))); - sock_z_->ReceivePacket(std::move(packet)); + z->socket.ReceivePacket(std::move(packet)); // First DATA, second fragment (lost) - packet = cb_a_.ConsumeSentPacket(); + packet = a.cb.ConsumeSentPacket(); EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(51))); // Second DATA, first fragment - packet = cb_a_.ConsumeSentPacket(); + packet = a.cb.ConsumeSentPacket(); EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(52))); - sock_z_->ReceivePacket(std::move(packet)); + z->socket.ReceivePacket(std::move(packet)); // Second DATA, second fragment (lost) - packet = cb_a_.ConsumeSentPacket(); + packet = a.cb.ConsumeSentPacket(); EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(52))); EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0))); // Third DATA, first fragment - packet = cb_a_.ConsumeSentPacket(); + packet = a.cb.ConsumeSentPacket(); EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(53))); EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0))); - sock_z_->ReceivePacket(std::move(packet)); + z->socket.ReceivePacket(std::move(packet)); // Third DATA, second fragment (lost) - packet = cb_a_.ConsumeSentPacket(); + packet = a.cb.ConsumeSentPacket(); EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(53))); EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0))); // Fourth DATA, first fragment - packet = cb_a_.ConsumeSentPacket(); + packet = a.cb.ConsumeSentPacket(); EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(54))); EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0))); - sock_z_->ReceivePacket(std::move(packet)); + z->socket.ReceivePacket(std::move(packet)); // Fourth DATA, second fragment - packet = cb_a_.ConsumeSentPacket(); + packet = a.cb.ConsumeSentPacket(); EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(54))); EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0))); - sock_z_->ReceivePacket(std::move(packet)); + z->socket.ReceivePacket(std::move(packet)); - ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + ExchangeMessages(a, *z); // Let the RTX timer expire, and exchange FORWARD-TSN/SACKs - AdvanceTime(options_.rto_initial); - RunTimers(); - ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + AdvanceTime(a, *z, a.options.rto_initial); - absl::optional msg1 = cb_z_.ConsumeReceivedMessage(); + ExchangeMessages(a, *z); + + absl::optional msg1 = z->cb.ConsumeReceivedMessage(); ASSERT_TRUE(msg1.has_value()); EXPECT_EQ(msg1->ppid(), PPID(54)); - ASSERT_FALSE(cb_z_.ConsumeReceivedMessage().has_value()); + ASSERT_FALSE(z->cb.ConsumeReceivedMessage().has_value()); - MaybeHandoverSocketZAndSendMessage(); + MaybeHandoverSocketAndSendMessage(a, std::move(z)); } struct FakeChunkConfig : ChunkConfig { @@ -1322,17 +1367,20 @@ class FakeChunk : public Chunk, public TLVTrait { }; TEST_P(DcSctpSocketParametrizedTest, ReceivingUnknownChunkRespondsWithError) { - ConnectSockets(); - MaybeHandoverSocketZ(); + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); // Inject a FAKE chunk - SctpPacket::Builder b(sock_a_->verification_tag(), DcSctpOptions()); + SctpPacket::Builder b(a.socket.verification_tag(), DcSctpOptions()); b.Add(FakeChunk()); - sock_a_->ReceivePacket(b.Build()); + a.socket.ReceivePacket(b.Build()); // ERROR is sent as a reply. Capture it. ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket reply_packet, - SctpPacket::Parse(cb_a_.ConsumeSentPacket())); + SctpPacket::Parse(a.cb.ConsumeSentPacket())); ASSERT_THAT(reply_packet.descriptors(), SizeIs(1)); ASSERT_HAS_VALUE_AND_ASSIGN( ErrorChunk error, ErrorChunk::Parse(reply_packet.descriptors()[0].data)); @@ -1341,50 +1389,52 @@ TEST_P(DcSctpSocketParametrizedTest, ReceivingUnknownChunkRespondsWithError) { error.error_causes().get()); EXPECT_THAT(cause.unrecognized_chunk(), ElementsAre(0x49, 0x00, 0x00, 0x04)); - MaybeHandoverSocketZAndSendMessage(); + MaybeHandoverSocketAndSendMessage(a, std::move(z)); } TEST_P(DcSctpSocketParametrizedTest, ReceivingErrorChunkReportsAsCallback) { - ConnectSockets(); - MaybeHandoverSocketZ(); + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); // Inject a ERROR chunk - SctpPacket::Builder b(sock_a_->verification_tag(), DcSctpOptions()); + SctpPacket::Builder b(a.socket.verification_tag(), DcSctpOptions()); b.Add( ErrorChunk(Parameters::Builder() .Add(UnrecognizedChunkTypeCause({0x49, 0x00, 0x00, 0x04})) .Build())); - EXPECT_CALL(cb_a_, OnError(ErrorKind::kPeerReported, - HasSubstr("Unrecognized Chunk Type"))); - sock_a_->ReceivePacket(b.Build()); + EXPECT_CALL(a.cb, OnError(ErrorKind::kPeerReported, + HasSubstr("Unrecognized Chunk Type"))); + a.socket.ReceivePacket(b.Build()); - MaybeHandoverSocketZAndSendMessage(); + MaybeHandoverSocketAndSendMessage(a, std::move(z)); } -TEST_F(DcSctpSocketTest, PassingHighWatermarkWillOnlyAcceptCumAckTsn) { - // Create a new association, z2 - and don't use z anymore. - testing::NiceMock cb_z2("Z2"); - DcSctpOptions options = options_; +TEST(DcSctpSocketTest, PassingHighWatermarkWillOnlyAcceptCumAckTsn) { + SocketUnderTest a("A"); + constexpr size_t kReceiveWindowBufferSize = 2000; - options.max_receiver_window_buffer_size = kReceiveWindowBufferSize; - options.mtu = 3000; - DcSctpSocket sock_z2("Z2", cb_z2, nullptr, options); + SocketUnderTest z( + "Z", {.mtu = 3000, + .max_receiver_window_buffer_size = kReceiveWindowBufferSize}); - EXPECT_CALL(cb_z2, OnClosed).Times(0); - EXPECT_CALL(cb_z2, OnAborted).Times(0); + EXPECT_CALL(z.cb, OnClosed).Times(0); + EXPECT_CALL(z.cb, OnAborted).Times(0); - sock_a_->Connect(); - std::vector init_data = cb_a_.ConsumeSentPacket(); + a.socket.Connect(); + std::vector init_data = a.cb.ConsumeSentPacket(); ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet, SctpPacket::Parse(init_data)); ASSERT_HAS_VALUE_AND_ASSIGN( InitChunk init_chunk, InitChunk::Parse(init_packet.descriptors()[0].data)); - sock_z2.ReceivePacket(init_data); - sock_a_->ReceivePacket(cb_z2.ConsumeSentPacket()); - sock_z2.ReceivePacket(cb_a_.ConsumeSentPacket()); - sock_a_->ReceivePacket(cb_z2.ConsumeSentPacket()); + z.socket.ReceivePacket(init_data); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // Fill up Z2 to the high watermark limit. constexpr size_t kWatermarkLimit = @@ -1394,99 +1444,105 @@ TEST_F(DcSctpSocketTest, PassingHighWatermarkWillOnlyAcceptCumAckTsn) { TSN tsn = init_chunk.initial_tsn(); AnyDataChunk::Options opts; opts.is_beginning = Data::IsBeginning(true); - sock_z2.ReceivePacket( - SctpPacket::Builder(sock_z2.verification_tag(), options) + z.socket.ReceivePacket( + SctpPacket::Builder(z.socket.verification_tag(), z.options) .Add(DataChunk(tsn, StreamID(1), SSN(0), PPID(53), std::vector(kWatermarkLimit + 1), opts)) .Build()); // First DATA will always trigger a SACK. It's not interesting. - EXPECT_THAT(cb_z2.ConsumeSentPacket(), + EXPECT_THAT(z.cb.ConsumeSentPacket(), AllOf(HasSackWithCumAckTsn(tsn), HasSackWithNoGapAckBlocks())); // This DATA should be accepted - it's advancing cum ack tsn. - sock_z2.ReceivePacket(SctpPacket::Builder(sock_z2.verification_tag(), options) - .Add(DataChunk(AddTo(tsn, 1), StreamID(1), SSN(0), - PPID(53), std::vector(1), - /*options=*/{})) - .Build()); + z.socket.ReceivePacket( + SctpPacket::Builder(z.socket.verification_tag(), z.options) + .Add(DataChunk(AddTo(tsn, 1), StreamID(1), SSN(0), PPID(53), + std::vector(1), + /*options=*/{})) + .Build()); // The receiver might have moved into delayed ack mode. - cb_z2.AdvanceTime(options.rto_initial); - RunTimers(cb_z2, sock_z2); + AdvanceTime(a, z, z.options.rto_initial); EXPECT_THAT( - cb_z2.ConsumeSentPacket(), + z.cb.ConsumeSentPacket(), AllOf(HasSackWithCumAckTsn(AddTo(tsn, 1)), HasSackWithNoGapAckBlocks())); // This DATA will not be accepted - it's not advancing cum ack tsn. - sock_z2.ReceivePacket(SctpPacket::Builder(sock_z2.verification_tag(), options) - .Add(DataChunk(AddTo(tsn, 3), StreamID(1), SSN(0), - PPID(53), std::vector(1), - /*options=*/{})) - .Build()); + z.socket.ReceivePacket( + SctpPacket::Builder(z.socket.verification_tag(), z.options) + .Add(DataChunk(AddTo(tsn, 3), StreamID(1), SSN(0), PPID(53), + std::vector(1), + /*options=*/{})) + .Build()); // Sack will be sent in IMMEDIATE mode when this is happening. EXPECT_THAT( - cb_z2.ConsumeSentPacket(), + z.cb.ConsumeSentPacket(), AllOf(HasSackWithCumAckTsn(AddTo(tsn, 1)), HasSackWithNoGapAckBlocks())); // This DATA will not be accepted either. - sock_z2.ReceivePacket(SctpPacket::Builder(sock_z2.verification_tag(), options) - .Add(DataChunk(AddTo(tsn, 4), StreamID(1), SSN(0), - PPID(53), std::vector(1), - /*options=*/{})) - .Build()); + z.socket.ReceivePacket( + SctpPacket::Builder(z.socket.verification_tag(), z.options) + .Add(DataChunk(AddTo(tsn, 4), StreamID(1), SSN(0), PPID(53), + std::vector(1), + /*options=*/{})) + .Build()); // Sack will be sent in IMMEDIATE mode when this is happening. EXPECT_THAT( - cb_z2.ConsumeSentPacket(), + z.cb.ConsumeSentPacket(), AllOf(HasSackWithCumAckTsn(AddTo(tsn, 1)), HasSackWithNoGapAckBlocks())); // This DATA should be accepted, and it fills the reassembly queue. - sock_z2.ReceivePacket( - SctpPacket::Builder(sock_z2.verification_tag(), options) + z.socket.ReceivePacket( + SctpPacket::Builder(z.socket.verification_tag(), z.options) .Add(DataChunk(AddTo(tsn, 2), StreamID(1), SSN(0), PPID(53), std::vector(kRemainingSize), /*options=*/{})) .Build()); // The receiver might have moved into delayed ack mode. - cb_z2.AdvanceTime(options.rto_initial); - RunTimers(cb_z2, sock_z2); + AdvanceTime(a, z, z.options.rto_initial); EXPECT_THAT( - cb_z2.ConsumeSentPacket(), + z.cb.ConsumeSentPacket(), AllOf(HasSackWithCumAckTsn(AddTo(tsn, 2)), HasSackWithNoGapAckBlocks())); - EXPECT_CALL(cb_z2, OnAborted(ErrorKind::kResourceExhaustion, _)); - EXPECT_CALL(cb_z2, OnClosed).Times(0); + EXPECT_CALL(z.cb, OnAborted(ErrorKind::kResourceExhaustion, _)); + EXPECT_CALL(z.cb, OnClosed).Times(0); // This DATA will make the connection close. It's too full now. - sock_z2.ReceivePacket( - SctpPacket::Builder(sock_z2.verification_tag(), options) + z.socket.ReceivePacket( + SctpPacket::Builder(z.socket.verification_tag(), z.options) .Add(DataChunk(AddTo(tsn, 3), StreamID(1), SSN(0), PPID(53), std::vector(kSmallMessageSize), /*options=*/{})) .Build()); } -TEST_F(DcSctpSocketTest, SetMaxMessageSize) { - sock_a_->SetMaxMessageSize(42u); - EXPECT_EQ(sock_a_->options().max_message_size, 42u); +TEST(DcSctpSocketTest, SetMaxMessageSize) { + SocketUnderTest a("A"); + + a.socket.SetMaxMessageSize(42u); + EXPECT_EQ(a.socket.options().max_message_size, 42u); } TEST_P(DcSctpSocketParametrizedTest, SendsMessagesWithLowLifetime) { - ConnectSockets(); - MaybeHandoverSocketZ(); + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); // Mock that the time always goes forward. TimeMs now(0); - EXPECT_CALL(cb_a_, TimeMillis).WillRepeatedly([&]() { + EXPECT_CALL(a.cb, TimeMillis).WillRepeatedly([&]() { now += DurationMs(3); return now; }); - EXPECT_CALL(cb_z_, TimeMillis).WillRepeatedly([&]() { + EXPECT_CALL(z->cb, TimeMillis).WillRepeatedly([&]() { now += DurationMs(3); return now; }); @@ -1499,27 +1555,30 @@ TEST_P(DcSctpSocketParametrizedTest, SendsMessagesWithLowLifetime) { send_options.unordered = IsUnordered((i % 2) == 0); send_options.lifetime = DurationMs(i % 3); // 0, 1, 2 ms - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), send_options); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), send_options); } - ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + ExchangeMessages(a, *z); for (int i = 0; i < kIterations; ++i) { - EXPECT_TRUE(cb_z_.ConsumeReceivedMessage().has_value()); + EXPECT_TRUE(z->cb.ConsumeReceivedMessage().has_value()); } - EXPECT_FALSE(cb_z_.ConsumeReceivedMessage().has_value()); + EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value()); // Validate that the sockets really make the time move forward. EXPECT_GE(*now, kIterations * 2); - MaybeHandoverSocketZAndSendMessage(); + MaybeHandoverSocketAndSendMessage(a, std::move(z)); } TEST_P(DcSctpSocketParametrizedTest, DiscardsMessagesWithLowLifetimeIfMustBuffer) { - ConnectSockets(); - MaybeHandoverSocketZ(); + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); SendOptions lifetime_0; lifetime_0.unordered = IsUnordered(true); @@ -1531,93 +1590,100 @@ TEST_P(DcSctpSocketParametrizedTest, // Mock that the time always goes forward. TimeMs now(0); - EXPECT_CALL(cb_a_, TimeMillis).WillRepeatedly([&]() { + EXPECT_CALL(a.cb, TimeMillis).WillRepeatedly([&]() { now += DurationMs(3); return now; }); - EXPECT_CALL(cb_z_, TimeMillis).WillRepeatedly([&]() { + EXPECT_CALL(z->cb, TimeMillis).WillRepeatedly([&]() { now += DurationMs(3); return now; }); // Fill up the send buffer with a large message. std::vector payload(kLargeMessageSize); - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions); // And queue a few small messages with lifetime=0 or 1 ms - can't be sent. - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2, 3}), lifetime_0); - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {4, 5, 6}), lifetime_1); - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {7, 8, 9}), lifetime_0); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2, 3}), lifetime_0); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {4, 5, 6}), lifetime_1); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {7, 8, 9}), lifetime_0); // Handle all that was sent until congestion window got full. for (;;) { - std::vector packet_from_a = cb_a_.ConsumeSentPacket(); + std::vector packet_from_a = a.cb.ConsumeSentPacket(); if (packet_from_a.empty()) { break; } - sock_z_->ReceivePacket(std::move(packet_from_a)); + z->socket.ReceivePacket(std::move(packet_from_a)); } // Shouldn't be enough to send that large message. - EXPECT_FALSE(cb_z_.ConsumeReceivedMessage().has_value()); + EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value()); // Exchange the rest of the messages, with the time ever increasing. - ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + ExchangeMessages(a, *z); // The large message should be delivered. It was sent reliably. - ASSERT_HAS_VALUE_AND_ASSIGN(DcSctpMessage m1, cb_z_.ConsumeReceivedMessage()); + ASSERT_HAS_VALUE_AND_ASSIGN(DcSctpMessage m1, z->cb.ConsumeReceivedMessage()); EXPECT_EQ(m1.stream_id(), StreamID(1)); EXPECT_THAT(m1.payload(), SizeIs(kLargeMessageSize)); // But none of the smaller messages. - EXPECT_FALSE(cb_z_.ConsumeReceivedMessage().has_value()); + EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value()); - MaybeHandoverSocketZAndSendMessage(); + MaybeHandoverSocketAndSendMessage(a, std::move(z)); } TEST_P(DcSctpSocketParametrizedTest, HasReasonableBufferedAmountValues) { - ConnectSockets(); - MaybeHandoverSocketZ(); + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); - EXPECT_EQ(sock_a_->buffered_amount(StreamID(1)), 0u); + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), + EXPECT_EQ(a.socket.buffered_amount(StreamID(1)), 0u); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), std::vector(kSmallMessageSize)), kSendOptions); // Sending a small message will directly send it as a single packet, so // nothing is left in the queue. - EXPECT_EQ(sock_a_->buffered_amount(StreamID(1)), 0u); + EXPECT_EQ(a.socket.buffered_amount(StreamID(1)), 0u); - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), std::vector(kLargeMessageSize)), kSendOptions); // Sending a message will directly start sending a few packets, so the // buffered amount is not the full message size. - EXPECT_GT(sock_a_->buffered_amount(StreamID(1)), 0u); - EXPECT_LT(sock_a_->buffered_amount(StreamID(1)), kLargeMessageSize); + EXPECT_GT(a.socket.buffered_amount(StreamID(1)), 0u); + EXPECT_LT(a.socket.buffered_amount(StreamID(1)), kLargeMessageSize); - MaybeHandoverSocketZAndSendMessage(); + MaybeHandoverSocketAndSendMessage(a, std::move(z)); } -TEST_F(DcSctpSocketTest, HasDefaultOnBufferedAmountLowValueZero) { - EXPECT_EQ(sock_a_->buffered_amount_low_threshold(StreamID(1)), 0u); +TEST(DcSctpSocketTest, HasDefaultOnBufferedAmountLowValueZero) { + SocketUnderTest a("A"); + EXPECT_EQ(a.socket.buffered_amount_low_threshold(StreamID(1)), 0u); } TEST_P(DcSctpSocketParametrizedTest, TriggersOnBufferedAmountLowWithDefaultValueZero) { - EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0); - ConnectSockets(); - MaybeHandoverSocketZ(); + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); - EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(1))); - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), + EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0); + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(1))); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), std::vector(kSmallMessageSize)), kSendOptions); - ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + ExchangeMessages(a, *z); - EXPECT_CALL(cb_a_, OnBufferedAmountLow).WillRepeatedly(testing::Return()); - MaybeHandoverSocketZAndSendMessage(); + EXPECT_CALL(a.cb, OnBufferedAmountLow).WillRepeatedly(testing::Return()); + MaybeHandoverSocketAndSendMessage(a, std::move(z)); } TEST_P(DcSctpSocketParametrizedTest, @@ -1625,64 +1691,70 @@ TEST_P(DcSctpSocketParametrizedTest, static constexpr size_t kMessageSize = 1000; static constexpr size_t kBufferedAmountLowThreshold = kMessageSize * 10; - sock_a_->SetBufferedAmountLowThreshold(StreamID(1), + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); + + a.socket.SetBufferedAmountLowThreshold(StreamID(1), kBufferedAmountLowThreshold); - EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0); - ConnectSockets(); - MaybeHandoverSocketZ(); + EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0); + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); - EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(1))).Times(0); - sock_a_->Send( + EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(1))).Times(0); + a.socket.Send( DcSctpMessage(StreamID(1), PPID(53), std::vector(kMessageSize)), kSendOptions); - ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + ExchangeMessages(a, *z); - sock_a_->Send( + a.socket.Send( DcSctpMessage(StreamID(1), PPID(53), std::vector(kMessageSize)), kSendOptions); - ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + ExchangeMessages(a, *z); - MaybeHandoverSocketZAndSendMessage(); + MaybeHandoverSocketAndSendMessage(a, std::move(z)); } TEST_P(DcSctpSocketParametrizedTest, TriggersOnBufferedAmountMultipleTimes) { static constexpr size_t kMessageSize = 1000; static constexpr size_t kBufferedAmountLowThreshold = kMessageSize / 2; - sock_a_->SetBufferedAmountLowThreshold(StreamID(1), + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); + + a.socket.SetBufferedAmountLowThreshold(StreamID(1), kBufferedAmountLowThreshold); - EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0); - ConnectSockets(); - MaybeHandoverSocketZ(); + EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0); + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); - EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(1))).Times(3); - EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(2))).Times(2); - sock_a_->Send( + EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(1))).Times(3); + EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(2))).Times(2); + a.socket.Send( DcSctpMessage(StreamID(1), PPID(53), std::vector(kMessageSize)), kSendOptions); - ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + ExchangeMessages(a, *z); - sock_a_->Send( + a.socket.Send( DcSctpMessage(StreamID(2), PPID(53), std::vector(kMessageSize)), kSendOptions); - ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + ExchangeMessages(a, *z); - sock_a_->Send( + a.socket.Send( DcSctpMessage(StreamID(1), PPID(53), std::vector(kMessageSize)), kSendOptions); - ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + ExchangeMessages(a, *z); - sock_a_->Send( + a.socket.Send( DcSctpMessage(StreamID(2), PPID(53), std::vector(kMessageSize)), kSendOptions); - ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + ExchangeMessages(a, *z); - sock_a_->Send( + a.socket.Send( DcSctpMessage(StreamID(1), PPID(53), std::vector(kMessageSize)), kSendOptions); - ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + ExchangeMessages(a, *z); - MaybeHandoverSocketZAndSendMessage(); + MaybeHandoverSocketAndSendMessage(a, std::move(z)); } TEST_P(DcSctpSocketParametrizedTest, @@ -1690,72 +1762,83 @@ TEST_P(DcSctpSocketParametrizedTest, static constexpr size_t kMessageSize = 1000; static constexpr size_t kBufferedAmountLowThreshold = kMessageSize * 1.5; - sock_a_->SetBufferedAmountLowThreshold(StreamID(1), - kBufferedAmountLowThreshold); - EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0); - ConnectSockets(); - MaybeHandoverSocketZ(); + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); - EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0); + a.socket.SetBufferedAmountLowThreshold(StreamID(1), + kBufferedAmountLowThreshold); + EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0); + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0); // Add a few messages to fill up the congestion window. When that is full, // messages will start to be fully buffered. - while (sock_a_->buffered_amount(StreamID(1)) <= kBufferedAmountLowThreshold) { - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), + while (a.socket.buffered_amount(StreamID(1)) <= kBufferedAmountLowThreshold) { + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), std::vector(kMessageSize)), kSendOptions); } - size_t initial_buffered = sock_a_->buffered_amount(StreamID(1)); + size_t initial_buffered = a.socket.buffered_amount(StreamID(1)); ASSERT_GT(initial_buffered, kBufferedAmountLowThreshold); // Start ACKing packets, which will empty the send queue, and trigger the // callback. - EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(1))).Times(1); - ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(1))).Times(1); + ExchangeMessages(a, *z); - MaybeHandoverSocketZAndSendMessage(); + MaybeHandoverSocketAndSendMessage(a, std::move(z)); } TEST_P(DcSctpSocketParametrizedTest, DoesntTriggerOnTotalBufferAmountLowWhenBelow) { - ConnectSockets(); - MaybeHandoverSocketZ(); + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); - EXPECT_CALL(cb_a_, OnTotalBufferedAmountLow).Times(0); + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), + EXPECT_CALL(a.cb, OnTotalBufferedAmountLow).Times(0); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), std::vector(kLargeMessageSize)), kSendOptions); - ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + ExchangeMessages(a, *z); - MaybeHandoverSocketZAndSendMessage(); + MaybeHandoverSocketAndSendMessage(a, std::move(z)); } TEST_P(DcSctpSocketParametrizedTest, TriggersOnTotalBufferAmountLowWhenCrossingThreshold) { - ConnectSockets(); - MaybeHandoverSocketZ(); + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); - EXPECT_CALL(cb_a_, OnTotalBufferedAmountLow).Times(0); + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_CALL(a.cb, OnTotalBufferedAmountLow).Times(0); // Fill up the send queue completely. for (;;) { - if (sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), + if (a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), std::vector(kLargeMessageSize)), kSendOptions) == SendStatus::kErrorResourceExhaustion) { break; } } - EXPECT_CALL(cb_a_, OnTotalBufferedAmountLow).Times(1); - ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + EXPECT_CALL(a.cb, OnTotalBufferedAmountLow).Times(1); + ExchangeMessages(a, *z); - MaybeHandoverSocketZAndSendMessage(); + MaybeHandoverSocketAndSendMessage(a, std::move(z)); } -TEST_F(DcSctpSocketTest, InitialMetricsAreZeroed) { - Metrics metrics = sock_a_->GetMetrics(); +TEST(DcSctpSocketTest, InitialMetricsAreZeroed) { + SocketUnderTest a("A"); + + Metrics metrics = a.socket.GetMetrics(); EXPECT_EQ(metrics.tx_packets_count, 0u); EXPECT_EQ(metrics.tx_messages_count, 0u); EXPECT_EQ(metrics.cwnd_bytes.has_value(), false); @@ -1766,85 +1849,90 @@ TEST_F(DcSctpSocketTest, InitialMetricsAreZeroed) { EXPECT_EQ(metrics.peer_rwnd_bytes.has_value(), false); } -TEST_F(DcSctpSocketTest, RxAndTxPacketMetricsIncrease) { - ConnectSockets(); +TEST(DcSctpSocketTest, RxAndTxPacketMetricsIncrease) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); - const size_t initial_a_rwnd = options_.max_receiver_window_buffer_size * + ConnectSockets(a, z); + + const size_t initial_a_rwnd = a.options.max_receiver_window_buffer_size * ReassemblyQueue::kHighWatermarkLimit; - EXPECT_EQ(sock_a_->GetMetrics().tx_packets_count, 2u); - EXPECT_EQ(sock_a_->GetMetrics().rx_packets_count, 2u); - EXPECT_EQ(sock_a_->GetMetrics().tx_messages_count, 0u); - EXPECT_EQ(*sock_a_->GetMetrics().cwnd_bytes, - options_.cwnd_mtus_initial * options_.mtu); - EXPECT_EQ(sock_a_->GetMetrics().unack_data_count, 0u); + EXPECT_EQ(a.socket.GetMetrics().tx_packets_count, 2u); + EXPECT_EQ(a.socket.GetMetrics().rx_packets_count, 2u); + EXPECT_EQ(a.socket.GetMetrics().tx_messages_count, 0u); + EXPECT_EQ(*a.socket.GetMetrics().cwnd_bytes, + a.options.cwnd_mtus_initial * a.options.mtu); + EXPECT_EQ(a.socket.GetMetrics().unack_data_count, 0u); - EXPECT_EQ(sock_z_->GetMetrics().rx_packets_count, 2u); - EXPECT_EQ(sock_z_->GetMetrics().rx_messages_count, 0u); + EXPECT_EQ(z.socket.GetMetrics().rx_packets_count, 2u); + EXPECT_EQ(z.socket.GetMetrics().rx_messages_count, 0u); - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); - EXPECT_EQ(sock_a_->GetMetrics().unack_data_count, 1u); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + EXPECT_EQ(a.socket.GetMetrics().unack_data_count, 1u); - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); // DATA - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); // SACK - EXPECT_EQ(*sock_a_->GetMetrics().peer_rwnd_bytes, initial_a_rwnd); - EXPECT_EQ(sock_a_->GetMetrics().unack_data_count, 0u); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // DATA + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // SACK + EXPECT_EQ(*a.socket.GetMetrics().peer_rwnd_bytes, initial_a_rwnd); + EXPECT_EQ(a.socket.GetMetrics().unack_data_count, 0u); - EXPECT_TRUE(cb_z_.ConsumeReceivedMessage().has_value()); + EXPECT_TRUE(z.cb.ConsumeReceivedMessage().has_value()); - EXPECT_EQ(sock_a_->GetMetrics().tx_packets_count, 3u); - EXPECT_EQ(sock_a_->GetMetrics().rx_packets_count, 3u); - EXPECT_EQ(sock_a_->GetMetrics().tx_messages_count, 1u); + EXPECT_EQ(a.socket.GetMetrics().tx_packets_count, 3u); + EXPECT_EQ(a.socket.GetMetrics().rx_packets_count, 3u); + EXPECT_EQ(a.socket.GetMetrics().tx_messages_count, 1u); - EXPECT_EQ(sock_z_->GetMetrics().rx_packets_count, 3u); - EXPECT_EQ(sock_z_->GetMetrics().rx_messages_count, 1u); + EXPECT_EQ(z.socket.GetMetrics().rx_packets_count, 3u); + EXPECT_EQ(z.socket.GetMetrics().rx_messages_count, 1u); // Send one more (large - fragmented), and receive the delayed SACK. - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), - std::vector(options_.mtu * 2 + 1)), + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector(a.options.mtu * 2 + 1)), kSendOptions); - EXPECT_EQ(sock_a_->GetMetrics().unack_data_count, 3u); + EXPECT_EQ(a.socket.GetMetrics().unack_data_count, 3u); - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); // DATA - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); // DATA + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // DATA + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // DATA - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); // SACK - EXPECT_EQ(sock_a_->GetMetrics().unack_data_count, 1u); - EXPECT_GT(*sock_a_->GetMetrics().peer_rwnd_bytes, 0u); - EXPECT_LT(*sock_a_->GetMetrics().peer_rwnd_bytes, initial_a_rwnd); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // SACK + EXPECT_EQ(a.socket.GetMetrics().unack_data_count, 1u); + EXPECT_GT(*a.socket.GetMetrics().peer_rwnd_bytes, 0u); + EXPECT_LT(*a.socket.GetMetrics().peer_rwnd_bytes, initial_a_rwnd); - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); // DATA + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // DATA - EXPECT_TRUE(cb_z_.ConsumeReceivedMessage().has_value()); + EXPECT_TRUE(z.cb.ConsumeReceivedMessage().has_value()); - EXPECT_EQ(sock_a_->GetMetrics().tx_packets_count, 6u); - EXPECT_EQ(sock_a_->GetMetrics().rx_packets_count, 4u); - EXPECT_EQ(sock_a_->GetMetrics().tx_messages_count, 2u); + EXPECT_EQ(a.socket.GetMetrics().tx_packets_count, 6u); + EXPECT_EQ(a.socket.GetMetrics().rx_packets_count, 4u); + EXPECT_EQ(a.socket.GetMetrics().tx_messages_count, 2u); - EXPECT_EQ(sock_z_->GetMetrics().rx_packets_count, 6u); - EXPECT_EQ(sock_z_->GetMetrics().rx_messages_count, 2u); + EXPECT_EQ(z.socket.GetMetrics().rx_packets_count, 6u); + EXPECT_EQ(z.socket.GetMetrics().rx_messages_count, 2u); // Delayed sack - AdvanceTime(options_.delayed_ack_max_timeout); - RunTimers(); + AdvanceTime(a, z, a.options.delayed_ack_max_timeout); - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); // SACK - EXPECT_EQ(sock_a_->GetMetrics().unack_data_count, 0u); - EXPECT_EQ(sock_a_->GetMetrics().rx_packets_count, 5u); - EXPECT_EQ(*sock_a_->GetMetrics().peer_rwnd_bytes, initial_a_rwnd); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // SACK + EXPECT_EQ(a.socket.GetMetrics().unack_data_count, 0u); + EXPECT_EQ(a.socket.GetMetrics().rx_packets_count, 5u); + EXPECT_EQ(*a.socket.GetMetrics().peer_rwnd_bytes, initial_a_rwnd); } TEST_P(DcSctpSocketParametrizedTest, UnackDataAlsoIncludesSendQueue) { - ConnectSockets(); - MaybeHandoverSocketZ(); + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), std::vector(kLargeMessageSize)), kSendOptions); size_t payload_bytes = - options_.mtu - SctpPacket::kHeaderSize - DataChunk::kHeaderSize; + a.options.mtu - SctpPacket::kHeaderSize - DataChunk::kHeaderSize; - size_t expected_sent_packets = options_.cwnd_mtus_initial; + size_t expected_sent_packets = a.options.cwnd_mtus_initial; size_t expected_queued_bytes = kLargeMessageSize - expected_sent_packets * payload_bytes; @@ -1853,42 +1941,48 @@ TEST_P(DcSctpSocketParametrizedTest, UnackDataAlsoIncludesSendQueue) { // Due to alignment, padding etc, it's hard to calculate the exact number, but // it should be in this range. - EXPECT_GE(sock_a_->GetMetrics().unack_data_count, + EXPECT_GE(a.socket.GetMetrics().unack_data_count, expected_sent_packets + expected_queued_packets); - EXPECT_LE(sock_a_->GetMetrics().unack_data_count, + EXPECT_LE(a.socket.GetMetrics().unack_data_count, expected_sent_packets + expected_queued_packets + 2); - MaybeHandoverSocketZAndSendMessage(); + MaybeHandoverSocketAndSendMessage(a, std::move(z)); } TEST_P(DcSctpSocketParametrizedTest, DoesntSendMoreThanMaxBurstPackets) { - ConnectSockets(); - MaybeHandoverSocketZ(); + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), std::vector(kLargeMessageSize)), kSendOptions); for (int i = 0; i < kMaxBurstPackets; ++i) { - std::vector packet = cb_a_.ConsumeSentPacket(); + std::vector packet = a.cb.ConsumeSentPacket(); EXPECT_THAT(packet, Not(IsEmpty())); - sock_z_->ReceivePacket(std::move(packet)); // DATA + z->socket.ReceivePacket(std::move(packet)); // DATA } - EXPECT_THAT(cb_a_.ConsumeSentPacket(), IsEmpty()); + EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty()); - ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); - MaybeHandoverSocketZAndSendMessage(); + ExchangeMessages(a, *z); + MaybeHandoverSocketAndSendMessage(a, std::move(z)); } TEST_P(DcSctpSocketParametrizedTest, SendsOnlyLargePackets) { - ConnectSockets(); - MaybeHandoverSocketZ(); + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); // A really large message, to ensure that the congestion window is often full. constexpr size_t kMessageSize = 100000; - sock_a_->Send( + a.socket.Send( DcSctpMessage(StreamID(1), PPID(53), std::vector(kMessageSize)), kSendOptions); @@ -1896,21 +1990,21 @@ TEST_P(DcSctpSocketParametrizedTest, SendsOnlyLargePackets) { std::vector data_packet_sizes; do { delivered_packet = false; - std::vector packet_from_a = cb_a_.ConsumeSentPacket(); + std::vector packet_from_a = a.cb.ConsumeSentPacket(); if (!packet_from_a.empty()) { data_packet_sizes.push_back(packet_from_a.size()); delivered_packet = true; - sock_z_->ReceivePacket(std::move(packet_from_a)); + z->socket.ReceivePacket(std::move(packet_from_a)); } - std::vector packet_from_z = cb_z_.ConsumeSentPacket(); + std::vector packet_from_z = z->cb.ConsumeSentPacket(); if (!packet_from_z.empty()) { delivered_packet = true; - sock_a_->ReceivePacket(std::move(packet_from_z)); + a.socket.ReceivePacket(std::move(packet_from_z)); } } while (delivered_packet); size_t packet_payload_bytes = - options_.mtu - SctpPacket::kHeaderSize - DataChunk::kHeaderSize; + a.options.mtu - SctpPacket::kHeaderSize - DataChunk::kHeaderSize; // +1 accounts for padding, and rounding up. size_t expected_packets = (kMessageSize + packet_payload_bytes - 1) / packet_payload_bytes + 1; @@ -1922,158 +2016,168 @@ TEST_P(DcSctpSocketParametrizedTest, SendsOnlyLargePackets) { for (size_t size : data_packet_sizes) { // The 4 is for padding/alignment. - EXPECT_GE(size, options_.mtu - 4); + EXPECT_GE(size, a.options.mtu - 4); } - MaybeHandoverSocketZAndSendMessage(); + MaybeHandoverSocketAndSendMessage(a, std::move(z)); } TEST_P(DcSctpSocketParametrizedTest, DoesntBundleForwardTsnWithData) { - ConnectSockets(); - MaybeHandoverSocketZ(); + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); // Force an RTT measurement using heartbeats. - AdvanceTime(options_.heartbeat_interval); - RunTimers(); + AdvanceTime(a, *z, a.options.heartbeat_interval); // HEARTBEAT - std::vector hb_req_a = cb_a_.ConsumeSentPacket(); - std::vector hb_req_z = cb_z_.ConsumeSentPacket(); + std::vector hb_req_a = a.cb.ConsumeSentPacket(); + std::vector hb_req_z = z->cb.ConsumeSentPacket(); constexpr DurationMs kRtt = DurationMs(80); - AdvanceTime(kRtt); - sock_z_->ReceivePacket(hb_req_a); - sock_a_->ReceivePacket(hb_req_z); + AdvanceTime(a, *z, kRtt); + z->socket.ReceivePacket(hb_req_a); + a.socket.ReceivePacket(hb_req_z); // HEARTBEAT_ACK - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); SendOptions send_options; send_options.max_retransmissions = 0; - std::vector payload(options_.mtu - 100); + std::vector payload(a.options.mtu - 100); // Send an initial message that is received, but the SACK was lost - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options); // DATA - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); // SACK (lost) - std::vector sack = cb_z_.ConsumeSentPacket(); + std::vector sack = z->cb.ConsumeSentPacket(); // Queue enough messages to fill the congestion window. do { - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options); - } while (!cb_a_.ConsumeSentPacket().empty()); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options); + } while (!a.cb.ConsumeSentPacket().empty()); // Enqueue at least one more. - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options); // Let all of them expire by T3-RTX and inspect what's sent. - AdvanceTime(options_.rto_initial); - RunTimers(); + AdvanceTime(a, *z, a.options.rto_initial); - std::vector sent1 = cb_a_.ConsumeSentPacket(); + std::vector sent1 = a.cb.ConsumeSentPacket(); ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet1, SctpPacket::Parse(sent1)); EXPECT_THAT(packet1.descriptors(), SizeIs(1)); EXPECT_EQ(packet1.descriptors()[0].type, ForwardTsnChunk::kType); - std::vector sent2 = cb_a_.ConsumeSentPacket(); + std::vector sent2 = a.cb.ConsumeSentPacket(); ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet2, SctpPacket::Parse(sent2)); EXPECT_GE(packet2.descriptors().size(), 1u); EXPECT_EQ(packet2.descriptors()[0].type, DataChunk::kType); // Drop all remaining packets that A has sent. - while (!cb_a_.ConsumeSentPacket().empty()) { + while (!a.cb.ConsumeSentPacket().empty()) { } // Replay the SACK, and see if a FORWARD-TSN is sent again. - sock_a_->ReceivePacket(sack); + a.socket.ReceivePacket(sack); // It shouldn't be sent as not enough time has passed yet. Instead, more // DATA chunks are sent, that are in the queue. - std::vector sent3 = cb_a_.ConsumeSentPacket(); + std::vector sent3 = a.cb.ConsumeSentPacket(); ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet3, SctpPacket::Parse(sent3)); EXPECT_GE(packet2.descriptors().size(), 1u); EXPECT_EQ(packet3.descriptors()[0].type, DataChunk::kType); // Now let RTT time pass, to allow a FORWARD-TSN to be sent again. - AdvanceTime(kRtt); - sock_a_->ReceivePacket(sack); + AdvanceTime(a, *z, kRtt); + a.socket.ReceivePacket(sack); - std::vector sent4 = cb_a_.ConsumeSentPacket(); + std::vector sent4 = a.cb.ConsumeSentPacket(); ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet4, SctpPacket::Parse(sent4)); EXPECT_THAT(packet4.descriptors(), SizeIs(1)); EXPECT_EQ(packet4.descriptors()[0].type, ForwardTsnChunk::kType); } -TEST_F(DcSctpSocketTest, SendMessagesAfterHandover) { - ConnectSockets(); +TEST(DcSctpSocketTest, SendMessagesAfterHandover) { + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); + + ConnectSockets(a, *z); // Send message before handover to move socket to a not initial state - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); - cb_z_.ConsumeReceivedMessage(); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + z->cb.ConsumeReceivedMessage(); - HandoverSocketZ(); + z = HandoverSocket(std::move(z)); absl::optional msg; RTC_LOG(LS_INFO) << "Sending A #1"; - sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {3, 4}), kSendOptions); - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {3, 4}), kSendOptions); + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); - msg = cb_z_.ConsumeReceivedMessage(); + msg = z->cb.ConsumeReceivedMessage(); ASSERT_TRUE(msg.has_value()); EXPECT_EQ(msg->stream_id(), StreamID(1)); EXPECT_THAT(msg->payload(), testing::ElementsAre(3, 4)); RTC_LOG(LS_INFO) << "Sending A #2"; - sock_a_->Send(DcSctpMessage(StreamID(2), PPID(53), {5, 6}), kSendOptions); - sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); + a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {5, 6}), kSendOptions); + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); - msg = cb_z_.ConsumeReceivedMessage(); + msg = z->cb.ConsumeReceivedMessage(); ASSERT_TRUE(msg.has_value()); EXPECT_EQ(msg->stream_id(), StreamID(2)); EXPECT_THAT(msg->payload(), testing::ElementsAre(5, 6)); RTC_LOG(LS_INFO) << "Sending Z #1"; - sock_z_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2, 3}), kSendOptions); - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); // ack - sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); // data + z->socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2, 3}), kSendOptions); + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); // ack + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); // data - msg = cb_a_.ConsumeReceivedMessage(); + msg = a.cb.ConsumeReceivedMessage(); ASSERT_TRUE(msg.has_value()); EXPECT_EQ(msg->stream_id(), StreamID(1)); EXPECT_THAT(msg->payload(), testing::ElementsAre(1, 2, 3)); } -TEST_F(DcSctpSocketTest, CanDetectDcsctpImplementation) { - ConnectSockets(); +TEST(DcSctpSocketTest, CanDetectDcsctpImplementation) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); - EXPECT_EQ(sock_a_->peer_implementation(), SctpImplementation::kDcsctp); + ConnectSockets(a, z); + + EXPECT_EQ(a.socket.peer_implementation(), SctpImplementation::kDcsctp); // As A initiated the connection establishment, Z will not receive enough // information to know about A's implementation - EXPECT_EQ(sock_z_->peer_implementation(), SctpImplementation::kUnknown); + EXPECT_EQ(z.socket.peer_implementation(), SctpImplementation::kUnknown); } -TEST_F(DcSctpSocketTest, BothCanDetectDcsctpImplementation) { - EXPECT_CALL(cb_a_, OnConnected).Times(1); - EXPECT_CALL(cb_z_, OnConnected).Times(1); - sock_a_->Connect(); - sock_z_->Connect(); +TEST(DcSctpSocketTest, BothCanDetectDcsctpImplementation) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); - ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_); + EXPECT_CALL(a.cb, OnConnected).Times(1); + EXPECT_CALL(z.cb, OnConnected).Times(1); + a.socket.Connect(); + z.socket.Connect(); - EXPECT_EQ(sock_a_->peer_implementation(), SctpImplementation::kDcsctp); - EXPECT_EQ(sock_z_->peer_implementation(), SctpImplementation::kDcsctp); + ExchangeMessages(a, z); + + EXPECT_EQ(a.socket.peer_implementation(), SctpImplementation::kDcsctp); + EXPECT_EQ(z.socket.peer_implementation(), SctpImplementation::kDcsctp); } } // namespace } // namespace dcsctp diff --git a/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h b/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h index 1e30777e89..803f688a84 100644 --- a/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h +++ b/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h @@ -153,12 +153,6 @@ class MockDcSctpSocketCallbacks : public DcSctpSocketCallbacks { return timeout_manager_.GetNextExpiredTimeout(); } - void Reset() { - sent_packets_.clear(); - received_messages_.clear(); - timeout_manager_.Reset(); - } - private: const std::string log_prefix_; TimeMs now_ = TimeMs(0); diff --git a/net/dcsctp/timer/fake_timeout.h b/net/dcsctp/timer/fake_timeout.h index f87227577b..74ffe5af29 100644 --- a/net/dcsctp/timer/fake_timeout.h +++ b/net/dcsctp/timer/fake_timeout.h @@ -97,8 +97,6 @@ class FakeTimeoutManager { return absl::nullopt; } - void Reset() { timers_.clear(); } - private: const std::function get_time_; webrtc::flat_set timers_;