RNN VAD: pitch search optimizations (part 4)
Add inverted lags index to simplify the loop in
`FindBestPitchPeriod48kHz()`. Instead of looping over 294 items,
only loop over the relevant ones (up to 10) by keeping track of
the relevant indexes.
The benchmark has shown a slight improvement (about +6x).
Benchmarked as follows:
```
out/release/modules_unittests \
--gtest_filter=*RnnVadTest.DISABLED_RnnVadPerformance* \
--gtest_also_run_disabled_tests --logs
```
Results:
| baseline | this CL
------+----------------------+------------------------
run 1 | 22.8319 +/- 1.46554 | 22.1951 +/- 0.747611
| 389.367x | 400.539x
------+----------------------+------------------------
run 2 | 22.4286 +/- 0.726449 | 22.2718 +/- 0.963738
| 396.369x | 399.16x
------+----------------------+------------------------
run 2 | 22.5688 +/- 0.831341 | 22.4166 +/- 0.953362
| 393.906x | 396.581x
This CL also moved `PitchPseudoInterpolationInvLagAutoCorr()`
into `FindBestPitchPeriod48kHz()`.
Bug: webrtc:10480
Change-Id: Id4e6d755045c3198a80fa94a0a7463577d909b7e
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/191764
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Karl Wiberg <kwiberg@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32590}
This commit is contained in:
parent
ccbc216ac5
commit
05f5d636e5
@ -79,24 +79,6 @@ int PitchPseudoInterpolationLagPitchBuf(
|
||||
return 2 * lag + offset;
|
||||
}
|
||||
|
||||
// Refines a pitch period |inverted_lag| encoded as inverted lag with
|
||||
// pseudo-interpolation. The output sample rate is twice as that of
|
||||
// |inverted_lag|.
|
||||
int PitchPseudoInterpolationInvLagAutoCorr(
|
||||
int inverted_lag,
|
||||
rtc::ArrayView<const float, kInitialNumLags24kHz> auto_correlation) {
|
||||
int offset = 0;
|
||||
// Cannot apply pseudo-interpolation at the boundaries.
|
||||
if (inverted_lag > 0 && inverted_lag < kInitialNumLags24kHz - 1) {
|
||||
offset = GetPitchPseudoInterpolationOffset(
|
||||
auto_correlation[inverted_lag + 1], auto_correlation[inverted_lag],
|
||||
auto_correlation[inverted_lag - 1]);
|
||||
}
|
||||
// TODO(bugs.webrtc.org/9076): When retraining, check if |offset| below should
|
||||
// be subtracted since |inverted_lag| is an inverted lag but offset is a lag.
|
||||
return 2 * inverted_lag + offset;
|
||||
}
|
||||
|
||||
// Integer multipliers used in ComputeExtendedPitchPeriod48kHz() when
|
||||
// looking for sub-harmonics.
|
||||
// The values have been chosen to serve the following algorithm. Given the
|
||||
@ -129,35 +111,75 @@ struct Range {
|
||||
int max;
|
||||
};
|
||||
|
||||
// Number of analyzed pitches to the left(right) of a pitch candidate.
|
||||
constexpr int kPitchNeighborhoodRadius = 2;
|
||||
|
||||
// Creates a pitch period interval centered in `inverted_lag` with hard-coded
|
||||
// radius. Clipping is applied so that the interval is always valid for a 24 kHz
|
||||
// pitch buffer.
|
||||
Range CreateInvertedLagRange(int inverted_lag) {
|
||||
constexpr int kRadius = 2;
|
||||
return {std::max(inverted_lag - kRadius, 0),
|
||||
std::min(inverted_lag + kRadius, kInitialNumLags24kHz - 1)};
|
||||
return {std::max(inverted_lag - kPitchNeighborhoodRadius, 0),
|
||||
std::min(inverted_lag + kPitchNeighborhoodRadius,
|
||||
kInitialNumLags24kHz - 1)};
|
||||
}
|
||||
|
||||
constexpr int kNumPitchCandidates = 2; // Best and second best.
|
||||
// Maximum number of analyzed pitch periods.
|
||||
constexpr int kMaxPitchPeriods24kHz =
|
||||
kNumPitchCandidates * (2 * kPitchNeighborhoodRadius + 1);
|
||||
|
||||
// Collection of inverted lags.
|
||||
class InvertedLagsIndex {
|
||||
public:
|
||||
InvertedLagsIndex() : num_entries_(0) {}
|
||||
// Adds an inverted lag to the index. Cannot add more than
|
||||
// `kMaxPitchPeriods24kHz` values.
|
||||
void Append(int inverted_lag) {
|
||||
RTC_DCHECK_LT(num_entries_, kMaxPitchPeriods24kHz);
|
||||
inverted_lags_[num_entries_++] = inverted_lag;
|
||||
}
|
||||
const int* data() const { return inverted_lags_.data(); }
|
||||
int size() const { return num_entries_; }
|
||||
|
||||
private:
|
||||
std::array<int, kMaxPitchPeriods24kHz> inverted_lags_;
|
||||
int num_entries_;
|
||||
};
|
||||
|
||||
// Computes the auto correlation coefficients for the inverted lags in the
|
||||
// closed interval `inverted_lags`.
|
||||
// closed interval `inverted_lags`. Updates `inverted_lags_index` by appending
|
||||
// the inverted lags for the computed auto correlation values.
|
||||
void ComputeAutoCorrelation(
|
||||
Range inverted_lags,
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<float, kInitialNumLags24kHz> auto_correlation) {
|
||||
rtc::ArrayView<float, kInitialNumLags24kHz> auto_correlation,
|
||||
InvertedLagsIndex& inverted_lags_index) {
|
||||
// Check valid range.
|
||||
RTC_DCHECK_LE(inverted_lags.min, inverted_lags.max);
|
||||
// Trick to avoid zero initialization of `auto_correlation`.
|
||||
// Needed by the pseudo-interpolation.
|
||||
if (inverted_lags.min > 0) {
|
||||
auto_correlation[inverted_lags.min - 1] = 0.f;
|
||||
}
|
||||
if (inverted_lags.max < kInitialNumLags24kHz - 1) {
|
||||
auto_correlation[inverted_lags.max + 1] = 0.f;
|
||||
}
|
||||
// Check valid `inverted_lag` indexes.
|
||||
RTC_DCHECK_GE(inverted_lags.min, 0);
|
||||
RTC_DCHECK_LT(inverted_lags.max, auto_correlation.size());
|
||||
RTC_DCHECK_LT(inverted_lags.max, kInitialNumLags24kHz);
|
||||
for (int inverted_lag = inverted_lags.min; inverted_lag <= inverted_lags.max;
|
||||
++inverted_lag) {
|
||||
auto_correlation[inverted_lag] =
|
||||
ComputeAutoCorrelation(inverted_lag, pitch_buffer);
|
||||
inverted_lags_index.Append(inverted_lag);
|
||||
}
|
||||
}
|
||||
|
||||
int ComputePitchPeriod24kHz(
|
||||
// Searches the strongest pitch period at 24 kHz and returns its inverted lag at
|
||||
// 48 kHz.
|
||||
int ComputePitchPeriod48kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<const int> inverted_lags,
|
||||
rtc::ArrayView<const float, kInitialNumLags24kHz> auto_correlation,
|
||||
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy) {
|
||||
static_assert(kMaxPitch24kHz > kInitialNumLags24kHz, "");
|
||||
@ -165,8 +187,7 @@ int ComputePitchPeriod24kHz(
|
||||
int best_inverted_lag = 0; // Pitch period.
|
||||
float best_numerator = -1.f; // Pitch strength numerator.
|
||||
float best_denominator = 0.f; // Pitch strength denominator.
|
||||
for (int inverted_lag = 0; inverted_lag < kInitialNumLags24kHz;
|
||||
++inverted_lag) {
|
||||
for (int inverted_lag : inverted_lags) {
|
||||
// A pitch candidate must have positive correlation.
|
||||
if (auto_correlation[inverted_lag] > 0.f) {
|
||||
// Auto-correlation energy normalized by frame energy.
|
||||
@ -181,7 +202,19 @@ int ComputePitchPeriod24kHz(
|
||||
}
|
||||
}
|
||||
}
|
||||
return best_inverted_lag;
|
||||
// Pseudo-interpolation to transform `best_inverted_lag` (24 kHz pitch) to a
|
||||
// 48 kHz pitch period.
|
||||
if (best_inverted_lag == 0 || best_inverted_lag >= kInitialNumLags24kHz - 1) {
|
||||
// Cannot apply pseudo-interpolation at the boundaries.
|
||||
return best_inverted_lag * 2;
|
||||
}
|
||||
int offset = GetPitchPseudoInterpolationOffset(
|
||||
auto_correlation[best_inverted_lag + 1],
|
||||
auto_correlation[best_inverted_lag],
|
||||
auto_correlation[best_inverted_lag - 1]);
|
||||
// TODO(bugs.webrtc.org/9076): When retraining, check if |offset| below should
|
||||
// be subtracted since |inverted_lag| is an inverted lag but offset is a lag.
|
||||
return 2 * best_inverted_lag + offset;
|
||||
}
|
||||
|
||||
// Returns an alternative pitch period for `pitch_period` given a `multiplier`
|
||||
@ -332,10 +365,10 @@ int ComputePitchPeriod48kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
|
||||
CandidatePitchPeriods pitch_candidates) {
|
||||
// Compute the auto-correlation terms only for neighbors of the given pitch
|
||||
// candidates (similar to what is done in ComputePitchAutoCorrelation(), but
|
||||
// for a few lag values).
|
||||
std::array<float, kInitialNumLags24kHz> auto_correlation{};
|
||||
// Compute the auto-correlation terms only for neighbors of the two pitch
|
||||
// candidates (best and second best).
|
||||
std::array<float, kInitialNumLags24kHz> auto_correlation;
|
||||
InvertedLagsIndex inverted_lags_index;
|
||||
// Create two inverted lag ranges so that `r1` precedes `r2`.
|
||||
const bool swap_candidates =
|
||||
pitch_candidates.best > pitch_candidates.second_best;
|
||||
@ -351,18 +384,17 @@ int ComputePitchPeriod48kHz(
|
||||
RTC_DCHECK_LE(r1.max, r2.max);
|
||||
if (r1.max + 1 >= r2.min) {
|
||||
// Overlapping or adjacent ranges.
|
||||
ComputeAutoCorrelation({r1.min, r2.max}, pitch_buffer, auto_correlation);
|
||||
ComputeAutoCorrelation({r1.min, r2.max}, pitch_buffer, auto_correlation,
|
||||
inverted_lags_index);
|
||||
} else {
|
||||
// Disjoint ranges.
|
||||
ComputeAutoCorrelation(r1, pitch_buffer, auto_correlation);
|
||||
ComputeAutoCorrelation(r2, pitch_buffer, auto_correlation);
|
||||
ComputeAutoCorrelation(r1, pitch_buffer, auto_correlation,
|
||||
inverted_lags_index);
|
||||
ComputeAutoCorrelation(r2, pitch_buffer, auto_correlation,
|
||||
inverted_lags_index);
|
||||
}
|
||||
// Find best pitch at 24 kHz.
|
||||
const int pitch_candidate_24kHz =
|
||||
ComputePitchPeriod24kHz(pitch_buffer, auto_correlation, y_energy);
|
||||
// Pseudo-interpolation.
|
||||
return PitchPseudoInterpolationInvLagAutoCorr(pitch_candidate_24kHz,
|
||||
auto_correlation);
|
||||
return ComputePitchPeriod48kHz(pitch_buffer, inverted_lags_index,
|
||||
auto_correlation, y_energy);
|
||||
}
|
||||
|
||||
PitchInfo ComputeExtendedPitchPeriod48kHz(
|
||||
|
||||
@ -128,9 +128,9 @@ class ExtendedPitchPeriodSearchParametrizaion
|
||||
TEST_P(ExtendedPitchPeriodSearchParametrizaion,
|
||||
PeriodBitExactnessGainWithinTolerance) {
|
||||
PitchTestData test_data;
|
||||
std::vector<float> y_energy(kMaxPitch24kHz + 1);
|
||||
rtc::ArrayView<float, kMaxPitch24kHz + 1> y_energy_view(y_energy.data(),
|
||||
kMaxPitch24kHz + 1);
|
||||
std::vector<float> y_energy(kRefineNumLags24kHz);
|
||||
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_view(y_energy.data(),
|
||||
kRefineNumLags24kHz);
|
||||
ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
|
||||
y_energy_view);
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user