From 8dbdf5e3bf8e7e0215d4f56b7836ea195312ee18 Mon Sep 17 00:00:00 2001 From: Alessio Bazzica Date: Thu, 14 Oct 2021 12:15:20 +0200 Subject: [PATCH] AGC2: VadWithLevel -> VoiceActivityDetectorWrapper 2/2 Internal refactoring of AGC2 to decouple the VAD, its wrapper and the peak and RMS level measurements. Bit exactness verified with audioproc_f on a collection of AEC dumps and Wav files (42 recordings in total). Bug: webrtc:7494 Change-Id: Ib560f1fcaa601557f4f30e47025c69e91b1b62e0 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/234524 Commit-Queue: Alessio Bazzica Reviewed-by: Hanna Silen Cr-Commit-Position: refs/heads/main@{#35208} --- modules/audio_processing/agc2/BUILD.gn | 1 + modules/audio_processing/agc2/adaptive_agc.cc | 35 ++++-- modules/audio_processing/agc2/adaptive_agc.h | 2 +- .../agc2/adaptive_mode_level_estimator.cc | 27 ++-- .../agc2/adaptive_mode_level_estimator.h | 2 +- .../adaptive_mode_level_estimator_unittest.cc | 88 ++++++++------ modules/audio_processing/agc2/vad_wrapper.cc | 85 ++++++------- modules/audio_processing/agc2/vad_wrapper.h | 54 ++++---- .../agc2/vad_wrapper_unittest.cc | 115 +++++++++++------- 9 files changed, 226 insertions(+), 183 deletions(-) diff --git a/modules/audio_processing/agc2/BUILD.gn b/modules/audio_processing/agc2/BUILD.gn index a897c0b1a7..f767a6d7f9 100644 --- a/modules/audio_processing/agc2/BUILD.gn +++ b/modules/audio_processing/agc2/BUILD.gn @@ -280,6 +280,7 @@ rtc_library("vad_wrapper_unittests") { ":common", ":vad_wrapper", "..:audio_frame_view", + "../../../rtc_base:checks", "../../../rtc_base:gunit_helpers", "../../../rtc_base:safe_compare", "../../../test:test_support", diff --git a/modules/audio_processing/agc2/adaptive_agc.cc b/modules/audio_processing/agc2/adaptive_agc.cc index ab1822d8ba..fb06549140 100644 --- a/modules/audio_processing/agc2/adaptive_agc.cc +++ b/modules/audio_processing/agc2/adaptive_agc.cc @@ -36,6 +36,24 @@ AvailableCpuFeatures GetAllowedCpuFeatures() { return features; } +// Peak and RMS audio levels in dBFS. +struct AudioLevels { + float peak_dbfs; + float rms_dbfs; +}; + +// Computes the audio levels for the first channel in `frame`. +AudioLevels ComputeAudioLevels(AudioFrameView frame) { + float peak = 0.0f; + float rms = 0.0f; + for (const auto& x : frame.channel(0)) { + peak = std::max(std::fabs(x), peak); + rms += x * x; + } + return {FloatS16ToDbfs(peak), + FloatS16ToDbfs(std::sqrt(rms / frame.samples_per_channel()))}; +} + } // namespace AdaptiveAgc::AdaptiveAgc( @@ -62,16 +80,17 @@ void AdaptiveAgc::Initialize(int sample_rate_hz, int num_channels) { } void AdaptiveAgc::Process(AudioFrameView frame, float limiter_envelope) { + AudioLevels levels = ComputeAudioLevels(frame); + AdaptiveDigitalGainApplier::FrameInfo info; - VadLevelAnalyzer::Result vad_result = vad_.AnalyzeFrame(frame); - info.speech_probability = vad_result.speech_probability; - apm_data_dumper_->DumpRaw("agc2_speech_probability", - vad_result.speech_probability); - apm_data_dumper_->DumpRaw("agc2_input_rms_dbfs", vad_result.rms_dbfs); - apm_data_dumper_->DumpRaw("agc2_input_peak_dbfs", vad_result.peak_dbfs); + info.speech_probability = vad_.Analyze(frame); + apm_data_dumper_->DumpRaw("agc2_speech_probability", info.speech_probability); + apm_data_dumper_->DumpRaw("agc2_input_rms_dbfs", levels.rms_dbfs); + apm_data_dumper_->DumpRaw("agc2_input_peak_dbfs", levels.peak_dbfs); - speech_level_estimator_.Update(vad_result); + speech_level_estimator_.Update(levels.rms_dbfs, levels.peak_dbfs, + info.speech_probability); info.speech_level_dbfs = speech_level_estimator_.level_dbfs(); info.speech_level_reliable = speech_level_estimator_.IsConfident(); apm_data_dumper_->DumpRaw("agc2_speech_level_dbfs", info.speech_level_dbfs); @@ -81,7 +100,7 @@ void AdaptiveAgc::Process(AudioFrameView frame, float limiter_envelope) { info.noise_rms_dbfs = noise_level_estimator_->Analyze(frame); apm_data_dumper_->DumpRaw("agc2_noise_rms_dbfs", info.noise_rms_dbfs); - saturation_protector_->Analyze(info.speech_probability, vad_result.peak_dbfs, + saturation_protector_->Analyze(info.speech_probability, levels.peak_dbfs, info.speech_level_dbfs); info.headroom_db = saturation_protector_->HeadroomDb(); apm_data_dumper_->DumpRaw("agc2_headroom_db", info.headroom_db); diff --git a/modules/audio_processing/agc2/adaptive_agc.h b/modules/audio_processing/agc2/adaptive_agc.h index 8ee8378df5..32de680b0b 100644 --- a/modules/audio_processing/agc2/adaptive_agc.h +++ b/modules/audio_processing/agc2/adaptive_agc.h @@ -47,7 +47,7 @@ class AdaptiveAgc { private: AdaptiveModeLevelEstimator speech_level_estimator_; - VadLevelAnalyzer vad_; + VoiceActivityDetectorWrapper vad_; AdaptiveDigitalGainApplier gain_controller_; ApmDataDumper* const apm_data_dumper_; std::unique_ptr noise_level_estimator_; diff --git a/modules/audio_processing/agc2/adaptive_mode_level_estimator.cc b/modules/audio_processing/agc2/adaptive_mode_level_estimator.cc index 81e7d291f6..fe021fec05 100644 --- a/modules/audio_processing/agc2/adaptive_mode_level_estimator.cc +++ b/modules/audio_processing/agc2/adaptive_mode_level_estimator.cc @@ -57,15 +57,16 @@ AdaptiveModeLevelEstimator::AdaptiveModeLevelEstimator( Reset(); } -void AdaptiveModeLevelEstimator::Update( - const VadLevelAnalyzer::Result& vad_level) { - RTC_DCHECK_GT(vad_level.rms_dbfs, -150.f); - RTC_DCHECK_LT(vad_level.rms_dbfs, 50.f); - RTC_DCHECK_GT(vad_level.peak_dbfs, -150.f); - RTC_DCHECK_LT(vad_level.peak_dbfs, 50.f); - RTC_DCHECK_GE(vad_level.speech_probability, 0.f); - RTC_DCHECK_LE(vad_level.speech_probability, 1.f); - if (vad_level.speech_probability < kVadConfidenceThreshold) { +void AdaptiveModeLevelEstimator::Update(float rms_dbfs, + float peak_dbfs, + float speech_probability) { + RTC_DCHECK_GT(rms_dbfs, -150.0f); + RTC_DCHECK_LT(rms_dbfs, 50.0f); + RTC_DCHECK_GT(peak_dbfs, -150.0f); + RTC_DCHECK_LT(peak_dbfs, 50.0f); + RTC_DCHECK_GE(speech_probability, 0.0f); + RTC_DCHECK_LE(speech_probability, 1.0f); + if (speech_probability < kVadConfidenceThreshold) { // Not a speech frame. if (adjacent_speech_frames_threshold_ > 1) { // When two or more adjacent speech frames are required in order to update @@ -93,14 +94,14 @@ void AdaptiveModeLevelEstimator::Update( preliminary_state_.time_to_confidence_ms -= kFrameDurationMs; } // Weighted average of levels with speech probability as weight. - RTC_DCHECK_GT(vad_level.speech_probability, 0.f); - const float leak_factor = buffer_is_full ? kLevelEstimatorLeakFactor : 1.f; + RTC_DCHECK_GT(speech_probability, 0.0f); + const float leak_factor = buffer_is_full ? kLevelEstimatorLeakFactor : 1.0f; preliminary_state_.level_dbfs.numerator = preliminary_state_.level_dbfs.numerator * leak_factor + - vad_level.rms_dbfs * vad_level.speech_probability; + rms_dbfs * speech_probability; preliminary_state_.level_dbfs.denominator = preliminary_state_.level_dbfs.denominator * leak_factor + - vad_level.speech_probability; + speech_probability; const float level_dbfs = preliminary_state_.level_dbfs.GetRatio(); diff --git a/modules/audio_processing/agc2/adaptive_mode_level_estimator.h b/modules/audio_processing/agc2/adaptive_mode_level_estimator.h index 14da6b7f49..989c8c3572 100644 --- a/modules/audio_processing/agc2/adaptive_mode_level_estimator.h +++ b/modules/audio_processing/agc2/adaptive_mode_level_estimator.h @@ -33,7 +33,7 @@ class AdaptiveModeLevelEstimator { delete; // Updates the level estimation. - void Update(const VadLevelAnalyzer::Result& vad_data); + void Update(float rms_dbfs, float peak_dbfs, float speech_probability); // Returns the estimated speech plus noise level. float level_dbfs() const { return level_dbfs_; } // Returns true if the estimator is confident on its current estimate. diff --git a/modules/audio_processing/agc2/adaptive_mode_level_estimator_unittest.cc b/modules/audio_processing/agc2/adaptive_mode_level_estimator_unittest.cc index 1cdd91d5d8..684fca188a 100644 --- a/modules/audio_processing/agc2/adaptive_mode_level_estimator_unittest.cc +++ b/modules/audio_processing/agc2/adaptive_mode_level_estimator_unittest.cc @@ -33,10 +33,12 @@ constexpr float kConvergenceSpeedTestsLevelTolerance = 0.5f; // Provides the `vad_level` value `num_iterations` times to `level_estimator`. void RunOnConstantLevel(int num_iterations, - const VadLevelAnalyzer::Result& vad_level, + float rms_dbfs, + float peak_dbfs, + float speech_probability, AdaptiveModeLevelEstimator& level_estimator) { for (int i = 0; i < num_iterations; ++i) { - level_estimator.Update(vad_level); + level_estimator.Update(rms_dbfs, peak_dbfs, speech_probability); } } @@ -47,6 +49,10 @@ constexpr AdaptiveDigitalConfig GetAdaptiveDigitalConfig( return config; } +constexpr float kNoSpeechProbability = 0.0f; +constexpr float kLowSpeechProbability = kVadConfidenceThreshold / 2.0f; +constexpr float kMaxSpeechProbability = 1.0f; + // Level estimator with data dumper. struct TestLevelEstimator { explicit TestLevelEstimator(int adjacent_speech_frames_threshold) @@ -55,36 +61,31 @@ struct TestLevelEstimator { &data_dumper, GetAdaptiveDigitalConfig(adjacent_speech_frames_threshold))), initial_speech_level_dbfs(estimator->level_dbfs()), - vad_level_rms(initial_speech_level_dbfs / 2.0f), - vad_level_peak(initial_speech_level_dbfs / 3.0f), - vad_data_speech( - {/*speech_probability=*/1.0f, vad_level_rms, vad_level_peak}), - vad_data_non_speech( - {/*speech_probability=*/kVadConfidenceThreshold / 2.0f, - vad_level_rms, vad_level_peak}) { - RTC_DCHECK_LT(vad_level_rms, vad_level_peak); - RTC_DCHECK_LT(initial_speech_level_dbfs, vad_level_rms); - RTC_DCHECK_GT(vad_level_rms - initial_speech_level_dbfs, 5.0f) - << "Adjust `vad_level_rms` so that the difference from the initial " + level_rms_dbfs(initial_speech_level_dbfs / 2.0f), + level_peak_dbfs(initial_speech_level_dbfs / 3.0f) { + RTC_DCHECK_LT(level_rms_dbfs, level_peak_dbfs); + RTC_DCHECK_LT(initial_speech_level_dbfs, level_rms_dbfs); + RTC_DCHECK_GT(level_rms_dbfs - initial_speech_level_dbfs, 5.0f) + << "Adjust `level_rms_dbfs` so that the difference from the initial " "level is wide enough for the tests"; } ApmDataDumper data_dumper; std::unique_ptr estimator; const float initial_speech_level_dbfs; - const float vad_level_rms; - const float vad_level_peak; - const VadLevelAnalyzer::Result vad_data_speech; - const VadLevelAnalyzer::Result vad_data_non_speech; + const float level_rms_dbfs; + const float level_peak_dbfs; }; // Checks that the level estimator converges to a constant input speech level. TEST(GainController2AdaptiveModeLevelEstimator, LevelStabilizes) { TestLevelEstimator level_estimator(/*adjacent_speech_frames_threshold=*/1); RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence, - level_estimator.vad_data_speech, + level_estimator.level_rms_dbfs, + level_estimator.level_peak_dbfs, kMaxSpeechProbability, *level_estimator.estimator); const float estimated_level_dbfs = level_estimator.estimator->level_dbfs(); - RunOnConstantLevel(/*num_iterations=*/1, level_estimator.vad_data_speech, + RunOnConstantLevel(/*num_iterations=*/1, level_estimator.level_rms_dbfs, + level_estimator.level_peak_dbfs, kMaxSpeechProbability, *level_estimator.estimator); EXPECT_NEAR(level_estimator.estimator->level_dbfs(), estimated_level_dbfs, 0.1f); @@ -95,7 +96,8 @@ TEST(GainController2AdaptiveModeLevelEstimator, LevelStabilizes) { TEST(GainController2AdaptiveModeLevelEstimator, IsNotConfident) { TestLevelEstimator level_estimator(/*adjacent_speech_frames_threshold=*/1); RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence / 2, - level_estimator.vad_data_speech, + level_estimator.level_rms_dbfs, + level_estimator.level_peak_dbfs, kMaxSpeechProbability, *level_estimator.estimator); EXPECT_FALSE(level_estimator.estimator->IsConfident()); } @@ -105,7 +107,8 @@ TEST(GainController2AdaptiveModeLevelEstimator, IsNotConfident) { TEST(GainController2AdaptiveModeLevelEstimator, IsConfident) { TestLevelEstimator level_estimator(/*adjacent_speech_frames_threshold=*/1); RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence, - level_estimator.vad_data_speech, + level_estimator.level_rms_dbfs, + level_estimator.level_peak_dbfs, kMaxSpeechProbability, *level_estimator.estimator); EXPECT_TRUE(level_estimator.estimator->IsConfident()); } @@ -117,15 +120,14 @@ TEST(GainController2AdaptiveModeLevelEstimator, TestLevelEstimator level_estimator(/*adjacent_speech_frames_threshold=*/1); // Simulate speech. RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence, - level_estimator.vad_data_speech, + level_estimator.level_rms_dbfs, + level_estimator.level_peak_dbfs, kMaxSpeechProbability, *level_estimator.estimator); const float estimated_level_dbfs = level_estimator.estimator->level_dbfs(); // Simulate full-scale non-speech. RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence, - VadLevelAnalyzer::Result{/*speech_probability=*/0.0f, - /*rms_dbfs=*/0.0f, - /*peak_dbfs=*/0.0f}, - *level_estimator.estimator); + /*rms_dbfs=*/0.0f, /*peak_dbfs=*/0.0f, + kNoSpeechProbability, *level_estimator.estimator); // No estimated level change is expected. EXPECT_FLOAT_EQ(level_estimator.estimator->level_dbfs(), estimated_level_dbfs); @@ -136,10 +138,11 @@ TEST(GainController2AdaptiveModeLevelEstimator, ConvergenceSpeedBeforeConfidence) { TestLevelEstimator level_estimator(/*adjacent_speech_frames_threshold=*/1); RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence, - level_estimator.vad_data_speech, + level_estimator.level_rms_dbfs, + level_estimator.level_peak_dbfs, kMaxSpeechProbability, *level_estimator.estimator); EXPECT_NEAR(level_estimator.estimator->level_dbfs(), - level_estimator.vad_data_speech.rms_dbfs, + level_estimator.level_rms_dbfs, kConvergenceSpeedTestsLevelTolerance); } @@ -150,11 +153,9 @@ TEST(GainController2AdaptiveModeLevelEstimator, // Reach confidence using the initial level estimate. RunOnConstantLevel( /*num_iterations=*/kNumFramesToConfidence, - VadLevelAnalyzer::Result{ - /*speech_probability=*/1.0f, - /*rms_dbfs=*/level_estimator.initial_speech_level_dbfs, - /*peak_dbfs=*/level_estimator.initial_speech_level_dbfs + 6.0f}, - *level_estimator.estimator); + /*rms_dbfs=*/level_estimator.initial_speech_level_dbfs, + /*peak_dbfs=*/level_estimator.initial_speech_level_dbfs + 6.0f, + kMaxSpeechProbability, *level_estimator.estimator); // No estimate change should occur, but confidence is achieved. ASSERT_FLOAT_EQ(level_estimator.estimator->level_dbfs(), level_estimator.initial_speech_level_dbfs); @@ -165,9 +166,10 @@ TEST(GainController2AdaptiveModeLevelEstimator, kConvergenceTimeAfterConfidenceNumFrames > kNumFramesToConfidence, ""); RunOnConstantLevel( /*num_iterations=*/kConvergenceTimeAfterConfidenceNumFrames, - level_estimator.vad_data_speech, *level_estimator.estimator); + level_estimator.level_rms_dbfs, level_estimator.level_peak_dbfs, + kMaxSpeechProbability, *level_estimator.estimator); EXPECT_NEAR(level_estimator.estimator->level_dbfs(), - level_estimator.vad_data_speech.rms_dbfs, + level_estimator.level_rms_dbfs, kConvergenceSpeedTestsLevelTolerance); } @@ -181,22 +183,28 @@ TEST_P(AdaptiveModeLevelEstimatorParametrization, DoNotAdaptToShortSpeechSegments) { TestLevelEstimator level_estimator(adjacent_speech_frames_threshold()); const float initial_level = level_estimator.estimator->level_dbfs(); - ASSERT_LT(initial_level, level_estimator.vad_data_speech.peak_dbfs); + ASSERT_LT(initial_level, level_estimator.level_peak_dbfs); for (int i = 0; i < adjacent_speech_frames_threshold() - 1; ++i) { SCOPED_TRACE(i); - level_estimator.estimator->Update(level_estimator.vad_data_speech); + level_estimator.estimator->Update(level_estimator.level_rms_dbfs, + level_estimator.level_peak_dbfs, + kMaxSpeechProbability); EXPECT_EQ(initial_level, level_estimator.estimator->level_dbfs()); } - level_estimator.estimator->Update(level_estimator.vad_data_non_speech); + level_estimator.estimator->Update(level_estimator.level_rms_dbfs, + level_estimator.level_peak_dbfs, + kLowSpeechProbability); EXPECT_EQ(initial_level, level_estimator.estimator->level_dbfs()); } TEST_P(AdaptiveModeLevelEstimatorParametrization, AdaptToEnoughSpeechSegments) { TestLevelEstimator level_estimator(adjacent_speech_frames_threshold()); const float initial_level = level_estimator.estimator->level_dbfs(); - ASSERT_LT(initial_level, level_estimator.vad_data_speech.peak_dbfs); + ASSERT_LT(initial_level, level_estimator.level_peak_dbfs); for (int i = 0; i < adjacent_speech_frames_threshold(); ++i) { - level_estimator.estimator->Update(level_estimator.vad_data_speech); + level_estimator.estimator->Update(level_estimator.level_rms_dbfs, + level_estimator.level_peak_dbfs, + kMaxSpeechProbability); } EXPECT_LT(initial_level, level_estimator.estimator->level_dbfs()); } diff --git a/modules/audio_processing/agc2/vad_wrapper.cc b/modules/audio_processing/agc2/vad_wrapper.cc index 94d5f67d9e..7b61aee99d 100644 --- a/modules/audio_processing/agc2/vad_wrapper.cc +++ b/modules/audio_processing/agc2/vad_wrapper.cc @@ -10,13 +10,10 @@ #include "modules/audio_processing/agc2/vad_wrapper.h" -#include #include -#include #include #include "api/array_view.h" -#include "common_audio/include/audio_util.h" #include "common_audio/resampler/include/push_resampler.h" #include "modules/audio_processing/agc2/agc2_common.h" #include "modules/audio_processing/agc2/rnn_vad/common.h" @@ -27,82 +24,72 @@ namespace webrtc { namespace { -using VoiceActivityDetector = VadLevelAnalyzer::VoiceActivityDetector; +constexpr int kNumFramesPerSecond = 100; -// Default VAD that combines a resampler and the RNN VAD. -// Computes the speech probability on the first channel. -class Vad : public VoiceActivityDetector { +class MonoVadImpl : public VoiceActivityDetectorWrapper::MonoVad { public: - explicit Vad(const AvailableCpuFeatures& cpu_features) + explicit MonoVadImpl(const AvailableCpuFeatures& cpu_features) : features_extractor_(cpu_features), rnn_vad_(cpu_features) {} - Vad(const Vad&) = delete; - Vad& operator=(const Vad&) = delete; - ~Vad() = default; + MonoVadImpl(const MonoVadImpl&) = delete; + MonoVadImpl& operator=(const MonoVadImpl&) = delete; + ~MonoVadImpl() = default; + int SampleRateHz() const override { return rnn_vad::kSampleRate24kHz; } void Reset() override { rnn_vad_.Reset(); } - - float ComputeProbability(AudioFrameView frame) override { - // The source number of channels is 1, because we always use the 1st - // channel. - resampler_.InitializeIfNeeded( - /*sample_rate_hz=*/static_cast(frame.samples_per_channel() * 100), - rnn_vad::kSampleRate24kHz, - /*num_channels=*/1); - - std::array work_frame; - // Feed the 1st channel to the resampler. - resampler_.Resample(frame.channel(0).data(), frame.samples_per_channel(), - work_frame.data(), rnn_vad::kFrameSize10ms24kHz); - + float Analyze(rtc::ArrayView frame) override { + RTC_DCHECK_EQ(frame.size(), rnn_vad::kFrameSize10ms24kHz); std::array feature_vector; const bool is_silence = features_extractor_.CheckSilenceComputeFeatures( - work_frame, feature_vector); + /*samples=*/{frame.data(), rnn_vad::kFrameSize10ms24kHz}, + feature_vector); return rnn_vad_.ComputeVadProbability(feature_vector, is_silence); } private: - PushResampler resampler_; rnn_vad::FeaturesExtractor features_extractor_; rnn_vad::RnnVad rnn_vad_; }; } // namespace -VadLevelAnalyzer::VadLevelAnalyzer(int vad_reset_period_ms, - const AvailableCpuFeatures& cpu_features) - : VadLevelAnalyzer(vad_reset_period_ms, - std::make_unique(cpu_features)) {} +VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper( + int vad_reset_period_ms, + const AvailableCpuFeatures& cpu_features) + : VoiceActivityDetectorWrapper( + vad_reset_period_ms, + std::make_unique(cpu_features)) {} -VadLevelAnalyzer::VadLevelAnalyzer(int vad_reset_period_ms, - std::unique_ptr vad) - : vad_(std::move(vad)), - vad_reset_period_frames_( +VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper( + int vad_reset_period_ms, + std::unique_ptr vad) + : vad_reset_period_frames_( rtc::CheckedDivExact(vad_reset_period_ms, kFrameDurationMs)), - time_to_vad_reset_(vad_reset_period_frames_) { + time_to_vad_reset_(vad_reset_period_frames_), + vad_(std::move(vad)) { RTC_DCHECK(vad_); RTC_DCHECK_GT(vad_reset_period_frames_, 1); + resampled_buffer_.resize( + rtc::CheckedDivExact(vad_->SampleRateHz(), kNumFramesPerSecond)); } -VadLevelAnalyzer::~VadLevelAnalyzer() = default; +VoiceActivityDetectorWrapper::~VoiceActivityDetectorWrapper() = default; -VadLevelAnalyzer::Result VadLevelAnalyzer::AnalyzeFrame( - AudioFrameView frame) { +float VoiceActivityDetectorWrapper::Analyze(AudioFrameView frame) { // Periodically reset the VAD. time_to_vad_reset_--; if (time_to_vad_reset_ <= 0) { vad_->Reset(); time_to_vad_reset_ = vad_reset_period_frames_; } - // Compute levels. - float peak = 0.0f; - float rms = 0.0f; - for (const auto& x : frame.channel(0)) { - peak = std::max(std::fabs(x), peak); - rms += x * x; - } - return {vad_->ComputeProbability(frame), - FloatS16ToDbfs(std::sqrt(rms / frame.samples_per_channel())), - FloatS16ToDbfs(peak)}; + + // Resample the first channel of `frame`. + resampler_.InitializeIfNeeded( + /*sample_rate_hz=*/frame.samples_per_channel() * kNumFramesPerSecond, + vad_->SampleRateHz(), /*num_channels=*/1); + resampler_.Resample(frame.channel(0).data(), frame.samples_per_channel(), + resampled_buffer_.data(), resampled_buffer_.size()); + + return vad_->Analyze(resampled_buffer_); } } // namespace webrtc diff --git a/modules/audio_processing/agc2/vad_wrapper.h b/modules/audio_processing/agc2/vad_wrapper.h index de73eabe58..f17fcda6b8 100644 --- a/modules/audio_processing/agc2/vad_wrapper.h +++ b/modules/audio_processing/agc2/vad_wrapper.h @@ -12,51 +12,57 @@ #define MODULES_AUDIO_PROCESSING_AGC2_VAD_WRAPPER_H_ #include +#include +#include "api/array_view.h" +#include "common_audio/resampler/include/push_resampler.h" #include "modules/audio_processing/agc2/cpu_features.h" #include "modules/audio_processing/include/audio_frame_view.h" namespace webrtc { -// Class to analyze voice activity and audio levels. -class VadLevelAnalyzer { +// Wraps a single-channel Voice Activity Detector (VAD) which is used to analyze +// the first channel of the input audio frames. Takes care of resampling the +// input frames to match the sample rate of the wrapped VAD and periodically +// resets the VAD. +class VoiceActivityDetectorWrapper { public: - struct Result { - float speech_probability; // Range: [0, 1]. - float rms_dbfs; // Root mean square power (dBFS). - float peak_dbfs; // Peak power (dBFS). - }; - - // Voice Activity Detector (VAD) interface. - class VoiceActivityDetector { + // Single channel VAD interface. + class MonoVad { public: - virtual ~VoiceActivityDetector() = default; + virtual ~MonoVad() = default; + // Returns the sample rate (Hz) required for the input frames analyzed by + // `ComputeProbability`. + virtual int SampleRateHz() const = 0; // Resets the internal state. virtual void Reset() = 0; // Analyzes an audio frame and returns the speech probability. - virtual float ComputeProbability(AudioFrameView frame) = 0; + virtual float Analyze(rtc::ArrayView frame) = 0; }; // Ctor. `vad_reset_period_ms` indicates the period in milliseconds to call - // `VadLevelAnalyzer::Reset()`; it must be equal to or greater than the - // duration of two frames. Uses `cpu_features` to instantiate the default VAD. - VadLevelAnalyzer(int vad_reset_period_ms, - const AvailableCpuFeatures& cpu_features); + // `MonoVad::Reset()`; it must be equal to or greater than the duration of two + // frames. Uses `cpu_features` to instantiate the default VAD. + VoiceActivityDetectorWrapper(int vad_reset_period_ms, + const AvailableCpuFeatures& cpu_features); // Ctor. Uses a custom `vad`. - VadLevelAnalyzer(int vad_reset_period_ms, - std::unique_ptr vad); + VoiceActivityDetectorWrapper(int vad_reset_period_ms, + std::unique_ptr vad); - VadLevelAnalyzer(const VadLevelAnalyzer&) = delete; - VadLevelAnalyzer& operator=(const VadLevelAnalyzer&) = delete; - ~VadLevelAnalyzer(); + VoiceActivityDetectorWrapper(const VoiceActivityDetectorWrapper&) = delete; + VoiceActivityDetectorWrapper& operator=(const VoiceActivityDetectorWrapper&) = + delete; + ~VoiceActivityDetectorWrapper(); - // Computes the speech probability and the level for `frame`. - Result AnalyzeFrame(AudioFrameView frame); + // Analyzes the first channel of `frame` and returns the speech probability. + float Analyze(AudioFrameView frame); private: - std::unique_ptr vad_; const int vad_reset_period_frames_; int time_to_vad_reset_; + PushResampler resampler_; + std::unique_ptr vad_; + std::vector resampled_buffer_; }; } // namespace webrtc diff --git a/modules/audio_processing/agc2/vad_wrapper_unittest.cc b/modules/audio_processing/agc2/vad_wrapper_unittest.cc index a6e776c8b5..c1f7029ef1 100644 --- a/modules/audio_processing/agc2/vad_wrapper_unittest.cc +++ b/modules/audio_processing/agc2/vad_wrapper_unittest.cc @@ -18,6 +18,7 @@ #include "modules/audio_processing/agc2/agc2_common.h" #include "modules/audio_processing/include/audio_frame_view.h" +#include "rtc_base/checks.h" #include "rtc_base/gunit.h" #include "rtc_base/numerics/safe_compare.h" #include "test/gmock.h" @@ -26,90 +27,78 @@ namespace webrtc { namespace { using ::testing::AnyNumber; +using ::testing::Return; using ::testing::ReturnRoundRobin; +using ::testing::Truly; constexpr int kNoVadPeriodicReset = kFrameDurationMs * (std::numeric_limits::max() / kFrameDurationMs); -constexpr int kSampleRateHz = 8000; +constexpr int kSampleRate8kHz = 8000; -class MockVad : public VadLevelAnalyzer::VoiceActivityDetector { +class MockVad : public VoiceActivityDetectorWrapper::MonoVad { public: + MOCK_METHOD(int, SampleRateHz, (), (const override)); MOCK_METHOD(void, Reset, (), (override)); - MOCK_METHOD(float, - ComputeProbability, - (AudioFrameView frame), - (override)); + MOCK_METHOD(float, Analyze, (rtc::ArrayView frame), (override)); }; -// Creates a `VadLevelAnalyzer` injecting a mock VAD which repeatedly returns -// the next value from `speech_probabilities` until it reaches the end and will -// restart from the beginning. -std::unique_ptr CreateVadLevelAnalyzerWithMockVad( +// Creates a `VoiceActivityDetectorWrapper` injecting a mock VAD that +// repeatedly returns the next value from `speech_probabilities` and that +// restarts from the beginning when after the last element is returned. +std::unique_ptr CreateMockVadWrapper( int vad_reset_period_ms, const std::vector& speech_probabilities, int expected_vad_reset_calls = 0) { auto vad = std::make_unique(); - EXPECT_CALL(*vad, ComputeProbability) + EXPECT_CALL(*vad, SampleRateHz) .Times(AnyNumber()) - .WillRepeatedly(ReturnRoundRobin(speech_probabilities)); + .WillRepeatedly(Return(kSampleRate8kHz)); if (expected_vad_reset_calls >= 0) { EXPECT_CALL(*vad, Reset).Times(expected_vad_reset_calls); } - return std::make_unique(vad_reset_period_ms, - std::move(vad)); + EXPECT_CALL(*vad, Analyze) + .Times(AnyNumber()) + .WillRepeatedly(ReturnRoundRobin(speech_probabilities)); + return std::make_unique(vad_reset_period_ms, + std::move(vad)); } // 10 ms mono frame. struct FrameWithView { // Ctor. Initializes the frame samples with `value`. - explicit FrameWithView(float value = 0.0f) - : channel0(samples.data()), - view(&channel0, /*num_channels=*/1, samples.size()) { - samples.fill(value); - } - std::array samples; + explicit FrameWithView(int sample_rate_hz = kSampleRate8kHz) + : samples(rtc::CheckedDivExact(sample_rate_hz, 100), 0.0f), + channel0(samples.data()), + view(&channel0, /*num_channels=*/1, samples.size()) {} + std::vector samples; const float* const channel0; const AudioFrameView view; }; -TEST(GainController2VadLevelAnalyzer, RmsLessThanPeakLevel) { - auto analyzer = CreateVadLevelAnalyzerWithMockVad( - /*vad_reset_period_ms=*/1500, - /*speech_probabilities=*/{1.0f}, - /*expected_vad_reset_calls=*/0); - // Handcrafted frame so that the average is lower than the peak value. - FrameWithView frame(1000.0f); // Constant frame. - frame.samples[10] = 2000.0f; // Except for one peak value. - // Compute audio frame levels. - auto levels_and_vad_prob = analyzer->AnalyzeFrame(frame.view); - EXPECT_LT(levels_and_vad_prob.rms_dbfs, levels_and_vad_prob.peak_dbfs); -} - -// Checks that the expect VAD probabilities are returned. -TEST(GainController2VadLevelAnalyzer, NoSpeechProbabilitySmoothing) { +// Checks that the expected speech probabilities are returned. +TEST(GainController2VoiceActivityDetectorWrapper, CheckSpeechProbabilities) { const std::vector speech_probabilities{0.709f, 0.484f, 0.882f, 0.167f, 0.44f, 0.525f, 0.858f, 0.314f, 0.653f, 0.965f, 0.413f, 0.0f}; - auto analyzer = CreateVadLevelAnalyzerWithMockVad(kNoVadPeriodicReset, - speech_probabilities); + auto vad_wrapper = + CreateMockVadWrapper(kNoVadPeriodicReset, speech_probabilities); FrameWithView frame; for (int i = 0; rtc::SafeLt(i, speech_probabilities.size()); ++i) { SCOPED_TRACE(i); - EXPECT_EQ(speech_probabilities[i], - analyzer->AnalyzeFrame(frame.view).speech_probability); + EXPECT_EQ(speech_probabilities[i], vad_wrapper->Analyze(frame.view)); } } // Checks that the VAD is not periodically reset. -TEST(GainController2VadLevelAnalyzer, VadNoPeriodicReset) { +TEST(GainController2VoiceActivityDetectorWrapper, VadNoPeriodicReset) { constexpr int kNumFrames = 19; - auto analyzer = CreateVadLevelAnalyzerWithMockVad( - kNoVadPeriodicReset, /*speech_probabilities=*/{1.0f}, - /*expected_vad_reset_calls=*/0); + auto vad_wrapper = + CreateMockVadWrapper(kNoVadPeriodicReset, /*speech_probabilities=*/{1.0f}, + /*expected_vad_reset_calls=*/0); FrameWithView frame; for (int i = 0; i < kNumFrames; ++i) { - analyzer->AnalyzeFrame(frame.view); + vad_wrapper->Analyze(frame.view); } } @@ -122,20 +111,52 @@ class VadPeriodResetParametrization // Checks that the VAD is periodically reset with the expected period. TEST_P(VadPeriodResetParametrization, VadPeriodicReset) { - auto analyzer = CreateVadLevelAnalyzerWithMockVad( + auto vad_wrapper = CreateMockVadWrapper( /*vad_reset_period_ms=*/vad_reset_period_frames() * kFrameDurationMs, /*speech_probabilities=*/{1.0f}, /*expected_vad_reset_calls=*/num_frames() / vad_reset_period_frames()); FrameWithView frame; for (int i = 0; i < num_frames(); ++i) { - analyzer->AnalyzeFrame(frame.view); + vad_wrapper->Analyze(frame.view); } } -INSTANTIATE_TEST_SUITE_P(GainController2VadLevelAnalyzer, +INSTANTIATE_TEST_SUITE_P(GainController2VoiceActivityDetectorWrapper, VadPeriodResetParametrization, ::testing::Combine(::testing::Values(1, 19, 123), ::testing::Values(2, 5, 20, 53))); +class VadResamplingParametrization + : public ::testing::TestWithParam> { + protected: + int input_sample_rate_hz() const { return std::get<0>(GetParam()); } + int vad_sample_rate_hz() const { return std::get<1>(GetParam()); } +}; + +// Checks that regardless of the input audio sample rate, the wrapped VAD +// analyzes frames having the expected size, that is according to its internal +// sample rate. +TEST_P(VadResamplingParametrization, CheckResampledFrameSize) { + auto vad = std::make_unique(); + EXPECT_CALL(*vad, SampleRateHz) + .Times(AnyNumber()) + .WillRepeatedly(Return(vad_sample_rate_hz())); + EXPECT_CALL(*vad, Reset).Times(0); + EXPECT_CALL(*vad, Analyze(Truly([this](rtc::ArrayView frame) { + return rtc::SafeEq(frame.size(), + rtc::CheckedDivExact(vad_sample_rate_hz(), 100)); + }))).Times(1); + auto vad_wrapper = std::make_unique( + kNoVadPeriodicReset, std::move(vad)); + FrameWithView frame(input_sample_rate_hz()); + vad_wrapper->Analyze(frame.view); +} + +INSTANTIATE_TEST_SUITE_P( + GainController2VoiceActivityDetectorWrapper, + VadResamplingParametrization, + ::testing::Combine(::testing::Values(8000, 16000, 44100, 48000), + ::testing::Values(6000, 8000, 12000, 16000, 24000))); + } // namespace } // namespace webrtc