diff --git a/media/BUILD.gn b/media/BUILD.gn index 74fd7a890c..b3b249ce95 100644 --- a/media/BUILD.gn +++ b/media/BUILD.gn @@ -423,6 +423,7 @@ if (rtc_build_dcsctp) { "../rtc_base:socket", "../rtc_base:stringutils", "../rtc_base:threading", + "../rtc_base/containers:flat_set", "../rtc_base/task_utils:pending_task_safety_flag", "../rtc_base/task_utils:to_queued_task", "../rtc_base/third_party/sigslot:sigslot", @@ -692,6 +693,16 @@ if (rtc_include_tests) { if (is_ios) { deps += [ ":rtc_media_unittests_bundle_data" ] } + + if (rtc_build_dcsctp) { + sources += [ "sctp/dcsctp_transport_unittest.cc" ] + deps += [ + ":rtc_data_dcsctp_transport", + "../net/dcsctp/public:factory", + "../net/dcsctp/public:mocks", + "../net/dcsctp/public:socket", + ] + } } } } diff --git a/media/sctp/dcsctp_transport.cc b/media/sctp/dcsctp_transport.cc index 0a671ced87..6527f6fcaa 100644 --- a/media/sctp/dcsctp_transport.cc +++ b/media/sctp/dcsctp_transport.cc @@ -116,10 +116,21 @@ bool IsEmptyPPID(dcsctp::PPID ppid) { DcSctpTransport::DcSctpTransport(rtc::Thread* network_thread, rtc::PacketTransportInternal* transport, Clock* clock) + : DcSctpTransport(network_thread, + transport, + clock, + std::make_unique()) {} + +DcSctpTransport::DcSctpTransport( + rtc::Thread* network_thread, + rtc::PacketTransportInternal* transport, + Clock* clock, + std::unique_ptr socket_factory) : network_thread_(network_thread), transport_(transport), clock_(clock), random_(clock_->TimeInMicroseconds()), + socket_factory_(std::move(socket_factory)), task_queue_timeout_factory_( *network_thread, [this]() { return TimeMillis(); }, @@ -175,9 +186,8 @@ bool DcSctpTransport::Start(int local_sctp_port, std::make_unique(debug_name_); } - dcsctp::DcSctpSocketFactory factory; - socket_ = - factory.Create(debug_name_, *this, std::move(packet_observer), options); + socket_ = socket_factory_->Create(debug_name_, *this, + std::move(packet_observer), options); } else { if (local_sctp_port != socket_->options().local_port || remote_sctp_port != socket_->options().remote_port) { @@ -202,6 +212,7 @@ bool DcSctpTransport::OpenStream(int sid) { << "): Transport is not started."; return false; } + local_close_.erase(dcsctp::StreamID(static_cast(sid))); return true; } @@ -213,6 +224,7 @@ bool DcSctpTransport::ResetStream(int sid) { return false; } dcsctp::StreamID streams[1] = {dcsctp::StreamID(static_cast(sid))}; + local_close_.insert(streams[0]); socket_->ResetStreams(streams); return true; } @@ -472,7 +484,11 @@ void DcSctpTransport::OnStreamsResetPerformed( RTC_LOG(LS_INFO) << debug_name_ << "->OnStreamsResetPerformed(...): Outgoing stream reset" << ", sid=" << stream_id.value(); - SignalClosingProcedureComplete(stream_id.value()); + if (!local_close_.contains(stream_id)) { + // When the close was not initiated locally, we can signal the end of the + // data channel close procedure when the remote ACKs the reset. + SignalClosingProcedureComplete(stream_id.value()); + } } } @@ -482,8 +498,18 @@ void DcSctpTransport::OnIncomingStreamsReset( RTC_LOG(LS_INFO) << debug_name_ << "->OnIncomingStreamsReset(...): Incoming stream reset" << ", sid=" << stream_id.value(); - SignalClosingProcedureStartedRemotely(stream_id.value()); - SignalClosingProcedureComplete(stream_id.value()); + if (!local_close_.contains(stream_id)) { + // When receiving an incoming stream reset event for a non local close + // procedure, the transport needs to reset the stream in the other + // direction too. + dcsctp::StreamID streams[1] = {stream_id}; + socket_->ResetStreams(streams); + SignalClosingProcedureStartedRemotely(stream_id.value()); + } else { + // The close procedure that was initiated locally is complete when we + // receive and incoming reset event. + SignalClosingProcedureComplete(stream_id.value()); + } } } diff --git a/media/sctp/dcsctp_transport.h b/media/sctp/dcsctp_transport.h index 11c2f829c5..5e3401d471 100644 --- a/media/sctp/dcsctp_transport.h +++ b/media/sctp/dcsctp_transport.h @@ -21,9 +21,11 @@ #include "media/sctp/sctp_transport_internal.h" #include "net/dcsctp/public/dcsctp_options.h" #include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/dcsctp_socket_factory.h" #include "net/dcsctp/public/types.h" #include "net/dcsctp/timer/task_queue_timeout.h" #include "p2p/base/packet_transport_internal.h" +#include "rtc_base/containers/flat_set.h" #include "rtc_base/copy_on_write_buffer.h" #include "rtc_base/random.h" #include "rtc_base/third_party/sigslot/sigslot.h" @@ -39,6 +41,10 @@ class DcSctpTransport : public cricket::SctpTransportInternal, DcSctpTransport(rtc::Thread* network_thread, rtc::PacketTransportInternal* transport, Clock* clock); + DcSctpTransport(rtc::Thread* network_thread, + rtc::PacketTransportInternal* transport, + Clock* clock, + std::unique_ptr socket_factory); ~DcSctpTransport() override; // cricket::SctpTransportInternal @@ -99,11 +105,13 @@ class DcSctpTransport : public cricket::SctpTransportInternal, Clock* clock_; Random random_; + std::unique_ptr socket_factory_; dcsctp::TaskQueueTimeoutFactory task_queue_timeout_factory_; std::unique_ptr socket_; std::string debug_name_ = "DcSctpTransport"; rtc::CopyOnWriteBuffer receive_buffer_; + flat_set local_close_; bool ready_to_send_data_ = false; }; diff --git a/media/sctp/dcsctp_transport_unittest.cc b/media/sctp/dcsctp_transport_unittest.cc new file mode 100644 index 0000000000..b382dc9548 --- /dev/null +++ b/media/sctp/dcsctp_transport_unittest.cc @@ -0,0 +1,129 @@ +/* + * Copyright 2022 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "media/sctp/dcsctp_transport.h" + +#include +#include + +#include "net/dcsctp/public/mock_dcsctp_socket.h" +#include "net/dcsctp/public/mock_dcsctp_socket_factory.h" +#include "p2p/base/fake_packet_transport.h" +#include "test/gtest.h" + +using ::testing::ByMove; +using ::testing::DoAll; +using ::testing::ElementsAre; +using ::testing::InSequence; +using ::testing::Invoke; +using ::testing::NiceMock; +using ::testing::Return; + +namespace webrtc { + +namespace { +class SctpInternalTransportObserver : public sigslot::has_slots<> { + public: + MOCK_METHOD(void, OnSignalReadyToSendData, ()); + MOCK_METHOD(void, OnSignalAssociationChangeCommunicationUp, ()); + MOCK_METHOD(void, OnSignalClosingProcedureStartedRemotely, (int)); + MOCK_METHOD(void, OnSignalClosingProcedureComplete, (int)); +}; + +class Peer { + public: + Peer() : fake_packet_transport_("transport"), simulated_clock_(1000) { + auto socket_ptr = std::make_unique(); + socket_ = socket_ptr.get(); + + auto mock_dcsctp_socket_factory = + std::make_unique(); + EXPECT_CALL(*mock_dcsctp_socket_factory, Create) + .Times(1) + .WillOnce(Return(ByMove(std::move(socket_ptr)))); + + sctp_transport_ = std::make_unique( + rtc::Thread::Current(), &fake_packet_transport_, &simulated_clock_, + std::move(mock_dcsctp_socket_factory)); + + sctp_transport_->SignalAssociationChangeCommunicationUp.connect( + static_cast(&observer_), + &SctpInternalTransportObserver::OnSignalReadyToSendData); + sctp_transport_->SignalAssociationChangeCommunicationUp.connect( + static_cast(&observer_), + &SctpInternalTransportObserver:: + OnSignalAssociationChangeCommunicationUp); + sctp_transport_->SignalClosingProcedureStartedRemotely.connect( + static_cast(&observer_), + &SctpInternalTransportObserver:: + OnSignalClosingProcedureStartedRemotely); + sctp_transport_->SignalClosingProcedureComplete.connect( + static_cast(&observer_), + &SctpInternalTransportObserver::OnSignalClosingProcedureComplete); + } + + rtc::FakePacketTransport fake_packet_transport_; + webrtc::SimulatedClock simulated_clock_; + dcsctp::MockDcSctpSocket* socket_; + std::unique_ptr sctp_transport_; + NiceMock observer_; +}; +} // namespace + +TEST(DcSctpTransportTest, OpenSequence) { + Peer peer_a; + peer_a.fake_packet_transport_.SetWritable(true); + + EXPECT_CALL(*peer_a.socket_, Connect) + .Times(1) + .WillOnce(Invoke(peer_a.sctp_transport_.get(), + &dcsctp::DcSctpSocketCallbacks::OnConnected)); + EXPECT_CALL(peer_a.observer_, OnSignalReadyToSendData); + EXPECT_CALL(peer_a.observer_, OnSignalAssociationChangeCommunicationUp); + + peer_a.sctp_transport_->Start(5000, 5000, 256 * 1024); +} + +TEST(DcSctpTransportTest, CloseSequence) { + Peer peer_a; + Peer peer_b; + peer_a.fake_packet_transport_.SetDestination(&peer_b.fake_packet_transport_, + false); + { + InSequence sequence; + + EXPECT_CALL(*peer_a.socket_, ResetStreams(ElementsAre(dcsctp::StreamID(1)))) + .WillOnce(DoAll( + Invoke(peer_b.sctp_transport_.get(), + &dcsctp::DcSctpSocketCallbacks::OnIncomingStreamsReset), + Invoke(peer_a.sctp_transport_.get(), + &dcsctp::DcSctpSocketCallbacks::OnStreamsResetPerformed), + Return(dcsctp::ResetStreamsStatus::kPerformed))); + + EXPECT_CALL(*peer_b.socket_, ResetStreams(ElementsAre(dcsctp::StreamID(1)))) + .WillOnce(DoAll( + Invoke(peer_a.sctp_transport_.get(), + &dcsctp::DcSctpSocketCallbacks::OnIncomingStreamsReset), + Invoke(peer_b.sctp_transport_.get(), + &dcsctp::DcSctpSocketCallbacks::OnStreamsResetPerformed), + Return(dcsctp::ResetStreamsStatus::kPerformed))); + + EXPECT_CALL(peer_a.observer_, OnSignalClosingProcedureComplete(1)); + EXPECT_CALL(peer_b.observer_, OnSignalClosingProcedureComplete(1)); + EXPECT_CALL(peer_b.observer_, OnSignalClosingProcedureStartedRemotely(1)); + } + + peer_a.sctp_transport_->Start(5000, 5000, 256 * 1024); + peer_b.sctp_transport_->Start(5000, 5000, 256 * 1024); + peer_a.sctp_transport_->OpenStream(1); + peer_a.sctp_transport_->ResetStream(1); +} + +} // namespace webrtc diff --git a/net/dcsctp/public/BUILD.gn b/net/dcsctp/public/BUILD.gn index 63fd463082..6cb289bf5b 100644 --- a/net/dcsctp/public/BUILD.gn +++ b/net/dcsctp/public/BUILD.gn @@ -57,8 +57,12 @@ rtc_source_set("factory") { rtc_source_set("mocks") { testonly = true - sources = [ "mock_dcsctp_socket.h" ] + sources = [ + "mock_dcsctp_socket.h", + "mock_dcsctp_socket_factory.h", + ] deps = [ + ":factory", ":socket", "../../../test:test_support", ] diff --git a/net/dcsctp/public/dcsctp_socket_factory.cc b/net/dcsctp/public/dcsctp_socket_factory.cc index 338d143424..ebcb5553e3 100644 --- a/net/dcsctp/public/dcsctp_socket_factory.cc +++ b/net/dcsctp/public/dcsctp_socket_factory.cc @@ -20,6 +20,9 @@ #include "net/dcsctp/socket/dcsctp_socket.h" namespace dcsctp { + +DcSctpSocketFactory::~DcSctpSocketFactory() = default; + std::unique_ptr DcSctpSocketFactory::Create( absl::string_view log_prefix, DcSctpSocketCallbacks& callbacks, diff --git a/net/dcsctp/public/dcsctp_socket_factory.h b/net/dcsctp/public/dcsctp_socket_factory.h index dcc68d9b54..ca429d3275 100644 --- a/net/dcsctp/public/dcsctp_socket_factory.h +++ b/net/dcsctp/public/dcsctp_socket_factory.h @@ -20,7 +20,8 @@ namespace dcsctp { class DcSctpSocketFactory { public: - std::unique_ptr Create( + virtual ~DcSctpSocketFactory(); + virtual std::unique_ptr Create( absl::string_view log_prefix, DcSctpSocketCallbacks& callbacks, std::unique_ptr packet_observer, diff --git a/net/dcsctp/public/mock_dcsctp_socket_factory.h b/net/dcsctp/public/mock_dcsctp_socket_factory.h new file mode 100644 index 0000000000..61f05577f2 --- /dev/null +++ b/net/dcsctp/public/mock_dcsctp_socket_factory.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2022 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PUBLIC_MOCK_DCSCTP_SOCKET_FACTORY_H_ +#define NET_DCSCTP_PUBLIC_MOCK_DCSCTP_SOCKET_FACTORY_H_ + +#include + +#include "net/dcsctp/public/dcsctp_socket_factory.h" +#include "test/gmock.h" + +namespace dcsctp { + +class MockDcSctpSocketFactory : public DcSctpSocketFactory { + public: + MOCK_METHOD(std::unique_ptr, + Create, + (absl::string_view log_prefix, + DcSctpSocketCallbacks& callbacks, + std::unique_ptr packet_observer, + const DcSctpOptions& options), + (override)); +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PUBLIC_MOCK_DCSCTP_SOCKET_FACTORY_H_ diff --git a/sdk/android/instrumentationtests/src/org/webrtc/PeerConnectionEndToEndTest.java b/sdk/android/instrumentationtests/src/org/webrtc/PeerConnectionEndToEndTest.java index 8efefb3903..f71bd36063 100644 --- a/sdk/android/instrumentationtests/src/org/webrtc/PeerConnectionEndToEndTest.java +++ b/sdk/android/instrumentationtests/src/org/webrtc/PeerConnectionEndToEndTest.java @@ -923,7 +923,6 @@ public class PeerConnectionEndToEndTest { answeringExpectations.expectStateChange(DataChannel.State.CLOSING); offeringExpectations.expectStateChange(DataChannel.State.CLOSED); answeringExpectations.expectStateChange(DataChannel.State.CLOSED); - answeringExpectations.dataChannel.close(); offeringExpectations.dataChannel.close(); assertTrue(offeringExpectations.waitForAllExpectationsToBeSatisfied(DEFAULT_TIMEOUT_SECONDS)); assertTrue(answeringExpectations.waitForAllExpectationsToBeSatisfied(DEFAULT_TIMEOUT_SECONDS)); @@ -1094,7 +1093,6 @@ public class PeerConnectionEndToEndTest { answeringExpectations.expectStateChange(DataChannel.State.CLOSING); offeringExpectations.expectStateChange(DataChannel.State.CLOSED); answeringExpectations.expectStateChange(DataChannel.State.CLOSED); - answeringExpectations.dataChannel.close(); offeringExpectations.dataChannel.close(); assertTrue(offeringExpectations.waitForAllExpectationsToBeSatisfied(DEFAULT_TIMEOUT_SECONDS)); assertTrue(answeringExpectations.waitForAllExpectationsToBeSatisfied(DEFAULT_TIMEOUT_SECONDS));