AGC2: VadWithLevel -> VoiceActivityDetectorWrapper 2/2
Internal refactoring of AGC2 to decouple the VAD, its wrapper and the peak and RMS level measurements. Bit exactness verified with audioproc_f on a collection of AEC dumps and Wav files (42 recordings in total). Bug: webrtc:7494 Change-Id: Ib560f1fcaa601557f4f30e47025c69e91b1b62e0 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/234524 Commit-Queue: Alessio Bazzica <alessiob@webrtc.org> Reviewed-by: Hanna Silen <silen@webrtc.org> Cr-Commit-Position: refs/heads/main@{#35208}
This commit is contained in:
parent
54f377308f
commit
8dbdf5e3bf
@ -280,6 +280,7 @@ rtc_library("vad_wrapper_unittests") {
|
||||
":common",
|
||||
":vad_wrapper",
|
||||
"..:audio_frame_view",
|
||||
"../../../rtc_base:checks",
|
||||
"../../../rtc_base:gunit_helpers",
|
||||
"../../../rtc_base:safe_compare",
|
||||
"../../../test:test_support",
|
||||
|
||||
@ -36,6 +36,24 @@ AvailableCpuFeatures GetAllowedCpuFeatures() {
|
||||
return features;
|
||||
}
|
||||
|
||||
// Peak and RMS audio levels in dBFS.
|
||||
struct AudioLevels {
|
||||
float peak_dbfs;
|
||||
float rms_dbfs;
|
||||
};
|
||||
|
||||
// Computes the audio levels for the first channel in `frame`.
|
||||
AudioLevels ComputeAudioLevels(AudioFrameView<float> frame) {
|
||||
float peak = 0.0f;
|
||||
float rms = 0.0f;
|
||||
for (const auto& x : frame.channel(0)) {
|
||||
peak = std::max(std::fabs(x), peak);
|
||||
rms += x * x;
|
||||
}
|
||||
return {FloatS16ToDbfs(peak),
|
||||
FloatS16ToDbfs(std::sqrt(rms / frame.samples_per_channel()))};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
AdaptiveAgc::AdaptiveAgc(
|
||||
@ -62,16 +80,17 @@ void AdaptiveAgc::Initialize(int sample_rate_hz, int num_channels) {
|
||||
}
|
||||
|
||||
void AdaptiveAgc::Process(AudioFrameView<float> frame, float limiter_envelope) {
|
||||
AudioLevels levels = ComputeAudioLevels(frame);
|
||||
|
||||
AdaptiveDigitalGainApplier::FrameInfo info;
|
||||
|
||||
VadLevelAnalyzer::Result vad_result = vad_.AnalyzeFrame(frame);
|
||||
info.speech_probability = vad_result.speech_probability;
|
||||
apm_data_dumper_->DumpRaw("agc2_speech_probability",
|
||||
vad_result.speech_probability);
|
||||
apm_data_dumper_->DumpRaw("agc2_input_rms_dbfs", vad_result.rms_dbfs);
|
||||
apm_data_dumper_->DumpRaw("agc2_input_peak_dbfs", vad_result.peak_dbfs);
|
||||
info.speech_probability = vad_.Analyze(frame);
|
||||
apm_data_dumper_->DumpRaw("agc2_speech_probability", info.speech_probability);
|
||||
apm_data_dumper_->DumpRaw("agc2_input_rms_dbfs", levels.rms_dbfs);
|
||||
apm_data_dumper_->DumpRaw("agc2_input_peak_dbfs", levels.peak_dbfs);
|
||||
|
||||
speech_level_estimator_.Update(vad_result);
|
||||
speech_level_estimator_.Update(levels.rms_dbfs, levels.peak_dbfs,
|
||||
info.speech_probability);
|
||||
info.speech_level_dbfs = speech_level_estimator_.level_dbfs();
|
||||
info.speech_level_reliable = speech_level_estimator_.IsConfident();
|
||||
apm_data_dumper_->DumpRaw("agc2_speech_level_dbfs", info.speech_level_dbfs);
|
||||
@ -81,7 +100,7 @@ void AdaptiveAgc::Process(AudioFrameView<float> frame, float limiter_envelope) {
|
||||
info.noise_rms_dbfs = noise_level_estimator_->Analyze(frame);
|
||||
apm_data_dumper_->DumpRaw("agc2_noise_rms_dbfs", info.noise_rms_dbfs);
|
||||
|
||||
saturation_protector_->Analyze(info.speech_probability, vad_result.peak_dbfs,
|
||||
saturation_protector_->Analyze(info.speech_probability, levels.peak_dbfs,
|
||||
info.speech_level_dbfs);
|
||||
info.headroom_db = saturation_protector_->HeadroomDb();
|
||||
apm_data_dumper_->DumpRaw("agc2_headroom_db", info.headroom_db);
|
||||
|
||||
@ -47,7 +47,7 @@ class AdaptiveAgc {
|
||||
|
||||
private:
|
||||
AdaptiveModeLevelEstimator speech_level_estimator_;
|
||||
VadLevelAnalyzer vad_;
|
||||
VoiceActivityDetectorWrapper vad_;
|
||||
AdaptiveDigitalGainApplier gain_controller_;
|
||||
ApmDataDumper* const apm_data_dumper_;
|
||||
std::unique_ptr<NoiseLevelEstimator> noise_level_estimator_;
|
||||
|
||||
@ -57,15 +57,16 @@ AdaptiveModeLevelEstimator::AdaptiveModeLevelEstimator(
|
||||
Reset();
|
||||
}
|
||||
|
||||
void AdaptiveModeLevelEstimator::Update(
|
||||
const VadLevelAnalyzer::Result& vad_level) {
|
||||
RTC_DCHECK_GT(vad_level.rms_dbfs, -150.f);
|
||||
RTC_DCHECK_LT(vad_level.rms_dbfs, 50.f);
|
||||
RTC_DCHECK_GT(vad_level.peak_dbfs, -150.f);
|
||||
RTC_DCHECK_LT(vad_level.peak_dbfs, 50.f);
|
||||
RTC_DCHECK_GE(vad_level.speech_probability, 0.f);
|
||||
RTC_DCHECK_LE(vad_level.speech_probability, 1.f);
|
||||
if (vad_level.speech_probability < kVadConfidenceThreshold) {
|
||||
void AdaptiveModeLevelEstimator::Update(float rms_dbfs,
|
||||
float peak_dbfs,
|
||||
float speech_probability) {
|
||||
RTC_DCHECK_GT(rms_dbfs, -150.0f);
|
||||
RTC_DCHECK_LT(rms_dbfs, 50.0f);
|
||||
RTC_DCHECK_GT(peak_dbfs, -150.0f);
|
||||
RTC_DCHECK_LT(peak_dbfs, 50.0f);
|
||||
RTC_DCHECK_GE(speech_probability, 0.0f);
|
||||
RTC_DCHECK_LE(speech_probability, 1.0f);
|
||||
if (speech_probability < kVadConfidenceThreshold) {
|
||||
// Not a speech frame.
|
||||
if (adjacent_speech_frames_threshold_ > 1) {
|
||||
// When two or more adjacent speech frames are required in order to update
|
||||
@ -93,14 +94,14 @@ void AdaptiveModeLevelEstimator::Update(
|
||||
preliminary_state_.time_to_confidence_ms -= kFrameDurationMs;
|
||||
}
|
||||
// Weighted average of levels with speech probability as weight.
|
||||
RTC_DCHECK_GT(vad_level.speech_probability, 0.f);
|
||||
const float leak_factor = buffer_is_full ? kLevelEstimatorLeakFactor : 1.f;
|
||||
RTC_DCHECK_GT(speech_probability, 0.0f);
|
||||
const float leak_factor = buffer_is_full ? kLevelEstimatorLeakFactor : 1.0f;
|
||||
preliminary_state_.level_dbfs.numerator =
|
||||
preliminary_state_.level_dbfs.numerator * leak_factor +
|
||||
vad_level.rms_dbfs * vad_level.speech_probability;
|
||||
rms_dbfs * speech_probability;
|
||||
preliminary_state_.level_dbfs.denominator =
|
||||
preliminary_state_.level_dbfs.denominator * leak_factor +
|
||||
vad_level.speech_probability;
|
||||
speech_probability;
|
||||
|
||||
const float level_dbfs = preliminary_state_.level_dbfs.GetRatio();
|
||||
|
||||
|
||||
@ -33,7 +33,7 @@ class AdaptiveModeLevelEstimator {
|
||||
delete;
|
||||
|
||||
// Updates the level estimation.
|
||||
void Update(const VadLevelAnalyzer::Result& vad_data);
|
||||
void Update(float rms_dbfs, float peak_dbfs, float speech_probability);
|
||||
// Returns the estimated speech plus noise level.
|
||||
float level_dbfs() const { return level_dbfs_; }
|
||||
// Returns true if the estimator is confident on its current estimate.
|
||||
|
||||
@ -33,10 +33,12 @@ constexpr float kConvergenceSpeedTestsLevelTolerance = 0.5f;
|
||||
|
||||
// Provides the `vad_level` value `num_iterations` times to `level_estimator`.
|
||||
void RunOnConstantLevel(int num_iterations,
|
||||
const VadLevelAnalyzer::Result& vad_level,
|
||||
float rms_dbfs,
|
||||
float peak_dbfs,
|
||||
float speech_probability,
|
||||
AdaptiveModeLevelEstimator& level_estimator) {
|
||||
for (int i = 0; i < num_iterations; ++i) {
|
||||
level_estimator.Update(vad_level);
|
||||
level_estimator.Update(rms_dbfs, peak_dbfs, speech_probability);
|
||||
}
|
||||
}
|
||||
|
||||
@ -47,6 +49,10 @@ constexpr AdaptiveDigitalConfig GetAdaptiveDigitalConfig(
|
||||
return config;
|
||||
}
|
||||
|
||||
constexpr float kNoSpeechProbability = 0.0f;
|
||||
constexpr float kLowSpeechProbability = kVadConfidenceThreshold / 2.0f;
|
||||
constexpr float kMaxSpeechProbability = 1.0f;
|
||||
|
||||
// Level estimator with data dumper.
|
||||
struct TestLevelEstimator {
|
||||
explicit TestLevelEstimator(int adjacent_speech_frames_threshold)
|
||||
@ -55,36 +61,31 @@ struct TestLevelEstimator {
|
||||
&data_dumper,
|
||||
GetAdaptiveDigitalConfig(adjacent_speech_frames_threshold))),
|
||||
initial_speech_level_dbfs(estimator->level_dbfs()),
|
||||
vad_level_rms(initial_speech_level_dbfs / 2.0f),
|
||||
vad_level_peak(initial_speech_level_dbfs / 3.0f),
|
||||
vad_data_speech(
|
||||
{/*speech_probability=*/1.0f, vad_level_rms, vad_level_peak}),
|
||||
vad_data_non_speech(
|
||||
{/*speech_probability=*/kVadConfidenceThreshold / 2.0f,
|
||||
vad_level_rms, vad_level_peak}) {
|
||||
RTC_DCHECK_LT(vad_level_rms, vad_level_peak);
|
||||
RTC_DCHECK_LT(initial_speech_level_dbfs, vad_level_rms);
|
||||
RTC_DCHECK_GT(vad_level_rms - initial_speech_level_dbfs, 5.0f)
|
||||
<< "Adjust `vad_level_rms` so that the difference from the initial "
|
||||
level_rms_dbfs(initial_speech_level_dbfs / 2.0f),
|
||||
level_peak_dbfs(initial_speech_level_dbfs / 3.0f) {
|
||||
RTC_DCHECK_LT(level_rms_dbfs, level_peak_dbfs);
|
||||
RTC_DCHECK_LT(initial_speech_level_dbfs, level_rms_dbfs);
|
||||
RTC_DCHECK_GT(level_rms_dbfs - initial_speech_level_dbfs, 5.0f)
|
||||
<< "Adjust `level_rms_dbfs` so that the difference from the initial "
|
||||
"level is wide enough for the tests";
|
||||
}
|
||||
ApmDataDumper data_dumper;
|
||||
std::unique_ptr<AdaptiveModeLevelEstimator> estimator;
|
||||
const float initial_speech_level_dbfs;
|
||||
const float vad_level_rms;
|
||||
const float vad_level_peak;
|
||||
const VadLevelAnalyzer::Result vad_data_speech;
|
||||
const VadLevelAnalyzer::Result vad_data_non_speech;
|
||||
const float level_rms_dbfs;
|
||||
const float level_peak_dbfs;
|
||||
};
|
||||
|
||||
// Checks that the level estimator converges to a constant input speech level.
|
||||
TEST(GainController2AdaptiveModeLevelEstimator, LevelStabilizes) {
|
||||
TestLevelEstimator level_estimator(/*adjacent_speech_frames_threshold=*/1);
|
||||
RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence,
|
||||
level_estimator.vad_data_speech,
|
||||
level_estimator.level_rms_dbfs,
|
||||
level_estimator.level_peak_dbfs, kMaxSpeechProbability,
|
||||
*level_estimator.estimator);
|
||||
const float estimated_level_dbfs = level_estimator.estimator->level_dbfs();
|
||||
RunOnConstantLevel(/*num_iterations=*/1, level_estimator.vad_data_speech,
|
||||
RunOnConstantLevel(/*num_iterations=*/1, level_estimator.level_rms_dbfs,
|
||||
level_estimator.level_peak_dbfs, kMaxSpeechProbability,
|
||||
*level_estimator.estimator);
|
||||
EXPECT_NEAR(level_estimator.estimator->level_dbfs(), estimated_level_dbfs,
|
||||
0.1f);
|
||||
@ -95,7 +96,8 @@ TEST(GainController2AdaptiveModeLevelEstimator, LevelStabilizes) {
|
||||
TEST(GainController2AdaptiveModeLevelEstimator, IsNotConfident) {
|
||||
TestLevelEstimator level_estimator(/*adjacent_speech_frames_threshold=*/1);
|
||||
RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence / 2,
|
||||
level_estimator.vad_data_speech,
|
||||
level_estimator.level_rms_dbfs,
|
||||
level_estimator.level_peak_dbfs, kMaxSpeechProbability,
|
||||
*level_estimator.estimator);
|
||||
EXPECT_FALSE(level_estimator.estimator->IsConfident());
|
||||
}
|
||||
@ -105,7 +107,8 @@ TEST(GainController2AdaptiveModeLevelEstimator, IsNotConfident) {
|
||||
TEST(GainController2AdaptiveModeLevelEstimator, IsConfident) {
|
||||
TestLevelEstimator level_estimator(/*adjacent_speech_frames_threshold=*/1);
|
||||
RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence,
|
||||
level_estimator.vad_data_speech,
|
||||
level_estimator.level_rms_dbfs,
|
||||
level_estimator.level_peak_dbfs, kMaxSpeechProbability,
|
||||
*level_estimator.estimator);
|
||||
EXPECT_TRUE(level_estimator.estimator->IsConfident());
|
||||
}
|
||||
@ -117,15 +120,14 @@ TEST(GainController2AdaptiveModeLevelEstimator,
|
||||
TestLevelEstimator level_estimator(/*adjacent_speech_frames_threshold=*/1);
|
||||
// Simulate speech.
|
||||
RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence,
|
||||
level_estimator.vad_data_speech,
|
||||
level_estimator.level_rms_dbfs,
|
||||
level_estimator.level_peak_dbfs, kMaxSpeechProbability,
|
||||
*level_estimator.estimator);
|
||||
const float estimated_level_dbfs = level_estimator.estimator->level_dbfs();
|
||||
// Simulate full-scale non-speech.
|
||||
RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence,
|
||||
VadLevelAnalyzer::Result{/*speech_probability=*/0.0f,
|
||||
/*rms_dbfs=*/0.0f,
|
||||
/*peak_dbfs=*/0.0f},
|
||||
*level_estimator.estimator);
|
||||
/*rms_dbfs=*/0.0f, /*peak_dbfs=*/0.0f,
|
||||
kNoSpeechProbability, *level_estimator.estimator);
|
||||
// No estimated level change is expected.
|
||||
EXPECT_FLOAT_EQ(level_estimator.estimator->level_dbfs(),
|
||||
estimated_level_dbfs);
|
||||
@ -136,10 +138,11 @@ TEST(GainController2AdaptiveModeLevelEstimator,
|
||||
ConvergenceSpeedBeforeConfidence) {
|
||||
TestLevelEstimator level_estimator(/*adjacent_speech_frames_threshold=*/1);
|
||||
RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence,
|
||||
level_estimator.vad_data_speech,
|
||||
level_estimator.level_rms_dbfs,
|
||||
level_estimator.level_peak_dbfs, kMaxSpeechProbability,
|
||||
*level_estimator.estimator);
|
||||
EXPECT_NEAR(level_estimator.estimator->level_dbfs(),
|
||||
level_estimator.vad_data_speech.rms_dbfs,
|
||||
level_estimator.level_rms_dbfs,
|
||||
kConvergenceSpeedTestsLevelTolerance);
|
||||
}
|
||||
|
||||
@ -150,11 +153,9 @@ TEST(GainController2AdaptiveModeLevelEstimator,
|
||||
// Reach confidence using the initial level estimate.
|
||||
RunOnConstantLevel(
|
||||
/*num_iterations=*/kNumFramesToConfidence,
|
||||
VadLevelAnalyzer::Result{
|
||||
/*speech_probability=*/1.0f,
|
||||
/*rms_dbfs=*/level_estimator.initial_speech_level_dbfs,
|
||||
/*peak_dbfs=*/level_estimator.initial_speech_level_dbfs + 6.0f},
|
||||
*level_estimator.estimator);
|
||||
/*rms_dbfs=*/level_estimator.initial_speech_level_dbfs,
|
||||
/*peak_dbfs=*/level_estimator.initial_speech_level_dbfs + 6.0f,
|
||||
kMaxSpeechProbability, *level_estimator.estimator);
|
||||
// No estimate change should occur, but confidence is achieved.
|
||||
ASSERT_FLOAT_EQ(level_estimator.estimator->level_dbfs(),
|
||||
level_estimator.initial_speech_level_dbfs);
|
||||
@ -165,9 +166,10 @@ TEST(GainController2AdaptiveModeLevelEstimator,
|
||||
kConvergenceTimeAfterConfidenceNumFrames > kNumFramesToConfidence, "");
|
||||
RunOnConstantLevel(
|
||||
/*num_iterations=*/kConvergenceTimeAfterConfidenceNumFrames,
|
||||
level_estimator.vad_data_speech, *level_estimator.estimator);
|
||||
level_estimator.level_rms_dbfs, level_estimator.level_peak_dbfs,
|
||||
kMaxSpeechProbability, *level_estimator.estimator);
|
||||
EXPECT_NEAR(level_estimator.estimator->level_dbfs(),
|
||||
level_estimator.vad_data_speech.rms_dbfs,
|
||||
level_estimator.level_rms_dbfs,
|
||||
kConvergenceSpeedTestsLevelTolerance);
|
||||
}
|
||||
|
||||
@ -181,22 +183,28 @@ TEST_P(AdaptiveModeLevelEstimatorParametrization,
|
||||
DoNotAdaptToShortSpeechSegments) {
|
||||
TestLevelEstimator level_estimator(adjacent_speech_frames_threshold());
|
||||
const float initial_level = level_estimator.estimator->level_dbfs();
|
||||
ASSERT_LT(initial_level, level_estimator.vad_data_speech.peak_dbfs);
|
||||
ASSERT_LT(initial_level, level_estimator.level_peak_dbfs);
|
||||
for (int i = 0; i < adjacent_speech_frames_threshold() - 1; ++i) {
|
||||
SCOPED_TRACE(i);
|
||||
level_estimator.estimator->Update(level_estimator.vad_data_speech);
|
||||
level_estimator.estimator->Update(level_estimator.level_rms_dbfs,
|
||||
level_estimator.level_peak_dbfs,
|
||||
kMaxSpeechProbability);
|
||||
EXPECT_EQ(initial_level, level_estimator.estimator->level_dbfs());
|
||||
}
|
||||
level_estimator.estimator->Update(level_estimator.vad_data_non_speech);
|
||||
level_estimator.estimator->Update(level_estimator.level_rms_dbfs,
|
||||
level_estimator.level_peak_dbfs,
|
||||
kLowSpeechProbability);
|
||||
EXPECT_EQ(initial_level, level_estimator.estimator->level_dbfs());
|
||||
}
|
||||
|
||||
TEST_P(AdaptiveModeLevelEstimatorParametrization, AdaptToEnoughSpeechSegments) {
|
||||
TestLevelEstimator level_estimator(adjacent_speech_frames_threshold());
|
||||
const float initial_level = level_estimator.estimator->level_dbfs();
|
||||
ASSERT_LT(initial_level, level_estimator.vad_data_speech.peak_dbfs);
|
||||
ASSERT_LT(initial_level, level_estimator.level_peak_dbfs);
|
||||
for (int i = 0; i < adjacent_speech_frames_threshold(); ++i) {
|
||||
level_estimator.estimator->Update(level_estimator.vad_data_speech);
|
||||
level_estimator.estimator->Update(level_estimator.level_rms_dbfs,
|
||||
level_estimator.level_peak_dbfs,
|
||||
kMaxSpeechProbability);
|
||||
}
|
||||
EXPECT_LT(initial_level, level_estimator.estimator->level_dbfs());
|
||||
}
|
||||
|
||||
@ -10,13 +10,10 @@
|
||||
|
||||
#include "modules/audio_processing/agc2/vad_wrapper.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <utility>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "common_audio/include/audio_util.h"
|
||||
#include "common_audio/resampler/include/push_resampler.h"
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
@ -27,82 +24,72 @@
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
|
||||
using VoiceActivityDetector = VadLevelAnalyzer::VoiceActivityDetector;
|
||||
constexpr int kNumFramesPerSecond = 100;
|
||||
|
||||
// Default VAD that combines a resampler and the RNN VAD.
|
||||
// Computes the speech probability on the first channel.
|
||||
class Vad : public VoiceActivityDetector {
|
||||
class MonoVadImpl : public VoiceActivityDetectorWrapper::MonoVad {
|
||||
public:
|
||||
explicit Vad(const AvailableCpuFeatures& cpu_features)
|
||||
explicit MonoVadImpl(const AvailableCpuFeatures& cpu_features)
|
||||
: features_extractor_(cpu_features), rnn_vad_(cpu_features) {}
|
||||
Vad(const Vad&) = delete;
|
||||
Vad& operator=(const Vad&) = delete;
|
||||
~Vad() = default;
|
||||
MonoVadImpl(const MonoVadImpl&) = delete;
|
||||
MonoVadImpl& operator=(const MonoVadImpl&) = delete;
|
||||
~MonoVadImpl() = default;
|
||||
|
||||
int SampleRateHz() const override { return rnn_vad::kSampleRate24kHz; }
|
||||
void Reset() override { rnn_vad_.Reset(); }
|
||||
|
||||
float ComputeProbability(AudioFrameView<const float> frame) override {
|
||||
// The source number of channels is 1, because we always use the 1st
|
||||
// channel.
|
||||
resampler_.InitializeIfNeeded(
|
||||
/*sample_rate_hz=*/static_cast<int>(frame.samples_per_channel() * 100),
|
||||
rnn_vad::kSampleRate24kHz,
|
||||
/*num_channels=*/1);
|
||||
|
||||
std::array<float, rnn_vad::kFrameSize10ms24kHz> work_frame;
|
||||
// Feed the 1st channel to the resampler.
|
||||
resampler_.Resample(frame.channel(0).data(), frame.samples_per_channel(),
|
||||
work_frame.data(), rnn_vad::kFrameSize10ms24kHz);
|
||||
|
||||
float Analyze(rtc::ArrayView<const float> frame) override {
|
||||
RTC_DCHECK_EQ(frame.size(), rnn_vad::kFrameSize10ms24kHz);
|
||||
std::array<float, rnn_vad::kFeatureVectorSize> feature_vector;
|
||||
const bool is_silence = features_extractor_.CheckSilenceComputeFeatures(
|
||||
work_frame, feature_vector);
|
||||
/*samples=*/{frame.data(), rnn_vad::kFrameSize10ms24kHz},
|
||||
feature_vector);
|
||||
return rnn_vad_.ComputeVadProbability(feature_vector, is_silence);
|
||||
}
|
||||
|
||||
private:
|
||||
PushResampler<float> resampler_;
|
||||
rnn_vad::FeaturesExtractor features_extractor_;
|
||||
rnn_vad::RnnVad rnn_vad_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
VadLevelAnalyzer::VadLevelAnalyzer(int vad_reset_period_ms,
|
||||
const AvailableCpuFeatures& cpu_features)
|
||||
: VadLevelAnalyzer(vad_reset_period_ms,
|
||||
std::make_unique<Vad>(cpu_features)) {}
|
||||
VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper(
|
||||
int vad_reset_period_ms,
|
||||
const AvailableCpuFeatures& cpu_features)
|
||||
: VoiceActivityDetectorWrapper(
|
||||
vad_reset_period_ms,
|
||||
std::make_unique<MonoVadImpl>(cpu_features)) {}
|
||||
|
||||
VadLevelAnalyzer::VadLevelAnalyzer(int vad_reset_period_ms,
|
||||
std::unique_ptr<VoiceActivityDetector> vad)
|
||||
: vad_(std::move(vad)),
|
||||
vad_reset_period_frames_(
|
||||
VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper(
|
||||
int vad_reset_period_ms,
|
||||
std::unique_ptr<MonoVad> vad)
|
||||
: vad_reset_period_frames_(
|
||||
rtc::CheckedDivExact(vad_reset_period_ms, kFrameDurationMs)),
|
||||
time_to_vad_reset_(vad_reset_period_frames_) {
|
||||
time_to_vad_reset_(vad_reset_period_frames_),
|
||||
vad_(std::move(vad)) {
|
||||
RTC_DCHECK(vad_);
|
||||
RTC_DCHECK_GT(vad_reset_period_frames_, 1);
|
||||
resampled_buffer_.resize(
|
||||
rtc::CheckedDivExact(vad_->SampleRateHz(), kNumFramesPerSecond));
|
||||
}
|
||||
|
||||
VadLevelAnalyzer::~VadLevelAnalyzer() = default;
|
||||
VoiceActivityDetectorWrapper::~VoiceActivityDetectorWrapper() = default;
|
||||
|
||||
VadLevelAnalyzer::Result VadLevelAnalyzer::AnalyzeFrame(
|
||||
AudioFrameView<const float> frame) {
|
||||
float VoiceActivityDetectorWrapper::Analyze(AudioFrameView<const float> frame) {
|
||||
// Periodically reset the VAD.
|
||||
time_to_vad_reset_--;
|
||||
if (time_to_vad_reset_ <= 0) {
|
||||
vad_->Reset();
|
||||
time_to_vad_reset_ = vad_reset_period_frames_;
|
||||
}
|
||||
// Compute levels.
|
||||
float peak = 0.0f;
|
||||
float rms = 0.0f;
|
||||
for (const auto& x : frame.channel(0)) {
|
||||
peak = std::max(std::fabs(x), peak);
|
||||
rms += x * x;
|
||||
}
|
||||
return {vad_->ComputeProbability(frame),
|
||||
FloatS16ToDbfs(std::sqrt(rms / frame.samples_per_channel())),
|
||||
FloatS16ToDbfs(peak)};
|
||||
|
||||
// Resample the first channel of `frame`.
|
||||
resampler_.InitializeIfNeeded(
|
||||
/*sample_rate_hz=*/frame.samples_per_channel() * kNumFramesPerSecond,
|
||||
vad_->SampleRateHz(), /*num_channels=*/1);
|
||||
resampler_.Resample(frame.channel(0).data(), frame.samples_per_channel(),
|
||||
resampled_buffer_.data(), resampled_buffer_.size());
|
||||
|
||||
return vad_->Analyze(resampled_buffer_);
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
@ -12,51 +12,57 @@
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_VAD_WRAPPER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "common_audio/resampler/include/push_resampler.h"
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
#include "modules/audio_processing/include/audio_frame_view.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
// Class to analyze voice activity and audio levels.
|
||||
class VadLevelAnalyzer {
|
||||
// Wraps a single-channel Voice Activity Detector (VAD) which is used to analyze
|
||||
// the first channel of the input audio frames. Takes care of resampling the
|
||||
// input frames to match the sample rate of the wrapped VAD and periodically
|
||||
// resets the VAD.
|
||||
class VoiceActivityDetectorWrapper {
|
||||
public:
|
||||
struct Result {
|
||||
float speech_probability; // Range: [0, 1].
|
||||
float rms_dbfs; // Root mean square power (dBFS).
|
||||
float peak_dbfs; // Peak power (dBFS).
|
||||
};
|
||||
|
||||
// Voice Activity Detector (VAD) interface.
|
||||
class VoiceActivityDetector {
|
||||
// Single channel VAD interface.
|
||||
class MonoVad {
|
||||
public:
|
||||
virtual ~VoiceActivityDetector() = default;
|
||||
virtual ~MonoVad() = default;
|
||||
// Returns the sample rate (Hz) required for the input frames analyzed by
|
||||
// `ComputeProbability`.
|
||||
virtual int SampleRateHz() const = 0;
|
||||
// Resets the internal state.
|
||||
virtual void Reset() = 0;
|
||||
// Analyzes an audio frame and returns the speech probability.
|
||||
virtual float ComputeProbability(AudioFrameView<const float> frame) = 0;
|
||||
virtual float Analyze(rtc::ArrayView<const float> frame) = 0;
|
||||
};
|
||||
|
||||
// Ctor. `vad_reset_period_ms` indicates the period in milliseconds to call
|
||||
// `VadLevelAnalyzer::Reset()`; it must be equal to or greater than the
|
||||
// duration of two frames. Uses `cpu_features` to instantiate the default VAD.
|
||||
VadLevelAnalyzer(int vad_reset_period_ms,
|
||||
const AvailableCpuFeatures& cpu_features);
|
||||
// `MonoVad::Reset()`; it must be equal to or greater than the duration of two
|
||||
// frames. Uses `cpu_features` to instantiate the default VAD.
|
||||
VoiceActivityDetectorWrapper(int vad_reset_period_ms,
|
||||
const AvailableCpuFeatures& cpu_features);
|
||||
// Ctor. Uses a custom `vad`.
|
||||
VadLevelAnalyzer(int vad_reset_period_ms,
|
||||
std::unique_ptr<VoiceActivityDetector> vad);
|
||||
VoiceActivityDetectorWrapper(int vad_reset_period_ms,
|
||||
std::unique_ptr<MonoVad> vad);
|
||||
|
||||
VadLevelAnalyzer(const VadLevelAnalyzer&) = delete;
|
||||
VadLevelAnalyzer& operator=(const VadLevelAnalyzer&) = delete;
|
||||
~VadLevelAnalyzer();
|
||||
VoiceActivityDetectorWrapper(const VoiceActivityDetectorWrapper&) = delete;
|
||||
VoiceActivityDetectorWrapper& operator=(const VoiceActivityDetectorWrapper&) =
|
||||
delete;
|
||||
~VoiceActivityDetectorWrapper();
|
||||
|
||||
// Computes the speech probability and the level for `frame`.
|
||||
Result AnalyzeFrame(AudioFrameView<const float> frame);
|
||||
// Analyzes the first channel of `frame` and returns the speech probability.
|
||||
float Analyze(AudioFrameView<const float> frame);
|
||||
|
||||
private:
|
||||
std::unique_ptr<VoiceActivityDetector> vad_;
|
||||
const int vad_reset_period_frames_;
|
||||
int time_to_vad_reset_;
|
||||
PushResampler<float> resampler_;
|
||||
std::unique_ptr<MonoVad> vad_;
|
||||
std::vector<float> resampled_buffer_;
|
||||
};
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "modules/audio_processing/include/audio_frame_view.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/gunit.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
#include "test/gmock.h"
|
||||
@ -26,90 +27,78 @@ namespace webrtc {
|
||||
namespace {
|
||||
|
||||
using ::testing::AnyNumber;
|
||||
using ::testing::Return;
|
||||
using ::testing::ReturnRoundRobin;
|
||||
using ::testing::Truly;
|
||||
|
||||
constexpr int kNoVadPeriodicReset =
|
||||
kFrameDurationMs * (std::numeric_limits<int>::max() / kFrameDurationMs);
|
||||
|
||||
constexpr int kSampleRateHz = 8000;
|
||||
constexpr int kSampleRate8kHz = 8000;
|
||||
|
||||
class MockVad : public VadLevelAnalyzer::VoiceActivityDetector {
|
||||
class MockVad : public VoiceActivityDetectorWrapper::MonoVad {
|
||||
public:
|
||||
MOCK_METHOD(int, SampleRateHz, (), (const override));
|
||||
MOCK_METHOD(void, Reset, (), (override));
|
||||
MOCK_METHOD(float,
|
||||
ComputeProbability,
|
||||
(AudioFrameView<const float> frame),
|
||||
(override));
|
||||
MOCK_METHOD(float, Analyze, (rtc::ArrayView<const float> frame), (override));
|
||||
};
|
||||
|
||||
// Creates a `VadLevelAnalyzer` injecting a mock VAD which repeatedly returns
|
||||
// the next value from `speech_probabilities` until it reaches the end and will
|
||||
// restart from the beginning.
|
||||
std::unique_ptr<VadLevelAnalyzer> CreateVadLevelAnalyzerWithMockVad(
|
||||
// Creates a `VoiceActivityDetectorWrapper` injecting a mock VAD that
|
||||
// repeatedly returns the next value from `speech_probabilities` and that
|
||||
// restarts from the beginning when after the last element is returned.
|
||||
std::unique_ptr<VoiceActivityDetectorWrapper> CreateMockVadWrapper(
|
||||
int vad_reset_period_ms,
|
||||
const std::vector<float>& speech_probabilities,
|
||||
int expected_vad_reset_calls = 0) {
|
||||
auto vad = std::make_unique<MockVad>();
|
||||
EXPECT_CALL(*vad, ComputeProbability)
|
||||
EXPECT_CALL(*vad, SampleRateHz)
|
||||
.Times(AnyNumber())
|
||||
.WillRepeatedly(ReturnRoundRobin(speech_probabilities));
|
||||
.WillRepeatedly(Return(kSampleRate8kHz));
|
||||
if (expected_vad_reset_calls >= 0) {
|
||||
EXPECT_CALL(*vad, Reset).Times(expected_vad_reset_calls);
|
||||
}
|
||||
return std::make_unique<VadLevelAnalyzer>(vad_reset_period_ms,
|
||||
std::move(vad));
|
||||
EXPECT_CALL(*vad, Analyze)
|
||||
.Times(AnyNumber())
|
||||
.WillRepeatedly(ReturnRoundRobin(speech_probabilities));
|
||||
return std::make_unique<VoiceActivityDetectorWrapper>(vad_reset_period_ms,
|
||||
std::move(vad));
|
||||
}
|
||||
|
||||
// 10 ms mono frame.
|
||||
struct FrameWithView {
|
||||
// Ctor. Initializes the frame samples with `value`.
|
||||
explicit FrameWithView(float value = 0.0f)
|
||||
: channel0(samples.data()),
|
||||
view(&channel0, /*num_channels=*/1, samples.size()) {
|
||||
samples.fill(value);
|
||||
}
|
||||
std::array<float, kSampleRateHz / 100> samples;
|
||||
explicit FrameWithView(int sample_rate_hz = kSampleRate8kHz)
|
||||
: samples(rtc::CheckedDivExact(sample_rate_hz, 100), 0.0f),
|
||||
channel0(samples.data()),
|
||||
view(&channel0, /*num_channels=*/1, samples.size()) {}
|
||||
std::vector<float> samples;
|
||||
const float* const channel0;
|
||||
const AudioFrameView<const float> view;
|
||||
};
|
||||
|
||||
TEST(GainController2VadLevelAnalyzer, RmsLessThanPeakLevel) {
|
||||
auto analyzer = CreateVadLevelAnalyzerWithMockVad(
|
||||
/*vad_reset_period_ms=*/1500,
|
||||
/*speech_probabilities=*/{1.0f},
|
||||
/*expected_vad_reset_calls=*/0);
|
||||
// Handcrafted frame so that the average is lower than the peak value.
|
||||
FrameWithView frame(1000.0f); // Constant frame.
|
||||
frame.samples[10] = 2000.0f; // Except for one peak value.
|
||||
// Compute audio frame levels.
|
||||
auto levels_and_vad_prob = analyzer->AnalyzeFrame(frame.view);
|
||||
EXPECT_LT(levels_and_vad_prob.rms_dbfs, levels_and_vad_prob.peak_dbfs);
|
||||
}
|
||||
|
||||
// Checks that the expect VAD probabilities are returned.
|
||||
TEST(GainController2VadLevelAnalyzer, NoSpeechProbabilitySmoothing) {
|
||||
// Checks that the expected speech probabilities are returned.
|
||||
TEST(GainController2VoiceActivityDetectorWrapper, CheckSpeechProbabilities) {
|
||||
const std::vector<float> speech_probabilities{0.709f, 0.484f, 0.882f, 0.167f,
|
||||
0.44f, 0.525f, 0.858f, 0.314f,
|
||||
0.653f, 0.965f, 0.413f, 0.0f};
|
||||
auto analyzer = CreateVadLevelAnalyzerWithMockVad(kNoVadPeriodicReset,
|
||||
speech_probabilities);
|
||||
auto vad_wrapper =
|
||||
CreateMockVadWrapper(kNoVadPeriodicReset, speech_probabilities);
|
||||
FrameWithView frame;
|
||||
for (int i = 0; rtc::SafeLt(i, speech_probabilities.size()); ++i) {
|
||||
SCOPED_TRACE(i);
|
||||
EXPECT_EQ(speech_probabilities[i],
|
||||
analyzer->AnalyzeFrame(frame.view).speech_probability);
|
||||
EXPECT_EQ(speech_probabilities[i], vad_wrapper->Analyze(frame.view));
|
||||
}
|
||||
}
|
||||
|
||||
// Checks that the VAD is not periodically reset.
|
||||
TEST(GainController2VadLevelAnalyzer, VadNoPeriodicReset) {
|
||||
TEST(GainController2VoiceActivityDetectorWrapper, VadNoPeriodicReset) {
|
||||
constexpr int kNumFrames = 19;
|
||||
auto analyzer = CreateVadLevelAnalyzerWithMockVad(
|
||||
kNoVadPeriodicReset, /*speech_probabilities=*/{1.0f},
|
||||
/*expected_vad_reset_calls=*/0);
|
||||
auto vad_wrapper =
|
||||
CreateMockVadWrapper(kNoVadPeriodicReset, /*speech_probabilities=*/{1.0f},
|
||||
/*expected_vad_reset_calls=*/0);
|
||||
FrameWithView frame;
|
||||
for (int i = 0; i < kNumFrames; ++i) {
|
||||
analyzer->AnalyzeFrame(frame.view);
|
||||
vad_wrapper->Analyze(frame.view);
|
||||
}
|
||||
}
|
||||
|
||||
@ -122,20 +111,52 @@ class VadPeriodResetParametrization
|
||||
|
||||
// Checks that the VAD is periodically reset with the expected period.
|
||||
TEST_P(VadPeriodResetParametrization, VadPeriodicReset) {
|
||||
auto analyzer = CreateVadLevelAnalyzerWithMockVad(
|
||||
auto vad_wrapper = CreateMockVadWrapper(
|
||||
/*vad_reset_period_ms=*/vad_reset_period_frames() * kFrameDurationMs,
|
||||
/*speech_probabilities=*/{1.0f},
|
||||
/*expected_vad_reset_calls=*/num_frames() / vad_reset_period_frames());
|
||||
FrameWithView frame;
|
||||
for (int i = 0; i < num_frames(); ++i) {
|
||||
analyzer->AnalyzeFrame(frame.view);
|
||||
vad_wrapper->Analyze(frame.view);
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(GainController2VadLevelAnalyzer,
|
||||
INSTANTIATE_TEST_SUITE_P(GainController2VoiceActivityDetectorWrapper,
|
||||
VadPeriodResetParametrization,
|
||||
::testing::Combine(::testing::Values(1, 19, 123),
|
||||
::testing::Values(2, 5, 20, 53)));
|
||||
|
||||
class VadResamplingParametrization
|
||||
: public ::testing::TestWithParam<std::tuple<int, int>> {
|
||||
protected:
|
||||
int input_sample_rate_hz() const { return std::get<0>(GetParam()); }
|
||||
int vad_sample_rate_hz() const { return std::get<1>(GetParam()); }
|
||||
};
|
||||
|
||||
// Checks that regardless of the input audio sample rate, the wrapped VAD
|
||||
// analyzes frames having the expected size, that is according to its internal
|
||||
// sample rate.
|
||||
TEST_P(VadResamplingParametrization, CheckResampledFrameSize) {
|
||||
auto vad = std::make_unique<MockVad>();
|
||||
EXPECT_CALL(*vad, SampleRateHz)
|
||||
.Times(AnyNumber())
|
||||
.WillRepeatedly(Return(vad_sample_rate_hz()));
|
||||
EXPECT_CALL(*vad, Reset).Times(0);
|
||||
EXPECT_CALL(*vad, Analyze(Truly([this](rtc::ArrayView<const float> frame) {
|
||||
return rtc::SafeEq(frame.size(),
|
||||
rtc::CheckedDivExact(vad_sample_rate_hz(), 100));
|
||||
}))).Times(1);
|
||||
auto vad_wrapper = std::make_unique<VoiceActivityDetectorWrapper>(
|
||||
kNoVadPeriodicReset, std::move(vad));
|
||||
FrameWithView frame(input_sample_rate_hz());
|
||||
vad_wrapper->Analyze(frame.view);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
GainController2VoiceActivityDetectorWrapper,
|
||||
VadResamplingParametrization,
|
||||
::testing::Combine(::testing::Values(8000, 16000, 44100, 48000),
|
||||
::testing::Values(6000, 8000, 12000, 16000, 24000)));
|
||||
|
||||
} // namespace
|
||||
} // namespace webrtc
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user