diff --git a/webrtc/p2p/base/dtlstransportchannel.cc b/webrtc/p2p/base/dtlstransportchannel.cc index d25de260f3..a6b06361f4 100644 --- a/webrtc/p2p/base/dtlstransportchannel.cc +++ b/webrtc/p2p/base/dtlstransportchannel.cc @@ -238,17 +238,19 @@ bool DtlsTransportChannelWrapper::SetRemoteFingerprint( remote_fingerprint_value_ = std::move(remote_fingerprint_value); remote_fingerprint_algorithm_ = digest_alg; - bool reconnect = (dtls_ != nullptr); + if (dtls_) { + // If the fingerprint is changing, we'll tear down the DTLS association and + // create a new one, resetting our state. + dtls_.reset(nullptr); + set_dtls_state(DTLS_TRANSPORT_NEW); + set_writable(false); + } if (!SetupDtls()) { set_dtls_state(DTLS_TRANSPORT_FAILED); return false; } - if (reconnect) { - Reconnect(); - } - return true; } @@ -297,6 +299,10 @@ bool DtlsTransportChannelWrapper::SetupDtls() { } LOG_J(LS_INFO, this) << "DTLS setup complete."; + + // If the underlying channel is already writable at this point, we may be + // able to start DTLS right away. + MaybeStartDtls(); return true; } @@ -419,15 +425,7 @@ void DtlsTransportChannelWrapper::OnWritableState(TransportChannel* channel) { switch (dtls_state()) { case DTLS_TRANSPORT_NEW: - // This should never fail: - // Because we are operating in a nonblocking mode and all - // incoming packets come in via OnReadPacket(), which rejects - // packets in this state, the incoming queue must be empty. We - // ignore write errors, thus any errors must be because of - // configuration and therefore are our fault. - // Note that in non-debug configurations, failure in - // MaybeStartDtls() changes the state to DTLS_TRANSPORT_FAILED. - VERIFY(MaybeStartDtls()); + MaybeStartDtls(); break; case DTLS_TRANSPORT_CONNECTED: // Note: SignalWritableState fired by set_writable. @@ -571,12 +569,19 @@ void DtlsTransportChannelWrapper::OnDtlsEvent(rtc::StreamInterface* dtls, } } -bool DtlsTransportChannelWrapper::MaybeStartDtls() { +void DtlsTransportChannelWrapper::MaybeStartDtls() { if (dtls_ && channel_->writable()) { if (dtls_->StartSSLWithPeer()) { + // This should never fail: + // Because we are operating in a nonblocking mode and all + // incoming packets come in via OnReadPacket(), which rejects + // packets in this state, the incoming queue must be empty. We + // ignore write errors, thus any errors must be because of + // configuration and therefore are our fault. + RTC_DCHECK(false) << "StartSSLWithPeer failed."; LOG_J(LS_ERROR, this) << "Couldn't start DTLS handshake"; set_dtls_state(DTLS_TRANSPORT_FAILED); - return false; + return; } LOG_J(LS_INFO, this) << "DtlsTransportChannelWrapper: Started DTLS handshake"; @@ -597,7 +602,6 @@ bool DtlsTransportChannelWrapper::MaybeStartDtls() { cached_client_hello_.Clear(); } } - return true; } // Called from OnReadPacket when a DTLS packet is received. @@ -672,12 +676,4 @@ void DtlsTransportChannelWrapper::OnChannelStateChanged( SignalStateChanged(this); } -void DtlsTransportChannelWrapper::Reconnect() { - set_dtls_state(DTLS_TRANSPORT_NEW); - set_writable(false); - if (channel_->writable()) { - OnWritableState(channel_); - } -} - } // namespace cricket diff --git a/webrtc/p2p/base/dtlstransportchannel.h b/webrtc/p2p/base/dtlstransportchannel.h index c8b76d3303..7b8247a449 100644 --- a/webrtc/p2p/base/dtlstransportchannel.h +++ b/webrtc/p2p/base/dtlstransportchannel.h @@ -207,7 +207,7 @@ class DtlsTransportChannelWrapper : public TransportChannelImpl { void OnReceivingState(TransportChannel* channel); void OnDtlsEvent(rtc::StreamInterface* stream_, int sig, int err); bool SetupDtls(); - bool MaybeStartDtls(); + void MaybeStartDtls(); bool HandleDtlsPacket(const char* data, size_t size); void OnGatheringState(TransportChannelImpl* channel); void OnCandidateGathered(TransportChannelImpl* channel, const Candidate& c); @@ -221,7 +221,6 @@ class DtlsTransportChannelWrapper : public TransportChannelImpl { int last_sent_packet_id, bool ready_to_send); void OnChannelStateChanged(TransportChannelImpl* channel); - void Reconnect(); rtc::Thread* worker_thread_; // Everything should occur on this thread. // Underlying channel, not owned by this class. diff --git a/webrtc/p2p/base/dtlstransportchannel_unittest.cc b/webrtc/p2p/base/dtlstransportchannel_unittest.cc index 705df2d95f..6eb0f0e3f1 100644 --- a/webrtc/p2p/base/dtlstransportchannel_unittest.cc +++ b/webrtc/p2p/base/dtlstransportchannel_unittest.cc @@ -45,9 +45,13 @@ cricket::TransportDescription MakeTransportDescription( std::unique_ptr fingerprint; if (cert) { std::string digest_algorithm; - cert->ssl_certificate().GetSignatureDigestAlgorithm(&digest_algorithm); + EXPECT_TRUE( + cert->ssl_certificate().GetSignatureDigestAlgorithm(&digest_algorithm)); + EXPECT_FALSE(digest_algorithm.empty()); fingerprint.reset( rtc::SSLFingerprint::Create(digest_algorithm, cert->identity())); + EXPECT_TRUE(fingerprint.get() != NULL); + EXPECT_EQ(rtc::DIGEST_SHA_256, digest_algorithm); } return cricket::TransportDescription(std::vector(), kIceUfrag1, kIcePwd1, cricket::ICEMODE_FULL, role, @@ -124,6 +128,48 @@ class DtlsTestClient : public sigslot::has_slots<> { local_role, remote_role, flags); } + void MaybeSetSrtpCryptoSuites() { + if (!use_dtls_srtp_) { + return; + } + std::vector ciphers; + ciphers.push_back(rtc::SRTP_AES128_CM_SHA1_80); + // SRTP ciphers will be set only in the beginning. + for (cricket::DtlsTransportChannelWrapper* channel : channels_) { + EXPECT_TRUE(channel->SetSrtpCryptoSuites(ciphers)); + } + } + + void SetLocalTransportDescription( + const rtc::scoped_refptr& cert, + cricket::ContentAction action, + ConnectionRole role, + int flags) { + // If |NF_EXPECT_FAILURE| is set, expect SRTD or SLTD to fail when + // content action is CA_ANSWER. + bool expect_success = + !((action == cricket::CA_ANSWER) && (flags & NF_EXPECT_FAILURE)); + EXPECT_EQ(expect_success, + transport_->SetLocalTransportDescription( + MakeTransportDescription(cert, role), action, nullptr)); + set_local_cert_ = (cert != nullptr); + } + + void SetRemoteTransportDescription( + const rtc::scoped_refptr& cert, + cricket::ContentAction action, + ConnectionRole role, + int flags) { + // If |NF_EXPECT_FAILURE| is set, expect SRTD or SLTD to fail when + // content action is CA_ANSWER. + bool expect_success = + !((action == cricket::CA_ANSWER) && (flags & NF_EXPECT_FAILURE)); + EXPECT_EQ(expect_success, + transport_->SetRemoteTransportDescription( + MakeTransportDescription(cert, role), action, nullptr)); + set_remote_cert_ = (cert != nullptr); + } + // Allow any DTLS configuration to be specified (including invalid ones). void Negotiate(const rtc::scoped_refptr& local_cert, const rtc::scoped_refptr& remote_cert, @@ -131,67 +177,23 @@ class DtlsTestClient : public sigslot::has_slots<> { ConnectionRole local_role, ConnectionRole remote_role, int flags) { - std::unique_ptr local_fingerprint; - std::unique_ptr remote_fingerprint; - if (local_cert) { - std::string digest_algorithm; - ASSERT_TRUE(local_cert->ssl_certificate().GetSignatureDigestAlgorithm( - &digest_algorithm)); - ASSERT_FALSE(digest_algorithm.empty()); - local_fingerprint.reset(rtc::SSLFingerprint::Create( - digest_algorithm, local_cert->identity())); - ASSERT_TRUE(local_fingerprint.get() != NULL); - EXPECT_EQ(rtc::DIGEST_SHA_256, digest_algorithm); - } - if (remote_cert) { - std::string digest_algorithm; - ASSERT_TRUE(remote_cert->ssl_certificate().GetSignatureDigestAlgorithm( - &digest_algorithm)); - ASSERT_FALSE(digest_algorithm.empty()); - remote_fingerprint.reset(rtc::SSLFingerprint::Create( - digest_algorithm, remote_cert->identity())); - ASSERT_TRUE(remote_fingerprint.get() != NULL); - EXPECT_EQ(rtc::DIGEST_SHA_256, digest_algorithm); - } - - if (use_dtls_srtp_ && !(flags & NF_REOFFER)) { + if (!(flags & NF_REOFFER)) { // SRTP ciphers will be set only in the beginning. - for (std::vector::iterator it = - channels_.begin(); it != channels_.end(); ++it) { - std::vector ciphers; - ciphers.push_back(rtc::SRTP_AES128_CM_SHA1_80); - ASSERT_TRUE((*it)->SetSrtpCryptoSuites(ciphers)); - } + MaybeSetSrtpCryptoSuites(); } - - cricket::TransportDescription local_desc( - std::vector(), kIceUfrag1, kIcePwd1, cricket::ICEMODE_FULL, - local_role, - // If remote if the offerer and has no DTLS support, answer will be - // without any fingerprint. - (action == cricket::CA_ANSWER && !remote_cert) - ? nullptr - : local_fingerprint.get()); - - cricket::TransportDescription remote_desc( - std::vector(), kIceUfrag1, kIcePwd1, cricket::ICEMODE_FULL, - remote_role, remote_fingerprint.get()); - - bool expect_success = (flags & NF_EXPECT_FAILURE) ? false : true; - // If |expect_success| is false, expect SRTD or SLTD to fail when - // content action is CA_ANSWER. if (action == cricket::CA_OFFER) { - ASSERT_TRUE(transport_->SetLocalTransportDescription( - local_desc, cricket::CA_OFFER, NULL)); - ASSERT_EQ(expect_success, transport_->SetRemoteTransportDescription( - remote_desc, cricket::CA_ANSWER, NULL)); + SetLocalTransportDescription(local_cert, cricket::CA_OFFER, local_role, + flags); + SetRemoteTransportDescription(remote_cert, cricket::CA_ANSWER, + remote_role, flags); } else { - ASSERT_TRUE(transport_->SetRemoteTransportDescription( - remote_desc, cricket::CA_OFFER, NULL)); - ASSERT_EQ(expect_success, transport_->SetLocalTransportDescription( - local_desc, cricket::CA_ANSWER, NULL)); + SetRemoteTransportDescription(remote_cert, cricket::CA_OFFER, remote_role, + flags); + // If remote if the offerer and has no DTLS support, answer will be + // without any fingerprint. + SetLocalTransportDescription(remote_cert ? local_cert : nullptr, + cricket::CA_ANSWER, local_role, flags); } - negotiated_dtls_ = (local_cert && remote_cert); } bool Connect(DtlsTestClient* peer, bool asymmetric) { @@ -227,6 +229,8 @@ class DtlsTestClient : public sigslot::has_slots<> { return received_dtls_client_hellos_; } + bool negotiated_dtls() const { return set_local_cert_ && set_remote_cert_; } + void CheckRole(rtc::SSLRole role) { if (role == rtc::SSL_CLIENT) { ASSERT_EQ(0, received_dtls_client_hellos_); @@ -243,7 +247,7 @@ class DtlsTestClient : public sigslot::has_slots<> { int crypto_suite; bool rv = (*it)->GetSrtpCryptoSuite(&crypto_suite); - if (negotiated_dtls_ && expected_crypto_suite) { + if (negotiated_dtls() && expected_crypto_suite) { ASSERT_TRUE(rv); ASSERT_EQ(crypto_suite, expected_crypto_suite); @@ -259,7 +263,7 @@ class DtlsTestClient : public sigslot::has_slots<> { int cipher; bool rv = (*it)->GetSslCipherSuite(&cipher); - if (negotiated_dtls_) { + if (negotiated_dtls()) { ASSERT_TRUE(rv); EXPECT_TRUE( @@ -388,7 +392,7 @@ class DtlsTestClient : public sigslot::has_slots<> { } else if (data[13] == 2) { ++received_dtls_server_hellos_; } - } else if (negotiated_dtls_ && !(data[0] >= 20 && data[0] <= 22)) { + } else if (negotiated_dtls() && !(data[0] >= 20 && data[0] <= 22)) { ASSERT_TRUE(data[0] == 23 || IsRtpLeadByte(data[0])); if (data[0] == 23) { ASSERT_TRUE(VerifyEncryptedPacket(data, size)); @@ -407,7 +411,8 @@ class DtlsTestClient : public sigslot::has_slots<> { std::set received_; bool use_dtls_srtp_ = false; rtc::SSLProtocolVersion ssl_max_version_ = rtc::SSL_PROTOCOL_DTLS_12; - bool negotiated_dtls_ = false; + bool set_local_cert_ = false; + bool set_remote_cert_ = false; int received_dtls_client_hellos_ = 0; int received_dtls_server_hellos_ = 0; rtc::SentPacket sent_packet_; @@ -457,10 +462,35 @@ class DtlsTransportChannelTest : public testing::Test { use_dtls_srtp_ = true; } - bool Connect(ConnectionRole client1_role, ConnectionRole client2_role) { - Negotiate(client1_role, client2_role); + // Negotiate local/remote fingerprint before or after the underlying + // tranpsort is connected? + enum NegotiateOrdering { NEGOTIATE_BEFORE_CONNECT, CONNECT_BEFORE_NEGOTIATE }; + bool Connect(ConnectionRole client1_role, + ConnectionRole client2_role, + NegotiateOrdering ordering = NEGOTIATE_BEFORE_CONNECT) { + bool rv; + if (ordering == NEGOTIATE_BEFORE_CONNECT) { + Negotiate(client1_role, client2_role); + rv = client1_.Connect(&client2_, false); + } else { + client1_.SetupChannels(channel_ct_, cricket::ICEROLE_CONTROLLING); + client2_.SetupChannels(channel_ct_, cricket::ICEROLE_CONTROLLED); + client1_.MaybeSetSrtpCryptoSuites(); + client2_.MaybeSetSrtpCryptoSuites(); + // This is equivalent to an offer being processed on both sides, but an + // answer not yet being received on the initiating side. So the + // connection will be made before negotiation has finished on both sides. + client1_.SetLocalTransportDescription(client1_.certificate(), + cricket::CA_OFFER, client1_role, 0); + client2_.SetRemoteTransportDescription( + client1_.certificate(), cricket::CA_OFFER, client1_role, 0); + client2_.SetLocalTransportDescription( + client2_.certificate(), cricket::CA_ANSWER, client2_role, 0); + rv = client1_.Connect(&client2_, false); + client1_.SetRemoteTransportDescription( + client2_.certificate(), cricket::CA_ANSWER, client2_role, 0); + } - bool rv = client1_.Connect(&client2_, false); EXPECT_TRUE(rv); if (!rv) return false; @@ -1018,3 +1048,14 @@ TEST_F(DtlsTransportChannelTest, TestRetransmissionSchedule) { EXPECT_EQ(++expected_hellos, client1_.received_dtls_client_hellos()); } } + +// Test that a DTLS connection can be made even if the underlying transport +// is connected before DTLS fingerprints/roles have been negotiated. +TEST_F(DtlsTransportChannelTest, TestConnectBeforeNegotiate) { + MAYBE_SKIP_TEST(HaveDtls); + PrepareDtls(true, true, rtc::KT_DEFAULT); + ASSERT_TRUE(Connect(cricket::CONNECTIONROLE_ACTPASS, + cricket::CONNECTIONROLE_ACTIVE, + CONNECT_BEFORE_NEGOTIATE)); + TestTransfer(0, 1000, 100, false); +}