From 9572b2fa5850da6d319b9efb5ee36290e2895f7f Mon Sep 17 00:00:00 2001 From: Philipp Hancke Date: Mon, 16 Dec 2024 10:13:11 -0800 Subject: [PATCH] srtp: spanify Protect + Unprotect MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Makes SrtpSession and SrtpTransport use rtc::CopyOnWriteBuffer for the Protect and Unprotect operations instead of passing around void pointers. Also updates the unit tests to use CopyOnWriteBuffer instead of char arrays with a fixed length. BUG=webrtc:357776213 No-Iwyu: missing include is a private libsrtp header Change-Id: I02a22ceb4e183e93c4ebd8c0a9c931404e0e32f3 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/358442 Reviewed-by: Henrik Boström Reviewed-by: Harald Alvestrand Commit-Queue: Philipp Hancke Cr-Commit-Position: refs/heads/main@{#43601} --- pc/BUILD.gn | 5 ++ pc/srtp_session.cc | 164 ++++++++++++++++++++++++----------- pc/srtp_session.h | 40 ++++++--- pc/srtp_session_unittest.cc | 167 +++++++++++++++++++----------------- pc/srtp_transport.cc | 101 +++++++++------------- pc/srtp_transport.h | 19 ++-- 6 files changed, 283 insertions(+), 213 deletions(-) diff --git a/pc/BUILD.gn b/pc/BUILD.gn index e13a04628c..d33b3f492f 100644 --- a/pc/BUILD.gn +++ b/pc/BUILD.gn @@ -594,6 +594,8 @@ rtc_source_set("srtp_session") { "../rtc_base:buffer", "../rtc_base:byte_order", "../rtc_base:checks", + "../rtc_base:copy_on_write_buffer", + "../rtc_base:ip_address", "../rtc_base:logging", "../rtc_base:macromagic", "../rtc_base:ssl_adapter", @@ -620,6 +622,8 @@ rtc_source_set("srtp_transport") { "../api:field_trials_view", "../api:libjingle_peerconnection_api", "../api:rtc_error", + "../api/units:timestamp", + "../call:rtp_receiver", "../media:rtp_utils", "../modules/rtp_rtcp:rtp_rtcp_format", "../p2p:packet_transport_internal", @@ -633,6 +637,7 @@ rtc_source_set("srtp_transport") { "../rtc_base:safe_conversions", "../rtc_base:ssl_adapter", "../rtc_base:zero_memory", + "../rtc_base/network:received_packet", "//third_party/abseil-cpp/absl/strings", ] } diff --git a/pc/srtp_session.cc b/pc/srtp_session.cc index bf27e63d39..755cf215ea 100644 --- a/pc/srtp_session.cc +++ b/pc/srtp_session.cc @@ -12,18 +12,21 @@ #include +#include +#include #include -#include +#include -#include "absl/base/attributes.h" -#include "absl/base/const_init.h" #include "absl/strings/string_view.h" #include "api/array_view.h" #include "api/field_trials_view.h" #include "modules/rtp_rtcp/source/rtp_util.h" #include "pc/external_hmac.h" +#include "rtc_base/buffer.h" #include "rtc_base/byte_order.h" #include "rtc_base/checks.h" +#include "rtc_base/copy_on_write_buffer.h" +#include "rtc_base/ip_address.h" #include "rtc_base/logging.h" #include "rtc_base/ssl_stream_adapter.h" #include "rtc_base/string_encode.h" @@ -149,8 +152,6 @@ void LibSrtpInitializer::DecrementLibsrtpUsageCountAndMaybeDeinit() { } // namespace -using ::webrtc::ParseRtpSequenceNumber; - // One more than the maximum libsrtp error code. Required by // RTC_HISTOGRAM_ENUMERATION. Keep this in sync with srtp_error_status_t defined // in srtp.h. @@ -196,7 +197,22 @@ bool SrtpSession::UpdateReceive(int crypto_suite, return UpdateKey(ssrc_any_inbound, crypto_suite, key, extension_ids); } -bool SrtpSession::ProtectRtp(void* p, int in_len, int max_len, int* out_len) { +bool SrtpSession::ProtectRtp(void* data, + int in_len, + int max_len, + int* out_len) { + // Note: this creates a copy of data, then needs to memcpy back. + // Do not use this variant. + rtc::CopyOnWriteBuffer buffer(static_cast(data), in_len, max_len); + bool ok = ProtectRtp(buffer); + if (ok) { + *out_len = buffer.size(); + std::memcpy(data, buffer.data(), *out_len); + } + return ok; +} + +bool SrtpSession::ProtectRtp(rtc::CopyOnWriteBuffer& buffer) { RTC_DCHECK(thread_checker_.IsCurrent()); if (!session_) { RTC_LOG(LS_WARNING) << "Failed to protect SRTP packet: no SRTP Session"; @@ -207,42 +223,52 @@ bool SrtpSession::ProtectRtp(void* p, int in_len, int max_len, int* out_len) { // SRTP_MAX_TRAILER_LEN bytes of free space after the data. WebRTC // never includes a MKI, therefore the amount of bytes added by the // srtp_protect call is known in advance and depends on the cipher suite. - int need_len = in_len + rtp_auth_tag_len_; // NOLINT - if (max_len < need_len) { + size_t need_len = buffer.size() + rtp_auth_tag_len_; // NOLINT + if (buffer.capacity() < need_len) { RTC_LOG(LS_WARNING) << "Failed to protect SRTP packet: The buffer length " - << max_len << " is less than the needed " << need_len; + << buffer.capacity() << " is less than the needed " + << need_len; return false; } if (dump_plain_rtp_) { - DumpPacket(p, in_len, /*outbound=*/true); + DumpPacket(buffer, /*outbound=*/true); } - *out_len = in_len; - int err = srtp_protect(session_, p, out_len); - int seq_num = ParseRtpSequenceNumber( - rtc::MakeArrayView(reinterpret_cast(p), in_len)); + int out_len = buffer.size(); + int err = srtp_protect(session_, buffer.MutableData(), &out_len); + int seq_num = webrtc::ParseRtpSequenceNumber(buffer); if (err != srtp_err_status_ok) { RTC_LOG(LS_WARNING) << "Failed to protect SRTP packet, seqnum=" << seq_num << ", err=" << err << ", last seqnum=" << last_send_seq_num_; return false; } + buffer.SetSize(out_len); last_send_seq_num_ = seq_num; return true; } -bool SrtpSession::ProtectRtp(void* p, +bool SrtpSession::ProtectRtp(rtc::CopyOnWriteBuffer& buffer, int64_t* index) { + if (!ProtectRtp(buffer)) { + return false; + } + return (index) ? GetSendStreamPacketIndex(buffer, index) : true; +} + +bool SrtpSession::ProtectRtp(void* data, int in_len, int max_len, int* out_len, int64_t* index) { - if (!ProtectRtp(p, in_len, max_len, out_len)) { + rtc::CopyOnWriteBuffer buffer(static_cast(data), in_len, max_len); + if (!ProtectRtp(buffer)) { return false; } - return (index) ? GetSendStreamPacketIndex(p, in_len, index) : true; + *out_len = buffer.size(); + return (index) ? GetSendStreamPacketIndex(buffer, index) : true; } -bool SrtpSession::ProtectRtcp(void* p, int in_len, int max_len, int* out_len) { +bool SrtpSession::ProtectRtcp(rtc::CopyOnWriteBuffer& buffer) { RTC_DCHECK(thread_checker_.IsCurrent()); if (!session_) { RTC_LOG(LS_WARNING) << "Failed to protect SRTCP packet: no SRTP Session"; @@ -253,34 +279,52 @@ bool SrtpSession::ProtectRtcp(void* p, int in_len, int max_len, int* out_len) { // SRTP_MAX_TRAILER_LEN bytes of free space after the data. WebRTC // never includes a MKI, therefore the amount of bytes added by the // srtp_protect_rtp call is known in advance and depends on the cipher suite. - int need_len = in_len + sizeof(uint32_t) + rtcp_auth_tag_len_; // NOLINT - if (max_len < need_len) { - RTC_LOG(LS_WARNING) << "Failed to protect SRTCP packet: The buffer length " - << max_len << " is less than the needed " << need_len; + size_t need_len = + buffer.size() + sizeof(uint32_t) + rtcp_auth_tag_len_; // NOLINT + if (buffer.capacity() < need_len) { + RTC_LOG(LS_WARNING) + << "Failed to protect SRTCP packet: The buffer capacity " + << buffer.capacity() << " is less than the needed " << need_len; return false; } if (dump_plain_rtp_) { - DumpPacket(p, in_len, /*outbound=*/true); + DumpPacket(buffer, /*outbound=*/true); } - *out_len = in_len; - int err = srtp_protect_rtcp(session_, p, out_len); + int out_len = buffer.size(); + int err = srtp_protect_rtcp(session_, buffer.MutableData(), &out_len); if (err != srtp_err_status_ok) { RTC_LOG(LS_WARNING) << "Failed to protect SRTCP packet, err=" << err; return false; } + buffer.SetSize(out_len); return true; } -bool SrtpSession::UnprotectRtp(void* p, int in_len, int* out_len) { +bool SrtpSession::ProtectRtcp(void* data, + int in_len, + int max_len, + int* out_len) { + // Note: this creates a copy of data, then needs to memcpy back. + // Do not use this variant. + rtc::CopyOnWriteBuffer buffer(static_cast(data), in_len, max_len); + bool result = ProtectRtcp(buffer); + if (result) { + *out_len = buffer.size(); + std::memcpy(data, buffer.data(), *out_len); + } + return result; +} + +bool SrtpSession::UnprotectRtp(rtc::CopyOnWriteBuffer& buffer) { RTC_DCHECK(thread_checker_.IsCurrent()); if (!session_) { RTC_LOG(LS_WARNING) << "Failed to unprotect SRTP packet: no SRTP Session"; return false; } + int out_len = buffer.size(); - *out_len = in_len; - int err = srtp_unprotect(session_, p, out_len); + int err = srtp_unprotect(session_, buffer.MutableData(), &out_len); if (err != srtp_err_status_ok) { // Limit the error logging to avoid excessive logs when there are lots of // bad packets. @@ -295,33 +339,55 @@ bool SrtpSession::UnprotectRtp(void* p, int in_len, int* out_len) { static_cast(err), kSrtpErrorCodeBoundary); return false; } + buffer.SetSize(out_len); if (dump_plain_rtp_) { - DumpPacket(p, *out_len, /*outbound=*/false); + DumpPacket(buffer, /*outbound=*/false); } return true; } -bool SrtpSession::UnprotectRtcp(void* p, int in_len, int* out_len) { +bool SrtpSession::UnprotectRtp(void* data, int in_len, int* out_len) { + rtc::CopyOnWriteBuffer buffer(static_cast(data), in_len); + bool ok = UnprotectRtp(buffer); + if (ok) { + *out_len = buffer.size(); + std::memcpy(data, buffer.data(), *out_len); + } + return ok; +} + +bool SrtpSession::UnprotectRtcp(rtc::CopyOnWriteBuffer& buffer) { RTC_DCHECK(thread_checker_.IsCurrent()); if (!session_) { RTC_LOG(LS_WARNING) << "Failed to unprotect SRTCP packet: no SRTP Session"; return false; } - *out_len = in_len; - int err = srtp_unprotect_rtcp(session_, p, out_len); + int out_len = buffer.size(); + int err = srtp_unprotect_rtcp(session_, buffer.MutableData(), &out_len); if (err != srtp_err_status_ok) { RTC_LOG(LS_WARNING) << "Failed to unprotect SRTCP packet, err=" << err; RTC_HISTOGRAM_ENUMERATION("WebRTC.PeerConnection.SrtcpUnprotectError", static_cast(err), kSrtpErrorCodeBoundary); return false; } + buffer.SetSize(out_len); if (dump_plain_rtp_) { - DumpPacket(p, *out_len, /*outbound=*/false); + DumpPacket(buffer, /*outbound=*/false); } return true; } +bool SrtpSession::UnprotectRtcp(void* data, int in_len, int* out_len) { + rtc::CopyOnWriteBuffer buffer(static_cast(data), in_len); + bool ok = UnprotectRtp(buffer); + if (ok) { + *out_len = buffer.size(); + std::memcpy(data, buffer.data(), *out_len); + } + return ok; +} + bool SrtpSession::GetRtpAuthParams(uint8_t** key, int* key_len, int* tag_len) { RTC_DCHECK(thread_checker_.IsCurrent()); RTC_DCHECK(IsExternalAuthActive()); @@ -373,12 +439,12 @@ bool SrtpSession::RemoveSsrcFromSession(uint32_t ssrc) { return srtp_remove_stream(session_, htonl(ssrc)) == srtp_err_status_ok; } -bool SrtpSession::GetSendStreamPacketIndex(void* p, - int in_len, +bool SrtpSession::GetSendStreamPacketIndex(rtc::CopyOnWriteBuffer& buffer, int64_t* index) { RTC_DCHECK(thread_checker_.IsCurrent()); - srtp_hdr_t* hdr = reinterpret_cast(p); - srtp_stream_ctx_t* stream = srtp_get_stream(session_, hdr->ssrc); + // libSRTP expects the SSRC to be in network byte order. + srtp_stream_ctx_t* stream = + srtp_get_stream(session_, htonl(webrtc::ParseRtpSsrc(buffer))); if (!stream) { return false; } @@ -534,25 +600,25 @@ void SrtpSession::HandleEventThunk(srtp_event_data_t* ev) { // extracted by searching for RTP_DUMP // grep RTP_DUMP chrome_debug.log > in.txt // and converted to pcap using -// text2pcap -D -u 1000,2000 -t %H:%M:%S. in.txt out.pcap +// text2pcap -D -u 1000,2000 -t %H:%M:%S.%f in.txt out.pcap // The resulting file can be replayed using the WebRTC video_replay tool and // be inspected in Wireshark using the RTP, VP8 and H264 dissectors. -void SrtpSession::DumpPacket(const void* buf, int len, bool outbound) { +void SrtpSession::DumpPacket(const rtc::CopyOnWriteBuffer& buffer, + bool outbound) { int64_t time_of_day = rtc::TimeUTCMillis() % (24 * 3600 * 1000); int64_t hours = time_of_day / (3600 * 1000); int64_t minutes = (time_of_day / (60 * 1000)) % 60; int64_t seconds = (time_of_day / 1000) % 60; int64_t millis = time_of_day % 1000; - RTC_LOG(LS_VERBOSE) << "\n" - << (outbound ? "O" : "I") << " " << std::setfill('0') - << std::setw(2) << hours << ":" << std::setfill('0') - << std::setw(2) << minutes << ":" << std::setfill('0') - << std::setw(2) << seconds << "." << std::setfill('0') - << std::setw(3) << millis << " " - << "000000 " - << rtc::hex_encode_with_delimiter( - absl::string_view((const char*)buf, len), ' ') - << " # RTP_DUMP"; + RTC_LOG(LS_VERBOSE) + << "\n" + << (outbound ? "O" : "I") << " " << std::setfill('0') << std::setw(2) + << hours << ":" << std::setfill('0') << std::setw(2) << minutes << ":" + << std::setfill('0') << std::setw(2) << seconds << "." + << std::setfill('0') << std::setw(3) << millis << " " << "000000 " + << rtc::hex_encode_with_delimiter( + absl::string_view(buffer.data(), buffer.size()), ' ') + << " # RTP_DUMP"; } } // namespace cricket diff --git a/pc/srtp_session.h b/pc/srtp_session.h index f9e7ae5d40..4dde928318 100644 --- a/pc/srtp_session.h +++ b/pc/srtp_session.h @@ -17,9 +17,9 @@ #include #include "api/field_trials_view.h" -#include "api/scoped_refptr.h" #include "api/sequence_checker.h" #include "rtc_base/buffer.h" +#include "rtc_base/copy_on_write_buffer.h" // Forward declaration to avoid pulling in libsrtp headers here struct srtp_event_data_t; @@ -62,18 +62,34 @@ class SrtpSession { // Encrypts/signs an individual RTP/RTCP packet, in-place. // If an HMAC is used, this will increase the packet size. - bool ProtectRtp(void* data, int in_len, int max_len, int* out_len); + [[deprecated("Pass CopyOnWriteBuffer")]] bool ProtectRtp(void* data, + int in_len, + int max_len, + int* out_len); + bool ProtectRtp(rtc::CopyOnWriteBuffer& buffer); // Overloaded version, outputs packet index. - bool ProtectRtp(void* data, - int in_len, - int max_len, - int* out_len, - int64_t* index); - bool ProtectRtcp(void* data, int in_len, int max_len, int* out_len); + [[deprecated("Pass CopyOnWriteBuffer")]] bool ProtectRtp(void* data, + int in_len, + int max_len, + int* out_len, + int64_t* index); + bool ProtectRtp(rtc::CopyOnWriteBuffer& buffer, int64_t* index); + + [[deprecated("Pass CopyOnWriteBuffer")]] bool ProtectRtcp(void* data, + int in_len, + int max_len, + int* out_len); + bool ProtectRtcp(rtc::CopyOnWriteBuffer& buffer); // Decrypts/verifies an invidiual RTP/RTCP packet. // If an HMAC is used, this will decrease the packet size. - bool UnprotectRtp(void* data, int in_len, int* out_len); - bool UnprotectRtcp(void* data, int in_len, int* out_len); + [[deprecated("Pass CopyOnWriteBuffer")]] bool UnprotectRtp(void* data, + int in_len, + int* out_len); + bool UnprotectRtp(rtc::CopyOnWriteBuffer& buffer); + [[deprecated("Pass CopyOnWriteBuffer")]] bool UnprotectRtcp(void* data, + int in_len, + int* out_len); + bool UnprotectRtcp(rtc::CopyOnWriteBuffer& buffer); // Helper method to get authentication params. bool GetRtpAuthParams(uint8_t** key, int* key_len, int* tag_len); @@ -115,11 +131,11 @@ class SrtpSession { const rtc::ZeroOnFreeBuffer& key, const std::vector& extension_ids); // Returns send stream current packet index from srtp db. - bool GetSendStreamPacketIndex(void* data, int in_len, int64_t* index); + bool GetSendStreamPacketIndex(rtc::CopyOnWriteBuffer& buffer, int64_t* index); // Writes unencrypted packets in text2pcap format to the log file // for debugging. - void DumpPacket(const void* buf, int len, bool outbound); + void DumpPacket(const rtc::CopyOnWriteBuffer& buffer, bool outbound); void HandleEvent(const srtp_event_data_t* ev); static void HandleEventThunk(srtp_event_data_t* ev); diff --git a/pc/srtp_session_unittest.cc b/pc/srtp_session_unittest.cc index 0785cc0c56..35d4108afc 100644 --- a/pc/srtp_session_unittest.cc +++ b/pc/srtp_session_unittest.cc @@ -12,11 +12,16 @@ #include -#include +#include +#include +#include +#include #include "media/base/fake_rtp.h" #include "pc/test/srtp_test_util.h" +#include "rtc_base/buffer.h" #include "rtc_base/byte_order.h" +#include "rtc_base/copy_on_write_buffer.h" #include "rtc_base/ssl_stream_adapter.h" // For rtc::SRTP_* #include "system_wrappers/include/metrics.h" #include "test/gmock.h" @@ -41,45 +46,45 @@ class SrtpSessionTest : public ::testing::Test { virtual void SetUp() { rtp_len_ = sizeof(kPcmuFrame); rtcp_len_ = sizeof(kRtcpReport); - memcpy(rtp_packet_, kPcmuFrame, rtp_len_); - memcpy(rtcp_packet_, kRtcpReport, rtcp_len_); + rtp_packet_.EnsureCapacity(rtp_len_ + 10); + rtp_packet_.SetData(kPcmuFrame, rtp_len_); + rtcp_packet_.EnsureCapacity(rtcp_len_ + 4 + 10); + rtcp_packet_.SetData(kRtcpReport, rtcp_len_); } void TestProtectRtp(int crypto_suite) { - int out_len = 0; - EXPECT_TRUE( - s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), &out_len)); - EXPECT_EQ(out_len, rtp_len_ + rtp_auth_tag_len(crypto_suite)); - EXPECT_NE(0, memcmp(rtp_packet_, kPcmuFrame, rtp_len_)); - rtp_len_ = out_len; + EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_)); + EXPECT_EQ(rtp_packet_.size(), rtp_len_ + rtp_auth_tag_len(crypto_suite)); + // Check that Protect changed the content (up to the original length). + EXPECT_NE(0, std::memcmp(kPcmuFrame, rtp_packet_.data(), rtp_len_)); + rtp_len_ = rtp_packet_.size(); } void TestProtectRtcp(int crypto_suite) { - int out_len = 0; - EXPECT_TRUE(s1_.ProtectRtcp(rtcp_packet_, rtcp_len_, sizeof(rtcp_packet_), - &out_len)); - EXPECT_EQ(out_len, - rtcp_len_ + 4 + rtcp_auth_tag_len(crypto_suite)); // NOLINT - EXPECT_NE(0, memcmp(rtcp_packet_, kRtcpReport, rtcp_len_)); - rtcp_len_ = out_len; + EXPECT_TRUE(s1_.ProtectRtcp(rtcp_packet_)); + EXPECT_EQ(rtcp_packet_.size(), + rtcp_len_ + 4 + rtcp_auth_tag_len(crypto_suite)); + // Check that Protect changed the content (up to the original length). + EXPECT_NE(0, std::memcmp(kRtcpReport, rtcp_packet_.data(), rtcp_len_)); + rtcp_len_ = rtcp_packet_.size(); } void TestUnprotectRtp(int crypto_suite) { - int out_len = 0, expected_len = sizeof(kPcmuFrame); - EXPECT_TRUE(s2_.UnprotectRtp(rtp_packet_, rtp_len_, &out_len)); - EXPECT_EQ(expected_len, out_len); - EXPECT_EQ(0, memcmp(rtp_packet_, kPcmuFrame, out_len)); + EXPECT_TRUE(s2_.UnprotectRtp(rtp_packet_)); + EXPECT_EQ(rtp_packet_.size(), sizeof(kPcmuFrame)); + EXPECT_EQ(0, + std::memcmp(kPcmuFrame, rtp_packet_.data(), rtp_packet_.size())); } void TestUnprotectRtcp(int crypto_suite) { - int out_len = 0, expected_len = sizeof(kRtcpReport); - EXPECT_TRUE(s2_.UnprotectRtcp(rtcp_packet_, rtcp_len_, &out_len)); - EXPECT_EQ(expected_len, out_len); - EXPECT_EQ(0, memcmp(rtcp_packet_, kRtcpReport, out_len)); + EXPECT_TRUE(s2_.UnprotectRtcp(rtcp_packet_)); + EXPECT_EQ(rtcp_packet_.size(), sizeof(kRtcpReport)); + EXPECT_EQ( + 0, std::memcmp(kRtcpReport, rtcp_packet_.data(), rtcp_packet_.size())); } webrtc::test::ScopedKeyValueConfig field_trials_; cricket::SrtpSession s1_; cricket::SrtpSession s2_; - char rtp_packet_[sizeof(kPcmuFrame) + 10]; - char rtcp_packet_[sizeof(kRtcpReport) + 4 + 10]; - int rtp_len_; - int rtcp_len_; + rtc::CopyOnWriteBuffer rtp_packet_; + rtc::CopyOnWriteBuffer rtcp_packet_; + size_t rtp_len_; + size_t rtcp_len_; }; // Test that we can set up the session and keys properly. @@ -140,9 +145,7 @@ TEST_F(SrtpSessionTest, TestGetSendStreamPacketIndex) { EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_32, kTestKey1, kEncryptedHeaderExtensionIds)); int64_t index; - int out_len = 0; - EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), - &out_len, &index)); + EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_, &index)); // `index` will be shifted by 16. int64_t be64_index = static_cast(NetworkToHost64(1 << 16)); EXPECT_EQ(be64_index, index); @@ -150,20 +153,20 @@ TEST_F(SrtpSessionTest, TestGetSendStreamPacketIndex) { // Test that we fail to unprotect if someone tampers with the RTP/RTCP paylaods. TEST_F(SrtpSessionTest, TestTamperReject) { - int out_len; EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1, kEncryptedHeaderExtensionIds)); EXPECT_TRUE(s2_.SetReceive(kSrtpAes128CmSha1_80, kTestKey1, kEncryptedHeaderExtensionIds)); TestProtectRtp(kSrtpAes128CmSha1_80); - TestProtectRtcp(kSrtpAes128CmSha1_80); - rtp_packet_[0] = 0x12; - rtcp_packet_[1] = 0x34; - EXPECT_FALSE(s2_.UnprotectRtp(rtp_packet_, rtp_len_, &out_len)); + rtp_packet_.MutableData()[0] = 0x12; + EXPECT_FALSE(s2_.UnprotectRtp(rtp_packet_)); EXPECT_METRIC_THAT( webrtc::metrics::Samples("WebRTC.PeerConnection.SrtpUnprotectError"), ElementsAre(Pair(srtp_err_status_bad_param, 1))); - EXPECT_FALSE(s2_.UnprotectRtcp(rtcp_packet_, rtcp_len_, &out_len)); + + TestProtectRtcp(kSrtpAes128CmSha1_80); + rtcp_packet_.MutableData()[1] = 0x34; + EXPECT_FALSE(s2_.UnprotectRtcp(rtcp_packet_)); EXPECT_METRIC_THAT( webrtc::metrics::Samples("WebRTC.PeerConnection.SrtcpUnprotectError"), ElementsAre(Pair(srtp_err_status_auth_fail, 1))); @@ -171,16 +174,15 @@ TEST_F(SrtpSessionTest, TestTamperReject) { // Test that we fail to unprotect if the payloads are not authenticated. TEST_F(SrtpSessionTest, TestUnencryptReject) { - int out_len; EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1, kEncryptedHeaderExtensionIds)); EXPECT_TRUE(s2_.SetReceive(kSrtpAes128CmSha1_80, kTestKey1, kEncryptedHeaderExtensionIds)); - EXPECT_FALSE(s2_.UnprotectRtp(rtp_packet_, rtp_len_, &out_len)); + EXPECT_FALSE(s2_.UnprotectRtp(rtp_packet_)); EXPECT_METRIC_THAT( webrtc::metrics::Samples("WebRTC.PeerConnection.SrtpUnprotectError"), ElementsAre(Pair(srtp_err_status_auth_fail, 1))); - EXPECT_FALSE(s2_.UnprotectRtcp(rtcp_packet_, rtcp_len_, &out_len)); + EXPECT_FALSE(s2_.UnprotectRtcp(rtcp_packet_)); EXPECT_METRIC_THAT( webrtc::metrics::Samples("WebRTC.PeerConnection.SrtcpUnprotectError"), ElementsAre(Pair(srtp_err_status_cant_check, 1))); @@ -188,21 +190,23 @@ TEST_F(SrtpSessionTest, TestUnencryptReject) { // Test that we fail when using buffers that are too small. TEST_F(SrtpSessionTest, TestBuffersTooSmall) { - int out_len; EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1, kEncryptedHeaderExtensionIds)); - EXPECT_FALSE(s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_) - 10, - &out_len)); - EXPECT_FALSE(s1_.ProtectRtcp(rtcp_packet_, rtcp_len_, - sizeof(rtcp_packet_) - 14, &out_len)); + // This buffer does not have extra capacity which we treat as an error. + rtc::CopyOnWriteBuffer rtp_packet(rtp_packet_.data(), rtp_packet_.size(), + rtp_packet_.size()); + EXPECT_FALSE(s1_.ProtectRtp(rtp_packet)); + // This buffer does not have extra capacity which we treat as an error. + rtc::CopyOnWriteBuffer rtcp_packet(rtcp_packet_.data(), rtcp_packet_.size(), + rtcp_packet_.size()); + EXPECT_FALSE(s1_.ProtectRtcp(rtcp_packet)); } TEST_F(SrtpSessionTest, TestReplay) { - static const uint16_t kMaxSeqnum = static_cast(-1); + static const uint16_t kMaxSeqnum = std::numeric_limits::max() - 1; static const uint16_t seqnum_big = 62275; static const uint16_t seqnum_small = 10; static const uint16_t replay_window = 1024; - int out_len; EXPECT_TRUE(s1_.SetSend(kSrtpAes128CmSha1_80, kTestKey1, kEncryptedHeaderExtensionIds)); @@ -210,38 +214,37 @@ TEST_F(SrtpSessionTest, TestReplay) { kEncryptedHeaderExtensionIds)); // Initial sequence number. - SetBE16(reinterpret_cast(rtp_packet_) + 2, seqnum_big); - EXPECT_TRUE( - s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), &out_len)); + SetBE16(rtp_packet_.MutableData() + 2, seqnum_big); + EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_)); + rtp_packet_.SetData(kPcmuFrame, sizeof(kPcmuFrame)); // Replay within the 1024 window should succeed. - SetBE16(reinterpret_cast(rtp_packet_) + 2, + SetBE16(rtp_packet_.MutableData() + 2, seqnum_big - replay_window + 1); - EXPECT_TRUE( - s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), &out_len)); + EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_)); + rtp_packet_.SetData(kPcmuFrame, sizeof(kPcmuFrame)); // Replay out side of the 1024 window should fail. - SetBE16(reinterpret_cast(rtp_packet_) + 2, + SetBE16(rtp_packet_.MutableData() + 2, seqnum_big - replay_window - 1); - EXPECT_FALSE( - s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), &out_len)); + EXPECT_FALSE(s1_.ProtectRtp(rtp_packet_)); + rtp_packet_.SetData(kPcmuFrame, sizeof(kPcmuFrame)); // Increment sequence number to a small number. - SetBE16(reinterpret_cast(rtp_packet_) + 2, seqnum_small); - EXPECT_TRUE( - s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), &out_len)); + SetBE16(rtp_packet_.MutableData() + 2, seqnum_small); + EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_)); // Replay around 0 but out side of the 1024 window should fail. - SetBE16(reinterpret_cast(rtp_packet_) + 2, + SetBE16(rtp_packet_.MutableData() + 2, kMaxSeqnum + seqnum_small - replay_window - 1); - EXPECT_FALSE( - s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), &out_len)); + EXPECT_FALSE(s1_.ProtectRtp(rtp_packet_)); + rtp_packet_.SetData(kPcmuFrame, sizeof(kPcmuFrame)); // Replay around 0 but within the 1024 window should succeed. for (uint16_t seqnum = 65000; seqnum < 65003; ++seqnum) { - SetBE16(reinterpret_cast(rtp_packet_) + 2, seqnum); - EXPECT_TRUE( - s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), &out_len)); + SetBE16(rtp_packet_.MutableData() + 2, seqnum); + EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_)); + rtp_packet_.SetData(kPcmuFrame, sizeof(kPcmuFrame)); } // Go back to normal sequence nubmer. @@ -249,9 +252,8 @@ TEST_F(SrtpSessionTest, TestReplay) { // without the fix, the loop above would keep incrementing local sequence // number in libsrtp, eventually the new sequence number would go out side // of the window. - SetBE16(reinterpret_cast(rtp_packet_) + 2, seqnum_small + 1); - EXPECT_TRUE( - s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), &out_len)); + SetBE16(rtp_packet_.MutableData() + 2, seqnum_small + 1); + EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_)); } TEST_F(SrtpSessionTest, RemoveSsrc) { @@ -259,33 +261,32 @@ TEST_F(SrtpSessionTest, RemoveSsrc) { kEncryptedHeaderExtensionIds)); EXPECT_TRUE(s2_.SetReceive(kSrtpAes128CmSha1_80, kTestKey1, kEncryptedHeaderExtensionIds)); - int out_len; // Encrypt and decrypt the packet once. - EXPECT_TRUE( - s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), &out_len)); - EXPECT_TRUE(s2_.UnprotectRtp(rtp_packet_, out_len, &out_len)); - EXPECT_EQ(rtp_len_, out_len); - EXPECT_EQ(0, memcmp(rtp_packet_, kPcmuFrame, out_len)); + EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_)); + EXPECT_TRUE(s2_.UnprotectRtp(rtp_packet_)); + EXPECT_EQ(sizeof(kPcmuFrame), rtp_packet_.size()); + EXPECT_EQ(0, std::memcmp(kPcmuFrame, rtp_packet_.data(), rtp_packet_.size())); // Recreate the original packet and encrypt again. - memcpy(rtp_packet_, kPcmuFrame, rtp_len_); - EXPECT_TRUE( - s1_.ProtectRtp(rtp_packet_, rtp_len_, sizeof(rtp_packet_), &out_len)); + rtp_packet_.SetData(kPcmuFrame, sizeof(kPcmuFrame)); + EXPECT_TRUE(s1_.ProtectRtp(rtp_packet_)); // Attempting to decrypt will fail as a replay attack. // (srtp_err_status_replay_fail) since the sequence number was already seen. - EXPECT_FALSE(s2_.UnprotectRtp(rtp_packet_, out_len, &out_len)); + EXPECT_FALSE(s2_.UnprotectRtp(rtp_packet_)); // Remove the fake packet SSRC 1 from the session. EXPECT_TRUE(s2_.RemoveSsrcFromSession(1)); EXPECT_FALSE(s2_.RemoveSsrcFromSession(1)); // Since the SRTP state was discarded, this is no longer a replay attack. - EXPECT_TRUE(s2_.UnprotectRtp(rtp_packet_, out_len, &out_len)); - EXPECT_EQ(rtp_len_, out_len); - EXPECT_EQ(0, memcmp(rtp_packet_, kPcmuFrame, out_len)); + EXPECT_TRUE(s2_.UnprotectRtp(rtp_packet_)); + EXPECT_EQ(sizeof(kPcmuFrame), rtp_packet_.size()); + EXPECT_EQ(0, std::memcmp(kPcmuFrame, rtp_packet_.data(), rtp_packet_.size())); EXPECT_TRUE(s2_.RemoveSsrcFromSession(1)); } +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-declarations" TEST_F(SrtpSessionTest, ProtectUnprotectWrapAroundRocMismatch) { // This unit tests demonstrates why you should be careful when // choosing the initial RTP sequence number as there can be decryption @@ -316,6 +317,7 @@ TEST_F(SrtpSessionTest, ProtectUnprotectWrapAroundRocMismatch) { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // clang-format on }; + const unsigned char kPayload[] = {0xBE, 0xEF}; int out_len; // Encrypt the frames in-order. There is a sequence number rollover from @@ -337,9 +339,12 @@ TEST_F(SrtpSessionTest, ProtectUnprotectWrapAroundRocMismatch) { EXPECT_FALSE(s2_.UnprotectRtp(kFrame2, sizeof(kFrame2), &out_len)); // Decrypt frame 1. EXPECT_TRUE(s2_.UnprotectRtp(kFrame1, sizeof(kFrame1), &out_len)); + EXPECT_EQ(0, std::memcmp(kFrame1 + 12, kPayload, sizeof(kPayload))); // Now decrypt frame 2 again. A rollover is detected which increases // the ROC to 1 so this succeeds. EXPECT_TRUE(s2_.UnprotectRtp(kFrame2, sizeof(kFrame2), &out_len)); + EXPECT_EQ(0, std::memcmp(kFrame2 + 12, kPayload, sizeof(kPayload))); } +#pragma clang diagnostic pop } // namespace rtc diff --git a/pc/srtp_transport.cc b/pc/srtp_transport.cc index 0f4acea090..90992c1e1a 100644 --- a/pc/srtp_transport.cc +++ b/pc/srtp_transport.cc @@ -10,25 +10,26 @@ #include "pc/srtp_transport.h" -#include - -#include +#include +#include #include #include -#include "absl/strings/match.h" +#include "api/field_trials_view.h" +#include "api/units/timestamp.h" +#include "call/rtp_demuxer.h" #include "media/base/rtp_utils.h" #include "modules/rtp_rtcp/source/rtp_util.h" #include "pc/rtp_transport.h" #include "pc/srtp_session.h" #include "rtc_base/async_packet_socket.h" +#include "rtc_base/buffer.h" #include "rtc_base/checks.h" #include "rtc_base/copy_on_write_buffer.h" #include "rtc_base/logging.h" -#include "rtc_base/numerics/safe_conversions.h" -#include "rtc_base/ssl_stream_adapter.h" +#include "rtc_base/network/received_packet.h" +#include "rtc_base/network_route.h" #include "rtc_base/trace_event.h" -#include "rtc_base/zero_memory.h" namespace webrtc { @@ -40,6 +41,7 @@ SrtpTransport::SrtpTransport(bool rtcp_mux_enabled, bool SrtpTransport::SendRtpPacket(rtc::CopyOnWriteBuffer* packet, const rtc::PacketOptions& options, int flags) { + RTC_DCHECK(packet); if (!IsSrtpActive()) { RTC_LOG(LS_ERROR) << "Failed to send the packet because SRTP transport is inactive."; @@ -47,23 +49,21 @@ bool SrtpTransport::SendRtpPacket(rtc::CopyOnWriteBuffer* packet, } rtc::PacketOptions updated_options = options; TRACE_EVENT0("webrtc", "SRTP Encode"); + // If ENABLE_EXTERNAL_AUTH flag is on then packet authentication is not done + // inside libsrtp for a RTP packet. A external HMAC module will be writing + // a fake HMAC value. This is ONLY done for a RTP packet. + // Socket layer will update rtp sendtime extension header if present in + // packet with current time before updating the HMAC. bool res; - uint8_t* data = packet->MutableData(); - int len = rtc::checked_cast(packet->size()); -// If ENABLE_EXTERNAL_AUTH flag is on then packet authentication is not done -// inside libsrtp for a RTP packet. A external HMAC module will be writing -// a fake HMAC value. This is ONLY done for a RTP packet. -// Socket layer will update rtp sendtime extension header if present in -// packet with current time before updating the HMAC. #if !defined(ENABLE_EXTERNAL_AUTH) - res = ProtectRtp(data, len, static_cast(packet->capacity()), &len); + res = ProtectRtp(*packet); #else if (!IsExternalAuthActive()) { - res = ProtectRtp(data, len, static_cast(packet->capacity()), &len); + res = ProtectRtp(*packet); } else { updated_options.packet_time_params.rtp_sendtime_extension_id = rtp_abs_sendtime_extn_id_; - res = ProtectRtp(data, len, static_cast(packet->capacity()), &len, + res = ProtectRtp(*packet, &updated_options.packet_time_params.srtp_packet_index); // If protection succeeds, let's get auth params from srtp. if (res) { @@ -83,19 +83,18 @@ bool SrtpTransport::SendRtpPacket(rtc::CopyOnWriteBuffer* packet, if (!res) { uint16_t seq_num = ParseRtpSequenceNumber(*packet); uint32_t ssrc = ParseRtpSsrc(*packet); - RTC_LOG(LS_ERROR) << "Failed to protect RTP packet: size=" << len + RTC_LOG(LS_ERROR) << "Failed to protect RTP packet: size=" << packet->size() << ", seqnum=" << seq_num << ", SSRC=" << ssrc; return false; } - // Update the length of the packet now that we've added the auth tag. - packet->SetSize(len); return SendPacket(/*rtcp=*/false, packet, updated_options, flags); } bool SrtpTransport::SendRtcpPacket(rtc::CopyOnWriteBuffer* packet, const rtc::PacketOptions& options, int flags) { + RTC_DCHECK(packet); if (!IsSrtpActive()) { RTC_LOG(LS_ERROR) << "Failed to send the packet because SRTP transport is inactive."; @@ -103,17 +102,13 @@ bool SrtpTransport::SendRtcpPacket(rtc::CopyOnWriteBuffer* packet, } TRACE_EVENT0("webrtc", "SRTP Encode"); - uint8_t* data = packet->MutableData(); - int len = rtc::checked_cast(packet->size()); - if (!ProtectRtcp(data, len, static_cast(packet->capacity()), &len)) { + if (!ProtectRtcp(*packet)) { int type = -1; - cricket::GetRtcpType(data, len, &type); - RTC_LOG(LS_ERROR) << "Failed to protect RTCP packet: size=" << len - << ", type=" << type; + cricket::GetRtcpType(packet->data(), packet->size(), &type); + RTC_LOG(LS_ERROR) << "Failed to protect RTCP packet: size=" + << packet->size() << ", type=" << type; return false; } - // Update the length of the packet now that we've added the auth tag. - packet->SetSize(len); return SendPacket(/*rtcp=*/true, packet, options, flags); } @@ -127,14 +122,13 @@ void SrtpTransport::OnRtpPacketReceived(const rtc::ReceivedPacket& packet) { } rtc::CopyOnWriteBuffer payload(packet.payload()); - char* data = payload.MutableData(); - int len = rtc::checked_cast(payload.size()); - if (!UnprotectRtp(data, len, &len)) { + if (!UnprotectRtp(payload)) { // Limit the error logging to avoid excessive logs when there are lots of // bad packets. const int kFailureLogThrottleCount = 100; if (decryption_failure_count_ % kFailureLogThrottleCount == 0) { - RTC_LOG(LS_ERROR) << "Failed to unprotect RTP packet: size=" << len + RTC_LOG(LS_ERROR) << "Failed to unprotect RTP packet: size=" + << payload.size() << ", seqnum=" << ParseRtpSequenceNumber(payload) << ", SSRC=" << ParseRtpSsrc(payload) << ", previous failure count: " @@ -143,7 +137,6 @@ void SrtpTransport::OnRtpPacketReceived(const rtc::ReceivedPacket& packet) { ++decryption_failure_count_; return; } - payload.SetSize(len); DemuxPacket(std::move(payload), packet.arrival_time().value_or(Timestamp::MinusInfinity()), packet.ecn()); @@ -157,16 +150,13 @@ void SrtpTransport::OnRtcpPacketReceived(const rtc::ReceivedPacket& packet) { return; } rtc::CopyOnWriteBuffer payload(packet.payload()); - char* data = payload.MutableData(); - int len = rtc::checked_cast(payload.size()); - if (!UnprotectRtcp(data, len, &len)) { + if (!UnprotectRtcp(payload)) { int type = -1; - cricket::GetRtcpType(data, len, &type); - RTC_LOG(LS_ERROR) << "Failed to unprotect RTCP packet: size=" << len - << ", type=" << type; + cricket::GetRtcpType(payload.data(), payload.size(), &type); + RTC_LOG(LS_ERROR) << "Failed to unprotect RTCP packet: size=" + << payload.size() << ", type=" << type; return; } - payload.SetSize(len); SendRtcpPacketReceived( &payload, packet.arrival_time() ? packet.arrival_time()->us() : -1); } @@ -291,63 +281,56 @@ void SrtpTransport::CreateSrtpSessions() { } } -bool SrtpTransport::ProtectRtp(void* p, int in_len, int max_len, int* out_len) { +bool SrtpTransport::ProtectRtp(rtc::CopyOnWriteBuffer& buffer) { if (!IsSrtpActive()) { RTC_LOG(LS_WARNING) << "Failed to ProtectRtp: SRTP not active"; return false; } RTC_CHECK(send_session_); - return send_session_->ProtectRtp(p, in_len, max_len, out_len); + return send_session_->ProtectRtp(buffer); } -bool SrtpTransport::ProtectRtp(void* p, - int in_len, - int max_len, - int* out_len, - int64_t* index) { +bool SrtpTransport::ProtectRtp(rtc::CopyOnWriteBuffer& buffer, int64_t* index) { if (!IsSrtpActive()) { RTC_LOG(LS_WARNING) << "Failed to ProtectRtp: SRTP not active"; return false; } RTC_CHECK(send_session_); - return send_session_->ProtectRtp(p, in_len, max_len, out_len, index); + return send_session_->ProtectRtp(buffer, index); } -bool SrtpTransport::ProtectRtcp(void* p, - int in_len, - int max_len, - int* out_len) { +bool SrtpTransport::ProtectRtcp(rtc::CopyOnWriteBuffer& buffer) { if (!IsSrtpActive()) { RTC_LOG(LS_WARNING) << "Failed to ProtectRtcp: SRTP not active"; return false; } if (send_rtcp_session_) { - return send_rtcp_session_->ProtectRtcp(p, in_len, max_len, out_len); + return send_rtcp_session_->ProtectRtcp(buffer); } else { RTC_CHECK(send_session_); - return send_session_->ProtectRtcp(p, in_len, max_len, out_len); + return send_session_->ProtectRtcp(buffer); } } -bool SrtpTransport::UnprotectRtp(void* p, int in_len, int* out_len) { +bool SrtpTransport::UnprotectRtp(rtc::CopyOnWriteBuffer& buffer) { if (!IsSrtpActive()) { RTC_LOG(LS_WARNING) << "Failed to UnprotectRtp: SRTP not active"; return false; } RTC_CHECK(recv_session_); - return recv_session_->UnprotectRtp(p, in_len, out_len); + return recv_session_->UnprotectRtp(buffer); } -bool SrtpTransport::UnprotectRtcp(void* p, int in_len, int* out_len) { +bool SrtpTransport::UnprotectRtcp(rtc::CopyOnWriteBuffer& buffer) { if (!IsSrtpActive()) { RTC_LOG(LS_WARNING) << "Failed to UnprotectRtcp: SRTP not active"; return false; } if (recv_rtcp_session_) { - return recv_rtcp_session_->UnprotectRtcp(p, in_len, out_len); + return recv_rtcp_session_->UnprotectRtcp(buffer); } else { RTC_CHECK(recv_session_); - return recv_session_->UnprotectRtcp(p, in_len, out_len); + return recv_session_->UnprotectRtcp(buffer); } } diff --git a/pc/srtp_transport.h b/pc/srtp_transport.h index dd86006ee1..f5ddc46faa 100644 --- a/pc/srtp_transport.h +++ b/pc/srtp_transport.h @@ -20,13 +20,14 @@ #include #include "api/field_trials_view.h" -#include "api/rtc_error.h" +#include "call/rtp_demuxer.h" #include "p2p/base/packet_transport_internal.h" #include "pc/rtp_transport.h" #include "pc/srtp_session.h" #include "rtc_base/async_packet_socket.h" #include "rtc_base/buffer.h" #include "rtc_base/copy_on_write_buffer.h" +#include "rtc_base/network/received_packet.h" #include "rtc_base/network_route.h" namespace webrtc { @@ -121,21 +122,15 @@ class SrtpTransport : public RtpTransport { // Override the RtpTransport::OnWritableState. void OnWritableState(rtc::PacketTransportInternal* packet_transport) override; - bool ProtectRtp(void* data, int in_len, int max_len, int* out_len); - + bool ProtectRtp(rtc::CopyOnWriteBuffer& buffer); // Overloaded version, outputs packet index. - bool ProtectRtp(void* data, - int in_len, - int max_len, - int* out_len, - int64_t* index); - bool ProtectRtcp(void* data, int in_len, int max_len, int* out_len); + bool ProtectRtp(rtc::CopyOnWriteBuffer& buffer, int64_t* index); + bool ProtectRtcp(rtc::CopyOnWriteBuffer& buffer); // Decrypts/verifies an invidiual RTP/RTCP packet. // If an HMAC is used, this will decrease the packet size. - bool UnprotectRtp(void* data, int in_len, int* out_len); - - bool UnprotectRtcp(void* data, int in_len, int* out_len); + bool UnprotectRtp(rtc::CopyOnWriteBuffer& buffer); + bool UnprotectRtcp(rtc::CopyOnWriteBuffer& buffer); const std::string content_name_;