diff --git a/experiments/field_trials.py b/experiments/field_trials.py index 3c55ff193b..60ecc46455 100755 --- a/experiments/field_trials.py +++ b/experiments/field_trials.py @@ -191,6 +191,9 @@ ACTIVE_FIELD_TRIALS: FrozenSet[FieldTrial] = frozenset([ FieldTrial('WebRTC-Video-Vp9FlexibleMode', 329396373, date(2025, 6, 26)), + FieldTrial('WebRTC-IceHandshakeDtls', + 367395350, + date(2026, 1, 1)), # keep-sorted end ]) # yapf: disable diff --git a/p2p/BUILD.gn b/p2p/BUILD.gn index 7f481ce1cb..5eb374b7b0 100644 --- a/p2p/BUILD.gn +++ b/p2p/BUILD.gn @@ -118,6 +118,7 @@ rtc_library("rtc_p2p") { "../rtc_base:byte_order", "../rtc_base:callback_list", "../rtc_base:checks", + "../rtc_base:copy_on_write_buffer", "../rtc_base:crc32", "../rtc_base:dscp", "../rtc_base:event_tracer", @@ -331,6 +332,7 @@ rtc_library("connection") { "../rtc_base:byte_buffer", "../rtc_base:callback_list", "../rtc_base:checks", + "../rtc_base:copy_on_write_buffer", "../rtc_base:crc32", "../rtc_base:crypto_random", "../rtc_base:dscp", @@ -340,6 +342,7 @@ rtc_library("connection") { "../rtc_base:macromagic", "../rtc_base:mdns_responder_interface", "../rtc_base:net_helper", + "../rtc_base:net_helpers", "../rtc_base:network", "../rtc_base:network_constants", "../rtc_base:rate_tracker", @@ -358,6 +361,7 @@ rtc_library("connection") { "../rtc_base/third_party/base64", "../rtc_base/third_party/sigslot", "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/base:core_headers", "//third_party/abseil-cpp/absl/functional:any_invocable", "//third_party/abseil-cpp/absl/strings:string_view", ] @@ -530,12 +534,14 @@ rtc_library("ice_transport_internal") { ":transport_description", "../api:array_view", "../api:candidate", + "../api:field_trials_view", "../api:rtc_error", "../api/transport:enums", "../rtc_base:callback_list", "../rtc_base:checks", "../rtc_base:network_constants", "../rtc_base:timeutils", + "../rtc_base/network:received_packet", "../rtc_base/system:rtc_export", "../rtc_base/third_party/sigslot", "//third_party/abseil-cpp/absl/functional:any_invocable", @@ -563,6 +569,8 @@ rtc_library("p2p_transport_channel") { ":candidate_pair_interface", ":connection", ":connection_info", + ":dtls_stun_piggyback_controller", + ":dtls_utils", ":ice_agent_interface", ":ice_controller_factory_interface", ":ice_controller_interface", @@ -591,6 +599,7 @@ rtc_library("p2p_transport_channel") { "../logging:ice_log", "../rtc_base:async_packet_socket", "../rtc_base:checks", + "../rtc_base:copy_on_write_buffer", "../rtc_base:dscp", "../rtc_base:event_tracer", "../rtc_base:ip_address", @@ -808,10 +817,12 @@ rtc_library("dtls_stun_piggyback_controller") { "../api:array_view", "../api:sequence_checker", "../api/transport:stun_types", + "../api/transport:stun_types", "../rtc_base:buffer", "../rtc_base:byte_buffer", "../rtc_base:checks", "../rtc_base:logging", + "../rtc_base:logging", "../rtc_base:macromagic", "../rtc_base:stringutils", "../rtc_base/system:no_unique_address", @@ -1165,6 +1176,7 @@ if (rtc_include_tests) { "base/turn_server_unittest.cc", "base/wrapping_active_ice_controller_unittest.cc", "client/basic_port_allocator_unittest.cc", + "dtls/dtls_ice_integrationtest.cc", "dtls/dtls_stun_piggyback_controller_unittest.cc", "dtls/dtls_transport_unittest.cc", "dtls/dtls_utils_unittest.cc", @@ -1176,14 +1188,19 @@ if (rtc_include_tests) { ":basic_ice_controller", ":basic_packet_socket_factory", ":basic_port_allocator", + ":candidate_pair_interface", ":connection", + ":connection_info", ":dtls_stun_piggyback_controller", ":dtls_transport", ":dtls_transport_internal", ":dtls_utils", ":fake_ice_transport", ":fake_port_allocator", + ":ice_controller_factory_interface", + ":ice_controller_interface", ":ice_credentials_iterator", + ":ice_switch_reason", ":ice_transport_internal", ":p2p_constants", ":p2p_server_utils", @@ -1205,9 +1222,11 @@ if (rtc_include_tests) { ":turn_port", ":wrapping_active_ice_controller", "../api:array_view", + "../api:async_dns_resolver", "../api:candidate", "../api:dtls_transport_interface", "../api:field_trials_view", + "../api:ice_transport_interface", "../api:libjingle_peerconnection_api", "../api:mock_async_dns_resolver", "../api:packet_socket_factory", @@ -1215,6 +1234,7 @@ if (rtc_include_tests) { "../api/crypto:options", "../api/task_queue", "../api/task_queue:pending_task_safety_flag", + "../api/transport:enums", "../api/transport:stun_types", "../api/units:time_delta", "../rtc_base:async_packet_socket", @@ -1236,11 +1256,13 @@ if (rtc_include_tests) { "../rtc_base:net_test_helpers", "../rtc_base:network", "../rtc_base:network_constants", + "../rtc_base:network_route", "../rtc_base:rtc_base_tests_utils", "../rtc_base:socket", "../rtc_base:socket_adapters", "../rtc_base:socket_address", "../rtc_base:socket_address_pair", + "../rtc_base:socket_server", "../rtc_base:ssl", "../rtc_base:ssl_adapter", "../rtc_base:stringutils", diff --git a/p2p/base/connection.cc b/p2p/base/connection.cc index b081d994ce..7a10154730 100644 --- a/p2p/base/connection.cc +++ b/p2p/base/connection.cc @@ -13,23 +13,43 @@ #include #include +#include #include #include #include +#include #include #include #include "absl/algorithm/container.h" +#include "absl/base/attributes.h" +#include "absl/functional/any_invocable.h" #include "absl/strings/string_view.h" #include "api/array_view.h" +#include "api/candidate.h" +#include "api/rtc_error.h" +#include "api/sequence_checker.h" +#include "api/task_queue/task_queue_base.h" +#include "api/transport/stun.h" #include "api/units/timestamp.h" +#include "logging/rtc_event_log/events/rtc_event_ice_candidate_pair.h" +#include "logging/rtc_event_log/events/rtc_event_ice_candidate_pair_config.h" +#include "logging/rtc_event_log/ice_logger.h" +#include "p2p/base/connection_info.h" #include "p2p/base/p2p_constants.h" +#include "p2p/base/p2p_transport_channel_ice_field_trials.h" +#include "p2p/base/port_interface.h" +#include "p2p/base/stun_request.h" +#include "p2p/base/transport_description.h" +#include "rtc_base/async_packet_socket.h" #include "rtc_base/byte_buffer.h" #include "rtc_base/checks.h" #include "rtc_base/crypto_random.h" #include "rtc_base/logging.h" #include "rtc_base/net_helper.h" +#include "rtc_base/net_helpers.h" #include "rtc_base/network.h" +#include "rtc_base/network/received_packet.h" #include "rtc_base/network/sent_packet.h" #include "rtc_base/network_constants.h" #include "rtc_base/numerics/safe_minmax.h" @@ -39,6 +59,7 @@ #include "rtc_base/string_utils.h" #include "rtc_base/strings/string_builder.h" #include "rtc_base/time_utils.h" +#include "rtc_base/weak_ptr.h" namespace cricket { namespace { @@ -557,6 +578,15 @@ void Connection::OnReadPacket(const rtc::ReceivedPacket& packet) { // This doesn't just check, it makes callbacks if transaction // id's match. case STUN_BINDING_RESPONSE: + if (dtls_stun_piggyback_consumer_) { + const StunByteStringAttribute* dtls_piggyback_attribute = + msg->GetByteString(STUN_ATTR_META_DTLS_IN_STUN); + const StunByteStringAttribute* dtls_piggyback_ack = + msg->GetByteString(STUN_ATTR_META_DTLS_IN_STUN_ACK); + dtls_stun_piggyback_consumer_(dtls_piggyback_attribute, + dtls_piggyback_ack); + } + ABSL_FALLTHROUGH_INTENDED; case STUN_BINDING_ERROR_RESPONSE: requests_.CheckResponse(msg.get()); break; @@ -581,6 +611,36 @@ void Connection::OnReadPacket(const rtc::ReceivedPacket& packet) { } } +void Connection::MaybeAddDtlsPiggybackingAttributes(StunMessage* msg) { + if (!(dtls_stun_piggyback_data_producer_ && + dtls_stun_piggyback_ack_producer_)) { + return; + } + std::optional dtls_piggyback_attr = + dtls_stun_piggyback_data_producer_(STUN_BINDING_RESPONSE); + std::optional dtls_piggyback_ack = + dtls_stun_piggyback_ack_producer_(STUN_BINDING_REQUEST); + + size_t need_length = + (dtls_piggyback_attr + ? dtls_piggyback_attr->length() + kStunAttributeHeaderSize + : 0) + + (dtls_piggyback_ack + ? dtls_piggyback_ack->length() + kStunAttributeHeaderSize + : 0); + if (msg->length() + need_length > kMaxStunBindingLength) { + return; + } + if (dtls_piggyback_attr) { + msg->AddAttribute(std::make_unique( + STUN_ATTR_META_DTLS_IN_STUN, *dtls_piggyback_attr)); + } + if (dtls_piggyback_ack) { + msg->AddAttribute(std::make_unique( + STUN_ATTR_META_DTLS_IN_STUN_ACK, *dtls_piggyback_ack)); + } +} + void Connection::HandleStunBindingOrGoogPingRequest(IceMessage* msg) { RTC_DCHECK_RUN_ON(network_thread_); // This connection should now be receiving. @@ -623,6 +683,14 @@ void Connection::HandleStunBindingOrGoogPingRequest(IceMessage* msg) { // This is a validated stun request from remote peer. if (msg->type() == STUN_BINDING_REQUEST) { + if (dtls_stun_piggyback_consumer_) { + const StunByteStringAttribute* dtls_piggyback_attribute = + msg->GetByteString(STUN_ATTR_META_DTLS_IN_STUN); + const StunByteStringAttribute* dtls_piggyback_ack = + msg->GetByteString(STUN_ATTR_META_DTLS_IN_STUN_ACK); + dtls_stun_piggyback_consumer_(dtls_piggyback_attribute, + dtls_piggyback_ack); + } SendStunBindingResponse(msg); } else { RTC_DCHECK(msg->type() == GOOG_PING_REQUEST); @@ -747,6 +815,8 @@ void Connection::SendStunBindingResponse(const StunMessage* message) { } } + MaybeAddDtlsPiggybackingAttributes(&response); + response.AddMessageIntegrity(local_candidate().password()); response.AddFingerprint(); @@ -1083,6 +1153,8 @@ std::unique_ptr Connection::BuildPingRequest( message->AddAttribute(std::move(delta)); } + MaybeAddDtlsPiggybackingAttributes(message.get()); + message->AddMessageIntegrity(remote_candidate_.password()); message->AddFingerprint(); @@ -1483,6 +1555,21 @@ void Connection::OnConnectionRequestResponse(StunRequest* request, } else if (delta_ack) { RTC_LOG(LS_ERROR) << "Discard GOOG_DELTA_ACK, no consumer"; } + + if (dtls_stun_piggyback_consumer_) { + const bool sent_dtls_piggyback = + request->msg()->GetByteString(STUN_ATTR_META_DTLS_IN_STUN) != nullptr; + const bool sent_dtls_piggyback_ack = + request->msg()->GetByteString(STUN_ATTR_META_DTLS_IN_STUN_ACK) != + nullptr; + const StunByteStringAttribute* dtls_piggyback_attr = + response->GetByteString(STUN_ATTR_META_DTLS_IN_STUN); + const StunByteStringAttribute* dtls_piggyback_ack = + response->GetByteString(STUN_ATTR_META_DTLS_IN_STUN_ACK); + if (sent_dtls_piggyback || sent_dtls_piggyback_ack) { + dtls_stun_piggyback_consumer_(dtls_piggyback_attr, dtls_piggyback_ack); + } + } } void Connection::OnConnectionRequestErrorResponse(ConnectionRequest* request, diff --git a/p2p/base/connection.h b/p2p/base/connection.h index 8079a56880..4f371a0973 100644 --- a/p2p/base/connection.h +++ b/p2p/base/connection.h @@ -18,7 +18,6 @@ #include #include #include -#include #include #include @@ -53,6 +52,8 @@ namespace cricket { // Version number for GOOG_PING, this is added to have the option of // adding other flavors in the future. constexpr int kGoogPingVersion = 1; +// 1200 is the "commonly used" MTU. Subtract M-I attribute (20+4) and FP (4+4). +constexpr int kMaxStunBindingLength = 1200 - 24 - 8; // Forward declaration so that a ConnectionRequest can contain a Connection. class Connection; @@ -359,6 +360,23 @@ class RTC_EXPORT Connection : public CandidatePairInterface { goog_delta_ack_consumer_ = std::nullopt; } + void RegisterDtlsPiggyback( + absl::AnyInvocable(StunMessageType)> + data_producer, + absl::AnyInvocable(StunMessageType)> + ack_producer, + absl::AnyInvocable consumer) { + dtls_stun_piggyback_data_producer_ = std::move(data_producer); + dtls_stun_piggyback_ack_producer_ = std::move(ack_producer); + dtls_stun_piggyback_consumer_ = std::move(consumer); + } + void DeregisterDtlsPiggyback() { + dtls_stun_piggyback_consumer_ = nullptr; + dtls_stun_piggyback_data_producer_ = nullptr; + dtls_stun_piggyback_ack_producer_ = nullptr; + } + protected: // A ConnectionRequest is a simple STUN ping used to determine writability. class ConnectionRequest; @@ -511,6 +529,15 @@ class RTC_EXPORT Connection : public CandidatePairInterface { goog_delta_ack_consumer_; absl::AnyInvocable received_packet_callback_; + + void MaybeAddDtlsPiggybackingAttributes(StunMessage* msg); + absl::AnyInvocable(StunMessageType)> + dtls_stun_piggyback_data_producer_ = nullptr; + absl::AnyInvocable(StunMessageType)> + dtls_stun_piggyback_ack_producer_ = nullptr; + absl::AnyInvocable + dtls_stun_piggyback_consumer_ = nullptr; }; // ProxyConnection defers all the interesting work to the port. diff --git a/p2p/base/ice_transport_internal.h b/p2p/base/ice_transport_internal.h index 851dbaf2f5..3a52073b76 100644 --- a/p2p/base/ice_transport_internal.h +++ b/p2p/base/ice_transport_internal.h @@ -23,6 +23,7 @@ #include "absl/strings/string_view.h" #include "api/array_view.h" #include "api/candidate.h" +#include "api/field_trials_view.h" #include "api/rtc_error.h" #include "api/transport/enums.h" #include "p2p/base/candidate_pair_interface.h" @@ -34,6 +35,7 @@ #include "p2p/base/transport_description.h" #include "rtc_base/callback_list.h" #include "rtc_base/checks.h" +#include "rtc_base/network/received_packet.h" #include "rtc_base/network_constants.h" #include "rtc_base/system/rtc_export.h" #include "rtc_base/third_party/sigslot/sigslot.h" @@ -201,6 +203,9 @@ struct RTC_EXPORT IceConfig { webrtc::VpnPreference vpn_preference = webrtc::VpnPreference::kDefault; + // Experimental feature to transport the DTLS handshake in STUN packets. + bool dtls_handshake_in_stun = false; + IceConfig(); IceConfig(int receiving_timeout_ms, int backup_connection_ping_interval, @@ -398,6 +403,15 @@ class RTC_EXPORT IceTransportInternal : public rtc::PacketTransportInternal { virtual const webrtc::FieldTrialsView* field_trials() const { return nullptr; } + void SetPiggybackDtlsDataCallback( + absl::AnyInvocable callback) { + RTC_DCHECK(callback == nullptr || !piggybacked_dtls_callback_); + piggybacked_dtls_callback_ = std::move(callback); + } + virtual void SetDtlsDataToPiggyback(rtc::ArrayView) {} + virtual void SetDtlsHandshakeComplete(bool is_dtls_client) {} + virtual bool IsDtlsPiggybackSupportedByPeer() { return false; } protected: void SendGatheringStateEvent() { gathering_state_callback_list_.Send(this); } @@ -419,6 +433,9 @@ class RTC_EXPORT IceTransportInternal : public rtc::PacketTransportInternal { absl::AnyInvocable candidate_pair_change_callback_; + absl::AnyInvocable + piggybacked_dtls_callback_; }; } // namespace cricket diff --git a/p2p/base/p2p_transport_channel.cc b/p2p/base/p2p_transport_channel.cc index 63d681e01a..c82e3f1a15 100644 --- a/p2p/base/p2p_transport_channel.cc +++ b/p2p/base/p2p_transport_channel.cc @@ -55,6 +55,7 @@ #include "p2p/base/regathering_controller.h" #include "p2p/base/transport_description.h" #include "p2p/base/wrapping_active_ice_controller.h" +#include "p2p/dtls/dtls_utils.h" #include "rtc_base/async_packet_socket.h" #include "rtc_base/checks.h" #include "rtc_base/dscp.h" @@ -200,7 +201,16 @@ P2PTransportChannel::P2PTransportChannel( true /* presume_writable_when_fully_relayed */, REGATHER_ON_FAILED_NETWORKS_INTERVAL, RECEIVING_SWITCHING_DELAY), - field_trials_(field_trials) { + field_trials_(field_trials), + dtls_stun_piggyback_controller_( + [this](rtc::ArrayView piggybacked_dtls_packet) { + if (piggybacked_dtls_callback_ == nullptr) { + return; + } + piggybacked_dtls_callback_( + this, rtc::ReceivedPacket(piggybacked_dtls_packet, + rtc::SocketAddress())); + }) { TRACE_EVENT0("webrtc", "P2PTransportChannel::P2PTransportChannel"); RTC_DCHECK(allocator_ != nullptr); // Validate IceConfig even for mostly built-in constant default values in case @@ -310,6 +320,22 @@ void P2PTransportChannel::AddConnection(Connection* connection) { [this](webrtc::RTCErrorOr delta_ack) { GoogDeltaAckReceived(std::move(delta_ack)); }); + if (config_.dtls_handshake_in_stun) { + connection->RegisterDtlsPiggyback( + [this](StunMessageType stun_message_type) { + return dtls_stun_piggyback_controller_.GetDataToPiggyback( + stun_message_type); + }, + [this](StunMessageType stun_message_type) { + return dtls_stun_piggyback_controller_.GetAckToPiggyback( + stun_message_type); + }, + [this](const StunByteStringAttribute* data, + const StunByteStringAttribute* ack) { + dtls_stun_piggyback_controller_.ReportDataPiggybacked(data, ack); + }); + } + LogCandidatePairConfig(connection, webrtc::IceCandidatePairConfigType::kAdded); @@ -695,6 +721,11 @@ void P2PTransportChannel::SetIceConfig(const IceConfig& config) { allocator_->SetVpnPreference(config_.vpn_preference); ice_controller_->SetIceConfig(config_); + if (config_.dtls_handshake_in_stun != config.dtls_handshake_in_stun) { + config_.dtls_handshake_in_stun = config.dtls_handshake_in_stun; + RTC_LOG(LS_INFO) << "Set DTLS handshake in STUN to " + << config.dtls_handshake_in_stun; + } RTC_DCHECK(ValidateIceConfig(config_).ok()); } @@ -1609,6 +1640,16 @@ int P2PTransportChannel::SendPacket(const char* data, error_ = ENOTCONN; return -1; } + /* + * When trying DTLS-STUN piggyback we need to drop handshake packets + * as we start fresh if this fails. + */ + if (config_.dtls_handshake_in_stun && IsDtlsPiggybackSupportedByPeer() && + IsDtlsHandshakePacket( + rtc::MakeArrayView(reinterpret_cast(data), len))) { + RTC_LOG(LS_INFO) << "Dropping DTLS handshake while attemping DTLS-in-STUN"; + return len; + } packets_sent_++; last_sent_packet_id_ = options.packet_id; @@ -2151,6 +2192,7 @@ void P2PTransportChannel::RemoveConnection(Connection* connection) { connection->DeregisterReceivedPacketCallback(); connections_.erase(it); connection->ClearStunDictConsumer(); + connection->DeregisterDtlsPiggyback(); ice_controller_->OnConnectionDestroyed(connection); } @@ -2272,6 +2314,12 @@ void P2PTransportChannel::SetWritable(bool writable) { SignalReadyToSend(this); } SignalWritableState(this); + + if (config_.dtls_handshake_in_stun && IsDtlsPiggybackSupportedByPeer()) { + // Need to STUN ping here to get the last bit of the DTLS handshake across + // as quickly as possible. + SendPingRequestInternal(selected_connection_); + } } void P2PTransportChannel::SetReceiving(bool receiving) { diff --git a/p2p/base/p2p_transport_channel.h b/p2p/base/p2p_transport_channel.h index 0bc3b84616..cef9714f71 100644 --- a/p2p/base/p2p_transport_channel.h +++ b/p2p/base/p2p_transport_channel.h @@ -58,6 +58,7 @@ #include "p2p/base/regathering_controller.h" #include "p2p/base/stun_dictionary.h" #include "p2p/base/transport_description.h" +#include "p2p/dtls/dtls_stun_piggyback_controller.h" #include "rtc_base/async_packet_socket.h" #include "rtc_base/checks.h" #include "rtc_base/dscp.h" @@ -251,6 +252,18 @@ class RTC_EXPORT P2PTransportChannel : public IceTransportInternal, const webrtc::FieldTrialsView* field_trials() const override { return field_trials_; } + void SetDtlsDataToPiggyback(rtc::ArrayView data) override { + dtls_stun_piggyback_controller_.SetDataToPiggyback(data); + } + void SetDtlsHandshakeComplete(bool is_dtls_client) override { + dtls_stun_piggyback_controller_.SetDtlsHandshakeComplete(is_dtls_client); + } + bool IsDtlsPiggybackSupportedByPeer() override { + RTC_DCHECK_RUN_ON(network_thread_); + return config_.dtls_handshake_in_stun && + dtls_stun_piggyback_controller_.state() != + DtlsStunPiggybackController::State::OFF; + } private: P2PTransportChannel( @@ -515,6 +528,9 @@ class RTC_EXPORT P2PTransportChannel : public IceTransportInternal, // A dictionary that tracks attributes from peer. StunDictionaryView stun_dict_view_; + + // A controller for piggybacking DTLS in STUN. + DtlsStunPiggybackController dtls_stun_piggyback_controller_; }; } // namespace cricket diff --git a/p2p/base/p2p_transport_channel_unittest.cc b/p2p/base/p2p_transport_channel_unittest.cc index 04c0641150..c23867ef92 100644 --- a/p2p/base/p2p_transport_channel_unittest.cc +++ b/p2p/base/p2p_transport_channel_unittest.cc @@ -10,29 +10,57 @@ #include "p2p/base/p2p_transport_channel.h" +#include +#include +#include #include +#include #include +#include #include -#include #include +#include #include "absl/algorithm/container.h" +#include "absl/functional/any_invocable.h" #include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "api/async_dns_resolver.h" +#include "api/candidate.h" +#include "api/field_trials_view.h" +#include "api/ice_transport_interface.h" +#include "api/packet_socket_factory.h" +#include "api/scoped_refptr.h" +#include "api/task_queue/pending_task_safety_flag.h" #include "api/test/mock_async_dns_resolver.h" -#include "p2p/base/active_ice_controller_factory_interface.h" -#include "p2p/base/active_ice_controller_interface.h" +#include "api/transport/enums.h" +#include "api/transport/stun.h" +#include "api/units/time_delta.h" #include "p2p/base/basic_ice_controller.h" +#include "p2p/base/basic_packet_socket_factory.h" +#include "p2p/base/candidate_pair_interface.h" #include "p2p/base/connection.h" +#include "p2p/base/connection_info.h" #include "p2p/base/fake_port_allocator.h" +#include "p2p/base/ice_controller_factory_interface.h" +#include "p2p/base/ice_controller_interface.h" +#include "p2p/base/ice_switch_reason.h" #include "p2p/base/ice_transport_internal.h" #include "p2p/base/mock_active_ice_controller.h" #include "p2p/base/mock_ice_controller.h" +#include "p2p/base/p2p_constants.h" #include "p2p/base/packet_transport_internal.h" +#include "p2p/base/port.h" +#include "p2p/base/port_allocator.h" +#include "p2p/base/port_interface.h" +#include "p2p/base/stun_dictionary.h" +#include "p2p/base/stun_server.h" #include "p2p/base/test_stun_server.h" #include "p2p/base/test_turn_server.h" +#include "p2p/base/transport_description.h" #include "p2p/client/basic_port_allocator.h" +#include "rtc_base/byte_buffer.h" #include "rtc_base/checks.h" -#include "rtc_base/crypto_random.h" #include "rtc_base/dscp.h" #include "rtc_base/fake_clock.h" #include "rtc_base/fake_mdns_responder.h" @@ -40,19 +68,28 @@ #include "rtc_base/firewall_socket_server.h" #include "rtc_base/gunit.h" #include "rtc_base/internal/default_socket_server.h" +#include "rtc_base/ip_address.h" #include "rtc_base/logging.h" #include "rtc_base/mdns_responder_interface.h" -#include "rtc_base/nat_server.h" #include "rtc_base/nat_socket_factory.h" +#include "rtc_base/nat_types.h" +#include "rtc_base/net_helper.h" +#include "rtc_base/net_helpers.h" +#include "rtc_base/network.h" #include "rtc_base/network/received_packet.h" -#include "rtc_base/proxy_server.h" +#include "rtc_base/network/sent_packet.h" +#include "rtc_base/network_constants.h" +#include "rtc_base/network_route.h" +#include "rtc_base/socket.h" #include "rtc_base/socket_address.h" -#include "rtc_base/ssl_adapter.h" -#include "rtc_base/strings/string_builder.h" +#include "rtc_base/socket_server.h" +#include "rtc_base/third_party/sigslot/sigslot.h" #include "rtc_base/thread.h" #include "rtc_base/time_utils.h" #include "rtc_base/virtual_socket_server.h" #include "system_wrappers/include/metrics.h" +#include "test/gmock.h" +#include "test/gtest.h" #include "test/scoped_key_value_config.h" namespace { @@ -6453,4 +6490,62 @@ TEST_F(P2PTransportChannelTest, TestIceNoOldCandidatesAfterIceRestart) { DestroyChannels(); } +class P2PTransportChannelTestDtlsInStun : public P2PTransportChannelTestBase { + public: + P2PTransportChannelTestDtlsInStun() : P2PTransportChannelTestBase() {} + + protected: + void Run(bool ep1_support, bool ep2_support) { + IceConfig ep1_config; + ep1_config.dtls_handshake_in_stun = ep1_support; + IceConfig ep2_config; + ep2_config.dtls_handshake_in_stun = ep2_support; + CreateChannels(ep1_config, ep2_config); + // DTLS server hello done message as test data. + std::vector dtls_data = { + 0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x0c, 0x0e, 0x00, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }; + if (ep1_support) { + ep1_ch1()->SetDtlsDataToPiggyback(dtls_data); + } + if (ep2_support) { + ep2_ch1()->SetDtlsDataToPiggyback(dtls_data); + } + EXPECT_TRUE_SIMULATED_WAIT(CheckConnected(ep1_ch1(), ep2_ch1()), + kDefaultTimeout, clock_); + } + + rtc::ScopedFakeClock clock_; +}; + +TEST_F(P2PTransportChannelTestDtlsInStun, NotSupportedByEither) { + Run(false, false); + EXPECT_FALSE(ep1_ch1()->IsDtlsPiggybackSupportedByPeer()); + EXPECT_FALSE(ep2_ch1()->IsDtlsPiggybackSupportedByPeer()); + DestroyChannels(); +} + +TEST_F(P2PTransportChannelTestDtlsInStun, SupportedByClient) { + Run(true, false); + EXPECT_FALSE(ep1_ch1()->IsDtlsPiggybackSupportedByPeer()); + EXPECT_FALSE(ep2_ch1()->IsDtlsPiggybackSupportedByPeer()); + DestroyChannels(); +} + +TEST_F(P2PTransportChannelTestDtlsInStun, SupportedByServer) { + Run(false, true); + EXPECT_FALSE(ep1_ch1()->IsDtlsPiggybackSupportedByPeer()); + EXPECT_FALSE(ep2_ch1()->IsDtlsPiggybackSupportedByPeer()); + DestroyChannels(); +} + +TEST_F(P2PTransportChannelTestDtlsInStun, SupportedByBoth) { + Run(true, true); + EXPECT_TRUE(ep1_ch1()->IsDtlsPiggybackSupportedByPeer()); + EXPECT_TRUE(ep2_ch1()->IsDtlsPiggybackSupportedByPeer()); + DestroyChannels(); +} + } // namespace cricket diff --git a/p2p/dtls/dtls_ice_integrationtest.cc b/p2p/dtls/dtls_ice_integrationtest.cc new file mode 100644 index 0000000000..c21bb3bc64 --- /dev/null +++ b/p2p/dtls/dtls_ice_integrationtest.cc @@ -0,0 +1,190 @@ +/* + * Copyright 2024 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include +#include +#include +#include +#include + +#include "api/candidate.h" +#include "api/crypto/crypto_options.h" +#include "api/scoped_refptr.h" +#include "p2p/base/basic_packet_socket_factory.h" +#include "p2p/base/ice_transport_internal.h" +#include "p2p/base/p2p_transport_channel.h" +#include "p2p/base/port_allocator.h" +#include "p2p/base/transport_description.h" +#include "p2p/client/basic_port_allocator.h" +#include "p2p/dtls/dtls_transport.h" +#include "rtc_base/fake_clock.h" +#include "rtc_base/fake_network.h" +#include "rtc_base/gunit.h" +#include "rtc_base/rtc_certificate.h" +#include "rtc_base/socket_address.h" +#include "rtc_base/ssl_fingerprint.h" +#include "rtc_base/ssl_identity.h" +#include "rtc_base/ssl_stream_adapter.h" +#include "rtc_base/third_party/sigslot/sigslot.h" +#include "rtc_base/thread.h" +#include "rtc_base/virtual_socket_server.h" +#include "test/gtest.h" + +namespace { +constexpr int kDefaultTimeout = 10000; + +void SetRemoteFingerprintFromCert( + cricket::DtlsTransport& transport, + const rtc::scoped_refptr& cert) { + std::unique_ptr fingerprint = + rtc::SSLFingerprint::CreateFromCertificate(*cert); + + transport.SetRemoteParameters( + fingerprint->algorithm, + reinterpret_cast(fingerprint->digest.data()), + fingerprint->digest.size(), std::nullopt); +} + +} // namespace + +namespace cricket { + +class DtlsIceIntegrationTest + : public ::testing::TestWithParam>, + public sigslot::has_slots<> { + public: + void CandidateC2S(IceTransportInternal*, const Candidate& c) { + thread_.PostTask([this, c = c]() { server_ice_->AddRemoteCandidate(c); }); + } + void CandidateS2C(IceTransportInternal*, const Candidate& c) { + thread_.PostTask([this, c = c]() { client_ice_->AddRemoteCandidate(c); }); + } + + protected: + DtlsIceIntegrationTest() + : ss_(std::make_unique()), + socket_factory_( + std::make_unique(ss_.get())), + thread_(ss_.get()), + client_allocator_( + std::make_unique(&network_manager_, + socket_factory_.get())), + server_allocator_( + std::make_unique(&network_manager_, + socket_factory_.get())), + client_ice_( + std::make_unique("client_transport", + 0, + client_allocator_.get())), + server_ice_( + std::make_unique("server_transport", + 0, + server_allocator_.get())), + client_dtls_(client_ice_.get(), + webrtc::CryptoOptions(), + /*event_log=*/nullptr, + rtc::SSL_PROTOCOL_DTLS_12), + server_dtls_(server_ice_.get(), + webrtc::CryptoOptions(), + /*event_log=*/nullptr, + rtc::SSL_PROTOCOL_DTLS_12), + client_ice_parameters_("c_ufrag", + "c_icepwd_something_something", + false), + server_ice_parameters_("s_ufrag", + "s_icepwd_something_something", + false), + client_dtls_stun_piggyback_(std::get<0>(GetParam())), + server_dtls_stun_piggyback_(std::get<1>(GetParam())) { + // Setup ICE. + client_ice_->SetIceParameters(client_ice_parameters_); + client_ice_->SetRemoteIceParameters(server_ice_parameters_); + client_ice_->SetIceRole(ICEROLE_CONTROLLING); + client_ice_->SignalCandidateGathered.connect( + this, &DtlsIceIntegrationTest::CandidateC2S); + server_ice_->SetIceParameters(server_ice_parameters_); + server_ice_->SetRemoteIceParameters(client_ice_parameters_); + server_ice_->SetIceRole(ICEROLE_CONTROLLED); + server_ice_->SignalCandidateGathered.connect( + this, &DtlsIceIntegrationTest::CandidateS2C); + + // Setup DTLS. + auto client_certificate = rtc::RTCCertificate::Create( + rtc::SSLIdentity::Create("test", rtc::KT_DEFAULT)); + client_dtls_.SetLocalCertificate(client_certificate); + client_dtls_.SetDtlsRole(rtc::SSL_SERVER); + auto server_certificate = rtc::RTCCertificate::Create( + rtc::SSLIdentity::Create("test", rtc::KT_DEFAULT)); + server_dtls_.SetLocalCertificate(server_certificate); + server_dtls_.SetDtlsRole(rtc::SSL_CLIENT); + + SetRemoteFingerprintFromCert(server_dtls_, client_certificate); + SetRemoteFingerprintFromCert(client_dtls_, server_certificate); + + // Setup the network. + network_manager_.AddInterface(rtc::SocketAddress("192.168.1.1", 0)); + client_allocator_->Initialize(); + server_allocator_->Initialize(); + } + + ~DtlsIceIntegrationTest() = default; + + rtc::FakeNetworkManager network_manager_; + std::unique_ptr ss_; + std::unique_ptr socket_factory_; + rtc::AutoSocketServerThread thread_; + + std::unique_ptr client_allocator_; + std::unique_ptr server_allocator_; + + std::unique_ptr client_ice_; + std::unique_ptr server_ice_; + + DtlsTransport client_dtls_; + DtlsTransport server_dtls_; + + IceParameters client_ice_parameters_; + IceParameters server_ice_parameters_; + + bool client_dtls_stun_piggyback_; + bool server_dtls_stun_piggyback_; + + rtc::ScopedFakeClock fake_clock_; +}; + +TEST_P(DtlsIceIntegrationTest, SmokeTest) { + cricket::IceConfig client_config; + client_config.dtls_handshake_in_stun = client_dtls_stun_piggyback_; + client_ice_->SetIceConfig(client_config); + + cricket::IceConfig server_config; + server_config.dtls_handshake_in_stun = server_dtls_stun_piggyback_; + server_ice_->SetIceConfig(server_config); + + client_ice_->MaybeStartGathering(); + server_ice_->MaybeStartGathering(); + + // Note: this only reaches the pending piggybacking state. + EXPECT_TRUE_SIMULATED_WAIT(client_dtls_.writable() && server_dtls_.writable(), + kDefaultTimeout, fake_clock_); + EXPECT_EQ(client_ice_->IsDtlsPiggybackSupportedByPeer(), + client_dtls_stun_piggyback_ && server_dtls_stun_piggyback_); + EXPECT_EQ(server_ice_->IsDtlsPiggybackSupportedByPeer(), + client_dtls_stun_piggyback_ && server_dtls_stun_piggyback_); +} + +INSTANTIATE_TEST_SUITE_P(DtlsStunPiggybackingIntegrationTest, + DtlsIceIntegrationTest, + ::testing::Values(std::make_pair(false, false), + std::make_pair(true, false), + std::make_pair(false, true), + std::make_pair(true, true))); + +} // namespace cricket diff --git a/p2p/dtls/dtls_transport.cc b/p2p/dtls/dtls_transport.cc index 26567fcb28..e1b5893f86 100644 --- a/p2p/dtls/dtls_transport.cc +++ b/p2p/dtls/dtls_transport.cc @@ -60,8 +60,10 @@ static const size_t kMaxPendingPackets = 2; // Minimum and maximum values for the initial DTLS handshake timeout. We'll pick // an initial timeout based on ICE RTT estimates, but clamp it to this range. -static const int kMinHandshakeTimeout = 50; -static const int kMaxHandshakeTimeout = 3000; +static const int kMinHandshakeTimeoutMs = 50; +static const int kMaxHandshakeTimeoutMs = 3000; +// This effectively disables the handshake timeout. +static const int kDisabledHandshakeTimeoutMs = 3600 * 1000 * 24; static bool IsRtpPacket(rtc::ArrayView payload) { const uint8_t* u = payload.data(); @@ -96,6 +98,13 @@ rtc::StreamResult StreamInterfaceChannel::Write( size_t& written, int& /* error */) { RTC_DCHECK_RUN_ON(&callback_sequence_); + + if (IsDtlsHandshakePacket(data) && + ice_transport_->IsDtlsPiggybackSupportedByPeer()) { + ice_transport_->SetDtlsDataToPiggyback(data); + // The ICE transport is responsible for dropping these packets. + } + // Always succeeds, since this is an unreliable transport anyway. // TODO(zhihuang): Should this block if ice_transport_'s temporarily // unwritable? @@ -150,6 +159,7 @@ DtlsTransport::DtlsTransport(IceTransportInternal* ice_transport, DtlsTransport::~DtlsTransport() { if (ice_transport_) { + ice_transport_->SetPiggybackDtlsDataCallback(nullptr); ice_transport_->DeregisterReceivedPacketCallback(this); } } @@ -531,6 +541,20 @@ void DtlsTransport::ConnectToIceTransport() { this, &DtlsTransport::OnReceivingState); ice_transport_->SignalNetworkRouteChanged.connect( this, &DtlsTransport::OnNetworkRouteChanged); + ice_transport_->SetPiggybackDtlsDataCallback( + [this](rtc::PacketTransportInternal* transport, + const rtc::ReceivedPacket& packet) { + RTC_DCHECK(dtls_active_); + RTC_DCHECK(IsDtlsHandshakePacket(packet.payload())); + if (!dtls_active_) { + // Not doing DTLS. + return; + } + if (!IsDtlsHandshakePacket(packet.payload())) { + return; + } + OnReadPacket(transport, packet); + }); } // The state transition logic here is as follows: @@ -557,11 +581,37 @@ void DtlsTransport::OnWritableState(rtc::PacketTransportInternal* transport) { return; } + // The opportunistic attempt to do DTLS piggybacking failed. + // Recreate the DTLS session. Note: this assumes we can consider + // the previous DTLS session state beyond repair and no packet + // reached the peer. + if (dtls_ && !was_ever_connected_ && + !ice_transport_->IsDtlsPiggybackSupportedByPeer() && + (dtls_state() == webrtc::DtlsTransportState::kConnecting || + dtls_state() == webrtc::DtlsTransportState::kNew)) { + RTC_LOG(LS_ERROR) << "DTLS piggybacking not supported, restarting..."; + ice_transport_->SetPiggybackDtlsDataCallback(nullptr); + + dtls_.reset(nullptr); + set_dtls_state(webrtc::DtlsTransportState::kNew); + set_writable(false); + + if (!SetupDtls()) { + RTC_LOG(LS_ERROR) + << "Failed to setup DTLS again after attempted piggybacking."; + set_dtls_state(webrtc::DtlsTransportState::kFailed); + return; + } + // SetupDtls has called MaybeStartDtls() already. + return; + } + switch (dtls_state()) { case webrtc::DtlsTransportState::kNew: MaybeStartDtls(); break; case webrtc::DtlsTransportState::kConnected: + was_ever_connected_ = true; // Note: SignalWritableState fired by set_writable. set_writable(ice_transport_->writable()); break; @@ -705,6 +755,7 @@ void DtlsTransport::OnDtlsEvent(int sig, int err) { // sure we don't accidentally frob the state if it's closed. set_dtls_state(webrtc::DtlsTransportState::kConnected); set_writable(true); + ice_transport_->SetDtlsHandshakeComplete(dtls_role_ == rtc::SSL_CLIENT); } } if (sig & rtc::SE_READ) { @@ -762,8 +813,13 @@ void DtlsTransport::OnNetworkRouteChanged( } void DtlsTransport::MaybeStartDtls() { - if (dtls_ && ice_transport_->writable()) { - ConfigureHandshakeTimeout(); + RTC_DCHECK(ice_transport_); + // When adding the DTLS handshake in STUN we want to call StartSSL even + // before the ICE transport is ready. + bool start_early_for_dtls_in_stun = + ice_transport_->config().dtls_handshake_in_stun; + if (dtls_ && (ice_transport_->writable() || start_early_for_dtls_in_stun)) { + ConfigureHandshakeTimeout(start_early_for_dtls_in_stun); if (dtls_->StartSSL()) { // This should never fail: @@ -851,18 +907,26 @@ void DtlsTransport::OnDtlsHandshakeError(rtc::SSLHandshakeError error) { SendDtlsHandshakeError(error); } -void DtlsTransport::ConfigureHandshakeTimeout() { +void DtlsTransport::ConfigureHandshakeTimeout(bool uses_dtls_in_stun) { RTC_DCHECK(dtls_); - std::optional rtt = ice_transport_->GetRttEstimate(); - if (rtt) { + std::optional rtt_ms = ice_transport_->GetRttEstimate(); + if (uses_dtls_in_stun) { + // Configure a very high timeout to effectively disable the DTLS timeout + // and avoid fragmented resends. This is ok since DTLS-in-STUN caches + // the handshake pacets and resends them using the pacing of ICE. + RTC_LOG(LS_INFO) << ToString() << ": configuring DTLS handshake timeout " + << kDisabledHandshakeTimeoutMs << "ms for DTLS-in-STUN"; + dtls_->SetInitialRetransmissionTimeout(kDisabledHandshakeTimeoutMs); + } else if (rtt_ms) { // Limit the timeout to a reasonable range in case the ICE RTT takes // extreme values. - int initial_timeout = std::max(kMinHandshakeTimeout, - std::min(kMaxHandshakeTimeout, 2 * (*rtt))); + int initial_timeout_ms = + std::max(kMinHandshakeTimeoutMs, + std::min(kMaxHandshakeTimeoutMs, 2 * (*rtt_ms))); RTC_LOG(LS_INFO) << ToString() << ": configuring DTLS handshake timeout " - << initial_timeout << " based on ICE RTT " << *rtt; + << initial_timeout_ms << "ms based on ICE RTT " << *rtt_ms; - dtls_->SetInitialRetransmissionTimeout(initial_timeout); + dtls_->SetInitialRetransmissionTimeout(initial_timeout_ms); } else { RTC_LOG(LS_INFO) << ToString() diff --git a/p2p/dtls/dtls_transport.h b/p2p/dtls/dtls_transport.h index 143f02458a..a2ee3dbf70 100644 --- a/p2p/dtls/dtls_transport.h +++ b/p2p/dtls/dtls_transport.h @@ -237,7 +237,7 @@ class DtlsTransport : public DtlsTransportInternal { void MaybeStartDtls(); bool HandleDtlsPacket(rtc::ArrayView payload); void OnDtlsHandshakeError(rtc::SSLHandshakeError error); - void ConfigureHandshakeTimeout(); + void ConfigureHandshakeTimeout(bool uses_dtls_in_stun); void set_receiving(bool receiving); void set_writable(bool writable); @@ -269,6 +269,8 @@ class DtlsTransport : public DtlsTransportInternal { bool receiving_ = false; bool writable_ = false; + bool was_ever_connected_ = false; + webrtc::RtcEventLog* const event_log_; }; diff --git a/p2p/dtls/dtls_transport_unittest.cc b/p2p/dtls/dtls_transport_unittest.cc index cc94f1881b..2eed02ee44 100644 --- a/p2p/dtls/dtls_transport_unittest.cc +++ b/p2p/dtls/dtls_transport_unittest.cc @@ -21,6 +21,8 @@ #include #include +#include "absl/functional/any_invocable.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "api/array_view.h" #include "api/crypto/crypto_options.h" @@ -34,6 +36,8 @@ #include "p2p/dtls/dtls_utils.h" #include "rtc_base/buffer.h" #include "rtc_base/byte_order.h" +#include "rtc_base/checks.h" +#include "rtc_base/copy_on_write_buffer.h" #include "rtc_base/fake_clock.h" #include "rtc_base/gunit.h" #include "rtc_base/logging.h" @@ -44,7 +48,6 @@ #include "rtc_base/ssl_stream_adapter.h" #include "rtc_base/third_party/sigslot/sigslot.h" #include "rtc_base/thread.h" -#include "test/field_trial.h" #include "test/gtest.h" #define MAYBE_SKIP_TEST(feature) \ diff --git a/pc/peer_connection.cc b/pc/peer_connection.cc index cc27509ce5..5f2d7bdea8 100644 --- a/pc/peer_connection.cc +++ b/pc/peer_connection.cc @@ -301,6 +301,8 @@ cricket::IceConfig ParseIceConfig( ice_config.network_preference = config.network_preference; ice_config.stable_writable_connection_ping_interval = config.stable_writable_connection_ping_interval_ms; + ice_config.dtls_handshake_in_stun = + false; // Filled in later based on field trial. return ice_config; } @@ -916,7 +918,11 @@ JsepTransportController* PeerConnection::InitializeTransportController_n( })); }); - transport_controller_->SetIceConfig(ParseIceConfig(configuration)); + auto ice_config = ParseIceConfig(configuration); + ice_config.dtls_handshake_in_stun = + CanAttemptDtlsStunPiggybacking(configuration); + + transport_controller_->SetIceConfig(ice_config); return transport_controller_.get(); } @@ -1644,6 +1650,8 @@ RTCError PeerConnection::SetConfiguration( modified_config.GetTurnPortPrunePolicy() != configuration_.GetTurnPortPrunePolicy(); cricket::IceConfig ice_config = ParseIceConfig(modified_config); + ice_config.dtls_handshake_in_stun = + CanAttemptDtlsStunPiggybacking(modified_config); // Apply part of the configuration on the network thread. In theory this // shouldn't fail. @@ -3122,4 +3130,14 @@ PeerConnection::InitializeUnDemuxablePacketHandler() { }; } +bool PeerConnection::CanAttemptDtlsStunPiggybacking( + const RTCConfiguration& configuration) { + // Enable DTLS-in-STUN only if no certificates were passed those + // may be RSA certificates and this feature only works with small + // ECDSA certificates. Determining the type of the key is + // not trivially possible at this point. + return dtls_enabled_ && configuration.certificates.empty() && + env_.field_trials().IsEnabled("WebRTC-IceHandshakeDtls"); +} + } // namespace webrtc diff --git a/pc/peer_connection.h b/pc/peer_connection.h index bf5b7857f4..f091ae0474 100644 --- a/pc/peer_connection.h +++ b/pc/peer_connection.h @@ -722,6 +722,8 @@ class PeerConnection : public PeerConnectionInternal, PayloadTypePicker payload_type_picker_; // This variable needs to be the last one in the class. rtc::WeakPtrFactory weak_factory_; + + bool CanAttemptDtlsStunPiggybacking(const RTCConfiguration& configuration); }; } // namespace webrtc