From a4f6303c5db11ecca599d999fbc341263be73a63 Mon Sep 17 00:00:00 2001 From: "cduvivier@google.com" Date: Thu, 2 Jun 2011 23:50:06 +0000 Subject: [PATCH] Vectorization of "FilterAdaptation": * 1.0% AEC overall speedup for straight C path. * 6.2% AEC overall speedup for SSE2 path. * fix warnings, make code compile with "-std=gnu89 -Wstrict-prototypes -Wold-style-definition -Wmissing-prototypes -Wmissing-declarations -Wdeclaration-after-statement -Wextra -Wall -Werror" Review URL: http://webrtc-codereview.appspot.com/24012 git-svn-id: http://webrtc.googlecode.com/svn/trunk@38 4adac7df-926f-26a2-2b94-8c16560cd09d --- .../aec/main/source/aec_core.c | 125 +++++++++--------- .../aec/main/source/aec_core.h | 7 + .../aec/main/source/aec_core_sse2.c | 102 +++++++++++++- 3 files changed, 168 insertions(+), 66 deletions(-) diff --git a/modules/audio_processing/aec/main/source/aec_core.c b/modules/audio_processing/aec/main/source/aec_core.c index 6534eb588e..3f8a088fd9 100644 --- a/modules/audio_processing/aec/main/source/aec_core.c +++ b/modules/audio_processing/aec/main/source/aec_core.c @@ -20,9 +20,6 @@ #include "ring_buffer.h" #include "system_wrappers/interface/cpu_features_wrapper.h" -#define IP_LEN PART_LEN // this must be at least ceil(2 + sqrt(PART_LEN)) -#define W_LEN PART_LEN - // Noise suppression static const int converged = 250; @@ -218,15 +215,16 @@ int WebRtcAec_FreeAec(aec_t *aec) static void FilterFar(aec_t *aec, float yf[2][PART_LEN1]) { - int i, j, pos; + int i; for (i = 0; i < NR_PART; i++) { + int j; int xPos = (i + aec->xfBufBlockPos) * PART_LEN1; + int pos = i * PART_LEN1; // Check for wrap if (i + aec->xfBufBlockPos >= NR_PART) { xPos -= NR_PART*(PART_LEN1); } - pos = i * PART_LEN1; for (j = 0; j < PART_LEN1; j++) { yf[0][j] += MulRe(aec->xfBuf[0][xPos + j], aec->xfBuf[1][xPos + j], aec->wfBuf[0][ pos + j], aec->wfBuf[1][ pos + j]); @@ -257,10 +255,68 @@ static void ScaleErrorSignal(aec_t *aec, float ef[2][PART_LEN1]) } } +static void FilterAdaptation(aec_t *aec, float *fft, float ef[2][PART_LEN1], + int ip[IP_LEN], float wfft[W_LEN]) { + int i, j; + for (i = 0; i < NR_PART; i++) { + int xPos = (i + aec->xfBufBlockPos)*(PART_LEN1); + int pos; + // Check for wrap + if (i + aec->xfBufBlockPos >= NR_PART) { + xPos -= NR_PART * PART_LEN1; + } + + pos = i * PART_LEN1; + +#ifdef UNCONSTR + for (j = 0; j < PART_LEN1; j++) { + aec->wfBuf[pos + j][0] += MulRe(aec->xfBuf[xPos + j][0], + -aec->xfBuf[xPos + j][1], + ef[j][0], ef[j][1]); + aec->wfBuf[pos + j][1] += MulIm(aec->xfBuf[xPos + j][0], + -aec->xfBuf[xPos + j][1], + ef[j][0], ef[j][1]); + } +#else + for (j = 0; j < PART_LEN; j++) { + + fft[2 * j] = MulRe(aec->xfBuf[0][xPos + j], + -aec->xfBuf[1][xPos + j], + ef[0][j], ef[1][j]); + fft[2 * j + 1] = MulIm(aec->xfBuf[0][xPos + j], + -aec->xfBuf[1][xPos + j], + ef[0][j], ef[1][j]); + } + fft[1] = MulRe(aec->xfBuf[0][xPos + PART_LEN], + -aec->xfBuf[1][xPos + PART_LEN], + ef[0][PART_LEN], ef[1][PART_LEN]); + + rdft(PART_LEN2, -1, fft, ip, wfft); + memset(fft + PART_LEN, 0, sizeof(float) * PART_LEN); + + // fft scaling + { + float scale = 2.0f / PART_LEN2; + for (j = 0; j < PART_LEN; j++) { + fft[j] *= scale; + } + } + rdft(PART_LEN2, 1, fft, ip, wfft); + + aec->wfBuf[0][pos] += fft[0]; + aec->wfBuf[0][pos + PART_LEN] += fft[1]; + + for (j = 1; j < PART_LEN; j++) { + aec->wfBuf[0][pos + j] += fft[2 * j]; + aec->wfBuf[1][pos + j] += fft[2 * j + 1]; + } +#endif // UNCONSTR + } +} + WebRtcAec_FilterFar_t WebRtcAec_FilterFar; WebRtcAec_ScaleErrorSignal_t WebRtcAec_ScaleErrorSignal; - -extern void WebRtcAec_InitAec_SSE2(void); +WebRtcAec_FilterAdaptation_t WebRtcAec_FilterAdaptation; int WebRtcAec_InitAec(aec_t *aec, int sampFreq) { @@ -387,6 +443,7 @@ int WebRtcAec_InitAec(aec_t *aec, int sampFreq) // Assembly optimization WebRtcAec_FilterFar = FilterFar; WebRtcAec_ScaleErrorSignal = ScaleErrorSignal; + WebRtcAec_FilterAdaptation = FilterAdaptation; if (WebRtc_GetCPUInfo(kSSE2)) { #if defined(__SSE2__) WebRtcAec_InitAec_SSE2(); @@ -483,11 +540,10 @@ static void ProcessBlock(aec_t *aec, const short *farend, const short *nearend, const short *nearendH, short *output, short *outputH) { - int i, j, pos; + int i; float d[PART_LEN], y[PART_LEN], e[PART_LEN], dH[PART_LEN]; short eInt16[PART_LEN]; float scale; - int xPos; float fft[PART_LEN2]; float xf[2][PART_LEN1], yf[2][PART_LEN1], ef[2][PART_LEN1]; @@ -656,56 +712,7 @@ static void ProcessBlock(aec_t *aec, const short *farend, if (aec->adaptToggle) { #endif // Filter adaptation - for (i = 0; i < NR_PART; i++) { - xPos = (i + aec->xfBufBlockPos)*(PART_LEN1); - // Check for wrap - if (i + aec->xfBufBlockPos >= NR_PART) { - xPos -= NR_PART*(PART_LEN1); - } - - pos = i * PART_LEN1; - -#ifdef UNCONSTR - for (j = 0; j < PART_LEN1; j++) { - aec->wfBuf[pos + j][0] += MulRe(aec->xfBuf[xPos + j][0], - -aec->xfBuf[xPos + j][1], ef[j][0], ef[j][1]); - aec->wfBuf[pos + j][1] += MulIm(aec->xfBuf[xPos + j][0], - -aec->xfBuf[xPos + j][1], ef[j][0], ef[j][1]); - } -#else - fft[0] = MulRe(aec->xfBuf[0][xPos], -aec->xfBuf[1][xPos], - ef[0][0], ef[1][0]); - fft[1] = MulRe(aec->xfBuf[0][xPos + PART_LEN], - -aec->xfBuf[1][xPos + PART_LEN], - ef[0][PART_LEN], ef[1][PART_LEN]); - - for (j = 1; j < PART_LEN; j++) { - - fft[2 * j] = MulRe(aec->xfBuf[0][xPos + j], - -aec->xfBuf[1][xPos + j], - ef[0][j], ef[1][j]); - fft[2 * j + 1] = MulIm(aec->xfBuf[0][xPos + j], - -aec->xfBuf[1][xPos + j], - ef[0][j], ef[1][j]); - } - rdft(PART_LEN2, -1, fft, ip, wfft); - memset(fft + PART_LEN, 0, sizeof(float)*PART_LEN); - - scale = 2.0f / PART_LEN2; - for (j = 0; j < PART_LEN; j++) { - fft[j] *= scale; // fft scaling - } - rdft(PART_LEN2, 1, fft, ip, wfft); - - aec->wfBuf[0][pos] += fft[0]; - aec->wfBuf[0][pos + PART_LEN] += fft[1]; - - for (j = 1; j < PART_LEN; j++) { - aec->wfBuf[0][pos + j] += fft[2 * j]; - aec->wfBuf[1][pos + j] += fft[2 * j + 1]; - } -#endif // UNCONSTR - } + WebRtcAec_FilterAdaptation(aec, fft, ef, ip, wfft); #ifdef G167 } #endif diff --git a/modules/audio_processing/aec/main/source/aec_core.h b/modules/audio_processing/aec/main/source/aec_core.h index 35d2f6b709..80d492f14a 100644 --- a/modules/audio_processing/aec/main/source/aec_core.h +++ b/modules/audio_processing/aec/main/source/aec_core.h @@ -170,10 +170,17 @@ typedef void (*WebRtcAec_FilterFar_t)(aec_t *aec, float yf[2][PART_LEN1]); extern WebRtcAec_FilterFar_t WebRtcAec_FilterFar; typedef void (*WebRtcAec_ScaleErrorSignal_t)(aec_t *aec, float ef[2][PART_LEN1]); extern WebRtcAec_ScaleErrorSignal_t WebRtcAec_ScaleErrorSignal; +#define IP_LEN PART_LEN // this must be at least ceil(2 + sqrt(PART_LEN)) +#define W_LEN PART_LEN +typedef void (*WebRtcAec_FilterAdaptation_t) + (aec_t *aec, float *fft, float ef[2][PART_LEN1], int ip[IP_LEN], + float wfft[W_LEN]); +extern WebRtcAec_FilterAdaptation_t WebRtcAec_FilterAdaptation; int WebRtcAec_CreateAec(aec_t **aec); int WebRtcAec_FreeAec(aec_t *aec); int WebRtcAec_InitAec(aec_t *aec, int sampFreq); +void WebRtcAec_InitAec_SSE2(void); void WebRtcAec_InitMetrics(aec_t *aec); void WebRtcAec_ProcessFrame(aec_t *aec, const short *farend, diff --git a/modules/audio_processing/aec/main/source/aec_core_sse2.c b/modules/audio_processing/aec/main/source/aec_core_sse2.c index 6cdada748b..8dfd118710 100644 --- a/modules/audio_processing/aec/main/source/aec_core_sse2.c +++ b/modules/audio_processing/aec/main/source/aec_core_sse2.c @@ -30,15 +30,16 @@ __inline static float MulIm(float aRe, float aIm, float bRe, float bIm) static void FilterFarSSE2(aec_t *aec, float yf[2][PART_LEN1]) { - int i, j, pos; + int i; for (i = 0; i < NR_PART; i++) { + int j; int xPos = (i + aec->xfBufBlockPos) * PART_LEN1; + int pos = i * PART_LEN1; // Check for wrap if (i + aec->xfBufBlockPos >= NR_PART) { xPos -= NR_PART*(PART_LEN1); } - pos = i * PART_LEN1; // vectorized code (four at once) for (j = 0; j + 3 < PART_LEN1; j += 4) { const __m128 xfBuf_re = _mm_loadu_ps(&aec->xfBuf[0][xPos + j]); @@ -78,12 +79,12 @@ static void ScaleErrorSignalSSE2(aec_t *aec, float ef[2][PART_LEN1]) // vectorized code (four at once) for (i = 0; i + 3 < PART_LEN1; i += 4) { const __m128 xPow = _mm_loadu_ps(&aec->xPow[i]); - __m128 ef_re = _mm_loadu_ps(&ef[0][i]); - __m128 ef_im = _mm_loadu_ps(&ef[1][i]); + const __m128 ef_re_base = _mm_loadu_ps(&ef[0][i]); + const __m128 ef_im_base = _mm_loadu_ps(&ef[1][i]); const __m128 xPowPlus = _mm_add_ps(xPow, k1e_10f); - ef_re = _mm_div_ps(ef_re, xPowPlus); - ef_im = _mm_div_ps(ef_im, xPowPlus); + __m128 ef_re = _mm_div_ps(ef_re_base, xPowPlus); + __m128 ef_im = _mm_div_ps(ef_im_base, xPowPlus); const __m128 ef_re2 = _mm_mul_ps(ef_re, ef_re); const __m128 ef_im2 = _mm_mul_ps(ef_im, ef_im); const __m128 ef_sum2 = _mm_add_ps(ef_re2, ef_im2); @@ -107,9 +108,10 @@ static void ScaleErrorSignalSSE2(aec_t *aec, float ef[2][PART_LEN1]) } // scalar code for the remaining items. for (; i < (PART_LEN1); i++) { + float absEf; ef[0][i] /= (aec->xPow[i] + 1e-10f); ef[1][i] /= (aec->xPow[i] + 1e-10f); - float absEf = sqrtf(ef[0][i] * ef[0][i] + ef[1][i] * ef[1][i]); + absEf = sqrtf(ef[0][i] * ef[0][i] + ef[1][i] * ef[1][i]); if (absEf > aec->errThresh) { absEf = aec->errThresh / (absEf + 1e-10f); @@ -123,9 +125,95 @@ static void ScaleErrorSignalSSE2(aec_t *aec, float ef[2][PART_LEN1]) } } +static void FilterAdaptationSSE2(aec_t *aec, float *fft, float ef[2][PART_LEN1], + int ip[IP_LEN], float wfft[W_LEN]) { + int i, j; + for (i = 0; i < NR_PART; i++) { + int xPos = (i + aec->xfBufBlockPos)*(PART_LEN1); + int pos = i * PART_LEN1; + // Check for wrap + if (i + aec->xfBufBlockPos >= NR_PART) { + xPos -= NR_PART * PART_LEN1; + } + +#ifdef UNCONSTR + for (j = 0; j < PART_LEN1; j++) { + aec->wfBuf[pos + j][0] += MulRe(aec->xfBuf[xPos + j][0], + -aec->xfBuf[xPos + j][1], + ef[j][0], ef[j][1]); + aec->wfBuf[pos + j][1] += MulIm(aec->xfBuf[xPos + j][0], + -aec->xfBuf[xPos + j][1], + ef[j][0], ef[j][1]); + } +#else + // Process the whole array... + for (j = 0; j < PART_LEN; j+= 4) { + // Load xfBuf and ef. + const __m128 xfBuf_re = _mm_loadu_ps(&aec->xfBuf[0][xPos + j]); + const __m128 xfBuf_im = _mm_loadu_ps(&aec->xfBuf[1][xPos + j]); + const __m128 ef_re = _mm_loadu_ps(&ef[0][j]); + const __m128 ef_im = _mm_loadu_ps(&ef[1][j]); + // Calculate the product of conjugate(xfBuf) by ef. + // re(conjugate(a) * b) = aRe * bRe + aIm * bIm + // im(conjugate(a) * b)= aRe * bIm - aIm * bRe + const __m128 a = _mm_mul_ps(xfBuf_re, ef_re); + const __m128 b = _mm_mul_ps(xfBuf_im, ef_im); + const __m128 c = _mm_mul_ps(xfBuf_re, ef_im); + const __m128 d = _mm_mul_ps(xfBuf_im, ef_re); + const __m128 e = _mm_add_ps(a, b); + const __m128 f = _mm_sub_ps(c, d); + // Interleave real and imaginary parts. + const __m128 g = _mm_unpacklo_ps(e, f); + const __m128 h = _mm_unpackhi_ps(e, f); + // Store + _mm_storeu_ps(&fft[2*j + 0], g); + _mm_storeu_ps(&fft[2*j + 4], h); + } + // ... and fixup the first imaginary entry. + fft[1] = MulRe(aec->xfBuf[0][xPos + PART_LEN], + -aec->xfBuf[1][xPos + PART_LEN], + ef[0][PART_LEN], ef[1][PART_LEN]); + + rdft(PART_LEN2, -1, fft, ip, wfft); + memset(fft + PART_LEN, 0, sizeof(float)*PART_LEN); + + // fft scaling + { + float scale = 2.0f / PART_LEN2; + const __m128 scale_ps = _mm_load_ps1(&scale); + for (j = 0; j < PART_LEN; j+=4) { + const __m128 fft_ps = _mm_loadu_ps(&fft[j]); + const __m128 fft_scale = _mm_mul_ps(fft_ps, scale_ps); + _mm_storeu_ps(&fft[j], fft_scale); + } + } + rdft(PART_LEN2, 1, fft, ip, wfft); + + { + float wt1 = aec->wfBuf[1][pos]; + aec->wfBuf[0][pos + PART_LEN] += fft[1]; + for (j = 0; j < PART_LEN; j+= 4) { + __m128 wtBuf_re = _mm_loadu_ps(&aec->wfBuf[0][pos + j]); + __m128 wtBuf_im = _mm_loadu_ps(&aec->wfBuf[1][pos + j]); + const __m128 fft0 = _mm_loadu_ps(&fft[2 * j + 0]); + const __m128 fft4 = _mm_loadu_ps(&fft[2 * j + 4]); + const __m128 fft_re = _mm_shuffle_ps(fft0, fft4, _MM_SHUFFLE(2, 0, 2 ,0)); + const __m128 fft_im = _mm_shuffle_ps(fft0, fft4, _MM_SHUFFLE(3, 1, 3 ,1)); + wtBuf_re = _mm_add_ps(wtBuf_re, fft_re); + wtBuf_im = _mm_add_ps(wtBuf_im, fft_im); + _mm_storeu_ps(&aec->wfBuf[0][pos + j], wtBuf_re); + _mm_storeu_ps(&aec->wfBuf[1][pos + j], wtBuf_im); + } + aec->wfBuf[1][pos] = wt1; + } +#endif // UNCONSTR + } +} + void WebRtcAec_InitAec_SSE2(void) { WebRtcAec_FilterFar = FilterFarSSE2; WebRtcAec_ScaleErrorSignal = ScaleErrorSignalSSE2; + WebRtcAec_FilterAdaptation = FilterAdaptationSSE2; } #endif //__SSE2__