From beb1d34729f794fa57cb70b62f63d706ad69c7e8 Mon Sep 17 00:00:00 2001 From: Alessio Bazzica Date: Mon, 14 May 2018 20:22:18 +0200 Subject: [PATCH] AGC2 RNN VAD: Feature extraction. This CL finalizes the feature extraction part for the RNN VAD adding a class that combines a high-pass filter, LP residual computation, pitch estimation and spectral features extraction. This CL also includes a minor refactoring of the pitch estimation library. Bug: webrtc:9076 Change-Id: I918b9e143bc6dd2bf508a891446067258a68a777 Reviewed-on: https://webrtc-review.googlesource.com/75504 Commit-Queue: Alessio Bazzica Reviewed-by: Alex Loiko Cr-Commit-Position: refs/heads/master@{#23235} --- modules/audio_processing/agc2/BUILD.gn | 2 +- modules/audio_processing/agc2/biquad_filter.h | 6 +- .../agc2/biquad_filter_unittest.cc | 19 ++++ .../audio_processing/agc2/rnn_vad/BUILD.gn | 4 + .../agc2/rnn_vad/features_extraction.cc | 90 +++++++++++++++ .../agc2/rnn_vad/features_extraction.h | 62 +++++++++++ .../rnn_vad/features_extraction_unittest.cc | 103 ++++++++++++++++++ .../agc2/rnn_vad/pitch_search.cc | 40 +++---- .../agc2/rnn_vad/pitch_search.h | 28 ++++- .../agc2/rnn_vad/pitch_search_unittest.cc | 13 +-- .../agc2/rnn_vad/sequence_buffer.h | 43 ++++---- .../agc2/rnn_vad/sequence_buffer_unittest.cc | 1 + 12 files changed, 355 insertions(+), 56 deletions(-) create mode 100644 modules/audio_processing/agc2/rnn_vad/features_extraction.cc create mode 100644 modules/audio_processing/agc2/rnn_vad/features_extraction.h create mode 100644 modules/audio_processing/agc2/rnn_vad/features_extraction_unittest.cc diff --git a/modules/audio_processing/agc2/BUILD.gn b/modules/audio_processing/agc2/BUILD.gn index 9ea1d44a60..e0ed2bb765 100644 --- a/modules/audio_processing/agc2/BUILD.gn +++ b/modules/audio_processing/agc2/BUILD.gn @@ -48,7 +48,7 @@ rtc_source_set("adaptive_digital") { } rtc_source_set("biquad_filter") { - visibility = [ ":*" ] + visibility = [ "./*" ] sources = [ "biquad_filter.cc", "biquad_filter.h", diff --git a/modules/audio_processing/agc2/biquad_filter.h b/modules/audio_processing/agc2/biquad_filter.h index 523d5822d3..284930c595 100644 --- a/modules/audio_processing/agc2/biquad_filter.h +++ b/modules/audio_processing/agc2/biquad_filter.h @@ -34,13 +34,17 @@ class BiQuadFilter { coefficients_ = coefficients; } + void Reset() { biquad_state_.Reset(); } + // Produces a filtered output y of the input x. Both x and y need to // have the same length. In-place modification is allowed. void Process(rtc::ArrayView x, rtc::ArrayView y); private: struct BiQuadState { - BiQuadState() { + BiQuadState() { Reset(); } + + void Reset() { std::fill(b, b + arraysize(b), 0.f); std::fill(a, a + arraysize(a), 0.f); } diff --git a/modules/audio_processing/agc2/biquad_filter_unittest.cc b/modules/audio_processing/agc2/biquad_filter_unittest.cc index 2fa161ea4a..cd9a272787 100644 --- a/modules/audio_processing/agc2/biquad_filter_unittest.cc +++ b/modules/audio_processing/agc2/biquad_filter_unittest.cc @@ -113,5 +113,24 @@ TEST(BiQuadFilterTest, FilterInPlace) { ExpectNearRelative(kBiQuadOutputSeq[i], samples, 2e-4f); } } + +TEST(BiQuadFilterTest, Reset) { + BiQuadFilter filter; + filter.Initialize(kBiQuadConfig); + + std::array samples1; + for (size_t i = 0; i < kNumFrames; ++i) { + filter.Process(kBiQuadInputSeq[i], samples1); + } + + filter.Reset(); + std::array samples2; + for (size_t i = 0; i < kNumFrames; ++i) { + filter.Process(kBiQuadInputSeq[i], samples2); + } + + EXPECT_EQ(samples1, samples2); +} + } // namespace test } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn index 83ab7c8dd7..f387ed31b6 100644 --- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn +++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn @@ -17,6 +17,8 @@ group("rnn_vad") { source_set("lib") { sources = [ "common.h", + "features_extraction.cc", + "features_extraction.h", "fft_util.cc", "fft_util.h", "lp_residual.cc", @@ -37,6 +39,7 @@ source_set("lib") { "symmetric_matrix_buffer.h", ] deps = [ + "..:biquad_filter", "../../../../api:array_view", "../../../../common_audio/", "../../../../rtc_base:checks", @@ -84,6 +87,7 @@ if (rtc_include_tests) { rtc_source_set("unittests") { testonly = true sources = [ + "features_extraction_unittest.cc", "fft_util_unittest.cc", "lp_residual_unittest.cc", "pitch_search_internal_unittest.cc", diff --git a/modules/audio_processing/agc2/rnn_vad/features_extraction.cc b/modules/audio_processing/agc2/rnn_vad/features_extraction.cc new file mode 100644 index 0000000000..01dcae7a3c --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/features_extraction.cc @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h" + +#include "modules/audio_processing/agc2/rnn_vad/lp_residual.h" +#include "rtc_base/checks.h" + +namespace webrtc { +namespace rnn_vad { +namespace { + +// Generated via "B, A = scipy.signal.butter(2, 30/12000, btype='highpass')" +const BiQuadFilter::BiQuadCoefficients kHpfConfig24k = { + {0.99446179f, -1.98892358f, 0.99446179f}, + {-1.98889291f, 0.98895425f}}; + +} // namespace + +FeaturesExtractor::FeaturesExtractor() + : use_high_pass_filter_(false), + pitch_buf_24kHz_(), + pitch_buf_24kHz_view_(pitch_buf_24kHz_.GetBufferView()), + lp_residual_(kBufSize24kHz), + lp_residual_view_(lp_residual_.data(), kBufSize24kHz), + pitch_estimator_(), + reference_frame_view_(pitch_buf_24kHz_.GetMostRecentValuesView()) { + RTC_DCHECK_EQ(kBufSize24kHz, lp_residual_.size()); + hpf_.Initialize(kHpfConfig24k); + Reset(); +} + +FeaturesExtractor::~FeaturesExtractor() = default; + +void FeaturesExtractor::Reset() { + pitch_buf_24kHz_.Reset(); + spectral_features_extractor_.Reset(); + if (use_high_pass_filter_) + hpf_.Reset(); +} + +bool FeaturesExtractor::CheckSilenceComputeFeatures( + rtc::ArrayView samples, + rtc::ArrayView feature_vector) { + // Pre-processing. + if (use_high_pass_filter_) { + std::array samples_filtered; + hpf_.Process(samples, samples_filtered); + // Feed buffer with the pre-processed version of |samples|. + pitch_buf_24kHz_.Push({samples_filtered.data(), samples_filtered.size()}); + } else { + // Feed buffer with |samples|. + pitch_buf_24kHz_.Push(samples); + } + // Extract the LP residual. + float lpc_coeffs[kNumLpcCoefficients]; + ComputeAndPostProcessLpcCoefficients(pitch_buf_24kHz_view_, + {lpc_coeffs, kNumLpcCoefficients}); + ComputeLpResidual({lpc_coeffs, kNumLpcCoefficients}, pitch_buf_24kHz_view_, + lp_residual_view_); + // Estimate pitch on the LP-residual and write the normalized pitch period + // into the output vector (normalization based on training data stats). + pitch_info_48kHz_ = pitch_estimator_.Estimate(lp_residual_view_); + feature_vector[kFeatureVectorSize - 2] = + 0.01f * (static_cast(pitch_info_48kHz_.period) - 300); + // Extract lagged frames (according to the estimated pitch period). + RTC_DCHECK_LE(pitch_info_48kHz_.period / 2, kMaxPitch24kHz); + auto lagged_frame = pitch_buf_24kHz_view_.subview( + kMaxPitch24kHz - pitch_info_48kHz_.period / 2, kFrameSize20ms24kHz); + // Analyze reference and lagged frames checking if silence has been detected + // and write the feature vector. + return spectral_features_extractor_.CheckSilenceComputeFeatures( + reference_frame_view_, {lagged_frame.data(), kFrameSize20ms24kHz}, + {{feature_vector.data() + kNumLowerBands, kNumBands - kNumLowerBands}, + {feature_vector.data(), kNumLowerBands}, + {feature_vector.data() + kNumBands, kNumLowerBands}, + {feature_vector.data() + kNumBands + kNumLowerBands, kNumLowerBands}, + {feature_vector.data() + kNumBands + 2 * kNumLowerBands, kNumLowerBands}, + &feature_vector[kFeatureVectorSize - 1]}); +} + +} // namespace rnn_vad +} // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/features_extraction.h b/modules/audio_processing/agc2/rnn_vad/features_extraction.h new file mode 100644 index 0000000000..1f63885c4e --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/features_extraction.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_FEATURES_EXTRACTION_H_ +#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_FEATURES_EXTRACTION_H_ + +#include +#include + +#include "api/array_view.h" +#include "modules/audio_processing/agc2/biquad_filter.h" +#include "modules/audio_processing/agc2/rnn_vad/common.h" +#include "modules/audio_processing/agc2/rnn_vad/pitch_search.h" +#include "modules/audio_processing/agc2/rnn_vad/sequence_buffer.h" +#include "modules/audio_processing/agc2/rnn_vad/spectral_features.h" + +namespace webrtc { +namespace rnn_vad { + +// Feature extractor to feed the VAD RNN. +class FeaturesExtractor { + public: + FeaturesExtractor(); + FeaturesExtractor(const FeaturesExtractor&) = delete; + FeaturesExtractor& operator=(const FeaturesExtractor&) = delete; + ~FeaturesExtractor(); + void Reset(); + // Analyzes the samples, computes the feature vector and returns true if + // silence is detected (false if not). When silence is detected, + // |feature_vector| is partially written and therefore must not be used to + // feed the VAD RNN. + bool CheckSilenceComputeFeatures( + rtc::ArrayView samples, + rtc::ArrayView feature_vector); + + private: + const bool use_high_pass_filter_; + // TODO(bugs.webrtc.org/7494): Remove HPF depending on how AGC2 is used in APM + // and on whether an HPF is already used as pre-processing step in APM. + BiQuadFilter hpf_; + SequenceBuffer + pitch_buf_24kHz_; + rtc::ArrayView pitch_buf_24kHz_view_; + std::vector lp_residual_; + rtc::ArrayView lp_residual_view_; + PitchEstimator pitch_estimator_; + rtc::ArrayView reference_frame_view_; + SpectralFeaturesExtractor spectral_features_extractor_; + PitchInfo pitch_info_48kHz_; +}; + +} // namespace rnn_vad +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_FEATURES_EXTRACTION_H_ diff --git a/modules/audio_processing/agc2/rnn_vad/features_extraction_unittest.cc b/modules/audio_processing/agc2/rnn_vad/features_extraction_unittest.cc new file mode 100644 index 0000000000..3405c9080c --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/features_extraction_unittest.cc @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h" + +#include +#include + +#include "modules/audio_processing/agc2/rnn_vad/test_utils.h" +// TODO(bugs.webrtc.org/8948): Add when the issue is fixed. +// #include "test/fpe_observer.h" +#include "test/gtest.h" + +namespace webrtc { +namespace rnn_vad { +namespace test { +namespace { + +constexpr size_t ceil(size_t n, size_t m) { + return (n + m - 1) / m; +} + +// Number of 10 ms frames required to fill a pitch buffer having size +// |kBufSize24kHz|. +constexpr size_t kNumTestDataFrames = ceil(kBufSize24kHz, kFrameSize10ms24kHz); +// Number of samples for the test data. +constexpr size_t kNumTestDataSize = kNumTestDataFrames * kFrameSize10ms24kHz; + +// Verifies that the pitch in Hz is in the detectable range. +bool PitchIsValid(float pitch_hz) { + const size_t pitch_period = + static_cast(static_cast(kSampleRate24kHz) / pitch_hz); + return kInitialMinPitch24kHz <= pitch_period && + pitch_period <= kMaxPitch24kHz; +} + +void CreatePureTone(float amplitude, float freq_hz, rtc::ArrayView dst) { + for (size_t i = 0; i < dst.size(); ++i) + dst[i] = amplitude * std::sin(2.f * kPi * freq_hz * i / kSampleRate24kHz); +} + +// Feeds |features_extractor| with |samples| splitting it in 10 ms frames. +// For every frame, the output is written into |feature_vector|. Returns true +// if silence is detected in the last frame. +bool FeedTestData(FeaturesExtractor* features_extractor, + rtc::ArrayView samples, + rtc::ArrayView feature_vector) { + // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. + // FloatingPointExceptionObserver fpe_observer; + bool is_silence = true; + const size_t num_frames = samples.size() / kFrameSize10ms24kHz; + for (size_t i = 0; i < num_frames; ++i) { + is_silence = features_extractor->CheckSilenceComputeFeatures( + {samples.data() + i * kFrameSize10ms24kHz, kFrameSize10ms24kHz}, + feature_vector); + } + return is_silence; +} + +} // namespace + +// Extracts the features for two pure tones and verifies that the pitch field +// values reflect the known tone frequencies. +TEST(RnnVadTest, FeatureExtractionLowHighPitch) { + constexpr float amplitude = 1000.f; + constexpr float low_pitch_hz = 150.f; + constexpr float high_pitch_hz = 250.f; + ASSERT_TRUE(PitchIsValid(low_pitch_hz)); + ASSERT_TRUE(PitchIsValid(high_pitch_hz)); + + FeaturesExtractor features_extractor; + std::vector samples(kNumTestDataSize); + std::vector feature_vector(kFeatureVectorSize); + ASSERT_EQ(kFeatureVectorSize, feature_vector.size()); + rtc::ArrayView feature_vector_view( + feature_vector.data(), kFeatureVectorSize); + + // Extract the normalized scalar feature that is proportional to the estimated + // pitch period. + constexpr size_t pitch_feature_index = kFeatureVectorSize - 2; + // Low frequency tone - i.e., high period. + CreatePureTone(amplitude, low_pitch_hz, samples); + ASSERT_FALSE(FeedTestData(&features_extractor, samples, feature_vector_view)); + float high_pitch_period = feature_vector_view[pitch_feature_index]; + // High frequency tone - i.e., low period. + features_extractor.Reset(); + CreatePureTone(amplitude, high_pitch_hz, samples); + ASSERT_FALSE(FeedTestData(&features_extractor, samples, feature_vector_view)); + float low_pitch_period = feature_vector_view[pitch_feature_index]; + // Check. + EXPECT_LT(low_pitch_period, high_pitch_period); +} + +} // namespace test +} // namespace rnn_vad +} // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc index 28959dda78..7596065299 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc @@ -9,32 +9,33 @@ */ #include "modules/audio_processing/agc2/rnn_vad/pitch_search.h" -#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h" namespace webrtc { namespace rnn_vad { -// TODO(bugs.webrtc.org/9076): To decrease the stack size, add a class that uses -// std::vector instances instead of the local arrays used in PitchSearch(). It -// is also useful once https://webrtc-review.googlesource.com/c/src/+/73366 -// lands. -PitchInfo PitchSearch(rtc::ArrayView pitch_buf, - PitchInfo prev_pitch_48kHz, - RealFourier* fft) { +PitchEstimator::PitchEstimator() + : fft_(RealFourier::Create(kAutoCorrelationFftOrder)), + pitch_buf_decimated_(kBufSize12kHz), + pitch_buf_decimated_view_(pitch_buf_decimated_.data(), kBufSize12kHz), + auto_corr_(kNumInvertedLags12kHz), + auto_corr_view_(auto_corr_.data(), kNumInvertedLags12kHz) { + RTC_DCHECK_EQ(kBufSize12kHz, pitch_buf_decimated_.size()); + RTC_DCHECK_EQ(kNumInvertedLags12kHz, auto_corr_view_.size()); +} + +PitchEstimator::~PitchEstimator() = default; + +PitchInfo PitchEstimator::Estimate( + rtc::ArrayView pitch_buf) { // Perform the initial pitch search at 12 kHz. - std::array pitch_buf_decimated; - Decimate2x(pitch_buf, - {pitch_buf_decimated.data(), pitch_buf_decimated.size()}); + Decimate2x(pitch_buf, pitch_buf_decimated_view_); // Compute auto-correlation terms. - std::array auto_corr; - ComputePitchAutoCorrelation( - {pitch_buf_decimated.data(), pitch_buf_decimated.size()}, kMaxPitch12kHz, - {auto_corr.data(), auto_corr.size()}, fft); + ComputePitchAutoCorrelation(pitch_buf_decimated_view_, kMaxPitch12kHz, + auto_corr_view_, fft_.get()); // Search for pitch at 12 kHz. std::array pitch_candidates_inv_lags = FindBestPitchPeriods( - {auto_corr.data(), auto_corr.size()}, - {pitch_buf_decimated.data(), pitch_buf_decimated.size()}, kMaxPitch12kHz); + auto_corr_view_, pitch_buf_decimated_view_, kMaxPitch12kHz); // Refine the pitch period estimation. // The refinement is done using the pitch buffer that contains 24 kHz samples. @@ -47,8 +48,9 @@ PitchInfo PitchSearch(rtc::ArrayView pitch_buf, {pitch_candidates_inv_lags.data(), pitch_candidates_inv_lags.size()}); // Look for stronger harmonics to find the final pitch period and its gain. RTC_DCHECK_LT(pitch_inv_lag_48kHz, kMaxPitch48kHz); - return CheckLowerPitchPeriodsAndComputePitchGain( - pitch_buf, kMaxPitch48kHz - pitch_inv_lag_48kHz, prev_pitch_48kHz); + last_pitch_48kHz_ = CheckLowerPitchPeriodsAndComputePitchGain( + pitch_buf, kMaxPitch48kHz - pitch_inv_lag_48kHz, last_pitch_48kHz_); + return last_pitch_48kHz_; } } // namespace rnn_vad diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.h b/modules/audio_processing/agc2/rnn_vad/pitch_search.h index 21e7a05b9e..59145353c1 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search.h +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.h @@ -11,19 +11,37 @@ #ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_H_ #define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_H_ +#include +#include + #include "api/array_view.h" #include "common_audio/real_fourier.h" #include "modules/audio_processing/agc2/rnn_vad/common.h" #include "modules/audio_processing/agc2/rnn_vad/pitch_info.h" +#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h" namespace webrtc { namespace rnn_vad { -// Searches the pitch period and gain. Return the pitch estimation data for -// 48 kHz. -PitchInfo PitchSearch(rtc::ArrayView pitch_buf, - PitchInfo prev_pitch_48kHz, - RealFourier* fft); +// Pitch estimator. +class PitchEstimator { + public: + PitchEstimator(); + PitchEstimator(const PitchEstimator&) = delete; + PitchEstimator& operator=(const PitchEstimator&) = delete; + ~PitchEstimator(); + // Estimates the pitch period and gain. Returns the pitch estimation data for + // 48 kHz. + PitchInfo Estimate(rtc::ArrayView pitch_buf); + + private: + PitchInfo last_pitch_48kHz_; + std::unique_ptr fft_; + std::vector pitch_buf_decimated_; + rtc::ArrayView pitch_buf_decimated_view_; + std::vector auto_corr_; + rtc::ArrayView auto_corr_view_; +}; } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc index b25aba393e..9e69b25393 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc @@ -9,6 +9,7 @@ */ #include "modules/audio_processing/agc2/rnn_vad/pitch_search.h" +#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h" #include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h" #include @@ -28,9 +29,7 @@ TEST(RnnVadTest, PitchSearchBitExactness) { const size_t num_frames = lp_residual_reader.second; std::array lp_residual; float expected_pitch_period, expected_pitch_gain; - PitchInfo last_pitch; - std::unique_ptr fft = - RealFourier::Create(kAutoCorrelationFftOrder); + PitchEstimator pitch_estimator; { // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; @@ -41,10 +40,10 @@ TEST(RnnVadTest, PitchSearchBitExactness) { {lp_residual.data(), lp_residual.size()}); lp_residual_reader.first->ReadValue(&expected_pitch_period); lp_residual_reader.first->ReadValue(&expected_pitch_gain); - last_pitch = PitchSearch({lp_residual.data(), lp_residual.size()}, - last_pitch, fft.get()); - EXPECT_EQ(static_cast(expected_pitch_period), last_pitch.period); - EXPECT_NEAR(expected_pitch_gain, last_pitch.gain, 1e-5f); + PitchInfo pitch_info = + pitch_estimator.Estimate({lp_residual.data(), lp_residual.size()}); + EXPECT_EQ(static_cast(expected_pitch_period), pitch_info.period); + EXPECT_NEAR(expected_pitch_gain, pitch_info.gain, 1e-5f); } } } diff --git a/modules/audio_processing/agc2/rnn_vad/sequence_buffer.h b/modules/audio_processing/agc2/rnn_vad/sequence_buffer.h index df92e23b81..75d3d9bc09 100644 --- a/modules/audio_processing/agc2/rnn_vad/sequence_buffer.h +++ b/modules/audio_processing/agc2/rnn_vad/sequence_buffer.h @@ -11,9 +11,10 @@ #ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SEQUENCE_BUFFER_H_ #define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SEQUENCE_BUFFER_H_ -#include +#include #include #include +#include #include "api/array_view.h" #include "rtc_base/checks.h" @@ -26,41 +27,38 @@ namespace rnn_vad { // chunks have size S and N respectively. For instance, when S = 2N the first // half of the sequence buffer is replaced with its second half, and the new N // values are written at the end of the buffer. -template +// The class also provides a view on the most recent M values, where 0 < M <= S +// and by default M = N. +template class SequenceBuffer { - static_assert(S >= N, - "The new chunk size is larger than the sequence buffer size."); + static_assert(N <= S, + "The new chunk size cannot be larger than the sequence buffer " + "size."); static_assert(std::is_arithmetic::value, "Integral or floating point required."); public: - SequenceBuffer() { buffer_.fill(0); } + SequenceBuffer() : buffer_(S) { + RTC_DCHECK_EQ(S, buffer_.size()); + Reset(); + } SequenceBuffer(const SequenceBuffer&) = delete; SequenceBuffer& operator=(const SequenceBuffer&) = delete; ~SequenceBuffer() = default; size_t size() const { return S; } size_t chunks_size() const { return N; } // Sets the sequence buffer values to zero. - void Reset() { buffer_.fill(0); } + void Reset() { std::fill(buffer_.begin(), buffer_.end(), 0); } // Returns a view on the whole buffer. rtc::ArrayView GetBufferView() const { return {buffer_.data(), S}; } - // Returns a view on part of the buffer; the first element starts at the given - // offset and the last one is the last one in the buffer. - rtc::ArrayView GetBufferView(int offset) const { - RTC_DCHECK_LE(0, offset); - RTC_DCHECK_LT(offset, S); - return {buffer_.data() + offset, S - offset}; - } - // Returns a view on part of the buffer; the first element starts at the given - // offset and the size of the view is |size|. - rtc::ArrayView GetBufferView(int offset, size_t size) const { - RTC_DCHECK_LE(0, offset); - RTC_DCHECK_LT(offset, S); - RTC_DCHECK_LT(0, size); - RTC_DCHECK_LE(size, S - offset); - return {buffer_.data() + offset, size}; + // Returns a view on the M most recent values of the buffer. + rtc::ArrayView GetMostRecentValuesView() const { + static_assert(M <= S, + "The number of most recent values cannot be larger than the " + "sequence buffer size."); + return {buffer_.data() + S - M, M}; } // Shifts left the buffer by N items and add new N items at the end. void Push(rtc::ArrayView new_values) { @@ -72,8 +70,7 @@ class SequenceBuffer { } private: - // TODO(bugs.webrtc.org/9076): Switch to std::vector to decrease stack size. - std::array buffer_; + std::vector buffer_; }; } // namespace rnn_vad diff --git a/modules/audio_processing/agc2/rnn_vad/sequence_buffer_unittest.cc b/modules/audio_processing/agc2/rnn_vad/sequence_buffer_unittest.cc index 7628c17689..900941b678 100644 --- a/modules/audio_processing/agc2/rnn_vad/sequence_buffer_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/sequence_buffer_unittest.cc @@ -11,6 +11,7 @@ #include "modules/audio_processing/agc2/rnn_vad/sequence_buffer.h" #include +#include #include "test/gtest.h"