APM Transient Suppressor (TS): initialization params in ctor

More robust API option that allows to fully initialize TS when created.

Bug: webrtc:13663
Change-Id: I42c38612ef772eb6d0bbde49d04ea39332a0e3c7
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/255821
Reviewed-by: Sam Zackrisson <saza@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#36490}
This commit is contained in:
Alessio Bazzica 2022-04-08 09:54:27 +02:00 committed by WebRTC LUCI CQ
parent 093ce288fd
commit 080006b42f
7 changed files with 33 additions and 16 deletions

View File

@ -1705,15 +1705,17 @@ void AudioProcessingImpl::InitializeTransientSuppressor() {
// Attempt to create a transient suppressor, if one is not already created.
if (!submodules_.transient_suppressor) {
submodules_.transient_suppressor = CreateTransientSuppressor(
submodule_creation_overrides_, transient_suppressor_vad_mode_);
}
if (submodules_.transient_suppressor) {
submodule_creation_overrides_, transient_suppressor_vad_mode_,
proc_fullband_sample_rate_hz(), capture_nonlocked_.split_rate,
num_proc_channels());
if (!submodules_.transient_suppressor) {
RTC_LOG(LS_WARNING)
<< "No transient suppressor created (probably disabled)";
}
} else {
submodules_.transient_suppressor->Initialize(
proc_fullband_sample_rate_hz(), capture_nonlocked_.split_rate,
num_proc_channels());
} else {
RTC_LOG(LS_WARNING)
<< "No transient suppressor created (probably disabled)";
}
} else {
submodules_.transient_suppressor.reset();

View File

@ -18,14 +18,18 @@ namespace webrtc {
std::unique_ptr<TransientSuppressor> CreateTransientSuppressor(
const ApmSubmoduleCreationOverrides& overrides,
TransientSuppressor::VadMode vad_mode) {
TransientSuppressor::VadMode vad_mode,
int sample_rate_hz,
int detection_rate_hz,
int num_channels) {
#ifdef WEBRTC_EXCLUDE_TRANSIENT_SUPPRESSOR
return nullptr;
#else
if (overrides.transient_suppression) {
return nullptr;
}
return std::make_unique<TransientSuppressorImpl>(vad_mode);
return std::make_unique<TransientSuppressorImpl>(
vad_mode, sample_rate_hz, detection_rate_hz, num_channels);
#endif
}

View File

@ -32,7 +32,10 @@ struct ApmSubmoduleCreationOverrides {
// * The corresponding override in `overrides` is enabled.
std::unique_ptr<TransientSuppressor> CreateTransientSuppressor(
const ApmSubmoduleCreationOverrides& overrides,
TransientSuppressor::VadMode vad_mode);
TransientSuppressor::VadMode vad_mode,
int sample_rate_hz,
int detection_rate_hz,
int num_channels);
} // namespace webrtc

View File

@ -166,9 +166,10 @@ void void_main() {
Agc agc;
TransientSuppressorImpl suppressor(TransientSuppressor::VadMode::kDefault);
suppressor.Initialize(absl::GetFlag(FLAGS_sample_rate_hz), detection_rate_hz,
absl::GetFlag(FLAGS_num_channels));
TransientSuppressorImpl suppressor(TransientSuppressor::VadMode::kDefault,
absl::GetFlag(FLAGS_sample_rate_hz),
detection_rate_hz,
absl::GetFlag(FLAGS_num_channels));
const size_t audio_buffer_size = absl::GetFlag(FLAGS_chunk_size_ms) *
absl::GetFlag(FLAGS_sample_rate_hz) / 1000;

View File

@ -57,7 +57,10 @@ std::string GetVadModeLabel(TransientSuppressor::VadMode vad_mode) {
} // namespace
TransientSuppressorImpl::TransientSuppressorImpl(VadMode vad_mode)
TransientSuppressorImpl::TransientSuppressorImpl(VadMode vad_mode,
int sample_rate_hz,
int detector_rate_hz,
int num_channels)
: vad_mode_(vad_mode),
analyzed_audio_is_silent_(false),
data_length_(0),
@ -77,6 +80,7 @@ TransientSuppressorImpl::TransientSuppressorImpl(VadMode vad_mode)
seed_(182),
using_reference_(false) {
RTC_LOG(LS_INFO) << "VAD mode: " << GetVadModeLabel(vad_mode_);
Initialize(sample_rate_hz, detector_rate_hz, num_channels);
}
TransientSuppressorImpl::~TransientSuppressorImpl() {}

View File

@ -27,7 +27,10 @@ class TransientDetector;
// restoration algorithm that attenuates unexpected spikes in the spectrum.
class TransientSuppressorImpl : public TransientSuppressor {
public:
explicit TransientSuppressorImpl(VadMode vad_mode);
TransientSuppressorImpl(VadMode vad_mode,
int sample_rate_hz,
int detector_rate_hz,
int num_channels);
~TransientSuppressorImpl() override;
int Initialize(int sample_rate_hz,

View File

@ -23,8 +23,8 @@ TEST_P(TransientSuppressorImplTest,
TypingDetectionLogicWorksAsExpectedForMono) {
static const int kNumChannels = 1;
TransientSuppressorImpl ts(GetParam());
ts.Initialize(ts::kSampleRate16kHz, ts::kSampleRate16kHz, kNumChannels);
TransientSuppressorImpl ts(GetParam(), ts::kSampleRate16kHz,
ts::kSampleRate16kHz, kNumChannels);
// Each key-press enables detection.
EXPECT_FALSE(ts.detection_enabled_);