From 43afc09fc5b95bd47feb94988658f04c246388a9 Mon Sep 17 00:00:00 2001 From: Alessio Bazzica Date: Wed, 6 Nov 2019 14:42:32 +0000 Subject: [PATCH] Reland "RNN VAD: prepare for SIMD optimization" This reverts commit 5ab21f8853892205594ae8559a00b431f30a8a06. Reason for revert: downstream fixed Original change's description: > Revert "RNN VAD: prepare for SIMD optimization" > > This reverts commit 7350a902374c796dec8ce583cfaf4b9697f3a525. > > Reason for revert: possibly breaking downstream projects > > Original change's description: > > RNN VAD: prepare for SIMD optimization > > > > This CL adds the boilerplate for SIMD optimization of FC and GRU layers > > in rnn.cc. The same scheme of AEC3 has been used. Unit tests for the > > optimized architectures have been added (the same unoptimized > > implementation will run). > > > > Minor changes: > > - unnecessary const removed in rnn.h > > - FC and GRU test data in the anon namespace as constexpr > > > > Bug: webrtc:10480 > > Change-Id: Ifae4e970326e7e7c603d49aeaf61194b5efdabd3 > > Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/141419 > > Commit-Queue: Alessio Bazzica > > Reviewed-by: Gustaf Ullberg > > Cr-Commit-Position: refs/heads/master@{#29696} > > TBR=gustaf@webrtc.org,alessiob@webrtc.org,fhernqvist@webrtc.org > > Change-Id: I9ae82f4bd2d30797646fabfb5ad16bea378208b8 > No-Presubmit: true > No-Tree-Checks: true > No-Try: true > Bug: webrtc:10480 > Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/158893 > Reviewed-by: Alessio Bazzica > Commit-Queue: Alessio Bazzica > Cr-Commit-Position: refs/heads/master@{#29699} TBR=gustaf@webrtc.org,alessiob@webrtc.org,fhernqvist@webrtc.org Change-Id: I33edd144f7ac795bf472aae9fa5a79c326000443 No-Presubmit: true No-Tree-Checks: true No-Try: true Bug: webrtc:10480 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/159001 Reviewed-by: Alessio Bazzica Commit-Queue: Alessio Bazzica Cr-Commit-Position: refs/heads/master@{#29708} --- .../audio_processing/agc2/rnn_vad/BUILD.gn | 10 ++ .../audio_processing/agc2/rnn_vad/common.cc | 34 ++++ .../audio_processing/agc2/rnn_vad/common.h | 7 + modules/audio_processing/agc2/rnn_vad/rnn.cc | 72 +++++++- modules/audio_processing/agc2/rnn_vad/rnn.h | 30 ++-- .../agc2/rnn_vad/rnn_unittest.cc | 155 ++++++++++-------- 6 files changed, 221 insertions(+), 87 deletions(-) create mode 100644 modules/audio_processing/agc2/rnn_vad/common.cc diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn index 71e02fb575..852abd88bf 100644 --- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn +++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn @@ -13,6 +13,7 @@ rtc_library("rnn_vad") { sources = [ "auto_correlation.cc", "auto_correlation.h", + "common.cc", "common.h", "features_extraction.cc", "features_extraction.h", @@ -33,11 +34,20 @@ rtc_library("rnn_vad") { "spectral_features_internal.h", "symmetric_matrix_buffer.h", ] + + defines = [] + if (rtc_build_with_neon && current_cpu != "arm64") { + suppressed_configs += [ "//build/config/compiler:compiler_arm_fpu" ] + cflags = [ "-mfpu=neon" ] + } + deps = [ "..:biquad_filter", "../../../../api:array_view", "../../../../rtc_base:checks", "../../../../rtc_base:rtc_base_approved", + "../../../../rtc_base/system:arch", + "../../../../system_wrappers:cpu_features_api", "../../utility:pffft_wrapper", "//third_party/rnnoise:rnn_vad", ] diff --git a/modules/audio_processing/agc2/rnn_vad/common.cc b/modules/audio_processing/agc2/rnn_vad/common.cc new file mode 100644 index 0000000000..744c87fea2 --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/common.cc @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/rnn_vad/common.h" + +#include "rtc_base/system/arch.h" +#include "system_wrappers/include/cpu_features_wrapper.h" + +namespace webrtc { +namespace rnn_vad { + +Optimization DetectOptimization() { +#if defined(WEBRTC_ARCH_X86_FAMILY) + if (WebRtc_GetCPUInfo(kSSE2) != 0) { + return Optimization::kSse2; + } +#endif + +#if defined(WEBRTC_HAS_NEON) + return Optimization::kNeon; +#endif + + return Optimization::kNone; +} + +} // namespace rnn_vad +} // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/common.h b/modules/audio_processing/agc2/rnn_vad/common.h index 6b434d2171..c2e8df6905 100644 --- a/modules/audio_processing/agc2/rnn_vad/common.h +++ b/modules/audio_processing/agc2/rnn_vad/common.h @@ -11,6 +11,8 @@ #ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_COMMON_H_ #define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_COMMON_H_ +#include + namespace webrtc { namespace rnn_vad { @@ -63,6 +65,11 @@ static_assert(kCepstralCoeffsHistorySize > 2, constexpr size_t kFeatureVectorSize = 42; +enum class Optimization { kNone, kSse2, kNeon }; + +// Detects what kind of optimizations to use for the code. +Optimization DetectOptimization(); + } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.cc b/modules/audio_processing/agc2/rnn_vad/rnn.cc index 94cc254045..e6ef2f3a41 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn.cc @@ -10,6 +10,15 @@ #include "modules/audio_processing/agc2/rnn_vad/rnn.h" +// Defines WEBRTC_ARCH_X86_FAMILY, used below. +#include "rtc_base/system/arch.h" + +#if defined(WEBRTC_HAS_NEON) +#include +#endif +#if defined(WEBRTC_ARCH_X86_FAMILY) +#include +#endif #include #include #include @@ -69,12 +78,14 @@ FullyConnectedLayer::FullyConnectedLayer( const size_t output_size, const rtc::ArrayView bias, const rtc::ArrayView weights, - float (*const activation_function)(float)) + float (*const activation_function)(float), + Optimization optimization) : input_size_(input_size), output_size_(output_size), bias_(GetScaledParams(bias)), weights_(GetScaledParams(weights)), - activation_function_(activation_function) { + activation_function_(activation_function), + optimization_(optimization) { RTC_DCHECK_LE(output_size_, kFullyConnectedLayersMaxUnits) << "Static over-allocation of fully-connected layers output vectors is " "not sufficient."; @@ -91,8 +102,26 @@ rtc::ArrayView FullyConnectedLayer::GetOutput() const { } void FullyConnectedLayer::ComputeOutput(rtc::ArrayView input) { - // TODO(bugs.chromium.org/9076): Optimize using SSE/AVX fused multiply-add - // operations. + switch (optimization_) { +#if defined(WEBRTC_ARCH_X86_FAMILY) + case Optimization::kSse2: + // TODO(bugs.chromium.org/10480): Handle Optimization::kSse2. + ComputeOutput_NONE(input); + break; +#endif +#if defined(WEBRTC_HAS_NEON) + case Optimization::kNeon: + // TODO(bugs.chromium.org/10480): Handle Optimization::kNeon. + ComputeOutput_NONE(input); + break; +#endif + default: + ComputeOutput_NONE(input); + } +} + +void FullyConnectedLayer::ComputeOutput_NONE( + rtc::ArrayView input) { for (size_t o = 0; o < output_size_; ++o) { output_[o] = bias_[o]; // TODO(bugs.chromium.org/9076): Benchmark how different layouts for @@ -109,12 +138,14 @@ GatedRecurrentLayer::GatedRecurrentLayer( const size_t output_size, const rtc::ArrayView bias, const rtc::ArrayView weights, - const rtc::ArrayView recurrent_weights) + const rtc::ArrayView recurrent_weights, + Optimization optimization) : input_size_(input_size), output_size_(output_size), bias_(GetScaledParams(bias)), weights_(GetScaledParams(weights)), - recurrent_weights_(GetScaledParams(recurrent_weights)) { + recurrent_weights_(GetScaledParams(recurrent_weights)), + optimization_(optimization) { RTC_DCHECK_LE(output_size_, kRecurrentLayersMaxUnits) << "Static over-allocation of recurrent layers state vectors is not " << "sufficient."; @@ -139,6 +170,26 @@ void GatedRecurrentLayer::Reset() { } void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView input) { + switch (optimization_) { +#if defined(WEBRTC_ARCH_X86_FAMILY) + case Optimization::kSse2: + // TODO(bugs.chromium.org/10480): Handle Optimization::kSse2. + ComputeOutput_NONE(input); + break; +#endif +#if defined(WEBRTC_HAS_NEON) + case Optimization::kNeon: + // TODO(bugs.chromium.org/10480): Handle Optimization::kNeon. + ComputeOutput_NONE(input); + break; +#endif + default: + ComputeOutput_NONE(input); + } +} + +void GatedRecurrentLayer::ComputeOutput_NONE( + rtc::ArrayView input) { // TODO(bugs.chromium.org/9076): Optimize using SSE/AVX fused multiply-add // operations. // Stride and offset used to read parameter arrays. @@ -203,17 +254,20 @@ RnnBasedVad::RnnBasedVad() kInputLayerOutputSize, kInputDenseBias, kInputDenseWeights, - TansigApproximated), + TansigApproximated, + DetectOptimization()), hidden_layer_(kInputLayerOutputSize, kHiddenLayerOutputSize, kHiddenGruBias, kHiddenGruWeights, - kHiddenGruRecurrentWeights), + kHiddenGruRecurrentWeights, + DetectOptimization()), output_layer_(kHiddenLayerOutputSize, kOutputLayerOutputSize, kOutputDenseBias, kOutputDenseWeights, - SigmoidApproximated) { + SigmoidApproximated, + DetectOptimization()) { // Input-output chaining size checks. RTC_DCHECK_EQ(input_layer_.output_size(), hidden_layer_.input_size()) << "The input and the hidden layers sizes do not match."; diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.h b/modules/audio_processing/agc2/rnn_vad/rnn.h index c38ff01b3e..f53a09379d 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn.h +++ b/modules/audio_processing/agc2/rnn_vad/rnn.h @@ -38,11 +38,12 @@ constexpr size_t kRecurrentLayersMaxUnits = 24; // Fully-connected layer. class FullyConnectedLayer { public: - FullyConnectedLayer(const size_t input_size, - const size_t output_size, - const rtc::ArrayView bias, - const rtc::ArrayView weights, - float (*const activation_function)(float)); + FullyConnectedLayer(size_t input_size, + size_t output_size, + rtc::ArrayView bias, + rtc::ArrayView weights, + float (*const activation_function)(float), + Optimization optimization); FullyConnectedLayer(const FullyConnectedLayer&) = delete; FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete; ~FullyConnectedLayer(); @@ -53,11 +54,15 @@ class FullyConnectedLayer { void ComputeOutput(rtc::ArrayView input); private: + // No SIMD optimizations. + void ComputeOutput_NONE(rtc::ArrayView input); + const size_t input_size_; const size_t output_size_; const std::vector bias_; const std::vector weights_; float (*const activation_function_)(float); + const Optimization optimization_; // The output vector of a recurrent layer has length equal to |output_size_|. // However, for efficiency, over-allocation is used. std::array output_; @@ -67,11 +72,12 @@ class FullyConnectedLayer { // activation functions for the update/reset and output gates respectively. class GatedRecurrentLayer { public: - GatedRecurrentLayer(const size_t input_size, - const size_t output_size, - const rtc::ArrayView bias, - const rtc::ArrayView weights, - const rtc::ArrayView recurrent_weights); + GatedRecurrentLayer(size_t input_size, + size_t output_size, + rtc::ArrayView bias, + rtc::ArrayView weights, + rtc::ArrayView recurrent_weights, + Optimization optimization); GatedRecurrentLayer(const GatedRecurrentLayer&) = delete; GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete; ~GatedRecurrentLayer(); @@ -83,6 +89,9 @@ class GatedRecurrentLayer { void ComputeOutput(rtc::ArrayView input); private: + // No SIMD optimizations. + void ComputeOutput_NONE(rtc::ArrayView input); + const size_t input_size_; const size_t output_size_; const std::vector bias_; @@ -91,6 +100,7 @@ class GatedRecurrentLayer { // The state vector of a recurrent layer has length equal to |output_size_|. // However, to avoid dynamic allocation, over-allocation is used. std::array state_; + const Optimization optimization_; }; // Recurrent network based VAD. diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc index 61e6f2670e..97ede1811a 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc @@ -14,6 +14,7 @@ #include "modules/audio_processing/agc2/rnn_vad/test_utils.h" #include "rtc_base/checks.h" +#include "rtc_base/logging.h" #include "test/gtest.h" #include "third_party/rnnoise/src/rnn_activations.h" #include "third_party/rnnoise/src/rnn_vad_weights.h" @@ -60,86 +61,104 @@ void TestGatedRecurrentLayer( } } +// Fully connected layer test data. +constexpr size_t kFullyConnectedInputSize = 24; +constexpr size_t kFullyConnectedOutputSize = 1; +constexpr std::array kFullyConnectedBias = {-50}; +constexpr std::array kFullyConnectedWeights = { + 127, 127, 127, 127, 127, 20, 127, -126, -126, -54, 14, 125, + -126, -126, 127, -125, -126, 127, -127, -127, -57, -30, 127, 80}; +constexpr std::array kFullyConnectedInputVectors = { + // Input 1. + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.215833917f, 0.290601075f, 0.238759011f, + 0.244751841f, 0.f, 0.0461241305f, 0.106401242f, 0.223070428f, 0.630603909f, + 0.690453172f, 0.f, 0.387645692f, 0.166913897f, 0.f, 0.0327451192f, 0.f, + 0.136149868f, 0.446351469f, + // Input 2. + 0.592162728f, 0.529089332f, 1.18205106f, 1.21736848f, 0.f, 0.470851123f, + 0.130675942f, 0.320903003f, 0.305496395f, 0.0571633279f, 1.57001138f, + 0.0182026215f, 0.0977443159f, 0.347477973f, 0.493206412f, 0.9688586f, + 0.0320267938f, 0.244722098f, 0.312745273f, 0.f, 0.00650715502f, + 0.312553257f, 1.62619662f, 0.782880902f, + // Input 3. + 0.395022154f, 0.333681047f, 0.76302278f, 0.965480626f, 0.f, 0.941198349f, + 0.0892967582f, 0.745046318f, 0.635769248f, 0.238564298f, 0.970656633f, + 0.014159563f, 0.094203949f, 0.446816623f, 0.640755892f, 1.20532358f, + 0.0254284926f, 0.283327013f, 0.726210058f, 0.0550272502f, 0.000344108557f, + 0.369803518f, 1.56680179f, 0.997883797f}; +constexpr std::array kFullyConnectedExpectedOutputs = { + 0.436567038f, 0.874741316f, 0.672785878f}; + +// Gated recurrent units layer test data. +constexpr size_t kGruInputSize = 5; +constexpr size_t kGruOutputSize = 4; +constexpr std::array kGruBias = {96, -99, -81, -114, 49, 119, + -118, 68, -76, 91, 121, 125}; +constexpr std::array kGruWeights = { + 124, 9, 1, 116, -66, -21, -118, -110, 104, 75, -23, -51, + -72, -111, 47, 93, 77, -98, 41, -8, 40, -23, -43, -107, + 9, -73, 30, -32, -2, 64, -26, 91, -48, -24, -28, -104, + 74, -46, 116, 15, 32, 52, -126, -38, -121, 12, -16, 110, + -95, 66, -103, -35, -38, 3, -126, -61, 28, 98, -117, -43}; +constexpr std::array kGruRecurrentWeights = { + -3, 87, 50, 51, -22, 27, -39, 62, 31, -83, -52, -48, + -6, 83, -19, 104, 105, 48, 23, 68, 23, 40, 7, -120, + 64, -62, 117, 85, -51, -43, 54, -105, 120, 56, -128, -107, + 39, 50, -17, -47, -117, 14, 108, 12, -7, -72, 103, -87, + -66, 82, 84, 100, -98, 102, -49, 44, 122, 106, -20, -69}; +constexpr std::array kGruInputSequence = { + 0.89395463f, 0.93224651f, 0.55788344f, 0.32341808f, 0.93355054f, + 0.13475326f, 0.97370994f, 0.14253306f, 0.93710381f, 0.76093364f, + 0.65780413f, 0.41657975f, 0.49403164f, 0.46843281f, 0.75138855f, + 0.24517593f, 0.47657707f, 0.57064998f, 0.435184f, 0.19319285f}; +constexpr std::array kGruExpectedOutputSequence = { + 0.0239123f, 0.5773077f, 0.f, 0.f, + 0.01282811f, 0.64330572f, 0.f, 0.04863098f, + 0.00781069f, 0.75267816f, 0.f, 0.02579715f, + 0.00471378f, 0.59162533f, 0.11087593f, 0.01334511f}; + } // namespace +class OptimizationTest : public ::testing::Test, + public ::testing::WithParamInterface {}; + // Checks that the output of a fully connected layer is within tolerance given // test input data. -TEST(RnnVadTest, CheckFullyConnectedLayerOutput) { - const std::array bias = {-50}; - const std::array weights = { - 127, 127, 127, 127, 127, 20, 127, -126, -126, -54, 14, 125, - -126, -126, 127, -125, -126, 127, -127, -127, -57, -30, 127, 80}; - FullyConnectedLayer fc(24, 1, bias, weights, SigmoidApproximated); +TEST_P(OptimizationTest, CheckFullyConnectedLayerOutput) { + const Optimization optimization = GetParam(); + RTC_LOG(LS_VERBOSE) << optimization; + FullyConnectedLayer fc(kFullyConnectedInputSize, kFullyConnectedOutputSize, + kFullyConnectedBias, kFullyConnectedWeights, + SigmoidApproximated, optimization); // Test on different inputs. - { - const std::array input_vector = { - 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.215833917f, 0.290601075f, 0.238759011f, 0.244751841f, - 0.f, 0.0461241305f, 0.106401242f, 0.223070428f, 0.630603909f, - 0.690453172f, 0.f, 0.387645692f, 0.166913897f, 0.f, - 0.0327451192f, 0.f, 0.136149868f, 0.446351469f}; - TestFullyConnectedLayer(&fc, input_vector, 0.436567038f); - } - { - const std::array input_vector = { - 0.592162728f, 0.529089332f, 1.18205106f, - 1.21736848f, 0.f, 0.470851123f, - 0.130675942f, 0.320903003f, 0.305496395f, - 0.0571633279f, 1.57001138f, 0.0182026215f, - 0.0977443159f, 0.347477973f, 0.493206412f, - 0.9688586f, 0.0320267938f, 0.244722098f, - 0.312745273f, 0.f, 0.00650715502f, - 0.312553257f, 1.62619662f, 0.782880902f}; - TestFullyConnectedLayer(&fc, input_vector, 0.874741316f); - } - { - const std::array input_vector = { - 0.395022154f, 0.333681047f, 0.76302278f, - 0.965480626f, 0.f, 0.941198349f, - 0.0892967582f, 0.745046318f, 0.635769248f, - 0.238564298f, 0.970656633f, 0.014159563f, - 0.094203949f, 0.446816623f, 0.640755892f, - 1.20532358f, 0.0254284926f, 0.283327013f, - 0.726210058f, 0.0550272502f, 0.000344108557f, - 0.369803518f, 1.56680179f, 0.997883797f}; - TestFullyConnectedLayer(&fc, input_vector, 0.672785878f); + static_assert( + kFullyConnectedInputVectors.size() % kFullyConnectedInputSize == 0, ""); + constexpr size_t kNumInputVectors = + kFullyConnectedInputVectors.size() / kFullyConnectedInputSize; + static_assert(kFullyConnectedExpectedOutputs.size() == kNumInputVectors, ""); + for (size_t i = 0; i < kNumInputVectors; ++i) { + rtc::ArrayView input( + kFullyConnectedInputVectors.data() + kFullyConnectedInputSize * i, + kFullyConnectedInputSize); + TestFullyConnectedLayer(&fc, input, kFullyConnectedExpectedOutputs[i]); } } // Checks that the output of a GRU layer is within tolerance given test input // data. -TEST(RnnVadTest, CheckGatedRecurrentLayer) { - const std::array bias = {96, -99, -81, -114, 49, 119, - -118, 68, -76, 91, 121, 125}; - const std::array weights = { - 124, 9, 1, 116, -66, -21, -118, -110, 104, 75, -23, -51, - -72, -111, 47, 93, 77, -98, 41, -8, 40, -23, -43, -107, - 9, -73, 30, -32, -2, 64, -26, 91, -48, -24, -28, -104, - 74, -46, 116, 15, 32, 52, -126, -38, -121, 12, -16, 110, - -95, 66, -103, -35, -38, 3, -126, -61, 28, 98, -117, -43}; - const std::array recurrent_weights = { - -3, 87, 50, 51, -22, 27, -39, 62, 31, -83, -52, -48, - -6, 83, -19, 104, 105, 48, 23, 68, 23, 40, 7, -120, - 64, -62, 117, 85, -51, -43, 54, -105, 120, 56, -128, -107, - 39, 50, -17, -47, -117, 14, 108, 12, -7, -72, 103, -87, - -66, 82, 84, 100, -98, 102, -49, 44, 122, 106, -20, -69}; - GatedRecurrentLayer gru(5, 4, bias, weights, recurrent_weights); - // Test on different inputs. - { - const std::array input_sequence = { - 0.89395463f, 0.93224651f, 0.55788344f, 0.32341808f, 0.93355054f, - 0.13475326f, 0.97370994f, 0.14253306f, 0.93710381f, 0.76093364f, - 0.65780413f, 0.41657975f, 0.49403164f, 0.46843281f, 0.75138855f, - 0.24517593f, 0.47657707f, 0.57064998f, 0.435184f, 0.19319285f}; - const std::array expected_output_sequence = { - 0.0239123f, 0.5773077f, 0.f, 0.f, - 0.01282811f, 0.64330572f, 0.f, 0.04863098f, - 0.00781069f, 0.75267816f, 0.f, 0.02579715f, - 0.00471378f, 0.59162533f, 0.11087593f, 0.01334511f}; - TestGatedRecurrentLayer(&gru, input_sequence, expected_output_sequence); - } +TEST_P(OptimizationTest, CheckGatedRecurrentLayer) { + const Optimization optimization = GetParam(); + RTC_LOG(LS_VERBOSE) << optimization; + GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights, + kGruRecurrentWeights, optimization); + TestGatedRecurrentLayer(&gru, kGruInputSequence, kGruExpectedOutputSequence); } +INSTANTIATE_TEST_SUITE_P(RnnVadTest, + OptimizationTest, + ::testing::Values(Optimization::kNone, + DetectOptimization())); + } // namespace test } // namespace rnn_vad } // namespace webrtc