AEC3: Add multichannel support in the ERLE estimation

Bug: webrtc:10913
Change-Id: I1667146d38dc99d099b140f47cd774a7f203b4f0
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/157047
Reviewed-by: Sam Zackrisson <saza@webrtc.org>
Commit-Queue: Per Åhgren <peah@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#29521}
This commit is contained in:
Per Åhgren 2019-10-17 14:40:54 +02:00 committed by Commit Bot
parent db8df17650
commit 785d4c40ca
14 changed files with 838 additions and 625 deletions

View File

@ -44,7 +44,7 @@ void ComputeAvgRenderReverb(
std::array<float, kFftLengthBy2Plus1> X2_data;
rtc::ArrayView<const float> X2;
if (num_render_channels > 1) {
auto sum_channels =
auto average_channels =
[](size_t num_render_channels,
const std::vector<std::vector<float>>& spectrum_band_0,
rtc::ArrayView<float, kFftLengthBy2Plus1> render_power) {
@ -55,14 +55,18 @@ void ComputeAvgRenderReverb(
render_power[k] += spectrum_band_0[ch][k];
}
}
const float normalizer = 1.f / num_render_channels;
for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
render_power[k] *= normalizer;
}
};
sum_channels(num_render_channels, spectrum_buffer.buffer[idx_past],
X2_data);
average_channels(num_render_channels, spectrum_buffer.buffer[idx_past],
X2_data);
reverb_model->UpdateReverbNoFreqShaping(
X2_data, /*power_spectrum_scaling=*/1.0f, reverb_decay);
sum_channels(num_render_channels, spectrum_buffer.buffer[idx_at_delay],
X2_data);
average_channels(num_render_channels, spectrum_buffer.buffer[idx_at_delay],
X2_data);
X2 = X2_data;
} else {
reverb_model->UpdateReverbNoFreqShaping(
@ -110,17 +114,18 @@ AecState::AecState(const EchoCanceller3Config& config,
: data_dumper_(
new ApmDataDumper(rtc::AtomicOps::Increment(&instance_count_))),
config_(config),
num_capture_channels_(num_capture_channels),
initial_state_(config_),
delay_state_(config_, num_capture_channels),
delay_state_(config_, num_capture_channels_),
transparent_state_(config_),
filter_quality_state_(config_, num_capture_channels),
filter_quality_state_(config_, num_capture_channels_),
erl_estimator_(2 * kNumBlocksPerSecond),
erle_estimator_(2 * kNumBlocksPerSecond, config_, num_capture_channels),
filter_analyzer_(config_, num_capture_channels),
erle_estimator_(2 * kNumBlocksPerSecond, config_, num_capture_channels_),
filter_analyzer_(config_, num_capture_channels_),
echo_audibility_(
config_.echo_audibility.use_stationarity_properties_at_init),
reverb_model_estimator_(config_, num_capture_channels),
subtractor_output_analyzers_(num_capture_channels) {}
reverb_model_estimator_(config_, num_capture_channels_),
subtractor_output_analyzer_(num_capture_channels_) {}
AecState::~AecState() = default;
@ -147,9 +152,7 @@ void AecState::HandleEchoPathChange(
} else if (echo_path_variability.gain_change) {
erle_estimator_.Reset(false);
}
for (auto& analyzer : subtractor_output_analyzers_) {
analyzer.HandleEchoPathChange();
}
subtractor_output_analyzer_.HandleEchoPathChange();
}
void AecState::Update(
@ -161,25 +164,19 @@ void AecState::Update(
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2_main,
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
rtc::ArrayView<const SubtractorOutput> subtractor_output) {
const size_t num_capture_channels = subtractor_output_analyzers_.size();
RTC_DCHECK_EQ(num_capture_channels, E2_main.size());
RTC_DCHECK_EQ(num_capture_channels, Y2.size());
RTC_DCHECK_EQ(num_capture_channels, subtractor_output.size());
RTC_DCHECK_EQ(num_capture_channels, subtractor_output_analyzers_.size());
RTC_DCHECK_EQ(num_capture_channels,
RTC_DCHECK_EQ(num_capture_channels_, Y2.size());
RTC_DCHECK_EQ(num_capture_channels_, subtractor_output.size());
RTC_DCHECK_EQ(num_capture_channels_,
adaptive_filter_frequency_responses.size());
RTC_DCHECK_EQ(num_capture_channels, adaptive_filter_impulse_responses.size());
RTC_DCHECK_EQ(num_capture_channels_,
adaptive_filter_impulse_responses.size());
// Analyze the filter outputs and filters.
bool any_filter_converged = false;
bool all_filters_diverged = true;
for (size_t ch = 0; ch < subtractor_output.size(); ++ch) {
subtractor_output_analyzers_[ch].Update(subtractor_output[ch]);
any_filter_converged = any_filter_converged ||
subtractor_output_analyzers_[ch].ConvergedFilter();
all_filters_diverged = all_filters_diverged &&
subtractor_output_analyzers_[ch].DivergedFilter();
}
bool any_filter_converged;
bool all_filters_diverged;
subtractor_output_analyzer_.Update(subtractor_output, &any_filter_converged,
&all_filters_diverged);
bool any_filter_consistent;
float max_echo_path_gain;
filter_analyzer_.Update(adaptive_filter_impulse_responses, render_buffer,
@ -229,16 +226,15 @@ void AecState::Update(
erle_estimator_.Reset(false);
}
erle_estimator_.Update(render_buffer, adaptive_filter_frequency_responses[0],
avg_render_spectrum_with_reverb, Y2[0], E2_main[0],
subtractor_output_analyzers_[0].ConvergedFilter(),
config_.erle.onset_detection);
erle_estimator_.Update(render_buffer, adaptive_filter_frequency_responses,
avg_render_spectrum_with_reverb, Y2, E2_main,
subtractor_output_analyzer_.ConvergedFilters());
// TODO(bugs.webrtc.org/10913): Take all channels into account.
const auto& X2 =
render_buffer.Spectrum(delay_state_.MinDirectPathFilterDelay(),
/*channel=*/0);
erl_estimator_.Update(subtractor_output_analyzers_[0].ConvergedFilter(), X2,
erl_estimator_.Update(subtractor_output_analyzer_.ConvergedFilters()[0], X2,
Y2[0]);
// Detect and flag echo saturation.

View File

@ -150,6 +150,7 @@ class AecState {
static int instance_count_;
std::unique_ptr<ApmDataDumper> data_dumper_;
const EchoCanceller3Config config_;
const size_t num_capture_channels_;
// Class for controlling the transition from the intial state, which in turn
// controls when the filter parameters for the initial state should be used.
@ -314,7 +315,7 @@ class AecState {
EchoAudibility echo_audibility_;
ReverbModelEstimator reverb_model_estimator_;
ReverbModel avg_render_reverb_;
std::vector<SubtractorOutputAnalyzer> subtractor_output_analyzers_;
SubtractorOutputAnalyzer subtractor_output_analyzer_;
};
} // namespace webrtc

View File

@ -15,14 +15,17 @@
namespace webrtc {
ErleEstimator::ErleEstimator(size_t startup_phase_length_blocks_,
ErleEstimator::ErleEstimator(size_t startup_phase_length_blocks,
const EchoCanceller3Config& config,
size_t num_capture_channels)
: startup_phase_length_blocks__(startup_phase_length_blocks_),
use_signal_dependent_erle_(config.erle.num_sections > 1),
: startup_phase_length_blocks_(startup_phase_length_blocks),
fullband_erle_estimator_(config.erle, num_capture_channels),
subband_erle_estimator_(config, num_capture_channels),
signal_dependent_erle_estimator_(config, num_capture_channels) {
subband_erle_estimator_(config, num_capture_channels) {
if (config.erle.num_sections > 1) {
signal_dependent_erle_estimator_ =
std::make_unique<SignalDependentErleEstimator>(config,
num_capture_channels);
}
Reset(true);
}
@ -31,7 +34,9 @@ ErleEstimator::~ErleEstimator() = default;
void ErleEstimator::Reset(bool delay_change) {
fullband_erle_estimator_.Reset();
subband_erle_estimator_.Reset();
signal_dependent_erle_estimator_.Reset();
if (signal_dependent_erle_estimator_) {
signal_dependent_erle_estimator_->Reset();
}
if (delay_change) {
blocks_since_reset_ = 0;
}
@ -39,41 +44,43 @@ void ErleEstimator::Reset(bool delay_change) {
void ErleEstimator::Update(
const RenderBuffer& render_buffer,
const std::vector<std::array<float, kFftLengthBy2Plus1>>&
filter_frequency_response,
rtc::ArrayView<const float> reverb_render_spectrum,
rtc::ArrayView<const float> capture_spectrum,
rtc::ArrayView<const float> subtractor_spectrum,
bool converged_filter,
bool onset_detection) {
RTC_DCHECK_EQ(kFftLengthBy2Plus1, reverb_render_spectrum.size());
RTC_DCHECK_EQ(kFftLengthBy2Plus1, capture_spectrum.size());
RTC_DCHECK_EQ(kFftLengthBy2Plus1, subtractor_spectrum.size());
const auto& X2_reverb = reverb_render_spectrum;
const auto& Y2 = capture_spectrum;
const auto& E2 = subtractor_spectrum;
rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>>
filter_frequency_responses,
rtc::ArrayView<const float, kFftLengthBy2Plus1>
avg_render_spectrum_with_reverb,
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> capture_spectra,
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>>
subtractor_spectra,
const std::vector<bool>& converged_filters) {
RTC_DCHECK_EQ(subband_erle_estimator_.Erle().size(), capture_spectra.size());
RTC_DCHECK_EQ(subband_erle_estimator_.Erle().size(),
subtractor_spectra.size());
const auto& X2_reverb = avg_render_spectrum_with_reverb;
const auto& Y2 = capture_spectra;
const auto& E2 = subtractor_spectra;
if (++blocks_since_reset_ < startup_phase_length_blocks__) {
if (++blocks_since_reset_ < startup_phase_length_blocks_) {
return;
}
subband_erle_estimator_.Update(X2_reverb, Y2, E2, converged_filter,
onset_detection);
subband_erle_estimator_.Update(X2_reverb, Y2, E2, converged_filters);
if (use_signal_dependent_erle_) {
signal_dependent_erle_estimator_.Update(
render_buffer, filter_frequency_response, X2_reverb, Y2, E2,
subband_erle_estimator_.Erle(), converged_filter);
if (signal_dependent_erle_estimator_) {
signal_dependent_erle_estimator_->Update(
render_buffer, filter_frequency_responses, X2_reverb, Y2, E2,
subband_erle_estimator_.Erle(), converged_filters);
}
fullband_erle_estimator_.Update(X2_reverb, Y2, E2, converged_filter);
fullband_erle_estimator_.Update(X2_reverb, Y2, E2, converged_filters);
}
void ErleEstimator::Dump(
const std::unique_ptr<ApmDataDumper>& data_dumper) const {
fullband_erle_estimator_.Dump(data_dumper);
subband_erle_estimator_.Dump(data_dumper);
signal_dependent_erle_estimator_.Dump(data_dumper);
if (signal_dependent_erle_estimator_) {
signal_dependent_erle_estimator_->Dump(data_dumper);
}
}
} // namespace webrtc

View File

@ -15,6 +15,7 @@
#include <array>
#include <memory>
#include <vector>
#include "absl/types/optional.h"
#include "api/array_view.h"
@ -32,7 +33,7 @@ namespace webrtc {
// and another one is done using the aggreation of energy over all the subbands.
class ErleEstimator {
public:
ErleEstimator(size_t startup_phase_length_blocks_,
ErleEstimator(size_t startup_phase_length_blocks,
const EchoCanceller3Config& config,
size_t num_capture_channels);
~ErleEstimator();
@ -41,24 +42,29 @@ class ErleEstimator {
void Reset(bool delay_change);
// Updates the ERLE estimates.
void Update(const RenderBuffer& render_buffer,
const std::vector<std::array<float, kFftLengthBy2Plus1>>&
filter_frequency_response,
rtc::ArrayView<const float> reverb_render_spectrum,
rtc::ArrayView<const float> capture_spectrum,
rtc::ArrayView<const float> subtractor_spectrum,
bool converged_filter,
bool onset_detection);
void Update(
const RenderBuffer& render_buffer,
rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>>
filter_frequency_responses,
rtc::ArrayView<const float, kFftLengthBy2Plus1>
avg_render_spectrum_with_reverb,
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>>
capture_spectra,
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>>
subtractor_spectra,
const std::vector<bool>& converged_filters);
// Returns the most recent subband ERLE estimates.
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Erle() const {
return use_signal_dependent_erle_ ? signal_dependent_erle_estimator_.Erle()
: subband_erle_estimator_.Erle();
return signal_dependent_erle_estimator_
? signal_dependent_erle_estimator_->Erle()
: subband_erle_estimator_.Erle();
}
// Returns the subband ERLE that are estimated during onsets. Used
// for logging/testing.
rtc::ArrayView<const float> ErleOnsets() const {
// Returns the subband ERLE that are estimated during onsets (only used for
// testing).
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> ErleOnsets()
const {
return subband_erle_estimator_.ErleOnsets();
}
@ -80,11 +86,11 @@ class ErleEstimator {
void Dump(const std::unique_ptr<ApmDataDumper>& data_dumper) const;
private:
const size_t startup_phase_length_blocks__;
const bool use_signal_dependent_erle_;
const size_t startup_phase_length_blocks_;
FullBandErleEstimator fullband_erle_estimator_;
SubbandErleEstimator subband_erle_estimator_;
SignalDependentErleEstimator signal_dependent_erle_estimator_;
std::unique_ptr<SignalDependentErleEstimator>
signal_dependent_erle_estimator_;
size_t blocks_since_reset_ = 0;
};

View File

@ -27,21 +27,25 @@ constexpr float kTrueErle = 10.f;
constexpr float kTrueErleOnsets = 1.0f;
constexpr float kEchoPathGain = 3.f;
void VerifyErleBands(rtc::ArrayView<const float> erle,
float reference_lf,
float reference_hf) {
std::for_each(
erle.begin(), erle.begin() + kLowFrequencyLimit,
[reference_lf](float a) { EXPECT_NEAR(reference_lf, a, 0.001); });
std::for_each(
erle.begin() + kLowFrequencyLimit, erle.end(),
[reference_hf](float a) { EXPECT_NEAR(reference_hf, a, 0.001); });
void VerifyErleBands(
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> erle,
float reference_lf,
float reference_hf) {
for (size_t ch = 0; ch < erle.size(); ++ch) {
std::for_each(
erle[ch].begin(), erle[ch].begin() + kLowFrequencyLimit,
[reference_lf](float a) { EXPECT_NEAR(reference_lf, a, 0.001); });
std::for_each(
erle[ch].begin() + kLowFrequencyLimit, erle[ch].end(),
[reference_hf](float a) { EXPECT_NEAR(reference_hf, a, 0.001); });
}
}
void VerifyErle(rtc::ArrayView<const float> erle,
float erle_time_domain,
float reference_lf,
float reference_hf) {
void VerifyErle(
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> erle,
float erle_time_domain,
float reference_lf,
float reference_hf) {
VerifyErleBands(erle, reference_lf, reference_hf);
EXPECT_NEAR(reference_lf, erle_time_domain, 0.5);
}
@ -65,160 +69,210 @@ void FormFarendTimeFrame(std::vector<std::vector<std::vector<float>>>* x) {
}
void FormFarendFrame(const RenderBuffer& render_buffer,
float erle,
std::array<float, kFftLengthBy2Plus1>* X2,
std::array<float, kFftLengthBy2Plus1>* E2,
std::array<float, kFftLengthBy2Plus1>* Y2,
float erle) {
rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> E2,
rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> Y2) {
const auto& spectrum_buffer = render_buffer.GetSpectrumBuffer();
const auto& X2_from_buffer =
spectrum_buffer.buffer[spectrum_buffer.write][/*channel=*/0];
std::copy(X2_from_buffer.begin(), X2_from_buffer.end(), X2->begin());
std::transform(X2->begin(), X2->end(), Y2->begin(),
[](float a) { return a * kEchoPathGain * kEchoPathGain; });
std::transform(Y2->begin(), Y2->end(), E2->begin(),
[erle](float a) { return a / erle; });
const int num_render_channels = spectrum_buffer.buffer[0].size();
const int num_capture_channels = Y2.size();
} // namespace
void FormNearendFrame(std::vector<std::vector<std::vector<float>>>* x,
std::array<float, kFftLengthBy2Plus1>* X2,
std::array<float, kFftLengthBy2Plus1>* E2,
std::array<float, kFftLengthBy2Plus1>* Y2) {
for (size_t band = 0; band < x->size(); ++band) {
for (size_t channel = 0; channel < (*x)[band].size(); ++channel) {
std::fill((*x)[band][channel].begin(), (*x)[band][channel].end(), 0.f);
X2->fill(0.f);
Y2->fill(500.f * 1000.f * 1000.f);
E2->fill((*Y2)[0]);
X2->fill(0.f);
for (int ch = 0; ch < num_render_channels; ++ch) {
for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
(*X2)[k] += spectrum_buffer.buffer[spectrum_buffer.write][ch][k] /
num_render_channels;
}
}
for (int ch = 0; ch < num_capture_channels; ++ch) {
std::transform(X2->begin(), X2->end(), Y2[ch].begin(),
[](float a) { return a * kEchoPathGain * kEchoPathGain; });
std::transform(Y2[ch].begin(), Y2[ch].end(), E2[ch].begin(),
[erle](float a) { return a / erle; });
}
}
void GetFilterFreq(std::vector<std::array<float, kFftLengthBy2Plus1>>&
filter_frequency_response,
size_t delay_headroom_samples) {
const size_t delay_headroom_blocks = delay_headroom_samples / kBlockSize;
for (auto& block_freq_resp : filter_frequency_response) {
block_freq_resp.fill(0.f);
void FormNearendFrame(
std::vector<std::vector<std::vector<float>>>* x,
std::array<float, kFftLengthBy2Plus1>* X2,
rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> E2,
rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> Y2) {
for (size_t band = 0; band < x->size(); ++band) {
for (size_t ch = 0; ch < (*x)[band].size(); ++ch) {
std::fill((*x)[band][ch].begin(), (*x)[band][ch].end(), 0.f);
}
}
for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
filter_frequency_response[delay_headroom_blocks][k] = kEchoPathGain;
X2->fill(0.f);
for (size_t ch = 0; ch < Y2.size(); ++ch) {
Y2[ch].fill(500.f * 1000.f * 1000.f);
E2[ch].fill(Y2[ch][0]);
}
}
void GetFilterFreq(
size_t delay_headroom_samples,
rtc::ArrayView<std::vector<std::array<float, kFftLengthBy2Plus1>>>
filter_frequency_response) {
const size_t delay_headroom_blocks = delay_headroom_samples / kBlockSize;
for (size_t ch = 0; ch < filter_frequency_response[0].size(); ++ch) {
for (auto& block_freq_resp : filter_frequency_response) {
block_freq_resp[ch].fill(0.f);
}
for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
filter_frequency_response[delay_headroom_blocks][ch][k] = kEchoPathGain;
}
}
}
} // namespace
TEST(ErleEstimator, VerifyErleIncreaseAndHold) {
std::array<float, kFftLengthBy2Plus1> X2;
std::array<float, kFftLengthBy2Plus1> E2;
std::array<float, kFftLengthBy2Plus1> Y2;
constexpr size_t kNumRenderChannels = 1;
constexpr size_t kNumCaptureChannels = 1;
constexpr int kSampleRateHz = 48000;
constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
EchoCanceller3Config config;
std::vector<std::vector<std::vector<float>>> x(
kNumBands, std::vector<std::vector<float>>(
kNumRenderChannels, std::vector<float>(kBlockSize, 0.f)));
std::vector<std::array<float, kFftLengthBy2Plus1>> filter_frequency_response(
config.filter.main.length_blocks);
std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
RenderDelayBuffer::Create(config, kSampleRateHz, kNumRenderChannels));
for (size_t num_render_channels : {1, 2, 4, 8}) {
for (size_t num_capture_channels : {1, 2, 4}) {
std::array<float, kFftLengthBy2Plus1> X2;
std::vector<std::array<float, kFftLengthBy2Plus1>> E2(
num_capture_channels);
std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(
num_capture_channels);
std::vector<bool> converged_filters(num_capture_channels, true);
GetFilterFreq(filter_frequency_response, config.delay.delay_headroom_samples);
EchoCanceller3Config config;
config.erle.onset_detection = true;
ErleEstimator estimator(0, config, kNumCaptureChannels);
std::vector<std::vector<std::vector<float>>> x(
kNumBands,
std::vector<std::vector<float>>(num_render_channels,
std::vector<float>(kBlockSize, 0.f)));
std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>>
filter_frequency_response(
config.filter.main.length_blocks,
std::vector<std::array<float, kFftLengthBy2Plus1>>(
num_capture_channels));
std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
RenderDelayBuffer::Create(config, kSampleRateHz,
num_render_channels));
FormFarendTimeFrame(&x);
render_delay_buffer->Insert(x);
render_delay_buffer->PrepareCaptureProcessing();
// Verifies that the ERLE estimate is properly increased to higher values.
FormFarendFrame(*render_delay_buffer->GetRenderBuffer(), &X2, &E2, &Y2,
kTrueErle);
for (size_t k = 0; k < 200; ++k) {
render_delay_buffer->Insert(x);
render_delay_buffer->PrepareCaptureProcessing();
estimator.Update(*render_delay_buffer->GetRenderBuffer(),
filter_frequency_response, X2, Y2, E2, true, true);
GetFilterFreq(config.delay.delay_headroom_samples,
filter_frequency_response);
ErleEstimator estimator(0, config, num_capture_channels);
FormFarendTimeFrame(&x);
render_delay_buffer->Insert(x);
render_delay_buffer->PrepareCaptureProcessing();
// Verifies that the ERLE estimate is properly increased to higher values.
FormFarendFrame(*render_delay_buffer->GetRenderBuffer(), kTrueErle, &X2,
E2, Y2);
for (size_t k = 0; k < 200; ++k) {
render_delay_buffer->Insert(x);
render_delay_buffer->PrepareCaptureProcessing();
estimator.Update(*render_delay_buffer->GetRenderBuffer(),
filter_frequency_response, X2, Y2, E2,
converged_filters);
}
VerifyErle(estimator.Erle(), std::pow(2.f, estimator.FullbandErleLog2()),
config.erle.max_l, config.erle.max_h);
FormNearendFrame(&x, &X2, E2, Y2);
// Verifies that the ERLE is not immediately decreased during nearend
// activity.
for (size_t k = 0; k < 50; ++k) {
render_delay_buffer->Insert(x);
render_delay_buffer->PrepareCaptureProcessing();
estimator.Update(*render_delay_buffer->GetRenderBuffer(),
filter_frequency_response, X2, Y2, E2,
converged_filters);
}
VerifyErle(estimator.Erle(), std::pow(2.f, estimator.FullbandErleLog2()),
config.erle.max_l, config.erle.max_h);
}
}
VerifyErle(estimator.Erle()[0], std::pow(2.f, estimator.FullbandErleLog2()),
config.erle.max_l, config.erle.max_h);
FormNearendFrame(&x, &X2, &E2, &Y2);
// Verifies that the ERLE is not immediately decreased during nearend
// activity.
for (size_t k = 0; k < 50; ++k) {
render_delay_buffer->Insert(x);
render_delay_buffer->PrepareCaptureProcessing();
estimator.Update(*render_delay_buffer->GetRenderBuffer(),
filter_frequency_response, X2, Y2, E2, true, true);
}
VerifyErle(estimator.Erle()[0], std::pow(2.f, estimator.FullbandErleLog2()),
config.erle.max_l, config.erle.max_h);
}
TEST(ErleEstimator, VerifyErleTrackingOnOnsets) {
constexpr size_t kNumRenderChannels = 1;
constexpr size_t kNumCaptureChannels = 1;
constexpr int kSampleRateHz = 48000;
constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
std::array<float, kFftLengthBy2Plus1> X2;
std::array<float, kFftLengthBy2Plus1> E2;
std::array<float, kFftLengthBy2Plus1> Y2;
EchoCanceller3Config config;
std::vector<std::vector<std::vector<float>>> x(
kNumBands, std::vector<std::vector<float>>(
kNumRenderChannels, std::vector<float>(kBlockSize, 0.f)));
std::vector<std::array<float, kFftLengthBy2Plus1>> filter_frequency_response(
config.filter.main.length_blocks);
std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
RenderDelayBuffer::Create(config, kSampleRateHz, kNumRenderChannels));
for (size_t num_render_channels : {1, 2, 4, 8}) {
for (size_t num_capture_channels : {1, 2, 4}) {
std::array<float, kFftLengthBy2Plus1> X2;
std::vector<std::array<float, kFftLengthBy2Plus1>> E2(
num_capture_channels);
std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(
num_capture_channels);
std::vector<bool> converged_filters(num_capture_channels, true);
EchoCanceller3Config config;
config.erle.onset_detection = true;
std::vector<std::vector<std::vector<float>>> x(
kNumBands,
std::vector<std::vector<float>>(num_render_channels,
std::vector<float>(kBlockSize, 0.f)));
std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>>
filter_frequency_response(
config.filter.main.length_blocks,
std::vector<std::array<float, kFftLengthBy2Plus1>>(
num_capture_channels));
std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
RenderDelayBuffer::Create(config, kSampleRateHz,
num_render_channels));
GetFilterFreq(filter_frequency_response, config.delay.delay_headroom_samples);
GetFilterFreq(config.delay.delay_headroom_samples,
filter_frequency_response);
ErleEstimator estimator(0, config, kNumCaptureChannels);
ErleEstimator estimator(/*startup_phase_length_blocks=*/0, config,
num_capture_channels);
FormFarendTimeFrame(&x);
render_delay_buffer->Insert(x);
render_delay_buffer->PrepareCaptureProcessing();
for (size_t burst = 0; burst < 20; ++burst) {
FormFarendFrame(*render_delay_buffer->GetRenderBuffer(), &X2, &E2, &Y2,
kTrueErleOnsets);
for (size_t k = 0; k < 10; ++k) {
FormFarendTimeFrame(&x);
render_delay_buffer->Insert(x);
render_delay_buffer->PrepareCaptureProcessing();
estimator.Update(*render_delay_buffer->GetRenderBuffer(),
filter_frequency_response, X2, Y2, E2, true, true);
}
FormFarendFrame(*render_delay_buffer->GetRenderBuffer(), &X2, &E2, &Y2,
kTrueErle);
for (size_t k = 0; k < 200; ++k) {
render_delay_buffer->Insert(x);
render_delay_buffer->PrepareCaptureProcessing();
estimator.Update(*render_delay_buffer->GetRenderBuffer(),
filter_frequency_response, X2, Y2, E2, true, true);
}
FormNearendFrame(&x, &X2, &E2, &Y2);
for (size_t k = 0; k < 300; ++k) {
render_delay_buffer->Insert(x);
render_delay_buffer->PrepareCaptureProcessing();
estimator.Update(*render_delay_buffer->GetRenderBuffer(),
filter_frequency_response, X2, Y2, E2, true, true);
for (size_t burst = 0; burst < 20; ++burst) {
FormFarendFrame(*render_delay_buffer->GetRenderBuffer(),
kTrueErleOnsets, &X2, E2, Y2);
for (size_t k = 0; k < 10; ++k) {
render_delay_buffer->Insert(x);
render_delay_buffer->PrepareCaptureProcessing();
estimator.Update(*render_delay_buffer->GetRenderBuffer(),
filter_frequency_response, X2, Y2, E2,
converged_filters);
}
FormFarendFrame(*render_delay_buffer->GetRenderBuffer(), kTrueErle, &X2,
E2, Y2);
for (size_t k = 0; k < 200; ++k) {
render_delay_buffer->Insert(x);
render_delay_buffer->PrepareCaptureProcessing();
estimator.Update(*render_delay_buffer->GetRenderBuffer(),
filter_frequency_response, X2, Y2, E2,
converged_filters);
}
FormNearendFrame(&x, &X2, E2, Y2);
for (size_t k = 0; k < 300; ++k) {
render_delay_buffer->Insert(x);
render_delay_buffer->PrepareCaptureProcessing();
estimator.Update(*render_delay_buffer->GetRenderBuffer(),
filter_frequency_response, X2, Y2, E2,
converged_filters);
}
}
VerifyErleBands(estimator.ErleOnsets(), config.erle.min, config.erle.min);
FormNearendFrame(&x, &X2, E2, Y2);
for (size_t k = 0; k < 1000; k++) {
estimator.Update(*render_delay_buffer->GetRenderBuffer(),
filter_frequency_response, X2, Y2, E2,
converged_filters);
}
// Verifies that during ne activity, Erle converges to the Erle for
// onsets.
VerifyErle(estimator.Erle(), std::pow(2.f, estimator.FullbandErleLog2()),
config.erle.min, config.erle.min);
}
}
VerifyErleBands(estimator.ErleOnsets(), config.erle.min, config.erle.min);
FormNearendFrame(&x, &X2, &E2, &Y2);
for (size_t k = 0; k < 1000; k++) {
estimator.Update(*render_delay_buffer->GetRenderBuffer(),
filter_frequency_response, X2, Y2, E2, true, true);
}
// Verifies that during ne activity, Erle converges to the Erle for onsets.
VerifyErle(estimator.Erle()[0], std::pow(2.f, estimator.FullbandErleLog2()),
config.erle.min, config.erle.min);
}
} // namespace webrtc

View File

@ -35,7 +35,9 @@ FullBandErleEstimator::FullBandErleEstimator(
size_t num_capture_channels)
: min_erle_log2_(FastApproxLog2f(config.min + kEpsilon)),
max_erle_lf_log2(FastApproxLog2f(config.max_l + kEpsilon)),
instantaneous_erle_(config),
hold_counters_time_domain_(num_capture_channels, 0),
erle_time_domain_log2_(num_capture_channels, min_erle_log2_),
instantaneous_erle_(num_capture_channels, ErleInstantaneous(config)),
linear_filters_qualities_(num_capture_channels) {
Reset();
}
@ -43,39 +45,49 @@ FullBandErleEstimator::FullBandErleEstimator(
FullBandErleEstimator::~FullBandErleEstimator() = default;
void FullBandErleEstimator::Reset() {
instantaneous_erle_.Reset();
for (auto& instantaneous_erle_ch : instantaneous_erle_) {
instantaneous_erle_ch.Reset();
}
UpdateQualityEstimates();
erle_time_domain_log2_ = min_erle_log2_;
hold_counter_time_domain_ = 0;
std::fill(erle_time_domain_log2_.begin(), erle_time_domain_log2_.end(),
min_erle_log2_);
std::fill(hold_counters_time_domain_.begin(),
hold_counters_time_domain_.end(), 0);
}
void FullBandErleEstimator::Update(rtc::ArrayView<const float> X2,
rtc::ArrayView<const float> Y2,
rtc::ArrayView<const float> E2,
bool converged_filter) {
if (converged_filter) {
// Computes the fullband ERLE.
const float X2_sum = std::accumulate(X2.begin(), X2.end(), 0.0f);
if (X2_sum > kX2BandEnergyThreshold * X2.size()) {
const float Y2_sum = std::accumulate(Y2.begin(), Y2.end(), 0.0f);
const float E2_sum = std::accumulate(E2.begin(), E2.end(), 0.0f);
if (instantaneous_erle_.Update(Y2_sum, E2_sum)) {
hold_counter_time_domain_ = kBlocksToHoldErle;
erle_time_domain_log2_ +=
0.1f * ((instantaneous_erle_.GetInstErleLog2().value()) -
erle_time_domain_log2_);
erle_time_domain_log2_ = rtc::SafeClamp(
erle_time_domain_log2_, min_erle_log2_, max_erle_lf_log2);
void FullBandErleEstimator::Update(
rtc::ArrayView<const float> X2,
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2,
const std::vector<bool>& converged_filters) {
for (size_t ch = 0; ch < Y2.size(); ++ch) {
if (converged_filters[ch]) {
// Computes the fullband ERLE.
const float X2_sum = std::accumulate(X2.begin(), X2.end(), 0.0f);
if (X2_sum > kX2BandEnergyThreshold * X2.size()) {
const float Y2_sum =
std::accumulate(Y2[ch].begin(), Y2[ch].end(), 0.0f);
const float E2_sum =
std::accumulate(E2[ch].begin(), E2[ch].end(), 0.0f);
if (instantaneous_erle_[ch].Update(Y2_sum, E2_sum)) {
hold_counters_time_domain_[ch] = kBlocksToHoldErle;
erle_time_domain_log2_[ch] +=
0.1f * ((instantaneous_erle_[ch].GetInstErleLog2().value()) -
erle_time_domain_log2_[ch]);
erle_time_domain_log2_[ch] = rtc::SafeClamp(
erle_time_domain_log2_[ch], min_erle_log2_, max_erle_lf_log2);
}
}
}
}
--hold_counter_time_domain_;
if (hold_counter_time_domain_ <= 0) {
erle_time_domain_log2_ =
std::max(min_erle_log2_, erle_time_domain_log2_ - 0.044f);
}
if (hold_counter_time_domain_ == 0) {
instantaneous_erle_.ResetAccumulators();
--hold_counters_time_domain_[ch];
if (hold_counters_time_domain_[ch] <= 0) {
erle_time_domain_log2_[ch] =
std::max(min_erle_log2_, erle_time_domain_log2_[ch] - 0.044f);
}
if (hold_counters_time_domain_[ch] == 0) {
instantaneous_erle_[ch].ResetAccumulators();
}
}
UpdateQualityEstimates();
@ -84,12 +96,14 @@ void FullBandErleEstimator::Update(rtc::ArrayView<const float> X2,
void FullBandErleEstimator::Dump(
const std::unique_ptr<ApmDataDumper>& data_dumper) const {
data_dumper->DumpRaw("aec3_fullband_erle_log2", FullbandErleLog2());
instantaneous_erle_.Dump(data_dumper);
instantaneous_erle_[0].Dump(data_dumper);
}
void FullBandErleEstimator::UpdateQualityEstimates() {
std::fill(linear_filters_qualities_.begin(), linear_filters_qualities_.end(),
instantaneous_erle_.GetQualityEstimate());
for (size_t ch = 0; ch < instantaneous_erle_.size(); ++ch) {
linear_filters_qualities_[ch] =
instantaneous_erle_[ch].GetQualityEstimate();
}
}
FullBandErleEstimator::ErleInstantaneous::ErleInstantaneous(

View File

@ -17,6 +17,7 @@
#include "absl/types/optional.h"
#include "api/array_view.h"
#include "api/audio/echo_canceller3_config.h"
#include "modules/audio_processing/aec3/aec3_common.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
namespace webrtc {
@ -33,12 +34,18 @@ class FullBandErleEstimator {
// Updates the ERLE estimator.
void Update(rtc::ArrayView<const float> X2,
rtc::ArrayView<const float> Y2,
rtc::ArrayView<const float> E2,
bool converged_filter);
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2,
const std::vector<bool>& converged_filters);
// Returns the fullband ERLE estimates in log2 units.
float FullbandErleLog2() const { return erle_time_domain_log2_; }
float FullbandErleLog2() const {
float min_erle = erle_time_domain_log2_[0];
for (size_t ch = 1; ch < erle_time_domain_log2_.size(); ++ch) {
min_erle = std::min(min_erle, erle_time_domain_log2_[ch]);
}
return min_erle;
}
// Returns an estimation of the current linear filter quality. It returns a
// float number between 0 and 1 mapping 1 to the highest possible quality.
@ -98,11 +105,11 @@ class FullBandErleEstimator {
int num_points_;
};
int hold_counter_time_domain_;
float erle_time_domain_log2_;
const float min_erle_log2_;
const float max_erle_lf_log2;
ErleInstantaneous instantaneous_erle_;
std::vector<int> hold_counters_time_domain_;
std::vector<float> erle_time_domain_log2_;
std::vector<ErleInstantaneous> instantaneous_erle_;
std::vector<absl::optional<float>> linear_filters_qualities_;
};

View File

@ -132,29 +132,38 @@ SignalDependentErleEstimator::SignalDependentErleEstimator(
num_blocks_,
num_sections_)),
erle_(num_capture_channels),
S2_section_accum_(num_sections_),
erle_estimators_(num_sections_),
correction_factors_(num_sections_) {
S2_section_accum_(
num_capture_channels,
std::vector<std::array<float, kFftLengthBy2Plus1>>(num_sections_)),
erle_estimators_(
num_capture_channels,
std::vector<std::array<float, kSubbands>>(num_sections_)),
erle_ref_(num_capture_channels),
correction_factors_(
num_capture_channels,
std::vector<std::array<float, kSubbands>>(num_sections_)),
num_updates_(num_capture_channels),
n_active_sections_(num_capture_channels) {
RTC_DCHECK_LE(num_sections_, num_blocks_);
RTC_DCHECK_GE(num_sections_, 1);
Reset();
}
SignalDependentErleEstimator::~SignalDependentErleEstimator() = default;
void SignalDependentErleEstimator::Reset() {
for (auto& erle : erle_) {
erle.fill(min_erle_);
for (size_t ch = 0; ch < erle_.size(); ++ch) {
erle_[ch].fill(min_erle_);
for (auto& erle_estimator : erle_estimators_[ch]) {
erle_estimator.fill(min_erle_);
}
erle_ref_[ch].fill(min_erle_);
for (auto& factor : correction_factors_[ch]) {
factor.fill(1.0f);
}
num_updates_[ch].fill(0);
n_active_sections_[ch].fill(0);
}
for (auto& erle_estimator : erle_estimators_) {
erle_estimator.fill(min_erle_);
}
erle_ref_.fill(min_erle_);
for (auto& factor : correction_factors_) {
factor.fill(1.0f);
}
num_updates_.fill(0);
}
// Updates the Erle estimate by analyzing the current input signals. It takes
@ -165,44 +174,45 @@ void SignalDependentErleEstimator::Reset() {
// correction factor to the erle that is given as an input to this method.
void SignalDependentErleEstimator::Update(
const RenderBuffer& render_buffer,
const std::vector<std::array<float, kFftLengthBy2Plus1>>&
filter_frequency_response,
rtc::ArrayView<const float> X2,
rtc::ArrayView<const float> Y2,
rtc::ArrayView<const float> E2,
rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>>
filter_frequency_responses,
rtc::ArrayView<const float, kFftLengthBy2Plus1> X2,
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2,
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> average_erle,
bool converged_filter) {
const std::vector<bool>& converged_filters) {
RTC_DCHECK_GT(num_sections_, 1);
// Gets the number of filter sections that are needed for achieving 90 %
// of the power spectrum energy of the echo estimate.
std::array<size_t, kFftLengthBy2Plus1> n_active_sections;
ComputeNumberOfActiveFilterSections(render_buffer, filter_frequency_response,
n_active_sections);
ComputeNumberOfActiveFilterSections(render_buffer,
filter_frequency_responses);
if (converged_filter) {
// Updates the correction factor that is used for correcting the erle and
// adapt it to the particular characteristics of the input signal.
UpdateCorrectionFactors(X2, Y2, E2, n_active_sections);
}
// Updates the correction factors that is used for correcting the erle and
// adapt it to the particular characteristics of the input signal.
UpdateCorrectionFactors(X2, Y2, E2, converged_filters);
// Applies the correction factor to the input erle for getting a more refined
// erle estimation for the current input signal.
for (size_t k = 0; k < kFftLengthBy2; ++k) {
float correction_factor =
correction_factors_[n_active_sections[k]][band_to_subband_[k]];
erle_[0][k] = rtc::SafeClamp(average_erle[0][k] * correction_factor,
min_erle_, max_erle_[band_to_subband_[k]]);
for (size_t ch = 0; ch < erle_.size(); ++ch) {
for (size_t k = 0; k < kFftLengthBy2; ++k) {
RTC_DCHECK_GT(correction_factors_[ch].size(), n_active_sections_[ch][k]);
float correction_factor =
correction_factors_[ch][n_active_sections_[ch][k]]
[band_to_subband_[k]];
erle_[ch][k] = rtc::SafeClamp(average_erle[ch][k] * correction_factor,
min_erle_, max_erle_[band_to_subband_[k]]);
}
}
}
void SignalDependentErleEstimator::Dump(
const std::unique_ptr<ApmDataDumper>& data_dumper) const {
for (auto& erle : erle_estimators_) {
for (auto& erle : erle_estimators_[0]) {
data_dumper->DumpRaw("aec3_all_erle", erle);
}
data_dumper->DumpRaw("aec3_ref_erle", erle_ref_);
for (auto& factor : correction_factors_) {
data_dumper->DumpRaw("aec3_ref_erle", erle_ref_[0]);
for (auto& factor : correction_factors_[0]) {
data_dumper->DumpRaw("aec3_erle_correction_factor", factor);
}
}
@ -211,163 +221,185 @@ void SignalDependentErleEstimator::Dump(
// together constitute 90% of the estimated echo energy.
void SignalDependentErleEstimator::ComputeNumberOfActiveFilterSections(
const RenderBuffer& render_buffer,
const std::vector<std::array<float, kFftLengthBy2Plus1>>&
filter_frequency_response,
rtc::ArrayView<size_t> n_active_filter_sections) {
rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>>
filter_frequency_responses) {
RTC_DCHECK_GT(num_sections_, 1);
// Computes an approximation of the power spectrum if the filter would have
// been limited to a certain number of filter sections.
ComputeEchoEstimatePerFilterSection(render_buffer, filter_frequency_response);
ComputeEchoEstimatePerFilterSection(render_buffer,
filter_frequency_responses);
// For each band, computes the number of filter sections that are needed for
// achieving the 90 % energy in the echo estimate.
ComputeActiveFilterSections(n_active_filter_sections);
ComputeActiveFilterSections();
}
void SignalDependentErleEstimator::UpdateCorrectionFactors(
rtc::ArrayView<const float> X2,
rtc::ArrayView<const float> Y2,
rtc::ArrayView<const float> E2,
rtc::ArrayView<const size_t> n_active_sections) {
constexpr float kX2BandEnergyThreshold = 44015068.0f;
constexpr float kSmthConstantDecreases = 0.1f;
constexpr float kSmthConstantIncreases = kSmthConstantDecreases / 2.f;
auto subband_powers = [](rtc::ArrayView<const float> power_spectrum,
rtc::ArrayView<float> power_spectrum_subbands) {
for (size_t subband = 0; subband < kSubbands; ++subband) {
RTC_DCHECK_LE(kBandBoundaries[subband + 1], power_spectrum.size());
power_spectrum_subbands[subband] = std::accumulate(
power_spectrum.begin() + kBandBoundaries[subband],
power_spectrum.begin() + kBandBoundaries[subband + 1], 0.f);
}
};
rtc::ArrayView<const float, kFftLengthBy2Plus1> X2,
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2,
const std::vector<bool>& converged_filters) {
for (size_t ch = 0; ch < converged_filters.size(); ++ch) {
if (converged_filters[ch]) {
constexpr float kX2BandEnergyThreshold = 44015068.0f;
constexpr float kSmthConstantDecreases = 0.1f;
constexpr float kSmthConstantIncreases = kSmthConstantDecreases / 2.f;
auto subband_powers = [](rtc::ArrayView<const float> power_spectrum,
rtc::ArrayView<float> power_spectrum_subbands) {
for (size_t subband = 0; subband < kSubbands; ++subband) {
RTC_DCHECK_LE(kBandBoundaries[subband + 1], power_spectrum.size());
power_spectrum_subbands[subband] = std::accumulate(
power_spectrum.begin() + kBandBoundaries[subband],
power_spectrum.begin() + kBandBoundaries[subband + 1], 0.f);
}
};
std::array<float, kSubbands> X2_subbands, E2_subbands, Y2_subbands;
subband_powers(X2, X2_subbands);
subband_powers(E2, E2_subbands);
subband_powers(Y2, Y2_subbands);
std::array<size_t, kSubbands> idx_subbands;
for (size_t subband = 0; subband < kSubbands; ++subband) {
// When aggregating the number of active sections in the filter for
// different bands we choose to take the minimum of all of them. As an
// example, if for one of the bands it is the direct path its main
// contributor to the final echo estimate, we consider the direct path is
// as well the main contributor for the subband that contains that
// particular band. That aggregate number of sections will be later used as
// the identifier of the erle estimator that needs to be updated.
RTC_DCHECK_LE(kBandBoundaries[subband + 1], n_active_sections.size());
idx_subbands[subband] = *std::min_element(
n_active_sections.begin() + kBandBoundaries[subband],
n_active_sections.begin() + kBandBoundaries[subband + 1]);
}
std::array<float, kSubbands> X2_subbands, E2_subbands, Y2_subbands;
subband_powers(X2, X2_subbands);
subband_powers(E2[ch], E2_subbands);
subband_powers(Y2[ch], Y2_subbands);
std::array<size_t, kSubbands> idx_subbands;
for (size_t subband = 0; subband < kSubbands; ++subband) {
// When aggregating the number of active sections in the filter for
// different bands we choose to take the minimum of all of them. As an
// example, if for one of the bands it is the direct path its main
// contributor to the final echo estimate, we consider the direct path
// is as well the main contributor for the subband that contains that
// particular band. That aggregate number of sections will be later used
// as the identifier of the erle estimator that needs to be updated.
RTC_DCHECK_LE(kBandBoundaries[subband + 1],
n_active_sections_[ch].size());
idx_subbands[subband] = *std::min_element(
n_active_sections_[ch].begin() + kBandBoundaries[subband],
n_active_sections_[ch].begin() + kBandBoundaries[subband + 1]);
}
std::array<float, kSubbands> new_erle;
std::array<bool, kSubbands> is_erle_updated;
is_erle_updated.fill(false);
new_erle.fill(0.f);
for (size_t subband = 0; subband < kSubbands; ++subband) {
if (X2_subbands[subband] > kX2BandEnergyThreshold &&
E2_subbands[subband] > 0) {
new_erle[subband] = Y2_subbands[subband] / E2_subbands[subband];
RTC_DCHECK_GT(new_erle[subband], 0);
is_erle_updated[subband] = true;
++num_updates_[subband];
}
}
std::array<float, kSubbands> new_erle;
std::array<bool, kSubbands> is_erle_updated;
is_erle_updated.fill(false);
new_erle.fill(0.f);
for (size_t subband = 0; subband < kSubbands; ++subband) {
if (X2_subbands[subband] > kX2BandEnergyThreshold &&
E2_subbands[subband] > 0) {
new_erle[subband] = Y2_subbands[subband] / E2_subbands[subband];
RTC_DCHECK_GT(new_erle[subband], 0);
is_erle_updated[subband] = true;
++num_updates_[ch][subband];
}
}
for (size_t subband = 0; subband < kSubbands; ++subband) {
const size_t idx = idx_subbands[subband];
RTC_DCHECK_LT(idx, erle_estimators_.size());
float alpha = new_erle[subband] > erle_estimators_[idx][subband]
? kSmthConstantIncreases
: kSmthConstantDecreases;
alpha = static_cast<float>(is_erle_updated[subband]) * alpha;
erle_estimators_[idx][subband] +=
alpha * (new_erle[subband] - erle_estimators_[idx][subband]);
erle_estimators_[idx][subband] = rtc::SafeClamp(
erle_estimators_[idx][subband], min_erle_, max_erle_[subband]);
}
for (size_t subband = 0; subband < kSubbands; ++subband) {
const size_t idx = idx_subbands[subband];
RTC_DCHECK_LT(idx, erle_estimators_[ch].size());
float alpha = new_erle[subband] > erle_estimators_[ch][idx][subband]
? kSmthConstantIncreases
: kSmthConstantDecreases;
alpha = static_cast<float>(is_erle_updated[subband]) * alpha;
erle_estimators_[ch][idx][subband] +=
alpha * (new_erle[subband] - erle_estimators_[ch][idx][subband]);
erle_estimators_[ch][idx][subband] = rtc::SafeClamp(
erle_estimators_[ch][idx][subband], min_erle_, max_erle_[subband]);
}
for (size_t subband = 0; subband < kSubbands; ++subband) {
float alpha = new_erle[subband] > erle_ref_[subband]
? kSmthConstantIncreases
: kSmthConstantDecreases;
alpha = static_cast<float>(is_erle_updated[subband]) * alpha;
erle_ref_[subband] += alpha * (new_erle[subband] - erle_ref_[subband]);
erle_ref_[subband] =
rtc::SafeClamp(erle_ref_[subband], min_erle_, max_erle_[subband]);
}
for (size_t subband = 0; subband < kSubbands; ++subband) {
float alpha = new_erle[subband] > erle_ref_[ch][subband]
? kSmthConstantIncreases
: kSmthConstantDecreases;
alpha = static_cast<float>(is_erle_updated[subband]) * alpha;
erle_ref_[ch][subband] +=
alpha * (new_erle[subband] - erle_ref_[ch][subband]);
erle_ref_[ch][subband] = rtc::SafeClamp(erle_ref_[ch][subband],
min_erle_, max_erle_[subband]);
}
for (size_t subband = 0; subband < kSubbands; ++subband) {
constexpr int kNumUpdateThr = 50;
if (is_erle_updated[subband] && num_updates_[subband] > kNumUpdateThr) {
const size_t idx = idx_subbands[subband];
RTC_DCHECK_GT(erle_ref_[subband], 0.f);
// Computes the ratio between the erle that is updated using all the
// points and the erle that is updated only on signals that share the
// same number of active filter sections.
float new_correction_factor =
erle_estimators_[idx][subband] / erle_ref_[subband];
for (size_t subband = 0; subband < kSubbands; ++subband) {
constexpr int kNumUpdateThr = 50;
if (is_erle_updated[subband] &&
num_updates_[ch][subband] > kNumUpdateThr) {
const size_t idx = idx_subbands[subband];
RTC_DCHECK_GT(erle_ref_[ch][subband], 0.f);
// Computes the ratio between the erle that is updated using all the
// points and the erle that is updated only on signals that share the
// same number of active filter sections.
float new_correction_factor =
erle_estimators_[ch][idx][subband] / erle_ref_[ch][subband];
correction_factors_[idx][subband] +=
0.1f * (new_correction_factor - correction_factors_[idx][subband]);
correction_factors_[ch][idx][subband] +=
0.1f *
(new_correction_factor - correction_factors_[ch][idx][subband]);
}
}
}
}
}
void SignalDependentErleEstimator::ComputeEchoEstimatePerFilterSection(
const RenderBuffer& render_buffer,
const std::vector<std::array<float, kFftLengthBy2Plus1>>&
filter_frequency_response) {
rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>>
filter_frequency_responses) {
const SpectrumBuffer& spectrum_render_buffer =
render_buffer.GetSpectrumBuffer();
const size_t num_render_channels = spectrum_render_buffer.buffer[0].size();
const size_t num_capture_channels = S2_section_accum_.size();
const float one_by_num_render_channels = 1.f / num_render_channels;
RTC_DCHECK_EQ(S2_section_accum_.size() + 1,
section_boundaries_blocks_.size());
size_t idx_render = render_buffer.Position();
idx_render = spectrum_render_buffer.OffsetIndex(
idx_render, section_boundaries_blocks_[0]);
RTC_DCHECK_EQ(S2_section_accum_.size(), filter_frequency_responses.size());
for (size_t section = 0; section < num_sections_; ++section) {
std::array<float, kFftLengthBy2Plus1> X2_section;
std::array<float, kFftLengthBy2Plus1> H2_section;
X2_section.fill(0.f);
H2_section.fill(0.f);
const size_t block_limit = std::min(section_boundaries_blocks_[section + 1],
filter_frequency_response.size());
for (size_t block = section_boundaries_blocks_[section];
block < block_limit; ++block) {
std::transform(
X2_section.begin(), X2_section.end(),
spectrum_render_buffer.buffer[idx_render][/*channel=*/0].begin(),
X2_section.begin(), std::plus<float>());
std::transform(H2_section.begin(), H2_section.end(),
filter_frequency_response[block].begin(),
H2_section.begin(), std::plus<float>());
idx_render = spectrum_render_buffer.IncIndex(idx_render);
for (size_t capture_ch = 0; capture_ch < num_capture_channels; ++capture_ch) {
RTC_DCHECK_EQ(S2_section_accum_[capture_ch].size() + 1,
section_boundaries_blocks_.size());
size_t idx_render = render_buffer.Position();
idx_render = spectrum_render_buffer.OffsetIndex(
idx_render, section_boundaries_blocks_[0]);
for (size_t section = 0; section < num_sections_; ++section) {
std::array<float, kFftLengthBy2Plus1> X2_section;
std::array<float, kFftLengthBy2Plus1> H2_section;
X2_section.fill(0.f);
H2_section.fill(0.f);
const size_t block_limit =
std::min(section_boundaries_blocks_[section + 1],
filter_frequency_responses[capture_ch].size());
for (size_t block = section_boundaries_blocks_[section];
block < block_limit; ++block) {
for (size_t render_ch = 0;
render_ch < spectrum_render_buffer.buffer[idx_render].size();
++render_ch) {
for (size_t k = 0; k < X2_section.size(); ++k) {
X2_section[k] +=
spectrum_render_buffer.buffer[idx_render][render_ch][k] *
one_by_num_render_channels;
}
}
std::transform(H2_section.begin(), H2_section.end(),
filter_frequency_responses[capture_ch][block].begin(),
H2_section.begin(), std::plus<float>());
idx_render = spectrum_render_buffer.IncIndex(idx_render);
}
std::transform(X2_section.begin(), X2_section.end(), H2_section.begin(),
S2_section_accum_[capture_ch][section].begin(),
std::multiplies<float>());
}
std::transform(X2_section.begin(), X2_section.end(), H2_section.begin(),
S2_section_accum_[section].begin(),
std::multiplies<float>());
}
for (size_t section = 1; section < num_sections_; ++section) {
std::transform(S2_section_accum_[section - 1].begin(),
S2_section_accum_[section - 1].end(),
S2_section_accum_[section].begin(),
S2_section_accum_[section].begin(), std::plus<float>());
for (size_t section = 1; section < num_sections_; ++section) {
std::transform(S2_section_accum_[capture_ch][section - 1].begin(),
S2_section_accum_[capture_ch][section - 1].end(),
S2_section_accum_[capture_ch][section].begin(),
S2_section_accum_[capture_ch][section].begin(),
std::plus<float>());
}
}
}
void SignalDependentErleEstimator::ComputeActiveFilterSections(
rtc::ArrayView<size_t> number_active_filter_sections) const {
std::fill(number_active_filter_sections.begin(),
number_active_filter_sections.end(), 0);
for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
size_t section = num_sections_;
float target = 0.9f * S2_section_accum_[num_sections_ - 1][k];
while (section > 0 && S2_section_accum_[section - 1][k] >= target) {
number_active_filter_sections[k] = --section;
void SignalDependentErleEstimator::ComputeActiveFilterSections() {
for (size_t ch = 0; ch < n_active_sections_.size(); ++ch) {
std::fill(n_active_sections_[ch].begin(), n_active_sections_[ch].end(), 0);
for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
size_t section = num_sections_;
float target = 0.9f * S2_section_accum_[ch][num_sections_ - 1][k];
while (section > 0 && S2_section_accum_[ch][section - 1][k] >= target) {
n_active_sections_[ch][k] = --section;
}
}
}
}

View File

@ -45,13 +45,13 @@ class SignalDependentErleEstimator {
// to be an estimation of the average Erle achieved by the linear filter.
void Update(
const RenderBuffer& render_buffer,
const std::vector<std::array<float, kFftLengthBy2Plus1>>&
rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>>
filter_frequency_response,
rtc::ArrayView<const float> X2,
rtc::ArrayView<const float> Y2,
rtc::ArrayView<const float> E2,
rtc::ArrayView<const float, kFftLengthBy2Plus1> X2,
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2,
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> average_erle,
bool converged_filter);
const std::vector<bool>& converged_filters);
void Dump(const std::unique_ptr<ApmDataDumper>& data_dumper) const;
@ -60,22 +60,21 @@ class SignalDependentErleEstimator {
private:
void ComputeNumberOfActiveFilterSections(
const RenderBuffer& render_buffer,
const std::vector<std::array<float, kFftLengthBy2Plus1>>&
filter_frequency_response,
rtc::ArrayView<size_t> n_active_filter_sections);
rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>>
filter_frequency_responses);
void UpdateCorrectionFactors(rtc::ArrayView<const float> X2,
rtc::ArrayView<const float> Y2,
rtc::ArrayView<const float> E2,
rtc::ArrayView<const size_t> n_active_sections);
void UpdateCorrectionFactors(
rtc::ArrayView<const float, kFftLengthBy2Plus1> X2,
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2,
const std::vector<bool>& converged_filters);
void ComputeEchoEstimatePerFilterSection(
const RenderBuffer& render_buffer,
const std::vector<std::array<float, kFftLengthBy2Plus1>>&
filter_frequency_response);
rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>>
filter_frequency_responses);
void ComputeActiveFilterSections(
rtc::ArrayView<size_t> number_active_filter_sections) const;
void ComputeActiveFilterSections();
const float min_erle_;
const size_t num_sections_;
@ -85,11 +84,13 @@ class SignalDependentErleEstimator {
const std::array<float, kSubbands> max_erle_;
const std::vector<size_t> section_boundaries_blocks_;
std::vector<std::array<float, kFftLengthBy2Plus1>> erle_;
std::vector<std::array<float, kFftLengthBy2Plus1>> S2_section_accum_;
std::vector<std::array<float, kSubbands>> erle_estimators_;
std::array<float, kSubbands> erle_ref_;
std::vector<std::array<float, kSubbands>> correction_factors_;
std::array<int, kSubbands> num_updates_;
std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>>
S2_section_accum_;
std::vector<std::vector<std::array<float, kSubbands>>> erle_estimators_;
std::vector<std::array<float, kSubbands>> erle_ref_;
std::vector<std::vector<std::array<float, kSubbands>>> correction_factors_;
std::vector<std::array<int, kSubbands>> num_updates_;
std::vector<std::array<size_t, kFftLengthBy2Plus1>> n_active_sections_;
};
} // namespace webrtc

View File

@ -44,13 +44,25 @@ void GetActiveFrame(std::vector<std::vector<std::vector<float>>>* x) {
class TestInputs {
public:
explicit TestInputs(const EchoCanceller3Config& cfg);
TestInputs(const EchoCanceller3Config& cfg,
size_t num_render_channels,
size_t num_capture_channels);
~TestInputs();
const RenderBuffer& GetRenderBuffer() { return *render_buffer_; }
rtc::ArrayView<const float> GetX2() { return X2_; }
rtc::ArrayView<const float> GetY2() { return Y2_; }
rtc::ArrayView<const float> GetE2() { return E2_; }
std::vector<std::array<float, kFftLengthBy2Plus1>> GetH2() { return H2_; }
rtc::ArrayView<const float, kFftLengthBy2Plus1> GetX2() { return X2_; }
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> GetY2() const {
return Y2_;
}
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> GetE2() const {
return E2_;
}
rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>>
GetH2() const {
return H2_;
}
const std::vector<bool>& GetConvergedFilters() const {
return converged_filters_;
}
void Update();
private:
@ -59,24 +71,37 @@ class TestInputs {
std::unique_ptr<RenderDelayBuffer> render_delay_buffer_;
RenderBuffer* render_buffer_;
std::array<float, kFftLengthBy2Plus1> X2_;
std::array<float, kFftLengthBy2Plus1> Y2_;
std::array<float, kFftLengthBy2Plus1> E2_;
std::vector<std::array<float, kFftLengthBy2Plus1>> H2_;
std::vector<std::array<float, kFftLengthBy2Plus1>> Y2_;
std::vector<std::array<float, kFftLengthBy2Plus1>> E2_;
std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>> H2_;
std::vector<std::vector<std::vector<float>>> x_;
std::vector<bool> converged_filters_;
};
TestInputs::TestInputs(const EchoCanceller3Config& cfg)
: render_delay_buffer_(RenderDelayBuffer::Create(cfg, 16000, 1)),
H2_(cfg.filter.main.length_blocks),
TestInputs::TestInputs(const EchoCanceller3Config& cfg,
size_t num_render_channels,
size_t num_capture_channels)
: render_delay_buffer_(
RenderDelayBuffer::Create(cfg, 16000, num_render_channels)),
Y2_(num_capture_channels),
E2_(num_capture_channels),
H2_(num_capture_channels,
std::vector<std::array<float, kFftLengthBy2Plus1>>(
cfg.filter.main.length_blocks)),
x_(1,
std::vector<std::vector<float>>(1,
std::vector<float>(kBlockSize, 0.f))) {
std::vector<std::vector<float>>(num_render_channels,
std::vector<float>(kBlockSize, 0.f))),
converged_filters_(num_capture_channels, true) {
render_delay_buffer_->AlignFromDelay(4);
render_buffer_ = render_delay_buffer_->GetRenderBuffer();
for (auto& H : H2_) {
H.fill(0.f);
for (auto& H2_ch : H2_) {
for (auto& H2_p : H2_ch) {
H2_p.fill(0.f);
}
}
for (auto& H2_p : H2_[0]) {
H2_p.fill(1.f);
}
H2_[0].fill(1.0f);
}
TestInputs::~TestInputs() = default;
@ -102,40 +127,47 @@ void TestInputs::UpdateCurrentPowerSpectra() {
auto& X2 = spectrum_render_buffer.buffer[idx][/*channel=*/0];
auto& X2_prev = spectrum_render_buffer.buffer[prev_idx][/*channel=*/0];
std::copy(X2.begin(), X2.end(), X2_.begin());
RTC_DCHECK_EQ(X2.size(), Y2_.size());
for (size_t k = 0; k < X2.size(); ++k) {
E2_[k] = 0.01f * X2_prev[k];
Y2_[k] = X2[k] + E2_[k];
for (size_t ch = 0; ch < Y2_.size(); ++ch) {
RTC_DCHECK_EQ(X2.size(), Y2_[ch].size());
for (size_t k = 0; k < X2.size(); ++k) {
E2_[ch][k] = 0.01f * X2_prev[k];
Y2_[ch][k] = X2[k] + E2_[ch][k];
}
}
}
} // namespace
TEST(SignalDependentErleEstimator, SweepSettings) {
const size_t kNumCaptureChannels = 1;
EchoCanceller3Config cfg;
size_t max_length_blocks = 50;
for (size_t blocks = 0; blocks < max_length_blocks; blocks = blocks + 10) {
for (size_t delay_headroom = 0; delay_headroom < 5; ++delay_headroom) {
for (size_t num_sections = 2; num_sections < max_length_blocks;
++num_sections) {
cfg.filter.main.length_blocks = blocks;
cfg.filter.main_initial.length_blocks =
std::min(cfg.filter.main_initial.length_blocks, blocks);
cfg.delay.delay_headroom_samples = delay_headroom * kBlockSize;
cfg.erle.num_sections = num_sections;
if (EchoCanceller3Config::Validate(&cfg)) {
SignalDependentErleEstimator s(cfg, kNumCaptureChannels);
std::array<std::array<float, kFftLengthBy2Plus1>, kNumCaptureChannels>
average_erle;
for (auto& e : average_erle) {
e.fill(cfg.erle.max_l);
}
TestInputs inputs(cfg);
for (size_t n = 0; n < 10; ++n) {
inputs.Update();
s.Update(inputs.GetRenderBuffer(), inputs.GetH2(), inputs.GetX2(),
inputs.GetY2(), inputs.GetE2(), average_erle, true);
for (size_t num_render_channels : {1, 2, 4}) {
for (size_t num_capture_channels : {1, 2, 4}) {
EchoCanceller3Config cfg;
size_t max_length_blocks = 50;
for (size_t blocks = 0; blocks < max_length_blocks;
blocks = blocks + 10) {
for (size_t delay_headroom = 0; delay_headroom < 5; ++delay_headroom) {
for (size_t num_sections = 2; num_sections < max_length_blocks;
++num_sections) {
cfg.filter.main.length_blocks = blocks;
cfg.filter.main_initial.length_blocks =
std::min(cfg.filter.main_initial.length_blocks, blocks);
cfg.delay.delay_headroom_samples = delay_headroom * kBlockSize;
cfg.erle.num_sections = num_sections;
if (EchoCanceller3Config::Validate(&cfg)) {
SignalDependentErleEstimator s(cfg, num_capture_channels);
std::vector<std::array<float, kFftLengthBy2Plus1>> average_erle(
num_capture_channels);
for (auto& e : average_erle) {
e.fill(cfg.erle.max_l);
}
TestInputs inputs(cfg, num_render_channels, num_capture_channels);
for (size_t n = 0; n < 10; ++n) {
inputs.Update();
s.Update(inputs.GetRenderBuffer(), inputs.GetH2(),
inputs.GetX2(), inputs.GetY2(), inputs.GetE2(),
average_erle, inputs.GetConvergedFilters());
}
}
}
}
}
@ -144,25 +176,29 @@ TEST(SignalDependentErleEstimator, SweepSettings) {
}
TEST(SignalDependentErleEstimator, LongerRun) {
const size_t kNumCaptureChannels = 1;
EchoCanceller3Config cfg;
cfg.filter.main.length_blocks = 2;
cfg.filter.main_initial.length_blocks = 1;
cfg.delay.delay_headroom_samples = 0;
cfg.delay.hysteresis_limit_blocks = 0;
cfg.erle.num_sections = 2;
EXPECT_EQ(EchoCanceller3Config::Validate(&cfg), true);
std::array<std::array<float, kFftLengthBy2Plus1>, kNumCaptureChannels>
average_erle;
for (auto& e : average_erle) {
e.fill(cfg.erle.max_l);
}
SignalDependentErleEstimator s(cfg, kNumCaptureChannels);
TestInputs inputs(cfg);
for (size_t n = 0; n < 200; ++n) {
inputs.Update();
s.Update(inputs.GetRenderBuffer(), inputs.GetH2(), inputs.GetX2(),
inputs.GetY2(), inputs.GetE2(), average_erle, true);
for (size_t num_render_channels : {1, 2, 4}) {
for (size_t num_capture_channels : {1, 2, 4}) {
EchoCanceller3Config cfg;
cfg.filter.main.length_blocks = 2;
cfg.filter.main_initial.length_blocks = 1;
cfg.delay.delay_headroom_samples = 0;
cfg.delay.hysteresis_limit_blocks = 0;
cfg.erle.num_sections = 2;
EXPECT_EQ(EchoCanceller3Config::Validate(&cfg), true);
std::vector<std::array<float, kFftLengthBy2Plus1>> average_erle(
num_capture_channels);
for (auto& e : average_erle) {
e.fill(cfg.erle.max_l);
}
SignalDependentErleEstimator s(cfg, num_capture_channels);
TestInputs inputs(cfg, num_render_channels, num_capture_channels);
for (size_t n = 0; n < 200; ++n) {
inputs.Update();
s.Update(inputs.GetRenderBuffer(), inputs.GetH2(), inputs.GetX2(),
inputs.GetY2(), inputs.GetE2(), average_erle,
inputs.GetConvergedFilters());
}
}
}
}

View File

@ -42,10 +42,15 @@ bool EnableMinErleDuringOnsets() {
SubbandErleEstimator::SubbandErleEstimator(const EchoCanceller3Config& config,
size_t num_capture_channels)
: min_erle_(config.erle.min),
: use_onset_detection_(config.erle.onset_detection),
min_erle_(config.erle.min),
max_erle_(SetMaxErleBands(config.erle.max_l, config.erle.max_h)),
use_min_erle_during_onsets_(EnableMinErleDuringOnsets()),
erle_(num_capture_channels) {
accum_spectra_(num_capture_channels),
erle_(num_capture_channels),
erle_onsets_(num_capture_channels),
coming_onset_(num_capture_channels),
hold_counters_(num_capture_channels) {
Reset();
}
@ -55,26 +60,23 @@ void SubbandErleEstimator::Reset() {
for (auto& erle : erle_) {
erle.fill(min_erle_);
}
erle_onsets_.fill(min_erle_);
coming_onset_.fill(true);
hold_counters_.fill(0);
for (size_t ch = 0; ch < erle_onsets_.size(); ++ch) {
erle_onsets_[ch].fill(min_erle_);
coming_onset_[ch].fill(true);
hold_counters_[ch].fill(0);
}
ResetAccumulatedSpectra();
}
void SubbandErleEstimator::Update(rtc::ArrayView<const float> X2,
rtc::ArrayView<const float> Y2,
rtc::ArrayView<const float> E2,
bool converged_filter,
bool onset_detection) {
if (converged_filter) {
// Note that the use of the converged_filter flag already imposed
// a minimum of the erle that can be estimated as that flag would
// be false if the filter is performing poorly.
UpdateAccumulatedSpectra(X2, Y2, E2);
UpdateBands(onset_detection);
}
void SubbandErleEstimator::Update(
rtc::ArrayView<const float, kFftLengthBy2Plus1> X2,
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2,
const std::vector<bool>& converged_filters) {
UpdateAccumulatedSpectra(X2, Y2, E2, converged_filters);
UpdateBands(converged_filters);
if (onset_detection) {
if (use_onset_detection_) {
DecreaseErlePerBandForLowRenderSignals();
}
@ -86,97 +88,129 @@ void SubbandErleEstimator::Update(rtc::ArrayView<const float> X2,
void SubbandErleEstimator::Dump(
const std::unique_ptr<ApmDataDumper>& data_dumper) const {
data_dumper->DumpRaw("aec3_erle_onset", ErleOnsets());
data_dumper->DumpRaw("aec3_erle_onset", ErleOnsets()[0]);
}
void SubbandErleEstimator::UpdateBands(bool onset_detection) {
std::array<float, kFftLengthBy2> new_erle;
std::array<bool, kFftLengthBy2> is_erle_updated;
is_erle_updated.fill(false);
for (size_t k = 1; k < kFftLengthBy2; ++k) {
if (accum_spectra_.num_points_[k] == kPointsToAccumulate &&
accum_spectra_.E2_[k] > 0.f) {
new_erle[k] = accum_spectra_.Y2_[k] / accum_spectra_.E2_[k];
is_erle_updated[k] = true;
void SubbandErleEstimator::UpdateBands(
const std::vector<bool>& converged_filters) {
const int num_capture_channels = static_cast<int>(accum_spectra_.Y2.size());
for (int ch = 0; ch < num_capture_channels; ++ch) {
// Note that the use of the converged_filter flag already imposed
// a minimum of the erle that can be estimated as that flag would
// be false if the filter is performing poorly.
if (!converged_filters[ch]) {
continue;
}
}
if (onset_detection) {
std::array<float, kFftLengthBy2> new_erle;
std::array<bool, kFftLengthBy2> is_erle_updated;
is_erle_updated.fill(false);
for (size_t k = 1; k < kFftLengthBy2; ++k) {
if (is_erle_updated[k] && !accum_spectra_.low_render_energy_[k]) {
if (coming_onset_[k]) {
coming_onset_[k] = false;
if (!use_min_erle_during_onsets_) {
float alpha = new_erle[k] < erle_onsets_[k] ? 0.3f : 0.15f;
erle_onsets_[k] = rtc::SafeClamp(
erle_onsets_[k] + alpha * (new_erle[k] - erle_onsets_[k]),
min_erle_, max_erle_[k]);
}
}
hold_counters_[k] = kBlocksForOnsetDetection;
if (accum_spectra_.num_points[ch] == kPointsToAccumulate &&
accum_spectra_.E2[ch][k] > 0.f) {
new_erle[k] = accum_spectra_.Y2[ch][k] / accum_spectra_.E2[ch][k];
is_erle_updated[k] = true;
}
}
}
for (size_t k = 1; k < kFftLengthBy2; ++k) {
if (is_erle_updated[k]) {
float alpha = 0.05f;
if (new_erle[k] < erle_[0][k]) {
alpha = accum_spectra_.low_render_energy_[k] ? 0.f : 0.1f;
if (use_onset_detection_) {
for (size_t k = 1; k < kFftLengthBy2; ++k) {
if (is_erle_updated[k] && !accum_spectra_.low_render_energy[ch][k]) {
if (coming_onset_[ch][k]) {
coming_onset_[ch][k] = false;
if (!use_min_erle_during_onsets_) {
float alpha = new_erle[k] < erle_onsets_[ch][k] ? 0.3f : 0.15f;
erle_onsets_[ch][k] = rtc::SafeClamp(
erle_onsets_[ch][k] +
alpha * (new_erle[k] - erle_onsets_[ch][k]),
min_erle_, max_erle_[k]);
}
}
hold_counters_[ch][k] = kBlocksForOnsetDetection;
}
}
}
for (size_t k = 1; k < kFftLengthBy2; ++k) {
if (is_erle_updated[k]) {
float alpha = 0.05f;
if (new_erle[k] < erle_[ch][k]) {
alpha = accum_spectra_.low_render_energy[ch][k] ? 0.f : 0.1f;
}
erle_[ch][k] =
rtc::SafeClamp(erle_[ch][k] + alpha * (new_erle[k] - erle_[ch][k]),
min_erle_, max_erle_[k]);
}
erle_[0][k] =
rtc::SafeClamp(erle_[0][k] + alpha * (new_erle[k] - erle_[0][k]),
min_erle_, max_erle_[k]);
}
}
}
void SubbandErleEstimator::DecreaseErlePerBandForLowRenderSignals() {
for (size_t k = 1; k < kFftLengthBy2; ++k) {
hold_counters_[k]--;
if (hold_counters_[k] <= (kBlocksForOnsetDetection - kBlocksToHoldErle)) {
if (erle_[0][k] > erle_onsets_[k]) {
erle_[0][k] = std::max(erle_onsets_[k], 0.97f * erle_[0][k]);
RTC_DCHECK_LE(min_erle_, erle_[0][k]);
}
if (hold_counters_[k] <= 0) {
coming_onset_[k] = true;
hold_counters_[k] = 0;
const int num_capture_channels = static_cast<int>(accum_spectra_.Y2.size());
for (int ch = 0; ch < num_capture_channels; ++ch) {
for (size_t k = 1; k < kFftLengthBy2; ++k) {
--hold_counters_[ch][k];
if (hold_counters_[ch][k] <=
(kBlocksForOnsetDetection - kBlocksToHoldErle)) {
if (erle_[ch][k] > erle_onsets_[ch][k]) {
erle_[ch][k] = std::max(erle_onsets_[ch][k], 0.97f * erle_[ch][k]);
RTC_DCHECK_LE(min_erle_, erle_[ch][k]);
}
if (hold_counters_[ch][k] <= 0) {
coming_onset_[ch][k] = true;
hold_counters_[ch][k] = 0;
}
}
}
}
}
void SubbandErleEstimator::ResetAccumulatedSpectra() {
accum_spectra_.Y2_.fill(0.f);
accum_spectra_.E2_.fill(0.f);
accum_spectra_.num_points_.fill(0);
accum_spectra_.low_render_energy_.fill(false);
for (size_t ch = 0; ch < erle_onsets_.size(); ++ch) {
accum_spectra_.Y2[ch].fill(0.f);
accum_spectra_.E2[ch].fill(0.f);
accum_spectra_.num_points[ch] = 0;
accum_spectra_.low_render_energy[ch].fill(false);
}
}
void SubbandErleEstimator::UpdateAccumulatedSpectra(
rtc::ArrayView<const float> X2,
rtc::ArrayView<const float> Y2,
rtc::ArrayView<const float> E2) {
rtc::ArrayView<const float, kFftLengthBy2Plus1> X2,
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2,
const std::vector<bool>& converged_filters) {
auto& st = accum_spectra_;
if (st.num_points_[0] == kPointsToAccumulate) {
st.num_points_[0] = 0;
st.Y2_.fill(0.f);
st.E2_.fill(0.f);
st.low_render_energy_.fill(false);
}
std::transform(Y2.begin(), Y2.end(), st.Y2_.begin(), st.Y2_.begin(),
std::plus<float>());
std::transform(E2.begin(), E2.end(), st.E2_.begin(), st.E2_.begin(),
std::plus<float>());
RTC_DCHECK_EQ(st.E2.size(), E2.size());
RTC_DCHECK_EQ(st.E2.size(), E2.size());
const int num_capture_channels = static_cast<int>(Y2.size());
for (int ch = 0; ch < num_capture_channels; ++ch) {
// Note that the use of the converged_filter flag already imposed
// a minimum of the erle that can be estimated as that flag would
// be false if the filter is performing poorly.
if (!converged_filters[ch]) {
continue;
}
for (size_t k = 0; k < X2.size(); ++k) {
st.low_render_energy_[k] =
st.low_render_energy_[k] || X2[k] < kX2BandEnergyThreshold;
if (st.num_points[ch] == kPointsToAccumulate) {
st.num_points[ch] = 0;
st.Y2[ch].fill(0.f);
st.E2[ch].fill(0.f);
st.low_render_energy[ch].fill(false);
}
std::transform(Y2[ch].begin(), Y2[ch].end(), st.Y2[ch].begin(),
st.Y2[ch].begin(), std::plus<float>());
std::transform(E2[ch].begin(), E2[ch].end(), st.E2[ch].begin(),
st.E2[ch].begin(), std::plus<float>());
for (size_t k = 0; k < X2.size(); ++k) {
st.low_render_energy[ch][k] =
st.low_render_energy[ch][k] || X2[k] < kX2BandEnergyThreshold;
}
++st.num_points[ch];
}
st.num_points_[0]++;
st.num_points_.fill(st.num_points_[0]);
}
} // namespace webrtc

View File

@ -35,47 +35,57 @@ class SubbandErleEstimator {
void Reset();
// Updates the ERLE estimate.
void Update(rtc::ArrayView<const float> X2,
rtc::ArrayView<const float> Y2,
rtc::ArrayView<const float> E2,
bool converged_filter,
bool onset_detection);
void Update(rtc::ArrayView<const float, kFftLengthBy2Plus1> X2,
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2,
const std::vector<bool>& converged_filters);
// Returns the ERLE estimate.
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Erle() const {
return erle_;
}
// Returns the ERLE estimate at onsets.
rtc::ArrayView<const float> ErleOnsets() const { return erle_onsets_; }
// Returns the ERLE estimate at onsets (only used for testing).
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> ErleOnsets()
const {
return erle_onsets_;
}
void Dump(const std::unique_ptr<ApmDataDumper>& data_dumper) const;
private:
struct AccumulatedSpectra {
std::array<float, kFftLengthBy2Plus1> Y2_;
std::array<float, kFftLengthBy2Plus1> E2_;
std::array<bool, kFftLengthBy2Plus1> low_render_energy_;
std::array<int, kFftLengthBy2Plus1> num_points_;
explicit AccumulatedSpectra(size_t num_capture_channels)
: Y2(num_capture_channels),
E2(num_capture_channels),
low_render_energy(num_capture_channels),
num_points(num_capture_channels) {}
std::vector<std::array<float, kFftLengthBy2Plus1>> Y2;
std::vector<std::array<float, kFftLengthBy2Plus1>> E2;
std::vector<std::array<bool, kFftLengthBy2Plus1>> low_render_energy;
std::vector<int> num_points;
};
void UpdateAccumulatedSpectra(rtc::ArrayView<const float> X2,
rtc::ArrayView<const float> Y2,
rtc::ArrayView<const float> E2);
void UpdateAccumulatedSpectra(
rtc::ArrayView<const float, kFftLengthBy2Plus1> X2,
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2,
const std::vector<bool>& converged_filters);
void ResetAccumulatedSpectra();
void UpdateBands(bool onset_detection);
void UpdateBands(const std::vector<bool>& converged_filters);
void DecreaseErlePerBandForLowRenderSignals();
const bool use_onset_detection_;
const float min_erle_;
const std::array<float, kFftLengthBy2Plus1> max_erle_;
const bool use_min_erle_during_onsets_;
AccumulatedSpectra accum_spectra_;
std::vector<std::array<float, kFftLengthBy2Plus1>> erle_;
std::array<float, kFftLengthBy2Plus1> erle_onsets_;
std::array<bool, kFftLengthBy2Plus1> coming_onset_;
std::array<int, kFftLengthBy2Plus1> hold_counters_;
std::vector<std::array<float, kFftLengthBy2Plus1>> erle_onsets_;
std::vector<std::array<bool, kFftLengthBy2Plus1>> coming_onset_;
std::vector<std::array<int, kFftLengthBy2Plus1>> hold_counters_;
};
} // namespace webrtc

View File

@ -16,26 +16,41 @@
namespace webrtc {
SubtractorOutputAnalyzer::SubtractorOutputAnalyzer() {}
SubtractorOutputAnalyzer::SubtractorOutputAnalyzer(size_t num_capture_channels)
: filters_converged_(num_capture_channels, false) {}
void SubtractorOutputAnalyzer::Update(
const SubtractorOutput& subtractor_output) {
const float y2 = subtractor_output.y2;
const float e2_main = subtractor_output.e2_main;
const float e2_shadow = subtractor_output.e2_shadow;
rtc::ArrayView<const SubtractorOutput> subtractor_output,
bool* any_filter_converged,
bool* all_filters_diverged) {
RTC_DCHECK(any_filter_converged);
RTC_DCHECK(all_filters_diverged);
RTC_DCHECK_EQ(subtractor_output.size(), filters_converged_.size());
constexpr float kConvergenceThreshold = 50 * 50 * kBlockSize;
main_filter_converged_ = e2_main < 0.5f * y2 && y2 > kConvergenceThreshold;
shadow_filter_converged_ =
e2_shadow < 0.05f * y2 && y2 > kConvergenceThreshold;
float min_e2 = std::min(e2_main, e2_shadow);
filter_diverged_ = min_e2 > 1.5f * y2 && y2 > 30.f * 30.f * kBlockSize;
*any_filter_converged = false;
*all_filters_diverged = true;
for (size_t ch = 0; ch < subtractor_output.size(); ++ch) {
const float y2 = subtractor_output[ch].y2;
const float e2_main = subtractor_output[ch].e2_main;
const float e2_shadow = subtractor_output[ch].e2_shadow;
constexpr float kConvergenceThreshold = 50 * 50 * kBlockSize;
bool main_filter_converged =
e2_main < 0.5f * y2 && y2 > kConvergenceThreshold;
bool shadow_filter_converged =
e2_shadow < 0.05f * y2 && y2 > kConvergenceThreshold;
float min_e2 = std::min(e2_main, e2_shadow);
bool filter_diverged = min_e2 > 1.5f * y2 && y2 > 30.f * 30.f * kBlockSize;
filters_converged_[ch] = main_filter_converged || shadow_filter_converged;
*any_filter_converged = *any_filter_converged || filters_converged_[ch];
*all_filters_diverged = *all_filters_diverged && filter_diverged;
}
}
void SubtractorOutputAnalyzer::HandleEchoPathChange() {
shadow_filter_converged_ = false;
main_filter_converged_ = false;
filter_diverged_ = false;
std::fill(filters_converged_.begin(), filters_converged_.end(), false);
}
} // namespace webrtc

View File

@ -11,32 +11,32 @@
#ifndef MODULES_AUDIO_PROCESSING_AEC3_SUBTRACTOR_OUTPUT_ANALYZER_H_
#define MODULES_AUDIO_PROCESSING_AEC3_SUBTRACTOR_OUTPUT_ANALYZER_H_
#include <vector>
#include "modules/audio_processing/aec3/subtractor_output.h"
namespace webrtc {
// Class for analyzing the properties subtractor output
// Class for analyzing the properties subtractor output.
class SubtractorOutputAnalyzer {
public:
SubtractorOutputAnalyzer();
explicit SubtractorOutputAnalyzer(size_t num_capture_channels);
~SubtractorOutputAnalyzer() = default;
// Analyses the subtractor output.
void Update(const SubtractorOutput& subtractor_output);
void Update(rtc::ArrayView<const SubtractorOutput> subtractor_output,
bool* any_filter_converged,
bool* all_filters_diverged);
bool ConvergedFilter() const {
return main_filter_converged_ || shadow_filter_converged_;
const std::vector<bool>& ConvergedFilters() const {
return filters_converged_;
}
bool DivergedFilter() const { return filter_diverged_; }
// Handle echo path change.
void HandleEchoPathChange();
private:
bool shadow_filter_converged_ = false;
bool main_filter_converged_ = false;
bool filter_diverged_ = false;
std::vector<bool> filters_converged_;
};
} // namespace webrtc