DTLS 1.3 - patch 2

- add DTLS1.3 ciphers (without KeyType)
- remove code in dtls_transport.cc that tries to parse DTLS packet
- cleanup some test
- start on test for packet loss during dtls handshake (more to come!)

After this patch is submitted, it is possible
to set max version = dtls1.3 and it will active
but DON'T do it yet.

BUG=webrtc:383141571

Change-Id: I6f9a120c53415ccee7a560ea83bd0c2636702997
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/371300
Commit-Queue: Jonas Oreland <jonaso@webrtc.org>
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#43595}
This commit is contained in:
Jonas Oreland 2024-12-18 09:18:21 +01:00 committed by WebRTC LUCI CQ
parent 486e3deba0
commit ac40185001
12 changed files with 419 additions and 73 deletions

View File

@ -1226,6 +1226,7 @@ if (rtc_include_tests) {
"../rtc_base/network:sent_packet",
"../rtc_base/third_party/sigslot",
"../system_wrappers:metrics",
"../test:field_trial",
"../test:rtc_expect_death",
"../test:scoped_key_value_config",
"../test:test_support",

View File

@ -63,6 +63,7 @@ class FakeIceTransport : public IceTransportInternal {
: rtc::Thread::Current()) {
RTC_DCHECK(network_thread_);
}
// Must be called either on the network thread, or after the network thread
// has been shut down.
~FakeIceTransport() override {
@ -289,7 +290,7 @@ class FakeIceTransport : public IceTransportInternal {
return true;
}
std::optional<int> GetRttEstimate() override { return std::nullopt; }
std::optional<int> GetRttEstimate() override { return rtt_estimate_; }
const Connection* selected_connection() const override { return nullptr; }
std::optional<const CandidatePair> GetSelectedCandidatePair() const override {
@ -314,25 +315,31 @@ class FakeIceTransport : public IceTransportInternal {
int SendPacket(const char* data,
size_t len,
const rtc::PacketOptions& options,
int /* flags */) override {
int flags) override {
RTC_DCHECK_RUN_ON(network_thread_);
if (!dest_) {
return -1;
}
send_packet_.AppendData(data, len);
if (!combine_outgoing_packets_ || send_packet_.size() > len) {
rtc::CopyOnWriteBuffer packet(std::move(send_packet_));
if (async_) {
network_thread_->PostDelayedTask(
SafeTask(task_safety_.flag(),
[this, packet] {
RTC_DCHECK_RUN_ON(network_thread_);
FakeIceTransport::SendPacketInternal(packet);
}),
TimeDelta::Millis(async_delay_ms_));
} else {
SendPacketInternal(packet);
if (packet_send_filter_func_ &&
packet_send_filter_func_(data, len, options, flags)) {
RTC_DLOG(LS_INFO) << name_ << ": dropping packet len=" << len
<< ", data[0]: " << static_cast<uint8_t>(data[0]);
} else {
send_packet_.AppendData(data, len);
if (!combine_outgoing_packets_ || send_packet_.size() > len) {
rtc::CopyOnWriteBuffer packet(std::move(send_packet_));
if (async_) {
network_thread_->PostDelayedTask(
SafeTask(task_safety_.flag(),
[this, packet] {
RTC_DCHECK_RUN_ON(network_thread_);
FakeIceTransport::SendPacketInternal(packet);
}),
TimeDelta::Millis(async_delay_ms_));
} else {
SendPacketInternal(packet);
}
}
}
rtc::SentPacket sent_packet(options.packet_id, rtc::TimeMillis());
@ -376,6 +383,38 @@ class FakeIceTransport : public IceTransportInternal {
});
}
// If `func` return TRUE means that packet will be dropped.
void set_packet_send_filter(
absl::AnyInvocable<bool(const char* data,
size_t len,
const rtc::PacketOptions& options,
int /* flags */)> func) {
RTC_DCHECK_RUN_ON(network_thread_);
RTC_DLOG(LS_INFO) << this << ": "
<< ((func == nullptr) ? "Clearing" : "Setting")
<< " packet send filter func";
packet_send_filter_func_ = std::move(func);
}
// If `func` return TRUE means that packet will be dropped.
void set_packet_recv_filter(
absl::AnyInvocable<bool(const rtc::CopyOnWriteBuffer& packet,
uint32_t time_ms)> func) {
RTC_DCHECK_RUN_ON(network_thread_);
RTC_DLOG(LS_INFO) << this << ": "
<< ((func == nullptr) ? "Clearing" : "Setting")
<< " packet recv filter func";
packet_recv_filter_func_ = std::move(func);
}
void set_rtt_estimate(std::optional<int> value, bool set_async = false) {
rtt_estimate_ = value;
if (value && set_async) {
SetAsync(true);
SetAsyncDelay(*value / 2);
}
}
private:
void set_writable(bool writable)
RTC_EXCLUSIVE_LOCKS_REQUIRED(network_thread_) {
@ -403,8 +442,21 @@ class FakeIceTransport : public IceTransportInternal {
RTC_EXCLUSIVE_LOCKS_REQUIRED(network_thread_) {
if (dest_) {
last_sent_packet_ = packet;
dest_->NotifyPacketReceived(rtc::ReceivedPacket::CreateFromLegacy(
packet.data(), packet.size(), rtc::TimeMicros()));
dest_->ReceivePacketInternal(packet);
}
}
void ReceivePacketInternal(const rtc::CopyOnWriteBuffer& packet) {
RTC_DCHECK_RUN_ON(network_thread_);
auto now = rtc::TimeMicros();
if (packet_recv_filter_func_ && packet_recv_filter_func_(packet, now)) {
RTC_DLOG(LS_INFO) << name_
<< ": dropping packet at receiver len=" << packet.size()
<< ", data[0]: "
<< static_cast<uint8_t>(packet.data()[0]);
} else {
NotifyPacketReceived(rtc::ReceivedPacket::CreateFromLegacy(
packet.data(), packet.size(), now));
}
}
@ -438,6 +490,13 @@ class FakeIceTransport : public IceTransportInternal {
rtc::CopyOnWriteBuffer last_sent_packet_ RTC_GUARDED_BY(network_thread_);
rtc::Thread* const network_thread_;
webrtc::ScopedTaskSafetyDetached task_safety_;
std::optional<int> rtt_estimate_;
// If filter func return TRUE means that packet will be dropped.
absl::AnyInvocable<bool(const char*, size_t, const rtc::PacketOptions&, int)>
packet_send_filter_func_ RTC_GUARDED_BY(network_thread_) = nullptr;
absl::AnyInvocable<bool(const rtc::CopyOnWriteBuffer&, uint64_t)>
packet_recv_filter_func_ RTC_GUARDED_BY(network_thread_) = nullptr;
};
class FakeIceTransportWrapper : public webrtc::IceTransportInterface {

View File

@ -801,24 +801,8 @@ void DtlsTransport::MaybeStartDtls() {
// Called from OnReadPacket when a DTLS packet is received.
bool DtlsTransport::HandleDtlsPacket(rtc::ArrayView<const uint8_t> payload) {
// Sanity check we're not passing junk that
// just looks like DTLS.
const uint8_t* tmp_data = payload.data();
size_t tmp_size = payload.size();
while (tmp_size > 0) {
if (tmp_size < kDtlsRecordHeaderLen)
return false; // Too short for the header
size_t record_len = (tmp_data[11] << 8) | (tmp_data[12]);
if ((record_len + kDtlsRecordHeaderLen) > tmp_size)
return false; // Body too short
tmp_data += record_len + kDtlsRecordHeaderLen;
tmp_size -= record_len + kDtlsRecordHeaderLen;
}
// Looks good. Pass to the SIC which ends up being passed to
// the DTLS stack.
// Pass to the StreamInterfaceChannel which ends up being passed to the DTLS
// stack.
return downward_->OnPacketReceived(
reinterpret_cast<const char*>(payload.data()), payload.size());
}

View File

@ -17,6 +17,7 @@
#include <optional>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "absl/strings/string_view.h"
@ -29,6 +30,7 @@
#include "p2p/base/packet_transport_internal.h"
#include "p2p/base/transport_description.h"
#include "p2p/dtls/dtls_transport_internal.h"
#include "p2p/dtls/dtls_utils.h"
#include "rtc_base/buffer.h"
#include "rtc_base/byte_order.h"
#include "rtc_base/fake_clock.h"
@ -41,6 +43,7 @@
#include "rtc_base/ssl_stream_adapter.h"
#include "rtc_base/third_party/sigslot/sigslot.h"
#include "rtc_base/thread.h"
#include "test/field_trial.h"
#include "test/gtest.h"
#define MAYBE_SKIP_TEST(feature) \
@ -56,8 +59,10 @@ static const size_t kPacketHeaderLen = 12;
static const int kFakePacketId = 0x1234;
static const int kTimeout = 10000;
const uint8_t kRtpLeadByte = 0x80;
static bool IsRtpLeadByte(uint8_t b) {
return ((b & 0xC0) == 0x80);
return b == kRtpLeadByte;
}
// `modify_digest` is used to set modified fingerprints that are meant to fail
@ -100,7 +105,8 @@ class DtlsTestClient : public sigslot::has_slots<> {
dtls_transport_ = nullptr;
fake_ice_transport_ = nullptr;
fake_ice_transport_.reset(new FakeIceTransport("fake", 0));
fake_ice_transport_.reset(
new FakeIceTransport(absl::StrCat("fake-", name_), 0));
fake_ice_transport_->SetAsync(true);
fake_ice_transport_->SetAsyncDelay(async_delay_ms);
fake_ice_transport_->SetIceRole(role);
@ -148,6 +154,14 @@ class DtlsTestClient : public sigslot::has_slots<> {
return received_dtls_server_hellos_;
}
std::optional<int> GetVersionBytes() {
int value;
if (dtls_transport_->GetSslVersionBytes(&value)) {
return value;
}
return std::nullopt;
}
void CheckRole(rtc::SSLRole role) {
if (role == rtc::SSL_CLIENT) {
ASSERT_EQ(0, received_dtls_client_hellos_);
@ -188,7 +202,7 @@ class DtlsTestClient : public sigslot::has_slots<> {
// Fill the packet with a known value and a sequence number to check
// against, and make sure that it doesn't look like DTLS.
memset(packet.get(), sent & 0xff, size);
packet[0] = (srtp) ? 0x80 : 0x00;
packet[0] = (srtp) ? kRtpLeadByte : 0x00;
rtc::SetBE32(packet.get() + kPacketNumOffset,
static_cast<uint32_t>(sent));
@ -257,9 +271,15 @@ class DtlsTestClient : public sigslot::has_slots<> {
}
// Transport callbacks
void set_writable_callback(absl::AnyInvocable<void()> func) {
writable_func_ = std::move(func);
}
void OnTransportWritableState(rtc::PacketTransportInternal* transport) {
RTC_LOG(LS_INFO) << name_ << ": Transport '" << transport->transport_name()
<< "' is writable";
if (writable_func_) {
writable_func_();
}
}
void OnTransportReadPacket(rtc::PacketTransportInternal* /* transport */,
@ -297,20 +317,19 @@ class DtlsTestClient : public sigslot::has_slots<> {
// Look at the handshake packets to see what role we played.
// Check that non-handshake packets are DTLS data or SRTP bypass.
const uint8_t* data = packet.payload().data();
size_t size = packet.payload().size();
if (data[0] == 22 && size > 17) {
if (data[13] == 1) {
if (IsDtlsHandshakePacket(packet.payload())) {
if (IsDtlsClientHelloPacket(packet.payload())) {
++received_dtls_client_hellos_;
} else if (data[13] == 2) {
++received_dtls_server_hellos_;
}
} else if (dtls_transport_->IsDtlsActive() &&
!(data[0] >= 20 && data[0] <= 22)) {
ASSERT_TRUE(data[0] == 23 || IsRtpLeadByte(data[0]));
if (data[0] == 23) {
ASSERT_TRUE(VerifyEncryptedPacket(data, size));
} else if (IsRtpLeadByte(data[0])) {
} else if (data[0] == 26) {
RTC_LOG(LS_INFO) << "Found DTLS ACK";
} else if (dtls_transport_->IsDtlsActive()) {
if (IsRtpLeadByte(data[0])) {
ASSERT_TRUE(VerifyPacket(packet.payload(), NULL));
} else if (packet_size_ && packet.payload().size() >= packet_size_) {
ASSERT_TRUE(VerifyEncryptedPacket(data, packet.payload().size()));
}
}
}
@ -326,6 +345,7 @@ class DtlsTestClient : public sigslot::has_slots<> {
int received_dtls_client_hellos_ = 0;
int received_dtls_server_hellos_ = 0;
rtc::SentPacket sent_packet_;
absl::AnyInvocable<void()> writable_func_;
};
// Base class for DtlsTransportTest and DtlsEventOrderingTest, which
@ -493,6 +513,18 @@ class DtlsTransportVersionTest
::testing::tuple<rtc::SSLProtocolVersion, rtc::SSLProtocolVersion>> {
};
// Will test every combination of 1.0/1.2/1.3 on the client and server.
// DTLS will negotiate an effective version (the min of client & sewrver).
INSTANTIATE_TEST_SUITE_P(
DtlsTransportVersionTest,
DtlsTransportVersionTest,
::testing::Combine(::testing::Values(rtc::SSL_PROTOCOL_DTLS_10,
rtc::SSL_PROTOCOL_DTLS_12,
rtc::SSL_PROTOCOL_DTLS_13),
::testing::Values(rtc::SSL_PROTOCOL_DTLS_10,
rtc::SSL_PROTOCOL_DTLS_12,
rtc::SSL_PROTOCOL_DTLS_13)));
// Test that an acceptable cipher suite is negotiated when different versions
// of DTLS are supported. Note that it's IsAcceptableCipher that does the actual
// work.
@ -503,14 +535,156 @@ TEST_P(DtlsTransportVersionTest, TestCipherSuiteNegotiation) {
ASSERT_TRUE(Connect());
}
// Will test every combination of 1.0/1.2 on the client and server.
INSTANTIATE_TEST_SUITE_P(
TestCipherSuiteNegotiation,
DtlsTransportVersionTest,
::testing::Combine(::testing::Values(rtc::SSL_PROTOCOL_DTLS_10,
rtc::SSL_PROTOCOL_DTLS_12),
::testing::Values(rtc::SSL_PROTOCOL_DTLS_10,
rtc::SSL_PROTOCOL_DTLS_12)));
enum HandshakeTestEvent {
EV_CLIENT_SEND = 0,
EV_SERVER_SEND = 1,
EV_CLIENT_RECV = 2,
EV_SERVER_RECV = 3,
EV_CLIENT_WRITABLE = 4,
EV_SERVER_WRITABLE = 5,
};
static const std::vector<HandshakeTestEvent> dtls_12_handshake_events{
// Flight 1
EV_CLIENT_SEND,
EV_SERVER_RECV,
EV_SERVER_SEND,
EV_CLIENT_RECV,
// Flight 2
EV_CLIENT_SEND,
EV_SERVER_RECV,
EV_SERVER_SEND,
EV_SERVER_WRITABLE,
EV_CLIENT_RECV,
EV_CLIENT_WRITABLE,
};
static const std::vector<HandshakeTestEvent> dtls_13_handshake_events{
// Flight 1
EV_CLIENT_SEND,
EV_SERVER_RECV,
EV_SERVER_SEND,
EV_CLIENT_RECV,
// Flight 2
EV_CLIENT_SEND,
EV_CLIENT_WRITABLE,
EV_SERVER_RECV,
EV_SERVER_SEND,
EV_SERVER_WRITABLE,
};
static const struct {
int version_bytes;
const std::vector<HandshakeTestEvent>& events;
} kEventsPerVersion[] = {
{rtc::kDtls12VersionBytes, dtls_12_handshake_events},
{rtc::kDtls13VersionBytes, dtls_13_handshake_events},
};
bool LogRecv(absl::string_view name,
const rtc::CopyOnWriteBuffer& packet,
uint64_t timestamp_ms) {
RTC_LOG(LS_INFO) << "time=" << timestamp_ms << " : " << name
<< ": ReceivePacket packet len=" << packet.size()
<< ", data[0]: " << static_cast<uint8_t>(packet.data()[0]);
return false;
}
bool LogSend(absl::string_view name,
uint64_t timestamp_ms,
bool drop,
const char* data,
size_t len) {
if (drop) {
RTC_LOG(LS_INFO) << "time=" << timestamp_ms << " : " << name
<< ": dropping packet len=" << len
<< ", data[0]: " << static_cast<uint8_t>(data[0]);
} else {
RTC_LOG(LS_INFO) << "time=" << timestamp_ms << " : " << name
<< ": SendPacket, len=" << len
<< ", data[0]: " << static_cast<uint8_t>(data[0]);
}
return drop;
}
TEST_P(DtlsTransportVersionTest, TestHandshakeFlights) {
// We can only change the retransmission schedule with a recently-added
// BoringSSL API. Skip the test if not built with BoringSSL.
MAYBE_SKIP_TEST(IsBoringSsl);
// Disable any forcing of Dtls1.3.
webrtc::test::ScopedFieldTrials trials("WebRTC-ForceDtls13/Off/");
PrepareDtls(rtc::KT_DEFAULT);
SetMaxProtocolVersions(::testing::get<0>(GetParam()),
::testing::get<1>(GetParam()));
Negotiate(/* client1_server= */ false);
std::vector<HandshakeTestEvent> events;
auto start_time_ns = fake_clock_.TimeNanos();
client1_.fake_ice_transport()->set_rtt_estimate(50, true);
client2_.fake_ice_transport()->set_rtt_estimate(50, true);
client1_.fake_ice_transport()->set_packet_recv_filter(
[&](auto packet, auto timestamp_us) {
events.push_back(EV_CLIENT_RECV);
return LogRecv("client", packet,
(timestamp_us - start_time_ns / 1000) / 1000);
});
client2_.fake_ice_transport()->set_packet_recv_filter(
[&](auto packet, auto timestamp_us) {
events.push_back(EV_SERVER_RECV);
return LogRecv("server", packet,
(timestamp_us - start_time_ns / 1000) / 1000);
});
client1_.set_writable_callback(
[&]() { events.push_back(EV_CLIENT_WRITABLE); });
client2_.set_writable_callback(
[&]() { events.push_back(EV_SERVER_WRITABLE); });
client1_.fake_ice_transport()->set_packet_send_filter(
[&](auto data, auto len, auto options, auto flags) {
events.push_back(EV_CLIENT_SEND);
bool drop = false;
auto diff_ms = (fake_clock_.TimeNanos() - start_time_ns) / 1000000;
return LogSend("client", diff_ms, drop, data, len);
});
client2_.fake_ice_transport()->set_packet_send_filter(
[&](auto data, auto len, auto options, auto flags) {
events.push_back(EV_SERVER_SEND);
bool drop = false;
auto diff_ms = (fake_clock_.TimeNanos() - start_time_ns) / 1000000;
return LogSend("server", diff_ms, drop, data, len);
});
EXPECT_TRUE(client1_.Connect(&client2_, false));
EXPECT_TRUE_SIMULATED_WAIT(client1_.dtls_transport()->writable() &&
client2_.dtls_transport()->writable(),
kTimeout, fake_clock_);
client1_.fake_ice_transport()->set_packet_send_filter(nullptr);
client2_.fake_ice_transport()->set_packet_send_filter(nullptr);
client1_.fake_ice_transport()->set_packet_recv_filter(nullptr);
client2_.fake_ice_transport()->set_packet_recv_filter(nullptr);
auto dtls_version_bytes = client1_.GetVersionBytes();
ASSERT_EQ(dtls_version_bytes, client2_.GetVersionBytes());
std::vector<HandshakeTestEvent> expect;
for (const auto e : kEventsPerVersion) {
if (e.version_bytes == dtls_version_bytes) {
expect = e.events;
break;
}
}
RTC_LOG(LS_INFO) << "Verifying events with ssl version bytes= "
<< *dtls_version_bytes;
EXPECT_EQ(events, expect);
}
// Connect with DTLS, negotiating DTLS-SRTP, and transfer SRTP using bypass.
TEST_F(DtlsTransportTest, TestTransferDtlsSrtp) {

View File

@ -416,6 +416,11 @@ TEST_P(DataChannelIntegrationTest, EndToEndCallWithSctpDataChannelHarmfulMtu) {
EXPECT_TRUE_WAIT(caller()->data_observer()->IsOpen(), kDefaultTimeout);
EXPECT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout);
if (caller()->tls_version() == rtc::kDtls13VersionBytes) {
ASSERT_EQ(caller()->tls_version(), rtc::kDtls13VersionBytes);
GTEST_SKIP() << "DTLS1.3 fragments packets larger than MTU";
}
virtual_socket_server()->set_max_udp_payload(kLowestSafePayloadSizeLimit - 1);
// Probe for an undelivered or slowly delivered message. The exact
// size limit seems to be dependent on the message history, so make the

View File

@ -235,6 +235,7 @@ bool PeerConnectionIntegrationWrapper::Init(
fake_network_manager_->AddInterface(kDefaultLocalAddress);
socket_factory_.reset(new rtc::BasicPacketSocketFactory(socket_server));
network_thread_ = network_thread;
std::unique_ptr<cricket::PortAllocator> port_allocator(
new cricket::BasicPortAllocator(fake_network_manager_.get(),

View File

@ -730,6 +730,16 @@ class PeerConnectionIntegrationWrapper : public PeerConnectionObserver,
connection_change_callback_ = std::move(func);
}
std::optional<int> tls_version() {
return network_thread_->BlockingCall([&] {
return pc()
->GetSctpTransport()
->dtls_transport()
->Information()
.tls_version();
});
}
private:
// Constructor used by friend class PeerConnectionIntegrationBaseTest.
explicit PeerConnectionIntegrationWrapper(const std::string& debug_name)
@ -1098,6 +1108,8 @@ class PeerConnectionIntegrationWrapper : public PeerConnectionObserver,
std::unique_ptr<rtc::FakeNetworkManager> fake_network_manager_;
std::unique_ptr<rtc::BasicPacketSocketFactory> socket_factory_;
rtc::Thread* network_thread_;
// Reference to the mDNS responder owned by `fake_network_manager_` after set.
FakeMdnsResponder* mdns_responder_ = nullptr;

View File

@ -980,10 +980,11 @@ SSL_CTX* OpenSSLStreamAdapter::SetupSSLContext() {
return nullptr;
}
SSL_CTX_set_min_proto_version(
ctx, ssl_mode_ == SSL_MODE_DTLS ? DTLS1_2_VERSION : TLS1_2_VERSION);
SSL_CTX_set_max_proto_version(ctx,
GetMaxVersion(ssl_mode_, ssl_max_version_));
auto min_version =
ssl_mode_ == SSL_MODE_DTLS ? DTLS1_2_VERSION : TLS1_2_VERSION;
auto max_version = GetMaxVersion(ssl_mode_, ssl_max_version_);
SSL_CTX_set_min_proto_version(ctx, min_version);
SSL_CTX_set_max_proto_version(ctx, max_version);
#ifdef OPENSSL_IS_BORINGSSL
// SSL_CTX_set_current_time_cb is only supported in BoringSSL.
@ -1188,6 +1189,21 @@ static const cipher_list OK_ECDSA_ciphers[] = {
};
#undef CDEF
static const cipher_list OK_DTLS13_ciphers[] = {
#ifdef TLS1_3_CK_AES_128_GCM_SHA256 // BoringSSL TLS 1.3
{static_cast<uint16_t>(TLS1_3_CK_AES_128_GCM_SHA256 & 0xffff),
"TLS_AES_128_GCM_SHA256"},
#endif
#ifdef TLS1_3_CK_AES_256_GCM_SHA256 // BoringSSL TLS 1.3
{static_cast<uint16_t>(TLS1_3_CK_AES_256_GCM_SHA256 & 0xffff),
"TLS_AES_256_GCM_SHA256"},
#endif
#ifdef TLS1_3_CK_CHACHA20_POLY1305_SHA256 // BoringSSL TLS 1.3
{static_cast<uint16_t>(TLS1_3_CK_CHACHA20_POLY1305_SHA256 & 0xffff),
"TLS_CHACHA20_POLY1305_SHA256"},
#endif
};
bool OpenSSLStreamAdapter::IsAcceptableCipher(int cipher, KeyType key_type) {
if (key_type == KT_RSA) {
for (const cipher_list& c : OK_RSA_ciphers) {
@ -1205,6 +1221,12 @@ bool OpenSSLStreamAdapter::IsAcceptableCipher(int cipher, KeyType key_type) {
}
}
for (const cipher_list& c : OK_DTLS13_ciphers) {
if (cipher == c.cipher) {
return true;
}
}
return false;
}
@ -1226,6 +1248,12 @@ bool OpenSSLStreamAdapter::IsAcceptableCipher(absl::string_view cipher,
}
}
for (const cipher_list& c : OK_DTLS13_ciphers) {
if (cipher == c.cipher_str) {
return true;
}
}
return false;
}
@ -1235,4 +1263,12 @@ void OpenSSLStreamAdapter::EnableTimeCallbackForTesting() {
#endif
}
SSLProtocolVersion OpenSSLStreamAdapter::GetMaxSupportedDTLSProtocolVersion() {
#if defined(OPENSSL_IS_BORINGSSL) && defined(DTLS1_3_VERSION)
return SSL_PROTOCOL_DTLS_13;
#else
return SSL_PROTOCOL_DTLS_12;
#endif
}
} // namespace rtc

View File

@ -128,6 +128,9 @@ class OpenSSLStreamAdapter final : public SSLStreamAdapter {
// using a fake clock.
static void EnableTimeCallbackForTesting();
// Return max DTLS SSLProtocolVersion supported by implementation.
static SSLProtocolVersion GetMaxSupportedDTLSProtocolVersion();
private:
enum SSLState {
// Before calling one of the StartSSL methods, data flows

View File

@ -120,6 +120,10 @@ void SSLStreamAdapter::EnableTimeCallbackForTesting() {
OpenSSLStreamAdapter::EnableTimeCallbackForTesting();
}
SSLProtocolVersion SSLStreamAdapter::GetMaxSupportedDTLSProtocolVersion() {
return OpenSSLStreamAdapter::GetMaxSupportedDTLSProtocolVersion();
}
///////////////////////////////////////////////////////////////////////////////
} // namespace rtc

View File

@ -97,6 +97,12 @@ enum SSLProtocolVersion {
SSL_PROTOCOL_DTLS_12 = SSL_PROTOCOL_TLS_12,
SSL_PROTOCOL_DTLS_13 = SSL_PROTOCOL_TLS_13,
};
// Versions returned from BoringSSL.
const uint16_t kDtls10VersionBytes = 0xfeff;
const uint16_t kDtls12VersionBytes = 0xfefd;
const uint16_t kDtls13VersionBytes = 0xfefc;
enum class SSLPeerCertificateDigestError {
NONE,
UNKNOWN_ALGORITHM,
@ -240,6 +246,9 @@ class SSLStreamAdapter : public StreamInterface {
// using a fake clock.
static void EnableTimeCallbackForTesting();
// Return max DTLS SSLProtocolVersion supported by implementation.
static SSLProtocolVersion GetMaxSupportedDTLSProtocolVersion();
// Deprecated. Do not use this API outside of testing.
// Do not set this to false outside of testing.
void SetClientAuthEnabledForTesting(bool enabled) {

View File

@ -665,7 +665,7 @@ class SSLStreamAdapterTestBase : public ::testing::Test,
// SS_OPENING and writes should return SR_BLOCK.
EXPECT_EQ(rtc::SS_OPENING, client_ssl_->GetState());
EXPECT_EQ(rtc::SS_OPENING, server_ssl_->GetState());
uint8_t packet[1];
uint8_t packet[1] = {0};
size_t sent;
int error;
EXPECT_EQ(rtc::SR_BLOCK, client_ssl_->Write(packet, sent, error));
@ -753,11 +753,10 @@ class SSLStreamAdapterTestBase : public ::testing::Test,
// Optionally damage application data (type 23). Note that we don't damage
// handshake packets and we damage the last byte to keep the header
// intact but break the MAC.
if (damage_ && (*static_cast<const unsigned char*>(data) == 23)) {
uint8_t data0 = static_cast<const unsigned char*>(data)[0];
if (damage_ && (data0 == 23 || data0 == 47)) {
std::vector<uint8_t> buf(data_len);
RTC_LOG(LS_VERBOSE) << "Damaging packet";
memcpy(&buf[0], data, data_len);
buf[data_len - 1]++;
return from->WriteData(rtc::MakeArrayView(&buf[0], data_len), written,
@ -1119,8 +1118,9 @@ TEST_P(SSLStreamAdapterTestDTLSHandshake, TestDTLSConnect) {
}
// Test getting the used DTLS ciphers.
// DTLS 1.2 is max version for client and server.
// DTLS 1.2 has different cipher suite than 1.3.
TEST_P(SSLStreamAdapterTestDTLSHandshake, TestGetSslCipherSuite) {
webrtc::test::ScopedFieldTrials trials("WebRTC-ForceDtls13/Off/");
SetupProtocolVersions(rtc::SSL_PROTOCOL_DTLS_12, rtc::SSL_PROTOCOL_DTLS_12);
TestHandshake();
@ -1476,18 +1476,76 @@ TEST_F(SSLStreamAdapterTestDTLSFromPEMStrings,
}
#pragma clang diagnostic pop
// Test getting the DTLS 1.2 version.
TEST_F(SSLStreamAdapterTestDTLS, TestGetSslVersionBytes) {
// https://datatracker.ietf.org/doc/html/rfc9147#section-5.3
const int kDtls1_2 = 0xFEFD;
SetupProtocolVersions(rtc::SSL_PROTOCOL_DTLS_12, rtc::SSL_PROTOCOL_DTLS_12);
struct SSLStreamAdapterTestDTLSHandshakeVersion
: public SSLStreamAdapterTestDTLS,
public WithParamInterface<std::tuple<
/* client*/ rtc::SSLProtocolVersion,
/* server*/ rtc::SSLProtocolVersion>> {
rtc::SSLProtocolVersion GetMin(
const std::vector<rtc::SSLProtocolVersion>& array) {
rtc::SSLProtocolVersion min = array[0];
for (const auto& e : array) {
if (static_cast<int>(e) < static_cast<int>(min)) {
min = e;
}
}
return min;
}
uint16_t AsDtlsVersionBytes(rtc::SSLProtocolVersion version) {
switch (version) {
case rtc::SSL_PROTOCOL_DTLS_10:
return rtc::kDtls10VersionBytes;
case rtc::SSL_PROTOCOL_DTLS_12:
return rtc::kDtls12VersionBytes;
case rtc::SSL_PROTOCOL_DTLS_13:
return rtc::kDtls13VersionBytes;
default:
break;
}
RTC_CHECK(false) << "Unknown version: " << static_cast<int>(version);
}
};
INSTANTIATE_TEST_SUITE_P(
SSLStreamAdapterTestDTLSHandshakeVersion,
SSLStreamAdapterTestDTLSHandshakeVersion,
Combine(Values(rtc::SSL_PROTOCOL_DTLS_12, rtc::SSL_PROTOCOL_DTLS_13),
Values(rtc::SSL_PROTOCOL_DTLS_12, rtc::SSL_PROTOCOL_DTLS_13)));
TEST_P(SSLStreamAdapterTestDTLSHandshakeVersion, TestGetSslVersionBytes) {
webrtc::test::ScopedFieldTrials trials("WebRTC-ForceDtls13/Off/");
auto client = ::testing::get<0>(GetParam());
auto server = ::testing::get<1>(GetParam());
SetupProtocolVersions(client, server);
TestHandshake();
int client_version;
ASSERT_TRUE(GetSslVersionBytes(true, &client_version));
EXPECT_EQ(client_version, kDtls1_2);
int server_version;
ASSERT_TRUE(GetSslVersionBytes(true, &client_version));
ASSERT_TRUE(GetSslVersionBytes(false, &server_version));
EXPECT_EQ(server_version, kDtls1_2);
rtc::SSLProtocolVersion expect =
GetMin({client, server,
rtc::SSLStreamAdapter::GetMaxSupportedDTLSProtocolVersion()});
auto expect_bytes = AsDtlsVersionBytes(expect);
EXPECT_EQ(client_version, expect_bytes);
EXPECT_EQ(server_version, expect_bytes);
}
TEST_P(SSLStreamAdapterTestDTLSHandshakeVersion, TestGetSslCipherSuite) {
webrtc::test::ScopedFieldTrials trials("WebRTC-ForceDtls13/Off/");
auto client = ::testing::get<0>(GetParam());
auto server = ::testing::get<1>(GetParam());
SetupProtocolVersions(client, server);
TestHandshake();
int client_cipher;
ASSERT_TRUE(GetSslCipherSuite(true, &client_cipher));
int server_cipher;
ASSERT_TRUE(GetSslCipherSuite(false, &server_cipher));
ASSERT_EQ(client_cipher, server_cipher);
ASSERT_TRUE(rtc::SSLStreamAdapter::IsAcceptableCipher(server_cipher,
rtc::KT_DEFAULT));
}