diff --git a/modules/audio_processing/aec3/BUILD.gn b/modules/audio_processing/aec3/BUILD.gn index f6b8ad0b51..292242f868 100644 --- a/modules/audio_processing/aec3/BUILD.gn +++ b/modules/audio_processing/aec3/BUILD.gn @@ -85,6 +85,8 @@ rtc_static_library("aec3") { "residual_echo_estimator.h", "reverb_model.cc", "reverb_model.h", + "reverb_model_estimator.cc", + "reverb_model_estimator.h", "reverb_model_fallback.cc", "reverb_model_fallback.h", "shadow_filter_update_gain.cc", @@ -187,6 +189,7 @@ if (rtc_include_tests) { "render_delay_controller_unittest.cc", "render_signal_analyzer_unittest.cc", "residual_echo_estimator_unittest.cc", + "reverb_model_estimator_unittest.cc", "shadow_filter_update_gain_unittest.cc", "skew_estimator_unittest.cc", "subtractor_unittest.cc", diff --git a/modules/audio_processing/aec3/aec3_common.cc b/modules/audio_processing/aec3/aec3_common.cc index 8e130551e9..3e60b46a47 100644 --- a/modules/audio_processing/aec3/aec3_common.cc +++ b/modules/audio_processing/aec3/aec3_common.cc @@ -13,6 +13,8 @@ #include "system_wrappers/include/cpu_features_wrapper.h" #include "typedefs.h" // NOLINT(build/include) +#include "rtc_base/checks.h" + namespace webrtc { Aec3Optimization DetectOptimization() { @@ -29,4 +31,25 @@ Aec3Optimization DetectOptimization() { return Aec3Optimization::kNone; } +float FastApproxLog2f(const float in) { + RTC_DCHECK_GT(in, .0f); + // Read and interpret float as uint32_t and then cast to float. + // This is done to extract the exponent (bits 30 - 23). + // "Right shift" of the exponent is then performed by multiplying + // with the constant (1/2^23). Finally, we subtract a constant to + // remove the bias (https://en.wikipedia.org/wiki/Exponent_bias). + union { + float dummy; + uint32_t a; + } x = {in}; + float out = x.a; + out *= 1.1920929e-7f; // 1/2^23 + out -= 126.942695f; // Remove bias. + return out; +} + +float Log2TodB(const float in_log2) { + return 3.0102999566398121 * in_log2; +} + } // namespace webrtc diff --git a/modules/audio_processing/aec3/aec3_common.h b/modules/audio_processing/aec3/aec3_common.h index 47f078415a..6422a5247d 100644 --- a/modules/audio_processing/aec3/aec3_common.h +++ b/modules/audio_processing/aec3/aec3_common.h @@ -88,6 +88,12 @@ constexpr size_t GetRenderDelayBufferSize(size_t down_sampling_factor, // Detects what kind of optimizations to use for the code. Aec3Optimization DetectOptimization(); +// Computes the log2 of the input in a fast an approximate manner. +float FastApproxLog2f(const float in); + +// Returns dB from a power quantity expressed in log2. +float Log2TodB(const float in_log2); + static_assert(1 << kBlockSizeLog2 == kBlockSize, "Proper number of shifts for blocksize"); diff --git a/modules/audio_processing/aec3/aec_state.cc b/modules/audio_processing/aec3/aec_state.cc index 9c6314eccf..b03a121555 100644 --- a/modules/audio_processing/aec3/aec_state.cc +++ b/modules/audio_processing/aec3/aec_state.cc @@ -15,7 +15,9 @@ #include #include +#include "absl/types/optional.h" #include "api/array_view.h" +#include "modules/audio_processing/aec3/aec3_common.h" #include "modules/audio_processing/logging/apm_data_dumper.h" #include "rtc_base/atomicops.h" #include "rtc_base/checks.h" @@ -68,13 +70,13 @@ AecState::AecState(const EchoCanceller3Config& config) EnableLinearModeWithDivergedFilter()), erle_estimator_(config.erle.min, config.erle.max_l, config.erle.max_h), max_render_(config_.filter.main.length_blocks, 0.f), - reverb_decay_(fabsf(config_.ep_strength.default_len)), gain_rampup_increase_(ComputeGainRampupIncrease(config_)), suppression_gain_limiter_(config_), filter_analyzer_(config_), blocks_since_converged_filter_(kBlocksSinceConvergencedFilterInit), active_blocks_since_consistent_filter_estimate_( - kBlocksSinceConsistentEstimateInit) {} + kBlocksSinceConsistentEstimateInit), + reverb_model_estimator_(config) {} AecState::~AecState() = default; @@ -285,12 +287,15 @@ void AecState::Update( use_linear_filter_output_ = usable_linear_estimate_ && !TransparentMode(); diverged_linear_filter_ = diverged_filter; - UpdateReverb(adaptive_filter_impulse_response); + reverb_model_estimator_.Update( + adaptive_filter_impulse_response, adaptive_filter_frequency_response, + erle_estimator_.GetInstLinearQualityEstimate(), filter_delay_blocks_, + usable_linear_estimate_, config_.ep_strength.default_len, + IsBlockStationary()); - data_dumper_->DumpRaw("aec3_erle", Erle()); - data_dumper_->DumpRaw("aec3_erle_onset", erle_estimator_.ErleOnsets()); + erle_estimator_.Dump(data_dumper_); + reverb_model_estimator_.Dump(data_dumper_); data_dumper_->DumpRaw("aec3_erl", Erl()); - data_dumper_->DumpRaw("aec3_erle_time_domain", ErleTimeDomain()); data_dumper_->DumpRaw("aec3_erl_time_domain", ErlTimeDomain()); data_dumper_->DumpRaw("aec3_usable_linear_estimate", UsableLinearEstimate()); data_dumper_->DumpRaw("aec3_transparent_mode", transparent_mode_); @@ -320,192 +325,7 @@ void AecState::Update( data_dumper_->DumpRaw("aec3_suppresion_gain_limiter_running", IsSuppressionGainLimitActive()); data_dumper_->DumpRaw("aec3_filter_tail_freq_resp_est", GetFreqRespTail()); -} -void AecState::UpdateReverb(const std::vector& impulse_response) { - // Echo tail estimation enabled if the below variable is set as negative. - if (config_.ep_strength.default_len >= 0.f) { - return; - } - - if ((!(filter_delay_blocks_ && usable_linear_estimate_)) || - (filter_delay_blocks_ > - static_cast(config_.filter.main.length_blocks) - 4)) { - return; - } - - constexpr float kOneByFftLengthBy2 = 1.f / kFftLengthBy2; - - // Form the data to match against by squaring the impulse response - // coefficients. - std::array - matching_data_data; - RTC_DCHECK_LE(GetTimeDomainLength(config_.filter.main.length_blocks), - matching_data_data.size()); - rtc::ArrayView matching_data( - matching_data_data.data(), - GetTimeDomainLength(config_.filter.main.length_blocks)); - std::transform(impulse_response.begin(), impulse_response.end(), - matching_data.begin(), [](float a) { return a * a; }); - - if (current_reverb_decay_section_ < config_.filter.main.length_blocks) { - // Update accumulated variables for the current filter section. - - const size_t start_index = current_reverb_decay_section_ * kFftLengthBy2; - - RTC_DCHECK_GT(matching_data.size(), start_index); - RTC_DCHECK_GE(matching_data.size(), start_index + kFftLengthBy2); - float section_energy = - std::accumulate(matching_data.begin() + start_index, - matching_data.begin() + start_index + kFftLengthBy2, - 0.f) * - kOneByFftLengthBy2; - - section_energy = std::max( - section_energy, 1e-32f); // Regularization to avoid division by 0. - - RTC_DCHECK_LT(current_reverb_decay_section_, block_energies_.size()); - const float energy_ratio = - block_energies_[current_reverb_decay_section_] / section_energy; - - main_filter_is_adapting_ = main_filter_is_adapting_ || - (energy_ratio > 1.1f || energy_ratio < 0.9f); - - // Count consecutive number of "good" filter sections, where "good" means: - // 1) energy is above noise floor. - // 2) energy of current section has not changed too much from last check. - if (!found_end_of_reverb_decay_ && section_energy > tail_energy_ && - !main_filter_is_adapting_) { - ++num_reverb_decay_sections_next_; - } else { - found_end_of_reverb_decay_ = true; - } - - block_energies_[current_reverb_decay_section_] = section_energy; - - if (num_reverb_decay_sections_ > 0) { - // Linear regression of log squared magnitude of impulse response. - for (size_t i = 0; i < kFftLengthBy2; i++) { - auto fast_approx_log2f = [](const float in) { - RTC_DCHECK_GT(in, .0f); - // Read and interpret float as uint32_t and then cast to float. - // This is done to extract the exponent (bits 30 - 23). - // "Right shift" of the exponent is then performed by multiplying - // with the constant (1/2^23). Finally, we subtract a constant to - // remove the bias (https://en.wikipedia.org/wiki/Exponent_bias). - union { - float dummy; - uint32_t a; - } x = {in}; - float out = x.a; - out *= 1.1920929e-7f; // 1/2^23 - out -= 126.942695f; // Remove bias. - return out; - }; - RTC_DCHECK_GT(matching_data.size(), start_index + i); - float z = fast_approx_log2f(matching_data[start_index + i]); - accumulated_nz_ += accumulated_count_ * z; - ++accumulated_count_; - } - } - - num_reverb_decay_sections_ = - num_reverb_decay_sections_ > 0 ? num_reverb_decay_sections_ - 1 : 0; - ++current_reverb_decay_section_; - - } else { - constexpr float kMaxDecay = 0.95f; // ~1 sec min RT60. - constexpr float kMinDecay = 0.02f; // ~15 ms max RT60. - - // Accumulated variables throughout whole filter. - - // Solve for decay rate. - - float decay = reverb_decay_; - - if (accumulated_nn_ != 0.f) { - const float exp_candidate = -accumulated_nz_ / accumulated_nn_; - decay = powf(2.0f, -exp_candidate * kFftLengthBy2); - decay = std::min(decay, kMaxDecay); - decay = std::max(decay, kMinDecay); - } - - // Filter tail energy (assumed to be noise). - - constexpr size_t kTailLength = kFftLength; - constexpr float k1ByTailLength = 1.f / kTailLength; - const size_t tail_index = - GetTimeDomainLength(config_.filter.main.length_blocks) - kTailLength; - - RTC_DCHECK_GT(matching_data.size(), tail_index); - tail_energy_ = std::accumulate(matching_data.begin() + tail_index, - matching_data.end(), 0.f) * - k1ByTailLength; - - // Update length of decay. - num_reverb_decay_sections_ = num_reverb_decay_sections_next_; - num_reverb_decay_sections_next_ = 0; - // Must have enough data (number of sections) in order - // to estimate decay rate. - if (num_reverb_decay_sections_ < 5) { - num_reverb_decay_sections_ = 0; - } - - const float N = num_reverb_decay_sections_ * kFftLengthBy2; - accumulated_nz_ = 0.f; - const float k1By12 = 1.f / 12.f; - // Arithmetic sum $2 \sum_{i=0.5}^{(N-1)/2}i^2$ calculated directly. - accumulated_nn_ = N * (N * N - 1.0f) * k1By12; - accumulated_count_ = -N * 0.5f; - // Linear regression approach assumes symmetric index around 0. - accumulated_count_ += 0.5f; - - // Identify the peak index of the impulse response. - const size_t peak_index = std::distance( - matching_data.begin(), - std::max_element(matching_data.begin(), matching_data.end())); - - current_reverb_decay_section_ = peak_index * kOneByFftLengthBy2 + 3; - // Make sure we're not out of bounds. - if (current_reverb_decay_section_ + 1 >= - config_.filter.main.length_blocks) { - current_reverb_decay_section_ = config_.filter.main.length_blocks; - } - size_t start_index = current_reverb_decay_section_ * kFftLengthBy2; - float first_section_energy = - std::accumulate(matching_data.begin() + start_index, - matching_data.begin() + start_index + kFftLengthBy2, - 0.f) * - kOneByFftLengthBy2; - - // To estimate the reverb decay, the energy of the first filter section - // must be substantially larger than the last. - // Also, the first filter section energy must not deviate too much - // from the max peak. - bool main_filter_has_reverb = first_section_energy > 4.f * tail_energy_; - bool main_filter_is_sane = first_section_energy > 2.f * tail_energy_ && - matching_data[peak_index] < 100.f; - - // Not detecting any decay, but tail is over noise - assume max decay. - if (num_reverb_decay_sections_ == 0 && main_filter_is_sane && - main_filter_has_reverb) { - decay = kMaxDecay; - } - - if (!main_filter_is_adapting_ && main_filter_is_sane && - num_reverb_decay_sections_ > 0) { - decay = std::max(.97f * reverb_decay_, decay); - reverb_decay_ -= .1f * (reverb_decay_ - decay); - } - - found_end_of_reverb_decay_ = - !(main_filter_is_sane && main_filter_has_reverb); - main_filter_is_adapting_ = false; - } - - data_dumper_->DumpRaw("aec3_reverb_decay", reverb_decay_); - data_dumper_->DumpRaw("aec3_reverb_tail_energy", tail_energy_); - data_dumper_->DumpRaw("aec3_suppression_gain_limit", SuppressionGainLimit()); } bool AecState::DetectActiveRender(rtc::ArrayView x) const { diff --git a/modules/audio_processing/aec3/aec_state.h b/modules/audio_processing/aec3/aec_state.h index 5c90128f29..caccdf7412 100644 --- a/modules/audio_processing/aec3/aec_state.h +++ b/modules/audio_processing/aec3/aec_state.h @@ -28,6 +28,7 @@ #include "modules/audio_processing/aec3/erle_estimator.h" #include "modules/audio_processing/aec3/filter_analyzer.h" #include "modules/audio_processing/aec3/render_buffer.h" +#include "modules/audio_processing/aec3/reverb_model_estimator.h" #include "modules/audio_processing/aec3/suppression_gain_limiter.h" #include "rtc_base/constructormagic.h" @@ -64,6 +65,17 @@ class AecState { // aec. bool UseStationaryProperties() const { return use_stationary_properties_; } + // Returns true if the current render block is estimated as stationary. + bool IsBlockStationary() const { + if (UseStationaryProperties()) { + return echo_audibility_.IsBlockStationary(); + } else { + // Assume that a non stationary block when the use of + // stationary properties are not enabled. + return false; + } + } + // Returns the ERLE. const std::array& Erle() const { return erle_estimator_.Erle(); @@ -77,8 +89,10 @@ class AecState { return absl::nullopt; } - // Returns the time-domain ERLE. - float ErleTimeDomain() const { return erle_estimator_.ErleTimeDomain(); } + // Returns the time-domain ERLE in log2 units. + float ErleTimeDomainLog2() const { + return erle_estimator_.ErleTimeDomainLog2(); + } // Returns the ERL. const std::array& Erl() const { @@ -112,7 +126,7 @@ class AecState { void HandleEchoPathChange(const EchoPathVariability& echo_path_variability); // Returns the decay factor for the echo reverberation. - float ReverbDecay() const { return reverb_decay_; } + float ReverbDecay() const { return reverb_model_estimator_.ReverbDecay(); } // Returns the upper limit for the echo suppression gain. float SuppressionGainLimit() const { @@ -146,7 +160,7 @@ class AecState { // Returns the tail freq. response of the linear filter. rtc::ArrayView GetFreqRespTail() const { - return filter_analyzer_.GetFreqRespTail(); + return reverb_model_estimator_.GetFreqRespTail(); } // Returns filter length in blocks. @@ -155,7 +169,6 @@ class AecState { } private: - void UpdateReverb(const std::vector& impulse_response); bool DetectActiveRender(rtc::ArrayView x) const; void UpdateSuppressorGainLimit(bool render_activity); bool DetectEchoSaturation(rtc::ArrayView x, @@ -182,18 +195,8 @@ class AecState { bool render_received_ = false; int filter_delay_blocks_ = 0; size_t blocks_since_last_saturation_ = 1000; - float tail_energy_ = 0.f; - float accumulated_nz_ = 0.f; - float accumulated_nn_ = 0.f; - float accumulated_count_ = 0.f; - size_t current_reverb_decay_section_ = 0; - size_t num_reverb_decay_sections_ = 0; - size_t num_reverb_decay_sections_next_ = 0; - bool found_end_of_reverb_decay_ = false; - bool main_filter_is_adapting_ = true; - std::array block_energies_; + std::vector max_render_; - float reverb_decay_ = fabsf(config_.ep_strength.default_len); bool filter_has_had_time_to_converge_ = false; bool initial_state_ = true; const float gain_rampup_increase_; @@ -214,6 +217,7 @@ class AecState { bool finite_erl_ = false; size_t active_blocks_since_converged_filter_ = 0; EchoAudibility echo_audibility_; + ReverbModelEstimator reverb_model_estimator_; RTC_DISALLOW_COPY_AND_ASSIGN(AecState); }; diff --git a/modules/audio_processing/aec3/echo_audibility.h b/modules/audio_processing/aec3/echo_audibility.h index 141ac620cb..4650fa5386 100644 --- a/modules/audio_processing/aec3/echo_audibility.h +++ b/modules/audio_processing/aec3/echo_audibility.h @@ -51,6 +51,11 @@ class EchoAudibility { } } + // Returns true if the current render block is estimated as stationary. + bool IsBlockStationary() const { + return render_stationarity_.IsBlockStationary(); + } + private: // Reset the EchoAudibility class. void Reset(); diff --git a/modules/audio_processing/aec3/echo_remover.cc b/modules/audio_processing/aec3/echo_remover.cc index deae3a3986..d382d93704 100644 --- a/modules/audio_processing/aec3/echo_remover.cc +++ b/modules/audio_processing/aec3/echo_remover.cc @@ -137,7 +137,7 @@ void EchoRemoverImpl::GetMetrics(EchoControl::Metrics* metrics) const { // Echo return loss (ERL) is inverted to go from gain to attenuation. metrics->echo_return_loss = -10.0 * log10(aec_state_.ErlTimeDomain()); metrics->echo_return_loss_enhancement = - 10.0 * log10(aec_state_.ErleTimeDomain()); + Log2TodB(aec_state_.ErleTimeDomainLog2()); } void EchoRemoverImpl::ProcessCapture( diff --git a/modules/audio_processing/aec3/echo_remover_metrics.cc b/modules/audio_processing/aec3/echo_remover_metrics.cc index c970649844..8592a93b65 100644 --- a/modules/audio_processing/aec3/echo_remover_metrics.cc +++ b/modules/audio_processing/aec3/echo_remover_metrics.cc @@ -67,7 +67,7 @@ void EchoRemoverMetrics::Update( aec3::UpdateDbMetric(aec_state.Erl(), &erl_); erl_time_domain_.UpdateInstant(aec_state.ErlTimeDomain()); aec3::UpdateDbMetric(aec_state.Erle(), &erle_); - erle_time_domain_.UpdateInstant(aec_state.ErleTimeDomain()); + erle_time_domain_.UpdateInstant(aec_state.ErleTimeDomainLog2()); aec3::UpdateDbMetric(comfort_noise_spectrum, &comfort_noise_); aec3::UpdateDbMetric(suppressor_gain, &suppressor_gain_); active_render_count_ += (aec_state.ActiveRender() ? 1 : 0); diff --git a/modules/audio_processing/aec3/erle_estimator.cc b/modules/audio_processing/aec3/erle_estimator.cc index ab6c1c7278..52ef8edf2f 100644 --- a/modules/audio_processing/aec3/erle_estimator.cc +++ b/modules/audio_processing/aec3/erle_estimator.cc @@ -13,29 +13,157 @@ #include #include +#include "absl/types/optional.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" #include "rtc_base/numerics/safe_minmax.h" namespace webrtc { +namespace { +constexpr int kPointsToAccumulate = 6; +constexpr float kEpsilon = 1e-3f; +} // namespace + ErleEstimator::ErleEstimator(float min_erle, float max_erle_lf, float max_erle_hf) : min_erle_(min_erle), + min_erle_log2_(FastApproxLog2f(min_erle_ + kEpsilon)), max_erle_lf_(max_erle_lf), - max_erle_hf_(max_erle_hf) { + max_erle_lf_log2(FastApproxLog2f(max_erle_lf_ + kEpsilon)), + max_erle_hf_(max_erle_hf), + erle_freq_inst_(kPointsToAccumulate), + erle_time_inst_(kPointsToAccumulate) { erle_.fill(min_erle_); erle_onsets_.fill(min_erle_); - Y2_acum_.fill(0.f); - E2_acum_.fill(0.f); - num_points_.fill(0); hold_counters_.fill(0); coming_onset_.fill(true); - erle_time_domain_ = min_erle_; + erle_time_domain_log2_ = min_erle_log2_; hold_counter_time_domain_ = 0; } ErleEstimator::~ErleEstimator() = default; +ErleEstimator::ErleTimeInstantaneous::ErleTimeInstantaneous( + int points_to_accumulate) + : points_to_accumulate_(points_to_accumulate) { + Reset(); +} +ErleEstimator::ErleTimeInstantaneous::~ErleTimeInstantaneous() = default; + +bool ErleEstimator::ErleTimeInstantaneous::Update(const float Y2_sum, + const float E2_sum) { + bool ret = false; + E2_acum_ += E2_sum; + Y2_acum_ += Y2_sum; + num_points_++; + if (num_points_ == points_to_accumulate_) { + if (E2_acum_ > 0.f) { + ret = true; + erle_log2_ = FastApproxLog2f(Y2_acum_ / E2_acum_ + kEpsilon); + } + num_points_ = 0; + E2_acum_ = 0.f; + Y2_acum_ = 0.f; + } + + if (ret) { + UpdateMaxMin(); + UpdateQualityEstimate(); + } + return ret; +} + +void ErleEstimator::ErleTimeInstantaneous::Reset() { + ResetAccumulators(); + max_erle_log2_ = -10.f; // -30 dB. + min_erle_log2_ = 33.f; // 100 dB. + inst_quality_estimate_ = 0.f; +} + +void ErleEstimator::ErleTimeInstantaneous::ResetAccumulators() { + erle_log2_ = absl::nullopt; + inst_quality_estimate_ = 0.f; + num_points_ = 0; + E2_acum_ = 0.f; + Y2_acum_ = 0.f; +} + +void ErleEstimator::ErleTimeInstantaneous::Dump( + const std::unique_ptr& data_dumper) { + data_dumper->DumpRaw("aec3_erle_time_inst_log2", + erle_log2_ ? *erle_log2_ : -10.f); + data_dumper->DumpRaw( + "aec3_erle_time_quality", + GetInstQualityEstimate() ? GetInstQualityEstimate().value() : 0.f); + data_dumper->DumpRaw("aec3_erle_time_max_log2", max_erle_log2_); + data_dumper->DumpRaw("aec3_erle_time_min_log2", min_erle_log2_); +} + +void ErleEstimator::ErleTimeInstantaneous::UpdateMaxMin() { + RTC_DCHECK(erle_log2_); + if (erle_log2_.value() > max_erle_log2_) { + max_erle_log2_ = erle_log2_.value(); + } else { + max_erle_log2_ -= 0.0004; // Forget factor, approx 1dB every 3 sec. + } + + if (erle_log2_.value() < min_erle_log2_) { + min_erle_log2_ = erle_log2_.value(); + } else { + min_erle_log2_ += 0.0004; // Forget factor, approx 1dB every 3 sec. + } +} + +void ErleEstimator::ErleTimeInstantaneous::UpdateQualityEstimate() { + const float alpha = 0.07f; + float quality_estimate = 0.f; + RTC_DCHECK(erle_log2_); + if (max_erle_log2_ > min_erle_log2_) { + quality_estimate = (erle_log2_.value() - min_erle_log2_) / + (max_erle_log2_ - min_erle_log2_); + } + if (quality_estimate > inst_quality_estimate_) { + inst_quality_estimate_ = quality_estimate; + } else { + inst_quality_estimate_ += + alpha * (quality_estimate - inst_quality_estimate_); + } +} + +ErleEstimator::ErleFreqInstantaneous::ErleFreqInstantaneous( + int points_to_accumulate) + : points_to_accumulate_(points_to_accumulate) { + Reset(); +} + +ErleEstimator::ErleFreqInstantaneous::~ErleFreqInstantaneous() = default; + +absl::optional +ErleEstimator::ErleFreqInstantaneous::Update(float Y2, float E2, size_t band) { + absl::optional ret = absl::nullopt; + RTC_DCHECK_LT(band, kFftLengthBy2Plus1); + Y2_acum_[band] += Y2; + E2_acum_[band] += E2; + if (++num_points_[band] == points_to_accumulate_) { + if (E2_acum_[band]) { + ret = Y2_acum_[band] / E2_acum_[band]; + } + num_points_[band] = 0; + Y2_acum_[band] = 0.f; + E2_acum_[band] = 0.f; + } + + return ret; +} + +void ErleEstimator::ErleFreqInstantaneous::Reset() { + Y2_acum_.fill(0.f); + E2_acum_.fill(0.f); + num_points_.fill(0); +} + void ErleEstimator::Update(rtc::ArrayView render_spectrum, rtc::ArrayView capture_spectrum, rtc::ArrayView subtractor_spectrum, @@ -49,7 +177,7 @@ void ErleEstimator::Update(rtc::ArrayView render_spectrum, // Corresponds of WGN of power -46 dBFS. constexpr float kX2Min = 44015068.0f; - constexpr int kPointsToAccumulate = 6; + constexpr int kErleHold = 100; constexpr int kBlocksForOnsetDetection = kErleHold + 150; @@ -66,24 +194,18 @@ void ErleEstimator::Update(rtc::ArrayView render_spectrum, auto erle_update = [&](size_t start, size_t stop, float max_erle) { for (size_t k = start; k < stop; ++k) { if (X2[k] > kX2Min) { - ++num_points_[k]; - Y2_acum_[k] += Y2[k]; - E2_acum_[k] += E2[k]; - if (num_points_[k] == kPointsToAccumulate) { - if (E2_acum_[k] > 0) { - const float new_erle = Y2_acum_[k] / E2_acum_[k]; - if (coming_onset_[k]) { - coming_onset_[k] = false; - erle_onsets_[k] = erle_band_update( - erle_onsets_[k], new_erle, 0.15f, 0.3f, min_erle_, max_erle); - } - hold_counters_[k] = kBlocksForOnsetDetection; - erle_[k] = erle_band_update(erle_[k], new_erle, 0.05f, 0.1f, - min_erle_, max_erle); + absl::optional new_erle = + erle_freq_inst_.Update(Y2[k], E2[k], k); + if (new_erle) { + if (coming_onset_[k]) { + coming_onset_[k] = false; + erle_onsets_[k] = + erle_band_update(erle_onsets_[k], new_erle.value(), 0.15f, 0.3f, + min_erle_, max_erle); } - num_points_[k] = 0; - Y2_acum_[k] = 0.f; - E2_acum_[k] = 0.f; + hold_counters_[k] = kBlocksForOnsetDetection; + erle_[k] = erle_band_update(erle_[k], new_erle.value(), 0.05f, 0.1f, + min_erle_, max_erle); } } } @@ -118,22 +240,34 @@ void ErleEstimator::Update(rtc::ArrayView render_spectrum, if (converged_filter) { // Compute ERLE over all frequency bins. const float X2_sum = std::accumulate(X2.begin(), X2.end(), 0.0f); - const float E2_sum = std::accumulate(E2.begin(), E2.end(), 0.0f); - if (X2_sum > kX2Min * X2.size() && E2_sum > 0.f) { + if (X2_sum > kX2Min * X2.size()) { const float Y2_sum = std::accumulate(Y2.begin(), Y2.end(), 0.0f); - const float new_erle = Y2_sum / E2_sum; - if (new_erle > erle_time_domain_) { + const float E2_sum = std::accumulate(E2.begin(), E2.end(), 0.0f); + if (erle_time_inst_.Update(Y2_sum, E2_sum)) { hold_counter_time_domain_ = kErleHold; - erle_time_domain_ += 0.1f * (new_erle - erle_time_domain_); - erle_time_domain_ = - rtc::SafeClamp(erle_time_domain_, min_erle_, max_erle_lf_); + erle_time_domain_log2_ += + 0.1f * ((erle_time_inst_.GetInstErle_log2().value()) - + erle_time_domain_log2_); + erle_time_domain_log2_ = rtc::SafeClamp( + erle_time_domain_log2_, min_erle_log2_, max_erle_lf_log2); } } } --hold_counter_time_domain_; - erle_time_domain_ = (hold_counter_time_domain_ > 0) - ? erle_time_domain_ - : std::max(min_erle_, 0.97f * erle_time_domain_); + if (hold_counter_time_domain_ <= 0) { + erle_time_domain_log2_ = + std::max(min_erle_log2_, erle_time_domain_log2_ - 0.044f); + } + if (hold_counter_time_domain_ == 0) { + erle_time_inst_.ResetAccumulators(); + } +} + +void ErleEstimator::Dump(const std::unique_ptr& data_dumper) { + data_dumper->DumpRaw("aec3_erle", Erle()); + data_dumper->DumpRaw("aec3_erle_onset", ErleOnsets()); + data_dumper->DumpRaw("aec3_erle_time_domain_log2", ErleTimeDomainLog2()); + erle_time_inst_.Dump(data_dumper); } } // namespace webrtc diff --git a/modules/audio_processing/aec3/erle_estimator.h b/modules/audio_processing/aec3/erle_estimator.h index cdfbf7f2f8..19bc7b407d 100644 --- a/modules/audio_processing/aec3/erle_estimator.h +++ b/modules/audio_processing/aec3/erle_estimator.h @@ -13,8 +13,10 @@ #include +#include "absl/types/optional.h" #include "api/array_view.h" #include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" #include "rtc_base/constructormagic.h" namespace webrtc { @@ -37,22 +39,80 @@ class ErleEstimator { const std::array& ErleOnsets() const { return erle_onsets_; } - float ErleTimeDomain() const { return erle_time_domain_; } + float ErleTimeDomainLog2() const { return erle_time_domain_log2_; } + + absl::optional GetInstLinearQualityEstimate() const { + return erle_time_inst_.GetInstQualityEstimate(); + } + + void Dump(const std::unique_ptr& data_dumper); + + class ErleTimeInstantaneous { + public: + ErleTimeInstantaneous(int points_to_accumulate); + ~ErleTimeInstantaneous(); + // Update the estimator with a new point, returns true + // if the instantaneous erle was updated due to having enough + // points for performing the estimate. + bool Update(const float Y2_sum, const float E2_sum); + // Reset all the members of the class. + void Reset(); + // Reset the members realated with an instantaneous estimate. + void ResetAccumulators(); + // Returns the instantaneous ERLE in log2 units. + absl::optional GetInstErle_log2() const { return erle_log2_; } + // Get an indication between 0 and 1 of the performance of the linear filter + // for the current time instant. + absl::optional GetInstQualityEstimate() const { + return erle_log2_ ? absl::optional(inst_quality_estimate_) + : absl::nullopt; + } + void Dump(const std::unique_ptr& data_dumper); + + private: + void UpdateMaxMin(); + void UpdateQualityEstimate(); + absl::optional erle_log2_; + float inst_quality_estimate_; + float max_erle_log2_; + float min_erle_log2_; + float Y2_acum_; + float E2_acum_; + int num_points_; + const int points_to_accumulate_; + }; + + class ErleFreqInstantaneous { + public: + ErleFreqInstantaneous(int points_to_accumulate); + ~ErleFreqInstantaneous(); + // Updates the ERLE for a band with a new block. Returns absl::nullopt + // if not enough points were accuulated for doing the estimation. + absl::optional Update(float Y2, float E2, size_t band); + // Reset all the member of the class. + void Reset(); + + private: + std::array Y2_acum_; + std::array E2_acum_; + std::array num_points_; + const int points_to_accumulate_; + }; private: std::array erle_; std::array erle_onsets_; - std::array Y2_acum_; - std::array E2_acum_; - std::array num_points_; std::array coming_onset_; std::array hold_counters_; - float erle_time_domain_; int hold_counter_time_domain_; + float erle_time_domain_log2_; const float min_erle_; + const float min_erle_log2_; const float max_erle_lf_; + const float max_erle_lf_log2; const float max_erle_hf_; - + ErleFreqInstantaneous erle_freq_inst_; + ErleTimeInstantaneous erle_time_inst_; RTC_DISALLOW_COPY_AND_ASSIGN(ErleEstimator); }; diff --git a/modules/audio_processing/aec3/erle_estimator_unittest.cc b/modules/audio_processing/aec3/erle_estimator_unittest.cc index ca812a5e44..86ac5dfd1f 100644 --- a/modules/audio_processing/aec3/erle_estimator_unittest.cc +++ b/modules/audio_processing/aec3/erle_estimator_unittest.cc @@ -8,6 +8,8 @@ * be found in the AUTHORS file in the root of the source tree. */ +#include + #include "modules/audio_processing/aec3/erle_estimator.h" #include "api/array_view.h" #include "test/gtest.h" @@ -39,7 +41,7 @@ void VerifyErle(rtc::ArrayView erle, float reference_lf, float reference_hf) { VerifyErleBands(erle, reference_lf, reference_hf); - EXPECT_NEAR(reference_lf, erle_time_domain, 0.001); + EXPECT_NEAR(reference_lf, erle_time_domain, 0.5); } void FormFarendFrame(std::array* X2, @@ -74,7 +76,8 @@ TEST(ErleEstimator, VerifyErleIncreaseAndHold) { for (size_t k = 0; k < 200; ++k) { estimator.Update(X2, Y2, E2, true); } - VerifyErle(estimator.Erle(), estimator.ErleTimeDomain(), 8.f, 1.5f); + VerifyErle(estimator.Erle(), std::pow(2.f, estimator.ErleTimeDomainLog2()), + kMaxErleLf, kMaxErleHf); FormNearendFrame(&X2, &E2, &Y2); // Verifies that the ERLE is not immediately decreased during nearend @@ -82,7 +85,8 @@ TEST(ErleEstimator, VerifyErleIncreaseAndHold) { for (size_t k = 0; k < 50; ++k) { estimator.Update(X2, Y2, E2, true); } - VerifyErle(estimator.Erle(), estimator.ErleTimeDomain(), 8.f, 1.5f); + VerifyErle(estimator.Erle(), std::pow(2.f, estimator.ErleTimeDomainLog2()), + kMaxErleLf, kMaxErleHf); } TEST(ErleEstimator, VerifyErleTrackingOnOnsets) { @@ -112,7 +116,8 @@ TEST(ErleEstimator, VerifyErleTrackingOnOnsets) { estimator.Update(X2, Y2, E2, true); } // Verifies that during ne activity, Erle converges to the Erle for onsets. - VerifyErle(estimator.Erle(), estimator.ErleTimeDomain(), kMinErle, kMinErle); + VerifyErle(estimator.Erle(), std::pow(2.f, estimator.ErleTimeDomainLog2()), + kMinErle, kMinErle); } TEST(ErleEstimator, VerifyNoErleUpdateDuringLowActivity) { @@ -128,7 +133,8 @@ TEST(ErleEstimator, VerifyNoErleUpdateDuringLowActivity) { for (size_t k = 0; k < 200; ++k) { estimator.Update(X2, Y2, E2, true); } - VerifyErle(estimator.Erle(), estimator.ErleTimeDomain(), kMinErle, kMinErle); + VerifyErle(estimator.Erle(), std::pow(2.f, estimator.ErleTimeDomainLog2()), + kMinErle, kMinErle); } } // namespace webrtc diff --git a/modules/audio_processing/aec3/filter_analyzer.cc b/modules/audio_processing/aec3/filter_analyzer.cc index a48a36254e..ab2c4f25ef 100644 --- a/modules/audio_processing/aec3/filter_analyzer.cc +++ b/modules/audio_processing/aec3/filter_analyzer.cc @@ -43,28 +43,6 @@ bool EnableFilterPreprocessing() { "WebRTC-Aec3FilterAnalyzerPreprocessorKillSwitch"); } -// Computes the ratio of the energies between the direct path and the tail. The -// energy is computed in the power spectrum domain discarding the DC -// contributions. -float ComputeRatioEnergies(rtc::ArrayView& freq_resp_direct_path, - rtc::ArrayView& freq_resp_tail) { - // Skipping the DC for the ratio computation - constexpr size_t n_skip_bins = 1; - RTC_CHECK_EQ(freq_resp_direct_path.size(), freq_resp_tail.size()); - - float direct_path_energy = - std::accumulate(freq_resp_direct_path.begin() + n_skip_bins, - freq_resp_direct_path.end(), 0.f); - - float tail_energy = std::accumulate(freq_resp_tail.begin() + n_skip_bins, - freq_resp_tail.end(), 0.f); - - if (direct_path_energy > 0) { - return tail_energy / direct_path_energy; - } else { - return 0.f; - } -} } // namespace int FilterAnalyzer::instance_count_ = 0; @@ -108,7 +86,6 @@ void FilterAnalyzer::Reset() { consistent_estimate_counter_ = 0; consistent_delay_reference_ = -10; gain_ = default_gain_; - freq_resp_tail_.fill(0.f); } void FilterAnalyzer::Update( @@ -169,7 +146,7 @@ void FilterAnalyzer::Update( consistent_estimate_ = consistent_estimate_counter_ > 1.5f * kNumBlocksPerSecond; - UpdateFreqRespTail(filter_freq_response); + filter_length_blocks_ = filter_time_domain.size() * (1.f / kBlockSize); } @@ -192,29 +169,4 @@ void FilterAnalyzer::UpdateFilterGain( } } -// Updates the estimation of the frequency response at the filter tail. -void FilterAnalyzer::UpdateFreqRespTail( - const std::vector>& - filter_freq_response) { - size_t num_blocks = filter_freq_response.size(); - rtc::ArrayView freq_resp_tail( - filter_freq_response[num_blocks - 1]); - rtc::ArrayView freq_resp_direct_path( - filter_freq_response[DelayBlocks()]); - float ratio_energies = - ComputeRatioEnergies(freq_resp_direct_path, freq_resp_tail); - ratio_tail_to_direct_path_ += - 0.1f * (ratio_energies - ratio_tail_to_direct_path_); - - for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { - freq_resp_tail_[k] = freq_resp_direct_path[k] * ratio_tail_to_direct_path_; - } - - for (size_t k = 1; k < kFftLengthBy2; ++k) { - float avg_neighbour = - 0.5f * (freq_resp_tail_[k - 1] + freq_resp_tail_[k + 1]); - freq_resp_tail_[k] = std::max(freq_resp_tail_[k], avg_neighbour); - } -} - } // namespace webrtc diff --git a/modules/audio_processing/aec3/filter_analyzer.h b/modules/audio_processing/aec3/filter_analyzer.h index 712e46aa8e..627341d71e 100644 --- a/modules/audio_processing/aec3/filter_analyzer.h +++ b/modules/audio_processing/aec3/filter_analyzer.h @@ -51,11 +51,6 @@ class FilterAnalyzer { // Returns the estimated filter gain. float Gain() const { return gain_; } - // Return the estimated freq. response of the tail of the filter. - rtc::ArrayView GetFreqRespTail() const { - return freq_resp_tail_; - } - // Returns the number of blocks for the current used filter. float FilterLengthBlocks() const { return filter_length_blocks_; } @@ -82,8 +77,6 @@ class FilterAnalyzer { size_t consistent_estimate_counter_ = 0; int consistent_delay_reference_ = -10; float gain_; - std::array freq_resp_tail_; - float ratio_tail_to_direct_path_ = 0.f; int filter_length_blocks_; RTC_DISALLOW_COPY_AND_ASSIGN(FilterAnalyzer); }; diff --git a/modules/audio_processing/aec3/reverb_model_estimator.cc b/modules/audio_processing/aec3/reverb_model_estimator.cc new file mode 100644 index 0000000000..18b2a845da --- /dev/null +++ b/modules/audio_processing/aec3/reverb_model_estimator.cc @@ -0,0 +1,321 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/reverb_model_estimator.h" + +#include +#include +#include +#include + +#include "api/array_view.h" +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "rtc_base/checks.h" +#include "system_wrappers/include/field_trial.h" + +namespace webrtc { + +namespace { + +bool EnableSmoothUpdatesTailFreqResp() { + return !field_trial::IsEnabled( + "WebRTC-Aec3SmoothUpdatesTailFreqRespKillSwitch"); +} + +// Computes the ratio of the energies between the direct path and the tail. The +// energy is computed in the power spectrum domain discarding the DC +// contributions. +float ComputeRatioEnergies( + const rtc::ArrayView& freq_resp_direct_path, + const rtc::ArrayView& freq_resp_tail) { + // Skipping the DC for the ratio computation + constexpr size_t n_skip_bins = 1; + RTC_CHECK_EQ(freq_resp_direct_path.size(), freq_resp_tail.size()); + + float direct_path_energy = + std::accumulate(freq_resp_direct_path.begin() + n_skip_bins, + freq_resp_direct_path.end(), 0.f); + + float tail_energy = std::accumulate(freq_resp_tail.begin() + n_skip_bins, + freq_resp_tail.end(), 0.f); + + if (direct_path_energy > 0) { + return tail_energy / direct_path_energy; + } else { + return 0.f; + } +} + +} // namespace + +ReverbModelEstimator::ReverbModelEstimator(const EchoCanceller3Config& config) + : filter_main_length_blocks_(config.filter.main.length_blocks), + reverb_decay_(fabsf(config.ep_strength.default_len)), + enable_smooth_freq_resp_tail_updates_(EnableSmoothUpdatesTailFreqResp()) { + block_energies_.fill(0.f); + freq_resp_tail_.fill(0.f); +} + +ReverbModelEstimator::~ReverbModelEstimator() = default; + +bool ReverbModelEstimator::IsAGoodFilterForDecayEstimation( + int filter_delay_blocks, + bool usable_linear_estimate, + size_t length_filter) { + if ((filter_delay_blocks && usable_linear_estimate) && + (filter_delay_blocks <= + static_cast(filter_main_length_blocks_) - 4) && + (length_filter >= + static_cast(GetTimeDomainLength(filter_main_length_blocks_)))) { + return true; + } else { + return false; + } +} + +void ReverbModelEstimator::Update( + const std::vector& impulse_response, + const std::vector>& + filter_freq_response, + const absl::optional& quality_linear, + int filter_delay_blocks, + bool usable_linear_estimate, + float default_decay, + bool stationary_block) { + if (enable_smooth_freq_resp_tail_updates_) { + if (!stationary_block) { + float alpha = 0; + if (quality_linear) { + alpha = 0.2f * quality_linear.value(); + UpdateFreqRespTail(filter_freq_response, filter_delay_blocks, alpha); + } + if (IsAGoodFilterForDecayEstimation(filter_delay_blocks, + usable_linear_estimate, + impulse_response.size())) { + alpha_ = std::max(alpha, alpha_); + if ((alpha_ > 0.f) && (default_decay < 0.f)) { + // Echo tail decay estimation if default_decay is negative. + UpdateReverbDecay(impulse_response); + } + } else { + ResetDecayEstimation(); + } + } + } else { + UpdateFreqRespTail(filter_freq_response, filter_delay_blocks, 0.1f); + } +} + +void ReverbModelEstimator::ResetDecayEstimation() { + accumulated_nz_ = 0.f; + accumulated_nn_ = 0.f; + accumulated_count_ = 0.f; + current_reverb_decay_section_ = 0; + num_reverb_decay_sections_ = 0; + num_reverb_decay_sections_next_ = 0; + found_end_of_reverb_decay_ = false; + alpha_ = 0.f; +} + +void ReverbModelEstimator::UpdateReverbDecay( + const std::vector& impulse_response) { + constexpr float kOneByFftLengthBy2 = 1.f / kFftLengthBy2; + + // Form the data to match against by squaring the impulse response + // coefficients. + std::array + matching_data_data; + RTC_DCHECK_LE(GetTimeDomainLength(filter_main_length_blocks_), + matching_data_data.size()); + rtc::ArrayView matching_data( + matching_data_data.data(), + GetTimeDomainLength(filter_main_length_blocks_)); + std::transform( + impulse_response.begin(), impulse_response.end(), matching_data.begin(), + [](float a) { return a * a; }); // TODO(devicentepena) check if focusing + // on one block would be enough. + + if (current_reverb_decay_section_ < filter_main_length_blocks_) { + // Update accumulated variables for the current filter section. + + const size_t start_index = current_reverb_decay_section_ * kFftLengthBy2; + + RTC_DCHECK_GT(matching_data.size(), start_index); + RTC_DCHECK_GE(matching_data.size(), start_index + kFftLengthBy2); + float section_energy = + std::accumulate(matching_data.begin() + start_index, + matching_data.begin() + start_index + kFftLengthBy2, + 0.f) * + kOneByFftLengthBy2; + + section_energy = std::max( + section_energy, 1e-32f); // Regularization to avoid division by 0. + + RTC_DCHECK_LT(current_reverb_decay_section_, block_energies_.size()); + const float energy_ratio = + block_energies_[current_reverb_decay_section_] / section_energy; + + found_end_of_reverb_decay_ = found_end_of_reverb_decay_ || + (energy_ratio > 1.1f || energy_ratio < 0.9f); + + // Count consecutive number of "good" filter sections, where "good" means: + // 1) energy is above noise floor. + // 2) energy of current section has not changed too much from last check. + if (!found_end_of_reverb_decay_ && section_energy > tail_energy_) { + ++num_reverb_decay_sections_next_; + } else { + found_end_of_reverb_decay_ = true; + } + + block_energies_[current_reverb_decay_section_] = section_energy; + + if (num_reverb_decay_sections_ > 0) { + // Linear regression of log squared magnitude of impulse response. + for (size_t i = 0; i < kFftLengthBy2; i++) { + RTC_DCHECK_GT(matching_data.size(), start_index + i); + float z = FastApproxLog2f(matching_data[start_index + i] + 1e-10); + accumulated_nz_ += accumulated_count_ * z; + ++accumulated_count_; + } + } + + num_reverb_decay_sections_ = + num_reverb_decay_sections_ > 0 ? num_reverb_decay_sections_ - 1 : 0; + ++current_reverb_decay_section_; + + } else { + constexpr float kMaxDecay = 0.95f; // ~1 sec min RT60. + constexpr float kMinDecay = 0.02f; // ~15 ms max RT60. + + // Accumulated variables throughout whole filter. + + // Solve for decay rate. + + float decay = reverb_decay_; + + if (accumulated_nn_ != 0.f) { + const float exp_candidate = -accumulated_nz_ / accumulated_nn_; + decay = powf(2.0f, -exp_candidate * kFftLengthBy2); + decay = std::min(decay, kMaxDecay); + decay = std::max(decay, kMinDecay); + } + + // Filter tail energy (assumed to be noise). + constexpr size_t kTailLength = kFftLengthBy2; + + constexpr float k1ByTailLength = 1.f / kTailLength; + const size_t tail_index = + GetTimeDomainLength(filter_main_length_blocks_) - kTailLength; + + RTC_DCHECK_GT(matching_data.size(), tail_index); + + tail_energy_ = std::accumulate(matching_data.begin() + tail_index, + matching_data.end(), 0.f) * + k1ByTailLength; + + // Update length of decay. + num_reverb_decay_sections_ = num_reverb_decay_sections_next_; + num_reverb_decay_sections_next_ = 0; + // Must have enough data (number of sections) in order + // to estimate decay rate. + if (num_reverb_decay_sections_ < 5) { + num_reverb_decay_sections_ = 0; + } + + const float N = num_reverb_decay_sections_ * kFftLengthBy2; + accumulated_nz_ = 0.f; + const float k1By12 = 1.f / 12.f; + // Arithmetic sum $2 \sum_{i=0.5}^{(N-1)/2}i^2$ calculated directly. + accumulated_nn_ = N * (N * N - 1.0f) * k1By12; + accumulated_count_ = -N * 0.5f; + // Linear regression approach assumes symmetric index around 0. + accumulated_count_ += 0.5f; + + // Identify the peak index of the impulse response. + const size_t peak_index = std::distance( + matching_data.begin(), + std::max_element(matching_data.begin(), matching_data.end())); + + current_reverb_decay_section_ = peak_index * kOneByFftLengthBy2 + 3; + // Make sure we're not out of bounds. + if (current_reverb_decay_section_ + 1 >= filter_main_length_blocks_) { + current_reverb_decay_section_ = filter_main_length_blocks_; + } + size_t start_index = current_reverb_decay_section_ * kFftLengthBy2; + float first_section_energy = + std::accumulate(matching_data.begin() + start_index, + matching_data.begin() + start_index + kFftLengthBy2, + 0.f) * + kOneByFftLengthBy2; + + // To estimate the reverb decay, the energy of the first filter section + // must be substantially larger than the last. + // Also, the first filter section energy must not deviate too much + // from the max peak. + bool main_filter_has_reverb = first_section_energy > 4.f * tail_energy_; + bool main_filter_is_sane = first_section_energy > 2.f * tail_energy_ && + matching_data[peak_index] < 100.f; + + // Not detecting any decay, but tail is over noise - assume max decay. + if (num_reverb_decay_sections_ == 0 && main_filter_is_sane && + main_filter_has_reverb) { + decay = kMaxDecay; + } + + if (main_filter_is_sane && num_reverb_decay_sections_ > 0) { + decay = std::max(.97f * reverb_decay_, decay); + reverb_decay_ -= alpha_ * (reverb_decay_ - decay); + } + + found_end_of_reverb_decay_ = + !(main_filter_is_sane && main_filter_has_reverb); + alpha_ = 0.f; // Stop estimation of the decay until another good filter is + // received + } +} + +// Updates the estimation of the frequency response at the filter tail. +void ReverbModelEstimator::UpdateFreqRespTail( + const std::vector>& + filter_freq_response, + int filter_delay_blocks, + float alpha) { + size_t num_blocks = filter_freq_response.size(); + rtc::ArrayView freq_resp_tail( + filter_freq_response[num_blocks - 1]); + rtc::ArrayView freq_resp_direct_path( + filter_freq_response[filter_delay_blocks]); + float ratio_energies = + ComputeRatioEnergies(freq_resp_direct_path, freq_resp_tail); + ratio_tail_to_direct_path_ += + alpha * (ratio_energies - ratio_tail_to_direct_path_); + + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + freq_resp_tail_[k] = freq_resp_direct_path[k] * ratio_tail_to_direct_path_; + } + + for (size_t k = 1; k < kFftLengthBy2; ++k) { + float avg_neighbour = + 0.5f * (freq_resp_tail_[k - 1] + freq_resp_tail_[k + 1]); + freq_resp_tail_[k] = std::max(freq_resp_tail_[k], avg_neighbour); + } +} + +void ReverbModelEstimator::Dump( + const std::unique_ptr& data_dumper) { + data_dumper->DumpRaw("aec3_reverb_decay", reverb_decay_); + data_dumper->DumpRaw("aec3_reverb_tail_energy", tail_energy_); + data_dumper->DumpRaw("aec3_reverb_alpha", alpha_); + data_dumper->DumpRaw("aec3_num_reverb_decay_sections", + static_cast(num_reverb_decay_sections_)); +} + +} // namespace webrtc diff --git a/modules/audio_processing/aec3/reverb_model_estimator.h b/modules/audio_processing/aec3/reverb_model_estimator.h new file mode 100644 index 0000000000..d2015aa886 --- /dev/null +++ b/modules/audio_processing/aec3/reverb_model_estimator.h @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_REVERB_MODEL_ESTIMATOR_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_REVERB_MODEL_ESTIMATOR_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" + +namespace webrtc { + +// The ReverbModelEstimator class describes an estimator of the parameters +// that are used for the reverberant model. +class ReverbModelEstimator { + public: + explicit ReverbModelEstimator(const EchoCanceller3Config& config); + ~ReverbModelEstimator(); + // Updates the model. + void Update(const std::vector& impulse_response, + const std::vector>& + filter_freq_response, + const absl::optional& quality_linear, + int filter_delay_blocks, + bool usable_linear_estimate, + float default_decay, + bool stationary_block); + // Returns the decay for the exponential model. + float ReverbDecay() const { return reverb_decay_; } + + void Dump(const std::unique_ptr& data_dumper); + + // Return the estimated freq. response of the tail of the filter. + rtc::ArrayView GetFreqRespTail() const { + return freq_resp_tail_; + } + + private: + bool IsAGoodFilterForDecayEstimation(int filter_delay_blocks, + bool usable_linear_estimate, + size_t length_filter); + void UpdateReverbDecay(const std::vector& impulse_response); + + void UpdateFreqRespTail( + const std::vector>& + filter_freq_response, + int filter_delay_blocks, + float alpha); + + void ResetDecayEstimation(); + + const size_t filter_main_length_blocks_; + + float accumulated_nz_ = 0.f; + float accumulated_nn_ = 0.f; + float accumulated_count_ = 0.f; + size_t current_reverb_decay_section_ = 0; + size_t num_reverb_decay_sections_ = 0; + size_t num_reverb_decay_sections_next_ = 0; + bool found_end_of_reverb_decay_ = false; + std::array block_energies_; + float reverb_decay_; + float tail_energy_ = 0.f; + float alpha_ = 0.f; + std::array freq_resp_tail_; + float ratio_tail_to_direct_path_ = 0.f; + bool enable_smooth_freq_resp_tail_updates_; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_REVERB_MODEL_ESTIMATOR_H_ diff --git a/modules/audio_processing/aec3/reverb_model_estimator_unittest.cc b/modules/audio_processing/aec3/reverb_model_estimator_unittest.cc new file mode 100644 index 0000000000..9667f4fcba --- /dev/null +++ b/modules/audio_processing/aec3/reverb_model_estimator_unittest.cc @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/reverb_model_estimator.h" + +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/aec3_fft.h" +#include "modules/audio_processing/aec3/fft_data.h" +#include "rtc_base/checks.h" + +#include "test/gtest.h" + +namespace webrtc { + +class ReverbModelEstimatorTest { + public: + explicit ReverbModelEstimatorTest(float default_decay) + : default_decay_(default_decay), + estimated_decay_(default_decay), + h_(aec3_config_.filter.main.length_blocks * kBlockSize, 0.f), + H2_(aec3_config_.filter.main.length_blocks) { + aec3_config_.ep_strength.default_len = default_decay_; + CreateImpulseResponseWithDecay(); + } + void RunEstimator(); + float GetDecay() { return estimated_decay_; } + float GetTrueDecay() { return true_power_decay_; } + float GetPowerTailDb() { return 10.f * log10(estimated_power_tail_); } + float GetTruePowerTailDb() { return 10.f * log10(true_power_tail_); } + + private: + void CreateImpulseResponseWithDecay(); + + absl::optional quality_linear_ = 1.0f; + static constexpr int filter_delay_blocks_ = 2; + static constexpr bool usable_linear_estimate_ = true; + static constexpr bool stationary_block_ = false; + static constexpr float true_power_decay_ = 0.5f; + EchoCanceller3Config aec3_config_; + float default_decay_; + float estimated_decay_; + float estimated_power_tail_ = 0.f; + float true_power_tail_ = 0.f; + std::vector h_; + std::vector> H2_; +}; + +void ReverbModelEstimatorTest::CreateImpulseResponseWithDecay() { + const Aec3Fft fft; + RTC_DCHECK_EQ(h_.size(), aec3_config_.filter.main.length_blocks * kBlockSize); + RTC_DCHECK_EQ(H2_.size(), aec3_config_.filter.main.length_blocks); + RTC_DCHECK_EQ(filter_delay_blocks_, 2); + const float peak = 1.0f; + float decay_power_sample = std::sqrt(true_power_decay_); + for (size_t k = 1; k < kBlockSizeLog2; k++) { + decay_power_sample = std::sqrt(decay_power_sample); + } + h_[filter_delay_blocks_ * kBlockSize] = peak; + for (size_t k = filter_delay_blocks_ * kBlockSize + 1; k < h_.size(); ++k) { + h_[k] = h_[k - 1] * std::sqrt(decay_power_sample); + } + + for (size_t block = 0; block < H2_.size(); ++block) { + std::array h_block; + h_block.fill(0.f); + FftData H_block; + rtc::ArrayView H2_block(H2_[block]); + std::copy(h_.begin() + block * kBlockSize, + h_.begin() + block * (kBlockSize + 1), h_block.begin()); + + fft.Fft(&h_block, &H_block); + for (size_t k = 0; k < H2_block.size(); ++k) { + H2_block[k] = + H_block.re[k] * H_block.re[k] + H_block.im[k] * H_block.im[k]; + } + } + rtc::ArrayView H2_tail(H2_[H2_.size() - 1]); + true_power_tail_ = std::accumulate(H2_tail.begin(), H2_tail.end(), 0.f); +} +void ReverbModelEstimatorTest::RunEstimator() { + ReverbModelEstimator estimator(aec3_config_); + for (size_t k = 0; k < 1000; ++k) { + estimator.Update(h_, H2_, quality_linear_, filter_delay_blocks_, + usable_linear_estimate_, default_decay_, + stationary_block_); + } + estimated_decay_ = estimator.ReverbDecay(); + rtc::ArrayView freq_resp_tail = estimator.GetFreqRespTail(); + estimated_power_tail_ = + std::accumulate(freq_resp_tail.begin(), freq_resp_tail.end(), 0.f); +} + +TEST(ReverbModelEstimatorTests, NotChangingDecay) { + constexpr float default_decay = 0.9f; + ReverbModelEstimatorTest test(default_decay); + test.RunEstimator(); + EXPECT_EQ(test.GetDecay(), default_decay); + EXPECT_NEAR(test.GetPowerTailDb(), test.GetTruePowerTailDb(), 5.f); +} + +TEST(ReverbModelEstimatorTests, ChangingDecay) { + constexpr float default_decay = -0.9f; + ReverbModelEstimatorTest test(default_decay); + test.RunEstimator(); + EXPECT_NEAR(test.GetDecay(), test.GetTrueDecay(), 0.1); + EXPECT_NEAR(test.GetPowerTailDb(), test.GetTruePowerTailDb(), 5.f); +} + +} // namespace webrtc diff --git a/modules/audio_processing/aec3/stationarity_estimator.cc b/modules/audio_processing/aec3/stationarity_estimator.cc index 2ab0eb4fe5..efeabf1117 100644 --- a/modules/audio_processing/aec3/stationarity_estimator.cc +++ b/modules/audio_processing/aec3/stationarity_estimator.cc @@ -48,6 +48,8 @@ void StationarityEstimator::UpdateNoiseEstimator( rtc::ArrayView spectrum) { noise_.Update(spectrum); data_dumper_->DumpRaw("aec3_stationarity_noise_spectrum", noise_.Spectrum()); + data_dumper_->DumpRaw("aec3_stationarity_is_block_stationary", + IsBlockStationary()); } void StationarityEstimator::UpdateStationarityFlags( @@ -88,6 +90,16 @@ void StationarityEstimator::UpdateStationarityFlags( SmoothStationaryPerFreq(); } +bool StationarityEstimator::IsBlockStationary() const { + float acum_stationarity = 0.f; + RTC_DCHECK_EQ(stationarity_flags_.size(), kFftLengthBy2Plus1); + for (size_t band = 0; band < stationarity_flags_.size(); ++band) { + bool st = IsBandStationary(band); + acum_stationarity += static_cast(st); + } + return ((acum_stationarity * (1.f / kFftLengthBy2Plus1)) > 0.75f); +} + bool StationarityEstimator::EstimateBandStationarity( const VectorBuffer& spectrum_buffer, const std::array& reverb, diff --git a/modules/audio_processing/aec3/stationarity_estimator.h b/modules/audio_processing/aec3/stationarity_estimator.h index d5fcd007bf..e2c5a62534 100644 --- a/modules/audio_processing/aec3/stationarity_estimator.h +++ b/modules/audio_processing/aec3/stationarity_estimator.h @@ -47,6 +47,9 @@ class StationarityEstimator { return stationarity_flags_[band] && (hangovers_[band] == 0); } + // Returns true if the current block is estimated as stationary. + bool IsBlockStationary() const; + private: static constexpr int kWindowLength = 13; // Returns the power of the stationary noise spectrum at a band.