RNN VAD: use VectorMath::DotProduct() for pitch search
This CL brings a large improvement to the RNN VAD CPU performance
by finally using `VectorMath::DotProduct()` for pitch search.
The realtime factor improved from about 390x to 570x for SSE2
(+180x, 45% faster) and to 610x for AVX2 (+235x, 60% faster).
RNN VAD benchmark results:
```
+-----+-------+------+------+
| run | none* | SSE2 | AVX2 |
+-----+-------+------+------+
| 1 | 393x | 572x | 618x |
| 2 | 388x | 568x | 607x |
| 3 | 393x | 564x | 599x |
+-----+-------+------+------+
```
*: baseline, no SIMD used for pitch search, but SSE2 used for the RNN
Results obtained as follows:
1. Force SSE2 in `DISABLED_RnnVadPerformance` for the RNN part in
order to measure the baseline correctly:
```
RnnBasedVad rnn_vad({/*sse2=*/true, /*avx2=*/true, /*neon=*/false});
```
2. Run the test:
```
$ ./out/release/modules_unittests \
--gtest_filter=*RnnVadTest*DISABLED_RnnVadPerformance* \
--gtest_also_run_disabled_tests --logs
```
Bug: webrtc:10480
Change-Id: I89a2bd420265540026944b9c0f1fdd4bfda7f475
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/195001
Reviewed-by: Gustaf Ullberg <gustaf@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32755}
This commit is contained in:
parent
be810cba19
commit
fd5dadbea9
@ -124,6 +124,7 @@ rtc_library("rnn_vad_pitch") {
|
||||
"../../../../rtc_base:gtest_prod",
|
||||
"../../../../rtc_base:safe_compare",
|
||||
"../../../../rtc_base:safe_conversions",
|
||||
"../../../../rtc_base/system:arch",
|
||||
]
|
||||
if (current_cpu == "x86" || current_cpu == "x64") {
|
||||
deps += [ ":vector_math_avx2" ]
|
||||
@ -246,6 +247,7 @@ if (rtc_include_tests) {
|
||||
"../../../../rtc_base:logging",
|
||||
"../../../../rtc_base:safe_compare",
|
||||
"../../../../rtc_base:safe_conversions",
|
||||
"../../../../rtc_base:stringutils",
|
||||
"../../../../rtc_base/system:arch",
|
||||
"../../../../test:test_support",
|
||||
"../../utility:pffft_wrapper",
|
||||
|
||||
@ -42,7 +42,7 @@ int PitchEstimator::Estimate(
|
||||
auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buffer_12kHz_view,
|
||||
auto_correlation_12kHz_view);
|
||||
CandidatePitchPeriods pitch_periods = ComputePitchPeriod12kHz(
|
||||
pitch_buffer_12kHz_view, auto_correlation_12kHz_view);
|
||||
pitch_buffer_12kHz_view, auto_correlation_12kHz_view, cpu_features_);
|
||||
// The refinement is done using the pitch buffer that contains 24 kHz samples.
|
||||
// Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12
|
||||
// to 24 kHz.
|
||||
@ -54,14 +54,15 @@ int PitchEstimator::Estimate(
|
||||
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_24kHz_view(
|
||||
y_energy_24kHz_.data(), kRefineNumLags24kHz);
|
||||
RTC_DCHECK_EQ(y_energy_24kHz_.size(), y_energy_24kHz_view.size());
|
||||
ComputeSlidingFrameSquareEnergies24kHz(pitch_buffer, y_energy_24kHz_view);
|
||||
ComputeSlidingFrameSquareEnergies24kHz(pitch_buffer, y_energy_24kHz_view,
|
||||
cpu_features_);
|
||||
// Estimation at 48 kHz.
|
||||
const int pitch_lag_48kHz =
|
||||
ComputePitchPeriod48kHz(pitch_buffer, y_energy_24kHz_view, pitch_periods);
|
||||
const int pitch_lag_48kHz = ComputePitchPeriod48kHz(
|
||||
pitch_buffer, y_energy_24kHz_view, pitch_periods, cpu_features_);
|
||||
last_pitch_48kHz_ = ComputeExtendedPitchPeriod48kHz(
|
||||
pitch_buffer, y_energy_24kHz_view,
|
||||
/*initial_pitch_period_48kHz=*/kMaxPitch48kHz - pitch_lag_48kHz,
|
||||
last_pitch_48kHz_);
|
||||
last_pitch_48kHz_, cpu_features_);
|
||||
return last_pitch_48kHz_.period;
|
||||
}
|
||||
|
||||
|
||||
@ -18,9 +18,11 @@
|
||||
#include <numeric>
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/vector_math.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
#include "rtc_base/numerics/safe_conversions.h"
|
||||
#include "rtc_base/system/arch.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
@ -28,14 +30,14 @@ namespace {
|
||||
|
||||
float ComputeAutoCorrelation(
|
||||
int inverted_lag,
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer) {
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
const VectorMath& vector_math) {
|
||||
RTC_DCHECK_LT(inverted_lag, kBufSize24kHz);
|
||||
RTC_DCHECK_LT(inverted_lag, kRefineNumLags24kHz);
|
||||
static_assert(kMaxPitch24kHz < kBufSize24kHz, "");
|
||||
// TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization.
|
||||
return std::inner_product(pitch_buffer.begin() + kMaxPitch24kHz,
|
||||
pitch_buffer.end(),
|
||||
pitch_buffer.begin() + inverted_lag, 0.f);
|
||||
return vector_math.DotProduct(
|
||||
pitch_buffer.subview(/*offset=*/kMaxPitch24kHz),
|
||||
pitch_buffer.subview(inverted_lag, kFrameSize20ms24kHz));
|
||||
}
|
||||
|
||||
// Given an auto-correlation coefficient `curr_auto_correlation` and its
|
||||
@ -66,15 +68,16 @@ int GetPitchPseudoInterpolationOffset(float prev_auto_correlation,
|
||||
// output sample rate is twice as that of |lag|.
|
||||
int PitchPseudoInterpolationLagPitchBuf(
|
||||
int lag,
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer) {
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
const VectorMath& vector_math) {
|
||||
int offset = 0;
|
||||
// Cannot apply pseudo-interpolation at the boundaries.
|
||||
if (lag > 0 && lag < kMaxPitch24kHz) {
|
||||
const int inverted_lag = kMaxPitch24kHz - lag;
|
||||
offset = GetPitchPseudoInterpolationOffset(
|
||||
ComputeAutoCorrelation(inverted_lag + 1, pitch_buffer),
|
||||
ComputeAutoCorrelation(inverted_lag, pitch_buffer),
|
||||
ComputeAutoCorrelation(inverted_lag - 1, pitch_buffer));
|
||||
ComputeAutoCorrelation(inverted_lag + 1, pitch_buffer, vector_math),
|
||||
ComputeAutoCorrelation(inverted_lag, pitch_buffer, vector_math),
|
||||
ComputeAutoCorrelation(inverted_lag - 1, pitch_buffer, vector_math));
|
||||
}
|
||||
return 2 * lag + offset;
|
||||
}
|
||||
@ -153,7 +156,8 @@ void ComputeAutoCorrelation(
|
||||
Range inverted_lags,
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<float, kInitialNumLags24kHz> auto_correlation,
|
||||
InvertedLagsIndex& inverted_lags_index) {
|
||||
InvertedLagsIndex& inverted_lags_index,
|
||||
const VectorMath& vector_math) {
|
||||
// Check valid range.
|
||||
RTC_DCHECK_LE(inverted_lags.min, inverted_lags.max);
|
||||
// Trick to avoid zero initialization of `auto_correlation`.
|
||||
@ -170,7 +174,7 @@ void ComputeAutoCorrelation(
|
||||
for (int inverted_lag = inverted_lags.min; inverted_lag <= inverted_lags.max;
|
||||
++inverted_lag) {
|
||||
auto_correlation[inverted_lag] =
|
||||
ComputeAutoCorrelation(inverted_lag, pitch_buffer);
|
||||
ComputeAutoCorrelation(inverted_lag, pitch_buffer, vector_math);
|
||||
inverted_lags_index.Append(inverted_lag);
|
||||
}
|
||||
}
|
||||
@ -181,7 +185,8 @@ int ComputePitchPeriod48kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<const int> inverted_lags,
|
||||
rtc::ArrayView<const float, kInitialNumLags24kHz> auto_correlation,
|
||||
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy) {
|
||||
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
|
||||
const VectorMath& vector_math) {
|
||||
static_assert(kMaxPitch24kHz > kInitialNumLags24kHz, "");
|
||||
static_assert(kMaxPitch24kHz < kBufSize24kHz, "");
|
||||
int best_inverted_lag = 0; // Pitch period.
|
||||
@ -289,10 +294,12 @@ void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
|
||||
|
||||
void ComputeSlidingFrameSquareEnergies24kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy) {
|
||||
float yy = std::inner_product(pitch_buffer.begin(),
|
||||
pitch_buffer.begin() + kFrameSize20ms24kHz,
|
||||
pitch_buffer.begin(), 0.f);
|
||||
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy,
|
||||
AvailableCpuFeatures cpu_features) {
|
||||
VectorMath vector_math(cpu_features);
|
||||
static_assert(kFrameSize20ms24kHz < kBufSize24kHz, "");
|
||||
const auto frame_20ms_view = pitch_buffer.subview(0, kFrameSize20ms24kHz);
|
||||
float yy = vector_math.DotProduct(frame_20ms_view, frame_20ms_view);
|
||||
y_energy[0] = yy;
|
||||
static_assert(kMaxPitch24kHz - 1 + kFrameSize20ms24kHz < kBufSize24kHz, "");
|
||||
static_assert(kMaxPitch24kHz < kRefineNumLags24kHz, "");
|
||||
@ -307,7 +314,8 @@ void ComputeSlidingFrameSquareEnergies24kHz(
|
||||
|
||||
CandidatePitchPeriods ComputePitchPeriod12kHz(
|
||||
rtc::ArrayView<const float, kBufSize12kHz> pitch_buffer,
|
||||
rtc::ArrayView<const float, kNumLags12kHz> auto_correlation) {
|
||||
rtc::ArrayView<const float, kNumLags12kHz> auto_correlation,
|
||||
AvailableCpuFeatures cpu_features) {
|
||||
static_assert(kMaxPitch12kHz > kNumLags12kHz, "");
|
||||
static_assert(kMaxPitch12kHz < kBufSize12kHz, "");
|
||||
|
||||
@ -326,10 +334,10 @@ CandidatePitchPeriods ComputePitchPeriod12kHz(
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization.
|
||||
float denominator = std::inner_product(
|
||||
pitch_buffer.begin(), pitch_buffer.begin() + kFrameSize20ms12kHz + 1,
|
||||
pitch_buffer.begin(), 1.f);
|
||||
VectorMath vector_math(cpu_features);
|
||||
static_assert(kFrameSize20ms12kHz + 1 < kBufSize12kHz, "");
|
||||
const auto frame_view = pitch_buffer.subview(0, kFrameSize20ms12kHz + 1);
|
||||
float denominator = 1.f + vector_math.DotProduct(frame_view, frame_view);
|
||||
// Search best and second best pitches by looking at the scaled
|
||||
// auto-correlation.
|
||||
PitchCandidate best;
|
||||
@ -364,7 +372,8 @@ CandidatePitchPeriods ComputePitchPeriod12kHz(
|
||||
int ComputePitchPeriod48kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
|
||||
CandidatePitchPeriods pitch_candidates) {
|
||||
CandidatePitchPeriods pitch_candidates,
|
||||
AvailableCpuFeatures cpu_features) {
|
||||
// Compute the auto-correlation terms only for neighbors of the two pitch
|
||||
// candidates (best and second best).
|
||||
std::array<float, kInitialNumLags24kHz> auto_correlation;
|
||||
@ -382,26 +391,28 @@ int ComputePitchPeriod48kHz(
|
||||
// Check `r1` precedes `r2`.
|
||||
RTC_DCHECK_LE(r1.min, r2.min);
|
||||
RTC_DCHECK_LE(r1.max, r2.max);
|
||||
VectorMath vector_math(cpu_features);
|
||||
if (r1.max + 1 >= r2.min) {
|
||||
// Overlapping or adjacent ranges.
|
||||
ComputeAutoCorrelation({r1.min, r2.max}, pitch_buffer, auto_correlation,
|
||||
inverted_lags_index);
|
||||
inverted_lags_index, vector_math);
|
||||
} else {
|
||||
// Disjoint ranges.
|
||||
ComputeAutoCorrelation(r1, pitch_buffer, auto_correlation,
|
||||
inverted_lags_index);
|
||||
inverted_lags_index, vector_math);
|
||||
ComputeAutoCorrelation(r2, pitch_buffer, auto_correlation,
|
||||
inverted_lags_index);
|
||||
inverted_lags_index, vector_math);
|
||||
}
|
||||
return ComputePitchPeriod48kHz(pitch_buffer, inverted_lags_index,
|
||||
auto_correlation, y_energy);
|
||||
auto_correlation, y_energy, vector_math);
|
||||
}
|
||||
|
||||
PitchInfo ComputeExtendedPitchPeriod48kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
|
||||
int initial_pitch_period_48kHz,
|
||||
PitchInfo last_pitch_48kHz) {
|
||||
PitchInfo last_pitch_48kHz,
|
||||
AvailableCpuFeatures cpu_features) {
|
||||
RTC_DCHECK_LE(kMinPitch48kHz, initial_pitch_period_48kHz);
|
||||
RTC_DCHECK_LE(initial_pitch_period_48kHz, kMaxPitch48kHz);
|
||||
|
||||
@ -419,13 +430,14 @@ PitchInfo ComputeExtendedPitchPeriod48kHz(
|
||||
RTC_DCHECK_GE(x_energy * y_energy, 0.f);
|
||||
return xy / std::sqrt(1.f + x_energy * y_energy);
|
||||
};
|
||||
VectorMath vector_math(cpu_features);
|
||||
|
||||
// Initialize the best pitch candidate with `initial_pitch_period_48kHz`.
|
||||
RefinedPitchCandidate best_pitch;
|
||||
best_pitch.period =
|
||||
std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHz - 1);
|
||||
best_pitch.xy =
|
||||
ComputeAutoCorrelation(kMaxPitch24kHz - best_pitch.period, pitch_buffer);
|
||||
best_pitch.xy = ComputeAutoCorrelation(kMaxPitch24kHz - best_pitch.period,
|
||||
pitch_buffer, vector_math);
|
||||
best_pitch.y_energy = y_energy[kMaxPitch24kHz - best_pitch.period];
|
||||
best_pitch.strength = pitch_strength(best_pitch.xy, best_pitch.y_energy);
|
||||
// Keep a copy of the initial pitch candidate.
|
||||
@ -463,9 +475,11 @@ PitchInfo ComputeExtendedPitchPeriod48kHz(
|
||||
// |alternative_pitch.period| by also looking at its possible sub-harmonic
|
||||
// |dual_alternative_period|.
|
||||
const float xy_primary_period = ComputeAutoCorrelation(
|
||||
kMaxPitch24kHz - alternative_pitch.period, pitch_buffer);
|
||||
kMaxPitch24kHz - alternative_pitch.period, pitch_buffer, vector_math);
|
||||
// TODO(webrtc:10480): Copy `xy_primary_period` if the secondary period is
|
||||
// equal to the primary one.
|
||||
const float xy_secondary_period = ComputeAutoCorrelation(
|
||||
kMaxPitch24kHz - dual_alternative_period, pitch_buffer);
|
||||
kMaxPitch24kHz - dual_alternative_period, pitch_buffer, vector_math);
|
||||
const float xy = 0.5f * (xy_primary_period + xy_secondary_period);
|
||||
const float yy =
|
||||
0.5f * (y_energy[kMaxPitch24kHz - alternative_pitch.period] +
|
||||
@ -489,8 +503,8 @@ PitchInfo ComputeExtendedPitchPeriod48kHz(
|
||||
: best_pitch.xy / (best_pitch.y_energy + 1.f);
|
||||
final_pitch_strength = std::min(best_pitch.strength, final_pitch_strength);
|
||||
int final_pitch_period_48kHz = std::max(
|
||||
kMinPitch48kHz,
|
||||
PitchPseudoInterpolationLagPitchBuf(best_pitch.period, pitch_buffer));
|
||||
kMinPitch48kHz, PitchPseudoInterpolationLagPitchBuf(
|
||||
best_pitch.period, pitch_buffer, vector_math));
|
||||
|
||||
return {final_pitch_period_48kHz, final_pitch_strength};
|
||||
}
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#include <utility>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
|
||||
namespace webrtc {
|
||||
@ -65,7 +66,8 @@ void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
|
||||
// buffer. The indexes of `y_energy` are inverted lags.
|
||||
void ComputeSlidingFrameSquareEnergies24kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy);
|
||||
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy,
|
||||
AvailableCpuFeatures cpu_features);
|
||||
|
||||
// Top-2 pitch period candidates. Unit: number of samples - i.e., inverted lags.
|
||||
struct CandidatePitchPeriods {
|
||||
@ -78,7 +80,8 @@ struct CandidatePitchPeriods {
|
||||
// indexes).
|
||||
CandidatePitchPeriods ComputePitchPeriod12kHz(
|
||||
rtc::ArrayView<const float, kBufSize12kHz> pitch_buffer,
|
||||
rtc::ArrayView<const float, kNumLags12kHz> auto_correlation);
|
||||
rtc::ArrayView<const float, kNumLags12kHz> auto_correlation,
|
||||
AvailableCpuFeatures cpu_features);
|
||||
|
||||
// Computes the pitch period at 48 kHz given a view on the 24 kHz pitch buffer,
|
||||
// the energies for the sliding frames `y` at 24 kHz and the pitch period
|
||||
@ -86,7 +89,8 @@ CandidatePitchPeriods ComputePitchPeriod12kHz(
|
||||
int ComputePitchPeriod48kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
|
||||
CandidatePitchPeriods pitch_candidates_24kHz);
|
||||
CandidatePitchPeriods pitch_candidates_24kHz,
|
||||
AvailableCpuFeatures cpu_features);
|
||||
|
||||
struct PitchInfo {
|
||||
int period;
|
||||
@ -101,7 +105,8 @@ PitchInfo ComputeExtendedPitchPeriod48kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
|
||||
int initial_pitch_period_48kHz,
|
||||
PitchInfo last_pitch_48kHz);
|
||||
PitchInfo last_pitch_48kHz,
|
||||
AvailableCpuFeatures cpu_features);
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
@ -11,9 +11,11 @@
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
|
||||
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
|
||||
#include "rtc_base/strings/string_builder.h"
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// #include "test/fpe_observer.h"
|
||||
#include "test/gtest.h"
|
||||
@ -26,20 +28,46 @@ namespace {
|
||||
constexpr int kTestPitchPeriodsLow = 3 * kMinPitch48kHz / 2;
|
||||
constexpr int kTestPitchPeriodsHigh = (3 * kMinPitch48kHz + kMaxPitch48kHz) / 2;
|
||||
|
||||
constexpr float kTestPitchGainsLow = 0.35f;
|
||||
constexpr float kTestPitchGainsHigh = 0.75f;
|
||||
constexpr float kTestPitchStrengthLow = 0.35f;
|
||||
constexpr float kTestPitchStrengthHigh = 0.75f;
|
||||
|
||||
} // namespace
|
||||
template <class T>
|
||||
std::string PrintTestIndexAndCpuFeatures(
|
||||
const ::testing::TestParamInfo<T>& info) {
|
||||
rtc::StringBuilder builder;
|
||||
builder << info.index << "_" << info.param.cpu_features.ToString();
|
||||
return builder.str();
|
||||
}
|
||||
|
||||
// Finds the relevant CPU features combinations to test.
|
||||
std::vector<AvailableCpuFeatures> GetCpuFeaturesToTest() {
|
||||
std::vector<AvailableCpuFeatures> v;
|
||||
v.push_back({/*sse2=*/false, /*avx2=*/false, /*neon=*/false});
|
||||
AvailableCpuFeatures available = GetAvailableCpuFeatures();
|
||||
if (available.avx2) {
|
||||
AvailableCpuFeatures features(
|
||||
{/*sse2=*/false, /*avx2=*/true, /*neon=*/false});
|
||||
v.push_back(features);
|
||||
}
|
||||
if (available.sse2) {
|
||||
AvailableCpuFeatures features(
|
||||
{/*sse2=*/true, /*avx2=*/false, /*neon=*/false});
|
||||
v.push_back(features);
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
// Checks that the frame-wise sliding square energy function produces output
|
||||
// within tolerance given test input data.
|
||||
TEST(RnnVadTest, ComputeSlidingFrameSquareEnergies24kHzWithinTolerance) {
|
||||
const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures();
|
||||
|
||||
PitchTestData test_data;
|
||||
std::array<float, kNumPitchBufSquareEnergies> computed_output;
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// FloatingPointExceptionObserver fpe_observer;
|
||||
ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
|
||||
computed_output);
|
||||
computed_output, cpu_features);
|
||||
auto square_energies_view = test_data.GetPitchBufSquareEnergiesView();
|
||||
ExpectNearAbsolute({square_energies_view.data(), square_energies_view.size()},
|
||||
computed_output, 1e-3f);
|
||||
@ -47,6 +75,8 @@ TEST(RnnVadTest, ComputeSlidingFrameSquareEnergies24kHzWithinTolerance) {
|
||||
|
||||
// Checks that the estimated pitch period is bit-exact given test input data.
|
||||
TEST(RnnVadTest, ComputePitchPeriod12kHzBitExactness) {
|
||||
const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures();
|
||||
|
||||
PitchTestData test_data;
|
||||
std::array<float, kBufSize12kHz> pitch_buf_decimated;
|
||||
Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
|
||||
@ -54,138 +84,141 @@ TEST(RnnVadTest, ComputePitchPeriod12kHzBitExactness) {
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// FloatingPointExceptionObserver fpe_observer;
|
||||
auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView();
|
||||
pitch_candidates =
|
||||
ComputePitchPeriod12kHz(pitch_buf_decimated, auto_corr_view);
|
||||
pitch_candidates = ComputePitchPeriod12kHz(pitch_buf_decimated,
|
||||
auto_corr_view, cpu_features);
|
||||
EXPECT_EQ(pitch_candidates.best, 140);
|
||||
EXPECT_EQ(pitch_candidates.second_best, 142);
|
||||
}
|
||||
|
||||
// Checks that the refined pitch period is bit-exact given test input data.
|
||||
TEST(RnnVadTest, ComputePitchPeriod48kHzBitExactness) {
|
||||
const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures();
|
||||
|
||||
PitchTestData test_data;
|
||||
std::vector<float> y_energy(kRefineNumLags24kHz);
|
||||
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_view(y_energy.data(),
|
||||
kRefineNumLags24kHz);
|
||||
ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
|
||||
y_energy_view);
|
||||
y_energy_view, cpu_features);
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// FloatingPointExceptionObserver fpe_observer;
|
||||
EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
|
||||
/*pitch_candidates=*/{280, 284}),
|
||||
560);
|
||||
EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
|
||||
/*pitch_candidates=*/{260, 284}),
|
||||
568);
|
||||
EXPECT_EQ(
|
||||
ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
|
||||
/*pitch_candidates=*/{280, 284}, cpu_features),
|
||||
560);
|
||||
EXPECT_EQ(
|
||||
ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
|
||||
/*pitch_candidates=*/{260, 284}, cpu_features),
|
||||
568);
|
||||
}
|
||||
|
||||
class PitchCandidatesParametrization
|
||||
: public ::testing::TestWithParam<CandidatePitchPeriods> {
|
||||
protected:
|
||||
CandidatePitchPeriods GetPitchCandidates() const { return GetParam(); }
|
||||
CandidatePitchPeriods GetSwappedPitchCandidates() const {
|
||||
CandidatePitchPeriods candidate = GetParam();
|
||||
return {candidate.second_best, candidate.best};
|
||||
}
|
||||
struct PitchCandidatesParameters {
|
||||
CandidatePitchPeriods pitch_candidates;
|
||||
AvailableCpuFeatures cpu_features;
|
||||
};
|
||||
|
||||
class PitchCandidatesParametrization
|
||||
: public ::testing::TestWithParam<PitchCandidatesParameters> {};
|
||||
|
||||
// Checks that the result of `ComputePitchPeriod48kHz()` does not depend on the
|
||||
// order of the input pitch candidates.
|
||||
TEST_P(PitchCandidatesParametrization,
|
||||
ComputePitchPeriod48kHzOrderDoesNotMatter) {
|
||||
const PitchCandidatesParameters params = GetParam();
|
||||
const CandidatePitchPeriods swapped_pitch_candidates{
|
||||
params.pitch_candidates.second_best, params.pitch_candidates.best};
|
||||
|
||||
PitchTestData test_data;
|
||||
std::vector<float> y_energy(kRefineNumLags24kHz);
|
||||
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_view(y_energy.data(),
|
||||
kRefineNumLags24kHz);
|
||||
ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
|
||||
y_energy_view);
|
||||
EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
|
||||
GetPitchCandidates()),
|
||||
ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
|
||||
GetSwappedPitchCandidates()));
|
||||
y_energy_view, params.cpu_features);
|
||||
EXPECT_EQ(
|
||||
ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
|
||||
params.pitch_candidates, params.cpu_features),
|
||||
ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
|
||||
swapped_pitch_candidates, params.cpu_features));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(RnnVadTest,
|
||||
PitchCandidatesParametrization,
|
||||
::testing::Values(CandidatePitchPeriods{0, 2},
|
||||
CandidatePitchPeriods{260, 284},
|
||||
CandidatePitchPeriods{280, 284},
|
||||
CandidatePitchPeriods{
|
||||
kInitialNumLags24kHz - 2,
|
||||
kInitialNumLags24kHz - 1}));
|
||||
std::vector<PitchCandidatesParameters> CreatePitchCandidatesParameters() {
|
||||
std::vector<PitchCandidatesParameters> v;
|
||||
for (AvailableCpuFeatures cpu_features : GetCpuFeaturesToTest()) {
|
||||
v.push_back({{0, 2}, cpu_features});
|
||||
v.push_back({{260, 284}, cpu_features});
|
||||
v.push_back({{280, 284}, cpu_features});
|
||||
v.push_back(
|
||||
{{kInitialNumLags24kHz - 2, kInitialNumLags24kHz - 1}, cpu_features});
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
RnnVadTest,
|
||||
PitchCandidatesParametrization,
|
||||
::testing::ValuesIn(CreatePitchCandidatesParameters()),
|
||||
PrintTestIndexAndCpuFeatures<PitchCandidatesParameters>);
|
||||
|
||||
struct ExtendedPitchPeriodSearchParameters {
|
||||
int initial_pitch_period;
|
||||
PitchInfo last_pitch;
|
||||
PitchInfo expected_pitch;
|
||||
AvailableCpuFeatures cpu_features;
|
||||
};
|
||||
|
||||
class ExtendedPitchPeriodSearchParametrizaion
|
||||
: public ::testing::TestWithParam<std::tuple<int, int, float, int, float>> {
|
||||
protected:
|
||||
int GetInitialPitchPeriod() const { return std::get<0>(GetParam()); }
|
||||
int GetLastPitchPeriod() const { return std::get<1>(GetParam()); }
|
||||
float GetLastPitchStrength() const { return std::get<2>(GetParam()); }
|
||||
int GetExpectedPitchPeriod() const { return std::get<3>(GetParam()); }
|
||||
float GetExpectedPitchStrength() const { return std::get<4>(GetParam()); }
|
||||
};
|
||||
: public ::testing::TestWithParam<ExtendedPitchPeriodSearchParameters> {};
|
||||
|
||||
// Checks that the computed pitch period is bit-exact and that the computed
|
||||
// pitch strength is within tolerance given test input data.
|
||||
TEST_P(ExtendedPitchPeriodSearchParametrizaion,
|
||||
PeriodBitExactnessGainWithinTolerance) {
|
||||
const ExtendedPitchPeriodSearchParameters params = GetParam();
|
||||
|
||||
PitchTestData test_data;
|
||||
std::vector<float> y_energy(kRefineNumLags24kHz);
|
||||
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_view(y_energy.data(),
|
||||
kRefineNumLags24kHz);
|
||||
ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
|
||||
y_energy_view);
|
||||
y_energy_view, params.cpu_features);
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// FloatingPointExceptionObserver fpe_observer;
|
||||
const auto computed_output = ComputeExtendedPitchPeriod48kHz(
|
||||
test_data.GetPitchBufView(), y_energy_view, GetInitialPitchPeriod(),
|
||||
{GetLastPitchPeriod(), GetLastPitchStrength()});
|
||||
EXPECT_EQ(GetExpectedPitchPeriod(), computed_output.period);
|
||||
EXPECT_NEAR(GetExpectedPitchStrength(), computed_output.strength, 1e-6f);
|
||||
test_data.GetPitchBufView(), y_energy_view, params.initial_pitch_period,
|
||||
params.last_pitch, params.cpu_features);
|
||||
EXPECT_EQ(params.expected_pitch.period, computed_output.period);
|
||||
EXPECT_NEAR(params.expected_pitch.strength, computed_output.strength, 1e-6f);
|
||||
}
|
||||
|
||||
std::vector<ExtendedPitchPeriodSearchParameters>
|
||||
CreateExtendedPitchPeriodSearchParameters() {
|
||||
std::vector<ExtendedPitchPeriodSearchParameters> v;
|
||||
for (AvailableCpuFeatures cpu_features : GetCpuFeaturesToTest()) {
|
||||
for (int last_pitch_period :
|
||||
{kTestPitchPeriodsLow, kTestPitchPeriodsHigh}) {
|
||||
for (float last_pitch_strength :
|
||||
{kTestPitchStrengthLow, kTestPitchStrengthHigh}) {
|
||||
v.push_back({kTestPitchPeriodsLow,
|
||||
{last_pitch_period, last_pitch_strength},
|
||||
{91, -0.0188608f},
|
||||
cpu_features});
|
||||
v.push_back({kTestPitchPeriodsHigh,
|
||||
{last_pitch_period, last_pitch_strength},
|
||||
{475, -0.0904344f},
|
||||
cpu_features});
|
||||
}
|
||||
}
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
RnnVadTest,
|
||||
ExtendedPitchPeriodSearchParametrizaion,
|
||||
::testing::Values(std::make_tuple(kTestPitchPeriodsLow,
|
||||
kTestPitchPeriodsLow,
|
||||
kTestPitchGainsLow,
|
||||
91,
|
||||
-0.0188608f),
|
||||
std::make_tuple(kTestPitchPeriodsLow,
|
||||
kTestPitchPeriodsLow,
|
||||
kTestPitchGainsHigh,
|
||||
91,
|
||||
-0.0188608f),
|
||||
std::make_tuple(kTestPitchPeriodsLow,
|
||||
kTestPitchPeriodsHigh,
|
||||
kTestPitchGainsLow,
|
||||
91,
|
||||
-0.0188608f),
|
||||
std::make_tuple(kTestPitchPeriodsLow,
|
||||
kTestPitchPeriodsHigh,
|
||||
kTestPitchGainsHigh,
|
||||
91,
|
||||
-0.0188608f),
|
||||
std::make_tuple(kTestPitchPeriodsHigh,
|
||||
kTestPitchPeriodsLow,
|
||||
kTestPitchGainsLow,
|
||||
475,
|
||||
-0.0904344f),
|
||||
std::make_tuple(kTestPitchPeriodsHigh,
|
||||
kTestPitchPeriodsLow,
|
||||
kTestPitchGainsHigh,
|
||||
475,
|
||||
-0.0904344f),
|
||||
std::make_tuple(kTestPitchPeriodsHigh,
|
||||
kTestPitchPeriodsHigh,
|
||||
kTestPitchGainsLow,
|
||||
475,
|
||||
-0.0904344f),
|
||||
std::make_tuple(kTestPitchPeriodsHigh,
|
||||
kTestPitchPeriodsHigh,
|
||||
kTestPitchGainsHigh,
|
||||
475,
|
||||
-0.0904344f)));
|
||||
::testing::ValuesIn(CreateExtendedPitchPeriodSearchParameters()),
|
||||
PrintTestIndexAndCpuFeatures<ExtendedPitchPeriodSearchParameters>);
|
||||
|
||||
} // namespace
|
||||
} // namespace test
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
@ -163,10 +163,11 @@ std::vector<AvailableCpuFeatures> GetCpuFeaturesToTest() {
|
||||
std::vector<AvailableCpuFeatures> v;
|
||||
v.push_back({/*sse2=*/false, /*avx2=*/false, /*neon=*/false});
|
||||
AvailableCpuFeatures available = GetAvailableCpuFeatures();
|
||||
if (available.avx2 && available.sse2) {
|
||||
v.push_back({/*sse2=*/true, /*avx2=*/true, /*neon=*/false});
|
||||
}
|
||||
if (available.sse2) {
|
||||
AvailableCpuFeatures features(
|
||||
{/*sse2=*/true, /*avx2=*/false, /*neon=*/false});
|
||||
v.push_back(features);
|
||||
v.push_back({/*sse2=*/true, /*avx2=*/false, /*neon=*/false});
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user