diff --git a/modules/audio_processing/agc2/rnn_vad/fft_util.cc b/modules/audio_processing/agc2/rnn_vad/fft_util.cc index a1c5dac477..4825e2befe 100644 --- a/modules/audio_processing/agc2/rnn_vad/fft_util.cc +++ b/modules/audio_processing/agc2/rnn_vad/fft_util.cc @@ -11,6 +11,7 @@ #include "modules/audio_processing/agc2/rnn_vad/fft_util.h" #include +#include #include #include "rtc_base/checks.h" @@ -42,8 +43,8 @@ BandAnalysisFft::~BandAnalysisFft() = default; void BandAnalysisFft::ForwardFft(rtc::ArrayView samples, rtc::ArrayView> dst) { - RTC_DCHECK_EQ(input_buf_.size(), samples.size()); - RTC_DCHECK_EQ(samples.size(), dst.size()); + RTC_DCHECK_EQ(samples.size(), kFrameSize20ms24kHz); + RTC_DCHECK_EQ(dst.size(), kFrameSize20ms24kHz / 2 + 1); // Apply windowing. RTC_DCHECK_EQ(input_buf_.size(), 2 * half_window_.size()); for (size_t i = 0; i < input_buf_.size() / 2; ++i) { @@ -52,7 +53,10 @@ void BandAnalysisFft::ForwardFft(rtc::ArrayView samples, input_buf_[j].real(samples[j] * half_window_[i]); } fft_.ForwardFft(kFrameSize20ms24kHz, input_buf_.data(), kFrameSize20ms24kHz, - dst.data()); + output_buf_.data()); + // Copy the first symmetric conjugate part. + RTC_DCHECK_LT(dst.size(), output_buf_.size()); + std::copy(output_buf_.begin(), output_buf_.begin() + dst.size(), dst.begin()); } } // namespace rnn_vad diff --git a/modules/audio_processing/agc2/rnn_vad/fft_util.h b/modules/audio_processing/agc2/rnn_vad/fft_util.h index f4265f4e32..c744ff6c64 100644 --- a/modules/audio_processing/agc2/rnn_vad/fft_util.h +++ b/modules/audio_processing/agc2/rnn_vad/fft_util.h @@ -21,6 +21,8 @@ namespace webrtc { namespace rnn_vad { +// TODO(alessiob): Switch to PFFFT using its own wrapper. +// TODO(alessiob): Delete this class when switching to PFFFT. // FFT implementation wrapper for the band-wise analysis step in which 20 ms // frames at 24 kHz are analyzed in the frequency domain. The goal of this class // are (i) making easy to switch to another FFT implementation, (ii) own the @@ -34,6 +36,8 @@ class BandAnalysisFft { ~BandAnalysisFft(); // Applies a windowing function to |samples|, computes the real forward FFT // and writes the result in |dst|. + // The size of |samples| must be 480 (20 ms at 24 kHz). + // The size of |dst| must be 241 since the complex conjugate is not written. void ForwardFft(rtc::ArrayView samples, rtc::ArrayView> dst); @@ -42,6 +46,7 @@ class BandAnalysisFft { "kFrameSize20ms24kHz must be even."); const std::array half_window_; std::array, kFrameSize20ms24kHz> input_buf_{}; + std::array, kFrameSize20ms24kHz> output_buf_{}; rnnoise::KissFft fft_; }; diff --git a/modules/audio_processing/agc2/rnn_vad/fft_util_unittest.cc b/modules/audio_processing/agc2/rnn_vad/fft_util_unittest.cc index 985460051d..28f56bd069 100644 --- a/modules/audio_processing/agc2/rnn_vad/fft_util_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/fft_util_unittest.cc @@ -8,7 +8,9 @@ * be found in the AUTHORS file in the root of the source tree. */ -#include +#include +#include +#include #include "modules/audio_processing/agc2/rnn_vad/common.h" #include "modules/audio_processing/agc2/rnn_vad/fft_util.h" @@ -20,27 +22,40 @@ namespace webrtc { namespace rnn_vad { namespace test { +namespace { -TEST(RnnVadTest, CheckBandAnalysisFftOutput) { - // Input data. - std::array samples{}; - for (int i = 0; i < static_cast(kFrameSize20ms24kHz); ++i) { - samples[i] = i - static_cast(kFrameSize20ms24kHz / 2); +std::vector CreateSine(float amplitude, + float frequency_hz, + float duration_s, + int sample_rate_hz) { + size_t num_samples = static_cast(duration_s * sample_rate_hz); + std::vector signal(num_samples); + for (size_t i = 0; i < num_samples; ++i) { + signal[i] = + amplitude * std::sin(i * 2.0 * kPi * frequency_hz / sample_rate_hz); } - // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. - // FloatingPointExceptionObserver fpe_observer; - BandAnalysisFft fft; - std::array, kFrameSize20ms24kHz> fft_coeffs; - fft.ForwardFft(samples, fft_coeffs); - // First coefficient is DC - i.e., real number. - EXPECT_EQ(0.f, fft_coeffs[0].imag()); - // Check conjugated symmetry of the FFT output. - for (size_t i = 1; i < fft_coeffs.size() / 2; ++i) { - SCOPED_TRACE(i); - const auto& a = fft_coeffs[i]; - const auto& b = fft_coeffs[fft_coeffs.size() - i]; - EXPECT_NEAR(a.real(), b.real(), 2e-6f); - EXPECT_NEAR(a.imag(), -b.imag(), 2e-6f); + return signal; +} + +} // namespace + +TEST(RnnVadTest, BandAnalysisFftTest) { + for (float frequency_hz : {200.f, 450.f, 1500.f}) { + SCOPED_TRACE(frequency_hz); + auto x = CreateSine( + /*amplitude=*/1000.f, frequency_hz, + /*duration_s=*/0.02f, + /*sample_rate_hz=*/kSampleRate24kHz); + BandAnalysisFft analyzer; + std::vector> x_fft(x.size() / 2 + 1); + analyzer.ForwardFft(x, x_fft); + int peak_fft_bin_index = std::distance( + x_fft.begin(), + std::max_element(x_fft.begin(), x_fft.end(), + [](std::complex a, std::complex b) { + return std::abs(a) < std::abs(b); + })); + EXPECT_EQ(frequency_hz, kSampleRate24kHz * peak_fft_bin_index / x.size()); } } diff --git a/modules/audio_processing/agc2/rnn_vad/spectral_features.cc b/modules/audio_processing/agc2/rnn_vad/spectral_features.cc index 695eed57f0..84db2dfecd 100644 --- a/modules/audio_processing/agc2/rnn_vad/spectral_features.cc +++ b/modules/audio_processing/agc2/rnn_vad/spectral_features.cc @@ -68,8 +68,8 @@ SpectralFeaturesView::~SpectralFeaturesView() = default; SpectralFeaturesExtractor::SpectralFeaturesExtractor() : fft_(), - reference_frame_fft_(kFrameSize20ms24kHz), - lagged_frame_fft_(kFrameSize20ms24kHz), + reference_frame_fft_(kFrameSize20ms24kHz / 2 + 1), + lagged_frame_fft_(kFrameSize20ms24kHz / 2 + 1), band_boundaries_( ComputeBandBoundaryIndexes(kSampleRate24kHz, kFrameSize20ms24kHz)), dct_table_(ComputeDctTable()) {}