diff --git a/webrtc/modules/audio_processing/audio_processing_impl.cc b/webrtc/modules/audio_processing/audio_processing_impl.cc index af6d9b7a21..d89cc33931 100644 --- a/webrtc/modules/audio_processing/audio_processing_impl.cc +++ b/webrtc/modules/audio_processing/audio_processing_impl.cc @@ -706,8 +706,10 @@ int AudioProcessingImpl::ProcessStreamLocked() { public_submodules_->noise_suppression->ProcessCaptureAudio(ca); if (constants_.intelligibility_enabled) { RTC_DCHECK(public_submodules_->noise_suppression->is_enabled()); + RTC_DCHECK(public_submodules_->gain_control->is_enabled()); public_submodules_->intelligibility_enhancer->SetCaptureNoiseEstimate( - public_submodules_->noise_suppression->NoiseEstimate()); + public_submodules_->noise_suppression->NoiseEstimate(), + public_submodules_->gain_control->compression_gain_db()); } // Ensure that the stream delay was set before the call to the diff --git a/webrtc/modules/audio_processing/gain_control_impl.cc b/webrtc/modules/audio_processing/gain_control_impl.cc index 9f381d26f1..2461f72ad3 100644 --- a/webrtc/modules/audio_processing/gain_control_impl.cc +++ b/webrtc/modules/audio_processing/gain_control_impl.cc @@ -275,6 +275,11 @@ int GainControlImpl::ProcessCaptureAudio(AudioBuffer* audio, return AudioProcessing::kNoError; } +int GainControlImpl::compression_gain_db() const { + rtc::CritScope cs(crit_capture_); + return compression_gain_db_; +} + // TODO(ajm): ensure this is called under kAdaptiveAnalog. int GainControlImpl::set_stream_analog_level(int level) { rtc::CritScope cs(crit_capture_); @@ -414,11 +419,6 @@ int GainControlImpl::set_compression_gain_db(int gain) { return Configure(); } -int GainControlImpl::compression_gain_db() const { - rtc::CritScope cs(crit_capture_); - return compression_gain_db_; -} - int GainControlImpl::enable_limiter(bool enable) { { rtc::CritScope cs(crit_capture_); diff --git a/webrtc/modules/audio_processing/gain_control_impl.h b/webrtc/modules/audio_processing/gain_control_impl.h index 9498ac60b5..2459ce3b4b 100644 --- a/webrtc/modules/audio_processing/gain_control_impl.h +++ b/webrtc/modules/audio_processing/gain_control_impl.h @@ -51,6 +51,8 @@ class GainControlImpl : public GainControl { // Reads render side data that has been queued on the render call. void ReadQueuedRenderData(); + int compression_gain_db() const override; + private: class GainController; @@ -61,7 +63,6 @@ class GainControlImpl : public GainControl { int set_target_level_dbfs(int level) override; int target_level_dbfs() const override; int set_compression_gain_db(int gain) override; - int compression_gain_db() const override; int enable_limiter(bool enable) override; int set_analog_level_limits(int minimum, int maximum) override; int analog_level_minimum() const override; diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.cc b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.cc index 33de5c1f95..ae7f911921 100644 --- a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.cc +++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.cc @@ -109,8 +109,12 @@ IntelligibilityEnhancer::IntelligibilityEnhancer(int sample_rate_hz, } void IntelligibilityEnhancer::SetCaptureNoiseEstimate( - std::vector noise) { + std::vector noise, int gain_db) { RTC_DCHECK_EQ(noise.size(), num_noise_bins_); + const float gain = std::pow(10.f, gain_db / 20.f); + for (auto& bin : noise) { + bin *= gain; + } // Disregarding return value since buffer overflow is acceptable, because it // is not critical to get each noise estimate. if (noise_estimation_queue_.Insert(&noise)) { diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h index 111b765f97..63ae80e2c4 100644 --- a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h +++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h @@ -36,7 +36,7 @@ class IntelligibilityEnhancer : public LappedTransform::Callback { size_t num_noise_bins); // Sets the capture noise magnitude spectrum estimate. - void SetCaptureNoiseEstimate(std::vector noise); + void SetCaptureNoiseEstimate(std::vector noise, int gain_db); // Reads chunk of speech in time domain and updates with modified signal. void ProcessRenderAudio(float* const* audio, @@ -56,6 +56,8 @@ class IntelligibilityEnhancer : public LappedTransform::Callback { private: FRIEND_TEST_ALL_PREFIXES(IntelligibilityEnhancerTest, TestErbCreation); FRIEND_TEST_ALL_PREFIXES(IntelligibilityEnhancerTest, TestSolveForGains); + FRIEND_TEST_ALL_PREFIXES(IntelligibilityEnhancerTest, + TestNoiseGainHasExpectedResult); // Updates the SNR estimation and enables or disables this component using a // hysteresis. diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer_unittest.cc b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer_unittest.cc index 080e228cb8..30035ab16e 100644 --- a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer_unittest.cc +++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer_unittest.cc @@ -237,7 +237,7 @@ void ProcessOneFrame(int sample_rate_hz, noise_suppressor->ProcessCaptureAudio(capture_audio_buffer); intelligibility_enhancer->SetCaptureNoiseEstimate( - noise_suppressor->NoiseEstimate()); + noise_suppressor->NoiseEstimate(), 0); if (sample_rate_hz > AudioProcessing::kSampleRate16kHz) { render_audio_buffer->MergeFrequencyBands(); @@ -311,12 +311,17 @@ void RunBitexactnessTest(int sample_rate_hz, output_reference, render_output, kElementErrorBound)); } +float float_rand() { + return std::rand() * 2.f / RAND_MAX - 1; +} + } // namespace class IntelligibilityEnhancerTest : public ::testing::Test { protected: IntelligibilityEnhancerTest() : clear_data_(kSamples), noise_data_(kSamples), orig_data_(kSamples) { + std::srand(1); enh_.reset( new IntelligibilityEnhancer(kSampleRate, kNumChannels, kNumNoiseBins)); } @@ -352,8 +357,6 @@ TEST_F(IntelligibilityEnhancerTest, TestRenderUpdate) { std::fill(orig_data_.begin(), orig_data_.end(), 0.f); std::fill(clear_data_.begin(), clear_data_.end(), 0.f); EXPECT_FALSE(CheckUpdate()); - std::srand(1); - auto float_rand = []() { return std::rand() * 2.f / RAND_MAX - 1; }; std::generate(noise_data_.begin(), noise_data_.end(), float_rand); EXPECT_FALSE(CheckUpdate()); std::generate(clear_data_.begin(), clear_data_.end(), float_rand); @@ -403,6 +406,29 @@ TEST_F(IntelligibilityEnhancerTest, TestSolveForGains) { } } +TEST_F(IntelligibilityEnhancerTest, TestNoiseGainHasExpectedResult) { + const int kGainDB = 6; + const float kGainFactor = std::pow(10.f, kGainDB / 20.f); + const float kTolerance = 0.003f; + std::vector noise(kNumNoiseBins); + std::vector noise_psd(kNumNoiseBins); + std::generate(noise.begin(), noise.end(), float_rand); + for (size_t i = 0; i < kNumNoiseBins; ++i) { + noise_psd[i] = kGainFactor * kGainFactor * noise[i] * noise[i]; + } + float* clear_cursor = clear_data_.data(); + for (size_t i = 0; i < kNumFramesToProcess; ++i) { + enh_->SetCaptureNoiseEstimate(noise, kGainDB); + enh_->ProcessRenderAudio(&clear_cursor, kSampleRate, kNumChannels); + } + const std::vector& estimated_psd = + enh_->noise_power_estimator_.power(); + for (size_t i = 0; i < kNumNoiseBins; ++i) { + EXPECT_LT(std::abs(estimated_psd[i] - noise_psd[i]) / noise_psd[i], + kTolerance); + } +} + TEST(IntelligibilityEnhancerBitExactnessTest, DISABLED_Mono8kHz) { const float kOutputReference[] = {-0.001892f, -0.003296f, -0.001953f}; diff --git a/webrtc/modules/audio_processing/intelligibility/test/intelligibility_proc.cc b/webrtc/modules/audio_processing/intelligibility/test/intelligibility_proc.cc index 64ccfd96ef..abd10d8516 100644 --- a/webrtc/modules/audio_processing/intelligibility/test/intelligibility_proc.cc +++ b/webrtc/modules/audio_processing/intelligibility/test/intelligibility_proc.cc @@ -64,7 +64,7 @@ void void_main(int argc, char* argv[]) { capture_audio.CopyFrom(noise_buf.channels(), stream_config); ns.AnalyzeCaptureAudio(&capture_audio); ns.ProcessCaptureAudio(&capture_audio); - enh.SetCaptureNoiseEstimate(ns.NoiseEstimate()); + enh.SetCaptureNoiseEstimate(ns.NoiseEstimate(), 0); enh.ProcessRenderAudio(in_buf.channels(), in_file.sample_rate(), in_file.num_channels()); Interleave(in_buf.channels(), in_buf.num_frames(), in_buf.num_channels(),