AGC2: VadWithLevel -> VoiceActivityDetectorWrapper 2/2

Internal refactoring of AGC2 to decouple the VAD, its wrapper and the
peak and RMS level measurements.

Bit exactness verified with audioproc_f on a collection of AEC dumps
and Wav files (42 recordings in total).

Bug: webrtc:7494
Change-Id: Ib560f1fcaa601557f4f30e47025c69e91b1b62e0
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/234524
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Hanna Silen <silen@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#35208}
This commit is contained in:
Alessio Bazzica 2021-10-14 12:15:20 +02:00 committed by WebRTC LUCI CQ
parent 54f377308f
commit 8dbdf5e3bf
9 changed files with 226 additions and 183 deletions

View File

@ -280,6 +280,7 @@ rtc_library("vad_wrapper_unittests") {
":common",
":vad_wrapper",
"..:audio_frame_view",
"../../../rtc_base:checks",
"../../../rtc_base:gunit_helpers",
"../../../rtc_base:safe_compare",
"../../../test:test_support",

View File

@ -36,6 +36,24 @@ AvailableCpuFeatures GetAllowedCpuFeatures() {
return features;
}
// Peak and RMS audio levels in dBFS.
struct AudioLevels {
float peak_dbfs;
float rms_dbfs;
};
// Computes the audio levels for the first channel in `frame`.
AudioLevels ComputeAudioLevels(AudioFrameView<float> frame) {
float peak = 0.0f;
float rms = 0.0f;
for (const auto& x : frame.channel(0)) {
peak = std::max(std::fabs(x), peak);
rms += x * x;
}
return {FloatS16ToDbfs(peak),
FloatS16ToDbfs(std::sqrt(rms / frame.samples_per_channel()))};
}
} // namespace
AdaptiveAgc::AdaptiveAgc(
@ -62,16 +80,17 @@ void AdaptiveAgc::Initialize(int sample_rate_hz, int num_channels) {
}
void AdaptiveAgc::Process(AudioFrameView<float> frame, float limiter_envelope) {
AudioLevels levels = ComputeAudioLevels(frame);
AdaptiveDigitalGainApplier::FrameInfo info;
VadLevelAnalyzer::Result vad_result = vad_.AnalyzeFrame(frame);
info.speech_probability = vad_result.speech_probability;
apm_data_dumper_->DumpRaw("agc2_speech_probability",
vad_result.speech_probability);
apm_data_dumper_->DumpRaw("agc2_input_rms_dbfs", vad_result.rms_dbfs);
apm_data_dumper_->DumpRaw("agc2_input_peak_dbfs", vad_result.peak_dbfs);
info.speech_probability = vad_.Analyze(frame);
apm_data_dumper_->DumpRaw("agc2_speech_probability", info.speech_probability);
apm_data_dumper_->DumpRaw("agc2_input_rms_dbfs", levels.rms_dbfs);
apm_data_dumper_->DumpRaw("agc2_input_peak_dbfs", levels.peak_dbfs);
speech_level_estimator_.Update(vad_result);
speech_level_estimator_.Update(levels.rms_dbfs, levels.peak_dbfs,
info.speech_probability);
info.speech_level_dbfs = speech_level_estimator_.level_dbfs();
info.speech_level_reliable = speech_level_estimator_.IsConfident();
apm_data_dumper_->DumpRaw("agc2_speech_level_dbfs", info.speech_level_dbfs);
@ -81,7 +100,7 @@ void AdaptiveAgc::Process(AudioFrameView<float> frame, float limiter_envelope) {
info.noise_rms_dbfs = noise_level_estimator_->Analyze(frame);
apm_data_dumper_->DumpRaw("agc2_noise_rms_dbfs", info.noise_rms_dbfs);
saturation_protector_->Analyze(info.speech_probability, vad_result.peak_dbfs,
saturation_protector_->Analyze(info.speech_probability, levels.peak_dbfs,
info.speech_level_dbfs);
info.headroom_db = saturation_protector_->HeadroomDb();
apm_data_dumper_->DumpRaw("agc2_headroom_db", info.headroom_db);

View File

@ -47,7 +47,7 @@ class AdaptiveAgc {
private:
AdaptiveModeLevelEstimator speech_level_estimator_;
VadLevelAnalyzer vad_;
VoiceActivityDetectorWrapper vad_;
AdaptiveDigitalGainApplier gain_controller_;
ApmDataDumper* const apm_data_dumper_;
std::unique_ptr<NoiseLevelEstimator> noise_level_estimator_;

View File

@ -57,15 +57,16 @@ AdaptiveModeLevelEstimator::AdaptiveModeLevelEstimator(
Reset();
}
void AdaptiveModeLevelEstimator::Update(
const VadLevelAnalyzer::Result& vad_level) {
RTC_DCHECK_GT(vad_level.rms_dbfs, -150.f);
RTC_DCHECK_LT(vad_level.rms_dbfs, 50.f);
RTC_DCHECK_GT(vad_level.peak_dbfs, -150.f);
RTC_DCHECK_LT(vad_level.peak_dbfs, 50.f);
RTC_DCHECK_GE(vad_level.speech_probability, 0.f);
RTC_DCHECK_LE(vad_level.speech_probability, 1.f);
if (vad_level.speech_probability < kVadConfidenceThreshold) {
void AdaptiveModeLevelEstimator::Update(float rms_dbfs,
float peak_dbfs,
float speech_probability) {
RTC_DCHECK_GT(rms_dbfs, -150.0f);
RTC_DCHECK_LT(rms_dbfs, 50.0f);
RTC_DCHECK_GT(peak_dbfs, -150.0f);
RTC_DCHECK_LT(peak_dbfs, 50.0f);
RTC_DCHECK_GE(speech_probability, 0.0f);
RTC_DCHECK_LE(speech_probability, 1.0f);
if (speech_probability < kVadConfidenceThreshold) {
// Not a speech frame.
if (adjacent_speech_frames_threshold_ > 1) {
// When two or more adjacent speech frames are required in order to update
@ -93,14 +94,14 @@ void AdaptiveModeLevelEstimator::Update(
preliminary_state_.time_to_confidence_ms -= kFrameDurationMs;
}
// Weighted average of levels with speech probability as weight.
RTC_DCHECK_GT(vad_level.speech_probability, 0.f);
const float leak_factor = buffer_is_full ? kLevelEstimatorLeakFactor : 1.f;
RTC_DCHECK_GT(speech_probability, 0.0f);
const float leak_factor = buffer_is_full ? kLevelEstimatorLeakFactor : 1.0f;
preliminary_state_.level_dbfs.numerator =
preliminary_state_.level_dbfs.numerator * leak_factor +
vad_level.rms_dbfs * vad_level.speech_probability;
rms_dbfs * speech_probability;
preliminary_state_.level_dbfs.denominator =
preliminary_state_.level_dbfs.denominator * leak_factor +
vad_level.speech_probability;
speech_probability;
const float level_dbfs = preliminary_state_.level_dbfs.GetRatio();

View File

@ -33,7 +33,7 @@ class AdaptiveModeLevelEstimator {
delete;
// Updates the level estimation.
void Update(const VadLevelAnalyzer::Result& vad_data);
void Update(float rms_dbfs, float peak_dbfs, float speech_probability);
// Returns the estimated speech plus noise level.
float level_dbfs() const { return level_dbfs_; }
// Returns true if the estimator is confident on its current estimate.

View File

@ -33,10 +33,12 @@ constexpr float kConvergenceSpeedTestsLevelTolerance = 0.5f;
// Provides the `vad_level` value `num_iterations` times to `level_estimator`.
void RunOnConstantLevel(int num_iterations,
const VadLevelAnalyzer::Result& vad_level,
float rms_dbfs,
float peak_dbfs,
float speech_probability,
AdaptiveModeLevelEstimator& level_estimator) {
for (int i = 0; i < num_iterations; ++i) {
level_estimator.Update(vad_level);
level_estimator.Update(rms_dbfs, peak_dbfs, speech_probability);
}
}
@ -47,6 +49,10 @@ constexpr AdaptiveDigitalConfig GetAdaptiveDigitalConfig(
return config;
}
constexpr float kNoSpeechProbability = 0.0f;
constexpr float kLowSpeechProbability = kVadConfidenceThreshold / 2.0f;
constexpr float kMaxSpeechProbability = 1.0f;
// Level estimator with data dumper.
struct TestLevelEstimator {
explicit TestLevelEstimator(int adjacent_speech_frames_threshold)
@ -55,36 +61,31 @@ struct TestLevelEstimator {
&data_dumper,
GetAdaptiveDigitalConfig(adjacent_speech_frames_threshold))),
initial_speech_level_dbfs(estimator->level_dbfs()),
vad_level_rms(initial_speech_level_dbfs / 2.0f),
vad_level_peak(initial_speech_level_dbfs / 3.0f),
vad_data_speech(
{/*speech_probability=*/1.0f, vad_level_rms, vad_level_peak}),
vad_data_non_speech(
{/*speech_probability=*/kVadConfidenceThreshold / 2.0f,
vad_level_rms, vad_level_peak}) {
RTC_DCHECK_LT(vad_level_rms, vad_level_peak);
RTC_DCHECK_LT(initial_speech_level_dbfs, vad_level_rms);
RTC_DCHECK_GT(vad_level_rms - initial_speech_level_dbfs, 5.0f)
<< "Adjust `vad_level_rms` so that the difference from the initial "
level_rms_dbfs(initial_speech_level_dbfs / 2.0f),
level_peak_dbfs(initial_speech_level_dbfs / 3.0f) {
RTC_DCHECK_LT(level_rms_dbfs, level_peak_dbfs);
RTC_DCHECK_LT(initial_speech_level_dbfs, level_rms_dbfs);
RTC_DCHECK_GT(level_rms_dbfs - initial_speech_level_dbfs, 5.0f)
<< "Adjust `level_rms_dbfs` so that the difference from the initial "
"level is wide enough for the tests";
}
ApmDataDumper data_dumper;
std::unique_ptr<AdaptiveModeLevelEstimator> estimator;
const float initial_speech_level_dbfs;
const float vad_level_rms;
const float vad_level_peak;
const VadLevelAnalyzer::Result vad_data_speech;
const VadLevelAnalyzer::Result vad_data_non_speech;
const float level_rms_dbfs;
const float level_peak_dbfs;
};
// Checks that the level estimator converges to a constant input speech level.
TEST(GainController2AdaptiveModeLevelEstimator, LevelStabilizes) {
TestLevelEstimator level_estimator(/*adjacent_speech_frames_threshold=*/1);
RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence,
level_estimator.vad_data_speech,
level_estimator.level_rms_dbfs,
level_estimator.level_peak_dbfs, kMaxSpeechProbability,
*level_estimator.estimator);
const float estimated_level_dbfs = level_estimator.estimator->level_dbfs();
RunOnConstantLevel(/*num_iterations=*/1, level_estimator.vad_data_speech,
RunOnConstantLevel(/*num_iterations=*/1, level_estimator.level_rms_dbfs,
level_estimator.level_peak_dbfs, kMaxSpeechProbability,
*level_estimator.estimator);
EXPECT_NEAR(level_estimator.estimator->level_dbfs(), estimated_level_dbfs,
0.1f);
@ -95,7 +96,8 @@ TEST(GainController2AdaptiveModeLevelEstimator, LevelStabilizes) {
TEST(GainController2AdaptiveModeLevelEstimator, IsNotConfident) {
TestLevelEstimator level_estimator(/*adjacent_speech_frames_threshold=*/1);
RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence / 2,
level_estimator.vad_data_speech,
level_estimator.level_rms_dbfs,
level_estimator.level_peak_dbfs, kMaxSpeechProbability,
*level_estimator.estimator);
EXPECT_FALSE(level_estimator.estimator->IsConfident());
}
@ -105,7 +107,8 @@ TEST(GainController2AdaptiveModeLevelEstimator, IsNotConfident) {
TEST(GainController2AdaptiveModeLevelEstimator, IsConfident) {
TestLevelEstimator level_estimator(/*adjacent_speech_frames_threshold=*/1);
RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence,
level_estimator.vad_data_speech,
level_estimator.level_rms_dbfs,
level_estimator.level_peak_dbfs, kMaxSpeechProbability,
*level_estimator.estimator);
EXPECT_TRUE(level_estimator.estimator->IsConfident());
}
@ -117,15 +120,14 @@ TEST(GainController2AdaptiveModeLevelEstimator,
TestLevelEstimator level_estimator(/*adjacent_speech_frames_threshold=*/1);
// Simulate speech.
RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence,
level_estimator.vad_data_speech,
level_estimator.level_rms_dbfs,
level_estimator.level_peak_dbfs, kMaxSpeechProbability,
*level_estimator.estimator);
const float estimated_level_dbfs = level_estimator.estimator->level_dbfs();
// Simulate full-scale non-speech.
RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence,
VadLevelAnalyzer::Result{/*speech_probability=*/0.0f,
/*rms_dbfs=*/0.0f,
/*peak_dbfs=*/0.0f},
*level_estimator.estimator);
/*rms_dbfs=*/0.0f, /*peak_dbfs=*/0.0f,
kNoSpeechProbability, *level_estimator.estimator);
// No estimated level change is expected.
EXPECT_FLOAT_EQ(level_estimator.estimator->level_dbfs(),
estimated_level_dbfs);
@ -136,10 +138,11 @@ TEST(GainController2AdaptiveModeLevelEstimator,
ConvergenceSpeedBeforeConfidence) {
TestLevelEstimator level_estimator(/*adjacent_speech_frames_threshold=*/1);
RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence,
level_estimator.vad_data_speech,
level_estimator.level_rms_dbfs,
level_estimator.level_peak_dbfs, kMaxSpeechProbability,
*level_estimator.estimator);
EXPECT_NEAR(level_estimator.estimator->level_dbfs(),
level_estimator.vad_data_speech.rms_dbfs,
level_estimator.level_rms_dbfs,
kConvergenceSpeedTestsLevelTolerance);
}
@ -150,11 +153,9 @@ TEST(GainController2AdaptiveModeLevelEstimator,
// Reach confidence using the initial level estimate.
RunOnConstantLevel(
/*num_iterations=*/kNumFramesToConfidence,
VadLevelAnalyzer::Result{
/*speech_probability=*/1.0f,
/*rms_dbfs=*/level_estimator.initial_speech_level_dbfs,
/*peak_dbfs=*/level_estimator.initial_speech_level_dbfs + 6.0f},
*level_estimator.estimator);
/*rms_dbfs=*/level_estimator.initial_speech_level_dbfs,
/*peak_dbfs=*/level_estimator.initial_speech_level_dbfs + 6.0f,
kMaxSpeechProbability, *level_estimator.estimator);
// No estimate change should occur, but confidence is achieved.
ASSERT_FLOAT_EQ(level_estimator.estimator->level_dbfs(),
level_estimator.initial_speech_level_dbfs);
@ -165,9 +166,10 @@ TEST(GainController2AdaptiveModeLevelEstimator,
kConvergenceTimeAfterConfidenceNumFrames > kNumFramesToConfidence, "");
RunOnConstantLevel(
/*num_iterations=*/kConvergenceTimeAfterConfidenceNumFrames,
level_estimator.vad_data_speech, *level_estimator.estimator);
level_estimator.level_rms_dbfs, level_estimator.level_peak_dbfs,
kMaxSpeechProbability, *level_estimator.estimator);
EXPECT_NEAR(level_estimator.estimator->level_dbfs(),
level_estimator.vad_data_speech.rms_dbfs,
level_estimator.level_rms_dbfs,
kConvergenceSpeedTestsLevelTolerance);
}
@ -181,22 +183,28 @@ TEST_P(AdaptiveModeLevelEstimatorParametrization,
DoNotAdaptToShortSpeechSegments) {
TestLevelEstimator level_estimator(adjacent_speech_frames_threshold());
const float initial_level = level_estimator.estimator->level_dbfs();
ASSERT_LT(initial_level, level_estimator.vad_data_speech.peak_dbfs);
ASSERT_LT(initial_level, level_estimator.level_peak_dbfs);
for (int i = 0; i < adjacent_speech_frames_threshold() - 1; ++i) {
SCOPED_TRACE(i);
level_estimator.estimator->Update(level_estimator.vad_data_speech);
level_estimator.estimator->Update(level_estimator.level_rms_dbfs,
level_estimator.level_peak_dbfs,
kMaxSpeechProbability);
EXPECT_EQ(initial_level, level_estimator.estimator->level_dbfs());
}
level_estimator.estimator->Update(level_estimator.vad_data_non_speech);
level_estimator.estimator->Update(level_estimator.level_rms_dbfs,
level_estimator.level_peak_dbfs,
kLowSpeechProbability);
EXPECT_EQ(initial_level, level_estimator.estimator->level_dbfs());
}
TEST_P(AdaptiveModeLevelEstimatorParametrization, AdaptToEnoughSpeechSegments) {
TestLevelEstimator level_estimator(adjacent_speech_frames_threshold());
const float initial_level = level_estimator.estimator->level_dbfs();
ASSERT_LT(initial_level, level_estimator.vad_data_speech.peak_dbfs);
ASSERT_LT(initial_level, level_estimator.level_peak_dbfs);
for (int i = 0; i < adjacent_speech_frames_threshold(); ++i) {
level_estimator.estimator->Update(level_estimator.vad_data_speech);
level_estimator.estimator->Update(level_estimator.level_rms_dbfs,
level_estimator.level_peak_dbfs,
kMaxSpeechProbability);
}
EXPECT_LT(initial_level, level_estimator.estimator->level_dbfs());
}

View File

@ -10,13 +10,10 @@
#include "modules/audio_processing/agc2/vad_wrapper.h"
#include <algorithm>
#include <array>
#include <cmath>
#include <utility>
#include "api/array_view.h"
#include "common_audio/include/audio_util.h"
#include "common_audio/resampler/include/push_resampler.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
@ -27,82 +24,72 @@
namespace webrtc {
namespace {
using VoiceActivityDetector = VadLevelAnalyzer::VoiceActivityDetector;
constexpr int kNumFramesPerSecond = 100;
// Default VAD that combines a resampler and the RNN VAD.
// Computes the speech probability on the first channel.
class Vad : public VoiceActivityDetector {
class MonoVadImpl : public VoiceActivityDetectorWrapper::MonoVad {
public:
explicit Vad(const AvailableCpuFeatures& cpu_features)
explicit MonoVadImpl(const AvailableCpuFeatures& cpu_features)
: features_extractor_(cpu_features), rnn_vad_(cpu_features) {}
Vad(const Vad&) = delete;
Vad& operator=(const Vad&) = delete;
~Vad() = default;
MonoVadImpl(const MonoVadImpl&) = delete;
MonoVadImpl& operator=(const MonoVadImpl&) = delete;
~MonoVadImpl() = default;
int SampleRateHz() const override { return rnn_vad::kSampleRate24kHz; }
void Reset() override { rnn_vad_.Reset(); }
float ComputeProbability(AudioFrameView<const float> frame) override {
// The source number of channels is 1, because we always use the 1st
// channel.
resampler_.InitializeIfNeeded(
/*sample_rate_hz=*/static_cast<int>(frame.samples_per_channel() * 100),
rnn_vad::kSampleRate24kHz,
/*num_channels=*/1);
std::array<float, rnn_vad::kFrameSize10ms24kHz> work_frame;
// Feed the 1st channel to the resampler.
resampler_.Resample(frame.channel(0).data(), frame.samples_per_channel(),
work_frame.data(), rnn_vad::kFrameSize10ms24kHz);
float Analyze(rtc::ArrayView<const float> frame) override {
RTC_DCHECK_EQ(frame.size(), rnn_vad::kFrameSize10ms24kHz);
std::array<float, rnn_vad::kFeatureVectorSize> feature_vector;
const bool is_silence = features_extractor_.CheckSilenceComputeFeatures(
work_frame, feature_vector);
/*samples=*/{frame.data(), rnn_vad::kFrameSize10ms24kHz},
feature_vector);
return rnn_vad_.ComputeVadProbability(feature_vector, is_silence);
}
private:
PushResampler<float> resampler_;
rnn_vad::FeaturesExtractor features_extractor_;
rnn_vad::RnnVad rnn_vad_;
};
} // namespace
VadLevelAnalyzer::VadLevelAnalyzer(int vad_reset_period_ms,
const AvailableCpuFeatures& cpu_features)
: VadLevelAnalyzer(vad_reset_period_ms,
std::make_unique<Vad>(cpu_features)) {}
VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper(
int vad_reset_period_ms,
const AvailableCpuFeatures& cpu_features)
: VoiceActivityDetectorWrapper(
vad_reset_period_ms,
std::make_unique<MonoVadImpl>(cpu_features)) {}
VadLevelAnalyzer::VadLevelAnalyzer(int vad_reset_period_ms,
std::unique_ptr<VoiceActivityDetector> vad)
: vad_(std::move(vad)),
vad_reset_period_frames_(
VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper(
int vad_reset_period_ms,
std::unique_ptr<MonoVad> vad)
: vad_reset_period_frames_(
rtc::CheckedDivExact(vad_reset_period_ms, kFrameDurationMs)),
time_to_vad_reset_(vad_reset_period_frames_) {
time_to_vad_reset_(vad_reset_period_frames_),
vad_(std::move(vad)) {
RTC_DCHECK(vad_);
RTC_DCHECK_GT(vad_reset_period_frames_, 1);
resampled_buffer_.resize(
rtc::CheckedDivExact(vad_->SampleRateHz(), kNumFramesPerSecond));
}
VadLevelAnalyzer::~VadLevelAnalyzer() = default;
VoiceActivityDetectorWrapper::~VoiceActivityDetectorWrapper() = default;
VadLevelAnalyzer::Result VadLevelAnalyzer::AnalyzeFrame(
AudioFrameView<const float> frame) {
float VoiceActivityDetectorWrapper::Analyze(AudioFrameView<const float> frame) {
// Periodically reset the VAD.
time_to_vad_reset_--;
if (time_to_vad_reset_ <= 0) {
vad_->Reset();
time_to_vad_reset_ = vad_reset_period_frames_;
}
// Compute levels.
float peak = 0.0f;
float rms = 0.0f;
for (const auto& x : frame.channel(0)) {
peak = std::max(std::fabs(x), peak);
rms += x * x;
}
return {vad_->ComputeProbability(frame),
FloatS16ToDbfs(std::sqrt(rms / frame.samples_per_channel())),
FloatS16ToDbfs(peak)};
// Resample the first channel of `frame`.
resampler_.InitializeIfNeeded(
/*sample_rate_hz=*/frame.samples_per_channel() * kNumFramesPerSecond,
vad_->SampleRateHz(), /*num_channels=*/1);
resampler_.Resample(frame.channel(0).data(), frame.samples_per_channel(),
resampled_buffer_.data(), resampled_buffer_.size());
return vad_->Analyze(resampled_buffer_);
}
} // namespace webrtc

View File

@ -12,51 +12,57 @@
#define MODULES_AUDIO_PROCESSING_AGC2_VAD_WRAPPER_H_
#include <memory>
#include <vector>
#include "api/array_view.h"
#include "common_audio/resampler/include/push_resampler.h"
#include "modules/audio_processing/agc2/cpu_features.h"
#include "modules/audio_processing/include/audio_frame_view.h"
namespace webrtc {
// Class to analyze voice activity and audio levels.
class VadLevelAnalyzer {
// Wraps a single-channel Voice Activity Detector (VAD) which is used to analyze
// the first channel of the input audio frames. Takes care of resampling the
// input frames to match the sample rate of the wrapped VAD and periodically
// resets the VAD.
class VoiceActivityDetectorWrapper {
public:
struct Result {
float speech_probability; // Range: [0, 1].
float rms_dbfs; // Root mean square power (dBFS).
float peak_dbfs; // Peak power (dBFS).
};
// Voice Activity Detector (VAD) interface.
class VoiceActivityDetector {
// Single channel VAD interface.
class MonoVad {
public:
virtual ~VoiceActivityDetector() = default;
virtual ~MonoVad() = default;
// Returns the sample rate (Hz) required for the input frames analyzed by
// `ComputeProbability`.
virtual int SampleRateHz() const = 0;
// Resets the internal state.
virtual void Reset() = 0;
// Analyzes an audio frame and returns the speech probability.
virtual float ComputeProbability(AudioFrameView<const float> frame) = 0;
virtual float Analyze(rtc::ArrayView<const float> frame) = 0;
};
// Ctor. `vad_reset_period_ms` indicates the period in milliseconds to call
// `VadLevelAnalyzer::Reset()`; it must be equal to or greater than the
// duration of two frames. Uses `cpu_features` to instantiate the default VAD.
VadLevelAnalyzer(int vad_reset_period_ms,
const AvailableCpuFeatures& cpu_features);
// `MonoVad::Reset()`; it must be equal to or greater than the duration of two
// frames. Uses `cpu_features` to instantiate the default VAD.
VoiceActivityDetectorWrapper(int vad_reset_period_ms,
const AvailableCpuFeatures& cpu_features);
// Ctor. Uses a custom `vad`.
VadLevelAnalyzer(int vad_reset_period_ms,
std::unique_ptr<VoiceActivityDetector> vad);
VoiceActivityDetectorWrapper(int vad_reset_period_ms,
std::unique_ptr<MonoVad> vad);
VadLevelAnalyzer(const VadLevelAnalyzer&) = delete;
VadLevelAnalyzer& operator=(const VadLevelAnalyzer&) = delete;
~VadLevelAnalyzer();
VoiceActivityDetectorWrapper(const VoiceActivityDetectorWrapper&) = delete;
VoiceActivityDetectorWrapper& operator=(const VoiceActivityDetectorWrapper&) =
delete;
~VoiceActivityDetectorWrapper();
// Computes the speech probability and the level for `frame`.
Result AnalyzeFrame(AudioFrameView<const float> frame);
// Analyzes the first channel of `frame` and returns the speech probability.
float Analyze(AudioFrameView<const float> frame);
private:
std::unique_ptr<VoiceActivityDetector> vad_;
const int vad_reset_period_frames_;
int time_to_vad_reset_;
PushResampler<float> resampler_;
std::unique_ptr<MonoVad> vad_;
std::vector<float> resampled_buffer_;
};
} // namespace webrtc

View File

@ -18,6 +18,7 @@
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/include/audio_frame_view.h"
#include "rtc_base/checks.h"
#include "rtc_base/gunit.h"
#include "rtc_base/numerics/safe_compare.h"
#include "test/gmock.h"
@ -26,90 +27,78 @@ namespace webrtc {
namespace {
using ::testing::AnyNumber;
using ::testing::Return;
using ::testing::ReturnRoundRobin;
using ::testing::Truly;
constexpr int kNoVadPeriodicReset =
kFrameDurationMs * (std::numeric_limits<int>::max() / kFrameDurationMs);
constexpr int kSampleRateHz = 8000;
constexpr int kSampleRate8kHz = 8000;
class MockVad : public VadLevelAnalyzer::VoiceActivityDetector {
class MockVad : public VoiceActivityDetectorWrapper::MonoVad {
public:
MOCK_METHOD(int, SampleRateHz, (), (const override));
MOCK_METHOD(void, Reset, (), (override));
MOCK_METHOD(float,
ComputeProbability,
(AudioFrameView<const float> frame),
(override));
MOCK_METHOD(float, Analyze, (rtc::ArrayView<const float> frame), (override));
};
// Creates a `VadLevelAnalyzer` injecting a mock VAD which repeatedly returns
// the next value from `speech_probabilities` until it reaches the end and will
// restart from the beginning.
std::unique_ptr<VadLevelAnalyzer> CreateVadLevelAnalyzerWithMockVad(
// Creates a `VoiceActivityDetectorWrapper` injecting a mock VAD that
// repeatedly returns the next value from `speech_probabilities` and that
// restarts from the beginning when after the last element is returned.
std::unique_ptr<VoiceActivityDetectorWrapper> CreateMockVadWrapper(
int vad_reset_period_ms,
const std::vector<float>& speech_probabilities,
int expected_vad_reset_calls = 0) {
auto vad = std::make_unique<MockVad>();
EXPECT_CALL(*vad, ComputeProbability)
EXPECT_CALL(*vad, SampleRateHz)
.Times(AnyNumber())
.WillRepeatedly(ReturnRoundRobin(speech_probabilities));
.WillRepeatedly(Return(kSampleRate8kHz));
if (expected_vad_reset_calls >= 0) {
EXPECT_CALL(*vad, Reset).Times(expected_vad_reset_calls);
}
return std::make_unique<VadLevelAnalyzer>(vad_reset_period_ms,
std::move(vad));
EXPECT_CALL(*vad, Analyze)
.Times(AnyNumber())
.WillRepeatedly(ReturnRoundRobin(speech_probabilities));
return std::make_unique<VoiceActivityDetectorWrapper>(vad_reset_period_ms,
std::move(vad));
}
// 10 ms mono frame.
struct FrameWithView {
// Ctor. Initializes the frame samples with `value`.
explicit FrameWithView(float value = 0.0f)
: channel0(samples.data()),
view(&channel0, /*num_channels=*/1, samples.size()) {
samples.fill(value);
}
std::array<float, kSampleRateHz / 100> samples;
explicit FrameWithView(int sample_rate_hz = kSampleRate8kHz)
: samples(rtc::CheckedDivExact(sample_rate_hz, 100), 0.0f),
channel0(samples.data()),
view(&channel0, /*num_channels=*/1, samples.size()) {}
std::vector<float> samples;
const float* const channel0;
const AudioFrameView<const float> view;
};
TEST(GainController2VadLevelAnalyzer, RmsLessThanPeakLevel) {
auto analyzer = CreateVadLevelAnalyzerWithMockVad(
/*vad_reset_period_ms=*/1500,
/*speech_probabilities=*/{1.0f},
/*expected_vad_reset_calls=*/0);
// Handcrafted frame so that the average is lower than the peak value.
FrameWithView frame(1000.0f); // Constant frame.
frame.samples[10] = 2000.0f; // Except for one peak value.
// Compute audio frame levels.
auto levels_and_vad_prob = analyzer->AnalyzeFrame(frame.view);
EXPECT_LT(levels_and_vad_prob.rms_dbfs, levels_and_vad_prob.peak_dbfs);
}
// Checks that the expect VAD probabilities are returned.
TEST(GainController2VadLevelAnalyzer, NoSpeechProbabilitySmoothing) {
// Checks that the expected speech probabilities are returned.
TEST(GainController2VoiceActivityDetectorWrapper, CheckSpeechProbabilities) {
const std::vector<float> speech_probabilities{0.709f, 0.484f, 0.882f, 0.167f,
0.44f, 0.525f, 0.858f, 0.314f,
0.653f, 0.965f, 0.413f, 0.0f};
auto analyzer = CreateVadLevelAnalyzerWithMockVad(kNoVadPeriodicReset,
speech_probabilities);
auto vad_wrapper =
CreateMockVadWrapper(kNoVadPeriodicReset, speech_probabilities);
FrameWithView frame;
for (int i = 0; rtc::SafeLt(i, speech_probabilities.size()); ++i) {
SCOPED_TRACE(i);
EXPECT_EQ(speech_probabilities[i],
analyzer->AnalyzeFrame(frame.view).speech_probability);
EXPECT_EQ(speech_probabilities[i], vad_wrapper->Analyze(frame.view));
}
}
// Checks that the VAD is not periodically reset.
TEST(GainController2VadLevelAnalyzer, VadNoPeriodicReset) {
TEST(GainController2VoiceActivityDetectorWrapper, VadNoPeriodicReset) {
constexpr int kNumFrames = 19;
auto analyzer = CreateVadLevelAnalyzerWithMockVad(
kNoVadPeriodicReset, /*speech_probabilities=*/{1.0f},
/*expected_vad_reset_calls=*/0);
auto vad_wrapper =
CreateMockVadWrapper(kNoVadPeriodicReset, /*speech_probabilities=*/{1.0f},
/*expected_vad_reset_calls=*/0);
FrameWithView frame;
for (int i = 0; i < kNumFrames; ++i) {
analyzer->AnalyzeFrame(frame.view);
vad_wrapper->Analyze(frame.view);
}
}
@ -122,20 +111,52 @@ class VadPeriodResetParametrization
// Checks that the VAD is periodically reset with the expected period.
TEST_P(VadPeriodResetParametrization, VadPeriodicReset) {
auto analyzer = CreateVadLevelAnalyzerWithMockVad(
auto vad_wrapper = CreateMockVadWrapper(
/*vad_reset_period_ms=*/vad_reset_period_frames() * kFrameDurationMs,
/*speech_probabilities=*/{1.0f},
/*expected_vad_reset_calls=*/num_frames() / vad_reset_period_frames());
FrameWithView frame;
for (int i = 0; i < num_frames(); ++i) {
analyzer->AnalyzeFrame(frame.view);
vad_wrapper->Analyze(frame.view);
}
}
INSTANTIATE_TEST_SUITE_P(GainController2VadLevelAnalyzer,
INSTANTIATE_TEST_SUITE_P(GainController2VoiceActivityDetectorWrapper,
VadPeriodResetParametrization,
::testing::Combine(::testing::Values(1, 19, 123),
::testing::Values(2, 5, 20, 53)));
class VadResamplingParametrization
: public ::testing::TestWithParam<std::tuple<int, int>> {
protected:
int input_sample_rate_hz() const { return std::get<0>(GetParam()); }
int vad_sample_rate_hz() const { return std::get<1>(GetParam()); }
};
// Checks that regardless of the input audio sample rate, the wrapped VAD
// analyzes frames having the expected size, that is according to its internal
// sample rate.
TEST_P(VadResamplingParametrization, CheckResampledFrameSize) {
auto vad = std::make_unique<MockVad>();
EXPECT_CALL(*vad, SampleRateHz)
.Times(AnyNumber())
.WillRepeatedly(Return(vad_sample_rate_hz()));
EXPECT_CALL(*vad, Reset).Times(0);
EXPECT_CALL(*vad, Analyze(Truly([this](rtc::ArrayView<const float> frame) {
return rtc::SafeEq(frame.size(),
rtc::CheckedDivExact(vad_sample_rate_hz(), 100));
}))).Times(1);
auto vad_wrapper = std::make_unique<VoiceActivityDetectorWrapper>(
kNoVadPeriodicReset, std::move(vad));
FrameWithView frame(input_sample_rate_hz());
vad_wrapper->Analyze(frame.view);
}
INSTANTIATE_TEST_SUITE_P(
GainController2VoiceActivityDetectorWrapper,
VadResamplingParametrization,
::testing::Combine(::testing::Values(8000, 16000, 44100, 48000),
::testing::Values(6000, 8000, 12000, 16000, 24000)));
} // namespace
} // namespace webrtc