From 056782c4b5c682ed8bf8881a31430717ab719122 Mon Sep 17 00:00:00 2001 From: Per K Date: Tue, 30 Jan 2024 12:32:05 +0100 Subject: [PATCH] Implement Socket::RecvFrom(ReceiveBuffer& buffer) in PhysicalSocketServer And RTC_CHECK(NOTREACHED) Socket::RecvFrom(void* pv..) This cl also change usage of PhysicalSocket to use PhysicalSocket::RecvFrom(ReceivedBuffer&) in Nat and tests. Note that Socket::RecvFrom(ReceiveBuffer& buffer) is already used in AsyncUdpSocket.( https://webrtc-review.googlesource.com/c/src/+/332200) AsyncTCPSocket uses Socket::Recv(). Therefore, there should be no production usage left of Socket::RecvFrom(void* pv..) in open source webrtc. Follow up cls should remove usage of Socket::RecvFrom(void* pv..) in implementations of rtc:AsyncSocketAdapter such as FirewallSocketAdapter. Change-Id: I597dc32b14be98e954a3dc419723f043e8a7e19e Bug: webrtc:15368 Change-Id: I597dc32b14be98e954a3dc419723f043e8a7e19e Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/332341 Reviewed-by: Harald Alvestrand Commit-Queue: Per Kjellander Cr-Commit-Position: refs/heads/main@{#41635} --- rtc_base/BUILD.gn | 4 +++ rtc_base/nat_server.cc | 16 +++++---- rtc_base/nat_socket_factory.cc | 58 +++++++++++------------------- rtc_base/nat_socket_factory.h | 6 ++-- rtc_base/nat_unittest.cc | 10 +++--- rtc_base/physical_socket_server.cc | 25 +++++++++++++ rtc_base/physical_socket_server.h | 2 ++ rtc_base/socket.h | 11 ++++-- rtc_base/socket_unittest.cc | 14 ++++---- 9 files changed, 86 insertions(+), 60 deletions(-) diff --git a/rtc_base/BUILD.gn b/rtc_base/BUILD.gn index ac30d8708b..8d108bbdec 100644 --- a/rtc_base/BUILD.gn +++ b/rtc_base/BUILD.gn @@ -1106,6 +1106,7 @@ rtc_library("socket") { ] deps = [ ":buffer", + ":checks", ":macromagic", ":socket_address", "../api/units:timestamp", @@ -1718,6 +1719,7 @@ rtc_library("rtc_base_tests_utils") { ":async_socket", ":async_tcp_socket", ":async_udp_socket", + ":buffer", ":byte_buffer", ":checks", ":ip_address", @@ -1737,6 +1739,7 @@ rtc_library("rtc_base_tests_utils") { ":stringutils", ":threading", ":timeutils", + "../api:array_view", "../api:make_ref_counted", "../api:refcountedbase", "../api:scoped_refptr", @@ -1852,6 +1855,7 @@ if (rtc_include_tests) { ":threading", ":timeutils", "../api/units:time_delta", + "../api/units:timestamp", "../system_wrappers", "../test:field_trial", "../test:fileutils", diff --git a/rtc_base/nat_server.cc b/rtc_base/nat_server.cc index c274cedf18..f21d404bd3 100644 --- a/rtc_base/nat_server.cc +++ b/rtc_base/nat_server.cc @@ -11,8 +11,10 @@ #include "rtc_base/nat_server.h" #include +#include #include +#include "api/array_view.h" #include "rtc_base/checks.h" #include "rtc_base/logging.h" #include "rtc_base/nat_socket_factory.h" @@ -97,8 +99,9 @@ class NATProxyServerSocket : public AsyncProxyServerSocket { } SocketAddress dest_addr; - size_t address_length = UnpackAddressFromNAT(data, *len, &dest_addr); - + size_t address_length = UnpackAddressFromNAT( + MakeArrayView(reinterpret_cast(data), *len), + &dest_addr); *len -= address_length; if (*len > 0) { memmove(data, data + address_length, *len); @@ -171,15 +174,12 @@ NATServer::~NATServer() { void NATServer::OnInternalUDPPacket(AsyncPacketSocket* socket, const rtc::ReceivedPacket& packet) { RTC_DCHECK(internal_socket_thread_.IsCurrent()); - const char* buf = reinterpret_cast(packet.payload().data()); - size_t size = packet.payload().size(); - const SocketAddress& addr = packet.source_address(); // Read the intended destination from the wire. SocketAddress dest_addr; - size_t length = UnpackAddressFromNAT(buf, size, &dest_addr); + size_t length = UnpackAddressFromNAT(packet.payload(), &dest_addr); // Find the translation for these addresses (allocating one if necessary). - SocketAddressPair route(addr, dest_addr); + SocketAddressPair route(packet.source_address(), dest_addr); InternalMap::iterator iter = int_map_->find(route); if (iter == int_map_->end()) { Translate(route); @@ -192,6 +192,8 @@ void NATServer::OnInternalUDPPacket(AsyncPacketSocket* socket, // Send the packet to its intended destination. rtc::PacketOptions options; + const char* buf = reinterpret_cast(packet.payload().data()); + size_t size = packet.payload().size(); iter->second->socket->SendTo(buf + length, size - length, dest_addr, options); } diff --git a/rtc_base/nat_socket_factory.cc b/rtc_base/nat_socket_factory.cc index 83ec2bc327..66e4f84cd7 100644 --- a/rtc_base/nat_socket_factory.cc +++ b/rtc_base/nat_socket_factory.cc @@ -10,7 +10,9 @@ #include "rtc_base/nat_socket_factory.h" +#include "api/units/timestamp.h" #include "rtc_base/arraysize.h" +#include "rtc_base/buffer.h" #include "rtc_base/checks.h" #include "rtc_base/logging.h" #include "rtc_base/nat_server.h" @@ -47,21 +49,20 @@ size_t PackAddressForNAT(char* buf, // Decodes the remote address from a packet that has been encoded with the nat's // quasi-STUN format. Returns the length of the address (i.e., the offset into // data where the original packet starts). -size_t UnpackAddressFromNAT(const char* buf, - size_t buf_size, +size_t UnpackAddressFromNAT(rtc::ArrayView buf, SocketAddress* remote_addr) { - RTC_DCHECK(buf_size >= 8); - RTC_DCHECK(buf[0] == 0); + RTC_CHECK(buf.size() >= 8); + RTC_DCHECK(buf.data()[0] == 0); int family = buf[1]; uint16_t port = - NetworkToHost16(*(reinterpret_cast(&buf[2]))); + NetworkToHost16(*(reinterpret_cast(&buf.data()[2]))); if (family == AF_INET) { - const in_addr* v4addr = reinterpret_cast(&buf[4]); + const in_addr* v4addr = reinterpret_cast(&buf.data()[4]); *remote_addr = SocketAddress(IPAddress(*v4addr), port); return kNATEncodedIPv4AddressSize; } else if (family == AF_INET6) { - RTC_DCHECK(buf_size >= 20); - const in6_addr* v6addr = reinterpret_cast(&buf[4]); + RTC_DCHECK(buf.size() >= 20); + const in6_addr* v6addr = reinterpret_cast(&buf.data()[4]); *remote_addr = SocketAddress(IPAddress(*v6addr), port); return kNATEncodedIPv6AddressSize; } @@ -76,14 +77,9 @@ class NATSocket : public Socket, public sigslot::has_slots<> { family_(family), type_(type), connected_(false), - socket_(nullptr), - buf_(nullptr), - size_(0) {} + socket_(nullptr) {} - ~NATSocket() override { - delete socket_; - delete[] buf_; - } + ~NATSocket() override { delete socket_; } SocketAddress GetLocalAddress() const override { return (socket_) ? socket_->GetLocalAddress() : SocketAddress(); @@ -165,23 +161,21 @@ class NATSocket : public Socket, public sigslot::has_slots<> { } // Make sure we have enough room to read the requested amount plus the // largest possible header address. - SocketAddress remote_addr; - Grow(size + kNATEncodedIPv6AddressSize); + buf_.EnsureCapacity(size + kNATEncodedIPv6AddressSize); // Read the packet from the socket. - int result = socket_->RecvFrom(buf_, size_, &remote_addr, timestamp); + Socket::ReceiveBuffer receive_buffer(buf_); + int result = socket_->RecvFrom(receive_buffer); if (result >= 0) { - RTC_DCHECK(remote_addr == server_addr_); - - // TODO: we need better framing so we know how many bytes we can - // return before we need to read the next address. For UDP, this will be - // fine as long as the reader always reads everything in the packet. - RTC_DCHECK((size_t)result < size_); + RTC_DCHECK(receive_buffer.source_address == server_addr_); + *timestamp = + receive_buffer.arrival_time.value_or(webrtc::Timestamp::Micros(0)) + .us(); // Decode the wire packet into the actual results. SocketAddress real_remote_addr; - size_t addrlength = UnpackAddressFromNAT(buf_, result, &real_remote_addr); - memcpy(data, buf_ + addrlength, result - addrlength); + size_t addrlength = UnpackAddressFromNAT(buf_, &real_remote_addr); + memcpy(data, buf_.data() + addrlength, result - addrlength); // Make sure this packet should be delivered before returning it. if (!connected_ || (real_remote_addr == remote_addr_)) { @@ -285,15 +279,6 @@ class NATSocket : public Socket, public sigslot::has_slots<> { return result; } - // Makes sure the buffer is at least the given size. - void Grow(size_t new_size) { - if (size_ < new_size) { - delete[] buf_; - size_ = new_size; - buf_ = new char[size_]; - } - } - // Sends the destination address to the server to tell it to connect. void SendConnectRequest() { char buf[kNATEncodedIPv6AddressSize]; @@ -323,8 +308,7 @@ class NATSocket : public Socket, public sigslot::has_slots<> { Socket* socket_; // Need to hold error in case it occurs before the socket is created. int error_ = 0; - char* buf_; - size_t size_; + Buffer buf_; }; // NATSocketFactory diff --git a/rtc_base/nat_socket_factory.h b/rtc_base/nat_socket_factory.h index f803496b05..5adcaa5dfd 100644 --- a/rtc_base/nat_socket_factory.h +++ b/rtc_base/nat_socket_factory.h @@ -13,10 +13,13 @@ #include +#include #include #include #include +#include "api/array_view.h" +#include "rtc_base/buffer.h" #include "rtc_base/nat_server.h" #include "rtc_base/nat_types.h" #include "rtc_base/socket.h" @@ -172,8 +175,7 @@ class NATSocketServer : public SocketServer, public NATInternalSocketFactory { size_t PackAddressForNAT(char* buf, size_t buf_size, const SocketAddress& remote_addr); -size_t UnpackAddressFromNAT(const char* buf, - size_t buf_size, +size_t UnpackAddressFromNAT(rtc::ArrayView buf, SocketAddress* remote_addr); } // namespace rtc diff --git a/rtc_base/nat_unittest.cc b/rtc_base/nat_unittest.cc index 742e0d6ee7..978a30aefe 100644 --- a/rtc_base/nat_unittest.cc +++ b/rtc_base/nat_unittest.cc @@ -233,12 +233,12 @@ bool TestConnectivity(const SocketAddress& src, const IPAddress& dst) { const char* buf = "hello other socket"; size_t len = strlen(buf); int sent = client->SendTo(buf, len, server->GetLocalAddress()); - SocketAddress addr; - const size_t kRecvBufSize = 64; - char recvbuf[kRecvBufSize]; + Thread::Current()->SleepMs(100); - int received = server->RecvFrom(recvbuf, kRecvBufSize, &addr, nullptr); - return received == sent && ::memcmp(buf, recvbuf, len) == 0; + rtc::Buffer payload; + Socket::ReceiveBuffer receive_buffer(payload); + int received = server->RecvFrom(receive_buffer); + return received == sent && ::memcmp(buf, payload.data(), len) == 0; } void TestPhysicalInternal(const SocketAddress& int_addr) { diff --git a/rtc_base/physical_socket_server.cc b/rtc_base/physical_socket_server.cc index b0af1c20ce..c3bc1814a1 100644 --- a/rtc_base/physical_socket_server.cc +++ b/rtc_base/physical_socket_server.cc @@ -432,6 +432,31 @@ int PhysicalSocket::RecvFrom(void* buffer, SocketAddress* out_addr, int64_t* timestamp) { int received = DoReadFromSocket(buffer, length, out_addr, timestamp); + + UpdateLastError(); + int error = GetError(); + bool success = (received >= 0) || IsBlockingError(error); + if (udp_ || success) { + EnableEvents(DE_READ); + } + if (!success) { + RTC_LOG_F(LS_VERBOSE) << "Error = " << error; + } + return received; +} + +int PhysicalSocket::RecvFrom(ReceiveBuffer& buffer) { + int64_t timestamp = -1; + static constexpr int BUF_SIZE = 64 * 1024; + buffer.payload.EnsureCapacity(BUF_SIZE); + + int received = + DoReadFromSocket(buffer.payload.data(), buffer.payload.capacity(), + &buffer.source_address, ×tamp); + buffer.payload.SetSize(received > 0 ? received : 0); + if (received > 0 && timestamp != -1) { + buffer.arrival_time = webrtc::Timestamp::Micros(timestamp); + } UpdateLastError(); int error = GetError(); bool success = (received >= 0) || IsBlockingError(error); diff --git a/rtc_base/physical_socket_server.h b/rtc_base/physical_socket_server.h index 584f42a188..2af563a3ca 100644 --- a/rtc_base/physical_socket_server.h +++ b/rtc_base/physical_socket_server.h @@ -188,10 +188,12 @@ class PhysicalSocket : public Socket, public sigslot::has_slots<> { const SocketAddress& addr) override; int Recv(void* buffer, size_t length, int64_t* timestamp) override; + // TODO(webrtc:15368): Deprecate and remove. int RecvFrom(void* buffer, size_t length, SocketAddress* out_addr, int64_t* timestamp) override; + int RecvFrom(ReceiveBuffer& buffer) override; int Listen(int backlog) override; Socket* Accept(SocketAddress* out_addr) override; diff --git a/rtc_base/socket.h b/rtc_base/socket.h index fac75aca94..98e468e754 100644 --- a/rtc_base/socket.h +++ b/rtc_base/socket.h @@ -14,6 +14,7 @@ #include #include "absl/types/optional.h" +#include "rtc_base/checks.h" #if defined(WEBRTC_POSIX) #include @@ -86,11 +87,11 @@ inline bool IsBlockingError(int e) { class RTC_EXPORT Socket { public: struct ReceiveBuffer { - ReceiveBuffer(rtc::Buffer& payload) : payload(payload) {} + ReceiveBuffer(Buffer& payload) : payload(payload) {} absl::optional arrival_time; SocketAddress source_address; - rtc::Buffer& payload; + Buffer& payload; }; virtual ~Socket() {} @@ -111,10 +112,14 @@ class RTC_EXPORT Socket { virtual int SendTo(const void* pv, size_t cb, const SocketAddress& addr) = 0; // `timestamp` is in units of microseconds. virtual int Recv(void* pv, size_t cb, int64_t* timestamp) = 0; + // TODO(webrtc:15368): Deprecate and remove. virtual int RecvFrom(void* pv, size_t cb, SocketAddress* paddr, - int64_t* timestamp) = 0; + int64_t* timestamp) { + // Not implemented. Use RecvFrom(ReceiveBuffer& buffer). + RTC_CHECK_NOTREACHED(); + } // Intended to replace RecvFrom(void* ...). // Default implementation calls RecvFrom(void* ...) with 64Kbyte buffer. // Returns number of bytes received or a negative value on error. diff --git a/rtc_base/socket_unittest.cc b/rtc_base/socket_unittest.cc index f5ef2a33fc..5314128d0a 100644 --- a/rtc_base/socket_unittest.cc +++ b/rtc_base/socket_unittest.cc @@ -20,6 +20,7 @@ #include "absl/memory/memory.h" #include "absl/strings/string_view.h" +#include "api/units/timestamp.h" #include "rtc_base/arraysize.h" #include "rtc_base/async_packet_socket.h" #include "rtc_base/async_udp_socket.h" @@ -1092,11 +1093,11 @@ void SocketTest::SocketRecvTimestamp(const IPAddress& loopback) { int64_t send_time_1 = TimeMicros(); socket->SendTo("foo", 3, address); - int64_t recv_timestamp_1; // Wait until data is available. EXPECT_TRUE_WAIT(sink.Check(socket.get(), SSE_READ), kTimeout); - char buffer[3]; - ASSERT_GT(socket->RecvFrom(buffer, 3, nullptr, &recv_timestamp_1), 0); + rtc::Buffer buffer; + Socket::ReceiveBuffer receive_buffer_1(buffer); + ASSERT_GT(socket->RecvFrom(receive_buffer_1), 0); const int64_t kTimeBetweenPacketsMs = 100; Thread::SleepMs(kTimeBetweenPacketsMs); @@ -1105,11 +1106,12 @@ void SocketTest::SocketRecvTimestamp(const IPAddress& loopback) { socket->SendTo("bar", 3, address); // Wait until data is available. EXPECT_TRUE_WAIT(sink.Check(socket.get(), SSE_READ), kTimeout); - int64_t recv_timestamp_2; - ASSERT_GT(socket->RecvFrom(buffer, 3, nullptr, &recv_timestamp_2), 0); + Socket::ReceiveBuffer receive_buffer_2(buffer); + ASSERT_GT(socket->RecvFrom(receive_buffer_2), 0); int64_t system_time_diff = send_time_2 - send_time_1; - int64_t recv_timestamp_diff = recv_timestamp_2 - recv_timestamp_1; + int64_t recv_timestamp_diff = + receive_buffer_2.arrival_time->us() - receive_buffer_1.arrival_time->us(); // Compare against the system time at the point of sending, because // SleepMs may not sleep for exactly the requested time. EXPECT_NEAR(system_time_diff, recv_timestamp_diff, 10000);