Reland "srtp: spanify Protect + Unprotect"

This is a reland of commit 9572b2fa5850da6d319b9efb5ee36290e2895f7f
that does not remove the legacy implementations yet.

Original change's description:
> srtp: spanify Protect + Unprotect
>
> Makes SrtpSession and SrtpTransport use rtc::CopyOnWriteBuffer for the Protect and Unprotect operations instead of passing around void pointers.
>
> Also updates the unit tests to use CopyOnWriteBuffer instead of char arrays with a fixed length.
>
> BUG=webrtc:357776213
> No-Iwyu: missing include is a private libsrtp header
>
> Change-Id: I02a22ceb4e183e93c4ebd8c0a9c931404e0e32f3
> Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/358442
> Reviewed-by: Henrik Boström <hbos@webrtc.org>
> Reviewed-by: Harald Alvestrand <hta@webrtc.org>
> Commit-Queue: Philipp Hancke <phancke@meta.com>
> Cr-Commit-Position: refs/heads/main@{#43601}

No-Iwyu: missing include is a private libsrtp header
Bug: webrtc:357776213
Change-Id: I93704e27a6c48e015b775712fcd848c8c0c753e5
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/372321
Commit-Queue: Philipp Hancke <phancke@meta.com>
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Reviewed-by: Henrik Boström <hbos@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#43799}
This commit is contained in:
Philipp Hancke 2025-01-22 15:27:59 -08:00 committed by WebRTC LUCI CQ
parent 4e8c984d15
commit 5090eaf363
6 changed files with 333 additions and 188 deletions

View File

@ -594,6 +594,8 @@ rtc_source_set("srtp_session") {
"../rtc_base:buffer", "../rtc_base:buffer",
"../rtc_base:byte_order", "../rtc_base:byte_order",
"../rtc_base:checks", "../rtc_base:checks",
"../rtc_base:copy_on_write_buffer",
"../rtc_base:ip_address",
"../rtc_base:logging", "../rtc_base:logging",
"../rtc_base:macromagic", "../rtc_base:macromagic",
"../rtc_base:ssl_adapter", "../rtc_base:ssl_adapter",
@ -620,6 +622,8 @@ rtc_source_set("srtp_transport") {
"../api:field_trials_view", "../api:field_trials_view",
"../api:libjingle_peerconnection_api", "../api:libjingle_peerconnection_api",
"../api:rtc_error", "../api:rtc_error",
"../api/units:timestamp",
"../call:rtp_receiver",
"../media:rtp_utils", "../media:rtp_utils",
"../modules/rtp_rtcp:rtp_rtcp_format", "../modules/rtp_rtcp:rtp_rtcp_format",
"../p2p:packet_transport_internal", "../p2p:packet_transport_internal",
@ -633,6 +637,7 @@ rtc_source_set("srtp_transport") {
"../rtc_base:safe_conversions", "../rtc_base:safe_conversions",
"../rtc_base:ssl_adapter", "../rtc_base:ssl_adapter",
"../rtc_base:zero_memory", "../rtc_base:zero_memory",
"../rtc_base/network:received_packet",
"//third_party/abseil-cpp/absl/strings", "//third_party/abseil-cpp/absl/strings",
] ]
} }

View File

@ -12,18 +12,21 @@
#include <string.h> #include <string.h>
#include <cstdint>
#include <cstring>
#include <iomanip> #include <iomanip>
#include <string> #include <vector>
#include "absl/base/attributes.h"
#include "absl/base/const_init.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "api/array_view.h" #include "api/array_view.h"
#include "api/field_trials_view.h" #include "api/field_trials_view.h"
#include "modules/rtp_rtcp/source/rtp_util.h" #include "modules/rtp_rtcp/source/rtp_util.h"
#include "pc/external_hmac.h" #include "pc/external_hmac.h"
#include "rtc_base/buffer.h"
#include "rtc_base/byte_order.h" #include "rtc_base/byte_order.h"
#include "rtc_base/checks.h" #include "rtc_base/checks.h"
#include "rtc_base/copy_on_write_buffer.h"
#include "rtc_base/ip_address.h"
#include "rtc_base/logging.h" #include "rtc_base/logging.h"
#include "rtc_base/ssl_stream_adapter.h" #include "rtc_base/ssl_stream_adapter.h"
#include "rtc_base/string_encode.h" #include "rtc_base/string_encode.h"
@ -149,8 +152,6 @@ void LibSrtpInitializer::DecrementLibsrtpUsageCountAndMaybeDeinit() {
} // namespace } // namespace
using ::webrtc::ParseRtpSequenceNumber;
// One more than the maximum libsrtp error code. Required by // One more than the maximum libsrtp error code. Required by
// RTC_HISTOGRAM_ENUMERATION. Keep this in sync with srtp_error_status_t defined // RTC_HISTOGRAM_ENUMERATION. Keep this in sync with srtp_error_status_t defined
// in srtp.h. // in srtp.h.
@ -196,6 +197,42 @@ bool SrtpSession::UpdateReceive(int crypto_suite,
return UpdateKey(ssrc_any_inbound, crypto_suite, key, extension_ids); return UpdateKey(ssrc_any_inbound, crypto_suite, key, extension_ids);
} }
bool SrtpSession::ProtectRtp(rtc::CopyOnWriteBuffer& buffer) {
RTC_DCHECK(thread_checker_.IsCurrent());
if (!session_) {
RTC_LOG(LS_WARNING) << "Failed to protect SRTP packet: no SRTP Session";
return false;
}
// Note: the need_len differs from the libsrtp recommendatіon to ensure
// SRTP_MAX_TRAILER_LEN bytes of free space after the data. WebRTC
// never includes a MKI, therefore the amount of bytes added by the
// srtp_protect call is known in advance and depends on the cipher suite.
size_t need_len = buffer.size() + rtp_auth_tag_len_; // NOLINT
if (buffer.capacity() < need_len) {
RTC_LOG(LS_WARNING) << "Failed to protect SRTP packet: The buffer length "
<< buffer.capacity() << " is less than the needed "
<< need_len;
return false;
}
if (dump_plain_rtp_) {
DumpPacket(buffer, /*outbound=*/true);
}
int out_len = buffer.size();
int err = srtp_protect(session_, buffer.MutableData<char>(), &out_len);
int seq_num = webrtc::ParseRtpSequenceNumber(buffer);
if (err != srtp_err_status_ok) {
RTC_LOG(LS_WARNING) << "Failed to protect SRTP packet, seqnum=" << seq_num
<< ", err=" << err
<< ", last seqnum=" << last_send_seq_num_;
return false;
}
buffer.SetSize(out_len);
last_send_seq_num_ = seq_num;
return true;
}
bool SrtpSession::ProtectRtp(void* p, int in_len, int max_len, int* out_len) { bool SrtpSession::ProtectRtp(void* p, int in_len, int max_len, int* out_len) {
RTC_DCHECK(thread_checker_.IsCurrent()); RTC_DCHECK(thread_checker_.IsCurrent());
if (!session_) { if (!session_) {
@ -219,7 +256,7 @@ bool SrtpSession::ProtectRtp(void* p, int in_len, int max_len, int* out_len) {
*out_len = in_len; *out_len = in_len;
int err = srtp_protect(session_, p, out_len); int err = srtp_protect(session_, p, out_len);
int seq_num = ParseRtpSequenceNumber( int seq_num = webrtc::ParseRtpSequenceNumber(
rtc::MakeArrayView(reinterpret_cast<const uint8_t*>(p), in_len)); rtc::MakeArrayView(reinterpret_cast<const uint8_t*>(p), in_len));
if (err != srtp_err_status_ok) { if (err != srtp_err_status_ok) {
RTC_LOG(LS_WARNING) << "Failed to protect SRTP packet, seqnum=" << seq_num RTC_LOG(LS_WARNING) << "Failed to protect SRTP packet, seqnum=" << seq_num
@ -231,15 +268,57 @@ bool SrtpSession::ProtectRtp(void* p, int in_len, int max_len, int* out_len) {
return true; return true;
} }
bool SrtpSession::ProtectRtp(void* p, bool SrtpSession::ProtectRtp(rtc::CopyOnWriteBuffer& buffer, int64_t* index) {
if (!ProtectRtp(buffer)) {
return false;
}
return (index) ? GetSendStreamPacketIndex(buffer, index) : true;
}
bool SrtpSession::ProtectRtp(void* data,
int in_len, int in_len,
int max_len, int max_len,
int* out_len, int* out_len,
int64_t* index) { int64_t* index) {
if (!ProtectRtp(p, in_len, max_len, out_len)) { rtc::CopyOnWriteBuffer buffer(static_cast<uint8_t*>(data), in_len, max_len);
if (!ProtectRtp(buffer)) {
return false; return false;
} }
return (index) ? GetSendStreamPacketIndex(p, in_len, index) : true; *out_len = buffer.size();
return (index) ? GetSendStreamPacketIndex(buffer, index) : true;
}
bool SrtpSession::ProtectRtcp(rtc::CopyOnWriteBuffer& buffer) {
RTC_DCHECK(thread_checker_.IsCurrent());
if (!session_) {
RTC_LOG(LS_WARNING) << "Failed to protect SRTCP packet: no SRTP Session";
return false;
}
// Note: the need_len differs from the libsrtp recommendatіon to ensure
// SRTP_MAX_TRAILER_LEN bytes of free space after the data. WebRTC
// never includes a MKI, therefore the amount of bytes added by the
// srtp_protect_rtp call is known in advance and depends on the cipher suite.
size_t need_len =
buffer.size() + sizeof(uint32_t) + rtcp_auth_tag_len_; // NOLINT
if (buffer.capacity() < need_len) {
RTC_LOG(LS_WARNING)
<< "Failed to protect SRTCP packet: The buffer capacity "
<< buffer.capacity() << " is less than the needed " << need_len;
return false;
}
if (dump_plain_rtp_) {
DumpPacket(buffer, /*outbound=*/true);
}
int out_len = buffer.size();
int err = srtp_protect_rtcp(session_, buffer.MutableData<char>(), &out_len);
if (err != srtp_err_status_ok) {
RTC_LOG(LS_WARNING) << "Failed to protect SRTCP packet, err=" << err;
return false;
}
buffer.SetSize(out_len);
return true;
} }
bool SrtpSession::ProtectRtcp(void* p, int in_len, int max_len, int* out_len) { bool SrtpSession::ProtectRtcp(void* p, int in_len, int max_len, int* out_len) {
@ -272,6 +351,36 @@ bool SrtpSession::ProtectRtcp(void* p, int in_len, int max_len, int* out_len) {
return true; return true;
} }
bool SrtpSession::UnprotectRtp(rtc::CopyOnWriteBuffer& buffer) {
RTC_DCHECK(thread_checker_.IsCurrent());
if (!session_) {
RTC_LOG(LS_WARNING) << "Failed to unprotect SRTP packet: no SRTP Session";
return false;
}
int out_len = buffer.size();
int err = srtp_unprotect(session_, buffer.MutableData<char>(), &out_len);
if (err != srtp_err_status_ok) {
// Limit the error logging to avoid excessive logs when there are lots of
// bad packets.
const int kFailureLogThrottleCount = 100;
if (decryption_failure_count_ % kFailureLogThrottleCount == 0) {
RTC_LOG(LS_WARNING) << "Failed to unprotect SRTP packet, err=" << err
<< ", previous failure count: "
<< decryption_failure_count_;
}
++decryption_failure_count_;
RTC_HISTOGRAM_ENUMERATION("WebRTC.PeerConnection.SrtpUnprotectError",
static_cast<int>(err), kSrtpErrorCodeBoundary);
return false;
}
buffer.SetSize(out_len);
if (dump_plain_rtp_) {
DumpPacket(buffer, /*outbound=*/false);
}
return true;
}
bool SrtpSession::UnprotectRtp(void* p, int in_len, int* out_len) { bool SrtpSession::UnprotectRtp(void* p, int in_len, int* out_len) {
RTC_DCHECK(thread_checker_.IsCurrent()); RTC_DCHECK(thread_checker_.IsCurrent());
if (!session_) { if (!session_) {
@ -301,6 +410,28 @@ bool SrtpSession::UnprotectRtp(void* p, int in_len, int* out_len) {
return true; return true;
} }
bool SrtpSession::UnprotectRtcp(rtc::CopyOnWriteBuffer& buffer) {
RTC_DCHECK(thread_checker_.IsCurrent());
if (!session_) {
RTC_LOG(LS_WARNING) << "Failed to unprotect SRTCP packet: no SRTP Session";
return false;
}
int out_len = buffer.size();
int err = srtp_unprotect_rtcp(session_, buffer.MutableData<char>(), &out_len);
if (err != srtp_err_status_ok) {
RTC_LOG(LS_WARNING) << "Failed to unprotect SRTCP packet, err=" << err;
RTC_HISTOGRAM_ENUMERATION("WebRTC.PeerConnection.SrtcpUnprotectError",
static_cast<int>(err), kSrtpErrorCodeBoundary);
return false;
}
buffer.SetSize(out_len);
if (dump_plain_rtp_) {
DumpPacket(buffer, /*outbound=*/false);
}
return true;
}
bool SrtpSession::UnprotectRtcp(void* p, int in_len, int* out_len) { bool SrtpSession::UnprotectRtcp(void* p, int in_len, int* out_len) {
RTC_DCHECK(thread_checker_.IsCurrent()); RTC_DCHECK(thread_checker_.IsCurrent());
if (!session_) { if (!session_) {
@ -373,12 +504,12 @@ bool SrtpSession::RemoveSsrcFromSession(uint32_t ssrc) {
return srtp_remove_stream(session_, htonl(ssrc)) == srtp_err_status_ok; return srtp_remove_stream(session_, htonl(ssrc)) == srtp_err_status_ok;
} }
bool SrtpSession::GetSendStreamPacketIndex(void* p, bool SrtpSession::GetSendStreamPacketIndex(rtc::CopyOnWriteBuffer& buffer,
int in_len,
int64_t* index) { int64_t* index) {
RTC_DCHECK(thread_checker_.IsCurrent()); RTC_DCHECK(thread_checker_.IsCurrent());
srtp_hdr_t* hdr = reinterpret_cast<srtp_hdr_t*>(p); // libSRTP expects the SSRC to be in network byte order.
srtp_stream_ctx_t* stream = srtp_get_stream(session_, hdr->ssrc); srtp_stream_ctx_t* stream =
srtp_get_stream(session_, htonl(webrtc::ParseRtpSsrc(buffer)));
if (!stream) { if (!stream) {
return false; return false;
} }
@ -534,24 +665,31 @@ void SrtpSession::HandleEventThunk(srtp_event_data_t* ev) {
// extracted by searching for RTP_DUMP // extracted by searching for RTP_DUMP
// grep RTP_DUMP chrome_debug.log > in.txt // grep RTP_DUMP chrome_debug.log > in.txt
// and converted to pcap using // and converted to pcap using
// text2pcap -D -u 1000,2000 -t %H:%M:%S. in.txt out.pcap // text2pcap -D -u 1000,2000 -t %H:%M:%S.%f in.txt out.pcap
// The resulting file can be replayed using the WebRTC video_replay tool and // The resulting file can be replayed using the WebRTC video_replay tool and
// be inspected in Wireshark using the RTP, VP8 and H264 dissectors. // be inspected in Wireshark using the RTP, VP8 and H264 dissectors.
void SrtpSession::DumpPacket(const void* buf, int len, bool outbound) { void SrtpSession::DumpPacket(const rtc::CopyOnWriteBuffer& buffer,
bool outbound) {
int64_t time_of_day = rtc::TimeUTCMillis() % (24 * 3600 * 1000); int64_t time_of_day = rtc::TimeUTCMillis() % (24 * 3600 * 1000);
int64_t hours = time_of_day / (3600 * 1000); int64_t hours = time_of_day / (3600 * 1000);
int64_t minutes = (time_of_day / (60 * 1000)) % 60; int64_t minutes = (time_of_day / (60 * 1000)) % 60;
int64_t seconds = (time_of_day / 1000) % 60; int64_t seconds = (time_of_day / 1000) % 60;
int64_t millis = time_of_day % 1000; int64_t millis = time_of_day % 1000;
RTC_LOG(LS_VERBOSE) << "\n" RTC_LOG(LS_VERBOSE)
<< (outbound ? "O" : "I") << " " << std::setfill('0') << "\n"
<< std::setw(2) << hours << ":" << std::setfill('0') << (outbound ? "O" : "I") << " " << std::setfill('0') << std::setw(2)
<< std::setw(2) << minutes << ":" << std::setfill('0') << hours << ":" << std::setfill('0') << std::setw(2) << minutes << ":"
<< std::setw(2) << seconds << "." << std::setfill('0') << std::setfill('0') << std::setw(2) << seconds << "."
<< std::setw(3) << millis << " " << "000000 " << std::setfill('0') << std::setw(3) << millis << " " << "000000 "
<< rtc::hex_encode_with_delimiter( << rtc::hex_encode_with_delimiter(
absl::string_view((const char*)buf, len), ' ') absl::string_view(buffer.data<char>(), buffer.size()), ' ')
<< " # RTP_DUMP"; << " # RTP_DUMP";
} }
void SrtpSession::DumpPacket(const void* buf, int len, bool outbound) {
const rtc::CopyOnWriteBuffer buffer(static_cast<const uint8_t*>(buf), len,
len);
DumpPacket(buffer, outbound);
}
} // namespace cricket } // namespace cricket

View File

@ -17,9 +17,9 @@
#include <vector> #include <vector>
#include "api/field_trials_view.h" #include "api/field_trials_view.h"
#include "api/scoped_refptr.h"
#include "api/sequence_checker.h" #include "api/sequence_checker.h"
#include "rtc_base/buffer.h" #include "rtc_base/buffer.h"
#include "rtc_base/copy_on_write_buffer.h"
// Forward declaration to avoid pulling in libsrtp headers here // Forward declaration to avoid pulling in libsrtp headers here
struct srtp_event_data_t; struct srtp_event_data_t;
@ -62,18 +62,34 @@ class SrtpSession {
// Encrypts/signs an individual RTP/RTCP packet, in-place. // Encrypts/signs an individual RTP/RTCP packet, in-place.
// If an HMAC is used, this will increase the packet size. // If an HMAC is used, this will increase the packet size.
bool ProtectRtp(void* data, int in_len, int max_len, int* out_len); [[deprecated("Pass CopyOnWriteBuffer")]] bool ProtectRtp(void* data,
int in_len,
int max_len,
int* out_len);
bool ProtectRtp(rtc::CopyOnWriteBuffer& buffer);
// Overloaded version, outputs packet index. // Overloaded version, outputs packet index.
bool ProtectRtp(void* data, [[deprecated("Pass CopyOnWriteBuffer")]] bool ProtectRtp(void* data,
int in_len, int in_len,
int max_len, int max_len,
int* out_len, int* out_len,
int64_t* index); int64_t* index);
bool ProtectRtcp(void* data, int in_len, int max_len, int* out_len); bool ProtectRtp(rtc::CopyOnWriteBuffer& buffer, int64_t* index);
[[deprecated("Pass CopyOnWriteBuffer")]] bool ProtectRtcp(void* data,
int in_len,
int max_len,
int* out_len);
bool ProtectRtcp(rtc::CopyOnWriteBuffer& buffer);
// Decrypts/verifies an invidiual RTP/RTCP packet. // Decrypts/verifies an invidiual RTP/RTCP packet.
// If an HMAC is used, this will decrease the packet size. // If an HMAC is used, this will decrease the packet size.
bool UnprotectRtp(void* data, int in_len, int* out_len); [[deprecated("Pass CopyOnWriteBuffer")]] bool UnprotectRtp(void* data,
bool UnprotectRtcp(void* data, int in_len, int* out_len); int in_len,
int* out_len);
bool UnprotectRtp(rtc::CopyOnWriteBuffer& buffer);
[[deprecated("Pass CopyOnWriteBuffer")]] bool UnprotectRtcp(void* data,
int in_len,
int* out_len);
bool UnprotectRtcp(rtc::CopyOnWriteBuffer& buffer);
// Helper method to get authentication params. // Helper method to get authentication params.
bool GetRtpAuthParams(uint8_t** key, int* key_len, int* tag_len); bool GetRtpAuthParams(uint8_t** key, int* key_len, int* tag_len);
@ -115,11 +131,14 @@ class SrtpSession {
const rtc::ZeroOnFreeBuffer<uint8_t>& key, const rtc::ZeroOnFreeBuffer<uint8_t>& key,
const std::vector<int>& extension_ids); const std::vector<int>& extension_ids);
// Returns send stream current packet index from srtp db. // Returns send stream current packet index from srtp db.
bool GetSendStreamPacketIndex(void* data, int in_len, int64_t* index); bool GetSendStreamPacketIndex(rtc::CopyOnWriteBuffer& buffer, int64_t* index);
// Writes unencrypted packets in text2pcap format to the log file // Writes unencrypted packets in text2pcap format to the log file
// for debugging. // for debugging.
void DumpPacket(const void* buf, int len, bool outbound); void DumpPacket(const rtc::CopyOnWriteBuffer& buffer, bool outbound);
[[deprecated("Pass CopyOnWriteBuffer")]] void DumpPacket(const void* buf,
int len,
bool outbound);
void HandleEvent(const srtp_event_data_t* ev); void HandleEvent(const srtp_event_data_t* ev);
static void HandleEventThunk(srtp_event_data_t* ev); static void HandleEventThunk(srtp_event_data_t* ev);

View File

@ -12,11 +12,16 @@
#include <string.h> #include <string.h>
#include <string> #include <cstdint>
#include <cstring>
#include <limits>
#include <vector>
#include "media/base/fake_rtp.h" #include "media/base/fake_rtp.h"
#include "pc/test/srtp_test_util.h" #include "pc/test/srtp_test_util.h"
#include "rtc_base/buffer.h"
#include "rtc_base/byte_order.h" #include "rtc_base/byte_order.h"
#include "rtc_base/copy_on_write_buffer.h"
#include "rtc_base/ssl_stream_adapter.h" // For rtc::SRTP_* #include "rtc_base/ssl_stream_adapter.h" // For rtc::SRTP_*
#include "system_wrappers/include/metrics.h" #include "system_wrappers/include/metrics.h"
#include "test/gmock.h" #include "test/gmock.h"
@ -41,45 +46,45 @@ class SrtpSessionTest : public ::testing::Test {
virtual void SetUp() { virtual void SetUp() {
rtp_len_ = sizeof(kPcmuFrame); rtp_len_ = sizeof(kPcmuFrame);
rtcp_len_ = sizeof(kRtcpReport); rtcp_len_ = sizeof(kRtcpReport);
memcpy(rtp_packet_, kPcmuFrame, rtp_len_); rtp_packet_.EnsureCapacity(rtp_len_ + 10);
memcpy(rtcp_packet_, kRtcpReport, rtcp_len_); rtp_packet_.SetData(kPcmuFrame, rtp_len_);
rtcp_packet_.EnsureCapacity(rtcp_len_ + 4 + 10);
rtcp_packet_.SetData(kRtcpReport, rtcp_len_);
} }
void TestProtectRtp(int crypto_suite) { void TestProtectRtp(int crypto_suite) {
int out_len = 0; EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_));
EXPECT_TRUE( EXPECT_EQ(rtp_packet_.size(), rtp_len_ + rtp_auth_tag_len(crypto_suite));
s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), &out_len)); // Check that Protect changed the content (up to the original length).
EXPECT_EQ(out_len, rtp_len_ + rtp_auth_tag_len(crypto_suite)); EXPECT_NE(0, std::memcmp(kPcmuFrame, rtp_packet_.data(), rtp_len_));
EXPECT_NE(0, memcmp(rtp_packet_, kPcmuFrame, rtp_len_)); rtp_len_ = rtp_packet_.size();
rtp_len_ = out_len;
} }
void TestProtectRtcp(int crypto_suite) { void TestProtectRtcp(int crypto_suite) {
int out_len = 0; EXPECT_TRUE(s1_.ProtectRtcp(rtcp_packet_));
EXPECT_TRUE(s1_.ProtectRtcp(rtcp_packet_, rtcp_len_, sizeof(rtcp_packet_), EXPECT_EQ(rtcp_packet_.size(),
&out_len)); rtcp_len_ + 4 + rtcp_auth_tag_len(crypto_suite));
EXPECT_EQ(out_len, // Check that Protect changed the content (up to the original length).
rtcp_len_ + 4 + rtcp_auth_tag_len(crypto_suite)); // NOLINT EXPECT_NE(0, std::memcmp(kRtcpReport, rtcp_packet_.data(), rtcp_len_));
EXPECT_NE(0, memcmp(rtcp_packet_, kRtcpReport, rtcp_len_)); rtcp_len_ = rtcp_packet_.size();
rtcp_len_ = out_len;
} }
void TestUnprotectRtp(int crypto_suite) { void TestUnprotectRtp(int crypto_suite) {
int out_len = 0, expected_len = sizeof(kPcmuFrame); EXPECT_TRUE(s2_.UnprotectRtp(rtp_packet_));
EXPECT_TRUE(s2_.UnprotectRtp(rtp_packet_, rtp_len_, &out_len)); EXPECT_EQ(rtp_packet_.size(), sizeof(kPcmuFrame));
EXPECT_EQ(expected_len, out_len); EXPECT_EQ(0,
EXPECT_EQ(0, memcmp(rtp_packet_, kPcmuFrame, out_len)); std::memcmp(kPcmuFrame, rtp_packet_.data(), rtp_packet_.size()));
} }
void TestUnprotectRtcp(int crypto_suite) { void TestUnprotectRtcp(int crypto_suite) {
int out_len = 0, expected_len = sizeof(kRtcpReport); EXPECT_TRUE(s2_.UnprotectRtcp(rtcp_packet_));
EXPECT_TRUE(s2_.UnprotectRtcp(rtcp_packet_, rtcp_len_, &out_len)); EXPECT_EQ(rtcp_packet_.size(), sizeof(kRtcpReport));
EXPECT_EQ(expected_len, out_len); EXPECT_EQ(
EXPECT_EQ(0, memcmp(rtcp_packet_, kRtcpReport, out_len)); 0, std::memcmp(kRtcpReport, rtcp_packet_.data(), rtcp_packet_.size()));
} }
webrtc::test::ScopedKeyValueConfig field_trials_; webrtc::test::ScopedKeyValueConfig field_trials_;
cricket::SrtpSession s1_; cricket::SrtpSession s1_;
cricket::SrtpSession s2_; cricket::SrtpSession s2_;
char rtp_packet_[sizeof(kPcmuFrame) + 10]; rtc::CopyOnWriteBuffer rtp_packet_;
char rtcp_packet_[sizeof(kRtcpReport) + 4 + 10]; rtc::CopyOnWriteBuffer rtcp_packet_;
int rtp_len_; size_t rtp_len_;
int rtcp_len_; size_t rtcp_len_;
}; };
// Test that we can set up the session and keys properly. // Test that we can set up the session and keys properly.
@ -140,9 +145,7 @@ TEST_F(SrtpSessionTest, TestGetSendStreamPacketIndex) {
EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_32, kTestKey1, EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_32, kTestKey1,
kEncryptedHeaderExtensionIds)); kEncryptedHeaderExtensionIds));
int64_t index; int64_t index;
int out_len = 0; EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_, &index));
EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_),
&out_len, &index));
// `index` will be shifted by 16. // `index` will be shifted by 16.
int64_t be64_index = static_cast<int64_t>(NetworkToHost64(1 << 16)); int64_t be64_index = static_cast<int64_t>(NetworkToHost64(1 << 16));
EXPECT_EQ(be64_index, index); EXPECT_EQ(be64_index, index);
@ -150,20 +153,20 @@ TEST_F(SrtpSessionTest, TestGetSendStreamPacketIndex) {
// Test that we fail to unprotect if someone tampers with the RTP/RTCP paylaods. // Test that we fail to unprotect if someone tampers with the RTP/RTCP paylaods.
TEST_F(SrtpSessionTest, TestTamperReject) { TEST_F(SrtpSessionTest, TestTamperReject) {
int out_len;
EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1, EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1,
kEncryptedHeaderExtensionIds)); kEncryptedHeaderExtensionIds));
EXPECT_TRUE(s2_.SetReceive(kSrtpAes128CmSha1_80, kTestKey1, EXPECT_TRUE(s2_.SetReceive(kSrtpAes128CmSha1_80, kTestKey1,
kEncryptedHeaderExtensionIds)); kEncryptedHeaderExtensionIds));
TestProtectRtp(kSrtpAes128CmSha1_80); TestProtectRtp(kSrtpAes128CmSha1_80);
TestProtectRtcp(kSrtpAes128CmSha1_80); rtp_packet_.MutableData<uint8_t>()[0] = 0x12;
rtp_packet_[0] = 0x12; EXPECT_FALSE(s2_.UnprotectRtp(rtp_packet_));
rtcp_packet_[1] = 0x34;
EXPECT_FALSE(s2_.UnprotectRtp(rtp_packet_, rtp_len_, &out_len));
EXPECT_METRIC_THAT( EXPECT_METRIC_THAT(
webrtc::metrics::Samples("WebRTC.PeerConnection.SrtpUnprotectError"), webrtc::metrics::Samples("WebRTC.PeerConnection.SrtpUnprotectError"),
ElementsAre(Pair(srtp_err_status_bad_param, 1))); ElementsAre(Pair(srtp_err_status_bad_param, 1)));
EXPECT_FALSE(s2_.UnprotectRtcp(rtcp_packet_, rtcp_len_, &out_len));
TestProtectRtcp(kSrtpAes128CmSha1_80);
rtcp_packet_.MutableData<uint8_t>()[1] = 0x34;
EXPECT_FALSE(s2_.UnprotectRtcp(rtcp_packet_));
EXPECT_METRIC_THAT( EXPECT_METRIC_THAT(
webrtc::metrics::Samples("WebRTC.PeerConnection.SrtcpUnprotectError"), webrtc::metrics::Samples("WebRTC.PeerConnection.SrtcpUnprotectError"),
ElementsAre(Pair(srtp_err_status_auth_fail, 1))); ElementsAre(Pair(srtp_err_status_auth_fail, 1)));
@ -171,16 +174,15 @@ TEST_F(SrtpSessionTest, TestTamperReject) {
// Test that we fail to unprotect if the payloads are not authenticated. // Test that we fail to unprotect if the payloads are not authenticated.
TEST_F(SrtpSessionTest, TestUnencryptReject) { TEST_F(SrtpSessionTest, TestUnencryptReject) {
int out_len;
EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1, EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1,
kEncryptedHeaderExtensionIds)); kEncryptedHeaderExtensionIds));
EXPECT_TRUE(s2_.SetReceive(kSrtpAes128CmSha1_80, kTestKey1, EXPECT_TRUE(s2_.SetReceive(kSrtpAes128CmSha1_80, kTestKey1,
kEncryptedHeaderExtensionIds)); kEncryptedHeaderExtensionIds));
EXPECT_FALSE(s2_.UnprotectRtp(rtp_packet_, rtp_len_, &out_len)); EXPECT_FALSE(s2_.UnprotectRtp(rtp_packet_));
EXPECT_METRIC_THAT( EXPECT_METRIC_THAT(
webrtc::metrics::Samples("WebRTC.PeerConnection.SrtpUnprotectError"), webrtc::metrics::Samples("WebRTC.PeerConnection.SrtpUnprotectError"),
ElementsAre(Pair(srtp_err_status_auth_fail, 1))); ElementsAre(Pair(srtp_err_status_auth_fail, 1)));
EXPECT_FALSE(s2_.UnprotectRtcp(rtcp_packet_, rtcp_len_, &out_len)); EXPECT_FALSE(s2_.UnprotectRtcp(rtcp_packet_));
EXPECT_METRIC_THAT( EXPECT_METRIC_THAT(
webrtc::metrics::Samples("WebRTC.PeerConnection.SrtcpUnprotectError"), webrtc::metrics::Samples("WebRTC.PeerConnection.SrtcpUnprotectError"),
ElementsAre(Pair(srtp_err_status_cant_check, 1))); ElementsAre(Pair(srtp_err_status_cant_check, 1)));
@ -188,21 +190,23 @@ TEST_F(SrtpSessionTest, TestUnencryptReject) {
// Test that we fail when using buffers that are too small. // Test that we fail when using buffers that are too small.
TEST_F(SrtpSessionTest, TestBuffersTooSmall) { TEST_F(SrtpSessionTest, TestBuffersTooSmall) {
int out_len;
EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1, EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1,
kEncryptedHeaderExtensionIds)); kEncryptedHeaderExtensionIds));
EXPECT_FALSE(s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_) - 10, // This buffer does not have extra capacity which we treat as an error.
&out_len)); rtc::CopyOnWriteBuffer rtp_packet(rtp_packet_.data(), rtp_packet_.size(),
EXPECT_FALSE(s1_.ProtectRtcp(rtcp_packet_, rtcp_len_, rtp_packet_.size());
sizeof(rtcp_packet_) - 14, &out_len)); EXPECT_FALSE(s1_.ProtectRtp(rtp_packet));
// This buffer does not have extra capacity which we treat as an error.
rtc::CopyOnWriteBuffer rtcp_packet(rtcp_packet_.data(), rtcp_packet_.size(),
rtcp_packet_.size());
EXPECT_FALSE(s1_.ProtectRtcp(rtcp_packet));
} }
TEST_F(SrtpSessionTest, TestReplay) { TEST_F(SrtpSessionTest, TestReplay) {
static const uint16_t kMaxSeqnum = static_cast<uint16_t>(-1); static const uint16_t kMaxSeqnum = std::numeric_limits<uint16_t>::max() - 1;
static const uint16_t seqnum_big = 62275; static const uint16_t seqnum_big = 62275;
static const uint16_t seqnum_small = 10; static const uint16_t seqnum_small = 10;
static const uint16_t replay_window = 1024; static const uint16_t replay_window = 1024;
int out_len;
EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1, EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1,
kEncryptedHeaderExtensionIds)); kEncryptedHeaderExtensionIds));
@ -210,38 +214,37 @@ TEST_F(SrtpSessionTest, TestReplay) {
kEncryptedHeaderExtensionIds)); kEncryptedHeaderExtensionIds));
// Initial sequence number. // Initial sequence number.
SetBE16(reinterpret_cast<uint8_t*>(rtp_packet_) + 2, seqnum_big); SetBE16(rtp_packet_.MutableData<uint8_t>() + 2, seqnum_big);
EXPECT_TRUE( EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_));
s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), &out_len)); rtp_packet_.SetData(kPcmuFrame, sizeof(kPcmuFrame));
// Replay within the 1024 window should succeed. // Replay within the 1024 window should succeed.
SetBE16(reinterpret_cast<uint8_t*>(rtp_packet_) + 2, SetBE16(rtp_packet_.MutableData<uint8_t>() + 2,
seqnum_big - replay_window + 1); seqnum_big - replay_window + 1);
EXPECT_TRUE( EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_));
s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), &out_len)); rtp_packet_.SetData(kPcmuFrame, sizeof(kPcmuFrame));
// Replay out side of the 1024 window should fail. // Replay out side of the 1024 window should fail.
SetBE16(reinterpret_cast<uint8_t*>(rtp_packet_) + 2, SetBE16(rtp_packet_.MutableData<uint8_t>() + 2,
seqnum_big - replay_window - 1); seqnum_big - replay_window - 1);
EXPECT_FALSE( EXPECT_FALSE(s1_.ProtectRtp(rtp_packet_));
s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), &out_len)); rtp_packet_.SetData(kPcmuFrame, sizeof(kPcmuFrame));
// Increment sequence number to a small number. // Increment sequence number to a small number.
SetBE16(reinterpret_cast<uint8_t*>(rtp_packet_) + 2, seqnum_small); SetBE16(rtp_packet_.MutableData<uint8_t>() + 2, seqnum_small);
EXPECT_TRUE( EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_));
s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), &out_len));
// Replay around 0 but out side of the 1024 window should fail. // Replay around 0 but out side of the 1024 window should fail.
SetBE16(reinterpret_cast<uint8_t*>(rtp_packet_) + 2, SetBE16(rtp_packet_.MutableData<uint8_t>() + 2,
kMaxSeqnum + seqnum_small - replay_window - 1); kMaxSeqnum + seqnum_small - replay_window - 1);
EXPECT_FALSE( EXPECT_FALSE(s1_.ProtectRtp(rtp_packet_));
s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), &out_len)); rtp_packet_.SetData(kPcmuFrame, sizeof(kPcmuFrame));
// Replay around 0 but within the 1024 window should succeed. // Replay around 0 but within the 1024 window should succeed.
for (uint16_t seqnum = 65000; seqnum < 65003; ++seqnum) { for (uint16_t seqnum = 65000; seqnum < 65003; ++seqnum) {
SetBE16(reinterpret_cast<uint8_t*>(rtp_packet_) + 2, seqnum); SetBE16(rtp_packet_.MutableData<uint8_t>() + 2, seqnum);
EXPECT_TRUE( EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_));
s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), &out_len)); rtp_packet_.SetData(kPcmuFrame, sizeof(kPcmuFrame));
} }
// Go back to normal sequence nubmer. // Go back to normal sequence nubmer.
@ -249,9 +252,8 @@ TEST_F(SrtpSessionTest, TestReplay) {
// without the fix, the loop above would keep incrementing local sequence // without the fix, the loop above would keep incrementing local sequence
// number in libsrtp, eventually the new sequence number would go out side // number in libsrtp, eventually the new sequence number would go out side
// of the window. // of the window.
SetBE16(reinterpret_cast<uint8_t*>(rtp_packet_) + 2, seqnum_small + 1); SetBE16(rtp_packet_.MutableData<uint8_t>() + 2, seqnum_small + 1);
EXPECT_TRUE( EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_));
s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), &out_len));
} }
TEST_F(SrtpSessionTest, RemoveSsrc) { TEST_F(SrtpSessionTest, RemoveSsrc) {
@ -259,33 +261,32 @@ TEST_F(SrtpSessionTest, RemoveSsrc) {
kEncryptedHeaderExtensionIds)); kEncryptedHeaderExtensionIds));
EXPECT_TRUE(s2_.SetReceive(kSrtpAes128CmSha1_80, kTestKey1, EXPECT_TRUE(s2_.SetReceive(kSrtpAes128CmSha1_80, kTestKey1,
kEncryptedHeaderExtensionIds)); kEncryptedHeaderExtensionIds));
int out_len;
// Encrypt and decrypt the packet once. // Encrypt and decrypt the packet once.
EXPECT_TRUE( EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_));
s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), &out_len)); EXPECT_TRUE(s2_.UnprotectRtp(rtp_packet_));
EXPECT_TRUE(s2_.UnprotectRtp(rtp_packet_, out_len, &out_len)); EXPECT_EQ(sizeof(kPcmuFrame), rtp_packet_.size());
EXPECT_EQ(rtp_len_, out_len); EXPECT_EQ(0, std::memcmp(kPcmuFrame, rtp_packet_.data(), rtp_packet_.size()));
EXPECT_EQ(0, memcmp(rtp_packet_, kPcmuFrame, out_len));
// Recreate the original packet and encrypt again. // Recreate the original packet and encrypt again.
memcpy(rtp_packet_, kPcmuFrame, rtp_len_); rtp_packet_.SetData(kPcmuFrame, sizeof(kPcmuFrame));
EXPECT_TRUE( EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_));
s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), &out_len));
// Attempting to decrypt will fail as a replay attack. // Attempting to decrypt will fail as a replay attack.
// (srtp_err_status_replay_fail) since the sequence number was already seen. // (srtp_err_status_replay_fail) since the sequence number was already seen.
EXPECT_FALSE(s2_.UnprotectRtp(rtp_packet_, out_len, &out_len)); EXPECT_FALSE(s2_.UnprotectRtp(rtp_packet_));
// Remove the fake packet SSRC 1 from the session. // Remove the fake packet SSRC 1 from the session.
EXPECT_TRUE(s2_.RemoveSsrcFromSession(1)); EXPECT_TRUE(s2_.RemoveSsrcFromSession(1));
EXPECT_FALSE(s2_.RemoveSsrcFromSession(1)); EXPECT_FALSE(s2_.RemoveSsrcFromSession(1));
// Since the SRTP state was discarded, this is no longer a replay attack. // Since the SRTP state was discarded, this is no longer a replay attack.
EXPECT_TRUE(s2_.UnprotectRtp(rtp_packet_, out_len, &out_len)); EXPECT_TRUE(s2_.UnprotectRtp(rtp_packet_));
EXPECT_EQ(rtp_len_, out_len); EXPECT_EQ(sizeof(kPcmuFrame), rtp_packet_.size());
EXPECT_EQ(0, memcmp(rtp_packet_, kPcmuFrame, out_len)); EXPECT_EQ(0, std::memcmp(kPcmuFrame, rtp_packet_.data(), rtp_packet_.size()));
EXPECT_TRUE(s2_.RemoveSsrcFromSession(1)); EXPECT_TRUE(s2_.RemoveSsrcFromSession(1));
} }
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wdeprecated-declarations"
TEST_F(SrtpSessionTest, ProtectUnprotectWrapAroundRocMismatch) { TEST_F(SrtpSessionTest, ProtectUnprotectWrapAroundRocMismatch) {
// This unit tests demonstrates why you should be careful when // This unit tests demonstrates why you should be careful when
// choosing the initial RTP sequence number as there can be decryption // choosing the initial RTP sequence number as there can be decryption
@ -316,6 +317,7 @@ TEST_F(SrtpSessionTest, ProtectUnprotectWrapAroundRocMismatch) {
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
// clang-format on // clang-format on
}; };
const unsigned char kPayload[] = {0xBE, 0xEF};
int out_len; int out_len;
// Encrypt the frames in-order. There is a sequence number rollover from // Encrypt the frames in-order. There is a sequence number rollover from
@ -337,9 +339,12 @@ TEST_F(SrtpSessionTest, ProtectUnprotectWrapAroundRocMismatch) {
EXPECT_FALSE(s2_.UnprotectRtp(kFrame2, sizeof(kFrame2), &out_len)); EXPECT_FALSE(s2_.UnprotectRtp(kFrame2, sizeof(kFrame2), &out_len));
// Decrypt frame 1. // Decrypt frame 1.
EXPECT_TRUE(s2_.UnprotectRtp(kFrame1, sizeof(kFrame1), &out_len)); EXPECT_TRUE(s2_.UnprotectRtp(kFrame1, sizeof(kFrame1), &out_len));
EXPECT_EQ(0, std::memcmp(kFrame1 + 12, kPayload, sizeof(kPayload)));
// Now decrypt frame 2 again. A rollover is detected which increases // Now decrypt frame 2 again. A rollover is detected which increases
// the ROC to 1 so this succeeds. // the ROC to 1 so this succeeds.
EXPECT_TRUE(s2_.UnprotectRtp(kFrame2, sizeof(kFrame2), &out_len)); EXPECT_TRUE(s2_.UnprotectRtp(kFrame2, sizeof(kFrame2), &out_len));
EXPECT_EQ(0, std::memcmp(kFrame2 + 12, kPayload, sizeof(kPayload)));
} }
#pragma clang diagnostic pop
} // namespace rtc } // namespace rtc

View File

@ -10,25 +10,26 @@
#include "pc/srtp_transport.h" #include "pc/srtp_transport.h"
#include <string.h> #include <cstdint>
#include <optional>
#include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "absl/strings/match.h" #include "api/field_trials_view.h"
#include "api/units/timestamp.h"
#include "call/rtp_demuxer.h"
#include "media/base/rtp_utils.h" #include "media/base/rtp_utils.h"
#include "modules/rtp_rtcp/source/rtp_util.h" #include "modules/rtp_rtcp/source/rtp_util.h"
#include "pc/rtp_transport.h" #include "pc/rtp_transport.h"
#include "pc/srtp_session.h" #include "pc/srtp_session.h"
#include "rtc_base/async_packet_socket.h" #include "rtc_base/async_packet_socket.h"
#include "rtc_base/buffer.h"
#include "rtc_base/checks.h" #include "rtc_base/checks.h"
#include "rtc_base/copy_on_write_buffer.h" #include "rtc_base/copy_on_write_buffer.h"
#include "rtc_base/logging.h" #include "rtc_base/logging.h"
#include "rtc_base/numerics/safe_conversions.h" #include "rtc_base/network/received_packet.h"
#include "rtc_base/ssl_stream_adapter.h" #include "rtc_base/network_route.h"
#include "rtc_base/trace_event.h" #include "rtc_base/trace_event.h"
#include "rtc_base/zero_memory.h"
namespace webrtc { namespace webrtc {
@ -40,6 +41,7 @@ SrtpTransport::SrtpTransport(bool rtcp_mux_enabled,
bool SrtpTransport::SendRtpPacket(rtc::CopyOnWriteBuffer* packet, bool SrtpTransport::SendRtpPacket(rtc::CopyOnWriteBuffer* packet,
const rtc::PacketOptions& options, const rtc::PacketOptions& options,
int flags) { int flags) {
RTC_DCHECK(packet);
if (!IsSrtpActive()) { if (!IsSrtpActive()) {
RTC_LOG(LS_ERROR) RTC_LOG(LS_ERROR)
<< "Failed to send the packet because SRTP transport is inactive."; << "Failed to send the packet because SRTP transport is inactive.";
@ -47,23 +49,21 @@ bool SrtpTransport::SendRtpPacket(rtc::CopyOnWriteBuffer* packet,
} }
rtc::PacketOptions updated_options = options; rtc::PacketOptions updated_options = options;
TRACE_EVENT0("webrtc", "SRTP Encode"); TRACE_EVENT0("webrtc", "SRTP Encode");
// If ENABLE_EXTERNAL_AUTH flag is on then packet authentication is not done
// inside libsrtp for a RTP packet. A external HMAC module will be writing
// a fake HMAC value. This is ONLY done for a RTP packet.
// Socket layer will update rtp sendtime extension header if present in
// packet with current time before updating the HMAC.
bool res; bool res;
uint8_t* data = packet->MutableData();
int len = rtc::checked_cast<int>(packet->size());
// If ENABLE_EXTERNAL_AUTH flag is on then packet authentication is not done
// inside libsrtp for a RTP packet. A external HMAC module will be writing
// a fake HMAC value. This is ONLY done for a RTP packet.
// Socket layer will update rtp sendtime extension header if present in
// packet with current time before updating the HMAC.
#if !defined(ENABLE_EXTERNAL_AUTH) #if !defined(ENABLE_EXTERNAL_AUTH)
res = ProtectRtp(data, len, static_cast<int>(packet->capacity()), &len); res = ProtectRtp(*packet);
#else #else
if (!IsExternalAuthActive()) { if (!IsExternalAuthActive()) {
res = ProtectRtp(data, len, static_cast<int>(packet->capacity()), &len); res = ProtectRtp(*packet);
} else { } else {
updated_options.packet_time_params.rtp_sendtime_extension_id = updated_options.packet_time_params.rtp_sendtime_extension_id =
rtp_abs_sendtime_extn_id_; rtp_abs_sendtime_extn_id_;
res = ProtectRtp(data, len, static_cast<int>(packet->capacity()), &len, res = ProtectRtp(*packet,
&updated_options.packet_time_params.srtp_packet_index); &updated_options.packet_time_params.srtp_packet_index);
// If protection succeeds, let's get auth params from srtp. // If protection succeeds, let's get auth params from srtp.
if (res) { if (res) {
@ -83,19 +83,18 @@ bool SrtpTransport::SendRtpPacket(rtc::CopyOnWriteBuffer* packet,
if (!res) { if (!res) {
uint16_t seq_num = ParseRtpSequenceNumber(*packet); uint16_t seq_num = ParseRtpSequenceNumber(*packet);
uint32_t ssrc = ParseRtpSsrc(*packet); uint32_t ssrc = ParseRtpSsrc(*packet);
RTC_LOG(LS_ERROR) << "Failed to protect RTP packet: size=" << len RTC_LOG(LS_ERROR) << "Failed to protect RTP packet: size=" << packet->size()
<< ", seqnum=" << seq_num << ", SSRC=" << ssrc; << ", seqnum=" << seq_num << ", SSRC=" << ssrc;
return false; return false;
} }
// Update the length of the packet now that we've added the auth tag.
packet->SetSize(len);
return SendPacket(/*rtcp=*/false, packet, updated_options, flags); return SendPacket(/*rtcp=*/false, packet, updated_options, flags);
} }
bool SrtpTransport::SendRtcpPacket(rtc::CopyOnWriteBuffer* packet, bool SrtpTransport::SendRtcpPacket(rtc::CopyOnWriteBuffer* packet,
const rtc::PacketOptions& options, const rtc::PacketOptions& options,
int flags) { int flags) {
RTC_DCHECK(packet);
if (!IsSrtpActive()) { if (!IsSrtpActive()) {
RTC_LOG(LS_ERROR) RTC_LOG(LS_ERROR)
<< "Failed to send the packet because SRTP transport is inactive."; << "Failed to send the packet because SRTP transport is inactive.";
@ -103,17 +102,13 @@ bool SrtpTransport::SendRtcpPacket(rtc::CopyOnWriteBuffer* packet,
} }
TRACE_EVENT0("webrtc", "SRTP Encode"); TRACE_EVENT0("webrtc", "SRTP Encode");
uint8_t* data = packet->MutableData(); if (!ProtectRtcp(*packet)) {
int len = rtc::checked_cast<int>(packet->size());
if (!ProtectRtcp(data, len, static_cast<int>(packet->capacity()), &len)) {
int type = -1; int type = -1;
cricket::GetRtcpType(data, len, &type); cricket::GetRtcpType(packet->data(), packet->size(), &type);
RTC_LOG(LS_ERROR) << "Failed to protect RTCP packet: size=" << len RTC_LOG(LS_ERROR) << "Failed to protect RTCP packet: size="
<< ", type=" << type; << packet->size() << ", type=" << type;
return false; return false;
} }
// Update the length of the packet now that we've added the auth tag.
packet->SetSize(len);
return SendPacket(/*rtcp=*/true, packet, options, flags); return SendPacket(/*rtcp=*/true, packet, options, flags);
} }
@ -127,14 +122,13 @@ void SrtpTransport::OnRtpPacketReceived(const rtc::ReceivedPacket& packet) {
} }
rtc::CopyOnWriteBuffer payload(packet.payload()); rtc::CopyOnWriteBuffer payload(packet.payload());
char* data = payload.MutableData<char>(); if (!UnprotectRtp(payload)) {
int len = rtc::checked_cast<int>(payload.size());
if (!UnprotectRtp(data, len, &len)) {
// Limit the error logging to avoid excessive logs when there are lots of // Limit the error logging to avoid excessive logs when there are lots of
// bad packets. // bad packets.
const int kFailureLogThrottleCount = 100; const int kFailureLogThrottleCount = 100;
if (decryption_failure_count_ % kFailureLogThrottleCount == 0) { if (decryption_failure_count_ % kFailureLogThrottleCount == 0) {
RTC_LOG(LS_ERROR) << "Failed to unprotect RTP packet: size=" << len RTC_LOG(LS_ERROR) << "Failed to unprotect RTP packet: size="
<< payload.size()
<< ", seqnum=" << ParseRtpSequenceNumber(payload) << ", seqnum=" << ParseRtpSequenceNumber(payload)
<< ", SSRC=" << ParseRtpSsrc(payload) << ", SSRC=" << ParseRtpSsrc(payload)
<< ", previous failure count: " << ", previous failure count: "
@ -143,7 +137,6 @@ void SrtpTransport::OnRtpPacketReceived(const rtc::ReceivedPacket& packet) {
++decryption_failure_count_; ++decryption_failure_count_;
return; return;
} }
payload.SetSize(len);
DemuxPacket(std::move(payload), DemuxPacket(std::move(payload),
packet.arrival_time().value_or(Timestamp::MinusInfinity()), packet.arrival_time().value_or(Timestamp::MinusInfinity()),
packet.ecn()); packet.ecn());
@ -157,16 +150,13 @@ void SrtpTransport::OnRtcpPacketReceived(const rtc::ReceivedPacket& packet) {
return; return;
} }
rtc::CopyOnWriteBuffer payload(packet.payload()); rtc::CopyOnWriteBuffer payload(packet.payload());
char* data = payload.MutableData<char>(); if (!UnprotectRtcp(payload)) {
int len = rtc::checked_cast<int>(payload.size());
if (!UnprotectRtcp(data, len, &len)) {
int type = -1; int type = -1;
cricket::GetRtcpType(data, len, &type); cricket::GetRtcpType(payload.data(), payload.size(), &type);
RTC_LOG(LS_ERROR) << "Failed to unprotect RTCP packet: size=" << len RTC_LOG(LS_ERROR) << "Failed to unprotect RTCP packet: size="
<< ", type=" << type; << payload.size() << ", type=" << type;
return; return;
} }
payload.SetSize(len);
SendRtcpPacketReceived( SendRtcpPacketReceived(
&payload, packet.arrival_time() ? packet.arrival_time()->us() : -1); &payload, packet.arrival_time() ? packet.arrival_time()->us() : -1);
} }
@ -291,63 +281,56 @@ void SrtpTransport::CreateSrtpSessions() {
} }
} }
bool SrtpTransport::ProtectRtp(void* p, int in_len, int max_len, int* out_len) { bool SrtpTransport::ProtectRtp(rtc::CopyOnWriteBuffer& buffer) {
if (!IsSrtpActive()) { if (!IsSrtpActive()) {
RTC_LOG(LS_WARNING) << "Failed to ProtectRtp: SRTP not active"; RTC_LOG(LS_WARNING) << "Failed to ProtectRtp: SRTP not active";
return false; return false;
} }
RTC_CHECK(send_session_); RTC_CHECK(send_session_);
return send_session_->ProtectRtp(p, in_len, max_len, out_len); return send_session_->ProtectRtp(buffer);
} }
bool SrtpTransport::ProtectRtp(void* p, bool SrtpTransport::ProtectRtp(rtc::CopyOnWriteBuffer& buffer, int64_t* index) {
int in_len,
int max_len,
int* out_len,
int64_t* index) {
if (!IsSrtpActive()) { if (!IsSrtpActive()) {
RTC_LOG(LS_WARNING) << "Failed to ProtectRtp: SRTP not active"; RTC_LOG(LS_WARNING) << "Failed to ProtectRtp: SRTP not active";
return false; return false;
} }
RTC_CHECK(send_session_); RTC_CHECK(send_session_);
return send_session_->ProtectRtp(p, in_len, max_len, out_len, index); return send_session_->ProtectRtp(buffer, index);
} }
bool SrtpTransport::ProtectRtcp(void* p, bool SrtpTransport::ProtectRtcp(rtc::CopyOnWriteBuffer& buffer) {
int in_len,
int max_len,
int* out_len) {
if (!IsSrtpActive()) { if (!IsSrtpActive()) {
RTC_LOG(LS_WARNING) << "Failed to ProtectRtcp: SRTP not active"; RTC_LOG(LS_WARNING) << "Failed to ProtectRtcp: SRTP not active";
return false; return false;
} }
if (send_rtcp_session_) { if (send_rtcp_session_) {
return send_rtcp_session_->ProtectRtcp(p, in_len, max_len, out_len); return send_rtcp_session_->ProtectRtcp(buffer);
} else { } else {
RTC_CHECK(send_session_); RTC_CHECK(send_session_);
return send_session_->ProtectRtcp(p, in_len, max_len, out_len); return send_session_->ProtectRtcp(buffer);
} }
} }
bool SrtpTransport::UnprotectRtp(void* p, int in_len, int* out_len) { bool SrtpTransport::UnprotectRtp(rtc::CopyOnWriteBuffer& buffer) {
if (!IsSrtpActive()) { if (!IsSrtpActive()) {
RTC_LOG(LS_WARNING) << "Failed to UnprotectRtp: SRTP not active"; RTC_LOG(LS_WARNING) << "Failed to UnprotectRtp: SRTP not active";
return false; return false;
} }
RTC_CHECK(recv_session_); RTC_CHECK(recv_session_);
return recv_session_->UnprotectRtp(p, in_len, out_len); return recv_session_->UnprotectRtp(buffer);
} }
bool SrtpTransport::UnprotectRtcp(void* p, int in_len, int* out_len) { bool SrtpTransport::UnprotectRtcp(rtc::CopyOnWriteBuffer& buffer) {
if (!IsSrtpActive()) { if (!IsSrtpActive()) {
RTC_LOG(LS_WARNING) << "Failed to UnprotectRtcp: SRTP not active"; RTC_LOG(LS_WARNING) << "Failed to UnprotectRtcp: SRTP not active";
return false; return false;
} }
if (recv_rtcp_session_) { if (recv_rtcp_session_) {
return recv_rtcp_session_->UnprotectRtcp(p, in_len, out_len); return recv_rtcp_session_->UnprotectRtcp(buffer);
} else { } else {
RTC_CHECK(recv_session_); RTC_CHECK(recv_session_);
return recv_session_->UnprotectRtcp(p, in_len, out_len); return recv_session_->UnprotectRtcp(buffer);
} }
} }

View File

@ -20,13 +20,14 @@
#include <vector> #include <vector>
#include "api/field_trials_view.h" #include "api/field_trials_view.h"
#include "api/rtc_error.h" #include "call/rtp_demuxer.h"
#include "p2p/base/packet_transport_internal.h" #include "p2p/base/packet_transport_internal.h"
#include "pc/rtp_transport.h" #include "pc/rtp_transport.h"
#include "pc/srtp_session.h" #include "pc/srtp_session.h"
#include "rtc_base/async_packet_socket.h" #include "rtc_base/async_packet_socket.h"
#include "rtc_base/buffer.h" #include "rtc_base/buffer.h"
#include "rtc_base/copy_on_write_buffer.h" #include "rtc_base/copy_on_write_buffer.h"
#include "rtc_base/network/received_packet.h"
#include "rtc_base/network_route.h" #include "rtc_base/network_route.h"
namespace webrtc { namespace webrtc {
@ -121,21 +122,15 @@ class SrtpTransport : public RtpTransport {
// Override the RtpTransport::OnWritableState. // Override the RtpTransport::OnWritableState.
void OnWritableState(rtc::PacketTransportInternal* packet_transport) override; void OnWritableState(rtc::PacketTransportInternal* packet_transport) override;
bool ProtectRtp(void* data, int in_len, int max_len, int* out_len); bool ProtectRtp(rtc::CopyOnWriteBuffer& buffer);
// Overloaded version, outputs packet index. // Overloaded version, outputs packet index.
bool ProtectRtp(void* data, bool ProtectRtp(rtc::CopyOnWriteBuffer& buffer, int64_t* index);
int in_len, bool ProtectRtcp(rtc::CopyOnWriteBuffer& buffer);
int max_len,
int* out_len,
int64_t* index);
bool ProtectRtcp(void* data, int in_len, int max_len, int* out_len);
// Decrypts/verifies an invidiual RTP/RTCP packet. // Decrypts/verifies an invidiual RTP/RTCP packet.
// If an HMAC is used, this will decrease the packet size. // If an HMAC is used, this will decrease the packet size.
bool UnprotectRtp(void* data, int in_len, int* out_len); bool UnprotectRtp(rtc::CopyOnWriteBuffer& buffer);
bool UnprotectRtcp(rtc::CopyOnWriteBuffer& buffer);
bool UnprotectRtcp(void* data, int in_len, int* out_len);
const std::string content_name_; const std::string content_name_;