Fix dcsctp handling of dtls restart

dtls_transport will when detecting a new fingerprint
(e.g by usage of pranswer) signal DtlsTransportState::kNew.
When this happen, the dtls crypto state is lost, and
sctp should reconnect, srtp does this automatically
in current code base.

The existing behavior in dcsctp is that it will detect
peer sending an init, and reconnect. But any messages sent
between the dtls restart and the message arriving from the
peer will be lost.

This patch changes so that this case is gracefully handled by
a) letting dcsctp_transport listen to dtls state
this is big part of patch and involves changing the type of
the underlying dtransport from rtc::PacketTransportInternal to cricket::DtlsTransportInternal. If requested, I can put this
into a separate patch...

b) if a dtls restart happens, delete and restart socket.

Testcase that fails before patch and works after is attached.
Bonus: And include-what-you-use on patch

Bug: b/375327137
Change-Id: Ib78488ae75fd8aeb50d121adf464a33dabbf95e2
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/367202
Commit-Queue: Jonas Oreland <jonaso@webrtc.org>
Reviewed-by: Victor Boivie <boivie@webrtc.org>
Commit-Queue: Harald Alvestrand <hta@webrtc.org>
Auto-Submit: Jonas Oreland <jonaso@webrtc.org>
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#43546}
This commit is contained in:
Jonas Oreland 2024-12-12 09:12:56 +01:00 committed by WebRTC LUCI CQ
parent 15fc228ea7
commit 575d323671
13 changed files with 240 additions and 40 deletions

View File

@ -18,12 +18,9 @@
// These classes are not part of the API, and are treated as opaque pointers.
namespace cricket {
class SctpTransportInternal;
class DtlsTransportInternal;
} // namespace cricket
namespace rtc {
class PacketTransportInternal;
} // namespace rtc
namespace webrtc {
// Factory class which can be used to allow fake SctpTransports to be injected
@ -37,7 +34,7 @@ class SctpTransportFactoryInterface {
// Create an SCTP transport using `channel` for the underlying transport.
virtual std::unique_ptr<cricket::SctpTransportInternal> CreateSctpTransport(
const Environment& env,
rtc::PacketTransportInternal* channel) = 0;
cricket::DtlsTransportInternal* channel) = 0;
};
} // namespace webrtc

View File

@ -698,6 +698,7 @@ rtc_source_set("rtc_data_sctp_transport_internal") {
"../api:priority",
"../api:rtc_error",
"../api/transport:datagram_transport_interface",
"../p2p:dtls_transport_internal",
"../p2p:packet_transport_internal",
"../rtc_base:copy_on_write_buffer",
"../rtc_base:threading",
@ -714,16 +715,21 @@ if (rtc_build_dcsctp) {
":media_channel",
":rtc_data_sctp_transport_internal",
"../api:array_view",
"../api:dtls_transport_interface",
"../api:libjingle_peerconnection_api",
"../api:priority",
"../api:rtc_error",
"../api:sequence_checker",
"../api/environment",
"../api/task_queue:pending_task_safety_flag",
"../api/task_queue:task_queue",
"../api/transport:datagram_transport_interface",
"../net/dcsctp/public:factory",
"../net/dcsctp/public:socket",
"../net/dcsctp/public:types",
"../net/dcsctp/public:utils",
"../net/dcsctp/timer:task_queue_timeout",
"../p2p:dtls_transport_internal",
"../p2p:packet_transport_internal",
"../rtc_base:checks",
"../rtc_base:copy_on_write_buffer",
@ -753,6 +759,7 @@ rtc_library("rtc_data_sctp_transport_factory") {
":rtc_data_sctp_transport_internal",
"../api/environment",
"../api/transport:sctp_transport_factory_interface",
"../p2p:dtls_transport_internal",
"../rtc_base:threading",
"../rtc_base/system:unused",
]
@ -953,6 +960,7 @@ if (rtc_include_tests) {
"../api/task_queue",
"../api/test/video:function_video_factory",
"../api/transport:bitrate_settings",
"../api/transport:datagram_transport_interface",
"../api/transport:field_trial_based_config",
"../api/transport/rtp:rtp_source",
"../api/units:data_rate",

View File

@ -11,24 +11,39 @@
#include "media/sctp/dcsctp_transport.h"
#include <atomic>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <limits>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "absl/strings/string_view.h"
#include "api/array_view.h"
#include "api/data_channel_interface.h"
#include "api/dtls_transport_interface.h"
#include "api/environment/environment.h"
#include "api/priority.h"
#include "media/base/media_channel.h"
#include "api/rtc_error.h"
#include "api/sequence_checker.h"
#include "api/task_queue/task_queue_base.h"
#include "api/transport/data_channel_transport_interface.h"
#include "media/sctp/sctp_transport_internal.h"
#include "net/dcsctp/public/dcsctp_message.h"
#include "net/dcsctp/public/dcsctp_options.h"
#include "net/dcsctp/public/dcsctp_socket.h"
#include "net/dcsctp/public/dcsctp_socket_factory.h"
#include "net/dcsctp/public/packet_observer.h"
#include "net/dcsctp/public/text_pcap_packet_observer.h"
#include "net/dcsctp/public/timeout.h"
#include "net/dcsctp/public/types.h"
#include "p2p/base/packet_transport_internal.h"
#include "p2p/dtls/dtls_transport_internal.h"
#include "rtc_base/checks.h"
#include "rtc_base/copy_on_write_buffer.h"
#include "rtc_base/logging.h"
#include "rtc_base/network/received_packet.h"
#include "rtc_base/socket.h"
@ -119,15 +134,16 @@ bool IsEmptyPPID(dcsctp::PPID ppid) {
DcSctpTransport::DcSctpTransport(const Environment& env,
rtc::Thread* network_thread,
rtc::PacketTransportInternal* transport)
cricket::DtlsTransportInternal* transport)
: DcSctpTransport(env,
network_thread,
transport,
std::make_unique<dcsctp::DcSctpSocketFactory>()) {}
DcSctpTransport::DcSctpTransport(
const Environment& env,
rtc::Thread* network_thread,
rtc::PacketTransportInternal* transport,
cricket::DtlsTransportInternal* transport,
std::unique_ptr<dcsctp::DcSctpSocketFactory> socket_factory)
: network_thread_(network_thread),
transport_(transport),
@ -168,7 +184,7 @@ void DcSctpTransport::SetDataChannelSink(DataChannelSink* sink) {
}
void DcSctpTransport::SetDtlsTransport(
rtc::PacketTransportInternal* transport) {
cricket::DtlsTransportInternal* transport) {
RTC_DCHECK_RUN_ON(network_thread_);
DisconnectTransportSignals();
transport_ = transport;
@ -662,6 +678,11 @@ void DcSctpTransport::ConnectTransportSignals() {
data_channel_sink_->OnTransportClosed({});
}
});
transport_->SubscribeDtlsTransportState(
this, [this](cricket::DtlsTransportInternal* transport,
DtlsTransportState state) {
OnDtlsTransportState(transport, state);
});
}
void DcSctpTransport::DisconnectTransportSignals() {
@ -672,6 +693,7 @@ void DcSctpTransport::DisconnectTransportSignals() {
transport_->SignalWritableState.disconnect(this);
transport_->DeregisterReceivedPacketCallback(this);
transport_->SetOnCloseCallback(nullptr);
transport_->UnsubscribeDtlsTransportState(this);
}
void DcSctpTransport::OnTransportWritableState(
@ -680,10 +702,30 @@ void DcSctpTransport::OnTransportWritableState(
RTC_DCHECK_EQ(transport_, transport);
RTC_DLOG(LS_VERBOSE) << debug_name_
<< "->OnTransportWritableState(), writable="
<< transport->writable();
<< transport->writable() << " socket: "
<< (socket_ ? std::to_string(
static_cast<int>(socket_->state()))
: "UNSET");
MaybeConnectSocket();
}
void DcSctpTransport::OnDtlsTransportState(
cricket::DtlsTransportInternal* transport,
webrtc::DtlsTransportState state) {
if (state == DtlsTransportState::kNew && socket_) {
// IF DTLS restart (DtlsTransportState::kNew)
// THEN
// restart socket so that we send an SCPT init
// before any outgoing messages. This is needed
// after DTLS fingerprint changed since peer will discard
// messages with crypto derived from old fingerprint.
RTC_DLOG(LS_INFO) << debug_name_ << " DTLS restart";
dcsctp::DcSctpOptions options = socket_->options();
socket_.reset();
Start(options.local_port, options.remote_port, options.max_message_size);
}
}
void DcSctpTransport::OnTransportReadPacket(
rtc::PacketTransportInternal* /* transport */,
const rtc::ReceivedPacket& packet) {

View File

@ -27,6 +27,7 @@
#include "net/dcsctp/public/types.h"
#include "net/dcsctp/timer/task_queue_timeout.h"
#include "p2p/base/packet_transport_internal.h"
#include "p2p/dtls/dtls_transport_internal.h"
#include "rtc_base/containers/flat_map.h"
#include "rtc_base/copy_on_write_buffer.h"
#include "rtc_base/network/received_packet.h"
@ -44,17 +45,17 @@ class DcSctpTransport : public cricket::SctpTransportInternal,
public:
DcSctpTransport(const Environment& env,
rtc::Thread* network_thread,
rtc::PacketTransportInternal* transport);
cricket::DtlsTransportInternal* transport);
DcSctpTransport(const Environment& env,
rtc::Thread* network_thread,
rtc::PacketTransportInternal* transport,
cricket::DtlsTransportInternal* transport,
std::unique_ptr<dcsctp::DcSctpSocketFactory> socket_factory);
~DcSctpTransport() override;
// cricket::SctpTransportInternal
void SetOnConnectedCallback(std::function<void()> callback) override;
void SetDataChannelSink(DataChannelSink* sink) override;
void SetDtlsTransport(rtc::PacketTransportInternal* transport) override;
void SetDtlsTransport(cricket::DtlsTransportInternal* transport) override;
bool Start(int local_sctp_port,
int remote_sctp_port,
int max_message_size) override;
@ -102,10 +103,12 @@ class DcSctpTransport : public cricket::SctpTransportInternal,
void OnTransportWritableState(rtc::PacketTransportInternal* transport);
void OnTransportReadPacket(rtc::PacketTransportInternal* transport,
const rtc::ReceivedPacket& packet);
void OnDtlsTransportState(cricket::DtlsTransportInternal* transport,
webrtc::DtlsTransportState);
void MaybeConnectSocket();
rtc::Thread* network_thread_;
rtc::PacketTransportInternal* transport_;
cricket::DtlsTransportInternal* transport_;
Environment env_;
Random random_;

View File

@ -11,20 +11,27 @@
#include "media/sctp/dcsctp_transport.h"
#include <memory>
#include <type_traits>
#include <utility>
#include "api/environment/environment.h"
#include "api/environment/environment_factory.h"
#include "api/priority.h"
#include "api/rtc_error.h"
#include "api/transport/data_channel_transport_interface.h"
#include "net/dcsctp/public/dcsctp_options.h"
#include "net/dcsctp/public/dcsctp_socket.h"
#include "net/dcsctp/public/mock_dcsctp_socket.h"
#include "net/dcsctp/public/mock_dcsctp_socket_factory.h"
#include "net/dcsctp/public/types.h"
#include "p2p/base/fake_packet_transport.h"
#include "p2p/dtls/fake_dtls_transport.h"
#include "rtc_base/copy_on_write_buffer.h"
#include "rtc_base/thread.h"
#include "test/gmock.h"
#include "test/gtest.h"
using ::testing::_;
using ::testing::ByMove;
using ::testing::DoAll;
using ::testing::ElementsAre;
using ::testing::InSequence;
using ::testing::Invoke;
@ -36,6 +43,9 @@ namespace webrtc {
namespace {
constexpr char kTransportName[] = "transport";
constexpr int kComponent = 77;
const PriorityValue kDefaultPriority = PriorityValue(Priority::kLow);
class MockDataChannelSink : public DataChannelSink {
@ -58,7 +68,7 @@ static_assert(!std::is_abstract_v<MockDataChannelSink>);
class Peer {
public:
Peer()
: fake_packet_transport_("transport"),
: fake_dtls_transport_(kTransportName, kComponent),
simulated_clock_(1000),
env_(CreateEnvironment(&simulated_clock_)) {
auto socket_ptr = std::make_unique<dcsctp::MockDcSctpSocket>();
@ -71,13 +81,13 @@ class Peer {
.WillOnce(Return(ByMove(std::move(socket_ptr))));
sctp_transport_ = std::make_unique<webrtc::DcSctpTransport>(
env_, rtc::Thread::Current(), &fake_packet_transport_,
env_, rtc::Thread::Current(), &fake_dtls_transport_,
std::move(mock_dcsctp_socket_factory));
sctp_transport_->SetDataChannelSink(&sink_);
sctp_transport_->SetOnConnectedCallback([this]() { sink_.OnConnected(); });
}
rtc::FakePacketTransport fake_packet_transport_;
cricket::FakeDtlsTransport fake_dtls_transport_;
webrtc::SimulatedClock simulated_clock_;
Environment env_;
dcsctp::MockDcSctpSocket* socket_;
@ -89,7 +99,7 @@ class Peer {
TEST(DcSctpTransportTest, OpenSequence) {
rtc::AutoThread main_thread;
Peer peer_a;
peer_a.fake_packet_transport_.SetWritable(true);
peer_a.fake_dtls_transport_.SetWritable(true);
EXPECT_CALL(*peer_a.socket_, Connect)
.Times(1)
@ -107,7 +117,7 @@ TEST(DcSctpTransportTest, CloseSequence) {
rtc::AutoThread main_thread;
Peer peer_a;
Peer peer_b;
peer_a.fake_packet_transport_.SetDestination(&peer_b.fake_packet_transport_,
peer_a.fake_dtls_transport_.SetDestination(&peer_b.fake_dtls_transport_,
false);
{
InSequence sequence;
@ -153,7 +163,7 @@ TEST(DcSctpTransportTest, CloseSequenceSimultaneous) {
rtc::AutoThread main_thread;
Peer peer_a;
Peer peer_b;
peer_a.fake_packet_transport_.SetDestination(&peer_b.fake_packet_transport_,
peer_a.fake_dtls_transport_.SetDestination(&peer_b.fake_dtls_transport_,
false);
{
InSequence sequence;

View File

@ -11,6 +11,7 @@
#include "media/sctp/sctp_transport_factory.h"
#include "api/environment/environment.h"
#include "p2p/dtls/dtls_transport_internal.h"
#include "rtc_base/system/unused.h"
#ifdef WEBRTC_HAVE_DCSCTP
@ -25,9 +26,8 @@ SctpTransportFactory::SctpTransportFactory(rtc::Thread* network_thread)
}
std::unique_ptr<SctpTransportInternal>
SctpTransportFactory::CreateSctpTransport(
const webrtc::Environment& env,
rtc::PacketTransportInternal* transport) {
SctpTransportFactory::CreateSctpTransport(const webrtc::Environment& env,
DtlsTransportInternal* transport) {
std::unique_ptr<SctpTransportInternal> result;
#ifdef WEBRTC_HAVE_DCSCTP
result = std::unique_ptr<SctpTransportInternal>(

View File

@ -26,7 +26,7 @@ class SctpTransportFactory : public webrtc::SctpTransportFactoryInterface {
std::unique_ptr<SctpTransportInternal> CreateSctpTransport(
const webrtc::Environment& env,
rtc::PacketTransportInternal* transport) override;
DtlsTransportInternal* transport) override;
private:
rtc::Thread* network_thread_;

View File

@ -23,6 +23,7 @@
#include "api/transport/data_channel_transport_interface.h"
#include "media/base/media_channel.h"
#include "p2p/base/packet_transport_internal.h"
#include "p2p/dtls/dtls_transport_internal.h"
#include "rtc_base/copy_on_write_buffer.h"
#include "rtc_base/thread.h"
@ -83,7 +84,7 @@ class SctpTransportInternal {
// Changes what underlying DTLS transport is uses. Used when switching which
// bundled transport the SctpTransport uses.
virtual void SetDtlsTransport(rtc::PacketTransportInternal* transport) = 0;
virtual void SetDtlsTransport(cricket::DtlsTransportInternal* transport) = 0;
// When Start is called, connects as soon as possible; this can be called
// before DTLS completes, in which case the connection will begin when DTLS

View File

@ -10,17 +10,22 @@
#include <stdint.h>
#include <atomic>
#include <cstdlib>
#include <iterator>
#include <memory>
#include <optional>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
#include "api/data_channel_interface.h"
#include "api/dtls_transport_interface.h"
#include "api/jsep.h"
#include "api/peer_connection_interface.h"
#include "api/rtc_error.h"
#include "api/scoped_refptr.h"
#include "api/sctp_transport_interface.h"
#include "api/stats/rtc_stats_report.h"
@ -30,6 +35,7 @@
#include "p2p/base/transport_info.h"
#include "pc/media_session.h"
#include "pc/session_description.h"
#include "pc/test/fake_rtc_certificate_generator.h"
#include "pc/test/integration_test_helpers.h"
#include "pc/test/mock_peer_connection_observers.h"
#include "rtc_base/copy_on_write_buffer.h"
@ -38,6 +44,7 @@
#include "rtc_base/gunit.h"
#include "rtc_base/logging.h"
#include "rtc_base/numerics/safe_conversions.h"
#include "rtc_base/strings/string_builder.h"
#include "rtc_base/virtual_socket_server.h"
#include "test/gmock.h"
#include "test/gtest.h"
@ -1217,6 +1224,111 @@ TEST_F(DataChannelIntegrationTestUnifiedPlan,
ASSERT_TRUE_WAIT(!callee()->data_observer()->IsOpen(), kDefaultTimeout);
}
TEST_F(DataChannelIntegrationTestUnifiedPlan, DtlsRestart) {
RTCConfiguration config;
ASSERT_TRUE(CreatePeerConnectionWrappersWithConfig(config, config));
PeerConnectionDependencies dependencies(nullptr);
std::unique_ptr<FakeRTCCertificateGenerator> cert_generator(
new FakeRTCCertificateGenerator());
cert_generator->use_alternate_key();
dependencies.cert_generator = std::move(cert_generator);
auto callee2 = CreatePeerConnectionWrapper("Callee2", nullptr, &config,
std::move(dependencies), nullptr,
/*reset_encoder_factory=*/false,
/*reset_decoder_factory=*/false);
ConnectFakeSignaling();
DataChannelInit dc_init;
dc_init.negotiated = true;
dc_init.id = 77;
caller()->CreateDataChannel("label", &dc_init);
callee()->CreateDataChannel("label", &dc_init);
callee2->CreateDataChannel("label", &dc_init);
std::unique_ptr<SessionDescriptionInterface> offer;
callee()->SetReceivedSdpMunger(
[&](std::unique_ptr<SessionDescriptionInterface>& sdp) {
offer = sdp->Clone();
});
callee()->SetGeneratedSdpMunger(
[](std::unique_ptr<SessionDescriptionInterface>& sdp) {
SetSdpType(sdp, SdpType::kPrAnswer);
});
std::unique_ptr<SessionDescriptionInterface> answer;
caller()->SetReceivedSdpMunger(
[&](std::unique_ptr<SessionDescriptionInterface>& sdp) {
answer = sdp->Clone();
});
caller()->CreateAndSetAndSignalOffer();
ASSERT_FALSE(HasFailure());
EXPECT_EQ(caller()->pc()->signaling_state(),
PeerConnectionInterface::kHaveRemotePrAnswer);
EXPECT_EQ(callee()->pc()->signaling_state(),
PeerConnectionInterface::kHaveLocalPrAnswer);
EXPECT_EQ_WAIT(DataChannelInterface::kOpen, caller()->data_channel()->state(),
kDefaultTimeout);
EXPECT_EQ_WAIT(DataChannelInterface::kOpen, callee()->data_channel()->state(),
kDefaultTimeout);
callee2->set_signaling_message_receiver(caller());
std::atomic<int> caller_sent_on_dc(0);
caller()->set_connection_change_callback(
[&](PeerConnectionInterface::PeerConnectionState new_state) {
if (new_state ==
PeerConnectionInterface::PeerConnectionState::kConnected) {
caller()->data_channel()->SendAsync(
DataBuffer("KESO"), [&](RTCError err) {
caller_sent_on_dc.store(err.ok() ? 1 : -1);
});
}
});
std::atomic<int> callee2_sent_on_dc(0);
callee2->set_connection_change_callback(
[&](PeerConnectionInterface::PeerConnectionState new_state) {
if (new_state ==
PeerConnectionInterface::PeerConnectionState::kConnected &&
callee2->data_channel()->state() == DataChannelInterface::kOpen) {
callee2->data_channel()->SendAsync(
DataBuffer("KENT"), [&](RTCError err) {
callee2_sent_on_dc.store(err.ok() ? 1 : -1);
});
}
});
callee2->data_observer()->set_state_change_callback(
[&](DataChannelInterface::DataState new_state) {
if (callee2->pc()->peer_connection_state() ==
PeerConnectionInterface::PeerConnectionState::kConnected &&
new_state == DataChannelInterface::kOpen) {
callee2->data_channel()->SendAsync(
DataBuffer("KENT"), [&](RTCError err) {
callee2_sent_on_dc.store(err.ok() ? 1 : -1);
});
}
});
std::string offer_sdp;
EXPECT_TRUE(offer->ToString(&offer_sdp));
callee2->ReceiveSdpMessage(SdpType::kOffer, offer_sdp);
EXPECT_EQ(caller()->pc()->signaling_state(),
PeerConnectionInterface::kStable);
EXPECT_EQ(callee2->pc()->signaling_state(), PeerConnectionInterface::kStable);
EXPECT_EQ_WAIT(PeerConnectionInterface::PeerConnectionState::kConnected,
caller()->pc()->peer_connection_state(), kDefaultTimeout);
EXPECT_EQ_WAIT(PeerConnectionInterface::PeerConnectionState::kConnected,
callee2->pc()->peer_connection_state(), kDefaultTimeout);
ASSERT_TRUE_WAIT(caller_sent_on_dc.load() != 0, kDefaultTimeout);
ASSERT_TRUE_WAIT(callee2_sent_on_dc.load() != 0, kDefaultTimeout);
EXPECT_EQ_WAIT("KENT", caller()->data_observer()->last_message(),
kDefaultTimeout);
EXPECT_EQ_WAIT("KESO", callee2->data_observer()->last_message(),
kDefaultTimeout);
}
#endif // WEBRTC_HAVE_SCTP
} // namespace

View File

@ -53,7 +53,7 @@ class FakeCricketSctpTransport : public cricket::SctpTransportInternal {
on_connected_callback_ = std::move(callback);
}
void SetDataChannelSink(DataChannelSink* sink) override {}
void SetDtlsTransport(rtc::PacketTransportInternal* transport) override {}
void SetDtlsTransport(cricket::DtlsTransportInternal* transport) override {}
bool Start(int local_port, int remote_port, int max_message_size) override {
return true;
}
@ -114,7 +114,7 @@ class TestSctpTransportObserver : public SctpTransportObserverInterface {
const std::vector<SctpTransportState>& States() { return states_; }
const SctpTransportInformation LastReceivedInformation() { return info_; }
SctpTransportInformation LastReceivedInformation() { return info_; }
private:
std::vector<SctpTransportState> states_;

View File

@ -454,7 +454,7 @@ class PeerConnectionIntegrationWrapper : public PeerConnectionObserver,
return data_channels_;
}
const MockDataChannelObserver* data_observer() const {
MockDataChannelObserver* data_observer() const {
if (data_observers_.size() == 0) {
return nullptr;
}
@ -740,6 +740,11 @@ class PeerConnectionIntegrationWrapper : public PeerConnectionObserver,
return 0;
}
void set_connection_change_callback(
std::function<void(PeerConnectionInterface::PeerConnectionState)> func) {
connection_change_callback_ = std::move(func);
}
private:
// Constructor used by friend class PeerConnectionIntegrationBaseTest.
explicit PeerConnectionIntegrationWrapper(const std::string& debug_name)
@ -780,11 +785,6 @@ class PeerConnectionIntegrationWrapper : public PeerConnectionObserver,
: nullptr;
}
void set_signaling_message_receiver(
SignalingMessageReceiver* signaling_message_receiver) {
signaling_message_receiver_ = signaling_message_receiver;
}
void set_signaling_delay_ms(int delay_ms) { signaling_delay_ms_ = delay_ms; }
void set_signal_ice_candidates(bool signal) {
@ -962,6 +962,12 @@ class PeerConnectionIntegrationWrapper : public PeerConnectionObserver,
}
// SignalingMessageReceiver callbacks.
public:
void set_signaling_message_receiver(
SignalingMessageReceiver* signaling_message_receiver) {
signaling_message_receiver_ = signaling_message_receiver;
}
void ReceiveSdpMessage(SdpType type, const std::string& msg) override {
if (type == SdpType::kOffer) {
HandleIncomingOffer(msg);
@ -982,6 +988,7 @@ class PeerConnectionIntegrationWrapper : public PeerConnectionObserver,
EXPECT_TRUE(result.value().ok());
}
private:
// PeerConnectionObserver callbacks.
void OnSignalingChange(
PeerConnectionInterface::SignalingState new_state) override {
@ -1021,9 +1028,13 @@ class PeerConnectionIntegrationWrapper : public PeerConnectionObserver,
PeerConnectionInterface::IceConnectionState new_state) override {
standardized_ice_connection_state_history_.push_back(new_state);
}
void OnConnectionChange(
PeerConnectionInterface::PeerConnectionState new_state) override {
peer_connection_state_history_.push_back(new_state);
if (connection_change_callback_) {
connection_change_callback_(new_state);
}
}
void OnIceGatheringChange(
@ -1072,6 +1083,7 @@ class PeerConnectionIntegrationWrapper : public PeerConnectionObserver,
SendIceMessage(candidate->sdp_mid(), candidate->sdp_mline_index(), ice_sdp);
last_candidate_gathered_ = candidate->candidate();
}
void OnIceCandidateError(const std::string& address,
int port,
const std::string& url,
@ -1172,6 +1184,9 @@ class PeerConnectionIntegrationWrapper : public PeerConnectionObserver,
uint64_t audio_concealed_stat_ = 0;
std::string rtp_stats_id_;
std::function<void(PeerConnectionInterface::PeerConnectionState)>
connection_change_callback_ = nullptr;
ScopedTaskSafety task_safety_;
friend class PeerConnectionIntegrationBaseTest;

View File

@ -390,7 +390,13 @@ class MockDataChannelObserver : public DataChannelObserver {
void OnBufferedAmountChange(uint64_t previous_amount) override {}
void OnStateChange() override { states_.push_back(channel_->state()); }
void OnStateChange() override {
states_.push_back(channel_->state());
if (state_change_callback_) {
state_change_callback_(states_.back());
}
}
void OnMessage(const DataBuffer& buffer) override {
messages_.push_back(
{std::string(buffer.data.data<char>(), buffer.data.size()),
@ -417,10 +423,16 @@ class MockDataChannelObserver : public DataChannelObserver {
return states_;
}
void set_state_change_callback(
std::function<void(DataChannelInterface::DataState)> func) {
state_change_callback_ = std::move(func);
}
private:
rtc::scoped_refptr<DataChannelInterface> channel_;
std::vector<DataChannelInterface::DataState> states_;
std::vector<Message> messages_;
std::function<void(DataChannelInterface::DataState)> state_change_callback_;
};
class MockStatsObserver : public StatsObserver {

View File

@ -25,7 +25,7 @@ class FakeSctpTransport : public cricket::SctpTransportInternal {
public:
void SetOnConnectedCallback(std::function<void()> callback) override {}
void SetDataChannelSink(webrtc::DataChannelSink* sink) override {}
void SetDtlsTransport(rtc::PacketTransportInternal* transport) override {}
void SetDtlsTransport(cricket::DtlsTransportInternal* transport) override {}
bool Start(int local_port, int remote_port, int max_message_size) override {
local_port_.emplace(local_port);
remote_port_.emplace(remote_port);
@ -73,7 +73,7 @@ class FakeSctpTransportFactory : public webrtc::SctpTransportFactoryInterface {
public:
std::unique_ptr<cricket::SctpTransportInternal> CreateSctpTransport(
const webrtc::Environment& env,
rtc::PacketTransportInternal*) override {
cricket::DtlsTransportInternal*) override {
last_fake_sctp_transport_ = new FakeSctpTransport();
return std::unique_ptr<cricket::SctpTransportInternal>(
last_fake_sctp_transport_);