AGC2 RNN VAD: Ring buffer
Ring buffer template for a finite number of arrays of given type and size. Bug: webrtc:9076 Change-Id: Ia6c2065b0013f4a00f693966641f9aebe09f6f5c Reviewed-on: https://webrtc-review.googlesource.com/70161 Reviewed-by: Alex Loiko <aleloi@webrtc.org> Reviewed-by: Sam Zackrisson <saza@webrtc.org> Commit-Queue: Alessio Bazzica <alessiob@webrtc.org> Cr-Commit-Position: refs/heads/master@{#22939}
This commit is contained in:
parent
8d7393bb28
commit
adbd808e0a
@ -17,6 +17,7 @@ group("rnn_vad") {
|
||||
source_set("lib") {
|
||||
sources = [
|
||||
"common.h",
|
||||
"ring_buffer.h",
|
||||
"sequence_buffer.h",
|
||||
]
|
||||
deps = [
|
||||
@ -29,6 +30,7 @@ if (rtc_include_tests) {
|
||||
rtc_source_set("unittests") {
|
||||
testonly = true
|
||||
sources = [
|
||||
"ring_buffer_unittest.cc",
|
||||
"sequence_buffer_unittest.cc",
|
||||
]
|
||||
deps = [
|
||||
|
||||
66
modules/audio_processing/agc2/rnn_vad/ring_buffer.h
Normal file
66
modules/audio_processing/agc2/rnn_vad/ring_buffer.h
Normal file
@ -0,0 +1,66 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RING_BUFFER_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RING_BUFFER_H_
|
||||
|
||||
#include <array>
|
||||
#include <cstring>
|
||||
#include <type_traits>
|
||||
|
||||
#include "api/array_view.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Ring buffer for N arrays of type T each one with size S.
|
||||
template <typename T, size_t S, size_t N>
|
||||
class RingBuffer {
|
||||
static_assert(S > 0, "");
|
||||
static_assert(N > 0, "");
|
||||
static_assert(std::is_arithmetic<T>::value,
|
||||
"Integral or floating point required.");
|
||||
|
||||
public:
|
||||
RingBuffer() : tail_(0) {}
|
||||
RingBuffer(const RingBuffer&) = delete;
|
||||
RingBuffer& operator=(const RingBuffer&) = delete;
|
||||
~RingBuffer() = default;
|
||||
// Set the ring buffer values to zero.
|
||||
void Reset() { buffer_.fill(0); }
|
||||
// Replace the least recently pushed array in the buffer with |new_values|.
|
||||
void Push(rtc::ArrayView<const T, S> new_values) {
|
||||
std::memcpy(buffer_.data() + S * tail_, new_values.data(), S * sizeof(T));
|
||||
tail_ += 1;
|
||||
if (tail_ == N)
|
||||
tail_ = 0;
|
||||
}
|
||||
// Return an array view onto the array with a given delay. A view on the last
|
||||
// and least recently push array is returned when |delay| is 0 and N - 1
|
||||
// respectively.
|
||||
rtc::ArrayView<const T, S> GetArrayView(size_t delay) const {
|
||||
const int delay_int = static_cast<int>(delay);
|
||||
RTC_DCHECK_LE(0, delay_int);
|
||||
RTC_DCHECK_LT(delay_int, N);
|
||||
int offset = tail_ - 1 - delay_int;
|
||||
if (offset < 0)
|
||||
offset += N;
|
||||
return {buffer_.data() + S * offset, S};
|
||||
}
|
||||
|
||||
private:
|
||||
int tail_; // Index of the least recently pushed sub-array.
|
||||
std::array<T, S * N> buffer_{};
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RING_BUFFER_H_
|
||||
116
modules/audio_processing/agc2/rnn_vad/ring_buffer_unittest.cc
Normal file
116
modules/audio_processing/agc2/rnn_vad/ring_buffer_unittest.cc
Normal file
@ -0,0 +1,116 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/ring_buffer.h"
|
||||
|
||||
#include "test/gtest.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace test {
|
||||
|
||||
using rnn_vad::RingBuffer;
|
||||
|
||||
namespace {
|
||||
|
||||
// Compare the elements of two given array views.
|
||||
template <typename T, std::ptrdiff_t S>
|
||||
void ExpectEq(rtc::ArrayView<const T, S> a, rtc::ArrayView<const T, S> b) {
|
||||
for (size_t i = 0; i < S; ++i) {
|
||||
SCOPED_TRACE(i);
|
||||
EXPECT_EQ(a[i], b[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Test push/read sequences.
|
||||
template <typename T, size_t S, size_t N>
|
||||
void TestRingBuffer() {
|
||||
SCOPED_TRACE(N);
|
||||
SCOPED_TRACE(S);
|
||||
std::array<T, S> prev_pushed_array;
|
||||
std::array<T, S> pushed_array;
|
||||
rtc::ArrayView<const T, S> pushed_array_view(pushed_array.data(), S);
|
||||
|
||||
// Init.
|
||||
RingBuffer<T, S, N> ring_buf;
|
||||
ring_buf.GetArrayView(0);
|
||||
pushed_array.fill(0);
|
||||
ring_buf.Push(pushed_array_view);
|
||||
ExpectEq(pushed_array_view, ring_buf.GetArrayView(0));
|
||||
|
||||
// Push N times and check most recent and second most recent.
|
||||
for (T v = 1; v <= static_cast<T>(N); ++v) {
|
||||
SCOPED_TRACE(v);
|
||||
prev_pushed_array = pushed_array;
|
||||
pushed_array.fill(v);
|
||||
ring_buf.Push(pushed_array_view);
|
||||
ExpectEq(pushed_array_view, ring_buf.GetArrayView(0));
|
||||
if (N > 1) {
|
||||
pushed_array.fill(v - 1);
|
||||
ExpectEq(pushed_array_view, ring_buf.GetArrayView(1));
|
||||
}
|
||||
}
|
||||
|
||||
// Check buffer.
|
||||
for (size_t delay = 2; delay < N; ++delay) {
|
||||
SCOPED_TRACE(delay);
|
||||
T expected_value = N - static_cast<T>(delay);
|
||||
pushed_array.fill(expected_value);
|
||||
ExpectEq(pushed_array_view, ring_buf.GetArrayView(delay));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Check that for different delays, different views are returned.
|
||||
TEST(RnnVadTest, RingBufferArrayViews) {
|
||||
constexpr size_t s = 3;
|
||||
constexpr size_t n = 4;
|
||||
RingBuffer<int, s, n> ring_buf;
|
||||
std::array<int, s> pushed_array;
|
||||
pushed_array.fill(1);
|
||||
for (size_t k = 0; k <= n; ++k) { // Push data n + 1 times.
|
||||
SCOPED_TRACE(k);
|
||||
// Check array views.
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
SCOPED_TRACE(i);
|
||||
auto view_i = ring_buf.GetArrayView(i);
|
||||
for (size_t j = i + 1; j < n; ++j) {
|
||||
SCOPED_TRACE(j);
|
||||
auto view_j = ring_buf.GetArrayView(j);
|
||||
EXPECT_NE(view_i, view_j);
|
||||
}
|
||||
}
|
||||
ring_buf.Push({pushed_array.data(), pushed_array.size()});
|
||||
}
|
||||
}
|
||||
|
||||
TEST(RnnVadTest, RingBufferUnsigned) {
|
||||
TestRingBuffer<uint8_t, 1, 1>();
|
||||
TestRingBuffer<uint8_t, 2, 5>();
|
||||
TestRingBuffer<uint8_t, 5, 2>();
|
||||
TestRingBuffer<uint8_t, 5, 5>();
|
||||
}
|
||||
|
||||
TEST(RnnVadTest, RingBufferSigned) {
|
||||
TestRingBuffer<int, 1, 1>();
|
||||
TestRingBuffer<int, 2, 5>();
|
||||
TestRingBuffer<int, 5, 2>();
|
||||
TestRingBuffer<int, 5, 5>();
|
||||
}
|
||||
|
||||
TEST(RnnVadTest, RingBufferFloating) {
|
||||
TestRingBuffer<float, 1, 1>();
|
||||
TestRingBuffer<float, 2, 5>();
|
||||
TestRingBuffer<float, 5, 2>();
|
||||
TestRingBuffer<float, 5, 5>();
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace webrtc
|
||||
Loading…
x
Reference in New Issue
Block a user