diff --git a/modules/audio_processing/BUILD.gn b/modules/audio_processing/BUILD.gn index 7f89509014..f58a04f91e 100644 --- a/modules/audio_processing/BUILD.gn +++ b/modules/audio_processing/BUILD.gn @@ -552,6 +552,7 @@ if (rtc_include_tests) { "agc2:adaptive_digital_unittests", "agc2:fixed_digital_unittests", "agc2:noise_estimator_unittests", + "agc2/rnn_vad:unittests", "test/conversational_speech:unittest", "vad:vad_unittests", "//testing/gtest", diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn index 1941eff679..f7dee9eaa4 100644 --- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn +++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn @@ -17,6 +17,7 @@ group("rnn_vad") { source_set("lib") { sources = [ "common.h", + "sequence_buffer.h", ] deps = [ "../../../../api:array_view", @@ -25,6 +26,18 @@ source_set("lib") { } if (rtc_include_tests) { + rtc_source_set("unittests") { + testonly = true + sources = [ + "sequence_buffer_unittest.cc", + ] + deps = [ + ":lib", + "../../../../api:array_view", + "../../../../test:test_support", + ] + } + rtc_executable("rnn_vad_tool") { testonly = true sources = [ diff --git a/modules/audio_processing/agc2/rnn_vad/sequence_buffer.h b/modules/audio_processing/agc2/rnn_vad/sequence_buffer.h new file mode 100644 index 0000000000..7ae2f95a00 --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/sequence_buffer.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SEQUENCE_BUFFER_H_ +#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SEQUENCE_BUFFER_H_ + +#include +#include +#include + +#include "api/array_view.h" +#include "rtc_base/checks.h" + +namespace webrtc { +namespace rnn_vad { + +// Linear buffer implementation to (i) push fixed size chunks of sequential data +// and (ii) view contiguous parts of the buffer. The buffer and the pushed +// chunks have size S and N respectively. For instance, when S = 2N the first +// half of the sequence buffer is replaced with its second half, and the new N +// values are written at the end of the buffer. +template +class SequenceBuffer { + static_assert(S >= N, + "The new chunk size is larger than the sequence buffer size."); + static_assert(std::is_arithmetic::value, + "Integral or floating point required."); + + public: + SequenceBuffer() { buffer_.fill(0); } + SequenceBuffer(const SequenceBuffer&) = delete; + SequenceBuffer& operator=(const SequenceBuffer&) = delete; + ~SequenceBuffer() = default; + size_t size() const { return S; } + size_t chunks_size() const { return N; } + // Sets the sequence buffer values to zero. + void Reset() { buffer_.fill(0); } + // Returns a view on the whole buffer. + rtc::ArrayView GetBufferView() const { + return {buffer_.data(), S}; + } + // Returns a view on part of the buffer; the first element starts at the given + // offset and the last one is the last one in the buffer. + rtc::ArrayView GetBufferView(int offset) const { + RTC_DCHECK_LE(0, offset); + RTC_DCHECK_LT(offset, S); + return {buffer_.data() + offset, S - offset}; + } + // Returns a view on part of the buffer; the first element starts at the given + // offset and the size of the view is |size|. + rtc::ArrayView GetBufferView(int offset, size_t size) const { + RTC_DCHECK_LE(0, offset); + RTC_DCHECK_LT(offset, S); + RTC_DCHECK_LT(0, size); + RTC_DCHECK_LE(size, S - offset); + return {buffer_.data() + offset, size}; + } + // Shifts left the buffer by N items and add new N items at the end. + void Push(rtc::ArrayView new_values) { + // Make space for the new values. + if (S > N) + std::memmove(buffer_.data(), buffer_.data() + N, (S - N) * sizeof(T)); + // Copy the new values at the end of the buffer. + std::memcpy(buffer_.data() + S - N, new_values.data(), N * sizeof(T)); + } + + private: + std::array buffer_; +}; + +} // namespace rnn_vad +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SEQUENCE_BUFFER_H_ diff --git a/modules/audio_processing/agc2/rnn_vad/sequence_buffer_unittest.cc b/modules/audio_processing/agc2/rnn_vad/sequence_buffer_unittest.cc new file mode 100644 index 0000000000..f15a256e69 --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/sequence_buffer_unittest.cc @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/rnn_vad/sequence_buffer.h" + +#include + +#include "test/gtest.h" + +namespace webrtc { +namespace test { + +using rnn_vad::SequenceBuffer; + +namespace { + +template +void TestSequenceBufferPushOp() { + SCOPED_TRACE(S); + SCOPED_TRACE(N); + SequenceBuffer seq_buf; + auto seq_buf_view = seq_buf.GetBufferView(); + std::array chunk; + rtc::ArrayView chunk_view(chunk.data(), chunk.size()); + + // Check that a chunk is fully gone after ceil(S / N) push ops. + chunk.fill(1); + seq_buf.Push(chunk_view); + chunk.fill(0); + constexpr size_t required_push_ops = (S % N) ? S / N + 1 : S / N; + for (size_t i = 0; i < required_push_ops - 1; ++i) { + SCOPED_TRACE(i); + seq_buf.Push(chunk_view); + // Still in the buffer. + const auto* m = std::max_element(seq_buf_view.begin(), seq_buf_view.end()); + EXPECT_EQ(1, *m); + } + // Gone after another push. + seq_buf.Push(chunk_view); + const auto* m = std::max_element(seq_buf_view.begin(), seq_buf_view.end()); + EXPECT_EQ(0, *m); + + // Check that the last item moves left by N positions after a push op. + if (S > N) { + // Fill in with non-zero values. + for (size_t i = 0; i < N; ++i) + chunk[i] = static_cast(i + 1); + seq_buf.Push(chunk_view); + // With the next Push(), |last| will be moved left by N positions. + const T last = chunk[N - 1]; + for (size_t i = 0; i < N; ++i) + chunk[i] = static_cast(last + i + 1); + seq_buf.Push(chunk_view); + EXPECT_EQ(last, seq_buf_view[S - N - 1]); + } +} + +} // namespace + +TEST(RnnVadTest, SequenceBufferGetters) { + constexpr size_t buffer_size = 8; + constexpr size_t chunk_size = 8; + SequenceBuffer seq_buf; + EXPECT_EQ(buffer_size, seq_buf.size()); + EXPECT_EQ(chunk_size, seq_buf.chunks_size()); + // Test view. + auto seq_buf_view = seq_buf.GetBufferView(); + EXPECT_EQ(0, seq_buf_view[0]); + EXPECT_EQ(0, seq_buf_view[seq_buf_view.size() - 1]); + constexpr std::array chunk = {10, 20, 30, 40, + 50, 60, 70, 80}; + seq_buf.Push({chunk.data(), chunk_size}); + EXPECT_EQ(10, *seq_buf_view.begin()); + EXPECT_EQ(80, *(seq_buf_view.end() - 1)); +} + +TEST(RnnVadTest, SequenceBufferPushOpsUnsigned) { + TestSequenceBufferPushOp(); // Chunk size: 25%. + TestSequenceBufferPushOp(); // Chunk size: 50%. + TestSequenceBufferPushOp(); // Chunk size: 100%. + TestSequenceBufferPushOp(); // Non-integer ratio. +} + +TEST(RnnVadTest, SequenceBufferPushOpsSigned) { + TestSequenceBufferPushOp(); // Chunk size: 25%. + TestSequenceBufferPushOp(); // Chunk size: 50%. + TestSequenceBufferPushOp(); // Chunk size: 100%. + TestSequenceBufferPushOp(); // Non-integer ratio. +} + +TEST(RnnVadTest, SequenceBufferPushOpsFloating) { + TestSequenceBufferPushOp(); // Chunk size: 25%. + TestSequenceBufferPushOp(); // Chunk size: 50%. + TestSequenceBufferPushOp(); // Chunk size: 100%. + TestSequenceBufferPushOp(); // Non-integer ratio. +} + +} // namespace test +} // namespace webrtc