Transient Suppressor (TS): add alternative VAD modes

It is now required to specify which VAD is used to compute the speech
probability passed when `TransientSuppressor::Suppress()` is called.
In this way, it is possible to adapt parameters and/or logic of a
`TransientSuppressor` implementation to the behavior of the used
VAD. This CL also adds a "no VAD" mode option, which ignores the speech
probability argument passed when `Suppress()` and always applies mild
suppression to preserve transparency.

Finally, this CL adds a field trial to choose which VAD is used by
APM for transient suppression. Wiring the RNN VAD to TS will be done
in a follow-up CL.

Bug: webrtc:13663
Change-Id: I21ed49f91875a4ee0f04db97ea87c0dbc3db7f8a
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/250962
Reviewed-by: Hanna Silen <silen@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#36485}
This commit is contained in:
Alessio Bazzica 2022-03-18 12:39:00 +01:00 committed by WebRTC LUCI CQ
parent 9190fef84d
commit efbe3af366
11 changed files with 133 additions and 43 deletions

View File

@ -205,7 +205,10 @@ rtc_library("audio_processing") {
"transient:transient_suppressor_api",
"vad",
]
absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ]
absl_deps = [
"//third_party/abseil-cpp/absl/strings",
"//third_party/abseil-cpp/absl/types:optional",
]
deps += [
"../../common_audio",

View File

@ -17,6 +17,7 @@
#include <type_traits>
#include <utility>
#include "absl/strings/match.h"
#include "absl/types/optional.h"
#include "api/array_view.h"
#include "api/audio/audio_frame.h"
@ -67,6 +68,29 @@ bool UseSetupSpecificDefaultAec3Congfig() {
"WebRTC-Aec3SetupSpecificDefaultConfigDefaultsKillSwitch");
}
// If the "WebRTC-Audio-TransientSuppressorVadMode" field trial is unspecified,
// returns `TransientSuppressor::VadMode::kDefault`, otherwise parses the field
// trial and returns the specified mode:
// - WebRTC-Audio-TransientSuppressorVadMode/Enabled-Default returns `kDefault`;
// - WebRTC-Audio-TransientSuppressorVadMode/Enabled-RnnVad returns `kRnnVad`;
// - WebRTC-Audio-TransientSuppressorVadMode/Enabled-NoVad returns `kNoVad`.
TransientSuppressor::VadMode GetTransientSuppressorVadMode() {
constexpr char kFieldTrial[] = "WebRTC-Audio-TransientSuppressorVadMode";
std::string full_name = webrtc::field_trial::FindFullName(kFieldTrial);
if (full_name.empty() || absl::EndsWith(full_name, "-Default")) {
return TransientSuppressor::VadMode::kDefault;
}
if (absl::EndsWith(full_name, "-RnnVad")) {
return TransientSuppressor::VadMode::kRnnVad;
}
if (absl::EndsWith(full_name, "-NoVad")) {
return TransientSuppressor::VadMode::kNoVad;
}
// Fallback to default.
RTC_LOG(LS_WARNING) << "Invalid parameter for " << kFieldTrial;
return TransientSuppressor::VadMode::kDefault;
}
// Identify the native processing rate that best handles a sample rate.
int SuitableProcessRate(int minimum_rate,
int max_splitting_rate,
@ -241,6 +265,7 @@ AudioProcessingImpl::AudioProcessingImpl(
UseSetupSpecificDefaultAec3Congfig()),
use_denormal_disabler_(
!field_trial::IsEnabled("WebRTC-ApmDenormalDisablerKillSwitch")),
transient_suppressor_vad_mode_(GetTransientSuppressorVadMode()),
capture_runtime_settings_(RuntimeSettingQueueSize()),
render_runtime_settings_(RuntimeSettingQueueSize()),
capture_runtime_settings_enqueuer_(&capture_runtime_settings_),
@ -1244,14 +1269,21 @@ int AudioProcessingImpl::ProcessCaptureStreamLocked() {
capture_buffer->num_frames()));
}
// TODO(aluebs): Investigate if the transient suppression placement should
// be before or after the AGC.
if (submodules_.transient_suppressor) {
float voice_probability =
submodules_.agc_manager.get()
? submodules_.agc_manager->voice_probability()
: 1.f;
float voice_probability = 1.0f;
switch (transient_suppressor_vad_mode_) {
case TransientSuppressor::VadMode::kDefault:
if (submodules_.agc_manager) {
voice_probability = submodules_.agc_manager->voice_probability();
}
break;
case TransientSuppressor::VadMode::kRnnVad:
// TODO(bugs.webrtc.org/13663): Use RNN VAD.
break;
case TransientSuppressor::VadMode::kNoVad:
// The transient suppressor will ignore `voice_probability`.
break;
}
submodules_.transient_suppressor->Suppress(
capture_buffer->channels()[0], capture_buffer->num_frames(),
capture_buffer->num_channels(),
@ -1672,8 +1704,8 @@ void AudioProcessingImpl::InitializeTransientSuppressor() {
!constants_.transient_suppressor_forced_off) {
// Attempt to create a transient suppressor, if one is not already created.
if (!submodules_.transient_suppressor) {
submodules_.transient_suppressor =
CreateTransientSuppressor(submodule_creation_overrides_);
submodules_.transient_suppressor = CreateTransientSuppressor(
submodule_creation_overrides_, transient_suppressor_vad_mode_);
}
if (submodules_.transient_suppressor) {
submodules_.transient_suppressor->Initialize(

View File

@ -185,6 +185,8 @@ class AudioProcessingImpl : public AudioProcessing {
const bool use_denormal_disabler_;
const TransientSuppressor::VadMode transient_suppressor_vad_mode_;
SwapQueue<RuntimeSetting> capture_runtime_settings_;
SwapQueue<RuntimeSetting> render_runtime_settings_;

View File

@ -17,14 +17,15 @@
namespace webrtc {
std::unique_ptr<TransientSuppressor> CreateTransientSuppressor(
const ApmSubmoduleCreationOverrides& overrides) {
const ApmSubmoduleCreationOverrides& overrides,
TransientSuppressor::VadMode vad_mode) {
#ifdef WEBRTC_EXCLUDE_TRANSIENT_SUPPRESSOR
return nullptr;
#else
if (overrides.transient_suppression) {
return nullptr;
}
return std::make_unique<TransientSuppressorImpl>();
return std::make_unique<TransientSuppressorImpl>(vad_mode);
#endif
}

View File

@ -31,7 +31,8 @@ struct ApmSubmoduleCreationOverrides {
// * WEBRTC_EXCLUDE_TRANSIENT_SUPPRESSOR is defined
// * The corresponding override in `overrides` is enabled.
std::unique_ptr<TransientSuppressor> CreateTransientSuppressor(
const ApmSubmoduleCreationOverrides& overrides);
const ApmSubmoduleCreationOverrides& overrides,
TransientSuppressor::VadMode vad_mode);
} // namespace webrtc

View File

@ -73,6 +73,7 @@ if (rtc_include_tests) {
"transient_suppression_test.cc",
]
deps = [
":transient_suppressor_api",
":transient_suppressor_impl",
"..:audio_processing",
"../../../common_audio",
@ -103,6 +104,7 @@ if (rtc_include_tests) {
"wpd_tree_unittest.cc",
]
deps = [
":transient_suppressor_api",
":transient_suppressor_impl",
"../../../rtc_base:stringutils",
"../../../rtc_base/system:file_wrapper",

View File

@ -20,6 +20,7 @@
#include "absl/flags/parse.h"
#include "common_audio/include/audio_util.h"
#include "modules/audio_processing/agc/agc.h"
#include "modules/audio_processing/transient/transient_suppressor.h"
#include "modules/audio_processing/transient/transient_suppressor_impl.h"
#include "test/gtest.h"
#include "test/testsupport/file_utils.h"
@ -165,7 +166,7 @@ void void_main() {
Agc agc;
TransientSuppressorImpl suppressor;
TransientSuppressorImpl suppressor(TransientSuppressor::VadMode::kDefault);
suppressor.Initialize(absl::GetFlag(FLAGS_sample_rate_hz), detection_rate_hz,
absl::GetFlag(FLAGS_num_channels));

View File

@ -13,6 +13,7 @@
#include <stddef.h>
#include <stdint.h>
#include <memory>
namespace webrtc {
@ -21,6 +22,21 @@ namespace webrtc {
// restoration algorithm that attenuates unexpected spikes in the spectrum.
class TransientSuppressor {
public:
// Type of VAD used by the caller to compute the `voice_probability` argument
// `Suppress()`.
enum class VadMode {
// By default, `TransientSuppressor` assumes that `voice_probability` is
// computed by `AgcManagerDirect`.
kDefault = 0,
// Use this mode when `TransientSuppressor` must assume that
// `voice_probability` is computed by the RNN VAD.
kRnnVad,
// Use this mode to let `TransientSuppressor::Suppressor()` ignore
// `voice_probability` and behave as if voice information is unavailable
// (regardless of the passed value).
kNoVad,
};
virtual ~TransientSuppressor() {}
virtual int Initialize(int sample_rate_hz,

View File

@ -18,6 +18,7 @@
#include <deque>
#include <limits>
#include <set>
#include <string>
#include "common_audio/include/audio_util.h"
#include "common_audio/signal_processing/include/signal_processing_library.h"
@ -32,7 +33,6 @@
namespace webrtc {
static const float kMeanIIRCoefficient = 0.5f;
static const float kVoiceThreshold = 0.02f;
// TODO(aluebs): Check if these values work also for 48kHz.
static const size_t kMinVoiceBin = 3;
@ -44,10 +44,23 @@ float ComplexMagnitude(float a, float b) {
return std::abs(a) + std::abs(b);
}
std::string GetVadModeLabel(TransientSuppressor::VadMode vad_mode) {
switch (vad_mode) {
case TransientSuppressor::VadMode::kDefault:
return "default";
case TransientSuppressor::VadMode::kRnnVad:
return "RNN VAD";
case TransientSuppressor::VadMode::kNoVad:
return "no VAD";
}
}
} // namespace
TransientSuppressorImpl::TransientSuppressorImpl()
: data_length_(0),
TransientSuppressorImpl::TransientSuppressorImpl(VadMode vad_mode)
: vad_mode_(vad_mode),
analyzed_audio_is_silent_(false),
data_length_(0),
detection_length_(0),
analysis_length_(0),
buffer_delay_(0),
@ -62,7 +75,9 @@ TransientSuppressorImpl::TransientSuppressorImpl()
use_hard_restoration_(false),
chunks_since_voice_change_(0),
seed_(182),
using_reference_(false) {}
using_reference_(false) {
RTC_LOG(LS_INFO) << "VAD mode: " << GetVadModeLabel(vad_mode_);
}
TransientSuppressorImpl::~TransientSuppressorImpl() {}
@ -304,16 +319,34 @@ void TransientSuppressorImpl::UpdateKeypress(bool key_pressed) {
}
void TransientSuppressorImpl::UpdateRestoration(float voice_probability) {
const int kHardRestorationOffsetDelay = 3;
const int kHardRestorationOnsetDelay = 80;
bool not_voiced = voice_probability < kVoiceThreshold;
bool not_voiced;
switch (vad_mode_) {
case TransientSuppressor::VadMode::kDefault: {
constexpr float kVoiceThreshold = 0.02f;
not_voiced = voice_probability < kVoiceThreshold;
break;
}
case TransientSuppressor::VadMode::kRnnVad: {
constexpr float kVoiceThreshold = 0.7f;
not_voiced = voice_probability < kVoiceThreshold;
break;
}
case TransientSuppressor::VadMode::kNoVad:
// Always assume that voice is detected.
not_voiced = false;
break;
}
if (not_voiced == use_hard_restoration_) {
chunks_since_voice_change_ = 0;
} else {
++chunks_since_voice_change_;
// Number of 10 ms frames to wait to transition to and from hard
// restoration.
constexpr int kHardRestorationOffsetDelay = 3;
constexpr int kHardRestorationOnsetDelay = 80;
if ((use_hard_restoration_ &&
chunks_since_voice_change_ > kHardRestorationOffsetDelay) ||
(!use_hard_restoration_ &&

View File

@ -27,30 +27,13 @@ class TransientDetector;
// restoration algorithm that attenuates unexpected spikes in the spectrum.
class TransientSuppressorImpl : public TransientSuppressor {
public:
TransientSuppressorImpl();
explicit TransientSuppressorImpl(VadMode vad_mode);
~TransientSuppressorImpl() override;
int Initialize(int sample_rate_hz,
int detector_rate_hz,
int num_channels) override;
// Processes a `data` chunk, and returns it with keystrokes suppressed from
// it. The float format is assumed to be int16 ranged. If there are more than
// one channel, the chunks are concatenated one after the other in `data`.
// `data_length` must be equal to `data_length_`.
// `num_channels` must be equal to `num_channels_`.
// A sub-band, ideally the higher, can be used as `detection_data`. If it is
// NULL, `data` is used for the detection too. The `detection_data` is always
// assumed mono.
// If a reference signal (e.g. keyboard microphone) is available, it can be
// passed in as `reference_data`. It is assumed mono and must have the same
// length as `data`. NULL is accepted if unavailable.
// This suppressor performs better if voice information is available.
// `voice_probability` is the probability of voice being present in this chunk
// of audio. If voice information is not available, `voice_probability` must
// always be set to 1.
// `key_pressed` determines if a key was pressed on this audio chunk.
// Returns 0 on success and -1 otherwise.
int Suppress(float* data,
size_t data_length,
int num_channels,
@ -74,8 +57,12 @@ class TransientSuppressorImpl : public TransientSuppressor {
void HardRestoration(float* spectral_mean);
void SoftRestoration(float* spectral_mean);
const VadMode vad_mode_;
std::unique_ptr<TransientDetector> detector_;
bool analyzed_audio_is_silent_;
size_t data_length_;
size_t detection_length_;
size_t analysis_length_;

View File

@ -8,17 +8,22 @@
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/transient/transient_suppressor_impl.h"
#include "modules/audio_processing/transient/transient_suppressor.h"
#include "modules/audio_processing/transient/common.h"
#include "modules/audio_processing/transient/transient_suppressor_impl.h"
#include "test/gtest.h"
namespace webrtc {
TEST(TransientSuppressorImplTest, TypingDetectionLogicWorksAsExpectedForMono) {
class TransientSuppressorImplTest
: public ::testing::TestWithParam<TransientSuppressor::VadMode> {};
TEST_P(TransientSuppressorImplTest,
TypingDetectionLogicWorksAsExpectedForMono) {
static const int kNumChannels = 1;
TransientSuppressorImpl ts;
TransientSuppressorImpl ts(GetParam());
ts.Initialize(ts::kSampleRate16kHz, ts::kSampleRate16kHz, kNumChannels);
// Each key-press enables detection.
@ -82,4 +87,11 @@ TEST(TransientSuppressorImplTest, TypingDetectionLogicWorksAsExpectedForMono) {
}
}
INSTANTIATE_TEST_SUITE_P(
,
TransientSuppressorImplTest,
::testing::Values(TransientSuppressor::VadMode::kDefault,
TransientSuppressor::VadMode::kRnnVad,
TransientSuppressor::VadMode::kNoVad));
} // namespace webrtc