diff --git a/p2p/BUILD.gn b/p2p/BUILD.gn index 5d4e706b9b..f1466d706e 100644 --- a/p2p/BUILD.gn +++ b/p2p/BUILD.gn @@ -29,6 +29,8 @@ rtc_static_library("rtc_p2p") { "base/dtlstransport.h", "base/dtlstransportinternal.cc", "base/dtlstransportinternal.h", + "base/icecredentialsiterator.cc", + "base/icecredentialsiterator.h", "base/icetransportinternal.cc", "base/icetransportinternal.h", "base/mdns_message.cc", @@ -154,6 +156,7 @@ if (rtc_include_tests) { "base/asyncstuntcpsocket_unittest.cc", "base/basicasyncresolverfactory_unittest.cc", "base/dtlstransport_unittest.cc", + "base/icecredentialsiterator_unittest.cc", "base/mdns_message_unittest.cc", "base/p2ptransportchannel_unittest.cc", "base/packetlossestimator_unittest.cc", diff --git a/p2p/base/icecredentialsiterator.cc b/p2p/base/icecredentialsiterator.cc new file mode 100644 index 0000000000..7d29653440 --- /dev/null +++ b/p2p/base/icecredentialsiterator.cc @@ -0,0 +1,36 @@ +/* + * Copyright 2018 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "p2p/base/icecredentialsiterator.h" +#include "rtc_base/helpers.h" + +namespace cricket { + +IceCredentialsIterator::IceCredentialsIterator( + const std::vector& pooled_credentials) + : pooled_ice_credentials_(pooled_credentials) {} + +IceCredentialsIterator::~IceCredentialsIterator() = default; + +IceParameters IceCredentialsIterator::CreateRandomIceCredentials() { + return IceParameters(rtc::CreateRandomString(ICE_UFRAG_LENGTH), + rtc::CreateRandomString(ICE_PWD_LENGTH), false); +} + +IceParameters IceCredentialsIterator::GetIceCredentials() { + if (pooled_ice_credentials_.empty()) { + return CreateRandomIceCredentials(); + } + IceParameters credentials = pooled_ice_credentials_.back(); + pooled_ice_credentials_.pop_back(); + return credentials; +} + +} // namespace cricket diff --git a/p2p/base/icecredentialsiterator.h b/p2p/base/icecredentialsiterator.h new file mode 100644 index 0000000000..33e1d6460a --- /dev/null +++ b/p2p/base/icecredentialsiterator.h @@ -0,0 +1,37 @@ +/* + * Copyright 2018 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef P2P_BASE_ICECREDENTIALSITERATOR_H_ +#define P2P_BASE_ICECREDENTIALSITERATOR_H_ + +#include + +#include "p2p/base/transportdescription.h" + +namespace cricket { + +class IceCredentialsIterator { + public: + explicit IceCredentialsIterator(const std::vector&); + virtual ~IceCredentialsIterator(); + + // Get next pooled ice credentials. + // Returns a new random credential if the pool is empty. + IceParameters GetIceCredentials(); + + static IceParameters CreateRandomIceCredentials(); + + private: + std::vector pooled_ice_credentials_; +}; + +} // namespace cricket + +#endif // P2P_BASE_ICECREDENTIALSITERATOR_H_ diff --git a/p2p/base/icecredentialsiterator_unittest.cc b/p2p/base/icecredentialsiterator_unittest.cc new file mode 100644 index 0000000000..00facfbb88 --- /dev/null +++ b/p2p/base/icecredentialsiterator_unittest.cc @@ -0,0 +1,49 @@ +/* + * Copyright 2018 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include +#include +#include + +#include "p2p/base/icecredentialsiterator.h" +#include "rtc_base/gunit.h" + +using cricket::IceParameters; +using cricket::IceCredentialsIterator; + +TEST(IceCredentialsIteratorTest, GetEmpty) { + std::vector empty; + IceCredentialsIterator iterator(empty); + // Verify that we can get credentials even if input is empty. + IceParameters credentials1 = iterator.GetIceCredentials(); +} + +TEST(IceCredentialsIteratorTest, GetOne) { + std::vector one = { + IceCredentialsIterator::CreateRandomIceCredentials()}; + IceCredentialsIterator iterator(one); + EXPECT_EQ(iterator.GetIceCredentials(), one[0]); + auto random = iterator.GetIceCredentials(); + EXPECT_NE(random, one[0]); + EXPECT_NE(random, iterator.GetIceCredentials()); +} + +TEST(IceCredentialsIteratorTest, GetTwo) { + std::vector two = { + IceCredentialsIterator::CreateRandomIceCredentials(), + IceCredentialsIterator::CreateRandomIceCredentials()}; + IceCredentialsIterator iterator(two); + EXPECT_EQ(iterator.GetIceCredentials(), two[1]); + EXPECT_EQ(iterator.GetIceCredentials(), two[0]); + auto random = iterator.GetIceCredentials(); + EXPECT_NE(random, two[0]); + EXPECT_NE(random, two[1]); + EXPECT_NE(random, iterator.GetIceCredentials()); +} diff --git a/p2p/base/portallocator.cc b/p2p/base/portallocator.cc index 5470b5b82e..d3b3a56edd 100644 --- a/p2p/base/portallocator.cc +++ b/p2p/base/portallocator.cc @@ -10,8 +10,10 @@ #include "p2p/base/portallocator.h" +#include #include +#include "p2p/base/icecredentialsiterator.h" #include "rtc_base/checks.h" namespace cricket { @@ -121,6 +123,10 @@ PortAllocator::~PortAllocator() { CheckRunOnValidThreadIfInitialized(); } +void PortAllocator::set_restrict_ice_credentials_change(bool value) { + restrict_ice_credentials_change_ = value; +} + bool PortAllocator::SetConfiguration( const ServerAddresses& stun_servers, const std::vector& turn_servers, @@ -166,8 +172,8 @@ bool PortAllocator::SetConfiguration( // If |candidate_pool_size_| is less than the number of pooled sessions, get // rid of the extras. while (candidate_pool_size_ < static_cast(pooled_sessions_.size())) { - pooled_sessions_.front().reset(nullptr); - pooled_sessions_.pop_front(); + pooled_sessions_.back().reset(nullptr); + pooled_sessions_.pop_back(); } // |stun_candidate_keepalive_interval_| will be used in STUN port allocation @@ -183,7 +189,11 @@ bool PortAllocator::SetConfiguration( // If |candidate_pool_size_| is greater than the number of pooled sessions, // create new sessions. while (static_cast(pooled_sessions_.size()) < candidate_pool_size_) { - PortAllocatorSession* pooled_session = CreateSessionInternal("", 0, "", ""); + IceParameters iceCredentials = + IceCredentialsIterator::CreateRandomIceCredentials(); + PortAllocatorSession* pooled_session = + CreateSessionInternal("", 0, iceCredentials.ufrag, iceCredentials.pwd); + pooled_session->set_pooled(true); pooled_session->StartGettingPorts(); pooled_sessions_.push_back( std::unique_ptr(pooled_session)); @@ -214,22 +224,50 @@ std::unique_ptr PortAllocator::TakePooledSession( if (pooled_sessions_.empty()) { return nullptr; } - std::unique_ptr ret = - std::move(pooled_sessions_.front()); + + IceParameters credentials(ice_ufrag, ice_pwd, false); + // If restrict_ice_credentials_change_ is TRUE, then call FindPooledSession + // with ice credentials. Otherwise call it with nullptr which means + // "find any" pooled session. + auto cit = FindPooledSession(restrict_ice_credentials_change_ ? &credentials + : nullptr); + if (cit == pooled_sessions_.end()) { + return nullptr; + } + + auto it = + pooled_sessions_.begin() + std::distance(pooled_sessions_.cbegin(), cit); + std::unique_ptr ret = std::move(*it); ret->SetIceParameters(content_name, component, ice_ufrag, ice_pwd); - // According to JSEP, a pooled session should filter candidates only after - // it's taken out of the pool. + ret->set_pooled(false); + // According to JSEP, a pooled session should filter candidates only + // after it's taken out of the pool. ret->SetCandidateFilter(candidate_filter()); - pooled_sessions_.pop_front(); + pooled_sessions_.erase(it); return ret; } -const PortAllocatorSession* PortAllocator::GetPooledSession() const { +const PortAllocatorSession* PortAllocator::GetPooledSession( + const IceParameters* ice_credentials) const { CheckRunOnValidThreadAndInitialized(); - if (pooled_sessions_.empty()) { + auto it = FindPooledSession(ice_credentials); + if (it == pooled_sessions_.end()) { return nullptr; + } else { + return it->get(); } - return pooled_sessions_.front().get(); +} + +std::vector>::const_iterator +PortAllocator::FindPooledSession(const IceParameters* ice_credentials) const { + for (auto it = pooled_sessions_.begin(); it != pooled_sessions_.end(); ++it) { + if (ice_credentials == nullptr || + ((*it)->ice_ufrag() == ice_credentials->ufrag && + (*it)->ice_pwd() == ice_credentials->pwd)) { + return it; + } + } + return pooled_sessions_.end(); } void PortAllocator::FreezeCandidatePool() { @@ -250,4 +288,14 @@ void PortAllocator::GetCandidateStatsFromPooledSessions( } } +std::vector PortAllocator::GetPooledIceCredentials() { + CheckRunOnValidThreadAndInitialized(); + std::vector list; + for (const auto& session : pooled_sessions_) { + list.push_back( + IceParameters(session->ice_ufrag(), session->ice_pwd(), false)); + } + return list; +} + } // namespace cricket diff --git a/p2p/base/portallocator.h b/p2p/base/portallocator.h index 8bd709642c..988447c156 100644 --- a/p2p/base/portallocator.h +++ b/p2p/base/portallocator.h @@ -201,7 +201,7 @@ class PortAllocatorSession : public sigslot::has_slots<> { int component() const { return component_; } const std::string& ice_ufrag() const { return ice_ufrag_; } const std::string& ice_pwd() const { return ice_pwd_; } - bool pooled() const { return ice_ufrag_.empty(); } + bool pooled() const { return pooled_; } // Setting this filter should affect not only candidates gathered in the // future, but candidates already gathered and ports already "ready", @@ -309,6 +309,8 @@ class PortAllocatorSession : public sigslot::has_slots<> { UpdateIceParametersInternal(); } + void set_pooled(bool value) { pooled_ = value; } + uint32_t flags_; uint32_t generation_; std::string content_name_; @@ -316,6 +318,8 @@ class PortAllocatorSession : public sigslot::has_slots<> { std::string ice_ufrag_; std::string ice_pwd_; + bool pooled_ = false; + // SetIceParameters is an implementation detail which only PortAllocator // should be able to call. friend class PortAllocator; @@ -335,6 +339,11 @@ class PortAllocator : public sigslot::has_slots<> { // constructing and configuring the PortAllocator subclasses. virtual void Initialize(); + // Set to true if some Ports need to know the ICE credentials when they are + // created. This will ensure that the PortAllocator will only match pooled + // allocator sessions to the ICE transport with the same credentials. + virtual void set_restrict_ice_credentials_change(bool value); + // Set STUN and TURN servers to be used in future sessions, and set // candidate pool size, as described in JSEP. // @@ -392,6 +401,8 @@ class PortAllocator : public sigslot::has_slots<> { // // Caller takes ownership of the returned session. // + // If restrict_ice_credentials_change is TRUE, then it will only + // return a pooled session with matching ice credentials. // If no pooled sessions are available, returns null. std::unique_ptr TakePooledSession( const std::string& content_name, @@ -399,8 +410,10 @@ class PortAllocator : public sigslot::has_slots<> { const std::string& ice_ufrag, const std::string& ice_pwd); - // Returns the next session that would be returned by TakePooledSession. - const PortAllocatorSession* GetPooledSession() const; + // Returns the next session that would be returned by TakePooledSession + // optionally restricting it to sessions with specified ice credentials. + const PortAllocatorSession* GetPooledSession( + const IceParameters* ice_credentials = nullptr) const; // After FreezeCandidatePool is called, changing the candidate pool size will // no longer be allowed, and changing ICE servers will not cause pooled @@ -548,6 +561,9 @@ class PortAllocator : public sigslot::has_slots<> { virtual void GetCandidateStatsFromPooledSessions( CandidateStatsList* candidate_stats_list); + // Return IceParameters of the pooled sessions. + std::vector GetPooledIceCredentials(); + protected: virtual PortAllocatorSession* CreateSessionInternal( const std::string& content_name, @@ -555,7 +571,7 @@ class PortAllocator : public sigslot::has_slots<> { const std::string& ice_ufrag, const std::string& ice_pwd) = 0; - const std::deque>& pooled_sessions() { + const std::vector>& pooled_sessions() { return pooled_sessions_; } @@ -586,7 +602,7 @@ class PortAllocator : public sigslot::has_slots<> { ServerAddresses stun_servers_; std::vector turn_servers_; int candidate_pool_size_ = 0; // Last value passed into SetConfiguration. - std::deque> pooled_sessions_; + std::vector> pooled_sessions_; bool candidate_pool_frozen_ = false; bool prune_turn_ports_ = false; @@ -596,6 +612,15 @@ class PortAllocator : public sigslot::has_slots<> { webrtc::TurnCustomizer* turn_customizer_ = nullptr; absl::optional stun_candidate_keepalive_interval_; + + // If true, TakePooledSession() will only return sessions that has same ice + // credentials as requested. + bool restrict_ice_credentials_change_ = false; + + // Returns iterator to pooled session with specified ice_credentials or first + // if ice_credentials is nullptr. + std::vector>::const_iterator + FindPooledSession(const IceParameters* ice_credentials = nullptr) const; }; } // namespace cricket diff --git a/p2p/base/portallocator_unittest.cc b/p2p/base/portallocator_unittest.cc index 3887a90db1..8b317f4d5a 100644 --- a/p2p/base/portallocator_unittest.cc +++ b/p2p/base/portallocator_unittest.cc @@ -71,8 +71,7 @@ class PortAllocatorTest : public testing::Test, public sigslot::has_slots<> { int GetAllPooledSessionsReturnCount() { int count = 0; - while (GetPooledSession()) { - TakePooledSession(); + while (TakePooledSession() != nullptr) { ++count; } return count; @@ -275,3 +274,29 @@ TEST_F(PortAllocatorTest, DiscardCandidatePool) { allocator_->DiscardCandidatePool(); EXPECT_EQ(0, GetAllPooledSessionsReturnCount()); } + +TEST_F(PortAllocatorTest, RestrictIceCredentialsChange) { + SetConfigurationWithPoolSize(1); + EXPECT_EQ(1, GetAllPooledSessionsReturnCount()); + allocator_->DiscardCandidatePool(); + + // Only return pooled sessions with the ice credentials that + // match those requested in TakePooledSession(). + allocator_->set_restrict_ice_credentials_change(true); + SetConfigurationWithPoolSize(1); + EXPECT_EQ(0, GetAllPooledSessionsReturnCount()); + allocator_->DiscardCandidatePool(); + + SetConfigurationWithPoolSize(1); + auto credentials = allocator_->GetPooledIceCredentials(); + ASSERT_EQ(1u, credentials.size()); + EXPECT_EQ(nullptr, + allocator_->TakePooledSession(kContentName, 0, kIceUfrag, kIcePwd)); + EXPECT_NE(nullptr, + allocator_->TakePooledSession(kContentName, 0, credentials[0].ufrag, + credentials[0].pwd)); + EXPECT_EQ(nullptr, + allocator_->TakePooledSession(kContentName, 0, credentials[0].ufrag, + credentials[0].pwd)); + allocator_->DiscardCandidatePool(); +} diff --git a/p2p/base/transportdescription.h b/p2p/base/transportdescription.h index 3bffdf971a..2ab973278f 100644 --- a/p2p/base/transportdescription.h +++ b/p2p/base/transportdescription.h @@ -67,11 +67,13 @@ struct IceParameters { bool ice_renomination) : ufrag(ice_ufrag), pwd(ice_pwd), renomination(ice_renomination) {} - bool operator==(const IceParameters& other) { + bool operator==(const IceParameters& other) const { return ufrag == other.ufrag && pwd == other.pwd && renomination == other.renomination; } - bool operator!=(const IceParameters& other) { return !(*this == other); } + bool operator!=(const IceParameters& other) const { + return !(*this == other); + } }; extern const char CONNECTIONROLE_ACTIVE_STR[]; diff --git a/p2p/base/transportdescriptionfactory.cc b/p2p/base/transportdescriptionfactory.cc index 618726e841..670950d5ec 100644 --- a/p2p/base/transportdescriptionfactory.cc +++ b/p2p/base/transportdescriptionfactory.cc @@ -27,13 +27,15 @@ TransportDescriptionFactory::~TransportDescriptionFactory() = default; TransportDescription* TransportDescriptionFactory::CreateOffer( const TransportOptions& options, - const TransportDescription* current_description) const { + const TransportDescription* current_description, + IceCredentialsIterator* ice_credentials) const { std::unique_ptr desc(new TransportDescription()); // Generate the ICE credentials if we don't already have them. if (!current_description || options.ice_restart) { - desc->ice_ufrag = rtc::CreateRandomString(ICE_UFRAG_LENGTH); - desc->ice_pwd = rtc::CreateRandomString(ICE_PWD_LENGTH); + IceParameters credentials = ice_credentials->GetIceCredentials(); + desc->ice_ufrag = credentials.ufrag; + desc->ice_pwd = credentials.pwd; } else { desc->ice_ufrag = current_description->ice_ufrag; desc->ice_pwd = current_description->ice_pwd; @@ -59,7 +61,8 @@ TransportDescription* TransportDescriptionFactory::CreateAnswer( const TransportDescription* offer, const TransportOptions& options, bool require_transport_attributes, - const TransportDescription* current_description) const { + const TransportDescription* current_description, + IceCredentialsIterator* ice_credentials) const { // TODO(juberti): Figure out why we get NULL offers, and fix this upstream. if (!offer) { RTC_LOG(LS_WARNING) << "Failed to create TransportDescription answer " @@ -71,8 +74,9 @@ TransportDescription* TransportDescriptionFactory::CreateAnswer( // Generate the ICE credentials if we don't already have them or ice is // being restarted. if (!current_description || options.ice_restart) { - desc->ice_ufrag = rtc::CreateRandomString(ICE_UFRAG_LENGTH); - desc->ice_pwd = rtc::CreateRandomString(ICE_PWD_LENGTH); + IceParameters credentials = ice_credentials->GetIceCredentials(); + desc->ice_ufrag = credentials.ufrag; + desc->ice_pwd = credentials.pwd; } else { desc->ice_ufrag = current_description->ice_ufrag; desc->ice_pwd = current_description->ice_pwd; diff --git a/p2p/base/transportdescriptionfactory.h b/p2p/base/transportdescriptionfactory.h index 937c5fa1fe..dc1476a80f 100644 --- a/p2p/base/transportdescriptionfactory.h +++ b/p2p/base/transportdescriptionfactory.h @@ -11,6 +11,7 @@ #ifndef P2P_BASE_TRANSPORTDESCRIPTIONFACTORY_H_ #define P2P_BASE_TRANSPORTDESCRIPTIONFACTORY_H_ +#include "p2p/base/icecredentialsiterator.h" #include "p2p/base/transportdescription.h" #include "rtc_base/rtccertificate.h" @@ -54,7 +55,8 @@ class TransportDescriptionFactory { // Creates a transport description suitable for use in an offer. TransportDescription* CreateOffer( const TransportOptions& options, - const TransportDescription* current_description) const; + const TransportDescription* current_description, + IceCredentialsIterator* ice_credentials) const; // Create a transport description that is a response to an offer. // // If |require_transport_attributes| is true, then TRANSPORT category @@ -66,7 +68,8 @@ class TransportDescriptionFactory { const TransportDescription* offer, const TransportOptions& options, bool require_transport_attributes, - const TransportDescription* current_description) const; + const TransportDescription* current_description, + IceCredentialsIterator* ice_credentials) const; private: bool SetSecurityInfo(TransportDescription* description, diff --git a/p2p/base/transportdescriptionfactory_unittest.cc b/p2p/base/transportdescriptionfactory_unittest.cc index a7c34b56a0..a3cdb805ee 100644 --- a/p2p/base/transportdescriptionfactory_unittest.cc +++ b/p2p/base/transportdescriptionfactory_unittest.cc @@ -26,7 +26,8 @@ using cricket::TransportOptions; class TransportDescriptionFactoryTest : public testing::Test { public: TransportDescriptionFactoryTest() - : cert1_(rtc::RTCCertificate::Create(std::unique_ptr( + : ice_credentials_({}), + cert1_(rtc::RTCCertificate::Create(std::unique_ptr( new rtc::FakeSSLIdentity("User1")))), cert2_(rtc::RTCCertificate::Create(std::unique_ptr( new rtc::FakeSSLIdentity("User2")))) {} @@ -64,21 +65,22 @@ class TransportDescriptionFactoryTest : public testing::Test { SetDtls(dtls); cricket::TransportOptions options; // The initial offer / answer exchange. - std::unique_ptr offer(f1_.CreateOffer(options, NULL)); + std::unique_ptr offer( + f1_.CreateOffer(options, NULL, &ice_credentials_)); std::unique_ptr answer( - f2_.CreateAnswer(offer.get(), options, true, NULL)); + f2_.CreateAnswer(offer.get(), options, true, NULL, &ice_credentials_)); // Create an updated offer where we restart ice. options.ice_restart = true; std::unique_ptr restart_offer( - f1_.CreateOffer(options, offer.get())); + f1_.CreateOffer(options, offer.get(), &ice_credentials_)); VerifyUfragAndPasswordChanged(dtls, offer.get(), restart_offer.get()); // Create a new answer. The transport ufrag and password is changed since // |options.ice_restart == true| - std::unique_ptr restart_answer( - f2_.CreateAnswer(restart_offer.get(), options, true, answer.get())); + std::unique_ptr restart_answer(f2_.CreateAnswer( + restart_offer.get(), options, true, answer.get(), &ice_credentials_)); ASSERT_TRUE(restart_answer.get() != NULL); VerifyUfragAndPasswordChanged(dtls, answer.get(), restart_answer.get()); @@ -108,19 +110,20 @@ class TransportDescriptionFactoryTest : public testing::Test { cricket::TransportOptions options; // The initial offer / answer exchange. std::unique_ptr offer( - f1_.CreateOffer(options, nullptr)); - std::unique_ptr answer( - f2_.CreateAnswer(offer.get(), options, true, nullptr)); + f1_.CreateOffer(options, nullptr, &ice_credentials_)); + std::unique_ptr answer(f2_.CreateAnswer( + offer.get(), options, true, nullptr, &ice_credentials_)); VerifyRenomination(offer.get(), false); VerifyRenomination(answer.get(), false); options.enable_ice_renomination = true; std::unique_ptr renomination_offer( - f1_.CreateOffer(options, offer.get())); + f1_.CreateOffer(options, offer.get(), &ice_credentials_)); VerifyRenomination(renomination_offer.get(), true); - std::unique_ptr renomination_answer(f2_.CreateAnswer( - renomination_offer.get(), options, true, answer.get())); + std::unique_ptr renomination_answer( + f2_.CreateAnswer(renomination_offer.get(), options, true, answer.get(), + &ice_credentials_)); VerifyRenomination(renomination_answer.get(), true); } @@ -145,6 +148,7 @@ class TransportDescriptionFactoryTest : public testing::Test { } } + cricket::IceCredentialsIterator ice_credentials_; TransportDescriptionFactory f1_; TransportDescriptionFactory f2_; @@ -154,7 +158,7 @@ class TransportDescriptionFactoryTest : public testing::Test { TEST_F(TransportDescriptionFactoryTest, TestOfferDefault) { std::unique_ptr desc( - f1_.CreateOffer(TransportOptions(), NULL)); + f1_.CreateOffer(TransportOptions(), NULL, &ice_credentials_)); CheckDesc(desc.get(), "", "", "", ""); } @@ -165,11 +169,11 @@ TEST_F(TransportDescriptionFactoryTest, TestOfferDtls) { ASSERT_TRUE( cert1_->ssl_certificate().GetSignatureDigestAlgorithm(&digest_alg)); std::unique_ptr desc( - f1_.CreateOffer(TransportOptions(), NULL)); + f1_.CreateOffer(TransportOptions(), NULL, &ice_credentials_)); CheckDesc(desc.get(), "", "", "", digest_alg); // Ensure it also works with SEC_REQUIRED. f1_.set_secure(cricket::SEC_REQUIRED); - desc.reset(f1_.CreateOffer(TransportOptions(), NULL)); + desc.reset(f1_.CreateOffer(TransportOptions(), NULL, &ice_credentials_)); CheckDesc(desc.get(), "", "", "", digest_alg); } @@ -177,7 +181,7 @@ TEST_F(TransportDescriptionFactoryTest, TestOfferDtls) { TEST_F(TransportDescriptionFactoryTest, TestOfferDtlsWithNoIdentity) { f1_.set_secure(cricket::SEC_ENABLED); std::unique_ptr desc( - f1_.CreateOffer(TransportOptions(), NULL)); + f1_.CreateOffer(TransportOptions(), NULL, &ice_credentials_)); ASSERT_TRUE(desc.get() == NULL); } @@ -190,34 +194,36 @@ TEST_F(TransportDescriptionFactoryTest, TestOfferDtlsReofferDtls) { ASSERT_TRUE( cert1_->ssl_certificate().GetSignatureDigestAlgorithm(&digest_alg)); std::unique_ptr old_desc( - f1_.CreateOffer(TransportOptions(), NULL)); + f1_.CreateOffer(TransportOptions(), NULL, &ice_credentials_)); ASSERT_TRUE(old_desc.get() != NULL); std::unique_ptr desc( - f1_.CreateOffer(TransportOptions(), old_desc.get())); + f1_.CreateOffer(TransportOptions(), old_desc.get(), &ice_credentials_)); CheckDesc(desc.get(), "", old_desc->ice_ufrag, old_desc->ice_pwd, digest_alg); } TEST_F(TransportDescriptionFactoryTest, TestAnswerDefault) { std::unique_ptr offer( - f1_.CreateOffer(TransportOptions(), NULL)); + f1_.CreateOffer(TransportOptions(), NULL, &ice_credentials_)); ASSERT_TRUE(offer.get() != NULL); - std::unique_ptr desc( - f2_.CreateAnswer(offer.get(), TransportOptions(), true, NULL)); + std::unique_ptr desc(f2_.CreateAnswer( + offer.get(), TransportOptions(), true, NULL, &ice_credentials_)); CheckDesc(desc.get(), "", "", "", ""); - desc.reset(f2_.CreateAnswer(offer.get(), TransportOptions(), true, NULL)); + desc.reset(f2_.CreateAnswer(offer.get(), TransportOptions(), true, NULL, + &ice_credentials_)); CheckDesc(desc.get(), "", "", "", ""); } // Test that we can update an answer properly; ICE credentials shouldn't change. TEST_F(TransportDescriptionFactoryTest, TestReanswer) { std::unique_ptr offer( - f1_.CreateOffer(TransportOptions(), NULL)); + f1_.CreateOffer(TransportOptions(), NULL, &ice_credentials_)); ASSERT_TRUE(offer.get() != NULL); - std::unique_ptr old_desc( - f2_.CreateAnswer(offer.get(), TransportOptions(), true, NULL)); + std::unique_ptr old_desc(f2_.CreateAnswer( + offer.get(), TransportOptions(), true, NULL, &ice_credentials_)); ASSERT_TRUE(old_desc.get() != NULL); std::unique_ptr desc( - f2_.CreateAnswer(offer.get(), TransportOptions(), true, old_desc.get())); + f2_.CreateAnswer(offer.get(), TransportOptions(), true, old_desc.get(), + &ice_credentials_)); ASSERT_TRUE(desc.get() != NULL); CheckDesc(desc.get(), "", old_desc->ice_ufrag, old_desc->ice_pwd, ""); } @@ -227,10 +233,10 @@ TEST_F(TransportDescriptionFactoryTest, TestAnswerDtlsToNoDtls) { f1_.set_secure(cricket::SEC_ENABLED); f1_.set_certificate(cert1_); std::unique_ptr offer( - f1_.CreateOffer(TransportOptions(), NULL)); + f1_.CreateOffer(TransportOptions(), NULL, &ice_credentials_)); ASSERT_TRUE(offer.get() != NULL); - std::unique_ptr desc( - f2_.CreateAnswer(offer.get(), TransportOptions(), true, NULL)); + std::unique_ptr desc(f2_.CreateAnswer( + offer.get(), TransportOptions(), true, NULL, &ice_credentials_)); CheckDesc(desc.get(), "", "", "", ""); } @@ -240,13 +246,14 @@ TEST_F(TransportDescriptionFactoryTest, TestAnswerNoDtlsToDtls) { f2_.set_secure(cricket::SEC_ENABLED); f2_.set_certificate(cert2_); std::unique_ptr offer( - f1_.CreateOffer(TransportOptions(), NULL)); + f1_.CreateOffer(TransportOptions(), NULL, &ice_credentials_)); ASSERT_TRUE(offer.get() != NULL); - std::unique_ptr desc( - f2_.CreateAnswer(offer.get(), TransportOptions(), true, NULL)); + std::unique_ptr desc(f2_.CreateAnswer( + offer.get(), TransportOptions(), true, NULL, &ice_credentials_)); CheckDesc(desc.get(), "", "", "", ""); f2_.set_secure(cricket::SEC_REQUIRED); - desc.reset(f2_.CreateAnswer(offer.get(), TransportOptions(), true, NULL)); + desc.reset(f2_.CreateAnswer(offer.get(), TransportOptions(), true, NULL, + &ice_credentials_)); ASSERT_TRUE(desc.get() == NULL); } @@ -265,13 +272,14 @@ TEST_F(TransportDescriptionFactoryTest, TestAnswerDtlsToDtls) { cert2_->ssl_certificate().GetSignatureDigestAlgorithm(&digest_alg2)); std::unique_ptr offer( - f1_.CreateOffer(TransportOptions(), NULL)); + f1_.CreateOffer(TransportOptions(), NULL, &ice_credentials_)); ASSERT_TRUE(offer.get() != NULL); - std::unique_ptr desc( - f2_.CreateAnswer(offer.get(), TransportOptions(), true, NULL)); + std::unique_ptr desc(f2_.CreateAnswer( + offer.get(), TransportOptions(), true, NULL, &ice_credentials_)); CheckDesc(desc.get(), "", "", "", digest_alg2); f2_.set_secure(cricket::SEC_REQUIRED); - desc.reset(f2_.CreateAnswer(offer.get(), TransportOptions(), true, NULL)); + desc.reset(f2_.CreateAnswer(offer.get(), TransportOptions(), true, NULL, + &ice_credentials_)); CheckDesc(desc.get(), "", "", "", digest_alg2); } @@ -304,9 +312,36 @@ TEST_F(TransportDescriptionFactoryTest, TestIceRenominationWithDtls) { TEST_F(TransportDescriptionFactoryTest, AddsTrickleIceOption) { cricket::TransportOptions options; std::unique_ptr offer( - f1_.CreateOffer(options, nullptr)); + f1_.CreateOffer(options, nullptr, &ice_credentials_)); EXPECT_TRUE(offer->HasOption("trickle")); std::unique_ptr answer( - f2_.CreateAnswer(offer.get(), options, true, nullptr)); + f2_.CreateAnswer(offer.get(), options, true, nullptr, &ice_credentials_)); EXPECT_TRUE(answer->HasOption("trickle")); } + +// Test CreateOffer with IceCredentialsIterator. +TEST_F(TransportDescriptionFactoryTest, CreateOfferIceCredentialsIterator) { + std::vector credentials = { + cricket::IceParameters("kalle", "anka", false)}; + cricket::IceCredentialsIterator credentialsIterator(credentials); + cricket::TransportOptions options; + std::unique_ptr offer( + f1_.CreateOffer(options, nullptr, &credentialsIterator)); + EXPECT_EQ(offer->GetIceParameters().ufrag, credentials[0].ufrag); + EXPECT_EQ(offer->GetIceParameters().pwd, credentials[0].pwd); +} + +// Test CreateAnswer with IceCredentialsIterator. +TEST_F(TransportDescriptionFactoryTest, CreateAnswerIceCredentialsIterator) { + cricket::TransportOptions options; + std::unique_ptr offer( + f1_.CreateOffer(options, nullptr, &ice_credentials_)); + + std::vector credentials = { + cricket::IceParameters("kalle", "anka", false)}; + cricket::IceCredentialsIterator credentialsIterator(credentials); + std::unique_ptr answer(f1_.CreateAnswer( + offer.get(), options, false, nullptr, &credentialsIterator)); + EXPECT_EQ(answer->GetIceParameters().ufrag, credentials[0].ufrag); + EXPECT_EQ(answer->GetIceParameters().pwd, credentials[0].pwd); +} diff --git a/pc/mediasession.cc b/pc/mediasession.cc index b75dfd6855..889576c306 100644 --- a/pc/mediasession.cc +++ b/pc/mediasession.cc @@ -1268,6 +1268,8 @@ SessionDescription* MediaSessionDescriptionFactory::CreateOffer( const SessionDescription* current_description) const { std::unique_ptr offer(new SessionDescription()); + IceCredentialsIterator ice_credentials( + session_options.pooled_ice_credentials); StreamParamsVec current_streams; GetCurrentStreamParams(current_description, ¤t_streams); @@ -1311,18 +1313,18 @@ SessionDescription* MediaSessionDescriptionFactory::CreateOffer( } switch (media_description_options.type) { case MEDIA_TYPE_AUDIO: - if (!AddAudioContentForOffer(media_description_options, session_options, - current_content, current_description, - audio_rtp_extensions, offer_audio_codecs, - ¤t_streams, offer.get())) { + if (!AddAudioContentForOffer( + media_description_options, session_options, current_content, + current_description, audio_rtp_extensions, offer_audio_codecs, + ¤t_streams, offer.get(), &ice_credentials)) { return nullptr; } break; case MEDIA_TYPE_VIDEO: - if (!AddVideoContentForOffer(media_description_options, session_options, - current_content, current_description, - video_rtp_extensions, offer_video_codecs, - ¤t_streams, offer.get())) { + if (!AddVideoContentForOffer( + media_description_options, session_options, current_content, + current_description, video_rtp_extensions, offer_video_codecs, + ¤t_streams, offer.get(), &ice_credentials)) { return nullptr; } break; @@ -1330,7 +1332,7 @@ SessionDescription* MediaSessionDescriptionFactory::CreateOffer( if (!AddDataContentForOffer(media_description_options, session_options, current_content, current_description, offer_data_codecs, ¤t_streams, - offer.get())) { + offer.get(), &ice_credentials)) { return nullptr; } break; @@ -1387,6 +1389,10 @@ SessionDescription* MediaSessionDescriptionFactory::CreateAnswer( if (!offer) { return nullptr; } + + IceCredentialsIterator ice_credentials( + session_options.pooled_ice_credentials); + // The answer contains the intersection of the codecs in the offer with the // codecs we support. As indicated by XEP-0167, we retain the same payload ids // from the offer in the answer. @@ -1449,7 +1455,7 @@ SessionDescription* MediaSessionDescriptionFactory::CreateAnswer( media_description_options, session_options, offer_content, offer, current_content, current_description, bundle_transport.get(), answer_audio_codecs, ¤t_streams, - answer.get())) { + answer.get(), &ice_credentials)) { return nullptr; } break; @@ -1458,16 +1464,16 @@ SessionDescription* MediaSessionDescriptionFactory::CreateAnswer( media_description_options, session_options, offer_content, offer, current_content, current_description, bundle_transport.get(), answer_video_codecs, ¤t_streams, - answer.get())) { + answer.get(), &ice_credentials)) { return nullptr; } break; case MEDIA_TYPE_DATA: - if (!AddDataContentForAnswer(media_description_options, session_options, - offer_content, offer, current_content, - current_description, - bundle_transport.get(), answer_data_codecs, - ¤t_streams, answer.get())) { + if (!AddDataContentForAnswer( + media_description_options, session_options, offer_content, + offer, current_content, current_description, + bundle_transport.get(), answer_data_codecs, ¤t_streams, + answer.get(), &ice_credentials)) { return nullptr; } break; @@ -1774,13 +1780,15 @@ bool MediaSessionDescriptionFactory::AddTransportOffer( const std::string& content_name, const TransportOptions& transport_options, const SessionDescription* current_desc, - SessionDescription* offer_desc) const { + SessionDescription* offer_desc, + IceCredentialsIterator* ice_credentials) const { if (!transport_desc_factory_) return false; const TransportDescription* current_tdesc = GetTransportDescription(content_name, current_desc); std::unique_ptr new_tdesc( - transport_desc_factory_->CreateOffer(transport_options, current_tdesc)); + transport_desc_factory_->CreateOffer(transport_options, current_tdesc, + ice_credentials)); bool ret = (new_tdesc.get() != NULL && offer_desc->AddTransportInfo(TransportInfo(content_name, *new_tdesc))); @@ -1796,7 +1804,8 @@ TransportDescription* MediaSessionDescriptionFactory::CreateTransportAnswer( const SessionDescription* offer_desc, const TransportOptions& transport_options, const SessionDescription* current_desc, - bool require_transport_attributes) const { + bool require_transport_attributes, + IceCredentialsIterator* ice_credentials) const { if (!transport_desc_factory_) return NULL; const TransportDescription* offer_tdesc = @@ -1805,7 +1814,7 @@ TransportDescription* MediaSessionDescriptionFactory::CreateTransportAnswer( GetTransportDescription(content_name, current_desc); return transport_desc_factory_->CreateAnswer(offer_tdesc, transport_options, require_transport_attributes, - current_tdesc); + current_tdesc, ice_credentials); } bool MediaSessionDescriptionFactory::AddTransportAnswer( @@ -1841,7 +1850,8 @@ bool MediaSessionDescriptionFactory::AddAudioContentForOffer( const RtpHeaderExtensions& audio_rtp_extensions, const AudioCodecs& audio_codecs, StreamParamsVec* current_streams, - SessionDescription* desc) const { + SessionDescription* desc, + IceCredentialsIterator* ice_credentials) const { // Filter audio_codecs (which includes all codecs, with correctly remapped // payload types) based on transceiver direction. const AudioCodecs& supported_audio_codecs = @@ -1897,7 +1907,7 @@ bool MediaSessionDescriptionFactory::AddAudioContentForOffer( media_description_options.stopped, audio.release()); if (!AddTransportOffer(media_description_options.mid, media_description_options.transport_options, - current_description, desc)) { + current_description, desc, ice_credentials)) { return false; } @@ -1912,7 +1922,8 @@ bool MediaSessionDescriptionFactory::AddVideoContentForOffer( const RtpHeaderExtensions& video_rtp_extensions, const VideoCodecs& video_codecs, StreamParamsVec* current_streams, - SessionDescription* desc) const { + SessionDescription* desc, + IceCredentialsIterator* ice_credentials) const { cricket::SecurePolicy sdes_policy = IsDtlsActive(current_content, current_description) ? cricket::SEC_DISABLED : secure(); @@ -1966,7 +1977,7 @@ bool MediaSessionDescriptionFactory::AddVideoContentForOffer( media_description_options.stopped, video.release()); if (!AddTransportOffer(media_description_options.mid, media_description_options.transport_options, - current_description, desc)) { + current_description, desc, ice_credentials)) { return false; } return true; @@ -1979,7 +1990,8 @@ bool MediaSessionDescriptionFactory::AddDataContentForOffer( const SessionDescription* current_description, const DataCodecs& data_codecs, StreamParamsVec* current_streams, - SessionDescription* desc) const { + SessionDescription* desc, + IceCredentialsIterator* ice_credentials) const { bool secure_transport = (transport_desc_factory_->secure() != SEC_DISABLED); std::unique_ptr data(new DataContentDescription()); @@ -2033,7 +2045,7 @@ bool MediaSessionDescriptionFactory::AddDataContentForOffer( } if (!AddTransportOffer(media_description_options.mid, media_description_options.transport_options, - current_description, desc)) { + current_description, desc, ice_credentials)) { return false; } return true; @@ -2061,15 +2073,16 @@ bool MediaSessionDescriptionFactory::AddAudioContentForAnswer( const TransportInfo* bundle_transport, const AudioCodecs& audio_codecs, StreamParamsVec* current_streams, - SessionDescription* answer) const { + SessionDescription* answer, + IceCredentialsIterator* ice_credentials) const { RTC_CHECK(IsMediaContentOfType(offer_content, MEDIA_TYPE_AUDIO)); const AudioContentDescription* offer_audio_description = offer_content->media_description()->as_audio(); - std::unique_ptr audio_transport( - CreateTransportAnswer(media_description_options.mid, offer_description, - media_description_options.transport_options, - current_description, bundle_transport != nullptr)); + std::unique_ptr audio_transport(CreateTransportAnswer( + media_description_options.mid, offer_description, + media_description_options.transport_options, current_description, + bundle_transport != nullptr, ice_credentials)); if (!audio_transport) { return false; } @@ -2155,15 +2168,16 @@ bool MediaSessionDescriptionFactory::AddVideoContentForAnswer( const TransportInfo* bundle_transport, const VideoCodecs& video_codecs, StreamParamsVec* current_streams, - SessionDescription* answer) const { + SessionDescription* answer, + IceCredentialsIterator* ice_credentials) const { RTC_CHECK(IsMediaContentOfType(offer_content, MEDIA_TYPE_VIDEO)); const VideoContentDescription* offer_video_description = offer_content->media_description()->as_video(); - std::unique_ptr video_transport( - CreateTransportAnswer(media_description_options.mid, offer_description, - media_description_options.transport_options, - current_description, bundle_transport != nullptr)); + std::unique_ptr video_transport(CreateTransportAnswer( + media_description_options.mid, offer_description, + media_description_options.transport_options, current_description, + bundle_transport != nullptr, ice_credentials)); if (!video_transport) { return false; } @@ -2241,11 +2255,12 @@ bool MediaSessionDescriptionFactory::AddDataContentForAnswer( const TransportInfo* bundle_transport, const DataCodecs& data_codecs, StreamParamsVec* current_streams, - SessionDescription* answer) const { - std::unique_ptr data_transport( - CreateTransportAnswer(media_description_options.mid, offer_description, - media_description_options.transport_options, - current_description, bundle_transport != nullptr)); + SessionDescription* answer, + IceCredentialsIterator* ice_credentials) const { + std::unique_ptr data_transport(CreateTransportAnswer( + media_description_options.mid, offer_description, + media_description_options.transport_options, current_description, + bundle_transport != nullptr, ice_credentials)); if (!data_transport) { return false; } diff --git a/pc/mediasession.h b/pc/mediasession.h index 54ca7c870e..b1df1dbd26 100644 --- a/pc/mediasession.h +++ b/pc/mediasession.h @@ -21,6 +21,7 @@ #include "api/mediatypes.h" #include "media/base/mediaconstants.h" #include "media/base/mediaengine.h" // For DataChannelType +#include "p2p/base/icecredentialsiterator.h" #include "p2p/base/transportdescriptionfactory.h" #include "pc/jseptransport.h" #include "pc/sessiondescription.h" @@ -98,6 +99,7 @@ struct MediaSessionOptions { // List of media description options in the same order that the media // descriptions will be generated. std::vector media_description_options; + std::vector pooled_ice_credentials; }; // Creates media session descriptions according to the supplied codecs and @@ -186,14 +188,16 @@ class MediaSessionDescriptionFactory { bool AddTransportOffer(const std::string& content_name, const TransportOptions& transport_options, const SessionDescription* current_desc, - SessionDescription* offer) const; + SessionDescription* offer, + IceCredentialsIterator* ice_credentials) const; TransportDescription* CreateTransportAnswer( const std::string& content_name, const SessionDescription* offer_desc, const TransportOptions& transport_options, const SessionDescription* current_desc, - bool require_transport_attributes) const; + bool require_transport_attributes, + IceCredentialsIterator* ice_credentials) const; bool AddTransportAnswer(const std::string& content_name, const TransportDescription& transport_desc, @@ -211,7 +215,8 @@ class MediaSessionDescriptionFactory { const RtpHeaderExtensions& audio_rtp_extensions, const AudioCodecs& audio_codecs, StreamParamsVec* current_streams, - SessionDescription* desc) const; + SessionDescription* desc, + IceCredentialsIterator* ice_credentials) const; bool AddVideoContentForOffer( const MediaDescriptionOptions& media_description_options, @@ -221,7 +226,8 @@ class MediaSessionDescriptionFactory { const RtpHeaderExtensions& video_rtp_extensions, const VideoCodecs& video_codecs, StreamParamsVec* current_streams, - SessionDescription* desc) const; + SessionDescription* desc, + IceCredentialsIterator* ice_credentials) const; bool AddDataContentForOffer( const MediaDescriptionOptions& media_description_options, @@ -230,7 +236,8 @@ class MediaSessionDescriptionFactory { const SessionDescription* current_description, const DataCodecs& data_codecs, StreamParamsVec* current_streams, - SessionDescription* desc) const; + SessionDescription* desc, + IceCredentialsIterator* ice_credentials) const; bool AddAudioContentForAnswer( const MediaDescriptionOptions& media_description_options, @@ -242,7 +249,8 @@ class MediaSessionDescriptionFactory { const TransportInfo* bundle_transport, const AudioCodecs& audio_codecs, StreamParamsVec* current_streams, - SessionDescription* answer) const; + SessionDescription* answer, + IceCredentialsIterator* ice_credentials) const; bool AddVideoContentForAnswer( const MediaDescriptionOptions& media_description_options, @@ -254,7 +262,8 @@ class MediaSessionDescriptionFactory { const TransportInfo* bundle_transport, const VideoCodecs& video_codecs, StreamParamsVec* current_streams, - SessionDescription* answer) const; + SessionDescription* answer, + IceCredentialsIterator* ice_credentials) const; bool AddDataContentForAnswer( const MediaDescriptionOptions& media_description_options, @@ -266,7 +275,8 @@ class MediaSessionDescriptionFactory { const TransportInfo* bundle_transport, const DataCodecs& data_codecs, StreamParamsVec* current_streams, - SessionDescription* answer) const; + SessionDescription* answer, + IceCredentialsIterator* ice_credentials) const; void ComputeAudioCodecsIntersectionAndUnion(); diff --git a/pc/peerconnection.cc b/pc/peerconnection.cc index 659d16b073..0861246151 100644 --- a/pc/peerconnection.cc +++ b/pc/peerconnection.cc @@ -3646,6 +3646,11 @@ void PeerConnection::GetOptionsForOffer( session_options->rtcp_cname = rtcp_cname_; session_options->crypto_options = factory_->options().crypto_options; session_options->is_unified_plan = IsUnifiedPlan(); + session_options->pooled_ice_credentials = + network_thread()->Invoke>( + RTC_FROM_HERE, + rtc::Bind(&cricket::PortAllocator::GetPooledIceCredentials, + port_allocator_.get())); } void PeerConnection::GetOptionsForPlanBOffer( @@ -3906,6 +3911,11 @@ void PeerConnection::GetOptionsForAnswer( session_options->rtcp_cname = rtcp_cname_; session_options->crypto_options = factory_->options().crypto_options; session_options->is_unified_plan = IsUnifiedPlan(); + session_options->pooled_ice_credentials = + network_thread()->Invoke>( + RTC_FROM_HERE, + rtc::Bind(&cricket::PortAllocator::GetPooledIceCredentials, + port_allocator_.get())); } void PeerConnection::GetOptionsForPlanBAnswer( diff --git a/pc/peerconnection_ice_unittest.cc b/pc/peerconnection_ice_unittest.cc index e6d6ac1550..4e25614d3c 100644 --- a/pc/peerconnection_ice_unittest.cc +++ b/pc/peerconnection_ice_unittest.cc @@ -78,6 +78,9 @@ class PeerConnectionWrapperForIceTest : public PeerConnectionWrapper { void set_network(rtc::FakeNetworkManager* network) { network_ = network; } + // The port allocator used by this PC. + cricket::PortAllocator* port_allocator_; + private: rtc::FakeNetworkManager* network_; }; @@ -115,6 +118,7 @@ class PeerConnectionIceBaseTest : public ::testing::Test { RTCConfiguration modified_config = config; modified_config.sdp_semantics = sdp_semantics_; auto observer = absl::make_unique(); + auto port_allocator_copy = port_allocator.get(); auto pc = pc_factory_->CreatePeerConnection( modified_config, std::move(port_allocator), nullptr, observer.get()); if (!pc) { @@ -124,6 +128,7 @@ class PeerConnectionIceBaseTest : public ::testing::Test { auto wrapper = absl::make_unique( pc_factory_, pc, std::move(observer)); wrapper->set_network(fake_network); + wrapper->port_allocator_ = port_allocator_copy; return wrapper; } @@ -1008,4 +1013,41 @@ TEST_F(PeerConnectionIceConfigTest, SetStunCandidateKeepaliveInterval) { EXPECT_EQ(actual_stun_keepalive_interval.value_or(-1), 321); } +TEST_P(PeerConnectionIceTest, IceCredentialsCreateOffer) { + RTCConfiguration config; + config.ice_candidate_pool_size = 1; + auto pc = CreatePeerConnectionWithAudioVideo(config); + ASSERT_NE(pc->port_allocator_, nullptr); + auto offer = pc->CreateOffer(); + auto credentials = pc->port_allocator_->GetPooledIceCredentials(); + ASSERT_EQ(1u, credentials.size()); + + auto* desc = offer->description(); + for (const auto& content : desc->contents()) { + auto* transport_info = desc->GetTransportInfoByName(content.name); + EXPECT_EQ(transport_info->description.ice_ufrag, credentials[0].ufrag); + EXPECT_EQ(transport_info->description.ice_pwd, credentials[0].pwd); + } +} + +TEST_P(PeerConnectionIceTest, IceCredentialsCreateAnswer) { + RTCConfiguration config; + config.ice_candidate_pool_size = 1; + auto pc = CreatePeerConnectionWithAudioVideo(config); + ASSERT_NE(pc->port_allocator_, nullptr); + auto offer = pc->CreateOffer(); + ASSERT_TRUE(pc->SetRemoteDescription(std::move(offer))); + auto answer = pc->CreateAnswer(); + + auto credentials = pc->port_allocator_->GetPooledIceCredentials(); + ASSERT_EQ(1u, credentials.size()); + + auto* desc = answer->description(); + for (const auto& content : desc->contents()) { + auto* transport_info = desc->GetTransportInfoByName(content.name); + EXPECT_EQ(transport_info->description.ice_ufrag, credentials[0].ufrag); + EXPECT_EQ(transport_info->description.ice_pwd, credentials[0].pwd); + } +} + } // namespace webrtc