Replace use of SignalReadPacket in DtlsTransport

Instead use PacketTransportInternal::NotifyPacketReceived

Bug: webrtc:15368
Change-Id: I70a83865c9b564429366bd297abc7dbd50da02e4
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/340301
Commit-Queue: Per Kjellander <perkj@webrtc.org>
Reviewed-by: Jonas Oreland <jonaso@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#41816}
This commit is contained in:
Per K 2024-02-26 19:47:02 +01:00 committed by WebRTC LUCI CQ
parent 09e81ccb27
commit 8e137d0509
4 changed files with 77 additions and 51 deletions

View File

@ -421,10 +421,13 @@ rtc_library("dtls_transport") {
"../rtc_base:checks",
"../rtc_base:dscp",
"../rtc_base:logging",
"../rtc_base:socket_address",
"../rtc_base:ssl",
"../rtc_base:stream",
"../rtc_base:stringutils",
"../rtc_base:threading",
"../rtc_base:timeutils",
"../rtc_base/network:received_packet",
"../rtc_base/system:no_unique_address",
]
absl_deps = [

View File

@ -11,6 +11,7 @@
#include "p2p/base/dtls_transport.h"
#include <algorithm>
#include <cstdint>
#include <memory>
#include <utility>
@ -26,10 +27,13 @@
#include "rtc_base/checks.h"
#include "rtc_base/dscp.h"
#include "rtc_base/logging.h"
#include "rtc_base/network/received_packet.h"
#include "rtc_base/rtc_certificate.h"
#include "rtc_base/socket_address.h"
#include "rtc_base/ssl_stream_adapter.h"
#include "rtc_base/stream.h"
#include "rtc_base/thread.h"
#include "rtc_base/time_utils.h"
namespace cricket {
@ -50,20 +54,20 @@ static const size_t kMaxPendingPackets = 2;
static const int kMinHandshakeTimeout = 50;
static const int kMaxHandshakeTimeout = 3000;
static bool IsDtlsPacket(const char* data, size_t len) {
const uint8_t* u = reinterpret_cast<const uint8_t*>(data);
return (len >= kDtlsRecordHeaderLen && (u[0] > 19 && u[0] < 64));
static bool IsDtlsPacket(rtc::ArrayView<const uint8_t> payload) {
const uint8_t* u = payload.data();
return (payload.size() >= kDtlsRecordHeaderLen && (u[0] > 19 && u[0] < 64));
}
static bool IsDtlsClientHelloPacket(const char* data, size_t len) {
if (!IsDtlsPacket(data, len)) {
static bool IsDtlsClientHelloPacket(rtc::ArrayView<const uint8_t> payload) {
if (!IsDtlsPacket(payload)) {
return false;
}
const uint8_t* u = reinterpret_cast<const uint8_t*>(data);
return len > 17 && u[0] == 22 && u[13] == 1;
const uint8_t* u = payload.data();
return payload.size() > 17 && u[0] == 22 && u[13] == 1;
}
static bool IsRtpPacket(const char* data, size_t len) {
const uint8_t* u = reinterpret_cast<const uint8_t*>(data);
return (len >= kMinRtpPacketLen && (u[0] & 0xC0) == 0x80);
static bool IsRtpPacket(rtc::ArrayView<const uint8_t> payload) {
const uint8_t* u = payload.data();
return (payload.size() >= kMinRtpPacketLen && (u[0] & 0xC0) == 0x80);
}
StreamInterfaceChannel::StreamInterfaceChannel(
@ -146,7 +150,11 @@ DtlsTransport::DtlsTransport(IceTransportInternal* ice_transport,
ConnectToIceTransport();
}
DtlsTransport::~DtlsTransport() = default;
DtlsTransport::~DtlsTransport() {
if (ice_transport_) {
ice_transport_->DeregisterReceivedPacketCallback(this);
}
}
webrtc::DtlsTransportState DtlsTransport::dtls_state() const {
return dtls_state_;
@ -444,7 +452,8 @@ int DtlsTransport::SendPacket(const char* data,
case webrtc::DtlsTransportState::kConnected:
if (flags & PF_SRTP_BYPASS) {
RTC_DCHECK(!srtp_ciphers_.empty());
if (!IsRtpPacket(data, size)) {
if (!IsRtpPacket(rtc::MakeArrayView(
reinterpret_cast<const uint8_t*>(data), size))) {
return -1;
}
@ -513,7 +522,12 @@ void DtlsTransport::ConnectToIceTransport() {
RTC_DCHECK(ice_transport_);
ice_transport_->SignalWritableState.connect(this,
&DtlsTransport::OnWritableState);
ice_transport_->SignalReadPacket.connect(this, &DtlsTransport::OnReadPacket);
ice_transport_->RegisterReceivedPacketCallback(
this, [&](rtc::PacketTransportInternal* transport,
const rtc::ReceivedPacket& packet) {
OnReadPacket(transport, packet);
});
ice_transport_->SignalSentPacket.connect(this, &DtlsTransport::OnSentPacket);
ice_transport_->SignalReadyToSend.connect(this,
&DtlsTransport::OnReadyToSend);
@ -590,17 +604,13 @@ void DtlsTransport::OnReceivingState(rtc::PacketTransportInternal* transport) {
}
void DtlsTransport::OnReadPacket(rtc::PacketTransportInternal* transport,
const char* data,
size_t size,
const int64_t& packet_time_us,
int flags) {
const rtc::ReceivedPacket& packet) {
RTC_DCHECK_RUN_ON(&thread_checker_);
RTC_DCHECK(transport == ice_transport_);
RTC_DCHECK(flags == 0);
if (!dtls_active_) {
// Not doing DTLS.
SignalReadPacket(this, data, size, packet_time_us, 0);
NotifyPacketReceived(packet);
return;
}
@ -615,11 +625,11 @@ void DtlsTransport::OnReadPacket(rtc::PacketTransportInternal* transport,
"doing DTLS or not.";
}
// Cache a client hello packet received before DTLS has actually started.
if (IsDtlsClientHelloPacket(data, size)) {
if (IsDtlsClientHelloPacket(packet.payload())) {
RTC_LOG(LS_INFO) << ToString()
<< ": Caching DTLS ClientHello packet until DTLS is "
"started.";
cached_client_hello_.SetData(data, size);
cached_client_hello_.SetData(packet.payload());
// If we haven't started setting up DTLS yet (because we don't have a
// remote fingerprint/role), we can use the client hello as a clue that
// the peer has chosen the client role, and proceed with the handshake.
@ -638,8 +648,8 @@ void DtlsTransport::OnReadPacket(rtc::PacketTransportInternal* transport,
case webrtc::DtlsTransportState::kConnected:
// We should only get DTLS or SRTP packets; STUN's already been demuxed.
// Is this potentially a DTLS packet?
if (IsDtlsPacket(data, size)) {
if (!HandleDtlsPacket(data, size)) {
if (IsDtlsPacket(packet.payload())) {
if (!HandleDtlsPacket(packet.payload())) {
RTC_LOG(LS_ERROR) << ToString() << ": Failed to handle DTLS packet.";
return;
}
@ -653,7 +663,7 @@ void DtlsTransport::OnReadPacket(rtc::PacketTransportInternal* transport,
}
// And it had better be a SRTP packet.
if (!IsRtpPacket(data, size)) {
if (!IsRtpPacket(packet.payload())) {
RTC_LOG(LS_ERROR)
<< ToString() << ": Received unexpected non-DTLS packet.";
return;
@ -663,7 +673,8 @@ void DtlsTransport::OnReadPacket(rtc::PacketTransportInternal* transport,
RTC_DCHECK(!srtp_ciphers_.empty());
// Signal this upwards as a bypass packet.
SignalReadPacket(this, data, size, packet_time_us, PF_SRTP_BYPASS);
NotifyPacketReceived(
packet.CopyAndSet(rtc::ReceivedPacket::kSrtpEncrypted));
}
break;
case webrtc::DtlsTransportState::kFailed:
@ -710,8 +721,13 @@ void DtlsTransport::OnDtlsEvent(rtc::StreamInterface* dtls, int sig, int err) {
do {
ret = dtls_->Read(buf, read, read_error);
if (ret == rtc::SR_SUCCESS) {
SignalReadPacket(this, reinterpret_cast<const char*>(buf), read,
rtc::TimeMicros(), 0);
// TODO(bugs.webrtc.org/15368): It should be possible to use information
// from the original packet here to populate socket address and
// timestamp.
NotifyPacketReceived(rtc::ReceivedPacket(
rtc::MakeArrayView(buf, read), rtc::SocketAddress(),
webrtc::Timestamp::Micros(rtc::TimeMicros()),
rtc::ReceivedPacket::kDtlsDecrypted));
} else if (ret == rtc::SR_EOS) {
// Remote peer shut down the association with no error.
RTC_LOG(LS_INFO) << ToString() << ": DTLS transport closed by remote";
@ -775,8 +791,7 @@ void DtlsTransport::MaybeStartDtls() {
if (*dtls_role_ == rtc::SSL_SERVER) {
RTC_LOG(LS_INFO) << ToString()
<< ": Handling cached DTLS ClientHello packet.";
if (!HandleDtlsPacket(cached_client_hello_.data<char>(),
cached_client_hello_.size())) {
if (!HandleDtlsPacket(cached_client_hello_)) {
RTC_LOG(LS_ERROR) << ToString() << ": Failed to handle DTLS packet.";
}
} else {
@ -790,11 +805,11 @@ void DtlsTransport::MaybeStartDtls() {
}
// Called from OnReadPacket when a DTLS packet is received.
bool DtlsTransport::HandleDtlsPacket(const char* data, size_t size) {
bool DtlsTransport::HandleDtlsPacket(rtc::ArrayView<const uint8_t> payload) {
// Sanity check we're not passing junk that
// just looks like DTLS.
const uint8_t* tmp_data = reinterpret_cast<const uint8_t*>(data);
size_t tmp_size = size;
const uint8_t* tmp_data = payload.data();
size_t tmp_size = payload.size();
while (tmp_size > 0) {
if (tmp_size < kDtlsRecordHeaderLen)
return false; // Too short for the header
@ -809,7 +824,8 @@ bool DtlsTransport::HandleDtlsPacket(const char* data, size_t size) {
// Looks good. Pass to the SIC which ends up being passed to
// the DTLS stack.
return downward_->OnPacketReceived(data, size);
return downward_->OnPacketReceived(
reinterpret_cast<const char*>(payload.data()), payload.size());
}
void DtlsTransport::set_receiving(bool receiving) {

View File

@ -23,6 +23,7 @@
#include "p2p/base/ice_transport_internal.h"
#include "rtc_base/buffer.h"
#include "rtc_base/buffer_queue.h"
#include "rtc_base/network/received_packet.h"
#include "rtc_base/ssl_stream_adapter.h"
#include "rtc_base/stream.h"
#include "rtc_base/strings/string_builder.h"
@ -216,10 +217,7 @@ class DtlsTransport : public DtlsTransportInternal {
void OnWritableState(rtc::PacketTransportInternal* transport);
void OnReadPacket(rtc::PacketTransportInternal* transport,
const char* data,
size_t size,
const int64_t& packet_time_us,
int flags);
const rtc::ReceivedPacket& packet);
void OnSentPacket(rtc::PacketTransportInternal* transport,
const rtc::SentPacket& sent_packet);
void OnReadyToSend(rtc::PacketTransportInternal* transport);
@ -228,7 +226,7 @@ class DtlsTransport : public DtlsTransportInternal {
void OnNetworkRouteChanged(absl::optional<rtc::NetworkRoute> network_route);
bool SetupDtls();
void MaybeStartDtls();
bool HandleDtlsPacket(const char* data, size_t size);
bool HandleDtlsPacket(rtc::ArrayView<const uint8_t> payload);
void OnDtlsHandshakeError(rtc::SSLHandshakeError error);
void ConfigureHandshakeTimeout();

View File

@ -11,6 +11,8 @@
#include "p2p/base/dtls_transport.h"
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <set>
#include <utility>
@ -23,6 +25,7 @@
#include "rtc_base/dscp.h"
#include "rtc_base/gunit.h"
#include "rtc_base/helpers.h"
#include "rtc_base/network/received_packet.h"
#include "rtc_base/rtc_certificate.h"
#include "rtc_base/ssl_adapter.h"
#include "rtc_base/ssl_identity.h"
@ -82,6 +85,9 @@ class DtlsTestClient : public sigslot::has_slots<> {
}
// Set up fake ICE transport and real DTLS transport under test.
void SetupTransports(IceRole role, int async_delay_ms = 0) {
dtls_transport_ = nullptr;
fake_ice_transport_ = nullptr;
fake_ice_transport_.reset(new FakeIceTransport("fake", 0));
fake_ice_transport_->SetAsync(true);
fake_ice_transport_->SetAsyncDelay(async_delay_ms);
@ -89,8 +95,11 @@ class DtlsTestClient : public sigslot::has_slots<> {
fake_ice_transport_->SetIceTiebreaker((role == ICEROLE_CONTROLLING) ? 1
: 2);
// Hook the raw packets so that we can verify they are encrypted.
fake_ice_transport_->SignalReadPacket.connect(
this, &DtlsTestClient::OnFakeIceTransportReadPacket);
fake_ice_transport_->RegisterReceivedPacketCallback(
this, [&](rtc::PacketTransportInternal* transport,
const rtc::ReceivedPacket& packet) {
OnFakeIceTransportReadPacket(transport, packet);
});
dtls_transport_ = std::make_unique<DtlsTransport>(
fake_ice_transport_.get(), webrtc::CryptoOptions(),
@ -200,14 +209,14 @@ class DtlsTestClient : public sigslot::has_slots<> {
size_t NumPacketsReceived() { return received_.size(); }
// Inverse of SendPackets.
bool VerifyPacket(const char* data, size_t size, uint32_t* out_num) {
bool VerifyPacket(const uint8_t* data, size_t size, uint32_t* out_num) {
if (size != packet_size_ ||
(data[0] != 0 && static_cast<uint8_t>(data[0]) != 0x80)) {
return false;
}
uint32_t packet_num = rtc::GetBE32(data + kPacketNumOffset);
for (size_t i = kPacketHeaderLen; i < size; ++i) {
if (static_cast<uint8_t>(data[i]) != (packet_num & 0xff)) {
if (data[i] != (packet_num & 0xff)) {
return false;
}
}
@ -216,7 +225,7 @@ class DtlsTestClient : public sigslot::has_slots<> {
}
return true;
}
bool VerifyEncryptedPacket(const char* data, size_t size) {
bool VerifyEncryptedPacket(const uint8_t* data, size_t size) {
// This is an encrypted data packet; let's make sure it's mostly random;
// less than 10% of the bytes should be equal to the cleartext packet.
if (size <= packet_size_) {
@ -225,7 +234,7 @@ class DtlsTestClient : public sigslot::has_slots<> {
uint32_t packet_num = rtc::GetBE32(data + kPacketNumOffset);
int num_matches = 0;
for (size_t i = kPacketNumOffset; i < size; ++i) {
if (static_cast<uint8_t>(data[i]) == (packet_num & 0xff)) {
if (data[i] == (packet_num & 0xff)) {
++num_matches;
}
}
@ -244,7 +253,8 @@ class DtlsTestClient : public sigslot::has_slots<> {
const int64_t& /* packet_time_us */,
int flags) {
uint32_t packet_num = 0;
ASSERT_TRUE(VerifyPacket(data, size, &packet_num));
ASSERT_TRUE(VerifyPacket(reinterpret_cast<const uint8_t*>(data), size,
&packet_num));
received_.insert(packet_num);
// Only DTLS-SRTP packets should have the bypass flag set.
int expected_flags =
@ -261,15 +271,14 @@ class DtlsTestClient : public sigslot::has_slots<> {
// Hook into the raw packet stream to make sure DTLS packets are encrypted.
void OnFakeIceTransportReadPacket(rtc::PacketTransportInternal* transport,
const char* data,
size_t size,
const int64_t& /* packet_time_us */,
int flags) {
// Flags shouldn't be set on the underlying Transport packets.
ASSERT_EQ(0, flags);
const rtc::ReceivedPacket& packet) {
// Packets should not be decrypted on the underlying Transport packets.
ASSERT_EQ(packet.decryption_info(), rtc::ReceivedPacket::kNotDecrypted);
// Look at the handshake packets to see what role we played.
// Check that non-handshake packets are DTLS data or SRTP bypass.
const uint8_t* data = packet.payload().data();
size_t size = packet.payload().size();
if (data[0] == 22 && size > 17) {
if (data[13] == 1) {
++received_dtls_client_hellos_;