From 253f8369bbf492d62900252d08688810d6a891fc Mon Sep 17 00:00:00 2001 From: Alessio Bazzica Date: Fri, 27 Nov 2020 16:02:38 +0100 Subject: [PATCH] AGC2 RNN VAD: safe SIMD optimizations scheme + AVX2 kill switch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In preparation for adding AVX2 code, a safe scheme to support different SIMD optimizations is added. Safety features: - AVX2 kill switch to stop using it even if supported by the architecture - struct indicating the available CPU features propagated from AGC2 to each component; in this way - better control over the unit tests - no need to propagate individual kill switches but just set to false features that are turned off Note that (i) this CL does not change the performance of the RNN VAD and (ii) no AVX2 optimization is added yet. Bug: webrtc:10480 Change-Id: I0e61f3311ecd140f38369cf68b6e5954f3dc1f5a Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/193140 Reviewed-by: Per Ã…hgren Commit-Queue: Alessio Bazzica Cr-Commit-Position: refs/heads/master@{#32739} --- modules/audio_processing/agc2/BUILD.gn | 15 ++ modules/audio_processing/agc2/adaptive_agc.cc | 13 +- modules/audio_processing/agc2/cpu_features.cc | 54 +++++ modules/audio_processing/agc2/cpu_features.h | 36 +++ .../audio_processing/agc2/rnn_vad/BUILD.gn | 14 +- .../audio_processing/agc2/rnn_vad/common.cc | 34 --- .../audio_processing/agc2/rnn_vad/common.h | 5 - .../agc2/rnn_vad/features_extraction.cc | 4 +- .../agc2/rnn_vad/features_extraction.h | 2 +- .../rnn_vad/features_extraction_unittest.cc | 4 +- .../agc2/rnn_vad/pitch_search.cc | 6 +- .../agc2/rnn_vad/pitch_search.h | 4 +- .../agc2/rnn_vad/pitch_search_unittest.cc | 4 +- modules/audio_processing/agc2/rnn_vad/rnn.cc | 70 ++---- modules/audio_processing/agc2/rnn_vad/rnn.h | 13 +- .../agc2/rnn_vad/rnn_unittest.cc | 208 +++++++----------- .../agc2/rnn_vad/rnn_vad_tool.cc | 6 +- .../agc2/rnn_vad/rnn_vad_unittest.cc | 39 +++- .../agc2/rnn_vad/test_utils.cc | 21 -- .../agc2/rnn_vad/test_utils.h | 3 - .../audio_processing/agc2/vad_with_level.cc | 11 +- .../audio_processing/agc2/vad_with_level.h | 4 +- .../include/audio_processing.h | 5 +- 23 files changed, 296 insertions(+), 279 deletions(-) create mode 100644 modules/audio_processing/agc2/cpu_features.cc create mode 100644 modules/audio_processing/agc2/cpu_features.h delete mode 100644 modules/audio_processing/agc2/rnn_vad/common.cc diff --git a/modules/audio_processing/agc2/BUILD.gn b/modules/audio_processing/agc2/BUILD.gn index bf09533737..c6667df420 100644 --- a/modules/audio_processing/agc2/BUILD.gn +++ b/modules/audio_processing/agc2/BUILD.gn @@ -56,6 +56,7 @@ rtc_library("adaptive_digital") { deps = [ ":common", + ":cpu_features", ":gain_applier", ":noise_level_estimator", ":rnn_vad_with_level", @@ -163,6 +164,7 @@ rtc_library("rnn_vad_with_level") { ] deps = [ ":common", + ":cpu_features", "..:audio_frame_view", "../../../api:array_view", "../../../common_audio", @@ -172,6 +174,19 @@ rtc_library("rnn_vad_with_level") { ] } +rtc_library("cpu_features") { + sources = [ + "cpu_features.cc", + "cpu_features.h", + ] + visibility = [ "./*" ] + deps = [ + "../../../rtc_base:stringutils", + "../../../rtc_base/system:arch", + "../../../system_wrappers", + ] +} + rtc_library("adaptive_digital_unittests") { testonly = true configs += [ "..:apm_debug_dump" ] diff --git a/modules/audio_processing/agc2/adaptive_agc.cc b/modules/audio_processing/agc2/adaptive_agc.cc index 0372ccf38a..4df3b58e9d 100644 --- a/modules/audio_processing/agc2/adaptive_agc.cc +++ b/modules/audio_processing/agc2/adaptive_agc.cc @@ -11,6 +11,7 @@ #include "modules/audio_processing/agc2/adaptive_agc.h" #include "common_audio/include/audio_util.h" +#include "modules/audio_processing/agc2/cpu_features.h" #include "modules/audio_processing/agc2/vad_with_level.h" #include "modules/audio_processing/logging/apm_data_dumper.h" #include "rtc_base/checks.h" @@ -32,6 +33,15 @@ constexpr int kGainApplierAdjacentSpeechFramesThreshold = 1; constexpr float kMaxGainChangePerSecondDb = 3.f; constexpr float kMaxOutputNoiseLevelDbfs = -50.f; +// Detects the available CPU features and applies a kill-switch to AVX2. +AvailableCpuFeatures GetAllowedCpuFeatures(bool avx2_allowed) { + AvailableCpuFeatures features = GetAvailableCpuFeatures(); + if (!avx2_allowed) { + features.avx2 = false; + } + return features; +} + } // namespace AdaptiveAgc::AdaptiveAgc(ApmDataDumper* apm_data_dumper) @@ -54,7 +64,8 @@ AdaptiveAgc::AdaptiveAgc(ApmDataDumper* apm_data_dumper, .level_estimator_adjacent_speech_frames_threshold, config.adaptive_digital.initial_saturation_margin_db, config.adaptive_digital.extra_saturation_margin_db), - vad_(config.adaptive_digital.vad_probability_attack), + vad_(config.adaptive_digital.vad_probability_attack, + GetAllowedCpuFeatures(config.adaptive_digital.avx2_allowed)), gain_applier_( apm_data_dumper, config.adaptive_digital.gain_applier_adjacent_speech_frames_threshold, diff --git a/modules/audio_processing/agc2/cpu_features.cc b/modules/audio_processing/agc2/cpu_features.cc new file mode 100644 index 0000000000..b4f377ffba --- /dev/null +++ b/modules/audio_processing/agc2/cpu_features.cc @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/cpu_features.h" + +#include "rtc_base/strings/string_builder.h" +#include "rtc_base/system/arch.h" +#include "system_wrappers/include/cpu_features_wrapper.h" + +namespace webrtc { + +std::string AvailableCpuFeatures::ToString() const { + char buf[64]; + rtc::SimpleStringBuilder builder(buf); + bool first = true; + if (sse2) { + builder << (first ? "SSE2" : "_SSE2"); + first = false; + } + if (avx2) { + builder << (first ? "AVX2" : "_AVX2"); + first = false; + } + if (neon) { + builder << (first ? "NEON" : "_NEON"); + first = false; + } + if (first) { + return "none"; + } + return builder.str(); +} + +// Detects available CPU features. +AvailableCpuFeatures GetAvailableCpuFeatures() { +#if defined(WEBRTC_ARCH_X86_FAMILY) + return {/*sse2=*/GetCPUInfo(kSSE2) != 0, + /*avx2=*/GetCPUInfo(kAVX2) != 0, + /*neon=*/false}; +#elif defined(WEBRTC_HAS_NEON) + return {/*sse2=*/false, + /*avx2=*/false, + /*neon=*/true}; +#endif +} + +} // namespace webrtc diff --git a/modules/audio_processing/agc2/cpu_features.h b/modules/audio_processing/agc2/cpu_features.h new file mode 100644 index 0000000000..bf73c3e562 --- /dev/null +++ b/modules/audio_processing/agc2/cpu_features.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC2_CPU_FEATURES_H_ +#define MODULES_AUDIO_PROCESSING_AGC2_CPU_FEATURES_H_ + +#include + +namespace webrtc { + +// Collection of flags indicating which CPU features are available on the +// current platform. True means available. +struct AvailableCpuFeatures { + AvailableCpuFeatures(bool sse2, bool avx2, bool neon) + : sse2(sse2), avx2(avx2), neon(neon) {} + // Intel. + bool sse2; + bool avx2; + // ARM. + bool neon; + std::string ToString() const; +}; + +// Detects what CPU features are available. +AvailableCpuFeatures GetAvailableCpuFeatures(); + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC2_CPU_FEATURES_H_ diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn index fae1d5a572..a4285bab5a 100644 --- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn +++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn @@ -29,10 +29,10 @@ rtc_library("rnn_vad") { ":rnn_vad_sequence_buffer", ":rnn_vad_spectral_features", "..:biquad_filter", + "..:cpu_features", "../../../../api:array_view", "../../../../api:function_view", "../../../../rtc_base:checks", - "../../../../rtc_base:logging", "../../../../rtc_base:safe_compare", "../../../../rtc_base:safe_conversions", "../../../../rtc_base/system:arch", @@ -53,16 +53,13 @@ rtc_library("rnn_vad_auto_correlation") { ] } -rtc_library("rnn_vad_common") { +rtc_source_set("rnn_vad_common") { # TODO(alessiob): Make this target visibility private. visibility = [ ":*", "..:rnn_vad_with_level", ] - sources = [ - "common.cc", - "common.h", - ] + sources = [ "common.h" ] deps = [ "../../../../rtc_base/system:arch", "../../../../system_wrappers", @@ -91,6 +88,7 @@ rtc_library("rnn_vad_pitch") { deps = [ ":rnn_vad_auto_correlation", ":rnn_vad_common", + "..:cpu_features", "../../../../api:array_view", "../../../../rtc_base:checks", "../../../../rtc_base:gtest_prod", @@ -156,8 +154,6 @@ if (rtc_include_tests) { "../../../../api:scoped_refptr", "../../../../rtc_base:checks", "../../../../rtc_base:safe_compare", - "../../../../rtc_base/system:arch", - "../../../../system_wrappers", "../../../../test:fileutils", "../../../../test:test_support", ] @@ -207,6 +203,7 @@ if (rtc_include_tests) { ":rnn_vad_spectral_features", ":rnn_vad_symmetric_matrix_buffer", ":test_utils", + "..:cpu_features", "../..:audioproc_test_utils", "../../../../api:array_view", "../../../../common_audio/", @@ -232,6 +229,7 @@ if (rtc_include_tests) { deps = [ ":rnn_vad", ":rnn_vad_common", + "..:cpu_features", "../../../../api:array_view", "../../../../common_audio", "../../../../rtc_base:rtc_base_approved", diff --git a/modules/audio_processing/agc2/rnn_vad/common.cc b/modules/audio_processing/agc2/rnn_vad/common.cc deleted file mode 100644 index 5d76b52e57..0000000000 --- a/modules/audio_processing/agc2/rnn_vad/common.cc +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#include "modules/audio_processing/agc2/rnn_vad/common.h" - -#include "rtc_base/system/arch.h" -#include "system_wrappers/include/cpu_features_wrapper.h" - -namespace webrtc { -namespace rnn_vad { - -Optimization DetectOptimization() { -#if defined(WEBRTC_ARCH_X86_FAMILY) - if (GetCPUInfo(kSSE2) != 0) { - return Optimization::kSse2; - } -#endif - -#if defined(WEBRTC_HAS_NEON) - return Optimization::kNeon; -#endif - - return Optimization::kNone; -} - -} // namespace rnn_vad -} // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/common.h b/modules/audio_processing/agc2/rnn_vad/common.h index 36b366ad1d..be5a2d58ce 100644 --- a/modules/audio_processing/agc2/rnn_vad/common.h +++ b/modules/audio_processing/agc2/rnn_vad/common.h @@ -71,11 +71,6 @@ static_assert(kCepstralCoeffsHistorySize > 2, constexpr int kFeatureVectorSize = 42; -enum class Optimization { kNone, kSse2, kNeon }; - -// Detects what kind of optimizations to use for the code. -Optimization DetectOptimization(); - } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/features_extraction.cc b/modules/audio_processing/agc2/rnn_vad/features_extraction.cc index cdbbbc311d..f86eba764e 100644 --- a/modules/audio_processing/agc2/rnn_vad/features_extraction.cc +++ b/modules/audio_processing/agc2/rnn_vad/features_extraction.cc @@ -26,13 +26,13 @@ const BiQuadFilter::BiQuadCoefficients kHpfConfig24k = { } // namespace -FeaturesExtractor::FeaturesExtractor() +FeaturesExtractor::FeaturesExtractor(const AvailableCpuFeatures& cpu_features) : use_high_pass_filter_(false), pitch_buf_24kHz_(), pitch_buf_24kHz_view_(pitch_buf_24kHz_.GetBufferView()), lp_residual_(kBufSize24kHz), lp_residual_view_(lp_residual_.data(), kBufSize24kHz), - pitch_estimator_(), + pitch_estimator_(cpu_features), reference_frame_view_(pitch_buf_24kHz_.GetMostRecentValuesView()) { RTC_DCHECK_EQ(kBufSize24kHz, lp_residual_.size()); hpf_.Initialize(kHpfConfig24k); diff --git a/modules/audio_processing/agc2/rnn_vad/features_extraction.h b/modules/audio_processing/agc2/rnn_vad/features_extraction.h index e2c77d2cf8..f4cea7a83d 100644 --- a/modules/audio_processing/agc2/rnn_vad/features_extraction.h +++ b/modules/audio_processing/agc2/rnn_vad/features_extraction.h @@ -26,7 +26,7 @@ namespace rnn_vad { // Feature extractor to feed the VAD RNN. class FeaturesExtractor { public: - FeaturesExtractor(); + explicit FeaturesExtractor(const AvailableCpuFeatures& cpu_features); FeaturesExtractor(const FeaturesExtractor&) = delete; FeaturesExtractor& operator=(const FeaturesExtractor&) = delete; ~FeaturesExtractor(); diff --git a/modules/audio_processing/agc2/rnn_vad/features_extraction_unittest.cc b/modules/audio_processing/agc2/rnn_vad/features_extraction_unittest.cc index 9df52738b4..0da971e3da 100644 --- a/modules/audio_processing/agc2/rnn_vad/features_extraction_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/features_extraction_unittest.cc @@ -13,6 +13,7 @@ #include #include +#include "modules/audio_processing/agc2/cpu_features.h" #include "modules/audio_processing/agc2/rnn_vad/test_utils.h" #include "rtc_base/numerics/safe_compare.h" #include "rtc_base/numerics/safe_conversions.h" @@ -77,7 +78,8 @@ TEST(RnnVadTest, FeatureExtractionLowHighPitch) { ASSERT_TRUE(PitchIsValid(low_pitch_hz)); ASSERT_TRUE(PitchIsValid(high_pitch_hz)); - FeaturesExtractor features_extractor; + const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures(); + FeaturesExtractor features_extractor(cpu_features); std::vector samples(kNumTestDataSize); std::vector feature_vector(kFeatureVectorSize); ASSERT_EQ(kFeatureVectorSize, rtc::dchecked_cast(feature_vector.size())); diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc index c6c3e1b2b5..c2e7665967 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc @@ -18,8 +18,9 @@ namespace webrtc { namespace rnn_vad { -PitchEstimator::PitchEstimator() - : y_energy_24kHz_(kRefineNumLags24kHz, 0.f), +PitchEstimator::PitchEstimator(const AvailableCpuFeatures& cpu_features) + : cpu_features_(cpu_features), + y_energy_24kHz_(kRefineNumLags24kHz, 0.f), pitch_buffer_12kHz_(kBufSize12kHz), auto_correlation_12kHz_(kNumLags12kHz) {} @@ -35,6 +36,7 @@ int PitchEstimator::Estimate( RTC_DCHECK_EQ(auto_correlation_12kHz_.size(), auto_correlation_12kHz_view.size()); + // TODO(bugs.chromium.org/10480): Use `cpu_features_` to estimate pitch. // Perform the initial pitch search at 12 kHz. Decimate2x(pitch_buffer, pitch_buffer_12kHz_view); auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buffer_12kHz_view, diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.h b/modules/audio_processing/agc2/rnn_vad/pitch_search.h index e96a2dcaf1..42c448eb56 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search.h +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.h @@ -15,6 +15,7 @@ #include #include "api/array_view.h" +#include "modules/audio_processing/agc2/cpu_features.h" #include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h" #include "modules/audio_processing/agc2/rnn_vad/common.h" #include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h" @@ -26,7 +27,7 @@ namespace rnn_vad { // Pitch estimator. class PitchEstimator { public: - PitchEstimator(); + explicit PitchEstimator(const AvailableCpuFeatures& cpu_features); PitchEstimator(const PitchEstimator&) = delete; PitchEstimator& operator=(const PitchEstimator&) = delete; ~PitchEstimator(); @@ -39,6 +40,7 @@ class PitchEstimator { return last_pitch_48kHz_.strength; } + const AvailableCpuFeatures cpu_features_; PitchInfo last_pitch_48kHz_{}; AutoCorrelationCalculator auto_corr_calculator_; std::vector y_energy_24kHz_; diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc index 98b791e872..fe9be5dbba 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc @@ -13,6 +13,7 @@ #include #include +#include "modules/audio_processing/agc2/cpu_features.h" #include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h" #include "modules/audio_processing/agc2/rnn_vad/test_utils.h" // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. @@ -29,7 +30,8 @@ TEST(RnnVadTest, PitchSearchWithinTolerance) { const int num_frames = std::min(lp_residual_reader.second, 300); // Max 3 s. std::vector lp_residual(kBufSize24kHz); float expected_pitch_period, expected_pitch_strength; - PitchEstimator pitch_estimator; + const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures(); + PitchEstimator pitch_estimator(cpu_features); { // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.cc b/modules/audio_processing/agc2/rnn_vad/rnn.cc index 2072a6854d..fb4962f724 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn.cc @@ -25,7 +25,6 @@ #include #include "rtc_base/checks.h" -#include "rtc_base/logging.h" #include "rtc_base/numerics/safe_conversions.h" #include "third_party/rnnoise/src/rnn_activations.h" #include "third_party/rnnoise/src/rnn_vad_weights.h" @@ -273,13 +272,13 @@ FullyConnectedLayer::FullyConnectedLayer( const rtc::ArrayView bias, const rtc::ArrayView weights, rtc::FunctionView activation_function, - Optimization optimization) + const AvailableCpuFeatures& cpu_features) : input_size_(input_size), output_size_(output_size), bias_(GetScaledParams(bias)), weights_(GetPreprocessedFcWeights(weights, output_size)), activation_function_(activation_function), - optimization_(optimization) { + cpu_features_(cpu_features) { RTC_DCHECK_LE(output_size_, kFullyConnectedLayersMaxUnits) << "Static over-allocation of fully-connected layers output vectors is " "not sufficient."; @@ -296,25 +295,18 @@ rtc::ArrayView FullyConnectedLayer::GetOutput() const { } void FullyConnectedLayer::ComputeOutput(rtc::ArrayView input) { - switch (optimization_) { #if defined(WEBRTC_ARCH_X86_FAMILY) - case Optimization::kSse2: - ComputeFullyConnectedLayerOutputSse2(input_size_, output_size_, input, - bias_, weights_, - activation_function_, output_); - break; -#endif -#if defined(WEBRTC_HAS_NEON) - case Optimization::kNeon: - // TODO(bugs.chromium.org/10480): Handle Optimization::kNeon. - ComputeFullyConnectedLayerOutput(input_size_, output_size_, input, bias_, - weights_, activation_function_, output_); - break; -#endif - default: - ComputeFullyConnectedLayerOutput(input_size_, output_size_, input, bias_, - weights_, activation_function_, output_); + // TODO(bugs.chromium.org/10480): Add AVX2. + if (cpu_features_.sse2) { + ComputeFullyConnectedLayerOutputSse2(input_size_, output_size_, input, + bias_, weights_, activation_function_, + output_); + return; } +#endif + // TODO(bugs.chromium.org/10480): Add Neon. + ComputeFullyConnectedLayerOutput(input_size_, output_size_, input, bias_, + weights_, activation_function_, output_); } GatedRecurrentLayer::GatedRecurrentLayer( @@ -322,15 +314,13 @@ GatedRecurrentLayer::GatedRecurrentLayer( const int output_size, const rtc::ArrayView bias, const rtc::ArrayView weights, - const rtc::ArrayView recurrent_weights, - Optimization optimization) + const rtc::ArrayView recurrent_weights) : input_size_(input_size), output_size_(output_size), bias_(GetPreprocessedGruTensor(bias, output_size)), weights_(GetPreprocessedGruTensor(weights, output_size)), recurrent_weights_( - GetPreprocessedGruTensor(recurrent_weights, output_size)), - optimization_(optimization) { + GetPreprocessedGruTensor(recurrent_weights, output_size)) { RTC_DCHECK_LE(output_size_, kRecurrentLayersMaxUnits) << "Static over-allocation of recurrent layers state vectors is not " "sufficient."; @@ -356,46 +346,30 @@ void GatedRecurrentLayer::Reset() { } void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView input) { - switch (optimization_) { -#if defined(WEBRTC_ARCH_X86_FAMILY) - case Optimization::kSse2: - // TODO(bugs.chromium.org/10480): Handle Optimization::kSse2. - ComputeGruLayerOutput(input_size_, output_size_, input, weights_, - recurrent_weights_, bias_, state_); - break; -#endif -#if defined(WEBRTC_HAS_NEON) - case Optimization::kNeon: - // TODO(bugs.chromium.org/10480): Handle Optimization::kNeon. - ComputeGruLayerOutput(input_size_, output_size_, input, weights_, - recurrent_weights_, bias_, state_); - break; -#endif - default: - ComputeGruLayerOutput(input_size_, output_size_, input, weights_, - recurrent_weights_, bias_, state_); - } + // TODO(bugs.chromium.org/10480): Add AVX2. + // TODO(bugs.chromium.org/10480): Add Neon. + ComputeGruLayerOutput(input_size_, output_size_, input, weights_, + recurrent_weights_, bias_, state_); } -RnnBasedVad::RnnBasedVad() +RnnBasedVad::RnnBasedVad(const AvailableCpuFeatures& cpu_features) : input_layer_(kInputLayerInputSize, kInputLayerOutputSize, kInputDenseBias, kInputDenseWeights, TansigApproximated, - DetectOptimization()), + cpu_features), hidden_layer_(kInputLayerOutputSize, kHiddenLayerOutputSize, kHiddenGruBias, kHiddenGruWeights, - kHiddenGruRecurrentWeights, - DetectOptimization()), + kHiddenGruRecurrentWeights), output_layer_(kHiddenLayerOutputSize, kOutputLayerOutputSize, kOutputDenseBias, kOutputDenseWeights, SigmoidApproximated, - DetectOptimization()) { + cpu_features) { // Input-output chaining size checks. RTC_DCHECK_EQ(input_layer_.output_size(), hidden_layer_.input_size()) << "The input and the hidden layers sizes do not match."; diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.h b/modules/audio_processing/agc2/rnn_vad/rnn.h index 5b44f53047..1ef4c76c21 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn.h +++ b/modules/audio_processing/agc2/rnn_vad/rnn.h @@ -19,6 +19,7 @@ #include "api/array_view.h" #include "api/function_view.h" +#include "modules/audio_processing/agc2/cpu_features.h" #include "modules/audio_processing/agc2/rnn_vad/common.h" #include "rtc_base/system/arch.h" @@ -45,13 +46,12 @@ class FullyConnectedLayer { rtc::ArrayView bias, rtc::ArrayView weights, rtc::FunctionView activation_function, - Optimization optimization); + const AvailableCpuFeatures& cpu_features); FullyConnectedLayer(const FullyConnectedLayer&) = delete; FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete; ~FullyConnectedLayer(); int input_size() const { return input_size_; } int output_size() const { return output_size_; } - Optimization optimization() const { return optimization_; } rtc::ArrayView GetOutput() const; // Computes the fully-connected layer output. void ComputeOutput(rtc::ArrayView input); @@ -65,7 +65,7 @@ class FullyConnectedLayer { // The output vector of a recurrent layer has length equal to |output_size_|. // However, for efficiency, over-allocation is used. std::array output_; - const Optimization optimization_; + const AvailableCpuFeatures cpu_features_; }; // Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as @@ -76,14 +76,12 @@ class GatedRecurrentLayer { int output_size, rtc::ArrayView bias, rtc::ArrayView weights, - rtc::ArrayView recurrent_weights, - Optimization optimization); + rtc::ArrayView recurrent_weights); GatedRecurrentLayer(const GatedRecurrentLayer&) = delete; GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete; ~GatedRecurrentLayer(); int input_size() const { return input_size_; } int output_size() const { return output_size_; } - Optimization optimization() const { return optimization_; } rtc::ArrayView GetOutput() const; void Reset(); // Computes the recurrent layer output and updates the status. @@ -98,13 +96,12 @@ class GatedRecurrentLayer { // The state vector of a recurrent layer has length equal to |output_size_|. // However, to avoid dynamic allocation, over-allocation is used. std::array state_; - const Optimization optimization_; }; // Recurrent network based VAD. class RnnBasedVad { public: - RnnBasedVad(); + explicit RnnBasedVad(const AvailableCpuFeatures& cpu_features); RnnBasedVad(const RnnBasedVad&) = delete; RnnBasedVad& operator=(const RnnBasedVad&) = delete; ~RnnBasedVad(); diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc index a57a899c8d..2e920e8d80 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc @@ -14,6 +14,7 @@ #include #include +#include "modules/audio_processing/agc2/cpu_features.h" #include "modules/audio_processing/agc2/rnn_vad/test_utils.h" #include "modules/audio_processing/test/performance_timer.h" #include "rtc_base/checks.h" @@ -27,7 +28,6 @@ namespace webrtc { namespace rnn_vad { namespace test { - namespace { void TestFullyConnectedLayer(FullyConnectedLayer* fc, @@ -39,26 +39,25 @@ void TestFullyConnectedLayer(FullyConnectedLayer* fc, } void TestGatedRecurrentLayer( - GatedRecurrentLayer* gru, + GatedRecurrentLayer& gru, rtc::ArrayView input_sequence, rtc::ArrayView expected_output_sequence) { - RTC_CHECK(gru); - auto gru_output_view = gru->GetOutput(); + auto gru_output_view = gru.GetOutput(); const int input_sequence_length = rtc::CheckedDivExact( - rtc::dchecked_cast(input_sequence.size()), gru->input_size()); + rtc::dchecked_cast(input_sequence.size()), gru.input_size()); const int output_sequence_length = rtc::CheckedDivExact( rtc::dchecked_cast(expected_output_sequence.size()), - gru->output_size()); + gru.output_size()); ASSERT_EQ(input_sequence_length, output_sequence_length) << "The test data length is invalid."; // Feed the GRU layer and check the output at every step. - gru->Reset(); + gru.Reset(); for (int i = 0; i < input_sequence_length; ++i) { SCOPED_TRACE(i); - gru->ComputeOutput( - input_sequence.subview(i * gru->input_size(), gru->input_size())); + gru.ComputeOutput( + input_sequence.subview(i * gru.input_size(), gru.input_size())); const auto expected_output = expected_output_sequence.subview( - i * gru->output_size(), gru->output_size()); + i * gru.output_size(), gru.output_size()); ExpectNearAbsolute(expected_output, gru_output_view, 3e-6f); } } @@ -134,141 +133,94 @@ constexpr std::array kGruExpectedOutputSequence = { 0.00781069f, 0.75267816f, 0.f, 0.02579715f, 0.00471378f, 0.59162533f, 0.11087593f, 0.01334511f}; -std::string GetOptimizationName(Optimization optimization) { - switch (optimization) { - case Optimization::kSse2: - return "SSE2"; - case Optimization::kNeon: - return "NEON"; - case Optimization::kNone: - return "none"; - } -} - -struct Result { - Optimization optimization; - double average_us; - double std_dev_us; -}; - -} // namespace - -// Checks that the output of a fully connected layer is within tolerance given -// test input data. -TEST(RnnVadTest, CheckFullyConnectedLayerOutput) { - FullyConnectedLayer fc(rnnoise::kInputLayerInputSize, - rnnoise::kInputLayerOutputSize, - rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights, - rnnoise::TansigApproximated, Optimization::kNone); - TestFullyConnectedLayer(&fc, kFullyConnectedInputVector, - kFullyConnectedExpectedOutput); -} - // Checks that the output of a GRU layer is within tolerance given test input // data. TEST(RnnVadTest, CheckGatedRecurrentLayer) { GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights, - kGruRecurrentWeights, Optimization::kNone); - TestGatedRecurrentLayer(&gru, kGruInputSequence, kGruExpectedOutputSequence); -} - -#if defined(WEBRTC_ARCH_X86_FAMILY) - -// Like CheckFullyConnectedLayerOutput, but testing the SSE2 implementation. -TEST(RnnVadTest, CheckFullyConnectedLayerOutputSse2) { - if (!IsOptimizationAvailable(Optimization::kSse2)) { - return; - } - - FullyConnectedLayer fc(rnnoise::kInputLayerInputSize, - rnnoise::kInputLayerOutputSize, - rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights, - rnnoise::TansigApproximated, Optimization::kSse2); - TestFullyConnectedLayer(&fc, kFullyConnectedInputVector, - kFullyConnectedExpectedOutput); -} - -// Like CheckGatedRecurrentLayer, but testing the SSE2 implementation. -TEST(RnnVadTest, CheckGatedRecurrentLayerSse2) { - if (!IsOptimizationAvailable(Optimization::kSse2)) { - return; - } - - GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights, - kGruRecurrentWeights, Optimization::kSse2); - TestGatedRecurrentLayer(&gru, kGruInputSequence, kGruExpectedOutputSequence); -} - -#endif // WEBRTC_ARCH_X86_FAMILY - -TEST(RnnVadTest, DISABLED_BenchmarkFullyConnectedLayer) { - std::vector> implementations; - implementations.emplace_back(std::make_unique( - rnnoise::kInputLayerInputSize, rnnoise::kInputLayerOutputSize, - rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights, - rnnoise::TansigApproximated, Optimization::kNone)); - if (IsOptimizationAvailable(Optimization::kSse2)) { - implementations.emplace_back(std::make_unique( - rnnoise::kInputLayerInputSize, rnnoise::kInputLayerOutputSize, - rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights, - rnnoise::TansigApproximated, Optimization::kSse2)); - } - - std::vector results; - constexpr int number_of_tests = 10000; - for (auto& fc : implementations) { - ::webrtc::test::PerformanceTimer perf_timer(number_of_tests); - for (int k = 0; k < number_of_tests; ++k) { - perf_timer.StartTimer(); - fc->ComputeOutput(kFullyConnectedInputVector); - perf_timer.StopTimer(); - } - results.push_back({fc->optimization(), perf_timer.GetDurationAverage(), - perf_timer.GetDurationStandardDeviation()}); - } - - for (const auto& result : results) { - RTC_LOG(LS_INFO) << GetOptimizationName(result.optimization) << ": " - << (result.average_us / 1e3) << " +/- " - << (result.std_dev_us / 1e3) << " ms"; - } + kGruRecurrentWeights); + TestGatedRecurrentLayer(gru, kGruInputSequence, kGruExpectedOutputSequence); } TEST(RnnVadTest, DISABLED_BenchmarkGatedRecurrentLayer) { - std::vector> implementations; - implementations.emplace_back(std::make_unique( - kGruInputSize, kGruOutputSize, kGruBias, kGruWeights, - kGruRecurrentWeights, Optimization::kNone)); + GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights, + kGruRecurrentWeights); rtc::ArrayView input_sequence(kGruInputSequence); static_assert(kGruInputSequence.size() % kGruInputSize == 0, ""); constexpr int input_sequence_length = kGruInputSequence.size() / kGruInputSize; - std::vector results; - constexpr int number_of_tests = 10000; - for (auto& gru : implementations) { - ::webrtc::test::PerformanceTimer perf_timer(number_of_tests); - gru->Reset(); - for (int k = 0; k < number_of_tests; ++k) { - perf_timer.StartTimer(); - for (int i = 0; i < input_sequence_length; ++i) { - gru->ComputeOutput( - input_sequence.subview(i * gru->input_size(), gru->input_size())); - } - perf_timer.StopTimer(); + constexpr int kNumTests = 10000; + ::webrtc::test::PerformanceTimer perf_timer(kNumTests); + for (int k = 0; k < kNumTests; ++k) { + perf_timer.StartTimer(); + for (int i = 0; i < input_sequence_length; ++i) { + gru.ComputeOutput( + input_sequence.subview(i * gru.input_size(), gru.input_size())); } - results.push_back({gru->optimization(), perf_timer.GetDurationAverage(), - perf_timer.GetDurationStandardDeviation()}); - } - - for (const auto& result : results) { - RTC_LOG(LS_INFO) << GetOptimizationName(result.optimization) << ": " - << (result.average_us / 1e3) << " +/- " - << (result.std_dev_us / 1e3) << " ms"; + perf_timer.StopTimer(); } + RTC_LOG(LS_INFO) << (perf_timer.GetDurationAverage() / 1000) << " +/- " + << (perf_timer.GetDurationStandardDeviation() / 1000) + << " ms"; } +class RnnParametrization + : public ::testing::TestWithParam {}; + +// Checks that the output of a fully connected layer is within tolerance given +// test input data. +TEST_P(RnnParametrization, CheckFullyConnectedLayerOutput) { + FullyConnectedLayer fc( + rnnoise::kInputLayerInputSize, rnnoise::kInputLayerOutputSize, + rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights, + rnnoise::TansigApproximated, /*cpu_features=*/GetParam()); + TestFullyConnectedLayer(&fc, kFullyConnectedInputVector, + kFullyConnectedExpectedOutput); +} + +TEST_P(RnnParametrization, DISABLED_BenchmarkFullyConnectedLayer) { + const AvailableCpuFeatures cpu_features = GetParam(); + FullyConnectedLayer fc(rnnoise::kInputLayerInputSize, + rnnoise::kInputLayerOutputSize, + rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights, + rnnoise::TansigApproximated, cpu_features); + + constexpr int kNumTests = 10000; + ::webrtc::test::PerformanceTimer perf_timer(kNumTests); + for (int k = 0; k < kNumTests; ++k) { + perf_timer.StartTimer(); + fc.ComputeOutput(kFullyConnectedInputVector); + perf_timer.StopTimer(); + } + RTC_LOG(LS_INFO) << "CPU features: " << cpu_features.ToString() << " | " + << (perf_timer.GetDurationAverage() / 1000) << " +/- " + << (perf_timer.GetDurationStandardDeviation() / 1000) + << " ms"; +} + +// Finds the relevant CPU features combinations to test. +std::vector GetCpuFeaturesToTest() { + std::vector v; + v.push_back({/*sse2=*/false, /*avx2=*/false, /*neon=*/false}); + AvailableCpuFeatures available = GetAvailableCpuFeatures(); + if (available.sse2) { + AvailableCpuFeatures features( + {/*sse2=*/true, /*avx2=*/false, /*neon=*/false}); + v.push_back(features); + } + return v; +} + +INSTANTIATE_TEST_SUITE_P( + RnnVadTest, + RnnParametrization, + ::testing::ValuesIn(GetCpuFeaturesToTest()), + [](const ::testing::TestParamInfo& info) { + return info.param.ToString(); + }); + +} // namespace } // namespace test } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_vad_tool.cc b/modules/audio_processing/agc2/rnn_vad/rnn_vad_tool.cc index 8b12b60c55..0f3ad5ce16 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn_vad_tool.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn_vad_tool.cc @@ -16,6 +16,7 @@ #include "absl/flags/parse.h" #include "common_audio/resampler/push_sinc_resampler.h" #include "common_audio/wav_file.h" +#include "modules/audio_processing/agc2/cpu_features.h" #include "modules/audio_processing/agc2/rnn_vad/common.h" #include "modules/audio_processing/agc2/rnn_vad/features_extraction.h" #include "modules/audio_processing/agc2/rnn_vad/rnn.h" @@ -63,9 +64,10 @@ int main(int argc, char* argv[]) { samples_10ms.resize(frame_size_10ms); std::array samples_10ms_24kHz; PushSincResampler resampler(frame_size_10ms, kFrameSize10ms24kHz); - FeaturesExtractor features_extractor; + const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures(); + FeaturesExtractor features_extractor(cpu_features); std::array feature_vector; - RnnBasedVad rnn_vad; + RnnBasedVad rnn_vad(cpu_features); // Compute VAD probabilities. while (true) { diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc index 0916bf5b81..6036a00fd0 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc @@ -13,6 +13,7 @@ #include #include "common_audio/resampler/push_sinc_resampler.h" +#include "modules/audio_processing/agc2/cpu_features.h" #include "modules/audio_processing/agc2/rnn_vad/features_extraction.h" #include "modules/audio_processing/agc2/rnn_vad/rnn.h" #include "modules/audio_processing/agc2/rnn_vad/test_utils.h" @@ -57,13 +58,17 @@ TEST(RnnVadTest, CheckWriteComputedOutputIsFalse) { << "Cannot land if kWriteComputedOutput is true."; } +class RnnVadProbabilityParametrization + : public ::testing::TestWithParam {}; + // Checks that the computed VAD probability for a test input sequence sampled at // 48 kHz is within tolerance. -TEST(RnnVadTest, RnnVadProbabilityWithinTolerance) { +TEST_P(RnnVadProbabilityParametrization, RnnVadProbabilityWithinTolerance) { // Init resampler, feature extractor and RNN. PushSincResampler decimator(kFrameSize10ms48kHz, kFrameSize10ms24kHz); - FeaturesExtractor features_extractor; - RnnBasedVad rnn_vad; + const AvailableCpuFeatures cpu_features = GetParam(); + FeaturesExtractor features_extractor(cpu_features); + RnnBasedVad rnn_vad(cpu_features); // Init input samples and expected output readers. auto samples_reader = CreatePcmSamplesReader(kFrameSize10ms48kHz); @@ -111,7 +116,7 @@ TEST(RnnVadTest, RnnVadProbabilityWithinTolerance) { // follows: // - on desktop: run the this unit test adding "--logs"; // - on android: run the this unit test adding "--logcat-output-file". -TEST(RnnVadTest, DISABLED_RnnVadPerformance) { +TEST_P(RnnVadProbabilityParametrization, DISABLED_RnnVadPerformance) { // PCM samples reader and buffers. auto samples_reader = CreatePcmSamplesReader(kFrameSize10ms48kHz); const int num_frames = samples_reader.second; @@ -127,9 +132,10 @@ TEST(RnnVadTest, DISABLED_RnnVadPerformance) { kFrameSize10ms24kHz); } // Initialize. - FeaturesExtractor features_extractor; + const AvailableCpuFeatures cpu_features = GetParam(); + FeaturesExtractor features_extractor(cpu_features); std::array feature_vector; - RnnBasedVad rnn_vad; + RnnBasedVad rnn_vad(cpu_features); constexpr int number_of_tests = 100; ::webrtc::test::PerformanceTimer perf_timer(number_of_tests); for (int k = 0; k < number_of_tests; ++k) { @@ -152,6 +158,27 @@ TEST(RnnVadTest, DISABLED_RnnVadPerformance) { perf_timer.GetDurationStandardDeviation()); } +// Finds the relevant CPU features combinations to test. +std::vector GetCpuFeaturesToTest() { + std::vector v; + v.push_back({/*sse2=*/false, /*avx2=*/false, /*neon=*/false}); + AvailableCpuFeatures available = GetAvailableCpuFeatures(); + if (available.sse2) { + AvailableCpuFeatures features( + {/*sse2=*/true, /*avx2=*/false, /*neon=*/false}); + v.push_back(features); + } + return v; +} + +INSTANTIATE_TEST_SUITE_P( + RnnVadTest, + RnnVadProbabilityParametrization, + ::testing::ValuesIn(GetCpuFeaturesToTest()), + [](const ::testing::TestParamInfo& info) { + return info.param.ToString(); + }); + } // namespace test } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.cc b/modules/audio_processing/agc2/rnn_vad/test_utils.cc index 24bbf13e31..75de1099f2 100644 --- a/modules/audio_processing/agc2/rnn_vad/test_utils.cc +++ b/modules/audio_processing/agc2/rnn_vad/test_utils.cc @@ -15,8 +15,6 @@ #include "rtc_base/checks.h" #include "rtc_base/numerics/safe_compare.h" -#include "rtc_base/system/arch.h" -#include "system_wrappers/include/cpu_features_wrapper.h" #include "test/gtest.h" #include "test/testsupport/file_utils.h" @@ -111,25 +109,6 @@ PitchTestData::GetPitchBufAutoCorrCoeffsView() const { kNumPitchBufAutoCorrCoeffs}; } -bool IsOptimizationAvailable(Optimization optimization) { - switch (optimization) { - case Optimization::kSse2: -#if defined(WEBRTC_ARCH_X86_FAMILY) - return GetCPUInfo(kSSE2) != 0; -#else - return false; -#endif - case Optimization::kNeon: -#if defined(WEBRTC_HAS_NEON) - return true; -#else - return false; -#endif - case Optimization::kNone: - return true; - } -} - } // namespace test } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.h b/modules/audio_processing/agc2/rnn_vad/test_utils.h index 23e642be81..3d1ad259db 100644 --- a/modules/audio_processing/agc2/rnn_vad/test_utils.h +++ b/modules/audio_processing/agc2/rnn_vad/test_utils.h @@ -152,9 +152,6 @@ class PitchTestData { std::array test_data_; }; -// Returns true if the given optimization is available. -bool IsOptimizationAvailable(Optimization optimization); - } // namespace test } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/vad_with_level.cc b/modules/audio_processing/agc2/vad_with_level.cc index 3dbb55732b..da3bd0a3fb 100644 --- a/modules/audio_processing/agc2/vad_with_level.cc +++ b/modules/audio_processing/agc2/vad_with_level.cc @@ -32,7 +32,8 @@ using VoiceActivityDetector = VadLevelAnalyzer::VoiceActivityDetector; // Computes the speech probability on the first channel. class Vad : public VoiceActivityDetector { public: - Vad() = default; + explicit Vad(const AvailableCpuFeatures& cpu_features) + : features_extractor_(cpu_features), rnn_vad_(cpu_features) {} Vad(const Vad&) = delete; Vad& operator=(const Vad&) = delete; ~Vad() = default; @@ -80,10 +81,12 @@ float SmoothedVadProbability(float p_old, float p_new, float attack) { VadLevelAnalyzer::VadLevelAnalyzer() : VadLevelAnalyzer(kDefaultSmoothedVadProbabilityAttack, - std::make_unique()) {} + GetAvailableCpuFeatures()) {} -VadLevelAnalyzer::VadLevelAnalyzer(float vad_probability_attack) - : VadLevelAnalyzer(vad_probability_attack, std::make_unique()) {} +VadLevelAnalyzer::VadLevelAnalyzer(float vad_probability_attack, + const AvailableCpuFeatures& cpu_features) + : VadLevelAnalyzer(vad_probability_attack, + std::make_unique(cpu_features)) {} VadLevelAnalyzer::VadLevelAnalyzer(float vad_probability_attack, std::unique_ptr vad) diff --git a/modules/audio_processing/agc2/vad_with_level.h b/modules/audio_processing/agc2/vad_with_level.h index ce72cdc754..2a6788278e 100644 --- a/modules/audio_processing/agc2/vad_with_level.h +++ b/modules/audio_processing/agc2/vad_with_level.h @@ -13,6 +13,7 @@ #include +#include "modules/audio_processing/agc2/cpu_features.h" #include "modules/audio_processing/include/audio_frame_view.h" namespace webrtc { @@ -36,7 +37,8 @@ class VadLevelAnalyzer { // Ctor. Uses the default VAD. VadLevelAnalyzer(); - explicit VadLevelAnalyzer(float vad_probability_attack); + VadLevelAnalyzer(float vad_probability_attack, + const AvailableCpuFeatures& cpu_features); // Ctor. Uses a custom `vad`. VadLevelAnalyzer(float vad_probability_attack, std::unique_ptr vad); diff --git a/modules/audio_processing/include/audio_processing.h b/modules/audio_processing/include/audio_processing.h index e85ac0c63e..b96ce926a1 100644 --- a/modules/audio_processing/include/audio_processing.h +++ b/modules/audio_processing/include/audio_processing.h @@ -350,10 +350,10 @@ class RTC_EXPORT AudioProcessing : public rtc::RefCountInterface { enum LevelEstimator { kRms, kPeak }; bool enabled = false; - struct { + struct FixedDigital { float gain_db = 0.f; } fixed_digital; - struct { + struct AdaptiveDigital { bool enabled = false; float vad_probability_attack = 1.f; LevelEstimator level_estimator = kRms; @@ -365,6 +365,7 @@ class RTC_EXPORT AudioProcessing : public rtc::RefCountInterface { int gain_applier_adjacent_speech_frames_threshold = 1; float max_gain_change_db_per_second = 3.f; float max_output_noise_level_dbfs = -50.f; + bool avx2_allowed = true; } adaptive_digital; } gain_controller2;