Finalized the SSE2 optimizations for the matched filter in AEC3
The SSE2 optimizations of the filter core in the matched filter was only half-done. This CL finalizes those. In particular: -It adds finalization of updating of the filter. -It removes the manual loop unrolling in order to reduce and simplify the code. Note that the changes pass the bitexactness tests in an external AEC3 test suite, and the test MatchedFilter.TestOptimizations succeed. BUG=webrtc:6018 Review-Url: https://codereview.webrtc.org/2813563003 Cr-Commit-Position: refs/heads/master@{#17655}
This commit is contained in:
parent
c0d74d9684
commit
b213a16b28
@ -31,50 +31,56 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
|
||||
rtc::ArrayView<float> h,
|
||||
bool* filters_updated,
|
||||
float* error_sum) {
|
||||
const int h_size = static_cast<int>(h.size());
|
||||
const int x_size = static_cast<int>(x.size());
|
||||
RTC_DCHECK_EQ(0, h_size % 4);
|
||||
|
||||
// Process for all samples in the sub-block.
|
||||
for (size_t i = 0; i < kSubBlockSize; ++i) {
|
||||
// Apply the matched filter as filter * x. and compute x * x.
|
||||
float x2_sum = 0.f;
|
||||
float s = 0;
|
||||
size_t x_index = x_start_index;
|
||||
RTC_DCHECK_EQ(0, h.size() % 4);
|
||||
// Apply the matched filter as filter * x, and compute x * x.
|
||||
|
||||
RTC_DCHECK_GT(x_size, x_start_index);
|
||||
const float* x_p = &x[x_start_index];
|
||||
const float* h_p = &h[0];
|
||||
|
||||
// Initialize values for the accumulation.
|
||||
__m128 s_128 = _mm_set1_ps(0);
|
||||
__m128 x2_sum_128 = _mm_set1_ps(0);
|
||||
float x2_sum = 0.f;
|
||||
float s = 0;
|
||||
|
||||
size_t k = 0;
|
||||
if (h.size() > (x.size() - x_index)) {
|
||||
const size_t limit = x.size() - x_index;
|
||||
for (; (k + 3) < limit; k += 4, x_index += 4) {
|
||||
const __m128 x_k = _mm_loadu_ps(&x[x_index]);
|
||||
const __m128 h_k = _mm_loadu_ps(&h[k]);
|
||||
// Compute loop chunk sizes until, and after, the wraparound of the circular
|
||||
// buffer for x.
|
||||
const int chunk1 =
|
||||
std::min(h_size, static_cast<int>(x_size - x_start_index));
|
||||
|
||||
// Perform the loop in two chunks.
|
||||
const int chunk2 = h_size - chunk1;
|
||||
for (int limit : {chunk1, chunk2}) {
|
||||
// Perform 128 bit vector operations.
|
||||
const int limit_by_4 = limit >> 2;
|
||||
for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
|
||||
// Load the data into 128 bit vectors.
|
||||
const __m128 x_k = _mm_loadu_ps(x_p);
|
||||
const __m128 h_k = _mm_loadu_ps(h_p);
|
||||
const __m128 xx = _mm_mul_ps(x_k, x_k);
|
||||
// Compute and accumulate x * x and h * x.
|
||||
x2_sum_128 = _mm_add_ps(x2_sum_128, xx);
|
||||
const __m128 hx = _mm_mul_ps(h_k, x_k);
|
||||
s_128 = _mm_add_ps(s_128, hx);
|
||||
}
|
||||
|
||||
for (; k < limit; ++k, ++x_index) {
|
||||
x2_sum += x[x_index] * x[x_index];
|
||||
s += h[k] * x[x_index];
|
||||
// Perform non-vector operations for any remaining items.
|
||||
for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
|
||||
const float x_k = *x_p;
|
||||
x2_sum += x_k * x_k;
|
||||
s += *h_p * x_k;
|
||||
}
|
||||
x_index = 0;
|
||||
}
|
||||
|
||||
for (; k + 3 < h.size(); k += 4, x_index += 4) {
|
||||
const __m128 x_k = _mm_loadu_ps(&x[x_index]);
|
||||
const __m128 h_k = _mm_loadu_ps(&h[k]);
|
||||
const __m128 xx = _mm_mul_ps(x_k, x_k);
|
||||
x2_sum_128 = _mm_add_ps(x2_sum_128, xx);
|
||||
const __m128 hx = _mm_mul_ps(h_k, x_k);
|
||||
s_128 = _mm_add_ps(s_128, hx);
|
||||
}
|
||||
|
||||
for (; k < h.size(); ++k, ++x_index) {
|
||||
x2_sum += x[x_index] * x[x_index];
|
||||
s += h[k] * x[x_index];
|
||||
|
||||
x_p = &x[0];
|
||||
}
|
||||
|
||||
// Combine the accumulated vector and scalar values.
|
||||
float* v = reinterpret_cast<float*>(&x2_sum_128);
|
||||
x2_sum += v[0] + v[1] + v[2] + v[3];
|
||||
v = reinterpret_cast<float*>(&s_128);
|
||||
@ -82,23 +88,47 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
|
||||
|
||||
// Compute the matched filter error.
|
||||
const float e = std::min(32767.f, std::max(-32768.f, y[i] - s));
|
||||
(*error_sum) += e * e;
|
||||
*error_sum += e * e;
|
||||
|
||||
// Update the matched filter estimate in an NLMS manner.
|
||||
if (x2_sum > x2_sum_threshold) {
|
||||
RTC_DCHECK_LT(0.f, x2_sum);
|
||||
const float alpha = 0.7f * e / x2_sum;
|
||||
const __m128 alpha_128 = _mm_set1_ps(alpha);
|
||||
|
||||
// filter = filter + 0.7 * (y - filter * x) / x * x.
|
||||
size_t x_index = x_start_index;
|
||||
for (size_t k = 0; k < h.size(); ++k) {
|
||||
h[k] += alpha * x[x_index];
|
||||
x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
|
||||
float* h_p = &h[0];
|
||||
x_p = &x[x_start_index];
|
||||
|
||||
// Perform the loop in two chunks.
|
||||
for (int limit : {chunk1, chunk2}) {
|
||||
// Perform 128 bit vector operations.
|
||||
const int limit_by_4 = limit >> 2;
|
||||
for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
|
||||
// Load the data into 128 bit vectors.
|
||||
__m128 h_k = _mm_loadu_ps(h_p);
|
||||
const __m128 x_k = _mm_loadu_ps(x_p);
|
||||
|
||||
// Compute h = h + alpha * x.
|
||||
const __m128 alpha_x = _mm_mul_ps(alpha_128, x_k);
|
||||
h_k = _mm_add_ps(h_k, alpha_x);
|
||||
|
||||
// Store the result.
|
||||
_mm_storeu_ps(h_p, h_k);
|
||||
}
|
||||
|
||||
// Perform non-vector operations for any remaining items.
|
||||
for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
|
||||
*h_p += alpha * *x_p;
|
||||
}
|
||||
|
||||
x_p = &x[0];
|
||||
}
|
||||
|
||||
*filters_updated = true;
|
||||
}
|
||||
|
||||
x_start_index = x_start_index > 0 ? x_start_index - 1 : x.size() - 1;
|
||||
x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@ -112,7 +142,7 @@ void MatchedFilterCore(size_t x_start_index,
|
||||
float* error_sum) {
|
||||
// Process for all samples in the sub-block.
|
||||
for (size_t i = 0; i < kSubBlockSize; ++i) {
|
||||
// Apply the matched filter as filter * x. and compute x * x.
|
||||
// Apply the matched filter as filter * x, and compute x * x.
|
||||
float x2_sum = 0.f;
|
||||
float s = 0;
|
||||
size_t x_index = x_start_index;
|
||||
|
||||
@ -74,7 +74,7 @@ TEST(MatchedFilter, TestOptimizations) {
|
||||
EXPECT_NEAR(error_sum, error_sum_SSE2, error_sum / 100000.f);
|
||||
|
||||
for (size_t j = 0; j < h.size(); ++j) {
|
||||
EXPECT_NEAR(h[j], h_SSE2[j], 0.001f);
|
||||
EXPECT_NEAR(h[j], h_SSE2[j], 0.00001f);
|
||||
}
|
||||
|
||||
x_index = (x_index + kSubBlockSize) % x.size();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user