diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn index fcf179c338..7822901fba 100644 --- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn +++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn @@ -92,6 +92,7 @@ rtc_library("rnn_vad_pitch") { ":rnn_vad_common", "../../../../api:array_view", "../../../../rtc_base:checks", + "../../../../rtc_base:safe_compare", ] } diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc index 1b3b459c5f..df73274cb7 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc @@ -35,20 +35,22 @@ PitchInfo PitchEstimator::Estimate( Decimate2x(pitch_buf, pitch_buf_decimated_view_); auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buf_decimated_view_, auto_corr_view_); - std::array pitch_candidates_inv_lags = FindBestPitchPeriods( - auto_corr_view_, pitch_buf_decimated_view_, kMaxPitch12kHz); + CandidatePitchPeriods pitch_candidates_inverted_lags = + FindBestPitchPeriods(auto_corr_view_, pitch_buf_decimated_view_, + static_cast(kMaxPitch12kHz)); // Refine the pitch period estimation. // The refinement is done using the pitch buffer that contains 24 kHz samples. // Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12 // to 24 kHz. - pitch_candidates_inv_lags[0] *= 2; - pitch_candidates_inv_lags[1] *= 2; - size_t pitch_inv_lag_48kHz = - RefinePitchPeriod48kHz(pitch_buf, pitch_candidates_inv_lags); + pitch_candidates_inverted_lags.best *= 2; + pitch_candidates_inverted_lags.second_best *= 2; + const int pitch_inv_lag_48kHz = + RefinePitchPeriod48kHz(pitch_buf, pitch_candidates_inverted_lags); // Look for stronger harmonics to find the final pitch period and its gain. - RTC_DCHECK_LT(pitch_inv_lag_48kHz, kMaxPitch48kHz); + RTC_DCHECK_LT(pitch_inv_lag_48kHz, static_cast(kMaxPitch48kHz)); last_pitch_48kHz_ = CheckLowerPitchPeriodsAndComputePitchGain( - pitch_buf, kMaxPitch48kHz - pitch_inv_lag_48kHz, last_pitch_48kHz_); + pitch_buf, static_cast(kMaxPitch48kHz) - pitch_inv_lag_48kHz, + last_pitch_48kHz_); return last_pitch_48kHz_; } diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc index f24a76f7bd..922669a4c5 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc @@ -19,38 +19,41 @@ #include "modules/audio_processing/agc2/rnn_vad/common.h" #include "rtc_base/checks.h" +#include "rtc_base/numerics/safe_compare.h" namespace webrtc { namespace rnn_vad { namespace { +constexpr int kMaxPitch24kHzInt = static_cast(kMaxPitch24kHz); + // Converts a lag to an inverted lag (only for 24kHz). -size_t GetInvertedLag(size_t lag) { - RTC_DCHECK_LE(lag, kMaxPitch24kHz); - return kMaxPitch24kHz - lag; +int GetInvertedLag(int lag) { + RTC_DCHECK_LE(lag, kMaxPitch24kHzInt); + return kMaxPitch24kHzInt - lag; } float ComputeAutoCorrelationCoeff(rtc::ArrayView pitch_buf, - size_t inv_lag, - size_t max_pitch_period) { - RTC_DCHECK_LT(inv_lag, pitch_buf.size()); - RTC_DCHECK_LT(max_pitch_period, pitch_buf.size()); - RTC_DCHECK_LE(inv_lag, max_pitch_period); + int inv_lag, + int max_pitch_period) { + RTC_DCHECK_LT(inv_lag, static_cast(pitch_buf.size())); + RTC_DCHECK_LT(max_pitch_period, static_cast(pitch_buf.size())); + RTC_DCHECK_LE(inv_lag, static_cast(max_pitch_period)); // TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization. - return std::inner_product(pitch_buf.begin() + max_pitch_period, - pitch_buf.end(), pitch_buf.begin() + inv_lag, 0.f); + return std::inner_product( + pitch_buf.begin() + static_cast(max_pitch_period), + pitch_buf.end(), pitch_buf.begin() + static_cast(inv_lag), 0.f); } -// Computes a pseudo-interpolation offset for an estimated pitch period |lag| by -// looking at the auto-correlation coefficients in the neighborhood of |lag|. -// (namely, |prev_auto_corr|, |lag_auto_corr| and |next_auto_corr|). The output -// is a lag in {-1, 0, +1}. +// Given the auto-correlation coefficients for a lag and its neighbors, computes +// a pseudo-interpolation offset to be applied to the pitch period associated to +// the central auto-correlation coefficient |lag_auto_corr|. The output is a lag +// in {-1, 0, +1}. // TODO(bugs.webrtc.org/9076): Consider removing pseudo-i since it // is relevant only if the spectral analysis works at a sample rate that is // twice as that of the pitch buffer (not so important instead for the estimated // pitch period feature fed into the RNN). -int GetPitchPseudoInterpolationOffset(size_t lag, - float prev_auto_corr, +int GetPitchPseudoInterpolationOffset(float prev_auto_corr, float lag_auto_corr, float next_auto_corr) { const float& a = prev_auto_corr; @@ -68,20 +71,19 @@ int GetPitchPseudoInterpolationOffset(size_t lag, // Refines a pitch period |lag| encoded as lag with pseudo-interpolation. The // output sample rate is twice as that of |lag|. -size_t PitchPseudoInterpolationLagPitchBuf( - size_t lag, +int PitchPseudoInterpolationLagPitchBuf( + int lag, rtc::ArrayView pitch_buf) { int offset = 0; // Cannot apply pseudo-interpolation at the boundaries. - if (lag > 0 && lag < kMaxPitch24kHz) { + if (lag > 0 && lag < kMaxPitch24kHzInt) { offset = GetPitchPseudoInterpolationOffset( - lag, ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag - 1), - kMaxPitch24kHz), + kMaxPitch24kHzInt), ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag), - kMaxPitch24kHz), + kMaxPitch24kHzInt), ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag + 1), - kMaxPitch24kHz)); + kMaxPitch24kHzInt)); } return 2 * lag + offset; } @@ -89,15 +91,14 @@ size_t PitchPseudoInterpolationLagPitchBuf( // Refines a pitch period |inv_lag| encoded as inverted lag with // pseudo-interpolation. The output sample rate is twice as that of // |inv_lag|. -size_t PitchPseudoInterpolationInvLagAutoCorr( - size_t inv_lag, +int PitchPseudoInterpolationInvLagAutoCorr( + int inv_lag, rtc::ArrayView auto_corr) { int offset = 0; // Cannot apply pseudo-interpolation at the boundaries. - if (inv_lag > 0 && inv_lag < auto_corr.size() - 1) { - offset = GetPitchPseudoInterpolationOffset(inv_lag, auto_corr[inv_lag + 1], - auto_corr[inv_lag], - auto_corr[inv_lag - 1]); + if (inv_lag > 0 && inv_lag < static_cast(auto_corr.size()) - 1) { + offset = GetPitchPseudoInterpolationOffset( + auto_corr[inv_lag + 1], auto_corr[inv_lag], auto_corr[inv_lag - 1]); } // TODO(bugs.webrtc.org/9076): When retraining, check if |offset| below should // be subtracted since |inv_lag| is an inverted lag but offset is a lag. @@ -198,8 +199,8 @@ float ComputePitchGainThreshold(int candidate_pitch_period, void ComputeSlidingFrameSquareEnergies( rtc::ArrayView pitch_buf, rtc::ArrayView yy_values) { - float yy = - ComputeAutoCorrelationCoeff(pitch_buf, kMaxPitch24kHz, kMaxPitch24kHz); + float yy = ComputeAutoCorrelationCoeff(pitch_buf, kMaxPitch24kHzInt, + kMaxPitch24kHzInt); yy_values[0] = yy; for (size_t i = 1; i < yy_values.size(); ++i) { RTC_DCHECK_LE(i, kMaxPitch24kHz + kFrameSize20ms24kHz); @@ -213,14 +214,14 @@ void ComputeSlidingFrameSquareEnergies( } } -std::array FindBestPitchPeriods( +CandidatePitchPeriods FindBestPitchPeriods( rtc::ArrayView auto_corr, rtc::ArrayView pitch_buf, - size_t max_pitch_period) { + int max_pitch_period) { // Stores a pitch candidate period and strength information. struct PitchCandidate { // Pitch period encoded as inverted lag. - size_t period_inverted_lag = 0; + int period_inverted_lag = 0; // Pitch strength encoded as a ratio. float strength_numerator = -1.f; float strength_denominator = 0.f; @@ -232,9 +233,10 @@ std::array FindBestPitchPeriods( } }; - RTC_DCHECK_GT(max_pitch_period, auto_corr.size()); - RTC_DCHECK_LT(max_pitch_period, pitch_buf.size()); - const size_t frame_size = pitch_buf.size() - max_pitch_period; + RTC_DCHECK_GT(max_pitch_period, static_cast(auto_corr.size())); + RTC_DCHECK_LT(max_pitch_period, static_cast(pitch_buf.size())); + const int frame_size = static_cast(pitch_buf.size()) - max_pitch_period; + RTC_DCHECK_GT(frame_size, 0); // TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization. float yy = std::inner_product(pitch_buf.begin(), pitch_buf.begin() + frame_size + 1, @@ -245,7 +247,8 @@ std::array FindBestPitchPeriods( PitchCandidate best; PitchCandidate second_best; second_best.period_inverted_lag = 1; - for (size_t inv_lag = 0; inv_lag < auto_corr.size(); ++inv_lag) { + for (int inv_lag = 0; inv_lag < static_cast(auto_corr.size()); + ++inv_lag) { // A pitch candidate must have positive correlation. if (auto_corr[inv_lag] > 0) { candidate.period_inverted_lag = inv_lag; @@ -267,32 +270,35 @@ std::array FindBestPitchPeriods( yy += new_coeff * new_coeff; yy = std::max(0.f, yy); } - return {{best.period_inverted_lag, second_best.period_inverted_lag}}; + return {best.period_inverted_lag, second_best.period_inverted_lag}; } -size_t RefinePitchPeriod48kHz( +int RefinePitchPeriod48kHz( rtc::ArrayView pitch_buf, - rtc::ArrayView inv_lags) { + CandidatePitchPeriods pitch_candidates_inverted_lags) { // Compute the auto-correlation terms only for neighbors of the given pitch // candidates (similar to what is done in ComputePitchAutoCorrelation(), but // for a few lag values). - std::array auto_corr; - auto_corr.fill(0.f); // Zeros become ignored lags in FindBestPitchPeriods(). - auto is_neighbor = [](size_t i, size_t j) { + std::array auto_correlation; + auto_correlation.fill( + 0.f); // Zeros become ignored lags in FindBestPitchPeriods(). + auto is_neighbor = [](int i, int j) { return ((i > j) ? (i - j) : (j - i)) <= 2; }; - for (size_t inv_lag = 0; inv_lag < auto_corr.size(); ++inv_lag) { - if (is_neighbor(inv_lag, inv_lags[0]) || is_neighbor(inv_lag, inv_lags[1])) - auto_corr[inv_lag] = - ComputeAutoCorrelationCoeff(pitch_buf, inv_lag, kMaxPitch24kHz); + // TODO(https://crbug.com/webrtc/10480): Optimize by removing the loop. + for (int inverted_lag = 0; rtc::SafeLt(inverted_lag, auto_correlation.size()); + ++inverted_lag) { + if (is_neighbor(inverted_lag, pitch_candidates_inverted_lags.best) || + is_neighbor(inverted_lag, pitch_candidates_inverted_lags.second_best)) + auto_correlation[inverted_lag] = ComputeAutoCorrelationCoeff( + pitch_buf, inverted_lag, kMaxPitch24kHzInt); } // Find best pitch at 24 kHz. - const auto pitch_candidates_inv_lags = FindBestPitchPeriods( - {auto_corr.data(), auto_corr.size()}, - {pitch_buf.data(), pitch_buf.size()}, kMaxPitch24kHz); - const auto inv_lag = pitch_candidates_inv_lags[0]; // Refine the best. + const CandidatePitchPeriods pitch_candidates_24kHz = + FindBestPitchPeriods(auto_correlation, pitch_buf, kMaxPitch24kHzInt); // Pseudo-interpolation. - return PitchPseudoInterpolationInvLagAutoCorr(inv_lag, auto_corr); + return PitchPseudoInterpolationInvLagAutoCorr(pitch_candidates_24kHz.best, + auto_correlation); } PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( @@ -327,15 +333,15 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( }; // Initial pitch candidate gain. RefinedPitchCandidate best_pitch; - best_pitch.period_24kHz = std::min(initial_pitch_period_48kHz / 2, - static_cast(kMaxPitch24kHz - 1)); + best_pitch.period_24kHz = + std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHzInt - 1); best_pitch.xy = ComputeAutoCorrelationCoeff( - pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHz); + pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHzInt); best_pitch.yy = yy_values[best_pitch.period_24kHz]; best_pitch.gain = pitch_gain(best_pitch.xy, best_pitch.yy, xx); // Store the initial pitch period information. - const size_t initial_pitch_period = best_pitch.period_24kHz; + const int initial_pitch_period = best_pitch.period_24kHz; const float initial_pitch_gain = best_pitch.gain; // Given the initial pitch estimation, check lower periods (i.e., harmonics). @@ -343,12 +349,13 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( RTC_DCHECK_GT(k, 0); return (2 * n * period + k) / (2 * k); // Same as round(n*period/k). }; - for (int k = 2; k < static_cast(kSubHarmonicMultipliers.size() + 2); - ++k) { + // |max_k| such that alternative_period(initial_pitch_period, max_k, 1) equals + // kMinPitch24kHz. + const int max_k = + (2 * initial_pitch_period) / (2 * static_cast(kMinPitch24kHz) - 1); + for (int k = 2; k <= max_k; ++k) { int candidate_pitch_period = alternative_period(initial_pitch_period, k, 1); - if (static_cast(candidate_pitch_period) < kMinPitch24kHz) { - break; - } + RTC_DCHECK_GE(candidate_pitch_period, static_cast(kMinPitch24kHz)); // When looking at |candidate_pitch_period|, we also look at one of its // sub-harmonics. |kSubHarmonicMultipliers| is used to know where to look. // |k| == 2 is a special case since |candidate_pitch_secondary_period| might @@ -356,8 +363,7 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( int candidate_pitch_secondary_period = alternative_period( initial_pitch_period, k, kSubHarmonicMultipliers[k - 2]); RTC_DCHECK_GT(candidate_pitch_secondary_period, 0); - if (k == 2 && - candidate_pitch_secondary_period > static_cast(kMaxPitch24kHz)) { + if (k == 2 && candidate_pitch_secondary_period > kMaxPitch24kHzInt) { candidate_pitch_secondary_period = initial_pitch_period; } RTC_DCHECK_NE(candidate_pitch_period, candidate_pitch_secondary_period) @@ -367,10 +373,10 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( // |candidate_pitch_period| by also looking at its possible sub-harmonic // |candidate_pitch_secondary_period|. float xy_primary_period = ComputeAutoCorrelationCoeff( - pitch_buf, GetInvertedLag(candidate_pitch_period), kMaxPitch24kHz); + pitch_buf, GetInvertedLag(candidate_pitch_period), kMaxPitch24kHzInt); float xy_secondary_period = ComputeAutoCorrelationCoeff( pitch_buf, GetInvertedLag(candidate_pitch_secondary_period), - kMaxPitch24kHz); + kMaxPitch24kHzInt); float xy = 0.5f * (xy_primary_period + xy_secondary_period); float yy = 0.5f * (yy_values[candidate_pitch_period] + yy_values[candidate_pitch_secondary_period]); @@ -393,7 +399,7 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( : best_pitch.xy / (best_pitch.yy + 1.f); final_pitch_gain = std::min(best_pitch.gain, final_pitch_gain); int final_pitch_period_48kHz = std::max( - kMinPitch48kHz, + static_cast(kMinPitch48kHz), PitchPseudoInterpolationLagPitchBuf(best_pitch.period_24kHz, pitch_buf)); return {final_pitch_period_48kHz, final_pitch_gain}; diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h index 2cc5ce6af8..cab6286523 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h @@ -14,6 +14,7 @@ #include #include +#include #include "api/array_view.h" #include "modules/audio_processing/agc2/rnn_vad/common.h" @@ -49,20 +50,26 @@ void ComputeSlidingFrameSquareEnergies( rtc::ArrayView pitch_buf, rtc::ArrayView yy_values); -// Given the auto-correlation coefficients stored according to -// ComputePitchAutoCorrelation() (i.e., using inverted lags), returns the best -// and the second best pitch periods. -std::array FindBestPitchPeriods( +// Top-2 pitch period candidates. +struct CandidatePitchPeriods { + int best; + int second_best; +}; + +// Computes the candidate pitch periods given the auto-correlation coefficients +// stored according to ComputePitchAutoCorrelation() (i.e., using inverted +// lags). The return periods are inverted lags. +CandidatePitchPeriods FindBestPitchPeriods( rtc::ArrayView auto_corr, rtc::ArrayView pitch_buf, - size_t max_pitch_period); + int max_pitch_period); // Refines the pitch period estimation given the pitch buffer |pitch_buf| and -// the initial pitch period estimation |inv_lags|. Returns an inverted lag at -// 48 kHz. -size_t RefinePitchPeriod48kHz( +// the initial pitch period estimation |pitch_candidates_inverted_lags|. +// Returns an inverted lag at 48 kHz. +int RefinePitchPeriod48kHz( rtc::ArrayView pitch_buf, - rtc::ArrayView inv_lags); + CandidatePitchPeriods pitch_candidates_inverted_lags); // Refines the pitch period estimation and compute the pitch gain. Returns the // refined pitch estimation data at 48 kHz. diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc index 23ff49a2fc..37fb15f72e 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc @@ -104,31 +104,29 @@ TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) { PitchTestData test_data; std::array pitch_buf_decimated; Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated); - std::array pitch_candidates_inv_lags; + CandidatePitchPeriods pitch_candidates; { // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView(); - pitch_candidates_inv_lags = - FindBestPitchPeriods({auto_corr_view.data(), auto_corr_view.size()}, - pitch_buf_decimated, kMaxPitch12kHz); + pitch_candidates = FindBestPitchPeriods(auto_corr_view, pitch_buf_decimated, + kMaxPitch12kHz); } - EXPECT_EQ(pitch_candidates_inv_lags[0], static_cast(140)); - EXPECT_EQ(pitch_candidates_inv_lags[1], static_cast(142)); + EXPECT_EQ(pitch_candidates.best, 140); + EXPECT_EQ(pitch_candidates.second_best, 142); } // Checks that the refined pitch period is bit-exact given test input data. TEST(RnnVadTest, RefinePitchPeriod48kHzBitExactness) { PitchTestData test_data; - size_t pitch_inv_lag; - { - // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. - // FloatingPointExceptionObserver fpe_observer; - const std::array pitch_candidates_inv_lags = {280, 284}; - pitch_inv_lag = RefinePitchPeriod48kHz(test_data.GetPitchBufView(), - pitch_candidates_inv_lags); - } - EXPECT_EQ(560u, pitch_inv_lag); + // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. + // FloatingPointExceptionObserver fpe_observer; + EXPECT_EQ(RefinePitchPeriod48kHz(test_data.GetPitchBufView(), + /*pitch_candidates=*/{280, 284}), + 560); + EXPECT_EQ(RefinePitchPeriod48kHz(test_data.GetPitchBufView(), + /*pitch_candidates=*/{260, 284}), + 568); } class CheckLowerPitchPeriodsAndComputePitchGainTest