diff --git a/webrtc/p2p/base/dtlstransportchannel.cc b/webrtc/p2p/base/dtlstransportchannel.cc index f9baa37562..1e3df13493 100644 --- a/webrtc/p2p/base/dtlstransportchannel.cc +++ b/webrtc/p2p/base/dtlstransportchannel.cc @@ -37,6 +37,13 @@ static bool IsDtlsPacket(const char* data, size_t len) { const uint8_t* u = reinterpret_cast(data); return (len >= kDtlsRecordHeaderLen && (u[0] > 19 && u[0] < 64)); } +static bool IsDtlsClientHelloPacket(const char* data, size_t len) { + if (!IsDtlsPacket(data, len)) { + return false; + } + const uint8_t* u = reinterpret_cast(data); + return len > 17 && u[0] == 22 && u[13] == 1; +} static bool IsRtpPacket(const char* data, size_t len) { const uint8_t* u = reinterpret_cast(data); return (len >= kMinRtpPacketLen && (u[0] & 0xC0) == 0x80); @@ -470,15 +477,18 @@ void DtlsTransportChannelWrapper::OnReadPacket( switch (dtls_state()) { case DTLS_TRANSPORT_NEW: if (dtls_) { - // Drop packets received before DTLS has actually started. - LOG_J(LS_INFO, this) << "Dropping packet received before DTLS started."; + LOG_J(LS_INFO, this) << "Packet received before DTLS started."; } else { - // Currently drop the packet, but we might in future - // decide to take this as evidence that the other - // side is ready to do DTLS and start the handshake - // on our end. - LOG_J(LS_WARNING, this) << "Received packet before we know if we are " - << "doing DTLS or not; dropping."; + LOG_J(LS_WARNING, this) << "Packet received before we know if we are " + << "doing DTLS or not."; + } + // Cache a client hello packet received before DTLS has actually started. + if (IsDtlsClientHelloPacket(data, size)) { + LOG_J(LS_INFO, this) << "Caching DTLS ClientHello packet until DTLS is " + << "started."; + cached_client_hello_.SetData(data, size); + } else { + LOG_J(LS_INFO, this) << "Not a DTLS ClientHello packet; dropping."; } break; @@ -577,6 +587,21 @@ bool DtlsTransportChannelWrapper::MaybeStartDtls() { LOG_J(LS_INFO, this) << "DtlsTransportChannelWrapper: Started DTLS handshake"; set_dtls_state(DTLS_TRANSPORT_CONNECTING); + // Now that the handshake has started, we can process a cached ClientHello + // (if one exists). + if (cached_client_hello_.size()) { + if (ssl_role_ == rtc::SSL_SERVER) { + LOG_J(LS_INFO, this) << "Handling cached DTLS ClientHello packet."; + if (!HandleDtlsPacket(cached_client_hello_.data(), + cached_client_hello_.size())) { + LOG_J(LS_ERROR, this) << "Failed to handle DTLS packet."; + } + } else { + LOG_J(LS_WARNING, this) << "Discarding cached DTLS ClientHello packet " + << "because we don't have the server role."; + } + cached_client_hello_.Clear(); + } } return true; } diff --git a/webrtc/p2p/base/dtlstransportchannel.h b/webrtc/p2p/base/dtlstransportchannel.h index f004bb1460..c5f55469f2 100644 --- a/webrtc/p2p/base/dtlstransportchannel.h +++ b/webrtc/p2p/base/dtlstransportchannel.h @@ -237,6 +237,12 @@ class DtlsTransportChannelWrapper : public TransportChannelImpl { rtc::Buffer remote_fingerprint_value_; std::string remote_fingerprint_algorithm_; + // Cached DTLS ClientHello packet that was received before we started the + // DTLS handshake. This could happen if the hello was received before the + // transport channel became writable, or before a remote fingerprint was + // received. + rtc::Buffer cached_client_hello_; + RTC_DISALLOW_COPY_AND_ASSIGN(DtlsTransportChannelWrapper); }; diff --git a/webrtc/p2p/base/dtlstransportchannel_unittest.cc b/webrtc/p2p/base/dtlstransportchannel_unittest.cc index 2c97ac6c6e..3e34affb4b 100644 --- a/webrtc/p2p/base/dtlstransportchannel_unittest.cc +++ b/webrtc/p2p/base/dtlstransportchannel_unittest.cc @@ -33,25 +33,34 @@ static const char kIcePwd1[] = "TESTICEPWD00000000000001"; static const size_t kPacketNumOffset = 8; static const size_t kPacketHeaderLen = 12; static const int kFakePacketId = 0x1234; +static const int kTimeout = 10000; static bool IsRtpLeadByte(uint8_t b) { return ((b & 0xC0) == 0x80); } +cricket::TransportDescription MakeTransportDescription( + const rtc::scoped_refptr& cert, + cricket::ConnectionRole role) { + rtc::scoped_ptr fingerprint; + if (cert) { + std::string digest_algorithm; + cert->ssl_certificate().GetSignatureDigestAlgorithm(&digest_algorithm); + fingerprint.reset( + rtc::SSLFingerprint::Create(digest_algorithm, cert->identity())); + } + return cricket::TransportDescription(std::vector(), kIceUfrag1, + kIcePwd1, cricket::ICEMODE_FULL, role, + fingerprint.get()); +} + using cricket::ConnectionRole; enum Flags { NF_REOFFER = 0x1, NF_EXPECT_FAILURE = 0x2 }; class DtlsTestClient : public sigslot::has_slots<> { public: - DtlsTestClient(const std::string& name) - : name_(name), - packet_size_(0), - use_dtls_srtp_(false), - ssl_max_version_(rtc::SSL_PROTOCOL_DTLS_12), - negotiated_dtls_(false), - received_dtls_client_hello_(false), - received_dtls_server_hello_(false) {} + DtlsTestClient(const std::string& name) : name_(name) {} void CreateCertificate(rtc::KeyType key_type) { certificate_ = rtc::RTCCertificate::Create(std::unique_ptr( @@ -185,9 +194,8 @@ class DtlsTestClient : public sigslot::has_slots<> { negotiated_dtls_ = (local_cert && remote_cert); } - bool Connect(DtlsTestClient* peer) { - transport_->ConnectChannels(); - transport_->SetDestination(peer->transport_.get()); + bool Connect(DtlsTestClient* peer, bool asymmetric) { + transport_->SetDestination(peer->transport_.get(), asymmetric); return true; } @@ -203,13 +211,29 @@ class DtlsTestClient : public sigslot::has_slots<> { return true; } + bool all_raw_channels_writable() const { + if (channels_.empty()) { + return false; + } + for (cricket::DtlsTransportChannelWrapper* channel : channels_) { + if (!channel->channel()->writable()) { + return false; + } + } + return true; + } + + int received_dtls_client_hellos() const { + return received_dtls_client_hellos_; + } + void CheckRole(rtc::SSLRole role) { if (role == rtc::SSL_CLIENT) { - ASSERT_FALSE(received_dtls_client_hello_); - ASSERT_TRUE(received_dtls_server_hello_); + ASSERT_EQ(0, received_dtls_client_hellos_); + ASSERT_GT(received_dtls_server_hellos_, 0); } else { - ASSERT_TRUE(received_dtls_client_hello_); - ASSERT_FALSE(received_dtls_server_hello_); + ASSERT_GT(received_dtls_client_hellos_, 0); + ASSERT_EQ(0, received_dtls_server_hellos_); } } @@ -358,20 +382,18 @@ class DtlsTestClient : public sigslot::has_slots<> { // Look at the handshake packets to see what role we played. // Check that non-handshake packets are DTLS data or SRTP bypass. - if (negotiated_dtls_) { - if (data[0] == 22 && size > 17) { - if (data[13] == 1) { - received_dtls_client_hello_ = true; - } else if (data[13] == 2) { - received_dtls_server_hello_ = true; - } - } else if (!(data[0] >= 20 && data[0] <= 22)) { - ASSERT_TRUE(data[0] == 23 || IsRtpLeadByte(data[0])); - if (data[0] == 23) { - ASSERT_TRUE(VerifyEncryptedPacket(data, size)); - } else if (IsRtpLeadByte(data[0])) { - ASSERT_TRUE(VerifyPacket(data, size, NULL)); - } + if (data[0] == 22 && size > 17) { + if (data[13] == 1) { + ++received_dtls_client_hellos_; + } else if (data[13] == 2) { + ++received_dtls_server_hellos_; + } + } else if (negotiated_dtls_ && !(data[0] >= 20 && data[0] <= 22)) { + ASSERT_TRUE(data[0] == 23 || IsRtpLeadByte(data[0])); + if (data[0] == 23) { + ASSERT_TRUE(VerifyEncryptedPacket(data, size)); + } else if (IsRtpLeadByte(data[0])) { + ASSERT_TRUE(VerifyPacket(data, size, NULL)); } } } @@ -381,13 +403,13 @@ class DtlsTestClient : public sigslot::has_slots<> { rtc::scoped_refptr certificate_; std::unique_ptr transport_; std::vector channels_; - size_t packet_size_; + size_t packet_size_ = 0u; std::set received_; - bool use_dtls_srtp_; - rtc::SSLProtocolVersion ssl_max_version_; - bool negotiated_dtls_; - bool received_dtls_client_hello_; - bool received_dtls_server_hello_; + bool use_dtls_srtp_ = false; + rtc::SSLProtocolVersion ssl_max_version_ = rtc::SSL_PROTOCOL_DTLS_12; + bool negotiated_dtls_ = false; + int received_dtls_client_hellos_ = 0; + int received_dtls_server_hellos_ = 0; rtc::SentPacket sent_packet_; }; @@ -437,14 +459,14 @@ class DtlsTransportChannelTest : public testing::Test { bool Connect(ConnectionRole client1_role, ConnectionRole client2_role) { Negotiate(client1_role, client2_role); - bool rv = client1_.Connect(&client2_); + bool rv = client1_.Connect(&client2_, false); EXPECT_TRUE(rv); if (!rv) return false; EXPECT_TRUE_WAIT( client1_.all_channels_writable() && client2_.all_channels_writable(), - 10000); + kTimeout); if (!client1_.all_channels_writable() || !client2_.all_channels_writable()) return false; @@ -535,7 +557,7 @@ class DtlsTransportChannelTest : public testing::Test { LOG(LS_INFO) << "Expect packets, size=" << size; client2_.ExpectPackets(channel, size); client1_.SendPackets(channel, size, count, srtp); - EXPECT_EQ_WAIT(count, client2_.NumPacketsReceived(), 10000); + EXPECT_EQ_WAIT(count, client2_.NumPacketsReceived(), kTimeout); } protected: @@ -828,11 +850,11 @@ TEST_F(DtlsTransportChannelTest, TestRenegotiateBeforeConnect) { Renegotiate(&client1_, cricket::CONNECTIONROLE_ACTPASS, cricket::CONNECTIONROLE_ACTIVE, NF_REOFFER); - bool rv = client1_.Connect(&client2_); + bool rv = client1_.Connect(&client2_, false); EXPECT_TRUE(rv); EXPECT_TRUE_WAIT( client1_.all_channels_writable() && client2_.all_channels_writable(), - 10000); + kTimeout); TestTransfer(0, 1000, 100, true); TestTransfer(1, 1000, 100, true); @@ -886,3 +908,69 @@ TEST_F(DtlsTransportChannelTest, TestCertificatesAfterConnect) { ASSERT_EQ(remote_cert2->ToPEMString(), certificate1->ssl_certificate().ToPEMString()); } + +// Test that DTLS completes promptly if a ClientHello is received before the +// transport channel is writable (allowing a ServerHello to be sent). +TEST_F(DtlsTransportChannelTest, TestReceiveClientHelloBeforeWritable) { + MAYBE_SKIP_TEST(HaveDtls); + PrepareDtls(true, true, rtc::KT_DEFAULT); + // Exchange transport descriptions. + Negotiate(cricket::CONNECTIONROLE_ACTPASS, cricket::CONNECTIONROLE_ACTIVE); + + // Make client2_ writable, but not client1_. + EXPECT_TRUE(client2_.Connect(&client1_, true)); + EXPECT_TRUE_WAIT(client2_.all_raw_channels_writable(), kTimeout); + + // Expect a DTLS ClientHello to be sent even while client1_ isn't writable. + EXPECT_EQ_WAIT(1, client1_.received_dtls_client_hellos(), kTimeout); + EXPECT_FALSE(client1_.all_raw_channels_writable()); + + // Now make client1_ writable and expect the handshake to complete + // without client2_ needing to retransmit the ClientHello. + EXPECT_TRUE(client1_.Connect(&client2_, true)); + EXPECT_TRUE_WAIT( + client1_.all_channels_writable() && client2_.all_channels_writable(), + kTimeout); + EXPECT_EQ(1, client1_.received_dtls_client_hellos()); +} + +// Test that DTLS completes promptly if a ClientHello is received before the +// transport channel has a remote fingerprint (allowing a ServerHello to be +// sent). +TEST_F(DtlsTransportChannelTest, + TestReceiveClientHelloBeforeRemoteFingerprint) { + MAYBE_SKIP_TEST(HaveDtls); + PrepareDtls(true, true, rtc::KT_DEFAULT); + client1_.SetupChannels(channel_ct_, cricket::ICEROLE_CONTROLLING); + client2_.SetupChannels(channel_ct_, cricket::ICEROLE_CONTROLLED); + + // Make client2_ writable and give it local/remote certs, but don't yet give + // client1_ a remote fingerprint. + client1_.transport()->SetLocalTransportDescription( + MakeTransportDescription(client1_.certificate(), + cricket::CONNECTIONROLE_ACTPASS), + cricket::CA_OFFER, nullptr); + client2_.Negotiate(&client1_, cricket::CA_ANSWER, + cricket::CONNECTIONROLE_ACTIVE, + cricket::CONNECTIONROLE_ACTPASS, 0); + EXPECT_TRUE(client2_.Connect(&client1_, true)); + EXPECT_TRUE_WAIT(client2_.all_raw_channels_writable(), kTimeout); + + // Expect a DTLS ClientHello to be sent even while client1_ doesn't have a + // remote fingerprint. + EXPECT_EQ_WAIT(1, client1_.received_dtls_client_hellos(), kTimeout); + EXPECT_FALSE(client1_.all_raw_channels_writable()); + + // Now make give client1_ its remote fingerprint and make it writable, and + // expect the handshake to complete without client2_ needing to retransmit + // the ClientHello. + client1_.transport()->SetRemoteTransportDescription( + MakeTransportDescription(client2_.certificate(), + cricket::CONNECTIONROLE_ACTIVE), + cricket::CA_ANSWER, nullptr); + EXPECT_TRUE(client1_.Connect(&client2_, true)); + EXPECT_TRUE_WAIT( + client1_.all_channels_writable() && client2_.all_channels_writable(), + kTimeout); + EXPECT_EQ(1, client1_.received_dtls_client_hellos()); +} diff --git a/webrtc/p2p/base/faketransportcontroller.h b/webrtc/p2p/base/faketransportcontroller.h index d2fdb3c904..1ed3fd85e6 100644 --- a/webrtc/p2p/base/faketransportcontroller.h +++ b/webrtc/p2p/base/faketransportcontroller.h @@ -141,20 +141,22 @@ class FakeTransportChannel : public TransportChannelImpl, void SetWritable(bool writable) { set_writable(writable); } - void SetDestination(FakeTransportChannel* dest) { + // Simulates the two transport channels connecting to each other. + // If |asymmetric| is true this method only affects this FakeTransportChannel. + // If false, it affects |dest| as well. + void SetDestination(FakeTransportChannel* dest, bool asymmetric = false) { if (state_ == STATE_CONNECTING && dest) { // This simulates the delivery of candidates. dest_ = dest; - dest_->dest_ = this; if (local_cert_ && dest_->local_cert_) { do_dtls_ = true; - dest_->do_dtls_ = true; NegotiateSrtpCiphers(); } state_ = STATE_CONNECTED; - dest_->state_ = STATE_CONNECTED; set_writable(true); - dest_->set_writable(true); + if (!asymmetric) { + dest->SetDestination(this, true); + } } else if (state_ == STATE_CONNECTED && !dest) { // Simulates loss of connectivity, by asymmetrically forgetting dest_. dest_ = nullptr; @@ -282,20 +284,6 @@ class FakeTransportChannel : public TransportChannelImpl, return false; } - void NegotiateSrtpCiphers() { - for (std::vector::const_iterator it1 = srtp_ciphers_.begin(); - it1 != srtp_ciphers_.end(); ++it1) { - for (std::vector::const_iterator it2 = dest_->srtp_ciphers_.begin(); - it2 != dest_->srtp_ciphers_.end(); ++it2) { - if (*it1 == *it2) { - chosen_crypto_suite_ = *it1; - dest_->chosen_crypto_suite_ = *it2; - return; - } - } - } - } - bool GetStats(ConnectionInfos* infos) override { ConnectionInfo info; infos->clear(); @@ -311,6 +299,19 @@ class FakeTransportChannel : public TransportChannelImpl, } private: + void NegotiateSrtpCiphers() { + for (std::vector::const_iterator it1 = srtp_ciphers_.begin(); + it1 != srtp_ciphers_.end(); ++it1) { + for (std::vector::const_iterator it2 = dest_->srtp_ciphers_.begin(); + it2 != dest_->srtp_ciphers_.end(); ++it2) { + if (*it1 == *it2) { + chosen_crypto_suite_ = *it1; + return; + } + } + } + } + enum State { STATE_INIT, STATE_CONNECTING, STATE_CONNECTED }; FakeTransportChannel* dest_ = nullptr; State state_ = STATE_INIT; @@ -359,11 +360,14 @@ class FakeTransport : public Transport { // If async, will send packets by "Post"-ing to message queue instead of // synchronously "Send"-ing. void SetAsync(bool async) { async_ = async; } - void SetDestination(FakeTransport* dest) { + + // If |asymmetric| is true, only set the destination for this transport, and + // not |dest|. + void SetDestination(FakeTransport* dest, bool asymmetric = false) { dest_ = dest; for (const auto& kv : channels_) { kv.second->SetLocalCertificate(certificate_); - SetChannelDestination(kv.first, kv.second); + SetChannelDestination(kv.first, kv.second, asymmetric); } } @@ -417,7 +421,7 @@ class FakeTransport : public Transport { FakeTransportChannel* channel = new FakeTransportChannel(name(), component); channel->set_ssl_max_protocol_version(ssl_max_version_); channel->SetAsync(async_); - SetChannelDestination(component, channel); + SetChannelDestination(component, channel, false); channels_[component] = channel; return channel; } @@ -433,15 +437,17 @@ class FakeTransport : public Transport { return (it != channels_.end()) ? it->second : nullptr; } - void SetChannelDestination(int component, FakeTransportChannel* channel) { + void SetChannelDestination(int component, + FakeTransportChannel* channel, + bool asymmetric) { FakeTransportChannel* dest_channel = nullptr; if (dest_) { dest_channel = dest_->GetFakeChannel(component); - if (dest_channel) { + if (dest_channel && !asymmetric) { dest_channel->SetLocalCertificate(dest_->certificate_); } } - channel->SetDestination(dest_channel); + channel->SetDestination(dest_channel, asymmetric); } // Note, this is distinct from the Channel map owned by Transport.