diff --git a/rtc_base/BUILD.gn b/rtc_base/BUILD.gn index db26b5072a..30722a5f54 100644 --- a/rtc_base/BUILD.gn +++ b/rtc_base/BUILD.gn @@ -935,7 +935,9 @@ rtc_library("async_dns_resolver") { ":logging", ":macromagic", ":platform_thread", + ":refcount", "../api:async_dns_resolver", + "../api:make_ref_counted", "../api:sequence_checker", "../api/task_queue:pending_task_safety_flag", ] diff --git a/rtc_base/async_dns_resolver.cc b/rtc_base/async_dns_resolver.cc index 918eeb74f9..8cbd21cb6d 100644 --- a/rtc_base/async_dns_resolver.cc +++ b/rtc_base/async_dns_resolver.cc @@ -15,6 +15,7 @@ #include #include +#include "api/make_ref_counted.h" #include "rtc_base/logging.h" #include "rtc_base/platform_thread.h" @@ -98,6 +99,42 @@ void PostTaskToGlobalQueue( } // namespace +class AsyncDnsResolver::State : public rtc::RefCountedBase { + public: + enum class Status { + kActive, // Running request, or able to be passed one + kFinished, // Request has finished processing + kDead // The owning AsyncDnsResolver has been deleted + }; + static rtc::scoped_refptr Create() { + return rtc::make_ref_counted(); + } + + // Execute the passed function if the state is Active. + void Finish(absl::AnyInvocable function) { + webrtc::MutexLock lock(&mutex_); + if (status_ != Status::kActive) { + return; + } + status_ = Status::kFinished; + function(); + } + void Kill() { + webrtc::MutexLock lock(&mutex_); + status_ = Status::kDead; + } + + private: + webrtc::Mutex mutex_; + Status status_ RTC_GUARDED_BY(mutex_) = Status::kActive; +}; + +AsyncDnsResolver::AsyncDnsResolver() : state_(State::Create()) {} + +AsyncDnsResolver::~AsyncDnsResolver() { + state_->Kill(); +} + void AsyncDnsResolver::Start(const rtc::SocketAddress& addr, absl::AnyInvocable callback) { Start(addr, addr.family(), std::move(callback)); @@ -111,17 +148,22 @@ void AsyncDnsResolver::Start(const rtc::SocketAddress& addr, result_.addr_ = addr; callback_ = std::move(callback); auto thread_function = [this, addr, family, flag = safety_.flag(), - caller_task_queue = - webrtc::TaskQueueBase::Current()] { + caller_task_queue = webrtc::TaskQueueBase::Current(), + state = state_] { std::vector addresses; int error = ResolveHostname(addr.hostname(), family, addresses); - caller_task_queue->PostTask( - SafeTask(flag, [this, error, addresses = std::move(addresses)] { - RTC_DCHECK_RUN_ON(&result_.sequence_checker_); - result_.addresses_ = addresses; - result_.error_ = error; - callback_(); - })); + // We assume that the caller task queue is still around if the + // AsyncDnsResolver has not been destroyed. + state->Finish([this, error, flag, caller_task_queue, + addresses = std::move(addresses)]() { + caller_task_queue->PostTask( + SafeTask(flag, [this, error, addresses = std::move(addresses)] { + RTC_DCHECK_RUN_ON(&result_.sequence_checker_); + result_.addresses_ = addresses; + result_.error_ = error; + callback_(); + })); + }); }; #if defined(WEBRTC_MAC) || defined(WEBRTC_IOS) PostTaskToGlobalQueue( diff --git a/rtc_base/async_dns_resolver.h b/rtc_base/async_dns_resolver.h index f8d60b90e2..c15af7a1cb 100644 --- a/rtc_base/async_dns_resolver.h +++ b/rtc_base/async_dns_resolver.h @@ -15,6 +15,7 @@ #include "api/async_dns_resolver.h" #include "api/sequence_checker.h" #include "api/task_queue/pending_task_safety_flag.h" +#include "rtc_base/ref_counted_object.h" #include "rtc_base/thread_annotations.h" namespace webrtc { @@ -38,6 +39,8 @@ class AsyncDnsResolverResultImpl : public AsyncDnsResolverResult { class AsyncDnsResolver : public AsyncDnsResolverInterface { public: + AsyncDnsResolver(); + ~AsyncDnsResolver(); // Start address resolution of the hostname in `addr`. void Start(const rtc::SocketAddress& addr, absl::AnyInvocable callback) override; @@ -48,7 +51,9 @@ class AsyncDnsResolver : public AsyncDnsResolverInterface { const AsyncDnsResolverResult& result() const override; private: - ScopedTaskSafety safety_; + class State; + ScopedTaskSafety safety_; // To check for client going away + rtc::scoped_refptr state_; // To check for "this" going away AsyncDnsResolverResultImpl result_; absl::AnyInvocable callback_; }; diff --git a/rtc_base/async_dns_resolver_unittest.cc b/rtc_base/async_dns_resolver_unittest.cc index e6a6d9415d..11f8b1b6f4 100644 --- a/rtc_base/async_dns_resolver_unittest.cc +++ b/rtc_base/async_dns_resolver_unittest.cc @@ -40,5 +40,19 @@ TEST(AsyncDnsResolver, ResolvingLocalhostWorks) { } } +TEST(AsyncDnsResolver, ResolveAfterDeleteDoesNotReturn) { + test::RunLoop loop; + std::unique_ptr resolver = + std::make_unique(); + rtc::SocketAddress address("localhost", + kPortNumber); // Port number does not matter + rtc::SocketAddress resolved_address; + bool done = false; + resolver->Start(address, [&done] { done = true; }); + resolver.reset(); // Deletes resolver. + rtc::Thread::Current()->SleepMs(1); // Allows callback to execute + EXPECT_FALSE(done); // Expect no result. +} + } // namespace } // namespace webrtc