From 4a201de10d0cfe81c1d2d9bc8fdf3a08025c93cc Mon Sep 17 00:00:00 2001 From: Emil Vardar Date: Wed, 18 Sep 2024 11:59:41 +0000 Subject: [PATCH] Add support for corruption classification. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This class calculates the corruption score based on the given samples from two frames. Bug: webrtc:358039777 Change-Id: Ib036f91ec16609e827137cc35d342a2c49764737 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/362801 Reviewed-by: Erik Språng Reviewed-by: Fanny Linderborg Commit-Queue: Emil Vardar (xWF) Cr-Commit-Position: refs/heads/main@{#43043} --- video/corruption_detection/BUILD.gn | 24 ++ .../corruption_classifier.cc | 107 +++++++ .../corruption_classifier.h | 75 +++++ .../corruption_classifier_unittest.cc | 269 ++++++++++++++++++ 4 files changed, 475 insertions(+) create mode 100644 video/corruption_detection/corruption_classifier.cc create mode 100644 video/corruption_detection/corruption_classifier.h create mode 100644 video/corruption_detection/corruption_classifier_unittest.cc diff --git a/video/corruption_detection/BUILD.gn b/video/corruption_detection/BUILD.gn index bbcc122711..d37e35947f 100644 --- a/video/corruption_detection/BUILD.gn +++ b/video/corruption_detection/BUILD.gn @@ -8,6 +8,19 @@ import("../../webrtc.gni") +rtc_library("corruption_classifier") { + sources = [ + "corruption_classifier.cc", + "corruption_classifier.h", + ] + deps = [ + ":halton_frame_sampler", + "../../api:array_view", + "../../rtc_base:checks", + "../../rtc_base:logging", + ] +} + rtc_library("frame_instrumentation_generator") { sources = [ "frame_instrumentation_generator.cc", @@ -66,6 +79,16 @@ rtc_library("halton_sequence") { } if (rtc_include_tests) { + rtc_library("corruption_classifier_unittest") { + testonly = true + sources = [ "corruption_classifier_unittest.cc" ] + deps = [ + ":corruption_classifier", + ":halton_frame_sampler", + "../../test:test_support", + ] + } + rtc_library("frame_instrumentation_generator_unittest") { testonly = true sources = [ "frame_instrumentation_generator_unittest.cc" ] @@ -115,6 +138,7 @@ if (rtc_include_tests) { testonly = true sources = [] deps = [ + ":corruption_classifier_unittest", ":frame_instrumentation_generator_unittest", ":generic_mapping_functions_unittest", ":halton_frame_sampler_unittest", diff --git a/video/corruption_detection/corruption_classifier.cc b/video/corruption_detection/corruption_classifier.cc new file mode 100644 index 0000000000..a4fc167cea --- /dev/null +++ b/video/corruption_detection/corruption_classifier.cc @@ -0,0 +1,107 @@ +/* + * Copyright 2024 The WebRTC project authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "video/corruption_detection/corruption_classifier.h" + +#include +#include +#include + +#include "api/array_view.h" +#include "rtc_base/checks.h" +#include "rtc_base/logging.h" +#include "video/corruption_detection/halton_frame_sampler.h" + +namespace webrtc { + +CorruptionClassifier::CorruptionClassifier(float scale_factor) + : config_(ScalarConfig{.scale_factor = scale_factor}) { + RTC_CHECK_GT(scale_factor, 0) << "The scale factor must be positive."; + RTC_LOG(LS_INFO) << "Calculating corruption probability using scale factor."; +} + +CorruptionClassifier::CorruptionClassifier(float growth_rate, float midpoint) + : config_(LogisticFunctionConfig{.growth_rate = growth_rate, + .midpoint = midpoint}) { + RTC_CHECK_GT(growth_rate, 0) + << "As the `score` is defined now (low score means probably not " + "corrupted and vice versa), the growth rate must be positive to have " + "a logistic function that is monotonically increasing."; + RTC_LOG(LS_INFO) + << "Calculating corruption probability using logistic function."; +} + +double CorruptionClassifier::CalculateCorruptionProbablility( + rtc::ArrayView filtered_original_samples, + rtc::ArrayView filtered_compressed_samples, + int luma_threshold, + int chroma_threshold) const { + RTC_DCHECK_GT(luma_threshold, 0) << "Luma threshold must be positive."; + RTC_DCHECK_GT(chroma_threshold, 0) << "Chroma threshold must be positive."; + RTC_DCHECK_EQ(filtered_original_samples.size(), + filtered_compressed_samples.size()) + << "The original and compressed frame have a different amount of " + "filtered samples."; + + double loss = GetScore(filtered_original_samples, filtered_compressed_samples, + luma_threshold, chroma_threshold); + + if (const auto* scalar_config = std::get_if(&config_)) { + // Fitting the unbounded loss to the interval of [0, 1] using a simple scale + // factor and capping the loss to 1. + return std::min(loss / scalar_config->scale_factor, 1.0); + } + + const auto config = std::get_if(&config_); + RTC_DCHECK(config); + // Fitting the unbounded loss to the interval of [0, 1] using the logistic + // function. + return 1 / (1 + std::exp(-config->growth_rate * (loss - config->midpoint))); +} + +// The score is calculated according to the following formula : +// +// score = (sum_i max{(|original_i - compressed_i| - threshold, 0)^2}) / N +// +// where N is the number of samples, i in [0, N), and the threshold is +// either `luma_threshold` or `chroma_threshold` depending on whether the +// sample is luma or chroma. +double CorruptionClassifier::GetScore( + rtc::ArrayView filtered_original_samples, + rtc::ArrayView filtered_compressed_samples, + int luma_threshold, + int chroma_threshold) const { + RTC_CHECK_EQ(filtered_original_samples.size(), + filtered_compressed_samples.size()); + const int num_samples = filtered_original_samples.size(); + double sum = 0.0; + for (int i = 0; i < num_samples; ++i) { + RTC_CHECK_EQ(filtered_original_samples[i].plane, + filtered_compressed_samples[i].plane); + double abs_diff = std::abs(filtered_original_samples[i].value - + filtered_compressed_samples[i].value); + switch (filtered_original_samples[i].plane) { + case ImagePlane::kLuma: + if (abs_diff > luma_threshold) { + sum += std::pow(abs_diff - luma_threshold, 2); + } + break; + case ImagePlane::kChroma: + if (abs_diff > chroma_threshold) { + sum += std::pow(abs_diff - chroma_threshold, 2); + } + break; + } + } + + return sum / num_samples; +} + +} // namespace webrtc diff --git a/video/corruption_detection/corruption_classifier.h b/video/corruption_detection/corruption_classifier.h new file mode 100644 index 0000000000..8e0c061fe0 --- /dev/null +++ b/video/corruption_detection/corruption_classifier.h @@ -0,0 +1,75 @@ +/* + * Copyright 2024 The WebRTC project authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef VIDEO_CORRUPTION_DETECTION_CORRUPTION_CLASSIFIER_H_ +#define VIDEO_CORRUPTION_DETECTION_CORRUPTION_CLASSIFIER_H_ + +#include + +#include "api/array_view.h" +#include "video/corruption_detection/halton_frame_sampler.h" + +namespace webrtc { + +// Based on the given filtered samples to `CalculateCorruptionProbablility` this +// class calculates a probability to indicate whether the frame is corrupted. +// The classification is done either by scaling the loss to the interval of [0, +// 1] using a simple `scale_factor` or by applying a logistic function to the +// loss. The logistic function is constructed based on `growth_rate` and +// `midpoint`, to the score between the original and the compressed frames' +// samples. This score is calculated using `GetScore`. +// +// TODO: bugs.webrtc.org/358039777 - Remove one of the constructors based on +// which mapping function works best in practice. +class CorruptionClassifier { + public: + // Calculates the corruption probability using a simple scale factor. + explicit CorruptionClassifier(float scale_factor); + // Calculates the corruption probability using a logistic function. + CorruptionClassifier(float growth_rate, float midpoint); + ~CorruptionClassifier() = default; + + // This function calculates and returns the probability (in the interval [0, + // 1] that a frame is corrupted. The probability is determined either by + // scaling the loss to the interval of [0, 1] using a simple `scale_factor` + // or by applying a logistic function to the loss. The method is chosen + // depending on the used constructor. + double CalculateCorruptionProbablility( + rtc::ArrayView filtered_original_samples, + rtc::ArrayView filtered_compressed_samples, + int luma_threshold, + int chroma_threshold) const; + + private: + struct ScalarConfig { + float scale_factor; + }; + + // Logistic function parameters. See + // https://en.wikipedia.org/wiki/Logistic_function. + struct LogisticFunctionConfig { + float growth_rate; + float midpoint; + }; + + // Returns the non-normalized score between the original and the compressed + // frames' samples. + double GetScore( + rtc::ArrayView filtered_original_samples, + rtc::ArrayView filtered_compressed_samples, + int luma_threshold, + int chroma_threshold) const; + + const std::variant config_; +}; + +} // namespace webrtc + +#endif // VIDEO_CORRUPTION_DETECTION_CORRUPTION_CLASSIFIER_H_ diff --git a/video/corruption_detection/corruption_classifier_unittest.cc b/video/corruption_detection/corruption_classifier_unittest.cc new file mode 100644 index 0000000000..1fbdb29cda --- /dev/null +++ b/video/corruption_detection/corruption_classifier_unittest.cc @@ -0,0 +1,269 @@ +/* + * Copyright 2024 The WebRTC project authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "video/corruption_detection/corruption_classifier.h" + +#include + +#include "test/gmock.h" +#include "test/gtest.h" +#include "video/corruption_detection/halton_frame_sampler.h" + +namespace webrtc { +namespace { + +using ::testing::DoubleNear; + +constexpr int kLumaThreshold = 3; +constexpr int kChromaThreshold = 2; + +constexpr double kMaxAbsoluteError = 1e-4; + +// Arbitrary values for testing. +constexpr double kBaseOriginalLumaSampleValue1 = 1.0; +constexpr double kBaseOriginalLumaSampleValue2 = 2.5; +constexpr double kBaseOriginalChromaSampleValue1 = 0.5; + +constexpr FilteredSample kFilteredOriginalSampleValues[] = { + {.value = kBaseOriginalLumaSampleValue1, .plane = ImagePlane::kLuma}, + {.value = kBaseOriginalLumaSampleValue2, .plane = ImagePlane::kLuma}, + {.value = kBaseOriginalChromaSampleValue1, .plane = ImagePlane::kChroma}}; + +// The value 14.0 corresponds to the corruption probability being on the same +// side of 0.5 in the `ScalarConfig` and `LogisticFunctionConfig`. +constexpr float kScaleFactor = 14.0; + +constexpr float kGrowthRate = 1.0; +constexpr float kMidpoint = 7.0; + +// Helper function to create fake compressed sample values. +std::vector GetCompressedSampleValues( + double increase_value_luma, + double increase_value_chroma) { + return std::vector{ + {.value = kBaseOriginalLumaSampleValue1 + increase_value_luma, + .plane = ImagePlane::kLuma}, + {.value = kBaseOriginalLumaSampleValue2 + increase_value_luma, + .plane = ImagePlane::kLuma}, + {.value = kBaseOriginalChromaSampleValue1 + increase_value_chroma, + .plane = ImagePlane::kChroma}}; +} + +TEST(CorruptionClassifierTest, + SameSampleValuesShouldResultInNoCorruptionScalarConfig) { + float kIncreaseValue = 0.0; + const std::vector kFilteredCompressedSampleValues = + GetCompressedSampleValues(kIncreaseValue, kIncreaseValue); + + CorruptionClassifier corruption_classifier(kScaleFactor); + + // Expected: score = 0. + // Note that the `score` above corresponds to the value returned by the + // `GetScore` function. Then this value should be passed through the Scalar or + // Logistic function giving the expected result inside DoubleNear. This + // applies for all the following tests. + EXPECT_THAT( + corruption_classifier.CalculateCorruptionProbablility( + kFilteredOriginalSampleValues, kFilteredCompressedSampleValues, + kLumaThreshold, kChromaThreshold), + DoubleNear(0.0, kMaxAbsoluteError)); +} + +TEST(CorruptionClassifierTest, + SameSampleValuesShouldResultInNoCorruptionLogisticFunctionConfig) { + float kIncreaseValue = 0.0; + const std::vector kFilteredCompressedSampleValues = + GetCompressedSampleValues(kIncreaseValue, kIncreaseValue); + + CorruptionClassifier corruption_classifier(kGrowthRate, kMidpoint); + + // Expected: score = 0. See above for explanation why we have `0.0009` below. + EXPECT_THAT( + corruption_classifier.CalculateCorruptionProbablility( + kFilteredOriginalSampleValues, kFilteredCompressedSampleValues, + kLumaThreshold, kChromaThreshold), + DoubleNear(0.0009, kMaxAbsoluteError)); +} + +TEST(CorruptionClassifierTest, + NoCorruptionWhenAllSampleDifferencesBelowThresholdScalarConfig) { + // Following value should be < `kLumaThreshold` and `kChromaThreshold`. + const double kIncreaseValue = 1; + const std::vector kFilteredCompressedSampleValues = + GetCompressedSampleValues(kIncreaseValue, kIncreaseValue); + + CorruptionClassifier corruption_classifier(kScaleFactor); + + // Expected: score = 0. + EXPECT_THAT( + corruption_classifier.CalculateCorruptionProbablility( + kFilteredOriginalSampleValues, kFilteredCompressedSampleValues, + kLumaThreshold, kChromaThreshold), + DoubleNear(0.0, kMaxAbsoluteError)); +} + +TEST(CorruptionClassifierTest, + NoCorruptionWhenAllSampleDifferencesBelowThresholdLogisticFunctionConfig) { + // Following value should be < `kLumaThreshold` and `kChromaThreshold`. + const double kIncreaseValue = 1; + const std::vector kFilteredCompressedSampleValues = + GetCompressedSampleValues(kIncreaseValue, kIncreaseValue); + + CorruptionClassifier corruption_classifier(kGrowthRate, kMidpoint); + + // Expected: score = 0. + EXPECT_THAT( + corruption_classifier.CalculateCorruptionProbablility( + kFilteredOriginalSampleValues, kFilteredCompressedSampleValues, + kLumaThreshold, kChromaThreshold), + DoubleNear(0.0009, kMaxAbsoluteError)); +} + +TEST(CorruptionClassifierTest, + NoCorruptionWhenSmallPartOfSamplesAboveThresholdScalarConfig) { + const double kIncreaseValueLuma = 1; + const double kIncreaseValueChroma = 2.5; // Above `kChromaThreshold`. + const std::vector kFilteredCompressedSampleValues = + GetCompressedSampleValues(kIncreaseValueLuma, kIncreaseValueChroma); + + CorruptionClassifier corruption_classifier(kScaleFactor); + + // Expected: score = (0.5)^2 / 3. + EXPECT_THAT( + corruption_classifier.CalculateCorruptionProbablility( + kFilteredOriginalSampleValues, kFilteredCompressedSampleValues, + kLumaThreshold, kChromaThreshold), + DoubleNear(0.0060, kMaxAbsoluteError)); +} + +TEST(CorruptionClassifierTest, + NoCorruptionWhenSmallPartOfSamplesAboveThresholdLogisticFunctionConfig) { + const double kIncreaseValueLuma = 1; + const double kIncreaseValueChroma = 2.5; // Above `kChromaThreshold`. + const std::vector kFilteredCompressedSampleValues = + GetCompressedSampleValues(kIncreaseValueLuma, kIncreaseValueChroma); + + CorruptionClassifier corruption_classifier(kGrowthRate, kMidpoint); + + // Expected: score = (0.5)^2 / 3. + EXPECT_THAT( + corruption_classifier.CalculateCorruptionProbablility( + kFilteredOriginalSampleValues, kFilteredCompressedSampleValues, + kLumaThreshold, kChromaThreshold), + DoubleNear(0.001, kMaxAbsoluteError)); +} + +TEST(CorruptionClassifierTest, + NoCorruptionWhenAllSamplesSlightlyAboveThresholdScalarConfig) { + const double kIncreaseValueLuma = 4.2; // Above `kLumaThreshold`. + const double kIncreaseValueChroma = 2.5; // Above `kChromaThreshold`. + const std::vector kFilteredCompressedSampleValues = + GetCompressedSampleValues(kIncreaseValueLuma, kIncreaseValueChroma); + + CorruptionClassifier corruption_classifier(kScaleFactor); + + // Expected: score = ((0.5)^2 + 2*(1.2)^2) / 3. + EXPECT_THAT( + corruption_classifier.CalculateCorruptionProbablility( + kFilteredOriginalSampleValues, kFilteredCompressedSampleValues, + kLumaThreshold, kChromaThreshold), + DoubleNear(0.07452, kMaxAbsoluteError)); +} + +TEST(CorruptionClassifierTest, + NoCorruptionWhenAllSamplesSlightlyAboveThresholdLogisticFunctionConfig) { + const double kIncreaseValueLuma = 4.2; // Above `kLumaThreshold`. + const double kIncreaseValueChroma = 2.5; // Above `kChromaThreshold`. + const std::vector kFilteredCompressedSampleValues = + GetCompressedSampleValues(kIncreaseValueLuma, kIncreaseValueChroma); + + CorruptionClassifier corruption_classifier(kGrowthRate, kMidpoint); + + // Expected: score = ((0.5)^2 + 2*(1.2)^2) / 3. + EXPECT_THAT( + corruption_classifier.CalculateCorruptionProbablility( + kFilteredOriginalSampleValues, kFilteredCompressedSampleValues, + kLumaThreshold, kChromaThreshold), + DoubleNear(0.0026, kMaxAbsoluteError)); +} + +// Observe that the following 2 tests in practice could be classified as +// corrupted, if so wanted. However, with the `kGrowthRate`, `kMidpoint` and +// `kScaleFactor` values chosen in these tests, the score is not high enough to +// be classified as corrupted. +TEST(CorruptionClassifierTest, + NoCorruptionWhenAllSamplesSomewhatAboveThresholdScalarConfig) { + const double kIncreaseValue = 5.0; + const std::vector kFilteredCompressedSampleValues = + GetCompressedSampleValues(kIncreaseValue, kIncreaseValue); + + CorruptionClassifier corruption_classifier(kScaleFactor); + + // Expected: score = ((3)^2 + 2*(2)^2) / 3. + EXPECT_THAT( + corruption_classifier.CalculateCorruptionProbablility( + kFilteredOriginalSampleValues, kFilteredCompressedSampleValues, + kLumaThreshold, kChromaThreshold), + DoubleNear(0.4048, kMaxAbsoluteError)); +} + +TEST(CorruptionClassifierTest, + NoCorruptionWhenAllSamplesSomewhatAboveThresholdLogisticFunctionConfig) { + // Somewhat above `kLumaThreshold` and `kChromaThreshold`. + const double kIncreaseValue = 5.0; + const std::vector kFilteredCompressedSampleValues = + GetCompressedSampleValues(kIncreaseValue, kIncreaseValue); + + CorruptionClassifier corruption_classifier(kGrowthRate, kMidpoint); + + // Expected: score = ((3)^2 + 2*(2)^2) / 3. + EXPECT_THAT( + corruption_classifier.CalculateCorruptionProbablility( + kFilteredOriginalSampleValues, kFilteredCompressedSampleValues, + kLumaThreshold, kChromaThreshold), + DoubleNear(0.2086, kMaxAbsoluteError)); +} + +TEST(CorruptionClassifierTest, + CorruptionWhenAllSamplesWellAboveThresholdScalarConfig) { + // Well above `kLumaThreshold` and `kChromaThreshold`. + const double kIncreaseValue = 7.0; + const std::vector kFilteredCompressedSampleValues = + GetCompressedSampleValues(kIncreaseValue, kIncreaseValue); + + CorruptionClassifier corruption_classifier(kScaleFactor); + + // Expected: score = ((5)^2 + 2*(4)^2) / 3. Expected 1 because of capping. + EXPECT_THAT( + corruption_classifier.CalculateCorruptionProbablility( + kFilteredOriginalSampleValues, kFilteredCompressedSampleValues, + kLumaThreshold, kChromaThreshold), + DoubleNear(1, kMaxAbsoluteError)); +} + +TEST(CorruptionClassifierTest, + CorruptionWhenAllSamplesWellAboveThresholdLogisticFunctionConfig) { + // Well above `kLumaThreshold` and `kChromaThreshold`. + const double kIncreaseValue = 7.0; + const std::vector kFilteredCompressedSampleValues = + GetCompressedSampleValues(kIncreaseValue, kIncreaseValue); + + CorruptionClassifier corruption_classifier(kGrowthRate, kMidpoint); + + // Expected: score = ((5)^2 + 2*(4)^2) / 3. + EXPECT_THAT( + corruption_classifier.CalculateCorruptionProbablility( + kFilteredOriginalSampleValues, kFilteredCompressedSampleValues, + kLumaThreshold, kChromaThreshold), + DoubleNear(1, kMaxAbsoluteError)); +} + +} // namespace +} // namespace webrtc