diff --git a/modules/audio_processing/aec3/matched_filter.cc b/modules/audio_processing/aec3/matched_filter.cc index c6747cb0cd..794381cc8a 100644 --- a/modules/audio_processing/aec3/matched_filter.cc +++ b/modules/audio_processing/aec3/matched_filter.cc @@ -166,7 +166,9 @@ void MatchedFilterCore_SSE2(size_t x_start_index, // Initialize values for the accumulation. __m128 s_128 = _mm_set1_ps(0); + __m128 s_128_4 = _mm_set1_ps(0); __m128 x2_sum_128 = _mm_set1_ps(0); + __m128 x2_sum_128_4 = _mm_set1_ps(0); float x2_sum = 0.f; float s = 0; @@ -179,20 +181,26 @@ void MatchedFilterCore_SSE2(size_t x_start_index, const int chunk2 = h_size - chunk1; for (int limit : {chunk1, chunk2}) { // Perform 128 bit vector operations. - const int limit_by_4 = limit >> 2; - for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) { + const int limit_by_8 = limit >> 3; + for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8) { // Load the data into 128 bit vectors. const __m128 x_k = _mm_loadu_ps(x_p); const __m128 h_k = _mm_loadu_ps(h_p); + const __m128 x_k_4 = _mm_loadu_ps(x_p + 4); + const __m128 h_k_4 = _mm_loadu_ps(h_p + 4); const __m128 xx = _mm_mul_ps(x_k, x_k); + const __m128 xx_4 = _mm_mul_ps(x_k_4, x_k_4); // Compute and accumulate x * x and h * x. x2_sum_128 = _mm_add_ps(x2_sum_128, xx); + x2_sum_128_4 = _mm_add_ps(x2_sum_128_4, xx_4); const __m128 hx = _mm_mul_ps(h_k, x_k); + const __m128 hx_4 = _mm_mul_ps(h_k_4, x_k_4); s_128 = _mm_add_ps(s_128, hx); + s_128_4 = _mm_add_ps(s_128_4, hx_4); } // Perform non-vector operations for any remaining items. - for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) { + for (int k = limit - limit_by_8 * 8; k > 0; --k, ++h_p, ++x_p) { const float x_k = *x_p; x2_sum += x_k * x_k; s += *h_p * x_k; @@ -202,8 +210,10 @@ void MatchedFilterCore_SSE2(size_t x_start_index, } // Combine the accumulated vector and scalar values. + x2_sum_128 = _mm_add_ps(x2_sum_128, x2_sum_128_4); float* v = reinterpret_cast(&x2_sum_128); x2_sum += v[0] + v[1] + v[2] + v[3]; + s_128 = _mm_add_ps(s_128, s_128_4); v = reinterpret_cast(&s_128); s += v[0] + v[1] + v[2] + v[3]; diff --git a/modules/audio_processing/aec3/matched_filter_avx2.cc b/modules/audio_processing/aec3/matched_filter_avx2.cc index ed32102aa4..8b7010f1dc 100644 --- a/modules/audio_processing/aec3/matched_filter_avx2.cc +++ b/modules/audio_processing/aec3/matched_filter_avx2.cc @@ -39,7 +39,9 @@ void MatchedFilterCore_AVX2(size_t x_start_index, // Initialize values for the accumulation. __m256 s_256 = _mm256_set1_ps(0); + __m256 s_256_8 = _mm256_set1_ps(0); __m256 x2_sum_256 = _mm256_set1_ps(0); + __m256 x2_sum_256_8 = _mm256_set1_ps(0); float x2_sum = 0.f; float s = 0; @@ -52,18 +54,22 @@ void MatchedFilterCore_AVX2(size_t x_start_index, const int chunk2 = h_size - chunk1; for (int limit : {chunk1, chunk2}) { // Perform 256 bit vector operations. - const int limit_by_8 = limit >> 3; - for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8) { + const int limit_by_16 = limit >> 4; + for (int k = limit_by_16; k > 0; --k, h_p += 16, x_p += 16) { // Load the data into 256 bit vectors. __m256 x_k = _mm256_loadu_ps(x_p); __m256 h_k = _mm256_loadu_ps(h_p); + __m256 x_k_8 = _mm256_loadu_ps(x_p + 8); + __m256 h_k_8 = _mm256_loadu_ps(h_p + 8); // Compute and accumulate x * x and h * x. x2_sum_256 = _mm256_fmadd_ps(x_k, x_k, x2_sum_256); + x2_sum_256_8 = _mm256_fmadd_ps(x_k_8, x_k_8, x2_sum_256_8); s_256 = _mm256_fmadd_ps(h_k, x_k, s_256); + s_256_8 = _mm256_fmadd_ps(h_k_8, x_k_8, s_256_8); } // Perform non-vector operations for any remaining items. - for (int k = limit - limit_by_8 * 8; k > 0; --k, ++h_p, ++x_p) { + for (int k = limit - limit_by_16 * 16; k > 0; --k, ++h_p, ++x_p) { const float x_k = *x_p; x2_sum += x_k * x_k; s += *h_p * x_k; @@ -73,6 +79,8 @@ void MatchedFilterCore_AVX2(size_t x_start_index, } // Sum components together. + x2_sum_256 = _mm256_add_ps(x2_sum_256, x2_sum_256_8); + s_256 = _mm256_add_ps(s_256, s_256_8); __m128 x2_sum_128 = _mm_add_ps(_mm256_extractf128_ps(x2_sum_256, 0), _mm256_extractf128_ps(x2_sum_256, 1)); __m128 s_128 = _mm_add_ps(_mm256_extractf128_ps(s_256, 0),