diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn index 7822901fba..292caae3d2 100644 --- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn +++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn @@ -33,6 +33,8 @@ rtc_library("rnn_vad") { "../../../../api:function_view", "../../../../rtc_base:checks", "../../../../rtc_base:logging", + "../../../../rtc_base:safe_compare", + "../../../../rtc_base:safe_conversions", "../../../../rtc_base/system:arch", "//third_party/rnnoise:rnn_vad", ] @@ -93,6 +95,7 @@ rtc_library("rnn_vad_pitch") { "../../../../api:array_view", "../../../../rtc_base:checks", "../../../../rtc_base:safe_compare", + "../../../../rtc_base:safe_conversions", ] } @@ -125,6 +128,7 @@ rtc_library("rnn_vad_spectral_features") { ":rnn_vad_symmetric_matrix_buffer", "../../../../api:array_view", "../../../../rtc_base:checks", + "../../../../rtc_base:safe_compare", "../../utility:pffft_wrapper", ] } @@ -134,6 +138,7 @@ rtc_source_set("rnn_vad_symmetric_matrix_buffer") { deps = [ "../../../../api:array_view", "../../../../rtc_base:checks", + "../../../../rtc_base:safe_compare", ] } @@ -150,6 +155,7 @@ if (rtc_include_tests) { "../../../../api:array_view", "../../../../api:scoped_refptr", "../../../../rtc_base:checks", + "../../../../rtc_base:safe_compare", "../../../../rtc_base/system:arch", "../../../../system_wrappers", "../../../../test:fileutils", @@ -206,6 +212,8 @@ if (rtc_include_tests) { "../../../../common_audio/", "../../../../rtc_base:checks", "../../../../rtc_base:logging", + "../../../../rtc_base:safe_compare", + "../../../../rtc_base:safe_conversions", "../../../../rtc_base/system:arch", "../../../../test:test_support", "../../utility:pffft_wrapper", @@ -227,6 +235,7 @@ if (rtc_include_tests) { "../../../../api:array_view", "../../../../common_audio", "../../../../rtc_base:rtc_base_approved", + "../../../../rtc_base:safe_compare", "../../../../test:test_support", "//third_party/abseil-cpp/absl/flags:flag", "//third_party/abseil-cpp/absl/flags:parse", diff --git a/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc b/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc index d932c78063..f6a4f42fd6 100644 --- a/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc +++ b/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc @@ -48,8 +48,8 @@ void AutoCorrelationCalculator::ComputeOnPitchBuffer( rtc::ArrayView auto_corr) { RTC_DCHECK_LT(auto_corr.size(), kMaxPitch12kHz); RTC_DCHECK_GT(pitch_buf.size(), kMaxPitch12kHz); - constexpr size_t kFftFrameSize = 1 << kAutoCorrelationFftOrder; - constexpr size_t kConvolutionLength = kBufSize12kHz - kMaxPitch12kHz; + constexpr int kFftFrameSize = 1 << kAutoCorrelationFftOrder; + constexpr int kConvolutionLength = kBufSize12kHz - kMaxPitch12kHz; static_assert(kConvolutionLength == kFrameSize20ms12kHz, "Mismatch between pitch buffer size, frame size and maximum " "pitch period."); diff --git a/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc b/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc index f66c0b299b..ef3748d7cf 100644 --- a/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc @@ -54,7 +54,7 @@ TEST(RnnVadTest, CheckAutoCorrelationOnConstantPitchBuffer) { } // The expected output is a vector filled with the same expected // auto-correlation value. The latter equals the length of a 20 ms frame. - constexpr size_t kFrameSize20ms12kHz = kFrameSize20ms24kHz / 2; + constexpr int kFrameSize20ms12kHz = kFrameSize20ms24kHz / 2; std::array expected_output; std::fill(expected_output.begin(), expected_output.end(), static_cast(kFrameSize20ms12kHz)); diff --git a/modules/audio_processing/agc2/rnn_vad/common.h b/modules/audio_processing/agc2/rnn_vad/common.h index c2e8df6905..d6deff1556 100644 --- a/modules/audio_processing/agc2/rnn_vad/common.h +++ b/modules/audio_processing/agc2/rnn_vad/common.h @@ -18,52 +18,52 @@ namespace rnn_vad { constexpr double kPi = 3.14159265358979323846; -constexpr size_t kSampleRate24kHz = 24000; -constexpr size_t kFrameSize10ms24kHz = kSampleRate24kHz / 100; -constexpr size_t kFrameSize20ms24kHz = kFrameSize10ms24kHz * 2; +constexpr int kSampleRate24kHz = 24000; +constexpr int kFrameSize10ms24kHz = kSampleRate24kHz / 100; +constexpr int kFrameSize20ms24kHz = kFrameSize10ms24kHz * 2; // Pitch buffer. -constexpr size_t kMinPitch24kHz = kSampleRate24kHz / 800; // 0.00125 s. -constexpr size_t kMaxPitch24kHz = kSampleRate24kHz / 62.5; // 0.016 s. -constexpr size_t kBufSize24kHz = kMaxPitch24kHz + kFrameSize20ms24kHz; +constexpr int kMinPitch24kHz = kSampleRate24kHz / 800; // 0.00125 s. +constexpr int kMaxPitch24kHz = kSampleRate24kHz / 62.5; // 0.016 s. +constexpr int kBufSize24kHz = kMaxPitch24kHz + kFrameSize20ms24kHz; static_assert((kBufSize24kHz & 1) == 0, "The buffer size must be even."); // 24 kHz analysis. // Define a higher minimum pitch period for the initial search. This is used to // avoid searching for very short periods, for which a refinement step is // responsible. -constexpr size_t kInitialMinPitch24kHz = 3 * kMinPitch24kHz; +constexpr int kInitialMinPitch24kHz = 3 * kMinPitch24kHz; static_assert(kMinPitch24kHz < kInitialMinPitch24kHz, ""); static_assert(kInitialMinPitch24kHz < kMaxPitch24kHz, ""); static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, ""); -constexpr size_t kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz; +constexpr int kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz; // 12 kHz analysis. -constexpr size_t kSampleRate12kHz = 12000; -constexpr size_t kFrameSize10ms12kHz = kSampleRate12kHz / 100; -constexpr size_t kFrameSize20ms12kHz = kFrameSize10ms12kHz * 2; -constexpr size_t kBufSize12kHz = kBufSize24kHz / 2; -constexpr size_t kInitialMinPitch12kHz = kInitialMinPitch24kHz / 2; -constexpr size_t kMaxPitch12kHz = kMaxPitch24kHz / 2; +constexpr int kSampleRate12kHz = 12000; +constexpr int kFrameSize10ms12kHz = kSampleRate12kHz / 100; +constexpr int kFrameSize20ms12kHz = kFrameSize10ms12kHz * 2; +constexpr int kBufSize12kHz = kBufSize24kHz / 2; +constexpr int kInitialMinPitch12kHz = kInitialMinPitch24kHz / 2; +constexpr int kMaxPitch12kHz = kMaxPitch24kHz / 2; static_assert(kMaxPitch12kHz > kInitialMinPitch12kHz, ""); // The inverted lags for the pitch interval [|kInitialMinPitch12kHz|, // |kMaxPitch12kHz|] are in the range [0, |kNumInvertedLags12kHz|]. -constexpr size_t kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz; +constexpr int kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz; // 48 kHz constants. -constexpr size_t kMinPitch48kHz = kMinPitch24kHz * 2; -constexpr size_t kMaxPitch48kHz = kMaxPitch24kHz * 2; +constexpr int kMinPitch48kHz = kMinPitch24kHz * 2; +constexpr int kMaxPitch48kHz = kMaxPitch24kHz * 2; // Spectral features. -constexpr size_t kNumBands = 22; -constexpr size_t kNumLowerBands = 6; +constexpr int kNumBands = 22; +constexpr int kNumLowerBands = 6; static_assert((0 < kNumLowerBands) && (kNumLowerBands < kNumBands), ""); -constexpr size_t kCepstralCoeffsHistorySize = 8; +constexpr int kCepstralCoeffsHistorySize = 8; static_assert(kCepstralCoeffsHistorySize > 2, "The history size must at least be 3 to compute first and second " "derivatives."); -constexpr size_t kFeatureVectorSize = 42; +constexpr int kFeatureVectorSize = 42; enum class Optimization { kNone, kSse2, kNeon }; diff --git a/modules/audio_processing/agc2/rnn_vad/features_extraction.cc b/modules/audio_processing/agc2/rnn_vad/features_extraction.cc index e9351797f5..c207baeec0 100644 --- a/modules/audio_processing/agc2/rnn_vad/features_extraction.cc +++ b/modules/audio_processing/agc2/rnn_vad/features_extraction.cc @@ -69,7 +69,7 @@ bool FeaturesExtractor::CheckSilenceComputeFeatures( // into the output vector (normalization based on training data stats). pitch_info_48kHz_ = pitch_estimator_.Estimate(lp_residual_view_); feature_vector[kFeatureVectorSize - 2] = - 0.01f * (static_cast(pitch_info_48kHz_.period) - 300); + 0.01f * (pitch_info_48kHz_.period - 300); // Extract lagged frames (according to the estimated pitch period). RTC_DCHECK_LE(pitch_info_48kHz_.period / 2, kMaxPitch24kHz); auto lagged_frame = pitch_buf_24kHz_view_.subview( diff --git a/modules/audio_processing/agc2/rnn_vad/features_extraction_unittest.cc b/modules/audio_processing/agc2/rnn_vad/features_extraction_unittest.cc index c00fc232eb..9df52738b4 100644 --- a/modules/audio_processing/agc2/rnn_vad/features_extraction_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/features_extraction_unittest.cc @@ -14,6 +14,8 @@ #include #include "modules/audio_processing/agc2/rnn_vad/test_utils.h" +#include "rtc_base/numerics/safe_compare.h" +#include "rtc_base/numerics/safe_conversions.h" // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // #include "test/fpe_observer.h" #include "test/gtest.h" @@ -23,26 +25,25 @@ namespace rnn_vad { namespace test { namespace { -constexpr size_t ceil(size_t n, size_t m) { +constexpr int ceil(int n, int m) { return (n + m - 1) / m; } // Number of 10 ms frames required to fill a pitch buffer having size // |kBufSize24kHz|. -constexpr size_t kNumTestDataFrames = ceil(kBufSize24kHz, kFrameSize10ms24kHz); +constexpr int kNumTestDataFrames = ceil(kBufSize24kHz, kFrameSize10ms24kHz); // Number of samples for the test data. -constexpr size_t kNumTestDataSize = kNumTestDataFrames * kFrameSize10ms24kHz; +constexpr int kNumTestDataSize = kNumTestDataFrames * kFrameSize10ms24kHz; // Verifies that the pitch in Hz is in the detectable range. bool PitchIsValid(float pitch_hz) { - const size_t pitch_period = - static_cast(static_cast(kSampleRate24kHz) / pitch_hz); + const int pitch_period = static_cast(kSampleRate24kHz) / pitch_hz; return kInitialMinPitch24kHz <= pitch_period && pitch_period <= kMaxPitch24kHz; } void CreatePureTone(float amplitude, float freq_hz, rtc::ArrayView dst) { - for (size_t i = 0; i < dst.size(); ++i) { + for (int i = 0; rtc::SafeLt(i, dst.size()); ++i) { dst[i] = amplitude * std::sin(2.f * kPi * freq_hz * i / kSampleRate24kHz); } } @@ -56,8 +57,8 @@ bool FeedTestData(FeaturesExtractor* features_extractor, // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; bool is_silence = true; - const size_t num_frames = samples.size() / kFrameSize10ms24kHz; - for (size_t i = 0; i < num_frames; ++i) { + const int num_frames = samples.size() / kFrameSize10ms24kHz; + for (int i = 0; i < num_frames; ++i) { is_silence = features_extractor->CheckSilenceComputeFeatures( {samples.data() + i * kFrameSize10ms24kHz, kFrameSize10ms24kHz}, feature_vector); @@ -79,13 +80,13 @@ TEST(RnnVadTest, FeatureExtractionLowHighPitch) { FeaturesExtractor features_extractor; std::vector samples(kNumTestDataSize); std::vector feature_vector(kFeatureVectorSize); - ASSERT_EQ(kFeatureVectorSize, feature_vector.size()); + ASSERT_EQ(kFeatureVectorSize, rtc::dchecked_cast(feature_vector.size())); rtc::ArrayView feature_vector_view( feature_vector.data(), kFeatureVectorSize); // Extract the normalized scalar feature that is proportional to the estimated // pitch period. - constexpr size_t pitch_feature_index = kFeatureVectorSize - 2; + constexpr int pitch_feature_index = kFeatureVectorSize - 2; // Low frequency tone - i.e., high period. CreatePureTone(amplitude, low_pitch_hz, samples); ASSERT_FALSE(FeedTestData(&features_extractor, samples, feature_vector_view)); diff --git a/modules/audio_processing/agc2/rnn_vad/lp_residual.cc b/modules/audio_processing/agc2/rnn_vad/lp_residual.cc index f732b97bcf..c553aa2ad1 100644 --- a/modules/audio_processing/agc2/rnn_vad/lp_residual.cc +++ b/modules/audio_processing/agc2/rnn_vad/lp_residual.cc @@ -28,9 +28,9 @@ namespace { void ComputeAutoCorrelation( rtc::ArrayView x, rtc::ArrayView auto_corr) { - constexpr size_t max_lag = auto_corr.size(); + constexpr int max_lag = auto_corr.size(); RTC_DCHECK_LT(max_lag, x.size()); - for (size_t lag = 0; lag < max_lag; ++lag) { + for (int lag = 0; lag < max_lag; ++lag) { auto_corr[lag] = std::inner_product(x.begin(), x.end() - lag, x.begin() + lag, 0.f); } @@ -56,9 +56,9 @@ void ComputeInitialInverseFilterCoefficients( rtc::ArrayView auto_corr, rtc::ArrayView lpc_coeffs) { float error = auto_corr[0]; - for (size_t i = 0; i < kNumLpcCoefficients - 1; ++i) { + for (int i = 0; i < kNumLpcCoefficients - 1; ++i) { float reflection_coeff = 0.f; - for (size_t j = 0; j < i; ++j) { + for (int j = 0; j < i; ++j) { reflection_coeff += lpc_coeffs[j] * auto_corr[i - j]; } reflection_coeff += auto_corr[i + 1]; @@ -72,7 +72,7 @@ void ComputeInitialInverseFilterCoefficients( reflection_coeff /= -error; // Update LPC coefficients and total error. lpc_coeffs[i] = reflection_coeff; - for (size_t j = 0; j<(i + 1)>> 1; ++j) { + for (int j = 0; j < ((i + 1) >> 1); ++j) { const float tmp1 = lpc_coeffs[j]; const float tmp2 = lpc_coeffs[i - 1 - j]; lpc_coeffs[j] = tmp1 + reflection_coeff * tmp2; diff --git a/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc b/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc index 1e80ee0631..177977688e 100644 --- a/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc @@ -53,14 +53,14 @@ TEST(RnnVadTest, LpResidualPipelineBitExactness) { std::vector expected_lp_residual(kBufSize24kHz); // Test length. - const size_t num_frames = std::min(pitch_buf_24kHz_reader.second, - static_cast(300)); // Max 3 s. + const int num_frames = + std::min(pitch_buf_24kHz_reader.second, 300); // Max 3 s. ASSERT_GE(lp_residual_reader.second, num_frames); { // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; - for (size_t i = 0; i < num_frames; ++i) { + for (int i = 0; i < num_frames; ++i) { // Read input. ASSERT_TRUE(pitch_buf_24kHz_reader.first->ReadChunk(pitch_buf_data)); // Read expected output (ignore pitch gain and period). diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc index df73274cb7..85f67377e4 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc @@ -35,9 +35,8 @@ PitchInfo PitchEstimator::Estimate( Decimate2x(pitch_buf, pitch_buf_decimated_view_); auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buf_decimated_view_, auto_corr_view_); - CandidatePitchPeriods pitch_candidates_inverted_lags = - FindBestPitchPeriods(auto_corr_view_, pitch_buf_decimated_view_, - static_cast(kMaxPitch12kHz)); + CandidatePitchPeriods pitch_candidates_inverted_lags = FindBestPitchPeriods( + auto_corr_view_, pitch_buf_decimated_view_, kMaxPitch12kHz); // Refine the pitch period estimation. // The refinement is done using the pitch buffer that contains 24 kHz samples. // Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12 @@ -47,10 +46,9 @@ PitchInfo PitchEstimator::Estimate( const int pitch_inv_lag_48kHz = RefinePitchPeriod48kHz(pitch_buf, pitch_candidates_inverted_lags); // Look for stronger harmonics to find the final pitch period and its gain. - RTC_DCHECK_LT(pitch_inv_lag_48kHz, static_cast(kMaxPitch48kHz)); + RTC_DCHECK_LT(pitch_inv_lag_48kHz, kMaxPitch48kHz); last_pitch_48kHz_ = CheckLowerPitchPeriodsAndComputePitchGain( - pitch_buf, static_cast(kMaxPitch48kHz) - pitch_inv_lag_48kHz, - last_pitch_48kHz_); + pitch_buf, kMaxPitch48kHz - pitch_inv_lag_48kHz, last_pitch_48kHz_); return last_pitch_48kHz_; } diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc index 922669a4c5..d782a18d2f 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc @@ -20,29 +20,27 @@ #include "modules/audio_processing/agc2/rnn_vad/common.h" #include "rtc_base/checks.h" #include "rtc_base/numerics/safe_compare.h" +#include "rtc_base/numerics/safe_conversions.h" namespace webrtc { namespace rnn_vad { namespace { -constexpr int kMaxPitch24kHzInt = static_cast(kMaxPitch24kHz); - // Converts a lag to an inverted lag (only for 24kHz). int GetInvertedLag(int lag) { - RTC_DCHECK_LE(lag, kMaxPitch24kHzInt); - return kMaxPitch24kHzInt - lag; + RTC_DCHECK_LE(lag, kMaxPitch24kHz); + return kMaxPitch24kHz - lag; } float ComputeAutoCorrelationCoeff(rtc::ArrayView pitch_buf, int inv_lag, int max_pitch_period) { - RTC_DCHECK_LT(inv_lag, static_cast(pitch_buf.size())); - RTC_DCHECK_LT(max_pitch_period, static_cast(pitch_buf.size())); - RTC_DCHECK_LE(inv_lag, static_cast(max_pitch_period)); + RTC_DCHECK_LT(inv_lag, pitch_buf.size()); + RTC_DCHECK_LT(max_pitch_period, pitch_buf.size()); + RTC_DCHECK_LE(inv_lag, max_pitch_period); // TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization. - return std::inner_product( - pitch_buf.begin() + static_cast(max_pitch_period), - pitch_buf.end(), pitch_buf.begin() + static_cast(inv_lag), 0.f); + return std::inner_product(pitch_buf.begin() + max_pitch_period, + pitch_buf.end(), pitch_buf.begin() + inv_lag, 0.f); } // Given the auto-correlation coefficients for a lag and its neighbors, computes @@ -76,14 +74,14 @@ int PitchPseudoInterpolationLagPitchBuf( rtc::ArrayView pitch_buf) { int offset = 0; // Cannot apply pseudo-interpolation at the boundaries. - if (lag > 0 && lag < kMaxPitch24kHzInt) { + if (lag > 0 && lag < kMaxPitch24kHz) { offset = GetPitchPseudoInterpolationOffset( ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag - 1), - kMaxPitch24kHzInt), + kMaxPitch24kHz), ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag), - kMaxPitch24kHzInt), + kMaxPitch24kHz), ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag + 1), - kMaxPitch24kHzInt)); + kMaxPitch24kHz)); } return 2 * lag + offset; } @@ -96,7 +94,7 @@ int PitchPseudoInterpolationInvLagAutoCorr( rtc::ArrayView auto_corr) { int offset = 0; // Cannot apply pseudo-interpolation at the boundaries. - if (inv_lag > 0 && inv_lag < static_cast(auto_corr.size()) - 1) { + if (inv_lag > 0 && inv_lag < rtc::dchecked_cast(auto_corr.size()) - 1) { offset = GetPitchPseudoInterpolationOffset( auto_corr[inv_lag + 1], auto_corr[inv_lag], auto_corr[inv_lag - 1]); } @@ -143,7 +141,7 @@ void Decimate2x(rtc::ArrayView src, rtc::ArrayView dst) { // TODO(bugs.webrtc.org/9076): Consider adding anti-aliasing filter. static_assert(2 * dst.size() == src.size(), ""); - for (size_t i = 0; i < dst.size(); ++i) { + for (int i = 0; rtc::SafeLt(i, dst.size()); ++i) { dst[i] = src[2 * i]; } } @@ -186,10 +184,10 @@ float ComputePitchGainThreshold(int candidate_pitch_period, // reduce the chance of false positives caused by a bias towards high // frequencies (originating from short-term correlations). float threshold = std::max(0.3f, 0.7f * g0 - lower_threshold_term); - if (static_cast(t1) < 3 * kMinPitch24kHz) { + if (t1 < 3 * kMinPitch24kHz) { // High frequency. threshold = std::max(0.4f, 0.85f * g0 - lower_threshold_term); - } else if (static_cast(t1) < 2 * kMinPitch24kHz) { + } else if (t1 < 2 * kMinPitch24kHz) { // Even higher frequency. threshold = std::max(0.5f, 0.9f * g0 - lower_threshold_term); } @@ -199,10 +197,10 @@ float ComputePitchGainThreshold(int candidate_pitch_period, void ComputeSlidingFrameSquareEnergies( rtc::ArrayView pitch_buf, rtc::ArrayView yy_values) { - float yy = ComputeAutoCorrelationCoeff(pitch_buf, kMaxPitch24kHzInt, - kMaxPitch24kHzInt); + float yy = + ComputeAutoCorrelationCoeff(pitch_buf, kMaxPitch24kHz, kMaxPitch24kHz); yy_values[0] = yy; - for (size_t i = 1; i < yy_values.size(); ++i) { + for (int i = 1; rtc::SafeLt(i, yy_values.size()); ++i) { RTC_DCHECK_LE(i, kMaxPitch24kHz + kFrameSize20ms24kHz); RTC_DCHECK_LE(i, kMaxPitch24kHz); const float old_coeff = pitch_buf[kMaxPitch24kHz + kFrameSize20ms24kHz - i]; @@ -233,9 +231,10 @@ CandidatePitchPeriods FindBestPitchPeriods( } }; - RTC_DCHECK_GT(max_pitch_period, static_cast(auto_corr.size())); - RTC_DCHECK_LT(max_pitch_period, static_cast(pitch_buf.size())); - const int frame_size = static_cast(pitch_buf.size()) - max_pitch_period; + RTC_DCHECK_GT(max_pitch_period, auto_corr.size()); + RTC_DCHECK_LT(max_pitch_period, pitch_buf.size()); + const int frame_size = + rtc::dchecked_cast(pitch_buf.size()) - max_pitch_period; RTC_DCHECK_GT(frame_size, 0); // TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization. float yy = @@ -247,7 +246,7 @@ CandidatePitchPeriods FindBestPitchPeriods( PitchCandidate best; PitchCandidate second_best; second_best.period_inverted_lag = 1; - for (int inv_lag = 0; inv_lag < static_cast(auto_corr.size()); + for (int inv_lag = 0; inv_lag < rtc::dchecked_cast(auto_corr.size()); ++inv_lag) { // A pitch candidate must have positive correlation. if (auto_corr[inv_lag] > 0) { @@ -290,12 +289,12 @@ int RefinePitchPeriod48kHz( ++inverted_lag) { if (is_neighbor(inverted_lag, pitch_candidates_inverted_lags.best) || is_neighbor(inverted_lag, pitch_candidates_inverted_lags.second_best)) - auto_correlation[inverted_lag] = ComputeAutoCorrelationCoeff( - pitch_buf, inverted_lag, kMaxPitch24kHzInt); + auto_correlation[inverted_lag] = + ComputeAutoCorrelationCoeff(pitch_buf, inverted_lag, kMaxPitch24kHz); } // Find best pitch at 24 kHz. const CandidatePitchPeriods pitch_candidates_24kHz = - FindBestPitchPeriods(auto_correlation, pitch_buf, kMaxPitch24kHzInt); + FindBestPitchPeriods(auto_correlation, pitch_buf, kMaxPitch24kHz); // Pseudo-interpolation. return PitchPseudoInterpolationInvLagAutoCorr(pitch_candidates_24kHz.best, auto_correlation); @@ -334,9 +333,9 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( // Initial pitch candidate gain. RefinedPitchCandidate best_pitch; best_pitch.period_24kHz = - std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHzInt - 1); + std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHz - 1); best_pitch.xy = ComputeAutoCorrelationCoeff( - pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHzInt); + pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHz); best_pitch.yy = yy_values[best_pitch.period_24kHz]; best_pitch.gain = pitch_gain(best_pitch.xy, best_pitch.yy, xx); @@ -351,11 +350,10 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( }; // |max_k| such that alternative_period(initial_pitch_period, max_k, 1) equals // kMinPitch24kHz. - const int max_k = - (2 * initial_pitch_period) / (2 * static_cast(kMinPitch24kHz) - 1); + const int max_k = (2 * initial_pitch_period) / (2 * kMinPitch24kHz - 1); for (int k = 2; k <= max_k; ++k) { int candidate_pitch_period = alternative_period(initial_pitch_period, k, 1); - RTC_DCHECK_GE(candidate_pitch_period, static_cast(kMinPitch24kHz)); + RTC_DCHECK_GE(candidate_pitch_period, kMinPitch24kHz); // When looking at |candidate_pitch_period|, we also look at one of its // sub-harmonics. |kSubHarmonicMultipliers| is used to know where to look. // |k| == 2 is a special case since |candidate_pitch_secondary_period| might @@ -363,7 +361,7 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( int candidate_pitch_secondary_period = alternative_period( initial_pitch_period, k, kSubHarmonicMultipliers[k - 2]); RTC_DCHECK_GT(candidate_pitch_secondary_period, 0); - if (k == 2 && candidate_pitch_secondary_period > kMaxPitch24kHzInt) { + if (k == 2 && candidate_pitch_secondary_period > kMaxPitch24kHz) { candidate_pitch_secondary_period = initial_pitch_period; } RTC_DCHECK_NE(candidate_pitch_period, candidate_pitch_secondary_period) @@ -373,10 +371,10 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( // |candidate_pitch_period| by also looking at its possible sub-harmonic // |candidate_pitch_secondary_period|. float xy_primary_period = ComputeAutoCorrelationCoeff( - pitch_buf, GetInvertedLag(candidate_pitch_period), kMaxPitch24kHzInt); + pitch_buf, GetInvertedLag(candidate_pitch_period), kMaxPitch24kHz); float xy_secondary_period = ComputeAutoCorrelationCoeff( pitch_buf, GetInvertedLag(candidate_pitch_secondary_period), - kMaxPitch24kHzInt); + kMaxPitch24kHz); float xy = 0.5f * (xy_primary_period + xy_secondary_period); float yy = 0.5f * (yy_values[candidate_pitch_period] + yy_values[candidate_pitch_secondary_period]); @@ -399,7 +397,7 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( : best_pitch.xy / (best_pitch.yy + 1.f); final_pitch_gain = std::min(best_pitch.gain, final_pitch_gain); int final_pitch_period_48kHz = std::max( - static_cast(kMinPitch48kHz), + kMinPitch48kHz, PitchPseudoInterpolationLagPitchBuf(best_pitch.period_24kHz, pitch_buf)); return {final_pitch_period_48kHz, final_pitch_gain}; diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc index 37fb15f72e..fdbee68357 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc @@ -34,11 +34,11 @@ constexpr float kTestPitchGainsHigh = 0.75f; class ComputePitchGainThresholdTest : public ::testing::Test, public ::testing::WithParamInterface> {}; @@ -46,11 +46,11 @@ class ComputePitchGainThresholdTest // data. TEST_P(ComputePitchGainThresholdTest, WithinTolerance) { const auto params = GetParam(); - const size_t candidate_pitch_period = std::get<0>(params); - const size_t pitch_period_ratio = std::get<1>(params); - const size_t initial_pitch_period = std::get<2>(params); + const int candidate_pitch_period = std::get<0>(params); + const int pitch_period_ratio = std::get<1>(params); + const int initial_pitch_period = std::get<2>(params); const float initial_pitch_gain = std::get<3>(params); - const size_t prev_pitch_period = std::get<4>(params); + const int prev_pitch_period = std::get<4>(params); const float prev_pitch_gain = std::get<5>(params); const float threshold = std::get<6>(params); { diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc index 99c9dfa06a..fdecb92807 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc @@ -28,22 +28,21 @@ namespace test { // pitch gain is within tolerance given test input data. TEST(RnnVadTest, PitchSearchWithinTolerance) { auto lp_residual_reader = CreateLpResidualAndPitchPeriodGainReader(); - const size_t num_frames = std::min(lp_residual_reader.second, - static_cast(300)); // Max 3 s. + const int num_frames = std::min(lp_residual_reader.second, 300); // Max 3 s. std::vector lp_residual(kBufSize24kHz); float expected_pitch_period, expected_pitch_gain; PitchEstimator pitch_estimator; { // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; - for (size_t i = 0; i < num_frames; ++i) { + for (int i = 0; i < num_frames; ++i) { SCOPED_TRACE(i); lp_residual_reader.first->ReadChunk(lp_residual); lp_residual_reader.first->ReadValue(&expected_pitch_period); lp_residual_reader.first->ReadValue(&expected_pitch_gain); PitchInfo pitch_info = pitch_estimator.Estimate({lp_residual.data(), kBufSize24kHz}); - EXPECT_EQ(static_cast(expected_pitch_period), pitch_info.period); + EXPECT_EQ(expected_pitch_period, pitch_info.period); EXPECT_NEAR(expected_pitch_gain, pitch_info.gain, 1e-5f); } } diff --git a/modules/audio_processing/agc2/rnn_vad/ring_buffer.h b/modules/audio_processing/agc2/rnn_vad/ring_buffer.h index 294b0c0ba8..f0270af918 100644 --- a/modules/audio_processing/agc2/rnn_vad/ring_buffer.h +++ b/modules/audio_processing/agc2/rnn_vad/ring_buffer.h @@ -21,7 +21,7 @@ namespace webrtc { namespace rnn_vad { // Ring buffer for N arrays of type T each one with size S. -template +template class RingBuffer { static_assert(S > 0, ""); static_assert(N > 0, ""); @@ -45,11 +45,10 @@ class RingBuffer { // Return an array view onto the array with a given delay. A view on the last // and least recently push array is returned when |delay| is 0 and N - 1 // respectively. - rtc::ArrayView GetArrayView(size_t delay) const { - const int delay_int = static_cast(delay); - RTC_DCHECK_LE(0, delay_int); - RTC_DCHECK_LT(delay_int, N); - int offset = tail_ - 1 - delay_int; + rtc::ArrayView GetArrayView(int delay) const { + RTC_DCHECK_LE(0, delay); + RTC_DCHECK_LT(delay, N); + int offset = tail_ - 1 - delay; if (offset < 0) offset += N; return {buffer_.data() + S * offset, S}; diff --git a/modules/audio_processing/agc2/rnn_vad/ring_buffer_unittest.cc b/modules/audio_processing/agc2/rnn_vad/ring_buffer_unittest.cc index f064651798..8b061a968f 100644 --- a/modules/audio_processing/agc2/rnn_vad/ring_buffer_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/ring_buffer_unittest.cc @@ -20,14 +20,14 @@ namespace { // Compare the elements of two given array views. template void ExpectEq(rtc::ArrayView a, rtc::ArrayView b) { - for (size_t i = 0; i < S; ++i) { + for (int i = 0; i < S; ++i) { SCOPED_TRACE(i); EXPECT_EQ(a[i], b[i]); } } // Test push/read sequences. -template +template void TestRingBuffer() { SCOPED_TRACE(N); SCOPED_TRACE(S); @@ -56,7 +56,7 @@ void TestRingBuffer() { } // Check buffer. - for (size_t delay = 2; delay < N; ++delay) { + for (int delay = 2; delay < N; ++delay) { SCOPED_TRACE(delay); T expected_value = N - static_cast(delay); pushed_array.fill(expected_value); @@ -68,18 +68,18 @@ void TestRingBuffer() { // Check that for different delays, different views are returned. TEST(RnnVadTest, RingBufferArrayViews) { - constexpr size_t s = 3; - constexpr size_t n = 4; + constexpr int s = 3; + constexpr int n = 4; RingBuffer ring_buf; std::array pushed_array; pushed_array.fill(1); - for (size_t k = 0; k <= n; ++k) { // Push data n + 1 times. + for (int k = 0; k <= n; ++k) { // Push data n + 1 times. SCOPED_TRACE(k); // Check array views. - for (size_t i = 0; i < n; ++i) { + for (int i = 0; i < n; ++i) { SCOPED_TRACE(i); auto view_i = ring_buf.GetArrayView(i); - for (size_t j = i + 1; j < n; ++j) { + for (int j = i + 1; j < n; ++j) { SCOPED_TRACE(j); auto view_j = ring_buf.GetArrayView(j); EXPECT_NE(view_i, view_j); diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.cc b/modules/audio_processing/agc2/rnn_vad/rnn.cc index 55a51ffa43..2072a6854d 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn.cc @@ -26,6 +26,7 @@ #include "rtc_base/checks.h" #include "rtc_base/logging.h" +#include "rtc_base/numerics/safe_conversions.h" #include "third_party/rnnoise/src/rnn_activations.h" #include "third_party/rnnoise/src/rnn_vad_weights.h" @@ -77,15 +78,16 @@ std::vector GetScaledParams(rtc::ArrayView params) { // Casts and scales |weights| and re-arranges the layout. std::vector GetPreprocessedFcWeights( rtc::ArrayView weights, - size_t output_size) { + int output_size) { if (output_size == 1) { return GetScaledParams(weights); } // Transpose, scale and cast. - const size_t input_size = rtc::CheckedDivExact(weights.size(), output_size); + const int input_size = rtc::CheckedDivExact( + rtc::dchecked_cast(weights.size()), output_size); std::vector w(weights.size()); - for (size_t o = 0; o < output_size; ++o) { - for (size_t i = 0; i < input_size; ++i) { + for (int o = 0; o < output_size; ++o) { + for (int i = 0; i < input_size; ++i) { w[o * input_size + i] = rnnoise::kWeightsScale * static_cast(weights[i * output_size + o]); } @@ -93,7 +95,7 @@ std::vector GetPreprocessedFcWeights( return w; } -constexpr size_t kNumGruGates = 3; // Update, reset, output. +constexpr int kNumGruGates = 3; // Update, reset, output. // TODO(bugs.chromium.org/10480): Hard-coded optimized layout and remove this // function to improve setup time. @@ -101,17 +103,17 @@ constexpr size_t kNumGruGates = 3; // Update, reset, output. // It works both for weights, recurrent weights and bias. std::vector GetPreprocessedGruTensor( rtc::ArrayView tensor_src, - size_t output_size) { + int output_size) { // Transpose, cast and scale. // |n| is the size of the first dimension of the 3-dim tensor |weights|. - const size_t n = - rtc::CheckedDivExact(tensor_src.size(), output_size * kNumGruGates); - const size_t stride_src = kNumGruGates * output_size; - const size_t stride_dst = n * output_size; + const int n = rtc::CheckedDivExact(rtc::dchecked_cast(tensor_src.size()), + output_size * kNumGruGates); + const int stride_src = kNumGruGates * output_size; + const int stride_dst = n * output_size; std::vector tensor_dst(tensor_src.size()); - for (size_t g = 0; g < kNumGruGates; ++g) { - for (size_t o = 0; o < output_size; ++o) { - for (size_t i = 0; i < n; ++i) { + for (int g = 0; g < kNumGruGates; ++g) { + for (int o = 0; o < output_size; ++o) { + for (int i = 0; i < n; ++i) { tensor_dst[g * stride_dst + o * n + i] = rnnoise::kWeightsScale * static_cast( @@ -122,28 +124,28 @@ std::vector GetPreprocessedGruTensor( return tensor_dst; } -void ComputeGruUpdateResetGates(size_t input_size, - size_t output_size, +void ComputeGruUpdateResetGates(int input_size, + int output_size, rtc::ArrayView weights, rtc::ArrayView recurrent_weights, rtc::ArrayView bias, rtc::ArrayView input, rtc::ArrayView state, rtc::ArrayView gate) { - for (size_t o = 0; o < output_size; ++o) { + for (int o = 0; o < output_size; ++o) { gate[o] = bias[o]; - for (size_t i = 0; i < input_size; ++i) { + for (int i = 0; i < input_size; ++i) { gate[o] += input[i] * weights[o * input_size + i]; } - for (size_t s = 0; s < output_size; ++s) { + for (int s = 0; s < output_size; ++s) { gate[o] += state[s] * recurrent_weights[o * output_size + s]; } gate[o] = SigmoidApproximated(gate[o]); } } -void ComputeGruOutputGate(size_t input_size, - size_t output_size, +void ComputeGruOutputGate(int input_size, + int output_size, rtc::ArrayView weights, rtc::ArrayView recurrent_weights, rtc::ArrayView bias, @@ -151,12 +153,12 @@ void ComputeGruOutputGate(size_t input_size, rtc::ArrayView state, rtc::ArrayView reset, rtc::ArrayView gate) { - for (size_t o = 0; o < output_size; ++o) { + for (int o = 0; o < output_size; ++o) { gate[o] = bias[o]; - for (size_t i = 0; i < input_size; ++i) { + for (int i = 0; i < input_size; ++i) { gate[o] += input[i] * weights[o * input_size + i]; } - for (size_t s = 0; s < output_size; ++s) { + for (int s = 0; s < output_size; ++s) { gate[o] += state[s] * recurrent_weights[o * output_size + s] * reset[s]; } gate[o] = RectifiedLinearUnit(gate[o]); @@ -164,8 +166,8 @@ void ComputeGruOutputGate(size_t input_size, } // Gated recurrent unit (GRU) layer un-optimized implementation. -void ComputeGruLayerOutput(size_t input_size, - size_t output_size, +void ComputeGruLayerOutput(int input_size, + int output_size, rtc::ArrayView input, rtc::ArrayView weights, rtc::ArrayView recurrent_weights, @@ -173,8 +175,8 @@ void ComputeGruLayerOutput(size_t input_size, rtc::ArrayView state) { RTC_DCHECK_EQ(input_size, input.size()); // Stride and offset used to read parameter arrays. - const size_t stride_in = input_size * output_size; - const size_t stride_out = output_size * output_size; + const int stride_in = input_size * output_size; + const int stride_out = output_size * output_size; // Update gate. std::array update; @@ -198,7 +200,7 @@ void ComputeGruLayerOutput(size_t input_size, bias.subview(2 * output_size, output_size), input, state, reset, output); // Update output through the update gates and update the state. - for (size_t o = 0; o < output_size; ++o) { + for (int o = 0; o < output_size; ++o) { output[o] = update[o] * state[o] + (1.f - update[o]) * output[o]; state[o] = output[o]; } @@ -206,8 +208,8 @@ void ComputeGruLayerOutput(size_t input_size, // Fully connected layer un-optimized implementation. void ComputeFullyConnectedLayerOutput( - size_t input_size, - size_t output_size, + int input_size, + int output_size, rtc::ArrayView input, rtc::ArrayView bias, rtc::ArrayView weights, @@ -216,11 +218,11 @@ void ComputeFullyConnectedLayerOutput( RTC_DCHECK_EQ(input.size(), input_size); RTC_DCHECK_EQ(bias.size(), output_size); RTC_DCHECK_EQ(weights.size(), input_size * output_size); - for (size_t o = 0; o < output_size; ++o) { + for (int o = 0; o < output_size; ++o) { output[o] = bias[o]; // TODO(bugs.chromium.org/9076): Benchmark how different layouts for // |weights_| change the performance across different platforms. - for (size_t i = 0; i < input_size; ++i) { + for (int i = 0; i < input_size; ++i) { output[o] += input[i] * weights[o * input_size + i]; } output[o] = activation_function(output[o]); @@ -230,8 +232,8 @@ void ComputeFullyConnectedLayerOutput( #if defined(WEBRTC_ARCH_X86_FAMILY) // Fully connected layer SSE2 implementation. void ComputeFullyConnectedLayerOutputSse2( - size_t input_size, - size_t output_size, + int input_size, + int output_size, rtc::ArrayView input, rtc::ArrayView bias, rtc::ArrayView weights, @@ -240,16 +242,16 @@ void ComputeFullyConnectedLayerOutputSse2( RTC_DCHECK_EQ(input.size(), input_size); RTC_DCHECK_EQ(bias.size(), output_size); RTC_DCHECK_EQ(weights.size(), input_size * output_size); - const size_t input_size_by_4 = input_size >> 2; - const size_t offset = input_size & ~3; + const int input_size_by_4 = input_size >> 2; + const int offset = input_size & ~3; __m128 sum_wx_128; const float* v = reinterpret_cast(&sum_wx_128); - for (size_t o = 0; o < output_size; ++o) { + for (int o = 0; o < output_size; ++o) { // Perform 128 bit vector operations. sum_wx_128 = _mm_set1_ps(0); const float* x_p = input.data(); const float* w_p = weights.data() + o * input_size; - for (size_t i = 0; i < input_size_by_4; ++i, x_p += 4, w_p += 4) { + for (int i = 0; i < input_size_by_4; ++i, x_p += 4, w_p += 4) { sum_wx_128 = _mm_add_ps(sum_wx_128, _mm_mul_ps(_mm_loadu_ps(x_p), _mm_loadu_ps(w_p))); } @@ -266,8 +268,8 @@ void ComputeFullyConnectedLayerOutputSse2( } // namespace FullyConnectedLayer::FullyConnectedLayer( - const size_t input_size, - const size_t output_size, + const int input_size, + const int output_size, const rtc::ArrayView bias, const rtc::ArrayView weights, rtc::FunctionView activation_function, @@ -316,8 +318,8 @@ void FullyConnectedLayer::ComputeOutput(rtc::ArrayView input) { } GatedRecurrentLayer::GatedRecurrentLayer( - const size_t input_size, - const size_t output_size, + const int input_size, + const int output_size, const rtc::ArrayView bias, const rtc::ArrayView weights, const rtc::ArrayView recurrent_weights, diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.h b/modules/audio_processing/agc2/rnn_vad/rnn.h index 58274b2e1e..5b44f53047 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn.h +++ b/modules/audio_processing/agc2/rnn_vad/rnn.h @@ -29,19 +29,19 @@ namespace rnn_vad { // over-allocate space for fully-connected layers output vectors (implemented as // std::array). The value should equal the number of units of the largest // fully-connected layer. -constexpr size_t kFullyConnectedLayersMaxUnits = 24; +constexpr int kFullyConnectedLayersMaxUnits = 24; // Maximum number of units for a recurrent layer. This value is used to // over-allocate space for recurrent layers state vectors (implemented as // std::array). The value should equal the number of units of the largest // recurrent layer. -constexpr size_t kRecurrentLayersMaxUnits = 24; +constexpr int kRecurrentLayersMaxUnits = 24; // Fully-connected layer. class FullyConnectedLayer { public: - FullyConnectedLayer(size_t input_size, - size_t output_size, + FullyConnectedLayer(int input_size, + int output_size, rtc::ArrayView bias, rtc::ArrayView weights, rtc::FunctionView activation_function, @@ -49,16 +49,16 @@ class FullyConnectedLayer { FullyConnectedLayer(const FullyConnectedLayer&) = delete; FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete; ~FullyConnectedLayer(); - size_t input_size() const { return input_size_; } - size_t output_size() const { return output_size_; } + int input_size() const { return input_size_; } + int output_size() const { return output_size_; } Optimization optimization() const { return optimization_; } rtc::ArrayView GetOutput() const; // Computes the fully-connected layer output. void ComputeOutput(rtc::ArrayView input); private: - const size_t input_size_; - const size_t output_size_; + const int input_size_; + const int output_size_; const std::vector bias_; const std::vector weights_; rtc::FunctionView activation_function_; @@ -72,8 +72,8 @@ class FullyConnectedLayer { // activation functions for the update/reset and output gates respectively. class GatedRecurrentLayer { public: - GatedRecurrentLayer(size_t input_size, - size_t output_size, + GatedRecurrentLayer(int input_size, + int output_size, rtc::ArrayView bias, rtc::ArrayView weights, rtc::ArrayView recurrent_weights, @@ -81,8 +81,8 @@ class GatedRecurrentLayer { GatedRecurrentLayer(const GatedRecurrentLayer&) = delete; GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete; ~GatedRecurrentLayer(); - size_t input_size() const { return input_size_; } - size_t output_size() const { return output_size_; } + int input_size() const { return input_size_; } + int output_size() const { return output_size_; } Optimization optimization() const { return optimization_; } rtc::ArrayView GetOutput() const; void Reset(); @@ -90,8 +90,8 @@ class GatedRecurrentLayer { void ComputeOutput(rtc::ArrayView input); private: - const size_t input_size_; - const size_t output_size_; + const int input_size_; + const int output_size_; const std::vector bias_; const std::vector weights_; const std::vector recurrent_weights_; diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc index 6e9f6f3690..a57a899c8d 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc @@ -18,6 +18,7 @@ #include "modules/audio_processing/test/performance_timer.h" #include "rtc_base/checks.h" #include "rtc_base/logging.h" +#include "rtc_base/numerics/safe_conversions.h" #include "rtc_base/system/arch.h" #include "test/gtest.h" #include "third_party/rnnoise/src/rnn_activations.h" @@ -43,15 +44,16 @@ void TestGatedRecurrentLayer( rtc::ArrayView expected_output_sequence) { RTC_CHECK(gru); auto gru_output_view = gru->GetOutput(); - const size_t input_sequence_length = - rtc::CheckedDivExact(input_sequence.size(), gru->input_size()); - const size_t output_sequence_length = - rtc::CheckedDivExact(expected_output_sequence.size(), gru->output_size()); + const int input_sequence_length = rtc::CheckedDivExact( + rtc::dchecked_cast(input_sequence.size()), gru->input_size()); + const int output_sequence_length = rtc::CheckedDivExact( + rtc::dchecked_cast(expected_output_sequence.size()), + gru->output_size()); ASSERT_EQ(input_sequence_length, output_sequence_length) << "The test data length is invalid."; // Feed the GRU layer and check the output at every step. gru->Reset(); - for (size_t i = 0; i < input_sequence_length; ++i) { + for (int i = 0; i < input_sequence_length; ++i) { SCOPED_TRACE(i); gru->ComputeOutput( input_sequence.subview(i * gru->input_size(), gru->input_size())); @@ -77,8 +79,8 @@ constexpr std::array kFullyConnectedExpectedOutput = { 0.875092f, 0.999846f, 0.997707f, -0.999382f, 0.973153f, -0.966605f}; // Gated recurrent units layer test data. -constexpr size_t kGruInputSize = 5; -constexpr size_t kGruOutputSize = 4; +constexpr int kGruInputSize = 5; +constexpr int kGruOutputSize = 4; constexpr std::array kGruBias = {96, -99, -81, -114, 49, 119, -118, 68, -76, 91, 121, 125}; constexpr std::array kGruWeights = { @@ -213,10 +215,10 @@ TEST(RnnVadTest, DISABLED_BenchmarkFullyConnectedLayer) { } std::vector results; - constexpr size_t number_of_tests = 10000; + constexpr int number_of_tests = 10000; for (auto& fc : implementations) { ::webrtc::test::PerformanceTimer perf_timer(number_of_tests); - for (size_t k = 0; k < number_of_tests; ++k) { + for (int k = 0; k < number_of_tests; ++k) { perf_timer.StartTimer(); fc->ComputeOutput(kFullyConnectedInputVector); perf_timer.StopTimer(); @@ -240,17 +242,17 @@ TEST(RnnVadTest, DISABLED_BenchmarkGatedRecurrentLayer) { rtc::ArrayView input_sequence(kGruInputSequence); static_assert(kGruInputSequence.size() % kGruInputSize == 0, ""); - constexpr size_t input_sequence_length = + constexpr int input_sequence_length = kGruInputSequence.size() / kGruInputSize; std::vector results; - constexpr size_t number_of_tests = 10000; + constexpr int number_of_tests = 10000; for (auto& gru : implementations) { ::webrtc::test::PerformanceTimer perf_timer(number_of_tests); gru->Reset(); - for (size_t k = 0; k < number_of_tests; ++k) { + for (int k = 0; k < number_of_tests; ++k) { perf_timer.StartTimer(); - for (size_t i = 0; i < input_sequence_length; ++i) { + for (int i = 0; i < input_sequence_length; ++i) { gru->ComputeOutput( input_sequence.subview(i * gru->input_size(), gru->input_size())); } diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_vad_tool.cc b/modules/audio_processing/agc2/rnn_vad/rnn_vad_tool.cc index c5293bedc7..8b12b60c55 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn_vad_tool.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn_vad_tool.cc @@ -20,6 +20,7 @@ #include "modules/audio_processing/agc2/rnn_vad/features_extraction.h" #include "modules/audio_processing/agc2/rnn_vad/rnn.h" #include "rtc_base/logging.h" +#include "rtc_base/numerics/safe_compare.h" ABSL_FLAG(std::string, i, "", "Path to the input wav file"); ABSL_FLAG(std::string, f, "", "Path to the output features file"); @@ -56,7 +57,7 @@ int main(int argc, char* argv[]) { } // Initialize. - const size_t frame_size_10ms = + const int frame_size_10ms = rtc::CheckedDivExact(wav_reader.sample_rate(), 100); std::vector samples_10ms; samples_10ms.resize(frame_size_10ms); @@ -69,9 +70,9 @@ int main(int argc, char* argv[]) { // Compute VAD probabilities. while (true) { // Read frame at the input sample rate. - const auto read_samples = + const size_t read_samples = wav_reader.ReadSamples(frame_size_10ms, samples_10ms.data()); - if (read_samples < frame_size_10ms) { + if (rtc::SafeLt(read_samples, frame_size_10ms)) { break; // EOF. } // Resample input. diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc index 8583d4bc1b..0916bf5b81 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc @@ -28,10 +28,10 @@ namespace rnn_vad { namespace test { namespace { -constexpr size_t kFrameSize10ms48kHz = 480; +constexpr int kFrameSize10ms48kHz = 480; -void DumpPerfStats(size_t num_samples, - size_t sample_rate, +void DumpPerfStats(int num_samples, + int sample_rate, double average_us, double standard_deviation) { float audio_track_length_ms = @@ -70,7 +70,7 @@ TEST(RnnVadTest, RnnVadProbabilityWithinTolerance) { auto expected_vad_prob_reader = CreateVadProbsReader(); // Input length. - const size_t num_frames = samples_reader.second; + const int num_frames = samples_reader.second; ASSERT_GE(expected_vad_prob_reader.second, num_frames); // Init buffers. @@ -85,7 +85,7 @@ TEST(RnnVadTest, RnnVadProbabilityWithinTolerance) { // Compute VAD probabilities on the downsampled input. float cumulative_error = 0.f; - for (size_t i = 0; i < num_frames; ++i) { + for (int i = 0; i < num_frames; ++i) { samples_reader.first->ReadChunk(samples_48k); decimator.Resample(samples_48k.data(), samples_48k.size(), samples_24k.data(), samples_24k.size()); @@ -114,13 +114,13 @@ TEST(RnnVadTest, RnnVadProbabilityWithinTolerance) { TEST(RnnVadTest, DISABLED_RnnVadPerformance) { // PCM samples reader and buffers. auto samples_reader = CreatePcmSamplesReader(kFrameSize10ms48kHz); - const size_t num_frames = samples_reader.second; + const int num_frames = samples_reader.second; std::array samples; // Pre-fetch and decimate samples. PushSincResampler decimator(kFrameSize10ms48kHz, kFrameSize10ms24kHz); std::vector prefetched_decimated_samples; prefetched_decimated_samples.resize(num_frames * kFrameSize10ms24kHz); - for (size_t i = 0; i < num_frames; ++i) { + for (int i = 0; i < num_frames; ++i) { samples_reader.first->ReadChunk(samples); decimator.Resample(samples.data(), samples.size(), &prefetched_decimated_samples[i * kFrameSize10ms24kHz], @@ -130,14 +130,14 @@ TEST(RnnVadTest, DISABLED_RnnVadPerformance) { FeaturesExtractor features_extractor; std::array feature_vector; RnnBasedVad rnn_vad; - constexpr size_t number_of_tests = 100; + constexpr int number_of_tests = 100; ::webrtc::test::PerformanceTimer perf_timer(number_of_tests); - for (size_t k = 0; k < number_of_tests; ++k) { + for (int k = 0; k < number_of_tests; ++k) { features_extractor.Reset(); rnn_vad.Reset(); // Process frames. perf_timer.StartTimer(); - for (size_t i = 0; i < num_frames; ++i) { + for (int i = 0; i < num_frames; ++i) { bool is_silence = features_extractor.CheckSilenceComputeFeatures( {&prefetched_decimated_samples[i * kFrameSize10ms24kHz], kFrameSize10ms24kHz}, diff --git a/modules/audio_processing/agc2/rnn_vad/sequence_buffer.h b/modules/audio_processing/agc2/rnn_vad/sequence_buffer.h index 75d3d9bc09..a7402788c8 100644 --- a/modules/audio_processing/agc2/rnn_vad/sequence_buffer.h +++ b/modules/audio_processing/agc2/rnn_vad/sequence_buffer.h @@ -29,7 +29,7 @@ namespace rnn_vad { // values are written at the end of the buffer. // The class also provides a view on the most recent M values, where 0 < M <= S // and by default M = N. -template +template class SequenceBuffer { static_assert(N <= S, "The new chunk size cannot be larger than the sequence buffer " @@ -45,8 +45,8 @@ class SequenceBuffer { SequenceBuffer(const SequenceBuffer&) = delete; SequenceBuffer& operator=(const SequenceBuffer&) = delete; ~SequenceBuffer() = default; - size_t size() const { return S; } - size_t chunks_size() const { return N; } + int size() const { return S; } + int chunks_size() const { return N; } // Sets the sequence buffer values to zero. void Reset() { std::fill(buffer_.begin(), buffer_.end(), 0); } // Returns a view on the whole buffer. diff --git a/modules/audio_processing/agc2/rnn_vad/sequence_buffer_unittest.cc b/modules/audio_processing/agc2/rnn_vad/sequence_buffer_unittest.cc index 9b66dcf701..125f1b821c 100644 --- a/modules/audio_processing/agc2/rnn_vad/sequence_buffer_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/sequence_buffer_unittest.cc @@ -20,7 +20,7 @@ namespace rnn_vad { namespace test { namespace { -template +template void TestSequenceBufferPushOp() { SCOPED_TRACE(S); SCOPED_TRACE(N); @@ -32,8 +32,8 @@ void TestSequenceBufferPushOp() { chunk.fill(1); seq_buf.Push(chunk); chunk.fill(0); - constexpr size_t required_push_ops = (S % N) ? S / N + 1 : S / N; - for (size_t i = 0; i < required_push_ops - 1; ++i) { + constexpr int required_push_ops = (S % N) ? S / N + 1 : S / N; + for (int i = 0; i < required_push_ops - 1; ++i) { SCOPED_TRACE(i); seq_buf.Push(chunk); // Still in the buffer. @@ -48,12 +48,12 @@ void TestSequenceBufferPushOp() { // Check that the last item moves left by N positions after a push op. if (S > N) { // Fill in with non-zero values. - for (size_t i = 0; i < N; ++i) + for (int i = 0; i < N; ++i) chunk[i] = static_cast(i + 1); seq_buf.Push(chunk); // With the next Push(), |last| will be moved left by N positions. const T last = chunk[N - 1]; - for (size_t i = 0; i < N; ++i) + for (int i = 0; i < N; ++i) chunk[i] = static_cast(last + i + 1); seq_buf.Push(chunk); EXPECT_EQ(last, seq_buf_view[S - N - 1]); @@ -63,8 +63,8 @@ void TestSequenceBufferPushOp() { } // namespace TEST(RnnVadTest, SequenceBufferGetters) { - constexpr size_t buffer_size = 8; - constexpr size_t chunk_size = 8; + constexpr int buffer_size = 8; + constexpr int chunk_size = 8; SequenceBuffer seq_buf; EXPECT_EQ(buffer_size, seq_buf.size()); EXPECT_EQ(chunk_size, seq_buf.chunks_size()); diff --git a/modules/audio_processing/agc2/rnn_vad/spectral_features.cc b/modules/audio_processing/agc2/rnn_vad/spectral_features.cc index 81e3339d70..96086babb6 100644 --- a/modules/audio_processing/agc2/rnn_vad/spectral_features.cc +++ b/modules/audio_processing/agc2/rnn_vad/spectral_features.cc @@ -16,6 +16,7 @@ #include #include "rtc_base/checks.h" +#include "rtc_base/numerics/safe_compare.h" namespace webrtc { namespace rnn_vad { @@ -32,11 +33,11 @@ void UpdateCepstralDifferenceStats( RTC_DCHECK(sym_matrix_buf); // Compute the new cepstral distance stats. std::array distances; - for (size_t i = 0; i < kCepstralCoeffsHistorySize - 1; ++i) { - const size_t delay = i + 1; + for (int i = 0; i < kCepstralCoeffsHistorySize - 1; ++i) { + const int delay = i + 1; auto old_cepstral_coeffs = ring_buf.GetArrayView(delay); distances[i] = 0.f; - for (size_t k = 0; k < kNumBands; ++k) { + for (int k = 0; k < kNumBands; ++k) { const float c = new_cepstral_coeffs[k] - old_cepstral_coeffs[k]; distances[i] += c * c; } @@ -48,9 +49,9 @@ void UpdateCepstralDifferenceStats( // Computes the first half of the Vorbis window. std::array ComputeScaledHalfVorbisWindow( float scaling = 1.f) { - constexpr size_t kHalfSize = kFrameSize20ms24kHz / 2; + constexpr int kHalfSize = kFrameSize20ms24kHz / 2; std::array half_window{}; - for (size_t i = 0; i < kHalfSize; ++i) { + for (int i = 0; i < kHalfSize; ++i) { half_window[i] = scaling * std::sin(0.5 * kPi * std::sin(0.5 * kPi * (i + 0.5) / kHalfSize) * @@ -71,8 +72,8 @@ void ComputeWindowedForwardFft( RTC_DCHECK_EQ(frame.size(), 2 * half_window.size()); // Apply windowing. auto in = fft_input_buffer->GetView(); - for (size_t i = 0, j = kFrameSize20ms24kHz - 1; i < half_window.size(); - ++i, --j) { + for (int i = 0, j = kFrameSize20ms24kHz - 1; + rtc::SafeLt(i, half_window.size()); ++i, --j) { in[i] = frame[i] * half_window[i]; in[j] = frame[j] * half_window[i]; } @@ -162,7 +163,7 @@ void SpectralFeaturesExtractor::ComputeAvgAndDerivatives( RTC_DCHECK_EQ(average.size(), first_derivative.size()); RTC_DCHECK_EQ(first_derivative.size(), second_derivative.size()); RTC_DCHECK_LE(average.size(), curr.size()); - for (size_t i = 0; i < average.size(); ++i) { + for (int i = 0; rtc::SafeLt(i, average.size()); ++i) { // Average, kernel: [1, 1, 1]. average[i] = curr[i] + prev1[i] + prev2[i]; // First derivative, kernel: [1, 0, - 1]. @@ -178,7 +179,7 @@ void SpectralFeaturesExtractor::ComputeNormalizedCepstralCorrelation( reference_frame_fft_->GetConstView(), lagged_frame_fft_->GetConstView(), bands_cross_corr_); // Normalize. - for (size_t i = 0; i < bands_cross_corr_.size(); ++i) { + for (int i = 0; rtc::SafeLt(i, bands_cross_corr_.size()); ++i) { bands_cross_corr_[i] = bands_cross_corr_[i] / std::sqrt(0.001f + reference_frame_bands_energy_[i] * @@ -194,9 +195,9 @@ void SpectralFeaturesExtractor::ComputeNormalizedCepstralCorrelation( float SpectralFeaturesExtractor::ComputeVariability() const { // Compute cepstral variability score. float variability = 0.f; - for (size_t delay1 = 0; delay1 < kCepstralCoeffsHistorySize; ++delay1) { + for (int delay1 = 0; delay1 < kCepstralCoeffsHistorySize; ++delay1) { float min_dist = std::numeric_limits::max(); - for (size_t delay2 = 0; delay2 < kCepstralCoeffsHistorySize; ++delay2) { + for (int delay2 = 0; delay2 < kCepstralCoeffsHistorySize; ++delay2) { if (delay1 == delay2) // The distance would be 0. continue; min_dist = diff --git a/modules/audio_processing/agc2/rnn_vad/spectral_features_internal.cc b/modules/audio_processing/agc2/rnn_vad/spectral_features_internal.cc index 29192a08f6..91c0086fc4 100644 --- a/modules/audio_processing/agc2/rnn_vad/spectral_features_internal.cc +++ b/modules/audio_processing/agc2/rnn_vad/spectral_features_internal.cc @@ -15,6 +15,7 @@ #include #include "rtc_base/checks.h" +#include "rtc_base/numerics/safe_compare.h" namespace webrtc { namespace rnn_vad { @@ -105,9 +106,9 @@ void SpectralCorrelator::ComputeCrossCorrelation( RTC_DCHECK_EQ(x[1], 0.f) << "The Nyquist coefficient must be zeroed."; RTC_DCHECK_EQ(y[1], 0.f) << "The Nyquist coefficient must be zeroed."; constexpr auto kOpusScaleNumBins24kHz20ms = GetOpusScaleNumBins24kHz20ms(); - size_t k = 0; // Next Fourier coefficient index. + int k = 0; // Next Fourier coefficient index. cross_corr[0] = 0.f; - for (size_t i = 0; i < kOpusBands24kHz - 1; ++i) { + for (int i = 0; i < kOpusBands24kHz - 1; ++i) { cross_corr[i + 1] = 0.f; for (int j = 0; j < kOpusScaleNumBins24kHz20ms[i]; ++j) { // Band size. const float v = x[2 * k] * y[2 * k] + x[2 * k + 1] * y[2 * k + 1]; @@ -137,11 +138,11 @@ void ComputeSmoothedLogMagnitudeSpectrum( return x; }; // Smoothing over the bands for which the band energy is defined. - for (size_t i = 0; i < bands_energy.size(); ++i) { + for (int i = 0; rtc::SafeLt(i, bands_energy.size()); ++i) { log_bands_energy[i] = smooth(std::log10(kOneByHundred + bands_energy[i])); } // Smoothing over the remaining bands (zero energy). - for (size_t i = bands_energy.size(); i < kNumBands; ++i) { + for (int i = bands_energy.size(); i < kNumBands; ++i) { log_bands_energy[i] = smooth(kLogOneByHundred); } } @@ -149,8 +150,8 @@ void ComputeSmoothedLogMagnitudeSpectrum( std::array ComputeDctTable() { std::array dct_table; const double k = std::sqrt(0.5); - for (size_t i = 0; i < kNumBands; ++i) { - for (size_t j = 0; j < kNumBands; ++j) + for (int i = 0; i < kNumBands; ++i) { + for (int j = 0; j < kNumBands; ++j) dct_table[i * kNumBands + j] = std::cos((i + 0.5) * j * kPi / kNumBands); dct_table[i * kNumBands] *= k; } @@ -173,9 +174,9 @@ void ComputeDct(rtc::ArrayView in, RTC_DCHECK_LE(in.size(), kNumBands); RTC_DCHECK_LE(1, out.size()); RTC_DCHECK_LE(out.size(), in.size()); - for (size_t i = 0; i < out.size(); ++i) { + for (int i = 0; rtc::SafeLt(i, out.size()); ++i) { out[i] = 0.f; - for (size_t j = 0; j < in.size(); ++j) { + for (int j = 0; rtc::SafeLt(j, in.size()); ++j) { out[i] += in[j] * dct_table[j * kNumBands + i]; } // TODO(bugs.webrtc.org/10480): Scaling factor in the DCT table. diff --git a/modules/audio_processing/agc2/rnn_vad/spectral_features_internal.h b/modules/audio_processing/agc2/rnn_vad/spectral_features_internal.h index ed4caad025..aa7b1c6a47 100644 --- a/modules/audio_processing/agc2/rnn_vad/spectral_features_internal.h +++ b/modules/audio_processing/agc2/rnn_vad/spectral_features_internal.h @@ -25,7 +25,7 @@ namespace rnn_vad { // At a sample rate of 24 kHz, the last 3 Opus bands are beyond the Nyquist // frequency. However, band #19 gets the contributions from band #18 because // of the symmetric triangular filter with peak response at 12 kHz. -constexpr size_t kOpusBands24kHz = 20; +constexpr int kOpusBands24kHz = 20; static_assert(kOpusBands24kHz < kNumBands, "The number of bands at 24 kHz must be less than those defined " "in the Opus scale at 48 kHz."); diff --git a/modules/audio_processing/agc2/rnn_vad/spectral_features_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/spectral_features_internal_unittest.cc index ec81295094..461047d004 100644 --- a/modules/audio_processing/agc2/rnn_vad/spectral_features_internal_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/spectral_features_internal_unittest.cc @@ -19,6 +19,7 @@ #include "api/array_view.h" #include "modules/audio_processing/agc2/rnn_vad/test_utils.h" #include "modules/audio_processing/utility/pffft_wrapper.h" +#include "rtc_base/numerics/safe_compare.h" // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // #include "test/fpe_observer.h" #include "test/gtest.h" @@ -34,13 +35,13 @@ namespace { std::vector ComputeTriangularFiltersWeights() { constexpr auto kOpusScaleNumBins24kHz20ms = GetOpusScaleNumBins24kHz20ms(); const auto& v = kOpusScaleNumBins24kHz20ms; // Alias. - const size_t num_weights = std::accumulate( - kOpusScaleNumBins24kHz20ms.begin(), kOpusScaleNumBins24kHz20ms.end(), 0); + const int num_weights = std::accumulate(kOpusScaleNumBins24kHz20ms.begin(), + kOpusScaleNumBins24kHz20ms.end(), 0); std::vector weights(num_weights); - size_t next_fft_coeff_index = 0; - for (size_t band = 0; band < v.size(); ++band) { - const size_t band_size = v[band]; - for (size_t j = 0; j < band_size; ++j) { + int next_fft_coeff_index = 0; + for (int band = 0; rtc::SafeLt(band, v.size()); ++band) { + const int band_size = v[band]; + for (int j = 0; rtc::SafeLt(j, band_size); ++j) { weights[next_fft_coeff_index + j] = static_cast(j) / band_size; } next_fft_coeff_index += band_size; @@ -58,7 +59,7 @@ TEST(RnnVadTest, TestOpusScaleBoundaries) { 3200, 4000, 4800, 5600, 6800, 8000, 9600, 12000, 15600, 20000}; constexpr auto kOpusScaleNumBins24kHz20ms = GetOpusScaleNumBins24kHz20ms(); int prev = 0; - for (size_t i = 0; i < kOpusScaleNumBins24kHz20ms.size(); ++i) { + for (int i = 0; rtc::SafeLt(i, kOpusScaleNumBins24kHz20ms.size()); ++i) { int boundary = kBandFrequencyBoundariesHz[i] * kFrameSize20ms24kHz / kSampleRate24kHz; EXPECT_EQ(kOpusScaleNumBins24kHz20ms[i], boundary - prev); @@ -72,8 +73,8 @@ TEST(RnnVadTest, TestOpusScaleBoundaries) { // is updated accordingly. TEST(RnnVadTest, DISABLED_TestOpusScaleWeights) { auto weights = ComputeTriangularFiltersWeights(); - size_t i = 0; - for (size_t band_size : GetOpusScaleNumBins24kHz20ms()) { + int i = 0; + for (int band_size : GetOpusScaleNumBins24kHz20ms()) { SCOPED_TRACE(band_size); rtc::ArrayView band_weights(weights.data() + i, band_size); float prev = -1.f; @@ -98,7 +99,7 @@ TEST(RnnVadTest, SpectralCorrelatorValidOutput) { // Compute and check output. SpectralCorrelator e; e.ComputeAutoCorrelation(in_view, out); - for (size_t i = 0; i < kOpusBands24kHz; ++i) { + for (int i = 0; i < kOpusBands24kHz; ++i) { SCOPED_TRACE(i); EXPECT_GT(out[i], 0.f); } diff --git a/modules/audio_processing/agc2/rnn_vad/spectral_features_unittest.cc b/modules/audio_processing/agc2/rnn_vad/spectral_features_unittest.cc index bc00e2c500..fa376f2a0a 100644 --- a/modules/audio_processing/agc2/rnn_vad/spectral_features_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/spectral_features_unittest.cc @@ -14,6 +14,7 @@ #include "modules/audio_processing/agc2/rnn_vad/test_utils.h" #include "rtc_base/checks.h" +#include "rtc_base/numerics/safe_compare.h" // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // #include "test/fpe_observer.h" #include "test/gtest.h" @@ -23,11 +24,11 @@ namespace rnn_vad { namespace test { namespace { -constexpr size_t kTestFeatureVectorSize = kNumBands + 3 * kNumLowerBands + 1; +constexpr int kTestFeatureVectorSize = kNumBands + 3 * kNumLowerBands + 1; // Writes non-zero sample values. void WriteTestData(rtc::ArrayView samples) { - for (size_t i = 0; i < samples.size(); ++i) { + for (int i = 0; rtc::SafeLt(i, samples.size()); ++i) { samples[i] = i % 100; } } @@ -124,7 +125,7 @@ TEST(RnnVadTest, CepstralFeaturesConstantAverageZeroDerivative) { // Fill the spectral features with test data. std::array feature_vector; - for (size_t i = 0; i < kCepstralCoeffsHistorySize; ++i) { + for (int i = 0; i < kCepstralCoeffsHistorySize; ++i) { is_silence = sfe.CheckSilenceComputeFeatures( samples_view, samples_view, GetHigherBandsSpectrum(&feature_vector), GetAverage(&feature_vector), GetFirstDerivative(&feature_vector), diff --git a/modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer.h b/modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer.h index f0282aaed5..dd3b62a1a3 100644 --- a/modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer.h +++ b/modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer.h @@ -18,6 +18,7 @@ #include "api/array_view.h" #include "rtc_base/checks.h" +#include "rtc_base/numerics/safe_compare.h" namespace webrtc { namespace rnn_vad { @@ -29,7 +30,7 @@ namespace rnn_vad { // removed when one of the two corresponding items that have been compared is // removed from the ring buffer. It is assumed that the comparison is symmetric // and that comparing an item with itself is not needed. -template +template class SymmetricMatrixBuffer { static_assert(S > 2, ""); @@ -55,9 +56,9 @@ class SymmetricMatrixBuffer { // column left. std::memmove(buf_.data(), buf_.data() + S, (buf_.size() - S) * sizeof(T)); // Copy new values in the last column in the right order. - for (size_t i = 0; i < values.size(); ++i) { - const size_t index = (S - 1 - i) * (S - 1) - 1; - RTC_DCHECK_LE(static_cast(0), index); + for (int i = 0; rtc::SafeLt(i, values.size()); ++i) { + const int index = (S - 1 - i) * (S - 1) - 1; + RTC_DCHECK_GE(index, 0); RTC_DCHECK_LT(index, buf_.size()); buf_[index] = values[i]; } @@ -65,9 +66,9 @@ class SymmetricMatrixBuffer { // Reads the value that corresponds to comparison of two items in the ring // buffer having delay |delay1| and |delay2|. The two arguments must not be // equal and both must be in {0, ..., S - 1}. - T GetValue(size_t delay1, size_t delay2) const { - int row = S - 1 - static_cast(delay1); - int col = S - 1 - static_cast(delay2); + T GetValue(int delay1, int delay2) const { + int row = S - 1 - delay1; + int col = S - 1 - delay2; RTC_DCHECK_NE(row, col) << "The diagonal cannot be accessed."; if (row > col) std::swap(row, col); // Swap to access the upper-right triangular part. diff --git a/modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer_unittest.cc b/modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer_unittest.cc index a1b8007696..c1da8d181b 100644 --- a/modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer_unittest.cc @@ -18,10 +18,10 @@ namespace rnn_vad { namespace test { namespace { -template +template void CheckSymmetry(const SymmetricMatrixBuffer* sym_matrix_buf) { - for (size_t row = 0; row < S - 1; ++row) - for (size_t col = row + 1; col < S; ++col) + for (int row = 0; row < S - 1; ++row) + for (int col = row + 1; col < S; ++col) EXPECT_EQ(sym_matrix_buf->GetValue(row, col), sym_matrix_buf->GetValue(col, row)); } @@ -30,12 +30,12 @@ using PairType = std::pair; // Checks that the symmetric matrix buffer contains any pair with a value equal // to the given one. -template +template bool CheckPairsWithValueExist( const SymmetricMatrixBuffer* sym_matrix_buf, const int value) { - for (size_t row = 0; row < S - 1; ++row) { - for (size_t col = row + 1; col < S; ++col) { + for (int row = 0; row < S - 1; ++row) { + for (int col = row + 1; col < S; ++col) { auto p = sym_matrix_buf->GetValue(row, col); if (p.first == value || p.second == value) return true; @@ -52,7 +52,7 @@ bool CheckPairsWithValueExist( TEST(RnnVadTest, SymmetricMatrixBufferUseCase) { // Instance a ring buffer which will be fed with a series of integer values. constexpr int kRingBufSize = 10; - RingBuffer(kRingBufSize)> ring_buf; + RingBuffer ring_buf; // Instance a symmetric matrix buffer for the ring buffer above. It stores // pairs of integers with which this test can easily check that the evolution // of RingBuffer and SymmetricMatrixBuffer match. @@ -81,8 +81,8 @@ TEST(RnnVadTest, SymmetricMatrixBufferUseCase) { CheckSymmetry(&sym_matrix_buf); // Check that the pairs resulting from the content in the ring buffer are // in the right position. - for (size_t delay1 = 0; delay1 < kRingBufSize - 1; ++delay1) { - for (size_t delay2 = delay1 + 1; delay2 < kRingBufSize; ++delay2) { + for (int delay1 = 0; delay1 < kRingBufSize - 1; ++delay1) { + for (int delay2 = delay1 + 1; delay2 < kRingBufSize; ++delay2) { const auto t1 = ring_buf.GetArrayView(delay1)[0]; const auto t2 = ring_buf.GetArrayView(delay2)[0]; ASSERT_LE(t2, t1); @@ -93,7 +93,7 @@ TEST(RnnVadTest, SymmetricMatrixBufferUseCase) { } // Check that every older element in the ring buffer still has a // corresponding pair in the symmetric matrix buffer. - for (size_t delay = 1; delay < kRingBufSize; ++delay) { + for (int delay = 1; delay < kRingBufSize; ++delay) { const auto t_prev = ring_buf.GetArrayView(delay)[0]; EXPECT_TRUE(CheckPairsWithValueExist(&sym_matrix_buf, t_prev)); } diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.cc b/modules/audio_processing/agc2/rnn_vad/test_utils.cc index c7bf02e74b..74571af640 100644 --- a/modules/audio_processing/agc2/rnn_vad/test_utils.cc +++ b/modules/audio_processing/agc2/rnn_vad/test_utils.cc @@ -13,6 +13,7 @@ #include #include "rtc_base/checks.h" +#include "rtc_base/numerics/safe_compare.h" #include "rtc_base/system/arch.h" #include "system_wrappers/include/cpu_features_wrapper.h" #include "test/gtest.h" @@ -24,7 +25,7 @@ namespace test { namespace { using ReaderPairType = - std::pair>, const size_t>; + std::pair>, const int>; } // namespace @@ -33,7 +34,7 @@ using webrtc::test::ResourcePath; void ExpectEqualFloatArray(rtc::ArrayView expected, rtc::ArrayView computed) { ASSERT_EQ(expected.size(), computed.size()); - for (size_t i = 0; i < expected.size(); ++i) { + for (int i = 0; rtc::SafeLt(i, expected.size()); ++i) { SCOPED_TRACE(i); EXPECT_FLOAT_EQ(expected[i], computed[i]); } @@ -43,14 +44,14 @@ void ExpectNearAbsolute(rtc::ArrayView expected, rtc::ArrayView computed, float tolerance) { ASSERT_EQ(expected.size(), computed.size()); - for (size_t i = 0; i < expected.size(); ++i) { + for (int i = 0; rtc::SafeLt(i, expected.size()); ++i) { SCOPED_TRACE(i); EXPECT_NEAR(expected[i], computed[i], tolerance); } } -std::pair>, const size_t> -CreatePcmSamplesReader(const size_t frame_length) { +std::pair>, const int> +CreatePcmSamplesReader(const int frame_length) { auto ptr = std::make_unique>( test::ResourcePath("audio_processing/agc2/rnn_vad/samples", "pcm"), frame_length); @@ -59,14 +60,14 @@ CreatePcmSamplesReader(const size_t frame_length) { } ReaderPairType CreatePitchBuffer24kHzReader() { - constexpr size_t cols = 864; + constexpr int cols = 864; auto ptr = std::make_unique>( ResourcePath("audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat"), cols); return {std::move(ptr), rtc::CheckedDivExact(ptr->data_length(), cols)}; } ReaderPairType CreateLpResidualAndPitchPeriodGainReader() { - constexpr size_t num_lp_residual_coeffs = 864; + constexpr int num_lp_residual_coeffs = 864; auto ptr = std::make_unique>( ResourcePath("audio_processing/agc2/rnn_vad/pitch_lp_res", "dat"), num_lp_residual_coeffs); @@ -83,7 +84,7 @@ ReaderPairType CreateVadProbsReader() { PitchTestData::PitchTestData() { BinaryFileReader test_data_reader( ResourcePath("audio_processing/agc2/rnn_vad/pitch_search_int", "dat"), - static_cast(1396)); + 1396); test_data_reader.ReadChunk(test_data_); } diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.h b/modules/audio_processing/agc2/rnn_vad/test_utils.h index db155e6a75..23e642be81 100644 --- a/modules/audio_processing/agc2/rnn_vad/test_utils.h +++ b/modules/audio_processing/agc2/rnn_vad/test_utils.h @@ -24,6 +24,7 @@ #include "api/array_view.h" #include "modules/audio_processing/agc2/rnn_vad/common.h" #include "rtc_base/checks.h" +#include "rtc_base/numerics/safe_compare.h" namespace webrtc { namespace rnn_vad { @@ -47,7 +48,7 @@ void ExpectNearAbsolute(rtc::ArrayView expected, template class BinaryFileReader { public: - explicit BinaryFileReader(const std::string& file_path, size_t chunk_size = 0) + BinaryFileReader(const std::string& file_path, int chunk_size = 0) : is_(file_path, std::ios::binary | std::ios::ate), data_length_(is_.tellg() / sizeof(T)), chunk_size_(chunk_size) { @@ -58,7 +59,7 @@ class BinaryFileReader { BinaryFileReader(const BinaryFileReader&) = delete; BinaryFileReader& operator=(const BinaryFileReader&) = delete; ~BinaryFileReader() = default; - size_t data_length() const { return data_length_; } + int data_length() const { return data_length_; } bool ReadValue(D* dst) { if (std::is_same::value) { is_.read(reinterpret_cast(dst), sizeof(T)); @@ -72,7 +73,7 @@ class BinaryFileReader { // If |chunk_size| was specified in the ctor, it will check that the size of // |dst| equals |chunk_size|. bool ReadChunk(rtc::ArrayView dst) { - RTC_DCHECK((chunk_size_ == 0) || (chunk_size_ == dst.size())); + RTC_DCHECK((chunk_size_ == 0) || rtc::SafeEq(chunk_size_, dst.size())); const std::streamsize bytes_to_read = dst.size() * sizeof(T); if (std::is_same::value) { is_.read(reinterpret_cast(dst.data()), bytes_to_read); @@ -83,13 +84,13 @@ class BinaryFileReader { } return is_.gcount() == bytes_to_read; } - void SeekForward(size_t items) { is_.seekg(items * sizeof(T), is_.cur); } + void SeekForward(int items) { is_.seekg(items * sizeof(T), is_.cur); } void SeekBeginning() { is_.seekg(0, is_.beg); } private: std::ifstream is_; - const size_t data_length_; - const size_t chunk_size_; + const int data_length_; + const int chunk_size_; std::vector buf_; }; @@ -117,22 +118,22 @@ class BinaryFileWriter { // pointer and the second the number of chunks that can be read from the file. // Creates a reader for the PCM samples that casts from S16 to float and reads // chunks with length |frame_length|. -std::pair>, const size_t> -CreatePcmSamplesReader(const size_t frame_length); +std::pair>, const int> +CreatePcmSamplesReader(const int frame_length); // Creates a reader for the pitch buffer content at 24 kHz. -std::pair>, const size_t> +std::pair>, const int> CreatePitchBuffer24kHzReader(); // Creates a reader for the the LP residual coefficients and the pitch period // and gain values. -std::pair>, const size_t> +std::pair>, const int> CreateLpResidualAndPitchPeriodGainReader(); // Creates a reader for the VAD probabilities. -std::pair>, const size_t> +std::pair>, const int> CreateVadProbsReader(); -constexpr size_t kNumPitchBufAutoCorrCoeffs = 147; -constexpr size_t kNumPitchBufSquareEnergies = 385; -constexpr size_t kPitchTestDataSize = +constexpr int kNumPitchBufAutoCorrCoeffs = 147; +constexpr int kNumPitchBufSquareEnergies = 385; +constexpr int kPitchTestDataSize = kBufSize24kHz + kNumPitchBufSquareEnergies + kNumPitchBufAutoCorrCoeffs; // Class to retrieve a test pitch buffer content and the expected output for the