From ce702dbbe4e86d5248a6ad4b13475cc166dd02a8 Mon Sep 17 00:00:00 2001 From: cschuldt Date: Wed, 8 Dec 2021 09:49:00 +0100 Subject: [PATCH] Optimize SSE2- & AVX2 parts of the matched filter further. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Manually unrolling the multiply-and-accumulate loop of the matched filter allows interleaving of instruction, which gives a significant saving. Bug: None Change-Id: Ie7a7d92bd453d81e9dd61812781a7b6d62e1f1f4 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/240321 Reviewed-by: Ivo Creusen Commit-Queue: Christian Schuldt Reviewed-by: Per Ã…hgren Cr-Commit-Position: refs/heads/main@{#35566} --- modules/audio_processing/aec3/matched_filter.cc | 16 +++++++++++++--- .../audio_processing/aec3/matched_filter_avx2.cc | 14 +++++++++++--- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/modules/audio_processing/aec3/matched_filter.cc b/modules/audio_processing/aec3/matched_filter.cc index c6747cb0cd..794381cc8a 100644 --- a/modules/audio_processing/aec3/matched_filter.cc +++ b/modules/audio_processing/aec3/matched_filter.cc @@ -166,7 +166,9 @@ void MatchedFilterCore_SSE2(size_t x_start_index, // Initialize values for the accumulation. __m128 s_128 = _mm_set1_ps(0); + __m128 s_128_4 = _mm_set1_ps(0); __m128 x2_sum_128 = _mm_set1_ps(0); + __m128 x2_sum_128_4 = _mm_set1_ps(0); float x2_sum = 0.f; float s = 0; @@ -179,20 +181,26 @@ void MatchedFilterCore_SSE2(size_t x_start_index, const int chunk2 = h_size - chunk1; for (int limit : {chunk1, chunk2}) { // Perform 128 bit vector operations. - const int limit_by_4 = limit >> 2; - for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) { + const int limit_by_8 = limit >> 3; + for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8) { // Load the data into 128 bit vectors. const __m128 x_k = _mm_loadu_ps(x_p); const __m128 h_k = _mm_loadu_ps(h_p); + const __m128 x_k_4 = _mm_loadu_ps(x_p + 4); + const __m128 h_k_4 = _mm_loadu_ps(h_p + 4); const __m128 xx = _mm_mul_ps(x_k, x_k); + const __m128 xx_4 = _mm_mul_ps(x_k_4, x_k_4); // Compute and accumulate x * x and h * x. x2_sum_128 = _mm_add_ps(x2_sum_128, xx); + x2_sum_128_4 = _mm_add_ps(x2_sum_128_4, xx_4); const __m128 hx = _mm_mul_ps(h_k, x_k); + const __m128 hx_4 = _mm_mul_ps(h_k_4, x_k_4); s_128 = _mm_add_ps(s_128, hx); + s_128_4 = _mm_add_ps(s_128_4, hx_4); } // Perform non-vector operations for any remaining items. - for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) { + for (int k = limit - limit_by_8 * 8; k > 0; --k, ++h_p, ++x_p) { const float x_k = *x_p; x2_sum += x_k * x_k; s += *h_p * x_k; @@ -202,8 +210,10 @@ void MatchedFilterCore_SSE2(size_t x_start_index, } // Combine the accumulated vector and scalar values. + x2_sum_128 = _mm_add_ps(x2_sum_128, x2_sum_128_4); float* v = reinterpret_cast(&x2_sum_128); x2_sum += v[0] + v[1] + v[2] + v[3]; + s_128 = _mm_add_ps(s_128, s_128_4); v = reinterpret_cast(&s_128); s += v[0] + v[1] + v[2] + v[3]; diff --git a/modules/audio_processing/aec3/matched_filter_avx2.cc b/modules/audio_processing/aec3/matched_filter_avx2.cc index ed32102aa4..8b7010f1dc 100644 --- a/modules/audio_processing/aec3/matched_filter_avx2.cc +++ b/modules/audio_processing/aec3/matched_filter_avx2.cc @@ -39,7 +39,9 @@ void MatchedFilterCore_AVX2(size_t x_start_index, // Initialize values for the accumulation. __m256 s_256 = _mm256_set1_ps(0); + __m256 s_256_8 = _mm256_set1_ps(0); __m256 x2_sum_256 = _mm256_set1_ps(0); + __m256 x2_sum_256_8 = _mm256_set1_ps(0); float x2_sum = 0.f; float s = 0; @@ -52,18 +54,22 @@ void MatchedFilterCore_AVX2(size_t x_start_index, const int chunk2 = h_size - chunk1; for (int limit : {chunk1, chunk2}) { // Perform 256 bit vector operations. - const int limit_by_8 = limit >> 3; - for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8) { + const int limit_by_16 = limit >> 4; + for (int k = limit_by_16; k > 0; --k, h_p += 16, x_p += 16) { // Load the data into 256 bit vectors. __m256 x_k = _mm256_loadu_ps(x_p); __m256 h_k = _mm256_loadu_ps(h_p); + __m256 x_k_8 = _mm256_loadu_ps(x_p + 8); + __m256 h_k_8 = _mm256_loadu_ps(h_p + 8); // Compute and accumulate x * x and h * x. x2_sum_256 = _mm256_fmadd_ps(x_k, x_k, x2_sum_256); + x2_sum_256_8 = _mm256_fmadd_ps(x_k_8, x_k_8, x2_sum_256_8); s_256 = _mm256_fmadd_ps(h_k, x_k, s_256); + s_256_8 = _mm256_fmadd_ps(h_k_8, x_k_8, s_256_8); } // Perform non-vector operations for any remaining items. - for (int k = limit - limit_by_8 * 8; k > 0; --k, ++h_p, ++x_p) { + for (int k = limit - limit_by_16 * 16; k > 0; --k, ++h_p, ++x_p) { const float x_k = *x_p; x2_sum += x_k * x_k; s += *h_p * x_k; @@ -73,6 +79,8 @@ void MatchedFilterCore_AVX2(size_t x_start_index, } // Sum components together. + x2_sum_256 = _mm256_add_ps(x2_sum_256, x2_sum_256_8); + s_256 = _mm256_add_ps(s_256, s_256_8); __m128 x2_sum_128 = _mm_add_ps(_mm256_extractf128_ps(x2_sum_256, 0), _mm256_extractf128_ps(x2_sum_256, 1)); __m128 s_128 = _mm_add_ps(_mm256_extractf128_ps(s_256, 0),