diff --git a/webrtc/base/openssladapter.cc b/webrtc/base/openssladapter.cc index bc7b99b97a..fbab6c84ef 100644 --- a/webrtc/base/openssladapter.cc +++ b/webrtc/base/openssladapter.cc @@ -154,6 +154,21 @@ static long socket_ctrl(BIO* b, int cmd, long num, void* ptr) { } } +static void LogSslError() { + // Walk down the error stack to find the SSL error. + uint32_t error_code; + const char* file; + int line; + do { + error_code = ERR_get_error_line(&file, &line); + if (ERR_GET_LIB(error_code) == ERR_LIB_SSL) { + LOG(LS_ERROR) << "ERR_LIB_SSL: " << error_code << ", " << file << ":" + << line; + break; + } + } while (error_code != 0); +} + ///////////////////////////////////////////////////////////////////////////// // OpenSSLAdapter ///////////////////////////////////////////////////////////////////////////// @@ -334,6 +349,14 @@ OpenSSLAdapter::BeginSSL() { SSL_set_app_data(ssl_, this); SSL_set_bio(ssl_, bio, bio); + // SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER allows different buffers to be passed + // into SSL_write when a record could only be partially transmitted (and thus + // requires another call to SSL_write to finish transmission). This allows us + // to copy the data into our own buffer when this occurs, since the original + // buffer can't safely be accessed after control exits Send. + // TODO(deadbeef): Do we want SSL_MODE_ENABLE_PARTIAL_WRITE? It doesn't + // appear Send handles partial writes properly, though maybe we never notice + // since we never send more than 16KB at once.. SSL_set_mode(ssl_, SSL_MODE_ENABLE_PARTIAL_WRITE | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); @@ -426,6 +449,7 @@ OpenSSLAdapter::Cleanup() { ssl_read_needs_write_ = false; ssl_write_needs_read_ = false; custom_verification_succeeded_ = false; + pending_data_.Clear(); if (ssl_) { SSL_free(ssl_); @@ -441,6 +465,46 @@ OpenSSLAdapter::Cleanup() { Thread::Current()->Clear(this, MSG_TIMEOUT); } +int OpenSSLAdapter::DoSslWrite(const void* pv, size_t cb, int* error) { + // If we have pending data (that was previously only partially written by + // SSL_write), we shouldn't be attempting to write anything else. + RTC_DCHECK(pending_data_.empty() || pv == pending_data_.data()); + RTC_DCHECK(error != nullptr); + + ssl_write_needs_read_ = false; + int ret = SSL_write(ssl_, pv, checked_cast(cb)); + *error = SSL_get_error(ssl_, ret); + switch (*error) { + case SSL_ERROR_NONE: + // Success! + return ret; + case SSL_ERROR_WANT_READ: + LOG(LS_INFO) << " -- error want read"; + ssl_write_needs_read_ = true; + SetError(EWOULDBLOCK); + break; + case SSL_ERROR_WANT_WRITE: + LOG(LS_INFO) << " -- error want write"; + SetError(EWOULDBLOCK); + break; + case SSL_ERROR_ZERO_RETURN: + // LOG(LS_INFO) << " -- remote side closed"; + SetError(EWOULDBLOCK); + // do we need to signal closure? + break; + case SSL_ERROR_SSL: + LogSslError(); + Error("SSL_write", ret ? ret : -1, false); + break; + default: + LOG(LS_WARNING) << "Unknown error from SSL_write: " << *error; + Error("SSL_write", ret ? ret : -1, false); + break; + } + + return SOCKET_ERROR; +} + // // AsyncSocket Implementation // @@ -466,38 +530,52 @@ OpenSSLAdapter::Send(const void* pv, size_t cb) { return SOCKET_ERROR; } + int ret; + int error; + + if (!pending_data_.empty()) { + ret = DoSslWrite(pending_data_.data(), pending_data_.size(), &error); + if (ret != static_cast(pending_data_.size())) { + // We couldn't finish sending the pending data, so we definitely can't + // send any more data. Return with an EWOULDBLOCK error. + SetError(EWOULDBLOCK); + return SOCKET_ERROR; + } + // We completed sending the data previously passed into SSL_write! Now + // we're allowed to send more data. + pending_data_.Clear(); + } + // OpenSSL will return an error if we try to write zero bytes if (cb == 0) return 0; - ssl_write_needs_read_ = false; + ret = DoSslWrite(pv, cb, &error); - int code = SSL_write(ssl_, pv, checked_cast(cb)); - switch (SSL_get_error(ssl_, code)) { - case SSL_ERROR_NONE: - //LOG(LS_INFO) << " -- success"; - return code; - case SSL_ERROR_WANT_READ: - //LOG(LS_INFO) << " -- error want read"; - ssl_write_needs_read_ = true; - SetError(EWOULDBLOCK); - break; - case SSL_ERROR_WANT_WRITE: - //LOG(LS_INFO) << " -- error want write"; - SetError(EWOULDBLOCK); - break; - case SSL_ERROR_ZERO_RETURN: - //LOG(LS_INFO) << " -- remote side closed"; - SetError(EWOULDBLOCK); - // do we need to signal closure? - break; - default: - //LOG(LS_INFO) << " -- error " << code; - Error("SSL_write", (code ? code : -1), false); - break; + // If SSL_write fails with SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, this + // means the underlying socket is blocked on reading or (more typically) + // writing. When this happens, OpenSSL requires that the next call to + // SSL_write uses the same arguments (though, with + // SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER, the actual buffer pointer may be + // different). + // + // However, after Send exits, we will have lost access to data the user of + // this class is trying to send, and there's no guarantee that the user of + // this class will call Send with the same arguements when it fails. So, we + // buffer the data ourselves. When we know the underlying socket is writable + // again from OnWriteEvent (or if Send is called again before that happens), + // we'll retry sending this buffered data. + if ((error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE) && + pending_data_.empty()) { + LOG(LS_WARNING) + << "SSL_write couldn't write to the underlying socket; buffering data."; + pending_data_.SetData(static_cast(pv), cb); + // Since we're taking responsibility for sending this data, return its full + // size. The user of this class can consider it sent. + return cb; } - return SOCKET_ERROR; + return ret; } int @@ -539,28 +617,33 @@ int OpenSSLAdapter::Recv(void* pv, size_t cb, int64_t* timestamp) { ssl_read_needs_write_ = false; int code = SSL_read(ssl_, pv, checked_cast(cb)); - switch (SSL_get_error(ssl_, code)) { - case SSL_ERROR_NONE: - //LOG(LS_INFO) << " -- success"; - return code; - case SSL_ERROR_WANT_READ: - //LOG(LS_INFO) << " -- error want read"; - SetError(EWOULDBLOCK); - break; - case SSL_ERROR_WANT_WRITE: - //LOG(LS_INFO) << " -- error want write"; - ssl_read_needs_write_ = true; - SetError(EWOULDBLOCK); - break; - case SSL_ERROR_ZERO_RETURN: - //LOG(LS_INFO) << " -- remote side closed"; - SetError(EWOULDBLOCK); - // do we need to signal closure? - break; - default: - //LOG(LS_INFO) << " -- error " << code; - Error("SSL_read", (code ? code : -1), false); - break; + int error = SSL_get_error(ssl_, code); + switch (error) { + case SSL_ERROR_NONE: + // LOG(LS_INFO) << " -- success"; + return code; + case SSL_ERROR_WANT_READ: + // LOG(LS_INFO) << " -- error want read"; + SetError(EWOULDBLOCK); + break; + case SSL_ERROR_WANT_WRITE: + // LOG(LS_INFO) << " -- error want write"; + ssl_read_needs_write_ = true; + SetError(EWOULDBLOCK); + break; + case SSL_ERROR_ZERO_RETURN: + // LOG(LS_INFO) << " -- remote side closed"; + SetError(EWOULDBLOCK); + // do we need to signal closure? + break; + case SSL_ERROR_SSL: + LogSslError(); + Error("SSL_read", (code ? code : -1), false); + break; + default: + LOG(LS_WARNING) << "Unknown error from SSL_read: " << error; + Error("SSL_read", (code ? code : -1), false); + break; } return SOCKET_ERROR; @@ -682,6 +765,16 @@ OpenSSLAdapter::OnWriteEvent(AsyncSocket* socket) { AsyncSocketAdapter::OnReadEvent(socket); } + // If a previous SSL_write failed due to the underlying socket being blocked, + // this will attempt finishing the write operation. + if (!pending_data_.empty()) { + int error; + if (DoSslWrite(pending_data_.data(), pending_data_.size(), &error) == + static_cast(pending_data_.size())) { + pending_data_.Clear(); + } + } + //LOG(LS_INFO) << " -- onStreamWriteable"; AsyncSocketAdapter::OnWriteEvent(socket); } diff --git a/webrtc/base/openssladapter.h b/webrtc/base/openssladapter.h index 554627f58f..2f0150f0f9 100644 --- a/webrtc/base/openssladapter.h +++ b/webrtc/base/openssladapter.h @@ -12,6 +12,7 @@ #define WEBRTC_BASE_OPENSSLADAPTER_H__ #include +#include "webrtc/base/buffer.h" #include "webrtc/base/messagehandler.h" #include "webrtc/base/messagequeue.h" #include "webrtc/base/ssladapter.h" @@ -65,6 +66,10 @@ private: void Error(const char* context, int err, bool signal = true); void Cleanup(); + // Return value and arguments have the same meanings as for Send; |error| is + // an output parameter filled with the result of SSL_get_error. + int DoSslWrite(const void* pv, size_t cb, int* error); + void OnMessage(Message* msg) override; static bool VerifyServerName(SSL* ssl, const char* host, @@ -86,6 +91,11 @@ private: // If true, socket will retain SSL configuration after Close. bool restartable_; + // This buffer is used if SSL_write fails with SSL_ERROR_WANT_WRITE, which + // means we need to keep retrying with *the same exact data* until it + // succeeds. Afterwards it will be cleared. + Buffer pending_data_; + SSL* ssl_; SSL_CTX* ssl_ctx_; std::string ssl_host_name_; diff --git a/webrtc/base/ssladapter_unittest.cc b/webrtc/base/ssladapter_unittest.cc index ec39d949a2..0eaac17885 100644 --- a/webrtc/base/ssladapter_unittest.cc +++ b/webrtc/base/ssladapter_unittest.cc @@ -15,9 +15,10 @@ #include "webrtc/base/ipaddress.h" #include "webrtc/base/socketstream.h" #include "webrtc/base/ssladapter.h" -#include "webrtc/base/sslstreamadapter.h" #include "webrtc/base/sslidentity.h" +#include "webrtc/base/sslstreamadapter.h" #include "webrtc/base/stream.h" +#include "webrtc/base/stringencode.h" #include "webrtc/base/virtualsocketserver.h" static const int kTimeout = 5000; @@ -210,19 +211,16 @@ class SSLAdapterTestDummyServer : public sigslot::has_slots<> { void OnSSLStreamAdapterEvent(rtc::StreamInterface* stream, int sig, int err) { if (sig & rtc::SE_READ) { char buffer[4096] = ""; - size_t read; int error; // Read data received from the client and store it in our internal // buffer. - rtc::StreamResult r = stream->Read(buffer, - sizeof(buffer) - 1, &read, &error); + rtc::StreamResult r = + stream->Read(buffer, sizeof(buffer) - 1, &read, &error); if (r == rtc::SR_SUCCESS) { buffer[read] = '\0'; - LOG(LS_INFO) << "Server received '" << buffer << "'"; - data_ += buffer; } } @@ -336,7 +334,7 @@ class SSLAdapterTestBase : public testing::Test, LOG(LS_INFO) << "Transfer complete."; } - private: + protected: const rtc::SSLMode ssl_mode_; std::unique_ptr vss_; @@ -389,6 +387,47 @@ TEST_F(SSLAdapterTestTLS_RSA, TestTLSTransfer) { TestTransfer("Hello, world!"); } +TEST_F(SSLAdapterTestTLS_RSA, TestTLSTransferWithBlockedSocket) { + TestHandshake(true); + + // Tell the underlying socket to simulate being blocked. + vss_->SetSendingBlocked(true); + + std::string expected; + int rv; + // Send messages until the SSL socket adapter starts applying backpressure. + // Note that this may not occur immediately since there may be some amount of + // intermediate buffering (either in our code or in BoringSSL). + for (int i = 0; i < 1024; ++i) { + std::string message = "Hello, world: " + rtc::ToString(i); + rv = client_->Send(message); + if (rv != static_cast(message.size())) { + // This test assumes either the whole message or none of it is sent. + ASSERT_EQ(-1, rv); + break; + } + expected += message; + } + // Assert that the loop above exited due to Send returning -1. + ASSERT_EQ(-1, rv); + + // Try sending another message while blocked. -1 should be returned again and + // it shouldn't end up received by the server later. + EXPECT_EQ(-1, client_->Send("Never sent")); + + // Unblock the underlying socket. All of the buffered messages should be sent + // without any further action. + vss_->SetSendingBlocked(false); + EXPECT_EQ_WAIT(expected, server_->GetReceivedData(), kTimeout); + + // Send another message. This previously wasn't working + std::string final_message = "Fin."; + expected += final_message; + EXPECT_EQ(static_cast(final_message.size()), + client_->Send(final_message)); + EXPECT_EQ_WAIT(expected, server_->GetReceivedData(), kTimeout); +} + // Test transfer between client and server, using ECDSA TEST_F(SSLAdapterTestTLS_ECDSA, TestTLSTransfer) { TestHandshake(true); diff --git a/webrtc/base/virtualsocketserver.cc b/webrtc/base/virtualsocketserver.cc index 9ba6a2f7fa..cfb4eb294b 100644 --- a/webrtc/base/virtualsocketserver.cc +++ b/webrtc/base/virtualsocketserver.cc @@ -56,6 +56,7 @@ enum { MSG_ID_ADDRESS_BOUND, MSG_ID_CONNECT, MSG_ID_DISCONNECT, + MSG_ID_SIGNALREADEVENT, }; // Packets are passed between sockets as messages. We copy the data just like @@ -303,6 +304,14 @@ int VirtualSocket::RecvFrom(void* pv, delete packet; } + // To behave like a real socket, SignalReadEvent should fire in the next + // message loop pass if there's still data buffered. + if (!recv_buffer_.empty()) { + // Clear the message so it doesn't end up posted multiple times. + server_->msg_queue_->Clear(this, MSG_ID_SIGNALREADEVENT); + server_->msg_queue_->Post(RTC_FROM_HERE, this, MSG_ID_SIGNALREADEVENT); + } + if (SOCK_STREAM == type_) { bool was_full = (recv_buffer_size_ == server_->recv_buffer_capacity_); recv_buffer_size_ -= data_read; @@ -421,6 +430,10 @@ void VirtualSocket::OnMessage(Message* pmsg) { } } else if (pmsg->message_id == MSG_ID_ADDRESS_BOUND) { SignalAddressReady(this, GetLocalAddress()); + } else if (pmsg->message_id == MSG_ID_SIGNALREADEVENT) { + if (!recv_buffer_.empty()) { + SignalReadEvent(this); + } } else { RTC_NOTREACHED(); }