diff --git a/api/test/DEPS b/api/test/DEPS index 1fc1f7cb30..9c293b323e 100644 --- a/api/test/DEPS +++ b/api/test/DEPS @@ -9,5 +9,6 @@ specific_include_rules = { "+rtc_base/asyncinvoker.h", "+rtc_base/criticalsection.h", "+rtc_base/thread.h", + "+rtc_base/thread_checker.h", ], } diff --git a/api/test/loopback_media_transport.h b/api/test/loopback_media_transport.h index f3f24d4c98..2524fb47b1 100644 --- a/api/test/loopback_media_transport.h +++ b/api/test/loopback_media_transport.h @@ -11,15 +11,91 @@ #ifndef API_TEST_LOOPBACK_MEDIA_TRANSPORT_H_ #define API_TEST_LOOPBACK_MEDIA_TRANSPORT_H_ +#include #include #include "api/media_transport_interface.h" #include "rtc_base/asyncinvoker.h" #include "rtc_base/criticalsection.h" #include "rtc_base/thread.h" +#include "rtc_base/thread_checker.h" namespace webrtc { +// Wrapper used to hand out unique_ptrs to loopback media transports without +// ownership changes. +class WrapperMediaTransport : public MediaTransportInterface { + public: + explicit WrapperMediaTransport(MediaTransportInterface* wrapped) + : wrapped_(wrapped) {} + + RTCError SendAudioFrame(uint64_t channel_id, + MediaTransportEncodedAudioFrame frame) override { + return wrapped_->SendAudioFrame(channel_id, std::move(frame)); + } + + RTCError SendVideoFrame( + uint64_t channel_id, + const MediaTransportEncodedVideoFrame& frame) override { + return wrapped_->SendVideoFrame(channel_id, frame); + } + + RTCError RequestKeyFrame(uint64_t channel_id) override { + return wrapped_->RequestKeyFrame(channel_id); + } + + void SetReceiveAudioSink(MediaTransportAudioSinkInterface* sink) override { + wrapped_->SetReceiveAudioSink(sink); + } + + void SetReceiveVideoSink(MediaTransportVideoSinkInterface* sink) override { + wrapped_->SetReceiveVideoSink(sink); + } + + void SetTargetTransferRateObserver( + webrtc::TargetTransferRateObserver* observer) override { + wrapped_->SetTargetTransferRateObserver(observer); + } + + void SetMediaTransportStateCallback( + MediaTransportStateCallback* callback) override { + wrapped_->SetMediaTransportStateCallback(callback); + } + + RTCError SendData(int channel_id, + const SendDataParams& params, + const rtc::CopyOnWriteBuffer& buffer) override { + return wrapped_->SendData(channel_id, params, buffer); + } + + RTCError CloseChannel(int channel_id) override { + return wrapped_->CloseChannel(channel_id); + } + + void SetDataSink(DataChannelSink* sink) override { + wrapped_->SetDataSink(sink); + } + + private: + MediaTransportInterface* wrapped_; +}; + +class WrapperMediaTransportFactory : public MediaTransportFactory { + public: + explicit WrapperMediaTransportFactory(MediaTransportInterface* wrapped) + : wrapped_(wrapped) {} + + RTCErrorOr> CreateMediaTransport( + rtc::PacketTransportInternal* packet_transport, + rtc::Thread* network_thread, + const MediaTransportSettings& settings) override { + return {absl::make_unique(wrapped_)}; + } + + private: + MediaTransportInterface* wrapped_; +}; + // Contains two MediaTransportsInterfaces that are connected to each other. // Currently supports audio only. class MediaTransportPair { @@ -31,6 +107,19 @@ class MediaTransportPair { MediaTransportInterface* first() { return &first_; } MediaTransportInterface* second() { return &second_; } + std::unique_ptr first_factory() { + return absl::make_unique(&first_); + } + + std::unique_ptr second_factory() { + return absl::make_unique(&second_); + } + + void SetState(MediaTransportState state) { + first_.SetState(state); + second_.SetState(state); + } + void FlushAsyncInvokes() { first_.FlushAsyncInvokes(); second_.FlushAsyncInvokes(); @@ -81,7 +170,14 @@ class MediaTransportPair { webrtc::TargetTransferRateObserver* observer) override {} void SetMediaTransportStateCallback( - MediaTransportStateCallback* callback) override {} + MediaTransportStateCallback* callback) override { + rtc::CritScope lock(&sink_lock_); + state_callback_ = callback; + invoker_.AsyncInvoke(RTC_FROM_HERE, thread_, [this] { + RTC_DCHECK_RUN_ON(thread_); + OnStateChanged(); + }); + } RTCError SendData(int channel_id, const SendDataParams& params, @@ -109,6 +205,14 @@ class MediaTransportPair { data_sink_ = sink; } + void SetState(MediaTransportState state) { + invoker_.AsyncInvoke(RTC_FROM_HERE, thread_, [this, state] { + RTC_DCHECK_RUN_ON(thread_); + state_ = state; + OnStateChanged(); + }); + } + void FlushAsyncInvokes() { invoker_.Flush(thread_); } private: @@ -136,12 +240,25 @@ class MediaTransportPair { } } + void OnStateChanged() RTC_RUN_ON(thread_) { + rtc::CritScope lock(&sink_lock_); + if (state_callback_) { + state_callback_->OnStateChanged(state_); + } + } + rtc::Thread* const thread_; rtc::CriticalSection sink_lock_; MediaTransportAudioSinkInterface* sink_ RTC_GUARDED_BY(sink_lock_) = nullptr; DataChannelSink* data_sink_ RTC_GUARDED_BY(sink_lock_) = nullptr; + MediaTransportStateCallback* state_callback_ RTC_GUARDED_BY(sink_lock_) = + nullptr; + + MediaTransportState state_ RTC_GUARDED_BY(thread_) = + MediaTransportState::kPending; + LoopbackMediaTransport* const other_; rtc::AsyncInvoker invoker_; diff --git a/api/test/loopback_media_transport_unittest.cc b/api/test/loopback_media_transport_unittest.cc index f85413c55b..ba741a05ca 100644 --- a/api/test/loopback_media_transport_unittest.cc +++ b/api/test/loopback_media_transport_unittest.cc @@ -32,6 +32,11 @@ class MockDataChannelSink : public DataChannelSink { MOCK_METHOD1(OnChannelClosed, void(int)); }; +class MockStateCallback : public MediaTransportStateCallback { + public: + MOCK_METHOD1(OnStateChanged, void(MediaTransportState)); +}; + // Test only uses the sequence number. MediaTransportEncodedAudioFrame CreateAudioFrame(int sequence_number) { static constexpr int kSamplingRateHz = 48000; @@ -122,4 +127,45 @@ TEST(LoopbackMediaTransport, CloseDeliveredToSink) { transport_pair.second()->SetDataSink(nullptr); } +TEST(LoopbackMediaTransport, InitialStateDeliveredWhenCallbackSet) { + std::unique_ptr thread = rtc::Thread::Create(); + thread->Start(); + MediaTransportPair transport_pair(thread.get()); + + MockStateCallback state_callback; + + EXPECT_CALL(state_callback, OnStateChanged(MediaTransportState::kPending)); + transport_pair.first()->SetMediaTransportStateCallback(&state_callback); + transport_pair.FlushAsyncInvokes(); +} + +TEST(LoopbackMediaTransport, ChangedStateDeliveredWhenCallbackSet) { + std::unique_ptr thread = rtc::Thread::Create(); + thread->Start(); + MediaTransportPair transport_pair(thread.get()); + + transport_pair.SetState(MediaTransportState::kWritable); + transport_pair.FlushAsyncInvokes(); + + MockStateCallback state_callback; + + EXPECT_CALL(state_callback, OnStateChanged(MediaTransportState::kWritable)); + transport_pair.first()->SetMediaTransportStateCallback(&state_callback); + transport_pair.FlushAsyncInvokes(); +} + +TEST(LoopbackMediaTransport, StateChangeDeliveredToCallback) { + std::unique_ptr thread = rtc::Thread::Create(); + thread->Start(); + MediaTransportPair transport_pair(thread.get()); + + MockStateCallback state_callback; + + EXPECT_CALL(state_callback, OnStateChanged(MediaTransportState::kPending)); + EXPECT_CALL(state_callback, OnStateChanged(MediaTransportState::kWritable)); + transport_pair.first()->SetMediaTransportStateCallback(&state_callback); + transport_pair.SetState(MediaTransportState::kWritable); + transport_pair.FlushAsyncInvokes(); +} + } // namespace webrtc diff --git a/media/base/mediaengine.h b/media/base/mediaengine.h index e752da8988..62f43f9949 100644 --- a/media/base/mediaengine.h +++ b/media/base/mediaengine.h @@ -159,7 +159,12 @@ class CompositeMediaEngine : public MediaEngineInterface { std::pair engines_; }; -enum DataChannelType { DCT_NONE = 0, DCT_RTP = 1, DCT_SCTP = 2 }; +enum DataChannelType { + DCT_NONE = 0, + DCT_RTP = 1, + DCT_SCTP = 2, + DCT_MEDIA_TRANSPORT = 3 +}; class DataEngineInterface { public: diff --git a/pc/BUILD.gn b/pc/BUILD.gn index 182f07764c..97045ba66d 100644 --- a/pc/BUILD.gn +++ b/pc/BUILD.gn @@ -485,6 +485,7 @@ if (rtc_include_tests) { "../api:fake_frame_decryptor", "../api:fake_frame_encryptor", "../api:libjingle_peerconnection_api", + "../api:loopback_media_transport", "../api:mock_rtp", "../api/units:time_delta", "../logging:fake_rtc_event_log", diff --git a/pc/datachannel.cc b/pc/datachannel.cc index f819d26695..e989586607 100644 --- a/pc/datachannel.cc +++ b/pc/datachannel.cc @@ -118,6 +118,10 @@ rtc::scoped_refptr DataChannel::Create( return channel; } +bool DataChannel::IsSctpLike(cricket::DataChannelType type) { + return type == cricket::DCT_SCTP || type == cricket::DCT_MEDIA_TRANSPORT; +} + DataChannel::DataChannel(DataChannelProviderInterface* provider, cricket::DataChannelType dct, const std::string& label) @@ -147,7 +151,7 @@ bool DataChannel::Init(const InternalDataChannelInit& config) { return false; } handshake_state_ = kHandshakeReady; - } else if (data_channel_type_ == cricket::DCT_SCTP) { + } else if (IsSctpLike(data_channel_type_)) { if (config.id < -1 || config.maxRetransmits < -1 || config.maxRetransmitTime < -1) { RTC_LOG(LS_ERROR) << "Failed to initialize the SCTP data channel due to " @@ -241,7 +245,7 @@ bool DataChannel::Send(const DataBuffer& buffer) { if (!queued_send_data_.Empty()) { // Only SCTP DataChannel queues the outgoing data when the transport is // blocked. - RTC_DCHECK(data_channel_type_ == cricket::DCT_SCTP); + RTC_DCHECK(IsSctpLike(data_channel_type_)); if (!QueueSendDataMessage(buffer)) { RTC_LOG(LS_ERROR) << "Closing the DataChannel due to a failure to queue " "additional data."; @@ -273,7 +277,7 @@ void DataChannel::SetReceiveSsrc(uint32_t receive_ssrc) { void DataChannel::SetSctpSid(int sid) { RTC_DCHECK_LT(config_.id, 0); RTC_DCHECK_GE(sid, 0); - RTC_DCHECK_EQ(data_channel_type_, cricket::DCT_SCTP); + RTC_DCHECK(IsSctpLike(data_channel_type_)); if (config_.id == sid) { return; } @@ -283,7 +287,7 @@ void DataChannel::SetSctpSid(int sid) { } void DataChannel::OnClosingProcedureStartedRemotely(int sid) { - if (data_channel_type_ == cricket::DCT_SCTP && sid == config_.id && + if (IsSctpLike(data_channel_type_) && sid == config_.id && state_ != kClosing && state_ != kClosed) { // Don't bother sending queued data since the side that initiated the // closure wouldn't receive it anyway. See crbug.com/559394 for a lengthy @@ -299,7 +303,7 @@ void DataChannel::OnClosingProcedureStartedRemotely(int sid) { } void DataChannel::OnClosingProcedureComplete(int sid) { - if (data_channel_type_ == cricket::DCT_SCTP && sid == config_.id) { + if (IsSctpLike(data_channel_type_) && sid == config_.id) { // If the closing procedure is complete, we should have finished sending // all pending data and transitioned to kClosing already. RTC_DCHECK_EQ(state_, kClosing); @@ -310,7 +314,7 @@ void DataChannel::OnClosingProcedureComplete(int sid) { } void DataChannel::OnTransportChannelCreated() { - RTC_DCHECK(data_channel_type_ == cricket::DCT_SCTP); + RTC_DCHECK(IsSctpLike(data_channel_type_)); if (!connected_to_provider_) { connected_to_provider_ = provider_->ConnectDataChannel(this); } @@ -348,12 +352,12 @@ void DataChannel::OnDataReceived(const cricket::ReceiveDataParams& params, if (data_channel_type_ == cricket::DCT_RTP && params.ssrc != receive_ssrc_) { return; } - if (data_channel_type_ == cricket::DCT_SCTP && params.sid != config_.id) { + if (IsSctpLike(data_channel_type_) && params.sid != config_.id) { return; } if (params.type == cricket::DMT_CONTROL) { - RTC_DCHECK(data_channel_type_ == cricket::DCT_SCTP); + RTC_DCHECK(IsSctpLike(data_channel_type_)); if (handshake_state_ != kHandshakeWaitingForAck) { // Ignore it if we are not expecting an ACK message. RTC_LOG(LS_WARNING) @@ -570,7 +574,7 @@ bool DataChannel::SendDataMessage(const DataBuffer& buffer, bool queue_if_blocked) { cricket::SendDataParams send_params; - if (data_channel_type_ == cricket::DCT_SCTP) { + if (IsSctpLike(data_channel_type_)) { send_params.ordered = config_.ordered; // Send as ordered if it is still going through OPEN/ACK signaling. if (handshake_state_ != kHandshakeReady && !config_.ordered) { @@ -597,7 +601,7 @@ bool DataChannel::SendDataMessage(const DataBuffer& buffer, return true; } - if (data_channel_type_ != cricket::DCT_SCTP) { + if (!IsSctpLike(data_channel_type_)) { return false; } @@ -649,7 +653,7 @@ void DataChannel::QueueControlMessage(const rtc::CopyOnWriteBuffer& buffer) { bool DataChannel::SendControlMessage(const rtc::CopyOnWriteBuffer& buffer) { bool is_open_message = handshake_state_ == kHandshakeShouldSendOpen; - RTC_DCHECK_EQ(data_channel_type_, cricket::DCT_SCTP); + RTC_DCHECK(IsSctpLike(data_channel_type_)); RTC_DCHECK(writable_); RTC_DCHECK_GE(config_.id, 0); RTC_DCHECK(!is_open_message || !config_.negotiated); diff --git a/pc/datachannel.h b/pc/datachannel.h index cbb0c8b2cf..22ea354c21 100644 --- a/pc/datachannel.h +++ b/pc/datachannel.h @@ -122,6 +122,8 @@ class DataChannel : public DataChannelInterface, public sigslot::has_slots<> { const std::string& label, const InternalDataChannelInit& config); + static bool IsSctpLike(cricket::DataChannelType type); + virtual void RegisterObserver(DataChannelObserver* observer); virtual void UnregisterObserver(); diff --git a/pc/jseptransportcontroller.cc b/pc/jseptransportcontroller.cc index 1613e1e791..cb99686a98 100644 --- a/pc/jseptransportcontroller.cc +++ b/pc/jseptransportcontroller.cc @@ -143,6 +143,15 @@ MediaTransportInterface* JsepTransportController::GetMediaTransport( return jsep_transport->media_transport(); } +MediaTransportState JsepTransportController::GetMediaTransportState( + const std::string& mid) const { + auto jsep_transport = GetJsepTransportForMid(mid); + if (!jsep_transport) { + return MediaTransportState::kPending; + } + return jsep_transport->media_transport_state(); +} + cricket::DtlsTransportInternal* JsepTransportController::GetDtlsTransport( const std::string& mid) const { auto jsep_transport = GetJsepTransportForMid(mid); @@ -1042,7 +1051,7 @@ RTCError JsepTransportController::MaybeCreateJsepTransport( jsep_transport->SignalRtcpMuxActive.connect( this, &JsepTransportController::UpdateAggregateStates_n); jsep_transport->SignalMediaTransportStateChanged.connect( - this, &JsepTransportController::UpdateAggregateStates_n); + this, &JsepTransportController::OnMediaTransportStateChanged_n); SetTransportForMid(content_info.name, jsep_transport.get()); jsep_transports_by_name_[content_info.name] = std::move(jsep_transport); @@ -1224,6 +1233,11 @@ void JsepTransportController::OnTransportStateChanged_n( UpdateAggregateStates_n(); } +void JsepTransportController::OnMediaTransportStateChanged_n() { + SignalMediaTransportStateChanged(); + UpdateAggregateStates_n(); +} + void JsepTransportController::UpdateAggregateStates_n() { RTC_DCHECK(network_thread_->IsCurrent()); diff --git a/pc/jseptransportcontroller.h b/pc/jseptransportcontroller.h index 5747990cd6..8d89795ce3 100644 --- a/pc/jseptransportcontroller.h +++ b/pc/jseptransportcontroller.h @@ -117,6 +117,7 @@ class JsepTransportController : public sigslot::has_slots<> { const std::string& mid) const; MediaTransportInterface* GetMediaTransport(const std::string& mid) const; + MediaTransportState GetMediaTransportState(const std::string& mid) const; /********************* * ICE-related methods @@ -200,6 +201,8 @@ class JsepTransportController : public sigslot::has_slots<> { sigslot::signal1 SignalDtlsHandshakeError; + sigslot::signal<> SignalMediaTransportStateChanged; + private: RTCError ApplyDescription_n(bool local, SdpType type, @@ -311,6 +314,7 @@ class JsepTransportController : public sigslot::has_slots<> { const cricket::Candidates& candidates); void OnTransportRoleConflict_n(cricket::IceTransportInternal* transport); void OnTransportStateChanged_n(cricket::IceTransportInternal* transport); + void OnMediaTransportStateChanged_n(); void UpdateAggregateStates_n(); diff --git a/pc/peerconnection.cc b/pc/peerconnection.cc index 3c5a025f22..54ef67740e 100644 --- a/pc/peerconnection.cc +++ b/pc/peerconnection.cc @@ -631,6 +631,35 @@ absl::optional RTCConfigurationToIceConfigOptionalInt( return rtc_configuration_parameter; } +cricket::DataMessageType ToCricketDataMessageType(DataMessageType type) { + switch (type) { + case DataMessageType::kText: + return cricket::DMT_TEXT; + case DataMessageType::kBinary: + return cricket::DMT_BINARY; + case DataMessageType::kControl: + return cricket::DMT_CONTROL; + default: + return cricket::DMT_NONE; + } + return cricket::DMT_NONE; +} + +DataMessageType ToWebrtcDataMessageType(cricket::DataMessageType type) { + switch (type) { + case cricket::DMT_TEXT: + return DataMessageType::kText; + case cricket::DMT_BINARY: + return DataMessageType::kBinary; + case cricket::DMT_CONTROL: + return DataMessageType::kControl; + case cricket::DMT_NONE: + default: + RTC_NOTREACHED(); + } + return DataMessageType::kControl; +} + } // namespace // Upon completion, posts a task to execute the callback of the @@ -828,6 +857,7 @@ PeerConnection::~PeerConnection() { webrtc_session_desc_factory_.reset(); sctp_invoker_.reset(); sctp_factory_.reset(); + media_transport_invoker_.reset(); transport_controller_.reset(); // port_allocator_ lives on the network thread and should be destroyed there. @@ -1009,10 +1039,18 @@ bool PeerConnection::Initialize( } } - // Enable creation of RTP data channels if the kEnableRtpDataChannels is set. - // It takes precendence over the disable_sctp_data_channels - // PeerConnectionFactoryInterface::Options. - if (configuration.enable_rtp_data_channel) { + if (configuration.use_media_transport_for_data_channels) { + if (configuration.enable_rtp_data_channel) { + RTC_LOG(LS_ERROR) << "enable_rtp_data_channel and " + "use_media_transport_for_data_channels are " + "incompatible and cannot both be set to true"; + return false; + } + data_channel_type_ = cricket::DCT_MEDIA_TRANSPORT; + } else if (configuration.enable_rtp_data_channel) { + // Enable creation of RTP data channels if the kEnableRtpDataChannels is + // set. It takes precendence over the disable_sctp_data_channels + // PeerConnectionFactoryInterface::Options. data_channel_type_ = cricket::DCT_RTP; } else { // DTLS has to be enabled to use SCTP. @@ -2035,6 +2073,16 @@ RTCError PeerConnection::ApplyLocalDescription( // |local_description()|. RTC_DCHECK(local_description()); + if (!is_caller_) { + if (remote_description()) { + // Remote description was applied first, so this PC is the callee. + is_caller_ = false; + } else { + // Local description is applied first, so this PC is the caller. + is_caller_ = true; + } + } + RTCError error = PushdownTransportDescription(cricket::CS_LOCAL, type); if (!error.ok()) { return error; @@ -2117,7 +2165,7 @@ RTCError PeerConnection::ApplyLocalDescription( // If setting the description decided our SSL role, allocate any necessary // SCTP sids. rtc::SSLRole role; - if (data_channel_type() == cricket::DCT_SCTP && GetSctpSslRole(&role)) { + if (DataChannel::IsSctpLike(data_channel_type_) && GetSctpSslRole(&role)) { AllocateSctpSids(role); } @@ -2392,7 +2440,7 @@ RTCError PeerConnection::ApplyRemoteDescription( // If setting the description decided our SSL role, allocate any necessary // SCTP sids. rtc::SSLRole role; - if (data_channel_type() == cricket::DCT_SCTP && GetSctpSslRole(&role)) { + if (DataChannel::IsSctpLike(data_channel_type_) && GetSctpSslRole(&role)) { AllocateSctpSids(role); } @@ -2723,7 +2771,7 @@ RTCError PeerConnection::UpdateDataChannel( if (content.rejected) { DestroyDataChannel(); } else { - if (!rtp_data_channel_ && !sctp_transport_) { + if (!rtp_data_channel_ && !sctp_transport_ && !media_transport_) { if (!CreateDataChannel(content.name)) { LOG_AND_RETURN_ERROR(RTCErrorType::INTERNAL_ERROR, "Failed to create data channel."); @@ -4191,6 +4239,8 @@ absl::optional PeerConnection::GetDataMid() const { return rtp_data_channel_->content_name(); case cricket::DCT_SCTP: return sctp_mid_; + case cricket::DCT_MEDIA_TRANSPORT: + return media_transport_data_mid_; default: return absl::nullopt; } @@ -4553,7 +4603,7 @@ rtc::scoped_refptr PeerConnection::InternalCreateDataChannel( } InternalDataChannelInit new_config = config ? (*config) : InternalDataChannelInit(); - if (data_channel_type() == cricket::DCT_SCTP) { + if (DataChannel::IsSctpLike(data_channel_type_)) { if (new_config.id < 0) { rtc::SSLRole role; if ((GetSctpSslRole(&role)) && @@ -4584,7 +4634,7 @@ rtc::scoped_refptr PeerConnection::InternalCreateDataChannel( } rtp_data_channels_[channel->label()] = channel; } else { - RTC_DCHECK(channel->data_channel_type() == cricket::DCT_SCTP); + RTC_DCHECK(DataChannel::IsSctpLike(data_channel_type_)); sctp_data_channels_.push_back(channel); channel->SignalClosed.connect(this, &PeerConnection::OnSctpDataChannelClosed); @@ -4664,6 +4714,27 @@ void PeerConnection::OnDataChannelOpenMessage( NoteUsageEvent(UsageEvent::DATA_ADDED); } +bool PeerConnection::HandleOpenMessage_s( + const cricket::ReceiveDataParams& params, + const rtc::CopyOnWriteBuffer& buffer) { + if (params.type == cricket::DMT_CONTROL && IsOpenMessage(buffer)) { + // Received OPEN message; parse and signal that a new data channel should + // be created. + std::string label; + InternalDataChannelInit config; + config.id = params.ssrc; + if (!ParseDataChannelOpenMessage(buffer, &label, &config)) { + RTC_LOG(LS_WARNING) << "Failed to parse the OPEN message for ssrc " + << params.ssrc; + return true; + } + config.open_handshake_role = InternalDataChannelInit::kAcker; + OnDataChannelOpenMessage(label, config); + return true; + } + return false; +} + rtc::scoped_refptr> PeerConnection::GetAudioTransceiver() const { // This method only works with Plan B SDP, where there is a single @@ -4907,19 +4978,25 @@ cricket::BaseChannel* PeerConnection::GetChannel( } bool PeerConnection::GetSctpSslRole(rtc::SSLRole* role) { + RTC_DCHECK_RUN_ON(signaling_thread()); if (!local_description() || !remote_description()) { RTC_LOG(LS_INFO) << "Local and Remote descriptions must be applied to get the " "SSL Role of the SCTP transport."; return false; } - if (!sctp_transport_) { + if (!sctp_transport_ && !media_transport_) { RTC_LOG(LS_INFO) << "Non-rejected SCTP m= section is needed to get the " "SSL Role of the SCTP transport."; return false; } - auto dtls_role = transport_controller_->GetDtlsRole(*sctp_mid_); + absl::optional dtls_role; + if (sctp_mid_) { + dtls_role = transport_controller_->GetDtlsRole(*sctp_mid_); + } else if (is_caller_) { + dtls_role = *is_caller_ ? rtc::SSL_SERVER : rtc::SSL_CLIENT; + } if (dtls_role) { *role = *dtls_role; return true; @@ -5165,11 +5242,22 @@ bool PeerConnection::GetRemoteTrackIdBySsrc(uint32_t ssrc, bool PeerConnection::SendData(const cricket::SendDataParams& params, const rtc::CopyOnWriteBuffer& payload, cricket::SendDataResult* result) { - if (!rtp_data_channel_ && !sctp_transport_) { - RTC_LOG(LS_ERROR) << "SendData called when rtp_data_channel_ " - "and sctp_transport_ are NULL."; + if (!rtp_data_channel_ && !sctp_transport_ && !media_transport_) { + RTC_LOG(LS_ERROR) << "SendData called when rtp_data_channel_, " + "sctp_transport_, and media_transport_ are NULL."; return false; } + if (media_transport_) { + SendDataParams send_params; + send_params.type = ToWebrtcDataMessageType(params.type); + send_params.ordered = params.ordered; + if (params.max_rtx_count >= 0) { + send_params.max_rtx_count = params.max_rtx_count; + } else if (params.max_rtx_ms >= 0) { + send_params.max_rtx_ms = params.max_rtx_ms; + } + return media_transport_->SendData(params.sid, send_params, payload).ok(); + } return rtp_data_channel_ ? rtp_data_channel_->SendData(params, payload, result) : network_thread()->Invoke( @@ -5179,13 +5267,23 @@ bool PeerConnection::SendData(const cricket::SendDataParams& params, } bool PeerConnection::ConnectDataChannel(DataChannel* webrtc_data_channel) { - if (!rtp_data_channel_ && !sctp_transport_) { + RTC_DCHECK_RUN_ON(signaling_thread()); + if (!rtp_data_channel_ && !sctp_transport_ && !media_transport_) { // Don't log an error here, because DataChannels are expected to call // ConnectDataChannel in this state. It's the only way to initially tell // whether or not the underlying transport is ready. return false; } - if (rtp_data_channel_) { + if (media_transport_) { + SignalMediaTransportWritable_s.connect(webrtc_data_channel, + &DataChannel::OnChannelReady); + SignalMediaTransportReceivedData_s.connect(webrtc_data_channel, + &DataChannel::OnDataReceived); + SignalMediaTransportChannelClosing_s.connect( + webrtc_data_channel, &DataChannel::OnClosingProcedureStartedRemotely); + SignalMediaTransportChannelClosed_s.connect( + webrtc_data_channel, &DataChannel::OnClosingProcedureComplete); + } else if (rtp_data_channel_) { rtp_data_channel_->SignalReadyToSendData.connect( webrtc_data_channel, &DataChannel::OnChannelReady); rtp_data_channel_->SignalDataReceived.connect(webrtc_data_channel, @@ -5204,13 +5302,19 @@ bool PeerConnection::ConnectDataChannel(DataChannel* webrtc_data_channel) { } void PeerConnection::DisconnectDataChannel(DataChannel* webrtc_data_channel) { - if (!rtp_data_channel_ && !sctp_transport_) { + RTC_DCHECK_RUN_ON(signaling_thread()); + if (!rtp_data_channel_ && !sctp_transport_ && !media_transport_) { RTC_LOG(LS_ERROR) << "DisconnectDataChannel called when rtp_data_channel_ and " "sctp_transport_ are NULL."; return; } - if (rtp_data_channel_) { + if (media_transport_) { + SignalMediaTransportWritable_s.disconnect(webrtc_data_channel); + SignalMediaTransportReceivedData_s.disconnect(webrtc_data_channel); + SignalMediaTransportChannelClosing_s.disconnect(webrtc_data_channel); + SignalMediaTransportChannelClosed_s.disconnect(webrtc_data_channel); + } else if (rtp_data_channel_) { rtp_data_channel_->SignalReadyToSendData.disconnect(webrtc_data_channel); rtp_data_channel_->SignalDataReceived.disconnect(webrtc_data_channel); } else { @@ -5222,6 +5326,10 @@ void PeerConnection::DisconnectDataChannel(DataChannel* webrtc_data_channel) { } void PeerConnection::AddSctpDataStream(int sid) { + if (media_transport_) { + // No-op. Media transport does not need to add streams. + return; + } if (!sctp_transport_) { RTC_LOG(LS_ERROR) << "AddSctpDataStream called when sctp_transport_ is NULL."; @@ -5233,6 +5341,10 @@ void PeerConnection::AddSctpDataStream(int sid) { } void PeerConnection::RemoveSctpDataStream(int sid) { + if (media_transport_) { + media_transport_->CloseChannel(sid); + return; + } if (!sctp_transport_) { RTC_LOG(LS_ERROR) << "RemoveSctpDataStream called when sctp_transport_ is " "NULL."; @@ -5244,10 +5356,43 @@ void PeerConnection::RemoveSctpDataStream(int sid) { } bool PeerConnection::ReadyToSendData() const { + RTC_DCHECK_RUN_ON(signaling_thread()); return (rtp_data_channel_ && rtp_data_channel_->ready_to_send_data()) || + (media_transport_ && media_transport_ready_to_send_data_) || sctp_ready_to_send_data_; } +void PeerConnection::OnDataReceived(int channel_id, + DataMessageType type, + const rtc::CopyOnWriteBuffer& buffer) { + cricket::ReceiveDataParams params; + params.sid = channel_id; + params.type = ToCricketDataMessageType(type); + media_transport_invoker_->AsyncInvoke( + RTC_FROM_HERE, signaling_thread(), [this, params, buffer] { + RTC_DCHECK_RUN_ON(signaling_thread()); + if (!HandleOpenMessage_s(params, buffer)) { + SignalMediaTransportReceivedData_s(params, buffer); + } + }); +} + +void PeerConnection::OnChannelClosing(int channel_id) { + media_transport_invoker_->AsyncInvoke( + RTC_FROM_HERE, signaling_thread(), [this, channel_id] { + RTC_DCHECK_RUN_ON(signaling_thread()); + SignalMediaTransportChannelClosing_s(channel_id); + }); +} + +void PeerConnection::OnChannelClosed(int channel_id) { + media_transport_invoker_->AsyncInvoke( + RTC_FROM_HERE, signaling_thread(), [this, channel_id] { + RTC_DCHECK_RUN_ON(signaling_thread()); + SignalMediaTransportChannelClosed_s(channel_id); + }); +} + absl::optional PeerConnection::sctp_transport_name() const { if (sctp_mid_ && transport_controller_) { auto dtls_transport = transport_controller_->GetDtlsTransport(*sctp_mid_); @@ -5608,7 +5753,7 @@ RTCError PeerConnection::CreateChannels(const SessionDescription& desc) { const cricket::ContentInfo* data = cricket::GetFirstDataContent(&desc); if (data_channel_type_ != cricket::DCT_NONE && data && !data->rejected && - !rtp_data_channel_ && !sctp_transport_) { + !rtp_data_channel_ && !sctp_transport_ && !media_transport_) { if (!CreateDataChannel(data->name)) { LOG_AND_RETURN_ERROR(RTCErrorType::INTERNAL_ERROR, "Failed to create data channel."); @@ -5666,35 +5811,49 @@ cricket::VideoChannel* PeerConnection::CreateVideoChannel( } bool PeerConnection::CreateDataChannel(const std::string& mid) { - bool sctp = (data_channel_type_ == cricket::DCT_SCTP); - if (sctp) { - if (!sctp_factory_) { - RTC_LOG(LS_ERROR) - << "Trying to create SCTP transport, but didn't compile with " - "SCTP support (HAVE_SCTP)"; + switch (data_channel_type_) { + case cricket::DCT_MEDIA_TRANSPORT: + if (network_thread()->Invoke( + RTC_FROM_HERE, + rtc::Bind(&PeerConnection::SetupMediaTransportForDataChannels_n, + this, mid))) { + for (const auto& channel : sctp_data_channels_) { + channel->OnTransportChannelCreated(); + } + return true; + } return false; - } - if (!network_thread()->Invoke( - RTC_FROM_HERE, - rtc::Bind(&PeerConnection::CreateSctpTransport_n, this, mid))) { - return false; - } - for (const auto& channel : sctp_data_channels_) { - channel->OnTransportChannelCreated(); - } - } else { - RtpTransportInternal* rtp_transport = GetRtpTransport(mid); - rtp_data_channel_ = channel_manager()->CreateRtpDataChannel( - configuration_.media_config, rtp_transport, signaling_thread(), mid, - SrtpRequired(), GetCryptoOptions()); - if (!rtp_data_channel_) { - return false; - } - rtp_data_channel_->SignalDtlsSrtpSetupFailure.connect( - this, &PeerConnection::OnDtlsSrtpSetupFailure); - rtp_data_channel_->SignalSentPacket.connect( - this, &PeerConnection::OnSentPacket_w); - rtp_data_channel_->SetRtpTransport(rtp_transport); + case cricket::DCT_SCTP: + if (!sctp_factory_) { + RTC_LOG(LS_ERROR) + << "Trying to create SCTP transport, but didn't compile with " + "SCTP support (HAVE_SCTP)"; + return false; + } + if (!network_thread()->Invoke( + RTC_FROM_HERE, + rtc::Bind(&PeerConnection::CreateSctpTransport_n, this, mid))) { + return false; + } + for (const auto& channel : sctp_data_channels_) { + channel->OnTransportChannelCreated(); + } + return true; + case cricket::DCT_RTP: + default: + RtpTransportInternal* rtp_transport = GetRtpTransport(mid); + rtp_data_channel_ = channel_manager()->CreateRtpDataChannel( + configuration_.media_config, rtp_transport, signaling_thread(), mid, + SrtpRequired(), GetCryptoOptions()); + if (!rtp_data_channel_) { + return false; + } + rtp_data_channel_->SignalDtlsSrtpSetupFailure.connect( + this, &PeerConnection::OnDtlsSrtpSetupFailure); + rtp_data_channel_->SignalSentPacket.connect( + this, &PeerConnection::OnSentPacket_w); + rtp_data_channel_->SetRtpTransport(rtp_transport); + return true; } return true; @@ -5784,22 +5943,8 @@ void PeerConnection::OnSctpTransportDataReceived_n( void PeerConnection::OnSctpTransportDataReceived_s( const cricket::ReceiveDataParams& params, const rtc::CopyOnWriteBuffer& payload) { - RTC_DCHECK(signaling_thread()->IsCurrent()); - if (params.type == cricket::DMT_CONTROL && IsOpenMessage(payload)) { - // Received OPEN message; parse and signal that a new data channel should - // be created. - std::string label; - InternalDataChannelInit config; - config.id = params.ssrc; - if (!ParseDataChannelOpenMessage(payload, &label, &config)) { - RTC_LOG(LS_WARNING) << "Failed to parse the OPEN message for sid " - << params.ssrc; - return; - } - config.open_handshake_role = InternalDataChannelInit::kAcker; - OnDataChannelOpenMessage(label, config); - } else { - // Otherwise just forward the signal. + RTC_DCHECK_RUN_ON(signaling_thread()); + if (!HandleOpenMessage_s(params, payload)) { SignalSctpDataReceived(params, payload); } } @@ -5822,6 +5967,49 @@ void PeerConnection::OnSctpClosingProcedureComplete_n(int sid) { &SignalSctpClosingProcedureComplete, sid)); } +bool PeerConnection::SetupMediaTransportForDataChannels_n( + const std::string& mid) { + media_transport_ = transport_controller_->GetMediaTransport(mid); + if (!media_transport_) { + RTC_LOG(LS_ERROR) << "Media transport is not available for data channels"; + return false; + } + + media_transport_invoker_ = absl::make_unique(); + media_transport_->SetDataSink(this); + media_transport_data_mid_ = mid; + transport_controller_->SignalMediaTransportStateChanged.connect( + this, &PeerConnection::OnMediaTransportStateChanged_n); + // Check the initial state right away, in case transport is already writable. + OnMediaTransportStateChanged_n(); + return true; +} + +void PeerConnection::TeardownMediaTransportForDataChannels_n() { + if (!media_transport_) { + return; + } + transport_controller_->SignalMediaTransportStateChanged.disconnect(this); + media_transport_data_mid_.reset(); + media_transport_->SetDataSink(nullptr); + media_transport_invoker_ = nullptr; + media_transport_ = nullptr; +} + +void PeerConnection::OnMediaTransportStateChanged_n() { + if (!media_transport_data_mid_ || + transport_controller_->GetMediaTransportState( + *media_transport_data_mid_) != MediaTransportState::kWritable) { + return; + } + media_transport_invoker_->AsyncInvoke( + RTC_FROM_HERE, signaling_thread(), [this] { + RTC_DCHECK_RUN_ON(signaling_thread()); + media_transport_ready_to_send_data_ = true; + SignalMediaTransportWritable_s(media_transport_ready_to_send_data_); + }); +} + // Returns false if bundle is enabled and rtcp_mux is disabled. bool PeerConnection::ValidateBundleSettings(const SessionDescription* desc) { bool bundle_enabled = desc->HasGroup(cricket::GROUP_TYPE_BUNDLE); @@ -6336,6 +6524,14 @@ void PeerConnection::DestroyDataChannel() { network_thread()->Invoke(RTC_FROM_HERE, [this] { DestroySctpTransport_n(); }); } + + if (media_transport_) { + OnDataChannelDestroyed(); + network_thread()->Invoke(RTC_FROM_HERE, [this] { + RTC_DCHECK_RUN_ON(network_thread()); + TeardownMediaTransportForDataChannels_n(); + }); + } } void PeerConnection::DestroyBaseChannel(cricket::BaseChannel* channel) { diff --git a/pc/peerconnection.h b/pc/peerconnection.h index 4099992a0e..fe4d777401 100644 --- a/pc/peerconnection.h +++ b/pc/peerconnection.h @@ -52,6 +52,7 @@ class RtcEventLog; // - Generating stats. class PeerConnection : public PeerConnectionInternal, public DataChannelProviderInterface, + public DataChannelSink, public JsepTransportController::Observer, public rtc::MessageHandler, public sigslot::has_slots<> { @@ -632,6 +633,11 @@ class PeerConnection : public PeerConnectionInternal, // Called when a valid data channel OPEN message is received. void OnDataChannelOpenMessage(const std::string& label, const InternalDataChannelInit& config); + // Parses and handles open messages. Returns true if the message is an open + // message, false otherwise. + bool HandleOpenMessage_s(const cricket::ReceiveDataParams& params, + const rtc::CopyOnWriteBuffer& buffer) + RTC_RUN_ON(signaling_thread()); // Returns true if the PeerConnection is configured to use Unified Plan // semantics for creating offers/answers and setting local/remote @@ -733,6 +739,13 @@ class PeerConnection : public PeerConnectionInternal, cricket::DataChannelType data_channel_type() const; + // Implements DataChannelSink. + void OnDataReceived(int channel_id, + DataMessageType type, + const rtc::CopyOnWriteBuffer& buffer) override; + void OnChannelClosing(int channel_id) override; + void OnChannelClosed(int channel_id) override; + // Called when an RTCCertificate is generated or retrieved by // WebRTCSessionDescriptionFactory. Should happen before setLocalDescription. void OnCertificateReady( @@ -830,6 +843,11 @@ class PeerConnection : public PeerConnectionInternal, void OnSctpClosingProcedureStartedRemotely_n(int sid); void OnSctpClosingProcedureComplete_n(int sid); + bool SetupMediaTransportForDataChannels_n(const std::string& mid) + RTC_RUN_ON(network_thread()); + void OnMediaTransportStateChanged_n() RTC_RUN_ON(network_thread()); + void TeardownMediaTransportForDataChannels_n() RTC_RUN_ON(network_thread()); + bool ValidateBundleSettings(const cricket::SessionDescription* desc); bool HasRtcpMuxEnabled(const cricket::ContentInfo* content); // Below methods are helper methods which verifies SDP. @@ -1050,6 +1068,33 @@ class PeerConnection : public PeerConnectionInternal, sigslot::signal1 SignalSctpClosingProcedureStartedRemotely; sigslot::signal1 SignalSctpClosingProcedureComplete; + // Whether this peer is the caller. Set when the local description is applied. + absl::optional is_caller_ RTC_GUARDED_BY(signaling_thread()); + + // Content name (MID) for media transport data channels in SDP. + absl::optional media_transport_data_mid_; + + // Media transport used for data channels. Thread-safe. + MediaTransportInterface* media_transport_ = nullptr; + + // Cached value of whether the media transport is ready to send. + bool media_transport_ready_to_send_data_ RTC_GUARDED_BY(signaling_thread()) = + false; + + // Used to invoke media transport signals on the signaling thread. + std::unique_ptr media_transport_invoker_; + + // Identical to the signals for SCTP, but from media transport: + sigslot::signal1 SignalMediaTransportWritable_s + RTC_GUARDED_BY(signaling_thread()); + sigslot::signal2 + SignalMediaTransportReceivedData_s RTC_GUARDED_BY(signaling_thread()); + sigslot::signal1 SignalMediaTransportChannelClosing_s + RTC_GUARDED_BY(signaling_thread()); + sigslot::signal1 SignalMediaTransportChannelClosed_s + RTC_GUARDED_BY(signaling_thread()); + std::unique_ptr current_local_description_; std::unique_ptr pending_local_description_; std::unique_ptr current_remote_description_; diff --git a/pc/peerconnection_datachannel_unittest.cc b/pc/peerconnection_datachannel_unittest.cc index d7f1fcce5a..cfb5dde22a 100644 --- a/pc/peerconnection_datachannel_unittest.cc +++ b/pc/peerconnection_datachannel_unittest.cc @@ -11,6 +11,7 @@ #include #include "api/peerconnectionproxy.h" +#include "api/test/fake_media_transport.h" #include "media/base/fakemediaengine.h" #include "pc/mediasession.h" #include "pc/peerconnection.h" @@ -31,17 +32,39 @@ using RTCConfiguration = PeerConnectionInterface::RTCConfiguration; using RTCOfferAnswerOptions = PeerConnectionInterface::RTCOfferAnswerOptions; using ::testing::Values; +namespace { + +PeerConnectionFactoryDependencies CreatePeerConnectionFactoryDependencies( + rtc::Thread* network_thread, + rtc::Thread* worker_thread, + rtc::Thread* signaling_thread, + std::unique_ptr media_engine, + std::unique_ptr call_factory, + std::unique_ptr media_transport_factory) { + PeerConnectionFactoryDependencies deps; + deps.network_thread = network_thread; + deps.worker_thread = worker_thread; + deps.signaling_thread = signaling_thread; + deps.media_engine = std::move(media_engine); + deps.call_factory = std::move(call_factory); + deps.media_transport_factory = std::move(media_transport_factory); + return deps; +} + +} // namespace + class PeerConnectionFactoryForDataChannelTest : public rtc::RefCountedObject { public: PeerConnectionFactoryForDataChannelTest() : rtc::RefCountedObject( - rtc::Thread::Current(), - rtc::Thread::Current(), - rtc::Thread::Current(), - absl::make_unique(), - CreateCallFactory(), - nullptr) {} + CreatePeerConnectionFactoryDependencies( + rtc::Thread::Current(), + rtc::Thread::Current(), + rtc::Thread::Current(), + absl::make_unique(), + CreateCallFactory(), + absl::make_unique())) {} std::unique_ptr CreateSctpTransportInternalFactory() { @@ -324,6 +347,52 @@ TEST_P(PeerConnectionDataChannelTest, SctpPortPropagatedFromSdpToTransport) { EXPECT_EQ(kNewRecvPort, callee_transport->local_port()); } +TEST_P(PeerConnectionDataChannelTest, + NoSctpTransportCreatedIfMediaTransportDataChannelsEnabled) { + RTCConfiguration config; + config.use_media_transport_for_data_channels = true; + config.enable_dtls_srtp = false; // SDES is required to use media transport. + auto caller = CreatePeerConnectionWithDataChannel(config); + + ASSERT_TRUE(caller->SetLocalDescription(caller->CreateOffer())); + EXPECT_FALSE(caller->sctp_transport_factory()->last_fake_sctp_transport()); +} + +TEST_P(PeerConnectionDataChannelTest, + MediaTransportDataChannelCreatedEvenIfSctpAvailable) { + RTCConfiguration config; + config.use_media_transport_for_data_channels = true; + config.enable_dtls_srtp = false; // SDES is required to use media transport. + PeerConnectionFactoryInterface::Options options; + options.disable_sctp_data_channels = false; + auto caller = CreatePeerConnectionWithDataChannel(config, options); + + ASSERT_TRUE(caller->SetLocalDescription(caller->CreateOffer())); + EXPECT_FALSE(caller->sctp_transport_factory()->last_fake_sctp_transport()); +} + +TEST_P(PeerConnectionDataChannelTest, + CannotEnableBothMediaTransportAndRtpDataChannels) { + RTCConfiguration config; + config.enable_rtp_data_channel = true; + config.use_media_transport_for_data_channels = true; + config.enable_dtls_srtp = false; // SDES is required to use media transport. + EXPECT_EQ(CreatePeerConnection(config), nullptr); +} + +TEST_P(PeerConnectionDataChannelTest, + MediaTransportDataChannelFailsWithoutSdes) { + RTCConfiguration config; + config.use_media_transport_for_data_channels = true; + config.enable_dtls_srtp = true; // Disables SDES for data sections. + auto caller = CreatePeerConnectionWithDataChannel(config); + + std::string error; + ASSERT_FALSE(caller->SetLocalDescription(caller->CreateOffer(), &error)); + EXPECT_EQ(error, + "Failed to set local offer sdp: Failed to create data channel."); +} + INSTANTIATE_TEST_CASE_P(PeerConnectionDataChannelTest, PeerConnectionDataChannelTest, Values(SdpSemantics::kPlanB, diff --git a/pc/peerconnection_integrationtest.cc b/pc/peerconnection_integrationtest.cc index 236a8bb97d..ccd7d05cb5 100644 --- a/pc/peerconnection_integrationtest.cc +++ b/pc/peerconnection_integrationtest.cc @@ -28,6 +28,7 @@ #include "api/peerconnectioninterface.h" #include "api/peerconnectionproxy.h" #include "api/rtpreceiverinterface.h" +#include "api/test/loopback_media_transport.h" #include "api/umametrics.h" #include "api/video_codecs/builtin_video_decoder_factory.h" #include "api/video_codecs/builtin_video_encoder_factory.h" @@ -251,7 +252,8 @@ class PeerConnectionWrapper : public webrtc::PeerConnectionObserver, webrtc::PeerConnectionDependencies dependencies(nullptr); dependencies.cert_generator = std::move(cert_generator); if (!client->Init(nullptr, nullptr, std::move(dependencies), network_thread, - worker_thread, nullptr)) { + worker_thread, nullptr, + /*media_transport_factory=*/nullptr)) { delete client; return nullptr; } @@ -588,12 +590,14 @@ class PeerConnectionWrapper : public webrtc::PeerConnectionObserver, explicit PeerConnectionWrapper(const std::string& debug_name) : debug_name_(debug_name) {} - bool Init(const PeerConnectionFactory::Options* options, - const PeerConnectionInterface::RTCConfiguration* config, - webrtc::PeerConnectionDependencies dependencies, - rtc::Thread* network_thread, - rtc::Thread* worker_thread, - std::unique_ptr event_log_factory) { + bool Init( + const PeerConnectionFactory::Options* options, + const PeerConnectionInterface::RTCConfiguration* config, + webrtc::PeerConnectionDependencies dependencies, + rtc::Thread* network_thread, + rtc::Thread* worker_thread, + std::unique_ptr event_log_factory, + std::unique_ptr media_transport_factory) { // There's an error in this test code if Init ends up being called twice. RTC_DCHECK(!peer_connection_); RTC_DCHECK(!peer_connection_factory_); @@ -631,6 +635,10 @@ class PeerConnectionWrapper : public webrtc::PeerConnectionObserver, pc_factory_dependencies.event_log_factory = webrtc::CreateRtcEventLogFactory(); } + if (media_transport_factory) { + pc_factory_dependencies.media_transport_factory = + std::move(media_transport_factory); + } peer_connection_factory_ = webrtc::CreateModularPeerConnectionFactory( std::move(pc_factory_dependencies)); @@ -1156,7 +1164,8 @@ class PeerConnectionIntegrationBaseTest : public testing::Test { ss_(new rtc::VirtualSocketServer()), fss_(new rtc::FirewallSocketServer(ss_.get())), network_thread_(new rtc::Thread(fss_.get())), - worker_thread_(rtc::Thread::Create()) { + worker_thread_(rtc::Thread::Create()), + loopback_media_transports_(network_thread_.get()) { network_thread_->SetName("PCNetworkThread", this); worker_thread_->SetName("PCWorkerThread", this); RTC_CHECK(network_thread_->Start()); @@ -1212,7 +1221,8 @@ class PeerConnectionIntegrationBaseTest : public testing::Test { const PeerConnectionFactory::Options* options, const RTCConfiguration* config, webrtc::PeerConnectionDependencies dependencies, - std::unique_ptr event_log_factory) { + std::unique_ptr event_log_factory, + std::unique_ptr media_transport_factory) { RTCConfiguration modified_config; if (config) { modified_config = *config; @@ -1227,7 +1237,8 @@ class PeerConnectionIntegrationBaseTest : public testing::Test { if (!client->Init(options, &modified_config, std::move(dependencies), network_thread_.get(), worker_thread_.get(), - std::move(event_log_factory))) { + std::move(event_log_factory), + std::move(media_transport_factory))) { return nullptr; } return client; @@ -1243,7 +1254,8 @@ class PeerConnectionIntegrationBaseTest : public testing::Test { new webrtc::FakeRtcEventLogFactory(rtc::Thread::Current())); return CreatePeerConnectionWrapper(debug_name, options, config, std::move(dependencies), - std::move(event_log_factory)); + std::move(event_log_factory), + /*media_transport_factory=*/nullptr); } bool CreatePeerConnectionWrappers() { @@ -1264,11 +1276,11 @@ class PeerConnectionIntegrationBaseTest : public testing::Test { sdp_semantics_ = caller_semantics; caller_ = CreatePeerConnectionWrapper( "Caller", nullptr, nullptr, webrtc::PeerConnectionDependencies(nullptr), - nullptr); + nullptr, /*media_transport_factory=*/nullptr); sdp_semantics_ = callee_semantics; callee_ = CreatePeerConnectionWrapper( "Callee", nullptr, nullptr, webrtc::PeerConnectionDependencies(nullptr), - nullptr); + nullptr, /*media_transport_factory=*/nullptr); sdp_semantics_ = original_semantics; return caller_ && callee_; } @@ -1278,10 +1290,28 @@ class PeerConnectionIntegrationBaseTest : public testing::Test { const PeerConnectionInterface::RTCConfiguration& callee_config) { caller_ = CreatePeerConnectionWrapper( "Caller", nullptr, &caller_config, - webrtc::PeerConnectionDependencies(nullptr), nullptr); + webrtc::PeerConnectionDependencies(nullptr), nullptr, + /*media_transport_factory=*/nullptr); callee_ = CreatePeerConnectionWrapper( "Callee", nullptr, &callee_config, - webrtc::PeerConnectionDependencies(nullptr), nullptr); + webrtc::PeerConnectionDependencies(nullptr), nullptr, + /*media_transport_factory=*/nullptr); + return caller_ && callee_; + } + + bool CreatePeerConnectionWrappersWithConfigAndMediaTransportFactory( + const PeerConnectionInterface::RTCConfiguration& caller_config, + const PeerConnectionInterface::RTCConfiguration& callee_config, + std::unique_ptr caller_factory, + std::unique_ptr callee_factory) { + caller_ = + CreatePeerConnectionWrapper("Caller", nullptr, &caller_config, + webrtc::PeerConnectionDependencies(nullptr), + nullptr, std::move(caller_factory)); + callee_ = + CreatePeerConnectionWrapper("Callee", nullptr, &callee_config, + webrtc::PeerConnectionDependencies(nullptr), + nullptr, std::move(callee_factory)); return caller_ && callee_; } @@ -1292,10 +1322,12 @@ class PeerConnectionIntegrationBaseTest : public testing::Test { webrtc::PeerConnectionDependencies callee_dependencies) { caller_ = CreatePeerConnectionWrapper("Caller", nullptr, &caller_config, - std::move(caller_dependencies), nullptr); + std::move(caller_dependencies), nullptr, + /*media_transport_factory=*/nullptr); callee_ = CreatePeerConnectionWrapper("Callee", nullptr, &callee_config, - std::move(callee_dependencies), nullptr); + std::move(callee_dependencies), nullptr, + /*media_transport_factory=*/nullptr); return caller_ && callee_; } @@ -1304,10 +1336,12 @@ class PeerConnectionIntegrationBaseTest : public testing::Test { const PeerConnectionFactory::Options& callee_options) { caller_ = CreatePeerConnectionWrapper( "Caller", &caller_options, nullptr, - webrtc::PeerConnectionDependencies(nullptr), nullptr); + webrtc::PeerConnectionDependencies(nullptr), nullptr, + /*media_transport_factory=*/nullptr); callee_ = CreatePeerConnectionWrapper( "Callee", &callee_options, nullptr, - webrtc::PeerConnectionDependencies(nullptr), nullptr); + webrtc::PeerConnectionDependencies(nullptr), nullptr, + /*media_transport_factory=*/nullptr); return caller_ && callee_; } @@ -1331,7 +1365,8 @@ class PeerConnectionIntegrationBaseTest : public testing::Test { webrtc::PeerConnectionDependencies dependencies(nullptr); dependencies.cert_generator = std::move(cert_generator); return CreatePeerConnectionWrapper("New Peer", nullptr, nullptr, - std::move(dependencies), nullptr); + std::move(dependencies), nullptr, + /*media_transport_factory=*/nullptr); } cricket::TestTurnServer* CreateTurnServer( @@ -1419,6 +1454,10 @@ class PeerConnectionIntegrationBaseTest : public testing::Test { rtc::VirtualSocketServer* virtual_socket_server() { return ss_.get(); } + webrtc::MediaTransportPair* loopback_media_transports() { + return &loopback_media_transports_; + } + PeerConnectionWrapper* caller() { return caller_.get(); } // Set the |caller_| to the |wrapper| passed in and return the @@ -1597,6 +1636,7 @@ class PeerConnectionIntegrationBaseTest : public testing::Test { // on the network thread. std::vector> turn_servers_; std::vector> turn_customizers_; + webrtc::MediaTransportPair loopback_media_transports_; std::unique_ptr caller_; std::unique_ptr callee_; }; @@ -3347,6 +3387,111 @@ TEST_P(PeerConnectionIntegrationTest, #endif // HAVE_SCTP +// This test sets up a call between two parties with audio, video, and a media +// transport data channel. +TEST_P(PeerConnectionIntegrationTest, MediaTransportDataChannelEndToEnd) { + PeerConnectionInterface::RTCConfiguration rtc_config; + rtc_config.use_media_transport_for_data_channels = true; + rtc_config.enable_dtls_srtp = false; // SDES is required for media transport. + ASSERT_TRUE(CreatePeerConnectionWrappersWithConfigAndMediaTransportFactory( + rtc_config, rtc_config, loopback_media_transports()->first_factory(), + loopback_media_transports()->second_factory())); + ConnectFakeSignaling(); + + // Expect that data channel created on caller side will show up for callee as + // well. + caller()->CreateDataChannel(); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + + // Ensure that the media transport is ready. + loopback_media_transports()->SetState(webrtc::MediaTransportState::kWritable); + loopback_media_transports()->FlushAsyncInvokes(); + + // Caller data channel should already exist (it created one). Callee data + // channel may not exist yet, since negotiation happens in-band, not in SDP. + ASSERT_NE(nullptr, caller()->data_channel()); + ASSERT_TRUE_WAIT(callee()->data_channel() != nullptr, kDefaultTimeout); + EXPECT_TRUE_WAIT(caller()->data_observer()->IsOpen(), kDefaultTimeout); + EXPECT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); + + // Ensure data can be sent in both directions. + std::string data = "hello world"; + caller()->data_channel()->Send(DataBuffer(data)); + EXPECT_EQ_WAIT(data, callee()->data_observer()->last_message(), + kDefaultTimeout); + callee()->data_channel()->Send(DataBuffer(data)); + EXPECT_EQ_WAIT(data, caller()->data_observer()->last_message(), + kDefaultTimeout); +} + +// Ensure that when the callee closes a media transport data channel, the +// closing procedure results in the data channel being closed for the caller +// as well. +TEST_P(PeerConnectionIntegrationTest, MediaTransportDataChannelCalleeCloses) { + PeerConnectionInterface::RTCConfiguration rtc_config; + rtc_config.use_media_transport_for_data_channels = true; + rtc_config.enable_dtls_srtp = false; // SDES is required for media transport. + ASSERT_TRUE(CreatePeerConnectionWrappersWithConfigAndMediaTransportFactory( + rtc_config, rtc_config, loopback_media_transports()->first_factory(), + loopback_media_transports()->second_factory())); + ConnectFakeSignaling(); + + // Create a data channel on the caller and signal it to the callee. + caller()->CreateDataChannel(); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + + // Ensure that the media transport is ready. + loopback_media_transports()->SetState(webrtc::MediaTransportState::kWritable); + loopback_media_transports()->FlushAsyncInvokes(); + + // Data channels exist and open on both ends of the connection. + ASSERT_NE(nullptr, caller()->data_channel()); + ASSERT_TRUE_WAIT(callee()->data_channel() != nullptr, kDefaultTimeout); + ASSERT_TRUE_WAIT(caller()->data_observer()->IsOpen(), kDefaultTimeout); + ASSERT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); + + // Close the data channel on the callee side, and wait for it to reach the + // "closed" state on both sides. + callee()->data_channel()->Close(); + EXPECT_TRUE_WAIT(!caller()->data_observer()->IsOpen(), kDefaultTimeout); + EXPECT_TRUE_WAIT(!callee()->data_observer()->IsOpen(), kDefaultTimeout); +} + +TEST_P(PeerConnectionIntegrationTest, + MediaTransportDataChannelConfigSentToOtherSide) { + PeerConnectionInterface::RTCConfiguration rtc_config; + rtc_config.use_media_transport_for_data_channels = true; + rtc_config.enable_dtls_srtp = false; // SDES is required for media transport. + ASSERT_TRUE(CreatePeerConnectionWrappersWithConfigAndMediaTransportFactory( + rtc_config, rtc_config, loopback_media_transports()->first_factory(), + loopback_media_transports()->second_factory())); + ConnectFakeSignaling(); + + // Create a data channel with a non-default configuration and signal it to the + // callee. + webrtc::DataChannelInit init; + init.id = 53; + init.maxRetransmits = 52; + caller()->CreateDataChannel("data-channel", &init); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + + // Ensure that the media transport is ready. + loopback_media_transports()->SetState(webrtc::MediaTransportState::kWritable); + loopback_media_transports()->FlushAsyncInvokes(); + + // Ensure that the data channel exists on the callee with the correct + // configuration. + ASSERT_TRUE_WAIT(callee()->data_channel() != nullptr, kDefaultTimeout); + ASSERT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); + EXPECT_EQ(init.id, callee()->data_channel()->id()); + EXPECT_EQ("data-channel", callee()->data_channel()->label()); + EXPECT_EQ(init.maxRetransmits, callee()->data_channel()->maxRetransmits()); + EXPECT_FALSE(callee()->data_channel()->negotiated()); +} + // Test that the ICE connection and gathering states eventually reach // "complete". TEST_P(PeerConnectionIntegrationTest, IceStatesReachCompletion) {