AGC2 RNN VAD: safe SIMD optimizations scheme + AVX2 kill switch
In preparation for adding AVX2 code, a safe scheme to support
different SIMD optimizations is added.
Safety features:
- AVX2 kill switch to stop using it even if supported by the
architecture
- struct indicating the available CPU features propagated from
AGC2 to each component; in this way
- better control over the unit tests
- no need to propagate individual kill switches but just
set to false features that are turned off
Note that (i) this CL does not change the performance of the RNN VAD
and (ii) no AVX2 optimization is added yet.
Bug: webrtc:10480
Change-Id: I0e61f3311ecd140f38369cf68b6e5954f3dc1f5a
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/193140
Reviewed-by: Per Åhgren <peah@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32739}
This commit is contained in:
parent
8d4cdd11d8
commit
253f8369bb
@ -56,6 +56,7 @@ rtc_library("adaptive_digital") {
|
||||
|
||||
deps = [
|
||||
":common",
|
||||
":cpu_features",
|
||||
":gain_applier",
|
||||
":noise_level_estimator",
|
||||
":rnn_vad_with_level",
|
||||
@ -163,6 +164,7 @@ rtc_library("rnn_vad_with_level") {
|
||||
]
|
||||
deps = [
|
||||
":common",
|
||||
":cpu_features",
|
||||
"..:audio_frame_view",
|
||||
"../../../api:array_view",
|
||||
"../../../common_audio",
|
||||
@ -172,6 +174,19 @@ rtc_library("rnn_vad_with_level") {
|
||||
]
|
||||
}
|
||||
|
||||
rtc_library("cpu_features") {
|
||||
sources = [
|
||||
"cpu_features.cc",
|
||||
"cpu_features.h",
|
||||
]
|
||||
visibility = [ "./*" ]
|
||||
deps = [
|
||||
"../../../rtc_base:stringutils",
|
||||
"../../../rtc_base/system:arch",
|
||||
"../../../system_wrappers",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_library("adaptive_digital_unittests") {
|
||||
testonly = true
|
||||
configs += [ "..:apm_debug_dump" ]
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
#include "modules/audio_processing/agc2/adaptive_agc.h"
|
||||
|
||||
#include "common_audio/include/audio_util.h"
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
#include "modules/audio_processing/agc2/vad_with_level.h"
|
||||
#include "modules/audio_processing/logging/apm_data_dumper.h"
|
||||
#include "rtc_base/checks.h"
|
||||
@ -32,6 +33,15 @@ constexpr int kGainApplierAdjacentSpeechFramesThreshold = 1;
|
||||
constexpr float kMaxGainChangePerSecondDb = 3.f;
|
||||
constexpr float kMaxOutputNoiseLevelDbfs = -50.f;
|
||||
|
||||
// Detects the available CPU features and applies a kill-switch to AVX2.
|
||||
AvailableCpuFeatures GetAllowedCpuFeatures(bool avx2_allowed) {
|
||||
AvailableCpuFeatures features = GetAvailableCpuFeatures();
|
||||
if (!avx2_allowed) {
|
||||
features.avx2 = false;
|
||||
}
|
||||
return features;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
AdaptiveAgc::AdaptiveAgc(ApmDataDumper* apm_data_dumper)
|
||||
@ -54,7 +64,8 @@ AdaptiveAgc::AdaptiveAgc(ApmDataDumper* apm_data_dumper,
|
||||
.level_estimator_adjacent_speech_frames_threshold,
|
||||
config.adaptive_digital.initial_saturation_margin_db,
|
||||
config.adaptive_digital.extra_saturation_margin_db),
|
||||
vad_(config.adaptive_digital.vad_probability_attack),
|
||||
vad_(config.adaptive_digital.vad_probability_attack,
|
||||
GetAllowedCpuFeatures(config.adaptive_digital.avx2_allowed)),
|
||||
gain_applier_(
|
||||
apm_data_dumper,
|
||||
config.adaptive_digital.gain_applier_adjacent_speech_frames_threshold,
|
||||
|
||||
54
modules/audio_processing/agc2/cpu_features.cc
Normal file
54
modules/audio_processing/agc2/cpu_features.cc
Normal file
@ -0,0 +1,54 @@
|
||||
/*
|
||||
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
|
||||
#include "rtc_base/strings/string_builder.h"
|
||||
#include "rtc_base/system/arch.h"
|
||||
#include "system_wrappers/include/cpu_features_wrapper.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
std::string AvailableCpuFeatures::ToString() const {
|
||||
char buf[64];
|
||||
rtc::SimpleStringBuilder builder(buf);
|
||||
bool first = true;
|
||||
if (sse2) {
|
||||
builder << (first ? "SSE2" : "_SSE2");
|
||||
first = false;
|
||||
}
|
||||
if (avx2) {
|
||||
builder << (first ? "AVX2" : "_AVX2");
|
||||
first = false;
|
||||
}
|
||||
if (neon) {
|
||||
builder << (first ? "NEON" : "_NEON");
|
||||
first = false;
|
||||
}
|
||||
if (first) {
|
||||
return "none";
|
||||
}
|
||||
return builder.str();
|
||||
}
|
||||
|
||||
// Detects available CPU features.
|
||||
AvailableCpuFeatures GetAvailableCpuFeatures() {
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
return {/*sse2=*/GetCPUInfo(kSSE2) != 0,
|
||||
/*avx2=*/GetCPUInfo(kAVX2) != 0,
|
||||
/*neon=*/false};
|
||||
#elif defined(WEBRTC_HAS_NEON)
|
||||
return {/*sse2=*/false,
|
||||
/*avx2=*/false,
|
||||
/*neon=*/true};
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
||||
36
modules/audio_processing/agc2/cpu_features.h
Normal file
36
modules/audio_processing/agc2/cpu_features.h
Normal file
@ -0,0 +1,36 @@
|
||||
/*
|
||||
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_CPU_FEATURES_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_CPU_FEATURES_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
// Collection of flags indicating which CPU features are available on the
|
||||
// current platform. True means available.
|
||||
struct AvailableCpuFeatures {
|
||||
AvailableCpuFeatures(bool sse2, bool avx2, bool neon)
|
||||
: sse2(sse2), avx2(avx2), neon(neon) {}
|
||||
// Intel.
|
||||
bool sse2;
|
||||
bool avx2;
|
||||
// ARM.
|
||||
bool neon;
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
// Detects what CPU features are available.
|
||||
AvailableCpuFeatures GetAvailableCpuFeatures();
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_CPU_FEATURES_H_
|
||||
@ -29,10 +29,10 @@ rtc_library("rnn_vad") {
|
||||
":rnn_vad_sequence_buffer",
|
||||
":rnn_vad_spectral_features",
|
||||
"..:biquad_filter",
|
||||
"..:cpu_features",
|
||||
"../../../../api:array_view",
|
||||
"../../../../api:function_view",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../../../rtc_base:logging",
|
||||
"../../../../rtc_base:safe_compare",
|
||||
"../../../../rtc_base:safe_conversions",
|
||||
"../../../../rtc_base/system:arch",
|
||||
@ -53,16 +53,13 @@ rtc_library("rnn_vad_auto_correlation") {
|
||||
]
|
||||
}
|
||||
|
||||
rtc_library("rnn_vad_common") {
|
||||
rtc_source_set("rnn_vad_common") {
|
||||
# TODO(alessiob): Make this target visibility private.
|
||||
visibility = [
|
||||
":*",
|
||||
"..:rnn_vad_with_level",
|
||||
]
|
||||
sources = [
|
||||
"common.cc",
|
||||
"common.h",
|
||||
]
|
||||
sources = [ "common.h" ]
|
||||
deps = [
|
||||
"../../../../rtc_base/system:arch",
|
||||
"../../../../system_wrappers",
|
||||
@ -91,6 +88,7 @@ rtc_library("rnn_vad_pitch") {
|
||||
deps = [
|
||||
":rnn_vad_auto_correlation",
|
||||
":rnn_vad_common",
|
||||
"..:cpu_features",
|
||||
"../../../../api:array_view",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../../../rtc_base:gtest_prod",
|
||||
@ -156,8 +154,6 @@ if (rtc_include_tests) {
|
||||
"../../../../api:scoped_refptr",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../../../rtc_base:safe_compare",
|
||||
"../../../../rtc_base/system:arch",
|
||||
"../../../../system_wrappers",
|
||||
"../../../../test:fileutils",
|
||||
"../../../../test:test_support",
|
||||
]
|
||||
@ -207,6 +203,7 @@ if (rtc_include_tests) {
|
||||
":rnn_vad_spectral_features",
|
||||
":rnn_vad_symmetric_matrix_buffer",
|
||||
":test_utils",
|
||||
"..:cpu_features",
|
||||
"../..:audioproc_test_utils",
|
||||
"../../../../api:array_view",
|
||||
"../../../../common_audio/",
|
||||
@ -232,6 +229,7 @@ if (rtc_include_tests) {
|
||||
deps = [
|
||||
":rnn_vad",
|
||||
":rnn_vad_common",
|
||||
"..:cpu_features",
|
||||
"../../../../api:array_view",
|
||||
"../../../../common_audio",
|
||||
"../../../../rtc_base:rtc_base_approved",
|
||||
|
||||
@ -1,34 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
|
||||
#include "rtc_base/system/arch.h"
|
||||
#include "system_wrappers/include/cpu_features_wrapper.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
Optimization DetectOptimization() {
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
if (GetCPUInfo(kSSE2) != 0) {
|
||||
return Optimization::kSse2;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
return Optimization::kNeon;
|
||||
#endif
|
||||
|
||||
return Optimization::kNone;
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
@ -71,11 +71,6 @@ static_assert(kCepstralCoeffsHistorySize > 2,
|
||||
|
||||
constexpr int kFeatureVectorSize = 42;
|
||||
|
||||
enum class Optimization { kNone, kSse2, kNeon };
|
||||
|
||||
// Detects what kind of optimizations to use for the code.
|
||||
Optimization DetectOptimization();
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
|
||||
@ -26,13 +26,13 @@ const BiQuadFilter::BiQuadCoefficients kHpfConfig24k = {
|
||||
|
||||
} // namespace
|
||||
|
||||
FeaturesExtractor::FeaturesExtractor()
|
||||
FeaturesExtractor::FeaturesExtractor(const AvailableCpuFeatures& cpu_features)
|
||||
: use_high_pass_filter_(false),
|
||||
pitch_buf_24kHz_(),
|
||||
pitch_buf_24kHz_view_(pitch_buf_24kHz_.GetBufferView()),
|
||||
lp_residual_(kBufSize24kHz),
|
||||
lp_residual_view_(lp_residual_.data(), kBufSize24kHz),
|
||||
pitch_estimator_(),
|
||||
pitch_estimator_(cpu_features),
|
||||
reference_frame_view_(pitch_buf_24kHz_.GetMostRecentValuesView()) {
|
||||
RTC_DCHECK_EQ(kBufSize24kHz, lp_residual_.size());
|
||||
hpf_.Initialize(kHpfConfig24k);
|
||||
|
||||
@ -26,7 +26,7 @@ namespace rnn_vad {
|
||||
// Feature extractor to feed the VAD RNN.
|
||||
class FeaturesExtractor {
|
||||
public:
|
||||
FeaturesExtractor();
|
||||
explicit FeaturesExtractor(const AvailableCpuFeatures& cpu_features);
|
||||
FeaturesExtractor(const FeaturesExtractor&) = delete;
|
||||
FeaturesExtractor& operator=(const FeaturesExtractor&) = delete;
|
||||
~FeaturesExtractor();
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
#include "rtc_base/numerics/safe_conversions.h"
|
||||
@ -77,7 +78,8 @@ TEST(RnnVadTest, FeatureExtractionLowHighPitch) {
|
||||
ASSERT_TRUE(PitchIsValid(low_pitch_hz));
|
||||
ASSERT_TRUE(PitchIsValid(high_pitch_hz));
|
||||
|
||||
FeaturesExtractor features_extractor;
|
||||
const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures();
|
||||
FeaturesExtractor features_extractor(cpu_features);
|
||||
std::vector<float> samples(kNumTestDataSize);
|
||||
std::vector<float> feature_vector(kFeatureVectorSize);
|
||||
ASSERT_EQ(kFeatureVectorSize, rtc::dchecked_cast<int>(feature_vector.size()));
|
||||
|
||||
@ -18,8 +18,9 @@
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
PitchEstimator::PitchEstimator()
|
||||
: y_energy_24kHz_(kRefineNumLags24kHz, 0.f),
|
||||
PitchEstimator::PitchEstimator(const AvailableCpuFeatures& cpu_features)
|
||||
: cpu_features_(cpu_features),
|
||||
y_energy_24kHz_(kRefineNumLags24kHz, 0.f),
|
||||
pitch_buffer_12kHz_(kBufSize12kHz),
|
||||
auto_correlation_12kHz_(kNumLags12kHz) {}
|
||||
|
||||
@ -35,6 +36,7 @@ int PitchEstimator::Estimate(
|
||||
RTC_DCHECK_EQ(auto_correlation_12kHz_.size(),
|
||||
auto_correlation_12kHz_view.size());
|
||||
|
||||
// TODO(bugs.chromium.org/10480): Use `cpu_features_` to estimate pitch.
|
||||
// Perform the initial pitch search at 12 kHz.
|
||||
Decimate2x(pitch_buffer, pitch_buffer_12kHz_view);
|
||||
auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buffer_12kHz_view,
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
|
||||
@ -26,7 +27,7 @@ namespace rnn_vad {
|
||||
// Pitch estimator.
|
||||
class PitchEstimator {
|
||||
public:
|
||||
PitchEstimator();
|
||||
explicit PitchEstimator(const AvailableCpuFeatures& cpu_features);
|
||||
PitchEstimator(const PitchEstimator&) = delete;
|
||||
PitchEstimator& operator=(const PitchEstimator&) = delete;
|
||||
~PitchEstimator();
|
||||
@ -39,6 +40,7 @@ class PitchEstimator {
|
||||
return last_pitch_48kHz_.strength;
|
||||
}
|
||||
|
||||
const AvailableCpuFeatures cpu_features_;
|
||||
PitchInfo last_pitch_48kHz_{};
|
||||
AutoCorrelationCalculator auto_corr_calculator_;
|
||||
std::vector<float> y_energy_24kHz_;
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
@ -29,7 +30,8 @@ TEST(RnnVadTest, PitchSearchWithinTolerance) {
|
||||
const int num_frames = std::min(lp_residual_reader.second, 300); // Max 3 s.
|
||||
std::vector<float> lp_residual(kBufSize24kHz);
|
||||
float expected_pitch_period, expected_pitch_strength;
|
||||
PitchEstimator pitch_estimator;
|
||||
const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures();
|
||||
PitchEstimator pitch_estimator(cpu_features);
|
||||
{
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// FloatingPointExceptionObserver fpe_observer;
|
||||
|
||||
@ -25,7 +25,6 @@
|
||||
#include <numeric>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/logging.h"
|
||||
#include "rtc_base/numerics/safe_conversions.h"
|
||||
#include "third_party/rnnoise/src/rnn_activations.h"
|
||||
#include "third_party/rnnoise/src/rnn_vad_weights.h"
|
||||
@ -273,13 +272,13 @@ FullyConnectedLayer::FullyConnectedLayer(
|
||||
const rtc::ArrayView<const int8_t> bias,
|
||||
const rtc::ArrayView<const int8_t> weights,
|
||||
rtc::FunctionView<float(float)> activation_function,
|
||||
Optimization optimization)
|
||||
const AvailableCpuFeatures& cpu_features)
|
||||
: input_size_(input_size),
|
||||
output_size_(output_size),
|
||||
bias_(GetScaledParams(bias)),
|
||||
weights_(GetPreprocessedFcWeights(weights, output_size)),
|
||||
activation_function_(activation_function),
|
||||
optimization_(optimization) {
|
||||
cpu_features_(cpu_features) {
|
||||
RTC_DCHECK_LE(output_size_, kFullyConnectedLayersMaxUnits)
|
||||
<< "Static over-allocation of fully-connected layers output vectors is "
|
||||
"not sufficient.";
|
||||
@ -296,25 +295,18 @@ rtc::ArrayView<const float> FullyConnectedLayer::GetOutput() const {
|
||||
}
|
||||
|
||||
void FullyConnectedLayer::ComputeOutput(rtc::ArrayView<const float> input) {
|
||||
switch (optimization_) {
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
case Optimization::kSse2:
|
||||
ComputeFullyConnectedLayerOutputSse2(input_size_, output_size_, input,
|
||||
bias_, weights_,
|
||||
activation_function_, output_);
|
||||
break;
|
||||
#endif
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
case Optimization::kNeon:
|
||||
// TODO(bugs.chromium.org/10480): Handle Optimization::kNeon.
|
||||
ComputeFullyConnectedLayerOutput(input_size_, output_size_, input, bias_,
|
||||
weights_, activation_function_, output_);
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
ComputeFullyConnectedLayerOutput(input_size_, output_size_, input, bias_,
|
||||
weights_, activation_function_, output_);
|
||||
// TODO(bugs.chromium.org/10480): Add AVX2.
|
||||
if (cpu_features_.sse2) {
|
||||
ComputeFullyConnectedLayerOutputSse2(input_size_, output_size_, input,
|
||||
bias_, weights_, activation_function_,
|
||||
output_);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
// TODO(bugs.chromium.org/10480): Add Neon.
|
||||
ComputeFullyConnectedLayerOutput(input_size_, output_size_, input, bias_,
|
||||
weights_, activation_function_, output_);
|
||||
}
|
||||
|
||||
GatedRecurrentLayer::GatedRecurrentLayer(
|
||||
@ -322,15 +314,13 @@ GatedRecurrentLayer::GatedRecurrentLayer(
|
||||
const int output_size,
|
||||
const rtc::ArrayView<const int8_t> bias,
|
||||
const rtc::ArrayView<const int8_t> weights,
|
||||
const rtc::ArrayView<const int8_t> recurrent_weights,
|
||||
Optimization optimization)
|
||||
const rtc::ArrayView<const int8_t> recurrent_weights)
|
||||
: input_size_(input_size),
|
||||
output_size_(output_size),
|
||||
bias_(GetPreprocessedGruTensor(bias, output_size)),
|
||||
weights_(GetPreprocessedGruTensor(weights, output_size)),
|
||||
recurrent_weights_(
|
||||
GetPreprocessedGruTensor(recurrent_weights, output_size)),
|
||||
optimization_(optimization) {
|
||||
GetPreprocessedGruTensor(recurrent_weights, output_size)) {
|
||||
RTC_DCHECK_LE(output_size_, kRecurrentLayersMaxUnits)
|
||||
<< "Static over-allocation of recurrent layers state vectors is not "
|
||||
"sufficient.";
|
||||
@ -356,46 +346,30 @@ void GatedRecurrentLayer::Reset() {
|
||||
}
|
||||
|
||||
void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView<const float> input) {
|
||||
switch (optimization_) {
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
case Optimization::kSse2:
|
||||
// TODO(bugs.chromium.org/10480): Handle Optimization::kSse2.
|
||||
ComputeGruLayerOutput(input_size_, output_size_, input, weights_,
|
||||
recurrent_weights_, bias_, state_);
|
||||
break;
|
||||
#endif
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
case Optimization::kNeon:
|
||||
// TODO(bugs.chromium.org/10480): Handle Optimization::kNeon.
|
||||
ComputeGruLayerOutput(input_size_, output_size_, input, weights_,
|
||||
recurrent_weights_, bias_, state_);
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
ComputeGruLayerOutput(input_size_, output_size_, input, weights_,
|
||||
recurrent_weights_, bias_, state_);
|
||||
}
|
||||
// TODO(bugs.chromium.org/10480): Add AVX2.
|
||||
// TODO(bugs.chromium.org/10480): Add Neon.
|
||||
ComputeGruLayerOutput(input_size_, output_size_, input, weights_,
|
||||
recurrent_weights_, bias_, state_);
|
||||
}
|
||||
|
||||
RnnBasedVad::RnnBasedVad()
|
||||
RnnBasedVad::RnnBasedVad(const AvailableCpuFeatures& cpu_features)
|
||||
: input_layer_(kInputLayerInputSize,
|
||||
kInputLayerOutputSize,
|
||||
kInputDenseBias,
|
||||
kInputDenseWeights,
|
||||
TansigApproximated,
|
||||
DetectOptimization()),
|
||||
cpu_features),
|
||||
hidden_layer_(kInputLayerOutputSize,
|
||||
kHiddenLayerOutputSize,
|
||||
kHiddenGruBias,
|
||||
kHiddenGruWeights,
|
||||
kHiddenGruRecurrentWeights,
|
||||
DetectOptimization()),
|
||||
kHiddenGruRecurrentWeights),
|
||||
output_layer_(kHiddenLayerOutputSize,
|
||||
kOutputLayerOutputSize,
|
||||
kOutputDenseBias,
|
||||
kOutputDenseWeights,
|
||||
SigmoidApproximated,
|
||||
DetectOptimization()) {
|
||||
cpu_features) {
|
||||
// Input-output chaining size checks.
|
||||
RTC_DCHECK_EQ(input_layer_.output_size(), hidden_layer_.input_size())
|
||||
<< "The input and the hidden layers sizes do not match.";
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "api/function_view.h"
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "rtc_base/system/arch.h"
|
||||
|
||||
@ -45,13 +46,12 @@ class FullyConnectedLayer {
|
||||
rtc::ArrayView<const int8_t> bias,
|
||||
rtc::ArrayView<const int8_t> weights,
|
||||
rtc::FunctionView<float(float)> activation_function,
|
||||
Optimization optimization);
|
||||
const AvailableCpuFeatures& cpu_features);
|
||||
FullyConnectedLayer(const FullyConnectedLayer&) = delete;
|
||||
FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete;
|
||||
~FullyConnectedLayer();
|
||||
int input_size() const { return input_size_; }
|
||||
int output_size() const { return output_size_; }
|
||||
Optimization optimization() const { return optimization_; }
|
||||
rtc::ArrayView<const float> GetOutput() const;
|
||||
// Computes the fully-connected layer output.
|
||||
void ComputeOutput(rtc::ArrayView<const float> input);
|
||||
@ -65,7 +65,7 @@ class FullyConnectedLayer {
|
||||
// The output vector of a recurrent layer has length equal to |output_size_|.
|
||||
// However, for efficiency, over-allocation is used.
|
||||
std::array<float, kFullyConnectedLayersMaxUnits> output_;
|
||||
const Optimization optimization_;
|
||||
const AvailableCpuFeatures cpu_features_;
|
||||
};
|
||||
|
||||
// Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as
|
||||
@ -76,14 +76,12 @@ class GatedRecurrentLayer {
|
||||
int output_size,
|
||||
rtc::ArrayView<const int8_t> bias,
|
||||
rtc::ArrayView<const int8_t> weights,
|
||||
rtc::ArrayView<const int8_t> recurrent_weights,
|
||||
Optimization optimization);
|
||||
rtc::ArrayView<const int8_t> recurrent_weights);
|
||||
GatedRecurrentLayer(const GatedRecurrentLayer&) = delete;
|
||||
GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete;
|
||||
~GatedRecurrentLayer();
|
||||
int input_size() const { return input_size_; }
|
||||
int output_size() const { return output_size_; }
|
||||
Optimization optimization() const { return optimization_; }
|
||||
rtc::ArrayView<const float> GetOutput() const;
|
||||
void Reset();
|
||||
// Computes the recurrent layer output and updates the status.
|
||||
@ -98,13 +96,12 @@ class GatedRecurrentLayer {
|
||||
// The state vector of a recurrent layer has length equal to |output_size_|.
|
||||
// However, to avoid dynamic allocation, over-allocation is used.
|
||||
std::array<float, kRecurrentLayersMaxUnits> state_;
|
||||
const Optimization optimization_;
|
||||
};
|
||||
|
||||
// Recurrent network based VAD.
|
||||
class RnnBasedVad {
|
||||
public:
|
||||
RnnBasedVad();
|
||||
explicit RnnBasedVad(const AvailableCpuFeatures& cpu_features);
|
||||
RnnBasedVad(const RnnBasedVad&) = delete;
|
||||
RnnBasedVad& operator=(const RnnBasedVad&) = delete;
|
||||
~RnnBasedVad();
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
|
||||
#include "modules/audio_processing/test/performance_timer.h"
|
||||
#include "rtc_base/checks.h"
|
||||
@ -27,7 +28,6 @@
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace test {
|
||||
|
||||
namespace {
|
||||
|
||||
void TestFullyConnectedLayer(FullyConnectedLayer* fc,
|
||||
@ -39,26 +39,25 @@ void TestFullyConnectedLayer(FullyConnectedLayer* fc,
|
||||
}
|
||||
|
||||
void TestGatedRecurrentLayer(
|
||||
GatedRecurrentLayer* gru,
|
||||
GatedRecurrentLayer& gru,
|
||||
rtc::ArrayView<const float> input_sequence,
|
||||
rtc::ArrayView<const float> expected_output_sequence) {
|
||||
RTC_CHECK(gru);
|
||||
auto gru_output_view = gru->GetOutput();
|
||||
auto gru_output_view = gru.GetOutput();
|
||||
const int input_sequence_length = rtc::CheckedDivExact(
|
||||
rtc::dchecked_cast<int>(input_sequence.size()), gru->input_size());
|
||||
rtc::dchecked_cast<int>(input_sequence.size()), gru.input_size());
|
||||
const int output_sequence_length = rtc::CheckedDivExact(
|
||||
rtc::dchecked_cast<int>(expected_output_sequence.size()),
|
||||
gru->output_size());
|
||||
gru.output_size());
|
||||
ASSERT_EQ(input_sequence_length, output_sequence_length)
|
||||
<< "The test data length is invalid.";
|
||||
// Feed the GRU layer and check the output at every step.
|
||||
gru->Reset();
|
||||
gru.Reset();
|
||||
for (int i = 0; i < input_sequence_length; ++i) {
|
||||
SCOPED_TRACE(i);
|
||||
gru->ComputeOutput(
|
||||
input_sequence.subview(i * gru->input_size(), gru->input_size()));
|
||||
gru.ComputeOutput(
|
||||
input_sequence.subview(i * gru.input_size(), gru.input_size()));
|
||||
const auto expected_output = expected_output_sequence.subview(
|
||||
i * gru->output_size(), gru->output_size());
|
||||
i * gru.output_size(), gru.output_size());
|
||||
ExpectNearAbsolute(expected_output, gru_output_view, 3e-6f);
|
||||
}
|
||||
}
|
||||
@ -134,141 +133,94 @@ constexpr std::array<float, 16> kGruExpectedOutputSequence = {
|
||||
0.00781069f, 0.75267816f, 0.f, 0.02579715f,
|
||||
0.00471378f, 0.59162533f, 0.11087593f, 0.01334511f};
|
||||
|
||||
std::string GetOptimizationName(Optimization optimization) {
|
||||
switch (optimization) {
|
||||
case Optimization::kSse2:
|
||||
return "SSE2";
|
||||
case Optimization::kNeon:
|
||||
return "NEON";
|
||||
case Optimization::kNone:
|
||||
return "none";
|
||||
}
|
||||
}
|
||||
|
||||
struct Result {
|
||||
Optimization optimization;
|
||||
double average_us;
|
||||
double std_dev_us;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// Checks that the output of a fully connected layer is within tolerance given
|
||||
// test input data.
|
||||
TEST(RnnVadTest, CheckFullyConnectedLayerOutput) {
|
||||
FullyConnectedLayer fc(rnnoise::kInputLayerInputSize,
|
||||
rnnoise::kInputLayerOutputSize,
|
||||
rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights,
|
||||
rnnoise::TansigApproximated, Optimization::kNone);
|
||||
TestFullyConnectedLayer(&fc, kFullyConnectedInputVector,
|
||||
kFullyConnectedExpectedOutput);
|
||||
}
|
||||
|
||||
// Checks that the output of a GRU layer is within tolerance given test input
|
||||
// data.
|
||||
TEST(RnnVadTest, CheckGatedRecurrentLayer) {
|
||||
GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights,
|
||||
kGruRecurrentWeights, Optimization::kNone);
|
||||
TestGatedRecurrentLayer(&gru, kGruInputSequence, kGruExpectedOutputSequence);
|
||||
}
|
||||
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
|
||||
// Like CheckFullyConnectedLayerOutput, but testing the SSE2 implementation.
|
||||
TEST(RnnVadTest, CheckFullyConnectedLayerOutputSse2) {
|
||||
if (!IsOptimizationAvailable(Optimization::kSse2)) {
|
||||
return;
|
||||
}
|
||||
|
||||
FullyConnectedLayer fc(rnnoise::kInputLayerInputSize,
|
||||
rnnoise::kInputLayerOutputSize,
|
||||
rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights,
|
||||
rnnoise::TansigApproximated, Optimization::kSse2);
|
||||
TestFullyConnectedLayer(&fc, kFullyConnectedInputVector,
|
||||
kFullyConnectedExpectedOutput);
|
||||
}
|
||||
|
||||
// Like CheckGatedRecurrentLayer, but testing the SSE2 implementation.
|
||||
TEST(RnnVadTest, CheckGatedRecurrentLayerSse2) {
|
||||
if (!IsOptimizationAvailable(Optimization::kSse2)) {
|
||||
return;
|
||||
}
|
||||
|
||||
GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights,
|
||||
kGruRecurrentWeights, Optimization::kSse2);
|
||||
TestGatedRecurrentLayer(&gru, kGruInputSequence, kGruExpectedOutputSequence);
|
||||
}
|
||||
|
||||
#endif // WEBRTC_ARCH_X86_FAMILY
|
||||
|
||||
TEST(RnnVadTest, DISABLED_BenchmarkFullyConnectedLayer) {
|
||||
std::vector<std::unique_ptr<FullyConnectedLayer>> implementations;
|
||||
implementations.emplace_back(std::make_unique<FullyConnectedLayer>(
|
||||
rnnoise::kInputLayerInputSize, rnnoise::kInputLayerOutputSize,
|
||||
rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights,
|
||||
rnnoise::TansigApproximated, Optimization::kNone));
|
||||
if (IsOptimizationAvailable(Optimization::kSse2)) {
|
||||
implementations.emplace_back(std::make_unique<FullyConnectedLayer>(
|
||||
rnnoise::kInputLayerInputSize, rnnoise::kInputLayerOutputSize,
|
||||
rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights,
|
||||
rnnoise::TansigApproximated, Optimization::kSse2));
|
||||
}
|
||||
|
||||
std::vector<Result> results;
|
||||
constexpr int number_of_tests = 10000;
|
||||
for (auto& fc : implementations) {
|
||||
::webrtc::test::PerformanceTimer perf_timer(number_of_tests);
|
||||
for (int k = 0; k < number_of_tests; ++k) {
|
||||
perf_timer.StartTimer();
|
||||
fc->ComputeOutput(kFullyConnectedInputVector);
|
||||
perf_timer.StopTimer();
|
||||
}
|
||||
results.push_back({fc->optimization(), perf_timer.GetDurationAverage(),
|
||||
perf_timer.GetDurationStandardDeviation()});
|
||||
}
|
||||
|
||||
for (const auto& result : results) {
|
||||
RTC_LOG(LS_INFO) << GetOptimizationName(result.optimization) << ": "
|
||||
<< (result.average_us / 1e3) << " +/- "
|
||||
<< (result.std_dev_us / 1e3) << " ms";
|
||||
}
|
||||
kGruRecurrentWeights);
|
||||
TestGatedRecurrentLayer(gru, kGruInputSequence, kGruExpectedOutputSequence);
|
||||
}
|
||||
|
||||
TEST(RnnVadTest, DISABLED_BenchmarkGatedRecurrentLayer) {
|
||||
std::vector<std::unique_ptr<GatedRecurrentLayer>> implementations;
|
||||
implementations.emplace_back(std::make_unique<GatedRecurrentLayer>(
|
||||
kGruInputSize, kGruOutputSize, kGruBias, kGruWeights,
|
||||
kGruRecurrentWeights, Optimization::kNone));
|
||||
GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights,
|
||||
kGruRecurrentWeights);
|
||||
|
||||
rtc::ArrayView<const float> input_sequence(kGruInputSequence);
|
||||
static_assert(kGruInputSequence.size() % kGruInputSize == 0, "");
|
||||
constexpr int input_sequence_length =
|
||||
kGruInputSequence.size() / kGruInputSize;
|
||||
|
||||
std::vector<Result> results;
|
||||
constexpr int number_of_tests = 10000;
|
||||
for (auto& gru : implementations) {
|
||||
::webrtc::test::PerformanceTimer perf_timer(number_of_tests);
|
||||
gru->Reset();
|
||||
for (int k = 0; k < number_of_tests; ++k) {
|
||||
perf_timer.StartTimer();
|
||||
for (int i = 0; i < input_sequence_length; ++i) {
|
||||
gru->ComputeOutput(
|
||||
input_sequence.subview(i * gru->input_size(), gru->input_size()));
|
||||
}
|
||||
perf_timer.StopTimer();
|
||||
constexpr int kNumTests = 10000;
|
||||
::webrtc::test::PerformanceTimer perf_timer(kNumTests);
|
||||
for (int k = 0; k < kNumTests; ++k) {
|
||||
perf_timer.StartTimer();
|
||||
for (int i = 0; i < input_sequence_length; ++i) {
|
||||
gru.ComputeOutput(
|
||||
input_sequence.subview(i * gru.input_size(), gru.input_size()));
|
||||
}
|
||||
results.push_back({gru->optimization(), perf_timer.GetDurationAverage(),
|
||||
perf_timer.GetDurationStandardDeviation()});
|
||||
}
|
||||
|
||||
for (const auto& result : results) {
|
||||
RTC_LOG(LS_INFO) << GetOptimizationName(result.optimization) << ": "
|
||||
<< (result.average_us / 1e3) << " +/- "
|
||||
<< (result.std_dev_us / 1e3) << " ms";
|
||||
perf_timer.StopTimer();
|
||||
}
|
||||
RTC_LOG(LS_INFO) << (perf_timer.GetDurationAverage() / 1000) << " +/- "
|
||||
<< (perf_timer.GetDurationStandardDeviation() / 1000)
|
||||
<< " ms";
|
||||
}
|
||||
|
||||
class RnnParametrization
|
||||
: public ::testing::TestWithParam<AvailableCpuFeatures> {};
|
||||
|
||||
// Checks that the output of a fully connected layer is within tolerance given
|
||||
// test input data.
|
||||
TEST_P(RnnParametrization, CheckFullyConnectedLayerOutput) {
|
||||
FullyConnectedLayer fc(
|
||||
rnnoise::kInputLayerInputSize, rnnoise::kInputLayerOutputSize,
|
||||
rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights,
|
||||
rnnoise::TansigApproximated, /*cpu_features=*/GetParam());
|
||||
TestFullyConnectedLayer(&fc, kFullyConnectedInputVector,
|
||||
kFullyConnectedExpectedOutput);
|
||||
}
|
||||
|
||||
TEST_P(RnnParametrization, DISABLED_BenchmarkFullyConnectedLayer) {
|
||||
const AvailableCpuFeatures cpu_features = GetParam();
|
||||
FullyConnectedLayer fc(rnnoise::kInputLayerInputSize,
|
||||
rnnoise::kInputLayerOutputSize,
|
||||
rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights,
|
||||
rnnoise::TansigApproximated, cpu_features);
|
||||
|
||||
constexpr int kNumTests = 10000;
|
||||
::webrtc::test::PerformanceTimer perf_timer(kNumTests);
|
||||
for (int k = 0; k < kNumTests; ++k) {
|
||||
perf_timer.StartTimer();
|
||||
fc.ComputeOutput(kFullyConnectedInputVector);
|
||||
perf_timer.StopTimer();
|
||||
}
|
||||
RTC_LOG(LS_INFO) << "CPU features: " << cpu_features.ToString() << " | "
|
||||
<< (perf_timer.GetDurationAverage() / 1000) << " +/- "
|
||||
<< (perf_timer.GetDurationStandardDeviation() / 1000)
|
||||
<< " ms";
|
||||
}
|
||||
|
||||
// Finds the relevant CPU features combinations to test.
|
||||
std::vector<AvailableCpuFeatures> GetCpuFeaturesToTest() {
|
||||
std::vector<AvailableCpuFeatures> v;
|
||||
v.push_back({/*sse2=*/false, /*avx2=*/false, /*neon=*/false});
|
||||
AvailableCpuFeatures available = GetAvailableCpuFeatures();
|
||||
if (available.sse2) {
|
||||
AvailableCpuFeatures features(
|
||||
{/*sse2=*/true, /*avx2=*/false, /*neon=*/false});
|
||||
v.push_back(features);
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
RnnVadTest,
|
||||
RnnParametrization,
|
||||
::testing::ValuesIn(GetCpuFeaturesToTest()),
|
||||
[](const ::testing::TestParamInfo<AvailableCpuFeatures>& info) {
|
||||
return info.param.ToString();
|
||||
});
|
||||
|
||||
} // namespace
|
||||
} // namespace test
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
#include "absl/flags/parse.h"
|
||||
#include "common_audio/resampler/push_sinc_resampler.h"
|
||||
#include "common_audio/wav_file.h"
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
|
||||
@ -63,9 +64,10 @@ int main(int argc, char* argv[]) {
|
||||
samples_10ms.resize(frame_size_10ms);
|
||||
std::array<float, kFrameSize10ms24kHz> samples_10ms_24kHz;
|
||||
PushSincResampler resampler(frame_size_10ms, kFrameSize10ms24kHz);
|
||||
FeaturesExtractor features_extractor;
|
||||
const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures();
|
||||
FeaturesExtractor features_extractor(cpu_features);
|
||||
std::array<float, kFeatureVectorSize> feature_vector;
|
||||
RnnBasedVad rnn_vad;
|
||||
RnnBasedVad rnn_vad(cpu_features);
|
||||
|
||||
// Compute VAD probabilities.
|
||||
while (true) {
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "common_audio/resampler/push_sinc_resampler.h"
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
|
||||
@ -57,13 +58,17 @@ TEST(RnnVadTest, CheckWriteComputedOutputIsFalse) {
|
||||
<< "Cannot land if kWriteComputedOutput is true.";
|
||||
}
|
||||
|
||||
class RnnVadProbabilityParametrization
|
||||
: public ::testing::TestWithParam<AvailableCpuFeatures> {};
|
||||
|
||||
// Checks that the computed VAD probability for a test input sequence sampled at
|
||||
// 48 kHz is within tolerance.
|
||||
TEST(RnnVadTest, RnnVadProbabilityWithinTolerance) {
|
||||
TEST_P(RnnVadProbabilityParametrization, RnnVadProbabilityWithinTolerance) {
|
||||
// Init resampler, feature extractor and RNN.
|
||||
PushSincResampler decimator(kFrameSize10ms48kHz, kFrameSize10ms24kHz);
|
||||
FeaturesExtractor features_extractor;
|
||||
RnnBasedVad rnn_vad;
|
||||
const AvailableCpuFeatures cpu_features = GetParam();
|
||||
FeaturesExtractor features_extractor(cpu_features);
|
||||
RnnBasedVad rnn_vad(cpu_features);
|
||||
|
||||
// Init input samples and expected output readers.
|
||||
auto samples_reader = CreatePcmSamplesReader(kFrameSize10ms48kHz);
|
||||
@ -111,7 +116,7 @@ TEST(RnnVadTest, RnnVadProbabilityWithinTolerance) {
|
||||
// follows:
|
||||
// - on desktop: run the this unit test adding "--logs";
|
||||
// - on android: run the this unit test adding "--logcat-output-file".
|
||||
TEST(RnnVadTest, DISABLED_RnnVadPerformance) {
|
||||
TEST_P(RnnVadProbabilityParametrization, DISABLED_RnnVadPerformance) {
|
||||
// PCM samples reader and buffers.
|
||||
auto samples_reader = CreatePcmSamplesReader(kFrameSize10ms48kHz);
|
||||
const int num_frames = samples_reader.second;
|
||||
@ -127,9 +132,10 @@ TEST(RnnVadTest, DISABLED_RnnVadPerformance) {
|
||||
kFrameSize10ms24kHz);
|
||||
}
|
||||
// Initialize.
|
||||
FeaturesExtractor features_extractor;
|
||||
const AvailableCpuFeatures cpu_features = GetParam();
|
||||
FeaturesExtractor features_extractor(cpu_features);
|
||||
std::array<float, kFeatureVectorSize> feature_vector;
|
||||
RnnBasedVad rnn_vad;
|
||||
RnnBasedVad rnn_vad(cpu_features);
|
||||
constexpr int number_of_tests = 100;
|
||||
::webrtc::test::PerformanceTimer perf_timer(number_of_tests);
|
||||
for (int k = 0; k < number_of_tests; ++k) {
|
||||
@ -152,6 +158,27 @@ TEST(RnnVadTest, DISABLED_RnnVadPerformance) {
|
||||
perf_timer.GetDurationStandardDeviation());
|
||||
}
|
||||
|
||||
// Finds the relevant CPU features combinations to test.
|
||||
std::vector<AvailableCpuFeatures> GetCpuFeaturesToTest() {
|
||||
std::vector<AvailableCpuFeatures> v;
|
||||
v.push_back({/*sse2=*/false, /*avx2=*/false, /*neon=*/false});
|
||||
AvailableCpuFeatures available = GetAvailableCpuFeatures();
|
||||
if (available.sse2) {
|
||||
AvailableCpuFeatures features(
|
||||
{/*sse2=*/true, /*avx2=*/false, /*neon=*/false});
|
||||
v.push_back(features);
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
RnnVadTest,
|
||||
RnnVadProbabilityParametrization,
|
||||
::testing::ValuesIn(GetCpuFeaturesToTest()),
|
||||
[](const ::testing::TestParamInfo<AvailableCpuFeatures>& info) {
|
||||
return info.param.ToString();
|
||||
});
|
||||
|
||||
} // namespace test
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
@ -15,8 +15,6 @@
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
#include "rtc_base/system/arch.h"
|
||||
#include "system_wrappers/include/cpu_features_wrapper.h"
|
||||
#include "test/gtest.h"
|
||||
#include "test/testsupport/file_utils.h"
|
||||
|
||||
@ -111,25 +109,6 @@ PitchTestData::GetPitchBufAutoCorrCoeffsView() const {
|
||||
kNumPitchBufAutoCorrCoeffs};
|
||||
}
|
||||
|
||||
bool IsOptimizationAvailable(Optimization optimization) {
|
||||
switch (optimization) {
|
||||
case Optimization::kSse2:
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
return GetCPUInfo(kSSE2) != 0;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
case Optimization::kNeon:
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
case Optimization::kNone:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
@ -152,9 +152,6 @@ class PitchTestData {
|
||||
std::array<float, kPitchTestDataSize> test_data_;
|
||||
};
|
||||
|
||||
// Returns true if the given optimization is available.
|
||||
bool IsOptimizationAvailable(Optimization optimization);
|
||||
|
||||
} // namespace test
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
@ -32,7 +32,8 @@ using VoiceActivityDetector = VadLevelAnalyzer::VoiceActivityDetector;
|
||||
// Computes the speech probability on the first channel.
|
||||
class Vad : public VoiceActivityDetector {
|
||||
public:
|
||||
Vad() = default;
|
||||
explicit Vad(const AvailableCpuFeatures& cpu_features)
|
||||
: features_extractor_(cpu_features), rnn_vad_(cpu_features) {}
|
||||
Vad(const Vad&) = delete;
|
||||
Vad& operator=(const Vad&) = delete;
|
||||
~Vad() = default;
|
||||
@ -80,10 +81,12 @@ float SmoothedVadProbability(float p_old, float p_new, float attack) {
|
||||
|
||||
VadLevelAnalyzer::VadLevelAnalyzer()
|
||||
: VadLevelAnalyzer(kDefaultSmoothedVadProbabilityAttack,
|
||||
std::make_unique<Vad>()) {}
|
||||
GetAvailableCpuFeatures()) {}
|
||||
|
||||
VadLevelAnalyzer::VadLevelAnalyzer(float vad_probability_attack)
|
||||
: VadLevelAnalyzer(vad_probability_attack, std::make_unique<Vad>()) {}
|
||||
VadLevelAnalyzer::VadLevelAnalyzer(float vad_probability_attack,
|
||||
const AvailableCpuFeatures& cpu_features)
|
||||
: VadLevelAnalyzer(vad_probability_attack,
|
||||
std::make_unique<Vad>(cpu_features)) {}
|
||||
|
||||
VadLevelAnalyzer::VadLevelAnalyzer(float vad_probability_attack,
|
||||
std::unique_ptr<VoiceActivityDetector> vad)
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
#include "modules/audio_processing/include/audio_frame_view.h"
|
||||
|
||||
namespace webrtc {
|
||||
@ -36,7 +37,8 @@ class VadLevelAnalyzer {
|
||||
|
||||
// Ctor. Uses the default VAD.
|
||||
VadLevelAnalyzer();
|
||||
explicit VadLevelAnalyzer(float vad_probability_attack);
|
||||
VadLevelAnalyzer(float vad_probability_attack,
|
||||
const AvailableCpuFeatures& cpu_features);
|
||||
// Ctor. Uses a custom `vad`.
|
||||
VadLevelAnalyzer(float vad_probability_attack,
|
||||
std::unique_ptr<VoiceActivityDetector> vad);
|
||||
|
||||
@ -350,10 +350,10 @@ class RTC_EXPORT AudioProcessing : public rtc::RefCountInterface {
|
||||
|
||||
enum LevelEstimator { kRms, kPeak };
|
||||
bool enabled = false;
|
||||
struct {
|
||||
struct FixedDigital {
|
||||
float gain_db = 0.f;
|
||||
} fixed_digital;
|
||||
struct {
|
||||
struct AdaptiveDigital {
|
||||
bool enabled = false;
|
||||
float vad_probability_attack = 1.f;
|
||||
LevelEstimator level_estimator = kRms;
|
||||
@ -365,6 +365,7 @@ class RTC_EXPORT AudioProcessing : public rtc::RefCountInterface {
|
||||
int gain_applier_adjacent_speech_frames_threshold = 1;
|
||||
float max_gain_change_db_per_second = 3.f;
|
||||
float max_output_noise_level_dbfs = -50.f;
|
||||
bool avx2_allowed = true;
|
||||
} adaptive_digital;
|
||||
} gain_controller2;
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user