Split SSLAdapter/SSLStreamAdapter and deprecate SSL(Stream)Adapter::SetMode

since we do not need two adapters with TLS and DTLS modes.
SSLAdapter is the TLS adapter,
SSLStreamAdapter is the DTLS adapter.

BUG=webrtc:353750117

Change-Id: I223917c71c88437339380e1f196dcf3c0e2021c8
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/354940
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Commit-Queue: Philipp Hancke <phancke@meta.com>
Reviewed-by: Mirko Bonadei <mbonadei@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#42713}
This commit is contained in:
Philipp Hancke 2024-07-30 16:42:49 -07:00 committed by WebRTC LUCI CQ
parent 7c793a7dbe
commit 5d6fa7d2fc
8 changed files with 58 additions and 589 deletions

View File

@ -376,7 +376,6 @@ bool DtlsTransport::SetupDtls() {
} }
dtls_->SetIdentity(local_certificate_->identity()->Clone()); dtls_->SetIdentity(local_certificate_->identity()->Clone());
dtls_->SetMode(rtc::SSL_MODE_DTLS);
dtls_->SetMaxProtocolVersion(ssl_max_version_); dtls_->SetMaxProtocolVersion(ssl_max_version_);
dtls_->SetServerRole(*dtls_role_); dtls_->SetServerRole(*dtls_role_);
dtls_->SetEventCallback( dtls_->SetEventCallback(

View File

@ -56,7 +56,7 @@ class OpenSSLAdapter final : public SSLAdapter {
void SetIgnoreBadCert(bool ignore) override; void SetIgnoreBadCert(bool ignore) override;
void SetAlpnProtocols(const std::vector<std::string>& protos) override; void SetAlpnProtocols(const std::vector<std::string>& protos) override;
void SetEllipticCurves(const std::vector<std::string>& curves) override; void SetEllipticCurves(const std::vector<std::string>& curves) override;
void SetMode(SSLMode mode) override; [[deprecated]] void SetMode(SSLMode mode) override;
void SetCertVerifier(SSLCertificateVerifier* ssl_cert_verifier) override; void SetCertVerifier(SSLCertificateVerifier* ssl_cert_verifier) override;
void SetIdentity(std::unique_ptr<SSLIdentity> identity) override; void SetIdentity(std::unique_ptr<SSLIdentity> identity) override;
void SetRole(SSLRole role) override; void SetRole(SSLRole role) override;

View File

@ -292,7 +292,7 @@ OpenSSLStreamAdapter::OpenSSLStreamAdapter(
permute_extension_( permute_extension_(
!webrtc::field_trial::IsDisabled("WebRTC-PermuteTlsClientHello")), !webrtc::field_trial::IsDisabled("WebRTC-PermuteTlsClientHello")),
#endif #endif
ssl_mode_(SSL_MODE_TLS), ssl_mode_(SSL_MODE_DTLS),
ssl_max_version_(SSL_PROTOCOL_TLS_12) { ssl_max_version_(SSL_PROTOCOL_TLS_12) {
stream_->SetEventCallback( stream_->SetEventCallback(
[this](int events, int err) { OnEvent(events, err); }); [this](int events, int err) { OnEvent(events, err); });

View File

@ -90,7 +90,7 @@ class OpenSSLStreamAdapter final : public SSLStreamAdapter,
// Goes from state SSL_NONE to either SSL_CONNECTING or SSL_WAIT, depending // Goes from state SSL_NONE to either SSL_CONNECTING or SSL_WAIT, depending
// on whether the underlying stream is already open or not. // on whether the underlying stream is already open or not.
int StartSSL() override; int StartSSL() override;
void SetMode(SSLMode mode) override; [[deprecated]] void SetMode(SSLMode mode) override;
void SetMaxProtocolVersion(SSLProtocolVersion version) override; void SetMaxProtocolVersion(SSLProtocolVersion version) override;
void SetInitialRetransmissionTimeout(int timeout_ms) override; void SetInitialRetransmissionTimeout(int timeout_ms) override;

View File

@ -80,8 +80,8 @@ class SSLAdapter : public AsyncSocketAdapter {
virtual void SetAlpnProtocols(const std::vector<std::string>& protos) = 0; virtual void SetAlpnProtocols(const std::vector<std::string>& protos) = 0;
virtual void SetEllipticCurves(const std::vector<std::string>& curves) = 0; virtual void SetEllipticCurves(const std::vector<std::string>& curves) = 0;
// Do DTLS or TLS (default is TLS, if unspecified) [[deprecated("Only TLS is supported by the adapter")]] virtual void SetMode(
virtual void SetMode(SSLMode mode) = 0; SSLMode mode) = 0;
// Specify a custom certificate verifier for SSL. // Specify a custom certificate verifier for SSL.
virtual void SetCertVerifier(SSLCertificateVerifier* ssl_cert_verifier) = 0; virtual void SetCertVerifier(SSLCertificateVerifier* ssl_cert_verifier) = 0;

View File

@ -20,7 +20,6 @@
#include "rtc_base/ip_address.h" #include "rtc_base/ip_address.h"
#include "rtc_base/message_digest.h" #include "rtc_base/message_digest.h"
#include "rtc_base/ssl_identity.h" #include "rtc_base/ssl_identity.h"
#include "rtc_base/ssl_stream_adapter.h"
#include "rtc_base/stream.h" #include "rtc_base/stream.h"
#include "rtc_base/string_encode.h" #include "rtc_base/string_encode.h"
#include "rtc_base/virtual_socket_server.h" #include "rtc_base/virtual_socket_server.h"
@ -31,21 +30,16 @@ using ::testing::Return;
static const int kTimeout = 5000; static const int kTimeout = 5000;
static rtc::Socket* CreateSocket(const rtc::SSLMode& ssl_mode) { static rtc::Socket* CreateSocket() {
rtc::SocketAddress address(rtc::IPAddress(INADDR_ANY), 0); rtc::SocketAddress address(rtc::IPAddress(INADDR_ANY), 0);
rtc::Socket* socket = rtc::Thread::Current()->socketserver()->CreateSocket( rtc::Socket* socket = rtc::Thread::Current()->socketserver()->CreateSocket(
address.family(), address.family(), SOCK_STREAM);
(ssl_mode == rtc::SSL_MODE_DTLS) ? SOCK_DGRAM : SOCK_STREAM);
socket->Bind(address); socket->Bind(address);
return socket; return socket;
} }
static std::string GetSSLProtocolName(const rtc::SSLMode& ssl_mode) {
return (ssl_mode == rtc::SSL_MODE_DTLS) ? "DTLS" : "TLS";
}
// Simple mock for the certificate verifier. // Simple mock for the certificate verifier.
class MockCertVerifier : public rtc::SSLCertificateVerifier { class MockCertVerifier : public rtc::SSLCertificateVerifier {
public: public:
@ -55,25 +49,24 @@ class MockCertVerifier : public rtc::SSLCertificateVerifier {
// TODO(benwright) - Move to using INSTANTIATE_TEST_SUITE_P instead of using // TODO(benwright) - Move to using INSTANTIATE_TEST_SUITE_P instead of using
// duplicate test cases for simple parameter changes. // duplicate test cases for simple parameter changes.
class SSLAdapterTestDummyClient : public sigslot::has_slots<> { class SSLAdapterTestDummy : public sigslot::has_slots<> {
public: public:
explicit SSLAdapterTestDummyClient(const rtc::SSLMode& ssl_mode) explicit SSLAdapterTestDummy() : socket_(CreateSocket()) {}
: ssl_mode_(ssl_mode) { virtual ~SSLAdapterTestDummy() = default;
rtc::Socket* socket = CreateSocket(ssl_mode_);
void CreateSSLAdapter(rtc::Socket* socket, rtc::SSLRole role) {
ssl_adapter_.reset(rtc::SSLAdapter::Create(socket)); ssl_adapter_.reset(rtc::SSLAdapter::Create(socket));
ssl_adapter_->SetMode(ssl_mode_);
// Ignore any certificate errors for the purpose of testing. // Ignore any certificate errors for the purpose of testing.
// Note: We do this only because we don't have a real certificate. // Note: We do this only because we don't have a real certificate.
// NEVER USE THIS IN PRODUCTION CODE! // NEVER USE THIS IN PRODUCTION CODE!
ssl_adapter_->SetIgnoreBadCert(true); ssl_adapter_->SetIgnoreBadCert(true);
ssl_adapter_->SignalReadEvent.connect( ssl_adapter_->SignalReadEvent.connect(
this, &SSLAdapterTestDummyClient::OnSSLAdapterReadEvent); this, &SSLAdapterTestDummy::OnSSLAdapterReadEvent);
ssl_adapter_->SignalCloseEvent.connect( ssl_adapter_->SignalCloseEvent.connect(
this, &SSLAdapterTestDummyClient::OnSSLAdapterCloseEvent); this, &SSLAdapterTestDummy::OnSSLAdapterCloseEvent);
ssl_adapter_->SetRole(role);
} }
void SetIgnoreBadCert(bool ignore_bad_cert) { void SetIgnoreBadCert(bool ignore_bad_cert) {
@ -100,27 +93,10 @@ class SSLAdapterTestDummyClient : public sigslot::has_slots<> {
const std::string& GetReceivedData() const { return data_; } const std::string& GetReceivedData() const { return data_; }
int Connect(absl::string_view hostname, const rtc::SocketAddress& address) {
RTC_LOG(LS_INFO) << "Initiating connection with " << address.ToString();
int rv = ssl_adapter_->Connect(address);
if (rv == 0) {
RTC_LOG(LS_INFO) << "Starting " << GetSSLProtocolName(ssl_mode_)
<< " handshake with " << hostname;
if (ssl_adapter_->StartSSL(hostname) != 0) {
return -1;
}
}
return rv;
}
int Close() { return ssl_adapter_->Close(); } int Close() { return ssl_adapter_->Close(); }
int Send(absl::string_view message) { int Send(absl::string_view message) {
RTC_LOG(LS_INFO) << "Client sending '" << message << "'"; RTC_LOG(LS_INFO) << "Sending '" << message << "'";
return ssl_adapter_->Send(message.data(), message.length()); return ssl_adapter_->Send(message.data(), message.length());
} }
@ -133,7 +109,7 @@ class SSLAdapterTestDummyClient : public sigslot::has_slots<> {
if (read != -1) { if (read != -1) {
buffer[read] = '\0'; buffer[read] = '\0';
RTC_LOG(LS_INFO) << "Client received '" << buffer << "'"; RTC_LOG(LS_INFO) << "Received '" << buffer << "'";
data_ += buffer; data_ += buffer;
} }
@ -148,125 +124,50 @@ class SSLAdapterTestDummyClient : public sigslot::has_slots<> {
} }
} }
private: protected:
const rtc::SSLMode ssl_mode_;
std::unique_ptr<rtc::SSLAdapter> ssl_adapter_; std::unique_ptr<rtc::SSLAdapter> ssl_adapter_;
std::unique_ptr<rtc::Socket> socket_;
private:
std::string data_; std::string data_;
}; };
namespace { class SSLAdapterTestDummyClient : public SSLAdapterTestDummy {
class SocketStream : public rtc::StreamInterface, public sigslot::has_slots<> {
public: public:
explicit SocketStream(rtc::Socket* socket) : socket_(socket) { explicit SSLAdapterTestDummyClient() : SSLAdapterTestDummy() {
socket_->SignalConnectEvent.connect(this, &SocketStream::OnConnectEvent); CreateSSLAdapter(socket_.release(), rtc::SSL_CLIENT);
socket_->SignalReadEvent.connect(this, &SocketStream::OnReadEvent);
socket_->SignalWriteEvent.connect(this, &SocketStream::OnWriteEvent);
socket_->SignalCloseEvent.connect(this, &SocketStream::OnCloseEvent);
} }
~SocketStream() override = default; int Connect(absl::string_view hostname, const rtc::SocketAddress& address) {
RTC_LOG(LS_INFO) << "Initiating connection with " << address.ToString();
int rv = ssl_adapter_->Connect(address);
rtc::StreamState GetState() const override { if (rv == 0) {
switch (socket_->GetState()) { RTC_LOG(LS_INFO) << "Starting TLS handshake with " << hostname;
case rtc::Socket::CS_CONNECTED:
return rtc::SS_OPEN; if (ssl_adapter_->StartSSL(hostname) != 0) {
case rtc::Socket::CS_CONNECTING: return -1;
return rtc::SS_OPENING; }
case rtc::Socket::CS_CLOSED:
default:
return rtc::SS_CLOSED;
} }
}
rtc::StreamResult Read(rtc::ArrayView<uint8_t> buffer, return rv;
size_t& read,
int& error) override {
int result = socket_->Recv(buffer.data(), buffer.size(), nullptr);
if (result < 0) {
if (socket_->IsBlocking())
return rtc::SR_BLOCK;
error = socket_->GetError();
return rtc::SR_ERROR;
}
if ((result > 0) || (buffer.size() == 0)) {
read = result;
return rtc::SR_SUCCESS;
}
return rtc::SR_EOS;
} }
rtc::StreamResult Write(rtc::ArrayView<const uint8_t> data,
size_t& written,
int& error) override {
int result = socket_->Send(data.data(), data.size());
if (result < 0) {
if (socket_->IsBlocking())
return rtc::SR_BLOCK;
error = socket_->GetError();
return rtc::SR_ERROR;
}
written = result;
return rtc::SR_SUCCESS;
}
void Close() override { socket_->Close(); }
private:
void OnConnectEvent(rtc::Socket* socket) {
RTC_DCHECK_RUN_ON(&callback_sequence_);
RTC_DCHECK_EQ(socket, socket_.get());
FireEvent(rtc::SE_OPEN | rtc::SE_READ | rtc::SE_WRITE, 0);
}
void OnReadEvent(rtc::Socket* socket) {
RTC_DCHECK_RUN_ON(&callback_sequence_);
RTC_DCHECK_EQ(socket, socket_.get());
FireEvent(rtc::SE_READ, 0);
}
void OnWriteEvent(rtc::Socket* socket) {
RTC_DCHECK_RUN_ON(&callback_sequence_);
RTC_DCHECK_EQ(socket, socket_.get());
FireEvent(rtc::SE_WRITE, 0);
}
void OnCloseEvent(rtc::Socket* socket, int err) {
RTC_DCHECK_RUN_ON(&callback_sequence_);
RTC_DCHECK_EQ(socket, socket_.get());
FireEvent(rtc::SE_CLOSE, err);
}
std::unique_ptr<rtc::Socket> socket_;
}; };
} // namespace class SSLAdapterTestDummyServer : public SSLAdapterTestDummy {
class SSLAdapterTestDummyServer : public sigslot::has_slots<> {
public: public:
explicit SSLAdapterTestDummyServer(const rtc::SSLMode& ssl_mode, explicit SSLAdapterTestDummyServer(const rtc::KeyParams& key_params)
const rtc::KeyParams& key_params) : SSLAdapterTestDummy(),
: ssl_mode_(ssl_mode) { ssl_identity_(rtc::SSLIdentity::Create(GetHostname(), key_params)) {
// Generate a key pair and a certificate for this host. socket_->Listen(1);
ssl_identity_ = rtc::SSLIdentity::Create(GetHostname(), key_params); socket_->SignalReadEvent.connect(this,
&SSLAdapterTestDummyServer::OnReadEvent);
server_socket_.reset(CreateSocket(ssl_mode_)); RTC_LOG(LS_INFO) << "TCP server listening on "
<< socket_->GetLocalAddress().ToString();
if (ssl_mode_ == rtc::SSL_MODE_TLS) {
server_socket_->SignalReadEvent.connect(
this, &SSLAdapterTestDummyServer::OnServerSocketReadEvent);
server_socket_->Listen(1);
}
RTC_LOG(LS_INFO) << ((ssl_mode_ == rtc::SSL_MODE_DTLS) ? "UDP" : "TCP")
<< " server listening on "
<< server_socket_->GetLocalAddress().ToString();
} }
rtc::SocketAddress GetAddress() const { rtc::SocketAddress GetAddress() const { return socket_->GetLocalAddress(); }
return server_socket_->GetLocalAddress();
}
std::string GetHostname() const { std::string GetHostname() const {
// Since we don't have a real certificate anyway, the value here doesn't // Since we don't have a real certificate anyway, the value here doesn't
@ -274,120 +175,26 @@ class SSLAdapterTestDummyServer : public sigslot::has_slots<> {
return "example.com"; return "example.com";
} }
const std::string& GetReceivedData() const { return data_; } protected:
void OnReadEvent(rtc::Socket* socket) {
int Send(absl::string_view message) { CreateSSLAdapter(socket_->Accept(nullptr), rtc::SSL_SERVER);
if (ssl_stream_adapter_ == nullptr || ssl_adapter_->SetIdentity(ssl_identity_->Clone());
ssl_stream_adapter_->GetState() != rtc::SS_OPEN) { if (ssl_adapter_->StartSSL(GetHostname()) != 0) {
// No connection yet. RTC_LOG(LS_ERROR) << "Starting SSL from server failed.";
return -1;
}
RTC_LOG(LS_INFO) << "Server sending '" << message << "'";
size_t written;
int error;
rtc::StreamResult r = ssl_stream_adapter_->Write(
rtc::MakeArrayView(reinterpret_cast<const uint8_t*>(message.data()),
message.size()),
written, error);
if (r == rtc::SR_SUCCESS) {
return written;
} else {
return -1;
}
}
void AcceptConnection(const rtc::SocketAddress& address) {
// Only a single connection is supported.
ASSERT_TRUE(ssl_stream_adapter_ == nullptr);
// This is only for DTLS.
ASSERT_EQ(rtc::SSL_MODE_DTLS, ssl_mode_);
// Transfer ownership of the socket to the SSLStreamAdapter object.
rtc::Socket* socket = server_socket_.release();
socket->Connect(address);
DoHandshake(socket);
}
void OnServerSocketReadEvent(rtc::Socket* socket) {
// Only a single connection is supported.
ASSERT_TRUE(ssl_stream_adapter_ == nullptr);
DoHandshake(server_socket_->Accept(nullptr));
}
void OnSSLStreamAdapterEvent(int sig, int err) {
if (sig & rtc::SE_READ) {
uint8_t buffer[4096] = "";
size_t read;
int error;
// Read data received from the client and store it in our internal
// buffer.
rtc::StreamResult r = ssl_stream_adapter_->Read(buffer, read, error);
if (r == rtc::SR_SUCCESS) {
buffer[read] = '\0';
// Here we assume that the buffer is interpretable as string.
char* buffer_as_char = reinterpret_cast<char*>(buffer);
RTC_LOG(LS_INFO) << "Server received '" << buffer_as_char << "'";
data_ += buffer_as_char;
}
} }
} }
private: private:
void DoHandshake(rtc::Socket* socket) {
ssl_stream_adapter_ =
rtc::SSLStreamAdapter::Create(std::make_unique<SocketStream>(socket));
ssl_stream_adapter_->SetMode(ssl_mode_);
ssl_stream_adapter_->SetServerRole();
// SSLStreamAdapter is normally used for peer-to-peer communication, but
// here we're testing communication between a client and a server
// (e.g. a WebRTC-based application and an RFC 5766 TURN server), where
// clients are not required to provide a certificate during handshake.
// Accordingly, we must disable client authentication here.
ssl_stream_adapter_->SetClientAuthEnabledForTesting(false);
ssl_stream_adapter_->SetIdentity(ssl_identity_->Clone());
// Set a bogus peer certificate digest.
unsigned char digest[20];
size_t digest_len = sizeof(digest);
ssl_stream_adapter_->SetPeerCertificateDigest(rtc::DIGEST_SHA_1, digest,
digest_len);
ssl_stream_adapter_->StartSSL();
ssl_stream_adapter_->SetEventCallback(
[this](int events, int err) { OnSSLStreamAdapterEvent(events, err); });
}
const rtc::SSLMode ssl_mode_;
std::unique_ptr<rtc::Socket> server_socket_;
std::unique_ptr<rtc::SSLStreamAdapter> ssl_stream_adapter_;
std::unique_ptr<rtc::SSLIdentity> ssl_identity_; std::unique_ptr<rtc::SSLIdentity> ssl_identity_;
std::string data_;
}; };
class SSLAdapterTestBase : public ::testing::Test, public sigslot::has_slots<> { class SSLAdapterTestBase : public ::testing::Test, public sigslot::has_slots<> {
public: public:
explicit SSLAdapterTestBase(const rtc::SSLMode& ssl_mode, explicit SSLAdapterTestBase(const rtc::KeyParams& key_params)
const rtc::KeyParams& key_params) : vss_(new rtc::VirtualSocketServer()),
: ssl_mode_(ssl_mode),
vss_(new rtc::VirtualSocketServer()),
thread_(vss_.get()), thread_(vss_.get()),
server_(new SSLAdapterTestDummyServer(ssl_mode_, key_params)), server_(new SSLAdapterTestDummyServer(key_params)),
client_(new SSLAdapterTestDummyClient(ssl_mode_)), client_(new SSLAdapterTestDummyClient()),
handshake_wait_(kTimeout) {} handshake_wait_(kTimeout) {}
void SetHandshakeWait(int wait) { handshake_wait_ = wait; } void SetHandshakeWait(int wait) { handshake_wait_ = wait; }
@ -430,26 +237,20 @@ class SSLAdapterTestBase : public ::testing::Test, public sigslot::has_slots<> {
// Now the state should be CS_CONNECTING // Now the state should be CS_CONNECTING
ASSERT_EQ(rtc::Socket::CS_CONNECTING, client_->GetState()); ASSERT_EQ(rtc::Socket::CS_CONNECTING, client_->GetState());
if (ssl_mode_ == rtc::SSL_MODE_DTLS) {
// For DTLS, call AcceptConnection() with the client's address.
server_->AcceptConnection(client_->GetAddress());
}
if (expect_success) { if (expect_success) {
// If expecting success, the client should end up in the CS_CONNECTED // If expecting success, the client should end up in the CS_CONNECTED
// state after handshake. // state after handshake.
EXPECT_EQ_WAIT(rtc::Socket::CS_CONNECTED, client_->GetState(), EXPECT_EQ_WAIT(rtc::Socket::CS_CONNECTED, client_->GetState(),
handshake_wait_); handshake_wait_);
RTC_LOG(LS_INFO) << GetSSLProtocolName(ssl_mode_) RTC_LOG(LS_INFO) << "TLS handshake complete.";
<< " handshake complete.";
} else { } else {
// On handshake failure the client should end up in the CS_CLOSED state. // On handshake failure the client should end up in the CS_CLOSED state.
EXPECT_EQ_WAIT(rtc::Socket::CS_CLOSED, client_->GetState(), EXPECT_EQ_WAIT(rtc::Socket::CS_CLOSED, client_->GetState(),
handshake_wait_); handshake_wait_);
RTC_LOG(LS_INFO) << GetSSLProtocolName(ssl_mode_) << " handshake failed."; RTC_LOG(LS_INFO) << "TLS handshake failed.";
} }
} }
@ -472,8 +273,6 @@ class SSLAdapterTestBase : public ::testing::Test, public sigslot::has_slots<> {
} }
protected: protected:
const rtc::SSLMode ssl_mode_;
std::unique_ptr<rtc::VirtualSocketServer> vss_; std::unique_ptr<rtc::VirtualSocketServer> vss_;
rtc::AutoSocketServerThread thread_; rtc::AutoSocketServerThread thread_;
std::unique_ptr<SSLAdapterTestDummyServer> server_; std::unique_ptr<SSLAdapterTestDummyServer> server_;
@ -485,30 +284,14 @@ class SSLAdapterTestBase : public ::testing::Test, public sigslot::has_slots<> {
class SSLAdapterTestTLS_RSA : public SSLAdapterTestBase { class SSLAdapterTestTLS_RSA : public SSLAdapterTestBase {
public: public:
SSLAdapterTestTLS_RSA() SSLAdapterTestTLS_RSA() : SSLAdapterTestBase(rtc::KeyParams::RSA()) {}
: SSLAdapterTestBase(rtc::SSL_MODE_TLS, rtc::KeyParams::RSA()) {}
}; };
class SSLAdapterTestTLS_ECDSA : public SSLAdapterTestBase { class SSLAdapterTestTLS_ECDSA : public SSLAdapterTestBase {
public: public:
SSLAdapterTestTLS_ECDSA() SSLAdapterTestTLS_ECDSA() : SSLAdapterTestBase(rtc::KeyParams::ECDSA()) {}
: SSLAdapterTestBase(rtc::SSL_MODE_TLS, rtc::KeyParams::ECDSA()) {}
}; };
class SSLAdapterTestDTLS_RSA : public SSLAdapterTestBase {
public:
SSLAdapterTestDTLS_RSA()
: SSLAdapterTestBase(rtc::SSL_MODE_DTLS, rtc::KeyParams::RSA()) {}
};
class SSLAdapterTestDTLS_ECDSA : public SSLAdapterTestBase {
public:
SSLAdapterTestDTLS_ECDSA()
: SSLAdapterTestBase(rtc::SSL_MODE_DTLS, rtc::KeyParams::ECDSA()) {}
};
// Basic tests: TLS
// Test that handshake works, using RSA // Test that handshake works, using RSA
TEST_F(SSLAdapterTestTLS_RSA, TestTLSConnect) { TEST_F(SSLAdapterTestTLS_RSA, TestTLSConnect) {
TestHandshake(true); TestHandshake(true);
@ -627,69 +410,3 @@ TEST_F(SSLAdapterTestTLS_ECDSA, TestTLSEllipticCurves) {
TestHandshake(true); TestHandshake(true);
TestTransfer("Hello, world!"); TestTransfer("Hello, world!");
} }
// Basic tests: DTLS
// Test that handshake works, using RSA
TEST_F(SSLAdapterTestDTLS_RSA, TestDTLSConnect) {
TestHandshake(true);
}
// Test that handshake works with a custom verifier that returns true. DTLS_RSA.
TEST_F(SSLAdapterTestDTLS_RSA, TestDTLSConnectCustomCertVerifierSucceeds) {
SetMockCertVerifier(/*return_value=*/true);
TestHandshake(/*expect_success=*/true);
}
// Test that handshake fails with a custom verifier that returns false.
// DTLS_RSA.
TEST_F(SSLAdapterTestDTLS_RSA, TestTLSConnectCustomCertVerifierFails) {
SetMockCertVerifier(/*return_value=*/false);
TestHandshake(/*expect_success=*/false);
}
// Test that handshake works, using ECDSA
TEST_F(SSLAdapterTestDTLS_ECDSA, TestDTLSConnect) {
TestHandshake(true);
}
// Test that handshake works with a custom verifier that returns true.
// DTLS_ECDSA.
TEST_F(SSLAdapterTestDTLS_ECDSA, TestDTLSConnectCustomCertVerifierSucceeds) {
SetMockCertVerifier(/*return_value=*/true);
TestHandshake(/*expect_success=*/true);
}
// Test that handshake fails with a custom verifier that returns false.
// DTLS_ECDSA.
TEST_F(SSLAdapterTestDTLS_ECDSA, TestTLSConnectCustomCertVerifierFails) {
SetMockCertVerifier(/*return_value=*/false);
TestHandshake(/*expect_success=*/false);
}
// Test transfer between client and server, using RSA
TEST_F(SSLAdapterTestDTLS_RSA, TestDTLSTransfer) {
TestHandshake(true);
TestTransfer("Hello, world!");
}
// Test transfer between client and server, using RSA with custom cert verifier.
TEST_F(SSLAdapterTestDTLS_RSA, TestDTLSTransferCustomCertVerifier) {
SetMockCertVerifier(/*return_value=*/true);
TestHandshake(/*expect_success=*/true);
TestTransfer("Hello, world!");
}
// Test transfer between client and server, using ECDSA
TEST_F(SSLAdapterTestDTLS_ECDSA, TestDTLSTransfer) {
TestHandshake(true);
TestTransfer("Hello, world!");
}
// Test transfer between client and server, using ECDSA with custom cert
// verifier.
TEST_F(SSLAdapterTestDTLS_ECDSA, TestDTLSTransferCustomCertVerifier) {
SetMockCertVerifier(/*return_value=*/true);
TestHandshake(/*expect_success=*/true);
TestTransfer("Hello, world!");
}

View File

@ -130,8 +130,8 @@ class SSLStreamAdapter : public StreamInterface {
// TODO(ekr@rtfm.com): rename this SetRole to reflect its new function // TODO(ekr@rtfm.com): rename this SetRole to reflect its new function
virtual void SetServerRole(SSLRole role = SSL_SERVER) = 0; virtual void SetServerRole(SSLRole role = SSL_SERVER) = 0;
// Do DTLS or TLS. [[deprecated("Only DTLS is supported by the stream adapter")]] virtual void
virtual void SetMode(SSLMode mode) = 0; SetMode(SSLMode mode) = 0;
// Set maximum supported protocol version. The highest version supported by // Set maximum supported protocol version. The highest version supported by
// both ends will be used for the connection, i.e. if one party supports // both ends will be used for the connection, i.e. if one party supports

View File

@ -42,7 +42,6 @@ using ::testing::Values;
using ::testing::WithParamInterface; using ::testing::WithParamInterface;
using ::webrtc::SafeTask; using ::webrtc::SafeTask;
static const int kBlockSize = 4096;
static const char kExporterLabel[] = "label"; static const char kExporterLabel[] = "label";
static const unsigned char kExporterContext[] = "context"; static const unsigned char kExporterContext[] = "context";
static int kExporterContextLen = sizeof(kExporterContext); static int kExporterContextLen = sizeof(kExporterContext);
@ -354,7 +353,6 @@ class BufferQueueStream : public rtc::StreamInterface {
rtc::BufferQueue buffer_; rtc::BufferQueue buffer_;
}; };
static const int kFifoBufferSize = 4096;
static const int kBufferCapacity = 1; static const int kBufferCapacity = 1;
static const size_t kDefaultBufferSize = 2048; static const size_t kDefaultBufferSize = 2048;
@ -513,8 +511,6 @@ class SSLStreamAdapterTestBase : public ::testing::Test,
} }
void TestHandshake(bool expect_success = true) { void TestHandshake(bool expect_success = true) {
server_ssl_->SetMode(dtls_ ? rtc::SSL_MODE_DTLS : rtc::SSL_MODE_TLS);
client_ssl_->SetMode(dtls_ ? rtc::SSL_MODE_DTLS : rtc::SSL_MODE_TLS);
if (!dtls_) { if (!dtls_) {
// Make sure we simulate a reliable network for TLS. // Make sure we simulate a reliable network for TLS.
@ -554,8 +550,6 @@ class SSLStreamAdapterTestBase : public ::testing::Test,
rtc::ScopedFakeClock clock; rtc::ScopedFakeClock clock;
int64_t time_start = clock.TimeNanos(); int64_t time_start = clock.TimeNanos();
webrtc::TimeDelta time_increment = webrtc::TimeDelta::Millis(1000); webrtc::TimeDelta time_increment = webrtc::TimeDelta::Millis(1000);
server_ssl_->SetMode(dtls_ ? rtc::SSL_MODE_DTLS : rtc::SSL_MODE_TLS);
client_ssl_->SetMode(dtls_ ? rtc::SSL_MODE_DTLS : rtc::SSL_MODE_TLS);
if (!dtls_) { if (!dtls_) {
// Make sure we simulate a reliable network for TLS. // Make sure we simulate a reliable network for TLS.
@ -596,9 +590,6 @@ class SSLStreamAdapterTestBase : public ::testing::Test,
// and the identity will be verified after the fact. It also verifies that // and the identity will be verified after the fact. It also verifies that
// packets can't be read or written before the identity has been verified. // packets can't be read or written before the identity has been verified.
void TestHandshakeWithDelayedIdentity(bool valid_identity) { void TestHandshakeWithDelayedIdentity(bool valid_identity) {
server_ssl_->SetMode(dtls_ ? rtc::SSL_MODE_DTLS : rtc::SSL_MODE_TLS);
client_ssl_->SetMode(dtls_ ? rtc::SSL_MODE_DTLS : rtc::SSL_MODE_TLS);
if (!dtls_) { if (!dtls_) {
// Make sure we simulate a reliable network for TLS. // Make sure we simulate a reliable network for TLS.
// This is just a check to make sure that people don't write wrong // This is just a check to make sure that people don't write wrong
@ -848,132 +839,6 @@ class SSLStreamAdapterTestBase : public ::testing::Test,
bool identities_set_; bool identities_set_;
}; };
class SSLStreamAdapterTestTLS
: public SSLStreamAdapterTestBase,
public WithParamInterface<tuple<rtc::KeyParams, rtc::KeyParams>> {
public:
SSLStreamAdapterTestTLS()
: SSLStreamAdapterTestBase("",
"",
false,
::testing::get<0>(GetParam()),
::testing::get<1>(GetParam())) {}
std::unique_ptr<rtc::StreamInterface> CreateClientStream() override final {
return absl::WrapUnique(
new SSLDummyStream(this, "c2s", &client_buffer_, &server_buffer_));
}
std::unique_ptr<rtc::StreamInterface> CreateServerStream() override final {
return absl::WrapUnique(
new SSLDummyStream(this, "s2c", &server_buffer_, &client_buffer_));
}
// Test data transfer for TLS
void TestTransfer(int size) override {
RTC_LOG(LS_INFO) << "Starting transfer test with " << size << " bytes";
// Create some dummy data to send.
size_t received;
send_stream_.ReserveSize(size);
for (int i = 0; i < size; ++i) {
uint8_t ch = static_cast<uint8_t>(i);
size_t written;
int error;
send_stream_.Write(rtc::MakeArrayView(&ch, 1), written, error);
}
send_stream_.Rewind();
// Prepare the receive stream.
recv_stream_.ReserveSize(size);
// Start sending
WriteData();
// Wait for the client to close
EXPECT_TRUE_WAIT(server_ssl_->GetState() == rtc::SS_CLOSED, 10000);
// Now check the data
recv_stream_.GetSize(&received);
EXPECT_EQ(static_cast<size_t>(size), received);
EXPECT_EQ(0,
memcmp(send_stream_.GetBuffer(), recv_stream_.GetBuffer(), size));
}
void WriteData() override {
size_t position, tosend, size;
rtc::StreamResult rv;
size_t sent;
uint8_t block[kBlockSize];
send_stream_.GetSize(&size);
if (!size)
return;
for (;;) {
send_stream_.GetPosition(&position);
int dummy_error;
if (send_stream_.Read(block, tosend, dummy_error) != rtc::SR_EOS) {
int error;
rv = client_ssl_->Write(rtc::MakeArrayView(block, tosend), sent, error);
if (rv == rtc::SR_SUCCESS) {
send_stream_.SetPosition(position + sent);
RTC_LOG(LS_VERBOSE) << "Sent: " << position + sent;
} else if (rv == rtc::SR_BLOCK) {
RTC_LOG(LS_VERBOSE) << "Blocked...";
send_stream_.SetPosition(position);
break;
} else {
ADD_FAILURE();
break;
}
} else {
// Now close
RTC_LOG(LS_INFO) << "Wrote " << position << " bytes. Closing";
client_ssl_->Close();
break;
}
}
}
void ReadData(rtc::StreamInterface* stream) override final {
uint8_t buffer[1600];
size_t bread;
int err2;
rtc::StreamResult r;
for (;;) {
r = stream->Read(buffer, bread, err2);
if (r == rtc::SR_ERROR || r == rtc::SR_EOS) {
// Unfortunately, errors are the way that the stream adapter
// signals close in OpenSSL.
stream->Close();
return;
}
if (r == rtc::SR_BLOCK)
break;
ASSERT_EQ(rtc::SR_SUCCESS, r);
RTC_LOG(LS_VERBOSE) << "Read " << bread;
size_t written;
int error;
recv_stream_.Write(rtc::MakeArrayView(buffer, bread), written, error);
}
}
private:
StreamWrapper client_buffer_{
std::make_unique<rtc::FifoBuffer>(kFifoBufferSize)};
StreamWrapper server_buffer_{
std::make_unique<rtc::FifoBuffer>(kFifoBufferSize)};
rtc::MemoryStream send_stream_;
rtc::MemoryStream recv_stream_;
};
class SSLStreamAdapterTestDTLSBase : public SSLStreamAdapterTestBase { class SSLStreamAdapterTestDTLSBase : public SSLStreamAdapterTestBase {
public: public:
SSLStreamAdapterTestDTLSBase(rtc::KeyParams param1, rtc::KeyParams param2) SSLStreamAdapterTestDTLSBase(rtc::KeyParams param1, rtc::KeyParams param2)
@ -1155,23 +1020,6 @@ class SSLStreamAdapterTestDTLSCertChain : public SSLStreamAdapterTestDTLS {
} }
}; };
// Basic tests: TLS
// Test that we can make a handshake work
TEST_P(SSLStreamAdapterTestTLS, TestTLSConnect) {
TestHandshake();
}
TEST_P(SSLStreamAdapterTestTLS, GetPeerCertChainWithOneCertificate) {
TestHandshake();
std::unique_ptr<rtc::SSLCertChain> cert_chain =
client_ssl_->GetPeerSSLCertChain();
ASSERT_NE(nullptr, cert_chain);
EXPECT_EQ(1u, cert_chain->GetSize());
EXPECT_EQ(cert_chain->Get(0).ToPEMString(),
server_identity()->certificate().ToPEMString());
}
TEST_F(SSLStreamAdapterTestDTLSCertChain, TwoCertHandshake) { TEST_F(SSLStreamAdapterTestDTLSCertChain, TwoCertHandshake) {
auto server_identity = rtc::SSLIdentity::CreateFromPEMChainStrings( auto server_identity = rtc::SSLIdentity::CreateFromPEMChainStrings(
kRSA_PRIVATE_KEY_PEM, std::string(kCERT_PEM) + kCACert); kRSA_PRIVATE_KEY_PEM, std::string(kCERT_PEM) + kCACert);
@ -1222,92 +1070,6 @@ TEST_F(SSLStreamAdapterTestDTLSCertChain, ThreeCertHandshake) {
#endif #endif
} }
// Test that closing the connection on one side updates the other side.
TEST_P(SSLStreamAdapterTestTLS, TestTLSClose) {
TestHandshake();
client_ssl_->Close();
EXPECT_EQ_WAIT(rtc::SS_CLOSED, server_ssl_->GetState(), handshake_wait_);
}
// Test transfer -- trivial
TEST_P(SSLStreamAdapterTestTLS, TestTLSTransfer) {
TestHandshake();
TestTransfer(100000);
}
// Test read-write after close.
TEST_P(SSLStreamAdapterTestTLS, ReadWriteAfterClose) {
TestHandshake();
TestTransfer(100000);
client_ssl_->Close();
rtc::StreamResult rv;
uint8_t block[kBlockSize];
size_t dummy;
int error;
// It's an error to write after closed.
rv = client_ssl_->Write(block, dummy, error);
ASSERT_EQ(rtc::SR_ERROR, rv);
// But after closed read gives you EOS.
rv = client_ssl_->Read(block, dummy, error);
ASSERT_EQ(rtc::SR_EOS, rv);
}
// Test a handshake with a bogus peer digest
TEST_P(SSLStreamAdapterTestTLS, TestTLSBogusDigest) {
SetPeerIdentitiesByDigest(false, true);
TestHandshake(false);
}
TEST_P(SSLStreamAdapterTestTLS, TestTLSDelayedIdentity) {
TestHandshakeWithDelayedIdentity(true);
}
TEST_P(SSLStreamAdapterTestTLS, TestTLSDelayedIdentityWithBogusDigest) {
TestHandshakeWithDelayedIdentity(false);
}
// Test that the correct error is returned when SetPeerCertificateDigest is
// called with an unknown algorithm.
TEST_P(SSLStreamAdapterTestTLS,
TestSetPeerCertificateDigestWithUnknownAlgorithm) {
unsigned char server_digest[20];
size_t server_digest_len;
bool rv;
rtc::SSLPeerCertificateDigestError err;
rv = server_identity()->certificate().ComputeDigest(
rtc::DIGEST_SHA_1, server_digest, 20, &server_digest_len);
ASSERT_TRUE(rv);
rv = client_ssl_->SetPeerCertificateDigest("unknown algorithm", server_digest,
server_digest_len, &err);
EXPECT_EQ(rtc::SSLPeerCertificateDigestError::UNKNOWN_ALGORITHM, err);
EXPECT_FALSE(rv);
}
// Test that the correct error is returned when SetPeerCertificateDigest is
// called with an invalid digest length.
TEST_P(SSLStreamAdapterTestTLS, TestSetPeerCertificateDigestWithInvalidLength) {
unsigned char server_digest[20];
size_t server_digest_len;
bool rv;
rtc::SSLPeerCertificateDigestError err;
rv = server_identity()->certificate().ComputeDigest(
rtc::DIGEST_SHA_1, server_digest, 20, &server_digest_len);
ASSERT_TRUE(rv);
rv = client_ssl_->SetPeerCertificateDigest(rtc::DIGEST_SHA_1, server_digest,
server_digest_len - 1, &err);
EXPECT_EQ(rtc::SSLPeerCertificateDigestError::INVALID_LENGTH, err);
EXPECT_FALSE(rv);
}
// Test moving a bunch of data
// Basic tests: DTLS // Basic tests: DTLS
// Test that we can make a handshake work // Test that we can make a handshake work
TEST_P(SSLStreamAdapterTestDTLS, TestDTLSConnect) { TEST_P(SSLStreamAdapterTestDTLS, TestDTLSConnect) {
@ -1626,15 +1388,6 @@ TEST_P(SSLStreamAdapterTestDTLS, TestGetSslCipherSuite) {
// The RSA keysizes here might look strange, why not include the RFC's size // The RSA keysizes here might look strange, why not include the RFC's size
// 2048?. The reason is test case slowness; testing two sizes to exercise // 2048?. The reason is test case slowness; testing two sizes to exercise
// parametrization is sufficient. // parametrization is sufficient.
INSTANTIATE_TEST_SUITE_P(
SSLStreamAdapterTestsTLS,
SSLStreamAdapterTestTLS,
Combine(Values(rtc::KeyParams::RSA(1024, 65537),
rtc::KeyParams::RSA(1152, 65537),
rtc::KeyParams::ECDSA(rtc::EC_NIST_P256)),
Values(rtc::KeyParams::RSA(1024, 65537),
rtc::KeyParams::RSA(1152, 65537),
rtc::KeyParams::ECDSA(rtc::EC_NIST_P256))));
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
SSLStreamAdapterTestsDTLS, SSLStreamAdapterTestsDTLS,
SSLStreamAdapterTestDTLS, SSLStreamAdapterTestDTLS,