diff --git a/modules/audio_processing/agc2/BUILD.gn b/modules/audio_processing/agc2/BUILD.gn index bf09533737..c6667df420 100644 --- a/modules/audio_processing/agc2/BUILD.gn +++ b/modules/audio_processing/agc2/BUILD.gn @@ -56,6 +56,7 @@ rtc_library("adaptive_digital") { deps = [ ":common", + ":cpu_features", ":gain_applier", ":noise_level_estimator", ":rnn_vad_with_level", @@ -163,6 +164,7 @@ rtc_library("rnn_vad_with_level") { ] deps = [ ":common", + ":cpu_features", "..:audio_frame_view", "../../../api:array_view", "../../../common_audio", @@ -172,6 +174,19 @@ rtc_library("rnn_vad_with_level") { ] } +rtc_library("cpu_features") { + sources = [ + "cpu_features.cc", + "cpu_features.h", + ] + visibility = [ "./*" ] + deps = [ + "../../../rtc_base:stringutils", + "../../../rtc_base/system:arch", + "../../../system_wrappers", + ] +} + rtc_library("adaptive_digital_unittests") { testonly = true configs += [ "..:apm_debug_dump" ] diff --git a/modules/audio_processing/agc2/adaptive_agc.cc b/modules/audio_processing/agc2/adaptive_agc.cc index 0372ccf38a..4df3b58e9d 100644 --- a/modules/audio_processing/agc2/adaptive_agc.cc +++ b/modules/audio_processing/agc2/adaptive_agc.cc @@ -11,6 +11,7 @@ #include "modules/audio_processing/agc2/adaptive_agc.h" #include "common_audio/include/audio_util.h" +#include "modules/audio_processing/agc2/cpu_features.h" #include "modules/audio_processing/agc2/vad_with_level.h" #include "modules/audio_processing/logging/apm_data_dumper.h" #include "rtc_base/checks.h" @@ -32,6 +33,15 @@ constexpr int kGainApplierAdjacentSpeechFramesThreshold = 1; constexpr float kMaxGainChangePerSecondDb = 3.f; constexpr float kMaxOutputNoiseLevelDbfs = -50.f; +// Detects the available CPU features and applies a kill-switch to AVX2. +AvailableCpuFeatures GetAllowedCpuFeatures(bool avx2_allowed) { + AvailableCpuFeatures features = GetAvailableCpuFeatures(); + if (!avx2_allowed) { + features.avx2 = false; + } + return features; +} + } // namespace AdaptiveAgc::AdaptiveAgc(ApmDataDumper* apm_data_dumper) @@ -54,7 +64,8 @@ AdaptiveAgc::AdaptiveAgc(ApmDataDumper* apm_data_dumper, .level_estimator_adjacent_speech_frames_threshold, config.adaptive_digital.initial_saturation_margin_db, config.adaptive_digital.extra_saturation_margin_db), - vad_(config.adaptive_digital.vad_probability_attack), + vad_(config.adaptive_digital.vad_probability_attack, + GetAllowedCpuFeatures(config.adaptive_digital.avx2_allowed)), gain_applier_( apm_data_dumper, config.adaptive_digital.gain_applier_adjacent_speech_frames_threshold, diff --git a/modules/audio_processing/agc2/cpu_features.cc b/modules/audio_processing/agc2/cpu_features.cc new file mode 100644 index 0000000000..b4f377ffba --- /dev/null +++ b/modules/audio_processing/agc2/cpu_features.cc @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/cpu_features.h" + +#include "rtc_base/strings/string_builder.h" +#include "rtc_base/system/arch.h" +#include "system_wrappers/include/cpu_features_wrapper.h" + +namespace webrtc { + +std::string AvailableCpuFeatures::ToString() const { + char buf[64]; + rtc::SimpleStringBuilder builder(buf); + bool first = true; + if (sse2) { + builder << (first ? "SSE2" : "_SSE2"); + first = false; + } + if (avx2) { + builder << (first ? "AVX2" : "_AVX2"); + first = false; + } + if (neon) { + builder << (first ? "NEON" : "_NEON"); + first = false; + } + if (first) { + return "none"; + } + return builder.str(); +} + +// Detects available CPU features. +AvailableCpuFeatures GetAvailableCpuFeatures() { +#if defined(WEBRTC_ARCH_X86_FAMILY) + return {/*sse2=*/GetCPUInfo(kSSE2) != 0, + /*avx2=*/GetCPUInfo(kAVX2) != 0, + /*neon=*/false}; +#elif defined(WEBRTC_HAS_NEON) + return {/*sse2=*/false, + /*avx2=*/false, + /*neon=*/true}; +#endif +} + +} // namespace webrtc diff --git a/modules/audio_processing/agc2/cpu_features.h b/modules/audio_processing/agc2/cpu_features.h new file mode 100644 index 0000000000..bf73c3e562 --- /dev/null +++ b/modules/audio_processing/agc2/cpu_features.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC2_CPU_FEATURES_H_ +#define MODULES_AUDIO_PROCESSING_AGC2_CPU_FEATURES_H_ + +#include + +namespace webrtc { + +// Collection of flags indicating which CPU features are available on the +// current platform. True means available. +struct AvailableCpuFeatures { + AvailableCpuFeatures(bool sse2, bool avx2, bool neon) + : sse2(sse2), avx2(avx2), neon(neon) {} + // Intel. + bool sse2; + bool avx2; + // ARM. + bool neon; + std::string ToString() const; +}; + +// Detects what CPU features are available. +AvailableCpuFeatures GetAvailableCpuFeatures(); + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC2_CPU_FEATURES_H_ diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn index fae1d5a572..a4285bab5a 100644 --- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn +++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn @@ -29,10 +29,10 @@ rtc_library("rnn_vad") { ":rnn_vad_sequence_buffer", ":rnn_vad_spectral_features", "..:biquad_filter", + "..:cpu_features", "../../../../api:array_view", "../../../../api:function_view", "../../../../rtc_base:checks", - "../../../../rtc_base:logging", "../../../../rtc_base:safe_compare", "../../../../rtc_base:safe_conversions", "../../../../rtc_base/system:arch", @@ -53,16 +53,13 @@ rtc_library("rnn_vad_auto_correlation") { ] } -rtc_library("rnn_vad_common") { +rtc_source_set("rnn_vad_common") { # TODO(alessiob): Make this target visibility private. visibility = [ ":*", "..:rnn_vad_with_level", ] - sources = [ - "common.cc", - "common.h", - ] + sources = [ "common.h" ] deps = [ "../../../../rtc_base/system:arch", "../../../../system_wrappers", @@ -91,6 +88,7 @@ rtc_library("rnn_vad_pitch") { deps = [ ":rnn_vad_auto_correlation", ":rnn_vad_common", + "..:cpu_features", "../../../../api:array_view", "../../../../rtc_base:checks", "../../../../rtc_base:gtest_prod", @@ -156,8 +154,6 @@ if (rtc_include_tests) { "../../../../api:scoped_refptr", "../../../../rtc_base:checks", "../../../../rtc_base:safe_compare", - "../../../../rtc_base/system:arch", - "../../../../system_wrappers", "../../../../test:fileutils", "../../../../test:test_support", ] @@ -207,6 +203,7 @@ if (rtc_include_tests) { ":rnn_vad_spectral_features", ":rnn_vad_symmetric_matrix_buffer", ":test_utils", + "..:cpu_features", "../..:audioproc_test_utils", "../../../../api:array_view", "../../../../common_audio/", @@ -232,6 +229,7 @@ if (rtc_include_tests) { deps = [ ":rnn_vad", ":rnn_vad_common", + "..:cpu_features", "../../../../api:array_view", "../../../../common_audio", "../../../../rtc_base:rtc_base_approved", diff --git a/modules/audio_processing/agc2/rnn_vad/common.cc b/modules/audio_processing/agc2/rnn_vad/common.cc deleted file mode 100644 index 5d76b52e57..0000000000 --- a/modules/audio_processing/agc2/rnn_vad/common.cc +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#include "modules/audio_processing/agc2/rnn_vad/common.h" - -#include "rtc_base/system/arch.h" -#include "system_wrappers/include/cpu_features_wrapper.h" - -namespace webrtc { -namespace rnn_vad { - -Optimization DetectOptimization() { -#if defined(WEBRTC_ARCH_X86_FAMILY) - if (GetCPUInfo(kSSE2) != 0) { - return Optimization::kSse2; - } -#endif - -#if defined(WEBRTC_HAS_NEON) - return Optimization::kNeon; -#endif - - return Optimization::kNone; -} - -} // namespace rnn_vad -} // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/common.h b/modules/audio_processing/agc2/rnn_vad/common.h index 36b366ad1d..be5a2d58ce 100644 --- a/modules/audio_processing/agc2/rnn_vad/common.h +++ b/modules/audio_processing/agc2/rnn_vad/common.h @@ -71,11 +71,6 @@ static_assert(kCepstralCoeffsHistorySize > 2, constexpr int kFeatureVectorSize = 42; -enum class Optimization { kNone, kSse2, kNeon }; - -// Detects what kind of optimizations to use for the code. -Optimization DetectOptimization(); - } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/features_extraction.cc b/modules/audio_processing/agc2/rnn_vad/features_extraction.cc index cdbbbc311d..f86eba764e 100644 --- a/modules/audio_processing/agc2/rnn_vad/features_extraction.cc +++ b/modules/audio_processing/agc2/rnn_vad/features_extraction.cc @@ -26,13 +26,13 @@ const BiQuadFilter::BiQuadCoefficients kHpfConfig24k = { } // namespace -FeaturesExtractor::FeaturesExtractor() +FeaturesExtractor::FeaturesExtractor(const AvailableCpuFeatures& cpu_features) : use_high_pass_filter_(false), pitch_buf_24kHz_(), pitch_buf_24kHz_view_(pitch_buf_24kHz_.GetBufferView()), lp_residual_(kBufSize24kHz), lp_residual_view_(lp_residual_.data(), kBufSize24kHz), - pitch_estimator_(), + pitch_estimator_(cpu_features), reference_frame_view_(pitch_buf_24kHz_.GetMostRecentValuesView()) { RTC_DCHECK_EQ(kBufSize24kHz, lp_residual_.size()); hpf_.Initialize(kHpfConfig24k); diff --git a/modules/audio_processing/agc2/rnn_vad/features_extraction.h b/modules/audio_processing/agc2/rnn_vad/features_extraction.h index e2c77d2cf8..f4cea7a83d 100644 --- a/modules/audio_processing/agc2/rnn_vad/features_extraction.h +++ b/modules/audio_processing/agc2/rnn_vad/features_extraction.h @@ -26,7 +26,7 @@ namespace rnn_vad { // Feature extractor to feed the VAD RNN. class FeaturesExtractor { public: - FeaturesExtractor(); + explicit FeaturesExtractor(const AvailableCpuFeatures& cpu_features); FeaturesExtractor(const FeaturesExtractor&) = delete; FeaturesExtractor& operator=(const FeaturesExtractor&) = delete; ~FeaturesExtractor(); diff --git a/modules/audio_processing/agc2/rnn_vad/features_extraction_unittest.cc b/modules/audio_processing/agc2/rnn_vad/features_extraction_unittest.cc index 9df52738b4..0da971e3da 100644 --- a/modules/audio_processing/agc2/rnn_vad/features_extraction_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/features_extraction_unittest.cc @@ -13,6 +13,7 @@ #include #include +#include "modules/audio_processing/agc2/cpu_features.h" #include "modules/audio_processing/agc2/rnn_vad/test_utils.h" #include "rtc_base/numerics/safe_compare.h" #include "rtc_base/numerics/safe_conversions.h" @@ -77,7 +78,8 @@ TEST(RnnVadTest, FeatureExtractionLowHighPitch) { ASSERT_TRUE(PitchIsValid(low_pitch_hz)); ASSERT_TRUE(PitchIsValid(high_pitch_hz)); - FeaturesExtractor features_extractor; + const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures(); + FeaturesExtractor features_extractor(cpu_features); std::vector samples(kNumTestDataSize); std::vector feature_vector(kFeatureVectorSize); ASSERT_EQ(kFeatureVectorSize, rtc::dchecked_cast(feature_vector.size())); diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc index c6c3e1b2b5..c2e7665967 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc @@ -18,8 +18,9 @@ namespace webrtc { namespace rnn_vad { -PitchEstimator::PitchEstimator() - : y_energy_24kHz_(kRefineNumLags24kHz, 0.f), +PitchEstimator::PitchEstimator(const AvailableCpuFeatures& cpu_features) + : cpu_features_(cpu_features), + y_energy_24kHz_(kRefineNumLags24kHz, 0.f), pitch_buffer_12kHz_(kBufSize12kHz), auto_correlation_12kHz_(kNumLags12kHz) {} @@ -35,6 +36,7 @@ int PitchEstimator::Estimate( RTC_DCHECK_EQ(auto_correlation_12kHz_.size(), auto_correlation_12kHz_view.size()); + // TODO(bugs.chromium.org/10480): Use `cpu_features_` to estimate pitch. // Perform the initial pitch search at 12 kHz. Decimate2x(pitch_buffer, pitch_buffer_12kHz_view); auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buffer_12kHz_view, diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.h b/modules/audio_processing/agc2/rnn_vad/pitch_search.h index e96a2dcaf1..42c448eb56 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search.h +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.h @@ -15,6 +15,7 @@ #include #include "api/array_view.h" +#include "modules/audio_processing/agc2/cpu_features.h" #include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h" #include "modules/audio_processing/agc2/rnn_vad/common.h" #include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h" @@ -26,7 +27,7 @@ namespace rnn_vad { // Pitch estimator. class PitchEstimator { public: - PitchEstimator(); + explicit PitchEstimator(const AvailableCpuFeatures& cpu_features); PitchEstimator(const PitchEstimator&) = delete; PitchEstimator& operator=(const PitchEstimator&) = delete; ~PitchEstimator(); @@ -39,6 +40,7 @@ class PitchEstimator { return last_pitch_48kHz_.strength; } + const AvailableCpuFeatures cpu_features_; PitchInfo last_pitch_48kHz_{}; AutoCorrelationCalculator auto_corr_calculator_; std::vector y_energy_24kHz_; diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc index 98b791e872..fe9be5dbba 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc @@ -13,6 +13,7 @@ #include #include +#include "modules/audio_processing/agc2/cpu_features.h" #include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h" #include "modules/audio_processing/agc2/rnn_vad/test_utils.h" // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. @@ -29,7 +30,8 @@ TEST(RnnVadTest, PitchSearchWithinTolerance) { const int num_frames = std::min(lp_residual_reader.second, 300); // Max 3 s. std::vector lp_residual(kBufSize24kHz); float expected_pitch_period, expected_pitch_strength; - PitchEstimator pitch_estimator; + const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures(); + PitchEstimator pitch_estimator(cpu_features); { // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.cc b/modules/audio_processing/agc2/rnn_vad/rnn.cc index 2072a6854d..fb4962f724 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn.cc @@ -25,7 +25,6 @@ #include #include "rtc_base/checks.h" -#include "rtc_base/logging.h" #include "rtc_base/numerics/safe_conversions.h" #include "third_party/rnnoise/src/rnn_activations.h" #include "third_party/rnnoise/src/rnn_vad_weights.h" @@ -273,13 +272,13 @@ FullyConnectedLayer::FullyConnectedLayer( const rtc::ArrayView bias, const rtc::ArrayView weights, rtc::FunctionView activation_function, - Optimization optimization) + const AvailableCpuFeatures& cpu_features) : input_size_(input_size), output_size_(output_size), bias_(GetScaledParams(bias)), weights_(GetPreprocessedFcWeights(weights, output_size)), activation_function_(activation_function), - optimization_(optimization) { + cpu_features_(cpu_features) { RTC_DCHECK_LE(output_size_, kFullyConnectedLayersMaxUnits) << "Static over-allocation of fully-connected layers output vectors is " "not sufficient."; @@ -296,25 +295,18 @@ rtc::ArrayView FullyConnectedLayer::GetOutput() const { } void FullyConnectedLayer::ComputeOutput(rtc::ArrayView input) { - switch (optimization_) { #if defined(WEBRTC_ARCH_X86_FAMILY) - case Optimization::kSse2: - ComputeFullyConnectedLayerOutputSse2(input_size_, output_size_, input, - bias_, weights_, - activation_function_, output_); - break; -#endif -#if defined(WEBRTC_HAS_NEON) - case Optimization::kNeon: - // TODO(bugs.chromium.org/10480): Handle Optimization::kNeon. - ComputeFullyConnectedLayerOutput(input_size_, output_size_, input, bias_, - weights_, activation_function_, output_); - break; -#endif - default: - ComputeFullyConnectedLayerOutput(input_size_, output_size_, input, bias_, - weights_, activation_function_, output_); + // TODO(bugs.chromium.org/10480): Add AVX2. + if (cpu_features_.sse2) { + ComputeFullyConnectedLayerOutputSse2(input_size_, output_size_, input, + bias_, weights_, activation_function_, + output_); + return; } +#endif + // TODO(bugs.chromium.org/10480): Add Neon. + ComputeFullyConnectedLayerOutput(input_size_, output_size_, input, bias_, + weights_, activation_function_, output_); } GatedRecurrentLayer::GatedRecurrentLayer( @@ -322,15 +314,13 @@ GatedRecurrentLayer::GatedRecurrentLayer( const int output_size, const rtc::ArrayView bias, const rtc::ArrayView weights, - const rtc::ArrayView recurrent_weights, - Optimization optimization) + const rtc::ArrayView recurrent_weights) : input_size_(input_size), output_size_(output_size), bias_(GetPreprocessedGruTensor(bias, output_size)), weights_(GetPreprocessedGruTensor(weights, output_size)), recurrent_weights_( - GetPreprocessedGruTensor(recurrent_weights, output_size)), - optimization_(optimization) { + GetPreprocessedGruTensor(recurrent_weights, output_size)) { RTC_DCHECK_LE(output_size_, kRecurrentLayersMaxUnits) << "Static over-allocation of recurrent layers state vectors is not " "sufficient."; @@ -356,46 +346,30 @@ void GatedRecurrentLayer::Reset() { } void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView input) { - switch (optimization_) { -#if defined(WEBRTC_ARCH_X86_FAMILY) - case Optimization::kSse2: - // TODO(bugs.chromium.org/10480): Handle Optimization::kSse2. - ComputeGruLayerOutput(input_size_, output_size_, input, weights_, - recurrent_weights_, bias_, state_); - break; -#endif -#if defined(WEBRTC_HAS_NEON) - case Optimization::kNeon: - // TODO(bugs.chromium.org/10480): Handle Optimization::kNeon. - ComputeGruLayerOutput(input_size_, output_size_, input, weights_, - recurrent_weights_, bias_, state_); - break; -#endif - default: - ComputeGruLayerOutput(input_size_, output_size_, input, weights_, - recurrent_weights_, bias_, state_); - } + // TODO(bugs.chromium.org/10480): Add AVX2. + // TODO(bugs.chromium.org/10480): Add Neon. + ComputeGruLayerOutput(input_size_, output_size_, input, weights_, + recurrent_weights_, bias_, state_); } -RnnBasedVad::RnnBasedVad() +RnnBasedVad::RnnBasedVad(const AvailableCpuFeatures& cpu_features) : input_layer_(kInputLayerInputSize, kInputLayerOutputSize, kInputDenseBias, kInputDenseWeights, TansigApproximated, - DetectOptimization()), + cpu_features), hidden_layer_(kInputLayerOutputSize, kHiddenLayerOutputSize, kHiddenGruBias, kHiddenGruWeights, - kHiddenGruRecurrentWeights, - DetectOptimization()), + kHiddenGruRecurrentWeights), output_layer_(kHiddenLayerOutputSize, kOutputLayerOutputSize, kOutputDenseBias, kOutputDenseWeights, SigmoidApproximated, - DetectOptimization()) { + cpu_features) { // Input-output chaining size checks. RTC_DCHECK_EQ(input_layer_.output_size(), hidden_layer_.input_size()) << "The input and the hidden layers sizes do not match."; diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.h b/modules/audio_processing/agc2/rnn_vad/rnn.h index 5b44f53047..1ef4c76c21 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn.h +++ b/modules/audio_processing/agc2/rnn_vad/rnn.h @@ -19,6 +19,7 @@ #include "api/array_view.h" #include "api/function_view.h" +#include "modules/audio_processing/agc2/cpu_features.h" #include "modules/audio_processing/agc2/rnn_vad/common.h" #include "rtc_base/system/arch.h" @@ -45,13 +46,12 @@ class FullyConnectedLayer { rtc::ArrayView bias, rtc::ArrayView weights, rtc::FunctionView activation_function, - Optimization optimization); + const AvailableCpuFeatures& cpu_features); FullyConnectedLayer(const FullyConnectedLayer&) = delete; FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete; ~FullyConnectedLayer(); int input_size() const { return input_size_; } int output_size() const { return output_size_; } - Optimization optimization() const { return optimization_; } rtc::ArrayView GetOutput() const; // Computes the fully-connected layer output. void ComputeOutput(rtc::ArrayView input); @@ -65,7 +65,7 @@ class FullyConnectedLayer { // The output vector of a recurrent layer has length equal to |output_size_|. // However, for efficiency, over-allocation is used. std::array output_; - const Optimization optimization_; + const AvailableCpuFeatures cpu_features_; }; // Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as @@ -76,14 +76,12 @@ class GatedRecurrentLayer { int output_size, rtc::ArrayView bias, rtc::ArrayView weights, - rtc::ArrayView recurrent_weights, - Optimization optimization); + rtc::ArrayView recurrent_weights); GatedRecurrentLayer(const GatedRecurrentLayer&) = delete; GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete; ~GatedRecurrentLayer(); int input_size() const { return input_size_; } int output_size() const { return output_size_; } - Optimization optimization() const { return optimization_; } rtc::ArrayView GetOutput() const; void Reset(); // Computes the recurrent layer output and updates the status. @@ -98,13 +96,12 @@ class GatedRecurrentLayer { // The state vector of a recurrent layer has length equal to |output_size_|. // However, to avoid dynamic allocation, over-allocation is used. std::array state_; - const Optimization optimization_; }; // Recurrent network based VAD. class RnnBasedVad { public: - RnnBasedVad(); + explicit RnnBasedVad(const AvailableCpuFeatures& cpu_features); RnnBasedVad(const RnnBasedVad&) = delete; RnnBasedVad& operator=(const RnnBasedVad&) = delete; ~RnnBasedVad(); diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc index a57a899c8d..2e920e8d80 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc @@ -14,6 +14,7 @@ #include #include +#include "modules/audio_processing/agc2/cpu_features.h" #include "modules/audio_processing/agc2/rnn_vad/test_utils.h" #include "modules/audio_processing/test/performance_timer.h" #include "rtc_base/checks.h" @@ -27,7 +28,6 @@ namespace webrtc { namespace rnn_vad { namespace test { - namespace { void TestFullyConnectedLayer(FullyConnectedLayer* fc, @@ -39,26 +39,25 @@ void TestFullyConnectedLayer(FullyConnectedLayer* fc, } void TestGatedRecurrentLayer( - GatedRecurrentLayer* gru, + GatedRecurrentLayer& gru, rtc::ArrayView input_sequence, rtc::ArrayView expected_output_sequence) { - RTC_CHECK(gru); - auto gru_output_view = gru->GetOutput(); + auto gru_output_view = gru.GetOutput(); const int input_sequence_length = rtc::CheckedDivExact( - rtc::dchecked_cast(input_sequence.size()), gru->input_size()); + rtc::dchecked_cast(input_sequence.size()), gru.input_size()); const int output_sequence_length = rtc::CheckedDivExact( rtc::dchecked_cast(expected_output_sequence.size()), - gru->output_size()); + gru.output_size()); ASSERT_EQ(input_sequence_length, output_sequence_length) << "The test data length is invalid."; // Feed the GRU layer and check the output at every step. - gru->Reset(); + gru.Reset(); for (int i = 0; i < input_sequence_length; ++i) { SCOPED_TRACE(i); - gru->ComputeOutput( - input_sequence.subview(i * gru->input_size(), gru->input_size())); + gru.ComputeOutput( + input_sequence.subview(i * gru.input_size(), gru.input_size())); const auto expected_output = expected_output_sequence.subview( - i * gru->output_size(), gru->output_size()); + i * gru.output_size(), gru.output_size()); ExpectNearAbsolute(expected_output, gru_output_view, 3e-6f); } } @@ -134,141 +133,94 @@ constexpr std::array kGruExpectedOutputSequence = { 0.00781069f, 0.75267816f, 0.f, 0.02579715f, 0.00471378f, 0.59162533f, 0.11087593f, 0.01334511f}; -std::string GetOptimizationName(Optimization optimization) { - switch (optimization) { - case Optimization::kSse2: - return "SSE2"; - case Optimization::kNeon: - return "NEON"; - case Optimization::kNone: - return "none"; - } -} - -struct Result { - Optimization optimization; - double average_us; - double std_dev_us; -}; - -} // namespace - -// Checks that the output of a fully connected layer is within tolerance given -// test input data. -TEST(RnnVadTest, CheckFullyConnectedLayerOutput) { - FullyConnectedLayer fc(rnnoise::kInputLayerInputSize, - rnnoise::kInputLayerOutputSize, - rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights, - rnnoise::TansigApproximated, Optimization::kNone); - TestFullyConnectedLayer(&fc, kFullyConnectedInputVector, - kFullyConnectedExpectedOutput); -} - // Checks that the output of a GRU layer is within tolerance given test input // data. TEST(RnnVadTest, CheckGatedRecurrentLayer) { GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights, - kGruRecurrentWeights, Optimization::kNone); - TestGatedRecurrentLayer(&gru, kGruInputSequence, kGruExpectedOutputSequence); -} - -#if defined(WEBRTC_ARCH_X86_FAMILY) - -// Like CheckFullyConnectedLayerOutput, but testing the SSE2 implementation. -TEST(RnnVadTest, CheckFullyConnectedLayerOutputSse2) { - if (!IsOptimizationAvailable(Optimization::kSse2)) { - return; - } - - FullyConnectedLayer fc(rnnoise::kInputLayerInputSize, - rnnoise::kInputLayerOutputSize, - rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights, - rnnoise::TansigApproximated, Optimization::kSse2); - TestFullyConnectedLayer(&fc, kFullyConnectedInputVector, - kFullyConnectedExpectedOutput); -} - -// Like CheckGatedRecurrentLayer, but testing the SSE2 implementation. -TEST(RnnVadTest, CheckGatedRecurrentLayerSse2) { - if (!IsOptimizationAvailable(Optimization::kSse2)) { - return; - } - - GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights, - kGruRecurrentWeights, Optimization::kSse2); - TestGatedRecurrentLayer(&gru, kGruInputSequence, kGruExpectedOutputSequence); -} - -#endif // WEBRTC_ARCH_X86_FAMILY - -TEST(RnnVadTest, DISABLED_BenchmarkFullyConnectedLayer) { - std::vector> implementations; - implementations.emplace_back(std::make_unique( - rnnoise::kInputLayerInputSize, rnnoise::kInputLayerOutputSize, - rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights, - rnnoise::TansigApproximated, Optimization::kNone)); - if (IsOptimizationAvailable(Optimization::kSse2)) { - implementations.emplace_back(std::make_unique( - rnnoise::kInputLayerInputSize, rnnoise::kInputLayerOutputSize, - rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights, - rnnoise::TansigApproximated, Optimization::kSse2)); - } - - std::vector results; - constexpr int number_of_tests = 10000; - for (auto& fc : implementations) { - ::webrtc::test::PerformanceTimer perf_timer(number_of_tests); - for (int k = 0; k < number_of_tests; ++k) { - perf_timer.StartTimer(); - fc->ComputeOutput(kFullyConnectedInputVector); - perf_timer.StopTimer(); - } - results.push_back({fc->optimization(), perf_timer.GetDurationAverage(), - perf_timer.GetDurationStandardDeviation()}); - } - - for (const auto& result : results) { - RTC_LOG(LS_INFO) << GetOptimizationName(result.optimization) << ": " - << (result.average_us / 1e3) << " +/- " - << (result.std_dev_us / 1e3) << " ms"; - } + kGruRecurrentWeights); + TestGatedRecurrentLayer(gru, kGruInputSequence, kGruExpectedOutputSequence); } TEST(RnnVadTest, DISABLED_BenchmarkGatedRecurrentLayer) { - std::vector> implementations; - implementations.emplace_back(std::make_unique( - kGruInputSize, kGruOutputSize, kGruBias, kGruWeights, - kGruRecurrentWeights, Optimization::kNone)); + GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights, + kGruRecurrentWeights); rtc::ArrayView input_sequence(kGruInputSequence); static_assert(kGruInputSequence.size() % kGruInputSize == 0, ""); constexpr int input_sequence_length = kGruInputSequence.size() / kGruInputSize; - std::vector results; - constexpr int number_of_tests = 10000; - for (auto& gru : implementations) { - ::webrtc::test::PerformanceTimer perf_timer(number_of_tests); - gru->Reset(); - for (int k = 0; k < number_of_tests; ++k) { - perf_timer.StartTimer(); - for (int i = 0; i < input_sequence_length; ++i) { - gru->ComputeOutput( - input_sequence.subview(i * gru->input_size(), gru->input_size())); - } - perf_timer.StopTimer(); + constexpr int kNumTests = 10000; + ::webrtc::test::PerformanceTimer perf_timer(kNumTests); + for (int k = 0; k < kNumTests; ++k) { + perf_timer.StartTimer(); + for (int i = 0; i < input_sequence_length; ++i) { + gru.ComputeOutput( + input_sequence.subview(i * gru.input_size(), gru.input_size())); } - results.push_back({gru->optimization(), perf_timer.GetDurationAverage(), - perf_timer.GetDurationStandardDeviation()}); - } - - for (const auto& result : results) { - RTC_LOG(LS_INFO) << GetOptimizationName(result.optimization) << ": " - << (result.average_us / 1e3) << " +/- " - << (result.std_dev_us / 1e3) << " ms"; + perf_timer.StopTimer(); } + RTC_LOG(LS_INFO) << (perf_timer.GetDurationAverage() / 1000) << " +/- " + << (perf_timer.GetDurationStandardDeviation() / 1000) + << " ms"; } +class RnnParametrization + : public ::testing::TestWithParam {}; + +// Checks that the output of a fully connected layer is within tolerance given +// test input data. +TEST_P(RnnParametrization, CheckFullyConnectedLayerOutput) { + FullyConnectedLayer fc( + rnnoise::kInputLayerInputSize, rnnoise::kInputLayerOutputSize, + rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights, + rnnoise::TansigApproximated, /*cpu_features=*/GetParam()); + TestFullyConnectedLayer(&fc, kFullyConnectedInputVector, + kFullyConnectedExpectedOutput); +} + +TEST_P(RnnParametrization, DISABLED_BenchmarkFullyConnectedLayer) { + const AvailableCpuFeatures cpu_features = GetParam(); + FullyConnectedLayer fc(rnnoise::kInputLayerInputSize, + rnnoise::kInputLayerOutputSize, + rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights, + rnnoise::TansigApproximated, cpu_features); + + constexpr int kNumTests = 10000; + ::webrtc::test::PerformanceTimer perf_timer(kNumTests); + for (int k = 0; k < kNumTests; ++k) { + perf_timer.StartTimer(); + fc.ComputeOutput(kFullyConnectedInputVector); + perf_timer.StopTimer(); + } + RTC_LOG(LS_INFO) << "CPU features: " << cpu_features.ToString() << " | " + << (perf_timer.GetDurationAverage() / 1000) << " +/- " + << (perf_timer.GetDurationStandardDeviation() / 1000) + << " ms"; +} + +// Finds the relevant CPU features combinations to test. +std::vector GetCpuFeaturesToTest() { + std::vector v; + v.push_back({/*sse2=*/false, /*avx2=*/false, /*neon=*/false}); + AvailableCpuFeatures available = GetAvailableCpuFeatures(); + if (available.sse2) { + AvailableCpuFeatures features( + {/*sse2=*/true, /*avx2=*/false, /*neon=*/false}); + v.push_back(features); + } + return v; +} + +INSTANTIATE_TEST_SUITE_P( + RnnVadTest, + RnnParametrization, + ::testing::ValuesIn(GetCpuFeaturesToTest()), + [](const ::testing::TestParamInfo& info) { + return info.param.ToString(); + }); + +} // namespace } // namespace test } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_vad_tool.cc b/modules/audio_processing/agc2/rnn_vad/rnn_vad_tool.cc index 8b12b60c55..0f3ad5ce16 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn_vad_tool.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn_vad_tool.cc @@ -16,6 +16,7 @@ #include "absl/flags/parse.h" #include "common_audio/resampler/push_sinc_resampler.h" #include "common_audio/wav_file.h" +#include "modules/audio_processing/agc2/cpu_features.h" #include "modules/audio_processing/agc2/rnn_vad/common.h" #include "modules/audio_processing/agc2/rnn_vad/features_extraction.h" #include "modules/audio_processing/agc2/rnn_vad/rnn.h" @@ -63,9 +64,10 @@ int main(int argc, char* argv[]) { samples_10ms.resize(frame_size_10ms); std::array samples_10ms_24kHz; PushSincResampler resampler(frame_size_10ms, kFrameSize10ms24kHz); - FeaturesExtractor features_extractor; + const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures(); + FeaturesExtractor features_extractor(cpu_features); std::array feature_vector; - RnnBasedVad rnn_vad; + RnnBasedVad rnn_vad(cpu_features); // Compute VAD probabilities. while (true) { diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc index 0916bf5b81..6036a00fd0 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc @@ -13,6 +13,7 @@ #include #include "common_audio/resampler/push_sinc_resampler.h" +#include "modules/audio_processing/agc2/cpu_features.h" #include "modules/audio_processing/agc2/rnn_vad/features_extraction.h" #include "modules/audio_processing/agc2/rnn_vad/rnn.h" #include "modules/audio_processing/agc2/rnn_vad/test_utils.h" @@ -57,13 +58,17 @@ TEST(RnnVadTest, CheckWriteComputedOutputIsFalse) { << "Cannot land if kWriteComputedOutput is true."; } +class RnnVadProbabilityParametrization + : public ::testing::TestWithParam {}; + // Checks that the computed VAD probability for a test input sequence sampled at // 48 kHz is within tolerance. -TEST(RnnVadTest, RnnVadProbabilityWithinTolerance) { +TEST_P(RnnVadProbabilityParametrization, RnnVadProbabilityWithinTolerance) { // Init resampler, feature extractor and RNN. PushSincResampler decimator(kFrameSize10ms48kHz, kFrameSize10ms24kHz); - FeaturesExtractor features_extractor; - RnnBasedVad rnn_vad; + const AvailableCpuFeatures cpu_features = GetParam(); + FeaturesExtractor features_extractor(cpu_features); + RnnBasedVad rnn_vad(cpu_features); // Init input samples and expected output readers. auto samples_reader = CreatePcmSamplesReader(kFrameSize10ms48kHz); @@ -111,7 +116,7 @@ TEST(RnnVadTest, RnnVadProbabilityWithinTolerance) { // follows: // - on desktop: run the this unit test adding "--logs"; // - on android: run the this unit test adding "--logcat-output-file". -TEST(RnnVadTest, DISABLED_RnnVadPerformance) { +TEST_P(RnnVadProbabilityParametrization, DISABLED_RnnVadPerformance) { // PCM samples reader and buffers. auto samples_reader = CreatePcmSamplesReader(kFrameSize10ms48kHz); const int num_frames = samples_reader.second; @@ -127,9 +132,10 @@ TEST(RnnVadTest, DISABLED_RnnVadPerformance) { kFrameSize10ms24kHz); } // Initialize. - FeaturesExtractor features_extractor; + const AvailableCpuFeatures cpu_features = GetParam(); + FeaturesExtractor features_extractor(cpu_features); std::array feature_vector; - RnnBasedVad rnn_vad; + RnnBasedVad rnn_vad(cpu_features); constexpr int number_of_tests = 100; ::webrtc::test::PerformanceTimer perf_timer(number_of_tests); for (int k = 0; k < number_of_tests; ++k) { @@ -152,6 +158,27 @@ TEST(RnnVadTest, DISABLED_RnnVadPerformance) { perf_timer.GetDurationStandardDeviation()); } +// Finds the relevant CPU features combinations to test. +std::vector GetCpuFeaturesToTest() { + std::vector v; + v.push_back({/*sse2=*/false, /*avx2=*/false, /*neon=*/false}); + AvailableCpuFeatures available = GetAvailableCpuFeatures(); + if (available.sse2) { + AvailableCpuFeatures features( + {/*sse2=*/true, /*avx2=*/false, /*neon=*/false}); + v.push_back(features); + } + return v; +} + +INSTANTIATE_TEST_SUITE_P( + RnnVadTest, + RnnVadProbabilityParametrization, + ::testing::ValuesIn(GetCpuFeaturesToTest()), + [](const ::testing::TestParamInfo& info) { + return info.param.ToString(); + }); + } // namespace test } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.cc b/modules/audio_processing/agc2/rnn_vad/test_utils.cc index 24bbf13e31..75de1099f2 100644 --- a/modules/audio_processing/agc2/rnn_vad/test_utils.cc +++ b/modules/audio_processing/agc2/rnn_vad/test_utils.cc @@ -15,8 +15,6 @@ #include "rtc_base/checks.h" #include "rtc_base/numerics/safe_compare.h" -#include "rtc_base/system/arch.h" -#include "system_wrappers/include/cpu_features_wrapper.h" #include "test/gtest.h" #include "test/testsupport/file_utils.h" @@ -111,25 +109,6 @@ PitchTestData::GetPitchBufAutoCorrCoeffsView() const { kNumPitchBufAutoCorrCoeffs}; } -bool IsOptimizationAvailable(Optimization optimization) { - switch (optimization) { - case Optimization::kSse2: -#if defined(WEBRTC_ARCH_X86_FAMILY) - return GetCPUInfo(kSSE2) != 0; -#else - return false; -#endif - case Optimization::kNeon: -#if defined(WEBRTC_HAS_NEON) - return true; -#else - return false; -#endif - case Optimization::kNone: - return true; - } -} - } // namespace test } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.h b/modules/audio_processing/agc2/rnn_vad/test_utils.h index 23e642be81..3d1ad259db 100644 --- a/modules/audio_processing/agc2/rnn_vad/test_utils.h +++ b/modules/audio_processing/agc2/rnn_vad/test_utils.h @@ -152,9 +152,6 @@ class PitchTestData { std::array test_data_; }; -// Returns true if the given optimization is available. -bool IsOptimizationAvailable(Optimization optimization); - } // namespace test } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/vad_with_level.cc b/modules/audio_processing/agc2/vad_with_level.cc index 3dbb55732b..da3bd0a3fb 100644 --- a/modules/audio_processing/agc2/vad_with_level.cc +++ b/modules/audio_processing/agc2/vad_with_level.cc @@ -32,7 +32,8 @@ using VoiceActivityDetector = VadLevelAnalyzer::VoiceActivityDetector; // Computes the speech probability on the first channel. class Vad : public VoiceActivityDetector { public: - Vad() = default; + explicit Vad(const AvailableCpuFeatures& cpu_features) + : features_extractor_(cpu_features), rnn_vad_(cpu_features) {} Vad(const Vad&) = delete; Vad& operator=(const Vad&) = delete; ~Vad() = default; @@ -80,10 +81,12 @@ float SmoothedVadProbability(float p_old, float p_new, float attack) { VadLevelAnalyzer::VadLevelAnalyzer() : VadLevelAnalyzer(kDefaultSmoothedVadProbabilityAttack, - std::make_unique()) {} + GetAvailableCpuFeatures()) {} -VadLevelAnalyzer::VadLevelAnalyzer(float vad_probability_attack) - : VadLevelAnalyzer(vad_probability_attack, std::make_unique()) {} +VadLevelAnalyzer::VadLevelAnalyzer(float vad_probability_attack, + const AvailableCpuFeatures& cpu_features) + : VadLevelAnalyzer(vad_probability_attack, + std::make_unique(cpu_features)) {} VadLevelAnalyzer::VadLevelAnalyzer(float vad_probability_attack, std::unique_ptr vad) diff --git a/modules/audio_processing/agc2/vad_with_level.h b/modules/audio_processing/agc2/vad_with_level.h index ce72cdc754..2a6788278e 100644 --- a/modules/audio_processing/agc2/vad_with_level.h +++ b/modules/audio_processing/agc2/vad_with_level.h @@ -13,6 +13,7 @@ #include +#include "modules/audio_processing/agc2/cpu_features.h" #include "modules/audio_processing/include/audio_frame_view.h" namespace webrtc { @@ -36,7 +37,8 @@ class VadLevelAnalyzer { // Ctor. Uses the default VAD. VadLevelAnalyzer(); - explicit VadLevelAnalyzer(float vad_probability_attack); + VadLevelAnalyzer(float vad_probability_attack, + const AvailableCpuFeatures& cpu_features); // Ctor. Uses a custom `vad`. VadLevelAnalyzer(float vad_probability_attack, std::unique_ptr vad); diff --git a/modules/audio_processing/include/audio_processing.h b/modules/audio_processing/include/audio_processing.h index e85ac0c63e..b96ce926a1 100644 --- a/modules/audio_processing/include/audio_processing.h +++ b/modules/audio_processing/include/audio_processing.h @@ -350,10 +350,10 @@ class RTC_EXPORT AudioProcessing : public rtc::RefCountInterface { enum LevelEstimator { kRms, kPeak }; bool enabled = false; - struct { + struct FixedDigital { float gain_db = 0.f; } fixed_digital; - struct { + struct AdaptiveDigital { bool enabled = false; float vad_probability_attack = 1.f; LevelEstimator level_estimator = kRms; @@ -365,6 +365,7 @@ class RTC_EXPORT AudioProcessing : public rtc::RefCountInterface { int gain_applier_adjacent_speech_frames_threshold = 1; float max_gain_change_db_per_second = 3.f; float max_output_noise_level_dbfs = -50.f; + bool avx2_allowed = true; } adaptive_digital; } gain_controller2;