diff --git a/p2p/base/fakedtlstransport.h b/p2p/base/fakedtlstransport.h index 44a12a597b..c4f9d2c239 100644 --- a/p2p/base/fakedtlstransport.h +++ b/p2p/base/fakedtlstransport.h @@ -179,10 +179,8 @@ class FakeDtlsTransport : public DtlsTransportInternal { return local_cert_; } std::unique_ptr GetRemoteSSLCertChain() const override { - if (!remote_cert_) { - return nullptr; - } - return absl::make_unique(remote_cert_->Clone()); + return remote_cert_ ? absl::make_unique(remote_cert_) + : nullptr; } bool ExportKeyingMaterial(const std::string& label, const uint8_t* context, diff --git a/pc/peerconnection.cc b/pc/peerconnection.cc index cce75286f2..0861246151 100644 --- a/pc/peerconnection.cc +++ b/pc/peerconnection.cc @@ -3144,7 +3144,7 @@ PeerConnection::GetRemoteAudioSSLCertificate() { if (!chain || !chain->GetSize()) { return nullptr; } - return chain->Get(0).Clone(); + return chain->Get(0).GetUniqueReference(); } std::unique_ptr diff --git a/pc/rtcstatscollector_unittest.cc b/pc/rtcstatscollector_unittest.cc index e565e4610f..7404d492ef 100644 --- a/pc/rtcstatscollector_unittest.cc +++ b/pc/rtcstatscollector_unittest.cc @@ -703,7 +703,8 @@ TEST_F(RTCStatsCollectorTest, CollectRTCCertificateStatsSingle) { CreateFakeCertificateAndInfoFromDers( std::vector({"(remote) single certificate"})); pc_->SetRemoteCertChain( - kTransportName, remote_certinfo->certificate->ssl_cert_chain().Clone()); + kTransportName, + remote_certinfo->certificate->ssl_cert_chain().UniqueCopy()); rtc::scoped_refptr report = stats_->GetStatsReport(); @@ -817,7 +818,7 @@ TEST_F(RTCStatsCollectorTest, CollectRTCCertificateStatsMultiple) { std::vector({"(remote) audio"})); pc_->SetRemoteCertChain( kAudioTransport, - audio_remote_certinfo->certificate->ssl_cert_chain().Clone()); + audio_remote_certinfo->certificate->ssl_cert_chain().UniqueCopy()); pc_->AddVideoChannel("video", kVideoTransport); std::unique_ptr video_local_certinfo = @@ -829,7 +830,7 @@ TEST_F(RTCStatsCollectorTest, CollectRTCCertificateStatsMultiple) { std::vector({"(remote) video"})); pc_->SetRemoteCertChain( kVideoTransport, - video_remote_certinfo->certificate->ssl_cert_chain().Clone()); + video_remote_certinfo->certificate->ssl_cert_chain().UniqueCopy()); rtc::scoped_refptr report = stats_->GetStatsReport(); ExpectReportContainsCertificateInfo(report, *audio_local_certinfo); @@ -853,7 +854,8 @@ TEST_F(RTCStatsCollectorTest, CollectRTCCertificateStatsChain) { "(remote) another", "(remote) chain"}); pc_->SetRemoteCertChain( - kTransportName, remote_certinfo->certificate->ssl_cert_chain().Clone()); + kTransportName, + remote_certinfo->certificate->ssl_cert_chain().UniqueCopy()); rtc::scoped_refptr report = stats_->GetStatsReport(); ExpectReportContainsCertificateInfo(report, *local_certinfo); @@ -1954,7 +1956,8 @@ TEST_F(RTCStatsCollectorTest, CollectRTCTransportStats) { CreateFakeCertificateAndInfoFromDers( {"(remote) local", "(remote) chain"}); pc_->SetRemoteCertChain( - kTransportName, remote_certinfo->certificate->ssl_cert_chain().Clone()); + kTransportName, + remote_certinfo->certificate->ssl_cert_chain().UniqueCopy()); report = stats_->GetFreshStatsReport(); diff --git a/pc/statscollector_unittest.cc b/pc/statscollector_unittest.cc index cbd7cc37dd..7c61b385a9 100644 --- a/pc/statscollector_unittest.cc +++ b/pc/statscollector_unittest.cc @@ -642,7 +642,7 @@ class StatsCollectorTest : public testing::Test { std::unique_ptr(local_identity.GetReference()))); pc->SetLocalCertificate(kTransportName, local_certificate); pc->SetRemoteCertChain(kTransportName, - remote_identity.cert_chain().Clone()); + remote_identity.cert_chain().UniqueCopy()); stats->UpdateStats(PeerConnectionInterface::kStatsOutputLevelStandard); diff --git a/pc/test/fakepeerconnectionforstats.h b/pc/test/fakepeerconnectionforstats.h index af86639eca..ae329e4450 100644 --- a/pc/test/fakepeerconnectionforstats.h +++ b/pc/test/fakepeerconnectionforstats.h @@ -319,7 +319,7 @@ class FakePeerConnectionForStats : public FakePeerConnectionBase { const std::string& transport_name) override { auto it = remote_cert_chains_by_transport_.find(transport_name); if (it != remote_cert_chains_by_transport_.end()) { - return it->second->Clone(); + return it->second->UniqueCopy(); } else { return nullptr; } diff --git a/rtc_base/fakesslidentity.cc b/rtc_base/fakesslidentity.cc index 62ac9dd020..80a3e78887 100644 --- a/rtc_base/fakesslidentity.cc +++ b/rtc_base/fakesslidentity.cc @@ -29,8 +29,8 @@ FakeSSLCertificate::FakeSSLCertificate(const FakeSSLCertificate&) = default; FakeSSLCertificate::~FakeSSLCertificate() = default; -std::unique_ptr FakeSSLCertificate::Clone() const { - return absl::make_unique(*this); +FakeSSLCertificate* FakeSSLCertificate::GetReference() const { + return new FakeSSLCertificate(*this); } std::string FakeSSLCertificate::ToPEMString() const { @@ -83,10 +83,10 @@ FakeSSLIdentity::FakeSSLIdentity(const std::vector& pem_strings) { } FakeSSLIdentity::FakeSSLIdentity(const FakeSSLCertificate& cert) - : cert_chain_(absl::make_unique(cert.Clone())) {} + : cert_chain_(absl::make_unique(&cert)) {} FakeSSLIdentity::FakeSSLIdentity(const FakeSSLIdentity& o) - : cert_chain_(o.cert_chain_->Clone()) {} + : cert_chain_(o.cert_chain_->UniqueCopy()) {} FakeSSLIdentity::~FakeSSLIdentity() = default; diff --git a/rtc_base/fakesslidentity.h b/rtc_base/fakesslidentity.h index 9d5770ce28..4494a524ef 100644 --- a/rtc_base/fakesslidentity.h +++ b/rtc_base/fakesslidentity.h @@ -28,7 +28,7 @@ class FakeSSLCertificate : public SSLCertificate { ~FakeSSLCertificate() override; // SSLCertificate implementation. - std::unique_ptr Clone() const override; + FakeSSLCertificate* GetReference() const override; std::string ToPEMString() const override; void ToDER(Buffer* der_buffer) const override; int64_t CertificateExpirationTime() const override; diff --git a/rtc_base/opensslcertificate.cc b/rtc_base/opensslcertificate.cc index 92443a4458..ed67a8938e 100644 --- a/rtc_base/opensslcertificate.cc +++ b/rtc_base/opensslcertificate.cc @@ -130,11 +130,10 @@ static void PrintCert(X509* x509) { #endif OpenSSLCertificate::OpenSSLCertificate(X509* x509) : x509_(x509) { - RTC_DCHECK(x509_ != nullptr); - X509_up_ref(x509_); + AddReference(); } -std::unique_ptr OpenSSLCertificate::Generate( +OpenSSLCertificate* OpenSSLCertificate::Generate( OpenSSLKeyPair* key_pair, const SSLIdentityParams& params) { SSLIdentityParams actual_params(params); @@ -150,12 +149,12 @@ std::unique_ptr OpenSSLCertificate::Generate( #if !defined(NDEBUG) PrintCert(x509); #endif - auto ret = absl::make_unique(x509); + OpenSSLCertificate* ret = new OpenSSLCertificate(x509); X509_free(x509); return ret; } -std::unique_ptr OpenSSLCertificate::FromPEMString( +OpenSSLCertificate* OpenSSLCertificate::FromPEMString( const std::string& pem_string) { BIO* bio = BIO_new_mem_buf(const_cast(pem_string.c_str()), -1); if (!bio) @@ -168,7 +167,7 @@ std::unique_ptr OpenSSLCertificate::FromPEMString( if (!x509) return nullptr; - auto ret = absl::make_unique(x509); + OpenSSLCertificate* ret = new OpenSSLCertificate(x509); X509_free(x509); return ret; } @@ -250,8 +249,8 @@ OpenSSLCertificate::~OpenSSLCertificate() { X509_free(x509_); } -std::unique_ptr OpenSSLCertificate::Clone() const { - return absl::make_unique(x509_); +OpenSSLCertificate* OpenSSLCertificate::GetReference() const { + return new OpenSSLCertificate(x509_); } std::string OpenSSLCertificate::ToPEMString() const { @@ -290,6 +289,11 @@ void OpenSSLCertificate::ToDER(Buffer* der_buffer) const { BIO_free(bio); } +void OpenSSLCertificate::AddReference() const { + RTC_DCHECK(x509_ != nullptr); + X509_up_ref(x509_); +} + bool OpenSSLCertificate::operator==(const OpenSSLCertificate& other) const { return X509_cmp(x509_, other.x509_) == 0; } diff --git a/rtc_base/opensslcertificate.h b/rtc_base/opensslcertificate.h index 3b49f93ef5..b7ecc3b78d 100644 --- a/rtc_base/opensslcertificate.h +++ b/rtc_base/opensslcertificate.h @@ -36,15 +36,13 @@ class OpenSSLCertificate : public SSLCertificate { // OpenSSLCertificate share ownership. explicit OpenSSLCertificate(X509* x509); - static std::unique_ptr Generate( - OpenSSLKeyPair* key_pair, - const SSLIdentityParams& params); - static std::unique_ptr FromPEMString( - const std::string& pem_string); + static OpenSSLCertificate* Generate(OpenSSLKeyPair* key_pair, + const SSLIdentityParams& params); + static OpenSSLCertificate* FromPEMString(const std::string& pem_string); ~OpenSSLCertificate() override; - std::unique_ptr Clone() const override; + OpenSSLCertificate* GetReference() const override; X509* x509() const { return x509_; } @@ -71,6 +69,8 @@ class OpenSSLCertificate : public SSLCertificate { int64_t CertificateExpirationTime() const override; private: + void AddReference() const; + X509* x509_; // NOT OWNED RTC_DISALLOW_COPY_AND_ASSIGN(OpenSSLCertificate); }; diff --git a/rtc_base/opensslidentity.cc b/rtc_base/opensslidentity.cc index a5bbd5d72d..a8c6919779 100644 --- a/rtc_base/opensslidentity.cc +++ b/rtc_base/opensslidentity.cc @@ -316,7 +316,7 @@ const SSLCertChain& OpenSSLIdentity::cert_chain() const { OpenSSLIdentity* OpenSSLIdentity::GetReference() const { return new OpenSSLIdentity(absl::WrapUnique(key_pair_->GetReference()), - cert_chain_->Clone()); + absl::WrapUnique(cert_chain_->Copy())); } bool OpenSSLIdentity::ConfigureIdentity(SSL_CTX* ctx) { diff --git a/rtc_base/opensslstreamadapter.cc b/rtc_base/opensslstreamadapter.cc index 38b4bb670d..fd54a082a7 100644 --- a/rtc_base/opensslstreamadapter.cc +++ b/rtc_base/opensslstreamadapter.cc @@ -1091,7 +1091,7 @@ bool OpenSSLStreamAdapter::VerifyPeerCertificate() { std::unique_ptr OpenSSLStreamAdapter::GetPeerSSLCertChain() const { - return peer_cert_chain_ ? peer_cert_chain_->Clone() : nullptr; + return peer_cert_chain_ ? peer_cert_chain_->UniqueCopy() : nullptr; } int OpenSSLStreamAdapter::SSLVerifyCallback(X509_STORE_CTX* store, void* arg) { diff --git a/rtc_base/sslcertificate.cc b/rtc_base/sslcertificate.cc index 142543fe51..e40feec219 100644 --- a/rtc_base/sslcertificate.cc +++ b/rtc_base/sslcertificate.cc @@ -30,7 +30,7 @@ SSLCertificateStats::SSLCertificateStats( std::string&& fingerprint, std::string&& fingerprint_algorithm, std::string&& base64_certificate, - std::unique_ptr issuer) + std::unique_ptr&& issuer) : fingerprint(std::move(fingerprint)), fingerprint_algorithm(std::move(fingerprint_algorithm)), base64_certificate(std::move(base64_certificate)), @@ -70,30 +70,49 @@ std::unique_ptr SSLCertificate::GetStats() const { std::move(der_base64), nullptr); } +std::unique_ptr SSLCertificate::GetUniqueReference() const { + return absl::WrapUnique(GetReference()); +} + ////////////////////////////////////////////////////////////////////// // SSLCertChain ////////////////////////////////////////////////////////////////////// -SSLCertChain::SSLCertChain(std::unique_ptr single_cert) { - certs_.push_back(std::move(single_cert)); -} - SSLCertChain::SSLCertChain(std::vector> certs) : certs_(std::move(certs)) {} +SSLCertChain::SSLCertChain(const std::vector& certs) { + RTC_DCHECK(!certs.empty()); + certs_.resize(certs.size()); + std::transform( + certs.begin(), certs.end(), certs_.begin(), + [](const SSLCertificate* cert) -> std::unique_ptr { + return cert->GetUniqueReference(); + }); +} + +SSLCertChain::SSLCertChain(const SSLCertificate* cert) { + certs_.push_back(cert->GetUniqueReference()); +} + SSLCertChain::SSLCertChain(SSLCertChain&& rhs) = default; SSLCertChain& SSLCertChain::operator=(SSLCertChain&&) = default; -SSLCertChain::~SSLCertChain() = default; +SSLCertChain::~SSLCertChain() {} -std::unique_ptr SSLCertChain::Clone() const { +SSLCertChain* SSLCertChain::Copy() const { std::vector> new_certs(certs_.size()); - std::transform( - certs_.begin(), certs_.end(), new_certs.begin(), - [](const std::unique_ptr& cert) - -> std::unique_ptr { return cert->Clone(); }); - return absl::make_unique(std::move(new_certs)); + std::transform(certs_.begin(), certs_.end(), new_certs.begin(), + [](const std::unique_ptr& cert) + -> std::unique_ptr { + return cert->GetUniqueReference(); + }); + return new SSLCertChain(std::move(new_certs)); +} + +std::unique_ptr SSLCertChain::UniqueCopy() const { + return absl::WrapUnique(Copy()); } std::unique_ptr SSLCertChain::GetStats() const { @@ -115,8 +134,7 @@ std::unique_ptr SSLCertChain::GetStats() const { } // static -std::unique_ptr SSLCertificate::FromPEMString( - const std::string& pem_string) { +SSLCertificate* SSLCertificate::FromPEMString(const std::string& pem_string) { return OpenSSLCertificate::FromPEMString(pem_string); } diff --git a/rtc_base/sslcertificate.h b/rtc_base/sslcertificate.h index c04852f2f2..029404cf3e 100644 --- a/rtc_base/sslcertificate.h +++ b/rtc_base/sslcertificate.h @@ -28,7 +28,7 @@ struct SSLCertificateStats { SSLCertificateStats(std::string&& fingerprint, std::string&& fingerprint_algorithm, std::string&& base64_certificate, - std::unique_ptr issuer); + std::unique_ptr&& issuer); ~SSLCertificateStats(); std::string fingerprint; std::string fingerprint_algorithm; @@ -51,13 +51,17 @@ class SSLCertificate { // The length of the string representation of the certificate is // stored in *pem_length if it is non-null, and only if // parsing was successful. - static std::unique_ptr FromPEMString( - const std::string& pem_string); - virtual ~SSLCertificate() = default; + // Caller is responsible for freeing the returned object. + static SSLCertificate* FromPEMString(const std::string& pem_string); + virtual ~SSLCertificate() {} // Returns a new SSLCertificate object instance wrapping the same - // underlying certificate, including its chain if present. - virtual std::unique_ptr Clone() const = 0; + // underlying certificate, including its chain if present. Caller is + // responsible for freeing the returned object. Use GetUniqueReference + // instead. + virtual SSLCertificate* GetReference() const = 0; + + std::unique_ptr GetUniqueReference() const; // Returns a PEM encoded string representation of the certificate. virtual std::string ToPEMString() const = 0; @@ -90,8 +94,11 @@ class SSLCertificate { // SSLCertificate pointers. class SSLCertChain { public: - explicit SSLCertChain(std::unique_ptr single_cert); explicit SSLCertChain(std::vector> certs); + // These constructors copy the provided SSLCertificate(s), so the caller + // retains ownership. + explicit SSLCertChain(const std::vector& certs); + explicit SSLCertChain(const SSLCertificate* cert); // Allow move semantics for the object. SSLCertChain(SSLCertChain&&); SSLCertChain& operator=(SSLCertChain&&); @@ -105,8 +112,10 @@ class SSLCertChain { const SSLCertificate& Get(size_t pos) const { return *(certs_[pos]); } // Returns a new SSLCertChain object instance wrapping the same underlying - // certificate chain. - std::unique_ptr Clone() const; + // certificate chain. Caller is responsible for freeing the returned object. + SSLCertChain* Copy() const; + // Same as above, but returning a unique_ptr for convenience. + std::unique_ptr UniqueCopy() const; // Gets information (fingerprint, etc.) about this certificate chain. This is // used for certificate stats, see diff --git a/rtc_base/sslidentity_unittest.cc b/rtc_base/sslidentity_unittest.cc index 9560aaee4e..68b582839f 100644 --- a/rtc_base/sslidentity_unittest.cc +++ b/rtc_base/sslidentity_unittest.cc @@ -201,7 +201,7 @@ class SSLIdentityTest : public testing::Test { ASSERT_TRUE(identity_ecdsa1_); ASSERT_TRUE(identity_ecdsa2_); - test_cert_ = rtc::SSLCertificate::FromPEMString(kTestCertificate); + test_cert_.reset(rtc::SSLCertificate::FromPEMString(kTestCertificate)); ASSERT_TRUE(test_cert_); } diff --git a/rtc_base/sslstreamadapter_unittest.cc b/rtc_base/sslstreamadapter_unittest.cc index ff4c7a0d92..389b0eaaf1 100644 --- a/rtc_base/sslstreamadapter_unittest.cc +++ b/rtc_base/sslstreamadapter_unittest.cc @@ -588,7 +588,8 @@ class SSLStreamAdapterTestBase : public testing::Test, chain = client_ssl_->GetPeerSSLCertChain(); else chain = server_ssl_->GetPeerSSLCertChain(); - return (chain && chain->GetSize()) ? chain->Get(0).Clone() : nullptr; + return (chain && chain->GetSize()) ? chain->Get(0).GetUniqueReference() + : nullptr; } bool GetSslCipherSuite(bool client, int* retval) {