From a7e15a2b7e9acc361c9233f901ab16a318e69b9d Mon Sep 17 00:00:00 2001 From: Danil Chapovalov Date: Tue, 5 Jul 2022 16:03:03 +0200 Subject: [PATCH] Introduce helper to guard an invocable with a safety flag This helper suppose to replace ToQueuedTask when calls to TaskQueueBase interfaces are converted to PostTask variants that take absl::AnyInvocable. Bug: webrtc:14245 Change-Id: I590a6ca068cf5e682ffb34770bd54cf5ce37d826 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/267706 Reviewed-by: Tomas Gunnarsson Reviewed-by: Harald Alvestrand Commit-Queue: Danil Chapovalov Cr-Commit-Position: refs/heads/main@{#37449} --- api/task_queue/BUILD.gn | 1 + api/task_queue/pending_task_safety_flag.h | 36 ++++++++------ .../pending_task_safety_flag_unittest.cc | 48 ++++++++++++------- rtc_base/BUILD.gn | 5 +- rtc_base/task_queue.h | 12 ++--- 5 files changed, 60 insertions(+), 42 deletions(-) diff --git a/api/task_queue/BUILD.gn b/api/task_queue/BUILD.gn index 368b3b0ed7..62a4a61595 100644 --- a/api/task_queue/BUILD.gn +++ b/api/task_queue/BUILD.gn @@ -121,6 +121,7 @@ rtc_library("pending_task_safety_flag") { "../../rtc_base:checks", "../../rtc_base/system:no_unique_address", ] + absl_deps = [ "//third_party/abseil-cpp/absl/functional:any_invocable" ] } if (rtc_include_tests) { diff --git a/api/task_queue/pending_task_safety_flag.h b/api/task_queue/pending_task_safety_flag.h index edfd8e93dd..3b948ca8f1 100644 --- a/api/task_queue/pending_task_safety_flag.h +++ b/api/task_queue/pending_task_safety_flag.h @@ -13,6 +13,7 @@ #include +#include "absl/functional/any_invocable.h" #include "api/ref_counted_base.h" #include "api/scoped_refptr.h" #include "api/sequence_checker.h" @@ -37,25 +38,25 @@ namespace webrtc { // // class ExampleClass { // .... -// my_task_queue_->PostTask(ToQueuedTask( -// [safety = pending_task_safety_flag_, this]() { +// rtc::scoped_refptr flag = safety_flag_; +// my_task_queue_->PostTask( +// [flag = std::move(flag), this] { // // Now running on the main thread. -// if (!safety->alive()) +// if (!flag->alive()) // return; // MyMethod(); -// })); +// }); // .... // ~ExampleClass() { -// pending_task_safety_flag_->SetNotAlive(); +// safety_flag_->SetNotAlive(); // } -// scoped_refptr pending_task_safety_flag_ +// scoped_refptr safety_flag_ // = PendingTaskSafetyFlag::Create(); // } // -// ToQueuedTask has an overload that makes this check automatic: +// SafeTask makes this check automatic: // -// my_task_queue_->PostTask(ToQueuedTask(pending_task_safety_flag_, -// [this]() { MyMethod(); })); +// my_task_queue_->PostTask(SafeTask(safety_flag_, [this] { MyMethod(); })); // class PendingTaskSafetyFlag final : public rtc::RefCountedNonVirtual { @@ -105,13 +106,10 @@ class PendingTaskSafetyFlag final // It does automatic PTSF creation and signalling of destruction when the // ScopedTaskSafety instance goes out of scope. // -// ToQueuedTask has an overload that takes a ScopedTaskSafety too, so there -// is no need to explicitly call the "flag" method. -// // Example usage: // -// my_task_queue->PostTask(ToQueuedTask(scoped_task_safety, -// [this]() { +// my_task_queue->PostTask(SafeTask(scoped_task_safety.flag(), +// [this] { // // task goes here // } // @@ -155,6 +153,16 @@ class ScopedTaskSafetyDetached final { PendingTaskSafetyFlag::CreateDetached(); }; +inline absl::AnyInvocable SafeTask( + rtc::scoped_refptr flag, + absl::AnyInvocable task) { + return [flag = std::move(flag), task = std::move(task)]() mutable { + if (flag->alive()) { + std::move(task)(); + } + }; +} + } // namespace webrtc #endif // API_TASK_QUEUE_PENDING_TASK_SAFETY_FLAG_H_ diff --git a/api/task_queue/pending_task_safety_flag_unittest.cc b/api/task_queue/pending_task_safety_flag_unittest.cc index f045cabb90..49cbb6a32b 100644 --- a/api/task_queue/pending_task_safety_flag_unittest.cc +++ b/api/task_queue/pending_task_safety_flag_unittest.cc @@ -12,7 +12,6 @@ #include -#include "api/task_queue/to_queued_task.h" #include "rtc_base/event.h" #include "rtc_base/logging.h" #include "rtc_base/task_queue_for_test.h" @@ -20,13 +19,6 @@ #include "test/gtest.h" namespace webrtc { -namespace { -using ::testing::AtLeast; -using ::testing::Invoke; -using ::testing::MockFunction; -using ::testing::NiceMock; -using ::testing::Return; -} // namespace TEST(PendingTaskSafetyFlagTest, Basic) { rtc::scoped_refptr safety_flag; @@ -75,11 +67,12 @@ TEST(PendingTaskSafetyFlagTest, PendingTaskSuccess) { void DoStuff() { RTC_DCHECK(!tq_main_->IsCurrent()); - tq_main_->PostTask(ToQueuedTask([safe = flag_, this]() { + rtc::scoped_refptr safe = flag_; + tq_main_->PostTask([safe = std::move(safe), this]() { if (!safe->alive()) return; stuff_done_ = true; - })); + }); } bool stuff_done() const { return stuff_done_; } @@ -87,14 +80,14 @@ TEST(PendingTaskSafetyFlagTest, PendingTaskSuccess) { private: TaskQueueBase* const tq_main_; bool stuff_done_ = false; - rtc::scoped_refptr flag_{ - PendingTaskSafetyFlag::Create()}; + rtc::scoped_refptr flag_ = + PendingTaskSafetyFlag::Create(); }; std::unique_ptr owner; tq1.SendTask( [&owner]() { - owner.reset(new Owner()); + owner = std::make_unique(); EXPECT_FALSE(owner->stuff_done()); }, RTC_FROM_HERE); @@ -125,7 +118,7 @@ TEST(PendingTaskSafetyFlagTest, PendingTaskDropped) { void DoStuff() { RTC_DCHECK(!tq_main_->IsCurrent()); tq_main_->PostTask( - ToQueuedTask(safety_, [this]() { *stuff_done_ = true; })); + SafeTask(safety_.flag(), [this]() { *stuff_done_ = true; })); } private: @@ -136,8 +129,9 @@ TEST(PendingTaskSafetyFlagTest, PendingTaskDropped) { std::unique_ptr owner; bool stuff_done = false; - tq1.SendTask([&owner, &stuff_done]() { owner.reset(new Owner(&stuff_done)); }, - RTC_FROM_HERE); + tq1.SendTask( + [&owner, &stuff_done]() { owner = std::make_unique(&stuff_done); }, + RTC_FROM_HERE); ASSERT_TRUE(owner); // Queue up a task on tq1 that will execute before the 'DoStuff' task // can, and delete the `owner` before the 'stuff' task can execute. @@ -168,13 +162,31 @@ TEST(PendingTaskSafetyFlagTest, PendingTaskNotAliveInitialized) { bool task_1_ran = false; bool task_2_ran = false; - tq.PostTask(ToQueuedTask(flag, [&task_1_ran]() { task_1_ran = true; })); + tq.PostTask(SafeTask(flag, [&task_1_ran]() { task_1_ran = true; })); tq.PostTask([&flag]() { flag->SetAlive(); }); - tq.PostTask(ToQueuedTask(flag, [&task_2_ran]() { task_2_ran = true; })); + tq.PostTask(SafeTask(flag, [&task_2_ran]() { task_2_ran = true; })); tq.WaitForPreviouslyPostedTasks(); EXPECT_FALSE(task_1_ran); EXPECT_TRUE(task_2_ran); } +TEST(PendingTaskSafetyFlagTest, SafeTask) { + rtc::scoped_refptr flag = + PendingTaskSafetyFlag::Create(); + + int count = 0; + // Create two identical tasks that increment the `count`. + auto task1 = SafeTask(flag, [&count] { ++count; }); + auto task2 = SafeTask(flag, [&count] { ++count; }); + + EXPECT_EQ(count, 0); + std::move(task1)(); + EXPECT_EQ(count, 1); + flag->SetNotAlive(); + // Now task2 should actually not run. + std::move(task2)(); + EXPECT_EQ(count, 1); +} + } // namespace webrtc diff --git a/rtc_base/BUILD.gn b/rtc_base/BUILD.gn index c997073898..73446952b2 100644 --- a/rtc_base/BUILD.gn +++ b/rtc_base/BUILD.gn @@ -604,7 +604,10 @@ rtc_library("rtc_task_queue") { "../api/task_queue:to_queued_task", "system:rtc_export", ] - absl_deps = [ "//third_party/abseil-cpp/absl/memory" ] + absl_deps = [ + "//third_party/abseil-cpp/absl/functional:any_invocable", + "//third_party/abseil-cpp/absl/memory", + ] } rtc_source_set("rtc_operations_chain") { diff --git a/rtc_base/task_queue.h b/rtc_base/task_queue.h index 882f751f4f..b2a08f8fa0 100644 --- a/rtc_base/task_queue.h +++ b/rtc_base/task_queue.h @@ -16,6 +16,7 @@ #include #include +#include "absl/functional/any_invocable.h" #include "absl/memory/memory.h" #include "api/task_queue/queued_task.h" #include "api/task_queue/task_queue_base.h" @@ -105,15 +106,8 @@ class RTC_LOCKABLE RTC_EXPORT TaskQueue { std::unique_ptr task, uint32_t milliseconds); - // std::enable_if is used here to make sure that calls to PostTask() with - // std::unique_ptr would not end up being - // caught by this template. - template >::value>::type* = nullptr> - void PostTask(Closure&& closure) { - PostTask(webrtc::ToQueuedTask(std::forward(closure))); + void PostTask(absl::AnyInvocable task) { + impl_->PostTask(std::move(task)); } private: