From 17a02a31d7d2897b75ad69fdac5d10e7475a5865 Mon Sep 17 00:00:00 2001 From: Victor Boivie Date: Mon, 2 May 2022 13:04:37 +0200 Subject: [PATCH] dcsctp: Add public API for setting priorities This is the first part of supporting stream priorities, and adds the API and very basic support for setting and retrieving the stream priority. This commit doesn't in any way change the actual packet sending - the specified priority values are stored, but not acted on. This is all that is client visible, so clients can start using the API as written, and they would never notice that things are missing. Bug: webrtc:5696 Change-Id: I24fce8cbb6f3cba187df99d1d3f45e73621c93c6 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/261943 Reviewed-by: Harald Alvestrand Commit-Queue: Victor Boivie Cr-Commit-Position: refs/heads/main@{#37034} --- net/dcsctp/public/dcsctp_handover_state.h | 1 + net/dcsctp/public/dcsctp_options.h | 5 +++ net/dcsctp/public/dcsctp_socket.h | 9 +++++ net/dcsctp/public/mock_dcsctp_socket.h | 10 +++++ net/dcsctp/public/types.h | 4 ++ net/dcsctp/socket/dcsctp_socket.cc | 9 +++++ net/dcsctp/socket/dcsctp_socket.h | 2 + net/dcsctp/socket/dcsctp_socket_test.cc | 47 ++++++++++++++++++++++- net/dcsctp/tx/rr_send_queue.cc | 33 ++++++++++++---- net/dcsctp/tx/rr_send_queue.h | 10 +++++ net/dcsctp/tx/rr_send_queue_test.cc | 36 +++++++++++++++++ 11 files changed, 158 insertions(+), 8 deletions(-) diff --git a/net/dcsctp/public/dcsctp_handover_state.h b/net/dcsctp/public/dcsctp_handover_state.h index a58535d45f..36fc37ba89 100644 --- a/net/dcsctp/public/dcsctp_handover_state.h +++ b/net/dcsctp/public/dcsctp_handover_state.h @@ -48,6 +48,7 @@ struct DcSctpSocketHandoverState { uint32_t next_ssn = 0; uint32_t next_unordered_mid = 0; uint32_t next_ordered_mid = 0; + uint16_t priority = 0; }; struct Transmission { uint32_t next_tsn = 0; diff --git a/net/dcsctp/public/dcsctp_options.h b/net/dcsctp/public/dcsctp_options.h index c394552e22..4511bed4a4 100644 --- a/net/dcsctp/public/dcsctp_options.h +++ b/net/dcsctp/public/dcsctp_options.h @@ -71,6 +71,11 @@ struct DcSctpOptions { // `max_receiver_window_buffer_size`). size_t max_message_size = 256 * 1024; + // The default stream priority, if not overridden by + // `SctpSocket::SetStreamPriority`. The default value is selected to be + // compatible with https://www.w3.org/TR/webrtc-priority/, section 4.2-4.3. + StreamPriority default_stream_priority = StreamPriority(256); + // Maximum received window buffer size. This should be a bit larger than the // largest sized message you want to be able to receive. This essentially // limits the memory usage on the receive side. Note that memory is allocated diff --git a/net/dcsctp/public/dcsctp_socket.h b/net/dcsctp/public/dcsctp_socket.h index e15a5bfaf8..0a65dae1d4 100644 --- a/net/dcsctp/public/dcsctp_socket.h +++ b/net/dcsctp/public/dcsctp_socket.h @@ -430,6 +430,15 @@ class DcSctpSocketInterface { // Update the options max_message_size. virtual void SetMaxMessageSize(size_t max_message_size) = 0; + // Sets the priority of an outgoing stream. The initial value, when not set, + // is `DcSctpOptions::default_stream_priority`. + virtual void SetStreamPriority(StreamID stream_id, + StreamPriority priority) = 0; + + // Returns the currently set priority for an outgoing stream. The initial + // value, when not set, is `DcSctpOptions::default_stream_priority`. + virtual StreamPriority GetStreamPriority(StreamID stream_id) const = 0; + // Sends the message `message` using the provided send options. // Sending a message is an asynchrous operation, and the `OnError` callback // may be invoked to indicate any errors in sending the message. diff --git a/net/dcsctp/public/mock_dcsctp_socket.h b/net/dcsctp/public/mock_dcsctp_socket.h index 6560a3f3fa..0fd572bd94 100644 --- a/net/dcsctp/public/mock_dcsctp_socket.h +++ b/net/dcsctp/public/mock_dcsctp_socket.h @@ -41,6 +41,16 @@ class MockDcSctpSocket : public DcSctpSocketInterface { MOCK_METHOD(void, SetMaxMessageSize, (size_t max_message_size), (override)); + MOCK_METHOD(void, + SetStreamPriority, + (StreamID stream_id, StreamPriority priority), + (override)); + + MOCK_METHOD(StreamPriority, + GetStreamPriority, + (StreamID stream_id), + (const, override)); + MOCK_METHOD(SendStatus, Send, (DcSctpMessage message, const SendOptions& send_options), diff --git a/net/dcsctp/public/types.h b/net/dcsctp/public/types.h index caa03bb96f..358e243fc5 100644 --- a/net/dcsctp/public/types.h +++ b/net/dcsctp/public/types.h @@ -31,6 +31,10 @@ using TimeoutID = webrtc::StrongAlias; // other messages on the same stream. using IsUnordered = webrtc::StrongAlias; +// Stream priority, where higher values indicate higher priority. The meaning of +// this value and how it's used depends on the stream scheduler. +using StreamPriority = webrtc::StrongAlias; + // Duration, as milliseconds. Overflows after 24 days. class DurationMs : public webrtc::StrongAlias { public: diff --git a/net/dcsctp/socket/dcsctp_socket.cc b/net/dcsctp/socket/dcsctp_socket.cc index 9d6ae0e3e6..e0a912c950 100644 --- a/net/dcsctp/socket/dcsctp_socket.cc +++ b/net/dcsctp/socket/dcsctp_socket.cc @@ -189,6 +189,7 @@ DcSctpSocket::DcSctpSocket(absl::string_view log_prefix, send_queue_( log_prefix_, options_.max_send_buffer_size, + options_.default_stream_priority, [this](StreamID stream_id) { callbacks_.OnBufferedAmountLow(stream_id); }, @@ -420,6 +421,14 @@ void DcSctpSocket::InternalClose(ErrorKind error, absl::string_view message) { RTC_DCHECK(IsConsistent()); } +void DcSctpSocket::SetStreamPriority(StreamID stream_id, + StreamPriority priority) { + send_queue_.SetStreamPriority(stream_id, priority); +} +StreamPriority DcSctpSocket::GetStreamPriority(StreamID stream_id) const { + return send_queue_.GetStreamPriority(stream_id); +} + SendStatus DcSctpSocket::Send(DcSctpMessage message, const SendOptions& send_options) { RTC_DCHECK_RUN_ON(&thread_checker_); diff --git a/net/dcsctp/socket/dcsctp_socket.h b/net/dcsctp/socket/dcsctp_socket.h index 07e760ab01..d70d0fca54 100644 --- a/net/dcsctp/socket/dcsctp_socket.h +++ b/net/dcsctp/socket/dcsctp_socket.h @@ -96,6 +96,8 @@ class DcSctpSocket : public DcSctpSocketInterface { SocketState state() const override; const DcSctpOptions& options() const override { return options_; } void SetMaxMessageSize(size_t max_message_size) override; + void SetStreamPriority(StreamID stream_id, StreamPriority priority) override; + StreamPriority GetStreamPriority(StreamID stream_id) const override; size_t buffered_amount(StreamID stream_id) const override; size_t buffered_amount_low_threshold(StreamID stream_id) const override; void SetBufferedAmountLowThreshold(StreamID stream_id, size_t bytes) override; diff --git a/net/dcsctp/socket/dcsctp_socket_test.cc b/net/dcsctp/socket/dcsctp_socket_test.cc index 770fd84a18..cc5566f9ae 100644 --- a/net/dcsctp/socket/dcsctp_socket_test.cc +++ b/net/dcsctp/socket/dcsctp_socket_test.cc @@ -2333,6 +2333,51 @@ TEST(DcSctpSocketTest, CloseStreamsWithPendingRequest) { absl::optional msg6 = z.cb.ConsumeReceivedMessage(); ASSERT_TRUE(msg6.has_value()); EXPECT_EQ(msg6->stream_id(), StreamID(3)); -} // namespace +} + +TEST(DcSctpSocketTest, StreamsHaveInitialPriority) { + DcSctpOptions options = {.default_stream_priority = StreamPriority(42)}; + SocketUnderTest a("A", options); + + EXPECT_EQ(a.socket.GetStreamPriority(StreamID(1)), + options.default_stream_priority); + + a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions); + + EXPECT_EQ(a.socket.GetStreamPriority(StreamID(2)), + options.default_stream_priority); +} + +TEST(DcSctpSocketTest, CanChangeStreamPriority) { + DcSctpOptions options = {.default_stream_priority = StreamPriority(42)}; + SocketUnderTest a("A", options); + + a.socket.SetStreamPriority(StreamID(1), StreamPriority(43)); + EXPECT_EQ(a.socket.GetStreamPriority(StreamID(1)), StreamPriority(43)); + + a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions); + + a.socket.SetStreamPriority(StreamID(2), StreamPriority(43)); + EXPECT_EQ(a.socket.GetStreamPriority(StreamID(2)), StreamPriority(43)); +} + +TEST_P(DcSctpSocketParametrizedTest, WillHandoverPriority) { + DcSctpOptions options = {.default_stream_priority = StreamPriority(42)}; + auto a = std::make_unique("A", options); + SocketUnderTest z("Z"); + + ConnectSockets(*a, z); + + a->socket.SetStreamPriority(StreamID(1), StreamPriority(43)); + a->socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions); + a->socket.SetStreamPriority(StreamID(2), StreamPriority(43)); + + ExchangeMessages(*a, z); + + a = MaybeHandoverSocket(std::move(a)); + + EXPECT_EQ(a->socket.GetStreamPriority(StreamID(1)), StreamPriority(43)); + EXPECT_EQ(a->socket.GetStreamPriority(StreamID(2)), StreamPriority(43)); +} } // namespace } // namespace dcsctp diff --git a/net/dcsctp/tx/rr_send_queue.cc b/net/dcsctp/tx/rr_send_queue.cc index d4ce59d58c..3a2166b813 100644 --- a/net/dcsctp/tx/rr_send_queue.cc +++ b/net/dcsctp/tx/rr_send_queue.cc @@ -30,11 +30,13 @@ namespace dcsctp { RRSendQueue::RRSendQueue(absl::string_view log_prefix, size_t buffer_size, + StreamPriority default_priority, std::function on_buffered_amount_low, size_t total_buffered_amount_low_threshold, std::function on_total_buffered_amount_low) : log_prefix_(std::string(log_prefix) + "fcfs: "), buffer_size_(buffer_size), + default_priority_(default_priority), on_buffered_amount_low_(std::move(on_buffered_amount_low)), total_buffered_amount_(std::move(on_total_buffered_amount_low)) { total_buffered_amount_.SetLowThreshold(total_buffered_amount_low_threshold); @@ -75,6 +77,7 @@ void RRSendQueue::OutgoingStream::AddHandoverState( state.next_ssn = next_ssn_.value(); state.next_ordered_mid = next_ordered_mid_.value(); state.next_unordered_mid = next_unordered_mid_.value(); + state.priority = *priority_; } bool RRSendQueue::IsConsistent() const { @@ -515,12 +518,28 @@ RRSendQueue::OutgoingStream& RRSendQueue::GetOrCreateStreamInfo( return streams_ .emplace(stream_id, OutgoingStream( - stream_id, + stream_id, default_priority_, [this, stream_id]() { on_buffered_amount_low_(stream_id); }, total_buffered_amount_)) .first->second; } +void RRSendQueue::SetStreamPriority(StreamID stream_id, + StreamPriority priority) { + OutgoingStream& stream = GetOrCreateStreamInfo(stream_id); + + stream.set_priority(priority); + RTC_DCHECK(IsConsistent()); +} + +StreamPriority RRSendQueue::GetStreamPriority(StreamID stream_id) const { + auto stream_it = streams_.find(stream_id); + if (stream_it == streams_.end()) { + return default_priority_; + } + return stream_it->second.priority(); +} + HandoverReadinessStatus RRSendQueue::GetHandoverReadiness() const { HandoverReadinessStatus status; if (!IsEmpty()) { @@ -542,12 +561,12 @@ void RRSendQueue::RestoreFromState(const DcSctpSocketHandoverState& state) { for (const DcSctpSocketHandoverState::OutgoingStream& state_stream : state.tx.streams) { StreamID stream_id(state_stream.id); - streams_.emplace(stream_id, OutgoingStream( - stream_id, - [this, stream_id]() { - on_buffered_amount_low_(stream_id); - }, - total_buffered_amount_, &state_stream)); + streams_.emplace( + stream_id, + OutgoingStream( + stream_id, StreamPriority(state_stream.priority), + [this, stream_id]() { on_buffered_amount_low_(stream_id); }, + total_buffered_amount_, &state_stream)); } } } // namespace dcsctp diff --git a/net/dcsctp/tx/rr_send_queue.h b/net/dcsctp/tx/rr_send_queue.h index 57a43ccd66..7ddb426ec9 100644 --- a/net/dcsctp/tx/rr_send_queue.h +++ b/net/dcsctp/tx/rr_send_queue.h @@ -43,6 +43,7 @@ class RRSendQueue : public SendQueue { public: RRSendQueue(absl::string_view log_prefix, size_t buffer_size, + StreamPriority default_priority, std::function on_buffered_amount_low, size_t total_buffered_amount_low_threshold, std::function on_total_buffered_amount_low); @@ -79,6 +80,8 @@ class RRSendQueue : public SendQueue { size_t buffered_amount_low_threshold(StreamID stream_id) const override; void SetBufferedAmountLowThreshold(StreamID stream_id, size_t bytes) override; + void SetStreamPriority(StreamID stream_id, StreamPriority priority); + StreamPriority GetStreamPriority(StreamID stream_id) const; HandoverReadinessStatus GetHandoverReadiness() const; void AddHandoverState(DcSctpSocketHandoverState& state); void RestoreFromState(const DcSctpSocketHandoverState& state); @@ -112,10 +115,12 @@ class RRSendQueue : public SendQueue { public: OutgoingStream( StreamID stream_id, + StreamPriority priority, std::function on_buffered_amount_low, ThresholdWatcher& total_buffered_amount, const DcSctpSocketHandoverState::OutgoingStream* state = nullptr) : stream_id_(stream_id), + priority_(priority), next_unordered_mid_(MID(state ? state->next_unordered_mid : 0)), next_ordered_mid_(MID(state ? state->next_ordered_mid : 0)), next_ssn_(SSN(state ? state->next_ssn : 0)), @@ -166,6 +171,9 @@ class RRSendQueue : public SendQueue { // expired non-partially sent message. bool HasDataToSend(TimeMs now); + void set_priority(StreamPriority priority) { priority_ = priority; } + StreamPriority priority() const { return priority_; } + void AddHandoverState( DcSctpSocketHandoverState::OutgoingStream& state) const; @@ -218,6 +226,7 @@ class RRSendQueue : public SendQueue { bool IsConsistent() const; const StreamID stream_id_; + StreamPriority priority_; PauseState pause_state_ = PauseState::kNotPaused; // MIDs are different for unordered and ordered messages sent on a stream. MID next_unordered_mid_; @@ -247,6 +256,7 @@ class RRSendQueue : public SendQueue { const std::string log_prefix_; const size_t buffer_size_; + const StreamPriority default_priority_; // Called when the buffered amount is below what has been set using // `SetBufferedAmountLowThreshold`. diff --git a/net/dcsctp/tx/rr_send_queue_test.cc b/net/dcsctp/tx/rr_send_queue_test.cc index fbbce58de1..3966c17e58 100644 --- a/net/dcsctp/tx/rr_send_queue_test.cc +++ b/net/dcsctp/tx/rr_send_queue_test.cc @@ -32,6 +32,7 @@ constexpr TimeMs kNow = TimeMs(0); constexpr StreamID kStreamID(1); constexpr PPID kPPID(53); constexpr size_t kMaxQueueSize = 1000; +constexpr StreamPriority kDefaultPriority(10); constexpr size_t kBufferedAmountLowThreshold = 500; constexpr size_t kOneFragmentPacketSize = 100; constexpr size_t kTwoFragmentPacketSize = 101; @@ -41,6 +42,7 @@ class RRSendQueueTest : public testing::Test { RRSendQueueTest() : buf_("log: ", kMaxQueueSize, + kDefaultPriority, on_buffered_amount_low_.AsStdFunction(), kBufferedAmountLowThreshold, on_total_buffered_amount_low_.AsStdFunction()) {} @@ -759,5 +761,39 @@ TEST_F(RRSendQueueTest, WillStayInAStreamAsLongAsThatMessageIsSending) { EXPECT_FALSE(buf_.Produce(kNow, kOneFragmentPacketSize).has_value()); } + +TEST_F(RRSendQueueTest, StreamsHaveInitialPriority) { + EXPECT_EQ(buf_.GetStreamPriority(StreamID(1)), kDefaultPriority); + + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, std::vector(40))); + EXPECT_EQ(buf_.GetStreamPriority(StreamID(2)), kDefaultPriority); +} + +TEST_F(RRSendQueueTest, CanChangeStreamPriority) { + buf_.SetStreamPriority(StreamID(1), StreamPriority(42)); + EXPECT_EQ(buf_.GetStreamPriority(StreamID(1)), StreamPriority(42)); + + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, std::vector(40))); + buf_.SetStreamPriority(StreamID(2), StreamPriority(42)); + EXPECT_EQ(buf_.GetStreamPriority(StreamID(2)), StreamPriority(42)); +} + +TEST_F(RRSendQueueTest, WillHandoverPriority) { + buf_.SetStreamPriority(StreamID(1), StreamPriority(42)); + + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, std::vector(40))); + buf_.SetStreamPriority(StreamID(2), StreamPriority(42)); + + DcSctpSocketHandoverState state; + buf_.AddHandoverState(state); + + RRSendQueue q2("log: ", kMaxQueueSize, kDefaultPriority, + on_buffered_amount_low_.AsStdFunction(), + kBufferedAmountLowThreshold, + on_total_buffered_amount_low_.AsStdFunction()); + q2.RestoreFromState(state); + EXPECT_EQ(q2.GetStreamPriority(StreamID(1)), StreamPriority(42)); + EXPECT_EQ(q2.GetStreamPriority(StreamID(2)), StreamPriority(42)); +} } // namespace } // namespace dcsctp