Introduce helper to guard an invocable with a safety flag

This helper suppose to replace ToQueuedTask when calls to TaskQueueBase interfaces are converted to PostTask variants that take absl::AnyInvocable.

Bug: webrtc:14245
Change-Id: I590a6ca068cf5e682ffb34770bd54cf5ce37d826
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/267706
Reviewed-by: Tomas Gunnarsson <tommi@webrtc.org>
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Commit-Queue: Danil Chapovalov <danilchap@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#37449}
This commit is contained in:
Danil Chapovalov 2022-07-05 16:03:03 +02:00 committed by WebRTC LUCI CQ
parent 4a93da315b
commit a7e15a2b7e
5 changed files with 60 additions and 42 deletions

View File

@ -121,6 +121,7 @@ rtc_library("pending_task_safety_flag") {
"../../rtc_base:checks",
"../../rtc_base/system:no_unique_address",
]
absl_deps = [ "//third_party/abseil-cpp/absl/functional:any_invocable" ]
}
if (rtc_include_tests) {

View File

@ -13,6 +13,7 @@
#include <utility>
#include "absl/functional/any_invocable.h"
#include "api/ref_counted_base.h"
#include "api/scoped_refptr.h"
#include "api/sequence_checker.h"
@ -37,25 +38,25 @@ namespace webrtc {
//
// class ExampleClass {
// ....
// my_task_queue_->PostTask(ToQueuedTask(
// [safety = pending_task_safety_flag_, this]() {
// rtc::scoped_refptr<PendingTaskSafetyFlag> flag = safety_flag_;
// my_task_queue_->PostTask(
// [flag = std::move(flag), this] {
// // Now running on the main thread.
// if (!safety->alive())
// if (!flag->alive())
// return;
// MyMethod();
// }));
// });
// ....
// ~ExampleClass() {
// pending_task_safety_flag_->SetNotAlive();
// safety_flag_->SetNotAlive();
// }
// scoped_refptr<PendingTaskSafetyFlag> pending_task_safety_flag_
// scoped_refptr<PendingTaskSafetyFlag> safety_flag_
// = PendingTaskSafetyFlag::Create();
// }
//
// ToQueuedTask has an overload that makes this check automatic:
// SafeTask makes this check automatic:
//
// my_task_queue_->PostTask(ToQueuedTask(pending_task_safety_flag_,
// [this]() { MyMethod(); }));
// my_task_queue_->PostTask(SafeTask(safety_flag_, [this] { MyMethod(); }));
//
class PendingTaskSafetyFlag final
: public rtc::RefCountedNonVirtual<PendingTaskSafetyFlag> {
@ -105,13 +106,10 @@ class PendingTaskSafetyFlag final
// It does automatic PTSF creation and signalling of destruction when the
// ScopedTaskSafety instance goes out of scope.
//
// ToQueuedTask has an overload that takes a ScopedTaskSafety too, so there
// is no need to explicitly call the "flag" method.
//
// Example usage:
//
// my_task_queue->PostTask(ToQueuedTask(scoped_task_safety,
// [this]() {
// my_task_queue->PostTask(SafeTask(scoped_task_safety.flag(),
// [this] {
// // task goes here
// }
//
@ -155,6 +153,16 @@ class ScopedTaskSafetyDetached final {
PendingTaskSafetyFlag::CreateDetached();
};
inline absl::AnyInvocable<void() &&> SafeTask(
rtc::scoped_refptr<PendingTaskSafetyFlag> flag,
absl::AnyInvocable<void() &&> task) {
return [flag = std::move(flag), task = std::move(task)]() mutable {
if (flag->alive()) {
std::move(task)();
}
};
}
} // namespace webrtc
#endif // API_TASK_QUEUE_PENDING_TASK_SAFETY_FLAG_H_

View File

@ -12,7 +12,6 @@
#include <memory>
#include "api/task_queue/to_queued_task.h"
#include "rtc_base/event.h"
#include "rtc_base/logging.h"
#include "rtc_base/task_queue_for_test.h"
@ -20,13 +19,6 @@
#include "test/gtest.h"
namespace webrtc {
namespace {
using ::testing::AtLeast;
using ::testing::Invoke;
using ::testing::MockFunction;
using ::testing::NiceMock;
using ::testing::Return;
} // namespace
TEST(PendingTaskSafetyFlagTest, Basic) {
rtc::scoped_refptr<PendingTaskSafetyFlag> safety_flag;
@ -75,11 +67,12 @@ TEST(PendingTaskSafetyFlagTest, PendingTaskSuccess) {
void DoStuff() {
RTC_DCHECK(!tq_main_->IsCurrent());
tq_main_->PostTask(ToQueuedTask([safe = flag_, this]() {
rtc::scoped_refptr<PendingTaskSafetyFlag> safe = flag_;
tq_main_->PostTask([safe = std::move(safe), this]() {
if (!safe->alive())
return;
stuff_done_ = true;
}));
});
}
bool stuff_done() const { return stuff_done_; }
@ -87,14 +80,14 @@ TEST(PendingTaskSafetyFlagTest, PendingTaskSuccess) {
private:
TaskQueueBase* const tq_main_;
bool stuff_done_ = false;
rtc::scoped_refptr<PendingTaskSafetyFlag> flag_{
PendingTaskSafetyFlag::Create()};
rtc::scoped_refptr<PendingTaskSafetyFlag> flag_ =
PendingTaskSafetyFlag::Create();
};
std::unique_ptr<Owner> owner;
tq1.SendTask(
[&owner]() {
owner.reset(new Owner());
owner = std::make_unique<Owner>();
EXPECT_FALSE(owner->stuff_done());
},
RTC_FROM_HERE);
@ -125,7 +118,7 @@ TEST(PendingTaskSafetyFlagTest, PendingTaskDropped) {
void DoStuff() {
RTC_DCHECK(!tq_main_->IsCurrent());
tq_main_->PostTask(
ToQueuedTask(safety_, [this]() { *stuff_done_ = true; }));
SafeTask(safety_.flag(), [this]() { *stuff_done_ = true; }));
}
private:
@ -136,8 +129,9 @@ TEST(PendingTaskSafetyFlagTest, PendingTaskDropped) {
std::unique_ptr<Owner> owner;
bool stuff_done = false;
tq1.SendTask([&owner, &stuff_done]() { owner.reset(new Owner(&stuff_done)); },
RTC_FROM_HERE);
tq1.SendTask(
[&owner, &stuff_done]() { owner = std::make_unique<Owner>(&stuff_done); },
RTC_FROM_HERE);
ASSERT_TRUE(owner);
// Queue up a task on tq1 that will execute before the 'DoStuff' task
// can, and delete the `owner` before the 'stuff' task can execute.
@ -168,13 +162,31 @@ TEST(PendingTaskSafetyFlagTest, PendingTaskNotAliveInitialized) {
bool task_1_ran = false;
bool task_2_ran = false;
tq.PostTask(ToQueuedTask(flag, [&task_1_ran]() { task_1_ran = true; }));
tq.PostTask(SafeTask(flag, [&task_1_ran]() { task_1_ran = true; }));
tq.PostTask([&flag]() { flag->SetAlive(); });
tq.PostTask(ToQueuedTask(flag, [&task_2_ran]() { task_2_ran = true; }));
tq.PostTask(SafeTask(flag, [&task_2_ran]() { task_2_ran = true; }));
tq.WaitForPreviouslyPostedTasks();
EXPECT_FALSE(task_1_ran);
EXPECT_TRUE(task_2_ran);
}
TEST(PendingTaskSafetyFlagTest, SafeTask) {
rtc::scoped_refptr<PendingTaskSafetyFlag> flag =
PendingTaskSafetyFlag::Create();
int count = 0;
// Create two identical tasks that increment the `count`.
auto task1 = SafeTask(flag, [&count] { ++count; });
auto task2 = SafeTask(flag, [&count] { ++count; });
EXPECT_EQ(count, 0);
std::move(task1)();
EXPECT_EQ(count, 1);
flag->SetNotAlive();
// Now task2 should actually not run.
std::move(task2)();
EXPECT_EQ(count, 1);
}
} // namespace webrtc

View File

@ -604,7 +604,10 @@ rtc_library("rtc_task_queue") {
"../api/task_queue:to_queued_task",
"system:rtc_export",
]
absl_deps = [ "//third_party/abseil-cpp/absl/memory" ]
absl_deps = [
"//third_party/abseil-cpp/absl/functional:any_invocable",
"//third_party/abseil-cpp/absl/memory",
]
}
rtc_source_set("rtc_operations_chain") {

View File

@ -16,6 +16,7 @@
#include <memory>
#include <utility>
#include "absl/functional/any_invocable.h"
#include "absl/memory/memory.h"
#include "api/task_queue/queued_task.h"
#include "api/task_queue/task_queue_base.h"
@ -105,15 +106,8 @@ class RTC_LOCKABLE RTC_EXPORT TaskQueue {
std::unique_ptr<webrtc::QueuedTask> task,
uint32_t milliseconds);
// std::enable_if is used here to make sure that calls to PostTask() with
// std::unique_ptr<SomeClassDerivedFromQueuedTask> would not end up being
// caught by this template.
template <class Closure,
typename std::enable_if<!std::is_convertible<
Closure,
std::unique_ptr<webrtc::QueuedTask>>::value>::type* = nullptr>
void PostTask(Closure&& closure) {
PostTask(webrtc::ToQueuedTask(std::forward<Closure>(closure)));
void PostTask(absl::AnyInvocable<void() &&> task) {
impl_->PostTask(std::move(task));
}
private: