From 175aa2e95c19a1e02786b09602c6f62b3554f1a2 Mon Sep 17 00:00:00 2001 From: Bjorn Mellem Date: Thu, 8 Nov 2018 11:23:22 -0800 Subject: [PATCH] Implement data channels over media transport. This changes PeerConnection to allow sending and receiving data channel messages over the media transport. If |use_media_transport_for_data_channels| is set, PeerConnection will use a DCT_MEDIA_TRANSPORT mode for data channels. DCT_MEDIA_TRANSPORT acts exactly like DCT_SCTP within the data channel and peer connection layers. On the transport layer, it uses the media transport instead of SCTP. It appears as an RTP data channel in SDP (just as media over media-transport appears as RTP in SDP). Bug: webrtc:9719 Change-Id: I6a90142bd3f43668479c825ed02689dcd0d58b78 Reviewed-on: https://webrtc-review.googlesource.com/c/109740 Commit-Queue: Bjorn Mellem Reviewed-by: Steve Anton Reviewed-by: Niels Moller Cr-Commit-Position: refs/heads/master@{#25575} --- api/test/DEPS | 1 + api/test/loopback_media_transport.h | 119 ++++++- api/test/loopback_media_transport_unittest.cc | 46 +++ media/base/mediaengine.h | 7 +- pc/BUILD.gn | 1 + pc/datachannel.cc | 26 +- pc/datachannel.h | 2 + pc/jseptransportcontroller.cc | 16 +- pc/jseptransportcontroller.h | 4 + pc/peerconnection.cc | 322 ++++++++++++++---- pc/peerconnection.h | 45 +++ pc/peerconnection_datachannel_unittest.cc | 81 ++++- pc/peerconnection_integrationtest.cc | 185 ++++++++-- 13 files changed, 752 insertions(+), 103 deletions(-) diff --git a/api/test/DEPS b/api/test/DEPS index 1fc1f7cb30..9c293b323e 100644 --- a/api/test/DEPS +++ b/api/test/DEPS @@ -9,5 +9,6 @@ specific_include_rules = { "+rtc_base/asyncinvoker.h", "+rtc_base/criticalsection.h", "+rtc_base/thread.h", + "+rtc_base/thread_checker.h", ], } diff --git a/api/test/loopback_media_transport.h b/api/test/loopback_media_transport.h index f3f24d4c98..2524fb47b1 100644 --- a/api/test/loopback_media_transport.h +++ b/api/test/loopback_media_transport.h @@ -11,15 +11,91 @@ #ifndef API_TEST_LOOPBACK_MEDIA_TRANSPORT_H_ #define API_TEST_LOOPBACK_MEDIA_TRANSPORT_H_ +#include #include #include "api/media_transport_interface.h" #include "rtc_base/asyncinvoker.h" #include "rtc_base/criticalsection.h" #include "rtc_base/thread.h" +#include "rtc_base/thread_checker.h" namespace webrtc { +// Wrapper used to hand out unique_ptrs to loopback media transports without +// ownership changes. +class WrapperMediaTransport : public MediaTransportInterface { + public: + explicit WrapperMediaTransport(MediaTransportInterface* wrapped) + : wrapped_(wrapped) {} + + RTCError SendAudioFrame(uint64_t channel_id, + MediaTransportEncodedAudioFrame frame) override { + return wrapped_->SendAudioFrame(channel_id, std::move(frame)); + } + + RTCError SendVideoFrame( + uint64_t channel_id, + const MediaTransportEncodedVideoFrame& frame) override { + return wrapped_->SendVideoFrame(channel_id, frame); + } + + RTCError RequestKeyFrame(uint64_t channel_id) override { + return wrapped_->RequestKeyFrame(channel_id); + } + + void SetReceiveAudioSink(MediaTransportAudioSinkInterface* sink) override { + wrapped_->SetReceiveAudioSink(sink); + } + + void SetReceiveVideoSink(MediaTransportVideoSinkInterface* sink) override { + wrapped_->SetReceiveVideoSink(sink); + } + + void SetTargetTransferRateObserver( + webrtc::TargetTransferRateObserver* observer) override { + wrapped_->SetTargetTransferRateObserver(observer); + } + + void SetMediaTransportStateCallback( + MediaTransportStateCallback* callback) override { + wrapped_->SetMediaTransportStateCallback(callback); + } + + RTCError SendData(int channel_id, + const SendDataParams& params, + const rtc::CopyOnWriteBuffer& buffer) override { + return wrapped_->SendData(channel_id, params, buffer); + } + + RTCError CloseChannel(int channel_id) override { + return wrapped_->CloseChannel(channel_id); + } + + void SetDataSink(DataChannelSink* sink) override { + wrapped_->SetDataSink(sink); + } + + private: + MediaTransportInterface* wrapped_; +}; + +class WrapperMediaTransportFactory : public MediaTransportFactory { + public: + explicit WrapperMediaTransportFactory(MediaTransportInterface* wrapped) + : wrapped_(wrapped) {} + + RTCErrorOr> CreateMediaTransport( + rtc::PacketTransportInternal* packet_transport, + rtc::Thread* network_thread, + const MediaTransportSettings& settings) override { + return {absl::make_unique(wrapped_)}; + } + + private: + MediaTransportInterface* wrapped_; +}; + // Contains two MediaTransportsInterfaces that are connected to each other. // Currently supports audio only. class MediaTransportPair { @@ -31,6 +107,19 @@ class MediaTransportPair { MediaTransportInterface* first() { return &first_; } MediaTransportInterface* second() { return &second_; } + std::unique_ptr first_factory() { + return absl::make_unique(&first_); + } + + std::unique_ptr second_factory() { + return absl::make_unique(&second_); + } + + void SetState(MediaTransportState state) { + first_.SetState(state); + second_.SetState(state); + } + void FlushAsyncInvokes() { first_.FlushAsyncInvokes(); second_.FlushAsyncInvokes(); @@ -81,7 +170,14 @@ class MediaTransportPair { webrtc::TargetTransferRateObserver* observer) override {} void SetMediaTransportStateCallback( - MediaTransportStateCallback* callback) override {} + MediaTransportStateCallback* callback) override { + rtc::CritScope lock(&sink_lock_); + state_callback_ = callback; + invoker_.AsyncInvoke(RTC_FROM_HERE, thread_, [this] { + RTC_DCHECK_RUN_ON(thread_); + OnStateChanged(); + }); + } RTCError SendData(int channel_id, const SendDataParams& params, @@ -109,6 +205,14 @@ class MediaTransportPair { data_sink_ = sink; } + void SetState(MediaTransportState state) { + invoker_.AsyncInvoke(RTC_FROM_HERE, thread_, [this, state] { + RTC_DCHECK_RUN_ON(thread_); + state_ = state; + OnStateChanged(); + }); + } + void FlushAsyncInvokes() { invoker_.Flush(thread_); } private: @@ -136,12 +240,25 @@ class MediaTransportPair { } } + void OnStateChanged() RTC_RUN_ON(thread_) { + rtc::CritScope lock(&sink_lock_); + if (state_callback_) { + state_callback_->OnStateChanged(state_); + } + } + rtc::Thread* const thread_; rtc::CriticalSection sink_lock_; MediaTransportAudioSinkInterface* sink_ RTC_GUARDED_BY(sink_lock_) = nullptr; DataChannelSink* data_sink_ RTC_GUARDED_BY(sink_lock_) = nullptr; + MediaTransportStateCallback* state_callback_ RTC_GUARDED_BY(sink_lock_) = + nullptr; + + MediaTransportState state_ RTC_GUARDED_BY(thread_) = + MediaTransportState::kPending; + LoopbackMediaTransport* const other_; rtc::AsyncInvoker invoker_; diff --git a/api/test/loopback_media_transport_unittest.cc b/api/test/loopback_media_transport_unittest.cc index f85413c55b..ba741a05ca 100644 --- a/api/test/loopback_media_transport_unittest.cc +++ b/api/test/loopback_media_transport_unittest.cc @@ -32,6 +32,11 @@ class MockDataChannelSink : public DataChannelSink { MOCK_METHOD1(OnChannelClosed, void(int)); }; +class MockStateCallback : public MediaTransportStateCallback { + public: + MOCK_METHOD1(OnStateChanged, void(MediaTransportState)); +}; + // Test only uses the sequence number. MediaTransportEncodedAudioFrame CreateAudioFrame(int sequence_number) { static constexpr int kSamplingRateHz = 48000; @@ -122,4 +127,45 @@ TEST(LoopbackMediaTransport, CloseDeliveredToSink) { transport_pair.second()->SetDataSink(nullptr); } +TEST(LoopbackMediaTransport, InitialStateDeliveredWhenCallbackSet) { + std::unique_ptr thread = rtc::Thread::Create(); + thread->Start(); + MediaTransportPair transport_pair(thread.get()); + + MockStateCallback state_callback; + + EXPECT_CALL(state_callback, OnStateChanged(MediaTransportState::kPending)); + transport_pair.first()->SetMediaTransportStateCallback(&state_callback); + transport_pair.FlushAsyncInvokes(); +} + +TEST(LoopbackMediaTransport, ChangedStateDeliveredWhenCallbackSet) { + std::unique_ptr thread = rtc::Thread::Create(); + thread->Start(); + MediaTransportPair transport_pair(thread.get()); + + transport_pair.SetState(MediaTransportState::kWritable); + transport_pair.FlushAsyncInvokes(); + + MockStateCallback state_callback; + + EXPECT_CALL(state_callback, OnStateChanged(MediaTransportState::kWritable)); + transport_pair.first()->SetMediaTransportStateCallback(&state_callback); + transport_pair.FlushAsyncInvokes(); +} + +TEST(LoopbackMediaTransport, StateChangeDeliveredToCallback) { + std::unique_ptr thread = rtc::Thread::Create(); + thread->Start(); + MediaTransportPair transport_pair(thread.get()); + + MockStateCallback state_callback; + + EXPECT_CALL(state_callback, OnStateChanged(MediaTransportState::kPending)); + EXPECT_CALL(state_callback, OnStateChanged(MediaTransportState::kWritable)); + transport_pair.first()->SetMediaTransportStateCallback(&state_callback); + transport_pair.SetState(MediaTransportState::kWritable); + transport_pair.FlushAsyncInvokes(); +} + } // namespace webrtc diff --git a/media/base/mediaengine.h b/media/base/mediaengine.h index e752da8988..62f43f9949 100644 --- a/media/base/mediaengine.h +++ b/media/base/mediaengine.h @@ -159,7 +159,12 @@ class CompositeMediaEngine : public MediaEngineInterface { std::pair engines_; }; -enum DataChannelType { DCT_NONE = 0, DCT_RTP = 1, DCT_SCTP = 2 }; +enum DataChannelType { + DCT_NONE = 0, + DCT_RTP = 1, + DCT_SCTP = 2, + DCT_MEDIA_TRANSPORT = 3 +}; class DataEngineInterface { public: diff --git a/pc/BUILD.gn b/pc/BUILD.gn index 182f07764c..97045ba66d 100644 --- a/pc/BUILD.gn +++ b/pc/BUILD.gn @@ -485,6 +485,7 @@ if (rtc_include_tests) { "../api:fake_frame_decryptor", "../api:fake_frame_encryptor", "../api:libjingle_peerconnection_api", + "../api:loopback_media_transport", "../api:mock_rtp", "../api/units:time_delta", "../logging:fake_rtc_event_log", diff --git a/pc/datachannel.cc b/pc/datachannel.cc index f819d26695..e989586607 100644 --- a/pc/datachannel.cc +++ b/pc/datachannel.cc @@ -118,6 +118,10 @@ rtc::scoped_refptr DataChannel::Create( return channel; } +bool DataChannel::IsSctpLike(cricket::DataChannelType type) { + return type == cricket::DCT_SCTP || type == cricket::DCT_MEDIA_TRANSPORT; +} + DataChannel::DataChannel(DataChannelProviderInterface* provider, cricket::DataChannelType dct, const std::string& label) @@ -147,7 +151,7 @@ bool DataChannel::Init(const InternalDataChannelInit& config) { return false; } handshake_state_ = kHandshakeReady; - } else if (data_channel_type_ == cricket::DCT_SCTP) { + } else if (IsSctpLike(data_channel_type_)) { if (config.id < -1 || config.maxRetransmits < -1 || config.maxRetransmitTime < -1) { RTC_LOG(LS_ERROR) << "Failed to initialize the SCTP data channel due to " @@ -241,7 +245,7 @@ bool DataChannel::Send(const DataBuffer& buffer) { if (!queued_send_data_.Empty()) { // Only SCTP DataChannel queues the outgoing data when the transport is // blocked. - RTC_DCHECK(data_channel_type_ == cricket::DCT_SCTP); + RTC_DCHECK(IsSctpLike(data_channel_type_)); if (!QueueSendDataMessage(buffer)) { RTC_LOG(LS_ERROR) << "Closing the DataChannel due to a failure to queue " "additional data."; @@ -273,7 +277,7 @@ void DataChannel::SetReceiveSsrc(uint32_t receive_ssrc) { void DataChannel::SetSctpSid(int sid) { RTC_DCHECK_LT(config_.id, 0); RTC_DCHECK_GE(sid, 0); - RTC_DCHECK_EQ(data_channel_type_, cricket::DCT_SCTP); + RTC_DCHECK(IsSctpLike(data_channel_type_)); if (config_.id == sid) { return; } @@ -283,7 +287,7 @@ void DataChannel::SetSctpSid(int sid) { } void DataChannel::OnClosingProcedureStartedRemotely(int sid) { - if (data_channel_type_ == cricket::DCT_SCTP && sid == config_.id && + if (IsSctpLike(data_channel_type_) && sid == config_.id && state_ != kClosing && state_ != kClosed) { // Don't bother sending queued data since the side that initiated the // closure wouldn't receive it anyway. See crbug.com/559394 for a lengthy @@ -299,7 +303,7 @@ void DataChannel::OnClosingProcedureStartedRemotely(int sid) { } void DataChannel::OnClosingProcedureComplete(int sid) { - if (data_channel_type_ == cricket::DCT_SCTP && sid == config_.id) { + if (IsSctpLike(data_channel_type_) && sid == config_.id) { // If the closing procedure is complete, we should have finished sending // all pending data and transitioned to kClosing already. RTC_DCHECK_EQ(state_, kClosing); @@ -310,7 +314,7 @@ void DataChannel::OnClosingProcedureComplete(int sid) { } void DataChannel::OnTransportChannelCreated() { - RTC_DCHECK(data_channel_type_ == cricket::DCT_SCTP); + RTC_DCHECK(IsSctpLike(data_channel_type_)); if (!connected_to_provider_) { connected_to_provider_ = provider_->ConnectDataChannel(this); } @@ -348,12 +352,12 @@ void DataChannel::OnDataReceived(const cricket::ReceiveDataParams& params, if (data_channel_type_ == cricket::DCT_RTP && params.ssrc != receive_ssrc_) { return; } - if (data_channel_type_ == cricket::DCT_SCTP && params.sid != config_.id) { + if (IsSctpLike(data_channel_type_) && params.sid != config_.id) { return; } if (params.type == cricket::DMT_CONTROL) { - RTC_DCHECK(data_channel_type_ == cricket::DCT_SCTP); + RTC_DCHECK(IsSctpLike(data_channel_type_)); if (handshake_state_ != kHandshakeWaitingForAck) { // Ignore it if we are not expecting an ACK message. RTC_LOG(LS_WARNING) @@ -570,7 +574,7 @@ bool DataChannel::SendDataMessage(const DataBuffer& buffer, bool queue_if_blocked) { cricket::SendDataParams send_params; - if (data_channel_type_ == cricket::DCT_SCTP) { + if (IsSctpLike(data_channel_type_)) { send_params.ordered = config_.ordered; // Send as ordered if it is still going through OPEN/ACK signaling. if (handshake_state_ != kHandshakeReady && !config_.ordered) { @@ -597,7 +601,7 @@ bool DataChannel::SendDataMessage(const DataBuffer& buffer, return true; } - if (data_channel_type_ != cricket::DCT_SCTP) { + if (!IsSctpLike(data_channel_type_)) { return false; } @@ -649,7 +653,7 @@ void DataChannel::QueueControlMessage(const rtc::CopyOnWriteBuffer& buffer) { bool DataChannel::SendControlMessage(const rtc::CopyOnWriteBuffer& buffer) { bool is_open_message = handshake_state_ == kHandshakeShouldSendOpen; - RTC_DCHECK_EQ(data_channel_type_, cricket::DCT_SCTP); + RTC_DCHECK(IsSctpLike(data_channel_type_)); RTC_DCHECK(writable_); RTC_DCHECK_GE(config_.id, 0); RTC_DCHECK(!is_open_message || !config_.negotiated); diff --git a/pc/datachannel.h b/pc/datachannel.h index cbb0c8b2cf..22ea354c21 100644 --- a/pc/datachannel.h +++ b/pc/datachannel.h @@ -122,6 +122,8 @@ class DataChannel : public DataChannelInterface, public sigslot::has_slots<> { const std::string& label, const InternalDataChannelInit& config); + static bool IsSctpLike(cricket::DataChannelType type); + virtual void RegisterObserver(DataChannelObserver* observer); virtual void UnregisterObserver(); diff --git a/pc/jseptransportcontroller.cc b/pc/jseptransportcontroller.cc index 1613e1e791..cb99686a98 100644 --- a/pc/jseptransportcontroller.cc +++ b/pc/jseptransportcontroller.cc @@ -143,6 +143,15 @@ MediaTransportInterface* JsepTransportController::GetMediaTransport( return jsep_transport->media_transport(); } +MediaTransportState JsepTransportController::GetMediaTransportState( + const std::string& mid) const { + auto jsep_transport = GetJsepTransportForMid(mid); + if (!jsep_transport) { + return MediaTransportState::kPending; + } + return jsep_transport->media_transport_state(); +} + cricket::DtlsTransportInternal* JsepTransportController::GetDtlsTransport( const std::string& mid) const { auto jsep_transport = GetJsepTransportForMid(mid); @@ -1042,7 +1051,7 @@ RTCError JsepTransportController::MaybeCreateJsepTransport( jsep_transport->SignalRtcpMuxActive.connect( this, &JsepTransportController::UpdateAggregateStates_n); jsep_transport->SignalMediaTransportStateChanged.connect( - this, &JsepTransportController::UpdateAggregateStates_n); + this, &JsepTransportController::OnMediaTransportStateChanged_n); SetTransportForMid(content_info.name, jsep_transport.get()); jsep_transports_by_name_[content_info.name] = std::move(jsep_transport); @@ -1224,6 +1233,11 @@ void JsepTransportController::OnTransportStateChanged_n( UpdateAggregateStates_n(); } +void JsepTransportController::OnMediaTransportStateChanged_n() { + SignalMediaTransportStateChanged(); + UpdateAggregateStates_n(); +} + void JsepTransportController::UpdateAggregateStates_n() { RTC_DCHECK(network_thread_->IsCurrent()); diff --git a/pc/jseptransportcontroller.h b/pc/jseptransportcontroller.h index 5747990cd6..8d89795ce3 100644 --- a/pc/jseptransportcontroller.h +++ b/pc/jseptransportcontroller.h @@ -117,6 +117,7 @@ class JsepTransportController : public sigslot::has_slots<> { const std::string& mid) const; MediaTransportInterface* GetMediaTransport(const std::string& mid) const; + MediaTransportState GetMediaTransportState(const std::string& mid) const; /********************* * ICE-related methods @@ -200,6 +201,8 @@ class JsepTransportController : public sigslot::has_slots<> { sigslot::signal1 SignalDtlsHandshakeError; + sigslot::signal<> SignalMediaTransportStateChanged; + private: RTCError ApplyDescription_n(bool local, SdpType type, @@ -311,6 +314,7 @@ class JsepTransportController : public sigslot::has_slots<> { const cricket::Candidates& candidates); void OnTransportRoleConflict_n(cricket::IceTransportInternal* transport); void OnTransportStateChanged_n(cricket::IceTransportInternal* transport); + void OnMediaTransportStateChanged_n(); void UpdateAggregateStates_n(); diff --git a/pc/peerconnection.cc b/pc/peerconnection.cc index 3c5a025f22..54ef67740e 100644 --- a/pc/peerconnection.cc +++ b/pc/peerconnection.cc @@ -631,6 +631,35 @@ absl::optional RTCConfigurationToIceConfigOptionalInt( return rtc_configuration_parameter; } +cricket::DataMessageType ToCricketDataMessageType(DataMessageType type) { + switch (type) { + case DataMessageType::kText: + return cricket::DMT_TEXT; + case DataMessageType::kBinary: + return cricket::DMT_BINARY; + case DataMessageType::kControl: + return cricket::DMT_CONTROL; + default: + return cricket::DMT_NONE; + } + return cricket::DMT_NONE; +} + +DataMessageType ToWebrtcDataMessageType(cricket::DataMessageType type) { + switch (type) { + case cricket::DMT_TEXT: + return DataMessageType::kText; + case cricket::DMT_BINARY: + return DataMessageType::kBinary; + case cricket::DMT_CONTROL: + return DataMessageType::kControl; + case cricket::DMT_NONE: + default: + RTC_NOTREACHED(); + } + return DataMessageType::kControl; +} + } // namespace // Upon completion, posts a task to execute the callback of the @@ -828,6 +857,7 @@ PeerConnection::~PeerConnection() { webrtc_session_desc_factory_.reset(); sctp_invoker_.reset(); sctp_factory_.reset(); + media_transport_invoker_.reset(); transport_controller_.reset(); // port_allocator_ lives on the network thread and should be destroyed there. @@ -1009,10 +1039,18 @@ bool PeerConnection::Initialize( } } - // Enable creation of RTP data channels if the kEnableRtpDataChannels is set. - // It takes precendence over the disable_sctp_data_channels - // PeerConnectionFactoryInterface::Options. - if (configuration.enable_rtp_data_channel) { + if (configuration.use_media_transport_for_data_channels) { + if (configuration.enable_rtp_data_channel) { + RTC_LOG(LS_ERROR) << "enable_rtp_data_channel and " + "use_media_transport_for_data_channels are " + "incompatible and cannot both be set to true"; + return false; + } + data_channel_type_ = cricket::DCT_MEDIA_TRANSPORT; + } else if (configuration.enable_rtp_data_channel) { + // Enable creation of RTP data channels if the kEnableRtpDataChannels is + // set. It takes precendence over the disable_sctp_data_channels + // PeerConnectionFactoryInterface::Options. data_channel_type_ = cricket::DCT_RTP; } else { // DTLS has to be enabled to use SCTP. @@ -2035,6 +2073,16 @@ RTCError PeerConnection::ApplyLocalDescription( // |local_description()|. RTC_DCHECK(local_description()); + if (!is_caller_) { + if (remote_description()) { + // Remote description was applied first, so this PC is the callee. + is_caller_ = false; + } else { + // Local description is applied first, so this PC is the caller. + is_caller_ = true; + } + } + RTCError error = PushdownTransportDescription(cricket::CS_LOCAL, type); if (!error.ok()) { return error; @@ -2117,7 +2165,7 @@ RTCError PeerConnection::ApplyLocalDescription( // If setting the description decided our SSL role, allocate any necessary // SCTP sids. rtc::SSLRole role; - if (data_channel_type() == cricket::DCT_SCTP && GetSctpSslRole(&role)) { + if (DataChannel::IsSctpLike(data_channel_type_) && GetSctpSslRole(&role)) { AllocateSctpSids(role); } @@ -2392,7 +2440,7 @@ RTCError PeerConnection::ApplyRemoteDescription( // If setting the description decided our SSL role, allocate any necessary // SCTP sids. rtc::SSLRole role; - if (data_channel_type() == cricket::DCT_SCTP && GetSctpSslRole(&role)) { + if (DataChannel::IsSctpLike(data_channel_type_) && GetSctpSslRole(&role)) { AllocateSctpSids(role); } @@ -2723,7 +2771,7 @@ RTCError PeerConnection::UpdateDataChannel( if (content.rejected) { DestroyDataChannel(); } else { - if (!rtp_data_channel_ && !sctp_transport_) { + if (!rtp_data_channel_ && !sctp_transport_ && !media_transport_) { if (!CreateDataChannel(content.name)) { LOG_AND_RETURN_ERROR(RTCErrorType::INTERNAL_ERROR, "Failed to create data channel."); @@ -4191,6 +4239,8 @@ absl::optional PeerConnection::GetDataMid() const { return rtp_data_channel_->content_name(); case cricket::DCT_SCTP: return sctp_mid_; + case cricket::DCT_MEDIA_TRANSPORT: + return media_transport_data_mid_; default: return absl::nullopt; } @@ -4553,7 +4603,7 @@ rtc::scoped_refptr PeerConnection::InternalCreateDataChannel( } InternalDataChannelInit new_config = config ? (*config) : InternalDataChannelInit(); - if (data_channel_type() == cricket::DCT_SCTP) { + if (DataChannel::IsSctpLike(data_channel_type_)) { if (new_config.id < 0) { rtc::SSLRole role; if ((GetSctpSslRole(&role)) && @@ -4584,7 +4634,7 @@ rtc::scoped_refptr PeerConnection::InternalCreateDataChannel( } rtp_data_channels_[channel->label()] = channel; } else { - RTC_DCHECK(channel->data_channel_type() == cricket::DCT_SCTP); + RTC_DCHECK(DataChannel::IsSctpLike(data_channel_type_)); sctp_data_channels_.push_back(channel); channel->SignalClosed.connect(this, &PeerConnection::OnSctpDataChannelClosed); @@ -4664,6 +4714,27 @@ void PeerConnection::OnDataChannelOpenMessage( NoteUsageEvent(UsageEvent::DATA_ADDED); } +bool PeerConnection::HandleOpenMessage_s( + const cricket::ReceiveDataParams& params, + const rtc::CopyOnWriteBuffer& buffer) { + if (params.type == cricket::DMT_CONTROL && IsOpenMessage(buffer)) { + // Received OPEN message; parse and signal that a new data channel should + // be created. + std::string label; + InternalDataChannelInit config; + config.id = params.ssrc; + if (!ParseDataChannelOpenMessage(buffer, &label, &config)) { + RTC_LOG(LS_WARNING) << "Failed to parse the OPEN message for ssrc " + << params.ssrc; + return true; + } + config.open_handshake_role = InternalDataChannelInit::kAcker; + OnDataChannelOpenMessage(label, config); + return true; + } + return false; +} + rtc::scoped_refptr> PeerConnection::GetAudioTransceiver() const { // This method only works with Plan B SDP, where there is a single @@ -4907,19 +4978,25 @@ cricket::BaseChannel* PeerConnection::GetChannel( } bool PeerConnection::GetSctpSslRole(rtc::SSLRole* role) { + RTC_DCHECK_RUN_ON(signaling_thread()); if (!local_description() || !remote_description()) { RTC_LOG(LS_INFO) << "Local and Remote descriptions must be applied to get the " "SSL Role of the SCTP transport."; return false; } - if (!sctp_transport_) { + if (!sctp_transport_ && !media_transport_) { RTC_LOG(LS_INFO) << "Non-rejected SCTP m= section is needed to get the " "SSL Role of the SCTP transport."; return false; } - auto dtls_role = transport_controller_->GetDtlsRole(*sctp_mid_); + absl::optional dtls_role; + if (sctp_mid_) { + dtls_role = transport_controller_->GetDtlsRole(*sctp_mid_); + } else if (is_caller_) { + dtls_role = *is_caller_ ? rtc::SSL_SERVER : rtc::SSL_CLIENT; + } if (dtls_role) { *role = *dtls_role; return true; @@ -5165,11 +5242,22 @@ bool PeerConnection::GetRemoteTrackIdBySsrc(uint32_t ssrc, bool PeerConnection::SendData(const cricket::SendDataParams& params, const rtc::CopyOnWriteBuffer& payload, cricket::SendDataResult* result) { - if (!rtp_data_channel_ && !sctp_transport_) { - RTC_LOG(LS_ERROR) << "SendData called when rtp_data_channel_ " - "and sctp_transport_ are NULL."; + if (!rtp_data_channel_ && !sctp_transport_ && !media_transport_) { + RTC_LOG(LS_ERROR) << "SendData called when rtp_data_channel_, " + "sctp_transport_, and media_transport_ are NULL."; return false; } + if (media_transport_) { + SendDataParams send_params; + send_params.type = ToWebrtcDataMessageType(params.type); + send_params.ordered = params.ordered; + if (params.max_rtx_count >= 0) { + send_params.max_rtx_count = params.max_rtx_count; + } else if (params.max_rtx_ms >= 0) { + send_params.max_rtx_ms = params.max_rtx_ms; + } + return media_transport_->SendData(params.sid, send_params, payload).ok(); + } return rtp_data_channel_ ? rtp_data_channel_->SendData(params, payload, result) : network_thread()->Invoke( @@ -5179,13 +5267,23 @@ bool PeerConnection::SendData(const cricket::SendDataParams& params, } bool PeerConnection::ConnectDataChannel(DataChannel* webrtc_data_channel) { - if (!rtp_data_channel_ && !sctp_transport_) { + RTC_DCHECK_RUN_ON(signaling_thread()); + if (!rtp_data_channel_ && !sctp_transport_ && !media_transport_) { // Don't log an error here, because DataChannels are expected to call // ConnectDataChannel in this state. It's the only way to initially tell // whether or not the underlying transport is ready. return false; } - if (rtp_data_channel_) { + if (media_transport_) { + SignalMediaTransportWritable_s.connect(webrtc_data_channel, + &DataChannel::OnChannelReady); + SignalMediaTransportReceivedData_s.connect(webrtc_data_channel, + &DataChannel::OnDataReceived); + SignalMediaTransportChannelClosing_s.connect( + webrtc_data_channel, &DataChannel::OnClosingProcedureStartedRemotely); + SignalMediaTransportChannelClosed_s.connect( + webrtc_data_channel, &DataChannel::OnClosingProcedureComplete); + } else if (rtp_data_channel_) { rtp_data_channel_->SignalReadyToSendData.connect( webrtc_data_channel, &DataChannel::OnChannelReady); rtp_data_channel_->SignalDataReceived.connect(webrtc_data_channel, @@ -5204,13 +5302,19 @@ bool PeerConnection::ConnectDataChannel(DataChannel* webrtc_data_channel) { } void PeerConnection::DisconnectDataChannel(DataChannel* webrtc_data_channel) { - if (!rtp_data_channel_ && !sctp_transport_) { + RTC_DCHECK_RUN_ON(signaling_thread()); + if (!rtp_data_channel_ && !sctp_transport_ && !media_transport_) { RTC_LOG(LS_ERROR) << "DisconnectDataChannel called when rtp_data_channel_ and " "sctp_transport_ are NULL."; return; } - if (rtp_data_channel_) { + if (media_transport_) { + SignalMediaTransportWritable_s.disconnect(webrtc_data_channel); + SignalMediaTransportReceivedData_s.disconnect(webrtc_data_channel); + SignalMediaTransportChannelClosing_s.disconnect(webrtc_data_channel); + SignalMediaTransportChannelClosed_s.disconnect(webrtc_data_channel); + } else if (rtp_data_channel_) { rtp_data_channel_->SignalReadyToSendData.disconnect(webrtc_data_channel); rtp_data_channel_->SignalDataReceived.disconnect(webrtc_data_channel); } else { @@ -5222,6 +5326,10 @@ void PeerConnection::DisconnectDataChannel(DataChannel* webrtc_data_channel) { } void PeerConnection::AddSctpDataStream(int sid) { + if (media_transport_) { + // No-op. Media transport does not need to add streams. + return; + } if (!sctp_transport_) { RTC_LOG(LS_ERROR) << "AddSctpDataStream called when sctp_transport_ is NULL."; @@ -5233,6 +5341,10 @@ void PeerConnection::AddSctpDataStream(int sid) { } void PeerConnection::RemoveSctpDataStream(int sid) { + if (media_transport_) { + media_transport_->CloseChannel(sid); + return; + } if (!sctp_transport_) { RTC_LOG(LS_ERROR) << "RemoveSctpDataStream called when sctp_transport_ is " "NULL."; @@ -5244,10 +5356,43 @@ void PeerConnection::RemoveSctpDataStream(int sid) { } bool PeerConnection::ReadyToSendData() const { + RTC_DCHECK_RUN_ON(signaling_thread()); return (rtp_data_channel_ && rtp_data_channel_->ready_to_send_data()) || + (media_transport_ && media_transport_ready_to_send_data_) || sctp_ready_to_send_data_; } +void PeerConnection::OnDataReceived(int channel_id, + DataMessageType type, + const rtc::CopyOnWriteBuffer& buffer) { + cricket::ReceiveDataParams params; + params.sid = channel_id; + params.type = ToCricketDataMessageType(type); + media_transport_invoker_->AsyncInvoke( + RTC_FROM_HERE, signaling_thread(), [this, params, buffer] { + RTC_DCHECK_RUN_ON(signaling_thread()); + if (!HandleOpenMessage_s(params, buffer)) { + SignalMediaTransportReceivedData_s(params, buffer); + } + }); +} + +void PeerConnection::OnChannelClosing(int channel_id) { + media_transport_invoker_->AsyncInvoke( + RTC_FROM_HERE, signaling_thread(), [this, channel_id] { + RTC_DCHECK_RUN_ON(signaling_thread()); + SignalMediaTransportChannelClosing_s(channel_id); + }); +} + +void PeerConnection::OnChannelClosed(int channel_id) { + media_transport_invoker_->AsyncInvoke( + RTC_FROM_HERE, signaling_thread(), [this, channel_id] { + RTC_DCHECK_RUN_ON(signaling_thread()); + SignalMediaTransportChannelClosed_s(channel_id); + }); +} + absl::optional PeerConnection::sctp_transport_name() const { if (sctp_mid_ && transport_controller_) { auto dtls_transport = transport_controller_->GetDtlsTransport(*sctp_mid_); @@ -5608,7 +5753,7 @@ RTCError PeerConnection::CreateChannels(const SessionDescription& desc) { const cricket::ContentInfo* data = cricket::GetFirstDataContent(&desc); if (data_channel_type_ != cricket::DCT_NONE && data && !data->rejected && - !rtp_data_channel_ && !sctp_transport_) { + !rtp_data_channel_ && !sctp_transport_ && !media_transport_) { if (!CreateDataChannel(data->name)) { LOG_AND_RETURN_ERROR(RTCErrorType::INTERNAL_ERROR, "Failed to create data channel."); @@ -5666,35 +5811,49 @@ cricket::VideoChannel* PeerConnection::CreateVideoChannel( } bool PeerConnection::CreateDataChannel(const std::string& mid) { - bool sctp = (data_channel_type_ == cricket::DCT_SCTP); - if (sctp) { - if (!sctp_factory_) { - RTC_LOG(LS_ERROR) - << "Trying to create SCTP transport, but didn't compile with " - "SCTP support (HAVE_SCTP)"; + switch (data_channel_type_) { + case cricket::DCT_MEDIA_TRANSPORT: + if (network_thread()->Invoke( + RTC_FROM_HERE, + rtc::Bind(&PeerConnection::SetupMediaTransportForDataChannels_n, + this, mid))) { + for (const auto& channel : sctp_data_channels_) { + channel->OnTransportChannelCreated(); + } + return true; + } return false; - } - if (!network_thread()->Invoke( - RTC_FROM_HERE, - rtc::Bind(&PeerConnection::CreateSctpTransport_n, this, mid))) { - return false; - } - for (const auto& channel : sctp_data_channels_) { - channel->OnTransportChannelCreated(); - } - } else { - RtpTransportInternal* rtp_transport = GetRtpTransport(mid); - rtp_data_channel_ = channel_manager()->CreateRtpDataChannel( - configuration_.media_config, rtp_transport, signaling_thread(), mid, - SrtpRequired(), GetCryptoOptions()); - if (!rtp_data_channel_) { - return false; - } - rtp_data_channel_->SignalDtlsSrtpSetupFailure.connect( - this, &PeerConnection::OnDtlsSrtpSetupFailure); - rtp_data_channel_->SignalSentPacket.connect( - this, &PeerConnection::OnSentPacket_w); - rtp_data_channel_->SetRtpTransport(rtp_transport); + case cricket::DCT_SCTP: + if (!sctp_factory_) { + RTC_LOG(LS_ERROR) + << "Trying to create SCTP transport, but didn't compile with " + "SCTP support (HAVE_SCTP)"; + return false; + } + if (!network_thread()->Invoke( + RTC_FROM_HERE, + rtc::Bind(&PeerConnection::CreateSctpTransport_n, this, mid))) { + return false; + } + for (const auto& channel : sctp_data_channels_) { + channel->OnTransportChannelCreated(); + } + return true; + case cricket::DCT_RTP: + default: + RtpTransportInternal* rtp_transport = GetRtpTransport(mid); + rtp_data_channel_ = channel_manager()->CreateRtpDataChannel( + configuration_.media_config, rtp_transport, signaling_thread(), mid, + SrtpRequired(), GetCryptoOptions()); + if (!rtp_data_channel_) { + return false; + } + rtp_data_channel_->SignalDtlsSrtpSetupFailure.connect( + this, &PeerConnection::OnDtlsSrtpSetupFailure); + rtp_data_channel_->SignalSentPacket.connect( + this, &PeerConnection::OnSentPacket_w); + rtp_data_channel_->SetRtpTransport(rtp_transport); + return true; } return true; @@ -5784,22 +5943,8 @@ void PeerConnection::OnSctpTransportDataReceived_n( void PeerConnection::OnSctpTransportDataReceived_s( const cricket::ReceiveDataParams& params, const rtc::CopyOnWriteBuffer& payload) { - RTC_DCHECK(signaling_thread()->IsCurrent()); - if (params.type == cricket::DMT_CONTROL && IsOpenMessage(payload)) { - // Received OPEN message; parse and signal that a new data channel should - // be created. - std::string label; - InternalDataChannelInit config; - config.id = params.ssrc; - if (!ParseDataChannelOpenMessage(payload, &label, &config)) { - RTC_LOG(LS_WARNING) << "Failed to parse the OPEN message for sid " - << params.ssrc; - return; - } - config.open_handshake_role = InternalDataChannelInit::kAcker; - OnDataChannelOpenMessage(label, config); - } else { - // Otherwise just forward the signal. + RTC_DCHECK_RUN_ON(signaling_thread()); + if (!HandleOpenMessage_s(params, payload)) { SignalSctpDataReceived(params, payload); } } @@ -5822,6 +5967,49 @@ void PeerConnection::OnSctpClosingProcedureComplete_n(int sid) { &SignalSctpClosingProcedureComplete, sid)); } +bool PeerConnection::SetupMediaTransportForDataChannels_n( + const std::string& mid) { + media_transport_ = transport_controller_->GetMediaTransport(mid); + if (!media_transport_) { + RTC_LOG(LS_ERROR) << "Media transport is not available for data channels"; + return false; + } + + media_transport_invoker_ = absl::make_unique(); + media_transport_->SetDataSink(this); + media_transport_data_mid_ = mid; + transport_controller_->SignalMediaTransportStateChanged.connect( + this, &PeerConnection::OnMediaTransportStateChanged_n); + // Check the initial state right away, in case transport is already writable. + OnMediaTransportStateChanged_n(); + return true; +} + +void PeerConnection::TeardownMediaTransportForDataChannels_n() { + if (!media_transport_) { + return; + } + transport_controller_->SignalMediaTransportStateChanged.disconnect(this); + media_transport_data_mid_.reset(); + media_transport_->SetDataSink(nullptr); + media_transport_invoker_ = nullptr; + media_transport_ = nullptr; +} + +void PeerConnection::OnMediaTransportStateChanged_n() { + if (!media_transport_data_mid_ || + transport_controller_->GetMediaTransportState( + *media_transport_data_mid_) != MediaTransportState::kWritable) { + return; + } + media_transport_invoker_->AsyncInvoke( + RTC_FROM_HERE, signaling_thread(), [this] { + RTC_DCHECK_RUN_ON(signaling_thread()); + media_transport_ready_to_send_data_ = true; + SignalMediaTransportWritable_s(media_transport_ready_to_send_data_); + }); +} + // Returns false if bundle is enabled and rtcp_mux is disabled. bool PeerConnection::ValidateBundleSettings(const SessionDescription* desc) { bool bundle_enabled = desc->HasGroup(cricket::GROUP_TYPE_BUNDLE); @@ -6336,6 +6524,14 @@ void PeerConnection::DestroyDataChannel() { network_thread()->Invoke(RTC_FROM_HERE, [this] { DestroySctpTransport_n(); }); } + + if (media_transport_) { + OnDataChannelDestroyed(); + network_thread()->Invoke(RTC_FROM_HERE, [this] { + RTC_DCHECK_RUN_ON(network_thread()); + TeardownMediaTransportForDataChannels_n(); + }); + } } void PeerConnection::DestroyBaseChannel(cricket::BaseChannel* channel) { diff --git a/pc/peerconnection.h b/pc/peerconnection.h index 4099992a0e..fe4d777401 100644 --- a/pc/peerconnection.h +++ b/pc/peerconnection.h @@ -52,6 +52,7 @@ class RtcEventLog; // - Generating stats. class PeerConnection : public PeerConnectionInternal, public DataChannelProviderInterface, + public DataChannelSink, public JsepTransportController::Observer, public rtc::MessageHandler, public sigslot::has_slots<> { @@ -632,6 +633,11 @@ class PeerConnection : public PeerConnectionInternal, // Called when a valid data channel OPEN message is received. void OnDataChannelOpenMessage(const std::string& label, const InternalDataChannelInit& config); + // Parses and handles open messages. Returns true if the message is an open + // message, false otherwise. + bool HandleOpenMessage_s(const cricket::ReceiveDataParams& params, + const rtc::CopyOnWriteBuffer& buffer) + RTC_RUN_ON(signaling_thread()); // Returns true if the PeerConnection is configured to use Unified Plan // semantics for creating offers/answers and setting local/remote @@ -733,6 +739,13 @@ class PeerConnection : public PeerConnectionInternal, cricket::DataChannelType data_channel_type() const; + // Implements DataChannelSink. + void OnDataReceived(int channel_id, + DataMessageType type, + const rtc::CopyOnWriteBuffer& buffer) override; + void OnChannelClosing(int channel_id) override; + void OnChannelClosed(int channel_id) override; + // Called when an RTCCertificate is generated or retrieved by // WebRTCSessionDescriptionFactory. Should happen before setLocalDescription. void OnCertificateReady( @@ -830,6 +843,11 @@ class PeerConnection : public PeerConnectionInternal, void OnSctpClosingProcedureStartedRemotely_n(int sid); void OnSctpClosingProcedureComplete_n(int sid); + bool SetupMediaTransportForDataChannels_n(const std::string& mid) + RTC_RUN_ON(network_thread()); + void OnMediaTransportStateChanged_n() RTC_RUN_ON(network_thread()); + void TeardownMediaTransportForDataChannels_n() RTC_RUN_ON(network_thread()); + bool ValidateBundleSettings(const cricket::SessionDescription* desc); bool HasRtcpMuxEnabled(const cricket::ContentInfo* content); // Below methods are helper methods which verifies SDP. @@ -1050,6 +1068,33 @@ class PeerConnection : public PeerConnectionInternal, sigslot::signal1 SignalSctpClosingProcedureStartedRemotely; sigslot::signal1 SignalSctpClosingProcedureComplete; + // Whether this peer is the caller. Set when the local description is applied. + absl::optional is_caller_ RTC_GUARDED_BY(signaling_thread()); + + // Content name (MID) for media transport data channels in SDP. + absl::optional media_transport_data_mid_; + + // Media transport used for data channels. Thread-safe. + MediaTransportInterface* media_transport_ = nullptr; + + // Cached value of whether the media transport is ready to send. + bool media_transport_ready_to_send_data_ RTC_GUARDED_BY(signaling_thread()) = + false; + + // Used to invoke media transport signals on the signaling thread. + std::unique_ptr media_transport_invoker_; + + // Identical to the signals for SCTP, but from media transport: + sigslot::signal1 SignalMediaTransportWritable_s + RTC_GUARDED_BY(signaling_thread()); + sigslot::signal2 + SignalMediaTransportReceivedData_s RTC_GUARDED_BY(signaling_thread()); + sigslot::signal1 SignalMediaTransportChannelClosing_s + RTC_GUARDED_BY(signaling_thread()); + sigslot::signal1 SignalMediaTransportChannelClosed_s + RTC_GUARDED_BY(signaling_thread()); + std::unique_ptr current_local_description_; std::unique_ptr pending_local_description_; std::unique_ptr current_remote_description_; diff --git a/pc/peerconnection_datachannel_unittest.cc b/pc/peerconnection_datachannel_unittest.cc index d7f1fcce5a..cfb5dde22a 100644 --- a/pc/peerconnection_datachannel_unittest.cc +++ b/pc/peerconnection_datachannel_unittest.cc @@ -11,6 +11,7 @@ #include #include "api/peerconnectionproxy.h" +#include "api/test/fake_media_transport.h" #include "media/base/fakemediaengine.h" #include "pc/mediasession.h" #include "pc/peerconnection.h" @@ -31,17 +32,39 @@ using RTCConfiguration = PeerConnectionInterface::RTCConfiguration; using RTCOfferAnswerOptions = PeerConnectionInterface::RTCOfferAnswerOptions; using ::testing::Values; +namespace { + +PeerConnectionFactoryDependencies CreatePeerConnectionFactoryDependencies( + rtc::Thread* network_thread, + rtc::Thread* worker_thread, + rtc::Thread* signaling_thread, + std::unique_ptr media_engine, + std::unique_ptr call_factory, + std::unique_ptr media_transport_factory) { + PeerConnectionFactoryDependencies deps; + deps.network_thread = network_thread; + deps.worker_thread = worker_thread; + deps.signaling_thread = signaling_thread; + deps.media_engine = std::move(media_engine); + deps.call_factory = std::move(call_factory); + deps.media_transport_factory = std::move(media_transport_factory); + return deps; +} + +} // namespace + class PeerConnectionFactoryForDataChannelTest : public rtc::RefCountedObject { public: PeerConnectionFactoryForDataChannelTest() : rtc::RefCountedObject( - rtc::Thread::Current(), - rtc::Thread::Current(), - rtc::Thread::Current(), - absl::make_unique(), - CreateCallFactory(), - nullptr) {} + CreatePeerConnectionFactoryDependencies( + rtc::Thread::Current(), + rtc::Thread::Current(), + rtc::Thread::Current(), + absl::make_unique(), + CreateCallFactory(), + absl::make_unique())) {} std::unique_ptr CreateSctpTransportInternalFactory() { @@ -324,6 +347,52 @@ TEST_P(PeerConnectionDataChannelTest, SctpPortPropagatedFromSdpToTransport) { EXPECT_EQ(kNewRecvPort, callee_transport->local_port()); } +TEST_P(PeerConnectionDataChannelTest, + NoSctpTransportCreatedIfMediaTransportDataChannelsEnabled) { + RTCConfiguration config; + config.use_media_transport_for_data_channels = true; + config.enable_dtls_srtp = false; // SDES is required to use media transport. + auto caller = CreatePeerConnectionWithDataChannel(config); + + ASSERT_TRUE(caller->SetLocalDescription(caller->CreateOffer())); + EXPECT_FALSE(caller->sctp_transport_factory()->last_fake_sctp_transport()); +} + +TEST_P(PeerConnectionDataChannelTest, + MediaTransportDataChannelCreatedEvenIfSctpAvailable) { + RTCConfiguration config; + config.use_media_transport_for_data_channels = true; + config.enable_dtls_srtp = false; // SDES is required to use media transport. + PeerConnectionFactoryInterface::Options options; + options.disable_sctp_data_channels = false; + auto caller = CreatePeerConnectionWithDataChannel(config, options); + + ASSERT_TRUE(caller->SetLocalDescription(caller->CreateOffer())); + EXPECT_FALSE(caller->sctp_transport_factory()->last_fake_sctp_transport()); +} + +TEST_P(PeerConnectionDataChannelTest, + CannotEnableBothMediaTransportAndRtpDataChannels) { + RTCConfiguration config; + config.enable_rtp_data_channel = true; + config.use_media_transport_for_data_channels = true; + config.enable_dtls_srtp = false; // SDES is required to use media transport. + EXPECT_EQ(CreatePeerConnection(config), nullptr); +} + +TEST_P(PeerConnectionDataChannelTest, + MediaTransportDataChannelFailsWithoutSdes) { + RTCConfiguration config; + config.use_media_transport_for_data_channels = true; + config.enable_dtls_srtp = true; // Disables SDES for data sections. + auto caller = CreatePeerConnectionWithDataChannel(config); + + std::string error; + ASSERT_FALSE(caller->SetLocalDescription(caller->CreateOffer(), &error)); + EXPECT_EQ(error, + "Failed to set local offer sdp: Failed to create data channel."); +} + INSTANTIATE_TEST_CASE_P(PeerConnectionDataChannelTest, PeerConnectionDataChannelTest, Values(SdpSemantics::kPlanB, diff --git a/pc/peerconnection_integrationtest.cc b/pc/peerconnection_integrationtest.cc index 236a8bb97d..ccd7d05cb5 100644 --- a/pc/peerconnection_integrationtest.cc +++ b/pc/peerconnection_integrationtest.cc @@ -28,6 +28,7 @@ #include "api/peerconnectioninterface.h" #include "api/peerconnectionproxy.h" #include "api/rtpreceiverinterface.h" +#include "api/test/loopback_media_transport.h" #include "api/umametrics.h" #include "api/video_codecs/builtin_video_decoder_factory.h" #include "api/video_codecs/builtin_video_encoder_factory.h" @@ -251,7 +252,8 @@ class PeerConnectionWrapper : public webrtc::PeerConnectionObserver, webrtc::PeerConnectionDependencies dependencies(nullptr); dependencies.cert_generator = std::move(cert_generator); if (!client->Init(nullptr, nullptr, std::move(dependencies), network_thread, - worker_thread, nullptr)) { + worker_thread, nullptr, + /*media_transport_factory=*/nullptr)) { delete client; return nullptr; } @@ -588,12 +590,14 @@ class PeerConnectionWrapper : public webrtc::PeerConnectionObserver, explicit PeerConnectionWrapper(const std::string& debug_name) : debug_name_(debug_name) {} - bool Init(const PeerConnectionFactory::Options* options, - const PeerConnectionInterface::RTCConfiguration* config, - webrtc::PeerConnectionDependencies dependencies, - rtc::Thread* network_thread, - rtc::Thread* worker_thread, - std::unique_ptr event_log_factory) { + bool Init( + const PeerConnectionFactory::Options* options, + const PeerConnectionInterface::RTCConfiguration* config, + webrtc::PeerConnectionDependencies dependencies, + rtc::Thread* network_thread, + rtc::Thread* worker_thread, + std::unique_ptr event_log_factory, + std::unique_ptr media_transport_factory) { // There's an error in this test code if Init ends up being called twice. RTC_DCHECK(!peer_connection_); RTC_DCHECK(!peer_connection_factory_); @@ -631,6 +635,10 @@ class PeerConnectionWrapper : public webrtc::PeerConnectionObserver, pc_factory_dependencies.event_log_factory = webrtc::CreateRtcEventLogFactory(); } + if (media_transport_factory) { + pc_factory_dependencies.media_transport_factory = + std::move(media_transport_factory); + } peer_connection_factory_ = webrtc::CreateModularPeerConnectionFactory( std::move(pc_factory_dependencies)); @@ -1156,7 +1164,8 @@ class PeerConnectionIntegrationBaseTest : public testing::Test { ss_(new rtc::VirtualSocketServer()), fss_(new rtc::FirewallSocketServer(ss_.get())), network_thread_(new rtc::Thread(fss_.get())), - worker_thread_(rtc::Thread::Create()) { + worker_thread_(rtc::Thread::Create()), + loopback_media_transports_(network_thread_.get()) { network_thread_->SetName("PCNetworkThread", this); worker_thread_->SetName("PCWorkerThread", this); RTC_CHECK(network_thread_->Start()); @@ -1212,7 +1221,8 @@ class PeerConnectionIntegrationBaseTest : public testing::Test { const PeerConnectionFactory::Options* options, const RTCConfiguration* config, webrtc::PeerConnectionDependencies dependencies, - std::unique_ptr event_log_factory) { + std::unique_ptr event_log_factory, + std::unique_ptr media_transport_factory) { RTCConfiguration modified_config; if (config) { modified_config = *config; @@ -1227,7 +1237,8 @@ class PeerConnectionIntegrationBaseTest : public testing::Test { if (!client->Init(options, &modified_config, std::move(dependencies), network_thread_.get(), worker_thread_.get(), - std::move(event_log_factory))) { + std::move(event_log_factory), + std::move(media_transport_factory))) { return nullptr; } return client; @@ -1243,7 +1254,8 @@ class PeerConnectionIntegrationBaseTest : public testing::Test { new webrtc::FakeRtcEventLogFactory(rtc::Thread::Current())); return CreatePeerConnectionWrapper(debug_name, options, config, std::move(dependencies), - std::move(event_log_factory)); + std::move(event_log_factory), + /*media_transport_factory=*/nullptr); } bool CreatePeerConnectionWrappers() { @@ -1264,11 +1276,11 @@ class PeerConnectionIntegrationBaseTest : public testing::Test { sdp_semantics_ = caller_semantics; caller_ = CreatePeerConnectionWrapper( "Caller", nullptr, nullptr, webrtc::PeerConnectionDependencies(nullptr), - nullptr); + nullptr, /*media_transport_factory=*/nullptr); sdp_semantics_ = callee_semantics; callee_ = CreatePeerConnectionWrapper( "Callee", nullptr, nullptr, webrtc::PeerConnectionDependencies(nullptr), - nullptr); + nullptr, /*media_transport_factory=*/nullptr); sdp_semantics_ = original_semantics; return caller_ && callee_; } @@ -1278,10 +1290,28 @@ class PeerConnectionIntegrationBaseTest : public testing::Test { const PeerConnectionInterface::RTCConfiguration& callee_config) { caller_ = CreatePeerConnectionWrapper( "Caller", nullptr, &caller_config, - webrtc::PeerConnectionDependencies(nullptr), nullptr); + webrtc::PeerConnectionDependencies(nullptr), nullptr, + /*media_transport_factory=*/nullptr); callee_ = CreatePeerConnectionWrapper( "Callee", nullptr, &callee_config, - webrtc::PeerConnectionDependencies(nullptr), nullptr); + webrtc::PeerConnectionDependencies(nullptr), nullptr, + /*media_transport_factory=*/nullptr); + return caller_ && callee_; + } + + bool CreatePeerConnectionWrappersWithConfigAndMediaTransportFactory( + const PeerConnectionInterface::RTCConfiguration& caller_config, + const PeerConnectionInterface::RTCConfiguration& callee_config, + std::unique_ptr caller_factory, + std::unique_ptr callee_factory) { + caller_ = + CreatePeerConnectionWrapper("Caller", nullptr, &caller_config, + webrtc::PeerConnectionDependencies(nullptr), + nullptr, std::move(caller_factory)); + callee_ = + CreatePeerConnectionWrapper("Callee", nullptr, &callee_config, + webrtc::PeerConnectionDependencies(nullptr), + nullptr, std::move(callee_factory)); return caller_ && callee_; } @@ -1292,10 +1322,12 @@ class PeerConnectionIntegrationBaseTest : public testing::Test { webrtc::PeerConnectionDependencies callee_dependencies) { caller_ = CreatePeerConnectionWrapper("Caller", nullptr, &caller_config, - std::move(caller_dependencies), nullptr); + std::move(caller_dependencies), nullptr, + /*media_transport_factory=*/nullptr); callee_ = CreatePeerConnectionWrapper("Callee", nullptr, &callee_config, - std::move(callee_dependencies), nullptr); + std::move(callee_dependencies), nullptr, + /*media_transport_factory=*/nullptr); return caller_ && callee_; } @@ -1304,10 +1336,12 @@ class PeerConnectionIntegrationBaseTest : public testing::Test { const PeerConnectionFactory::Options& callee_options) { caller_ = CreatePeerConnectionWrapper( "Caller", &caller_options, nullptr, - webrtc::PeerConnectionDependencies(nullptr), nullptr); + webrtc::PeerConnectionDependencies(nullptr), nullptr, + /*media_transport_factory=*/nullptr); callee_ = CreatePeerConnectionWrapper( "Callee", &callee_options, nullptr, - webrtc::PeerConnectionDependencies(nullptr), nullptr); + webrtc::PeerConnectionDependencies(nullptr), nullptr, + /*media_transport_factory=*/nullptr); return caller_ && callee_; } @@ -1331,7 +1365,8 @@ class PeerConnectionIntegrationBaseTest : public testing::Test { webrtc::PeerConnectionDependencies dependencies(nullptr); dependencies.cert_generator = std::move(cert_generator); return CreatePeerConnectionWrapper("New Peer", nullptr, nullptr, - std::move(dependencies), nullptr); + std::move(dependencies), nullptr, + /*media_transport_factory=*/nullptr); } cricket::TestTurnServer* CreateTurnServer( @@ -1419,6 +1454,10 @@ class PeerConnectionIntegrationBaseTest : public testing::Test { rtc::VirtualSocketServer* virtual_socket_server() { return ss_.get(); } + webrtc::MediaTransportPair* loopback_media_transports() { + return &loopback_media_transports_; + } + PeerConnectionWrapper* caller() { return caller_.get(); } // Set the |caller_| to the |wrapper| passed in and return the @@ -1597,6 +1636,7 @@ class PeerConnectionIntegrationBaseTest : public testing::Test { // on the network thread. std::vector> turn_servers_; std::vector> turn_customizers_; + webrtc::MediaTransportPair loopback_media_transports_; std::unique_ptr caller_; std::unique_ptr callee_; }; @@ -3347,6 +3387,111 @@ TEST_P(PeerConnectionIntegrationTest, #endif // HAVE_SCTP +// This test sets up a call between two parties with audio, video, and a media +// transport data channel. +TEST_P(PeerConnectionIntegrationTest, MediaTransportDataChannelEndToEnd) { + PeerConnectionInterface::RTCConfiguration rtc_config; + rtc_config.use_media_transport_for_data_channels = true; + rtc_config.enable_dtls_srtp = false; // SDES is required for media transport. + ASSERT_TRUE(CreatePeerConnectionWrappersWithConfigAndMediaTransportFactory( + rtc_config, rtc_config, loopback_media_transports()->first_factory(), + loopback_media_transports()->second_factory())); + ConnectFakeSignaling(); + + // Expect that data channel created on caller side will show up for callee as + // well. + caller()->CreateDataChannel(); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + + // Ensure that the media transport is ready. + loopback_media_transports()->SetState(webrtc::MediaTransportState::kWritable); + loopback_media_transports()->FlushAsyncInvokes(); + + // Caller data channel should already exist (it created one). Callee data + // channel may not exist yet, since negotiation happens in-band, not in SDP. + ASSERT_NE(nullptr, caller()->data_channel()); + ASSERT_TRUE_WAIT(callee()->data_channel() != nullptr, kDefaultTimeout); + EXPECT_TRUE_WAIT(caller()->data_observer()->IsOpen(), kDefaultTimeout); + EXPECT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); + + // Ensure data can be sent in both directions. + std::string data = "hello world"; + caller()->data_channel()->Send(DataBuffer(data)); + EXPECT_EQ_WAIT(data, callee()->data_observer()->last_message(), + kDefaultTimeout); + callee()->data_channel()->Send(DataBuffer(data)); + EXPECT_EQ_WAIT(data, caller()->data_observer()->last_message(), + kDefaultTimeout); +} + +// Ensure that when the callee closes a media transport data channel, the +// closing procedure results in the data channel being closed for the caller +// as well. +TEST_P(PeerConnectionIntegrationTest, MediaTransportDataChannelCalleeCloses) { + PeerConnectionInterface::RTCConfiguration rtc_config; + rtc_config.use_media_transport_for_data_channels = true; + rtc_config.enable_dtls_srtp = false; // SDES is required for media transport. + ASSERT_TRUE(CreatePeerConnectionWrappersWithConfigAndMediaTransportFactory( + rtc_config, rtc_config, loopback_media_transports()->first_factory(), + loopback_media_transports()->second_factory())); + ConnectFakeSignaling(); + + // Create a data channel on the caller and signal it to the callee. + caller()->CreateDataChannel(); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + + // Ensure that the media transport is ready. + loopback_media_transports()->SetState(webrtc::MediaTransportState::kWritable); + loopback_media_transports()->FlushAsyncInvokes(); + + // Data channels exist and open on both ends of the connection. + ASSERT_NE(nullptr, caller()->data_channel()); + ASSERT_TRUE_WAIT(callee()->data_channel() != nullptr, kDefaultTimeout); + ASSERT_TRUE_WAIT(caller()->data_observer()->IsOpen(), kDefaultTimeout); + ASSERT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); + + // Close the data channel on the callee side, and wait for it to reach the + // "closed" state on both sides. + callee()->data_channel()->Close(); + EXPECT_TRUE_WAIT(!caller()->data_observer()->IsOpen(), kDefaultTimeout); + EXPECT_TRUE_WAIT(!callee()->data_observer()->IsOpen(), kDefaultTimeout); +} + +TEST_P(PeerConnectionIntegrationTest, + MediaTransportDataChannelConfigSentToOtherSide) { + PeerConnectionInterface::RTCConfiguration rtc_config; + rtc_config.use_media_transport_for_data_channels = true; + rtc_config.enable_dtls_srtp = false; // SDES is required for media transport. + ASSERT_TRUE(CreatePeerConnectionWrappersWithConfigAndMediaTransportFactory( + rtc_config, rtc_config, loopback_media_transports()->first_factory(), + loopback_media_transports()->second_factory())); + ConnectFakeSignaling(); + + // Create a data channel with a non-default configuration and signal it to the + // callee. + webrtc::DataChannelInit init; + init.id = 53; + init.maxRetransmits = 52; + caller()->CreateDataChannel("data-channel", &init); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + + // Ensure that the media transport is ready. + loopback_media_transports()->SetState(webrtc::MediaTransportState::kWritable); + loopback_media_transports()->FlushAsyncInvokes(); + + // Ensure that the data channel exists on the callee with the correct + // configuration. + ASSERT_TRUE_WAIT(callee()->data_channel() != nullptr, kDefaultTimeout); + ASSERT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); + EXPECT_EQ(init.id, callee()->data_channel()->id()); + EXPECT_EQ("data-channel", callee()->data_channel()->label()); + EXPECT_EQ(init.maxRetransmits, callee()->data_channel()->maxRetransmits()); + EXPECT_FALSE(callee()->data_channel()->negotiated()); +} + // Test that the ICE connection and gathering states eventually reach // "complete". TEST_P(PeerConnectionIntegrationTest, IceStatesReachCompletion) {