AGC2 RNN VAD: Feature extraction.
This CL finalizes the feature extraction part for the RNN VAD adding a class that combines a high-pass filter, LP residual computation, pitch estimation and spectral features extraction. This CL also includes a minor refactoring of the pitch estimation library. Bug: webrtc:9076 Change-Id: I918b9e143bc6dd2bf508a891446067258a68a777 Reviewed-on: https://webrtc-review.googlesource.com/75504 Commit-Queue: Alessio Bazzica <alessiob@webrtc.org> Reviewed-by: Alex Loiko <aleloi@webrtc.org> Cr-Commit-Position: refs/heads/master@{#23235}
This commit is contained in:
parent
bfe3d854cd
commit
beb1d34729
@ -48,7 +48,7 @@ rtc_source_set("adaptive_digital") {
|
||||
}
|
||||
|
||||
rtc_source_set("biquad_filter") {
|
||||
visibility = [ ":*" ]
|
||||
visibility = [ "./*" ]
|
||||
sources = [
|
||||
"biquad_filter.cc",
|
||||
"biquad_filter.h",
|
||||
|
||||
@ -34,13 +34,17 @@ class BiQuadFilter {
|
||||
coefficients_ = coefficients;
|
||||
}
|
||||
|
||||
void Reset() { biquad_state_.Reset(); }
|
||||
|
||||
// Produces a filtered output y of the input x. Both x and y need to
|
||||
// have the same length. In-place modification is allowed.
|
||||
void Process(rtc::ArrayView<const float> x, rtc::ArrayView<float> y);
|
||||
|
||||
private:
|
||||
struct BiQuadState {
|
||||
BiQuadState() {
|
||||
BiQuadState() { Reset(); }
|
||||
|
||||
void Reset() {
|
||||
std::fill(b, b + arraysize(b), 0.f);
|
||||
std::fill(a, a + arraysize(a), 0.f);
|
||||
}
|
||||
|
||||
@ -113,5 +113,24 @@ TEST(BiQuadFilterTest, FilterInPlace) {
|
||||
ExpectNearRelative(kBiQuadOutputSeq[i], samples, 2e-4f);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BiQuadFilterTest, Reset) {
|
||||
BiQuadFilter filter;
|
||||
filter.Initialize(kBiQuadConfig);
|
||||
|
||||
std::array<float, kFrameSize> samples1;
|
||||
for (size_t i = 0; i < kNumFrames; ++i) {
|
||||
filter.Process(kBiQuadInputSeq[i], samples1);
|
||||
}
|
||||
|
||||
filter.Reset();
|
||||
std::array<float, kFrameSize> samples2;
|
||||
for (size_t i = 0; i < kNumFrames; ++i) {
|
||||
filter.Process(kBiQuadInputSeq[i], samples2);
|
||||
}
|
||||
|
||||
EXPECT_EQ(samples1, samples2);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace webrtc
|
||||
|
||||
@ -17,6 +17,8 @@ group("rnn_vad") {
|
||||
source_set("lib") {
|
||||
sources = [
|
||||
"common.h",
|
||||
"features_extraction.cc",
|
||||
"features_extraction.h",
|
||||
"fft_util.cc",
|
||||
"fft_util.h",
|
||||
"lp_residual.cc",
|
||||
@ -37,6 +39,7 @@ source_set("lib") {
|
||||
"symmetric_matrix_buffer.h",
|
||||
]
|
||||
deps = [
|
||||
"..:biquad_filter",
|
||||
"../../../../api:array_view",
|
||||
"../../../../common_audio/",
|
||||
"../../../../rtc_base:checks",
|
||||
@ -84,6 +87,7 @@ if (rtc_include_tests) {
|
||||
rtc_source_set("unittests") {
|
||||
testonly = true
|
||||
sources = [
|
||||
"features_extraction_unittest.cc",
|
||||
"fft_util_unittest.cc",
|
||||
"lp_residual_unittest.cc",
|
||||
"pitch_search_internal_unittest.cc",
|
||||
|
||||
90
modules/audio_processing/agc2/rnn_vad/features_extraction.cc
Normal file
90
modules/audio_processing/agc2/rnn_vad/features_extraction.cc
Normal file
@ -0,0 +1,90 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/lp_residual.h"
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace {
|
||||
|
||||
// Generated via "B, A = scipy.signal.butter(2, 30/12000, btype='highpass')"
|
||||
const BiQuadFilter::BiQuadCoefficients kHpfConfig24k = {
|
||||
{0.99446179f, -1.98892358f, 0.99446179f},
|
||||
{-1.98889291f, 0.98895425f}};
|
||||
|
||||
} // namespace
|
||||
|
||||
FeaturesExtractor::FeaturesExtractor()
|
||||
: use_high_pass_filter_(false),
|
||||
pitch_buf_24kHz_(),
|
||||
pitch_buf_24kHz_view_(pitch_buf_24kHz_.GetBufferView()),
|
||||
lp_residual_(kBufSize24kHz),
|
||||
lp_residual_view_(lp_residual_.data(), kBufSize24kHz),
|
||||
pitch_estimator_(),
|
||||
reference_frame_view_(pitch_buf_24kHz_.GetMostRecentValuesView()) {
|
||||
RTC_DCHECK_EQ(kBufSize24kHz, lp_residual_.size());
|
||||
hpf_.Initialize(kHpfConfig24k);
|
||||
Reset();
|
||||
}
|
||||
|
||||
FeaturesExtractor::~FeaturesExtractor() = default;
|
||||
|
||||
void FeaturesExtractor::Reset() {
|
||||
pitch_buf_24kHz_.Reset();
|
||||
spectral_features_extractor_.Reset();
|
||||
if (use_high_pass_filter_)
|
||||
hpf_.Reset();
|
||||
}
|
||||
|
||||
bool FeaturesExtractor::CheckSilenceComputeFeatures(
|
||||
rtc::ArrayView<const float, kFrameSize10ms24kHz> samples,
|
||||
rtc::ArrayView<float, kFeatureVectorSize> feature_vector) {
|
||||
// Pre-processing.
|
||||
if (use_high_pass_filter_) {
|
||||
std::array<float, kFrameSize10ms24kHz> samples_filtered;
|
||||
hpf_.Process(samples, samples_filtered);
|
||||
// Feed buffer with the pre-processed version of |samples|.
|
||||
pitch_buf_24kHz_.Push({samples_filtered.data(), samples_filtered.size()});
|
||||
} else {
|
||||
// Feed buffer with |samples|.
|
||||
pitch_buf_24kHz_.Push(samples);
|
||||
}
|
||||
// Extract the LP residual.
|
||||
float lpc_coeffs[kNumLpcCoefficients];
|
||||
ComputeAndPostProcessLpcCoefficients(pitch_buf_24kHz_view_,
|
||||
{lpc_coeffs, kNumLpcCoefficients});
|
||||
ComputeLpResidual({lpc_coeffs, kNumLpcCoefficients}, pitch_buf_24kHz_view_,
|
||||
lp_residual_view_);
|
||||
// Estimate pitch on the LP-residual and write the normalized pitch period
|
||||
// into the output vector (normalization based on training data stats).
|
||||
pitch_info_48kHz_ = pitch_estimator_.Estimate(lp_residual_view_);
|
||||
feature_vector[kFeatureVectorSize - 2] =
|
||||
0.01f * (static_cast<int>(pitch_info_48kHz_.period) - 300);
|
||||
// Extract lagged frames (according to the estimated pitch period).
|
||||
RTC_DCHECK_LE(pitch_info_48kHz_.period / 2, kMaxPitch24kHz);
|
||||
auto lagged_frame = pitch_buf_24kHz_view_.subview(
|
||||
kMaxPitch24kHz - pitch_info_48kHz_.period / 2, kFrameSize20ms24kHz);
|
||||
// Analyze reference and lagged frames checking if silence has been detected
|
||||
// and write the feature vector.
|
||||
return spectral_features_extractor_.CheckSilenceComputeFeatures(
|
||||
reference_frame_view_, {lagged_frame.data(), kFrameSize20ms24kHz},
|
||||
{{feature_vector.data() + kNumLowerBands, kNumBands - kNumLowerBands},
|
||||
{feature_vector.data(), kNumLowerBands},
|
||||
{feature_vector.data() + kNumBands, kNumLowerBands},
|
||||
{feature_vector.data() + kNumBands + kNumLowerBands, kNumLowerBands},
|
||||
{feature_vector.data() + kNumBands + 2 * kNumLowerBands, kNumLowerBands},
|
||||
&feature_vector[kFeatureVectorSize - 1]});
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
62
modules/audio_processing/agc2/rnn_vad/features_extraction.h
Normal file
62
modules/audio_processing/agc2/rnn_vad/features_extraction.h
Normal file
@ -0,0 +1,62 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_FEATURES_EXTRACTION_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_FEATURES_EXTRACTION_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/biquad_filter.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_search.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/sequence_buffer.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/spectral_features.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Feature extractor to feed the VAD RNN.
|
||||
class FeaturesExtractor {
|
||||
public:
|
||||
FeaturesExtractor();
|
||||
FeaturesExtractor(const FeaturesExtractor&) = delete;
|
||||
FeaturesExtractor& operator=(const FeaturesExtractor&) = delete;
|
||||
~FeaturesExtractor();
|
||||
void Reset();
|
||||
// Analyzes the samples, computes the feature vector and returns true if
|
||||
// silence is detected (false if not). When silence is detected,
|
||||
// |feature_vector| is partially written and therefore must not be used to
|
||||
// feed the VAD RNN.
|
||||
bool CheckSilenceComputeFeatures(
|
||||
rtc::ArrayView<const float, kFrameSize10ms24kHz> samples,
|
||||
rtc::ArrayView<float, kFeatureVectorSize> feature_vector);
|
||||
|
||||
private:
|
||||
const bool use_high_pass_filter_;
|
||||
// TODO(bugs.webrtc.org/7494): Remove HPF depending on how AGC2 is used in APM
|
||||
// and on whether an HPF is already used as pre-processing step in APM.
|
||||
BiQuadFilter hpf_;
|
||||
SequenceBuffer<float, kBufSize24kHz, kFrameSize10ms24kHz, kFrameSize20ms24kHz>
|
||||
pitch_buf_24kHz_;
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf_24kHz_view_;
|
||||
std::vector<float> lp_residual_;
|
||||
rtc::ArrayView<float, kBufSize24kHz> lp_residual_view_;
|
||||
PitchEstimator pitch_estimator_;
|
||||
rtc::ArrayView<const float, kFrameSize20ms24kHz> reference_frame_view_;
|
||||
SpectralFeaturesExtractor spectral_features_extractor_;
|
||||
PitchInfo pitch_info_48kHz_;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_FEATURES_EXTRACTION_H_
|
||||
@ -0,0 +1,103 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// #include "test/fpe_observer.h"
|
||||
#include "test/gtest.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace test {
|
||||
namespace {
|
||||
|
||||
constexpr size_t ceil(size_t n, size_t m) {
|
||||
return (n + m - 1) / m;
|
||||
}
|
||||
|
||||
// Number of 10 ms frames required to fill a pitch buffer having size
|
||||
// |kBufSize24kHz|.
|
||||
constexpr size_t kNumTestDataFrames = ceil(kBufSize24kHz, kFrameSize10ms24kHz);
|
||||
// Number of samples for the test data.
|
||||
constexpr size_t kNumTestDataSize = kNumTestDataFrames * kFrameSize10ms24kHz;
|
||||
|
||||
// Verifies that the pitch in Hz is in the detectable range.
|
||||
bool PitchIsValid(float pitch_hz) {
|
||||
const size_t pitch_period =
|
||||
static_cast<size_t>(static_cast<float>(kSampleRate24kHz) / pitch_hz);
|
||||
return kInitialMinPitch24kHz <= pitch_period &&
|
||||
pitch_period <= kMaxPitch24kHz;
|
||||
}
|
||||
|
||||
void CreatePureTone(float amplitude, float freq_hz, rtc::ArrayView<float> dst) {
|
||||
for (size_t i = 0; i < dst.size(); ++i)
|
||||
dst[i] = amplitude * std::sin(2.f * kPi * freq_hz * i / kSampleRate24kHz);
|
||||
}
|
||||
|
||||
// Feeds |features_extractor| with |samples| splitting it in 10 ms frames.
|
||||
// For every frame, the output is written into |feature_vector|. Returns true
|
||||
// if silence is detected in the last frame.
|
||||
bool FeedTestData(FeaturesExtractor* features_extractor,
|
||||
rtc::ArrayView<const float> samples,
|
||||
rtc::ArrayView<float, kFeatureVectorSize> feature_vector) {
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// FloatingPointExceptionObserver fpe_observer;
|
||||
bool is_silence = true;
|
||||
const size_t num_frames = samples.size() / kFrameSize10ms24kHz;
|
||||
for (size_t i = 0; i < num_frames; ++i) {
|
||||
is_silence = features_extractor->CheckSilenceComputeFeatures(
|
||||
{samples.data() + i * kFrameSize10ms24kHz, kFrameSize10ms24kHz},
|
||||
feature_vector);
|
||||
}
|
||||
return is_silence;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Extracts the features for two pure tones and verifies that the pitch field
|
||||
// values reflect the known tone frequencies.
|
||||
TEST(RnnVadTest, FeatureExtractionLowHighPitch) {
|
||||
constexpr float amplitude = 1000.f;
|
||||
constexpr float low_pitch_hz = 150.f;
|
||||
constexpr float high_pitch_hz = 250.f;
|
||||
ASSERT_TRUE(PitchIsValid(low_pitch_hz));
|
||||
ASSERT_TRUE(PitchIsValid(high_pitch_hz));
|
||||
|
||||
FeaturesExtractor features_extractor;
|
||||
std::vector<float> samples(kNumTestDataSize);
|
||||
std::vector<float> feature_vector(kFeatureVectorSize);
|
||||
ASSERT_EQ(kFeatureVectorSize, feature_vector.size());
|
||||
rtc::ArrayView<float, kFeatureVectorSize> feature_vector_view(
|
||||
feature_vector.data(), kFeatureVectorSize);
|
||||
|
||||
// Extract the normalized scalar feature that is proportional to the estimated
|
||||
// pitch period.
|
||||
constexpr size_t pitch_feature_index = kFeatureVectorSize - 2;
|
||||
// Low frequency tone - i.e., high period.
|
||||
CreatePureTone(amplitude, low_pitch_hz, samples);
|
||||
ASSERT_FALSE(FeedTestData(&features_extractor, samples, feature_vector_view));
|
||||
float high_pitch_period = feature_vector_view[pitch_feature_index];
|
||||
// High frequency tone - i.e., low period.
|
||||
features_extractor.Reset();
|
||||
CreatePureTone(amplitude, high_pitch_hz, samples);
|
||||
ASSERT_FALSE(FeedTestData(&features_extractor, samples, feature_vector_view));
|
||||
float low_pitch_period = feature_vector_view[pitch_feature_index];
|
||||
// Check.
|
||||
EXPECT_LT(low_pitch_period, high_pitch_period);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
@ -9,32 +9,33 @@
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_search.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// TODO(bugs.webrtc.org/9076): To decrease the stack size, add a class that uses
|
||||
// std::vector instances instead of the local arrays used in PitchSearch(). It
|
||||
// is also useful once https://webrtc-review.googlesource.com/c/src/+/73366
|
||||
// lands.
|
||||
PitchInfo PitchSearch(rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
|
||||
PitchInfo prev_pitch_48kHz,
|
||||
RealFourier* fft) {
|
||||
PitchEstimator::PitchEstimator()
|
||||
: fft_(RealFourier::Create(kAutoCorrelationFftOrder)),
|
||||
pitch_buf_decimated_(kBufSize12kHz),
|
||||
pitch_buf_decimated_view_(pitch_buf_decimated_.data(), kBufSize12kHz),
|
||||
auto_corr_(kNumInvertedLags12kHz),
|
||||
auto_corr_view_(auto_corr_.data(), kNumInvertedLags12kHz) {
|
||||
RTC_DCHECK_EQ(kBufSize12kHz, pitch_buf_decimated_.size());
|
||||
RTC_DCHECK_EQ(kNumInvertedLags12kHz, auto_corr_view_.size());
|
||||
}
|
||||
|
||||
PitchEstimator::~PitchEstimator() = default;
|
||||
|
||||
PitchInfo PitchEstimator::Estimate(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf) {
|
||||
// Perform the initial pitch search at 12 kHz.
|
||||
std::array<float, kBufSize12kHz> pitch_buf_decimated;
|
||||
Decimate2x(pitch_buf,
|
||||
{pitch_buf_decimated.data(), pitch_buf_decimated.size()});
|
||||
Decimate2x(pitch_buf, pitch_buf_decimated_view_);
|
||||
// Compute auto-correlation terms.
|
||||
std::array<float, kNumInvertedLags12kHz> auto_corr;
|
||||
ComputePitchAutoCorrelation(
|
||||
{pitch_buf_decimated.data(), pitch_buf_decimated.size()}, kMaxPitch12kHz,
|
||||
{auto_corr.data(), auto_corr.size()}, fft);
|
||||
ComputePitchAutoCorrelation(pitch_buf_decimated_view_, kMaxPitch12kHz,
|
||||
auto_corr_view_, fft_.get());
|
||||
|
||||
// Search for pitch at 12 kHz.
|
||||
std::array<size_t, 2> pitch_candidates_inv_lags = FindBestPitchPeriods(
|
||||
{auto_corr.data(), auto_corr.size()},
|
||||
{pitch_buf_decimated.data(), pitch_buf_decimated.size()}, kMaxPitch12kHz);
|
||||
auto_corr_view_, pitch_buf_decimated_view_, kMaxPitch12kHz);
|
||||
|
||||
// Refine the pitch period estimation.
|
||||
// The refinement is done using the pitch buffer that contains 24 kHz samples.
|
||||
@ -47,8 +48,9 @@ PitchInfo PitchSearch(rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
|
||||
{pitch_candidates_inv_lags.data(), pitch_candidates_inv_lags.size()});
|
||||
// Look for stronger harmonics to find the final pitch period and its gain.
|
||||
RTC_DCHECK_LT(pitch_inv_lag_48kHz, kMaxPitch48kHz);
|
||||
return CheckLowerPitchPeriodsAndComputePitchGain(
|
||||
pitch_buf, kMaxPitch48kHz - pitch_inv_lag_48kHz, prev_pitch_48kHz);
|
||||
last_pitch_48kHz_ = CheckLowerPitchPeriodsAndComputePitchGain(
|
||||
pitch_buf, kMaxPitch48kHz - pitch_inv_lag_48kHz, last_pitch_48kHz_);
|
||||
return last_pitch_48kHz_;
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
|
||||
@ -11,19 +11,37 @@
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "common_audio/real_fourier.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Searches the pitch period and gain. Return the pitch estimation data for
|
||||
// 48 kHz.
|
||||
PitchInfo PitchSearch(rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
|
||||
PitchInfo prev_pitch_48kHz,
|
||||
RealFourier* fft);
|
||||
// Pitch estimator.
|
||||
class PitchEstimator {
|
||||
public:
|
||||
PitchEstimator();
|
||||
PitchEstimator(const PitchEstimator&) = delete;
|
||||
PitchEstimator& operator=(const PitchEstimator&) = delete;
|
||||
~PitchEstimator();
|
||||
// Estimates the pitch period and gain. Returns the pitch estimation data for
|
||||
// 48 kHz.
|
||||
PitchInfo Estimate(rtc::ArrayView<const float, kBufSize24kHz> pitch_buf);
|
||||
|
||||
private:
|
||||
PitchInfo last_pitch_48kHz_;
|
||||
std::unique_ptr<RealFourier> fft_;
|
||||
std::vector<float> pitch_buf_decimated_;
|
||||
rtc::ArrayView<float, kBufSize12kHz> pitch_buf_decimated_view_;
|
||||
std::vector<float> auto_corr_;
|
||||
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr_view_;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
@ -9,6 +9,7 @@
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_search.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
|
||||
|
||||
#include <array>
|
||||
@ -28,9 +29,7 @@ TEST(RnnVadTest, PitchSearchBitExactness) {
|
||||
const size_t num_frames = lp_residual_reader.second;
|
||||
std::array<float, 864> lp_residual;
|
||||
float expected_pitch_period, expected_pitch_gain;
|
||||
PitchInfo last_pitch;
|
||||
std::unique_ptr<RealFourier> fft =
|
||||
RealFourier::Create(kAutoCorrelationFftOrder);
|
||||
PitchEstimator pitch_estimator;
|
||||
{
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// FloatingPointExceptionObserver fpe_observer;
|
||||
@ -41,10 +40,10 @@ TEST(RnnVadTest, PitchSearchBitExactness) {
|
||||
{lp_residual.data(), lp_residual.size()});
|
||||
lp_residual_reader.first->ReadValue(&expected_pitch_period);
|
||||
lp_residual_reader.first->ReadValue(&expected_pitch_gain);
|
||||
last_pitch = PitchSearch({lp_residual.data(), lp_residual.size()},
|
||||
last_pitch, fft.get());
|
||||
EXPECT_EQ(static_cast<size_t>(expected_pitch_period), last_pitch.period);
|
||||
EXPECT_NEAR(expected_pitch_gain, last_pitch.gain, 1e-5f);
|
||||
PitchInfo pitch_info =
|
||||
pitch_estimator.Estimate({lp_residual.data(), lp_residual.size()});
|
||||
EXPECT_EQ(static_cast<size_t>(expected_pitch_period), pitch_info.period);
|
||||
EXPECT_NEAR(expected_pitch_gain, pitch_info.gain, 1e-5f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -11,9 +11,10 @@
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SEQUENCE_BUFFER_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SEQUENCE_BUFFER_H_
|
||||
|
||||
#include <array>
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "rtc_base/checks.h"
|
||||
@ -26,41 +27,38 @@ namespace rnn_vad {
|
||||
// chunks have size S and N respectively. For instance, when S = 2N the first
|
||||
// half of the sequence buffer is replaced with its second half, and the new N
|
||||
// values are written at the end of the buffer.
|
||||
template <typename T, size_t S, size_t N>
|
||||
// The class also provides a view on the most recent M values, where 0 < M <= S
|
||||
// and by default M = N.
|
||||
template <typename T, size_t S, size_t N, size_t M = N>
|
||||
class SequenceBuffer {
|
||||
static_assert(S >= N,
|
||||
"The new chunk size is larger than the sequence buffer size.");
|
||||
static_assert(N <= S,
|
||||
"The new chunk size cannot be larger than the sequence buffer "
|
||||
"size.");
|
||||
static_assert(std::is_arithmetic<T>::value,
|
||||
"Integral or floating point required.");
|
||||
|
||||
public:
|
||||
SequenceBuffer() { buffer_.fill(0); }
|
||||
SequenceBuffer() : buffer_(S) {
|
||||
RTC_DCHECK_EQ(S, buffer_.size());
|
||||
Reset();
|
||||
}
|
||||
SequenceBuffer(const SequenceBuffer&) = delete;
|
||||
SequenceBuffer& operator=(const SequenceBuffer&) = delete;
|
||||
~SequenceBuffer() = default;
|
||||
size_t size() const { return S; }
|
||||
size_t chunks_size() const { return N; }
|
||||
// Sets the sequence buffer values to zero.
|
||||
void Reset() { buffer_.fill(0); }
|
||||
void Reset() { std::fill(buffer_.begin(), buffer_.end(), 0); }
|
||||
// Returns a view on the whole buffer.
|
||||
rtc::ArrayView<const T, S> GetBufferView() const {
|
||||
return {buffer_.data(), S};
|
||||
}
|
||||
// Returns a view on part of the buffer; the first element starts at the given
|
||||
// offset and the last one is the last one in the buffer.
|
||||
rtc::ArrayView<const T> GetBufferView(int offset) const {
|
||||
RTC_DCHECK_LE(0, offset);
|
||||
RTC_DCHECK_LT(offset, S);
|
||||
return {buffer_.data() + offset, S - offset};
|
||||
}
|
||||
// Returns a view on part of the buffer; the first element starts at the given
|
||||
// offset and the size of the view is |size|.
|
||||
rtc::ArrayView<const T> GetBufferView(int offset, size_t size) const {
|
||||
RTC_DCHECK_LE(0, offset);
|
||||
RTC_DCHECK_LT(offset, S);
|
||||
RTC_DCHECK_LT(0, size);
|
||||
RTC_DCHECK_LE(size, S - offset);
|
||||
return {buffer_.data() + offset, size};
|
||||
// Returns a view on the M most recent values of the buffer.
|
||||
rtc::ArrayView<const T, M> GetMostRecentValuesView() const {
|
||||
static_assert(M <= S,
|
||||
"The number of most recent values cannot be larger than the "
|
||||
"sequence buffer size.");
|
||||
return {buffer_.data() + S - M, M};
|
||||
}
|
||||
// Shifts left the buffer by N items and add new N items at the end.
|
||||
void Push(rtc::ArrayView<const T, N> new_values) {
|
||||
@ -72,8 +70,7 @@ class SequenceBuffer {
|
||||
}
|
||||
|
||||
private:
|
||||
// TODO(bugs.webrtc.org/9076): Switch to std::vector to decrease stack size.
|
||||
std::array<T, S> buffer_;
|
||||
std::vector<T> buffer_;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
#include "modules/audio_processing/agc2/rnn_vad/sequence_buffer.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
|
||||
#include "test/gtest.h"
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user