diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc index d7ba65f932..262c386453 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc @@ -79,24 +79,6 @@ int PitchPseudoInterpolationLagPitchBuf( return 2 * lag + offset; } -// Refines a pitch period |inverted_lag| encoded as inverted lag with -// pseudo-interpolation. The output sample rate is twice as that of -// |inverted_lag|. -int PitchPseudoInterpolationInvLagAutoCorr( - int inverted_lag, - rtc::ArrayView auto_correlation) { - int offset = 0; - // Cannot apply pseudo-interpolation at the boundaries. - if (inverted_lag > 0 && inverted_lag < kInitialNumLags24kHz - 1) { - offset = GetPitchPseudoInterpolationOffset( - auto_correlation[inverted_lag + 1], auto_correlation[inverted_lag], - auto_correlation[inverted_lag - 1]); - } - // TODO(bugs.webrtc.org/9076): When retraining, check if |offset| below should - // be subtracted since |inverted_lag| is an inverted lag but offset is a lag. - return 2 * inverted_lag + offset; -} - // Integer multipliers used in ComputeExtendedPitchPeriod48kHz() when // looking for sub-harmonics. // The values have been chosen to serve the following algorithm. Given the @@ -129,35 +111,75 @@ struct Range { int max; }; +// Number of analyzed pitches to the left(right) of a pitch candidate. +constexpr int kPitchNeighborhoodRadius = 2; + // Creates a pitch period interval centered in `inverted_lag` with hard-coded // radius. Clipping is applied so that the interval is always valid for a 24 kHz // pitch buffer. Range CreateInvertedLagRange(int inverted_lag) { - constexpr int kRadius = 2; - return {std::max(inverted_lag - kRadius, 0), - std::min(inverted_lag + kRadius, kInitialNumLags24kHz - 1)}; + return {std::max(inverted_lag - kPitchNeighborhoodRadius, 0), + std::min(inverted_lag + kPitchNeighborhoodRadius, + kInitialNumLags24kHz - 1)}; } +constexpr int kNumPitchCandidates = 2; // Best and second best. +// Maximum number of analyzed pitch periods. +constexpr int kMaxPitchPeriods24kHz = + kNumPitchCandidates * (2 * kPitchNeighborhoodRadius + 1); + +// Collection of inverted lags. +class InvertedLagsIndex { + public: + InvertedLagsIndex() : num_entries_(0) {} + // Adds an inverted lag to the index. Cannot add more than + // `kMaxPitchPeriods24kHz` values. + void Append(int inverted_lag) { + RTC_DCHECK_LT(num_entries_, kMaxPitchPeriods24kHz); + inverted_lags_[num_entries_++] = inverted_lag; + } + const int* data() const { return inverted_lags_.data(); } + int size() const { return num_entries_; } + + private: + std::array inverted_lags_; + int num_entries_; +}; + // Computes the auto correlation coefficients for the inverted lags in the -// closed interval `inverted_lags`. +// closed interval `inverted_lags`. Updates `inverted_lags_index` by appending +// the inverted lags for the computed auto correlation values. void ComputeAutoCorrelation( Range inverted_lags, rtc::ArrayView pitch_buffer, - rtc::ArrayView auto_correlation) { + rtc::ArrayView auto_correlation, + InvertedLagsIndex& inverted_lags_index) { // Check valid range. RTC_DCHECK_LE(inverted_lags.min, inverted_lags.max); + // Trick to avoid zero initialization of `auto_correlation`. + // Needed by the pseudo-interpolation. + if (inverted_lags.min > 0) { + auto_correlation[inverted_lags.min - 1] = 0.f; + } + if (inverted_lags.max < kInitialNumLags24kHz - 1) { + auto_correlation[inverted_lags.max + 1] = 0.f; + } // Check valid `inverted_lag` indexes. RTC_DCHECK_GE(inverted_lags.min, 0); - RTC_DCHECK_LT(inverted_lags.max, auto_correlation.size()); + RTC_DCHECK_LT(inverted_lags.max, kInitialNumLags24kHz); for (int inverted_lag = inverted_lags.min; inverted_lag <= inverted_lags.max; ++inverted_lag) { auto_correlation[inverted_lag] = ComputeAutoCorrelation(inverted_lag, pitch_buffer); + inverted_lags_index.Append(inverted_lag); } } -int ComputePitchPeriod24kHz( +// Searches the strongest pitch period at 24 kHz and returns its inverted lag at +// 48 kHz. +int ComputePitchPeriod48kHz( rtc::ArrayView pitch_buffer, + rtc::ArrayView inverted_lags, rtc::ArrayView auto_correlation, rtc::ArrayView y_energy) { static_assert(kMaxPitch24kHz > kInitialNumLags24kHz, ""); @@ -165,8 +187,7 @@ int ComputePitchPeriod24kHz( int best_inverted_lag = 0; // Pitch period. float best_numerator = -1.f; // Pitch strength numerator. float best_denominator = 0.f; // Pitch strength denominator. - for (int inverted_lag = 0; inverted_lag < kInitialNumLags24kHz; - ++inverted_lag) { + for (int inverted_lag : inverted_lags) { // A pitch candidate must have positive correlation. if (auto_correlation[inverted_lag] > 0.f) { // Auto-correlation energy normalized by frame energy. @@ -181,7 +202,19 @@ int ComputePitchPeriod24kHz( } } } - return best_inverted_lag; + // Pseudo-interpolation to transform `best_inverted_lag` (24 kHz pitch) to a + // 48 kHz pitch period. + if (best_inverted_lag == 0 || best_inverted_lag >= kInitialNumLags24kHz - 1) { + // Cannot apply pseudo-interpolation at the boundaries. + return best_inverted_lag * 2; + } + int offset = GetPitchPseudoInterpolationOffset( + auto_correlation[best_inverted_lag + 1], + auto_correlation[best_inverted_lag], + auto_correlation[best_inverted_lag - 1]); + // TODO(bugs.webrtc.org/9076): When retraining, check if |offset| below should + // be subtracted since |inverted_lag| is an inverted lag but offset is a lag. + return 2 * best_inverted_lag + offset; } // Returns an alternative pitch period for `pitch_period` given a `multiplier` @@ -332,10 +365,10 @@ int ComputePitchPeriod48kHz( rtc::ArrayView pitch_buffer, rtc::ArrayView y_energy, CandidatePitchPeriods pitch_candidates) { - // Compute the auto-correlation terms only for neighbors of the given pitch - // candidates (similar to what is done in ComputePitchAutoCorrelation(), but - // for a few lag values). - std::array auto_correlation{}; + // Compute the auto-correlation terms only for neighbors of the two pitch + // candidates (best and second best). + std::array auto_correlation; + InvertedLagsIndex inverted_lags_index; // Create two inverted lag ranges so that `r1` precedes `r2`. const bool swap_candidates = pitch_candidates.best > pitch_candidates.second_best; @@ -351,18 +384,17 @@ int ComputePitchPeriod48kHz( RTC_DCHECK_LE(r1.max, r2.max); if (r1.max + 1 >= r2.min) { // Overlapping or adjacent ranges. - ComputeAutoCorrelation({r1.min, r2.max}, pitch_buffer, auto_correlation); + ComputeAutoCorrelation({r1.min, r2.max}, pitch_buffer, auto_correlation, + inverted_lags_index); } else { // Disjoint ranges. - ComputeAutoCorrelation(r1, pitch_buffer, auto_correlation); - ComputeAutoCorrelation(r2, pitch_buffer, auto_correlation); + ComputeAutoCorrelation(r1, pitch_buffer, auto_correlation, + inverted_lags_index); + ComputeAutoCorrelation(r2, pitch_buffer, auto_correlation, + inverted_lags_index); } - // Find best pitch at 24 kHz. - const int pitch_candidate_24kHz = - ComputePitchPeriod24kHz(pitch_buffer, auto_correlation, y_energy); - // Pseudo-interpolation. - return PitchPseudoInterpolationInvLagAutoCorr(pitch_candidate_24kHz, - auto_correlation); + return ComputePitchPeriod48kHz(pitch_buffer, inverted_lags_index, + auto_correlation, y_energy); } PitchInfo ComputeExtendedPitchPeriod48kHz( diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc index fc715c6aef..152d569823 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc @@ -128,9 +128,9 @@ class ExtendedPitchPeriodSearchParametrizaion TEST_P(ExtendedPitchPeriodSearchParametrizaion, PeriodBitExactnessGainWithinTolerance) { PitchTestData test_data; - std::vector y_energy(kMaxPitch24kHz + 1); - rtc::ArrayView y_energy_view(y_energy.data(), - kMaxPitch24kHz + 1); + std::vector y_energy(kRefineNumLags24kHz); + rtc::ArrayView y_energy_view(y_energy.data(), + kRefineNumLags24kHz); ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(), y_energy_view); // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.