Add support for corruption classification.

This class calculates the corruption score based on the given samples from two frames.

Bug: webrtc:358039777
Change-Id: Ib036f91ec16609e827137cc35d342a2c49764737
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/362801
Reviewed-by: Erik Språng <sprang@webrtc.org>
Reviewed-by: Fanny Linderborg <linderborg@webrtc.org>
Commit-Queue: Emil Vardar (xWF) <vardar@google.com>
Cr-Commit-Position: refs/heads/main@{#43043}
This commit is contained in:
Emil Vardar 2024-09-18 11:59:41 +00:00 committed by WebRTC LUCI CQ
parent f045dbd67c
commit 4a201de10d
4 changed files with 475 additions and 0 deletions

View File

@ -8,6 +8,19 @@
import("../../webrtc.gni")
rtc_library("corruption_classifier") {
sources = [
"corruption_classifier.cc",
"corruption_classifier.h",
]
deps = [
":halton_frame_sampler",
"../../api:array_view",
"../../rtc_base:checks",
"../../rtc_base:logging",
]
}
rtc_library("frame_instrumentation_generator") {
sources = [
"frame_instrumentation_generator.cc",
@ -66,6 +79,16 @@ rtc_library("halton_sequence") {
}
if (rtc_include_tests) {
rtc_library("corruption_classifier_unittest") {
testonly = true
sources = [ "corruption_classifier_unittest.cc" ]
deps = [
":corruption_classifier",
":halton_frame_sampler",
"../../test:test_support",
]
}
rtc_library("frame_instrumentation_generator_unittest") {
testonly = true
sources = [ "frame_instrumentation_generator_unittest.cc" ]
@ -115,6 +138,7 @@ if (rtc_include_tests) {
testonly = true
sources = []
deps = [
":corruption_classifier_unittest",
":frame_instrumentation_generator_unittest",
":generic_mapping_functions_unittest",
":halton_frame_sampler_unittest",

View File

@ -0,0 +1,107 @@
/*
* Copyright 2024 The WebRTC project authors. All rights reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "video/corruption_detection/corruption_classifier.h"
#include <algorithm>
#include <cmath>
#include <variant>
#include "api/array_view.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
#include "video/corruption_detection/halton_frame_sampler.h"
namespace webrtc {
CorruptionClassifier::CorruptionClassifier(float scale_factor)
: config_(ScalarConfig{.scale_factor = scale_factor}) {
RTC_CHECK_GT(scale_factor, 0) << "The scale factor must be positive.";
RTC_LOG(LS_INFO) << "Calculating corruption probability using scale factor.";
}
CorruptionClassifier::CorruptionClassifier(float growth_rate, float midpoint)
: config_(LogisticFunctionConfig{.growth_rate = growth_rate,
.midpoint = midpoint}) {
RTC_CHECK_GT(growth_rate, 0)
<< "As the `score` is defined now (low score means probably not "
"corrupted and vice versa), the growth rate must be positive to have "
"a logistic function that is monotonically increasing.";
RTC_LOG(LS_INFO)
<< "Calculating corruption probability using logistic function.";
}
double CorruptionClassifier::CalculateCorruptionProbablility(
rtc::ArrayView<const FilteredSample> filtered_original_samples,
rtc::ArrayView<const FilteredSample> filtered_compressed_samples,
int luma_threshold,
int chroma_threshold) const {
RTC_DCHECK_GT(luma_threshold, 0) << "Luma threshold must be positive.";
RTC_DCHECK_GT(chroma_threshold, 0) << "Chroma threshold must be positive.";
RTC_DCHECK_EQ(filtered_original_samples.size(),
filtered_compressed_samples.size())
<< "The original and compressed frame have a different amount of "
"filtered samples.";
double loss = GetScore(filtered_original_samples, filtered_compressed_samples,
luma_threshold, chroma_threshold);
if (const auto* scalar_config = std::get_if<ScalarConfig>(&config_)) {
// Fitting the unbounded loss to the interval of [0, 1] using a simple scale
// factor and capping the loss to 1.
return std::min(loss / scalar_config->scale_factor, 1.0);
}
const auto config = std::get_if<LogisticFunctionConfig>(&config_);
RTC_DCHECK(config);
// Fitting the unbounded loss to the interval of [0, 1] using the logistic
// function.
return 1 / (1 + std::exp(-config->growth_rate * (loss - config->midpoint)));
}
// The score is calculated according to the following formula :
//
// score = (sum_i max{(|original_i - compressed_i| - threshold, 0)^2}) / N
//
// where N is the number of samples, i in [0, N), and the threshold is
// either `luma_threshold` or `chroma_threshold` depending on whether the
// sample is luma or chroma.
double CorruptionClassifier::GetScore(
rtc::ArrayView<const FilteredSample> filtered_original_samples,
rtc::ArrayView<const FilteredSample> filtered_compressed_samples,
int luma_threshold,
int chroma_threshold) const {
RTC_CHECK_EQ(filtered_original_samples.size(),
filtered_compressed_samples.size());
const int num_samples = filtered_original_samples.size();
double sum = 0.0;
for (int i = 0; i < num_samples; ++i) {
RTC_CHECK_EQ(filtered_original_samples[i].plane,
filtered_compressed_samples[i].plane);
double abs_diff = std::abs(filtered_original_samples[i].value -
filtered_compressed_samples[i].value);
switch (filtered_original_samples[i].plane) {
case ImagePlane::kLuma:
if (abs_diff > luma_threshold) {
sum += std::pow(abs_diff - luma_threshold, 2);
}
break;
case ImagePlane::kChroma:
if (abs_diff > chroma_threshold) {
sum += std::pow(abs_diff - chroma_threshold, 2);
}
break;
}
}
return sum / num_samples;
}
} // namespace webrtc

View File

@ -0,0 +1,75 @@
/*
* Copyright 2024 The WebRTC project authors. All rights reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef VIDEO_CORRUPTION_DETECTION_CORRUPTION_CLASSIFIER_H_
#define VIDEO_CORRUPTION_DETECTION_CORRUPTION_CLASSIFIER_H_
#include <variant>
#include "api/array_view.h"
#include "video/corruption_detection/halton_frame_sampler.h"
namespace webrtc {
// Based on the given filtered samples to `CalculateCorruptionProbablility` this
// class calculates a probability to indicate whether the frame is corrupted.
// The classification is done either by scaling the loss to the interval of [0,
// 1] using a simple `scale_factor` or by applying a logistic function to the
// loss. The logistic function is constructed based on `growth_rate` and
// `midpoint`, to the score between the original and the compressed frames'
// samples. This score is calculated using `GetScore`.
//
// TODO: bugs.webrtc.org/358039777 - Remove one of the constructors based on
// which mapping function works best in practice.
class CorruptionClassifier {
public:
// Calculates the corruption probability using a simple scale factor.
explicit CorruptionClassifier(float scale_factor);
// Calculates the corruption probability using a logistic function.
CorruptionClassifier(float growth_rate, float midpoint);
~CorruptionClassifier() = default;
// This function calculates and returns the probability (in the interval [0,
// 1] that a frame is corrupted. The probability is determined either by
// scaling the loss to the interval of [0, 1] using a simple `scale_factor`
// or by applying a logistic function to the loss. The method is chosen
// depending on the used constructor.
double CalculateCorruptionProbablility(
rtc::ArrayView<const FilteredSample> filtered_original_samples,
rtc::ArrayView<const FilteredSample> filtered_compressed_samples,
int luma_threshold,
int chroma_threshold) const;
private:
struct ScalarConfig {
float scale_factor;
};
// Logistic function parameters. See
// https://en.wikipedia.org/wiki/Logistic_function.
struct LogisticFunctionConfig {
float growth_rate;
float midpoint;
};
// Returns the non-normalized score between the original and the compressed
// frames' samples.
double GetScore(
rtc::ArrayView<const FilteredSample> filtered_original_samples,
rtc::ArrayView<const FilteredSample> filtered_compressed_samples,
int luma_threshold,
int chroma_threshold) const;
const std::variant<ScalarConfig, LogisticFunctionConfig> config_;
};
} // namespace webrtc
#endif // VIDEO_CORRUPTION_DETECTION_CORRUPTION_CLASSIFIER_H_

View File

@ -0,0 +1,269 @@
/*
* Copyright 2024 The WebRTC project authors. All rights reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "video/corruption_detection/corruption_classifier.h"
#include <vector>
#include "test/gmock.h"
#include "test/gtest.h"
#include "video/corruption_detection/halton_frame_sampler.h"
namespace webrtc {
namespace {
using ::testing::DoubleNear;
constexpr int kLumaThreshold = 3;
constexpr int kChromaThreshold = 2;
constexpr double kMaxAbsoluteError = 1e-4;
// Arbitrary values for testing.
constexpr double kBaseOriginalLumaSampleValue1 = 1.0;
constexpr double kBaseOriginalLumaSampleValue2 = 2.5;
constexpr double kBaseOriginalChromaSampleValue1 = 0.5;
constexpr FilteredSample kFilteredOriginalSampleValues[] = {
{.value = kBaseOriginalLumaSampleValue1, .plane = ImagePlane::kLuma},
{.value = kBaseOriginalLumaSampleValue2, .plane = ImagePlane::kLuma},
{.value = kBaseOriginalChromaSampleValue1, .plane = ImagePlane::kChroma}};
// The value 14.0 corresponds to the corruption probability being on the same
// side of 0.5 in the `ScalarConfig` and `LogisticFunctionConfig`.
constexpr float kScaleFactor = 14.0;
constexpr float kGrowthRate = 1.0;
constexpr float kMidpoint = 7.0;
// Helper function to create fake compressed sample values.
std::vector<FilteredSample> GetCompressedSampleValues(
double increase_value_luma,
double increase_value_chroma) {
return std::vector<FilteredSample>{
{.value = kBaseOriginalLumaSampleValue1 + increase_value_luma,
.plane = ImagePlane::kLuma},
{.value = kBaseOriginalLumaSampleValue2 + increase_value_luma,
.plane = ImagePlane::kLuma},
{.value = kBaseOriginalChromaSampleValue1 + increase_value_chroma,
.plane = ImagePlane::kChroma}};
}
TEST(CorruptionClassifierTest,
SameSampleValuesShouldResultInNoCorruptionScalarConfig) {
float kIncreaseValue = 0.0;
const std::vector<FilteredSample> kFilteredCompressedSampleValues =
GetCompressedSampleValues(kIncreaseValue, kIncreaseValue);
CorruptionClassifier corruption_classifier(kScaleFactor);
// Expected: score = 0.
// Note that the `score` above corresponds to the value returned by the
// `GetScore` function. Then this value should be passed through the Scalar or
// Logistic function giving the expected result inside DoubleNear. This
// applies for all the following tests.
EXPECT_THAT(
corruption_classifier.CalculateCorruptionProbablility(
kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
kLumaThreshold, kChromaThreshold),
DoubleNear(0.0, kMaxAbsoluteError));
}
TEST(CorruptionClassifierTest,
SameSampleValuesShouldResultInNoCorruptionLogisticFunctionConfig) {
float kIncreaseValue = 0.0;
const std::vector<FilteredSample> kFilteredCompressedSampleValues =
GetCompressedSampleValues(kIncreaseValue, kIncreaseValue);
CorruptionClassifier corruption_classifier(kGrowthRate, kMidpoint);
// Expected: score = 0. See above for explanation why we have `0.0009` below.
EXPECT_THAT(
corruption_classifier.CalculateCorruptionProbablility(
kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
kLumaThreshold, kChromaThreshold),
DoubleNear(0.0009, kMaxAbsoluteError));
}
TEST(CorruptionClassifierTest,
NoCorruptionWhenAllSampleDifferencesBelowThresholdScalarConfig) {
// Following value should be < `kLumaThreshold` and `kChromaThreshold`.
const double kIncreaseValue = 1;
const std::vector<FilteredSample> kFilteredCompressedSampleValues =
GetCompressedSampleValues(kIncreaseValue, kIncreaseValue);
CorruptionClassifier corruption_classifier(kScaleFactor);
// Expected: score = 0.
EXPECT_THAT(
corruption_classifier.CalculateCorruptionProbablility(
kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
kLumaThreshold, kChromaThreshold),
DoubleNear(0.0, kMaxAbsoluteError));
}
TEST(CorruptionClassifierTest,
NoCorruptionWhenAllSampleDifferencesBelowThresholdLogisticFunctionConfig) {
// Following value should be < `kLumaThreshold` and `kChromaThreshold`.
const double kIncreaseValue = 1;
const std::vector<FilteredSample> kFilteredCompressedSampleValues =
GetCompressedSampleValues(kIncreaseValue, kIncreaseValue);
CorruptionClassifier corruption_classifier(kGrowthRate, kMidpoint);
// Expected: score = 0.
EXPECT_THAT(
corruption_classifier.CalculateCorruptionProbablility(
kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
kLumaThreshold, kChromaThreshold),
DoubleNear(0.0009, kMaxAbsoluteError));
}
TEST(CorruptionClassifierTest,
NoCorruptionWhenSmallPartOfSamplesAboveThresholdScalarConfig) {
const double kIncreaseValueLuma = 1;
const double kIncreaseValueChroma = 2.5; // Above `kChromaThreshold`.
const std::vector<FilteredSample> kFilteredCompressedSampleValues =
GetCompressedSampleValues(kIncreaseValueLuma, kIncreaseValueChroma);
CorruptionClassifier corruption_classifier(kScaleFactor);
// Expected: score = (0.5)^2 / 3.
EXPECT_THAT(
corruption_classifier.CalculateCorruptionProbablility(
kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
kLumaThreshold, kChromaThreshold),
DoubleNear(0.0060, kMaxAbsoluteError));
}
TEST(CorruptionClassifierTest,
NoCorruptionWhenSmallPartOfSamplesAboveThresholdLogisticFunctionConfig) {
const double kIncreaseValueLuma = 1;
const double kIncreaseValueChroma = 2.5; // Above `kChromaThreshold`.
const std::vector<FilteredSample> kFilteredCompressedSampleValues =
GetCompressedSampleValues(kIncreaseValueLuma, kIncreaseValueChroma);
CorruptionClassifier corruption_classifier(kGrowthRate, kMidpoint);
// Expected: score = (0.5)^2 / 3.
EXPECT_THAT(
corruption_classifier.CalculateCorruptionProbablility(
kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
kLumaThreshold, kChromaThreshold),
DoubleNear(0.001, kMaxAbsoluteError));
}
TEST(CorruptionClassifierTest,
NoCorruptionWhenAllSamplesSlightlyAboveThresholdScalarConfig) {
const double kIncreaseValueLuma = 4.2; // Above `kLumaThreshold`.
const double kIncreaseValueChroma = 2.5; // Above `kChromaThreshold`.
const std::vector<FilteredSample> kFilteredCompressedSampleValues =
GetCompressedSampleValues(kIncreaseValueLuma, kIncreaseValueChroma);
CorruptionClassifier corruption_classifier(kScaleFactor);
// Expected: score = ((0.5)^2 + 2*(1.2)^2) / 3.
EXPECT_THAT(
corruption_classifier.CalculateCorruptionProbablility(
kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
kLumaThreshold, kChromaThreshold),
DoubleNear(0.07452, kMaxAbsoluteError));
}
TEST(CorruptionClassifierTest,
NoCorruptionWhenAllSamplesSlightlyAboveThresholdLogisticFunctionConfig) {
const double kIncreaseValueLuma = 4.2; // Above `kLumaThreshold`.
const double kIncreaseValueChroma = 2.5; // Above `kChromaThreshold`.
const std::vector<FilteredSample> kFilteredCompressedSampleValues =
GetCompressedSampleValues(kIncreaseValueLuma, kIncreaseValueChroma);
CorruptionClassifier corruption_classifier(kGrowthRate, kMidpoint);
// Expected: score = ((0.5)^2 + 2*(1.2)^2) / 3.
EXPECT_THAT(
corruption_classifier.CalculateCorruptionProbablility(
kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
kLumaThreshold, kChromaThreshold),
DoubleNear(0.0026, kMaxAbsoluteError));
}
// Observe that the following 2 tests in practice could be classified as
// corrupted, if so wanted. However, with the `kGrowthRate`, `kMidpoint` and
// `kScaleFactor` values chosen in these tests, the score is not high enough to
// be classified as corrupted.
TEST(CorruptionClassifierTest,
NoCorruptionWhenAllSamplesSomewhatAboveThresholdScalarConfig) {
const double kIncreaseValue = 5.0;
const std::vector<FilteredSample> kFilteredCompressedSampleValues =
GetCompressedSampleValues(kIncreaseValue, kIncreaseValue);
CorruptionClassifier corruption_classifier(kScaleFactor);
// Expected: score = ((3)^2 + 2*(2)^2) / 3.
EXPECT_THAT(
corruption_classifier.CalculateCorruptionProbablility(
kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
kLumaThreshold, kChromaThreshold),
DoubleNear(0.4048, kMaxAbsoluteError));
}
TEST(CorruptionClassifierTest,
NoCorruptionWhenAllSamplesSomewhatAboveThresholdLogisticFunctionConfig) {
// Somewhat above `kLumaThreshold` and `kChromaThreshold`.
const double kIncreaseValue = 5.0;
const std::vector<FilteredSample> kFilteredCompressedSampleValues =
GetCompressedSampleValues(kIncreaseValue, kIncreaseValue);
CorruptionClassifier corruption_classifier(kGrowthRate, kMidpoint);
// Expected: score = ((3)^2 + 2*(2)^2) / 3.
EXPECT_THAT(
corruption_classifier.CalculateCorruptionProbablility(
kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
kLumaThreshold, kChromaThreshold),
DoubleNear(0.2086, kMaxAbsoluteError));
}
TEST(CorruptionClassifierTest,
CorruptionWhenAllSamplesWellAboveThresholdScalarConfig) {
// Well above `kLumaThreshold` and `kChromaThreshold`.
const double kIncreaseValue = 7.0;
const std::vector<FilteredSample> kFilteredCompressedSampleValues =
GetCompressedSampleValues(kIncreaseValue, kIncreaseValue);
CorruptionClassifier corruption_classifier(kScaleFactor);
// Expected: score = ((5)^2 + 2*(4)^2) / 3. Expected 1 because of capping.
EXPECT_THAT(
corruption_classifier.CalculateCorruptionProbablility(
kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
kLumaThreshold, kChromaThreshold),
DoubleNear(1, kMaxAbsoluteError));
}
TEST(CorruptionClassifierTest,
CorruptionWhenAllSamplesWellAboveThresholdLogisticFunctionConfig) {
// Well above `kLumaThreshold` and `kChromaThreshold`.
const double kIncreaseValue = 7.0;
const std::vector<FilteredSample> kFilteredCompressedSampleValues =
GetCompressedSampleValues(kIncreaseValue, kIncreaseValue);
CorruptionClassifier corruption_classifier(kGrowthRate, kMidpoint);
// Expected: score = ((5)^2 + 2*(4)^2) / 3.
EXPECT_THAT(
corruption_classifier.CalculateCorruptionProbablility(
kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
kLumaThreshold, kChromaThreshold),
DoubleNear(1, kMaxAbsoluteError));
}
} // namespace
} // namespace webrtc