diff --git a/net/dcsctp/public/dcsctp_options.h b/net/dcsctp/public/dcsctp_options.h index c394552e22..4511bed4a4 100644 --- a/net/dcsctp/public/dcsctp_options.h +++ b/net/dcsctp/public/dcsctp_options.h @@ -71,6 +71,11 @@ struct DcSctpOptions { // `max_receiver_window_buffer_size`). size_t max_message_size = 256 * 1024; + // The default stream priority, if not overridden by + // `SctpSocket::SetStreamPriority`. The default value is selected to be + // compatible with https://www.w3.org/TR/webrtc-priority/, section 4.2-4.3. + StreamPriority default_stream_priority = StreamPriority(256); + // Maximum received window buffer size. This should be a bit larger than the // largest sized message you want to be able to receive. This essentially // limits the memory usage on the receive side. Note that memory is allocated diff --git a/net/dcsctp/public/dcsctp_socket.h b/net/dcsctp/public/dcsctp_socket.h index e15a5bfaf8..0a65dae1d4 100644 --- a/net/dcsctp/public/dcsctp_socket.h +++ b/net/dcsctp/public/dcsctp_socket.h @@ -430,6 +430,15 @@ class DcSctpSocketInterface { // Update the options max_message_size. virtual void SetMaxMessageSize(size_t max_message_size) = 0; + // Sets the priority of an outgoing stream. The initial value, when not set, + // is `DcSctpOptions::default_stream_priority`. + virtual void SetStreamPriority(StreamID stream_id, + StreamPriority priority) = 0; + + // Returns the currently set priority for an outgoing stream. The initial + // value, when not set, is `DcSctpOptions::default_stream_priority`. + virtual StreamPriority GetStreamPriority(StreamID stream_id) const = 0; + // Sends the message `message` using the provided send options. // Sending a message is an asynchrous operation, and the `OnError` callback // may be invoked to indicate any errors in sending the message. diff --git a/net/dcsctp/public/mock_dcsctp_socket.h b/net/dcsctp/public/mock_dcsctp_socket.h index 6560a3f3fa..0fd572bd94 100644 --- a/net/dcsctp/public/mock_dcsctp_socket.h +++ b/net/dcsctp/public/mock_dcsctp_socket.h @@ -41,6 +41,16 @@ class MockDcSctpSocket : public DcSctpSocketInterface { MOCK_METHOD(void, SetMaxMessageSize, (size_t max_message_size), (override)); + MOCK_METHOD(void, + SetStreamPriority, + (StreamID stream_id, StreamPriority priority), + (override)); + + MOCK_METHOD(StreamPriority, + GetStreamPriority, + (StreamID stream_id), + (const, override)); + MOCK_METHOD(SendStatus, Send, (DcSctpMessage message, const SendOptions& send_options), diff --git a/net/dcsctp/public/types.h b/net/dcsctp/public/types.h index caa03bb96f..358e243fc5 100644 --- a/net/dcsctp/public/types.h +++ b/net/dcsctp/public/types.h @@ -31,6 +31,10 @@ using TimeoutID = webrtc::StrongAlias; // other messages on the same stream. using IsUnordered = webrtc::StrongAlias; +// Stream priority, where higher values indicate higher priority. The meaning of +// this value and how it's used depends on the stream scheduler. +using StreamPriority = webrtc::StrongAlias; + // Duration, as milliseconds. Overflows after 24 days. class DurationMs : public webrtc::StrongAlias { public: diff --git a/net/dcsctp/socket/dcsctp_socket.cc b/net/dcsctp/socket/dcsctp_socket.cc index 9d6ae0e3e6..5f8312d4bd 100644 --- a/net/dcsctp/socket/dcsctp_socket.cc +++ b/net/dcsctp/socket/dcsctp_socket.cc @@ -189,6 +189,7 @@ DcSctpSocket::DcSctpSocket(absl::string_view log_prefix, send_queue_( log_prefix_, options_.max_send_buffer_size, + options_.default_stream_priority, [this](StreamID stream_id) { callbacks_.OnBufferedAmountLow(stream_id); }, @@ -420,6 +421,16 @@ void DcSctpSocket::InternalClose(ErrorKind error, absl::string_view message) { RTC_DCHECK(IsConsistent()); } +void DcSctpSocket::SetStreamPriority(StreamID stream_id, + StreamPriority priority) { + RTC_DCHECK_RUN_ON(&thread_checker_); + send_queue_.SetStreamPriority(stream_id, priority); +} +StreamPriority DcSctpSocket::GetStreamPriority(StreamID stream_id) const { + RTC_DCHECK_RUN_ON(&thread_checker_); + return send_queue_.GetStreamPriority(stream_id); +} + SendStatus DcSctpSocket::Send(DcSctpMessage message, const SendOptions& send_options) { RTC_DCHECK_RUN_ON(&thread_checker_); diff --git a/net/dcsctp/socket/dcsctp_socket.h b/net/dcsctp/socket/dcsctp_socket.h index 07e760ab01..d70d0fca54 100644 --- a/net/dcsctp/socket/dcsctp_socket.h +++ b/net/dcsctp/socket/dcsctp_socket.h @@ -96,6 +96,8 @@ class DcSctpSocket : public DcSctpSocketInterface { SocketState state() const override; const DcSctpOptions& options() const override { return options_; } void SetMaxMessageSize(size_t max_message_size) override; + void SetStreamPriority(StreamID stream_id, StreamPriority priority) override; + StreamPriority GetStreamPriority(StreamID stream_id) const override; size_t buffered_amount(StreamID stream_id) const override; size_t buffered_amount_low_threshold(StreamID stream_id) const override; void SetBufferedAmountLowThreshold(StreamID stream_id, size_t bytes) override; diff --git a/net/dcsctp/socket/dcsctp_socket_test.cc b/net/dcsctp/socket/dcsctp_socket_test.cc index 770fd84a18..cc5566f9ae 100644 --- a/net/dcsctp/socket/dcsctp_socket_test.cc +++ b/net/dcsctp/socket/dcsctp_socket_test.cc @@ -2333,6 +2333,51 @@ TEST(DcSctpSocketTest, CloseStreamsWithPendingRequest) { absl::optional msg6 = z.cb.ConsumeReceivedMessage(); ASSERT_TRUE(msg6.has_value()); EXPECT_EQ(msg6->stream_id(), StreamID(3)); -} // namespace +} + +TEST(DcSctpSocketTest, StreamsHaveInitialPriority) { + DcSctpOptions options = {.default_stream_priority = StreamPriority(42)}; + SocketUnderTest a("A", options); + + EXPECT_EQ(a.socket.GetStreamPriority(StreamID(1)), + options.default_stream_priority); + + a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions); + + EXPECT_EQ(a.socket.GetStreamPriority(StreamID(2)), + options.default_stream_priority); +} + +TEST(DcSctpSocketTest, CanChangeStreamPriority) { + DcSctpOptions options = {.default_stream_priority = StreamPriority(42)}; + SocketUnderTest a("A", options); + + a.socket.SetStreamPriority(StreamID(1), StreamPriority(43)); + EXPECT_EQ(a.socket.GetStreamPriority(StreamID(1)), StreamPriority(43)); + + a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions); + + a.socket.SetStreamPriority(StreamID(2), StreamPriority(43)); + EXPECT_EQ(a.socket.GetStreamPriority(StreamID(2)), StreamPriority(43)); +} + +TEST_P(DcSctpSocketParametrizedTest, WillHandoverPriority) { + DcSctpOptions options = {.default_stream_priority = StreamPriority(42)}; + auto a = std::make_unique("A", options); + SocketUnderTest z("Z"); + + ConnectSockets(*a, z); + + a->socket.SetStreamPriority(StreamID(1), StreamPriority(43)); + a->socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions); + a->socket.SetStreamPriority(StreamID(2), StreamPriority(43)); + + ExchangeMessages(*a, z); + + a = MaybeHandoverSocket(std::move(a)); + + EXPECT_EQ(a->socket.GetStreamPriority(StreamID(1)), StreamPriority(43)); + EXPECT_EQ(a->socket.GetStreamPriority(StreamID(2)), StreamPriority(43)); +} } // namespace } // namespace dcsctp diff --git a/net/dcsctp/tx/rr_send_queue.cc b/net/dcsctp/tx/rr_send_queue.cc index d4ce59d58c..3a2166b813 100644 --- a/net/dcsctp/tx/rr_send_queue.cc +++ b/net/dcsctp/tx/rr_send_queue.cc @@ -30,11 +30,13 @@ namespace dcsctp { RRSendQueue::RRSendQueue(absl::string_view log_prefix, size_t buffer_size, + StreamPriority default_priority, std::function on_buffered_amount_low, size_t total_buffered_amount_low_threshold, std::function on_total_buffered_amount_low) : log_prefix_(std::string(log_prefix) + "fcfs: "), buffer_size_(buffer_size), + default_priority_(default_priority), on_buffered_amount_low_(std::move(on_buffered_amount_low)), total_buffered_amount_(std::move(on_total_buffered_amount_low)) { total_buffered_amount_.SetLowThreshold(total_buffered_amount_low_threshold); @@ -75,6 +77,7 @@ void RRSendQueue::OutgoingStream::AddHandoverState( state.next_ssn = next_ssn_.value(); state.next_ordered_mid = next_ordered_mid_.value(); state.next_unordered_mid = next_unordered_mid_.value(); + state.priority = *priority_; } bool RRSendQueue::IsConsistent() const { @@ -515,12 +518,28 @@ RRSendQueue::OutgoingStream& RRSendQueue::GetOrCreateStreamInfo( return streams_ .emplace(stream_id, OutgoingStream( - stream_id, + stream_id, default_priority_, [this, stream_id]() { on_buffered_amount_low_(stream_id); }, total_buffered_amount_)) .first->second; } +void RRSendQueue::SetStreamPriority(StreamID stream_id, + StreamPriority priority) { + OutgoingStream& stream = GetOrCreateStreamInfo(stream_id); + + stream.set_priority(priority); + RTC_DCHECK(IsConsistent()); +} + +StreamPriority RRSendQueue::GetStreamPriority(StreamID stream_id) const { + auto stream_it = streams_.find(stream_id); + if (stream_it == streams_.end()) { + return default_priority_; + } + return stream_it->second.priority(); +} + HandoverReadinessStatus RRSendQueue::GetHandoverReadiness() const { HandoverReadinessStatus status; if (!IsEmpty()) { @@ -542,12 +561,12 @@ void RRSendQueue::RestoreFromState(const DcSctpSocketHandoverState& state) { for (const DcSctpSocketHandoverState::OutgoingStream& state_stream : state.tx.streams) { StreamID stream_id(state_stream.id); - streams_.emplace(stream_id, OutgoingStream( - stream_id, - [this, stream_id]() { - on_buffered_amount_low_(stream_id); - }, - total_buffered_amount_, &state_stream)); + streams_.emplace( + stream_id, + OutgoingStream( + stream_id, StreamPriority(state_stream.priority), + [this, stream_id]() { on_buffered_amount_low_(stream_id); }, + total_buffered_amount_, &state_stream)); } } } // namespace dcsctp diff --git a/net/dcsctp/tx/rr_send_queue.h b/net/dcsctp/tx/rr_send_queue.h index 57a43ccd66..7ddb426ec9 100644 --- a/net/dcsctp/tx/rr_send_queue.h +++ b/net/dcsctp/tx/rr_send_queue.h @@ -43,6 +43,7 @@ class RRSendQueue : public SendQueue { public: RRSendQueue(absl::string_view log_prefix, size_t buffer_size, + StreamPriority default_priority, std::function on_buffered_amount_low, size_t total_buffered_amount_low_threshold, std::function on_total_buffered_amount_low); @@ -79,6 +80,8 @@ class RRSendQueue : public SendQueue { size_t buffered_amount_low_threshold(StreamID stream_id) const override; void SetBufferedAmountLowThreshold(StreamID stream_id, size_t bytes) override; + void SetStreamPriority(StreamID stream_id, StreamPriority priority); + StreamPriority GetStreamPriority(StreamID stream_id) const; HandoverReadinessStatus GetHandoverReadiness() const; void AddHandoverState(DcSctpSocketHandoverState& state); void RestoreFromState(const DcSctpSocketHandoverState& state); @@ -112,10 +115,12 @@ class RRSendQueue : public SendQueue { public: OutgoingStream( StreamID stream_id, + StreamPriority priority, std::function on_buffered_amount_low, ThresholdWatcher& total_buffered_amount, const DcSctpSocketHandoverState::OutgoingStream* state = nullptr) : stream_id_(stream_id), + priority_(priority), next_unordered_mid_(MID(state ? state->next_unordered_mid : 0)), next_ordered_mid_(MID(state ? state->next_ordered_mid : 0)), next_ssn_(SSN(state ? state->next_ssn : 0)), @@ -166,6 +171,9 @@ class RRSendQueue : public SendQueue { // expired non-partially sent message. bool HasDataToSend(TimeMs now); + void set_priority(StreamPriority priority) { priority_ = priority; } + StreamPriority priority() const { return priority_; } + void AddHandoverState( DcSctpSocketHandoverState::OutgoingStream& state) const; @@ -218,6 +226,7 @@ class RRSendQueue : public SendQueue { bool IsConsistent() const; const StreamID stream_id_; + StreamPriority priority_; PauseState pause_state_ = PauseState::kNotPaused; // MIDs are different for unordered and ordered messages sent on a stream. MID next_unordered_mid_; @@ -247,6 +256,7 @@ class RRSendQueue : public SendQueue { const std::string log_prefix_; const size_t buffer_size_; + const StreamPriority default_priority_; // Called when the buffered amount is below what has been set using // `SetBufferedAmountLowThreshold`. diff --git a/net/dcsctp/tx/rr_send_queue_test.cc b/net/dcsctp/tx/rr_send_queue_test.cc index fbbce58de1..3966c17e58 100644 --- a/net/dcsctp/tx/rr_send_queue_test.cc +++ b/net/dcsctp/tx/rr_send_queue_test.cc @@ -32,6 +32,7 @@ constexpr TimeMs kNow = TimeMs(0); constexpr StreamID kStreamID(1); constexpr PPID kPPID(53); constexpr size_t kMaxQueueSize = 1000; +constexpr StreamPriority kDefaultPriority(10); constexpr size_t kBufferedAmountLowThreshold = 500; constexpr size_t kOneFragmentPacketSize = 100; constexpr size_t kTwoFragmentPacketSize = 101; @@ -41,6 +42,7 @@ class RRSendQueueTest : public testing::Test { RRSendQueueTest() : buf_("log: ", kMaxQueueSize, + kDefaultPriority, on_buffered_amount_low_.AsStdFunction(), kBufferedAmountLowThreshold, on_total_buffered_amount_low_.AsStdFunction()) {} @@ -759,5 +761,39 @@ TEST_F(RRSendQueueTest, WillStayInAStreamAsLongAsThatMessageIsSending) { EXPECT_FALSE(buf_.Produce(kNow, kOneFragmentPacketSize).has_value()); } + +TEST_F(RRSendQueueTest, StreamsHaveInitialPriority) { + EXPECT_EQ(buf_.GetStreamPriority(StreamID(1)), kDefaultPriority); + + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, std::vector(40))); + EXPECT_EQ(buf_.GetStreamPriority(StreamID(2)), kDefaultPriority); +} + +TEST_F(RRSendQueueTest, CanChangeStreamPriority) { + buf_.SetStreamPriority(StreamID(1), StreamPriority(42)); + EXPECT_EQ(buf_.GetStreamPriority(StreamID(1)), StreamPriority(42)); + + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, std::vector(40))); + buf_.SetStreamPriority(StreamID(2), StreamPriority(42)); + EXPECT_EQ(buf_.GetStreamPriority(StreamID(2)), StreamPriority(42)); +} + +TEST_F(RRSendQueueTest, WillHandoverPriority) { + buf_.SetStreamPriority(StreamID(1), StreamPriority(42)); + + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, std::vector(40))); + buf_.SetStreamPriority(StreamID(2), StreamPriority(42)); + + DcSctpSocketHandoverState state; + buf_.AddHandoverState(state); + + RRSendQueue q2("log: ", kMaxQueueSize, kDefaultPriority, + on_buffered_amount_low_.AsStdFunction(), + kBufferedAmountLowThreshold, + on_total_buffered_amount_low_.AsStdFunction()); + q2.RestoreFromState(state); + EXPECT_EQ(q2.GetStreamPriority(StreamID(1)), StreamPriority(42)); + EXPECT_EQ(q2.GetStreamPriority(StreamID(2)), StreamPriority(42)); +} } // namespace } // namespace dcsctp