From e488a0dbe4114ce51feeaf663ad4e2a6bd4b9a2b Mon Sep 17 00:00:00 2001 From: jbauch Date: Thu, 19 Nov 2015 05:17:58 -0800 Subject: [PATCH] Fix DTLS packet boundary handling in SSLStreamAdapterTests. The tests were not honoring packet boundaries, thus causing failures in tests with dropped/broken packets. This CL fixes this and also re-enables the tests. R=torbjorng@webrtc.org,pthatcher@webrtc.org,tommi@webrtc.org,juberti@webrtc.org BUG=webrtc:5005,webrtc:5188 Review URL: https://codereview.webrtc.org/1440193002 Cr-Commit-Position: refs/heads/master@{#10709} --- webrtc/base/bufferqueue.cc | 14 +- webrtc/base/bufferqueue.h | 17 +- webrtc/base/sslstreamadapter_unittest.cc | 244 ++++++++++++++++------- 3 files changed, 190 insertions(+), 85 deletions(-) diff --git a/webrtc/base/bufferqueue.cc b/webrtc/base/bufferqueue.cc index 955af51f6b..1ac57abc0c 100644 --- a/webrtc/base/bufferqueue.cc +++ b/webrtc/base/bufferqueue.cc @@ -38,19 +38,19 @@ bool BufferQueue::ReadFront(void* buffer, size_t bytes, size_t* bytes_read) { return false; } + bool was_writable = queue_.size() < capacity_; Buffer* packet = queue_.front(); queue_.pop_front(); - size_t next_packet_size = packet->size(); - if (bytes > next_packet_size) { - bytes = next_packet_size; - } - + bytes = std::min(bytes, packet->size()); memcpy(buffer, packet->data(), bytes); if (bytes_read) { *bytes_read = bytes; } free_list_.push_back(packet); + if (!was_writable) { + NotifyWritableForTest(); + } return true; } @@ -61,6 +61,7 @@ bool BufferQueue::WriteBack(const void* buffer, size_t bytes, return false; } + bool was_readable = !queue_.empty(); Buffer* packet; if (!free_list_.empty()) { packet = free_list_.back(); @@ -74,6 +75,9 @@ bool BufferQueue::WriteBack(const void* buffer, size_t bytes, *bytes_written = bytes; } queue_.push_back(packet); + if (!was_readable) { + NotifyReadableForTest(); + } return true; } diff --git a/webrtc/base/bufferqueue.h b/webrtc/base/bufferqueue.h index 4941fccf2e..458f0189cd 100644 --- a/webrtc/base/bufferqueue.h +++ b/webrtc/base/bufferqueue.h @@ -21,26 +21,33 @@ namespace rtc { class BufferQueue { public: - // Creates a buffer queue queue with a given capacity and default buffer size. + // Creates a buffer queue with a given capacity and default buffer size. BufferQueue(size_t capacity, size_t default_size); - ~BufferQueue(); + virtual ~BufferQueue(); // Return number of queued buffers. size_t size() const; // ReadFront will only read one buffer at a time and will truncate buffers // that don't fit in the passed memory. + // Returns true unless no data could be returned. bool ReadFront(void* data, size_t bytes, size_t* bytes_read); // WriteBack always writes either the complete memory or nothing. + // Returns true unless no data could be written. bool WriteBack(const void* data, size_t bytes, size_t* bytes_written); + protected: + // These methods are called when the state of the queue changes. + virtual void NotifyReadableForTest() {} + virtual void NotifyWritableForTest() {} + private: size_t capacity_; size_t default_size_; - std::deque queue_; - std::vector free_list_; - mutable CriticalSection crit_; // object lock + mutable CriticalSection crit_; + std::deque queue_ GUARDED_BY(crit_); + std::vector free_list_ GUARDED_BY(crit_); RTC_DISALLOW_COPY_AND_ASSIGN(BufferQueue); }; diff --git a/webrtc/base/sslstreamadapter_unittest.cc b/webrtc/base/sslstreamadapter_unittest.cc index 0344bd057a..b14a88707a 100644 --- a/webrtc/base/sslstreamadapter_unittest.cc +++ b/webrtc/base/sslstreamadapter_unittest.cc @@ -13,6 +13,7 @@ #include #include +#include "webrtc/base/bufferqueue.h" #include "webrtc/base/gunit.h" #include "webrtc/base/helpers.h" #include "webrtc/base/scoped_ptr.h" @@ -72,26 +73,26 @@ static const char kCERT_PEM[] = class SSLStreamAdapterTestBase; -class SSLDummyStream : public rtc::StreamInterface, - public sigslot::has_slots<> { +class SSLDummyStreamBase : public rtc::StreamInterface, + public sigslot::has_slots<> { public: - explicit SSLDummyStream(SSLStreamAdapterTestBase *test, - const std::string &side, - rtc::FifoBuffer *in, - rtc::FifoBuffer *out) : - test_(test), + SSLDummyStreamBase(SSLStreamAdapterTestBase* test, + const std::string &side, + rtc::StreamInterface* in, + rtc::StreamInterface* out) : + test_base_(test), side_(side), in_(in), out_(out), first_packet_(true) { - in_->SignalEvent.connect(this, &SSLDummyStream::OnEventIn); - out_->SignalEvent.connect(this, &SSLDummyStream::OnEventOut); + in_->SignalEvent.connect(this, &SSLDummyStreamBase::OnEventIn); + out_->SignalEvent.connect(this, &SSLDummyStreamBase::OnEventOut); } - virtual rtc::StreamState GetState() const { return rtc::SS_OPEN; } + rtc::StreamState GetState() const override { return rtc::SS_OPEN; } - virtual rtc::StreamResult Read(void* buffer, size_t buffer_len, - size_t* read, int* error) { + rtc::StreamResult Read(void* buffer, size_t buffer_len, + size_t* read, int* error) override { rtc::StreamResult r; r = in_->Read(buffer, buffer_len, read, error); @@ -109,22 +110,20 @@ class SSLDummyStream : public rtc::StreamInterface, } // Catch readability events on in and pass them up. - virtual void OnEventIn(rtc::StreamInterface *stream, int sig, - int err) { + void OnEventIn(rtc::StreamInterface* stream, int sig, int err) { int mask = (rtc::SE_READ | rtc::SE_CLOSE); if (sig & mask) { - LOG(LS_INFO) << "SSLDummyStream::OnEvent side=" << side_ << " sig=" + LOG(LS_INFO) << "SSLDummyStreamBase::OnEvent side=" << side_ << " sig=" << sig << " forwarding upward"; PostEvent(sig & mask, 0); } } // Catch writeability events on out and pass them up. - virtual void OnEventOut(rtc::StreamInterface *stream, int sig, - int err) { + void OnEventOut(rtc::StreamInterface* stream, int sig, int err) { if (sig & rtc::SE_WRITE) { - LOG(LS_INFO) << "SSLDummyStream::OnEvent side=" << side_ << " sig=" + LOG(LS_INFO) << "SSLDummyStreamBase::OnEvent side=" << side_ << " sig=" << sig << " forwarding upward"; PostEvent(sig & rtc::SE_WRITE, 0); @@ -133,28 +132,92 @@ class SSLDummyStream : public rtc::StreamInterface, // Write to the outgoing FifoBuffer rtc::StreamResult WriteData(const void* data, size_t data_len, - size_t* written, int* error) { + size_t* written, int* error) { return out_->Write(data, data_len, written, error); } - // Defined later - virtual rtc::StreamResult Write(const void* data, size_t data_len, - size_t* written, int* error); + rtc::StreamResult Write(const void* data, size_t data_len, + size_t* written, int* error) override; - virtual void Close() { + void Close() override { LOG(LS_INFO) << "Closing outbound stream"; out_->Close(); } - private: - SSLStreamAdapterTestBase *test_; + protected: + SSLStreamAdapterTestBase* test_base_; const std::string side_; - rtc::FifoBuffer *in_; - rtc::FifoBuffer *out_; + rtc::StreamInterface* in_; + rtc::StreamInterface* out_; bool first_packet_; }; +class SSLDummyStreamTLS : public SSLDummyStreamBase { + public: + SSLDummyStreamTLS(SSLStreamAdapterTestBase* test, + const std::string& side, + rtc::FifoBuffer* in, + rtc::FifoBuffer* out) : + SSLDummyStreamBase(test, side, in, out) { + } +}; + +class BufferQueueStream : public rtc::BufferQueue, + public rtc::StreamInterface { + public: + BufferQueueStream(size_t capacity, size_t default_size) + : rtc::BufferQueue(capacity, default_size) { + } + + // Implementation of abstract StreamInterface methods. + + // A buffer queue stream is always "open". + rtc::StreamState GetState() const override { return rtc::SS_OPEN; } + + // Reading a buffer queue stream will either succeed or block. + rtc::StreamResult Read(void* buffer, size_t buffer_len, + size_t* read, int* error) override { + if (!ReadFront(buffer, buffer_len, read)) { + return rtc::SR_BLOCK; + } + return rtc::SR_SUCCESS; + } + + // Writing to a buffer queue stream will either succeed or block. + rtc::StreamResult Write(const void* data, size_t data_len, + size_t* written, int* error) override { + if (!WriteBack(data, data_len, written)) { + return rtc::SR_BLOCK; + } + return rtc::SR_SUCCESS; + } + + // A buffer queue stream can not be closed. + void Close() override {} + + protected: + void NotifyReadableForTest() override { + PostEvent(rtc::SE_READ, 0); + } + + void NotifyWritableForTest() override { + PostEvent(rtc::SE_WRITE, 0); + } +}; + +class SSLDummyStreamDTLS : public SSLDummyStreamBase { + public: + SSLDummyStreamDTLS(SSLStreamAdapterTestBase* test, + const std::string& side, + BufferQueueStream* in, + BufferQueueStream* out) : + SSLDummyStreamBase(test, side, in, out) { + } +}; + static const int kFifoBufferSize = 4096; +static const int kBufferCapacity = 1; +static const size_t kDefaultBufferSize = 2048; class SSLStreamAdapterTestBase : public testing::Test, public sigslot::has_slots<> { @@ -165,14 +228,12 @@ class SSLStreamAdapterTestBase : public testing::Test, bool dtls, rtc::KeyParams client_key_type = rtc::KeyParams(rtc::KT_DEFAULT), rtc::KeyParams server_key_type = rtc::KeyParams(rtc::KT_DEFAULT)) - : client_buffer_(kFifoBufferSize), - server_buffer_(kFifoBufferSize), - client_stream_( - new SSLDummyStream(this, "c2s", &client_buffer_, &server_buffer_)), - server_stream_( - new SSLDummyStream(this, "s2c", &server_buffer_, &client_buffer_)), - client_ssl_(rtc::SSLStreamAdapter::Create(client_stream_)), - server_ssl_(rtc::SSLStreamAdapter::Create(server_stream_)), + : client_cert_pem_(client_cert_pem), + client_private_key_pem_(client_private_key_pem), + client_key_type_(client_key_type), + server_key_type_(server_key_type), + client_stream_(NULL), + server_stream_(NULL), client_identity_(NULL), server_identity_(NULL), delay_(0), @@ -185,21 +246,6 @@ class SSLStreamAdapterTestBase : public testing::Test, identities_set_(false) { // Set use of the test RNG to get predictable loss patterns. rtc::SetRandomTestMode(true); - - // Set up the slots - client_ssl_->SignalEvent.connect(this, &SSLStreamAdapterTestBase::OnEvent); - server_ssl_->SignalEvent.connect(this, &SSLStreamAdapterTestBase::OnEvent); - - if (!client_cert_pem.empty() && !client_private_key_pem.empty()) { - client_identity_ = rtc::SSLIdentity::FromPEMStrings( - client_private_key_pem, client_cert_pem); - } else { - client_identity_ = rtc::SSLIdentity::Generate("client", client_key_type); - } - server_identity_ = rtc::SSLIdentity::Generate("server", server_key_type); - - client_ssl_->SetIdentity(client_identity_); - server_ssl_->SetIdentity(server_identity_); } ~SSLStreamAdapterTestBase() { @@ -207,14 +253,40 @@ class SSLStreamAdapterTestBase : public testing::Test, rtc::SetRandomTestMode(false); } + virtual void SetUp() override { + CreateStreams(); + + client_ssl_.reset(rtc::SSLStreamAdapter::Create(client_stream_)); + server_ssl_.reset(rtc::SSLStreamAdapter::Create(server_stream_)); + + // Set up the slots + client_ssl_->SignalEvent.connect(this, &SSLStreamAdapterTestBase::OnEvent); + server_ssl_->SignalEvent.connect(this, &SSLStreamAdapterTestBase::OnEvent); + + if (!client_cert_pem_.empty() && !client_private_key_pem_.empty()) { + client_identity_ = rtc::SSLIdentity::FromPEMStrings( + client_private_key_pem_, client_cert_pem_); + } else { + client_identity_ = rtc::SSLIdentity::Generate("client", client_key_type_); + } + server_identity_ = rtc::SSLIdentity::Generate("server", server_key_type_); + + client_ssl_->SetIdentity(client_identity_); + server_ssl_->SetIdentity(server_identity_); + } + + virtual void TearDown() override { + client_ssl_.reset(nullptr); + server_ssl_.reset(nullptr); + } + + virtual void CreateStreams() = 0; + // Recreate the client/server identities with the specified validity period. // |not_before| and |not_after| are offsets from the current time in number // of seconds. void ResetIdentitiesWithValidity(int not_before, int not_after) { - client_stream_ = - new SSLDummyStream(this, "c2s", &client_buffer_, &server_buffer_); - server_stream_ = - new SSLDummyStream(this, "s2c", &server_buffer_, &client_buffer_); + CreateStreams(); client_ssl_.reset(rtc::SSLStreamAdapter::Create(client_stream_)); server_ssl_.reset(rtc::SSLStreamAdapter::Create(server_stream_)); @@ -329,9 +401,9 @@ class SSLStreamAdapterTestBase : public testing::Test, } } - rtc::StreamResult DataWritten(SSLDummyStream *from, const void *data, - size_t data_len, size_t *written, - int *error) { + rtc::StreamResult DataWritten(SSLDummyStreamBase *from, const void *data, + size_t data_len, size_t *written, + int *error) { // Randomly drop loss_ percent of packets if (rtc::CreateRandomId() % 100 < static_cast(loss_)) { LOG(LS_INFO) << "Randomly dropping packet, size=" << data_len; @@ -440,10 +512,12 @@ class SSLStreamAdapterTestBase : public testing::Test, virtual void TestTransfer(int size) = 0; protected: - rtc::FifoBuffer client_buffer_; - rtc::FifoBuffer server_buffer_; - SSLDummyStream *client_stream_; // freed by client_ssl_ destructor - SSLDummyStream *server_stream_; // freed by server_ssl_ destructor + std::string client_cert_pem_; + std::string client_private_key_pem_; + rtc::KeyParams client_key_type_; + rtc::KeyParams server_key_type_; + SSLDummyStreamBase *client_stream_; // freed by client_ssl_ destructor + SSLDummyStreamBase *server_stream_; // freed by server_ssl_ destructor rtc::scoped_ptr client_ssl_; rtc::scoped_ptr server_ssl_; rtc::SSLIdentity *client_identity_; // freed by client_ssl_ destructor @@ -467,7 +541,17 @@ class SSLStreamAdapterTestTLS "", false, ::testing::get<0>(GetParam()), - ::testing::get<1>(GetParam())){}; + ::testing::get<1>(GetParam())), + client_buffer_(kFifoBufferSize), + server_buffer_(kFifoBufferSize) { + } + + virtual void CreateStreams() override { + client_stream_ = + new SSLDummyStreamTLS(this, "c2s", &client_buffer_, &server_buffer_); + server_stream_ = + new SSLDummyStreamTLS(this, "s2c", &server_buffer_, &client_buffer_); + } // Test data transfer for TLS virtual void TestTransfer(int size) { @@ -562,6 +646,8 @@ class SSLStreamAdapterTestTLS } private: + rtc::FifoBuffer client_buffer_; + rtc::FifoBuffer server_buffer_; rtc::MemoryStream send_stream_; rtc::MemoryStream recv_stream_; }; @@ -576,6 +662,8 @@ class SSLStreamAdapterTestDTLS true, ::testing::get<0>(GetParam()), ::testing::get<1>(GetParam())), + client_buffer_(kBufferCapacity, kDefaultBufferSize), + server_buffer_(kBufferCapacity, kDefaultBufferSize), packet_size_(1000), count_(0), sent_(0) {} @@ -583,13 +671,22 @@ class SSLStreamAdapterTestDTLS SSLStreamAdapterTestDTLS(const std::string& cert_pem, const std::string& private_key_pem) : SSLStreamAdapterTestBase(cert_pem, private_key_pem, true), + client_buffer_(kBufferCapacity, kDefaultBufferSize), + server_buffer_(kBufferCapacity, kDefaultBufferSize), packet_size_(1000), count_(0), sent_(0) { } + virtual void CreateStreams() override { + client_stream_ = + new SSLDummyStreamDTLS(this, "c2s", &client_buffer_, &server_buffer_); + server_stream_ = + new SSLDummyStreamDTLS(this, "s2c", &server_buffer_, &client_buffer_); + } + virtual void WriteData() { unsigned char *packet = new unsigned char[1600]; - do { + while (sent_ < count_) { memset(packet, sent_ & 0xff, packet_size_); *(reinterpret_cast(packet)) = sent_; @@ -605,7 +702,7 @@ class SSLStreamAdapterTestDTLS ADD_FAILURE(); break; } - } while (sent_ < count_); + } delete [] packet; } @@ -664,6 +761,8 @@ class SSLStreamAdapterTestDTLS }; private: + BufferQueueStream client_buffer_; + BufferQueueStream server_buffer_; size_t packet_size_; int count_; int sent_; @@ -671,7 +770,7 @@ class SSLStreamAdapterTestDTLS }; -rtc::StreamResult SSLDummyStream::Write(const void* data, size_t data_len, +rtc::StreamResult SSLDummyStreamBase::Write(const void* data, size_t data_len, size_t* written, int* error) { *written = data_len; @@ -679,15 +778,13 @@ rtc::StreamResult SSLDummyStream::Write(const void* data, size_t data_len, if (first_packet_) { first_packet_ = false; - if (test_->GetLoseFirstPacket()) { + if (test_base_->GetLoseFirstPacket()) { LOG(LS_INFO) << "Losing initial packet of length " << data_len; return rtc::SR_SUCCESS; } } - return test_->DataWritten(this, data, data_len, written, error); - - return rtc::SR_SUCCESS; + return test_base_->DataWritten(this, data, data_len, written, error); }; class SSLStreamAdapterTestDTLSFromPEMStrings : public SSLStreamAdapterTestDTLS { @@ -779,23 +876,20 @@ TEST_P(SSLStreamAdapterTestDTLS, DISABLED_TestDTLSConnectWithSmallMtu) { }; // Test transfer -- trivial -// Disabled due to https://code.google.com/p/webrtc/issues/detail?id=5005 -TEST_P(SSLStreamAdapterTestDTLS, DISABLED_TestDTLSTransfer) { +TEST_P(SSLStreamAdapterTestDTLS, TestDTLSTransfer) { MAYBE_SKIP_TEST(HaveDtls); TestHandshake(); TestTransfer(100); }; -// Disabled due to https://code.google.com/p/webrtc/issues/detail?id=5005 -TEST_P(SSLStreamAdapterTestDTLS, DISABLED_TestDTLSTransferWithLoss) { +TEST_P(SSLStreamAdapterTestDTLS, TestDTLSTransferWithLoss) { MAYBE_SKIP_TEST(HaveDtls); TestHandshake(); SetLoss(10); TestTransfer(100); }; -// Disabled due to https://code.google.com/p/webrtc/issues/detail?id=5005 -TEST_P(SSLStreamAdapterTestDTLS, DISABLED_TestDTLSTransferWithDamage) { +TEST_P(SSLStreamAdapterTestDTLS, TestDTLSTransferWithDamage) { MAYBE_SKIP_TEST(HaveDtls); SetDamage(); // Must be called first because first packet // write happens at end of handshake.