From 767898c0481d510ffdc55ac3ae39f3f448d09b41 Mon Sep 17 00:00:00 2001 From: Hanna Silen Date: Wed, 5 Oct 2022 18:48:36 +0200 Subject: [PATCH] Add SpeechProbabilityBuffer Add a buffer class to store speech probabilities and to estimate speech activity. Follows the implementation of speech activity computation in LoudnessHistogram but uses floats for computations. Bug: webrtc:7494 Change-Id: I6ee72ec52919904ea4e1fbe51d61993aa7813c9f Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/277801 Reviewed-by: Alessio Bazzica Commit-Queue: Hanna Silen Cr-Commit-Position: refs/heads/main@{#38309} --- modules/audio_processing/agc2/BUILD.gn | 19 + .../agc2/speech_probability_buffer.cc | 105 ++++++ .../agc2/speech_probability_buffer.h | 80 ++++ .../speech_probability_buffer_unittest.cc | 346 ++++++++++++++++++ 4 files changed, 550 insertions(+) create mode 100644 modules/audio_processing/agc2/speech_probability_buffer.cc create mode 100644 modules/audio_processing/agc2/speech_probability_buffer.h create mode 100644 modules/audio_processing/agc2/speech_probability_buffer_unittest.cc diff --git a/modules/audio_processing/agc2/BUILD.gn b/modules/audio_processing/agc2/BUILD.gn index d6e684e588..f7d7842c07 100644 --- a/modules/audio_processing/agc2/BUILD.gn +++ b/modules/audio_processing/agc2/BUILD.gn @@ -181,6 +181,23 @@ rtc_source_set("gain_map") { sources = [ "gain_map_internal.h" ] } +rtc_library("input_volume_controller") { + sources = [ + "speech_probability_buffer.cc", + "speech_probability_buffer.h", + ] + + visibility = [ + "..:gain_controller2", + "./*", + ] + + deps = [ + "../../../rtc_base:checks", + "../../../rtc_base:gtest_prod", + ] +} + rtc_library("noise_level_estimator") { sources = [ "noise_level_estimator.cc", @@ -317,6 +334,7 @@ rtc_library("input_volume_controller_unittests") { "clipping_predictor_evaluator_unittest.cc", "clipping_predictor_level_buffer_unittest.cc", "clipping_predictor_unittest.cc", + "speech_probability_buffer_unittest.cc", ] configs += [ "..:apm_debug_dump" ] @@ -325,6 +343,7 @@ rtc_library("input_volume_controller_unittests") { ":clipping_predictor", ":clipping_predictor_evaluator", ":gain_map", + ":input_volume_controller", "../../../rtc_base:checks", "../../../rtc_base:random", "../../../rtc_base:safe_conversions", diff --git a/modules/audio_processing/agc2/speech_probability_buffer.cc b/modules/audio_processing/agc2/speech_probability_buffer.cc new file mode 100644 index 0000000000..7746f6c000 --- /dev/null +++ b/modules/audio_processing/agc2/speech_probability_buffer.cc @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2022 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/speech_probability_buffer.h" + +#include + +#include "rtc_base/checks.h" + +namespace webrtc { +namespace { + +constexpr float kActivityThreshold = 0.9f; +constexpr int kNumAnalysisFrames = 100; +// We use 12 in AGC2 adaptive digital, but with a slightly different logic. +constexpr int kTransientWidthThreshold = 7; + +} // namespace + +SpeechProbabilityBuffer::SpeechProbabilityBuffer( + float low_probability_threshold) + : low_probability_threshold_(low_probability_threshold), + probabilities_(kNumAnalysisFrames) { + RTC_DCHECK_GE(low_probability_threshold, 0.0f); + RTC_DCHECK_LE(low_probability_threshold, 1.0f); + RTC_DCHECK(!probabilities_.empty()); +} + +void SpeechProbabilityBuffer::Update(float probability) { + // Remove the oldest entry if the circular buffer is full. + if (buffer_is_full_) { + const float oldest_probability = probabilities_[buffer_index_]; + sum_probabilities_ -= oldest_probability; + } + + // Check for transients. + if (probability <= low_probability_threshold_) { + // Set a probability lower than the threshold to zero. + probability = 0.0f; + + // Check if this has been a transient. + if (num_high_probability_observations_ <= kTransientWidthThreshold) { + RemoveTransient(); + } + num_high_probability_observations_ = 0; + } else if (num_high_probability_observations_ <= kTransientWidthThreshold) { + ++num_high_probability_observations_; + } + + // Update the circular buffer and the current sum. + probabilities_[buffer_index_] = probability; + sum_probabilities_ += probability; + + // Increment the buffer index and check for wrap-around. + if (++buffer_index_ >= kNumAnalysisFrames) { + buffer_index_ = 0; + buffer_is_full_ = true; + } +} + +void SpeechProbabilityBuffer::RemoveTransient() { + // Don't expect to be here if high-activity region is longer than + // `kTransientWidthThreshold` or there has not been any transient. + RTC_DCHECK_LE(num_high_probability_observations_, kTransientWidthThreshold); + + // Replace previously added probabilities with zero. + int index = + (buffer_index_ > 0) ? (buffer_index_ - 1) : (kNumAnalysisFrames - 1); + + while (num_high_probability_observations_-- > 0) { + sum_probabilities_ -= probabilities_[index]; + probabilities_[index] = 0.0f; + + // Update the circular buffer index. + index = (index > 0) ? (index - 1) : (kNumAnalysisFrames - 1); + } +} + +bool SpeechProbabilityBuffer::IsActiveSegment() const { + if (!buffer_is_full_) { + return false; + } + if (sum_probabilities_ < kActivityThreshold * kNumAnalysisFrames) { + return false; + } + return true; +} + +void SpeechProbabilityBuffer::Reset() { + sum_probabilities_ = 0.0f; + + // Empty the circular buffer. + buffer_index_ = 0; + buffer_is_full_ = false; + num_high_probability_observations_ = 0; +} + +} // namespace webrtc diff --git a/modules/audio_processing/agc2/speech_probability_buffer.h b/modules/audio_processing/agc2/speech_probability_buffer.h new file mode 100644 index 0000000000..3056a3eeab --- /dev/null +++ b/modules/audio_processing/agc2/speech_probability_buffer.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2022 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC2_SPEECH_PROBABILITY_BUFFER_H_ +#define MODULES_AUDIO_PROCESSING_AGC2_SPEECH_PROBABILITY_BUFFER_H_ + +#include + +#include "rtc_base/gtest_prod_util.h" + +namespace webrtc { + +// This class implements a circular buffer that stores speech probabilities +// for a speech segment and estimates speech activity for that segment. +class SpeechProbabilityBuffer { + public: + // Ctor. The value of `low_probability_threshold` is required to be on the + // range [0.0f, 1.0f]. + explicit SpeechProbabilityBuffer(float low_probability_threshold); + ~SpeechProbabilityBuffer() {} + SpeechProbabilityBuffer(const SpeechProbabilityBuffer&) = delete; + SpeechProbabilityBuffer& operator=(const SpeechProbabilityBuffer&) = delete; + + // Adds `probability` in the buffer and computes an updatds sum of the buffer + // probabilities. Value of `probability` is required to be on the range + // [0.0f, 1.0f]. + void Update(float probability); + + // Resets the histogram, forgets the past. + void Reset(); + + // Returns true if the segment is active (a long enough segment with an + // average speech probability above `low_probability_threshold`). + bool IsActiveSegment() const; + + private: + void RemoveTransient(); + + // Use only for testing. + float GetSumProbabilities() const { return sum_probabilities_; } + + FRIEND_TEST_ALL_PREFIXES(SpeechProbabilityBufferTest, + CheckSumAfterInitialization); + FRIEND_TEST_ALL_PREFIXES(SpeechProbabilityBufferTest, CheckSumAfterUpdate); + FRIEND_TEST_ALL_PREFIXES(SpeechProbabilityBufferTest, CheckSumAfterReset); + FRIEND_TEST_ALL_PREFIXES(SpeechProbabilityBufferTest, + CheckSumAfterTransientNotRemoved); + FRIEND_TEST_ALL_PREFIXES(SpeechProbabilityBufferTest, + CheckSumAfterTransientRemoved); + + const float low_probability_threshold_; + + // Sum of probabilities stored in `probabilities_`. Must be updated if + // `probabilities_` is updated. + float sum_probabilities_ = 0.0f; + + // Circular buffer for probabilities. + std::vector probabilities_; + + // Current index of the circular buffer, where the newest data will be written + // to, therefore, pointing to the oldest data if buffer is full. + int buffer_index_ = 0; + + // Indicates if the buffer is full and adding a new value removes the oldest + // value. + int buffer_is_full_ = false; + + int num_high_probability_observations_ = 0; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC2_SPEECH_PROBABILITY_BUFFER_H_ diff --git a/modules/audio_processing/agc2/speech_probability_buffer_unittest.cc b/modules/audio_processing/agc2/speech_probability_buffer_unittest.cc new file mode 100644 index 0000000000..89cc209d9d --- /dev/null +++ b/modules/audio_processing/agc2/speech_probability_buffer_unittest.cc @@ -0,0 +1,346 @@ +/* + * Copyright (c) 2022 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/speech_probability_buffer.h" + +#include + +#include "test/gtest.h" + +namespace webrtc { +namespace { + +constexpr float kAbsError = 0.001f; +constexpr float kActivityThreshold = 0.9f; +constexpr float kLowProbabilityThreshold = 0.2f; +constexpr int kNumAnalysisFrames = 100; + +} // namespace + +TEST(SpeechProbabilityBufferTest, CheckSumAfterInitialization) { + SpeechProbabilityBuffer buffer(kLowProbabilityThreshold); + + EXPECT_EQ(buffer.GetSumProbabilities(), 0.0f); +} + +TEST(SpeechProbabilityBufferTest, CheckSumAfterUpdate) { + SpeechProbabilityBuffer buffer(kLowProbabilityThreshold); + + buffer.Update(0.7f); + + EXPECT_NEAR(buffer.GetSumProbabilities(), 0.7f, kAbsError); + + buffer.Update(0.6f); + + EXPECT_NEAR(buffer.GetSumProbabilities(), 1.3f, kAbsError); + + for (int i = 0; i < kNumAnalysisFrames - 1; ++i) { + buffer.Update(1.0f); + } + + EXPECT_NEAR(buffer.GetSumProbabilities(), 99.6f, kAbsError); +} + +TEST(SpeechProbabilityBufferTest, CheckSumAfterReset) { + SpeechProbabilityBuffer buffer(kLowProbabilityThreshold); + + buffer.Update(0.7f); + buffer.Update(0.6f); + buffer.Update(0.3f); + + EXPECT_GT(buffer.GetSumProbabilities(), 0.0f); + + buffer.Reset(); + + EXPECT_EQ(buffer.GetSumProbabilities(), 0.0f); +} + +TEST(SpeechProbabilityBufferTest, CheckSumAfterTransientNotRemoved) { + SpeechProbabilityBuffer buffer(kLowProbabilityThreshold); + + buffer.Update(1.0f); + buffer.Update(1.0f); + buffer.Update(1.0f); + buffer.Update(1.0f); + buffer.Update(1.0f); + buffer.Update(1.0f); + buffer.Update(1.0f); + buffer.Update(1.0f); + buffer.Update(1.0f); + + buffer.Update(0.0f); + + EXPECT_NEAR(buffer.GetSumProbabilities(), 9.0f, kAbsError); + + buffer.Update(0.0f); + + EXPECT_NEAR(buffer.GetSumProbabilities(), 9.0f, kAbsError); +} + +TEST(SpeechProbabilityBufferTest, CheckSumAfterTransientRemoved) { + SpeechProbabilityBuffer buffer(kLowProbabilityThreshold); + + buffer.Update(0.0f); + buffer.Update(0.0f); + buffer.Update(0.0f); + buffer.Update(0.0f); + buffer.Update(0.0f); + buffer.Update(0.0f); + buffer.Update(1.0f); + buffer.Update(1.0f); + buffer.Update(1.0f); + + EXPECT_NEAR(buffer.GetSumProbabilities(), 3.0f, kAbsError); + + buffer.Update(0.0f); + + EXPECT_NEAR(buffer.GetSumProbabilities(), 0.0f, kAbsError); +} + +TEST(SpeechProbabilityBufferTest, CheckSegmentIsNotActiveAfterNoUpdates) { + SpeechProbabilityBuffer buffer(kLowProbabilityThreshold); + + EXPECT_FALSE(buffer.IsActiveSegment()); +} + +TEST(SpeechProbabilityBufferTest, CheckSegmentIsActiveChangesFromFalseToTrue) { + SpeechProbabilityBuffer buffer(kLowProbabilityThreshold); + + // Add low probabilities until the buffer is full. That's not enough + // to make `IsActiveSegment()` to return true. + for (int i = 0; i < kNumAnalysisFrames; ++i) { + buffer.Update(0.0f); + } + + EXPECT_FALSE(buffer.IsActiveSegment()); + + // Add high probabilities until `IsActiveSegment()` returns true. + for (int i = 0; i < kActivityThreshold * kNumAnalysisFrames - 1; ++i) { + buffer.Update(1.0f); + } + + EXPECT_FALSE(buffer.IsActiveSegment()); + + buffer.Update(1.0f); + + EXPECT_TRUE(buffer.IsActiveSegment()); +} + +TEST(SpeechProbabilityBufferTest, CheckSegmentIsActiveChangesFromTrueToFalse) { + SpeechProbabilityBuffer buffer(kLowProbabilityThreshold); + + // Add high probabilities until the buffer is full. That's enough to + // make `IsActiveSegment()` to return true. + for (int i = 0; i < kNumAnalysisFrames; ++i) { + buffer.Update(1.0f); + } + + EXPECT_TRUE(buffer.IsActiveSegment()); + + // Add low probabilities until `IsActiveSegment()` returns false. + for (int i = 0; i < (1.0f - kActivityThreshold) * kNumAnalysisFrames - 1; + ++i) { + buffer.Update(0.0f); + } + + EXPECT_TRUE(buffer.IsActiveSegment()); + + buffer.Update(1.0f); + + EXPECT_TRUE(buffer.IsActiveSegment()); +} + +TEST(SpeechProbabilityBufferTest, + CheckSegmentIsActiveAfterUpdatesWithHighProbabilities) { + SpeechProbabilityBuffer buffer(kLowProbabilityThreshold); + + for (int i = 0; i < kNumAnalysisFrames - 1; ++i) { + buffer.Update(1.0f); + } + + EXPECT_FALSE(buffer.IsActiveSegment()); + + buffer.Update(1.0f); + + EXPECT_TRUE(buffer.IsActiveSegment()); +} + +TEST(SpeechProbabilityBufferTest, + CheckSegmentIsNotActiveAfterUpdatesWithLowProbabilities) { + SpeechProbabilityBuffer buffer(kLowProbabilityThreshold); + + for (int i = 0; i < kNumAnalysisFrames - 1; ++i) { + buffer.Update(0.3f); + } + + EXPECT_FALSE(buffer.IsActiveSegment()); + + buffer.Update(0.3f); + + EXPECT_FALSE(buffer.IsActiveSegment()); +} + +TEST(SpeechProbabilityBufferTest, CheckSegmentIsActiveAfterBufferIsFull) { + SpeechProbabilityBuffer buffer(kLowProbabilityThreshold); + + for (int i = 0; i < kNumAnalysisFrames - 1; ++i) { + buffer.Update(1.0f); + } + + EXPECT_FALSE(buffer.IsActiveSegment()); + + buffer.Update(1.0f); + + EXPECT_TRUE(buffer.IsActiveSegment()); + + buffer.Update(1.0f); + + EXPECT_TRUE(buffer.IsActiveSegment()); +} + +TEST(SpeechProbabilityBufferTest, CheckSegmentIsNotActiveAfterBufferIsFull) { + SpeechProbabilityBuffer buffer(kLowProbabilityThreshold); + + for (int i = 0; i < kNumAnalysisFrames - 1; ++i) { + buffer.Update(0.29f); + } + + EXPECT_FALSE(buffer.IsActiveSegment()); + + buffer.Update(0.29f); + + EXPECT_FALSE(buffer.IsActiveSegment()); + + buffer.Update(0.29f); + + EXPECT_FALSE(buffer.IsActiveSegment()); +} + +TEST(SpeechProbabilityBufferTest, CheckSegmentIsNotActiveAfterReset) { + SpeechProbabilityBuffer buffer(kLowProbabilityThreshold); + + for (int i = 0; i < kNumAnalysisFrames; ++i) { + buffer.Update(1.0f); + } + + EXPECT_TRUE(buffer.IsActiveSegment()); + + buffer.Reset(); + + EXPECT_FALSE(buffer.IsActiveSegment()); +} + +TEST(SpeechProbabilityBufferTest, + CheckSegmentIsNotActiveAfterTransientRemovedAfterFewUpdates) { + SpeechProbabilityBuffer buffer(kLowProbabilityThreshold); + + buffer.Update(0.4f); + buffer.Update(0.4f); + buffer.Update(0.0f); + + EXPECT_FALSE(buffer.IsActiveSegment()); +} + +TEST(SpeechProbabilityBufferTest, + CheckSegmentIsActiveAfterTransientNotRemoved) { + SpeechProbabilityBuffer buffer(kLowProbabilityThreshold); + + for (int i = 0; i < kNumAnalysisFrames; ++i) { + buffer.Update(1.0f); + } + + buffer.Update(0.7f); + buffer.Update(0.8f); + buffer.Update(0.9f); + buffer.Update(1.0f); + + EXPECT_TRUE(buffer.IsActiveSegment()); + + buffer.Update(0.0f); + + EXPECT_TRUE(buffer.IsActiveSegment()); + + buffer.Update(0.7f); + + EXPECT_TRUE(buffer.IsActiveSegment()); +} + +TEST(SpeechProbabilityBufferTest, + CheckSegmentIsNotActiveAfterTransientNotRemoved) { + SpeechProbabilityBuffer buffer(kLowProbabilityThreshold); + + for (int i = 0; i < kNumAnalysisFrames; ++i) { + buffer.Update(0.1f); + } + + buffer.Update(0.7f); + buffer.Update(0.8f); + buffer.Update(0.9f); + buffer.Update(1.0f); + + EXPECT_FALSE(buffer.IsActiveSegment()); + + buffer.Update(0.0f); + + EXPECT_FALSE(buffer.IsActiveSegment()); + + buffer.Update(0.7f); + + EXPECT_FALSE(buffer.IsActiveSegment()); +} + +TEST(SpeechProbabilityBufferTest, + CheckSegmentIsNotActiveAfterTransientRemoved) { + SpeechProbabilityBuffer buffer(kLowProbabilityThreshold); + + for (int i = 0; i < kNumAnalysisFrames; ++i) { + buffer.Update(0.1f); + } + + buffer.Update(0.7f); + buffer.Update(0.8f); + buffer.Update(0.9f); + buffer.Update(1.0f); + + EXPECT_FALSE(buffer.IsActiveSegment()); + + buffer.Update(0.0f); + + EXPECT_FALSE(buffer.IsActiveSegment()); + + buffer.Update(0.7f); + + EXPECT_FALSE(buffer.IsActiveSegment()); +} + +TEST(SpeechProbabilityBufferTest, CheckSegmentIsActiveAfterTransientRemoved) { + SpeechProbabilityBuffer buffer(kLowProbabilityThreshold); + + for (int i = 0; i < kNumAnalysisFrames; ++i) { + buffer.Update(1.0f); + } + + buffer.Update(0.7f); + buffer.Update(0.8f); + buffer.Update(0.9f); + buffer.Update(1.0f); + + EXPECT_TRUE(buffer.IsActiveSegment()); + + buffer.Update(0.0f); + + EXPECT_TRUE(buffer.IsActiveSegment()); + + buffer.Update(0.7f); + + EXPECT_TRUE(buffer.IsActiveSegment()); +} + +} // namespace webrtc