FFT-based auto correlation.
During pitch search in the RNN VAD, we calculate auto correlation. Before this CL, we computed kNumInvertedLags12kHz=147 dot products of vectors with kBufSize12kHz-kMaxPitch12kHz=240 elements. This was the most time consuming step of the new VAD. This CL makes the computation happen in frequency domain. Profiling shows a 3x speed increase. In future, we can try using a more efficient FFT and to reduce the FFT length to some of e.g. 400, 405, 432. # For minimal Clang plugin check change. TBR: kwiberg@webrtc.org Bug: webrtc:9076 Change-Id: I688251a415869d53175a37f390f441d4e035d954 Reviewed-on: https://webrtc-review.googlesource.com/73366 Reviewed-by: Karl Wiberg <kwiberg@webrtc.org> Reviewed-by: Alessio Bazzica <alessiob@webrtc.org> Commit-Queue: Alex Loiko <aleloi@webrtc.org> Cr-Commit-Position: refs/heads/master@{#23171}
This commit is contained in:
parent
0bd0a3fe4c
commit
0520b0eb7b
@ -45,6 +45,8 @@ RealFourierOoura::RealFourierOoura(int fft_order)
|
||||
RTC_CHECK_GE(fft_order, 1);
|
||||
}
|
||||
|
||||
RealFourierOoura::~RealFourierOoura() = default;
|
||||
|
||||
void RealFourierOoura::Forward(const float* src, complex<float>* dest) const {
|
||||
{
|
||||
// This cast is well-defined since C++11. See "Non-static data members" at:
|
||||
@ -82,4 +84,8 @@ void RealFourierOoura::Inverse(const complex<float>* src, float* dest) const {
|
||||
std::for_each(dest, dest + length_, [scale](float& v) { v *= scale; });
|
||||
}
|
||||
|
||||
int RealFourierOoura::order() const {
|
||||
return order_;
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
@ -21,13 +21,12 @@ namespace webrtc {
|
||||
class RealFourierOoura : public RealFourier {
|
||||
public:
|
||||
explicit RealFourierOoura(int fft_order);
|
||||
~RealFourierOoura() override;
|
||||
|
||||
void Forward(const float* src, std::complex<float>* dest) const override;
|
||||
void Inverse(const std::complex<float>* src, float* dest) const override;
|
||||
|
||||
int order() const override {
|
||||
return order_;
|
||||
}
|
||||
int order() const override;
|
||||
|
||||
private:
|
||||
const int order_;
|
||||
@ -42,4 +41,3 @@ class RealFourierOoura : public RealFourier {
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // COMMON_AUDIO_REAL_FOURIER_OOURA_H_
|
||||
|
||||
|
||||
@ -36,6 +36,7 @@ source_set("lib") {
|
||||
]
|
||||
deps = [
|
||||
"../../../../api:array_view",
|
||||
"../../../../common_audio/",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../../../rtc_base:rtc_base_approved",
|
||||
"//third_party/rnnoise:kiss_fft",
|
||||
@ -95,6 +96,7 @@ if (rtc_include_tests) {
|
||||
":lib",
|
||||
":lib_test",
|
||||
"../../../../api:array_view",
|
||||
"../../../../common_audio/",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../../../test:test_support",
|
||||
"//third_party/rnnoise:rnn_vad",
|
||||
|
||||
@ -15,7 +15,8 @@ namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
PitchInfo PitchSearch(rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
|
||||
PitchInfo prev_pitch_48kHz) {
|
||||
PitchInfo prev_pitch_48kHz,
|
||||
RealFourier* fft) {
|
||||
// Perform the initial pitch search at 12 kHz.
|
||||
std::array<float, kBufSize12kHz> pitch_buf_decimated;
|
||||
Decimate2x(pitch_buf,
|
||||
@ -24,7 +25,8 @@ PitchInfo PitchSearch(rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
|
||||
std::array<float, kNumInvertedLags12kHz> auto_corr;
|
||||
ComputePitchAutoCorrelation(
|
||||
{pitch_buf_decimated.data(), pitch_buf_decimated.size()}, kMaxPitch12kHz,
|
||||
{auto_corr.data(), auto_corr.size()});
|
||||
{auto_corr.data(), auto_corr.size()}, fft);
|
||||
|
||||
// Search for pitch at 12 kHz.
|
||||
std::array<size_t, 2> pitch_candidates_inv_lags = FindBestPitchPeriods(
|
||||
{auto_corr.data(), auto_corr.size()},
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_H_
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "common_audio/real_fourier.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
|
||||
|
||||
@ -21,7 +22,8 @@ namespace rnn_vad {
|
||||
// Searches the pitch period and gain. Return the pitch estimation data for
|
||||
// 48 kHz.
|
||||
PitchInfo PitchSearch(rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
|
||||
PitchInfo prev_pitch_48kHz);
|
||||
PitchInfo prev_pitch_48kHz,
|
||||
RealFourier* fft);
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
@ -205,18 +205,62 @@ void ComputeSlidingFrameSquareEnergies(
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(bugs.webrtc.org/9076): Optimize using FFT and/or vectorization.
|
||||
void ComputePitchAutoCorrelation(
|
||||
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
|
||||
size_t max_pitch_period,
|
||||
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr) {
|
||||
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr,
|
||||
webrtc::RealFourier* fft) {
|
||||
RTC_DCHECK_GT(max_pitch_period, auto_corr.size());
|
||||
RTC_DCHECK_LT(max_pitch_period, pitch_buf.size());
|
||||
// Compute auto-correlation coefficients.
|
||||
for (size_t inv_lag = 0; inv_lag < auto_corr.size(); ++inv_lag) {
|
||||
auto_corr[inv_lag] =
|
||||
ComputeAutoCorrelationCoeff(pitch_buf, inv_lag, max_pitch_period);
|
||||
RTC_DCHECK(fft);
|
||||
|
||||
constexpr size_t time_domain_fft_length = 1 << kAutoCorrelationFftOrder;
|
||||
constexpr size_t freq_domain_fft_length = time_domain_fft_length / 2 + 1;
|
||||
|
||||
RTC_DCHECK_EQ(RealFourier::FftLength(fft->order()), time_domain_fft_length);
|
||||
RTC_DCHECK_EQ(RealFourier::ComplexLength(fft->order()),
|
||||
freq_domain_fft_length);
|
||||
|
||||
// Cross-correlation of y_i=pitch_buf[i:i+convolution_length] and
|
||||
// x=pitch_buf[-convolution_length:] is equivalent to convolution of
|
||||
// y_i and reversed(x). New notation: h=reversed(x), x=y.
|
||||
std::array<float, time_domain_fft_length> h{};
|
||||
std::array<float, time_domain_fft_length> x{};
|
||||
|
||||
const size_t convolution_length = kBufSize12kHz - max_pitch_period;
|
||||
// Check that the FFT-length is big enough to avoid cyclic
|
||||
// convolution errors.
|
||||
RTC_DCHECK_GT(time_domain_fft_length,
|
||||
kNumInvertedLags12kHz + convolution_length);
|
||||
|
||||
// h[0:convolution_length] is reversed pitch_buf[-convolution_length:].
|
||||
std::reverse_copy(pitch_buf.end() - convolution_length, pitch_buf.end(),
|
||||
h.begin());
|
||||
|
||||
// x is pitch_buf[:kNumInvertedLags12kHz + convolution_length].
|
||||
std::copy(pitch_buf.begin(),
|
||||
pitch_buf.begin() + kNumInvertedLags12kHz + convolution_length,
|
||||
x.begin());
|
||||
|
||||
// Shift to frequency domain.
|
||||
std::array<std::complex<float>, freq_domain_fft_length> X{};
|
||||
std::array<std::complex<float>, freq_domain_fft_length> H{};
|
||||
fft->Forward(&x[0], &X[0]);
|
||||
fft->Forward(&h[0], &H[0]);
|
||||
|
||||
// Convolve in frequency domain.
|
||||
for (size_t i = 0; i < X.size(); ++i) {
|
||||
X[i] *= H[i];
|
||||
}
|
||||
|
||||
// Shift back to time domain.
|
||||
std::array<float, time_domain_fft_length> x_conv_h;
|
||||
fft->Inverse(&X[0], &x_conv_h[0]);
|
||||
|
||||
// Collect the result.
|
||||
std::copy(x_conv_h.begin() + convolution_length - 1,
|
||||
x_conv_h.begin() + convolution_length + kNumInvertedLags12kHz - 1,
|
||||
auto_corr.begin());
|
||||
}
|
||||
|
||||
std::array<size_t, 2> FindBestPitchPeriods(
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
#include <array>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "common_audio/real_fourier.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
|
||||
|
||||
@ -26,6 +27,11 @@ static_assert(kMaxPitch12kHz > kInitialMinPitch12kHz, "");
|
||||
static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, "");
|
||||
constexpr size_t kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz;
|
||||
constexpr size_t kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz;
|
||||
constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT.
|
||||
|
||||
static_assert(1 << kAutoCorrelationFftOrder >
|
||||
kNumInvertedLags12kHz + kBufSize12kHz - kMaxPitch12kHz,
|
||||
"");
|
||||
|
||||
// Performs 2x decimation without any anti-aliasing filter.
|
||||
void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
|
||||
@ -70,7 +76,8 @@ void ComputeSlidingFrameSquareEnergies(
|
||||
void ComputePitchAutoCorrelation(
|
||||
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
|
||||
size_t max_pitch_period,
|
||||
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr);
|
||||
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr,
|
||||
webrtc::RealFourier* fft);
|
||||
|
||||
// Given the auto-correlation coefficients stored according to
|
||||
// ComputePitchAutoCorrelation() (i.e., using inverted lags), returns the best
|
||||
|
||||
@ -9,6 +9,7 @@
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
|
||||
#include "common_audio/real_fourier.h"
|
||||
|
||||
#include <array>
|
||||
#include <tuple>
|
||||
@ -415,16 +416,47 @@ TEST(RnnVadTest, ComputePitchAutoCorrelationBitExactness) {
|
||||
{
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// FloatingPointExceptionObserver fpe_observer;
|
||||
|
||||
std::unique_ptr<RealFourier> fft =
|
||||
RealFourier::Create(kAutoCorrelationFftOrder);
|
||||
ComputePitchAutoCorrelation(
|
||||
{pitch_buf_decimated.data(), pitch_buf_decimated.size()},
|
||||
kMaxPitch12kHz, {computed_output.data(), computed_output.size()});
|
||||
kMaxPitch12kHz, {computed_output.data(), computed_output.size()},
|
||||
fft.get());
|
||||
}
|
||||
ExpectNearAbsolute(
|
||||
{kPitchBufferAutoCorrCoeffs.data(), kPitchBufferAutoCorrCoeffs.size()},
|
||||
{computed_output.data(), computed_output.size()}, 3e-3f);
|
||||
}
|
||||
|
||||
// Check that the auto correlation function computes the right thing for a
|
||||
// simple use case.
|
||||
TEST(RnnVadTest, ComputePitchAutoCorrelationConstantBuffer) {
|
||||
// Create constant signal with no pitch.
|
||||
std::array<float, kBufSize12kHz> pitch_buf_decimated;
|
||||
std::fill(pitch_buf_decimated.begin(), pitch_buf_decimated.end(), 1.f);
|
||||
|
||||
std::array<float, kPitchBufferAutoCorrCoeffs.size()> computed_output;
|
||||
{
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// FloatingPointExceptionObserver fpe_observer;
|
||||
|
||||
std::unique_ptr<RealFourier> fft =
|
||||
RealFourier::Create(kAutoCorrelationFftOrder);
|
||||
ComputePitchAutoCorrelation(
|
||||
{pitch_buf_decimated.data(), pitch_buf_decimated.size()},
|
||||
kMaxPitch12kHz, {computed_output.data(), computed_output.size()},
|
||||
fft.get());
|
||||
}
|
||||
|
||||
// The expected output is constantly the length of the fixed 'x'
|
||||
// array in ComputePitchAutoCorrelation.
|
||||
std::array<float, kPitchBufferAutoCorrCoeffs.size()> expected_output;
|
||||
std::fill(expected_output.begin(), expected_output.end(),
|
||||
kBufSize12kHz - kMaxPitch12kHz);
|
||||
ExpectNearAbsolute({expected_output.data(), expected_output.size()},
|
||||
{computed_output.data(), computed_output.size()}, 4e-5f);
|
||||
}
|
||||
|
||||
TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) {
|
||||
std::array<float, kBufSize12kHz> pitch_buf_decimated;
|
||||
Decimate2x({kPitchBufferData.data(), kPitchBufferData.size()},
|
||||
|
||||
@ -9,6 +9,7 @@
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_search.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
|
||||
|
||||
#include <array>
|
||||
|
||||
@ -28,6 +29,8 @@ TEST(RnnVadTest, PitchSearchBitExactness) {
|
||||
std::array<float, 864> lp_residual;
|
||||
float expected_pitch_period, expected_pitch_gain;
|
||||
PitchInfo last_pitch;
|
||||
std::unique_ptr<RealFourier> fft =
|
||||
RealFourier::Create(kAutoCorrelationFftOrder);
|
||||
{
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// FloatingPointExceptionObserver fpe_observer;
|
||||
@ -38,8 +41,8 @@ TEST(RnnVadTest, PitchSearchBitExactness) {
|
||||
{lp_residual.data(), lp_residual.size()});
|
||||
lp_residual_reader.first->ReadValue(&expected_pitch_period);
|
||||
lp_residual_reader.first->ReadValue(&expected_pitch_gain);
|
||||
last_pitch =
|
||||
PitchSearch({lp_residual.data(), lp_residual.size()}, last_pitch);
|
||||
last_pitch = PitchSearch({lp_residual.data(), lp_residual.size()},
|
||||
last_pitch, fft.get());
|
||||
EXPECT_EQ(static_cast<size_t>(expected_pitch_period), last_pitch.period);
|
||||
EXPECT_NEAR(expected_pitch_gain, last_pitch.gain, 1e-5f);
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user