From c99b6c793653b57cdf462e37109b1ed4d3addeea Mon Sep 17 00:00:00 2001 From: Zhi Huang Date: Fri, 10 Nov 2017 16:44:46 -0800 Subject: [PATCH] Remove the SetEncryptedHeaderExtensionIds methods. The existing methods SetEncrypedHeaderExtensionIds in SrtpTransport and SrtpSession are removed because those methods could be confusing. When these methods are called the head extension IDs are not actually updated and the user need to call SetRtpParams again to make that happen. The existing setter just caches the new IDs. To make it less confusing, the SetEncryptedHeaderExtensionIds is removed and the new extension IDs will be set immediately when setting the crypto params. For SDES, the crypto params and the header extension IDs will be set at the same time. For DTLS, the new header extensions are cached in BaseChannel and will be set when the DTLS handshake is completed. Another major change is that when doing DTLS-SRTP, the encrypted header extension IDs will be updated only when they are changed. Bug: webrtc:7013 Change-Id: Ib70d4797456ae5ecb61b3dfff15c7e3e7ede89bd Reviewed-on: https://webrtc-review.googlesource.com/15860 Commit-Queue: Zhi Huang Reviewed-by: Peter Thatcher Cr-Commit-Position: refs/heads/master@{#20639} --- pc/channel.cc | 73 +++++++++++++++++++++++++++--------- pc/channel.h | 16 ++++++++ pc/srtpsession.cc | 63 +++++++++++++++++++------------ pc/srtpsession.h | 45 +++++++++++++++------- pc/srtpsession_unittest.cc | 62 ++++++++++++++++++++---------- pc/srtptransport.cc | 41 +++++++++----------- pc/srtptransport.h | 16 +++----- pc/srtptransport_unittest.cc | 44 ++++++++++------------ 8 files changed, 229 insertions(+), 131 deletions(-) diff --git a/pc/channel.cc b/pc/channel.cc index 21e666cf7d..3950c03a10 100644 --- a/pc/channel.cc +++ b/pc/channel.cc @@ -878,13 +878,26 @@ bool BaseChannel::SetupDtlsSrtp_n(bool rtcp) { recv_key = &server_write_key; } + // Use an empty encrypted header extension ID vector if not set. This could + // happen when the DTLS handshake is completed before processing the + // Offer/Answer which contains the encrypted header extension IDs. + std::vector send_extension_ids; + std::vector recv_extension_ids; + if (catched_send_extension_ids_) { + send_extension_ids = *catched_send_extension_ids_; + } + if (catched_recv_extension_ids_) { + recv_extension_ids = *catched_recv_extension_ids_; + } + if (rtcp) { if (!dtls_active()) { RTC_DCHECK(srtp_transport_); ret = srtp_transport_->SetRtcpParams( selected_crypto_suite, &(*send_key)[0], - static_cast(send_key->size()), selected_crypto_suite, - &(*recv_key)[0], static_cast(recv_key->size())); + static_cast(send_key->size()), send_extension_ids, + selected_crypto_suite, &(*recv_key)[0], + static_cast(recv_key->size()), recv_extension_ids); } else { // RTCP doesn't need to call SetRtpParam because it is only used // to make the updated encrypted RTP header extension IDs take effect. @@ -892,10 +905,11 @@ bool BaseChannel::SetupDtlsSrtp_n(bool rtcp) { } } else { RTC_DCHECK(srtp_transport_); - ret = srtp_transport_->SetRtpParams(selected_crypto_suite, &(*send_key)[0], - static_cast(send_key->size()), - selected_crypto_suite, &(*recv_key)[0], - static_cast(recv_key->size())); + ret = srtp_transport_->SetRtpParams( + selected_crypto_suite, &(*send_key)[0], + static_cast(send_key->size()), send_extension_ids, + selected_crypto_suite, &(*recv_key)[0], + static_cast(recv_key->size()), recv_extension_ids); dtls_active_ = ret; } @@ -1043,10 +1057,11 @@ bool BaseChannel::SetSrtp_n(const std::vector& cryptos, if (!srtp_transport_ && !dtls && !cryptos.empty()) { EnableSrtpTransport_n(); } - if (srtp_transport_) { - srtp_transport_->SetEncryptedHeaderExtensionIds(src, - encrypted_extension_ids); - } + + bool encrypted_header_extensions_id_changed = + EncryptedHeaderExtensionIdsChanged(src, encrypted_extension_ids); + CacheEncryptedHeaderExtensionIds(src, encrypted_extension_ids); + switch (action) { case CA_OFFER: // If DTLS is already active on the channel, we could be renegotiating @@ -1078,13 +1093,17 @@ bool BaseChannel::SetSrtp_n(const std::vector& cryptos, if ((action == CA_PRANSWER || action == CA_ANSWER) && !dtls && ret) { if (sdes_negotiator_.send_cipher_suite() && sdes_negotiator_.recv_cipher_suite()) { + RTC_DCHECK(catched_send_extension_ids_); + RTC_DCHECK(catched_recv_extension_ids_); ret = srtp_transport_->SetRtpParams( *(sdes_negotiator_.send_cipher_suite()), sdes_negotiator_.send_key().data(), static_cast(sdes_negotiator_.send_key().size()), + *(catched_send_extension_ids_), *(sdes_negotiator_.recv_cipher_suite()), sdes_negotiator_.recv_key().data(), - static_cast(sdes_negotiator_.recv_key().size())); + static_cast(sdes_negotiator_.recv_key().size()), + *(catched_recv_extension_ids_)); } else { RTC_LOG(LS_INFO) << "No crypto keys are provided for SDES."; if (action == CA_ANSWER && srtp_transport_) { @@ -1096,16 +1115,16 @@ bool BaseChannel::SetSrtp_n(const std::vector& cryptos, } } - // Only update SRTP filter if using DTLS. SDES is handled internally + // Only update SRTP transport if using DTLS. SDES is handled internally // by the SRTP filter. - // TODO(jbauch): Only update if encrypted extension ids have changed. if (ret && dtls_active() && rtp_dtls_transport_ && - rtp_dtls_transport_->dtls_state() == DTLS_TRANSPORT_CONNECTED) { - bool rtcp = false; - ret = SetupDtlsSrtp_n(rtcp); + rtp_dtls_transport_->dtls_state() == DTLS_TRANSPORT_CONNECTED && + encrypted_header_extensions_id_changed) { + ret = SetupDtlsSrtp_n(/*rtcp=*/false); } + if (!ret) { - SafeSetError("Failed to setup SRTP filter.", error_desc); + SafeSetError("Failed to setup SRTP.", error_desc); return false; } return true; @@ -1433,6 +1452,26 @@ void BaseChannel::SignalSentPacket_w(const rtc::SentPacket& sent_packet) { SignalSentPacket(sent_packet); } +void BaseChannel::CacheEncryptedHeaderExtensionIds( + cricket::ContentSource source, + const std::vector& extension_ids) { + source == ContentSource::CS_LOCAL + ? catched_recv_extension_ids_.emplace(extension_ids) + : catched_send_extension_ids_.emplace(extension_ids); +} + +bool BaseChannel::EncryptedHeaderExtensionIdsChanged( + cricket::ContentSource source, + const std::vector& new_extension_ids) { + if (source == ContentSource::CS_LOCAL) { + return !catched_recv_extension_ids_ || + (*catched_recv_extension_ids_) != new_extension_ids; + } else { + return !catched_send_extension_ids_ || + (*catched_send_extension_ids_) != new_extension_ids; + } +} + VoiceChannel::VoiceChannel(rtc::Thread* worker_thread, rtc::Thread* network_thread, rtc::Thread* signaling_thread, diff --git a/pc/channel.h b/pc/channel.h index 5689338230..ec13f07cf3 100644 --- a/pc/channel.h +++ b/pc/channel.h @@ -368,6 +368,18 @@ class BaseChannel // Wraps the existing RtpTransport in an SrtpTransport. void EnableSrtpTransport_n(); + // Cache the encrypted header extension IDs when setting the local/remote + // description and use them later together with other crypto parameters from + // DtlsTransport. + void CacheEncryptedHeaderExtensionIds(cricket::ContentSource source, + const std::vector& extension_ids); + + // Return true if the new header extension IDs are different from the existing + // ones. + bool EncryptedHeaderExtensionIdsChanged( + cricket::ContentSource source, + const std::vector& new_extension_ids); + rtc::Thread* const worker_thread_; rtc::Thread* const network_thread_; rtc::Thread* const signaling_thread_; @@ -410,6 +422,10 @@ class BaseChannel MediaContentDirection local_content_direction_ = MD_INACTIVE; MediaContentDirection remote_content_direction_ = MD_INACTIVE; CandidatePairInterface* selected_candidate_pair_; + + // The cached encrypted header extension IDs. + rtc::Optional> catched_send_extension_ids_; + rtc::Optional> catched_recv_extension_ids_; }; // VoiceChannel is a specialization that adds support for early media, DTMF, diff --git a/pc/srtpsession.cc b/pc/srtpsession.cc index 8fe8dc0f8f..a07848d475 100644 --- a/pc/srtpsession.cc +++ b/pc/srtpsession.cc @@ -32,20 +32,32 @@ SrtpSession::~SrtpSession() { } } -bool SrtpSession::SetSend(int cs, const uint8_t* key, size_t len) { - return SetKey(ssrc_any_outbound, cs, key, len); +bool SrtpSession::SetSend(int cs, + const uint8_t* key, + size_t len, + const std::vector& extension_ids) { + return SetKey(ssrc_any_outbound, cs, key, len, extension_ids); } -bool SrtpSession::UpdateSend(int cs, const uint8_t* key, size_t len) { - return UpdateKey(ssrc_any_outbound, cs, key, len); +bool SrtpSession::UpdateSend(int cs, + const uint8_t* key, + size_t len, + const std::vector& extension_ids) { + return UpdateKey(ssrc_any_outbound, cs, key, len, extension_ids); } -bool SrtpSession::SetRecv(int cs, const uint8_t* key, size_t len) { - return SetKey(ssrc_any_inbound, cs, key, len); +bool SrtpSession::SetRecv(int cs, + const uint8_t* key, + size_t len, + const std::vector& extension_ids) { + return SetKey(ssrc_any_inbound, cs, key, len, extension_ids); } -bool SrtpSession::UpdateRecv(int cs, const uint8_t* key, size_t len) { - return UpdateKey(ssrc_any_inbound, cs, key, len); +bool SrtpSession::UpdateRecv(int cs, + const uint8_t* key, + size_t len, + const std::vector& extension_ids) { + return UpdateKey(ssrc_any_inbound, cs, key, len, extension_ids); } bool SrtpSession::ProtectRtp(void* p, int in_len, int max_len, int* out_len) { @@ -203,7 +215,11 @@ bool SrtpSession::GetSendStreamPacketIndex(void* p, return true; } -bool SrtpSession::DoSetKey(int type, int cs, const uint8_t* key, size_t len) { +bool SrtpSession::DoSetKey(int type, + int cs, + const uint8_t* key, + size_t len, + const std::vector& extension_ids) { RTC_DCHECK(thread_checker_.CalledOnValidThread()); srtp_policy_t policy; @@ -262,10 +278,9 @@ bool SrtpSession::DoSetKey(int type, int cs, const uint8_t* key, size_t len) { !rtc::IsGcmCryptoSuite(cs)) { policy.rtp.auth_type = EXTERNAL_HMAC_SHA1; } - if (!encrypted_header_extension_ids_.empty()) { - policy.enc_xtn_hdr = const_cast(&encrypted_header_extension_ids_[0]); - policy.enc_xtn_hdr_count = - static_cast(encrypted_header_extension_ids_.size()); + if (!extension_ids.empty()) { + policy.enc_xtn_hdr = const_cast(&extension_ids[0]); + policy.enc_xtn_hdr_count = static_cast(extension_ids.size()); } policy.next = nullptr; @@ -291,7 +306,11 @@ bool SrtpSession::DoSetKey(int type, int cs, const uint8_t* key, size_t len) { return true; } -bool SrtpSession::SetKey(int type, int cs, const uint8_t* key, size_t len) { +bool SrtpSession::SetKey(int type, + int cs, + const uint8_t* key, + size_t len, + const std::vector& extension_ids) { RTC_DCHECK(thread_checker_.CalledOnValidThread()); if (session_) { RTC_LOG(LS_ERROR) << "Failed to create SRTP session: " @@ -307,23 +326,21 @@ bool SrtpSession::SetKey(int type, int cs, const uint8_t* key, size_t len) { return false; } - return DoSetKey(type, cs, key, len); + return DoSetKey(type, cs, key, len, extension_ids); } -bool SrtpSession::UpdateKey(int type, int cs, const uint8_t* key, size_t len) { +bool SrtpSession::UpdateKey(int type, + int cs, + const uint8_t* key, + size_t len, + const std::vector& extension_ids) { RTC_DCHECK(thread_checker_.CalledOnValidThread()); if (!session_) { RTC_LOG(LS_ERROR) << "Failed to update non-existing SRTP session"; return false; } - return DoSetKey(type, cs, key, len); -} - -void SrtpSession::SetEncryptedHeaderExtensionIds( - const std::vector& encrypted_header_extension_ids) { - RTC_DCHECK(thread_checker_.CalledOnValidThread()); - encrypted_header_extension_ids_ = encrypted_header_extension_ids; + return DoSetKey(type, cs, key, len, extension_ids); } int g_libsrtp_usage_count = 0; diff --git a/pc/srtpsession.h b/pc/srtpsession.h index 94702da130..a6e78fab6b 100644 --- a/pc/srtpsession.h +++ b/pc/srtpsession.h @@ -30,16 +30,25 @@ class SrtpSession { // Configures the session for sending data using the specified // cipher-suite and key. Receiving must be done by a separate session. - bool SetSend(int cs, const uint8_t* key, size_t len); - bool UpdateSend(int cs, const uint8_t* key, size_t len); + bool SetSend(int cs, + const uint8_t* key, + size_t len, + const std::vector& extension_ids); + bool UpdateSend(int cs, + const uint8_t* key, + size_t len, + const std::vector& extension_ids); // Configures the session for receiving data using the specified // cipher-suite and key. Sending must be done by a separate session. - bool SetRecv(int cs, const uint8_t* key, size_t len); - bool UpdateRecv(int cs, const uint8_t* key, size_t len); - - void SetEncryptedHeaderExtensionIds( - const std::vector& encrypted_header_extension_ids); + bool SetRecv(int cs, + const uint8_t* key, + size_t len, + const std::vector& extension_ids); + bool UpdateRecv(int cs, + const uint8_t* key, + size_t len, + const std::vector& extension_ids); // Encrypts/signs an individual RTP/RTCP packet, in-place. // If an HMAC is used, this will increase the packet size. @@ -75,12 +84,21 @@ class SrtpSession { bool IsExternalAuthActive() const; private: - bool DoSetKey(int type, int cs, const uint8_t* key, size_t len); - bool SetKey(int type, int cs, const uint8_t* key, size_t len); - bool UpdateKey(int type, int cs, const uint8_t* key, size_t len); - bool SetEncryptedHeaderExtensionIds( - int type, - const std::vector& encrypted_header_extension_ids); + bool DoSetKey(int type, + int cs, + const uint8_t* key, + size_t len, + const std::vector& extension_ids); + bool SetKey(int type, + int cs, + const uint8_t* key, + size_t len, + const std::vector& extension_ids); + bool UpdateKey(int type, + int cs, + const uint8_t* key, + size_t len, + const std::vector& extension_ids); // Returns send stream current packet index from srtp db. bool GetSendStreamPacketIndex(void* data, int in_len, int64_t* index); @@ -104,7 +122,6 @@ class SrtpSession { int last_send_seq_num_ = -1; bool external_auth_active_ = false; bool external_auth_enabled_ = false; - std::vector encrypted_header_extension_ids_; RTC_DISALLOW_COPY_AND_ASSIGN(SrtpSession); }; diff --git a/pc/srtpsession_unittest.cc b/pc/srtpsession_unittest.cc index b89b3ad55a..dc325739e8 100644 --- a/pc/srtpsession_unittest.cc +++ b/pc/srtpsession_unittest.cc @@ -19,6 +19,8 @@ namespace rtc { +std::vector kEncryptedHeaderExtensionIds; + class SrtpSessionTest : public testing::Test { protected: virtual void SetUp() { @@ -65,28 +67,38 @@ class SrtpSessionTest : public testing::Test { // Test that we can set up the session and keys properly. TEST_F(SrtpSessionTest, TestGoodSetup) { - EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); - EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); + EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); + EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); } // Test that we can't change the keys once set. TEST_F(SrtpSessionTest, TestBadSetup) { - EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); - EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); - EXPECT_FALSE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey2, kTestKeyLen)); - EXPECT_FALSE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey2, kTestKeyLen)); + EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); + EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); + EXPECT_FALSE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey2, kTestKeyLen, + kEncryptedHeaderExtensionIds)); + EXPECT_FALSE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey2, kTestKeyLen, + kEncryptedHeaderExtensionIds)); } // Test that we fail keys of the wrong length. TEST_F(SrtpSessionTest, TestKeysTooShort) { - EXPECT_FALSE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, 1)); - EXPECT_FALSE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, 1)); + EXPECT_FALSE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, 1, + kEncryptedHeaderExtensionIds)); + EXPECT_FALSE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, 1, + kEncryptedHeaderExtensionIds)); } // Test that we can encrypt and decrypt RTP/RTCP using AES_CM_128_HMAC_SHA1_80. TEST_F(SrtpSessionTest, TestProtect_AES_CM_128_HMAC_SHA1_80) { - EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); - EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); + EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); + EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); TestProtectRtp(CS_AES_CM_128_HMAC_SHA1_80); TestProtectRtcp(CS_AES_CM_128_HMAC_SHA1_80); TestUnprotectRtp(CS_AES_CM_128_HMAC_SHA1_80); @@ -95,8 +107,10 @@ TEST_F(SrtpSessionTest, TestProtect_AES_CM_128_HMAC_SHA1_80) { // Test that we can encrypt and decrypt RTP/RTCP using AES_CM_128_HMAC_SHA1_32. TEST_F(SrtpSessionTest, TestProtect_AES_CM_128_HMAC_SHA1_32) { - EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_32, kTestKey1, kTestKeyLen)); - EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_32, kTestKey1, kTestKeyLen)); + EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_32, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); + EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_32, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); TestProtectRtp(CS_AES_CM_128_HMAC_SHA1_32); TestProtectRtcp(CS_AES_CM_128_HMAC_SHA1_32); TestUnprotectRtp(CS_AES_CM_128_HMAC_SHA1_32); @@ -104,7 +118,8 @@ TEST_F(SrtpSessionTest, TestProtect_AES_CM_128_HMAC_SHA1_32) { } TEST_F(SrtpSessionTest, TestGetSendStreamPacketIndex) { - EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_32, kTestKey1, kTestKeyLen)); + EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_32, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); int64_t index; int out_len = 0; EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), @@ -117,8 +132,10 @@ TEST_F(SrtpSessionTest, TestGetSendStreamPacketIndex) { // Test that we fail to unprotect if someone tampers with the RTP/RTCP paylaods. TEST_F(SrtpSessionTest, TestTamperReject) { int out_len; - EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); - EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); + EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); + EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); TestProtectRtp(CS_AES_CM_128_HMAC_SHA1_80); TestProtectRtcp(CS_AES_CM_128_HMAC_SHA1_80); rtp_packet_[0] = 0x12; @@ -130,8 +147,10 @@ TEST_F(SrtpSessionTest, TestTamperReject) { // Test that we fail to unprotect if the payloads are not authenticated. TEST_F(SrtpSessionTest, TestUnencryptReject) { int out_len; - EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); - EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); + EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); + EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); EXPECT_FALSE(s2_.UnprotectRtp(rtp_packet_, rtp_len_, &out_len)); EXPECT_FALSE(s2_.UnprotectRtcp(rtcp_packet_, rtcp_len_, &out_len)); } @@ -139,7 +158,8 @@ TEST_F(SrtpSessionTest, TestUnencryptReject) { // Test that we fail when using buffers that are too small. TEST_F(SrtpSessionTest, TestBuffersTooSmall) { int out_len; - EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); + EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); EXPECT_FALSE(s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_) - 10, &out_len)); EXPECT_FALSE(s1_.ProtectRtcp(rtcp_packet_, rtcp_len_, @@ -153,8 +173,10 @@ TEST_F(SrtpSessionTest, TestReplay) { static const uint16_t replay_window = 1024; int out_len; - EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); - EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); + EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); + EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, + kEncryptedHeaderExtensionIds)); // Initial sequence number. SetBE16(reinterpret_cast(rtp_packet_) + 2, seqnum_big); diff --git a/pc/srtptransport.cc b/pc/srtptransport.cc index b71276cac1..1343fd0cac 100644 --- a/pc/srtptransport.cc +++ b/pc/srtptransport.cc @@ -173,9 +173,11 @@ void SrtpTransport::OnPacketReceived(bool rtcp, bool SrtpTransport::SetRtpParams(int send_cs, const uint8_t* send_key, int send_key_len, + const std::vector& send_extension_ids, int recv_cs, const uint8_t* recv_key, - int recv_key_len) { + int recv_key_len, + const std::vector& recv_extension_ids) { // If parameters are being set for the first time, we should create new SRTP // sessions and call "SetSend/SetRecv". Otherwise we should call // "UpdateSend"/"UpdateRecv" on the existing sessions, which will internally @@ -186,21 +188,20 @@ bool SrtpTransport::SetRtpParams(int send_cs, CreateSrtpSessions(); new_sessions = true; } - send_session_->SetEncryptedHeaderExtensionIds( - send_encrypted_header_extension_ids_); bool ret = new_sessions - ? send_session_->SetSend(send_cs, send_key, send_key_len) - : send_session_->UpdateSend(send_cs, send_key, send_key_len); + ? send_session_->SetSend(send_cs, send_key, send_key_len, + send_extension_ids) + : send_session_->UpdateSend(send_cs, send_key, send_key_len, + send_extension_ids); if (!ret) { ResetParams(); return false; } - recv_session_->SetEncryptedHeaderExtensionIds( - recv_encrypted_header_extension_ids_); - ret = new_sessions - ? recv_session_->SetRecv(recv_cs, recv_key, recv_key_len) - : recv_session_->UpdateRecv(recv_cs, recv_key, recv_key_len); + ret = new_sessions ? recv_session_->SetRecv(recv_cs, recv_key, recv_key_len, + recv_extension_ids) + : recv_session_->UpdateRecv( + recv_cs, recv_key, recv_key_len, recv_extension_ids); if (!ret) { ResetParams(); return false; @@ -216,9 +217,11 @@ bool SrtpTransport::SetRtpParams(int send_cs, bool SrtpTransport::SetRtcpParams(int send_cs, const uint8_t* send_key, int send_key_len, + const std::vector& send_extension_ids, int recv_cs, const uint8_t* recv_key, - int recv_key_len) { + int recv_key_len, + const std::vector& recv_extension_ids) { // This can only be called once, but can be safely called after // SetRtpParams if (send_rtcp_session_ || recv_rtcp_session_) { @@ -227,12 +230,14 @@ bool SrtpTransport::SetRtcpParams(int send_cs, } send_rtcp_session_.reset(new cricket::SrtpSession()); - if (!send_rtcp_session_->SetSend(send_cs, send_key, send_key_len)) { + if (!send_rtcp_session_->SetSend(send_cs, send_key, send_key_len, + send_extension_ids)) { return false; } recv_rtcp_session_.reset(new cricket::SrtpSession()); - if (!recv_rtcp_session_->SetRecv(recv_cs, recv_key, recv_key_len)) { + if (!recv_rtcp_session_->SetRecv(recv_cs, recv_key, recv_key_len, + recv_extension_ids)) { return false; } @@ -255,16 +260,6 @@ void SrtpTransport::ResetParams() { RTC_LOG(LS_INFO) << "The params in SRTP transport are reset."; } -void SrtpTransport::SetEncryptedHeaderExtensionIds( - cricket::ContentSource source, - const std::vector& extension_ids) { - if (source == cricket::CS_LOCAL) { - recv_encrypted_header_extension_ids_ = extension_ids; - } else { - send_encrypted_header_extension_ids_ = extension_ids; - } -} - void SrtpTransport::CreateSrtpSessions() { send_session_.reset(new cricket::SrtpSession()); recv_session_.reset(new cricket::SrtpSession()); diff --git a/pc/srtptransport.h b/pc/srtptransport.h index 03c353c530..13abd6b47d 100644 --- a/pc/srtptransport.h +++ b/pc/srtptransport.h @@ -100,9 +100,11 @@ class SrtpTransport : public RtpTransportInternal { bool SetRtpParams(int send_cs, const uint8_t* send_key, int send_key_len, + const std::vector& send_extension_ids, int recv_cs, const uint8_t* recv_key, - int recv_key_len); + int recv_key_len, + const std::vector& recv_extension_ids); // Create new send/recv sessions and set the negotiated crypto keys for RTCP // packet encryption. The keys can either come from SDES negotiation or DTLS @@ -110,18 +112,14 @@ class SrtpTransport : public RtpTransportInternal { bool SetRtcpParams(int send_cs, const uint8_t* send_key, int send_key_len, + const std::vector& send_extension_ids, int recv_cs, const uint8_t* recv_key, - int recv_key_len); + int recv_key_len, + const std::vector& recv_extension_ids); void ResetParams(); - // Set the header extension ids that should be encrypted for the given source. - // This method doesn't immediately update the SRTP session with the new IDs, - // and you need to call SetRtpParams for that to happen. - void SetEncryptedHeaderExtensionIds(cricket::ContentSource source, - const std::vector& extension_ids); - // If external auth is enabled, SRTP will write a dummy auth tag that then // later must get replaced before the packet is sent out. Only supported for // non-GCM cipher suites and can be checked through "IsExternalAuthActive" @@ -187,8 +185,6 @@ class SrtpTransport : public RtpTransportInternal { std::unique_ptr send_rtcp_session_; std::unique_ptr recv_rtcp_session_; - std::vector send_encrypted_header_extension_ids_; - std::vector recv_encrypted_header_extension_ids_; bool external_auth_enabled_ = false; int rtp_abs_sendtime_extn_id_ = -1; diff --git a/pc/srtptransport_unittest.cc b/pc/srtptransport_unittest.cc index 35a792ddb3..3533863852 100644 --- a/pc/srtptransport_unittest.cc +++ b/pc/srtptransport_unittest.cc @@ -220,14 +220,15 @@ class SrtpTransportTest : public testing::Test, public sigslot::has_slots<> { srtp_transport1_->EnableExternalAuth(); srtp_transport2_->EnableExternalAuth(); } - EXPECT_TRUE( - srtp_transport1_->SetRtpParams(cs, key1, key1_len, cs, key2, key2_len)); - EXPECT_TRUE( - srtp_transport2_->SetRtpParams(cs, key2, key2_len, cs, key1, key1_len)); - EXPECT_TRUE(srtp_transport1_->SetRtcpParams(cs, key1, key1_len, cs, key2, - key2_len)); - EXPECT_TRUE(srtp_transport2_->SetRtcpParams(cs, key2, key2_len, cs, key1, - key1_len)); + std::vector extension_ids; + EXPECT_TRUE(srtp_transport1_->SetRtpParams( + cs, key1, key1_len, extension_ids, cs, key2, key2_len, extension_ids)); + EXPECT_TRUE(srtp_transport2_->SetRtpParams( + cs, key2, key2_len, extension_ids, cs, key1, key1_len, extension_ids)); + EXPECT_TRUE(srtp_transport1_->SetRtcpParams( + cs, key1, key1_len, extension_ids, cs, key2, key2_len, extension_ids)); + EXPECT_TRUE(srtp_transport2_->SetRtcpParams( + cs, key2, key2_len, extension_ids, cs, key1, key1_len, extension_ids)); EXPECT_TRUE(srtp_transport1_->IsActive()); EXPECT_TRUE(srtp_transport2_->IsActive()); if (rtc::IsGcmCryptoSuite(cs)) { @@ -308,18 +309,12 @@ class SrtpTransportTest : public testing::Test, public sigslot::has_slots<> { encrypted_headers.push_back(4); EXPECT_EQ(key1_len, key2_len); EXPECT_EQ(cs_name, rtc::SrtpCryptoSuiteToName(cs)); - srtp_transport1_->SetEncryptedHeaderExtensionIds(cricket::CS_LOCAL, - encrypted_headers); - srtp_transport1_->SetEncryptedHeaderExtensionIds(cricket::CS_REMOTE, - encrypted_headers); - srtp_transport2_->SetEncryptedHeaderExtensionIds(cricket::CS_LOCAL, - encrypted_headers); - srtp_transport2_->SetEncryptedHeaderExtensionIds(cricket::CS_REMOTE, - encrypted_headers); - EXPECT_TRUE( - srtp_transport1_->SetRtpParams(cs, key1, key1_len, cs, key2, key2_len)); - EXPECT_TRUE( - srtp_transport2_->SetRtpParams(cs, key2, key2_len, cs, key1, key1_len)); + EXPECT_TRUE(srtp_transport1_->SetRtpParams(cs, key1, key1_len, + encrypted_headers, cs, key2, + key2_len, encrypted_headers)); + EXPECT_TRUE(srtp_transport2_->SetRtpParams(cs, key2, key2_len, + encrypted_headers, cs, key1, + key1_len, encrypted_headers)); EXPECT_TRUE(srtp_transport1_->IsActive()); EXPECT_TRUE(srtp_transport2_->IsActive()); EXPECT_FALSE(srtp_transport1_->IsExternalAuthActive()); @@ -409,12 +404,13 @@ INSTANTIATE_TEST_CASE_P(ExternalAuth, // Test directly setting the params with bogus keys. TEST_F(SrtpTransportTest, TestSetParamsKeyTooShort) { + std::vector extension_ids; EXPECT_FALSE(srtp_transport1_->SetRtpParams( - rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen - 1, - rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen - 1)); + rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen - 1, extension_ids, + rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen - 1, extension_ids)); EXPECT_FALSE(srtp_transport1_->SetRtcpParams( - rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen - 1, - rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen - 1)); + rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen - 1, extension_ids, + rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen - 1, extension_ids)); } } // namespace webrtc