From f2a2fe84b824d2f19c89484726689d0fd281e0dc Mon Sep 17 00:00:00 2001 From: Alessio Bazzica Date: Wed, 11 Nov 2020 12:54:39 +0100 Subject: [PATCH] Reland "RNN VAD: pitch search optimizations (part 3)" This reverts commit 57e68ee1b94cce853ab4305680bbe71c01f95e82. Reason for revert: bug in ancestor CL fixed Original change's description: > Revert "RNN VAD: pitch search optimizations (part 3)" > > This reverts commit ea89f2a447c514b73da2ed6189fe4b8485f123c6. > > Reason for revert: bug in ancestor CL https://webrtc-review.googlesource.com/c/src/+/191320 > > Original change's description: > > RNN VAD: pitch search optimizations (part 3) > > > > `ComputeSlidingFrameSquareEnergies()` which computes the energy of a > > sliding 20 ms frame in the pitch buffer has been switched from backward > > to forward. > > > > The benchmark has shown a slight improvement (about +6x). > > > > This change is not bit exact but all the tolerance tests still pass > > except for one single case in `RnnVadTest,PitchSearchWithinTolerance` > > for which the tolerance has been slightly increased. Note that the pitch > > estimation is still bit-exact. > > > > Benchmarked as follows: > > ``` > > out/release/modules_unittests \ > > --gtest_filter=*RnnVadTest.DISABLED_RnnVadPerformance* \ > > --gtest_also_run_disabled_tests --logs > > ``` > > > > Results: > > > > | baseline | this CL > > ------+----------------------+------------------------ > > run 1 | 22.8319 +/- 1.46554 | 22.087 +/- 0.552932 > > | 389.367x | 402.499x > > ------+----------------------+------------------------ > > run 2 | 22.4286 +/- 0.726449 | 22.216 +/- 0.916222 > > | 396.369x | 400.162x > > ------+----------------------+------------------------ > > run 2 | 22.5688 +/- 0.831341 | 22.4902 +/- 1.04881 > > | 393.906x | 395.283x > > > > Bug: webrtc:10480 > > Change-Id: I1fd54077a32e25e46196c8e18f003cd0ffd503e1 > > Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/191703 > > Commit-Queue: Alessio Bazzica > > Reviewed-by: Karl Wiberg > > Cr-Commit-Position: refs/heads/master@{#32572} > > TBR=alessiob@webrtc.org,kwiberg@webrtc.org > > Change-Id: I57a8f937ade0a35e1ccf0e229c391cc3a10e7c48 > No-Presubmit: true > No-Tree-Checks: true > No-Try: true > Bug: webrtc:10480 > Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/192621 > Reviewed-by: Alessio Bazzica > Commit-Queue: Alessio Bazzica > Cr-Commit-Position: refs/heads/master@{#32578} TBR=alessiob@webrtc.org,kwiberg@webrtc.org # Not skipping CQ checks because this is a reland. Bug: webrtc:10480 Change-Id: I1d510697236255d8c0cca405e90781f5d8c6a3e6 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/192783 Commit-Queue: Alessio Bazzica Reviewed-by: Alessio Bazzica Reviewed-by: Karl Wiberg Cr-Commit-Position: refs/heads/master@{#32587} --- .../agc2/rnn_vad/pitch_search_internal.cc | 35 ++++++++++--------- .../agc2/rnn_vad/pitch_search_internal.h | 4 +-- .../rnn_vad/pitch_search_internal_unittest.cc | 2 +- .../agc2/rnn_vad/pitch_search_unittest.cc | 2 +- .../agc2/rnn_vad/test_utils.cc | 5 +++ 5 files changed, 27 insertions(+), 21 deletions(-) diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc index 4de3450574..d7ba65f932 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc @@ -172,7 +172,7 @@ int ComputePitchPeriod24kHz( // Auto-correlation energy normalized by frame energy. const float numerator = auto_correlation[inverted_lag] * auto_correlation[inverted_lag]; - const float denominator = y_energy[kMaxPitch24kHz - inverted_lag]; + const float denominator = y_energy[inverted_lag]; // Compare numerator/denominator ratios without using divisions. if (numerator * best_denominator > best_numerator * denominator) { best_inverted_lag = inverted_lag; @@ -256,19 +256,19 @@ void Decimate2x(rtc::ArrayView src, void ComputeSlidingFrameSquareEnergies24kHz( rtc::ArrayView pitch_buffer, - rtc::ArrayView yy_values) { - float yy = ComputeAutoCorrelation(kMaxPitch24kHz, pitch_buffer); - yy_values[0] = yy; - static_assert(kMaxPitch24kHz - (kRefineNumLags24kHz - 1) >= 0, ""); + rtc::ArrayView y_energy) { + float yy = std::inner_product(pitch_buffer.begin(), + pitch_buffer.begin() + kFrameSize20ms24kHz, + pitch_buffer.begin(), 0.f); + y_energy[0] = yy; static_assert(kMaxPitch24kHz - 1 + kFrameSize20ms24kHz < kBufSize24kHz, ""); - for (int lag = 1; lag < kRefineNumLags24kHz; ++lag) { - const int inverted_lag = kMaxPitch24kHz - lag; - const float y_old = pitch_buffer[inverted_lag + kFrameSize20ms24kHz]; - const float y_new = pitch_buffer[inverted_lag]; - yy -= y_old * y_old; - yy += y_new * y_new; - yy = std::max(0.f, yy); - yy_values[lag] = yy; + static_assert(kMaxPitch24kHz < kRefineNumLags24kHz, ""); + for (int inverted_lag = 0; inverted_lag < kMaxPitch24kHz; ++inverted_lag) { + yy -= pitch_buffer[inverted_lag] * pitch_buffer[inverted_lag]; + yy += pitch_buffer[inverted_lag + kFrameSize20ms24kHz] * + pitch_buffer[inverted_lag + kFrameSize20ms24kHz]; + yy = std::max(1.f, yy); + y_energy[inverted_lag + 1] = yy; } } @@ -382,7 +382,7 @@ PitchInfo ComputeExtendedPitchPeriod48kHz( float y_energy; // Energy of the sliding frame `y`. }; - const float x_energy = y_energy[0]; + const float x_energy = y_energy[kMaxPitch24kHz]; const auto pitch_strength = [x_energy](float xy, float y_energy) { RTC_DCHECK_GE(x_energy * y_energy, 0.f); return xy / std::sqrt(1.f + x_energy * y_energy); @@ -394,7 +394,7 @@ PitchInfo ComputeExtendedPitchPeriod48kHz( std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHz - 1); best_pitch.xy = ComputeAutoCorrelation(kMaxPitch24kHz - best_pitch.period, pitch_buffer); - best_pitch.y_energy = y_energy[best_pitch.period]; + best_pitch.y_energy = y_energy[kMaxPitch24kHz - best_pitch.period]; best_pitch.strength = pitch_strength(best_pitch.xy, best_pitch.y_energy); // Keep a copy of the initial pitch candidate. const PitchInfo initial_pitch{best_pitch.period, best_pitch.strength}; @@ -435,8 +435,9 @@ PitchInfo ComputeExtendedPitchPeriod48kHz( const float xy_secondary_period = ComputeAutoCorrelation( kMaxPitch24kHz - dual_alternative_period, pitch_buffer); const float xy = 0.5f * (xy_primary_period + xy_secondary_period); - const float yy = 0.5f * (y_energy[alternative_pitch.period] + - y_energy[dual_alternative_period]); + const float yy = + 0.5f * (y_energy[kMaxPitch24kHz - alternative_pitch.period] + + y_energy[kMaxPitch24kHz - dual_alternative_period]); alternative_pitch.strength = pitch_strength(xy, yy); // Maybe update best period. diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h index 693ab9e5d1..0af55f8e69 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h @@ -62,10 +62,10 @@ void Decimate2x(rtc::ArrayView src, // corresponding pitch period. // Computes the sum of squared samples for every sliding frame `y` in the pitch -// buffer. The indexes of `yy_values` are lags. +// buffer. The indexes of `y_energy` are inverted lags. void ComputeSlidingFrameSquareEnergies24kHz( rtc::ArrayView pitch_buffer, - rtc::ArrayView yy_values); + rtc::ArrayView y_energy); // Top-2 pitch period candidates. Unit: number of samples - i.e., inverted lags. struct CandidatePitchPeriods { diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc index 26bc395c42..fc715c6aef 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc @@ -42,7 +42,7 @@ TEST(RnnVadTest, ComputeSlidingFrameSquareEnergies24kHzWithinTolerance) { computed_output); auto square_energies_view = test_data.GetPitchBufSquareEnergiesView(); ExpectNearAbsolute({square_energies_view.data(), square_energies_view.size()}, - computed_output, 3e-2f); + computed_output, 1e-3f); } // Checks that the estimated pitch period is bit-exact given test input data. diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc index c57c8c24db..98b791e872 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc @@ -42,7 +42,7 @@ TEST(RnnVadTest, PitchSearchWithinTolerance) { pitch_estimator.Estimate({lp_residual.data(), kBufSize24kHz}); EXPECT_EQ(expected_pitch_period, pitch_period); EXPECT_NEAR(expected_pitch_strength, - pitch_estimator.GetLastPitchStrengthForTesting(), 1e-5f); + pitch_estimator.GetLastPitchStrengthForTesting(), 15e-6f); } } } diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.cc b/modules/audio_processing/agc2/rnn_vad/test_utils.cc index 74571af640..24bbf13e31 100644 --- a/modules/audio_processing/agc2/rnn_vad/test_utils.cc +++ b/modules/audio_processing/agc2/rnn_vad/test_utils.cc @@ -10,6 +10,7 @@ #include "modules/audio_processing/agc2/rnn_vad/test_utils.h" +#include #include #include "rtc_base/checks.h" @@ -86,6 +87,10 @@ PitchTestData::PitchTestData() { ResourcePath("audio_processing/agc2/rnn_vad/pitch_search_int", "dat"), 1396); test_data_reader.ReadChunk(test_data_); + // Reverse the order of the squared energy values. + // Required after the WebRTC CL 191703 which switched to forward computation. + std::reverse(test_data_.begin() + kBufSize24kHz, + test_data_.begin() + kBufSize24kHz + kNumPitchBufSquareEnergies); } PitchTestData::~PitchTestData() = default;