diff --git a/audio/BUILD.gn b/audio/BUILD.gn index bbf6bdd0d9..054e090ba6 100644 --- a/audio/BUILD.gn +++ b/audio/BUILD.gn @@ -152,6 +152,7 @@ if (rtc_include_tests) { "../api/audio_codecs:audio_codecs_api", "../api/audio_codecs/opus:audio_decoder_opus", "../api/audio_codecs/opus:audio_encoder_opus", + "../api/crypto:frame_decryptor_interface", "../api/rtc_event_log", "../api/task_queue:default_task_queue_factory", "../api/units:time_delta", diff --git a/audio/audio_receive_stream.cc b/audio/audio_receive_stream.cc index cc53a746ff..7476e08914 100644 --- a/audio/audio_receive_stream.cc +++ b/audio/audio_receive_stream.cc @@ -246,6 +246,14 @@ void AudioReceiveStream::SetUseTransportCcAndNackHistory(bool use_transport_cc, } } +void AudioReceiveStream::SetFrameDecryptor( + rtc::scoped_refptr frame_decryptor) { + // TODO(bugs.webrtc.org/11993): This is called via WebRtcAudioReceiveStream, + // expect to be called on the network thread. + RTC_DCHECK_RUN_ON(&worker_thread_checker_); + channel_receive_->SetFrameDecryptor(std::move(frame_decryptor)); +} + webrtc::AudioReceiveStream::Stats AudioReceiveStream::GetStats( bool get_and_clear_legacy_stats) const { RTC_DCHECK_RUN_ON(&worker_thread_checker_); diff --git a/audio/audio_receive_stream.h b/audio/audio_receive_stream.h index 4f63155377..108794cd92 100644 --- a/audio/audio_receive_stream.h +++ b/audio/audio_receive_stream.h @@ -91,6 +91,8 @@ class AudioReceiveStream final : public webrtc::AudioReceiveStream, void SetDecoderMap(std::map decoder_map) override; void SetUseTransportCcAndNackHistory(bool use_transport_cc, int history_ms) override; + void SetFrameDecryptor(rtc::scoped_refptr + frame_decryptor) override; webrtc::AudioReceiveStream::Stats GetStats( bool get_and_clear_legacy_stats) const override; diff --git a/audio/channel_receive.cc b/audio/channel_receive.cc index 0582171b62..28568b17c4 100644 --- a/audio/channel_receive.cc +++ b/audio/channel_receive.cc @@ -177,6 +177,9 @@ class ChannelReceive : public ChannelReceiveInterface { rtc::scoped_refptr frame_transformer) override; + void SetFrameDecryptor(rtc::scoped_refptr + frame_decryptor) override; + private: void ReceivePacket(const uint8_t* packet, size_t packet_length, @@ -275,10 +278,12 @@ class ChannelReceive : public ChannelReceiveInterface { SequenceChecker construction_thread_; // E2EE Audio Frame Decryption - rtc::scoped_refptr frame_decryptor_; + rtc::scoped_refptr frame_decryptor_ + RTC_GUARDED_BY(worker_thread_checker_); webrtc::CryptoOptions crypto_options_; - webrtc::AbsoluteCaptureTimeInterpolator absolute_capture_time_interpolator_; + webrtc::AbsoluteCaptureTimeInterpolator absolute_capture_time_interpolator_ + RTC_GUARDED_BY(worker_thread_checker_); webrtc::CaptureClockOffsetUpdater capture_clock_offset_updater_; @@ -889,6 +894,13 @@ void ChannelReceive::SetDepacketizerToDecoderFrameTransformer( InitFrameTransformerDelegate(std::move(frame_transformer)); } +void ChannelReceive::SetFrameDecryptor( + rtc::scoped_refptr frame_decryptor) { + // TODO(bugs.webrtc.org/11993): Expect to be called on the network thread. + RTC_DCHECK_RUN_ON(&worker_thread_checker_); + frame_decryptor_ = std::move(frame_decryptor); +} + NetworkStatistics ChannelReceive::GetNetworkStatistics( bool get_and_clear_legacy_stats) const { RTC_DCHECK_RUN_ON(&worker_thread_checker_); diff --git a/audio/channel_receive.h b/audio/channel_receive.h index c55968b55f..0a51be6649 100644 --- a/audio/channel_receive.h +++ b/audio/channel_receive.h @@ -159,6 +159,9 @@ class ChannelReceiveInterface : public RtpPacketSinkInterface { virtual void SetDepacketizerToDecoderFrameTransformer( rtc::scoped_refptr frame_transformer) = 0; + + virtual void SetFrameDecryptor( + rtc::scoped_refptr frame_decryptor) = 0; }; std::unique_ptr CreateChannelReceive( diff --git a/audio/mock_voe_channel_proxy.h b/audio/mock_voe_channel_proxy.h index 7f140d400d..deef5ae068 100644 --- a/audio/mock_voe_channel_proxy.h +++ b/audio/mock_voe_channel_proxy.h @@ -17,6 +17,7 @@ #include #include +#include "api/crypto/frame_decryptor_interface.h" #include "api/test/mock_frame_encryptor.h" #include "audio/channel_receive.h" #include "audio/channel_send.h" @@ -98,6 +99,11 @@ class MockChannelReceive : public voe::ChannelReceiveInterface { SetDepacketizerToDecoderFrameTransformer, (rtc::scoped_refptr frame_transformer), (override)); + MOCK_METHOD( + void, + SetFrameDecryptor, + (rtc::scoped_refptr frame_decryptor), + (override)); }; class MockChannelSend : public voe::ChannelSendInterface { diff --git a/call/audio_receive_stream.h b/call/audio_receive_stream.h index 45c318c404..2f67f7cc14 100644 --- a/call/audio_receive_stream.h +++ b/call/audio_receive_stream.h @@ -177,6 +177,8 @@ class AudioReceiveStream { virtual void SetDecoderMap(std::map decoder_map) = 0; virtual void SetUseTransportCcAndNackHistory(bool use_transport_cc, int history_ms) = 0; + virtual void SetFrameDecryptor( + rtc::scoped_refptr frame_decryptor) = 0; // Starts stream activity. // When a stream is active, it can receive, process and deliver packets. diff --git a/media/engine/fake_webrtc_call.cc b/media/engine/fake_webrtc_call.cc index 5f484285a5..0190d88e0e 100644 --- a/media/engine/fake_webrtc_call.cc +++ b/media/engine/fake_webrtc_call.cc @@ -113,6 +113,11 @@ void FakeAudioReceiveStream::SetUseTransportCcAndNackHistory( config_.rtp.nack.rtp_history_ms = history_ms; } +void FakeAudioReceiveStream::SetFrameDecryptor( + rtc::scoped_refptr frame_decryptor) { + config_.frame_decryptor = std::move(frame_decryptor); +} + webrtc::AudioReceiveStream::Stats FakeAudioReceiveStream::GetStats( bool get_and_clear_legacy_stats) const { return stats_; diff --git a/media/engine/fake_webrtc_call.h b/media/engine/fake_webrtc_call.h index 79f155cd86..8df85d7564 100644 --- a/media/engine/fake_webrtc_call.h +++ b/media/engine/fake_webrtc_call.h @@ -112,6 +112,8 @@ class FakeAudioReceiveStream final : public webrtc::AudioReceiveStream { std::map decoder_map) override; void SetUseTransportCcAndNackHistory(bool use_transport_cc, int history_ms) override; + void SetFrameDecryptor(rtc::scoped_refptr + frame_decryptor) override; webrtc::AudioReceiveStream::Stats GetStats( bool get_and_clear_legacy_stats) const override; diff --git a/media/engine/webrtc_voice_engine.cc b/media/engine/webrtc_voice_engine.cc index 602d23cf68..575e2325f3 100644 --- a/media/engine/webrtc_voice_engine.cc +++ b/media/engine/webrtc_voice_engine.cc @@ -1218,7 +1218,7 @@ class WebRtcVoiceMediaChannel::WebRtcAudioReceiveStream { rtc::scoped_refptr frame_decryptor) { RTC_DCHECK_RUN_ON(&worker_thread_checker_); config_.frame_decryptor = frame_decryptor; - RecreateAudioReceiveStream(); + stream_->SetFrameDecryptor(std::move(frame_decryptor)); } void SetLocalSsrc(uint32_t local_ssrc) {