diff --git a/webrtc/base/opensslstreamadapter.cc b/webrtc/base/opensslstreamadapter.cc index 4b40c38904..6a42003968 100644 --- a/webrtc/base/opensslstreamadapter.cc +++ b/webrtc/base/opensslstreamadapter.cc @@ -25,6 +25,7 @@ #include #include +#include "webrtc/base/checks.h" #include "webrtc/base/common.h" #include "webrtc/base/logging.h" #include "webrtc/base/safe_conversions.h" @@ -290,11 +291,11 @@ OpenSSLStreamAdapter::OpenSSLStreamAdapter(StreamInterface* stream) ssl_max_version_(SSL_PROTOCOL_TLS_12) {} OpenSSLStreamAdapter::~OpenSSLStreamAdapter() { - Cleanup(); + Cleanup(0); } void OpenSSLStreamAdapter::SetIdentity(SSLIdentity* identity) { - ASSERT(!identity_); + RTC_DCHECK(!identity_); identity_.reset(static_cast(identity)); } @@ -309,25 +310,56 @@ std::unique_ptr OpenSSLStreamAdapter::GetPeerCertificate() : nullptr; } -bool OpenSSLStreamAdapter::SetPeerCertificateDigest(const std::string - &digest_alg, - const unsigned char* - digest_val, - size_t digest_len) { - ASSERT(!peer_certificate_); - ASSERT(peer_certificate_digest_algorithm_.size() == 0); +bool OpenSSLStreamAdapter::SetPeerCertificateDigest( + const std::string& digest_alg, + const unsigned char* digest_val, + size_t digest_len, + SSLPeerCertificateDigestError* error) { + RTC_DCHECK(!peer_certificate_verified_); + RTC_DCHECK(!has_peer_certificate_digest()); size_t expected_len; + if (error) { + *error = SSLPeerCertificateDigestError::NONE; + } if (!OpenSSLDigest::GetDigestSize(digest_alg, &expected_len)) { LOG(LS_WARNING) << "Unknown digest algorithm: " << digest_alg; + if (error) { + *error = SSLPeerCertificateDigestError::UNKNOWN_ALGORITHM; + } return false; } - if (expected_len != digest_len) + if (expected_len != digest_len) { + if (error) { + *error = SSLPeerCertificateDigestError::INVALID_LENGTH; + } return false; + } peer_certificate_digest_value_.SetData(digest_val, digest_len); peer_certificate_digest_algorithm_ = digest_alg; + if (!peer_certificate_) { + // Normal case, where the digest is set before we obtain the certificate + // from the handshake. + return true; + } + + if (!VerifyPeerCertificate()) { + Error("SetPeerCertificateDigest", -1, SSL_AD_BAD_CERTIFICATE, false); + if (error) { + *error = SSLPeerCertificateDigestError::VERIFICATION_FAILED; + } + return false; + } + + if (state_ == SSL_CONNECTED) { + // Post the event asynchronously to unwind the stack. The caller + // of ContinueSSL may be the same object listening for these + // events and may not be prepared for reentrancy. + PostEvent(SE_OPEN | SE_READ | SE_WRITE, 0); + } + return true; } @@ -450,7 +482,7 @@ bool OpenSSLStreamAdapter::SetDtlsSrtpCryptoSuites( bool OpenSSLStreamAdapter::GetDtlsSrtpCryptoSuite(int* crypto_suite) { #ifdef HAVE_DTLS_SRTP - ASSERT(state_ == SSL_CONNECTED); + RTC_DCHECK(state_ == SSL_CONNECTED); if (state_ != SSL_CONNECTED) return false; @@ -461,15 +493,22 @@ bool OpenSSLStreamAdapter::GetDtlsSrtpCryptoSuite(int* crypto_suite) { return false; *crypto_suite = srtp_profile->id; - ASSERT(!SrtpCryptoSuiteToName(*crypto_suite).empty()); + RTC_DCHECK(!SrtpCryptoSuiteToName(*crypto_suite).empty()); return true; #else return false; #endif } +bool OpenSSLStreamAdapter::IsTlsConnected() { + return state_ == SSL_CONNECTED; +} + int OpenSSLStreamAdapter::StartSSL() { - ASSERT(state_ == SSL_NONE); + if (state_ != SSL_NONE) { + // Don't allow StartSSL to be called twice. + return -1; + } if (StreamAdapterInterface::GetState() != SS_OPEN) { state_ = SSL_WAIT; @@ -478,7 +517,7 @@ int OpenSSLStreamAdapter::StartSSL() { state_ = SSL_CONNECTING; if (int err = BeginSSL()) { - Error("BeginSSL", err, false); + Error("BeginSSL", err, 0, false); return err; } @@ -486,12 +525,12 @@ int OpenSSLStreamAdapter::StartSSL() { } void OpenSSLStreamAdapter::SetMode(SSLMode mode) { - ASSERT(state_ == SSL_NONE); + RTC_DCHECK(state_ == SSL_NONE); ssl_mode_ = mode; } void OpenSSLStreamAdapter::SetMaxProtocolVersion(SSLProtocolVersion version) { - ASSERT(ssl_ctx_ == NULL); + RTC_DCHECK(ssl_ctx_ == NULL); ssl_max_version_ = version; } @@ -513,6 +552,9 @@ StreamResult OpenSSLStreamAdapter::Write(const void* data, size_t data_len, return SR_BLOCK; case SSL_CONNECTED: + if (waiting_to_verify_peer_certificate()) { + return SR_BLOCK; + } break; case SSL_ERROR: @@ -537,7 +579,7 @@ StreamResult OpenSSLStreamAdapter::Write(const void* data, size_t data_len, switch (ssl_error) { case SSL_ERROR_NONE: LOG(LS_VERBOSE) << " -- success"; - ASSERT(0 < code && static_cast(code) <= data_len); + RTC_DCHECK(0 < code && static_cast(code) <= data_len); if (written) *written = code; return SR_SUCCESS; @@ -551,7 +593,7 @@ StreamResult OpenSSLStreamAdapter::Write(const void* data, size_t data_len, case SSL_ERROR_ZERO_RETURN: default: - Error("SSL_write", (ssl_error ? ssl_error : -1), false); + Error("SSL_write", (ssl_error ? ssl_error : -1), 0, false); if (error) *error = ssl_error_code_; return SR_ERROR; @@ -572,6 +614,9 @@ StreamResult OpenSSLStreamAdapter::Read(void* data, size_t data_len, return SR_BLOCK; case SSL_CONNECTED: + if (waiting_to_verify_peer_certificate()) { + return SR_BLOCK; + } break; case SSL_CLOSED: @@ -598,7 +643,7 @@ StreamResult OpenSSLStreamAdapter::Read(void* data, size_t data_len, switch (ssl_error) { case SSL_ERROR_NONE: LOG(LS_VERBOSE) << " -- success"; - ASSERT(0 < code && static_cast(code) <= data_len); + RTC_DCHECK(0 < code && static_cast(code) <= data_len); if (read) *read = code; @@ -624,15 +669,12 @@ StreamResult OpenSSLStreamAdapter::Read(void* data, size_t data_len, return SR_BLOCK; case SSL_ERROR_ZERO_RETURN: LOG(LS_VERBOSE) << " -- remote side closed"; - // When we're closed at SSL layer, also close the stream level which - // performs necessary clean up. Otherwise, a new incoming packet after - // this could overflow the stream buffer. - this->stream()->Close(); + Close(); return SR_EOS; break; default: LOG(LS_VERBOSE) << " -- error " << code; - Error("SSL_read", (ssl_error ? ssl_error : -1), false); + Error("SSL_read", (ssl_error ? ssl_error : -1), 0, false); if (error) *error = ssl_error_code_; return SR_ERROR; @@ -649,11 +691,11 @@ void OpenSSLStreamAdapter::FlushInput(unsigned int left) { int code = SSL_read(ssl_, buf, toread); int ssl_error = SSL_get_error(ssl_, code); - ASSERT(ssl_error == SSL_ERROR_NONE); + RTC_DCHECK(ssl_error == SSL_ERROR_NONE); if (ssl_error != SSL_ERROR_NONE) { LOG(LS_VERBOSE) << " -- error " << code; - Error("SSL_read", (ssl_error ? ssl_error : -1), false); + Error("SSL_read", (ssl_error ? ssl_error : -1), 0, false); return; } @@ -663,8 +705,11 @@ void OpenSSLStreamAdapter::FlushInput(unsigned int left) { } void OpenSSLStreamAdapter::Close() { - Cleanup(); - ASSERT(state_ == SSL_CLOSED || state_ == SSL_ERROR); + Cleanup(0); + RTC_DCHECK(state_ == SSL_CLOSED || state_ == SSL_ERROR); + // When we're closed at SSL layer, also close the stream level which + // performs necessary clean up. Otherwise, a new incoming packet after + // this could overflow the stream buffer. StreamAdapterInterface::Close(); } @@ -674,6 +719,9 @@ StreamState OpenSSLStreamAdapter::GetState() const { case SSL_CONNECTING: return SS_OPENING; case SSL_CONNECTED: + if (waiting_to_verify_peer_certificate()) { + return SS_OPENING; + } return SS_OPEN; default: return SS_CLOSED; @@ -685,16 +733,16 @@ void OpenSSLStreamAdapter::OnEvent(StreamInterface* stream, int events, int err) { int events_to_signal = 0; int signal_error = 0; - ASSERT(stream == this->stream()); + RTC_DCHECK(stream == this->stream()); if ((events & SE_OPEN)) { LOG(LS_VERBOSE) << "OpenSSLStreamAdapter::OnEvent SE_OPEN"; if (state_ != SSL_WAIT) { - ASSERT(state_ == SSL_NONE); + RTC_DCHECK(state_ == SSL_NONE); events_to_signal |= SE_OPEN; } else { state_ = SSL_CONNECTING; if (int err = BeginSSL()) { - Error("BeginSSL", err, true); + Error("BeginSSL", err, 0, true); return; } } @@ -707,7 +755,7 @@ void OpenSSLStreamAdapter::OnEvent(StreamInterface* stream, int events, events_to_signal |= events & (SE_READ|SE_WRITE); } else if (state_ == SSL_CONNECTING) { if (int err = ContinueSSL()) { - Error("ContinueSSL", err, true); + Error("ContinueSSL", err, 0, true); return; } } else if (state_ == SSL_CONNECTED) { @@ -725,10 +773,10 @@ void OpenSSLStreamAdapter::OnEvent(StreamInterface* stream, int events, } if ((events & SE_CLOSE)) { LOG(LS_VERBOSE) << "OpenSSLStreamAdapter::OnEvent(SE_CLOSE, " << err << ")"; - Cleanup(); + Cleanup(0); events_to_signal |= SE_CLOSE; // SE_CLOSE is the only event that uses the final parameter to OnEvent(). - ASSERT(signal_error == 0); + RTC_DCHECK(signal_error == 0); signal_error = err; } if (events_to_signal) @@ -736,16 +784,14 @@ void OpenSSLStreamAdapter::OnEvent(StreamInterface* stream, int events, } int OpenSSLStreamAdapter::BeginSSL() { - ASSERT(state_ == SSL_CONNECTING); + RTC_DCHECK(state_ == SSL_CONNECTING); // The underlying stream has opened. - // A peer certificate digest must have been specified by now. - ASSERT(!peer_certificate_digest_algorithm_.empty()); LOG(LS_INFO) << "BeginSSL with peer."; BIO* bio = NULL; // First set up the context. - ASSERT(ssl_ctx_ == NULL); + RTC_DCHECK(ssl_ctx_ == NULL); ssl_ctx_ = SetupSSLContext(); if (!ssl_ctx_) return -1; @@ -799,7 +845,7 @@ int OpenSSLStreamAdapter::BeginSSL() { int OpenSSLStreamAdapter::ContinueSSL() { LOG(LS_VERBOSE) << "ContinueSSL"; - ASSERT(state_ == SSL_CONNECTING); + RTC_DCHECK(state_ == SSL_CONNECTING); // Clear the DTLS timer Thread::Current()->Clear(this, MSG_TIMEOUT); @@ -809,15 +855,21 @@ int OpenSSLStreamAdapter::ContinueSSL() { switch (ssl_error = SSL_get_error(ssl_, code)) { case SSL_ERROR_NONE: LOG(LS_VERBOSE) << " -- success"; - - if (!SSLPostConnectionCheck(ssl_, NULL, - peer_certificate_digest_algorithm_)) { - LOG(LS_ERROR) << "TLS post connection check failed"; - return -1; - } + // By this point, OpenSSL should have given us a certificate, or errored + // out if one was missing. + RTC_DCHECK(peer_certificate_ || !client_auth_enabled()); state_ = SSL_CONNECTED; - StreamAdapterInterface::OnEvent(stream(), SE_OPEN|SE_READ|SE_WRITE, 0); + if (!waiting_to_verify_peer_certificate()) { + // We have everything we need to start the connection, so signal + // SE_OPEN. If we need a client certificate fingerprint and don't have + // it yet, we'll instead signal SE_OPEN in SetPeerCertificateDigest. + // + // Post the event asynchronously to unwind the stack. The + // caller of ContinueSSL may be the same object listening + // for these events and may not be prepared for reentrancy. + PostEvent(SE_OPEN | SE_READ | SE_WRITE, 0); + } break; case SSL_ERROR_WANT_READ: { @@ -851,17 +903,20 @@ int OpenSSLStreamAdapter::ContinueSSL() { return 0; } -void OpenSSLStreamAdapter::Error(const char* context, int err, bool signal) { - LOG(LS_WARNING) << "OpenSSLStreamAdapter::Error(" - << context << ", " << err << ")"; +void OpenSSLStreamAdapter::Error(const char* context, + int err, + uint8_t alert, + bool signal) { + LOG(LS_WARNING) << "OpenSSLStreamAdapter::Error(" << context << ", " << err + << ", " << static_cast(alert) << ")"; state_ = SSL_ERROR; ssl_error_code_ = err; - Cleanup(); + Cleanup(alert); if (signal) StreamAdapterInterface::OnEvent(stream(), SE_CLOSE, err); } -void OpenSSLStreamAdapter::Cleanup() { +void OpenSSLStreamAdapter::Cleanup(uint8_t alert) { LOG(LS_INFO) << "Cleanup"; if (state_ != SSL_ERROR) { @@ -870,12 +925,25 @@ void OpenSSLStreamAdapter::Cleanup() { } if (ssl_) { - int ret = SSL_shutdown(ssl_); - if (ret < 0) { - LOG(LS_WARNING) << "SSL_shutdown failed, error = " - << SSL_get_error(ssl_, ret); + int ret; +// SSL_send_fatal_alert is only available in BoringSSL. +#ifdef OPENSSL_IS_BORINGSSL + if (alert) { + ret = SSL_send_fatal_alert(ssl_, alert); + if (ret < 0) { + LOG(LS_WARNING) << "SSL_send_fatal_alert failed, error = " + << SSL_get_error(ssl_, ret); + } + } else { +#endif + ret = SSL_shutdown(ssl_); + if (ret < 0) { + LOG(LS_WARNING) << "SSL_shutdown failed, error = " + << SSL_get_error(ssl_, ret); + } +#ifdef OPENSSL_IS_BORINGSSL } - +#endif SSL_free(ssl_); ssl_ = NULL; } @@ -1033,21 +1101,42 @@ SSL_CTX* OpenSSLStreamAdapter::SetupSSLContext() { return ctx; } -int OpenSSLStreamAdapter::SSLVerifyCallback(int ok, X509_STORE_CTX* store) { - // Get our SSL structure from the store - SSL* ssl = reinterpret_cast(X509_STORE_CTX_get_ex_data( - store, - SSL_get_ex_data_X509_STORE_CTX_idx())); - OpenSSLStreamAdapter* stream = - reinterpret_cast(SSL_get_app_data(ssl)); +bool OpenSSLStreamAdapter::VerifyPeerCertificate() { + if (!has_peer_certificate_digest() || !peer_certificate_) { + LOG(LS_WARNING) << "Missing digest or peer certificate."; + return false; + } - if (stream->peer_certificate_digest_algorithm_.empty()) { + unsigned char digest[EVP_MAX_MD_SIZE]; + size_t digest_length; + if (!OpenSSLCertificate::ComputeDigest( + peer_certificate_->x509(), peer_certificate_digest_algorithm_, digest, + sizeof(digest), &digest_length)) { + LOG(LS_WARNING) << "Failed to compute peer cert digest."; + return false; + } + + Buffer computed_digest(digest, digest_length); + if (computed_digest != peer_certificate_digest_value_) { + LOG(LS_WARNING) << "Rejected peer certificate due to mismatched digest."; return 0; } + // Ignore any verification error if the digest matches, since there is no + // value in checking the validity of a self-signed cert issued by untrusted + // sources. + LOG(LS_INFO) << "Accepted peer certificate."; + peer_certificate_verified_ = true; + return true; +} + +int OpenSSLStreamAdapter::SSLVerifyCallback(int ok, X509_STORE_CTX* store) { + // Get our SSL structure from the store + SSL* ssl = reinterpret_cast( + X509_STORE_CTX_get_ex_data(store, SSL_get_ex_data_X509_STORE_CTX_idx())); X509* cert = X509_STORE_CTX_get_current_cert(store); int depth = X509_STORE_CTX_get_error_depth(store); - // For now We ignore the parent certificates and verify the leaf against + // For now we ignore the parent certificates and verify the leaf against // the digest. // // TODO(jiayl): Verify the chain is a proper chain and report the chain to @@ -1057,38 +1146,20 @@ int OpenSSLStreamAdapter::SSLVerifyCallback(int ok, X509_STORE_CTX* store) { return 1; } - unsigned char digest[EVP_MAX_MD_SIZE]; - size_t digest_length; - if (!OpenSSLCertificate::ComputeDigest( - cert, - stream->peer_certificate_digest_algorithm_, - digest, sizeof(digest), - &digest_length)) { - LOG(LS_WARNING) << "Failed to compute peer cert digest."; - return 0; - } - - Buffer computed_digest(digest, digest_length); - if (computed_digest != stream->peer_certificate_digest_value_) { - LOG(LS_WARNING) << "Rejected peer certificate due to mismatched digest."; - return 0; - } - // Ignore any verification error if the digest matches, since there is no - // value in checking the validity of a self-signed cert issued by untrusted - // sources. - LOG(LS_INFO) << "Accepted peer certificate."; + OpenSSLStreamAdapter* stream = + reinterpret_cast(SSL_get_app_data(ssl)); // Record the peer's certificate. stream->peer_certificate_.reset(new OpenSSLCertificate(cert)); - return 1; -} -bool OpenSSLStreamAdapter::SSLPostConnectionCheck(SSL* ssl, - const X509* peer_cert, - const std::string - &peer_digest) { - ASSERT((peer_cert != NULL) || (!peer_digest.empty())); - return true; + // If the peer certificate digest isn't known yet, we'll wait to verify + // until it's known, and for now just return a success status. + if (stream->peer_certificate_digest_algorithm_.empty()) { + LOG(LS_INFO) << "Waiting to verify certificate until digest is known."; + return 1; + } + + return stream->VerifyPeerCertificate(); } bool OpenSSLStreamAdapter::HaveDtls() { diff --git a/webrtc/base/opensslstreamadapter.h b/webrtc/base/opensslstreamadapter.h index 76dbad24a0..76ff2183a0 100644 --- a/webrtc/base/opensslstreamadapter.h +++ b/webrtc/base/opensslstreamadapter.h @@ -63,9 +63,11 @@ class OpenSSLStreamAdapter : public SSLStreamAdapter { // Default argument is for compatibility void SetServerRole(SSLRole role = SSL_SERVER) override; - bool SetPeerCertificateDigest(const std::string& digest_alg, - const unsigned char* digest_val, - size_t digest_len) override; + bool SetPeerCertificateDigest( + const std::string& digest_alg, + const unsigned char* digest_val, + size_t digest_len, + SSLPeerCertificateDigestError* error = nullptr) override; std::unique_ptr GetPeerCertificate() const override; @@ -105,6 +107,8 @@ class OpenSSLStreamAdapter : public SSLStreamAdapter { bool SetDtlsSrtpCryptoSuites(const std::vector& crypto_suites) override; bool GetDtlsSrtpCryptoSuite(int* crypto_suite) override; + bool IsTlsConnected() override; + // Capabilities interfaces static bool HaveDtls(); static bool HaveDtlsSrtp(); @@ -147,8 +151,11 @@ class OpenSSLStreamAdapter : public SSLStreamAdapter { // raised on the stream with the specified error. // A 0 error means a graceful close, otherwise there is not really enough // context to interpret the error code. - void Error(const char* context, int err, bool signal); - void Cleanup(); + // |alert| indicates an alert description (one of the SSL_AD constants) to + // send to the remote endpoint when closing the association. If 0, a normal + // shutdown will be performed. + void Error(const char* context, int err, uint8_t alert, bool signal); + void Cleanup(uint8_t alert); // Override MessageHandler void OnMessage(Message* msg) override; @@ -158,16 +165,23 @@ class OpenSSLStreamAdapter : public SSLStreamAdapter { // SSL library configuration SSL_CTX* SetupSSLContext(); - // SSL verification check - bool SSLPostConnectionCheck(SSL* ssl, - const X509* peer_cert, - const std::string& peer_digest); + // Verify the peer certificate matches the signaled digest. + bool VerifyPeerCertificate(); // SSL certification verification error handler, called back from // the openssl library. Returns an int interpreted as a boolean in // the C style: zero means verification failure, non-zero means // passed. static int SSLVerifyCallback(int ok, X509_STORE_CTX* store); + bool waiting_to_verify_peer_certificate() const { + return client_auth_enabled() && !peer_certificate_verified_; + } + + bool has_peer_certificate_digest() const { + return !peer_certificate_digest_algorithm_.empty() && + !peer_certificate_digest_value_.empty(); + } + SSLState state_; SSLRole role_; int ssl_error_code_; // valid when state_ == SSL_ERROR or SSL_CLOSED @@ -184,6 +198,7 @@ class OpenSSLStreamAdapter : public SSLStreamAdapter { // The certificate that the peer presented. Initially null, until the // connection is established. std::unique_ptr peer_certificate_; + bool peer_certificate_verified_ = false; // The digest of the certificate that the peer must present. Buffer peer_certificate_digest_value_; std::string peer_certificate_digest_algorithm_; diff --git a/webrtc/base/sslstreamadapter.h b/webrtc/base/sslstreamadapter.h index a7ef23fc79..2b99f00cb3 100644 --- a/webrtc/base/sslstreamadapter.h +++ b/webrtc/base/sslstreamadapter.h @@ -106,6 +106,12 @@ enum SSLProtocolVersion { SSL_PROTOCOL_DTLS_10 = SSL_PROTOCOL_TLS_11, SSL_PROTOCOL_DTLS_12 = SSL_PROTOCOL_TLS_12, }; +enum class SSLPeerCertificateDigestError { + NONE, + UNKNOWN_ALGORITHM, + INVALID_LENGTH, + VERIFICATION_FAILED, +}; // Errors for Read -- in the high range so no conflict with OpenSSL. enum { SSE_MSG_TRUNC = 0xff0001 }; @@ -173,9 +179,14 @@ class SSLStreamAdapter : public StreamAdapterInterface { // certificate is assumed to have been obtained through some other secure // channel (such as the signaling channel). This must specify the terminal // certificate, not just a CA. SSLStream makes a copy of the digest value. - virtual bool SetPeerCertificateDigest(const std::string& digest_alg, - const unsigned char* digest_val, - size_t digest_len) = 0; + // + // Returns true if successful. + // |error| is optional and provides more information about the failure. + virtual bool SetPeerCertificateDigest( + const std::string& digest_alg, + const unsigned char* digest_val, + size_t digest_len, + SSLPeerCertificateDigestError* error = nullptr) = 0; // Retrieves the peer's X.509 certificate, if a connection has been // established. It returns the transmitted over SSL, including the entire @@ -211,6 +222,12 @@ class SSLStreamAdapter : public StreamAdapterInterface { virtual bool SetDtlsSrtpCryptoSuites(const std::vector& crypto_suites); virtual bool GetDtlsSrtpCryptoSuite(int* crypto_suite); + // Returns true if a TLS connection has been established. + // The only difference between this and "GetState() == SE_OPEN" is that if + // the peer certificate digest hasn't been verified, the state will still be + // SS_OPENING but IsTlsConnected should return true. + virtual bool IsTlsConnected() = 0; + // Capabilities testing static bool HaveDtls(); static bool HaveDtlsSrtp(); diff --git a/webrtc/base/sslstreamadapter_unittest.cc b/webrtc/base/sslstreamadapter_unittest.cc index 341e09ff1e..9e156c0b3c 100644 --- a/webrtc/base/sslstreamadapter_unittest.cc +++ b/webrtc/base/sslstreamadapter_unittest.cc @@ -325,36 +325,44 @@ class SSLStreamAdapterTestBase : public testing::Test, } } - void SetPeerIdentitiesByDigest(bool correct) { - unsigned char digest[20]; - size_t digest_len; + void SetPeerIdentitiesByDigest(bool correct, bool expect_success) { + unsigned char server_digest[20]; + size_t server_digest_len; + unsigned char client_digest[20]; + size_t client_digest_len; bool rv; + rtc::SSLPeerCertificateDigestError err; + rtc::SSLPeerCertificateDigestError expected_err = + expect_success + ? rtc::SSLPeerCertificateDigestError::NONE + : rtc::SSLPeerCertificateDigestError::VERIFICATION_FAILED; LOG(LS_INFO) << "Setting peer identities by digest"; - rv = server_identity_->certificate().ComputeDigest(rtc::DIGEST_SHA_1, - digest, 20, - &digest_len); + rv = server_identity_->certificate().ComputeDigest( + rtc::DIGEST_SHA_1, server_digest, 20, &server_digest_len); ASSERT_TRUE(rv); + rv = client_identity_->certificate().ComputeDigest( + rtc::DIGEST_SHA_1, client_digest, 20, &client_digest_len); + ASSERT_TRUE(rv); + if (!correct) { LOG(LS_INFO) << "Setting bogus digest for server cert"; - digest[0]++; + server_digest[0]++; } - rv = client_ssl_->SetPeerCertificateDigest(rtc::DIGEST_SHA_1, digest, - digest_len); - ASSERT_TRUE(rv); + rv = client_ssl_->SetPeerCertificateDigest(rtc::DIGEST_SHA_1, server_digest, + server_digest_len, &err); + EXPECT_EQ(expected_err, err); + EXPECT_EQ(expect_success, rv); - - rv = client_identity_->certificate().ComputeDigest(rtc::DIGEST_SHA_1, - digest, 20, &digest_len); - ASSERT_TRUE(rv); if (!correct) { LOG(LS_INFO) << "Setting bogus digest for client cert"; - digest[0]++; + client_digest[0]++; } - rv = server_ssl_->SetPeerCertificateDigest(rtc::DIGEST_SHA_1, digest, - digest_len); - ASSERT_TRUE(rv); + rv = server_ssl_->SetPeerCertificateDigest(rtc::DIGEST_SHA_1, client_digest, + client_digest_len, &err); + EXPECT_EQ(expected_err, err); + EXPECT_EQ(expect_success, rv); identities_set_ = true; } @@ -379,7 +387,7 @@ class SSLStreamAdapterTestBase : public testing::Test, } if (!identities_set_) - SetPeerIdentitiesByDigest(true); + SetPeerIdentitiesByDigest(true, true); // Start the handshake int rv; @@ -402,6 +410,57 @@ class SSLStreamAdapterTestBase : public testing::Test, } } + // This tests that the handshake can complete before the identity is + // verified, and the identity will be verified after the fact. + void TestHandshakeWithDelayedIdentity(bool valid_identity) { + server_ssl_->SetMode(dtls_ ? rtc::SSL_MODE_DTLS : rtc::SSL_MODE_TLS); + client_ssl_->SetMode(dtls_ ? rtc::SSL_MODE_DTLS : rtc::SSL_MODE_TLS); + + if (!dtls_) { + // Make sure we simulate a reliable network for TLS. + // This is just a check to make sure that people don't write wrong + // tests. + ASSERT((mtu_ == 1460) && (loss_ == 0) && (lose_first_packet_ == 0)); + } + + // Start the handshake + int rv; + + server_ssl_->SetServerRole(); + rv = server_ssl_->StartSSL(); + ASSERT_EQ(0, rv); + + rv = client_ssl_->StartSSL(); + ASSERT_EQ(0, rv); + + // Now run the handshake. + EXPECT_TRUE_WAIT( + client_ssl_->IsTlsConnected() && server_ssl_->IsTlsConnected(), + handshake_wait_); + + // Until the identity has been verified, the state should still be + // SS_OPENING and writes should return SR_BLOCK. + EXPECT_EQ(rtc::SS_OPENING, client_ssl_->GetState()); + EXPECT_EQ(rtc::SS_OPENING, server_ssl_->GetState()); + unsigned char packet[1]; + size_t sent; + EXPECT_EQ(rtc::SR_BLOCK, client_ssl_->Write(&packet, 1, &sent, 0)); + EXPECT_EQ(rtc::SR_BLOCK, server_ssl_->Write(&packet, 1, &sent, 0)); + + // If we set an invalid identity at this point, SetPeerCertificateDigest + // should return false. + SetPeerIdentitiesByDigest(valid_identity, valid_identity); + // State should then transition to SS_OPEN or SS_CLOSED based on validation + // of the identity. + if (valid_identity) { + EXPECT_EQ(rtc::SS_OPEN, client_ssl_->GetState()); + EXPECT_EQ(rtc::SS_OPEN, server_ssl_->GetState()); + } else { + EXPECT_EQ(rtc::SS_CLOSED, client_ssl_->GetState()); + EXPECT_EQ(rtc::SS_CLOSED, server_ssl_->GetState()); + } + } + rtc::StreamResult DataWritten(SSLDummyStreamBase *from, const void *data, size_t data_len, size_t *written, int *error) { @@ -849,10 +908,55 @@ TEST_P(SSLStreamAdapterTestTLS, ReadWriteAfterClose) { // Test a handshake with a bogus peer digest TEST_P(SSLStreamAdapterTestTLS, TestTLSBogusDigest) { - SetPeerIdentitiesByDigest(false); + SetPeerIdentitiesByDigest(false, true); TestHandshake(false); }; +TEST_P(SSLStreamAdapterTestTLS, TestTLSDelayedIdentity) { + TestHandshakeWithDelayedIdentity(true); +}; + +TEST_P(SSLStreamAdapterTestTLS, TestTLSDelayedIdentityWithBogusDigest) { + TestHandshakeWithDelayedIdentity(false); +}; + +// Test that the correct error is returned when SetPeerCertificateDigest is +// called with an unknown algorithm. +TEST_P(SSLStreamAdapterTestTLS, + TestSetPeerCertificateDigestWithUnknownAlgorithm) { + unsigned char server_digest[20]; + size_t server_digest_len; + bool rv; + rtc::SSLPeerCertificateDigestError err; + + rv = server_identity_->certificate().ComputeDigest( + rtc::DIGEST_SHA_1, server_digest, 20, &server_digest_len); + ASSERT_TRUE(rv); + + rv = client_ssl_->SetPeerCertificateDigest("unknown algorithm", server_digest, + server_digest_len, &err); + EXPECT_EQ(rtc::SSLPeerCertificateDigestError::UNKNOWN_ALGORITHM, err); + EXPECT_FALSE(rv); +} + +// Test that the correct error is returned when SetPeerCertificateDigest is +// called with an invalid digest length. +TEST_P(SSLStreamAdapterTestTLS, TestSetPeerCertificateDigestWithInvalidLength) { + unsigned char server_digest[20]; + size_t server_digest_len; + bool rv; + rtc::SSLPeerCertificateDigestError err; + + rv = server_identity_->certificate().ComputeDigest( + rtc::DIGEST_SHA_1, server_digest, 20, &server_digest_len); + ASSERT_TRUE(rv); + + rv = client_ssl_->SetPeerCertificateDigest(rtc::DIGEST_SHA_1, server_digest, + server_digest_len - 1, &err); + EXPECT_EQ(rtc::SSLPeerCertificateDigestError::INVALID_LENGTH, err); + EXPECT_FALSE(rv); +} + // Test moving a bunch of data // Basic tests: DTLS @@ -911,6 +1015,14 @@ TEST_P(SSLStreamAdapterTestDTLS, TestDTLSTransferWithDamage) { TestTransfer(100); }; +TEST_P(SSLStreamAdapterTestDTLS, TestDTLSDelayedIdentity) { + TestHandshakeWithDelayedIdentity(true); +}; + +TEST_P(SSLStreamAdapterTestDTLS, TestDTLSDelayedIdentityWithBogusDigest) { + TestHandshakeWithDelayedIdentity(false); +}; + // Test DTLS-SRTP with all high ciphers TEST_P(SSLStreamAdapterTestDTLS, TestDTLSSrtpHigh) { MAYBE_SKIP_TEST(HaveDtlsSrtp); diff --git a/webrtc/p2p/base/dtlstransportchannel.cc b/webrtc/p2p/base/dtlstransportchannel.cc index 3dde8fea05..d95bdcd916 100644 --- a/webrtc/p2p/base/dtlstransportchannel.cc +++ b/webrtc/p2p/base/dtlstransportchannel.cc @@ -179,7 +179,7 @@ bool DtlsTransportChannelWrapper::SetSslMaxProtocolVersion( } bool DtlsTransportChannelWrapper::SetSslRole(rtc::SSLRole role) { - if (dtls_state() == DTLS_TRANSPORT_CONNECTED) { + if (dtls_) { if (ssl_role_ != role) { LOG(LS_ERROR) << "SSL Role can't be reversed after the session is setup."; return false; @@ -235,12 +235,33 @@ bool DtlsTransportChannelWrapper::SetRemoteFingerprint( } // At this point we know we are doing DTLS + bool fingerprint_changing = remote_fingerprint_value_.size() > 0u; remote_fingerprint_value_ = std::move(remote_fingerprint_value); remote_fingerprint_algorithm_ = digest_alg; - if (dtls_) { - // If the fingerprint is changing, we'll tear down the DTLS association and - // create a new one, resetting our state. + if (dtls_ && !fingerprint_changing) { + // This can occur if DTLS is set up before a remote fingerprint is + // received. For instance, if we set up DTLS due to receiving an early + // ClientHello. + rtc::SSLPeerCertificateDigestError err; + if (!dtls_->SetPeerCertificateDigest( + remote_fingerprint_algorithm_, + reinterpret_cast(remote_fingerprint_value_.data()), + remote_fingerprint_value_.size(), &err)) { + LOG_J(LS_ERROR, this) << "Couldn't set DTLS certificate digest."; + set_dtls_state(DTLS_TRANSPORT_FAILED); + // If the error is "verification failed", don't return false, because + // this means the fingerprint was formatted correctly but didn't match + // the certificate from the DTLS handshake. Thus the DTLS state should go + // to "failed", but SetRemoteDescription shouldn't fail. + return err == rtc::SSLPeerCertificateDigestError::VERIFICATION_FAILED; + } + return true; + } + + // If the fingerprint is changing, we'll tear down the DTLS association and + // create a new one, resetting our state. + if (dtls_ && fingerprint_changing) { dtls_.reset(nullptr); set_dtls_state(DTLS_TRANSPORT_NEW); set_writable(false); @@ -282,7 +303,8 @@ bool DtlsTransportChannelWrapper::SetupDtls() { dtls_->SignalEvent.connect(this, &DtlsTransportChannelWrapper::OnDtlsEvent); dtls_->SignalSSLHandshakeError.connect( this, &DtlsTransportChannelWrapper::OnDtlsHandshakeError); - if (!dtls_->SetPeerCertificateDigest( + if (remote_fingerprint_value_.size() && + !dtls_->SetPeerCertificateDigest( remote_fingerprint_algorithm_, reinterpret_cast(remote_fingerprint_value_.data()), remote_fingerprint_value_.size())) { @@ -401,6 +423,10 @@ int DtlsTransportChannelWrapper::SendPacket( } } +bool DtlsTransportChannelWrapper::IsDtlsConnected() { + return dtls_ && dtls_->IsTlsConnected(); +} + // The state transition logic here is as follows: // (1) If we're not doing DTLS-SRTP, then the state is just the // state of the underlying impl() @@ -481,6 +507,14 @@ void DtlsTransportChannelWrapper::OnReadPacket( LOG_J(LS_INFO, this) << "Caching DTLS ClientHello packet until DTLS is " << "started."; cached_client_hello_.SetData(data, size); + // If we haven't started setting up DTLS yet (because we don't have a + // remote fingerprint/role), we can use the client hello as a clue that + // the peer has chosen the client role, and proceed with the handshake. + // The fingerprint will be verified when it's set. + if (!dtls_ && local_certificate_) { + SetSslRole(rtc::SSL_SERVER); + SetupDtls(); + } } else { LOG_J(LS_INFO, this) << "Not a DTLS ClientHello packet; dropping."; } @@ -554,8 +588,20 @@ void DtlsTransportChannelWrapper::OnDtlsEvent(rtc::StreamInterface* dtls, if (sig & rtc::SE_READ) { char buf[kMaxDtlsPacketLen]; size_t read; - if (dtls_->Read(buf, sizeof(buf), &read, NULL) == rtc::SR_SUCCESS) { + int read_error; + rtc::StreamResult ret = dtls_->Read(buf, sizeof(buf), &read, &read_error); + if (ret == rtc::SR_SUCCESS) { SignalReadPacket(this, buf, read, rtc::CreatePacketTime(0), 0); + } else if (ret == rtc::SR_EOS) { + // Remote peer shut down the association with no error. + LOG_J(LS_INFO, this) << "DTLS channel closed"; + set_writable(false); + set_dtls_state(DTLS_TRANSPORT_CLOSED); + } else if (ret == rtc::SR_ERROR) { + // Remote peer shut down the association with an error. + LOG_J(LS_INFO, this) << "DTLS channel error, code=" << read_error; + set_writable(false); + set_dtls_state(DTLS_TRANSPORT_FAILED); } } if (sig & rtc::SE_CLOSE) { diff --git a/webrtc/p2p/base/dtlstransportchannel.h b/webrtc/p2p/base/dtlstransportchannel.h index 19823f8c34..a07c605e05 100644 --- a/webrtc/p2p/base/dtlstransportchannel.h +++ b/webrtc/p2p/base/dtlstransportchannel.h @@ -193,8 +193,12 @@ class DtlsTransportChannelWrapper : public TransportChannelImpl { // Needed by DtlsTransport. TransportChannelImpl* channel() { return channel_; } + // For informational purposes. Tells if the DTLS handshake has finished. + // This may be true even if writable() is false, if the remote fingerprint + // has not yet been verified. + bool IsDtlsConnected(); + private: - void OnReadableState(TransportChannel* channel); void OnWritableState(TransportChannel* channel); void OnReadPacket(TransportChannel* channel, const char* data, size_t size, const rtc::PacketTime& packet_time, int flags); diff --git a/webrtc/p2p/base/dtlstransportchannel_unittest.cc b/webrtc/p2p/base/dtlstransportchannel_unittest.cc index 6eb0f0e3f1..0a6e254e8f 100644 --- a/webrtc/p2p/base/dtlstransportchannel_unittest.cc +++ b/webrtc/p2p/base/dtlstransportchannel_unittest.cc @@ -81,10 +81,11 @@ class DtlsTestClient : public sigslot::has_slots<> { ASSERT(!transport_); ssl_max_version_ = version; } - void SetupChannels(int count, cricket::IceRole role) { + void SetupChannels(int count, cricket::IceRole role, int async_delay_ms = 0) { transport_.reset(new cricket::DtlsTransport( "dtls content name", nullptr, certificate_)); transport_->SetAsync(true); + transport_->SetAsyncDelay(async_delay_ms); transport_->SetIceRole(role); transport_->SetIceTiebreaker( (role == cricket::ICEROLE_CONTROLLING) ? 1 : 2); @@ -119,6 +120,11 @@ class DtlsTestClient : public sigslot::has_slots<> { static_cast(wrapper->channel()) : NULL; } + cricket::DtlsTransportChannelWrapper* GetDtlsChannel(int component) { + cricket::TransportChannelImpl* ch = transport_->GetChannel(component); + return static_cast(ch); + } + // Offer DTLS if we have an identity; pass in a remote fingerprint only if // both sides support DTLS. void Negotiate(DtlsTestClient* peer, cricket::ContentAction action, @@ -152,7 +158,6 @@ class DtlsTestClient : public sigslot::has_slots<> { EXPECT_EQ(expect_success, transport_->SetLocalTransportDescription( MakeTransportDescription(cert, role), action, nullptr)); - set_local_cert_ = (cert != nullptr); } void SetRemoteTransportDescription( @@ -167,7 +172,6 @@ class DtlsTestClient : public sigslot::has_slots<> { EXPECT_EQ(expect_success, transport_->SetRemoteTransportDescription( MakeTransportDescription(cert, role), action, nullptr)); - set_remote_cert_ = (cert != nullptr); } // Allow any DTLS configuration to be specified (including invalid ones). @@ -229,7 +233,16 @@ class DtlsTestClient : public sigslot::has_slots<> { return received_dtls_client_hellos_; } - bool negotiated_dtls() const { return set_local_cert_ && set_remote_cert_; } + int received_dtls_server_hellos() const { + return received_dtls_server_hellos_; + } + + bool negotiated_dtls() const { + return transport_->local_description() && + transport_->local_description()->identity_fingerprint && + transport_->remote_description() && + transport_->remote_description()->identity_fingerprint; + } void CheckRole(rtc::SSLRole role) { if (role == rtc::SSL_CLIENT) { @@ -411,18 +424,19 @@ class DtlsTestClient : public sigslot::has_slots<> { std::set received_; bool use_dtls_srtp_ = false; rtc::SSLProtocolVersion ssl_max_version_ = rtc::SSL_PROTOCOL_DTLS_12; - bool set_local_cert_ = false; - bool set_remote_cert_ = false; int received_dtls_client_hellos_ = 0; int received_dtls_server_hellos_ = 0; rtc::SentPacket sent_packet_; }; +// Base class for DtlsTransportChannelTest and DtlsEventOrderingTest, which +// inherit from different variants of testing::Test. +// // Note that this test always uses a FakeClock, due to the |fake_clock_| member // variable. -class DtlsTransportChannelTest : public testing::Test { +class DtlsTransportChannelTestBase { public: - DtlsTransportChannelTest() + DtlsTransportChannelTestBase() : client1_("P1"), client2_("P2"), channel_ct_(1), @@ -495,9 +509,9 @@ class DtlsTransportChannelTest : public testing::Test { if (!rv) return false; - EXPECT_TRUE_WAIT( + EXPECT_TRUE_SIMULATED_WAIT( client1_.all_channels_writable() && client2_.all_channels_writable(), - kTimeout); + kTimeout, fake_clock_); if (!client1_.all_channels_writable() || !client2_.all_channels_writable()) return false; @@ -588,7 +602,8 @@ class DtlsTransportChannelTest : public testing::Test { LOG(LS_INFO) << "Expect packets, size=" << size; client2_.ExpectPackets(channel, size); client1_.SendPackets(channel, size, count, srtp); - EXPECT_EQ_WAIT(count, client2_.NumPacketsReceived(), kTimeout); + EXPECT_EQ_SIMULATED_WAIT(count, client2_.NumPacketsReceived(), kTimeout, + fake_clock_); } protected: @@ -601,6 +616,9 @@ class DtlsTransportChannelTest : public testing::Test { rtc::SSLProtocolVersion ssl_expected_version_; }; +class DtlsTransportChannelTest : public DtlsTransportChannelTestBase, + public ::testing::Test {}; + // Test that transport negotiation of ICE, no DTLS works properly. TEST_F(DtlsTransportChannelTest, TestChannelSetupIce) { Negotiate(); @@ -884,9 +902,9 @@ TEST_F(DtlsTransportChannelTest, TestRenegotiateBeforeConnect) { cricket::CONNECTIONROLE_ACTIVE, NF_REOFFER); bool rv = client1_.Connect(&client2_, false); EXPECT_TRUE(rv); - EXPECT_TRUE_WAIT( + EXPECT_TRUE_SIMULATED_WAIT( client1_.all_channels_writable() && client2_.all_channels_writable(), - kTimeout); + kTimeout, fake_clock_); TestTransfer(0, 1000, 100, true); TestTransfer(1, 1000, 100, true); @@ -941,72 +959,6 @@ TEST_F(DtlsTransportChannelTest, TestCertificatesAfterConnect) { certificate1->ssl_certificate().ToPEMString()); } -// Test that DTLS completes promptly if a ClientHello is received before the -// transport channel is writable (allowing a ServerHello to be sent). -TEST_F(DtlsTransportChannelTest, TestReceiveClientHelloBeforeWritable) { - MAYBE_SKIP_TEST(HaveDtls); - PrepareDtls(true, true, rtc::KT_DEFAULT); - // Exchange transport descriptions. - Negotiate(cricket::CONNECTIONROLE_ACTPASS, cricket::CONNECTIONROLE_ACTIVE); - - // Make client2_ writable, but not client1_. - EXPECT_TRUE(client2_.Connect(&client1_, true)); - EXPECT_TRUE_WAIT(client2_.all_raw_channels_writable(), kTimeout); - - // Expect a DTLS ClientHello to be sent even while client1_ isn't writable. - EXPECT_EQ_WAIT(1, client1_.received_dtls_client_hellos(), kTimeout); - EXPECT_FALSE(client1_.all_raw_channels_writable()); - - // Now make client1_ writable and expect the handshake to complete - // without client2_ needing to retransmit the ClientHello. - EXPECT_TRUE(client1_.Connect(&client2_, true)); - EXPECT_TRUE_WAIT( - client1_.all_channels_writable() && client2_.all_channels_writable(), - kTimeout); - EXPECT_EQ(1, client1_.received_dtls_client_hellos()); -} - -// Test that DTLS completes promptly if a ClientHello is received before the -// transport channel has a remote fingerprint (allowing a ServerHello to be -// sent). -TEST_F(DtlsTransportChannelTest, - TestReceiveClientHelloBeforeRemoteFingerprint) { - MAYBE_SKIP_TEST(HaveDtls); - PrepareDtls(true, true, rtc::KT_DEFAULT); - client1_.SetupChannels(channel_ct_, cricket::ICEROLE_CONTROLLING); - client2_.SetupChannels(channel_ct_, cricket::ICEROLE_CONTROLLED); - - // Make client2_ writable and give it local/remote certs, but don't yet give - // client1_ a remote fingerprint. - client1_.transport()->SetLocalTransportDescription( - MakeTransportDescription(client1_.certificate(), - cricket::CONNECTIONROLE_ACTPASS), - cricket::CA_OFFER, nullptr); - client2_.Negotiate(&client1_, cricket::CA_ANSWER, - cricket::CONNECTIONROLE_ACTIVE, - cricket::CONNECTIONROLE_ACTPASS, 0); - EXPECT_TRUE(client2_.Connect(&client1_, true)); - EXPECT_TRUE_WAIT(client2_.all_raw_channels_writable(), kTimeout); - - // Expect a DTLS ClientHello to be sent even while client1_ doesn't have a - // remote fingerprint. - EXPECT_EQ_WAIT(1, client1_.received_dtls_client_hellos(), kTimeout); - EXPECT_FALSE(client1_.all_raw_channels_writable()); - - // Now make give client1_ its remote fingerprint and make it writable, and - // expect the handshake to complete without client2_ needing to retransmit - // the ClientHello. - client1_.transport()->SetRemoteTransportDescription( - MakeTransportDescription(client2_.certificate(), - cricket::CONNECTIONROLE_ACTIVE), - cricket::CA_ANSWER, nullptr); - EXPECT_TRUE(client1_.Connect(&client2_, true)); - EXPECT_TRUE_WAIT( - client1_.all_channels_writable() && client2_.all_channels_writable(), - kTimeout); - EXPECT_EQ(1, client1_.received_dtls_client_hellos()); -} - // Test that packets are retransmitted according to the expected schedule. // Each time a timeout occurs, the retransmission timer should be doubled up to // 60 seconds. The timer defaults to 1 second, but for WebRTC we should be @@ -1024,7 +976,8 @@ TEST_F(DtlsTransportChannelTest, TestRetransmissionSchedule) { // Make client2_ writable, but not client1_. // This means client1_ will send DTLS client hellos but get no response. EXPECT_TRUE(client2_.Connect(&client1_, true)); - EXPECT_TRUE_WAIT(client2_.all_raw_channels_writable(), kTimeout); + EXPECT_TRUE_SIMULATED_WAIT(client2_.all_raw_channels_writable(), kTimeout, + fake_clock_); // Wait for the first client hello to be sent. EXPECT_EQ_WAIT(1, client1_.received_dtls_client_hellos(), kTimeout); @@ -1059,3 +1012,164 @@ TEST_F(DtlsTransportChannelTest, TestConnectBeforeNegotiate) { CONNECT_BEFORE_NEGOTIATE)); TestTransfer(0, 1000, 100, false); } + +// The following events can occur in many different orders: +// 1. Caller receives remote fingerprint. +// 2. Caller is writable. +// 3. Caller receives ClientHello. +// 4. DTLS handshake finishes. +// +// The tests below cover all causally consistent permutations of these events; +// the caller must be writable and receive a ClientHello before the handshake +// finishes, but otherwise any ordering is possible. +// +// For each permutation, the test verifies that a connection is established and +// fingerprint verified without any DTLS packet needing to be retransmitted. +// +// Each permutation is also tested with valid and invalid fingerprints, +// ensuring that the handshake fails with an invalid fingerprint. +enum DtlsTransportEvent { + CALLER_RECEIVES_FINGERPRINT, + CALLER_WRITABLE, + CALLER_RECEIVES_CLIENTHELLO, + HANDSHAKE_FINISHES +}; + +class DtlsEventOrderingTest + : public DtlsTransportChannelTestBase, + public ::testing::TestWithParam< + ::testing::tuple, bool>> { + protected: + // If |valid_fingerprint| is false, the caller will receive a fingerprint + // that doesn't match the callee's certificate, so the handshake should fail. + void TestEventOrdering(const std::vector& events, + bool valid_fingerprint) { + // Pre-setup: Set local certificate on both caller and callee, and + // remote fingerprint on callee, but neither is writable and the caller + // doesn't have the callee's fingerprint. + PrepareDtls(true, true, rtc::KT_DEFAULT); + // Simulate packets being sent and arriving asynchronously. + // Otherwise the entire DTLS handshake would occur in one clock tick, and + // we couldn't inject method calls in the middle of it. + int simulated_delay_ms = 10; + client1_.SetupChannels(channel_ct_, cricket::ICEROLE_CONTROLLING, + simulated_delay_ms); + client2_.SetupChannels(channel_ct_, cricket::ICEROLE_CONTROLLED, + simulated_delay_ms); + client1_.SetLocalTransportDescription(client1_.certificate(), + cricket::CA_OFFER, + cricket::CONNECTIONROLE_ACTPASS, 0); + client2_.Negotiate(&client1_, cricket::CA_ANSWER, + cricket::CONNECTIONROLE_ACTIVE, + cricket::CONNECTIONROLE_ACTPASS, 0); + + for (DtlsTransportEvent e : events) { + switch (e) { + case CALLER_RECEIVES_FINGERPRINT: + if (valid_fingerprint) { + client1_.SetRemoteTransportDescription( + client2_.certificate(), cricket::CA_ANSWER, + cricket::CONNECTIONROLE_ACTIVE, 0); + } else { + // Create a fingerprint with a correct algorithm but an invalid + // digest. + cricket::TransportDescription remote_desc = + MakeTransportDescription(client2_.certificate(), + cricket::CONNECTIONROLE_ACTIVE); + ++(remote_desc.identity_fingerprint->digest[0]); + // Even if certificate verification fails inside this method, + // it should return true as long as the fingerprint was formatted + // correctly. + EXPECT_TRUE(client1_.transport()->SetRemoteTransportDescription( + remote_desc, cricket::CA_ANSWER, nullptr)); + } + break; + case CALLER_WRITABLE: + EXPECT_TRUE(client1_.Connect(&client2_, true)); + EXPECT_TRUE_SIMULATED_WAIT(client1_.all_raw_channels_writable(), + kTimeout, fake_clock_); + break; + case CALLER_RECEIVES_CLIENTHELLO: + // Sanity check that a ClientHello hasn't already been received. + EXPECT_EQ(0, client1_.received_dtls_client_hellos()); + // Making client2_ writable will cause it to send the ClientHello. + EXPECT_TRUE(client2_.Connect(&client1_, true)); + EXPECT_TRUE_SIMULATED_WAIT(client2_.all_raw_channels_writable(), + kTimeout, fake_clock_); + EXPECT_EQ_SIMULATED_WAIT(1, client1_.received_dtls_client_hellos(), + kTimeout, fake_clock_); + break; + case HANDSHAKE_FINISHES: + // Sanity check that the handshake hasn't already finished. + EXPECT_FALSE(client1_.GetDtlsChannel(0)->IsDtlsConnected() || + client1_.GetDtlsChannel(0)->dtls_state() == + cricket::DTLS_TRANSPORT_FAILED); + EXPECT_TRUE_SIMULATED_WAIT( + client1_.GetDtlsChannel(0)->IsDtlsConnected() || + client1_.GetDtlsChannel(0)->dtls_state() == + cricket::DTLS_TRANSPORT_FAILED, + kTimeout, fake_clock_); + break; + } + } + + cricket::DtlsTransportState expected_final_state = + valid_fingerprint ? cricket::DTLS_TRANSPORT_CONNECTED + : cricket::DTLS_TRANSPORT_FAILED; + EXPECT_EQ_SIMULATED_WAIT(expected_final_state, + client1_.GetDtlsChannel(0)->dtls_state(), kTimeout, + fake_clock_); + EXPECT_EQ_SIMULATED_WAIT(expected_final_state, + client2_.GetDtlsChannel(0)->dtls_state(), kTimeout, + fake_clock_); + + // Channel should be writable iff there was a valid fingerprint. + EXPECT_EQ(valid_fingerprint, client1_.GetDtlsChannel(0)->writable()); + EXPECT_EQ(valid_fingerprint, client2_.GetDtlsChannel(0)->writable()); + + // Check that no hello needed to be retransmitted. + EXPECT_EQ(1, client1_.received_dtls_client_hellos()); + EXPECT_EQ(1, client2_.received_dtls_server_hellos()); + + if (valid_fingerprint) { + TestTransfer(0, 1000, 100, false); + } + } +}; + +TEST_P(DtlsEventOrderingTest, TestEventOrdering) { + MAYBE_SKIP_TEST(HaveDtls); + TestEventOrdering(::testing::get<0>(GetParam()), + ::testing::get<1>(GetParam())); +} + +INSTANTIATE_TEST_CASE_P( + TestEventOrdering, + DtlsEventOrderingTest, + ::testing::Combine( + ::testing::Values( + std::vector{ + CALLER_RECEIVES_FINGERPRINT, CALLER_WRITABLE, + CALLER_RECEIVES_CLIENTHELLO, HANDSHAKE_FINISHES}, + std::vector{ + CALLER_WRITABLE, CALLER_RECEIVES_FINGERPRINT, + CALLER_RECEIVES_CLIENTHELLO, HANDSHAKE_FINISHES}, + std::vector{ + CALLER_WRITABLE, CALLER_RECEIVES_CLIENTHELLO, + CALLER_RECEIVES_FINGERPRINT, HANDSHAKE_FINISHES}, + std::vector{ + CALLER_WRITABLE, CALLER_RECEIVES_CLIENTHELLO, + HANDSHAKE_FINISHES, CALLER_RECEIVES_FINGERPRINT}, + std::vector{ + CALLER_RECEIVES_FINGERPRINT, CALLER_RECEIVES_CLIENTHELLO, + CALLER_WRITABLE, HANDSHAKE_FINISHES}, + std::vector{ + CALLER_RECEIVES_CLIENTHELLO, CALLER_RECEIVES_FINGERPRINT, + CALLER_WRITABLE, HANDSHAKE_FINISHES}, + std::vector{ + CALLER_RECEIVES_CLIENTHELLO, CALLER_WRITABLE, + CALLER_RECEIVES_FINGERPRINT, HANDSHAKE_FINISHES}, + std::vector{CALLER_RECEIVES_CLIENTHELLO, + CALLER_WRITABLE, HANDSHAKE_FINISHES, + CALLER_RECEIVES_FINGERPRINT}), + ::testing::Bool())); diff --git a/webrtc/p2p/base/faketransportcontroller.h b/webrtc/p2p/base/faketransportcontroller.h index 153aa56790..5d0ceb434d 100644 --- a/webrtc/p2p/base/faketransportcontroller.h +++ b/webrtc/p2p/base/faketransportcontroller.h @@ -69,6 +69,7 @@ class FakeTransportChannel : public TransportChannelImpl, // If async, will send packets by "Post"-ing to message queue instead of // synchronously "Send"-ing. void SetAsync(bool async) { async_ = async; } + void SetAsyncDelay(int delay_ms) { async_delay_ms_ = delay_ms; } TransportChannelState GetState() const override { if (connection_count_ == 0) { @@ -200,7 +201,12 @@ class FakeTransportChannel : public TransportChannelImpl, PacketMessageData* packet = new PacketMessageData(data, len); if (async_) { - rtc::Thread::Current()->Post(RTC_FROM_HERE, this, 0, packet); + if (async_delay_ms_) { + rtc::Thread::Current()->PostDelayed(RTC_FROM_HERE, async_delay_ms_, + this, 0, packet); + } else { + rtc::Thread::Current()->Post(RTC_FROM_HERE, this, 0, packet); + } } else { rtc::Thread::Current()->Send(RTC_FROM_HERE, this, 0, packet); } @@ -311,6 +317,7 @@ class FakeTransportChannel : public TransportChannelImpl, FakeTransportChannel* dest_ = nullptr; State state_ = STATE_INIT; bool async_ = false; + int async_delay_ms_ = 0; Candidates remote_candidates_; rtc::scoped_refptr local_cert_; rtc::FakeSSLCertificate* remote_cert_ = nullptr; @@ -354,6 +361,7 @@ class FakeTransport : public Transport { // If async, will send packets by "Post"-ing to message queue instead of // synchronously "Send"-ing. void SetAsync(bool async) { async_ = async; } + void SetAsyncDelay(int delay_ms) { async_delay_ms_ = delay_ms; } // If |asymmetric| is true, only set the destination for this transport, and // not |dest|. @@ -415,6 +423,7 @@ class FakeTransport : public Transport { FakeTransportChannel* channel = new FakeTransportChannel(name(), component); channel->set_ssl_max_protocol_version(ssl_max_version_); channel->SetAsync(async_); + channel->SetAsyncDelay(async_delay_ms_); SetChannelDestination(component, channel, false); channels_[component] = channel; return channel; @@ -451,6 +460,7 @@ class FakeTransport : public Transport { ChannelMap channels_; FakeTransport* dest_ = nullptr; bool async_ = false; + int async_delay_ms_ = 0; rtc::scoped_refptr certificate_; rtc::SSLProtocolVersion ssl_max_version_ = rtc::SSL_PROTOCOL_DTLS_12; };