diff --git a/p2p/base/dtls_transport.cc b/p2p/base/dtls_transport.cc index 9d49c09894..a5e29d0103 100644 --- a/p2p/base/dtls_transport.cc +++ b/p2p/base/dtls_transport.cc @@ -14,6 +14,7 @@ #include #include +#include "absl/memory/memory.h" #include "api/rtc_event_log/rtc_event_log.h" #include "logging/rtc_event_log/events/rtc_event_dtls_transport_state.h" #include "logging/rtc_event_log/events/rtc_event_dtls_writable_state.h" @@ -325,18 +326,19 @@ bool DtlsTransport::ExportKeyingMaterial(const std::string& label, bool DtlsTransport::SetupDtls() { RTC_DCHECK(dtls_role_); - StreamInterfaceChannel* downward = new StreamInterfaceChannel(ice_transport_); + { + auto downward = std::make_unique(ice_transport_); + StreamInterfaceChannel* downward_ptr = downward.get(); - dtls_.reset(rtc::SSLStreamAdapter::Create(downward)); - if (!dtls_) { - RTC_LOG(LS_ERROR) << ToString() << ": Failed to create DTLS adapter."; - delete downward; - return false; + dtls_ = rtc::SSLStreamAdapter::Create(std::move(downward)); + if (!dtls_) { + RTC_LOG(LS_ERROR) << ToString() << ": Failed to create DTLS adapter."; + return false; + } + downward_ = downward_ptr; } - downward_ = downward; - - dtls_->SetIdentity(local_certificate_->identity()->GetReference()); + dtls_->SetIdentity(local_certificate_->identity()->Clone()); dtls_->SetMode(rtc::SSL_MODE_DTLS); dtls_->SetMaxProtocolVersion(ssl_max_version_); dtls_->SetServerRole(*dtls_role_); diff --git a/p2p/base/dtls_transport_unittest.cc b/p2p/base/dtls_transport_unittest.cc index 8ac6e9b8a6..c31062dd94 100644 --- a/p2p/base/dtls_transport_unittest.cc +++ b/p2p/base/dtls_transport_unittest.cc @@ -66,8 +66,7 @@ class DtlsTestClient : public sigslot::has_slots<> { explicit DtlsTestClient(const std::string& name) : name_(name) {} void CreateCertificate(rtc::KeyType key_type) { certificate_ = - rtc::RTCCertificate::Create(std::unique_ptr( - rtc::SSLIdentity::Generate(name_, key_type))); + rtc::RTCCertificate::Create(rtc::SSLIdentity::Create(name_, key_type)); } const rtc::scoped_refptr& certificate() { return certificate_; diff --git a/p2p/base/test_turn_server.h b/p2p/base/test_turn_server.h index 3a9da85f08..d438a83301 100644 --- a/p2p/base/test_turn_server.h +++ b/p2p/base/test_turn_server.h @@ -109,7 +109,7 @@ class TestTurnServer : public TurnAuthInterface { rtc::SSLAdapter* adapter = rtc::SSLAdapter::Create(socket); adapter->SetRole(rtc::SSL_SERVER); adapter->SetIdentity( - rtc::SSLIdentity::Generate(common_name, rtc::KeyParams())); + rtc::SSLIdentity::Create(common_name, rtc::KeyParams())); adapter->SetIgnoreBadCert(ignore_bad_cert); socket = adapter; } diff --git a/pc/channel_unittest.cc b/pc/channel_unittest.cc index c1037f7193..a3fe3f68de 100644 --- a/pc/channel_unittest.cc +++ b/pc/channel_unittest.cc @@ -179,9 +179,8 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { rtcp1 = fake_rtcp_dtls_transport1_.get(); } if (flags1 & DTLS) { - auto cert1 = - rtc::RTCCertificate::Create(std::unique_ptr( - rtc::SSLIdentity::Generate("session1", rtc::KT_DEFAULT))); + auto cert1 = rtc::RTCCertificate::Create( + rtc::SSLIdentity::Create("session1", rtc::KT_DEFAULT)); fake_rtp_dtls_transport1_->SetLocalCertificate(cert1); if (fake_rtcp_dtls_transport1_) { fake_rtcp_dtls_transport1_->SetLocalCertificate(cert1); @@ -209,9 +208,8 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { rtcp2 = fake_rtcp_dtls_transport2_.get(); } if (flags2 & DTLS) { - auto cert2 = - rtc::RTCCertificate::Create(std::unique_ptr( - rtc::SSLIdentity::Generate("session2", rtc::KT_DEFAULT))); + auto cert2 = rtc::RTCCertificate::Create( + rtc::SSLIdentity::Create("session2", rtc::KT_DEFAULT)); fake_rtp_dtls_transport2_->SetLocalCertificate(cert2); if (fake_rtcp_dtls_transport2_) { fake_rtcp_dtls_transport2_->SetLocalCertificate(cert2); diff --git a/pc/dtls_srtp_transport_unittest.cc b/pc/dtls_srtp_transport_unittest.cc index 770c140ce7..6952159a01 100644 --- a/pc/dtls_srtp_transport_unittest.cc +++ b/pc/dtls_srtp_transport_unittest.cc @@ -97,11 +97,11 @@ class DtlsSrtpTransportTest : public ::testing::Test, void CompleteDtlsHandshake(FakeDtlsTransport* fake_dtls1, FakeDtlsTransport* fake_dtls2) { - auto cert1 = rtc::RTCCertificate::Create(std::unique_ptr( - rtc::SSLIdentity::Generate("session1", rtc::KT_DEFAULT))); + auto cert1 = rtc::RTCCertificate::Create( + rtc::SSLIdentity::Create("session1", rtc::KT_DEFAULT)); fake_dtls1->SetLocalCertificate(cert1); - auto cert2 = rtc::RTCCertificate::Create(std::unique_ptr( - rtc::SSLIdentity::Generate("session1", rtc::KT_DEFAULT))); + auto cert2 = rtc::RTCCertificate::Create( + rtc::SSLIdentity::Create("session1", rtc::KT_DEFAULT)); fake_dtls2->SetLocalCertificate(cert2); fake_dtls1->SetDestination(fake_dtls2); } diff --git a/pc/dtls_transport_unittest.cc b/pc/dtls_transport_unittest.cc index f7d7a88d1e..a3f0a7ce8b 100644 --- a/pc/dtls_transport_unittest.cc +++ b/pc/dtls_transport_unittest.cc @@ -70,11 +70,11 @@ class DtlsTransportTest : public ::testing::Test { auto fake_dtls1 = static_cast(transport_->internal()); auto fake_dtls2 = std::make_unique( "audio", cricket::ICE_CANDIDATE_COMPONENT_RTP); - auto cert1 = rtc::RTCCertificate::Create(absl::WrapUnique( - rtc::SSLIdentity::Generate("session1", rtc::KT_DEFAULT))); + auto cert1 = rtc::RTCCertificate::Create( + rtc::SSLIdentity::Create("session1", rtc::KT_DEFAULT)); fake_dtls1->SetLocalCertificate(cert1); - auto cert2 = rtc::RTCCertificate::Create(absl::WrapUnique( - rtc::SSLIdentity::Generate("session1", rtc::KT_DEFAULT))); + auto cert2 = rtc::RTCCertificate::Create( + rtc::SSLIdentity::Create("session1", rtc::KT_DEFAULT)); fake_dtls2->SetLocalCertificate(cert2); fake_dtls1->SetDestination(fake_dtls2.get()); } diff --git a/pc/jsep_transport_controller_unittest.cc b/pc/jsep_transport_controller_unittest.cc index 18fdc209d1..7b18be8809 100644 --- a/pc/jsep_transport_controller_unittest.cc +++ b/pc/jsep_transport_controller_unittest.cc @@ -637,8 +637,8 @@ TEST_F(JsepTransportControllerTest, SetAndGetLocalCertificate) { CreateJsepTransportController(JsepTransportController::Config()); rtc::scoped_refptr certificate1 = - rtc::RTCCertificate::Create(std::unique_ptr( - rtc::SSLIdentity::Generate("session1", rtc::KT_DEFAULT))); + rtc::RTCCertificate::Create( + rtc::SSLIdentity::Create("session1", rtc::KT_DEFAULT)); rtc::scoped_refptr returned_certificate; auto description = std::make_unique(); @@ -662,8 +662,8 @@ TEST_F(JsepTransportControllerTest, SetAndGetLocalCertificate) { // Shouldn't be able to change the identity once set. rtc::scoped_refptr certificate2 = - rtc::RTCCertificate::Create(std::unique_ptr( - rtc::SSLIdentity::Generate("session2", rtc::KT_DEFAULT))); + rtc::RTCCertificate::Create( + rtc::SSLIdentity::Create("session2", rtc::KT_DEFAULT)); EXPECT_FALSE(transport_controller_->SetLocalCertificate(certificate2)); } @@ -691,12 +691,10 @@ TEST_F(JsepTransportControllerTest, GetRemoteSSLCertChain) { TEST_F(JsepTransportControllerTest, GetDtlsRole) { CreateJsepTransportController(JsepTransportController::Config()); - auto offer_certificate = - rtc::RTCCertificate::Create(std::unique_ptr( - rtc::SSLIdentity::Generate("offer", rtc::KT_DEFAULT))); - auto answer_certificate = - rtc::RTCCertificate::Create(std::unique_ptr( - rtc::SSLIdentity::Generate("answer", rtc::KT_DEFAULT))); + auto offer_certificate = rtc::RTCCertificate::Create( + rtc::SSLIdentity::Create("offer", rtc::KT_DEFAULT)); + auto answer_certificate = rtc::RTCCertificate::Create( + rtc::SSLIdentity::Create("answer", rtc::KT_DEFAULT)); transport_controller_->SetLocalCertificate(offer_certificate); auto offer_desc = std::make_unique(); diff --git a/pc/jsep_transport_unittest.cc b/pc/jsep_transport_unittest.cc index ccaf01b9a4..a4b1d5593e 100644 --- a/pc/jsep_transport_unittest.cc +++ b/pc/jsep_transport_unittest.cc @@ -225,11 +225,11 @@ TEST_P(JsepTransport2WithRtcpMux, SetDtlsParameters) { // Create certificates. rtc::scoped_refptr local_cert = - rtc::RTCCertificate::Create(std::unique_ptr( - rtc::SSLIdentity::Generate("local", rtc::KT_DEFAULT))); + rtc::RTCCertificate::Create( + rtc::SSLIdentity::Create("local", rtc::KT_DEFAULT)); rtc::scoped_refptr remote_cert = - rtc::RTCCertificate::Create(std::unique_ptr( - rtc::SSLIdentity::Generate("remote", rtc::KT_DEFAULT))); + rtc::RTCCertificate::Create( + rtc::SSLIdentity::Create("remote", rtc::KT_DEFAULT)); jsep_transport_->SetLocalCertificate(local_cert); // Apply offer. @@ -276,11 +276,11 @@ TEST_P(JsepTransport2WithRtcpMux, SetDtlsParametersWithPassiveAnswer) { // Create certificates. rtc::scoped_refptr local_cert = - rtc::RTCCertificate::Create(std::unique_ptr( - rtc::SSLIdentity::Generate("local", rtc::KT_DEFAULT))); + rtc::RTCCertificate::Create( + rtc::SSLIdentity::Create("local", rtc::KT_DEFAULT)); rtc::scoped_refptr remote_cert = - rtc::RTCCertificate::Create(std::unique_ptr( - rtc::SSLIdentity::Generate("remote", rtc::KT_DEFAULT))); + rtc::RTCCertificate::Create( + rtc::SSLIdentity::Create("remote", rtc::KT_DEFAULT)); jsep_transport_->SetLocalCertificate(local_cert); // Apply offer. @@ -393,8 +393,8 @@ TEST_P(JsepTransport2WithRtcpMux, VerifyCertificateFingerprint) { for (auto& key_type : key_types) { rtc::scoped_refptr certificate = - rtc::RTCCertificate::Create(std::unique_ptr( - rtc::SSLIdentity::Generate("testing", key_type))); + rtc::RTCCertificate::Create( + rtc::SSLIdentity::Create("testing", key_type)); ASSERT_NE(nullptr, certificate); std::string digest_algorithm; @@ -433,8 +433,8 @@ TEST_P(JsepTransport2WithRtcpMux, ValidDtlsRoleNegotiation) { // Just use the same certificate for both sides; doesn't really matter in a // non end-to-end test. rtc::scoped_refptr certificate = - rtc::RTCCertificate::Create(std::unique_ptr( - rtc::SSLIdentity::Generate("testing", rtc::KT_ECDSA))); + rtc::RTCCertificate::Create( + rtc::SSLIdentity::Create("testing", rtc::KT_ECDSA)); JsepTransportDescription local_description = MakeJsepTransportDescription( rtcp_mux_enabled, kIceUfrag1, kIcePwd1, certificate); @@ -532,8 +532,8 @@ TEST_P(JsepTransport2WithRtcpMux, InvalidDtlsRoleNegotiation) { // Just use the same certificate for both sides; doesn't really matter in a // non end-to-end test. rtc::scoped_refptr certificate = - rtc::RTCCertificate::Create(std::unique_ptr( - rtc::SSLIdentity::Generate("testing", rtc::KT_ECDSA))); + rtc::RTCCertificate::Create( + rtc::SSLIdentity::Create("testing", rtc::KT_ECDSA)); JsepTransportDescription local_description = MakeJsepTransportDescription( rtcp_mux_enabled, kIceUfrag1, kIcePwd1, certificate); @@ -663,8 +663,8 @@ TEST_F(JsepTransport2Test, ValidDtlsReofferFromAnswerer) { // Just use the same certificate for both sides; doesn't really matter in a // non end-to-end test. rtc::scoped_refptr certificate = - rtc::RTCCertificate::Create(std::unique_ptr( - rtc::SSLIdentity::Generate("testing", rtc::KT_ECDSA))); + rtc::RTCCertificate::Create( + rtc::SSLIdentity::Create("testing", rtc::KT_ECDSA)); bool rtcp_mux_enabled = true; jsep_transport_ = CreateJsepTransport2(rtcp_mux_enabled, SrtpMode::kDtlsSrtp); jsep_transport_->SetLocalCertificate(certificate); @@ -710,8 +710,8 @@ TEST_F(JsepTransport2Test, InvalidDtlsReofferFromAnswerer) { // Just use the same certificate for both sides; doesn't really matter in a // non end-to-end test. rtc::scoped_refptr certificate = - rtc::RTCCertificate::Create(std::unique_ptr( - rtc::SSLIdentity::Generate("testing", rtc::KT_ECDSA))); + rtc::RTCCertificate::Create( + rtc::SSLIdentity::Create("testing", rtc::KT_ECDSA)); bool rtcp_mux_enabled = true; jsep_transport_ = CreateJsepTransport2(rtcp_mux_enabled, SrtpMode::kDtlsSrtp); jsep_transport_->SetLocalCertificate(certificate); @@ -756,8 +756,8 @@ TEST_F(JsepTransport2Test, InvalidDtlsReofferFromAnswerer) { // since JSEP requires generating "actpass". TEST_F(JsepTransport2Test, RemoteOfferWithCurrentNegotiatedDtlsRole) { rtc::scoped_refptr certificate = - rtc::RTCCertificate::Create(std::unique_ptr( - rtc::SSLIdentity::Generate("testing", rtc::KT_ECDSA))); + rtc::RTCCertificate::Create( + rtc::SSLIdentity::Create("testing", rtc::KT_ECDSA)); bool rtcp_mux_enabled = true; jsep_transport_ = CreateJsepTransport2(rtcp_mux_enabled, SrtpMode::kDtlsSrtp); jsep_transport_->SetLocalCertificate(certificate); @@ -801,8 +801,8 @@ TEST_F(JsepTransport2Test, RemoteOfferWithCurrentNegotiatedDtlsRole) { // role is rejected. TEST_F(JsepTransport2Test, RemoteOfferThatChangesNegotiatedDtlsRole) { rtc::scoped_refptr certificate = - rtc::RTCCertificate::Create(std::unique_ptr( - rtc::SSLIdentity::Generate("testing", rtc::KT_ECDSA))); + rtc::RTCCertificate::Create( + rtc::SSLIdentity::Create("testing", rtc::KT_ECDSA)); bool rtcp_mux_enabled = true; jsep_transport_ = CreateJsepTransport2(rtcp_mux_enabled, SrtpMode::kDtlsSrtp); jsep_transport_->SetLocalCertificate(certificate); @@ -846,8 +846,8 @@ TEST_F(JsepTransport2Test, RemoteOfferThatChangesNegotiatedDtlsRole) { // interpreted as having an active role. TEST_F(JsepTransport2Test, DtlsSetupWithLegacyAsAnswerer) { rtc::scoped_refptr certificate = - rtc::RTCCertificate::Create(std::unique_ptr( - rtc::SSLIdentity::Generate("testing", rtc::KT_ECDSA))); + rtc::RTCCertificate::Create( + rtc::SSLIdentity::Create("testing", rtc::KT_ECDSA)); bool rtcp_mux_enabled = true; jsep_transport_ = CreateJsepTransport2(rtcp_mux_enabled, SrtpMode::kDtlsSrtp); jsep_transport_->SetLocalCertificate(certificate); @@ -1052,13 +1052,11 @@ class JsepTransport2HeaderExtensionTest this, &JsepTransport2HeaderExtensionTest::OnReadPacket2); if (mode == SrtpMode::kDtlsSrtp) { - auto cert1 = - rtc::RTCCertificate::Create(std::unique_ptr( - rtc::SSLIdentity::Generate("session1", rtc::KT_DEFAULT))); + auto cert1 = rtc::RTCCertificate::Create( + rtc::SSLIdentity::Create("session1", rtc::KT_DEFAULT)); jsep_transport1_->rtp_dtls_transport()->SetLocalCertificate(cert1); - auto cert2 = - rtc::RTCCertificate::Create(std::unique_ptr( - rtc::SSLIdentity::Generate("session1", rtc::KT_DEFAULT))); + auto cert2 = rtc::RTCCertificate::Create( + rtc::SSLIdentity::Create("session1", rtc::KT_DEFAULT)); jsep_transport2_->rtp_dtls_transport()->SetLocalCertificate(cert2); } } diff --git a/rtc_base/BUILD.gn b/rtc_base/BUILD.gn index 2c6dd3c56b..5167e5a4a5 100644 --- a/rtc_base/BUILD.gn +++ b/rtc_base/BUILD.gn @@ -757,6 +757,7 @@ rtc_library("rtc_base") { defines = [] deps = [ ":checks", + ":deprecation", ":stringutils", "../api:array_view", "../api:function_view", diff --git a/rtc_base/fake_ssl_identity.cc b/rtc_base/fake_ssl_identity.cc index 309708f62d..7374d2ebdb 100644 --- a/rtc_base/fake_ssl_identity.cc +++ b/rtc_base/fake_ssl_identity.cc @@ -94,6 +94,10 @@ FakeSSLIdentity* FakeSSLIdentity::GetReference() const { return new FakeSSLIdentity(*this); } +std::unique_ptr FakeSSLIdentity::CloneInternal() const { + return std::make_unique(*this); +} + const SSLCertificate& FakeSSLIdentity::certificate() const { return cert_chain_->Get(0); } diff --git a/rtc_base/fake_ssl_identity.h b/rtc_base/fake_ssl_identity.h index c3a8d1f171..a592154953 100644 --- a/rtc_base/fake_ssl_identity.h +++ b/rtc_base/fake_ssl_identity.h @@ -73,6 +73,8 @@ class FakeSSLIdentity : public SSLIdentity { virtual bool operator==(const SSLIdentity& other) const; private: + std::unique_ptr CloneInternal() const override; + std::unique_ptr cert_chain_; }; diff --git a/rtc_base/openssl_adapter.cc b/rtc_base/openssl_adapter.cc index 07c2b818cf..e71758b66c 100644 --- a/rtc_base/openssl_adapter.cc +++ b/rtc_base/openssl_adapter.cc @@ -20,6 +20,7 @@ #include +#include "absl/memory/memory.h" #include "rtc_base/checks.h" #include "rtc_base/location.h" #include "rtc_base/logging.h" @@ -226,6 +227,12 @@ void OpenSSLAdapter::SetIdentity(SSLIdentity* identity) { identity_.reset(static_cast(identity)); } +void OpenSSLAdapter::SetIdentity(std::unique_ptr identity) { + RTC_DCHECK(!identity_); + identity_ = + absl::WrapUnique(static_cast(identity.release())); +} + void OpenSSLAdapter::SetRole(SSLRole role) { role_ = role; } @@ -238,7 +245,7 @@ AsyncSocket* OpenSSLAdapter::Accept(SocketAddress* paddr) { } SSLAdapter* adapter = SSLAdapter::Create(socket); - adapter->SetIdentity(identity_->GetReference()); + adapter->SetIdentity(identity_->Clone()); adapter->SetRole(rtc::SSL_SERVER); adapter->SetIgnoreBadCert(ignore_bad_cert_); adapter->StartSSL("", false); diff --git a/rtc_base/openssl_adapter.h b/rtc_base/openssl_adapter.h index c3cab2fd78..7079fe39b4 100644 --- a/rtc_base/openssl_adapter.h +++ b/rtc_base/openssl_adapter.h @@ -54,6 +54,7 @@ class OpenSSLAdapter final : public SSLAdapter, public MessageHandler { void SetMode(SSLMode mode) override; void SetCertVerifier(SSLCertificateVerifier* ssl_cert_verifier) override; void SetIdentity(SSLIdentity* identity) override; + void SetIdentity(std::unique_ptr identity) override; void SetRole(SSLRole role) override; AsyncSocket* Accept(SocketAddress* paddr) override; int StartSSL(const char* hostname, bool restartable) override; diff --git a/rtc_base/openssl_identity.cc b/rtc_base/openssl_identity.cc index 8d12c07816..5b23df9f53 100644 --- a/rtc_base/openssl_identity.cc +++ b/rtc_base/openssl_identity.cc @@ -209,21 +209,24 @@ OpenSSLIdentity::OpenSSLIdentity(std::unique_ptr key_pair, OpenSSLIdentity::~OpenSSLIdentity() = default; -OpenSSLIdentity* OpenSSLIdentity::GenerateInternal( +std::unique_ptr OpenSSLIdentity::CreateInternal( const SSLIdentityParams& params) { std::unique_ptr key_pair( OpenSSLKeyPair::Generate(params.key_params)); if (key_pair) { std::unique_ptr certificate( OpenSSLCertificate::Generate(key_pair.get(), params)); - if (certificate != nullptr) - return new OpenSSLIdentity(std::move(key_pair), std::move(certificate)); + if (certificate != nullptr) { + return absl::WrapUnique( + new OpenSSLIdentity(std::move(key_pair), std::move(certificate))); + } } RTC_LOG(LS_INFO) << "Identity generation failed"; return nullptr; } -OpenSSLIdentity* OpenSSLIdentity::GenerateWithExpiration( +// static +std::unique_ptr OpenSSLIdentity::CreateWithExpiration( const std::string& common_name, const KeyParams& key_params, time_t certificate_lifetime) { @@ -235,16 +238,30 @@ OpenSSLIdentity* OpenSSLIdentity::GenerateWithExpiration( params.not_after = now + certificate_lifetime; if (params.not_before > params.not_after) return nullptr; - return GenerateInternal(params); + return CreateInternal(params); +} + +OpenSSLIdentity* OpenSSLIdentity::GenerateWithExpiration( + const std::string& common_name, + const KeyParams& key_params, + time_t certificate_lifetime) { + return CreateWithExpiration(common_name, key_params, certificate_lifetime) + .release(); +} + +std::unique_ptr OpenSSLIdentity::CreateForTest( + const SSLIdentityParams& params) { + return CreateInternal(params); } OpenSSLIdentity* OpenSSLIdentity::GenerateForTest( const SSLIdentityParams& params) { - return GenerateInternal(params); + return CreateInternal(params).release(); } -SSLIdentity* OpenSSLIdentity::FromPEMStrings(const std::string& private_key, - const std::string& certificate) { +std::unique_ptr OpenSSLIdentity::CreateFromPEMStrings( + const std::string& private_key, + const std::string& certificate) { std::unique_ptr cert( OpenSSLCertificate::FromPEMString(certificate)); if (!cert) { @@ -259,10 +276,16 @@ SSLIdentity* OpenSSLIdentity::FromPEMStrings(const std::string& private_key, return nullptr; } - return new OpenSSLIdentity(std::move(key_pair), std::move(cert)); + return absl::WrapUnique( + new OpenSSLIdentity(std::move(key_pair), std::move(cert))); } -SSLIdentity* OpenSSLIdentity::FromPEMChainStrings( +SSLIdentity* OpenSSLIdentity::FromPEMStrings(const std::string& private_key, + const std::string& certificate) { + return CreateFromPEMStrings(private_key, certificate).release(); +} + +std::unique_ptr OpenSSLIdentity::CreateFromPEMChainStrings( const std::string& private_key, const std::string& certificate_chain) { BIO* bio = BIO_new_mem_buf(certificate_chain.data(), @@ -300,8 +323,14 @@ SSLIdentity* OpenSSLIdentity::FromPEMChainStrings( return nullptr; } - return new OpenSSLIdentity(std::move(key_pair), - std::make_unique(std::move(certs))); + return absl::WrapUnique(new OpenSSLIdentity( + std::move(key_pair), std::make_unique(std::move(certs)))); +} + +SSLIdentity* OpenSSLIdentity::FromPEMChainStrings( + const std::string& private_key, + const std::string& certificate_chain) { + return CreateFromPEMChainStrings(private_key, certificate_chain).release(); } const OpenSSLCertificate& OpenSSLIdentity::certificate() const { @@ -313,8 +342,14 @@ const SSLCertChain& OpenSSLIdentity::cert_chain() const { } OpenSSLIdentity* OpenSSLIdentity::GetReference() const { - return new OpenSSLIdentity(absl::WrapUnique(key_pair_->GetReference()), - cert_chain_->Clone()); + return static_cast(CloneInternal().release()); +} + +std::unique_ptr OpenSSLIdentity::CloneInternal() const { + // We cannot use std::make_unique here because the referenced OpenSSLIdentity + // constructor is private. + return absl::WrapUnique(new OpenSSLIdentity( + absl::WrapUnique(key_pair_->GetReference()), cert_chain_->Clone())); } bool OpenSSLIdentity::ConfigureIdentity(SSL_CTX* ctx) { diff --git a/rtc_base/openssl_identity.h b/rtc_base/openssl_identity.h index f0c4fb895d..c499b06332 100644 --- a/rtc_base/openssl_identity.h +++ b/rtc_base/openssl_identity.h @@ -60,6 +60,19 @@ class OpenSSLKeyPair final { // them consistently. class OpenSSLIdentity final : public SSLIdentity { public: + static std::unique_ptr CreateWithExpiration( + const std::string& common_name, + const KeyParams& key_params, + time_t certificate_lifetime); + static std::unique_ptr CreateForTest( + const SSLIdentityParams& params); + static std::unique_ptr CreateFromPEMStrings( + const std::string& private_key, + const std::string& certificate); + static std::unique_ptr CreateFromPEMChainStrings( + const std::string& private_key, + const std::string& certificate_chain); + // Deprecated versions static OpenSSLIdentity* GenerateWithExpiration(const std::string& common_name, const KeyParams& key_params, time_t certificate_lifetime); @@ -72,7 +85,7 @@ class OpenSSLIdentity final : public SSLIdentity { const OpenSSLCertificate& certificate() const override; const SSLCertChain& cert_chain() const override; - OpenSSLIdentity* GetReference() const override; + RTC_DEPRECATED OpenSSLIdentity* GetReference() const override; // Configure an SSL context object to use our key and certificate. bool ConfigureIdentity(SSL_CTX* ctx); @@ -87,8 +100,10 @@ class OpenSSLIdentity final : public SSLIdentity { std::unique_ptr certificate); OpenSSLIdentity(std::unique_ptr key_pair, std::unique_ptr cert_chain); + std::unique_ptr CloneInternal() const override; - static OpenSSLIdentity* GenerateInternal(const SSLIdentityParams& params); + static std::unique_ptr CreateInternal( + const SSLIdentityParams& params); std::unique_ptr key_pair_; std::unique_ptr cert_chain_; diff --git a/rtc_base/openssl_stream_adapter.cc b/rtc_base/openssl_stream_adapter.cc index 32af96b65f..3fa42af6e9 100644 --- a/rtc_base/openssl_stream_adapter.cc +++ b/rtc_base/openssl_stream_adapter.cc @@ -265,8 +265,9 @@ static long stream_ctrl(BIO* b, int cmd, long num, void* ptr) { // OpenSSLStreamAdapter ///////////////////////////////////////////////////////////////////////////// -OpenSSLStreamAdapter::OpenSSLStreamAdapter(StreamInterface* stream) - : SSLStreamAdapter(stream), +OpenSSLStreamAdapter::OpenSSLStreamAdapter( + std::unique_ptr stream) + : SSLStreamAdapter(std::move(stream)), state_(SSL_NONE), role_(SSL_CLIENT), ssl_read_needs_write_(false), @@ -284,9 +285,13 @@ OpenSSLStreamAdapter::~OpenSSLStreamAdapter() { Cleanup(0); } -void OpenSSLStreamAdapter::SetIdentity(SSLIdentity* identity) { +void OpenSSLStreamAdapter::SetIdentity(std::unique_ptr identity) { RTC_DCHECK(!identity_); - identity_.reset(static_cast(identity)); + identity_.reset(static_cast(identity.release())); +} + +OpenSSLIdentity* OpenSSLStreamAdapter::GetIdentityForTesting() const { + return identity_.get(); } void OpenSSLStreamAdapter::SetServerRole(SSLRole role) { diff --git a/rtc_base/openssl_stream_adapter.h b/rtc_base/openssl_stream_adapter.h index f8dd5b1358..7ea324321b 100644 --- a/rtc_base/openssl_stream_adapter.h +++ b/rtc_base/openssl_stream_adapter.h @@ -57,10 +57,11 @@ class SSLCertChain; class OpenSSLStreamAdapter final : public SSLStreamAdapter { public: - explicit OpenSSLStreamAdapter(StreamInterface* stream); + explicit OpenSSLStreamAdapter(std::unique_ptr stream); ~OpenSSLStreamAdapter() override; - void SetIdentity(SSLIdentity* identity) override; + void SetIdentity(std::unique_ptr identity) override; + OpenSSLIdentity* GetIdentityForTesting() const override; // Default argument is for compatibility void SetServerRole(SSLRole role = SSL_SERVER) override; diff --git a/rtc_base/rtc_certificate.cc b/rtc_base/rtc_certificate.cc index 1edc393e6c..04ae99685d 100644 --- a/rtc_base/rtc_certificate.cc +++ b/rtc_base/rtc_certificate.cc @@ -64,7 +64,7 @@ RTCCertificatePEM RTCCertificate::ToPEM() const { scoped_refptr RTCCertificate::FromPEM( const RTCCertificatePEM& pem) { std::unique_ptr identity( - SSLIdentity::FromPEMStrings(pem.private_key(), pem.certificate())); + SSLIdentity::CreateFromPEMStrings(pem.private_key(), pem.certificate())); if (!identity) return nullptr; return new RefCountedObject(identity.release()); diff --git a/rtc_base/rtc_certificate_generator.cc b/rtc_base/rtc_certificate_generator.cc index cd9cccedf7..4c9d378dd2 100644 --- a/rtc_base/rtc_certificate_generator.cc +++ b/rtc_base/rtc_certificate_generator.cc @@ -109,9 +109,9 @@ scoped_refptr RTCCertificateGenerator::GenerateCertificate( return nullptr; } - SSLIdentity* identity = nullptr; + std::unique_ptr identity; if (!expires_ms) { - identity = SSLIdentity::Generate(kIdentityName, key_params); + identity = SSLIdentity::Create(kIdentityName, key_params); } else { uint64_t expires_s = *expires_ms / 1000; // Limit the expiration time to something reasonable (a year). This was @@ -123,14 +123,12 @@ scoped_refptr RTCCertificateGenerator::GenerateCertificate( // |SSLIdentity::Generate| should stop relying on |time_t|. // See bugs.webrtc.org/5720. time_t cert_lifetime_s = static_cast(expires_s); - identity = SSLIdentity::GenerateWithExpiration(kIdentityName, key_params, - cert_lifetime_s); + identity = SSLIdentity::Create(kIdentityName, key_params, cert_lifetime_s); } if (!identity) { return nullptr; } - std::unique_ptr identity_sptr(identity); - return RTCCertificate::Create(std::move(identity_sptr)); + return RTCCertificate::Create(std::move(identity)); } RTCCertificateGenerator::RTCCertificateGenerator(Thread* signaling_thread, diff --git a/rtc_base/rtc_certificate_unittest.cc b/rtc_base/rtc_certificate_unittest.cc index 1150eee0ab..96bd67ba85 100644 --- a/rtc_base/rtc_certificate_unittest.cc +++ b/rtc_base/rtc_certificate_unittest.cc @@ -33,7 +33,7 @@ class RTCCertificateTest : public ::testing::Test { protected: scoped_refptr GenerateECDSA() { std::unique_ptr identity( - SSLIdentity::Generate(kTestCertCommonName, KeyParams::ECDSA())); + SSLIdentity::Create(kTestCertCommonName, KeyParams::ECDSA())); RTC_CHECK(identity); return RTCCertificate::Create(std::move(identity)); } @@ -78,7 +78,7 @@ class RTCCertificateTest : public ::testing::Test { // is fast to generate. params.key_params = KeyParams::ECDSA(); - std::unique_ptr identity(SSLIdentity::GenerateForTest(params)); + std::unique_ptr identity(SSLIdentity::CreateForTest(params)); return RTCCertificate::Create(std::move(identity)); } }; diff --git a/rtc_base/ssl_adapter.h b/rtc_base/ssl_adapter.h index e0ed81eaf3..f72871af31 100644 --- a/rtc_base/ssl_adapter.h +++ b/rtc_base/ssl_adapter.h @@ -69,7 +69,9 @@ class SSLAdapter : public AsyncSocketAdapter { virtual void SetCertVerifier(SSLCertificateVerifier* ssl_cert_verifier) = 0; // Set the certificate this socket will present to incoming clients. - virtual void SetIdentity(SSLIdentity* identity) = 0; + // Takes ownership of |identity|. + RTC_DEPRECATED virtual void SetIdentity(SSLIdentity* identity) = 0; + virtual void SetIdentity(std::unique_ptr identity) = 0; // Choose whether the socket acts as a server socket or client socket. virtual void SetRole(SSLRole role) = 0; diff --git a/rtc_base/ssl_adapter_unittest.cc b/rtc_base/ssl_adapter_unittest.cc index 3fa12217f7..fbbde78a5a 100644 --- a/rtc_base/ssl_adapter_unittest.cc +++ b/rtc_base/ssl_adapter_unittest.cc @@ -12,6 +12,7 @@ #include #include +#include "absl/memory/memory.h" #include "rtc_base/gunit.h" #include "rtc_base/ip_address.h" #include "rtc_base/message_digest.h" @@ -163,7 +164,7 @@ class SSLAdapterTestDummyServer : public sigslot::has_slots<> { const rtc::KeyParams& key_params) : ssl_mode_(ssl_mode) { // Generate a key pair and a certificate for this host. - ssl_identity_.reset(rtc::SSLIdentity::Generate(GetHostname(), key_params)); + ssl_identity_ = rtc::SSLIdentity::Create(GetHostname(), key_params); server_socket_.reset(CreateSocket(ssl_mode_)); @@ -254,9 +255,8 @@ class SSLAdapterTestDummyServer : public sigslot::has_slots<> { private: void DoHandshake(rtc::AsyncSocket* socket) { - rtc::SocketStream* stream = new rtc::SocketStream(socket); - - ssl_stream_adapter_.reset(rtc::SSLStreamAdapter::Create(stream)); + ssl_stream_adapter_ = rtc::SSLStreamAdapter::Create( + std::make_unique(socket)); ssl_stream_adapter_->SetMode(ssl_mode_); ssl_stream_adapter_->SetServerRole(); @@ -268,7 +268,7 @@ class SSLAdapterTestDummyServer : public sigslot::has_slots<> { // Accordingly, we must disable client authentication here. ssl_stream_adapter_->SetClientAuthEnabledForTesting(false); - ssl_stream_adapter_->SetIdentity(ssl_identity_->GetReference()); + ssl_stream_adapter_->SetIdentity(ssl_identity_->Clone()); // Set a bogus peer certificate digest. unsigned char digest[20]; diff --git a/rtc_base/ssl_identity.cc b/rtc_base/ssl_identity.cc index 64c0f67297..410bb6118e 100644 --- a/rtc_base/ssl_identity.cc +++ b/rtc_base/ssl_identity.cc @@ -209,6 +209,51 @@ std::string SSLIdentity::DerToPem(const std::string& pem_type, return result.Release(); } +// static +std::unique_ptr SSLIdentity::Create(const std::string& common_name, + const KeyParams& key_param, + time_t certificate_lifetime) { + return OpenSSLIdentity::CreateWithExpiration(common_name, key_param, + certificate_lifetime); +} + +// static +std::unique_ptr SSLIdentity::Create(const std::string& common_name, + const KeyParams& key_param) { + return OpenSSLIdentity::CreateWithExpiration( + common_name, key_param, kDefaultCertificateLifetimeInSeconds); +} + +// static +std::unique_ptr SSLIdentity::Create(const std::string& common_name, + KeyType key_type) { + return OpenSSLIdentity::CreateWithExpiration( + common_name, KeyParams(key_type), kDefaultCertificateLifetimeInSeconds); +} + +// static +std::unique_ptr SSLIdentity::CreateForTest( + const SSLIdentityParams& params) { + return OpenSSLIdentity::CreateForTest(params); +} + +// Construct an identity from a private key and a certificate. +// static +std::unique_ptr SSLIdentity::CreateFromPEMStrings( + const std::string& private_key, + const std::string& certificate) { + return OpenSSLIdentity::CreateFromPEMStrings(private_key, certificate); +} + +// Construct an identity from a private key and a certificate chain. +// static +std::unique_ptr SSLIdentity::CreateFromPEMChainStrings( + const std::string& private_key, + const std::string& certificate_chain) { + return OpenSSLIdentity::CreateFromPEMChainStrings(private_key, + certificate_chain); +} + // static SSLIdentity* SSLIdentity::GenerateWithExpiration(const std::string& common_name, const KeyParams& key_params, diff --git a/rtc_base/ssl_identity.h b/rtc_base/ssl_identity.h index 30e456b24e..ae4dbea99b 100644 --- a/rtc_base/ssl_identity.h +++ b/rtc_base/ssl_identity.h @@ -15,8 +15,10 @@ #include #include +#include #include +#include "rtc_base/deprecation.h" #include "rtc_base/system/rtc_export.h" namespace rtc { @@ -107,26 +109,52 @@ class RTC_EXPORT SSLIdentity { // should be a non-negative number. // Returns null on failure. // Caller is responsible for freeing the returned object. - static SSLIdentity* GenerateWithExpiration(const std::string& common_name, + static std::unique_ptr Create(const std::string& common_name, const KeyParams& key_param, time_t certificate_lifetime); - static SSLIdentity* Generate(const std::string& common_name, - const KeyParams& key_param); - static SSLIdentity* Generate(const std::string& common_name, - KeyType key_type); + static std::unique_ptr Create(const std::string& common_name, + const KeyParams& key_param); + static std::unique_ptr Create(const std::string& common_name, + KeyType key_type); + static std::unique_ptr CreateForTest( + const SSLIdentityParams& params); + + // Construct an identity from a private key and a certificate. + static std::unique_ptr CreateFromPEMStrings( + const std::string& private_key, + const std::string& certificate); + + // Construct an identity from a private key and a certificate chain. + static std::unique_ptr CreateFromPEMChainStrings( + const std::string& private_key, + const std::string& certificate_chain); + + // Old versions of Create(). These return a pointer, but still require the + // caller to take ownership. + RTC_DEPRECATED static SSLIdentity* GenerateWithExpiration( + const std::string& common_name, + const KeyParams& key_param, + time_t certificate_lifetime); + RTC_DEPRECATED static SSLIdentity* Generate(const std::string& common_name, + const KeyParams& key_param); + RTC_DEPRECATED static SSLIdentity* Generate(const std::string& common_name, + KeyType key_type); // Generates an identity with the specified validity period. // TODO(torbjorng): Now that Generate() accepts relevant params, make tests // use that instead of this function. - static SSLIdentity* GenerateForTest(const SSLIdentityParams& params); + RTC_DEPRECATED static SSLIdentity* GenerateForTest( + const SSLIdentityParams& params); // Construct an identity from a private key and a certificate. - static SSLIdentity* FromPEMStrings(const std::string& private_key, - const std::string& certificate); + RTC_DEPRECATED static SSLIdentity* FromPEMStrings( + const std::string& private_key, + const std::string& certificate); // Construct an identity from a private key and a certificate chain. - static SSLIdentity* FromPEMChainStrings(const std::string& private_key, - const std::string& certificate_chain); + RTC_DEPRECATED static SSLIdentity* FromPEMChainStrings( + const std::string& private_key, + const std::string& certificate_chain); virtual ~SSLIdentity() {} @@ -134,7 +162,8 @@ class RTC_EXPORT SSLIdentity { // identity information. // Caller is responsible for freeing the returned object. // TODO(hbos,torbjorng): Rename to a less confusing name. - virtual SSLIdentity* GetReference() const = 0; + RTC_DEPRECATED virtual SSLIdentity* GetReference() const = 0; + std::unique_ptr Clone() const { return CloneInternal(); } // Returns a temporary reference to the end-entity (leaf) certificate. virtual const SSLCertificate& certificate() const = 0; @@ -150,6 +179,9 @@ class RTC_EXPORT SSLIdentity { static std::string DerToPem(const std::string& pem_type, const unsigned char* data, size_t length); + + protected: + virtual std::unique_ptr CloneInternal() const = 0; }; bool operator==(const SSLIdentity& a, const SSLIdentity& b); diff --git a/rtc_base/ssl_identity_unittest.cc b/rtc_base/ssl_identity_unittest.cc index 8e4d02db41..0d9d0fd859 100644 --- a/rtc_base/ssl_identity_unittest.cc +++ b/rtc_base/ssl_identity_unittest.cc @@ -194,10 +194,10 @@ IdentityAndInfo CreateFakeIdentityAndInfoFromDers( class SSLIdentityTest : public ::testing::Test { public: void SetUp() override { - identity_rsa1_.reset(SSLIdentity::Generate("test1", rtc::KT_RSA)); - identity_rsa2_.reset(SSLIdentity::Generate("test2", rtc::KT_RSA)); - identity_ecdsa1_.reset(SSLIdentity::Generate("test3", rtc::KT_ECDSA)); - identity_ecdsa2_.reset(SSLIdentity::Generate("test4", rtc::KT_ECDSA)); + identity_rsa1_ = SSLIdentity::Create("test1", rtc::KT_RSA); + identity_rsa2_ = SSLIdentity::Create("test2", rtc::KT_RSA); + identity_ecdsa1_ = SSLIdentity::Create("test3", rtc::KT_ECDSA); + identity_ecdsa2_ = SSLIdentity::Create("test4", rtc::KT_ECDSA); ASSERT_TRUE(identity_rsa1_); ASSERT_TRUE(identity_rsa2_); @@ -303,8 +303,8 @@ class SSLIdentityTest : public ::testing::Test { std::string priv_pem = identity.PrivateKeyToPEMString(); std::string publ_pem = identity.PublicKeyToPEMString(); std::string cert_pem = identity.certificate().ToPEMString(); - std::unique_ptr clone( - SSLIdentity::FromPEMStrings(priv_pem, cert_pem)); + std::unique_ptr clone = + SSLIdentity::CreateFromPEMStrings(priv_pem, cert_pem); EXPECT_TRUE(clone); // Make sure the clone is identical to the original. @@ -390,7 +390,7 @@ TEST_F(SSLIdentityTest, IdentityComparison) { TEST_F(SSLIdentityTest, FromPEMStringsRSA) { std::unique_ptr identity( - SSLIdentity::FromPEMStrings(kRSA_PRIVATE_KEY_PEM, kRSA_CERT_PEM)); + SSLIdentity::CreateFromPEMStrings(kRSA_PRIVATE_KEY_PEM, kRSA_CERT_PEM)); EXPECT_TRUE(identity); EXPECT_EQ(kRSA_PRIVATE_KEY_PEM, identity->PrivateKeyToPEMString()); EXPECT_EQ(kRSA_PUBLIC_KEY_PEM, identity->PublicKeyToPEMString()); @@ -398,8 +398,8 @@ TEST_F(SSLIdentityTest, FromPEMStringsRSA) { } TEST_F(SSLIdentityTest, FromPEMStringsEC) { - std::unique_ptr identity( - SSLIdentity::FromPEMStrings(kECDSA_PRIVATE_KEY_PEM, kECDSA_CERT_PEM)); + std::unique_ptr identity(SSLIdentity::CreateFromPEMStrings( + kECDSA_PRIVATE_KEY_PEM, kECDSA_CERT_PEM)); EXPECT_TRUE(identity); EXPECT_EQ(kECDSA_PRIVATE_KEY_PEM, identity->PrivateKeyToPEMString()); EXPECT_EQ(kECDSA_PUBLIC_KEY_PEM, identity->PublicKeyToPEMString()); @@ -433,7 +433,7 @@ TEST_F(SSLIdentityTest, GetSignatureDigestAlgorithm) { TEST_F(SSLIdentityTest, SSLCertificateGetStatsRSA) { std::unique_ptr identity( - SSLIdentity::FromPEMStrings(kRSA_PRIVATE_KEY_PEM, kRSA_CERT_PEM)); + SSLIdentity::CreateFromPEMStrings(kRSA_PRIVATE_KEY_PEM, kRSA_CERT_PEM)); std::unique_ptr stats = identity->certificate().GetStats(); EXPECT_EQ(stats->fingerprint, kRSA_FINGERPRINT); @@ -443,8 +443,8 @@ TEST_F(SSLIdentityTest, SSLCertificateGetStatsRSA) { } TEST_F(SSLIdentityTest, SSLCertificateGetStatsECDSA) { - std::unique_ptr identity( - SSLIdentity::FromPEMStrings(kECDSA_PRIVATE_KEY_PEM, kECDSA_CERT_PEM)); + std::unique_ptr identity(SSLIdentity::CreateFromPEMStrings( + kECDSA_PRIVATE_KEY_PEM, kECDSA_CERT_PEM)); std::unique_ptr stats = identity->certificate().GetStats(); EXPECT_EQ(stats->fingerprint, kECDSA_FINGERPRINT); @@ -580,14 +580,13 @@ class SSLIdentityExpirationTest : public ::testing::Test { time_t lifetime = rtc::CreateRandomId() % (0x80000000 - time_before_generation); rtc::KeyParams key_params = rtc::KeyParams::ECDSA(rtc::EC_NIST_P256); - SSLIdentity* identity = - rtc::SSLIdentity::GenerateWithExpiration("", key_params, lifetime); + auto identity = + rtc::SSLIdentity::Create("", key_params, lifetime); time_t time_after_generation = time(nullptr); EXPECT_LE(time_before_generation + lifetime, identity->certificate().CertificateExpirationTime()); EXPECT_GE(time_after_generation + lifetime, identity->certificate().CertificateExpirationTime()); - delete identity; } } }; diff --git a/rtc_base/ssl_stream_adapter.cc b/rtc_base/ssl_stream_adapter.cc index 372c37ff0d..354622e6f0 100644 --- a/rtc_base/ssl_stream_adapter.cc +++ b/rtc_base/ssl_stream_adapter.cc @@ -10,6 +10,7 @@ #include "rtc_base/ssl_stream_adapter.h" +#include "absl/memory/memory.h" #include "rtc_base/openssl_stream_adapter.h" /////////////////////////////////////////////////////////////////////////////// @@ -89,12 +90,13 @@ bool IsGcmCryptoSuiteName(const std::string& crypto_suite) { crypto_suite == CS_AEAD_AES_128_GCM); } -SSLStreamAdapter* SSLStreamAdapter::Create(StreamInterface* stream) { - return new OpenSSLStreamAdapter(stream); +std::unique_ptr SSLStreamAdapter::Create( + std::unique_ptr stream) { + return std::make_unique(std::move(stream)); } -SSLStreamAdapter::SSLStreamAdapter(StreamInterface* stream) - : StreamAdapterInterface(stream) {} +SSLStreamAdapter::SSLStreamAdapter(std::unique_ptr stream) + : StreamAdapterInterface(stream.release()) {} SSLStreamAdapter::~SSLStreamAdapter() {} diff --git a/rtc_base/ssl_stream_adapter.h b/rtc_base/ssl_stream_adapter.h index 2c317110a3..b5756a4322 100644 --- a/rtc_base/ssl_stream_adapter.h +++ b/rtc_base/ssl_stream_adapter.h @@ -17,6 +17,8 @@ #include #include +#include "absl/memory/memory.h" +#include "rtc_base/deprecation.h" #include "rtc_base/ssl_certificate.h" #include "rtc_base/ssl_identity.h" #include "rtc_base/stream.h" @@ -122,15 +124,23 @@ class SSLStreamAdapter : public StreamAdapterInterface { // Instantiate an SSLStreamAdapter wrapping the given stream, // (using the selected implementation for the platform). // Caller is responsible for freeing the returned object. - static SSLStreamAdapter* Create(StreamInterface* stream); + static std::unique_ptr Create( + std::unique_ptr stream); + RTC_DEPRECATED static SSLStreamAdapter* Create(StreamInterface* stream) { + return Create(absl::WrapUnique(stream)).release(); + } - explicit SSLStreamAdapter(StreamInterface* stream); + explicit SSLStreamAdapter(std::unique_ptr stream); ~SSLStreamAdapter() override; // Specify our SSL identity: key and certificate. SSLStream takes ownership // of the SSLIdentity object and will free it when appropriate. Should be // called no more than once on a given SSLStream instance. - virtual void SetIdentity(SSLIdentity* identity) = 0; + virtual void SetIdentity(std::unique_ptr identity) = 0; + RTC_DEPRECATED virtual void SetIdentity(SSLIdentity* identity) { + SetIdentity(absl::WrapUnique(identity)); + } + virtual SSLIdentity* GetIdentityForTesting() const = 0; // Call this to indicate that we are to play the server role (or client role, // if the default argument is replaced by SSL_CLIENT). diff --git a/rtc_base/ssl_stream_adapter_unittest.cc b/rtc_base/ssl_stream_adapter_unittest.cc index e0ddafcec2..f6d20d1607 100644 --- a/rtc_base/ssl_stream_adapter_unittest.cc +++ b/rtc_base/ssl_stream_adapter_unittest.cc @@ -13,6 +13,7 @@ #include #include +#include "absl/memory/memory.h" #include "rtc_base/buffer_queue.h" #include "rtc_base/checks.h" #include "rtc_base/gunit.h" @@ -298,8 +299,6 @@ class SSLStreamAdapterTestBase : public ::testing::Test, server_key_type_(server_key_type), client_stream_(nullptr), server_stream_(nullptr), - client_identity_(nullptr), - server_identity_(nullptr), delay_(0), mtu_(1460), loss_(0), @@ -320,23 +319,26 @@ class SSLStreamAdapterTestBase : public ::testing::Test, void SetUp() override { CreateStreams(); - client_ssl_.reset(rtc::SSLStreamAdapter::Create(client_stream_)); - server_ssl_.reset(rtc::SSLStreamAdapter::Create(server_stream_)); + client_ssl_ = + rtc::SSLStreamAdapter::Create(absl::WrapUnique(client_stream_)); + server_ssl_ = + rtc::SSLStreamAdapter::Create(absl::WrapUnique(server_stream_)); // Set up the slots client_ssl_->SignalEvent.connect(this, &SSLStreamAdapterTestBase::OnEvent); server_ssl_->SignalEvent.connect(this, &SSLStreamAdapterTestBase::OnEvent); + std::unique_ptr client_identity; if (!client_cert_pem_.empty() && !client_private_key_pem_.empty()) { - client_identity_ = rtc::SSLIdentity::FromPEMStrings( + client_identity = rtc::SSLIdentity::CreateFromPEMStrings( client_private_key_pem_, client_cert_pem_); } else { - client_identity_ = rtc::SSLIdentity::Generate("client", client_key_type_); + client_identity = rtc::SSLIdentity::Create("client", client_key_type_); } - server_identity_ = rtc::SSLIdentity::Generate("server", server_key_type_); + auto server_identity = rtc::SSLIdentity::Create("server", server_key_type_); - client_ssl_->SetIdentity(client_identity_); - server_ssl_->SetIdentity(server_identity_); + client_ssl_->SetIdentity(std::move(client_identity)); + server_ssl_->SetIdentity(std::move(server_identity)); } void TearDown() override { @@ -352,8 +354,10 @@ class SSLStreamAdapterTestBase : public ::testing::Test, void ResetIdentitiesWithValidity(int not_before, int not_after) { CreateStreams(); - client_ssl_.reset(rtc::SSLStreamAdapter::Create(client_stream_)); - server_ssl_.reset(rtc::SSLStreamAdapter::Create(server_stream_)); + client_ssl_ = + rtc::SSLStreamAdapter::Create(absl::WrapUnique(client_stream_)); + server_ssl_ = + rtc::SSLStreamAdapter::Create(absl::WrapUnique(server_stream_)); client_ssl_->SignalEvent.connect(this, &SSLStreamAdapterTestBase::OnEvent); server_ssl_->SignalEvent.connect(this, &SSLStreamAdapterTestBase::OnEvent); @@ -365,17 +369,17 @@ class SSLStreamAdapterTestBase : public ::testing::Test, client_params.common_name = "client"; client_params.not_before = now + not_before; client_params.not_after = now + not_after; - client_identity_ = rtc::SSLIdentity::GenerateForTest(client_params); + auto client_identity = rtc::SSLIdentity::CreateForTest(client_params); rtc::SSLIdentityParams server_params; server_params.key_params = rtc::KeyParams(rtc::KT_DEFAULT); server_params.common_name = "server"; server_params.not_before = now + not_before; server_params.not_after = now + not_after; - server_identity_ = rtc::SSLIdentity::GenerateForTest(server_params); + auto server_identity = rtc::SSLIdentity::CreateForTest(server_params); - client_ssl_->SetIdentity(client_identity_); - server_ssl_->SetIdentity(server_identity_); + client_ssl_->SetIdentity(std::move(client_identity)); + server_ssl_->SetIdentity(std::move(server_identity)); } virtual void OnEvent(rtc::StreamInterface* stream, int sig, int err) { @@ -404,10 +408,10 @@ class SSLStreamAdapterTestBase : public ::testing::Test, RTC_LOG(LS_INFO) << "Setting peer identities by digest"; - rv = server_identity_->certificate().ComputeDigest( + rv = server_identity()->certificate().ComputeDigest( rtc::DIGEST_SHA_1, server_digest, 20, &server_digest_len); ASSERT_TRUE(rv); - rv = client_identity_->certificate().ComputeDigest( + rv = client_identity()->certificate().ComputeDigest( rtc::DIGEST_SHA_1, client_digest, 20, &client_digest_len); ASSERT_TRUE(rv); @@ -634,6 +638,19 @@ class SSLStreamAdapterTestBase : public ::testing::Test, virtual void TestTransfer(int size) = 0; protected: + rtc::SSLIdentity* client_identity() const { + if (!client_ssl_) { + return nullptr; + } + return client_ssl_->GetIdentityForTesting(); + } + rtc::SSLIdentity* server_identity() const { + if (!server_ssl_) { + return nullptr; + } + return server_ssl_->GetIdentityForTesting(); + } + std::string client_cert_pem_; std::string client_private_key_pem_; rtc::KeyParams client_key_type_; @@ -642,8 +659,6 @@ class SSLStreamAdapterTestBase : public ::testing::Test, SSLDummyStreamBase* server_stream_; // freed by server_ssl_ destructor std::unique_ptr client_ssl_; std::unique_ptr server_ssl_; - rtc::SSLIdentity* client_identity_; // freed by client_ssl_ destructor - rtc::SSLIdentity* server_identity_; // freed by server_ssl_ destructor int delay_; size_t mtu_; int loss_; @@ -939,8 +954,10 @@ class SSLStreamAdapterTestDTLSCertChain : public SSLStreamAdapterTestDTLS { void SetUp() override { CreateStreams(); - client_ssl_.reset(rtc::SSLStreamAdapter::Create(client_stream_)); - server_ssl_.reset(rtc::SSLStreamAdapter::Create(server_stream_)); + client_ssl_ = + rtc::SSLStreamAdapter::Create(absl::WrapUnique(client_stream_)); + server_ssl_ = + rtc::SSLStreamAdapter::Create(absl::WrapUnique(server_stream_)); // Set up the slots client_ssl_->SignalEvent.connect( @@ -950,14 +967,15 @@ class SSLStreamAdapterTestDTLSCertChain : public SSLStreamAdapterTestDTLS { reinterpret_cast(this), &SSLStreamAdapterTestBase::OnEvent); + std::unique_ptr client_identity; if (!client_cert_pem_.empty() && !client_private_key_pem_.empty()) { - client_identity_ = rtc::SSLIdentity::FromPEMStrings( + client_identity = rtc::SSLIdentity::CreateFromPEMStrings( client_private_key_pem_, client_cert_pem_); } else { - client_identity_ = rtc::SSLIdentity::Generate("client", client_key_type_); + client_identity = rtc::SSLIdentity::Create("client", client_key_type_); } - client_ssl_->SetIdentity(client_identity_); + client_ssl_->SetIdentity(std::move(client_identity)); } }; @@ -975,13 +993,13 @@ TEST_P(SSLStreamAdapterTestTLS, GetPeerCertChainWithOneCertificate) { ASSERT_NE(nullptr, cert_chain); EXPECT_EQ(1u, cert_chain->GetSize()); EXPECT_EQ(cert_chain->Get(0).ToPEMString(), - server_identity_->certificate().ToPEMString()); + server_identity()->certificate().ToPEMString()); } TEST_F(SSLStreamAdapterTestDTLSCertChain, TwoCertHandshake) { - server_identity_ = rtc::SSLIdentity::FromPEMChainStrings( + auto server_identity = rtc::SSLIdentity::CreateFromPEMChainStrings( kRSA_PRIVATE_KEY_PEM, std::string(kCERT_PEM) + kCACert); - server_ssl_->SetIdentity(server_identity_); + server_ssl_->SetIdentity(std::move(server_identity)); TestHandshake(); std::unique_ptr peer_cert_chain = client_ssl_->GetPeerSSLCertChain(); @@ -992,11 +1010,8 @@ TEST_F(SSLStreamAdapterTestDTLSCertChain, TwoCertHandshake) { } TEST_F(SSLStreamAdapterTestDTLSCertChain, TwoCertHandshakeWithCopy) { - std::unique_ptr identity( - rtc::SSLIdentity::FromPEMChainStrings(kRSA_PRIVATE_KEY_PEM, - std::string(kCERT_PEM) + kCACert)); - server_identity_ = identity->GetReference(); - server_ssl_->SetIdentity(server_identity_); + server_ssl_->SetIdentity(rtc::SSLIdentity::CreateFromPEMChainStrings( + kRSA_PRIVATE_KEY_PEM, std::string(kCERT_PEM) + kCACert)); TestHandshake(); std::unique_ptr peer_cert_chain = client_ssl_->GetPeerSSLCertChain(); @@ -1007,9 +1022,8 @@ TEST_F(SSLStreamAdapterTestDTLSCertChain, TwoCertHandshakeWithCopy) { } TEST_F(SSLStreamAdapterTestDTLSCertChain, ThreeCertHandshake) { - server_identity_ = rtc::SSLIdentity::FromPEMChainStrings( - kRSA_PRIVATE_KEY_PEM, std::string(kCERT_PEM) + kIntCert1 + kCACert); - server_ssl_->SetIdentity(server_identity_); + server_ssl_->SetIdentity(rtc::SSLIdentity::CreateFromPEMChainStrings( + kRSA_PRIVATE_KEY_PEM, std::string(kCERT_PEM) + kIntCert1 + kCACert)); TestHandshake(); std::unique_ptr peer_cert_chain = client_ssl_->GetPeerSSLCertChain(); @@ -1075,7 +1089,7 @@ TEST_P(SSLStreamAdapterTestTLS, bool rv; rtc::SSLPeerCertificateDigestError err; - rv = server_identity_->certificate().ComputeDigest( + rv = server_identity()->certificate().ComputeDigest( rtc::DIGEST_SHA_1, server_digest, 20, &server_digest_len); ASSERT_TRUE(rv); @@ -1093,7 +1107,7 @@ TEST_P(SSLStreamAdapterTestTLS, TestSetPeerCertificateDigestWithInvalidLength) { bool rv; rtc::SSLPeerCertificateDigestError err; - rv = server_identity_->certificate().ComputeDigest( + rv = server_identity()->certificate().ComputeDigest( rtc::DIGEST_SHA_1, server_digest, 20, &server_digest_len); ASSERT_TRUE(rv); @@ -1476,24 +1490,26 @@ class SSLStreamAdapterTestDTLSLegacyProtocols webrtc::test::ScopedFieldTrials trial(experiment); client_stream_ = new SSLDummyStreamDTLS(this, "c2s", &client_buffer_, &server_buffer_); - client_ssl_.reset(rtc::SSLStreamAdapter::Create(client_stream_)); + client_ssl_ = + rtc::SSLStreamAdapter::Create(absl::WrapUnique(client_stream_)); client_ssl_->SignalEvent.connect( static_cast(this), &SSLStreamAdapterTestBase::OnEvent); - client_identity_ = rtc::SSLIdentity::Generate("client", client_key_type_); - client_ssl_->SetIdentity(client_identity_); + auto client_identity = rtc::SSLIdentity::Create("client", client_key_type_); + client_ssl_->SetIdentity(std::move(client_identity)); } void ConfigureServer(std::string experiment) { // webrtc::test::ScopedFieldTrials trial(experiment); server_stream_ = new SSLDummyStreamDTLS(this, "s2c", &server_buffer_, &client_buffer_); - server_ssl_.reset(rtc::SSLStreamAdapter::Create(server_stream_)); + server_ssl_ = + rtc::SSLStreamAdapter::Create(absl::WrapUnique(server_stream_)); server_ssl_->SignalEvent.connect( static_cast(this), &SSLStreamAdapterTestBase::OnEvent); - server_identity_ = rtc::SSLIdentity::Generate("server", server_key_type_); - server_ssl_->SetIdentity(server_identity_); + server_ssl_->SetIdentity( + rtc::SSLIdentity::Create("server", server_key_type_)); } }; diff --git a/test/peer_scenario/scenario_connection.cc b/test/peer_scenario/scenario_connection.cc index d6d2880920..92082f5097 100644 --- a/test/peer_scenario/scenario_connection.cc +++ b/test/peer_scenario/scenario_connection.cc @@ -85,7 +85,7 @@ ScenarioIceConnectionImpl::ScenarioIceConnectionImpl( signaling_thread_(rtc::Thread::Current()), network_thread_(manager_->network_thread()), certificate_(rtc::RTCCertificate::Create( - absl::WrapUnique(rtc::SSLIdentity::Generate("", ::rtc::KT_DEFAULT)))), + rtc::SSLIdentity::Create("", ::rtc::KT_DEFAULT))), transport_description_( /*transport_options*/ {}, rtc::CreateRandomString(cricket::ICE_UFRAG_LENGTH),