From 64e58309693a64b58a1a4bc910065458589904c0 Mon Sep 17 00:00:00 2001 From: Alessio Bazzica Date: Thu, 14 Oct 2021 15:47:52 +0200 Subject: [PATCH] AGC2: VAD wrapper, add `Initialize()` method Not passing the sample rate to the `VoiceActivityDetectorWrapper` ctor yet since that would require an unnecessary refactoring of `AdaptiveAgc` which will soon be removed. Instead, to ensure correct initialization until the child CL [1] lands, `VoiceActivityDetectorWrapper::initialized_` is temporarily added. Bit exactness verified with audioproc_f on a collection of AEC dumps and Wav files (42 recordings in total). [1] https://webrtc-review.googlesource.com/c/src/+/234583 Bug: webrtc:7494 Change-Id: I4b4be7b8106ba36c958d91bf263a7b30271a1ee3 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/234587 Commit-Queue: Alessio Bazzica Reviewed-by: Hanna Silen Cr-Commit-Position: refs/heads/main@{#35213} --- modules/audio_processing/agc2/adaptive_agc.cc | 1 + modules/audio_processing/agc2/vad_wrapper.cc | 22 ++++++++--- modules/audio_processing/agc2/vad_wrapper.h | 10 +++++ .../agc2/vad_wrapper_unittest.cc | 37 ++++++++++++++----- 4 files changed, 56 insertions(+), 14 deletions(-) diff --git a/modules/audio_processing/agc2/adaptive_agc.cc b/modules/audio_processing/agc2/adaptive_agc.cc index fb06549140..b5433655c4 100644 --- a/modules/audio_processing/agc2/adaptive_agc.cc +++ b/modules/audio_processing/agc2/adaptive_agc.cc @@ -77,6 +77,7 @@ AdaptiveAgc::~AdaptiveAgc() = default; void AdaptiveAgc::Initialize(int sample_rate_hz, int num_channels) { gain_controller_.Initialize(sample_rate_hz, num_channels); + vad_.Initialize(sample_rate_hz); } void AdaptiveAgc::Process(AudioFrameView frame, float limiter_envelope) { diff --git a/modules/audio_processing/agc2/vad_wrapper.cc b/modules/audio_processing/agc2/vad_wrapper.cc index 7b61aee99d..17d9638be2 100644 --- a/modules/audio_processing/agc2/vad_wrapper.cc +++ b/modules/audio_processing/agc2/vad_wrapper.cc @@ -64,6 +64,8 @@ VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper( std::unique_ptr vad) : vad_reset_period_frames_( rtc::CheckedDivExact(vad_reset_period_ms, kFrameDurationMs)), + initialized_(false), + frame_size_(0), time_to_vad_reset_(vad_reset_period_frames_), vad_(std::move(vad)) { RTC_DCHECK(vad_); @@ -74,19 +76,29 @@ VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper( VoiceActivityDetectorWrapper::~VoiceActivityDetectorWrapper() = default; +void VoiceActivityDetectorWrapper::Initialize(int sample_rate_hz) { + RTC_DCHECK_GT(sample_rate_hz, 0); + frame_size_ = rtc::CheckedDivExact(sample_rate_hz, kNumFramesPerSecond); + int status = + resampler_.InitializeIfNeeded(sample_rate_hz, vad_->SampleRateHz(), + /*num_channels=*/1); + constexpr int kStatusOk = 0; + RTC_DCHECK_EQ(status, kStatusOk); + vad_->Reset(); + initialized_ = true; +} + float VoiceActivityDetectorWrapper::Analyze(AudioFrameView frame) { + RTC_DCHECK(initialized_); // Periodically reset the VAD. time_to_vad_reset_--; if (time_to_vad_reset_ <= 0) { vad_->Reset(); time_to_vad_reset_ = vad_reset_period_frames_; } - // Resample the first channel of `frame`. - resampler_.InitializeIfNeeded( - /*sample_rate_hz=*/frame.samples_per_channel() * kNumFramesPerSecond, - vad_->SampleRateHz(), /*num_channels=*/1); - resampler_.Resample(frame.channel(0).data(), frame.samples_per_channel(), + RTC_DCHECK_EQ(frame.samples_per_channel(), frame_size_); + resampler_.Resample(frame.channel(0).data(), frame_size_, resampled_buffer_.data(), resampled_buffer_.size()); return vad_->Analyze(resampled_buffer_); diff --git a/modules/audio_processing/agc2/vad_wrapper.h b/modules/audio_processing/agc2/vad_wrapper.h index f17fcda6b8..0579ca11d4 100644 --- a/modules/audio_processing/agc2/vad_wrapper.h +++ b/modules/audio_processing/agc2/vad_wrapper.h @@ -43,6 +43,7 @@ class VoiceActivityDetectorWrapper { // Ctor. `vad_reset_period_ms` indicates the period in milliseconds to call // `MonoVad::Reset()`; it must be equal to or greater than the duration of two // frames. Uses `cpu_features` to instantiate the default VAD. + // TODO(bugs.webrtc.org/7494): Pass sample rate. VoiceActivityDetectorWrapper(int vad_reset_period_ms, const AvailableCpuFeatures& cpu_features); // Ctor. Uses a custom `vad`. @@ -54,11 +55,20 @@ class VoiceActivityDetectorWrapper { delete; ~VoiceActivityDetectorWrapper(); + // TODO(bugs.webrtc.org/7494): Call initialize in the ctor. + // Initializes the VAD wrapper. Must be called before `Analyze()`. + void Initialize(int sample_rate_hz); + // Analyzes the first channel of `frame` and returns the speech probability. + // `frame` must be a 10 ms frame with the sample rate specified in the last + // `Initialize()` call. float Analyze(AudioFrameView frame); private: const int vad_reset_period_frames_; + // TODO(bugs.webrtc.org/7494): Remove `initialized_`. + bool initialized_; + int frame_size_; int time_to_vad_reset_; PushResampler resampler_; std::unique_ptr vad_; diff --git a/modules/audio_processing/agc2/vad_wrapper_unittest.cc b/modules/audio_processing/agc2/vad_wrapper_unittest.cc index c1f7029ef1..27e5af6843 100644 --- a/modules/audio_processing/agc2/vad_wrapper_unittest.cc +++ b/modules/audio_processing/agc2/vad_wrapper_unittest.cc @@ -43,13 +43,26 @@ class MockVad : public VoiceActivityDetectorWrapper::MonoVad { MOCK_METHOD(float, Analyze, (rtc::ArrayView frame), (override)); }; +// Checks that the ctor and `Initialize()` read the sample rate of the wrapped +// VAD. +TEST(GainController2VoiceActivityDetectorWrapper, CtorAndInitReadSampleRate) { + auto vad = std::make_unique(); + EXPECT_CALL(*vad, SampleRateHz) + .Times(2) + .WillRepeatedly(Return(kSampleRate8kHz)); + EXPECT_CALL(*vad, Reset).Times(AnyNumber()); + auto vad_wrapper = std::make_unique( + kNoVadPeriodicReset, std::move(vad)); + vad_wrapper->Initialize(kSampleRate8kHz); +} + // Creates a `VoiceActivityDetectorWrapper` injecting a mock VAD that // repeatedly returns the next value from `speech_probabilities` and that // restarts from the beginning when after the last element is returned. std::unique_ptr CreateMockVadWrapper( int vad_reset_period_ms, const std::vector& speech_probabilities, - int expected_vad_reset_calls = 0) { + int expected_vad_reset_calls) { auto vad = std::make_unique(); EXPECT_CALL(*vad, SampleRateHz) .Times(AnyNumber()) @@ -67,7 +80,7 @@ std::unique_ptr CreateMockVadWrapper( // 10 ms mono frame. struct FrameWithView { // Ctor. Initializes the frame samples with `value`. - explicit FrameWithView(int sample_rate_hz = kSampleRate8kHz) + explicit FrameWithView(int sample_rate_hz) : samples(rtc::CheckedDivExact(sample_rate_hz, 100), 0.0f), channel0(samples.data()), view(&channel0, /*num_channels=*/1, samples.size()) {} @@ -82,8 +95,10 @@ TEST(GainController2VoiceActivityDetectorWrapper, CheckSpeechProbabilities) { 0.44f, 0.525f, 0.858f, 0.314f, 0.653f, 0.965f, 0.413f, 0.0f}; auto vad_wrapper = - CreateMockVadWrapper(kNoVadPeriodicReset, speech_probabilities); - FrameWithView frame; + CreateMockVadWrapper(kNoVadPeriodicReset, speech_probabilities, + /*expected_vad_reset_calls=*/1); + vad_wrapper->Initialize(kSampleRate8kHz); + FrameWithView frame(kSampleRate8kHz); for (int i = 0; rtc::SafeLt(i, speech_probabilities.size()); ++i) { SCOPED_TRACE(i); EXPECT_EQ(speech_probabilities[i], vad_wrapper->Analyze(frame.view)); @@ -95,8 +110,9 @@ TEST(GainController2VoiceActivityDetectorWrapper, VadNoPeriodicReset) { constexpr int kNumFrames = 19; auto vad_wrapper = CreateMockVadWrapper(kNoVadPeriodicReset, /*speech_probabilities=*/{1.0f}, - /*expected_vad_reset_calls=*/0); - FrameWithView frame; + /*expected_vad_reset_calls=*/1); + vad_wrapper->Initialize(kSampleRate8kHz); + FrameWithView frame(kSampleRate8kHz); for (int i = 0; i < kNumFrames; ++i) { vad_wrapper->Analyze(frame.view); } @@ -114,8 +130,10 @@ TEST_P(VadPeriodResetParametrization, VadPeriodicReset) { auto vad_wrapper = CreateMockVadWrapper( /*vad_reset_period_ms=*/vad_reset_period_frames() * kFrameDurationMs, /*speech_probabilities=*/{1.0f}, - /*expected_vad_reset_calls=*/num_frames() / vad_reset_period_frames()); - FrameWithView frame; + /*expected_vad_reset_calls=*/1 + + num_frames() / vad_reset_period_frames()); + vad_wrapper->Initialize(kSampleRate8kHz); + FrameWithView frame(kSampleRate8kHz); for (int i = 0; i < num_frames(); ++i) { vad_wrapper->Analyze(frame.view); } @@ -141,7 +159,7 @@ TEST_P(VadResamplingParametrization, CheckResampledFrameSize) { EXPECT_CALL(*vad, SampleRateHz) .Times(AnyNumber()) .WillRepeatedly(Return(vad_sample_rate_hz())); - EXPECT_CALL(*vad, Reset).Times(0); + EXPECT_CALL(*vad, Reset).Times(1); EXPECT_CALL(*vad, Analyze(Truly([this](rtc::ArrayView frame) { return rtc::SafeEq(frame.size(), rtc::CheckedDivExact(vad_sample_rate_hz(), 100)); @@ -149,6 +167,7 @@ TEST_P(VadResamplingParametrization, CheckResampledFrameSize) { auto vad_wrapper = std::make_unique( kNoVadPeriodicReset, std::move(vad)); FrameWithView frame(input_sample_rate_hz()); + vad_wrapper->Initialize(input_sample_rate_hz()); vad_wrapper->Analyze(frame.view); }