Implement Socket::RecvFrom(ReceiveBuffer& buffer) in PhysicalSocketServer

And RTC_CHECK(NOTREACHED) Socket::RecvFrom(void* pv..)

This cl also change usage of PhysicalSocket to use PhysicalSocket::RecvFrom(ReceivedBuffer&) in Nat and tests.
Note that Socket::RecvFrom(ReceiveBuffer& buffer) is already used in AsyncUdpSocket.( https://webrtc-review.googlesource.com/c/src/+/332200)
AsyncTCPSocket uses Socket::Recv(). Therefore, there should be no production usage left of Socket::RecvFrom(void* pv..) in open source webrtc.

Follow up cls should remove usage of Socket::RecvFrom(void* pv..) in implementations of rtc:AsyncSocketAdapter such as FirewallSocketAdapter.

Change-Id: I597dc32b14be98e954a3dc419723f043e8a7e19e

Bug: webrtc:15368
Change-Id: I597dc32b14be98e954a3dc419723f043e8a7e19e
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/332341
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Commit-Queue: Per Kjellander <perkj@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#41635}
This commit is contained in:
Per K 2024-01-30 12:32:05 +01:00 committed by WebRTC LUCI CQ
parent 59f3b35013
commit 056782c4b5
9 changed files with 86 additions and 60 deletions

View File

@ -1106,6 +1106,7 @@ rtc_library("socket") {
] ]
deps = [ deps = [
":buffer", ":buffer",
":checks",
":macromagic", ":macromagic",
":socket_address", ":socket_address",
"../api/units:timestamp", "../api/units:timestamp",
@ -1718,6 +1719,7 @@ rtc_library("rtc_base_tests_utils") {
":async_socket", ":async_socket",
":async_tcp_socket", ":async_tcp_socket",
":async_udp_socket", ":async_udp_socket",
":buffer",
":byte_buffer", ":byte_buffer",
":checks", ":checks",
":ip_address", ":ip_address",
@ -1737,6 +1739,7 @@ rtc_library("rtc_base_tests_utils") {
":stringutils", ":stringutils",
":threading", ":threading",
":timeutils", ":timeutils",
"../api:array_view",
"../api:make_ref_counted", "../api:make_ref_counted",
"../api:refcountedbase", "../api:refcountedbase",
"../api:scoped_refptr", "../api:scoped_refptr",
@ -1852,6 +1855,7 @@ if (rtc_include_tests) {
":threading", ":threading",
":timeutils", ":timeutils",
"../api/units:time_delta", "../api/units:time_delta",
"../api/units:timestamp",
"../system_wrappers", "../system_wrappers",
"../test:field_trial", "../test:field_trial",
"../test:fileutils", "../test:fileutils",

View File

@ -11,8 +11,10 @@
#include "rtc_base/nat_server.h" #include "rtc_base/nat_server.h"
#include <cstddef> #include <cstddef>
#include <cstdint>
#include <memory> #include <memory>
#include "api/array_view.h"
#include "rtc_base/checks.h" #include "rtc_base/checks.h"
#include "rtc_base/logging.h" #include "rtc_base/logging.h"
#include "rtc_base/nat_socket_factory.h" #include "rtc_base/nat_socket_factory.h"
@ -97,8 +99,9 @@ class NATProxyServerSocket : public AsyncProxyServerSocket {
} }
SocketAddress dest_addr; SocketAddress dest_addr;
size_t address_length = UnpackAddressFromNAT(data, *len, &dest_addr); size_t address_length = UnpackAddressFromNAT(
MakeArrayView(reinterpret_cast<const uint8_t*>(data), *len),
&dest_addr);
*len -= address_length; *len -= address_length;
if (*len > 0) { if (*len > 0) {
memmove(data, data + address_length, *len); memmove(data, data + address_length, *len);
@ -171,15 +174,12 @@ NATServer::~NATServer() {
void NATServer::OnInternalUDPPacket(AsyncPacketSocket* socket, void NATServer::OnInternalUDPPacket(AsyncPacketSocket* socket,
const rtc::ReceivedPacket& packet) { const rtc::ReceivedPacket& packet) {
RTC_DCHECK(internal_socket_thread_.IsCurrent()); RTC_DCHECK(internal_socket_thread_.IsCurrent());
const char* buf = reinterpret_cast<const char*>(packet.payload().data());
size_t size = packet.payload().size();
const SocketAddress& addr = packet.source_address();
// Read the intended destination from the wire. // Read the intended destination from the wire.
SocketAddress dest_addr; SocketAddress dest_addr;
size_t length = UnpackAddressFromNAT(buf, size, &dest_addr); size_t length = UnpackAddressFromNAT(packet.payload(), &dest_addr);
// Find the translation for these addresses (allocating one if necessary). // Find the translation for these addresses (allocating one if necessary).
SocketAddressPair route(addr, dest_addr); SocketAddressPair route(packet.source_address(), dest_addr);
InternalMap::iterator iter = int_map_->find(route); InternalMap::iterator iter = int_map_->find(route);
if (iter == int_map_->end()) { if (iter == int_map_->end()) {
Translate(route); Translate(route);
@ -192,6 +192,8 @@ void NATServer::OnInternalUDPPacket(AsyncPacketSocket* socket,
// Send the packet to its intended destination. // Send the packet to its intended destination.
rtc::PacketOptions options; rtc::PacketOptions options;
const char* buf = reinterpret_cast<const char*>(packet.payload().data());
size_t size = packet.payload().size();
iter->second->socket->SendTo(buf + length, size - length, dest_addr, options); iter->second->socket->SendTo(buf + length, size - length, dest_addr, options);
} }

View File

@ -10,7 +10,9 @@
#include "rtc_base/nat_socket_factory.h" #include "rtc_base/nat_socket_factory.h"
#include "api/units/timestamp.h"
#include "rtc_base/arraysize.h" #include "rtc_base/arraysize.h"
#include "rtc_base/buffer.h"
#include "rtc_base/checks.h" #include "rtc_base/checks.h"
#include "rtc_base/logging.h" #include "rtc_base/logging.h"
#include "rtc_base/nat_server.h" #include "rtc_base/nat_server.h"
@ -47,21 +49,20 @@ size_t PackAddressForNAT(char* buf,
// Decodes the remote address from a packet that has been encoded with the nat's // Decodes the remote address from a packet that has been encoded with the nat's
// quasi-STUN format. Returns the length of the address (i.e., the offset into // quasi-STUN format. Returns the length of the address (i.e., the offset into
// data where the original packet starts). // data where the original packet starts).
size_t UnpackAddressFromNAT(const char* buf, size_t UnpackAddressFromNAT(rtc::ArrayView<const uint8_t> buf,
size_t buf_size,
SocketAddress* remote_addr) { SocketAddress* remote_addr) {
RTC_DCHECK(buf_size >= 8); RTC_CHECK(buf.size() >= 8);
RTC_DCHECK(buf[0] == 0); RTC_DCHECK(buf.data()[0] == 0);
int family = buf[1]; int family = buf[1];
uint16_t port = uint16_t port =
NetworkToHost16(*(reinterpret_cast<const uint16_t*>(&buf[2]))); NetworkToHost16(*(reinterpret_cast<const uint16_t*>(&buf.data()[2])));
if (family == AF_INET) { if (family == AF_INET) {
const in_addr* v4addr = reinterpret_cast<const in_addr*>(&buf[4]); const in_addr* v4addr = reinterpret_cast<const in_addr*>(&buf.data()[4]);
*remote_addr = SocketAddress(IPAddress(*v4addr), port); *remote_addr = SocketAddress(IPAddress(*v4addr), port);
return kNATEncodedIPv4AddressSize; return kNATEncodedIPv4AddressSize;
} else if (family == AF_INET6) { } else if (family == AF_INET6) {
RTC_DCHECK(buf_size >= 20); RTC_DCHECK(buf.size() >= 20);
const in6_addr* v6addr = reinterpret_cast<const in6_addr*>(&buf[4]); const in6_addr* v6addr = reinterpret_cast<const in6_addr*>(&buf.data()[4]);
*remote_addr = SocketAddress(IPAddress(*v6addr), port); *remote_addr = SocketAddress(IPAddress(*v6addr), port);
return kNATEncodedIPv6AddressSize; return kNATEncodedIPv6AddressSize;
} }
@ -76,14 +77,9 @@ class NATSocket : public Socket, public sigslot::has_slots<> {
family_(family), family_(family),
type_(type), type_(type),
connected_(false), connected_(false),
socket_(nullptr), socket_(nullptr) {}
buf_(nullptr),
size_(0) {}
~NATSocket() override { ~NATSocket() override { delete socket_; }
delete socket_;
delete[] buf_;
}
SocketAddress GetLocalAddress() const override { SocketAddress GetLocalAddress() const override {
return (socket_) ? socket_->GetLocalAddress() : SocketAddress(); return (socket_) ? socket_->GetLocalAddress() : SocketAddress();
@ -165,23 +161,21 @@ class NATSocket : public Socket, public sigslot::has_slots<> {
} }
// Make sure we have enough room to read the requested amount plus the // Make sure we have enough room to read the requested amount plus the
// largest possible header address. // largest possible header address.
SocketAddress remote_addr; buf_.EnsureCapacity(size + kNATEncodedIPv6AddressSize);
Grow(size + kNATEncodedIPv6AddressSize);
// Read the packet from the socket. // Read the packet from the socket.
int result = socket_->RecvFrom(buf_, size_, &remote_addr, timestamp); Socket::ReceiveBuffer receive_buffer(buf_);
int result = socket_->RecvFrom(receive_buffer);
if (result >= 0) { if (result >= 0) {
RTC_DCHECK(remote_addr == server_addr_); RTC_DCHECK(receive_buffer.source_address == server_addr_);
*timestamp =
// TODO: we need better framing so we know how many bytes we can receive_buffer.arrival_time.value_or(webrtc::Timestamp::Micros(0))
// return before we need to read the next address. For UDP, this will be .us();
// fine as long as the reader always reads everything in the packet.
RTC_DCHECK((size_t)result < size_);
// Decode the wire packet into the actual results. // Decode the wire packet into the actual results.
SocketAddress real_remote_addr; SocketAddress real_remote_addr;
size_t addrlength = UnpackAddressFromNAT(buf_, result, &real_remote_addr); size_t addrlength = UnpackAddressFromNAT(buf_, &real_remote_addr);
memcpy(data, buf_ + addrlength, result - addrlength); memcpy(data, buf_.data() + addrlength, result - addrlength);
// Make sure this packet should be delivered before returning it. // Make sure this packet should be delivered before returning it.
if (!connected_ || (real_remote_addr == remote_addr_)) { if (!connected_ || (real_remote_addr == remote_addr_)) {
@ -285,15 +279,6 @@ class NATSocket : public Socket, public sigslot::has_slots<> {
return result; return result;
} }
// Makes sure the buffer is at least the given size.
void Grow(size_t new_size) {
if (size_ < new_size) {
delete[] buf_;
size_ = new_size;
buf_ = new char[size_];
}
}
// Sends the destination address to the server to tell it to connect. // Sends the destination address to the server to tell it to connect.
void SendConnectRequest() { void SendConnectRequest() {
char buf[kNATEncodedIPv6AddressSize]; char buf[kNATEncodedIPv6AddressSize];
@ -323,8 +308,7 @@ class NATSocket : public Socket, public sigslot::has_slots<> {
Socket* socket_; Socket* socket_;
// Need to hold error in case it occurs before the socket is created. // Need to hold error in case it occurs before the socket is created.
int error_ = 0; int error_ = 0;
char* buf_; Buffer buf_;
size_t size_;
}; };
// NATSocketFactory // NATSocketFactory

View File

@ -13,10 +13,13 @@
#include <stddef.h> #include <stddef.h>
#include <cstdint>
#include <map> #include <map>
#include <memory> #include <memory>
#include <set> #include <set>
#include "api/array_view.h"
#include "rtc_base/buffer.h"
#include "rtc_base/nat_server.h" #include "rtc_base/nat_server.h"
#include "rtc_base/nat_types.h" #include "rtc_base/nat_types.h"
#include "rtc_base/socket.h" #include "rtc_base/socket.h"
@ -172,8 +175,7 @@ class NATSocketServer : public SocketServer, public NATInternalSocketFactory {
size_t PackAddressForNAT(char* buf, size_t PackAddressForNAT(char* buf,
size_t buf_size, size_t buf_size,
const SocketAddress& remote_addr); const SocketAddress& remote_addr);
size_t UnpackAddressFromNAT(const char* buf, size_t UnpackAddressFromNAT(rtc::ArrayView<const uint8_t> buf,
size_t buf_size,
SocketAddress* remote_addr); SocketAddress* remote_addr);
} // namespace rtc } // namespace rtc

View File

@ -233,12 +233,12 @@ bool TestConnectivity(const SocketAddress& src, const IPAddress& dst) {
const char* buf = "hello other socket"; const char* buf = "hello other socket";
size_t len = strlen(buf); size_t len = strlen(buf);
int sent = client->SendTo(buf, len, server->GetLocalAddress()); int sent = client->SendTo(buf, len, server->GetLocalAddress());
SocketAddress addr;
const size_t kRecvBufSize = 64;
char recvbuf[kRecvBufSize];
Thread::Current()->SleepMs(100); Thread::Current()->SleepMs(100);
int received = server->RecvFrom(recvbuf, kRecvBufSize, &addr, nullptr); rtc::Buffer payload;
return received == sent && ::memcmp(buf, recvbuf, len) == 0; Socket::ReceiveBuffer receive_buffer(payload);
int received = server->RecvFrom(receive_buffer);
return received == sent && ::memcmp(buf, payload.data(), len) == 0;
} }
void TestPhysicalInternal(const SocketAddress& int_addr) { void TestPhysicalInternal(const SocketAddress& int_addr) {

View File

@ -432,6 +432,31 @@ int PhysicalSocket::RecvFrom(void* buffer,
SocketAddress* out_addr, SocketAddress* out_addr,
int64_t* timestamp) { int64_t* timestamp) {
int received = DoReadFromSocket(buffer, length, out_addr, timestamp); int received = DoReadFromSocket(buffer, length, out_addr, timestamp);
UpdateLastError();
int error = GetError();
bool success = (received >= 0) || IsBlockingError(error);
if (udp_ || success) {
EnableEvents(DE_READ);
}
if (!success) {
RTC_LOG_F(LS_VERBOSE) << "Error = " << error;
}
return received;
}
int PhysicalSocket::RecvFrom(ReceiveBuffer& buffer) {
int64_t timestamp = -1;
static constexpr int BUF_SIZE = 64 * 1024;
buffer.payload.EnsureCapacity(BUF_SIZE);
int received =
DoReadFromSocket(buffer.payload.data(), buffer.payload.capacity(),
&buffer.source_address, &timestamp);
buffer.payload.SetSize(received > 0 ? received : 0);
if (received > 0 && timestamp != -1) {
buffer.arrival_time = webrtc::Timestamp::Micros(timestamp);
}
UpdateLastError(); UpdateLastError();
int error = GetError(); int error = GetError();
bool success = (received >= 0) || IsBlockingError(error); bool success = (received >= 0) || IsBlockingError(error);

View File

@ -188,10 +188,12 @@ class PhysicalSocket : public Socket, public sigslot::has_slots<> {
const SocketAddress& addr) override; const SocketAddress& addr) override;
int Recv(void* buffer, size_t length, int64_t* timestamp) override; int Recv(void* buffer, size_t length, int64_t* timestamp) override;
// TODO(webrtc:15368): Deprecate and remove.
int RecvFrom(void* buffer, int RecvFrom(void* buffer,
size_t length, size_t length,
SocketAddress* out_addr, SocketAddress* out_addr,
int64_t* timestamp) override; int64_t* timestamp) override;
int RecvFrom(ReceiveBuffer& buffer) override;
int Listen(int backlog) override; int Listen(int backlog) override;
Socket* Accept(SocketAddress* out_addr) override; Socket* Accept(SocketAddress* out_addr) override;

View File

@ -14,6 +14,7 @@
#include <errno.h> #include <errno.h>
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "rtc_base/checks.h"
#if defined(WEBRTC_POSIX) #if defined(WEBRTC_POSIX)
#include <arpa/inet.h> #include <arpa/inet.h>
@ -86,11 +87,11 @@ inline bool IsBlockingError(int e) {
class RTC_EXPORT Socket { class RTC_EXPORT Socket {
public: public:
struct ReceiveBuffer { struct ReceiveBuffer {
ReceiveBuffer(rtc::Buffer& payload) : payload(payload) {} ReceiveBuffer(Buffer& payload) : payload(payload) {}
absl::optional<webrtc::Timestamp> arrival_time; absl::optional<webrtc::Timestamp> arrival_time;
SocketAddress source_address; SocketAddress source_address;
rtc::Buffer& payload; Buffer& payload;
}; };
virtual ~Socket() {} virtual ~Socket() {}
@ -111,10 +112,14 @@ class RTC_EXPORT Socket {
virtual int SendTo(const void* pv, size_t cb, const SocketAddress& addr) = 0; virtual int SendTo(const void* pv, size_t cb, const SocketAddress& addr) = 0;
// `timestamp` is in units of microseconds. // `timestamp` is in units of microseconds.
virtual int Recv(void* pv, size_t cb, int64_t* timestamp) = 0; virtual int Recv(void* pv, size_t cb, int64_t* timestamp) = 0;
// TODO(webrtc:15368): Deprecate and remove.
virtual int RecvFrom(void* pv, virtual int RecvFrom(void* pv,
size_t cb, size_t cb,
SocketAddress* paddr, SocketAddress* paddr,
int64_t* timestamp) = 0; int64_t* timestamp) {
// Not implemented. Use RecvFrom(ReceiveBuffer& buffer).
RTC_CHECK_NOTREACHED();
}
// Intended to replace RecvFrom(void* ...). // Intended to replace RecvFrom(void* ...).
// Default implementation calls RecvFrom(void* ...) with 64Kbyte buffer. // Default implementation calls RecvFrom(void* ...) with 64Kbyte buffer.
// Returns number of bytes received or a negative value on error. // Returns number of bytes received or a negative value on error.

View File

@ -20,6 +20,7 @@
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "api/units/timestamp.h"
#include "rtc_base/arraysize.h" #include "rtc_base/arraysize.h"
#include "rtc_base/async_packet_socket.h" #include "rtc_base/async_packet_socket.h"
#include "rtc_base/async_udp_socket.h" #include "rtc_base/async_udp_socket.h"
@ -1092,11 +1093,11 @@ void SocketTest::SocketRecvTimestamp(const IPAddress& loopback) {
int64_t send_time_1 = TimeMicros(); int64_t send_time_1 = TimeMicros();
socket->SendTo("foo", 3, address); socket->SendTo("foo", 3, address);
int64_t recv_timestamp_1;
// Wait until data is available. // Wait until data is available.
EXPECT_TRUE_WAIT(sink.Check(socket.get(), SSE_READ), kTimeout); EXPECT_TRUE_WAIT(sink.Check(socket.get(), SSE_READ), kTimeout);
char buffer[3]; rtc::Buffer buffer;
ASSERT_GT(socket->RecvFrom(buffer, 3, nullptr, &recv_timestamp_1), 0); Socket::ReceiveBuffer receive_buffer_1(buffer);
ASSERT_GT(socket->RecvFrom(receive_buffer_1), 0);
const int64_t kTimeBetweenPacketsMs = 100; const int64_t kTimeBetweenPacketsMs = 100;
Thread::SleepMs(kTimeBetweenPacketsMs); Thread::SleepMs(kTimeBetweenPacketsMs);
@ -1105,11 +1106,12 @@ void SocketTest::SocketRecvTimestamp(const IPAddress& loopback) {
socket->SendTo("bar", 3, address); socket->SendTo("bar", 3, address);
// Wait until data is available. // Wait until data is available.
EXPECT_TRUE_WAIT(sink.Check(socket.get(), SSE_READ), kTimeout); EXPECT_TRUE_WAIT(sink.Check(socket.get(), SSE_READ), kTimeout);
int64_t recv_timestamp_2; Socket::ReceiveBuffer receive_buffer_2(buffer);
ASSERT_GT(socket->RecvFrom(buffer, 3, nullptr, &recv_timestamp_2), 0); ASSERT_GT(socket->RecvFrom(receive_buffer_2), 0);
int64_t system_time_diff = send_time_2 - send_time_1; int64_t system_time_diff = send_time_2 - send_time_1;
int64_t recv_timestamp_diff = recv_timestamp_2 - recv_timestamp_1; int64_t recv_timestamp_diff =
receive_buffer_2.arrival_time->us() - receive_buffer_1.arrival_time->us();
// Compare against the system time at the point of sending, because // Compare against the system time at the point of sending, because
// SleepMs may not sleep for exactly the requested time. // SleepMs may not sleep for exactly the requested time.
EXPECT_NEAR(system_time_diff, recv_timestamp_diff, 10000); EXPECT_NEAR(system_time_diff, recv_timestamp_diff, 10000);