diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn index 814a7f5b44..bd7ff492b9 100644 --- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn +++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn @@ -17,6 +17,8 @@ group("rnn_vad") { source_set("lib") { sources = [ "common.h", + "lp_residual.cc", + "lp_residual.h", "ring_buffer.h", "sequence_buffer.h", "symmetric_matrix_buffer.h", @@ -28,18 +30,54 @@ source_set("lib") { } if (rtc_include_tests) { + source_set("lib_test") { + testonly = true + sources = [ + "test_utils.cc", + "test_utils.h", + ] + deps = [ + "../../../../api:array_view", + "../../../../rtc_base:checks", + "../../../../rtc_base:ptr_util", + "../../../../test:fileutils", + "../../../../test:test_support", + ] + } + + unittest_resources = [ + "../../../../resources/audio_processing/agc2/rnn_vad/pitch_buf_24k.dat", + "../../../../resources/audio_processing/agc2/rnn_vad/pitch_lp_res.dat", + ] + + if (is_ios) { + bundle_data("unittests_bundle_data") { + testonly = true + sources = unittest_resources + outputs = [ + "{{bundle_resources_dir}}/{{source_file_part}}", + ] + } + } + rtc_source_set("unittests") { testonly = true sources = [ + "lp_residual_unittest.cc", "ring_buffer_unittest.cc", "sequence_buffer_unittest.cc", "symmetric_matrix_buffer_unittest.cc", ] deps = [ ":lib", + ":lib_test", "../../../../api:array_view", "../../../../test:test_support", ] + data = unittest_resources + if (is_ios) { + deps += [ ":unittests_bundle_data" ] + } } rtc_executable("rnn_vad_tool") { diff --git a/modules/audio_processing/agc2/rnn_vad/common.h b/modules/audio_processing/agc2/rnn_vad/common.h index 93569ff1e9..ec42a7b4cf 100644 --- a/modules/audio_processing/agc2/rnn_vad/common.h +++ b/modules/audio_processing/agc2/rnn_vad/common.h @@ -15,7 +15,14 @@ namespace webrtc { namespace rnn_vad { constexpr size_t kSampleRate24kHz = 24000; -constexpr size_t kFrameSize10ms24kHz = 240; +constexpr size_t kFrameSize10ms24kHz = kSampleRate24kHz / 100; +constexpr size_t kFrameSize20ms24kHz = kFrameSize10ms24kHz * 2; + +// Pitch analysis params. +constexpr size_t kPitchMinPeriod24kHz = kSampleRate24kHz / 800; // 0.00125 s. +constexpr size_t kPitchMaxPeriod24kHz = kSampleRate24kHz / 62.5; // 0.016 s. +constexpr size_t kBufSize24kHz = kPitchMaxPeriod24kHz + kFrameSize20ms24kHz; +static_assert((kBufSize24kHz & 1) == 0, "The buffer size must be even."); } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/lp_residual.cc b/modules/audio_processing/agc2/rnn_vad/lp_residual.cc new file mode 100644 index 0000000000..63aec6b0b9 --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/lp_residual.cc @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/rnn_vad/lp_residual.h" + +#include +#include +#include + +#include "rtc_base/checks.h" + +namespace webrtc { +namespace rnn_vad { +namespace { + +// Computes cross-correlation coefficients between |x| and |y| and writes them +// in |x_corr|. The lag values are in {0, ..., max_lag - 1}, where max_lag +// equals the size of |x_corr|. +// The |x| and |y| sub-arrays used to compute a cross-correlation coefficients +// for a lag l have both size "size of |x| - l" - i.e., the longest sub-array is +// used. |x| and |y| must have the same size. +void ComputeCrossCorrelation( + rtc::ArrayView x, + rtc::ArrayView y, + rtc::ArrayView x_corr) { + constexpr size_t max_lag = x_corr.size(); + RTC_DCHECK_EQ(x.size(), y.size()); + RTC_DCHECK_LT(max_lag, x.size()); + for (size_t lag = 0; lag < max_lag; ++lag) + x_corr[lag] = + std::inner_product(x.begin(), x.end() - lag, y.begin() + lag, 0.f); +} + +// Applies denoising to the auto-correlation coefficients. +void DenoiseAutoCorrelation( + rtc::ArrayView auto_corr) { + // Assume -40 dB white noise floor. + auto_corr[0] *= 1.0001f; + for (size_t i = 1; i < kNumLpcCoefficients; ++i) + auto_corr[i] -= auto_corr[i] * (0.008f * i) * (0.008f * i); +} + +// Computes the initial inverse filter coefficients given the auto-correlation +// coefficients of an input frame. +void ComputeInitialInverseFilterCoefficients( + rtc::ArrayView auto_corr, + rtc::ArrayView lpc_coeffs) { + float error = auto_corr[0]; + for (size_t i = 0; i < kNumLpcCoefficients - 1; ++i) { + float reflection_coeff = 0.f; + for (size_t j = 0; j < i; ++j) + reflection_coeff += lpc_coeffs[j] * auto_corr[i - j]; + reflection_coeff += auto_corr[i + 1]; + reflection_coeff /= -error; + // Update LPC coefficients and total error. + lpc_coeffs[i] = reflection_coeff; + for (size_t j = 0; j<(i + 1)>> 1; ++j) { + const float tmp1 = lpc_coeffs[j]; + const float tmp2 = lpc_coeffs[i - 1 - j]; + lpc_coeffs[j] = tmp1 + reflection_coeff * tmp2; + lpc_coeffs[i - 1 - j] = tmp2 + reflection_coeff * tmp1; + } + error -= reflection_coeff * reflection_coeff * error; + if (error < 0.001f * auto_corr[0]) + break; + } +} + +} // namespace + +void ComputeAndPostProcessLpcCoefficients( + rtc::ArrayView x, + rtc::ArrayView lpc_coeffs) { + std::array auto_corr; + ComputeCrossCorrelation(x, x, {auto_corr.data(), auto_corr.size()}); + if (auto_corr[0] == 0.f) { // Empty frame. + std::fill(lpc_coeffs.begin(), lpc_coeffs.end(), 0); + return; + } + DenoiseAutoCorrelation({auto_corr.data(), auto_corr.size()}); + std::array lpc_coeffs_pre{}; + ComputeInitialInverseFilterCoefficients( + {auto_corr.data(), auto_corr.size()}, + {lpc_coeffs_pre.data(), lpc_coeffs_pre.size()}); + // LPC coefficients post-processing. + // TODO(https://bugs.webrtc.org/9076): Consider removing these steps. + { + float c = 1.f; + for (size_t i = 0; i < kNumLpcCoefficients - 1; ++i) { + c *= 0.9f; + lpc_coeffs_pre[i] *= c; + } + } + { + const float c = 0.8f; + lpc_coeffs[0] = lpc_coeffs_pre[0] + c; + lpc_coeffs[1] = lpc_coeffs_pre[1] + c * lpc_coeffs_pre[0]; + lpc_coeffs[2] = lpc_coeffs_pre[2] + c * lpc_coeffs_pre[1]; + lpc_coeffs[3] = lpc_coeffs_pre[3] + c * lpc_coeffs_pre[2]; + lpc_coeffs[4] = c * lpc_coeffs_pre[3]; + } +} + +void ComputeLpResidual( + rtc::ArrayView lpc_coeffs, + rtc::ArrayView x, + rtc::ArrayView y) { + RTC_DCHECK_LT(kNumLpcCoefficients, x.size()); + RTC_DCHECK_EQ(x.size(), y.size()); + std::array input_chunk; + input_chunk.fill(0.f); + for (size_t i = 0; i < y.size(); ++i) { + const float sum = std::inner_product(input_chunk.begin(), input_chunk.end(), + lpc_coeffs.begin(), x[i]); + // Circular shift and add a new sample. + for (size_t j = kNumLpcCoefficients - 1; j > 0; --j) + input_chunk[j] = input_chunk[j - 1]; + input_chunk[0] = x[i]; + // Copy result. + y[i] = sum; + } +} + +} // namespace rnn_vad +} // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/lp_residual.h b/modules/audio_processing/agc2/rnn_vad/lp_residual.h new file mode 100644 index 0000000000..bffafd291f --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/lp_residual.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_LP_RESIDUAL_H_ +#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_LP_RESIDUAL_H_ + +#include "api/array_view.h" + +namespace webrtc { +namespace rnn_vad { + +// LPC inverse filter length. +constexpr size_t kNumLpcCoefficients = 5; + +// Given a frame |x|, computes a post-processed version of LPC coefficients +// tailored for pitch estimation. +void ComputeAndPostProcessLpcCoefficients( + rtc::ArrayView x, + rtc::ArrayView lpc_coeffs); + +// Computes the LP residual for the input frame |x| and the LPC coefficients +// |lpc_coeffs|. |y| and |x| can point to the same array for in-place +// computation. +void ComputeLpResidual( + rtc::ArrayView lpc_coeffs, + rtc::ArrayView x, + rtc::ArrayView y); + +} // namespace rnn_vad +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_LP_RESIDUAL_H_ diff --git a/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc b/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc new file mode 100644 index 0000000000..41ffe68e0a --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/rnn_vad/lp_residual.h" + +#include + +#include "modules/audio_processing/agc2/rnn_vad/common.h" +#include "modules/audio_processing/agc2/rnn_vad/test_utils.h" +// TODO(https://bugs.webrtc.org/8948): Add when the issue is fixed. +// #include "test/fpe_observer.h" +#include "test/gtest.h" + +namespace webrtc { +namespace test { + +using rnn_vad::ComputeAndPostProcessLpcCoefficients; +using rnn_vad::ComputeLpResidual; +using rnn_vad::kBufSize24kHz; +using rnn_vad::kFrameSize10ms24kHz; +using rnn_vad::kNumLpcCoefficients; + +TEST(RnnVadTest, LpResidualOfEmptyFrame) { + // TODO(https://bugs.webrtc.org/8948): Add when the issue is fixed. + // FloatingPointExceptionObserver fpe_observer; + + // Input frame (empty, i.e., all samples set to 0). + std::array empty_frame; + empty_frame.fill(0.f); + // Compute inverse filter coefficients. + std::array lpc_coeffs; + ComputeAndPostProcessLpcCoefficients({empty_frame}, + {lpc_coeffs.data(), lpc_coeffs.size()}); + // Compute LP residual. + std::array lp_residual; + ComputeLpResidual({lpc_coeffs.data(), lpc_coeffs.size()}, {empty_frame}, + {lp_residual}); +} + +// TODO(https://bugs.webrtc.org/9076): Remove when the issue is fixed. +TEST(RnnVadTest, LpResidualPipelineBitExactness) { + // Pitch buffer 24 kHz data reader. + auto pitch_buf_24kHz_reader = CreatePitchBuffer24kHzReader(); + const size_t num_frames = pitch_buf_24kHz_reader.second; + std::array pitch_buf_data; + rtc::ArrayView pitch_buf_data_view( + pitch_buf_data.data(), pitch_buf_data.size()); + // Read ground-truth. + auto lp_residual_reader = CreateLpResidualAndPitchPeriodGainReader(); + ASSERT_EQ(num_frames, lp_residual_reader.second); + std::array expected_lp_residual; + rtc::ArrayView expected_lp_residual_view( + expected_lp_residual.data(), expected_lp_residual.size()); + // Init pipeline. + std::array lpc_coeffs; + rtc::ArrayView lpc_coeffs_view( + lpc_coeffs.data(), kNumLpcCoefficients); + std::array computed_lp_residual; + rtc::ArrayView computed_lp_residual_view( + computed_lp_residual.data(), computed_lp_residual.size()); + { + // TODO(https://bugs.webrtc.org/8948): Add when the issue is fixed. + // FloatingPointExceptionObserver fpe_observer; + + for (size_t i = 0; i < num_frames; ++i) { + SCOPED_TRACE(i); + // Read input and expected output. + pitch_buf_24kHz_reader.first->ReadChunk(pitch_buf_data_view); + lp_residual_reader.first->ReadChunk(expected_lp_residual_view); + // Skip pitch gain and period. + float unused; + lp_residual_reader.first->ReadValue(&unused); + lp_residual_reader.first->ReadValue(&unused); + // Run pipeline. + ComputeAndPostProcessLpcCoefficients(pitch_buf_data_view, + lpc_coeffs_view); + ComputeLpResidual(lpc_coeffs_view, pitch_buf_data_view, + computed_lp_residual_view); + // Compare. + ExpectNearAbsolute(expected_lp_residual_view, computed_lp_residual_view, + kFloatMin); + } + } +} + +} // namespace test +} // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.cc b/modules/audio_processing/agc2/rnn_vad/test_utils.cc new file mode 100644 index 0000000000..374f0b6a8a --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/test_utils.cc @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/rnn_vad/test_utils.h" + +#include "rtc_base/checks.h" +#include "rtc_base/ptr_util.h" +#include "test/gtest.h" +#include "test/testsupport/fileutils.h" + +namespace webrtc { +namespace test { +namespace { + +using ReaderPairType = + std::pair>, const size_t>; + +} // namespace + +void ExpectNearAbsolute(rtc::ArrayView expected, + rtc::ArrayView computed, + const float tolerance) { + ASSERT_EQ(expected.size(), computed.size()); + for (size_t i = 0; i < expected.size(); ++i) { + SCOPED_TRACE(i); + EXPECT_NEAR(expected[i], computed[i], tolerance); + } +} + +ReaderPairType CreatePitchBuffer24kHzReader() { + auto ptr = rtc::MakeUnique>( + test::ResourcePath("audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat"), + 864); + return {std::move(ptr), + rtc::CheckedDivExact(ptr->data_length(), static_cast(864))}; +} + +ReaderPairType CreateLpResidualAndPitchPeriodGainReader() { + constexpr size_t num_lp_residual_coeffs = 864; + auto ptr = rtc::MakeUnique>( + test::ResourcePath("audio_processing/agc2/rnn_vad/pitch_lp_res", "dat"), + num_lp_residual_coeffs); + return {std::move(ptr), + rtc::CheckedDivExact(ptr->data_length(), 2 + num_lp_residual_coeffs)}; +} + +} // namespace test +} // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.h b/modules/audio_processing/agc2/rnn_vad/test_utils.h new file mode 100644 index 0000000000..6c16447889 --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/test_utils.h @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_ +#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "api/array_view.h" +#include "rtc_base/checks.h" + +namespace webrtc { +namespace test { + +constexpr float kFloatMin = std::numeric_limits::min(); + +// Fail for every pair from two equally sized rtc::ArrayView views such +// that their absolute error is above a given threshold. +void ExpectNearAbsolute(rtc::ArrayView expected, + rtc::ArrayView computed, + const float tolerance); + +// Reader for binary files consisting of an arbitrary long sequence of elements +// having type T. It is possible to read and cast to another type D at once. +template +class BinaryFileReader { + public: + explicit BinaryFileReader(const std::string& file_path, + const size_t chunk_size = 1) + : is_(file_path, std::ios::binary | std::ios::ate), + data_length_(is_.tellg() / sizeof(T)), + chunk_size_(chunk_size) { + RTC_CHECK_LT(0, chunk_size_); + RTC_CHECK(is_); + SeekBeginning(); + buf_.resize(chunk_size_); + } + BinaryFileReader(const BinaryFileReader&) = delete; + BinaryFileReader& operator=(const BinaryFileReader&) = delete; + ~BinaryFileReader() = default; + size_t data_length() const { return data_length_; } + bool ReadValue(D* dst) { + if (std::is_same::value) { + is_.read(reinterpret_cast(dst), sizeof(T)); + } else { + T v; + is_.read(reinterpret_cast(&v), sizeof(T)); + *dst = static_cast(v); + } + return is_.gcount() == sizeof(T); + } + bool ReadChunk(rtc::ArrayView dst) { + RTC_DCHECK_EQ(chunk_size_, dst.size()); + const std::streamsize bytes_to_read = chunk_size_ * sizeof(T); + if (std::is_same::value) { + is_.read(reinterpret_cast(dst.data()), bytes_to_read); + } else { + is_.read(reinterpret_cast(buf_.data()), bytes_to_read); + std::transform(buf_.begin(), buf_.end(), dst.begin(), + [](const T& v) -> D { return static_cast(v); }); + } + return is_.gcount() == bytes_to_read; + } + void SeekForward(size_t items) { is_.seekg(items * sizeof(T), is_.cur); } + void SeekBeginning() { is_.seekg(0, is_.beg); } + + private: + std::ifstream is_; + const size_t data_length_; + const size_t chunk_size_; + std::vector buf_; +}; + +// Factories for resource file readers; the functions below return a pair where +// the first item is a reader unique pointer and the second the number of chunks +// that can be read from the file. + +// Creates a reader for the pitch buffer content at 24 kHz. +std::pair>, const size_t> +CreatePitchBuffer24kHzReader(); +// Creates a reader for the the LP residual coefficients and the pitch period +// and gain values. +std::pair>, const size_t> +CreateLpResidualAndPitchPeriodGainReader(); + +} // namespace test +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_ diff --git a/resources/audio_processing/agc2/rnn_vad/pitch_buf_24k.dat.sha1 b/resources/audio_processing/agc2/rnn_vad/pitch_buf_24k.dat.sha1 new file mode 100644 index 0000000000..1bcf6397e1 --- /dev/null +++ b/resources/audio_processing/agc2/rnn_vad/pitch_buf_24k.dat.sha1 @@ -0,0 +1 @@ +ea598d4ef3f4e34bce4c4c5d0791a588517582b9 \ No newline at end of file diff --git a/resources/audio_processing/agc2/rnn_vad/pitch_lp_res.dat.sha1 b/resources/audio_processing/agc2/rnn_vad/pitch_lp_res.dat.sha1 new file mode 100644 index 0000000000..48533190ab --- /dev/null +++ b/resources/audio_processing/agc2/rnn_vad/pitch_lp_res.dat.sha1 @@ -0,0 +1 @@ +03069dd7d5cf7ba63c0b6c90f7f7a283e7488ea0 \ No newline at end of file