FFT-based auto correlation.

During pitch search in the RNN VAD, we calculate auto
correlation. Before this CL, we computed kNumInvertedLags12kHz=147 dot
products of vectors with kBufSize12kHz-kMaxPitch12kHz=240
elements. This was the most time consuming step of the new VAD.

This CL makes the computation happen in frequency domain. Profiling
shows a 3x speed increase. In future, we can try using a more efficient
FFT and to reduce the FFT length to some of e.g. 400, 405, 432.

# For minimal Clang plugin check change.
TBR: kwiberg@webrtc.org

Bug: webrtc:9076
Change-Id: I688251a415869d53175a37f390f441d4e035d954
Reviewed-on: https://webrtc-review.googlesource.com/73366
Reviewed-by: Karl Wiberg <kwiberg@webrtc.org>
Reviewed-by: Alessio Bazzica <alessiob@webrtc.org>
Commit-Queue: Alex Loiko <aleloi@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#23171}
This commit is contained in:
Alex Loiko 2018-05-08 13:11:12 +02:00 committed by Commit Bot
parent 0bd0a3fe4c
commit 0520b0eb7b
9 changed files with 114 additions and 18 deletions

View File

@ -45,6 +45,8 @@ RealFourierOoura::RealFourierOoura(int fft_order)
RTC_CHECK_GE(fft_order, 1);
}
RealFourierOoura::~RealFourierOoura() = default;
void RealFourierOoura::Forward(const float* src, complex<float>* dest) const {
{
// This cast is well-defined since C++11. See "Non-static data members" at:
@ -82,4 +84,8 @@ void RealFourierOoura::Inverse(const complex<float>* src, float* dest) const {
std::for_each(dest, dest + length_, [scale](float& v) { v *= scale; });
}
int RealFourierOoura::order() const {
return order_;
}
} // namespace webrtc

View File

@ -21,13 +21,12 @@ namespace webrtc {
class RealFourierOoura : public RealFourier {
public:
explicit RealFourierOoura(int fft_order);
~RealFourierOoura() override;
void Forward(const float* src, std::complex<float>* dest) const override;
void Inverse(const std::complex<float>* src, float* dest) const override;
int order() const override {
return order_;
}
int order() const override;
private:
const int order_;
@ -42,4 +41,3 @@ class RealFourierOoura : public RealFourier {
} // namespace webrtc
#endif // COMMON_AUDIO_REAL_FOURIER_OOURA_H_

View File

@ -36,6 +36,7 @@ source_set("lib") {
]
deps = [
"../../../../api:array_view",
"../../../../common_audio/",
"../../../../rtc_base:checks",
"../../../../rtc_base:rtc_base_approved",
"//third_party/rnnoise:kiss_fft",
@ -95,6 +96,7 @@ if (rtc_include_tests) {
":lib",
":lib_test",
"../../../../api:array_view",
"../../../../common_audio/",
"../../../../rtc_base:checks",
"../../../../test:test_support",
"//third_party/rnnoise:rnn_vad",

View File

@ -15,7 +15,8 @@ namespace webrtc {
namespace rnn_vad {
PitchInfo PitchSearch(rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
PitchInfo prev_pitch_48kHz) {
PitchInfo prev_pitch_48kHz,
RealFourier* fft) {
// Perform the initial pitch search at 12 kHz.
std::array<float, kBufSize12kHz> pitch_buf_decimated;
Decimate2x(pitch_buf,
@ -24,7 +25,8 @@ PitchInfo PitchSearch(rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
std::array<float, kNumInvertedLags12kHz> auto_corr;
ComputePitchAutoCorrelation(
{pitch_buf_decimated.data(), pitch_buf_decimated.size()}, kMaxPitch12kHz,
{auto_corr.data(), auto_corr.size()});
{auto_corr.data(), auto_corr.size()}, fft);
// Search for pitch at 12 kHz.
std::array<size_t, 2> pitch_candidates_inv_lags = FindBestPitchPeriods(
{auto_corr.data(), auto_corr.size()},

View File

@ -12,6 +12,7 @@
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_H_
#include "api/array_view.h"
#include "common_audio/real_fourier.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
@ -21,7 +22,8 @@ namespace rnn_vad {
// Searches the pitch period and gain. Return the pitch estimation data for
// 48 kHz.
PitchInfo PitchSearch(rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
PitchInfo prev_pitch_48kHz);
PitchInfo prev_pitch_48kHz,
RealFourier* fft);
} // namespace rnn_vad
} // namespace webrtc

View File

@ -205,18 +205,62 @@ void ComputeSlidingFrameSquareEnergies(
}
}
// TODO(bugs.webrtc.org/9076): Optimize using FFT and/or vectorization.
void ComputePitchAutoCorrelation(
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
size_t max_pitch_period,
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr) {
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr,
webrtc::RealFourier* fft) {
RTC_DCHECK_GT(max_pitch_period, auto_corr.size());
RTC_DCHECK_LT(max_pitch_period, pitch_buf.size());
// Compute auto-correlation coefficients.
for (size_t inv_lag = 0; inv_lag < auto_corr.size(); ++inv_lag) {
auto_corr[inv_lag] =
ComputeAutoCorrelationCoeff(pitch_buf, inv_lag, max_pitch_period);
RTC_DCHECK(fft);
constexpr size_t time_domain_fft_length = 1 << kAutoCorrelationFftOrder;
constexpr size_t freq_domain_fft_length = time_domain_fft_length / 2 + 1;
RTC_DCHECK_EQ(RealFourier::FftLength(fft->order()), time_domain_fft_length);
RTC_DCHECK_EQ(RealFourier::ComplexLength(fft->order()),
freq_domain_fft_length);
// Cross-correlation of y_i=pitch_buf[i:i+convolution_length] and
// x=pitch_buf[-convolution_length:] is equivalent to convolution of
// y_i and reversed(x). New notation: h=reversed(x), x=y.
std::array<float, time_domain_fft_length> h{};
std::array<float, time_domain_fft_length> x{};
const size_t convolution_length = kBufSize12kHz - max_pitch_period;
// Check that the FFT-length is big enough to avoid cyclic
// convolution errors.
RTC_DCHECK_GT(time_domain_fft_length,
kNumInvertedLags12kHz + convolution_length);
// h[0:convolution_length] is reversed pitch_buf[-convolution_length:].
std::reverse_copy(pitch_buf.end() - convolution_length, pitch_buf.end(),
h.begin());
// x is pitch_buf[:kNumInvertedLags12kHz + convolution_length].
std::copy(pitch_buf.begin(),
pitch_buf.begin() + kNumInvertedLags12kHz + convolution_length,
x.begin());
// Shift to frequency domain.
std::array<std::complex<float>, freq_domain_fft_length> X{};
std::array<std::complex<float>, freq_domain_fft_length> H{};
fft->Forward(&x[0], &X[0]);
fft->Forward(&h[0], &H[0]);
// Convolve in frequency domain.
for (size_t i = 0; i < X.size(); ++i) {
X[i] *= H[i];
}
// Shift back to time domain.
std::array<float, time_domain_fft_length> x_conv_h;
fft->Inverse(&X[0], &x_conv_h[0]);
// Collect the result.
std::copy(x_conv_h.begin() + convolution_length - 1,
x_conv_h.begin() + convolution_length + kNumInvertedLags12kHz - 1,
auto_corr.begin());
}
std::array<size_t, 2> FindBestPitchPeriods(

View File

@ -14,6 +14,7 @@
#include <array>
#include "api/array_view.h"
#include "common_audio/real_fourier.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
@ -26,6 +27,11 @@ static_assert(kMaxPitch12kHz > kInitialMinPitch12kHz, "");
static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, "");
constexpr size_t kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz;
constexpr size_t kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz;
constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT.
static_assert(1 << kAutoCorrelationFftOrder >
kNumInvertedLags12kHz + kBufSize12kHz - kMaxPitch12kHz,
"");
// Performs 2x decimation without any anti-aliasing filter.
void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
@ -70,7 +76,8 @@ void ComputeSlidingFrameSquareEnergies(
void ComputePitchAutoCorrelation(
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
size_t max_pitch_period,
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr);
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr,
webrtc::RealFourier* fft);
// Given the auto-correlation coefficients stored according to
// ComputePitchAutoCorrelation() (i.e., using inverted lags), returns the best

View File

@ -9,6 +9,7 @@
*/
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
#include "common_audio/real_fourier.h"
#include <array>
#include <tuple>
@ -415,16 +416,47 @@ TEST(RnnVadTest, ComputePitchAutoCorrelationBitExactness) {
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
std::unique_ptr<RealFourier> fft =
RealFourier::Create(kAutoCorrelationFftOrder);
ComputePitchAutoCorrelation(
{pitch_buf_decimated.data(), pitch_buf_decimated.size()},
kMaxPitch12kHz, {computed_output.data(), computed_output.size()});
kMaxPitch12kHz, {computed_output.data(), computed_output.size()},
fft.get());
}
ExpectNearAbsolute(
{kPitchBufferAutoCorrCoeffs.data(), kPitchBufferAutoCorrCoeffs.size()},
{computed_output.data(), computed_output.size()}, 3e-3f);
}
// Check that the auto correlation function computes the right thing for a
// simple use case.
TEST(RnnVadTest, ComputePitchAutoCorrelationConstantBuffer) {
// Create constant signal with no pitch.
std::array<float, kBufSize12kHz> pitch_buf_decimated;
std::fill(pitch_buf_decimated.begin(), pitch_buf_decimated.end(), 1.f);
std::array<float, kPitchBufferAutoCorrCoeffs.size()> computed_output;
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
std::unique_ptr<RealFourier> fft =
RealFourier::Create(kAutoCorrelationFftOrder);
ComputePitchAutoCorrelation(
{pitch_buf_decimated.data(), pitch_buf_decimated.size()},
kMaxPitch12kHz, {computed_output.data(), computed_output.size()},
fft.get());
}
// The expected output is constantly the length of the fixed 'x'
// array in ComputePitchAutoCorrelation.
std::array<float, kPitchBufferAutoCorrCoeffs.size()> expected_output;
std::fill(expected_output.begin(), expected_output.end(),
kBufSize12kHz - kMaxPitch12kHz);
ExpectNearAbsolute({expected_output.data(), expected_output.size()},
{computed_output.data(), computed_output.size()}, 4e-5f);
}
TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) {
std::array<float, kBufSize12kHz> pitch_buf_decimated;
Decimate2x({kPitchBufferData.data(), kPitchBufferData.size()},

View File

@ -9,6 +9,7 @@
*/
#include "modules/audio_processing/agc2/rnn_vad/pitch_search.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
#include <array>
@ -28,6 +29,8 @@ TEST(RnnVadTest, PitchSearchBitExactness) {
std::array<float, 864> lp_residual;
float expected_pitch_period, expected_pitch_gain;
PitchInfo last_pitch;
std::unique_ptr<RealFourier> fft =
RealFourier::Create(kAutoCorrelationFftOrder);
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
@ -38,8 +41,8 @@ TEST(RnnVadTest, PitchSearchBitExactness) {
{lp_residual.data(), lp_residual.size()});
lp_residual_reader.first->ReadValue(&expected_pitch_period);
lp_residual_reader.first->ReadValue(&expected_pitch_gain);
last_pitch =
PitchSearch({lp_residual.data(), lp_residual.size()}, last_pitch);
last_pitch = PitchSearch({lp_residual.data(), lp_residual.size()},
last_pitch, fft.get());
EXPECT_EQ(static_cast<size_t>(expected_pitch_period), last_pitch.period);
EXPECT_NEAR(expected_pitch_gain, last_pitch.gain, 1e-5f);
}