diff --git a/modules/audio_processing/transient/transient_suppression_test.cc b/modules/audio_processing/transient/transient_suppression_test.cc index 9864253d6e..2d8baf9416 100644 --- a/modules/audio_processing/transient/transient_suppression_test.cc +++ b/modules/audio_processing/transient/transient_suppression_test.cc @@ -200,12 +200,11 @@ void void_main() { audio_buffer_f[i] = audio_buffer_i[i]; } - ASSERT_EQ(0, suppressor.Suppress( - audio_buffer_f.get(), audio_buffer_size, - absl::GetFlag(FLAGS_num_channels), detection_buffer.get(), - detection_buffer_size, reference_buffer.get(), - audio_buffer_size, agc.voice_probability(), true)) - << "The transient suppressor could not suppress the frame"; + suppressor.Suppress(audio_buffer_f.get(), audio_buffer_size, + absl::GetFlag(FLAGS_num_channels), + detection_buffer.get(), detection_buffer_size, + reference_buffer.get(), audio_buffer_size, + agc.voice_probability(), true); // Write result to out file. WritePCM(out_file, audio_buffer_size, absl::GetFlag(FLAGS_num_channels), diff --git a/modules/audio_processing/transient/transient_suppressor.h b/modules/audio_processing/transient/transient_suppressor.h index b6cb61f13a..dd998a1154 100644 --- a/modules/audio_processing/transient/transient_suppressor.h +++ b/modules/audio_processing/transient/transient_suppressor.h @@ -11,10 +11,7 @@ #ifndef MODULES_AUDIO_PROCESSING_TRANSIENT_TRANSIENT_SUPPRESSOR_H_ #define MODULES_AUDIO_PROCESSING_TRANSIENT_TRANSIENT_SUPPRESSOR_H_ -#include -#include - -#include +#include namespace webrtc { @@ -39,9 +36,9 @@ class TransientSuppressor { virtual ~TransientSuppressor() {} - virtual int Initialize(int sample_rate_hz, - int detector_rate_hz, - int num_channels) = 0; + virtual void Initialize(int sample_rate_hz, + int detector_rate_hz, + int num_channels) = 0; // Processes a `data` chunk, and returns it with keystrokes suppressed from // it. The float format is assumed to be int16 ranged. If there are more than @@ -59,16 +56,15 @@ class TransientSuppressor { // of audio. If voice information is not available, `voice_probability` must // always be set to 1. // `key_pressed` determines if a key was pressed on this audio chunk. - // Returns 0 on success and -1 otherwise. - virtual int Suppress(float* data, - size_t data_length, - int num_channels, - const float* detection_data, - size_t detection_length, - const float* reference_data, - size_t reference_length, - float voice_probability, - bool key_pressed) = 0; + virtual void Suppress(float* data, + size_t data_length, + int num_channels, + const float* detection_data, + size_t detection_length, + const float* reference_data, + size_t reference_length, + float voice_probability, + bool key_pressed) = 0; }; } // namespace webrtc diff --git a/modules/audio_processing/transient/transient_suppressor_impl.cc b/modules/audio_processing/transient/transient_suppressor_impl.cc index b2a389eb4b..f3fbf09240 100644 --- a/modules/audio_processing/transient/transient_suppressor_impl.cc +++ b/modules/audio_processing/transient/transient_suppressor_impl.cc @@ -85,9 +85,19 @@ TransientSuppressorImpl::TransientSuppressorImpl(VadMode vad_mode, TransientSuppressorImpl::~TransientSuppressorImpl() {} -int TransientSuppressorImpl::Initialize(int sample_rate_hz, - int detection_rate_hz, - int num_channels) { +void TransientSuppressorImpl::Initialize(int sample_rate_hz, + int detection_rate_hz, + int num_channels) { + RTC_DCHECK(sample_rate_hz == ts::kSampleRate8kHz || + sample_rate_hz == ts::kSampleRate16kHz || + sample_rate_hz == ts::kSampleRate32kHz || + sample_rate_hz == ts::kSampleRate48kHz); + RTC_DCHECK(detection_rate_hz == ts::kSampleRate8kHz || + detection_rate_hz == ts::kSampleRate16kHz || + detection_rate_hz == ts::kSampleRate32kHz || + detection_rate_hz == ts::kSampleRate48kHz); + RTC_DCHECK_GT(num_channels, 0); + switch (sample_rate_hz) { case ts::kSampleRate8kHz: analysis_length_ = 128u; @@ -106,24 +116,13 @@ int TransientSuppressorImpl::Initialize(int sample_rate_hz, window_ = kBlocks480w1024; break; default: - return -1; - } - if (detection_rate_hz != ts::kSampleRate8kHz && - detection_rate_hz != ts::kSampleRate16kHz && - detection_rate_hz != ts::kSampleRate32kHz && - detection_rate_hz != ts::kSampleRate48kHz) { - return -1; - } - if (num_channels <= 0) { - return -1; + RTC_DCHECK_NOTREACHED(); + return; } detector_.reset(new TransientDetector(detection_rate_hz)); data_length_ = sample_rate_hz * ts::kChunkSizeMs / 1000; - if (data_length_ > analysis_length_) { - RTC_DCHECK_NOTREACHED(); - return -1; - } + RTC_DCHECK_LE(data_length_, analysis_length_); buffer_delay_ = analysis_length_ - data_length_; complex_analysis_length_ = analysis_length_ / 2 + 1; @@ -174,28 +173,26 @@ int TransientSuppressorImpl::Initialize(int sample_rate_hz, chunks_since_voice_change_ = 0; seed_ = 182; using_reference_ = false; - return 0; } -int TransientSuppressorImpl::Suppress(float* data, - size_t data_length, - int num_channels, - const float* detection_data, - size_t detection_length, - const float* reference_data, - size_t reference_length, - float voice_probability, - bool key_pressed) { +void TransientSuppressorImpl::Suppress(float* data, + size_t data_length, + int num_channels, + const float* detection_data, + size_t detection_length, + const float* reference_data, + size_t reference_length, + float voice_probability, + bool key_pressed) { if (!data || data_length != data_length_ || num_channels != num_channels_ || detection_length != detection_length_ || voice_probability < 0 || voice_probability > 1) { - return -1; + return; } UpdateKeypress(key_pressed); UpdateBuffers(data); - int result = 0; if (detection_enabled_) { UpdateRestoration(voice_probability); @@ -208,7 +205,7 @@ int TransientSuppressorImpl::Suppress(float* data, float detector_result = detector_->Detect(detection_data, detection_length, reference_data, reference_length); if (detector_result < 0) { - return -1; + return; } using_reference_ = detector_->using_reference(); @@ -238,7 +235,6 @@ int TransientSuppressorImpl::Suppress(float* data, : &in_buffer_[i * analysis_length_], data_length_ * sizeof(*data)); } - return result; } // This should only be called when detection is enabled. UpdateBuffers() must diff --git a/modules/audio_processing/transient/transient_suppressor_impl.h b/modules/audio_processing/transient/transient_suppressor_impl.h index ceb9cce38e..75caf5b813 100644 --- a/modules/audio_processing/transient/transient_suppressor_impl.h +++ b/modules/audio_processing/transient/transient_suppressor_impl.h @@ -33,19 +33,19 @@ class TransientSuppressorImpl : public TransientSuppressor { int num_channels); ~TransientSuppressorImpl() override; - int Initialize(int sample_rate_hz, - int detector_rate_hz, - int num_channels) override; + void Initialize(int sample_rate_hz, + int detector_rate_hz, + int num_channels) override; - int Suppress(float* data, - size_t data_length, - int num_channels, - const float* detection_data, - size_t detection_length, - const float* reference_data, - size_t reference_length, - float voice_probability, - bool key_pressed) override; + void Suppress(float* data, + size_t data_length, + int num_channels, + const float* detection_data, + size_t detection_length, + const float* reference_data, + size_t reference_length, + float voice_probability, + bool key_pressed) override; private: FRIEND_TEST_ALL_PREFIXES(TransientSuppressorImplTest,