From 842b675304ad9f4854ae3decdf654bd380169e19 Mon Sep 17 00:00:00 2001 From: Alessio Bazzica Date: Thu, 29 Oct 2020 15:49:57 +0100 Subject: [PATCH] RNN VAD: Pitch periods as integers and for-if-break optimization This CL includes two changes: 1. the type for (inverted) lags and pitch periods changed from size_t to int to reduce the chance of bugs with pitch period manipulations 2. CheckLowerPitchPeriodsAndComputePitchGain() is optimized by replacing an unnecessary if statement inside the loop with the predetermined number of loops Bug: webrtc:10480 Change-Id: I38432699254b37a2c0111279c28be8dc65b87e9b Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/139252 Commit-Queue: Alessio Bazzica Reviewed-by: Gustaf Ullberg Reviewed-by: Fredrik Hernqvist Cr-Commit-Position: refs/heads/master@{#32521} --- .../audio_processing/agc2/rnn_vad/BUILD.gn | 1 + .../agc2/rnn_vad/pitch_search.cc | 18 ++- .../agc2/rnn_vad/pitch_search_internal.cc | 140 +++++++++--------- .../agc2/rnn_vad/pitch_search_internal.h | 25 ++-- .../rnn_vad/pitch_search_internal_unittest.cc | 28 ++-- 5 files changed, 113 insertions(+), 99 deletions(-) diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn index fcf179c338..7822901fba 100644 --- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn +++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn @@ -92,6 +92,7 @@ rtc_library("rnn_vad_pitch") { ":rnn_vad_common", "../../../../api:array_view", "../../../../rtc_base:checks", + "../../../../rtc_base:safe_compare", ] } diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc index 1b3b459c5f..df73274cb7 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc @@ -35,20 +35,22 @@ PitchInfo PitchEstimator::Estimate( Decimate2x(pitch_buf, pitch_buf_decimated_view_); auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buf_decimated_view_, auto_corr_view_); - std::array pitch_candidates_inv_lags = FindBestPitchPeriods( - auto_corr_view_, pitch_buf_decimated_view_, kMaxPitch12kHz); + CandidatePitchPeriods pitch_candidates_inverted_lags = + FindBestPitchPeriods(auto_corr_view_, pitch_buf_decimated_view_, + static_cast(kMaxPitch12kHz)); // Refine the pitch period estimation. // The refinement is done using the pitch buffer that contains 24 kHz samples. // Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12 // to 24 kHz. - pitch_candidates_inv_lags[0] *= 2; - pitch_candidates_inv_lags[1] *= 2; - size_t pitch_inv_lag_48kHz = - RefinePitchPeriod48kHz(pitch_buf, pitch_candidates_inv_lags); + pitch_candidates_inverted_lags.best *= 2; + pitch_candidates_inverted_lags.second_best *= 2; + const int pitch_inv_lag_48kHz = + RefinePitchPeriod48kHz(pitch_buf, pitch_candidates_inverted_lags); // Look for stronger harmonics to find the final pitch period and its gain. - RTC_DCHECK_LT(pitch_inv_lag_48kHz, kMaxPitch48kHz); + RTC_DCHECK_LT(pitch_inv_lag_48kHz, static_cast(kMaxPitch48kHz)); last_pitch_48kHz_ = CheckLowerPitchPeriodsAndComputePitchGain( - pitch_buf, kMaxPitch48kHz - pitch_inv_lag_48kHz, last_pitch_48kHz_); + pitch_buf, static_cast(kMaxPitch48kHz) - pitch_inv_lag_48kHz, + last_pitch_48kHz_); return last_pitch_48kHz_; } diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc index f24a76f7bd..922669a4c5 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc @@ -19,38 +19,41 @@ #include "modules/audio_processing/agc2/rnn_vad/common.h" #include "rtc_base/checks.h" +#include "rtc_base/numerics/safe_compare.h" namespace webrtc { namespace rnn_vad { namespace { +constexpr int kMaxPitch24kHzInt = static_cast(kMaxPitch24kHz); + // Converts a lag to an inverted lag (only for 24kHz). -size_t GetInvertedLag(size_t lag) { - RTC_DCHECK_LE(lag, kMaxPitch24kHz); - return kMaxPitch24kHz - lag; +int GetInvertedLag(int lag) { + RTC_DCHECK_LE(lag, kMaxPitch24kHzInt); + return kMaxPitch24kHzInt - lag; } float ComputeAutoCorrelationCoeff(rtc::ArrayView pitch_buf, - size_t inv_lag, - size_t max_pitch_period) { - RTC_DCHECK_LT(inv_lag, pitch_buf.size()); - RTC_DCHECK_LT(max_pitch_period, pitch_buf.size()); - RTC_DCHECK_LE(inv_lag, max_pitch_period); + int inv_lag, + int max_pitch_period) { + RTC_DCHECK_LT(inv_lag, static_cast(pitch_buf.size())); + RTC_DCHECK_LT(max_pitch_period, static_cast(pitch_buf.size())); + RTC_DCHECK_LE(inv_lag, static_cast(max_pitch_period)); // TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization. - return std::inner_product(pitch_buf.begin() + max_pitch_period, - pitch_buf.end(), pitch_buf.begin() + inv_lag, 0.f); + return std::inner_product( + pitch_buf.begin() + static_cast(max_pitch_period), + pitch_buf.end(), pitch_buf.begin() + static_cast(inv_lag), 0.f); } -// Computes a pseudo-interpolation offset for an estimated pitch period |lag| by -// looking at the auto-correlation coefficients in the neighborhood of |lag|. -// (namely, |prev_auto_corr|, |lag_auto_corr| and |next_auto_corr|). The output -// is a lag in {-1, 0, +1}. +// Given the auto-correlation coefficients for a lag and its neighbors, computes +// a pseudo-interpolation offset to be applied to the pitch period associated to +// the central auto-correlation coefficient |lag_auto_corr|. The output is a lag +// in {-1, 0, +1}. // TODO(bugs.webrtc.org/9076): Consider removing pseudo-i since it // is relevant only if the spectral analysis works at a sample rate that is // twice as that of the pitch buffer (not so important instead for the estimated // pitch period feature fed into the RNN). -int GetPitchPseudoInterpolationOffset(size_t lag, - float prev_auto_corr, +int GetPitchPseudoInterpolationOffset(float prev_auto_corr, float lag_auto_corr, float next_auto_corr) { const float& a = prev_auto_corr; @@ -68,20 +71,19 @@ int GetPitchPseudoInterpolationOffset(size_t lag, // Refines a pitch period |lag| encoded as lag with pseudo-interpolation. The // output sample rate is twice as that of |lag|. -size_t PitchPseudoInterpolationLagPitchBuf( - size_t lag, +int PitchPseudoInterpolationLagPitchBuf( + int lag, rtc::ArrayView pitch_buf) { int offset = 0; // Cannot apply pseudo-interpolation at the boundaries. - if (lag > 0 && lag < kMaxPitch24kHz) { + if (lag > 0 && lag < kMaxPitch24kHzInt) { offset = GetPitchPseudoInterpolationOffset( - lag, ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag - 1), - kMaxPitch24kHz), + kMaxPitch24kHzInt), ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag), - kMaxPitch24kHz), + kMaxPitch24kHzInt), ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag + 1), - kMaxPitch24kHz)); + kMaxPitch24kHzInt)); } return 2 * lag + offset; } @@ -89,15 +91,14 @@ size_t PitchPseudoInterpolationLagPitchBuf( // Refines a pitch period |inv_lag| encoded as inverted lag with // pseudo-interpolation. The output sample rate is twice as that of // |inv_lag|. -size_t PitchPseudoInterpolationInvLagAutoCorr( - size_t inv_lag, +int PitchPseudoInterpolationInvLagAutoCorr( + int inv_lag, rtc::ArrayView auto_corr) { int offset = 0; // Cannot apply pseudo-interpolation at the boundaries. - if (inv_lag > 0 && inv_lag < auto_corr.size() - 1) { - offset = GetPitchPseudoInterpolationOffset(inv_lag, auto_corr[inv_lag + 1], - auto_corr[inv_lag], - auto_corr[inv_lag - 1]); + if (inv_lag > 0 && inv_lag < static_cast(auto_corr.size()) - 1) { + offset = GetPitchPseudoInterpolationOffset( + auto_corr[inv_lag + 1], auto_corr[inv_lag], auto_corr[inv_lag - 1]); } // TODO(bugs.webrtc.org/9076): When retraining, check if |offset| below should // be subtracted since |inv_lag| is an inverted lag but offset is a lag. @@ -198,8 +199,8 @@ float ComputePitchGainThreshold(int candidate_pitch_period, void ComputeSlidingFrameSquareEnergies( rtc::ArrayView pitch_buf, rtc::ArrayView yy_values) { - float yy = - ComputeAutoCorrelationCoeff(pitch_buf, kMaxPitch24kHz, kMaxPitch24kHz); + float yy = ComputeAutoCorrelationCoeff(pitch_buf, kMaxPitch24kHzInt, + kMaxPitch24kHzInt); yy_values[0] = yy; for (size_t i = 1; i < yy_values.size(); ++i) { RTC_DCHECK_LE(i, kMaxPitch24kHz + kFrameSize20ms24kHz); @@ -213,14 +214,14 @@ void ComputeSlidingFrameSquareEnergies( } } -std::array FindBestPitchPeriods( +CandidatePitchPeriods FindBestPitchPeriods( rtc::ArrayView auto_corr, rtc::ArrayView pitch_buf, - size_t max_pitch_period) { + int max_pitch_period) { // Stores a pitch candidate period and strength information. struct PitchCandidate { // Pitch period encoded as inverted lag. - size_t period_inverted_lag = 0; + int period_inverted_lag = 0; // Pitch strength encoded as a ratio. float strength_numerator = -1.f; float strength_denominator = 0.f; @@ -232,9 +233,10 @@ std::array FindBestPitchPeriods( } }; - RTC_DCHECK_GT(max_pitch_period, auto_corr.size()); - RTC_DCHECK_LT(max_pitch_period, pitch_buf.size()); - const size_t frame_size = pitch_buf.size() - max_pitch_period; + RTC_DCHECK_GT(max_pitch_period, static_cast(auto_corr.size())); + RTC_DCHECK_LT(max_pitch_period, static_cast(pitch_buf.size())); + const int frame_size = static_cast(pitch_buf.size()) - max_pitch_period; + RTC_DCHECK_GT(frame_size, 0); // TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization. float yy = std::inner_product(pitch_buf.begin(), pitch_buf.begin() + frame_size + 1, @@ -245,7 +247,8 @@ std::array FindBestPitchPeriods( PitchCandidate best; PitchCandidate second_best; second_best.period_inverted_lag = 1; - for (size_t inv_lag = 0; inv_lag < auto_corr.size(); ++inv_lag) { + for (int inv_lag = 0; inv_lag < static_cast(auto_corr.size()); + ++inv_lag) { // A pitch candidate must have positive correlation. if (auto_corr[inv_lag] > 0) { candidate.period_inverted_lag = inv_lag; @@ -267,32 +270,35 @@ std::array FindBestPitchPeriods( yy += new_coeff * new_coeff; yy = std::max(0.f, yy); } - return {{best.period_inverted_lag, second_best.period_inverted_lag}}; + return {best.period_inverted_lag, second_best.period_inverted_lag}; } -size_t RefinePitchPeriod48kHz( +int RefinePitchPeriod48kHz( rtc::ArrayView pitch_buf, - rtc::ArrayView inv_lags) { + CandidatePitchPeriods pitch_candidates_inverted_lags) { // Compute the auto-correlation terms only for neighbors of the given pitch // candidates (similar to what is done in ComputePitchAutoCorrelation(), but // for a few lag values). - std::array auto_corr; - auto_corr.fill(0.f); // Zeros become ignored lags in FindBestPitchPeriods(). - auto is_neighbor = [](size_t i, size_t j) { + std::array auto_correlation; + auto_correlation.fill( + 0.f); // Zeros become ignored lags in FindBestPitchPeriods(). + auto is_neighbor = [](int i, int j) { return ((i > j) ? (i - j) : (j - i)) <= 2; }; - for (size_t inv_lag = 0; inv_lag < auto_corr.size(); ++inv_lag) { - if (is_neighbor(inv_lag, inv_lags[0]) || is_neighbor(inv_lag, inv_lags[1])) - auto_corr[inv_lag] = - ComputeAutoCorrelationCoeff(pitch_buf, inv_lag, kMaxPitch24kHz); + // TODO(https://crbug.com/webrtc/10480): Optimize by removing the loop. + for (int inverted_lag = 0; rtc::SafeLt(inverted_lag, auto_correlation.size()); + ++inverted_lag) { + if (is_neighbor(inverted_lag, pitch_candidates_inverted_lags.best) || + is_neighbor(inverted_lag, pitch_candidates_inverted_lags.second_best)) + auto_correlation[inverted_lag] = ComputeAutoCorrelationCoeff( + pitch_buf, inverted_lag, kMaxPitch24kHzInt); } // Find best pitch at 24 kHz. - const auto pitch_candidates_inv_lags = FindBestPitchPeriods( - {auto_corr.data(), auto_corr.size()}, - {pitch_buf.data(), pitch_buf.size()}, kMaxPitch24kHz); - const auto inv_lag = pitch_candidates_inv_lags[0]; // Refine the best. + const CandidatePitchPeriods pitch_candidates_24kHz = + FindBestPitchPeriods(auto_correlation, pitch_buf, kMaxPitch24kHzInt); // Pseudo-interpolation. - return PitchPseudoInterpolationInvLagAutoCorr(inv_lag, auto_corr); + return PitchPseudoInterpolationInvLagAutoCorr(pitch_candidates_24kHz.best, + auto_correlation); } PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( @@ -327,15 +333,15 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( }; // Initial pitch candidate gain. RefinedPitchCandidate best_pitch; - best_pitch.period_24kHz = std::min(initial_pitch_period_48kHz / 2, - static_cast(kMaxPitch24kHz - 1)); + best_pitch.period_24kHz = + std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHzInt - 1); best_pitch.xy = ComputeAutoCorrelationCoeff( - pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHz); + pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHzInt); best_pitch.yy = yy_values[best_pitch.period_24kHz]; best_pitch.gain = pitch_gain(best_pitch.xy, best_pitch.yy, xx); // Store the initial pitch period information. - const size_t initial_pitch_period = best_pitch.period_24kHz; + const int initial_pitch_period = best_pitch.period_24kHz; const float initial_pitch_gain = best_pitch.gain; // Given the initial pitch estimation, check lower periods (i.e., harmonics). @@ -343,12 +349,13 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( RTC_DCHECK_GT(k, 0); return (2 * n * period + k) / (2 * k); // Same as round(n*period/k). }; - for (int k = 2; k < static_cast(kSubHarmonicMultipliers.size() + 2); - ++k) { + // |max_k| such that alternative_period(initial_pitch_period, max_k, 1) equals + // kMinPitch24kHz. + const int max_k = + (2 * initial_pitch_period) / (2 * static_cast(kMinPitch24kHz) - 1); + for (int k = 2; k <= max_k; ++k) { int candidate_pitch_period = alternative_period(initial_pitch_period, k, 1); - if (static_cast(candidate_pitch_period) < kMinPitch24kHz) { - break; - } + RTC_DCHECK_GE(candidate_pitch_period, static_cast(kMinPitch24kHz)); // When looking at |candidate_pitch_period|, we also look at one of its // sub-harmonics. |kSubHarmonicMultipliers| is used to know where to look. // |k| == 2 is a special case since |candidate_pitch_secondary_period| might @@ -356,8 +363,7 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( int candidate_pitch_secondary_period = alternative_period( initial_pitch_period, k, kSubHarmonicMultipliers[k - 2]); RTC_DCHECK_GT(candidate_pitch_secondary_period, 0); - if (k == 2 && - candidate_pitch_secondary_period > static_cast(kMaxPitch24kHz)) { + if (k == 2 && candidate_pitch_secondary_period > kMaxPitch24kHzInt) { candidate_pitch_secondary_period = initial_pitch_period; } RTC_DCHECK_NE(candidate_pitch_period, candidate_pitch_secondary_period) @@ -367,10 +373,10 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( // |candidate_pitch_period| by also looking at its possible sub-harmonic // |candidate_pitch_secondary_period|. float xy_primary_period = ComputeAutoCorrelationCoeff( - pitch_buf, GetInvertedLag(candidate_pitch_period), kMaxPitch24kHz); + pitch_buf, GetInvertedLag(candidate_pitch_period), kMaxPitch24kHzInt); float xy_secondary_period = ComputeAutoCorrelationCoeff( pitch_buf, GetInvertedLag(candidate_pitch_secondary_period), - kMaxPitch24kHz); + kMaxPitch24kHzInt); float xy = 0.5f * (xy_primary_period + xy_secondary_period); float yy = 0.5f * (yy_values[candidate_pitch_period] + yy_values[candidate_pitch_secondary_period]); @@ -393,7 +399,7 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( : best_pitch.xy / (best_pitch.yy + 1.f); final_pitch_gain = std::min(best_pitch.gain, final_pitch_gain); int final_pitch_period_48kHz = std::max( - kMinPitch48kHz, + static_cast(kMinPitch48kHz), PitchPseudoInterpolationLagPitchBuf(best_pitch.period_24kHz, pitch_buf)); return {final_pitch_period_48kHz, final_pitch_gain}; diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h index 2cc5ce6af8..cab6286523 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h @@ -14,6 +14,7 @@ #include #include +#include #include "api/array_view.h" #include "modules/audio_processing/agc2/rnn_vad/common.h" @@ -49,20 +50,26 @@ void ComputeSlidingFrameSquareEnergies( rtc::ArrayView pitch_buf, rtc::ArrayView yy_values); -// Given the auto-correlation coefficients stored according to -// ComputePitchAutoCorrelation() (i.e., using inverted lags), returns the best -// and the second best pitch periods. -std::array FindBestPitchPeriods( +// Top-2 pitch period candidates. +struct CandidatePitchPeriods { + int best; + int second_best; +}; + +// Computes the candidate pitch periods given the auto-correlation coefficients +// stored according to ComputePitchAutoCorrelation() (i.e., using inverted +// lags). The return periods are inverted lags. +CandidatePitchPeriods FindBestPitchPeriods( rtc::ArrayView auto_corr, rtc::ArrayView pitch_buf, - size_t max_pitch_period); + int max_pitch_period); // Refines the pitch period estimation given the pitch buffer |pitch_buf| and -// the initial pitch period estimation |inv_lags|. Returns an inverted lag at -// 48 kHz. -size_t RefinePitchPeriod48kHz( +// the initial pitch period estimation |pitch_candidates_inverted_lags|. +// Returns an inverted lag at 48 kHz. +int RefinePitchPeriod48kHz( rtc::ArrayView pitch_buf, - rtc::ArrayView inv_lags); + CandidatePitchPeriods pitch_candidates_inverted_lags); // Refines the pitch period estimation and compute the pitch gain. Returns the // refined pitch estimation data at 48 kHz. diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc index 23ff49a2fc..37fb15f72e 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc @@ -104,31 +104,29 @@ TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) { PitchTestData test_data; std::array pitch_buf_decimated; Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated); - std::array pitch_candidates_inv_lags; + CandidatePitchPeriods pitch_candidates; { // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView(); - pitch_candidates_inv_lags = - FindBestPitchPeriods({auto_corr_view.data(), auto_corr_view.size()}, - pitch_buf_decimated, kMaxPitch12kHz); + pitch_candidates = FindBestPitchPeriods(auto_corr_view, pitch_buf_decimated, + kMaxPitch12kHz); } - EXPECT_EQ(pitch_candidates_inv_lags[0], static_cast(140)); - EXPECT_EQ(pitch_candidates_inv_lags[1], static_cast(142)); + EXPECT_EQ(pitch_candidates.best, 140); + EXPECT_EQ(pitch_candidates.second_best, 142); } // Checks that the refined pitch period is bit-exact given test input data. TEST(RnnVadTest, RefinePitchPeriod48kHzBitExactness) { PitchTestData test_data; - size_t pitch_inv_lag; - { - // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. - // FloatingPointExceptionObserver fpe_observer; - const std::array pitch_candidates_inv_lags = {280, 284}; - pitch_inv_lag = RefinePitchPeriod48kHz(test_data.GetPitchBufView(), - pitch_candidates_inv_lags); - } - EXPECT_EQ(560u, pitch_inv_lag); + // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. + // FloatingPointExceptionObserver fpe_observer; + EXPECT_EQ(RefinePitchPeriod48kHz(test_data.GetPitchBufView(), + /*pitch_candidates=*/{280, 284}), + 560); + EXPECT_EQ(RefinePitchPeriod48kHz(test_data.GetPitchBufView(), + /*pitch_candidates=*/{260, 284}), + 568); } class CheckLowerPitchPeriodsAndComputePitchGainTest