Further SSE2 optimizations for the AEC3 adaptive filter.

This CL adds further SSE2 optimizations for the AEC3
adaptive filter.

The changes are bitexact

BUG=webrtc:6018

Review-Url: https://codereview.webrtc.org/2810133002
Cr-Commit-Position: refs/heads/master@{#17667}
This commit is contained in:
peah 2017-04-12 03:04:09 -07:00 committed by Commit bot
parent 1c223b2f75
commit 69ffdf4938
3 changed files with 132 additions and 8 deletions

View File

@ -36,10 +36,15 @@ void Constrain(const Aec3Fft& fft, FftData* H) {
fft.Fft(&h, H);
}
} // namespace
namespace aec3 {
// Computes and stores the frequency response of the filter.
void UpdateFrequencyResponse(
rtc::ArrayView<const FftData> H,
std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) {
RTC_DCHECK_EQ(H.size(), H2->size());
for (size_t k = 0; k < H.size(); ++k) {
std::transform(H[k].re.begin(), H[k].re.end(), H[k].im.begin(),
(*H2)[k].begin(),
@ -47,6 +52,27 @@ void UpdateFrequencyResponse(
}
}
#if defined(WEBRTC_ARCH_X86_FAMILY)
// Computes and stores the frequency response of the filter.
void UpdateFrequencyResponse_SSE2(
rtc::ArrayView<const FftData> H,
std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) {
RTC_DCHECK_EQ(H.size(), H2->size());
for (size_t k = 0; k < H.size(); ++k) {
for (size_t j = 0; j < kFftLengthBy2; j += 4) {
const __m128 re = _mm_loadu_ps(&H[k].re[j]);
const __m128 re2 = _mm_mul_ps(re, re);
const __m128 im = _mm_loadu_ps(&H[k].im[j]);
const __m128 im2 = _mm_mul_ps(im, im);
const __m128 H2_k_j = _mm_add_ps(re2, im2);
_mm_storeu_ps(&(*H2)[k][j], H2_k_j);
}
(*H2)[k][kFftLengthBy2] = H[k].re[kFftLengthBy2] * H[k].re[kFftLengthBy2] +
H[k].im[kFftLengthBy2] * H[k].im[kFftLengthBy2];
}
}
#endif
// Computes and stores the echo return loss estimate of the filter, which is the
// sum of the partition frequency responses.
void UpdateErlEstimator(
@ -59,9 +85,24 @@ void UpdateErlEstimator(
}
}
} // namespace
namespace aec3 {
#if defined(WEBRTC_ARCH_X86_FAMILY)
// Computes and stores the echo return loss estimate of the filter, which is the
// sum of the partition frequency responses.
void UpdateErlEstimator_SSE2(
const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
std::array<float, kFftLengthBy2Plus1>* erl) {
erl->fill(0.f);
for (auto& H2_j : H2) {
for (size_t k = 0; k < kFftLengthBy2; k += 4) {
const __m128 H2_j_k = _mm_loadu_ps(&H2_j[k]);
__m128 erl_k = _mm_loadu_ps(&(*erl)[k]);
erl_k = _mm_add_ps(erl_k, H2_j_k);
_mm_storeu_ps(&(*erl)[k], erl_k);
}
(*erl)[kFftLengthBy2] += H2_j[kFftLengthBy2];
}
}
#endif
// Adapts the filter partitions as H(t+1)=H(t)+G(t)*conj(X(t)).
void AdaptPartitions(const RenderBuffer& render_buffer,
@ -290,8 +331,17 @@ void AdaptiveFirFilter::Adapt(const RenderBuffer& render_buffer,
: 0;
// Update the frequency response and echo return loss for the filter.
UpdateFrequencyResponse(H_, &H2_);
UpdateErlEstimator(H2_, &erl_);
switch (optimization_) {
#if defined(WEBRTC_ARCH_X86_FAMILY)
case Aec3Optimization::kSse2:
aec3::UpdateFrequencyResponse_SSE2(H_, &H2_);
aec3::UpdateErlEstimator_SSE2(H2_, &erl_);
break;
#endif
default:
aec3::UpdateFrequencyResponse(H_, &H2_);
aec3::UpdateErlEstimator(H2_, &erl_);
}
}
} // namespace webrtc

View File

@ -25,6 +25,27 @@
namespace webrtc {
namespace aec3 {
// Computes and stores the frequency response of the filter.
void UpdateFrequencyResponse(
rtc::ArrayView<const FftData> H,
std::vector<std::array<float, kFftLengthBy2Plus1>>* H2);
#if defined(WEBRTC_ARCH_X86_FAMILY)
void UpdateFrequencyResponse_SSE2(
rtc::ArrayView<const FftData> H,
std::vector<std::array<float, kFftLengthBy2Plus1>>* H2);
#endif
// Computes and stores the echo return loss estimate of the filter, which is the
// sum of the partition frequency responses.
void UpdateErlEstimator(
const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
std::array<float, kFftLengthBy2Plus1>* erl);
#if defined(WEBRTC_ARCH_X86_FAMILY)
void UpdateErlEstimator_SSE2(
const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
std::array<float, kFftLengthBy2Plus1>* erl);
#endif
// Adapts the filter partitions.
void AdaptPartitions(const RenderBuffer& render_buffer,
const FftData& G,

View File

@ -42,9 +42,9 @@ std::string ProduceDebugText(size_t delay) {
} // namespace
#if defined(WEBRTC_ARCH_X86_FAMILY)
// Verifies that the optimized methods are bitexact to their reference
// counterparts.
TEST(AdaptiveFirFilter, TestOptimizations) {
// Verifies that the optimized methods for filter adaptation are bitexact to
// their reference counterparts.
TEST(AdaptiveFirFilter, FilterAdaptationOptimizations) {
bool use_sse2 = (WebRtc_GetCPUInfo(kSSE2) != 0);
if (use_sse2) {
RenderBuffer render_buffer(Aec3Optimization::kNone, 3, 12,
@ -93,6 +93,59 @@ TEST(AdaptiveFirFilter, TestOptimizations) {
}
}
// Verifies that the optimized method for frequency response computation is
// bitexact to the reference counterpart.
TEST(AdaptiveFirFilter, UpdateFrequencyResponseOptimization) {
bool use_sse2 = (WebRtc_GetCPUInfo(kSSE2) != 0);
if (use_sse2) {
const size_t kNumPartitions = 12;
std::vector<FftData> H(kNumPartitions);
std::vector<std::array<float, kFftLengthBy2Plus1>> H2(kNumPartitions);
std::vector<std::array<float, kFftLengthBy2Plus1>> H2_SSE2(kNumPartitions);
for (size_t j = 0; j < H.size(); ++j) {
for (size_t k = 0; k < H[j].re.size(); ++k) {
H[j].re[k] = k + j / 3.f;
H[j].im[k] = j + k / 7.f;
}
}
UpdateFrequencyResponse(H, &H2);
UpdateFrequencyResponse_SSE2(H, &H2_SSE2);
for (size_t j = 0; j < H2.size(); ++j) {
for (size_t k = 0; k < H[j].re.size(); ++k) {
EXPECT_FLOAT_EQ(H2[j][k], H2_SSE2[j][k]);
}
}
}
}
// Verifies that the optimized method for echo return loss computation is
// bitexact to the reference counterpart.
TEST(AdaptiveFirFilter, UpdateErlOptimization) {
bool use_sse2 = (WebRtc_GetCPUInfo(kSSE2) != 0);
if (use_sse2) {
const size_t kNumPartitions = 12;
std::vector<std::array<float, kFftLengthBy2Plus1>> H2(kNumPartitions);
std::array<float, kFftLengthBy2Plus1> erl;
std::array<float, kFftLengthBy2Plus1> erl_SSE2;
for (size_t j = 0; j < H2.size(); ++j) {
for (size_t k = 0; k < H2[j].size(); ++k) {
H2[j][k] = k + j / 3.f;
}
}
UpdateErlEstimator(H2, &erl);
UpdateErlEstimator_SSE2(H2, &erl_SSE2);
for (size_t j = 0; j < erl.size(); ++j) {
EXPECT_FLOAT_EQ(erl[j], erl_SSE2[j]);
}
}
}
#endif
#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)