diff --git a/webrtc/modules/audio_processing/aec3/adaptive_fir_filter.cc b/webrtc/modules/audio_processing/aec3/adaptive_fir_filter.cc index 7f66ce5c94..3174fa762e 100644 --- a/webrtc/modules/audio_processing/aec3/adaptive_fir_filter.cc +++ b/webrtc/modules/audio_processing/aec3/adaptive_fir_filter.cc @@ -36,10 +36,15 @@ void Constrain(const Aec3Fft& fft, FftData* H) { fft.Fft(&h, H); } +} // namespace + +namespace aec3 { + // Computes and stores the frequency response of the filter. void UpdateFrequencyResponse( rtc::ArrayView H, std::vector>* H2) { + RTC_DCHECK_EQ(H.size(), H2->size()); for (size_t k = 0; k < H.size(); ++k) { std::transform(H[k].re.begin(), H[k].re.end(), H[k].im.begin(), (*H2)[k].begin(), @@ -47,6 +52,27 @@ void UpdateFrequencyResponse( } } +#if defined(WEBRTC_ARCH_X86_FAMILY) +// Computes and stores the frequency response of the filter. +void UpdateFrequencyResponse_SSE2( + rtc::ArrayView H, + std::vector>* H2) { + RTC_DCHECK_EQ(H.size(), H2->size()); + for (size_t k = 0; k < H.size(); ++k) { + for (size_t j = 0; j < kFftLengthBy2; j += 4) { + const __m128 re = _mm_loadu_ps(&H[k].re[j]); + const __m128 re2 = _mm_mul_ps(re, re); + const __m128 im = _mm_loadu_ps(&H[k].im[j]); + const __m128 im2 = _mm_mul_ps(im, im); + const __m128 H2_k_j = _mm_add_ps(re2, im2); + _mm_storeu_ps(&(*H2)[k][j], H2_k_j); + } + (*H2)[k][kFftLengthBy2] = H[k].re[kFftLengthBy2] * H[k].re[kFftLengthBy2] + + H[k].im[kFftLengthBy2] * H[k].im[kFftLengthBy2]; + } +} +#endif + // Computes and stores the echo return loss estimate of the filter, which is the // sum of the partition frequency responses. void UpdateErlEstimator( @@ -59,9 +85,24 @@ void UpdateErlEstimator( } } -} // namespace - -namespace aec3 { +#if defined(WEBRTC_ARCH_X86_FAMILY) +// Computes and stores the echo return loss estimate of the filter, which is the +// sum of the partition frequency responses. +void UpdateErlEstimator_SSE2( + const std::vector>& H2, + std::array* erl) { + erl->fill(0.f); + for (auto& H2_j : H2) { + for (size_t k = 0; k < kFftLengthBy2; k += 4) { + const __m128 H2_j_k = _mm_loadu_ps(&H2_j[k]); + __m128 erl_k = _mm_loadu_ps(&(*erl)[k]); + erl_k = _mm_add_ps(erl_k, H2_j_k); + _mm_storeu_ps(&(*erl)[k], erl_k); + } + (*erl)[kFftLengthBy2] += H2_j[kFftLengthBy2]; + } +} +#endif // Adapts the filter partitions as H(t+1)=H(t)+G(t)*conj(X(t)). void AdaptPartitions(const RenderBuffer& render_buffer, @@ -290,8 +331,17 @@ void AdaptiveFirFilter::Adapt(const RenderBuffer& render_buffer, : 0; // Update the frequency response and echo return loss for the filter. - UpdateFrequencyResponse(H_, &H2_); - UpdateErlEstimator(H2_, &erl_); + switch (optimization_) { +#if defined(WEBRTC_ARCH_X86_FAMILY) + case Aec3Optimization::kSse2: + aec3::UpdateFrequencyResponse_SSE2(H_, &H2_); + aec3::UpdateErlEstimator_SSE2(H2_, &erl_); + break; +#endif + default: + aec3::UpdateFrequencyResponse(H_, &H2_); + aec3::UpdateErlEstimator(H2_, &erl_); + } } } // namespace webrtc diff --git a/webrtc/modules/audio_processing/aec3/adaptive_fir_filter.h b/webrtc/modules/audio_processing/aec3/adaptive_fir_filter.h index 4fe10eabbb..78b64225bd 100644 --- a/webrtc/modules/audio_processing/aec3/adaptive_fir_filter.h +++ b/webrtc/modules/audio_processing/aec3/adaptive_fir_filter.h @@ -25,6 +25,27 @@ namespace webrtc { namespace aec3 { +// Computes and stores the frequency response of the filter. +void UpdateFrequencyResponse( + rtc::ArrayView H, + std::vector>* H2); +#if defined(WEBRTC_ARCH_X86_FAMILY) +void UpdateFrequencyResponse_SSE2( + rtc::ArrayView H, + std::vector>* H2); +#endif + +// Computes and stores the echo return loss estimate of the filter, which is the +// sum of the partition frequency responses. +void UpdateErlEstimator( + const std::vector>& H2, + std::array* erl); +#if defined(WEBRTC_ARCH_X86_FAMILY) +void UpdateErlEstimator_SSE2( + const std::vector>& H2, + std::array* erl); +#endif + // Adapts the filter partitions. void AdaptPartitions(const RenderBuffer& render_buffer, const FftData& G, diff --git a/webrtc/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc b/webrtc/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc index 85d9769bf3..6d1a5820fe 100644 --- a/webrtc/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc +++ b/webrtc/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc @@ -42,9 +42,9 @@ std::string ProduceDebugText(size_t delay) { } // namespace #if defined(WEBRTC_ARCH_X86_FAMILY) -// Verifies that the optimized methods are bitexact to their reference -// counterparts. -TEST(AdaptiveFirFilter, TestOptimizations) { +// Verifies that the optimized methods for filter adaptation are bitexact to +// their reference counterparts. +TEST(AdaptiveFirFilter, FilterAdaptationOptimizations) { bool use_sse2 = (WebRtc_GetCPUInfo(kSSE2) != 0); if (use_sse2) { RenderBuffer render_buffer(Aec3Optimization::kNone, 3, 12, @@ -93,6 +93,59 @@ TEST(AdaptiveFirFilter, TestOptimizations) { } } +// Verifies that the optimized method for frequency response computation is +// bitexact to the reference counterpart. +TEST(AdaptiveFirFilter, UpdateFrequencyResponseOptimization) { + bool use_sse2 = (WebRtc_GetCPUInfo(kSSE2) != 0); + if (use_sse2) { + const size_t kNumPartitions = 12; + std::vector H(kNumPartitions); + std::vector> H2(kNumPartitions); + std::vector> H2_SSE2(kNumPartitions); + + for (size_t j = 0; j < H.size(); ++j) { + for (size_t k = 0; k < H[j].re.size(); ++k) { + H[j].re[k] = k + j / 3.f; + H[j].im[k] = j + k / 7.f; + } + } + + UpdateFrequencyResponse(H, &H2); + UpdateFrequencyResponse_SSE2(H, &H2_SSE2); + + for (size_t j = 0; j < H2.size(); ++j) { + for (size_t k = 0; k < H[j].re.size(); ++k) { + EXPECT_FLOAT_EQ(H2[j][k], H2_SSE2[j][k]); + } + } + } +} + +// Verifies that the optimized method for echo return loss computation is +// bitexact to the reference counterpart. +TEST(AdaptiveFirFilter, UpdateErlOptimization) { + bool use_sse2 = (WebRtc_GetCPUInfo(kSSE2) != 0); + if (use_sse2) { + const size_t kNumPartitions = 12; + std::vector> H2(kNumPartitions); + std::array erl; + std::array erl_SSE2; + + for (size_t j = 0; j < H2.size(); ++j) { + for (size_t k = 0; k < H2[j].size(); ++k) { + H2[j][k] = k + j / 3.f; + } + } + + UpdateErlEstimator(H2, &erl); + UpdateErlEstimator_SSE2(H2, &erl_SSE2); + + for (size_t j = 0; j < erl.size(); ++j) { + EXPECT_FLOAT_EQ(erl[j], erl_SSE2[j]); + } + } +} + #endif #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)