diff --git a/rtc_base/ssl_stream_adapter_unittest.cc b/rtc_base/ssl_stream_adapter_unittest.cc index def4c47c0d..338921824d 100644 --- a/rtc_base/ssl_stream_adapter_unittest.cc +++ b/rtc_base/ssl_stream_adapter_unittest.cc @@ -157,6 +157,7 @@ class SSLDummyStreamBase : public rtc::StreamInterface, rtc::StreamInterface* in, rtc::StreamInterface* out) : test_base_(test), side_(side), in_(in), out_(out), first_packet_(true) { + RTC_DCHECK_NE(in, out); in_->SignalEvent.connect(this, &SSLDummyStreamBase::OnEventIn); out_->SignalEvent.connect(this, &SSLDummyStreamBase::OnEventOut); } @@ -187,7 +188,7 @@ class SSLDummyStreamBase : public rtc::StreamInterface, int mask = (rtc::SE_READ | rtc::SE_CLOSE); if (sig & mask) { - RTC_LOG(LS_VERBOSE) << "SSLDummyStreamBase::OnEvent side=" << side_ + RTC_LOG(LS_VERBOSE) << "SSLDummyStreamBase::OnEventIn side=" << side_ << " sig=" << sig << " forwarding upward"; PostEvent(sig & mask, 0); } @@ -196,7 +197,7 @@ class SSLDummyStreamBase : public rtc::StreamInterface, // Catch writeability events on out and pass them up. void OnEventOut(rtc::StreamInterface* stream, int sig, int err) { if (sig & rtc::SE_WRITE) { - RTC_LOG(LS_VERBOSE) << "SSLDummyStreamBase::OnEvent side=" << side_ + RTC_LOG(LS_VERBOSE) << "SSLDummyStreamBase::OnEventOut side=" << side_ << " sig=" << sig << " forwarding upward"; PostEvent(sig & rtc::SE_WRITE, 0); @@ -327,8 +328,6 @@ class SSLStreamAdapterTestBase : public ::testing::Test, client_private_key_pem_(client_private_key_pem), client_key_type_(client_key_type), server_key_type_(server_key_type), - client_stream_(nullptr), - server_stream_(nullptr), delay_(0), mtu_(1460), loss_(0), @@ -347,16 +346,7 @@ class SSLStreamAdapterTestBase : public ::testing::Test, } void SetUp() override { - CreateStreams(); - - client_ssl_ = - rtc::SSLStreamAdapter::Create(absl::WrapUnique(client_stream_)); - server_ssl_ = - rtc::SSLStreamAdapter::Create(absl::WrapUnique(server_stream_)); - - // Set up the slots - client_ssl_->SignalEvent.connect(this, &SSLStreamAdapterTestBase::OnEvent); - server_ssl_->SignalEvent.connect(this, &SSLStreamAdapterTestBase::OnEvent); + InitializeClientAndServerStreams(); std::unique_ptr client_identity; if (!client_cert_pem_.empty() && !client_private_key_pem_.empty()) { @@ -376,21 +366,41 @@ class SSLStreamAdapterTestBase : public ::testing::Test, server_ssl_.reset(nullptr); } - virtual void CreateStreams() = 0; + virtual std::unique_ptr CreateClientStream() = 0; + virtual std::unique_ptr CreateServerStream() = 0; + + void InitializeClientAndServerStreams( + absl::string_view client_experiment = "", + absl::string_view server_experiment = "") { + // Note: `client_ssl_` and `server_ssl_` may be non-nullptr. + + // The legacy TLS protocols flag is read when the OpenSSLStreamAdapter is + // initialized, so we set the field trials while constructing the adapters. + using webrtc::test::ScopedFieldTrials; + { + std::unique_ptr trial( + client_experiment.empty() ? nullptr + : new ScopedFieldTrials(client_experiment)); + client_ssl_ = rtc::SSLStreamAdapter::Create(CreateClientStream()); + } + { + std::unique_ptr trial( + server_experiment.empty() ? nullptr + : new ScopedFieldTrials(server_experiment)); + server_ssl_ = rtc::SSLStreamAdapter::Create(CreateServerStream()); + } + + client_ssl_->SignalEvent.connect(this, + &SSLStreamAdapterTestBase::OnClientEvent); + server_ssl_->SignalEvent.connect(this, + &SSLStreamAdapterTestBase::OnServerEvent); + } // Recreate the client/server identities with the specified validity period. // `not_before` and `not_after` are offsets from the current time in number // of seconds. void ResetIdentitiesWithValidity(int not_before, int not_after) { - CreateStreams(); - - client_ssl_ = - rtc::SSLStreamAdapter::Create(absl::WrapUnique(client_stream_)); - server_ssl_ = - rtc::SSLStreamAdapter::Create(absl::WrapUnique(server_stream_)); - - client_ssl_->SignalEvent.connect(this, &SSLStreamAdapterTestBase::OnEvent); - server_ssl_->SignalEvent.connect(this, &SSLStreamAdapterTestBase::OnEvent); + InitializeClientAndServerStreams(); time_t now = time(nullptr); @@ -412,18 +422,6 @@ class SSLStreamAdapterTestBase : public ::testing::Test, server_ssl_->SetIdentity(std::move(server_identity)); } - virtual void OnEvent(rtc::StreamInterface* stream, int sig, int err) { - RTC_LOG(LS_VERBOSE) << "SSLStreamAdapterTestBase::OnEvent sig=" << sig; - - if (sig & rtc::SE_READ) { - ReadData(stream); - } - - if ((stream == client_ssl_.get()) && (sig & rtc::SE_WRITE)) { - WriteData(); - } - } - void SetPeerIdentitiesByDigest(bool correct, bool expect_success) { unsigned char server_digest[20]; size_t server_digest_len; @@ -755,6 +753,30 @@ class SSLStreamAdapterTestBase : public ::testing::Test, virtual void ReadData(rtc::StreamInterface* stream) = 0; virtual void TestTransfer(int size) = 0; + private: + void OnClientEvent(rtc::StreamInterface* stream, int sig, int err) { + RTC_DCHECK_EQ(stream, client_ssl_.get()); + RTC_LOG(LS_VERBOSE) << "SSLStreamAdapterTestBase::OnClientEvent sig=" + << sig; + + if (sig & rtc::SE_READ) { + ReadData(stream); + } + + if (sig & rtc::SE_WRITE) { + WriteData(); + } + } + + void OnServerEvent(rtc::StreamInterface* stream, int sig, int err) { + RTC_DCHECK_EQ(stream, server_ssl_.get()); + RTC_LOG(LS_VERBOSE) << "SSLStreamAdapterTestBase::OnServerEvent sig=" + << sig; + if (sig & rtc::SE_READ) { + ReadData(stream); + } + } + protected: rtc::SSLIdentity* client_identity() const { if (!client_ssl_) { @@ -774,8 +796,6 @@ class SSLStreamAdapterTestBase : public ::testing::Test, std::string client_private_key_pem_; rtc::KeyParams client_key_type_; rtc::KeyParams server_key_type_; - SSLDummyStreamBase* client_stream_; // freed by client_ssl_ destructor - SSLDummyStreamBase* server_stream_; // freed by server_ssl_ destructor std::unique_ptr client_ssl_; std::unique_ptr server_ssl_; int delay_; @@ -801,11 +821,14 @@ class SSLStreamAdapterTestTLS client_buffer_(kFifoBufferSize), server_buffer_(kFifoBufferSize) {} - void CreateStreams() override { - client_stream_ = - new SSLDummyStreamTLS(this, "c2s", &client_buffer_, &server_buffer_); - server_stream_ = - new SSLDummyStreamTLS(this, "s2c", &server_buffer_, &client_buffer_); + std::unique_ptr CreateClientStream() override final { + return absl::WrapUnique( + new SSLDummyStreamTLS(this, "c2s", &client_buffer_, &server_buffer_)); + } + + std::unique_ptr CreateServerStream() override final { + return absl::WrapUnique( + new SSLDummyStreamTLS(this, "s2c", &server_buffer_, &client_buffer_)); } // Test data transfer for TLS @@ -877,7 +900,7 @@ class SSLStreamAdapterTestTLS } } - void ReadData(rtc::StreamInterface* stream) override { + void ReadData(rtc::StreamInterface* stream) override final { uint8_t buffer[1600]; size_t bread; int err2; @@ -930,11 +953,14 @@ class SSLStreamAdapterTestDTLSBase : public SSLStreamAdapterTestBase { count_(0), sent_(0) {} - void CreateStreams() override { - client_stream_ = - new SSLDummyStreamDTLS(this, "c2s", &client_buffer_, &server_buffer_); - server_stream_ = - new SSLDummyStreamDTLS(this, "s2c", &server_buffer_, &client_buffer_); + std::unique_ptr CreateClientStream() override final { + return absl::WrapUnique( + new SSLDummyStreamDTLS(this, "c2s", &client_buffer_, &server_buffer_)); + } + + std::unique_ptr CreateServerStream() override final { + return absl::WrapUnique( + new SSLDummyStreamDTLS(this, "s2c", &server_buffer_, &client_buffer_)); } void WriteData() override { @@ -968,7 +994,7 @@ class SSLStreamAdapterTestDTLSBase : public SSLStreamAdapterTestBase { delete[] packet; } - void ReadData(rtc::StreamInterface* stream) override { + void ReadData(rtc::StreamInterface* stream) override final { uint8_t buffer[2000]; size_t bread; int err2; @@ -1077,20 +1103,7 @@ class SSLStreamAdapterTestDTLSCertChain : public SSLStreamAdapterTestDTLS { public: SSLStreamAdapterTestDTLSCertChain() : SSLStreamAdapterTestDTLS("", "") {} void SetUp() override { - CreateStreams(); - - client_ssl_ = - rtc::SSLStreamAdapter::Create(absl::WrapUnique(client_stream_)); - server_ssl_ = - rtc::SSLStreamAdapter::Create(absl::WrapUnique(server_stream_)); - - // Set up the slots - client_ssl_->SignalEvent.connect( - reinterpret_cast(this), - &SSLStreamAdapterTestBase::OnEvent); - server_ssl_->SignalEvent.connect( - reinterpret_cast(this), - &SSLStreamAdapterTestBase::OnEvent); + InitializeClientAndServerStreams(); std::unique_ptr client_identity; if (!client_cert_pem_.empty() && !client_private_key_pem_.empty()) { @@ -1625,35 +1638,11 @@ class SSLStreamAdapterTestDTLSExtensionPermutation rtc::KeyParams::ECDSA(rtc::EC_NIST_P256)) { } - // Do not use the SetUp version from the parent class. - void SetUp() override {} - - // The legacy TLS protocols flag is read when the OpenSSLStreamAdapter is - // initialized, so we set the experiment while creationg client_ssl_ - // and server_ssl_. - - void ConfigureClient(absl::string_view experiment) { - webrtc::test::ScopedFieldTrials trial{std::string(experiment)}; - client_stream_ = - new SSLDummyStreamDTLS(this, "c2s", &client_buffer_, &server_buffer_); - client_ssl_ = - rtc::SSLStreamAdapter::Create(absl::WrapUnique(client_stream_)); - client_ssl_->SignalEvent.connect( - static_cast(this), - &SSLStreamAdapterTestBase::OnEvent); - auto client_identity = rtc::SSLIdentity::Create("client", client_key_type_); - client_ssl_->SetIdentity(std::move(client_identity)); - } - - void ConfigureServer(absl::string_view experiment) { - webrtc::test::ScopedFieldTrials trial{std::string(experiment)}; - server_stream_ = - new SSLDummyStreamDTLS(this, "s2c", &server_buffer_, &client_buffer_); - server_ssl_ = - rtc::SSLStreamAdapter::Create(absl::WrapUnique(server_stream_)); - server_ssl_->SignalEvent.connect( - static_cast(this), - &SSLStreamAdapterTestBase::OnEvent); + void Initialize(absl::string_view client_experiment, + absl::string_view server_experiment) { + InitializeClientAndServerStreams(client_experiment, server_experiment); + client_ssl_->SetIdentity( + rtc::SSLIdentity::Create("client", client_key_type_)); server_ssl_->SetIdentity( rtc::SSLIdentity::Create("server", server_key_type_)); } @@ -1661,29 +1650,26 @@ class SSLStreamAdapterTestDTLSExtensionPermutation TEST_F(SSLStreamAdapterTestDTLSExtensionPermutation, ClientDefaultServerDefault) { - ConfigureClient(""); - ConfigureServer(""); + Initialize("", ""); TestHandshake(); } TEST_F(SSLStreamAdapterTestDTLSExtensionPermutation, ClientDefaultServerPermute) { - ConfigureClient(""); - ConfigureServer("WebRTC-PermuteTlsClientHello/Enabled/"); + Initialize("", "WebRTC-PermuteTlsClientHello/Enabled/"); TestHandshake(); } TEST_F(SSLStreamAdapterTestDTLSExtensionPermutation, ClientPermuteServerDefault) { - ConfigureClient("WebRTC-PermuteTlsClientHello/Enabled/"); - ConfigureServer(""); + Initialize("WebRTC-PermuteTlsClientHello/Enabled/", ""); TestHandshake(); } TEST_F(SSLStreamAdapterTestDTLSExtensionPermutation, ClientPermuteServerPermute) { - ConfigureClient("WebRTC-PermuteTlsClientHello/Enabled/"); - ConfigureServer("WebRTC-PermuteTlsClientHello/Enabled/"); + Initialize("WebRTC-PermuteTlsClientHello/Enabled/", + "WebRTC-PermuteTlsClientHello/Enabled/"); TestHandshake(); } #endif // OPENSSL_IS_BORINGSSL