From 2f7d1c62e21e2f3786c0803c973d71b414726d8d Mon Sep 17 00:00:00 2001 From: Alessio Bazzica Date: Mon, 9 Nov 2020 15:40:14 +0100 Subject: [PATCH] RNN VAD: pitch search optimizations (part 2) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This CL brings a large improvement to the VAD by precomputing the energy for the sliding frame `y` in the pitch buffer instead of computing them twice in two different places. The realtime factor has improved by about +16x. There is room for additional improvement (TODOs added), but that will be done in a follow up CL since the change won't be bit-exact and careful testing is needed. Benchmarked as follows: ``` out/release/modules_unittests \ --gtest_filter=*RnnVadTest.DISABLED_RnnVadPerformance* \ --gtest_also_run_disabled_tests --logs ``` Results: | baseline | this CL ------+----------------------+------------------------ run 1 | 23.568 +/- 0.990788 | 22.8319 +/- 1.46554 | 377.207x | 389.367x ------+----------------------+------------------------ run 2 | 23.3714 +/- 0.857523 | 22.4286 +/- 0.726449 | 380.379x | 396.369x ------+----------------------+------------------------ run 2 | 23.709 +/- 1.04477 | 22.5688 +/- 0.831341 | 374.963x | 393.906x Bug: webrtc:10480 Change-Id: I599a4dda2bde16dc6c2f42cf89e96afbd4630311 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/191484 Reviewed-by: Per Ã…hgren Commit-Queue: Alessio Bazzica Cr-Commit-Position: refs/heads/master@{#32571} --- .../agc2/rnn_vad/pitch_search.cc | 51 +++++++----- .../agc2/rnn_vad/pitch_search.h | 7 +- .../agc2/rnn_vad/pitch_search_internal.cc | 78 ++++++++----------- .../agc2/rnn_vad/pitch_search_internal.h | 12 ++- .../rnn_vad/pitch_search_internal_unittest.cc | 16 +++- 5 files changed, 87 insertions(+), 77 deletions(-) diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc index 9d4c5a2d81..c6c3e1b2b5 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc @@ -19,37 +19,46 @@ namespace webrtc { namespace rnn_vad { PitchEstimator::PitchEstimator() - : pitch_buf_decimated_(kBufSize12kHz), - pitch_buf_decimated_view_(pitch_buf_decimated_.data(), kBufSize12kHz), - auto_corr_(kNumLags12kHz), - auto_corr_view_(auto_corr_.data(), kNumLags12kHz) { - RTC_DCHECK_EQ(kBufSize12kHz, pitch_buf_decimated_.size()); - RTC_DCHECK_EQ(kNumLags12kHz, auto_corr_view_.size()); -} + : y_energy_24kHz_(kRefineNumLags24kHz, 0.f), + pitch_buffer_12kHz_(kBufSize12kHz), + auto_correlation_12kHz_(kNumLags12kHz) {} PitchEstimator::~PitchEstimator() = default; int PitchEstimator::Estimate( rtc::ArrayView pitch_buffer) { + rtc::ArrayView pitch_buffer_12kHz_view( + pitch_buffer_12kHz_.data(), kBufSize12kHz); + RTC_DCHECK_EQ(pitch_buffer_12kHz_.size(), pitch_buffer_12kHz_view.size()); + rtc::ArrayView auto_correlation_12kHz_view( + auto_correlation_12kHz_.data(), kNumLags12kHz); + RTC_DCHECK_EQ(auto_correlation_12kHz_.size(), + auto_correlation_12kHz_view.size()); + // Perform the initial pitch search at 12 kHz. - Decimate2x(pitch_buffer, pitch_buf_decimated_view_); - auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buf_decimated_view_, - auto_corr_view_); - CandidatePitchPeriods pitch_candidates_inverted_lags = - ComputePitchPeriod12kHz(pitch_buf_decimated_view_, auto_corr_view_); - // Refine the pitch period estimation. + Decimate2x(pitch_buffer, pitch_buffer_12kHz_view); + auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buffer_12kHz_view, + auto_correlation_12kHz_view); + CandidatePitchPeriods pitch_periods = ComputePitchPeriod12kHz( + pitch_buffer_12kHz_view, auto_correlation_12kHz_view); // The refinement is done using the pitch buffer that contains 24 kHz samples. // Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12 // to 24 kHz. - pitch_candidates_inverted_lags.best *= 2; - pitch_candidates_inverted_lags.second_best *= 2; - const int pitch_inv_lag_48kHz = - ComputePitchPeriod48kHz(pitch_buffer, pitch_candidates_inverted_lags); - // Look for stronger harmonics to find the final pitch period and its gain. - RTC_DCHECK_LT(pitch_inv_lag_48kHz, kMaxPitch48kHz); + pitch_periods.best *= 2; + pitch_periods.second_best *= 2; + + // Refine the initial pitch period estimation from 12 kHz to 48 kHz. + // Pre-compute frame energies at 24 kHz. + rtc::ArrayView y_energy_24kHz_view( + y_energy_24kHz_.data(), kRefineNumLags24kHz); + RTC_DCHECK_EQ(y_energy_24kHz_.size(), y_energy_24kHz_view.size()); + ComputeSlidingFrameSquareEnergies24kHz(pitch_buffer, y_energy_24kHz_view); + // Estimation at 48 kHz. + const int pitch_lag_48kHz = + ComputePitchPeriod48kHz(pitch_buffer, y_energy_24kHz_view, pitch_periods); last_pitch_48kHz_ = ComputeExtendedPitchPeriod48kHz( - pitch_buffer, - /*initial_pitch_period_48kHz=*/kMaxPitch48kHz - pitch_inv_lag_48kHz, + pitch_buffer, y_energy_24kHz_view, + /*initial_pitch_period_48kHz=*/kMaxPitch48kHz - pitch_lag_48kHz, last_pitch_48kHz_); return last_pitch_48kHz_.period; } diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.h b/modules/audio_processing/agc2/rnn_vad/pitch_search.h index 1e6b9ad706..e96a2dcaf1 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search.h +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.h @@ -41,10 +41,9 @@ class PitchEstimator { PitchInfo last_pitch_48kHz_{}; AutoCorrelationCalculator auto_corr_calculator_; - std::vector pitch_buf_decimated_; - rtc::ArrayView pitch_buf_decimated_view_; - std::vector auto_corr_; - rtc::ArrayView auto_corr_view_; + std::vector y_energy_24kHz_; + std::vector pitch_buffer_12kHz_; + std::vector auto_correlation_12kHz_; }; } // namespace rnn_vad diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc index 8179dbd965..d62cddf067 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc @@ -153,17 +153,12 @@ void ComputeAutoCorrelation( } } -int FindBestPitchPeriods24kHz( +int ComputePitchPeriod24kHz( + rtc::ArrayView pitch_buffer, rtc::ArrayView auto_correlation, - rtc::ArrayView pitch_buffer) { + rtc::ArrayView y_energy) { static_assert(kMaxPitch24kHz > kInitialNumLags24kHz, ""); static_assert(kMaxPitch24kHz < kBufSize24kHz, ""); - // Initialize the sliding 20 ms frame energy. - // TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization. - float denominator = std::inner_product( - pitch_buffer.begin(), pitch_buffer.begin() + kFrameSize20ms24kHz + 1, - pitch_buffer.begin(), 1.f); - // Search best pitch by looking at the scaled auto-correlation. int best_inverted_lag = 0; // Pitch period. float best_numerator = -1.f; // Pitch strength numerator. float best_denominator = 0.f; // Pitch strength denominator. @@ -171,8 +166,10 @@ int FindBestPitchPeriods24kHz( ++inverted_lag) { // A pitch candidate must have positive correlation. if (auto_correlation[inverted_lag] > 0.f) { + // Auto-correlation energy normalized by frame energy. const float numerator = auto_correlation[inverted_lag] * auto_correlation[inverted_lag]; + const float denominator = y_energy[kMaxPitch24kHz - inverted_lag]; // Compare numerator/denominator ratios without using divisions. if (numerator * best_denominator > best_numerator * denominator) { best_inverted_lag = inverted_lag; @@ -180,14 +177,6 @@ int FindBestPitchPeriods24kHz( best_denominator = denominator; } } - // Update |denominator| for the next inverted lag. - static_assert(kInitialNumLags24kHz + kFrameSize20ms24kHz < kBufSize24kHz, - ""); - const float y_old = pitch_buffer[inverted_lag]; - const float y_new = pitch_buffer[inverted_lag + kFrameSize20ms24kHz]; - denominator -= y_old * y_old; - denominator += y_new * y_new; - denominator = std::max(0.f, denominator); } return best_inverted_lag; } @@ -338,6 +327,7 @@ CandidatePitchPeriods ComputePitchPeriod12kHz( int ComputePitchPeriod48kHz( rtc::ArrayView pitch_buffer, + rtc::ArrayView y_energy, CandidatePitchPeriods pitch_candidates) { // Compute the auto-correlation terms only for neighbors of the given pitch // candidates (similar to what is done in ComputePitchAutoCorrelation(), but @@ -362,7 +352,7 @@ int ComputePitchPeriod48kHz( } // Find best pitch at 24 kHz. const int pitch_candidate_24kHz = - FindBestPitchPeriods24kHz(auto_correlation, pitch_buffer); + ComputePitchPeriod24kHz(pitch_buffer, auto_correlation, y_energy); // Pseudo-interpolation. return PitchPseudoInterpolationInvLagAutoCorr(pitch_candidate_24kHz, auto_correlation); @@ -370,6 +360,7 @@ int ComputePitchPeriod48kHz( PitchInfo ComputeExtendedPitchPeriod48kHz( rtc::ArrayView pitch_buffer, + rtc::ArrayView y_energy, int initial_pitch_period_48kHz, PitchInfo last_pitch_48kHz) { RTC_DCHECK_LE(kMinPitch48kHz, initial_pitch_period_48kHz); @@ -379,34 +370,30 @@ PitchInfo ComputeExtendedPitchPeriod48kHz( struct RefinedPitchCandidate { int period; float strength; - // Additional strength data used for the final estimation of the strength. - float xy; // Cross-correlation. - float yy; // Auto-correlation. + // Additional strength data used for the final pitch estimation. + float xy; // Auto-correlation. + float y_energy; // Energy of the sliding frame `y`. }; - // Initialize. - std::array yy_values; - // TODO(bugs.webrtc.org/9076): Reuse values from FindBestPitchPeriods24kHz(). - ComputeSlidingFrameSquareEnergies24kHz(pitch_buffer, yy_values); - const float xx = yy_values[0]; - const auto pitch_strength = [](float xy, float yy, float xx) { - RTC_DCHECK_GE(xx * yy, 0.f); - return xy / std::sqrt(1.f + xx * yy); + const float x_energy = y_energy[0]; + const auto pitch_strength = [x_energy](float xy, float y_energy) { + RTC_DCHECK_GE(x_energy * y_energy, 0.f); + return xy / std::sqrt(1.f + x_energy * y_energy); }; - // Initial pitch candidate. + + // Initialize the best pitch candidate with `initial_pitch_period_48kHz`. RefinedPitchCandidate best_pitch; best_pitch.period = std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHz - 1); best_pitch.xy = ComputeAutoCorrelation(kMaxPitch24kHz - best_pitch.period, pitch_buffer); - best_pitch.yy = yy_values[best_pitch.period]; - best_pitch.strength = pitch_strength(best_pitch.xy, best_pitch.yy, xx); - - // 24 kHz version of the last estimated pitch and copy of the initial - // estimation. + best_pitch.y_energy = y_energy[best_pitch.period]; + best_pitch.strength = pitch_strength(best_pitch.xy, best_pitch.y_energy); + // Keep a copy of the initial pitch candidate. + const PitchInfo initial_pitch{best_pitch.period, best_pitch.strength}; + // 24 kHz version of the last estimated pitch. const PitchInfo last_pitch{last_pitch_48kHz.period / 2, last_pitch_48kHz.strength}; - const PitchInfo initial_pitch{best_pitch.period, best_pitch.strength}; // Find `max_period_divisor` such that the result of // `GetAlternativePitchPeriod(initial_pitch_period, 1, max_period_divisor)` @@ -436,14 +423,14 @@ PitchInfo ComputeExtendedPitchPeriod48kHz( // Compute an auto-correlation score for the primary pitch candidate // |alternative_pitch.period| by also looking at its possible sub-harmonic // |dual_alternative_period|. - float xy_primary_period = ComputeAutoCorrelation( + const float xy_primary_period = ComputeAutoCorrelation( kMaxPitch24kHz - alternative_pitch.period, pitch_buffer); - float xy_secondary_period = ComputeAutoCorrelation( + const float xy_secondary_period = ComputeAutoCorrelation( kMaxPitch24kHz - dual_alternative_period, pitch_buffer); - float xy = 0.5f * (xy_primary_period + xy_secondary_period); - float yy = 0.5f * (yy_values[alternative_pitch.period] + - yy_values[dual_alternative_period]); - alternative_pitch.strength = pitch_strength(xy, yy, xx); + const float xy = 0.5f * (xy_primary_period + xy_secondary_period); + const float yy = 0.5f * (y_energy[alternative_pitch.period] + + y_energy[dual_alternative_period]); + alternative_pitch.strength = pitch_strength(xy, yy); // Maybe update best period. if (IsAlternativePitchStrongerThanInitial( @@ -455,10 +442,11 @@ PitchInfo ComputeExtendedPitchPeriod48kHz( // Final pitch strength and period. best_pitch.xy = std::max(0.f, best_pitch.xy); - RTC_DCHECK_LE(0.f, best_pitch.yy); - float final_pitch_strength = (best_pitch.yy <= best_pitch.xy) - ? 1.f - : best_pitch.xy / (best_pitch.yy + 1.f); + RTC_DCHECK_LE(0.f, best_pitch.y_energy); + float final_pitch_strength = + (best_pitch.y_energy <= best_pitch.xy) + ? 1.f + : best_pitch.xy / (best_pitch.y_energy + 1.f); final_pitch_strength = std::min(best_pitch.strength, final_pitch_strength); int final_pitch_period_48kHz = std::max( kMinPitch48kHz, diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h index b16a2f438d..693ab9e5d1 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h @@ -80,10 +80,12 @@ CandidatePitchPeriods ComputePitchPeriod12kHz( rtc::ArrayView pitch_buffer, rtc::ArrayView auto_correlation); -// Computes the pitch period at 48 kHz given a view on the 24 kHz pitch buffer -// and the pitch period candidates at 24 kHz (encoded as inverted lag). +// Computes the pitch period at 48 kHz given a view on the 24 kHz pitch buffer, +// the energies for the sliding frames `y` at 24 kHz and the pitch period +// candidates at 24 kHz (encoded as inverted lag). int ComputePitchPeriod48kHz( rtc::ArrayView pitch_buffer, + rtc::ArrayView y_energy, CandidatePitchPeriods pitch_candidates_24kHz); struct PitchInfo { @@ -92,10 +94,12 @@ struct PitchInfo { }; // Computes the pitch period at 48 kHz searching in an extended pitch range -// given a view on the 24 kHz pitch buffer, the initial 48 kHz estimation -// (computed by `ComputePitchPeriod48kHz()`) and the last estimated pitch. +// given a view on the 24 kHz pitch buffer, the energies for the sliding frames +// `y` at 24 kHz, the initial 48 kHz estimation (computed by +// `ComputePitchPeriod48kHz()`) and the last estimated pitch. PitchInfo ComputeExtendedPitchPeriod48kHz( rtc::ArrayView pitch_buffer, + rtc::ArrayView y_energy, int initial_pitch_period_48kHz, PitchInfo last_pitch_48kHz); diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc index 7acb046db1..e5826d02af 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc @@ -63,12 +63,17 @@ TEST(RnnVadTest, ComputePitchPeriod12kHzBitExactness) { // Checks that the refined pitch period is bit-exact given test input data. TEST(RnnVadTest, ComputePitchPeriod48kHzBitExactness) { PitchTestData test_data; + std::vector y_energy(kMaxPitch24kHz + 1); + rtc::ArrayView y_energy_view(y_energy.data(), + kMaxPitch24kHz + 1); + ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(), + y_energy_view); // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; - EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(), + EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view, /*pitch_candidates=*/{280, 284}), 560); - EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(), + EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view, /*pitch_candidates=*/{260, 284}), 568); } @@ -90,10 +95,15 @@ class ComputeExtendedPitchPeriod48kHzTest TEST_P(ComputeExtendedPitchPeriod48kHzTest, PeriodBitExactnessGainWithinTolerance) { PitchTestData test_data; + std::vector y_energy(kMaxPitch24kHz + 1); + rtc::ArrayView y_energy_view(y_energy.data(), + kMaxPitch24kHz + 1); + ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(), + y_energy_view); // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; const auto computed_output = ComputeExtendedPitchPeriod48kHz( - test_data.GetPitchBufView(), GetInitialPitchPeriod(), + test_data.GetPitchBufView(), y_energy_view, GetInitialPitchPeriod(), {GetLastPitchPeriod(), GetLastPitchStrength()}); EXPECT_EQ(GetExpectedPitchPeriod(), computed_output.period); EXPECT_NEAR(GetExpectedPitchStrength(), computed_output.strength, 1e-6f);