Transient Suppressor (TS): add alternative VAD modes
It is now required to specify which VAD is used to compute the speech probability passed when `TransientSuppressor::Suppress()` is called. In this way, it is possible to adapt parameters and/or logic of a `TransientSuppressor` implementation to the behavior of the used VAD. This CL also adds a "no VAD" mode option, which ignores the speech probability argument passed when `Suppress()` and always applies mild suppression to preserve transparency. Finally, this CL adds a field trial to choose which VAD is used by APM for transient suppression. Wiring the RNN VAD to TS will be done in a follow-up CL. Bug: webrtc:13663 Change-Id: I21ed49f91875a4ee0f04db97ea87c0dbc3db7f8a Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/250962 Reviewed-by: Hanna Silen <silen@webrtc.org> Commit-Queue: Alessio Bazzica <alessiob@webrtc.org> Cr-Commit-Position: refs/heads/main@{#36485}
This commit is contained in:
parent
9190fef84d
commit
efbe3af366
@ -205,7 +205,10 @@ rtc_library("audio_processing") {
|
||||
"transient:transient_suppressor_api",
|
||||
"vad",
|
||||
]
|
||||
absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ]
|
||||
absl_deps = [
|
||||
"//third_party/abseil-cpp/absl/strings",
|
||||
"//third_party/abseil-cpp/absl/types:optional",
|
||||
]
|
||||
|
||||
deps += [
|
||||
"../../common_audio",
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "api/array_view.h"
|
||||
#include "api/audio/audio_frame.h"
|
||||
@ -67,6 +68,29 @@ bool UseSetupSpecificDefaultAec3Congfig() {
|
||||
"WebRTC-Aec3SetupSpecificDefaultConfigDefaultsKillSwitch");
|
||||
}
|
||||
|
||||
// If the "WebRTC-Audio-TransientSuppressorVadMode" field trial is unspecified,
|
||||
// returns `TransientSuppressor::VadMode::kDefault`, otherwise parses the field
|
||||
// trial and returns the specified mode:
|
||||
// - WebRTC-Audio-TransientSuppressorVadMode/Enabled-Default returns `kDefault`;
|
||||
// - WebRTC-Audio-TransientSuppressorVadMode/Enabled-RnnVad returns `kRnnVad`;
|
||||
// - WebRTC-Audio-TransientSuppressorVadMode/Enabled-NoVad returns `kNoVad`.
|
||||
TransientSuppressor::VadMode GetTransientSuppressorVadMode() {
|
||||
constexpr char kFieldTrial[] = "WebRTC-Audio-TransientSuppressorVadMode";
|
||||
std::string full_name = webrtc::field_trial::FindFullName(kFieldTrial);
|
||||
if (full_name.empty() || absl::EndsWith(full_name, "-Default")) {
|
||||
return TransientSuppressor::VadMode::kDefault;
|
||||
}
|
||||
if (absl::EndsWith(full_name, "-RnnVad")) {
|
||||
return TransientSuppressor::VadMode::kRnnVad;
|
||||
}
|
||||
if (absl::EndsWith(full_name, "-NoVad")) {
|
||||
return TransientSuppressor::VadMode::kNoVad;
|
||||
}
|
||||
// Fallback to default.
|
||||
RTC_LOG(LS_WARNING) << "Invalid parameter for " << kFieldTrial;
|
||||
return TransientSuppressor::VadMode::kDefault;
|
||||
}
|
||||
|
||||
// Identify the native processing rate that best handles a sample rate.
|
||||
int SuitableProcessRate(int minimum_rate,
|
||||
int max_splitting_rate,
|
||||
@ -241,6 +265,7 @@ AudioProcessingImpl::AudioProcessingImpl(
|
||||
UseSetupSpecificDefaultAec3Congfig()),
|
||||
use_denormal_disabler_(
|
||||
!field_trial::IsEnabled("WebRTC-ApmDenormalDisablerKillSwitch")),
|
||||
transient_suppressor_vad_mode_(GetTransientSuppressorVadMode()),
|
||||
capture_runtime_settings_(RuntimeSettingQueueSize()),
|
||||
render_runtime_settings_(RuntimeSettingQueueSize()),
|
||||
capture_runtime_settings_enqueuer_(&capture_runtime_settings_),
|
||||
@ -1244,14 +1269,21 @@ int AudioProcessingImpl::ProcessCaptureStreamLocked() {
|
||||
capture_buffer->num_frames()));
|
||||
}
|
||||
|
||||
// TODO(aluebs): Investigate if the transient suppression placement should
|
||||
// be before or after the AGC.
|
||||
if (submodules_.transient_suppressor) {
|
||||
float voice_probability =
|
||||
submodules_.agc_manager.get()
|
||||
? submodules_.agc_manager->voice_probability()
|
||||
: 1.f;
|
||||
|
||||
float voice_probability = 1.0f;
|
||||
switch (transient_suppressor_vad_mode_) {
|
||||
case TransientSuppressor::VadMode::kDefault:
|
||||
if (submodules_.agc_manager) {
|
||||
voice_probability = submodules_.agc_manager->voice_probability();
|
||||
}
|
||||
break;
|
||||
case TransientSuppressor::VadMode::kRnnVad:
|
||||
// TODO(bugs.webrtc.org/13663): Use RNN VAD.
|
||||
break;
|
||||
case TransientSuppressor::VadMode::kNoVad:
|
||||
// The transient suppressor will ignore `voice_probability`.
|
||||
break;
|
||||
}
|
||||
submodules_.transient_suppressor->Suppress(
|
||||
capture_buffer->channels()[0], capture_buffer->num_frames(),
|
||||
capture_buffer->num_channels(),
|
||||
@ -1672,8 +1704,8 @@ void AudioProcessingImpl::InitializeTransientSuppressor() {
|
||||
!constants_.transient_suppressor_forced_off) {
|
||||
// Attempt to create a transient suppressor, if one is not already created.
|
||||
if (!submodules_.transient_suppressor) {
|
||||
submodules_.transient_suppressor =
|
||||
CreateTransientSuppressor(submodule_creation_overrides_);
|
||||
submodules_.transient_suppressor = CreateTransientSuppressor(
|
||||
submodule_creation_overrides_, transient_suppressor_vad_mode_);
|
||||
}
|
||||
if (submodules_.transient_suppressor) {
|
||||
submodules_.transient_suppressor->Initialize(
|
||||
|
||||
@ -185,6 +185,8 @@ class AudioProcessingImpl : public AudioProcessing {
|
||||
|
||||
const bool use_denormal_disabler_;
|
||||
|
||||
const TransientSuppressor::VadMode transient_suppressor_vad_mode_;
|
||||
|
||||
SwapQueue<RuntimeSetting> capture_runtime_settings_;
|
||||
SwapQueue<RuntimeSetting> render_runtime_settings_;
|
||||
|
||||
|
||||
@ -17,14 +17,15 @@
|
||||
namespace webrtc {
|
||||
|
||||
std::unique_ptr<TransientSuppressor> CreateTransientSuppressor(
|
||||
const ApmSubmoduleCreationOverrides& overrides) {
|
||||
const ApmSubmoduleCreationOverrides& overrides,
|
||||
TransientSuppressor::VadMode vad_mode) {
|
||||
#ifdef WEBRTC_EXCLUDE_TRANSIENT_SUPPRESSOR
|
||||
return nullptr;
|
||||
#else
|
||||
if (overrides.transient_suppression) {
|
||||
return nullptr;
|
||||
}
|
||||
return std::make_unique<TransientSuppressorImpl>();
|
||||
return std::make_unique<TransientSuppressorImpl>(vad_mode);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@ -31,7 +31,8 @@ struct ApmSubmoduleCreationOverrides {
|
||||
// * WEBRTC_EXCLUDE_TRANSIENT_SUPPRESSOR is defined
|
||||
// * The corresponding override in `overrides` is enabled.
|
||||
std::unique_ptr<TransientSuppressor> CreateTransientSuppressor(
|
||||
const ApmSubmoduleCreationOverrides& overrides);
|
||||
const ApmSubmoduleCreationOverrides& overrides,
|
||||
TransientSuppressor::VadMode vad_mode);
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
|
||||
@ -73,6 +73,7 @@ if (rtc_include_tests) {
|
||||
"transient_suppression_test.cc",
|
||||
]
|
||||
deps = [
|
||||
":transient_suppressor_api",
|
||||
":transient_suppressor_impl",
|
||||
"..:audio_processing",
|
||||
"../../../common_audio",
|
||||
@ -103,6 +104,7 @@ if (rtc_include_tests) {
|
||||
"wpd_tree_unittest.cc",
|
||||
]
|
||||
deps = [
|
||||
":transient_suppressor_api",
|
||||
":transient_suppressor_impl",
|
||||
"../../../rtc_base:stringutils",
|
||||
"../../../rtc_base/system:file_wrapper",
|
||||
|
||||
@ -20,6 +20,7 @@
|
||||
#include "absl/flags/parse.h"
|
||||
#include "common_audio/include/audio_util.h"
|
||||
#include "modules/audio_processing/agc/agc.h"
|
||||
#include "modules/audio_processing/transient/transient_suppressor.h"
|
||||
#include "modules/audio_processing/transient/transient_suppressor_impl.h"
|
||||
#include "test/gtest.h"
|
||||
#include "test/testsupport/file_utils.h"
|
||||
@ -165,7 +166,7 @@ void void_main() {
|
||||
|
||||
Agc agc;
|
||||
|
||||
TransientSuppressorImpl suppressor;
|
||||
TransientSuppressorImpl suppressor(TransientSuppressor::VadMode::kDefault);
|
||||
suppressor.Initialize(absl::GetFlag(FLAGS_sample_rate_hz), detection_rate_hz,
|
||||
absl::GetFlag(FLAGS_num_channels));
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace webrtc {
|
||||
@ -21,6 +22,21 @@ namespace webrtc {
|
||||
// restoration algorithm that attenuates unexpected spikes in the spectrum.
|
||||
class TransientSuppressor {
|
||||
public:
|
||||
// Type of VAD used by the caller to compute the `voice_probability` argument
|
||||
// `Suppress()`.
|
||||
enum class VadMode {
|
||||
// By default, `TransientSuppressor` assumes that `voice_probability` is
|
||||
// computed by `AgcManagerDirect`.
|
||||
kDefault = 0,
|
||||
// Use this mode when `TransientSuppressor` must assume that
|
||||
// `voice_probability` is computed by the RNN VAD.
|
||||
kRnnVad,
|
||||
// Use this mode to let `TransientSuppressor::Suppressor()` ignore
|
||||
// `voice_probability` and behave as if voice information is unavailable
|
||||
// (regardless of the passed value).
|
||||
kNoVad,
|
||||
};
|
||||
|
||||
virtual ~TransientSuppressor() {}
|
||||
|
||||
virtual int Initialize(int sample_rate_hz,
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#include <deque>
|
||||
#include <limits>
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
#include "common_audio/include/audio_util.h"
|
||||
#include "common_audio/signal_processing/include/signal_processing_library.h"
|
||||
@ -32,7 +33,6 @@
|
||||
namespace webrtc {
|
||||
|
||||
static const float kMeanIIRCoefficient = 0.5f;
|
||||
static const float kVoiceThreshold = 0.02f;
|
||||
|
||||
// TODO(aluebs): Check if these values work also for 48kHz.
|
||||
static const size_t kMinVoiceBin = 3;
|
||||
@ -44,10 +44,23 @@ float ComplexMagnitude(float a, float b) {
|
||||
return std::abs(a) + std::abs(b);
|
||||
}
|
||||
|
||||
std::string GetVadModeLabel(TransientSuppressor::VadMode vad_mode) {
|
||||
switch (vad_mode) {
|
||||
case TransientSuppressor::VadMode::kDefault:
|
||||
return "default";
|
||||
case TransientSuppressor::VadMode::kRnnVad:
|
||||
return "RNN VAD";
|
||||
case TransientSuppressor::VadMode::kNoVad:
|
||||
return "no VAD";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TransientSuppressorImpl::TransientSuppressorImpl()
|
||||
: data_length_(0),
|
||||
TransientSuppressorImpl::TransientSuppressorImpl(VadMode vad_mode)
|
||||
: vad_mode_(vad_mode),
|
||||
analyzed_audio_is_silent_(false),
|
||||
data_length_(0),
|
||||
detection_length_(0),
|
||||
analysis_length_(0),
|
||||
buffer_delay_(0),
|
||||
@ -62,7 +75,9 @@ TransientSuppressorImpl::TransientSuppressorImpl()
|
||||
use_hard_restoration_(false),
|
||||
chunks_since_voice_change_(0),
|
||||
seed_(182),
|
||||
using_reference_(false) {}
|
||||
using_reference_(false) {
|
||||
RTC_LOG(LS_INFO) << "VAD mode: " << GetVadModeLabel(vad_mode_);
|
||||
}
|
||||
|
||||
TransientSuppressorImpl::~TransientSuppressorImpl() {}
|
||||
|
||||
@ -304,16 +319,34 @@ void TransientSuppressorImpl::UpdateKeypress(bool key_pressed) {
|
||||
}
|
||||
|
||||
void TransientSuppressorImpl::UpdateRestoration(float voice_probability) {
|
||||
const int kHardRestorationOffsetDelay = 3;
|
||||
const int kHardRestorationOnsetDelay = 80;
|
||||
|
||||
bool not_voiced = voice_probability < kVoiceThreshold;
|
||||
bool not_voiced;
|
||||
switch (vad_mode_) {
|
||||
case TransientSuppressor::VadMode::kDefault: {
|
||||
constexpr float kVoiceThreshold = 0.02f;
|
||||
not_voiced = voice_probability < kVoiceThreshold;
|
||||
break;
|
||||
}
|
||||
case TransientSuppressor::VadMode::kRnnVad: {
|
||||
constexpr float kVoiceThreshold = 0.7f;
|
||||
not_voiced = voice_probability < kVoiceThreshold;
|
||||
break;
|
||||
}
|
||||
case TransientSuppressor::VadMode::kNoVad:
|
||||
// Always assume that voice is detected.
|
||||
not_voiced = false;
|
||||
break;
|
||||
}
|
||||
|
||||
if (not_voiced == use_hard_restoration_) {
|
||||
chunks_since_voice_change_ = 0;
|
||||
} else {
|
||||
++chunks_since_voice_change_;
|
||||
|
||||
// Number of 10 ms frames to wait to transition to and from hard
|
||||
// restoration.
|
||||
constexpr int kHardRestorationOffsetDelay = 3;
|
||||
constexpr int kHardRestorationOnsetDelay = 80;
|
||||
|
||||
if ((use_hard_restoration_ &&
|
||||
chunks_since_voice_change_ > kHardRestorationOffsetDelay) ||
|
||||
(!use_hard_restoration_ &&
|
||||
|
||||
@ -27,30 +27,13 @@ class TransientDetector;
|
||||
// restoration algorithm that attenuates unexpected spikes in the spectrum.
|
||||
class TransientSuppressorImpl : public TransientSuppressor {
|
||||
public:
|
||||
TransientSuppressorImpl();
|
||||
explicit TransientSuppressorImpl(VadMode vad_mode);
|
||||
~TransientSuppressorImpl() override;
|
||||
|
||||
int Initialize(int sample_rate_hz,
|
||||
int detector_rate_hz,
|
||||
int num_channels) override;
|
||||
|
||||
// Processes a `data` chunk, and returns it with keystrokes suppressed from
|
||||
// it. The float format is assumed to be int16 ranged. If there are more than
|
||||
// one channel, the chunks are concatenated one after the other in `data`.
|
||||
// `data_length` must be equal to `data_length_`.
|
||||
// `num_channels` must be equal to `num_channels_`.
|
||||
// A sub-band, ideally the higher, can be used as `detection_data`. If it is
|
||||
// NULL, `data` is used for the detection too. The `detection_data` is always
|
||||
// assumed mono.
|
||||
// If a reference signal (e.g. keyboard microphone) is available, it can be
|
||||
// passed in as `reference_data`. It is assumed mono and must have the same
|
||||
// length as `data`. NULL is accepted if unavailable.
|
||||
// This suppressor performs better if voice information is available.
|
||||
// `voice_probability` is the probability of voice being present in this chunk
|
||||
// of audio. If voice information is not available, `voice_probability` must
|
||||
// always be set to 1.
|
||||
// `key_pressed` determines if a key was pressed on this audio chunk.
|
||||
// Returns 0 on success and -1 otherwise.
|
||||
int Suppress(float* data,
|
||||
size_t data_length,
|
||||
int num_channels,
|
||||
@ -74,8 +57,12 @@ class TransientSuppressorImpl : public TransientSuppressor {
|
||||
void HardRestoration(float* spectral_mean);
|
||||
void SoftRestoration(float* spectral_mean);
|
||||
|
||||
const VadMode vad_mode_;
|
||||
|
||||
std::unique_ptr<TransientDetector> detector_;
|
||||
|
||||
bool analyzed_audio_is_silent_;
|
||||
|
||||
size_t data_length_;
|
||||
size_t detection_length_;
|
||||
size_t analysis_length_;
|
||||
|
||||
@ -8,17 +8,22 @@
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/transient/transient_suppressor_impl.h"
|
||||
#include "modules/audio_processing/transient/transient_suppressor.h"
|
||||
|
||||
#include "modules/audio_processing/transient/common.h"
|
||||
#include "modules/audio_processing/transient/transient_suppressor_impl.h"
|
||||
#include "test/gtest.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
TEST(TransientSuppressorImplTest, TypingDetectionLogicWorksAsExpectedForMono) {
|
||||
class TransientSuppressorImplTest
|
||||
: public ::testing::TestWithParam<TransientSuppressor::VadMode> {};
|
||||
|
||||
TEST_P(TransientSuppressorImplTest,
|
||||
TypingDetectionLogicWorksAsExpectedForMono) {
|
||||
static const int kNumChannels = 1;
|
||||
|
||||
TransientSuppressorImpl ts;
|
||||
TransientSuppressorImpl ts(GetParam());
|
||||
ts.Initialize(ts::kSampleRate16kHz, ts::kSampleRate16kHz, kNumChannels);
|
||||
|
||||
// Each key-press enables detection.
|
||||
@ -82,4 +87,11 @@ TEST(TransientSuppressorImplTest, TypingDetectionLogicWorksAsExpectedForMono) {
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
,
|
||||
TransientSuppressorImplTest,
|
||||
::testing::Values(TransientSuppressor::VadMode::kDefault,
|
||||
TransientSuppressor::VadMode::kRnnVad,
|
||||
TransientSuppressor::VadMode::kNoVad));
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user