diff --git a/modules/video_coding/codecs/av1/BUILD.gn b/modules/video_coding/codecs/av1/BUILD.gn index 2f7d438063..41fd163427 100644 --- a/modules/video_coding/codecs/av1/BUILD.gn +++ b/modules/video_coding/codecs/av1/BUILD.gn @@ -110,6 +110,9 @@ if (rtc_include_tests) { deps = [ ":scalability_structures", ":scalable_video_controller", + "../..:chain_diff_calculator", + "../..:frame_dependencies_calculator", + "../../../../api/video:video_frame_type", "../../../../test:test_support", "//third_party/abseil-cpp/absl/types:optional", ] diff --git a/modules/video_coding/codecs/av1/scalability_structure_l2t1.cc b/modules/video_coding/codecs/av1/scalability_structure_l2t1.cc index 0a397d963b..2c97ef75aa 100644 --- a/modules/video_coding/codecs/av1/scalability_structure_l2t1.cc +++ b/modules/video_coding/codecs/av1/scalability_structure_l2t1.cc @@ -52,7 +52,7 @@ FrameDependencyStructure ScalabilityStructureL2T1::DependencyStructure() const { structure.templates = { Builder().S(0).Dtis("SR").Fdiffs({2}).ChainDiffs({2, 1}).Build(), Builder().S(0).Dtis("SS").ChainDiffs({0, 0}).Build(), - Builder().S(1).Dtis("-R").Fdiffs({1, 2}).ChainDiffs({1, 2}).Build(), + Builder().S(1).Dtis("-R").Fdiffs({2, 1}).ChainDiffs({1, 1}).Build(), Builder().S(1).Dtis("-S").Fdiffs({1}).ChainDiffs({1, 1}).Build(), }; return structure; diff --git a/modules/video_coding/codecs/av1/scalability_structure_unittest.cc b/modules/video_coding/codecs/av1/scalability_structure_unittest.cc index 2f4342fd0b..37e9d716d3 100644 --- a/modules/video_coding/codecs/av1/scalability_structure_unittest.cc +++ b/modules/video_coding/codecs/av1/scalability_structure_unittest.cc @@ -17,11 +17,14 @@ #include #include "absl/types/optional.h" +#include "api/video/video_frame_type.h" +#include "modules/video_coding/chain_diff_calculator.h" #include "modules/video_coding/codecs/av1/scalability_structure_l1t2.h" #include "modules/video_coding/codecs/av1/scalability_structure_l2t1.h" #include "modules/video_coding/codecs/av1/scalability_structure_l2t1_key.h" #include "modules/video_coding/codecs/av1/scalability_structure_s2t1.h" #include "modules/video_coding/codecs/av1/scalable_video_controller.h" +#include "modules/video_coding/frame_dependencies_calculator.h" #include "test/gmock.h" #include "test/gtest.h" @@ -29,6 +32,7 @@ namespace webrtc { namespace { using ::testing::AllOf; +using ::testing::Contains; using ::testing::Each; using ::testing::Field; using ::testing::Ge; @@ -47,9 +51,48 @@ struct SvcTestParam { std::string name; std::function()> svc_factory; + int num_temporal_units; }; -class ScalabilityStructureTest : public TestWithParam {}; +class ScalabilityStructureTest : public TestWithParam { + public: + std::vector GenerateAllFrames() { + std::vector frames; + + FrameDependenciesCalculator frame_deps_calculator; + ChainDiffCalculator chain_diff_calculator; + std::unique_ptr structure_controller = + GetParam().svc_factory(); + FrameDependencyStructure structure = + structure_controller->DependencyStructure(); + for (int i = 0; i < GetParam().num_temporal_units; ++i) { + for (auto& layer_frame : + structure_controller->NextFrameConfig(/*reset=*/false)) { + int64_t frame_id = static_cast(frames.size()); + bool is_keyframe = layer_frame.is_keyframe; + absl::optional frame_info = + structure_controller->OnEncodeDone(std::move(layer_frame)); + EXPECT_TRUE(frame_info.has_value()); + if (is_keyframe) { + chain_diff_calculator.Reset(frame_info->part_of_chain); + } + frame_info->chain_diffs = + chain_diff_calculator.From(frame_id, frame_info->part_of_chain); + for (int64_t base_frame_id : frame_deps_calculator.FromBuffersUsage( + is_keyframe ? VideoFrameType::kVideoFrameKey + : VideoFrameType::kVideoFrameDelta, + frame_id, frame_info->encoder_buffers)) { + EXPECT_LT(base_frame_id, frame_id); + EXPECT_GE(base_frame_id, 0); + frame_info->frame_diffs.push_back(frame_id - base_frame_id); + } + + frames.push_back(*std::move(frame_info)); + } + } + return frames; + } +}; TEST_P(ScalabilityStructureTest, NumberOfDecodeTargetsAndChainsAreInRangeAndConsistent) { @@ -111,14 +154,140 @@ TEST_P(ScalabilityStructureTest, TemplatesMatchNumberOfDecodeTargetsAndChains) { SizeIs(structure.num_chains))))); } +TEST_P(ScalabilityStructureTest, FrameInfoMatchesFrameDependencyStructure) { + FrameDependencyStructure structure = + GetParam().svc_factory()->DependencyStructure(); + std::vector frame_infos = GenerateAllFrames(); + for (size_t frame_id = 0; frame_id < frame_infos.size(); ++frame_id) { + const auto& frame = frame_infos[frame_id]; + EXPECT_GE(frame.spatial_id, 0) << " for frame " << frame_id; + EXPECT_GE(frame.temporal_id, 0) << " for frame " << frame_id; + EXPECT_THAT(frame.decode_target_indications, + SizeIs(structure.num_decode_targets)) + << " for frame " << frame_id; + EXPECT_THAT(frame.part_of_chain, SizeIs(structure.num_chains)) + << " for frame " << frame_id; + } +} + +TEST_P(ScalabilityStructureTest, ThereIsAPerfectTemplateForEachFrame) { + FrameDependencyStructure structure = + GetParam().svc_factory()->DependencyStructure(); + std::vector frame_infos = GenerateAllFrames(); + for (size_t frame_id = 0; frame_id < frame_infos.size(); ++frame_id) { + EXPECT_THAT(structure.templates, Contains(frame_infos[frame_id])) + << " for frame " << frame_id; + } +} + +TEST_P(ScalabilityStructureTest, FrameDependsOnSameOrLowerLayer) { + std::vector frame_infos = GenerateAllFrames(); + int64_t num_frames = frame_infos.size(); + + for (int64_t frame_id = 0; frame_id < num_frames; ++frame_id) { + const auto& frame = frame_infos[frame_id]; + for (int frame_diff : frame.frame_diffs) { + int64_t base_frame_id = frame_id - frame_diff; + const auto& base_frame = frame_infos[base_frame_id]; + EXPECT_GE(frame.spatial_id, base_frame.spatial_id) + << "Frame " << frame_id << " depends on frame " << base_frame_id; + EXPECT_GE(frame.temporal_id, base_frame.temporal_id) + << "Frame " << frame_id << " depends on frame " << base_frame_id; + } + } +} + +TEST_P(ScalabilityStructureTest, NoFrameDependsOnDiscardableOrNotPresent) { + std::vector frame_infos = GenerateAllFrames(); + int64_t num_frames = frame_infos.size(); + FrameDependencyStructure structure = + GetParam().svc_factory()->DependencyStructure(); + + for (int dt = 0; dt < structure.num_decode_targets; ++dt) { + for (int64_t frame_id = 0; frame_id < num_frames; ++frame_id) { + const auto& frame = frame_infos[frame_id]; + if (frame.decode_target_indications[dt] == + DecodeTargetIndication::kNotPresent) { + continue; + } + for (int frame_diff : frame.frame_diffs) { + int64_t base_frame_id = frame_id - frame_diff; + const auto& base_frame = frame_infos[base_frame_id]; + EXPECT_NE(base_frame.decode_target_indications[dt], + DecodeTargetIndication::kNotPresent) + << "Frame " << frame_id << " depends on frame " << base_frame_id + << " that is not part of decode target#" << dt; + EXPECT_NE(base_frame.decode_target_indications[dt], + DecodeTargetIndication::kDiscardable) + << "Frame " << frame_id << " depends on frame " << base_frame_id + << " that is discardable for decode target#" << dt; + } + } + } +} + +TEST_P(ScalabilityStructureTest, NoFrameDependsThroughSwitchIndication) { + FrameDependencyStructure structure = + GetParam().svc_factory()->DependencyStructure(); + std::vector frame_infos = GenerateAllFrames(); + int64_t num_frames = frame_infos.size(); + std::vector> full_deps(num_frames); + + // For each frame calculate set of all frames it depends on, both directly and + // indirectly. + for (int64_t frame_id = 0; frame_id < num_frames; ++frame_id) { + std::set all_base_frames; + for (int frame_diff : frame_infos[frame_id].frame_diffs) { + int64_t base_frame_id = frame_id - frame_diff; + all_base_frames.insert(base_frame_id); + const auto& indirect = full_deps[base_frame_id]; + all_base_frames.insert(indirect.begin(), indirect.end()); + } + full_deps[frame_id] = std::move(all_base_frames); + } + + // Now check the switch indication: frames after the switch indication mustn't + // depend on any addition frames before the switch indications. + for (int dt = 0; dt < structure.num_decode_targets; ++dt) { + for (int64_t switch_frame_id = 0; switch_frame_id < num_frames; + ++switch_frame_id) { + if (frame_infos[switch_frame_id].decode_target_indications[dt] != + DecodeTargetIndication::kSwitch) { + continue; + } + for (int64_t later_frame_id = switch_frame_id + 1; + later_frame_id < num_frames; ++later_frame_id) { + if (frame_infos[later_frame_id].decode_target_indications[dt] == + DecodeTargetIndication::kNotPresent) { + continue; + } + for (int frame_diff : frame_infos[later_frame_id].frame_diffs) { + int64_t early_frame_id = later_frame_id - frame_diff; + if (early_frame_id < switch_frame_id) { + EXPECT_THAT(full_deps[switch_frame_id], Contains(early_frame_id)) + << "For decode target #" << dt << " frame " << later_frame_id + << " depends on the frame " << early_frame_id + << " that switch indication frame " << switch_frame_id + << " doesn't directly on indirectly depend on."; + } + } + } + } + } +} + INSTANTIATE_TEST_SUITE_P( Svc, ScalabilityStructureTest, - Values(SvcTestParam{"L1T2", std::make_unique}, - SvcTestParam{"L2T1", std::make_unique}, + Values(SvcTestParam{"L1T2", std::make_unique, + /*num_temporal_units=*/4}, + SvcTestParam{"L2T1", std::make_unique, + /*num_temporal_units=*/3}, SvcTestParam{"L2T1Key", - std::make_unique}, - SvcTestParam{"S2T1", std::make_unique}), + std::make_unique, + /*num_temporal_units=*/3}, + SvcTestParam{"S2T1", std::make_unique, + /*num_temporal_units=*/3}), [](const testing::TestParamInfo& info) { return info.param.name; });