RNN VAD: prepare for SIMD optimization
This CL adds the boilerplate for SIMD optimization of FC and GRU layers in rnn.cc. The same scheme of AEC3 has been used. Unit tests for the optimized architectures have been added (the same unoptimized implementation will run). Minor changes: - unnecessary const removed in rnn.h - FC and GRU test data in the anon namespace as constexpr Bug: webrtc:10480 Change-Id: Ifae4e970326e7e7c603d49aeaf61194b5efdabd3 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/141419 Commit-Queue: Alessio Bazzica <alessiob@webrtc.org> Reviewed-by: Gustaf Ullberg <gustaf@webrtc.org> Cr-Commit-Position: refs/heads/master@{#29696}
This commit is contained in:
parent
ad04327df8
commit
7350a90237
@ -13,6 +13,7 @@ rtc_library("rnn_vad") {
|
||||
sources = [
|
||||
"auto_correlation.cc",
|
||||
"auto_correlation.h",
|
||||
"common.cc",
|
||||
"common.h",
|
||||
"features_extraction.cc",
|
||||
"features_extraction.h",
|
||||
@ -33,11 +34,20 @@ rtc_library("rnn_vad") {
|
||||
"spectral_features_internal.h",
|
||||
"symmetric_matrix_buffer.h",
|
||||
]
|
||||
|
||||
defines = []
|
||||
if (rtc_build_with_neon && current_cpu != "arm64") {
|
||||
suppressed_configs += [ "//build/config/compiler:compiler_arm_fpu" ]
|
||||
cflags = [ "-mfpu=neon" ]
|
||||
}
|
||||
|
||||
deps = [
|
||||
"..:biquad_filter",
|
||||
"../../../../api:array_view",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../../../rtc_base:rtc_base_approved",
|
||||
"../../../../rtc_base/system:arch",
|
||||
"../../../../system_wrappers:cpu_features_api",
|
||||
"../../utility:pffft_wrapper",
|
||||
"//third_party/rnnoise:rnn_vad",
|
||||
]
|
||||
|
||||
34
modules/audio_processing/agc2/rnn_vad/common.cc
Normal file
34
modules/audio_processing/agc2/rnn_vad/common.cc
Normal file
@ -0,0 +1,34 @@
|
||||
/*
|
||||
* Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
|
||||
#include "rtc_base/system/arch.h"
|
||||
#include "system_wrappers/include/cpu_features_wrapper.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
Optimization DetectOptimization() {
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
if (WebRtc_GetCPUInfo(kSSE2) != 0) {
|
||||
return Optimization::kSse2;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
return Optimization::kNeon;
|
||||
#endif
|
||||
|
||||
return Optimization::kNone;
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
@ -11,6 +11,8 @@
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_COMMON_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_COMMON_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
@ -63,6 +65,11 @@ static_assert(kCepstralCoeffsHistorySize > 2,
|
||||
|
||||
constexpr size_t kFeatureVectorSize = 42;
|
||||
|
||||
enum class Optimization { kNone, kSse2, kNeon };
|
||||
|
||||
// Detects what kind of optimizations to use for the code.
|
||||
Optimization DetectOptimization();
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
|
||||
@ -10,6 +10,15 @@
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
|
||||
|
||||
// Defines WEBRTC_ARCH_X86_FAMILY, used below.
|
||||
#include "rtc_base/system/arch.h"
|
||||
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
#include <emmintrin.h>
|
||||
#endif
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
@ -69,12 +78,14 @@ FullyConnectedLayer::FullyConnectedLayer(
|
||||
const size_t output_size,
|
||||
const rtc::ArrayView<const int8_t> bias,
|
||||
const rtc::ArrayView<const int8_t> weights,
|
||||
float (*const activation_function)(float))
|
||||
float (*const activation_function)(float),
|
||||
Optimization optimization)
|
||||
: input_size_(input_size),
|
||||
output_size_(output_size),
|
||||
bias_(GetScaledParams(bias)),
|
||||
weights_(GetScaledParams(weights)),
|
||||
activation_function_(activation_function) {
|
||||
activation_function_(activation_function),
|
||||
optimization_(optimization) {
|
||||
RTC_DCHECK_LE(output_size_, kFullyConnectedLayersMaxUnits)
|
||||
<< "Static over-allocation of fully-connected layers output vectors is "
|
||||
"not sufficient.";
|
||||
@ -91,8 +102,26 @@ rtc::ArrayView<const float> FullyConnectedLayer::GetOutput() const {
|
||||
}
|
||||
|
||||
void FullyConnectedLayer::ComputeOutput(rtc::ArrayView<const float> input) {
|
||||
// TODO(bugs.chromium.org/9076): Optimize using SSE/AVX fused multiply-add
|
||||
// operations.
|
||||
switch (optimization_) {
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
case Optimization::kSse2:
|
||||
// TODO(bugs.chromium.org/10480): Handle Optimization::kSse2.
|
||||
ComputeOutput_NONE(input);
|
||||
break;
|
||||
#endif
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
case Optimization::kNeon:
|
||||
// TODO(bugs.chromium.org/10480): Handle Optimization::kNeon.
|
||||
ComputeOutput_NONE(input);
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
ComputeOutput_NONE(input);
|
||||
}
|
||||
}
|
||||
|
||||
void FullyConnectedLayer::ComputeOutput_NONE(
|
||||
rtc::ArrayView<const float> input) {
|
||||
for (size_t o = 0; o < output_size_; ++o) {
|
||||
output_[o] = bias_[o];
|
||||
// TODO(bugs.chromium.org/9076): Benchmark how different layouts for
|
||||
@ -109,12 +138,14 @@ GatedRecurrentLayer::GatedRecurrentLayer(
|
||||
const size_t output_size,
|
||||
const rtc::ArrayView<const int8_t> bias,
|
||||
const rtc::ArrayView<const int8_t> weights,
|
||||
const rtc::ArrayView<const int8_t> recurrent_weights)
|
||||
const rtc::ArrayView<const int8_t> recurrent_weights,
|
||||
Optimization optimization)
|
||||
: input_size_(input_size),
|
||||
output_size_(output_size),
|
||||
bias_(GetScaledParams(bias)),
|
||||
weights_(GetScaledParams(weights)),
|
||||
recurrent_weights_(GetScaledParams(recurrent_weights)) {
|
||||
recurrent_weights_(GetScaledParams(recurrent_weights)),
|
||||
optimization_(optimization) {
|
||||
RTC_DCHECK_LE(output_size_, kRecurrentLayersMaxUnits)
|
||||
<< "Static over-allocation of recurrent layers state vectors is not "
|
||||
<< "sufficient.";
|
||||
@ -139,6 +170,26 @@ void GatedRecurrentLayer::Reset() {
|
||||
}
|
||||
|
||||
void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView<const float> input) {
|
||||
switch (optimization_) {
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
case Optimization::kSse2:
|
||||
// TODO(bugs.chromium.org/10480): Handle Optimization::kSse2.
|
||||
ComputeOutput_NONE(input);
|
||||
break;
|
||||
#endif
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
case Optimization::kNeon:
|
||||
// TODO(bugs.chromium.org/10480): Handle Optimization::kNeon.
|
||||
ComputeOutput_NONE(input);
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
ComputeOutput_NONE(input);
|
||||
}
|
||||
}
|
||||
|
||||
void GatedRecurrentLayer::ComputeOutput_NONE(
|
||||
rtc::ArrayView<const float> input) {
|
||||
// TODO(bugs.chromium.org/9076): Optimize using SSE/AVX fused multiply-add
|
||||
// operations.
|
||||
// Stride and offset used to read parameter arrays.
|
||||
@ -203,17 +254,20 @@ RnnBasedVad::RnnBasedVad()
|
||||
kInputLayerOutputSize,
|
||||
kInputDenseBias,
|
||||
kInputDenseWeights,
|
||||
TansigApproximated),
|
||||
TansigApproximated,
|
||||
DetectOptimization()),
|
||||
hidden_layer_(kInputLayerOutputSize,
|
||||
kHiddenLayerOutputSize,
|
||||
kHiddenGruBias,
|
||||
kHiddenGruWeights,
|
||||
kHiddenGruRecurrentWeights),
|
||||
kHiddenGruRecurrentWeights,
|
||||
DetectOptimization()),
|
||||
output_layer_(kHiddenLayerOutputSize,
|
||||
kOutputLayerOutputSize,
|
||||
kOutputDenseBias,
|
||||
kOutputDenseWeights,
|
||||
SigmoidApproximated) {
|
||||
SigmoidApproximated,
|
||||
DetectOptimization()) {
|
||||
// Input-output chaining size checks.
|
||||
RTC_DCHECK_EQ(input_layer_.output_size(), hidden_layer_.input_size())
|
||||
<< "The input and the hidden layers sizes do not match.";
|
||||
|
||||
@ -38,11 +38,12 @@ constexpr size_t kRecurrentLayersMaxUnits = 24;
|
||||
// Fully-connected layer.
|
||||
class FullyConnectedLayer {
|
||||
public:
|
||||
FullyConnectedLayer(const size_t input_size,
|
||||
const size_t output_size,
|
||||
const rtc::ArrayView<const int8_t> bias,
|
||||
const rtc::ArrayView<const int8_t> weights,
|
||||
float (*const activation_function)(float));
|
||||
FullyConnectedLayer(size_t input_size,
|
||||
size_t output_size,
|
||||
rtc::ArrayView<const int8_t> bias,
|
||||
rtc::ArrayView<const int8_t> weights,
|
||||
float (*const activation_function)(float),
|
||||
Optimization optimization);
|
||||
FullyConnectedLayer(const FullyConnectedLayer&) = delete;
|
||||
FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete;
|
||||
~FullyConnectedLayer();
|
||||
@ -53,11 +54,15 @@ class FullyConnectedLayer {
|
||||
void ComputeOutput(rtc::ArrayView<const float> input);
|
||||
|
||||
private:
|
||||
// No SIMD optimizations.
|
||||
void ComputeOutput_NONE(rtc::ArrayView<const float> input);
|
||||
|
||||
const size_t input_size_;
|
||||
const size_t output_size_;
|
||||
const std::vector<float> bias_;
|
||||
const std::vector<float> weights_;
|
||||
float (*const activation_function_)(float);
|
||||
const Optimization optimization_;
|
||||
// The output vector of a recurrent layer has length equal to |output_size_|.
|
||||
// However, for efficiency, over-allocation is used.
|
||||
std::array<float, kFullyConnectedLayersMaxUnits> output_;
|
||||
@ -67,11 +72,12 @@ class FullyConnectedLayer {
|
||||
// activation functions for the update/reset and output gates respectively.
|
||||
class GatedRecurrentLayer {
|
||||
public:
|
||||
GatedRecurrentLayer(const size_t input_size,
|
||||
const size_t output_size,
|
||||
const rtc::ArrayView<const int8_t> bias,
|
||||
const rtc::ArrayView<const int8_t> weights,
|
||||
const rtc::ArrayView<const int8_t> recurrent_weights);
|
||||
GatedRecurrentLayer(size_t input_size,
|
||||
size_t output_size,
|
||||
rtc::ArrayView<const int8_t> bias,
|
||||
rtc::ArrayView<const int8_t> weights,
|
||||
rtc::ArrayView<const int8_t> recurrent_weights,
|
||||
Optimization optimization);
|
||||
GatedRecurrentLayer(const GatedRecurrentLayer&) = delete;
|
||||
GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete;
|
||||
~GatedRecurrentLayer();
|
||||
@ -83,6 +89,9 @@ class GatedRecurrentLayer {
|
||||
void ComputeOutput(rtc::ArrayView<const float> input);
|
||||
|
||||
private:
|
||||
// No SIMD optimizations.
|
||||
void ComputeOutput_NONE(rtc::ArrayView<const float> input);
|
||||
|
||||
const size_t input_size_;
|
||||
const size_t output_size_;
|
||||
const std::vector<float> bias_;
|
||||
@ -91,6 +100,7 @@ class GatedRecurrentLayer {
|
||||
// The state vector of a recurrent layer has length equal to |output_size_|.
|
||||
// However, to avoid dynamic allocation, over-allocation is used.
|
||||
std::array<float, kRecurrentLayersMaxUnits> state_;
|
||||
const Optimization optimization_;
|
||||
};
|
||||
|
||||
// Recurrent network based VAD.
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/logging.h"
|
||||
#include "test/gtest.h"
|
||||
#include "third_party/rnnoise/src/rnn_activations.h"
|
||||
#include "third_party/rnnoise/src/rnn_vad_weights.h"
|
||||
@ -60,86 +61,104 @@ void TestGatedRecurrentLayer(
|
||||
}
|
||||
}
|
||||
|
||||
// Fully connected layer test data.
|
||||
constexpr size_t kFullyConnectedInputSize = 24;
|
||||
constexpr size_t kFullyConnectedOutputSize = 1;
|
||||
constexpr std::array<int8_t, 1> kFullyConnectedBias = {-50};
|
||||
constexpr std::array<int8_t, 24> kFullyConnectedWeights = {
|
||||
127, 127, 127, 127, 127, 20, 127, -126, -126, -54, 14, 125,
|
||||
-126, -126, 127, -125, -126, 127, -127, -127, -57, -30, 127, 80};
|
||||
constexpr std::array<float, 24 * 3> kFullyConnectedInputVectors = {
|
||||
// Input 1.
|
||||
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.215833917f, 0.290601075f, 0.238759011f,
|
||||
0.244751841f, 0.f, 0.0461241305f, 0.106401242f, 0.223070428f, 0.630603909f,
|
||||
0.690453172f, 0.f, 0.387645692f, 0.166913897f, 0.f, 0.0327451192f, 0.f,
|
||||
0.136149868f, 0.446351469f,
|
||||
// Input 2.
|
||||
0.592162728f, 0.529089332f, 1.18205106f, 1.21736848f, 0.f, 0.470851123f,
|
||||
0.130675942f, 0.320903003f, 0.305496395f, 0.0571633279f, 1.57001138f,
|
||||
0.0182026215f, 0.0977443159f, 0.347477973f, 0.493206412f, 0.9688586f,
|
||||
0.0320267938f, 0.244722098f, 0.312745273f, 0.f, 0.00650715502f,
|
||||
0.312553257f, 1.62619662f, 0.782880902f,
|
||||
// Input 3.
|
||||
0.395022154f, 0.333681047f, 0.76302278f, 0.965480626f, 0.f, 0.941198349f,
|
||||
0.0892967582f, 0.745046318f, 0.635769248f, 0.238564298f, 0.970656633f,
|
||||
0.014159563f, 0.094203949f, 0.446816623f, 0.640755892f, 1.20532358f,
|
||||
0.0254284926f, 0.283327013f, 0.726210058f, 0.0550272502f, 0.000344108557f,
|
||||
0.369803518f, 1.56680179f, 0.997883797f};
|
||||
constexpr std::array<float, 3> kFullyConnectedExpectedOutputs = {
|
||||
0.436567038f, 0.874741316f, 0.672785878f};
|
||||
|
||||
// Gated recurrent units layer test data.
|
||||
constexpr size_t kGruInputSize = 5;
|
||||
constexpr size_t kGruOutputSize = 4;
|
||||
constexpr std::array<int8_t, 12> kGruBias = {96, -99, -81, -114, 49, 119,
|
||||
-118, 68, -76, 91, 121, 125};
|
||||
constexpr std::array<int8_t, 60> kGruWeights = {
|
||||
124, 9, 1, 116, -66, -21, -118, -110, 104, 75, -23, -51,
|
||||
-72, -111, 47, 93, 77, -98, 41, -8, 40, -23, -43, -107,
|
||||
9, -73, 30, -32, -2, 64, -26, 91, -48, -24, -28, -104,
|
||||
74, -46, 116, 15, 32, 52, -126, -38, -121, 12, -16, 110,
|
||||
-95, 66, -103, -35, -38, 3, -126, -61, 28, 98, -117, -43};
|
||||
constexpr std::array<int8_t, 60> kGruRecurrentWeights = {
|
||||
-3, 87, 50, 51, -22, 27, -39, 62, 31, -83, -52, -48,
|
||||
-6, 83, -19, 104, 105, 48, 23, 68, 23, 40, 7, -120,
|
||||
64, -62, 117, 85, -51, -43, 54, -105, 120, 56, -128, -107,
|
||||
39, 50, -17, -47, -117, 14, 108, 12, -7, -72, 103, -87,
|
||||
-66, 82, 84, 100, -98, 102, -49, 44, 122, 106, -20, -69};
|
||||
constexpr std::array<float, 20> kGruInputSequence = {
|
||||
0.89395463f, 0.93224651f, 0.55788344f, 0.32341808f, 0.93355054f,
|
||||
0.13475326f, 0.97370994f, 0.14253306f, 0.93710381f, 0.76093364f,
|
||||
0.65780413f, 0.41657975f, 0.49403164f, 0.46843281f, 0.75138855f,
|
||||
0.24517593f, 0.47657707f, 0.57064998f, 0.435184f, 0.19319285f};
|
||||
constexpr std::array<float, 16> kGruExpectedOutputSequence = {
|
||||
0.0239123f, 0.5773077f, 0.f, 0.f,
|
||||
0.01282811f, 0.64330572f, 0.f, 0.04863098f,
|
||||
0.00781069f, 0.75267816f, 0.f, 0.02579715f,
|
||||
0.00471378f, 0.59162533f, 0.11087593f, 0.01334511f};
|
||||
|
||||
} // namespace
|
||||
|
||||
class OptimizationTest : public ::testing::Test,
|
||||
public ::testing::WithParamInterface<Optimization> {};
|
||||
|
||||
// Checks that the output of a fully connected layer is within tolerance given
|
||||
// test input data.
|
||||
TEST(RnnVadTest, CheckFullyConnectedLayerOutput) {
|
||||
const std::array<int8_t, 1> bias = {-50};
|
||||
const std::array<int8_t, 24> weights = {
|
||||
127, 127, 127, 127, 127, 20, 127, -126, -126, -54, 14, 125,
|
||||
-126, -126, 127, -125, -126, 127, -127, -127, -57, -30, 127, 80};
|
||||
FullyConnectedLayer fc(24, 1, bias, weights, SigmoidApproximated);
|
||||
TEST_P(OptimizationTest, CheckFullyConnectedLayerOutput) {
|
||||
const Optimization optimization = GetParam();
|
||||
RTC_LOG(LS_VERBOSE) << optimization;
|
||||
FullyConnectedLayer fc(kFullyConnectedInputSize, kFullyConnectedOutputSize,
|
||||
kFullyConnectedBias, kFullyConnectedWeights,
|
||||
SigmoidApproximated, optimization);
|
||||
// Test on different inputs.
|
||||
{
|
||||
const std::array<float, 24> input_vector = {
|
||||
0.f, 0.f, 0.f, 0.f, 0.f,
|
||||
0.f, 0.215833917f, 0.290601075f, 0.238759011f, 0.244751841f,
|
||||
0.f, 0.0461241305f, 0.106401242f, 0.223070428f, 0.630603909f,
|
||||
0.690453172f, 0.f, 0.387645692f, 0.166913897f, 0.f,
|
||||
0.0327451192f, 0.f, 0.136149868f, 0.446351469f};
|
||||
TestFullyConnectedLayer(&fc, input_vector, 0.436567038f);
|
||||
}
|
||||
{
|
||||
const std::array<float, 24> input_vector = {
|
||||
0.592162728f, 0.529089332f, 1.18205106f,
|
||||
1.21736848f, 0.f, 0.470851123f,
|
||||
0.130675942f, 0.320903003f, 0.305496395f,
|
||||
0.0571633279f, 1.57001138f, 0.0182026215f,
|
||||
0.0977443159f, 0.347477973f, 0.493206412f,
|
||||
0.9688586f, 0.0320267938f, 0.244722098f,
|
||||
0.312745273f, 0.f, 0.00650715502f,
|
||||
0.312553257f, 1.62619662f, 0.782880902f};
|
||||
TestFullyConnectedLayer(&fc, input_vector, 0.874741316f);
|
||||
}
|
||||
{
|
||||
const std::array<float, 24> input_vector = {
|
||||
0.395022154f, 0.333681047f, 0.76302278f,
|
||||
0.965480626f, 0.f, 0.941198349f,
|
||||
0.0892967582f, 0.745046318f, 0.635769248f,
|
||||
0.238564298f, 0.970656633f, 0.014159563f,
|
||||
0.094203949f, 0.446816623f, 0.640755892f,
|
||||
1.20532358f, 0.0254284926f, 0.283327013f,
|
||||
0.726210058f, 0.0550272502f, 0.000344108557f,
|
||||
0.369803518f, 1.56680179f, 0.997883797f};
|
||||
TestFullyConnectedLayer(&fc, input_vector, 0.672785878f);
|
||||
static_assert(
|
||||
kFullyConnectedInputVectors.size() % kFullyConnectedInputSize == 0, "");
|
||||
constexpr size_t kNumInputVectors =
|
||||
kFullyConnectedInputVectors.size() / kFullyConnectedInputSize;
|
||||
static_assert(kFullyConnectedExpectedOutputs.size() == kNumInputVectors, "");
|
||||
for (size_t i = 0; i < kNumInputVectors; ++i) {
|
||||
rtc::ArrayView<const float> input(
|
||||
kFullyConnectedInputVectors.data() + kFullyConnectedInputSize * i,
|
||||
kFullyConnectedInputSize);
|
||||
TestFullyConnectedLayer(&fc, input, kFullyConnectedExpectedOutputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Checks that the output of a GRU layer is within tolerance given test input
|
||||
// data.
|
||||
TEST(RnnVadTest, CheckGatedRecurrentLayer) {
|
||||
const std::array<int8_t, 12> bias = {96, -99, -81, -114, 49, 119,
|
||||
-118, 68, -76, 91, 121, 125};
|
||||
const std::array<int8_t, 60> weights = {
|
||||
124, 9, 1, 116, -66, -21, -118, -110, 104, 75, -23, -51,
|
||||
-72, -111, 47, 93, 77, -98, 41, -8, 40, -23, -43, -107,
|
||||
9, -73, 30, -32, -2, 64, -26, 91, -48, -24, -28, -104,
|
||||
74, -46, 116, 15, 32, 52, -126, -38, -121, 12, -16, 110,
|
||||
-95, 66, -103, -35, -38, 3, -126, -61, 28, 98, -117, -43};
|
||||
const std::array<int8_t, 60> recurrent_weights = {
|
||||
-3, 87, 50, 51, -22, 27, -39, 62, 31, -83, -52, -48,
|
||||
-6, 83, -19, 104, 105, 48, 23, 68, 23, 40, 7, -120,
|
||||
64, -62, 117, 85, -51, -43, 54, -105, 120, 56, -128, -107,
|
||||
39, 50, -17, -47, -117, 14, 108, 12, -7, -72, 103, -87,
|
||||
-66, 82, 84, 100, -98, 102, -49, 44, 122, 106, -20, -69};
|
||||
GatedRecurrentLayer gru(5, 4, bias, weights, recurrent_weights);
|
||||
// Test on different inputs.
|
||||
{
|
||||
const std::array<float, 20> input_sequence = {
|
||||
0.89395463f, 0.93224651f, 0.55788344f, 0.32341808f, 0.93355054f,
|
||||
0.13475326f, 0.97370994f, 0.14253306f, 0.93710381f, 0.76093364f,
|
||||
0.65780413f, 0.41657975f, 0.49403164f, 0.46843281f, 0.75138855f,
|
||||
0.24517593f, 0.47657707f, 0.57064998f, 0.435184f, 0.19319285f};
|
||||
const std::array<float, 16> expected_output_sequence = {
|
||||
0.0239123f, 0.5773077f, 0.f, 0.f,
|
||||
0.01282811f, 0.64330572f, 0.f, 0.04863098f,
|
||||
0.00781069f, 0.75267816f, 0.f, 0.02579715f,
|
||||
0.00471378f, 0.59162533f, 0.11087593f, 0.01334511f};
|
||||
TestGatedRecurrentLayer(&gru, input_sequence, expected_output_sequence);
|
||||
}
|
||||
TEST_P(OptimizationTest, CheckGatedRecurrentLayer) {
|
||||
const Optimization optimization = GetParam();
|
||||
RTC_LOG(LS_VERBOSE) << optimization;
|
||||
GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights,
|
||||
kGruRecurrentWeights, optimization);
|
||||
TestGatedRecurrentLayer(&gru, kGruInputSequence, kGruExpectedOutputSequence);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(RnnVadTest,
|
||||
OptimizationTest,
|
||||
::testing::Values(Optimization::kNone,
|
||||
DetectOptimization()));
|
||||
|
||||
} // namespace test
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user