Add support for corruption classification.
This class calculates the corruption score based on the given samples from two frames. Bug: webrtc:358039777 Change-Id: Ib036f91ec16609e827137cc35d342a2c49764737 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/362801 Reviewed-by: Erik Språng <sprang@webrtc.org> Reviewed-by: Fanny Linderborg <linderborg@webrtc.org> Commit-Queue: Emil Vardar (xWF) <vardar@google.com> Cr-Commit-Position: refs/heads/main@{#43043}
This commit is contained in:
parent
f045dbd67c
commit
4a201de10d
@ -8,6 +8,19 @@
|
|||||||
|
|
||||||
import("../../webrtc.gni")
|
import("../../webrtc.gni")
|
||||||
|
|
||||||
|
rtc_library("corruption_classifier") {
|
||||||
|
sources = [
|
||||||
|
"corruption_classifier.cc",
|
||||||
|
"corruption_classifier.h",
|
||||||
|
]
|
||||||
|
deps = [
|
||||||
|
":halton_frame_sampler",
|
||||||
|
"../../api:array_view",
|
||||||
|
"../../rtc_base:checks",
|
||||||
|
"../../rtc_base:logging",
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
rtc_library("frame_instrumentation_generator") {
|
rtc_library("frame_instrumentation_generator") {
|
||||||
sources = [
|
sources = [
|
||||||
"frame_instrumentation_generator.cc",
|
"frame_instrumentation_generator.cc",
|
||||||
@ -66,6 +79,16 @@ rtc_library("halton_sequence") {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (rtc_include_tests) {
|
if (rtc_include_tests) {
|
||||||
|
rtc_library("corruption_classifier_unittest") {
|
||||||
|
testonly = true
|
||||||
|
sources = [ "corruption_classifier_unittest.cc" ]
|
||||||
|
deps = [
|
||||||
|
":corruption_classifier",
|
||||||
|
":halton_frame_sampler",
|
||||||
|
"../../test:test_support",
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
rtc_library("frame_instrumentation_generator_unittest") {
|
rtc_library("frame_instrumentation_generator_unittest") {
|
||||||
testonly = true
|
testonly = true
|
||||||
sources = [ "frame_instrumentation_generator_unittest.cc" ]
|
sources = [ "frame_instrumentation_generator_unittest.cc" ]
|
||||||
@ -115,6 +138,7 @@ if (rtc_include_tests) {
|
|||||||
testonly = true
|
testonly = true
|
||||||
sources = []
|
sources = []
|
||||||
deps = [
|
deps = [
|
||||||
|
":corruption_classifier_unittest",
|
||||||
":frame_instrumentation_generator_unittest",
|
":frame_instrumentation_generator_unittest",
|
||||||
":generic_mapping_functions_unittest",
|
":generic_mapping_functions_unittest",
|
||||||
":halton_frame_sampler_unittest",
|
":halton_frame_sampler_unittest",
|
||||||
|
|||||||
107
video/corruption_detection/corruption_classifier.cc
Normal file
107
video/corruption_detection/corruption_classifier.cc
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2024 The WebRTC project authors. All rights reserved.
|
||||||
|
*
|
||||||
|
* Use of this source code is governed by a BSD-style license
|
||||||
|
* that can be found in the LICENSE file in the root of the source
|
||||||
|
* tree. An additional intellectual property rights grant can be found
|
||||||
|
* in the file PATENTS. All contributing project authors may
|
||||||
|
* be found in the AUTHORS file in the root of the source tree.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "video/corruption_detection/corruption_classifier.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <variant>
|
||||||
|
|
||||||
|
#include "api/array_view.h"
|
||||||
|
#include "rtc_base/checks.h"
|
||||||
|
#include "rtc_base/logging.h"
|
||||||
|
#include "video/corruption_detection/halton_frame_sampler.h"
|
||||||
|
|
||||||
|
namespace webrtc {
|
||||||
|
|
||||||
|
CorruptionClassifier::CorruptionClassifier(float scale_factor)
|
||||||
|
: config_(ScalarConfig{.scale_factor = scale_factor}) {
|
||||||
|
RTC_CHECK_GT(scale_factor, 0) << "The scale factor must be positive.";
|
||||||
|
RTC_LOG(LS_INFO) << "Calculating corruption probability using scale factor.";
|
||||||
|
}
|
||||||
|
|
||||||
|
CorruptionClassifier::CorruptionClassifier(float growth_rate, float midpoint)
|
||||||
|
: config_(LogisticFunctionConfig{.growth_rate = growth_rate,
|
||||||
|
.midpoint = midpoint}) {
|
||||||
|
RTC_CHECK_GT(growth_rate, 0)
|
||||||
|
<< "As the `score` is defined now (low score means probably not "
|
||||||
|
"corrupted and vice versa), the growth rate must be positive to have "
|
||||||
|
"a logistic function that is monotonically increasing.";
|
||||||
|
RTC_LOG(LS_INFO)
|
||||||
|
<< "Calculating corruption probability using logistic function.";
|
||||||
|
}
|
||||||
|
|
||||||
|
double CorruptionClassifier::CalculateCorruptionProbablility(
|
||||||
|
rtc::ArrayView<const FilteredSample> filtered_original_samples,
|
||||||
|
rtc::ArrayView<const FilteredSample> filtered_compressed_samples,
|
||||||
|
int luma_threshold,
|
||||||
|
int chroma_threshold) const {
|
||||||
|
RTC_DCHECK_GT(luma_threshold, 0) << "Luma threshold must be positive.";
|
||||||
|
RTC_DCHECK_GT(chroma_threshold, 0) << "Chroma threshold must be positive.";
|
||||||
|
RTC_DCHECK_EQ(filtered_original_samples.size(),
|
||||||
|
filtered_compressed_samples.size())
|
||||||
|
<< "The original and compressed frame have a different amount of "
|
||||||
|
"filtered samples.";
|
||||||
|
|
||||||
|
double loss = GetScore(filtered_original_samples, filtered_compressed_samples,
|
||||||
|
luma_threshold, chroma_threshold);
|
||||||
|
|
||||||
|
if (const auto* scalar_config = std::get_if<ScalarConfig>(&config_)) {
|
||||||
|
// Fitting the unbounded loss to the interval of [0, 1] using a simple scale
|
||||||
|
// factor and capping the loss to 1.
|
||||||
|
return std::min(loss / scalar_config->scale_factor, 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto config = std::get_if<LogisticFunctionConfig>(&config_);
|
||||||
|
RTC_DCHECK(config);
|
||||||
|
// Fitting the unbounded loss to the interval of [0, 1] using the logistic
|
||||||
|
// function.
|
||||||
|
return 1 / (1 + std::exp(-config->growth_rate * (loss - config->midpoint)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// The score is calculated according to the following formula :
|
||||||
|
//
|
||||||
|
// score = (sum_i max{(|original_i - compressed_i| - threshold, 0)^2}) / N
|
||||||
|
//
|
||||||
|
// where N is the number of samples, i in [0, N), and the threshold is
|
||||||
|
// either `luma_threshold` or `chroma_threshold` depending on whether the
|
||||||
|
// sample is luma or chroma.
|
||||||
|
double CorruptionClassifier::GetScore(
|
||||||
|
rtc::ArrayView<const FilteredSample> filtered_original_samples,
|
||||||
|
rtc::ArrayView<const FilteredSample> filtered_compressed_samples,
|
||||||
|
int luma_threshold,
|
||||||
|
int chroma_threshold) const {
|
||||||
|
RTC_CHECK_EQ(filtered_original_samples.size(),
|
||||||
|
filtered_compressed_samples.size());
|
||||||
|
const int num_samples = filtered_original_samples.size();
|
||||||
|
double sum = 0.0;
|
||||||
|
for (int i = 0; i < num_samples; ++i) {
|
||||||
|
RTC_CHECK_EQ(filtered_original_samples[i].plane,
|
||||||
|
filtered_compressed_samples[i].plane);
|
||||||
|
double abs_diff = std::abs(filtered_original_samples[i].value -
|
||||||
|
filtered_compressed_samples[i].value);
|
||||||
|
switch (filtered_original_samples[i].plane) {
|
||||||
|
case ImagePlane::kLuma:
|
||||||
|
if (abs_diff > luma_threshold) {
|
||||||
|
sum += std::pow(abs_diff - luma_threshold, 2);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case ImagePlane::kChroma:
|
||||||
|
if (abs_diff > chroma_threshold) {
|
||||||
|
sum += std::pow(abs_diff - chroma_threshold, 2);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sum / num_samples;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace webrtc
|
||||||
75
video/corruption_detection/corruption_classifier.h
Normal file
75
video/corruption_detection/corruption_classifier.h
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2024 The WebRTC project authors. All rights reserved.
|
||||||
|
*
|
||||||
|
* Use of this source code is governed by a BSD-style license
|
||||||
|
* that can be found in the LICENSE file in the root of the source
|
||||||
|
* tree. An additional intellectual property rights grant can be found
|
||||||
|
* in the file PATENTS. All contributing project authors may
|
||||||
|
* be found in the AUTHORS file in the root of the source tree.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef VIDEO_CORRUPTION_DETECTION_CORRUPTION_CLASSIFIER_H_
|
||||||
|
#define VIDEO_CORRUPTION_DETECTION_CORRUPTION_CLASSIFIER_H_
|
||||||
|
|
||||||
|
#include <variant>
|
||||||
|
|
||||||
|
#include "api/array_view.h"
|
||||||
|
#include "video/corruption_detection/halton_frame_sampler.h"
|
||||||
|
|
||||||
|
namespace webrtc {
|
||||||
|
|
||||||
|
// Based on the given filtered samples to `CalculateCorruptionProbablility` this
|
||||||
|
// class calculates a probability to indicate whether the frame is corrupted.
|
||||||
|
// The classification is done either by scaling the loss to the interval of [0,
|
||||||
|
// 1] using a simple `scale_factor` or by applying a logistic function to the
|
||||||
|
// loss. The logistic function is constructed based on `growth_rate` and
|
||||||
|
// `midpoint`, to the score between the original and the compressed frames'
|
||||||
|
// samples. This score is calculated using `GetScore`.
|
||||||
|
//
|
||||||
|
// TODO: bugs.webrtc.org/358039777 - Remove one of the constructors based on
|
||||||
|
// which mapping function works best in practice.
|
||||||
|
class CorruptionClassifier {
|
||||||
|
public:
|
||||||
|
// Calculates the corruption probability using a simple scale factor.
|
||||||
|
explicit CorruptionClassifier(float scale_factor);
|
||||||
|
// Calculates the corruption probability using a logistic function.
|
||||||
|
CorruptionClassifier(float growth_rate, float midpoint);
|
||||||
|
~CorruptionClassifier() = default;
|
||||||
|
|
||||||
|
// This function calculates and returns the probability (in the interval [0,
|
||||||
|
// 1] that a frame is corrupted. The probability is determined either by
|
||||||
|
// scaling the loss to the interval of [0, 1] using a simple `scale_factor`
|
||||||
|
// or by applying a logistic function to the loss. The method is chosen
|
||||||
|
// depending on the used constructor.
|
||||||
|
double CalculateCorruptionProbablility(
|
||||||
|
rtc::ArrayView<const FilteredSample> filtered_original_samples,
|
||||||
|
rtc::ArrayView<const FilteredSample> filtered_compressed_samples,
|
||||||
|
int luma_threshold,
|
||||||
|
int chroma_threshold) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct ScalarConfig {
|
||||||
|
float scale_factor;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Logistic function parameters. See
|
||||||
|
// https://en.wikipedia.org/wiki/Logistic_function.
|
||||||
|
struct LogisticFunctionConfig {
|
||||||
|
float growth_rate;
|
||||||
|
float midpoint;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Returns the non-normalized score between the original and the compressed
|
||||||
|
// frames' samples.
|
||||||
|
double GetScore(
|
||||||
|
rtc::ArrayView<const FilteredSample> filtered_original_samples,
|
||||||
|
rtc::ArrayView<const FilteredSample> filtered_compressed_samples,
|
||||||
|
int luma_threshold,
|
||||||
|
int chroma_threshold) const;
|
||||||
|
|
||||||
|
const std::variant<ScalarConfig, LogisticFunctionConfig> config_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace webrtc
|
||||||
|
|
||||||
|
#endif // VIDEO_CORRUPTION_DETECTION_CORRUPTION_CLASSIFIER_H_
|
||||||
269
video/corruption_detection/corruption_classifier_unittest.cc
Normal file
269
video/corruption_detection/corruption_classifier_unittest.cc
Normal file
@ -0,0 +1,269 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2024 The WebRTC project authors. All rights reserved.
|
||||||
|
*
|
||||||
|
* Use of this source code is governed by a BSD-style license
|
||||||
|
* that can be found in the LICENSE file in the root of the source
|
||||||
|
* tree. An additional intellectual property rights grant can be found
|
||||||
|
* in the file PATENTS. All contributing project authors may
|
||||||
|
* be found in the AUTHORS file in the root of the source tree.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "video/corruption_detection/corruption_classifier.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "test/gmock.h"
|
||||||
|
#include "test/gtest.h"
|
||||||
|
#include "video/corruption_detection/halton_frame_sampler.h"
|
||||||
|
|
||||||
|
namespace webrtc {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::testing::DoubleNear;
|
||||||
|
|
||||||
|
constexpr int kLumaThreshold = 3;
|
||||||
|
constexpr int kChromaThreshold = 2;
|
||||||
|
|
||||||
|
constexpr double kMaxAbsoluteError = 1e-4;
|
||||||
|
|
||||||
|
// Arbitrary values for testing.
|
||||||
|
constexpr double kBaseOriginalLumaSampleValue1 = 1.0;
|
||||||
|
constexpr double kBaseOriginalLumaSampleValue2 = 2.5;
|
||||||
|
constexpr double kBaseOriginalChromaSampleValue1 = 0.5;
|
||||||
|
|
||||||
|
constexpr FilteredSample kFilteredOriginalSampleValues[] = {
|
||||||
|
{.value = kBaseOriginalLumaSampleValue1, .plane = ImagePlane::kLuma},
|
||||||
|
{.value = kBaseOriginalLumaSampleValue2, .plane = ImagePlane::kLuma},
|
||||||
|
{.value = kBaseOriginalChromaSampleValue1, .plane = ImagePlane::kChroma}};
|
||||||
|
|
||||||
|
// The value 14.0 corresponds to the corruption probability being on the same
|
||||||
|
// side of 0.5 in the `ScalarConfig` and `LogisticFunctionConfig`.
|
||||||
|
constexpr float kScaleFactor = 14.0;
|
||||||
|
|
||||||
|
constexpr float kGrowthRate = 1.0;
|
||||||
|
constexpr float kMidpoint = 7.0;
|
||||||
|
|
||||||
|
// Helper function to create fake compressed sample values.
|
||||||
|
std::vector<FilteredSample> GetCompressedSampleValues(
|
||||||
|
double increase_value_luma,
|
||||||
|
double increase_value_chroma) {
|
||||||
|
return std::vector<FilteredSample>{
|
||||||
|
{.value = kBaseOriginalLumaSampleValue1 + increase_value_luma,
|
||||||
|
.plane = ImagePlane::kLuma},
|
||||||
|
{.value = kBaseOriginalLumaSampleValue2 + increase_value_luma,
|
||||||
|
.plane = ImagePlane::kLuma},
|
||||||
|
{.value = kBaseOriginalChromaSampleValue1 + increase_value_chroma,
|
||||||
|
.plane = ImagePlane::kChroma}};
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CorruptionClassifierTest,
|
||||||
|
SameSampleValuesShouldResultInNoCorruptionScalarConfig) {
|
||||||
|
float kIncreaseValue = 0.0;
|
||||||
|
const std::vector<FilteredSample> kFilteredCompressedSampleValues =
|
||||||
|
GetCompressedSampleValues(kIncreaseValue, kIncreaseValue);
|
||||||
|
|
||||||
|
CorruptionClassifier corruption_classifier(kScaleFactor);
|
||||||
|
|
||||||
|
// Expected: score = 0.
|
||||||
|
// Note that the `score` above corresponds to the value returned by the
|
||||||
|
// `GetScore` function. Then this value should be passed through the Scalar or
|
||||||
|
// Logistic function giving the expected result inside DoubleNear. This
|
||||||
|
// applies for all the following tests.
|
||||||
|
EXPECT_THAT(
|
||||||
|
corruption_classifier.CalculateCorruptionProbablility(
|
||||||
|
kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
|
||||||
|
kLumaThreshold, kChromaThreshold),
|
||||||
|
DoubleNear(0.0, kMaxAbsoluteError));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CorruptionClassifierTest,
|
||||||
|
SameSampleValuesShouldResultInNoCorruptionLogisticFunctionConfig) {
|
||||||
|
float kIncreaseValue = 0.0;
|
||||||
|
const std::vector<FilteredSample> kFilteredCompressedSampleValues =
|
||||||
|
GetCompressedSampleValues(kIncreaseValue, kIncreaseValue);
|
||||||
|
|
||||||
|
CorruptionClassifier corruption_classifier(kGrowthRate, kMidpoint);
|
||||||
|
|
||||||
|
// Expected: score = 0. See above for explanation why we have `0.0009` below.
|
||||||
|
EXPECT_THAT(
|
||||||
|
corruption_classifier.CalculateCorruptionProbablility(
|
||||||
|
kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
|
||||||
|
kLumaThreshold, kChromaThreshold),
|
||||||
|
DoubleNear(0.0009, kMaxAbsoluteError));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CorruptionClassifierTest,
|
||||||
|
NoCorruptionWhenAllSampleDifferencesBelowThresholdScalarConfig) {
|
||||||
|
// Following value should be < `kLumaThreshold` and `kChromaThreshold`.
|
||||||
|
const double kIncreaseValue = 1;
|
||||||
|
const std::vector<FilteredSample> kFilteredCompressedSampleValues =
|
||||||
|
GetCompressedSampleValues(kIncreaseValue, kIncreaseValue);
|
||||||
|
|
||||||
|
CorruptionClassifier corruption_classifier(kScaleFactor);
|
||||||
|
|
||||||
|
// Expected: score = 0.
|
||||||
|
EXPECT_THAT(
|
||||||
|
corruption_classifier.CalculateCorruptionProbablility(
|
||||||
|
kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
|
||||||
|
kLumaThreshold, kChromaThreshold),
|
||||||
|
DoubleNear(0.0, kMaxAbsoluteError));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CorruptionClassifierTest,
|
||||||
|
NoCorruptionWhenAllSampleDifferencesBelowThresholdLogisticFunctionConfig) {
|
||||||
|
// Following value should be < `kLumaThreshold` and `kChromaThreshold`.
|
||||||
|
const double kIncreaseValue = 1;
|
||||||
|
const std::vector<FilteredSample> kFilteredCompressedSampleValues =
|
||||||
|
GetCompressedSampleValues(kIncreaseValue, kIncreaseValue);
|
||||||
|
|
||||||
|
CorruptionClassifier corruption_classifier(kGrowthRate, kMidpoint);
|
||||||
|
|
||||||
|
// Expected: score = 0.
|
||||||
|
EXPECT_THAT(
|
||||||
|
corruption_classifier.CalculateCorruptionProbablility(
|
||||||
|
kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
|
||||||
|
kLumaThreshold, kChromaThreshold),
|
||||||
|
DoubleNear(0.0009, kMaxAbsoluteError));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CorruptionClassifierTest,
|
||||||
|
NoCorruptionWhenSmallPartOfSamplesAboveThresholdScalarConfig) {
|
||||||
|
const double kIncreaseValueLuma = 1;
|
||||||
|
const double kIncreaseValueChroma = 2.5; // Above `kChromaThreshold`.
|
||||||
|
const std::vector<FilteredSample> kFilteredCompressedSampleValues =
|
||||||
|
GetCompressedSampleValues(kIncreaseValueLuma, kIncreaseValueChroma);
|
||||||
|
|
||||||
|
CorruptionClassifier corruption_classifier(kScaleFactor);
|
||||||
|
|
||||||
|
// Expected: score = (0.5)^2 / 3.
|
||||||
|
EXPECT_THAT(
|
||||||
|
corruption_classifier.CalculateCorruptionProbablility(
|
||||||
|
kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
|
||||||
|
kLumaThreshold, kChromaThreshold),
|
||||||
|
DoubleNear(0.0060, kMaxAbsoluteError));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CorruptionClassifierTest,
|
||||||
|
NoCorruptionWhenSmallPartOfSamplesAboveThresholdLogisticFunctionConfig) {
|
||||||
|
const double kIncreaseValueLuma = 1;
|
||||||
|
const double kIncreaseValueChroma = 2.5; // Above `kChromaThreshold`.
|
||||||
|
const std::vector<FilteredSample> kFilteredCompressedSampleValues =
|
||||||
|
GetCompressedSampleValues(kIncreaseValueLuma, kIncreaseValueChroma);
|
||||||
|
|
||||||
|
CorruptionClassifier corruption_classifier(kGrowthRate, kMidpoint);
|
||||||
|
|
||||||
|
// Expected: score = (0.5)^2 / 3.
|
||||||
|
EXPECT_THAT(
|
||||||
|
corruption_classifier.CalculateCorruptionProbablility(
|
||||||
|
kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
|
||||||
|
kLumaThreshold, kChromaThreshold),
|
||||||
|
DoubleNear(0.001, kMaxAbsoluteError));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CorruptionClassifierTest,
|
||||||
|
NoCorruptionWhenAllSamplesSlightlyAboveThresholdScalarConfig) {
|
||||||
|
const double kIncreaseValueLuma = 4.2; // Above `kLumaThreshold`.
|
||||||
|
const double kIncreaseValueChroma = 2.5; // Above `kChromaThreshold`.
|
||||||
|
const std::vector<FilteredSample> kFilteredCompressedSampleValues =
|
||||||
|
GetCompressedSampleValues(kIncreaseValueLuma, kIncreaseValueChroma);
|
||||||
|
|
||||||
|
CorruptionClassifier corruption_classifier(kScaleFactor);
|
||||||
|
|
||||||
|
// Expected: score = ((0.5)^2 + 2*(1.2)^2) / 3.
|
||||||
|
EXPECT_THAT(
|
||||||
|
corruption_classifier.CalculateCorruptionProbablility(
|
||||||
|
kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
|
||||||
|
kLumaThreshold, kChromaThreshold),
|
||||||
|
DoubleNear(0.07452, kMaxAbsoluteError));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CorruptionClassifierTest,
|
||||||
|
NoCorruptionWhenAllSamplesSlightlyAboveThresholdLogisticFunctionConfig) {
|
||||||
|
const double kIncreaseValueLuma = 4.2; // Above `kLumaThreshold`.
|
||||||
|
const double kIncreaseValueChroma = 2.5; // Above `kChromaThreshold`.
|
||||||
|
const std::vector<FilteredSample> kFilteredCompressedSampleValues =
|
||||||
|
GetCompressedSampleValues(kIncreaseValueLuma, kIncreaseValueChroma);
|
||||||
|
|
||||||
|
CorruptionClassifier corruption_classifier(kGrowthRate, kMidpoint);
|
||||||
|
|
||||||
|
// Expected: score = ((0.5)^2 + 2*(1.2)^2) / 3.
|
||||||
|
EXPECT_THAT(
|
||||||
|
corruption_classifier.CalculateCorruptionProbablility(
|
||||||
|
kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
|
||||||
|
kLumaThreshold, kChromaThreshold),
|
||||||
|
DoubleNear(0.0026, kMaxAbsoluteError));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Observe that the following 2 tests in practice could be classified as
|
||||||
|
// corrupted, if so wanted. However, with the `kGrowthRate`, `kMidpoint` and
|
||||||
|
// `kScaleFactor` values chosen in these tests, the score is not high enough to
|
||||||
|
// be classified as corrupted.
|
||||||
|
TEST(CorruptionClassifierTest,
|
||||||
|
NoCorruptionWhenAllSamplesSomewhatAboveThresholdScalarConfig) {
|
||||||
|
const double kIncreaseValue = 5.0;
|
||||||
|
const std::vector<FilteredSample> kFilteredCompressedSampleValues =
|
||||||
|
GetCompressedSampleValues(kIncreaseValue, kIncreaseValue);
|
||||||
|
|
||||||
|
CorruptionClassifier corruption_classifier(kScaleFactor);
|
||||||
|
|
||||||
|
// Expected: score = ((3)^2 + 2*(2)^2) / 3.
|
||||||
|
EXPECT_THAT(
|
||||||
|
corruption_classifier.CalculateCorruptionProbablility(
|
||||||
|
kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
|
||||||
|
kLumaThreshold, kChromaThreshold),
|
||||||
|
DoubleNear(0.4048, kMaxAbsoluteError));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CorruptionClassifierTest,
|
||||||
|
NoCorruptionWhenAllSamplesSomewhatAboveThresholdLogisticFunctionConfig) {
|
||||||
|
// Somewhat above `kLumaThreshold` and `kChromaThreshold`.
|
||||||
|
const double kIncreaseValue = 5.0;
|
||||||
|
const std::vector<FilteredSample> kFilteredCompressedSampleValues =
|
||||||
|
GetCompressedSampleValues(kIncreaseValue, kIncreaseValue);
|
||||||
|
|
||||||
|
CorruptionClassifier corruption_classifier(kGrowthRate, kMidpoint);
|
||||||
|
|
||||||
|
// Expected: score = ((3)^2 + 2*(2)^2) / 3.
|
||||||
|
EXPECT_THAT(
|
||||||
|
corruption_classifier.CalculateCorruptionProbablility(
|
||||||
|
kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
|
||||||
|
kLumaThreshold, kChromaThreshold),
|
||||||
|
DoubleNear(0.2086, kMaxAbsoluteError));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CorruptionClassifierTest,
|
||||||
|
CorruptionWhenAllSamplesWellAboveThresholdScalarConfig) {
|
||||||
|
// Well above `kLumaThreshold` and `kChromaThreshold`.
|
||||||
|
const double kIncreaseValue = 7.0;
|
||||||
|
const std::vector<FilteredSample> kFilteredCompressedSampleValues =
|
||||||
|
GetCompressedSampleValues(kIncreaseValue, kIncreaseValue);
|
||||||
|
|
||||||
|
CorruptionClassifier corruption_classifier(kScaleFactor);
|
||||||
|
|
||||||
|
// Expected: score = ((5)^2 + 2*(4)^2) / 3. Expected 1 because of capping.
|
||||||
|
EXPECT_THAT(
|
||||||
|
corruption_classifier.CalculateCorruptionProbablility(
|
||||||
|
kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
|
||||||
|
kLumaThreshold, kChromaThreshold),
|
||||||
|
DoubleNear(1, kMaxAbsoluteError));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CorruptionClassifierTest,
|
||||||
|
CorruptionWhenAllSamplesWellAboveThresholdLogisticFunctionConfig) {
|
||||||
|
// Well above `kLumaThreshold` and `kChromaThreshold`.
|
||||||
|
const double kIncreaseValue = 7.0;
|
||||||
|
const std::vector<FilteredSample> kFilteredCompressedSampleValues =
|
||||||
|
GetCompressedSampleValues(kIncreaseValue, kIncreaseValue);
|
||||||
|
|
||||||
|
CorruptionClassifier corruption_classifier(kGrowthRate, kMidpoint);
|
||||||
|
|
||||||
|
// Expected: score = ((5)^2 + 2*(4)^2) / 3.
|
||||||
|
EXPECT_THAT(
|
||||||
|
corruption_classifier.CalculateCorruptionProbablility(
|
||||||
|
kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
|
||||||
|
kLumaThreshold, kChromaThreshold),
|
||||||
|
DoubleNear(1, kMaxAbsoluteError));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace webrtc
|
||||||
Loading…
x
Reference in New Issue
Block a user