dcsctp: implement socket handover in the DcSctpSocket class and expose the functionality in the API

Bug: webrtc:13154
Change-Id: Idf4f4028c8e65943cb6b41fab0baef1b3584205d
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/232126
Reviewed-by: Victor Boivie <boivie@webrtc.org>
Commit-Queue: Sergey Sukhanov <sergeysu@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#35029}
This commit is contained in:
Sergey Sukhanov 2021-09-17 15:32:48 +02:00 committed by WebRTC LUCI CQ
parent 4893dbe7f1
commit 4397281f38
10 changed files with 492 additions and 69 deletions

View File

@ -24,6 +24,25 @@ namespace dcsctp {
// for serialization. Serialization is not provided by dcSCTP. If needed it has
// to be implemented in the calling client.
struct DcSctpSocketHandoverState {
enum class SocketState {
kClosed,
kConnected,
};
SocketState socket_state = SocketState::kClosed;
uint32_t my_verification_tag = 0;
uint32_t my_initial_tsn = 0;
uint32_t peer_verification_tag = 0;
uint32_t peer_initial_tsn = 0;
uint64_t tie_tag = 0;
struct Capabilities {
bool partial_reliability = false;
bool message_interleaving = false;
bool reconfig = false;
};
Capabilities capabilities;
struct Transmission {
uint32_t next_tsn = 0;
uint32_t next_reset_req_sn = 0;
@ -98,6 +117,7 @@ class HandoverReadinessStatus
value() |= status.value();
return *this;
}
std::string ToString() const;
};
} // namespace dcsctp

View File

@ -17,6 +17,7 @@
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "api/array_view.h"
#include "net/dcsctp/public/dcsctp_handover_state.h"
#include "net/dcsctp/public/dcsctp_message.h"
#include "net/dcsctp/public/dcsctp_options.h"
#include "net/dcsctp/public/packet_observer.h"
@ -355,6 +356,14 @@ class DcSctpSocketInterface {
// `DcSctpSocketCallbacks::OnConnected` will be called on success.
virtual void Connect() = 0;
// Puts this socket to the state in which the original socket was when its
// `DcSctpSocketHandoverState` was captured by `GetHandoverStateAndClose`.
// `RestoreFromState` is allowed only on the closed socket.
// `DcSctpSocketCallbacks::OnConnected` will be called if a connected socket
// state is restored.
// `DcSctpSocketCallbacks::OnError` will be called on error.
virtual void RestoreFromState(const DcSctpSocketHandoverState& state) = 0;
// Gracefully shutdowns the socket and sends all outstanding data. This is an
// asynchronous operation and `DcSctpSocketCallbacks::OnClosed` will be called
// on success.
@ -417,6 +426,20 @@ class DcSctpSocketInterface {
// Retrieves the latest metrics.
virtual Metrics GetMetrics() const = 0;
// Returns empty bitmask if the socket is in the state in which a snapshot of
// the state can be made by `GetHandoverStateAndClose()`. Return value is
// invalidated by a call to any non-const method.
virtual HandoverReadinessStatus GetHandoverReadiness() const = 0;
// Collects a snapshot of the socket state that can be used to reconstruct
// this socket in another process. On success this socket object is closed
// synchronously and no callbacks will be made after the method has returned.
// The method fails if the socket is not in a state ready for handover.
// nullopt indicates the failure. `DcSctpSocketCallbacks::OnClosed` will be
// called on success.
virtual absl::optional<DcSctpSocketHandoverState>
GetHandoverStateAndClose() = 0;
};
} // namespace dcsctp

View File

@ -26,6 +26,11 @@ class MockDcSctpSocket : public DcSctpSocketInterface {
MOCK_METHOD(void, Connect, (), (override));
MOCK_METHOD(void,
RestoreFromState,
(const DcSctpSocketHandoverState&),
(override));
MOCK_METHOD(void, Shutdown, (), (override));
MOCK_METHOD(void, Close, (), (override));
@ -59,6 +64,15 @@ class MockDcSctpSocket : public DcSctpSocketInterface {
(override));
MOCK_METHOD(Metrics, GetMetrics, (), (const, override));
MOCK_METHOD(HandoverReadinessStatus,
GetHandoverReadiness,
(),
(const, override));
MOCK_METHOD(absl::optional<DcSctpSocketHandoverState>,
GetHandoverStateAndClose,
(),
(override));
};
} // namespace dcsctp

View File

@ -139,8 +139,57 @@ TieTag MakeTieTag(DcSctpSocketCallbacks& cb) {
static_cast<uint64_t>(tie_tag_lower));
}
constexpr absl::string_view HandoverUnreadinessReasonToString(
HandoverUnreadinessReason reason) {
switch (reason) {
case HandoverUnreadinessReason::kWrongConnectionState:
return "WRONG_CONNECTION_STATE";
case HandoverUnreadinessReason::kSendQueueNotEmpty:
return "SEND_QUEUE_NOT_EMPTY";
case HandoverUnreadinessReason::kDataTrackerTsnBlocksPending:
return "DATA_TRACKER_TSN_BLOCKS_PENDING";
case HandoverUnreadinessReason::kReassemblyQueueDeliveredTSNsGap:
return "REASSEMBLY_QUEUE_DELIVERED_TSN_GAP";
case HandoverUnreadinessReason::kStreamResetDeferred:
return "STREAM_RESET_DEFERRED";
case HandoverUnreadinessReason::kOrderedStreamHasUnassembledChunks:
return "ORDERED_STREAM_HAS_UNASSEMBLED_CHUNKS";
case HandoverUnreadinessReason::kUnorderedStreamHasUnassembledChunks:
return "UNORDERED_STREAM_HAS_UNASSEMBLED_CHUNKS";
case HandoverUnreadinessReason::kRetransmissionQueueOutstandingData:
return "RETRANSMISSION_QUEUE_OUTSTANDING_DATA";
case HandoverUnreadinessReason::kRetransmissionQueueFastRecovery:
return "RETRANSMISSION_QUEUE_FAST_RECOVERY";
case HandoverUnreadinessReason::kRetransmissionQueueNotEmpty:
return "RETRANSMISSION_QUEUE_NOT_EMPTY";
case HandoverUnreadinessReason::kPendingStreamReset:
return "PENDING_STREAM_RESET";
case HandoverUnreadinessReason::kPendingStreamResetRequest:
return "PENDING_STREAM_RESET_REQUEST";
}
}
} // namespace
std::string HandoverReadinessStatus::ToString() const {
std::string result;
for (uint32_t bit = 1;
bit <= static_cast<uint32_t>(HandoverUnreadinessReason::kMax);
bit *= 2) {
auto flag = static_cast<HandoverUnreadinessReason>(bit);
if (Contains(flag)) {
if (!result.empty()) {
result.append(",");
}
absl::string_view s = HandoverUnreadinessReasonToString(flag);
result.append(s.data(), s.size());
}
}
if (result.empty()) {
result = "READY";
}
return result;
}
DcSctpSocket::DcSctpSocket(absl::string_view log_prefix,
DcSctpSocketCallbacks& callbacks,
std::unique_ptr<PacketObserver> packet_observer,
@ -286,6 +335,42 @@ void DcSctpSocket::Connect() {
callbacks_.TriggerDeferred();
}
void DcSctpSocket::RestoreFromState(const DcSctpSocketHandoverState& state) {
if (state_ != State::kClosed) {
callbacks_.OnError(ErrorKind::kUnsupportedOperation,
"Only closed socket can be restored from state");
} else {
if (state.socket_state ==
DcSctpSocketHandoverState::SocketState::kConnected) {
VerificationTag my_verification_tag =
VerificationTag(state.my_verification_tag);
connect_params_.verification_tag = my_verification_tag;
Capabilities capabilities;
capabilities.partial_reliability = state.capabilities.partial_reliability;
capabilities.message_interleaving =
state.capabilities.message_interleaving;
capabilities.reconfig = state.capabilities.reconfig;
tcb_ = std::make_unique<TransmissionControlBlock>(
timer_manager_, log_prefix_, options_, capabilities, callbacks_,
send_queue_, my_verification_tag, TSN(state.my_initial_tsn),
VerificationTag(state.peer_verification_tag),
TSN(state.peer_initial_tsn), static_cast<size_t>(0),
TieTag(state.tie_tag), packet_sender_,
[this]() { return state_ == State::kEstablished; }, &state);
RTC_DLOG(LS_VERBOSE) << log_prefix() << "Created peer TCB from state: "
<< tcb_->ToString();
SetState(State::kEstablished, "restored from handover state");
callbacks_.OnConnected();
}
}
RTC_DCHECK(IsConsistent());
callbacks_.TriggerDeferred();
}
void DcSctpSocket::Shutdown() {
if (tcb_ != nullptr) {
// https://tools.ietf.org/html/rfc4960#section-9.2
@ -1579,4 +1664,38 @@ void DcSctpSocket::SendShutdownAck() {
t2_shutdown_->Start();
}
HandoverReadinessStatus DcSctpSocket::GetHandoverReadiness() const {
HandoverReadinessStatus status;
if (state_ != State::kClosed && state_ != State::kEstablished) {
status.Add(HandoverUnreadinessReason::kWrongConnectionState);
}
if (!send_queue_.IsEmpty()) {
status.Add(HandoverUnreadinessReason::kSendQueueNotEmpty);
}
if (tcb_) {
status.Add(tcb_->GetHandoverReadiness());
}
return status;
}
absl::optional<DcSctpSocketHandoverState>
DcSctpSocket::GetHandoverStateAndClose() {
if (!GetHandoverReadiness().IsReady()) {
return absl::nullopt;
}
DcSctpSocketHandoverState state;
if (state_ == State::kClosed) {
state.socket_state = DcSctpSocketHandoverState::SocketState::kClosed;
} else if (state_ == State::kEstablished) {
state.socket_state = DcSctpSocketHandoverState::SocketState::kConnected;
tcb_->AddHandoverState(state);
InternalClose(ErrorKind::kNoError, "handover");
callbacks_.TriggerDeferred();
}
return std::move(state);
}
} // namespace dcsctp

View File

@ -85,6 +85,7 @@ class DcSctpSocket : public DcSctpSocketInterface {
void ReceivePacket(rtc::ArrayView<const uint8_t> data) override;
void HandleTimeout(TimeoutID timeout_id) override;
void Connect() override;
void RestoreFromState(const DcSctpSocketHandoverState& state) override;
void Shutdown() override;
void Close() override;
SendStatus Send(DcSctpMessage message,
@ -98,6 +99,8 @@ class DcSctpSocket : public DcSctpSocketInterface {
size_t buffered_amount_low_threshold(StreamID stream_id) const override;
void SetBufferedAmountLowThreshold(StreamID stream_id, size_t bytes) override;
Metrics GetMetrics() const override;
HandoverReadinessStatus GetHandoverReadiness() const override;
absl::optional<DcSctpSocketHandoverState> GetHandoverStateAndClose() override;
// Returns this socket's verification tag, or zero if not yet connected.
VerificationTag verification_tag() const {

View File

@ -315,6 +315,24 @@ class DcSctpSocketTest : public testing::Test {
EXPECT_EQ(sock_z_->state(), SocketState::kConnected);
}
void HandoverSocketZ() {
ASSERT_EQ(sock_z_->GetHandoverReadiness(), HandoverReadinessStatus());
bool is_closed = sock_z_->state() == SocketState::kClosed;
if (!is_closed) {
EXPECT_CALL(cb_z_, OnClosed).Times(1);
}
absl::optional<DcSctpSocketHandoverState> handover_state =
sock_z_->GetHandoverStateAndClose();
EXPECT_TRUE(handover_state.has_value());
cb_z_.Reset();
sock_z_ = std::make_unique<DcSctpSocket>("Z", cb_z_, GetPacketObserver("Z"),
options_);
if (!is_closed) {
EXPECT_CALL(cb_z_, OnConnected).Times(1);
}
sock_z_->RestoreFromState(*handover_state);
}
const DcSctpOptions options_;
testing::NiceMock<MockDcSctpSocketCallbacks> cb_a_;
testing::NiceMock<MockDcSctpSocketCallbacks> cb_z_;
@ -322,6 +340,52 @@ class DcSctpSocketTest : public testing::Test {
std::unique_ptr<DcSctpSocket> sock_z_;
};
// Test parameter that controls whether to perform handovers during the test. A
// test can have multiple points where it conditionally hands over socket Z.
// Either socket Z will be handed over at all those points or handed over never.
enum class HandoverMode {
kNoHandover,
kPerformHandovers,
};
class DcSctpSocketParametrizedTest
: public DcSctpSocketTest,
public ::testing::WithParamInterface<HandoverMode> {
protected:
// Trigger handover for socket Z depending on the current test param.
void MaybeHandoverSocketZ() {
if (GetParam() == HandoverMode::kPerformHandovers) {
HandoverSocketZ();
}
}
// Trigger handover for socket Z depending on the current test param.
// Then checks message passing to verify the handed over socket is functional.
void MaybeHandoverSocketZAndSendMessage() {
if (GetParam() == HandoverMode::kPerformHandovers) {
HandoverSocketZ();
}
ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_);
sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_);
absl::optional<DcSctpMessage> msg = cb_z_.ConsumeReceivedMessage();
ASSERT_TRUE(msg.has_value());
EXPECT_EQ(msg->stream_id(), StreamID(1));
}
};
INSTANTIATE_TEST_SUITE_P(Handovers,
DcSctpSocketParametrizedTest,
testing::Values(HandoverMode::kNoHandover,
HandoverMode::kPerformHandovers),
[](const auto& test_info) {
return test_info.param ==
HandoverMode::kPerformHandovers
? "WithHandovers"
: "NoHandover";
});
TEST_F(DcSctpSocketTest, EstablishConnection) {
EXPECT_CALL(cb_a_, OnConnected).Times(1);
EXPECT_CALL(cb_z_, OnConnected).Times(1);
@ -566,8 +630,8 @@ TEST_F(DcSctpSocketTest, ResendingCookieEchoTooManyTimesAborts) {
TEST_F(DcSctpSocketTest, DoesntSendMorePacketsUntilCookieAckHasBeenReceived) {
sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
std::vector<uint8_t>(kLargeMessageSize)),
kSendOptions);
std::vector<uint8_t>(kLargeMessageSize)),
kSendOptions);
sock_a_->Connect();
// Z reads INIT, produces INIT_ACK
@ -623,11 +687,13 @@ TEST_F(DcSctpSocketTest, DoesntSendMorePacketsUntilCookieAckHasBeenReceived) {
SizeIs(kLargeMessageSize));
}
TEST_F(DcSctpSocketTest, ShutdownConnection) {
TEST_P(DcSctpSocketParametrizedTest, ShutdownConnection) {
ConnectSockets();
MaybeHandoverSocketZ();
RTC_LOG(LS_INFO) << "Shutting down";
EXPECT_CALL(cb_z_, OnClosed).Times(1);
sock_a_->Shutdown();
// Z reads SHUTDOWN, produces SHUTDOWN_ACK
sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket());
@ -638,6 +704,9 @@ TEST_F(DcSctpSocketTest, ShutdownConnection) {
EXPECT_EQ(sock_a_->state(), SocketState::kClosed);
EXPECT_EQ(sock_z_->state(), SocketState::kClosed);
MaybeHandoverSocketZ();
EXPECT_EQ(sock_z_->state(), SocketState::kClosed);
}
TEST_F(DcSctpSocketTest, ShutdownTimerExpiresTooManyTimeClosesConnection) {
@ -704,8 +773,9 @@ TEST_F(DcSctpSocketTest, SendMessageAfterEstablished) {
EXPECT_EQ(msg->stream_id(), StreamID(1));
}
TEST_F(DcSctpSocketTest, TimeoutResendsPacket) {
TEST_P(DcSctpSocketParametrizedTest, TimeoutResendsPacket) {
ConnectSockets();
MaybeHandoverSocketZ();
sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
cb_a_.ConsumeSentPacket();
@ -719,10 +789,13 @@ TEST_F(DcSctpSocketTest, TimeoutResendsPacket) {
absl::optional<DcSctpMessage> msg = cb_z_.ConsumeReceivedMessage();
ASSERT_TRUE(msg.has_value());
EXPECT_EQ(msg->stream_id(), StreamID(1));
MaybeHandoverSocketZAndSendMessage();
}
TEST_F(DcSctpSocketTest, SendALotOfBytesMissedSecondPacket) {
TEST_P(DcSctpSocketParametrizedTest, SendALotOfBytesMissedSecondPacket) {
ConnectSockets();
MaybeHandoverSocketZ();
std::vector<uint8_t> payload(kLargeMessageSize);
sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions);
@ -739,10 +812,13 @@ TEST_F(DcSctpSocketTest, SendALotOfBytesMissedSecondPacket) {
ASSERT_TRUE(msg.has_value());
EXPECT_EQ(msg->stream_id(), StreamID(1));
EXPECT_THAT(msg->payload(), testing::ElementsAreArray(payload));
MaybeHandoverSocketZAndSendMessage();
}
TEST_F(DcSctpSocketTest, SendingHeartbeatAnswersWithAck) {
TEST_P(DcSctpSocketParametrizedTest, SendingHeartbeatAnswersWithAck) {
ConnectSockets();
MaybeHandoverSocketZ();
// Inject a HEARTBEAT chunk
SctpPacket::Builder b(sock_a_->verification_tag(), DcSctpOptions());
@ -761,10 +837,13 @@ TEST_F(DcSctpSocketTest, SendingHeartbeatAnswersWithAck) {
HeartbeatAckChunk::Parse(ack_packet.descriptors()[0].data));
ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatInfoParameter info_param, ack.info());
EXPECT_THAT(info_param.info(), ElementsAre(1, 2, 3, 4));
MaybeHandoverSocketZAndSendMessage();
}
TEST_F(DcSctpSocketTest, ExpectHeartbeatToBeSent) {
TEST_P(DcSctpSocketParametrizedTest, ExpectHeartbeatToBeSent) {
ConnectSockets();
MaybeHandoverSocketZ();
EXPECT_THAT(cb_a_.ConsumeSentPacket(), IsEmpty());
@ -786,11 +865,16 @@ TEST_F(DcSctpSocketTest, ExpectHeartbeatToBeSent) {
// Feed it to Sock-z and expect a HEARTBEAT_ACK that will be propagated back.
sock_z_->ReceivePacket(hb_packet_raw);
sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket());
MaybeHandoverSocketZAndSendMessage();
}
TEST_F(DcSctpSocketTest, CloseConnectionAfterTooManyLostHeartbeats) {
TEST_P(DcSctpSocketParametrizedTest,
CloseConnectionAfterTooManyLostHeartbeats) {
ConnectSockets();
MaybeHandoverSocketZ();
EXPECT_CALL(cb_z_, OnClosed).Times(1);
EXPECT_THAT(cb_a_.ConsumeSentPacket(), testing::IsEmpty());
// Force-close socket Z so that it doesn't interfere from now on.
sock_z_->Close();
@ -825,12 +909,16 @@ TEST_F(DcSctpSocketTest, CloseConnectionAfterTooManyLostHeartbeats) {
// Should suffice as exceeding RTO
AdvanceTime(DurationMs(1000));
RunTimers();
MaybeHandoverSocketZ();
}
TEST_F(DcSctpSocketTest, RecoversAfterASuccessfulAck) {
TEST_P(DcSctpSocketParametrizedTest, RecoversAfterASuccessfulAck) {
ConnectSockets();
MaybeHandoverSocketZ();
EXPECT_THAT(cb_a_.ConsumeSentPacket(), testing::IsEmpty());
EXPECT_CALL(cb_z_, OnClosed).Times(1);
// Force-close socket Z so that it doesn't interfere from now on.
sock_z_->Close();
@ -882,8 +970,9 @@ TEST_F(DcSctpSocketTest, RecoversAfterASuccessfulAck) {
EXPECT_EQ(another_packet.descriptors()[0].type, HeartbeatRequestChunk::kType);
}
TEST_F(DcSctpSocketTest, ResetStream) {
TEST_P(DcSctpSocketParametrizedTest, ResetStream) {
ConnectSockets();
MaybeHandoverSocketZ();
sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), {});
sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket());
@ -906,10 +995,13 @@ TEST_F(DcSctpSocketTest, ResetStream) {
// Receiving a response will trigger a callback. Streams are now reset.
EXPECT_CALL(cb_a_, OnStreamsResetPerformed).Times(1);
sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket());
MaybeHandoverSocketZAndSendMessage();
}
TEST_F(DcSctpSocketTest, ResetStreamWillMakeChunksStartAtZeroSsn) {
TEST_P(DcSctpSocketParametrizedTest, ResetStreamWillMakeChunksStartAtZeroSsn) {
ConnectSockets();
MaybeHandoverSocketZ();
std::vector<uint8_t> payload(options_.mtu - 100);
@ -956,10 +1048,14 @@ TEST_F(DcSctpSocketTest, ResetStreamWillMakeChunksStartAtZeroSsn) {
// Handle SACK
sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket());
MaybeHandoverSocketZAndSendMessage();
}
TEST_F(DcSctpSocketTest, ResetStreamWillOnlyResetTheRequestedStreams) {
TEST_P(DcSctpSocketParametrizedTest,
ResetStreamWillOnlyResetTheRequestedStreams) {
ConnectSockets();
MaybeHandoverSocketZ();
std::vector<uint8_t> payload(options_.mtu - 100);
@ -1034,10 +1130,13 @@ TEST_F(DcSctpSocketTest, ResetStreamWillOnlyResetTheRequestedStreams) {
// Handle SACK
sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket());
MaybeHandoverSocketZAndSendMessage();
}
TEST_F(DcSctpSocketTest, OnePeerReconnects) {
TEST_P(DcSctpSocketParametrizedTest, OnePeerReconnects) {
ConnectSockets();
MaybeHandoverSocketZ();
EXPECT_CALL(cb_a_, OnConnectionRestarted).Times(1);
// Let's be evil here - reconnect while a fragmented packet was about to be
@ -1064,8 +1163,9 @@ TEST_F(DcSctpSocketTest, OnePeerReconnects) {
EXPECT_THAT(msg->payload(), testing::ElementsAreArray(payload));
}
TEST_F(DcSctpSocketTest, SendMessageWithLimitedRtx) {
TEST_P(DcSctpSocketParametrizedTest, SendMessageWithLimitedRtx) {
ConnectSockets();
MaybeHandoverSocketZ();
SendOptions send_options;
send_options.max_retransmissions = 0;
@ -1117,10 +1217,13 @@ TEST_F(DcSctpSocketTest, SendMessageWithLimitedRtx) {
absl::optional<DcSctpMessage> msg3 = cb_z_.ConsumeReceivedMessage();
EXPECT_FALSE(msg3.has_value());
MaybeHandoverSocketZAndSendMessage();
}
TEST_F(DcSctpSocketTest, SendManyFragmentedMessagesWithLimitedRtx) {
TEST_P(DcSctpSocketParametrizedTest, SendManyFragmentedMessagesWithLimitedRtx) {
ConnectSockets();
MaybeHandoverSocketZ();
SendOptions send_options;
send_options.unordered = IsUnordered(true);
@ -1210,8 +1313,9 @@ class FakeChunk : public Chunk, public TLVTrait<FakeChunkConfig> {
std::string ToString() const override { return "FAKE"; }
};
TEST_F(DcSctpSocketTest, ReceivingUnknownChunkRespondsWithError) {
TEST_P(DcSctpSocketParametrizedTest, ReceivingUnknownChunkRespondsWithError) {
ConnectSockets();
MaybeHandoverSocketZ();
// Inject a FAKE chunk
SctpPacket::Builder b(sock_a_->verification_tag(), DcSctpOptions());
@ -1228,10 +1332,13 @@ TEST_F(DcSctpSocketTest, ReceivingUnknownChunkRespondsWithError) {
UnrecognizedChunkTypeCause cause,
error.error_causes().get<UnrecognizedChunkTypeCause>());
EXPECT_THAT(cause.unrecognized_chunk(), ElementsAre(0x49, 0x00, 0x00, 0x04));
MaybeHandoverSocketZAndSendMessage();
}
TEST_F(DcSctpSocketTest, ReceivingErrorChunkReportsAsCallback) {
TEST_P(DcSctpSocketParametrizedTest, ReceivingErrorChunkReportsAsCallback) {
ConnectSockets();
MaybeHandoverSocketZ();
// Inject a ERROR chunk
SctpPacket::Builder b(sock_a_->verification_tag(), DcSctpOptions());
@ -1243,6 +1350,8 @@ TEST_F(DcSctpSocketTest, ReceivingErrorChunkReportsAsCallback) {
EXPECT_CALL(cb_a_, OnError(ErrorKind::kPeerReported,
HasSubstr("Unrecognized Chunk Type")));
sock_a_->ReceivePacket(b.Build());
MaybeHandoverSocketZAndSendMessage();
}
TEST_F(DcSctpSocketTest, PassingHighWatermarkWillOnlyAcceptCumAckTsn) {
@ -1359,8 +1468,9 @@ TEST_F(DcSctpSocketTest, SetMaxMessageSize) {
EXPECT_EQ(sock_a_->options().max_message_size, 42u);
}
TEST_F(DcSctpSocketTest, SendsMessagesWithLowLifetime) {
TEST_P(DcSctpSocketParametrizedTest, SendsMessagesWithLowLifetime) {
ConnectSockets();
MaybeHandoverSocketZ();
// Mock that the time always goes forward.
TimeMs now(0);
@ -1394,10 +1504,14 @@ TEST_F(DcSctpSocketTest, SendsMessagesWithLowLifetime) {
// Validate that the sockets really make the time move forward.
EXPECT_GE(*now, kIterations * 2);
MaybeHandoverSocketZAndSendMessage();
}
TEST_F(DcSctpSocketTest, DiscardsMessagesWithLowLifetimeIfMustBuffer) {
TEST_P(DcSctpSocketParametrizedTest,
DiscardsMessagesWithLowLifetimeIfMustBuffer) {
ConnectSockets();
MaybeHandoverSocketZ();
SendOptions lifetime_0;
lifetime_0.unordered = IsUnordered(true);
@ -1449,53 +1563,65 @@ TEST_F(DcSctpSocketTest, DiscardsMessagesWithLowLifetimeIfMustBuffer) {
// But none of the smaller messages.
EXPECT_FALSE(cb_z_.ConsumeReceivedMessage().has_value());
MaybeHandoverSocketZAndSendMessage();
}
TEST_F(DcSctpSocketTest, HasReasonableBufferedAmountValues) {
TEST_P(DcSctpSocketParametrizedTest, HasReasonableBufferedAmountValues) {
ConnectSockets();
MaybeHandoverSocketZ();
EXPECT_EQ(sock_a_->buffered_amount(StreamID(1)), 0u);
sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
std::vector<uint8_t>(kSmallMessageSize)),
kSendOptions);
std::vector<uint8_t>(kSmallMessageSize)),
kSendOptions);
// Sending a small message will directly send it as a single packet, so
// nothing is left in the queue.
EXPECT_EQ(sock_a_->buffered_amount(StreamID(1)), 0u);
sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
std::vector<uint8_t>(kLargeMessageSize)),
kSendOptions);
std::vector<uint8_t>(kLargeMessageSize)),
kSendOptions);
// Sending a message will directly start sending a few packets, so the
// buffered amount is not the full message size.
EXPECT_GT(sock_a_->buffered_amount(StreamID(1)), 0u);
EXPECT_LT(sock_a_->buffered_amount(StreamID(1)), kLargeMessageSize);
MaybeHandoverSocketZAndSendMessage();
}
TEST_F(DcSctpSocketTest, HasDefaultOnBufferedAmountLowValueZero) {
EXPECT_EQ(sock_a_->buffered_amount_low_threshold(StreamID(1)), 0u);
}
TEST_F(DcSctpSocketTest, TriggersOnBufferedAmountLowWithDefaultValueZero) {
TEST_P(DcSctpSocketParametrizedTest,
TriggersOnBufferedAmountLowWithDefaultValueZero) {
EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0);
ConnectSockets();
MaybeHandoverSocketZ();
EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(1)));
sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
std::vector<uint8_t>(kSmallMessageSize)),
kSendOptions);
std::vector<uint8_t>(kSmallMessageSize)),
kSendOptions);
ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_);
EXPECT_CALL(cb_a_, OnBufferedAmountLow).WillRepeatedly(testing::Return());
MaybeHandoverSocketZAndSendMessage();
}
TEST_F(DcSctpSocketTest, DoesntTriggerOnBufferedAmountLowIfBelowThreshold) {
TEST_P(DcSctpSocketParametrizedTest,
DoesntTriggerOnBufferedAmountLowIfBelowThreshold) {
static constexpr size_t kMessageSize = 1000;
static constexpr size_t kBufferedAmountLowThreshold = kMessageSize * 10;
sock_a_->SetBufferedAmountLowThreshold(StreamID(1),
kBufferedAmountLowThreshold);
kBufferedAmountLowThreshold);
EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0);
ConnectSockets();
MaybeHandoverSocketZ();
EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(1))).Times(0);
sock_a_->Send(
@ -1507,16 +1633,19 @@ TEST_F(DcSctpSocketTest, DoesntTriggerOnBufferedAmountLowIfBelowThreshold) {
DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
kSendOptions);
ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_);
MaybeHandoverSocketZAndSendMessage();
}
TEST_F(DcSctpSocketTest, TriggersOnBufferedAmountMultipleTimes) {
TEST_P(DcSctpSocketParametrizedTest, TriggersOnBufferedAmountMultipleTimes) {
static constexpr size_t kMessageSize = 1000;
static constexpr size_t kBufferedAmountLowThreshold = kMessageSize / 2;
sock_a_->SetBufferedAmountLowThreshold(StreamID(1),
kBufferedAmountLowThreshold);
kBufferedAmountLowThreshold);
EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0);
ConnectSockets();
MaybeHandoverSocketZ();
EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(1))).Times(3);
EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(2))).Times(2);
@ -1544,16 +1673,20 @@ TEST_F(DcSctpSocketTest, TriggersOnBufferedAmountMultipleTimes) {
DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
kSendOptions);
ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_);
MaybeHandoverSocketZAndSendMessage();
}
TEST_F(DcSctpSocketTest, TriggersOnBufferedAmountLowOnlyWhenCrossingThreshold) {
TEST_P(DcSctpSocketParametrizedTest,
TriggersOnBufferedAmountLowOnlyWhenCrossingThreshold) {
static constexpr size_t kMessageSize = 1000;
static constexpr size_t kBufferedAmountLowThreshold = kMessageSize * 1.5;
sock_a_->SetBufferedAmountLowThreshold(StreamID(1),
kBufferedAmountLowThreshold);
kBufferedAmountLowThreshold);
EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0);
ConnectSockets();
MaybeHandoverSocketZ();
EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0);
@ -1561,8 +1694,8 @@ TEST_F(DcSctpSocketTest, TriggersOnBufferedAmountLowOnlyWhenCrossingThreshold) {
// messages will start to be fully buffered.
while (sock_a_->buffered_amount(StreamID(1)) <= kBufferedAmountLowThreshold) {
sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
std::vector<uint8_t>(kMessageSize)),
kSendOptions);
std::vector<uint8_t>(kMessageSize)),
kSendOptions);
}
size_t initial_buffered = sock_a_->buffered_amount(StreamID(1));
ASSERT_GT(initial_buffered, kBufferedAmountLowThreshold);
@ -1571,36 +1704,46 @@ TEST_F(DcSctpSocketTest, TriggersOnBufferedAmountLowOnlyWhenCrossingThreshold) {
// callback.
EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(1))).Times(1);
ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_);
MaybeHandoverSocketZAndSendMessage();
}
TEST_F(DcSctpSocketTest, DoesntTriggerOnTotalBufferAmountLowWhenBelow) {
TEST_P(DcSctpSocketParametrizedTest,
DoesntTriggerOnTotalBufferAmountLowWhenBelow) {
ConnectSockets();
MaybeHandoverSocketZ();
EXPECT_CALL(cb_a_, OnTotalBufferedAmountLow).Times(0);
sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
std::vector<uint8_t>(kLargeMessageSize)),
kSendOptions);
std::vector<uint8_t>(kLargeMessageSize)),
kSendOptions);
ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_);
MaybeHandoverSocketZAndSendMessage();
}
TEST_F(DcSctpSocketTest, TriggersOnTotalBufferAmountLowWhenCrossingThreshold) {
TEST_P(DcSctpSocketParametrizedTest,
TriggersOnTotalBufferAmountLowWhenCrossingThreshold) {
ConnectSockets();
MaybeHandoverSocketZ();
EXPECT_CALL(cb_a_, OnTotalBufferedAmountLow).Times(0);
// Fill up the send queue completely.
for (;;) {
if (sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
std::vector<uint8_t>(kLargeMessageSize)),
kSendOptions) == SendStatus::kErrorResourceExhaustion) {
std::vector<uint8_t>(kLargeMessageSize)),
kSendOptions) == SendStatus::kErrorResourceExhaustion) {
break;
}
}
EXPECT_CALL(cb_a_, OnTotalBufferedAmountLow).Times(1);
ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_);
MaybeHandoverSocketZAndSendMessage();
}
TEST_F(DcSctpSocketTest, InitialMetricsAreZeroed) {
@ -1650,8 +1793,8 @@ TEST_F(DcSctpSocketTest, RxAndTxPacketMetricsIncrease) {
// Send one more (large - fragmented), and receive the delayed SACK.
sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
std::vector<uint8_t>(options_.mtu * 2 + 1)),
kSendOptions);
std::vector<uint8_t>(options_.mtu * 2 + 1)),
kSendOptions);
EXPECT_EQ(sock_a_->GetMetrics().unack_data_count, 3u);
sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); // DATA
@ -1683,12 +1826,13 @@ TEST_F(DcSctpSocketTest, RxAndTxPacketMetricsIncrease) {
EXPECT_EQ(*sock_a_->GetMetrics().peer_rwnd_bytes, initial_a_rwnd);
}
TEST_F(DcSctpSocketTest, UnackDataAlsoIncludesSendQueue) {
TEST_P(DcSctpSocketParametrizedTest, UnackDataAlsoIncludesSendQueue) {
ConnectSockets();
MaybeHandoverSocketZ();
sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
std::vector<uint8_t>(kLargeMessageSize)),
kSendOptions);
std::vector<uint8_t>(kLargeMessageSize)),
kSendOptions);
size_t payload_bytes =
options_.mtu - SctpPacket::kHeaderSize - DataChunk::kHeaderSize;
@ -1706,14 +1850,17 @@ TEST_F(DcSctpSocketTest, UnackDataAlsoIncludesSendQueue) {
EXPECT_LE(sock_a_->GetMetrics().unack_data_count,
expected_sent_packets + expected_queued_packets + 2);
MaybeHandoverSocketZAndSendMessage();
}
TEST_F(DcSctpSocketTest, DoesntSendMoreThanMaxBurstPackets) {
TEST_P(DcSctpSocketParametrizedTest, DoesntSendMoreThanMaxBurstPackets) {
ConnectSockets();
MaybeHandoverSocketZ();
sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
std::vector<uint8_t>(kLargeMessageSize)),
kSendOptions);
std::vector<uint8_t>(kLargeMessageSize)),
kSendOptions);
for (int i = 0; i < kMaxBurstPackets; ++i) {
std::vector<uint8_t> packet = cb_a_.ConsumeSentPacket();
@ -1722,10 +1869,14 @@ TEST_F(DcSctpSocketTest, DoesntSendMoreThanMaxBurstPackets) {
}
EXPECT_THAT(cb_a_.ConsumeSentPacket(), IsEmpty());
ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_);
MaybeHandoverSocketZAndSendMessage();
}
TEST_F(DcSctpSocketTest, SendsOnlyLargePackets) {
TEST_P(DcSctpSocketParametrizedTest, SendsOnlyLargePackets) {
ConnectSockets();
MaybeHandoverSocketZ();
// A really large message, to ensure that the congestion window is often full.
constexpr size_t kMessageSize = 100000;
@ -1765,10 +1916,13 @@ TEST_F(DcSctpSocketTest, SendsOnlyLargePackets) {
// The 4 is for padding/alignment.
EXPECT_GE(size, options_.mtu - 4);
}
MaybeHandoverSocketZAndSendMessage();
}
TEST_F(DcSctpSocketTest, DoesntBundleForwardTsnWithData) {
TEST_P(DcSctpSocketParametrizedTest, DoesntBundleForwardTsnWithData) {
ConnectSockets();
MaybeHandoverSocketZ();
// Force an RTT measurement using heartbeats.
AdvanceTime(options_.heartbeat_interval);
@ -1848,5 +2002,49 @@ TEST_F(DcSctpSocketTest, DoesntBundleForwardTsnWithData) {
EXPECT_EQ(packet4.descriptors()[0].type, ForwardTsnChunk::kType);
}
TEST_F(DcSctpSocketTest, SendMessagesAfterHandover) {
ConnectSockets();
// Send message before handover to move socket to a not initial state
sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket());
cb_z_.ConsumeReceivedMessage();
HandoverSocketZ();
absl::optional<DcSctpMessage> msg;
RTC_LOG(LS_INFO) << "Sending A #1";
sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {3, 4}), kSendOptions);
sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket());
msg = cb_z_.ConsumeReceivedMessage();
ASSERT_TRUE(msg.has_value());
EXPECT_EQ(msg->stream_id(), StreamID(1));
EXPECT_THAT(msg->payload(), testing::ElementsAre(3, 4));
RTC_LOG(LS_INFO) << "Sending A #2";
sock_a_->Send(DcSctpMessage(StreamID(2), PPID(53), {5, 6}), kSendOptions);
sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket());
msg = cb_z_.ConsumeReceivedMessage();
ASSERT_TRUE(msg.has_value());
EXPECT_EQ(msg->stream_id(), StreamID(2));
EXPECT_THAT(msg->payload(), testing::ElementsAre(5, 6));
RTC_LOG(LS_INFO) << "Sending Z #1";
sock_z_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2, 3}), kSendOptions);
sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); // ack
sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); // data
msg = cb_a_.ConsumeReceivedMessage();
ASSERT_TRUE(msg.has_value());
EXPECT_EQ(msg->stream_id(), StreamID(1));
EXPECT_THAT(msg->payload(), testing::ElementsAre(1, 2, 3));
}
} // namespace
} // namespace dcsctp

View File

@ -150,6 +150,12 @@ class MockDcSctpSocketCallbacks : public DcSctpSocketCallbacks {
return timeout_manager_.GetNextExpiredTimeout();
}
void Reset() {
sent_packets_.clear();
received_messages_.clear();
timeout_manager_.Reset();
}
private:
const std::string log_prefix_;
TimeMs now_ = TimeMs(0);

View File

@ -183,4 +183,30 @@ std::string TransmissionControlBlock::ToString() const {
return sb.Release();
}
HandoverReadinessStatus TransmissionControlBlock::GetHandoverReadiness() const {
HandoverReadinessStatus status;
status.Add(data_tracker_.GetHandoverReadiness());
status.Add(stream_reset_handler_.GetHandoverReadiness());
status.Add(reassembly_queue_.GetHandoverReadiness());
status.Add(retransmission_queue_.GetHandoverReadiness());
return status;
}
void TransmissionControlBlock::AddHandoverState(
DcSctpSocketHandoverState& state) {
state.capabilities.partial_reliability = capabilities_.partial_reliability;
state.capabilities.message_interleaving = capabilities_.message_interleaving;
state.capabilities.reconfig = capabilities_.reconfig;
state.my_verification_tag = my_verification_tag().value();
state.peer_verification_tag = peer_verification_tag().value();
state.my_initial_tsn = my_initial_tsn().value();
state.peer_initial_tsn = peer_initial_tsn().value();
state.tie_tag = tie_tag().value();
data_tracker_.AddHandoverState(state);
stream_reset_handler_.AddHandoverState(state);
reassembly_queue_.AddHandoverState(state);
retransmission_queue_.AddHandoverState(state);
}
} // namespace dcsctp

View File

@ -44,20 +44,22 @@ namespace dcsctp {
// closed or restarted, this object will be deleted and/or replaced.
class TransmissionControlBlock : public Context {
public:
TransmissionControlBlock(TimerManager& timer_manager,
absl::string_view log_prefix,
const DcSctpOptions& options,
const Capabilities& capabilities,
DcSctpSocketCallbacks& callbacks,
SendQueue& send_queue,
VerificationTag my_verification_tag,
TSN my_initial_tsn,
VerificationTag peer_verification_tag,
TSN peer_initial_tsn,
size_t a_rwnd,
TieTag tie_tag,
PacketSender& packet_sender,
std::function<bool()> is_connection_established)
TransmissionControlBlock(
TimerManager& timer_manager,
absl::string_view log_prefix,
const DcSctpOptions& options,
const Capabilities& capabilities,
DcSctpSocketCallbacks& callbacks,
SendQueue& send_queue,
VerificationTag my_verification_tag,
TSN my_initial_tsn,
VerificationTag peer_verification_tag,
TSN peer_initial_tsn,
size_t a_rwnd,
TieTag tie_tag,
PacketSender& packet_sender,
std::function<bool()> is_connection_established,
const DcSctpSocketHandoverState* handover_state = nullptr)
: log_prefix_(log_prefix),
options_(options),
timer_manager_(timer_manager),
@ -86,10 +88,14 @@ class TransmissionControlBlock : public Context {
packet_sender_(packet_sender),
rto_(options),
tx_error_counter_(log_prefix, options),
data_tracker_(log_prefix, delayed_ack_timer_.get(), peer_initial_tsn),
data_tracker_(log_prefix,
delayed_ack_timer_.get(),
peer_initial_tsn,
handover_state),
reassembly_queue_(log_prefix,
peer_initial_tsn,
options.max_receiver_window_buffer_size),
options.max_receiver_window_buffer_size,
handover_state),
retransmission_queue_(
log_prefix,
my_initial_tsn,
@ -100,13 +106,15 @@ class TransmissionControlBlock : public Context {
*t3_rtx_,
options,
capabilities.partial_reliability,
capabilities.message_interleaving),
capabilities.message_interleaving,
handover_state),
stream_reset_handler_(log_prefix,
this,
&timer_manager,
&data_tracker_,
&reassembly_queue_,
&retransmission_queue_),
&retransmission_queue_,
handover_state),
heartbeat_handler_(log_prefix, options, this, &timer_manager_) {}
// Implementation of `Context`.
@ -188,6 +196,10 @@ class TransmissionControlBlock : public Context {
// Returns a textual representation of this object, for logging.
std::string ToString() const;
HandoverReadinessStatus GetHandoverReadiness() const;
void AddHandoverState(DcSctpSocketHandoverState& state);
private:
// Will be called when the retransmission timer (t3-rtx) expires.
absl::optional<DurationMs> OnRtxTimerExpiry();

View File

@ -91,6 +91,8 @@ class FakeTimeoutManager {
return absl::nullopt;
}
void Reset() { timers_.clear(); }
private:
const std::function<TimeMs()> get_time_;
webrtc::flat_set<FakeTimeout*> timers_;