From e63d38ba342a13acafd6f273f21e5d55d5db8b13 Mon Sep 17 00:00:00 2001 From: Alessio Bazzica Date: Thu, 19 Apr 2018 17:56:55 +0200 Subject: [PATCH] AGC2 RNN VAD: Linear Prediction Residual Functions to estimate the inverse filter via LPC and compute the LP residual applying the inverse filter. This CL also includes test utilities, in particular BinaryFileReader, used to read chunks of data and optionally cast them on the fly, and Create*Reader() functions to read resource files available at test time. Bug: webrtc:9076 Change-Id: Ia4793b8ad6a63cb3089ed11ddad89d1aa0b840f6 Reviewed-on: https://webrtc-review.googlesource.com/70244 Commit-Queue: Alessio Bazzica Reviewed-by: Jesus de Vicente Pena Reviewed-by: Alex Loiko Cr-Commit-Position: refs/heads/master@{#22946} --- .../audio_processing/agc2/rnn_vad/BUILD.gn | 38 +++++ .../audio_processing/agc2/rnn_vad/common.h | 9 +- .../agc2/rnn_vad/lp_residual.cc | 132 ++++++++++++++++++ .../agc2/rnn_vad/lp_residual.h | 39 ++++++ .../agc2/rnn_vad/lp_residual_unittest.cc | 94 +++++++++++++ .../agc2/rnn_vad/test_utils.cc | 55 ++++++++ .../agc2/rnn_vad/test_utils.h | 102 ++++++++++++++ .../agc2/rnn_vad/pitch_buf_24k.dat.sha1 | 1 + .../agc2/rnn_vad/pitch_lp_res.dat.sha1 | 1 + 9 files changed, 470 insertions(+), 1 deletion(-) create mode 100644 modules/audio_processing/agc2/rnn_vad/lp_residual.cc create mode 100644 modules/audio_processing/agc2/rnn_vad/lp_residual.h create mode 100644 modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc create mode 100644 modules/audio_processing/agc2/rnn_vad/test_utils.cc create mode 100644 modules/audio_processing/agc2/rnn_vad/test_utils.h create mode 100644 resources/audio_processing/agc2/rnn_vad/pitch_buf_24k.dat.sha1 create mode 100644 resources/audio_processing/agc2/rnn_vad/pitch_lp_res.dat.sha1 diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn index 814a7f5b44..bd7ff492b9 100644 --- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn +++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn @@ -17,6 +17,8 @@ group("rnn_vad") { source_set("lib") { sources = [ "common.h", + "lp_residual.cc", + "lp_residual.h", "ring_buffer.h", "sequence_buffer.h", "symmetric_matrix_buffer.h", @@ -28,18 +30,54 @@ source_set("lib") { } if (rtc_include_tests) { + source_set("lib_test") { + testonly = true + sources = [ + "test_utils.cc", + "test_utils.h", + ] + deps = [ + "../../../../api:array_view", + "../../../../rtc_base:checks", + "../../../../rtc_base:ptr_util", + "../../../../test:fileutils", + "../../../../test:test_support", + ] + } + + unittest_resources = [ + "../../../../resources/audio_processing/agc2/rnn_vad/pitch_buf_24k.dat", + "../../../../resources/audio_processing/agc2/rnn_vad/pitch_lp_res.dat", + ] + + if (is_ios) { + bundle_data("unittests_bundle_data") { + testonly = true + sources = unittest_resources + outputs = [ + "{{bundle_resources_dir}}/{{source_file_part}}", + ] + } + } + rtc_source_set("unittests") { testonly = true sources = [ + "lp_residual_unittest.cc", "ring_buffer_unittest.cc", "sequence_buffer_unittest.cc", "symmetric_matrix_buffer_unittest.cc", ] deps = [ ":lib", + ":lib_test", "../../../../api:array_view", "../../../../test:test_support", ] + data = unittest_resources + if (is_ios) { + deps += [ ":unittests_bundle_data" ] + } } rtc_executable("rnn_vad_tool") { diff --git a/modules/audio_processing/agc2/rnn_vad/common.h b/modules/audio_processing/agc2/rnn_vad/common.h index 93569ff1e9..ec42a7b4cf 100644 --- a/modules/audio_processing/agc2/rnn_vad/common.h +++ b/modules/audio_processing/agc2/rnn_vad/common.h @@ -15,7 +15,14 @@ namespace webrtc { namespace rnn_vad { constexpr size_t kSampleRate24kHz = 24000; -constexpr size_t kFrameSize10ms24kHz = 240; +constexpr size_t kFrameSize10ms24kHz = kSampleRate24kHz / 100; +constexpr size_t kFrameSize20ms24kHz = kFrameSize10ms24kHz * 2; + +// Pitch analysis params. +constexpr size_t kPitchMinPeriod24kHz = kSampleRate24kHz / 800; // 0.00125 s. +constexpr size_t kPitchMaxPeriod24kHz = kSampleRate24kHz / 62.5; // 0.016 s. +constexpr size_t kBufSize24kHz = kPitchMaxPeriod24kHz + kFrameSize20ms24kHz; +static_assert((kBufSize24kHz & 1) == 0, "The buffer size must be even."); } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/lp_residual.cc b/modules/audio_processing/agc2/rnn_vad/lp_residual.cc new file mode 100644 index 0000000000..63aec6b0b9 --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/lp_residual.cc @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/rnn_vad/lp_residual.h" + +#include +#include +#include + +#include "rtc_base/checks.h" + +namespace webrtc { +namespace rnn_vad { +namespace { + +// Computes cross-correlation coefficients between |x| and |y| and writes them +// in |x_corr|. The lag values are in {0, ..., max_lag - 1}, where max_lag +// equals the size of |x_corr|. +// The |x| and |y| sub-arrays used to compute a cross-correlation coefficients +// for a lag l have both size "size of |x| - l" - i.e., the longest sub-array is +// used. |x| and |y| must have the same size. +void ComputeCrossCorrelation( + rtc::ArrayView x, + rtc::ArrayView y, + rtc::ArrayView x_corr) { + constexpr size_t max_lag = x_corr.size(); + RTC_DCHECK_EQ(x.size(), y.size()); + RTC_DCHECK_LT(max_lag, x.size()); + for (size_t lag = 0; lag < max_lag; ++lag) + x_corr[lag] = + std::inner_product(x.begin(), x.end() - lag, y.begin() + lag, 0.f); +} + +// Applies denoising to the auto-correlation coefficients. +void DenoiseAutoCorrelation( + rtc::ArrayView auto_corr) { + // Assume -40 dB white noise floor. + auto_corr[0] *= 1.0001f; + for (size_t i = 1; i < kNumLpcCoefficients; ++i) + auto_corr[i] -= auto_corr[i] * (0.008f * i) * (0.008f * i); +} + +// Computes the initial inverse filter coefficients given the auto-correlation +// coefficients of an input frame. +void ComputeInitialInverseFilterCoefficients( + rtc::ArrayView auto_corr, + rtc::ArrayView lpc_coeffs) { + float error = auto_corr[0]; + for (size_t i = 0; i < kNumLpcCoefficients - 1; ++i) { + float reflection_coeff = 0.f; + for (size_t j = 0; j < i; ++j) + reflection_coeff += lpc_coeffs[j] * auto_corr[i - j]; + reflection_coeff += auto_corr[i + 1]; + reflection_coeff /= -error; + // Update LPC coefficients and total error. + lpc_coeffs[i] = reflection_coeff; + for (size_t j = 0; j<(i + 1)>> 1; ++j) { + const float tmp1 = lpc_coeffs[j]; + const float tmp2 = lpc_coeffs[i - 1 - j]; + lpc_coeffs[j] = tmp1 + reflection_coeff * tmp2; + lpc_coeffs[i - 1 - j] = tmp2 + reflection_coeff * tmp1; + } + error -= reflection_coeff * reflection_coeff * error; + if (error < 0.001f * auto_corr[0]) + break; + } +} + +} // namespace + +void ComputeAndPostProcessLpcCoefficients( + rtc::ArrayView x, + rtc::ArrayView lpc_coeffs) { + std::array auto_corr; + ComputeCrossCorrelation(x, x, {auto_corr.data(), auto_corr.size()}); + if (auto_corr[0] == 0.f) { // Empty frame. + std::fill(lpc_coeffs.begin(), lpc_coeffs.end(), 0); + return; + } + DenoiseAutoCorrelation({auto_corr.data(), auto_corr.size()}); + std::array lpc_coeffs_pre{}; + ComputeInitialInverseFilterCoefficients( + {auto_corr.data(), auto_corr.size()}, + {lpc_coeffs_pre.data(), lpc_coeffs_pre.size()}); + // LPC coefficients post-processing. + // TODO(https://bugs.webrtc.org/9076): Consider removing these steps. + { + float c = 1.f; + for (size_t i = 0; i < kNumLpcCoefficients - 1; ++i) { + c *= 0.9f; + lpc_coeffs_pre[i] *= c; + } + } + { + const float c = 0.8f; + lpc_coeffs[0] = lpc_coeffs_pre[0] + c; + lpc_coeffs[1] = lpc_coeffs_pre[1] + c * lpc_coeffs_pre[0]; + lpc_coeffs[2] = lpc_coeffs_pre[2] + c * lpc_coeffs_pre[1]; + lpc_coeffs[3] = lpc_coeffs_pre[3] + c * lpc_coeffs_pre[2]; + lpc_coeffs[4] = c * lpc_coeffs_pre[3]; + } +} + +void ComputeLpResidual( + rtc::ArrayView lpc_coeffs, + rtc::ArrayView x, + rtc::ArrayView y) { + RTC_DCHECK_LT(kNumLpcCoefficients, x.size()); + RTC_DCHECK_EQ(x.size(), y.size()); + std::array input_chunk; + input_chunk.fill(0.f); + for (size_t i = 0; i < y.size(); ++i) { + const float sum = std::inner_product(input_chunk.begin(), input_chunk.end(), + lpc_coeffs.begin(), x[i]); + // Circular shift and add a new sample. + for (size_t j = kNumLpcCoefficients - 1; j > 0; --j) + input_chunk[j] = input_chunk[j - 1]; + input_chunk[0] = x[i]; + // Copy result. + y[i] = sum; + } +} + +} // namespace rnn_vad +} // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/lp_residual.h b/modules/audio_processing/agc2/rnn_vad/lp_residual.h new file mode 100644 index 0000000000..bffafd291f --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/lp_residual.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_LP_RESIDUAL_H_ +#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_LP_RESIDUAL_H_ + +#include "api/array_view.h" + +namespace webrtc { +namespace rnn_vad { + +// LPC inverse filter length. +constexpr size_t kNumLpcCoefficients = 5; + +// Given a frame |x|, computes a post-processed version of LPC coefficients +// tailored for pitch estimation. +void ComputeAndPostProcessLpcCoefficients( + rtc::ArrayView x, + rtc::ArrayView lpc_coeffs); + +// Computes the LP residual for the input frame |x| and the LPC coefficients +// |lpc_coeffs|. |y| and |x| can point to the same array for in-place +// computation. +void ComputeLpResidual( + rtc::ArrayView lpc_coeffs, + rtc::ArrayView x, + rtc::ArrayView y); + +} // namespace rnn_vad +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_LP_RESIDUAL_H_ diff --git a/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc b/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc new file mode 100644 index 0000000000..41ffe68e0a --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/rnn_vad/lp_residual.h" + +#include + +#include "modules/audio_processing/agc2/rnn_vad/common.h" +#include "modules/audio_processing/agc2/rnn_vad/test_utils.h" +// TODO(https://bugs.webrtc.org/8948): Add when the issue is fixed. +// #include "test/fpe_observer.h" +#include "test/gtest.h" + +namespace webrtc { +namespace test { + +using rnn_vad::ComputeAndPostProcessLpcCoefficients; +using rnn_vad::ComputeLpResidual; +using rnn_vad::kBufSize24kHz; +using rnn_vad::kFrameSize10ms24kHz; +using rnn_vad::kNumLpcCoefficients; + +TEST(RnnVadTest, LpResidualOfEmptyFrame) { + // TODO(https://bugs.webrtc.org/8948): Add when the issue is fixed. + // FloatingPointExceptionObserver fpe_observer; + + // Input frame (empty, i.e., all samples set to 0). + std::array empty_frame; + empty_frame.fill(0.f); + // Compute inverse filter coefficients. + std::array lpc_coeffs; + ComputeAndPostProcessLpcCoefficients({empty_frame}, + {lpc_coeffs.data(), lpc_coeffs.size()}); + // Compute LP residual. + std::array lp_residual; + ComputeLpResidual({lpc_coeffs.data(), lpc_coeffs.size()}, {empty_frame}, + {lp_residual}); +} + +// TODO(https://bugs.webrtc.org/9076): Remove when the issue is fixed. +TEST(RnnVadTest, LpResidualPipelineBitExactness) { + // Pitch buffer 24 kHz data reader. + auto pitch_buf_24kHz_reader = CreatePitchBuffer24kHzReader(); + const size_t num_frames = pitch_buf_24kHz_reader.second; + std::array pitch_buf_data; + rtc::ArrayView pitch_buf_data_view( + pitch_buf_data.data(), pitch_buf_data.size()); + // Read ground-truth. + auto lp_residual_reader = CreateLpResidualAndPitchPeriodGainReader(); + ASSERT_EQ(num_frames, lp_residual_reader.second); + std::array expected_lp_residual; + rtc::ArrayView expected_lp_residual_view( + expected_lp_residual.data(), expected_lp_residual.size()); + // Init pipeline. + std::array lpc_coeffs; + rtc::ArrayView lpc_coeffs_view( + lpc_coeffs.data(), kNumLpcCoefficients); + std::array computed_lp_residual; + rtc::ArrayView computed_lp_residual_view( + computed_lp_residual.data(), computed_lp_residual.size()); + { + // TODO(https://bugs.webrtc.org/8948): Add when the issue is fixed. + // FloatingPointExceptionObserver fpe_observer; + + for (size_t i = 0; i < num_frames; ++i) { + SCOPED_TRACE(i); + // Read input and expected output. + pitch_buf_24kHz_reader.first->ReadChunk(pitch_buf_data_view); + lp_residual_reader.first->ReadChunk(expected_lp_residual_view); + // Skip pitch gain and period. + float unused; + lp_residual_reader.first->ReadValue(&unused); + lp_residual_reader.first->ReadValue(&unused); + // Run pipeline. + ComputeAndPostProcessLpcCoefficients(pitch_buf_data_view, + lpc_coeffs_view); + ComputeLpResidual(lpc_coeffs_view, pitch_buf_data_view, + computed_lp_residual_view); + // Compare. + ExpectNearAbsolute(expected_lp_residual_view, computed_lp_residual_view, + kFloatMin); + } + } +} + +} // namespace test +} // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.cc b/modules/audio_processing/agc2/rnn_vad/test_utils.cc new file mode 100644 index 0000000000..374f0b6a8a --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/test_utils.cc @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/rnn_vad/test_utils.h" + +#include "rtc_base/checks.h" +#include "rtc_base/ptr_util.h" +#include "test/gtest.h" +#include "test/testsupport/fileutils.h" + +namespace webrtc { +namespace test { +namespace { + +using ReaderPairType = + std::pair>, const size_t>; + +} // namespace + +void ExpectNearAbsolute(rtc::ArrayView expected, + rtc::ArrayView computed, + const float tolerance) { + ASSERT_EQ(expected.size(), computed.size()); + for (size_t i = 0; i < expected.size(); ++i) { + SCOPED_TRACE(i); + EXPECT_NEAR(expected[i], computed[i], tolerance); + } +} + +ReaderPairType CreatePitchBuffer24kHzReader() { + auto ptr = rtc::MakeUnique>( + test::ResourcePath("audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat"), + 864); + return {std::move(ptr), + rtc::CheckedDivExact(ptr->data_length(), static_cast(864))}; +} + +ReaderPairType CreateLpResidualAndPitchPeriodGainReader() { + constexpr size_t num_lp_residual_coeffs = 864; + auto ptr = rtc::MakeUnique>( + test::ResourcePath("audio_processing/agc2/rnn_vad/pitch_lp_res", "dat"), + num_lp_residual_coeffs); + return {std::move(ptr), + rtc::CheckedDivExact(ptr->data_length(), 2 + num_lp_residual_coeffs)}; +} + +} // namespace test +} // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.h b/modules/audio_processing/agc2/rnn_vad/test_utils.h new file mode 100644 index 0000000000..6c16447889 --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/test_utils.h @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_ +#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "api/array_view.h" +#include "rtc_base/checks.h" + +namespace webrtc { +namespace test { + +constexpr float kFloatMin = std::numeric_limits::min(); + +// Fail for every pair from two equally sized rtc::ArrayView views such +// that their absolute error is above a given threshold. +void ExpectNearAbsolute(rtc::ArrayView expected, + rtc::ArrayView computed, + const float tolerance); + +// Reader for binary files consisting of an arbitrary long sequence of elements +// having type T. It is possible to read and cast to another type D at once. +template +class BinaryFileReader { + public: + explicit BinaryFileReader(const std::string& file_path, + const size_t chunk_size = 1) + : is_(file_path, std::ios::binary | std::ios::ate), + data_length_(is_.tellg() / sizeof(T)), + chunk_size_(chunk_size) { + RTC_CHECK_LT(0, chunk_size_); + RTC_CHECK(is_); + SeekBeginning(); + buf_.resize(chunk_size_); + } + BinaryFileReader(const BinaryFileReader&) = delete; + BinaryFileReader& operator=(const BinaryFileReader&) = delete; + ~BinaryFileReader() = default; + size_t data_length() const { return data_length_; } + bool ReadValue(D* dst) { + if (std::is_same::value) { + is_.read(reinterpret_cast(dst), sizeof(T)); + } else { + T v; + is_.read(reinterpret_cast(&v), sizeof(T)); + *dst = static_cast(v); + } + return is_.gcount() == sizeof(T); + } + bool ReadChunk(rtc::ArrayView dst) { + RTC_DCHECK_EQ(chunk_size_, dst.size()); + const std::streamsize bytes_to_read = chunk_size_ * sizeof(T); + if (std::is_same::value) { + is_.read(reinterpret_cast(dst.data()), bytes_to_read); + } else { + is_.read(reinterpret_cast(buf_.data()), bytes_to_read); + std::transform(buf_.begin(), buf_.end(), dst.begin(), + [](const T& v) -> D { return static_cast(v); }); + } + return is_.gcount() == bytes_to_read; + } + void SeekForward(size_t items) { is_.seekg(items * sizeof(T), is_.cur); } + void SeekBeginning() { is_.seekg(0, is_.beg); } + + private: + std::ifstream is_; + const size_t data_length_; + const size_t chunk_size_; + std::vector buf_; +}; + +// Factories for resource file readers; the functions below return a pair where +// the first item is a reader unique pointer and the second the number of chunks +// that can be read from the file. + +// Creates a reader for the pitch buffer content at 24 kHz. +std::pair>, const size_t> +CreatePitchBuffer24kHzReader(); +// Creates a reader for the the LP residual coefficients and the pitch period +// and gain values. +std::pair>, const size_t> +CreateLpResidualAndPitchPeriodGainReader(); + +} // namespace test +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_ diff --git a/resources/audio_processing/agc2/rnn_vad/pitch_buf_24k.dat.sha1 b/resources/audio_processing/agc2/rnn_vad/pitch_buf_24k.dat.sha1 new file mode 100644 index 0000000000..1bcf6397e1 --- /dev/null +++ b/resources/audio_processing/agc2/rnn_vad/pitch_buf_24k.dat.sha1 @@ -0,0 +1 @@ +ea598d4ef3f4e34bce4c4c5d0791a588517582b9 \ No newline at end of file diff --git a/resources/audio_processing/agc2/rnn_vad/pitch_lp_res.dat.sha1 b/resources/audio_processing/agc2/rnn_vad/pitch_lp_res.dat.sha1 new file mode 100644 index 0000000000..48533190ab --- /dev/null +++ b/resources/audio_processing/agc2/rnn_vad/pitch_lp_res.dat.sha1 @@ -0,0 +1 @@ +03069dd7d5cf7ba63c0b6c90f7f7a283e7488ea0 \ No newline at end of file