RNN VAD: pitch search optimizations (part 4)

Add inverted lags index to simplify the loop in
`FindBestPitchPeriod48kHz()`. Instead of looping over 294 items,
only loop over the relevant ones (up to 10) by keeping track of
the relevant indexes.

The benchmark has shown a slight improvement (about +6x).

Benchmarked as follows:
```
out/release/modules_unittests \
  --gtest_filter=*RnnVadTest.DISABLED_RnnVadPerformance* \
  --gtest_also_run_disabled_tests --logs
```

Results:

      | baseline             | this CL
------+----------------------+------------------------
run 1 | 22.8319 +/- 1.46554  | 22.1951 +/- 0.747611
      | 389.367x             | 400.539x
------+----------------------+------------------------
run 2 | 22.4286 +/- 0.726449 | 22.2718 +/- 0.963738
      | 396.369x             | 399.16x
------+----------------------+------------------------
run 2 | 22.5688 +/- 0.831341 | 22.4166 +/- 0.953362
      | 393.906x             | 396.581x

This CL also moved `PitchPseudoInterpolationInvLagAutoCorr()`
into `FindBestPitchPeriod48kHz()`.

Bug: webrtc:10480
Change-Id: Id4e6d755045c3198a80fa94a0a7463577d909b7e
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/191764
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Karl Wiberg <kwiberg@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32590}
This commit is contained in:
Alessio Bazzica 2020-11-11 14:59:40 +01:00 committed by Commit Bot
parent ccbc216ac5
commit 05f5d636e5
2 changed files with 76 additions and 44 deletions

View File

@ -79,24 +79,6 @@ int PitchPseudoInterpolationLagPitchBuf(
return 2 * lag + offset;
}
// Refines a pitch period |inverted_lag| encoded as inverted lag with
// pseudo-interpolation. The output sample rate is twice as that of
// |inverted_lag|.
int PitchPseudoInterpolationInvLagAutoCorr(
int inverted_lag,
rtc::ArrayView<const float, kInitialNumLags24kHz> auto_correlation) {
int offset = 0;
// Cannot apply pseudo-interpolation at the boundaries.
if (inverted_lag > 0 && inverted_lag < kInitialNumLags24kHz - 1) {
offset = GetPitchPseudoInterpolationOffset(
auto_correlation[inverted_lag + 1], auto_correlation[inverted_lag],
auto_correlation[inverted_lag - 1]);
}
// TODO(bugs.webrtc.org/9076): When retraining, check if |offset| below should
// be subtracted since |inverted_lag| is an inverted lag but offset is a lag.
return 2 * inverted_lag + offset;
}
// Integer multipliers used in ComputeExtendedPitchPeriod48kHz() when
// looking for sub-harmonics.
// The values have been chosen to serve the following algorithm. Given the
@ -129,35 +111,75 @@ struct Range {
int max;
};
// Number of analyzed pitches to the left(right) of a pitch candidate.
constexpr int kPitchNeighborhoodRadius = 2;
// Creates a pitch period interval centered in `inverted_lag` with hard-coded
// radius. Clipping is applied so that the interval is always valid for a 24 kHz
// pitch buffer.
Range CreateInvertedLagRange(int inverted_lag) {
constexpr int kRadius = 2;
return {std::max(inverted_lag - kRadius, 0),
std::min(inverted_lag + kRadius, kInitialNumLags24kHz - 1)};
return {std::max(inverted_lag - kPitchNeighborhoodRadius, 0),
std::min(inverted_lag + kPitchNeighborhoodRadius,
kInitialNumLags24kHz - 1)};
}
constexpr int kNumPitchCandidates = 2; // Best and second best.
// Maximum number of analyzed pitch periods.
constexpr int kMaxPitchPeriods24kHz =
kNumPitchCandidates * (2 * kPitchNeighborhoodRadius + 1);
// Collection of inverted lags.
class InvertedLagsIndex {
public:
InvertedLagsIndex() : num_entries_(0) {}
// Adds an inverted lag to the index. Cannot add more than
// `kMaxPitchPeriods24kHz` values.
void Append(int inverted_lag) {
RTC_DCHECK_LT(num_entries_, kMaxPitchPeriods24kHz);
inverted_lags_[num_entries_++] = inverted_lag;
}
const int* data() const { return inverted_lags_.data(); }
int size() const { return num_entries_; }
private:
std::array<int, kMaxPitchPeriods24kHz> inverted_lags_;
int num_entries_;
};
// Computes the auto correlation coefficients for the inverted lags in the
// closed interval `inverted_lags`.
// closed interval `inverted_lags`. Updates `inverted_lags_index` by appending
// the inverted lags for the computed auto correlation values.
void ComputeAutoCorrelation(
Range inverted_lags,
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
rtc::ArrayView<float, kInitialNumLags24kHz> auto_correlation) {
rtc::ArrayView<float, kInitialNumLags24kHz> auto_correlation,
InvertedLagsIndex& inverted_lags_index) {
// Check valid range.
RTC_DCHECK_LE(inverted_lags.min, inverted_lags.max);
// Trick to avoid zero initialization of `auto_correlation`.
// Needed by the pseudo-interpolation.
if (inverted_lags.min > 0) {
auto_correlation[inverted_lags.min - 1] = 0.f;
}
if (inverted_lags.max < kInitialNumLags24kHz - 1) {
auto_correlation[inverted_lags.max + 1] = 0.f;
}
// Check valid `inverted_lag` indexes.
RTC_DCHECK_GE(inverted_lags.min, 0);
RTC_DCHECK_LT(inverted_lags.max, auto_correlation.size());
RTC_DCHECK_LT(inverted_lags.max, kInitialNumLags24kHz);
for (int inverted_lag = inverted_lags.min; inverted_lag <= inverted_lags.max;
++inverted_lag) {
auto_correlation[inverted_lag] =
ComputeAutoCorrelation(inverted_lag, pitch_buffer);
inverted_lags_index.Append(inverted_lag);
}
}
int ComputePitchPeriod24kHz(
// Searches the strongest pitch period at 24 kHz and returns its inverted lag at
// 48 kHz.
int ComputePitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
rtc::ArrayView<const int> inverted_lags,
rtc::ArrayView<const float, kInitialNumLags24kHz> auto_correlation,
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy) {
static_assert(kMaxPitch24kHz > kInitialNumLags24kHz, "");
@ -165,8 +187,7 @@ int ComputePitchPeriod24kHz(
int best_inverted_lag = 0; // Pitch period.
float best_numerator = -1.f; // Pitch strength numerator.
float best_denominator = 0.f; // Pitch strength denominator.
for (int inverted_lag = 0; inverted_lag < kInitialNumLags24kHz;
++inverted_lag) {
for (int inverted_lag : inverted_lags) {
// A pitch candidate must have positive correlation.
if (auto_correlation[inverted_lag] > 0.f) {
// Auto-correlation energy normalized by frame energy.
@ -181,7 +202,19 @@ int ComputePitchPeriod24kHz(
}
}
}
return best_inverted_lag;
// Pseudo-interpolation to transform `best_inverted_lag` (24 kHz pitch) to a
// 48 kHz pitch period.
if (best_inverted_lag == 0 || best_inverted_lag >= kInitialNumLags24kHz - 1) {
// Cannot apply pseudo-interpolation at the boundaries.
return best_inverted_lag * 2;
}
int offset = GetPitchPseudoInterpolationOffset(
auto_correlation[best_inverted_lag + 1],
auto_correlation[best_inverted_lag],
auto_correlation[best_inverted_lag - 1]);
// TODO(bugs.webrtc.org/9076): When retraining, check if |offset| below should
// be subtracted since |inverted_lag| is an inverted lag but offset is a lag.
return 2 * best_inverted_lag + offset;
}
// Returns an alternative pitch period for `pitch_period` given a `multiplier`
@ -332,10 +365,10 @@ int ComputePitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
CandidatePitchPeriods pitch_candidates) {
// Compute the auto-correlation terms only for neighbors of the given pitch
// candidates (similar to what is done in ComputePitchAutoCorrelation(), but
// for a few lag values).
std::array<float, kInitialNumLags24kHz> auto_correlation{};
// Compute the auto-correlation terms only for neighbors of the two pitch
// candidates (best and second best).
std::array<float, kInitialNumLags24kHz> auto_correlation;
InvertedLagsIndex inverted_lags_index;
// Create two inverted lag ranges so that `r1` precedes `r2`.
const bool swap_candidates =
pitch_candidates.best > pitch_candidates.second_best;
@ -351,18 +384,17 @@ int ComputePitchPeriod48kHz(
RTC_DCHECK_LE(r1.max, r2.max);
if (r1.max + 1 >= r2.min) {
// Overlapping or adjacent ranges.
ComputeAutoCorrelation({r1.min, r2.max}, pitch_buffer, auto_correlation);
ComputeAutoCorrelation({r1.min, r2.max}, pitch_buffer, auto_correlation,
inverted_lags_index);
} else {
// Disjoint ranges.
ComputeAutoCorrelation(r1, pitch_buffer, auto_correlation);
ComputeAutoCorrelation(r2, pitch_buffer, auto_correlation);
ComputeAutoCorrelation(r1, pitch_buffer, auto_correlation,
inverted_lags_index);
ComputeAutoCorrelation(r2, pitch_buffer, auto_correlation,
inverted_lags_index);
}
// Find best pitch at 24 kHz.
const int pitch_candidate_24kHz =
ComputePitchPeriod24kHz(pitch_buffer, auto_correlation, y_energy);
// Pseudo-interpolation.
return PitchPseudoInterpolationInvLagAutoCorr(pitch_candidate_24kHz,
auto_correlation);
return ComputePitchPeriod48kHz(pitch_buffer, inverted_lags_index,
auto_correlation, y_energy);
}
PitchInfo ComputeExtendedPitchPeriod48kHz(

View File

@ -128,9 +128,9 @@ class ExtendedPitchPeriodSearchParametrizaion
TEST_P(ExtendedPitchPeriodSearchParametrizaion,
PeriodBitExactnessGainWithinTolerance) {
PitchTestData test_data;
std::vector<float> y_energy(kMaxPitch24kHz + 1);
rtc::ArrayView<float, kMaxPitch24kHz + 1> y_energy_view(y_energy.data(),
kMaxPitch24kHz + 1);
std::vector<float> y_energy(kRefineNumLags24kHz);
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_view(y_energy.data(),
kRefineNumLags24kHz);
ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
y_energy_view);
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.