AGC2 add an interface for the noise level estimator

Done in preparation for the child CL which adds an alternative
implementation.

Bug: webrtc:7494
Change-Id: I4963376afc917eae434a0d0ccee18f21880eefe0
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/214125
Reviewed-by: Jakob Ivarsson <jakobi@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#33646}
This commit is contained in:
Alessio Bazzica 2021-04-07 14:57:40 +02:00 committed by Commit Bot
parent c335b0e63b
commit 11bd143974
5 changed files with 109 additions and 94 deletions

View File

@ -58,7 +58,7 @@ AdaptiveAgc::AdaptiveAgc(ApmDataDumper* apm_data_dumper)
kMaxGainChangePerSecondDb,
kMaxOutputNoiseLevelDbfs),
apm_data_dumper_(apm_data_dumper),
noise_level_estimator_(apm_data_dumper) {
noise_level_estimator_(CreateNoiseLevelEstimator(apm_data_dumper)) {
RTC_DCHECK(apm_data_dumper);
}
@ -80,7 +80,7 @@ AdaptiveAgc::AdaptiveAgc(ApmDataDumper* apm_data_dumper,
config.adaptive_digital.max_gain_change_db_per_second,
config.adaptive_digital.max_output_noise_level_dbfs),
apm_data_dumper_(apm_data_dumper),
noise_level_estimator_(apm_data_dumper) {
noise_level_estimator_(CreateNoiseLevelEstimator(apm_data_dumper)) {
RTC_DCHECK(apm_data_dumper);
if (!config.adaptive_digital.use_saturation_protector) {
RTC_LOG(LS_WARNING) << "The saturation protector cannot be disabled.";
@ -94,7 +94,7 @@ void AdaptiveAgc::Process(AudioFrameView<float> frame, float limiter_envelope) {
info.vad_result = vad_.AnalyzeFrame(frame);
speech_level_estimator_.Update(info.vad_result);
info.input_level_dbfs = speech_level_estimator_.level_dbfs();
info.input_noise_level_dbfs = noise_level_estimator_.Analyze(frame);
info.input_noise_level_dbfs = noise_level_estimator_->Analyze(frame);
info.limiter_envelope_dbfs =
limiter_envelope > 0 ? FloatS16ToDbfs(limiter_envelope) : -90.0f;
info.estimate_is_confident = speech_level_estimator_.IsConfident();

View File

@ -11,6 +11,8 @@
#ifndef MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_AGC_H_
#define MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_AGC_H_
#include <memory>
#include "modules/audio_processing/agc2/adaptive_digital_gain_applier.h"
#include "modules/audio_processing/agc2/adaptive_mode_level_estimator.h"
#include "modules/audio_processing/agc2/noise_level_estimator.h"
@ -42,7 +44,7 @@ class AdaptiveAgc {
VadLevelAnalyzer vad_;
AdaptiveDigitalGainApplier gain_applier_;
ApmDataDumper* const apm_data_dumper_;
NoiseLevelEstimator noise_level_estimator_;
std::unique_ptr<NoiseLevelEstimator> noise_level_estimator_;
};
} // namespace webrtc

View File

@ -18,11 +18,11 @@
#include "api/array_view.h"
#include "common_audio/include/audio_util.h"
#include "modules/audio_processing/agc2/signal_classifier.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace {
constexpr int kFramesPerSecond = 100;
@ -41,86 +41,106 @@ float EnergyToDbfs(float signal_energy, size_t num_samples) {
const float rms = std::sqrt(signal_energy / num_samples);
return FloatS16ToDbfs(rms);
}
} // namespace
NoiseLevelEstimator::NoiseLevelEstimator(ApmDataDumper* data_dumper)
: data_dumper_(data_dumper), signal_classifier_(data_dumper) {
Initialize(48000);
}
NoiseLevelEstimator::~NoiseLevelEstimator() {}
void NoiseLevelEstimator::Initialize(int sample_rate_hz) {
sample_rate_hz_ = sample_rate_hz;
noise_energy_ = 1.0f;
first_update_ = true;
min_noise_energy_ = sample_rate_hz * 2.0f * 2.0f / kFramesPerSecond;
noise_energy_hold_counter_ = 0;
signal_classifier_.Initialize(sample_rate_hz);
}
float NoiseLevelEstimator::Analyze(const AudioFrameView<const float>& frame) {
data_dumper_->DumpRaw("agc2_noise_level_estimator_hold_counter",
noise_energy_hold_counter_);
const int sample_rate_hz =
static_cast<int>(frame.samples_per_channel() * kFramesPerSecond);
if (sample_rate_hz != sample_rate_hz_) {
Initialize(sample_rate_hz);
}
const float frame_energy = FrameEnergy(frame);
if (frame_energy <= 0.f) {
RTC_DCHECK_GE(frame_energy, 0.f);
data_dumper_->DumpRaw("agc2_noise_level_estimator_signal_type", -1);
return EnergyToDbfs(noise_energy_, frame.samples_per_channel());
class NoiseLevelEstimatorImpl : public NoiseLevelEstimator {
public:
NoiseLevelEstimatorImpl(ApmDataDumper* data_dumper)
: data_dumper_(data_dumper), signal_classifier_(data_dumper) {
Initialize(48000);
}
NoiseLevelEstimatorImpl(const NoiseLevelEstimatorImpl&) = delete;
NoiseLevelEstimatorImpl& operator=(const NoiseLevelEstimatorImpl&) = delete;
~NoiseLevelEstimatorImpl() = default;
if (first_update_) {
// Initialize the noise energy to the frame energy.
first_update_ = false;
data_dumper_->DumpRaw("agc2_noise_level_estimator_signal_type", -1);
noise_energy_ = std::max(frame_energy, min_noise_energy_);
return EnergyToDbfs(noise_energy_, frame.samples_per_channel());
}
float Analyze(const AudioFrameView<const float>& frame) {
data_dumper_->DumpRaw("agc2_noise_level_estimator_hold_counter",
noise_energy_hold_counter_);
const int sample_rate_hz =
static_cast<int>(frame.samples_per_channel() * kFramesPerSecond);
if (sample_rate_hz != sample_rate_hz_) {
Initialize(sample_rate_hz);
}
const float frame_energy = FrameEnergy(frame);
if (frame_energy <= 0.f) {
RTC_DCHECK_GE(frame_energy, 0.f);
data_dumper_->DumpRaw("agc2_noise_level_estimator_signal_type", -1);
return EnergyToDbfs(noise_energy_, frame.samples_per_channel());
}
const SignalClassifier::SignalType signal_type =
signal_classifier_.Analyze(frame.channel(0));
data_dumper_->DumpRaw("agc2_noise_level_estimator_signal_type",
static_cast<int>(signal_type));
if (first_update_) {
// Initialize the noise energy to the frame energy.
first_update_ = false;
data_dumper_->DumpRaw("agc2_noise_level_estimator_signal_type", -1);
noise_energy_ = std::max(frame_energy, min_noise_energy_);
return EnergyToDbfs(noise_energy_, frame.samples_per_channel());
}
// Update the noise estimate in a minimum statistics-type manner.
if (signal_type == SignalClassifier::SignalType::kStationary) {
if (frame_energy > noise_energy_) {
// Leak the estimate upwards towards the frame energy if no recent
// downward update.
noise_energy_hold_counter_ = std::max(noise_energy_hold_counter_ - 1, 0);
const SignalClassifier::SignalType signal_type =
signal_classifier_.Analyze(frame.channel(0));
data_dumper_->DumpRaw("agc2_noise_level_estimator_signal_type",
static_cast<int>(signal_type));
if (noise_energy_hold_counter_ == 0) {
constexpr float kMaxNoiseEnergyFactor = 1.01f;
// Update the noise estimate in a minimum statistics-type manner.
if (signal_type == SignalClassifier::SignalType::kStationary) {
if (frame_energy > noise_energy_) {
// Leak the estimate upwards towards the frame energy if no recent
// downward update.
noise_energy_hold_counter_ =
std::max(noise_energy_hold_counter_ - 1, 0);
if (noise_energy_hold_counter_ == 0) {
constexpr float kMaxNoiseEnergyFactor = 1.01f;
noise_energy_ =
std::min(noise_energy_ * kMaxNoiseEnergyFactor, frame_energy);
}
} else {
// Update smoothly downwards with a limited maximum update magnitude.
constexpr float kMinNoiseEnergyFactor = 0.9f;
constexpr float kNoiseEnergyDeltaFactor = 0.05f;
noise_energy_ =
std::min(noise_energy_ * kMaxNoiseEnergyFactor, frame_energy);
std::max(noise_energy_ * kMinNoiseEnergyFactor,
noise_energy_ - kNoiseEnergyDeltaFactor *
(noise_energy_ - frame_energy));
// Prevent an energy increase for the next 10 seconds.
constexpr int kNumFramesToEnergyIncreaseAllowed = 1000;
noise_energy_hold_counter_ = kNumFramesToEnergyIncreaseAllowed;
}
} else {
// Update smoothly downwards with a limited maximum update magnitude.
constexpr float kMinNoiseEnergyFactor = 0.9f;
constexpr float kNoiseEnergyDeltaFactor = 0.05f;
noise_energy_ =
std::max(noise_energy_ * kMinNoiseEnergyFactor,
noise_energy_ - kNoiseEnergyDeltaFactor *
(noise_energy_ - frame_energy));
// Prevent an energy increase for the next 10 seconds.
constexpr int kNumFramesToEnergyIncreaseAllowed = 1000;
noise_energy_hold_counter_ = kNumFramesToEnergyIncreaseAllowed;
// TODO(bugs.webrtc.org/7494): Remove to not forget the estimated level.
// For a non-stationary signal, leak the estimate downwards in order to
// avoid estimate locking due to incorrect signal classification.
noise_energy_ = noise_energy_ * 0.99f;
}
} else {
// TODO(bugs.webrtc.org/7494): Remove to not forget the estimated level.
// For a non-stationary signal, leak the estimate downwards in order to
// avoid estimate locking due to incorrect signal classification.
noise_energy_ = noise_energy_ * 0.99f;
// Ensure a minimum of the estimate.
noise_energy_ = std::max(noise_energy_, min_noise_energy_);
return EnergyToDbfs(noise_energy_, frame.samples_per_channel());
}
// Ensure a minimum of the estimate.
noise_energy_ = std::max(noise_energy_, min_noise_energy_);
return EnergyToDbfs(noise_energy_, frame.samples_per_channel());
private:
void Initialize(int sample_rate_hz) {
sample_rate_hz_ = sample_rate_hz;
noise_energy_ = 1.0f;
first_update_ = true;
min_noise_energy_ = sample_rate_hz * 2.0f * 2.0f / kFramesPerSecond;
noise_energy_hold_counter_ = 0;
signal_classifier_.Initialize(sample_rate_hz);
}
ApmDataDumper* const data_dumper_;
int sample_rate_hz_;
float min_noise_energy_;
bool first_update_;
float noise_energy_;
int noise_energy_hold_counter_;
SignalClassifier signal_classifier_;
};
} // namespace
std::unique_ptr<NoiseLevelEstimator> CreateNoiseLevelEstimator(
ApmDataDumper* data_dumper) {
return std::make_unique<NoiseLevelEstimatorImpl>(data_dumper);
}
} // namespace webrtc

View File

@ -11,33 +11,26 @@
#ifndef MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_
#define MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_
#include "modules/audio_processing/agc2/signal_classifier.h"
#include <memory>
#include "modules/audio_processing/include/audio_frame_view.h"
namespace webrtc {
class ApmDataDumper;
// Noise level estimator interface.
class NoiseLevelEstimator {
public:
NoiseLevelEstimator(ApmDataDumper* data_dumper);
NoiseLevelEstimator(const NoiseLevelEstimator&) = delete;
NoiseLevelEstimator& operator=(const NoiseLevelEstimator&) = delete;
~NoiseLevelEstimator();
// Returns the estimated noise level in dBFS.
float Analyze(const AudioFrameView<const float>& frame);
private:
void Initialize(int sample_rate_hz);
ApmDataDumper* const data_dumper_;
int sample_rate_hz_;
float min_noise_energy_;
bool first_update_;
float noise_energy_;
int noise_energy_hold_counter_;
SignalClassifier signal_classifier_;
virtual ~NoiseLevelEstimator() = default;
// Analyzes a 10 ms `frame`, updates the noise level estimation and returns
// the value for the latter in dBFS.
virtual float Analyze(const AudioFrameView<const float>& frame) = 0;
};
// Creates a noise level estimator based on stationarity detection.
std::unique_ptr<NoiseLevelEstimator> CreateNoiseLevelEstimator(
ApmDataDumper* data_dumper);
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_

View File

@ -31,7 +31,7 @@ constexpr int kFramesPerSecond = 100;
float RunEstimator(rtc::FunctionView<float()> sample_generator,
int sample_rate_hz) {
ApmDataDumper data_dumper(0);
NoiseLevelEstimator estimator(&data_dumper);
auto estimator = CreateNoiseLevelEstimator(&data_dumper);
const int samples_per_channel =
rtc::CheckedDivExact(sample_rate_hz, kFramesPerSecond);
VectorFloatFrame signal(1, samples_per_channel, 0.0f);
@ -41,9 +41,9 @@ float RunEstimator(rtc::FunctionView<float()> sample_generator,
for (int j = 0; j < samples_per_channel; ++j) {
frame_view.channel(0)[j] = sample_generator();
}
estimator.Analyze(frame_view);
estimator->Analyze(frame_view);
}
return estimator.Analyze(signal.float_frame_view());
return estimator->Analyze(signal.float_frame_view());
}
class NoiseEstimatorParametrization : public ::testing::TestWithParam<int> {