From 4736d4e524567eda5c4d13ac7fb317802030b438 Mon Sep 17 00:00:00 2001 From: Alessio Bazzica Date: Tue, 17 Apr 2018 16:46:45 +0200 Subject: [PATCH] AGC2 RNN VAD: Sequence buffer The SequenceBuffer class template implements a linear buffer with a Push operation that is used to add a fixed size chunk of new samples into the buffer. Its properties are its size and the size of the chunks that are pushed. It is used to implement the pitch buffer in the RNN VAD feature extractor, for which a ring buffer would be a painful choice. Bug: webrtc:9076 Change-Id: I4767bf06d5a414dbed724a96ea4186ef013a1e30 Reviewed-on: https://webrtc-review.googlesource.com/70204 Commit-Queue: Alessio Bazzica Reviewed-by: Gustaf Ullberg Cr-Commit-Position: refs/heads/master@{#22919} --- modules/audio_processing/BUILD.gn | 1 + .../audio_processing/agc2/rnn_vad/BUILD.gn | 13 +++ .../agc2/rnn_vad/sequence_buffer.h | 81 +++++++++++++ .../agc2/rnn_vad/sequence_buffer_unittest.cc | 106 ++++++++++++++++++ 4 files changed, 201 insertions(+) create mode 100644 modules/audio_processing/agc2/rnn_vad/sequence_buffer.h create mode 100644 modules/audio_processing/agc2/rnn_vad/sequence_buffer_unittest.cc diff --git a/modules/audio_processing/BUILD.gn b/modules/audio_processing/BUILD.gn index 7f89509014..f58a04f91e 100644 --- a/modules/audio_processing/BUILD.gn +++ b/modules/audio_processing/BUILD.gn @@ -552,6 +552,7 @@ if (rtc_include_tests) { "agc2:adaptive_digital_unittests", "agc2:fixed_digital_unittests", "agc2:noise_estimator_unittests", + "agc2/rnn_vad:unittests", "test/conversational_speech:unittest", "vad:vad_unittests", "//testing/gtest", diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn index 1941eff679..f7dee9eaa4 100644 --- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn +++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn @@ -17,6 +17,7 @@ group("rnn_vad") { source_set("lib") { sources = [ "common.h", + "sequence_buffer.h", ] deps = [ "../../../../api:array_view", @@ -25,6 +26,18 @@ source_set("lib") { } if (rtc_include_tests) { + rtc_source_set("unittests") { + testonly = true + sources = [ + "sequence_buffer_unittest.cc", + ] + deps = [ + ":lib", + "../../../../api:array_view", + "../../../../test:test_support", + ] + } + rtc_executable("rnn_vad_tool") { testonly = true sources = [ diff --git a/modules/audio_processing/agc2/rnn_vad/sequence_buffer.h b/modules/audio_processing/agc2/rnn_vad/sequence_buffer.h new file mode 100644 index 0000000000..7ae2f95a00 --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/sequence_buffer.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SEQUENCE_BUFFER_H_ +#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SEQUENCE_BUFFER_H_ + +#include +#include +#include + +#include "api/array_view.h" +#include "rtc_base/checks.h" + +namespace webrtc { +namespace rnn_vad { + +// Linear buffer implementation to (i) push fixed size chunks of sequential data +// and (ii) view contiguous parts of the buffer. The buffer and the pushed +// chunks have size S and N respectively. For instance, when S = 2N the first +// half of the sequence buffer is replaced with its second half, and the new N +// values are written at the end of the buffer. +template +class SequenceBuffer { + static_assert(S >= N, + "The new chunk size is larger than the sequence buffer size."); + static_assert(std::is_arithmetic::value, + "Integral or floating point required."); + + public: + SequenceBuffer() { buffer_.fill(0); } + SequenceBuffer(const SequenceBuffer&) = delete; + SequenceBuffer& operator=(const SequenceBuffer&) = delete; + ~SequenceBuffer() = default; + size_t size() const { return S; } + size_t chunks_size() const { return N; } + // Sets the sequence buffer values to zero. + void Reset() { buffer_.fill(0); } + // Returns a view on the whole buffer. + rtc::ArrayView GetBufferView() const { + return {buffer_.data(), S}; + } + // Returns a view on part of the buffer; the first element starts at the given + // offset and the last one is the last one in the buffer. + rtc::ArrayView GetBufferView(int offset) const { + RTC_DCHECK_LE(0, offset); + RTC_DCHECK_LT(offset, S); + return {buffer_.data() + offset, S - offset}; + } + // Returns a view on part of the buffer; the first element starts at the given + // offset and the size of the view is |size|. + rtc::ArrayView GetBufferView(int offset, size_t size) const { + RTC_DCHECK_LE(0, offset); + RTC_DCHECK_LT(offset, S); + RTC_DCHECK_LT(0, size); + RTC_DCHECK_LE(size, S - offset); + return {buffer_.data() + offset, size}; + } + // Shifts left the buffer by N items and add new N items at the end. + void Push(rtc::ArrayView new_values) { + // Make space for the new values. + if (S > N) + std::memmove(buffer_.data(), buffer_.data() + N, (S - N) * sizeof(T)); + // Copy the new values at the end of the buffer. + std::memcpy(buffer_.data() + S - N, new_values.data(), N * sizeof(T)); + } + + private: + std::array buffer_; +}; + +} // namespace rnn_vad +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SEQUENCE_BUFFER_H_ diff --git a/modules/audio_processing/agc2/rnn_vad/sequence_buffer_unittest.cc b/modules/audio_processing/agc2/rnn_vad/sequence_buffer_unittest.cc new file mode 100644 index 0000000000..f15a256e69 --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/sequence_buffer_unittest.cc @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/rnn_vad/sequence_buffer.h" + +#include + +#include "test/gtest.h" + +namespace webrtc { +namespace test { + +using rnn_vad::SequenceBuffer; + +namespace { + +template +void TestSequenceBufferPushOp() { + SCOPED_TRACE(S); + SCOPED_TRACE(N); + SequenceBuffer seq_buf; + auto seq_buf_view = seq_buf.GetBufferView(); + std::array chunk; + rtc::ArrayView chunk_view(chunk.data(), chunk.size()); + + // Check that a chunk is fully gone after ceil(S / N) push ops. + chunk.fill(1); + seq_buf.Push(chunk_view); + chunk.fill(0); + constexpr size_t required_push_ops = (S % N) ? S / N + 1 : S / N; + for (size_t i = 0; i < required_push_ops - 1; ++i) { + SCOPED_TRACE(i); + seq_buf.Push(chunk_view); + // Still in the buffer. + const auto* m = std::max_element(seq_buf_view.begin(), seq_buf_view.end()); + EXPECT_EQ(1, *m); + } + // Gone after another push. + seq_buf.Push(chunk_view); + const auto* m = std::max_element(seq_buf_view.begin(), seq_buf_view.end()); + EXPECT_EQ(0, *m); + + // Check that the last item moves left by N positions after a push op. + if (S > N) { + // Fill in with non-zero values. + for (size_t i = 0; i < N; ++i) + chunk[i] = static_cast(i + 1); + seq_buf.Push(chunk_view); + // With the next Push(), |last| will be moved left by N positions. + const T last = chunk[N - 1]; + for (size_t i = 0; i < N; ++i) + chunk[i] = static_cast(last + i + 1); + seq_buf.Push(chunk_view); + EXPECT_EQ(last, seq_buf_view[S - N - 1]); + } +} + +} // namespace + +TEST(RnnVadTest, SequenceBufferGetters) { + constexpr size_t buffer_size = 8; + constexpr size_t chunk_size = 8; + SequenceBuffer seq_buf; + EXPECT_EQ(buffer_size, seq_buf.size()); + EXPECT_EQ(chunk_size, seq_buf.chunks_size()); + // Test view. + auto seq_buf_view = seq_buf.GetBufferView(); + EXPECT_EQ(0, seq_buf_view[0]); + EXPECT_EQ(0, seq_buf_view[seq_buf_view.size() - 1]); + constexpr std::array chunk = {10, 20, 30, 40, + 50, 60, 70, 80}; + seq_buf.Push({chunk.data(), chunk_size}); + EXPECT_EQ(10, *seq_buf_view.begin()); + EXPECT_EQ(80, *(seq_buf_view.end() - 1)); +} + +TEST(RnnVadTest, SequenceBufferPushOpsUnsigned) { + TestSequenceBufferPushOp(); // Chunk size: 25%. + TestSequenceBufferPushOp(); // Chunk size: 50%. + TestSequenceBufferPushOp(); // Chunk size: 100%. + TestSequenceBufferPushOp(); // Non-integer ratio. +} + +TEST(RnnVadTest, SequenceBufferPushOpsSigned) { + TestSequenceBufferPushOp(); // Chunk size: 25%. + TestSequenceBufferPushOp(); // Chunk size: 50%. + TestSequenceBufferPushOp(); // Chunk size: 100%. + TestSequenceBufferPushOp(); // Non-integer ratio. +} + +TEST(RnnVadTest, SequenceBufferPushOpsFloating) { + TestSequenceBufferPushOp(); // Chunk size: 25%. + TestSequenceBufferPushOp(); // Chunk size: 50%. + TestSequenceBufferPushOp(); // Chunk size: 100%. + TestSequenceBufferPushOp(); // Non-integer ratio. +} + +} // namespace test +} // namespace webrtc