diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn index fafea4294c..dbba6c173c 100644 --- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn +++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn @@ -124,6 +124,7 @@ rtc_library("rnn_vad_pitch") { "../../../../rtc_base:gtest_prod", "../../../../rtc_base:safe_compare", "../../../../rtc_base:safe_conversions", + "../../../../rtc_base/system:arch", ] if (current_cpu == "x86" || current_cpu == "x64") { deps += [ ":vector_math_avx2" ] @@ -246,6 +247,7 @@ if (rtc_include_tests) { "../../../../rtc_base:logging", "../../../../rtc_base:safe_compare", "../../../../rtc_base:safe_conversions", + "../../../../rtc_base:stringutils", "../../../../rtc_base/system:arch", "../../../../test:test_support", "../../utility:pffft_wrapper", diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc index c2e7665967..77a118853f 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc @@ -42,7 +42,7 @@ int PitchEstimator::Estimate( auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buffer_12kHz_view, auto_correlation_12kHz_view); CandidatePitchPeriods pitch_periods = ComputePitchPeriod12kHz( - pitch_buffer_12kHz_view, auto_correlation_12kHz_view); + pitch_buffer_12kHz_view, auto_correlation_12kHz_view, cpu_features_); // The refinement is done using the pitch buffer that contains 24 kHz samples. // Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12 // to 24 kHz. @@ -54,14 +54,15 @@ int PitchEstimator::Estimate( rtc::ArrayView y_energy_24kHz_view( y_energy_24kHz_.data(), kRefineNumLags24kHz); RTC_DCHECK_EQ(y_energy_24kHz_.size(), y_energy_24kHz_view.size()); - ComputeSlidingFrameSquareEnergies24kHz(pitch_buffer, y_energy_24kHz_view); + ComputeSlidingFrameSquareEnergies24kHz(pitch_buffer, y_energy_24kHz_view, + cpu_features_); // Estimation at 48 kHz. - const int pitch_lag_48kHz = - ComputePitchPeriod48kHz(pitch_buffer, y_energy_24kHz_view, pitch_periods); + const int pitch_lag_48kHz = ComputePitchPeriod48kHz( + pitch_buffer, y_energy_24kHz_view, pitch_periods, cpu_features_); last_pitch_48kHz_ = ComputeExtendedPitchPeriod48kHz( pitch_buffer, y_energy_24kHz_view, /*initial_pitch_period_48kHz=*/kMaxPitch48kHz - pitch_lag_48kHz, - last_pitch_48kHz_); + last_pitch_48kHz_, cpu_features_); return last_pitch_48kHz_.period; } diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc index 262c386453..0b8a77e488 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc @@ -18,9 +18,11 @@ #include #include "modules/audio_processing/agc2/rnn_vad/common.h" +#include "modules/audio_processing/agc2/rnn_vad/vector_math.h" #include "rtc_base/checks.h" #include "rtc_base/numerics/safe_compare.h" #include "rtc_base/numerics/safe_conversions.h" +#include "rtc_base/system/arch.h" namespace webrtc { namespace rnn_vad { @@ -28,14 +30,14 @@ namespace { float ComputeAutoCorrelation( int inverted_lag, - rtc::ArrayView pitch_buffer) { + rtc::ArrayView pitch_buffer, + const VectorMath& vector_math) { RTC_DCHECK_LT(inverted_lag, kBufSize24kHz); RTC_DCHECK_LT(inverted_lag, kRefineNumLags24kHz); static_assert(kMaxPitch24kHz < kBufSize24kHz, ""); - // TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization. - return std::inner_product(pitch_buffer.begin() + kMaxPitch24kHz, - pitch_buffer.end(), - pitch_buffer.begin() + inverted_lag, 0.f); + return vector_math.DotProduct( + pitch_buffer.subview(/*offset=*/kMaxPitch24kHz), + pitch_buffer.subview(inverted_lag, kFrameSize20ms24kHz)); } // Given an auto-correlation coefficient `curr_auto_correlation` and its @@ -66,15 +68,16 @@ int GetPitchPseudoInterpolationOffset(float prev_auto_correlation, // output sample rate is twice as that of |lag|. int PitchPseudoInterpolationLagPitchBuf( int lag, - rtc::ArrayView pitch_buffer) { + rtc::ArrayView pitch_buffer, + const VectorMath& vector_math) { int offset = 0; // Cannot apply pseudo-interpolation at the boundaries. if (lag > 0 && lag < kMaxPitch24kHz) { const int inverted_lag = kMaxPitch24kHz - lag; offset = GetPitchPseudoInterpolationOffset( - ComputeAutoCorrelation(inverted_lag + 1, pitch_buffer), - ComputeAutoCorrelation(inverted_lag, pitch_buffer), - ComputeAutoCorrelation(inverted_lag - 1, pitch_buffer)); + ComputeAutoCorrelation(inverted_lag + 1, pitch_buffer, vector_math), + ComputeAutoCorrelation(inverted_lag, pitch_buffer, vector_math), + ComputeAutoCorrelation(inverted_lag - 1, pitch_buffer, vector_math)); } return 2 * lag + offset; } @@ -153,7 +156,8 @@ void ComputeAutoCorrelation( Range inverted_lags, rtc::ArrayView pitch_buffer, rtc::ArrayView auto_correlation, - InvertedLagsIndex& inverted_lags_index) { + InvertedLagsIndex& inverted_lags_index, + const VectorMath& vector_math) { // Check valid range. RTC_DCHECK_LE(inverted_lags.min, inverted_lags.max); // Trick to avoid zero initialization of `auto_correlation`. @@ -170,7 +174,7 @@ void ComputeAutoCorrelation( for (int inverted_lag = inverted_lags.min; inverted_lag <= inverted_lags.max; ++inverted_lag) { auto_correlation[inverted_lag] = - ComputeAutoCorrelation(inverted_lag, pitch_buffer); + ComputeAutoCorrelation(inverted_lag, pitch_buffer, vector_math); inverted_lags_index.Append(inverted_lag); } } @@ -181,7 +185,8 @@ int ComputePitchPeriod48kHz( rtc::ArrayView pitch_buffer, rtc::ArrayView inverted_lags, rtc::ArrayView auto_correlation, - rtc::ArrayView y_energy) { + rtc::ArrayView y_energy, + const VectorMath& vector_math) { static_assert(kMaxPitch24kHz > kInitialNumLags24kHz, ""); static_assert(kMaxPitch24kHz < kBufSize24kHz, ""); int best_inverted_lag = 0; // Pitch period. @@ -289,10 +294,12 @@ void Decimate2x(rtc::ArrayView src, void ComputeSlidingFrameSquareEnergies24kHz( rtc::ArrayView pitch_buffer, - rtc::ArrayView y_energy) { - float yy = std::inner_product(pitch_buffer.begin(), - pitch_buffer.begin() + kFrameSize20ms24kHz, - pitch_buffer.begin(), 0.f); + rtc::ArrayView y_energy, + AvailableCpuFeatures cpu_features) { + VectorMath vector_math(cpu_features); + static_assert(kFrameSize20ms24kHz < kBufSize24kHz, ""); + const auto frame_20ms_view = pitch_buffer.subview(0, kFrameSize20ms24kHz); + float yy = vector_math.DotProduct(frame_20ms_view, frame_20ms_view); y_energy[0] = yy; static_assert(kMaxPitch24kHz - 1 + kFrameSize20ms24kHz < kBufSize24kHz, ""); static_assert(kMaxPitch24kHz < kRefineNumLags24kHz, ""); @@ -307,7 +314,8 @@ void ComputeSlidingFrameSquareEnergies24kHz( CandidatePitchPeriods ComputePitchPeriod12kHz( rtc::ArrayView pitch_buffer, - rtc::ArrayView auto_correlation) { + rtc::ArrayView auto_correlation, + AvailableCpuFeatures cpu_features) { static_assert(kMaxPitch12kHz > kNumLags12kHz, ""); static_assert(kMaxPitch12kHz < kBufSize12kHz, ""); @@ -326,10 +334,10 @@ CandidatePitchPeriods ComputePitchPeriod12kHz( } }; - // TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization. - float denominator = std::inner_product( - pitch_buffer.begin(), pitch_buffer.begin() + kFrameSize20ms12kHz + 1, - pitch_buffer.begin(), 1.f); + VectorMath vector_math(cpu_features); + static_assert(kFrameSize20ms12kHz + 1 < kBufSize12kHz, ""); + const auto frame_view = pitch_buffer.subview(0, kFrameSize20ms12kHz + 1); + float denominator = 1.f + vector_math.DotProduct(frame_view, frame_view); // Search best and second best pitches by looking at the scaled // auto-correlation. PitchCandidate best; @@ -364,7 +372,8 @@ CandidatePitchPeriods ComputePitchPeriod12kHz( int ComputePitchPeriod48kHz( rtc::ArrayView pitch_buffer, rtc::ArrayView y_energy, - CandidatePitchPeriods pitch_candidates) { + CandidatePitchPeriods pitch_candidates, + AvailableCpuFeatures cpu_features) { // Compute the auto-correlation terms only for neighbors of the two pitch // candidates (best and second best). std::array auto_correlation; @@ -382,26 +391,28 @@ int ComputePitchPeriod48kHz( // Check `r1` precedes `r2`. RTC_DCHECK_LE(r1.min, r2.min); RTC_DCHECK_LE(r1.max, r2.max); + VectorMath vector_math(cpu_features); if (r1.max + 1 >= r2.min) { // Overlapping or adjacent ranges. ComputeAutoCorrelation({r1.min, r2.max}, pitch_buffer, auto_correlation, - inverted_lags_index); + inverted_lags_index, vector_math); } else { // Disjoint ranges. ComputeAutoCorrelation(r1, pitch_buffer, auto_correlation, - inverted_lags_index); + inverted_lags_index, vector_math); ComputeAutoCorrelation(r2, pitch_buffer, auto_correlation, - inverted_lags_index); + inverted_lags_index, vector_math); } return ComputePitchPeriod48kHz(pitch_buffer, inverted_lags_index, - auto_correlation, y_energy); + auto_correlation, y_energy, vector_math); } PitchInfo ComputeExtendedPitchPeriod48kHz( rtc::ArrayView pitch_buffer, rtc::ArrayView y_energy, int initial_pitch_period_48kHz, - PitchInfo last_pitch_48kHz) { + PitchInfo last_pitch_48kHz, + AvailableCpuFeatures cpu_features) { RTC_DCHECK_LE(kMinPitch48kHz, initial_pitch_period_48kHz); RTC_DCHECK_LE(initial_pitch_period_48kHz, kMaxPitch48kHz); @@ -419,13 +430,14 @@ PitchInfo ComputeExtendedPitchPeriod48kHz( RTC_DCHECK_GE(x_energy * y_energy, 0.f); return xy / std::sqrt(1.f + x_energy * y_energy); }; + VectorMath vector_math(cpu_features); // Initialize the best pitch candidate with `initial_pitch_period_48kHz`. RefinedPitchCandidate best_pitch; best_pitch.period = std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHz - 1); - best_pitch.xy = - ComputeAutoCorrelation(kMaxPitch24kHz - best_pitch.period, pitch_buffer); + best_pitch.xy = ComputeAutoCorrelation(kMaxPitch24kHz - best_pitch.period, + pitch_buffer, vector_math); best_pitch.y_energy = y_energy[kMaxPitch24kHz - best_pitch.period]; best_pitch.strength = pitch_strength(best_pitch.xy, best_pitch.y_energy); // Keep a copy of the initial pitch candidate. @@ -463,9 +475,11 @@ PitchInfo ComputeExtendedPitchPeriod48kHz( // |alternative_pitch.period| by also looking at its possible sub-harmonic // |dual_alternative_period|. const float xy_primary_period = ComputeAutoCorrelation( - kMaxPitch24kHz - alternative_pitch.period, pitch_buffer); + kMaxPitch24kHz - alternative_pitch.period, pitch_buffer, vector_math); + // TODO(webrtc:10480): Copy `xy_primary_period` if the secondary period is + // equal to the primary one. const float xy_secondary_period = ComputeAutoCorrelation( - kMaxPitch24kHz - dual_alternative_period, pitch_buffer); + kMaxPitch24kHz - dual_alternative_period, pitch_buffer, vector_math); const float xy = 0.5f * (xy_primary_period + xy_secondary_period); const float yy = 0.5f * (y_energy[kMaxPitch24kHz - alternative_pitch.period] + @@ -489,8 +503,8 @@ PitchInfo ComputeExtendedPitchPeriod48kHz( : best_pitch.xy / (best_pitch.y_energy + 1.f); final_pitch_strength = std::min(best_pitch.strength, final_pitch_strength); int final_pitch_period_48kHz = std::max( - kMinPitch48kHz, - PitchPseudoInterpolationLagPitchBuf(best_pitch.period, pitch_buffer)); + kMinPitch48kHz, PitchPseudoInterpolationLagPitchBuf( + best_pitch.period, pitch_buffer, vector_math)); return {final_pitch_period_48kHz, final_pitch_strength}; } diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h index 0af55f8e69..aa2dd13745 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h @@ -17,6 +17,7 @@ #include #include "api/array_view.h" +#include "modules/audio_processing/agc2/cpu_features.h" #include "modules/audio_processing/agc2/rnn_vad/common.h" namespace webrtc { @@ -65,7 +66,8 @@ void Decimate2x(rtc::ArrayView src, // buffer. The indexes of `y_energy` are inverted lags. void ComputeSlidingFrameSquareEnergies24kHz( rtc::ArrayView pitch_buffer, - rtc::ArrayView y_energy); + rtc::ArrayView y_energy, + AvailableCpuFeatures cpu_features); // Top-2 pitch period candidates. Unit: number of samples - i.e., inverted lags. struct CandidatePitchPeriods { @@ -78,7 +80,8 @@ struct CandidatePitchPeriods { // indexes). CandidatePitchPeriods ComputePitchPeriod12kHz( rtc::ArrayView pitch_buffer, - rtc::ArrayView auto_correlation); + rtc::ArrayView auto_correlation, + AvailableCpuFeatures cpu_features); // Computes the pitch period at 48 kHz given a view on the 24 kHz pitch buffer, // the energies for the sliding frames `y` at 24 kHz and the pitch period @@ -86,7 +89,8 @@ CandidatePitchPeriods ComputePitchPeriod12kHz( int ComputePitchPeriod48kHz( rtc::ArrayView pitch_buffer, rtc::ArrayView y_energy, - CandidatePitchPeriods pitch_candidates_24kHz); + CandidatePitchPeriods pitch_candidates_24kHz, + AvailableCpuFeatures cpu_features); struct PitchInfo { int period; @@ -101,7 +105,8 @@ PitchInfo ComputeExtendedPitchPeriod48kHz( rtc::ArrayView pitch_buffer, rtc::ArrayView y_energy, int initial_pitch_period_48kHz, - PitchInfo last_pitch_48kHz); + PitchInfo last_pitch_48kHz, + AvailableCpuFeatures cpu_features); } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc index 152d569823..a4a4df12dc 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc @@ -11,9 +11,11 @@ #include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h" #include +#include #include #include "modules/audio_processing/agc2/rnn_vad/test_utils.h" +#include "rtc_base/strings/string_builder.h" // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // #include "test/fpe_observer.h" #include "test/gtest.h" @@ -26,20 +28,46 @@ namespace { constexpr int kTestPitchPeriodsLow = 3 * kMinPitch48kHz / 2; constexpr int kTestPitchPeriodsHigh = (3 * kMinPitch48kHz + kMaxPitch48kHz) / 2; -constexpr float kTestPitchGainsLow = 0.35f; -constexpr float kTestPitchGainsHigh = 0.75f; +constexpr float kTestPitchStrengthLow = 0.35f; +constexpr float kTestPitchStrengthHigh = 0.75f; -} // namespace +template +std::string PrintTestIndexAndCpuFeatures( + const ::testing::TestParamInfo& info) { + rtc::StringBuilder builder; + builder << info.index << "_" << info.param.cpu_features.ToString(); + return builder.str(); +} + +// Finds the relevant CPU features combinations to test. +std::vector GetCpuFeaturesToTest() { + std::vector v; + v.push_back({/*sse2=*/false, /*avx2=*/false, /*neon=*/false}); + AvailableCpuFeatures available = GetAvailableCpuFeatures(); + if (available.avx2) { + AvailableCpuFeatures features( + {/*sse2=*/false, /*avx2=*/true, /*neon=*/false}); + v.push_back(features); + } + if (available.sse2) { + AvailableCpuFeatures features( + {/*sse2=*/true, /*avx2=*/false, /*neon=*/false}); + v.push_back(features); + } + return v; +} // Checks that the frame-wise sliding square energy function produces output // within tolerance given test input data. TEST(RnnVadTest, ComputeSlidingFrameSquareEnergies24kHzWithinTolerance) { + const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures(); + PitchTestData test_data; std::array computed_output; // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(), - computed_output); + computed_output, cpu_features); auto square_energies_view = test_data.GetPitchBufSquareEnergiesView(); ExpectNearAbsolute({square_energies_view.data(), square_energies_view.size()}, computed_output, 1e-3f); @@ -47,6 +75,8 @@ TEST(RnnVadTest, ComputeSlidingFrameSquareEnergies24kHzWithinTolerance) { // Checks that the estimated pitch period is bit-exact given test input data. TEST(RnnVadTest, ComputePitchPeriod12kHzBitExactness) { + const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures(); + PitchTestData test_data; std::array pitch_buf_decimated; Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated); @@ -54,138 +84,141 @@ TEST(RnnVadTest, ComputePitchPeriod12kHzBitExactness) { // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView(); - pitch_candidates = - ComputePitchPeriod12kHz(pitch_buf_decimated, auto_corr_view); + pitch_candidates = ComputePitchPeriod12kHz(pitch_buf_decimated, + auto_corr_view, cpu_features); EXPECT_EQ(pitch_candidates.best, 140); EXPECT_EQ(pitch_candidates.second_best, 142); } // Checks that the refined pitch period is bit-exact given test input data. TEST(RnnVadTest, ComputePitchPeriod48kHzBitExactness) { + const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures(); + PitchTestData test_data; std::vector y_energy(kRefineNumLags24kHz); rtc::ArrayView y_energy_view(y_energy.data(), kRefineNumLags24kHz); ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(), - y_energy_view); + y_energy_view, cpu_features); // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; - EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view, - /*pitch_candidates=*/{280, 284}), - 560); - EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view, - /*pitch_candidates=*/{260, 284}), - 568); + EXPECT_EQ( + ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view, + /*pitch_candidates=*/{280, 284}, cpu_features), + 560); + EXPECT_EQ( + ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view, + /*pitch_candidates=*/{260, 284}, cpu_features), + 568); } -class PitchCandidatesParametrization - : public ::testing::TestWithParam { - protected: - CandidatePitchPeriods GetPitchCandidates() const { return GetParam(); } - CandidatePitchPeriods GetSwappedPitchCandidates() const { - CandidatePitchPeriods candidate = GetParam(); - return {candidate.second_best, candidate.best}; - } +struct PitchCandidatesParameters { + CandidatePitchPeriods pitch_candidates; + AvailableCpuFeatures cpu_features; }; +class PitchCandidatesParametrization + : public ::testing::TestWithParam {}; + // Checks that the result of `ComputePitchPeriod48kHz()` does not depend on the // order of the input pitch candidates. TEST_P(PitchCandidatesParametrization, ComputePitchPeriod48kHzOrderDoesNotMatter) { + const PitchCandidatesParameters params = GetParam(); + const CandidatePitchPeriods swapped_pitch_candidates{ + params.pitch_candidates.second_best, params.pitch_candidates.best}; + PitchTestData test_data; std::vector y_energy(kRefineNumLags24kHz); rtc::ArrayView y_energy_view(y_energy.data(), kRefineNumLags24kHz); ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(), - y_energy_view); - EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view, - GetPitchCandidates()), - ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view, - GetSwappedPitchCandidates())); + y_energy_view, params.cpu_features); + EXPECT_EQ( + ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view, + params.pitch_candidates, params.cpu_features), + ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view, + swapped_pitch_candidates, params.cpu_features)); } -INSTANTIATE_TEST_SUITE_P(RnnVadTest, - PitchCandidatesParametrization, - ::testing::Values(CandidatePitchPeriods{0, 2}, - CandidatePitchPeriods{260, 284}, - CandidatePitchPeriods{280, 284}, - CandidatePitchPeriods{ - kInitialNumLags24kHz - 2, - kInitialNumLags24kHz - 1})); +std::vector CreatePitchCandidatesParameters() { + std::vector v; + for (AvailableCpuFeatures cpu_features : GetCpuFeaturesToTest()) { + v.push_back({{0, 2}, cpu_features}); + v.push_back({{260, 284}, cpu_features}); + v.push_back({{280, 284}, cpu_features}); + v.push_back( + {{kInitialNumLags24kHz - 2, kInitialNumLags24kHz - 1}, cpu_features}); + } + return v; +} + +INSTANTIATE_TEST_SUITE_P( + RnnVadTest, + PitchCandidatesParametrization, + ::testing::ValuesIn(CreatePitchCandidatesParameters()), + PrintTestIndexAndCpuFeatures); + +struct ExtendedPitchPeriodSearchParameters { + int initial_pitch_period; + PitchInfo last_pitch; + PitchInfo expected_pitch; + AvailableCpuFeatures cpu_features; +}; class ExtendedPitchPeriodSearchParametrizaion - : public ::testing::TestWithParam> { - protected: - int GetInitialPitchPeriod() const { return std::get<0>(GetParam()); } - int GetLastPitchPeriod() const { return std::get<1>(GetParam()); } - float GetLastPitchStrength() const { return std::get<2>(GetParam()); } - int GetExpectedPitchPeriod() const { return std::get<3>(GetParam()); } - float GetExpectedPitchStrength() const { return std::get<4>(GetParam()); } -}; + : public ::testing::TestWithParam {}; // Checks that the computed pitch period is bit-exact and that the computed // pitch strength is within tolerance given test input data. TEST_P(ExtendedPitchPeriodSearchParametrizaion, PeriodBitExactnessGainWithinTolerance) { + const ExtendedPitchPeriodSearchParameters params = GetParam(); + PitchTestData test_data; std::vector y_energy(kRefineNumLags24kHz); rtc::ArrayView y_energy_view(y_energy.data(), kRefineNumLags24kHz); ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(), - y_energy_view); + y_energy_view, params.cpu_features); // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; const auto computed_output = ComputeExtendedPitchPeriod48kHz( - test_data.GetPitchBufView(), y_energy_view, GetInitialPitchPeriod(), - {GetLastPitchPeriod(), GetLastPitchStrength()}); - EXPECT_EQ(GetExpectedPitchPeriod(), computed_output.period); - EXPECT_NEAR(GetExpectedPitchStrength(), computed_output.strength, 1e-6f); + test_data.GetPitchBufView(), y_energy_view, params.initial_pitch_period, + params.last_pitch, params.cpu_features); + EXPECT_EQ(params.expected_pitch.period, computed_output.period); + EXPECT_NEAR(params.expected_pitch.strength, computed_output.strength, 1e-6f); +} + +std::vector +CreateExtendedPitchPeriodSearchParameters() { + std::vector v; + for (AvailableCpuFeatures cpu_features : GetCpuFeaturesToTest()) { + for (int last_pitch_period : + {kTestPitchPeriodsLow, kTestPitchPeriodsHigh}) { + for (float last_pitch_strength : + {kTestPitchStrengthLow, kTestPitchStrengthHigh}) { + v.push_back({kTestPitchPeriodsLow, + {last_pitch_period, last_pitch_strength}, + {91, -0.0188608f}, + cpu_features}); + v.push_back({kTestPitchPeriodsHigh, + {last_pitch_period, last_pitch_strength}, + {475, -0.0904344f}, + cpu_features}); + } + } + } + return v; } INSTANTIATE_TEST_SUITE_P( RnnVadTest, ExtendedPitchPeriodSearchParametrizaion, - ::testing::Values(std::make_tuple(kTestPitchPeriodsLow, - kTestPitchPeriodsLow, - kTestPitchGainsLow, - 91, - -0.0188608f), - std::make_tuple(kTestPitchPeriodsLow, - kTestPitchPeriodsLow, - kTestPitchGainsHigh, - 91, - -0.0188608f), - std::make_tuple(kTestPitchPeriodsLow, - kTestPitchPeriodsHigh, - kTestPitchGainsLow, - 91, - -0.0188608f), - std::make_tuple(kTestPitchPeriodsLow, - kTestPitchPeriodsHigh, - kTestPitchGainsHigh, - 91, - -0.0188608f), - std::make_tuple(kTestPitchPeriodsHigh, - kTestPitchPeriodsLow, - kTestPitchGainsLow, - 475, - -0.0904344f), - std::make_tuple(kTestPitchPeriodsHigh, - kTestPitchPeriodsLow, - kTestPitchGainsHigh, - 475, - -0.0904344f), - std::make_tuple(kTestPitchPeriodsHigh, - kTestPitchPeriodsHigh, - kTestPitchGainsLow, - 475, - -0.0904344f), - std::make_tuple(kTestPitchPeriodsHigh, - kTestPitchPeriodsHigh, - kTestPitchGainsHigh, - 475, - -0.0904344f))); + ::testing::ValuesIn(CreateExtendedPitchPeriodSearchParameters()), + PrintTestIndexAndCpuFeatures); +} // namespace } // namespace test } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc index 6036a00fd0..fa7795c20c 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc @@ -163,10 +163,11 @@ std::vector GetCpuFeaturesToTest() { std::vector v; v.push_back({/*sse2=*/false, /*avx2=*/false, /*neon=*/false}); AvailableCpuFeatures available = GetAvailableCpuFeatures(); + if (available.avx2 && available.sse2) { + v.push_back({/*sse2=*/true, /*avx2=*/true, /*neon=*/false}); + } if (available.sse2) { - AvailableCpuFeatures features( - {/*sse2=*/true, /*avx2=*/false, /*neon=*/false}); - v.push_back(features); + v.push_back({/*sse2=*/true, /*avx2=*/false, /*neon=*/false}); } return v; }