From 1dca9d513ada149c444a7316e16e91d53635ab5b Mon Sep 17 00:00:00 2001 From: Diogo Real Date: Tue, 29 Aug 2017 12:18:32 -0700 Subject: [PATCH] Support a user-provided string for the TLS ALPN extension. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix source formatting Add TLS ALPN extension. Bug: webrtc:8086 Change-Id: I1f28ccd78760d3415e465f734744d2c2f93845e2 Reviewed-on: https://chromium-review.googlesource.com/611150 Reviewed-by: Sami Kalliomäki Reviewed-by: Magnus Jedvert Reviewed-by: Justin Uberti Reviewed-by: Taylor Brandstetter Reviewed-by: Peter Thatcher Commit-Queue: Diogo Real Cr-Commit-Position: refs/heads/master@{#19588} --- webrtc/api/peerconnectioninterface.h | 5 +- webrtc/p2p/base/basicpacketsocketfactory.cc | 19 ++++--- webrtc/p2p/base/basicpacketsocketfactory.h | 13 ++++- webrtc/p2p/base/packetsocketfactory.h | 22 ++++++- webrtc/p2p/base/port_unittest.cc | 8 +-- webrtc/p2p/base/portallocator.h | 1 + webrtc/p2p/base/testturnserver.h | 2 +- webrtc/p2p/base/turnport.cc | 9 ++- webrtc/p2p/base/turnport.h | 13 ++++- webrtc/p2p/base/turnport_unittest.cc | 6 +- webrtc/p2p/client/basicportallocator.cc | 3 +- webrtc/pc/iceserverparsing.cc | 2 + webrtc/rtc_base/BUILD.gn | 9 +++ webrtc/rtc_base/openssladapter.cc | 48 ++++++++++++++-- webrtc/rtc_base/openssladapter.h | 10 ++++ webrtc/rtc_base/openssladapter_unittest.cc | 41 +++++++++++++ webrtc/rtc_base/ssladapter.h | 8 +-- webrtc/rtc_base/ssladapter_unittest.cc | 18 +++++- .../api/org/webrtc/PeerConnection.java | 57 ++++++++++++++++++- webrtc/sdk/android/src/jni/jni_helpers.cc | 12 ++++ webrtc/sdk/android/src/jni/jni_helpers.h | 5 ++ .../src/jni/pc/java_native_conversion.cc | 5 ++ .../Classes/PeerConnection/RTCIceServer.mm | 33 ++++++++++- .../Framework/Headers/WebRTC/RTCIceServer.h | 17 +++++- .../Framework/UnitTests/RTCIceServerTest.mm | 19 +++++++ 25 files changed, 344 insertions(+), 41 deletions(-) create mode 100644 webrtc/rtc_base/openssladapter_unittest.cc diff --git a/webrtc/api/peerconnectioninterface.h b/webrtc/api/peerconnectioninterface.h index 632fc3ca9f..6a2b904453 100644 --- a/webrtc/api/peerconnectioninterface.h +++ b/webrtc/api/peerconnectioninterface.h @@ -193,11 +193,14 @@ class PeerConnectionInterface : public rtc::RefCountInterface { // extension). If |urls| itself contains the hostname, this isn't // necessary. std::string hostname; + // List of protocols to be used in the TLS ALPN extension. + std::vector tls_alpn_protocols; bool operator==(const IceServer& o) const { return uri == o.uri && urls == o.urls && username == o.username && password == o.password && tls_cert_policy == o.tls_cert_policy && - hostname == o.hostname; + hostname == o.hostname && + tls_alpn_protocols == o.tls_alpn_protocols; } bool operator!=(const IceServer& o) const { return !(*this == o); } }; diff --git a/webrtc/p2p/base/basicpacketsocketfactory.cc b/webrtc/p2p/base/basicpacketsocketfactory.cc index e911f2fe22..fe9eb4bd93 100644 --- a/webrtc/p2p/base/basicpacketsocketfactory.cc +++ b/webrtc/p2p/base/basicpacketsocketfactory.cc @@ -105,8 +105,11 @@ AsyncPacketSocket* BasicPacketSocketFactory::CreateServerTcpSocket( } AsyncPacketSocket* BasicPacketSocketFactory::CreateClientTcpSocket( - const SocketAddress& local_address, const SocketAddress& remote_address, - const ProxyInfo& proxy_info, const std::string& user_agent, int opts) { + const SocketAddress& local_address, + const SocketAddress& remote_address, + const ProxyInfo& proxy_info, + const std::string& user_agent, + const PacketSocketTcpOptions& tcp_options) { AsyncSocket* socket = socket_factory()->CreateAsyncSocket(local_address.family(), SOCK_STREAM); if (!socket) { @@ -138,9 +141,9 @@ AsyncPacketSocket* BasicPacketSocketFactory::CreateClientTcpSocket( } // Assert that at most one TLS option is used. - int tlsOpts = - opts & (PacketSocketFactory::OPT_TLS | PacketSocketFactory::OPT_TLS_FAKE | - PacketSocketFactory::OPT_TLS_INSECURE); + int tlsOpts = tcp_options.opts & (PacketSocketFactory::OPT_TLS | + PacketSocketFactory::OPT_TLS_FAKE | + PacketSocketFactory::OPT_TLS_INSECURE); RTC_DCHECK((tlsOpts & (tlsOpts - 1)) == 0); if ((tlsOpts & PacketSocketFactory::OPT_TLS) || @@ -152,9 +155,11 @@ AsyncPacketSocket* BasicPacketSocketFactory::CreateClientTcpSocket( } if (tlsOpts & PacketSocketFactory::OPT_TLS_INSECURE) { - ssl_adapter->set_ignore_bad_cert(true); + ssl_adapter->SetIgnoreBadCert(true); } + ssl_adapter->SetAlpnProtocols(tcp_options.tls_alpn_protocols); + socket = ssl_adapter; if (ssl_adapter->StartSSL(remote_address.hostname().c_str(), false) != 0) { @@ -176,7 +181,7 @@ AsyncPacketSocket* BasicPacketSocketFactory::CreateClientTcpSocket( // Finally, wrap that socket in a TCP or STUN TCP packet socket. AsyncPacketSocket* tcp_socket; - if (opts & PacketSocketFactory::OPT_STUN) { + if (tcp_options.opts & PacketSocketFactory::OPT_STUN) { tcp_socket = new cricket::AsyncStunTCPSocket(socket, false); } else { tcp_socket = new AsyncTCPSocket(socket, false); diff --git a/webrtc/p2p/base/basicpacketsocketfactory.h b/webrtc/p2p/base/basicpacketsocketfactory.h index 5046e0f518..3cb3304adf 100644 --- a/webrtc/p2p/base/basicpacketsocketfactory.h +++ b/webrtc/p2p/base/basicpacketsocketfactory.h @@ -37,7 +37,18 @@ class BasicPacketSocketFactory : public PacketSocketFactory { const SocketAddress& remote_address, const ProxyInfo& proxy_info, const std::string& user_agent, - int opts) override; + int opts) override { + PacketSocketTcpOptions tcp_options; + tcp_options.opts = opts; + return CreateClientTcpSocket(local_address, remote_address, proxy_info, + user_agent, tcp_options); + } + AsyncPacketSocket* CreateClientTcpSocket( + const SocketAddress& local_address, + const SocketAddress& remote_address, + const ProxyInfo& proxy_info, + const std::string& user_agent, + const PacketSocketTcpOptions& tcp_options) override; AsyncResolverInterface* CreateAsyncResolver() override; diff --git a/webrtc/p2p/base/packetsocketfactory.h b/webrtc/p2p/base/packetsocketfactory.h index 60f0ae0580..34f568c708 100644 --- a/webrtc/p2p/base/packetsocketfactory.h +++ b/webrtc/p2p/base/packetsocketfactory.h @@ -16,6 +16,12 @@ namespace rtc { +// This structure contains options required to create TCP packet sockets. +struct PacketSocketTcpOptions { + int opts; + std::vector tls_alpn_protocols; +}; + class AsyncPacketSocket; class AsyncResolverInterface; @@ -45,7 +51,7 @@ class PacketSocketFactory { uint16_t max_port, int opts) = 0; - // TODO: |proxy_info| and |user_agent| should be set + // TODO(deadbeef): |proxy_info| and |user_agent| should be set // per-factory and not when socket is created. virtual AsyncPacketSocket* CreateClientTcpSocket( const SocketAddress& local_address, @@ -54,6 +60,20 @@ class PacketSocketFactory { const std::string& user_agent, int opts) = 0; + // TODO(deadbeef): |proxy_info|, |user_agent| and |tcp_options| should + // be set per-factory and not when socket is created. + // TODO(deadbeef): Implement this method in all subclasses (namely those in + // Chromium), make pure virtual, and remove the old CreateClientTcpSocket. + virtual AsyncPacketSocket* CreateClientTcpSocket( + const SocketAddress& local_address, + const SocketAddress& remote_address, + const ProxyInfo& proxy_info, + const std::string& user_agent, + const PacketSocketTcpOptions& tcp_options) { + return CreateClientTcpSocket(local_address, remote_address, proxy_info, + user_agent, tcp_options.opts); + } + virtual AsyncResolverInterface* CreateAsyncResolver() = 0; private: diff --git a/webrtc/p2p/base/port_unittest.cc b/webrtc/p2p/base/port_unittest.cc index 757c1b31bf..0a5360991b 100644 --- a/webrtc/p2p/base/port_unittest.cc +++ b/webrtc/p2p/base/port_unittest.cc @@ -533,10 +533,10 @@ class PortTest : public testing::Test, public sigslot::has_slots<> { PacketSocketFactory* socket_factory, ProtocolType int_proto, ProtocolType ext_proto, const rtc::SocketAddress& server_addr) { - return TurnPort::Create(&main_, socket_factory, MakeNetwork(addr), 0, 0, - username_, password_, - ProtocolAddress(server_addr, int_proto), - kRelayCredentials, 0, std::string()); + return TurnPort::Create( + &main_, socket_factory, MakeNetwork(addr), 0, 0, username_, password_, + ProtocolAddress(server_addr, int_proto), kRelayCredentials, 0, + std::string(), std::vector()); } RelayPort* CreateGturnPort(const SocketAddress& addr, ProtocolType int_proto, ProtocolType ext_proto) { diff --git a/webrtc/p2p/base/portallocator.h b/webrtc/p2p/base/portallocator.h index 45a941a9b7..eef4305591 100644 --- a/webrtc/p2p/base/portallocator.h +++ b/webrtc/p2p/base/portallocator.h @@ -191,6 +191,7 @@ struct RelayServerConfig { RelayCredentials credentials; int priority = 0; TlsCertPolicy tls_cert_policy = TlsCertPolicy::TLS_CERT_POLICY_SECURE; + std::vector tls_alpn_protocols; }; class PortAllocatorSession : public sigslot::has_slots<> { diff --git a/webrtc/p2p/base/testturnserver.h b/webrtc/p2p/base/testturnserver.h index 80f259fdc9..4333be0965 100644 --- a/webrtc/p2p/base/testturnserver.h +++ b/webrtc/p2p/base/testturnserver.h @@ -97,7 +97,7 @@ class TestTurnServer : public TurnAuthInterface { adapter->SetRole(rtc::SSL_SERVER); adapter->SetIdentity( rtc::SSLIdentity::Generate("test turn server", rtc::KeyParams())); - adapter->set_ignore_bad_cert(true); + adapter->SetIgnoreBadCert(true); socket = adapter; } socket->Bind(int_addr); diff --git a/webrtc/p2p/base/turnport.cc b/webrtc/p2p/base/turnport.cc index 4dfe06d13a..68dae82583 100644 --- a/webrtc/p2p/base/turnport.cc +++ b/webrtc/p2p/base/turnport.cc @@ -221,7 +221,8 @@ TurnPort::TurnPort(rtc::Thread* thread, const ProtocolAddress& server_address, const RelayCredentials& credentials, int server_priority, - const std::string& origin) + const std::string& origin, + const std::vector& tls_alpn_protocols) : Port(thread, RELAY_PORT_TYPE, factory, @@ -231,6 +232,7 @@ TurnPort::TurnPort(rtc::Thread* thread, username, password), server_address_(server_address), + tls_alpn_protocols_(tls_alpn_protocols), credentials_(credentials), socket_(NULL), resolver_(NULL), @@ -336,9 +338,12 @@ bool TurnPort::CreateTurnClientSocket() { } } + rtc::PacketSocketTcpOptions tcp_options; + tcp_options.opts = opts; + tcp_options.tls_alpn_protocols = tls_alpn_protocols_; socket_ = socket_factory()->CreateClientTcpSocket( rtc::SocketAddress(Network()->GetBestIP(), 0), server_address_.address, - proxy(), user_agent(), opts); + proxy(), user_agent(), tcp_options); } if (!socket_) { diff --git a/webrtc/p2p/base/turnport.h b/webrtc/p2p/base/turnport.h index abdaa3dcb4..46200a988c 100644 --- a/webrtc/p2p/base/turnport.h +++ b/webrtc/p2p/base/turnport.h @@ -69,10 +69,11 @@ class TurnPort : public Port { const ProtocolAddress& server_address, const RelayCredentials& credentials, int server_priority, - const std::string& origin) { + const std::string& origin, + const std::vector& tls_alpn_protocols) { return new TurnPort(thread, factory, network, min_port, max_port, username, password, server_address, credentials, server_priority, - origin); + origin, tls_alpn_protocols); } virtual ~TurnPort(); @@ -95,6 +96,10 @@ class TurnPort : public Port { tls_cert_policy_ = tls_cert_policy; } + virtual std::vector GetTlsAlpnProtocols() const { + return tls_alpn_protocols_; + } + virtual void PrepareAddress(); virtual Connection* CreateConnection( const Candidate& c, PortInterface::CandidateOrigin origin); @@ -186,7 +191,8 @@ class TurnPort : public Port { const ProtocolAddress& server_address, const RelayCredentials& credentials, int server_priority, - const std::string& origin); + const std::string& origin, + const std::vector& alpn_protocols); private: enum { @@ -266,6 +272,7 @@ class TurnPort : public Port { ProtocolAddress server_address_; TlsCertPolicy tls_cert_policy_ = TlsCertPolicy::TLS_CERT_POLICY_SECURE; + std::vector tls_alpn_protocols_; RelayCredentials credentials_; AttemptedServerSet attempted_server_addresses_; diff --git a/webrtc/p2p/base/turnport_unittest.cc b/webrtc/p2p/base/turnport_unittest.cc index 7c6f72bda7..546c21e782 100644 --- a/webrtc/p2p/base/turnport_unittest.cc +++ b/webrtc/p2p/base/turnport_unittest.cc @@ -261,9 +261,9 @@ class TurnPortTest : public testing::Test, const ProtocolAddress& server_address, const std::string& origin) { RelayCredentials credentials(username, password); - turn_port_.reset(TurnPort::Create(&main_, &socket_factory_, network, 0, 0, - kIceUfrag1, kIcePwd1, server_address, - credentials, 0, origin)); + turn_port_.reset(TurnPort::Create( + &main_, &socket_factory_, network, 0, 0, kIceUfrag1, kIcePwd1, + server_address, credentials, 0, origin, std::vector())); // This TURN port will be the controlling. turn_port_->SetIceRole(ICEROLE_CONTROLLING); ConnectSignals(); diff --git a/webrtc/p2p/client/basicportallocator.cc b/webrtc/p2p/client/basicportallocator.cc index 66ee9a4ecf..fe2dfb6d1e 100644 --- a/webrtc/p2p/client/basicportallocator.cc +++ b/webrtc/p2p/client/basicportallocator.cc @@ -1444,7 +1444,8 @@ void AllocationSequence::CreateTurnPort(const RelayServerConfig& config) { session_->network_thread(), session_->socket_factory(), network_, session_->allocator()->min_port(), session_->allocator()->max_port(), session_->username(), session_->password(), *relay_port, - config.credentials, config.priority, session_->allocator()->origin()); + config.credentials, config.priority, session_->allocator()->origin(), + config.tls_alpn_protocols); } RTC_DCHECK(port != NULL); port->SetTlsCertPolicy(config.tls_cert_policy); diff --git a/webrtc/pc/iceserverparsing.cc b/webrtc/pc/iceserverparsing.cc index d9f4885c2b..00e9f2590f 100644 --- a/webrtc/pc/iceserverparsing.cc +++ b/webrtc/pc/iceserverparsing.cc @@ -257,6 +257,8 @@ static RTCErrorType ParseIceServerUrl( config.tls_cert_policy = cricket::TlsCertPolicy::TLS_CERT_POLICY_INSECURE_NO_CHECK; } + config.tls_alpn_protocols = server.tls_alpn_protocols; + turn_servers->push_back(config); break; } diff --git a/webrtc/rtc_base/BUILD.gn b/webrtc/rtc_base/BUILD.gn index 01a5679b35..03e8ea608a 100644 --- a/webrtc/rtc_base/BUILD.gn +++ b/webrtc/rtc_base/BUILD.gn @@ -1038,6 +1038,7 @@ if (rtc_include_tests) { } if (is_posix) { sources += [ + "openssladapter_unittest.cc", "ssladapter_unittest.cc", "sslidentity_unittest.cc", "sslstreamadapter_unittest.cc", @@ -1056,6 +1057,14 @@ if (rtc_include_tests) { # Suppress warnings from the Chromium Clang plugin (bugs.webrtc.org/163). suppressed_configs += [ "//build/config/clang:find_bad_constructs" ] } + if (build_with_chromium) { + include_dirs = [ "../../boringssl/src/include" ] + } + if (rtc_build_ssl) { + deps += [ "//third_party/boringssl" ] + } else { + configs += [ ":external_ssl_library" ] + } } } diff --git a/webrtc/rtc_base/openssladapter.cc b/webrtc/rtc_base/openssladapter.cc index 64eb0ab77e..9164258692 100644 --- a/webrtc/rtc_base/openssladapter.cc +++ b/webrtc/rtc_base/openssladapter.cc @@ -286,6 +286,7 @@ OpenSSLAdapter::OpenSSLAdapter(AsyncSocket* socket, ssl_(nullptr), ssl_ctx_(nullptr), ssl_mode_(SSL_MODE_TLS), + ignore_bad_cert_(false), custom_verification_succeeded_(false) { // If a factory is used, take a reference on the factory's SSL_CTX. // Otherwise, we'll create our own later. @@ -302,6 +303,14 @@ OpenSSLAdapter::~OpenSSLAdapter() { Cleanup(); } +void OpenSSLAdapter::SetIgnoreBadCert(bool ignore) { + ignore_bad_cert_ = ignore; +} + +void OpenSSLAdapter::SetAlpnProtocols(const std::vector& protos) { + alpn_protocols_ = protos; +} + void OpenSSLAdapter::SetMode(SSLMode mode) { RTC_DCHECK(!ssl_ctx_); RTC_DCHECK(state_ == SSL_NONE); @@ -327,7 +336,7 @@ AsyncSocket* OpenSSLAdapter::Accept(SocketAddress* paddr) { SSLAdapter* adapter = SSLAdapter::Create(socket); adapter->SetIdentity(identity_->GetReference()); adapter->SetRole(rtc::SSL_SERVER); - adapter->set_ignore_bad_cert(ignore_bad_cert()); + adapter->SetIgnoreBadCert(ignore_bad_cert_); adapter->StartSSL("", false); return adapter; } @@ -424,10 +433,18 @@ int OpenSSLAdapter::BeginSSL() { } // Set a couple common TLS extensions; even though we don't use them yet. - // TODO(emadomara) Add ALPN extension. SSL_enable_ocsp_stapling(ssl_); SSL_enable_signed_cert_timestamps(ssl_); + if (!alpn_protocols_.empty()) { + std::string tls_alpn_string = TransformAlpnProtocols(alpn_protocols_); + if (!tls_alpn_string.empty()) { + SSL_set_alpn_protos( + ssl_, reinterpret_cast(tls_alpn_string.data()), + tls_alpn_string.size()); + } + } + // Now that the initial config is done, transfer ownership of |bio| to the // SSL object. If ContinueSSL() fails, the bio will be freed in Cleanup(). SSL_set_bio(ssl_, bio, bio); @@ -927,14 +944,14 @@ bool OpenSSLAdapter::VerifyServerName(SSL* ssl, const char* host, } bool OpenSSLAdapter::SSLPostConnectionCheck(SSL* ssl, const char* host) { - bool ok = VerifyServerName(ssl, host, ignore_bad_cert()); + bool ok = VerifyServerName(ssl, host, ignore_bad_cert_); if (ok) { ok = (SSL_get_verify_result(ssl) == X509_V_OK || custom_verification_succeeded_); } - if (!ok && ignore_bad_cert()) { + if (!ok && ignore_bad_cert_) { LOG(LS_INFO) << "Other TLS post connection checks failed."; ok = true; } @@ -1009,7 +1026,7 @@ int OpenSSLAdapter::SSLVerifyCallback(int ok, X509_STORE_CTX* store) { } // Should only be used for debugging and development. - if (!ok && stream->ignore_bad_cert()) { + if (!ok && stream->ignore_bad_cert_) { LOG(LS_WARNING) << "Ignoring cert error while verifying cert chain"; ok = 1; } @@ -1096,6 +1113,27 @@ SSL_CTX* OpenSSLAdapter::CreateContext(SSLMode mode, bool enable_cache) { return ctx; } +std::string TransformAlpnProtocols( + const std::vector& alpn_protocols) { + // Transforms the alpn_protocols list to the format expected by + // Open/BoringSSL. This requires joining the protocols into a single string + // and prepending a character with the size of the protocol string before + // each protocol. + std::string transformed_alpn; + for (const std::string& proto : alpn_protocols) { + if (proto.size() == 0 || proto.size() > 0xFF) { + LOG(LS_ERROR) << "OpenSSLAdapter::Error(" + << "TransformAlpnProtocols received proto with size " + << proto.size() << ")"; + return ""; + } + transformed_alpn += static_cast(proto.size()); + transformed_alpn += proto; + LOG(LS_VERBOSE) << "TransformAlpnProtocols: Adding proto: " << proto; + } + return transformed_alpn; +} + ////////////////////////////////////////////////////////////////////// // OpenSSLAdapterFactory ////////////////////////////////////////////////////////////////////// diff --git a/webrtc/rtc_base/openssladapter.h b/webrtc/rtc_base/openssladapter.h index b57ea8fd33..9c6c34479b 100644 --- a/webrtc/rtc_base/openssladapter.h +++ b/webrtc/rtc_base/openssladapter.h @@ -38,6 +38,9 @@ class OpenSSLAdapter : public SSLAdapter, public MessageHandler { OpenSSLAdapterFactory* factory = nullptr); ~OpenSSLAdapter() override; + void SetIgnoreBadCert(bool ignore) override; + void SetAlpnProtocols(const std::vector& protos) override; + void SetMode(SSLMode mode) override; void SetIdentity(SSLIdentity* identity) override; void SetRole(SSLRole role) override; @@ -129,10 +132,17 @@ class OpenSSLAdapter : public SSLAdapter, public MessageHandler { std::string ssl_host_name_; // Do DTLS or not SSLMode ssl_mode_; + // If true, the server certificate need not match the configured hostname. + bool ignore_bad_cert_; + // List of protocols to be used in the TLS ALPN extension. + std::vector alpn_protocols_; bool custom_verification_succeeded_; }; +std::string TransformAlpnProtocols(const std::vector& protos); + +///////////////////////////////////////////////////////////////////////////// class OpenSSLAdapterFactory : public SSLAdapterFactory { public: OpenSSLAdapterFactory(); diff --git a/webrtc/rtc_base/openssladapter_unittest.cc b/webrtc/rtc_base/openssladapter_unittest.cc new file mode 100644 index 0000000000..e4432664fb --- /dev/null +++ b/webrtc/rtc_base/openssladapter_unittest.cc @@ -0,0 +1,41 @@ +/* + * Copyright 2017 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include +#include +#include + +#include "webrtc/rtc_base/gunit.h" +#include "webrtc/rtc_base/openssladapter.h" + +namespace rtc { + +TEST(OpenSSLAdapterTest, TestTransformAlpnProtocols) { + EXPECT_EQ("", TransformAlpnProtocols(std::vector())); + + // Protocols larger than 255 characters (whose size can't be fit in a byte), + // can't be converted, and an empty string will be returned. + std::string large_protocol(256, 'a'); + EXPECT_EQ("", + TransformAlpnProtocols(std::vector{large_protocol})); + + // One protocol test. + std::vector alpn_protos{"h2"}; + std::stringstream expected_response; + expected_response << static_cast(2) << "h2"; + EXPECT_EQ(expected_response.str(), TransformAlpnProtocols(alpn_protos)); + + // Standard protocols test (h2,http/1.1). + alpn_protos.push_back("http/1.1"); + expected_response << static_cast(8) << "http/1.1"; + EXPECT_EQ(expected_response.str(), TransformAlpnProtocols(alpn_protos)); +} + +} // namespace rtc diff --git a/webrtc/rtc_base/ssladapter.h b/webrtc/rtc_base/ssladapter.h index 87e7debc76..b30e176c45 100644 --- a/webrtc/rtc_base/ssladapter.h +++ b/webrtc/rtc_base/ssladapter.h @@ -47,8 +47,8 @@ class SSLAdapter : public AsyncSocketAdapter { // Do not call these methods in production code. // TODO(juberti): Remove the opportunistic encryption mechanism in // BasicPacketSocketFactory that uses this function. - bool ignore_bad_cert() const { return ignore_bad_cert_; } - void set_ignore_bad_cert(bool ignore) { ignore_bad_cert_ = ignore; } + virtual void SetIgnoreBadCert(bool ignore) = 0; + virtual void SetAlpnProtocols(const std::vector& protos) = 0; // Do DTLS or TLS (default is TLS, if unspecified) virtual void SetMode(SSLMode mode) = 0; @@ -76,10 +76,6 @@ class SSLAdapter : public AsyncSocketAdapter { // and deletes |socket|. Otherwise, the returned SSLAdapter takes ownership // of |socket|. static SSLAdapter* Create(AsyncSocket* socket); - - private: - // If true, the server certificate need not match the configured hostname. - bool ignore_bad_cert_ = false; }; /////////////////////////////////////////////////////////////////////////////// diff --git a/webrtc/rtc_base/ssladapter_unittest.cc b/webrtc/rtc_base/ssladapter_unittest.cc index 929b14f87b..5c61f6a2f7 100644 --- a/webrtc/rtc_base/ssladapter_unittest.cc +++ b/webrtc/rtc_base/ssladapter_unittest.cc @@ -52,7 +52,7 @@ class SSLAdapterTestDummyClient : public sigslot::has_slots<> { // Ignore any certificate errors for the purpose of testing. // Note: We do this only because we don't have a real certificate. // NEVER USE THIS IN PRODUCTION CODE! - ssl_adapter_->set_ignore_bad_cert(true); + ssl_adapter_->SetIgnoreBadCert(true); ssl_adapter_->SignalReadEvent.connect(this, &SSLAdapterTestDummyClient::OnSSLAdapterReadEvent); @@ -60,6 +60,10 @@ class SSLAdapterTestDummyClient : public sigslot::has_slots<> { &SSLAdapterTestDummyClient::OnSSLAdapterCloseEvent); } + void SetAlpnProtocols(const std::vector& protos) { + ssl_adapter_->SetAlpnProtocols(protos); + } + rtc::SocketAddress GetAddress() const { return ssl_adapter_->GetLocalAddress(); } @@ -282,6 +286,10 @@ class SSLAdapterTestBase : public testing::Test, handshake_wait_ = wait; } + void SetAlpnProtocols(const std::vector& protos) { + client_->SetAlpnProtocols(protos); + } + void TestHandshake(bool expect_success) { int rv; @@ -434,6 +442,14 @@ TEST_F(SSLAdapterTestTLS_ECDSA, TestTLSTransfer) { TestTransfer("Hello, world!"); } +// Test transfer using ALPN with protos as h2 and http/1.1 +TEST_F(SSLAdapterTestTLS_ECDSA, TestTLSALPN) { + std::vector alpn_protos{"h2", "http/1.1"}; + SetAlpnProtocols(alpn_protos); + TestHandshake(true); + TestTransfer("Hello, world!"); +} + // Basic tests: DTLS // Test that handshake works, using RSA diff --git a/webrtc/sdk/android/api/org/webrtc/PeerConnection.java b/webrtc/sdk/android/api/org/webrtc/PeerConnection.java index 88610a43c0..ce9a110c8f 100644 --- a/webrtc/sdk/android/api/org/webrtc/PeerConnection.java +++ b/webrtc/sdk/android/api/org/webrtc/PeerConnection.java @@ -110,6 +110,9 @@ public class PeerConnection { // necessary. public final String hostname; + // List of protocols to be used in the TLS ALPN extension. + public final List tlsAlpnProtocols; + /** Convenience constructor for STUN servers. */ public IceServer(String uri) { this(uri, "", ""); @@ -125,16 +128,68 @@ public class PeerConnection { public IceServer(String uri, String username, String password, TlsCertPolicy tlsCertPolicy, String hostname) { + this(uri, username, password, tlsCertPolicy, hostname, null); + } + + private IceServer(String uri, String username, String password, TlsCertPolicy tlsCertPolicy, + String hostname, List tlsAlpnProtocols) { this.uri = uri; this.username = username; this.password = password; this.tlsCertPolicy = tlsCertPolicy; this.hostname = hostname; + this.tlsAlpnProtocols = tlsAlpnProtocols; } public String toString() { return uri + " [" + username + ":" + password + "] [" + tlsCertPolicy + "] [" + hostname - + "]"; + + "] [" + tlsAlpnProtocols + "]"; + } + + public static Builder builder(String uri) { + return new Builder(uri); + } + + public static class Builder { + private String uri; + private String username = ""; + private String password = ""; + private TlsCertPolicy tlsCertPolicy = TlsCertPolicy.TLS_CERT_POLICY_SECURE; + private String hostname = ""; + private List tlsAlpnProtocols; + + private Builder(String uri) { + this.uri = uri; + } + + public Builder setUsername(String username) { + this.username = username; + return this; + } + + public Builder setPassword(String password) { + this.password = password; + return this; + } + + public Builder setTlsCertPolicy(TlsCertPolicy tlsCertPolicy) { + this.tlsCertPolicy = tlsCertPolicy; + return this; + } + + public Builder setHostname(String hostname) { + this.hostname = hostname; + return this; + } + + public Builder setTlsAlpnProtocols(List tlsAlpnProtocols) { + this.tlsAlpnProtocols = tlsAlpnProtocols; + return this; + } + + public IceServer createIceServer() { + return new IceServer(uri, username, password, tlsCertPolicy, hostname, tlsAlpnProtocols); + } } } diff --git a/webrtc/sdk/android/src/jni/jni_helpers.cc b/webrtc/sdk/android/src/jni/jni_helpers.cc index bb9bbf4703..f6b4b6f6e5 100644 --- a/webrtc/sdk/android/src/jni/jni_helpers.cc +++ b/webrtc/sdk/android/src/jni/jni_helpers.cc @@ -261,6 +261,18 @@ std::string JavaToStdString(JNIEnv* jni, const jstring& j_string) { return std::string(buf.begin(), buf.end()); } +// Given a list of jstrings, reinterprets it to a new vector of native strings. +std::vector JavaToStdVectorStrings(JNIEnv* jni, jobject list) { + std::vector converted_list; + if (list != nullptr) { + for (jobject str : Iterable(jni, list)) { + converted_list.push_back( + JavaToStdString(jni, reinterpret_cast(str))); + } + } + return converted_list; +} + // Return the (singleton) Java Enum object corresponding to |index|; jobject JavaEnumFromIndex(JNIEnv* jni, jclass state_class, const std::string& state_class_name, int index) { diff --git a/webrtc/sdk/android/src/jni/jni_helpers.h b/webrtc/sdk/android/src/jni/jni_helpers.h index 618c8f6240..cc04f8bb55 100644 --- a/webrtc/sdk/android/src/jni/jni_helpers.h +++ b/webrtc/sdk/android/src/jni/jni_helpers.h @@ -16,6 +16,7 @@ #include #include +#include #include "webrtc/rtc_base/checks.h" #include "webrtc/rtc_base/constructormagic.h" @@ -99,6 +100,10 @@ jstring JavaStringFromStdString(JNIEnv* jni, const std::string& native); // Given a (UTF-16) jstring return a new UTF-8 native string. std::string JavaToStdString(JNIEnv* jni, const jstring& j_string); +// Given a List of (UTF-16) jstrings +// return a new vector of UTF-8 native strings. +std::vector JavaToStdVectorStrings(JNIEnv* jni, jobject list); + // Return the (singleton) Java Enum object corresponding to |index|; jobject JavaEnumFromIndex(JNIEnv* jni, jclass state_class, const std::string& state_class_name, int index); diff --git a/webrtc/sdk/android/src/jni/pc/java_native_conversion.cc b/webrtc/sdk/android/src/jni/pc/java_native_conversion.cc index 43c1636548..799f67dbce 100644 --- a/webrtc/sdk/android/src/jni/pc/java_native_conversion.cc +++ b/webrtc/sdk/android/src/jni/pc/java_native_conversion.cc @@ -362,6 +362,8 @@ void JavaToNativeIceServers(JNIEnv* jni, GetObjectField(jni, j_ice_server, j_ice_server_tls_cert_policy_id); jfieldID j_ice_server_hostname_id = GetFieldID(jni, j_ice_server_class, "hostname", "Ljava/lang/String;"); + jfieldID j_ice_server_tls_alpn_protocols_id = GetFieldID( + jni, j_ice_server_class, "tlsAlpnProtocols", "Ljava/util/List;"); jstring uri = reinterpret_cast( GetObjectField(jni, j_ice_server, j_ice_server_uri_id)); jstring username = reinterpret_cast( @@ -372,12 +374,15 @@ void JavaToNativeIceServers(JNIEnv* jni, JavaToNativeTlsCertPolicy(jni, j_ice_server_tls_cert_policy); jstring hostname = reinterpret_cast( GetObjectField(jni, j_ice_server, j_ice_server_hostname_id)); + jobject tls_alpn_protocols = GetNullableObjectField( + jni, j_ice_server, j_ice_server_tls_alpn_protocols_id); PeerConnectionInterface::IceServer server; server.uri = JavaToStdString(jni, uri); server.username = JavaToStdString(jni, username); server.password = JavaToStdString(jni, password); server.tls_cert_policy = tls_cert_policy; server.hostname = JavaToStdString(jni, hostname); + server.tls_alpn_protocols = JavaToStdVectorStrings(jni, tls_alpn_protocols); ice_servers->push_back(server); } } diff --git a/webrtc/sdk/objc/Framework/Classes/PeerConnection/RTCIceServer.mm b/webrtc/sdk/objc/Framework/Classes/PeerConnection/RTCIceServer.mm index 7f276580b5..a28e237623 100644 --- a/webrtc/sdk/objc/Framework/Classes/PeerConnection/RTCIceServer.mm +++ b/webrtc/sdk/objc/Framework/Classes/PeerConnection/RTCIceServer.mm @@ -19,6 +19,7 @@ @synthesize credential = _credential; @synthesize tlsCertPolicy = _tlsCertPolicy; @synthesize hostname = _hostname; +@synthesize tlsAlpnProtocols = _tlsAlpnProtocols; - (instancetype)initWithURLStrings:(NSArray *)urlStrings { return [self initWithURLStrings:urlStrings @@ -51,6 +52,20 @@ credential:(NSString *)credential tlsCertPolicy:(RTCTlsCertPolicy)tlsCertPolicy hostname:(NSString *)hostname { + return [self initWithURLStrings:urlStrings + username:username + credential:credential + tlsCertPolicy:tlsCertPolicy + hostname:hostname + tlsAlpnProtocols:[NSMutableArray new]]; +} + +- (instancetype)initWithURLStrings:(NSArray *)urlStrings + username:(NSString *)username + credential:(NSString *)credential + tlsCertPolicy:(RTCTlsCertPolicy)tlsCertPolicy + hostname:(NSString *)hostname + tlsAlpnProtocols:(NSArray *)tlsAlpnProtocols { NSParameterAssert(urlStrings.count); if (self = [super init]) { _urlStrings = [[NSArray alloc] initWithArray:urlStrings copyItems:YES]; @@ -58,17 +73,19 @@ _credential = [credential copy]; _tlsCertPolicy = tlsCertPolicy; _hostname = [hostname copy]; + _tlsAlpnProtocols = [[NSArray alloc] initWithArray:tlsAlpnProtocols copyItems:YES]; } return self; } - (NSString *)description { - return [NSString stringWithFormat:@"RTCIceServer:\n%@\n%@\n%@\n%@\n%@", + return [NSString stringWithFormat:@"RTCIceServer:\n%@\n%@\n%@\n%@\n%@\n%@", _urlStrings, _username, _credential, [self stringForTlsCertPolicy:_tlsCertPolicy], - _hostname]; + _hostname, + _tlsAlpnProtocols]; } #pragma mark - Private @@ -89,6 +106,10 @@ iceServer.password = [NSString stdStringForString:_credential]; iceServer.hostname = [NSString stdStringForString:_hostname]; + [_tlsAlpnProtocols enumerateObjectsUsingBlock:^(NSString *proto, NSUInteger idx, BOOL *stop) { + iceServer.tls_alpn_protocols.push_back(proto.stdString); + }]; + [_urlStrings enumerateObjectsUsingBlock:^(NSString *url, NSUInteger idx, BOOL *stop) { @@ -118,6 +139,11 @@ NSString *username = [NSString stringForStdString:nativeServer.username]; NSString *credential = [NSString stringForStdString:nativeServer.password]; NSString *hostname = [NSString stringForStdString:nativeServer.hostname]; + NSMutableArray *tlsAlpnProtocols = + [NSMutableArray arrayWithCapacity:nativeServer.tls_alpn_protocols.size()]; + for (auto const &proto : nativeServer.tls_alpn_protocols) { + [tlsAlpnProtocols addObject:[NSString stringForStdString:proto]]; + } RTCTlsCertPolicy tlsCertPolicy; switch (nativeServer.tls_cert_policy) { @@ -133,7 +159,8 @@ username:username credential:credential tlsCertPolicy:tlsCertPolicy - hostname:hostname]; + hostname:hostname + tlsAlpnProtocols:tlsAlpnProtocols]; return self; } diff --git a/webrtc/sdk/objc/Framework/Headers/WebRTC/RTCIceServer.h b/webrtc/sdk/objc/Framework/Headers/WebRTC/RTCIceServer.h index 1fa006f82e..e9baa8feb4 100644 --- a/webrtc/sdk/objc/Framework/Headers/WebRTC/RTCIceServer.h +++ b/webrtc/sdk/objc/Framework/Headers/WebRTC/RTCIceServer.h @@ -43,6 +43,9 @@ RTC_EXPORT */ @property(nonatomic, readonly, nullable) NSString *hostname; +/** List of protocols to be used in the TLS ALPN extension. */ +@property(nonatomic, readonly) NSArray *tlsAlpnProtocols; + - (nonnull instancetype)init NS_UNAVAILABLE; /** Convenience initializer for a server with no authentication (e.g. STUN). */ @@ -73,7 +76,19 @@ RTC_EXPORT username:(nullable NSString *)username credential:(nullable NSString *)credential tlsCertPolicy:(RTCTlsCertPolicy)tlsCertPolicy - hostname:(nullable NSString *)hostname NS_DESIGNATED_INITIALIZER; + hostname:(nullable NSString *)hostname; + +/** + * Initialize an RTCIceServer with its associated URLs, optional username, + * optional credential, TLS cert policy, hostname and ALPN protocols. + */ +- (instancetype)initWithURLStrings:(NSArray *)urlStrings + username:(nullable NSString *)username + credential:(nullable NSString *)credential + tlsCertPolicy:(RTCTlsCertPolicy)tlsCertPolicy + hostname:(nullable NSString *)hostname + tlsAlpnProtocols:(NSArray *)tlsAlpnProtocols + NS_DESIGNATED_INITIALIZER; @end diff --git a/webrtc/sdk/objc/Framework/UnitTests/RTCIceServerTest.mm b/webrtc/sdk/objc/Framework/UnitTests/RTCIceServerTest.mm index fb25eb38c9..9d42c0768d 100644 --- a/webrtc/sdk/objc/Framework/UnitTests/RTCIceServerTest.mm +++ b/webrtc/sdk/objc/Framework/UnitTests/RTCIceServerTest.mm @@ -76,12 +76,30 @@ EXPECT_EQ("hostname", iceStruct.hostname); } +- (void)testTlsAlpnProtocols { + RTCIceServer *server = [[RTCIceServer alloc] initWithURLStrings:@[ @"turn1:turn1.example.net" ] + username:@"username" + credential:@"credential" + tlsCertPolicy:RTCTlsCertPolicySecure + hostname:@"hostname" + tlsAlpnProtocols:@[ @"proto1", @"proto2" ]]; + webrtc::PeerConnectionInterface::IceServer iceStruct = server.nativeServer; + EXPECT_EQ(1u, iceStruct.urls.size()); + EXPECT_EQ("turn1:turn1.example.net", iceStruct.urls.front()); + EXPECT_EQ("username", iceStruct.username); + EXPECT_EQ("credential", iceStruct.password); + EXPECT_EQ("hostname", iceStruct.hostname); + EXPECT_EQ(2u, iceStruct.tls_alpn_protocols.size()); +} + - (void)testInitFromNativeServer { webrtc::PeerConnectionInterface::IceServer nativeServer; nativeServer.username = "username"; nativeServer.password = "password"; nativeServer.urls.push_back("stun:stun.example.net"); nativeServer.hostname = "hostname"; + nativeServer.tls_alpn_protocols.push_back("proto1"); + nativeServer.tls_alpn_protocols.push_back("proto2"); RTCIceServer *iceServer = [[RTCIceServer alloc] initWithNativeServer:nativeServer]; @@ -91,6 +109,7 @@ EXPECT_EQ("username", [NSString stdStringForString:iceServer.username]); EXPECT_EQ("password", [NSString stdStringForString:iceServer.credential]); EXPECT_EQ("hostname", [NSString stdStringForString:iceServer.hostname]); + EXPECT_EQ(2u, iceServer.tlsAlpnProtocols.count); } @end