webrtc_m130/pc/test/mock_peer_connection_observers.h
Jonas Oreland 575d323671 Fix dcsctp handling of dtls restart
dtls_transport will when detecting a new fingerprint
(e.g by usage of pranswer) signal DtlsTransportState::kNew.
When this happen, the dtls crypto state is lost, and
sctp should reconnect, srtp does this automatically
in current code base.

The existing behavior in dcsctp is that it will detect
peer sending an init, and reconnect. But any messages sent
between the dtls restart and the message arriving from the
peer will be lost.

This patch changes so that this case is gracefully handled by
a) letting dcsctp_transport listen to dtls state
this is big part of patch and involves changing the type of
the underlying dtransport from rtc::PacketTransportInternal to cricket::DtlsTransportInternal. If requested, I can put this
into a separate patch...

b) if a dtls restart happens, delete and restart socket.

Testcase that fails before patch and works after is attached.
Bonus: And include-what-you-use on patch

Bug: b/375327137
Change-Id: Ib78488ae75fd8aeb50d121adf464a33dabbf95e2
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/367202
Commit-Queue: Jonas Oreland <jonaso@webrtc.org>
Reviewed-by: Victor Boivie <boivie@webrtc.org>
Commit-Queue: Harald Alvestrand <hta@webrtc.org>
Auto-Submit: Jonas Oreland <jonaso@webrtc.org>
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#43546}
2024-12-12 02:47:01 -08:00

611 lines
18 KiB
C++

/*
* Copyright 2012 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
// This file contains mock implementations of observers used in PeerConnection.
// TODO(steveanton): These aren't really mocks and should be renamed.
#ifndef PC_TEST_MOCK_PEER_CONNECTION_OBSERVERS_H_
#define PC_TEST_MOCK_PEER_CONNECTION_OBSERVERS_H_
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "api/data_channel_interface.h"
#include "api/jsep_ice_candidate.h"
#include "pc/stream_collection.h"
#include "rtc_base/checks.h"
namespace webrtc {
class MockPeerConnectionObserver : public PeerConnectionObserver {
public:
struct AddTrackEvent {
explicit AddTrackEvent(
rtc::scoped_refptr<RtpReceiverInterface> event_receiver,
std::vector<rtc::scoped_refptr<MediaStreamInterface>> event_streams)
: receiver(std::move(event_receiver)),
streams(std::move(event_streams)) {
for (auto stream : streams) {
std::vector<rtc::scoped_refptr<MediaStreamTrackInterface>> tracks;
for (auto audio_track : stream->GetAudioTracks()) {
tracks.push_back(audio_track);
}
for (auto video_track : stream->GetVideoTracks()) {
tracks.push_back(video_track);
}
snapshotted_stream_tracks[stream] = tracks;
}
}
rtc::scoped_refptr<RtpReceiverInterface> receiver;
std::vector<rtc::scoped_refptr<MediaStreamInterface>> streams;
// This map records the tracks present in each stream at the time the
// OnAddTrack callback was issued.
std::map<rtc::scoped_refptr<MediaStreamInterface>,
std::vector<rtc::scoped_refptr<MediaStreamTrackInterface>>>
snapshotted_stream_tracks;
};
MockPeerConnectionObserver() : remote_streams_(StreamCollection::Create()) {}
virtual ~MockPeerConnectionObserver() {}
void SetPeerConnectionInterface(PeerConnectionInterface* pc) {
pc_ = pc;
if (pc) {
state_ = pc_->signaling_state();
}
}
void OnSignalingChange(
PeerConnectionInterface::SignalingState new_state) override {
RTC_DCHECK(pc_);
RTC_DCHECK(pc_->signaling_state() == new_state);
state_ = new_state;
}
MediaStreamInterface* RemoteStream(const std::string& label) {
return remote_streams_->find(label);
}
StreamCollectionInterface* remote_streams() const {
return remote_streams_.get();
}
void OnAddStream(rtc::scoped_refptr<MediaStreamInterface> stream) override {
last_added_stream_ = stream;
remote_streams_->AddStream(stream);
}
void OnRemoveStream(
rtc::scoped_refptr<MediaStreamInterface> stream) override {
last_removed_stream_ = stream;
remote_streams_->RemoveStream(stream.get());
}
void OnRenegotiationNeeded() override { renegotiation_needed_ = true; }
void OnNegotiationNeededEvent(uint32_t event_id) override {
latest_negotiation_needed_event_ = event_id;
}
void OnDataChannel(
rtc::scoped_refptr<DataChannelInterface> data_channel) override {
last_datachannel_ = data_channel;
}
void OnIceConnectionChange(
PeerConnectionInterface::IceConnectionState new_state) override {
RTC_DCHECK(pc_);
RTC_DCHECK(pc_->ice_connection_state() == new_state);
// When ICE is finished, the caller will get to a kIceConnectionCompleted
// state, because it has the ICE controlling role, while the callee
// will get to a kIceConnectionConnected state. This means that both ICE
// and DTLS are connected.
ice_connected_ =
(new_state == PeerConnectionInterface::kIceConnectionConnected) ||
(new_state == PeerConnectionInterface::kIceConnectionCompleted);
callback_triggered_ = true;
}
void OnIceGatheringChange(
PeerConnectionInterface::IceGatheringState new_state) override {
RTC_DCHECK(pc_);
RTC_DCHECK(pc_->ice_gathering_state() == new_state);
ice_gathering_complete_ =
new_state == PeerConnectionInterface::kIceGatheringComplete;
callback_triggered_ = true;
}
void OnIceCandidate(const IceCandidateInterface* candidate) override {
RTC_DCHECK(pc_);
candidates_.push_back(std::make_unique<JsepIceCandidate>(
candidate->sdp_mid(), candidate->sdp_mline_index(),
candidate->candidate()));
callback_triggered_ = true;
}
void OnIceCandidatesRemoved(
const std::vector<cricket::Candidate>& candidates) override {
num_candidates_removed_++;
callback_triggered_ = true;
}
void OnIceConnectionReceivingChange(bool receiving) override {
callback_triggered_ = true;
}
void OnAddTrack(rtc::scoped_refptr<RtpReceiverInterface> receiver,
const std::vector<rtc::scoped_refptr<MediaStreamInterface>>&
streams) override {
RTC_DCHECK(receiver);
num_added_tracks_++;
last_added_track_label_ = receiver->id();
add_track_events_.push_back(AddTrackEvent(receiver, streams));
}
void OnTrack(
rtc::scoped_refptr<RtpTransceiverInterface> transceiver) override {
on_track_transceivers_.push_back(transceiver);
}
void OnRemoveTrack(
rtc::scoped_refptr<RtpReceiverInterface> receiver) override {
remove_track_events_.push_back(receiver);
}
std::vector<rtc::scoped_refptr<RtpReceiverInterface>> GetAddTrackReceivers() {
std::vector<rtc::scoped_refptr<RtpReceiverInterface>> receivers;
for (const AddTrackEvent& event : add_track_events_) {
receivers.push_back(event.receiver);
}
return receivers;
}
int CountAddTrackEventsForStream(const std::string& stream_id) {
int found_tracks = 0;
for (const AddTrackEvent& event : add_track_events_) {
bool has_stream_id = false;
for (auto stream : event.streams) {
if (stream->id() == stream_id) {
has_stream_id = true;
break;
}
}
if (has_stream_id) {
++found_tracks;
}
}
return found_tracks;
}
// Returns the id of the last added stream.
// Empty string if no stream have been added.
std::string GetLastAddedStreamId() {
if (last_added_stream_.get())
return last_added_stream_->id();
return "";
}
std::string GetLastRemovedStreamId() {
if (last_removed_stream_.get())
return last_removed_stream_->id();
return "";
}
IceCandidateInterface* last_candidate() {
if (candidates_.empty()) {
return nullptr;
} else {
return candidates_.back().get();
}
}
std::vector<const IceCandidateInterface*> GetAllCandidates() {
std::vector<const IceCandidateInterface*> candidates;
for (const auto& candidate : candidates_) {
candidates.push_back(candidate.get());
}
return candidates;
}
std::vector<IceCandidateInterface*> GetCandidatesByMline(int mline_index) {
std::vector<IceCandidateInterface*> candidates;
for (const auto& candidate : candidates_) {
if (candidate->sdp_mline_index() == mline_index) {
candidates.push_back(candidate.get());
}
}
return candidates;
}
bool legacy_renegotiation_needed() const { return renegotiation_needed_; }
void clear_legacy_renegotiation_needed() { renegotiation_needed_ = false; }
bool has_negotiation_needed_event() {
return latest_negotiation_needed_event_.has_value();
}
uint32_t latest_negotiation_needed_event() {
return latest_negotiation_needed_event_.value_or(0u);
}
void clear_latest_negotiation_needed_event() {
latest_negotiation_needed_event_ = std::nullopt;
}
rtc::scoped_refptr<PeerConnectionInterface> pc_;
PeerConnectionInterface::SignalingState state_;
std::vector<std::unique_ptr<IceCandidateInterface>> candidates_;
rtc::scoped_refptr<DataChannelInterface> last_datachannel_;
rtc::scoped_refptr<StreamCollection> remote_streams_;
bool renegotiation_needed_ = false;
std::optional<uint32_t> latest_negotiation_needed_event_;
bool ice_gathering_complete_ = false;
bool ice_connected_ = false;
bool callback_triggered_ = false;
int num_added_tracks_ = 0;
std::string last_added_track_label_;
std::vector<AddTrackEvent> add_track_events_;
std::vector<rtc::scoped_refptr<RtpReceiverInterface>> remove_track_events_;
std::vector<rtc::scoped_refptr<RtpTransceiverInterface>>
on_track_transceivers_;
int num_candidates_removed_ = 0;
private:
rtc::scoped_refptr<MediaStreamInterface> last_added_stream_;
rtc::scoped_refptr<MediaStreamInterface> last_removed_stream_;
};
class MockCreateSessionDescriptionObserver
: public CreateSessionDescriptionObserver {
public:
MockCreateSessionDescriptionObserver()
: called_(false),
error_("MockCreateSessionDescriptionObserver not called") {}
virtual ~MockCreateSessionDescriptionObserver() {}
void OnSuccess(SessionDescriptionInterface* desc) override {
MutexLock lock(&mutex_);
called_ = true;
error_ = "";
desc_.reset(desc);
}
void OnFailure(RTCError error) override {
MutexLock lock(&mutex_);
called_ = true;
error_ = error.message();
}
bool called() const {
MutexLock lock(&mutex_);
return called_;
}
bool result() const {
MutexLock lock(&mutex_);
return error_.empty();
}
const std::string& error() const {
MutexLock lock(&mutex_);
return error_;
}
std::unique_ptr<SessionDescriptionInterface> MoveDescription() {
MutexLock lock(&mutex_);
return std::move(desc_);
}
private:
mutable Mutex mutex_;
bool called_ RTC_GUARDED_BY(mutex_);
std::string error_ RTC_GUARDED_BY(mutex_);
std::unique_ptr<SessionDescriptionInterface> desc_ RTC_GUARDED_BY(mutex_);
};
class MockSetSessionDescriptionObserver : public SetSessionDescriptionObserver {
public:
static rtc::scoped_refptr<MockSetSessionDescriptionObserver> Create() {
return rtc::make_ref_counted<MockSetSessionDescriptionObserver>();
}
MockSetSessionDescriptionObserver()
: called_(false),
error_("MockSetSessionDescriptionObserver not called") {}
~MockSetSessionDescriptionObserver() override {}
void OnSuccess() override {
MutexLock lock(&mutex_);
called_ = true;
error_ = "";
}
void OnFailure(RTCError error) override {
MutexLock lock(&mutex_);
called_ = true;
error_ = error.message();
}
bool called() const {
MutexLock lock(&mutex_);
return called_;
}
bool result() const {
MutexLock lock(&mutex_);
return error_.empty();
}
const std::string& error() const {
MutexLock lock(&mutex_);
return error_;
}
private:
mutable Mutex mutex_;
bool called_;
std::string error_;
};
class FakeSetLocalDescriptionObserver
: public SetLocalDescriptionObserverInterface {
public:
bool called() const { return error_.has_value(); }
RTCError& error() {
RTC_DCHECK(error_.has_value());
return *error_;
}
// SetLocalDescriptionObserverInterface implementation.
void OnSetLocalDescriptionComplete(RTCError error) override {
error_ = std::move(error);
}
private:
// Set on complete, on success this is set to an RTCError::OK() error.
std::optional<RTCError> error_;
};
class FakeSetRemoteDescriptionObserver
: public SetRemoteDescriptionObserverInterface {
public:
bool called() const { return error_.has_value(); }
RTCError& error() {
RTC_DCHECK(error_.has_value());
return *error_;
}
// SetRemoteDescriptionObserverInterface implementation.
void OnSetRemoteDescriptionComplete(RTCError error) override {
error_ = std::move(error);
}
private:
// Set on complete, on success this is set to an RTCError::OK() error.
std::optional<RTCError> error_;
};
class MockDataChannelObserver : public DataChannelObserver {
public:
struct Message {
std::string data;
bool binary;
};
explicit MockDataChannelObserver(DataChannelInterface* channel)
: channel_(channel) {
channel_->RegisterObserver(this);
states_.push_back(channel_->state());
}
virtual ~MockDataChannelObserver() { channel_->UnregisterObserver(); }
void OnBufferedAmountChange(uint64_t previous_amount) override {}
void OnStateChange() override {
states_.push_back(channel_->state());
if (state_change_callback_) {
state_change_callback_(states_.back());
}
}
void OnMessage(const DataBuffer& buffer) override {
messages_.push_back(
{std::string(buffer.data.data<char>(), buffer.data.size()),
buffer.binary});
}
bool IsOpen() const { return state() == DataChannelInterface::kOpen; }
std::vector<Message> messages() const { return messages_; }
std::string last_message() const {
if (messages_.empty())
return {};
return messages_.back().data;
}
bool last_message_is_binary() const {
if (messages_.empty())
return false;
return messages_.back().binary;
}
size_t received_message_count() const { return messages_.size(); }
DataChannelInterface::DataState state() const { return states_.back(); }
const std::vector<DataChannelInterface::DataState>& states() const {
return states_;
}
void set_state_change_callback(
std::function<void(DataChannelInterface::DataState)> func) {
state_change_callback_ = std::move(func);
}
private:
rtc::scoped_refptr<DataChannelInterface> channel_;
std::vector<DataChannelInterface::DataState> states_;
std::vector<Message> messages_;
std::function<void(DataChannelInterface::DataState)> state_change_callback_;
};
class MockStatsObserver : public StatsObserver {
public:
MockStatsObserver() : called_(false), stats_() {}
virtual ~MockStatsObserver() {}
virtual void OnComplete(const StatsReports& reports) {
RTC_CHECK(!called_);
called_ = true;
stats_.Clear();
stats_.number_of_reports = reports.size();
for (const auto* r : reports) {
if (r->type() == StatsReport::kStatsReportTypeSsrc) {
stats_.timestamp = r->timestamp();
GetIntValue(r, StatsReport::kStatsValueNameAudioOutputLevel,
&stats_.audio_output_level);
GetIntValue(r, StatsReport::kStatsValueNameAudioInputLevel,
&stats_.audio_input_level);
GetIntValue(r, StatsReport::kStatsValueNameBytesReceived,
&stats_.bytes_received);
GetIntValue(r, StatsReport::kStatsValueNameBytesSent,
&stats_.bytes_sent);
GetInt64Value(r, StatsReport::kStatsValueNameCaptureStartNtpTimeMs,
&stats_.capture_start_ntp_time);
stats_.track_ids.emplace_back();
GetStringValue(r, StatsReport::kStatsValueNameTrackId,
&stats_.track_ids.back());
} else if (r->type() == StatsReport::kStatsReportTypeBwe) {
stats_.timestamp = r->timestamp();
GetIntValue(r, StatsReport::kStatsValueNameAvailableReceiveBandwidth,
&stats_.available_receive_bandwidth);
} else if (r->type() == StatsReport::kStatsReportTypeComponent) {
stats_.timestamp = r->timestamp();
GetStringValue(r, StatsReport::kStatsValueNameDtlsCipher,
&stats_.dtls_cipher);
GetStringValue(r, StatsReport::kStatsValueNameSrtpCipher,
&stats_.srtp_cipher);
}
}
}
bool called() const { return called_; }
size_t number_of_reports() const { return stats_.number_of_reports; }
double timestamp() const { return stats_.timestamp; }
int AudioOutputLevel() const {
RTC_CHECK(called_);
return stats_.audio_output_level;
}
int AudioInputLevel() const {
RTC_CHECK(called_);
return stats_.audio_input_level;
}
int BytesReceived() const {
RTC_CHECK(called_);
return stats_.bytes_received;
}
int BytesSent() const {
RTC_CHECK(called_);
return stats_.bytes_sent;
}
int64_t CaptureStartNtpTime() const {
RTC_CHECK(called_);
return stats_.capture_start_ntp_time;
}
int AvailableReceiveBandwidth() const {
RTC_CHECK(called_);
return stats_.available_receive_bandwidth;
}
std::string DtlsCipher() const {
RTC_CHECK(called_);
return stats_.dtls_cipher;
}
std::string SrtpCipher() const {
RTC_CHECK(called_);
return stats_.srtp_cipher;
}
std::vector<std::string> TrackIds() const {
RTC_CHECK(called_);
return stats_.track_ids;
}
private:
bool GetIntValue(const StatsReport* report,
StatsReport::StatsValueName name,
int* value) {
const StatsReport::Value* v = report->FindValue(name);
if (v) {
// TODO(tommi): We should really just be using an int here :-/
*value = rtc::FromString<int>(v->ToString());
}
return v != nullptr;
}
bool GetInt64Value(const StatsReport* report,
StatsReport::StatsValueName name,
int64_t* value) {
const StatsReport::Value* v = report->FindValue(name);
if (v) {
// TODO(tommi): We should really just be using an int here :-/
*value = rtc::FromString<int64_t>(v->ToString());
}
return v != nullptr;
}
bool GetStringValue(const StatsReport* report,
StatsReport::StatsValueName name,
std::string* value) {
const StatsReport::Value* v = report->FindValue(name);
if (v)
*value = v->ToString();
return v != nullptr;
}
bool called_;
struct {
void Clear() {
number_of_reports = 0;
timestamp = 0;
audio_output_level = 0;
audio_input_level = 0;
bytes_received = 0;
bytes_sent = 0;
capture_start_ntp_time = 0;
available_receive_bandwidth = 0;
dtls_cipher.clear();
srtp_cipher.clear();
track_ids.clear();
}
size_t number_of_reports;
double timestamp;
int audio_output_level;
int audio_input_level;
int bytes_received;
int bytes_sent;
int64_t capture_start_ntp_time;
int available_receive_bandwidth;
std::string dtls_cipher;
std::string srtp_cipher;
std::vector<std::string> track_ids;
} stats_;
};
// Helper class that just stores the report from the callback.
class MockRTCStatsCollectorCallback : public RTCStatsCollectorCallback {
public:
rtc::scoped_refptr<const RTCStatsReport> report() { return report_; }
bool called() const { return called_; }
protected:
void OnStatsDelivered(
const rtc::scoped_refptr<const RTCStatsReport>& report) override {
report_ = report;
called_ = true;
}
private:
bool called_ = false;
rtc::scoped_refptr<const RTCStatsReport> report_;
};
} // namespace webrtc
#endif // PC_TEST_MOCK_PEER_CONNECTION_OBSERVERS_H_