diff --git a/p2p/BUILD.gn b/p2p/BUILD.gn index c3858cba0d..9d25078bd8 100644 --- a/p2p/BUILD.gn +++ b/p2p/BUILD.gn @@ -421,10 +421,13 @@ rtc_library("dtls_transport") { "../rtc_base:checks", "../rtc_base:dscp", "../rtc_base:logging", + "../rtc_base:socket_address", "../rtc_base:ssl", "../rtc_base:stream", "../rtc_base:stringutils", "../rtc_base:threading", + "../rtc_base:timeutils", + "../rtc_base/network:received_packet", "../rtc_base/system:no_unique_address", ] absl_deps = [ diff --git a/p2p/base/dtls_transport.cc b/p2p/base/dtls_transport.cc index a9ff9d3784..6f30c6dbdd 100644 --- a/p2p/base/dtls_transport.cc +++ b/p2p/base/dtls_transport.cc @@ -11,6 +11,7 @@ #include "p2p/base/dtls_transport.h" #include +#include #include #include @@ -26,10 +27,13 @@ #include "rtc_base/checks.h" #include "rtc_base/dscp.h" #include "rtc_base/logging.h" +#include "rtc_base/network/received_packet.h" #include "rtc_base/rtc_certificate.h" +#include "rtc_base/socket_address.h" #include "rtc_base/ssl_stream_adapter.h" #include "rtc_base/stream.h" #include "rtc_base/thread.h" +#include "rtc_base/time_utils.h" namespace cricket { @@ -50,20 +54,20 @@ static const size_t kMaxPendingPackets = 2; static const int kMinHandshakeTimeout = 50; static const int kMaxHandshakeTimeout = 3000; -static bool IsDtlsPacket(const char* data, size_t len) { - const uint8_t* u = reinterpret_cast(data); - return (len >= kDtlsRecordHeaderLen && (u[0] > 19 && u[0] < 64)); +static bool IsDtlsPacket(rtc::ArrayView payload) { + const uint8_t* u = payload.data(); + return (payload.size() >= kDtlsRecordHeaderLen && (u[0] > 19 && u[0] < 64)); } -static bool IsDtlsClientHelloPacket(const char* data, size_t len) { - if (!IsDtlsPacket(data, len)) { +static bool IsDtlsClientHelloPacket(rtc::ArrayView payload) { + if (!IsDtlsPacket(payload)) { return false; } - const uint8_t* u = reinterpret_cast(data); - return len > 17 && u[0] == 22 && u[13] == 1; + const uint8_t* u = payload.data(); + return payload.size() > 17 && u[0] == 22 && u[13] == 1; } -static bool IsRtpPacket(const char* data, size_t len) { - const uint8_t* u = reinterpret_cast(data); - return (len >= kMinRtpPacketLen && (u[0] & 0xC0) == 0x80); +static bool IsRtpPacket(rtc::ArrayView payload) { + const uint8_t* u = payload.data(); + return (payload.size() >= kMinRtpPacketLen && (u[0] & 0xC0) == 0x80); } StreamInterfaceChannel::StreamInterfaceChannel( @@ -146,7 +150,11 @@ DtlsTransport::DtlsTransport(IceTransportInternal* ice_transport, ConnectToIceTransport(); } -DtlsTransport::~DtlsTransport() = default; +DtlsTransport::~DtlsTransport() { + if (ice_transport_) { + ice_transport_->DeregisterReceivedPacketCallback(this); + } +} webrtc::DtlsTransportState DtlsTransport::dtls_state() const { return dtls_state_; @@ -444,7 +452,8 @@ int DtlsTransport::SendPacket(const char* data, case webrtc::DtlsTransportState::kConnected: if (flags & PF_SRTP_BYPASS) { RTC_DCHECK(!srtp_ciphers_.empty()); - if (!IsRtpPacket(data, size)) { + if (!IsRtpPacket(rtc::MakeArrayView( + reinterpret_cast(data), size))) { return -1; } @@ -513,7 +522,12 @@ void DtlsTransport::ConnectToIceTransport() { RTC_DCHECK(ice_transport_); ice_transport_->SignalWritableState.connect(this, &DtlsTransport::OnWritableState); - ice_transport_->SignalReadPacket.connect(this, &DtlsTransport::OnReadPacket); + ice_transport_->RegisterReceivedPacketCallback( + this, [&](rtc::PacketTransportInternal* transport, + const rtc::ReceivedPacket& packet) { + OnReadPacket(transport, packet); + }); + ice_transport_->SignalSentPacket.connect(this, &DtlsTransport::OnSentPacket); ice_transport_->SignalReadyToSend.connect(this, &DtlsTransport::OnReadyToSend); @@ -590,17 +604,13 @@ void DtlsTransport::OnReceivingState(rtc::PacketTransportInternal* transport) { } void DtlsTransport::OnReadPacket(rtc::PacketTransportInternal* transport, - const char* data, - size_t size, - const int64_t& packet_time_us, - int flags) { + const rtc::ReceivedPacket& packet) { RTC_DCHECK_RUN_ON(&thread_checker_); RTC_DCHECK(transport == ice_transport_); - RTC_DCHECK(flags == 0); if (!dtls_active_) { // Not doing DTLS. - SignalReadPacket(this, data, size, packet_time_us, 0); + NotifyPacketReceived(packet); return; } @@ -615,11 +625,11 @@ void DtlsTransport::OnReadPacket(rtc::PacketTransportInternal* transport, "doing DTLS or not."; } // Cache a client hello packet received before DTLS has actually started. - if (IsDtlsClientHelloPacket(data, size)) { + if (IsDtlsClientHelloPacket(packet.payload())) { RTC_LOG(LS_INFO) << ToString() << ": Caching DTLS ClientHello packet until DTLS is " "started."; - cached_client_hello_.SetData(data, size); + cached_client_hello_.SetData(packet.payload()); // If we haven't started setting up DTLS yet (because we don't have a // remote fingerprint/role), we can use the client hello as a clue that // the peer has chosen the client role, and proceed with the handshake. @@ -638,8 +648,8 @@ void DtlsTransport::OnReadPacket(rtc::PacketTransportInternal* transport, case webrtc::DtlsTransportState::kConnected: // We should only get DTLS or SRTP packets; STUN's already been demuxed. // Is this potentially a DTLS packet? - if (IsDtlsPacket(data, size)) { - if (!HandleDtlsPacket(data, size)) { + if (IsDtlsPacket(packet.payload())) { + if (!HandleDtlsPacket(packet.payload())) { RTC_LOG(LS_ERROR) << ToString() << ": Failed to handle DTLS packet."; return; } @@ -653,7 +663,7 @@ void DtlsTransport::OnReadPacket(rtc::PacketTransportInternal* transport, } // And it had better be a SRTP packet. - if (!IsRtpPacket(data, size)) { + if (!IsRtpPacket(packet.payload())) { RTC_LOG(LS_ERROR) << ToString() << ": Received unexpected non-DTLS packet."; return; @@ -663,7 +673,8 @@ void DtlsTransport::OnReadPacket(rtc::PacketTransportInternal* transport, RTC_DCHECK(!srtp_ciphers_.empty()); // Signal this upwards as a bypass packet. - SignalReadPacket(this, data, size, packet_time_us, PF_SRTP_BYPASS); + NotifyPacketReceived( + packet.CopyAndSet(rtc::ReceivedPacket::kSrtpEncrypted)); } break; case webrtc::DtlsTransportState::kFailed: @@ -710,8 +721,13 @@ void DtlsTransport::OnDtlsEvent(rtc::StreamInterface* dtls, int sig, int err) { do { ret = dtls_->Read(buf, read, read_error); if (ret == rtc::SR_SUCCESS) { - SignalReadPacket(this, reinterpret_cast(buf), read, - rtc::TimeMicros(), 0); + // TODO(bugs.webrtc.org/15368): It should be possible to use information + // from the original packet here to populate socket address and + // timestamp. + NotifyPacketReceived(rtc::ReceivedPacket( + rtc::MakeArrayView(buf, read), rtc::SocketAddress(), + webrtc::Timestamp::Micros(rtc::TimeMicros()), + rtc::ReceivedPacket::kDtlsDecrypted)); } else if (ret == rtc::SR_EOS) { // Remote peer shut down the association with no error. RTC_LOG(LS_INFO) << ToString() << ": DTLS transport closed by remote"; @@ -775,8 +791,7 @@ void DtlsTransport::MaybeStartDtls() { if (*dtls_role_ == rtc::SSL_SERVER) { RTC_LOG(LS_INFO) << ToString() << ": Handling cached DTLS ClientHello packet."; - if (!HandleDtlsPacket(cached_client_hello_.data(), - cached_client_hello_.size())) { + if (!HandleDtlsPacket(cached_client_hello_)) { RTC_LOG(LS_ERROR) << ToString() << ": Failed to handle DTLS packet."; } } else { @@ -790,11 +805,11 @@ void DtlsTransport::MaybeStartDtls() { } // Called from OnReadPacket when a DTLS packet is received. -bool DtlsTransport::HandleDtlsPacket(const char* data, size_t size) { +bool DtlsTransport::HandleDtlsPacket(rtc::ArrayView payload) { // Sanity check we're not passing junk that // just looks like DTLS. - const uint8_t* tmp_data = reinterpret_cast(data); - size_t tmp_size = size; + const uint8_t* tmp_data = payload.data(); + size_t tmp_size = payload.size(); while (tmp_size > 0) { if (tmp_size < kDtlsRecordHeaderLen) return false; // Too short for the header @@ -809,7 +824,8 @@ bool DtlsTransport::HandleDtlsPacket(const char* data, size_t size) { // Looks good. Pass to the SIC which ends up being passed to // the DTLS stack. - return downward_->OnPacketReceived(data, size); + return downward_->OnPacketReceived( + reinterpret_cast(payload.data()), payload.size()); } void DtlsTransport::set_receiving(bool receiving) { diff --git a/p2p/base/dtls_transport.h b/p2p/base/dtls_transport.h index 9408025be5..f479325258 100644 --- a/p2p/base/dtls_transport.h +++ b/p2p/base/dtls_transport.h @@ -23,6 +23,7 @@ #include "p2p/base/ice_transport_internal.h" #include "rtc_base/buffer.h" #include "rtc_base/buffer_queue.h" +#include "rtc_base/network/received_packet.h" #include "rtc_base/ssl_stream_adapter.h" #include "rtc_base/stream.h" #include "rtc_base/strings/string_builder.h" @@ -216,10 +217,7 @@ class DtlsTransport : public DtlsTransportInternal { void OnWritableState(rtc::PacketTransportInternal* transport); void OnReadPacket(rtc::PacketTransportInternal* transport, - const char* data, - size_t size, - const int64_t& packet_time_us, - int flags); + const rtc::ReceivedPacket& packet); void OnSentPacket(rtc::PacketTransportInternal* transport, const rtc::SentPacket& sent_packet); void OnReadyToSend(rtc::PacketTransportInternal* transport); @@ -228,7 +226,7 @@ class DtlsTransport : public DtlsTransportInternal { void OnNetworkRouteChanged(absl::optional network_route); bool SetupDtls(); void MaybeStartDtls(); - bool HandleDtlsPacket(const char* data, size_t size); + bool HandleDtlsPacket(rtc::ArrayView payload); void OnDtlsHandshakeError(rtc::SSLHandshakeError error); void ConfigureHandshakeTimeout(); diff --git a/p2p/base/dtls_transport_unittest.cc b/p2p/base/dtls_transport_unittest.cc index e338ab6a49..ddf18746d8 100644 --- a/p2p/base/dtls_transport_unittest.cc +++ b/p2p/base/dtls_transport_unittest.cc @@ -11,6 +11,8 @@ #include "p2p/base/dtls_transport.h" #include +#include +#include #include #include #include @@ -23,6 +25,7 @@ #include "rtc_base/dscp.h" #include "rtc_base/gunit.h" #include "rtc_base/helpers.h" +#include "rtc_base/network/received_packet.h" #include "rtc_base/rtc_certificate.h" #include "rtc_base/ssl_adapter.h" #include "rtc_base/ssl_identity.h" @@ -82,6 +85,9 @@ class DtlsTestClient : public sigslot::has_slots<> { } // Set up fake ICE transport and real DTLS transport under test. void SetupTransports(IceRole role, int async_delay_ms = 0) { + dtls_transport_ = nullptr; + fake_ice_transport_ = nullptr; + fake_ice_transport_.reset(new FakeIceTransport("fake", 0)); fake_ice_transport_->SetAsync(true); fake_ice_transport_->SetAsyncDelay(async_delay_ms); @@ -89,8 +95,11 @@ class DtlsTestClient : public sigslot::has_slots<> { fake_ice_transport_->SetIceTiebreaker((role == ICEROLE_CONTROLLING) ? 1 : 2); // Hook the raw packets so that we can verify they are encrypted. - fake_ice_transport_->SignalReadPacket.connect( - this, &DtlsTestClient::OnFakeIceTransportReadPacket); + fake_ice_transport_->RegisterReceivedPacketCallback( + this, [&](rtc::PacketTransportInternal* transport, + const rtc::ReceivedPacket& packet) { + OnFakeIceTransportReadPacket(transport, packet); + }); dtls_transport_ = std::make_unique( fake_ice_transport_.get(), webrtc::CryptoOptions(), @@ -200,14 +209,14 @@ class DtlsTestClient : public sigslot::has_slots<> { size_t NumPacketsReceived() { return received_.size(); } // Inverse of SendPackets. - bool VerifyPacket(const char* data, size_t size, uint32_t* out_num) { + bool VerifyPacket(const uint8_t* data, size_t size, uint32_t* out_num) { if (size != packet_size_ || (data[0] != 0 && static_cast(data[0]) != 0x80)) { return false; } uint32_t packet_num = rtc::GetBE32(data + kPacketNumOffset); for (size_t i = kPacketHeaderLen; i < size; ++i) { - if (static_cast(data[i]) != (packet_num & 0xff)) { + if (data[i] != (packet_num & 0xff)) { return false; } } @@ -216,7 +225,7 @@ class DtlsTestClient : public sigslot::has_slots<> { } return true; } - bool VerifyEncryptedPacket(const char* data, size_t size) { + bool VerifyEncryptedPacket(const uint8_t* data, size_t size) { // This is an encrypted data packet; let's make sure it's mostly random; // less than 10% of the bytes should be equal to the cleartext packet. if (size <= packet_size_) { @@ -225,7 +234,7 @@ class DtlsTestClient : public sigslot::has_slots<> { uint32_t packet_num = rtc::GetBE32(data + kPacketNumOffset); int num_matches = 0; for (size_t i = kPacketNumOffset; i < size; ++i) { - if (static_cast(data[i]) == (packet_num & 0xff)) { + if (data[i] == (packet_num & 0xff)) { ++num_matches; } } @@ -244,7 +253,8 @@ class DtlsTestClient : public sigslot::has_slots<> { const int64_t& /* packet_time_us */, int flags) { uint32_t packet_num = 0; - ASSERT_TRUE(VerifyPacket(data, size, &packet_num)); + ASSERT_TRUE(VerifyPacket(reinterpret_cast(data), size, + &packet_num)); received_.insert(packet_num); // Only DTLS-SRTP packets should have the bypass flag set. int expected_flags = @@ -261,15 +271,14 @@ class DtlsTestClient : public sigslot::has_slots<> { // Hook into the raw packet stream to make sure DTLS packets are encrypted. void OnFakeIceTransportReadPacket(rtc::PacketTransportInternal* transport, - const char* data, - size_t size, - const int64_t& /* packet_time_us */, - int flags) { - // Flags shouldn't be set on the underlying Transport packets. - ASSERT_EQ(0, flags); + const rtc::ReceivedPacket& packet) { + // Packets should not be decrypted on the underlying Transport packets. + ASSERT_EQ(packet.decryption_info(), rtc::ReceivedPacket::kNotDecrypted); // Look at the handshake packets to see what role we played. // Check that non-handshake packets are DTLS data or SRTP bypass. + const uint8_t* data = packet.payload().data(); + size_t size = packet.payload().size(); if (data[0] == 22 && size > 17) { if (data[13] == 1) { ++received_dtls_client_hellos_;