From c1ece012cb804bb74c53187f63c9e10a5c49eb54 Mon Sep 17 00:00:00 2001 From: Alessio Bazzica Date: Fri, 25 Sep 2020 14:31:17 +0200 Subject: [PATCH] AGC2 VAD probability: instant decay / slow attack Feature added to gain robustness to occasional VAD speech probability spikes. In such a case, the attack process reduces the chance that the smoothed values are greater than the speech threshold. Bug: webrtc:7494 Change-Id: I6babe5afe30ea3dea021181a19d86bb74b33a98c Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/185046 Commit-Queue: Alessio Bazzica Reviewed-by: Gustaf Ullberg Cr-Commit-Position: refs/heads/master@{#32198} --- modules/audio_processing/agc2/BUILD.gn | 5 + modules/audio_processing/agc2/agc2_common.h | 3 + .../audio_processing/agc2/vad_with_level.cc | 34 +++++- .../audio_processing/agc2/vad_with_level.h | 6 +- .../agc2/vad_with_level_unittest.cc | 113 ++++++++++++++++-- 5 files changed, 145 insertions(+), 16 deletions(-) diff --git a/modules/audio_processing/agc2/BUILD.gn b/modules/audio_processing/agc2/BUILD.gn index 90cbe83360..ffe4bf47c8 100644 --- a/modules/audio_processing/agc2/BUILD.gn +++ b/modules/audio_processing/agc2/BUILD.gn @@ -95,6 +95,7 @@ rtc_library("common") { "../../../rtc_base:rtc_base_approved", "../../../system_wrappers:field_trial", ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] } rtc_library("fixed_digital") { @@ -168,6 +169,7 @@ rtc_library("rnn_vad_with_level") { "vad_with_level.h", ] deps = [ + ":common", "..:audio_frame_view", "../../../api:array_view", "../../../common_audio", @@ -265,9 +267,12 @@ rtc_library("rnn_vad_with_level_unittests") { testonly = true sources = [ "vad_with_level_unittest.cc" ] deps = [ + ":common", ":rnn_vad_with_level", "..:audio_frame_view", "../../../rtc_base:gunit_helpers", + "../../../rtc_base:safe_compare", + "../../../test:test_support", ] } diff --git a/modules/audio_processing/agc2/agc2_common.h b/modules/audio_processing/agc2/agc2_common.h index a6389f4c2d..0549898e26 100644 --- a/modules/audio_processing/agc2/agc2_common.h +++ b/modules/audio_processing/agc2/agc2_common.h @@ -49,6 +49,9 @@ constexpr float kFullBufferLeakFactor = 1.f - 1.f / kFullBufferSizeMs; constexpr float kInitialSpeechLevelEstimateDbfs = -30.f; +// Robust VAD probability and speech decisions. +constexpr float kDefaultSmoothedVadProbabilityAttack = 1.f; + // Saturation Protector settings. float GetInitialSaturationMarginDb(); float GetExtraSaturationMarginOffsetDb(); diff --git a/modules/audio_processing/agc2/vad_with_level.cc b/modules/audio_processing/agc2/vad_with_level.cc index d35fbef90a..3dbb55732b 100644 --- a/modules/audio_processing/agc2/vad_with_level.cc +++ b/modules/audio_processing/agc2/vad_with_level.cc @@ -17,6 +17,7 @@ #include "api/array_view.h" #include "common_audio/include/audio_util.h" #include "common_audio/resampler/include/push_resampler.h" +#include "modules/audio_processing/agc2/agc2_common.h" #include "modules/audio_processing/agc2/rnn_vad/common.h" #include "modules/audio_processing/agc2/rnn_vad/features_extraction.h" #include "modules/audio_processing/agc2/rnn_vad/rnn.h" @@ -61,12 +62,32 @@ class Vad : public VoiceActivityDetector { rnn_vad::RnnBasedVad rnn_vad_; }; +// Returns an updated version of `p_old` by using instant decay and the given +// `attack` on a new VAD probability value `p_new`. +float SmoothedVadProbability(float p_old, float p_new, float attack) { + RTC_DCHECK_GT(attack, 0.f); + RTC_DCHECK_LE(attack, 1.f); + if (p_new < p_old || attack == 1.f) { + // Instant decay (or no smoothing). + return p_new; + } else { + // Attack phase. + return attack * p_new + (1.f - attack) * p_old; + } +} + } // namespace -VadLevelAnalyzer::VadLevelAnalyzer() : vad_(std::make_unique()) {} +VadLevelAnalyzer::VadLevelAnalyzer() + : VadLevelAnalyzer(kDefaultSmoothedVadProbabilityAttack, + std::make_unique()) {} -VadLevelAnalyzer::VadLevelAnalyzer(std::unique_ptr vad) - : vad_(std::move(vad)) { +VadLevelAnalyzer::VadLevelAnalyzer(float vad_probability_attack) + : VadLevelAnalyzer(vad_probability_attack, std::make_unique()) {} + +VadLevelAnalyzer::VadLevelAnalyzer(float vad_probability_attack, + std::unique_ptr vad) + : vad_(std::move(vad)), vad_probability_attack_(vad_probability_attack) { RTC_DCHECK(vad_); } @@ -74,13 +95,18 @@ VadLevelAnalyzer::~VadLevelAnalyzer() = default; VadLevelAnalyzer::Result VadLevelAnalyzer::AnalyzeFrame( AudioFrameView frame) { + // Compute levels. float peak = 0.f; float rms = 0.f; for (const auto& x : frame.channel(0)) { peak = std::max(std::fabs(x), peak); rms += x * x; } - return {vad_->ComputeProbability(frame), + // Compute smoothed speech probability. + vad_probability_ = SmoothedVadProbability( + /*p_old=*/vad_probability_, /*p_new=*/vad_->ComputeProbability(frame), + vad_probability_attack_); + return {vad_probability_, FloatS16ToDbfs(std::sqrt(rms / frame.samples_per_channel())), FloatS16ToDbfs(peak)}; } diff --git a/modules/audio_processing/agc2/vad_with_level.h b/modules/audio_processing/agc2/vad_with_level.h index 56eb79faf7..ce72cdc754 100644 --- a/modules/audio_processing/agc2/vad_with_level.h +++ b/modules/audio_processing/agc2/vad_with_level.h @@ -36,8 +36,10 @@ class VadLevelAnalyzer { // Ctor. Uses the default VAD. VadLevelAnalyzer(); + explicit VadLevelAnalyzer(float vad_probability_attack); // Ctor. Uses a custom `vad`. - explicit VadLevelAnalyzer(std::unique_ptr vad); + VadLevelAnalyzer(float vad_probability_attack, + std::unique_ptr vad); VadLevelAnalyzer(const VadLevelAnalyzer&) = delete; VadLevelAnalyzer& operator=(const VadLevelAnalyzer&) = delete; ~VadLevelAnalyzer(); @@ -47,6 +49,8 @@ class VadLevelAnalyzer { private: std::unique_ptr vad_; + const float vad_probability_attack_; + float vad_probability_ = 0.f; }; } // namespace webrtc diff --git a/modules/audio_processing/agc2/vad_with_level_unittest.cc b/modules/audio_processing/agc2/vad_with_level_unittest.cc index 5017e8369e..fb93c86417 100644 --- a/modules/audio_processing/agc2/vad_with_level_unittest.cc +++ b/modules/audio_processing/agc2/vad_with_level_unittest.cc @@ -10,30 +10,121 @@ #include "modules/audio_processing/agc2/vad_with_level.h" +#include +#include + +#include "modules/audio_processing/agc2/agc2_common.h" +#include "modules/audio_processing/include/audio_frame_view.h" #include "rtc_base/gunit.h" +#include "rtc_base/numerics/safe_compare.h" +#include "test/gmock.h" namespace webrtc { namespace { -TEST(AutomaticGainController2VadWithLevelEstimator, - PeakLevelGreaterThanRmsLevel) { - constexpr size_t kSampleRateHz = 8000; +using ::testing::AnyNumber; +using ::testing::ReturnRoundRobin; - // 10 ms input frame, constant except for one peak value. - // Handcrafted so that the average is lower than the peak value. - std::array frame; - frame.fill(1000.f); - frame[10] = 2000.f; - float* const channel0 = frame.data(); - AudioFrameView frame_view(&channel0, 1, frame.size()); +constexpr float kInstantAttack = 1.f; +constexpr float kSlowAttack = 0.1f; + +constexpr int kSampleRateHz = 8000; + +class MockVad : public VadLevelAnalyzer::VoiceActivityDetector { + public: + MOCK_METHOD(float, + ComputeProbability, + (AudioFrameView frame), + (override)); +}; + +// Creates a `VadLevelAnalyzer` injecting a mock VAD which repeatedly returns +// the next value from `speech_probabilities` until it reaches the end and will +// restart from the beginning. +std::unique_ptr CreateVadLevelAnalyzerWithMockVad( + float vad_probability_attack, + const std::vector& speech_probabilities) { + auto vad = std::make_unique(); + EXPECT_CALL(*vad, ComputeProbability) + .Times(AnyNumber()) + .WillRepeatedly(ReturnRoundRobin(speech_probabilities)); + return std::make_unique(vad_probability_attack, + std::move(vad)); +} + +// 10 ms mono frame. +struct FrameWithView { + // Ctor. Initializes the frame samples with `value`. + FrameWithView(float value = 0.f) + : channel0(samples.data()), + view(&channel0, /*num_channels=*/1, samples.size()) { + samples.fill(value); + } + std::array samples; + const float* const channel0; + const AudioFrameView view; +}; + +TEST(AutomaticGainController2VadLevelAnalyzer, PeakLevelGreaterThanRmsLevel) { + // Handcrafted frame so that the average is lower than the peak value. + FrameWithView frame(1000.f); // Constant frame. + frame.samples[10] = 2000.f; // Except for one peak value. // Compute audio frame levels (the VAD result is ignored). VadLevelAnalyzer analyzer; - auto levels_and_vad_prob = analyzer.AnalyzeFrame(frame_view); + auto levels_and_vad_prob = analyzer.AnalyzeFrame(frame.view); // Compare peak and RMS levels. EXPECT_LT(levels_and_vad_prob.rms_dbfs, levels_and_vad_prob.peak_dbfs); } +// Checks that the unprocessed and the smoothed speech probabilities match when +// instant attack is used. +TEST(AutomaticGainController2VadLevelAnalyzer, NoSpeechProbabilitySmoothing) { + const std::vector speech_probabilities{0.709f, 0.484f, 0.882f, 0.167f, + 0.44f, 0.525f, 0.858f, 0.314f, + 0.653f, 0.965f, 0.413f, 0.f}; + auto analyzer = + CreateVadLevelAnalyzerWithMockVad(kInstantAttack, speech_probabilities); + FrameWithView frame; + for (int i = 0; rtc::SafeLt(i, speech_probabilities.size()); ++i) { + SCOPED_TRACE(i); + EXPECT_EQ(speech_probabilities[i], + analyzer->AnalyzeFrame(frame.view).speech_probability); + } +} + +// Checks that the smoothed speech probability does not instantly converge to +// the unprocessed one when slow attack is used. +TEST(AutomaticGainController2VadLevelAnalyzer, + SlowAttackSpeechProbabilitySmoothing) { + const std::vector speech_probabilities{0.f, 0.f, 1.f, 1.f, 1.f, 1.f}; + auto analyzer = + CreateVadLevelAnalyzerWithMockVad(kSlowAttack, speech_probabilities); + FrameWithView frame; + float prev_probability = 0.f; + for (int i = 0; rtc::SafeLt(i, speech_probabilities.size()); ++i) { + SCOPED_TRACE(i); + const float smoothed_probability = + analyzer->AnalyzeFrame(frame.view).speech_probability; + EXPECT_LT(smoothed_probability, 1.f); // Not enough time to reach 1. + EXPECT_LE(prev_probability, smoothed_probability); // Converge towards 1. + prev_probability = smoothed_probability; + } +} + +// Checks that the smoothed speech probability instantly decays to the +// unprocessed one when slow attack is used. +TEST(AutomaticGainController2VadLevelAnalyzer, SpeechProbabilityInstantDecay) { + const std::vector speech_probabilities{1.f, 1.f, 1.f, 1.f, 1.f, 0.f}; + auto analyzer = + CreateVadLevelAnalyzerWithMockVad(kSlowAttack, speech_probabilities); + FrameWithView frame; + for (int i = 0; rtc::SafeLt(i, speech_probabilities.size() - 1); ++i) { + analyzer->AnalyzeFrame(frame.view); + } + EXPECT_EQ(0.f, analyzer->AnalyzeFrame(frame.view).speech_probability); +} + } // namespace } // namespace webrtc