From 522d71bf3605e99f4a2b5fd253d5ccd171575e7d Mon Sep 17 00:00:00 2001 From: peah Date: Thu, 23 Feb 2017 05:16:26 -0800 Subject: [PATCH] Finalization of the first version of EchoCanceller 3 This CL adds the remaining code for the first version of EchoCanceller3. TBR=aleloi@webrtc.org BUG=webrtc:6018 Review-Url: https://codereview.webrtc.org/2678423005 Cr-Commit-Position: refs/heads/master@{#16801} --- webrtc/modules/audio_processing/BUILD.gn | 60 +++- .../aec3/adaptive_fir_filter.cc | 309 ++++++++++++++++++ .../aec3/adaptive_fir_filter.h | 115 +++++++ .../aec3/adaptive_fir_filter_unittest.cc | 219 +++++++++++++ .../audio_processing/aec3/aec3_common.cc | 27 ++ .../aec3/{aec3_constants.h => aec3_common.h} | 21 +- .../modules/audio_processing/aec3/aec3_fft.cc | 42 +++ .../modules/audio_processing/aec3/aec3_fft.h | 60 ++++ .../aec3/aec3_fft_unittest.cc | 211 ++++++++++++ .../audio_processing/aec3/aec_state.cc | 162 +++++++++ .../modules/audio_processing/aec3/aec_state.h | 127 +++++++ .../aec3/aec_state_unittest.cc | 276 ++++++++++++++++ .../audio_processing/aec3/block_framer.h | 2 +- .../aec3/block_framer_unittest.cc | 2 +- .../audio_processing/aec3/block_processor.cc | 2 +- .../aec3/block_processor_unittest.cc | 4 +- .../aec3/comfort_noise_generator.cc | 208 ++++++++++++ .../aec3/comfort_noise_generator.h | 68 ++++ .../aec3/comfort_noise_generator_unittest.cc | 119 +++++++ .../audio_processing/aec3/decimator_by_4.h | 2 +- .../aec3/decimator_by_4_unittest.cc | 2 +- .../audio_processing/aec3/echo_canceller3.cc | 25 +- .../aec3/echo_canceller3_unittest.cc | 86 +++-- .../aec3/echo_path_delay_estimator.cc | 3 +- .../echo_path_delay_estimator_unittest.cc | 4 +- .../aec3/echo_path_variability.cc | 18 + .../aec3/echo_path_variability.h | 3 +- .../aec3/echo_path_variability_unittest.cc | 39 +++ .../audio_processing/aec3/echo_remover.cc | 222 +++++++++++-- .../aec3/echo_remover_unittest.cc | 81 ++++- .../audio_processing/aec3/erl_estimator.cc | 63 ++++ .../audio_processing/aec3/erl_estimator.h | 43 +++ .../aec3/erl_estimator_unittest.cc | 70 ++++ .../audio_processing/aec3/erle_estimator.cc | 65 ++++ .../audio_processing/aec3/erle_estimator.h | 44 +++ .../aec3/erle_estimator_unittest.cc | 66 ++++ .../audio_processing/aec3/fft_buffer.cc | 72 ++++ .../audio_processing/aec3/fft_buffer.h | 70 ++++ .../aec3/fft_buffer_unittest.cc | 76 +++++ .../modules/audio_processing/aec3/fft_data.h | 98 ++++++ .../aec3/fft_data_unittest.cc | 163 +++++++++ .../audio_processing/aec3/frame_blocker.cc | 1 - .../audio_processing/aec3/frame_blocker.h | 2 +- .../aec3/frame_blocker_unittest.cc | 2 +- .../aec3/main_filter_update_gain.cc | 117 +++++++ .../aec3/main_filter_update_gain.h | 56 ++++ .../aec3/main_filter_update_gain_unittest.cc | 287 ++++++++++++++++ .../audio_processing/aec3/matched_filter.cc | 212 ++++++++---- .../audio_processing/aec3/matched_filter.h | 29 +- .../aec3/matched_filter_lag_aggregator.cc | 2 +- .../matched_filter_lag_aggregator_unittest.cc | 4 +- .../aec3/matched_filter_unittest.cc | 71 +++- .../aec3/mock/mock_render_delay_buffer.h | 2 +- .../audio_processing/aec3/output_selector.cc | 72 ++++ .../audio_processing/aec3/output_selector.h | 41 +++ .../aec3/output_selector_unittest.cc | 71 ++++ .../audio_processing/aec3/power_echo_model.cc | 111 +++++++ .../audio_processing/aec3/power_echo_model.h | 61 ++++ .../aec3/power_echo_model_unittest.cc | 133 ++++++++ .../aec3/render_delay_buffer.cc | 2 +- .../aec3/render_delay_buffer_unittest.cc | 2 +- .../aec3/render_delay_controller.cc | 4 +- .../aec3/render_delay_controller_unittest.cc | 17 +- .../aec3/render_signal_analyzer.cc | 66 ++++ .../aec3/render_signal_analyzer.h | 54 +++ .../aec3/render_signal_analyzer_unittest.cc | 123 +++++++ .../aec3/residual_echo_estimator.cc | 215 ++++++++++++ .../aec3/residual_echo_estimator.h | 52 +++ .../aec3/residual_echo_estimator_unittest.cc | 87 +++++ .../aec3/shadow_filter_update_gain.cc | 60 ++++ .../aec3/shadow_filter_update_gain.h | 39 +++ .../shadow_filter_update_gain_unittest.cc | 187 +++++++++++ .../audio_processing/aec3/subtractor.cc | 117 +++++++ .../audio_processing/aec3/subtractor.h | 79 +++++ .../audio_processing/aec3/subtractor_output.h | 44 +++ .../aec3/subtractor_unittest.cc | 175 ++++++++++ .../aec3/suppression_filter.cc | 178 ++++++++++ .../aec3/suppression_filter.h | 43 +++ .../aec3/suppression_filter_unittest.cc | 180 ++++++++++ .../audio_processing/aec3/suppression_gain.cc | 282 ++++++++++++++++ .../audio_processing/aec3/suppression_gain.h | 64 ++++ .../aec3/suppression_gain_unittest.cc | 148 +++++++++ .../audio_processing/audio_processing_impl.cc | 29 +- 83 files changed, 6668 insertions(+), 202 deletions(-) create mode 100644 webrtc/modules/audio_processing/aec3/adaptive_fir_filter.cc create mode 100644 webrtc/modules/audio_processing/aec3/adaptive_fir_filter.h create mode 100644 webrtc/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc create mode 100644 webrtc/modules/audio_processing/aec3/aec3_common.cc rename webrtc/modules/audio_processing/aec3/{aec3_constants.h => aec3_common.h} (80%) create mode 100644 webrtc/modules/audio_processing/aec3/aec3_fft.cc create mode 100644 webrtc/modules/audio_processing/aec3/aec3_fft.h create mode 100644 webrtc/modules/audio_processing/aec3/aec3_fft_unittest.cc create mode 100644 webrtc/modules/audio_processing/aec3/aec_state.cc create mode 100644 webrtc/modules/audio_processing/aec3/aec_state.h create mode 100644 webrtc/modules/audio_processing/aec3/aec_state_unittest.cc create mode 100644 webrtc/modules/audio_processing/aec3/comfort_noise_generator.cc create mode 100644 webrtc/modules/audio_processing/aec3/comfort_noise_generator.h create mode 100644 webrtc/modules/audio_processing/aec3/comfort_noise_generator_unittest.cc create mode 100644 webrtc/modules/audio_processing/aec3/echo_path_variability.cc create mode 100644 webrtc/modules/audio_processing/aec3/echo_path_variability_unittest.cc create mode 100644 webrtc/modules/audio_processing/aec3/erl_estimator.cc create mode 100644 webrtc/modules/audio_processing/aec3/erl_estimator.h create mode 100644 webrtc/modules/audio_processing/aec3/erl_estimator_unittest.cc create mode 100644 webrtc/modules/audio_processing/aec3/erle_estimator.cc create mode 100644 webrtc/modules/audio_processing/aec3/erle_estimator.h create mode 100644 webrtc/modules/audio_processing/aec3/erle_estimator_unittest.cc create mode 100644 webrtc/modules/audio_processing/aec3/fft_buffer.cc create mode 100644 webrtc/modules/audio_processing/aec3/fft_buffer.h create mode 100644 webrtc/modules/audio_processing/aec3/fft_buffer_unittest.cc create mode 100644 webrtc/modules/audio_processing/aec3/fft_data.h create mode 100644 webrtc/modules/audio_processing/aec3/fft_data_unittest.cc create mode 100644 webrtc/modules/audio_processing/aec3/main_filter_update_gain.cc create mode 100644 webrtc/modules/audio_processing/aec3/main_filter_update_gain.h create mode 100644 webrtc/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc create mode 100644 webrtc/modules/audio_processing/aec3/output_selector.cc create mode 100644 webrtc/modules/audio_processing/aec3/output_selector.h create mode 100644 webrtc/modules/audio_processing/aec3/output_selector_unittest.cc create mode 100644 webrtc/modules/audio_processing/aec3/power_echo_model.cc create mode 100644 webrtc/modules/audio_processing/aec3/power_echo_model.h create mode 100644 webrtc/modules/audio_processing/aec3/power_echo_model_unittest.cc create mode 100644 webrtc/modules/audio_processing/aec3/render_signal_analyzer.cc create mode 100644 webrtc/modules/audio_processing/aec3/render_signal_analyzer.h create mode 100644 webrtc/modules/audio_processing/aec3/render_signal_analyzer_unittest.cc create mode 100644 webrtc/modules/audio_processing/aec3/residual_echo_estimator.cc create mode 100644 webrtc/modules/audio_processing/aec3/residual_echo_estimator.h create mode 100644 webrtc/modules/audio_processing/aec3/residual_echo_estimator_unittest.cc create mode 100644 webrtc/modules/audio_processing/aec3/shadow_filter_update_gain.cc create mode 100644 webrtc/modules/audio_processing/aec3/shadow_filter_update_gain.h create mode 100644 webrtc/modules/audio_processing/aec3/shadow_filter_update_gain_unittest.cc create mode 100644 webrtc/modules/audio_processing/aec3/subtractor.cc create mode 100644 webrtc/modules/audio_processing/aec3/subtractor.h create mode 100644 webrtc/modules/audio_processing/aec3/subtractor_output.h create mode 100644 webrtc/modules/audio_processing/aec3/subtractor_unittest.cc create mode 100644 webrtc/modules/audio_processing/aec3/suppression_filter.cc create mode 100644 webrtc/modules/audio_processing/aec3/suppression_filter.h create mode 100644 webrtc/modules/audio_processing/aec3/suppression_filter_unittest.cc create mode 100644 webrtc/modules/audio_processing/aec3/suppression_gain.cc create mode 100644 webrtc/modules/audio_processing/aec3/suppression_gain.h create mode 100644 webrtc/modules/audio_processing/aec3/suppression_gain_unittest.cc diff --git a/webrtc/modules/audio_processing/BUILD.gn b/webrtc/modules/audio_processing/BUILD.gn index 2684a782c3..c06707aebf 100644 --- a/webrtc/modules/audio_processing/BUILD.gn +++ b/webrtc/modules/audio_processing/BUILD.gn @@ -26,32 +26,68 @@ rtc_static_library("audio_processing") { "aec/aec_resampler.h", "aec/echo_cancellation.cc", "aec/echo_cancellation.h", - "aec3/aec3_constants.h", + "aec3/adaptive_fir_filter.cc", + "aec3/adaptive_fir_filter.h", + "aec3/aec3_common.cc", + "aec3/aec3_common.h", + "aec3/aec3_fft.cc", + "aec3/aec3_fft.h", + "aec3/aec_state.cc", + "aec3/aec_state.h", "aec3/block_framer.cc", "aec3/block_framer.h", "aec3/block_processor.cc", "aec3/block_processor.h", "aec3/cascaded_biquad_filter.cc", "aec3/cascaded_biquad_filter.h", + "aec3/comfort_noise_generator.cc", + "aec3/comfort_noise_generator.h", "aec3/decimator_by_4.cc", "aec3/decimator_by_4.h", "aec3/echo_canceller3.cc", "aec3/echo_canceller3.h", "aec3/echo_path_delay_estimator.cc", "aec3/echo_path_delay_estimator.h", + "aec3/echo_path_variability.cc", "aec3/echo_path_variability.h", "aec3/echo_remover.cc", "aec3/echo_remover.h", + "aec3/erl_estimator.cc", + "aec3/erl_estimator.h", + "aec3/erle_estimator.cc", + "aec3/erle_estimator.h", + "aec3/fft_buffer.cc", + "aec3/fft_buffer.h", + "aec3/fft_data.h", "aec3/frame_blocker.cc", "aec3/frame_blocker.h", + "aec3/main_filter_update_gain.cc", + "aec3/main_filter_update_gain.h", "aec3/matched_filter.cc", "aec3/matched_filter.h", "aec3/matched_filter_lag_aggregator.cc", "aec3/matched_filter_lag_aggregator.h", + "aec3/output_selector.cc", + "aec3/output_selector.h", + "aec3/power_echo_model.cc", + "aec3/power_echo_model.h", "aec3/render_delay_buffer.cc", "aec3/render_delay_buffer.h", "aec3/render_delay_controller.cc", "aec3/render_delay_controller.h", + "aec3/render_signal_analyzer.cc", + "aec3/render_signal_analyzer.h", + "aec3/residual_echo_estimator.cc", + "aec3/residual_echo_estimator.h", + "aec3/shadow_filter_update_gain.cc", + "aec3/shadow_filter_update_gain.h", + "aec3/subtractor.cc", + "aec3/subtractor.h", + "aec3/subtractor_output.h", + "aec3/suppression_filter.cc", + "aec3/suppression_filter.h", + "aec3/suppression_gain.cc", + "aec3/suppression_gain.h", "aecm/aecm_core.cc", "aecm/aecm_core.h", "aecm/echo_control_mobile.cc", @@ -522,22 +558,36 @@ if (rtc_include_tests) { ":audioproc_unittest_proto", ] sources += [ + "aec3/adaptive_fir_filter_unittest.cc", + "aec3/aec3_fft_unittest.cc", + "aec3/aec_state_unittest.cc", "aec3/block_framer_unittest.cc", "aec3/block_processor_unittest.cc", "aec3/cascaded_biquad_filter_unittest.cc", + "aec3/comfort_noise_generator_unittest.cc", "aec3/decimator_by_4_unittest.cc", "aec3/echo_canceller3_unittest.cc", "aec3/echo_path_delay_estimator_unittest.cc", + "aec3/echo_path_variability_unittest.cc", "aec3/echo_remover_unittest.cc", + "aec3/erl_estimator_unittest.cc", + "aec3/erle_estimator_unittest.cc", + "aec3/fft_buffer_unittest.cc", + "aec3/fft_data_unittest.cc", "aec3/frame_blocker_unittest.cc", + "aec3/main_filter_update_gain_unittest.cc", "aec3/matched_filter_lag_aggregator_unittest.cc", "aec3/matched_filter_unittest.cc", - "aec3/mock/mock_block_processor.h", - "aec3/mock/mock_echo_remover.h", - "aec3/mock/mock_render_delay_buffer.h", - "aec3/mock/mock_render_delay_controller.h", + "aec3/output_selector_unittest.cc", + "aec3/power_echo_model_unittest.cc", "aec3/render_delay_buffer_unittest.cc", "aec3/render_delay_controller_unittest.cc", + "aec3/render_signal_analyzer_unittest.cc", + "aec3/residual_echo_estimator_unittest.cc", + "aec3/shadow_filter_update_gain_unittest.cc", + "aec3/subtractor_unittest.cc", + "aec3/suppression_filter_unittest.cc", + "aec3/suppression_gain_unittest.cc", "audio_processing_impl_locking_unittest.cc", "audio_processing_impl_unittest.cc", "audio_processing_unittest.cc", diff --git a/webrtc/modules/audio_processing/aec3/adaptive_fir_filter.cc b/webrtc/modules/audio_processing/aec3/adaptive_fir_filter.cc new file mode 100644 index 0000000000..300baf8eb2 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/adaptive_fir_filter.cc @@ -0,0 +1,309 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/adaptive_fir_filter.h" + +#include "webrtc/typedefs.h" +#if defined(WEBRTC_ARCH_X86_FAMILY) +#include +#endif +#include +#include + +#include "webrtc/base/checks.h" +#include "webrtc/modules/audio_processing/aec3/fft_data.h" + +namespace webrtc { + +namespace { + +// Constrains the a partiton of the frequency domain filter to be limited in +// time via setting the relevant time-domain coefficients to zero. +void Constrain(const Aec3Fft& fft, FftData* H) { + std::array h; + fft.Ifft(*H, &h); + constexpr float kScale = 1.0f / kFftLengthBy2; + std::for_each(h.begin(), h.begin() + kFftLengthBy2, + [kScale](float& a) { a *= kScale; }); + std::fill(h.begin() + kFftLengthBy2, h.end(), 0.f); + fft.Fft(&h, H); +} + +// Computes and stores the frequency response of the filter. +void UpdateFrequencyResponse( + rtc::ArrayView H, + std::vector>* H2) { + for (size_t k = 0; k < H.size(); ++k) { + std::transform(H[k].re.begin(), H[k].re.end(), H[k].im.begin(), + (*H2)[k].begin(), + [](float a, float b) { return a * a + b * b; }); + } +} + +// Computes and stores the echo return loss estimate of the filter, which is the +// sum of the partition frequency responses. +void UpdateErlEstimator( + const std::vector>& H2, + std::array* erl) { + erl->fill(0.f); + for (auto& H2_j : H2) { + std::transform(H2_j.begin(), H2_j.end(), erl->begin(), erl->begin(), + std::plus()); + } +} + +// Resets the filter. +void ResetFilter(rtc::ArrayView H) { + for (auto& H_j : H) { + H_j.Clear(); + } +} + +} // namespace + +namespace aec3 { + +// Adapts the filter partitions as H(t+1)=H(t)+G(t)*conj(X(t)). +void AdaptPartitions(const FftBuffer& X_buffer, + const FftData& G, + rtc::ArrayView H) { + rtc::ArrayView X_buffer_data = X_buffer.Buffer(); + size_t index = X_buffer.Position(); + for (auto& H_j : H) { + const FftData& X = X_buffer_data[index]; + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + H_j.re[k] += X.re[k] * G.re[k] + X.im[k] * G.im[k]; + H_j.im[k] += X.re[k] * G.im[k] - X.im[k] * G.re[k]; + } + + index = index < (X_buffer_data.size() - 1) ? index + 1 : 0; + } +} + +#if defined(WEBRTC_ARCH_X86_FAMILY) +// Adapts the filter partitions. (SSE2 variant) +void AdaptPartitions_SSE2(const FftBuffer& X_buffer, + const FftData& G, + rtc::ArrayView H) { + rtc::ArrayView X_buffer_data = X_buffer.Buffer(); + const int lim1 = + std::min(X_buffer_data.size() - X_buffer.Position(), H.size()); + const int lim2 = H.size(); + constexpr int kNumFourBinBands = kFftLengthBy2 / 4; + FftData* H_j; + const FftData* X; + int limit; + int j; + for (int k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) { + const __m128 G_re = _mm_loadu_ps(&G.re[k]); + const __m128 G_im = _mm_loadu_ps(&G.im[k]); + + H_j = &H[0]; + X = &X_buffer_data[X_buffer.Position()]; + limit = lim1; + j = 0; + do { + for (; j < limit; ++j, ++H_j, ++X) { + const __m128 X_re = _mm_loadu_ps(&X->re[k]); + const __m128 X_im = _mm_loadu_ps(&X->im[k]); + const __m128 H_re = _mm_loadu_ps(&H_j->re[k]); + const __m128 H_im = _mm_loadu_ps(&H_j->im[k]); + const __m128 a = _mm_mul_ps(X_re, G_re); + const __m128 b = _mm_mul_ps(X_im, G_im); + const __m128 c = _mm_mul_ps(X_re, G_im); + const __m128 d = _mm_mul_ps(X_im, G_re); + const __m128 e = _mm_add_ps(a, b); + const __m128 f = _mm_sub_ps(c, d); + const __m128 g = _mm_add_ps(H_re, e); + const __m128 h = _mm_add_ps(H_im, f); + _mm_storeu_ps(&H_j->re[k], g); + _mm_storeu_ps(&H_j->im[k], h); + } + + X = &X_buffer_data[0]; + limit = lim2; + } while (j < lim2); + } + + H_j = &H[0]; + X = &X_buffer_data[X_buffer.Position()]; + limit = lim1; + j = 0; + do { + for (; j < limit; ++j, ++H_j, ++X) { + H_j->re[kFftLengthBy2] += X->re[kFftLengthBy2] * G.re[kFftLengthBy2] + + X->im[kFftLengthBy2] * G.im[kFftLengthBy2]; + H_j->im[kFftLengthBy2] += X->re[kFftLengthBy2] * G.im[kFftLengthBy2] - + X->im[kFftLengthBy2] * G.re[kFftLengthBy2]; + } + + X = &X_buffer_data[0]; + limit = lim2; + } while (j < lim2); +} +#endif + +// Produces the filter output. +void ApplyFilter(const FftBuffer& X_buffer, + rtc::ArrayView H, + FftData* S) { + S->re.fill(0.f); + S->im.fill(0.f); + + rtc::ArrayView X_buffer_data = X_buffer.Buffer(); + size_t index = X_buffer.Position(); + for (auto& H_j : H) { + const FftData& X = X_buffer_data[index]; + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + S->re[k] += X.re[k] * H_j.re[k] - X.im[k] * H_j.im[k]; + S->im[k] += X.re[k] * H_j.im[k] + X.im[k] * H_j.re[k]; + } + index = index < (X_buffer_data.size() - 1) ? index + 1 : 0; + } +} + +#if defined(WEBRTC_ARCH_X86_FAMILY) +// Produces the filter output (SSE2 variant). +void ApplyFilter_SSE2(const FftBuffer& X_buffer, + rtc::ArrayView H, + FftData* S) { + S->re.fill(0.f); + S->im.fill(0.f); + + rtc::ArrayView X_buffer_data = X_buffer.Buffer(); + const int lim1 = + std::min(X_buffer_data.size() - X_buffer.Position(), H.size()); + const int lim2 = H.size(); + constexpr int kNumFourBinBands = kFftLengthBy2 / 4; + const FftData* H_j = &H[0]; + const FftData* X = &X_buffer_data[X_buffer.Position()]; + + int j = 0; + int limit = lim1; + do { + for (; j < limit; ++j, ++H_j, ++X) { + for (int k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) { + const __m128 X_re = _mm_loadu_ps(&X->re[k]); + const __m128 X_im = _mm_loadu_ps(&X->im[k]); + const __m128 H_re = _mm_loadu_ps(&H_j->re[k]); + const __m128 H_im = _mm_loadu_ps(&H_j->im[k]); + const __m128 S_re = _mm_loadu_ps(&S->re[k]); + const __m128 S_im = _mm_loadu_ps(&S->im[k]); + const __m128 a = _mm_mul_ps(X_re, H_re); + const __m128 b = _mm_mul_ps(X_im, H_im); + const __m128 c = _mm_mul_ps(X_re, H_im); + const __m128 d = _mm_mul_ps(X_im, H_re); + const __m128 e = _mm_sub_ps(a, b); + const __m128 f = _mm_add_ps(c, d); + const __m128 g = _mm_add_ps(S_re, e); + const __m128 h = _mm_add_ps(S_im, f); + _mm_storeu_ps(&S->re[k], g); + _mm_storeu_ps(&S->im[k], h); + } + } + limit = lim2; + X = &X_buffer_data[0]; + } while (j < lim2); + + H_j = &H[0]; + X = &X_buffer_data[X_buffer.Position()]; + j = 0; + limit = lim1; + do { + for (; j < limit; ++j, ++H_j, ++X) { + S->re[kFftLengthBy2] += X->re[kFftLengthBy2] * H_j->re[kFftLengthBy2] - + X->im[kFftLengthBy2] * H_j->im[kFftLengthBy2]; + S->im[kFftLengthBy2] += X->re[kFftLengthBy2] * H_j->im[kFftLengthBy2] + + X->im[kFftLengthBy2] * H_j->re[kFftLengthBy2]; + } + limit = lim2; + X = &X_buffer_data[0]; + } while (j < lim2); +} +#endif + +} // namespace aec3 + +AdaptiveFirFilter::AdaptiveFirFilter(size_t size_partitions, + bool use_filter_statistics, + Aec3Optimization optimization, + ApmDataDumper* data_dumper) + : data_dumper_(data_dumper), + optimization_(optimization), + H_(size_partitions) { + RTC_DCHECK(data_dumper_); + ResetFilter(H_); + + if (use_filter_statistics) { + H2_.reset(new std::vector>( + size_partitions, std::array())); + for (auto H2_k : *H2_) { + H2_k.fill(0.f); + } + + erl_.reset(new std::array()); + erl_->fill(0.f); + } +} + +AdaptiveFirFilter::~AdaptiveFirFilter() = default; + +void AdaptiveFirFilter::HandleEchoPathChange() { + ResetFilter(H_); + if (H2_) { + for (auto H2_k : *H2_) { + H2_k.fill(0.f); + } + RTC_DCHECK(erl_); + erl_->fill(0.f); + } +} + +void AdaptiveFirFilter::Filter(const FftBuffer& X_buffer, FftData* S) const { + RTC_DCHECK(S); + switch (optimization_) { +#if defined(WEBRTC_ARCH_X86_FAMILY) + case Aec3Optimization::kSse2: + aec3::ApplyFilter_SSE2(X_buffer, H_, S); + break; +#endif + default: + aec3::ApplyFilter(X_buffer, H_, S); + } +} + +void AdaptiveFirFilter::Adapt(const FftBuffer& X_buffer, const FftData& G) { + // Adapt the filter. + switch (optimization_) { +#if defined(WEBRTC_ARCH_X86_FAMILY) + case Aec3Optimization::kSse2: + aec3::AdaptPartitions_SSE2(X_buffer, G, H_); + break; +#endif + default: + aec3::AdaptPartitions(X_buffer, G, H_); + } + + // Constrain the filter partitions in a cyclic manner. + Constrain(fft_, &H_[partition_to_constrain_]); + partition_to_constrain_ = partition_to_constrain_ < (H_.size() - 1) + ? partition_to_constrain_ + 1 + : 0; + + // Optionally update the frequency response and echo return loss for the + // filter. + if (H2_) { + RTC_DCHECK(erl_); + UpdateFrequencyResponse(H_, H2_.get()); + UpdateErlEstimator(*H2_, erl_.get()); + } +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/adaptive_fir_filter.h b/webrtc/modules/audio_processing/aec3/adaptive_fir_filter.h new file mode 100644 index 0000000000..d927f148e1 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/adaptive_fir_filter.h @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_ADAPTIVE_FIR_FILTER_H_ +#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_ADAPTIVE_FIR_FILTER_H_ + +#include +#include +#include + +#include "webrtc/base/array_view.h" +#include "webrtc/base/constructormagic.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" +#include "webrtc/modules/audio_processing/aec3/aec3_fft.h" +#include "webrtc/modules/audio_processing/aec3/fft_buffer.h" +#include "webrtc/modules/audio_processing/aec3/fft_data.h" +#include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" + +namespace webrtc { +namespace aec3 { +// Adapts the filter partitions. +void AdaptPartitions(const FftBuffer& X_buffer, + const FftData& G, + rtc::ArrayView H); +#if defined(WEBRTC_ARCH_X86_FAMILY) +void AdaptPartitions_SSE2(const FftBuffer& X_buffer, + const FftData& G, + rtc::ArrayView H); +#endif + +// Produces the filter output. +void ApplyFilter(const FftBuffer& X_buffer, + rtc::ArrayView H, + FftData* S); +#if defined(WEBRTC_ARCH_X86_FAMILY) +void ApplyFilter_SSE2(const FftBuffer& X_buffer, + rtc::ArrayView H, + FftData* S); +#endif + +} // namespace aec3 + +// Provides a frequency domain adaptive filter functionality. +class AdaptiveFirFilter { + public: + AdaptiveFirFilter(size_t size_partitions, + bool use_filter_statistics, + Aec3Optimization optimization, + ApmDataDumper* data_dumper); + + ~AdaptiveFirFilter(); + + // Produces the output of the filter. + void Filter(const FftBuffer& X_buffer, FftData* S) const; + + // Adapts the filter. + void Adapt(const FftBuffer& X_buffer, const FftData& G); + + // Receives reports that known echo path changes have occured and adjusts + // the filter adaptation accordingly. + void HandleEchoPathChange(); + + // Returns the filter size. + size_t SizePartitions() const { return H_.size(); } + + // Returns the filter based echo return loss. This method can only be used if + // the usage of filter statistics has been specified during the creation of + // the adaptive filter. + const std::array& Erl() const { + RTC_DCHECK(erl_) << "The filter must be created with use_filter_statistics " + "set to true in order to be able to call retrieve the " + "ERL."; + return *erl_; + } + + // Returns the frequency responses for the filter partitions. This method can + // only be used if the usage of filter statistics has been specified during + // the creation of the adaptive filter. + const std::vector>& + FilterFrequencyResponse() const { + RTC_DCHECK(H2_) << "The filter must be created with use_filter_statistics " + "set to true in order to be able to call retrieve the " + "filter frequency responde."; + return *H2_; + } + + void DumpFilter(const char* name) { + for (auto& H : H_) { + data_dumper_->DumpRaw(name, H.re); + data_dumper_->DumpRaw(name, H.im); + } + } + + private: + ApmDataDumper* const data_dumper_; + const Aec3Fft fft_; + const Aec3Optimization optimization_; + std::vector H_; + std::unique_ptr>> H2_; + std::unique_ptr> erl_; + size_t partition_to_constrain_ = 0; + + RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(AdaptiveFirFilter); +}; + +} // namespace webrtc + +#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_ADAPTIVE_FIR_FILTER_H_ diff --git a/webrtc/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc b/webrtc/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc new file mode 100644 index 0000000000..d46eba571b --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc @@ -0,0 +1,219 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/adaptive_fir_filter.h" + +#include +#include +#include +#include "webrtc/typedefs.h" +#if defined(WEBRTC_ARCH_X86_FAMILY) +#include +#endif +#include "webrtc/base/arraysize.h" +#include "webrtc/base/random.h" +#include "webrtc/modules/audio_processing/aec3/aec_state.h" +#include "webrtc/modules/audio_processing/aec3/aec3_fft.h" +#include "webrtc/modules/audio_processing/aec3/render_signal_analyzer.h" +#include "webrtc/modules/audio_processing/aec3/shadow_filter_update_gain.h" +#include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" +#include "webrtc/modules/audio_processing/test/echo_canceller_test_tools.h" +#include "webrtc/system_wrappers/include/cpu_features_wrapper.h" +#include "webrtc/test/gtest.h" + +namespace webrtc { +namespace aec3 { +namespace { + +std::string ProduceDebugText(size_t delay) { + std::ostringstream ss; + ss << ", Delay: " << delay; + return ss.str(); +} + +} // namespace + +#if defined(WEBRTC_ARCH_X86_FAMILY) +// Verifies that the optimized methods are bitexact to their reference +// counterparts. +TEST(AdaptiveFirFilter, TestOptimizations) { + bool use_sse2 = (WebRtc_GetCPUInfo(kSSE2) != 0); + if (use_sse2) { + FftBuffer X_buffer(Aec3Optimization::kNone, 12, std::vector(1, 12)); + std::array x_old; + x_old.fill(0.f); + Random random_generator(42U); + std::vector x(kBlockSize, 0.f); + FftData X; + FftData S_C; + FftData S_SSE2; + FftData G; + Aec3Fft fft; + std::vector H_C(10); + std::vector H_SSE2(10); + for (auto& H_j : H_C) { + H_j.Clear(); + } + for (auto& H_j : H_SSE2) { + H_j.Clear(); + } + + for (size_t k = 0; k < 500; ++k) { + RandomizeSampleVector(&random_generator, x); + fft.PaddedFft(x, x_old, &X); + X_buffer.Insert(X); + + ApplyFilter_SSE2(X_buffer, H_SSE2, &S_SSE2); + ApplyFilter(X_buffer, H_C, &S_C); + for (size_t j = 0; j < S_C.re.size(); ++j) { + EXPECT_FLOAT_EQ(S_C.re[j], S_SSE2.re[j]); + EXPECT_FLOAT_EQ(S_C.im[j], S_SSE2.im[j]); + } + + std::for_each(G.re.begin(), G.re.end(), + [&](float& a) { a = random_generator.Rand(); }); + std::for_each(G.im.begin(), G.im.end(), + [&](float& a) { a = random_generator.Rand(); }); + + AdaptPartitions_SSE2(X_buffer, G, H_SSE2); + AdaptPartitions(X_buffer, G, H_C); + + for (size_t k = 0; k < H_C.size(); ++k) { + for (size_t j = 0; j < H_C[k].re.size(); ++j) { + EXPECT_FLOAT_EQ(H_C[k].re[j], H_SSE2[k].re[j]); + EXPECT_FLOAT_EQ(H_C[k].im[j], H_SSE2[k].im[j]); + } + } + } + } +} + +#endif + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) +// Verifies that the check for non-null data dumper works. +TEST(AdaptiveFirFilter, NullDataDumper) { + EXPECT_DEATH(AdaptiveFirFilter(9, true, DetectOptimization(), nullptr), ""); +} + +// Verifies that the check for non-null filter output works. +TEST(AdaptiveFirFilter, NullFilterOutput) { + ApmDataDumper data_dumper(42); + AdaptiveFirFilter filter(9, true, DetectOptimization(), &data_dumper); + FftBuffer X_buffer(Aec3Optimization::kNone, filter.SizePartitions(), + std::vector(1, filter.SizePartitions())); + EXPECT_DEATH(filter.Filter(X_buffer, nullptr), ""); +} + +// Verifies that the check for whether filter statistics are being generated +// works when retrieving the ERL. +TEST(AdaptiveFirFilter, ErlAccessWhenNoFilterStatistics) { + ApmDataDumper data_dumper(42); + AdaptiveFirFilter filter(9, false, DetectOptimization(), &data_dumper); + EXPECT_DEATH(filter.Erl(), ""); +} + +// Verifies that the check for whether filter statistics are being generated +// works when retrieving the filter frequencyResponse. +TEST(AdaptiveFirFilter, FilterFrequencyResponseAccessWhenNoFilterStatistics) { + ApmDataDumper data_dumper(42); + AdaptiveFirFilter filter(9, false, DetectOptimization(), &data_dumper); + EXPECT_DEATH(filter.FilterFrequencyResponse(), ""); +} + +#endif + +// Verifies that the filter statistics can be accessed when filter statistics +// are turned on. +TEST(AdaptiveFirFilter, FilterStatisticsAccess) { + ApmDataDumper data_dumper(42); + AdaptiveFirFilter filter(9, true, DetectOptimization(), &data_dumper); + filter.Erl(); + filter.FilterFrequencyResponse(); +} + +// Verifies that the filter size if correctly repported. +TEST(AdaptiveFirFilter, FilterSize) { + ApmDataDumper data_dumper(42); + for (size_t filter_size = 1; filter_size < 5; ++filter_size) { + AdaptiveFirFilter filter(filter_size, false, DetectOptimization(), + &data_dumper); + EXPECT_EQ(filter_size, filter.SizePartitions()); + } +} + +// Verifies that the filter is being able to properly filter a signal and to +// adapt its coefficients. +TEST(AdaptiveFirFilter, FilterAndAdapt) { + constexpr size_t kNumBlocksToProcess = 500; + ApmDataDumper data_dumper(42); + AdaptiveFirFilter filter(9, true, DetectOptimization(), &data_dumper); + Aec3Fft fft; + FftBuffer X_buffer(Aec3Optimization::kNone, filter.SizePartitions(), + std::vector(1, filter.SizePartitions())); + std::array x_old; + x_old.fill(0.f); + ShadowFilterUpdateGain gain; + Random random_generator(42U); + std::vector x(kBlockSize, 0.f); + std::vector y(kBlockSize, 0.f); + AecState aec_state; + RenderSignalAnalyzer render_signal_analyzer; + FftData X; + std::vector e(kBlockSize, 0.f); + std::array s; + FftData S; + FftData G; + FftData E; + std::array Y2; + std::array E2_main; + std::array E2_shadow; + Y2.fill(0.f); + E2_main.fill(0.f); + E2_shadow.fill(0.f); + + constexpr float kScale = 1.0f / kFftLengthBy2; + + for (size_t delay_samples : {0, 64, 150, 200, 301}) { + DelayBuffer delay_buffer(delay_samples); + SCOPED_TRACE(ProduceDebugText(delay_samples)); + for (size_t k = 0; k < kNumBlocksToProcess; ++k) { + RandomizeSampleVector(&random_generator, x); + delay_buffer.Delay(x, y); + + fft.PaddedFft(x, x_old, &X); + X_buffer.Insert(X); + render_signal_analyzer.Update(X_buffer, aec_state.FilterDelay()); + + filter.Filter(X_buffer, &S); + fft.Ifft(S, &s); + std::transform(y.begin(), y.end(), s.begin() + kFftLengthBy2, e.begin(), + [&](float a, float b) { return a - b * kScale; }); + std::for_each(e.begin(), e.end(), [](float& a) { + a = std::max(std::min(a, 32767.0f), -32768.0f); + }); + fft.ZeroPaddedFft(e, &E); + + gain.Compute(X_buffer, render_signal_analyzer, E, filter.SizePartitions(), + false, &G); + filter.Adapt(X_buffer, G); + aec_state.Update(filter.FilterFrequencyResponse(), + rtc::Optional(), X_buffer, E2_main, E2_shadow, + Y2, x, EchoPathVariability(false, false), false); + } + // Verify that the filter is able to perform well. + EXPECT_LT(1000 * std::inner_product(e.begin(), e.end(), e.begin(), 0.f), + std::inner_product(y.begin(), y.end(), y.begin(), 0.f)); + ASSERT_TRUE(aec_state.FilterDelay()); + EXPECT_EQ(delay_samples / kBlockSize, *aec_state.FilterDelay()); + } +} +} // namespace aec3 +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/aec3_common.cc b/webrtc/modules/audio_processing/aec3/aec3_common.cc new file mode 100644 index 0000000000..da0f2c4f19 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/aec3_common.cc @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" + +#include "webrtc/typedefs.h" +#include "webrtc/system_wrappers/include/cpu_features_wrapper.h" + +namespace webrtc { + +Aec3Optimization DetectOptimization() { +#if defined(WEBRTC_ARCH_X86_FAMILY) + if (WebRtc_GetCPUInfo(kSSE2) != 0) { + return Aec3Optimization::kSse2; + } +#endif + return Aec3Optimization::kNone; +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/aec3_constants.h b/webrtc/modules/audio_processing/aec3/aec3_common.h similarity index 80% rename from webrtc/modules/audio_processing/aec3/aec3_constants.h rename to webrtc/modules/audio_processing/aec3/aec3_common.h index 054b0d8afd..3a5e835e21 100644 --- a/webrtc/modules/audio_processing/aec3/aec3_constants.h +++ b/webrtc/modules/audio_processing/aec3/aec3_common.h @@ -8,15 +8,27 @@ * be found in the AUTHORS file in the root of the source tree. */ -#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_AEC3_CONSTANTS_H_ -#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_AEC3_CONSTANTS_H_ +#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_AEC3_COMMON_H_ +#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_AEC3_COMMON_H_ #include +#include "webrtc/typedefs.h" namespace webrtc { +#ifdef _MSC_VER /* visual c++ */ +#define ALIGN16_BEG __declspec(align(16)) +#define ALIGN16_END +#else /* gcc or icc */ +#define ALIGN16_BEG +#define ALIGN16_END __attribute__((aligned(16))) +#endif + +enum class Aec3Optimization { kNone, kSse2 }; + constexpr size_t kFftLengthBy2 = 64; constexpr size_t kFftLengthBy2Plus1 = kFftLengthBy2 + 1; +constexpr size_t kFftLengthBy2Minus1 = kFftLengthBy2 - 1; constexpr size_t kFftLength = 2 * kFftLengthBy2; constexpr size_t kMaxNumBands = 3; @@ -39,6 +51,9 @@ constexpr bool ValidFullBandRate(int sample_rate_hz) { sample_rate_hz == 32000 || sample_rate_hz == 48000; } +// Detects what kind of optimizations to use for the code. +Aec3Optimization DetectOptimization(); + static_assert(1 == NumBandsForRate(8000), "Number of bands for 8 kHz"); static_assert(1 == NumBandsForRate(16000), "Number of bands for 16 kHz"); static_assert(2 == NumBandsForRate(32000), "Number of bands for 32 kHz"); @@ -65,4 +80,4 @@ static_assert(!ValidFullBandRate(8001), } // namespace webrtc -#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_AEC3_CONSTANTS_H_ +#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_AEC3_COMMON_H_ diff --git a/webrtc/modules/audio_processing/aec3/aec3_fft.cc b/webrtc/modules/audio_processing/aec3/aec3_fft.cc new file mode 100644 index 0000000000..3f9ff44e9c --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/aec3_fft.cc @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/aec3_fft.h" + +#include + +#include "webrtc/base/checks.h" + +namespace webrtc { + +// TODO(peah): Change x to be std::array once the rest of the code allows this. +void Aec3Fft::ZeroPaddedFft(rtc::ArrayView x, FftData* X) const { + RTC_DCHECK(X); + RTC_DCHECK_EQ(kFftLengthBy2, x.size()); + std::array fft; + std::fill(fft.begin(), fft.begin() + kFftLengthBy2, 0.f); + std::copy(x.begin(), x.end(), fft.begin() + kFftLengthBy2); + Fft(&fft, X); +} + +void Aec3Fft::PaddedFft(rtc::ArrayView x, + rtc::ArrayView x_old, + FftData* X) const { + RTC_DCHECK(X); + RTC_DCHECK_EQ(kFftLengthBy2, x.size()); + RTC_DCHECK_EQ(kFftLengthBy2, x_old.size()); + std::array fft; + std::copy(x_old.begin(), x_old.end(), fft.begin()); + std::copy(x.begin(), x.end(), fft.begin() + x_old.size()); + std::copy(x.begin(), x.end(), x_old.begin()); + Fft(&fft, X); +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/aec3_fft.h b/webrtc/modules/audio_processing/aec3/aec3_fft.h new file mode 100644 index 0000000000..6cfe3bd9e7 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/aec3_fft.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_AEC3_FFT_H_ +#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_AEC3_FFT_H_ + +#include + +#include "webrtc/base/array_view.h" +#include "webrtc/base/constructormagic.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" +#include "webrtc/modules/audio_processing/aec3/fft_data.h" +#include "webrtc/modules/audio_processing/utility/ooura_fft.h" + +namespace webrtc { + +// Wrapper class that provides 128 point real valued FFT functionality with the +// FftData type. +class Aec3Fft { + public: + Aec3Fft() = default; + // Computes the FFT. Note that both the input and output are modified. + void Fft(std::array* x, FftData* X) const { + RTC_DCHECK(x); + RTC_DCHECK(X); + ooura_fft_.Fft(x->data()); + X->CopyFromPackedArray(*x); + } + // Computes the inverse Fft. + void Ifft(const FftData& X, std::array* x) const { + RTC_DCHECK(x); + X.CopyToPackedArray(x); + ooura_fft_.InverseFft(x->data()); + } + + // Pads the input with kFftLengthBy2 initial zeros before computing the Fft. + void ZeroPaddedFft(rtc::ArrayView x, FftData* X) const; + + // Concatenates the kFftLengthBy2 values long x and x_old before computing the + // Fft. After that, x is copied to x_old. + void PaddedFft(rtc::ArrayView x, + rtc::ArrayView x_old, + FftData* X) const; + + private: + const OouraFft ooura_fft_; + + RTC_DISALLOW_COPY_AND_ASSIGN(Aec3Fft); +}; + +} // namespace webrtc + +#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_AEC3_FFT_H_ diff --git a/webrtc/modules/audio_processing/aec3/aec3_fft_unittest.cc b/webrtc/modules/audio_processing/aec3/aec3_fft_unittest.cc new file mode 100644 index 0000000000..ae1f52e5af --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/aec3_fft_unittest.cc @@ -0,0 +1,211 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/aec3_fft.h" + +#include + +#include "webrtc/test/gmock.h" +#include "webrtc/test/gtest.h" + +namespace webrtc { + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +// Verifies that the check for non-null input in Fft works. +TEST(Aec3Fft, NullFftInput) { + Aec3Fft fft; + FftData X; + EXPECT_DEATH(fft.Fft(nullptr, &X), ""); +} + +// Verifies that the check for non-null input in Fft works. +TEST(Aec3Fft, NullFftOutput) { + Aec3Fft fft; + std::array x; + EXPECT_DEATH(fft.Fft(&x, nullptr), ""); +} + +// Verifies that the check for non-null output in Ifft works. +TEST(Aec3Fft, NullIfftOutput) { + Aec3Fft fft; + FftData X; + EXPECT_DEATH(fft.Ifft(X, nullptr), ""); +} + +// Verifies that the check for non-null output in ZeroPaddedFft works. +TEST(Aec3Fft, NullZeroPaddedFftOutput) { + Aec3Fft fft; + std::array x; + EXPECT_DEATH(fft.ZeroPaddedFft(x, nullptr), ""); +} + +// Verifies that the check for input length in ZeroPaddedFft works. +TEST(Aec3Fft, ZeroPaddedFftWrongInputLength) { + Aec3Fft fft; + FftData X; + std::array x; + EXPECT_DEATH(fft.ZeroPaddedFft(x, &X), ""); +} + +// Verifies that the check for non-null output in PaddedFft works. +TEST(Aec3Fft, NullPaddedFftOutput) { + Aec3Fft fft; + std::array x; + std::array x_old; + EXPECT_DEATH(fft.PaddedFft(x, x_old, nullptr), ""); +} + +// Verifies that the check for input length in PaddedFft works. +TEST(Aec3Fft, PaddedFftWrongInputLength) { + Aec3Fft fft; + FftData X; + std::array x; + std::array x_old; + EXPECT_DEATH(fft.PaddedFft(x, x_old, &X), ""); +} + +// Verifies that the check for length in the old value in PaddedFft works. +TEST(Aec3Fft, PaddedFftWrongOldValuesLength) { + Aec3Fft fft; + FftData X; + std::array x; + std::array x_old; + EXPECT_DEATH(fft.PaddedFft(x, x_old, &X), ""); +} + +#endif + +// Verifies that Fft works as intended. +TEST(Aec3Fft, Fft) { + Aec3Fft fft; + FftData X; + std::array x; + x.fill(0.f); + fft.Fft(&x, &X); + EXPECT_THAT(X.re, ::testing::Each(0.f)); + EXPECT_THAT(X.im, ::testing::Each(0.f)); + + x.fill(0.f); + x[0] = 1.f; + fft.Fft(&x, &X); + EXPECT_THAT(X.re, ::testing::Each(1.f)); + EXPECT_THAT(X.im, ::testing::Each(0.f)); + + x.fill(1.f); + fft.Fft(&x, &X); + EXPECT_EQ(128.f, X.re[0]); + std::for_each(X.re.begin() + 1, X.re.end(), + [](float a) { EXPECT_EQ(0.f, a); }); + EXPECT_THAT(X.im, ::testing::Each(0.f)); +} + +// Verifies that InverseFft works as intended. +TEST(Aec3Fft, Ifft) { + Aec3Fft fft; + FftData X; + std::array x; + + X.re.fill(0.f); + X.im.fill(0.f); + fft.Ifft(X, &x); + EXPECT_THAT(x, ::testing::Each(0.f)); + + X.re.fill(1.f); + X.im.fill(0.f); + fft.Ifft(X, &x); + EXPECT_EQ(64.f, x[0]); + std::for_each(x.begin() + 1, x.end(), [](float a) { EXPECT_EQ(0.f, a); }); + + X.re.fill(0.f); + X.re[0] = 128; + X.im.fill(0.f); + fft.Ifft(X, &x); + EXPECT_THAT(x, ::testing::Each(64.f)); +} + +// Verifies that InverseFft and Fft work as intended. +TEST(Aec3Fft, FftAndIfft) { + Aec3Fft fft; + FftData X; + std::array x; + std::array x_ref; + + int v = 0; + for (int k = 0; k < 20; ++k) { + for (size_t j = 0; j < x.size(); ++j) { + x[j] = v++; + x_ref[j] = x[j] * 64.f; + } + fft.Fft(&x, &X); + fft.Ifft(X, &x); + for (size_t j = 0; j < x.size(); ++j) { + EXPECT_NEAR(x_ref[j], x[j], 0.001f); + } + } +} + +// Verifies that ZeroPaddedFft work as intended. +TEST(Aec3Fft, ZeroPaddedFft) { + Aec3Fft fft; + FftData X; + std::array x_in; + std::array x_ref; + std::array x_out; + + int v = 0; + x_ref.fill(0.f); + for (int k = 0; k < 20; ++k) { + for (size_t j = 0; j < x_in.size(); ++j) { + x_in[j] = v++; + x_ref[j + kFftLengthBy2] = x_in[j] * 64.f; + } + fft.ZeroPaddedFft(x_in, &X); + fft.Ifft(X, &x_out); + for (size_t j = 0; j < x_out.size(); ++j) { + EXPECT_NEAR(x_ref[j], x_out[j], 0.1f); + } + } +} + +// Verifies that ZeroPaddedFft work as intended. +TEST(Aec3Fft, PaddedFft) { + Aec3Fft fft; + FftData X; + std::array x_in; + std::array x_out; + std::array x_old; + std::array x_old_ref; + std::array x_ref; + + int v = 0; + x_old.fill(0.f); + for (int k = 0; k < 20; ++k) { + for (size_t j = 0; j < x_in.size(); ++j) { + x_in[j] = v++; + } + + std::copy(x_old.begin(), x_old.end(), x_ref.begin()); + std::copy(x_in.begin(), x_in.end(), x_ref.begin() + kFftLengthBy2); + std::copy(x_in.begin(), x_in.end(), x_old_ref.begin()); + std::for_each(x_ref.begin(), x_ref.end(), [](float& a) { a *= 64.f; }); + + fft.PaddedFft(x_in, x_old, &X); + fft.Ifft(X, &x_out); + + for (size_t j = 0; j < x_out.size(); ++j) { + EXPECT_NEAR(x_ref[j], x_out[j], 0.1f); + } + + EXPECT_EQ(x_old_ref, x_old); + } +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/aec_state.cc b/webrtc/modules/audio_processing/aec3/aec_state.cc new file mode 100644 index 0000000000..c18fd6d870 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/aec_state.cc @@ -0,0 +1,162 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/aec_state.h" + +#include +#include +#include + +#include "webrtc/base/atomicops.h" +#include "webrtc/base/checks.h" +#include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" + +namespace webrtc { +namespace { + +constexpr float kMaxFilterEstimateStrength = 1000.f; + +// Compute the delay of the adaptive filter as the partition with a distinct +// peak. +void AnalyzeFilter( + const std::vector>& + filter_frequency_response, + std::array* bands_with_reliable_filter, + std::array* filter_estimate_strength, + rtc::Optional* filter_delay) { + const auto& H2 = filter_frequency_response; + + size_t reliable_delays_sum = 0; + size_t num_reliable_delays = 0; + + constexpr size_t kUpperBin = kFftLengthBy2 - 5; + for (size_t k = 1; k < kUpperBin; ++k) { + int peak = 0; + for (size_t j = 0; j < H2.size(); ++j) { + if (H2[j][k] > H2[peak][k]) { + peak = j; + } + } + + if (H2[peak][k] == 0.f) { + (*filter_estimate_strength)[k] = 0.f; + } else if (H2[H2.size() - 1][k] == 0.f) { + (*filter_estimate_strength)[k] = kMaxFilterEstimateStrength; + } else { + (*filter_estimate_strength)[k] = std::min( + kMaxFilterEstimateStrength, H2[peak][k] / H2[H2.size() - 1][k]); + } + + constexpr float kMargin = 10.f; + if (kMargin * H2[H2.size() - 1][k] < H2[peak][k]) { + (*bands_with_reliable_filter)[k] = true; + reliable_delays_sum += peak; + ++num_reliable_delays; + } else { + (*bands_with_reliable_filter)[k] = false; + } + } + (*bands_with_reliable_filter)[0] = (*bands_with_reliable_filter)[1]; + std::fill(bands_with_reliable_filter->begin() + kUpperBin, + bands_with_reliable_filter->end(), + (*bands_with_reliable_filter)[kUpperBin - 1]); + (*filter_estimate_strength)[0] = (*filter_estimate_strength)[1]; + std::fill(filter_estimate_strength->begin() + kUpperBin, + filter_estimate_strength->end(), + (*filter_estimate_strength)[kUpperBin - 1]); + + *filter_delay = + num_reliable_delays > 20 + ? rtc::Optional(reliable_delays_sum / num_reliable_delays) + : rtc::Optional(); +} + +constexpr int kActiveRenderCounterInitial = 50; +constexpr int kActiveRenderCounterMax = 200; +constexpr int kEchoPathChangeCounterInitial = 50; +constexpr int kEchoPathChangeCounterMax = 200; + +} // namespace + +int AecState::instance_count_ = 0; + +AecState::AecState() + : data_dumper_( + new ApmDataDumper(rtc::AtomicOps::Increment(&instance_count_))), + echo_path_change_counter_(kEchoPathChangeCounterInitial), + active_render_counter_(kActiveRenderCounterInitial) { + bands_with_reliable_filter_.fill(false); + filter_estimate_strength_.fill(0.f); +} + +AecState::~AecState() = default; + +void AecState::Update(const std::vector>& + filter_frequency_response, + const rtc::Optional& external_delay_samples, + const FftBuffer& X_buffer, + const std::array& E2_main, + const std::array& E2_shadow, + const std::array& Y2, + rtc::ArrayView x, + const EchoPathVariability& echo_path_variability, + bool echo_leakage_detected) { + filter_length_ = filter_frequency_response.size(); + AnalyzeFilter(filter_frequency_response, &bands_with_reliable_filter_, + &filter_estimate_strength_, &filter_delay_); + // Compute the externally provided delay in partitions. The truncation is + // intended here. + external_delay_ = + external_delay_samples + ? rtc::Optional(*external_delay_samples / kBlockSize) + : rtc::Optional(); + + const float x_energy = std::inner_product(x.begin(), x.end(), x.begin(), 0.f); + + echo_path_change_counter_ = echo_path_variability.AudioPathChanged() + ? kEchoPathChangeCounterMax + : echo_path_change_counter_ - 1; + active_render_counter_ = x_energy > 10000.f * kFftLengthBy2 + ? kActiveRenderCounterMax + : active_render_counter_ - 1; + + usable_linear_estimate_ = filter_delay_ && echo_path_change_counter_ <= 0; + + echo_leakage_detected_ = echo_leakage_detected; + + model_based_aec_feasible_ = usable_linear_estimate_ || external_delay_; + + if (usable_linear_estimate_) { + const auto& X2 = X_buffer.Spectrum(*filter_delay_); + + // TODO(peah): Expose these as stats. + erle_estimator_.Update(X2, Y2, E2_main); + erl_estimator_.Update(X2, Y2); + +// TODO(peah): Add working functionality for headset detection. Until the +// functionality for that is working the headset detector is hardcoded to detect +// no headset. +#if 0 + const auto& erl = erl_estimator_.Erl(); + const int low_erl_band_count = std::count_if( + erl.begin(), erl.end(), [](float a) { return a <= 0.1f; }); + + const int noisy_band_count = std::count_if( + filter_estimate_strength_.begin(), filter_estimate_strength_.end(), + [](float a) { return a <= 10.f; }); + headset_detected_ = low_erl_band_count > 20 && noisy_band_count > 20; +#endif + headset_detected_ = false; + } else { + headset_detected_ = false; + } +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/aec_state.h b/webrtc/modules/audio_processing/aec3/aec_state.h new file mode 100644 index 0000000000..e3502b4a4f --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/aec_state.h @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_AEC_STATE_H_ +#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_AEC_STATE_H_ + +#include +#include +#include + +#include "webrtc/base/array_view.h" +#include "webrtc/base/constructormagic.h" +#include "webrtc/base/optional.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" +#include "webrtc/modules/audio_processing/aec3/echo_path_variability.h" +#include "webrtc/modules/audio_processing/aec3/erle_estimator.h" +#include "webrtc/modules/audio_processing/aec3/erl_estimator.h" +#include "webrtc/modules/audio_processing/aec3/fft_buffer.h" + +namespace webrtc { + +class ApmDataDumper; + +// Handles the state and the conditions for the echo removal functionality. +class AecState { + public: + AecState(); + ~AecState(); + + // Returns whether the linear filter estimate is usable. + bool UsableLinearEstimate() const { return usable_linear_estimate_; } + + // Returns whether there has been echo leakage detected. + bool EchoLeakageDetected() const { return echo_leakage_detected_; } + + // Returns whether it is possible at all to use the model based echo removal + // functionalities. + bool ModelBasedAecFeasible() const { return model_based_aec_feasible_; } + + // Returns whether the render signal is currently active. + bool ActiveRender() const { return active_render_counter_ > 0; } + + // Returns the ERLE. + const std::array& Erle() const { + return erle_estimator_.Erle(); + } + + // Returns the ERL. + const std::array& Erl() const { + return erl_estimator_.Erl(); + } + + // Returns the delay estimate based on the linear filter. + rtc::Optional FilterDelay() const { return filter_delay_; } + + // Returns the externally provided delay. + rtc::Optional ExternalDelay() const { return external_delay_; } + + // Returns the bands where the linear filter is reliable. + const std::array& BandsWithReliableFilter() const { + return bands_with_reliable_filter_; + } + + // Reports whether the filter is poorly aligned. + bool PoorlyAlignedFilter() const { + return FilterDelay() ? *FilterDelay() > 0.75f * filter_length_ : false; + } + + // Returns the strength of the filter. + const std::array& FilterEstimateStrength() const { + return filter_estimate_strength_; + } + + // Returns whether the capture signal is saturated. + bool SaturatedCapture() const { return capture_signal_saturation_; } + + // Updates the capture signal saturation. + void UpdateCaptureSaturation(bool capture_signal_saturation) { + capture_signal_saturation_ = capture_signal_saturation; + } + + // Returns whether a probable headset setup has been detected. + bool HeadsetDetected() const { return headset_detected_; } + + // Updates the aec state. + void Update(const std::vector>& + filter_frequency_response, + const rtc::Optional& external_delay_samples, + const FftBuffer& X_buffer, + const std::array& E2_main, + const std::array& E2_shadow, + const std::array& Y2, + rtc::ArrayView x, + const EchoPathVariability& echo_path_variability, + bool echo_leakage_detected); + + private: + static int instance_count_; + std::unique_ptr data_dumper_; + ErlEstimator erl_estimator_; + ErleEstimator erle_estimator_; + int echo_path_change_counter_; + int active_render_counter_; + bool usable_linear_estimate_ = false; + bool echo_leakage_detected_ = false; + bool model_based_aec_feasible_ = false; + bool capture_signal_saturation_ = false; + bool headset_detected_ = false; + rtc::Optional filter_delay_; + rtc::Optional external_delay_; + std::array bands_with_reliable_filter_; + std::array filter_estimate_strength_; + size_t filter_length_; + + RTC_DISALLOW_COPY_AND_ASSIGN(AecState); +}; + +} // namespace webrtc + +#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_AEC_STATE_H_ diff --git a/webrtc/modules/audio_processing/aec3/aec_state_unittest.cc b/webrtc/modules/audio_processing/aec3/aec_state_unittest.cc new file mode 100644 index 0000000000..6b25f25e08 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/aec_state_unittest.cc @@ -0,0 +1,276 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/aec_state.h" + +#include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" +#include "webrtc/test/gtest.h" + +namespace webrtc { + +// Verify the general functionality of AecState +TEST(AecState, NormalUsage) { + ApmDataDumper data_dumper(42); + AecState state; + FftBuffer X_buffer(Aec3Optimization::kNone, 30, std::vector(1, 30)); + std::array E2_main; + std::array E2_shadow; + std::array Y2; + std::array x; + EchoPathVariability echo_path_variability(false, false); + x.fill(0.f); + + std::vector> + converged_filter_frequency_response(10); + for (auto& v : converged_filter_frequency_response) { + v.fill(0.01f); + } + std::vector> + diverged_filter_frequency_response = converged_filter_frequency_response; + converged_filter_frequency_response[2].fill(100.f); + + // Verify that model based aec feasibility and linear AEC usability are false + // when the filter is diverged and there is no external delay reported. + state.Update(diverged_filter_frequency_response, rtc::Optional(), + X_buffer, E2_main, E2_shadow, Y2, x, echo_path_variability, + false); + EXPECT_FALSE(state.ModelBasedAecFeasible()); + EXPECT_FALSE(state.UsableLinearEstimate()); + + // Verify that model based aec feasibility is true and that linear AEC + // usability is false when the filter is diverged and there is an external + // delay reported. + state.Update(diverged_filter_frequency_response, rtc::Optional(), + X_buffer, E2_main, E2_shadow, Y2, x, echo_path_variability, + false); + EXPECT_FALSE(state.ModelBasedAecFeasible()); + for (int k = 0; k < 50; ++k) { + state.Update(diverged_filter_frequency_response, rtc::Optional(2), + X_buffer, E2_main, E2_shadow, Y2, x, echo_path_variability, + false); + } + EXPECT_TRUE(state.ModelBasedAecFeasible()); + EXPECT_FALSE(state.UsableLinearEstimate()); + + // Verify that linear AEC usability is true when the filter is converged + for (int k = 0; k < 50; ++k) { + state.Update(converged_filter_frequency_response, rtc::Optional(2), + X_buffer, E2_main, E2_shadow, Y2, x, echo_path_variability, + false); + } + EXPECT_TRUE(state.UsableLinearEstimate()); + + // Verify that linear AEC usability becomes false after an echo path change is + // reported + echo_path_variability = EchoPathVariability(true, false); + state.Update(converged_filter_frequency_response, rtc::Optional(2), + X_buffer, E2_main, E2_shadow, Y2, x, echo_path_variability, + false); + EXPECT_FALSE(state.UsableLinearEstimate()); + + // Verify that the active render detection works as intended. + x.fill(101.f); + state.Update(converged_filter_frequency_response, rtc::Optional(2), + X_buffer, E2_main, E2_shadow, Y2, x, echo_path_variability, + false); + EXPECT_TRUE(state.ActiveRender()); + + x.fill(0.f); + for (int k = 0; k < 200; ++k) { + state.Update(converged_filter_frequency_response, rtc::Optional(2), + X_buffer, E2_main, E2_shadow, Y2, x, echo_path_variability, + false); + } + EXPECT_FALSE(state.ActiveRender()); + + x.fill(101.f); + state.Update(converged_filter_frequency_response, rtc::Optional(2), + X_buffer, E2_main, E2_shadow, Y2, x, echo_path_variability, + false); + EXPECT_TRUE(state.ActiveRender()); + + // Verify that echo leakage is properly reported. + state.Update(converged_filter_frequency_response, rtc::Optional(2), + X_buffer, E2_main, E2_shadow, Y2, x, echo_path_variability, + false); + EXPECT_FALSE(state.EchoLeakageDetected()); + + state.Update(converged_filter_frequency_response, rtc::Optional(2), + X_buffer, E2_main, E2_shadow, Y2, x, echo_path_variability, + true); + EXPECT_TRUE(state.EchoLeakageDetected()); + + // Verify that the bands containing reliable filter estimates are properly + // reported. + echo_path_variability = EchoPathVariability(false, false); + for (int k = 0; k < 200; ++k) { + state.Update(converged_filter_frequency_response, rtc::Optional(2), + X_buffer, E2_main, E2_shadow, Y2, x, echo_path_variability, + false); + } + + FftData X; + X.re.fill(10000.f); + X.im.fill(0.f); + for (size_t k = 0; k < X_buffer.Buffer().size(); ++k) { + X_buffer.Insert(X); + } + + Y2.fill(10.f * 1000.f * 1000.f); + E2_main.fill(100.f * Y2[0]); + E2_shadow.fill(100.f * Y2[0]); + state.Update(converged_filter_frequency_response, rtc::Optional(2), + X_buffer, E2_main, E2_shadow, Y2, x, echo_path_variability, + false); + + E2_main.fill(0.1f * Y2[0]); + E2_shadow.fill(E2_main[0]); + for (size_t k = 0; k < Y2.size(); k += 2) { + E2_main[k] = Y2[k]; + E2_shadow[k] = Y2[k]; + } + state.Update(converged_filter_frequency_response, rtc::Optional(2), + X_buffer, E2_main, E2_shadow, Y2, x, echo_path_variability, + false); + + const std::array& reliable_bands = + state.BandsWithReliableFilter(); + + EXPECT_EQ(reliable_bands[0], reliable_bands[1]); + for (size_t k = 1; k < kFftLengthBy2 - 5; ++k) { + EXPECT_TRUE(reliable_bands[k]); + } + for (size_t k = kFftLengthBy2 - 5; k < reliable_bands.size(); ++k) { + EXPECT_EQ(reliable_bands[kFftLengthBy2 - 6], reliable_bands[k]); + } + + // Verify that the ERL is properly estimated + Y2.fill(10.f * X.re[0] * X.re[0]); + for (size_t k = 0; k < 100000; ++k) { + state.Update(converged_filter_frequency_response, rtc::Optional(2), + X_buffer, E2_main, E2_shadow, Y2, x, echo_path_variability, + false); + } + + ASSERT_TRUE(state.UsableLinearEstimate()); + const std::array& erl = state.Erl(); + std::for_each(erl.begin(), erl.end(), + [](float a) { EXPECT_NEAR(10.f, a, 0.1); }); + + // Verify that the ERLE is properly estimated + E2_main.fill(1.f * X.re[0] * X.re[0]); + Y2.fill(10.f * E2_main[0]); + for (size_t k = 0; k < 10000; ++k) { + state.Update(converged_filter_frequency_response, rtc::Optional(2), + X_buffer, E2_main, E2_shadow, Y2, x, echo_path_variability, + false); + } + ASSERT_TRUE(state.UsableLinearEstimate()); + std::for_each(state.Erle().begin(), state.Erle().end(), + [](float a) { EXPECT_NEAR(8.f, a, 0.1); }); + + E2_main.fill(1.f * X.re[0] * X.re[0]); + Y2.fill(5.f * E2_main[0]); + for (size_t k = 0; k < 10000; ++k) { + state.Update(converged_filter_frequency_response, rtc::Optional(2), + X_buffer, E2_main, E2_shadow, Y2, x, echo_path_variability, + false); + } + ASSERT_TRUE(state.UsableLinearEstimate()); + std::for_each(state.Erle().begin(), state.Erle().end(), + [](float a) { EXPECT_NEAR(5.f, a, 0.1); }); +} + +// Verifies the a non-significant delay is correctly identified. +TEST(AecState, NonSignificantDelay) { + AecState state; + FftBuffer X_buffer(Aec3Optimization::kNone, 30, std::vector(1, 30)); + std::array E2_main; + std::array E2_shadow; + std::array Y2; + std::array x; + EchoPathVariability echo_path_variability(false, false); + x.fill(0.f); + + std::vector> frequency_response(30); + for (auto& v : frequency_response) { + v.fill(0.01f); + } + + // Verify that a non-significant filter delay is identified correctly. + state.Update(frequency_response, rtc::Optional(), X_buffer, E2_main, + E2_shadow, Y2, x, echo_path_variability, false); + EXPECT_FALSE(state.FilterDelay()); +} + +// Verifies the delay for a converged filter is correctly identified. +TEST(AecState, ConvergedFilterDelay) { + constexpr int kFilterLength = 10; + AecState state; + FftBuffer X_buffer(Aec3Optimization::kNone, 30, std::vector(1, 30)); + std::array E2_main; + std::array E2_shadow; + std::array Y2; + std::array x; + EchoPathVariability echo_path_variability(false, false); + x.fill(0.f); + + std::vector> frequency_response( + kFilterLength); + + // Verify that the filter delay for a converged filter is properly identified. + for (int k = 0; k < kFilterLength; ++k) { + for (auto& v : frequency_response) { + v.fill(0.01f); + } + frequency_response[k].fill(100.f); + + state.Update(frequency_response, rtc::Optional(), X_buffer, E2_main, + E2_shadow, Y2, x, echo_path_variability, false); + EXPECT_TRUE(k == (kFilterLength - 1) || state.FilterDelay()); + if (k != (kFilterLength - 1)) { + EXPECT_EQ(k, state.FilterDelay()); + } + } +} + +// Verify that the externally reported delay is properly reported and converted. +TEST(AecState, ExternalDelay) { + AecState state; + std::array E2_main; + std::array E2_shadow; + std::array Y2; + std::array x; + E2_main.fill(0.f); + E2_shadow.fill(0.f); + Y2.fill(0.f); + x.fill(0.f); + FftBuffer X_buffer(Aec3Optimization::kNone, 30, std::vector(1, 30)); + std::vector> frequency_response(30); + for (auto& v : frequency_response) { + v.fill(0.01f); + } + + for (size_t k = 0; k < frequency_response.size() - 1; ++k) { + state.Update(frequency_response, rtc::Optional(k * kBlockSize + 5), + X_buffer, E2_main, E2_shadow, Y2, x, + EchoPathVariability(false, false), false); + EXPECT_TRUE(state.ExternalDelay()); + EXPECT_EQ(k, state.ExternalDelay()); + } + + // Verify that the externally reported delay is properly unset when it is no + // longer present. + state.Update(frequency_response, rtc::Optional(), X_buffer, E2_main, + E2_shadow, Y2, x, EchoPathVariability(false, false), false); + EXPECT_FALSE(state.ExternalDelay()); +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/block_framer.h b/webrtc/modules/audio_processing/aec3/block_framer.h index 8a90300f6c..c8bca8ea44 100644 --- a/webrtc/modules/audio_processing/aec3/block_framer.h +++ b/webrtc/modules/audio_processing/aec3/block_framer.h @@ -15,7 +15,7 @@ #include "webrtc/base/array_view.h" #include "webrtc/base/constructormagic.h" -#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" namespace webrtc { diff --git a/webrtc/modules/audio_processing/aec3/block_framer_unittest.cc b/webrtc/modules/audio_processing/aec3/block_framer_unittest.cc index 38112392bb..b6419f7c60 100644 --- a/webrtc/modules/audio_processing/aec3/block_framer_unittest.cc +++ b/webrtc/modules/audio_processing/aec3/block_framer_unittest.cc @@ -14,7 +14,7 @@ #include #include -#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" #include "webrtc/test/gtest.h" namespace webrtc { diff --git a/webrtc/modules/audio_processing/aec3/block_processor.cc b/webrtc/modules/audio_processing/aec3/block_processor.cc index 550a21073f..223b693813 100644 --- a/webrtc/modules/audio_processing/aec3/block_processor.cc +++ b/webrtc/modules/audio_processing/aec3/block_processor.cc @@ -12,7 +12,7 @@ #include "webrtc/base/atomicops.h" #include "webrtc/base/constructormagic.h" #include "webrtc/base/optional.h" -#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" #include "webrtc/modules/audio_processing/aec3/echo_path_variability.h" #include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" #include "webrtc/system_wrappers/include/logging.h" diff --git a/webrtc/modules/audio_processing/aec3/block_processor_unittest.cc b/webrtc/modules/audio_processing/aec3/block_processor_unittest.cc index ac3af6afb1..78e5c7ed7c 100644 --- a/webrtc/modules/audio_processing/aec3/block_processor_unittest.cc +++ b/webrtc/modules/audio_processing/aec3/block_processor_unittest.cc @@ -17,7 +17,7 @@ #include "webrtc/base/checks.h" #include "webrtc/base/random.h" -#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" #include "webrtc/modules/audio_processing/aec3/mock/mock_echo_remover.h" #include "webrtc/modules/audio_processing/aec3/mock/mock_render_delay_buffer.h" #include "webrtc/modules/audio_processing/aec3/mock/mock_render_delay_controller.h" @@ -231,6 +231,8 @@ TEST(BlockProcessor, VerifyRenderNumBandsCheck) { } } +// TODO(peah): Verify the check for correct number of bands in the capture +// signal. TEST(BlockProcessor, VerifyCaptureNumBandsCheck) { for (auto rate : {8000, 16000, 32000, 48000}) { SCOPED_TRACE(ProduceDebugText(rate)); diff --git a/webrtc/modules/audio_processing/aec3/comfort_noise_generator.cc b/webrtc/modules/audio_processing/aec3/comfort_noise_generator.cc new file mode 100644 index 0000000000..f630b25175 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/comfort_noise_generator.cc @@ -0,0 +1,208 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/comfort_noise_generator.h" + +#include "webrtc/typedefs.h" +#if defined(WEBRTC_ARCH_X86_FAMILY) +#include +#endif +#include +#include +#include +#include +#include + +#include "webrtc/common_audio/signal_processing/include/signal_processing_library.h" + +namespace webrtc { + +namespace { + +// Creates an array of uniformly distributed variables. +void TableRandomValue(int16_t* vector, int16_t vector_length, uint32_t* seed) { + for (int i = 0; i < vector_length; i++) { + seed[0] = (seed[0] * ((int32_t)69069) + 1) & (0x80000000 - 1); + vector[i] = (int16_t)(seed[0] >> 16); + } +} + +} // namespace + +namespace aec3 { + +#if defined(WEBRTC_ARCH_X86_FAMILY) + +void EstimateComfortNoise_SSE2(const std::array& N2, + uint32_t* seed, + FftData* lower_band_noise, + FftData* upper_band_noise) { + FftData* N_low = lower_band_noise; + FftData* N_high = upper_band_noise; + + // Compute square root spectrum. + std::array N; + for (size_t k = 0; k < kFftLengthBy2; k += 4) { + __m128 v = _mm_loadu_ps(&N2[k]); + v = _mm_sqrt_ps(v); + _mm_storeu_ps(&N[k], v); + } + + N[kFftLengthBy2] = sqrtf(N2[kFftLengthBy2]); + + // Compute the noise level for the upper bands. + constexpr float kOneByNumBands = 1.f / (kFftLengthBy2Plus1 / 2 + 1); + constexpr int kFftLengthBy2Plus1By2 = kFftLengthBy2Plus1 / 2; + const float high_band_noise_level = + std::accumulate(N.begin() + kFftLengthBy2Plus1By2, N.end(), 0.f) * + kOneByNumBands; + + // Generate complex noise. + std::array random_values_int; + TableRandomValue(random_values_int.data(), random_values_int.size(), seed); + + std::array sin; + std::array cos; + constexpr float kScale = 6.28318530717959f / 32768.0f; + std::transform(random_values_int.begin(), random_values_int.end(), + sin.begin(), [&](int16_t a) { return -sinf(kScale * a); }); + std::transform(random_values_int.begin(), random_values_int.end(), + cos.begin(), [&](int16_t a) { return cosf(kScale * a); }); + + // Form low-frequency noise via spectral shaping. + N_low->re[0] = N_low->re[kFftLengthBy2] = N_high->re[0] = + N_high->re[kFftLengthBy2] = 0.f; + std::transform(cos.begin(), cos.end(), N.begin() + 1, N_low->re.begin() + 1, + std::multiplies()); + std::transform(sin.begin(), sin.end(), N.begin() + 1, N_low->im.begin() + 1, + std::multiplies()); + + // Form the high-frequency noise via simple levelling. + std::transform(cos.begin(), cos.end(), N_high->re.begin() + 1, + [&](float a) { return high_band_noise_level * a; }); + std::transform(sin.begin(), sin.end(), N_high->im.begin() + 1, + [&](float a) { return high_band_noise_level * a; }); +} + +#endif + +void EstimateComfortNoise(const std::array& N2, + uint32_t* seed, + FftData* lower_band_noise, + FftData* upper_band_noise) { + FftData* N_low = lower_band_noise; + FftData* N_high = upper_band_noise; + + // Compute square root spectrum. + std::array N; + std::transform(N2.begin(), N2.end(), N.begin(), + [](float a) { return sqrtf(a); }); + + // Compute the noise level for the upper bands. + constexpr float kOneByNumBands = 1.f / (kFftLengthBy2Plus1 / 2 + 1); + constexpr int kFftLengthBy2Plus1By2 = kFftLengthBy2Plus1 / 2; + const float high_band_noise_level = + std::accumulate(N.begin() + kFftLengthBy2Plus1By2, N.end(), 0.f) * + kOneByNumBands; + + // Generate complex noise. + std::array random_values_int; + TableRandomValue(random_values_int.data(), random_values_int.size(), seed); + + std::array sin; + std::array cos; + constexpr float kScale = 6.28318530717959f / 32768.0f; + std::transform(random_values_int.begin(), random_values_int.end(), + sin.begin(), [&](int16_t a) { return -sinf(kScale * a); }); + std::transform(random_values_int.begin(), random_values_int.end(), + cos.begin(), [&](int16_t a) { return cosf(kScale * a); }); + + // Form low-frequency noise via spectral shaping. + N_low->re[0] = N_low->re[kFftLengthBy2] = N_high->re[0] = + N_high->re[kFftLengthBy2] = 0.f; + std::transform(cos.begin(), cos.end(), N.begin() + 1, N_low->re.begin() + 1, + std::multiplies()); + std::transform(sin.begin(), sin.end(), N.begin() + 1, N_low->im.begin() + 1, + std::multiplies()); + + // Form the high-frequency noise via simple levelling. + std::transform(cos.begin(), cos.end(), N_high->re.begin() + 1, + [&](float a) { return high_band_noise_level * a; }); + std::transform(sin.begin(), sin.end(), N_high->im.begin() + 1, + [&](float a) { return high_band_noise_level * a; }); +} + +} // namespace aec3 + +ComfortNoiseGenerator::ComfortNoiseGenerator(Aec3Optimization optimization) + : optimization_(optimization), + seed_(42), + N2_initial_(new std::array()) { + N2_initial_->fill(0.f); + Y2_smoothed_.fill(0.f); + N2_.fill(1.0e6f); +} + +ComfortNoiseGenerator::~ComfortNoiseGenerator() = default; + +void ComfortNoiseGenerator::Compute( + const AecState& aec_state, + const std::array& capture_spectrum, + FftData* lower_band_noise, + FftData* upper_band_noise) { + RTC_DCHECK(lower_band_noise); + RTC_DCHECK(upper_band_noise); + const auto& Y2 = capture_spectrum; + + if (!aec_state.SaturatedCapture()) { + // Smooth Y2. + std::transform(Y2_smoothed_.begin(), Y2_smoothed_.end(), Y2.begin(), + Y2_smoothed_.begin(), + [](float a, float b) { return a + 0.1f * (b - a); }); + + if (N2_counter_ > 50) { + // Update N2 from Y2_smoothed. + std::transform(N2_.begin(), N2_.end(), Y2_smoothed_.begin(), N2_.begin(), + [](float a, float b) { + return b < a ? (0.9f * b + 0.1f * a) * 1.0002f + : a * 1.0002f; + }); + } + + if (N2_initial_) { + if (++N2_counter_ == 1000) { + N2_initial_.reset(); + } else { + // Compute the N2_initial from N2. + std::transform( + N2_.begin(), N2_.end(), N2_initial_->begin(), N2_initial_->begin(), + [](float a, float b) { return a > b ? b + 0.001f * (a - b) : a; }); + } + } + } + + // Choose N2 estimate to use. + const std::array& N2 = + N2_initial_ ? *N2_initial_ : N2_; + + switch (optimization_) { +#if defined(WEBRTC_ARCH_X86_FAMILY) + case Aec3Optimization::kSse2: + aec3::EstimateComfortNoise_SSE2(N2, &seed_, lower_band_noise, + upper_band_noise); + break; +#endif + default: + aec3::EstimateComfortNoise(N2, &seed_, lower_band_noise, + upper_band_noise); + } +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/comfort_noise_generator.h b/webrtc/modules/audio_processing/aec3/comfort_noise_generator.h new file mode 100644 index 0000000000..14332f7881 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/comfort_noise_generator.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_COMFORT_NOISE_GENERATOR_H_ +#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_COMFORT_NOISE_GENERATOR_H_ + +#include +#include + +#include "webrtc/base/constructormagic.h" +#include "webrtc/modules/audio_processing/aec3/aec_state.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" +#include "webrtc/modules/audio_processing/aec3/fft_data.h" + +namespace webrtc { +namespace aec3 { +#if defined(WEBRTC_ARCH_X86_FAMILY) + +void EstimateComfortNoise_SSE2(const std::array& N2, + uint32_t* seed, + FftData* lower_band_noise, + FftData* upper_band_noise); +#endif +void EstimateComfortNoise(const std::array& N2, + uint32_t* seed, + FftData* lower_band_noise, + FftData* upper_band_noise); + +} // namespace aec3 + +// Generates the comfort noise. +class ComfortNoiseGenerator { + public: + explicit ComfortNoiseGenerator(Aec3Optimization optimization); + ~ComfortNoiseGenerator(); + + // Computes the comfort noise. + void Compute(const AecState& aec_state, + const std::array& capture_spectrum, + FftData* lower_band_noise, + FftData* upper_band_noise); + + // Returns the estimate of the background noise spectrum. + const std::array& NoiseSpectrum() const { + return N2_; + } + + private: + const Aec3Optimization optimization_; + uint32_t seed_; + std::unique_ptr> N2_initial_; + std::array Y2_smoothed_; + std::array N2_; + int N2_counter_ = 0; + + RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(ComfortNoiseGenerator); +}; + +} // namespace webrtc + +#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_COMFORT_NOISE_GENERATOR_H_ diff --git a/webrtc/modules/audio_processing/aec3/comfort_noise_generator_unittest.cc b/webrtc/modules/audio_processing/aec3/comfort_noise_generator_unittest.cc new file mode 100644 index 0000000000..dcdbab38eb --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/comfort_noise_generator_unittest.cc @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/comfort_noise_generator.h" + +#include +#include + +#include "webrtc/typedefs.h" +#include "webrtc/base/random.h" +#include "webrtc/system_wrappers/include/cpu_features_wrapper.h" +#include "webrtc/test/gtest.h" + +namespace webrtc { +namespace aec3 { +namespace { + +float Power(const FftData& N) { + std::array N2; + N.Spectrum(Aec3Optimization::kNone, &N2); + return std::accumulate(N2.begin(), N2.end(), 0.f) / N2.size(); +} + +} // namespace + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +TEST(ComfortNoiseGenerator, NullLowerBandNoise) { + std::array N2; + FftData noise; + EXPECT_DEATH(ComfortNoiseGenerator(DetectOptimization()) + .Compute(AecState(), N2, nullptr, &noise), + ""); +} + +TEST(ComfortNoiseGenerator, NullUpperBandNoise) { + std::array N2; + FftData noise; + EXPECT_DEATH(ComfortNoiseGenerator(DetectOptimization()) + .Compute(AecState(), N2, &noise, nullptr), + ""); +} + +#endif + +#if defined(WEBRTC_ARCH_X86_FAMILY) +// Verifies that the optimized methods are bitexact to their reference +// counterparts. +TEST(ComfortNoiseGenerator, TestOptimizations) { + if (WebRtc_GetCPUInfo(kSSE2) != 0) { + Random random_generator(42U); + uint32_t seed = 42; + uint32_t seed_SSE2 = 42; + std::array N2; + FftData lower_band_noise; + FftData upper_band_noise; + FftData lower_band_noise_SSE2; + FftData upper_band_noise_SSE2; + for (int k = 0; k < 10; ++k) { + for (size_t j = 0; j < N2.size(); ++j) { + N2[j] = random_generator.Rand() * 1000.f; + } + + EstimateComfortNoise(N2, &seed, &lower_band_noise, &upper_band_noise); + EstimateComfortNoise_SSE2(N2, &seed_SSE2, &lower_band_noise_SSE2, + &upper_band_noise_SSE2); + for (size_t j = 0; j < lower_band_noise.re.size(); ++j) { + EXPECT_NEAR(lower_band_noise.re[j], lower_band_noise_SSE2.re[j], + 0.00001f); + EXPECT_NEAR(upper_band_noise.re[j], upper_band_noise_SSE2.re[j], + 0.00001f); + } + for (size_t j = 1; j < lower_band_noise.re.size() - 1; ++j) { + EXPECT_NEAR(lower_band_noise.im[j], lower_band_noise_SSE2.im[j], + 0.00001f); + EXPECT_NEAR(upper_band_noise.im[j], upper_band_noise_SSE2.im[j], + 0.00001f); + } + } + } +} + +#endif + +TEST(ComfortNoiseGenerator, CorrectLevel) { + ComfortNoiseGenerator cng(DetectOptimization()); + AecState aec_state; + + std::array N2; + N2.fill(1000.f * 1000.f); + + FftData n_lower; + FftData n_upper; + n_lower.re.fill(0.f); + n_lower.im.fill(0.f); + n_upper.re.fill(0.f); + n_upper.im.fill(0.f); + + // Ensure instantaneous updata to nonzero noise. + cng.Compute(aec_state, N2, &n_lower, &n_upper); + EXPECT_LT(0.f, Power(n_lower)); + EXPECT_LT(0.f, Power(n_upper)); + + for (int k = 0; k < 10000; ++k) { + cng.Compute(aec_state, N2, &n_lower, &n_upper); + } + EXPECT_NEAR(N2[0], Power(n_lower), N2[0] / 10.f); + EXPECT_NEAR(N2[0], Power(n_upper), N2[0] / 10.f); +} + +} // namespace aec3 +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/decimator_by_4.h b/webrtc/modules/audio_processing/aec3/decimator_by_4.h index 076c1688c8..9a22dfcfed 100644 --- a/webrtc/modules/audio_processing/aec3/decimator_by_4.h +++ b/webrtc/modules/audio_processing/aec3/decimator_by_4.h @@ -15,7 +15,7 @@ #include "webrtc/base/array_view.h" #include "webrtc/base/constructormagic.h" -#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" #include "webrtc/modules/audio_processing/aec3/cascaded_biquad_filter.h" namespace webrtc { diff --git a/webrtc/modules/audio_processing/aec3/decimator_by_4_unittest.cc b/webrtc/modules/audio_processing/aec3/decimator_by_4_unittest.cc index a7699ba64f..760c5e59b7 100644 --- a/webrtc/modules/audio_processing/aec3/decimator_by_4_unittest.cc +++ b/webrtc/modules/audio_processing/aec3/decimator_by_4_unittest.cc @@ -18,7 +18,7 @@ #include #include -#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" #include "webrtc/test/gtest.h" namespace webrtc { diff --git a/webrtc/modules/audio_processing/aec3/echo_canceller3.cc b/webrtc/modules/audio_processing/aec3/echo_canceller3.cc index a4b796d969..60efced0ec 100644 --- a/webrtc/modules/audio_processing/aec3/echo_canceller3.cc +++ b/webrtc/modules/audio_processing/aec3/echo_canceller3.cc @@ -20,7 +20,7 @@ namespace { bool DetectSaturation(rtc::ArrayView y) { for (auto y_k : y) { - if (y_k >= 32767.0f || y_k <= -32768.0f) { + if (y_k >= 32700.0f || y_k <= -32700.0f) { return true; } } @@ -107,16 +107,14 @@ bool BufferRemainingRenderFrameContent(FrameBlocker* render_blocker, return block_processor->BufferRender(block); } -void CopyAudioBufferIntoFrame(AudioBuffer* buffer, - size_t num_bands, - size_t frame_length, - std::vector>* frame) { +void CopyLowestBandIntoFrame(AudioBuffer* buffer, + size_t num_bands, + size_t frame_length, + std::vector>* frame) { RTC_DCHECK_EQ(num_bands, frame->size()); - for (size_t i = 0; i < num_bands; ++i) { - rtc::ArrayView buffer_view(&buffer->split_bands_f(0)[i][0], - frame_length); - std::copy(buffer_view.begin(), buffer_view.end(), (*frame)[i].begin()); - } + RTC_DCHECK_EQ(frame_length, (*frame)[0].size()); + rtc::ArrayView buffer_view(&buffer->channels_f()[0][0], frame_length); + std::copy(buffer_view.begin(), buffer_view.end(), (*frame)[0].begin()); } // [B,A] = butter(2,100/4000,'high') @@ -182,14 +180,13 @@ EchoCanceller3::RenderWriter::~RenderWriter() = default; bool EchoCanceller3::RenderWriter::Insert(AudioBuffer* input) { RTC_DCHECK_EQ(1, input->num_channels()); - RTC_DCHECK_EQ(num_bands_, input->num_bands()); RTC_DCHECK_EQ(frame_length_, input->num_frames_per_band()); data_dumper_->DumpWav("aec3_render_input", frame_length_, - &input->split_bands_f(0)[0][0], + &input->channels_f()[0][0], LowestBandRate(sample_rate_hz_), 1); - CopyAudioBufferIntoFrame(input, num_bands_, frame_length_, - &render_queue_input_frame_); + CopyLowestBandIntoFrame(input, num_bands_, frame_length_, + &render_queue_input_frame_); if (render_highpass_filter_) { render_highpass_filter_->Process(render_queue_input_frame_[0]); diff --git a/webrtc/modules/audio_processing/aec3/echo_canceller3_unittest.cc b/webrtc/modules/audio_processing/aec3/echo_canceller3_unittest.cc index 8ccaa51e52..1162f70e2b 100644 --- a/webrtc/modules/audio_processing/aec3/echo_canceller3_unittest.cc +++ b/webrtc/modules/audio_processing/aec3/echo_canceller3_unittest.cc @@ -17,7 +17,7 @@ #include #include -#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" #include "webrtc/modules/audio_processing/aec3/block_processor.h" #include "webrtc/modules/audio_processing/aec3/frame_blocker.h" #include "webrtc/modules/audio_processing/aec3/mock/mock_block_processor.h" @@ -48,6 +48,17 @@ void PopulateInputFrame(size_t frame_length, } } +// Populates the frame with linearly increasing sample values. +void PopulateInputFrame(size_t frame_length, + size_t frame_index, + float* frame, + int offset) { + for (size_t i = 0; i < frame_length; ++i) { + float value = static_cast(frame_index * frame_length + i) + offset; + frame[i] = std::max(value, 0.f); + } +} + // Verifies the that samples in the output frame are identical to the samples // that were produced for the input frame, with an offset in order to compensate // for buffering delays. @@ -75,6 +86,25 @@ bool VerifyOutputFrameBitexactness(size_t frame_length, return true; } +// Verifies the that samples in the output frame are identical to the samples +// that were produced for the input frame, with an offset in order to compensate +// for buffering delays. +bool VerifyOutputFrameBitexactness(size_t frame_length, + size_t frame_index, + const float* const* frame, + int offset) { + float reference_frame[480]; + + PopulateInputFrame(frame_length, frame_index, reference_frame, offset); + for (size_t i = 0; i < frame_length; ++i) { + if (reference_frame[i] != frame[0][i]) { + return false; + } + } + + return true; +} + // Class for testing that the capture data is properly received by the block // processor and that the processor data is properly passed to the // EchoCanceller3 output. @@ -159,8 +189,8 @@ class EchoCanceller3Tester { OptionalBandSplit(); PopulateInputFrame(frame_length_, num_bands_, frame_index, &capture_buffer_.split_bands_f(0)[0], 0); - PopulateInputFrame(frame_length_, num_bands_, frame_index, - &render_buffer_.split_bands_f(0)[0], 100); + PopulateInputFrame(frame_length_, frame_index, + &render_buffer_.channels_f()[0][0], 0); EXPECT_TRUE(aec3.AnalyzeRender(&render_buffer_)); aec3.ProcessCapture(&capture_buffer_, false); @@ -184,14 +214,14 @@ class EchoCanceller3Tester { OptionalBandSplit(); PopulateInputFrame(frame_length_, num_bands_, frame_index, &capture_buffer_.split_bands_f(0)[0], 100); - PopulateInputFrame(frame_length_, num_bands_, frame_index, - &render_buffer_.split_bands_f(0)[0], 0); + PopulateInputFrame(frame_length_, frame_index, + &render_buffer_.channels_f()[0][0], 0); EXPECT_TRUE(aec3.AnalyzeRender(&render_buffer_)); aec3.ProcessCapture(&capture_buffer_, false); EXPECT_TRUE(VerifyOutputFrameBitexactness( - frame_length_, num_bands_, frame_index, - &capture_buffer_.split_bands_f(0)[0], -64)); + frame_length_, frame_index, &capture_buffer_.split_bands_f(0)[0], + -64)); } } @@ -263,8 +293,8 @@ class EchoCanceller3Tester { PopulateInputFrame(frame_length_, num_bands_, frame_index, &capture_buffer_.split_bands_f(0)[0], 0); - PopulateInputFrame(frame_length_, num_bands_, frame_index, - &render_buffer_.split_bands_f(0)[0], 0); + PopulateInputFrame(frame_length_, frame_index, + &render_buffer_.channels_f()[0][0], 0); EXPECT_TRUE(aec3.AnalyzeRender(&render_buffer_)); aec3.ProcessCapture(&capture_buffer_, echo_path_change); @@ -354,8 +384,8 @@ class EchoCanceller3Tester { PopulateInputFrame(frame_length_, num_bands_, frame_index, &capture_buffer_.split_bands_f(0)[0], 0); - PopulateInputFrame(frame_length_, num_bands_, frame_index, - &render_buffer_.split_bands_f(0)[0], 0); + PopulateInputFrame(frame_length_, frame_index, + &render_buffer_.channels_f()[0][0], 0); EXPECT_TRUE(aec3.AnalyzeRender(&render_buffer_)); aec3.ProcessCapture(&capture_buffer_, false); @@ -414,7 +444,6 @@ class EchoCanceller3Tester { EchoCanceller3 aec3(sample_rate_hz_, false, std::move(block_processor_mock)); - for (size_t frame_index = 0; frame_index < kNumFramesToProcess; ++frame_index) { for (int k = 0; k < fullband_frame_length_; ++k) { @@ -440,8 +469,8 @@ class EchoCanceller3Tester { PopulateInputFrame(frame_length_, num_bands_, frame_index, &capture_buffer_.split_bands_f(0)[0], 0); - PopulateInputFrame(frame_length_, num_bands_, frame_index, - &render_buffer_.split_bands_f(0)[0], 0); + PopulateInputFrame(frame_length_, frame_index, + &render_buffer_.channels_f()[0][0], 0); EXPECT_TRUE(aec3.AnalyzeRender(&render_buffer_)); aec3.ProcessCapture(&capture_buffer_, false); @@ -462,8 +491,8 @@ class EchoCanceller3Tester { if (sample_rate_hz_ > 16000) { render_buffer_.SplitIntoFrequencyBands(); } - PopulateInputFrame(frame_length_, num_bands_, frame_index, - &render_buffer_.split_bands_f(0)[0], 0); + PopulateInputFrame(frame_length_, frame_index, + &render_buffer_.channels_f()[0][0], 0); EXPECT_TRUE(aec3.AnalyzeRender(&render_buffer_)); } @@ -480,8 +509,8 @@ class EchoCanceller3Tester { aec3.ProcessCapture(&capture_buffer_, false); EXPECT_TRUE(VerifyOutputFrameBitexactness( - frame_length_, num_bands_, frame_index, - &capture_buffer_.split_bands_f(0)[0], -64)); + frame_length_, frame_index, &capture_buffer_.split_bands_f(0)[0], + -64)); } } @@ -497,8 +526,8 @@ class EchoCanceller3Tester { if (sample_rate_hz_ > 16000) { render_buffer_.SplitIntoFrequencyBands(); } - PopulateInputFrame(frame_length_, num_bands_, frame_index, - &render_buffer_.split_bands_f(0)[0], 0); + PopulateInputFrame(frame_length_, frame_index, + &render_buffer_.channels_f()[0][0], 0); if (k == 0) { EXPECT_TRUE(aec3.AnalyzeRender(&render_buffer_)); @@ -518,8 +547,7 @@ class EchoCanceller3Tester { // way that the number of bands for the rates are different. const int aec3_sample_rate_hz = sample_rate_hz_ == 48000 ? 32000 : 48000; EchoCanceller3 aec3(aec3_sample_rate_hz, false); - PopulateInputFrame(frame_length_, num_bands_, 0, - &render_buffer_.split_bands_f(0)[0], 0); + PopulateInputFrame(frame_length_, 0, &render_buffer_.channels_f()[0][0], 0); EXPECT_DEATH(aec3.AnalyzeRender(&render_buffer_), ""); } @@ -547,8 +575,7 @@ class EchoCanceller3Tester { EchoCanceller3 aec3(aec3_sample_rate_hz, false); OptionalBandSplit(); - PopulateInputFrame(frame_length_, num_bands_, 0, - &render_buffer_.split_bands_f(0)[0], 0); + PopulateInputFrame(frame_length_, 0, &render_buffer_.channels_f()[0][0], 0); EXPECT_DEATH(aec3.AnalyzeRender(&render_buffer_), ""); } @@ -673,12 +700,6 @@ TEST(EchoCanceller3Messaging, EchoLeakage) { } #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) -TEST(EchoCanceller3InputCheck, WrongRenderNumBandsCheckVerification) { - for (auto rate : {8000, 16000, 32000, 48000}) { - SCOPED_TRACE(ProduceDebugText(rate)); - EchoCanceller3Tester(rate).RunAnalyzeRenderNumBandsCheckVerification(); - } -} TEST(EchoCanceller3InputCheck, WrongCaptureNumBandsCheckVerification) { for (auto rate : {8000, 16000, 32000, 48000}) { @@ -687,7 +708,10 @@ TEST(EchoCanceller3InputCheck, WrongCaptureNumBandsCheckVerification) { } } -TEST(EchoCanceller3InputCheck, WrongRenderFrameLengthCheckVerification) { +// TODO(peah): Re-enable the test once the issue with memory leaks during DEATH +// tests on test bots has been fixed. +TEST(EchoCanceller3InputCheck, + DISABLED_WrongRenderFrameLengthCheckVerification) { for (auto rate : {8000, 16000}) { SCOPED_TRACE(ProduceDebugText(rate)); EchoCanceller3Tester(rate).RunAnalyzeRenderFrameLengthCheckVerification(); diff --git a/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator.cc b/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator.cc index 539832df03..6472bcb1fb 100644 --- a/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator.cc +++ b/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator.cc @@ -13,7 +13,7 @@ #include #include "webrtc/base/checks.h" -#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" #include "webrtc/modules/audio_processing/include/audio_processing.h" #include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" @@ -32,6 +32,7 @@ constexpr int kDownSamplingFactor = 4; EchoPathDelayEstimator::EchoPathDelayEstimator(ApmDataDumper* data_dumper) : data_dumper_(data_dumper), matched_filter_(data_dumper_, + DetectOptimization(), kMatchedFilterWindowSizeSubBlocks, kNumMatchedFilters, kMatchedFilterAlignmentShiftSizeSubBlocks), diff --git a/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator_unittest.cc b/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator_unittest.cc index ba9ff23540..476362a8c2 100644 --- a/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator_unittest.cc +++ b/webrtc/modules/audio_processing/aec3/echo_path_delay_estimator_unittest.cc @@ -15,7 +15,7 @@ #include #include "webrtc/base/random.h" -#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" #include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" #include "webrtc/modules/audio_processing/test/echo_canceller_test_tools.h" #include "webrtc/test/gtest.h" @@ -49,7 +49,7 @@ TEST(EchoPathDelayEstimator, DelayEstimation) { std::vector render(kBlockSize, 0.f); std::vector capture(kBlockSize, 0.f); ApmDataDumper data_dumper(0); - for (size_t delay_samples : {0, 64, 150, 200, 800, 4000}) { + for (size_t delay_samples : {15, 64, 150, 200, 800, 4000}) { SCOPED_TRACE(ProduceDebugText(delay_samples)); DelayBuffer signal_delay_buffer(delay_samples); EchoPathDelayEstimator estimator(&data_dumper); diff --git a/webrtc/modules/audio_processing/aec3/echo_path_variability.cc b/webrtc/modules/audio_processing/aec3/echo_path_variability.cc new file mode 100644 index 0000000000..514659205c --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/echo_path_variability.cc @@ -0,0 +1,18 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/echo_path_variability.h" + +namespace webrtc { + +EchoPathVariability::EchoPathVariability(bool gain_change, bool delay_change) + : gain_change(gain_change), delay_change(delay_change) {} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/echo_path_variability.h b/webrtc/modules/audio_processing/aec3/echo_path_variability.h index 070887964d..7755362787 100644 --- a/webrtc/modules/audio_processing/aec3/echo_path_variability.h +++ b/webrtc/modules/audio_processing/aec3/echo_path_variability.h @@ -14,8 +14,7 @@ namespace webrtc { struct EchoPathVariability { - EchoPathVariability(bool gain_change, bool delay_change) - : gain_change(gain_change), delay_change(delay_change) {} + EchoPathVariability(bool gain_change, bool delay_change); bool AudioPathChanged() const { return gain_change || delay_change; } bool gain_change; diff --git a/webrtc/modules/audio_processing/aec3/echo_path_variability_unittest.cc b/webrtc/modules/audio_processing/aec3/echo_path_variability_unittest.cc new file mode 100644 index 0000000000..e2e82d1f97 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/echo_path_variability_unittest.cc @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/echo_path_variability.h" +#include "webrtc/test/gtest.h" + +namespace webrtc { + +TEST(EchoPathVariability, CorrectBehavior) { + // Test correct passing and reporting of the gain change information. + EchoPathVariability v(true, true); + EXPECT_TRUE(v.gain_change); + EXPECT_TRUE(v.delay_change); + EXPECT_TRUE(v.AudioPathChanged()); + + v = EchoPathVariability(true, false); + EXPECT_TRUE(v.gain_change); + EXPECT_FALSE(v.delay_change); + EXPECT_TRUE(v.AudioPathChanged()); + + v = EchoPathVariability(false, true); + EXPECT_FALSE(v.gain_change); + EXPECT_TRUE(v.delay_change); + EXPECT_TRUE(v.AudioPathChanged()); + + v = EchoPathVariability(false, false); + EXPECT_FALSE(v.gain_change); + EXPECT_FALSE(v.delay_change); + EXPECT_FALSE(v.AudioPathChanged()); +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/echo_remover.cc b/webrtc/modules/audio_processing/aec3/echo_remover.cc index ab0b68bb16..f700389475 100644 --- a/webrtc/modules/audio_processing/aec3/echo_remover.cc +++ b/webrtc/modules/audio_processing/aec3/echo_remover.cc @@ -10,59 +10,237 @@ #include "webrtc/modules/audio_processing/aec3/echo_remover.h" #include -#include +#include +#include +#include +#include "webrtc/base/array_view.h" +#include "webrtc/base/atomicops.h" #include "webrtc/base/constructormagic.h" -#include "webrtc/base/checks.h" -#include "webrtc/base/optional.h" -#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" +#include "webrtc/modules/audio_processing/aec3/aec_state.h" +#include "webrtc/modules/audio_processing/aec3/comfort_noise_generator.h" +#include "webrtc/modules/audio_processing/aec3/echo_path_variability.h" +#include "webrtc/modules/audio_processing/aec3/fft_buffer.h" +#include "webrtc/modules/audio_processing/aec3/fft_data.h" +#include "webrtc/modules/audio_processing/aec3/output_selector.h" +#include "webrtc/modules/audio_processing/aec3/power_echo_model.h" +#include "webrtc/modules/audio_processing/aec3/render_delay_buffer.h" +#include "webrtc/modules/audio_processing/aec3/residual_echo_estimator.h" +#include "webrtc/modules/audio_processing/aec3/subtractor.h" +#include "webrtc/modules/audio_processing/aec3/suppression_filter.h" +#include "webrtc/modules/audio_processing/aec3/suppression_gain.h" +#include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" namespace webrtc { namespace { + +void LinearEchoPower(const FftData& E, + const FftData& Y, + std::array* S2) { + for (size_t k = 0; k < E.re.size(); ++k) { + (*S2)[k] = (Y.re[k] - E.re[k]) * (Y.re[k] - E.re[k]) + + (Y.im[k] - E.im[k]) * (Y.im[k] - E.im[k]); + } +} + +float BlockPower(const std::array x) { + return std::accumulate(x.begin(), x.end(), 0.f, + [](float a, float b) -> float { return a + b * b; }); +} + +// Class for removing the echo from the capture signal. class EchoRemoverImpl final : public EchoRemover { public: explicit EchoRemoverImpl(int sample_rate_hz); ~EchoRemoverImpl() override; - void ProcessBlock(const rtc::Optional& echo_path_delay_samples, - const EchoPathVariability& echo_path_variability, - bool capture_signal_saturation, - const std::vector>& render, - std::vector>* capture) override; + // Removes the echo from a block of samples from the capture signal. The + // supplied render signal is assumed to be pre-aligned with the capture + // signal. + void ProcessBlock( + const rtc::Optional& external_echo_path_delay_estimate, + const EchoPathVariability& echo_path_variability, + bool capture_signal_saturation, + const std::vector>& render, + std::vector>* capture) override; - void UpdateEchoLeakageStatus(bool leakage_detected) override; + // Updates the status on whether echo leakage is detected in the output of the + // echo remover. + void UpdateEchoLeakageStatus(bool leakage_detected) override { + echo_leakage_detected_ = leakage_detected; + } private: + static int instance_count_; + const Aec3Fft fft_; + std::unique_ptr data_dumper_; + const Aec3Optimization optimization_; const int sample_rate_hz_; + Subtractor subtractor_; + SuppressionGain suppression_gain_; + ComfortNoiseGenerator cng_; + SuppressionFilter suppression_filter_; + PowerEchoModel power_echo_model_; + FftBuffer X_buffer_; + RenderSignalAnalyzer render_signal_analyzer_; + OutputSelector output_selector_; + ResidualEchoEstimator residual_echo_estimator_; + bool echo_leakage_detected_ = false; + std::array x_old_; + AecState aec_state_; RTC_DISALLOW_COPY_AND_ASSIGN(EchoRemoverImpl); }; -// TODO(peah): Add functionality. +int EchoRemoverImpl::instance_count_ = 0; + EchoRemoverImpl::EchoRemoverImpl(int sample_rate_hz) - : sample_rate_hz_(sample_rate_hz) { - RTC_DCHECK(ValidFullBandRate(sample_rate_hz_)); + : data_dumper_( + new ApmDataDumper(rtc::AtomicOps::Increment(&instance_count_))), + optimization_(DetectOptimization()), + sample_rate_hz_(sample_rate_hz), + subtractor_(data_dumper_.get(), optimization_), + suppression_gain_(optimization_), + cng_(optimization_), + suppression_filter_(sample_rate_hz_), + X_buffer_(optimization_, + std::max(subtractor_.MinFarendBufferLength(), + power_echo_model_.MinFarendBufferLength()), + subtractor_.NumBlocksInRenderSums()) { + RTC_DCHECK(ValidFullBandRate(sample_rate_hz)); + x_old_.fill(0.f); } EchoRemoverImpl::~EchoRemoverImpl() = default; -// TODO(peah): Add functionality. void EchoRemoverImpl::ProcessBlock( const rtc::Optional& echo_path_delay_samples, const EchoPathVariability& echo_path_variability, bool capture_signal_saturation, const std::vector>& render, std::vector>* capture) { - RTC_DCHECK(capture); - RTC_DCHECK_EQ(render.size(), NumBandsForRate(sample_rate_hz_)); - RTC_DCHECK_EQ(capture->size(), NumBandsForRate(sample_rate_hz_)); - RTC_DCHECK_EQ(render[0].size(), kBlockSize); - RTC_DCHECK_EQ((*capture)[0].size(), kBlockSize); -} + const std::vector>& x = render; + std::vector>* y = capture; -// TODO(peah): Add functionality. -void EchoRemoverImpl::UpdateEchoLeakageStatus(bool leakage_detected) {} + RTC_DCHECK(y); + RTC_DCHECK_EQ(x.size(), NumBandsForRate(sample_rate_hz_)); + RTC_DCHECK_EQ(y->size(), NumBandsForRate(sample_rate_hz_)); + RTC_DCHECK_EQ(x[0].size(), kBlockSize); + RTC_DCHECK_EQ((*y)[0].size(), kBlockSize); + const std::vector& x0 = x[0]; + std::vector& y0 = (*y)[0]; + + data_dumper_->DumpWav("aec3_processblock_capture_input", kBlockSize, &y0[0], + LowestBandRate(sample_rate_hz_), 1); + data_dumper_->DumpWav("aec3_processblock_render_input", kBlockSize, &x0[0], + LowestBandRate(sample_rate_hz_), 1); + + aec_state_.UpdateCaptureSaturation(capture_signal_saturation); + + if (echo_path_variability.AudioPathChanged()) { + subtractor_.HandleEchoPathChange(echo_path_variability); + power_echo_model_.HandleEchoPathChange(echo_path_variability); + } + + std::array Y2; + std::array S2_power; + std::array R2; + std::array S2_linear; + std::array G; + FftData X; + FftData Y; + FftData comfort_noise; + FftData high_band_comfort_noise; + SubtractorOutput subtractor_output; + FftData& E_main = subtractor_output.E_main; + auto& E2_main = subtractor_output.E2_main; + auto& E2_shadow = subtractor_output.E2_shadow; + auto& e_main = subtractor_output.e_main; + auto& e_shadow = subtractor_output.e_shadow; + + // Update the render signal buffer. + fft_.PaddedFft(x0, x_old_, &X); + X_buffer_.Insert(X); + + // Analyze the render signal. + render_signal_analyzer_.Update(X_buffer_, aec_state_.FilterDelay()); + + // Perform linear echo cancellation. + subtractor_.Process(X_buffer_, y0, render_signal_analyzer_, + aec_state_.SaturatedCapture(), &subtractor_output); + + // Compute spectra. + fft_.ZeroPaddedFft(y0, &Y); + LinearEchoPower(E_main, Y, &S2_linear); + Y.Spectrum(optimization_, &Y2); + + // Update the AEC state information. + aec_state_.Update(subtractor_.FilterFrequencyResponse(), + echo_path_delay_samples, X_buffer_, E2_main, E2_shadow, Y2, + x0, echo_path_variability, echo_leakage_detected_); + + // Use the power model to estimate the echo. + power_echo_model_.EstimateEcho(X_buffer_, Y2, aec_state_, &S2_power); + + // Choose the linear output. + output_selector_.FormLinearOutput(e_main, y0); + data_dumper_->DumpWav("aec3_output_linear", kBlockSize, &y0[0], + LowestBandRate(sample_rate_hz_), 1); + const auto& E2 = output_selector_.UseSubtractorOutput() ? E2_main : Y2; + + // Estimate the residual echo power. + residual_echo_estimator_.Estimate( + output_selector_.UseSubtractorOutput(), aec_state_, X_buffer_, + subtractor_.FilterFrequencyResponse(), E2_main, E2_shadow, S2_linear, + S2_power, Y2, &R2); + + // Estimate the comfort noise. + cng_.Compute(aec_state_, Y2, &comfort_noise, &high_band_comfort_noise); + + // Detect basic doubletalk. + const bool doubletalk = BlockPower(e_shadow) < BlockPower(e_main); + + // A choose and apply echo suppression gain. + suppression_gain_.GetGain(E2, R2, cng_.NoiseSpectrum(), + doubletalk ? 0.001f : 0.0001f, &G); + suppression_filter_.ApplyGain(comfort_noise, high_band_comfort_noise, G, y); + + // Debug outputs for the purpose of development and analysis. + data_dumper_->DumpRaw("aec3_N2", cng_.NoiseSpectrum()); + data_dumper_->DumpRaw("aec3_suppressor_gain", G); + data_dumper_->DumpWav("aec3_output", + rtc::ArrayView(&y0[0], kBlockSize), + LowestBandRate(sample_rate_hz_), 1); + data_dumper_->DumpRaw("aec3_using_subtractor_output", + output_selector_.UseSubtractorOutput() ? 1 : 0); + data_dumper_->DumpRaw("aec3_doubletalk", doubletalk ? 1 : 0); + data_dumper_->DumpRaw("aec3_E2", E2); + data_dumper_->DumpRaw("aec3_E2_main", E2_main); + data_dumper_->DumpRaw("aec3_E2_shadow", E2_shadow); + data_dumper_->DumpRaw("aec3_S2_linear", S2_linear); + data_dumper_->DumpRaw("aec3_S2_power", S2_power); + data_dumper_->DumpRaw("aec3_Y2", Y2); + data_dumper_->DumpRaw("aec3_R2", R2); + data_dumper_->DumpRaw("aec3_erle", aec_state_.Erle()); + data_dumper_->DumpRaw("aec3_erl", aec_state_.Erl()); + data_dumper_->DumpRaw("aec3_reliable_filter_bands", + aec_state_.BandsWithReliableFilter()); + data_dumper_->DumpRaw("aec3_active_render", aec_state_.ActiveRender()); + data_dumper_->DumpRaw("aec3_model_based_aec_feasible", + aec_state_.ModelBasedAecFeasible()); + data_dumper_->DumpRaw("aec3_usable_linear_estimate", + aec_state_.UsableLinearEstimate()); + data_dumper_->DumpRaw( + "aec3_filter_delay", + aec_state_.FilterDelay() ? *aec_state_.FilterDelay() : -1); + data_dumper_->DumpRaw( + "aec3_external_delay", + aec_state_.ExternalDelay() ? *aec_state_.ExternalDelay() : -1); + data_dumper_->DumpRaw("aec3_capture_saturation", + aec_state_.SaturatedCapture() ? 1 : 0); +} } // namespace diff --git a/webrtc/modules/audio_processing/aec3/echo_remover_unittest.cc b/webrtc/modules/audio_processing/aec3/echo_remover_unittest.cc index 1c019937e5..29d3410a1e 100644 --- a/webrtc/modules/audio_processing/aec3/echo_remover_unittest.cc +++ b/webrtc/modules/audio_processing/aec3/echo_remover_unittest.cc @@ -10,12 +10,16 @@ #include "webrtc/modules/audio_processing/aec3/echo_remover.h" +#include #include +#include #include #include -#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/base/random.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" #include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" +#include "webrtc/modules/audio_processing/test/echo_canceller_test_tools.h" #include "webrtc/test/gtest.h" namespace webrtc { @@ -27,12 +31,18 @@ std::string ProduceDebugText(int sample_rate_hz) { return ss.str(); } +std::string ProduceDebugText(int sample_rate_hz, int delay) { + std::ostringstream ss(ProduceDebugText(sample_rate_hz)); + ss << ", Delay: " << delay; + return ss.str(); +} + } // namespace // Verifies the basic API call sequence TEST(EchoRemover, BasicApiCalls) { for (auto rate : {8000, 16000, 32000, 48000}) { - ProduceDebugText(rate); + SCOPED_TRACE(ProduceDebugText(rate)); std::unique_ptr remover(EchoRemover::Create(rate)); std::vector> render(NumBandsForRate(rate), @@ -64,7 +74,7 @@ TEST(EchoRemover, DISABLED_WrongSampleRate) { // Verifies the check for the render block size. TEST(EchoRemover, WrongRenderBlockSize) { for (auto rate : {8000, 16000, 32000, 48000}) { - ProduceDebugText(rate); + SCOPED_TRACE(ProduceDebugText(rate)); std::unique_ptr remover(EchoRemover::Create(rate)); std::vector> render( @@ -83,7 +93,7 @@ TEST(EchoRemover, WrongRenderBlockSize) { // Verifies the check for the capture block size. TEST(EchoRemover, WrongCaptureBlockSize) { for (auto rate : {8000, 16000, 32000, 48000}) { - ProduceDebugText(rate); + SCOPED_TRACE(ProduceDebugText(rate)); std::unique_ptr remover(EchoRemover::Create(rate)); std::vector> render(NumBandsForRate(rate), @@ -102,7 +112,7 @@ TEST(EchoRemover, WrongCaptureBlockSize) { // Verifies the check for the number of render bands. TEST(EchoRemover, WrongRenderNumBands) { for (auto rate : {16000, 32000, 48000}) { - ProduceDebugText(rate); + SCOPED_TRACE(ProduceDebugText(rate)); std::unique_ptr remover(EchoRemover::Create(rate)); std::vector> render( @@ -120,9 +130,11 @@ TEST(EchoRemover, WrongRenderNumBands) { } // Verifies the check for the number of capture bands. -TEST(EchoRemover, WrongCaptureNumBands) { +// TODO(peah): Re-enable the test once the issue with memory leaks during DEATH +// tests on test bots has been fixed.c +TEST(EchoRemover, DISABLED_WrongCaptureNumBands) { for (auto rate : {16000, 32000, 48000}) { - ProduceDebugText(rate); + SCOPED_TRACE(ProduceDebugText(rate)); std::unique_ptr remover(EchoRemover::Create(rate)); std::vector> render(NumBandsForRate(rate), @@ -155,4 +167,59 @@ TEST(EchoRemover, NullCapture) { #endif +// Performs a sanity check that the echo_remover is able to properly +// remove echoes. +TEST(EchoRemover, BasicEchoRemoval) { + constexpr int kNumBlocksToProcess = 500; + Random random_generator(42U); + for (auto rate : {8000, 16000, 32000, 48000}) { + std::vector> x(NumBandsForRate(rate), + std::vector(kBlockSize, 0.f)); + std::vector> y(NumBandsForRate(rate), + std::vector(kBlockSize, 0.f)); + EchoPathVariability echo_path_variability(false, false); + for (size_t delay_samples : {0, 64, 150, 200, 301}) { + SCOPED_TRACE(ProduceDebugText(rate, delay_samples)); + std::unique_ptr remover(EchoRemover::Create(rate)); + std::vector>> delay_buffers(x.size()); + for (size_t j = 0; j < x.size(); ++j) { + delay_buffers[j].reset(new DelayBuffer(delay_samples)); + } + + float input_energy = 0.f; + float output_energy = 0.f; + for (int k = 0; k < kNumBlocksToProcess; ++k) { + const bool silence = k < 100 || (k % 100 >= 10); + + for (size_t j = 0; j < x.size(); ++j) { + if (silence) { + std::fill(x[j].begin(), x[j].end(), 0.f); + } else { + RandomizeSampleVector(&random_generator, x[j]); + } + delay_buffers[j]->Delay(x[j], y[j]); + } + + if (k > kNumBlocksToProcess / 2) { + for (size_t j = 0; j < x.size(); ++j) { + input_energy = std::inner_product(y[j].begin(), y[j].end(), + y[j].begin(), input_energy); + } + } + + remover->ProcessBlock(rtc::Optional(delay_samples), + echo_path_variability, false, x, &y); + + if (k > kNumBlocksToProcess / 2) { + for (size_t j = 0; j < x.size(); ++j) { + output_energy = std::inner_product(y[j].begin(), y[j].end(), + y[j].begin(), output_energy); + } + } + } + EXPECT_GT(input_energy, 10.f * output_energy); + } + } +} + } // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/erl_estimator.cc b/webrtc/modules/audio_processing/aec3/erl_estimator.cc new file mode 100644 index 0000000000..6990ea3cec --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/erl_estimator.cc @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/erl_estimator.h" + +#include + +namespace webrtc { + +namespace { + +constexpr float kMinErl = 0.01f; +constexpr float kMaxErl = 1000.f; + +} // namespace + +ErlEstimator::ErlEstimator() { + erl_.fill(kMaxErl); + hold_counters_.fill(0); +} + +ErlEstimator::~ErlEstimator() = default; + +void ErlEstimator::Update( + const std::array& render_spectrum, + const std::array& capture_spectrum) { + const auto& X2 = render_spectrum; + const auto& Y2 = capture_spectrum; + + // Corresponds to WGN of power -46 dBFS. + constexpr float kX2Min = 44015068.0f; + + // Update the estimates in a maximum statistics manner. + for (size_t k = 1; k < kFftLengthBy2; ++k) { + if (X2[k] > kX2Min) { + const float new_erl = Y2[k] / X2[k]; + if (new_erl < erl_[k]) { + hold_counters_[k - 1] = 1000; + erl_[k] += 0.1 * (new_erl - erl_[k]); + erl_[k] = std::max(erl_[k], kMinErl); + } + } + } + + std::for_each(hold_counters_.begin(), hold_counters_.end(), + [](int& a) { --a; }); + std::transform(hold_counters_.begin(), hold_counters_.end(), erl_.begin() + 1, + erl_.begin() + 1, [](int a, float b) { + return a > 0 ? b : std::min(kMaxErl, 2.f * b); + }); + + erl_[0] = erl_[1]; + erl_[kFftLengthBy2] = erl_[kFftLengthBy2 - 1]; +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/erl_estimator.h b/webrtc/modules/audio_processing/aec3/erl_estimator.h new file mode 100644 index 0000000000..33eba26536 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/erl_estimator.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_ERL_ESTIMATOR_H_ +#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_ERL_ESTIMATOR_H_ + +#include + +#include "webrtc/base/constructormagic.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" + +namespace webrtc { + +// Estimates the echo return loss based on the signal spectra. +class ErlEstimator { + public: + ErlEstimator(); + ~ErlEstimator(); + + // Updates the ERL estimate. + void Update(const std::array& render_spectrum, + const std::array& capture_spectrum); + + // Returns the most recent ERL estimate. + const std::array& Erl() const { return erl_; } + + private: + std::array erl_; + std::array hold_counters_; + + RTC_DISALLOW_COPY_AND_ASSIGN(ErlEstimator); +}; + +} // namespace webrtc + +#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_ERL_ESTIMATOR_H_ diff --git a/webrtc/modules/audio_processing/aec3/erl_estimator_unittest.cc b/webrtc/modules/audio_processing/aec3/erl_estimator_unittest.cc new file mode 100644 index 0000000000..bf803820de --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/erl_estimator_unittest.cc @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/erl_estimator.h" + +#include "webrtc/test/gtest.h" + +namespace webrtc { + +namespace { + +void VerifyErl(const std::array& erl, + float reference) { + std::for_each(erl.begin(), erl.end(), + [reference](float a) { EXPECT_NEAR(reference, a, 0.001); }); +} + +} // namespace + +// Verifies that the correct ERL estimates are achieved. +TEST(ErlEstimator, Estimates) { + std::array X2; + std::array Y2; + + ErlEstimator estimator; + + // Verifies that the ERL estimate is properly reduced to lower values. + X2.fill(500 * 1000.f * 1000.f); + Y2.fill(10 * X2[0]); + for (size_t k = 0; k < 200; ++k) { + estimator.Update(X2, Y2); + } + VerifyErl(estimator.Erl(), 10.f); + + // Verifies that the ERL is not immediately increased when the ERL in the data + // increases. + Y2.fill(10000 * X2[0]); + for (size_t k = 0; k < 998; ++k) { + estimator.Update(X2, Y2); + } + VerifyErl(estimator.Erl(), 10.f); + + // Verifies that the rate of increase is 3 dB. + estimator.Update(X2, Y2); + VerifyErl(estimator.Erl(), 20.f); + + // Verifies that the maximum ERL is achieved when there are no low RLE + // estimates. + for (size_t k = 0; k < 1000; ++k) { + estimator.Update(X2, Y2); + } + VerifyErl(estimator.Erl(), 1000.f); + + // Verifies that the ERL estimate is is not updated for low-level signals + X2.fill(1000.f * 1000.f); + Y2.fill(10 * X2[0]); + for (size_t k = 0; k < 200; ++k) { + estimator.Update(X2, Y2); + } + VerifyErl(estimator.Erl(), 1000.f); +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/erle_estimator.cc b/webrtc/modules/audio_processing/aec3/erle_estimator.cc new file mode 100644 index 0000000000..044e11ea3d --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/erle_estimator.cc @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/erle_estimator.h" + +#include + +namespace webrtc { + +namespace { + +constexpr float kMinErle = 1.f; +constexpr float kMaxErle = 8.f; + +} // namespace + +ErleEstimator::ErleEstimator() { + erle_.fill(kMinErle); + hold_counters_.fill(0); +} + +ErleEstimator::~ErleEstimator() = default; + +void ErleEstimator::Update( + const std::array& render_spectrum, + const std::array& capture_spectrum, + const std::array& subtractor_spectrum) { + const auto& X2 = render_spectrum; + const auto& Y2 = capture_spectrum; + const auto& E2 = subtractor_spectrum; + + // Corresponds of WGN of power -46 dBFS. + constexpr float kX2Min = 44015068.0f; + + // Update the estimates in a clamped minimum statistics manner. + for (size_t k = 1; k < kFftLengthBy2; ++k) { + if (X2[k] > kX2Min && E2[k] > 0.f) { + const float new_erle = Y2[k] / E2[k]; + if (new_erle > erle_[k]) { + hold_counters_[k - 1] = 100; + erle_[k] += 0.1f * (new_erle - erle_[k]); + erle_[k] = std::max(kMinErle, std::min(erle_[k], kMaxErle)); + } + } + } + + std::for_each(hold_counters_.begin(), hold_counters_.end(), + [](int& a) { --a; }); + std::transform(hold_counters_.begin(), hold_counters_.end(), + erle_.begin() + 1, erle_.begin() + 1, [](int a, float b) { + return a > 0 ? b : std::max(kMinErle, 0.97f * b); + }); + + erle_[0] = erle_[1]; + erle_[kFftLengthBy2] = erle_[kFftLengthBy2 - 1]; +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/erle_estimator.h b/webrtc/modules/audio_processing/aec3/erle_estimator.h new file mode 100644 index 0000000000..d504ef2876 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/erle_estimator.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_ERLE_ESTIMATOR_H_ +#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_ERLE_ESTIMATOR_H_ + +#include + +#include "webrtc/base/constructormagic.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" + +namespace webrtc { + +// Estimates the echo return loss enhancement based on the signal spectra. +class ErleEstimator { + public: + ErleEstimator(); + ~ErleEstimator(); + + // Updates the ERLE estimate. + void Update(const std::array& render_spectrum, + const std::array& capture_spectrum, + const std::array& subtractor_spectrum); + + // Returns the most recent ERLE estimate. + const std::array& Erle() const { return erle_; } + + private: + std::array erle_; + std::array hold_counters_; + + RTC_DISALLOW_COPY_AND_ASSIGN(ErleEstimator); +}; + +} // namespace webrtc + +#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_ERLE_ESTIMATOR_H_ diff --git a/webrtc/modules/audio_processing/aec3/erle_estimator_unittest.cc b/webrtc/modules/audio_processing/aec3/erle_estimator_unittest.cc new file mode 100644 index 0000000000..5fdabfa906 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/erle_estimator_unittest.cc @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/erle_estimator.h" +#include "webrtc/test/gtest.h" + +namespace webrtc { + +namespace { + +void VerifyErle(const std::array& erle, + float reference) { + std::for_each(erle.begin(), erle.end(), + [reference](float a) { EXPECT_NEAR(reference, a, 0.001); }); +} + +} // namespace + +// Verifies that the correct ERLE estimates are achieved. +TEST(ErleEstimator, Estimates) { + std::array X2; + std::array E2; + std::array Y2; + + ErleEstimator estimator; + + // Verifies that the ERLE estimate is properley increased to higher values. + X2.fill(500 * 1000.f * 1000.f); + E2.fill(1000.f * 1000.f); + Y2.fill(10 * E2[0]); + for (size_t k = 0; k < 200; ++k) { + estimator.Update(X2, Y2, E2); + } + VerifyErle(estimator.Erle(), 8.f); + + // Verifies that the ERLE is not immediately decreased when the ERLE in the + // data decreases. + Y2.fill(0.1f * E2[0]); + for (size_t k = 0; k < 98; ++k) { + estimator.Update(X2, Y2, E2); + } + VerifyErle(estimator.Erle(), 8.f); + + // Verifies that the minimum ERLE is eventually achieved. + for (size_t k = 0; k < 1000; ++k) { + estimator.Update(X2, Y2, E2); + } + VerifyErle(estimator.Erle(), 1.f); + + // Verifies that the ERLE estimate is is not updated for low-level render + // signals. + X2.fill(1000.f * 1000.f); + Y2.fill(10 * E2[0]); + for (size_t k = 0; k < 200; ++k) { + estimator.Update(X2, Y2, E2); + } + VerifyErle(estimator.Erle(), 1.f); +} +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/fft_buffer.cc b/webrtc/modules/audio_processing/aec3/fft_buffer.cc new file mode 100644 index 0000000000..6542d108ef --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/fft_buffer.cc @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/fft_buffer.h" + +#include + +#include "webrtc/base/checks.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" + +namespace webrtc { + +FftBuffer::FftBuffer(Aec3Optimization optimization, + size_t num_partitions, + const std::vector num_ffts_for_spectral_sums) + : optimization_(optimization), + fft_buffer_(num_partitions), + spectrum_buffer_(num_partitions, std::array()), + spectral_sums_(num_ffts_for_spectral_sums.size(), + std::array()) { + // Current implementation only allows a maximum of one spectral sum lengths. + RTC_DCHECK_EQ(1, num_ffts_for_spectral_sums.size()); + spectral_sums_length_ = num_ffts_for_spectral_sums[0]; + RTC_DCHECK_GE(fft_buffer_.size(), spectral_sums_length_); + + for (auto& sum : spectral_sums_) { + sum.fill(0.f); + } + + for (auto& spectrum : spectrum_buffer_) { + spectrum.fill(0.f); + } + + for (auto& fft : fft_buffer_) { + fft.Clear(); + } +} + +FftBuffer::~FftBuffer() = default; + +void FftBuffer::Insert(const FftData& fft) { + // Insert the fft into the buffer. + position_ = (position_ - 1 + fft_buffer_.size()) % fft_buffer_.size(); + fft_buffer_[position_].Assign(fft); + + // Compute and insert the spectrum for the FFT into the spectrum buffer. + fft.Spectrum(optimization_, &spectrum_buffer_[position_]); + + // Pre-compute and cachec the spectral sums. + std::copy(spectrum_buffer_[position_].begin(), + spectrum_buffer_[position_].end(), spectral_sums_[0].begin()); + size_t position = (position_ + 1) % fft_buffer_.size(); + for (size_t j = 1; j < spectral_sums_length_; ++j) { + const std::array& spectrum = + spectrum_buffer_[position]; + + for (size_t k = 0; k < spectral_sums_[0].size(); ++k) { + spectral_sums_[0][k] += spectrum[k]; + } + + position = position < (fft_buffer_.size() - 1) ? position + 1 : 0; + } +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/fft_buffer.h b/webrtc/modules/audio_processing/aec3/fft_buffer.h new file mode 100644 index 0000000000..c99c9570f3 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/fft_buffer.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_FFT_BUFFER_H_ +#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_FFT_BUFFER_H_ + +#include +#include + +#include "webrtc/base/array_view.h" +#include "webrtc/base/constructormagic.h" +#include "webrtc/modules/audio_processing/aec3/fft_data.h" + +namespace webrtc { + +// Provides a circular buffer for 128 point real-valued FFT data. +class FftBuffer { + public: + // The constructor takes as parameters the size of the buffer, as well as a + // vector containing the number of FFTs that will be included in the spectral + // sums in the call to SpectralSum. + FftBuffer(Aec3Optimization optimization, + size_t size, + const std::vector num_ffts_for_spectral_sums); + ~FftBuffer(); + + // Insert an FFT into the buffer. + void Insert(const FftData& fft); + + // Get the spectrum from one of the FFTs in the buffer + const std::array& Spectrum( + size_t buffer_offset_ffts) const { + return spectrum_buffer_[(position_ + buffer_offset_ffts) % + fft_buffer_.size()]; + } + + // Returns the sum of the spectrums for a certain number of FFTs. + const std::array& SpectralSum( + size_t num_ffts) const { + RTC_DCHECK_EQ(spectral_sums_length_, num_ffts); + return spectral_sums_[0]; + } + + // Returns the circular buffer. + rtc::ArrayView Buffer() const { return fft_buffer_; } + + // Returns the current position in the circular buffer + size_t Position() const { return position_; } + + private: + const Aec3Optimization optimization_; + std::vector fft_buffer_; + std::vector> spectrum_buffer_; + size_t spectral_sums_length_; + std::vector> spectral_sums_; + size_t position_ = 0; + + RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(FftBuffer); +}; + +} // namespace webrtc + +#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_FFT_BUFFER_H_ diff --git a/webrtc/modules/audio_processing/aec3/fft_buffer_unittest.cc b/webrtc/modules/audio_processing/aec3/fft_buffer_unittest.cc new file mode 100644 index 0000000000..d4854bd849 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/fft_buffer_unittest.cc @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/fft_buffer.h" + +#include +#include +#include + +#include "webrtc/test/gtest.h" + +namespace webrtc { +namespace {} // namespace + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +// Verifies that the check for that the provided numbers of Ffts to include in +// the spectral sum is equal to the one supported works. +TEST(FftBuffer, TooLargeNumberOfSpectralSums) { + EXPECT_DEATH(FftBuffer(Aec3Optimization::kNone, 1, std::vector(2, 1)), + ""); +} + +TEST(FftBuffer, TooSmallNumberOfSpectralSums) { + EXPECT_DEATH(FftBuffer(Aec3Optimization::kNone, 1, std::vector()), + ""); +} + +// Verifies that the check for that the provided number of Ffts to to include in +// the spectral is feasible works. +TEST(FftBuffer, FeasibleNumberOfFftsInSum) { + EXPECT_DEATH(FftBuffer(Aec3Optimization::kNone, 1, std::vector(1, 2)), + ""); +} + +#endif + +// Verify the basic usage of the FftBuffer. +TEST(FftBuffer, NormalUsage) { + constexpr int kBufferSize = 10; + FftBuffer buffer(Aec3Optimization::kNone, kBufferSize, + std::vector(1, kBufferSize)); + FftData X; + std::vector> buffer_ref(kBufferSize); + + for (int k = 0; k < 30; ++k) { + std::array X2_sum_ref; + X2_sum_ref.fill(0.f); + for (size_t j = 0; j < buffer.Buffer().size(); ++j) { + const std::array& X2 = buffer.Spectrum(j); + const std::array& X2_ref = buffer_ref[j]; + EXPECT_EQ(X2_ref, X2); + + std::transform(X2_ref.begin(), X2_ref.end(), X2_sum_ref.begin(), + X2_sum_ref.begin(), std::plus()); + } + EXPECT_EQ(X2_sum_ref, buffer.SpectralSum(kBufferSize)); + + std::array X2; + X.re.fill(k); + X.im.fill(k); + X.Spectrum(Aec3Optimization::kNone, &X2); + buffer.Insert(X); + buffer_ref.pop_back(); + buffer_ref.insert(buffer_ref.begin(), X2); + } +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/fft_data.h b/webrtc/modules/audio_processing/aec3/fft_data.h new file mode 100644 index 0000000000..5a92d91d62 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/fft_data.h @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_FFT_DATA_H_ +#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_FFT_DATA_H_ + +#include "webrtc/typedefs.h" +#if defined(WEBRTC_ARCH_X86_FAMILY) +#include +#endif +#include +#include + +#include "webrtc/base/array_view.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" + +namespace webrtc { + +// Struct that holds imaginary data produced from 128 point real-valued FFTs. +struct FftData { + // Copies the data in src. + void Assign(const FftData& src) { + std::copy(src.re.begin(), src.re.end(), re.begin()); + std::copy(src.im.begin(), src.im.end(), im.begin()); + im[0] = im[kFftLengthBy2] = 0; + } + + // Clears all the imaginary. + void Clear() { + re.fill(0.f); + im.fill(0.f); + } + + // Computes the power spectrum of the data. + void Spectrum(Aec3Optimization optimization, + std::array* power_spectrum) const { + RTC_DCHECK(power_spectrum); + switch (optimization) { +#if defined(WEBRTC_ARCH_X86_FAMILY) + case Aec3Optimization::kSse2: { + constexpr int kNumFourBinBands = kFftLengthBy2 / 4; + constexpr int kLimit = kNumFourBinBands * 4; + for (size_t k = 0; k < kLimit; k += 4) { + const __m128 r = _mm_loadu_ps(&re[k]); + const __m128 i = _mm_loadu_ps(&im[k]); + const __m128 ii = _mm_mul_ps(i, i); + const __m128 rr = _mm_mul_ps(r, r); + const __m128 rrii = _mm_add_ps(rr, ii); + _mm_storeu_ps(&(*power_spectrum)[k], rrii); + } + (*power_spectrum)[kFftLengthBy2] = + re[kFftLengthBy2] * re[kFftLengthBy2] + + im[kFftLengthBy2] * im[kFftLengthBy2]; + } break; +#endif + default: + std::transform(re.begin(), re.end(), im.begin(), + power_spectrum->begin(), + [](float a, float b) { return a * a + b * b; }); + } + } + + // Copy the data from an interleaved array. + void CopyFromPackedArray(const std::array& v) { + re[0] = v[0]; + re[kFftLengthBy2] = v[1]; + im[0] = im[kFftLengthBy2] = 0; + for (size_t k = 1, j = 2; k < kFftLengthBy2; ++k) { + re[k] = v[j++]; + im[k] = v[j++]; + } + } + + // Copies the data into an interleaved array. + void CopyToPackedArray(std::array* v) const { + RTC_DCHECK(v); + (*v)[0] = re[0]; + (*v)[1] = re[kFftLengthBy2]; + for (size_t k = 1, j = 2; k < kFftLengthBy2; ++k) { + (*v)[j++] = re[k]; + (*v)[j++] = im[k]; + } + } + + std::array re; + std::array im; +}; + +} // namespace webrtc + +#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_FFT_DATA_H_ diff --git a/webrtc/modules/audio_processing/aec3/fft_data_unittest.cc b/webrtc/modules/audio_processing/aec3/fft_data_unittest.cc new file mode 100644 index 0000000000..e5881cff77 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/fft_data_unittest.cc @@ -0,0 +1,163 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/fft_data.h" + +#include "webrtc/system_wrappers/include/cpu_features_wrapper.h" +#include "webrtc/test/gtest.h" +#include "webrtc/typedefs.h" + +namespace webrtc { + +#if defined(WEBRTC_ARCH_X86_FAMILY) +// Verifies that the optimized methods are bitexact to their reference +// counterparts. +TEST(FftData, TestOptimizations) { + if (WebRtc_GetCPUInfo(kSSE2) != 0) { + FftData x; + + for (size_t k = 0; k < x.re.size(); ++k) { + x.re[k] = k + 1; + } + + x.im[0] = x.im[x.im.size() - 1] = 0.f; + for (size_t k = 1; k < x.im.size() - 1; ++k) { + x.im[k] = 2.f * (k + 1); + } + + std::array spectrum; + std::array spectrum_sse2; + x.Spectrum(Aec3Optimization::kNone, &spectrum); + x.Spectrum(Aec3Optimization::kSse2, &spectrum_sse2); + EXPECT_EQ(spectrum, spectrum_sse2); + } +} +#endif + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +// Verifies the check for null output in CopyToPackedArray. +TEST(FftData, NonNullCopyToPackedArrayOutput) { + EXPECT_DEATH(FftData().CopyToPackedArray(nullptr), ""); +} + +// Verifies the check for null output in Spectrum. +TEST(FftData, NonNullSpectrumOutput) { + EXPECT_DEATH(FftData().Spectrum(Aec3Optimization::kNone, nullptr), ""); +} + +#endif + +// Verifies that the Assign method properly copies the data from the source and +// ensures that the imaginary components for the DC and Nyquist bins are 0. +TEST(FftData, Assign) { + FftData x; + FftData y; + + x.re.fill(1.f); + x.im.fill(2.f); + y.Assign(x); + EXPECT_EQ(x.re, y.re); + EXPECT_EQ(0.f, y.im[0]); + EXPECT_EQ(0.f, y.im[x.im.size() - 1]); + for (size_t k = 1; k < x.im.size() - 1; ++k) { + EXPECT_EQ(x.im[k], y.im[k]); + } +} + +// Verifies that the Clear method properly clears all the data. +TEST(FftData, Clear) { + FftData x_ref; + FftData x; + + x_ref.re.fill(0.f); + x_ref.im.fill(0.f); + + x.re.fill(1.f); + x.im.fill(2.f); + x.Clear(); + + EXPECT_EQ(x_ref.re, x.re); + EXPECT_EQ(x_ref.im, x.im); +} + +// Verifies that the spectrum is correctly computed. +TEST(FftData, Spectrum) { + FftData x; + + for (size_t k = 0; k < x.re.size(); ++k) { + x.re[k] = k + 1; + } + + x.im[0] = x.im[x.im.size() - 1] = 0.f; + for (size_t k = 1; k < x.im.size() - 1; ++k) { + x.im[k] = 2.f * (k + 1); + } + + std::array spectrum; + x.Spectrum(Aec3Optimization::kNone, &spectrum); + + EXPECT_EQ(x.re[0] * x.re[0], spectrum[0]); + EXPECT_EQ(x.re[spectrum.size() - 1] * x.re[spectrum.size() - 1], + spectrum[spectrum.size() - 1]); + for (size_t k = 1; k < spectrum.size() - 1; ++k) { + EXPECT_EQ(x.re[k] * x.re[k] + x.im[k] * x.im[k], spectrum[k]); + } +} + +// Verifies that the functionality in CopyToPackedArray works as intended. +TEST(FftData, CopyToPackedArray) { + FftData x; + std::array x_packed; + + for (size_t k = 0; k < x.re.size(); ++k) { + x.re[k] = k + 1; + } + + x.im[0] = x.im[x.im.size() - 1] = 0.f; + for (size_t k = 1; k < x.im.size() - 1; ++k) { + x.im[k] = 2.f * (k + 1); + } + + x.CopyToPackedArray(&x_packed); + + EXPECT_EQ(x.re[0], x_packed[0]); + EXPECT_EQ(x.re[x.re.size() - 1], x_packed[1]); + for (size_t k = 1; k < x_packed.size() / 2; ++k) { + EXPECT_EQ(x.re[k], x_packed[2 * k]); + EXPECT_EQ(x.im[k], x_packed[2 * k + 1]); + } +} + +// Verifies that the functionality in CopyFromPackedArray works as intended +// (relies on that the functionality in CopyToPackedArray has been verified in +// the test above). +TEST(FftData, CopyFromPackedArray) { + FftData x_ref; + FftData x; + std::array x_packed; + + for (size_t k = 0; k < x_ref.re.size(); ++k) { + x_ref.re[k] = k + 1; + } + + x_ref.im[0] = x_ref.im[x_ref.im.size() - 1] = 0.f; + for (size_t k = 1; k < x_ref.im.size() - 1; ++k) { + x_ref.im[k] = 2.f * (k + 1); + } + + x_ref.CopyToPackedArray(&x_packed); + x.CopyFromPackedArray(x_packed); + + EXPECT_EQ(x_ref.re, x.re); + EXPECT_EQ(x_ref.im, x.im); +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/frame_blocker.cc b/webrtc/modules/audio_processing/aec3/frame_blocker.cc index b15b454384..a4e7893489 100644 --- a/webrtc/modules/audio_processing/aec3/frame_blocker.cc +++ b/webrtc/modules/audio_processing/aec3/frame_blocker.cc @@ -13,7 +13,6 @@ #include #include "webrtc/base/checks.h" -#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" namespace webrtc { diff --git a/webrtc/modules/audio_processing/aec3/frame_blocker.h b/webrtc/modules/audio_processing/aec3/frame_blocker.h index 958d5f2c0c..c4217202d5 100644 --- a/webrtc/modules/audio_processing/aec3/frame_blocker.h +++ b/webrtc/modules/audio_processing/aec3/frame_blocker.h @@ -16,7 +16,7 @@ #include "webrtc/base/array_view.h" #include "webrtc/base/constructormagic.h" -#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" namespace webrtc { diff --git a/webrtc/modules/audio_processing/aec3/frame_blocker_unittest.cc b/webrtc/modules/audio_processing/aec3/frame_blocker_unittest.cc index 498fa8eabc..217615c798 100644 --- a/webrtc/modules/audio_processing/aec3/frame_blocker_unittest.cc +++ b/webrtc/modules/audio_processing/aec3/frame_blocker_unittest.cc @@ -14,7 +14,7 @@ #include #include -#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" #include "webrtc/modules/audio_processing/aec3/block_framer.h" #include "webrtc/test/gtest.h" diff --git a/webrtc/modules/audio_processing/aec3/main_filter_update_gain.cc b/webrtc/modules/audio_processing/aec3/main_filter_update_gain.cc new file mode 100644 index 0000000000..f3531f0622 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/main_filter_update_gain.cc @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/main_filter_update_gain.h" + +#include +#include + +#include "webrtc/base/atomicops.h" +#include "webrtc/base/checks.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" +#include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" + +namespace webrtc { +namespace { + +constexpr float kHErrorInitial = 10000.f; + +} // namespace + +int MainFilterUpdateGain::instance_count_ = 0; + +MainFilterUpdateGain::MainFilterUpdateGain() + : data_dumper_( + new ApmDataDumper(rtc::AtomicOps::Increment(&instance_count_))), + poor_excitation_counter_(1000) { + H_error_.fill(kHErrorInitial); +} + +MainFilterUpdateGain::~MainFilterUpdateGain() {} + +void MainFilterUpdateGain::HandleEchoPathChange() { + H_error_.fill(kHErrorInitial); +} + +void MainFilterUpdateGain::Compute( + const FftBuffer& render_buffer, + const RenderSignalAnalyzer& render_signal_analyzer, + const SubtractorOutput& subtractor_output, + const AdaptiveFirFilter& filter, + bool saturated_capture_signal, + FftData* gain_fft) { + RTC_DCHECK(gain_fft); + // Introducing shorter notation to improve readability. + const FftBuffer& X_buffer = render_buffer; + const FftData& E_main = subtractor_output.E_main; + const auto& E2_main = subtractor_output.E2_main; + const auto& E2_shadow = subtractor_output.E2_shadow; + FftData* G = gain_fft; + const size_t size_partitions = filter.SizePartitions(); + const auto& X2 = X_buffer.SpectralSum(size_partitions); + const auto& erl = filter.Erl(); + + ++call_counter_; + + if (render_signal_analyzer.PoorSignalExcitation()) { + poor_excitation_counter_ = 0; + } + + // Do not update the filter if the render is not sufficiently excited. + if (++poor_excitation_counter_ < size_partitions || + saturated_capture_signal || call_counter_ <= size_partitions) { + G->re.fill(0.f); + G->im.fill(0.f); + } else { + // Corresponds of WGN of power -46 dBFS. + constexpr float kX2Min = 44015068.0f; + std::array mu; + // mu = H_error / (0.5* H_error* X2 + n * E2). + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + mu[k] = + X2[k] > kX2Min + ? H_error_[k] / + (0.5f * H_error_[k] * X2[k] + size_partitions * E2_main[k]) + : 0.f; + } + + // Avoid updating the filter close to narrow bands in the render signals. + render_signal_analyzer.MaskRegionsAroundNarrowBands(&mu); + + // H_error = H_error - 0.5 * mu * X2 * H_error. + for (size_t k = 0; k < H_error_.size(); ++k) { + H_error_[k] -= 0.5f * mu[k] * X2[k] * H_error_[k]; + } + + // G = mu * E. + std::transform(mu.begin(), mu.end(), E_main.re.begin(), G->re.begin(), + std::multiplies()); + std::transform(mu.begin(), mu.end(), E_main.im.begin(), G->im.begin(), + std::multiplies()); + } + + // H_error = H_error + factor * erl. + std::array H_error_increase; + constexpr float kErlScaleAccurate = 1.f / 30.0f; + constexpr float kErlScaleInaccurate = 1.f / 10.0f; + std::transform(E2_shadow.begin(), E2_shadow.end(), E2_main.begin(), + H_error_increase.begin(), [&](float a, float b) { + return a >= b ? kErlScaleAccurate : kErlScaleInaccurate; + }); + std::transform(erl.begin(), erl.end(), H_error_increase.begin(), + H_error_increase.begin(), std::multiplies()); + std::transform(H_error_.begin(), H_error_.end(), H_error_increase.begin(), + H_error_.begin(), + [&](float a, float b) { return std::max(a + b, 0.1f); }); + + data_dumper_->DumpRaw("aec3_main_gain_H_error", H_error_); +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/main_filter_update_gain.h b/webrtc/modules/audio_processing/aec3/main_filter_update_gain.h new file mode 100644 index 0000000000..9a3d8eef92 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/main_filter_update_gain.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_MAIN_FILTER_UPDATE_GAIN_H_ +#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_MAIN_FILTER_UPDATE_GAIN_H_ + +#include +#include + +#include "webrtc/base/constructormagic.h" +#include "webrtc/modules/audio_processing/aec3/adaptive_fir_filter.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" +#include "webrtc/modules/audio_processing/aec3/render_signal_analyzer.h" +#include "webrtc/modules/audio_processing/aec3/subtractor_output.h" +#include "webrtc/modules/audio_processing/aec3/fft_buffer.h" + +namespace webrtc { + +class ApmDataDumper; + +// Provides functionality for computing the adaptive gain for the main filter. +class MainFilterUpdateGain { + public: + MainFilterUpdateGain(); + ~MainFilterUpdateGain(); + + // Takes action in the case of a known echo path change. + void HandleEchoPathChange(); + + // Computes the gain. + void Compute(const FftBuffer& render_buffer, + const RenderSignalAnalyzer& render_signal_analyzer, + const SubtractorOutput& subtractor_output, + const AdaptiveFirFilter& filter, + bool saturated_capture_signal, + FftData* gain_fft); + + private: + static int instance_count_; + std::unique_ptr data_dumper_; + std::array H_error_; + size_t poor_excitation_counter_; + size_t call_counter_ = 0; + RTC_DISALLOW_COPY_AND_ASSIGN(MainFilterUpdateGain); +}; + +} // namespace webrtc + +#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_MAIN_FILTER_UPDATE_GAIN_H_ diff --git a/webrtc/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc b/webrtc/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc new file mode 100644 index 0000000000..92b2f9e297 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc @@ -0,0 +1,287 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/main_filter_update_gain.h" + +#include +#include +#include + +#include "webrtc/base/random.h" +#include "webrtc/modules/audio_processing/aec3/adaptive_fir_filter.h" +#include "webrtc/modules/audio_processing/aec3/aec_state.h" +#include "webrtc/modules/audio_processing/aec3/fft_buffer.h" +#include "webrtc/modules/audio_processing/aec3/render_signal_analyzer.h" +#include "webrtc/modules/audio_processing/aec3/shadow_filter_update_gain.h" +#include "webrtc/modules/audio_processing/aec3/subtractor_output.h" +#include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" +#include "webrtc/modules/audio_processing/test/echo_canceller_test_tools.h" +#include "webrtc/test/gtest.h" + +namespace webrtc { +namespace { + +// Method for performing the simulations needed to test the main filter update +// gain functionality. +void RunFilterUpdateTest(int num_blocks_to_process, + size_t delay_samples, + const std::vector& blocks_with_echo_path_changes, + const std::vector& blocks_with_saturation, + bool use_silent_render_in_second_half, + std::array* e_last_block, + std::array* y_last_block, + FftData* G_last_block) { + ApmDataDumper data_dumper(42); + AdaptiveFirFilter main_filter(9, true, DetectOptimization(), &data_dumper); + AdaptiveFirFilter shadow_filter(9, true, DetectOptimization(), &data_dumper); + Aec3Fft fft; + FftBuffer X_buffer(Aec3Optimization::kNone, main_filter.SizePartitions(), + std::vector(1, main_filter.SizePartitions())); + std::array x_old; + x_old.fill(0.f); + ShadowFilterUpdateGain shadow_gain; + MainFilterUpdateGain main_gain; + Random random_generator(42U); + std::vector x(kBlockSize, 0.f); + std::vector y(kBlockSize, 0.f); + AecState aec_state; + RenderSignalAnalyzer render_signal_analyzer; + FftData X; + std::array s; + FftData S; + FftData G; + SubtractorOutput output; + output.Reset(); + FftData& E_main = output.E_main; + FftData& E_shadow = output.E_shadow; + std::array Y2; + std::array& E2_main = output.E2_main; + std::array& E2_shadow = output.E2_shadow; + std::array& e_main = output.e_main; + std::array& e_shadow = output.e_shadow; + Y2.fill(0.f); + + constexpr float kScale = 1.0f / kFftLengthBy2; + + DelayBuffer delay_buffer(delay_samples); + for (int k = 0; k < num_blocks_to_process; ++k) { + // Handle echo path changes. + if (std::find(blocks_with_echo_path_changes.begin(), + blocks_with_echo_path_changes.end(), + k) != blocks_with_echo_path_changes.end()) { + main_filter.HandleEchoPathChange(); + } + + // Handle saturation. + const bool saturation = + std::find(blocks_with_saturation.begin(), blocks_with_saturation.end(), + k) != blocks_with_saturation.end(); + + // Create the render signal. + if (use_silent_render_in_second_half && k > num_blocks_to_process / 2) { + std::fill(x.begin(), x.end(), 0.f); + } else { + RandomizeSampleVector(&random_generator, x); + } + delay_buffer.Delay(x, y); + fft.PaddedFft(x, x_old, &X); + X_buffer.Insert(X); + render_signal_analyzer.Update(X_buffer, aec_state.FilterDelay()); + + // Apply the main filter. + main_filter.Filter(X_buffer, &S); + fft.Ifft(S, &s); + std::transform(y.begin(), y.end(), s.begin() + kFftLengthBy2, + e_main.begin(), + [&](float a, float b) { return a - b * kScale; }); + std::for_each(e_main.begin(), e_main.end(), [](float& a) { + a = std::max(std::min(a, 32767.0f), -32768.0f); + }); + fft.ZeroPaddedFft(e_main, &E_main); + + // Apply the shadow filter. + shadow_filter.Filter(X_buffer, &S); + fft.Ifft(S, &s); + std::transform(y.begin(), y.end(), s.begin() + kFftLengthBy2, + e_shadow.begin(), + [&](float a, float b) { return a - b * kScale; }); + std::for_each(e_shadow.begin(), e_shadow.end(), [](float& a) { + a = std::max(std::min(a, 32767.0f), -32768.0f); + }); + fft.ZeroPaddedFft(e_shadow, &E_shadow); + + // Compute spectra for future use. + E_main.Spectrum(Aec3Optimization::kNone, &output.E2_main); + E_shadow.Spectrum(Aec3Optimization::kNone, &output.E2_shadow); + + // Adapt the shadow filter. + shadow_gain.Compute(X_buffer, render_signal_analyzer, E_shadow, + shadow_filter.SizePartitions(), saturation, &G); + shadow_filter.Adapt(X_buffer, G); + + // Adapt the main filter + main_gain.Compute(X_buffer, render_signal_analyzer, output, main_filter, + saturation, &G); + main_filter.Adapt(X_buffer, G); + + // Update the delay. + aec_state.Update(main_filter.FilterFrequencyResponse(), + rtc::Optional(), X_buffer, E2_main, E2_shadow, Y2, + x, EchoPathVariability(false, false), false); + } + + std::copy(e_main.begin(), e_main.end(), e_last_block->begin()); + std::copy(y.begin(), y.end(), y_last_block->begin()); + std::copy(G.re.begin(), G.re.end(), G_last_block->re.begin()); + std::copy(G.im.begin(), G.im.end(), G_last_block->im.begin()); +} + +std::string ProduceDebugText(size_t delay) { + std::ostringstream ss; + ss << "Delay: " << delay; + return ss.str(); +} + +} // namespace + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +// Verifies that the check for non-null output gain parameter works. +TEST(MainFilterUpdateGain, NullDataOutputGain) { + ApmDataDumper data_dumper(42); + AdaptiveFirFilter filter(9, true, DetectOptimization(), &data_dumper); + FftBuffer X_buffer(Aec3Optimization::kNone, filter.SizePartitions(), + std::vector(1, filter.SizePartitions())); + RenderSignalAnalyzer analyzer; + SubtractorOutput output; + MainFilterUpdateGain gain; + EXPECT_DEATH(gain.Compute(X_buffer, analyzer, output, filter, false, nullptr), + ""); +} + +#endif + +// Verifies that the gain formed causes the filter using it to converge. +TEST(MainFilterUpdateGain, GainCausesFilterToConverge) { + std::vector blocks_with_echo_path_changes; + std::vector blocks_with_saturation; + for (size_t delay_samples : {0, 64, 150, 200, 301}) { + SCOPED_TRACE(ProduceDebugText(delay_samples)); + + std::array e; + std::array y; + FftData G; + + RunFilterUpdateTest(500, delay_samples, blocks_with_echo_path_changes, + blocks_with_saturation, false, &e, &y, &G); + + // Verify that the main filter is able to perform well. + EXPECT_LT(1000 * std::inner_product(e.begin(), e.end(), e.begin(), 0.f), + std::inner_product(y.begin(), y.end(), y.begin(), 0.f)); + } +} + +// Verifies that the magnitude of the gain on average decreases for a +// persistently exciting signal. +TEST(MainFilterUpdateGain, DecreasingGain) { + std::vector blocks_with_echo_path_changes; + std::vector blocks_with_saturation; + + std::array e; + std::array y; + FftData G_a; + FftData G_b; + FftData G_c; + std::array G_a_power; + std::array G_b_power; + std::array G_c_power; + + RunFilterUpdateTest(100, 65, blocks_with_echo_path_changes, + blocks_with_saturation, false, &e, &y, &G_a); + RunFilterUpdateTest(200, 65, blocks_with_echo_path_changes, + blocks_with_saturation, false, &e, &y, &G_b); + RunFilterUpdateTest(300, 65, blocks_with_echo_path_changes, + blocks_with_saturation, false, &e, &y, &G_c); + + G_a.Spectrum(Aec3Optimization::kNone, &G_a_power); + G_b.Spectrum(Aec3Optimization::kNone, &G_b_power); + G_c.Spectrum(Aec3Optimization::kNone, &G_c_power); + + EXPECT_GT(std::accumulate(G_a_power.begin(), G_a_power.end(), 0.), + std::accumulate(G_b_power.begin(), G_b_power.end(), 0.)); + + EXPECT_GT(std::accumulate(G_b_power.begin(), G_b_power.end(), 0.), + std::accumulate(G_c_power.begin(), G_c_power.end(), 0.)); +} + +// Verifies that the gain is zero when there is saturation and that the internal +// error estimates cause the gain to increase after a period of saturation. +TEST(MainFilterUpdateGain, SaturationBehavior) { + std::vector blocks_with_echo_path_changes; + std::vector blocks_with_saturation; + for (int k = 99; k < 200; ++k) { + blocks_with_saturation.push_back(k); + } + + std::array e; + std::array y; + FftData G_a; + FftData G_b; + FftData G_a_ref; + G_a_ref.re.fill(0.f); + G_a_ref.im.fill(0.f); + + std::array G_a_power; + std::array G_b_power; + + RunFilterUpdateTest(100, 65, blocks_with_echo_path_changes, + blocks_with_saturation, false, &e, &y, &G_a); + + EXPECT_EQ(G_a_ref.re, G_a.re); + EXPECT_EQ(G_a_ref.im, G_a.im); + + RunFilterUpdateTest(99, 65, blocks_with_echo_path_changes, + blocks_with_saturation, false, &e, &y, &G_a); + RunFilterUpdateTest(201, 65, blocks_with_echo_path_changes, + blocks_with_saturation, false, &e, &y, &G_b); + + G_a.Spectrum(Aec3Optimization::kNone, &G_a_power); + G_b.Spectrum(Aec3Optimization::kNone, &G_b_power); + + EXPECT_LT(std::accumulate(G_a_power.begin(), G_a_power.end(), 0.), + std::accumulate(G_b_power.begin(), G_b_power.end(), 0.)); +} + +// Verifies that the gain increases after an echo path change. +TEST(MainFilterUpdateGain, EchoPathChangeBehavior) { + std::vector blocks_with_echo_path_changes; + std::vector blocks_with_saturation; + blocks_with_echo_path_changes.push_back(99); + + std::array e; + std::array y; + FftData G_a; + FftData G_b; + std::array G_a_power; + std::array G_b_power; + + RunFilterUpdateTest(99, 65, blocks_with_echo_path_changes, + blocks_with_saturation, false, &e, &y, &G_a); + RunFilterUpdateTest(100, 65, blocks_with_echo_path_changes, + blocks_with_saturation, false, &e, &y, &G_b); + + G_a.Spectrum(Aec3Optimization::kNone, &G_a_power); + G_b.Spectrum(Aec3Optimization::kNone, &G_b_power); + + EXPECT_LT(std::accumulate(G_a_power.begin(), G_a_power.end(), 0.), + std::accumulate(G_b_power.begin(), G_b_power.end(), 0.)); +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/matched_filter.cc b/webrtc/modules/audio_processing/aec3/matched_filter.cc index f187159911..64596b53c0 100644 --- a/webrtc/modules/audio_processing/aec3/matched_filter.cc +++ b/webrtc/modules/audio_processing/aec3/matched_filter.cc @@ -9,6 +9,10 @@ */ #include "webrtc/modules/audio_processing/aec3/matched_filter.h" +#include "webrtc/typedefs.h" +#if defined(WEBRTC_ARCH_X86_FAMILY) +#include +#endif #include #include @@ -16,6 +20,131 @@ #include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" namespace webrtc { +namespace aec3 { + +#if defined(WEBRTC_ARCH_X86_FAMILY) + +void MatchedFilterCore_SSE2(size_t x_start_index, + float x2_sum_threshold, + rtc::ArrayView x, + rtc::ArrayView y, + rtc::ArrayView h, + bool* filters_updated, + float* error_sum) { + // Process for all samples in the sub-block. + for (size_t i = 0; i < kSubBlockSize; ++i) { + // Apply the matched filter as filter * x. and compute x * x. + float x2_sum = 0.f; + float s = 0; + size_t x_index = x_start_index; + RTC_DCHECK_EQ(0, h.size() % 4); + + __m128 s_128 = _mm_set1_ps(0); + __m128 x2_sum_128 = _mm_set1_ps(0); + + size_t k = 0; + if (h.size() > (x.size() - x_index)) { + const size_t limit = x.size() - x_index; + for (; (k + 3) < limit; k += 4, x_index += 4) { + const __m128 x_k = _mm_loadu_ps(&x[x_index]); + const __m128 h_k = _mm_loadu_ps(&h[k]); + const __m128 xx = _mm_mul_ps(x_k, x_k); + x2_sum_128 = _mm_add_ps(x2_sum_128, xx); + const __m128 hx = _mm_mul_ps(h_k, x_k); + s_128 = _mm_add_ps(s_128, hx); + } + + for (; k < limit; ++k, ++x_index) { + x2_sum += x[x_index] * x[x_index]; + s += h[k] * x[x_index]; + } + x_index = 0; + } + + for (; k + 3 < h.size(); k += 4, x_index += 4) { + const __m128 x_k = _mm_loadu_ps(&x[x_index]); + const __m128 h_k = _mm_loadu_ps(&h[k]); + const __m128 xx = _mm_mul_ps(x_k, x_k); + x2_sum_128 = _mm_add_ps(x2_sum_128, xx); + const __m128 hx = _mm_mul_ps(h_k, x_k); + s_128 = _mm_add_ps(s_128, hx); + } + + for (; k < h.size(); ++k, ++x_index) { + x2_sum += x[x_index] * x[x_index]; + s += h[k] * x[x_index]; + } + + float* v = reinterpret_cast(&x2_sum_128); + x2_sum += v[0] + v[1] + v[2] + v[3]; + v = reinterpret_cast(&s_128); + s += v[0] + v[1] + v[2] + v[3]; + + // Compute the matched filter error. + const float e = std::min(32767.f, std::max(-32768.f, y[i] - s)); + (*error_sum) += e * e; + + // Update the matched filter estimate in an NLMS manner. + if (x2_sum > x2_sum_threshold) { + RTC_DCHECK_LT(0.f, x2_sum); + const float alpha = 0.7f * e / x2_sum; + + // filter = filter + 0.7 * (y - filter * x) / x * x. + size_t x_index = x_start_index; + for (size_t k = 0; k < h.size(); ++k) { + h[k] += alpha * x[x_index]; + x_index = x_index < (x.size() - 1) ? x_index + 1 : 0; + } + *filters_updated = true; + } + + x_start_index = x_start_index > 0 ? x_start_index - 1 : x.size() - 1; + } +} +#endif + +void MatchedFilterCore(size_t x_start_index, + float x2_sum_threshold, + rtc::ArrayView x, + rtc::ArrayView y, + rtc::ArrayView h, + bool* filters_updated, + float* error_sum) { + // Process for all samples in the sub-block. + for (size_t i = 0; i < kSubBlockSize; ++i) { + // Apply the matched filter as filter * x. and compute x * x. + float x2_sum = 0.f; + float s = 0; + size_t x_index = x_start_index; + for (size_t k = 0; k < h.size(); ++k) { + x2_sum += x[x_index] * x[x_index]; + s += h[k] * x[x_index]; + x_index = x_index < (x.size() - 1) ? x_index + 1 : 0; + } + + // Compute the matched filter error. + const float e = std::min(32767.f, std::max(-32768.f, y[i] - s)); + (*error_sum) += e * e; + + // Update the matched filter estimate in an NLMS manner. + if (x2_sum > x2_sum_threshold) { + RTC_DCHECK_LT(0.f, x2_sum); + const float alpha = 0.7f * e / x2_sum; + + // filter = filter + 0.7 * (y - filter * x) / x * x. + size_t x_index = x_start_index; + for (size_t k = 0; k < h.size(); ++k) { + h[k] += alpha * x[x_index]; + x_index = x_index < (x.size() - 1) ? x_index + 1 : 0; + } + *filters_updated = true; + } + + x_start_index = x_start_index > 0 ? x_start_index - 1 : x.size() - 1; + } +} + +} // namespace aec3 MatchedFilter::IndexedBuffer::IndexedBuffer(size_t size) : data(size, 0.f) { RTC_DCHECK_EQ(0, size % kSubBlockSize); @@ -24,10 +153,12 @@ MatchedFilter::IndexedBuffer::IndexedBuffer(size_t size) : data(size, 0.f) { MatchedFilter::IndexedBuffer::~IndexedBuffer() = default; MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper, + Aec3Optimization optimization, size_t window_size_sub_blocks, int num_matched_filters, size_t alignment_shift_sub_blocks) : data_dumper_(data_dumper), + optimization_(optimization), filter_intra_lag_shift_(alignment_shift_sub_blocks * kSubBlockSize), filters_(num_matched_filters, std::vector(window_size_sub_blocks * kSubBlockSize, 0.f)), @@ -65,65 +196,17 @@ void MatchedFilter::Update(const std::array& render, (x_buffer_.index + alignment_shift + kSubBlockSize - 1) % x_buffer_.data.size(); - // Process for all samples in the sub-block. - for (size_t i = 0; i < kSubBlockSize; ++i) { - // As x_buffer is a circular buffer, all of the processing is split into - // two loops around the wrapping of the buffer. - const size_t loop_size_1 = - std::min(filters_[n].size(), x_buffer_.data.size() - x_start_index); - const size_t loop_size_2 = filters_[n].size() - loop_size_1; - RTC_DCHECK_EQ(filters_[n].size(), loop_size_1 + loop_size_2); - - // x * x. - float x2_sum = std::inner_product( - x_buffer_.data.begin() + x_start_index, - x_buffer_.data.begin() + x_start_index + loop_size_1, - x_buffer_.data.begin() + x_start_index, 0.f); - // Apply the matched filter as filter * x. - float s = std::inner_product(filters_[n].begin(), - filters_[n].begin() + loop_size_1, - x_buffer_.data.begin() + x_start_index, 0.f); - - if (loop_size_2 > 0) { - // Update the cumulative sum of x * x. - x2_sum = std::inner_product(x_buffer_.data.begin(), - x_buffer_.data.begin() + loop_size_2, - x_buffer_.data.begin(), x2_sum); - - // Compute the matched filter output filter * x in a cumulative manner. - s = std::inner_product(x_buffer_.data.begin(), - x_buffer_.data.begin() + loop_size_2, - filters_[n].begin() + loop_size_1, s); - } - - // Compute the matched filter error. - const float e = std::min(32767.f, std::max(-32768.f, y[i] - s)); - error_sum += e * e; - - // Update the matched filter estimate in an NLMS manner. - if (x2_sum > x2_sum_threshold) { - filters_updated = true; - RTC_DCHECK_LT(0.f, x2_sum); - const float alpha = 0.7f * e / x2_sum; - - // filter = filter + 0.7 * (y - filter * x) / x * x. - std::transform(filters_[n].begin(), filters_[n].begin() + loop_size_1, - x_buffer_.data.begin() + x_start_index, - filters_[n].begin(), - [&](float a, float b) { return a + alpha * b; }); - - if (loop_size_2 > 0) { - // filter = filter + 0.7 * (y - filter * x) / x * x. - std::transform(x_buffer_.data.begin(), - x_buffer_.data.begin() + loop_size_2, - filters_[n].begin() + loop_size_1, - filters_[n].begin() + loop_size_1, - [&](float a, float b) { return b + alpha * a; }); - } - } - - x_start_index = - x_start_index > 0 ? x_start_index - 1 : x_buffer_.data.size() - 1; + switch (optimization_) { +#if defined(WEBRTC_ARCH_X86_FAMILY) + case Aec3Optimization::kSse2: + aec3::MatchedFilterCore_SSE2(x_start_index, x2_sum_threshold, + x_buffer_.data, y, filters_[n], + &filters_updated, &error_sum); + break; +#endif + default: + aec3::MatchedFilterCore(x_start_index, x2_sum_threshold, x_buffer_.data, + y, filters_[n], &filters_updated, &error_sum); } // Compute anchor for the matched filter error. @@ -140,11 +223,12 @@ void MatchedFilter::Update(const std::array& render, [](float a, float b) -> bool { return a * a < b * b; })); // Update the lag estimates for the matched filter. - const float kMatchingFilterThreshold = 0.3f; - lag_estimates_[n] = - LagEstimate(error_sum_anchor - error_sum, - error_sum < kMatchingFilterThreshold * error_sum_anchor, - lag_estimate + alignment_shift, filters_updated); + const float kMatchingFilterThreshold = 0.1f; + lag_estimates_[n] = LagEstimate( + error_sum_anchor - error_sum, + (lag_estimate > 2 && lag_estimate < (filters_[n].size() - 10) && + error_sum < kMatchingFilterThreshold * error_sum_anchor), + lag_estimate + alignment_shift, filters_updated); // TODO(peah): Remove once development of EchoCanceller3 is fully done. RTC_DCHECK_EQ(4, filters_.size()); diff --git a/webrtc/modules/audio_processing/aec3/matched_filter.h b/webrtc/modules/audio_processing/aec3/matched_filter.h index 3e09d4b971..4be4cc2d59 100644 --- a/webrtc/modules/audio_processing/aec3/matched_filter.h +++ b/webrtc/modules/audio_processing/aec3/matched_filter.h @@ -17,9 +17,34 @@ #include "webrtc/base/constructormagic.h" #include "webrtc/base/optional.h" -#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" namespace webrtc { +namespace aec3 { + +#if defined(WEBRTC_ARCH_X86_FAMILY) + +// Filter core for the matched filter that is optimized for SSE2. +void MatchedFilterCore_SSE2(size_t x_start_index, + float x2_sum_threshold, + rtc::ArrayView x, + rtc::ArrayView y, + rtc::ArrayView h, + bool* filters_updated, + float* error_sum); + +#endif + +// Filter core for the matched filter. +void MatchedFilterCore(size_t x_start_index, + float x2_sum_threshold, + rtc::ArrayView x, + rtc::ArrayView y, + rtc::ArrayView h, + bool* filters_updated, + float* error_sum); + +} // namespace aec3 class ApmDataDumper; @@ -41,6 +66,7 @@ class MatchedFilter { }; MatchedFilter(ApmDataDumper* data_dumper, + Aec3Optimization optimization, size_t window_size_sub_blocks, int num_matched_filters, size_t alignment_shift_sub_blocks); @@ -71,6 +97,7 @@ class MatchedFilter { }; ApmDataDumper* const data_dumper_; + const Aec3Optimization optimization_; const size_t filter_intra_lag_shift_; std::vector> filters_; std::vector lag_estimates_; diff --git a/webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.cc b/webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.cc index d9176efa94..3734ed80e0 100644 --- a/webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.cc +++ b/webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.cc @@ -59,7 +59,7 @@ rtc::Optional MatchedFilterLagAggregator::Aggregate( candidate_ = lag_estimates[best_lag_estimate_index].lag; } - return candidate_counter_ >= 10 ? rtc::Optional(candidate_) + return candidate_counter_ >= 15 ? rtc::Optional(candidate_) : rtc::Optional(); } diff --git a/webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator_unittest.cc b/webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator_unittest.cc index b76116ba0b..a0b86e55a4 100644 --- a/webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator_unittest.cc +++ b/webrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator_unittest.cc @@ -15,7 +15,7 @@ #include #include "webrtc/base/array_view.h" -#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" #include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" #include "webrtc/test/gtest.h" @@ -32,7 +32,7 @@ void VerifyNoAggregateOutputForRepeatedLagAggregation( } constexpr size_t kThresholdForRequiredLagUpdatesInARow = 10; -constexpr size_t kThresholdForRequiredIdenticalLagAggregates = 10; +constexpr size_t kThresholdForRequiredIdenticalLagAggregates = 15; } // namespace diff --git a/webrtc/modules/audio_processing/aec3/matched_filter_unittest.cc b/webrtc/modules/audio_processing/aec3/matched_filter_unittest.cc index 993ebc8b92..952b3710fe 100644 --- a/webrtc/modules/audio_processing/aec3/matched_filter_unittest.cc +++ b/webrtc/modules/audio_processing/aec3/matched_filter_unittest.cc @@ -10,16 +10,23 @@ #include "webrtc/modules/audio_processing/aec3/matched_filter.h" +#include "webrtc/typedefs.h" +#if defined(WEBRTC_ARCH_X86_FAMILY) +#include +#endif #include #include #include -#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/base/random.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" #include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" #include "webrtc/modules/audio_processing/test/echo_canceller_test_tools.h" +#include "webrtc/system_wrappers/include/cpu_features_wrapper.h" #include "webrtc/test/gtest.h" namespace webrtc { +namespace aec3 { namespace { std::string ProduceDebugText(size_t delay) { @@ -34,6 +41,47 @@ constexpr size_t kNumMatchedFilters = 4; } // namespace +#if defined(WEBRTC_ARCH_X86_FAMILY) +// Verifies that the optimized methods are bitexact to their reference +// counterparts. +TEST(MatchedFilter, TestOptimizations) { + bool use_sse2 = (WebRtc_GetCPUInfo(kSSE2) != 0); + if (use_sse2) { + Random random_generator(42U); + std::vector x(2000); + RandomizeSampleVector(&random_generator, x); + std::vector y(kSubBlockSize); + std::vector h_SSE2(512); + std::vector h(512); + int x_index = 0; + for (int k = 0; k < 1000; ++k) { + RandomizeSampleVector(&random_generator, y); + + bool filters_updated = false; + float error_sum = 0.f; + bool filters_updated_SSE2 = false; + float error_sum_SSE2 = 0.f; + + MatchedFilterCore_SSE2(x_index, h.size() * 150.f * 150.f, x, y, h_SSE2, + &filters_updated_SSE2, &error_sum_SSE2); + + MatchedFilterCore(x_index, h.size() * 150.f * 150.f, x, y, h, + &filters_updated, &error_sum); + + EXPECT_EQ(filters_updated, filters_updated_SSE2); + EXPECT_NEAR(error_sum, error_sum_SSE2, error_sum / 100000.f); + + for (size_t j = 0; j < h.size(); ++j) { + EXPECT_NEAR(h[j], h_SSE2[j], 0.001f); + } + + x_index = (x_index + kSubBlockSize) % x.size(); + } + } +} + +#endif + // Verifies that the matched filter produces proper lag estimates for // artificially // delayed signals. @@ -44,10 +92,11 @@ TEST(MatchedFilter, LagEstimation) { render.fill(0.f); capture.fill(0.f); ApmDataDumper data_dumper(0); - for (size_t delay_samples : {0, 64, 150, 200, 800, 1000}) { + for (size_t delay_samples : {5, 64, 150, 200, 800, 1000}) { SCOPED_TRACE(ProduceDebugText(delay_samples)); DelayBuffer signal_delay_buffer(delay_samples); - MatchedFilter filter(&data_dumper, kWindowSizeSubBlocks, kNumMatchedFilters, + MatchedFilter filter(&data_dumper, DetectOptimization(), + kWindowSizeSubBlocks, kNumMatchedFilters, kAlignmentShiftSubBlocks); // Analyze the correlation between render and capture. @@ -107,8 +156,8 @@ TEST(MatchedFilter, LagNotReliableForUncorrelatedRenderAndCapture) { render.fill(0.f); capture.fill(0.f); ApmDataDumper data_dumper(0); - MatchedFilter filter(&data_dumper, kWindowSizeSubBlocks, kNumMatchedFilters, - kAlignmentShiftSubBlocks); + MatchedFilter filter(&data_dumper, DetectOptimization(), kWindowSizeSubBlocks, + kNumMatchedFilters, kAlignmentShiftSubBlocks); // Analyze the correlation between render and capture. for (size_t k = 0; k < 100; ++k) { @@ -136,8 +185,8 @@ TEST(MatchedFilter, LagNotUpdatedForLowLevelRender) { render.fill(0.f); capture.fill(0.f); ApmDataDumper data_dumper(0); - MatchedFilter filter(&data_dumper, kWindowSizeSubBlocks, kNumMatchedFilters, - kAlignmentShiftSubBlocks); + MatchedFilter filter(&data_dumper, DetectOptimization(), kWindowSizeSubBlocks, + kNumMatchedFilters, kAlignmentShiftSubBlocks); // Analyze the correlation between render and capture. for (size_t k = 0; k < 100; ++k) { @@ -167,7 +216,8 @@ TEST(MatchedFilter, NumberOfLagEstimates) { ApmDataDumper data_dumper(0); for (size_t num_matched_filters = 0; num_matched_filters < 10; ++num_matched_filters) { - MatchedFilter filter(&data_dumper, 32, num_matched_filters, 1); + MatchedFilter filter(&data_dumper, DetectOptimization(), 32, + num_matched_filters, 1); EXPECT_EQ(num_matched_filters, filter.GetLagEstimates().size()); } } @@ -177,14 +227,15 @@ TEST(MatchedFilter, NumberOfLagEstimates) { // Verifies the check for non-zero windows size. TEST(MatchedFilter, ZeroWindowSize) { ApmDataDumper data_dumper(0); - EXPECT_DEATH(MatchedFilter(&data_dumper, 0, 1, 1), ""); + EXPECT_DEATH(MatchedFilter(&data_dumper, DetectOptimization(), 0, 1, 1), ""); } // Verifies the check for non-null data dumper. TEST(MatchedFilter, NullDataDumper) { - EXPECT_DEATH(MatchedFilter(nullptr, 1, 1, 1), ""); + EXPECT_DEATH(MatchedFilter(nullptr, DetectOptimization(), 1, 1, 1), ""); } #endif +} // namespace aec3 } // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/mock/mock_render_delay_buffer.h b/webrtc/modules/audio_processing/aec3/mock/mock_render_delay_buffer.h index 3b17dbe2fe..93c8e0d1c9 100644 --- a/webrtc/modules/audio_processing/aec3/mock/mock_render_delay_buffer.h +++ b/webrtc/modules/audio_processing/aec3/mock/mock_render_delay_buffer.h @@ -13,7 +13,7 @@ #include -#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" #include "webrtc/modules/audio_processing/aec3/render_delay_buffer.h" #include "webrtc/test/gmock.h" diff --git a/webrtc/modules/audio_processing/aec3/output_selector.cc b/webrtc/modules/audio_processing/aec3/output_selector.cc new file mode 100644 index 0000000000..a8700cbe3e --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/output_selector.cc @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/output_selector.h" + +#include +#include + +#include "webrtc/base/checks.h" + +namespace webrtc { +namespace { + +// Performs the transition between the signals in a smooth manner. +void SmoothFrameTransition(bool from_y_to_e, + rtc::ArrayView e, + rtc::ArrayView y) { + RTC_DCHECK_LT(0u, e.size()); + RTC_DCHECK_EQ(y.size(), e.size()); + + const float change_factor = (from_y_to_e ? 1.f : -1.f) / e.size(); + float averaging = from_y_to_e ? 0.f : 1.f; + for (size_t k = 0; k < e.size(); ++k) { + y[k] += averaging * (e[k] - y[k]); + averaging += change_factor; + } + RTC_DCHECK_EQ(from_y_to_e ? 1.f : 0.f, averaging); +} + +float BlockPower(rtc::ArrayView x) { + return std::accumulate(x.begin(), x.end(), 0.f, + [](float a, float b) -> float { return a + b * b; }); +} + +} // namespace + +OutputSelector::OutputSelector() = default; + +OutputSelector::~OutputSelector() = default; + +void OutputSelector::FormLinearOutput( + rtc::ArrayView subtractor_output, + rtc::ArrayView capture) { + RTC_DCHECK_EQ(subtractor_output.size(), capture.size()); + rtc::ArrayView& e_main = subtractor_output; + rtc::ArrayView y = capture; + + const bool subtractor_output_is_best = + BlockPower(y) > 1.5f * BlockPower(e_main); + output_change_counter_ = subtractor_output_is_best != use_subtractor_output_ + ? output_change_counter_ + 1 + : 0; + + if (subtractor_output_is_best != use_subtractor_output_ && + ((subtractor_output_is_best && output_change_counter_ > 3) || + (!subtractor_output_is_best && output_change_counter_ > 10))) { + use_subtractor_output_ = subtractor_output_is_best; + SmoothFrameTransition(use_subtractor_output_, e_main, y); + output_change_counter_ = 0; + } else if (use_subtractor_output_) { + std::copy(e_main.begin(), e_main.end(), y.begin()); + } +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/output_selector.h b/webrtc/modules/audio_processing/aec3/output_selector.h new file mode 100644 index 0000000000..943e547cde --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/output_selector.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_OUTPUT_SELECTOR_H_ +#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_OUTPUT_SELECTOR_H_ + +#include "webrtc/base/array_view.h" +#include "webrtc/base/constructormagic.h" + +namespace webrtc { + +// Performs the selection between which of the linear aec output and the +// microphone signal should be used as the echo suppressor output. +class OutputSelector { + public: + OutputSelector(); + ~OutputSelector(); + + // Forms the most appropriate output signal. + void FormLinearOutput(rtc::ArrayView subtractor_output, + rtc::ArrayView capture); + + // Returns true if the linear aec output is the one used. + bool UseSubtractorOutput() const { return use_subtractor_output_; } + + private: + bool use_subtractor_output_ = false; + int output_change_counter_ = 0; + RTC_DISALLOW_COPY_AND_ASSIGN(OutputSelector); +}; + +} // namespace webrtc + +#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_OUTPUT_SELECTOR_H_ diff --git a/webrtc/modules/audio_processing/aec3/output_selector_unittest.cc b/webrtc/modules/audio_processing/aec3/output_selector_unittest.cc new file mode 100644 index 0000000000..49f671d2b1 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/output_selector_unittest.cc @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/output_selector.h" + +#include +#include + +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" +#include "webrtc/test/gtest.h" + +namespace webrtc { + +// Verifies that the switching between the signals in the output works as +// intended. +TEST(OutputSelector, ProperSwitching) { + OutputSelector selector; + + constexpr int kNumBlocksToSwitchToSubtractor = 3; + constexpr int kNumBlocksToSwitchFromSubtractor = 10; + + std::array weaker; + std::array stronger; + std::array y; + std::array e; + weaker.fill(10.f); + stronger.fill(20.f); + + bool y_is_weakest = false; + + const auto form_e_and_y = [&](bool y_equals_weaker) { + if (y_equals_weaker) { + std::copy(weaker.begin(), weaker.end(), y.begin()); + std::copy(stronger.begin(), stronger.end(), e.begin()); + } else { + std::copy(stronger.begin(), stronger.end(), y.begin()); + std::copy(weaker.begin(), weaker.end(), e.begin()); + } + }; + + for (int k = 0; k < 30; ++k) { + // Verify that it takes a while for the signals transition to take effect. + const int num_blocks_to_switch = y_is_weakest + ? kNumBlocksToSwitchFromSubtractor + : kNumBlocksToSwitchToSubtractor; + for (int j = 0; j < num_blocks_to_switch; ++j) { + form_e_and_y(y_is_weakest); + selector.FormLinearOutput(e, y); + EXPECT_EQ(stronger, y); + EXPECT_EQ(y_is_weakest, selector.UseSubtractorOutput()); + } + + // Verify that the transition block is a mix between the signals. + form_e_and_y(y_is_weakest); + selector.FormLinearOutput(e, y); + EXPECT_NE(weaker, y); + EXPECT_NE(stronger, y); + EXPECT_EQ(!y_is_weakest, selector.UseSubtractorOutput()); + + y_is_weakest = !y_is_weakest; + } +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/power_echo_model.cc b/webrtc/modules/audio_processing/aec3/power_echo_model.cc new file mode 100644 index 0000000000..8ad5486e07 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/power_echo_model.cc @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "webrtc/modules/audio_processing/aec3/power_echo_model.h" + +#include +#include + +#include "webrtc/base/optional.h" + +namespace webrtc { +namespace { + +// Computes the spectral power over that last 20 frames. +void RecentMaximum(const FftBuffer& X_buffer, + std::array* R2) { + R2->fill(0.f); + for (size_t j = 0; j < 20; ++j) { + std::transform(R2->begin(), R2->end(), X_buffer.Spectrum(j).begin(), + R2->begin(), + [](float a, float b) { return std::max(a, b); }); + } +} + +constexpr float kHInitial = 10.f; +constexpr int kUpdateCounterInitial = 300; + +} // namespace + +PowerEchoModel::PowerEchoModel() { + H2_.fill(CountedFloat(kHInitial, kUpdateCounterInitial)); +} + +PowerEchoModel::~PowerEchoModel() = default; + +void PowerEchoModel::HandleEchoPathChange( + const EchoPathVariability& variability) { + if (variability.gain_change) { + H2_.fill(CountedFloat(kHInitial, kUpdateCounterInitial)); + } +} + +void PowerEchoModel::EstimateEcho( + const FftBuffer& render_buffer, + const std::array& capture_spectrum, + const AecState& aec_state, + std::array* echo_spectrum) { + RTC_DCHECK(echo_spectrum); + + const FftBuffer& X_buffer = render_buffer; + const auto& Y2 = capture_spectrum; + std::array* S2 = echo_spectrum; + + // Choose delay to use. + const rtc::Optional delay = + aec_state.FilterDelay() + ? aec_state.FilterDelay() + : (aec_state.ExternalDelay() ? rtc::Optional(std::min( + *aec_state.ExternalDelay(), + X_buffer.Buffer().size() - 1)) + : rtc::Optional()); + + // Compute R2. + std::array render_max; + if (!delay) { + RecentMaximum(render_buffer, &render_max); + } + const std::array& X2_active = + delay ? render_buffer.Spectrum(*delay) : render_max; + + if (!aec_state.SaturatedCapture()) { + // Corresponds of WGN of power -46dBFS. + constexpr float kX2Min = 44015068.0f; + const int max_update_counter_value = delay ? 300 : 500; + + std::array new_H2; + + // new_H2 = Y2 / X2. + std::transform(X2_active.begin(), X2_active.end(), Y2.begin(), + new_H2.begin(), + [&](float a, float b) { return a > kX2Min ? b / a : -1.f; }); + + // Lambda for updating H2 in a maximum statistics manner. + auto H2_updater = [&](float a, CountedFloat b) { + if (a > 0) { + if (a > b.value) { + b.counter = max_update_counter_value; + b.value = a; + } else if (--b.counter <= 0) { + b.value = std::max(b.value * 0.9f, 1.f); + } + } + return b; + }; + + std::transform(new_H2.begin(), new_H2.end(), H2_.begin(), H2_.begin(), + H2_updater); + } + + // S2 = H2*X2_active. + std::transform(H2_.begin(), H2_.end(), X2_active.begin(), S2->begin(), + [](CountedFloat a, float b) { return a.value * b; }); +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/power_echo_model.h b/webrtc/modules/audio_processing/aec3/power_echo_model.h new file mode 100644 index 0000000000..8df82f0982 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/power_echo_model.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_POWER_ECHO_MODEL_H_ +#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_POWER_ECHO_MODEL_H_ + +#include + +#include "webrtc/base/constructormagic.h" +#include "webrtc/base/optional.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" +#include "webrtc/modules/audio_processing/aec3/aec_state.h" +#include "webrtc/modules/audio_processing/aec3/echo_path_variability.h" +#include "webrtc/modules/audio_processing/aec3/fft_buffer.h" + +namespace webrtc { + +// Provides an echo model based on power spectral estimates that estimates the +// echo spectrum. +class PowerEchoModel { + public: + PowerEchoModel(); + ~PowerEchoModel(); + + // Ajusts the model according to echo path changes. + void HandleEchoPathChange(const EchoPathVariability& variability); + + // Updates the echo model and estimates the echo spectrum. + void EstimateEcho( + const FftBuffer& render_buffer, + const std::array& capture_spectrum, + const AecState& aec_state, + std::array* echo_spectrum); + + // Returns the minimum required farend buffer length. + size_t MinFarendBufferLength() const { return kRenderBufferSize; } + + private: + // Provides a float value that is coupled with a counter. + struct CountedFloat { + CountedFloat() : value(0.f), counter(0) {} + CountedFloat(float value, int counter) : value(value), counter(counter) {} + float value; + int counter; + }; + + const size_t kRenderBufferSize = 100; + std::array H2_; + + RTC_DISALLOW_COPY_AND_ASSIGN(PowerEchoModel); +}; +} // namespace webrtc + +#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_POWER_ECHO_MODEL_H_ diff --git a/webrtc/modules/audio_processing/aec3/power_echo_model_unittest.cc b/webrtc/modules/audio_processing/aec3/power_echo_model_unittest.cc new file mode 100644 index 0000000000..019be60389 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/power_echo_model_unittest.cc @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/power_echo_model.h" + +#include +#include +#include + +#include "webrtc/base/random.h" +#include "webrtc/modules/audio_processing/aec3/aec_state.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" +#include "webrtc/modules/audio_processing/aec3/aec3_fft.h" +#include "webrtc/modules/audio_processing/aec3/echo_path_variability.h" +#include "webrtc/modules/audio_processing/test/echo_canceller_test_tools.h" + +#include "webrtc/test/gtest.h" + +namespace webrtc { +namespace { + +std::string ProduceDebugText(size_t delay, bool known_delay) { + std::ostringstream ss; + ss << "True delay: " << delay; + ss << ", Delay known: " << (known_delay ? "true" : "false"); + return ss.str(); +} + +} // namespace + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +// Verifies that the check for non-null output parameter works. +TEST(PowerEchoModel, NullEstimateEchoOutput) { + PowerEchoModel model; + std::array Y2; + AecState aec_state; + FftBuffer X_buffer(Aec3Optimization::kNone, model.MinFarendBufferLength(), + std::vector(1, model.MinFarendBufferLength())); + + EXPECT_DEATH(model.EstimateEcho(X_buffer, Y2, aec_state, nullptr), ""); +} + +#endif + +TEST(PowerEchoModel, BasicSetup) { + PowerEchoModel model; + Random random_generator(42U); + AecState aec_state; + Aec3Fft fft; + std::array Y2; + std::array S2; + std::array E2_main; + std::array E2_shadow; + std::array x_old; + std::array y; + std::vector x(kBlockSize, 0.f); + FftData X; + FftData Y; + x_old.fill(0.f); + + FftBuffer X_buffer(Aec3Optimization::kNone, model.MinFarendBufferLength(), + std::vector(1, model.MinFarendBufferLength())); + + for (size_t delay_samples : {0, 64, 301}) { + DelayBuffer delay_buffer(delay_samples); + auto model_applier = [&](int num_iterations, float y_scale, + bool known_delay) { + for (int k = 0; k < num_iterations; ++k) { + RandomizeSampleVector(&random_generator, x); + delay_buffer.Delay(x, y); + std::for_each(y.begin(), y.end(), [&](float& a) { a *= y_scale; }); + + fft.PaddedFft(x, x_old, &X); + X_buffer.Insert(X); + + fft.ZeroPaddedFft(y, &Y); + Y.Spectrum(Aec3Optimization::kNone, &Y2); + + aec_state.Update(std::vector>( + 10, std::array()), + known_delay ? rtc::Optional(delay_samples) + : rtc::Optional(), + X_buffer, E2_main, E2_shadow, Y2, x, + EchoPathVariability(false, false), false); + + model.EstimateEcho(X_buffer, Y2, aec_state, &S2); + } + }; + + for (int j = 0; j < 2; ++j) { + bool known_delay = j == 0; + SCOPED_TRACE(ProduceDebugText(delay_samples, known_delay)); + // Verify that the echo path estimates converges downwards to a fairly + // tight bound estimate. + model_applier(600, 1.f, known_delay); + for (size_t k = 1; k < S2.size() - 1; ++k) { + EXPECT_LE(Y2[k], 2.f * S2[k]); + } + + // Verify that stronger echo paths are detected immediately. + model_applier(100, 10.f, known_delay); + for (size_t k = 1; k < S2.size() - 1; ++k) { + EXPECT_LE(Y2[k], 5.f * S2[k]); + } + + // Verify that there is a delay until a weaker echo path is detected. + model_applier(50, 100.f, known_delay); + model_applier(50, 1.f, known_delay); + for (size_t k = 1; k < S2.size() - 1; ++k) { + EXPECT_LE(100.f * Y2[k], S2[k]); + } + + // Verify that an echo path change causes the echo path estimate to be + // reset. + model_applier(600, 0.1f, known_delay); + model.HandleEchoPathChange(EchoPathVariability(true, false)); + model_applier(50, 0.1f, known_delay); + for (size_t k = 1; k < S2.size() - 1; ++k) { + EXPECT_LE(10.f * Y2[k], S2[k]); + } + } + } +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/render_delay_buffer.cc b/webrtc/modules/audio_processing/aec3/render_delay_buffer.cc index 8fc6b926b1..d3c2aaabd9 100644 --- a/webrtc/modules/audio_processing/aec3/render_delay_buffer.cc +++ b/webrtc/modules/audio_processing/aec3/render_delay_buffer.cc @@ -15,7 +15,7 @@ #include "webrtc/base/checks.h" #include "webrtc/base/constructormagic.h" -#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" #include "webrtc/system_wrappers/include/logging.h" namespace webrtc { diff --git a/webrtc/modules/audio_processing/aec3/render_delay_buffer_unittest.cc b/webrtc/modules/audio_processing/aec3/render_delay_buffer_unittest.cc index 448ba0b355..75a1b8ba3e 100644 --- a/webrtc/modules/audio_processing/aec3/render_delay_buffer_unittest.cc +++ b/webrtc/modules/audio_processing/aec3/render_delay_buffer_unittest.cc @@ -17,7 +17,7 @@ #include "webrtc/base/array_view.h" #include "webrtc/base/random.h" -#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" #include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" #include "webrtc/test/gtest.h" diff --git a/webrtc/modules/audio_processing/aec3/render_delay_controller.cc b/webrtc/modules/audio_processing/aec3/render_delay_controller.cc index 981b744eef..834e1e7641 100644 --- a/webrtc/modules/audio_processing/aec3/render_delay_controller.cc +++ b/webrtc/modules/audio_processing/aec3/render_delay_controller.cc @@ -16,7 +16,7 @@ #include "webrtc/base/atomicops.h" #include "webrtc/base/constructormagic.h" -#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" #include "webrtc/modules/audio_processing/aec3/echo_path_delay_estimator.h" #include "webrtc/system_wrappers/include/logging.h" @@ -147,7 +147,7 @@ size_t RenderDelayControllerImpl::GetDelay( const int headroom = echo_path_delay_samples_ - delay_ * kBlockSize; RTC_DCHECK_LE(0, headroom); headroom_samples_ = rtc::Optional(headroom); - } else if (++blocks_since_last_delay_estimate_ > 25000) { + } else if (++blocks_since_last_delay_estimate_ > 250 * 20) { headroom_samples_ = rtc::Optional(); } diff --git a/webrtc/modules/audio_processing/aec3/render_delay_controller_unittest.cc b/webrtc/modules/audio_processing/aec3/render_delay_controller_unittest.cc index 9d382adb0f..169cb7e7fd 100644 --- a/webrtc/modules/audio_processing/aec3/render_delay_controller_unittest.cc +++ b/webrtc/modules/audio_processing/aec3/render_delay_controller_unittest.cc @@ -17,7 +17,7 @@ #include #include "webrtc/base/random.h" -#include "webrtc/modules/audio_processing/aec3/aec3_constants.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" #include "webrtc/modules/audio_processing/aec3/render_delay_buffer.h" #include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" #include "webrtc/modules/audio_processing/test/echo_canceller_test_tools.h" @@ -33,16 +33,17 @@ std::string ProduceDebugText(int sample_rate_hz) { } std::string ProduceDebugText(int sample_rate_hz, size_t delay) { - std::ostringstream ss(ProduceDebugText(sample_rate_hz)); - ss << ", Delay: " << delay; + std::ostringstream ss; + ss << ProduceDebugText(sample_rate_hz) << ", Delay: " << delay; return ss.str(); } std::string ProduceDebugText(int sample_rate_hz, size_t delay, size_t max_jitter) { - std::ostringstream ss(ProduceDebugText(sample_rate_hz, delay)); - ss << ", Max Api call jitter: " << max_jitter; + std::ostringstream ss; + ss << ProduceDebugText(sample_rate_hz, delay) + << ", Max Api call jitter: " << max_jitter; return ss.str(); } @@ -111,7 +112,7 @@ TEST(RenderDelayController, Alignment) { std::vector capture_block(kBlockSize, 0.f); size_t delay_blocks = 0; for (auto rate : {8000, 16000, 32000, 48000}) { - for (size_t delay_samples : {0, 50, 150, 200, 800, 4000}) { + for (size_t delay_samples : {15, 50, 150, 200, 800, 4000}) { SCOPED_TRACE(ProduceDebugText(rate, delay_samples)); std::unique_ptr render_delay_buffer( RenderDelayBuffer::Create(250, NumBandsForRate(rate), @@ -119,7 +120,7 @@ TEST(RenderDelayController, Alignment) { std::unique_ptr delay_controller( RenderDelayController::Create(rate, *render_delay_buffer)); DelayBuffer signal_delay_buffer(delay_samples); - for (size_t k = 0; k < (300 + delay_samples / kBlockSize); ++k) { + for (size_t k = 0; k < (400 + delay_samples / kBlockSize); ++k) { RandomizeSampleVector(&random_generator, render_block); signal_delay_buffer.Delay(render_block, capture_block); EXPECT_TRUE(delay_controller->AnalyzeRender(render_block)); @@ -152,7 +153,7 @@ TEST(RenderDelayController, AlignmentWithJitter) { std::vector render_block(kBlockSize, 0.f); std::vector capture_block(kBlockSize, 0.f); for (auto rate : {8000, 16000, 32000, 48000}) { - for (size_t delay_samples : {0, 50, 800}) { + for (size_t delay_samples : {15, 50, 800}) { for (size_t max_jitter : {1, 9, 20}) { size_t delay_blocks = 0; SCOPED_TRACE(ProduceDebugText(rate, delay_samples, max_jitter)); diff --git a/webrtc/modules/audio_processing/aec3/render_signal_analyzer.cc b/webrtc/modules/audio_processing/aec3/render_signal_analyzer.cc new file mode 100644 index 0000000000..da1b571911 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/render_signal_analyzer.cc @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/render_signal_analyzer.h" + +#include + +#include "webrtc/base/checks.h" + +namespace webrtc { + +namespace { +constexpr size_t kCounterThreshold = 5; + +} // namespace + +RenderSignalAnalyzer::RenderSignalAnalyzer() { + narrow_band_counters_.fill(0); +} +RenderSignalAnalyzer::~RenderSignalAnalyzer() = default; + +void RenderSignalAnalyzer::Update( + const FftBuffer& X_buffer, + const rtc::Optional& delay_partitions) { + if (!delay_partitions) { + narrow_band_counters_.fill(0); + return; + } + + const std::array& X2 = + X_buffer.Spectrum(*delay_partitions); + + // Detect narrow band signal regions. + for (size_t k = 1; k < (X2.size() - 1); ++k) { + narrow_band_counters_[k - 1] = X2[k] > 3 * std::max(X2[k - 1], X2[k + 1]) + ? narrow_band_counters_[k - 1] + 1 + : 0; + } +} + +void RenderSignalAnalyzer::MaskRegionsAroundNarrowBands( + std::array* v) const { + RTC_DCHECK(v); + + // Set v to zero around narrow band signal regions. + if (narrow_band_counters_[0] > kCounterThreshold) { + (*v)[1] = (*v)[0] = 0.f; + } + for (size_t k = 2; k < kFftLengthBy2 - 1; ++k) { + if (narrow_band_counters_[k - 1] > kCounterThreshold) { + (*v)[k - 2] = (*v)[k - 1] = (*v)[k] = (*v)[k + 1] = (*v)[k + 2] = 0.f; + } + } + if (narrow_band_counters_[kFftLengthBy2 - 2] > kCounterThreshold) { + (*v)[kFftLengthBy2] = (*v)[kFftLengthBy2 - 1] = 0.f; + } +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/render_signal_analyzer.h b/webrtc/modules/audio_processing/aec3/render_signal_analyzer.h new file mode 100644 index 0000000000..1218fe6f93 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/render_signal_analyzer.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_RENDER_SIGNAL_ANALYZER_H_ +#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_RENDER_SIGNAL_ANALYZER_H_ + +#include +#include + +#include "webrtc/base/constructormagic.h" +#include "webrtc/base/optional.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" +#include "webrtc/modules/audio_processing/aec3/fft_buffer.h" + +namespace webrtc { + +// Provides functionality for analyzing the properties of the render signal. +class RenderSignalAnalyzer { + public: + RenderSignalAnalyzer(); + ~RenderSignalAnalyzer(); + + // Updates the render signal analysis with the most recent render signal. + void Update(const FftBuffer& X_buffer, + const rtc::Optional& delay_partitions); + + // Returns true if the render signal is poorly exciting. + bool PoorSignalExcitation() const { + RTC_DCHECK_LT(2, narrow_band_counters_.size()); + return std::any_of(narrow_band_counters_.begin(), + narrow_band_counters_.end(), + [](size_t a) { return a > 10; }); + } + + // Zeros the array around regions with narrow bands signal characteristics. + void MaskRegionsAroundNarrowBands( + std::array* v) const; + + private: + std::array narrow_band_counters_; + + RTC_DISALLOW_COPY_AND_ASSIGN(RenderSignalAnalyzer); +}; + +} // namespace webrtc + +#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_RENDER_SIGNAL_ANALYZER_H_ diff --git a/webrtc/modules/audio_processing/aec3/render_signal_analyzer_unittest.cc b/webrtc/modules/audio_processing/aec3/render_signal_analyzer_unittest.cc new file mode 100644 index 0000000000..e8b462f048 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/render_signal_analyzer_unittest.cc @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/render_signal_analyzer.h" + +#include +#include +#include + +#include "webrtc/base/array_view.h" +#include "webrtc/base/random.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" +#include "webrtc/modules/audio_processing/aec3/aec3_fft.h" +#include "webrtc/modules/audio_processing/aec3/fft_buffer.h" +#include "webrtc/modules/audio_processing/aec3/fft_data.h" +#include "webrtc/modules/audio_processing/test/echo_canceller_test_tools.h" +#include "webrtc/test/gtest.h" + +namespace webrtc { +namespace { + +constexpr float kPi = 3.141592f; + +void ProduceSinusoid(int sample_rate_hz, + float sinusoidal_frequency_hz, + size_t* sample_counter, + rtc::ArrayView x) { + // Produce a sinusoid of the specified frequency. + for (size_t k = *sample_counter, j = 0; k < (*sample_counter + kBlockSize); + ++k, ++j) { + x[j] = + 32767.f * sin(2.f * kPi * sinusoidal_frequency_hz * k / sample_rate_hz); + } + *sample_counter = *sample_counter + kBlockSize; +} + +} // namespace + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) +// Verifies that the check for non-null output parameter works. +TEST(RenderSignalAnalyzer, NullMaskOutput) { + RenderSignalAnalyzer analyzer; + EXPECT_DEATH(analyzer.MaskRegionsAroundNarrowBands(nullptr), ""); +} + +#endif + +// Verify that no narrow bands are detected in a Gaussian noise signal. +TEST(RenderSignalAnalyzer, NoFalseDetectionOfNarrowBands) { + RenderSignalAnalyzer analyzer; + Random random_generator(42U); + std::vector x(kBlockSize, 0.f); + std::array x_old; + FftData X; + Aec3Fft fft; + FftBuffer X_buffer(Aec3Optimization::kNone, 1, std::vector(1, 1)); + std::array mask; + x_old.fill(0.f); + + for (size_t k = 0; k < 100; ++k) { + RandomizeSampleVector(&random_generator, x); + fft.PaddedFft(x, x_old, &X); + X_buffer.Insert(X); + analyzer.Update(X_buffer, rtc::Optional(0)); + } + + mask.fill(1.f); + analyzer.MaskRegionsAroundNarrowBands(&mask); + EXPECT_TRUE( + std::all_of(mask.begin(), mask.end(), [](float a) { return a == 1.f; })); + EXPECT_FALSE(analyzer.PoorSignalExcitation()); +} + +// Verify that a sinusiod signal is detected as narrow bands. +TEST(RenderSignalAnalyzer, NarrowBandDetection) { + RenderSignalAnalyzer analyzer; + Random random_generator(42U); + std::vector x(kBlockSize, 0.f); + std::array x_old; + FftData X; + Aec3Fft fft; + FftBuffer X_buffer(Aec3Optimization::kNone, 1, std::vector(1, 1)); + std::array mask; + x_old.fill(0.f); + constexpr int kSinusFrequencyBin = 32; + + auto generate_sinusoid_test = [&](bool known_delay) { + size_t sample_counter = 0; + for (size_t k = 0; k < 100; ++k) { + ProduceSinusoid(16000, 16000 / 2 * kSinusFrequencyBin / kFftLengthBy2, + &sample_counter, x); + fft.PaddedFft(x, x_old, &X); + X_buffer.Insert(X); + analyzer.Update( + X_buffer, + known_delay ? rtc::Optional(0) : rtc::Optional()); + } + }; + + generate_sinusoid_test(true); + mask.fill(1.f); + analyzer.MaskRegionsAroundNarrowBands(&mask); + for (int k = 0; k < static_cast(mask.size()); ++k) { + EXPECT_EQ(abs(k - kSinusFrequencyBin) <= 2 ? 0.f : 1.f, mask[k]); + } + EXPECT_TRUE(analyzer.PoorSignalExcitation()); + + // Verify that no bands are detected as narrow when the delay is unknown. + generate_sinusoid_test(false); + mask.fill(1.f); + analyzer.MaskRegionsAroundNarrowBands(&mask); + std::for_each(mask.begin(), mask.end(), [](float a) { EXPECT_EQ(1.f, a); }); + EXPECT_FALSE(analyzer.PoorSignalExcitation()); +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/residual_echo_estimator.cc b/webrtc/modules/audio_processing/aec3/residual_echo_estimator.cc new file mode 100644 index 0000000000..38d5beb5fb --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/residual_echo_estimator.cc @@ -0,0 +1,215 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/residual_echo_estimator.h" + +#include +#include + +#include "webrtc/base/checks.h" + +namespace webrtc { +namespace { + +constexpr float kSaturationLeakageFactor = 10.f; +constexpr size_t kSaturationLeakageBlocks = 10; + +// Estimates the residual echo power when there is no detection correlation +// between the render and capture signals. +void InfiniteErlPowerEstimate( + size_t active_render_counter, + size_t blocks_since_last_saturation, + const std::array& S2_fallback, + std::array* R2) { + if (active_render_counter > 5 * 250) { + // After an amount of active render samples for which an echo should have + // been detected in the capture signal if the ERL was not infinite, set the + // residual echo to 0. + R2->fill(0.f); + } else { + // Before certainty has been reached about the presence of echo, use the + // fallback echo power estimate as the residual echo estimate. Add a leakage + // factor when there is saturation. + std::copy(S2_fallback.begin(), S2_fallback.end(), R2->begin()); + if (blocks_since_last_saturation < kSaturationLeakageBlocks) { + std::for_each(R2->begin(), R2->end(), + [](float& a) { a *= kSaturationLeakageFactor; }); + } + } +} + +// Estimates the echo power in an half-duplex manner. +void HalfDuplexPowerEstimate(bool active_render, + const std::array& Y2, + std::array* R2) { + // Set the residual echo power to the power of the capture signal. + if (active_render) { + std::copy(Y2.begin(), Y2.end(), R2->begin()); + } else { + R2->fill(0.f); + } +} + +// Estimates the residual echo power based on gains. +void GainBasedPowerEstimate( + size_t external_delay, + const FftBuffer& X_buffer, + size_t blocks_since_last_saturation, + const std::array& bands_with_reliable_filter, + const std::array& echo_path_gain, + const std::array& S2_fallback, + std::array* R2) { + const auto& X2 = X_buffer.Spectrum(external_delay); + + // Base the residual echo power on gain of the linear echo path estimate if + // that is reliable, otherwise use the fallback echo path estimate. Add a + // leakage factor when there is saturation. + for (size_t k = 0; k < R2->size(); ++k) { + (*R2)[k] = bands_with_reliable_filter[k] ? echo_path_gain[k] * X2[k] + : S2_fallback[k]; + } + if (blocks_since_last_saturation < kSaturationLeakageBlocks) { + std::for_each(R2->begin(), R2->end(), + [](float& a) { a *= kSaturationLeakageFactor; }); + } +} + +// Estimates the residual echo power based on the linear echo path. +void ErleBasedPowerEstimate( + bool headset_detected, + const FftBuffer& X_buffer, + bool using_subtractor_output, + size_t linear_filter_based_delay, + size_t blocks_since_last_saturation, + bool poorly_aligned_filter, + const std::array& bands_with_reliable_filter, + const std::array& echo_path_gain, + const std::array& S2_fallback, + const std::array& S2_linear, + const std::array& Y2, + const std::array& erle, + const std::array& erl, + std::array* R2) { + // Residual echo power after saturation. + if (blocks_since_last_saturation < kSaturationLeakageBlocks) { + for (size_t k = 0; k < R2->size(); ++k) { + (*R2)[k] = kSaturationLeakageFactor * + (bands_with_reliable_filter[k] && using_subtractor_output + ? S2_linear[k] + : std::min(S2_fallback[k], Y2[k])); + } + return; + } + + // Residual echo power when a headset is used. + if (headset_detected) { + const auto& X2 = X_buffer.Spectrum(linear_filter_based_delay); + for (size_t k = 0; k < R2->size(); ++k) { + RTC_DCHECK_LT(0.f, erle[k]); + (*R2)[k] = bands_with_reliable_filter[k] && using_subtractor_output + ? S2_linear[k] / erle[k] + : std::min(S2_fallback[k], Y2[k]); + (*R2)[k] = std::min((*R2)[k], X2[k] * erl[k]); + } + return; + } + + // Residual echo power when the adaptive filter is poorly aligned. + if (poorly_aligned_filter) { + for (size_t k = 0; k < R2->size(); ++k) { + (*R2)[k] = bands_with_reliable_filter[k] && using_subtractor_output + ? S2_linear[k] + : std::min(S2_fallback[k], Y2[k]); + } + return; + } + + // Residual echo power when there is no recent saturation, no headset detected + // and when the adaptive filter is well aligned. + for (size_t k = 0; k < R2->size(); ++k) { + RTC_DCHECK_LT(0.f, erle[k]); + const auto& X2 = X_buffer.Spectrum(linear_filter_based_delay); + (*R2)[k] = bands_with_reliable_filter[k] && using_subtractor_output + ? S2_linear[k] / erle[k] + : std::min(echo_path_gain[k] * X2[k], Y2[k]); + } +} + +} // namespace + +ResidualEchoEstimator::ResidualEchoEstimator() { + echo_path_gain_.fill(0.f); +} + +ResidualEchoEstimator::~ResidualEchoEstimator() = default; + +void ResidualEchoEstimator::Estimate( + bool using_subtractor_output, + const AecState& aec_state, + const FftBuffer& X_buffer, + const std::vector>& H2, + const std::array& E2_main, + const std::array& E2_shadow, + const std::array& S2_linear, + const std::array& S2_fallback, + const std::array& Y2, + std::array* R2) { + RTC_DCHECK(R2); + const rtc::Optional& linear_filter_based_delay = + aec_state.FilterDelay(); + + // Update the echo path gain. + if (linear_filter_based_delay) { + std::copy(H2[*linear_filter_based_delay].begin(), + H2[*linear_filter_based_delay].end(), echo_path_gain_.begin()); + } + + // Counts the blocks since saturation. + if (aec_state.SaturatedCapture()) { + blocks_since_last_saturation_ = 0; + } else { + ++blocks_since_last_saturation_; + } + + // Counts the number of active render blocks that are in a row. + if (aec_state.ActiveRender()) { + ++active_render_counter_; + } + + const auto& bands_with_reliable_filter = aec_state.BandsWithReliableFilter(); + + if (aec_state.UsableLinearEstimate()) { + // Residual echo power estimation when the adaptive filter is reliable. + RTC_DCHECK(linear_filter_based_delay); + ErleBasedPowerEstimate( + aec_state.HeadsetDetected(), X_buffer, using_subtractor_output, + *linear_filter_based_delay, blocks_since_last_saturation_, + aec_state.PoorlyAlignedFilter(), bands_with_reliable_filter, + echo_path_gain_, S2_fallback, S2_linear, Y2, aec_state.Erle(), + aec_state.Erl(), R2); + } else if (aec_state.ModelBasedAecFeasible()) { + // Residual echo power when the adaptive filter is not reliable but still an + // external echo path delay is provided (and hence can be estimated). + RTC_DCHECK(aec_state.ExternalDelay()); + GainBasedPowerEstimate( + *aec_state.ExternalDelay(), X_buffer, blocks_since_last_saturation_, + bands_with_reliable_filter, echo_path_gain_, S2_fallback, R2); + } else if (aec_state.EchoLeakageDetected()) { + // Residual echo power when an external residual echo detection algorithm + // has deemed the echo canceller to leak echoes. + HalfDuplexPowerEstimate(aec_state.ActiveRender(), Y2, R2); + } else { + // Residual echo power when none of the other cases are fulfilled. + InfiniteErlPowerEstimate(active_render_counter_, + blocks_since_last_saturation_, S2_fallback, R2); + } +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/residual_echo_estimator.h b/webrtc/modules/audio_processing/aec3/residual_echo_estimator.h new file mode 100644 index 0000000000..6c59f434b9 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/residual_echo_estimator.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_RESIDUAL_ECHO_ESTIMATOR_H_ +#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_RESIDUAL_ECHO_ESTIMATOR_H_ + +#include +#include +#include + +#include "webrtc/base/array_view.h" +#include "webrtc/base/constructormagic.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" +#include "webrtc/modules/audio_processing/aec3/aec_state.h" +#include "webrtc/modules/audio_processing/aec3/fft_buffer.h" + +namespace webrtc { + +class ResidualEchoEstimator { + public: + ResidualEchoEstimator(); + ~ResidualEchoEstimator(); + + void Estimate(bool using_subtractor_output, + const AecState& aec_state, + const FftBuffer& X_buffer, + const std::vector>& H2, + const std::array& E2_main, + const std::array& E2_shadow, + const std::array& S2_linear, + const std::array& S2_fallback, + const std::array& Y2, + std::array* R2); + + private: + std::array echo_path_gain_; + size_t active_render_counter_ = 0; + size_t blocks_since_last_saturation_ = 1000; + + RTC_DISALLOW_COPY_AND_ASSIGN(ResidualEchoEstimator); +}; + +} // namespace webrtc + +#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_RESIDUAL_ECHO_ESTIMATOR_H_ diff --git a/webrtc/modules/audio_processing/aec3/residual_echo_estimator_unittest.cc b/webrtc/modules/audio_processing/aec3/residual_echo_estimator_unittest.cc new file mode 100644 index 0000000000..850ea4bed2 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/residual_echo_estimator_unittest.cc @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/residual_echo_estimator.h" + +#include "webrtc/base/random.h" +#include "webrtc/modules/audio_processing/aec3/aec_state.h" +#include "webrtc/modules/audio_processing/aec3/aec3_fft.h" +#include "webrtc/modules/audio_processing/test/echo_canceller_test_tools.h" +#include "webrtc/test/gtest.h" + +namespace webrtc { + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +// Verifies that the check for non-null output gains works. +TEST(ResidualEchoEstimator, NullOutputGains) { + AecState aec_state; + FftBuffer X_buffer(Aec3Optimization::kNone, 10, std::vector(1, 10)); + std::vector> H2; + std::array E2_main; + std::array E2_shadow; + std::array S2_linear; + std::array S2_fallback; + std::array Y2; + + EXPECT_DEATH(ResidualEchoEstimator().Estimate(true, aec_state, X_buffer, H2, + E2_main, E2_shadow, S2_linear, + S2_fallback, Y2, nullptr), + ""); +} + +#endif + +TEST(ResidualEchoEstimator, BasicTest) { + ResidualEchoEstimator estimator; + AecState aec_state; + FftBuffer X_buffer(Aec3Optimization::kNone, 10, std::vector(1, 10)); + std::array E2_main; + std::array E2_shadow; + std::array S2_linear; + std::array S2_fallback; + std::array Y2; + std::array R2; + EchoPathVariability echo_path_variability(false, false); + std::array x; + std::vector> H2(10); + Random random_generator(42U); + FftData X; + std::array x_old; + Aec3Fft fft; + + for (auto& H2_k : H2) { + H2_k.fill(0.01f); + } + H2[2].fill(10.f); + + constexpr float kLevel = 10.f; + E2_shadow.fill(kLevel); + E2_main.fill(kLevel); + S2_linear.fill(kLevel); + S2_fallback.fill(kLevel); + Y2.fill(kLevel); + + for (int k = 0; k < 100; ++k) { + RandomizeSampleVector(&random_generator, x); + fft.PaddedFft(x, x_old, &X); + X_buffer.Insert(X); + + aec_state.Update(H2, rtc::Optional(2), X_buffer, E2_main, E2_shadow, + Y2, x, echo_path_variability, false); + + estimator.Estimate(true, aec_state, X_buffer, H2, E2_main, E2_shadow, + S2_linear, S2_fallback, Y2, &R2); + } + std::for_each(R2.begin(), R2.end(), + [&](float a) { EXPECT_NEAR(kLevel, a, 0.1f); }); +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/shadow_filter_update_gain.cc b/webrtc/modules/audio_processing/aec3/shadow_filter_update_gain.cc new file mode 100644 index 0000000000..ed74b3ffb4 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/shadow_filter_update_gain.cc @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/shadow_filter_update_gain.h" + +#include +#include + +#include "webrtc/base/checks.h" + +namespace webrtc { + +void ShadowFilterUpdateGain::Compute( + const FftBuffer& X_buffer, + const RenderSignalAnalyzer& render_signal_analyzer, + const FftData& E_shadow, + size_t size_partitions, + bool saturated_capture_signal, + FftData* G) { + RTC_DCHECK(G); + ++call_counter_; + + if (render_signal_analyzer.PoorSignalExcitation()) { + poor_signal_excitation_counter_ = 0; + } + + // Do not update the filter if the render is not sufficiently excited. + if (++poor_signal_excitation_counter_ < size_partitions || + saturated_capture_signal || call_counter_ <= size_partitions) { + G->re.fill(0.f); + G->im.fill(0.f); + return; + } + + // Compute mu. + constexpr float kX2Min = 44015068.0f; + constexpr float kMuFixed = .5f; + std::array mu; + const auto& X2 = X_buffer.SpectralSum(size_partitions); + std::transform(X2.begin(), X2.end(), mu.begin(), + [&](float a) { return a > kX2Min ? kMuFixed / a : 0.f; }); + + // Avoid updating the filter close to narrow bands in the render signals. + render_signal_analyzer.MaskRegionsAroundNarrowBands(&mu); + + // G = mu * E * X2. + std::transform(mu.begin(), mu.end(), E_shadow.re.begin(), G->re.begin(), + std::multiplies()); + std::transform(mu.begin(), mu.end(), E_shadow.im.begin(), G->im.begin(), + std::multiplies()); +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/shadow_filter_update_gain.h b/webrtc/modules/audio_processing/aec3/shadow_filter_update_gain.h new file mode 100644 index 0000000000..0f414ff59f --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/shadow_filter_update_gain.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_SHADOW_FILTER_UPDATE_GAIN_H_ +#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_SHADOW_FILTER_UPDATE_GAIN_H_ + +#include "webrtc/base/constructormagic.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" +#include "webrtc/modules/audio_processing/aec3/fft_buffer.h" +#include "webrtc/modules/audio_processing/aec3/render_signal_analyzer.h" + +namespace webrtc { + +// Provides functionality for computing the fixed gain for the shadow filter. +class ShadowFilterUpdateGain { + public: + // Computes the gain. + void Compute(const FftBuffer& X_buffer, + const RenderSignalAnalyzer& render_signal_analyzer, + const FftData& E_shadow, + size_t size_partitions, + bool saturated_capture_signal, + FftData* G); + + private: + size_t poor_signal_excitation_counter_ = 0; + size_t call_counter_ = 0; +}; + +} // namespace webrtc + +#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_SHADOW_FILTER_UPDATE_GAIN_H_ diff --git a/webrtc/modules/audio_processing/aec3/shadow_filter_update_gain_unittest.cc b/webrtc/modules/audio_processing/aec3/shadow_filter_update_gain_unittest.cc new file mode 100644 index 0000000000..ab98eefebb --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/shadow_filter_update_gain_unittest.cc @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/shadow_filter_update_gain.h" + +#include +#include +#include +#include + +#include "webrtc/base/random.h" +#include "webrtc/modules/audio_processing/aec3/adaptive_fir_filter.h" +#include "webrtc/modules/audio_processing/aec3/aec_state.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" +#include "webrtc/modules/audio_processing/test/echo_canceller_test_tools.h" +#include "webrtc/test/gtest.h" + +namespace webrtc { +namespace { + +// Method for performing the simulations needed to test the main filter update +// gain functionality. +void RunFilterUpdateTest(int num_blocks_to_process, + size_t delay_samples, + const std::vector& blocks_with_saturation, + std::array* e_last_block, + std::array* y_last_block, + FftData* G_last_block) { + ApmDataDumper data_dumper(42); + AdaptiveFirFilter main_filter(9, true, DetectOptimization(), &data_dumper); + AdaptiveFirFilter shadow_filter(9, true, DetectOptimization(), &data_dumper); + Aec3Fft fft; + FftBuffer X_buffer(Aec3Optimization::kNone, main_filter.SizePartitions(), + std::vector(1, main_filter.SizePartitions())); + std::array x_old; + x_old.fill(0.f); + ShadowFilterUpdateGain shadow_gain; + Random random_generator(42U); + std::vector x(kBlockSize, 0.f); + std::vector y(kBlockSize, 0.f); + AecState aec_state; + RenderSignalAnalyzer render_signal_analyzer; + FftData X; + std::array s; + FftData S; + FftData G; + FftData E_shadow; + std::array e_shadow; + + constexpr float kScale = 1.0f / kFftLengthBy2; + + DelayBuffer delay_buffer(delay_samples); + for (int k = 0; k < num_blocks_to_process; ++k) { + // Handle saturation. + bool saturation = + std::find(blocks_with_saturation.begin(), blocks_with_saturation.end(), + k) != blocks_with_saturation.end(); + + // Create the render signal. + RandomizeSampleVector(&random_generator, x); + delay_buffer.Delay(x, y); + fft.PaddedFft(x, x_old, &X); + X_buffer.Insert(X); + render_signal_analyzer.Update( + X_buffer, rtc::Optional(delay_samples / kBlockSize)); + + shadow_filter.Filter(X_buffer, &S); + fft.Ifft(S, &s); + std::transform(y.begin(), y.end(), s.begin() + kFftLengthBy2, + e_shadow.begin(), + [&](float a, float b) { return a - b * kScale; }); + std::for_each(e_shadow.begin(), e_shadow.end(), [](float& a) { + a = std::max(std::min(a, 32767.0f), -32768.0f); + }); + fft.ZeroPaddedFft(e_shadow, &E_shadow); + + shadow_gain.Compute(X_buffer, render_signal_analyzer, E_shadow, + shadow_filter.SizePartitions(), saturation, &G); + shadow_filter.Adapt(X_buffer, G); + } + + std::copy(e_shadow.begin(), e_shadow.end(), e_last_block->begin()); + std::copy(y.begin(), y.end(), y_last_block->begin()); + std::copy(G.re.begin(), G.re.end(), G_last_block->re.begin()); + std::copy(G.im.begin(), G.im.end(), G_last_block->im.begin()); +} + +std::string ProduceDebugText(size_t delay) { + std::ostringstream ss; + ss << ", Delay: " << delay; + return ss.str(); +} + +} // namespace + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +// Verifies that the check for non-null output gain parameter works. +TEST(ShadowFilterUpdateGain, NullDataOutputGain) { + ApmDataDumper data_dumper(42); + FftBuffer X_buffer(Aec3Optimization::kNone, 1, std::vector(1, 1)); + RenderSignalAnalyzer analyzer; + FftData E; + ShadowFilterUpdateGain gain; + EXPECT_DEATH(gain.Compute(X_buffer, analyzer, E, 1, false, nullptr), ""); +} + +#endif + +// Verifies that the gain formed causes the filter using it to converge. +TEST(ShadowFilterUpdateGain, GainCausesFilterToConverge) { + std::vector blocks_with_echo_path_changes; + std::vector blocks_with_saturation; + for (size_t delay_samples : {0, 64, 150, 200, 301}) { + SCOPED_TRACE(ProduceDebugText(delay_samples)); + + std::array e; + std::array y; + FftData G; + + RunFilterUpdateTest(500, delay_samples, blocks_with_saturation, &e, &y, &G); + + // Verify that the main filter is able to perform well. + EXPECT_LT(1000 * std::inner_product(e.begin(), e.end(), e.begin(), 0.f), + std::inner_product(y.begin(), y.end(), y.begin(), 0.f)); + } +} + +// Verifies that the magnitude of the gain on average decreases for a +// persistently exciting signal. +TEST(ShadowFilterUpdateGain, DecreasingGain) { + std::vector blocks_with_echo_path_changes; + std::vector blocks_with_saturation; + + std::array e; + std::array y; + FftData G_a; + FftData G_b; + FftData G_c; + std::array G_a_power; + std::array G_b_power; + std::array G_c_power; + + RunFilterUpdateTest(100, 65, blocks_with_saturation, &e, &y, &G_a); + RunFilterUpdateTest(200, 65, blocks_with_saturation, &e, &y, &G_b); + RunFilterUpdateTest(300, 65, blocks_with_saturation, &e, &y, &G_c); + + G_a.Spectrum(Aec3Optimization::kNone, &G_a_power); + G_b.Spectrum(Aec3Optimization::kNone, &G_b_power); + G_c.Spectrum(Aec3Optimization::kNone, &G_c_power); + + EXPECT_GT(std::accumulate(G_a_power.begin(), G_a_power.end(), 0.), + std::accumulate(G_b_power.begin(), G_b_power.end(), 0.)); + + EXPECT_GT(std::accumulate(G_b_power.begin(), G_b_power.end(), 0.), + std::accumulate(G_c_power.begin(), G_c_power.end(), 0.)); +} + +// Verifies that the gain is zero when there is saturation. +TEST(ShadowFilterUpdateGain, SaturationBehavior) { + std::vector blocks_with_echo_path_changes; + std::vector blocks_with_saturation; + for (int k = 99; k < 200; ++k) { + blocks_with_saturation.push_back(k); + } + + std::array e; + std::array y; + FftData G_a; + FftData G_a_ref; + G_a_ref.re.fill(0.f); + G_a_ref.im.fill(0.f); + + RunFilterUpdateTest(100, 65, blocks_with_saturation, &e, &y, &G_a); + + EXPECT_EQ(G_a_ref.re, G_a.re); + EXPECT_EQ(G_a_ref.im, G_a.im); +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/subtractor.cc b/webrtc/modules/audio_processing/aec3/subtractor.cc new file mode 100644 index 0000000000..2dcbbde0ba --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/subtractor.cc @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/subtractor.h" + +#include + +#include "webrtc/base/array_view.h" +#include "webrtc/base/checks.h" +#include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" + +namespace webrtc { + +namespace { + +void ComputeError(const Aec3Fft& fft, + const FftData& S, + rtc::ArrayView y, + std::array* e, + FftData* E) { + std::array s; + fft.Ifft(S, &s); + constexpr float kScale = 1.0f / kFftLengthBy2; + std::transform(y.begin(), y.end(), s.begin() + kFftLengthBy2, e->begin(), + [&](float a, float b) { return a - b * kScale; }); + std::for_each(e->begin(), e->end(), [](float& a) { + a = std::max(std::min(a, 32767.0f), -32768.0f); + }); + fft.ZeroPaddedFft(*e, E); +} +} // namespace + +std::vector Subtractor::NumBlocksInRenderSums() const { + if (kMainFilterSizePartitions != kShadowFilterSizePartitions) { + return {kMainFilterSizePartitions, kShadowFilterSizePartitions}; + } else { + return {kMainFilterSizePartitions}; + } +} + +Subtractor::Subtractor(ApmDataDumper* data_dumper, + Aec3Optimization optimization) + : data_dumper_(data_dumper), + optimization_(optimization), + main_filter_(kMainFilterSizePartitions, true, optimization, data_dumper_), + shadow_filter_(kShadowFilterSizePartitions, + false, + optimization, + data_dumper_) { + RTC_DCHECK(data_dumper_); +} + +Subtractor::~Subtractor() {} + +void Subtractor::HandleEchoPathChange( + const EchoPathVariability& echo_path_variability) { + if (echo_path_variability.delay_change) { + main_filter_.HandleEchoPathChange(); + shadow_filter_.HandleEchoPathChange(); + G_main_.HandleEchoPathChange(); + } +} + +void Subtractor::Process(const FftBuffer& render_buffer, + const rtc::ArrayView capture, + const RenderSignalAnalyzer& render_signal_analyzer, + bool saturation, + SubtractorOutput* output) { + RTC_DCHECK_EQ(kBlockSize, capture.size()); + rtc::ArrayView y = capture; + const FftBuffer& X_buffer = render_buffer; + FftData& E_main = output->E_main; + FftData& E_shadow = output->E_shadow; + std::array& e_main = output->e_main; + std::array& e_shadow = output->e_shadow; + + FftData S; + FftData& G = S; + + // Form and analyze the output of the main filter. + main_filter_.Filter(X_buffer, &S); + ComputeError(fft_, S, y, &e_main, &E_main); + + // Form and analyze the output of the shadow filter. + shadow_filter_.Filter(X_buffer, &S); + ComputeError(fft_, S, y, &e_shadow, &E_shadow); + + // Compute spectra for future use. + E_main.Spectrum(optimization_, &output->E2_main); + E_shadow.Spectrum(optimization_, &output->E2_shadow); + + // Update the main filter. + G_main_.Compute(X_buffer, render_signal_analyzer, *output, main_filter_, + saturation, &G); + main_filter_.Adapt(X_buffer, G); + data_dumper_->DumpRaw("aec3_subtractor_G_main", G.re); + data_dumper_->DumpRaw("aec3_subtractor_G_main", G.im); + + // Update the shadow filter. + G_shadow_.Compute(X_buffer, render_signal_analyzer, E_shadow, + shadow_filter_.SizePartitions(), saturation, &G); + shadow_filter_.Adapt(X_buffer, G); + data_dumper_->DumpRaw("aec3_subtractor_G_shadow", G.re); + data_dumper_->DumpRaw("aec3_subtractor_G_shadow", G.im); + + main_filter_.DumpFilter("aec3_subtractor_H_main"); + shadow_filter_.DumpFilter("aec3_subtractor_H_shadow"); +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/subtractor.h b/webrtc/modules/audio_processing/aec3/subtractor.h new file mode 100644 index 0000000000..742e57c7fe --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/subtractor.h @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_SUBTRACTOR_H_ +#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_SUBTRACTOR_H_ + +#include +#include +#include + +#include "webrtc/base/constructormagic.h" +#include "webrtc/modules/audio_processing/aec3/adaptive_fir_filter.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" +#include "webrtc/modules/audio_processing/aec3/aec3_fft.h" +#include "webrtc/modules/audio_processing/aec3/echo_path_variability.h" +#include "webrtc/modules/audio_processing/aec3/fft_buffer.h" +#include "webrtc/modules/audio_processing/aec3/main_filter_update_gain.h" +#include "webrtc/modules/audio_processing/aec3/shadow_filter_update_gain.h" +#include "webrtc/modules/audio_processing/aec3/subtractor_output.h" +#include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" +#include "webrtc/modules/audio_processing/utility/ooura_fft.h" + +namespace webrtc { + +// Proves linear echo cancellation functionality +class Subtractor { + public: + Subtractor(ApmDataDumper* data_dumper, Aec3Optimization optimization); + ~Subtractor(); + + // Performs the echo subtraction. + void Process(const FftBuffer& render_buffer, + const rtc::ArrayView capture, + const RenderSignalAnalyzer& render_signal_analyzer, + bool saturation, + SubtractorOutput* output); + + // Returns a vector with the number of blocks included in the render buffer + // sums. + std::vector NumBlocksInRenderSums() const; + + // Returns the minimum required farend buffer length. + size_t MinFarendBufferLength() const { + return std::max(kMainFilterSizePartitions, kShadowFilterSizePartitions); + } + + void HandleEchoPathChange(const EchoPathVariability& echo_path_variability); + + // Returns the block-wise frequency response of the main adaptive filter. + const std::vector>& + FilterFrequencyResponse() const { + return main_filter_.FilterFrequencyResponse(); + } + + private: + const size_t kMainFilterSizePartitions = 12; + const size_t kShadowFilterSizePartitions = 12; + + const Aec3Fft fft_; + ApmDataDumper* data_dumper_; + const Aec3Optimization optimization_; + AdaptiveFirFilter main_filter_; + AdaptiveFirFilter shadow_filter_; + MainFilterUpdateGain G_main_; + ShadowFilterUpdateGain G_shadow_; + + RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(Subtractor); +}; + +} // namespace webrtc + +#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_SUBTRACTOR_H_ diff --git a/webrtc/modules/audio_processing/aec3/subtractor_output.h b/webrtc/modules/audio_processing/aec3/subtractor_output.h new file mode 100644 index 0000000000..90b9065b3b --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/subtractor_output.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_SUBTRACTOR_OUTPUT_H_ +#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_SUBTRACTOR_OUTPUT_H_ + +#include + +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" +#include "webrtc/modules/audio_processing/aec3/fft_data.h" + +namespace webrtc { + +// Stores the values being returned from the echo subtractor. +struct SubtractorOutput { + std::array e_main; + std::array e_shadow; + FftData E_main; + FftData E_shadow; + std::array E2_main; + std::array E2_shadow; + + void Reset() { + e_main.fill(0.f); + e_shadow.fill(0.f); + E_main.re.fill(0.f); + E_main.im.fill(0.f); + E_shadow.re.fill(0.f); + E_shadow.im.fill(0.f); + E2_main.fill(0.f); + E2_shadow.fill(0.f); + } +}; + +} // namespace webrtc + +#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_SUBTRACTOR_OUTPUT_H_ diff --git a/webrtc/modules/audio_processing/aec3/subtractor_unittest.cc b/webrtc/modules/audio_processing/aec3/subtractor_unittest.cc new file mode 100644 index 0000000000..45e2510c3e --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/subtractor_unittest.cc @@ -0,0 +1,175 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/subtractor.h" + +#include +#include +#include + +#include "webrtc/base/random.h" +#include "webrtc/modules/audio_processing/aec3/aec_state.h" +#include "webrtc/modules/audio_processing/test/echo_canceller_test_tools.h" +#include "webrtc/test/gtest.h" + +namespace webrtc { +namespace { + +float RunSubtractorTest(int num_blocks_to_process, + int delay_samples, + bool uncorrelated_inputs, + const std::vector& blocks_with_echo_path_changes) { + ApmDataDumper data_dumper(42); + Subtractor subtractor(&data_dumper, DetectOptimization()); + std::vector x(kBlockSize, 0.f); + std::vector y(kBlockSize, 0.f); + std::array x_old; + SubtractorOutput output; + FftBuffer X_buffer( + Aec3Optimization::kNone, subtractor.MinFarendBufferLength(), + std::vector(1, subtractor.MinFarendBufferLength())); + RenderSignalAnalyzer render_signal_analyzer; + Random random_generator(42U); + Aec3Fft fft; + FftData X; + std::array Y2; + std::array E2_main; + std::array E2_shadow; + AecState aec_state; + x_old.fill(0.f); + Y2.fill(0.f); + E2_main.fill(0.f); + E2_shadow.fill(0.f); + + DelayBuffer delay_buffer(delay_samples); + for (int k = 0; k < num_blocks_to_process; ++k) { + RandomizeSampleVector(&random_generator, x); + if (uncorrelated_inputs) { + RandomizeSampleVector(&random_generator, y); + } else { + delay_buffer.Delay(x, y); + } + fft.PaddedFft(x, x_old, &X); + X_buffer.Insert(X); + render_signal_analyzer.Update(X_buffer, aec_state.FilterDelay()); + + // Handle echo path changes. + if (std::find(blocks_with_echo_path_changes.begin(), + blocks_with_echo_path_changes.end(), + k) != blocks_with_echo_path_changes.end()) { + subtractor.HandleEchoPathChange(EchoPathVariability(true, true)); + } + subtractor.Process(X_buffer, y, render_signal_analyzer, false, &output); + + aec_state.Update(subtractor.FilterFrequencyResponse(), + rtc::Optional(delay_samples / kBlockSize), + X_buffer, E2_main, E2_shadow, Y2, x, + EchoPathVariability(false, false), false); + } + + const float output_power = std::inner_product( + output.e_main.begin(), output.e_main.end(), output.e_main.begin(), 0.f); + const float y_power = std::inner_product(y.begin(), y.end(), y.begin(), 0.f); + if (y_power == 0.f) { + ADD_FAILURE(); + return -1.0; + } + return output_power / y_power; +} + +std::string ProduceDebugText(size_t delay) { + std::ostringstream ss; + ss << "Delay: " << delay; + return ss.str(); +} + +} // namespace + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +// Verifies that the check for non data dumper works. +TEST(Subtractor, NullDataDumper) { + EXPECT_DEATH(Subtractor(nullptr, DetectOptimization()), ""); +} + +// Verifies the check for null subtractor output. +// TODO(peah): Re-enable the test once the issue with memory leaks during DEATH +// tests on test bots has been fixed. +TEST(Subtractor, DISABLED_NullOutput) { + ApmDataDumper data_dumper(42); + Subtractor subtractor(&data_dumper, DetectOptimization()); + FftBuffer X_buffer( + Aec3Optimization::kNone, subtractor.MinFarendBufferLength(), + std::vector(1, subtractor.MinFarendBufferLength())); + RenderSignalAnalyzer render_signal_analyzer; + std::vector y(kBlockSize, 0.f); + + EXPECT_DEATH( + subtractor.Process(X_buffer, y, render_signal_analyzer, false, nullptr), + ""); +} + +// Verifies the check for the capture signal size. +TEST(Subtractor, WrongCaptureSize) { + ApmDataDumper data_dumper(42); + Subtractor subtractor(&data_dumper, DetectOptimization()); + FftBuffer X_buffer( + Aec3Optimization::kNone, subtractor.MinFarendBufferLength(), + std::vector(1, subtractor.MinFarendBufferLength())); + RenderSignalAnalyzer render_signal_analyzer; + std::vector y(kBlockSize - 1, 0.f); + SubtractorOutput output; + + EXPECT_DEATH( + subtractor.Process(X_buffer, y, render_signal_analyzer, false, &output), + ""); +} + +#endif + +// Verifies that the subtractor is able to converge on correlated data. +TEST(Subtractor, Convergence) { + std::vector blocks_with_echo_path_changes; + for (size_t delay_samples : {0, 64, 150, 200, 301}) { + SCOPED_TRACE(ProduceDebugText(delay_samples)); + + float echo_to_nearend_power = RunSubtractorTest( + 100, delay_samples, false, blocks_with_echo_path_changes); + EXPECT_GT(0.1f, echo_to_nearend_power); + } +} + +// Verifies that the subtractor does not converge on uncorrelated signals. +TEST(Subtractor, NonConvergenceOnUncorrelatedSignals) { + std::vector blocks_with_echo_path_changes; + for (size_t delay_samples : {0, 64, 150, 200, 301}) { + SCOPED_TRACE(ProduceDebugText(delay_samples)); + + float echo_to_nearend_power = RunSubtractorTest( + 100, delay_samples, true, blocks_with_echo_path_changes); + EXPECT_NEAR(1.f, echo_to_nearend_power, 0.05); + } +} + +// Verifies that the subtractor is properly reset when there is an echo path +// change. +TEST(Subtractor, EchoPathChangeReset) { + std::vector blocks_with_echo_path_changes; + blocks_with_echo_path_changes.push_back(99); + for (size_t delay_samples : {0, 64, 150, 200, 301}) { + SCOPED_TRACE(ProduceDebugText(delay_samples)); + + float echo_to_nearend_power = RunSubtractorTest( + 100, delay_samples, false, blocks_with_echo_path_changes); + EXPECT_NEAR(1.f, echo_to_nearend_power, 0.0000001f); + } +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/suppression_filter.cc b/webrtc/modules/audio_processing/aec3/suppression_filter.cc new file mode 100644 index 0000000000..f127acf9de --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/suppression_filter.cc @@ -0,0 +1,178 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/suppression_filter.h" + +#include +#include +#include +#include +#include + +#include "webrtc/modules/audio_processing/utility/ooura_fft.h" + +namespace webrtc { +namespace { + +// Hanning window from Matlab command win = sqrt(hanning(128)). +const float kSqrtHanning[kFftLength] = { + 0.00000000000000f, 0.02454122852291f, 0.04906767432742f, 0.07356456359967f, + 0.09801714032956f, 0.12241067519922f, 0.14673047445536f, 0.17096188876030f, + 0.19509032201613f, 0.21910124015687f, 0.24298017990326f, 0.26671275747490f, + 0.29028467725446f, 0.31368174039889f, 0.33688985339222f, 0.35989503653499f, + 0.38268343236509f, 0.40524131400499f, 0.42755509343028f, 0.44961132965461f, + 0.47139673682600f, 0.49289819222978f, 0.51410274419322f, 0.53499761988710f, + 0.55557023301960f, 0.57580819141785f, 0.59569930449243f, 0.61523159058063f, + 0.63439328416365f, 0.65317284295378f, 0.67155895484702f, 0.68954054473707f, + 0.70710678118655f, 0.72424708295147f, 0.74095112535496f, 0.75720884650648f, + 0.77301045336274f, 0.78834642762661f, 0.80320753148064f, 0.81758481315158f, + 0.83146961230255f, 0.84485356524971f, 0.85772861000027f, 0.87008699110871f, + 0.88192126434835f, 0.89322430119552f, 0.90398929312344f, 0.91420975570353f, + 0.92387953251129f, 0.93299279883474f, 0.94154406518302f, 0.94952818059304f, + 0.95694033573221f, 0.96377606579544f, 0.97003125319454f, 0.97570213003853f, + 0.98078528040323f, 0.98527764238894f, 0.98917650996478f, 0.99247953459871f, + 0.99518472667220f, 0.99729045667869f, 0.99879545620517f, 0.99969881869620f, + 1.00000000000000f, 0.99969881869620f, 0.99879545620517f, 0.99729045667869f, + 0.99518472667220f, 0.99247953459871f, 0.98917650996478f, 0.98527764238894f, + 0.98078528040323f, 0.97570213003853f, 0.97003125319454f, 0.96377606579544f, + 0.95694033573221f, 0.94952818059304f, 0.94154406518302f, 0.93299279883474f, + 0.92387953251129f, 0.91420975570353f, 0.90398929312344f, 0.89322430119552f, + 0.88192126434835f, 0.87008699110871f, 0.85772861000027f, 0.84485356524971f, + 0.83146961230255f, 0.81758481315158f, 0.80320753148064f, 0.78834642762661f, + 0.77301045336274f, 0.75720884650648f, 0.74095112535496f, 0.72424708295147f, + 0.70710678118655f, 0.68954054473707f, 0.67155895484702f, 0.65317284295378f, + 0.63439328416365f, 0.61523159058063f, 0.59569930449243f, 0.57580819141785f, + 0.55557023301960f, 0.53499761988710f, 0.51410274419322f, 0.49289819222978f, + 0.47139673682600f, 0.44961132965461f, 0.42755509343028f, 0.40524131400499f, + 0.38268343236509f, 0.35989503653499f, 0.33688985339222f, 0.31368174039889f, + 0.29028467725446f, 0.26671275747490f, 0.24298017990326f, 0.21910124015687f, + 0.19509032201613f, 0.17096188876030f, 0.14673047445536f, 0.12241067519922f, + 0.09801714032956f, 0.07356456359967f, 0.04906767432742f, 0.02454122852291f}; + +} // namespace + +SuppressionFilter::SuppressionFilter(int sample_rate_hz) + : sample_rate_hz_(sample_rate_hz), + e_output_old_(NumBandsForRate(sample_rate_hz_)) { + RTC_DCHECK(ValidFullBandRate(sample_rate_hz_)); + e_input_old_.fill(0.f); + std::for_each(e_output_old_.begin(), e_output_old_.end(), + [](std::array& a) { a.fill(0.f); }); +} + +SuppressionFilter::~SuppressionFilter() = default; + +void SuppressionFilter::ApplyGain( + const FftData& comfort_noise, + const FftData& comfort_noise_high_band, + const std::array& suppression_gain, + std::vector>* e) { + RTC_DCHECK(e); + RTC_DCHECK_EQ(e->size(), NumBandsForRate(sample_rate_hz_)); + FftData E; + std::array e_extended; + constexpr float kIfftNormalization = 2.f / kFftLength; + + // Analysis filterbank. + std::transform(e_input_old_.begin(), e_input_old_.end(), + std::begin(kSqrtHanning), e_extended.begin(), + std::multiplies()); + std::transform((*e)[0].begin(), (*e)[0].end(), + std::begin(kSqrtHanning) + kFftLengthBy2, + e_extended.begin() + kFftLengthBy2, std::multiplies()); + std::copy((*e)[0].begin(), (*e)[0].end(), e_input_old_.begin()); + fft_.Fft(&e_extended, &E); + + // Apply gain. + std::transform(suppression_gain.begin(), suppression_gain.end(), E.re.begin(), + E.re.begin(), std::multiplies()); + std::transform(suppression_gain.begin(), suppression_gain.end(), E.im.begin(), + E.im.begin(), std::multiplies()); + + // Compute and add the comfort noise. + std::array scaled_comfort_noise; + std::transform(suppression_gain.begin(), suppression_gain.end(), + comfort_noise.re.begin(), scaled_comfort_noise.begin(), + [](float a, float b) { return std::max(1.f - a, 0.f) * b; }); + std::transform(scaled_comfort_noise.begin(), scaled_comfort_noise.end(), + E.re.begin(), E.re.begin(), std::plus()); + std::transform(suppression_gain.begin(), suppression_gain.end(), + comfort_noise.im.begin(), scaled_comfort_noise.begin(), + [](float a, float b) { return std::max(1.f - a, 0.f) * b; }); + std::transform(scaled_comfort_noise.begin(), scaled_comfort_noise.end(), + E.im.begin(), E.im.begin(), std::plus()); + + // Synthesis filterbank. + fft_.Ifft(E, &e_extended); + std::transform(e_output_old_[0].begin(), e_output_old_[0].end(), + std::begin(kSqrtHanning) + kFftLengthBy2, (*e)[0].begin(), + [&](float a, float b) { return kIfftNormalization * a * b; }); + std::transform(e_extended.begin(), e_extended.begin() + kFftLengthBy2, + std::begin(kSqrtHanning), e_extended.begin(), + [&](float a, float b) { return kIfftNormalization * a * b; }); + std::transform((*e)[0].begin(), (*e)[0].end(), e_extended.begin(), + (*e)[0].begin(), std::plus()); + std::for_each((*e)[0].begin(), (*e)[0].end(), [](float& x_k) { + x_k = std::max(std::min(x_k, 32767.0f), -32768.0f); + }); + std::copy(e_extended.begin() + kFftLengthBy2, e_extended.begin() + kFftLength, + std::begin(e_output_old_[0])); + + if (e->size() > 1) { + // Form time-domain high-band noise. + std::array time_domain_high_band_noise; + std::transform(comfort_noise_high_band.re.begin(), + comfort_noise_high_band.re.end(), E.re.begin(), + [&](float a) { return kIfftNormalization * a; }); + std::transform(comfort_noise_high_band.im.begin(), + comfort_noise_high_band.im.end(), E.im.begin(), + [&](float a) { return kIfftNormalization * a; }); + fft_.Ifft(E, &time_domain_high_band_noise); + + // Scale and apply the noise to the signals. + // TODO(peah): Ensure that the high bands are properly delayed. + constexpr int kNumBandsAveragingUpperGain = kFftLengthBy2 / 4; + constexpr float kOneByNumBandsAveragingUpperGain = + 1.f / kNumBandsAveragingUpperGain; + float high_bands_gain = + std::accumulate(suppression_gain.end() - kNumBandsAveragingUpperGain, + suppression_gain.end(), 0.f) * + kOneByNumBandsAveragingUpperGain; + + float high_bands_noise_scaling = + 0.4f * std::max(1.f - high_bands_gain * high_bands_gain, 0.f); + + std::transform( + (*e)[1].begin(), (*e)[1].end(), time_domain_high_band_noise.begin(), + (*e)[1].begin(), [&](float a, float b) { + return std::max( + std::min(b * high_bands_noise_scaling + high_bands_gain * a, + 32767.0f), + -32768.0f); + }); + + if (e->size() > 2) { + RTC_DCHECK_EQ(3, e->size()); + std::for_each((*e)[2].begin(), (*e)[2].end(), [&](float& a) { + a = std::max(std::min(a * high_bands_gain, 32767.0f), -32768.0f); + }); + } + + std::array tmp; + for (size_t k = 1; k < e->size(); ++k) { + std::copy((*e)[k].begin(), (*e)[k].end(), tmp.begin()); + std::copy(e_output_old_[k].begin(), e_output_old_[k].end(), + (*e)[k].begin()); + std::copy(tmp.begin(), tmp.end(), e_output_old_[k].begin()); + } + } +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/suppression_filter.h b/webrtc/modules/audio_processing/aec3/suppression_filter.h new file mode 100644 index 0000000000..31710475c9 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/suppression_filter.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_SUPPRESSION_FILTER_H_ +#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_SUPPRESSION_FILTER_H_ + +#include +#include + +#include "webrtc/base/constructormagic.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" +#include "webrtc/modules/audio_processing/aec3/aec3_fft.h" + +namespace webrtc { + +class SuppressionFilter { + public: + explicit SuppressionFilter(int sample_rate_hz); + ~SuppressionFilter(); + void ApplyGain(const FftData& comfort_noise, + const FftData& comfort_noise_high_bands, + const std::array& suppression_gain, + std::vector>* e); + + private: + const int sample_rate_hz_; + const OouraFft ooura_fft_; + const Aec3Fft fft_; + std::array e_input_old_; + std::vector> e_output_old_; + RTC_DISALLOW_COPY_AND_ASSIGN(SuppressionFilter); +}; + +} // namespace webrtc + +#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_SUPPRESSION_FILTER_H_ diff --git a/webrtc/modules/audio_processing/aec3/suppression_filter_unittest.cc b/webrtc/modules/audio_processing/aec3/suppression_filter_unittest.cc new file mode 100644 index 0000000000..e8710b8375 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/suppression_filter_unittest.cc @@ -0,0 +1,180 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/suppression_filter.h" + +#include +#include +#include + +#include "webrtc/test/gtest.h" + +namespace webrtc { +namespace { + +constexpr float kPi = 3.141592f; + +void ProduceSinusoid(int sample_rate_hz, + float sinusoidal_frequency_hz, + size_t* sample_counter, + rtc::ArrayView x) { + // Produce a sinusoid of the specified frequency. + for (size_t k = *sample_counter, j = 0; k < (*sample_counter + kBlockSize); + ++k, ++j) { + x[j] = + 32767.f * sin(2.f * kPi * sinusoidal_frequency_hz * k / sample_rate_hz); + } + *sample_counter = *sample_counter + kBlockSize; +} + +} // namespace + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +// Verifies the check for null suppressor output. +TEST(SuppressionFilter, NullOutput) { + FftData cn; + FftData cn_high_bands; + std::array gain; + + EXPECT_DEATH( + SuppressionFilter(16000).ApplyGain(cn, cn_high_bands, gain, nullptr), ""); +} + +// Verifies the check for allowed sample rate. +TEST(SuppressionFilter, ProperSampleRate) { + EXPECT_DEATH(SuppressionFilter(16001), ""); +} + +#endif + +// Verifies that no comfort noise is added when the gain is 1. +TEST(SuppressionFilter, ComfortNoiseInUnityGain) { + SuppressionFilter filter(48000); + FftData cn; + FftData cn_high_bands; + std::array gain; + + gain.fill(1.f); + cn.re.fill(1.f); + cn.im.fill(1.f); + cn_high_bands.re.fill(1.f); + cn_high_bands.im.fill(1.f); + + std::vector> e(3, std::vector(kBlockSize, 0.f)); + std::vector> e_ref = e; + filter.ApplyGain(cn, cn_high_bands, gain, &e); + + for (size_t k = 0; k < e.size(); ++k) { + EXPECT_EQ(e_ref[k], e[k]); + } +} + +// Verifies that the suppressor is able to suppress a signal. +TEST(SuppressionFilter, SignalSuppression) { + SuppressionFilter filter(48000); + FftData cn; + FftData cn_high_bands; + std::array gain; + std::vector> e(3, std::vector(kBlockSize, 0.f)); + + gain.fill(1.f); + std::for_each(gain.begin() + 10, gain.end(), [](float& a) { a = 0.f; }); + + cn.re.fill(0.f); + cn.im.fill(0.f); + cn_high_bands.re.fill(0.f); + cn_high_bands.im.fill(0.f); + + size_t sample_counter = 0; + + float e0_input = 0.f; + float e0_output = 0.f; + for (size_t k = 0; k < 100; ++k) { + ProduceSinusoid(16000, 16000 * 40 / kFftLengthBy2 / 2, &sample_counter, + e[0]); + e0_input = + std::inner_product(e[0].begin(), e[0].end(), e[0].begin(), e0_input); + filter.ApplyGain(cn, cn_high_bands, gain, &e); + e0_output = + std::inner_product(e[0].begin(), e[0].end(), e[0].begin(), e0_output); + } + + EXPECT_LT(e0_output, e0_input / 1000.f); +} + +// Verifies that the suppressor is able to pass through a desired signal while +// applying suppressing for some frequencies. +TEST(SuppressionFilter, SignalTransparency) { + SuppressionFilter filter(48000); + FftData cn; + FftData cn_high_bands; + std::array gain; + std::vector> e(3, std::vector(kBlockSize, 0.f)); + + gain.fill(1.f); + std::for_each(gain.begin() + 30, gain.end(), [](float& a) { a = 0.f; }); + + cn.re.fill(0.f); + cn.im.fill(0.f); + cn_high_bands.re.fill(0.f); + cn_high_bands.im.fill(0.f); + + size_t sample_counter = 0; + + float e0_input = 0.f; + float e0_output = 0.f; + for (size_t k = 0; k < 100; ++k) { + ProduceSinusoid(16000, 16000 * 10 / kFftLengthBy2 / 2, &sample_counter, + e[0]); + e0_input = + std::inner_product(e[0].begin(), e[0].end(), e[0].begin(), e0_input); + filter.ApplyGain(cn, cn_high_bands, gain, &e); + e0_output = + std::inner_product(e[0].begin(), e[0].end(), e[0].begin(), e0_output); + } + + EXPECT_LT(0.9f * e0_input, e0_output); +} + +// Verifies that the suppressor delay. +TEST(SuppressionFilter, Delay) { + SuppressionFilter filter(48000); + FftData cn; + FftData cn_high_bands; + std::array gain; + std::vector> e(3, std::vector(kBlockSize, 0.f)); + + gain.fill(1.f); + + cn.re.fill(0.f); + cn.im.fill(0.f); + cn_high_bands.re.fill(0.f); + cn_high_bands.im.fill(0.f); + + for (size_t k = 0; k < 100; ++k) { + for (size_t j = 0; j < 3; ++j) { + for (size_t i = 0; i < kBlockSize; ++i) { + e[j][i] = k * kBlockSize + i; + } + } + + filter.ApplyGain(cn, cn_high_bands, gain, &e); + if (k > 2) { + for (size_t j = 0; j < 2; ++j) { + for (size_t i = 0; i < kBlockSize; ++i) { + EXPECT_NEAR(k * kBlockSize + i - kBlockSize, e[j][i], 0.01); + } + } + } + } +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/suppression_gain.cc b/webrtc/modules/audio_processing/aec3/suppression_gain.cc new file mode 100644 index 0000000000..34bb9cb390 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/suppression_gain.cc @@ -0,0 +1,282 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/suppression_gain.h" + +#include "webrtc/typedefs.h" +#if defined(WEBRTC_ARCH_X86_FAMILY) +#include +#endif +#include +#include +#include + +namespace webrtc { +namespace { + +constexpr int kNumIterations = 2; +constexpr float kEchoMaskingMargin = 1.f / 10.f; +constexpr float kBandMaskingFactor = 1.f / 2.f; +constexpr float kTimeMaskingFactor = 1.f / 10.f; + +} // namespace + +namespace aec3 { + +#if defined(WEBRTC_ARCH_X86_FAMILY) + +// Optimized SSE2 code for the gain computation. +// TODO(peah): Add further optimizations, in particular for the divisions. +void ComputeGains_SSE2( + const std::array& nearend_power, + const std::array& residual_echo_power, + const std::array& comfort_noise_power, + float strong_nearend_margin, + std::array* previous_gain_squared, + std::array* previous_masker, + std::array* gain) { + std::array masker; + std::array same_band_masker; + std::array one_by_residual_echo_power; + std::array strong_nearend; + std::array neighboring_bands_masker; + std::array* gain_squared = gain; + + // Precompute 1/residual_echo_power. + std::transform(residual_echo_power.begin() + 1, residual_echo_power.end() - 1, + one_by_residual_echo_power.begin(), + [](float a) { return a > 0.f ? 1.f / a : -1.f; }); + + // Precompute indicators for bands with strong nearend. + std::transform( + residual_echo_power.begin() + 1, residual_echo_power.end() - 1, + nearend_power.begin() + 1, strong_nearend.begin(), + [&](float a, float b) { return a <= strong_nearend_margin * b; }); + + // Precompute masker for the same band. + std::transform(comfort_noise_power.begin() + 1, comfort_noise_power.end() - 1, + previous_masker->begin(), same_band_masker.begin(), + [&](float a, float b) { return a + kTimeMaskingFactor * b; }); + + for (int k = 0; k < kNumIterations; ++k) { + if (k == 0) { + // Add masker from the same band. + std::copy(same_band_masker.begin(), same_band_masker.end(), + masker.begin()); + } else { + // Add masker for neighboring bands. + std::transform(nearend_power.begin(), nearend_power.end(), + gain_squared->begin(), neighboring_bands_masker.begin(), + std::multiplies()); + std::transform(neighboring_bands_masker.begin(), + neighboring_bands_masker.end(), + comfort_noise_power.begin(), + neighboring_bands_masker.begin(), std::plus()); + std::transform( + neighboring_bands_masker.begin(), neighboring_bands_masker.end() - 2, + neighboring_bands_masker.begin() + 2, masker.begin(), + [&](float a, float b) { return kBandMaskingFactor * (a + b); }); + + // Add masker from the same band. + std::transform(same_band_masker.begin(), same_band_masker.end(), + masker.begin(), masker.begin(), std::plus()); + } + + // Compute new gain as: + // G2(t,f) = (comfort_noise_power(t,f) + G2(t-1)*nearend_power(t-1)) * + // kTimeMaskingFactor + // * kEchoMaskingMargin / residual_echo_power(t,f). + // or + // G2(t,f) = ((comfort_noise_power(t,f) + G2(t-1) * + // nearend_power(t-1)) * kTimeMaskingFactor + + // (comfort_noise_power(t, f-1) + comfort_noise_power(t, f+1) + + // (G2(t,f-1)*nearend_power(t, f-1) + + // G2(t,f+1)*nearend_power(t, f+1)) * + // kTimeMaskingFactor) * kBandMaskingFactor) + // * kEchoMaskingMargin / residual_echo_power(t,f). + std::transform( + masker.begin(), masker.end(), one_by_residual_echo_power.begin(), + gain_squared->begin() + 1, [&](float a, float b) { + return b >= 0 ? std::min(kEchoMaskingMargin * a * b, 1.f) : 1.f; + }); + + // Limit gain for bands with strong nearend. + std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, + strong_nearend.begin(), gain_squared->begin() + 1, + [](float a, bool b) { return b ? 1.f : a; }); + + // Limit the allowed gain update over time. + std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, + previous_gain_squared->begin(), gain_squared->begin() + 1, + [](float a, float b) { + return b < 0.0001f ? std::min(a, 0.0001f) + : std::min(a, b * 2.f); + }); + + (*gain_squared)[0] = (*gain_squared)[1]; + (*gain_squared)[kFftLengthBy2] = (*gain_squared)[kFftLengthBy2Minus1]; + } + + std::copy(gain_squared->begin() + 1, gain_squared->end() - 1, + previous_gain_squared->begin()); + + std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, + nearend_power.begin() + 1, previous_masker->begin(), + std::multiplies()); + std::transform(previous_masker->begin(), previous_masker->end(), + comfort_noise_power.begin() + 1, previous_masker->begin(), + std::plus()); + + for (size_t k = 0; k < kFftLengthBy2; k += 4) { + __m128 g = _mm_loadu_ps(&(*gain_squared)[k]); + g = _mm_sqrt_ps(g); + _mm_storeu_ps(&(*gain)[k], g); + } + + (*gain)[kFftLengthBy2] = sqrtf((*gain)[kFftLengthBy2]); +} + +#endif + +void ComputeGains( + const std::array& nearend_power, + const std::array& residual_echo_power, + const std::array& comfort_noise_power, + float strong_nearend_margin, + std::array* previous_gain_squared, + std::array* previous_masker, + std::array* gain) { + std::array masker; + std::array same_band_masker; + std::array one_by_residual_echo_power; + std::array strong_nearend; + std::array neighboring_bands_masker; + std::array* gain_squared = gain; + + // Precompute 1/residual_echo_power. + std::transform(residual_echo_power.begin() + 1, residual_echo_power.end() - 1, + one_by_residual_echo_power.begin(), + [](float a) { return a > 0.f ? 1.f / a : -1.f; }); + + // Precompute indicators for bands with strong nearend. + std::transform( + residual_echo_power.begin() + 1, residual_echo_power.end() - 1, + nearend_power.begin() + 1, strong_nearend.begin(), + [&](float a, float b) { return a <= strong_nearend_margin * b; }); + + // Precompute masker for the same band. + std::transform(comfort_noise_power.begin() + 1, comfort_noise_power.end() - 1, + previous_masker->begin(), same_band_masker.begin(), + [&](float a, float b) { return a + kTimeMaskingFactor * b; }); + + for (int k = 0; k < kNumIterations; ++k) { + if (k == 0) { + // Add masker from the same band. + std::copy(same_band_masker.begin(), same_band_masker.end(), + masker.begin()); + } else { + // Add masker for neightboring bands. + std::transform(nearend_power.begin(), nearend_power.end(), + gain_squared->begin(), neighboring_bands_masker.begin(), + std::multiplies()); + std::transform(neighboring_bands_masker.begin(), + neighboring_bands_masker.end(), + comfort_noise_power.begin(), + neighboring_bands_masker.begin(), std::plus()); + std::transform( + neighboring_bands_masker.begin(), neighboring_bands_masker.end() - 2, + neighboring_bands_masker.begin() + 2, masker.begin(), + [&](float a, float b) { return kBandMaskingFactor * (a + b); }); + + // Add masker from the same band. + std::transform(same_band_masker.begin(), same_band_masker.end(), + masker.begin(), masker.begin(), std::plus()); + } + + // Compute new gain as: + // G2(t,f) = (comfort_noise_power(t,f) + G2(t-1)*nearend_power(t-1)) * + // kTimeMaskingFactor + // * kEchoMaskingMargin / residual_echo_power(t,f). + // or + // G2(t,f) = ((comfort_noise_power(t,f) + G2(t-1) * + // nearend_power(t-1)) * kTimeMaskingFactor + + // (comfort_noise_power(t, f-1) + comfort_noise_power(t, f+1) + + // (G2(t,f-1)*nearend_power(t, f-1) + + // G2(t,f+1)*nearend_power(t, f+1)) * + // kTimeMaskingFactor) * kBandMaskingFactor) + // * kEchoMaskingMargin / residual_echo_power(t,f). + std::transform( + masker.begin(), masker.end(), one_by_residual_echo_power.begin(), + gain_squared->begin() + 1, [&](float a, float b) { + return b >= 0 ? std::min(kEchoMaskingMargin * a * b, 1.f) : 1.f; + }); + + // Limit gain for bands with strong nearend. + std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, + strong_nearend.begin(), gain_squared->begin() + 1, + [](float a, bool b) { return b ? 1.f : a; }); + + // Limit the allowed gain update over time. + std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, + previous_gain_squared->begin(), gain_squared->begin() + 1, + [](float a, float b) { + return b < 0.0001f ? std::min(a, 0.0001f) + : std::min(a, b * 2.f); + }); + + (*gain_squared)[0] = (*gain_squared)[1]; + (*gain_squared)[kFftLengthBy2] = (*gain_squared)[kFftLengthBy2Minus1]; + } + + std::copy(gain_squared->begin() + 1, gain_squared->end() - 1, + previous_gain_squared->begin()); + + std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, + nearend_power.begin() + 1, previous_masker->begin(), + std::multiplies()); + std::transform(previous_masker->begin(), previous_masker->end(), + comfort_noise_power.begin() + 1, previous_masker->begin(), + std::plus()); + + std::transform(gain_squared->begin(), gain_squared->end(), gain->begin(), + [](float a) { return sqrtf(a); }); +} + +} // namespace aec3 + +SuppressionGain::SuppressionGain(Aec3Optimization optimization) + : optimization_(optimization) { + previous_gain_squared_.fill(1.f); + previous_masker_.fill(0.f); +} + +void SuppressionGain::GetGain( + const std::array& nearend_power, + const std::array& residual_echo_power, + const std::array& comfort_noise_power, + float strong_nearend_margin, + std::array* gain) { + RTC_DCHECK(gain); + switch (optimization_) { +#if defined(WEBRTC_ARCH_X86_FAMILY) + case Aec3Optimization::kSse2: + aec3::ComputeGains_SSE2(nearend_power, residual_echo_power, + comfort_noise_power, strong_nearend_margin, + &previous_gain_squared_, &previous_masker_, gain); + break; +#endif + default: + aec3::ComputeGains(nearend_power, residual_echo_power, + comfort_noise_power, strong_nearend_margin, + &previous_gain_squared_, &previous_masker_, gain); + } +} + +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/suppression_gain.h b/webrtc/modules/audio_processing/aec3/suppression_gain.h new file mode 100644 index 0000000000..bccbed8897 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/suppression_gain.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_SUPPRESSION_GAIN_H_ +#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_SUPPRESSION_GAIN_H_ + +#include + +#include "webrtc/base/constructormagic.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" +#include "webrtc/modules/audio_processing/aec3/fft_buffer.h" + +namespace webrtc { +namespace aec3 { +#if defined(WEBRTC_ARCH_X86_FAMILY) + +void ComputeGains_SSE2( + const std::array& nearend_power, + const std::array& residual_echo_power, + const std::array& comfort_noise_power, + float strong_nearend_margin, + std::array* previous_gain_squared, + std::array* previous_masker, + std::array* gain); + +#endif + +void ComputeGains( + const std::array& nearend_power, + const std::array& residual_echo_power, + const std::array& comfort_noise_power, + float strong_nearend_margin, + std::array* previous_gain_squared, + std::array* previous_masker, + std::array* gain); + +} // namespace aec3 + +class SuppressionGain { + public: + explicit SuppressionGain(Aec3Optimization optimization); + void GetGain(const std::array& nearend_power, + const std::array& residual_echo_power, + const std::array& comfort_noise_power, + float strong_nearend_margin, + std::array* gain); + + private: + const Aec3Optimization optimization_; + std::array previous_gain_squared_; + std::array previous_masker_; + RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(SuppressionGain); +}; + +} // namespace webrtc + +#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_SUPPRESSION_GAIN_H_ diff --git a/webrtc/modules/audio_processing/aec3/suppression_gain_unittest.cc b/webrtc/modules/audio_processing/aec3/suppression_gain_unittest.cc new file mode 100644 index 0000000000..6016f182a9 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/suppression_gain_unittest.cc @@ -0,0 +1,148 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/suppression_gain.h" + +#include "webrtc/typedefs.h" +#include "webrtc/system_wrappers/include/cpu_features_wrapper.h" +#include "webrtc/test/gtest.h" + +namespace webrtc { +namespace aec3 { + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +// Verifies that the check for non-null output gains works. +TEST(SuppressionGain, NullOutputGains) { + std::array E2; + std::array R2; + std::array N2; + EXPECT_DEATH( + SuppressionGain(DetectOptimization()).GetGain(E2, R2, N2, 0.1f, nullptr), + ""); +} + +#endif + +#if defined(WEBRTC_ARCH_X86_FAMILY) +// Verifies that the optimized methods are bitexact to their reference +// counterparts. +TEST(SuppressionGain, TestOptimizations) { + if (WebRtc_GetCPUInfo(kSSE2) != 0) { + std::array G2_old; + std::array M2_old; + std::array G2_old_SSE2; + std::array M2_old_SSE2; + std::array E2; + std::array R2; + std::array N2; + std::array g; + std::array g_SSE2; + + G2_old.fill(1.f); + M2_old.fill(.23f); + G2_old_SSE2.fill(1.f); + M2_old_SSE2.fill(.23f); + + E2.fill(10.f); + R2.fill(0.1f); + N2.fill(100.f); + for (int k = 0; k < 10; ++k) { + ComputeGains(E2, R2, N2, 0.1f, &G2_old, &M2_old, &g); + ComputeGains_SSE2(E2, R2, N2, 0.1f, &G2_old_SSE2, &M2_old_SSE2, &g_SSE2); + for (size_t j = 0; j < G2_old.size(); ++j) { + EXPECT_NEAR(G2_old[j], G2_old_SSE2[j], 0.0000001f); + } + for (size_t j = 0; j < M2_old.size(); ++j) { + EXPECT_NEAR(M2_old[j], M2_old_SSE2[j], 0.0000001f); + } + for (size_t j = 0; j < g.size(); ++j) { + EXPECT_NEAR(g[j], g_SSE2[j], 0.0000001f); + } + } + + E2.fill(100.f); + R2.fill(0.1f); + N2.fill(0.f); + for (int k = 0; k < 10; ++k) { + ComputeGains(E2, R2, N2, 0.1f, &G2_old, &M2_old, &g); + ComputeGains_SSE2(E2, R2, N2, 0.1f, &G2_old_SSE2, &M2_old_SSE2, &g_SSE2); + for (size_t j = 0; j < G2_old.size(); ++j) { + EXPECT_NEAR(G2_old[j], G2_old_SSE2[j], 0.0000001f); + } + for (size_t j = 0; j < M2_old.size(); ++j) { + EXPECT_NEAR(M2_old[j], M2_old_SSE2[j], 0.0000001f); + } + for (size_t j = 0; j < g.size(); ++j) { + EXPECT_NEAR(g[j], g_SSE2[j], 0.0000001f); + } + } + + E2.fill(0.1f); + R2.fill(100.f); + N2.fill(0.f); + for (int k = 0; k < 10; ++k) { + ComputeGains(E2, R2, N2, 0.1f, &G2_old, &M2_old, &g); + ComputeGains_SSE2(E2, R2, N2, 0.1f, &G2_old_SSE2, &M2_old_SSE2, &g_SSE2); + for (size_t j = 0; j < G2_old.size(); ++j) { + EXPECT_NEAR(G2_old[j], G2_old_SSE2[j], 0.0000001f); + } + for (size_t j = 0; j < M2_old.size(); ++j) { + EXPECT_NEAR(M2_old[j], M2_old_SSE2[j], 0.0000001f); + } + for (size_t j = 0; j < g.size(); ++j) { + EXPECT_NEAR(g[j], g_SSE2[j], 0.0000001f); + } + } + } +} +#endif + +// Does a sanity check that the gains are correctly computed. +TEST(SuppressionGain, BasicGainComputation) { + SuppressionGain suppression_gain(DetectOptimization()); + std::array E2; + std::array R2; + std::array N2; + std::array g; + + // Ensure that a strong noise is detected to mask any echoes. + E2.fill(10.f); + R2.fill(0.1f); + N2.fill(100.f); + for (int k = 0; k < 10; ++k) { + suppression_gain.GetGain(E2, R2, N2, 0.1f, &g); + } + std::for_each(g.begin(), g.end(), + [](float a) { EXPECT_NEAR(1.f, a, 0.001); }); + + // Ensure that a strong nearend is detected to mask any echoes. + E2.fill(100.f); + R2.fill(0.1f); + N2.fill(0.f); + for (int k = 0; k < 10; ++k) { + suppression_gain.GetGain(E2, R2, N2, 0.1f, &g); + } + std::for_each(g.begin(), g.end(), + [](float a) { EXPECT_NEAR(1.f, a, 0.001); }); + + // Ensure that a strong echo is suppressed. + E2.fill(0.1f); + R2.fill(100.f); + N2.fill(0.f); + for (int k = 0; k < 10; ++k) { + suppression_gain.GetGain(E2, R2, N2, 0.1f, &g); + } + std::for_each(g.begin(), g.end(), + [](float a) { EXPECT_NEAR(0.f, a, 0.001); }); +} + +} // namespace aec3 +} // namespace webrtc diff --git a/webrtc/modules/audio_processing/audio_processing_impl.cc b/webrtc/modules/audio_processing/audio_processing_impl.cc index 14d6a26d1a..4c9188f768 100644 --- a/webrtc/modules/audio_processing/audio_processing_impl.cc +++ b/webrtc/modules/audio_processing/audio_processing_impl.cc @@ -439,18 +439,11 @@ int AudioProcessingImpl::MaybeInitialize( } int AudioProcessingImpl::InitializeLocked() { - int capture_audiobuffer_num_channels; - if (private_submodules_->echo_canceller3) { - // TODO(peah): Ensure that the echo canceller can operate on more than one - // microphone channel. - RTC_DCHECK(!capture_nonlocked_.beamformer_enabled); - capture_audiobuffer_num_channels = 1; - } else { - capture_audiobuffer_num_channels = - capture_nonlocked_.beamformer_enabled - ? formats_.api_format.input_stream().num_channels() - : formats_.api_format.output_stream().num_channels(); - } + const int capture_audiobuffer_num_channels = + capture_nonlocked_.beamformer_enabled + ? formats_.api_format.input_stream().num_channels() + : formats_.api_format.output_stream().num_channels(); + const int render_audiobuffer_num_output_frames = formats_.api_format.reverse_output_stream().num_frames() == 0 ? formats_.render_processing_format.num_frames() @@ -576,9 +569,7 @@ int AudioProcessingImpl::InitializeLocked(const ProcessingConfig& config) { submodule_states_.RenderMultiBandSubModulesActive()); // TODO(aluebs): Remove this restriction once we figure out why the 3-band // splitting filter degrades the AEC performance. - // TODO(peah): Verify that the band splitting is needed for the AEC3. - if (render_processing_rate > kSampleRate32kHz && - !capture_nonlocked_.echo_canceller3_enabled) { + if (render_processing_rate > kSampleRate32kHz) { render_processing_rate = submodule_states_.RenderMultiBandProcessingActive() ? kSampleRate32kHz : kSampleRate16kHz; @@ -1162,6 +1153,14 @@ int AudioProcessingImpl::ProcessCaptureStreamLocked() { capture_buffer->SplitIntoFrequencyBands(); } + if (private_submodules_->echo_canceller3) { + // Force down-mixing of the number of channels after the detection of + // capture signal saturation. + // TODO(peah): Look into ensuring that this kind of tampering with the + // AudioBuffer functionality should not be needed. + capture_buffer->set_num_channels(1); + } + if (capture_nonlocked_.beamformer_enabled) { private_submodules_->beamformer->AnalyzeChunk( *capture_buffer->split_data_f());