diff --git a/api/transport/sctp_transport_factory_interface.h b/api/transport/sctp_transport_factory_interface.h index 403899bcf1..cff6b4bbca 100644 --- a/api/transport/sctp_transport_factory_interface.h +++ b/api/transport/sctp_transport_factory_interface.h @@ -18,12 +18,9 @@ // These classes are not part of the API, and are treated as opaque pointers. namespace cricket { class SctpTransportInternal; +class DtlsTransportInternal; } // namespace cricket -namespace rtc { -class PacketTransportInternal; -} // namespace rtc - namespace webrtc { // Factory class which can be used to allow fake SctpTransports to be injected @@ -37,7 +34,7 @@ class SctpTransportFactoryInterface { // Create an SCTP transport using `channel` for the underlying transport. virtual std::unique_ptr CreateSctpTransport( const Environment& env, - rtc::PacketTransportInternal* channel) = 0; + cricket::DtlsTransportInternal* channel) = 0; }; } // namespace webrtc diff --git a/media/BUILD.gn b/media/BUILD.gn index c4ba3fad40..a931c6c489 100644 --- a/media/BUILD.gn +++ b/media/BUILD.gn @@ -698,6 +698,7 @@ rtc_source_set("rtc_data_sctp_transport_internal") { "../api:priority", "../api:rtc_error", "../api/transport:datagram_transport_interface", + "../p2p:dtls_transport_internal", "../p2p:packet_transport_internal", "../rtc_base:copy_on_write_buffer", "../rtc_base:threading", @@ -714,16 +715,21 @@ if (rtc_build_dcsctp) { ":media_channel", ":rtc_data_sctp_transport_internal", "../api:array_view", + "../api:dtls_transport_interface", "../api:libjingle_peerconnection_api", "../api:priority", + "../api:rtc_error", + "../api:sequence_checker", "../api/environment", "../api/task_queue:pending_task_safety_flag", "../api/task_queue:task_queue", + "../api/transport:datagram_transport_interface", "../net/dcsctp/public:factory", "../net/dcsctp/public:socket", "../net/dcsctp/public:types", "../net/dcsctp/public:utils", "../net/dcsctp/timer:task_queue_timeout", + "../p2p:dtls_transport_internal", "../p2p:packet_transport_internal", "../rtc_base:checks", "../rtc_base:copy_on_write_buffer", @@ -753,6 +759,7 @@ rtc_library("rtc_data_sctp_transport_factory") { ":rtc_data_sctp_transport_internal", "../api/environment", "../api/transport:sctp_transport_factory_interface", + "../p2p:dtls_transport_internal", "../rtc_base:threading", "../rtc_base/system:unused", ] @@ -953,6 +960,7 @@ if (rtc_include_tests) { "../api/task_queue", "../api/test/video:function_video_factory", "../api/transport:bitrate_settings", + "../api/transport:datagram_transport_interface", "../api/transport:field_trial_based_config", "../api/transport/rtp:rtp_source", "../api/units:data_rate", diff --git a/media/sctp/dcsctp_transport.cc b/media/sctp/dcsctp_transport.cc index 9d35fc236a..f33bf82d21 100644 --- a/media/sctp/dcsctp_transport.cc +++ b/media/sctp/dcsctp_transport.cc @@ -11,24 +11,39 @@ #include "media/sctp/dcsctp_transport.h" #include +#include #include +#include #include +#include #include +#include #include #include #include "absl/strings/string_view.h" #include "api/array_view.h" #include "api/data_channel_interface.h" +#include "api/dtls_transport_interface.h" #include "api/environment/environment.h" #include "api/priority.h" -#include "media/base/media_channel.h" +#include "api/rtc_error.h" +#include "api/sequence_checker.h" +#include "api/task_queue/task_queue_base.h" +#include "api/transport/data_channel_transport_interface.h" +#include "media/sctp/sctp_transport_internal.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" #include "net/dcsctp/public/dcsctp_socket_factory.h" #include "net/dcsctp/public/packet_observer.h" #include "net/dcsctp/public/text_pcap_packet_observer.h" +#include "net/dcsctp/public/timeout.h" #include "net/dcsctp/public/types.h" #include "p2p/base/packet_transport_internal.h" +#include "p2p/dtls/dtls_transport_internal.h" #include "rtc_base/checks.h" +#include "rtc_base/copy_on_write_buffer.h" #include "rtc_base/logging.h" #include "rtc_base/network/received_packet.h" #include "rtc_base/socket.h" @@ -119,15 +134,16 @@ bool IsEmptyPPID(dcsctp::PPID ppid) { DcSctpTransport::DcSctpTransport(const Environment& env, rtc::Thread* network_thread, - rtc::PacketTransportInternal* transport) + cricket::DtlsTransportInternal* transport) : DcSctpTransport(env, network_thread, transport, std::make_unique()) {} + DcSctpTransport::DcSctpTransport( const Environment& env, rtc::Thread* network_thread, - rtc::PacketTransportInternal* transport, + cricket::DtlsTransportInternal* transport, std::unique_ptr socket_factory) : network_thread_(network_thread), transport_(transport), @@ -168,7 +184,7 @@ void DcSctpTransport::SetDataChannelSink(DataChannelSink* sink) { } void DcSctpTransport::SetDtlsTransport( - rtc::PacketTransportInternal* transport) { + cricket::DtlsTransportInternal* transport) { RTC_DCHECK_RUN_ON(network_thread_); DisconnectTransportSignals(); transport_ = transport; @@ -662,6 +678,11 @@ void DcSctpTransport::ConnectTransportSignals() { data_channel_sink_->OnTransportClosed({}); } }); + transport_->SubscribeDtlsTransportState( + this, [this](cricket::DtlsTransportInternal* transport, + DtlsTransportState state) { + OnDtlsTransportState(transport, state); + }); } void DcSctpTransport::DisconnectTransportSignals() { @@ -672,6 +693,7 @@ void DcSctpTransport::DisconnectTransportSignals() { transport_->SignalWritableState.disconnect(this); transport_->DeregisterReceivedPacketCallback(this); transport_->SetOnCloseCallback(nullptr); + transport_->UnsubscribeDtlsTransportState(this); } void DcSctpTransport::OnTransportWritableState( @@ -680,10 +702,30 @@ void DcSctpTransport::OnTransportWritableState( RTC_DCHECK_EQ(transport_, transport); RTC_DLOG(LS_VERBOSE) << debug_name_ << "->OnTransportWritableState(), writable=" - << transport->writable(); + << transport->writable() << " socket: " + << (socket_ ? std::to_string( + static_cast(socket_->state())) + : "UNSET"); MaybeConnectSocket(); } +void DcSctpTransport::OnDtlsTransportState( + cricket::DtlsTransportInternal* transport, + webrtc::DtlsTransportState state) { + if (state == DtlsTransportState::kNew && socket_) { + // IF DTLS restart (DtlsTransportState::kNew) + // THEN + // restart socket so that we send an SCPT init + // before any outgoing messages. This is needed + // after DTLS fingerprint changed since peer will discard + // messages with crypto derived from old fingerprint. + RTC_DLOG(LS_INFO) << debug_name_ << " DTLS restart"; + dcsctp::DcSctpOptions options = socket_->options(); + socket_.reset(); + Start(options.local_port, options.remote_port, options.max_message_size); + } +} + void DcSctpTransport::OnTransportReadPacket( rtc::PacketTransportInternal* /* transport */, const rtc::ReceivedPacket& packet) { diff --git a/media/sctp/dcsctp_transport.h b/media/sctp/dcsctp_transport.h index 030babed78..f80214b0d5 100644 --- a/media/sctp/dcsctp_transport.h +++ b/media/sctp/dcsctp_transport.h @@ -27,6 +27,7 @@ #include "net/dcsctp/public/types.h" #include "net/dcsctp/timer/task_queue_timeout.h" #include "p2p/base/packet_transport_internal.h" +#include "p2p/dtls/dtls_transport_internal.h" #include "rtc_base/containers/flat_map.h" #include "rtc_base/copy_on_write_buffer.h" #include "rtc_base/network/received_packet.h" @@ -44,17 +45,17 @@ class DcSctpTransport : public cricket::SctpTransportInternal, public: DcSctpTransport(const Environment& env, rtc::Thread* network_thread, - rtc::PacketTransportInternal* transport); + cricket::DtlsTransportInternal* transport); DcSctpTransport(const Environment& env, rtc::Thread* network_thread, - rtc::PacketTransportInternal* transport, + cricket::DtlsTransportInternal* transport, std::unique_ptr socket_factory); ~DcSctpTransport() override; // cricket::SctpTransportInternal void SetOnConnectedCallback(std::function callback) override; void SetDataChannelSink(DataChannelSink* sink) override; - void SetDtlsTransport(rtc::PacketTransportInternal* transport) override; + void SetDtlsTransport(cricket::DtlsTransportInternal* transport) override; bool Start(int local_sctp_port, int remote_sctp_port, int max_message_size) override; @@ -102,10 +103,12 @@ class DcSctpTransport : public cricket::SctpTransportInternal, void OnTransportWritableState(rtc::PacketTransportInternal* transport); void OnTransportReadPacket(rtc::PacketTransportInternal* transport, const rtc::ReceivedPacket& packet); + void OnDtlsTransportState(cricket::DtlsTransportInternal* transport, + webrtc::DtlsTransportState); void MaybeConnectSocket(); rtc::Thread* network_thread_; - rtc::PacketTransportInternal* transport_; + cricket::DtlsTransportInternal* transport_; Environment env_; Random random_; diff --git a/media/sctp/dcsctp_transport_unittest.cc b/media/sctp/dcsctp_transport_unittest.cc index eb9dc67d71..a6de226032 100644 --- a/media/sctp/dcsctp_transport_unittest.cc +++ b/media/sctp/dcsctp_transport_unittest.cc @@ -11,20 +11,27 @@ #include "media/sctp/dcsctp_transport.h" #include +#include #include #include "api/environment/environment.h" #include "api/environment/environment_factory.h" #include "api/priority.h" +#include "api/rtc_error.h" +#include "api/transport/data_channel_transport_interface.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" #include "net/dcsctp/public/mock_dcsctp_socket.h" #include "net/dcsctp/public/mock_dcsctp_socket_factory.h" #include "net/dcsctp/public/types.h" -#include "p2p/base/fake_packet_transport.h" +#include "p2p/dtls/fake_dtls_transport.h" +#include "rtc_base/copy_on_write_buffer.h" +#include "rtc_base/thread.h" +#include "test/gmock.h" #include "test/gtest.h" using ::testing::_; using ::testing::ByMove; -using ::testing::DoAll; using ::testing::ElementsAre; using ::testing::InSequence; using ::testing::Invoke; @@ -36,6 +43,9 @@ namespace webrtc { namespace { +constexpr char kTransportName[] = "transport"; +constexpr int kComponent = 77; + const PriorityValue kDefaultPriority = PriorityValue(Priority::kLow); class MockDataChannelSink : public DataChannelSink { @@ -58,7 +68,7 @@ static_assert(!std::is_abstract_v); class Peer { public: Peer() - : fake_packet_transport_("transport"), + : fake_dtls_transport_(kTransportName, kComponent), simulated_clock_(1000), env_(CreateEnvironment(&simulated_clock_)) { auto socket_ptr = std::make_unique(); @@ -71,13 +81,13 @@ class Peer { .WillOnce(Return(ByMove(std::move(socket_ptr)))); sctp_transport_ = std::make_unique( - env_, rtc::Thread::Current(), &fake_packet_transport_, + env_, rtc::Thread::Current(), &fake_dtls_transport_, std::move(mock_dcsctp_socket_factory)); sctp_transport_->SetDataChannelSink(&sink_); sctp_transport_->SetOnConnectedCallback([this]() { sink_.OnConnected(); }); } - rtc::FakePacketTransport fake_packet_transport_; + cricket::FakeDtlsTransport fake_dtls_transport_; webrtc::SimulatedClock simulated_clock_; Environment env_; dcsctp::MockDcSctpSocket* socket_; @@ -89,7 +99,7 @@ class Peer { TEST(DcSctpTransportTest, OpenSequence) { rtc::AutoThread main_thread; Peer peer_a; - peer_a.fake_packet_transport_.SetWritable(true); + peer_a.fake_dtls_transport_.SetWritable(true); EXPECT_CALL(*peer_a.socket_, Connect) .Times(1) @@ -107,8 +117,8 @@ TEST(DcSctpTransportTest, CloseSequence) { rtc::AutoThread main_thread; Peer peer_a; Peer peer_b; - peer_a.fake_packet_transport_.SetDestination(&peer_b.fake_packet_transport_, - false); + peer_a.fake_dtls_transport_.SetDestination(&peer_b.fake_dtls_transport_, + false); { InSequence sequence; @@ -153,8 +163,8 @@ TEST(DcSctpTransportTest, CloseSequenceSimultaneous) { rtc::AutoThread main_thread; Peer peer_a; Peer peer_b; - peer_a.fake_packet_transport_.SetDestination(&peer_b.fake_packet_transport_, - false); + peer_a.fake_dtls_transport_.SetDestination(&peer_b.fake_dtls_transport_, + false); { InSequence sequence; diff --git a/media/sctp/sctp_transport_factory.cc b/media/sctp/sctp_transport_factory.cc index 72aa80c94d..9f6424b3c0 100644 --- a/media/sctp/sctp_transport_factory.cc +++ b/media/sctp/sctp_transport_factory.cc @@ -11,6 +11,7 @@ #include "media/sctp/sctp_transport_factory.h" #include "api/environment/environment.h" +#include "p2p/dtls/dtls_transport_internal.h" #include "rtc_base/system/unused.h" #ifdef WEBRTC_HAVE_DCSCTP @@ -25,9 +26,8 @@ SctpTransportFactory::SctpTransportFactory(rtc::Thread* network_thread) } std::unique_ptr -SctpTransportFactory::CreateSctpTransport( - const webrtc::Environment& env, - rtc::PacketTransportInternal* transport) { +SctpTransportFactory::CreateSctpTransport(const webrtc::Environment& env, + DtlsTransportInternal* transport) { std::unique_ptr result; #ifdef WEBRTC_HAVE_DCSCTP result = std::unique_ptr( diff --git a/media/sctp/sctp_transport_factory.h b/media/sctp/sctp_transport_factory.h index 14eb648376..f1eefe566c 100644 --- a/media/sctp/sctp_transport_factory.h +++ b/media/sctp/sctp_transport_factory.h @@ -26,7 +26,7 @@ class SctpTransportFactory : public webrtc::SctpTransportFactoryInterface { std::unique_ptr CreateSctpTransport( const webrtc::Environment& env, - rtc::PacketTransportInternal* transport) override; + DtlsTransportInternal* transport) override; private: rtc::Thread* network_thread_; diff --git a/media/sctp/sctp_transport_internal.h b/media/sctp/sctp_transport_internal.h index 2dad0cd2bf..24e6fb83a0 100644 --- a/media/sctp/sctp_transport_internal.h +++ b/media/sctp/sctp_transport_internal.h @@ -23,6 +23,7 @@ #include "api/transport/data_channel_transport_interface.h" #include "media/base/media_channel.h" #include "p2p/base/packet_transport_internal.h" +#include "p2p/dtls/dtls_transport_internal.h" #include "rtc_base/copy_on_write_buffer.h" #include "rtc_base/thread.h" @@ -83,7 +84,7 @@ class SctpTransportInternal { // Changes what underlying DTLS transport is uses. Used when switching which // bundled transport the SctpTransport uses. - virtual void SetDtlsTransport(rtc::PacketTransportInternal* transport) = 0; + virtual void SetDtlsTransport(cricket::DtlsTransportInternal* transport) = 0; // When Start is called, connects as soon as possible; this can be called // before DTLS completes, in which case the connection will begin when DTLS diff --git a/pc/data_channel_integrationtest.cc b/pc/data_channel_integrationtest.cc index d779b23142..468366252d 100644 --- a/pc/data_channel_integrationtest.cc +++ b/pc/data_channel_integrationtest.cc @@ -10,17 +10,22 @@ #include +#include #include #include +#include #include #include #include +#include #include #include "absl/algorithm/container.h" #include "api/data_channel_interface.h" #include "api/dtls_transport_interface.h" +#include "api/jsep.h" #include "api/peer_connection_interface.h" +#include "api/rtc_error.h" #include "api/scoped_refptr.h" #include "api/sctp_transport_interface.h" #include "api/stats/rtc_stats_report.h" @@ -30,6 +35,7 @@ #include "p2p/base/transport_info.h" #include "pc/media_session.h" #include "pc/session_description.h" +#include "pc/test/fake_rtc_certificate_generator.h" #include "pc/test/integration_test_helpers.h" #include "pc/test/mock_peer_connection_observers.h" #include "rtc_base/copy_on_write_buffer.h" @@ -38,6 +44,7 @@ #include "rtc_base/gunit.h" #include "rtc_base/logging.h" #include "rtc_base/numerics/safe_conversions.h" +#include "rtc_base/strings/string_builder.h" #include "rtc_base/virtual_socket_server.h" #include "test/gmock.h" #include "test/gtest.h" @@ -1217,6 +1224,111 @@ TEST_F(DataChannelIntegrationTestUnifiedPlan, ASSERT_TRUE_WAIT(!callee()->data_observer()->IsOpen(), kDefaultTimeout); } +TEST_F(DataChannelIntegrationTestUnifiedPlan, DtlsRestart) { + RTCConfiguration config; + ASSERT_TRUE(CreatePeerConnectionWrappersWithConfig(config, config)); + PeerConnectionDependencies dependencies(nullptr); + std::unique_ptr cert_generator( + new FakeRTCCertificateGenerator()); + cert_generator->use_alternate_key(); + dependencies.cert_generator = std::move(cert_generator); + auto callee2 = CreatePeerConnectionWrapper("Callee2", nullptr, &config, + std::move(dependencies), nullptr, + /*reset_encoder_factory=*/false, + /*reset_decoder_factory=*/false); + ConnectFakeSignaling(); + + DataChannelInit dc_init; + dc_init.negotiated = true; + dc_init.id = 77; + caller()->CreateDataChannel("label", &dc_init); + callee()->CreateDataChannel("label", &dc_init); + callee2->CreateDataChannel("label", &dc_init); + + std::unique_ptr offer; + callee()->SetReceivedSdpMunger( + [&](std::unique_ptr& sdp) { + offer = sdp->Clone(); + }); + callee()->SetGeneratedSdpMunger( + [](std::unique_ptr& sdp) { + SetSdpType(sdp, SdpType::kPrAnswer); + }); + std::unique_ptr answer; + caller()->SetReceivedSdpMunger( + [&](std::unique_ptr& sdp) { + answer = sdp->Clone(); + }); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_FALSE(HasFailure()); + EXPECT_EQ(caller()->pc()->signaling_state(), + PeerConnectionInterface::kHaveRemotePrAnswer); + EXPECT_EQ(callee()->pc()->signaling_state(), + PeerConnectionInterface::kHaveLocalPrAnswer); + EXPECT_EQ_WAIT(DataChannelInterface::kOpen, caller()->data_channel()->state(), + kDefaultTimeout); + EXPECT_EQ_WAIT(DataChannelInterface::kOpen, callee()->data_channel()->state(), + kDefaultTimeout); + + callee2->set_signaling_message_receiver(caller()); + + std::atomic caller_sent_on_dc(0); + caller()->set_connection_change_callback( + [&](PeerConnectionInterface::PeerConnectionState new_state) { + if (new_state == + PeerConnectionInterface::PeerConnectionState::kConnected) { + caller()->data_channel()->SendAsync( + DataBuffer("KESO"), [&](RTCError err) { + caller_sent_on_dc.store(err.ok() ? 1 : -1); + }); + } + }); + + std::atomic callee2_sent_on_dc(0); + callee2->set_connection_change_callback( + [&](PeerConnectionInterface::PeerConnectionState new_state) { + if (new_state == + PeerConnectionInterface::PeerConnectionState::kConnected && + callee2->data_channel()->state() == DataChannelInterface::kOpen) { + callee2->data_channel()->SendAsync( + DataBuffer("KENT"), [&](RTCError err) { + callee2_sent_on_dc.store(err.ok() ? 1 : -1); + }); + } + }); + + callee2->data_observer()->set_state_change_callback( + [&](DataChannelInterface::DataState new_state) { + if (callee2->pc()->peer_connection_state() == + PeerConnectionInterface::PeerConnectionState::kConnected && + new_state == DataChannelInterface::kOpen) { + callee2->data_channel()->SendAsync( + DataBuffer("KENT"), [&](RTCError err) { + callee2_sent_on_dc.store(err.ok() ? 1 : -1); + }); + } + }); + + std::string offer_sdp; + EXPECT_TRUE(offer->ToString(&offer_sdp)); + callee2->ReceiveSdpMessage(SdpType::kOffer, offer_sdp); + EXPECT_EQ(caller()->pc()->signaling_state(), + PeerConnectionInterface::kStable); + EXPECT_EQ(callee2->pc()->signaling_state(), PeerConnectionInterface::kStable); + + EXPECT_EQ_WAIT(PeerConnectionInterface::PeerConnectionState::kConnected, + caller()->pc()->peer_connection_state(), kDefaultTimeout); + EXPECT_EQ_WAIT(PeerConnectionInterface::PeerConnectionState::kConnected, + callee2->pc()->peer_connection_state(), kDefaultTimeout); + + ASSERT_TRUE_WAIT(caller_sent_on_dc.load() != 0, kDefaultTimeout); + ASSERT_TRUE_WAIT(callee2_sent_on_dc.load() != 0, kDefaultTimeout); + EXPECT_EQ_WAIT("KENT", caller()->data_observer()->last_message(), + kDefaultTimeout); + EXPECT_EQ_WAIT("KESO", callee2->data_observer()->last_message(), + kDefaultTimeout); +} + #endif // WEBRTC_HAVE_SCTP } // namespace diff --git a/pc/sctp_transport_unittest.cc b/pc/sctp_transport_unittest.cc index 0c873c8c4e..293881a06f 100644 --- a/pc/sctp_transport_unittest.cc +++ b/pc/sctp_transport_unittest.cc @@ -53,7 +53,7 @@ class FakeCricketSctpTransport : public cricket::SctpTransportInternal { on_connected_callback_ = std::move(callback); } void SetDataChannelSink(DataChannelSink* sink) override {} - void SetDtlsTransport(rtc::PacketTransportInternal* transport) override {} + void SetDtlsTransport(cricket::DtlsTransportInternal* transport) override {} bool Start(int local_port, int remote_port, int max_message_size) override { return true; } @@ -114,7 +114,7 @@ class TestSctpTransportObserver : public SctpTransportObserverInterface { const std::vector& States() { return states_; } - const SctpTransportInformation LastReceivedInformation() { return info_; } + SctpTransportInformation LastReceivedInformation() { return info_; } private: std::vector states_; diff --git a/pc/test/integration_test_helpers.h b/pc/test/integration_test_helpers.h index 7d9216e752..c6e438a948 100644 --- a/pc/test/integration_test_helpers.h +++ b/pc/test/integration_test_helpers.h @@ -454,7 +454,7 @@ class PeerConnectionIntegrationWrapper : public PeerConnectionObserver, return data_channels_; } - const MockDataChannelObserver* data_observer() const { + MockDataChannelObserver* data_observer() const { if (data_observers_.size() == 0) { return nullptr; } @@ -740,6 +740,11 @@ class PeerConnectionIntegrationWrapper : public PeerConnectionObserver, return 0; } + void set_connection_change_callback( + std::function func) { + connection_change_callback_ = std::move(func); + } + private: // Constructor used by friend class PeerConnectionIntegrationBaseTest. explicit PeerConnectionIntegrationWrapper(const std::string& debug_name) @@ -780,11 +785,6 @@ class PeerConnectionIntegrationWrapper : public PeerConnectionObserver, : nullptr; } - void set_signaling_message_receiver( - SignalingMessageReceiver* signaling_message_receiver) { - signaling_message_receiver_ = signaling_message_receiver; - } - void set_signaling_delay_ms(int delay_ms) { signaling_delay_ms_ = delay_ms; } void set_signal_ice_candidates(bool signal) { @@ -962,6 +962,12 @@ class PeerConnectionIntegrationWrapper : public PeerConnectionObserver, } // SignalingMessageReceiver callbacks. + public: + void set_signaling_message_receiver( + SignalingMessageReceiver* signaling_message_receiver) { + signaling_message_receiver_ = signaling_message_receiver; + } + void ReceiveSdpMessage(SdpType type, const std::string& msg) override { if (type == SdpType::kOffer) { HandleIncomingOffer(msg); @@ -982,6 +988,7 @@ class PeerConnectionIntegrationWrapper : public PeerConnectionObserver, EXPECT_TRUE(result.value().ok()); } + private: // PeerConnectionObserver callbacks. void OnSignalingChange( PeerConnectionInterface::SignalingState new_state) override { @@ -1021,9 +1028,13 @@ class PeerConnectionIntegrationWrapper : public PeerConnectionObserver, PeerConnectionInterface::IceConnectionState new_state) override { standardized_ice_connection_state_history_.push_back(new_state); } + void OnConnectionChange( PeerConnectionInterface::PeerConnectionState new_state) override { peer_connection_state_history_.push_back(new_state); + if (connection_change_callback_) { + connection_change_callback_(new_state); + } } void OnIceGatheringChange( @@ -1072,6 +1083,7 @@ class PeerConnectionIntegrationWrapper : public PeerConnectionObserver, SendIceMessage(candidate->sdp_mid(), candidate->sdp_mline_index(), ice_sdp); last_candidate_gathered_ = candidate->candidate(); } + void OnIceCandidateError(const std::string& address, int port, const std::string& url, @@ -1172,6 +1184,9 @@ class PeerConnectionIntegrationWrapper : public PeerConnectionObserver, uint64_t audio_concealed_stat_ = 0; std::string rtp_stats_id_; + std::function + connection_change_callback_ = nullptr; + ScopedTaskSafety task_safety_; friend class PeerConnectionIntegrationBaseTest; diff --git a/pc/test/mock_peer_connection_observers.h b/pc/test/mock_peer_connection_observers.h index 063e0d68c5..600850e416 100644 --- a/pc/test/mock_peer_connection_observers.h +++ b/pc/test/mock_peer_connection_observers.h @@ -390,7 +390,13 @@ class MockDataChannelObserver : public DataChannelObserver { void OnBufferedAmountChange(uint64_t previous_amount) override {} - void OnStateChange() override { states_.push_back(channel_->state()); } + void OnStateChange() override { + states_.push_back(channel_->state()); + if (state_change_callback_) { + state_change_callback_(states_.back()); + } + } + void OnMessage(const DataBuffer& buffer) override { messages_.push_back( {std::string(buffer.data.data(), buffer.data.size()), @@ -417,10 +423,16 @@ class MockDataChannelObserver : public DataChannelObserver { return states_; } + void set_state_change_callback( + std::function func) { + state_change_callback_ = std::move(func); + } + private: rtc::scoped_refptr channel_; std::vector states_; std::vector messages_; + std::function state_change_callback_; }; class MockStatsObserver : public StatsObserver { diff --git a/test/pc/sctp/fake_sctp_transport.h b/test/pc/sctp/fake_sctp_transport.h index b5d0866799..3f71d448da 100644 --- a/test/pc/sctp/fake_sctp_transport.h +++ b/test/pc/sctp/fake_sctp_transport.h @@ -25,7 +25,7 @@ class FakeSctpTransport : public cricket::SctpTransportInternal { public: void SetOnConnectedCallback(std::function callback) override {} void SetDataChannelSink(webrtc::DataChannelSink* sink) override {} - void SetDtlsTransport(rtc::PacketTransportInternal* transport) override {} + void SetDtlsTransport(cricket::DtlsTransportInternal* transport) override {} bool Start(int local_port, int remote_port, int max_message_size) override { local_port_.emplace(local_port); remote_port_.emplace(remote_port); @@ -73,7 +73,7 @@ class FakeSctpTransportFactory : public webrtc::SctpTransportFactoryInterface { public: std::unique_ptr CreateSctpTransport( const webrtc::Environment& env, - rtc::PacketTransportInternal*) override { + cricket::DtlsTransportInternal*) override { last_fake_sctp_transport_ = new FakeSctpTransport(); return std::unique_ptr( last_fake_sctp_transport_);