diff --git a/p2p/dtls/dtls_transport_unittest.cc b/p2p/dtls/dtls_transport_unittest.cc index 14e774d4ff..cc94f1881b 100644 --- a/p2p/dtls/dtls_transport_unittest.cc +++ b/p2p/dtls/dtls_transport_unittest.cc @@ -10,6 +10,7 @@ #include "p2p/dtls/dtls_transport.h" +#include #include #include #include @@ -507,34 +508,6 @@ TEST_F(DtlsTransportTest, KeyingMaterialExporter) { EXPECT_EQ(client1_out, client2_out); } -class DtlsTransportVersionTest - : public DtlsTransportTestBase, - public ::testing::TestWithParam< - ::testing::tuple> { -}; - -// Will test every combination of 1.0/1.2/1.3 on the client and server. -// DTLS will negotiate an effective version (the min of client & sewrver). -INSTANTIATE_TEST_SUITE_P( - DtlsTransportVersionTest, - DtlsTransportVersionTest, - ::testing::Combine(::testing::Values(rtc::SSL_PROTOCOL_DTLS_10, - rtc::SSL_PROTOCOL_DTLS_12, - rtc::SSL_PROTOCOL_DTLS_13), - ::testing::Values(rtc::SSL_PROTOCOL_DTLS_10, - rtc::SSL_PROTOCOL_DTLS_12, - rtc::SSL_PROTOCOL_DTLS_13))); - -// Test that an acceptable cipher suite is negotiated when different versions -// of DTLS are supported. Note that it's IsAcceptableCipher that does the actual -// work. -TEST_P(DtlsTransportVersionTest, TestCipherSuiteNegotiation) { - PrepareDtls(rtc::KT_DEFAULT); - SetMaxProtocolVersions(::testing::get<0>(GetParam()), - ::testing::get<1>(GetParam())); - ASSERT_TRUE(Connect()); -} - enum HandshakeTestEvent { EV_CLIENT_SEND = 0, EV_SERVER_SEND = 1, @@ -542,6 +515,9 @@ enum HandshakeTestEvent { EV_SERVER_RECV = 3, EV_CLIENT_WRITABLE = 4, EV_SERVER_WRITABLE = 5, + + EV_CLIENT_SEND_DROPPED = 6, + EV_SERVER_SEND_DROPPED = 7, }; static const std::vector dtls_12_handshake_events{ @@ -583,106 +559,243 @@ static const struct { {rtc::kDtls13VersionBytes, dtls_13_handshake_events}, }; -bool LogRecv(absl::string_view name, - const rtc::CopyOnWriteBuffer& packet, - uint64_t timestamp_ms) { - RTC_LOG(LS_INFO) << "time=" << timestamp_ms << " : " << name - << ": ReceivePacket packet len=" << packet.size() - << ", data[0]: " << static_cast(packet.data()[0]); - return false; -} - -bool LogSend(absl::string_view name, - uint64_t timestamp_ms, - bool drop, - const char* data, - size_t len) { - if (drop) { - RTC_LOG(LS_INFO) << "time=" << timestamp_ms << " : " << name - << ": dropping packet len=" << len - << ", data[0]: " << static_cast(data[0]); - } else { - RTC_LOG(LS_INFO) << "time=" << timestamp_ms << " : " << name - << ": SendPacket, len=" << len - << ", data[0]: " << static_cast(data[0]); +class DtlsTransportVersionTest + : public DtlsTransportTestBase, + public ::testing::TestWithParam< + ::testing::tuple> { + public: + void Prepare() { + PrepareDtls(rtc::KT_DEFAULT); + SetMaxProtocolVersions(::testing::get<0>(GetParam()), + ::testing::get<1>(GetParam())); } - return drop; -} -TEST_P(DtlsTransportVersionTest, TestHandshakeFlights) { - // We can only change the retransmission schedule with a recently-added - // BoringSSL API. Skip the test if not built with BoringSSL. - MAYBE_SKIP_TEST(IsBoringSsl); + // Run DTLS handshake. + // - store events in `events` + // - drop packets as specified in `packets_to_drop` + std::pair> + RunHandshake(std::set packets_to_drop) { + Negotiate(/* client1_server= */ false); - // Disable any forcing of Dtls1.3. - webrtc::test::ScopedFieldTrials trials("WebRTC-ForceDtls13/Off/"); - PrepareDtls(rtc::KT_DEFAULT); - SetMaxProtocolVersions(::testing::get<0>(GetParam()), - ::testing::get<1>(GetParam())); + std::vector events; + auto start_time_ns = fake_clock_.TimeNanos(); + client1_.fake_ice_transport()->set_rtt_estimate(50, true); + client2_.fake_ice_transport()->set_rtt_estimate(50, true); - Negotiate(/* client1_server= */ false); + client1_.fake_ice_transport()->set_packet_recv_filter( + [&](auto packet, auto timestamp_us) { + events.push_back(EV_CLIENT_RECV); + return LogRecv("client", packet, + (timestamp_us - start_time_ns / 1000) / 1000); + }); + client2_.fake_ice_transport()->set_packet_recv_filter( + [&](auto packet, auto timestamp_us) { + events.push_back(EV_SERVER_RECV); + return LogRecv("server", packet, + (timestamp_us - start_time_ns / 1000) / 1000); + }); + client1_.set_writable_callback( + [&]() { events.push_back(EV_CLIENT_WRITABLE); }); + client2_.set_writable_callback( + [&]() { events.push_back(EV_SERVER_WRITABLE); }); - std::vector events; + unsigned packet_num = 0; + client1_.fake_ice_transport()->set_packet_send_filter( + [&](auto data, auto len, auto options, auto flags) { + bool drop = packets_to_drop.find(packet_num) != packets_to_drop.end(); + packet_num++; + if (!drop) { + events.push_back(EV_CLIENT_SEND); + } else { + events.push_back(EV_CLIENT_SEND_DROPPED); + } + auto diff_ms = (fake_clock_.TimeNanos() - start_time_ns) / 1000000; + return LogSend("client", diff_ms, drop, data, len); + }); + client2_.fake_ice_transport()->set_packet_send_filter( + [&](auto data, auto len, auto options, auto flags) { + bool drop = packets_to_drop.find(packet_num) != packets_to_drop.end(); + packet_num++; + if (!drop) { + events.push_back(EV_SERVER_SEND); + } else { + events.push_back(EV_SERVER_SEND_DROPPED); + } + auto diff_ms = (fake_clock_.TimeNanos() - start_time_ns) / 1000000; + return LogSend("server", diff_ms, drop, data, len); + }); - auto start_time_ns = fake_clock_.TimeNanos(); - client1_.fake_ice_transport()->set_rtt_estimate(50, true); - client2_.fake_ice_transport()->set_rtt_estimate(50, true); + EXPECT_TRUE(client1_.Connect(&client2_, false)); - client1_.fake_ice_transport()->set_packet_recv_filter( - [&](auto packet, auto timestamp_us) { - events.push_back(EV_CLIENT_RECV); - return LogRecv("client", packet, - (timestamp_us - start_time_ns / 1000) / 1000); - }); - client2_.fake_ice_transport()->set_packet_recv_filter( - [&](auto packet, auto timestamp_us) { - events.push_back(EV_SERVER_RECV); - return LogRecv("server", packet, - (timestamp_us - start_time_ns / 1000) / 1000); - }); - client1_.set_writable_callback( - [&]() { events.push_back(EV_CLIENT_WRITABLE); }); - client2_.set_writable_callback( - [&]() { events.push_back(EV_SERVER_WRITABLE); }); + EXPECT_TRUE_SIMULATED_WAIT(client1_.dtls_transport()->writable() && + client2_.dtls_transport()->writable(), + kTimeout, fake_clock_); - client1_.fake_ice_transport()->set_packet_send_filter( - [&](auto data, auto len, auto options, auto flags) { - events.push_back(EV_CLIENT_SEND); - bool drop = false; - auto diff_ms = (fake_clock_.TimeNanos() - start_time_ns) / 1000000; - return LogSend("client", diff_ms, drop, data, len); - }); - client2_.fake_ice_transport()->set_packet_send_filter( - [&](auto data, auto len, auto options, auto flags) { - events.push_back(EV_SERVER_SEND); - bool drop = false; - auto diff_ms = (fake_clock_.TimeNanos() - start_time_ns) / 1000000; - return LogSend("server", diff_ms, drop, data, len); - }); + client1_.fake_ice_transport()->set_packet_send_filter(nullptr); + client2_.fake_ice_transport()->set_packet_send_filter(nullptr); + client1_.fake_ice_transport()->set_packet_recv_filter(nullptr); + client2_.fake_ice_transport()->set_packet_recv_filter(nullptr); - EXPECT_TRUE(client1_.Connect(&client2_, false)); + auto dtls_version_bytes = client1_.GetVersionBytes(); + EXPECT_EQ(dtls_version_bytes, client2_.GetVersionBytes()); + return std::make_pair(*dtls_version_bytes, std::move(events)); + } - EXPECT_TRUE_SIMULATED_WAIT(client1_.dtls_transport()->writable() && - client2_.dtls_transport()->writable(), - kTimeout, fake_clock_); - - client1_.fake_ice_transport()->set_packet_send_filter(nullptr); - client2_.fake_ice_transport()->set_packet_send_filter(nullptr); - client1_.fake_ice_transport()->set_packet_recv_filter(nullptr); - client2_.fake_ice_transport()->set_packet_recv_filter(nullptr); - - auto dtls_version_bytes = client1_.GetVersionBytes(); - ASSERT_EQ(dtls_version_bytes, client2_.GetVersionBytes()); - - std::vector expect; - for (const auto e : kEventsPerVersion) { - if (e.version_bytes == dtls_version_bytes) { - expect = e.events; - break; + int GetExpectedDtlsVersionBytes() { + int version = std::min(static_cast(::testing::get<0>(GetParam())), + static_cast(::testing::get<1>(GetParam()))); + if (version == rtc::SSL_PROTOCOL_DTLS_13) { + return rtc::kDtls13VersionBytes; + } else { + return rtc::kDtls12VersionBytes; } } + + std::vector GetExpectedEvents(int dtls_version_bytes) { + for (const auto e : kEventsPerVersion) { + if (e.version_bytes == dtls_version_bytes) { + return e.events; + } + } + return {}; + } + + private: + bool LogRecv(absl::string_view name, + const rtc::CopyOnWriteBuffer& packet, + uint64_t timestamp_ms) { + RTC_LOG(LS_INFO) << "time=" << timestamp_ms << " : " << name + << ": ReceivePacket packet len=" << packet.size() + << ", data[0]: " << static_cast(packet.data()[0]); + return false; + } + + bool LogSend(absl::string_view name, + uint64_t timestamp_ms, + bool drop, + const char* data, + size_t len) { + if (drop) { + RTC_LOG(LS_INFO) << "time=" << timestamp_ms << " : " << name + << ": dropping packet len=" << len + << ", data[0]: " << static_cast(data[0]); + } else { + RTC_LOG(LS_INFO) << "time=" << timestamp_ms << " : " << name + << ": SendPacket, len=" << len + << ", data[0]: " << static_cast(data[0]); + } + return drop; + } +}; + +// Will test every combination of 1.0/1.2/1.3 on the client and server. +// DTLS will negotiate an effective version (the min of client & sewrver). +INSTANTIATE_TEST_SUITE_P( + DtlsTransportVersionTest, + DtlsTransportVersionTest, + ::testing::Combine(::testing::Values(rtc::SSL_PROTOCOL_DTLS_10, + rtc::SSL_PROTOCOL_DTLS_12, + rtc::SSL_PROTOCOL_DTLS_13), + ::testing::Values(rtc::SSL_PROTOCOL_DTLS_10, + rtc::SSL_PROTOCOL_DTLS_12, + rtc::SSL_PROTOCOL_DTLS_13))); + +// Test that an acceptable cipher suite is negotiated when different versions +// of DTLS are supported. Note that it's IsAcceptableCipher that does the actual +// work. +TEST_P(DtlsTransportVersionTest, CipherSuiteNegotiation) { + Prepare(); + ASSERT_TRUE(Connect()); +} + +TEST_P(DtlsTransportVersionTest, HandshakeFlights) { + Prepare(); + auto [dtls_version_bytes, events] = RunHandshake({}); + RTC_LOG(LS_INFO) << "Verifying events with ssl version bytes= " - << *dtls_version_bytes; + << dtls_version_bytes; + auto expect = GetExpectedEvents(dtls_version_bytes); + EXPECT_EQ(events, expect); +} + +TEST_P(DtlsTransportVersionTest, HandshakeLoseFirstClientPacket) { + MAYBE_SKIP_TEST(IsBoringSsl); + + Prepare(); + auto [dtls_version_bytes, events] = RunHandshake({/* packet_num= */ 0}); + + auto expect = GetExpectedEvents(dtls_version_bytes); + + // If first packet is lost...it is simply retransmitted by client, + // nothing else changes. + expect.insert(expect.begin(), EV_CLIENT_SEND_DROPPED); + + EXPECT_EQ(events, expect); +} + +TEST_P(DtlsTransportVersionTest, HandshakeLoseSecondClientPacket) { + MAYBE_SKIP_TEST(IsBoringSsl); + + Prepare(); + auto [dtls_version_bytes, events] = RunHandshake({/* packet_num= */ 2}); + + std::vector expect; + + switch (dtls_version_bytes) { + case rtc::kDtls12VersionBytes: + expect = { + // Flight 1 + EV_CLIENT_SEND, + EV_SERVER_RECV, + EV_SERVER_SEND, + EV_CLIENT_RECV, + + // Flight 2 + EV_CLIENT_SEND_DROPPED, + + // Server retransmit. + EV_SERVER_SEND, + // Client retransmit. + EV_CLIENT_SEND, + // Client receive retransmit => Do nothing, has already retransmitted. + EV_CLIENT_RECV, + // Handshake resume. + EV_SERVER_RECV, + EV_SERVER_SEND, + EV_SERVER_WRITABLE, + EV_CLIENT_RECV, + EV_CLIENT_WRITABLE, + }; + break; + case rtc::kDtls13VersionBytes: + expect = { + // Flight 1 + EV_CLIENT_SEND, + EV_SERVER_RECV, + EV_SERVER_SEND, + EV_CLIENT_RECV, + + // Flight 2 + EV_CLIENT_SEND_DROPPED, + // Client doesn't know packet it is dropped, so it becomes writable. + EV_CLIENT_WRITABLE, + + // Server retransmit. + EV_SERVER_SEND, + // Client retransmit. + EV_CLIENT_SEND, + + // Client receive retransmit => Do nothing, has already retransmitted. + EV_CLIENT_RECV, + // Handshake resume. + EV_SERVER_RECV, + EV_SERVER_SEND, + EV_SERVER_WRITABLE, + }; + break; + default: + RTC_CHECK(false) << "Unknown dtls version bytes: " << dtls_version_bytes; + } EXPECT_EQ(events, expect); } diff --git a/rtc_base/openssl_stream_adapter.cc b/rtc_base/openssl_stream_adapter.cc index d99fd2179a..cd08c54568 100644 --- a/rtc_base/openssl_stream_adapter.cc +++ b/rtc_base/openssl_stream_adapter.cc @@ -797,7 +797,11 @@ void OpenSSLStreamAdapter::SetTimeout(int delay_ms) { Error("DTLSv1_handle_timeout", res, -1, true); return webrtc::TimeDelta::PlusInfinity(); } - ContinueSSL(); + // We check the timer even after SSL_CONNECTED, + // but ContinueSSL() is only needed when SSL_CONNECTING + if (state_ == SSL_CONNECTING) { + ContinueSSL(); + } } else { RTC_DCHECK_NOTREACHED(); } @@ -860,7 +864,7 @@ int OpenSSLStreamAdapter::ContinueSSL() { switch (ssl_error) { case SSL_ERROR_NONE: - RTC_DLOG(LS_VERBOSE) << " -- success"; + RTC_DLOG(LS_INFO) << " -- success"; // By this point, OpenSSL should have given us a certificate, or errored // out if one was missing. RTC_DCHECK(peer_cert_chain_ || !GetClientAuthEnabled()); @@ -878,16 +882,11 @@ int OpenSSLStreamAdapter::ContinueSSL() { FireEvent(SE_OPEN | SE_READ | SE_WRITE, 0); } break; - case SSL_ERROR_WANT_READ: { - RTC_DLOG(LS_VERBOSE) << " -- error want read"; - struct timeval timeout; - if (DTLSv1_get_timeout(ssl_, &timeout)) { - int delay = timeout.tv_sec * 1000 + timeout.tv_usec / 1000; - SetTimeout(delay); - } - } break; + case SSL_ERROR_WANT_READ: + RTC_DLOG(LS_INFO) << " -- error when we want to read"; + break; case SSL_ERROR_WANT_WRITE: - RTC_DLOG(LS_VERBOSE) << " -- error want write"; + RTC_DLOG(LS_INFO) << " -- error when we want to write"; break; case SSL_ERROR_ZERO_RETURN: default: { @@ -905,6 +904,12 @@ int OpenSSLStreamAdapter::ContinueSSL() { } } + struct timeval timeout; + if (DTLSv1_get_timeout(ssl_, &timeout)) { + int delay = timeout.tv_sec * 1000 + timeout.tv_usec / 1000; + SetTimeout(delay); + } + return 0; }