From 37bb54eb2079bacc5003eda1b5b69126a9e3e516 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Wed, 29 Jun 2016 10:41:00 +0200 Subject: [PATCH] Reland: Remove global list of SRTP sessions. Instead save a reference to the SrtpSession inside the srtp_ctx_t. The original CL was https://codereview.webrtc.org/1416093010 and should be good to reland now that internal projects are using a more recent version of libsrtp. BUG=webrtc:5133 R=mattdr@webrtc.org, pthatcher@webrtc.org Review URL: https://codereview.webrtc.org/2109893002 . Cr-Commit-Position: refs/heads/master@{#13318} --- webrtc/pc/srtpfilter.cc | 70 +++++++++++++++++++---------------------- webrtc/pc/srtpfilter.h | 4 +-- 2 files changed, 35 insertions(+), 39 deletions(-) diff --git a/webrtc/pc/srtpfilter.cc b/webrtc/pc/srtpfilter.cc index e4796fd672..60dd4f1c60 100644 --- a/webrtc/pc/srtpfilter.cc +++ b/webrtc/pc/srtpfilter.cc @@ -16,6 +16,7 @@ #include "webrtc/base/base64.h" #include "webrtc/base/byteorder.h" +#include "webrtc/base/checks.h" #include "webrtc/base/common.h" #include "webrtc/base/logging.h" #include "webrtc/base/stringencode.h" @@ -197,7 +198,7 @@ bool SrtpFilter::ProtectRtp(void* p, int in_len, int max_len, int* out_len) { LOG(LS_WARNING) << "Failed to ProtectRtp: SRTP not active"; return false; } - ASSERT(send_session_ != NULL); + RTC_CHECK(send_session_); return send_session_->ProtectRtp(p, in_len, max_len, out_len); } @@ -210,7 +211,7 @@ bool SrtpFilter::ProtectRtp(void* p, LOG(LS_WARNING) << "Failed to ProtectRtp: SRTP not active"; return false; } - ASSERT(send_session_ != NULL); + RTC_CHECK(send_session_); return send_session_->ProtectRtp(p, in_len, max_len, out_len, index); } @@ -222,7 +223,7 @@ bool SrtpFilter::ProtectRtcp(void* p, int in_len, int max_len, int* out_len) { if (send_rtcp_session_) { return send_rtcp_session_->ProtectRtcp(p, in_len, max_len, out_len); } else { - ASSERT(send_session_ != NULL); + RTC_CHECK(send_session_); return send_session_->ProtectRtcp(p, in_len, max_len, out_len); } } @@ -232,7 +233,7 @@ bool SrtpFilter::UnprotectRtp(void* p, int in_len, int* out_len) { LOG(LS_WARNING) << "Failed to UnprotectRtp: SRTP not active"; return false; } - ASSERT(recv_session_ != NULL); + RTC_CHECK(recv_session_); return recv_session_->UnprotectRtp(p, in_len, out_len); } @@ -244,7 +245,7 @@ bool SrtpFilter::UnprotectRtcp(void* p, int in_len, int* out_len) { if (recv_rtcp_session_) { return recv_rtcp_session_->UnprotectRtcp(p, in_len, out_len); } else { - ASSERT(recv_session_ != NULL); + RTC_CHECK(recv_session_); return recv_session_->UnprotectRtcp(p, in_len, out_len); } } @@ -255,16 +256,16 @@ bool SrtpFilter::GetRtpAuthParams(uint8_t** key, int* key_len, int* tag_len) { return false; } - ASSERT(send_session_ != NULL); + RTC_CHECK(send_session_); return send_session_->GetRtpAuthParams(key, key_len, tag_len); } void SrtpFilter::set_signal_silent_time(int signal_silent_time_in_ms) { signal_silent_time_in_ms_ = signal_silent_time_in_ms; if (IsActive()) { - ASSERT(send_session_ != NULL); + RTC_CHECK(send_session_); send_session_->set_signal_silent_time(signal_silent_time_in_ms); - ASSERT(recv_session_ != NULL); + RTC_CHECK(recv_session_); recv_session_->set_signal_silent_time(signal_silent_time_in_ms); if (send_rtcp_session_) send_rtcp_session_->set_signal_silent_time(signal_silent_time_in_ms); @@ -452,7 +453,7 @@ bool SrtpFilter::ParseKeyParams(const std::string& key_params, // Fail if base64 decode fails, or the key is the wrong size. std::string key_b64(key_params.substr(7)), key_str; if (!rtc::Base64::Decode(key_b64, rtc::Base64::DO_STRICT, - &key_str, NULL) || + &key_str, nullptr) || static_cast(key_str.size()) != len) { return false; } @@ -468,28 +469,21 @@ bool SrtpFilter::ParseKeyParams(const std::string& key_params, bool SrtpSession::inited_ = false; -// This lock protects SrtpSession::inited_ and SrtpSession::sessions_. +// This lock protects SrtpSession::inited_. rtc::GlobalLockPod SrtpSession::lock_; SrtpSession::SrtpSession() - : session_(NULL), + : session_(nullptr), rtp_auth_tag_len_(0), rtcp_auth_tag_len_(0), srtp_stat_(new SrtpStat()), last_send_seq_num_(-1) { - { - rtc::GlobalLockScope ls(&lock_); - sessions()->push_back(this); - } SignalSrtpError.repeat(srtp_stat_->SignalSrtpError); } SrtpSession::~SrtpSession() { - { - rtc::GlobalLockScope ls(&lock_); - sessions()->erase(std::find(sessions()->begin(), sessions()->end(), this)); - } if (session_) { + srtp_set_user_data(session_, nullptr); srtp_dealloc(session_); } } @@ -503,6 +497,7 @@ bool SrtpSession::SetRecv(int cs, const uint8_t* key, int len) { } bool SrtpSession::ProtectRtp(void* p, int in_len, int max_len, int* out_len) { + RTC_DCHECK(thread_checker_.CalledOnValidThread()); if (!session_) { LOG(LS_WARNING) << "Failed to protect SRTP packet: no SRTP Session"; return false; @@ -545,6 +540,7 @@ bool SrtpSession::ProtectRtp(void* p, } bool SrtpSession::ProtectRtcp(void* p, int in_len, int max_len, int* out_len) { + RTC_DCHECK(thread_checker_.CalledOnValidThread()); if (!session_) { LOG(LS_WARNING) << "Failed to protect SRTCP packet: no SRTP Session"; return false; @@ -568,6 +564,7 @@ bool SrtpSession::ProtectRtcp(void* p, int in_len, int max_len, int* out_len) { } bool SrtpSession::UnprotectRtp(void* p, int in_len, int* out_len) { + RTC_DCHECK(thread_checker_.CalledOnValidThread()); if (!session_) { LOG(LS_WARNING) << "Failed to unprotect SRTP packet: no SRTP Session"; return false; @@ -587,6 +584,7 @@ bool SrtpSession::UnprotectRtp(void* p, int in_len, int* out_len) { } bool SrtpSession::UnprotectRtcp(void* p, int in_len, int* out_len) { + RTC_DCHECK(thread_checker_.CalledOnValidThread()); if (!session_) { LOG(LS_WARNING) << "Failed to unprotect SRTCP packet: no SRTP Session"; return false; @@ -604,7 +602,8 @@ bool SrtpSession::UnprotectRtcp(void* p, int in_len, int* out_len) { bool SrtpSession::GetRtpAuthParams(uint8_t** key, int* key_len, int* tag_len) { #if defined(ENABLE_EXTERNAL_AUTH) - ExternalHmacContext* external_hmac = NULL; + RTC_DCHECK(thread_checker_.CalledOnValidThread()); + ExternalHmacContext* external_hmac = nullptr; // stream_template will be the reference context for other streams. // Let's use it for getting the keys. srtp_stream_ctx_t* srtp_context = session_->stream_template; @@ -630,10 +629,12 @@ bool SrtpSession::GetRtpAuthParams(uint8_t** key, int* key_len, int* tag_len) { bool SrtpSession::GetSendStreamPacketIndex(void* p, int in_len, int64_t* index) { + RTC_DCHECK(thread_checker_.CalledOnValidThread()); srtp_hdr_t* hdr = reinterpret_cast(p); srtp_stream_ctx_t* stream = srtp_get_stream(session_, hdr->ssrc); - if (stream == NULL) + if (!stream) { return false; + } // Shift packet index, put into network byte order *index = static_cast( @@ -646,6 +647,7 @@ void SrtpSession::set_signal_silent_time(int signal_silent_time_in_ms) { } bool SrtpSession::SetKey(int type, int cs, const uint8_t* key, int len) { + RTC_DCHECK(thread_checker_.CalledOnValidThread()); if (session_) { LOG(LS_ERROR) << "Failed to create SRTP session: " << "SRTP session already created"; @@ -692,16 +694,16 @@ bool SrtpSession::SetKey(int type, int cs, const uint8_t* key, int len) { policy.rtp.auth_type = EXTERNAL_HMAC_SHA1; } #endif - policy.next = NULL; + policy.next = nullptr; int err = srtp_create(&session_, &policy); if (err != err_status_ok) { - session_ = NULL; + session_ = nullptr; LOG(LS_ERROR) << "Failed to create SRTP session, err=" << err; return false; } - + srtp_set_user_data(session_, this); rtp_auth_tag_len_ = policy.rtp.auth_tag_len; rtcp_auth_tag_len_ = policy.rtcp.auth_tag_len; return true; @@ -750,6 +752,7 @@ void SrtpSession::Terminate() { } void SrtpSession::HandleEvent(const srtp_event_data_t* ev) { + RTC_DCHECK(thread_checker_.CalledOnValidThread()); switch (ev->event) { case event_ssrc_collision: LOG(LS_INFO) << "SRTP event: SSRC collision"; @@ -770,22 +773,15 @@ void SrtpSession::HandleEvent(const srtp_event_data_t* ev) { } void SrtpSession::HandleEventThunk(srtp_event_data_t* ev) { - rtc::GlobalLockScope ls(&lock_); - - for (std::list::iterator it = sessions()->begin(); - it != sessions()->end(); ++it) { - if ((*it)->session_ == ev->session) { - (*it)->HandleEvent(ev); - break; - } + // Callback will be executed from same thread that calls the "srtp_protect" + // and "srtp_unprotect" functions. + SrtpSession* session = static_cast( + srtp_get_user_data(ev->session)); + if (session) { + session->HandleEvent(ev); } } -std::list* SrtpSession::sessions() { - RTC_DEFINE_STATIC_LOCAL(std::list, sessions, ()); - return &sessions; -} - #else // !HAVE_SRTP // On some systems, SRTP is not (yet) available. diff --git a/webrtc/pc/srtpfilter.h b/webrtc/pc/srtpfilter.h index b54eb8bc86..cde9ad7e09 100644 --- a/webrtc/pc/srtpfilter.h +++ b/webrtc/pc/srtpfilter.h @@ -22,6 +22,7 @@ #include "webrtc/base/criticalsection.h" #include "webrtc/base/sigslotrepeater.h" #include "webrtc/base/sslstreamadapter.h" +#include "webrtc/base/thread_checker.h" #include "webrtc/media/base/cryptoparams.h" #include "webrtc/p2p/base/sessiondescription.h" @@ -225,8 +226,7 @@ class SrtpSession { void HandleEvent(const srtp_event_data_t* ev); static void HandleEventThunk(srtp_event_data_t* ev); - static std::list* sessions(); - + rtc::ThreadChecker thread_checker_; srtp_ctx_t* session_; int rtp_auth_tag_len_; int rtcp_auth_tag_len_;