Add SpeechProbabilityBuffer

Add a buffer class to store speech probabilities and to estimate speech
activity. Follows the implementation of speech activity computation in
LoudnessHistogram but uses floats for computations.

Bug: webrtc:7494
Change-Id: I6ee72ec52919904ea4e1fbe51d61993aa7813c9f
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/277801
Reviewed-by: Alessio Bazzica <alessiob@webrtc.org>
Commit-Queue: Hanna Silen <silen@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#38309}
This commit is contained in:
Hanna Silen 2022-10-05 18:48:36 +02:00 committed by WebRTC LUCI CQ
parent c57a28c46b
commit 767898c048
4 changed files with 550 additions and 0 deletions

View File

@ -181,6 +181,23 @@ rtc_source_set("gain_map") {
sources = [ "gain_map_internal.h" ]
}
rtc_library("input_volume_controller") {
sources = [
"speech_probability_buffer.cc",
"speech_probability_buffer.h",
]
visibility = [
"..:gain_controller2",
"./*",
]
deps = [
"../../../rtc_base:checks",
"../../../rtc_base:gtest_prod",
]
}
rtc_library("noise_level_estimator") {
sources = [
"noise_level_estimator.cc",
@ -317,6 +334,7 @@ rtc_library("input_volume_controller_unittests") {
"clipping_predictor_evaluator_unittest.cc",
"clipping_predictor_level_buffer_unittest.cc",
"clipping_predictor_unittest.cc",
"speech_probability_buffer_unittest.cc",
]
configs += [ "..:apm_debug_dump" ]
@ -325,6 +343,7 @@ rtc_library("input_volume_controller_unittests") {
":clipping_predictor",
":clipping_predictor_evaluator",
":gain_map",
":input_volume_controller",
"../../../rtc_base:checks",
"../../../rtc_base:random",
"../../../rtc_base:safe_conversions",

View File

@ -0,0 +1,105 @@
/*
* Copyright (c) 2022 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/speech_probability_buffer.h"
#include <algorithm>
#include "rtc_base/checks.h"
namespace webrtc {
namespace {
constexpr float kActivityThreshold = 0.9f;
constexpr int kNumAnalysisFrames = 100;
// We use 12 in AGC2 adaptive digital, but with a slightly different logic.
constexpr int kTransientWidthThreshold = 7;
} // namespace
SpeechProbabilityBuffer::SpeechProbabilityBuffer(
float low_probability_threshold)
: low_probability_threshold_(low_probability_threshold),
probabilities_(kNumAnalysisFrames) {
RTC_DCHECK_GE(low_probability_threshold, 0.0f);
RTC_DCHECK_LE(low_probability_threshold, 1.0f);
RTC_DCHECK(!probabilities_.empty());
}
void SpeechProbabilityBuffer::Update(float probability) {
// Remove the oldest entry if the circular buffer is full.
if (buffer_is_full_) {
const float oldest_probability = probabilities_[buffer_index_];
sum_probabilities_ -= oldest_probability;
}
// Check for transients.
if (probability <= low_probability_threshold_) {
// Set a probability lower than the threshold to zero.
probability = 0.0f;
// Check if this has been a transient.
if (num_high_probability_observations_ <= kTransientWidthThreshold) {
RemoveTransient();
}
num_high_probability_observations_ = 0;
} else if (num_high_probability_observations_ <= kTransientWidthThreshold) {
++num_high_probability_observations_;
}
// Update the circular buffer and the current sum.
probabilities_[buffer_index_] = probability;
sum_probabilities_ += probability;
// Increment the buffer index and check for wrap-around.
if (++buffer_index_ >= kNumAnalysisFrames) {
buffer_index_ = 0;
buffer_is_full_ = true;
}
}
void SpeechProbabilityBuffer::RemoveTransient() {
// Don't expect to be here if high-activity region is longer than
// `kTransientWidthThreshold` or there has not been any transient.
RTC_DCHECK_LE(num_high_probability_observations_, kTransientWidthThreshold);
// Replace previously added probabilities with zero.
int index =
(buffer_index_ > 0) ? (buffer_index_ - 1) : (kNumAnalysisFrames - 1);
while (num_high_probability_observations_-- > 0) {
sum_probabilities_ -= probabilities_[index];
probabilities_[index] = 0.0f;
// Update the circular buffer index.
index = (index > 0) ? (index - 1) : (kNumAnalysisFrames - 1);
}
}
bool SpeechProbabilityBuffer::IsActiveSegment() const {
if (!buffer_is_full_) {
return false;
}
if (sum_probabilities_ < kActivityThreshold * kNumAnalysisFrames) {
return false;
}
return true;
}
void SpeechProbabilityBuffer::Reset() {
sum_probabilities_ = 0.0f;
// Empty the circular buffer.
buffer_index_ = 0;
buffer_is_full_ = false;
num_high_probability_observations_ = 0;
}
} // namespace webrtc

View File

@ -0,0 +1,80 @@
/*
* Copyright (c) 2022 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_SPEECH_PROBABILITY_BUFFER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_SPEECH_PROBABILITY_BUFFER_H_
#include <vector>
#include "rtc_base/gtest_prod_util.h"
namespace webrtc {
// This class implements a circular buffer that stores speech probabilities
// for a speech segment and estimates speech activity for that segment.
class SpeechProbabilityBuffer {
public:
// Ctor. The value of `low_probability_threshold` is required to be on the
// range [0.0f, 1.0f].
explicit SpeechProbabilityBuffer(float low_probability_threshold);
~SpeechProbabilityBuffer() {}
SpeechProbabilityBuffer(const SpeechProbabilityBuffer&) = delete;
SpeechProbabilityBuffer& operator=(const SpeechProbabilityBuffer&) = delete;
// Adds `probability` in the buffer and computes an updatds sum of the buffer
// probabilities. Value of `probability` is required to be on the range
// [0.0f, 1.0f].
void Update(float probability);
// Resets the histogram, forgets the past.
void Reset();
// Returns true if the segment is active (a long enough segment with an
// average speech probability above `low_probability_threshold`).
bool IsActiveSegment() const;
private:
void RemoveTransient();
// Use only for testing.
float GetSumProbabilities() const { return sum_probabilities_; }
FRIEND_TEST_ALL_PREFIXES(SpeechProbabilityBufferTest,
CheckSumAfterInitialization);
FRIEND_TEST_ALL_PREFIXES(SpeechProbabilityBufferTest, CheckSumAfterUpdate);
FRIEND_TEST_ALL_PREFIXES(SpeechProbabilityBufferTest, CheckSumAfterReset);
FRIEND_TEST_ALL_PREFIXES(SpeechProbabilityBufferTest,
CheckSumAfterTransientNotRemoved);
FRIEND_TEST_ALL_PREFIXES(SpeechProbabilityBufferTest,
CheckSumAfterTransientRemoved);
const float low_probability_threshold_;
// Sum of probabilities stored in `probabilities_`. Must be updated if
// `probabilities_` is updated.
float sum_probabilities_ = 0.0f;
// Circular buffer for probabilities.
std::vector<float> probabilities_;
// Current index of the circular buffer, where the newest data will be written
// to, therefore, pointing to the oldest data if buffer is full.
int buffer_index_ = 0;
// Indicates if the buffer is full and adding a new value removes the oldest
// value.
int buffer_is_full_ = false;
int num_high_probability_observations_ = 0;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_SPEECH_PROBABILITY_BUFFER_H_

View File

@ -0,0 +1,346 @@
/*
* Copyright (c) 2022 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/speech_probability_buffer.h"
#include <algorithm>
#include "test/gtest.h"
namespace webrtc {
namespace {
constexpr float kAbsError = 0.001f;
constexpr float kActivityThreshold = 0.9f;
constexpr float kLowProbabilityThreshold = 0.2f;
constexpr int kNumAnalysisFrames = 100;
} // namespace
TEST(SpeechProbabilityBufferTest, CheckSumAfterInitialization) {
SpeechProbabilityBuffer buffer(kLowProbabilityThreshold);
EXPECT_EQ(buffer.GetSumProbabilities(), 0.0f);
}
TEST(SpeechProbabilityBufferTest, CheckSumAfterUpdate) {
SpeechProbabilityBuffer buffer(kLowProbabilityThreshold);
buffer.Update(0.7f);
EXPECT_NEAR(buffer.GetSumProbabilities(), 0.7f, kAbsError);
buffer.Update(0.6f);
EXPECT_NEAR(buffer.GetSumProbabilities(), 1.3f, kAbsError);
for (int i = 0; i < kNumAnalysisFrames - 1; ++i) {
buffer.Update(1.0f);
}
EXPECT_NEAR(buffer.GetSumProbabilities(), 99.6f, kAbsError);
}
TEST(SpeechProbabilityBufferTest, CheckSumAfterReset) {
SpeechProbabilityBuffer buffer(kLowProbabilityThreshold);
buffer.Update(0.7f);
buffer.Update(0.6f);
buffer.Update(0.3f);
EXPECT_GT(buffer.GetSumProbabilities(), 0.0f);
buffer.Reset();
EXPECT_EQ(buffer.GetSumProbabilities(), 0.0f);
}
TEST(SpeechProbabilityBufferTest, CheckSumAfterTransientNotRemoved) {
SpeechProbabilityBuffer buffer(kLowProbabilityThreshold);
buffer.Update(1.0f);
buffer.Update(1.0f);
buffer.Update(1.0f);
buffer.Update(1.0f);
buffer.Update(1.0f);
buffer.Update(1.0f);
buffer.Update(1.0f);
buffer.Update(1.0f);
buffer.Update(1.0f);
buffer.Update(0.0f);
EXPECT_NEAR(buffer.GetSumProbabilities(), 9.0f, kAbsError);
buffer.Update(0.0f);
EXPECT_NEAR(buffer.GetSumProbabilities(), 9.0f, kAbsError);
}
TEST(SpeechProbabilityBufferTest, CheckSumAfterTransientRemoved) {
SpeechProbabilityBuffer buffer(kLowProbabilityThreshold);
buffer.Update(0.0f);
buffer.Update(0.0f);
buffer.Update(0.0f);
buffer.Update(0.0f);
buffer.Update(0.0f);
buffer.Update(0.0f);
buffer.Update(1.0f);
buffer.Update(1.0f);
buffer.Update(1.0f);
EXPECT_NEAR(buffer.GetSumProbabilities(), 3.0f, kAbsError);
buffer.Update(0.0f);
EXPECT_NEAR(buffer.GetSumProbabilities(), 0.0f, kAbsError);
}
TEST(SpeechProbabilityBufferTest, CheckSegmentIsNotActiveAfterNoUpdates) {
SpeechProbabilityBuffer buffer(kLowProbabilityThreshold);
EXPECT_FALSE(buffer.IsActiveSegment());
}
TEST(SpeechProbabilityBufferTest, CheckSegmentIsActiveChangesFromFalseToTrue) {
SpeechProbabilityBuffer buffer(kLowProbabilityThreshold);
// Add low probabilities until the buffer is full. That's not enough
// to make `IsActiveSegment()` to return true.
for (int i = 0; i < kNumAnalysisFrames; ++i) {
buffer.Update(0.0f);
}
EXPECT_FALSE(buffer.IsActiveSegment());
// Add high probabilities until `IsActiveSegment()` returns true.
for (int i = 0; i < kActivityThreshold * kNumAnalysisFrames - 1; ++i) {
buffer.Update(1.0f);
}
EXPECT_FALSE(buffer.IsActiveSegment());
buffer.Update(1.0f);
EXPECT_TRUE(buffer.IsActiveSegment());
}
TEST(SpeechProbabilityBufferTest, CheckSegmentIsActiveChangesFromTrueToFalse) {
SpeechProbabilityBuffer buffer(kLowProbabilityThreshold);
// Add high probabilities until the buffer is full. That's enough to
// make `IsActiveSegment()` to return true.
for (int i = 0; i < kNumAnalysisFrames; ++i) {
buffer.Update(1.0f);
}
EXPECT_TRUE(buffer.IsActiveSegment());
// Add low probabilities until `IsActiveSegment()` returns false.
for (int i = 0; i < (1.0f - kActivityThreshold) * kNumAnalysisFrames - 1;
++i) {
buffer.Update(0.0f);
}
EXPECT_TRUE(buffer.IsActiveSegment());
buffer.Update(1.0f);
EXPECT_TRUE(buffer.IsActiveSegment());
}
TEST(SpeechProbabilityBufferTest,
CheckSegmentIsActiveAfterUpdatesWithHighProbabilities) {
SpeechProbabilityBuffer buffer(kLowProbabilityThreshold);
for (int i = 0; i < kNumAnalysisFrames - 1; ++i) {
buffer.Update(1.0f);
}
EXPECT_FALSE(buffer.IsActiveSegment());
buffer.Update(1.0f);
EXPECT_TRUE(buffer.IsActiveSegment());
}
TEST(SpeechProbabilityBufferTest,
CheckSegmentIsNotActiveAfterUpdatesWithLowProbabilities) {
SpeechProbabilityBuffer buffer(kLowProbabilityThreshold);
for (int i = 0; i < kNumAnalysisFrames - 1; ++i) {
buffer.Update(0.3f);
}
EXPECT_FALSE(buffer.IsActiveSegment());
buffer.Update(0.3f);
EXPECT_FALSE(buffer.IsActiveSegment());
}
TEST(SpeechProbabilityBufferTest, CheckSegmentIsActiveAfterBufferIsFull) {
SpeechProbabilityBuffer buffer(kLowProbabilityThreshold);
for (int i = 0; i < kNumAnalysisFrames - 1; ++i) {
buffer.Update(1.0f);
}
EXPECT_FALSE(buffer.IsActiveSegment());
buffer.Update(1.0f);
EXPECT_TRUE(buffer.IsActiveSegment());
buffer.Update(1.0f);
EXPECT_TRUE(buffer.IsActiveSegment());
}
TEST(SpeechProbabilityBufferTest, CheckSegmentIsNotActiveAfterBufferIsFull) {
SpeechProbabilityBuffer buffer(kLowProbabilityThreshold);
for (int i = 0; i < kNumAnalysisFrames - 1; ++i) {
buffer.Update(0.29f);
}
EXPECT_FALSE(buffer.IsActiveSegment());
buffer.Update(0.29f);
EXPECT_FALSE(buffer.IsActiveSegment());
buffer.Update(0.29f);
EXPECT_FALSE(buffer.IsActiveSegment());
}
TEST(SpeechProbabilityBufferTest, CheckSegmentIsNotActiveAfterReset) {
SpeechProbabilityBuffer buffer(kLowProbabilityThreshold);
for (int i = 0; i < kNumAnalysisFrames; ++i) {
buffer.Update(1.0f);
}
EXPECT_TRUE(buffer.IsActiveSegment());
buffer.Reset();
EXPECT_FALSE(buffer.IsActiveSegment());
}
TEST(SpeechProbabilityBufferTest,
CheckSegmentIsNotActiveAfterTransientRemovedAfterFewUpdates) {
SpeechProbabilityBuffer buffer(kLowProbabilityThreshold);
buffer.Update(0.4f);
buffer.Update(0.4f);
buffer.Update(0.0f);
EXPECT_FALSE(buffer.IsActiveSegment());
}
TEST(SpeechProbabilityBufferTest,
CheckSegmentIsActiveAfterTransientNotRemoved) {
SpeechProbabilityBuffer buffer(kLowProbabilityThreshold);
for (int i = 0; i < kNumAnalysisFrames; ++i) {
buffer.Update(1.0f);
}
buffer.Update(0.7f);
buffer.Update(0.8f);
buffer.Update(0.9f);
buffer.Update(1.0f);
EXPECT_TRUE(buffer.IsActiveSegment());
buffer.Update(0.0f);
EXPECT_TRUE(buffer.IsActiveSegment());
buffer.Update(0.7f);
EXPECT_TRUE(buffer.IsActiveSegment());
}
TEST(SpeechProbabilityBufferTest,
CheckSegmentIsNotActiveAfterTransientNotRemoved) {
SpeechProbabilityBuffer buffer(kLowProbabilityThreshold);
for (int i = 0; i < kNumAnalysisFrames; ++i) {
buffer.Update(0.1f);
}
buffer.Update(0.7f);
buffer.Update(0.8f);
buffer.Update(0.9f);
buffer.Update(1.0f);
EXPECT_FALSE(buffer.IsActiveSegment());
buffer.Update(0.0f);
EXPECT_FALSE(buffer.IsActiveSegment());
buffer.Update(0.7f);
EXPECT_FALSE(buffer.IsActiveSegment());
}
TEST(SpeechProbabilityBufferTest,
CheckSegmentIsNotActiveAfterTransientRemoved) {
SpeechProbabilityBuffer buffer(kLowProbabilityThreshold);
for (int i = 0; i < kNumAnalysisFrames; ++i) {
buffer.Update(0.1f);
}
buffer.Update(0.7f);
buffer.Update(0.8f);
buffer.Update(0.9f);
buffer.Update(1.0f);
EXPECT_FALSE(buffer.IsActiveSegment());
buffer.Update(0.0f);
EXPECT_FALSE(buffer.IsActiveSegment());
buffer.Update(0.7f);
EXPECT_FALSE(buffer.IsActiveSegment());
}
TEST(SpeechProbabilityBufferTest, CheckSegmentIsActiveAfterTransientRemoved) {
SpeechProbabilityBuffer buffer(kLowProbabilityThreshold);
for (int i = 0; i < kNumAnalysisFrames; ++i) {
buffer.Update(1.0f);
}
buffer.Update(0.7f);
buffer.Update(0.8f);
buffer.Update(0.9f);
buffer.Update(1.0f);
EXPECT_TRUE(buffer.IsActiveSegment());
buffer.Update(0.0f);
EXPECT_TRUE(buffer.IsActiveSegment());
buffer.Update(0.7f);
EXPECT_TRUE(buffer.IsActiveSegment());
}
} // namespace webrtc