APM Transient Suppressor (TS): initialization params in ctor

More robust API option that allows to fully initialize TS when created.

Bug: webrtc:13663
Change-Id: I42c38612ef772eb6d0bbde49d04ea39332a0e3c7
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/255821
Reviewed-by: Sam Zackrisson <saza@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#36490}
This commit is contained in:
Alessio Bazzica 2022-04-08 09:54:27 +02:00 committed by WebRTC LUCI CQ
parent 093ce288fd
commit 080006b42f
7 changed files with 33 additions and 16 deletions

View File

@ -1705,15 +1705,17 @@ void AudioProcessingImpl::InitializeTransientSuppressor() {
// Attempt to create a transient suppressor, if one is not already created. // Attempt to create a transient suppressor, if one is not already created.
if (!submodules_.transient_suppressor) { if (!submodules_.transient_suppressor) {
submodules_.transient_suppressor = CreateTransientSuppressor( submodules_.transient_suppressor = CreateTransientSuppressor(
submodule_creation_overrides_, transient_suppressor_vad_mode_); submodule_creation_overrides_, transient_suppressor_vad_mode_,
} proc_fullband_sample_rate_hz(), capture_nonlocked_.split_rate,
if (submodules_.transient_suppressor) { num_proc_channels());
if (!submodules_.transient_suppressor) {
RTC_LOG(LS_WARNING)
<< "No transient suppressor created (probably disabled)";
}
} else {
submodules_.transient_suppressor->Initialize( submodules_.transient_suppressor->Initialize(
proc_fullband_sample_rate_hz(), capture_nonlocked_.split_rate, proc_fullband_sample_rate_hz(), capture_nonlocked_.split_rate,
num_proc_channels()); num_proc_channels());
} else {
RTC_LOG(LS_WARNING)
<< "No transient suppressor created (probably disabled)";
} }
} else { } else {
submodules_.transient_suppressor.reset(); submodules_.transient_suppressor.reset();

View File

@ -18,14 +18,18 @@ namespace webrtc {
std::unique_ptr<TransientSuppressor> CreateTransientSuppressor( std::unique_ptr<TransientSuppressor> CreateTransientSuppressor(
const ApmSubmoduleCreationOverrides& overrides, const ApmSubmoduleCreationOverrides& overrides,
TransientSuppressor::VadMode vad_mode) { TransientSuppressor::VadMode vad_mode,
int sample_rate_hz,
int detection_rate_hz,
int num_channels) {
#ifdef WEBRTC_EXCLUDE_TRANSIENT_SUPPRESSOR #ifdef WEBRTC_EXCLUDE_TRANSIENT_SUPPRESSOR
return nullptr; return nullptr;
#else #else
if (overrides.transient_suppression) { if (overrides.transient_suppression) {
return nullptr; return nullptr;
} }
return std::make_unique<TransientSuppressorImpl>(vad_mode); return std::make_unique<TransientSuppressorImpl>(
vad_mode, sample_rate_hz, detection_rate_hz, num_channels);
#endif #endif
} }

View File

@ -32,7 +32,10 @@ struct ApmSubmoduleCreationOverrides {
// * The corresponding override in `overrides` is enabled. // * The corresponding override in `overrides` is enabled.
std::unique_ptr<TransientSuppressor> CreateTransientSuppressor( std::unique_ptr<TransientSuppressor> CreateTransientSuppressor(
const ApmSubmoduleCreationOverrides& overrides, const ApmSubmoduleCreationOverrides& overrides,
TransientSuppressor::VadMode vad_mode); TransientSuppressor::VadMode vad_mode,
int sample_rate_hz,
int detection_rate_hz,
int num_channels);
} // namespace webrtc } // namespace webrtc

View File

@ -166,9 +166,10 @@ void void_main() {
Agc agc; Agc agc;
TransientSuppressorImpl suppressor(TransientSuppressor::VadMode::kDefault); TransientSuppressorImpl suppressor(TransientSuppressor::VadMode::kDefault,
suppressor.Initialize(absl::GetFlag(FLAGS_sample_rate_hz), detection_rate_hz, absl::GetFlag(FLAGS_sample_rate_hz),
absl::GetFlag(FLAGS_num_channels)); detection_rate_hz,
absl::GetFlag(FLAGS_num_channels));
const size_t audio_buffer_size = absl::GetFlag(FLAGS_chunk_size_ms) * const size_t audio_buffer_size = absl::GetFlag(FLAGS_chunk_size_ms) *
absl::GetFlag(FLAGS_sample_rate_hz) / 1000; absl::GetFlag(FLAGS_sample_rate_hz) / 1000;

View File

@ -57,7 +57,10 @@ std::string GetVadModeLabel(TransientSuppressor::VadMode vad_mode) {
} // namespace } // namespace
TransientSuppressorImpl::TransientSuppressorImpl(VadMode vad_mode) TransientSuppressorImpl::TransientSuppressorImpl(VadMode vad_mode,
int sample_rate_hz,
int detector_rate_hz,
int num_channels)
: vad_mode_(vad_mode), : vad_mode_(vad_mode),
analyzed_audio_is_silent_(false), analyzed_audio_is_silent_(false),
data_length_(0), data_length_(0),
@ -77,6 +80,7 @@ TransientSuppressorImpl::TransientSuppressorImpl(VadMode vad_mode)
seed_(182), seed_(182),
using_reference_(false) { using_reference_(false) {
RTC_LOG(LS_INFO) << "VAD mode: " << GetVadModeLabel(vad_mode_); RTC_LOG(LS_INFO) << "VAD mode: " << GetVadModeLabel(vad_mode_);
Initialize(sample_rate_hz, detector_rate_hz, num_channels);
} }
TransientSuppressorImpl::~TransientSuppressorImpl() {} TransientSuppressorImpl::~TransientSuppressorImpl() {}

View File

@ -27,7 +27,10 @@ class TransientDetector;
// restoration algorithm that attenuates unexpected spikes in the spectrum. // restoration algorithm that attenuates unexpected spikes in the spectrum.
class TransientSuppressorImpl : public TransientSuppressor { class TransientSuppressorImpl : public TransientSuppressor {
public: public:
explicit TransientSuppressorImpl(VadMode vad_mode); TransientSuppressorImpl(VadMode vad_mode,
int sample_rate_hz,
int detector_rate_hz,
int num_channels);
~TransientSuppressorImpl() override; ~TransientSuppressorImpl() override;
int Initialize(int sample_rate_hz, int Initialize(int sample_rate_hz,

View File

@ -23,8 +23,8 @@ TEST_P(TransientSuppressorImplTest,
TypingDetectionLogicWorksAsExpectedForMono) { TypingDetectionLogicWorksAsExpectedForMono) {
static const int kNumChannels = 1; static const int kNumChannels = 1;
TransientSuppressorImpl ts(GetParam()); TransientSuppressorImpl ts(GetParam(), ts::kSampleRate16kHz,
ts.Initialize(ts::kSampleRate16kHz, ts::kSampleRate16kHz, kNumChannels); ts::kSampleRate16kHz, kNumChannels);
// Each key-press enables detection. // Each key-press enables detection.
EXPECT_FALSE(ts.detection_enabled_); EXPECT_FALSE(ts.detection_enabled_);