diff --git a/webrtc/p2p/base/fakepackettransport.h b/webrtc/p2p/base/fakepackettransport.h index 16af9a42ae..82d39c2364 100644 --- a/webrtc/p2p/base/fakepackettransport.h +++ b/webrtc/p2p/base/fakepackettransport.h @@ -86,6 +86,8 @@ class FakePacketTransport : public PacketTransportInternal { bool GetOption(Socket::Option opt, int* value) override { return true; } int GetError() override { return 0; } + const CopyOnWriteBuffer* last_sent_packet() { return &last_sent_packet_; } + private: void set_writable(bool writable) { if (writable_ == writable) { @@ -107,12 +109,14 @@ class FakePacketTransport : public PacketTransportInternal { } void SendPacketInternal(const CopyOnWriteBuffer& packet) { + last_sent_packet_ = packet; if (dest_) { dest_->SignalReadPacket(dest_, packet.data(), packet.size(), CreatePacketTime(0), 0); } } + CopyOnWriteBuffer last_sent_packet_; AsyncInvoker invoker_; std::string debug_name_; FakePacketTransport* dest_ = nullptr; diff --git a/webrtc/pc/channel.cc b/webrtc/pc/channel.cc index 59f0869431..d99a105868 100644 --- a/webrtc/pc/channel.cc +++ b/webrtc/pc/channel.cc @@ -158,18 +158,22 @@ BaseChannel::BaseChannel(rtc::Thread* worker_thread, signaling_thread_(signaling_thread), content_name_(content_name), rtcp_mux_required_(rtcp_mux_required), - rtp_transport_( - srtp_required - ? rtc::WrapUnique( - new webrtc::SrtpTransport(rtcp_mux_required, content_name)) - : rtc::MakeUnique(rtcp_mux_required)), srtp_required_(srtp_required), media_channel_(media_channel), selected_candidate_pair_(nullptr) { RTC_DCHECK(worker_thread_ == rtc::Thread::Current()); + if (srtp_required) { + auto transport = + rtc::MakeUnique(rtcp_mux_required, content_name); + srtp_transport_ = transport.get(); + rtp_transport_ = std::move(transport); #if defined(ENABLE_EXTERNAL_AUTH) - srtp_filter_.EnableExternalAuth(); + srtp_transport_->EnableExternalAuth(); #endif + } else { + rtp_transport_ = rtc::MakeUnique(rtcp_mux_required); + srtp_transport_ = nullptr; + } rtp_transport_->SignalReadyToSend.connect( this, &BaseChannel::OnTransportReadyToSend); // TODO(zstein): RtpTransport::SignalPacketReceived will probably be replaced @@ -314,14 +318,17 @@ void BaseChannel::SetTransports_n( return; } - // When using DTLS-SRTP, we must reset the SrtpFilter every time the transport - // changes and wait until the DTLS handshake is complete to set the newly - // negotiated parameters. + // When using DTLS-SRTP, we must reset the SrtpTransport every time the + // DtlsTransport changes and wait until the DTLS handshake is complete to set + // the newly negotiated parameters. if (ShouldSetupDtlsSrtp_n()) { // Set |writable_| to false such that UpdateWritableState_w can set up // DTLS-SRTP when |writable_| becomes true again. writable_ = false; - srtp_filter_.ResetParams(); + dtls_active_ = false; + if (srtp_transport_) { + srtp_transport_->ResetParams(); + } } // If this BaseChannel doesn't require RTCP mux and we haven't fully @@ -377,8 +384,8 @@ void BaseChannel::SetTransport_n( } if (rtcp && new_dtls_transport) { - RTC_CHECK(!(ShouldSetupDtlsSrtp_n() && srtp_filter_.IsActive())) - << "Setting RTCP for DTLS/SRTP after SrtpFilter is active " + RTC_CHECK(!(ShouldSetupDtlsSrtp_n() && srtp_active())) + << "Setting RTCP for DTLS/SRTP after the DTLS is active " << "should never happen."; } @@ -529,8 +536,7 @@ bool BaseChannel::IsReadyToSendMedia_n() const { // and we have had some form of connectivity. return enabled() && IsReceiveContentDirection(remote_content_direction_) && IsSendContentDirection(local_content_direction_) && - was_ever_writable() && - (srtp_filter_.IsActive() || !ShouldSetupDtlsSrtp_n()); + was_ever_writable() && (srtp_active() || !ShouldSetupDtlsSrtp_n()); } bool BaseChannel::SendPacket(rtc::CopyOnWriteBuffer* packet, @@ -582,13 +588,16 @@ void BaseChannel::OnDtlsState(DtlsTransportInternal* transport, return; } - // Reset the srtp filter if it's not the CONNECTED state. For the CONNECTED + // Reset the SrtpTransport if it's not the CONNECTED state. For the CONNECTED // state, setting up DTLS-SRTP context is deferred to ChannelWritable_w to // cover other scenarios like the whole transport is writable (not just this // TransportChannel) or when TransportChannel is attached after DTLS is // negotiated. if (state != DTLS_TRANSPORT_CONNECTED) { - srtp_filter_.ResetParams(); + dtls_active_ = false; + if (srtp_transport_) { + srtp_transport_->ResetParams(); + } } } @@ -662,91 +671,30 @@ bool BaseChannel::SendPacket(bool rtcp, return false; } - rtc::PacketOptions updated_options; - updated_options = options; - // Protect if needed. - if (srtp_filter_.IsActive()) { - TRACE_EVENT0("webrtc", "SRTP Encode"); - bool res; - uint8_t* data = packet->data(); - int len = static_cast(packet->size()); - if (!rtcp) { - // If ENABLE_EXTERNAL_AUTH flag is on then packet authentication is not done - // inside libsrtp for a RTP packet. A external HMAC module will be writing - // a fake HMAC value. This is ONLY done for a RTP packet. - // Socket layer will update rtp sendtime extension header if present in - // packet with current time before updating the HMAC. -#if !defined(ENABLE_EXTERNAL_AUTH) - res = srtp_filter_.ProtectRtp( - data, len, static_cast(packet->capacity()), &len); -#else - if (!srtp_filter_.IsExternalAuthActive()) { - res = srtp_filter_.ProtectRtp( - data, len, static_cast(packet->capacity()), &len); - } else { - updated_options.packet_time_params.rtp_sendtime_extension_id = - rtp_abs_sendtime_extn_id_; - res = srtp_filter_.ProtectRtp( - data, len, static_cast(packet->capacity()), &len, - &updated_options.packet_time_params.srtp_packet_index); - // If protection succeeds, let's get auth params from srtp. - if (res) { - uint8_t* auth_key = NULL; - int key_len; - res = srtp_filter_.GetRtpAuthParams( - &auth_key, &key_len, - &updated_options.packet_time_params.srtp_auth_tag_len); - if (res) { - updated_options.packet_time_params.srtp_auth_key.resize(key_len); - updated_options.packet_time_params.srtp_auth_key.assign( - auth_key, auth_key + key_len); - } - } - } -#endif - if (!res) { - int seq_num = -1; - uint32_t ssrc = 0; - GetRtpSeqNum(data, len, &seq_num); - GetRtpSsrc(data, len, &ssrc); - LOG(LS_ERROR) << "Failed to protect " << content_name_ - << " RTP packet: size=" << len - << ", seqnum=" << seq_num << ", SSRC=" << ssrc; + if (!srtp_active()) { + if (srtp_required_) { + // The audio/video engines may attempt to send RTCP packets as soon as the + // streams are created, so don't treat this as an error for RTCP. + // See: https://bugs.chromium.org/p/webrtc/issues/detail?id=6809 + if (rtcp) { return false; } - } else { - res = srtp_filter_.ProtectRtcp(data, len, - static_cast(packet->capacity()), - &len); - if (!res) { - int type = -1; - GetRtcpType(data, len, &type); - LOG(LS_ERROR) << "Failed to protect " << content_name_ - << " RTCP packet: size=" << len << ", type=" << type; - return false; - } - } - - // Update the length of the packet now that we've added the auth tag. - packet->SetSize(len); - } else if (srtp_required_) { - // The audio/video engines may attempt to send RTCP packets as soon as the - // streams are created, so don't treat this as an error for RTCP. - // See: https://bugs.chromium.org/p/webrtc/issues/detail?id=6809 - if (rtcp) { + // However, there shouldn't be any RTP packets sent before SRTP is set up + // (and SetSend(true) is called). + LOG(LS_ERROR) << "Can't send outgoing RTP packet when SRTP is inactive" + << " and crypto is required"; + RTC_NOTREACHED(); return false; } - // However, there shouldn't be any RTP packets sent before SRTP is set up - // (and SetSend(true) is called). - LOG(LS_ERROR) << "Can't send outgoing RTP packet when SRTP is inactive" - << " and crypto is required"; - RTC_NOTREACHED(); - return false; + // Bon voyage. + return rtcp ? rtp_transport_->SendRtcpPacket(packet, options, PF_NORMAL) + : rtp_transport_->SendRtpPacket(packet, options, PF_NORMAL); } - + RTC_DCHECK(srtp_transport_); + RTC_DCHECK(srtp_transport_->IsActive()); // Bon voyage. - int flags = (secure() && secure_dtls()) ? PF_SRTP_BYPASS : PF_NORMAL; - return rtp_transport_->SendPacket(rtcp, packet, updated_options, flags); + return rtcp ? srtp_transport_->SendRtcpPacket(packet, options, PF_SRTP_BYPASS) + : srtp_transport_->SendRtpPacket(packet, options, PF_SRTP_BYPASS); } bool BaseChannel::HandlesPayloadType(int packet_type) const { @@ -761,37 +709,7 @@ void BaseChannel::OnPacketReceived(bool rtcp, signaling_thread()->Post(RTC_FROM_HERE, this, MSG_FIRSTPACKETRECEIVED); } - // Unprotect the packet, if needed. - if (srtp_filter_.IsActive()) { - TRACE_EVENT0("webrtc", "SRTP Decode"); - char* data = packet->data(); - int len = static_cast(packet->size()); - bool res; - if (!rtcp) { - res = srtp_filter_.UnprotectRtp(data, len, &len); - if (!res) { - int seq_num = -1; - uint32_t ssrc = 0; - GetRtpSeqNum(data, len, &seq_num); - GetRtpSsrc(data, len, &ssrc); - LOG(LS_ERROR) << "Failed to unprotect " << content_name_ - << " RTP packet: size=" << len << ", seqnum=" << seq_num - << ", SSRC=" << ssrc; - return; - } - } else { - res = srtp_filter_.UnprotectRtcp(data, len, &len); - if (!res) { - int type = -1; - GetRtcpType(data, len, &type); - LOG(LS_ERROR) << "Failed to unprotect " << content_name_ - << " RTCP packet: size=" << len << ", type=" << type; - return; - } - } - - packet->SetSize(len); - } else if (srtp_required_) { + if (!srtp_active() && srtp_required_) { // Our session description indicates that SRTP is required, but we got a // packet before our SRTP filter is active. This means either that // a) we got SRTP packets before we received the SDES keys, in which case @@ -995,43 +913,37 @@ bool BaseChannel::SetupDtlsSrtp_n(bool rtcp) { recv_key = &server_write_key; } - if (!srtp_filter_.IsActive()) { - if (rtcp) { - ret = srtp_filter_.SetRtcpParams(selected_crypto_suite, &(*send_key)[0], - static_cast(send_key->size()), - selected_crypto_suite, &(*recv_key)[0], - static_cast(recv_key->size())); + if (rtcp) { + if (!dtls_active()) { + RTC_DCHECK(srtp_transport_); + ret = srtp_transport_->SetRtcpParams( + selected_crypto_suite, &(*send_key)[0], + static_cast(send_key->size()), selected_crypto_suite, + &(*recv_key)[0], static_cast(recv_key->size())); } else { - ret = srtp_filter_.SetRtpParams(selected_crypto_suite, &(*send_key)[0], - static_cast(send_key->size()), - selected_crypto_suite, &(*recv_key)[0], - static_cast(recv_key->size())); + // RTCP doesn't need to call SetRtpParam because it is only used + // to make the updated encrypted RTP header extension IDs take effect. + ret = true; } } else { - if (rtcp) { - // RTCP doesn't need to be updated because UpdateRtpParams is only used - // to update the set of encrypted RTP header extension IDs. - ret = true; - } else { - ret = srtp_filter_.UpdateRtpParams( - selected_crypto_suite, - &(*send_key)[0], static_cast(send_key->size()), - selected_crypto_suite, - &(*recv_key)[0], static_cast(recv_key->size())); - } + RTC_DCHECK(srtp_transport_); + ret = srtp_transport_->SetRtpParams(selected_crypto_suite, &(*send_key)[0], + static_cast(send_key->size()), + selected_crypto_suite, &(*recv_key)[0], + static_cast(recv_key->size())); + dtls_active_ = ret; } if (!ret) { LOG(LS_WARNING) << "DTLS-SRTP key installation failed"; } else { - dtls_keyed_ = true; UpdateTransportOverhead(); } return ret; } void BaseChannel::MaybeSetupDtlsSrtp_n() { - if (srtp_filter_.IsActive()) { + if (dtls_active()) { return; } @@ -1039,6 +951,10 @@ void BaseChannel::MaybeSetupDtlsSrtp_n() { return; } + if (!srtp_transport_) { + EnableSrtpTransport_n(); + } + if (!SetupDtlsSrtp_n(false)) { SignalDtlsSrtpSetupFailure_n(false); return; @@ -1122,6 +1038,24 @@ bool BaseChannel::CheckSrtpConfig_n(const std::vector& cryptos, return true; } +void BaseChannel::EnableSrtpTransport_n() { + if (srtp_transport_ == nullptr) { + rtp_transport_->SignalReadyToSend.disconnect(this); + rtp_transport_->SignalPacketReceived.disconnect(this); + + auto transport = rtc::MakeUnique( + std::move(rtp_transport_), content_name_); + srtp_transport_ = transport.get(); + rtp_transport_ = std::move(transport); + + rtp_transport_->SignalReadyToSend.connect( + this, &BaseChannel::OnTransportReadyToSend); + rtp_transport_->SignalPacketReceived.connect( + this, &BaseChannel::OnPacketReceived); + LOG(LS_INFO) << "Wrapping RtpTransport in SrtpTransport."; + } +} + bool BaseChannel::SetSrtp_n(const std::vector& cryptos, ContentAction action, ContentSource src, @@ -1138,36 +1072,69 @@ bool BaseChannel::SetSrtp_n(const std::vector& cryptos, if (!ret) { return false; } - srtp_filter_.SetEncryptedHeaderExtensionIds(src, encrypted_extension_ids); + + // If SRTP was not required, but we're setting a description that uses SDES, + // we need to upgrade to an SrtpTransport. + if (!srtp_transport_ && !dtls && !cryptos.empty()) { + EnableSrtpTransport_n(); + } + if (srtp_transport_) { + srtp_transport_->SetEncryptedHeaderExtensionIds(src, + encrypted_extension_ids); + } switch (action) { case CA_OFFER: // If DTLS is already active on the channel, we could be renegotiating // here. We don't update the srtp filter. if (!dtls) { - ret = srtp_filter_.SetOffer(cryptos, src); + ret = sdes_negotiator_.SetOffer(cryptos, src); } break; case CA_PRANSWER: // If we're doing DTLS-SRTP, we don't want to update the filter // with an answer, because we already have SRTP parameters. if (!dtls) { - ret = srtp_filter_.SetProvisionalAnswer(cryptos, src); + ret = sdes_negotiator_.SetProvisionalAnswer(cryptos, src); } break; case CA_ANSWER: // If we're doing DTLS-SRTP, we don't want to update the filter // with an answer, because we already have SRTP parameters. if (!dtls) { - ret = srtp_filter_.SetAnswer(cryptos, src); + ret = sdes_negotiator_.SetAnswer(cryptos, src); } break; default: break; } + + // If setting an SDES answer succeeded, apply the negotiated parameters + // to the SRTP transport. + if ((action == CA_PRANSWER || action == CA_ANSWER) && !dtls && ret) { + if (sdes_negotiator_.send_cipher_suite() && + sdes_negotiator_.recv_cipher_suite()) { + ret = srtp_transport_->SetRtpParams( + *(sdes_negotiator_.send_cipher_suite()), + sdes_negotiator_.send_key().data(), + static_cast(sdes_negotiator_.send_key().size()), + *(sdes_negotiator_.recv_cipher_suite()), + sdes_negotiator_.recv_key().data(), + static_cast(sdes_negotiator_.recv_key().size())); + } else { + LOG(LS_INFO) << "No crypto keys are provided for SDES."; + if (action == CA_ANSWER && srtp_transport_) { + // Explicitly reset the |srtp_transport_| if no crypto param is + // provided in the answer. No need to call |ResetParams()| for + // |sdes_negotiator_| because it resets the params inside |SetAnswer|. + srtp_transport_->ResetParams(); + } + } + } + // Only update SRTP filter if using DTLS. SDES is handled internally // by the SRTP filter. // TODO(jbauch): Only update if encrypted extension ids have changed. - if (ret && dtls_keyed_ && rtp_dtls_transport_ && + if (ret && dtls_active() && rtp_dtls_transport_ && rtp_dtls_transport_->dtls_state() == DTLS_TRANSPORT_CONNECTED) { bool rtcp = false; ret = SetupDtlsSrtp_n(rtcp); @@ -1211,7 +1178,6 @@ bool BaseChannel::SetRtcpMux_n(bool enable, transport_name_.empty() ? rtp_transport_->rtp_packet_transport()->debug_name() : transport_name_; - ; LOG(LS_INFO) << "Enabling rtcp-mux for " << content_name() << "; no longer need RTCP transport for " << debug_name; if (rtp_transport_->rtcp_packet_transport()) { @@ -1440,7 +1406,13 @@ void BaseChannel::MaybeCacheRtpAbsSendTimeHeaderExtension_w( void BaseChannel::CacheRtpAbsSendTimeHeaderExtension_n( int rtp_abs_sendtime_extn_id) { - rtp_abs_sendtime_extn_id_ = rtp_abs_sendtime_extn_id; + if (srtp_transport_) { + srtp_transport_->CacheRtpAbsSendTimeHeaderExtension( + rtp_abs_sendtime_extn_id); + } else { + LOG(LS_WARNING) << "Trying to cache the Absolute Send Time extension id " + "but the SRTP is not active."; + } } void BaseChannel::OnMessage(rtc::Message *pmsg) { @@ -1724,9 +1696,9 @@ int BaseChannel::GetTransportOverheadPerPacket() const { ? kTcpOverhaed : kUdpOverhaed; - if (secure()) { + if (sdes_active()) { int srtp_overhead = 0; - if (srtp_filter_.GetSrtpOverhead(&srtp_overhead)) + if (srtp_transport_->GetSrtpOverhead(&srtp_overhead)) transport_overhead_per_packet += srtp_overhead; } diff --git a/webrtc/pc/channel.h b/webrtc/pc/channel.h index c6dc29dd08..b95bd529b0 100644 --- a/webrtc/pc/channel.h +++ b/webrtc/pc/channel.h @@ -33,7 +33,6 @@ #include "webrtc/pc/mediamonitor.h" #include "webrtc/pc/mediasession.h" #include "webrtc/pc/rtcpmuxfilter.h" -#include "webrtc/pc/rtptransportinternal.h" #include "webrtc/pc/srtpfilter.h" #include "webrtc/rtc_base/asyncinvoker.h" #include "webrtc/rtc_base/asyncudpsocket.h" @@ -44,6 +43,8 @@ namespace webrtc { class AudioSinkInterface; +class RtpTransportInternal; +class SrtpTransport; } // namespace webrtc namespace cricket { @@ -99,12 +100,12 @@ class BaseChannel const std::string& transport_name() const { return transport_name_; } bool enabled() const { return enabled_; } - // This function returns true if we are using SRTP. - bool secure() const { return srtp_filter_.IsActive(); } - // The following function returns true if we are using - // DTLS-based keying. If you turned off SRTP later, however - // you could have secure() == false and dtls_secure() == true. - bool secure_dtls() const { return dtls_keyed_; } + // This function returns true if we are using SDES. + bool sdes_active() const { return sdes_negotiator_.IsActive(); } + // The following function returns true if we are using DTLS-based keying. + bool dtls_active() const { return dtls_active_; } + // This function returns true if using SRTP (DTLS-based keying or SDES). + bool srtp_active() const { return sdes_active() || dtls_active(); } bool writable() const { return writable_; } @@ -188,8 +189,6 @@ class BaseChannel override; int SetOption_n(SocketType type, rtc::Socket::Option o, int val); - SrtpFilter* srtp_filter() { return &srtp_filter_; } - virtual cricket::MediaType media_type() = 0; // This function returns true if we require SRTP for call setup. @@ -378,6 +377,8 @@ class BaseChannel void CacheRtpAbsSendTimeHeaderExtension_n(int rtp_abs_sendtime_extn_id); int GetTransportOverheadPerPacket() const; void UpdateTransportOverhead(); + // Wraps the existing RtpTransport in an SrtpTransport. + void EnableSrtpTransport_n(); rtc::Thread* const worker_thread_; rtc::Thread* const network_thread_; @@ -398,16 +399,16 @@ class BaseChannel DtlsTransportInternal* rtp_dtls_transport_ = nullptr; DtlsTransportInternal* rtcp_dtls_transport_ = nullptr; std::unique_ptr rtp_transport_; + webrtc::SrtpTransport* srtp_transport_ = nullptr; std::vector > socket_options_; std::vector > rtcp_socket_options_; - SrtpFilter srtp_filter_; + SrtpFilter sdes_negotiator_; RtcpMuxFilter rtcp_mux_filter_; bool writable_ = false; bool was_ever_writable_ = false; bool has_received_packet_ = false; - bool dtls_keyed_ = false; + bool dtls_active_ = false; const bool srtp_required_ = true; - int rtp_abs_sendtime_extn_id_ = -1; // MediaChannel related members that should be accessed from the worker // thread. diff --git a/webrtc/pc/channel_unittest.cc b/webrtc/pc/channel_unittest.cc index 3f30201874..ae34a12de4 100644 --- a/webrtc/pc/channel_unittest.cc +++ b/webrtc/pc/channel_unittest.cc @@ -581,7 +581,7 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { // Basic sanity check. void TestInit() { CreateChannels(0, 0); - EXPECT_FALSE(channel1_->secure()); + EXPECT_FALSE(channel1_->srtp_active()); EXPECT_FALSE(media_channel1_->sending()); if (verify_playout_) { EXPECT_FALSE(media_channel1_->playout()); @@ -896,8 +896,8 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { EXPECT_TRUE(channel2_->SetRemoteContent(&content4, CA_ANSWER, NULL)); EXPECT_EQ(0u, media_channel2_->recv_streams().size()); - EXPECT_TRUE(channel1_->secure()); - EXPECT_TRUE(channel2_->secure()); + EXPECT_TRUE(channel1_->srtp_active()); + EXPECT_TRUE(channel2_->srtp_active()); SendCustomRtp2(kSsrc2, 0); WaitForThreads(); EXPECT_TRUE(CheckCustomRtp1(kSsrc2, 0)); @@ -1253,14 +1253,14 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { // Test setting up a call. void TestCallSetup() { CreateChannels(0, 0); - EXPECT_FALSE(channel1_->secure()); + EXPECT_FALSE(channel1_->srtp_active()); EXPECT_TRUE(SendInitiate()); if (verify_playout_) { EXPECT_TRUE(media_channel1_->playout()); } EXPECT_FALSE(media_channel1_->sending()); EXPECT_TRUE(SendAccept()); - EXPECT_FALSE(channel1_->secure()); + EXPECT_FALSE(channel1_->srtp_active()); EXPECT_TRUE(media_channel1_->sending()); EXPECT_EQ(1U, media_channel1_->codecs().size()); if (verify_playout_) { @@ -1535,17 +1535,17 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { bool dtls1 = !!(flags1_in & DTLS); bool dtls2 = !!(flags2_in & DTLS); CreateChannels(flags1, flags2); - EXPECT_FALSE(channel1_->secure()); - EXPECT_FALSE(channel2_->secure()); + EXPECT_FALSE(channel1_->srtp_active()); + EXPECT_FALSE(channel2_->srtp_active()); EXPECT_TRUE(SendInitiate()); WaitForThreads(); EXPECT_TRUE(channel1_->writable()); EXPECT_TRUE(channel2_->writable()); EXPECT_TRUE(SendAccept()); - EXPECT_TRUE(channel1_->secure()); - EXPECT_TRUE(channel2_->secure()); - EXPECT_EQ(dtls1 && dtls2, channel1_->secure_dtls()); - EXPECT_EQ(dtls1 && dtls2, channel2_->secure_dtls()); + EXPECT_TRUE(channel1_->srtp_active()); + EXPECT_TRUE(channel2_->srtp_active()); + EXPECT_EQ(dtls1 && dtls2, channel1_->dtls_active()); + EXPECT_EQ(dtls1 && dtls2, channel2_->dtls_active()); SendRtp1(); SendRtp2(); SendRtcp1(); @@ -1564,12 +1564,12 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { // Test that we properly handling SRTP negotiating down to RTP. void SendSrtpToRtp() { CreateChannels(SECURE, 0); - EXPECT_FALSE(channel1_->secure()); - EXPECT_FALSE(channel2_->secure()); + EXPECT_FALSE(channel1_->srtp_active()); + EXPECT_FALSE(channel2_->srtp_active()); EXPECT_TRUE(SendInitiate()); EXPECT_TRUE(SendAccept()); - EXPECT_FALSE(channel1_->secure()); - EXPECT_FALSE(channel2_->secure()); + EXPECT_FALSE(channel1_->srtp_active()); + EXPECT_FALSE(channel2_->srtp_active()); SendRtp1(); SendRtp2(); SendRtcp1(); @@ -1594,8 +1594,8 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { SSRC_MUX | RTCP_MUX | SECURE); EXPECT_TRUE(SendOffer()); EXPECT_TRUE(SendProvisionalAnswer()); - EXPECT_TRUE(channel1_->secure()); - EXPECT_TRUE(channel2_->secure()); + EXPECT_TRUE(channel1_->srtp_active()); + EXPECT_TRUE(channel2_->srtp_active()); EXPECT_TRUE(channel1_->NeedsRtcpTransport()); EXPECT_TRUE(channel2_->NeedsRtcpTransport()); WaitForThreads(); // Wait for 'sending' flag go through network thread. @@ -1620,8 +1620,8 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { EXPECT_FALSE(channel2_->NeedsRtcpTransport()); EXPECT_EQ(1, rtcp_mux_activated_callbacks1_); EXPECT_EQ(1, rtcp_mux_activated_callbacks2_); - EXPECT_TRUE(channel1_->secure()); - EXPECT_TRUE(channel2_->secure()); + EXPECT_TRUE(channel1_->srtp_active()); + EXPECT_TRUE(channel2_->srtp_active()); SendCustomRtcp1(kSsrc1); SendCustomRtp1(kSsrc1, ++sequence_number1_1); SendCustomRtcp2(kSsrc2); diff --git a/webrtc/pc/rtptransport.cc b/webrtc/pc/rtptransport.cc index ac57eb887b..b26783260f 100644 --- a/webrtc/pc/rtptransport.cc +++ b/webrtc/pc/rtptransport.cc @@ -76,6 +76,18 @@ bool RtpTransport::IsWritable(bool rtcp) const { return transport && transport->writable(); } +bool RtpTransport::SendRtpPacket(rtc::CopyOnWriteBuffer* packet, + const rtc::PacketOptions& options, + int flags) { + return SendPacket(false, packet, options, flags); +} + +bool RtpTransport::SendRtcpPacket(rtc::CopyOnWriteBuffer* packet, + const rtc::PacketOptions& options, + int flags) { + return SendPacket(true, packet, options, flags); +} + bool RtpTransport::SendPacket(bool rtcp, rtc::CopyOnWriteBuffer* packet, const rtc::PacketOptions& options, diff --git a/webrtc/pc/rtptransport.h b/webrtc/pc/rtptransport.h index 8b408cabfa..94c5877142 100644 --- a/webrtc/pc/rtptransport.h +++ b/webrtc/pc/rtptransport.h @@ -56,10 +56,13 @@ class RtpTransport : public RtpTransportInternal { bool IsWritable(bool rtcp) const override; - bool SendPacket(bool rtcp, - rtc::CopyOnWriteBuffer* packet, - const rtc::PacketOptions& options, - int flags) override; + bool SendRtpPacket(rtc::CopyOnWriteBuffer* packet, + const rtc::PacketOptions& options, + int flags) override; + + bool SendRtcpPacket(rtc::CopyOnWriteBuffer* packet, + const rtc::PacketOptions& options, + int flags) override; bool HandlesPayloadType(int payload_type) const override; @@ -80,6 +83,11 @@ class RtpTransport : public RtpTransportInternal { void MaybeSignalReadyToSend(); + bool SendPacket(bool rtcp, + rtc::CopyOnWriteBuffer* packet, + const rtc::PacketOptions& options, + int flags); + void OnReadPacket(rtc::PacketTransportInternal* transport, const char* data, size_t len, diff --git a/webrtc/pc/rtptransportinternal.h b/webrtc/pc/rtptransportinternal.h index fd94d8e8f0..684afc4e67 100644 --- a/webrtc/pc/rtptransportinternal.h +++ b/webrtc/pc/rtptransportinternal.h @@ -54,10 +54,13 @@ class RtpTransportInternal : public RtpTransportInterface, virtual bool IsWritable(bool rtcp) const = 0; - virtual bool SendPacket(bool rtcp, - rtc::CopyOnWriteBuffer* packet, - const rtc::PacketOptions& options, - int flags) = 0; + virtual bool SendRtpPacket(rtc::CopyOnWriteBuffer* packet, + const rtc::PacketOptions& options, + int flags) = 0; + + virtual bool SendRtcpPacket(rtc::CopyOnWriteBuffer* packet, + const rtc::PacketOptions& options, + int flags) = 0; virtual bool HandlesPayloadType(int payload_type) const = 0; diff --git a/webrtc/pc/srtpfilter.cc b/webrtc/pc/srtpfilter.cc index dde84bc14e..560dac8f82 100644 --- a/webrtc/pc/srtpfilter.cc +++ b/webrtc/pc/srtpfilter.cc @@ -17,7 +17,6 @@ #include "webrtc/media/base/rtputils.h" #include "webrtc/pc/srtpsession.h" #include "webrtc/rtc_base/base64.h" -#include "webrtc/rtc_base/buffer.h" #include "webrtc/rtc_base/byteorder.h" #include "webrtc/rtc_base/checks.h" #include "webrtc/rtc_base/logging.h" @@ -63,209 +62,6 @@ bool SrtpFilter::SetProvisionalAnswer( return DoSetAnswer(answer_params, source, false); } -bool SrtpFilter::SetRtpParams(int send_cs, - const uint8_t* send_key, - int send_key_len, - int recv_cs, - const uint8_t* recv_key, - int recv_key_len) { - if (IsActive()) { - LOG(LS_ERROR) << "Tried to set SRTP Params when filter already active"; - return false; - } - CreateSrtpSessions(); - send_session_->SetEncryptedHeaderExtensionIds( - send_encrypted_header_extension_ids_); - if (!send_session_->SetSend(send_cs, send_key, send_key_len)) { - return false; - } - - recv_session_->SetEncryptedHeaderExtensionIds( - recv_encrypted_header_extension_ids_); - if (!recv_session_->SetRecv(recv_cs, recv_key, recv_key_len)) { - return false; - } - - state_ = ST_ACTIVE; - - LOG(LS_INFO) << "SRTP activated with negotiated parameters:" - << " send cipher_suite " << send_cs - << " recv cipher_suite " << recv_cs; - return true; -} - -bool SrtpFilter::UpdateRtpParams(int send_cs, - const uint8_t* send_key, - int send_key_len, - int recv_cs, - const uint8_t* recv_key, - int recv_key_len) { - if (!IsActive()) { - LOG(LS_ERROR) << "Tried to update SRTP Params when filter is not active"; - return false; - } - send_session_->SetEncryptedHeaderExtensionIds( - send_encrypted_header_extension_ids_); - if (!send_session_->UpdateSend(send_cs, send_key, send_key_len)) { - return false; - } - - recv_session_->SetEncryptedHeaderExtensionIds( - recv_encrypted_header_extension_ids_); - if (!recv_session_->UpdateRecv(recv_cs, recv_key, recv_key_len)) { - return false; - } - - LOG(LS_INFO) << "SRTP updated with negotiated parameters:" - << " send cipher_suite " << send_cs - << " recv cipher_suite " << recv_cs; - return true; -} - -// This function is provided separately because DTLS-SRTP behaves -// differently in RTP/RTCP mux and non-mux modes. -// -// - In the non-muxed case, RTP and RTCP are keyed with different -// keys (from different DTLS handshakes), and so we need a new -// SrtpSession. -// - In the muxed case, they are keyed with the same keys, so -// this function is not needed -bool SrtpFilter::SetRtcpParams(int send_cs, - const uint8_t* send_key, - int send_key_len, - int recv_cs, - const uint8_t* recv_key, - int recv_key_len) { - // This can only be called once, but can be safely called after - // SetRtpParams - if (send_rtcp_session_ || recv_rtcp_session_) { - LOG(LS_ERROR) << "Tried to set SRTCP Params when filter already active"; - return false; - } - - send_rtcp_session_.reset(new SrtpSession()); - if (!send_rtcp_session_->SetRecv(send_cs, send_key, send_key_len)) { - return false; - } - - recv_rtcp_session_.reset(new SrtpSession()); - if (!recv_rtcp_session_->SetRecv(recv_cs, recv_key, recv_key_len)) { - return false; - } - - LOG(LS_INFO) << "SRTCP activated with negotiated parameters:" - << " send cipher_suite " << send_cs - << " recv cipher_suite " << recv_cs; - - return true; -} - -bool SrtpFilter::ProtectRtp(void* p, int in_len, int max_len, int* out_len) { - if (!IsActive()) { - LOG(LS_WARNING) << "Failed to ProtectRtp: SRTP not active"; - return false; - } - RTC_CHECK(send_session_); - return send_session_->ProtectRtp(p, in_len, max_len, out_len); -} - -bool SrtpFilter::ProtectRtp(void* p, - int in_len, - int max_len, - int* out_len, - int64_t* index) { - if (!IsActive()) { - LOG(LS_WARNING) << "Failed to ProtectRtp: SRTP not active"; - return false; - } - RTC_CHECK(send_session_); - return send_session_->ProtectRtp(p, in_len, max_len, out_len, index); -} - -bool SrtpFilter::ProtectRtcp(void* p, int in_len, int max_len, int* out_len) { - if (!IsActive()) { - LOG(LS_WARNING) << "Failed to ProtectRtcp: SRTP not active"; - return false; - } - if (send_rtcp_session_) { - return send_rtcp_session_->ProtectRtcp(p, in_len, max_len, out_len); - } else { - RTC_CHECK(send_session_); - return send_session_->ProtectRtcp(p, in_len, max_len, out_len); - } -} - -bool SrtpFilter::UnprotectRtp(void* p, int in_len, int* out_len) { - if (!IsActive()) { - LOG(LS_WARNING) << "Failed to UnprotectRtp: SRTP not active"; - return false; - } - RTC_CHECK(recv_session_); - return recv_session_->UnprotectRtp(p, in_len, out_len); -} - -bool SrtpFilter::UnprotectRtcp(void* p, int in_len, int* out_len) { - if (!IsActive()) { - LOG(LS_WARNING) << "Failed to UnprotectRtcp: SRTP not active"; - return false; - } - if (recv_rtcp_session_) { - return recv_rtcp_session_->UnprotectRtcp(p, in_len, out_len); - } else { - RTC_CHECK(recv_session_); - return recv_session_->UnprotectRtcp(p, in_len, out_len); - } -} - -bool SrtpFilter::GetRtpAuthParams(uint8_t** key, int* key_len, int* tag_len) { - if (!IsActive()) { - LOG(LS_WARNING) << "Failed to GetRtpAuthParams: SRTP not active"; - return false; - } - - RTC_CHECK(send_session_); - return send_session_->GetRtpAuthParams(key, key_len, tag_len); -} - -bool SrtpFilter::GetSrtpOverhead(int* srtp_overhead) const { - if (!IsActive()) { - LOG(LS_WARNING) << "Failed to GetSrtpOverhead: SRTP not active"; - return false; - } - - RTC_CHECK(send_session_); - *srtp_overhead = send_session_->GetSrtpOverhead(); - return true; -} - -void SrtpFilter::EnableExternalAuth() { - RTC_DCHECK(!IsActive()); - external_auth_enabled_ = true; -} - -bool SrtpFilter::IsExternalAuthEnabled() const { - return external_auth_enabled_; -} - -bool SrtpFilter::IsExternalAuthActive() const { - if (!IsActive()) { - LOG(LS_WARNING) << "Failed to check IsExternalAuthActive: SRTP not active"; - return false; - } - - RTC_CHECK(send_session_); - return send_session_->IsExternalAuthActive(); -} - -void SrtpFilter::SetEncryptedHeaderExtensionIds(ContentSource source, - const std::vector& extension_ids) { - if (source == CS_LOCAL) { - recv_encrypted_header_extension_ids_ = extension_ids; - } else { - send_encrypted_header_extension_ids_ = extension_ids; - } -} - bool SrtpFilter::ExpectOffer(ContentSource source) { return ((state_ == ST_INIT) || (state_ == ST_ACTIVE) || @@ -323,13 +119,16 @@ bool SrtpFilter::DoSetAnswer(const std::vector& answer_params, CryptoParams selected_params; if (!NegotiateParams(answer_params, &selected_params)) return false; - const CryptoParams& send_params = + + const CryptoParams& new_send_params = (source == CS_REMOTE) ? selected_params : answer_params[0]; - const CryptoParams& recv_params = + const CryptoParams& new_recv_params = (source == CS_REMOTE) ? answer_params[0] : selected_params; - if (!ApplyParams(send_params, recv_params)) { + if (!ApplySendParams(new_send_params) || !ApplyRecvParams(new_recv_params)) { return false; } + applied_send_params_ = new_send_params; + applied_recv_params_ = new_recv_params; if (final) { offer_params_.clear(); @@ -341,17 +140,6 @@ bool SrtpFilter::DoSetAnswer(const std::vector& answer_params, return true; } -void SrtpFilter::CreateSrtpSessions() { - send_session_.reset(new SrtpSession()); - applied_send_params_ = CryptoParams(); - recv_session_.reset(new SrtpSession()); - applied_recv_params_ = CryptoParams(); - - if (external_auth_enabled_) { - send_session_->EnableExternalAuth(); - } -} - bool SrtpFilter::NegotiateParams(const std::vector& answer_params, CryptoParams* selected_params) { // We're processing an accept. We should have exactly one set of params, @@ -379,85 +167,76 @@ bool SrtpFilter::NegotiateParams(const std::vector& answer_params, return ret; } -bool SrtpFilter::ApplyParams(const CryptoParams& send_params, - const CryptoParams& recv_params) { - // TODO(jiayl): Split this method to apply send and receive CryptoParams - // independently, so that we can skip one method when either send or receive - // CryptoParams is unchanged. +bool SrtpFilter::ResetParams() { + offer_params_.clear(); + applied_send_params_ = CryptoParams(); + applied_recv_params_ = CryptoParams(); + send_cipher_suite_ = rtc::Optional(); + recv_cipher_suite_ = rtc::Optional(); + send_key_.Clear(); + recv_key_.Clear(); + state_ = ST_INIT; + return true; +} + +bool SrtpFilter::ApplySendParams(const CryptoParams& send_params) { if (applied_send_params_.cipher_suite == send_params.cipher_suite && - applied_send_params_.key_params == send_params.key_params && - applied_recv_params_.cipher_suite == recv_params.cipher_suite && - applied_recv_params_.key_params == recv_params.key_params) { - LOG(LS_INFO) << "Applying the same SRTP parameters again. No-op."; + applied_send_params_.key_params == send_params.key_params) { + LOG(LS_INFO) << "Applying the same SRTP send parameters again. No-op."; // We do not want to reset the ROC if the keys are the same. So just return. return true; } - int send_suite = rtc::SrtpCryptoSuiteFromName(send_params.cipher_suite); - int recv_suite = rtc::SrtpCryptoSuiteFromName(recv_params.cipher_suite); - if (send_suite == rtc::SRTP_INVALID_CRYPTO_SUITE || - recv_suite == rtc::SRTP_INVALID_CRYPTO_SUITE) { + send_cipher_suite_ = rtc::Optional( + rtc::SrtpCryptoSuiteFromName(send_params.cipher_suite)); + if (send_cipher_suite_ == rtc::SRTP_INVALID_CRYPTO_SUITE) { LOG(LS_WARNING) << "Unknown crypto suite(s) received:" - << " send cipher_suite " << send_params.cipher_suite - << " recv cipher_suite " << recv_params.cipher_suite; + << " send cipher_suite " << send_params.cipher_suite; return false; } int send_key_len, send_salt_len; - int recv_key_len, recv_salt_len; - if (!rtc::GetSrtpKeyAndSaltLengths(send_suite, &send_key_len, - &send_salt_len) || - !rtc::GetSrtpKeyAndSaltLengths(recv_suite, &recv_key_len, - &recv_salt_len)) { + if (!rtc::GetSrtpKeyAndSaltLengths(*send_cipher_suite_, &send_key_len, + &send_salt_len)) { LOG(LS_WARNING) << "Could not get lengths for crypto suite(s):" - << " send cipher_suite " << send_params.cipher_suite + << " send cipher_suite " << send_params.cipher_suite; + return false; + } + + send_key_ = rtc::Buffer(send_key_len + send_salt_len); + return ParseKeyParams(send_params.key_params, send_key_.data(), + send_key_.size()); +} + +bool SrtpFilter::ApplyRecvParams(const CryptoParams& recv_params) { + if (applied_recv_params_.cipher_suite == recv_params.cipher_suite && + applied_recv_params_.key_params == recv_params.key_params) { + LOG(LS_INFO) << "Applying the same SRTP recv parameters again. No-op."; + + // We do not want to reset the ROC if the keys are the same. So just return. + return true; + } + + recv_cipher_suite_ = rtc::Optional( + rtc::SrtpCryptoSuiteFromName(recv_params.cipher_suite)); + if (recv_cipher_suite_ == rtc::SRTP_INVALID_CRYPTO_SUITE) { + LOG(LS_WARNING) << "Unknown crypto suite(s) received:" << " recv cipher_suite " << recv_params.cipher_suite; return false; } - // TODO(juberti): Zero these buffers after use. - bool ret; - rtc::Buffer send_key(send_key_len + send_salt_len); - rtc::Buffer recv_key(recv_key_len + recv_salt_len); - ret = (ParseKeyParams(send_params.key_params, send_key.data(), - send_key.size()) && - ParseKeyParams(recv_params.key_params, recv_key.data(), - recv_key.size())); - if (ret) { - CreateSrtpSessions(); - send_session_->SetEncryptedHeaderExtensionIds( - send_encrypted_header_extension_ids_); - recv_session_->SetEncryptedHeaderExtensionIds( - recv_encrypted_header_extension_ids_); - ret = (send_session_->SetSend( - rtc::SrtpCryptoSuiteFromName(send_params.cipher_suite), - send_key.data(), send_key.size()) && - recv_session_->SetRecv( - rtc::SrtpCryptoSuiteFromName(recv_params.cipher_suite), - recv_key.data(), recv_key.size())); + int recv_key_len, recv_salt_len; + if (!rtc::GetSrtpKeyAndSaltLengths(*recv_cipher_suite_, &recv_key_len, + &recv_salt_len)) { + LOG(LS_WARNING) << "Could not get lengths for crypto suite(s):" + << " recv cipher_suite " << recv_params.cipher_suite; + return false; } - if (ret) { - LOG(LS_INFO) << "SRTP activated with negotiated parameters:" - << " send cipher_suite " << send_params.cipher_suite - << " recv cipher_suite " << recv_params.cipher_suite; - applied_send_params_ = send_params; - applied_recv_params_ = recv_params; - } else { - LOG(LS_WARNING) << "Failed to apply negotiated SRTP parameters"; - } - return ret; -} -bool SrtpFilter::ResetParams() { - offer_params_.clear(); - state_ = ST_INIT; - send_session_ = nullptr; - recv_session_ = nullptr; - send_rtcp_session_ = nullptr; - recv_rtcp_session_ = nullptr; - LOG(LS_INFO) << "SRTP reset to init state"; - return true; + recv_key_ = rtc::Buffer(recv_key_len + recv_salt_len); + return ParseKeyParams(recv_params.key_params, recv_key_.data(), + recv_key_.size()); } bool SrtpFilter::ParseKeyParams(const std::string& key_params, @@ -472,8 +251,9 @@ bool SrtpFilter::ParseKeyParams(const std::string& key_params, // Fail if base64 decode fails, or the key is the wrong size. std::string key_b64(key_params.substr(7)), key_str; - if (!rtc::Base64::Decode(key_b64, rtc::Base64::DO_STRICT, - &key_str, nullptr) || key_str.size() != len) { + if (!rtc::Base64::Decode(key_b64, rtc::Base64::DO_STRICT, &key_str, + nullptr) || + key_str.size() != len) { return false; } diff --git a/webrtc/pc/srtpfilter.h b/webrtc/pc/srtpfilter.h index 15fdae9582..619aaa37d5 100644 --- a/webrtc/pc/srtpfilter.h +++ b/webrtc/pc/srtpfilter.h @@ -20,8 +20,10 @@ #include "webrtc/media/base/cryptoparams.h" #include "webrtc/p2p/base/sessiondescription.h" #include "webrtc/rtc_base/basictypes.h" +#include "webrtc/rtc_base/buffer.h" #include "webrtc/rtc_base/constructormagic.h" #include "webrtc/rtc_base/criticalsection.h" +#include "webrtc/rtc_base/optional.h" #include "webrtc/rtc_base/sslstreamadapter.h" #include "webrtc/rtc_base/thread_checker.h" @@ -31,15 +33,10 @@ struct srtp_ctx_t_; namespace cricket { -class SrtpSession; - void ShutdownSrtp(); -// Class to transform SRTP to/from RTP. -// Initialize by calling SetSend with the local security params, then call -// SetRecv once the remote security params are received. At that point -// Protect/UnprotectRt(c)p can be called to encrypt/decrypt data. -// TODO: Figure out concurrency policy for SrtpFilter. +// A helper class used to negotiate SDES crypto params. +// TODO(zhihuang): Find a better name for this class, like "SdesNegotiator". class SrtpFilter { public: enum Mode { @@ -76,85 +73,38 @@ class SrtpFilter { bool SetAnswer(const std::vector& answer_params, ContentSource source); - // Set the header extension ids that should be encrypted for the given source. - void SetEncryptedHeaderExtensionIds(ContentSource source, - const std::vector& extension_ids); - - // Just set up both sets of keys directly. - // Used with DTLS-SRTP. - bool SetRtpParams(int send_cs, - const uint8_t* send_key, - int send_key_len, - int recv_cs, - const uint8_t* recv_key, - int recv_key_len); - bool UpdateRtpParams(int send_cs, - const uint8_t* send_key, - int send_key_len, - int recv_cs, - const uint8_t* recv_key, - int recv_key_len); - bool SetRtcpParams(int send_cs, - const uint8_t* send_key, - int send_key_len, - int recv_cs, - const uint8_t* recv_key, - int recv_key_len); - - // Encrypts/signs an individual RTP/RTCP packet, in-place. - // If an HMAC is used, this will increase the packet size. - bool ProtectRtp(void* data, int in_len, int max_len, int* out_len); - // Overloaded version, outputs packet index. - bool ProtectRtp(void* data, - int in_len, - int max_len, - int* out_len, - int64_t* index); - bool ProtectRtcp(void* data, int in_len, int max_len, int* out_len); - // Decrypts/verifies an invidiual RTP/RTCP packet. - // If an HMAC is used, this will decrease the packet size. - bool UnprotectRtp(void* data, int in_len, int* out_len); - bool UnprotectRtcp(void* data, int in_len, int* out_len); - - // Returns rtp auth params from srtp context. - bool GetRtpAuthParams(uint8_t** key, int* key_len, int* tag_len); - - // Returns srtp overhead for rtp packets. - bool GetSrtpOverhead(int* srtp_overhead) const; - - // If external auth is enabled, SRTP will write a dummy auth tag that then - // later must get replaced before the packet is sent out. Only supported for - // non-GCM cipher suites and can be checked through "IsExternalAuthActive" - // if it is actually used. This method is only valid before the RTP params - // have been set. - void EnableExternalAuth(); - bool IsExternalAuthEnabled() const; - - // A SRTP filter supports external creation of the auth tag if a non-GCM - // cipher is used. This method is only valid after the RTP params have - // been set. - bool IsExternalAuthActive() const; - bool ResetParams(); + rtc::Optional send_cipher_suite() { return send_cipher_suite_; } + rtc::Optional recv_cipher_suite() { return recv_cipher_suite_; } + + const rtc::Buffer& send_key() { return send_key_; } + const rtc::Buffer& recv_key() { return recv_key_; } + protected: bool ExpectOffer(ContentSource source); + bool StoreParams(const std::vector& params, ContentSource source); + bool ExpectAnswer(ContentSource source); + bool DoSetAnswer(const std::vector& answer_params, - ContentSource source, - bool final); - void CreateSrtpSessions(); + ContentSource source, + bool final); + bool NegotiateParams(const std::vector& answer_params, CryptoParams* selected_params); - bool ApplyParams(const CryptoParams& send_params, - const CryptoParams& recv_params); + + private: + bool ApplySendParams(const CryptoParams& send_params); + + bool ApplyRecvParams(const CryptoParams& recv_params); + static bool ParseKeyParams(const std::string& params, uint8_t* key, size_t len); - private: enum State { ST_INIT, // SRTP filter unused. ST_SENTOFFER, // Offer with SRTP parameters sent. @@ -179,16 +129,13 @@ class SrtpFilter { ST_RECEIVEDPRANSWER }; State state_ = ST_INIT; - bool external_auth_enabled_ = false; std::vector offer_params_; - std::unique_ptr send_session_; - std::unique_ptr recv_session_; - std::unique_ptr send_rtcp_session_; - std::unique_ptr recv_rtcp_session_; CryptoParams applied_send_params_; CryptoParams applied_recv_params_; - std::vector send_encrypted_header_extension_ids_; - std::vector recv_encrypted_header_extension_ids_; + rtc::Optional send_cipher_suite_; + rtc::Optional recv_cipher_suite_; + rtc::Buffer send_key_; + rtc::Buffer recv_key_; }; } // namespace cricket diff --git a/webrtc/pc/srtpfilter_unittest.cc b/webrtc/pc/srtpfilter_unittest.cc index 3f6f008a11..c4ad305afe 100644 --- a/webrtc/pc/srtpfilter_unittest.cc +++ b/webrtc/pc/srtpfilter_unittest.cc @@ -13,14 +13,7 @@ #include "webrtc/pc/srtpfilter.h" #include "webrtc/media/base/cryptoparams.h" -#include "webrtc/media/base/fakertp.h" -#include "webrtc/p2p/base/sessiondescription.h" -#include "webrtc/pc/srtptestutil.h" -#include "webrtc/rtc_base/buffer.h" -#include "webrtc/rtc_base/byteorder.h" -#include "webrtc/rtc_base/constructormagic.h" #include "webrtc/rtc_base/gunit.h" -#include "webrtc/rtc_base/thread.h" using cricket::CryptoParams; using cricket::CS_LOCAL; @@ -28,14 +21,6 @@ using cricket::CS_REMOTE; namespace rtc { -static const uint8_t kTestKeyGcm128_1[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZ12"; -static const uint8_t kTestKeyGcm128_2[] = "21ZYXWVUTSRQPONMLKJIHGFEDCBA"; -static const int kTestKeyGcm128Len = 28; // 128 bits key + 96 bits salt. -static const uint8_t kTestKeyGcm256_1[] = - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqr"; -static const uint8_t kTestKeyGcm256_2[] = - "rqponmlkjihgfedcbaZYXWVUTSRQPONMLKJIHGFEDCBA"; -static const int kTestKeyGcm256Len = 44; // 256 bits key + 96 bits salt. static const std::string kTestKeyParams1 = "inline:WVNfX19zZW1jdGwgKCkgewkyMjA7fQp9CnVubGVz"; static const std::string kTestKeyParams2 = @@ -67,15 +52,13 @@ static const cricket::CryptoParams kTestCryptoParamsGcm4( class SrtpFilterTest : public testing::Test { protected: - SrtpFilterTest() - // Need to initialize |sequence_number_|, the value does not matter. - : sequence_number_(1) { - } + SrtpFilterTest() {} static std::vector MakeVector(const CryptoParams& params) { std::vector vec; vec.push_back(params); return vec; } + void TestSetParams(const std::vector& params1, const std::vector& params2) { EXPECT_TRUE(f1_.SetOffer(params1, CS_LOCAL)); @@ -87,184 +70,16 @@ class SrtpFilterTest : public testing::Test { EXPECT_TRUE(f1_.IsActive()); EXPECT_TRUE(f2_.IsActive()); } - void TestRtpAuthParams(cricket::SrtpFilter* filter, const std::string& cs) { - int overhead; - EXPECT_TRUE(filter->GetSrtpOverhead(&overhead)); - switch (SrtpCryptoSuiteFromName(cs)) { - case SRTP_AES128_CM_SHA1_32: - EXPECT_EQ(32/8, overhead); // 32-bit tag. - break; - case SRTP_AES128_CM_SHA1_80: - EXPECT_EQ(80/8, overhead); // 80-bit tag. - break; - default: - RTC_NOTREACHED(); - break; - } - uint8_t* auth_key = nullptr; - int key_len = 0; - int tag_len = 0; - EXPECT_TRUE(filter->GetRtpAuthParams(&auth_key, &key_len, &tag_len)); - EXPECT_NE(nullptr, auth_key); - EXPECT_EQ(160/8, key_len); // Length of SHA-1 is 160 bits. - EXPECT_EQ(overhead, tag_len); + void VerifyCryptoParamsMatch(const std::string& cs1, const std::string& cs2) { + EXPECT_EQ(rtc::SrtpCryptoSuiteFromName(cs1), f1_.send_cipher_suite()); + EXPECT_EQ(rtc::SrtpCryptoSuiteFromName(cs2), f2_.send_cipher_suite()); + EXPECT_TRUE(f1_.send_key() == f2_.recv_key()); + EXPECT_TRUE(f2_.send_key() == f1_.recv_key()); } - void TestProtectUnprotect(const std::string& cs1, const std::string& cs2) { - Buffer rtp_buffer(sizeof(kPcmuFrame) + rtp_auth_tag_len(cs1)); - char* rtp_packet = rtp_buffer.data(); - char original_rtp_packet[sizeof(kPcmuFrame)]; - Buffer rtcp_buffer(sizeof(kRtcpReport) + 4 + rtcp_auth_tag_len(cs2)); - char* rtcp_packet = rtcp_buffer.data(); - int rtp_len = sizeof(kPcmuFrame), rtcp_len = sizeof(kRtcpReport), out_len; - memcpy(rtp_packet, kPcmuFrame, rtp_len); - // In order to be able to run this test function multiple times we can not - // use the same sequence number twice. Increase the sequence number by one. - SetBE16(reinterpret_cast(rtp_packet) + 2, ++sequence_number_); - memcpy(original_rtp_packet, rtp_packet, rtp_len); - memcpy(rtcp_packet, kRtcpReport, rtcp_len); - EXPECT_TRUE(f1_.ProtectRtp(rtp_packet, rtp_len, - static_cast(rtp_buffer.size()), - &out_len)); - EXPECT_EQ(out_len, rtp_len + rtp_auth_tag_len(cs1)); - EXPECT_NE(0, memcmp(rtp_packet, original_rtp_packet, rtp_len)); - if (!f1_.IsExternalAuthActive()) { - EXPECT_TRUE(f2_.UnprotectRtp(rtp_packet, out_len, &out_len)); - EXPECT_EQ(rtp_len, out_len); - EXPECT_EQ(0, memcmp(rtp_packet, original_rtp_packet, rtp_len)); - } else { - // With external auth enabled, SRTP doesn't write the auth tag and - // unprotect would fail. Check accessing the information about the - // tag instead, similar to what the actual code would do that relies - // on external auth. - TestRtpAuthParams(&f1_, cs1); - } - - EXPECT_TRUE(f2_.ProtectRtp(rtp_packet, rtp_len, - static_cast(rtp_buffer.size()), - &out_len)); - EXPECT_EQ(out_len, rtp_len + rtp_auth_tag_len(cs2)); - EXPECT_NE(0, memcmp(rtp_packet, original_rtp_packet, rtp_len)); - if (!f2_.IsExternalAuthActive()) { - EXPECT_TRUE(f1_.UnprotectRtp(rtp_packet, out_len, &out_len)); - EXPECT_EQ(rtp_len, out_len); - EXPECT_EQ(0, memcmp(rtp_packet, original_rtp_packet, rtp_len)); - } else { - TestRtpAuthParams(&f2_, cs2); - } - - EXPECT_TRUE(f1_.ProtectRtcp(rtcp_packet, rtcp_len, - static_cast(rtcp_buffer.size()), - &out_len)); - EXPECT_EQ(out_len, rtcp_len + 4 + rtcp_auth_tag_len(cs1)); // NOLINT - EXPECT_NE(0, memcmp(rtcp_packet, kRtcpReport, rtcp_len)); - EXPECT_TRUE(f2_.UnprotectRtcp(rtcp_packet, out_len, &out_len)); - EXPECT_EQ(rtcp_len, out_len); - EXPECT_EQ(0, memcmp(rtcp_packet, kRtcpReport, rtcp_len)); - - EXPECT_TRUE(f2_.ProtectRtcp(rtcp_packet, rtcp_len, - static_cast(rtcp_buffer.size()), - &out_len)); - EXPECT_EQ(out_len, rtcp_len + 4 + rtcp_auth_tag_len(cs2)); // NOLINT - EXPECT_NE(0, memcmp(rtcp_packet, kRtcpReport, rtcp_len)); - EXPECT_TRUE(f1_.UnprotectRtcp(rtcp_packet, out_len, &out_len)); - EXPECT_EQ(rtcp_len, out_len); - EXPECT_EQ(0, memcmp(rtcp_packet, kRtcpReport, rtcp_len)); - } - void TestProtectUnprotectHeaderEncryption(const std::string& cs1, - const std::string& cs2, - const std::vector& encrypted_header_ids) { - Buffer rtp_buffer(sizeof(kPcmuFrameWithExtensions) + rtp_auth_tag_len(cs1)); - char* rtp_packet = rtp_buffer.data(); - size_t rtp_packet_size = rtp_buffer.size(); - char original_rtp_packet[sizeof(kPcmuFrameWithExtensions)]; - size_t original_rtp_packet_size = sizeof(original_rtp_packet); - int rtp_len = sizeof(kPcmuFrameWithExtensions), out_len; - memcpy(rtp_packet, kPcmuFrameWithExtensions, rtp_len); - // In order to be able to run this test function multiple times we can not - // use the same sequence number twice. Increase the sequence number by one. - SetBE16(reinterpret_cast(rtp_packet) + 2, ++sequence_number_); - memcpy(original_rtp_packet, rtp_packet, rtp_len); - - EXPECT_TRUE(f1_.ProtectRtp(rtp_packet, rtp_len, - static_cast(rtp_buffer.size()), - &out_len)); - EXPECT_EQ(out_len, rtp_len + rtp_auth_tag_len(cs1)); - EXPECT_NE(0, memcmp(rtp_packet, original_rtp_packet, rtp_len)); - CompareHeaderExtensions(rtp_packet, rtp_packet_size, - original_rtp_packet, original_rtp_packet_size, - encrypted_header_ids, false); - EXPECT_TRUE(f2_.UnprotectRtp(rtp_packet, out_len, &out_len)); - EXPECT_EQ(rtp_len, out_len); - EXPECT_EQ(0, memcmp(rtp_packet, original_rtp_packet, rtp_len)); - CompareHeaderExtensions(rtp_packet, rtp_packet_size, - original_rtp_packet, original_rtp_packet_size, - encrypted_header_ids, true); - - EXPECT_TRUE(f2_.ProtectRtp(rtp_packet, rtp_len, - static_cast(rtp_buffer.size()), - &out_len)); - EXPECT_EQ(out_len, rtp_len + rtp_auth_tag_len(cs2)); - EXPECT_NE(0, memcmp(rtp_packet, original_rtp_packet, rtp_len)); - CompareHeaderExtensions(rtp_packet, rtp_packet_size, - original_rtp_packet, original_rtp_packet_size, - encrypted_header_ids, false); - EXPECT_TRUE(f1_.UnprotectRtp(rtp_packet, out_len, &out_len)); - EXPECT_EQ(rtp_len, out_len); - EXPECT_EQ(0, memcmp(rtp_packet, original_rtp_packet, rtp_len)); - CompareHeaderExtensions(rtp_packet, rtp_packet_size, - original_rtp_packet, original_rtp_packet_size, - encrypted_header_ids, true); - } - void TestProtectSetParamsDirect(bool enable_external_auth, int cs, - const uint8_t* key1, int key1_len, const uint8_t* key2, int key2_len, - const std::string& cs_name) { - EXPECT_EQ(key1_len, key2_len); - EXPECT_EQ(cs_name, SrtpCryptoSuiteToName(cs)); - if (enable_external_auth) { - f1_.EnableExternalAuth(); - f2_.EnableExternalAuth(); - } - EXPECT_TRUE(f1_.SetRtpParams(cs, key1, key1_len, cs, key2, key2_len)); - EXPECT_TRUE(f2_.SetRtpParams(cs, key2, key2_len, cs, key1, key1_len)); - EXPECT_TRUE(f1_.SetRtcpParams(cs, key1, key1_len, cs, key2, key2_len)); - EXPECT_TRUE(f2_.SetRtcpParams(cs, key2, key2_len, cs, key1, key1_len)); - EXPECT_TRUE(f1_.IsActive()); - EXPECT_TRUE(f2_.IsActive()); - if (IsGcmCryptoSuite(cs)) { - EXPECT_FALSE(f1_.IsExternalAuthActive()); - EXPECT_FALSE(f2_.IsExternalAuthActive()); - } else if (enable_external_auth) { - EXPECT_TRUE(f1_.IsExternalAuthActive()); - EXPECT_TRUE(f2_.IsExternalAuthActive()); - } - TestProtectUnprotect(cs_name, cs_name); - } - void TestProtectSetParamsDirectHeaderEncryption(int cs, - const uint8_t* key1, int key1_len, const uint8_t* key2, int key2_len, - const std::string& cs_name) { - std::vector encrypted_headers; - encrypted_headers.push_back(1); - // Don't encrypt header ids 2 and 3. - encrypted_headers.push_back(4); - EXPECT_EQ(key1_len, key2_len); - EXPECT_EQ(cs_name, SrtpCryptoSuiteToName(cs)); - f1_.SetEncryptedHeaderExtensionIds(CS_LOCAL, encrypted_headers); - f1_.SetEncryptedHeaderExtensionIds(CS_REMOTE, encrypted_headers); - f2_.SetEncryptedHeaderExtensionIds(CS_LOCAL, encrypted_headers); - f2_.SetEncryptedHeaderExtensionIds(CS_REMOTE, encrypted_headers); - EXPECT_TRUE(f1_.SetRtpParams(cs, key1, key1_len, cs, key2, key2_len)); - EXPECT_TRUE(f2_.SetRtpParams(cs, key2, key2_len, cs, key1, key1_len)); - EXPECT_TRUE(f1_.IsActive()); - EXPECT_TRUE(f2_.IsActive()); - EXPECT_FALSE(f1_.IsExternalAuthActive()); - EXPECT_FALSE(f2_.IsExternalAuthActive()); - TestProtectUnprotectHeaderEncryption(cs_name, cs_name, encrypted_headers); - } cricket::SrtpFilter f1_; cricket::SrtpFilter f2_; - int sequence_number_; }; // Test that we can set up the session and keys properly. @@ -478,22 +293,6 @@ TEST_F(SrtpFilterTest, TestUnsupportedOptions) { EXPECT_FALSE(f1_.IsActive()); } -// Test that we can encrypt/decrypt after setting the same CryptoParams again on -// one side. -TEST_F(SrtpFilterTest, TestSettingSameKeyOnOneSide) { - std::vector offer(MakeVector(kTestCryptoParams1)); - std::vector answer(MakeVector(kTestCryptoParams2)); - TestSetParams(offer, answer); - - TestProtectUnprotect(CS_AES_CM_128_HMAC_SHA1_80, - CS_AES_CM_128_HMAC_SHA1_80); - - // Re-applying the same keys on one end and it should not reset the ROC. - EXPECT_TRUE(f2_.SetOffer(offer, CS_REMOTE)); - EXPECT_TRUE(f2_.SetAnswer(answer, CS_LOCAL)); - TestProtectUnprotect(CS_AES_CM_128_HMAC_SHA1_80, CS_AES_CM_128_HMAC_SHA1_80); -} - // Test that we can encrypt/decrypt after negotiating AES_CM_128_HMAC_SHA1_80. TEST_F(SrtpFilterTest, TestProtect_AES_CM_128_HMAC_SHA1_80) { std::vector offer(MakeVector(kTestCryptoParams1)); @@ -502,7 +301,8 @@ TEST_F(SrtpFilterTest, TestProtect_AES_CM_128_HMAC_SHA1_80) { offer[1].tag = 2; offer[1].cipher_suite = CS_AES_CM_128_HMAC_SHA1_32; TestSetParams(offer, answer); - TestProtectUnprotect(CS_AES_CM_128_HMAC_SHA1_80, CS_AES_CM_128_HMAC_SHA1_80); + VerifyCryptoParamsMatch(CS_AES_CM_128_HMAC_SHA1_80, + CS_AES_CM_128_HMAC_SHA1_80); } // Test that we can encrypt/decrypt after negotiating AES_CM_128_HMAC_SHA1_32. @@ -515,7 +315,8 @@ TEST_F(SrtpFilterTest, TestProtect_AES_CM_128_HMAC_SHA1_32) { answer[0].tag = 2; answer[0].cipher_suite = CS_AES_CM_128_HMAC_SHA1_32; TestSetParams(offer, answer); - TestProtectUnprotect(CS_AES_CM_128_HMAC_SHA1_32, CS_AES_CM_128_HMAC_SHA1_32); + VerifyCryptoParamsMatch(CS_AES_CM_128_HMAC_SHA1_32, + CS_AES_CM_128_HMAC_SHA1_32); } // Test that we can change encryption parameters. @@ -524,7 +325,8 @@ TEST_F(SrtpFilterTest, TestChangeParameters) { std::vector answer(MakeVector(kTestCryptoParams2)); TestSetParams(offer, answer); - TestProtectUnprotect(CS_AES_CM_128_HMAC_SHA1_80, CS_AES_CM_128_HMAC_SHA1_80); + VerifyCryptoParamsMatch(CS_AES_CM_128_HMAC_SHA1_80, + CS_AES_CM_128_HMAC_SHA1_80); // Change the key parameters and cipher_suite. offer[0].key_params = kTestKeyParams3; @@ -538,13 +340,15 @@ TEST_F(SrtpFilterTest, TestChangeParameters) { EXPECT_TRUE(f1_.IsActive()); // Test that the old keys are valid until the negotiation is complete. - TestProtectUnprotect(CS_AES_CM_128_HMAC_SHA1_80, CS_AES_CM_128_HMAC_SHA1_80); + VerifyCryptoParamsMatch(CS_AES_CM_128_HMAC_SHA1_80, + CS_AES_CM_128_HMAC_SHA1_80); // Complete the negotiation and test that we can still understand each other. EXPECT_TRUE(f2_.SetAnswer(answer, CS_LOCAL)); EXPECT_TRUE(f1_.SetAnswer(answer, CS_REMOTE)); - TestProtectUnprotect(CS_AES_CM_128_HMAC_SHA1_32, CS_AES_CM_128_HMAC_SHA1_32); + VerifyCryptoParamsMatch(CS_AES_CM_128_HMAC_SHA1_32, + CS_AES_CM_128_HMAC_SHA1_32); } // Test that we can send and receive provisional answers with crypto enabled. @@ -564,7 +368,8 @@ TEST_F(SrtpFilterTest, TestProvisionalAnswer) { EXPECT_TRUE(f1_.SetProvisionalAnswer(answer, CS_REMOTE)); EXPECT_TRUE(f1_.IsActive()); EXPECT_TRUE(f2_.IsActive()); - TestProtectUnprotect(CS_AES_CM_128_HMAC_SHA1_80, CS_AES_CM_128_HMAC_SHA1_80); + VerifyCryptoParamsMatch(CS_AES_CM_128_HMAC_SHA1_80, + CS_AES_CM_128_HMAC_SHA1_80); answer[0].key_params = kTestKeyParams4; answer[0].tag = 2; @@ -573,7 +378,8 @@ TEST_F(SrtpFilterTest, TestProvisionalAnswer) { EXPECT_TRUE(f1_.SetAnswer(answer, CS_REMOTE)); EXPECT_TRUE(f1_.IsActive()); EXPECT_TRUE(f2_.IsActive()); - TestProtectUnprotect(CS_AES_CM_128_HMAC_SHA1_32, CS_AES_CM_128_HMAC_SHA1_32); + VerifyCryptoParamsMatch(CS_AES_CM_128_HMAC_SHA1_32, + CS_AES_CM_128_HMAC_SHA1_32); } // Test that a provisional answer doesn't need to contain a crypto. @@ -595,7 +401,8 @@ TEST_F(SrtpFilterTest, TestProvisionalAnswerWithoutCrypto) { EXPECT_TRUE(f1_.SetAnswer(answer, CS_REMOTE)); EXPECT_TRUE(f1_.IsActive()); EXPECT_TRUE(f2_.IsActive()); - TestProtectUnprotect(CS_AES_CM_128_HMAC_SHA1_80, CS_AES_CM_128_HMAC_SHA1_80); + VerifyCryptoParamsMatch(CS_AES_CM_128_HMAC_SHA1_80, + CS_AES_CM_128_HMAC_SHA1_80); } // Test that if we get a new local offer after a provisional answer @@ -622,7 +429,8 @@ TEST_F(SrtpFilterTest, TestLocalOfferAfterProvisionalAnswerWithoutCrypto) { EXPECT_TRUE(f1_.SetAnswer(answer, CS_REMOTE)); EXPECT_TRUE(f1_.IsActive()); EXPECT_TRUE(f2_.IsActive()); - TestProtectUnprotect(CS_AES_CM_128_HMAC_SHA1_80, CS_AES_CM_128_HMAC_SHA1_80); + VerifyCryptoParamsMatch(CS_AES_CM_128_HMAC_SHA1_80, + CS_AES_CM_128_HMAC_SHA1_80); } // Test that we can disable encryption. @@ -631,7 +439,8 @@ TEST_F(SrtpFilterTest, TestDisableEncryption) { std::vector answer(MakeVector(kTestCryptoParams2)); TestSetParams(offer, answer); - TestProtectUnprotect(CS_AES_CM_128_HMAC_SHA1_80, CS_AES_CM_128_HMAC_SHA1_80); + VerifyCryptoParamsMatch(CS_AES_CM_128_HMAC_SHA1_80, + CS_AES_CM_128_HMAC_SHA1_80); offer.clear(); answer.clear(); @@ -641,7 +450,8 @@ TEST_F(SrtpFilterTest, TestDisableEncryption) { EXPECT_TRUE(f2_.IsActive()); // Test that the old keys are valid until the negotiation is complete. - TestProtectUnprotect(CS_AES_CM_128_HMAC_SHA1_80, CS_AES_CM_128_HMAC_SHA1_80); + VerifyCryptoParamsMatch(CS_AES_CM_128_HMAC_SHA1_80, + CS_AES_CM_128_HMAC_SHA1_80); // Complete the negotiation. EXPECT_TRUE(f2_.SetAnswer(answer, CS_LOCAL)); @@ -651,86 +461,4 @@ TEST_F(SrtpFilterTest, TestDisableEncryption) { EXPECT_FALSE(f2_.IsActive()); } -class SrtpFilterProtectSetParamsDirectTest - : public SrtpFilterTest, - public testing::WithParamInterface { -}; - -// Test directly setting the params with AES_CM_128_HMAC_SHA1_80. -TEST_P(SrtpFilterProtectSetParamsDirectTest, Test_AES_CM_128_HMAC_SHA1_80) { - bool enable_external_auth = GetParam(); - TestProtectSetParamsDirect(enable_external_auth, SRTP_AES128_CM_SHA1_80, - kTestKey1, kTestKeyLen, kTestKey2, kTestKeyLen, - CS_AES_CM_128_HMAC_SHA1_80); -} - -TEST_F(SrtpFilterTest, - TestProtectSetParamsDirectHeaderEncryption_AES_CM_128_HMAC_SHA1_80) { - TestProtectSetParamsDirectHeaderEncryption( - SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen, kTestKey2, kTestKeyLen, - CS_AES_CM_128_HMAC_SHA1_80); -} - -// Test directly setting the params with AES_CM_128_HMAC_SHA1_32. -TEST_P(SrtpFilterProtectSetParamsDirectTest, Test_AES_CM_128_HMAC_SHA1_32) { - bool enable_external_auth = GetParam(); - TestProtectSetParamsDirect(enable_external_auth, SRTP_AES128_CM_SHA1_32, - kTestKey1, kTestKeyLen, kTestKey2, kTestKeyLen, - CS_AES_CM_128_HMAC_SHA1_32); -} - -TEST_F(SrtpFilterTest, - TestProtectSetParamsDirectHeaderEncryption_AES_CM_128_HMAC_SHA1_32) { - TestProtectSetParamsDirectHeaderEncryption( - SRTP_AES128_CM_SHA1_32, kTestKey1, kTestKeyLen, kTestKey2, kTestKeyLen, - CS_AES_CM_128_HMAC_SHA1_32); -} - -// Test directly setting the params with SRTP_AEAD_AES_128_GCM. -TEST_P(SrtpFilterProtectSetParamsDirectTest, Test_SRTP_AEAD_AES_128_GCM) { - bool enable_external_auth = GetParam(); - TestProtectSetParamsDirect(enable_external_auth, SRTP_AEAD_AES_128_GCM, - kTestKeyGcm128_1, kTestKeyGcm128Len, - kTestKeyGcm128_2, kTestKeyGcm128Len, - CS_AEAD_AES_128_GCM); -} - -TEST_F(SrtpFilterTest, - TestProtectSetParamsDirectHeaderEncryption_SRTP_AEAD_AES_128_GCM) { - TestProtectSetParamsDirectHeaderEncryption( - SRTP_AEAD_AES_128_GCM, kTestKeyGcm128_1, kTestKeyGcm128Len, - kTestKeyGcm128_2, kTestKeyGcm128Len, CS_AEAD_AES_128_GCM); -} - -// Test directly setting the params with SRTP_AEAD_AES_256_GCM. -TEST_P(SrtpFilterProtectSetParamsDirectTest, Test_SRTP_AEAD_AES_256_GCM) { - bool enable_external_auth = GetParam(); - TestProtectSetParamsDirect(enable_external_auth, SRTP_AEAD_AES_256_GCM, - kTestKeyGcm256_1, kTestKeyGcm256Len, - kTestKeyGcm256_2, kTestKeyGcm256Len, - CS_AEAD_AES_256_GCM); -} - -TEST_F(SrtpFilterTest, - TestProtectSetParamsDirectHeaderEncryption_SRTP_AEAD_AES_256_GCM) { - TestProtectSetParamsDirectHeaderEncryption( - SRTP_AEAD_AES_256_GCM, kTestKeyGcm256_1, kTestKeyGcm256Len, - kTestKeyGcm256_2, kTestKeyGcm256Len, CS_AEAD_AES_256_GCM); -} - -// Run all tests both with and without external auth enabled. -INSTANTIATE_TEST_CASE_P(ExternalAuth, - SrtpFilterProtectSetParamsDirectTest, - ::testing::Values(true, false)); - -// Test directly setting the params with bogus keys. -TEST_F(SrtpFilterTest, TestSetParamsKeyTooShort) { - EXPECT_FALSE(f1_.SetRtpParams(SRTP_AES128_CM_SHA1_80, kTestKey1, - kTestKeyLen - 1, SRTP_AES128_CM_SHA1_80, - kTestKey1, kTestKeyLen - 1)); - EXPECT_FALSE(f1_.SetRtcpParams(SRTP_AES128_CM_SHA1_80, kTestKey1, - kTestKeyLen - 1, SRTP_AES128_CM_SHA1_80, - kTestKey1, kTestKeyLen - 1)); -} - } // namespace rtc diff --git a/webrtc/pc/srtptransport.cc b/webrtc/pc/srtptransport.cc index 6e6ff06274..d9f054be7e 100644 --- a/webrtc/pc/srtptransport.cc +++ b/webrtc/pc/srtptransport.cc @@ -16,6 +16,7 @@ #include "webrtc/pc/rtptransport.h" #include "webrtc/pc/srtpsession.h" #include "webrtc/rtc_base/asyncpacketsocket.h" +#include "webrtc/rtc_base/base64.h" #include "webrtc/rtc_base/copyonwritebuffer.h" #include "webrtc/rtc_base/ptr_util.h" #include "webrtc/rtc_base/trace_event.h" @@ -42,21 +43,322 @@ void SrtpTransport::ConnectToRtpTransport() { &SrtpTransport::OnReadyToSend); } +bool SrtpTransport::SendRtpPacket(rtc::CopyOnWriteBuffer* packet, + const rtc::PacketOptions& options, + int flags) { + return SendPacket(false, packet, options, flags); +} + +bool SrtpTransport::SendRtcpPacket(rtc::CopyOnWriteBuffer* packet, + const rtc::PacketOptions& options, + int flags) { + return SendPacket(true, packet, options, flags); +} + bool SrtpTransport::SendPacket(bool rtcp, rtc::CopyOnWriteBuffer* packet, const rtc::PacketOptions& options, int flags) { - // TODO(zstein): Protect packet. + if (!IsActive()) { + LOG(LS_ERROR) + << "Failed to send the packet because SRTP transport is inactive."; + return false; + } - return rtp_transport_->SendPacket(rtcp, packet, options, flags); + rtc::PacketOptions updated_options = options; + rtc::CopyOnWriteBuffer cp = *packet; + TRACE_EVENT0("webrtc", "SRTP Encode"); + bool res; + uint8_t* data = packet->data(); + int len = static_cast(packet->size()); + if (!rtcp) { +// If ENABLE_EXTERNAL_AUTH flag is on then packet authentication is not done +// inside libsrtp for a RTP packet. A external HMAC module will be writing +// a fake HMAC value. This is ONLY done for a RTP packet. +// Socket layer will update rtp sendtime extension header if present in +// packet with current time before updating the HMAC. +#if !defined(ENABLE_EXTERNAL_AUTH) + res = ProtectRtp(data, len, static_cast(packet->capacity()), &len); +#else + if (!IsExternalAuthActive()) { + res = ProtectRtp(data, len, static_cast(packet->capacity()), &len); + } else { + updated_options.packet_time_params.rtp_sendtime_extension_id = + rtp_abs_sendtime_extn_id_; + res = ProtectRtp(data, len, static_cast(packet->capacity()), &len, + &updated_options.packet_time_params.srtp_packet_index); + // If protection succeeds, let's get auth params from srtp. + if (res) { + uint8_t* auth_key = NULL; + int key_len; + res = GetRtpAuthParams( + &auth_key, &key_len, + &updated_options.packet_time_params.srtp_auth_tag_len); + if (res) { + updated_options.packet_time_params.srtp_auth_key.resize(key_len); + updated_options.packet_time_params.srtp_auth_key.assign( + auth_key, auth_key + key_len); + } + } + } +#endif + if (!res) { + int seq_num = -1; + uint32_t ssrc = 0; + cricket::GetRtpSeqNum(data, len, &seq_num); + cricket::GetRtpSsrc(data, len, &ssrc); + LOG(LS_ERROR) << "Failed to protect " << content_name_ + << " RTP packet: size=" << len << ", seqnum=" << seq_num + << ", SSRC=" << ssrc; + return false; + } + } else { + res = ProtectRtcp(data, len, static_cast(packet->capacity()), &len); + if (!res) { + int type = -1; + cricket::GetRtcpType(data, len, &type); + LOG(LS_ERROR) << "Failed to protect " << content_name_ + << " RTCP packet: size=" << len << ", type=" << type; + return false; + } + } + + // Update the length of the packet now that we've added the auth tag. + packet->SetSize(len); + return rtcp ? rtp_transport_->SendRtcpPacket(packet, updated_options, flags) + : rtp_transport_->SendRtpPacket(packet, updated_options, flags); } void SrtpTransport::OnPacketReceived(bool rtcp, rtc::CopyOnWriteBuffer* packet, const rtc::PacketTime& packet_time) { - // TODO(zstein): Unprotect packet. + if (!IsActive()) { + LOG(LS_WARNING) << "Inactive SRTP transport received a packet. Drop it."; + return; + } + TRACE_EVENT0("webrtc", "SRTP Decode"); + char* data = packet->data(); + int len = static_cast(packet->size()); + bool res; + if (!rtcp) { + res = UnprotectRtp(data, len, &len); + if (!res) { + int seq_num = -1; + uint32_t ssrc = 0; + cricket::GetRtpSeqNum(data, len, &seq_num); + cricket::GetRtpSsrc(data, len, &ssrc); + LOG(LS_ERROR) << "Failed to unprotect " << content_name_ + << " RTP packet: size=" << len << ", seqnum=" << seq_num + << ", SSRC=" << ssrc; + return; + } + } else { + res = UnprotectRtcp(data, len, &len); + if (!res) { + int type = -1; + cricket::GetRtcpType(data, len, &type); + LOG(LS_ERROR) << "Failed to unprotect " << content_name_ + << " RTCP packet: size=" << len << ", type=" << type; + return; + } + } + + packet->SetSize(len); SignalPacketReceived(rtcp, packet, packet_time); } +bool SrtpTransport::SetRtpParams(int send_cs, + const uint8_t* send_key, + int send_key_len, + int recv_cs, + const uint8_t* recv_key, + int recv_key_len) { + CreateSrtpSessions(); + send_session_->SetEncryptedHeaderExtensionIds( + send_encrypted_header_extension_ids_); + if (external_auth_enabled_) { + send_session_->EnableExternalAuth(); + } + if (!send_session_->SetSend(send_cs, send_key, send_key_len)) { + ResetParams(); + return false; + } + + recv_session_->SetEncryptedHeaderExtensionIds( + recv_encrypted_header_extension_ids_); + if (!recv_session_->SetRecv(recv_cs, recv_key, recv_key_len)) { + ResetParams(); + return false; + } + + LOG(LS_INFO) << "SRTP activated with negotiated parameters:" + << " send cipher_suite " << send_cs << " recv cipher_suite " + << recv_cs; + return true; +} + +bool SrtpTransport::SetRtcpParams(int send_cs, + const uint8_t* send_key, + int send_key_len, + int recv_cs, + const uint8_t* recv_key, + int recv_key_len) { + // This can only be called once, but can be safely called after + // SetRtpParams + if (send_rtcp_session_ || recv_rtcp_session_) { + LOG(LS_ERROR) << "Tried to set SRTCP Params when filter already active"; + return false; + } + + send_rtcp_session_.reset(new cricket::SrtpSession()); + if (!send_rtcp_session_->SetRecv(send_cs, send_key, send_key_len)) { + return false; + } + + recv_rtcp_session_.reset(new cricket::SrtpSession()); + if (!recv_rtcp_session_->SetRecv(recv_cs, recv_key, recv_key_len)) { + return false; + } + + LOG(LS_INFO) << "SRTCP activated with negotiated parameters:" + << " send cipher_suite " << send_cs << " recv cipher_suite " + << recv_cs; + + return true; +} + +bool SrtpTransport::IsActive() const { + return send_session_ && recv_session_; +} + +void SrtpTransport::ResetParams() { + send_session_ = nullptr; + recv_session_ = nullptr; + send_rtcp_session_ = nullptr; + recv_rtcp_session_ = nullptr; + LOG(LS_INFO) << "The params in SRTP transport are reset."; +} + +void SrtpTransport::SetEncryptedHeaderExtensionIds( + cricket::ContentSource source, + const std::vector& extension_ids) { + if (source == cricket::CS_LOCAL) { + recv_encrypted_header_extension_ids_ = extension_ids; + } else { + send_encrypted_header_extension_ids_ = extension_ids; + } +} + +void SrtpTransport::CreateSrtpSessions() { + send_session_.reset(new cricket::SrtpSession()); + recv_session_.reset(new cricket::SrtpSession()); + + if (external_auth_enabled_) { + send_session_->EnableExternalAuth(); + } +} + +bool SrtpTransport::ProtectRtp(void* p, int in_len, int max_len, int* out_len) { + if (!IsActive()) { + LOG(LS_WARNING) << "Failed to ProtectRtp: SRTP not active"; + return false; + } + RTC_CHECK(send_session_); + return send_session_->ProtectRtp(p, in_len, max_len, out_len); +} + +bool SrtpTransport::ProtectRtp(void* p, + int in_len, + int max_len, + int* out_len, + int64_t* index) { + if (!IsActive()) { + LOG(LS_WARNING) << "Failed to ProtectRtp: SRTP not active"; + return false; + } + RTC_CHECK(send_session_); + return send_session_->ProtectRtp(p, in_len, max_len, out_len, index); +} + +bool SrtpTransport::ProtectRtcp(void* p, + int in_len, + int max_len, + int* out_len) { + if (!IsActive()) { + LOG(LS_WARNING) << "Failed to ProtectRtcp: SRTP not active"; + return false; + } + if (send_rtcp_session_) { + return send_rtcp_session_->ProtectRtcp(p, in_len, max_len, out_len); + } else { + RTC_CHECK(send_session_); + return send_session_->ProtectRtcp(p, in_len, max_len, out_len); + } +} + +bool SrtpTransport::UnprotectRtp(void* p, int in_len, int* out_len) { + if (!IsActive()) { + LOG(LS_WARNING) << "Failed to UnprotectRtp: SRTP not active"; + return false; + } + RTC_CHECK(recv_session_); + return recv_session_->UnprotectRtp(p, in_len, out_len); +} + +bool SrtpTransport::UnprotectRtcp(void* p, int in_len, int* out_len) { + if (!IsActive()) { + LOG(LS_WARNING) << "Failed to UnprotectRtcp: SRTP not active"; + return false; + } + if (recv_rtcp_session_) { + return recv_rtcp_session_->UnprotectRtcp(p, in_len, out_len); + } else { + RTC_CHECK(recv_session_); + return recv_session_->UnprotectRtcp(p, in_len, out_len); + } +} + +bool SrtpTransport::GetRtpAuthParams(uint8_t** key, + int* key_len, + int* tag_len) { + if (!IsActive()) { + LOG(LS_WARNING) << "Failed to GetRtpAuthParams: SRTP not active"; + return false; + } + + RTC_CHECK(send_session_); + return send_session_->GetRtpAuthParams(key, key_len, tag_len); +} + +bool SrtpTransport::GetSrtpOverhead(int* srtp_overhead) const { + if (!IsActive()) { + LOG(LS_WARNING) << "Failed to GetSrtpOverhead: SRTP not active"; + return false; + } + + RTC_CHECK(send_session_); + *srtp_overhead = send_session_->GetSrtpOverhead(); + return true; +} + +void SrtpTransport::EnableExternalAuth() { + RTC_DCHECK(!IsActive()); + external_auth_enabled_ = true; +} + +bool SrtpTransport::IsExternalAuthEnabled() const { + return external_auth_enabled_; +} + +bool SrtpTransport::IsExternalAuthActive() const { + if (!IsActive()) { + LOG(LS_WARNING) << "Failed to check IsExternalAuthActive: SRTP not active"; + return false; + } + + RTC_CHECK(send_session_); + return send_session_->IsExternalAuthActive(); +} + } // namespace webrtc diff --git a/webrtc/pc/srtptransport.h b/webrtc/pc/srtptransport.h index 58ef205f87..769df0620a 100644 --- a/webrtc/pc/srtptransport.h +++ b/webrtc/pc/srtptransport.h @@ -17,20 +17,17 @@ #include "webrtc/pc/rtptransportinternal.h" #include "webrtc/pc/srtpfilter.h" +#include "webrtc/pc/srtpsession.h" #include "webrtc/rtc_base/checks.h" namespace webrtc { // This class will eventually be a wrapper around RtpTransportInternal -// that protects and unprotects sent and received RTP packets. This -// functionality is currently implemented by SrtpFilter and BaseChannel, but -// will be moved here in the future. +// that protects and unprotects sent and received RTP packets. class SrtpTransport : public RtpTransportInternal { public: SrtpTransport(bool rtcp_mux_enabled, const std::string& content_name); - // TODO(zstein): Consider taking an RtpTransport instead of an - // RtpTransportInternal. SrtpTransport(std::unique_ptr transport, const std::string& content_name); @@ -61,14 +58,21 @@ class SrtpTransport : public RtpTransportInternal { return rtp_transport_->GetRtcpPacketTransport(); } + bool SendRtpPacket(rtc::CopyOnWriteBuffer* packet, + const rtc::PacketOptions& options, + int flags) override; + + bool SendRtcpPacket(rtc::CopyOnWriteBuffer* packet, + const rtc::PacketOptions& options, + int flags) override; + bool IsWritable(bool rtcp) const override { return rtp_transport_->IsWritable(rtcp); } - bool SendPacket(bool rtcp, - rtc::CopyOnWriteBuffer* packet, - const rtc::PacketOptions& options, - int flags) override; + // The transport becomes active if the send_session_ and recv_session_ are + // created. + bool IsActive() const; bool HandlesPayloadType(int payload_type) const override { return rtp_transport_->HandlesPayloadType(payload_type); @@ -89,18 +93,104 @@ class SrtpTransport : public RtpTransportInternal { // TODO(zstein): Remove this when we remove RtpTransportAdapter. RtpTransportAdapter* GetInternal() override { return nullptr; } + // Create new send/recv sessions and set the negotiated crypto keys for RTP + // packet encryption. The keys can either come from SDES negotiation or DTLS + // handshake. + bool SetRtpParams(int send_cs, + const uint8_t* send_key, + int send_key_len, + int recv_cs, + const uint8_t* recv_key, + int recv_key_len); + + // Create new send/recv sessions and set the negotiated crypto keys for RTCP + // packet encryption. The keys can either come from SDES negotiation or DTLS + // handshake. + bool SetRtcpParams(int send_cs, + const uint8_t* send_key, + int send_key_len, + int recv_cs, + const uint8_t* recv_key, + int recv_key_len); + + void ResetParams(); + + // Set the header extension ids that should be encrypted for the given source. + // This method doesn't immediately update the SRTP session with the new IDs, + // and you need to call SetRtpParams for that to happen. + void SetEncryptedHeaderExtensionIds(cricket::ContentSource source, + const std::vector& extension_ids); + + // If external auth is enabled, SRTP will write a dummy auth tag that then + // later must get replaced before the packet is sent out. Only supported for + // non-GCM cipher suites and can be checked through "IsExternalAuthActive" + // if it is actually used. This method is only valid before the RTP params + // have been set. + void EnableExternalAuth(); + bool IsExternalAuthEnabled() const; + + // A SrtpTransport supports external creation of the auth tag if a non-GCM + // cipher is used. This method is only valid after the RTP params have + // been set. + bool IsExternalAuthActive() const; + + // Returns srtp overhead for rtp packets. + bool GetSrtpOverhead(int* srtp_overhead) const; + + // Returns rtp auth params from srtp context. + bool GetRtpAuthParams(uint8_t** key, int* key_len, int* tag_len); + + // Helper method to get RTP Absoulute SendTime extension header id if + // present in remote supported extensions list. + void CacheRtpAbsSendTimeHeaderExtension(int rtp_abs_sendtime_extn_id) { + rtp_abs_sendtime_extn_id_ = rtp_abs_sendtime_extn_id; + } + private: + void CreateSrtpSessions(); + void ConnectToRtpTransport(); + bool SendPacket(bool rtcp, + rtc::CopyOnWriteBuffer* packet, + const rtc::PacketOptions& options, + int flags); + void OnPacketReceived(bool rtcp, rtc::CopyOnWriteBuffer* packet, const rtc::PacketTime& packet_time); void OnReadyToSend(bool ready) { SignalReadyToSend(ready); } - const std::string content_name_; + bool ProtectRtp(void* data, int in_len, int max_len, int* out_len); + // Overloaded version, outputs packet index. + bool ProtectRtp(void* data, + int in_len, + int max_len, + int* out_len, + int64_t* index); + bool ProtectRtcp(void* data, int in_len, int max_len, int* out_len); + + // Decrypts/verifies an invidiual RTP/RTCP packet. + // If an HMAC is used, this will decrease the packet size. + bool UnprotectRtp(void* data, int in_len, int* out_len); + + bool UnprotectRtcp(void* data, int in_len, int* out_len); + + const std::string content_name_; std::unique_ptr rtp_transport_; + + std::unique_ptr send_session_; + std::unique_ptr recv_session_; + std::unique_ptr send_rtcp_session_; + std::unique_ptr recv_rtcp_session_; + + std::vector send_encrypted_header_extension_ids_; + std::vector recv_encrypted_header_extension_ids_; + bool external_auth_enabled_ = false; + + int rtp_abs_sendtime_extn_id_ = -1; }; } // namespace webrtc diff --git a/webrtc/pc/srtptransport_unittest.cc b/webrtc/pc/srtptransport_unittest.cc index e54dac3ea2..e097ff17cd 100644 --- a/webrtc/pc/srtptransport_unittest.cc +++ b/webrtc/pc/srtptransport_unittest.cc @@ -10,67 +10,413 @@ #include "webrtc/pc/srtptransport.h" +#include "webrtc/media/base/fakertp.h" +#include "webrtc/p2p/base/dtlstransportinternal.h" +#include "webrtc/p2p/base/fakepackettransport.h" #include "webrtc/pc/rtptransport.h" #include "webrtc/pc/rtptransporttestutil.h" +#include "webrtc/pc/srtptestutil.h" #include "webrtc/rtc_base/asyncpacketsocket.h" #include "webrtc/rtc_base/gunit.h" #include "webrtc/rtc_base/ptr_util.h" -#include "webrtc/test/gmock.h" +#include "webrtc/rtc_base/sslstreamadapter.h" + +using rtc::kTestKey1; +using rtc::kTestKey2; +using rtc::kTestKeyLen; +using rtc::SRTP_AEAD_AES_128_GCM; namespace webrtc { +static const uint8_t kTestKeyGcm128_1[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZ12"; +static const uint8_t kTestKeyGcm128_2[] = "21ZYXWVUTSRQPONMLKJIHGFEDCBA"; +static const int kTestKeyGcm128Len = 28; // 128 bits key + 96 bits salt. +static const uint8_t kTestKeyGcm256_1[] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqr"; +static const uint8_t kTestKeyGcm256_2[] = + "rqponmlkjihgfedcbaZYXWVUTSRQPONMLKJIHGFEDCBA"; +static const int kTestKeyGcm256Len = 44; // 256 bits key + 96 bits salt. -using testing::_; -using testing::Return; +class SrtpTransportTest : public testing::Test, public sigslot::has_slots<> { + protected: + SrtpTransportTest() { + bool rtcp_mux_enabled = true; + auto rtp_transport1 = rtc::MakeUnique(rtcp_mux_enabled); + auto rtp_transport2 = rtc::MakeUnique(rtcp_mux_enabled); -class MockRtpTransport : public RtpTransport { - public: - MockRtpTransport() : RtpTransport(true) {} + rtp_packet_transport1_ = + rtc::MakeUnique("fake_packet_transport1"); + rtp_packet_transport2_ = + rtc::MakeUnique("fake_packet_transport2"); - MOCK_METHOD4(SendPacket, - bool(bool rtcp, - rtc::CopyOnWriteBuffer* packet, - const rtc::PacketOptions& options, - int flags)); + bool asymmetric = false; + rtp_packet_transport1_->SetDestination(rtp_packet_transport2_.get(), + asymmetric); - void PretendReceivedPacket() { - bool rtcp = false; - rtc::CopyOnWriteBuffer buffer; - rtc::PacketTime time; - SignalPacketReceived(rtcp, &buffer, time); + rtp_transport1->SetRtpPacketTransport(rtp_packet_transport1_.get()); + rtp_transport2->SetRtpPacketTransport(rtp_packet_transport2_.get()); + + // Add payload type for RTP packet and RTCP packet. + rtp_transport1->AddHandledPayloadType(0x00); + rtp_transport2->AddHandledPayloadType(0x00); + rtp_transport1->AddHandledPayloadType(0xc9); + rtp_transport2->AddHandledPayloadType(0xc9); + + srtp_transport1_ = + rtc::MakeUnique(std::move(rtp_transport1), "content"); + srtp_transport2_ = + rtc::MakeUnique(std::move(rtp_transport2), "content"); + + srtp_transport1_->SignalPacketReceived.connect( + this, &SrtpTransportTest::OnPacketReceived1); + srtp_transport2_->SignalPacketReceived.connect( + this, &SrtpTransportTest::OnPacketReceived2); } + + void OnPacketReceived1(bool rtcp, + rtc::CopyOnWriteBuffer* packet, + const rtc::PacketTime& packet_time) { + LOG(LS_INFO) << "SrtpTransport1 Received a packet."; + last_recv_packet1_ = *packet; + } + + void OnPacketReceived2(bool rtcp, + rtc::CopyOnWriteBuffer* packet, + const rtc::PacketTime& packet_time) { + LOG(LS_INFO) << "SrtpTransport2 Received a packet."; + last_recv_packet2_ = *packet; + } + + // With external auth enabled, SRTP doesn't write the auth tag and + // unprotect would fail. Check accessing the information about the + // tag instead, similar to what the actual code would do that relies + // on external auth. + void TestRtpAuthParams(SrtpTransport* transport, const std::string& cs) { + int overhead; + EXPECT_TRUE(transport->GetSrtpOverhead(&overhead)); + switch (rtc::SrtpCryptoSuiteFromName(cs)) { + case rtc::SRTP_AES128_CM_SHA1_32: + EXPECT_EQ(32 / 8, overhead); // 32-bit tag. + break; + case rtc::SRTP_AES128_CM_SHA1_80: + EXPECT_EQ(80 / 8, overhead); // 80-bit tag. + break; + default: + RTC_NOTREACHED(); + break; + } + + uint8_t* auth_key = nullptr; + int key_len = 0; + int tag_len = 0; + EXPECT_TRUE(transport->GetRtpAuthParams(&auth_key, &key_len, &tag_len)); + EXPECT_NE(nullptr, auth_key); + EXPECT_EQ(160 / 8, key_len); // Length of SHA-1 is 160 bits. + EXPECT_EQ(overhead, tag_len); + } + + void TestSendRecvRtpPacket(const std::string& cipher_suite_name) { + size_t rtp_len = sizeof(kPcmuFrame); + size_t packet_size = rtp_len + rtc::rtp_auth_tag_len(cipher_suite_name); + rtc::Buffer rtp_packet_buffer(packet_size); + char* rtp_packet_data = rtp_packet_buffer.data(); + memcpy(rtp_packet_data, kPcmuFrame, rtp_len); + // In order to be able to run this test function multiple times we can not + // use the same sequence number twice. Increase the sequence number by one. + rtc::SetBE16(reinterpret_cast(rtp_packet_data) + 2, + ++sequence_number_); + rtc::CopyOnWriteBuffer rtp_packet1to2(rtp_packet_data, rtp_len, + packet_size); + rtc::CopyOnWriteBuffer rtp_packet2to1(rtp_packet_data, rtp_len, + packet_size); + + char original_rtp_data[sizeof(kPcmuFrame)]; + memcpy(original_rtp_data, rtp_packet_data, rtp_len); + + rtc::PacketOptions options; + // Send a packet from |srtp_transport1_| to |srtp_transport2_| and verify + // that the packet can be successfully received and decrypted. + ASSERT_TRUE(srtp_transport1_->SendRtpPacket(&rtp_packet1to2, options, + cricket::PF_SRTP_BYPASS)); + if (srtp_transport1_->IsExternalAuthActive()) { + TestRtpAuthParams(srtp_transport1_.get(), cipher_suite_name); + } else { + ASSERT_TRUE(last_recv_packet2_.data()); + EXPECT_TRUE( + memcmp(last_recv_packet2_.data(), original_rtp_data, rtp_len) == 0); + // Get the encrypted packet from underneath packet transport and verify + // the data is actually encrypted. + auto fake_rtp_packet_transport = static_cast( + srtp_transport1_->rtp_packet_transport()); + EXPECT_FALSE(memcmp(fake_rtp_packet_transport->last_sent_packet()->data(), + original_rtp_data, rtp_len) == 0); + } + + // Do the same thing in the opposite direction; + ASSERT_TRUE(srtp_transport2_->SendRtpPacket(&rtp_packet2to1, options, + cricket::PF_SRTP_BYPASS)); + if (srtp_transport2_->IsExternalAuthActive()) { + TestRtpAuthParams(srtp_transport2_.get(), cipher_suite_name); + } else { + ASSERT_TRUE(last_recv_packet1_.data()); + EXPECT_TRUE( + memcmp(last_recv_packet1_.data(), original_rtp_data, rtp_len) == 0); + auto fake_rtp_packet_transport = static_cast( + srtp_transport2_->rtp_packet_transport()); + EXPECT_FALSE(memcmp(fake_rtp_packet_transport->last_sent_packet()->data(), + original_rtp_data, rtp_len) == 0); + } + } + + void TestSendRecvRtcpPacket(const std::string& cipher_suite_name) { + size_t rtcp_len = sizeof(kRtcpReport); + size_t packet_size = + rtcp_len + 4 + rtc::rtcp_auth_tag_len(cipher_suite_name); + rtc::Buffer rtcp_packet_buffer(packet_size); + char* rtcp_packet_data = rtcp_packet_buffer.data(); + memcpy(rtcp_packet_data, kRtcpReport, rtcp_len); + + rtc::CopyOnWriteBuffer rtcp_packet1to2(rtcp_packet_data, rtcp_len, + packet_size); + rtc::CopyOnWriteBuffer rtcp_packet2to1(rtcp_packet_data, rtcp_len, + packet_size); + + rtc::PacketOptions options; + // Send a packet from |srtp_transport1_| to |srtp_transport2_| and verify + // that the packet can be successfully received and decrypted. + ASSERT_TRUE(srtp_transport1_->SendRtcpPacket(&rtcp_packet1to2, options, + cricket::PF_SRTP_BYPASS)); + ASSERT_TRUE(last_recv_packet2_.data()); + EXPECT_TRUE(memcmp(last_recv_packet2_.data(), rtcp_packet_data, rtcp_len) == + 0); + // Get the encrypted packet from underneath packet transport and verify the + // data is actually encrypted. + auto fake_rtp_packet_transport = static_cast( + srtp_transport1_->rtp_packet_transport()); + EXPECT_FALSE(memcmp(fake_rtp_packet_transport->last_sent_packet()->data(), + rtcp_packet_data, rtcp_len) == 0); + + // Do the same thing in the opposite direction; + ASSERT_TRUE(srtp_transport2_->SendRtcpPacket(&rtcp_packet2to1, options, + cricket::PF_SRTP_BYPASS)); + ASSERT_TRUE(last_recv_packet1_.data()); + EXPECT_TRUE(memcmp(last_recv_packet1_.data(), rtcp_packet_data, rtcp_len) == + 0); + fake_rtp_packet_transport = static_cast( + srtp_transport2_->rtp_packet_transport()); + EXPECT_FALSE(memcmp(fake_rtp_packet_transport->last_sent_packet()->data(), + rtcp_packet_data, rtcp_len) == 0); + } + + void TestSendRecvPacket(bool enable_external_auth, + int cs, + const uint8_t* key1, + int key1_len, + const uint8_t* key2, + int key2_len, + const std::string& cipher_suite_name) { + EXPECT_EQ(key1_len, key2_len); + EXPECT_EQ(cipher_suite_name, rtc::SrtpCryptoSuiteToName(cs)); + if (enable_external_auth) { + srtp_transport1_->EnableExternalAuth(); + srtp_transport2_->EnableExternalAuth(); + } + EXPECT_TRUE( + srtp_transport1_->SetRtpParams(cs, key1, key1_len, cs, key2, key2_len)); + EXPECT_TRUE( + srtp_transport2_->SetRtpParams(cs, key2, key2_len, cs, key1, key1_len)); + EXPECT_TRUE(srtp_transport1_->SetRtcpParams(cs, key1, key1_len, cs, key2, + key2_len)); + EXPECT_TRUE(srtp_transport2_->SetRtcpParams(cs, key2, key2_len, cs, key1, + key1_len)); + EXPECT_TRUE(srtp_transport1_->IsActive()); + EXPECT_TRUE(srtp_transport2_->IsActive()); + if (rtc::IsGcmCryptoSuite(cs)) { + EXPECT_FALSE(srtp_transport1_->IsExternalAuthActive()); + EXPECT_FALSE(srtp_transport2_->IsExternalAuthActive()); + } else if (enable_external_auth) { + EXPECT_TRUE(srtp_transport1_->IsExternalAuthActive()); + EXPECT_TRUE(srtp_transport2_->IsExternalAuthActive()); + } + TestSendRecvRtpPacket(cipher_suite_name); + TestSendRecvRtcpPacket(cipher_suite_name); + } + + void TestSendRecvPacketWithEncryptedHeaderExtension( + const std::string& cs, + const std::vector& encrypted_header_ids) { + size_t rtp_len = sizeof(kPcmuFrameWithExtensions); + size_t packet_size = rtp_len + rtc::rtp_auth_tag_len(cs); + rtc::Buffer rtp_packet_buffer(packet_size); + char* rtp_packet_data = rtp_packet_buffer.data(); + memcpy(rtp_packet_data, kPcmuFrameWithExtensions, rtp_len); + // In order to be able to run this test function multiple times we can not + // use the same sequence number twice. Increase the sequence number by one. + rtc::SetBE16(reinterpret_cast(rtp_packet_data) + 2, + ++sequence_number_); + rtc::CopyOnWriteBuffer rtp_packet1to2(rtp_packet_data, rtp_len, + packet_size); + rtc::CopyOnWriteBuffer rtp_packet2to1(rtp_packet_data, rtp_len, + packet_size); + + char original_rtp_data[sizeof(kPcmuFrameWithExtensions)]; + memcpy(original_rtp_data, rtp_packet_data, rtp_len); + + rtc::PacketOptions options; + // Send a packet from |srtp_transport1_| to |srtp_transport2_| and verify + // that the packet can be successfully received and decrypted. + ASSERT_TRUE(srtp_transport1_->SendRtpPacket(&rtp_packet1to2, options, + cricket::PF_SRTP_BYPASS)); + ASSERT_TRUE(last_recv_packet2_.data()); + EXPECT_TRUE(memcmp(last_recv_packet2_.data(), original_rtp_data, rtp_len) == + 0); + // Get the encrypted packet from underneath packet transport and verify the + // data and header extension are actually encrypted. + auto fake_rtp_packet_transport = static_cast( + srtp_transport1_->rtp_packet_transport()); + EXPECT_FALSE(memcmp(fake_rtp_packet_transport->last_sent_packet()->data(), + original_rtp_data, rtp_len) == 0); + CompareHeaderExtensions( + reinterpret_cast( + fake_rtp_packet_transport->last_sent_packet()->data()), + fake_rtp_packet_transport->last_sent_packet()->size(), + original_rtp_data, rtp_len, encrypted_header_ids, false); + + // Do the same thing in the opposite direction; + ASSERT_TRUE(srtp_transport2_->SendRtpPacket(&rtp_packet2to1, options, + cricket::PF_SRTP_BYPASS)); + ASSERT_TRUE(last_recv_packet1_.data()); + EXPECT_TRUE(memcmp(last_recv_packet1_.data(), original_rtp_data, rtp_len) == + 0); + fake_rtp_packet_transport = static_cast( + srtp_transport2_->rtp_packet_transport()); + EXPECT_FALSE(memcmp(fake_rtp_packet_transport->last_sent_packet()->data(), + original_rtp_data, rtp_len) == 0); + CompareHeaderExtensions( + reinterpret_cast( + fake_rtp_packet_transport->last_sent_packet()->data()), + fake_rtp_packet_transport->last_sent_packet()->size(), + original_rtp_data, rtp_len, encrypted_header_ids, false); + } + + void TestSendRecvEncryptedHeaderExtension(int cs, + const uint8_t* key1, + int key1_len, + const uint8_t* key2, + int key2_len, + const std::string& cs_name) { + std::vector encrypted_headers; + encrypted_headers.push_back(1); + // Don't encrypt header ids 2 and 3. + encrypted_headers.push_back(4); + EXPECT_EQ(key1_len, key2_len); + EXPECT_EQ(cs_name, rtc::SrtpCryptoSuiteToName(cs)); + srtp_transport1_->SetEncryptedHeaderExtensionIds(cricket::CS_LOCAL, + encrypted_headers); + srtp_transport1_->SetEncryptedHeaderExtensionIds(cricket::CS_REMOTE, + encrypted_headers); + srtp_transport2_->SetEncryptedHeaderExtensionIds(cricket::CS_LOCAL, + encrypted_headers); + srtp_transport2_->SetEncryptedHeaderExtensionIds(cricket::CS_REMOTE, + encrypted_headers); + EXPECT_TRUE( + srtp_transport1_->SetRtpParams(cs, key1, key1_len, cs, key2, key2_len)); + EXPECT_TRUE( + srtp_transport2_->SetRtpParams(cs, key2, key2_len, cs, key1, key1_len)); + EXPECT_TRUE(srtp_transport1_->IsActive()); + EXPECT_TRUE(srtp_transport2_->IsActive()); + EXPECT_FALSE(srtp_transport1_->IsExternalAuthActive()); + EXPECT_FALSE(srtp_transport2_->IsExternalAuthActive()); + TestSendRecvPacketWithEncryptedHeaderExtension(cs_name, encrypted_headers); + } + + std::unique_ptr srtp_transport1_; + std::unique_ptr srtp_transport2_; + + std::unique_ptr rtp_packet_transport1_; + std::unique_ptr rtp_packet_transport2_; + + rtc::CopyOnWriteBuffer last_recv_packet1_; + rtc::CopyOnWriteBuffer last_recv_packet2_; + int sequence_number_ = 0; }; -TEST(SrtpTransportTest, SendPacket) { - auto rtp_transport = rtc::MakeUnique(); - EXPECT_CALL(*rtp_transport, SendPacket(_, _, _, _)).WillOnce(Return(true)); +class SrtpTransportTestWithExternalAuth + : public SrtpTransportTest, + public testing::WithParamInterface {}; - SrtpTransport srtp_transport(std::move(rtp_transport), "a"); - - const bool rtcp = false; - rtc::CopyOnWriteBuffer packet; - rtc::PacketOptions options; - int flags = 0; - EXPECT_TRUE(srtp_transport.SendPacket(rtcp, &packet, options, flags)); - - // TODO(zstein): Also verify that the packet received by RtpTransport has been - // protected once SrtpTransport handles that. +TEST_P(SrtpTransportTestWithExternalAuth, + SendAndRecvPacket_AES_CM_128_HMAC_SHA1_80) { + bool enable_external_auth = GetParam(); + TestSendRecvPacket(enable_external_auth, rtc::SRTP_AES128_CM_SHA1_80, + kTestKey1, kTestKeyLen, kTestKey2, kTestKeyLen, + rtc::CS_AES_CM_128_HMAC_SHA1_80); } -// Test that SrtpTransport fires SignalPacketReceived when the underlying -// RtpTransport fires SignalPacketReceived. -TEST(SrtpTransportTest, SignalPacketReceived) { - auto rtp_transport = rtc::MakeUnique(); - MockRtpTransport* rtp_transport_raw = rtp_transport.get(); - SrtpTransport srtp_transport(std::move(rtp_transport), "a"); +TEST_F(SrtpTransportTest, + SendAndRecvPacketWithHeaderExtension_AES_CM_128_HMAC_SHA1_80) { + TestSendRecvEncryptedHeaderExtension(rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, + kTestKeyLen, kTestKey2, kTestKeyLen, + rtc::CS_AES_CM_128_HMAC_SHA1_80); +} - SignalPacketReceivedCounter counter(&srtp_transport); +TEST_P(SrtpTransportTestWithExternalAuth, + SendAndRecvPacket_AES_CM_128_HMAC_SHA1_32) { + bool enable_external_auth = GetParam(); + TestSendRecvPacket(enable_external_auth, rtc::SRTP_AES128_CM_SHA1_32, + kTestKey1, kTestKeyLen, kTestKey2, kTestKeyLen, + rtc::CS_AES_CM_128_HMAC_SHA1_32); +} - rtp_transport_raw->PretendReceivedPacket(); +TEST_F(SrtpTransportTest, + SendAndRecvPacketWithHeaderExtension_AES_CM_128_HMAC_SHA1_32) { + TestSendRecvEncryptedHeaderExtension(rtc::SRTP_AES128_CM_SHA1_32, kTestKey1, + kTestKeyLen, kTestKey2, kTestKeyLen, + rtc::CS_AES_CM_128_HMAC_SHA1_32); +} - EXPECT_EQ(1, counter.rtp_count()); +TEST_P(SrtpTransportTestWithExternalAuth, + SendAndRecvPacket_SRTP_AEAD_AES_128_GCM) { + bool enable_external_auth = GetParam(); + TestSendRecvPacket(enable_external_auth, rtc::SRTP_AEAD_AES_128_GCM, + kTestKeyGcm128_1, kTestKeyGcm128Len, kTestKeyGcm128_2, + kTestKeyGcm128Len, rtc::CS_AEAD_AES_128_GCM); +} - // TODO(zstein): Also verify that the packet is unprotected once SrtpTransport - // handles that. +TEST_F(SrtpTransportTest, + SendAndRecvPacketWithHeaderExtension_SRTP_AEAD_AES_128_GCM) { + TestSendRecvEncryptedHeaderExtension( + rtc::SRTP_AEAD_AES_128_GCM, kTestKeyGcm128_1, kTestKeyGcm128Len, + kTestKeyGcm128_2, kTestKeyGcm128Len, rtc::CS_AEAD_AES_128_GCM); +} + +TEST_P(SrtpTransportTestWithExternalAuth, + SendAndRecvPacket_SRTP_AEAD_AES_256_GCM) { + bool enable_external_auth = GetParam(); + TestSendRecvPacket(enable_external_auth, rtc::SRTP_AEAD_AES_256_GCM, + kTestKeyGcm256_1, kTestKeyGcm256Len, kTestKeyGcm256_2, + kTestKeyGcm256Len, rtc::CS_AEAD_AES_256_GCM); +} + +TEST_F(SrtpTransportTest, + SendAndRecvPacketWithHeaderExtension_SRTP_AEAD_AES_256_GCM) { + TestSendRecvEncryptedHeaderExtension( + rtc::SRTP_AEAD_AES_256_GCM, kTestKeyGcm256_1, kTestKeyGcm256Len, + kTestKeyGcm256_2, kTestKeyGcm256Len, rtc::CS_AEAD_AES_256_GCM); +} + +// Run all tests both with and without external auth enabled. +INSTANTIATE_TEST_CASE_P(ExternalAuth, + SrtpTransportTestWithExternalAuth, + ::testing::Values(true, false)); + +// Test directly setting the params with bogus keys. +TEST_F(SrtpTransportTest, TestSetParamsKeyTooShort) { + EXPECT_FALSE(srtp_transport1_->SetRtpParams( + rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen - 1, + rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen - 1)); + EXPECT_FALSE(srtp_transport1_->SetRtcpParams( + rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen - 1, + rtc::SRTP_AES128_CM_SHA1_80, kTestKey1, kTestKeyLen - 1)); } } // namespace webrtc