From 9a6533932fd15018206f9f2dd22918f3a696083c Mon Sep 17 00:00:00 2001 From: Philipp Hancke Date: Fri, 13 Sep 2024 20:30:12 -0700 Subject: [PATCH] srtp: spanify key setters BUG=webrtc:357776213 Change-Id: I307085690588e324409bb32a3db5ec9cfa99df52 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/362126 Reviewed-by: Harald Alvestrand Reviewed-by: Florent Castelli Commit-Queue: Philipp Hancke Cr-Commit-Position: refs/heads/main@{#43055} --- pc/BUILD.gn | 1 + pc/dtls_srtp_transport.cc | 12 ++-- pc/srtp_session.cc | 50 +++++++++++----- pc/srtp_session.h | 49 +++++++++------ pc/srtp_session_unittest.cc | 72 +++++++++++----------- pc/srtp_transport.cc | 46 +++++++------- pc/srtp_transport.h | 12 ++-- pc/srtp_transport_unittest.cc | 109 ++++++++++++++++------------------ pc/test/srtp_test_util.h | 7 ++- 9 files changed, 190 insertions(+), 168 deletions(-) diff --git a/pc/BUILD.gn b/pc/BUILD.gn index bbd339a753..706f7bda2b 100644 --- a/pc/BUILD.gn +++ b/pc/BUILD.gn @@ -595,6 +595,7 @@ rtc_source_set("srtp_session") { "../api:scoped_refptr", "../api:sequence_checker", "../modules/rtp_rtcp:rtp_rtcp_format", + "../rtc_base:buffer", "../rtc_base:byte_order", "../rtc_base:checks", "../rtc_base:logging", diff --git a/pc/dtls_srtp_transport.cc b/pc/dtls_srtp_transport.cc index d28285dc8d..96a7a09785 100644 --- a/pc/dtls_srtp_transport.cc +++ b/pc/dtls_srtp_transport.cc @@ -165,10 +165,8 @@ void DtlsSrtpTransport::SetupRtpDtlsSrtp() { if (!ExtractParams(rtp_dtls_transport_, &selected_crypto_suite, &send_key, &recv_key) || - !SetRtpParams(selected_crypto_suite, &send_key[0], - static_cast(send_key.size()), send_extension_ids, - selected_crypto_suite, &recv_key[0], - static_cast(recv_key.size()), recv_extension_ids)) { + !SetRtpParams(selected_crypto_suite, send_key, send_extension_ids, + selected_crypto_suite, recv_key, recv_extension_ids)) { RTC_LOG(LS_WARNING) << "DTLS-SRTP key installation for RTP failed"; } } @@ -195,10 +193,8 @@ void DtlsSrtpTransport::SetupRtcpDtlsSrtp() { rtc::ZeroOnFreeBuffer rtcp_recv_key; if (!ExtractParams(rtcp_dtls_transport_, &selected_crypto_suite, &rtcp_send_key, &rtcp_recv_key) || - !SetRtcpParams(selected_crypto_suite, &rtcp_send_key[0], - static_cast(rtcp_send_key.size()), send_extension_ids, - selected_crypto_suite, &rtcp_recv_key[0], - static_cast(rtcp_recv_key.size()), + !SetRtcpParams(selected_crypto_suite, rtcp_send_key, send_extension_ids, + selected_crypto_suite, rtcp_recv_key, recv_extension_ids)) { RTC_LOG(LS_WARNING) << "DTLS-SRTP key installation for RTCP failed"; } diff --git a/pc/srtp_session.cc b/pc/srtp_session.cc index ee78710709..193298400c 100644 --- a/pc/srtp_session.cc +++ b/pc/srtp_session.cc @@ -27,6 +27,7 @@ #include "rtc_base/logging.h" #include "rtc_base/ssl_stream_adapter.h" #include "rtc_base/string_encode.h" +#include "rtc_base/synchronization/mutex.h" #include "rtc_base/thread_annotations.h" #include "rtc_base/time_utils.h" #include "system_wrappers/include/metrics.h" @@ -171,28 +172,52 @@ bool SrtpSession::SetSend(int crypto_suite, const uint8_t* key, size_t len, const std::vector& extension_ids) { - return SetKey(ssrc_any_outbound, crypto_suite, key, len, extension_ids); + return SetSend(crypto_suite, {key, len}, extension_ids); +} + +bool SrtpSession::SetSend(int crypto_suite, + const rtc::ZeroOnFreeBuffer& key, + const std::vector& extension_ids) { + return SetKey(ssrc_any_outbound, crypto_suite, key, extension_ids); } bool SrtpSession::UpdateSend(int crypto_suite, const uint8_t* key, size_t len, const std::vector& extension_ids) { - return UpdateKey(ssrc_any_outbound, crypto_suite, key, len, extension_ids); + return UpdateSend(crypto_suite, {key, len}, extension_ids); +} + +bool SrtpSession::UpdateSend(int crypto_suite, + const rtc::ZeroOnFreeBuffer& key, + const std::vector& extension_ids) { + return UpdateKey(ssrc_any_outbound, crypto_suite, key, extension_ids); } bool SrtpSession::SetRecv(int crypto_suite, const uint8_t* key, size_t len, const std::vector& extension_ids) { - return SetKey(ssrc_any_inbound, crypto_suite, key, len, extension_ids); + return SetReceive(crypto_suite, {key, len}, extension_ids); +} + +bool SrtpSession::SetReceive(int crypto_suite, + const rtc::ZeroOnFreeBuffer& key, + const std::vector& extension_ids) { + return SetKey(ssrc_any_inbound, crypto_suite, key, extension_ids); } bool SrtpSession::UpdateRecv(int crypto_suite, const uint8_t* key, size_t len, const std::vector& extension_ids) { - return UpdateKey(ssrc_any_inbound, crypto_suite, key, len, extension_ids); + return UpdateReceive(crypto_suite, {key, len}, extension_ids); +} + +bool SrtpSession::UpdateReceive(int crypto_suite, + const rtc::ZeroOnFreeBuffer& key, + const std::vector& extension_ids) { + return UpdateKey(ssrc_any_inbound, crypto_suite, key, extension_ids); } bool SrtpSession::ProtectRtp(void* p, int in_len, int max_len, int* out_len) { @@ -390,8 +415,7 @@ bool SrtpSession::GetSendStreamPacketIndex(void* p, bool SrtpSession::DoSetKey(int type, int crypto_suite, - const uint8_t* key, - size_t len, + const rtc::ZeroOnFreeBuffer& key, const std::vector& extension_ids) { RTC_DCHECK(thread_checker_.IsCurrent()); @@ -408,7 +432,7 @@ bool SrtpSession::DoSetKey(int type, return false; } - if (!key || len != static_cast(policy.rtp.cipher_key_len)) { + if (key.size() != static_cast(policy.rtp.cipher_key_len)) { RTC_LOG(LS_ERROR) << "Failed to " << (session_ ? "update" : "create") << " SRTP session: invalid key"; return false; @@ -416,7 +440,7 @@ bool SrtpSession::DoSetKey(int type, policy.ssrc.type = static_cast(type); policy.ssrc.value = 0; - policy.key = const_cast(key); + policy.key = const_cast(key.data()); // TODO(astor) parse window size from WSH session-param policy.window_size = 1024; policy.allow_repeat_tx = 1; @@ -460,8 +484,7 @@ bool SrtpSession::DoSetKey(int type, bool SrtpSession::SetKey(int type, int crypto_suite, - const uint8_t* key, - size_t len, + const rtc::ZeroOnFreeBuffer& key, const std::vector& extension_ids) { RTC_DCHECK(thread_checker_.IsCurrent()); if (session_) { @@ -479,13 +502,12 @@ bool SrtpSession::SetKey(int type, return false; } - return DoSetKey(type, crypto_suite, key, len, extension_ids); + return DoSetKey(type, crypto_suite, key, extension_ids); } bool SrtpSession::UpdateKey(int type, int crypto_suite, - const uint8_t* key, - size_t len, + const rtc::ZeroOnFreeBuffer& key, const std::vector& extension_ids) { RTC_DCHECK(thread_checker_.IsCurrent()); if (!session_) { @@ -493,7 +515,7 @@ bool SrtpSession::UpdateKey(int type, return false; } - return DoSetKey(type, crypto_suite, key, len, extension_ids); + return DoSetKey(type, crypto_suite, key, extension_ids); } void ProhibitLibsrtpInitialization() { diff --git a/pc/srtp_session.h b/pc/srtp_session.h index 560f32fcd4..3bee25abdb 100644 --- a/pc/srtp_session.h +++ b/pc/srtp_session.h @@ -19,7 +19,7 @@ #include "api/field_trials_view.h" #include "api/scoped_refptr.h" #include "api/sequence_checker.h" -#include "rtc_base/synchronization/mutex.h" +#include "rtc_base/buffer.h" // Forward declaration to avoid pulling in libsrtp headers here struct srtp_event_data_t; @@ -44,25 +44,41 @@ class SrtpSession { // Configures the session for sending data using the specified // crypto suite and key. Receiving must be done by a separate session. + [[deprecated("Pass ZeroOnFreeBuffer to SetSend")]] bool SetSend( + int crypto_suite, + const uint8_t* key, + size_t len, + const std::vector& extension_ids); bool SetSend(int crypto_suite, - const uint8_t* key, - size_t len, + const rtc::ZeroOnFreeBuffer& key, const std::vector& extension_ids); + [[deprecated("Pass ZeroOnFreeBuffer to UpdateSend")]] bool UpdateSend( + int crypto_suite, + const uint8_t* key, + size_t len, + const std::vector& extension_ids); bool UpdateSend(int crypto_suite, - const uint8_t* key, - size_t len, + const rtc::ZeroOnFreeBuffer& key, const std::vector& extension_ids); // Configures the session for receiving data using the specified // crypto suite and key. Sending must be done by a separate session. - bool SetRecv(int crypto_suite, - const uint8_t* key, - size_t len, - const std::vector& extension_ids); - bool UpdateRecv(int crypto_suite, - const uint8_t* key, - size_t len, + [[deprecated("Pass ZeroOnFreeBuffer to SetReceive")]] bool SetRecv( + int crypto_suite, + const uint8_t* key, + size_t len, + const std::vector& extension_ids); + bool SetReceive(int crypto_suite, + const rtc::ZeroOnFreeBuffer& key, const std::vector& extension_ids); + [[deprecated("Pass ZeroOnFreeBuffer to UpdateReceive")]] bool UpdateRecv( + int crypto_suite, + const uint8_t* key, + size_t len, + const std::vector& extension_ids); + bool UpdateReceive(int crypto_suite, + const rtc::ZeroOnFreeBuffer& key, + const std::vector& extension_ids); // Encrypts/signs an individual RTP/RTCP packet, in-place. // If an HMAC is used, this will increase the packet size. @@ -108,18 +124,15 @@ class SrtpSession { private: bool DoSetKey(int type, int crypto_suite, - const uint8_t* key, - size_t len, + const rtc::ZeroOnFreeBuffer& key, const std::vector& extension_ids); bool SetKey(int type, int crypto_suite, - const uint8_t* key, - size_t len, + const rtc::ZeroOnFreeBuffer& key, const std::vector& extension_ids); bool UpdateKey(int type, int crypto_suite, - const uint8_t* key, - size_t len, + const rtc::ZeroOnFreeBuffer& key, const std::vector& extension_ids); // Returns send stream current packet index from srtp db. bool GetSendStreamPacketIndex(void* data, int in_len, int64_t* index); diff --git a/pc/srtp_session_unittest.cc b/pc/srtp_session_unittest.cc index eb6562392b..0785cc0c56 100644 --- a/pc/srtp_session_unittest.cc +++ b/pc/srtp_session_unittest.cc @@ -84,38 +84,40 @@ class SrtpSessionTest : public ::testing::Test { // Test that we can set up the session and keys properly. TEST_F(SrtpSessionTest, TestGoodSetup) { - EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1, kTestKeyLen, - kEncryptedHeaderExtensionIds)); - EXPECT_TRUE(s2_.SetRecv(kSrtpAes128CmSha1_80, kTestKey1, kTestKeyLen, + EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1, kEncryptedHeaderExtensionIds)); + EXPECT_TRUE(s2_.SetReceive(kSrtpAes128CmSha1_80, kTestKey1, + kEncryptedHeaderExtensionIds)); } // Test that we can't change the keys once set. TEST_F(SrtpSessionTest, TestBadSetup) { - EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1, kTestKeyLen, + EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1, kEncryptedHeaderExtensionIds)); - EXPECT_TRUE(s2_.SetRecv(kSrtpAes128CmSha1_80, kTestKey1, kTestKeyLen, - kEncryptedHeaderExtensionIds)); - EXPECT_FALSE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey2, kTestKeyLen, - kEncryptedHeaderExtensionIds)); - EXPECT_FALSE(s2_.SetRecv(kSrtpAes128CmSha1_80, kTestKey2, kTestKeyLen, + EXPECT_TRUE(s2_.SetReceive(kSrtpAes128CmSha1_80, kTestKey1, + kEncryptedHeaderExtensionIds)); + EXPECT_FALSE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey2, kEncryptedHeaderExtensionIds)); + EXPECT_FALSE(s2_.SetReceive(kSrtpAes128CmSha1_80, kTestKey2, + kEncryptedHeaderExtensionIds)); } // Test that we fail keys of the wrong length. TEST_F(SrtpSessionTest, TestKeysTooShort) { - EXPECT_FALSE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1, 1, - kEncryptedHeaderExtensionIds)); - EXPECT_FALSE(s2_.SetRecv(kSrtpAes128CmSha1_80, kTestKey1, 1, + EXPECT_FALSE(s1_.SetSend(kSrtpAes128CmSha1_80, + rtc::ZeroOnFreeBuffer(kTestKey1.data(), 1), kEncryptedHeaderExtensionIds)); + EXPECT_FALSE(s2_.SetReceive( + kSrtpAes128CmSha1_80, rtc::ZeroOnFreeBuffer(kTestKey1.data(), 1), + kEncryptedHeaderExtensionIds)); } // Test that we can encrypt and decrypt RTP/RTCP using AES_CM_128_HMAC_SHA1_80. TEST_F(SrtpSessionTest, TestProtect_AES_CM_128_HMAC_SHA1_80) { - EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1, kTestKeyLen, - kEncryptedHeaderExtensionIds)); - EXPECT_TRUE(s2_.SetRecv(kSrtpAes128CmSha1_80, kTestKey1, kTestKeyLen, + EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1, kEncryptedHeaderExtensionIds)); + EXPECT_TRUE(s2_.SetReceive(kSrtpAes128CmSha1_80, kTestKey1, + kEncryptedHeaderExtensionIds)); TestProtectRtp(kSrtpAes128CmSha1_80); TestProtectRtcp(kSrtpAes128CmSha1_80); TestUnprotectRtp(kSrtpAes128CmSha1_80); @@ -124,10 +126,10 @@ TEST_F(SrtpSessionTest, TestProtect_AES_CM_128_HMAC_SHA1_80) { // Test that we can encrypt and decrypt RTP/RTCP using AES_CM_128_HMAC_SHA1_32. TEST_F(SrtpSessionTest, TestProtect_AES_CM_128_HMAC_SHA1_32) { - EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_32, kTestKey1, kTestKeyLen, - kEncryptedHeaderExtensionIds)); - EXPECT_TRUE(s2_.SetRecv(kSrtpAes128CmSha1_32, kTestKey1, kTestKeyLen, + EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_32, kTestKey1, kEncryptedHeaderExtensionIds)); + EXPECT_TRUE(s2_.SetReceive(kSrtpAes128CmSha1_32, kTestKey1, + kEncryptedHeaderExtensionIds)); TestProtectRtp(kSrtpAes128CmSha1_32); TestProtectRtcp(kSrtpAes128CmSha1_32); TestUnprotectRtp(kSrtpAes128CmSha1_32); @@ -135,7 +137,7 @@ TEST_F(SrtpSessionTest, TestProtect_AES_CM_128_HMAC_SHA1_32) { } TEST_F(SrtpSessionTest, TestGetSendStreamPacketIndex) { - EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_32, kTestKey1, kTestKeyLen, + EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_32, kTestKey1, kEncryptedHeaderExtensionIds)); int64_t index; int out_len = 0; @@ -149,10 +151,10 @@ TEST_F(SrtpSessionTest, TestGetSendStreamPacketIndex) { // Test that we fail to unprotect if someone tampers with the RTP/RTCP paylaods. TEST_F(SrtpSessionTest, TestTamperReject) { int out_len; - EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1, kTestKeyLen, - kEncryptedHeaderExtensionIds)); - EXPECT_TRUE(s2_.SetRecv(kSrtpAes128CmSha1_80, kTestKey1, kTestKeyLen, + EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1, kEncryptedHeaderExtensionIds)); + EXPECT_TRUE(s2_.SetReceive(kSrtpAes128CmSha1_80, kTestKey1, + kEncryptedHeaderExtensionIds)); TestProtectRtp(kSrtpAes128CmSha1_80); TestProtectRtcp(kSrtpAes128CmSha1_80); rtp_packet_[0] = 0x12; @@ -170,10 +172,10 @@ TEST_F(SrtpSessionTest, TestTamperReject) { // Test that we fail to unprotect if the payloads are not authenticated. TEST_F(SrtpSessionTest, TestUnencryptReject) { int out_len; - EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1, kTestKeyLen, - kEncryptedHeaderExtensionIds)); - EXPECT_TRUE(s2_.SetRecv(kSrtpAes128CmSha1_80, kTestKey1, kTestKeyLen, + EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1, kEncryptedHeaderExtensionIds)); + EXPECT_TRUE(s2_.SetReceive(kSrtpAes128CmSha1_80, kTestKey1, + kEncryptedHeaderExtensionIds)); EXPECT_FALSE(s2_.UnprotectRtp(rtp_packet_, rtp_len_, &out_len)); EXPECT_METRIC_THAT( webrtc::metrics::Samples("WebRTC.PeerConnection.SrtpUnprotectError"), @@ -187,7 +189,7 @@ TEST_F(SrtpSessionTest, TestUnencryptReject) { // Test that we fail when using buffers that are too small. TEST_F(SrtpSessionTest, TestBuffersTooSmall) { int out_len; - EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1, kTestKeyLen, + EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1, kEncryptedHeaderExtensionIds)); EXPECT_FALSE(s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_) - 10, &out_len)); @@ -202,10 +204,10 @@ TEST_F(SrtpSessionTest, TestReplay) { static const uint16_t replay_window = 1024; int out_len; - EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1, kTestKeyLen, - kEncryptedHeaderExtensionIds)); - EXPECT_TRUE(s2_.SetRecv(kSrtpAes128CmSha1_80, kTestKey1, kTestKeyLen, + EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1, kEncryptedHeaderExtensionIds)); + EXPECT_TRUE(s2_.SetReceive(kSrtpAes128CmSha1_80, kTestKey1, + kEncryptedHeaderExtensionIds)); // Initial sequence number. SetBE16(reinterpret_cast(rtp_packet_) + 2, seqnum_big); @@ -253,10 +255,10 @@ TEST_F(SrtpSessionTest, TestReplay) { } TEST_F(SrtpSessionTest, RemoveSsrc) { - EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1, kTestKeyLen, - kEncryptedHeaderExtensionIds)); - EXPECT_TRUE(s2_.SetRecv(kSrtpAes128CmSha1_80, kTestKey1, kTestKeyLen, + EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1, kEncryptedHeaderExtensionIds)); + EXPECT_TRUE(s2_.SetReceive(kSrtpAes128CmSha1_80, kTestKey1, + kEncryptedHeaderExtensionIds)); int out_len; // Encrypt and decrypt the packet once. EXPECT_TRUE( @@ -290,10 +292,10 @@ TEST_F(SrtpSessionTest, ProtectUnprotectWrapAroundRocMismatch) { // failures when it wraps around with packet loss. Pick your starting // sequence number in the lower half of the range for robustness reasons, // see packet_sequencer.cc for the code doing so. - EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1, kTestKeyLen, - kEncryptedHeaderExtensionIds)); - EXPECT_TRUE(s2_.SetRecv(kSrtpAes128CmSha1_80, kTestKey1, kTestKeyLen, + EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1, kEncryptedHeaderExtensionIds)); + EXPECT_TRUE(s2_.SetReceive(kSrtpAes128CmSha1_80, kTestKey1, + kEncryptedHeaderExtensionIds)); // Buffers include enough room for the 10 byte SRTP auth tag so we can // encrypt in place. unsigned char kFrame1[] = { diff --git a/pc/srtp_transport.cc b/pc/srtp_transport.cc index 4340f82113..c603385502 100644 --- a/pc/srtp_transport.cc +++ b/pc/srtp_transport.cc @@ -191,17 +191,15 @@ void SrtpTransport::OnWritableState( } bool SrtpTransport::SetRtpParams(int send_crypto_suite, - const uint8_t* send_key, - int send_key_len, + const rtc::ZeroOnFreeBuffer& send_key, const std::vector& send_extension_ids, int recv_crypto_suite, - const uint8_t* recv_key, - int recv_key_len, + const rtc::ZeroOnFreeBuffer& recv_key, const std::vector& recv_extension_ids) { // If parameters are being set for the first time, we should create new SRTP - // sessions and call "SetSend/SetRecv". Otherwise we should call - // "UpdateSend"/"UpdateRecv" on the existing sessions, which will internally - // call "srtp_update". + // sessions and call "SetSend/SetReceive". Otherwise we should call + // "UpdateSend"/"UpdateReceive" on the existing sessions, which will + // internally call "srtp_update". bool new_sessions = false; if (!send_session_) { RTC_DCHECK(!recv_session_); @@ -210,19 +208,18 @@ bool SrtpTransport::SetRtpParams(int send_crypto_suite, } bool ret = new_sessions ? send_session_->SetSend(send_crypto_suite, send_key, - send_key_len, send_extension_ids) + send_extension_ids) : send_session_->UpdateSend(send_crypto_suite, send_key, - send_key_len, send_extension_ids); + send_extension_ids); if (!ret) { ResetParams(); return false; } - ret = new_sessions - ? recv_session_->SetRecv(recv_crypto_suite, recv_key, recv_key_len, - recv_extension_ids) - : recv_session_->UpdateRecv(recv_crypto_suite, recv_key, - recv_key_len, recv_extension_ids); + ret = new_sessions ? recv_session_->SetReceive(recv_crypto_suite, recv_key, + recv_extension_ids) + : recv_session_->UpdateReceive(recv_crypto_suite, recv_key, + recv_extension_ids); if (!ret) { ResetParams(); return false; @@ -236,14 +233,13 @@ bool SrtpTransport::SetRtpParams(int send_crypto_suite, return true; } -bool SrtpTransport::SetRtcpParams(int send_crypto_suite, - const uint8_t* send_key, - int send_key_len, - const std::vector& send_extension_ids, - int recv_crypto_suite, - const uint8_t* recv_key, - int recv_key_len, - const std::vector& recv_extension_ids) { +bool SrtpTransport::SetRtcpParams( + int send_crypto_suite, + const rtc::ZeroOnFreeBuffer& send_key, + const std::vector& send_extension_ids, + int recv_crypto_suite, + const rtc::ZeroOnFreeBuffer& recv_key, + const std::vector& recv_extension_ids) { // This can only be called once, but can be safely called after // SetRtpParams if (send_rtcp_session_ || recv_rtcp_session_) { @@ -252,14 +248,14 @@ bool SrtpTransport::SetRtcpParams(int send_crypto_suite, } send_rtcp_session_.reset(new cricket::SrtpSession(field_trials_)); - if (!send_rtcp_session_->SetSend(send_crypto_suite, send_key, send_key_len, + if (!send_rtcp_session_->SetSend(send_crypto_suite, send_key, send_extension_ids)) { return false; } recv_rtcp_session_.reset(new cricket::SrtpSession(field_trials_)); - if (!recv_rtcp_session_->SetRecv(recv_crypto_suite, recv_key, recv_key_len, - recv_extension_ids)) { + if (!recv_rtcp_session_->SetReceive(recv_crypto_suite, recv_key, + recv_extension_ids)) { return false; } diff --git a/pc/srtp_transport.h b/pc/srtp_transport.h index baa164e983..dd86006ee1 100644 --- a/pc/srtp_transport.h +++ b/pc/srtp_transport.h @@ -58,24 +58,20 @@ class SrtpTransport : public RtpTransport { // packet encryption. The keys can either come from SDES negotiation or DTLS // handshake. bool SetRtpParams(int send_crypto_suite, - const uint8_t* send_key, - int send_key_len, + const rtc::ZeroOnFreeBuffer& send_key, const std::vector& send_extension_ids, int recv_crypto_suite, - const uint8_t* recv_key, - int recv_key_len, + const rtc::ZeroOnFreeBuffer& recv_key, const std::vector& recv_extension_ids); // Create new send/recv sessions and set the negotiated crypto keys for RTCP // packet encryption. The keys can either come from SDES negotiation or DTLS // handshake. bool SetRtcpParams(int send_crypto_suite, - const uint8_t* send_key, - int send_key_len, + const rtc::ZeroOnFreeBuffer& send_key, const std::vector& send_extension_ids, int recv_crypto_suite, - const uint8_t* recv_key, - int recv_key_len, + const rtc::ZeroOnFreeBuffer& recv_key, const std::vector& recv_extension_ids); void ResetParams(); diff --git a/pc/srtp_transport_unittest.cc b/pc/srtp_transport_unittest.cc index a1f153a343..5064a9c601 100644 --- a/pc/srtp_transport_unittest.cc +++ b/pc/srtp_transport_unittest.cc @@ -32,17 +32,18 @@ using rtc::kSrtpAeadAes128Gcm; using rtc::kTestKey1; using rtc::kTestKey2; -using rtc::kTestKeyLen; namespace webrtc { -static const uint8_t kTestKeyGcm128_1[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZ12"; -static const uint8_t kTestKeyGcm128_2[] = "21ZYXWVUTSRQPONMLKJIHGFEDCBA"; -static const int kTestKeyGcm128Len = 28; // 128 bits key + 96 bits salt. -static const uint8_t kTestKeyGcm256_1[] = - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqr"; -static const uint8_t kTestKeyGcm256_2[] = - "rqponmlkjihgfedcbaZYXWVUTSRQPONMLKJIHGFEDCBA"; -static const int kTestKeyGcm256Len = 44; // 256 bits key + 96 bits salt. +// 128 bits key + 96 bits salt. +static const rtc::ZeroOnFreeBuffer kTestKeyGcm128_1{ + "ABCDEFGHIJKLMNOPQRSTUVWXYZ12", 28}; +static const rtc::ZeroOnFreeBuffer kTestKeyGcm128_2{ + "21ZYXWVUTSRQPONMLKJIHGFEDCBA", 28}; +// 256 bits key + 96 bits salt. +static const rtc::ZeroOnFreeBuffer kTestKeyGcm256_1{ + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqr", 44}; +static const rtc::ZeroOnFreeBuffer kTestKeyGcm256_2{ + "rqponmlkjihgfedcbaZYXWVUTSRQPONMLKJIHGFEDCBA", 44}; class SrtpTransportTest : public ::testing::Test, public sigslot::has_slots<> { protected: @@ -216,28 +217,22 @@ class SrtpTransportTest : public ::testing::Test, public sigslot::has_slots<> { void TestSendRecvPacket(bool enable_external_auth, int crypto_suite, - const uint8_t* key1, - int key1_len, - const uint8_t* key2, - int key2_len) { - EXPECT_EQ(key1_len, key2_len); + const rtc::ZeroOnFreeBuffer& key1, + const rtc::ZeroOnFreeBuffer& key2) { + EXPECT_EQ(key1.size(), key2.size()); if (enable_external_auth) { srtp_transport1_->EnableExternalAuth(); srtp_transport2_->EnableExternalAuth(); } std::vector extension_ids; - EXPECT_TRUE(srtp_transport1_->SetRtpParams(crypto_suite, key1, key1_len, - extension_ids, crypto_suite, - key2, key2_len, extension_ids)); - EXPECT_TRUE(srtp_transport2_->SetRtpParams(crypto_suite, key2, key2_len, - extension_ids, crypto_suite, - key1, key1_len, extension_ids)); - EXPECT_TRUE(srtp_transport1_->SetRtcpParams(crypto_suite, key1, key1_len, - extension_ids, crypto_suite, - key2, key2_len, extension_ids)); - EXPECT_TRUE(srtp_transport2_->SetRtcpParams(crypto_suite, key2, key2_len, - extension_ids, crypto_suite, - key1, key1_len, extension_ids)); + EXPECT_TRUE(srtp_transport1_->SetRtpParams( + crypto_suite, key1, extension_ids, crypto_suite, key2, extension_ids)); + EXPECT_TRUE(srtp_transport2_->SetRtpParams( + crypto_suite, key2, extension_ids, crypto_suite, key1, extension_ids)); + EXPECT_TRUE(srtp_transport1_->SetRtcpParams( + crypto_suite, key1, extension_ids, crypto_suite, key2, extension_ids)); + EXPECT_TRUE(srtp_transport2_->SetRtcpParams( + crypto_suite, key2, extension_ids, crypto_suite, key1, extension_ids)); EXPECT_TRUE(srtp_transport1_->IsSrtpActive()); EXPECT_TRUE(srtp_transport2_->IsSrtpActive()); if (rtc::IsGcmCryptoSuite(crypto_suite)) { @@ -308,22 +303,21 @@ class SrtpTransportTest : public ::testing::Test, public sigslot::has_slots<> { original_rtp_data, rtp_len, encrypted_header_ids, false); } - void TestSendRecvEncryptedHeaderExtension(int crypto_suite, - const uint8_t* key1, - int key1_len, - const uint8_t* key2, - int key2_len) { + void TestSendRecvEncryptedHeaderExtension( + int crypto_suite, + const rtc::ZeroOnFreeBuffer& key1, + const rtc::ZeroOnFreeBuffer& key2) { std::vector encrypted_headers; encrypted_headers.push_back(kHeaderExtensionIDs[0]); // Don't encrypt header ids 2 and 3. encrypted_headers.push_back(kHeaderExtensionIDs[1]); - EXPECT_EQ(key1_len, key2_len); - EXPECT_TRUE(srtp_transport1_->SetRtpParams( - crypto_suite, key1, key1_len, encrypted_headers, crypto_suite, key2, - key2_len, encrypted_headers)); - EXPECT_TRUE(srtp_transport2_->SetRtpParams( - crypto_suite, key2, key2_len, encrypted_headers, crypto_suite, key1, - key1_len, encrypted_headers)); + EXPECT_EQ(key1.size(), key2.size()); + EXPECT_TRUE(srtp_transport1_->SetRtpParams(crypto_suite, key1, + encrypted_headers, crypto_suite, + key2, encrypted_headers)); + EXPECT_TRUE(srtp_transport2_->SetRtpParams(crypto_suite, key2, + encrypted_headers, crypto_suite, + key1, encrypted_headers)); EXPECT_TRUE(srtp_transport1_->IsSrtpActive()); EXPECT_TRUE(srtp_transport2_->IsSrtpActive()); EXPECT_FALSE(srtp_transport1_->IsExternalAuthActive()); @@ -353,56 +347,52 @@ TEST_P(SrtpTransportTestWithExternalAuth, SendAndRecvPacket_AES_CM_128_HMAC_SHA1_80) { bool enable_external_auth = GetParam(); TestSendRecvPacket(enable_external_auth, rtc::kSrtpAes128CmSha1_80, kTestKey1, - kTestKeyLen, kTestKey2, kTestKeyLen); + kTestKey2); } TEST_F(SrtpTransportTest, SendAndRecvPacketWithHeaderExtension_AES_CM_128_HMAC_SHA1_80) { TestSendRecvEncryptedHeaderExtension(rtc::kSrtpAes128CmSha1_80, kTestKey1, - kTestKeyLen, kTestKey2, kTestKeyLen); + kTestKey2); } TEST_P(SrtpTransportTestWithExternalAuth, SendAndRecvPacket_AES_CM_128_HMAC_SHA1_32) { bool enable_external_auth = GetParam(); TestSendRecvPacket(enable_external_auth, rtc::kSrtpAes128CmSha1_32, kTestKey1, - kTestKeyLen, kTestKey2, kTestKeyLen); + kTestKey2); } TEST_F(SrtpTransportTest, SendAndRecvPacketWithHeaderExtension_AES_CM_128_HMAC_SHA1_32) { TestSendRecvEncryptedHeaderExtension(rtc::kSrtpAes128CmSha1_32, kTestKey1, - kTestKeyLen, kTestKey2, kTestKeyLen); + kTestKey2); } TEST_P(SrtpTransportTestWithExternalAuth, SendAndRecvPacket_kSrtpAeadAes128Gcm) { bool enable_external_auth = GetParam(); TestSendRecvPacket(enable_external_auth, rtc::kSrtpAeadAes128Gcm, - kTestKeyGcm128_1, kTestKeyGcm128Len, kTestKeyGcm128_2, - kTestKeyGcm128Len); + kTestKeyGcm128_1, kTestKeyGcm128_2); } TEST_F(SrtpTransportTest, SendAndRecvPacketWithHeaderExtension_kSrtpAeadAes128Gcm) { TestSendRecvEncryptedHeaderExtension(rtc::kSrtpAeadAes128Gcm, - kTestKeyGcm128_1, kTestKeyGcm128Len, - kTestKeyGcm128_2, kTestKeyGcm128Len); + kTestKeyGcm128_1, kTestKeyGcm128_2); } TEST_P(SrtpTransportTestWithExternalAuth, SendAndRecvPacket_kSrtpAeadAes256Gcm) { bool enable_external_auth = GetParam(); TestSendRecvPacket(enable_external_auth, rtc::kSrtpAeadAes256Gcm, - kTestKeyGcm256_1, kTestKeyGcm256Len, kTestKeyGcm256_2, - kTestKeyGcm256Len); + kTestKeyGcm256_1, kTestKeyGcm256_2); } TEST_F(SrtpTransportTest, SendAndRecvPacketWithHeaderExtension_kSrtpAeadAes256Gcm) { TestSendRecvEncryptedHeaderExtension(rtc::kSrtpAeadAes256Gcm, - kTestKeyGcm256_1, kTestKeyGcm256Len, - kTestKeyGcm256_2, kTestKeyGcm256Len); + kTestKeyGcm256_1, kTestKeyGcm256_2); } // Run all tests both with and without external auth enabled. @@ -414,11 +404,17 @@ INSTANTIATE_TEST_SUITE_P(ExternalAuth, TEST_F(SrtpTransportTest, TestSetParamsKeyTooShort) { std::vector extension_ids; EXPECT_FALSE(srtp_transport1_->SetRtpParams( - rtc::kSrtpAes128CmSha1_80, kTestKey1, kTestKeyLen - 1, extension_ids, - rtc::kSrtpAes128CmSha1_80, kTestKey1, kTestKeyLen - 1, extension_ids)); + rtc::kSrtpAes128CmSha1_80, + rtc::ZeroOnFreeBuffer(kTestKey1.data(), kTestKey1.size() - 1), + extension_ids, rtc::kSrtpAes128CmSha1_80, + rtc::ZeroOnFreeBuffer(kTestKey1.data(), kTestKey1.size() - 1), + extension_ids)); EXPECT_FALSE(srtp_transport1_->SetRtcpParams( - rtc::kSrtpAes128CmSha1_80, kTestKey1, kTestKeyLen - 1, extension_ids, - rtc::kSrtpAes128CmSha1_80, kTestKey1, kTestKeyLen - 1, extension_ids)); + rtc::kSrtpAes128CmSha1_80, + rtc::ZeroOnFreeBuffer(kTestKey1.data(), kTestKey1.size() - 1), + extension_ids, rtc::kSrtpAes128CmSha1_80, + rtc::ZeroOnFreeBuffer(kTestKey1.data(), kTestKey1.size() - 1), + extension_ids)); } TEST_F(SrtpTransportTest, RemoveSrtpReceiveStream) { @@ -437,9 +433,8 @@ TEST_F(SrtpTransportTest, RemoveSrtpReceiveStream) { std::vector extension_ids; EXPECT_TRUE(srtp_transport->SetRtpParams( - rtc::kSrtpAeadAes128Gcm, kTestKeyGcm128_1, kTestKeyGcm128Len, - extension_ids, rtc::kSrtpAeadAes128Gcm, kTestKeyGcm128_1, - kTestKeyGcm128Len, extension_ids)); + rtc::kSrtpAeadAes128Gcm, kTestKeyGcm128_1, extension_ids, + rtc::kSrtpAeadAes128Gcm, kTestKeyGcm128_1, extension_ids)); RtpDemuxerCriteria demuxer_criteria; uint32_t ssrc = 0x1; // SSRC of kPcmuFrame diff --git a/pc/test/srtp_test_util.h b/pc/test/srtp_test_util.h index 6f74bf482b..ef6069ab6c 100644 --- a/pc/test/srtp_test_util.h +++ b/pc/test/srtp_test_util.h @@ -15,9 +15,10 @@ namespace rtc { -static const uint8_t kTestKey1[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZ1234"; -static const uint8_t kTestKey2[] = "4321ZYXWVUTSRQPONMLKJIHGFEDCBA"; -static const int kTestKeyLen = 30; +static const rtc::ZeroOnFreeBuffer kTestKey1{ + "ABCDEFGHIJKLMNOPQRSTUVWXYZ1234", 30}; +static const rtc::ZeroOnFreeBuffer kTestKey2{ + "4321ZYXWVUTSRQPONMLKJIHGFEDCBA", 30}; static int rtp_auth_tag_len(int crypto_suite) { switch (crypto_suite) {