diff --git a/common_audio/real_fourier_ooura.cc b/common_audio/real_fourier_ooura.cc index 5d75717bc7..ca043e4e13 100644 --- a/common_audio/real_fourier_ooura.cc +++ b/common_audio/real_fourier_ooura.cc @@ -45,6 +45,8 @@ RealFourierOoura::RealFourierOoura(int fft_order) RTC_CHECK_GE(fft_order, 1); } +RealFourierOoura::~RealFourierOoura() = default; + void RealFourierOoura::Forward(const float* src, complex* dest) const { { // This cast is well-defined since C++11. See "Non-static data members" at: @@ -82,4 +84,8 @@ void RealFourierOoura::Inverse(const complex* src, float* dest) const { std::for_each(dest, dest + length_, [scale](float& v) { v *= scale; }); } +int RealFourierOoura::order() const { + return order_; +} + } // namespace webrtc diff --git a/common_audio/real_fourier_ooura.h b/common_audio/real_fourier_ooura.h index f885a34f58..bb8eef96df 100644 --- a/common_audio/real_fourier_ooura.h +++ b/common_audio/real_fourier_ooura.h @@ -21,13 +21,12 @@ namespace webrtc { class RealFourierOoura : public RealFourier { public: explicit RealFourierOoura(int fft_order); + ~RealFourierOoura() override; void Forward(const float* src, std::complex* dest) const override; void Inverse(const std::complex* src, float* dest) const override; - int order() const override { - return order_; - } + int order() const override; private: const int order_; @@ -42,4 +41,3 @@ class RealFourierOoura : public RealFourier { } // namespace webrtc #endif // COMMON_AUDIO_REAL_FOURIER_OOURA_H_ - diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn index b0ca347f69..f35d5c383f 100644 --- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn +++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn @@ -36,6 +36,7 @@ source_set("lib") { ] deps = [ "../../../../api:array_view", + "../../../../common_audio/", "../../../../rtc_base:checks", "../../../../rtc_base:rtc_base_approved", "//third_party/rnnoise:kiss_fft", @@ -95,6 +96,7 @@ if (rtc_include_tests) { ":lib", ":lib_test", "../../../../api:array_view", + "../../../../common_audio/", "../../../../rtc_base:checks", "../../../../test:test_support", "//third_party/rnnoise:rnn_vad", diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc index 4d83588cb3..9261935ca8 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc @@ -15,7 +15,8 @@ namespace webrtc { namespace rnn_vad { PitchInfo PitchSearch(rtc::ArrayView pitch_buf, - PitchInfo prev_pitch_48kHz) { + PitchInfo prev_pitch_48kHz, + RealFourier* fft) { // Perform the initial pitch search at 12 kHz. std::array pitch_buf_decimated; Decimate2x(pitch_buf, @@ -24,7 +25,8 @@ PitchInfo PitchSearch(rtc::ArrayView pitch_buf, std::array auto_corr; ComputePitchAutoCorrelation( {pitch_buf_decimated.data(), pitch_buf_decimated.size()}, kMaxPitch12kHz, - {auto_corr.data(), auto_corr.size()}); + {auto_corr.data(), auto_corr.size()}, fft); + // Search for pitch at 12 kHz. std::array pitch_candidates_inv_lags = FindBestPitchPeriods( {auto_corr.data(), auto_corr.size()}, diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.h b/modules/audio_processing/agc2/rnn_vad/pitch_search.h index a0af0ebfa2..21e7a05b9e 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search.h +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.h @@ -12,6 +12,7 @@ #define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_H_ #include "api/array_view.h" +#include "common_audio/real_fourier.h" #include "modules/audio_processing/agc2/rnn_vad/common.h" #include "modules/audio_processing/agc2/rnn_vad/pitch_info.h" @@ -21,7 +22,8 @@ namespace rnn_vad { // Searches the pitch period and gain. Return the pitch estimation data for // 48 kHz. PitchInfo PitchSearch(rtc::ArrayView pitch_buf, - PitchInfo prev_pitch_48kHz); + PitchInfo prev_pitch_48kHz, + RealFourier* fft); } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc index 1ff4621b28..99600e047d 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc @@ -205,18 +205,62 @@ void ComputeSlidingFrameSquareEnergies( } } -// TODO(bugs.webrtc.org/9076): Optimize using FFT and/or vectorization. void ComputePitchAutoCorrelation( rtc::ArrayView pitch_buf, size_t max_pitch_period, - rtc::ArrayView auto_corr) { + rtc::ArrayView auto_corr, + webrtc::RealFourier* fft) { RTC_DCHECK_GT(max_pitch_period, auto_corr.size()); RTC_DCHECK_LT(max_pitch_period, pitch_buf.size()); - // Compute auto-correlation coefficients. - for (size_t inv_lag = 0; inv_lag < auto_corr.size(); ++inv_lag) { - auto_corr[inv_lag] = - ComputeAutoCorrelationCoeff(pitch_buf, inv_lag, max_pitch_period); + RTC_DCHECK(fft); + + constexpr size_t time_domain_fft_length = 1 << kAutoCorrelationFftOrder; + constexpr size_t freq_domain_fft_length = time_domain_fft_length / 2 + 1; + + RTC_DCHECK_EQ(RealFourier::FftLength(fft->order()), time_domain_fft_length); + RTC_DCHECK_EQ(RealFourier::ComplexLength(fft->order()), + freq_domain_fft_length); + + // Cross-correlation of y_i=pitch_buf[i:i+convolution_length] and + // x=pitch_buf[-convolution_length:] is equivalent to convolution of + // y_i and reversed(x). New notation: h=reversed(x), x=y. + std::array h{}; + std::array x{}; + + const size_t convolution_length = kBufSize12kHz - max_pitch_period; + // Check that the FFT-length is big enough to avoid cyclic + // convolution errors. + RTC_DCHECK_GT(time_domain_fft_length, + kNumInvertedLags12kHz + convolution_length); + + // h[0:convolution_length] is reversed pitch_buf[-convolution_length:]. + std::reverse_copy(pitch_buf.end() - convolution_length, pitch_buf.end(), + h.begin()); + + // x is pitch_buf[:kNumInvertedLags12kHz + convolution_length]. + std::copy(pitch_buf.begin(), + pitch_buf.begin() + kNumInvertedLags12kHz + convolution_length, + x.begin()); + + // Shift to frequency domain. + std::array, freq_domain_fft_length> X{}; + std::array, freq_domain_fft_length> H{}; + fft->Forward(&x[0], &X[0]); + fft->Forward(&h[0], &H[0]); + + // Convolve in frequency domain. + for (size_t i = 0; i < X.size(); ++i) { + X[i] *= H[i]; } + + // Shift back to time domain. + std::array x_conv_h; + fft->Inverse(&X[0], &x_conv_h[0]); + + // Collect the result. + std::copy(x_conv_h.begin() + convolution_length - 1, + x_conv_h.begin() + convolution_length + kNumInvertedLags12kHz - 1, + auto_corr.begin()); } std::array FindBestPitchPeriods( diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h index dfe1b35ff7..75f7f17a42 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h @@ -14,6 +14,7 @@ #include #include "api/array_view.h" +#include "common_audio/real_fourier.h" #include "modules/audio_processing/agc2/rnn_vad/common.h" #include "modules/audio_processing/agc2/rnn_vad/pitch_info.h" @@ -26,6 +27,11 @@ static_assert(kMaxPitch12kHz > kInitialMinPitch12kHz, ""); static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, ""); constexpr size_t kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz; constexpr size_t kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz; +constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT. + +static_assert(1 << kAutoCorrelationFftOrder > + kNumInvertedLags12kHz + kBufSize12kHz - kMaxPitch12kHz, + ""); // Performs 2x decimation without any anti-aliasing filter. void Decimate2x(rtc::ArrayView src, @@ -70,7 +76,8 @@ void ComputeSlidingFrameSquareEnergies( void ComputePitchAutoCorrelation( rtc::ArrayView pitch_buf, size_t max_pitch_period, - rtc::ArrayView auto_corr); + rtc::ArrayView auto_corr, + webrtc::RealFourier* fft); // Given the auto-correlation coefficients stored according to // ComputePitchAutoCorrelation() (i.e., using inverted lags), returns the best diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc index 9a6a2676c6..4b1be0d11f 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc @@ -9,6 +9,7 @@ */ #include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h" +#include "common_audio/real_fourier.h" #include #include @@ -415,16 +416,47 @@ TEST(RnnVadTest, ComputePitchAutoCorrelationBitExactness) { { // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; - + std::unique_ptr fft = + RealFourier::Create(kAutoCorrelationFftOrder); ComputePitchAutoCorrelation( {pitch_buf_decimated.data(), pitch_buf_decimated.size()}, - kMaxPitch12kHz, {computed_output.data(), computed_output.size()}); + kMaxPitch12kHz, {computed_output.data(), computed_output.size()}, + fft.get()); } ExpectNearAbsolute( {kPitchBufferAutoCorrCoeffs.data(), kPitchBufferAutoCorrCoeffs.size()}, {computed_output.data(), computed_output.size()}, 3e-3f); } +// Check that the auto correlation function computes the right thing for a +// simple use case. +TEST(RnnVadTest, ComputePitchAutoCorrelationConstantBuffer) { + // Create constant signal with no pitch. + std::array pitch_buf_decimated; + std::fill(pitch_buf_decimated.begin(), pitch_buf_decimated.end(), 1.f); + + std::array computed_output; + { + // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. + // FloatingPointExceptionObserver fpe_observer; + + std::unique_ptr fft = + RealFourier::Create(kAutoCorrelationFftOrder); + ComputePitchAutoCorrelation( + {pitch_buf_decimated.data(), pitch_buf_decimated.size()}, + kMaxPitch12kHz, {computed_output.data(), computed_output.size()}, + fft.get()); + } + + // The expected output is constantly the length of the fixed 'x' + // array in ComputePitchAutoCorrelation. + std::array expected_output; + std::fill(expected_output.begin(), expected_output.end(), + kBufSize12kHz - kMaxPitch12kHz); + ExpectNearAbsolute({expected_output.data(), expected_output.size()}, + {computed_output.data(), computed_output.size()}, 4e-5f); +} + TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) { std::array pitch_buf_decimated; Decimate2x({kPitchBufferData.data(), kPitchBufferData.size()}, diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc index 441776465d..b25aba393e 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc @@ -9,6 +9,7 @@ */ #include "modules/audio_processing/agc2/rnn_vad/pitch_search.h" +#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h" #include @@ -28,6 +29,8 @@ TEST(RnnVadTest, PitchSearchBitExactness) { std::array lp_residual; float expected_pitch_period, expected_pitch_gain; PitchInfo last_pitch; + std::unique_ptr fft = + RealFourier::Create(kAutoCorrelationFftOrder); { // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; @@ -38,8 +41,8 @@ TEST(RnnVadTest, PitchSearchBitExactness) { {lp_residual.data(), lp_residual.size()}); lp_residual_reader.first->ReadValue(&expected_pitch_period); lp_residual_reader.first->ReadValue(&expected_pitch_gain); - last_pitch = - PitchSearch({lp_residual.data(), lp_residual.size()}, last_pitch); + last_pitch = PitchSearch({lp_residual.data(), lp_residual.size()}, + last_pitch, fft.get()); EXPECT_EQ(static_cast(expected_pitch_period), last_pitch.period); EXPECT_NEAR(expected_pitch_gain, last_pitch.gain, 1e-5f); }