diff --git a/webrtc/BUILD.gn b/webrtc/BUILD.gn index accd2c91d7..fc7885ed63 100644 --- a/webrtc/BUILD.gn +++ b/webrtc/BUILD.gn @@ -449,7 +449,6 @@ if (rtc_include_tests) { "p2p/base/dtlstransportchannel_unittest.cc", "p2p/base/fakeportallocator.h", "p2p/base/faketransportcontroller.h", - "p2p/base/jseptransport_unittest.cc", "p2p/base/p2ptransportchannel_unittest.cc", "p2p/base/port_unittest.cc", "p2p/base/portallocator_unittest.cc", @@ -464,6 +463,7 @@ if (rtc_include_tests) { "p2p/base/testrelayserver.h", "p2p/base/teststunserver.h", "p2p/base/testturnserver.h", + "p2p/base/transport_unittest.cc", "p2p/base/transportcontroller_unittest.cc", "p2p/base/transportdescriptionfactory_unittest.cc", "p2p/base/turnport_unittest.cc", diff --git a/webrtc/api/statscollector_unittest.cc b/webrtc/api/statscollector_unittest.cc index baa35d3659..10fe1e3392 100644 --- a/webrtc/api/statscollector_unittest.cc +++ b/webrtc/api/statscollector_unittest.cc @@ -1477,6 +1477,10 @@ TEST_F(StatsCollectorTest, NoCertificates) { session_stats.transport_stats[transport_stats.transport_name] = transport_stats; + // Fake transport object. + std::unique_ptr transport( + new cricket::FakeTransport(transport_stats.transport_name)); + // Configure MockWebRtcSession EXPECT_CALL(session_, GetTransportStats(_)) .WillOnce(DoAll(SetArgPointee<0>(session_stats), diff --git a/webrtc/p2p/BUILD.gn b/webrtc/p2p/BUILD.gn index a63af7ba8e..ee796b06a6 100644 --- a/webrtc/p2p/BUILD.gn +++ b/webrtc/p2p/BUILD.gn @@ -27,12 +27,13 @@ rtc_static_library("rtc_p2p") { "base/basicpacketsocketfactory.h", "base/candidate.h", "base/common.h", + "base/dtlstransport.h", "base/dtlstransportchannel.cc", "base/dtlstransportchannel.h", - "base/jseptransport.cc", - "base/jseptransport.h", "base/p2pconstants.cc", "base/p2pconstants.h", + "base/p2ptransport.cc", + "base/p2ptransport.h", "base/p2ptransportchannel.cc", "base/p2ptransportchannel.h", "base/packetsocketfactory.h", @@ -58,6 +59,8 @@ rtc_static_library("rtc_p2p") { "base/stunrequest.h", "base/tcpport.cc", "base/tcpport.h", + "base/transport.cc", + "base/transport.h", "base/transportchannel.cc", "base/transportchannel.h", "base/transportchannelimpl.h", diff --git a/webrtc/p2p/base/dtlstransport.h b/webrtc/p2p/base/dtlstransport.h new file mode 100644 index 0000000000..e59472d7f9 --- /dev/null +++ b/webrtc/p2p/base/dtlstransport.h @@ -0,0 +1,163 @@ +/* + * Copyright 2012 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_DTLSTRANSPORT_H_ +#define WEBRTC_P2P_BASE_DTLSTRANSPORT_H_ + +#include + +#include "webrtc/p2p/base/dtlstransportchannel.h" +#include "webrtc/p2p/base/transport.h" + +namespace rtc { +class SSLIdentity; +} + +namespace cricket { + +class PortAllocator; + +// Base should be a descendant of cricket::Transport and have a constructor +// that takes a transport name and PortAllocator. +// +// Everything in this class should be called on the network thread. +template +class DtlsTransport : public Base { + public: + DtlsTransport(const std::string& name, + PortAllocator* allocator, + const rtc::scoped_refptr& certificate) + : Base(name, allocator), + certificate_(certificate), + secure_role_(rtc::SSL_CLIENT), + ssl_max_version_(rtc::SSL_PROTOCOL_DTLS_12) {} + + ~DtlsTransport() { + Base::DestroyAllChannels(); + } + + void SetLocalCertificate( + const rtc::scoped_refptr& certificate) override { + certificate_ = certificate; + } + bool GetLocalCertificate( + rtc::scoped_refptr* certificate) override { + if (!certificate_) + return false; + + *certificate = certificate_; + return true; + } + + bool SetSslMaxProtocolVersion(rtc::SSLProtocolVersion version) override { + ssl_max_version_ = version; + return true; + } + + bool ApplyLocalTransportDescription(TransportChannelImpl* channel, + std::string* error_desc) override { + rtc::SSLFingerprint* local_fp = + Base::local_description()->identity_fingerprint.get(); + + if (!local_fp) { + certificate_ = nullptr; + } else if (!Base::VerifyCertificateFingerprint(certificate_.get(), local_fp, + error_desc)) { + return false; + } + + if (!channel->SetLocalCertificate(certificate_)) { + return BadTransportDescription("Failed to set local identity.", + error_desc); + } + + // Apply the description in the base class. + return Base::ApplyLocalTransportDescription(channel, error_desc); + } + + bool NegotiateTransportDescription(ContentAction local_role, + std::string* error_desc) override { + if (!Base::local_description() || !Base::remote_description()) { + const std::string msg = "Local and Remote description must be set before " + "transport descriptions are negotiated"; + return BadTransportDescription(msg, error_desc); + } + rtc::SSLFingerprint* local_fp = + Base::local_description()->identity_fingerprint.get(); + rtc::SSLFingerprint* remote_fp = + Base::remote_description()->identity_fingerprint.get(); + if (remote_fp && local_fp) { + remote_fingerprint_.reset(new rtc::SSLFingerprint(*remote_fp)); + if (!Base::NegotiateRole(local_role, &secure_role_, error_desc)) { + return false; + } + } else if (local_fp && (local_role == CA_ANSWER)) { + return BadTransportDescription( + "Local fingerprint supplied when caller didn't offer DTLS.", + error_desc); + } else { + // We are not doing DTLS + remote_fingerprint_.reset(new rtc::SSLFingerprint("", nullptr, 0)); + } + // Now run the negotiation for the base class. + return Base::NegotiateTransportDescription(local_role, error_desc); + } + + DtlsTransportChannelWrapper* CreateTransportChannel(int component) override { + DtlsTransportChannelWrapper* channel = new DtlsTransportChannelWrapper( + Base::CreateTransportChannel(component)); + channel->SetSslMaxProtocolVersion(ssl_max_version_); + return channel; + } + + void DestroyTransportChannel(TransportChannelImpl* channel) override { + // Kind of ugly, but this lets us do the exact inverse of the create. + DtlsTransportChannelWrapper* dtls_channel = + static_cast(channel); + TransportChannelImpl* base_channel = dtls_channel->channel(); + delete dtls_channel; + Base::DestroyTransportChannel(base_channel); + } + + bool GetSslRole(rtc::SSLRole* ssl_role) const override { + ASSERT(ssl_role != NULL); + *ssl_role = secure_role_; + return true; + } + + private: + bool ApplyNegotiatedTransportDescription(TransportChannelImpl* channel, + std::string* error_desc) override { + // Set ssl role. Role must be set before fingerprint is applied, which + // initiates DTLS setup. + if (!channel->SetSslRole(secure_role_)) { + return BadTransportDescription("Failed to set ssl role for the channel.", + error_desc); + } + // Apply remote fingerprint. + if (!channel->SetRemoteFingerprint(remote_fingerprint_->algorithm, + reinterpret_cast( + remote_fingerprint_->digest.data()), + remote_fingerprint_->digest.size())) { + return BadTransportDescription("Failed to apply remote fingerprint.", + error_desc); + } + return Base::ApplyNegotiatedTransportDescription(channel, error_desc); + } + + rtc::scoped_refptr certificate_; + rtc::SSLRole secure_role_; + rtc::SSLProtocolVersion ssl_max_version_; + std::unique_ptr remote_fingerprint_; +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_DTLSTRANSPORT_H_ diff --git a/webrtc/p2p/base/dtlstransportchannel_unittest.cc b/webrtc/p2p/base/dtlstransportchannel_unittest.cc index 3efc1e7df9..1d8802dd55 100644 --- a/webrtc/p2p/base/dtlstransportchannel_unittest.cc +++ b/webrtc/p2p/base/dtlstransportchannel_unittest.cc @@ -11,7 +11,7 @@ #include #include -#include "webrtc/p2p/base/dtlstransportchannel.h" +#include "webrtc/p2p/base/dtlstransport.h" #include "webrtc/p2p/base/faketransportcontroller.h" #include "webrtc/p2p/base/packettransportinterface.h" #include "webrtc/base/common.h" @@ -63,9 +63,6 @@ using cricket::ConnectionRole; enum Flags { NF_REOFFER = 0x1, NF_EXPECT_FAILURE = 0x2 }; -// TODO(deadbeef): Remove the dependency on JsepTransport. This test should be -// testing DtlsTransportChannel by itself, calling methods to set the -// configuration directly instead of negotiating TransportDescriptions. class DtlsTestClient : public sigslot::has_slots<> { public: DtlsTestClient(const std::string& name) : name_(name) {} @@ -82,25 +79,23 @@ class DtlsTestClient : public sigslot::has_slots<> { use_dtls_srtp_ = true; } void SetupMaxProtocolVersion(rtc::SSLProtocolVersion version) { + ASSERT(!transport_); ssl_max_version_ = version; } void SetupChannels(int count, cricket::IceRole role, int async_delay_ms = 0) { - transport_.reset( - new cricket::JsepTransport("dtls content name", certificate_)); - for (int i = 0; i < count; ++i) { - cricket::FakeTransportChannel* fake_ice_channel = - new cricket::FakeTransportChannel(transport_->mid(), i); - fake_ice_channel->SetAsync(true); - fake_ice_channel->SetAsyncDelay(async_delay_ms); - // Hook the raw packets so that we can verify they are encrypted. - fake_ice_channel->SignalReadPacket.connect( - this, &DtlsTestClient::OnFakeTransportChannelReadPacket); + transport_.reset(new cricket::DtlsTransport( + "dtls content name", nullptr, certificate_)); + transport_->SetAsync(true); + transport_->SetAsyncDelay(async_delay_ms); + transport_->SetIceRole(role); + transport_->SetIceTiebreaker( + (role == cricket::ICEROLE_CONTROLLING) ? 1 : 2); + for (int i = 0; i < count; ++i) { cricket::DtlsTransportChannelWrapper* channel = - new cricket::DtlsTransportChannelWrapper(fake_ice_channel); - channel->SetLocalCertificate(certificate_); - channel->SetIceRole(role); - channel->SetIceTiebreaker((role == cricket::ICEROLE_CONTROLLING) ? 1 : 2); + static_cast( + transport_->CreateChannel(i)); + ASSERT_TRUE(channel != NULL); channel->SetSslMaxProtocolVersion(ssl_max_version_); channel->SignalWritableState.connect(this, &DtlsTestClient::OnTransportChannelWritableState); @@ -108,32 +103,27 @@ class DtlsTestClient : public sigslot::has_slots<> { &DtlsTestClient::OnTransportChannelReadPacket); channel->SignalSentPacket.connect( this, &DtlsTestClient::OnTransportChannelSentPacket); - channels_.push_back( - std::unique_ptr(channel)); - fake_channels_.push_back( - std::unique_ptr(fake_ice_channel)); - transport_->AddChannel(channel, i); + channels_.push_back(channel); + + // Hook the raw packets so that we can verify they are encrypted. + channel->channel()->SignalReadPacket.connect( + this, &DtlsTestClient::OnFakeTransportChannelReadPacket); } } - cricket::JsepTransport* transport() { return transport_.get(); } + cricket::Transport* transport() { return transport_.get(); } cricket::FakeTransportChannel* GetFakeChannel(int component) { - for (const auto& ch : fake_channels_) { - if (ch->component() == component) { - return ch.get(); - } - } - return nullptr; + cricket::TransportChannelImpl* ch = transport_->GetChannel(component); + cricket::DtlsTransportChannelWrapper* wrapper = + static_cast(ch); + return (wrapper) ? + static_cast(wrapper->channel()) : NULL; } cricket::DtlsTransportChannelWrapper* GetDtlsChannel(int component) { - for (const auto& ch : channels_) { - if (ch->component() == component) { - return ch.get(); - } - } - return nullptr; + cricket::TransportChannelImpl* ch = transport_->GetChannel(component); + return static_cast(ch); } // Offer DTLS if we have an identity; pass in a remote fingerprint only if @@ -152,7 +142,7 @@ class DtlsTestClient : public sigslot::has_slots<> { std::vector ciphers; ciphers.push_back(rtc::SRTP_AES128_CM_SHA1_80); // SRTP ciphers will be set only in the beginning. - for (const auto& channel : channels_) { + for (cricket::DtlsTransportChannelWrapper* channel : channels_) { EXPECT_TRUE(channel->SetSrtpCryptoSuites(ciphers)); } } @@ -212,10 +202,7 @@ class DtlsTestClient : public sigslot::has_slots<> { } bool Connect(DtlsTestClient* peer, bool asymmetric) { - for (auto& channel : fake_channels_) { - channel->SetDestination(peer->GetFakeChannel(channel->component()), - asymmetric); - } + transport_->SetDestination(peer->transport_.get(), asymmetric); return true; } @@ -223,7 +210,7 @@ class DtlsTestClient : public sigslot::has_slots<> { if (channels_.empty()) { return false; } - for (const auto& channel : channels_) { + for (cricket::DtlsTransportChannelWrapper* channel : channels_) { if (!channel->writable()) { return false; } @@ -235,7 +222,7 @@ class DtlsTestClient : public sigslot::has_slots<> { if (channels_.empty()) { return false; } - for (const auto& channel : channels_) { + for (cricket::DtlsTransportChannelWrapper* channel : channels_) { if (!channel->channel()->writable()) { return false; } @@ -269,10 +256,11 @@ class DtlsTestClient : public sigslot::has_slots<> { } void CheckSrtp(int expected_crypto_suite) { - for (const auto& channel : channels_) { + for (std::vector::iterator it = + channels_.begin(); it != channels_.end(); ++it) { int crypto_suite; - bool rv = channel->GetSrtpCryptoSuite(&crypto_suite); + bool rv = (*it)->GetSrtpCryptoSuite(&crypto_suite); if (negotiated_dtls() && expected_crypto_suite) { ASSERT_TRUE(rv); @@ -284,10 +272,11 @@ class DtlsTestClient : public sigslot::has_slots<> { } void CheckSsl() { - for (const auto& channel : channels_) { + for (std::vector::iterator it = + channels_.begin(); it != channels_.end(); ++it) { int cipher; - bool rv = channel->GetSslCipherSuite(&cipher); + bool rv = (*it)->GetSslCipherSuite(&cipher); if (negotiated_dtls()) { ASSERT_TRUE(rv); @@ -434,9 +423,8 @@ class DtlsTestClient : public sigslot::has_slots<> { private: std::string name_; rtc::scoped_refptr certificate_; - std::vector> fake_channels_; - std::vector> channels_; - std::unique_ptr transport_; + std::unique_ptr transport_; + std::vector channels_; size_t packet_size_ = 0u; std::set received_; bool use_dtls_srtp_ = false; @@ -833,8 +821,8 @@ TEST_F(DtlsTransportChannelTest, TestDtlsSetupWithLegacyAsAnswerer) { NegotiateWithLegacy(); rtc::SSLRole channel1_role; rtc::SSLRole channel2_role; - client1_.transport()->GetSslRole(&channel1_role); - client2_.transport()->GetSslRole(&channel2_role); + EXPECT_TRUE(client1_.transport()->GetSslRole(&channel1_role)); + EXPECT_TRUE(client2_.transport()->GetSslRole(&channel2_role)); EXPECT_EQ(rtc::SSL_SERVER, channel1_role); EXPECT_EQ(rtc::SSL_CLIENT, channel2_role); } @@ -944,8 +932,8 @@ TEST_F(DtlsTransportChannelTest, TestCertificatesBeforeConnect) { ASSERT_TRUE(client2_.transport()->GetLocalCertificate(&certificate2)); ASSERT_NE(certificate1->ssl_certificate().ToPEMString(), certificate2->ssl_certificate().ToPEMString()); - ASSERT_FALSE(client1_.GetDtlsChannel(0)->GetRemoteSSLCertificate()); - ASSERT_FALSE(client2_.GetDtlsChannel(0)->GetRemoteSSLCertificate()); + ASSERT_FALSE(client1_.transport()->GetRemoteSSLCertificate()); + ASSERT_FALSE(client2_.transport()->GetRemoteSSLCertificate()); } // Test Certificates state after connection. @@ -965,12 +953,12 @@ TEST_F(DtlsTransportChannelTest, TestCertificatesAfterConnect) { // Each side's remote certificate is the other side's local certificate. std::unique_ptr remote_cert1 = - client1_.GetDtlsChannel(0)->GetRemoteSSLCertificate(); + client1_.transport()->GetRemoteSSLCertificate(); ASSERT_TRUE(remote_cert1); ASSERT_EQ(remote_cert1->ToPEMString(), certificate2->ssl_certificate().ToPEMString()); std::unique_ptr remote_cert2 = - client2_.GetDtlsChannel(0)->GetRemoteSSLCertificate(); + client2_.transport()->GetRemoteSSLCertificate(); ASSERT_TRUE(remote_cert2); ASSERT_EQ(remote_cert2->ToPEMString(), certificate1->ssl_certificate().ToPEMString()); diff --git a/webrtc/p2p/base/faketransportcontroller.h b/webrtc/p2p/base/faketransportcontroller.h index d42a93ddae..9598b822ce 100644 --- a/webrtc/p2p/base/faketransportcontroller.h +++ b/webrtc/p2p/base/faketransportcontroller.h @@ -17,6 +17,7 @@ #include #include "webrtc/p2p/base/candidatepairinterface.h" +#include "webrtc/p2p/base/transport.h" #include "webrtc/p2p/base/transportchannel.h" #include "webrtc/p2p/base/transportcontroller.h" #include "webrtc/p2p/base/transportchannelimpl.h" @@ -34,6 +35,8 @@ namespace cricket { +class FakeTransport; + namespace { struct PacketMessageData : public rtc::MessageData { PacketMessageData(const char* data, size_t len) : packet(data, len) {} @@ -340,6 +343,146 @@ class FakeTransportChannel : public TransportChannelImpl, bool had_connection_ = false; }; +// Fake transport class, which can be passed to anything that needs a Transport. +// Can be informed of another FakeTransport via SetDestination (low-tech way +// of doing candidates) +class FakeTransport : public Transport { + public: + typedef std::map ChannelMap; + + explicit FakeTransport(const std::string& name) : Transport(name, nullptr) {} + + // Note that we only have a constructor with the allocator parameter so it can + // be wrapped by a DtlsTransport. + FakeTransport(const std::string& name, PortAllocator* allocator) + : Transport(name, nullptr) {} + + ~FakeTransport() { DestroyAllChannels(); } + + const ChannelMap& channels() const { return channels_; } + + // If async, will send packets by "Post"-ing to message queue instead of + // synchronously "Send"-ing. + void SetAsync(bool async) { async_ = async; } + void SetAsyncDelay(int delay_ms) { async_delay_ms_ = delay_ms; } + + // If |asymmetric| is true, only set the destination for this transport, and + // not |dest|. + void SetDestination(FakeTransport* dest, bool asymmetric = false) { + dest_ = dest; + for (const auto& kv : channels_) { + kv.second->SetLocalCertificate(certificate_); + SetChannelDestination(kv.first, kv.second, asymmetric); + } + } + + void SetWritable(bool writable) { + for (const auto& kv : channels_) { + kv.second->SetWritable(writable); + } + } + + void SetLocalCertificate( + const rtc::scoped_refptr& certificate) override { + certificate_ = certificate; + } + bool GetLocalCertificate( + rtc::scoped_refptr* certificate) override { + if (!certificate_) + return false; + + *certificate = certificate_; + return true; + } + + bool GetSslRole(rtc::SSLRole* role) const override { + if (channels_.empty()) { + return false; + } + return channels_.begin()->second->GetSslRole(role); + } + + bool SetSslMaxProtocolVersion(rtc::SSLProtocolVersion version) override { + ssl_max_version_ = version; + for (const auto& kv : channels_) { + kv.second->set_ssl_max_protocol_version(ssl_max_version_); + } + return true; + } + rtc::SSLProtocolVersion ssl_max_protocol_version() const { + return ssl_max_version_; + } + + using Transport::local_description; + using Transport::remote_description; + using Transport::VerifyCertificateFingerprint; + using Transport::NegotiateRole; + + protected: + TransportChannelImpl* CreateTransportChannel(int component) override { + if (channels_.find(component) != channels_.end()) { + return nullptr; + } + FakeTransportChannel* channel = new FakeTransportChannel(name(), component); + channel->set_ssl_max_protocol_version(ssl_max_version_); + channel->SetAsync(async_); + channel->SetAsyncDelay(async_delay_ms_); + SetChannelDestination(component, channel, false); + channels_[component] = channel; + return channel; + } + + void DestroyTransportChannel(TransportChannelImpl* channel) override { + channels_.erase(channel->component()); + delete channel; + } + + private: + FakeTransportChannel* GetFakeChannel(int component) { + auto it = channels_.find(component); + return (it != channels_.end()) ? it->second : nullptr; + } + + void SetChannelDestination(int component, + FakeTransportChannel* channel, + bool asymmetric) { + FakeTransportChannel* dest_channel = nullptr; + if (dest_) { + dest_channel = dest_->GetFakeChannel(component); + if (dest_channel && !asymmetric) { + dest_channel->SetLocalCertificate(dest_->certificate_); + } + } + channel->SetDestination(dest_channel, asymmetric); + } + + // Note, this is distinct from the Channel map owned by Transport. + // This map just tracks the FakeTransportChannels created by this class. + // It's mainly needed so that we can access a FakeTransportChannel directly, + // even if wrapped by a DtlsTransportChannelWrapper. + ChannelMap channels_; + FakeTransport* dest_ = nullptr; + bool async_ = false; + int async_delay_ms_ = 0; + rtc::scoped_refptr certificate_; + rtc::SSLProtocolVersion ssl_max_version_ = rtc::SSL_PROTOCOL_DTLS_12; +}; + +#ifdef HAVE_QUIC +class FakeQuicTransport : public QuicTransport { + public: + FakeQuicTransport(const std::string& transport_name) + : QuicTransport(transport_name, nullptr, nullptr) {} + + protected: + QuicTransportChannel* CreateTransportChannel(int component) override { + FakeTransportChannel* fake_ice_transport_channel = + new FakeTransportChannel(name(), component); + return new QuicTransportChannel(fake_ice_transport_channel); + } +}; +#endif + // Fake candidate pair class, which can be passed to BaseChannel for testing // purposes. class FakeCandidatePair : public CandidatePairInterface { @@ -369,66 +512,52 @@ class FakeTransportController : public TransportController { FakeTransportController() : TransportController(rtc::Thread::Current(), rtc::Thread::Current(), - nullptr) {} + nullptr), + fail_create_channel_(false) {} explicit FakeTransportController(bool redetermine_role_on_ice_restart) : TransportController(rtc::Thread::Current(), rtc::Thread::Current(), nullptr, - redetermine_role_on_ice_restart) {} + redetermine_role_on_ice_restart), + fail_create_channel_(false) {} explicit FakeTransportController(IceRole role) : TransportController(rtc::Thread::Current(), rtc::Thread::Current(), - nullptr) { + nullptr), + fail_create_channel_(false) { SetIceRole(role); } explicit FakeTransportController(rtc::Thread* network_thread) - : TransportController(rtc::Thread::Current(), network_thread, nullptr) {} + : TransportController(rtc::Thread::Current(), network_thread, nullptr), + fail_create_channel_(false) {} FakeTransportController(rtc::Thread* network_thread, IceRole role) - : TransportController(rtc::Thread::Current(), network_thread, nullptr) { + : TransportController(rtc::Thread::Current(), network_thread, nullptr), + fail_create_channel_(false) { SetIceRole(role); } - FakeTransportChannel* GetFakeTransportChannel_n( - const std::string& transport_name, - int component) { - return static_cast( - get_channel_for_testing(transport_name, component)); + FakeTransport* GetTransport_n(const std::string& transport_name) { + return static_cast( + TransportController::GetTransport_n(transport_name)); } - // Simulate the exchange of transport descriptions, and the gathering and - // exchange of ICE candidates. void Connect(FakeTransportController* dest) { - for (const std::string& transport_name : transport_names_for_testing()) { - TransportDescription local_desc( - std::vector(), - rtc::CreateRandomString(cricket::ICE_UFRAG_LENGTH), - rtc::CreateRandomString(cricket::ICE_PWD_LENGTH), - cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_NONE, nullptr); - TransportDescription remote_desc( - std::vector(), - rtc::CreateRandomString(cricket::ICE_UFRAG_LENGTH), - rtc::CreateRandomString(cricket::ICE_PWD_LENGTH), - cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_NONE, nullptr); - std::string err; - SetLocalTransportDescription(transport_name, local_desc, - cricket::CA_OFFER, &err); - dest->SetRemoteTransportDescription(transport_name, local_desc, - cricket::CA_OFFER, &err); - dest->SetLocalTransportDescription(transport_name, remote_desc, - cricket::CA_ANSWER, &err); - SetRemoteTransportDescription(transport_name, remote_desc, - cricket::CA_ANSWER, &err); - } - MaybeStartGathering(); - dest->MaybeStartGathering(); network_thread()->Invoke( RTC_FROM_HERE, - rtc::Bind(&FakeTransportController::SetChannelDestinations_n, this, - dest)); + rtc::Bind(&FakeTransportController::Connect_n, this, dest)); + } + + TransportChannel* CreateTransportChannel_n(const std::string& transport_name, + int component) override { + if (fail_create_channel_) { + return nullptr; + } + return TransportController::CreateTransportChannel_n(transport_name, + component); } FakeCandidatePair* CreateFakeCandidatePair( @@ -443,35 +572,50 @@ class FakeTransportController : public TransportController { return new FakeCandidatePair(local_candidate, remote_candidate); } - protected: - // The ICE channel is never actually used by TransportController directly, - // since (currently) the DTLS channel pretends to be both ICE + DTLS. This - // will change when we get rid of TransportChannelImpl. - TransportChannelImpl* CreateIceTransportChannel_n( - const std::string& transport_name, - int component) override { - return nullptr; + void set_fail_channel_creation(bool fail_channel_creation) { + fail_create_channel_ = fail_channel_creation; } - TransportChannelImpl* CreateDtlsTransportChannel_n( - const std::string& transport_name, - int component, - TransportChannelImpl*) override { - return new FakeTransportChannel(transport_name, component); + protected: + Transport* CreateTransport_n(const std::string& transport_name) override { +#ifdef HAVE_QUIC + if (quic()) { + return new FakeQuicTransport(transport_name); + } +#endif + return new FakeTransport(transport_name); + } + + void Connect_n(FakeTransportController* dest) { + // Simulate the exchange of candidates. + ConnectChannels_n(); + dest->ConnectChannels_n(); + for (auto& kv : transports()) { + FakeTransport* transport = static_cast(kv.second); + transport->SetDestination(dest->GetTransport_n(kv.first)); + } + } + + void ConnectChannels_n() { + TransportDescription faketransport_desc( + std::vector(), + rtc::CreateRandomString(cricket::ICE_UFRAG_LENGTH), + rtc::CreateRandomString(cricket::ICE_PWD_LENGTH), cricket::ICEMODE_FULL, + cricket::CONNECTIONROLE_NONE, nullptr); + for (auto& kv : transports()) { + FakeTransport* transport = static_cast(kv.second); + // Set local transport description for FakeTransport before connecting. + // Otherwise, the RTC_CHECK in Transport.ConnectChannel will fail. + if (!transport->local_description()) { + transport->SetLocalTransportDescription(faketransport_desc, + cricket::CA_OFFER, nullptr); + } + transport->MaybeStartGathering(); + } } private: - void SetChannelDestinations_n(FakeTransportController* dest) { - for (TransportChannelImpl* tc : channels_for_testing()) { - FakeTransportChannel* local = static_cast(tc); - FakeTransportChannel* remote = dest->GetFakeTransportChannel_n( - local->transport_name(), local->component()); - if (remote) { - bool asymmetric = false; - local->SetDestination(remote, asymmetric); - } - } - } + bool fail_create_channel_; }; } // namespace cricket diff --git a/webrtc/p2p/base/p2ptransport.cc b/webrtc/p2p/base/p2ptransport.cc new file mode 100644 index 0000000000..1ad2a6faa3 --- /dev/null +++ b/webrtc/p2p/base/p2ptransport.cc @@ -0,0 +1,38 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/p2p/base/p2ptransport.h" + +#include + +#include "webrtc/base/base64.h" +#include "webrtc/base/common.h" +#include "webrtc/base/stringencode.h" +#include "webrtc/base/stringutils.h" +#include "webrtc/p2p/base/p2ptransportchannel.h" + +namespace cricket { + +P2PTransport::P2PTransport(const std::string& name, PortAllocator* allocator) + : Transport(name, allocator) {} + +P2PTransport::~P2PTransport() { + DestroyAllChannels(); +} + +TransportChannelImpl* P2PTransport::CreateTransportChannel(int component) { + return new P2PTransportChannel(name(), component, port_allocator()); +} + +void P2PTransport::DestroyTransportChannel(TransportChannelImpl* channel) { + delete channel; +} + +} // namespace cricket diff --git a/webrtc/p2p/base/p2ptransport.h b/webrtc/p2p/base/p2ptransport.h new file mode 100644 index 0000000000..87353356e2 --- /dev/null +++ b/webrtc/p2p/base/p2ptransport.h @@ -0,0 +1,39 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_P2PTRANSPORT_H_ +#define WEBRTC_P2P_BASE_P2PTRANSPORT_H_ + +#include + +#include "webrtc/base/constructormagic.h" +#include "webrtc/p2p/base/transport.h" + +namespace cricket { + +// Everything in this class should be called on the network thread. +class P2PTransport : public Transport { + public: + P2PTransport(const std::string& name, PortAllocator* allocator); + virtual ~P2PTransport(); + + protected: + // Creates and destroys P2PTransportChannel. + virtual TransportChannelImpl* CreateTransportChannel(int component); + virtual void DestroyTransportChannel(TransportChannelImpl* channel); + + friend class P2PTransportChannel; + + RTC_DISALLOW_COPY_AND_ASSIGN(P2PTransport); +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_P2PTRANSPORT_H_ diff --git a/webrtc/p2p/base/p2ptransportchannel.cc b/webrtc/p2p/base/p2ptransportchannel.cc index 5f858c4858..e1a955c696 100644 --- a/webrtc/p2p/base/p2ptransportchannel.cc +++ b/webrtc/p2p/base/p2ptransportchannel.cc @@ -97,6 +97,12 @@ static constexpr int DEFAULT_BACKUP_CONNECTION_PING_INTERVAL = 25 * 1000; static constexpr int a_is_better = 1; static constexpr int b_is_better = -1; +P2PTransportChannel::P2PTransportChannel(const std::string& transport_name, + int component, + P2PTransport* transport, + PortAllocator* allocator) + : P2PTransportChannel(transport_name, component, allocator) {} + P2PTransportChannel::P2PTransportChannel(const std::string& transport_name, int component, PortAllocator* allocator) diff --git a/webrtc/p2p/base/p2ptransportchannel.h b/webrtc/p2p/base/p2ptransportchannel.h index e538dc2c97..cdb83005b6 100644 --- a/webrtc/p2p/base/p2ptransportchannel.h +++ b/webrtc/p2p/base/p2ptransportchannel.h @@ -29,6 +29,7 @@ #include "webrtc/base/constructormagic.h" #include "webrtc/p2p/base/candidate.h" #include "webrtc/p2p/base/candidatepairinterface.h" +#include "webrtc/p2p/base/p2ptransport.h" #include "webrtc/p2p/base/portallocator.h" #include "webrtc/p2p/base/portinterface.h" #include "webrtc/p2p/base/transportchannelimpl.h" @@ -66,6 +67,12 @@ class P2PTransportChannel : public TransportChannelImpl, P2PTransportChannel(const std::string& transport_name, int component, PortAllocator* allocator); + // TODO(mikescarlett): Deprecated. Remove when Chromium's + // IceTransportChannel does not depend on this. + P2PTransportChannel(const std::string& transport_name, + int component, + P2PTransport* transport, + PortAllocator* allocator); virtual ~P2PTransportChannel(); // From TransportChannelImpl: diff --git a/webrtc/p2p/base/p2ptransportchannel_unittest.cc b/webrtc/p2p/base/p2ptransportchannel_unittest.cc index 4a16e2f4e8..fa6fac17b5 100644 --- a/webrtc/p2p/base/p2ptransportchannel_unittest.cc +++ b/webrtc/p2p/base/p2ptransportchannel_unittest.cc @@ -4201,7 +4201,7 @@ class P2PTransportChannelMostLikelyToWorkFirstTest P2PTransportChannel& StartTransportChannel( bool prioritize_most_likely_to_work, int stable_writable_connection_ping_interval) { - channel_.reset(new P2PTransportChannel("checks", 1, allocator())); + channel_.reset(new P2PTransportChannel("checks", 1, nullptr, allocator())); IceConfig config = channel_->config(); config.prioritize_most_likely_candidate_pairs = prioritize_most_likely_to_work; diff --git a/webrtc/p2p/base/port.h b/webrtc/p2p/base/port.h index c7e783721b..3c84a14efd 100644 --- a/webrtc/p2p/base/port.h +++ b/webrtc/p2p/base/port.h @@ -19,11 +19,11 @@ #include "webrtc/p2p/base/candidate.h" #include "webrtc/p2p/base/candidatepairinterface.h" -#include "webrtc/p2p/base/jseptransport.h" #include "webrtc/p2p/base/packetsocketfactory.h" #include "webrtc/p2p/base/portinterface.h" #include "webrtc/p2p/base/stun.h" #include "webrtc/p2p/base/stunrequest.h" +#include "webrtc/p2p/base/transport.h" #include "webrtc/base/asyncpacketsocket.h" #include "webrtc/base/network.h" #include "webrtc/base/proxyinfo.h" diff --git a/webrtc/p2p/base/port_unittest.cc b/webrtc/p2p/base/port_unittest.cc index 028830e62c..d5767ddd2b 100644 --- a/webrtc/p2p/base/port_unittest.cc +++ b/webrtc/p2p/base/port_unittest.cc @@ -11,13 +11,13 @@ #include #include "webrtc/p2p/base/basicpacketsocketfactory.h" -#include "webrtc/p2p/base/jseptransport.h" #include "webrtc/p2p/base/relayport.h" #include "webrtc/p2p/base/stunport.h" #include "webrtc/p2p/base/tcpport.h" #include "webrtc/p2p/base/testrelayserver.h" #include "webrtc/p2p/base/teststunserver.h" #include "webrtc/p2p/base/testturnserver.h" +#include "webrtc/p2p/base/transport.h" #include "webrtc/p2p/base/turnport.h" #include "webrtc/base/arraysize.h" #include "webrtc/base/buffer.h" diff --git a/webrtc/p2p/base/portinterface.h b/webrtc/p2p/base/portinterface.h index e08e791f1a..38945f9e13 100644 --- a/webrtc/p2p/base/portinterface.h +++ b/webrtc/p2p/base/portinterface.h @@ -13,9 +13,9 @@ #include +#include "webrtc/p2p/base/transport.h" #include "webrtc/base/asyncpacketsocket.h" #include "webrtc/base/socketaddress.h" -#include "webrtc/p2p/base/jseptransport.h" namespace rtc { class Network; diff --git a/webrtc/p2p/base/jseptransport.cc b/webrtc/p2p/base/transport.cc similarity index 62% rename from webrtc/p2p/base/jseptransport.cc rename to webrtc/p2p/base/transport.cc index abfe0449a4..7bd14c4cad 100644 --- a/webrtc/p2p/base/jseptransport.cc +++ b/webrtc/p2p/base/transport.cc @@ -11,12 +11,10 @@ #include #include // for std::pair -#include "webrtc/p2p/base/jseptransport.h" +#include "webrtc/p2p/base/transport.h" #include "webrtc/p2p/base/candidate.h" -#include "webrtc/p2p/base/dtlstransportchannel.h" #include "webrtc/p2p/base/p2pconstants.h" -#include "webrtc/p2p/base/p2ptransportchannel.h" #include "webrtc/p2p/base/port.h" #include "webrtc/p2p/base/transportchannelimpl.h" #include "webrtc/base/bind.h" @@ -60,7 +58,169 @@ bool IceCredentialsChanged(const std::string& old_ufrag, return (old_ufrag != new_ufrag) || (old_pwd != new_pwd); } -bool VerifyCandidate(const Candidate& cand, std::string* error) { +Transport::Transport(const std::string& name, PortAllocator* allocator) + : name_(name), allocator_(allocator) {} + +Transport::~Transport() { + RTC_DCHECK(channels_destroyed_); +} + +void Transport::SetIceRole(IceRole role) { + ice_role_ = role; + for (const auto& kv : channels_) { + kv.second->SetIceRole(ice_role_); + } +} + +std::unique_ptr Transport::GetRemoteSSLCertificate() { + if (channels_.empty()) { + return nullptr; + } + + auto iter = channels_.begin(); + return iter->second->GetRemoteSSLCertificate(); +} + +void Transport::SetIceConfig(const IceConfig& config) { + ice_config_ = config; + for (const auto& kv : channels_) { + kv.second->SetIceConfig(ice_config_); + } +} + +bool Transport::SetLocalTransportDescription( + const TransportDescription& description, + ContentAction action, + std::string* error_desc) { + bool ret = true; + + if (!VerifyIceParams(description)) { + return BadTransportDescription("Invalid ice-ufrag or ice-pwd length", + error_desc); + } + + local_description_.reset(new TransportDescription(description)); + + for (const auto& kv : channels_) { + ret &= ApplyLocalTransportDescription(kv.second, error_desc); + } + if (!ret) { + return false; + } + + // If PRANSWER/ANSWER is set, we should decide transport protocol type. + if (action == CA_PRANSWER || action == CA_ANSWER) { + ret &= NegotiateTransportDescription(action, error_desc); + } + if (ret) { + local_description_set_ = true; + } + + return ret; +} + +bool Transport::SetRemoteTransportDescription( + const TransportDescription& description, + ContentAction action, + std::string* error_desc) { + bool ret = true; + + if (!VerifyIceParams(description)) { + return BadTransportDescription("Invalid ice-ufrag or ice-pwd length", + error_desc); + } + + remote_description_.reset(new TransportDescription(description)); + for (const auto& kv : channels_) { + ret &= ApplyRemoteTransportDescription(kv.second, error_desc); + } + + // If PRANSWER/ANSWER is set, we should decide transport protocol type. + if (action == CA_PRANSWER || action == CA_ANSWER) { + ret = NegotiateTransportDescription(CA_OFFER, error_desc); + } + if (ret) { + remote_description_set_ = true; + } + + return ret; +} + +TransportChannelImpl* Transport::CreateChannel(int component) { + TransportChannelImpl* channel; + + // Create the entry if it does not exist. + bool channel_exists = false; + auto iter = channels_.find(component); + if (iter == channels_.end()) { + channel = CreateTransportChannel(component); + channels_.insert(std::pair(component, channel)); + } else { + channel = iter->second; + channel_exists = true; + } + + channels_destroyed_ = false; + + if (channel_exists) { + // If this is an existing channel, we should just return it. + return channel; + } + + // Push down our transport state to the new channel. + channel->SetIceRole(ice_role_); + channel->SetIceTiebreaker(tiebreaker_); + channel->SetIceConfig(ice_config_); + // TODO(ronghuawu): Change CreateChannel to be able to return error since + // below Apply**Description calls can fail. + if (local_description_) + ApplyLocalTransportDescription(channel, nullptr); + if (remote_description_) + ApplyRemoteTransportDescription(channel, nullptr); + if (local_description_ && remote_description_) + ApplyNegotiatedTransportDescription(channel, nullptr); + + return channel; +} + +TransportChannelImpl* Transport::GetChannel(int component) { + auto iter = channels_.find(component); + return (iter != channels_.end()) ? iter->second : nullptr; +} + +bool Transport::HasChannels() { + return !channels_.empty(); +} + +void Transport::DestroyChannel(int component) { + auto iter = channels_.find(component); + if (iter == channels_.end()) + return; + + TransportChannelImpl* channel = iter->second; + channels_.erase(iter); + DestroyTransportChannel(channel); +} + +void Transport::MaybeStartGathering() { + CallChannels(&TransportChannelImpl::MaybeStartGathering); +} + +void Transport::DestroyAllChannels() { + for (const auto& kv : channels_) { + DestroyTransportChannel(kv.second); + } + channels_.clear(); + channels_destroyed_ = true; +} + +void Transport::CallChannels(TransportChannelFunc func) { + for (const auto& kv : channels_) { + (kv.second->*func)(); + } +} + +bool Transport::VerifyCandidate(const Candidate& cand, std::string* error) { // No address zero. if (cand.address().IsNil() || cand.address().IsAnyIP()) { *error = "candidate has address of zero"; @@ -88,10 +248,17 @@ bool VerifyCandidate(const Candidate& cand, std::string* error) { } } + if (!HasChannel(cand.component())) { + *error = "Candidate has an unknown component: " + cand.ToString() + + " for content: " + name(); + return false; + } + return true; } -bool VerifyCandidates(const Candidates& candidates, std::string* error) { +bool Transport::VerifyCandidates(const Candidates& candidates, + std::string* error) { for (const Candidate& candidate : candidates) { if (!VerifyCandidate(candidate, error)) { return false; @@ -100,151 +267,14 @@ bool VerifyCandidates(const Candidates& candidates, std::string* error) { return true; } -JsepTransport::JsepTransport( - const std::string& mid, - const rtc::scoped_refptr& certificate) - : mid_(mid), certificate_(certificate) {} -bool JsepTransport::AddChannel(TransportChannelImpl* dtls, int component) { - if (channels_.find(component) != channels_.end()) { - LOG(LS_ERROR) << "Adding channel for component " << component << " twice."; - return false; - } - channels_[component] = dtls; - // Something's wrong if a channel is being added after a description is set. - // This may currently occur if rtcp-mux is negotiated, then a new m= section - // is added in a later offer/answer. But this is suboptimal and should be - // changed; we shouldn't support going from muxed to non-muxed. - // TODO(deadbeef): Once this is fixed, make the warning an error, and remove - // the calls to "ApplyXTransportDescription" below. - if (local_description_set_ || remote_description_set_) { - LOG(LS_WARNING) << "Adding new transport channel after " - "transport description already applied."; - } - bool ret = true; - std::string err; - if (local_description_set_) { - ret &= ApplyLocalTransportDescription(channels_[component], &err); - } - if (remote_description_set_) { - ret &= ApplyRemoteTransportDescription(channels_[component], &err); - } - if (local_description_set_ && remote_description_set_) { - ret &= ApplyNegotiatedTransportDescription(channels_[component], &err); - } - return ret; -} - -bool JsepTransport::RemoveChannel(int component) { - auto it = channels_.find(component); - if (it == channels_.end()) { - LOG(LS_ERROR) << "Trying to remove channel for component " << component - << ", which doesn't exist."; - return false; - } - channels_.erase(component); - return true; -} - -bool JsepTransport::HasChannels() const { - return !channels_.empty(); -} - -void JsepTransport::SetLocalCertificate( - const rtc::scoped_refptr& certificate) { - certificate_ = certificate; -} - -bool JsepTransport::GetLocalCertificate( - rtc::scoped_refptr* certificate) const { - if (!certificate_) { - return false; - } - - *certificate = certificate_; - return true; -} - -bool JsepTransport::SetLocalTransportDescription( - const TransportDescription& description, - ContentAction action, - std::string* error_desc) { - bool ret = true; - - if (!VerifyIceParams(description)) { - return BadTransportDescription("Invalid ice-ufrag or ice-pwd length", - error_desc); - } - - local_description_.reset(new TransportDescription(description)); - - rtc::SSLFingerprint* local_fp = - local_description_->identity_fingerprint.get(); - - if (!local_fp) { - certificate_ = nullptr; - } else if (!VerifyCertificateFingerprint(certificate_.get(), local_fp, - error_desc)) { - return false; - } - - for (const auto& kv : channels_) { - ret &= ApplyLocalTransportDescription(kv.second, error_desc); - } - if (!ret) { - return false; - } - - // If PRANSWER/ANSWER is set, we should decide transport protocol type. - if (action == CA_PRANSWER || action == CA_ANSWER) { - ret &= NegotiateTransportDescription(action, error_desc); - } - if (ret) { - local_description_set_ = true; - } - - return ret; -} - -bool JsepTransport::SetRemoteTransportDescription( - const TransportDescription& description, - ContentAction action, - std::string* error_desc) { - bool ret = true; - - if (!VerifyIceParams(description)) { - return BadTransportDescription("Invalid ice-ufrag or ice-pwd length", - error_desc); - } - - remote_description_.reset(new TransportDescription(description)); - for (const auto& kv : channels_) { - ret &= ApplyRemoteTransportDescription(kv.second, error_desc); - } - - // If PRANSWER/ANSWER is set, we should decide transport protocol type. - if (action == CA_PRANSWER || action == CA_ANSWER) { - ret = NegotiateTransportDescription(CA_OFFER, error_desc); - } - if (ret) { - remote_description_set_ = true; - } - - return ret; -} - -void JsepTransport::GetSslRole(rtc::SSLRole* ssl_role) const { - RTC_DCHECK(ssl_role); - *ssl_role = secure_role_; -} - -bool JsepTransport::GetStats(TransportStats* stats) { - stats->transport_name = mid(); +bool Transport::GetStats(TransportStats* stats) { + stats->transport_name = name(); stats->channel_stats.clear(); - for (auto& kv : channels_) { + for (auto kv : channels_) { TransportChannelImpl* channel = kv.second; TransportChannelStats substats; - substats.component = kv.first; + substats.component = channel->component(); channel->GetSrtpCryptoSuite(&substats.srtp_crypto_suite); channel->GetSslCipherSuite(&substats.ssl_cipher_suite); if (!channel->GetStats(&substats.connection_infos)) { @@ -255,7 +285,88 @@ bool JsepTransport::GetStats(TransportStats* stats) { return true; } -bool JsepTransport::VerifyCertificateFingerprint( +bool Transport::AddRemoteCandidates(const std::vector& candidates, + std::string* error) { + ASSERT(!channels_destroyed_); + // Verify each candidate before passing down to the transport layer. + if (!VerifyCandidates(candidates, error)) { + return false; + } + + for (const Candidate& candidate : candidates) { + TransportChannelImpl* channel = GetChannel(candidate.component()); + if (channel != nullptr) { + channel->AddRemoteCandidate(candidate); + } + } + return true; +} + +bool Transport::RemoveRemoteCandidates(const std::vector& candidates, + std::string* error) { + ASSERT(!channels_destroyed_); + // Verify each candidate before passing down to the transport layer. + if (!VerifyCandidates(candidates, error)) { + return false; + } + + for (const Candidate& candidate : candidates) { + TransportChannelImpl* channel = GetChannel(candidate.component()); + if (channel != nullptr) { + channel->RemoveRemoteCandidate(candidate); + } + } + return true; +} + +bool Transport::ApplyLocalTransportDescription(TransportChannelImpl* ch, + std::string* error_desc) { + ch->SetIceParameters(local_description_->GetIceParameters()); + return true; +} + +bool Transport::ApplyRemoteTransportDescription(TransportChannelImpl* ch, + std::string* error_desc) { + ch->SetRemoteIceParameters(remote_description_->GetIceParameters()); + return true; +} + +bool Transport::ApplyNegotiatedTransportDescription( + TransportChannelImpl* channel, + std::string* error_desc) { + channel->SetRemoteIceMode(remote_ice_mode_); + return true; +} + +bool Transport::NegotiateTransportDescription(ContentAction local_role, + std::string* error_desc) { + // TODO(ekr@rtfm.com): This is ICE-specific stuff. Refactor into + // P2PTransport. + + // If transport is in ICEROLE_CONTROLLED and remote end point supports only + // ice_lite, this local end point should take CONTROLLING role. + if (ice_role_ == ICEROLE_CONTROLLED && + remote_description_->ice_mode == ICEMODE_LITE) { + SetIceRole(ICEROLE_CONTROLLING); + } + + // Update remote ice_mode to all existing channels. + remote_ice_mode_ = remote_description_->ice_mode; + + // Now that we have negotiated everything, push it downward. + // Note that we cache the result so that if we have race conditions + // between future SetRemote/SetLocal invocations and new channel + // creation, we have the negotiation state saved until a new + // negotiation happens. + for (const auto& kv : channels_) { + if (!ApplyNegotiatedTransportDescription(kv.second, error_desc)) { + return false; + } + } + return true; +} + +bool Transport::VerifyCertificateFingerprint( const rtc::RTCCertificate* certificate, const rtc::SSLFingerprint* fingerprint, std::string* error_desc) const { @@ -279,89 +390,11 @@ bool JsepTransport::VerifyCertificateFingerprint( return BadTransportDescription(desc.str(), error_desc); } -bool JsepTransport::ApplyLocalTransportDescription( - TransportChannelImpl* channel, - std::string* error_desc) { - channel->SetIceParameters(local_description_->GetIceParameters()); - return true; -} - -bool JsepTransport::ApplyRemoteTransportDescription( - TransportChannelImpl* channel, - std::string* error_desc) { - // Currently, all ICE-related calls still go through this DTLS channel. But - // that will change once we get rid of TransportChannelImpl, and the DTLS - // channel interface no longer includes ICE-specific methods. Then this class - // will need to call dtls->ice()->SetIceRole(), for example, assuming the Dtls - // interface will expose its inner ICE channel. - channel->SetRemoteIceParameters(remote_description_->GetIceParameters()); - channel->SetRemoteIceMode(remote_description_->ice_mode); - return true; -} - -bool JsepTransport::ApplyNegotiatedTransportDescription( - TransportChannelImpl* channel, - std::string* error_desc) { - // Set SSL role. Role must be set before fingerprint is applied, which - // initiates DTLS setup. - if (!channel->SetSslRole(secure_role_)) { - return BadTransportDescription("Failed to set SSL role for the channel.", - error_desc); - } - // Apply remote fingerprint. - if (!channel->SetRemoteFingerprint( - remote_fingerprint_->algorithm, - reinterpret_cast(remote_fingerprint_->digest.data()), - remote_fingerprint_->digest.size())) { - return BadTransportDescription("Failed to apply remote fingerprint.", - error_desc); - } - return true; -} - -bool JsepTransport::NegotiateTransportDescription(ContentAction local_role, - std::string* error_desc) { - if (!local_description_ || !remote_description_) { - const std::string msg = - "Applying an answer transport description " - "without applying any offer."; - return BadTransportDescription(msg, error_desc); - } - rtc::SSLFingerprint* local_fp = - local_description_->identity_fingerprint.get(); - rtc::SSLFingerprint* remote_fp = - remote_description_->identity_fingerprint.get(); - if (remote_fp && local_fp) { - remote_fingerprint_.reset(new rtc::SSLFingerprint(*remote_fp)); - if (!NegotiateRole(local_role, &secure_role_, error_desc)) { - return false; - } - } else if (local_fp && (local_role == CA_ANSWER)) { - return BadTransportDescription( - "Local fingerprint supplied when caller didn't offer DTLS.", - error_desc); - } else { - // We are not doing DTLS - remote_fingerprint_.reset(new rtc::SSLFingerprint("", nullptr, 0)); - } - // Now that we have negotiated everything, push it downward. - // Note that we cache the result so that if we have race conditions - // between future SetRemote/SetLocal invocations and new channel - // creation, we have the negotiation state saved until a new - // negotiation happens. - for (const auto& kv : channels_) { - if (!ApplyNegotiatedTransportDescription(kv.second, error_desc)) { - return false; - } - } - return true; -} - -bool JsepTransport::NegotiateRole(ContentAction local_role, - rtc::SSLRole* ssl_role, - std::string* error_desc) const { +bool Transport::NegotiateRole(ContentAction local_role, + rtc::SSLRole* ssl_role, + std::string* error_desc) const { RTC_DCHECK(ssl_role); - if (!local_description_ || !remote_description_) { + if (!local_description() || !remote_description()) { const std::string msg = "Local and Remote description must be set before " "transport descriptions are negotiated"; @@ -391,8 +424,8 @@ bool JsepTransport::NegotiateRole(ContentAction local_role, // ClientHello over each flow (host/port quartet). // IOW - actpass and passive modes should be treated as server and // active as client. - ConnectionRole local_connection_role = local_description_->connection_role; - ConnectionRole remote_connection_role = remote_description_->connection_role; + ConnectionRole local_connection_role = local_description()->connection_role; + ConnectionRole remote_connection_role = remote_description()->connection_role; bool is_remote_server = false; if (local_role == CA_OFFER) { diff --git a/webrtc/p2p/base/jseptransport.h b/webrtc/p2p/base/transport.h similarity index 66% rename from webrtc/p2p/base/jseptransport.h rename to webrtc/p2p/base/transport.h index 537ddea01a..adef6533ea 100644 --- a/webrtc/p2p/base/jseptransport.h +++ b/webrtc/p2p/base/transport.h @@ -22,8 +22,8 @@ // It is not possible to do so here because the subclass destructor will // already have run. -#ifndef WEBRTC_P2P_BASE_JSEPTRANSPORT_H_ -#define WEBRTC_P2P_BASE_JSEPTRANSPORT_H_ +#ifndef WEBRTC_P2P_BASE_TRANSPORT_H_ +#define WEBRTC_P2P_BASE_TRANSPORT_H_ #include #include @@ -43,14 +43,12 @@ namespace cricket { -class TransportChannelImpl; +class PortAllocator; +class TransportChannel; class TransportChannelImpl; typedef std::vector Candidates; -// TODO(deadbeef): Move all of these enums, POD types and utility methods to -// another header file. - // TODO(deadbeef): Unify with PeerConnectionInterface::IceConnectionState // once /talk/ and /webrtc/ are combined, and also switch to ENUM_NAME naming // style. @@ -117,14 +115,14 @@ struct ConnectionInfo { recv_ping_responses(0), key(NULL) {} - bool best_connection; // Is this the best connection we have? - bool writable; // Has this connection received a STUN response? - bool receiving; // Has this connection received anything? - bool timeout; // Has this connection timed out? - bool new_connection; // Is this a newly created connection? - size_t rtt; // The STUN RTT for this connection. - size_t sent_total_bytes; // Total bytes sent on this connection. - size_t sent_bytes_second; // Bps over the last measurement interval. + bool best_connection; // Is this the best connection we have? + bool writable; // Has this connection received a STUN response? + bool receiving; // Has this connection received anything? + bool timeout; // Has this connection timed out? + bool new_connection; // Is this a newly created connection? + size_t rtt; // The STUN RTT for this connection. + size_t sent_total_bytes; // Total bytes sent on this connection. + size_t sent_bytes_second; // Bps over the last measurement interval. size_t sent_discarded_packets; // Number of outgoing packets discarded due to // socket errors. size_t sent_total_packets; // Number of total outgoing packets attempted for @@ -244,134 +242,182 @@ bool IceCredentialsChanged(const std::string& old_ufrag, const std::string& new_ufrag, const std::string& new_pwd); -// If a candidate is not acceptable, returns false and sets error. -bool VerifyCandidate(const Candidate& candidate, std::string* error); -bool VerifyCandidates(const Candidates& candidates, std::string* error); - -// Helper class used by TransportController that processes -// TransportDescriptions. A TransportDescription represents the -// transport-specific properties of an SDP m= section, processed according to -// JSEP. Each transport consists of DTLS and ICE transport channels for RTP -// (and possibly RTCP, if rtcp-mux isn't used). -// TODO(deadbeef): Move this into /pc/ and out of /p2p/base/, since it's -// PeerConnection-specific. -class JsepTransport : public sigslot::has_slots<> { +class Transport : public sigslot::has_slots<> { public: - // |mid| is just used for log statements in order to identify the Transport. - // Note that |certificate| is allowed to be null since a remote description - // may be set before a local certificate is generated. - JsepTransport(const std::string& mid, - const rtc::scoped_refptr& certificate); + Transport(const std::string& name, PortAllocator* allocator); + virtual ~Transport(); - // Returns the MID of this transport. - const std::string& mid() const { return mid_; } + // Returns the name of this transport. + const std::string& name() const { return name_; } - // Add or remove channel that is affected when a local/remote transport - // description is set on this transport. Need to add all channels before - // setting a transport description. - bool AddChannel(TransportChannelImpl* dtls, int component); - bool RemoveChannel(int component); - bool HasChannels() const; + // Returns the port allocator object for this transport. + PortAllocator* port_allocator() { return allocator_; } bool ready_for_remote_candidates() const { return local_description_set_ && remote_description_set_; } + void SetIceRole(IceRole role); + IceRole ice_role() const { return ice_role_; } + + void SetIceTiebreaker(uint64_t IceTiebreaker) { tiebreaker_ = IceTiebreaker; } + uint64_t IceTiebreaker() { return tiebreaker_; } + + void SetIceConfig(const IceConfig& config); + // Must be called before applying local session description. - // Needed in order to verify the local fingerprint. - void SetLocalCertificate( - const rtc::scoped_refptr& certificate); + virtual void SetLocalCertificate( + const rtc::scoped_refptr& certificate) {} // Get a copy of the local certificate provided by SetLocalCertificate. - bool GetLocalCertificate( - rtc::scoped_refptr* certificate) const; + virtual bool GetLocalCertificate( + rtc::scoped_refptr* certificate) { + return false; + } - // Set the local TransportDescription to be used by DTLS and ICE channels - // that are part of this Transport. + // Get a copy of the remote certificate in use by the specified channel. + std::unique_ptr GetRemoteSSLCertificate(); + + // Create, destroy, and lookup the channels of this type by their components. + TransportChannelImpl* CreateChannel(int component); + + TransportChannelImpl* GetChannel(int component); + + bool HasChannel(int component) { + return (NULL != GetChannel(component)); + } + bool HasChannels(); + + void DestroyChannel(int component); + + // Set the local TransportDescription to be used by TransportChannels. bool SetLocalTransportDescription(const TransportDescription& description, ContentAction action, std::string* error_desc); - // Set the remote TransportDescription to be used by DTLS and ICE channels - // that are part of this Transport. + // Set the remote TransportDescription to be used by TransportChannels. bool SetRemoteTransportDescription(const TransportDescription& description, ContentAction action, std::string* error_desc); - void GetSslRole(rtc::SSLRole* ssl_role) const; + // Tells channels to start gathering candidates if necessary. + // Should be called after ConnectChannels() has been called at least once, + // which will happen in SetLocalTransportDescription. + void MaybeStartGathering(); + + // Resets all of the channels back to their initial state. They are no + // longer connecting. + void ResetChannels(); + + // Destroys every channel created so far. + void DestroyAllChannels(); - // TODO(deadbeef): Make this const. See comment in transportcontroller.h. bool GetStats(TransportStats* stats); - // The current local transport description, possibly used + // Called when one or more candidates are ready from the remote peer. + bool AddRemoteCandidates(const std::vector& candidates, + std::string* error); + bool RemoveRemoteCandidates(const std::vector& candidates, + std::string* error); + + virtual bool GetSslRole(rtc::SSLRole* ssl_role) const { return false; } + + // Must be called before channel is starting to connect. + virtual bool SetSslMaxProtocolVersion(rtc::SSLProtocolVersion version) { + return false; + } + + // The current local transport description, for use by derived classes + // when performing transport description negotiation, and possibly used // by the transport controller. const TransportDescription* local_description() const { return local_description_.get(); } - // The current remote transport description, possibly used + // The current remote transport description, for use by derived classes + // when performing transport description negotiation, and possibly used // by the transport controller. const TransportDescription* remote_description() const { return remote_description_.get(); } - // TODO(deadbeef): The methods below are only public for testing. Should make - // them utility functions or objects so they can be tested independently from - // this class. + protected: + // These are called by Create/DestroyChannel above in order to create or + // destroy the appropriate type of channel. + virtual TransportChannelImpl* CreateTransportChannel(int component) = 0; + virtual void DestroyTransportChannel(TransportChannelImpl* channel) = 0; - // Returns false if the certificate's identity does not match the fingerprint, - // or either is NULL. - bool VerifyCertificateFingerprint(const rtc::RTCCertificate* certificate, - const rtc::SSLFingerprint* fingerprint, - std::string* error_desc) const; + // Pushes down the transport parameters from the local description, such + // as the ICE ufrag and pwd. + // Derived classes can override, but must call the base as well. + virtual bool ApplyLocalTransportDescription(TransportChannelImpl* channel, + std::string* error_desc); - // Negotiates the SSL role based off the offer and answer as specified by - // RFC 4145, section-4.1. Returns false if the SSL role cannot be determined - // from the local description and remote description. - bool NegotiateRole(ContentAction local_role, - rtc::SSLRole* ssl_role, - std::string* error_desc) const; - - private: - TransportChannelImpl* GetChannel(int component); + // Pushes down remote ice credentials from the remote description to the + // transport channel. + virtual bool ApplyRemoteTransportDescription(TransportChannelImpl* ch, + std::string* error_desc); // Negotiates the transport parameters based on the current local and remote // transport description, such as the ICE role to use, and whether DTLS // should be activated. - // - // Called when an answer TransportDescription is applied. - bool NegotiateTransportDescription(ContentAction local_role, - std::string* error_desc); - - // Pushes down the transport parameters from the local description, such - // as the ICE ufrag and pwd. - bool ApplyLocalTransportDescription(TransportChannelImpl* channel, - std::string* error_desc); - - // Pushes down the transport parameters from the remote description to the - // transport channel. - bool ApplyRemoteTransportDescription(TransportChannelImpl* channel, - std::string* error_desc); + // Derived classes can negotiate their specific parameters here, but must call + // the base as well. + virtual bool NegotiateTransportDescription(ContentAction local_role, + std::string* error_desc); // Pushes down the transport parameters obtained via negotiation. - bool ApplyNegotiatedTransportDescription(TransportChannelImpl* channel, - std::string* error_desc); + // Derived classes can set their specific parameters here, but must call the + // base as well. + virtual bool ApplyNegotiatedTransportDescription( + TransportChannelImpl* channel, + std::string* error_desc); - const std::string mid_; - rtc::scoped_refptr certificate_; - rtc::SSLRole secure_role_ = rtc::SSL_CLIENT; - std::unique_ptr remote_fingerprint_; + // Returns false if the certificate's identity does not match the fingerprint, + // or either is NULL. + virtual bool VerifyCertificateFingerprint( + const rtc::RTCCertificate* certificate, + const rtc::SSLFingerprint* fingerprint, + std::string* error_desc) const; + + // Negotiates the SSL role based off the offer and answer as specified by + // RFC 4145, section-4.1. Returns false if the SSL role cannot be determined + // from the local description and remote description. + virtual bool NegotiateRole(ContentAction local_role, + rtc::SSLRole* ssl_role, + std::string* error_desc) const; + + private: + // If a candidate is not acceptable, returns false and sets error. + // Call this before calling OnRemoteCandidates. + bool VerifyCandidate(const Candidate& candidate, std::string* error); + bool VerifyCandidates(const Candidates& candidates, std::string* error); + + // Candidate component => TransportChannelImpl* + typedef std::map ChannelMap; + + // Helper function that invokes the given function on every channel. + typedef void (TransportChannelImpl::* TransportChannelFunc)(); + void CallChannels(TransportChannelFunc func); + + const std::string name_; + PortAllocator* const allocator_; + bool channels_destroyed_ = false; + IceRole ice_role_ = ICEROLE_UNKNOWN; + uint64_t tiebreaker_ = 0; + IceMode remote_ice_mode_ = ICEMODE_FULL; + IceConfig ice_config_; std::unique_ptr local_description_; std::unique_ptr remote_description_; bool local_description_set_ = false; bool remote_description_set_ = false; - // Candidate component => DTLS channel - std::map channels_; + ChannelMap channels_; - RTC_DISALLOW_COPY_AND_ASSIGN(JsepTransport); + RTC_DISALLOW_COPY_AND_ASSIGN(Transport); }; + } // namespace cricket -#endif // WEBRTC_P2P_BASE_JSEPTRANSPORT_H_ +#endif // WEBRTC_P2P_BASE_TRANSPORT_H_ diff --git a/webrtc/p2p/base/jseptransport_unittest.cc b/webrtc/p2p/base/transport_unittest.cc similarity index 54% rename from webrtc/p2p/base/jseptransport_unittest.cc rename to webrtc/p2p/base/transport_unittest.cc index 2f2510c476..d119e83367 100644 --- a/webrtc/p2p/base/jseptransport_unittest.cc +++ b/webrtc/p2p/base/transport_unittest.cc @@ -14,8 +14,10 @@ #include "webrtc/base/gunit.h" #include "webrtc/base/network.h" #include "webrtc/p2p/base/faketransportcontroller.h" +#include "webrtc/p2p/base/p2ptransport.h" -using cricket::JsepTransport; +using cricket::Transport; +using cricket::FakeTransport; using cricket::TransportChannel; using cricket::FakeTransportChannel; using cricket::IceRole; @@ -28,52 +30,201 @@ static const char kIcePwd1[] = "TESTICEPWD00000000000001"; static const char kIceUfrag2[] = "TESTICEUFRAG0002"; static const char kIcePwd2[] = "TESTICEPWD00000000000002"; -class JsepTransportTest : public testing::Test, public sigslot::has_slots<> { +class TransportTest : public testing::Test, + public sigslot::has_slots<> { public: - JsepTransportTest() - : transport_(new JsepTransport("test content name", nullptr)) {} - bool SetupChannel() { - fake_ice_channel_.reset(new FakeTransportChannel(transport_->mid(), 1)); - fake_dtls_channel_.reset(new FakeTransportChannel(transport_->mid(), 1)); - return transport_->AddChannel(fake_dtls_channel_.get(), 1); + TransportTest() + : transport_(new FakeTransport("test content name")), channel_(NULL) {} + ~TransportTest() { + transport_->DestroyAllChannels(); + } + bool SetupChannel() { + channel_ = CreateChannel(1); + return (channel_ != NULL); + } + FakeTransportChannel* CreateChannel(int component) { + return static_cast( + transport_->CreateChannel(component)); + } + void DestroyChannel() { + transport_->DestroyChannel(1); + channel_ = NULL; } - void DestroyChannel() { transport_->RemoveChannel(1); } protected: - std::unique_ptr fake_dtls_channel_; - std::unique_ptr fake_ice_channel_; - std::unique_ptr transport_; + std::unique_ptr transport_; + FakeTransportChannel* channel_; }; // This test verifies channels are created with proper ICE -// ufrag/password after a transport description is applied. -TEST_F(JsepTransportTest, TestChannelIceParameters) { +// role, tiebreaker and remote ice mode and credentials after offer and +// answer negotiations. +TEST_F(TransportTest, TestChannelIceParameters) { + transport_->SetIceRole(cricket::ICEROLE_CONTROLLING); + transport_->SetIceTiebreaker(99U); cricket::TransportDescription local_desc(kIceUfrag1, kIcePwd1); - ASSERT_TRUE(transport_->SetLocalTransportDescription( - local_desc, cricket::CA_OFFER, NULL)); + ASSERT_TRUE(transport_->SetLocalTransportDescription(local_desc, + cricket::CA_OFFER, + NULL)); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, transport_->ice_role()); EXPECT_TRUE(SetupChannel()); - EXPECT_EQ(cricket::ICEMODE_FULL, fake_dtls_channel_->remote_ice_mode()); - EXPECT_EQ(kIceUfrag1, fake_dtls_channel_->ice_ufrag()); - EXPECT_EQ(kIcePwd1, fake_dtls_channel_->ice_pwd()); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, channel_->GetIceRole()); + EXPECT_EQ(cricket::ICEMODE_FULL, channel_->remote_ice_mode()); + EXPECT_EQ(kIceUfrag1, channel_->ice_ufrag()); + EXPECT_EQ(kIcePwd1, channel_->ice_pwd()); cricket::TransportDescription remote_desc(kIceUfrag1, kIcePwd1); - ASSERT_TRUE(transport_->SetRemoteTransportDescription( - remote_desc, cricket::CA_ANSWER, NULL)); - EXPECT_EQ(cricket::ICEMODE_FULL, fake_dtls_channel_->remote_ice_mode()); - EXPECT_EQ(kIceUfrag1, fake_dtls_channel_->remote_ice_ufrag()); - EXPECT_EQ(kIcePwd1, fake_dtls_channel_->remote_ice_pwd()); + ASSERT_TRUE(transport_->SetRemoteTransportDescription(remote_desc, + cricket::CA_ANSWER, + NULL)); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, channel_->GetIceRole()); + EXPECT_EQ(99U, channel_->IceTiebreaker()); + EXPECT_EQ(cricket::ICEMODE_FULL, channel_->remote_ice_mode()); + // Changing the transport role from CONTROLLING to CONTROLLED. + transport_->SetIceRole(cricket::ICEROLE_CONTROLLED); + EXPECT_EQ(cricket::ICEROLE_CONTROLLED, channel_->GetIceRole()); + EXPECT_EQ(cricket::ICEMODE_FULL, channel_->remote_ice_mode()); + EXPECT_EQ(kIceUfrag1, channel_->remote_ice_ufrag()); + EXPECT_EQ(kIcePwd1, channel_->remote_ice_pwd()); } // Verifies that IceCredentialsChanged returns true when either ufrag or pwd // changed, and false in other cases. -TEST_F(JsepTransportTest, TestIceCredentialsChanged) { +TEST_F(TransportTest, TestIceCredentialsChanged) { EXPECT_TRUE(cricket::IceCredentialsChanged("u1", "p1", "u2", "p2")); EXPECT_TRUE(cricket::IceCredentialsChanged("u1", "p1", "u2", "p1")); EXPECT_TRUE(cricket::IceCredentialsChanged("u1", "p1", "u1", "p2")); EXPECT_FALSE(cricket::IceCredentialsChanged("u1", "p1", "u1", "p1")); } -TEST_F(JsepTransportTest, TestGetStats) { +// This test verifies that the callee's ICE role remains the same when the +// callee triggers an ICE restart. +// +// RFC5245 currently says that the role *should* change on an ICE restart, +// but this rule was intended for an ICE restart that occurs when an endpoint +// is changing to ICE lite (which we already handle). See discussion here: +// https://mailarchive.ietf.org/arch/msg/ice/C0_QRCTNcwtvUF12y28jQicPR10 +TEST_F(TransportTest, TestIceControlledToControllingOnIceRestart) { + EXPECT_TRUE(SetupChannel()); + transport_->SetIceRole(cricket::ICEROLE_CONTROLLED); + + cricket::TransportDescription desc(kIceUfrag1, kIcePwd1); + ASSERT_TRUE(transport_->SetRemoteTransportDescription(desc, + cricket::CA_OFFER, + NULL)); + ASSERT_TRUE(transport_->SetLocalTransportDescription(desc, + cricket::CA_ANSWER, + NULL)); + EXPECT_EQ(cricket::ICEROLE_CONTROLLED, transport_->ice_role()); + + cricket::TransportDescription new_local_desc(kIceUfrag2, kIcePwd2); + ASSERT_TRUE(transport_->SetLocalTransportDescription(new_local_desc, + cricket::CA_OFFER, + NULL)); + EXPECT_EQ(cricket::ICEROLE_CONTROLLED, transport_->ice_role()); + EXPECT_EQ(cricket::ICEROLE_CONTROLLED, channel_->GetIceRole()); +} + +// This test verifies that the caller's ICE role remains the same when the +// callee triggers an ICE restart. +// +// RFC5245 currently says that the role *should* change on an ICE restart, +// but this rule was intended for an ICE restart that occurs when an endpoint +// is changing to ICE lite (which we already handle). See discussion here: +// https://mailarchive.ietf.org/arch/msg/ice/C0_QRCTNcwtvUF12y28jQicPR10 +TEST_F(TransportTest, TestIceControllingToControlledOnIceRestart) { + EXPECT_TRUE(SetupChannel()); + transport_->SetIceRole(cricket::ICEROLE_CONTROLLING); + + cricket::TransportDescription desc(kIceUfrag1, kIcePwd1); + ASSERT_TRUE(transport_->SetLocalTransportDescription(desc, + cricket::CA_OFFER, + NULL)); + ASSERT_TRUE(transport_->SetRemoteTransportDescription(desc, + cricket::CA_ANSWER, + NULL)); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, transport_->ice_role()); + + cricket::TransportDescription new_local_desc(kIceUfrag2, kIcePwd2); + ASSERT_TRUE(transport_->SetLocalTransportDescription(new_local_desc, + cricket::CA_ANSWER, + NULL)); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, transport_->ice_role()); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, channel_->GetIceRole()); +} + +// This test verifies that the caller's ICE role is still controlling after the +// callee triggers ICE restart if the callee's ICE mode is LITE. +TEST_F(TransportTest, TestIceControllingOnIceRestartIfRemoteIsIceLite) { + EXPECT_TRUE(SetupChannel()); + transport_->SetIceRole(cricket::ICEROLE_CONTROLLING); + + cricket::TransportDescription desc(kIceUfrag1, kIcePwd1); + ASSERT_TRUE(transport_->SetLocalTransportDescription(desc, + cricket::CA_OFFER, + NULL)); + + cricket::TransportDescription remote_desc( + std::vector(), kIceUfrag1, kIcePwd1, cricket::ICEMODE_LITE, + cricket::CONNECTIONROLE_NONE, NULL); + ASSERT_TRUE(transport_->SetRemoteTransportDescription(remote_desc, + cricket::CA_ANSWER, + NULL)); + + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, transport_->ice_role()); + + cricket::TransportDescription new_local_desc(kIceUfrag2, kIcePwd2); + ASSERT_TRUE(transport_->SetLocalTransportDescription(new_local_desc, + cricket::CA_ANSWER, + NULL)); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, transport_->ice_role()); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, channel_->GetIceRole()); +} + +// Tests channel role is reversed after receiving ice-lite from remote. +TEST_F(TransportTest, TestSetRemoteIceLiteInOffer) { + transport_->SetIceRole(cricket::ICEROLE_CONTROLLED); + cricket::TransportDescription remote_desc( + std::vector(), kIceUfrag1, kIcePwd1, cricket::ICEMODE_LITE, + cricket::CONNECTIONROLE_ACTPASS, NULL); + ASSERT_TRUE(transport_->SetRemoteTransportDescription(remote_desc, + cricket::CA_OFFER, + NULL)); + cricket::TransportDescription local_desc(kIceUfrag1, kIcePwd1); + ASSERT_TRUE(transport_->SetLocalTransportDescription(local_desc, + cricket::CA_ANSWER, + NULL)); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, transport_->ice_role()); + EXPECT_TRUE(SetupChannel()); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, channel_->GetIceRole()); + EXPECT_EQ(cricket::ICEMODE_LITE, channel_->remote_ice_mode()); +} + +// Tests ice-lite in remote answer. +TEST_F(TransportTest, TestSetRemoteIceLiteInAnswer) { + transport_->SetIceRole(cricket::ICEROLE_CONTROLLING); + cricket::TransportDescription local_desc(kIceUfrag1, kIcePwd1); + ASSERT_TRUE(transport_->SetLocalTransportDescription(local_desc, + cricket::CA_OFFER, + NULL)); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, transport_->ice_role()); + EXPECT_TRUE(SetupChannel()); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, channel_->GetIceRole()); + // Channels will be created in ICEFULL_MODE. + EXPECT_EQ(cricket::ICEMODE_FULL, channel_->remote_ice_mode()); + cricket::TransportDescription remote_desc( + std::vector(), kIceUfrag1, kIcePwd1, cricket::ICEMODE_LITE, + cricket::CONNECTIONROLE_NONE, NULL); + ASSERT_TRUE(transport_->SetRemoteTransportDescription(remote_desc, + cricket::CA_ANSWER, + NULL)); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, channel_->GetIceRole()); + // After receiving remote description with ICEMODE_LITE, channel should + // have mode set to ICEMODE_LITE. + EXPECT_EQ(cricket::ICEMODE_LITE, channel_->remote_ice_mode()); +} + +TEST_F(TransportTest, TestGetStats) { EXPECT_TRUE(SetupChannel()); cricket::TransportStats stats; EXPECT_TRUE(transport_->GetStats(&stats)); @@ -95,7 +246,7 @@ TEST_F(JsepTransportTest, TestGetStats) { // Tests that VerifyCertificateFingerprint only returns true when the // certificate matches the fingerprint. -TEST_F(JsepTransportTest, TestVerifyCertificateFingerprint) { +TEST_F(TransportTest, TestVerifyCertificateFingerprint) { std::string error_desc; EXPECT_FALSE( transport_->VerifyCertificateFingerprint(nullptr, nullptr, &error_desc)); @@ -130,7 +281,7 @@ TEST_F(JsepTransportTest, TestVerifyCertificateFingerprint) { } // Tests that NegotiateRole sets the SSL role correctly. -TEST_F(JsepTransportTest, TestNegotiateRole) { +TEST_F(TransportTest, TestNegotiateRole) { TransportDescription local_desc(kIceUfrag1, kIcePwd1); TransportDescription remote_desc(kIceUfrag2, kIcePwd2); diff --git a/webrtc/p2p/base/transportchannel.h b/webrtc/p2p/base/transportchannel.h index c7c3c755f3..2f43e8f4f7 100644 --- a/webrtc/p2p/base/transportchannel.h +++ b/webrtc/p2p/base/transportchannel.h @@ -19,7 +19,7 @@ #include "webrtc/p2p/base/candidate.h" #include "webrtc/p2p/base/candidatepairinterface.h" #include "webrtc/p2p/base/packettransportinterface.h" -#include "webrtc/p2p/base/jseptransport.h" +#include "webrtc/p2p/base/transport.h" #include "webrtc/p2p/base/transportdescription.h" #include "webrtc/base/asyncpacketsocket.h" #include "webrtc/base/basictypes.h" diff --git a/webrtc/p2p/base/transportcontroller.cc b/webrtc/p2p/base/transportcontroller.cc index e4dc5afa21..8dd748647b 100644 --- a/webrtc/p2p/base/transportcontroller.cc +++ b/webrtc/p2p/base/transportcontroller.cc @@ -16,8 +16,14 @@ #include "webrtc/base/bind.h" #include "webrtc/base/checks.h" #include "webrtc/base/thread.h" +#include "webrtc/p2p/base/dtlstransport.h" +#include "webrtc/p2p/base/p2ptransport.h" #include "webrtc/p2p/base/port.h" +#ifdef HAVE_QUIC +#include "webrtc/p2p/quic/quictransport.h" +#endif // HAVE_QUIC + namespace cricket { enum { @@ -54,11 +60,10 @@ TransportController::TransportController(rtc::Thread* signaling_thread, true) {} TransportController::~TransportController() { - // Channel destructors may try to send packets, so this needs to happen on - // the network thread. network_thread_->Invoke( RTC_FROM_HERE, - rtc::Bind(&TransportController::DestroyAllChannels_n, this)); + rtc::Bind(&TransportController::DestroyAllTransports_n, this)); + signaling_thread_->Clear(this); } bool TransportController::SetSslMaxProtocolVersion( @@ -81,7 +86,7 @@ void TransportController::SetIceRole(IceRole ice_role) { } bool TransportController::GetSslRole(const std::string& transport_name, - rtc::SSLRole* role) const { + rtc::SSLRole* role) { return network_thread_->Invoke( RTC_FROM_HERE, rtc::Bind(&TransportController::GetSslRole_n, this, transport_name, role)); @@ -96,7 +101,7 @@ bool TransportController::SetLocalCertificate( bool TransportController::GetLocalCertificate( const std::string& transport_name, - rtc::scoped_refptr* certificate) const { + rtc::scoped_refptr* certificate) { return network_thread_->Invoke( RTC_FROM_HERE, rtc::Bind(&TransportController::GetLocalCertificate_n, this, transport_name, certificate)); @@ -104,7 +109,7 @@ bool TransportController::GetLocalCertificate( std::unique_ptr TransportController::GetRemoteSSLCertificate( - const std::string& transport_name) const { + const std::string& transport_name) { return network_thread_->Invoke>( RTC_FROM_HERE, rtc::Bind(&TransportController::GetRemoteSSLCertificate_n, this, transport_name)); @@ -154,7 +159,7 @@ bool TransportController::RemoveRemoteCandidates(const Candidates& candidates, } bool TransportController::ReadyForRemoteCandidates( - const std::string& transport_name) const { + const std::string& transport_name) { return network_thread_->Invoke( RTC_FROM_HERE, rtc::Bind(&TransportController::ReadyForRemoteCandidates_n, this, transport_name)); @@ -167,70 +172,42 @@ bool TransportController::GetStats(const std::string& transport_name, rtc::Bind(&TransportController::GetStats_n, this, transport_name, stats)); } -void TransportController::SetMetricsObserver( - webrtc::MetricsObserverInterface* metrics_observer) { - return network_thread_->Invoke( - RTC_FROM_HERE, rtc::Bind(&TransportController::SetMetricsObserver_n, this, - metrics_observer)); -} - TransportChannel* TransportController::CreateTransportChannel_n( const std::string& transport_name, int component) { RTC_DCHECK(network_thread_->IsCurrent()); - RefCountedChannel* existing_channel = GetChannel_n(transport_name, component); - if (existing_channel) { + auto it = FindChannel_n(transport_name, component); + if (it != channels_.end()) { // Channel already exists; increment reference count and return. - existing_channel->AddRef(); - return existing_channel->dtls(); + it->AddRef(); + return it->get(); } // Need to create a new channel. - JsepTransport* transport = GetOrCreateJsepTransport_n(transport_name); - - // Create DTLS channel wrapping ICE channel, and configure it. - TransportChannelImpl* ice = - CreateIceTransportChannel_n(transport_name, component); - // TODO(deadbeef): To support QUIC, would need to create a - // QuicTransportChannel here. What is "dtls" in this file would then become - // "dtls or quic". - TransportChannelImpl* dtls = - CreateDtlsTransportChannel_n(transport_name, component, ice); - dtls->SetMetricsObserver(metrics_observer_); - dtls->SetIceRole(ice_role_); - dtls->SetIceTiebreaker(ice_tiebreaker_); - dtls->SetIceConfig(ice_config_); - if (certificate_) { - bool set_cert_success = dtls->SetLocalCertificate(certificate_); - RTC_DCHECK(set_cert_success); - } - - // Connect to signals offered by the channels. Currently, the DTLS channel - // forwards signals from the ICE channel, so we only need to connect to the - // DTLS channel. In the future this won't be the case. - dtls->SignalWritableState.connect( + Transport* transport = GetOrCreateTransport_n(transport_name); + TransportChannelImpl* channel = transport->CreateChannel(component); + channel->SetMetricsObserver(metrics_observer_); + channel->SignalWritableState.connect( this, &TransportController::OnChannelWritableState_n); - dtls->SignalReceivingState.connect( + channel->SignalReceivingState.connect( this, &TransportController::OnChannelReceivingState_n); - dtls->SignalGatheringState.connect( + channel->SignalGatheringState.connect( this, &TransportController::OnChannelGatheringState_n); - dtls->SignalCandidateGathered.connect( + channel->SignalCandidateGathered.connect( this, &TransportController::OnChannelCandidateGathered_n); - dtls->SignalCandidatesRemoved.connect( + channel->SignalCandidatesRemoved.connect( this, &TransportController::OnChannelCandidatesRemoved_n); - dtls->SignalRoleConflict.connect( + channel->SignalRoleConflict.connect( this, &TransportController::OnChannelRoleConflict_n); - dtls->SignalStateChanged.connect( + channel->SignalStateChanged.connect( this, &TransportController::OnChannelStateChanged_n); - dtls->SignalDtlsHandshakeError.connect( + channel->SignalDtlsHandshakeError.connect( this, &TransportController::OnDtlsHandshakeError); - channels_.insert(channels_.end(), RefCountedChannel(dtls, ice))->AddRef(); - bool channel_added = transport->AddChannel(dtls, component); - RTC_DCHECK(channel_added); + channels_.insert(channels_.end(), RefCountedChannel(channel))->AddRef(); // Adding a channel could cause aggregate state to change. UpdateAggregateStates_n(); - return dtls; + return channel; } void TransportController::DestroyTransportChannel_n( @@ -238,68 +215,56 @@ void TransportController::DestroyTransportChannel_n( int component) { RTC_DCHECK(network_thread_->IsCurrent()); - auto it = GetChannelIterator_n(transport_name, component); + auto it = FindChannel_n(transport_name, component); if (it == channels_.end()) { LOG(LS_WARNING) << "Attempting to delete " << transport_name << " TransportChannel " << component << ", which doesn't exist."; return; } + it->DecRef(); if (it->ref() > 0) { return; } - channels_.erase(it); - JsepTransport* t = GetJsepTransport_n(transport_name); - bool channel_removed = t->RemoveChannel(component); - RTC_DCHECK(channel_removed); + channels_.erase(it); + Transport* transport = GetTransport_n(transport_name); + transport->DestroyChannel(component); // Just as we create a Transport when its first channel is created, // we delete it when its last channel is deleted. - if (!t->HasChannels()) { - transports_.erase(transport_name); + if (!transport->HasChannels()) { + DestroyTransport_n(transport_name); } - // Removing a channel could cause aggregate state to change. UpdateAggregateStates_n(); } -std::vector TransportController::transport_names_for_testing() { - std::vector ret; - for (const auto& kv : transports_) { - ret.push_back(kv.first); +const rtc::scoped_refptr& +TransportController::certificate_for_testing() { + return certificate_; +} + +Transport* TransportController::CreateTransport_n( + const std::string& transport_name) { + RTC_DCHECK(network_thread_->IsCurrent()); + +#ifdef HAVE_QUIC + if (quic_) { + return new QuicTransport(transport_name, port_allocator(), certificate_); } - return ret; +#endif // HAVE_QUIC + Transport* transport = new DtlsTransport( + transport_name, port_allocator(), certificate_); + return transport; } -std::vector TransportController::channels_for_testing() { - std::vector ret; - for (RefCountedChannel& channel : channels_) { - ret.push_back(channel.dtls()); - } - return ret; -} +Transport* TransportController::GetTransport_n( + const std::string& transport_name) { + RTC_DCHECK(network_thread_->IsCurrent()); -TransportChannelImpl* TransportController::get_channel_for_testing( - const std::string& transport_name, - int component) { - RefCountedChannel* ch = GetChannel_n(transport_name, component); - return ch ? ch->dtls() : nullptr; -} - -TransportChannelImpl* TransportController::CreateIceTransportChannel_n( - const std::string& transport_name, - int component) { - return new P2PTransportChannel(transport_name, component, port_allocator_); -} - -TransportChannelImpl* TransportController::CreateDtlsTransportChannel_n( - const std::string&, - int, - TransportChannelImpl* ice) { - DtlsTransportChannelWrapper* dtls = new DtlsTransportChannelWrapper(ice); - dtls->SetSslMaxProtocolVersion(ssl_max_version_); - return dtls; + auto iter = transports_.find(transport_name); + return (iter != transports_.end()) ? iter->second : nullptr; } void TransportController::OnMessage(rtc::Message* pmsg) { @@ -339,77 +304,58 @@ void TransportController::OnMessage(rtc::Message* pmsg) { } std::vector::iterator -TransportController::GetChannelIterator_n(const std::string& transport_name, - int component) { - RTC_DCHECK(network_thread_->IsCurrent()); +TransportController::FindChannel_n(const std::string& transport_name, + int component) { return std::find_if( channels_.begin(), channels_.end(), [transport_name, component](const RefCountedChannel& channel) { - return channel.dtls()->transport_name() == transport_name && - channel.dtls()->component() == component; + return channel->transport_name() == transport_name && + channel->component() == component; }); } -std::vector::const_iterator -TransportController::GetChannelIterator_n(const std::string& transport_name, - int component) const { - RTC_DCHECK(network_thread_->IsCurrent()); - return std::find_if( - channels_.begin(), channels_.end(), - [transport_name, component](const RefCountedChannel& channel) { - return channel.dtls()->transport_name() == transport_name && - channel.dtls()->component() == component; - }); -} - -const JsepTransport* TransportController::GetJsepTransport_n( - const std::string& transport_name) const { - RTC_DCHECK(network_thread_->IsCurrent()); - auto it = transports_.find(transport_name); - return (it == transports_.end()) ? nullptr : it->second.get(); -} - -JsepTransport* TransportController::GetJsepTransport_n( - const std::string& transport_name) { - RTC_DCHECK(network_thread_->IsCurrent()); - auto it = transports_.find(transport_name); - return (it == transports_.end()) ? nullptr : it->second.get(); -} - -const TransportController::RefCountedChannel* TransportController::GetChannel_n( - const std::string& transport_name, - int component) const { - RTC_DCHECK(network_thread_->IsCurrent()); - auto it = GetChannelIterator_n(transport_name, component); - return (it == channels_.end()) ? nullptr : &(*it); -} - -TransportController::RefCountedChannel* TransportController::GetChannel_n( - const std::string& transport_name, - int component) { - RTC_DCHECK(network_thread_->IsCurrent()); - auto it = GetChannelIterator_n(transport_name, component); - return (it == channels_.end()) ? nullptr : &(*it); -} - -JsepTransport* TransportController::GetOrCreateJsepTransport_n( +Transport* TransportController::GetOrCreateTransport_n( const std::string& transport_name) { RTC_DCHECK(network_thread_->IsCurrent()); - JsepTransport* transport = GetJsepTransport_n(transport_name); + Transport* transport = GetTransport_n(transport_name); if (transport) { return transport; } - transport = new JsepTransport(transport_name, certificate_); - transports_[transport_name] = std::unique_ptr(transport); + transport = CreateTransport_n(transport_name); + // The stuff below happens outside of CreateTransport_w so that unit tests + // can override CreateTransport_w to return a different type of transport. + transport->SetSslMaxProtocolVersion(ssl_max_version_); + transport->SetIceConfig(ice_config_); + transport->SetIceRole(ice_role_); + transport->SetIceTiebreaker(ice_tiebreaker_); + if (certificate_) { + transport->SetLocalCertificate(certificate_); + } + transports_[transport_name] = transport; + return transport; } -void TransportController::DestroyAllChannels_n() { +void TransportController::DestroyTransport_n( + const std::string& transport_name) { RTC_DCHECK(network_thread_->IsCurrent()); + + auto iter = transports_.find(transport_name); + if (iter != transports_.end()) { + delete iter->second; + transports_.erase(transport_name); + } +} + +void TransportController::DestroyAllTransports_n() { + RTC_DCHECK(network_thread_->IsCurrent()); + + for (const auto& kv : transports_) { + delete kv.second; + } transports_.clear(); - channels_.clear(); } bool TransportController::SetSslMaxProtocolVersion_n( @@ -427,82 +373,74 @@ bool TransportController::SetSslMaxProtocolVersion_n( void TransportController::SetIceConfig_n(const IceConfig& config) { RTC_DCHECK(network_thread_->IsCurrent()); - ice_config_ = config; - for (auto& channel : channels_) { - channel.dtls()->SetIceConfig(ice_config_); + for (const auto& kv : transports_) { + kv.second->SetIceConfig(ice_config_); } } void TransportController::SetIceRole_n(IceRole ice_role) { RTC_DCHECK(network_thread_->IsCurrent()); - ice_role_ = ice_role; - for (auto& channel : channels_) { - channel.dtls()->SetIceRole(ice_role_); + for (const auto& kv : transports_) { + kv.second->SetIceRole(ice_role_); } } bool TransportController::GetSslRole_n(const std::string& transport_name, - rtc::SSLRole* role) const { + rtc::SSLRole* role) { RTC_DCHECK(network_thread_->IsCurrent()); - const JsepTransport* t = GetJsepTransport_n(transport_name); + Transport* t = GetTransport_n(transport_name); if (!t) { return false; } - t->GetSslRole(role); - return true; + + return t->GetSslRole(role); } bool TransportController::SetLocalCertificate_n( const rtc::scoped_refptr& certificate) { RTC_DCHECK(network_thread_->IsCurrent()); - // Can't change a certificate, or set a null certificate. - if (certificate_ || !certificate) { + if (certificate_) { + return false; + } + if (!certificate) { return false; } certificate_ = certificate; - // Set certificate both for Transport, which verifies it matches the - // fingerprint in SDP... - for (auto& kv : transports_) { + for (const auto& kv : transports_) { kv.second->SetLocalCertificate(certificate_); } - // ... and for the DTLS channel, which needs it for the DTLS handshake. - for (auto& channel : channels_) { - bool set_cert_success = channel.dtls()->SetLocalCertificate(certificate); - RTC_DCHECK(set_cert_success); - } return true; } bool TransportController::GetLocalCertificate_n( const std::string& transport_name, - rtc::scoped_refptr* certificate) const { + rtc::scoped_refptr* certificate) { RTC_DCHECK(network_thread_->IsCurrent()); - const JsepTransport* t = GetJsepTransport_n(transport_name); + Transport* t = GetTransport_n(transport_name); if (!t) { return false; } + return t->GetLocalCertificate(certificate); } std::unique_ptr TransportController::GetRemoteSSLCertificate_n( - const std::string& transport_name) const { + const std::string& transport_name) { RTC_DCHECK(network_thread_->IsCurrent()); - // Get the certificate from the RTP channel's DTLS handshake. Should be - // identical to the RTCP channel's, since they were given the same remote - // fingerprint. - const RefCountedChannel* ch = GetChannel_n(transport_name, 1); - if (!ch) { + Transport* t = GetTransport_n(transport_name); + if (!t) { return nullptr; } - return ch->dtls()->GetRemoteSSLCertificate(); + + return t->GetRemoteSSLCertificate(); } bool TransportController::SetLocalTransportDescription_n( @@ -512,7 +450,7 @@ bool TransportController::SetLocalTransportDescription_n( std::string* err) { RTC_DCHECK(network_thread_->IsCurrent()); - JsepTransport* transport = GetJsepTransport_n(transport_name); + Transport* transport = GetTransport_n(transport_name); if (!transport) { // If we didn't find a transport, that's not an error; // it could have been deleted as a result of bundling. @@ -548,15 +486,7 @@ bool TransportController::SetRemoteTransportDescription_n( std::string* err) { RTC_DCHECK(network_thread_->IsCurrent()); - // If our role is ICEROLE_CONTROLLED and the remote endpoint supports only - // ice_lite, this local endpoint should take the CONTROLLING role. - // TODO(deadbeef): This is a session-level attribute, so it really shouldn't - // be in a TransportDescription in the first place... - if (ice_role_ == ICEROLE_CONTROLLED && tdesc.ice_mode == ICEMODE_LITE) { - SetIceRole_n(ICEROLE_CONTROLLING); - } - - JsepTransport* transport = GetJsepTransport_n(transport_name); + Transport* transport = GetTransport_n(transport_name); if (!transport) { // If we didn't find a transport, that's not an error; // it could have been deleted as a result of bundling. @@ -570,8 +500,8 @@ bool TransportController::SetRemoteTransportDescription_n( } void TransportController::MaybeStartGathering_n() { - for (auto& channel : channels_) { - channel.dtls()->MaybeStartGathering(); + for (const auto& kv : transports_) { + kv.second->MaybeStartGathering(); } } @@ -581,40 +511,19 @@ bool TransportController::AddRemoteCandidates_n( std::string* err) { RTC_DCHECK(network_thread_->IsCurrent()); - // Verify each candidate before passing down to the transport layer. - if (!VerifyCandidates(candidates, err)) { - return false; - } - - JsepTransport* transport = GetJsepTransport_n(transport_name); + Transport* transport = GetTransport_n(transport_name); if (!transport) { // If we didn't find a transport, that's not an error; // it could have been deleted as a result of bundling. return true; } - for (const Candidate& candidate : candidates) { - RefCountedChannel* channel = - GetChannel_n(transport_name, candidate.component()); - if (!channel) { - *err = "Candidate has an unknown component: " + candidate.ToString() + - " for content: " + transport_name; - return false; - } - channel->dtls()->AddRemoteCandidate(candidate); - } - return true; + return transport->AddRemoteCandidates(candidates, err); } bool TransportController::RemoveRemoteCandidates_n(const Candidates& candidates, std::string* err) { RTC_DCHECK(network_thread_->IsCurrent()); - - // Verify each candidate before passing down to the transport layer. - if (!VerifyCandidates(candidates, err)) { - return false; - } - std::map candidates_by_transport_name; for (const Candidate& cand : candidates) { RTC_DCHECK(!cand.transport_name().empty()); @@ -622,31 +531,23 @@ bool TransportController::RemoveRemoteCandidates_n(const Candidates& candidates, } bool result = true; - for (const auto& kv : candidates_by_transport_name) { - const std::string& transport_name = kv.first; - const Candidates& candidates = kv.second; - JsepTransport* transport = GetJsepTransport_n(transport_name); + for (auto kv : candidates_by_transport_name) { + Transport* transport = GetTransport_n(kv.first); if (!transport) { // If we didn't find a transport, that's not an error; // it could have been deleted as a result of bundling. continue; } - for (const Candidate& candidate : candidates) { - RefCountedChannel* channel = - GetChannel_n(transport_name, candidate.component()); - if (channel) { - channel->dtls()->RemoveRemoteCandidate(candidate); - } - } + result &= transport->RemoveRemoteCandidates(kv.second, err); } return result; } bool TransportController::ReadyForRemoteCandidates_n( - const std::string& transport_name) const { + const std::string& transport_name) { RTC_DCHECK(network_thread_->IsCurrent()); - const JsepTransport* transport = GetJsepTransport_n(transport_name); + Transport* transport = GetTransport_n(transport_name); if (!transport) { return false; } @@ -657,22 +558,13 @@ bool TransportController::GetStats_n(const std::string& transport_name, TransportStats* stats) { RTC_DCHECK(network_thread_->IsCurrent()); - JsepTransport* transport = GetJsepTransport_n(transport_name); + Transport* transport = GetTransport_n(transport_name); if (!transport) { return false; } return transport->GetStats(stats); } -void TransportController::SetMetricsObserver_n( - webrtc::MetricsObserverInterface* metrics_observer) { - RTC_DCHECK(network_thread_->IsCurrent()); - metrics_observer_ = metrics_observer; - for (auto& channel : channels_) { - channel.dtls()->SetMetricsObserver(metrics_observer); - } -} - void TransportController::OnChannelWritableState_n( rtc::PacketTransportInterface* transport) { RTC_DCHECK(network_thread_->IsCurrent()); @@ -762,21 +654,19 @@ void TransportController::UpdateAggregateStates_n() { bool any_gathering = false; bool all_done_gathering = !channels_.empty(); for (const auto& channel : channels_) { - any_receiving = any_receiving || channel.dtls()->receiving(); - any_failed = - any_failed || - channel.dtls()->GetState() == TransportChannelState::STATE_FAILED; - all_connected = all_connected && channel.dtls()->writable(); + any_receiving = any_receiving || channel->receiving(); + any_failed = any_failed || + channel->GetState() == TransportChannelState::STATE_FAILED; + all_connected = all_connected && channel->writable(); all_completed = - all_completed && channel.dtls()->writable() && - channel.dtls()->GetState() == TransportChannelState::STATE_COMPLETED && - channel.dtls()->GetIceRole() == ICEROLE_CONTROLLING && - channel.dtls()->gathering_state() == kIceGatheringComplete; + all_completed && channel->writable() && + channel->GetState() == TransportChannelState::STATE_COMPLETED && + channel->GetIceRole() == ICEROLE_CONTROLLING && + channel->gathering_state() == kIceGatheringComplete; any_gathering = - any_gathering || channel.dtls()->gathering_state() != kIceGatheringNew; - all_done_gathering = - all_done_gathering && - channel.dtls()->gathering_state() == kIceGatheringComplete; + any_gathering || channel->gathering_state() != kIceGatheringNew; + all_done_gathering = all_done_gathering && + channel->gathering_state() == kIceGatheringComplete; } if (any_failed) { @@ -816,4 +706,12 @@ void TransportController::OnDtlsHandshakeError(rtc::SSLHandshakeError error) { SignalDtlsHandshakeError(error); } +void TransportController::SetMetricsObserver( + webrtc::MetricsObserverInterface* metrics_observer) { + metrics_observer_ = metrics_observer; + for (auto channel : channels_) { + channel->SetMetricsObserver(metrics_observer); + } +} + } // namespace cricket diff --git a/webrtc/p2p/base/transportcontroller.h b/webrtc/p2p/base/transportcontroller.h index cdac4b6dcb..a408421ffb 100644 --- a/webrtc/p2p/base/transportcontroller.h +++ b/webrtc/p2p/base/transportcontroller.h @@ -20,9 +20,7 @@ #include "webrtc/base/sigslot.h" #include "webrtc/base/sslstreamadapter.h" #include "webrtc/p2p/base/candidate.h" -#include "webrtc/p2p/base/dtlstransportchannel.h" -#include "webrtc/p2p/base/jseptransport.h" -#include "webrtc/p2p/base/p2ptransportchannel.h" +#include "webrtc/p2p/base/transport.h" namespace rtc { class Thread; @@ -64,7 +62,7 @@ class TransportController : public sigslot::has_slots<>, void SetIceConfig(const IceConfig& config); void SetIceRole(IceRole ice_role); - bool GetSslRole(const std::string& transport_name, rtc::SSLRole* role) const; + bool GetSslRole(const std::string& transport_name, rtc::SSLRole* role); // Specifies the identity to use in this session. // Can only be called once. @@ -72,11 +70,10 @@ class TransportController : public sigslot::has_slots<>, const rtc::scoped_refptr& certificate); bool GetLocalCertificate( const std::string& transport_name, - rtc::scoped_refptr* certificate) const; - // Caller owns returned certificate. This method mainly exists for stats - // reporting. + rtc::scoped_refptr* certificate); + // Caller owns returned certificate std::unique_ptr GetRemoteSSLCertificate( - const std::string& transport_name) const; + const std::string& transport_name); bool SetLocalTransportDescription(const std::string& transport_name, const TransportDescription& tdesc, ContentAction action, @@ -92,12 +89,8 @@ class TransportController : public sigslot::has_slots<>, const Candidates& candidates, std::string* err); bool RemoveRemoteCandidates(const Candidates& candidates, std::string* err); - bool ReadyForRemoteCandidates(const std::string& transport_name) const; - // TODO(deadbeef): GetStats isn't const because all the way down to - // OpenSSLStreamAdapter, - // GetSslCipherSuite and GetDtlsSrtpCryptoSuite are not const. Fix this. + bool ReadyForRemoteCandidates(const std::string& transport_name); bool GetStats(const std::string& transport_name, TransportStats* stats); - void SetMetricsObserver(webrtc::MetricsObserverInterface* metrics_observer); // Creates a channel if it doesn't exist. Otherwise, increments a reference // count and returns an existing channel. @@ -113,17 +106,6 @@ class TransportController : public sigslot::has_slots<>, void use_quic() { quic_ = true; } bool quic() const { return quic_; } - // TODO(deadbeef): Remove all for_testing methods! - const rtc::scoped_refptr& certificate_for_testing() - const { - return certificate_; - } - std::vector transport_names_for_testing(); - std::vector channels_for_testing(); - TransportChannelImpl* get_channel_for_testing( - const std::string& transport_name, - int component); - // All of these signals are fired on the signalling thread. // If any transport failed => failed, @@ -146,33 +128,31 @@ class TransportController : public sigslot::has_slots<>, sigslot::signal1 SignalCandidatesRemoved; + // for unit test + const rtc::scoped_refptr& certificate_for_testing(); + sigslot::signal1 SignalDtlsHandshakeError; + void SetMetricsObserver(webrtc::MetricsObserverInterface* metrics_observer); + protected: - // TODO(deadbeef): Get rid of these virtual methods. Used by - // FakeTransportController currently, but FakeTransportController shouldn't - // even be functioning by subclassing TransportController. - virtual TransportChannelImpl* CreateIceTransportChannel_n( - const std::string& transport_name, - int component); - virtual TransportChannelImpl* CreateDtlsTransportChannel_n( - const std::string& transport_name, - int component, - TransportChannelImpl* ice); + // Protected and virtual so we can override it in unit tests. + virtual Transport* CreateTransport_n(const std::string& transport_name); + + // For unit tests + const std::map& transports() { return transports_; } + Transport* GetTransport_n(const std::string& transport_name); private: void OnMessage(rtc::Message* pmsg) override; - // This structure groups the DTLS and ICE channels, and helps keep track of - // how many external objects (BaseChannels) reference each channel. + // It's the Transport that's currently responsible for creating/destroying + // channels, but the TransportController keeps track of how many external + // objects (BaseChannels) reference each channel. struct RefCountedChannel { - RefCountedChannel() = default; - // TODO(deadbeef): Change the types of |dtls| and |ice| to - // DtlsTransportChannelWrapper and P2PTransportChannelWrapper, - // once TransportChannelImpl is removed. - explicit RefCountedChannel(TransportChannelImpl* dtls, - TransportChannelImpl* ice) - : ice_(ice), dtls_(dtls), ref_(0) {} + RefCountedChannel() : impl_(nullptr), ref_(0) {} + explicit RefCountedChannel(TransportChannelImpl* impl) + : impl_(impl), ref_(0) {} void AddRef() { ++ref_; } void DecRef() { @@ -181,51 +161,33 @@ class TransportController : public sigslot::has_slots<>, } int ref() const { return ref_; } - // Currently, all ICE-related calls still go through this DTLS channel. But - // that will change once we get rid of TransportChannelImpl, and the DTLS - // channel interface no longer includes ICE-specific methods. - const TransportChannelImpl* dtls() const { return dtls_.get(); } - TransportChannelImpl* dtls() { return dtls_.get(); } - const TransportChannelImpl* ice() const { return ice_.get(); } - TransportChannelImpl* ice() { return ice_.get(); } + TransportChannelImpl* get() const { return impl_; } + TransportChannelImpl* operator->() const { return impl_; } private: - std::unique_ptr ice_; - std::unique_ptr dtls_; - int ref_ = 0; + TransportChannelImpl* impl_; + int ref_; }; - // Helper functions to get a channel or transport, or iterator to it (in case - // it needs to be erased). - std::vector::iterator GetChannelIterator_n( + std::vector::iterator FindChannel_n( const std::string& transport_name, int component); - std::vector::const_iterator GetChannelIterator_n( - const std::string& transport_name, - int component) const; - const JsepTransport* GetJsepTransport_n( - const std::string& transport_name) const; - JsepTransport* GetJsepTransport_n(const std::string& transport_name); - const RefCountedChannel* GetChannel_n(const std::string& transport_name, - int component) const; - RefCountedChannel* GetChannel_n(const std::string& transport_name, - int component); - JsepTransport* GetOrCreateJsepTransport_n(const std::string& transport_name); - void DestroyAllChannels_n(); + Transport* GetOrCreateTransport_n(const std::string& transport_name); + void DestroyTransport_n(const std::string& transport_name); + void DestroyAllTransports_n(); bool SetSslMaxProtocolVersion_n(rtc::SSLProtocolVersion version); void SetIceConfig_n(const IceConfig& config); void SetIceRole_n(IceRole ice_role); - bool GetSslRole_n(const std::string& transport_name, - rtc::SSLRole* role) const; + bool GetSslRole_n(const std::string& transport_name, rtc::SSLRole* role); bool SetLocalCertificate_n( const rtc::scoped_refptr& certificate); bool GetLocalCertificate_n( const std::string& transport_name, - rtc::scoped_refptr* certificate) const; + rtc::scoped_refptr* certificate); std::unique_ptr GetRemoteSSLCertificate_n( - const std::string& transport_name) const; + const std::string& transport_name); bool SetLocalTransportDescription_n(const std::string& transport_name, const TransportDescription& tdesc, ContentAction action, @@ -239,9 +201,8 @@ class TransportController : public sigslot::has_slots<>, const Candidates& candidates, std::string* err); bool RemoveRemoteCandidates_n(const Candidates& candidates, std::string* err); - bool ReadyForRemoteCandidates_n(const std::string& transport_name) const; + bool ReadyForRemoteCandidates_n(const std::string& transport_name); bool GetStats_n(const std::string& transport_name, TransportStats* stats); - void SetMetricsObserver_n(webrtc::MetricsObserverInterface* metrics_observer); // Handlers for signals from Transport. void OnChannelWritableState_n(rtc::PacketTransportInterface* transport); @@ -261,21 +222,24 @@ class TransportController : public sigslot::has_slots<>, rtc::Thread* const signaling_thread_ = nullptr; rtc::Thread* const network_thread_ = nullptr; - PortAllocator* const port_allocator_ = nullptr; + typedef std::map TransportMap; + TransportMap transports_; - std::map> transports_; std::vector channels_; + PortAllocator* const port_allocator_ = nullptr; + rtc::SSLProtocolVersion ssl_max_version_ = rtc::SSL_PROTOCOL_DTLS_12; + // Aggregate state for TransportChannelImpls. IceConnectionState connection_state_ = kIceConnectionConnecting; bool receiving_ = false; IceGatheringState gathering_state_ = kIceGatheringNew; + // TODO(deadbeef): Move the fields below down to the transports themselves IceConfig ice_config_; IceRole ice_role_ = ICEROLE_CONTROLLING; bool redetermine_role_on_ice_restart_; uint64_t ice_tiebreaker_ = rtc::CreateRandomId64(); - rtc::SSLProtocolVersion ssl_max_version_ = rtc::SSL_PROTOCOL_DTLS_12; rtc::scoped_refptr certificate_; rtc::AsyncInvoker invoker_; // True if QUIC is used instead of DTLS. diff --git a/webrtc/p2p/base/transportcontroller_unittest.cc b/webrtc/p2p/base/transportcontroller_unittest.cc index 9f30518af7..d1f3816af9 100644 --- a/webrtc/p2p/base/transportcontroller_unittest.cc +++ b/webrtc/p2p/base/transportcontroller_unittest.cc @@ -749,48 +749,4 @@ TEST_F(TransportControllerTest, IceRoleNotRedetermined) { EXPECT_EQ(ICEROLE_CONTROLLED, channel->GetIceRole()); } -// Tests channel role is reversed after receiving ice-lite from remote. -TEST_F(TransportControllerTest, TestSetRemoteIceLiteInOffer) { - FakeTransportChannel* channel = CreateChannel("audio", 1); - ASSERT_NE(nullptr, channel); - std::string err; - - transport_controller_->SetIceRole(ICEROLE_CONTROLLED); - TransportDescription remote_desc(std::vector(), kIceUfrag1, - kIcePwd1, ICEMODE_LITE, - CONNECTIONROLE_ACTPASS, nullptr); - EXPECT_TRUE(transport_controller_->SetRemoteTransportDescription( - "audio", remote_desc, CA_OFFER, &err)); - TransportDescription local_desc(kIceUfrag1, kIcePwd1); - ASSERT_TRUE(transport_controller_->SetLocalTransportDescription( - "audio", local_desc, CA_ANSWER, nullptr)); - - EXPECT_EQ(ICEROLE_CONTROLLING, channel->GetIceRole()); - EXPECT_EQ(ICEMODE_LITE, channel->remote_ice_mode()); -} - -// Tests ice-lite in remote answer. -TEST_F(TransportControllerTest, TestSetRemoteIceLiteInAnswer) { - FakeTransportChannel* channel = CreateChannel("audio", 1); - ASSERT_NE(nullptr, channel); - std::string err; - - transport_controller_->SetIceRole(ICEROLE_CONTROLLING); - TransportDescription local_desc(kIceUfrag1, kIcePwd1); - ASSERT_TRUE(transport_controller_->SetLocalTransportDescription( - "audio", local_desc, CA_OFFER, nullptr)); - EXPECT_EQ(ICEROLE_CONTROLLING, channel->GetIceRole()); - // Channels will be created in ICEFULL_MODE. - EXPECT_EQ(ICEMODE_FULL, channel->remote_ice_mode()); - TransportDescription remote_desc(std::vector(), kIceUfrag1, - kIcePwd1, ICEMODE_LITE, CONNECTIONROLE_NONE, - nullptr); - ASSERT_TRUE(transport_controller_->SetRemoteTransportDescription( - "audio", remote_desc, CA_ANSWER, nullptr)); - EXPECT_EQ(ICEROLE_CONTROLLING, channel->GetIceRole()); - // After receiving remote description with ICEMODE_LITE, channel should - // have mode set to ICEMODE_LITE. - EXPECT_EQ(ICEMODE_LITE, channel->remote_ice_mode()); -} - } // namespace cricket { diff --git a/webrtc/p2p/client/socketmonitor.h b/webrtc/p2p/client/socketmonitor.h index b13be74735..00190d9d44 100644 --- a/webrtc/p2p/client/socketmonitor.h +++ b/webrtc/p2p/client/socketmonitor.h @@ -16,7 +16,7 @@ #include "webrtc/base/criticalsection.h" #include "webrtc/base/sigslot.h" #include "webrtc/base/thread.h" -#include "webrtc/p2p/base/jseptransport.h" // for ConnectionInfos +#include "webrtc/p2p/base/transport.h" // for ConnectionInfos // TODO(pthatcher): Move these to connectionmonitor.h and // connectionmonitor.cc, or just move them into channel.cc diff --git a/webrtc/p2p/quic/quictransport.h b/webrtc/p2p/quic/quictransport.h index 5e834e02f5..14bd13f3b2 100644 --- a/webrtc/p2p/quic/quictransport.h +++ b/webrtc/p2p/quic/quictransport.h @@ -15,7 +15,7 @@ #include #include -#include "webrtc/p2p/base/jseptransport.h" +#include "webrtc/p2p/base/transport.h" #include "webrtc/p2p/quic/quictransportchannel.h" namespace cricket { @@ -23,11 +23,7 @@ namespace cricket { class P2PTransportChannel; class PortAllocator; -// TODO(deadbeef): To get QUIC working with TransportController again, would -// need to merge this class with Transport (or make separate DTLS/QUIC -// subclasses). The only difference between the two (as of typing this) is that -// the QUIC channel *requires* a fingerprint, whereas the DTLS channel can -// operate in a passthrough mode when SDES is used. +// TODO(mikescarlett): Refactor to avoid code duplication with DtlsTransport. class QuicTransport : public Transport { public: QuicTransport(const std::string& name, diff --git a/webrtc/pc/channel_unittest.cc b/webrtc/pc/channel_unittest.cc index 85a7e697a8..52bf8f53dc 100644 --- a/webrtc/pc/channel_unittest.cc +++ b/webrtc/pc/channel_unittest.cc @@ -21,7 +21,6 @@ #include "webrtc/media/base/mediachannel.h" #include "webrtc/media/base/testutils.h" #include "webrtc/p2p/base/faketransportcontroller.h" -#include "webrtc/p2p/base/transportchannelimpl.h" #include "webrtc/pc/channel.h" #define MAYBE_SKIP_TEST(feature) \ @@ -272,22 +271,17 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { return channel1_->RemoveRecvStream(id); } - std::vector GetChannels1() { - return transport_controller1_->channels_for_testing(); + cricket::FakeTransport* GetTransport1() { + std::string name = channel1_->content_name(); + return network_thread_->Invoke( + RTC_FROM_HERE, + [this, name] { return transport_controller1_->GetTransport_n(name); }); } - - std::vector GetChannels2() { - return transport_controller2_->channels_for_testing(); - } - - cricket::FakeTransportChannel* GetFakeChannel1(int component) { - return transport_controller1_->GetFakeTransportChannel_n( - channel1_->content_name(), component); - } - - cricket::FakeTransportChannel* GetFakeChannel2(int component) { - return transport_controller2_->GetFakeTransportChannel_n( - channel2_->content_name(), component); + cricket::FakeTransport* GetTransport2() { + std::string name = channel2_->content_name(); + return network_thread_->Invoke( + RTC_FROM_HERE, + [this, name] { return transport_controller2_->GetTransport_n(name); }); } void SendRtp1() { @@ -1020,8 +1014,10 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { CreateChannels(0, 0); EXPECT_TRUE(SendInitiate()); EXPECT_TRUE(SendAccept()); - EXPECT_EQ(1U, GetChannels1().size()); - EXPECT_EQ(1U, GetChannels2().size()); + ASSERT_TRUE(GetTransport1()); + ASSERT_TRUE(GetTransport2()); + EXPECT_EQ(1U, GetTransport1()->channels().size()); + EXPECT_EQ(1U, GetTransport2()->channels().size()); SendRtp1(); SendRtp2(); WaitForThreads(); @@ -1049,8 +1045,10 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { CreateChannels(0, 0); EXPECT_TRUE(SendInitiate()); EXPECT_TRUE(SendAccept()); - EXPECT_EQ(1U, GetChannels1().size()); - EXPECT_EQ(1U, GetChannels2().size()); + ASSERT_TRUE(GetTransport1()); + ASSERT_TRUE(GetTransport2()); + EXPECT_EQ(1U, GetTransport1()->channels().size()); + EXPECT_EQ(1U, GetTransport2()->channels().size()); SendRtcp1(); SendRtcp2(); WaitForThreads(); @@ -1063,8 +1061,10 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { CreateChannels(0, RTCP); EXPECT_TRUE(SendInitiate()); EXPECT_TRUE(SendAccept()); - EXPECT_EQ(1U, GetChannels1().size()); - EXPECT_EQ(2U, GetChannels2().size()); + ASSERT_TRUE(GetTransport1()); + ASSERT_TRUE(GetTransport2()); + EXPECT_EQ(1U, GetTransport1()->channels().size()); + EXPECT_EQ(2U, GetTransport2()->channels().size()); SendRtcp1(); SendRtcp2(); WaitForThreads(); @@ -1077,8 +1077,10 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { CreateChannels(RTCP, 0); EXPECT_TRUE(SendInitiate()); EXPECT_TRUE(SendAccept()); - EXPECT_EQ(2U, GetChannels1().size()); - EXPECT_EQ(1U, GetChannels2().size()); + ASSERT_TRUE(GetTransport1()); + ASSERT_TRUE(GetTransport2()); + EXPECT_EQ(2U, GetTransport1()->channels().size()); + EXPECT_EQ(1U, GetTransport2()->channels().size()); SendRtcp1(); SendRtcp2(); WaitForThreads(); @@ -1091,8 +1093,10 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { CreateChannels(RTCP, RTCP); EXPECT_TRUE(SendInitiate()); EXPECT_TRUE(SendAccept()); - EXPECT_EQ(2U, GetChannels1().size()); - EXPECT_EQ(2U, GetChannels2().size()); + ASSERT_TRUE(GetTransport1()); + ASSERT_TRUE(GetTransport2()); + EXPECT_EQ(2U, GetTransport1()->channels().size()); + EXPECT_EQ(2U, GetTransport2()->channels().size()); SendRtcp1(); SendRtcp2(); WaitForThreads(); @@ -1107,8 +1111,10 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { CreateChannels(RTCP | RTCP_MUX, RTCP); EXPECT_TRUE(SendInitiate()); EXPECT_TRUE(SendAccept()); - EXPECT_EQ(2U, GetChannels1().size()); - EXPECT_EQ(2U, GetChannels2().size()); + ASSERT_TRUE(GetTransport1()); + ASSERT_TRUE(GetTransport2()); + EXPECT_EQ(2U, GetTransport1()->channels().size()); + EXPECT_EQ(2U, GetTransport2()->channels().size()); SendRtcp1(); SendRtcp2(); WaitForThreads(); @@ -1122,10 +1128,12 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { void SendRtcpMuxToRtcpMux() { CreateChannels(RTCP | RTCP_MUX, RTCP | RTCP_MUX); EXPECT_TRUE(SendInitiate()); - EXPECT_EQ(2U, GetChannels1().size()); - EXPECT_EQ(1U, GetChannels2().size()); + ASSERT_TRUE(GetTransport1()); + ASSERT_TRUE(GetTransport2()); + EXPECT_EQ(2U, GetTransport1()->channels().size()); + EXPECT_EQ(1U, GetTransport2()->channels().size()); EXPECT_TRUE(SendAccept()); - EXPECT_EQ(1U, GetChannels1().size()); + EXPECT_EQ(1U, GetTransport1()->channels().size()); SendRtp1(); SendRtp2(); SendRtcp1(); @@ -1147,8 +1155,10 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { CreateChannels(RTCP | RTCP_MUX, RTCP | RTCP_MUX); channel1_->ActivateRtcpMux(); EXPECT_TRUE(SendInitiate()); - EXPECT_EQ(1U, GetChannels1().size()); - EXPECT_EQ(1U, GetChannels2().size()); + ASSERT_TRUE(GetTransport1()); + ASSERT_TRUE(GetTransport2()); + EXPECT_EQ(1U, GetTransport1()->channels().size()); + EXPECT_EQ(1U, GetTransport2()->channels().size()); EXPECT_TRUE(SendAccept()); SendRtp1(); SendRtp2(); @@ -1171,10 +1181,12 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { CreateChannels(RTCP | RTCP_MUX, RTCP | RTCP_MUX); channel2_->ActivateRtcpMux(); EXPECT_TRUE(SendInitiate()); - EXPECT_EQ(2U, GetChannels1().size()); - EXPECT_EQ(1U, GetChannels2().size()); + ASSERT_TRUE(GetTransport1()); + ASSERT_TRUE(GetTransport2()); + EXPECT_EQ(2U, GetTransport1()->channels().size()); + EXPECT_EQ(1U, GetTransport2()->channels().size()); EXPECT_TRUE(SendAccept()); - EXPECT_EQ(1U, GetChannels1().size()); + EXPECT_EQ(1U, GetTransport1()->channels().size()); SendRtp1(); SendRtp2(); SendRtcp1(); @@ -1197,10 +1209,12 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { channel1_->ActivateRtcpMux(); channel2_->ActivateRtcpMux(); EXPECT_TRUE(SendInitiate()); - EXPECT_EQ(1U, GetChannels1().size()); - EXPECT_EQ(1U, GetChannels2().size()); + ASSERT_TRUE(GetTransport1()); + ASSERT_TRUE(GetTransport2()); + EXPECT_EQ(1U, GetTransport1()->channels().size()); + EXPECT_EQ(1U, GetTransport2()->channels().size()); EXPECT_TRUE(SendAccept()); - EXPECT_EQ(1U, GetChannels1().size()); + EXPECT_EQ(1U, GetTransport1()->channels().size()); SendRtp1(); SendRtp2(); SendRtcp1(); @@ -1222,8 +1236,10 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { CreateChannels(RTCP | RTCP_MUX, RTCP); channel1_->ActivateRtcpMux(); EXPECT_TRUE(SendInitiate()); - EXPECT_EQ(1U, GetChannels1().size()); - EXPECT_EQ(2U, GetChannels2().size()); + ASSERT_TRUE(GetTransport1()); + ASSERT_TRUE(GetTransport2()); + EXPECT_EQ(1U, GetTransport1()->channels().size()); + EXPECT_EQ(2U, GetTransport2()->channels().size()); EXPECT_FALSE(SendAccept()); } @@ -1231,8 +1247,10 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { void SendEarlyRtcpMuxToRtcp() { CreateChannels(RTCP | RTCP_MUX, RTCP); EXPECT_TRUE(SendInitiate()); - EXPECT_EQ(2U, GetChannels1().size()); - EXPECT_EQ(2U, GetChannels2().size()); + ASSERT_TRUE(GetTransport1()); + ASSERT_TRUE(GetTransport2()); + EXPECT_EQ(2U, GetTransport1()->channels().size()); + EXPECT_EQ(2U, GetTransport2()->channels().size()); // RTCP can be sent before the call is accepted, if the transport is ready. // It should not be muxed though, as the remote side doesn't support mux. @@ -1249,7 +1267,7 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { // Complete call setup and ensure everything is still OK. EXPECT_TRUE(SendAccept()); - EXPECT_EQ(2U, GetChannels1().size()); + EXPECT_EQ(2U, GetTransport1()->channels().size()); SendRtcp1(); SendRtcp2(); WaitForThreads(); @@ -1264,8 +1282,10 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { void SendEarlyRtcpMuxToRtcpMux() { CreateChannels(RTCP | RTCP_MUX, RTCP | RTCP_MUX); EXPECT_TRUE(SendInitiate()); - EXPECT_EQ(2U, GetChannels1().size()); - EXPECT_EQ(1U, GetChannels2().size()); + ASSERT_TRUE(GetTransport1()); + ASSERT_TRUE(GetTransport2()); + EXPECT_EQ(2U, GetTransport1()->channels().size()); + EXPECT_EQ(1U, GetTransport2()->channels().size()); // RTCP can't be sent yet, since the RTCP transport isn't writable, and // we haven't yet received the accept that says we should mux. @@ -1281,7 +1301,7 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { // Complete call setup and ensure everything is still OK. EXPECT_TRUE(SendAccept()); - EXPECT_EQ(1U, GetChannels1().size()); + EXPECT_EQ(1U, GetTransport1()->channels().size()); SendRtcp1(); SendRtcp2(); WaitForThreads(); @@ -1369,8 +1389,10 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { EXPECT_TRUE(SendProvisionalAnswer()); EXPECT_TRUE(channel1_->secure()); EXPECT_TRUE(channel2_->secure()); - EXPECT_EQ(2U, GetChannels1().size()); - EXPECT_EQ(2U, GetChannels2().size()); + ASSERT_TRUE(GetTransport1()); + ASSERT_TRUE(GetTransport2()); + EXPECT_EQ(2U, GetTransport1()->channels().size()); + EXPECT_EQ(2U, GetTransport2()->channels().size()); WaitForThreads(); // Wait for 'sending' flag go through network thread. SendCustomRtcp1(kSsrc1); SendCustomRtp1(kSsrc1, ++sequence_number1_1); @@ -1387,8 +1409,8 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { // Complete call setup and ensure everything is still OK. EXPECT_TRUE(SendFinalAnswer()); - EXPECT_EQ(1U, GetChannels1().size()); - EXPECT_EQ(1U, GetChannels2().size()); + EXPECT_EQ(1U, GetTransport1()->channels().size()); + EXPECT_EQ(1U, GetTransport2()->channels().size()); EXPECT_TRUE(channel1_->secure()); EXPECT_TRUE(channel2_->secure()); SendCustomRtcp1(kSsrc1); @@ -1454,8 +1476,10 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { CreateChannels(0, 0); EXPECT_TRUE(SendInitiate()); EXPECT_TRUE(SendAccept()); - EXPECT_EQ(1U, GetChannels1().size()); - EXPECT_EQ(1U, GetChannels2().size()); + ASSERT_TRUE(GetTransport1()); + ASSERT_TRUE(GetTransport2()); + EXPECT_EQ(1U, GetTransport1()->channels().size()); + EXPECT_EQ(1U, GetTransport2()->channels().size()); SendRtp1(); SendRtp2(); WaitForThreads(); @@ -1466,7 +1490,7 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { // Lose writability, which should fail. network_thread_->Invoke( - RTC_FROM_HERE, [this] { GetFakeChannel1(1)->SetWritable(false); }); + RTC_FROM_HERE, [this] { GetTransport1()->SetWritable(false); }); SendRtp1(); SendRtp2(); WaitForThreads(); @@ -1475,7 +1499,7 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { // Regain writability network_thread_->Invoke( - RTC_FROM_HERE, [this] { GetFakeChannel1(1)->SetWritable(true); }); + RTC_FROM_HERE, [this] { GetTransport1()->SetWritable(true); }); EXPECT_TRUE(media_channel1_->sending()); SendRtp1(); SendRtp2(); @@ -1487,7 +1511,7 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { // Lose writability completely network_thread_->Invoke( - RTC_FROM_HERE, [this] { GetFakeChannel1(1)->SetDestination(nullptr); }); + RTC_FROM_HERE, [this] { GetTransport1()->SetDestination(NULL); }); EXPECT_TRUE(media_channel1_->sending()); // Should fail also. @@ -1499,7 +1523,7 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { // Gain writability back network_thread_->Invoke(RTC_FROM_HERE, [this] { - GetFakeChannel1(1)->SetDestination(GetFakeChannel2(1)); + GetTransport1()->SetDestination(GetTransport2()); }); EXPECT_TRUE(media_channel1_->sending()); SendRtp1(); @@ -1528,11 +1552,13 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { } CreateChannels(flags, flags); EXPECT_TRUE(SendInitiate()); - EXPECT_EQ(2U, GetChannels1().size()); - EXPECT_EQ(expected_channels, GetChannels2().size()); + ASSERT_TRUE(GetTransport1()); + ASSERT_TRUE(GetTransport2()); + EXPECT_EQ(2U, GetTransport1()->channels().size()); + EXPECT_EQ(expected_channels, GetTransport2()->channels().size()); EXPECT_TRUE(SendAccept()); - EXPECT_EQ(expected_channels, GetChannels1().size()); - EXPECT_EQ(expected_channels, GetChannels2().size()); + EXPECT_EQ(expected_channels, GetTransport1()->channels().size()); + EXPECT_EQ(expected_channels, GetTransport2()->channels().size()); EXPECT_TRUE(channel1_->bundle_filter()->FindPayloadType(pl_type1)); EXPECT_TRUE(channel2_->bundle_filter()->FindPayloadType(pl_type1)); EXPECT_FALSE(channel1_->bundle_filter()->FindPayloadType(pl_type2)); @@ -1712,8 +1738,10 @@ class ChannelTest : public testing::Test, public sigslot::has_slots<> { CreateChannels(RTCP, RTCP); EXPECT_TRUE(SendInitiate()); EXPECT_TRUE(SendAccept()); - EXPECT_EQ(2U, GetChannels1().size()); - EXPECT_EQ(2U, GetChannels2().size()); + ASSERT_TRUE(GetTransport1()); + ASSERT_TRUE(GetTransport2()); + EXPECT_EQ(2U, GetTransport1()->channels().size()); + EXPECT_EQ(2U, GetTransport2()->channels().size()); // Send RTCP1 from a different thread. ScopedCallThread send_rtcp([this] { SendRtcp1(); }); diff --git a/webrtc/pc/channelmanager_unittest.cc b/webrtc/pc/channelmanager_unittest.cc index 4e14453c63..174d064b05 100644 --- a/webrtc/pc/channelmanager_unittest.cc +++ b/webrtc/pc/channelmanager_unittest.cc @@ -143,6 +143,31 @@ TEST_F(ChannelManagerTest, CreateDestroyChannelsOnThread) { cm_->Terminate(); } +// Test that we fail to create a voice/video channel if the session is unable +// to create a cricket::TransportChannel +TEST_F(ChannelManagerTest, NoTransportChannelTest) { + EXPECT_TRUE(cm_->Init()); + transport_controller_->set_fail_channel_creation(true); + // The test is useless unless the session does not fail creating + // cricket::TransportChannel. + ASSERT_TRUE(transport_controller_->CreateTransportChannel_n( + "audio", cricket::ICE_CANDIDATE_COMPONENT_RTP) == nullptr); + + cricket::VoiceChannel* voice_channel = cm_->CreateVoiceChannel( + &fake_mc_, transport_controller_, cricket::CN_AUDIO, nullptr, false, + AudioOptions()); + EXPECT_TRUE(voice_channel == nullptr); + cricket::VideoChannel* video_channel = cm_->CreateVideoChannel( + &fake_mc_, transport_controller_, cricket::CN_VIDEO, nullptr, false, + VideoOptions()); + EXPECT_TRUE(video_channel == nullptr); + cricket::DataChannel* data_channel = + cm_->CreateDataChannel(transport_controller_, cricket::CN_DATA, nullptr, + false, cricket::DCT_RTP); + EXPECT_TRUE(data_channel == nullptr); + cm_->Terminate(); +} + TEST_F(ChannelManagerTest, SetVideoRtxEnabled) { std::vector codecs; const VideoCodec rtx_codec(96, "rtx"); diff --git a/webrtc/pc/mediasession.h b/webrtc/pc/mediasession.h index ee2126d912..289bc35642 100644 --- a/webrtc/pc/mediasession.h +++ b/webrtc/pc/mediasession.h @@ -25,7 +25,7 @@ #include "webrtc/media/base/mediaengine.h" // For DataChannelType #include "webrtc/media/base/streamparams.h" #include "webrtc/p2p/base/sessiondescription.h" -#include "webrtc/p2p/base/jseptransport.h" +#include "webrtc/p2p/base/transport.h" #include "webrtc/p2p/base/transportdescriptionfactory.h" namespace cricket {