diff --git a/modules/audio_processing/agc2/adaptive_agc.cc b/modules/audio_processing/agc2/adaptive_agc.cc index 816754d08d..f1017b3a6c 100644 --- a/modules/audio_processing/agc2/adaptive_agc.cc +++ b/modules/audio_processing/agc2/adaptive_agc.cc @@ -51,10 +51,9 @@ void AdaptiveAgc::Process(AudioFrameView float_frame, apm_data_dumper_->DumpRaw("agc2_vad_peak_dbfs", signal_with_levels.vad_result.peak_dbfs); - speech_level_estimator_.UpdateEstimation(signal_with_levels.vad_result); + speech_level_estimator_.Update(signal_with_levels.vad_result); - signal_with_levels.input_level_dbfs = - speech_level_estimator_.LatestLevelEstimate(); + signal_with_levels.input_level_dbfs = speech_level_estimator_.GetLevelDbfs(); signal_with_levels.input_noise_level_dbfs = noise_level_estimator_.Analyze(float_frame); @@ -68,7 +67,7 @@ void AdaptiveAgc::Process(AudioFrameView float_frame, signal_with_levels.limiter_audio_level_dbfs); signal_with_levels.estimate_is_confident = - speech_level_estimator_.LevelEstimationIsConfident(); + speech_level_estimator_.IsConfident(); // The gain applier applies the gain. gain_applier_.Process(signal_with_levels); diff --git a/modules/audio_processing/agc2/adaptive_mode_level_estimator.cc b/modules/audio_processing/agc2/adaptive_mode_level_estimator.cc index 2f5e442aac..0f839ba715 100644 --- a/modules/audio_processing/agc2/adaptive_mode_level_estimator.cc +++ b/modules/audio_processing/agc2/adaptive_mode_level_estimator.cc @@ -17,6 +17,11 @@ namespace webrtc { +float AdaptiveModeLevelEstimator::State::Ratio::GetRatio() const { + RTC_DCHECK_NE(denominator, 0.f); + return numerator / denominator; +} + AdaptiveModeLevelEstimator::AdaptiveModeLevelEstimator( ApmDataDumper* apm_data_dumper) : AdaptiveModeLevelEstimator( @@ -43,13 +48,16 @@ AdaptiveModeLevelEstimator::AdaptiveModeLevelEstimator( bool use_saturation_protector, float initial_saturation_margin_db, float extra_saturation_margin_db) - : level_estimator_(level_estimator), + : apm_data_dumper_(apm_data_dumper), + saturation_protector_(apm_data_dumper, initial_saturation_margin_db), + level_estimator_type_(level_estimator), use_saturation_protector_(use_saturation_protector), extra_saturation_margin_db_(extra_saturation_margin_db), - saturation_protector_(apm_data_dumper, initial_saturation_margin_db), - apm_data_dumper_(apm_data_dumper) {} + last_level_dbfs_(absl::nullopt) { + Reset(); +} -void AdaptiveModeLevelEstimator::UpdateEstimation( +void AdaptiveModeLevelEstimator::Update( const VadLevelAnalyzer::Result& vad_level) { RTC_DCHECK_GT(vad_level.rms_dbfs, -150.f); RTC_DCHECK_LT(vad_level.rms_dbfs, 50.f); @@ -63,64 +71,80 @@ void AdaptiveModeLevelEstimator::UpdateEstimation( return; } - const bool buffer_is_full = buffer_size_ms_ >= kFullBufferSizeMs; + // Update the state. + RTC_DCHECK_GE(state_.time_to_full_buffer_ms, 0); + const bool buffer_is_full = state_.time_to_full_buffer_ms == 0; if (!buffer_is_full) { - buffer_size_ms_ += kFrameDurationMs; + state_.time_to_full_buffer_ms -= kFrameDurationMs; } - const float leak_factor = buffer_is_full ? kFullBufferLeakFactor : 1.f; - - // Read speech level estimation. - float speech_level_dbfs = 0.f; + // Read level estimation. + float level_dbfs = 0.f; using LevelEstimatorType = AudioProcessing::Config::GainController2::LevelEstimator; - switch (level_estimator_) { + switch (level_estimator_type_) { case LevelEstimatorType::kRms: - speech_level_dbfs = vad_level.rms_dbfs; + level_dbfs = vad_level.rms_dbfs; break; case LevelEstimatorType::kPeak: - speech_level_dbfs = vad_level.peak_dbfs; + level_dbfs = vad_level.peak_dbfs; break; } - // Update speech level estimation. - estimate_numerator_ = estimate_numerator_ * leak_factor + - speech_level_dbfs * vad_level.speech_probability; - estimate_denominator_ = - estimate_denominator_ * leak_factor + vad_level.speech_probability; - last_estimate_with_offset_dbfs_ = estimate_numerator_ / estimate_denominator_; + // Update level estimation (average level weighted by speech probability). + RTC_DCHECK_GT(vad_level.speech_probability, 0.f); + const float leak_factor = buffer_is_full ? kFullBufferLeakFactor : 1.f; + state_.level_dbfs.numerator = state_.level_dbfs.numerator * leak_factor + + level_dbfs * vad_level.speech_probability; + state_.level_dbfs.denominator = state_.level_dbfs.denominator * leak_factor + + vad_level.speech_probability; + // Cache level estimation. + last_level_dbfs_ = state_.level_dbfs.GetRatio(); + + // TODO(crbug.com/webrtc/7494): Update saturation protector state in `state`. if (use_saturation_protector_) { - saturation_protector_.UpdateMargin(vad_level.peak_dbfs, - last_estimate_with_offset_dbfs_); - DebugDumpEstimate(); + saturation_protector_.UpdateMargin( + /*speech_peak_dbfs=*/vad_level.peak_dbfs, + /*speech_level_dbfs=*/last_level_dbfs_.value()); } + + DebugDumpEstimate(); } -float AdaptiveModeLevelEstimator::LatestLevelEstimate() const { - return rtc::SafeClamp( - last_estimate_with_offset_dbfs_ + - (use_saturation_protector_ ? (saturation_protector_.margin_db() + - extra_saturation_margin_db_) - : 0.f), - -90.f, 30.f); +float AdaptiveModeLevelEstimator::GetLevelDbfs() const { + float level_dbfs = last_level_dbfs_.value_or(kInitialSpeechLevelEstimateDbfs); + if (use_saturation_protector_) { + level_dbfs += saturation_protector_.margin_db(); + level_dbfs += extra_saturation_margin_db_; + } + return rtc::SafeClamp(level_dbfs, -90.f, 30.f); +} + +bool AdaptiveModeLevelEstimator::IsConfident() const { + // Returns true if enough speech frames have been observed. + return state_.time_to_full_buffer_ms == 0; } void AdaptiveModeLevelEstimator::Reset() { - buffer_size_ms_ = 0; - last_estimate_with_offset_dbfs_ = kInitialSpeechLevelEstimateDbfs; - estimate_numerator_ = 0.f; - estimate_denominator_ = 0.f; saturation_protector_.Reset(); + ResetState(state_); + last_level_dbfs_ = absl::nullopt; +} + +void AdaptiveModeLevelEstimator::ResetState(State& state) { + state.time_to_full_buffer_ms = kFullBufferSizeMs; + state.level_dbfs.numerator = 0.f; + state.level_dbfs.denominator = 0.f; + // TODO(crbug.com/webrtc/7494): Reset saturation protector state in `state`. } void AdaptiveModeLevelEstimator::DebugDumpEstimate() { if (apm_data_dumper_) { - apm_data_dumper_->DumpRaw("agc2_adaptive_level_estimate_with_offset_dbfs", - last_estimate_with_offset_dbfs_); apm_data_dumper_->DumpRaw("agc2_adaptive_level_estimate_dbfs", - LatestLevelEstimate()); + GetLevelDbfs()); } saturation_protector_.DebugDumpEstimate(); } + } // namespace webrtc diff --git a/modules/audio_processing/agc2/adaptive_mode_level_estimator.h b/modules/audio_processing/agc2/adaptive_mode_level_estimator.h index a02641c50c..f5d6303020 100644 --- a/modules/audio_processing/agc2/adaptive_mode_level_estimator.h +++ b/modules/audio_processing/agc2/adaptive_mode_level_estimator.h @@ -13,6 +13,7 @@ #include +#include "absl/types/optional.h" #include "modules/audio_processing/agc2/agc2_common.h" #include "modules/audio_processing/agc2/saturation_protector.h" #include "modules/audio_processing/agc2/vad_with_level.h" @@ -21,6 +22,7 @@ namespace webrtc { class ApmDataDumper; +// Level estimator for the digital adaptive gain controller. class AdaptiveModeLevelEstimator { public: explicit AdaptiveModeLevelEstimator(ApmDataDumper* apm_data_dumper); @@ -40,26 +42,42 @@ class AdaptiveModeLevelEstimator { bool use_saturation_protector, float initial_saturation_margin_db, float extra_saturation_margin_db); - void UpdateEstimation(const VadLevelAnalyzer::Result& vad_level); - float LatestLevelEstimate() const; + + // Updates the level estimation. + void Update(const VadLevelAnalyzer::Result& vad_data); + // Returns the estimated speech plus noise level. + float GetLevelDbfs() const; + // Returns true if the estimator is confident on its current estimate. + bool IsConfident() const; + void Reset(); - bool LevelEstimationIsConfident() const { - return buffer_size_ms_ >= kFullBufferSizeMs; - } private: + // Part of the level estimator state used for check-pointing and restore ops. + struct State { + struct Ratio { + float numerator; + float denominator; + float GetRatio() const; + }; + int time_to_full_buffer_ms; + Ratio level_dbfs; + // TODO(crbug.com/webrtc/7494): Add saturation protector state. + }; + + void ResetState(State& state); void DebugDumpEstimate(); + ApmDataDumper* const apm_data_dumper_; + SaturationProtector saturation_protector_; + const AudioProcessing::Config::GainController2::LevelEstimator - level_estimator_; + level_estimator_type_; const bool use_saturation_protector_; const float extra_saturation_margin_db_; - size_t buffer_size_ms_ = 0; - float last_estimate_with_offset_dbfs_ = kInitialSpeechLevelEstimateDbfs; - float estimate_numerator_ = 0.f; - float estimate_denominator_ = 0.f; - SaturationProtector saturation_protector_; - ApmDataDumper* const apm_data_dumper_; + // TODO(crbug.com/webrtc/7494): Add temporary state. + State state_; + absl::optional last_level_dbfs_; }; } // namespace webrtc diff --git a/modules/audio_processing/agc2/adaptive_mode_level_estimator_agc.cc b/modules/audio_processing/agc2/adaptive_mode_level_estimator_agc.cc index b7c64373fc..17fa58280b 100644 --- a/modules/audio_processing/agc2/adaptive_mode_level_estimator_agc.cc +++ b/modules/audio_processing/agc2/adaptive_mode_level_estimator_agc.cc @@ -38,7 +38,7 @@ void AdaptiveModeLevelEstimatorAgc::Process(const int16_t* audio, if (latest_voice_probability_ > kVadConfidenceThreshold) { time_in_ms_since_last_estimate_ += kFrameDurationMs; } - level_estimator_.UpdateEstimation(vad_prob); + level_estimator_.Update(vad_prob); } // Retrieves the difference between the target RMS level and the current @@ -48,8 +48,8 @@ bool AdaptiveModeLevelEstimatorAgc::GetRmsErrorDb(int* error) { if (time_in_ms_since_last_estimate_ <= kTimeUntilConfidentMs) { return false; } - *error = std::floor(target_level_dbfs() - - level_estimator_.LatestLevelEstimate() + 0.5f); + *error = + std::floor(target_level_dbfs() - level_estimator_.GetLevelDbfs() + 0.5f); time_in_ms_since_last_estimate_ = 0; return true; } diff --git a/modules/audio_processing/agc2/adaptive_mode_level_estimator_unittest.cc b/modules/audio_processing/agc2/adaptive_mode_level_estimator_unittest.cc index be1fc9482e..6ab0655094 100644 --- a/modules/audio_processing/agc2/adaptive_mode_level_estimator_unittest.cc +++ b/modules/audio_processing/agc2/adaptive_mode_level_estimator_unittest.cc @@ -29,7 +29,7 @@ void RunOnConstantLevel(int num_iterations, const VadLevelAnalyzer::Result& vad_level, AdaptiveModeLevelEstimator& level_estimator) { for (int i = 0; i < num_iterations; ++i) { - level_estimator.UpdateEstimation(vad_level); + level_estimator.Update(vad_level); } } @@ -54,8 +54,8 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator, VadLevelAnalyzer::Result vad_level{kMaxSpeechProbability, /*rms_dbfs=*/-20.f, /*peak_dbfs=*/-10.f}; - level_estimator.estimator->UpdateEstimation(vad_level); - static_cast(level_estimator.estimator->LatestLevelEstimate()); + level_estimator.estimator->Update(vad_level); + static_cast(level_estimator.estimator->GetLevelDbfs()); } TEST(AutomaticGainController2AdaptiveModeLevelEstimator, LevelShouldStabilize) { @@ -69,9 +69,9 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator, LevelShouldStabilize) { kSpeechPeakDbfs}, *level_estimator.estimator); - EXPECT_NEAR(level_estimator.estimator->LatestLevelEstimate() - - kExtraSaturationMarginDb, - kSpeechPeakDbfs, 0.1f); + EXPECT_NEAR( + level_estimator.estimator->GetLevelDbfs() - kExtraSaturationMarginDb, + kSpeechPeakDbfs, 0.1f); } TEST(AutomaticGainController2AdaptiveModeLevelEstimator, @@ -96,9 +96,9 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator, *level_estimator.estimator); // Level should not have changed. - EXPECT_NEAR(level_estimator.estimator->LatestLevelEstimate() - - kExtraSaturationMarginDb, - kSpeechRmsDbfs, 0.1f); + EXPECT_NEAR( + level_estimator.estimator->GetLevelDbfs() - kExtraSaturationMarginDb, + kSpeechRmsDbfs, 0.1f); } TEST(AutomaticGainController2AdaptiveModeLevelEstimator, TimeToAdapt) { @@ -128,7 +128,7 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator, TimeToAdapt) { /*peak_dbfs=*/kDifferentSpeechRmsDbfs}, *level_estimator.estimator); EXPECT_GT(std::abs(kDifferentSpeechRmsDbfs - - level_estimator.estimator->LatestLevelEstimate()), + level_estimator.estimator->GetLevelDbfs()), kMaxDifferenceDb); // Run for some more time. Afterwards, we should have adapted. @@ -139,9 +139,9 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator, TimeToAdapt) { /*rms_dbfs=*/kDifferentSpeechRmsDbfs - kInitialSaturationMarginDb, /*peak_dbfs=*/kDifferentSpeechRmsDbfs}, *level_estimator.estimator); - EXPECT_NEAR(level_estimator.estimator->LatestLevelEstimate() - - kExtraSaturationMarginDb, - kDifferentSpeechRmsDbfs, kMaxDifferenceDb * 0.5f); + EXPECT_NEAR( + level_estimator.estimator->GetLevelDbfs() - kExtraSaturationMarginDb, + kDifferentSpeechRmsDbfs, kMaxDifferenceDb * 0.5f); } TEST(AutomaticGainController2AdaptiveModeLevelEstimator, @@ -175,7 +175,7 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator, const float kMaxDifferenceDb = 0.1f * std::abs(kDifferentSpeechRmsDbfs - kInitialSpeechRmsDbfs); EXPECT_LT(std::abs(kDifferentSpeechRmsDbfs - - (level_estimator.estimator->LatestLevelEstimate() - + (level_estimator.estimator->GetLevelDbfs() - kExtraSaturationMarginDb)), kMaxDifferenceDb); }