AGC2 RNN VAD: Sequence buffer
The SequenceBuffer class template implements a linear buffer with a Push operation that is used to add a fixed size chunk of new samples into the buffer. Its properties are its size and the size of the chunks that are pushed. It is used to implement the pitch buffer in the RNN VAD feature extractor, for which a ring buffer would be a painful choice. Bug: webrtc:9076 Change-Id: I4767bf06d5a414dbed724a96ea4186ef013a1e30 Reviewed-on: https://webrtc-review.googlesource.com/70204 Commit-Queue: Alessio Bazzica <alessiob@webrtc.org> Reviewed-by: Gustaf Ullberg <gustaf@webrtc.org> Cr-Commit-Position: refs/heads/master@{#22919}
This commit is contained in:
parent
a44ab181bf
commit
4736d4e524
@ -552,6 +552,7 @@ if (rtc_include_tests) {
|
||||
"agc2:adaptive_digital_unittests",
|
||||
"agc2:fixed_digital_unittests",
|
||||
"agc2:noise_estimator_unittests",
|
||||
"agc2/rnn_vad:unittests",
|
||||
"test/conversational_speech:unittest",
|
||||
"vad:vad_unittests",
|
||||
"//testing/gtest",
|
||||
|
||||
@ -17,6 +17,7 @@ group("rnn_vad") {
|
||||
source_set("lib") {
|
||||
sources = [
|
||||
"common.h",
|
||||
"sequence_buffer.h",
|
||||
]
|
||||
deps = [
|
||||
"../../../../api:array_view",
|
||||
@ -25,6 +26,18 @@ source_set("lib") {
|
||||
}
|
||||
|
||||
if (rtc_include_tests) {
|
||||
rtc_source_set("unittests") {
|
||||
testonly = true
|
||||
sources = [
|
||||
"sequence_buffer_unittest.cc",
|
||||
]
|
||||
deps = [
|
||||
":lib",
|
||||
"../../../../api:array_view",
|
||||
"../../../../test:test_support",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_executable("rnn_vad_tool") {
|
||||
testonly = true
|
||||
sources = [
|
||||
|
||||
81
modules/audio_processing/agc2/rnn_vad/sequence_buffer.h
Normal file
81
modules/audio_processing/agc2/rnn_vad/sequence_buffer.h
Normal file
@ -0,0 +1,81 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SEQUENCE_BUFFER_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SEQUENCE_BUFFER_H_
|
||||
|
||||
#include <array>
|
||||
#include <cstring>
|
||||
#include <type_traits>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Linear buffer implementation to (i) push fixed size chunks of sequential data
|
||||
// and (ii) view contiguous parts of the buffer. The buffer and the pushed
|
||||
// chunks have size S and N respectively. For instance, when S = 2N the first
|
||||
// half of the sequence buffer is replaced with its second half, and the new N
|
||||
// values are written at the end of the buffer.
|
||||
template <typename T, size_t S, size_t N>
|
||||
class SequenceBuffer {
|
||||
static_assert(S >= N,
|
||||
"The new chunk size is larger than the sequence buffer size.");
|
||||
static_assert(std::is_arithmetic<T>::value,
|
||||
"Integral or floating point required.");
|
||||
|
||||
public:
|
||||
SequenceBuffer() { buffer_.fill(0); }
|
||||
SequenceBuffer(const SequenceBuffer&) = delete;
|
||||
SequenceBuffer& operator=(const SequenceBuffer&) = delete;
|
||||
~SequenceBuffer() = default;
|
||||
size_t size() const { return S; }
|
||||
size_t chunks_size() const { return N; }
|
||||
// Sets the sequence buffer values to zero.
|
||||
void Reset() { buffer_.fill(0); }
|
||||
// Returns a view on the whole buffer.
|
||||
rtc::ArrayView<const T, S> GetBufferView() const {
|
||||
return {buffer_.data(), S};
|
||||
}
|
||||
// Returns a view on part of the buffer; the first element starts at the given
|
||||
// offset and the last one is the last one in the buffer.
|
||||
rtc::ArrayView<const T> GetBufferView(int offset) const {
|
||||
RTC_DCHECK_LE(0, offset);
|
||||
RTC_DCHECK_LT(offset, S);
|
||||
return {buffer_.data() + offset, S - offset};
|
||||
}
|
||||
// Returns a view on part of the buffer; the first element starts at the given
|
||||
// offset and the size of the view is |size|.
|
||||
rtc::ArrayView<const T> GetBufferView(int offset, size_t size) const {
|
||||
RTC_DCHECK_LE(0, offset);
|
||||
RTC_DCHECK_LT(offset, S);
|
||||
RTC_DCHECK_LT(0, size);
|
||||
RTC_DCHECK_LE(size, S - offset);
|
||||
return {buffer_.data() + offset, size};
|
||||
}
|
||||
// Shifts left the buffer by N items and add new N items at the end.
|
||||
void Push(rtc::ArrayView<const T, N> new_values) {
|
||||
// Make space for the new values.
|
||||
if (S > N)
|
||||
std::memmove(buffer_.data(), buffer_.data() + N, (S - N) * sizeof(T));
|
||||
// Copy the new values at the end of the buffer.
|
||||
std::memcpy(buffer_.data() + S - N, new_values.data(), N * sizeof(T));
|
||||
}
|
||||
|
||||
private:
|
||||
std::array<T, S> buffer_;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SEQUENCE_BUFFER_H_
|
||||
@ -0,0 +1,106 @@
|
||||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/sequence_buffer.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "test/gtest.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace test {
|
||||
|
||||
using rnn_vad::SequenceBuffer;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, size_t S, size_t N>
|
||||
void TestSequenceBufferPushOp() {
|
||||
SCOPED_TRACE(S);
|
||||
SCOPED_TRACE(N);
|
||||
SequenceBuffer<T, S, N> seq_buf;
|
||||
auto seq_buf_view = seq_buf.GetBufferView();
|
||||
std::array<T, N> chunk;
|
||||
rtc::ArrayView<T, N> chunk_view(chunk.data(), chunk.size());
|
||||
|
||||
// Check that a chunk is fully gone after ceil(S / N) push ops.
|
||||
chunk.fill(1);
|
||||
seq_buf.Push(chunk_view);
|
||||
chunk.fill(0);
|
||||
constexpr size_t required_push_ops = (S % N) ? S / N + 1 : S / N;
|
||||
for (size_t i = 0; i < required_push_ops - 1; ++i) {
|
||||
SCOPED_TRACE(i);
|
||||
seq_buf.Push(chunk_view);
|
||||
// Still in the buffer.
|
||||
const auto* m = std::max_element(seq_buf_view.begin(), seq_buf_view.end());
|
||||
EXPECT_EQ(1, *m);
|
||||
}
|
||||
// Gone after another push.
|
||||
seq_buf.Push(chunk_view);
|
||||
const auto* m = std::max_element(seq_buf_view.begin(), seq_buf_view.end());
|
||||
EXPECT_EQ(0, *m);
|
||||
|
||||
// Check that the last item moves left by N positions after a push op.
|
||||
if (S > N) {
|
||||
// Fill in with non-zero values.
|
||||
for (size_t i = 0; i < N; ++i)
|
||||
chunk[i] = static_cast<T>(i + 1);
|
||||
seq_buf.Push(chunk_view);
|
||||
// With the next Push(), |last| will be moved left by N positions.
|
||||
const T last = chunk[N - 1];
|
||||
for (size_t i = 0; i < N; ++i)
|
||||
chunk[i] = static_cast<T>(last + i + 1);
|
||||
seq_buf.Push(chunk_view);
|
||||
EXPECT_EQ(last, seq_buf_view[S - N - 1]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST(RnnVadTest, SequenceBufferGetters) {
|
||||
constexpr size_t buffer_size = 8;
|
||||
constexpr size_t chunk_size = 8;
|
||||
SequenceBuffer<int, buffer_size, chunk_size> seq_buf;
|
||||
EXPECT_EQ(buffer_size, seq_buf.size());
|
||||
EXPECT_EQ(chunk_size, seq_buf.chunks_size());
|
||||
// Test view.
|
||||
auto seq_buf_view = seq_buf.GetBufferView();
|
||||
EXPECT_EQ(0, seq_buf_view[0]);
|
||||
EXPECT_EQ(0, seq_buf_view[seq_buf_view.size() - 1]);
|
||||
constexpr std::array<int, chunk_size> chunk = {10, 20, 30, 40,
|
||||
50, 60, 70, 80};
|
||||
seq_buf.Push({chunk.data(), chunk_size});
|
||||
EXPECT_EQ(10, *seq_buf_view.begin());
|
||||
EXPECT_EQ(80, *(seq_buf_view.end() - 1));
|
||||
}
|
||||
|
||||
TEST(RnnVadTest, SequenceBufferPushOpsUnsigned) {
|
||||
TestSequenceBufferPushOp<uint8_t, 32, 8>(); // Chunk size: 25%.
|
||||
TestSequenceBufferPushOp<uint8_t, 32, 16>(); // Chunk size: 50%.
|
||||
TestSequenceBufferPushOp<uint8_t, 32, 32>(); // Chunk size: 100%.
|
||||
TestSequenceBufferPushOp<uint8_t, 23, 7>(); // Non-integer ratio.
|
||||
}
|
||||
|
||||
TEST(RnnVadTest, SequenceBufferPushOpsSigned) {
|
||||
TestSequenceBufferPushOp<int, 32, 8>(); // Chunk size: 25%.
|
||||
TestSequenceBufferPushOp<int, 32, 16>(); // Chunk size: 50%.
|
||||
TestSequenceBufferPushOp<int, 32, 32>(); // Chunk size: 100%.
|
||||
TestSequenceBufferPushOp<int, 23, 7>(); // Non-integer ratio.
|
||||
}
|
||||
|
||||
TEST(RnnVadTest, SequenceBufferPushOpsFloating) {
|
||||
TestSequenceBufferPushOp<float, 32, 8>(); // Chunk size: 25%.
|
||||
TestSequenceBufferPushOp<float, 32, 16>(); // Chunk size: 50%.
|
||||
TestSequenceBufferPushOp<float, 32, 32>(); // Chunk size: 100%.
|
||||
TestSequenceBufferPushOp<float, 23, 7>(); // Non-integer ratio.
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace webrtc
|
||||
Loading…
x
Reference in New Issue
Block a user