diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn index 12e32e544c..814a7f5b44 100644 --- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn +++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn @@ -19,6 +19,7 @@ source_set("lib") { "common.h", "ring_buffer.h", "sequence_buffer.h", + "symmetric_matrix_buffer.h", ] deps = [ "../../../../api:array_view", @@ -32,6 +33,7 @@ if (rtc_include_tests) { sources = [ "ring_buffer_unittest.cc", "sequence_buffer_unittest.cc", + "symmetric_matrix_buffer_unittest.cc", ] deps = [ ":lib", diff --git a/modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer.h b/modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer.h new file mode 100644 index 0000000000..f0282aaed5 --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer.h @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SYMMETRIC_MATRIX_BUFFER_H_ +#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SYMMETRIC_MATRIX_BUFFER_H_ + +#include +#include +#include +#include + +#include "api/array_view.h" +#include "rtc_base/checks.h" + +namespace webrtc { +namespace rnn_vad { + +// Data structure to buffer the results of pair-wise comparisons between items +// stored in a ring buffer. Every time that the oldest item is replaced in the +// ring buffer, the new one is compared to the remaining items in the ring +// buffer. The results of such comparisons need to be buffered and automatically +// removed when one of the two corresponding items that have been compared is +// removed from the ring buffer. It is assumed that the comparison is symmetric +// and that comparing an item with itself is not needed. +template +class SymmetricMatrixBuffer { + static_assert(S > 2, ""); + + public: + SymmetricMatrixBuffer() = default; + SymmetricMatrixBuffer(const SymmetricMatrixBuffer&) = delete; + SymmetricMatrixBuffer& operator=(const SymmetricMatrixBuffer&) = delete; + ~SymmetricMatrixBuffer() = default; + // Sets the buffer values to zero. + void Reset() { + static_assert(std::is_arithmetic::value, + "Integral or floating point required."); + buf_.fill(0); + } + // Pushes the results from the comparison between the most recent item and + // those that are still in the ring buffer. The first element in |values| must + // correspond to the comparison between the most recent item and the second + // most recent one in the ring buffer, whereas the last element in |values| + // must correspond to the comparison between the most recent item and the + // oldest one in the ring buffer. + void Push(rtc::ArrayView values) { + // Move the lower-right sub-matrix of size (S-2) x (S-2) one row up and one + // column left. + std::memmove(buf_.data(), buf_.data() + S, (buf_.size() - S) * sizeof(T)); + // Copy new values in the last column in the right order. + for (size_t i = 0; i < values.size(); ++i) { + const size_t index = (S - 1 - i) * (S - 1) - 1; + RTC_DCHECK_LE(static_cast(0), index); + RTC_DCHECK_LT(index, buf_.size()); + buf_[index] = values[i]; + } + } + // Reads the value that corresponds to comparison of two items in the ring + // buffer having delay |delay1| and |delay2|. The two arguments must not be + // equal and both must be in {0, ..., S - 1}. + T GetValue(size_t delay1, size_t delay2) const { + int row = S - 1 - static_cast(delay1); + int col = S - 1 - static_cast(delay2); + RTC_DCHECK_NE(row, col) << "The diagonal cannot be accessed."; + if (row > col) + std::swap(row, col); // Swap to access the upper-right triangular part. + RTC_DCHECK_LE(0, row); + RTC_DCHECK_LT(row, S - 1) << "Not enforcing row < col and row != col."; + RTC_DCHECK_LE(1, col) << "Not enforcing row < col and row != col."; + RTC_DCHECK_LT(col, S); + const int index = row * (S - 1) + (col - 1); + RTC_DCHECK_LE(0, index); + RTC_DCHECK_LT(index, buf_.size()); + return buf_[index]; + } + + private: + // Encode an upper-right triangular matrix (excluding its diagonal) using a + // square matrix. This allows to move the data in Push() with one single + // operation. + std::array buf_{}; +}; + +} // namespace rnn_vad +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SYMMETRIC_MATRIX_BUFFER_H_ diff --git a/modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer_unittest.cc b/modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer_unittest.cc new file mode 100644 index 0000000000..408467a259 --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer_unittest.cc @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer.h" + +#include "modules/audio_processing/agc2/rnn_vad/ring_buffer.h" +#include "test/gtest.h" + +namespace webrtc { +namespace test { +namespace { + +using rnn_vad::RingBuffer; +using rnn_vad::SymmetricMatrixBuffer; + +template +void CheckSymmetry(const SymmetricMatrixBuffer* sym_matrix_buf) { + for (size_t row = 0; row < S - 1; ++row) + for (size_t col = row + 1; col < S; ++col) + EXPECT_EQ(sym_matrix_buf->GetValue(row, col), + sym_matrix_buf->GetValue(col, row)); +} + +using PairType = std::pair; + +// Checks that the symmetric matrix buffer contains any pair with a value equal +// to the given one. +template +bool CheckPairsWithValueExist( + const SymmetricMatrixBuffer* sym_matrix_buf, + const int value) { + for (size_t row = 0; row < S - 1; ++row) { + for (size_t col = row + 1; col < S; ++col) { + auto p = sym_matrix_buf->GetValue(row, col); + if (p.first == value || p.second == value) + return true; + } + } + return false; +} + +} // namespace + +// Test that shows how to combine RingBuffer and SymmetricMatrixBuffer to +// efficiently compute pair-wise scores. This test verifies that the evolution +// of a SymmetricMatrixBuffer instance follows that of RingBuffer. +TEST(RnnVadTest, SymmetricMatrixBufferUseCase) { + // Instance a ring buffer which will be fed with a series of integer values. + constexpr int kRingBufSize = 10; + RingBuffer(kRingBufSize)> ring_buf; + // Instance a symmetric matrix buffer for the ring buffer above. It stores + // pairs of integers with which this test can easily check that the evolution + // of RingBuffer and SymmetricMatrixBuffer match. + SymmetricMatrixBuffer sym_matrix_buf; + for (int t = 1; t <= 100; ++t) { // Evolution steps. + SCOPED_TRACE(t); + const int t_removed = ring_buf.GetArrayView(kRingBufSize - 1)[0]; + ring_buf.Push({&t, 1}); + // The head of the ring buffer is |t|. + ASSERT_EQ(t, ring_buf.GetArrayView(0)[0]); + // Create the comparisons between |t| and the older elements in the ring + // buffer. + std::array new_comparions; + for (int i = 0; i < kRingBufSize - 1; ++i) { + // Start comparing |t| to the second newest element in the ring buffer. + const int delay = i + 1; + const auto t_prev = ring_buf.GetArrayView(delay)[0]; + ASSERT_EQ(std::max(0, t - delay), t_prev); + // Compare the last element |t| with |t_prev|. + new_comparions[i].first = t_prev; + new_comparions[i].second = t; + } + // Push the new comparisons in the symmetric matrix buffer. + sym_matrix_buf.Push({new_comparions.data(), new_comparions.size()}); + // Tests. + CheckSymmetry(&sym_matrix_buf); + // Check that the pairs resulting from the content in the ring buffer are + // in the right position. + for (size_t delay1 = 0; delay1 < kRingBufSize - 1; ++delay1) { + for (size_t delay2 = delay1 + 1; delay2 < kRingBufSize; ++delay2) { + const auto t1 = ring_buf.GetArrayView(delay1)[0]; + const auto t2 = ring_buf.GetArrayView(delay2)[0]; + ASSERT_LE(t2, t1); + const auto p = sym_matrix_buf.GetValue(delay1, delay2); + EXPECT_EQ(p.first, t2); + EXPECT_EQ(p.second, t1); + } + } + // Check that every older element in the ring buffer still has a + // corresponding pair in the symmetric matrix buffer. + for (size_t delay = 1; delay < kRingBufSize; ++delay) { + const auto t_prev = ring_buf.GetArrayView(delay)[0]; + EXPECT_TRUE(CheckPairsWithValueExist(&sym_matrix_buf, t_prev)); + } + // Check that the element removed from the ring buffer has no corresponding + // pairs in the symmetric matrix buffer. + if (t > kRingBufSize - 1) { + EXPECT_FALSE(CheckPairsWithValueExist(&sym_matrix_buf, t_removed)); + } + } +} + +} // namespace test +} // namespace webrtc