From 47d7fbd8fe4077b425c90b580d74a624276bfe7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85hgren?= Date: Tue, 24 Apr 2018 12:44:29 +0200 Subject: [PATCH] Reuse the AEC2 coherence-based gain for the lower bands in AEC3. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This CL overrides the power-based suppressor gain decision with a coherence based descision for the cases when that indicates a higher suppressor gain. Bug: webrtc:9159,chromium:833801 Change-Id: I0e7d82ac1b8c70ffe9d45907559bb14b1b849d71 Reviewed-on: https://webrtc-review.googlesource.com/71660 Commit-Queue: Per Ã…hgren Reviewed-by: Gustaf Ullberg Cr-Commit-Position: refs/heads/master@{#22997} --- api/audio/echo_canceller3_config.h | 4 + modules/audio_processing/aec3/BUILD.gn | 5 +- modules/audio_processing/aec3/aec3_fft.cc | 63 ++++- modules/audio_processing/aec3/aec3_fft.h | 10 +- .../audio_processing/aec3/coherence_gain.cc | 257 ++++++++++++++++++ .../audio_processing/aec3/coherence_gain.h | 77 ++++++ modules/audio_processing/aec3/echo_remover.cc | 45 ++- .../audio_processing/aec3/output_selector.cc | 59 ---- .../audio_processing/aec3/output_selector.h | 38 --- .../aec3/output_selector_unittest.cc | 69 ----- .../aec3/suppression_filter.cc | 16 +- .../aec3/suppression_filter.h | 2 +- .../aec3/suppression_filter_unittest.cc | 42 ++- .../audio_processing/aec3/suppression_gain.cc | 33 ++- .../audio_processing/aec3/suppression_gain.h | 27 +- .../aec3/suppression_gain_unittest.cc | 43 ++- 16 files changed, 564 insertions(+), 226 deletions(-) create mode 100644 modules/audio_processing/aec3/coherence_gain.cc create mode 100644 modules/audio_processing/aec3/coherence_gain.h delete mode 100644 modules/audio_processing/aec3/output_selector.cc delete mode 100644 modules/audio_processing/aec3/output_selector.h delete mode 100644 modules/audio_processing/aec3/output_selector_unittest.cc diff --git a/api/audio/echo_canceller3_config.h b/api/audio/echo_canceller3_config.h index 3b033dc91d..debd4872da 100644 --- a/api/audio/echo_canceller3_config.h +++ b/api/audio/echo_canceller3_config.h @@ -143,6 +143,10 @@ struct EchoCanceller3Config { float nonlinear_hold = 1; float nonlinear_release = 0.001f; } echo_model; + + struct Suppressor { + size_t bands_with_reliable_coherence = 5; + } suppressor; }; } // namespace webrtc diff --git a/modules/audio_processing/aec3/BUILD.gn b/modules/audio_processing/aec3/BUILD.gn index 2773d353fb..0dfb78e256 100644 --- a/modules/audio_processing/aec3/BUILD.gn +++ b/modules/audio_processing/aec3/BUILD.gn @@ -30,6 +30,8 @@ rtc_static_library("aec3") { "block_processor_metrics.h", "cascaded_biquad_filter.cc", "cascaded_biquad_filter.h", + "coherence_gain.cc", + "coherence_gain.h", "comfort_noise_generator.cc", "comfort_noise_generator.h", "decimator.cc", @@ -66,8 +68,6 @@ rtc_static_library("aec3") { "matched_filter_lag_aggregator.h", "matrix_buffer.cc", "matrix_buffer.h", - "output_selector.cc", - "output_selector.h", "render_buffer.cc", "render_buffer.h", "render_delay_buffer.cc", @@ -172,7 +172,6 @@ if (rtc_include_tests) { "main_filter_update_gain_unittest.cc", "matched_filter_lag_aggregator_unittest.cc", "matched_filter_unittest.cc", - "output_selector_unittest.cc", "render_buffer_unittest.cc", "render_delay_buffer_unittest.cc", "render_delay_controller_metrics_unittest.cc", diff --git a/modules/audio_processing/aec3/aec3_fft.cc b/modules/audio_processing/aec3/aec3_fft.cc index d6690360f8..5c9a061cba 100644 --- a/modules/audio_processing/aec3/aec3_fft.cc +++ b/modules/audio_processing/aec3/aec3_fft.cc @@ -33,6 +33,41 @@ const float kHanning64[kFftLengthBy2] = { 0.15088159f, 0.11697778f, 0.08688061f, 0.06088921f, 0.03926189f, 0.0222136f, 0.00991376f, 0.00248461f, 0.f}; +// Hanning window from Matlab command win = sqrt(hanning(128)). +const float kSqrtHanning128[kFftLength] = { + 0.00000000000000f, 0.02454122852291f, 0.04906767432742f, 0.07356456359967f, + 0.09801714032956f, 0.12241067519922f, 0.14673047445536f, 0.17096188876030f, + 0.19509032201613f, 0.21910124015687f, 0.24298017990326f, 0.26671275747490f, + 0.29028467725446f, 0.31368174039889f, 0.33688985339222f, 0.35989503653499f, + 0.38268343236509f, 0.40524131400499f, 0.42755509343028f, 0.44961132965461f, + 0.47139673682600f, 0.49289819222978f, 0.51410274419322f, 0.53499761988710f, + 0.55557023301960f, 0.57580819141785f, 0.59569930449243f, 0.61523159058063f, + 0.63439328416365f, 0.65317284295378f, 0.67155895484702f, 0.68954054473707f, + 0.70710678118655f, 0.72424708295147f, 0.74095112535496f, 0.75720884650648f, + 0.77301045336274f, 0.78834642762661f, 0.80320753148064f, 0.81758481315158f, + 0.83146961230255f, 0.84485356524971f, 0.85772861000027f, 0.87008699110871f, + 0.88192126434835f, 0.89322430119552f, 0.90398929312344f, 0.91420975570353f, + 0.92387953251129f, 0.93299279883474f, 0.94154406518302f, 0.94952818059304f, + 0.95694033573221f, 0.96377606579544f, 0.97003125319454f, 0.97570213003853f, + 0.98078528040323f, 0.98527764238894f, 0.98917650996478f, 0.99247953459871f, + 0.99518472667220f, 0.99729045667869f, 0.99879545620517f, 0.99969881869620f, + 1.00000000000000f, 0.99969881869620f, 0.99879545620517f, 0.99729045667869f, + 0.99518472667220f, 0.99247953459871f, 0.98917650996478f, 0.98527764238894f, + 0.98078528040323f, 0.97570213003853f, 0.97003125319454f, 0.96377606579544f, + 0.95694033573221f, 0.94952818059304f, 0.94154406518302f, 0.93299279883474f, + 0.92387953251129f, 0.91420975570353f, 0.90398929312344f, 0.89322430119552f, + 0.88192126434835f, 0.87008699110871f, 0.85772861000027f, 0.84485356524971f, + 0.83146961230255f, 0.81758481315158f, 0.80320753148064f, 0.78834642762661f, + 0.77301045336274f, 0.75720884650648f, 0.74095112535496f, 0.72424708295147f, + 0.70710678118655f, 0.68954054473707f, 0.67155895484702f, 0.65317284295378f, + 0.63439328416365f, 0.61523159058063f, 0.59569930449243f, 0.57580819141785f, + 0.55557023301960f, 0.53499761988710f, 0.51410274419322f, 0.49289819222978f, + 0.47139673682600f, 0.44961132965461f, 0.42755509343028f, 0.40524131400499f, + 0.38268343236509f, 0.35989503653499f, 0.33688985339222f, 0.31368174039889f, + 0.29028467725446f, 0.26671275747490f, 0.24298017990326f, 0.21910124015687f, + 0.19509032201613f, 0.17096188876030f, 0.14673047445536f, 0.12241067519922f, + 0.09801714032956f, 0.07356456359967f, 0.04906767432742f, 0.02454122852291f}; + } // namespace // TODO(peah): Change x to be std::array once the rest of the code allows this. @@ -52,6 +87,9 @@ void Aec3Fft::ZeroPaddedFft(rtc::ArrayView x, fft.begin() + kFftLengthBy2, [](float a, float b) { return a * b; }); break; + case Window::kSqrtHanning: + RTC_NOTREACHED(); + break; default: RTC_NOTREACHED(); } @@ -61,14 +99,33 @@ void Aec3Fft::ZeroPaddedFft(rtc::ArrayView x, void Aec3Fft::PaddedFft(rtc::ArrayView x, rtc::ArrayView x_old, + Window window, FftData* X) const { RTC_DCHECK(X); RTC_DCHECK_EQ(kFftLengthBy2, x.size()); RTC_DCHECK_EQ(kFftLengthBy2, x_old.size()); std::array fft; - std::copy(x_old.begin(), x_old.end(), fft.begin()); - std::copy(x.begin(), x.end(), fft.begin() + x_old.size()); - std::copy(x.begin(), x.end(), x_old.begin()); + + switch (window) { + case Window::kRectangular: + std::copy(x_old.begin(), x_old.end(), fft.begin()); + std::copy(x.begin(), x.end(), fft.begin() + x_old.size()); + std::copy(x.begin(), x.end(), x_old.begin()); + break; + case Window::kHanning: + RTC_NOTREACHED(); + break; + case Window::kSqrtHanning: + std::transform(x_old.begin(), x_old.end(), std::begin(kSqrtHanning128), + fft.begin(), std::multiplies()); + std::transform(x.begin(), x.end(), + std::begin(kSqrtHanning128) + x_old.size(), + fft.begin() + x_old.size(), std::multiplies()); + break; + default: + RTC_NOTREACHED(); + } + Fft(&fft, X); } diff --git a/modules/audio_processing/aec3/aec3_fft.h b/modules/audio_processing/aec3/aec3_fft.h index f3dddb3f1b..f730284745 100644 --- a/modules/audio_processing/aec3/aec3_fft.h +++ b/modules/audio_processing/aec3/aec3_fft.h @@ -25,7 +25,7 @@ namespace webrtc { // FftData type. class Aec3Fft { public: - enum class Window { kRectangular, kHanning }; + enum class Window { kRectangular, kHanning, kSqrtHanning }; Aec3Fft() = default; // Computes the FFT. Note that both the input and output are modified. @@ -52,6 +52,14 @@ class Aec3Fft { // Fft. After that, x is copied to x_old. void PaddedFft(rtc::ArrayView x, rtc::ArrayView x_old, + FftData* X) const { + PaddedFft(x, x_old, Window::kRectangular, X); + } + + // Padded Fft using a time-domain window. + void PaddedFft(rtc::ArrayView x, + rtc::ArrayView x_old, + Window window, FftData* X) const; private: diff --git a/modules/audio_processing/aec3/coherence_gain.cc b/modules/audio_processing/aec3/coherence_gain.cc new file mode 100644 index 0000000000..ad33382b3a --- /dev/null +++ b/modules/audio_processing/aec3/coherence_gain.cc @@ -0,0 +1,257 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/coherence_gain.h" + +#include + +#include + +#include "rtc_base/checks.h" + +namespace webrtc { + +namespace { + +// Matlab code to produce table: +// overDriveCurve = [sqrt(linspace(0,1,65))' + 1]; +// fprintf(1, '\t%.4f, %.4f, %.4f, %.4f, %.4f, %.4f,\n', overDriveCurve); +const float kOverDriveCurve[kFftLengthBy2Plus1] = { + 1.0000f, 1.1250f, 1.1768f, 1.2165f, 1.2500f, 1.2795f, 1.3062f, 1.3307f, + 1.3536f, 1.3750f, 1.3953f, 1.4146f, 1.4330f, 1.4507f, 1.4677f, 1.4841f, + 1.5000f, 1.5154f, 1.5303f, 1.5449f, 1.5590f, 1.5728f, 1.5863f, 1.5995f, + 1.6124f, 1.6250f, 1.6374f, 1.6495f, 1.6614f, 1.6731f, 1.6847f, 1.6960f, + 1.7071f, 1.7181f, 1.7289f, 1.7395f, 1.7500f, 1.7603f, 1.7706f, 1.7806f, + 1.7906f, 1.8004f, 1.8101f, 1.8197f, 1.8292f, 1.8385f, 1.8478f, 1.8570f, + 1.8660f, 1.8750f, 1.8839f, 1.8927f, 1.9014f, 1.9100f, 1.9186f, 1.9270f, + 1.9354f, 1.9437f, 1.9520f, 1.9601f, 1.9682f, 1.9763f, 1.9843f, 1.9922f, + 2.0000f}; + +// Matlab code to produce table: +// weightCurve = [0 ; 0.3 * sqrt(linspace(0,1,64))' + 0.1]; +// fprintf(1, '\t%.4f, %.4f, %.4f, %.4f, %.4f, %.4f,\n', weightCurve); +const float kWeightCurve[kFftLengthBy2Plus1] = { + 0.0000f, 0.1000f, 0.1378f, 0.1535f, 0.1655f, 0.1756f, 0.1845f, 0.1926f, + 0.2000f, 0.2069f, 0.2134f, 0.2195f, 0.2254f, 0.2309f, 0.2363f, 0.2414f, + 0.2464f, 0.2512f, 0.2558f, 0.2604f, 0.2648f, 0.2690f, 0.2732f, 0.2773f, + 0.2813f, 0.2852f, 0.2890f, 0.2927f, 0.2964f, 0.3000f, 0.3035f, 0.3070f, + 0.3104f, 0.3138f, 0.3171f, 0.3204f, 0.3236f, 0.3268f, 0.3299f, 0.3330f, + 0.3360f, 0.3390f, 0.3420f, 0.3449f, 0.3478f, 0.3507f, 0.3535f, 0.3563f, + 0.3591f, 0.3619f, 0.3646f, 0.3673f, 0.3699f, 0.3726f, 0.3752f, 0.3777f, + 0.3803f, 0.3828f, 0.3854f, 0.3878f, 0.3903f, 0.3928f, 0.3952f, 0.3976f, + 0.4000f}; + +int CmpFloat(const void* a, const void* b) { + const float* da = static_cast(a); + const float* db = static_cast(b); + return (*da > *db) - (*da < *db); +} + +} // namespace + +CoherenceGain::CoherenceGain(int sample_rate_hz, size_t num_bands_to_compute) + : num_bands_to_compute_(num_bands_to_compute), + sample_rate_scaler_(sample_rate_hz >= 16000 ? 2 : 1) { + spectra_.Cye.Clear(); + spectra_.Cxy.Clear(); + spectra_.Pe.fill(0.f); + // Initialize to 1 in order to prevent numerical instability in the first + // block. + spectra_.Py.fill(1.f); + spectra_.Px.fill(1.f); +} + +CoherenceGain::~CoherenceGain() = default; + +void CoherenceGain::ComputeGain(const FftData& E, + const FftData& X, + const FftData& Y, + rtc::ArrayView gain) { + std::array coherence_ye; + std::array coherence_xy; + + UpdateCoherenceSpectra(E, X, Y); + ComputeCoherence(coherence_ye, coherence_xy); + FormSuppressionGain(coherence_ye, coherence_xy, gain); +} + +// Updates the following smoothed Power Spectral Densities (PSD): +// - sd : near-end +// - se : residual echo +// - sx : far-end +// - sde : cross-PSD of near-end and residual echo +// - sxd : cross-PSD of near-end and far-end +// +void CoherenceGain::UpdateCoherenceSpectra(const FftData& E, + const FftData& X, + const FftData& Y) { + const float s = sample_rate_scaler_ == 1 ? 0.9f : 0.92f; + const float one_minus_s = 1.f - s; + auto& c = spectra_; + + for (size_t i = 0; i < c.Py.size(); i++) { + c.Py[i] = + s * c.Py[i] + one_minus_s * (Y.re[i] * Y.re[i] + Y.im[i] * Y.im[i]); + c.Pe[i] = + s * c.Pe[i] + one_minus_s * (E.re[i] * E.re[i] + E.im[i] * E.im[i]); + // We threshold here to protect against the ill-effects of a zero farend. + // The threshold is not arbitrarily chosen, but balances protection and + // adverse interaction with the algorithm's tuning. + + // Threshold to protect against the ill-effects of a zero far-end. + c.Px[i] = + s * c.Px[i] + + one_minus_s * std::max(X.re[i] * X.re[i] + X.im[i] * X.im[i], 15.f); + + c.Cye.re[i] = + s * c.Cye.re[i] + one_minus_s * (Y.re[i] * E.re[i] + Y.im[i] * E.im[i]); + c.Cye.im[i] = + s * c.Cye.im[i] + one_minus_s * (Y.re[i] * E.im[i] - Y.im[i] * E.re[i]); + + c.Cxy.re[i] = + s * c.Cxy.re[i] + one_minus_s * (Y.re[i] * X.re[i] + Y.im[i] * X.im[i]); + c.Cxy.im[i] = + s * c.Cxy.im[i] + one_minus_s * (Y.re[i] * X.im[i] - Y.im[i] * X.re[i]); + } +} + +void CoherenceGain::FormSuppressionGain( + rtc::ArrayView coherence_ye, + rtc::ArrayView coherence_xy, + rtc::ArrayView gain) { + RTC_DCHECK_EQ(kFftLengthBy2Plus1, coherence_ye.size()); + RTC_DCHECK_EQ(kFftLengthBy2Plus1, coherence_xy.size()); + RTC_DCHECK_EQ(kFftLengthBy2Plus1, gain.size()); + constexpr int kPrefBandSize = 24; + auto& gs = gain_state_; + std::array h_nl_pref; + float h_nl_fb = 0; + float h_nl_fb_low = 0; + const int pref_band_size = kPrefBandSize / sample_rate_scaler_; + const int min_pref_band = 4 / sample_rate_scaler_; + + float h_nl_de_avg = 0.f; + float h_nl_xd_avg = 0.f; + for (int i = min_pref_band; i < pref_band_size + min_pref_band; ++i) { + h_nl_xd_avg += coherence_xy[i]; + h_nl_de_avg += coherence_ye[i]; + } + h_nl_xd_avg /= pref_band_size; + h_nl_xd_avg = 1 - h_nl_xd_avg; + h_nl_de_avg /= pref_band_size; + + if (h_nl_xd_avg < 0.75f && h_nl_xd_avg < gs.h_nl_xd_avg_min) { + gs.h_nl_xd_avg_min = h_nl_xd_avg; + } + + if (h_nl_de_avg > 0.98f && h_nl_xd_avg > 0.9f) { + gs.near_state = true; + } else if (h_nl_de_avg < 0.95f || h_nl_xd_avg < 0.8f) { + gs.near_state = false; + } + + std::array h_nl; + if (gs.h_nl_xd_avg_min == 1) { + gs.overdrive = 15.f; + + if (gs.near_state) { + std::copy(coherence_ye.begin(), coherence_ye.end(), h_nl.begin()); + h_nl_fb = h_nl_de_avg; + h_nl_fb_low = h_nl_de_avg; + } else { + for (size_t i = 0; i < h_nl.size(); ++i) { + h_nl[i] = 1 - coherence_xy[i]; + h_nl[i] = std::max(h_nl[i], 0.f); + } + h_nl_fb = h_nl_xd_avg; + h_nl_fb_low = h_nl_xd_avg; + } + } else { + if (gs.near_state) { + std::copy(coherence_ye.begin(), coherence_ye.end(), h_nl.begin()); + h_nl_fb = h_nl_de_avg; + h_nl_fb_low = h_nl_de_avg; + } else { + for (size_t i = 0; i < h_nl.size(); ++i) { + h_nl[i] = std::min(coherence_ye[i], 1 - coherence_xy[i]); + h_nl[i] = std::max(h_nl[i], 0.f); + } + + // Select an order statistic from the preferred bands. + // TODO(peah): Using quicksort now, but a selection algorithm may be + // preferred. + std::copy(h_nl.begin() + min_pref_band, + h_nl.begin() + min_pref_band + pref_band_size, + h_nl_pref.begin()); + std::qsort(h_nl_pref.data(), pref_band_size, sizeof(float), CmpFloat); + + constexpr float kPrefBandQuant = 0.75f; + h_nl_fb = h_nl_pref[static_cast( + floor(kPrefBandQuant * (pref_band_size - 1)))]; + constexpr float kPrefBandQuantLow = 0.5f; + h_nl_fb_low = h_nl_pref[static_cast( + floor(kPrefBandQuantLow * (pref_band_size - 1)))]; + } + } + + // Track the local filter minimum to determine suppression overdrive. + if (h_nl_fb_low < 0.6f && h_nl_fb_low < gs.h_nl_fb_local_min) { + gs.h_nl_fb_local_min = h_nl_fb_low; + gs.h_nl_fb_min = h_nl_fb_low; + gs.h_nl_new_min = 1; + gs.h_nl_min_ctr = 0; + } + gs.h_nl_fb_local_min = + std::min(gs.h_nl_fb_local_min + 0.0008f / sample_rate_scaler_, 1.f); + gs.h_nl_xd_avg_min = + std::min(gs.h_nl_xd_avg_min + 0.0006f / sample_rate_scaler_, 1.f); + + if (gs.h_nl_new_min == 1) { + ++gs.h_nl_min_ctr; + } + if (gs.h_nl_min_ctr == 2) { + gs.h_nl_new_min = 0; + gs.h_nl_min_ctr = 0; + constexpr float epsilon = 1e-10f; + gs.overdrive = std::max( + -18.4f / static_cast(log(gs.h_nl_fb_min + epsilon) + epsilon), + 15.f); + } + + // Smooth the overdrive. + if (gs.overdrive < gs.overdrive_scaling) { + gs.overdrive_scaling = 0.99f * gs.overdrive_scaling + 0.01f * gs.overdrive; + } else { + gs.overdrive_scaling = 0.9f * gs.overdrive_scaling + 0.1f * gs.overdrive; + } + + // Apply the overdrive. + RTC_DCHECK_LE(num_bands_to_compute_, gain.size()); + for (size_t i = 0; i < num_bands_to_compute_; ++i) { + if (h_nl[i] > h_nl_fb) { + h_nl[i] = kWeightCurve[i] * h_nl_fb + (1 - kWeightCurve[i]) * h_nl[i]; + } + gain[i] = powf(h_nl[i], gs.overdrive_scaling * kOverDriveCurve[i]); + } +} + +void CoherenceGain::ComputeCoherence(rtc::ArrayView coherence_ye, + rtc::ArrayView coherence_xy) const { + const auto& c = spectra_; + constexpr float epsilon = 1e-10f; + for (size_t i = 0; i < coherence_ye.size(); ++i) { + coherence_ye[i] = (c.Cye.re[i] * c.Cye.re[i] + c.Cye.im[i] * c.Cye.im[i]) / + (c.Py[i] * c.Pe[i] + epsilon); + coherence_xy[i] = (c.Cxy.re[i] * c.Cxy.re[i] + c.Cxy.im[i] * c.Cxy.im[i]) / + (c.Px[i] * c.Py[i] + epsilon); + } +} + +} // namespace webrtc diff --git a/modules/audio_processing/aec3/coherence_gain.h b/modules/audio_processing/aec3/coherence_gain.h new file mode 100644 index 0000000000..b6e22fdbbc --- /dev/null +++ b/modules/audio_processing/aec3/coherence_gain.h @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_COHERENCE_GAIN_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_COHERENCE_GAIN_H_ + +#include + +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/aec3_fft.h" +#include "rtc_base/constructormagic.h" + +namespace webrtc { + +// Class for computing an echo suppression gain based on the coherence measure. +class CoherenceGain { + public: + CoherenceGain(int sample_rate_hz, size_t num_bands_to_compute); + ~CoherenceGain(); + + // Computes the gain based on the FFTs of the filter error output signal, the + // render signal and the capture signal. + void ComputeGain(const FftData& E, + const FftData& X, + const FftData& Y, + rtc::ArrayView gain); + + private: + struct { + FftData Cye; + FftData Cxy; + std::array Px; + std::array Py; + std::array Pe; + } spectra_; + + struct { + float h_nl_fb_min = 1; + float h_nl_fb_local_min = 1; + float h_nl_xd_avg_min = 1.f; + int h_nl_new_min = 0; + float h_nl_min_ctr = 0; + float overdrive = 2; + float overdrive_scaling = 2; + bool near_state = false; + } gain_state_; + + const Aec3Fft fft_; + const size_t num_bands_to_compute_; + const int sample_rate_scaler_; + + // Updates the spectral estimates used for the coherence computation. + void UpdateCoherenceSpectra(const FftData& E, + const FftData& X, + const FftData& Y); + + // Compute the suppression gain based on the coherence. + void FormSuppressionGain(rtc::ArrayView coherence_ye, + rtc::ArrayView coherence_xy, + rtc::ArrayView h_nl); + + // Compute the coherence. + void ComputeCoherence(rtc::ArrayView coherence_ye, + rtc::ArrayView coherence_xy) const; + + RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(CoherenceGain); +}; +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_COHERENCE_GAIN_H_ diff --git a/modules/audio_processing/aec3/echo_remover.cc b/modules/audio_processing/aec3/echo_remover.cc index 8e14032f2c..fea561d837 100644 --- a/modules/audio_processing/aec3/echo_remover.cc +++ b/modules/audio_processing/aec3/echo_remover.cc @@ -22,7 +22,6 @@ #include "modules/audio_processing/aec3/echo_path_variability.h" #include "modules/audio_processing/aec3/echo_remover_metrics.h" #include "modules/audio_processing/aec3/fft_data.h" -#include "modules/audio_processing/aec3/output_selector.h" #include "modules/audio_processing/aec3/render_buffer.h" #include "modules/audio_processing/aec3/render_delay_buffer.h" #include "modules/audio_processing/aec3/residual_echo_estimator.h" @@ -86,12 +85,14 @@ class EchoRemoverImpl final : public EchoRemover { ComfortNoiseGenerator cng_; SuppressionFilter suppression_filter_; RenderSignalAnalyzer render_signal_analyzer_; - OutputSelector output_selector_; ResidualEchoEstimator residual_echo_estimator_; bool echo_leakage_detected_ = false; AecState aec_state_; EchoRemoverMetrics metrics_; bool initial_state_ = true; + std::array e_old_; + std::array x_old_; + std::array y_old_; RTC_DISALLOW_COPY_AND_ASSIGN(EchoRemoverImpl); }; @@ -107,13 +108,16 @@ EchoRemoverImpl::EchoRemoverImpl(const EchoCanceller3Config& config, optimization_(DetectOptimization()), sample_rate_hz_(sample_rate_hz), subtractor_(config, data_dumper_.get(), optimization_), - suppression_gain_(config_, optimization_), + suppression_gain_(config_, optimization_, sample_rate_hz), cng_(optimization_), suppression_filter_(sample_rate_hz_), render_signal_analyzer_(config_), residual_echo_estimator_(config_), aec_state_(config_) { RTC_DCHECK(ValidFullBandRate(sample_rate_hz)); + x_old_.fill(0.f); + y_old_.fill(0.f); + e_old_.fill(0.f); } EchoRemoverImpl::~EchoRemoverImpl() = default; @@ -191,6 +195,8 @@ void EchoRemoverImpl::ProcessCapture( fft_.ZeroPaddedFft(y0, Aec3Fft::Window::kRectangular, &Y); LinearEchoPower(E_main_nonwindowed, Y, &S2_linear); Y.Spectrum(optimization_, Y2); + fft_.PaddedFft(y0, y_old_, Aec3Fft::Window::kSqrtHanning, &Y); + std::copy(y0.begin(), y0.end(), y_old_.begin()); // Update the AEC state information. aec_state_.Update(external_delay, subtractor_.FilterFrequencyResponse(), @@ -201,12 +207,11 @@ void EchoRemoverImpl::ProcessCapture( // Choose the linear output. data_dumper_->DumpWav("aec3_output_linear2", kBlockSize, &e_main[0], LowestBandRate(sample_rate_hz_), 1); - output_selector_.FormLinearOutput(aec_state_.UseLinearFilterOutput(), e_main, - y0); - + if (aec_state_.UseLinearFilterOutput()) { + std::copy(e_main.begin(), e_main.end(), y0.begin()); + } data_dumper_->DumpWav("aec3_output_linear", kBlockSize, &y0[0], LowestBandRate(sample_rate_hz_), 1); - data_dumper_->DumpRaw("aec3_output_linear", y0); const auto& E2 = aec_state_.UseLinearFilterOutput() ? E2_main : Y2; // Estimate the residual echo power. @@ -216,12 +221,32 @@ void EchoRemoverImpl::ProcessCapture( // Estimate the comfort noise. cng_.Compute(aec_state_, Y2, &comfort_noise, &high_band_comfort_noise); - // A choose and apply echo suppression gain. - suppression_gain_.GetGain(E2, R2, cng_.NoiseSpectrum(), + // Compute spectra. + const bool suppression_gain_uses_ffts = + config_.suppressor.bands_with_reliable_coherence > 0; + FftData X; + if (suppression_gain_uses_ffts) { + const std::vector& x_aligned = + render_buffer->Block(-aec_state_.FilterDelayBlocks())[0]; + fft_.PaddedFft(x_aligned, x_old_, Aec3Fft::Window::kSqrtHanning, &X); + std::copy(x_aligned.begin(), x_aligned.end(), x_old_.begin()); + } else { + X.Clear(); + } + + FftData E; + fft_.PaddedFft(e_main, e_old_, Aec3Fft::Window::kSqrtHanning, &E); + std::copy(e_main.begin(), e_main.end(), e_old_.begin()); + + const auto& Y_fft = aec_state_.UseLinearFilterOutput() ? E : Y; + + // Compute and apply the suppression gain. + suppression_gain_.GetGain(E2, R2, cng_.NoiseSpectrum(), E, X, Y, render_signal_analyzer_, aec_state_, x, &high_bands_gain, &G); + suppression_filter_.ApplyGain(comfort_noise, high_band_comfort_noise, G, - high_bands_gain, y); + high_bands_gain, Y_fft, y); // Update the metrics. metrics_.Update(aec_state_, cng_.NoiseSpectrum(), G); diff --git a/modules/audio_processing/aec3/output_selector.cc b/modules/audio_processing/aec3/output_selector.cc deleted file mode 100644 index 4f547d98d9..0000000000 --- a/modules/audio_processing/aec3/output_selector.cc +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#include "modules/audio_processing/aec3/output_selector.h" - -#include -#include - -#include "rtc_base/checks.h" - -namespace webrtc { -namespace { - -// Performs the transition between the signals in a smooth manner. -void SmoothFrameTransition(bool from_y_to_e, - rtc::ArrayView e, - rtc::ArrayView y) { - RTC_DCHECK_LT(0u, e.size()); - RTC_DCHECK_EQ(y.size(), e.size()); - - const float change_factor = (from_y_to_e ? 1.f : -1.f) / e.size(); - float averaging = from_y_to_e ? 0.f : 1.f; - for (size_t k = 0; k < e.size(); ++k) { - y[k] += averaging * (e[k] - y[k]); - averaging += change_factor; - } - RTC_DCHECK_EQ(from_y_to_e ? 1.f : 0.f, averaging); -} - -} // namespace - -OutputSelector::OutputSelector() = default; - -OutputSelector::~OutputSelector() = default; - -void OutputSelector::FormLinearOutput( - bool use_subtractor_output, - rtc::ArrayView subtractor_output, - rtc::ArrayView capture) { - RTC_DCHECK_EQ(subtractor_output.size(), capture.size()); - rtc::ArrayView& e_main = subtractor_output; - rtc::ArrayView y = capture; - - if (use_subtractor_output != use_subtractor_output_) { - use_subtractor_output_ = use_subtractor_output; - SmoothFrameTransition(use_subtractor_output_, e_main, y); - } else if (use_subtractor_output_) { - std::copy(e_main.begin(), e_main.end(), y.begin()); - } -} - -} // namespace webrtc diff --git a/modules/audio_processing/aec3/output_selector.h b/modules/audio_processing/aec3/output_selector.h deleted file mode 100644 index 17605a6a45..0000000000 --- a/modules/audio_processing/aec3/output_selector.h +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef MODULES_AUDIO_PROCESSING_AEC3_OUTPUT_SELECTOR_H_ -#define MODULES_AUDIO_PROCESSING_AEC3_OUTPUT_SELECTOR_H_ - -#include "api/array_view.h" -#include "rtc_base/constructormagic.h" - -namespace webrtc { - -// Performs the selection between which of the linear aec output and the -// microphone signal should be used as the echo suppressor output. -class OutputSelector { - public: - OutputSelector(); - ~OutputSelector(); - - // Forms the most appropriate output signal. - void FormLinearOutput(bool use_subtractor_output, - rtc::ArrayView subtractor_output, - rtc::ArrayView capture); - - private: - bool use_subtractor_output_ = false; - RTC_DISALLOW_COPY_AND_ASSIGN(OutputSelector); -}; - -} // namespace webrtc - -#endif // MODULES_AUDIO_PROCESSING_AEC3_OUTPUT_SELECTOR_H_ diff --git a/modules/audio_processing/aec3/output_selector_unittest.cc b/modules/audio_processing/aec3/output_selector_unittest.cc deleted file mode 100644 index c7add1c838..0000000000 --- a/modules/audio_processing/aec3/output_selector_unittest.cc +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#include "modules/audio_processing/aec3/output_selector.h" - -#include -#include - -#include "modules/audio_processing/aec3/aec3_common.h" -#include "test/gtest.h" - -namespace webrtc { - -// Verifies that the switching between the signals in the output works as -// intended. -TEST(OutputSelector, ProperSwitching) { - OutputSelector selector; - - std::array y; - std::array e; - std::array e_ref; - std::array y_ref; - auto init_blocks = [](std::array* e, - std::array* y) { - e->fill(10.f); - y->fill(20.f); - }; - - init_blocks(&e_ref, &y_ref); - - init_blocks(&e, &y); - selector.FormLinearOutput(false, e, y); - EXPECT_EQ(y_ref, y); - - init_blocks(&e, &y); - selector.FormLinearOutput(true, e, y); - EXPECT_NE(e_ref, y); - EXPECT_NE(y_ref, y); - - init_blocks(&e, &y); - selector.FormLinearOutput(true, e, y); - EXPECT_EQ(e_ref, y); - - init_blocks(&e, &y); - selector.FormLinearOutput(true, e, y); - EXPECT_EQ(e_ref, y); - - init_blocks(&e, &y); - selector.FormLinearOutput(false, e, y); - EXPECT_NE(e_ref, y); - EXPECT_NE(y_ref, y); - - init_blocks(&e, &y); - selector.FormLinearOutput(false, e, y); - EXPECT_EQ(y_ref, y); - - init_blocks(&e, &y); - selector.FormLinearOutput(false, e, y); - EXPECT_EQ(y_ref, y); -} - -} // namespace webrtc diff --git a/modules/audio_processing/aec3/suppression_filter.cc b/modules/audio_processing/aec3/suppression_filter.cc index 8c92bf5762..87e3008dc7 100644 --- a/modules/audio_processing/aec3/suppression_filter.cc +++ b/modules/audio_processing/aec3/suppression_filter.cc @@ -64,7 +64,6 @@ SuppressionFilter::SuppressionFilter(int sample_rate_hz) fft_(), e_output_old_(NumBandsForRate(sample_rate_hz_)) { RTC_DCHECK(ValidFullBandRate(sample_rate_hz_)); - e_input_old_.fill(0.f); std::for_each(e_output_old_.begin(), e_output_old_.end(), [](std::array& a) { a.fill(0.f); }); } @@ -76,22 +75,14 @@ void SuppressionFilter::ApplyGain( const FftData& comfort_noise_high_band, const std::array& suppression_gain, float high_bands_gain, + const FftData& E_lowest_band, std::vector>* e) { RTC_DCHECK(e); RTC_DCHECK_EQ(e->size(), NumBandsForRate(sample_rate_hz_)); FftData E; - std::array e_extended; - constexpr float kIfftNormalization = 2.f / kFftLength; // Analysis filterbank. - std::transform(e_input_old_.begin(), e_input_old_.end(), - std::begin(kSqrtHanning), e_extended.begin(), - std::multiplies()); - std::transform((*e)[0].begin(), (*e)[0].end(), - std::begin(kSqrtHanning) + kFftLengthBy2, - e_extended.begin() + kFftLengthBy2, std::multiplies()); - std::copy((*e)[0].begin(), (*e)[0].end(), e_input_old_.begin()); - fft_.Fft(&e_extended, &E); + E.Assign(E_lowest_band); // Apply gain. std::transform(suppression_gain.begin(), suppression_gain.end(), E.re.begin(), @@ -113,6 +104,9 @@ void SuppressionFilter::ApplyGain( E.im.begin(), E.im.begin(), std::plus()); // Synthesis filterbank. + std::array e_extended; + constexpr float kIfftNormalization = 2.f / kFftLength; + fft_.Ifft(E, &e_extended); std::transform(e_output_old_[0].begin(), e_output_old_[0].end(), std::begin(kSqrtHanning) + kFftLengthBy2, (*e)[0].begin(), diff --git a/modules/audio_processing/aec3/suppression_filter.h b/modules/audio_processing/aec3/suppression_filter.h index 5f91dea28f..237408d7f9 100644 --- a/modules/audio_processing/aec3/suppression_filter.h +++ b/modules/audio_processing/aec3/suppression_filter.h @@ -28,13 +28,13 @@ class SuppressionFilter { const FftData& comfort_noise_high_bands, const std::array& suppression_gain, float high_bands_gain, + const FftData& E_lowest_band, std::vector>* e); private: const int sample_rate_hz_; const OouraFft ooura_fft_; const Aec3Fft fft_; - std::array e_input_old_; std::vector> e_output_old_; RTC_DISALLOW_COPY_AND_ASSIGN(SuppressionFilter); }; diff --git a/modules/audio_processing/aec3/suppression_filter_unittest.cc b/modules/audio_processing/aec3/suppression_filter_unittest.cc index 51b3f91f2a..eaa608eed5 100644 --- a/modules/audio_processing/aec3/suppression_filter_unittest.cc +++ b/modules/audio_processing/aec3/suppression_filter_unittest.cc @@ -42,10 +42,11 @@ void ProduceSinusoid(int sample_rate_hz, TEST(SuppressionFilter, NullOutput) { FftData cn; FftData cn_high_bands; + FftData E; std::array gain; EXPECT_DEATH(SuppressionFilter(16000).ApplyGain(cn, cn_high_bands, gain, 1.0f, - nullptr), + E, nullptr), ""); } @@ -62,7 +63,10 @@ TEST(SuppressionFilter, ComfortNoiseInUnityGain) { FftData cn; FftData cn_high_bands; std::array gain; + std::array e_old_; + Aec3Fft fft; + e_old_.fill(0.f); gain.fill(1.f); cn.re.fill(1.f); cn.im.fill(1.f); @@ -71,7 +75,12 @@ TEST(SuppressionFilter, ComfortNoiseInUnityGain) { std::vector> e(3, std::vector(kBlockSize, 0.f)); std::vector> e_ref = e; - filter.ApplyGain(cn, cn_high_bands, gain, 1.f, &e); + + FftData E; + fft.PaddedFft(e[0], e_old_, Aec3Fft::Window::kSqrtHanning, &E); + std::copy(e[0].begin(), e[0].end(), e_old_.begin()); + + filter.ApplyGain(cn, cn_high_bands, gain, 1.f, E, &e); for (size_t k = 0; k < e.size(); ++k) { EXPECT_EQ(e_ref[k], e[k]); @@ -83,8 +92,11 @@ TEST(SuppressionFilter, SignalSuppression) { SuppressionFilter filter(48000); FftData cn; FftData cn_high_bands; + std::array e_old_; + Aec3Fft fft; std::array gain; std::vector> e(3, std::vector(kBlockSize, 0.f)); + e_old_.fill(0.f); gain.fill(1.f); std::for_each(gain.begin() + 10, gain.end(), [](float& a) { a = 0.f; }); @@ -103,7 +115,12 @@ TEST(SuppressionFilter, SignalSuppression) { e[0]); e0_input = std::inner_product(e[0].begin(), e[0].end(), e[0].begin(), e0_input); - filter.ApplyGain(cn, cn_high_bands, gain, 1.f, &e); + + FftData E; + fft.PaddedFft(e[0], e_old_, Aec3Fft::Window::kSqrtHanning, &E); + std::copy(e[0].begin(), e[0].end(), e_old_.begin()); + + filter.ApplyGain(cn, cn_high_bands, gain, 1.f, E, &e); e0_output = std::inner_product(e[0].begin(), e[0].end(), e[0].begin(), e0_output); } @@ -116,10 +133,12 @@ TEST(SuppressionFilter, SignalSuppression) { TEST(SuppressionFilter, SignalTransparency) { SuppressionFilter filter(48000); FftData cn; + std::array e_old_; + Aec3Fft fft; FftData cn_high_bands; std::array gain; std::vector> e(3, std::vector(kBlockSize, 0.f)); - + e_old_.fill(0.f); gain.fill(1.f); std::for_each(gain.begin() + 30, gain.end(), [](float& a) { a = 0.f; }); @@ -137,7 +156,12 @@ TEST(SuppressionFilter, SignalTransparency) { e[0]); e0_input = std::inner_product(e[0].begin(), e[0].end(), e[0].begin(), e0_input); - filter.ApplyGain(cn, cn_high_bands, gain, 1.f, &e); + + FftData E; + fft.PaddedFft(e[0], e_old_, Aec3Fft::Window::kSqrtHanning, &E); + std::copy(e[0].begin(), e[0].end(), e_old_.begin()); + + filter.ApplyGain(cn, cn_high_bands, gain, 1.f, E, &e); e0_output = std::inner_product(e[0].begin(), e[0].end(), e[0].begin(), e0_output); } @@ -150,6 +174,8 @@ TEST(SuppressionFilter, Delay) { SuppressionFilter filter(48000); FftData cn; FftData cn_high_bands; + std::array e_old_; + Aec3Fft fft; std::array gain; std::vector> e(3, std::vector(kBlockSize, 0.f)); @@ -167,7 +193,11 @@ TEST(SuppressionFilter, Delay) { } } - filter.ApplyGain(cn, cn_high_bands, gain, 1.f, &e); + FftData E; + fft.PaddedFft(e[0], e_old_, Aec3Fft::Window::kSqrtHanning, &E); + std::copy(e[0].begin(), e[0].end(), e_old_.begin()); + + filter.ApplyGain(cn, cn_high_bands, gain, 1.f, E, &e); if (k > 2) { for (size_t j = 0; j < 2; ++j) { for (size_t i = 0; i < kBlockSize; ++i) { diff --git a/modules/audio_processing/aec3/suppression_gain.cc b/modules/audio_processing/aec3/suppression_gain.cc index 8ebab013bc..110c23279d 100644 --- a/modules/audio_processing/aec3/suppression_gain.cc +++ b/modules/audio_processing/aec3/suppression_gain.cc @@ -342,11 +342,14 @@ void SuppressionGain::LowerBandGain( } SuppressionGain::SuppressionGain(const EchoCanceller3Config& config, - Aec3Optimization optimization) + Aec3Optimization optimization, + int sample_rate_hz) : optimization_(optimization), config_(config), state_change_duration_blocks_( - static_cast(config_.filter.config_change_duration_blocks)) { + static_cast(config_.filter.config_change_duration_blocks)), + coherence_gain_(sample_rate_hz, + config_.suppressor.bands_with_reliable_coherence) { RTC_DCHECK_LT(0, state_change_duration_blocks_); one_by_state_change_duration_blocks_ = 1.f / state_change_duration_blocks_; last_gain_.fill(1.f); @@ -355,10 +358,15 @@ SuppressionGain::SuppressionGain(const EchoCanceller3Config& config, last_echo_.fill(0.f); } +SuppressionGain::~SuppressionGain() = default; + void SuppressionGain::GetGain( - const std::array& nearend, - const std::array& echo, - const std::array& comfort_noise, + const std::array& nearend_spectrum, + const std::array& echo_spectrum, + const std::array& comfort_noise_spectrum, + const FftData& linear_aec_fft, + const FftData& render_fft, + const FftData& capture_fft, const RenderSignalAnalyzer& render_signal_analyzer, const AecState& aec_state, const std::vector>& render, @@ -371,8 +379,8 @@ void SuppressionGain::GetGain( bool low_noise_render = low_render_detector_.Detect(render); const rtc::Optional narrow_peak_band = render_signal_analyzer.NarrowPeakBand(); - LowerBandGain(low_noise_render, narrow_peak_band, aec_state, nearend, echo, - comfort_noise, low_band_gain); + LowerBandGain(low_noise_render, narrow_peak_band, aec_state, nearend_spectrum, + echo_spectrum, comfort_noise_spectrum, low_band_gain); const float gain_upper_bound = aec_state.SuppressionGainLimit(); if (gain_upper_bound < 1.f) { @@ -384,6 +392,17 @@ void SuppressionGain::GetGain( // Compute the gain for the upper bands. *high_bands_gain = UpperBandsGain(narrow_peak_band, aec_state.SaturatedEcho(), render, *low_band_gain); + + // Adjust the gain for bands where the coherence indicates not echo. + if (config_.suppressor.bands_with_reliable_coherence > 0) { + std::array G_coherence; + coherence_gain_.ComputeGain(linear_aec_fft, render_fft, capture_fft, + G_coherence); + for (size_t k = 0; k < config_.suppressor.bands_with_reliable_coherence; + ++k) { + (*low_band_gain)[k] = std::max((*low_band_gain)[k], G_coherence[k]); + } + } } void SuppressionGain::SetInitialState(bool state) { diff --git a/modules/audio_processing/aec3/suppression_gain.h b/modules/audio_processing/aec3/suppression_gain.h index a519894086..59eeb35857 100644 --- a/modules/audio_processing/aec3/suppression_gain.h +++ b/modules/audio_processing/aec3/suppression_gain.h @@ -17,6 +17,7 @@ #include "api/audio/echo_canceller3_config.h" #include "modules/audio_processing/aec3/aec3_common.h" #include "modules/audio_processing/aec3/aec_state.h" +#include "modules/audio_processing/aec3/coherence_gain.h" #include "modules/audio_processing/aec3/render_signal_analyzer.h" #include "rtc_base/constructormagic.h" @@ -25,15 +26,21 @@ namespace webrtc { class SuppressionGain { public: SuppressionGain(const EchoCanceller3Config& config, - Aec3Optimization optimization); - void GetGain(const std::array& nearend, - const std::array& echo, - const std::array& comfort_noise, - const RenderSignalAnalyzer& render_signal_analyzer, - const AecState& aec_state, - const std::vector>& render, - float* high_bands_gain, - std::array* low_band_gain); + Aec3Optimization optimization, + int sample_rate_hz); + ~SuppressionGain(); + void GetGain( + const std::array& nearend_spectrum, + const std::array& echo_spectrum, + const std::array& comfort_noise_spectrum, + const FftData& linear_aec_fft, + const FftData& render_fft, + const FftData& capture_fft, + const RenderSignalAnalyzer& render_signal_analyzer, + const AecState& aec_state, + const std::vector>& render, + float* high_bands_gain, + std::array* low_band_gain); // Toggles the usage of the initial state. void SetInitialState(bool state); @@ -75,6 +82,8 @@ class SuppressionGain { LowNoiseRenderDetector low_render_detector_; bool initial_state_ = true; int initial_state_change_counter_ = 0; + CoherenceGain coherence_gain_; + RTC_DISALLOW_COPY_AND_ASSIGN(SuppressionGain); }; diff --git a/modules/audio_processing/aec3/suppression_gain_unittest.cc b/modules/audio_processing/aec3/suppression_gain_unittest.cc index 9c12b29b60..128c61eaf3 100644 --- a/modules/audio_processing/aec3/suppression_gain_unittest.cc +++ b/modules/audio_processing/aec3/suppression_gain_unittest.cc @@ -29,15 +29,25 @@ TEST(SuppressionGain, NullOutputGains) { std::array E2; std::array R2; std::array N2; + FftData E; + FftData X; + FftData Y; E2.fill(0.f); R2.fill(0.f); N2.fill(0.f); + E.re.fill(0.f); + E.im.fill(0.f); + X.re.fill(0.f); + X.im.fill(0.f); + Y.re.fill(0.f); + Y.im.fill(0.f); + float high_bands_gain; AecState aec_state(EchoCanceller3Config{}); EXPECT_DEATH( - SuppressionGain(EchoCanceller3Config{}, DetectOptimization()) - .GetGain(E2, R2, N2, RenderSignalAnalyzer((EchoCanceller3Config{})), - aec_state, + SuppressionGain(EchoCanceller3Config{}, DetectOptimization(), 16000) + .GetGain(E2, R2, N2, E, X, Y, + RenderSignalAnalyzer((EchoCanceller3Config{})), aec_state, std::vector>( 3, std::vector(kBlockSize, 0.f)), &high_bands_gain, nullptr), @@ -48,8 +58,8 @@ TEST(SuppressionGain, NullOutputGains) { // Does a sanity check that the gains are correctly computed. TEST(SuppressionGain, BasicGainComputation) { - SuppressionGain suppression_gain(EchoCanceller3Config(), - DetectOptimization()); + SuppressionGain suppression_gain(EchoCanceller3Config(), DetectOptimization(), + 16000); RenderSignalAnalyzer analyzer(EchoCanceller3Config{}); float high_bands_gain; std::array E2; @@ -58,6 +68,9 @@ TEST(SuppressionGain, BasicGainComputation) { std::array N2; std::array g; std::array s; + FftData E; + FftData X; + FftData Y; std::vector> x(1, std::vector(kBlockSize, 0.f)); EchoCanceller3Config config; AecState aec_state(config); @@ -73,6 +86,12 @@ TEST(SuppressionGain, BasicGainComputation) { R2.fill(0.1f); N2.fill(100.f); s.fill(10.f); + E.re.fill(sqrtf(E2[0])); + E.im.fill(0.f); + X.re.fill(sqrtf(R2[0])); + X.im.fill(0.f); + Y.re.fill(sqrtf(Y2[0])); + Y.im.fill(0.f); // Ensure that the gain is no longer forced to zero. for (int k = 0; k <= kNumBlocksPerSecond / 5 + 1; ++k) { @@ -87,7 +106,7 @@ TEST(SuppressionGain, BasicGainComputation) { subtractor.FilterImpulseResponse(), subtractor.ConvergedFilter(), subtractor.DivergedFilter(), *render_delay_buffer->GetRenderBuffer(), E2, Y2, s); - suppression_gain.GetGain(E2, R2, N2, analyzer, aec_state, x, + suppression_gain.GetGain(E2, R2, N2, E, X, Y, analyzer, aec_state, x, &high_bands_gain, &g); } std::for_each(g.begin(), g.end(), @@ -98,12 +117,16 @@ TEST(SuppressionGain, BasicGainComputation) { Y2.fill(100.f); R2.fill(0.1f); N2.fill(0.f); + E.re.fill(sqrtf(E2[0])); + X.re.fill(sqrtf(R2[0])); + Y.re.fill(sqrtf(Y2[0])); + for (int k = 0; k < 100; ++k) { aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse(), subtractor.FilterImpulseResponse(), subtractor.ConvergedFilter(), subtractor.DivergedFilter(), *render_delay_buffer->GetRenderBuffer(), E2, Y2, s); - suppression_gain.GetGain(E2, R2, N2, analyzer, aec_state, x, + suppression_gain.GetGain(E2, R2, N2, E, X, Y, analyzer, aec_state, x, &high_bands_gain, &g); } std::for_each(g.begin(), g.end(), @@ -112,9 +135,11 @@ TEST(SuppressionGain, BasicGainComputation) { // Ensure that a strong echo is suppressed. E2.fill(1000000000.f); R2.fill(10000000000000.f); - N2.fill(0.f); + E.re.fill(sqrtf(E2[0])); + X.re.fill(sqrtf(R2[0])); + for (int k = 0; k < 10; ++k) { - suppression_gain.GetGain(E2, R2, N2, analyzer, aec_state, x, + suppression_gain.GetGain(E2, R2, N2, E, X, Y, analyzer, aec_state, x, &high_bands_gain, &g); } std::for_each(g.begin(), g.end(),