AGC2 RNN VAD: Ring buffer

Ring buffer template for a finite number of arrays of given type and size.

Bug: webrtc:9076
Change-Id: Ia6c2065b0013f4a00f693966641f9aebe09f6f5c
Reviewed-on: https://webrtc-review.googlesource.com/70161
Reviewed-by: Alex Loiko <aleloi@webrtc.org>
Reviewed-by: Sam Zackrisson <saza@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#22939}
This commit is contained in:
Alessio Bazzica 2018-04-19 14:28:00 +02:00 committed by Commit Bot
parent 8d7393bb28
commit adbd808e0a
3 changed files with 184 additions and 0 deletions

View File

@ -17,6 +17,7 @@ group("rnn_vad") {
source_set("lib") {
sources = [
"common.h",
"ring_buffer.h",
"sequence_buffer.h",
]
deps = [
@ -29,6 +30,7 @@ if (rtc_include_tests) {
rtc_source_set("unittests") {
testonly = true
sources = [
"ring_buffer_unittest.cc",
"sequence_buffer_unittest.cc",
]
deps = [

View File

@ -0,0 +1,66 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RING_BUFFER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RING_BUFFER_H_
#include <array>
#include <cstring>
#include <type_traits>
#include "api/array_view.h"
namespace webrtc {
namespace rnn_vad {
// Ring buffer for N arrays of type T each one with size S.
template <typename T, size_t S, size_t N>
class RingBuffer {
static_assert(S > 0, "");
static_assert(N > 0, "");
static_assert(std::is_arithmetic<T>::value,
"Integral or floating point required.");
public:
RingBuffer() : tail_(0) {}
RingBuffer(const RingBuffer&) = delete;
RingBuffer& operator=(const RingBuffer&) = delete;
~RingBuffer() = default;
// Set the ring buffer values to zero.
void Reset() { buffer_.fill(0); }
// Replace the least recently pushed array in the buffer with |new_values|.
void Push(rtc::ArrayView<const T, S> new_values) {
std::memcpy(buffer_.data() + S * tail_, new_values.data(), S * sizeof(T));
tail_ += 1;
if (tail_ == N)
tail_ = 0;
}
// Return an array view onto the array with a given delay. A view on the last
// and least recently push array is returned when |delay| is 0 and N - 1
// respectively.
rtc::ArrayView<const T, S> GetArrayView(size_t delay) const {
const int delay_int = static_cast<int>(delay);
RTC_DCHECK_LE(0, delay_int);
RTC_DCHECK_LT(delay_int, N);
int offset = tail_ - 1 - delay_int;
if (offset < 0)
offset += N;
return {buffer_.data() + S * offset, S};
}
private:
int tail_; // Index of the least recently pushed sub-array.
std::array<T, S * N> buffer_{};
};
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RING_BUFFER_H_

View File

@ -0,0 +1,116 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/ring_buffer.h"
#include "test/gtest.h"
namespace webrtc {
namespace test {
using rnn_vad::RingBuffer;
namespace {
// Compare the elements of two given array views.
template <typename T, std::ptrdiff_t S>
void ExpectEq(rtc::ArrayView<const T, S> a, rtc::ArrayView<const T, S> b) {
for (size_t i = 0; i < S; ++i) {
SCOPED_TRACE(i);
EXPECT_EQ(a[i], b[i]);
}
}
// Test push/read sequences.
template <typename T, size_t S, size_t N>
void TestRingBuffer() {
SCOPED_TRACE(N);
SCOPED_TRACE(S);
std::array<T, S> prev_pushed_array;
std::array<T, S> pushed_array;
rtc::ArrayView<const T, S> pushed_array_view(pushed_array.data(), S);
// Init.
RingBuffer<T, S, N> ring_buf;
ring_buf.GetArrayView(0);
pushed_array.fill(0);
ring_buf.Push(pushed_array_view);
ExpectEq(pushed_array_view, ring_buf.GetArrayView(0));
// Push N times and check most recent and second most recent.
for (T v = 1; v <= static_cast<T>(N); ++v) {
SCOPED_TRACE(v);
prev_pushed_array = pushed_array;
pushed_array.fill(v);
ring_buf.Push(pushed_array_view);
ExpectEq(pushed_array_view, ring_buf.GetArrayView(0));
if (N > 1) {
pushed_array.fill(v - 1);
ExpectEq(pushed_array_view, ring_buf.GetArrayView(1));
}
}
// Check buffer.
for (size_t delay = 2; delay < N; ++delay) {
SCOPED_TRACE(delay);
T expected_value = N - static_cast<T>(delay);
pushed_array.fill(expected_value);
ExpectEq(pushed_array_view, ring_buf.GetArrayView(delay));
}
}
} // namespace
// Check that for different delays, different views are returned.
TEST(RnnVadTest, RingBufferArrayViews) {
constexpr size_t s = 3;
constexpr size_t n = 4;
RingBuffer<int, s, n> ring_buf;
std::array<int, s> pushed_array;
pushed_array.fill(1);
for (size_t k = 0; k <= n; ++k) { // Push data n + 1 times.
SCOPED_TRACE(k);
// Check array views.
for (size_t i = 0; i < n; ++i) {
SCOPED_TRACE(i);
auto view_i = ring_buf.GetArrayView(i);
for (size_t j = i + 1; j < n; ++j) {
SCOPED_TRACE(j);
auto view_j = ring_buf.GetArrayView(j);
EXPECT_NE(view_i, view_j);
}
}
ring_buf.Push({pushed_array.data(), pushed_array.size()});
}
}
TEST(RnnVadTest, RingBufferUnsigned) {
TestRingBuffer<uint8_t, 1, 1>();
TestRingBuffer<uint8_t, 2, 5>();
TestRingBuffer<uint8_t, 5, 2>();
TestRingBuffer<uint8_t, 5, 5>();
}
TEST(RnnVadTest, RingBufferSigned) {
TestRingBuffer<int, 1, 1>();
TestRingBuffer<int, 2, 5>();
TestRingBuffer<int, 5, 2>();
TestRingBuffer<int, 5, 5>();
}
TEST(RnnVadTest, RingBufferFloating) {
TestRingBuffer<float, 1, 1>();
TestRingBuffer<float, 2, 5>();
TestRingBuffer<float, 5, 2>();
TestRingBuffer<float, 5, 5>();
}
} // namespace test
} // namespace webrtc