From 5e79b293137f9022322331c9644743203d246ba3 Mon Sep 17 00:00:00 2001 From: peah Date: Wed, 12 Apr 2017 01:20:45 -0700 Subject: [PATCH] Adding new functionality for SIMD optimizations in AEC3 Most of the complex functionality in AEC3 is done using vector maths. This CL adds a new functionality for performing these using SIMD operations in a simple manner whenever such are available. The reason for putting the implementations in the header file is to allow any possible inlining. BUG=webrtc:6018 Review-Url: https://codereview.webrtc.org/2813823002 Cr-Commit-Position: refs/heads/master@{#17663} --- webrtc/modules/audio_processing/BUILD.gn | 2 + .../audio_processing/aec3/suppression_gain.cc | 171 ++---------------- .../audio_processing/aec3/suppression_gain.h | 24 --- .../aec3/suppression_gain_unittest.cc | 74 -------- .../audio_processing/aec3/vector_math.h | 128 +++++++++++++ .../aec3/vector_math_unittest.cc | 87 +++++++++ 6 files changed, 236 insertions(+), 250 deletions(-) create mode 100644 webrtc/modules/audio_processing/aec3/vector_math.h create mode 100644 webrtc/modules/audio_processing/aec3/vector_math_unittest.cc diff --git a/webrtc/modules/audio_processing/BUILD.gn b/webrtc/modules/audio_processing/BUILD.gn index a57136a82e..5686dc6690 100644 --- a/webrtc/modules/audio_processing/BUILD.gn +++ b/webrtc/modules/audio_processing/BUILD.gn @@ -94,6 +94,7 @@ rtc_static_library("audio_processing") { "aec3/suppression_filter.h", "aec3/suppression_gain.cc", "aec3/suppression_gain.h", + "aec3/vector_math.h", "aecm/aecm_core.cc", "aecm/aecm_core.h", "aecm/echo_control_mobile.cc", @@ -601,6 +602,7 @@ if (rtc_include_tests) { "aec3/subtractor_unittest.cc", "aec3/suppression_filter_unittest.cc", "aec3/suppression_gain_unittest.cc", + "aec3/vector_math_unittest.cc", "audio_processing_impl_locking_unittest.cc", "audio_processing_impl_unittest.cc", "audio_processing_unittest.cc", diff --git a/webrtc/modules/audio_processing/aec3/suppression_gain.cc b/webrtc/modules/audio_processing/aec3/suppression_gain.cc index 4bf452cbbe..86af60f316 100644 --- a/webrtc/modules/audio_processing/aec3/suppression_gain.cc +++ b/webrtc/modules/audio_processing/aec3/suppression_gain.cc @@ -20,6 +20,7 @@ #include #include "webrtc/base/checks.h" +#include "webrtc/modules/audio_processing/aec3/vector_math.h" namespace webrtc { namespace { @@ -48,15 +49,9 @@ constexpr float kEchoMaskingMargin = 1.f / 20.f; constexpr float kBandMaskingFactor = 1.f / 10.f; constexpr float kTimeMaskingFactor = 1.f / 10.f; -} // namespace - -namespace aec3 { - -#if defined(WEBRTC_ARCH_X86_FAMILY) - -// Optimized SSE2 code for the gain computation. // TODO(peah): Add further optimizations, in particular for the divisions. -void ComputeGains_SSE2( +void ComputeGains( + Aec3Optimization optimization, const std::array& nearend_power, const std::array& residual_echo_power, const std::array& comfort_noise_power, @@ -70,6 +65,7 @@ void ComputeGains_SSE2( std::array strong_nearend; std::array neighboring_bands_masker; std::array* gain_squared = gain; + aec3::VectorMath math(optimization); // Precompute 1/residual_echo_power. std::transform(residual_echo_power.begin() + 1, residual_echo_power.end() - 1, @@ -94,21 +90,15 @@ void ComputeGains_SSE2( masker.begin()); } else { // Add masker for neighboring bands. - std::transform(nearend_power.begin(), nearend_power.end(), - gain_squared->begin(), neighboring_bands_masker.begin(), - std::multiplies()); - std::transform(neighboring_bands_masker.begin(), - neighboring_bands_masker.end(), - comfort_noise_power.begin(), - neighboring_bands_masker.begin(), std::plus()); + math.Multiply(nearend_power, *gain_squared, neighboring_bands_masker); + math.Accumulate(comfort_noise_power, neighboring_bands_masker); std::transform( neighboring_bands_masker.begin(), neighboring_bands_masker.end() - 2, neighboring_bands_masker.begin() + 2, masker.begin(), [&](float a, float b) { return kBandMaskingFactor * (a + b); }); // Add masker from the same band. - std::transform(same_band_masker.begin(), same_band_masker.end(), - masker.begin(), masker.begin(), std::plus()); + math.Accumulate(same_band_masker, masker); } // Compute new gain as: @@ -150,130 +140,17 @@ void ComputeGains_SSE2( std::copy(gain_squared->begin() + 1, gain_squared->end() - 1, previous_gain_squared->begin()); - std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, - nearend_power.begin() + 1, previous_masker->begin(), - std::multiplies()); - std::transform(previous_masker->begin(), previous_masker->end(), - comfort_noise_power.begin() + 1, previous_masker->begin(), - std::plus()); - - for (size_t k = 0; k < kFftLengthBy2; k += 4) { - __m128 g = _mm_loadu_ps(&(*gain_squared)[k]); - g = _mm_sqrt_ps(g); - _mm_storeu_ps(&(*gain)[k], g); - } - - (*gain)[kFftLengthBy2] = sqrtf((*gain)[kFftLengthBy2]); + math.Multiply( + rtc::ArrayView(&(*gain_squared)[1], previous_masker->size()), + rtc::ArrayView(&nearend_power[1], previous_masker->size()), + *previous_masker); + math.Accumulate(rtc::ArrayView(&comfort_noise_power[1], + previous_masker->size()), + *previous_masker); + math.Sqrt(*gain); } -#endif - -void ComputeGains( - const std::array& nearend_power, - const std::array& residual_echo_power, - const std::array& comfort_noise_power, - float strong_nearend_margin, - std::array* previous_gain_squared, - std::array* previous_masker, - std::array* gain) { - std::array masker; - std::array same_band_masker; - std::array one_by_residual_echo_power; - std::array strong_nearend; - std::array neighboring_bands_masker; - std::array* gain_squared = gain; - - // Precompute 1/residual_echo_power. - std::transform(residual_echo_power.begin() + 1, residual_echo_power.end() - 1, - one_by_residual_echo_power.begin(), - [](float a) { return a > 0.f ? 1.f / a : -1.f; }); - - // Precompute indicators for bands with strong nearend. - std::transform( - residual_echo_power.begin() + 1, residual_echo_power.end() - 1, - nearend_power.begin() + 1, strong_nearend.begin(), - [&](float a, float b) { return a <= strong_nearend_margin * b; }); - - // Precompute masker for the same band. - std::transform(comfort_noise_power.begin() + 1, comfort_noise_power.end() - 1, - previous_masker->begin(), same_band_masker.begin(), - [&](float a, float b) { return a + kTimeMaskingFactor * b; }); - - for (int k = 0; k < kNumIterations; ++k) { - if (k == 0) { - // Add masker from the same band. - std::copy(same_band_masker.begin(), same_band_masker.end(), - masker.begin()); - } else { - // Add masker for neightboring bands. - std::transform(nearend_power.begin(), nearend_power.end(), - gain_squared->begin(), neighboring_bands_masker.begin(), - std::multiplies()); - std::transform(neighboring_bands_masker.begin(), - neighboring_bands_masker.end(), - comfort_noise_power.begin(), - neighboring_bands_masker.begin(), std::plus()); - std::transform( - neighboring_bands_masker.begin(), neighboring_bands_masker.end() - 2, - neighboring_bands_masker.begin() + 2, masker.begin(), - [&](float a, float b) { return kBandMaskingFactor * (a + b); }); - - // Add masker from the same band. - std::transform(same_band_masker.begin(), same_band_masker.end(), - masker.begin(), masker.begin(), std::plus()); - } - - // Compute new gain as: - // G2(t,f) = (comfort_noise_power(t,f) + G2(t-1)*nearend_power(t-1)) * - // kTimeMaskingFactor - // * kEchoMaskingMargin / residual_echo_power(t,f). - // or - // G2(t,f) = ((comfort_noise_power(t,f) + G2(t-1) * - // nearend_power(t-1)) * kTimeMaskingFactor + - // (comfort_noise_power(t, f-1) + comfort_noise_power(t, f+1) + - // (G2(t,f-1)*nearend_power(t, f-1) + - // G2(t,f+1)*nearend_power(t, f+1)) * - // kTimeMaskingFactor) * kBandMaskingFactor) - // * kEchoMaskingMargin / residual_echo_power(t,f). - std::transform( - masker.begin(), masker.end(), one_by_residual_echo_power.begin(), - gain_squared->begin() + 1, [&](float a, float b) { - return b >= 0 ? std::min(kEchoMaskingMargin * a * b, 1.f) : 1.f; - }); - - // Limit gain for bands with strong nearend. - std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, - strong_nearend.begin(), gain_squared->begin() + 1, - [](float a, bool b) { return b ? 1.f : a; }); - - // Limit the allowed gain update over time. - std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, - previous_gain_squared->begin(), gain_squared->begin() + 1, - [](float a, float b) { - return b < 0.001f ? std::min(a, 0.001f) - : std::min(a, b * 2.f); - }); - - // Process the gains to avoid artefacts caused by gain realization in the - // filterbank and impact of external pre-processing of the signal. - GainPostProcessing(gain_squared); - } - - std::copy(gain_squared->begin() + 1, gain_squared->end() - 1, - previous_gain_squared->begin()); - - std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, - nearend_power.begin() + 1, previous_masker->begin(), - std::multiplies()); - std::transform(previous_masker->begin(), previous_masker->end(), - comfort_noise_power.begin() + 1, previous_masker->begin(), - std::plus()); - - std::transform(gain_squared->begin(), gain_squared->end(), gain->begin(), - [](float a) { return sqrtf(a); }); -} - -} // namespace aec3 +} // namespace // Computes an upper bound on the gain to apply for high frequencies. float HighFrequencyGainBound(bool saturated_echo, @@ -342,19 +219,9 @@ void SuppressionGain::GetGain( // Choose margin to use. const float margin = saturated_echo ? 0.001f : 0.01f; - switch (optimization_) { -#if defined(WEBRTC_ARCH_X86_FAMILY) - case Aec3Optimization::kSse2: - aec3::ComputeGains_SSE2( - nearend_power, residual_echo_power, comfort_noise_power, margin, - &previous_gain_squared_, &previous_masker_, low_band_gain); - break; -#endif - default: - aec3::ComputeGains(nearend_power, residual_echo_power, - comfort_noise_power, margin, &previous_gain_squared_, - &previous_masker_, low_band_gain); - } + ComputeGains(optimization_, nearend_power, residual_echo_power, + comfort_noise_power, margin, &previous_gain_squared_, + &previous_masker_, low_band_gain); if (num_capture_bands > 1) { // Compute the gain for upper frequencies. diff --git a/webrtc/modules/audio_processing/aec3/suppression_gain.h b/webrtc/modules/audio_processing/aec3/suppression_gain.h index d0b4114393..e4ad3fc714 100644 --- a/webrtc/modules/audio_processing/aec3/suppression_gain.h +++ b/webrtc/modules/audio_processing/aec3/suppression_gain.h @@ -18,30 +18,6 @@ #include "webrtc/modules/audio_processing/aec3/aec3_common.h" namespace webrtc { -namespace aec3 { -#if defined(WEBRTC_ARCH_X86_FAMILY) - -void ComputeGains_SSE2( - const std::array& nearend_power, - const std::array& residual_echo_power, - const std::array& comfort_noise_power, - float strong_nearend_margin, - std::array* previous_gain_squared, - std::array* previous_masker, - std::array* gain); - -#endif - -void ComputeGains( - const std::array& nearend_power, - const std::array& residual_echo_power, - const std::array& comfort_noise_power, - float strong_nearend_margin, - std::array* previous_gain_squared, - std::array* previous_masker, - std::array* gain); - -} // namespace aec3 class SuppressionGain { public: diff --git a/webrtc/modules/audio_processing/aec3/suppression_gain_unittest.cc b/webrtc/modules/audio_processing/aec3/suppression_gain_unittest.cc index 83c41e1254..1fd011f10c 100644 --- a/webrtc/modules/audio_processing/aec3/suppression_gain_unittest.cc +++ b/webrtc/modules/audio_processing/aec3/suppression_gain_unittest.cc @@ -39,80 +39,6 @@ TEST(SuppressionGain, NullOutputGains) { #endif -#if defined(WEBRTC_ARCH_X86_FAMILY) -// Verifies that the optimized methods are bitexact to their reference -// counterparts. -TEST(SuppressionGain, TestOptimizations) { - if (WebRtc_GetCPUInfo(kSSE2) != 0) { - std::array G2_old; - std::array M2_old; - std::array G2_old_SSE2; - std::array M2_old_SSE2; - std::array E2; - std::array R2; - std::array N2; - std::array g; - std::array g_SSE2; - - G2_old.fill(1.f); - M2_old.fill(.23f); - G2_old_SSE2.fill(1.f); - M2_old_SSE2.fill(.23f); - - E2.fill(10.f); - R2.fill(0.1f); - N2.fill(100.f); - for (int k = 0; k < 10; ++k) { - ComputeGains(E2, R2, N2, 0.1f, &G2_old, &M2_old, &g); - ComputeGains_SSE2(E2, R2, N2, 0.1f, &G2_old_SSE2, &M2_old_SSE2, &g_SSE2); - for (size_t j = 0; j < G2_old.size(); ++j) { - EXPECT_NEAR(G2_old[j], G2_old_SSE2[j], 0.0000001f); - } - for (size_t j = 0; j < M2_old.size(); ++j) { - EXPECT_NEAR(M2_old[j], M2_old_SSE2[j], 0.0000001f); - } - for (size_t j = 0; j < g.size(); ++j) { - EXPECT_NEAR(g[j], g_SSE2[j], 0.0000001f); - } - } - - E2.fill(100.f); - R2.fill(0.1f); - N2.fill(0.f); - for (int k = 0; k < 10; ++k) { - ComputeGains(E2, R2, N2, 0.1f, &G2_old, &M2_old, &g); - ComputeGains_SSE2(E2, R2, N2, 0.1f, &G2_old_SSE2, &M2_old_SSE2, &g_SSE2); - for (size_t j = 0; j < G2_old.size(); ++j) { - EXPECT_NEAR(G2_old[j], G2_old_SSE2[j], 0.0000001f); - } - for (size_t j = 0; j < M2_old.size(); ++j) { - EXPECT_NEAR(M2_old[j], M2_old_SSE2[j], 0.0000001f); - } - for (size_t j = 0; j < g.size(); ++j) { - EXPECT_NEAR(g[j], g_SSE2[j], 0.0000001f); - } - } - - E2.fill(0.1f); - R2.fill(100.f); - N2.fill(0.f); - for (int k = 0; k < 10; ++k) { - ComputeGains(E2, R2, N2, 0.1f, &G2_old, &M2_old, &g); - ComputeGains_SSE2(E2, R2, N2, 0.1f, &G2_old_SSE2, &M2_old_SSE2, &g_SSE2); - for (size_t j = 0; j < G2_old.size(); ++j) { - EXPECT_NEAR(G2_old[j], G2_old_SSE2[j], 0.0000001f); - } - for (size_t j = 0; j < M2_old.size(); ++j) { - EXPECT_NEAR(M2_old[j], M2_old_SSE2[j], 0.0000001f); - } - for (size_t j = 0; j < g.size(); ++j) { - EXPECT_NEAR(g[j], g_SSE2[j], 0.0000001f); - } - } - } -} -#endif - // Does a sanity check that the gains are correctly computed. TEST(SuppressionGain, BasicGainComputation) { SuppressionGain suppression_gain(DetectOptimization()); diff --git a/webrtc/modules/audio_processing/aec3/vector_math.h b/webrtc/modules/audio_processing/aec3/vector_math.h new file mode 100644 index 0000000000..afd4262b6b --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/vector_math.h @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_VECTOR_MATH_H_ +#define WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_VECTOR_MATH_H_ + +#include "webrtc/typedefs.h" +#if defined(WEBRTC_ARCH_X86_FAMILY) +#include +#endif +#include +#include +#include +#include + +#include "webrtc/base/array_view.h" +#include "webrtc/base/checks.h" +#include "webrtc/modules/audio_processing/aec3/aec3_common.h" + +namespace webrtc { +namespace aec3 { + +// Provides optimizations for mathematical operations based on vectors. +class VectorMath { + public: + explicit VectorMath(Aec3Optimization optimization) + : optimization_(optimization) {} + + // Elementwise square root. + void Sqrt(rtc::ArrayView x) { + switch (optimization_) { +#if defined(WEBRTC_ARCH_X86_FAMILY) + case Aec3Optimization::kSse2: { + const int x_size = static_cast(x.size()); + const int vector_limit = x_size >> 2; + + int j = 0; + for (; j < vector_limit * 4; j += 4) { + __m128 g = _mm_loadu_ps(&x[j]); + g = _mm_sqrt_ps(g); + _mm_storeu_ps(&x[j], g); + } + + for (; j < x_size; ++j) { + x[j] = sqrtf(x[j]); + } + } break; +#endif + default: + std::for_each(x.begin(), x.end(), [](float& a) { a = sqrtf(a); }); + } + } + + // Elementwise vector multiplication z = x * y. + void Multiply(rtc::ArrayView x, + rtc::ArrayView y, + rtc::ArrayView z) { + RTC_DCHECK_EQ(z.size(), x.size()); + RTC_DCHECK_EQ(z.size(), y.size()); + switch (optimization_) { +#if defined(WEBRTC_ARCH_X86_FAMILY) + case Aec3Optimization::kSse2: { + const int x_size = static_cast(x.size()); + const int vector_limit = x_size >> 2; + + int j = 0; + for (; j < vector_limit * 4; j += 4) { + const __m128 x_j = _mm_loadu_ps(&x[j]); + const __m128 y_j = _mm_loadu_ps(&y[j]); + const __m128 z_j = _mm_mul_ps(x_j, y_j); + _mm_storeu_ps(&z[j], z_j); + } + + for (; j < x_size; ++j) { + z[j] = x[j] * y[j]; + } + } break; +#endif + default: + std::transform(x.begin(), x.end(), y.begin(), z.begin(), + std::multiplies()); + } + } + + // Elementwise vector accumulation z += x. + void Accumulate(rtc::ArrayView x, rtc::ArrayView z) { + RTC_DCHECK_EQ(z.size(), x.size()); + switch (optimization_) { +#if defined(WEBRTC_ARCH_X86_FAMILY) + case Aec3Optimization::kSse2: { + const int x_size = static_cast(x.size()); + const int vector_limit = x_size >> 2; + + int j = 0; + for (; j < vector_limit * 4; j += 4) { + const __m128 x_j = _mm_loadu_ps(&x[j]); + __m128 z_j = _mm_loadu_ps(&z[j]); + z_j = _mm_add_ps(x_j, z_j); + _mm_storeu_ps(&z[j], z_j); + } + + for (; j < x_size; ++j) { + z[j] += x[j]; + } + } break; +#endif + default: + std::transform(x.begin(), x.end(), z.begin(), z.begin(), + std::plus()); + } + } + + private: + Aec3Optimization optimization_; +}; + +} // namespace aec3 + +} // namespace webrtc + +#endif // WEBRTC_MODULES_AUDIO_PROCESSING_AEC3_VECTOR_MATH_H_ diff --git a/webrtc/modules/audio_processing/aec3/vector_math_unittest.cc b/webrtc/modules/audio_processing/aec3/vector_math_unittest.cc new file mode 100644 index 0000000000..b40cf8d2a5 --- /dev/null +++ b/webrtc/modules/audio_processing/aec3/vector_math_unittest.cc @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/modules/audio_processing/aec3/vector_math.h" + +#include + +#include "webrtc/system_wrappers/include/cpu_features_wrapper.h" +#include "webrtc/test/gtest.h" +#include "webrtc/typedefs.h" + +namespace webrtc { + +#if defined(WEBRTC_ARCH_X86_FAMILY) + +TEST(VectorMath, Sqrt) { + if (WebRtc_GetCPUInfo(kSSE2) != 0) { + std::array x; + std::array z; + std::array z_sse2; + + for (size_t k = 0; k < x.size(); ++k) { + x[k] = (2.f / 3.f) * k; + } + + std::copy(x.begin(), x.end(), z.begin()); + aec3::VectorMath(Aec3Optimization::kNone).Sqrt(z); + std::copy(x.begin(), x.end(), z_sse2.begin()); + aec3::VectorMath(Aec3Optimization::kSse2).Sqrt(z_sse2); + EXPECT_EQ(z, z_sse2); + for (size_t k = 0; k < z.size(); ++k) { + EXPECT_FLOAT_EQ(z[k], z_sse2[k]); + EXPECT_FLOAT_EQ(sqrtf(x[k]), z_sse2[k]); + } + } +} + +TEST(VectorMath, Multiply) { + if (WebRtc_GetCPUInfo(kSSE2) != 0) { + std::array x; + std::array y; + std::array z; + std::array z_sse2; + + for (size_t k = 0; k < x.size(); ++k) { + x[k] = k; + y[k] = (2.f / 3.f) * k; + } + + aec3::VectorMath(Aec3Optimization::kNone).Multiply(x, y, z); + aec3::VectorMath(Aec3Optimization::kSse2).Multiply(x, y, z_sse2); + for (size_t k = 0; k < z.size(); ++k) { + EXPECT_FLOAT_EQ(z[k], z_sse2[k]); + EXPECT_FLOAT_EQ(x[k] * y[k], z_sse2[k]); + } + } +} + +TEST(VectorMath, Accumulate) { + if (WebRtc_GetCPUInfo(kSSE2) != 0) { + std::array x; + std::array z; + std::array z_sse2; + + for (size_t k = 0; k < x.size(); ++k) { + x[k] = k; + z[k] = z_sse2[k] = 2.f * k; + } + + aec3::VectorMath(Aec3Optimization::kNone).Accumulate(x, z); + aec3::VectorMath(Aec3Optimization::kSse2).Accumulate(x, z_sse2); + for (size_t k = 0; k < z.size(); ++k) { + EXPECT_FLOAT_EQ(z[k], z_sse2[k]); + EXPECT_FLOAT_EQ(x[k] + 2.f * x[k], z_sse2[k]); + } + } +} +#endif + +} // namespace webrtc