diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc index c6c3e1b2b5..9d4c5a2d81 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc @@ -19,46 +19,37 @@ namespace webrtc { namespace rnn_vad { PitchEstimator::PitchEstimator() - : y_energy_24kHz_(kRefineNumLags24kHz, 0.f), - pitch_buffer_12kHz_(kBufSize12kHz), - auto_correlation_12kHz_(kNumLags12kHz) {} + : pitch_buf_decimated_(kBufSize12kHz), + pitch_buf_decimated_view_(pitch_buf_decimated_.data(), kBufSize12kHz), + auto_corr_(kNumLags12kHz), + auto_corr_view_(auto_corr_.data(), kNumLags12kHz) { + RTC_DCHECK_EQ(kBufSize12kHz, pitch_buf_decimated_.size()); + RTC_DCHECK_EQ(kNumLags12kHz, auto_corr_view_.size()); +} PitchEstimator::~PitchEstimator() = default; int PitchEstimator::Estimate( rtc::ArrayView pitch_buffer) { - rtc::ArrayView pitch_buffer_12kHz_view( - pitch_buffer_12kHz_.data(), kBufSize12kHz); - RTC_DCHECK_EQ(pitch_buffer_12kHz_.size(), pitch_buffer_12kHz_view.size()); - rtc::ArrayView auto_correlation_12kHz_view( - auto_correlation_12kHz_.data(), kNumLags12kHz); - RTC_DCHECK_EQ(auto_correlation_12kHz_.size(), - auto_correlation_12kHz_view.size()); - // Perform the initial pitch search at 12 kHz. - Decimate2x(pitch_buffer, pitch_buffer_12kHz_view); - auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buffer_12kHz_view, - auto_correlation_12kHz_view); - CandidatePitchPeriods pitch_periods = ComputePitchPeriod12kHz( - pitch_buffer_12kHz_view, auto_correlation_12kHz_view); + Decimate2x(pitch_buffer, pitch_buf_decimated_view_); + auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buf_decimated_view_, + auto_corr_view_); + CandidatePitchPeriods pitch_candidates_inverted_lags = + ComputePitchPeriod12kHz(pitch_buf_decimated_view_, auto_corr_view_); + // Refine the pitch period estimation. // The refinement is done using the pitch buffer that contains 24 kHz samples. // Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12 // to 24 kHz. - pitch_periods.best *= 2; - pitch_periods.second_best *= 2; - - // Refine the initial pitch period estimation from 12 kHz to 48 kHz. - // Pre-compute frame energies at 24 kHz. - rtc::ArrayView y_energy_24kHz_view( - y_energy_24kHz_.data(), kRefineNumLags24kHz); - RTC_DCHECK_EQ(y_energy_24kHz_.size(), y_energy_24kHz_view.size()); - ComputeSlidingFrameSquareEnergies24kHz(pitch_buffer, y_energy_24kHz_view); - // Estimation at 48 kHz. - const int pitch_lag_48kHz = - ComputePitchPeriod48kHz(pitch_buffer, y_energy_24kHz_view, pitch_periods); + pitch_candidates_inverted_lags.best *= 2; + pitch_candidates_inverted_lags.second_best *= 2; + const int pitch_inv_lag_48kHz = + ComputePitchPeriod48kHz(pitch_buffer, pitch_candidates_inverted_lags); + // Look for stronger harmonics to find the final pitch period and its gain. + RTC_DCHECK_LT(pitch_inv_lag_48kHz, kMaxPitch48kHz); last_pitch_48kHz_ = ComputeExtendedPitchPeriod48kHz( - pitch_buffer, y_energy_24kHz_view, - /*initial_pitch_period_48kHz=*/kMaxPitch48kHz - pitch_lag_48kHz, + pitch_buffer, + /*initial_pitch_period_48kHz=*/kMaxPitch48kHz - pitch_inv_lag_48kHz, last_pitch_48kHz_); return last_pitch_48kHz_.period; } diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.h b/modules/audio_processing/agc2/rnn_vad/pitch_search.h index e96a2dcaf1..1e6b9ad706 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search.h +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.h @@ -41,9 +41,10 @@ class PitchEstimator { PitchInfo last_pitch_48kHz_{}; AutoCorrelationCalculator auto_corr_calculator_; - std::vector y_energy_24kHz_; - std::vector pitch_buffer_12kHz_; - std::vector auto_correlation_12kHz_; + std::vector pitch_buf_decimated_; + rtc::ArrayView pitch_buf_decimated_view_; + std::vector auto_corr_; + rtc::ArrayView auto_corr_view_; }; } // namespace rnn_vad diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc index d62cddf067..8179dbd965 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc @@ -153,12 +153,17 @@ void ComputeAutoCorrelation( } } -int ComputePitchPeriod24kHz( - rtc::ArrayView pitch_buffer, +int FindBestPitchPeriods24kHz( rtc::ArrayView auto_correlation, - rtc::ArrayView y_energy) { + rtc::ArrayView pitch_buffer) { static_assert(kMaxPitch24kHz > kInitialNumLags24kHz, ""); static_assert(kMaxPitch24kHz < kBufSize24kHz, ""); + // Initialize the sliding 20 ms frame energy. + // TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization. + float denominator = std::inner_product( + pitch_buffer.begin(), pitch_buffer.begin() + kFrameSize20ms24kHz + 1, + pitch_buffer.begin(), 1.f); + // Search best pitch by looking at the scaled auto-correlation. int best_inverted_lag = 0; // Pitch period. float best_numerator = -1.f; // Pitch strength numerator. float best_denominator = 0.f; // Pitch strength denominator. @@ -166,10 +171,8 @@ int ComputePitchPeriod24kHz( ++inverted_lag) { // A pitch candidate must have positive correlation. if (auto_correlation[inverted_lag] > 0.f) { - // Auto-correlation energy normalized by frame energy. const float numerator = auto_correlation[inverted_lag] * auto_correlation[inverted_lag]; - const float denominator = y_energy[kMaxPitch24kHz - inverted_lag]; // Compare numerator/denominator ratios without using divisions. if (numerator * best_denominator > best_numerator * denominator) { best_inverted_lag = inverted_lag; @@ -177,6 +180,14 @@ int ComputePitchPeriod24kHz( best_denominator = denominator; } } + // Update |denominator| for the next inverted lag. + static_assert(kInitialNumLags24kHz + kFrameSize20ms24kHz < kBufSize24kHz, + ""); + const float y_old = pitch_buffer[inverted_lag]; + const float y_new = pitch_buffer[inverted_lag + kFrameSize20ms24kHz]; + denominator -= y_old * y_old; + denominator += y_new * y_new; + denominator = std::max(0.f, denominator); } return best_inverted_lag; } @@ -327,7 +338,6 @@ CandidatePitchPeriods ComputePitchPeriod12kHz( int ComputePitchPeriod48kHz( rtc::ArrayView pitch_buffer, - rtc::ArrayView y_energy, CandidatePitchPeriods pitch_candidates) { // Compute the auto-correlation terms only for neighbors of the given pitch // candidates (similar to what is done in ComputePitchAutoCorrelation(), but @@ -352,7 +362,7 @@ int ComputePitchPeriod48kHz( } // Find best pitch at 24 kHz. const int pitch_candidate_24kHz = - ComputePitchPeriod24kHz(pitch_buffer, auto_correlation, y_energy); + FindBestPitchPeriods24kHz(auto_correlation, pitch_buffer); // Pseudo-interpolation. return PitchPseudoInterpolationInvLagAutoCorr(pitch_candidate_24kHz, auto_correlation); @@ -360,7 +370,6 @@ int ComputePitchPeriod48kHz( PitchInfo ComputeExtendedPitchPeriod48kHz( rtc::ArrayView pitch_buffer, - rtc::ArrayView y_energy, int initial_pitch_period_48kHz, PitchInfo last_pitch_48kHz) { RTC_DCHECK_LE(kMinPitch48kHz, initial_pitch_period_48kHz); @@ -370,30 +379,34 @@ PitchInfo ComputeExtendedPitchPeriod48kHz( struct RefinedPitchCandidate { int period; float strength; - // Additional strength data used for the final pitch estimation. - float xy; // Auto-correlation. - float y_energy; // Energy of the sliding frame `y`. + // Additional strength data used for the final estimation of the strength. + float xy; // Cross-correlation. + float yy; // Auto-correlation. }; - const float x_energy = y_energy[0]; - const auto pitch_strength = [x_energy](float xy, float y_energy) { - RTC_DCHECK_GE(x_energy * y_energy, 0.f); - return xy / std::sqrt(1.f + x_energy * y_energy); + // Initialize. + std::array yy_values; + // TODO(bugs.webrtc.org/9076): Reuse values from FindBestPitchPeriods24kHz(). + ComputeSlidingFrameSquareEnergies24kHz(pitch_buffer, yy_values); + const float xx = yy_values[0]; + const auto pitch_strength = [](float xy, float yy, float xx) { + RTC_DCHECK_GE(xx * yy, 0.f); + return xy / std::sqrt(1.f + xx * yy); }; - - // Initialize the best pitch candidate with `initial_pitch_period_48kHz`. + // Initial pitch candidate. RefinedPitchCandidate best_pitch; best_pitch.period = std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHz - 1); best_pitch.xy = ComputeAutoCorrelation(kMaxPitch24kHz - best_pitch.period, pitch_buffer); - best_pitch.y_energy = y_energy[best_pitch.period]; - best_pitch.strength = pitch_strength(best_pitch.xy, best_pitch.y_energy); - // Keep a copy of the initial pitch candidate. - const PitchInfo initial_pitch{best_pitch.period, best_pitch.strength}; - // 24 kHz version of the last estimated pitch. + best_pitch.yy = yy_values[best_pitch.period]; + best_pitch.strength = pitch_strength(best_pitch.xy, best_pitch.yy, xx); + + // 24 kHz version of the last estimated pitch and copy of the initial + // estimation. const PitchInfo last_pitch{last_pitch_48kHz.period / 2, last_pitch_48kHz.strength}; + const PitchInfo initial_pitch{best_pitch.period, best_pitch.strength}; // Find `max_period_divisor` such that the result of // `GetAlternativePitchPeriod(initial_pitch_period, 1, max_period_divisor)` @@ -423,14 +436,14 @@ PitchInfo ComputeExtendedPitchPeriod48kHz( // Compute an auto-correlation score for the primary pitch candidate // |alternative_pitch.period| by also looking at its possible sub-harmonic // |dual_alternative_period|. - const float xy_primary_period = ComputeAutoCorrelation( + float xy_primary_period = ComputeAutoCorrelation( kMaxPitch24kHz - alternative_pitch.period, pitch_buffer); - const float xy_secondary_period = ComputeAutoCorrelation( + float xy_secondary_period = ComputeAutoCorrelation( kMaxPitch24kHz - dual_alternative_period, pitch_buffer); - const float xy = 0.5f * (xy_primary_period + xy_secondary_period); - const float yy = 0.5f * (y_energy[alternative_pitch.period] + - y_energy[dual_alternative_period]); - alternative_pitch.strength = pitch_strength(xy, yy); + float xy = 0.5f * (xy_primary_period + xy_secondary_period); + float yy = 0.5f * (yy_values[alternative_pitch.period] + + yy_values[dual_alternative_period]); + alternative_pitch.strength = pitch_strength(xy, yy, xx); // Maybe update best period. if (IsAlternativePitchStrongerThanInitial( @@ -442,11 +455,10 @@ PitchInfo ComputeExtendedPitchPeriod48kHz( // Final pitch strength and period. best_pitch.xy = std::max(0.f, best_pitch.xy); - RTC_DCHECK_LE(0.f, best_pitch.y_energy); - float final_pitch_strength = - (best_pitch.y_energy <= best_pitch.xy) - ? 1.f - : best_pitch.xy / (best_pitch.y_energy + 1.f); + RTC_DCHECK_LE(0.f, best_pitch.yy); + float final_pitch_strength = (best_pitch.yy <= best_pitch.xy) + ? 1.f + : best_pitch.xy / (best_pitch.yy + 1.f); final_pitch_strength = std::min(best_pitch.strength, final_pitch_strength); int final_pitch_period_48kHz = std::max( kMinPitch48kHz, diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h index 693ab9e5d1..b16a2f438d 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h @@ -80,12 +80,10 @@ CandidatePitchPeriods ComputePitchPeriod12kHz( rtc::ArrayView pitch_buffer, rtc::ArrayView auto_correlation); -// Computes the pitch period at 48 kHz given a view on the 24 kHz pitch buffer, -// the energies for the sliding frames `y` at 24 kHz and the pitch period -// candidates at 24 kHz (encoded as inverted lag). +// Computes the pitch period at 48 kHz given a view on the 24 kHz pitch buffer +// and the pitch period candidates at 24 kHz (encoded as inverted lag). int ComputePitchPeriod48kHz( rtc::ArrayView pitch_buffer, - rtc::ArrayView y_energy, CandidatePitchPeriods pitch_candidates_24kHz); struct PitchInfo { @@ -94,12 +92,10 @@ struct PitchInfo { }; // Computes the pitch period at 48 kHz searching in an extended pitch range -// given a view on the 24 kHz pitch buffer, the energies for the sliding frames -// `y` at 24 kHz, the initial 48 kHz estimation (computed by -// `ComputePitchPeriod48kHz()`) and the last estimated pitch. +// given a view on the 24 kHz pitch buffer, the initial 48 kHz estimation +// (computed by `ComputePitchPeriod48kHz()`) and the last estimated pitch. PitchInfo ComputeExtendedPitchPeriod48kHz( rtc::ArrayView pitch_buffer, - rtc::ArrayView y_energy, int initial_pitch_period_48kHz, PitchInfo last_pitch_48kHz); diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc index e5826d02af..7acb046db1 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc @@ -63,17 +63,12 @@ TEST(RnnVadTest, ComputePitchPeriod12kHzBitExactness) { // Checks that the refined pitch period is bit-exact given test input data. TEST(RnnVadTest, ComputePitchPeriod48kHzBitExactness) { PitchTestData test_data; - std::vector y_energy(kMaxPitch24kHz + 1); - rtc::ArrayView y_energy_view(y_energy.data(), - kMaxPitch24kHz + 1); - ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(), - y_energy_view); // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; - EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view, + EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(), /*pitch_candidates=*/{280, 284}), 560); - EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view, + EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(), /*pitch_candidates=*/{260, 284}), 568); } @@ -95,15 +90,10 @@ class ComputeExtendedPitchPeriod48kHzTest TEST_P(ComputeExtendedPitchPeriod48kHzTest, PeriodBitExactnessGainWithinTolerance) { PitchTestData test_data; - std::vector y_energy(kMaxPitch24kHz + 1); - rtc::ArrayView y_energy_view(y_energy.data(), - kMaxPitch24kHz + 1); - ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(), - y_energy_view); // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; const auto computed_output = ComputeExtendedPitchPeriod48kHz( - test_data.GetPitchBufView(), y_energy_view, GetInitialPitchPeriod(), + test_data.GetPitchBufView(), GetInitialPitchPeriod(), {GetLastPitchPeriod(), GetLastPitchStrength()}); EXPECT_EQ(GetExpectedPitchPeriod(), computed_output.period); EXPECT_NEAR(GetExpectedPitchStrength(), computed_output.strength, 1e-6f);