From 4dde3df3b5ac205b77dab67ad28802ae1ed834ba Mon Sep 17 00:00:00 2001 From: zstein Date: Fri, 7 Jul 2017 14:26:25 -0700 Subject: [PATCH] Move SrtpSession and tests to their own files. BUG=None Review-Url: https://codereview.webrtc.org/2976443002 Cr-Commit-Position: refs/heads/master@{#18935} --- webrtc/pc/BUILD.gn | 4 + webrtc/pc/DEPS | 2 +- webrtc/pc/srtpfilter.cc | 395 +---------------------------- webrtc/pc/srtpfilter.h | 81 ------ webrtc/pc/srtpfilter_unittest.cc | 291 ++++----------------- webrtc/pc/srtpsession.cc | 408 ++++++++++++++++++++++++++++++ webrtc/pc/srtpsession.h | 109 ++++++++ webrtc/pc/srtpsession_unittest.cc | 204 +++++++++++++++ webrtc/pc/srtptestutil.h | 45 ++++ 9 files changed, 815 insertions(+), 724 deletions(-) create mode 100644 webrtc/pc/srtpsession.cc create mode 100644 webrtc/pc/srtpsession.h create mode 100644 webrtc/pc/srtpsession_unittest.cc create mode 100644 webrtc/pc/srtptestutil.h diff --git a/webrtc/pc/BUILD.gn b/webrtc/pc/BUILD.gn index 3bbaa57e19..2ff1a0af50 100644 --- a/webrtc/pc/BUILD.gn +++ b/webrtc/pc/BUILD.gn @@ -50,6 +50,8 @@ rtc_static_library("rtc_pc_base") { "rtptransport.h", "srtpfilter.cc", "srtpfilter.h", + "srtpsession.cc", + "srtpsession.h", "voicechannel.h", ] @@ -257,6 +259,8 @@ if (rtc_include_tests) { "rtcpmuxfilter_unittest.cc", "rtptransport_unittest.cc", "srtpfilter_unittest.cc", + "srtpsession_unittest.cc", + "srtptestutil.h", ] include_dirs = [ "//third_party/libsrtp/srtp" ] diff --git a/webrtc/pc/DEPS b/webrtc/pc/DEPS index d77d279769..ee87af146e 100644 --- a/webrtc/pc/DEPS +++ b/webrtc/pc/DEPS @@ -1,5 +1,5 @@ include_rules = [ - "+third_party/libsrtp" + "+third_party/libsrtp", "+webrtc/api", "+webrtc/base", "+webrtc/call", diff --git a/webrtc/pc/srtpfilter.cc b/webrtc/pc/srtpfilter.cc index a7634757d6..dde84bc14e 100644 --- a/webrtc/pc/srtpfilter.cc +++ b/webrtc/pc/srtpfilter.cc @@ -14,16 +14,13 @@ #include -#include "third_party/libsrtp/include/srtp.h" -#include "third_party/libsrtp/include/srtp_priv.h" #include "webrtc/media/base/rtputils.h" -#include "webrtc/pc/externalhmac.h" +#include "webrtc/pc/srtpsession.h" #include "webrtc/rtc_base/base64.h" #include "webrtc/rtc_base/buffer.h" #include "webrtc/rtc_base/byteorder.h" #include "webrtc/rtc_base/checks.h" #include "webrtc/rtc_base/logging.h" -#include "webrtc/rtc_base/sslstreamadapter.h" #include "webrtc/rtc_base/stringencode.h" #include "webrtc/rtc_base/timeutils.h" @@ -484,394 +481,4 @@ bool SrtpFilter::ParseKeyParams(const std::string& key_params, return true; } -/////////////////////////////////////////////////////////////////////////////// -// SrtpSession - -bool SrtpSession::inited_ = false; - -// This lock protects SrtpSession::inited_. -rtc::GlobalLockPod SrtpSession::lock_; - -SrtpSession::SrtpSession() {} - -SrtpSession::~SrtpSession() { - if (session_) { - srtp_set_user_data(session_, nullptr); - srtp_dealloc(session_); - } -} - -bool SrtpSession::SetSend(int cs, const uint8_t* key, size_t len) { - return SetKey(ssrc_any_outbound, cs, key, len); -} - -bool SrtpSession::UpdateSend(int cs, const uint8_t* key, size_t len) { - return UpdateKey(ssrc_any_outbound, cs, key, len); -} - -bool SrtpSession::SetRecv(int cs, const uint8_t* key, size_t len) { - return SetKey(ssrc_any_inbound, cs, key, len); -} - -bool SrtpSession::UpdateRecv(int cs, const uint8_t* key, size_t len) { - return UpdateKey(ssrc_any_inbound, cs, key, len); -} - -bool SrtpSession::ProtectRtp(void* p, int in_len, int max_len, int* out_len) { - RTC_DCHECK(thread_checker_.CalledOnValidThread()); - if (!session_) { - LOG(LS_WARNING) << "Failed to protect SRTP packet: no SRTP Session"; - return false; - } - - int need_len = in_len + rtp_auth_tag_len_; // NOLINT - if (max_len < need_len) { - LOG(LS_WARNING) << "Failed to protect SRTP packet: The buffer length " - << max_len << " is less than the needed " << need_len; - return false; - } - - *out_len = in_len; - int err = srtp_protect(session_, p, out_len); - int seq_num; - GetRtpSeqNum(p, in_len, &seq_num); - if (err != srtp_err_status_ok) { - LOG(LS_WARNING) << "Failed to protect SRTP packet, seqnum=" - << seq_num << ", err=" << err << ", last seqnum=" - << last_send_seq_num_; - return false; - } - last_send_seq_num_ = seq_num; - return true; -} - -bool SrtpSession::ProtectRtp(void* p, - int in_len, - int max_len, - int* out_len, - int64_t* index) { - if (!ProtectRtp(p, in_len, max_len, out_len)) { - return false; - } - return (index) ? GetSendStreamPacketIndex(p, in_len, index) : true; -} - -bool SrtpSession::ProtectRtcp(void* p, int in_len, int max_len, int* out_len) { - RTC_DCHECK(thread_checker_.CalledOnValidThread()); - if (!session_) { - LOG(LS_WARNING) << "Failed to protect SRTCP packet: no SRTP Session"; - return false; - } - - int need_len = in_len + sizeof(uint32_t) + rtcp_auth_tag_len_; // NOLINT - if (max_len < need_len) { - LOG(LS_WARNING) << "Failed to protect SRTCP packet: The buffer length " - << max_len << " is less than the needed " << need_len; - return false; - } - - *out_len = in_len; - int err = srtp_protect_rtcp(session_, p, out_len); - if (err != srtp_err_status_ok) { - LOG(LS_WARNING) << "Failed to protect SRTCP packet, err=" << err; - return false; - } - return true; -} - -bool SrtpSession::UnprotectRtp(void* p, int in_len, int* out_len) { - RTC_DCHECK(thread_checker_.CalledOnValidThread()); - if (!session_) { - LOG(LS_WARNING) << "Failed to unprotect SRTP packet: no SRTP Session"; - return false; - } - - *out_len = in_len; - int err = srtp_unprotect(session_, p, out_len); - if (err != srtp_err_status_ok) { - LOG(LS_WARNING) << "Failed to unprotect SRTP packet, err=" << err; - return false; - } - return true; -} - -bool SrtpSession::UnprotectRtcp(void* p, int in_len, int* out_len) { - RTC_DCHECK(thread_checker_.CalledOnValidThread()); - if (!session_) { - LOG(LS_WARNING) << "Failed to unprotect SRTCP packet: no SRTP Session"; - return false; - } - - *out_len = in_len; - int err = srtp_unprotect_rtcp(session_, p, out_len); - if (err != srtp_err_status_ok) { - LOG(LS_WARNING) << "Failed to unprotect SRTCP packet, err=" << err; - return false; - } - return true; -} - -bool SrtpSession::GetRtpAuthParams(uint8_t** key, int* key_len, int* tag_len) { - RTC_DCHECK(thread_checker_.CalledOnValidThread()); - RTC_DCHECK(IsExternalAuthActive()); - if (!IsExternalAuthActive()) { - return false; - } - - ExternalHmacContext* external_hmac = nullptr; - // stream_template will be the reference context for other streams. - // Let's use it for getting the keys. - srtp_stream_ctx_t* srtp_context = session_->stream_template; -#if defined(SRTP_MAX_MKI_LEN) - // libsrtp 2.1.0 - if (srtp_context && srtp_context->session_keys && - srtp_context->session_keys->rtp_auth) { - external_hmac = reinterpret_cast( - srtp_context->session_keys->rtp_auth->state); - } -#else - // libsrtp 2.0.0 - // TODO(jbauch): Remove after switching to libsrtp 2.1.0 - if (srtp_context && srtp_context->rtp_auth) { - external_hmac = reinterpret_cast( - srtp_context->rtp_auth->state); - } -#endif - - if (!external_hmac) { - LOG(LS_ERROR) << "Failed to get auth keys from libsrtp!."; - return false; - } - - *key = external_hmac->key; - *key_len = external_hmac->key_length; - *tag_len = rtp_auth_tag_len_; - return true; -} - -int SrtpSession::GetSrtpOverhead() const { - return rtp_auth_tag_len_; -} - -void SrtpSession::EnableExternalAuth() { - RTC_DCHECK(!session_); - external_auth_enabled_ = true; -} - -bool SrtpSession::IsExternalAuthEnabled() const { - return external_auth_enabled_; -} - -bool SrtpSession::IsExternalAuthActive() const { - return external_auth_active_; -} - -bool SrtpSession::GetSendStreamPacketIndex(void* p, - int in_len, - int64_t* index) { - RTC_DCHECK(thread_checker_.CalledOnValidThread()); - srtp_hdr_t* hdr = reinterpret_cast(p); - srtp_stream_ctx_t* stream = srtp_get_stream(session_, hdr->ssrc); - if (!stream) { - return false; - } - - // Shift packet index, put into network byte order - *index = static_cast( - rtc::NetworkToHost64( - srtp_rdbx_get_packet_index(&stream->rtp_rdbx) << 16)); - return true; -} - - -bool SrtpSession::DoSetKey(int type, int cs, const uint8_t* key, size_t len) { - RTC_DCHECK(thread_checker_.CalledOnValidThread()); - - srtp_policy_t policy; - memset(&policy, 0, sizeof(policy)); - if (cs == rtc::SRTP_AES128_CM_SHA1_80) { - srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtp); - srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtcp); - } else if (cs == rtc::SRTP_AES128_CM_SHA1_32) { - // RTP HMAC is shortened to 32 bits, but RTCP remains 80 bits. - srtp_crypto_policy_set_aes_cm_128_hmac_sha1_32(&policy.rtp); - srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtcp); - } else if (cs == rtc::SRTP_AEAD_AES_128_GCM) { - srtp_crypto_policy_set_aes_gcm_128_16_auth(&policy.rtp); - srtp_crypto_policy_set_aes_gcm_128_16_auth(&policy.rtcp); - } else if (cs == rtc::SRTP_AEAD_AES_256_GCM) { - srtp_crypto_policy_set_aes_gcm_256_16_auth(&policy.rtp); - srtp_crypto_policy_set_aes_gcm_256_16_auth(&policy.rtcp); - } else { - LOG(LS_WARNING) << "Failed to " << (session_ ? "update" : "create") - << " SRTP session: unsupported cipher_suite " << cs; - return false; - } - - int expected_key_len; - int expected_salt_len; - if (!rtc::GetSrtpKeyAndSaltLengths(cs, &expected_key_len, - &expected_salt_len)) { - // This should never happen. - LOG(LS_WARNING) << "Failed to " << (session_ ? "update" : "create") - << " SRTP session: unsupported cipher_suite without length information" - << cs; - return false; - } - - if (!key || - len != static_cast(expected_key_len + expected_salt_len)) { - LOG(LS_WARNING) << "Failed to " << (session_ ? "update" : "create") - << " SRTP session: invalid key"; - return false; - } - - policy.ssrc.type = static_cast(type); - policy.ssrc.value = 0; - policy.key = const_cast(key); - // TODO(astor) parse window size from WSH session-param - policy.window_size = 1024; - policy.allow_repeat_tx = 1; - // If external authentication option is enabled, supply custom auth module - // id EXTERNAL_HMAC_SHA1 in the policy structure. - // We want to set this option only for rtp packets. - // By default policy structure is initialized to HMAC_SHA1. - // Enable external HMAC authentication only for outgoing streams and only - // for cipher suites that support it (i.e. only non-GCM cipher suites). - if (type == ssrc_any_outbound && IsExternalAuthEnabled() && - !rtc::IsGcmCryptoSuite(cs)) { - policy.rtp.auth_type = EXTERNAL_HMAC_SHA1; - } - if (!encrypted_header_extension_ids_.empty()) { - policy.enc_xtn_hdr = const_cast(&encrypted_header_extension_ids_[0]); - policy.enc_xtn_hdr_count = - static_cast(encrypted_header_extension_ids_.size()); - } - policy.next = nullptr; - - if (!session_) { - int err = srtp_create(&session_, &policy); - if (err != srtp_err_status_ok) { - session_ = nullptr; - LOG(LS_ERROR) << "Failed to create SRTP session, err=" << err; - return false; - } - srtp_set_user_data(session_, this); - } else { - int err = srtp_update(session_, &policy); - if (err != srtp_err_status_ok) { - LOG(LS_ERROR) << "Failed to update SRTP session, err=" << err; - return false; - } - } - - rtp_auth_tag_len_ = policy.rtp.auth_tag_len; - rtcp_auth_tag_len_ = policy.rtcp.auth_tag_len; - external_auth_active_ = (policy.rtp.auth_type == EXTERNAL_HMAC_SHA1); - return true; -} - -bool SrtpSession::SetKey(int type, int cs, const uint8_t* key, size_t len) { - RTC_DCHECK(thread_checker_.CalledOnValidThread()); - if (session_) { - LOG(LS_ERROR) << "Failed to create SRTP session: " - << "SRTP session already created"; - return false; - } - - if (!Init()) { - return false; - } - - return DoSetKey(type, cs, key, len); -} - -bool SrtpSession::UpdateKey(int type, int cs, const uint8_t* key, size_t len) { - RTC_DCHECK(thread_checker_.CalledOnValidThread()); - if (!session_) { - LOG(LS_ERROR) << "Failed to update non-existing SRTP session"; - return false; - } - - return DoSetKey(type, cs, key, len); -} - -void SrtpSession::SetEncryptedHeaderExtensionIds( - const std::vector& encrypted_header_extension_ids) { - RTC_DCHECK(thread_checker_.CalledOnValidThread()); - encrypted_header_extension_ids_ = encrypted_header_extension_ids; -} - -bool SrtpSession::Init() { - rtc::GlobalLockScope ls(&lock_); - - if (!inited_) { - int err; - err = srtp_init(); - if (err != srtp_err_status_ok) { - LOG(LS_ERROR) << "Failed to init SRTP, err=" << err; - return false; - } - - err = srtp_install_event_handler(&SrtpSession::HandleEventThunk); - if (err != srtp_err_status_ok) { - LOG(LS_ERROR) << "Failed to install SRTP event handler, err=" << err; - return false; - } - - err = external_crypto_init(); - if (err != srtp_err_status_ok) { - LOG(LS_ERROR) << "Failed to initialize fake auth, err=" << err; - return false; - } - inited_ = true; - } - - return true; -} - -void SrtpSession::Terminate() { - rtc::GlobalLockScope ls(&lock_); - - if (inited_) { - int err = srtp_shutdown(); - if (err) { - LOG(LS_ERROR) << "srtp_shutdown failed. err=" << err; - return; - } - inited_ = false; - } -} - -void SrtpSession::HandleEvent(const srtp_event_data_t* ev) { - RTC_DCHECK(thread_checker_.CalledOnValidThread()); - switch (ev->event) { - case event_ssrc_collision: - LOG(LS_INFO) << "SRTP event: SSRC collision"; - break; - case event_key_soft_limit: - LOG(LS_INFO) << "SRTP event: reached soft key usage limit"; - break; - case event_key_hard_limit: - LOG(LS_INFO) << "SRTP event: reached hard key usage limit"; - break; - case event_packet_index_limit: - LOG(LS_INFO) << "SRTP event: reached hard packet limit (2^48 packets)"; - break; - default: - LOG(LS_INFO) << "SRTP event: unknown " << ev->event; - break; - } -} - -void SrtpSession::HandleEventThunk(srtp_event_data_t* ev) { - // Callback will be executed from same thread that calls the "srtp_protect" - // and "srtp_unprotect" functions. - SrtpSession* session = static_cast( - srtp_get_user_data(ev->session)); - if (session) { - session->HandleEvent(ev); - } -} - } // namespace cricket diff --git a/webrtc/pc/srtpfilter.h b/webrtc/pc/srtpfilter.h index 97ae26b00d..15fdae9582 100644 --- a/webrtc/pc/srtpfilter.h +++ b/webrtc/pc/srtpfilter.h @@ -191,87 +191,6 @@ class SrtpFilter { std::vector recv_encrypted_header_extension_ids_; }; -// Class that wraps a libSRTP session. -class SrtpSession { - public: - SrtpSession(); - ~SrtpSession(); - - // Configures the session for sending data using the specified - // cipher-suite and key. Receiving must be done by a separate session. - bool SetSend(int cs, const uint8_t* key, size_t len); - bool UpdateSend(int cs, const uint8_t* key, size_t len); - - // Configures the session for receiving data using the specified - // cipher-suite and key. Sending must be done by a separate session. - bool SetRecv(int cs, const uint8_t* key, size_t len); - bool UpdateRecv(int cs, const uint8_t* key, size_t len); - - void SetEncryptedHeaderExtensionIds( - const std::vector& encrypted_header_extension_ids); - - // Encrypts/signs an individual RTP/RTCP packet, in-place. - // If an HMAC is used, this will increase the packet size. - bool ProtectRtp(void* data, int in_len, int max_len, int* out_len); - // Overloaded version, outputs packet index. - bool ProtectRtp(void* data, - int in_len, - int max_len, - int* out_len, - int64_t* index); - bool ProtectRtcp(void* data, int in_len, int max_len, int* out_len); - // Decrypts/verifies an invidiual RTP/RTCP packet. - // If an HMAC is used, this will decrease the packet size. - bool UnprotectRtp(void* data, int in_len, int* out_len); - bool UnprotectRtcp(void* data, int in_len, int* out_len); - - // Helper method to get authentication params. - bool GetRtpAuthParams(uint8_t** key, int* key_len, int* tag_len); - - int GetSrtpOverhead() const; - - // If external auth is enabled, SRTP will write a dummy auth tag that then - // later must get replaced before the packet is sent out. Only supported for - // non-GCM cipher suites and can be checked through "IsExternalAuthActive" - // if it is actually used. This method is only valid before the RTP params - // have been set. - void EnableExternalAuth(); - bool IsExternalAuthEnabled() const; - - // A SRTP session supports external creation of the auth tag if a non-GCM - // cipher is used. This method is only valid after the RTP params have - // been set. - bool IsExternalAuthActive() const; - - // Calls srtp_shutdown if it's initialized. - static void Terminate(); - - private: - bool DoSetKey(int type, int cs, const uint8_t* key, size_t len); - bool SetKey(int type, int cs, const uint8_t* key, size_t len); - bool UpdateKey(int type, int cs, const uint8_t* key, size_t len); - bool SetEncryptedHeaderExtensionIds(int type, - const std::vector& encrypted_header_extension_ids); - // Returns send stream current packet index from srtp db. - bool GetSendStreamPacketIndex(void* data, int in_len, int64_t* index); - - static bool Init(); - void HandleEvent(const srtp_event_data_t* ev); - static void HandleEventThunk(srtp_event_data_t* ev); - - rtc::ThreadChecker thread_checker_; - srtp_ctx_t_* session_ = nullptr; - int rtp_auth_tag_len_ = 0; - int rtcp_auth_tag_len_ = 0; - static bool inited_; - static rtc::GlobalLockPod lock_; - int last_send_seq_num_ = -1; - bool external_auth_active_ = false; - bool external_auth_enabled_ = false; - std::vector encrypted_header_extension_ids_; - RTC_DISALLOW_COPY_AND_ASSIGN(SrtpSession); -}; - } // namespace cricket #endif // WEBRTC_PC_SRTPFILTER_H_ diff --git a/webrtc/pc/srtpfilter_unittest.cc b/webrtc/pc/srtpfilter_unittest.cc index a200de2e0c..3f6f008a11 100644 --- a/webrtc/pc/srtpfilter_unittest.cc +++ b/webrtc/pc/srtpfilter_unittest.cc @@ -12,27 +12,22 @@ #include "webrtc/pc/srtpfilter.h" -#include "third_party/libsrtp/include/srtp.h" #include "webrtc/media/base/cryptoparams.h" #include "webrtc/media/base/fakertp.h" #include "webrtc/p2p/base/sessiondescription.h" +#include "webrtc/pc/srtptestutil.h" #include "webrtc/rtc_base/buffer.h" #include "webrtc/rtc_base/byteorder.h" #include "webrtc/rtc_base/constructormagic.h" #include "webrtc/rtc_base/gunit.h" #include "webrtc/rtc_base/thread.h" -using rtc::CS_AES_CM_128_HMAC_SHA1_80; -using rtc::CS_AES_CM_128_HMAC_SHA1_32; -using rtc::CS_AEAD_AES_128_GCM; -using rtc::CS_AEAD_AES_256_GCM; using cricket::CryptoParams; using cricket::CS_LOCAL; using cricket::CS_REMOTE; -static const uint8_t kTestKey1[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZ1234"; -static const uint8_t kTestKey2[] = "4321ZYXWVUTSRQPONMLKJIHGFEDCBA"; -static const int kTestKeyLen = 30; +namespace rtc { + static const uint8_t kTestKeyGcm128_1[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZ12"; static const uint8_t kTestKeyGcm128_2[] = "21ZYXWVUTSRQPONMLKJIHGFEDCBA"; static const int kTestKeyGcm128Len = 28; // 128 bits key + 96 bits salt. @@ -70,23 +65,6 @@ static const cricket::CryptoParams kTestCryptoParamsGcm3( static const cricket::CryptoParams kTestCryptoParamsGcm4( 1, "AEAD_AES_128_GCM", kTestKeyParamsGcm4, ""); -static int rtp_auth_tag_len(const std::string& cs) { - if (cs == CS_AES_CM_128_HMAC_SHA1_32) { - return 4; - } else if (cs == CS_AEAD_AES_128_GCM || cs == CS_AEAD_AES_256_GCM) { - return 16; - } else { - return 10; - } -} -static int rtcp_auth_tag_len(const std::string& cs) { - if (cs == CS_AEAD_AES_128_GCM || cs == CS_AEAD_AES_256_GCM) { - return 16; - } else { - return 10; - } -} - class SrtpFilterTest : public testing::Test { protected: SrtpFilterTest() @@ -112,11 +90,11 @@ class SrtpFilterTest : public testing::Test { void TestRtpAuthParams(cricket::SrtpFilter* filter, const std::string& cs) { int overhead; EXPECT_TRUE(filter->GetSrtpOverhead(&overhead)); - switch (rtc::SrtpCryptoSuiteFromName(cs)) { - case rtc::SRTP_AES128_CM_SHA1_32: + switch (SrtpCryptoSuiteFromName(cs)) { + case SRTP_AES128_CM_SHA1_32: EXPECT_EQ(32/8, overhead); // 32-bit tag. break; - case rtc::SRTP_AES128_CM_SHA1_80: + case SRTP_AES128_CM_SHA1_80: EXPECT_EQ(80/8, overhead); // 80-bit tag. break; default: @@ -133,17 +111,16 @@ class SrtpFilterTest : public testing::Test { EXPECT_EQ(overhead, tag_len); } void TestProtectUnprotect(const std::string& cs1, const std::string& cs2) { - rtc::Buffer rtp_buffer(sizeof(kPcmuFrame) + rtp_auth_tag_len(cs1)); + Buffer rtp_buffer(sizeof(kPcmuFrame) + rtp_auth_tag_len(cs1)); char* rtp_packet = rtp_buffer.data(); char original_rtp_packet[sizeof(kPcmuFrame)]; - rtc::Buffer rtcp_buffer(sizeof(kRtcpReport) + 4 + rtcp_auth_tag_len(cs2)); + Buffer rtcp_buffer(sizeof(kRtcpReport) + 4 + rtcp_auth_tag_len(cs2)); char* rtcp_packet = rtcp_buffer.data(); int rtp_len = sizeof(kPcmuFrame), rtcp_len = sizeof(kRtcpReport), out_len; memcpy(rtp_packet, kPcmuFrame, rtp_len); // In order to be able to run this test function multiple times we can not // use the same sequence number twice. Increase the sequence number by one. - rtc::SetBE16(reinterpret_cast(rtp_packet) + 2, - ++sequence_number_); + SetBE16(reinterpret_cast(rtp_packet) + 2, ++sequence_number_); memcpy(original_rtp_packet, rtp_packet, rtp_len); memcpy(rtcp_packet, kRtcpReport, rtcp_len); @@ -198,8 +175,7 @@ class SrtpFilterTest : public testing::Test { void TestProtectUnprotectHeaderEncryption(const std::string& cs1, const std::string& cs2, const std::vector& encrypted_header_ids) { - rtc::Buffer rtp_buffer(sizeof(kPcmuFrameWithExtensions) + - rtp_auth_tag_len(cs1)); + Buffer rtp_buffer(sizeof(kPcmuFrameWithExtensions) + rtp_auth_tag_len(cs1)); char* rtp_packet = rtp_buffer.data(); size_t rtp_packet_size = rtp_buffer.size(); char original_rtp_packet[sizeof(kPcmuFrameWithExtensions)]; @@ -208,8 +184,7 @@ class SrtpFilterTest : public testing::Test { memcpy(rtp_packet, kPcmuFrameWithExtensions, rtp_len); // In order to be able to run this test function multiple times we can not // use the same sequence number twice. Increase the sequence number by one. - rtc::SetBE16(reinterpret_cast(rtp_packet) + 2, - ++sequence_number_); + SetBE16(reinterpret_cast(rtp_packet) + 2, ++sequence_number_); memcpy(original_rtp_packet, rtp_packet, rtp_len); EXPECT_TRUE(f1_.ProtectRtp(rtp_packet, rtp_len, @@ -246,7 +221,7 @@ class SrtpFilterTest : public testing::Test { const uint8_t* key1, int key1_len, const uint8_t* key2, int key2_len, const std::string& cs_name) { EXPECT_EQ(key1_len, key2_len); - EXPECT_EQ(cs_name, rtc::SrtpCryptoSuiteToName(cs)); + EXPECT_EQ(cs_name, SrtpCryptoSuiteToName(cs)); if (enable_external_auth) { f1_.EnableExternalAuth(); f2_.EnableExternalAuth(); @@ -257,7 +232,7 @@ class SrtpFilterTest : public testing::Test { EXPECT_TRUE(f2_.SetRtcpParams(cs, key2, key2_len, cs, key1, key1_len)); EXPECT_TRUE(f1_.IsActive()); EXPECT_TRUE(f2_.IsActive()); - if (rtc::IsGcmCryptoSuite(cs)) { + if (IsGcmCryptoSuite(cs)) { EXPECT_FALSE(f1_.IsExternalAuthActive()); EXPECT_FALSE(f2_.IsExternalAuthActive()); } else if (enable_external_auth) { @@ -274,7 +249,7 @@ class SrtpFilterTest : public testing::Test { // Don't encrypt header ids 2 and 3. encrypted_headers.push_back(4); EXPECT_EQ(key1_len, key2_len); - EXPECT_EQ(cs_name, rtc::SrtpCryptoSuiteToName(cs)); + EXPECT_EQ(cs_name, SrtpCryptoSuiteToName(cs)); f1_.SetEncryptedHeaderExtensionIds(CS_LOCAL, encrypted_headers); f1_.SetEncryptedHeaderExtensionIds(CS_REMOTE, encrypted_headers); f2_.SetEncryptedHeaderExtensionIds(CS_LOCAL, encrypted_headers); @@ -684,61 +659,63 @@ class SrtpFilterProtectSetParamsDirectTest // Test directly setting the params with AES_CM_128_HMAC_SHA1_80. TEST_P(SrtpFilterProtectSetParamsDirectTest, Test_AES_CM_128_HMAC_SHA1_80) { bool enable_external_auth = GetParam(); - TestProtectSetParamsDirect(enable_external_auth, rtc::SRTP_AES128_CM_SHA1_80, - kTestKey1, kTestKeyLen, kTestKey2, kTestKeyLen, - CS_AES_CM_128_HMAC_SHA1_80); + TestProtectSetParamsDirect(enable_external_auth, SRTP_AES128_CM_SHA1_80, + kTestKey1, kTestKeyLen, kTestKey2, kTestKeyLen, + CS_AES_CM_128_HMAC_SHA1_80); } TEST_F(SrtpFilterTest, TestProtectSetParamsDirectHeaderEncryption_AES_CM_128_HMAC_SHA1_80) { - TestProtectSetParamsDirectHeaderEncryption(rtc::SRTP_AES128_CM_SHA1_80, - kTestKey1, kTestKeyLen, kTestKey2, kTestKeyLen, + TestProtectSetParamsDirectHeaderEncryption( + SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, kTestKey2, kTestKeyLen, CS_AES_CM_128_HMAC_SHA1_80); } // Test directly setting the params with AES_CM_128_HMAC_SHA1_32. TEST_P(SrtpFilterProtectSetParamsDirectTest, Test_AES_CM_128_HMAC_SHA1_32) { bool enable_external_auth = GetParam(); - TestProtectSetParamsDirect(enable_external_auth, rtc::SRTP_AES128_CM_SHA1_32, - kTestKey1, kTestKeyLen, kTestKey2, kTestKeyLen, - CS_AES_CM_128_HMAC_SHA1_32); + TestProtectSetParamsDirect(enable_external_auth, SRTP_AES128_CM_SHA1_32, + kTestKey1, kTestKeyLen, kTestKey2, kTestKeyLen, + CS_AES_CM_128_HMAC_SHA1_32); } TEST_F(SrtpFilterTest, TestProtectSetParamsDirectHeaderEncryption_AES_CM_128_HMAC_SHA1_32) { - TestProtectSetParamsDirectHeaderEncryption(rtc::SRTP_AES128_CM_SHA1_32, - kTestKey1, kTestKeyLen, kTestKey2, kTestKeyLen, + TestProtectSetParamsDirectHeaderEncryption( + SRTP_AES128_CM_SHA1_32, kTestKey1, kTestKeyLen, kTestKey2, kTestKeyLen, CS_AES_CM_128_HMAC_SHA1_32); } // Test directly setting the params with SRTP_AEAD_AES_128_GCM. TEST_P(SrtpFilterProtectSetParamsDirectTest, Test_SRTP_AEAD_AES_128_GCM) { bool enable_external_auth = GetParam(); - TestProtectSetParamsDirect(enable_external_auth, rtc::SRTP_AEAD_AES_128_GCM, - kTestKeyGcm128_1, kTestKeyGcm128Len, kTestKeyGcm128_2, kTestKeyGcm128Len, - CS_AEAD_AES_128_GCM); + TestProtectSetParamsDirect(enable_external_auth, SRTP_AEAD_AES_128_GCM, + kTestKeyGcm128_1, kTestKeyGcm128Len, + kTestKeyGcm128_2, kTestKeyGcm128Len, + CS_AEAD_AES_128_GCM); } TEST_F(SrtpFilterTest, TestProtectSetParamsDirectHeaderEncryption_SRTP_AEAD_AES_128_GCM) { - TestProtectSetParamsDirectHeaderEncryption(rtc::SRTP_AEAD_AES_128_GCM, - kTestKeyGcm128_1, kTestKeyGcm128Len, kTestKeyGcm128_2, kTestKeyGcm128Len, - CS_AEAD_AES_128_GCM); + TestProtectSetParamsDirectHeaderEncryption( + SRTP_AEAD_AES_128_GCM, kTestKeyGcm128_1, kTestKeyGcm128Len, + kTestKeyGcm128_2, kTestKeyGcm128Len, CS_AEAD_AES_128_GCM); } // Test directly setting the params with SRTP_AEAD_AES_256_GCM. TEST_P(SrtpFilterProtectSetParamsDirectTest, Test_SRTP_AEAD_AES_256_GCM) { bool enable_external_auth = GetParam(); - TestProtectSetParamsDirect(enable_external_auth, rtc::SRTP_AEAD_AES_256_GCM, - kTestKeyGcm256_1, kTestKeyGcm256Len, kTestKeyGcm256_2, kTestKeyGcm256Len, - CS_AEAD_AES_256_GCM); + TestProtectSetParamsDirect(enable_external_auth, SRTP_AEAD_AES_256_GCM, + kTestKeyGcm256_1, kTestKeyGcm256Len, + kTestKeyGcm256_2, kTestKeyGcm256Len, + CS_AEAD_AES_256_GCM); } TEST_F(SrtpFilterTest, TestProtectSetParamsDirectHeaderEncryption_SRTP_AEAD_AES_256_GCM) { - TestProtectSetParamsDirectHeaderEncryption(rtc::SRTP_AEAD_AES_256_GCM, - kTestKeyGcm256_1, kTestKeyGcm256Len, kTestKeyGcm256_2, kTestKeyGcm256Len, - CS_AEAD_AES_256_GCM); + TestProtectSetParamsDirectHeaderEncryption( + SRTP_AEAD_AES_256_GCM, kTestKeyGcm256_1, kTestKeyGcm256Len, + kTestKeyGcm256_2, kTestKeyGcm256Len, CS_AEAD_AES_256_GCM); } // Run all tests both with and without external auth enabled. @@ -748,194 +725,12 @@ INSTANTIATE_TEST_CASE_P(ExternalAuth, // Test directly setting the params with bogus keys. TEST_F(SrtpFilterTest, TestSetParamsKeyTooShort) { - EXPECT_FALSE(f1_.SetRtpParams(rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, - kTestKeyLen - 1, rtc::SRTP_AES128_CM_SHA1_80, + EXPECT_FALSE(f1_.SetRtpParams(SRTP_AES128_CM_SHA1_80, kTestKey1, + kTestKeyLen - 1, SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen - 1)); - EXPECT_FALSE(f1_.SetRtcpParams(rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, - kTestKeyLen - 1, rtc::SRTP_AES128_CM_SHA1_80, + EXPECT_FALSE(f1_.SetRtcpParams(SRTP_AES128_CM_SHA1_80, kTestKey1, + kTestKeyLen - 1, SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen - 1)); } -class SrtpSessionTest : public testing::Test { - protected: - virtual void SetUp() { - rtp_len_ = sizeof(kPcmuFrame); - rtcp_len_ = sizeof(kRtcpReport); - memcpy(rtp_packet_, kPcmuFrame, rtp_len_); - memcpy(rtcp_packet_, kRtcpReport, rtcp_len_); - } - void TestProtectRtp(const std::string& cs) { - int out_len = 0; - EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_, rtp_len_, - sizeof(rtp_packet_), &out_len)); - EXPECT_EQ(out_len, rtp_len_ + rtp_auth_tag_len(cs)); - EXPECT_NE(0, memcmp(rtp_packet_, kPcmuFrame, rtp_len_)); - rtp_len_ = out_len; - } - void TestProtectRtcp(const std::string& cs) { - int out_len = 0; - EXPECT_TRUE(s1_.ProtectRtcp(rtcp_packet_, rtcp_len_, - sizeof(rtcp_packet_), &out_len)); - EXPECT_EQ(out_len, rtcp_len_ + 4 + rtcp_auth_tag_len(cs)); // NOLINT - EXPECT_NE(0, memcmp(rtcp_packet_, kRtcpReport, rtcp_len_)); - rtcp_len_ = out_len; - } - void TestUnprotectRtp(const std::string& cs) { - int out_len = 0, expected_len = sizeof(kPcmuFrame); - EXPECT_TRUE(s2_.UnprotectRtp(rtp_packet_, rtp_len_, &out_len)); - EXPECT_EQ(expected_len, out_len); - EXPECT_EQ(0, memcmp(rtp_packet_, kPcmuFrame, out_len)); - } - void TestUnprotectRtcp(const std::string& cs) { - int out_len = 0, expected_len = sizeof(kRtcpReport); - EXPECT_TRUE(s2_.UnprotectRtcp(rtcp_packet_, rtcp_len_, &out_len)); - EXPECT_EQ(expected_len, out_len); - EXPECT_EQ(0, memcmp(rtcp_packet_, kRtcpReport, out_len)); - } - cricket::SrtpSession s1_; - cricket::SrtpSession s2_; - char rtp_packet_[sizeof(kPcmuFrame) + 10]; - char rtcp_packet_[sizeof(kRtcpReport) + 4 + 10]; - int rtp_len_; - int rtcp_len_; -}; - -// Test that we can set up the session and keys properly. -TEST_F(SrtpSessionTest, TestGoodSetup) { - EXPECT_TRUE(s1_.SetSend(rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); - EXPECT_TRUE(s2_.SetRecv(rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); -} - -// Test that we can't change the keys once set. -TEST_F(SrtpSessionTest, TestBadSetup) { - EXPECT_TRUE(s1_.SetSend(rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); - EXPECT_TRUE(s2_.SetRecv(rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); - EXPECT_FALSE( - s1_.SetSend(rtc::SRTP_AES128_CM_SHA1_80, kTestKey2, kTestKeyLen)); - EXPECT_FALSE( - s2_.SetRecv(rtc::SRTP_AES128_CM_SHA1_80, kTestKey2, kTestKeyLen)); -} - -// Test that we fail keys of the wrong length. -TEST_F(SrtpSessionTest, TestKeysTooShort) { - EXPECT_FALSE(s1_.SetSend(rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, 1)); - EXPECT_FALSE(s2_.SetRecv(rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, 1)); -} - -// Test that we can encrypt and decrypt RTP/RTCP using AES_CM_128_HMAC_SHA1_80. -TEST_F(SrtpSessionTest, TestProtect_AES_CM_128_HMAC_SHA1_80) { - EXPECT_TRUE(s1_.SetSend(rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); - EXPECT_TRUE(s2_.SetRecv(rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); - TestProtectRtp(CS_AES_CM_128_HMAC_SHA1_80); - TestProtectRtcp(CS_AES_CM_128_HMAC_SHA1_80); - TestUnprotectRtp(CS_AES_CM_128_HMAC_SHA1_80); - TestUnprotectRtcp(CS_AES_CM_128_HMAC_SHA1_80); -} - -// Test that we can encrypt and decrypt RTP/RTCP using AES_CM_128_HMAC_SHA1_32. -TEST_F(SrtpSessionTest, TestProtect_AES_CM_128_HMAC_SHA1_32) { - EXPECT_TRUE(s1_.SetSend(rtc::SRTP_AES128_CM_SHA1_32, kTestKey1, kTestKeyLen)); - EXPECT_TRUE(s2_.SetRecv(rtc::SRTP_AES128_CM_SHA1_32, kTestKey1, kTestKeyLen)); - TestProtectRtp(CS_AES_CM_128_HMAC_SHA1_32); - TestProtectRtcp(CS_AES_CM_128_HMAC_SHA1_32); - TestUnprotectRtp(CS_AES_CM_128_HMAC_SHA1_32); - TestUnprotectRtcp(CS_AES_CM_128_HMAC_SHA1_32); -} - -TEST_F(SrtpSessionTest, TestGetSendStreamPacketIndex) { - EXPECT_TRUE(s1_.SetSend(rtc::SRTP_AES128_CM_SHA1_32, kTestKey1, kTestKeyLen)); - int64_t index; - int out_len = 0; - EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_, rtp_len_, - sizeof(rtp_packet_), &out_len, &index)); - // |index| will be shifted by 16. - int64_t be64_index = static_cast(rtc::NetworkToHost64(1 << 16)); - EXPECT_EQ(be64_index, index); -} - -// Test that we fail to unprotect if someone tampers with the RTP/RTCP paylaods. -TEST_F(SrtpSessionTest, TestTamperReject) { - int out_len; - EXPECT_TRUE(s1_.SetSend(rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); - EXPECT_TRUE(s2_.SetRecv(rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); - TestProtectRtp(CS_AES_CM_128_HMAC_SHA1_80); - TestProtectRtcp(CS_AES_CM_128_HMAC_SHA1_80); - rtp_packet_[0] = 0x12; - rtcp_packet_[1] = 0x34; - EXPECT_FALSE(s2_.UnprotectRtp(rtp_packet_, rtp_len_, &out_len)); - EXPECT_FALSE(s2_.UnprotectRtcp(rtcp_packet_, rtcp_len_, &out_len)); -} - -// Test that we fail to unprotect if the payloads are not authenticated. -TEST_F(SrtpSessionTest, TestUnencryptReject) { - int out_len; - EXPECT_TRUE(s1_.SetSend(rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); - EXPECT_TRUE(s2_.SetRecv(rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); - EXPECT_FALSE(s2_.UnprotectRtp(rtp_packet_, rtp_len_, &out_len)); - EXPECT_FALSE(s2_.UnprotectRtcp(rtcp_packet_, rtcp_len_, &out_len)); -} - -// Test that we fail when using buffers that are too small. -TEST_F(SrtpSessionTest, TestBuffersTooSmall) { - int out_len; - EXPECT_TRUE(s1_.SetSend(rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); - EXPECT_FALSE(s1_.ProtectRtp(rtp_packet_, rtp_len_, - sizeof(rtp_packet_) - 10, &out_len)); - EXPECT_FALSE(s1_.ProtectRtcp(rtcp_packet_, rtcp_len_, - sizeof(rtcp_packet_) - 14, &out_len)); -} - -TEST_F(SrtpSessionTest, TestReplay) { - static const uint16_t kMaxSeqnum = static_cast(-1); - static const uint16_t seqnum_big = 62275; - static const uint16_t seqnum_small = 10; - static const uint16_t replay_window = 1024; - int out_len; - - EXPECT_TRUE(s1_.SetSend(rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); - EXPECT_TRUE(s2_.SetRecv(rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); - - // Initial sequence number. - rtc::SetBE16(reinterpret_cast(rtp_packet_) + 2, seqnum_big); - EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), - &out_len)); - - // Replay within the 1024 window should succeed. - rtc::SetBE16(reinterpret_cast(rtp_packet_) + 2, - seqnum_big - replay_window + 1); - EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), - &out_len)); - - // Replay out side of the 1024 window should fail. - rtc::SetBE16(reinterpret_cast(rtp_packet_) + 2, - seqnum_big - replay_window - 1); - EXPECT_FALSE(s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), - &out_len)); - - // Increment sequence number to a small number. - rtc::SetBE16(reinterpret_cast(rtp_packet_) + 2, seqnum_small); - EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), - &out_len)); - - // Replay around 0 but out side of the 1024 window should fail. - rtc::SetBE16(reinterpret_cast(rtp_packet_) + 2, - kMaxSeqnum + seqnum_small - replay_window - 1); - EXPECT_FALSE(s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), - &out_len)); - - // Replay around 0 but within the 1024 window should succeed. - for (uint16_t seqnum = 65000; seqnum < 65003; ++seqnum) { - rtc::SetBE16(reinterpret_cast(rtp_packet_) + 2, seqnum); - EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), - &out_len)); - } - - // Go back to normal sequence nubmer. - // NOTE: without the fix in libsrtp, this would fail. This is because - // without the fix, the loop above would keep incrementing local sequence - // number in libsrtp, eventually the new sequence number would go out side - // of the window. - rtc::SetBE16(reinterpret_cast(rtp_packet_) + 2, seqnum_small + 1); - EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), - &out_len)); -} +} // namespace rtc diff --git a/webrtc/pc/srtpsession.cc b/webrtc/pc/srtpsession.cc new file mode 100644 index 0000000000..e8b29460a7 --- /dev/null +++ b/webrtc/pc/srtpsession.cc @@ -0,0 +1,408 @@ +/* + * Copyright 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/pc/srtpsession.h" + +#include "third_party/libsrtp/include/srtp.h" +#include "third_party/libsrtp/include/srtp_priv.h" +#include "webrtc/media/base/rtputils.h" +#include "webrtc/pc/externalhmac.h" +#include "webrtc/rtc_base/logging.h" +#include "webrtc/rtc_base/sslstreamadapter.h" + +namespace cricket { + +bool SrtpSession::inited_ = false; + +// This lock protects SrtpSession::inited_. +rtc::GlobalLockPod SrtpSession::lock_; + +SrtpSession::SrtpSession() {} + +SrtpSession::~SrtpSession() { + if (session_) { + srtp_set_user_data(session_, nullptr); + srtp_dealloc(session_); + } +} + +bool SrtpSession::SetSend(int cs, const uint8_t* key, size_t len) { + return SetKey(ssrc_any_outbound, cs, key, len); +} + +bool SrtpSession::UpdateSend(int cs, const uint8_t* key, size_t len) { + return UpdateKey(ssrc_any_outbound, cs, key, len); +} + +bool SrtpSession::SetRecv(int cs, const uint8_t* key, size_t len) { + return SetKey(ssrc_any_inbound, cs, key, len); +} + +bool SrtpSession::UpdateRecv(int cs, const uint8_t* key, size_t len) { + return UpdateKey(ssrc_any_inbound, cs, key, len); +} + +bool SrtpSession::ProtectRtp(void* p, int in_len, int max_len, int* out_len) { + RTC_DCHECK(thread_checker_.CalledOnValidThread()); + if (!session_) { + LOG(LS_WARNING) << "Failed to protect SRTP packet: no SRTP Session"; + return false; + } + + int need_len = in_len + rtp_auth_tag_len_; // NOLINT + if (max_len < need_len) { + LOG(LS_WARNING) << "Failed to protect SRTP packet: The buffer length " + << max_len << " is less than the needed " << need_len; + return false; + } + + *out_len = in_len; + int err = srtp_protect(session_, p, out_len); + int seq_num; + GetRtpSeqNum(p, in_len, &seq_num); + if (err != srtp_err_status_ok) { + LOG(LS_WARNING) << "Failed to protect SRTP packet, seqnum=" << seq_num + << ", err=" << err + << ", last seqnum=" << last_send_seq_num_; + return false; + } + last_send_seq_num_ = seq_num; + return true; +} + +bool SrtpSession::ProtectRtp(void* p, + int in_len, + int max_len, + int* out_len, + int64_t* index) { + if (!ProtectRtp(p, in_len, max_len, out_len)) { + return false; + } + return (index) ? GetSendStreamPacketIndex(p, in_len, index) : true; +} + +bool SrtpSession::ProtectRtcp(void* p, int in_len, int max_len, int* out_len) { + RTC_DCHECK(thread_checker_.CalledOnValidThread()); + if (!session_) { + LOG(LS_WARNING) << "Failed to protect SRTCP packet: no SRTP Session"; + return false; + } + + int need_len = in_len + sizeof(uint32_t) + rtcp_auth_tag_len_; // NOLINT + if (max_len < need_len) { + LOG(LS_WARNING) << "Failed to protect SRTCP packet: The buffer length " + << max_len << " is less than the needed " << need_len; + return false; + } + + *out_len = in_len; + int err = srtp_protect_rtcp(session_, p, out_len); + if (err != srtp_err_status_ok) { + LOG(LS_WARNING) << "Failed to protect SRTCP packet, err=" << err; + return false; + } + return true; +} + +bool SrtpSession::UnprotectRtp(void* p, int in_len, int* out_len) { + RTC_DCHECK(thread_checker_.CalledOnValidThread()); + if (!session_) { + LOG(LS_WARNING) << "Failed to unprotect SRTP packet: no SRTP Session"; + return false; + } + + *out_len = in_len; + int err = srtp_unprotect(session_, p, out_len); + if (err != srtp_err_status_ok) { + LOG(LS_WARNING) << "Failed to unprotect SRTP packet, err=" << err; + return false; + } + return true; +} + +bool SrtpSession::UnprotectRtcp(void* p, int in_len, int* out_len) { + RTC_DCHECK(thread_checker_.CalledOnValidThread()); + if (!session_) { + LOG(LS_WARNING) << "Failed to unprotect SRTCP packet: no SRTP Session"; + return false; + } + + *out_len = in_len; + int err = srtp_unprotect_rtcp(session_, p, out_len); + if (err != srtp_err_status_ok) { + LOG(LS_WARNING) << "Failed to unprotect SRTCP packet, err=" << err; + return false; + } + return true; +} + +bool SrtpSession::GetRtpAuthParams(uint8_t** key, int* key_len, int* tag_len) { + RTC_DCHECK(thread_checker_.CalledOnValidThread()); + RTC_DCHECK(IsExternalAuthActive()); + if (!IsExternalAuthActive()) { + return false; + } + + ExternalHmacContext* external_hmac = nullptr; + // stream_template will be the reference context for other streams. + // Let's use it for getting the keys. + srtp_stream_ctx_t* srtp_context = session_->stream_template; +#if defined(SRTP_MAX_MKI_LEN) + // libsrtp 2.1.0 + if (srtp_context && srtp_context->session_keys && + srtp_context->session_keys->rtp_auth) { + external_hmac = reinterpret_cast( + srtp_context->session_keys->rtp_auth->state); + } +#else + // libsrtp 2.0.0 + // TODO(jbauch): Remove after switching to libsrtp 2.1.0 + if (srtp_context && srtp_context->rtp_auth) { + external_hmac = + reinterpret_cast(srtp_context->rtp_auth->state); + } +#endif + + if (!external_hmac) { + LOG(LS_ERROR) << "Failed to get auth keys from libsrtp!."; + return false; + } + + *key = external_hmac->key; + *key_len = external_hmac->key_length; + *tag_len = rtp_auth_tag_len_; + return true; +} + +int SrtpSession::GetSrtpOverhead() const { + return rtp_auth_tag_len_; +} + +void SrtpSession::EnableExternalAuth() { + RTC_DCHECK(!session_); + external_auth_enabled_ = true; +} + +bool SrtpSession::IsExternalAuthEnabled() const { + return external_auth_enabled_; +} + +bool SrtpSession::IsExternalAuthActive() const { + return external_auth_active_; +} + +bool SrtpSession::GetSendStreamPacketIndex(void* p, + int in_len, + int64_t* index) { + RTC_DCHECK(thread_checker_.CalledOnValidThread()); + srtp_hdr_t* hdr = reinterpret_cast(p); + srtp_stream_ctx_t* stream = srtp_get_stream(session_, hdr->ssrc); + if (!stream) { + return false; + } + + // Shift packet index, put into network byte order + *index = static_cast(rtc::NetworkToHost64( + srtp_rdbx_get_packet_index(&stream->rtp_rdbx) << 16)); + return true; +} + +bool SrtpSession::DoSetKey(int type, int cs, const uint8_t* key, size_t len) { + RTC_DCHECK(thread_checker_.CalledOnValidThread()); + + srtp_policy_t policy; + memset(&policy, 0, sizeof(policy)); + if (cs == rtc::SRTP_AES128_CM_SHA1_80) { + srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtp); + srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtcp); + } else if (cs == rtc::SRTP_AES128_CM_SHA1_32) { + // RTP HMAC is shortened to 32 bits, but RTCP remains 80 bits. + srtp_crypto_policy_set_aes_cm_128_hmac_sha1_32(&policy.rtp); + srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtcp); + } else if (cs == rtc::SRTP_AEAD_AES_128_GCM) { + srtp_crypto_policy_set_aes_gcm_128_16_auth(&policy.rtp); + srtp_crypto_policy_set_aes_gcm_128_16_auth(&policy.rtcp); + } else if (cs == rtc::SRTP_AEAD_AES_256_GCM) { + srtp_crypto_policy_set_aes_gcm_256_16_auth(&policy.rtp); + srtp_crypto_policy_set_aes_gcm_256_16_auth(&policy.rtcp); + } else { + LOG(LS_WARNING) << "Failed to " << (session_ ? "update" : "create") + << " SRTP session: unsupported cipher_suite " << cs; + return false; + } + + int expected_key_len; + int expected_salt_len; + if (!rtc::GetSrtpKeyAndSaltLengths(cs, &expected_key_len, + &expected_salt_len)) { + // This should never happen. + LOG(LS_WARNING) + << "Failed to " << (session_ ? "update" : "create") + << " SRTP session: unsupported cipher_suite without length information" + << cs; + return false; + } + + if (!key || + len != static_cast(expected_key_len + expected_salt_len)) { + LOG(LS_WARNING) << "Failed to " << (session_ ? "update" : "create") + << " SRTP session: invalid key"; + return false; + } + + policy.ssrc.type = static_cast(type); + policy.ssrc.value = 0; + policy.key = const_cast(key); + // TODO(astor) parse window size from WSH session-param + policy.window_size = 1024; + policy.allow_repeat_tx = 1; + // If external authentication option is enabled, supply custom auth module + // id EXTERNAL_HMAC_SHA1 in the policy structure. + // We want to set this option only for rtp packets. + // By default policy structure is initialized to HMAC_SHA1. + // Enable external HMAC authentication only for outgoing streams and only + // for cipher suites that support it (i.e. only non-GCM cipher suites). + if (type == ssrc_any_outbound && IsExternalAuthEnabled() && + !rtc::IsGcmCryptoSuite(cs)) { + policy.rtp.auth_type = EXTERNAL_HMAC_SHA1; + } + if (!encrypted_header_extension_ids_.empty()) { + policy.enc_xtn_hdr = const_cast(&encrypted_header_extension_ids_[0]); + policy.enc_xtn_hdr_count = + static_cast(encrypted_header_extension_ids_.size()); + } + policy.next = nullptr; + + if (!session_) { + int err = srtp_create(&session_, &policy); + if (err != srtp_err_status_ok) { + session_ = nullptr; + LOG(LS_ERROR) << "Failed to create SRTP session, err=" << err; + return false; + } + srtp_set_user_data(session_, this); + } else { + int err = srtp_update(session_, &policy); + if (err != srtp_err_status_ok) { + LOG(LS_ERROR) << "Failed to update SRTP session, err=" << err; + return false; + } + } + + rtp_auth_tag_len_ = policy.rtp.auth_tag_len; + rtcp_auth_tag_len_ = policy.rtcp.auth_tag_len; + external_auth_active_ = (policy.rtp.auth_type == EXTERNAL_HMAC_SHA1); + return true; +} + +bool SrtpSession::SetKey(int type, int cs, const uint8_t* key, size_t len) { + RTC_DCHECK(thread_checker_.CalledOnValidThread()); + if (session_) { + LOG(LS_ERROR) << "Failed to create SRTP session: " + << "SRTP session already created"; + return false; + } + + if (!Init()) { + return false; + } + + return DoSetKey(type, cs, key, len); +} + +bool SrtpSession::UpdateKey(int type, int cs, const uint8_t* key, size_t len) { + RTC_DCHECK(thread_checker_.CalledOnValidThread()); + if (!session_) { + LOG(LS_ERROR) << "Failed to update non-existing SRTP session"; + return false; + } + + return DoSetKey(type, cs, key, len); +} + +void SrtpSession::SetEncryptedHeaderExtensionIds( + const std::vector& encrypted_header_extension_ids) { + RTC_DCHECK(thread_checker_.CalledOnValidThread()); + encrypted_header_extension_ids_ = encrypted_header_extension_ids; +} + +bool SrtpSession::Init() { + rtc::GlobalLockScope ls(&lock_); + + if (!inited_) { + int err; + err = srtp_init(); + if (err != srtp_err_status_ok) { + LOG(LS_ERROR) << "Failed to init SRTP, err=" << err; + return false; + } + + err = srtp_install_event_handler(&SrtpSession::HandleEventThunk); + if (err != srtp_err_status_ok) { + LOG(LS_ERROR) << "Failed to install SRTP event handler, err=" << err; + return false; + } + + err = external_crypto_init(); + if (err != srtp_err_status_ok) { + LOG(LS_ERROR) << "Failed to initialize fake auth, err=" << err; + return false; + } + inited_ = true; + } + + return true; +} + +void SrtpSession::Terminate() { + rtc::GlobalLockScope ls(&lock_); + + if (inited_) { + int err = srtp_shutdown(); + if (err) { + LOG(LS_ERROR) << "srtp_shutdown failed. err=" << err; + return; + } + inited_ = false; + } +} + +void SrtpSession::HandleEvent(const srtp_event_data_t* ev) { + RTC_DCHECK(thread_checker_.CalledOnValidThread()); + switch (ev->event) { + case event_ssrc_collision: + LOG(LS_INFO) << "SRTP event: SSRC collision"; + break; + case event_key_soft_limit: + LOG(LS_INFO) << "SRTP event: reached soft key usage limit"; + break; + case event_key_hard_limit: + LOG(LS_INFO) << "SRTP event: reached hard key usage limit"; + break; + case event_packet_index_limit: + LOG(LS_INFO) << "SRTP event: reached hard packet limit (2^48 packets)"; + break; + default: + LOG(LS_INFO) << "SRTP event: unknown " << ev->event; + break; + } +} + +void SrtpSession::HandleEventThunk(srtp_event_data_t* ev) { + // Callback will be executed from same thread that calls the "srtp_protect" + // and "srtp_unprotect" functions. + SrtpSession* session = + static_cast(srtp_get_user_data(ev->session)); + if (session) { + session->HandleEvent(ev); + } +} + +} // namespace cricket diff --git a/webrtc/pc/srtpsession.h b/webrtc/pc/srtpsession.h new file mode 100644 index 0000000000..c490c48f9d --- /dev/null +++ b/webrtc/pc/srtpsession.h @@ -0,0 +1,109 @@ +/* + * Copyright 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_PC_SRTPSESSION_H_ +#define WEBRTC_PC_SRTPSESSION_H_ + +#include + +#include "webrtc/rtc_base/basictypes.h" +#include "webrtc/rtc_base/thread_checker.h" + +// Forward declaration to avoid pulling in libsrtp headers here +struct srtp_event_data_t; +struct srtp_ctx_t_; + +namespace cricket { + +// Class that wraps a libSRTP session. +class SrtpSession { + public: + SrtpSession(); + ~SrtpSession(); + + // Configures the session for sending data using the specified + // cipher-suite and key. Receiving must be done by a separate session. + bool SetSend(int cs, const uint8_t* key, size_t len); + bool UpdateSend(int cs, const uint8_t* key, size_t len); + + // Configures the session for receiving data using the specified + // cipher-suite and key. Sending must be done by a separate session. + bool SetRecv(int cs, const uint8_t* key, size_t len); + bool UpdateRecv(int cs, const uint8_t* key, size_t len); + + void SetEncryptedHeaderExtensionIds( + const std::vector& encrypted_header_extension_ids); + + // Encrypts/signs an individual RTP/RTCP packet, in-place. + // If an HMAC is used, this will increase the packet size. + bool ProtectRtp(void* data, int in_len, int max_len, int* out_len); + // Overloaded version, outputs packet index. + bool ProtectRtp(void* data, + int in_len, + int max_len, + int* out_len, + int64_t* index); + bool ProtectRtcp(void* data, int in_len, int max_len, int* out_len); + // Decrypts/verifies an invidiual RTP/RTCP packet. + // If an HMAC is used, this will decrease the packet size. + bool UnprotectRtp(void* data, int in_len, int* out_len); + bool UnprotectRtcp(void* data, int in_len, int* out_len); + + // Helper method to get authentication params. + bool GetRtpAuthParams(uint8_t** key, int* key_len, int* tag_len); + + int GetSrtpOverhead() const; + + // If external auth is enabled, SRTP will write a dummy auth tag that then + // later must get replaced before the packet is sent out. Only supported for + // non-GCM cipher suites and can be checked through "IsExternalAuthActive" + // if it is actually used. This method is only valid before the RTP params + // have been set. + void EnableExternalAuth(); + bool IsExternalAuthEnabled() const; + + // A SRTP session supports external creation of the auth tag if a non-GCM + // cipher is used. This method is only valid after the RTP params have + // been set. + bool IsExternalAuthActive() const; + + // Calls srtp_shutdown if it's initialized. + static void Terminate(); + + private: + bool DoSetKey(int type, int cs, const uint8_t* key, size_t len); + bool SetKey(int type, int cs, const uint8_t* key, size_t len); + bool UpdateKey(int type, int cs, const uint8_t* key, size_t len); + bool SetEncryptedHeaderExtensionIds( + int type, + const std::vector& encrypted_header_extension_ids); + // Returns send stream current packet index from srtp db. + bool GetSendStreamPacketIndex(void* data, int in_len, int64_t* index); + + static bool Init(); + void HandleEvent(const srtp_event_data_t* ev); + static void HandleEventThunk(srtp_event_data_t* ev); + + rtc::ThreadChecker thread_checker_; + srtp_ctx_t_* session_ = nullptr; + int rtp_auth_tag_len_ = 0; + int rtcp_auth_tag_len_ = 0; + static bool inited_; + static rtc::GlobalLockPod lock_; + int last_send_seq_num_ = -1; + bool external_auth_active_ = false; + bool external_auth_enabled_ = false; + std::vector encrypted_header_extension_ids_; + RTC_DISALLOW_COPY_AND_ASSIGN(SrtpSession); +}; + +} // namespace cricket + +#endif // WEBRTC_PC_SRTPSESSION_H_ diff --git a/webrtc/pc/srtpsession_unittest.cc b/webrtc/pc/srtpsession_unittest.cc new file mode 100644 index 0000000000..d10989d4dd --- /dev/null +++ b/webrtc/pc/srtpsession_unittest.cc @@ -0,0 +1,204 @@ +/* + * Copyright 2004 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/pc/srtpsession.h" + +#include + +#include "webrtc/media/base/fakertp.h" +#include "webrtc/pc/srtptestutil.h" +#include "webrtc/rtc_base/gunit.h" +#include "webrtc/rtc_base/sslstreamadapter.h" // For rtc::SRTP_* + +namespace rtc { + +class SrtpSessionTest : public testing::Test { + protected: + virtual void SetUp() { + rtp_len_ = sizeof(kPcmuFrame); + rtcp_len_ = sizeof(kRtcpReport); + memcpy(rtp_packet_, kPcmuFrame, rtp_len_); + memcpy(rtcp_packet_, kRtcpReport, rtcp_len_); + } + void TestProtectRtp(const std::string& cs) { + int out_len = 0; + EXPECT_TRUE( + s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), &out_len)); + EXPECT_EQ(out_len, rtp_len_ + rtp_auth_tag_len(cs)); + EXPECT_NE(0, memcmp(rtp_packet_, kPcmuFrame, rtp_len_)); + rtp_len_ = out_len; + } + void TestProtectRtcp(const std::string& cs) { + int out_len = 0; + EXPECT_TRUE(s1_.ProtectRtcp(rtcp_packet_, rtcp_len_, sizeof(rtcp_packet_), + &out_len)); + EXPECT_EQ(out_len, rtcp_len_ + 4 + rtcp_auth_tag_len(cs)); // NOLINT + EXPECT_NE(0, memcmp(rtcp_packet_, kRtcpReport, rtcp_len_)); + rtcp_len_ = out_len; + } + void TestUnprotectRtp(const std::string& cs) { + int out_len = 0, expected_len = sizeof(kPcmuFrame); + EXPECT_TRUE(s2_.UnprotectRtp(rtp_packet_, rtp_len_, &out_len)); + EXPECT_EQ(expected_len, out_len); + EXPECT_EQ(0, memcmp(rtp_packet_, kPcmuFrame, out_len)); + } + void TestUnprotectRtcp(const std::string& cs) { + int out_len = 0, expected_len = sizeof(kRtcpReport); + EXPECT_TRUE(s2_.UnprotectRtcp(rtcp_packet_, rtcp_len_, &out_len)); + EXPECT_EQ(expected_len, out_len); + EXPECT_EQ(0, memcmp(rtcp_packet_, kRtcpReport, out_len)); + } + cricket::SrtpSession s1_; + cricket::SrtpSession s2_; + char rtp_packet_[sizeof(kPcmuFrame) + 10]; + char rtcp_packet_[sizeof(kRtcpReport) + 4 + 10]; + int rtp_len_; + int rtcp_len_; +}; + +// Test that we can set up the session and keys properly. +TEST_F(SrtpSessionTest, TestGoodSetup) { + EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); + EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); +} + +// Test that we can't change the keys once set. +TEST_F(SrtpSessionTest, TestBadSetup) { + EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); + EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); + EXPECT_FALSE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey2, kTestKeyLen)); + EXPECT_FALSE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey2, kTestKeyLen)); +} + +// Test that we fail keys of the wrong length. +TEST_F(SrtpSessionTest, TestKeysTooShort) { + EXPECT_FALSE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, 1)); + EXPECT_FALSE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, 1)); +} + +// Test that we can encrypt and decrypt RTP/RTCP using AES_CM_128_HMAC_SHA1_80. +TEST_F(SrtpSessionTest, TestProtect_AES_CM_128_HMAC_SHA1_80) { + EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); + EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); + TestProtectRtp(CS_AES_CM_128_HMAC_SHA1_80); + TestProtectRtcp(CS_AES_CM_128_HMAC_SHA1_80); + TestUnprotectRtp(CS_AES_CM_128_HMAC_SHA1_80); + TestUnprotectRtcp(CS_AES_CM_128_HMAC_SHA1_80); +} + +// Test that we can encrypt and decrypt RTP/RTCP using AES_CM_128_HMAC_SHA1_32. +TEST_F(SrtpSessionTest, TestProtect_AES_CM_128_HMAC_SHA1_32) { + EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_32, kTestKey1, kTestKeyLen)); + EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_32, kTestKey1, kTestKeyLen)); + TestProtectRtp(CS_AES_CM_128_HMAC_SHA1_32); + TestProtectRtcp(CS_AES_CM_128_HMAC_SHA1_32); + TestUnprotectRtp(CS_AES_CM_128_HMAC_SHA1_32); + TestUnprotectRtcp(CS_AES_CM_128_HMAC_SHA1_32); +} + +TEST_F(SrtpSessionTest, TestGetSendStreamPacketIndex) { + EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_32, kTestKey1, kTestKeyLen)); + int64_t index; + int out_len = 0; + EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), + &out_len, &index)); + // |index| will be shifted by 16. + int64_t be64_index = static_cast(NetworkToHost64(1 << 16)); + EXPECT_EQ(be64_index, index); +} + +// Test that we fail to unprotect if someone tampers with the RTP/RTCP paylaods. +TEST_F(SrtpSessionTest, TestTamperReject) { + int out_len; + EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); + EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); + TestProtectRtp(CS_AES_CM_128_HMAC_SHA1_80); + TestProtectRtcp(CS_AES_CM_128_HMAC_SHA1_80); + rtp_packet_[0] = 0x12; + rtcp_packet_[1] = 0x34; + EXPECT_FALSE(s2_.UnprotectRtp(rtp_packet_, rtp_len_, &out_len)); + EXPECT_FALSE(s2_.UnprotectRtcp(rtcp_packet_, rtcp_len_, &out_len)); +} + +// Test that we fail to unprotect if the payloads are not authenticated. +TEST_F(SrtpSessionTest, TestUnencryptReject) { + int out_len; + EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); + EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); + EXPECT_FALSE(s2_.UnprotectRtp(rtp_packet_, rtp_len_, &out_len)); + EXPECT_FALSE(s2_.UnprotectRtcp(rtcp_packet_, rtcp_len_, &out_len)); +} + +// Test that we fail when using buffers that are too small. +TEST_F(SrtpSessionTest, TestBuffersTooSmall) { + int out_len; + EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); + EXPECT_FALSE(s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_) - 10, + &out_len)); + EXPECT_FALSE(s1_.ProtectRtcp(rtcp_packet_, rtcp_len_, + sizeof(rtcp_packet_) - 14, &out_len)); +} + +TEST_F(SrtpSessionTest, TestReplay) { + static const uint16_t kMaxSeqnum = static_cast(-1); + static const uint16_t seqnum_big = 62275; + static const uint16_t seqnum_small = 10; + static const uint16_t replay_window = 1024; + int out_len; + + EXPECT_TRUE(s1_.SetSend(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); + EXPECT_TRUE(s2_.SetRecv(SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen)); + + // Initial sequence number. + SetBE16(reinterpret_cast(rtp_packet_) + 2, seqnum_big); + EXPECT_TRUE( + s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), &out_len)); + + // Replay within the 1024 window should succeed. + SetBE16(reinterpret_cast(rtp_packet_) + 2, + seqnum_big - replay_window + 1); + EXPECT_TRUE( + s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), &out_len)); + + // Replay out side of the 1024 window should fail. + SetBE16(reinterpret_cast(rtp_packet_) + 2, + seqnum_big - replay_window - 1); + EXPECT_FALSE( + s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), &out_len)); + + // Increment sequence number to a small number. + SetBE16(reinterpret_cast(rtp_packet_) + 2, seqnum_small); + EXPECT_TRUE( + s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), &out_len)); + + // Replay around 0 but out side of the 1024 window should fail. + SetBE16(reinterpret_cast(rtp_packet_) + 2, + kMaxSeqnum + seqnum_small - replay_window - 1); + EXPECT_FALSE( + s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), &out_len)); + + // Replay around 0 but within the 1024 window should succeed. + for (uint16_t seqnum = 65000; seqnum < 65003; ++seqnum) { + SetBE16(reinterpret_cast(rtp_packet_) + 2, seqnum); + EXPECT_TRUE( + s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), &out_len)); + } + + // Go back to normal sequence nubmer. + // NOTE: without the fix in libsrtp, this would fail. This is because + // without the fix, the loop above would keep incrementing local sequence + // number in libsrtp, eventually the new sequence number would go out side + // of the window. + SetBE16(reinterpret_cast(rtp_packet_) + 2, seqnum_small + 1); + EXPECT_TRUE( + s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), &out_len)); +} + +} // namespace rtc diff --git a/webrtc/pc/srtptestutil.h b/webrtc/pc/srtptestutil.h new file mode 100644 index 0000000000..daf7eb03ad --- /dev/null +++ b/webrtc/pc/srtptestutil.h @@ -0,0 +1,45 @@ +/* + * Copyright 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_PC_SRTPTESTUTIL_H_ +#define WEBRTC_PC_SRTPTESTUTIL_H_ + +#include + +namespace rtc { + +extern const char CS_AES_CM_128_HMAC_SHA1_32[]; +extern const char CS_AEAD_AES_128_GCM[]; +extern const char CS_AEAD_AES_256_GCM[]; + +static const uint8_t kTestKey1[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZ1234"; +static const uint8_t kTestKey2[] = "4321ZYXWVUTSRQPONMLKJIHGFEDCBA"; +static const int kTestKeyLen = 30; + +static int rtp_auth_tag_len(const std::string& cs) { + if (cs == CS_AES_CM_128_HMAC_SHA1_32) { + return 4; + } else if (cs == CS_AEAD_AES_128_GCM || cs == CS_AEAD_AES_256_GCM) { + return 16; + } else { + return 10; + } +} +static int rtcp_auth_tag_len(const std::string& cs) { + if (cs == CS_AEAD_AES_128_GCM || cs == CS_AEAD_AES_256_GCM) { + return 16; + } else { + return 10; + } +} + +} // namespace rtc + +#endif // WEBRTC_PC_SRTPTESTUTIL_H_