diff --git a/talk/app/webrtc/statscollector.cc b/talk/app/webrtc/statscollector.cc index 06c4b44b08..f5ff708a02 100644 --- a/talk/app/webrtc/statscollector.cc +++ b/talk/app/webrtc/statscollector.cc @@ -30,6 +30,8 @@ #include #include +#include "talk/base/base64.h" +#include "talk/base/scoped_ptr.h" #include "talk/session/media/channel.h" namespace webrtc { @@ -52,6 +54,7 @@ const char StatsReport::kStatsValueNameChannelId[] = "googChannelId"; const char StatsReport::kStatsValueNameCodecName[] = "googCodecName"; const char StatsReport::kStatsValueNameComponent[] = "googComponent"; const char StatsReport::kStatsValueNameContentName[] = "googContentName"; +const char StatsReport::kStatsValueNameDer[] = "googDerBase64"; // Echo metrics from the audio processing module. const char StatsReport::kStatsValueNameEchoCancellationQualityMin[] = "googEchoCancellationQualityMin"; @@ -64,6 +67,7 @@ const char StatsReport::kStatsValueNameEchoReturnLoss[] = const char StatsReport::kStatsValueNameEchoReturnLossEnhancement[] = "googEchoCancellationReturnLossEnhancement"; +const char StatsReport::kStatsValueNameFingerprint[] = "googFingerprint"; const char StatsReport::kStatsValueNameFirsReceived[] = "googFirsReceived"; const char StatsReport::kStatsValueNameFirsSent[] = "googFirsSent"; const char StatsReport::kStatsValueNameFrameHeightReceived[] = @@ -82,8 +86,11 @@ const char StatsReport::kStatsValueNameFrameWidthReceived[] = "googFrameWidthReceived"; const char StatsReport::kStatsValueNameFrameWidthSent[] = "googFrameWidthSent"; const char StatsReport::kStatsValueNameInitiator[] = "googInitiator"; +const char StatsReport::kStatsValueNameIssuerId[] = "googIssuerId"; const char StatsReport::kStatsValueNameJitterReceived[] = "googJitterReceived"; const char StatsReport::kStatsValueNameLocalAddress[] = "googLocalAddress"; +const char StatsReport::kStatsValueNameLocalCertificateId[] = + "googLocalCertificateId"; const char StatsReport::kStatsValueNameNacksReceived[] = "googNacksReceived"; const char StatsReport::kStatsValueNameNacksSent[] = "googNacksSent"; const char StatsReport::kStatsValueNamePacketsReceived[] = "packetsReceived"; @@ -91,6 +98,8 @@ const char StatsReport::kStatsValueNamePacketsSent[] = "packetsSent"; const char StatsReport::kStatsValueNamePacketsLost[] = "packetsLost"; const char StatsReport::kStatsValueNameReadable[] = "googReadable"; const char StatsReport::kStatsValueNameRemoteAddress[] = "googRemoteAddress"; +const char StatsReport::kStatsValueNameRemoteCertificateId[] = + "googRemoteCertificateId"; const char StatsReport::kStatsValueNameRetransmitBitrate[] = "googRetransmitBitrate"; const char StatsReport::kStatsValueNameRtt[] = "googRtt"; @@ -114,6 +123,7 @@ const char StatsReport::kStatsReportTypeIceCandidate[] = "iceCandidate"; const char StatsReport::kStatsReportTypeTransport[] = "googTransport"; const char StatsReport::kStatsReportTypeComponent[] = "googComponent"; const char StatsReport::kStatsReportTypeCandidatePair[] = "googCandidatePair"; +const char StatsReport::kStatsReportTypeCertificate[] = "googCertificate"; const char StatsReport::kStatsReportVideoBweId[] = "bweforvideo"; @@ -434,6 +444,58 @@ StatsReport* StatsCollector::PrepareReport(uint32 ssrc, return report; } +std::string StatsCollector::AddOneCertificateReport( + const talk_base::SSLCertificate* cert, const std::string& issuer_id) { + // TODO(bemasc): Move this computation to a helper class that caches these + // values to reduce CPU use in GetStats. This will require adding a fast + // SSLCertificate::Equals() method to detect certificate changes. + talk_base::scoped_ptr ssl_fingerprint( + talk_base::SSLFingerprint::Create(talk_base::DIGEST_SHA_256, cert)); + std::string fingerprint = ssl_fingerprint->GetRfc4572Fingerprint(); + + talk_base::Buffer der_buffer; + cert->ToDER(&der_buffer); + std::string der_base64; + talk_base::Base64::EncodeFromArray( + der_buffer.data(), der_buffer.length(), &der_base64); + + StatsReport report; + report.type = StatsReport::kStatsReportTypeCertificate; + report.id = StatsId(report.type, fingerprint); + report.timestamp = stats_gathering_started_; + report.AddValue(StatsReport::kStatsValueNameFingerprint, fingerprint); + report.AddValue(StatsReport::kStatsValueNameDer, der_base64); + if (!issuer_id.empty()) + report.AddValue(StatsReport::kStatsValueNameIssuerId, issuer_id); + reports_[report.id] = report; + return report.id; +} + +std::string StatsCollector::AddCertificateReports( + const talk_base::SSLCertificate* cert) { + // Produces a chain of StatsReports representing this certificate and the rest + // of its chain, and adds those reports to |reports_|. The return value is + // the id of the leaf report. The provided cert must be non-null, so at least + // one report will always be provided and the returned string will never be + // empty. + ASSERT(cert != NULL); + + std::string issuer_id; + talk_base::scoped_ptr chain; + if (cert->GetChain(chain.accept())) { + // This loop runs in reverse, i.e. from root to leaf, so that each + // certificate's issuer's report ID is known before the child certificate's + // report is generated. The root certificate does not have an issuer ID + // value. + for (ptrdiff_t i = chain->GetSize() - 1; i >= 0; --i) { + const talk_base::SSLCertificate& cert_i = chain->Get(i); + issuer_id = AddOneCertificateReport(&cert_i, issuer_id); + } + } + // Add the leaf certificate. + return AddOneCertificateReport(cert, issuer_id); +} + void StatsCollector::ExtractSessionInfo() { // Extract information from the base session. StatsReport report; @@ -454,6 +516,22 @@ void StatsCollector::ExtractSessionInfo() { for (cricket::TransportStatsMap::iterator transport_iter = stats.transport_stats.begin(); transport_iter != stats.transport_stats.end(); ++transport_iter) { + // Attempt to get a copy of the certificates from the transport and + // expose them in stats reports. All channels in a transport share the + // same local and remote certificates. + std::string local_cert_report_id, remote_cert_report_id; + cricket::Transport* transport = + session_->GetTransport(transport_iter->second.content_name); + if (transport) { + talk_base::scoped_ptr identity; + if (transport->GetIdentity(identity.accept())) + local_cert_report_id = AddCertificateReports( + &(identity->certificate())); + + talk_base::scoped_ptr cert; + if (transport->GetRemoteCertificate(cert.accept())) + remote_cert_report_id = AddCertificateReports(cert.get()); + } for (cricket::TransportChannelStatsList::iterator channel_iter = transport_iter->second.channel_stats.begin(); channel_iter != transport_iter->second.channel_stats.end(); @@ -467,6 +545,14 @@ void StatsCollector::ExtractSessionInfo() { channel_report.timestamp = stats_gathering_started_; channel_report.AddValue(StatsReport::kStatsValueNameComponent, channel_iter->component); + if (!local_cert_report_id.empty()) + channel_report.AddValue( + StatsReport::kStatsValueNameLocalCertificateId, + local_cert_report_id); + if (!remote_cert_report_id.empty()) + channel_report.AddValue( + StatsReport::kStatsValueNameRemoteCertificateId, + remote_cert_report_id); reports_[channel_report.id] = channel_report; for (size_t i = 0; i < channel_iter->connection_infos.size(); diff --git a/talk/app/webrtc/statscollector.h b/talk/app/webrtc/statscollector.h index 03a32c4934..c34b5a0b6f 100644 --- a/talk/app/webrtc/statscollector.h +++ b/talk/app/webrtc/statscollector.h @@ -75,6 +75,14 @@ class StatsCollector { private: bool CopySelectedReports(const std::string& selector, StatsReports* reports); + // Helper method for AddCertificateReports. + std::string AddOneCertificateReport( + const talk_base::SSLCertificate* cert, const std::string& issuer_id); + + // Adds a report for this certificate and every certificate in its chain, and + // returns the leaf certificate's report's ID. + std::string AddCertificateReports(const talk_base::SSLCertificate* cert); + void ExtractSessionInfo(); void ExtractVoiceInfo(); void ExtractVideoInfo(); diff --git a/talk/app/webrtc/statscollector_unittest.cc b/talk/app/webrtc/statscollector_unittest.cc index cce1645bca..982983217c 100644 --- a/talk/app/webrtc/statscollector_unittest.cc +++ b/talk/app/webrtc/statscollector_unittest.cc @@ -30,6 +30,8 @@ #include "talk/app/webrtc/mediastream.h" #include "talk/app/webrtc/videotrack.h" +#include "talk/base/base64.h" +#include "talk/base/fakesslidentity.h" #include "talk/base/gunit.h" #include "talk/media/base/fakemediaengine.h" #include "talk/media/devices/fakedevicemanager.h" @@ -60,11 +62,12 @@ class MockWebRtcSession : public webrtc::WebRtcSession { public: explicit MockWebRtcSession(cricket::ChannelManager* channel_manager) : WebRtcSession(channel_manager, talk_base::Thread::Current(), - NULL, NULL, NULL) { + talk_base::Thread::Current(), NULL, NULL) { } MOCK_METHOD0(video_channel, cricket::VideoChannel*()); MOCK_METHOD2(GetTrackIdBySsrc, bool(uint32, std::string*)); MOCK_METHOD1(GetStats, bool(cricket::SessionStats*)); + MOCK_METHOD1(GetTransport, cricket::Transport*(const std::string&)); }; class MockVideoMediaChannel : public cricket::FakeVideoMediaChannel { @@ -76,8 +79,21 @@ class MockVideoMediaChannel : public cricket::FakeVideoMediaChannel { MOCK_METHOD1(GetStats, bool(cricket::VideoMediaInfo*)); }; +bool GetValue(const webrtc::StatsReport* report, + const std::string& name, + std::string* value) { + webrtc::StatsReport::Values::const_iterator it = report->values.begin(); + for (; it != report->values.end(); ++it) { + if (it->name == name) { + *value = it->value; + return true; + } + } + return false; +} + std::string ExtractStatsValue(const std::string& type, - webrtc::StatsReports reports, + const webrtc::StatsReports& reports, const std::string name) { if (reports.empty()) { return kNoReports; @@ -85,12 +101,9 @@ std::string ExtractStatsValue(const std::string& type, for (size_t i = 0; i < reports.size(); ++i) { if (reports[i].type != type) continue; - webrtc::StatsReport::Values::const_iterator it = - reports[i].values.begin(); - for (; it != reports[i].values.end(); ++it) { - if (it->name == name) { - return it->value; - } + std::string ret; + if (GetValue(&reports[i], name, &ret)) { + return ret; } } @@ -99,9 +112,8 @@ std::string ExtractStatsValue(const std::string& type, // Finds the |n|-th report of type |type| in |reports|. // |n| starts from 1 for finding the first report. -const webrtc::StatsReport* FindNthReportByType(webrtc::StatsReports reports, - const std::string& type, - int n) { +const webrtc::StatsReport* FindNthReportByType( + const webrtc::StatsReports& reports, const std::string& type, int n) { for (size_t i = 0; i < reports.size(); ++i) { if (reports[i].type == type) { n--; @@ -112,7 +124,7 @@ const webrtc::StatsReport* FindNthReportByType(webrtc::StatsReports reports, return NULL; } -const webrtc::StatsReport* FindReportById(webrtc::StatsReports reports, +const webrtc::StatsReport* FindReportById(const webrtc::StatsReports& reports, const std::string& id) { for (size_t i = 0; i < reports.size(); ++i) { if (reports[i].id == id) { @@ -134,6 +146,42 @@ std::string ExtractBweStatsValue(webrtc::StatsReports reports, webrtc::StatsReport::kStatsReportTypeBwe, reports, name); } +std::string DerToPem(const std::string& der) { + return talk_base::SSLIdentity::DerToPem( + talk_base::kPemTypeCertificate, + reinterpret_cast(der.c_str()), + der.length()); +} + +std::vector DersToPems( + const std::vector& ders) { + std::vector pems(ders.size()); + std::transform(ders.begin(), ders.end(), pems.begin(), DerToPem); + return pems; +} + +void CheckCertChainReports(const webrtc::StatsReports& reports, + const std::vector& ders, + const std::string& start_id) { + std::string certificate_id = start_id; + size_t i = 0; + while (true) { + const webrtc::StatsReport* report = FindReportById(reports, certificate_id); + ASSERT_TRUE(report != NULL); + std::string der_base64; + EXPECT_TRUE(GetValue( + report, webrtc::StatsReport::kStatsValueNameDer, &der_base64)); + std::string der = talk_base::Base64::Decode(der_base64, + talk_base::Base64::DO_STRICT); + EXPECT_EQ(ders[i], der); + ++i; + if (!GetValue( + report, webrtc::StatsReport::kStatsValueNameIssuerId, &certificate_id)) + break; + } + EXPECT_EQ(ders.size(), i); +} + class StatsCollectorTest : public testing::Test { protected: StatsCollectorTest() @@ -147,6 +195,77 @@ class StatsCollectorTest : public testing::Test { EXPECT_CALL(session_, GetStats(_)).WillRepeatedly(Return(false)); } + void TestCertificateReports(const talk_base::FakeSSLCertificate& local_cert, + const std::vector& local_ders, + const talk_base::FakeSSLCertificate& remote_cert, + const std::vector& remote_ders) { + webrtc::StatsCollector stats; // Implementation under test. + webrtc::StatsReports reports; // returned values. + stats.set_session(&session_); + + // Fake stats to process. + cricket::TransportChannelStats channel_stats; + channel_stats.component = 1; + + cricket::TransportStats transport_stats; + transport_stats.content_name = "audio"; + transport_stats.channel_stats.push_back(channel_stats); + + cricket::SessionStats session_stats; + session_stats.transport_stats[transport_stats.content_name] = + transport_stats; + + // Fake certificates to report. + talk_base::FakeSSLIdentity local_identity(local_cert); + talk_base::scoped_ptr remote_cert_copy( + remote_cert.GetReference()); + + // Fake transport object. + talk_base::scoped_ptr transport( + new cricket::FakeTransport( + session_.signaling_thread(), + session_.worker_thread(), + transport_stats.content_name)); + transport->SetIdentity(&local_identity); + cricket::FakeTransportChannel* channel = + static_cast( + transport->CreateChannel(channel_stats.component)); + EXPECT_FALSE(channel == NULL); + channel->SetRemoteCertificate(remote_cert_copy.get()); + + // Configure MockWebRtcSession + EXPECT_CALL(session_, GetTransport(transport_stats.content_name)) + .WillOnce(Return(transport.get())); + EXPECT_CALL(session_, GetStats(_)) + .WillOnce(DoAll(SetArgPointee<0>(session_stats), + Return(true))); + EXPECT_CALL(session_, video_channel()) + .WillRepeatedly(ReturnNull()); + + stats.UpdateStats(); + + stats.GetStats(NULL, &reports); + + const webrtc::StatsReport* channel_report = FindNthReportByType( + reports, webrtc::StatsReport::kStatsReportTypeComponent, 1); + EXPECT_TRUE(channel_report != NULL); + + // Check local certificate chain. + std::string local_certificate_id = ExtractStatsValue( + webrtc::StatsReport::kStatsReportTypeComponent, + reports, + webrtc::StatsReport::kStatsValueNameLocalCertificateId); + EXPECT_NE(kNotFound, local_certificate_id); + CheckCertChainReports(reports, local_ders, local_certificate_id); + + // Check remote certificate chain. + std::string remote_certificate_id = ExtractStatsValue( + webrtc::StatsReport::kStatsReportTypeComponent, + reports, + webrtc::StatsReport::kStatsValueNameRemoteCertificateId); + EXPECT_NE(kNotFound, remote_certificate_id); + CheckCertChainReports(reports, remote_ders, remote_certificate_id); + } cricket::FakeMediaEngine* media_engine_; talk_base::scoped_ptr channel_manager_; MockWebRtcSession session_; @@ -439,4 +558,142 @@ TEST_F(StatsCollectorTest, TransportObjectLinkedFromSsrcObject) { ASSERT_FALSE(transport_report == NULL); } +// This test verifies that all chained certificates are correctly +// reported +TEST_F(StatsCollectorTest, DISABLED_ChainedCertificateReportsCreated) { + // Build local certificate chain. + std::vector local_ders(5); + local_ders[0] = "These"; + local_ders[1] = "are"; + local_ders[2] = "some"; + local_ders[3] = "der"; + local_ders[4] = "values"; + talk_base::FakeSSLCertificate local_cert(DersToPems(local_ders)); + + // Build remote certificate chain + std::vector remote_ders(4); + remote_ders[0] = "A"; + remote_ders[1] = "non-"; + remote_ders[2] = "intersecting"; + remote_ders[3] = "set"; + talk_base::FakeSSLCertificate remote_cert(DersToPems(remote_ders)); + + TestCertificateReports(local_cert, local_ders, remote_cert, remote_ders); +} + +// This test verifies that all certificates without chains are correctly +// reported. +TEST_F(StatsCollectorTest, DISABLED_ChainlessCertificateReportsCreated) { + // Build local certificate. + std::string local_der = "This is the local der."; + talk_base::FakeSSLCertificate local_cert(DerToPem(local_der)); + + // Build remote certificate. + std::string remote_der = "This is somebody else's der."; + talk_base::FakeSSLCertificate remote_cert(DerToPem(remote_der)); + + TestCertificateReports(local_cert, std::vector(1, local_der), + remote_cert, std::vector(1, remote_der)); +} + +// This test verifies that the stats are generated correctly when no +// transport is present. +TEST_F(StatsCollectorTest, DISABLED_NoTransport) { + webrtc::StatsCollector stats; // Implementation under test. + webrtc::StatsReports reports; // returned values. + stats.set_session(&session_); + + // Fake stats to process. + cricket::TransportChannelStats channel_stats; + channel_stats.component = 1; + + cricket::TransportStats transport_stats; + transport_stats.content_name = "audio"; + transport_stats.channel_stats.push_back(channel_stats); + + cricket::SessionStats session_stats; + session_stats.transport_stats[transport_stats.content_name] = + transport_stats; + + // Configure MockWebRtcSession + EXPECT_CALL(session_, GetTransport(transport_stats.content_name)) + .WillOnce(ReturnNull()); + EXPECT_CALL(session_, GetStats(_)) + .WillOnce(DoAll(SetArgPointee<0>(session_stats), + Return(true))); + EXPECT_CALL(session_, video_channel()) + .WillRepeatedly(ReturnNull()); + + stats.UpdateStats(); + stats.GetStats(NULL, &reports); + + // Check that the local certificate is absent. + std::string local_certificate_id = ExtractStatsValue( + webrtc::StatsReport::kStatsReportTypeComponent, + reports, + webrtc::StatsReport::kStatsValueNameLocalCertificateId); + ASSERT_EQ(kNotFound, local_certificate_id); + + // Check that the remote certificate is absent. + std::string remote_certificate_id = ExtractStatsValue( + webrtc::StatsReport::kStatsReportTypeComponent, + reports, + webrtc::StatsReport::kStatsValueNameRemoteCertificateId); + ASSERT_EQ(kNotFound, remote_certificate_id); +} + +// This test verifies that the stats are generated correctly when the transport +// does not have any certificates. +TEST_F(StatsCollectorTest, DISABLED_NoCertificates) { + webrtc::StatsCollector stats; // Implementation under test. + webrtc::StatsReports reports; // returned values. + stats.set_session(&session_); + + // Fake stats to process. + cricket::TransportChannelStats channel_stats; + channel_stats.component = 1; + + cricket::TransportStats transport_stats; + transport_stats.content_name = "audio"; + transport_stats.channel_stats.push_back(channel_stats); + + cricket::SessionStats session_stats; + session_stats.transport_stats[transport_stats.content_name] = + transport_stats; + + // Fake transport object. + talk_base::scoped_ptr transport( + new cricket::FakeTransport( + session_.signaling_thread(), + session_.worker_thread(), + transport_stats.content_name)); + + // Configure MockWebRtcSession + EXPECT_CALL(session_, GetTransport(transport_stats.content_name)) + .WillOnce(Return(transport.get())); + EXPECT_CALL(session_, GetStats(_)) + .WillOnce(DoAll(SetArgPointee<0>(session_stats), + Return(true))); + EXPECT_CALL(session_, video_channel()) + .WillRepeatedly(ReturnNull()); + + stats.UpdateStats(); + stats.GetStats(NULL, &reports); + + // Check that the local certificate is absent. + std::string local_certificate_id = ExtractStatsValue( + webrtc::StatsReport::kStatsReportTypeComponent, + reports, + webrtc::StatsReport::kStatsValueNameLocalCertificateId); + ASSERT_EQ(kNotFound, local_certificate_id); + + // Check that the remote certificate is absent. + std::string remote_certificate_id = ExtractStatsValue( + webrtc::StatsReport::kStatsReportTypeComponent, + reports, + webrtc::StatsReport::kStatsValueNameRemoteCertificateId); + ASSERT_EQ(kNotFound, remote_certificate_id); +} + + } // namespace diff --git a/talk/app/webrtc/statstypes.h b/talk/app/webrtc/statstypes.h index 30a8b84165..fe368859fc 100644 --- a/talk/app/webrtc/statstypes.h +++ b/talk/app/webrtc/statstypes.h @@ -99,6 +99,14 @@ class StatsReport { // The id of StatsReport of type VideoBWE. static const char kStatsReportVideoBweId[]; + // A StatsReport of |type| = "googCertificate" contains an SSL certificate + // transmitted by one of the endpoints of this connection. The |id| is + // controlled by the fingerprint, and is used to identify the certificate in + // the Channel stats (as "googLocalCertificateId" or + // "googRemoteCertificateId") and in any child certificates (as + // "googIssuerId"). + static const char kStatsReportTypeCertificate[]; + // StatsValue names static const char kStatsValueNameAudioOutputLevel[]; static const char kStatsValueNameAudioInputLevel[]; @@ -152,6 +160,11 @@ class StatsReport { static const char kStatsValueNameTrackId[]; static const char kStatsValueNameSsrc[]; static const char kStatsValueNameTypingNoiseState[]; + static const char kStatsValueNameDer[]; + static const char kStatsValueNameFingerprint[]; + static const char kStatsValueNameIssuerId[]; + static const char kStatsValueNameLocalCertificateId[]; + static const char kStatsValueNameRemoteCertificateId[]; }; typedef std::vector StatsReports; diff --git a/talk/base/fakesslidentity.h b/talk/base/fakesslidentity.h index f3c44e4225..5efa268f8f 100644 --- a/talk/base/fakesslidentity.h +++ b/talk/base/fakesslidentity.h @@ -28,6 +28,9 @@ #ifndef TALK_BASE_FAKESSLIDENTITY_H_ #define TALK_BASE_FAKESSLIDENTITY_H_ +#include +#include + #include "talk/base/messagedigest.h" #include "talk/base/sslidentity.h" @@ -36,12 +39,25 @@ namespace talk_base { class FakeSSLCertificate : public talk_base::SSLCertificate { public: explicit FakeSSLCertificate(const std::string& data) : data_(data) {} + explicit FakeSSLCertificate(const std::vector& certs) + : data_(certs.front()) { + std::vector::const_iterator it; + // Skip certs[0]. + for (it = certs.begin() + 1; it != certs.end(); ++it) { + certs_.push_back(FakeSSLCertificate(*it)); + } + } virtual FakeSSLCertificate* GetReference() const { return new FakeSSLCertificate(*this); } virtual std::string ToPEMString() const { return data_; } + virtual void ToDER(Buffer* der_buffer) const { + std::string der_string; + VERIFY(SSLIdentity::PemToDer(kPemTypeCertificate, data_, &der_string)); + der_buffer->SetData(der_string.c_str(), der_string.size()); + } virtual bool ComputeDigest(const std::string &algorithm, unsigned char *digest, std::size_t size, std::size_t *length) const { @@ -49,13 +65,27 @@ class FakeSSLCertificate : public talk_base::SSLCertificate { digest, size); return (*length != 0); } + virtual bool GetChain(SSLCertChain** chain) const { + if (certs_.empty()) + return false; + std::vector new_certs(certs_.size()); + std::transform(certs_.begin(), certs_.end(), new_certs.begin(), DupCert); + *chain = new SSLCertChain(new_certs); + return true; + } + private: + static FakeSSLCertificate* DupCert(FakeSSLCertificate cert) { + return cert.GetReference(); + } std::string data_; + std::vector certs_; }; class FakeSSLIdentity : public talk_base::SSLIdentity { public: explicit FakeSSLIdentity(const std::string& data) : cert_(data) {} + explicit FakeSSLIdentity(const FakeSSLCertificate& cert) : cert_(cert) {} virtual FakeSSLIdentity* GetReference() const { return new FakeSSLIdentity(*this); } diff --git a/talk/base/nssidentity.cc b/talk/base/nssidentity.cc index c660aee0a5..96bfcc3b09 100644 --- a/talk/base/nssidentity.cc +++ b/talk/base/nssidentity.cc @@ -26,6 +26,10 @@ * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ +#include +#include +#include + #if HAVE_CONFIG_H #include "config.h" #endif // HAVE_CONFIG_H @@ -34,8 +38,6 @@ #include "talk/base/nssidentity.h" -#include - #include "cert.h" #include "cryptohi.h" #include "keyhi.h" @@ -90,6 +92,43 @@ NSSKeyPair *NSSKeyPair::GetReference() { return new NSSKeyPair(privkey, pubkey); } +NSSCertificate::NSSCertificate(CERTCertificate* cert) + : certificate_(CERT_DupCertificate(cert)) { + ASSERT(certificate_ != NULL); +} + +static void DeleteCert(SSLCertificate* cert) { + delete cert; +} + +NSSCertificate::NSSCertificate(CERTCertList* cert_list) { + // Copy the first cert into certificate_. + CERTCertListNode* node = CERT_LIST_HEAD(cert_list); + certificate_ = CERT_DupCertificate(node->cert); + + // Put any remaining certificates into the chain. + node = CERT_LIST_NEXT(node); + std::vector certs; + for (; !CERT_LIST_END(node, cert_list); node = CERT_LIST_NEXT(node)) { + certs.push_back(new NSSCertificate(node->cert)); + } + + if (!certs.empty()) + chain_.reset(new SSLCertChain(certs)); + + // The SSLCertChain constructor copies its input, so now we have to delete + // the originals. + std::for_each(certs.begin(), certs.end(), DeleteCert); +} + +NSSCertificate::NSSCertificate(CERTCertificate* cert, SSLCertChain* chain) + : certificate_(CERT_DupCertificate(cert)) { + ASSERT(certificate_ != NULL); + if (chain) + chain_.reset(chain->Copy()); +} + + NSSCertificate *NSSCertificate::FromPEMString(const std::string &pem_string) { std::string der; if (!SSLIdentity::PemToDer(kPemTypeCertificate, pem_string, &der)) @@ -105,15 +144,13 @@ NSSCertificate *NSSCertificate::FromPEMString(const std::string &pem_string) { if (!cert) return NULL; - return new NSSCertificate(cert); + NSSCertificate* ret = new NSSCertificate(cert); + CERT_DestroyCertificate(cert); + return ret; } NSSCertificate *NSSCertificate::GetReference() const { - CERTCertificate *certificate = CERT_DupCertificate(certificate_); - if (!certificate) - return NULL; - - return new NSSCertificate(certificate); + return new NSSCertificate(certificate_, chain_.get()); } std::string NSSCertificate::ToPEMString() const { @@ -122,6 +159,10 @@ std::string NSSCertificate::ToPEMString() const { certificate_->derCert.len); } +void NSSCertificate::ToDER(Buffer* der_buffer) const { + der_buffer->SetData(certificate_->derCert.data, certificate_->derCert.len); +} + bool NSSCertificate::GetDigestLength(const std::string &algorithm, std::size_t *length) { const SECHashObject *ho; @@ -156,6 +197,14 @@ bool NSSCertificate::ComputeDigest(const std::string &algorithm, return true; } +bool NSSCertificate::GetChain(SSLCertChain** chain) const { + if (!chain_) + return false; + + *chain = chain_->Copy(); + return true; +} + bool NSSCertificate::Equals(const NSSCertificate *tocompare) const { if (!certificate_->derCert.len) return false; @@ -301,9 +350,9 @@ NSSIdentity *NSSIdentity::Generate(const std::string &common_name) { fail: delete keypair; - CERT_DestroyCertificate(certificate); done: + if (certificate) CERT_DestroyCertificate(certificate); if (subject_name) CERT_DestroyName(subject_name); if (spki) SECKEY_DestroySubjectPublicKeyInfo(spki); if (certreq) CERT_DestroyCertificateRequest(certreq); diff --git a/talk/base/nssidentity.h b/talk/base/nssidentity.h index 725c546277..f4bfc8bcc1 100644 --- a/talk/base/nssidentity.h +++ b/talk/base/nssidentity.h @@ -66,7 +66,10 @@ class NSSKeyPair { class NSSCertificate : public SSLCertificate { public: static NSSCertificate* FromPEMString(const std::string& pem_string); - explicit NSSCertificate(CERTCertificate* cert) : certificate_(cert) {} + // The caller retains ownership of the argument to all the constructors, + // and the constructor makes a copy. + explicit NSSCertificate(CERTCertificate* cert); + explicit NSSCertificate(CERTCertList* cert_list); virtual ~NSSCertificate() { if (certificate_) CERT_DestroyCertificate(certificate_); @@ -76,24 +79,30 @@ class NSSCertificate : public SSLCertificate { virtual std::string ToPEMString() const; + virtual void ToDER(Buffer* der_buffer) const; + virtual bool ComputeDigest(const std::string& algorithm, unsigned char* digest, std::size_t size, std::size_t* length) const; + virtual bool GetChain(SSLCertChain** chain) const; + CERTCertificate* certificate() { return certificate_; } // Helper function to get the length of a digest static bool GetDigestLength(const std::string& algorithm, std::size_t* length); - // Comparison + // Comparison. Only the certificate itself is considered, not the chain. bool Equals(const NSSCertificate* tocompare) const; private: + NSSCertificate(CERTCertificate* cert, SSLCertChain* chain); static bool GetDigestObject(const std::string& algorithm, const SECHashObject** hash_object); CERTCertificate* certificate_; + scoped_ptr chain_; DISALLOW_EVIL_CONSTRUCTORS(NSSCertificate); }; diff --git a/talk/base/nssstreamadapter.cc b/talk/base/nssstreamadapter.cc index c9a540d521..185c243f5e 100644 --- a/talk/base/nssstreamadapter.cc +++ b/talk/base/nssstreamadapter.cc @@ -821,6 +821,13 @@ SECStatus NSSStreamAdapter::AuthCertificateHook(void *arg, if (ok) { stream->cert_ok_ = true; + + // Record the peer's certificate chain. + CERTCertList* cert_list = SSL_PeerCertificateChain(fd); + ASSERT(cert_list != NULL); + + stream->peer_certificate_.reset(new NSSCertificate(cert_list)); + CERT_DestroyCertList(cert_list); return SECSuccess; } diff --git a/talk/base/opensslidentity.cc b/talk/base/opensslidentity.cc index a48c94fd75..7408af1ab6 100644 --- a/talk/base/opensslidentity.cc +++ b/talk/base/opensslidentity.cc @@ -40,6 +40,7 @@ #include #include +#include "talk/base/checks.h" #include "talk/base/helpers.h" #include "talk/base/logging.h" #include "talk/base/openssldigest.h" @@ -211,7 +212,9 @@ OpenSSLCertificate* OpenSSLCertificate::Generate( #ifdef _DEBUG PrintCert(x509); #endif - return new OpenSSLCertificate(x509); + OpenSSLCertificate* ret = new OpenSSLCertificate(x509); + X509_free(x509); + return ret; } OpenSSLCertificate* OpenSSLCertificate::FromPEMString( @@ -224,10 +227,12 @@ OpenSSLCertificate* OpenSSLCertificate::FromPEMString( X509 *x509 = PEM_read_bio_X509(bio, NULL, NULL, const_cast("\0")); BIO_free(bio); - if (x509) - return new OpenSSLCertificate(x509); - else + if (!x509) return NULL; + + OpenSSLCertificate* ret = new OpenSSLCertificate(x509); + X509_free(x509); + return ret; } bool OpenSSLCertificate::ComputeDigest(const std::string &algorithm, @@ -264,11 +269,14 @@ OpenSSLCertificate::~OpenSSLCertificate() { std::string OpenSSLCertificate::ToPEMString() const { BIO* bio = BIO_new(BIO_s_mem()); - if (!bio) - return NULL; + if (!bio) { + UNREACHABLE(); + return std::string(); + } if (!PEM_write_bio_X509(bio, x509_)) { BIO_free(bio); - return NULL; + UNREACHABLE(); + return std::string(); } BIO_write(bio, "\0", 1); char* buffer; @@ -278,7 +286,29 @@ std::string OpenSSLCertificate::ToPEMString() const { return ret; } +void OpenSSLCertificate::ToDER(Buffer* der_buffer) const { + // In case of failure, make sure to leave the buffer empty. + der_buffer->SetData(NULL, 0); + + // Calculates the DER representation of the certificate, from scratch. + BIO* bio = BIO_new(BIO_s_mem()); + if (!bio) { + UNREACHABLE(); + return; + } + if (!i2d_X509_bio(bio, x509_)) { + BIO_free(bio); + UNREACHABLE(); + return; + } + char* data; + size_t length = BIO_get_mem_data(bio, &data); + der_buffer->SetData(data, length); + BIO_free(bio); +} + void OpenSSLCertificate::AddReference() const { + ASSERT(x509_ != NULL); CRYPTO_add(&x509_->references, 1, CRYPTO_LOCK_X509); } diff --git a/talk/base/opensslidentity.h b/talk/base/opensslidentity.h index ca001b5cfa..0d1bf73b07 100644 --- a/talk/base/opensslidentity.h +++ b/talk/base/opensslidentity.h @@ -25,8 +25,8 @@ * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ -#ifndef TALK_BASE_OPENSSLIDENTITY_H__ -#define TALK_BASE_OPENSSLIDENTITY_H__ +#ifndef TALK_BASE_OPENSSLIDENTITY_H_ +#define TALK_BASE_OPENSSLIDENTITY_H_ #include #include @@ -72,6 +72,11 @@ class OpenSSLKeyPair { // which is also reference counted inside the OpenSSL library. class OpenSSLCertificate : public SSLCertificate { public: + // Caller retains ownership of the X509 object. + explicit OpenSSLCertificate(X509* x509) : x509_(x509) { + AddReference(); + } + static OpenSSLCertificate* Generate(OpenSSLKeyPair* key_pair, const std::string& common_name); static OpenSSLCertificate* FromPEMString(const std::string& pem_string); @@ -79,7 +84,6 @@ class OpenSSLCertificate : public SSLCertificate { virtual ~OpenSSLCertificate(); virtual OpenSSLCertificate* GetReference() const { - AddReference(); return new OpenSSLCertificate(x509_); } @@ -87,6 +91,8 @@ class OpenSSLCertificate : public SSLCertificate { virtual std::string ToPEMString() const; + virtual void ToDER(Buffer* der_buffer) const; + // Compute the digest of the certificate given algorithm virtual bool ComputeDigest(const std::string &algorithm, unsigned char *digest, std::size_t size, @@ -99,10 +105,14 @@ class OpenSSLCertificate : public SSLCertificate { std::size_t size, std::size_t *length); - private: - explicit OpenSSLCertificate(X509* x509) : x509_(x509) { - ASSERT(x509_ != NULL); + virtual bool GetChain(SSLCertChain** chain) const { + // Chains are not yet supported when using OpenSSL. + // OpenSSLStreamAdapter::SSLVerifyCallback currently requires the remote + // certificate to be self-signed. + return false; } + + private: void AddReference() const; X509* x509_; @@ -148,4 +158,4 @@ class OpenSSLIdentity : public SSLIdentity { } // namespace talk_base -#endif // TALK_BASE_OPENSSLIDENTITY_H__ +#endif // TALK_BASE_OPENSSLIDENTITY_H_ diff --git a/talk/base/opensslstreamadapter.cc b/talk/base/opensslstreamadapter.cc index 16021a96bd..034dfcf926 100644 --- a/talk/base/opensslstreamadapter.cc +++ b/talk/base/opensslstreamadapter.cc @@ -217,6 +217,14 @@ void OpenSSLStreamAdapter::SetPeerCertificate(SSLCertificate* cert) { peer_certificate_.reset(static_cast(cert)); } +bool OpenSSLStreamAdapter::GetPeerCertificate(SSLCertificate** cert) const { + if (!peer_certificate_) + return false; + + *cert = peer_certificate_->GetReference(); + return true; +} + bool OpenSSLStreamAdapter::SetPeerCertificateDigest(const std::string &digest_alg, const unsigned char* @@ -857,6 +865,9 @@ int OpenSSLStreamAdapter::SSLVerifyCallback(int ok, X509_STORE_CTX* store) { LOG(LS_INFO) << "Accepted self-signed peer certificate authority"; ok = 1; + + // Record the peer's certificate. + stream->peer_certificate_.reset(new OpenSSLCertificate(cert)); } } } diff --git a/talk/base/opensslstreamadapter.h b/talk/base/opensslstreamadapter.h index 8e92a10a5f..3c478187fc 100644 --- a/talk/base/opensslstreamadapter.h +++ b/talk/base/opensslstreamadapter.h @@ -86,6 +86,8 @@ class OpenSSLStreamAdapter : public SSLStreamAdapter { const unsigned char* digest_val, size_t digest_len); + virtual bool GetPeerCertificate(SSLCertificate** cert) const; + virtual int StartSSLWithServer(const char* server_name); virtual int StartSSLWithPeer(); virtual void SetMode(SSLMode mode); @@ -190,8 +192,8 @@ class OpenSSLStreamAdapter : public SSLStreamAdapter { // in traditional mode, the server name that the server's certificate // must specify. Empty in peer-to-peer mode. std::string ssl_server_name_; - // In peer-to-peer mode, the certificate that the peer must - // present. Empty in traditional mode. + // The certificate that the peer must present or did present. Initially + // null in traditional mode, until the connection is established. scoped_ptr peer_certificate_; // In peer-to-peer mode, the digest of the certificate that // the peer must present. diff --git a/talk/base/sslfingerprint.h b/talk/base/sslfingerprint.h index 4d41156f86..b85778947e 100644 --- a/talk/base/sslfingerprint.h +++ b/talk/base/sslfingerprint.h @@ -47,9 +47,14 @@ struct SSLFingerprint { return NULL; } + return Create(algorithm, &(identity->certificate())); + } + + static SSLFingerprint* Create(const std::string& algorithm, + const talk_base::SSLCertificate* cert) { uint8 digest_val[64]; size_t digest_len; - bool ret = identity->certificate().ComputeDigest( + bool ret = cert->ComputeDigest( algorithm, digest_val, sizeof(digest_val), &digest_len); if (!ret) { return NULL; diff --git a/talk/base/sslidentity.h b/talk/base/sslidentity.h index b9425f78fd..345691c398 100644 --- a/talk/base/sslidentity.h +++ b/talk/base/sslidentity.h @@ -30,11 +30,18 @@ #ifndef TALK_BASE_SSLIDENTITY_H_ #define TALK_BASE_SSLIDENTITY_H_ +#include #include +#include + +#include "talk/base/buffer.h" #include "talk/base/messagedigest.h" namespace talk_base { +// Forward declaration due to circular dependency with SSLCertificate. +class SSLCertChain; + // Abstract interface overridden by SSL library specific // implementations. @@ -55,19 +62,72 @@ class SSLCertificate { virtual ~SSLCertificate() {} // Returns a new SSLCertificate object instance wrapping the same - // underlying certificate. + // underlying certificate, including its chain if present. // Caller is responsible for freeing the returned object. virtual SSLCertificate* GetReference() const = 0; + // Provides the cert chain, or returns false. The caller owns the chain. + // The chain includes a copy of each certificate, excluding the leaf. + virtual bool GetChain(SSLCertChain** chain) const = 0; + // Returns a PEM encoded string representation of the certificate. virtual std::string ToPEMString() const = 0; + // Provides a DER encoded binary representation of the certificate. + virtual void ToDER(Buffer* der_buffer) const = 0; + // Compute the digest of the certificate given algorithm virtual bool ComputeDigest(const std::string &algorithm, unsigned char* digest, std::size_t size, std::size_t* length) const = 0; }; +// SSLCertChain is a simple wrapper for a vector of SSLCertificates. It serves +// primarily to ensure proper memory management (especially deletion) of the +// SSLCertificate pointers. +class SSLCertChain { + public: + // These constructors copy the provided SSLCertificate(s), so the caller + // retains ownership. + explicit SSLCertChain(const std::vector& certs) { + ASSERT(!certs.empty()); + certs_.resize(certs.size()); + std::transform(certs.begin(), certs.end(), certs_.begin(), DupCert); + } + explicit SSLCertChain(const SSLCertificate* cert) { + certs_.push_back(cert->GetReference()); + } + + ~SSLCertChain() { + std::for_each(certs_.begin(), certs_.end(), DeleteCert); + } + + // Vector access methods. + size_t GetSize() const { return certs_.size(); } + + // Returns a temporary reference, only valid until the chain is destroyed. + const SSLCertificate& Get(size_t pos) const { return *(certs_[pos]); } + + // Returns a new SSLCertChain object instance wrapping the same underlying + // certificate chain. Caller is responsible for freeing the returned object. + SSLCertChain* Copy() const { + return new SSLCertChain(certs_); + } + + private: + // Helper function for duplicating a vector of certificates. + static SSLCertificate* DupCert(const SSLCertificate* cert) { + return cert->GetReference(); + } + + // Helper function for deleting a vector of certificates. + static void DeleteCert(SSLCertificate* cert) { delete cert; } + + std::vector certs_; + + DISALLOW_COPY_AND_ASSIGN(SSLCertChain); +}; + // Our identity in an SSL negotiation: a keypair and certificate (both // with the same public key). // This too is pretty much immutable once created. @@ -108,4 +168,4 @@ extern const char kPemTypeRsaPrivateKey[]; } // namespace talk_base -#endif // TALK_BASE_SSLIDENTITY_H__ +#endif // TALK_BASE_SSLIDENTITY_H_ diff --git a/talk/base/sslstreamadapter.h b/talk/base/sslstreamadapter.h index 2afe1daf18..3a7797370c 100644 --- a/talk/base/sslstreamadapter.h +++ b/talk/base/sslstreamadapter.h @@ -25,8 +25,8 @@ * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ -#ifndef TALK_BASE_SSLSTREAMADAPTER_H__ -#define TALK_BASE_SSLSTREAMADAPTER_H__ +#ifndef TALK_BASE_SSLSTREAMADAPTER_H_ +#define TALK_BASE_SSLSTREAMADAPTER_H_ #include #include @@ -111,9 +111,9 @@ class SSLStreamAdapter : public StreamAdapterInterface { // mode. // Generally, SetIdentity() and possibly SetServerRole() should have // been called before this. - // SetPeerCertificate() must also be called. It may be called after - // StartSSLWithPeer() but must be called before the underlying - // stream opens. + // SetPeerCertificate() or SetPeerCertificateDigest() must also be called. + // It may be called after StartSSLWithPeer() but must be called before the + // underlying stream opens. virtual int StartSSLWithPeer() = 0; // Specify the certificate that our peer is expected to use in @@ -138,6 +138,13 @@ class SSLStreamAdapter : public StreamAdapterInterface { const unsigned char* digest_val, size_t digest_len) = 0; + // Retrieves the peer's X.509 certificate, if a certificate has been + // provided by SetPeerCertificate or a connection has been established. If + // a connection has been established, this returns the + // certificate transmitted over SSL, including the entire chain. + // The returned certificate is owned by the caller. + virtual bool GetPeerCertificate(SSLCertificate** cert) const = 0; + // Key Exporter interface from RFC 5705 // Arguments are: // label -- the exporter label. @@ -182,4 +189,4 @@ class SSLStreamAdapter : public StreamAdapterInterface { } // namespace talk_base -#endif // TALK_BASE_SSLSTREAMADAPTER_H__ +#endif // TALK_BASE_SSLSTREAMADAPTER_H_ diff --git a/talk/base/sslstreamadapter_unittest.cc b/talk/base/sslstreamadapter_unittest.cc index 1fe1a66348..e7335be48f 100644 --- a/talk/base/sslstreamadapter_unittest.cc +++ b/talk/base/sslstreamadapter_unittest.cc @@ -33,6 +33,7 @@ #include "talk/base/gunit.h" #include "talk/base/helpers.h" +#include "talk/base/scoped_ptr.h" #include "talk/base/ssladapter.h" #include "talk/base/sslconfig.h" #include "talk/base/sslidentity.h" @@ -388,6 +389,13 @@ class SSLStreamAdapterTestBase : public testing::Test, return server_ssl_->GetDtlsSrtpCipher(retval); } + bool GetPeerCertificate(bool client, talk_base::SSLCertificate** cert) { + if (client) + return client_ssl_->GetPeerCertificate(cert); + else + return server_ssl_->GetPeerCertificate(cert); + } + bool ExportKeyingMaterial(const char *label, const unsigned char *context, size_t context_len, @@ -885,3 +893,42 @@ TEST_F(SSLStreamAdapterTestDTLSFromPEMStrings, TestTransfer) { TestHandshake(); TestTransfer(100); } + +// Test getting the remote certificate. +TEST_F(SSLStreamAdapterTestDTLSFromPEMStrings, TestDTLSGetPeerCertificate) { + MAYBE_SKIP_TEST(HaveDtls); + + // Peer certificates haven't been received yet. + talk_base::scoped_ptr client_peer_cert; + ASSERT_FALSE(GetPeerCertificate(true, client_peer_cert.accept())); + ASSERT_FALSE(client_peer_cert != NULL); + + talk_base::scoped_ptr server_peer_cert; + ASSERT_FALSE(GetPeerCertificate(false, server_peer_cert.accept())); + ASSERT_FALSE(server_peer_cert != NULL); + + TestHandshake(); + + // The client should have a peer certificate after the handshake. + ASSERT_TRUE(GetPeerCertificate(true, client_peer_cert.accept())); + ASSERT_TRUE(client_peer_cert != NULL); + + // It's not kCERT_PEM. + std::string client_peer_string = client_peer_cert->ToPEMString(); + ASSERT_NE(kCERT_PEM, client_peer_string); + + // It must not have a chain, because the test certs are self-signed. + talk_base::SSLCertChain* client_peer_chain; + ASSERT_FALSE(client_peer_cert->GetChain(&client_peer_chain)); + + // The server should have a peer certificate after the handshake. + ASSERT_TRUE(GetPeerCertificate(false, server_peer_cert.accept())); + ASSERT_TRUE(server_peer_cert != NULL); + + // It's kCERT_PEM + ASSERT_EQ(kCERT_PEM, server_peer_cert->ToPEMString()); + + // It must not have a chain, because the test certs are self-signed. + talk_base::SSLCertChain* server_peer_chain; + ASSERT_FALSE(server_peer_cert->GetChain(&server_peer_chain)); +} diff --git a/talk/base/sslstreamadapterhelper.cc b/talk/base/sslstreamadapterhelper.cc index 5a1a255505..b42faa80c6 100644 --- a/talk/base/sslstreamadapterhelper.cc +++ b/talk/base/sslstreamadapterhelper.cc @@ -87,6 +87,14 @@ void SSLStreamAdapterHelper::SetPeerCertificate(SSLCertificate* cert) { peer_certificate_.reset(cert); } +bool SSLStreamAdapterHelper::GetPeerCertificate(SSLCertificate** cert) const { + if (!peer_certificate_) + return false; + + *cert = peer_certificate_->GetReference(); + return true; +} + bool SSLStreamAdapterHelper::SetPeerCertificateDigest( const std::string &digest_alg, const unsigned char* digest_val, diff --git a/talk/base/sslstreamadapterhelper.h b/talk/base/sslstreamadapterhelper.h index e8cb3b08b1..7c28056612 100644 --- a/talk/base/sslstreamadapterhelper.h +++ b/talk/base/sslstreamadapterhelper.h @@ -2,26 +2,26 @@ * libjingle * Copyright 2004--2008, Google Inc. * - * Redistribution and use in source and binary forms, with or without + * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, + * 1. Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. - * 3. The name of the author may not be used to endorse or promote products + * 3. The name of the author may not be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED - * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF + * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO - * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, - * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR - * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ @@ -63,6 +63,7 @@ class SSLStreamAdapterHelper : public SSLStreamAdapter { virtual bool SetPeerCertificateDigest(const std::string& digest_alg, const unsigned char* digest_val, size_t digest_len); + virtual bool GetPeerCertificate(SSLCertificate** cert) const; virtual StreamState GetState() const; virtual void Close(); diff --git a/talk/media/webrtc/webrtcvoiceengine.cc b/talk/media/webrtc/webrtcvoiceengine.cc index f3dcf3b858..5c16d6e626 100644 --- a/talk/media/webrtc/webrtcvoiceengine.cc +++ b/talk/media/webrtc/webrtcvoiceengine.cc @@ -234,6 +234,9 @@ class WebRtcSoundclipMedia : public SoundclipMedia { } bool Init() { + if (!engine_->voe_sc()) { + return false; + } webrtc_channel_ = engine_->voe_sc()->base()->CreateChannel(); if (webrtc_channel_ == -1) { LOG_RTCERR0(CreateChannel); @@ -300,6 +303,7 @@ class WebRtcSoundclipMedia : public SoundclipMedia { WebRtcVoiceEngine::WebRtcVoiceEngine() : voe_wrapper_(new VoEWrapper()), voe_wrapper_sc_(new VoEWrapper()), + voe_wrapper_sc_initialized_(false), tracing_(new VoETraceWrapper()), adm_(NULL), adm_sc_(NULL), @@ -316,6 +320,7 @@ WebRtcVoiceEngine::WebRtcVoiceEngine(VoEWrapper* voe_wrapper, VoETraceWrapper* tracing) : voe_wrapper_(voe_wrapper), voe_wrapper_sc_(voe_wrapper_sc), + voe_wrapper_sc_initialized_(false), tracing_(tracing), adm_(NULL), adm_sc_(NULL), @@ -539,6 +544,23 @@ bool WebRtcVoiceEngine::InitInternal() { LOG(LS_INFO) << ToString(*it); } + // Disable the DTMF playout when a tone is sent. + // PlayDtmfTone will be used if local playout is needed. + if (voe_wrapper_->dtmf()->SetDtmfFeedbackStatus(false) == -1) { + LOG_RTCERR1(SetDtmfFeedbackStatus, false); + } + + initialized_ = true; + return true; +} + +bool WebRtcVoiceEngine::EnsureSoundclipEngineInit() { + if (voe_wrapper_sc_initialized_) { + return true; + } + // Note that, if initialization fails, voe_wrapper_sc_initialized_ will still + // be false, so subsequent calls to EnsureSoundclipEngineInit will + // probably just fail again. That's acceptable behavior. #if defined(LINUX) && !defined(HAVE_LIBPULSE) voe_wrapper_sc_->hw()->SetAudioDeviceLayer(webrtc::kAudioLinuxAlsa); #endif @@ -572,14 +594,8 @@ bool WebRtcVoiceEngine::InitInternal() { } } #endif - - // Disable the DTMF playout when a tone is sent. - // PlayDtmfTone will be used if local playout is needed. - if (voe_wrapper_->dtmf()->SetDtmfFeedbackStatus(false) == -1) { - LOG_RTCERR1(SetDtmfFeedbackStatus, false); - } - - initialized_ = true; + voe_wrapper_sc_initialized_ = true; + LOG(LS_INFO) << "Initialized WebRtc soundclip engine."; return true; } @@ -589,7 +605,10 @@ void WebRtcVoiceEngine::Terminate() { StopAecDump(); - voe_wrapper_sc_->base()->Terminate(); + if (voe_wrapper_sc_) { + voe_wrapper_sc_initialized_ = false; + voe_wrapper_sc_->base()->Terminate(); + } voe_wrapper_->base()->Terminate(); desired_local_monitor_enable_ = false; } @@ -608,6 +627,11 @@ VoiceMediaChannel *WebRtcVoiceEngine::CreateChannel() { } SoundclipMedia *WebRtcVoiceEngine::CreateSoundclip() { + if (!EnsureSoundclipEngineInit()) { + LOG(LS_ERROR) << "Unable to create soundclip: soundclip engine failed to " + << "initialize."; + return NULL; + } WebRtcSoundclipMedia *soundclip = new WebRtcSoundclipMedia(this); if (!soundclip->Init() || !soundclip->Enable()) { delete soundclip; diff --git a/talk/media/webrtc/webrtcvoiceengine.h b/talk/media/webrtc/webrtcvoiceengine.h index 9e8ef86ee6..7809706a8c 100644 --- a/talk/media/webrtc/webrtcvoiceengine.h +++ b/talk/media/webrtc/webrtcvoiceengine.h @@ -184,6 +184,7 @@ class WebRtcVoiceEngine void Construct(); void ConstructCodecs(); bool InitInternal(); + bool EnsureSoundclipEngineInit(); void SetTraceFilter(int filter); void SetTraceOptions(const std::string& options); // Applies either options or overrides. Every option that is "set" @@ -227,6 +228,7 @@ class WebRtcVoiceEngine talk_base::scoped_ptr voe_wrapper_; // A secondary instance, for playing out soundclips (on the 'ring' device). talk_base::scoped_ptr voe_wrapper_sc_; + bool voe_wrapper_sc_initialized_; talk_base::scoped_ptr tracing_; // The external audio device manager webrtc::AudioDeviceModule* adm_; diff --git a/talk/media/webrtc/webrtcvoiceengine_unittest.cc b/talk/media/webrtc/webrtcvoiceengine_unittest.cc index 883ff48563..8dc0dffbc7 100644 --- a/talk/media/webrtc/webrtcvoiceengine_unittest.cc +++ b/talk/media/webrtc/webrtcvoiceengine_unittest.cc @@ -294,7 +294,8 @@ TEST_F(WebRtcVoiceEngineTestFake, StartupShutdown) { EXPECT_FALSE(voe_sc_.IsInited()); EXPECT_TRUE(engine_.Init(talk_base::Thread::Current())); EXPECT_TRUE(voe_.IsInited()); - EXPECT_TRUE(voe_sc_.IsInited()); + // The soundclip engine is lazily initialized. + EXPECT_FALSE(voe_sc_.IsInited()); engine_.Terminate(); EXPECT_FALSE(voe_.IsInited()); EXPECT_FALSE(voe_sc_.IsInited()); @@ -2142,7 +2143,9 @@ TEST_F(WebRtcVoiceEngineTestFake, PlayRingbackWithMultipleStreams) { // Tests creating soundclips, and make sure they come from the right engine. TEST_F(WebRtcVoiceEngineTestFake, CreateSoundclip) { EXPECT_TRUE(engine_.Init(talk_base::Thread::Current())); + EXPECT_FALSE(voe_sc_.IsInited()); soundclip_ = engine_.CreateSoundclip(); + EXPECT_TRUE(voe_sc_.IsInited()); ASSERT_TRUE(soundclip_ != NULL); EXPECT_EQ(0, voe_.GetNumChannels()); EXPECT_EQ(1, voe_sc_.GetNumChannels()); @@ -2151,6 +2154,10 @@ TEST_F(WebRtcVoiceEngineTestFake, CreateSoundclip) { delete soundclip_; soundclip_ = NULL; EXPECT_EQ(0, voe_sc_.GetNumChannels()); + // Make sure the soundclip engine is uninitialized on shutdown, now that + // we've initialized it by creating a soundclip. + engine_.Terminate(); + EXPECT_FALSE(voe_sc_.IsInited()); } // Tests playing out a fake sound. diff --git a/talk/p2p/base/dtlstransport.h b/talk/p2p/base/dtlstransport.h index 93da1033e8..7492171ee1 100644 --- a/talk/p2p/base/dtlstransport.h +++ b/talk/p2p/base/dtlstransport.h @@ -58,6 +58,13 @@ class DtlsTransport : public Base { virtual void SetIdentity_w(talk_base::SSLIdentity* identity) { identity_ = identity; } + virtual bool GetIdentity_w(talk_base::SSLIdentity** identity) { + if (!identity_) + return false; + + *identity = identity_->GetReference(); + return true; + } virtual bool ApplyLocalTransportDescription_w(TransportChannelImpl* channel) { diff --git a/talk/p2p/base/dtlstransportchannel.cc b/talk/p2p/base/dtlstransportchannel.cc index dead3a550b..a92e7ccad4 100644 --- a/talk/p2p/base/dtlstransportchannel.cc +++ b/talk/p2p/base/dtlstransportchannel.cc @@ -173,6 +173,15 @@ bool DtlsTransportChannelWrapper::SetLocalIdentity( return true; } +bool DtlsTransportChannelWrapper::GetLocalIdentity( + talk_base::SSLIdentity** identity) const { + if (!local_identity_) + return false; + + *identity = local_identity_->GetReference(); + return true; +} + bool DtlsTransportChannelWrapper::SetSslRole(talk_base::SSLRole role) { if (dtls_state_ == STATE_OPEN) { if (ssl_role_ != role) { @@ -230,6 +239,14 @@ bool DtlsTransportChannelWrapper::SetRemoteFingerprint( return true; } +bool DtlsTransportChannelWrapper::GetRemoteCertificate( + talk_base::SSLCertificate** cert) const { + if (!dtls_) + return false; + + return dtls_->GetPeerCertificate(cert); +} + bool DtlsTransportChannelWrapper::SetupDtls() { StreamInterfaceChannel* downward = new StreamInterfaceChannel(worker_thread_, channel_); diff --git a/talk/p2p/base/dtlstransportchannel.h b/talk/p2p/base/dtlstransportchannel.h index aec8c7ac42..29d97a2977 100644 --- a/talk/p2p/base/dtlstransportchannel.h +++ b/talk/p2p/base/dtlstransportchannel.h @@ -128,6 +128,7 @@ class DtlsTransportChannelWrapper : public TransportChannelImpl { return channel_->GetIceRole(); } virtual bool SetLocalIdentity(talk_base::SSLIdentity *identity); + virtual bool GetLocalIdentity(talk_base::SSLIdentity** identity) const; virtual bool SetRemoteFingerprint(const std::string& digest_alg, const uint8* digest, @@ -164,6 +165,10 @@ class DtlsTransportChannelWrapper : public TransportChannelImpl { virtual bool GetSslRole(talk_base::SSLRole* role) const; virtual bool SetSslRole(talk_base::SSLRole role); + // Once DTLS has been established, this method retrieves the certificate in + // use by the remote peer, for use in external identity verification. + virtual bool GetRemoteCertificate(talk_base::SSLCertificate** cert) const; + // Once DTLS has established (i.e., this channel is writable), this method // extracts the keys negotiated during the DTLS handshake, for use in external // encryption. DTLS-SRTP uses this to extract the needed SRTP keys. diff --git a/talk/p2p/base/dtlstransportchannel_unittest.cc b/talk/p2p/base/dtlstransportchannel_unittest.cc index 267d60be16..6a4e5ade90 100644 --- a/talk/p2p/base/dtlstransportchannel_unittest.cc +++ b/talk/p2p/base/dtlstransportchannel_unittest.cc @@ -751,3 +751,56 @@ TEST_F(DtlsTransportChannelTest, TestDtlsReOfferWithDifferentSetupAttr) { TestTransfer(0, 1000, 100, true); TestTransfer(1, 1000, 100, true); } + +// Test Certificates state after negotiation but before connection. +TEST_F(DtlsTransportChannelTest, TestCertificatesBeforeConnect) { + MAYBE_SKIP_TEST(HaveDtls); + PrepareDtls(true, true); + Negotiate(); + + talk_base::scoped_ptr identity1; + talk_base::scoped_ptr identity2; + talk_base::scoped_ptr remote_cert1; + talk_base::scoped_ptr remote_cert2; + + // After negotiation, each side has a distinct local certificate, but still no + // remote certificate, because connection has not yet occurred. + ASSERT_TRUE(client1_.transport()->GetIdentity(identity1.accept())); + ASSERT_TRUE(client2_.transport()->GetIdentity(identity2.accept())); + ASSERT_NE(identity1->certificate().ToPEMString(), + identity2->certificate().ToPEMString()); + ASSERT_FALSE( + client1_.transport()->GetRemoteCertificate(remote_cert1.accept())); + ASSERT_FALSE(remote_cert1 != NULL); + ASSERT_FALSE( + client2_.transport()->GetRemoteCertificate(remote_cert2.accept())); + ASSERT_FALSE(remote_cert2 != NULL); +} + +// Test Certificates state after connection. +TEST_F(DtlsTransportChannelTest, TestCertificatesAfterConnect) { + MAYBE_SKIP_TEST(HaveDtls); + PrepareDtls(true, true); + ASSERT_TRUE(Connect()); + + talk_base::scoped_ptr identity1; + talk_base::scoped_ptr identity2; + talk_base::scoped_ptr remote_cert1; + talk_base::scoped_ptr remote_cert2; + + // After connection, each side has a distinct local certificate. + ASSERT_TRUE(client1_.transport()->GetIdentity(identity1.accept())); + ASSERT_TRUE(client2_.transport()->GetIdentity(identity2.accept())); + ASSERT_NE(identity1->certificate().ToPEMString(), + identity2->certificate().ToPEMString()); + + // Each side's remote certificate is the other side's local certificate. + ASSERT_TRUE( + client1_.transport()->GetRemoteCertificate(remote_cert1.accept())); + ASSERT_EQ(remote_cert1->ToPEMString(), + identity2->certificate().ToPEMString()); + ASSERT_TRUE( + client2_.transport()->GetRemoteCertificate(remote_cert2.accept())); + ASSERT_EQ(remote_cert2->ToPEMString(), + identity1->certificate().ToPEMString()); +} diff --git a/talk/p2p/base/fakesession.h b/talk/p2p/base/fakesession.h index bc05ce2e13..9a8fadaf2c 100644 --- a/talk/p2p/base/fakesession.h +++ b/talk/p2p/base/fakesession.h @@ -33,6 +33,7 @@ #include #include "talk/base/buffer.h" +#include "talk/base/fakesslidentity.h" #include "talk/base/sigslot.h" #include "talk/base/sslfingerprint.h" #include "talk/base/messagequeue.h" @@ -212,11 +213,16 @@ class FakeTransportChannel : public TransportChannelImpl, return true; } - bool IsDtlsActive() const { + + void SetRemoteCertificate(talk_base::FakeSSLCertificate* cert) { + remote_cert_ = cert; + } + + virtual bool IsDtlsActive() const { return do_dtls_; } - bool SetSrtpCiphers(const std::vector& ciphers) { + virtual bool SetSrtpCiphers(const std::vector& ciphers) { srtp_ciphers_ = ciphers; return true; } @@ -229,6 +235,22 @@ class FakeTransportChannel : public TransportChannelImpl, return false; } + virtual bool GetLocalIdentity(talk_base::SSLIdentity** identity) const { + if (!identity_) + return false; + + *identity = identity_->GetReference(); + return true; + } + + virtual bool GetRemoteCertificate(talk_base::SSLCertificate** cert) const { + if (!remote_cert_) + return false; + + *cert = remote_cert_->GetReference(); + return true; + } + virtual bool ExportKeyingMaterial(const std::string& label, const uint8* context, size_t context_len, @@ -272,6 +294,7 @@ class FakeTransportChannel : public TransportChannelImpl, State state_; bool async_; talk_base::SSLIdentity* identity_; + talk_base::FakeSSLCertificate* remote_cert_; bool do_dtls_; std::vector srtp_ciphers_; std::string chosen_srtp_cipher_; @@ -349,6 +372,16 @@ class FakeTransport : public Transport { channels_.erase(channel->component()); delete channel; } + virtual void SetIdentity_w(talk_base::SSLIdentity* identity) { + identity_ = identity; + } + virtual bool GetIdentity_w(talk_base::SSLIdentity** identity) { + if (!identity_) + return false; + + *identity = identity_->GetReference(); + return true; + } private: FakeTransportChannel* GetFakeChannel(int component) { diff --git a/talk/p2p/base/p2ptransportchannel.h b/talk/p2p/base/p2ptransportchannel.h index 2fc718641f..63ec6aa28b 100644 --- a/talk/p2p/base/p2ptransportchannel.h +++ b/talk/p2p/base/p2ptransportchannel.h @@ -127,6 +127,15 @@ class P2PTransportChannel : public TransportChannelImpl, return false; } + // Returns false because the channel is not encrypted by default. + virtual bool GetLocalIdentity(talk_base::SSLIdentity** identity) const { + return false; + } + + virtual bool GetRemoteCertificate(talk_base::SSLCertificate** cert) const { + return false; + } + // Allows key material to be extracted for external encryption. virtual bool ExportKeyingMaterial( const std::string& label, diff --git a/talk/p2p/base/rawtransportchannel.h b/talk/p2p/base/rawtransportchannel.h index 2aac2b5edf..ed38952d56 100644 --- a/talk/p2p/base/rawtransportchannel.h +++ b/talk/p2p/base/rawtransportchannel.h @@ -128,6 +128,15 @@ class RawTransportChannel : public TransportChannelImpl, return false; } + // Returns false because the channel is not DTLS. + virtual bool GetLocalIdentity(talk_base::SSLIdentity** identity) const { + return false; + } + + virtual bool GetRemoteCertificate(talk_base::SSLCertificate** cert) const { + return false; + } + // Allows key material to be extracted for external encryption. virtual bool ExportKeyingMaterial( const std::string& label, diff --git a/talk/p2p/base/session.h b/talk/p2p/base/session.h index 12310bc8e1..292e7a50f4 100644 --- a/talk/p2p/base/session.h +++ b/talk/p2p/base/session.h @@ -350,7 +350,7 @@ class BaseSession : public sigslot::has_slots<>, // Returns the transport that has been negotiated or NULL if // negotiation is still in progress. - Transport* GetTransport(const std::string& content_name); + virtual Transport* GetTransport(const std::string& content_name); // Creates a new channel with the given names. This method may be called // immediately after creating the session. However, the actual diff --git a/talk/p2p/base/transport.cc b/talk/p2p/base/transport.cc index 3e4ad70406..4404c081a8 100644 --- a/talk/p2p/base/transport.cc +++ b/talk/p2p/base/transport.cc @@ -107,6 +107,29 @@ void Transport::SetIdentity(talk_base::SSLIdentity* identity) { worker_thread_->Invoke(Bind(&Transport::SetIdentity_w, this, identity)); } +bool Transport::GetIdentity(talk_base::SSLIdentity** identity) { + // The identity is set on the worker thread, so for safety it must also be + // acquired on the worker thread. + return worker_thread_->Invoke( + Bind(&Transport::GetIdentity_w, this, identity)); +} + +bool Transport::GetRemoteCertificate(talk_base::SSLCertificate** cert) { + // Channels can be deleted on the worker thread, so for safety the remote + // certificate is acquired on the worker thread. + return worker_thread_->Invoke( + Bind(&Transport::GetRemoteCertificate_w, this, cert)); +} + +bool Transport::GetRemoteCertificate_w(talk_base::SSLCertificate** cert) { + ASSERT(worker_thread()->IsCurrent()); + if (channels_.empty()) + return false; + + ChannelMap::iterator iter = channels_.begin(); + return iter->second->GetRemoteCertificate(cert); +} + bool Transport::SetLocalTransportDescription( const TransportDescription& description, ContentAction action) { return worker_thread_->Invoke(Bind( diff --git a/talk/p2p/base/transport.h b/talk/p2p/base/transport.h index 381215f5d1..f9e9d88745 100644 --- a/talk/p2p/base/transport.h +++ b/talk/p2p/base/transport.h @@ -247,6 +247,12 @@ class Transport : public talk_base::MessageHandler, // Must be called before applying local session description. void SetIdentity(talk_base::SSLIdentity* identity); + // Get a copy of the local identity provided by SetIdentity. + bool GetIdentity(talk_base::SSLIdentity** identity); + + // Get a copy of the remote certificate in use by the specified channel. + bool GetRemoteCertificate(talk_base::SSLCertificate** cert); + TransportProtocol protocol() const { return protocol_; } // Create, destroy, and lookup the channels of this type by their components. @@ -349,6 +355,10 @@ class Transport : public talk_base::MessageHandler, virtual void SetIdentity_w(talk_base::SSLIdentity* identity) {} + virtual bool GetIdentity_w(talk_base::SSLIdentity** identity) { + return false; + } + // Pushes down the transport parameters from the local description, such // as the ICE ufrag and pwd. // Derived classes can override, but must call the base as well. @@ -462,6 +472,8 @@ class Transport : public talk_base::MessageHandler, bool SetRemoteTransportDescription_w(const TransportDescription& desc, ContentAction action); bool GetStats_w(TransportStats* infos); + bool GetRemoteCertificate_w(talk_base::SSLCertificate** cert); + talk_base::Thread* signaling_thread_; talk_base::Thread* worker_thread_; diff --git a/talk/p2p/base/transportchannel.h b/talk/p2p/base/transportchannel.h index 85fff7a9f9..c48e1a5424 100644 --- a/talk/p2p/base/transportchannel.h +++ b/talk/p2p/base/transportchannel.h @@ -101,12 +101,18 @@ class TransportChannel : public sigslot::has_slots<> { // Default implementation. virtual bool GetSslRole(talk_base::SSLRole* role) const = 0; - // Set up the ciphers to use for DTLS-SRTP. + // Sets up the ciphers to use for DTLS-SRTP. virtual bool SetSrtpCiphers(const std::vector& ciphers) = 0; - // Find out which DTLS-SRTP cipher was negotiated + // Finds out which DTLS-SRTP cipher was negotiated virtual bool GetSrtpCipher(std::string* cipher) = 0; + // Gets a copy of the local SSL identity, owned by the caller. + virtual bool GetLocalIdentity(talk_base::SSLIdentity** identity) const = 0; + + // Gets a copy of the remote side's SSL certificate, owned by the caller. + virtual bool GetRemoteCertificate(talk_base::SSLCertificate** cert) const = 0; + // Allows key material to be extracted for external encryption. virtual bool ExportKeyingMaterial(const std::string& label, const uint8* context, diff --git a/talk/p2p/base/transportchannelimpl.h b/talk/p2p/base/transportchannelimpl.h index cde2441307..d8432b7323 100644 --- a/talk/p2p/base/transportchannelimpl.h +++ b/talk/p2p/base/transportchannelimpl.h @@ -93,7 +93,10 @@ class TransportChannelImpl : public TransportChannel { virtual void OnCandidate(const Candidate& candidate) = 0; // DTLS methods - // Set DTLS local identity. + // Set DTLS local identity. The identity object is not copied, but the caller + // retains ownership and must delete it after this TransportChannelImpl is + // destroyed. + // TODO(bemasc): Fix the ownership semantics of this method. virtual bool SetLocalIdentity(talk_base::SSLIdentity* identity) = 0; // Set DTLS Remote fingerprint. Must be after local identity set. diff --git a/talk/p2p/base/transportchannelproxy.cc b/talk/p2p/base/transportchannelproxy.cc index 318d133a04..9a10603e7c 100644 --- a/talk/p2p/base/transportchannelproxy.cc +++ b/talk/p2p/base/transportchannelproxy.cc @@ -180,6 +180,24 @@ bool TransportChannelProxy::GetSrtpCipher(std::string* cipher) { return impl_->GetSrtpCipher(cipher); } +bool TransportChannelProxy::GetLocalIdentity( + talk_base::SSLIdentity** identity) const { + ASSERT(talk_base::Thread::Current() == worker_thread_); + if (!impl_) { + return false; + } + return impl_->GetLocalIdentity(identity); +} + +bool TransportChannelProxy::GetRemoteCertificate( + talk_base::SSLCertificate** cert) const { + ASSERT(talk_base::Thread::Current() == worker_thread_); + if (!impl_) { + return false; + } + return impl_->GetRemoteCertificate(cert); +} + bool TransportChannelProxy::ExportKeyingMaterial(const std::string& label, const uint8* context, size_t context_len, diff --git a/talk/p2p/base/transportchannelproxy.h b/talk/p2p/base/transportchannelproxy.h index 29f4663419..3559ed5883 100644 --- a/talk/p2p/base/transportchannelproxy.h +++ b/talk/p2p/base/transportchannelproxy.h @@ -75,6 +75,8 @@ class TransportChannelProxy : public TransportChannel, virtual bool SetSslRole(talk_base::SSLRole role); virtual bool SetSrtpCiphers(const std::vector& ciphers); virtual bool GetSrtpCipher(std::string* cipher); + virtual bool GetLocalIdentity(talk_base::SSLIdentity** identity) const; + virtual bool GetRemoteCertificate(talk_base::SSLCertificate** cert) const; virtual bool ExportKeyingMaterial(const std::string& label, const uint8* context, size_t context_len,