diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn index f7dee9eaa4..12e32e544c 100644 --- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn +++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn @@ -17,6 +17,7 @@ group("rnn_vad") { source_set("lib") { sources = [ "common.h", + "ring_buffer.h", "sequence_buffer.h", ] deps = [ @@ -29,6 +30,7 @@ if (rtc_include_tests) { rtc_source_set("unittests") { testonly = true sources = [ + "ring_buffer_unittest.cc", "sequence_buffer_unittest.cc", ] deps = [ diff --git a/modules/audio_processing/agc2/rnn_vad/ring_buffer.h b/modules/audio_processing/agc2/rnn_vad/ring_buffer.h new file mode 100644 index 0000000000..294b0c0ba8 --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/ring_buffer.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RING_BUFFER_H_ +#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RING_BUFFER_H_ + +#include +#include +#include + +#include "api/array_view.h" + +namespace webrtc { +namespace rnn_vad { + +// Ring buffer for N arrays of type T each one with size S. +template +class RingBuffer { + static_assert(S > 0, ""); + static_assert(N > 0, ""); + static_assert(std::is_arithmetic::value, + "Integral or floating point required."); + + public: + RingBuffer() : tail_(0) {} + RingBuffer(const RingBuffer&) = delete; + RingBuffer& operator=(const RingBuffer&) = delete; + ~RingBuffer() = default; + // Set the ring buffer values to zero. + void Reset() { buffer_.fill(0); } + // Replace the least recently pushed array in the buffer with |new_values|. + void Push(rtc::ArrayView new_values) { + std::memcpy(buffer_.data() + S * tail_, new_values.data(), S * sizeof(T)); + tail_ += 1; + if (tail_ == N) + tail_ = 0; + } + // Return an array view onto the array with a given delay. A view on the last + // and least recently push array is returned when |delay| is 0 and N - 1 + // respectively. + rtc::ArrayView GetArrayView(size_t delay) const { + const int delay_int = static_cast(delay); + RTC_DCHECK_LE(0, delay_int); + RTC_DCHECK_LT(delay_int, N); + int offset = tail_ - 1 - delay_int; + if (offset < 0) + offset += N; + return {buffer_.data() + S * offset, S}; + } + + private: + int tail_; // Index of the least recently pushed sub-array. + std::array buffer_{}; +}; + +} // namespace rnn_vad +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RING_BUFFER_H_ diff --git a/modules/audio_processing/agc2/rnn_vad/ring_buffer_unittest.cc b/modules/audio_processing/agc2/rnn_vad/ring_buffer_unittest.cc new file mode 100644 index 0000000000..0848f8d56c --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/ring_buffer_unittest.cc @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/rnn_vad/ring_buffer.h" + +#include "test/gtest.h" + +namespace webrtc { +namespace test { + +using rnn_vad::RingBuffer; + +namespace { + +// Compare the elements of two given array views. +template +void ExpectEq(rtc::ArrayView a, rtc::ArrayView b) { + for (size_t i = 0; i < S; ++i) { + SCOPED_TRACE(i); + EXPECT_EQ(a[i], b[i]); + } +} + +// Test push/read sequences. +template +void TestRingBuffer() { + SCOPED_TRACE(N); + SCOPED_TRACE(S); + std::array prev_pushed_array; + std::array pushed_array; + rtc::ArrayView pushed_array_view(pushed_array.data(), S); + + // Init. + RingBuffer ring_buf; + ring_buf.GetArrayView(0); + pushed_array.fill(0); + ring_buf.Push(pushed_array_view); + ExpectEq(pushed_array_view, ring_buf.GetArrayView(0)); + + // Push N times and check most recent and second most recent. + for (T v = 1; v <= static_cast(N); ++v) { + SCOPED_TRACE(v); + prev_pushed_array = pushed_array; + pushed_array.fill(v); + ring_buf.Push(pushed_array_view); + ExpectEq(pushed_array_view, ring_buf.GetArrayView(0)); + if (N > 1) { + pushed_array.fill(v - 1); + ExpectEq(pushed_array_view, ring_buf.GetArrayView(1)); + } + } + + // Check buffer. + for (size_t delay = 2; delay < N; ++delay) { + SCOPED_TRACE(delay); + T expected_value = N - static_cast(delay); + pushed_array.fill(expected_value); + ExpectEq(pushed_array_view, ring_buf.GetArrayView(delay)); + } +} + +} // namespace + +// Check that for different delays, different views are returned. +TEST(RnnVadTest, RingBufferArrayViews) { + constexpr size_t s = 3; + constexpr size_t n = 4; + RingBuffer ring_buf; + std::array pushed_array; + pushed_array.fill(1); + for (size_t k = 0; k <= n; ++k) { // Push data n + 1 times. + SCOPED_TRACE(k); + // Check array views. + for (size_t i = 0; i < n; ++i) { + SCOPED_TRACE(i); + auto view_i = ring_buf.GetArrayView(i); + for (size_t j = i + 1; j < n; ++j) { + SCOPED_TRACE(j); + auto view_j = ring_buf.GetArrayView(j); + EXPECT_NE(view_i, view_j); + } + } + ring_buf.Push({pushed_array.data(), pushed_array.size()}); + } +} + +TEST(RnnVadTest, RingBufferUnsigned) { + TestRingBuffer(); + TestRingBuffer(); + TestRingBuffer(); + TestRingBuffer(); +} + +TEST(RnnVadTest, RingBufferSigned) { + TestRingBuffer(); + TestRingBuffer(); + TestRingBuffer(); + TestRingBuffer(); +} + +TEST(RnnVadTest, RingBufferFloating) { + TestRingBuffer(); + TestRingBuffer(); + TestRingBuffer(); + TestRingBuffer(); +} + +} // namespace test +} // namespace webrtc