Cache a ClientHello received before the DTLS handshake has started.
In some cases, the DTLS ClientHello may arrive before the server's transport is writable (before it receives a STUN ping response), or even before it receives a remote fingerprint. If this packet is discarded, it may take a second for a it to be sent again. So, this CL caches it instead of dropping it, and feeds it into the SSL library once the handshake has been started. BUG=webrtc:5789 Review-Url: https://codereview.webrtc.org/1912323002 Cr-Commit-Position: refs/heads/master@{#12634}
This commit is contained in:
parent
fac23f00ef
commit
e84cd2eaca
@ -37,6 +37,13 @@ static bool IsDtlsPacket(const char* data, size_t len) {
|
||||
const uint8_t* u = reinterpret_cast<const uint8_t*>(data);
|
||||
return (len >= kDtlsRecordHeaderLen && (u[0] > 19 && u[0] < 64));
|
||||
}
|
||||
static bool IsDtlsClientHelloPacket(const char* data, size_t len) {
|
||||
if (!IsDtlsPacket(data, len)) {
|
||||
return false;
|
||||
}
|
||||
const uint8_t* u = reinterpret_cast<const uint8_t*>(data);
|
||||
return len > 17 && u[0] == 22 && u[13] == 1;
|
||||
}
|
||||
static bool IsRtpPacket(const char* data, size_t len) {
|
||||
const uint8_t* u = reinterpret_cast<const uint8_t*>(data);
|
||||
return (len >= kMinRtpPacketLen && (u[0] & 0xC0) == 0x80);
|
||||
@ -470,15 +477,18 @@ void DtlsTransportChannelWrapper::OnReadPacket(
|
||||
switch (dtls_state()) {
|
||||
case DTLS_TRANSPORT_NEW:
|
||||
if (dtls_) {
|
||||
// Drop packets received before DTLS has actually started.
|
||||
LOG_J(LS_INFO, this) << "Dropping packet received before DTLS started.";
|
||||
LOG_J(LS_INFO, this) << "Packet received before DTLS started.";
|
||||
} else {
|
||||
// Currently drop the packet, but we might in future
|
||||
// decide to take this as evidence that the other
|
||||
// side is ready to do DTLS and start the handshake
|
||||
// on our end.
|
||||
LOG_J(LS_WARNING, this) << "Received packet before we know if we are "
|
||||
<< "doing DTLS or not; dropping.";
|
||||
LOG_J(LS_WARNING, this) << "Packet received before we know if we are "
|
||||
<< "doing DTLS or not.";
|
||||
}
|
||||
// Cache a client hello packet received before DTLS has actually started.
|
||||
if (IsDtlsClientHelloPacket(data, size)) {
|
||||
LOG_J(LS_INFO, this) << "Caching DTLS ClientHello packet until DTLS is "
|
||||
<< "started.";
|
||||
cached_client_hello_.SetData(data, size);
|
||||
} else {
|
||||
LOG_J(LS_INFO, this) << "Not a DTLS ClientHello packet; dropping.";
|
||||
}
|
||||
break;
|
||||
|
||||
@ -577,6 +587,21 @@ bool DtlsTransportChannelWrapper::MaybeStartDtls() {
|
||||
LOG_J(LS_INFO, this)
|
||||
<< "DtlsTransportChannelWrapper: Started DTLS handshake";
|
||||
set_dtls_state(DTLS_TRANSPORT_CONNECTING);
|
||||
// Now that the handshake has started, we can process a cached ClientHello
|
||||
// (if one exists).
|
||||
if (cached_client_hello_.size()) {
|
||||
if (ssl_role_ == rtc::SSL_SERVER) {
|
||||
LOG_J(LS_INFO, this) << "Handling cached DTLS ClientHello packet.";
|
||||
if (!HandleDtlsPacket(cached_client_hello_.data<char>(),
|
||||
cached_client_hello_.size())) {
|
||||
LOG_J(LS_ERROR, this) << "Failed to handle DTLS packet.";
|
||||
}
|
||||
} else {
|
||||
LOG_J(LS_WARNING, this) << "Discarding cached DTLS ClientHello packet "
|
||||
<< "because we don't have the server role.";
|
||||
}
|
||||
cached_client_hello_.Clear();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -237,6 +237,12 @@ class DtlsTransportChannelWrapper : public TransportChannelImpl {
|
||||
rtc::Buffer remote_fingerprint_value_;
|
||||
std::string remote_fingerprint_algorithm_;
|
||||
|
||||
// Cached DTLS ClientHello packet that was received before we started the
|
||||
// DTLS handshake. This could happen if the hello was received before the
|
||||
// transport channel became writable, or before a remote fingerprint was
|
||||
// received.
|
||||
rtc::Buffer cached_client_hello_;
|
||||
|
||||
RTC_DISALLOW_COPY_AND_ASSIGN(DtlsTransportChannelWrapper);
|
||||
};
|
||||
|
||||
|
||||
@ -33,25 +33,34 @@ static const char kIcePwd1[] = "TESTICEPWD00000000000001";
|
||||
static const size_t kPacketNumOffset = 8;
|
||||
static const size_t kPacketHeaderLen = 12;
|
||||
static const int kFakePacketId = 0x1234;
|
||||
static const int kTimeout = 10000;
|
||||
|
||||
static bool IsRtpLeadByte(uint8_t b) {
|
||||
return ((b & 0xC0) == 0x80);
|
||||
}
|
||||
|
||||
cricket::TransportDescription MakeTransportDescription(
|
||||
const rtc::scoped_refptr<rtc::RTCCertificate>& cert,
|
||||
cricket::ConnectionRole role) {
|
||||
rtc::scoped_ptr<rtc::SSLFingerprint> fingerprint;
|
||||
if (cert) {
|
||||
std::string digest_algorithm;
|
||||
cert->ssl_certificate().GetSignatureDigestAlgorithm(&digest_algorithm);
|
||||
fingerprint.reset(
|
||||
rtc::SSLFingerprint::Create(digest_algorithm, cert->identity()));
|
||||
}
|
||||
return cricket::TransportDescription(std::vector<std::string>(), kIceUfrag1,
|
||||
kIcePwd1, cricket::ICEMODE_FULL, role,
|
||||
fingerprint.get());
|
||||
}
|
||||
|
||||
using cricket::ConnectionRole;
|
||||
|
||||
enum Flags { NF_REOFFER = 0x1, NF_EXPECT_FAILURE = 0x2 };
|
||||
|
||||
class DtlsTestClient : public sigslot::has_slots<> {
|
||||
public:
|
||||
DtlsTestClient(const std::string& name)
|
||||
: name_(name),
|
||||
packet_size_(0),
|
||||
use_dtls_srtp_(false),
|
||||
ssl_max_version_(rtc::SSL_PROTOCOL_DTLS_12),
|
||||
negotiated_dtls_(false),
|
||||
received_dtls_client_hello_(false),
|
||||
received_dtls_server_hello_(false) {}
|
||||
DtlsTestClient(const std::string& name) : name_(name) {}
|
||||
void CreateCertificate(rtc::KeyType key_type) {
|
||||
certificate_ =
|
||||
rtc::RTCCertificate::Create(std::unique_ptr<rtc::SSLIdentity>(
|
||||
@ -185,9 +194,8 @@ class DtlsTestClient : public sigslot::has_slots<> {
|
||||
negotiated_dtls_ = (local_cert && remote_cert);
|
||||
}
|
||||
|
||||
bool Connect(DtlsTestClient* peer) {
|
||||
transport_->ConnectChannels();
|
||||
transport_->SetDestination(peer->transport_.get());
|
||||
bool Connect(DtlsTestClient* peer, bool asymmetric) {
|
||||
transport_->SetDestination(peer->transport_.get(), asymmetric);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -203,13 +211,29 @@ class DtlsTestClient : public sigslot::has_slots<> {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool all_raw_channels_writable() const {
|
||||
if (channels_.empty()) {
|
||||
return false;
|
||||
}
|
||||
for (cricket::DtlsTransportChannelWrapper* channel : channels_) {
|
||||
if (!channel->channel()->writable()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int received_dtls_client_hellos() const {
|
||||
return received_dtls_client_hellos_;
|
||||
}
|
||||
|
||||
void CheckRole(rtc::SSLRole role) {
|
||||
if (role == rtc::SSL_CLIENT) {
|
||||
ASSERT_FALSE(received_dtls_client_hello_);
|
||||
ASSERT_TRUE(received_dtls_server_hello_);
|
||||
ASSERT_EQ(0, received_dtls_client_hellos_);
|
||||
ASSERT_GT(received_dtls_server_hellos_, 0);
|
||||
} else {
|
||||
ASSERT_TRUE(received_dtls_client_hello_);
|
||||
ASSERT_FALSE(received_dtls_server_hello_);
|
||||
ASSERT_GT(received_dtls_client_hellos_, 0);
|
||||
ASSERT_EQ(0, received_dtls_server_hellos_);
|
||||
}
|
||||
}
|
||||
|
||||
@ -358,20 +382,18 @@ class DtlsTestClient : public sigslot::has_slots<> {
|
||||
|
||||
// Look at the handshake packets to see what role we played.
|
||||
// Check that non-handshake packets are DTLS data or SRTP bypass.
|
||||
if (negotiated_dtls_) {
|
||||
if (data[0] == 22 && size > 17) {
|
||||
if (data[13] == 1) {
|
||||
received_dtls_client_hello_ = true;
|
||||
} else if (data[13] == 2) {
|
||||
received_dtls_server_hello_ = true;
|
||||
}
|
||||
} else if (!(data[0] >= 20 && data[0] <= 22)) {
|
||||
ASSERT_TRUE(data[0] == 23 || IsRtpLeadByte(data[0]));
|
||||
if (data[0] == 23) {
|
||||
ASSERT_TRUE(VerifyEncryptedPacket(data, size));
|
||||
} else if (IsRtpLeadByte(data[0])) {
|
||||
ASSERT_TRUE(VerifyPacket(data, size, NULL));
|
||||
}
|
||||
if (data[0] == 22 && size > 17) {
|
||||
if (data[13] == 1) {
|
||||
++received_dtls_client_hellos_;
|
||||
} else if (data[13] == 2) {
|
||||
++received_dtls_server_hellos_;
|
||||
}
|
||||
} else if (negotiated_dtls_ && !(data[0] >= 20 && data[0] <= 22)) {
|
||||
ASSERT_TRUE(data[0] == 23 || IsRtpLeadByte(data[0]));
|
||||
if (data[0] == 23) {
|
||||
ASSERT_TRUE(VerifyEncryptedPacket(data, size));
|
||||
} else if (IsRtpLeadByte(data[0])) {
|
||||
ASSERT_TRUE(VerifyPacket(data, size, NULL));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -381,13 +403,13 @@ class DtlsTestClient : public sigslot::has_slots<> {
|
||||
rtc::scoped_refptr<rtc::RTCCertificate> certificate_;
|
||||
std::unique_ptr<cricket::FakeTransport> transport_;
|
||||
std::vector<cricket::DtlsTransportChannelWrapper*> channels_;
|
||||
size_t packet_size_;
|
||||
size_t packet_size_ = 0u;
|
||||
std::set<int> received_;
|
||||
bool use_dtls_srtp_;
|
||||
rtc::SSLProtocolVersion ssl_max_version_;
|
||||
bool negotiated_dtls_;
|
||||
bool received_dtls_client_hello_;
|
||||
bool received_dtls_server_hello_;
|
||||
bool use_dtls_srtp_ = false;
|
||||
rtc::SSLProtocolVersion ssl_max_version_ = rtc::SSL_PROTOCOL_DTLS_12;
|
||||
bool negotiated_dtls_ = false;
|
||||
int received_dtls_client_hellos_ = 0;
|
||||
int received_dtls_server_hellos_ = 0;
|
||||
rtc::SentPacket sent_packet_;
|
||||
};
|
||||
|
||||
@ -437,14 +459,14 @@ class DtlsTransportChannelTest : public testing::Test {
|
||||
bool Connect(ConnectionRole client1_role, ConnectionRole client2_role) {
|
||||
Negotiate(client1_role, client2_role);
|
||||
|
||||
bool rv = client1_.Connect(&client2_);
|
||||
bool rv = client1_.Connect(&client2_, false);
|
||||
EXPECT_TRUE(rv);
|
||||
if (!rv)
|
||||
return false;
|
||||
|
||||
EXPECT_TRUE_WAIT(
|
||||
client1_.all_channels_writable() && client2_.all_channels_writable(),
|
||||
10000);
|
||||
kTimeout);
|
||||
if (!client1_.all_channels_writable() || !client2_.all_channels_writable())
|
||||
return false;
|
||||
|
||||
@ -535,7 +557,7 @@ class DtlsTransportChannelTest : public testing::Test {
|
||||
LOG(LS_INFO) << "Expect packets, size=" << size;
|
||||
client2_.ExpectPackets(channel, size);
|
||||
client1_.SendPackets(channel, size, count, srtp);
|
||||
EXPECT_EQ_WAIT(count, client2_.NumPacketsReceived(), 10000);
|
||||
EXPECT_EQ_WAIT(count, client2_.NumPacketsReceived(), kTimeout);
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -828,11 +850,11 @@ TEST_F(DtlsTransportChannelTest, TestRenegotiateBeforeConnect) {
|
||||
|
||||
Renegotiate(&client1_, cricket::CONNECTIONROLE_ACTPASS,
|
||||
cricket::CONNECTIONROLE_ACTIVE, NF_REOFFER);
|
||||
bool rv = client1_.Connect(&client2_);
|
||||
bool rv = client1_.Connect(&client2_, false);
|
||||
EXPECT_TRUE(rv);
|
||||
EXPECT_TRUE_WAIT(
|
||||
client1_.all_channels_writable() && client2_.all_channels_writable(),
|
||||
10000);
|
||||
kTimeout);
|
||||
|
||||
TestTransfer(0, 1000, 100, true);
|
||||
TestTransfer(1, 1000, 100, true);
|
||||
@ -886,3 +908,69 @@ TEST_F(DtlsTransportChannelTest, TestCertificatesAfterConnect) {
|
||||
ASSERT_EQ(remote_cert2->ToPEMString(),
|
||||
certificate1->ssl_certificate().ToPEMString());
|
||||
}
|
||||
|
||||
// Test that DTLS completes promptly if a ClientHello is received before the
|
||||
// transport channel is writable (allowing a ServerHello to be sent).
|
||||
TEST_F(DtlsTransportChannelTest, TestReceiveClientHelloBeforeWritable) {
|
||||
MAYBE_SKIP_TEST(HaveDtls);
|
||||
PrepareDtls(true, true, rtc::KT_DEFAULT);
|
||||
// Exchange transport descriptions.
|
||||
Negotiate(cricket::CONNECTIONROLE_ACTPASS, cricket::CONNECTIONROLE_ACTIVE);
|
||||
|
||||
// Make client2_ writable, but not client1_.
|
||||
EXPECT_TRUE(client2_.Connect(&client1_, true));
|
||||
EXPECT_TRUE_WAIT(client2_.all_raw_channels_writable(), kTimeout);
|
||||
|
||||
// Expect a DTLS ClientHello to be sent even while client1_ isn't writable.
|
||||
EXPECT_EQ_WAIT(1, client1_.received_dtls_client_hellos(), kTimeout);
|
||||
EXPECT_FALSE(client1_.all_raw_channels_writable());
|
||||
|
||||
// Now make client1_ writable and expect the handshake to complete
|
||||
// without client2_ needing to retransmit the ClientHello.
|
||||
EXPECT_TRUE(client1_.Connect(&client2_, true));
|
||||
EXPECT_TRUE_WAIT(
|
||||
client1_.all_channels_writable() && client2_.all_channels_writable(),
|
||||
kTimeout);
|
||||
EXPECT_EQ(1, client1_.received_dtls_client_hellos());
|
||||
}
|
||||
|
||||
// Test that DTLS completes promptly if a ClientHello is received before the
|
||||
// transport channel has a remote fingerprint (allowing a ServerHello to be
|
||||
// sent).
|
||||
TEST_F(DtlsTransportChannelTest,
|
||||
TestReceiveClientHelloBeforeRemoteFingerprint) {
|
||||
MAYBE_SKIP_TEST(HaveDtls);
|
||||
PrepareDtls(true, true, rtc::KT_DEFAULT);
|
||||
client1_.SetupChannels(channel_ct_, cricket::ICEROLE_CONTROLLING);
|
||||
client2_.SetupChannels(channel_ct_, cricket::ICEROLE_CONTROLLED);
|
||||
|
||||
// Make client2_ writable and give it local/remote certs, but don't yet give
|
||||
// client1_ a remote fingerprint.
|
||||
client1_.transport()->SetLocalTransportDescription(
|
||||
MakeTransportDescription(client1_.certificate(),
|
||||
cricket::CONNECTIONROLE_ACTPASS),
|
||||
cricket::CA_OFFER, nullptr);
|
||||
client2_.Negotiate(&client1_, cricket::CA_ANSWER,
|
||||
cricket::CONNECTIONROLE_ACTIVE,
|
||||
cricket::CONNECTIONROLE_ACTPASS, 0);
|
||||
EXPECT_TRUE(client2_.Connect(&client1_, true));
|
||||
EXPECT_TRUE_WAIT(client2_.all_raw_channels_writable(), kTimeout);
|
||||
|
||||
// Expect a DTLS ClientHello to be sent even while client1_ doesn't have a
|
||||
// remote fingerprint.
|
||||
EXPECT_EQ_WAIT(1, client1_.received_dtls_client_hellos(), kTimeout);
|
||||
EXPECT_FALSE(client1_.all_raw_channels_writable());
|
||||
|
||||
// Now make give client1_ its remote fingerprint and make it writable, and
|
||||
// expect the handshake to complete without client2_ needing to retransmit
|
||||
// the ClientHello.
|
||||
client1_.transport()->SetRemoteTransportDescription(
|
||||
MakeTransportDescription(client2_.certificate(),
|
||||
cricket::CONNECTIONROLE_ACTIVE),
|
||||
cricket::CA_ANSWER, nullptr);
|
||||
EXPECT_TRUE(client1_.Connect(&client2_, true));
|
||||
EXPECT_TRUE_WAIT(
|
||||
client1_.all_channels_writable() && client2_.all_channels_writable(),
|
||||
kTimeout);
|
||||
EXPECT_EQ(1, client1_.received_dtls_client_hellos());
|
||||
}
|
||||
|
||||
@ -141,20 +141,22 @@ class FakeTransportChannel : public TransportChannelImpl,
|
||||
|
||||
void SetWritable(bool writable) { set_writable(writable); }
|
||||
|
||||
void SetDestination(FakeTransportChannel* dest) {
|
||||
// Simulates the two transport channels connecting to each other.
|
||||
// If |asymmetric| is true this method only affects this FakeTransportChannel.
|
||||
// If false, it affects |dest| as well.
|
||||
void SetDestination(FakeTransportChannel* dest, bool asymmetric = false) {
|
||||
if (state_ == STATE_CONNECTING && dest) {
|
||||
// This simulates the delivery of candidates.
|
||||
dest_ = dest;
|
||||
dest_->dest_ = this;
|
||||
if (local_cert_ && dest_->local_cert_) {
|
||||
do_dtls_ = true;
|
||||
dest_->do_dtls_ = true;
|
||||
NegotiateSrtpCiphers();
|
||||
}
|
||||
state_ = STATE_CONNECTED;
|
||||
dest_->state_ = STATE_CONNECTED;
|
||||
set_writable(true);
|
||||
dest_->set_writable(true);
|
||||
if (!asymmetric) {
|
||||
dest->SetDestination(this, true);
|
||||
}
|
||||
} else if (state_ == STATE_CONNECTED && !dest) {
|
||||
// Simulates loss of connectivity, by asymmetrically forgetting dest_.
|
||||
dest_ = nullptr;
|
||||
@ -282,20 +284,6 @@ class FakeTransportChannel : public TransportChannelImpl,
|
||||
return false;
|
||||
}
|
||||
|
||||
void NegotiateSrtpCiphers() {
|
||||
for (std::vector<int>::const_iterator it1 = srtp_ciphers_.begin();
|
||||
it1 != srtp_ciphers_.end(); ++it1) {
|
||||
for (std::vector<int>::const_iterator it2 = dest_->srtp_ciphers_.begin();
|
||||
it2 != dest_->srtp_ciphers_.end(); ++it2) {
|
||||
if (*it1 == *it2) {
|
||||
chosen_crypto_suite_ = *it1;
|
||||
dest_->chosen_crypto_suite_ = *it2;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool GetStats(ConnectionInfos* infos) override {
|
||||
ConnectionInfo info;
|
||||
infos->clear();
|
||||
@ -311,6 +299,19 @@ class FakeTransportChannel : public TransportChannelImpl,
|
||||
}
|
||||
|
||||
private:
|
||||
void NegotiateSrtpCiphers() {
|
||||
for (std::vector<int>::const_iterator it1 = srtp_ciphers_.begin();
|
||||
it1 != srtp_ciphers_.end(); ++it1) {
|
||||
for (std::vector<int>::const_iterator it2 = dest_->srtp_ciphers_.begin();
|
||||
it2 != dest_->srtp_ciphers_.end(); ++it2) {
|
||||
if (*it1 == *it2) {
|
||||
chosen_crypto_suite_ = *it1;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum State { STATE_INIT, STATE_CONNECTING, STATE_CONNECTED };
|
||||
FakeTransportChannel* dest_ = nullptr;
|
||||
State state_ = STATE_INIT;
|
||||
@ -359,11 +360,14 @@ class FakeTransport : public Transport {
|
||||
// If async, will send packets by "Post"-ing to message queue instead of
|
||||
// synchronously "Send"-ing.
|
||||
void SetAsync(bool async) { async_ = async; }
|
||||
void SetDestination(FakeTransport* dest) {
|
||||
|
||||
// If |asymmetric| is true, only set the destination for this transport, and
|
||||
// not |dest|.
|
||||
void SetDestination(FakeTransport* dest, bool asymmetric = false) {
|
||||
dest_ = dest;
|
||||
for (const auto& kv : channels_) {
|
||||
kv.second->SetLocalCertificate(certificate_);
|
||||
SetChannelDestination(kv.first, kv.second);
|
||||
SetChannelDestination(kv.first, kv.second, asymmetric);
|
||||
}
|
||||
}
|
||||
|
||||
@ -417,7 +421,7 @@ class FakeTransport : public Transport {
|
||||
FakeTransportChannel* channel = new FakeTransportChannel(name(), component);
|
||||
channel->set_ssl_max_protocol_version(ssl_max_version_);
|
||||
channel->SetAsync(async_);
|
||||
SetChannelDestination(component, channel);
|
||||
SetChannelDestination(component, channel, false);
|
||||
channels_[component] = channel;
|
||||
return channel;
|
||||
}
|
||||
@ -433,15 +437,17 @@ class FakeTransport : public Transport {
|
||||
return (it != channels_.end()) ? it->second : nullptr;
|
||||
}
|
||||
|
||||
void SetChannelDestination(int component, FakeTransportChannel* channel) {
|
||||
void SetChannelDestination(int component,
|
||||
FakeTransportChannel* channel,
|
||||
bool asymmetric) {
|
||||
FakeTransportChannel* dest_channel = nullptr;
|
||||
if (dest_) {
|
||||
dest_channel = dest_->GetFakeChannel(component);
|
||||
if (dest_channel) {
|
||||
if (dest_channel && !asymmetric) {
|
||||
dest_channel->SetLocalCertificate(dest_->certificate_);
|
||||
}
|
||||
}
|
||||
channel->SetDestination(dest_channel);
|
||||
channel->SetDestination(dest_channel, asymmetric);
|
||||
}
|
||||
|
||||
// Note, this is distinct from the Channel map owned by Transport.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user