diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.cc b/modules/audio_processing/agc2/rnn_vad/rnn.cc index a5b34c479d..94cc254045 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn.cc @@ -44,10 +44,26 @@ using rnnoise::kOutputLayerOutputSize; static_assert(kOutputLayerOutputSize <= kFullyConnectedLayersMaxUnits, "Increase kFullyConnectedLayersMaxUnits."); -using rnnoise::RectifiedLinearUnit; using rnnoise::SigmoidApproximated; using rnnoise::TansigApproximated; +namespace { + +inline float RectifiedLinearUnit(float x) { + return x < 0.f ? 0.f : x; +} + +std::vector GetScaledParams(rtc::ArrayView params) { + std::vector scaled_params(params.size()); + std::transform(params.begin(), params.end(), scaled_params.begin(), + [](int8_t x) -> float { + return rnnoise::kWeightsScale * static_cast(x); + }); + return scaled_params; +} + +} // namespace + FullyConnectedLayer::FullyConnectedLayer( const size_t input_size, const size_t output_size, @@ -56,8 +72,8 @@ FullyConnectedLayer::FullyConnectedLayer( float (*const activation_function)(float)) : input_size_(input_size), output_size_(output_size), - bias_(bias), - weights_(weights), + bias_(GetScaledParams(bias)), + weights_(GetScaledParams(weights)), activation_function_(activation_function) { RTC_DCHECK_LE(output_size_, kFullyConnectedLayersMaxUnits) << "Static over-allocation of fully-connected layers output vectors is " @@ -84,7 +100,7 @@ void FullyConnectedLayer::ComputeOutput(rtc::ArrayView input) { for (size_t i = 0; i < input_size_; ++i) { output_[o] += input[i] * weights_[i * output_size_ + o]; } - output_[o] = (*activation_function_)(kWeightsScale * output_[o]); + output_[o] = (*activation_function_)(output_[o]); } } @@ -93,14 +109,12 @@ GatedRecurrentLayer::GatedRecurrentLayer( const size_t output_size, const rtc::ArrayView bias, const rtc::ArrayView weights, - const rtc::ArrayView recurrent_weights, - float (*const activation_function)(float)) + const rtc::ArrayView recurrent_weights) : input_size_(input_size), output_size_(output_size), - bias_(bias), - weights_(weights), - recurrent_weights_(recurrent_weights), - activation_function_(activation_function) { + bias_(GetScaledParams(bias)), + weights_(GetScaledParams(weights)), + recurrent_weights_(GetScaledParams(recurrent_weights)) { RTC_DCHECK_LE(output_size_, kRecurrentLayersMaxUnits) << "Static over-allocation of recurrent layers state vectors is not " << "sufficient."; @@ -144,7 +158,7 @@ void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView input) { for (size_t s = 0; s < output_size_; ++s) { update[o] += state_[s] * recurrent_weights_[s * stride + o]; } // Add state. - update[o] = SigmoidApproximated(kWeightsScale * update[o]); + update[o] = SigmoidApproximated(update[o]); } // Compute reset gates. @@ -158,7 +172,7 @@ void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView input) { for (size_t s = 0; s < output_size_; ++s) { // Add state. reset[o] += state_[s] * recurrent_weights_[offset + s * stride + o]; } - reset[o] = SigmoidApproximated(kWeightsScale * reset[o]); + reset[o] = SigmoidApproximated(reset[o]); } // Compute output. @@ -174,7 +188,7 @@ void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView input) { output[o] += state_[s] * recurrent_weights_[offset + s * stride + o] * reset[s]; } - output[o] = (*activation_function_)(kWeightsScale * output[o]); + output[o] = RectifiedLinearUnit(output[o]); // Update output through the update gates. output[o] = update[o] * state_[o] + (1.f - update[o]) * output[o]; } @@ -194,8 +208,7 @@ RnnBasedVad::RnnBasedVad() kHiddenLayerOutputSize, kHiddenGruBias, kHiddenGruWeights, - kHiddenGruRecurrentWeights, - RectifiedLinearUnit), + kHiddenGruRecurrentWeights), output_layer_(kHiddenLayerOutputSize, kOutputLayerOutputSize, kOutputDenseBias, diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.h b/modules/audio_processing/agc2/rnn_vad/rnn.h index 1129464939..c38ff01b3e 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn.h +++ b/modules/audio_processing/agc2/rnn_vad/rnn.h @@ -15,6 +15,7 @@ #include #include +#include #include "api/array_view.h" #include "modules/audio_processing/agc2/rnn_vad/common.h" @@ -54,23 +55,23 @@ class FullyConnectedLayer { private: const size_t input_size_; const size_t output_size_; - const rtc::ArrayView bias_; - const rtc::ArrayView weights_; + const std::vector bias_; + const std::vector weights_; float (*const activation_function_)(float); // The output vector of a recurrent layer has length equal to |output_size_|. // However, for efficiency, over-allocation is used. std::array output_; }; -// Recurrent layer with gated recurrent units (GRUs). +// Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as +// activation functions for the update/reset and output gates respectively. class GatedRecurrentLayer { public: GatedRecurrentLayer(const size_t input_size, const size_t output_size, const rtc::ArrayView bias, const rtc::ArrayView weights, - const rtc::ArrayView recurrent_weights, - float (*const activation_function)(float)); + const rtc::ArrayView recurrent_weights); GatedRecurrentLayer(const GatedRecurrentLayer&) = delete; GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete; ~GatedRecurrentLayer(); @@ -84,10 +85,9 @@ class GatedRecurrentLayer { private: const size_t input_size_; const size_t output_size_; - const rtc::ArrayView bias_; - const rtc::ArrayView weights_; - const rtc::ArrayView recurrent_weights_; - float (*const activation_function_)(float); + const std::vector bias_; + const std::vector weights_; + const std::vector recurrent_weights_; // The state vector of a recurrent layer has length equal to |output_size_|. // However, to avoid dynamic allocation, over-allocation is used. std::array state_; diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc index 40ac70baf5..61e6f2670e 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc @@ -123,8 +123,7 @@ TEST(RnnVadTest, CheckGatedRecurrentLayer) { 64, -62, 117, 85, -51, -43, 54, -105, 120, 56, -128, -107, 39, 50, -17, -47, -117, 14, 108, 12, -7, -72, 103, -87, -66, 82, 84, 100, -98, 102, -49, 44, 122, 106, -20, -69}; - GatedRecurrentLayer gru(5, 4, bias, weights, recurrent_weights, - RectifiedLinearUnit); + GatedRecurrentLayer gru(5, 4, bias, weights, recurrent_weights); // Test on different inputs. { const std::array input_sequence = {