RNN VAD: Switch to PFFFT

This CL replaces KissFFT with PFFFT for the spectral features
computation.

Remarks:
- Extra FFT output vector copy eliminated
- Scaling and windowing merged into a single vector for efficiency
- Nyquist frequency hack to keep the iteration in
  BandFeaturesExtractor::ComputeSpectralCrossCorrelation simple

Bug: webrtc:9577, webrtc:10480
Change-Id: I436563bd257f66a243f5402be270ffcf859bd184
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/130221
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Per Åhgren <peah@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#27536}
This commit is contained in:
Alessio Bazzica 2019-04-10 11:11:09 +02:00 committed by Commit Bot
parent 4a53766c84
commit 4d4cce8b3f
11 changed files with 91 additions and 215 deletions

View File

@ -16,8 +16,6 @@ rtc_source_set("rnn_vad") {
"common.h",
"features_extraction.cc",
"features_extraction.h",
"fft_util.cc",
"fft_util.h",
"lp_residual.cc",
"lp_residual.h",
"pitch_info.h",
@ -90,7 +88,6 @@ if (rtc_include_tests) {
sources = [
"auto_correlation_unittest.cc",
"features_extraction_unittest.cc",
"fft_util_unittest.cc",
"lp_residual_unittest.cc",
"pitch_search_internal_unittest.cc",
"pitch_search_unittest.cc",
@ -111,6 +108,7 @@ if (rtc_include_tests) {
"../../../../rtc_base:checks",
"../../../../rtc_base:logging",
"../../../../test:test_support",
"../../utility:pffft_wrapper",
"//third_party/rnnoise:rnn_vad",
]
data = unittest_resources

View File

@ -53,7 +53,6 @@ constexpr size_t kMinPitch48kHz = kMinPitch24kHz * 2;
constexpr size_t kMaxPitch48kHz = kMaxPitch24kHz * 2;
// Spectral features.
constexpr size_t kFftSizeBy2Plus1 = kFrameSize20ms24kHz / 2 + 1;
constexpr size_t kNumBands = 22;
constexpr size_t kNumLowerBands = 6;
static_assert((0 < kNumLowerBands) && (kNumLowerBands < kNumBands), "");

View File

@ -1,63 +0,0 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/fft_util.h"
#include <stddef.h>
#include <algorithm>
#include <cmath>
#include "rtc_base/checks.h"
namespace webrtc {
namespace rnn_vad {
namespace {
constexpr size_t kHalfFrameSize = kFrameSize20ms24kHz / 2;
// Computes the first half of the Vorbis window.
std::array<float, kHalfFrameSize> ComputeHalfVorbisWindow() {
std::array<float, kHalfFrameSize> half_window{};
for (size_t i = 0; i < kHalfFrameSize; ++i) {
half_window[i] =
std::sin(0.5 * kPi * std::sin(0.5 * kPi * (i + 0.5) / kHalfFrameSize) *
std::sin(0.5 * kPi * (i + 0.5) / kHalfFrameSize));
}
return half_window;
}
} // namespace
FftUtil::FftUtil()
: half_window_(ComputeHalfVorbisWindow()),
fft_(static_cast<int>(input_buf_.size())) {}
FftUtil::~FftUtil() = default;
void FftUtil::WindowedFft(rtc::ArrayView<const float> samples,
rtc::ArrayView<std::complex<float>> dst) {
RTC_DCHECK_EQ(samples.size(), kFrameSize20ms24kHz);
RTC_DCHECK_EQ(dst.size(), kFftSizeBy2Plus1);
// Apply windowing.
RTC_DCHECK_EQ(input_buf_.size(), 2 * half_window_.size());
for (size_t i = 0; i < input_buf_.size() / 2; ++i) {
input_buf_[i].real(samples[i] * half_window_[i]);
size_t j = kFrameSize20ms24kHz - i - 1;
input_buf_[j].real(samples[j] * half_window_[i]);
}
fft_.ForwardFft(kFrameSize20ms24kHz, input_buf_.data(), kFrameSize20ms24kHz,
output_buf_.data());
// Copy the first symmetric conjugate part.
RTC_DCHECK_LT(dst.size(), output_buf_.size());
std::copy(output_buf_.begin(), output_buf_.begin() + dst.size(), dst.begin());
}
} // namespace rnn_vad
} // namespace webrtc

View File

@ -1,55 +0,0 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_FFT_UTIL_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_FFT_UTIL_H_
#include <array>
#include <complex>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "third_party/rnnoise/src/kiss_fft.h"
namespace webrtc {
namespace rnn_vad {
// TODO(alessiob): Switch to PFFFT and remove this class.
// FFT implementation wrapper for the band-wise analysis step in which 20 ms
// frames at 24 kHz are analyzed in the frequency domain. The goal of this class
// are (i) making easy to switch to another FFT implementation, (ii) own the
// input buffer for the FFT and (iii) apply a windowing function before
// computing the FFT.
class FftUtil {
public:
FftUtil();
FftUtil(const FftUtil&) = delete;
FftUtil& operator=(const FftUtil&) = delete;
~FftUtil();
// Applies a windowing function to |samples|, computes the real forward FFT
// and writes the result in |dst|.
// The size of |samples| must be 480 (20 ms at 24 kHz).
// The size of |dst| must be 241 since the complex conjugate is not written.
void WindowedFft(rtc::ArrayView<const float> samples,
rtc::ArrayView<std::complex<float>> dst);
private:
static_assert((kFrameSize20ms24kHz & 1) == 0,
"kFrameSize20ms24kHz must be even.");
const std::array<float, kFrameSize20ms24kHz / 2> half_window_;
std::array<std::complex<float>, kFrameSize20ms24kHz> input_buf_;
std::array<std::complex<float>, kFrameSize20ms24kHz> output_buf_;
rnnoise::KissFft fft_;
};
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_FFT_UTIL_H_

View File

@ -1,64 +0,0 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include <algorithm>
#include <cmath>
#include <vector>
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/fft_util.h"
#include "rtc_base/checks.h"
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// #include "test/fpe_observer.h"
#include "test/gtest.h"
namespace webrtc {
namespace rnn_vad {
namespace test {
namespace {
std::vector<float> CreateSine(float amplitude,
float frequency_hz,
float duration_s,
int sample_rate_hz) {
size_t num_samples = static_cast<size_t>(duration_s * sample_rate_hz);
std::vector<float> signal(num_samples);
for (size_t i = 0; i < num_samples; ++i) {
signal[i] =
amplitude * std::sin(i * 2.0 * kPi * frequency_hz / sample_rate_hz);
}
return signal;
}
} // namespace
TEST(RnnVadTest, FftUtilTest) {
for (float frequency_hz : {200.f, 450.f, 1500.f}) {
SCOPED_TRACE(frequency_hz);
auto x = CreateSine(
/*amplitude=*/1000.f, frequency_hz,
/*duration_s=*/0.02f,
/*sample_rate_hz=*/kSampleRate24kHz);
FftUtil analyzer;
std::vector<std::complex<float>> x_fft(x.size() / 2 + 1);
analyzer.WindowedFft(x, x_fft);
int peak_fft_bin_index = std::distance(
x_fft.begin(),
std::max_element(x_fft.begin(), x_fft.end(),
[](std::complex<float> a, std::complex<float> b) {
return std::abs(a) < std::abs(b);
}));
EXPECT_EQ(frequency_hz, kSampleRate24kHz * peak_fft_bin_index / x.size());
}
}
} // namespace test
} // namespace rnn_vad
} // namespace webrtc

View File

@ -108,6 +108,10 @@ int main(int argc, char* argv[]) {
if (features_file) {
const float float_is_silence = is_silence ? 1.f : 0.f;
fwrite(&float_is_silence, sizeof(float), 1, features_file);
if (is_silence) {
// Do not write uninitialized values.
feature_vector.fill(0.f);
}
fwrite(feature_vector.data(), sizeof(float), kFeatureVectorSize,
features_file);
}

View File

@ -45,12 +45,52 @@ void UpdateCepstralDifferenceStats(
sym_matrix_buf->Push(distances);
}
// Computes the first half of the Vorbis window.
std::array<float, kFrameSize20ms24kHz / 2> ComputeScaledHalfVorbisWindow(
float scaling = 1.f) {
constexpr size_t kHalfSize = kFrameSize20ms24kHz / 2;
std::array<float, kHalfSize> half_window{};
for (size_t i = 0; i < kHalfSize; ++i) {
half_window[i] =
scaling *
std::sin(0.5 * kPi * std::sin(0.5 * kPi * (i + 0.5) / kHalfSize) *
std::sin(0.5 * kPi * (i + 0.5) / kHalfSize));
}
return half_window;
}
// Computes the forward FFT on a 20 ms frame to which a given window function is
// applied. The Fourier coefficient corresponding to the Nyquist frequency is
// set to zero (it is never used and this allows to simplify the code).
void ComputeWindowedForwardFft(
rtc::ArrayView<const float, kFrameSize20ms24kHz> frame,
const std::array<float, kFrameSize20ms24kHz / 2>& half_window,
Pffft::FloatBuffer* fft_input_buffer,
Pffft::FloatBuffer* fft_output_buffer,
Pffft* fft) {
RTC_DCHECK_EQ(frame.size(), 2 * half_window.size());
// Apply windowing.
auto in = fft_input_buffer->GetView();
for (size_t i = 0, j = kFrameSize20ms24kHz - 1; i < half_window.size();
++i, --j) {
in[i] = frame[i] * half_window[i];
in[j] = frame[j] * half_window[i];
}
fft->ForwardTransform(*fft_input_buffer, fft_output_buffer, /*ordered=*/true);
// Set the Nyquist frequency coefficient to zero.
auto out = fft_output_buffer->GetView();
out[1] = 0.f;
}
} // namespace
SpectralFeaturesExtractor::SpectralFeaturesExtractor()
: fft_(),
reference_frame_fft_(kFftSizeBy2Plus1),
lagged_frame_fft_(kFftSizeBy2Plus1),
: half_window_(ComputeScaledHalfVorbisWindow(
1.f / static_cast<float>(kFrameSize20ms24kHz))),
fft_(kFrameSize20ms24kHz, Pffft::FftType::kReal),
fft_buffer_(fft_.CreateBuffer()),
reference_frame_fft_(fft_.CreateBuffer()),
lagged_frame_fft_(fft_.CreateBuffer()),
dct_table_(ComputeDctTable()) {}
SpectralFeaturesExtractor::~SpectralFeaturesExtractor() = default;
@ -70,10 +110,10 @@ bool SpectralFeaturesExtractor::CheckSilenceComputeFeatures(
rtc::ArrayView<float, kNumLowerBands> bands_cross_corr,
float* variability) {
// Compute the Opus band energies for the reference frame.
fft_.WindowedFft(reference_frame, reference_frame_fft_);
ComputeWindowedForwardFft(reference_frame, half_window_, fft_buffer_.get(),
reference_frame_fft_.get(), &fft_);
spectral_correlator_.ComputeAutoCorrelation(
{reference_frame_fft_.data(), kFftSizeBy2Plus1},
reference_frame_bands_energy_);
reference_frame_fft_->GetConstView(), reference_frame_bands_energy_);
// Check if the reference frame has silence.
const float tot_energy =
std::accumulate(reference_frame_bands_energy_.begin(),
@ -82,9 +122,10 @@ bool SpectralFeaturesExtractor::CheckSilenceComputeFeatures(
return true;
}
// Compute the Opus band energies for the lagged frame.
fft_.WindowedFft(lagged_frame, lagged_frame_fft_);
spectral_correlator_.ComputeAutoCorrelation(
{lagged_frame_fft_.data(), kFftSizeBy2Plus1}, lagged_frame_bands_energy_);
ComputeWindowedForwardFft(lagged_frame, half_window_, fft_buffer_.get(),
lagged_frame_fft_.get(), &fft_);
spectral_correlator_.ComputeAutoCorrelation(lagged_frame_fft_->GetConstView(),
lagged_frame_bands_energy_);
// Log of the band energies for the reference frame.
std::array<float, kNumBands> log_bands_energy;
ComputeSmoothedLogMagnitudeSpectrum(reference_frame_bands_energy_,
@ -134,8 +175,8 @@ void SpectralFeaturesExtractor::ComputeAvgAndDerivatives(
void SpectralFeaturesExtractor::ComputeNormalizedCepstralCorrelation(
rtc::ArrayView<float, kNumLowerBands> bands_cross_corr) {
spectral_correlator_.ComputeCrossCorrelation(
{reference_frame_fft_.data(), kFftSizeBy2Plus1},
{lagged_frame_fft_.data(), kFftSizeBy2Plus1}, bands_cross_corr_);
reference_frame_fft_->GetConstView(), lagged_frame_fft_->GetConstView(),
bands_cross_corr_);
// Normalize.
for (size_t i = 0; i < bands_cross_corr_.size(); ++i) {
bands_cross_corr_[i] =

View File

@ -12,16 +12,16 @@
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SPECTRAL_FEATURES_H_
#include <array>
#include <complex>
#include <cstddef>
#include <memory>
#include <vector>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/fft_util.h"
#include "modules/audio_processing/agc2/rnn_vad/ring_buffer.h"
#include "modules/audio_processing/agc2/rnn_vad/spectral_features_internal.h"
#include "modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer.h"
#include "modules/audio_processing/utility/pffft_wrapper.h"
namespace webrtc {
namespace rnn_vad {
@ -58,9 +58,11 @@ class SpectralFeaturesExtractor {
rtc::ArrayView<float, kNumLowerBands> bands_cross_corr);
float ComputeVariability() const;
FftUtil fft_;
std::vector<std::complex<float>> reference_frame_fft_;
std::vector<std::complex<float>> lagged_frame_fft_;
const std::array<float, kFrameSize20ms24kHz / 2> half_window_;
Pffft fft_;
std::unique_ptr<Pffft::FloatBuffer> fft_buffer_;
std::unique_ptr<Pffft::FloatBuffer> reference_frame_fft_;
std::unique_ptr<Pffft::FloatBuffer> lagged_frame_fft_;
SpectralCorrelator spectral_correlator_;
std::array<float, kOpusBands24kHz> reference_frame_bands_energy_;
std::array<float, kOpusBands24kHz> lagged_frame_bands_energy_;

View File

@ -91,22 +91,26 @@ SpectralCorrelator::SpectralCorrelator()
SpectralCorrelator::~SpectralCorrelator() = default;
void SpectralCorrelator::ComputeAutoCorrelation(
rtc::ArrayView<const std::complex<float>, kFftSizeBy2Plus1> x,
rtc::ArrayView<const float> x,
rtc::ArrayView<float, kOpusBands24kHz> auto_corr) const {
ComputeCrossCorrelation(x, x, auto_corr);
}
void SpectralCorrelator::ComputeCrossCorrelation(
rtc::ArrayView<const std::complex<float>, kFftSizeBy2Plus1> x,
rtc::ArrayView<const std::complex<float>, kFftSizeBy2Plus1> y,
rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y,
rtc::ArrayView<float, kOpusBands24kHz> cross_corr) const {
RTC_DCHECK_EQ(x.size(), kFrameSize20ms24kHz);
RTC_DCHECK_EQ(x.size(), y.size());
RTC_DCHECK_EQ(x[1], 0.f) << "The Nyquist coefficient must be zeroed.";
RTC_DCHECK_EQ(y[1], 0.f) << "The Nyquist coefficient must be zeroed.";
constexpr auto kOpusScaleNumBins24kHz20ms = GetOpusScaleNumBins24kHz20ms();
size_t k = 0; // Next Fourier coefficient index.
cross_corr[0] = 0.f;
for (size_t i = 0; i < kOpusBands24kHz - 1; ++i) {
cross_corr[i + 1] = 0.f;
for (int j = 0; j < kOpusScaleNumBins24kHz20ms[i]; ++j) { // Band size.
const float v = x[k].real() * y[k].real() + x[k].imag() * y[k].imag();
const float v = x[2 * k] * y[2 * k] + x[2 * k + 1] * y[2 * k + 1];
const float tmp = weights_[k] * v;
cross_corr[i] += v - tmp;
cross_corr[i + 1] += tmp;
@ -114,8 +118,7 @@ void SpectralCorrelator::ComputeCrossCorrelation(
}
}
cross_corr[0] *= 2.f; // The first band only gets half contribution.
// The Nyquist coefficient is never used.
RTC_DCHECK_EQ(k, kFftSizeBy2Plus1 - 1);
RTC_DCHECK_EQ(k, kFrameSize20ms24kHz / 2); // Nyquist coefficient never used.
}
void ComputeSmoothedLogMagnitudeSpectrum(

View File

@ -13,7 +13,6 @@
#include <stddef.h>
#include <array>
#include <complex>
#include <vector>
#include "api/array_view.h"
@ -50,14 +49,22 @@ class SpectralCorrelator {
~SpectralCorrelator();
// Computes the band-wise spectral auto-correlations.
// |x| must:
// - have size equal to |kFrameSize20ms24kHz|;
// - be encoded as vectors of interleaved real-complex FFT coefficients
// where x[1] = y[1] = 0 (the Nyquist frequency coefficient is omitted).
void ComputeAutoCorrelation(
rtc::ArrayView<const std::complex<float>, kFftSizeBy2Plus1> x,
rtc::ArrayView<const float> x,
rtc::ArrayView<float, kOpusBands24kHz> auto_corr) const;
// Computes the band-wise spectral cross-correlations.
// |x| and |y| must:
// - have size equal to |kFrameSize20ms24kHz|;
// - be encoded as vectors of interleaved real-complex FFT coefficients where
// x[1] = y[1] = 0 (the Nyquist frequency coefficient is omitted).
void ComputeCrossCorrelation(
rtc::ArrayView<const std::complex<float>, kFftSizeBy2Plus1> x,
rtc::ArrayView<const std::complex<float>, kFftSizeBy2Plus1> y,
rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y,
rtc::ArrayView<float, kOpusBands24kHz> cross_corr) const;
private:

View File

@ -18,6 +18,7 @@
#include "api/array_view.h"
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
#include "modules/audio_processing/utility/pffft_wrapper.h"
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// #include "test/fpe_observer.h"
#include "test/gtest.h"
@ -86,10 +87,13 @@ TEST(RnnVadTest, DISABLED_TestOpusScaleWeights) {
TEST(RnnVadTest, SpectralCorrelatorValidOutput) {
SpectralCorrelator e;
std::array<std::complex<float>, kFftSizeBy2Plus1> in;
Pffft fft(kFrameSize20ms24kHz, Pffft::FftType::kReal);
auto in = fft.CreateBuffer();
std::array<float, kOpusBands24kHz> out;
in.fill({1.f, 1.f});
e.ComputeAutoCorrelation(in, out);
auto in_view = in->GetView();
std::fill(in_view.begin(), in_view.end(), 1.f);
in_view[1] = 0.f; // Nyquist frequency.
e.ComputeAutoCorrelation(in_view, out);
for (size_t i = 0; i < kOpusBands24kHz; ++i) {
SCOPED_TRACE(i);
EXPECT_GT(out[i], 0.f);