From f22550175b9604d0f9f00ac461bd9b5f25e1277a Mon Sep 17 00:00:00 2001 From: Alessio Bazzica Date: Fri, 27 Apr 2018 16:44:11 +0200 Subject: [PATCH] AGC2 RNN VAD: Pitch Search MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Functions to estimate pitch period and gain. Bug: webrtc:9076 Change-Id: Icfe9430dcae11bdb96165c5bfe6e2b1d3bf848ab Reviewed-on: https://webrtc-review.googlesource.com/70382 Commit-Queue: Alex Loiko Reviewed-by: Per Ã…hgren Cr-Commit-Position: refs/heads/master@{#23066} --- .../audio_processing/agc2/rnn_vad/BUILD.gn | 7 + .../audio_processing/agc2/rnn_vad/common.h | 25 +- .../agc2/rnn_vad/lp_residual.cc | 26 +- .../agc2/rnn_vad/lp_residual_unittest.cc | 16 +- .../agc2/rnn_vad/pitch_info.h | 29 + .../agc2/rnn_vad/pitch_search.cc | 49 ++ .../agc2/rnn_vad/pitch_search.h | 29 + .../agc2/rnn_vad/pitch_search_internal.cc | 407 ++++++++++++++ .../agc2/rnn_vad/pitch_search_internal.h | 100 ++++ .../rnn_vad/pitch_search_internal_unittest.cc | 531 ++++++++++++++++++ .../agc2/rnn_vad/pitch_search_unittest.cc | 51 ++ .../agc2/rnn_vad/ring_buffer_unittest.cc | 5 +- .../agc2/rnn_vad/rnn_vad_tool.cc | 10 +- .../agc2/rnn_vad/sequence_buffer_unittest.cc | 5 +- .../symmetric_matrix_buffer_unittest.cc | 5 +- .../agc2/rnn_vad/test_utils.cc | 11 +- .../agc2/rnn_vad/test_utils.h | 7 +- 17 files changed, 1264 insertions(+), 49 deletions(-) create mode 100644 modules/audio_processing/agc2/rnn_vad/pitch_info.h create mode 100644 modules/audio_processing/agc2/rnn_vad/pitch_search.cc create mode 100644 modules/audio_processing/agc2/rnn_vad/pitch_search.h create mode 100644 modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc create mode 100644 modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h create mode 100644 modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc create mode 100644 modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn index bd7ff492b9..e05dcab604 100644 --- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn +++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn @@ -19,6 +19,11 @@ source_set("lib") { "common.h", "lp_residual.cc", "lp_residual.h", + "pitch_info.h", + "pitch_search.cc", + "pitch_search.h", + "pitch_search_internal.cc", + "pitch_search_internal.h", "ring_buffer.h", "sequence_buffer.h", "symmetric_matrix_buffer.h", @@ -64,6 +69,8 @@ if (rtc_include_tests) { testonly = true sources = [ "lp_residual_unittest.cc", + "pitch_search_internal_unittest.cc", + "pitch_search_unittest.cc", "ring_buffer_unittest.cc", "sequence_buffer_unittest.cc", "symmetric_matrix_buffer_unittest.cc", diff --git a/modules/audio_processing/agc2/rnn_vad/common.h b/modules/audio_processing/agc2/rnn_vad/common.h index ec42a7b4cf..252bf8472c 100644 --- a/modules/audio_processing/agc2/rnn_vad/common.h +++ b/modules/audio_processing/agc2/rnn_vad/common.h @@ -19,11 +19,30 @@ constexpr size_t kFrameSize10ms24kHz = kSampleRate24kHz / 100; constexpr size_t kFrameSize20ms24kHz = kFrameSize10ms24kHz * 2; // Pitch analysis params. -constexpr size_t kPitchMinPeriod24kHz = kSampleRate24kHz / 800; // 0.00125 s. -constexpr size_t kPitchMaxPeriod24kHz = kSampleRate24kHz / 62.5; // 0.016 s. -constexpr size_t kBufSize24kHz = kPitchMaxPeriod24kHz + kFrameSize20ms24kHz; +constexpr size_t kMinPitch24kHz = kSampleRate24kHz / 800; // 0.00125 s. +constexpr size_t kMaxPitch24kHz = kSampleRate24kHz / 62.5; // 0.016 s. +constexpr size_t kBufSize24kHz = kMaxPitch24kHz + kFrameSize20ms24kHz; static_assert((kBufSize24kHz & 1) == 0, "The buffer size must be even."); +// Define a higher minimum pitch period for the initial search. This is used to +// avoid searching for very short periods, for which a refinement step is +// responsible. +constexpr size_t kInitialMinPitch24kHz = 3 * kMinPitch24kHz; +static_assert(kMinPitch24kHz < kInitialMinPitch24kHz, ""); +static_assert(kInitialMinPitch24kHz < kMaxPitch24kHz, ""); + +// 12 kHz analysis. +constexpr size_t kSampleRate12kHz = 12000; +constexpr size_t kFrameSize10ms12kHz = kSampleRate12kHz / 100; +constexpr size_t kFrameSize20ms12kHz = kFrameSize10ms12kHz * 2; +constexpr size_t kBufSize12kHz = kBufSize24kHz / 2; +constexpr size_t kInitialMinPitch12kHz = kInitialMinPitch24kHz / 2; +constexpr size_t kMaxPitch12kHz = kMaxPitch24kHz / 2; + +// 48 kHz constants. +constexpr size_t kMinPitch48kHz = kMinPitch24kHz * 2; +constexpr size_t kMaxPitch48kHz = kMaxPitch24kHz * 2; + } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/lp_residual.cc b/modules/audio_processing/agc2/rnn_vad/lp_residual.cc index 63aec6b0b9..483336de93 100644 --- a/modules/audio_processing/agc2/rnn_vad/lp_residual.cc +++ b/modules/audio_processing/agc2/rnn_vad/lp_residual.cc @@ -90,22 +90,18 @@ void ComputeAndPostProcessLpcCoefficients( {auto_corr.data(), auto_corr.size()}, {lpc_coeffs_pre.data(), lpc_coeffs_pre.size()}); // LPC coefficients post-processing. - // TODO(https://bugs.webrtc.org/9076): Consider removing these steps. - { - float c = 1.f; - for (size_t i = 0; i < kNumLpcCoefficients - 1; ++i) { - c *= 0.9f; - lpc_coeffs_pre[i] *= c; - } - } - { - const float c = 0.8f; - lpc_coeffs[0] = lpc_coeffs_pre[0] + c; - lpc_coeffs[1] = lpc_coeffs_pre[1] + c * lpc_coeffs_pre[0]; - lpc_coeffs[2] = lpc_coeffs_pre[2] + c * lpc_coeffs_pre[1]; - lpc_coeffs[3] = lpc_coeffs_pre[3] + c * lpc_coeffs_pre[2]; - lpc_coeffs[4] = c * lpc_coeffs_pre[3]; + // TODO(bugs.webrtc.org/9076): Consider removing these steps. + float c1 = 1.f; + for (size_t i = 0; i < kNumLpcCoefficients - 1; ++i) { + c1 *= 0.9f; + lpc_coeffs_pre[i] *= c1; } + const float c2 = 0.8f; + lpc_coeffs[0] = lpc_coeffs_pre[0] + c2; + lpc_coeffs[1] = lpc_coeffs_pre[1] + c2 * lpc_coeffs_pre[0]; + lpc_coeffs[2] = lpc_coeffs_pre[2] + c2 * lpc_coeffs_pre[1]; + lpc_coeffs[3] = lpc_coeffs_pre[3] + c2 * lpc_coeffs_pre[2]; + lpc_coeffs[4] = c2 * lpc_coeffs_pre[3]; } void ComputeLpResidual( diff --git a/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc b/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc index 41ffe68e0a..23f1e14c22 100644 --- a/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc @@ -14,21 +14,16 @@ #include "modules/audio_processing/agc2/rnn_vad/common.h" #include "modules/audio_processing/agc2/rnn_vad/test_utils.h" -// TODO(https://bugs.webrtc.org/8948): Add when the issue is fixed. +// TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // #include "test/fpe_observer.h" #include "test/gtest.h" namespace webrtc { +namespace rnn_vad { namespace test { -using rnn_vad::ComputeAndPostProcessLpcCoefficients; -using rnn_vad::ComputeLpResidual; -using rnn_vad::kBufSize24kHz; -using rnn_vad::kFrameSize10ms24kHz; -using rnn_vad::kNumLpcCoefficients; - TEST(RnnVadTest, LpResidualOfEmptyFrame) { - // TODO(https://bugs.webrtc.org/8948): Add when the issue is fixed. + // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; // Input frame (empty, i.e., all samples set to 0). @@ -44,7 +39,7 @@ TEST(RnnVadTest, LpResidualOfEmptyFrame) { {lp_residual}); } -// TODO(https://bugs.webrtc.org/9076): Remove when the issue is fixed. +// TODO(bugs.webrtc.org/9076): Remove when the issue is fixed. TEST(RnnVadTest, LpResidualPipelineBitExactness) { // Pitch buffer 24 kHz data reader. auto pitch_buf_24kHz_reader = CreatePitchBuffer24kHzReader(); @@ -66,7 +61,7 @@ TEST(RnnVadTest, LpResidualPipelineBitExactness) { rtc::ArrayView computed_lp_residual_view( computed_lp_residual.data(), computed_lp_residual.size()); { - // TODO(https://bugs.webrtc.org/8948): Add when the issue is fixed. + // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; for (size_t i = 0; i < num_frames; ++i) { @@ -91,4 +86,5 @@ TEST(RnnVadTest, LpResidualPipelineBitExactness) { } } // namespace test +} // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_info.h b/modules/audio_processing/agc2/rnn_vad/pitch_info.h new file mode 100644 index 0000000000..f0998d1fad --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/pitch_info.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_INFO_H_ +#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_INFO_H_ + +namespace webrtc { +namespace rnn_vad { + +// Stores pitch period and gain information. The pitch gain measures the +// strength of the pitch (the higher, the stronger). +struct PitchInfo { + PitchInfo() : period(0), gain(0.f) {} + PitchInfo(size_t p, float g) : period(p), gain(g) {} + size_t period; + float gain; +}; + +} // namespace rnn_vad +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_INFO_H_ diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc new file mode 100644 index 0000000000..4d83588cb3 --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/rnn_vad/pitch_search.h" +#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h" + +namespace webrtc { +namespace rnn_vad { + +PitchInfo PitchSearch(rtc::ArrayView pitch_buf, + PitchInfo prev_pitch_48kHz) { + // Perform the initial pitch search at 12 kHz. + std::array pitch_buf_decimated; + Decimate2x(pitch_buf, + {pitch_buf_decimated.data(), pitch_buf_decimated.size()}); + // Compute auto-correlation terms. + std::array auto_corr; + ComputePitchAutoCorrelation( + {pitch_buf_decimated.data(), pitch_buf_decimated.size()}, kMaxPitch12kHz, + {auto_corr.data(), auto_corr.size()}); + // Search for pitch at 12 kHz. + std::array pitch_candidates_inv_lags = FindBestPitchPeriods( + {auto_corr.data(), auto_corr.size()}, + {pitch_buf_decimated.data(), pitch_buf_decimated.size()}, kMaxPitch12kHz); + + // Refine the pitch period estimation. + // The refinement is done using the pitch buffer that contains 24 kHz samples. + // Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12 + // to 24 kHz. + for (size_t i = 0; i < pitch_candidates_inv_lags.size(); ++i) + pitch_candidates_inv_lags[i] *= 2; + size_t pitch_inv_lag_48kHz = RefinePitchPeriod48kHz( + pitch_buf, + {pitch_candidates_inv_lags.data(), pitch_candidates_inv_lags.size()}); + // Look for stronger harmonics to find the final pitch period and its gain. + RTC_DCHECK_LT(pitch_inv_lag_48kHz, kMaxPitch48kHz); + return CheckLowerPitchPeriodsAndComputePitchGain( + pitch_buf, kMaxPitch48kHz - pitch_inv_lag_48kHz, prev_pitch_48kHz); +} + +} // namespace rnn_vad +} // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.h b/modules/audio_processing/agc2/rnn_vad/pitch_search.h new file mode 100644 index 0000000000..a0af0ebfa2 --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_H_ +#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_H_ + +#include "api/array_view.h" +#include "modules/audio_processing/agc2/rnn_vad/common.h" +#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h" + +namespace webrtc { +namespace rnn_vad { + +// Searches the pitch period and gain. Return the pitch estimation data for +// 48 kHz. +PitchInfo PitchSearch(rtc::ArrayView pitch_buf, + PitchInfo prev_pitch_48kHz); + +} // namespace rnn_vad +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_H_ diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc new file mode 100644 index 0000000000..1ff4621b28 --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc @@ -0,0 +1,407 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h" + +#include +#include +#include +#include + +#include "modules/audio_processing/agc2/rnn_vad/common.h" +#include "rtc_base/checks.h" + +namespace webrtc { +namespace rnn_vad { +namespace { + +// Converts a lag to an inverted lag (only for 24kHz). +size_t GetInvertedLag(size_t lag) { + RTC_DCHECK_LE(lag, kMaxPitch24kHz); + return kMaxPitch24kHz - lag; +} + +float ComputeAutoCorrelationCoeff(rtc::ArrayView pitch_buf, + size_t inv_lag, + size_t max_pitch_period) { + RTC_DCHECK_LT(inv_lag, pitch_buf.size()); + RTC_DCHECK_LT(max_pitch_period, pitch_buf.size()); + RTC_DCHECK_LE(inv_lag, max_pitch_period); + // TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization. + return std::inner_product(pitch_buf.begin() + max_pitch_period, + pitch_buf.end(), pitch_buf.begin() + inv_lag, 0.f); +} + +// Computes a pseudo-interpolation offset for an estimated pitch period |lag| by +// looking at the auto-correlation coefficients in the neighborhood of |lag|. +// (namely, |prev_auto_corr|, |lag_auto_corr| and |next_auto_corr|). The output +// is a lag in {-1, 0, +1}. +// TODO(bugs.webrtc.org/9076): Consider removing pseudo-i since it +// is relevant only if the spectral analysis works at a sample rate that is +// twice as that of the pitch buffer (not so important instead for the estimated +// pitch period feature fed into the RNN). +int GetPitchPseudoInterpolationOffset(size_t lag, + float prev_auto_corr, + float lag_auto_corr, + float next_auto_corr) { + const float& a = prev_auto_corr; + const float& b = lag_auto_corr; + const float& c = next_auto_corr; + + int offset = 0; + if ((c - a) > 0.7f * (b - a)) { + offset = 1; // |c| is the largest auto-correlation coefficient. + } else if ((a - c) > 0.7f * (b - c)) { + offset = -1; // |a| is the largest auto-correlation coefficient. + } + return offset; +} + +// Refines a pitch period |lag| encoded as lag with pseudo-interpolation. The +// output sample rate is twice as that of |lag|. +size_t PitchPseudoInterpolationLagPitchBuf( + size_t lag, + rtc::ArrayView pitch_buf) { + int offset = 0; + // Cannot apply pseudo-interpolation at the boundaries. + if (lag > 0 && lag < kMaxPitch24kHz) { + offset = GetPitchPseudoInterpolationOffset( + lag, + ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag - 1), + kMaxPitch24kHz), + ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag), + kMaxPitch24kHz), + ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag + 1), + kMaxPitch24kHz)); + } + return 2 * lag + offset; +} + +// Refines a pitch period |inv_lag| encoded as inverted lag with +// pseudo-interpolation. The output sample rate is twice as that of +// |inv_lag|. +size_t PitchPseudoInterpolationInvLagAutoCorr( + size_t inv_lag, + rtc::ArrayView auto_corr) { + int offset = 0; + // Cannot apply pseudo-interpolation at the boundaries. + if (inv_lag > 0 && inv_lag < auto_corr.size() - 1) { + offset = GetPitchPseudoInterpolationOffset(inv_lag, auto_corr[inv_lag + 1], + auto_corr[inv_lag], + auto_corr[inv_lag - 1]); + } + // TODO(bugs.webrtc.org/9076): When retraining, check if |offset| below should + // be subtracted since |inv_lag| is an inverted lag but offset is a lag. + return 2 * inv_lag + offset; +} + +// Integer multipliers used in CheckLowerPitchPeriodsAndComputePitchGain() when +// looking for sub-harmonics. +// The values have been chosen to serve the following algorithm. Given the +// initial pitch period T, we examine whether one of its harmonics is the true +// fundamental frequency. We consider T/k with k in {2, ..., 15}. For each of +// these harmonics, in addition to the pitch gain of itself, we choose one +// multiple of its pitch period, n*T/k, to validate it (by averaging their pitch +// gains). The multiplier n is chosen so that n*T/k is used only one time over +// all k. When for example k = 4, we should also expect a peak at 3*T/4. When +// k = 8 instead we don't want to look at 2*T/8, since we have already checked +// T/4 before. Instead, we look at T*3/8. +// The array can be generate in Python as follows: +// from fractions import Fraction +// # Smallest positive integer not in X. +// def mex(X): +// for i in range(1, int(max(X)+2)): +// if i not in X: +// return i +// # Visited multiples of the period. +// S = {1} +// for n in range(2, 16): +// sn = mex({n * i for i in S} | {1}) +// S = S | {Fraction(1, n), Fraction(sn, n)} +// print(sn, end=', ') +constexpr std::array kSubHarmonicMultipliers = { + {3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2}}; + +// Initial pitch period candidate thresholds for ComputePitchGainThreshold() for +// a sample rate of 24 kHz. Computed as [5*k*k for k in range(16)]. +constexpr std::array kInitialPitchPeriodThresholds = { + {20, 45, 80, 125, 180, 245, 320, 405, 500, 605, 720, 845, 980, 1125}}; + +} // namespace + +void Decimate2x(rtc::ArrayView src, + rtc::ArrayView dst) { + // TODO(bugs.webrtc.org/9076): Consider adding anti-aliasing filter. + static_assert(2 * dst.size() == src.size(), ""); + for (size_t i = 0; i < dst.size(); ++i) + dst[i] = src[2 * i]; +} + +float ComputePitchGainThreshold(size_t candidate_pitch_period, + size_t pitch_period_ratio, + size_t initial_pitch_period, + float initial_pitch_gain, + size_t prev_pitch_period, + size_t prev_pitch_gain) { + // Map arguments to more compact aliases. + const size_t& t1 = candidate_pitch_period; + const size_t& k = pitch_period_ratio; + const size_t& t0 = initial_pitch_period; + const float& g0 = initial_pitch_gain; + const size_t& t_prev = prev_pitch_period; + const size_t& g_prev = prev_pitch_gain; + + // Validate input. + RTC_DCHECK_GE(k, 2); + + // Compute a term that lowers the threshold when |t1| is close to the last + // estimated period |t_prev| - i.e., pitch tracking. + float lower_threshold_term = 0; + if (abs(static_cast(t1) - static_cast(t_prev)) <= 1) { + // The candidate pitch period is within 1 sample from the previous one. + // Make the candidate at |t1| very easy to be accepted. + lower_threshold_term = g_prev; + } else if (abs(static_cast(t1) - static_cast(t_prev)) == 2 && + t0 > kInitialPitchPeriodThresholds[k - 2]) { + // The candidate pitch period is 2 samples far from the previous one and the + // period |t0| (from which |t1| has been derived) is greater than a + // threshold. Make |t1| easy to be accepted. + lower_threshold_term = 0.5f * g_prev; + } + // Set the threshold based on the gain of the initial estimate |t0|. Also + // reduce the chance of false positives caused by a bias towards high + // frequencies (originating from short-term correlations). + float threshold = std::max(0.3f, 0.7f * g0 - lower_threshold_term); + if (t1 < 3 * kMinPitch24kHz) { // High frequency. + threshold = std::max(0.4f, 0.85f * g0 - lower_threshold_term); + } else if (t1 < 2 * kMinPitch24kHz) { // Even higher frequency. + threshold = std::max(0.5f, 0.9f * g0 - lower_threshold_term); + } + return threshold; +} + +void ComputeSlidingFrameSquareEnergies( + rtc::ArrayView pitch_buf, + rtc::ArrayView yy_values) { + float yy = + ComputeAutoCorrelationCoeff(pitch_buf, kMaxPitch24kHz, kMaxPitch24kHz); + yy_values[0] = yy; + for (size_t i = 1; i < yy_values.size(); ++i) { + RTC_DCHECK_LE(i, kMaxPitch24kHz + kFrameSize20ms24kHz); + RTC_DCHECK_LE(i, kMaxPitch24kHz); + const float old_coeff = pitch_buf[kMaxPitch24kHz + kFrameSize20ms24kHz - i]; + const float new_coeff = pitch_buf[kMaxPitch24kHz - i]; + yy -= old_coeff * old_coeff; + yy += new_coeff * new_coeff; + yy = std::max(0.f, yy); + yy_values[i] = yy; + } +} + +// TODO(bugs.webrtc.org/9076): Optimize using FFT and/or vectorization. +void ComputePitchAutoCorrelation( + rtc::ArrayView pitch_buf, + size_t max_pitch_period, + rtc::ArrayView auto_corr) { + RTC_DCHECK_GT(max_pitch_period, auto_corr.size()); + RTC_DCHECK_LT(max_pitch_period, pitch_buf.size()); + // Compute auto-correlation coefficients. + for (size_t inv_lag = 0; inv_lag < auto_corr.size(); ++inv_lag) { + auto_corr[inv_lag] = + ComputeAutoCorrelationCoeff(pitch_buf, inv_lag, max_pitch_period); + } +} + +std::array FindBestPitchPeriods( + rtc::ArrayView auto_corr, + rtc::ArrayView pitch_buf, + size_t max_pitch_period) { + // Stores a pitch candidate period and strength information. + struct PitchCandidate { + // Pitch period encoded as inverted lag. + size_t period_inverted_lag = 0; + // Pitch strength encoded as a ratio. + float strength_numerator = -1.f; + float strength_denominator = 0.f; + // Compare the strength of two pitch candidates. + bool HasStrongerPitchThan(const PitchCandidate& b) const { + // Comparing the numerator/denominator ratios without using divisions. + return strength_numerator * b.strength_denominator > + b.strength_numerator * strength_denominator; + } + }; + + RTC_DCHECK_GT(max_pitch_period, auto_corr.size()); + RTC_DCHECK_LT(max_pitch_period, pitch_buf.size()); + const size_t frame_size = pitch_buf.size() - max_pitch_period; + // TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization. + float yy = + std::inner_product(pitch_buf.begin(), pitch_buf.begin() + frame_size + 1, + pitch_buf.begin(), 1.f); + // Search best and second best pitches by looking at the scaled + // auto-correlation. + PitchCandidate candidate; + PitchCandidate best; + PitchCandidate second_best; + second_best.period_inverted_lag = 1; + for (size_t inv_lag = 0; inv_lag < auto_corr.size(); ++inv_lag) { + // A pitch candidate must have positive correlation. + if (auto_corr[inv_lag] > 0) { + candidate.period_inverted_lag = inv_lag; + candidate.strength_numerator = auto_corr[inv_lag] * auto_corr[inv_lag]; + candidate.strength_denominator = yy; + if (candidate.HasStrongerPitchThan(second_best)) { + if (candidate.HasStrongerPitchThan(best)) { + second_best = best; + best = candidate; + } else { + second_best = candidate; + } + } + } + // Update |squared_energy_y| for the next inverted lag. + const float old_coeff = pitch_buf[inv_lag]; + const float new_coeff = pitch_buf[inv_lag + frame_size]; + yy -= old_coeff * old_coeff; + yy += new_coeff * new_coeff; + yy = std::max(0.f, yy); + } + return {{best.period_inverted_lag, second_best.period_inverted_lag}}; +} + +size_t RefinePitchPeriod48kHz( + rtc::ArrayView pitch_buf, + rtc::ArrayView inv_lags) { + // Compute the auto-correlation terms only for neighbors of the given pitch + // candidates (similar to what is done in ComputePitchAutoCorrelation(), but + // for a few lag values). + std::array auto_corr; + auto_corr.fill(0.f); // Zeros become ignored lags in FindBestPitchPeriods(). + auto is_neighbor = [](size_t i, size_t j) { + return ((i > j) ? (i - j) : (j - i)) <= 2; + }; + for (size_t inv_lag = 0; inv_lag < auto_corr.size(); ++inv_lag) { + if (is_neighbor(inv_lag, inv_lags[0]) || is_neighbor(inv_lag, inv_lags[1])) + auto_corr[inv_lag] = + ComputeAutoCorrelationCoeff(pitch_buf, inv_lag, kMaxPitch24kHz); + } + // Find best pitch at 24 kHz. + const auto pitch_candidates_inv_lags = FindBestPitchPeriods( + {auto_corr.data(), auto_corr.size()}, + {pitch_buf.data(), pitch_buf.size()}, kMaxPitch24kHz); + const auto inv_lag = pitch_candidates_inv_lags[0]; // Refine the best. + // Pseudo-interpolation. + return PitchPseudoInterpolationInvLagAutoCorr( + inv_lag, {auto_corr.data(), auto_corr.size()}); +} + +PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( + rtc::ArrayView pitch_buf, + size_t initial_pitch_period_48kHz, + PitchInfo prev_pitch_48kHz) { + RTC_DCHECK_LE(kMinPitch48kHz, initial_pitch_period_48kHz); + RTC_DCHECK_LE(initial_pitch_period_48kHz, kMaxPitch48kHz); + // Stores information for a refined pitch candidate. + struct RefinedPitchCandidate { + RefinedPitchCandidate() {} + RefinedPitchCandidate(size_t period_24kHz, float gain, float xy, float yy) + : period_24kHz(period_24kHz), gain(gain), xy(xy), yy(yy) {} + size_t period_24kHz; + // Pitch strength information. + float gain; + // Additional pitch strength information used for the final estimation of + // pitch gain. + float xy; // Cross-correlation. + float yy; // Auto-correlation. + }; + + // Initialize. + std::array yy_values; + ComputeSlidingFrameSquareEnergies(pitch_buf, + {yy_values.data(), yy_values.size()}); + const float xx = yy_values[0]; + // Helper lambdas. + const auto pitch_gain = [](float xy, float yy, float xx) { + RTC_DCHECK_LE(0.f, xx * yy); + return xy / std::sqrt(1.f + xx * yy); + }; + // Initial pitch candidate gain. + RefinedPitchCandidate best_pitch; + best_pitch.period_24kHz = + std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHz - 1); + best_pitch.xy = ComputeAutoCorrelationCoeff( + pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHz); + best_pitch.yy = yy_values[best_pitch.period_24kHz]; + best_pitch.gain = pitch_gain(best_pitch.xy, best_pitch.yy, xx); + + // Store the initial pitch period information. + const size_t initial_pitch_period = best_pitch.period_24kHz; + const float initial_pitch_gain = best_pitch.gain; + + // Given the initial pitch estimation, check lower periods (i.e., harmonics). + const auto alternative_period = [](size_t period, size_t k, + size_t n) -> size_t { + RTC_DCHECK_LT(0, k); + return (2 * n * period + k) / (2 * k); // Same as round(n*period/k). + }; + for (size_t k = 2; k < kSubHarmonicMultipliers.size() + 2; ++k) { + size_t candidate_pitch_period = + alternative_period(initial_pitch_period, k, 1); + if (candidate_pitch_period < kMinPitch24kHz) + break; + // When looking at |candidate_pitch_period|, we also look at one of its + // sub-harmonics. |kSubHarmonicMultipliers| is used to know where to look. + // |k| == 2 is a special case since |candidate_pitch_secondary_period| might + // be greater than the maximum pitch period. + size_t candidate_pitch_secondary_period = alternative_period( + initial_pitch_period, k, kSubHarmonicMultipliers[k - 2]); + if (k == 2 && candidate_pitch_secondary_period > kMaxPitch24kHz) + candidate_pitch_secondary_period = initial_pitch_period; + RTC_DCHECK_NE(candidate_pitch_period, candidate_pitch_secondary_period) + << "The lower pitch period and the additional sub-harmonic must not " + << "coincide."; + // Compute an auto-correlation score for the primary pitch candidate + // |candidate_pitch_period| by also looking at its possible sub-harmonic + // |candidate_pitch_secondary_period|. + float xy_primary_period = ComputeAutoCorrelationCoeff( + pitch_buf, GetInvertedLag(candidate_pitch_period), kMaxPitch24kHz); + float xy_secondary_period = ComputeAutoCorrelationCoeff( + pitch_buf, GetInvertedLag(candidate_pitch_secondary_period), + kMaxPitch24kHz); + float xy = 0.5f * (xy_primary_period + xy_secondary_period); + float yy = 0.5f * (yy_values[candidate_pitch_period] + + yy_values[candidate_pitch_secondary_period]); + float candidate_pitch_gain = pitch_gain(xy, yy, xx); + + // Maybe update best period. + float threshold = ComputePitchGainThreshold( + candidate_pitch_period, k, initial_pitch_period, initial_pitch_gain, + prev_pitch_48kHz.period / 2, prev_pitch_48kHz.gain); + if (candidate_pitch_gain > threshold) { + best_pitch = {candidate_pitch_period, candidate_pitch_gain, xy, yy}; + } + } + + // Final pitch gain and period. + best_pitch.xy = std::max(0.f, best_pitch.xy); + RTC_DCHECK_LE(0.f, best_pitch.yy); + float final_pitch_gain = (best_pitch.yy <= best_pitch.xy) + ? 1.f + : best_pitch.xy / (best_pitch.yy + 1.f); + final_pitch_gain = std::min(best_pitch.gain, final_pitch_gain); + size_t final_pitch_period_48kHz = std::max( + kMinPitch48kHz, + PitchPseudoInterpolationLagPitchBuf(best_pitch.period_24kHz, pitch_buf)); + + return {final_pitch_period_48kHz, final_pitch_gain}; +} + +} // namespace rnn_vad +} // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h new file mode 100644 index 0000000000..dfe1b35ff7 --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_INTERNAL_H_ +#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_INTERNAL_H_ + +#include + +#include "api/array_view.h" +#include "modules/audio_processing/agc2/rnn_vad/common.h" +#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h" + +namespace webrtc { +namespace rnn_vad { + +// The inverted lags for the pitch interval [|kInitialMinPitch12kHz|, +// |kMaxPitch12kHz|] are in the range [0, |kNumInvertedLags|]. +static_assert(kMaxPitch12kHz > kInitialMinPitch12kHz, ""); +static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, ""); +constexpr size_t kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz; +constexpr size_t kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz; + +// Performs 2x decimation without any anti-aliasing filter. +void Decimate2x(rtc::ArrayView src, + rtc::ArrayView dst); + +// Computes a gain threshold for a candidate pitch period given the initial and +// the previous pitch period and gain estimates and the pitch period ratio used +// to derive the candidate pitch period from the initial period. +float ComputePitchGainThreshold(size_t candidate_pitch_period, + size_t pitch_period_ratio, + size_t initial_pitch_period, + float initial_pitch_gain, + size_t prev_pitch_period, + size_t prev_pitch_gain); + +// Computes the sum of squared samples for every sliding frame in the pitch +// buffer. |yy_values| indexes are lags. +// +// The pitch buffer is structured as depicted below: +// |.........|...........| +// a b +// The part on the left, named "a" contains the oldest samples, whereas "b" the +// most recent ones. The size of "a" corresponds to the maximum pitch period, +// that of "b" to the frame size (e.g., 16 ms and 20 ms respectively). +void ComputeSlidingFrameSquareEnergies( + rtc::ArrayView pitch_buf, + rtc::ArrayView yy_values); + +// Computes the auto-correlation coefficients for a given pitch interval. +// |auto_corr| indexes are inverted lags. +// +// The auto-correlations coefficients are computed as follows: +// |.........|...........| <- pitch buffer +// [ x (fixed) ] +// [ y_0 ] +// [ y_{m-1} ] +// x and y are sub-array of equal length; x is never moved, whereas y slides. +// The cross-correlation between y_0 and x corresponds to the auto-correlation +// for the maximum pitch period. Hence, the first value in |auto_corr| has an +// inverted lag equal to 0 that corresponds to a lag equal to the maximum pitch +// period. +void ComputePitchAutoCorrelation( + rtc::ArrayView pitch_buf, + size_t max_pitch_period, + rtc::ArrayView auto_corr); + +// Given the auto-correlation coefficients stored according to +// ComputePitchAutoCorrelation() (i.e., using inverted lags), returns the best +// and the second best pitch periods. +std::array FindBestPitchPeriods( + rtc::ArrayView auto_corr, + rtc::ArrayView pitch_buf, + size_t max_pitch_period); + +// Refines the pitch period estimation given the pitch buffer |pitch_buf| and +// the initial pitch period estimation |inv_lags|. Returns an inverted lag at +// 48 kHz. +size_t RefinePitchPeriod48kHz( + rtc::ArrayView pitch_buf, + rtc::ArrayView inv_lags); + +// Refines the pitch period estimation and compute the pitch gain. Returns the +// refined pitch estimation data at 48 kHz. +PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( + rtc::ArrayView pitch_buf, + size_t initial_pitch_period_48kHz, + PitchInfo prev_pitch_48kHz); + +} // namespace rnn_vad +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_INTERNAL_H_ diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc new file mode 100644 index 0000000000..9a6a2676c6 --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc @@ -0,0 +1,531 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h" + +#include +#include + +#include "modules/audio_processing/agc2/rnn_vad/test_utils.h" +// TODO(bugs.webrtc.org/8948): Add when the issue is fixed. +// #include "test/fpe_observer.h" +#include "test/gtest.h" + +namespace webrtc { +namespace rnn_vad { +namespace test { +namespace { + +// TODO(bugs.webrtc.org/9076): Move to resource file. +constexpr std::array kPitchBufferData = { + -35.248100281f, -25.836528778f, 5.682674408f, 2.880297661f, -1.648161888f, + -4.094896793f, -3.500580072f, -0.896141529f, -2.989939451f, -4.608089447f, + -3.721750736f, -2.290785789f, -3.326566458f, -4.370154381f, -3.221047878f, + -4.049056530f, -2.846302271f, -1.805017233f, -1.547624588f, -0.809937477f, + -1.446955442f, -3.258146763f, -1.849959373f, 0.005283833f, -0.571619749f, + -0.630573988f, -0.162780523f, -2.699024916f, -0.856231451f, 2.748089552f, + 2.026614428f, -0.474685907f, -0.571918726f, 1.186420918f, 1.770769954f, + 2.017296791f, 1.154794335f, 1.082345366f, 1.954892635f, 2.249727726f, + 2.643483400f, 1.857815385f, 0.064472735f, 0.015978813f, 0.301099658f, + 0.478950322f, -0.669701457f, -0.654453993f, 1.338572979f, -0.493052602f, + -1.763812065f, 0.524392128f, 0.010438919f, -1.726593733f, -2.866710663f, + -2.065258503f, -3.010460854f, -3.994765282f, -4.102010250f, -3.135548830f, + -2.597487926f, -2.255330563f, -1.002008915f, 0.523116589f, 1.430158496f, + -1.655169368f, -2.263641357f, 0.766040802f, 1.166070461f, 0.002490997f, + 0.401043415f, -0.158550858f, -0.572042346f, 1.365390539f, -1.397871614f, + -2.020734787f, -1.979169965f, -1.025816441f, 0.012545407f, -1.042758584f, + -1.206598401f, -1.140330791f, -3.060853720f, -3.530077934f, -1.774474382f, + -1.342000484f, -3.171817064f, -2.489153862f, -1.593364000f, -2.552185535f, + -2.899760723f, -4.698278427f, -4.123534203f, -2.613421679f, -2.061793327f, + -4.113687515f, -3.174087524f, -2.367874622f, -4.523970604f, -4.250762939f, + -2.752931118f, -1.547106743f, -4.109455109f, -3.893044472f, -2.348384857f, + -3.194510698f, -3.502159357f, -2.785978794f, -1.981978416f, -3.279178143f, + -3.007923365f, -1.801304340f, -1.839247227f, -1.003675938f, -0.985928297f, + -1.647925615f, -2.166392088f, -1.947163343f, 0.488545895f, 1.567199469f, + -1.179960012f, -2.710370064f, -2.613196850f, -3.205850124f, -2.796218395f, + -0.715085745f, 1.406243801f, -0.779834270f, -2.075612307f, -0.922246933f, + -1.849850416f, 0.979040504f, 3.570628166f, 0.945924520f, -2.821768284f, + -6.262358189f, -6.154916763f, -0.567943573f, 2.386518955f, 1.673806906f, + -3.676584721f, -7.129202843f, -3.311969519f, 1.126702785f, 3.218248606f, + 1.600885630f, -1.709451079f, -6.822564125f, -6.011950970f, -0.671678543f, + 1.080205441f, -1.342422366f, -3.589303732f, -3.586701870f, -3.425134897f, + -1.078015327f, 2.556719542f, 0.469867468f, 0.139251709f, -0.118916273f, + -1.284181952f, 0.941113472f, 0.550188303f, -1.767568469f, -5.429461956f, + -5.065113068f, -2.111886740f, -3.606999397f, -2.410579205f, 1.013466120f, + 1.057218194f, 0.305267453f, 2.898609161f, 5.776575565f, 4.792305946f, + -0.863526106f, -2.439013481f, -0.825202525f, -2.297998428f, -0.520106375f, + -0.653605103f, -3.204111576f, -2.455038786f, -2.160304308f, 0.622359931f, + 3.803062916f, 4.340928555f, 2.390868664f, 1.645600080f, 0.405841053f, + -0.153203994f, 3.438643217f, 4.752261162f, 1.552502871f, 1.947945356f, + 0.856451511f, -0.606808305f, -1.223945618f, -1.845071912f, -0.204472303f, + 1.750840783f, 2.435559034f, -1.253612280f, -2.675215721f, 1.614801407f, + 3.002861023f, 1.743503809f, 3.409059286f, 4.303173542f, 2.441751957f, + 1.752274275f, 1.874113560f, 2.070837736f, 1.401355743f, -0.330647945f, + -0.664121151f, 1.196543574f, 1.506967187f, 0.985752344f, -1.265938520f, + -1.433794141f, 0.380195618f, 0.061504841f, 1.079771042f, 1.773771763f, + 3.226663589f, 4.170571804f, 4.220288277f, 3.619904041f, 2.316211224f, + 2.012817860f, 0.370972633f, 0.517094851f, 1.869508862f, 0.357770681f, + -2.991472483f, -3.216646433f, 0.232109070f, 1.803660274f, 2.928784370f, + 4.909455776f, 5.913621426f, 4.653719902f, 4.387111187f, 4.793289661f, + 4.744520187f, 5.214610100f, 3.996322632f, 2.619040728f, 0.758128643f, + -0.092789888f, 0.070066452f, 0.704165459f, 2.042234898f, 2.768569231f, + 3.340583324f, 3.212181091f, 2.748130322f, 3.077554941f, 2.189792156f, + 2.646749735f, 2.817450523f, 1.611892223f, 1.981805444f, -1.088236094f, + -2.187484741f, -0.654897690f, -0.900939941f, 0.148309708f, 1.498139143f, + -0.261296749f, -3.220157146f, -1.727450609f, 0.807144105f, -0.809251904f, + -2.361308336f, -1.421746969f, -0.793132067f, -0.313778281f, -0.641793191f, + -0.999286890f, 0.219423503f, 0.976444781f, 0.152786255f, -0.405437022f, + 0.120257735f, -0.392024517f, -0.019678771f, 1.492373466f, 0.926774263f, + 0.566291928f, 1.307234287f, 1.496955752f, 1.448441863f, 2.212901354f, + 1.314700723f, 0.213681281f, 1.011370897f, 1.827155828f, 0.250772655f, + -0.429592669f, 0.435638547f, 1.506532907f, 1.350761652f, -0.387142301f, + -1.770648122f, -2.690037489f, -1.788924456f, -2.023291588f, -2.354584694f, + -2.587521076f, -2.002159595f, -0.355855435f, 0.825611115f, 3.075081587f, + 2.687968254f, 0.074088633f, 0.439936757f, 1.214704275f, 2.670343399f, + 1.567362547f, -1.573154926f, -3.216549397f, -3.596383333f, -3.893716335f, + -2.456265688f, -4.313135624f, -5.783064842f, -5.344826221f, -3.484399319f, + -2.235594273f, -3.568959475f, -2.447141886f, -0.755384564f, -1.178364277f, + 1.034289122f, 1.746821165f, -1.159413576f, -2.569937706f, -1.742212296f, + -0.270784855f, 1.886857986f, 0.831889153f, 0.636521816f, -0.067433357f, + -0.256595969f, 0.907287478f, 1.575596929f, 0.393882513f, -0.510042071f, + 0.507258415f, 0.059408009f, 1.776192427f, 1.664948106f, -0.341539711f, + -0.072047889f, -0.795555651f, 0.704908550f, 2.127685547f, 1.486027241f, + 1.973046541f, 2.456688404f, 2.871328354f, 4.989626408f, 5.076294422f, + 4.262395859f, 3.622689009f, 3.241683960f, 4.222597599f, 3.575423479f, + 1.997965097f, 1.391216874f, 2.329971790f, 2.898612261f, 3.871258736f, + 2.857767582f, 2.960238218f, 3.047467470f, 2.790968180f, 2.183730364f, + 1.991029263f, 2.727865934f, 1.561259747f, 0.787606239f, 3.036532879f, + 2.430759192f, 1.475822210f, 2.307994127f, 1.857011318f, 1.538355589f, + 2.320549965f, 3.305005074f, 2.554165363f, 2.630100727f, 3.506094217f, + 4.454113483f, 2.894124269f, 4.061129570f, 4.425602436f, 3.218537807f, + 2.712452173f, 5.546891212f, 6.138017654f, 5.897895813f, 5.698192596f, + 4.096743584f, 2.661385298f, 3.646550655f, 4.626225948f, 5.025664330f, + 3.861543894f, 4.374861717f, 5.388185978f, 3.376737356f, 2.751175404f, + 3.299628258f, 2.025987387f, 1.094563961f, 0.128147125f, -4.321690559f, + -6.165239811f, -4.245608330f, -2.974690914f, -5.110438824f, -6.619713306f, + -6.594148636f, -7.972207069f, -8.034727097f, -7.296438217f, -6.822746754f, + -6.375267029f, -7.629575729f, -8.404177666f, -5.002337456f, -7.024040699f, + -7.799823761f, -5.423873901f, -4.861459732f, -2.772324085f, 0.002551556f, + -1.445306778f, -1.726813316f, 0.889497757f, 1.760663986f, 2.722227097f, + 4.755805969f, 4.188167572f, 1.547533512f, 2.444593906f, 1.612852097f, + -0.508655310f, 0.046535015f, 1.720140934f, 1.265070438f, 0.976964772f, + 2.446830273f, 6.308787823f, 7.798269272f, 5.347163200f, 3.540414810f, + 3.510186911f, 4.305843830f, 5.957427025f, 7.200410843f, 7.049768448f, + 7.179680824f, 8.508881569f, 9.094768524f, 12.307214737f, 14.215225220f, + 11.316717148f, 8.660657883f, 7.528784275f, 7.616339207f, 6.968524933f, + 4.246424198f, 0.214603424f, 0.449179649f, 1.695000648f, 0.110423088f, + -0.304885864f, -2.038585663f, -5.223299980f, -5.486608505f, -5.728059292f, + -4.866038799f, -2.678806305f, -3.464673519f, -3.407086372f, -2.490849733f, + -0.161162257f, 0.118952155f, 0.312392950f, -0.341049194f, 0.013419867f, + 3.722306252f, 3.901551247f, 1.781876802f, 2.446551561f, 3.659160852f, + 2.530288696f, 3.577404499f, 3.201550961f, 0.281389952f, -0.291333675f, + 1.386508465f, 2.181721210f, -2.802821159f, -1.531007886f, 1.608560324f, + -0.523656845f, -0.281057000f, 0.571323991f, 0.668095112f, -1.637194037f, + -2.756963253f, -1.340666890f, -2.180127621f, -1.874165773f, 0.660111070f, + 0.197176635f, 0.781580091f, 1.749967933f, 0.674724638f, -2.082683325f, + -3.159717083f, -2.898023844f, -4.691623211f, -5.614190102f, -6.157790661f, + -7.776132584f, -8.029224396f, -6.940879345f, -7.065263271f, -7.003522396f, + -5.691181183f, -7.872379780f, -7.614178658f, -5.778759003f, -4.605045319f, + -4.695390224f, -5.865473270f, -5.825413227f, -4.648111820f, -2.193091869f, + -0.172003269f, 1.482686043f, -0.915655136f, -2.626194954f, 1.852293015f, + 4.184171677f, 4.083235264f, 1.048256874f, -1.361350536f, 0.438748837f, + 1.716395378f, 2.916294813f, 2.639499664f, 0.059617281f, -1.883811951f, + 2.136622429f, 6.641947269f, 5.951328754f, 3.875293493f, 3.003573895f, + 2.687273264f, 4.843512535f, 6.420391560f, 6.014624596f, 3.444208860f, + 0.717782736f, 2.659932613f, 5.204012871f, 5.516477585f, 3.315031528f, + 0.454023123f, -0.026421070f, 0.802503586f, 2.606507778f, 1.679640770f, + -1.917723656f, -3.348850250f, -2.580049515f, -1.783200264f, -0.810425520f, + -0.374402523f, -3.705567360f, -5.367071629f, -4.344952106f, -0.968293428f, + 1.147591949f, -1.240655184f, -2.621209621f, -2.452539444f, -1.543132067f, + 0.422753096f, 1.026433110f, 0.858573675f, -0.695377707f, -0.242624998f, + 3.892488956f, 4.100893021f, 3.498974323f, 1.744507313f, -0.912925899f, + 0.929271877f, 3.531583786f, 4.938030243f, 4.081199646f, 0.061933577f, + -2.232783318f, -1.356980443f, 1.794556737f, 3.510458231f, 1.323192716f, + -0.505770206f, 2.126557350f, 2.507567406f, 2.232018232f, 1.872283101f, + 1.265762568f, 0.577634692f, 0.021484375f, 3.114191532f, 1.579384208f, + 0.930754900f, 0.308351398f, -0.425426602f, 3.359810352f, 2.437057972f, + 1.210662127f, 0.708607912f, -1.576705575f, 0.007833481f, -0.178357601f, + -0.880272985f, 0.078738928f, 0.339336634f, -0.763550043f, -1.669098496f, + -2.083987713f, -1.946106076f, -0.953974366f, -0.856883168f, -1.282670021f, + -1.551425457f, -2.249363184f, -2.555188894f, -1.254808664f, -1.368662596f, + -1.839509130f, -0.839046180f, -0.452676475f, 0.721064806f, 1.988085508f, + 0.456556678f, -0.255003691f, 0.384676337f, 1.075410485f, 0.617453933f, + 1.470067143f, 1.493275523f, 0.954153359f, 1.027234554f, -0.434967309f, + -0.694453120f, 0.477285773f, 0.436861426f, 1.486879349f, -0.158989906f, + 0.361879885f, 3.234876394f, 1.105287671f, -0.982552111f, 1.514200211f, + 0.821707547f, -1.142312169f, 1.845819831f, 3.934516191f, 2.251807690f, + 0.530044913f, -1.043874860f, -0.891365111f, -0.264675498f, 0.288083673f, + 0.606682122f, -1.132072091f, -3.530973911f, -2.005296707f, 0.335011721f, + -0.240332901f, -2.763209343f, -2.148519516f, -1.864180326f, -0.814615071f, + -1.589591861f, -2.455522776f, -0.756391644f, 0.689822078f, 0.171640277f, + -0.225937843f, 0.363246441f, 0.098157287f, -1.638891220f, -0.400456548f, + 1.076233864f, 2.288599968f, 2.716089964f, 1.585703373f, 0.846301913f, + 0.887506902f, -0.439320147f, -0.823126972f, 0.712436378f, 1.027045608f, + 0.360925227f, -2.289939404f, -1.035227180f, 0.931313038f, -0.133454978f, + 0.160856903f, 0.700653732f, 0.817580283f, -0.223383546f, 0.713623106f, + 1.327106714f, 1.558022618f, 1.346337557f, -0.661301017f, 0.707845926f, + 2.435726643f, 0.763329387f, 0.485213757f, 2.295393229f, 4.525130272f, + 2.354229450f, -0.043517172f, 1.635316610f, 1.651852608f, 1.240020633f, + 0.320237398f, -0.571269870f, -0.686546564f, -1.796948791f, -0.966899753f, + -0.404109240f, -1.295783877f, -2.058131218f, -2.279026985f, -2.183017731f, + -2.516988277f, -0.276667058f, -0.475267202f, -2.645681143f, -0.504431605f, + -1.031255722f, -3.401877880f, -1.075011969f, -0.667404234f, -2.419279575f, + -1.230643749f, 1.151491284f, 0.374734998f, -2.004124880f, -1.923788905f, + -0.767004371f, 0.512374282f, 2.254727125f, 1.373157024f, 0.633022547f, + 0.194831967f, 0.226476192f, 1.294842482f, 0.838023365f, 1.291390896f, + 0.128176212f, -1.109287858f, 0.166733295f, 0.847469866f, -0.662097514f, + -0.489783406f, 1.523754478f, 1.903803706f, -0.748670340f, 0.721136212f, + 1.627746105f, -0.731291413f, 0.646574259f, 1.722917080f, 0.372141778f, + -0.063563704f, 0.916404963f, 2.092662811f, 1.699481010f, 0.181074798f, + -1.361395121f, 0.581034362f, 1.451567292f, 0.526586652f, 1.206429839f, + -1.041464567f, -2.891606331f, 0.638695598f, 1.198848009f, -0.771047413f, + -1.074250221f, -0.500067651f, 0.308775485f, 0.552724898f, 1.083443999f, + 1.371356130f, 0.360372365f, 3.391613960f, 2.896605730f, 0.799045980f, + 0.922905385f, 3.240214348f, 4.740911484f, 2.945639610f, 2.544054747f, + 3.048654795f, 3.541822433f, 4.390746117f, 5.632675171f, 7.721554756f, + 6.390114784f, 5.962307930f, 5.873732567f, 5.625522137f, 4.857854843f, + 3.148367405f, 3.966898203f, 4.309705257f, 3.543770313f, 2.427399397f, + 0.324177742f, -1.809771061f, -2.191485405f, 0.006873131f, -0.876847267f, + -0.928904057f, 0.889565945f, -0.127671242f, -1.695463657f, -1.193793774f, + -1.452976227f, -3.406696558f, -2.564189196f, -2.136555195f, -2.374645710f, + -3.230790854f, -3.076714516f, -3.245117664f, -2.254387617f, -0.245034039f, + -1.072510719f, -1.887740970f, 0.431427240f, 1.132410765f, -1.015120149f, + -0.274977922f, -1.910447717f, -2.865208864f, -0.131696820f}; + +// TODO(bugs.webrtc.org/9076): Move to resource file. +constexpr std::array kPitchBufferFrameSquareEnergies = { + 5150.291992188f, 5150.894531250f, 5145.122558594f, 5148.914062500f, + 5152.802734375f, 5156.541015625f, 5163.048339844f, 5172.149414062f, + 5177.349121094f, 5184.365722656f, 5199.292480469f, 5202.612304688f, + 5197.510253906f, 5189.979492188f, 5183.533203125f, 5190.677734375f, + 5203.943359375f, 5207.876464844f, 5209.395019531f, 5225.451660156f, + 5249.794921875f, 5271.816894531f, 5280.045410156f, 5285.289062500f, + 5288.319335938f, 5289.758789062f, 5294.285644531f, 5289.979980469f, + 5287.337402344f, 5287.237792969f, 5281.462402344f, 5271.676269531f, + 5256.257324219f, 5240.524414062f, 5230.869628906f, 5207.531250000f, + 5176.040039062f, 5144.021484375f, 5109.295410156f, 5068.527832031f, + 5008.909667969f, 4977.587890625f, 4959.000976562f, 4950.016601562f, + 4940.795410156f, 4937.358398438f, 4935.286132812f, 4914.154296875f, + 4906.706542969f, 4906.924804688f, 4907.674804688f, 4899.855468750f, + 4894.340820312f, 4906.948242188f, 4910.065429688f, 4921.032714844f, + 4949.294433594f, 4982.643066406f, 5000.996093750f, 5005.875488281f, + 5020.441894531f, 5031.938964844f, 5041.877441406f, 5035.990722656f, + 5037.362792969f, 5043.038085938f, 5044.236328125f, 5042.322753906f, + 5041.990722656f, 5047.362304688f, 5056.785644531f, 5054.579101562f, + 5050.326171875f, 5053.495117188f, 5060.186523438f, 5065.591796875f, + 5066.717285156f, 5069.499511719f, 5076.201171875f, 5076.687011719f, + 5076.316894531f, 5077.581054688f, 5076.226074219f, 5074.094238281f, + 5074.039062500f, 5073.663574219f, 5076.283691406f, 5077.278808594f, + 5076.094238281f, 5077.806152344f, 5081.035644531f, 5082.431640625f, + 5082.995605469f, 5084.653320312f, 5084.936035156f, 5085.394042969f, + 5085.735351562f, 5080.651855469f, 5080.542968750f, 5079.969238281f, + 5076.432617188f, 5072.439453125f, 5073.252441406f, 5071.974609375f, + 5071.458496094f, 5066.017578125f, 5065.670898438f, 5065.144042969f, + 5055.592773438f, 5060.104980469f, 5060.505371094f, 5054.157226562f, + 5056.915039062f, 5067.208007812f, 5060.940917969f, 5058.419921875f, + 5053.248046875f, 5049.823730469f, 5048.573242188f, 5053.195312500f, + 5053.444335938f, 5054.143066406f, 5056.270019531f, 5063.881835938f, + 5070.784667969f, 5074.042480469f, 5080.785156250f, 5085.663085938f, + 5095.979003906f, 5101.596191406f, 5088.784667969f, 5087.686523438f, + 5087.946777344f, 5087.369140625f, 5081.445312500f, 5081.519042969f, + 5087.940917969f, 5102.099121094f, 5126.864257812f, 5147.613281250f, + 5170.079589844f, 5189.276367188f, 5210.265136719f, 5244.745117188f, + 5268.821777344f, 5277.381835938f, 5279.768066406f, 5278.750000000f, + 5283.853027344f, 5292.671386719f, 5291.744628906f, 5294.732421875f, + 5294.322265625f, 5294.267089844f, 5297.530761719f, 5302.179199219f, + 5312.768066406f, 5323.202148438f, 5335.357910156f, 5344.610839844f, + 5347.597167969f, 5346.077148438f, 5346.071289062f, 5346.083984375f, + 5348.088378906f, 5349.661621094f, 5350.157226562f, 5351.855957031f, + 5347.257812500f, 5345.171875000f, 5344.617675781f, 5343.106445312f, + 5342.778808594f, 5338.655761719f, 5341.668457031f, 5347.518066406f, + 5362.014160156f, 5361.167968750f, 5362.926269531f, 5371.575195312f, + 5374.099609375f, 5381.186523438f, 5381.963867188f, 5386.806152344f, + 5389.590820312f, 5384.562011719f, 5372.485839844f, 5370.576660156f, + 5369.640136719f, 5369.698242188f, 5371.199707031f, 5372.644531250f, + 5394.006835938f, 5395.366699219f, 5395.259277344f, 5395.398437500f, + 5395.895507812f, 5401.420898438f, 5420.036621094f, 5434.017578125f, + 5434.215820312f, 5437.827636719f, 5442.944335938f, 5450.980468750f, + 5449.246582031f, 5449.135742188f, 5453.259765625f, 5453.792968750f, + 5459.676757812f, 5460.213867188f, 5479.227539062f, 5512.076171875f, + 5520.272949219f, 5519.662109375f, 5517.395996094f, 5516.550292969f, + 5520.786621094f, 5527.268066406f, 5526.668457031f, 5549.916992188f, + 5577.750976562f, 5580.141113281f, 5579.533691406f, 5576.632324219f, + 5573.938476562f, 5571.166503906f, 5570.603027344f, 5570.708496094f, + 5577.238769531f, 5577.625976562f, 5589.325683594f, 5602.189941406f, + 5612.587402344f, 5613.887695312f, 5613.588867188f, 5608.100585938f, + 5632.956054688f, 5679.322265625f, 5682.149414062f, 5683.846191406f, + 5691.708496094f, 5683.279785156f, 5694.248535156f, 5744.740722656f, + 5756.655761719f, 5755.952148438f, 5756.665527344f, 5750.700195312f, + 5784.060546875f, 5823.021972656f, 5829.233398438f, 5817.804687500f, + 5827.333984375f, 5826.451171875f, 5824.887695312f, 5825.734375000f, + 5813.386230469f, 5789.609863281f, 5779.115234375f, 5778.762695312f, + 5785.748046875f, 5792.981933594f, 5787.567871094f, 5778.096679688f, + 5764.337402344f, 5766.734375000f, 5766.489746094f, 5769.543945312f, + 5773.183593750f, 5775.720703125f, 5774.311523438f, 5769.303710938f, + 5765.815917969f, 5767.521484375f, 5775.251953125f, 5785.067382812f, + 5770.117187500f, 5749.073242188f, 5747.606933594f, 5757.671875000f, + 5762.530273438f, 5774.506347656f, 5784.737304688f, 5775.916015625f, + 5779.816894531f, 5795.064453125f, 5808.736816406f, 5813.699707031f, + 5823.773925781f, 5840.490234375f, 5833.751953125f, 5810.150390625f, + 5800.072265625f, 5815.070800781f, 5822.964355469f, 5817.615234375f, + 5783.978027344f, 5748.952636719f, 5735.553710938f, 5730.132812500f, + 5724.260253906f, 5721.703613281f, 5695.653808594f, 5652.838867188f, + 5649.729980469f, 5647.268554688f, 5647.265136719f, 5641.350585938f, + 5636.762695312f, 5637.900390625f, 5639.662109375f, 5639.672851562f, + 5638.901367188f, 5622.253417969f, 5604.906738281f, 5601.475585938f, + 5595.938476562f, 5595.687011719f, 5598.612792969f, 5601.322753906f, + 5598.558593750f, 5577.227050781f, 5544.295410156f, 5514.978027344f, + 5499.678222656f, 5488.303222656f, 5471.735839844f, 5429.718261719f, + 5376.806640625f, 5348.682128906f, 5307.851074219f, 5260.914062500f, + 5212.738281250f, 5148.544921875f, 5091.187500000f, 5053.512207031f, + 5023.785156250f, 5002.202148438f, 4994.252441406f, 4984.498046875f, + 4980.251464844f, 4979.796875000f, 4976.738769531f, 4979.579589844f, + 4986.528320312f, 4991.153808594f, 4991.462890625f, 4987.881347656f, + 4987.417480469f, 4983.885742188f, 4984.341308594f, 4985.302734375f, + 4985.303710938f, 4985.449707031f, 4989.282226562f, 4994.246582031f, + 4992.635742188f, 4992.064453125f, 4987.331054688f, 4985.806152344f, + 4986.047851562f, 4985.968750000f, 4979.141113281f, 4976.958984375f, + 4972.650390625f, 4959.916503906f, 4956.325683594f, 4956.408691406f, + 4949.288085938f, 4951.827636719f, 4962.202636719f, 4981.184570312f, + 4992.152832031f, 4997.386230469f, 5011.211914062f, 5026.242187500f, + 5023.573730469f, 5012.373046875f, 5017.451171875f, 5010.541015625f, + 4980.446777344f, 4958.639648438f, 4963.649902344f, 5627.020507812f, + 6869.356445312f}; + +// TODO(bugs.webrtc.org/9076): Move to resource file. +constexpr std::array kPitchBufferAutoCorrCoeffs = { + -423.526794434f, -260.724456787f, -173.558380127f, -71.720344543f, + -1.149698257f, 71.451370239f, 71.455848694f, 149.755233765f, + 199.401885986f, 243.961334229f, 269.339721680f, 243.776992798f, + 294.753814697f, 209.465484619f, 139.224700928f, 131.474136353f, + 42.872886658f, -32.431114197f, -90.191261292f, -94.912338257f, + -172.627227783f, -138.089843750f, -89.236648560f, -69.348426819f, + 25.044368744f, 44.184486389f, 61.602676392f, 150.157394409f, + 185.254760742f, 233.352676392f, 296.255371094f, 292.464141846f, + 256.903472900f, 250.926574707f, 174.207122803f, 130.214172363f, + 65.655899048f, -68.448402405f, -147.239669800f, -230.553405762f, + -311.217895508f, -447.173889160f, -509.306060791f, -551.155822754f, + -580.678405762f, -658.902709961f, -697.141967773f, -751.233032227f, + -690.860351562f, -571.689575195f, -521.124572754f, -429.477294922f, + -375.685913086f, -277.387329102f, -154.100753784f, -105.723197937f, + 117.502632141f, 219.290512085f, 255.376770020f, 444.264831543f, + 470.727416992f, 460.139129639f, 494.179931641f, 389.801116943f, + 357.082763672f, 222.748138428f, 179.100601196f, -26.893497467f, + -85.033767700f, -223.577529907f, -247.136367798f, -223.011428833f, + -292.724914551f, -246.538131714f, -247.388458252f, -228.452484131f, + -30.476575851f, 4.652336121f, 64.730491638f, 156.081161499f, + 177.569305420f, 261.671569824f, 336.274414062f, 424.203369141f, + 564.190734863f, 608.841796875f, 671.252136230f, 712.249877930f, + 623.135498047f, 564.775695801f, 576.405639648f, 380.181854248f, + 306.687164307f, 180.344757080f, -41.317466736f, -183.548736572f, + -223.835021973f, -273.299652100f, -235.727813721f, -276.899627686f, + -302.224975586f, -349.227142334f, -370.935058594f, -364.022613525f, + -287.682952881f, -273.828704834f, -156.869720459f, -88.654510498f, + 14.299798012f, 137.048034668f, 260.182342529f, 423.380767822f, + 591.277282715f, 581.151306152f, 643.898864746f, 547.919006348f, + 355.534271240f, 238.222915649f, 4.463035583f, -193.763305664f, + -281.212432861f, -546.399353027f, -615.602600098f, -574.225891113f, + -726.701843262f, -564.840942383f, -588.488037109f, -651.052551270f, + -453.769104004f, -502.886627197f, -463.373016357f, -291.709564209f, + -288.857421875f, -152.114242554f, 105.401855469f, 211.479980469f, + 468.501983643f, 796.984985352f, 880.254089355f, 1114.614379883f, + 1219.664794922f, 1093.687377930f, 1125.042602539f, 1020.942382812f, + 794.315246582f, 772.126831055f, 447.410736084f}; + +constexpr std::array kTestPitchPeriods = { + 3 * kMinPitch48kHz / 2, (3 * kMinPitch48kHz + kMaxPitch48kHz) / 2, +}; +constexpr std::array kTestPitchGains = {0.35f, 0.75f}; + +} // namespace + +class ComputePitchGainThresholdTest + : public testing::Test, + public ::testing::WithParamInterface< + std::tuple> {}; + +TEST_P(ComputePitchGainThresholdTest, BitExactness) { + const auto params = GetParam(); + const size_t candidate_pitch_period = std::get<0>(params); + const size_t pitch_period_ratio = std::get<1>(params); + const size_t initial_pitch_period = std::get<2>(params); + const float initial_pitch_gain = std::get<3>(params); + const size_t prev_pitch_period = std::get<4>(params); + const size_t prev_pitch_gain = std::get<5>(params); + const float threshold = std::get<6>(params); + + { + // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. + // FloatingPointExceptionObserver fpe_observer; + + EXPECT_NEAR( + threshold, + ComputePitchGainThreshold(candidate_pitch_period, pitch_period_ratio, + initial_pitch_period, initial_pitch_gain, + prev_pitch_period, prev_pitch_gain), + 3e-6f); + } +} + +INSTANTIATE_TEST_CASE_P( + RnnVadTest, + ComputePitchGainThresholdTest, + ::testing::Values( + std::make_tuple(31, 7, 219, 0.45649201f, 199, 0.604747f, 0.40000001f), + std::make_tuple(113, + 2, + 226, + 0.20967799f, + 219, + 0.40392199f, + 0.30000001f), + std::make_tuple(63, 2, 126, 0.210788f, 364, 0.098519f, 0.40000001f), + std::make_tuple(30, 5, 152, 0.82356697f, 149, 0.55535901f, 0.700032f), + std::make_tuple(76, 2, 151, 0.79522997f, 151, 0.82356697f, 0.675946f), + std::make_tuple(31, 5, 153, 0.85069299f, 150, 0.79073799f, 0.72308898f), + std::make_tuple(78, 2, 156, 0.72750503f, 153, 0.85069299f, 0.618379f))); + +TEST(RnnVadTest, ComputeSlidingFrameSquareEnergiesBitExactness) { + std::array computed_output; + { + // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. + // FloatingPointExceptionObserver fpe_observer; + + ComputeSlidingFrameSquareEnergies( + {kPitchBufferData.data(), kPitchBufferData.size()}, + {computed_output.data(), computed_output.size()}); + } + ExpectNearAbsolute({kPitchBufferFrameSquareEnergies.data(), + kPitchBufferFrameSquareEnergies.size()}, + {computed_output.data(), computed_output.size()}, 3e-2f); +} + +TEST(RnnVadTest, ComputePitchAutoCorrelationBitExactness) { + std::array pitch_buf_decimated; + Decimate2x({kPitchBufferData.data(), kPitchBufferData.size()}, + {pitch_buf_decimated.data(), pitch_buf_decimated.size()}); + std::array computed_output; + { + // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. + // FloatingPointExceptionObserver fpe_observer; + + ComputePitchAutoCorrelation( + {pitch_buf_decimated.data(), pitch_buf_decimated.size()}, + kMaxPitch12kHz, {computed_output.data(), computed_output.size()}); + } + ExpectNearAbsolute( + {kPitchBufferAutoCorrCoeffs.data(), kPitchBufferAutoCorrCoeffs.size()}, + {computed_output.data(), computed_output.size()}, 3e-3f); +} + +TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) { + std::array pitch_buf_decimated; + Decimate2x({kPitchBufferData.data(), kPitchBufferData.size()}, + {pitch_buf_decimated.data(), pitch_buf_decimated.size()}); + std::array pitch_candidates_inv_lags; + { + // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. + // FloatingPointExceptionObserver fpe_observer; + + pitch_candidates_inv_lags = FindBestPitchPeriods( + {kPitchBufferAutoCorrCoeffs}, {pitch_buf_decimated}, kMaxPitch12kHz); + } + const std::array expected_output = {140, 142}; + EXPECT_EQ(expected_output, pitch_candidates_inv_lags); +} + +TEST(RnnVadTest, RefinePitchPeriod48kHzBitExactness) { + std::array pitch_buf_decimated; + Decimate2x({kPitchBufferData.data(), kPitchBufferData.size()}, + {pitch_buf_decimated.data(), pitch_buf_decimated.size()}); + size_t pitch_inv_lag; + { + // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. + // FloatingPointExceptionObserver fpe_observer; + + const std::array pitch_candidates_inv_lags = {280, 284}; + pitch_inv_lag = RefinePitchPeriod48kHz( + {kPitchBufferData.data(), kPitchBufferData.size()}, + {pitch_candidates_inv_lags.data(), pitch_candidates_inv_lags.size()}); + } + EXPECT_EQ(560u, pitch_inv_lag); +} + +class CheckLowerPitchPeriodsAndComputePitchGainTest + : public testing::Test, + public ::testing::WithParamInterface< + std::tuple> {}; + +TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest, BitExactness) { + const auto params = GetParam(); + const size_t initial_pitch_period = std::get<0>(params); + const size_t prev_pitch_period = std::get<1>(params); + const float prev_pitch_gain = std::get<2>(params); + const size_t expected_pitch_period = std::get<3>(params); + const float expected_pitch_gain = std::get<4>(params); + + { + // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. + // FloatingPointExceptionObserver fpe_observer; + + const auto computed_output = CheckLowerPitchPeriodsAndComputePitchGain( + {kPitchBufferData.data(), kPitchBufferData.size()}, + initial_pitch_period, {prev_pitch_period, prev_pitch_gain}); + EXPECT_EQ(expected_pitch_period, computed_output.period); + EXPECT_NEAR(expected_pitch_gain, computed_output.gain, 1e-6f); + } +} + +INSTANTIATE_TEST_CASE_P(RnnVadTest, + CheckLowerPitchPeriodsAndComputePitchGainTest, + ::testing::Values(std::make_tuple(kTestPitchPeriods[0], + kTestPitchPeriods[0], + kTestPitchGains[0], + 91, + -0.0188608f), + std::make_tuple(kTestPitchPeriods[0], + kTestPitchPeriods[0], + kTestPitchGains[1], + 91, + -0.0188608f), + std::make_tuple(kTestPitchPeriods[0], + kTestPitchPeriods[1], + kTestPitchGains[0], + 91, + -0.0188608f), + std::make_tuple(kTestPitchPeriods[0], + kTestPitchPeriods[1], + kTestPitchGains[1], + 91, + -0.0188608f), + std::make_tuple(kTestPitchPeriods[1], + kTestPitchPeriods[0], + kTestPitchGains[0], + 475, + -0.0904344f), + std::make_tuple(kTestPitchPeriods[1], + kTestPitchPeriods[0], + kTestPitchGains[1], + 475, + -0.0904344f), + std::make_tuple(kTestPitchPeriods[1], + kTestPitchPeriods[1], + kTestPitchGains[0], + 475, + -0.0904344f), + std::make_tuple(kTestPitchPeriods[1], + kTestPitchPeriods[1], + kTestPitchGains[1], + 475, + -0.0904344f))); + +} // namespace test +} // namespace rnn_vad +} // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc new file mode 100644 index 0000000000..441776465d --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/rnn_vad/pitch_search.h" + +#include + +#include "modules/audio_processing/agc2/rnn_vad/test_utils.h" +// TODO(bugs.webrtc.org/8948): Add when the issue is fixed. +// #include "test/fpe_observer.h" +#include "test/gtest.h" + +namespace webrtc { +namespace rnn_vad { +namespace test { + +// TODO(bugs.webrtc.org/9076): Remove when the issue is fixed. +TEST(RnnVadTest, PitchSearchBitExactness) { + auto lp_residual_reader = CreateLpResidualAndPitchPeriodGainReader(); + const size_t num_frames = lp_residual_reader.second; + std::array lp_residual; + float expected_pitch_period, expected_pitch_gain; + PitchInfo last_pitch; + { + // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. + // FloatingPointExceptionObserver fpe_observer; + + for (size_t i = 0; i < num_frames; ++i) { + SCOPED_TRACE(i); + lp_residual_reader.first->ReadChunk( + {lp_residual.data(), lp_residual.size()}); + lp_residual_reader.first->ReadValue(&expected_pitch_period); + lp_residual_reader.first->ReadValue(&expected_pitch_gain); + last_pitch = + PitchSearch({lp_residual.data(), lp_residual.size()}, last_pitch); + EXPECT_EQ(static_cast(expected_pitch_period), last_pitch.period); + EXPECT_NEAR(expected_pitch_gain, last_pitch.gain, 1e-5f); + } + } +} + +} // namespace test +} // namespace rnn_vad +} // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/ring_buffer_unittest.cc b/modules/audio_processing/agc2/rnn_vad/ring_buffer_unittest.cc index 0848f8d56c..91383d11bf 100644 --- a/modules/audio_processing/agc2/rnn_vad/ring_buffer_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/ring_buffer_unittest.cc @@ -13,10 +13,8 @@ #include "test/gtest.h" namespace webrtc { +namespace rnn_vad { namespace test { - -using rnn_vad::RingBuffer; - namespace { // Compare the elements of two given array views. @@ -113,4 +111,5 @@ TEST(RnnVadTest, RingBufferFloating) { } } // namespace test +} // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_vad_tool.cc b/modules/audio_processing/agc2/rnn_vad/rnn_vad_tool.cc index af8905285c..6ab932c652 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn_vad_tool.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn_vad_tool.cc @@ -79,7 +79,7 @@ int main(int argc, char* argv[]) { std::array samples_10ms_24kHz; PushSincResampler resampler(frame_size_10ms, kFrameSize10ms24kHz); - // TODO(alessiob): Init feature extractor and RNN-based VAD. + // TODO(bugs.webrtc.org/9076): Init feature extractor and RNN-based VAD. // Compute VAD probabilities. while (true) { @@ -93,7 +93,7 @@ int main(int argc, char* argv[]) { resampler.Resample(samples_10ms.data(), samples_10ms.size(), samples_10ms_24kHz.data(), samples_10ms_24kHz.size()); - // TODO(alessiob): Extract features. + // TODO(bugs.webrtc.org/9076): Extract features. float vad_probability; bool is_silence = true; @@ -101,15 +101,15 @@ int main(int argc, char* argv[]) { if (features_file) { const float float_is_silence = is_silence ? 1.f : 0.f; fwrite(&float_is_silence, sizeof(float), 1, features_file); - // TODO(alessiob): Write feature vector. + // TODO(bugs.webrtc.org/9076): Write feature vector. } // Compute VAD probability. if (is_silence) { vad_probability = 0.f; - // TODO(alessiob): Reset VAD. + // TODO(bugs.webrtc.org/9076): Reset VAD. } else { - // TODO(alessiob): Compute VAD probability. + // TODO(bugs.webrtc.org/9076): Compute VAD probability. } RTC_DCHECK_GE(vad_probability, 0.f); RTC_DCHECK_GE(1.f, vad_probability); diff --git a/modules/audio_processing/agc2/rnn_vad/sequence_buffer_unittest.cc b/modules/audio_processing/agc2/rnn_vad/sequence_buffer_unittest.cc index f15a256e69..7628c17689 100644 --- a/modules/audio_processing/agc2/rnn_vad/sequence_buffer_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/sequence_buffer_unittest.cc @@ -15,10 +15,8 @@ #include "test/gtest.h" namespace webrtc { +namespace rnn_vad { namespace test { - -using rnn_vad::SequenceBuffer; - namespace { template @@ -103,4 +101,5 @@ TEST(RnnVadTest, SequenceBufferPushOpsFloating) { } } // namespace test +} // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer_unittest.cc b/modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer_unittest.cc index 408467a259..a1b8007696 100644 --- a/modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer_unittest.cc @@ -14,12 +14,10 @@ #include "test/gtest.h" namespace webrtc { +namespace rnn_vad { namespace test { namespace { -using rnn_vad::RingBuffer; -using rnn_vad::SymmetricMatrixBuffer; - template void CheckSymmetry(const SymmetricMatrixBuffer* sym_matrix_buf) { for (size_t row = 0; row < S - 1; ++row) @@ -108,4 +106,5 @@ TEST(RnnVadTest, SymmetricMatrixBufferUseCase) { } } // namespace test +} // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.cc b/modules/audio_processing/agc2/rnn_vad/test_utils.cc index 374f0b6a8a..c6cf21e61b 100644 --- a/modules/audio_processing/agc2/rnn_vad/test_utils.cc +++ b/modules/audio_processing/agc2/rnn_vad/test_utils.cc @@ -16,6 +16,7 @@ #include "test/testsupport/fileutils.h" namespace webrtc { +namespace rnn_vad { namespace test { namespace { @@ -24,9 +25,11 @@ using ReaderPairType = } // namespace +using webrtc::test::ResourcePath; + void ExpectNearAbsolute(rtc::ArrayView expected, rtc::ArrayView computed, - const float tolerance) { + float tolerance) { ASSERT_EQ(expected.size(), computed.size()); for (size_t i = 0; i < expected.size(); ++i) { SCOPED_TRACE(i); @@ -36,8 +39,7 @@ void ExpectNearAbsolute(rtc::ArrayView expected, ReaderPairType CreatePitchBuffer24kHzReader() { auto ptr = rtc::MakeUnique>( - test::ResourcePath("audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat"), - 864); + ResourcePath("audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat"), 864); return {std::move(ptr), rtc::CheckedDivExact(ptr->data_length(), static_cast(864))}; } @@ -45,11 +47,12 @@ ReaderPairType CreatePitchBuffer24kHzReader() { ReaderPairType CreateLpResidualAndPitchPeriodGainReader() { constexpr size_t num_lp_residual_coeffs = 864; auto ptr = rtc::MakeUnique>( - test::ResourcePath("audio_processing/agc2/rnn_vad/pitch_lp_res", "dat"), + ResourcePath("audio_processing/agc2/rnn_vad/pitch_lp_res", "dat"), num_lp_residual_coeffs); return {std::move(ptr), rtc::CheckedDivExact(ptr->data_length(), 2 + num_lp_residual_coeffs)}; } } // namespace test +} // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.h b/modules/audio_processing/agc2/rnn_vad/test_utils.h index 6c16447889..3f580ab48c 100644 --- a/modules/audio_processing/agc2/rnn_vad/test_utils.h +++ b/modules/audio_processing/agc2/rnn_vad/test_utils.h @@ -23,6 +23,7 @@ #include "rtc_base/checks.h" namespace webrtc { +namespace rnn_vad { namespace test { constexpr float kFloatMin = std::numeric_limits::min(); @@ -31,15 +32,14 @@ constexpr float kFloatMin = std::numeric_limits::min(); // that their absolute error is above a given threshold. void ExpectNearAbsolute(rtc::ArrayView expected, rtc::ArrayView computed, - const float tolerance); + float tolerance); // Reader for binary files consisting of an arbitrary long sequence of elements // having type T. It is possible to read and cast to another type D at once. template class BinaryFileReader { public: - explicit BinaryFileReader(const std::string& file_path, - const size_t chunk_size = 1) + explicit BinaryFileReader(const std::string& file_path, size_t chunk_size = 1) : is_(file_path, std::ios::binary | std::ios::ate), data_length_(is_.tellg() / sizeof(T)), chunk_size_(chunk_size) { @@ -97,6 +97,7 @@ std::pair>, const size_t> CreateLpResidualAndPitchPeriodGainReader(); } // namespace test +} // namespace rnn_vad } // namespace webrtc #endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_