diff --git a/api/BUILD.gn b/api/BUILD.gn index 5300a6a128..eca66e51aa 100644 --- a/api/BUILD.gn +++ b/api/BUILD.gn @@ -345,6 +345,7 @@ rtc_source_set("async_dns_resolver") { visibility = [ "*" ] sources = [ "async_dns_resolver.h" ] deps = [ + "../rtc_base:checks", "../rtc_base:socket_address", "../rtc_base/system:rtc_export", ] diff --git a/api/async_dns_resolver.h b/api/async_dns_resolver.h index 138503b59f..82d80de2c3 100644 --- a/api/async_dns_resolver.h +++ b/api/async_dns_resolver.h @@ -14,6 +14,7 @@ #include #include +#include "rtc_base/checks.h" #include "rtc_base/socket_address.h" #include "rtc_base/system/rtc_export.h" @@ -63,6 +64,10 @@ class RTC_EXPORT AsyncDnsResolverInterface { // Start address resolution of the hostname in `addr`. virtual void Start(const rtc::SocketAddress& addr, std::function callback) = 0; + // Start address resolution of the hostname in `addr` matching `family`. + virtual void Start(const rtc::SocketAddress& addr, + int family, + std::function callback) = 0; virtual const AsyncDnsResolverResult& result() const = 0; }; @@ -79,6 +84,14 @@ class AsyncDnsResolverFactoryInterface { virtual std::unique_ptr CreateAndResolve( const rtc::SocketAddress& addr, std::function callback) = 0; + // Creates an AsyncDnsResolver and starts resolving the name to an address + // matching the specified family. The callback will be called when resolution + // is finished. The callback will be called on the sequence that the caller + // runs on. + virtual std::unique_ptr CreateAndResolve( + const rtc::SocketAddress& addr, + int family, + std::function callback) = 0; // Creates an AsyncDnsResolver and does not start it. // For backwards compatibility, will be deprecated and removed. // One has to do a separate Start() call on the diff --git a/api/test/mock_async_dns_resolver.h b/api/test/mock_async_dns_resolver.h index 7cc17a8427..81132c96a5 100644 --- a/api/test/mock_async_dns_resolver.h +++ b/api/test/mock_async_dns_resolver.h @@ -34,6 +34,10 @@ class MockAsyncDnsResolver : public AsyncDnsResolverInterface { Start, (const rtc::SocketAddress&, std::function), (override)); + MOCK_METHOD(void, + Start, + (const rtc::SocketAddress&, int family, std::function), + (override)); MOCK_METHOD(AsyncDnsResolverResult&, result, (), (const, override)); }; @@ -43,6 +47,10 @@ class MockAsyncDnsResolverFactory : public AsyncDnsResolverFactoryInterface { CreateAndResolve, (const rtc::SocketAddress&, std::function), (override)); + MOCK_METHOD(std::unique_ptr, + CreateAndResolve, + (const rtc::SocketAddress&, int, std::function), + (override)); MOCK_METHOD(std::unique_ptr, Create, (), diff --git a/api/wrapping_async_dns_resolver.h b/api/wrapping_async_dns_resolver.h index 80da206e75..5155b0f528 100644 --- a/api/wrapping_async_dns_resolver.h +++ b/api/wrapping_async_dns_resolver.h @@ -13,6 +13,7 @@ #include #include +#include #include "absl/memory/memory.h" #include "api/async_dns_resolver.h" @@ -68,14 +69,18 @@ class RTC_EXPORT WrappingAsyncDnsResolver : public AsyncDnsResolverInterface, void Start(const rtc::SocketAddress& addr, std::function callback) override { RTC_DCHECK_RUN_ON(&sequence_checker_); - RTC_DCHECK_EQ(State::kNotStarted, state_); - state_ = State::kStarted; - callback_ = callback; - wrapped_->SignalDone.connect(this, - &WrappingAsyncDnsResolver::OnResolveResult); + PrepareToResolve(std::move(callback)); wrapped_->Start(addr); } + void Start(const rtc::SocketAddress& addr, + int family, + std::function callback) override { + RTC_DCHECK_RUN_ON(&sequence_checker_); + PrepareToResolve(std::move(callback)); + wrapped_->Start(addr, family); + } + const AsyncDnsResolverResult& result() const override { RTC_DCHECK_RUN_ON(&sequence_checker_); RTC_DCHECK_EQ(State::kResolved, state_); @@ -92,6 +97,15 @@ class RTC_EXPORT WrappingAsyncDnsResolver : public AsyncDnsResolverInterface, return wrapped_.get(); } + void PrepareToResolve(std::function callback) { + RTC_DCHECK_RUN_ON(&sequence_checker_); + RTC_DCHECK_EQ(State::kNotStarted, state_); + state_ = State::kStarted; + callback_ = std::move(callback); + wrapped_->SignalDone.connect(this, + &WrappingAsyncDnsResolver::OnResolveResult); + } + void OnResolveResult(rtc::AsyncResolverInterface* ref) { RTC_DCHECK_RUN_ON(&sequence_checker_); RTC_DCHECK(state_ == State::kStarted); diff --git a/p2p/base/basic_async_resolver_factory.cc b/p2p/base/basic_async_resolver_factory.cc index 6824357821..3fdf75b12f 100644 --- a/p2p/base/basic_async_resolver_factory.cc +++ b/p2p/base/basic_async_resolver_factory.cc @@ -36,7 +36,17 @@ WrappingAsyncDnsResolverFactory::CreateAndResolve( const rtc::SocketAddress& addr, std::function callback) { std::unique_ptr resolver = Create(); - resolver->Start(addr, callback); + resolver->Start(addr, std::move(callback)); + return resolver; +} + +std::unique_ptr +WrappingAsyncDnsResolverFactory::CreateAndResolve( + const rtc::SocketAddress& addr, + int family, + std::function callback) { + std::unique_ptr resolver = Create(); + resolver->Start(addr, family, std::move(callback)); return resolver; } diff --git a/p2p/base/basic_async_resolver_factory.h b/p2p/base/basic_async_resolver_factory.h index c988913068..9a0ba1ab28 100644 --- a/p2p/base/basic_async_resolver_factory.h +++ b/p2p/base/basic_async_resolver_factory.h @@ -45,6 +45,11 @@ class WrappingAsyncDnsResolverFactory final const rtc::SocketAddress& addr, std::function callback) override; + std::unique_ptr CreateAndResolve( + const rtc::SocketAddress& addr, + int family, + std::function callback) override; + std::unique_ptr Create() override; private: diff --git a/p2p/base/mock_async_resolver.h b/p2p/base/mock_async_resolver.h index 8bc0eb9cff..44164716b2 100644 --- a/p2p/base/mock_async_resolver.h +++ b/p2p/base/mock_async_resolver.h @@ -30,6 +30,7 @@ class MockAsyncResolver : public AsyncResolverInterface { ~MockAsyncResolver() = default; MOCK_METHOD(void, Start, (const rtc::SocketAddress&), (override)); + MOCK_METHOD(void, Start, (const rtc::SocketAddress&, int family), (override)); MOCK_METHOD(bool, GetResolvedAddress, (int family, SocketAddress* addr), diff --git a/rtc_base/BUILD.gn b/rtc_base/BUILD.gn index ee8d6cd991..2476bca1f1 100644 --- a/rtc_base/BUILD.gn +++ b/rtc_base/BUILD.gn @@ -845,6 +845,7 @@ rtc_library("async_resolver_interface") { "async_resolver_interface.h", ] deps = [ + ":checks", ":socket_address", "system:rtc_export", "third_party/sigslot", diff --git a/rtc_base/async_resolver.cc b/rtc_base/async_resolver.cc index 198013c2bc..7c1a6fe78d 100644 --- a/rtc_base/async_resolver.cc +++ b/rtc_base/async_resolver.cc @@ -145,14 +145,18 @@ void RunResolution(void* obj) { } void AsyncResolver::Start(const SocketAddress& addr) { + Start(addr, addr.family()); +} + +void AsyncResolver::Start(const SocketAddress& addr, int family) { RTC_DCHECK_RUN_ON(&sequence_checker_); RTC_DCHECK(!destroy_called_); addr_ = addr; auto thread_function = - [this, addr, caller_task_queue = webrtc::TaskQueueBase::Current(), + [this, addr, family, caller_task_queue = webrtc::TaskQueueBase::Current(), state = state_] { std::vector addresses; - int error = ResolveHostname(addr.hostname(), addr.family(), &addresses); + int error = ResolveHostname(addr.hostname(), family, &addresses); webrtc::MutexLock lock(&state->mutex); if (state->status == State::Status::kLive) { caller_task_queue->PostTask( diff --git a/rtc_base/async_resolver.h b/rtc_base/async_resolver.h index b7125ba7cf..46be43860e 100644 --- a/rtc_base/async_resolver.h +++ b/rtc_base/async_resolver.h @@ -45,6 +45,7 @@ class RTC_EXPORT AsyncResolver : public AsyncResolverInterface { ~AsyncResolver() override; void Start(const SocketAddress& addr) override; + void Start(const SocketAddress& addr, int family) override; bool GetResolvedAddress(int family, SocketAddress* addr) const override; int GetError() const override; void Destroy(bool wait) override; diff --git a/rtc_base/async_resolver_interface.h b/rtc_base/async_resolver_interface.h index 6916ea4860..998ebd800d 100644 --- a/rtc_base/async_resolver_interface.h +++ b/rtc_base/async_resolver_interface.h @@ -11,6 +11,7 @@ #ifndef RTC_BASE_ASYNC_RESOLVER_INTERFACE_H_ #define RTC_BASE_ASYNC_RESOLVER_INTERFACE_H_ +#include "rtc_base/checks.h" #include "rtc_base/socket_address.h" #include "rtc_base/system/rtc_export.h" #include "rtc_base/third_party/sigslot/sigslot.h" @@ -25,6 +26,12 @@ class RTC_EXPORT AsyncResolverInterface { // Start address resolution of the hostname in `addr`. virtual void Start(const SocketAddress& addr) = 0; + // Start address resolution of the hostname in `addr` matching `family`. + virtual void Start(const SocketAddress& addr, int family) { + // TODO(webrtc:14319) make pure virtual when all subclasses have been + // updated. + RTC_DCHECK_NOTREACHED(); + } // Returns true iff the address from `Start` was successfully resolved. // If the address was successfully resolved, sets `addr` to a copy of the // address from `Start` with the IP address set to the top most resolved