From 7efe5332f2f4e43a3df4bbe13de96af26623f81c Mon Sep 17 00:00:00 2001 From: Alessio Bazzica Date: Fri, 8 Apr 2022 11:22:36 +0200 Subject: [PATCH] APM Transient Suppressor (TS): integrate `VoiceProbabilityDelayUnit` This CL adds a component in the TS implementation to return a delayed version of the voice probability values observed when `Suppress()` is called. That is needed in order to temporally align the voice probability values to the processed audio since TS adds algorithmic delay. Bug: webrtc:13663 Change-Id: I5041ace3939d2ce7ba084ae703428e66f1aa06be Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/255860 Reviewed-by: Hanna Silen Commit-Queue: Alessio Bazzica Cr-Commit-Position: refs/heads/main@{#36496} --- modules/audio_processing/transient/BUILD.gn | 3 + .../transient/transient_suppressor.h | 21 +++-- .../transient/transient_suppressor_impl.cc | 33 ++++--- .../transient/transient_suppressor_impl.h | 22 ++--- .../transient_suppressor_unittest.cc | 90 +++++++++++++++++-- 5 files changed, 133 insertions(+), 36 deletions(-) diff --git a/modules/audio_processing/transient/BUILD.gn b/modules/audio_processing/transient/BUILD.gn index 02b56721c4..6d9802e0a2 100644 --- a/modules/audio_processing/transient/BUILD.gn +++ b/modules/audio_processing/transient/BUILD.gn @@ -37,6 +37,7 @@ rtc_library("transient_suppressor_impl") { ] deps = [ ":transient_suppressor_api", + ":voice_probability_delay_unit", "../../../common_audio:common_audio", "../../../common_audio:common_audio_c", "../../../common_audio:fir_filter", @@ -96,6 +97,7 @@ if (rtc_include_tests) { "//testing/gtest", "//third_party/abseil-cpp/absl/flags:flag", "//third_party/abseil-cpp/absl/flags:parse", + "//third_party/abseil-cpp/absl/types:optional", ] } } @@ -124,5 +126,6 @@ if (rtc_include_tests) { "../../../test:test_support", "//testing/gtest", ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] } } diff --git a/modules/audio_processing/transient/transient_suppressor.h b/modules/audio_processing/transient/transient_suppressor.h index dd998a1154..ecb3c3baab 100644 --- a/modules/audio_processing/transient/transient_suppressor.h +++ b/modules/audio_processing/transient/transient_suppressor.h @@ -56,15 +56,18 @@ class TransientSuppressor { // of audio. If voice information is not available, `voice_probability` must // always be set to 1. // `key_pressed` determines if a key was pressed on this audio chunk. - virtual void Suppress(float* data, - size_t data_length, - int num_channels, - const float* detection_data, - size_t detection_length, - const float* reference_data, - size_t reference_length, - float voice_probability, - bool key_pressed) = 0; + // Returns a delayed version of `voice_probability` according to the + // algorithmic delay introduced by this method. In this way, the modified + // `data` and the returned voice probability will be temporally aligned. + virtual float Suppress(float* data, + size_t data_length, + int num_channels, + const float* detection_data, + size_t detection_length, + const float* reference_data, + size_t reference_length, + float voice_probability, + bool key_pressed) = 0; }; } // namespace webrtc diff --git a/modules/audio_processing/transient/transient_suppressor_impl.cc b/modules/audio_processing/transient/transient_suppressor_impl.cc index f3fbf09240..90428464e3 100644 --- a/modules/audio_processing/transient/transient_suppressor_impl.cc +++ b/modules/audio_processing/transient/transient_suppressor_impl.cc @@ -62,6 +62,7 @@ TransientSuppressorImpl::TransientSuppressorImpl(VadMode vad_mode, int detector_rate_hz, int num_channels) : vad_mode_(vad_mode), + voice_probability_delay_unit_(/*delay_num_samples=*/0, sample_rate_hz), analyzed_audio_is_silent_(false), data_length_(0), detection_length_(0), @@ -125,6 +126,9 @@ void TransientSuppressorImpl::Initialize(int sample_rate_hz, RTC_DCHECK_LE(data_length_, analysis_length_); buffer_delay_ = analysis_length_ - data_length_; + voice_probability_delay_unit_.Initialize(/*delay_num_samples=*/buffer_delay_, + sample_rate_hz); + complex_analysis_length_ = analysis_length_ / 2 + 1; RTC_DCHECK_GE(complex_analysis_length_, kMaxVoiceBin); num_channels_ = num_channels; @@ -175,19 +179,21 @@ void TransientSuppressorImpl::Initialize(int sample_rate_hz, using_reference_ = false; } -void TransientSuppressorImpl::Suppress(float* data, - size_t data_length, - int num_channels, - const float* detection_data, - size_t detection_length, - const float* reference_data, - size_t reference_length, - float voice_probability, - bool key_pressed) { +float TransientSuppressorImpl::Suppress(float* data, + size_t data_length, + int num_channels, + const float* detection_data, + size_t detection_length, + const float* reference_data, + size_t reference_length, + float voice_probability, + bool key_pressed) { if (!data || data_length != data_length_ || num_channels != num_channels_ || detection_length != detection_length_ || voice_probability < 0 || voice_probability > 1) { - return; + // The audio is not modified, so the voice probability is returned as is + // (delay not applied). + return voice_probability; } UpdateKeypress(key_pressed); @@ -205,7 +211,9 @@ void TransientSuppressorImpl::Suppress(float* data, float detector_result = detector_->Detect(detection_data, detection_length, reference_data, reference_length); if (detector_result < 0) { - return; + // The audio is not modified, so the voice probability is returned as is + // (delay not applied). + return voice_probability; } using_reference_ = detector_->using_reference(); @@ -235,6 +243,9 @@ void TransientSuppressorImpl::Suppress(float* data, : &in_buffer_[i * analysis_length_], data_length_ * sizeof(*data)); } + + // The audio has been modified, return the delayed voice probability. + return voice_probability_delay_unit_.Delay(voice_probability); } // This should only be called when detection is enabled. UpdateBuffers() must diff --git a/modules/audio_processing/transient/transient_suppressor_impl.h b/modules/audio_processing/transient/transient_suppressor_impl.h index 75caf5b813..4005a16b0a 100644 --- a/modules/audio_processing/transient/transient_suppressor_impl.h +++ b/modules/audio_processing/transient/transient_suppressor_impl.h @@ -17,6 +17,7 @@ #include #include "modules/audio_processing/transient/transient_suppressor.h" +#include "modules/audio_processing/transient/voice_probability_delay_unit.h" #include "rtc_base/gtest_prod_util.h" namespace webrtc { @@ -37,18 +38,18 @@ class TransientSuppressorImpl : public TransientSuppressor { int detector_rate_hz, int num_channels) override; - void Suppress(float* data, - size_t data_length, - int num_channels, - const float* detection_data, - size_t detection_length, - const float* reference_data, - size_t reference_length, - float voice_probability, - bool key_pressed) override; + float Suppress(float* data, + size_t data_length, + int num_channels, + const float* detection_data, + size_t detection_length, + const float* reference_data, + size_t reference_length, + float voice_probability, + bool key_pressed) override; private: - FRIEND_TEST_ALL_PREFIXES(TransientSuppressorImplTest, + FRIEND_TEST_ALL_PREFIXES(TransientSuppressorVadModeParametrization, TypingDetectionLogicWorksAsExpectedForMono); void Suppress(float* in_ptr, float* spectral_mean, float* out_ptr); @@ -61,6 +62,7 @@ class TransientSuppressorImpl : public TransientSuppressor { void SoftRestoration(float* spectral_mean); const VadMode vad_mode_; + VoiceProbabilityDelayUnit voice_probability_delay_unit_; std::unique_ptr detector_; diff --git a/modules/audio_processing/transient/transient_suppressor_unittest.cc b/modules/audio_processing/transient/transient_suppressor_unittest.cc index eb24cd1cf8..ab48504af6 100644 --- a/modules/audio_processing/transient/transient_suppressor_unittest.cc +++ b/modules/audio_processing/transient/transient_suppressor_unittest.cc @@ -10,21 +10,37 @@ #include "modules/audio_processing/transient/transient_suppressor.h" +#include + +#include "absl/types/optional.h" #include "modules/audio_processing/transient/common.h" #include "modules/audio_processing/transient/transient_suppressor_impl.h" #include "test/gtest.h" namespace webrtc { +namespace { +constexpr int kMono = 1; -class TransientSuppressorImplTest +// Returns the index of the first non-zero sample in `samples` or an unspecified +// value if no value is zero. +absl::optional FindFirstNonZeroSample(const std::vector& samples) { + for (size_t i = 0; i < samples.size(); ++i) { + if (samples[i] != 0.0f) { + return i; + } + } + return absl::nullopt; +} + +} // namespace + +class TransientSuppressorVadModeParametrization : public ::testing::TestWithParam {}; -TEST_P(TransientSuppressorImplTest, +TEST_P(TransientSuppressorVadModeParametrization, TypingDetectionLogicWorksAsExpectedForMono) { - static const int kNumChannels = 1; - TransientSuppressorImpl ts(GetParam(), ts::kSampleRate16kHz, - ts::kSampleRate16kHz, kNumChannels); + ts::kSampleRate16kHz, kMono); // Each key-press enables detection. EXPECT_FALSE(ts.detection_enabled_); @@ -88,10 +104,72 @@ TEST_P(TransientSuppressorImplTest, } INSTANTIATE_TEST_SUITE_P( - , TransientSuppressorImplTest, + TransientSuppressorVadModeParametrization, ::testing::Values(TransientSuppressor::VadMode::kDefault, TransientSuppressor::VadMode::kRnnVad, TransientSuppressor::VadMode::kNoVad)); +class TransientSuppressorSampleRateParametrization + : public ::testing::TestWithParam {}; + +// Checks that voice probability and processed audio data are temporally aligned +// after `Suppress()` is called. +TEST_P(TransientSuppressorSampleRateParametrization, + CheckAudioAndVoiceProbabilityTemporallyAligned) { + const int sample_rate_hz = GetParam(); + TransientSuppressorImpl ts(TransientSuppressor::VadMode::kDefault, + sample_rate_hz, + /*detection_rate_hz=*/sample_rate_hz, kMono); + + const int frame_size = sample_rate_hz * ts::kChunkSizeMs / 1000; + std::vector frame(frame_size); + + constexpr int kMaxAttempts = 3; + for (int i = 0; i < kMaxAttempts; ++i) { + SCOPED_TRACE(i); + + // Call `Suppress()` on frames of non-zero audio samples. + std::fill(frame.begin(), frame.end(), 1000.0f); + float delayed_voice_probability = ts.Suppress( + frame.data(), frame.size(), kMono, /*detection_data=*/nullptr, + /*detection_length=*/frame_size, /*reference_data=*/nullptr, + /*reference_length=*/frame_size, /*voice_probability=*/1.0f, + /*key_pressed=*/false); + + // Detect the algorithmic delay of `TransientSuppressorImpl`. + absl::optional frame_delay = FindFirstNonZeroSample(frame); + + // Check that the delayed voice probability is delayed according to the + // measured delay. + if (frame_delay.has_value()) { + if (*frame_delay == 0) { + // When the delay is a multiple integer of the frame duration, + // `Suppress()` returns a copy of a previously observed voice + // probability value. + EXPECT_EQ(delayed_voice_probability, 1.0f); + } else { + // Instead, when the delay is fractional, `Suppress()` returns an + // interpolated value. Since the exact value depends on the + // interpolation method, we only check that the delayed voice + // probability is not zero as it must converge towards the previoulsy + // observed value. + EXPECT_GT(delayed_voice_probability, 0.0f); + } + break; + } else { + // The algorithmic delay is longer than the duration of a single frame. + // Until the delay is detected, the delayed voice probability is zero. + EXPECT_EQ(delayed_voice_probability, 0.0f); + } + } +} + +INSTANTIATE_TEST_SUITE_P(TransientSuppressorImplTest, + TransientSuppressorSampleRateParametrization, + ::testing::Values(ts::kSampleRate8kHz, + ts::kSampleRate16kHz, + ts::kSampleRate32kHz, + ts::kSampleRate48kHz)); + } // namespace webrtc