Fix race between destroying SctpTransport and receiving notification on timer thread.

This gets rid of the SctpTransportMap::Retrieve method and forces
everything to go through PostToTransportThread, which behaves safely
with relation to the transport's destruction.

Bug: webrtc:12467
Change-Id: Id4a723c2c985be2a368d2cc5c5e62deb04c509ab
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/208800
Reviewed-by: Niels Moller <nisse@webrtc.org>
Commit-Queue: Taylor <deadbeef@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#33364}
This commit is contained in:
Taylor Brandstetter 2021-02-26 12:56:36 -08:00 committed by Commit Bot
parent 88a51b2902
commit a88fe7be14
3 changed files with 107 additions and 95 deletions

View File

@ -20,6 +20,7 @@ enum PreservedErrno {
// Successful return value from usrsctp callbacks. Is not actually used by
// usrsctp, but all example programs for usrsctp use 1 as their return value.
constexpr int kSctpSuccessReturn = 1;
constexpr int kSctpErrorReturn = 0;
} // namespace
@ -27,7 +28,6 @@ constexpr int kSctpSuccessReturn = 1;
#include <stdio.h>
#include <usrsctp.h>
#include <functional>
#include <memory>
#include <unordered_map>
@ -252,31 +252,20 @@ class SctpTransportMap {
return map_.erase(id) > 0;
}
// Must be called on the transport's network thread to protect against
// simultaneous deletion/deregistration of the transport; if that's not
// guaranteed, use ExecuteWithLock.
SctpTransport* Retrieve(uintptr_t id) const {
webrtc::MutexLock lock(&lock_);
SctpTransport* transport = RetrieveWhileHoldingLock(id);
if (transport) {
RTC_DCHECK_RUN_ON(transport->network_thread());
}
return transport;
}
// Posts |action| to the network thread of the transport identified by |id|
// and returns true if found, all while holding a lock to protect against the
// transport being simultaneously deleted/deregistered, or returns false if
// not found.
bool PostToTransportThread(uintptr_t id,
std::function<void(SctpTransport*)> action) const {
template <typename F>
bool PostToTransportThread(uintptr_t id, F action) const {
webrtc::MutexLock lock(&lock_);
SctpTransport* transport = RetrieveWhileHoldingLock(id);
if (!transport) {
return false;
}
transport->network_thread_->PostTask(ToQueuedTask(
transport->task_safety_, [transport, action]() { action(transport); }));
transport->task_safety_,
[transport, action{std::move(action)}]() { action(transport); }));
return true;
}
@ -429,7 +418,7 @@ class SctpTransport::UsrSctpWrapper {
if (!found) {
RTC_LOG(LS_ERROR)
<< "OnSctpOutboundPacket: Failed to get transport for socket ID "
<< addr;
<< addr << "; possibly was already destroyed.";
return EINVAL;
}
@ -447,28 +436,49 @@ class SctpTransport::UsrSctpWrapper {
struct sctp_rcvinfo rcv,
int flags,
void* ulp_info) {
SctpTransport* transport = GetTransportFromSocket(sock);
if (!transport) {
struct DeleteByFree {
void operator()(void* p) const { free(p); }
};
std::unique_ptr<void, DeleteByFree> owned_data(data, DeleteByFree());
absl::optional<uintptr_t> id = GetTransportIdFromSocket(sock);
if (!id) {
RTC_LOG(LS_ERROR)
<< "OnSctpInboundPacket: Failed to get transport for socket " << sock
<< "; possibly was already destroyed.";
free(data);
return 0;
<< "OnSctpInboundPacket: Failed to get transport ID from socket "
<< sock;
return kSctpErrorReturn;
}
// Sanity check that both methods of getting the SctpTransport pointer
// yield the same result.
RTC_CHECK_EQ(transport, static_cast<SctpTransport*>(ulp_info));
int result =
transport->OnDataOrNotificationFromSctp(data, length, rcv, flags);
free(data);
return result;
if (!g_transport_map_) {
RTC_LOG(LS_ERROR)
<< "OnSctpInboundPacket called after usrsctp uninitialized?";
return kSctpErrorReturn;
}
// PostsToTransportThread protects against the transport being
// simultaneously deregistered/deleted, since this callback may come from
// the SCTP timer thread and thus race with the network thread.
bool found = g_transport_map_->PostToTransportThread(
*id, [owned_data{std::move(owned_data)}, length, rcv,
flags](SctpTransport* transport) {
transport->OnDataOrNotificationFromSctp(owned_data.get(), length, rcv,
flags);
});
if (!found) {
RTC_LOG(LS_ERROR)
<< "OnSctpInboundPacket: Failed to get transport for socket ID "
<< *id << "; possibly was already destroyed.";
return kSctpErrorReturn;
}
return kSctpSuccessReturn;
}
static SctpTransport* GetTransportFromSocket(struct socket* sock) {
static absl::optional<uintptr_t> GetTransportIdFromSocket(
struct socket* sock) {
absl::optional<uintptr_t> ret;
struct sockaddr* addrs = nullptr;
int naddrs = usrsctp_getladdrs(sock, 0, &addrs);
if (naddrs <= 0 || addrs[0].sa_family != AF_CONN) {
return nullptr;
return ret;
}
// usrsctp_getladdrs() returns the addresses bound to this socket, which
// contains the SctpTransport id as sconn_addr. Read the id,
@ -477,17 +487,10 @@ class SctpTransport::UsrSctpWrapper {
// id of the transport that created them, so [0] is as good as any other.
struct sockaddr_conn* sconn =
reinterpret_cast<struct sockaddr_conn*>(&addrs[0]);
if (!g_transport_map_) {
RTC_LOG(LS_ERROR)
<< "GetTransportFromSocket called after usrsctp uninitialized?";
usrsctp_freeladdrs(addrs);
return nullptr;
}
SctpTransport* transport = g_transport_map_->Retrieve(
reinterpret_cast<uintptr_t>(sconn->sconn_addr));
ret = reinterpret_cast<uintptr_t>(sconn->sconn_addr);
usrsctp_freeladdrs(addrs);
return transport;
return ret;
}
// TODO(crbug.com/webrtc/11899): This is a legacy callback signature, remove
@ -496,14 +499,26 @@ class SctpTransport::UsrSctpWrapper {
// Fired on our I/O thread. SctpTransport::OnPacketReceived() gets
// a packet containing acknowledgments, which goes into usrsctp_conninput,
// and then back here.
SctpTransport* transport = GetTransportFromSocket(sock);
if (!transport) {
absl::optional<uintptr_t> id = GetTransportIdFromSocket(sock);
if (!id) {
RTC_LOG(LS_ERROR)
<< "SendThresholdCallback: Failed to get transport for socket "
<< sock << "; possibly was already destroyed.";
<< "SendThresholdCallback: Failed to get transport ID from socket "
<< sock;
return 0;
}
transport->OnSendThresholdCallback();
if (!g_transport_map_) {
RTC_LOG(LS_ERROR)
<< "SendThresholdCallback called after usrsctp uninitialized?";
return 0;
}
bool found = g_transport_map_->PostToTransportThread(
*id,
[](SctpTransport* transport) { transport->OnSendThresholdCallback(); });
if (!found) {
RTC_LOG(LS_ERROR)
<< "SendThresholdCallback: Failed to get transport for socket ID "
<< *id << "; possibly was already destroyed.";
}
return 0;
}
@ -513,17 +528,26 @@ class SctpTransport::UsrSctpWrapper {
// Fired on our I/O thread. SctpTransport::OnPacketReceived() gets
// a packet containing acknowledgments, which goes into usrsctp_conninput,
// and then back here.
SctpTransport* transport = GetTransportFromSocket(sock);
if (!transport) {
absl::optional<uintptr_t> id = GetTransportIdFromSocket(sock);
if (!id) {
RTC_LOG(LS_ERROR)
<< "SendThresholdCallback: Failed to get transport for socket "
<< sock << "; possibly was already destroyed.";
<< "SendThresholdCallback: Failed to get transport ID from socket "
<< sock;
return 0;
}
// Sanity check that both methods of getting the SctpTransport pointer
// yield the same result.
RTC_CHECK_EQ(transport, static_cast<SctpTransport*>(ulp_info));
transport->OnSendThresholdCallback();
if (!g_transport_map_) {
RTC_LOG(LS_ERROR)
<< "SendThresholdCallback called after usrsctp uninitialized?";
return 0;
}
bool found = g_transport_map_->PostToTransportThread(
*id,
[](SctpTransport* transport) { transport->OnSendThresholdCallback(); });
if (!found) {
RTC_LOG(LS_ERROR)
<< "SendThresholdCallback: Failed to get transport for socket ID "
<< *id << "; possibly was already destroyed.";
}
return 0;
}
};
@ -1175,24 +1199,25 @@ void SctpTransport::OnPacketFromSctpToNetwork(
rtc::PacketOptions(), PF_NORMAL);
}
int SctpTransport::InjectDataOrNotificationFromSctpForTesting(
void SctpTransport::InjectDataOrNotificationFromSctpForTesting(
const void* data,
size_t length,
struct sctp_rcvinfo rcv,
int flags) {
return OnDataOrNotificationFromSctp(data, length, rcv, flags);
OnDataOrNotificationFromSctp(data, length, rcv, flags);
}
int SctpTransport::OnDataOrNotificationFromSctp(const void* data,
size_t length,
struct sctp_rcvinfo rcv,
int flags) {
void SctpTransport::OnDataOrNotificationFromSctp(const void* data,
size_t length,
struct sctp_rcvinfo rcv,
int flags) {
RTC_DCHECK_RUN_ON(network_thread_);
// If data is NULL, the SCTP association has been closed.
if (!data) {
RTC_LOG(LS_INFO) << debug_name_
<< "->OnDataOrNotificationFromSctp(...): "
"No data; association closed.";
return kSctpSuccessReturn;
return;
}
// Handle notifications early.
@ -1205,14 +1230,10 @@ int SctpTransport::OnDataOrNotificationFromSctp(const void* data,
<< "->OnDataOrNotificationFromSctp(...): SCTP notification"
<< " length=" << length;
// Copy and dispatch asynchronously
rtc::CopyOnWriteBuffer notification(reinterpret_cast<const uint8_t*>(data),
length);
network_thread_->PostTask(ToQueuedTask(
task_safety_, [this, notification = std::move(notification)]() {
OnNotificationFromSctp(notification);
}));
return kSctpSuccessReturn;
OnNotificationFromSctp(notification);
return;
}
// Log data chunk
@ -1230,7 +1251,7 @@ int SctpTransport::OnDataOrNotificationFromSctp(const void* data,
// Unexpected PPID, dropping
RTC_LOG(LS_ERROR) << "Received an unknown PPID " << ppid
<< " on an SCTP packet. Dropping.";
return kSctpSuccessReturn;
return;
}
// Expect only continuation messages belonging to the same SID. The SCTP
@ -1266,7 +1287,7 @@ int SctpTransport::OnDataOrNotificationFromSctp(const void* data,
if (partial_incoming_message_.size() < kSctpSendBufferSize) {
// We still have space in the buffer. Continue buffering chunks until
// the message is complete before handing it out.
return kSctpSuccessReturn;
return;
} else {
// The sender is exceeding the maximum message size that we announced.
// Spit out a warning but still hand out the partial message. Note that
@ -1280,18 +1301,9 @@ int SctpTransport::OnDataOrNotificationFromSctp(const void* data,
}
}
// Dispatch the complete message.
// The ownership of the packet transfers to |invoker_|. Using
// CopyOnWriteBuffer is the most convenient way to do this.
network_thread_->PostTask(webrtc::ToQueuedTask(
task_safety_, [this, params = std::move(params),
message = partial_incoming_message_]() {
OnDataFromSctpToTransport(params, message);
}));
// Reset the message buffer
// Dispatch the complete message and reset the message buffer.
OnDataFromSctpToTransport(params, partial_incoming_message_);
partial_incoming_message_.Clear();
return kSctpSuccessReturn;
}
void SctpTransport::OnDataFromSctpToTransport(

View File

@ -96,10 +96,10 @@ class SctpTransport : public SctpTransportInternal,
void set_debug_name_for_testing(const char* debug_name) override {
debug_name_ = debug_name;
}
int InjectDataOrNotificationFromSctpForTesting(const void* data,
size_t length,
struct sctp_rcvinfo rcv,
int flags);
void InjectDataOrNotificationFromSctpForTesting(const void* data,
size_t length,
struct sctp_rcvinfo rcv,
int flags);
// Exposed to allow Post call from c-callbacks.
// TODO(deadbeef): Remove this or at least make it return a const pointer.
@ -180,12 +180,12 @@ class SctpTransport : public SctpTransportInternal,
// Called using |invoker_| to send packet on the network.
void OnPacketFromSctpToNetwork(const rtc::CopyOnWriteBuffer& buffer);
// Called on the SCTP thread.
// Called on the network thread.
// Flags are standard socket API flags (RFC 6458).
int OnDataOrNotificationFromSctp(const void* data,
size_t length,
struct sctp_rcvinfo rcv,
int flags);
void OnDataOrNotificationFromSctp(const void* data,
size_t length,
struct sctp_rcvinfo rcv,
int flags);
// Called using |invoker_| to decide what to do with the data.
void OnDataFromSctpToTransport(const ReceiveDataParams& params,
const rtc::CopyOnWriteBuffer& buffer);

View File

@ -282,8 +282,8 @@ TEST_F(SctpTransportTest, MessageInterleavedWithNotification) {
meta.rcv_tsn = 42;
meta.rcv_cumtsn = 42;
chunk.SetData("meow?", 5);
EXPECT_EQ(1, transport1->InjectDataOrNotificationFromSctpForTesting(
chunk.data(), chunk.size(), meta, 0));
transport1->InjectDataOrNotificationFromSctpForTesting(chunk.data(),
chunk.size(), meta, 0);
// Inject a notification in between chunks.
union sctp_notification notification;
@ -292,15 +292,15 @@ TEST_F(SctpTransportTest, MessageInterleavedWithNotification) {
notification.sn_header.sn_type = SCTP_PEER_ADDR_CHANGE;
notification.sn_header.sn_flags = 0;
notification.sn_header.sn_length = sizeof(notification);
EXPECT_EQ(1, transport1->InjectDataOrNotificationFromSctpForTesting(
&notification, sizeof(notification), {0}, MSG_NOTIFICATION));
transport1->InjectDataOrNotificationFromSctpForTesting(
&notification, sizeof(notification), {0}, MSG_NOTIFICATION);
// Inject chunk 2/2
meta.rcv_tsn = 42;
meta.rcv_cumtsn = 43;
chunk.SetData(" rawr!", 6);
EXPECT_EQ(1, transport1->InjectDataOrNotificationFromSctpForTesting(
chunk.data(), chunk.size(), meta, MSG_EOR));
transport1->InjectDataOrNotificationFromSctpForTesting(
chunk.data(), chunk.size(), meta, MSG_EOR);
// Expect the message to contain both chunks.
EXPECT_TRUE_WAIT(ReceivedData(&recv1, 1, "meow? rawr!"), kDefaultTimeout);