Refactor StunMessage a bit

* Add ctors for providing the type and transaction id at construction.
* Update tests to use them instead of SetType+SetTransactionID
* Make sure stun message enum types are based on uint16_t
* Mark SetTransactionID as deprecated.
* Mark SetStunMagicCookie as deprecated (unused in webrtc).
* Add SetTransactionIdForTest for the one test that uses it (might not
  actually need it)
* Make StunRequest::Construct() protected.
  * Add a TODO to follow up on this since construction of StunRequest
    goes through an unnecessarily complex 3-step process involving
    other classes and a virtual method.

Bug: none
Change-Id: Ib013e58f28e7b2b4fcb3b3e1034da31dfc93e9d3
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/264546
Reviewed-by: Niels Moller <nisse@webrtc.org>
Commit-Queue: Tomas Gunnarsson <tommi@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#37079}
This commit is contained in:
Tommi 2022-06-01 15:29:31 +02:00 committed by WebRTC LUCI CQ
parent 80a860532e
commit 408143d5af
17 changed files with 180 additions and 185 deletions

View File

@ -11,6 +11,7 @@
#include "api/transport/stun.h"
#include <string.h>
#include <algorithm>
#include <cstdint>
#include <iterator>
@ -20,6 +21,7 @@
#include "rtc_base/byte_order.h"
#include "rtc_base/checks.h"
#include "rtc_base/crc32.h"
#include "rtc_base/helpers.h"
#include "rtc_base/logging.h"
#include "rtc_base/message_digest.h"
@ -34,11 +36,11 @@ const int k127Utf8CharactersLengthInBytes = 508;
const int kMessageIntegrityAttributeLength = 20;
const int kTheoreticalMaximumAttributeLength = 65535;
uint32_t ReduceTransactionId(const std::string& transaction_id) {
uint32_t ReduceTransactionId(absl::string_view transaction_id) {
RTC_DCHECK(transaction_id.length() == cricket::kStunTransactionIdLength ||
transaction_id.length() ==
cricket::kStunLegacyTransactionIdLength);
ByteBufferReader reader(transaction_id.c_str(), transaction_id.length());
transaction_id.length() == cricket::kStunLegacyTransactionIdLength)
<< transaction_id.length();
ByteBufferReader reader(transaction_id.data(), transaction_id.size());
uint32_t result = 0;
uint32_t next;
while (reader.ReadUInt32(&next)) {
@ -102,10 +104,15 @@ const int SERVER_NOT_REACHABLE_ERROR = 701;
// StunMessage
StunMessage::StunMessage()
: type_(0),
length_(0),
transaction_id_(EMPTY_TRANSACTION_ID),
stun_magic_cookie_(kStunMagicCookie) {
: StunMessage(STUN_INVALID_MESSAGE_TYPE, EMPTY_TRANSACTION_ID) {}
StunMessage::StunMessage(uint16_t type)
: StunMessage(type, GenerateTransactionId()) {}
StunMessage::StunMessage(uint16_t type, absl::string_view transaction_id)
: type_(type),
transaction_id_(transaction_id),
reduced_transaction_id_(ReduceTransactionId(transaction_id_)) {
RTC_DCHECK(IsValidTransactionId(transaction_id_));
}
@ -118,15 +125,6 @@ bool StunMessage::IsLegacy() const {
return false;
}
bool StunMessage::SetTransactionID(const std::string& str) {
if (!IsValidTransactionId(str)) {
return false;
}
transaction_id_ = str;
reduced_transaction_id_ = ReduceTransactionId(transaction_id_);
return true;
}
static bool DesignatedExpertRange(int attr_type) {
return (attr_type >= 0x4000 && attr_type <= 0x7FFF) ||
(attr_type >= 0xC000 && attr_type <= 0xFFFF);
@ -442,6 +440,11 @@ bool StunMessage::ValidateFingerprint(const char* data, size_t size) {
rtc::ComputeCrc32(data, size - fingerprint_attr_size));
}
// static
std::string StunMessage::GenerateTransactionId() {
return rtc::CreateRandomString(kStunTransactionIdLength);
}
bool StunMessage::IsStunMethod(rtc::ArrayView<int> methods,
const char* data,
size_t size) {
@ -589,6 +592,12 @@ void StunMessage::SetStunMagicCookie(uint32_t val) {
stun_magic_cookie_ = val;
}
void StunMessage::SetTransactionIdForTesting(absl::string_view transaction_id) {
RTC_DCHECK(IsValidTransactionId(transaction_id));
transaction_id_ = std::string(transaction_id);
reduced_transaction_id_ = ReduceTransactionId(transaction_id_);
}
StunAttributeValueType StunMessage::GetAttributeValueType(int type) const {
switch (type) {
case STUN_ATTR_MAPPED_ADDRESS:
@ -647,7 +656,7 @@ const StunAttribute* StunMessage::GetAttribute(int type) const {
return NULL;
}
bool StunMessage::IsValidTransactionId(const std::string& transaction_id) {
bool StunMessage::IsValidTransactionId(absl::string_view transaction_id) {
return transaction_id.size() == kStunTransactionIdLength ||
transaction_id.size() == kStunLegacyTransactionIdLength;
}

View File

@ -31,7 +31,8 @@
namespace cricket {
// These are the types of STUN messages defined in RFC 5389.
enum StunMessageType {
enum StunMessageType : uint16_t {
STUN_INVALID_MESSAGE_TYPE = 0x0000,
STUN_BINDING_REQUEST = 0x0001,
STUN_BINDING_INDICATION = 0x0011,
STUN_BINDING_RESPONSE = 0x0101,
@ -144,7 +145,16 @@ class StunXorAddressAttribute;
// that attribute class.
class StunMessage {
public:
// Constructs a StunMessage with an invalid type and empty, legacy length
// (16 bytes, RFC3489) transaction id.
StunMessage();
// Construct a `StunMessage` with a specific type and generate a new
// 12 byte transaction id (RFC5389).
explicit StunMessage(uint16_t type);
StunMessage(uint16_t type, absl::string_view transaction_id);
virtual ~StunMessage();
// The verification status of the message. This is checked on parsing,
@ -169,7 +179,12 @@ class StunMessage {
bool IsLegacy() const;
void SetType(int type) { type_ = static_cast<uint16_t>(type); }
bool SetTransactionID(const std::string& str);
[[deprecated]] bool SetTransactionID(absl::string_view transaction_id) {
if (!IsValidTransactionId(transaction_id))
return false;
SetTransactionIdForTesting(transaction_id);
return true;
}
// Get a list of all of the attribute types in the "comprehension required"
// range that were not recognized.
@ -233,6 +248,9 @@ class StunMessage {
// Verifies that a given buffer is STUN by checking for a correct FINGERPRINT.
static bool ValidateFingerprint(const char* data, size_t size);
// Generates a new 12 byte (RFC5389) transaction id.
static std::string GenerateTransactionId();
// Adds a FINGERPRINT attribute that is valid for the current message.
bool AddFingerprint();
@ -249,7 +267,10 @@ class StunMessage {
// Modify the stun magic cookie used for this STUN message.
// This is used for testing.
void SetStunMagicCookie(uint32_t val);
[[deprecated]] void SetStunMagicCookie(uint32_t val);
// Change the internal transaction id. Used only for testing.
void SetTransactionIdForTesting(absl::string_view transaction_id);
// Contruct a copy of `this`.
std::unique_ptr<StunMessage> Clone() const;
@ -292,7 +313,7 @@ class StunMessage {
private:
StunAttribute* CreateAttribute(int type, size_t length) /* const*/;
const StunAttribute* GetAttribute(int type) const;
static bool IsValidTransactionId(const std::string& transaction_id);
static bool IsValidTransactionId(absl::string_view transaction_id);
bool AddMessageIntegrityOfType(int mi_attr_type,
size_t mi_attr_size,
const char* key,
@ -303,11 +324,11 @@ class StunMessage {
size_t size,
const std::string& password);
uint16_t type_;
uint16_t length_;
uint16_t type_ = STUN_INVALID_MESSAGE_TYPE;
uint16_t length_ = 0;
std::string transaction_id_;
uint32_t reduced_transaction_id_;
uint32_t stun_magic_cookie_;
uint32_t reduced_transaction_id_ = 0;
uint32_t stun_magic_cookie_ = kStunMagicCookie;
// The original buffer for messages created by Read().
std::string buffer_;
IntegrityStatus integrity_ = IntegrityStatus::kNotSet;
@ -635,13 +656,16 @@ enum RelayAttributeType {
// A "GTURN" STUN message.
class RelayMessage : public StunMessage {
public:
using StunMessage::StunMessage;
protected:
StunAttributeValueType GetAttributeValueType(int type) const override;
StunMessage* CreateNew() const override;
};
// Defined in TURN RFC 5766.
enum TurnMessageType {
enum TurnMessageType : uint16_t {
STUN_ALLOCATE_REQUEST = 0x0003,
STUN_ALLOCATE_RESPONSE = 0x0103,
STUN_ALLOCATE_ERROR_RESPONSE = 0x0113,
@ -689,6 +713,9 @@ extern const char STUN_ERROR_REASON_ALLOCATION_MISMATCH[];
extern const char STUN_ERROR_REASON_WRONG_CREDENTIALS[];
extern const char STUN_ERROR_REASON_UNSUPPORTED_PROTOCOL[];
class TurnMessage : public StunMessage {
public:
using StunMessage::StunMessage;
protected:
StunAttributeValueType GetAttributeValueType(int type) const override;
StunMessage* CreateNew() const override;
@ -747,6 +774,9 @@ extern const char STUN_ERROR_REASON_ROLE_CONFLICT[];
// A RFC 5245 ICE STUN message.
class IceMessage : public StunMessage {
public:
using StunMessage::StunMessage;
protected:
StunAttributeValueType GetAttributeValueType(int type) const override;
StunMessage* CreateNew() const override;

View File

@ -761,7 +761,6 @@ TEST_F(StunTest, ReadLegacyMessage) {
TEST_F(StunTest, SetIPv6XorAddressAttributeOwner) {
StunMessage msg;
StunMessage msg2;
size_t size = ReadStunMessage(&msg, kStunMessageWithIPv6XorMappedAddress);
rtc::IPAddress test_address(kIPv6TestAddress1);
@ -775,7 +774,7 @@ TEST_F(StunTest, SetIPv6XorAddressAttributeOwner) {
test_address);
// Owner with a different transaction ID.
msg2.SetTransactionID("ABCDABCDABCD");
StunMessage msg2(STUN_INVALID_MESSAGE_TYPE, "ABCDABCDABCD");
StunXorAddressAttribute addr2(STUN_ATTR_XOR_MAPPED_ADDRESS, 20, NULL);
addr2.SetIP(addr->ipaddr());
addr2.SetPort(addr->port());
@ -809,7 +808,6 @@ TEST_F(StunTest, SetIPv4XorAddressAttributeOwner) {
// should _not_ be affected by a change in owner. IPv4 XOR address uses the
// magic cookie value which is fixed.
StunMessage msg;
StunMessage msg2;
size_t size = ReadStunMessage(&msg, kStunMessageWithIPv4XorMappedAddress);
rtc::IPAddress test_address(kIPv4TestAddress1);
@ -823,7 +821,7 @@ TEST_F(StunTest, SetIPv4XorAddressAttributeOwner) {
test_address);
// Owner with a different transaction ID.
msg2.SetTransactionID("ABCDABCDABCD");
StunMessage msg2(STUN_INVALID_MESSAGE_TYPE, "ABCDABCDABCD");
StunXorAddressAttribute addr2(STUN_ATTR_XOR_MAPPED_ADDRESS, 20, NULL);
addr2.SetIP(addr->ipaddr());
addr2.SetPort(addr->port());
@ -893,13 +891,12 @@ TEST_F(StunTest, CreateAddressInArbitraryOrder) {
}
TEST_F(StunTest, WriteMessageWithIPv6AddressAttribute) {
StunMessage msg;
size_t size = sizeof(kStunMessageWithIPv6MappedAddress);
rtc::IPAddress test_ip(kIPv6TestAddress1);
msg.SetType(STUN_BINDING_REQUEST);
msg.SetTransactionID(
StunMessage msg(
STUN_BINDING_REQUEST,
std::string(reinterpret_cast<const char*>(kTestTransactionId1),
kStunTransactionIdLength));
CheckStunTransactionID(msg, kTestTransactionId1, kStunTransactionIdLength);
@ -922,13 +919,12 @@ TEST_F(StunTest, WriteMessageWithIPv6AddressAttribute) {
}
TEST_F(StunTest, WriteMessageWithIPv4AddressAttribute) {
StunMessage msg;
size_t size = sizeof(kStunMessageWithIPv4MappedAddress);
rtc::IPAddress test_ip(kIPv4TestAddress1);
msg.SetType(STUN_BINDING_RESPONSE);
msg.SetTransactionID(
StunMessage msg(
STUN_BINDING_RESPONSE,
std::string(reinterpret_cast<const char*>(kTestTransactionId1),
kStunTransactionIdLength));
CheckStunTransactionID(msg, kTestTransactionId1, kStunTransactionIdLength);
@ -951,13 +947,12 @@ TEST_F(StunTest, WriteMessageWithIPv4AddressAttribute) {
}
TEST_F(StunTest, WriteMessageWithIPv6XorAddressAttribute) {
StunMessage msg;
size_t size = sizeof(kStunMessageWithIPv6XorMappedAddress);
rtc::IPAddress test_ip(kIPv6TestAddress1);
msg.SetType(STUN_BINDING_RESPONSE);
msg.SetTransactionID(
StunMessage msg(
STUN_BINDING_RESPONSE,
std::string(reinterpret_cast<const char*>(kTestTransactionId2),
kStunTransactionIdLength));
CheckStunTransactionID(msg, kTestTransactionId2, kStunTransactionIdLength);
@ -981,13 +976,12 @@ TEST_F(StunTest, WriteMessageWithIPv6XorAddressAttribute) {
}
TEST_F(StunTest, WriteMessageWithIPv4XoreAddressAttribute) {
StunMessage msg;
size_t size = sizeof(kStunMessageWithIPv4XorMappedAddress);
rtc::IPAddress test_ip(kIPv4TestAddress1);
msg.SetType(STUN_BINDING_RESPONSE);
msg.SetTransactionID(
StunMessage msg(
STUN_BINDING_RESPONSE,
std::string(reinterpret_cast<const char*>(kTestTransactionId1),
kStunTransactionIdLength));
CheckStunTransactionID(msg, kTestTransactionId1, kStunTransactionIdLength);
@ -1083,11 +1077,10 @@ TEST_F(StunTest, ReadMessageWithAnUnknownAttribute) {
}
TEST_F(StunTest, WriteMessageWithAnErrorCodeAttribute) {
StunMessage msg;
size_t size = sizeof(kStunMessageWithErrorAttribute);
msg.SetType(STUN_BINDING_ERROR_RESPONSE);
msg.SetTransactionID(
StunMessage msg(
STUN_BINDING_ERROR_RESPONSE,
std::string(reinterpret_cast<const char*>(kTestTransactionId1),
kStunTransactionIdLength));
CheckStunTransactionID(msg, kTestTransactionId1, kStunTransactionIdLength);
@ -1105,11 +1098,10 @@ TEST_F(StunTest, WriteMessageWithAnErrorCodeAttribute) {
}
TEST_F(StunTest, WriteMessageWithAUInt16ListAttribute) {
StunMessage msg;
size_t size = sizeof(kStunMessageWithUInt16ListAttribute);
msg.SetType(STUN_BINDING_REQUEST);
msg.SetTransactionID(
StunMessage msg(
STUN_BINDING_REQUEST,
std::string(reinterpret_cast<const char*>(kTestTransactionId2),
kStunTransactionIdLength));
CheckStunTransactionID(msg, kTestTransactionId2, kStunTransactionIdLength);
@ -1475,7 +1467,7 @@ static const unsigned char kRelayMessage[] = {
// Test that we can read the GTURN-specific fields.
TEST_F(StunTest, ReadRelayMessage) {
RelayMessage msg, msg2;
RelayMessage msg;
const char* input = reinterpret_cast<const char*>(kRelayMessage);
size_t size = sizeof(kRelayMessage);
@ -1486,8 +1478,7 @@ TEST_F(StunTest, ReadRelayMessage) {
EXPECT_EQ(size - 20, msg.length());
EXPECT_EQ("0123456789ab", msg.transaction_id());
msg2.SetType(STUN_BINDING_REQUEST);
msg2.SetTransactionID("0123456789ab");
RelayMessage msg2(STUN_BINDING_REQUEST, "0123456789ab");
in_addr legacy_in_addr;
legacy_in_addr.s_addr = htonl(17U);
@ -1710,7 +1701,7 @@ TEST_F(StunTest, CopyAttribute) {
// Test Clone
TEST_F(StunTest, Clone) {
IceMessage msg;
IceMessage msg(0, "0123456789ab");
{
auto errorcode = StunAttribute::CreateErrorCode();
errorcode->SetCode(kTestErrorCode);
@ -1736,9 +1727,6 @@ TEST_F(StunTest, Clone) {
auto copy = msg.Clone();
ASSERT_NE(nullptr, copy.get());
msg.SetTransactionID("0123456789ab");
copy->SetTransactionID("0123456789ab");
rtc::ByteBufferWriter out1;
EXPECT_TRUE(msg.Write(&out1));
rtc::ByteBufferWriter out2;
@ -1812,21 +1800,18 @@ TEST_F(StunTest, EqualAttributes) {
}
TEST_F(StunTest, ReduceTransactionIdIsHostOrderIndependent) {
std::string transaction_id = "abcdefghijkl";
StunMessage message;
ASSERT_TRUE(message.SetTransactionID(transaction_id));
const std::string transaction_id = "abcdefghijkl";
StunMessage message(0, transaction_id);
uint32_t reduced_transaction_id = message.reduced_transaction_id();
EXPECT_EQ(reduced_transaction_id, 1835954016u);
}
TEST_F(StunTest, GoogMiscInfo) {
StunMessage msg;
StunMessage msg(STUN_BINDING_REQUEST, "ABCDEFGHIJKL");
const size_t size =
/* msg header */ 20 +
/* attr header */ 4 +
/* 3 * 2 rounded to multiple of 4 */ 8;
msg.SetType(STUN_BINDING_REQUEST);
msg.SetTransactionID("ABCDEFGH");
auto list =
StunAttribute::CreateUInt16ListAttribute(STUN_ATTR_GOOG_MISC_INFO);
list->AddTypeAtIndex(0, 0x1U);
@ -1861,9 +1846,7 @@ TEST_F(StunTest, IsStunMethod) {
}
TEST_F(StunTest, SizeRestrictionOnAttributes) {
StunMessage msg;
msg.SetType(STUN_BINDING_REQUEST);
msg.SetTransactionID("ABCDEFGH");
StunMessage msg(STUN_BINDING_REQUEST, "ABCDEFGHIJKL");
auto long_username = StunAttribute::CreateByteString(STUN_ATTR_USERNAME);
std::string long_string(509, 'x');
long_username->CopyBytes(long_string.c_str(), long_string.size());

View File

@ -179,12 +179,12 @@ class Connection::ConnectionRequest : public StunRequest {
// A ConnectionRequest is a STUN binding used to determine writability.
Connection::ConnectionRequest::ConnectionRequest(StunRequestManager& manager,
Connection* connection)
: StunRequest(manager, std::make_unique<IceMessage>()),
: StunRequest(manager, std::make_unique<IceMessage>(STUN_BINDING_REQUEST)),
connection_(connection) {}
void Connection::ConnectionRequest::Prepare(StunMessage* message) {
RTC_DCHECK_RUN_ON(connection_->network_thread_);
message->SetType(STUN_BINDING_REQUEST);
RTC_DCHECK_EQ(message->type(), STUN_BINDING_REQUEST);
std::string username;
connection_->port()->CreateStunUsername(
connection_->remote_candidate().username(), &username);
@ -728,9 +728,7 @@ void Connection::SendStunBindingResponse(const StunMessage* message) {
}
// Fill in the response.
StunMessage response;
response.SetType(STUN_BINDING_RESPONSE);
response.SetTransactionID(message->transaction_id());
StunMessage response(STUN_BINDING_RESPONSE, message->transaction_id());
const StunUInt32Attribute* retransmit_attr =
message->GetUInt32(STUN_ATTR_RETRANSMIT_COUNT);
if (retransmit_attr) {
@ -776,9 +774,7 @@ void Connection::SendGoogPingResponse(const StunMessage* message) {
RTC_DCHECK(message->type() == GOOG_PING_REQUEST);
// Fill in the response.
StunMessage response;
response.SetType(GOOG_PING_RESPONSE);
response.SetTransactionID(message->transaction_id());
StunMessage response(GOOG_PING_RESPONSE, message->transaction_id());
response.AddMessageIntegrity32(local_candidate().password());
SendResponseMessage(response);
}

View File

@ -3495,8 +3495,7 @@ class P2PTransportChannelPingTest : public ::testing::Test,
int priority,
uint32_t nomination,
const absl::optional<std::string>& piggyback_ping_id) {
IceMessage msg;
msg.SetType(STUN_BINDING_REQUEST);
IceMessage msg(STUN_BINDING_REQUEST);
msg.AddAttribute(std::make_unique<StunByteStringAttribute>(
STUN_ATTR_USERNAME,
conn->local_candidate().username() + ":" + remote_ufrag));
@ -3510,7 +3509,6 @@ class P2PTransportChannelPingTest : public ::testing::Test,
msg.AddAttribute(std::make_unique<StunByteStringAttribute>(
STUN_ATTR_GOOG_LAST_ICE_CHECK_RECEIVED, piggyback_ping_id.value()));
}
msg.SetTransactionID(rtc::CreateRandomString(kStunTransactionIdLength));
msg.AddMessageIntegrity(conn->local_candidate().password());
msg.AddFingerprint();
rtc::ByteBufferWriter buf;

View File

@ -754,13 +754,10 @@ void Port::SendBindingErrorResponse(StunMessage* message,
message->type() == GOOG_PING_REQUEST);
// Fill in the response message.
StunMessage response;
if (message->type() == STUN_BINDING_REQUEST) {
response.SetType(STUN_BINDING_ERROR_RESPONSE);
} else {
response.SetType(GOOG_PING_ERROR_RESPONSE);
}
response.SetTransactionID(message->transaction_id());
StunMessage response(message->type() == STUN_BINDING_REQUEST
? STUN_BINDING_ERROR_RESPONSE
: GOOG_PING_ERROR_RESPONSE,
message->transaction_id());
// When doing GICE, we need to write out the error code incorrectly to
// maintain backwards compatiblility.
@ -805,9 +802,7 @@ void Port::SendUnknownAttributesErrorResponse(
RTC_DCHECK(message->type() == STUN_BINDING_REQUEST);
// Fill in the response message.
StunMessage response;
response.SetType(STUN_BINDING_ERROR_RESPONSE);
response.SetTransactionID(message->transaction_id());
StunMessage response(STUN_BINDING_ERROR_RESPONSE, message->transaction_id());
auto error_attr = StunAttribute::CreateErrorCode();
error_attr->SetCode(STUN_ERROR_UNKNOWN_ATTRIBUTE);

View File

@ -756,14 +756,12 @@ class PortTest : public ::testing::Test, public sigslot::has_slots<> {
EXPECT_TRUE_WAIT(ch2.conn() == NULL, kDefaultTimeout);
}
std::unique_ptr<IceMessage> CreateStunMessage(int type) {
auto msg = std::make_unique<IceMessage>();
msg->SetType(type);
msg->SetTransactionID("TESTTESTTEST");
std::unique_ptr<IceMessage> CreateStunMessage(StunMessageType type) {
auto msg = std::make_unique<IceMessage>(type, "TESTTESTTEST");
return msg;
}
std::unique_ptr<IceMessage> CreateStunMessageWithUsername(
int type,
StunMessageType type,
const std::string& username) {
std::unique_ptr<IceMessage> msg = CreateStunMessage(type);
msg->AddAttribute(std::make_unique<StunByteStringAttribute>(
@ -2319,7 +2317,7 @@ TEST_F(PortTest, TestHandleStunMessageBadFingerprint) {
// Now, add a fingerprint, but munge the message so it's not valid.
in_msg->AddFingerprint();
in_msg->SetTransactionID("TESTTESTBADD");
in_msg->SetTransactionIdForTesting("TESTTESTBADD");
WriteStunMessage(*in_msg, buf.get());
EXPECT_FALSE(port->GetStunMessage(buf->Data(), buf->Length(), addr, &out_msg,
&username));
@ -2337,7 +2335,7 @@ TEST_F(PortTest, TestHandleStunMessageBadFingerprint) {
// Now, add a fingerprint, but munge the message so it's not valid.
in_msg->AddFingerprint();
in_msg->SetTransactionID("TESTTESTBADD");
in_msg->SetTransactionIdForTesting("TESTTESTBADD");
WriteStunMessage(*in_msg, buf.get());
EXPECT_FALSE(port->GetStunMessage(buf->Data(), buf->Length(), addr, &out_msg,
&username));
@ -2356,7 +2354,7 @@ TEST_F(PortTest, TestHandleStunMessageBadFingerprint) {
// Now, add a fingerprint, but munge the message so it's not valid.
in_msg->AddFingerprint();
in_msg->SetTransactionID("TESTTESTBADD");
in_msg->SetTransactionIdForTesting("TESTTESTBADD");
WriteStunMessage(*in_msg, buf.get());
EXPECT_FALSE(port->GetStunMessage(buf->Data(), buf->Length(), addr, &out_msg,
&username));
@ -3413,7 +3411,7 @@ TEST_F(PortTest, TestErrorResponseMakesGoogPingFallBackToStunBinding) {
// But rather than the RESPONSE...feedback an error.
StunMessage error_response;
error_response.SetType(GOOG_PING_ERROR_RESPONSE);
error_response.SetTransactionID(response2->transaction_id());
error_response.SetTransactionIdForTesting(response2->transaction_id());
error_response.AddMessageIntegrity32("rpass");
rtc::ByteBufferWriter buf;
error_response.Write(&buf);

View File

@ -193,12 +193,10 @@ void StunRequestManager::SendPacket(const void* data,
StunRequest::StunRequest(StunRequestManager& manager)
: manager_(manager),
msg_(new StunMessage()),
msg_(new StunMessage(STUN_INVALID_MESSAGE_TYPE)),
tstamp_(0),
count_(0),
timeout_(false) {
msg_->SetTransactionID(rtc::CreateRandomString(kStunTransactionIdLength));
}
timeout_(false) {}
StunRequest::StunRequest(StunRequestManager& manager,
std::unique_ptr<StunMessage> message)
@ -207,7 +205,7 @@ StunRequest::StunRequest(StunRequestManager& manager,
tstamp_(0),
count_(0),
timeout_(false) {
msg_->SetTransactionID(rtc::CreateRandomString(kStunTransactionIdLength));
RTC_DCHECK(!msg_->transaction_id().empty());
}
StunRequest::~StunRequest() {
@ -215,10 +213,15 @@ StunRequest::~StunRequest() {
}
void StunRequest::Construct() {
if (msg_->type() == 0) {
Prepare(msg_.get());
RTC_DCHECK(msg_->type() != 0);
}
// TODO(tommi): The implementation assumes that Construct() is only called
// once (see `StunRequestManager::SendDelayed` below). However, these steps to
// construct a request object are odd to have at this level (virtual method
// called from the parent class, _after_ construction) and also triggered
// from a different class... via a "Send" method.
// Remove `Prepare`, `Construct` and make construction of the message objects
// separate from the StunRequest and StunRequestManager classes.
Prepare(msg_.get());
RTC_DCHECK_NE(msg_->type(), 0);
}
int StunRequest::type() {

View File

@ -91,9 +91,6 @@ class StunRequest : public rtc::MessageHandler {
std::unique_ptr<StunMessage> message);
~StunRequest() override;
// Causes our wrapped StunMessage to be Prepared
void Construct();
// The manager handling this request (if it has been scheduled for sending).
StunRequestManager* manager() { return &manager_; }
@ -117,6 +114,11 @@ class StunRequest : public rtc::MessageHandler {
protected:
friend class StunRequestManager;
// Causes our wrapped StunMessage to be Prepared.
// Only called by StunRequestManager.
// TODO(tommi): get rid of this (see cc file).
void Construct();
// Fills in a request object to be sent. Note that request's transaction ID
// will already be set and cannot be changed.
virtual void Prepare(StunMessage* message) {}

View File

@ -15,6 +15,7 @@
#include "rtc_base/fake_clock.h"
#include "rtc_base/gunit.h"
#include "rtc_base/helpers.h"
#include "rtc_base/logging.h"
#include "rtc_base/time_utils.h"
#include "test/gtest.h"
@ -24,11 +25,8 @@ namespace {
std::unique_ptr<StunMessage> CreateStunMessage(
StunMessageType type,
const StunMessage* req = nullptr) {
std::unique_ptr<StunMessage> msg = std::make_unique<StunMessage>();
msg->SetType(type);
if (req) {
msg->SetTransactionID(req->transaction_id());
}
std::unique_ptr<StunMessage> msg = std::make_unique<StunMessage>(
type, req ? req->transaction_id() : StunMessage::GenerateTransactionId());
return msg;
}
@ -199,8 +197,7 @@ TEST_F(StunRequestTest, TestNoEmptyRequest) {
manager_.SendDelayed(request, 100);
StunMessage dummy_req;
dummy_req.SetTransactionID(request->id());
StunMessage dummy_req(0, request->id());
std::unique_ptr<StunMessage> res =
CreateStunMessage(STUN_BINDING_RESPONSE, &dummy_req);

View File

@ -53,7 +53,7 @@ void StunServer::OnPacket(rtc::AsyncPacketSocket* socket,
void StunServer::OnBindingRequest(StunMessage* msg,
const rtc::SocketAddress& remote_addr) {
StunMessage response;
StunMessage response(STUN_BINDING_RESPONSE, msg->transaction_id());
GetStunBindResponse(msg, remote_addr, &response);
SendResponse(response, remote_addr);
}
@ -62,9 +62,8 @@ void StunServer::SendErrorResponse(const StunMessage& msg,
const rtc::SocketAddress& addr,
int error_code,
const char* error_desc) {
StunMessage err_msg;
err_msg.SetType(GetStunErrorResponseType(msg.type()));
err_msg.SetTransactionID(msg.transaction_id());
StunMessage err_msg(GetStunErrorResponseType(msg.type()),
msg.transaction_id());
auto err_code = StunAttribute::CreateErrorCode();
err_code->SetCode(error_code);
@ -86,8 +85,8 @@ void StunServer::SendResponse(const StunMessage& msg,
void StunServer::GetStunBindResponse(StunMessage* message,
const rtc::SocketAddress& remote_addr,
StunMessage* response) const {
response->SetType(STUN_BINDING_RESPONSE);
response->SetTransactionID(message->transaction_id());
RTC_DCHECK_EQ(response->type(), STUN_BINDING_RESPONSE);
RTC_DCHECK_EQ(response->transaction_id(), message->transaction_id());
// Tell the user the address that we received their message from.
std::unique_ptr<StunAddressAttribute> mapped_addr;

View File

@ -76,11 +76,9 @@ class StunServerTest : public ::testing::Test {
#if !defined(THREAD_SANITIZER)
TEST_F(StunServerTest, TestGood) {
StunMessage req;
// kStunLegacyTransactionIdLength = 16 for legacy RFC 3489 request
std::string transaction_id = "0123456789abcdef";
req.SetType(STUN_BINDING_REQUEST);
req.SetTransactionID(transaction_id);
StunMessage req(STUN_BINDING_REQUEST, transaction_id);
Send(req);
StunMessage* msg = Receive();
@ -98,12 +96,10 @@ TEST_F(StunServerTest, TestGood) {
}
TEST_F(StunServerTest, TestGoodXorMappedAddr) {
StunMessage req;
// kStunTransactionIdLength = 12 for RFC 5389 request
// StunMessage::Write will automatically insert magic cookie (0x2112A442)
std::string transaction_id = "0123456789ab";
req.SetType(STUN_BINDING_REQUEST);
req.SetTransactionID(transaction_id);
StunMessage req(STUN_BINDING_REQUEST, transaction_id);
Send(req);
StunMessage* msg = Receive();
@ -122,11 +118,9 @@ TEST_F(StunServerTest, TestGoodXorMappedAddr) {
// Send legacy RFC 3489 request, should not get xor mapped addr
TEST_F(StunServerTest, TestNoXorMappedAddr) {
StunMessage req;
// kStunLegacyTransactionIdLength = 16 for legacy RFC 3489 request
std::string transaction_id = "0123456789abcdef";
req.SetType(STUN_BINDING_REQUEST);
req.SetTransactionID(transaction_id);
StunMessage req(STUN_BINDING_REQUEST, transaction_id);
Send(req);
StunMessage* msg = Receive();

View File

@ -28,7 +28,7 @@ void TestStunServer::OnBindingRequest(StunMessage* msg,
if (fake_stun_addr_.IsNil()) {
StunServer::OnBindingRequest(msg, remote_addr);
} else {
StunMessage response;
StunMessage response(STUN_BINDING_RESPONSE, msg->transaction_id());
GetStunBindResponse(msg, fake_stun_addr_, &response);
SendResponse(response, remote_addr);
}

View File

@ -1374,12 +1374,13 @@ void TurnPort::MaybeAddTurnLoggingId(StunMessage* msg) {
}
TurnAllocateRequest::TurnAllocateRequest(TurnPort* port)
: StunRequest(port->request_manager(), std::make_unique<TurnMessage>()),
: StunRequest(port->request_manager(),
std::make_unique<TurnMessage>(TURN_ALLOCATE_REQUEST)),
port_(port) {}
void TurnAllocateRequest::Prepare(StunMessage* message) {
// Create the request as indicated in RFC 5766, Section 6.1.
message->SetType(TURN_ALLOCATE_REQUEST);
RTC_DCHECK_EQ(message->type(), TURN_ALLOCATE_REQUEST);
auto transport_attr =
StunAttribute::CreateUInt32(STUN_ATTR_REQUESTED_TRANSPORT);
transport_attr->SetValue(IPPROTO_UDP << 24);
@ -1563,14 +1564,15 @@ void TurnAllocateRequest::OnTryAlternate(StunMessage* response, int code) {
}
TurnRefreshRequest::TurnRefreshRequest(TurnPort* port)
: StunRequest(port->request_manager(), std::make_unique<TurnMessage>()),
: StunRequest(port->request_manager(),
std::make_unique<TurnMessage>(TURN_REFRESH_REQUEST)),
port_(port),
lifetime_(-1) {}
void TurnRefreshRequest::Prepare(StunMessage* message) {
// Create the request as indicated in RFC 5766, Section 7.1.
// No attributes need to be included.
message->SetType(TURN_REFRESH_REQUEST);
RTC_DCHECK_EQ(message->type(), TURN_REFRESH_REQUEST);
if (lifetime_ > -1) {
message->AddAttribute(
std::make_unique<StunUInt32Attribute>(STUN_ATTR_LIFETIME, lifetime_));
@ -1646,7 +1648,9 @@ TurnCreatePermissionRequest::TurnCreatePermissionRequest(
TurnEntry* entry,
const rtc::SocketAddress& ext_addr,
const std::string& remote_ufrag)
: StunRequest(port->request_manager(), std::make_unique<TurnMessage>()),
: StunRequest(
port->request_manager(),
std::make_unique<TurnMessage>(TURN_CREATE_PERMISSION_REQUEST)),
port_(port),
entry_(entry),
ext_addr_(ext_addr),
@ -1719,7 +1723,8 @@ TurnChannelBindRequest::TurnChannelBindRequest(
TurnEntry* entry,
int channel_id,
const rtc::SocketAddress& ext_addr)
: StunRequest(port->request_manager(), std::make_unique<TurnMessage>()),
: StunRequest(port->request_manager(),
std::make_unique<TurnMessage>(TURN_CHANNEL_BIND_REQUEST)),
port_(port),
entry_(entry),
channel_id_(channel_id),
@ -1730,7 +1735,7 @@ TurnChannelBindRequest::TurnChannelBindRequest(
void TurnChannelBindRequest::Prepare(StunMessage* message) {
// Create the request as indicated in RFC5766, Section 11.1.
message->SetType(TURN_CHANNEL_BIND_REQUEST);
RTC_DCHECK_EQ(message->type(), TURN_CHANNEL_BIND_REQUEST);
message->AddAttribute(std::make_unique<StunUInt32Attribute>(
STUN_ATTR_CHANNEL_NUMBER, channel_id_ << 16));
message->AddAttribute(std::make_unique<StunXorAddressAttribute>(
@ -1824,9 +1829,7 @@ int TurnEntry::Send(const void* data,
!port_->TurnCustomizerAllowChannelData(data, size, payload)) {
// If we haven't bound the channel yet, we have to use a Send Indication.
// The turn_customizer_ can also make us use Send Indication.
TurnMessage msg;
msg.SetType(TURN_SEND_INDICATION);
msg.SetTransactionID(rtc::CreateRandomString(kStunTransactionIdLength));
TurnMessage msg(TURN_SEND_INDICATION);
msg.AddAttribute(std::make_unique<StunXorAddressAttribute>(
STUN_ATTR_XOR_PEER_ADDRESS, ext_addr_));
msg.AddAttribute(

View File

@ -100,27 +100,21 @@ class TurnServerAllocation::Channel : public rtc::MessageHandlerAutoCleanup {
rtc::SocketAddress peer_;
};
static bool InitResponse(const StunMessage* req, StunMessage* resp) {
int resp_type = (req) ? GetStunSuccessResponseType(req->type()) : -1;
if (resp_type == -1)
return false;
resp->SetType(resp_type);
resp->SetTransactionID(req->transaction_id());
return true;
int GetStunSuccessResponseTypeOrZero(const StunMessage& req) {
const int resp_type = GetStunSuccessResponseType(req.type());
return resp_type == -1 ? 0 : resp_type;
}
static bool InitErrorResponse(const StunMessage* req,
int code,
int GetStunErrorResponseTypeOrZero(const StunMessage& req) {
const int resp_type = GetStunErrorResponseType(req.type());
return resp_type == -1 ? 0 : resp_type;
}
static void InitErrorResponse(int code,
const std::string& reason,
StunMessage* resp) {
int resp_type = (req) ? GetStunErrorResponseType(req->type()) : -1;
if (resp_type == -1)
return false;
resp->SetType(resp_type);
resp->SetTransactionID(req->transaction_id());
resp->AddAttribute(std::make_unique<cricket::StunErrorCodeAttribute>(
STUN_ATTR_ERROR_CODE, code, reason));
return true;
}
TurnServer::TurnServer(rtc::Thread* thread)
@ -380,9 +374,8 @@ bool TurnServer::CheckAuthorization(TurnServerConnection* conn,
void TurnServer::HandleBindingRequest(TurnServerConnection* conn,
const StunMessage* req) {
StunMessage response;
InitResponse(req, &response);
StunMessage response(GetStunSuccessResponseTypeOrZero(*req),
req->transaction_id());
// Tell the user the address that we received their request from.
auto mapped_addr_attr = std::make_unique<StunXorAddressAttribute>(
STUN_ATTR_XOR_MAPPED_ADDRESS, conn->src());
@ -487,8 +480,9 @@ void TurnServer::SendErrorResponse(TurnServerConnection* conn,
int code,
const std::string& reason) {
RTC_DCHECK_RUN_ON(thread_);
TurnMessage resp;
InitErrorResponse(req, code, reason, &resp);
TurnMessage resp(GetStunErrorResponseTypeOrZero(*req), req->transaction_id());
InitErrorResponse(code, reason, &resp);
RTC_LOG(LS_INFO) << "Sending error response, type=" << resp.type()
<< ", code=" << code << ", reason=" << reason;
SendStun(conn, &resp);
@ -498,8 +492,8 @@ void TurnServer::SendErrorResponseWithRealmAndNonce(TurnServerConnection* conn,
const StunMessage* msg,
int code,
const std::string& reason) {
TurnMessage resp;
InitErrorResponse(msg, code, reason, &resp);
TurnMessage resp(GetStunErrorResponseTypeOrZero(*msg), msg->transaction_id());
InitErrorResponse(code, reason, &resp);
int64_t timestamp = rtc::TimeMillis();
if (ts_for_next_nonce_) {
@ -517,8 +511,8 @@ void TurnServer::SendErrorResponseWithAlternateServer(
TurnServerConnection* conn,
const StunMessage* msg,
const rtc::SocketAddress& addr) {
TurnMessage resp;
InitErrorResponse(msg, STUN_ERROR_TRY_ALTERNATE,
TurnMessage resp(GetStunErrorResponseTypeOrZero(*msg), msg->transaction_id());
InitErrorResponse(STUN_ERROR_TRY_ALTERNATE,
STUN_ERROR_REASON_TRY_ALTERNATE_SERVER, &resp);
resp.AddAttribute(
std::make_unique<StunAddressAttribute>(STUN_ATTR_ALTERNATE_SERVER, addr));
@ -671,7 +665,7 @@ void TurnServerAllocation::HandleAllocateRequest(const TurnMessage* msg) {
username_ = username_attr->GetString();
// Figure out the lifetime and start the allocation timer.
int lifetime_secs = ComputeLifetime(msg);
int lifetime_secs = ComputeLifetime(*msg);
thread_->PostDelayed(RTC_FROM_HERE, lifetime_secs * 1000, this,
MSG_ALLOCATION_TIMEOUT);
@ -679,8 +673,8 @@ void TurnServerAllocation::HandleAllocateRequest(const TurnMessage* msg) {
<< ": Created allocation with lifetime=" << lifetime_secs;
// We've already validated all the important bits; just send a response here.
TurnMessage response;
InitResponse(msg, &response);
TurnMessage response(GetStunSuccessResponseTypeOrZero(*msg),
msg->transaction_id());
auto mapped_addr_attr = std::make_unique<StunXorAddressAttribute>(
STUN_ATTR_XOR_MAPPED_ADDRESS, conn_.src());
@ -697,7 +691,7 @@ void TurnServerAllocation::HandleAllocateRequest(const TurnMessage* msg) {
void TurnServerAllocation::HandleRefreshRequest(const TurnMessage* msg) {
// Figure out the new lifetime.
int lifetime_secs = ComputeLifetime(msg);
int lifetime_secs = ComputeLifetime(*msg);
// Reset the expiration timer.
thread_->Clear(this, MSG_ALLOCATION_TIMEOUT);
@ -708,8 +702,8 @@ void TurnServerAllocation::HandleRefreshRequest(const TurnMessage* msg) {
<< ": Refreshed allocation, lifetime=" << lifetime_secs;
// Send a success response with a LIFETIME attribute.
TurnMessage response;
InitResponse(msg, &response);
TurnMessage response(GetStunSuccessResponseTypeOrZero(*msg),
msg->transaction_id());
auto lifetime_attr =
std::make_unique<StunUInt32Attribute>(STUN_ATTR_LIFETIME, lifetime_secs);
@ -763,8 +757,8 @@ void TurnServerAllocation::HandleCreatePermissionRequest(
<< peer_attr->GetAddress().ToSensitiveString();
// Send a success response.
TurnMessage response;
InitResponse(msg, &response);
TurnMessage response(GetStunSuccessResponseTypeOrZero(*msg),
msg->transaction_id());
SendResponse(&response);
}
@ -812,8 +806,8 @@ void TurnServerAllocation::HandleChannelBindRequest(const TurnMessage* msg) {
<< ", peer=" << peer_attr->GetAddress().ToSensitiveString();
// Send a success response.
TurnMessage response;
InitResponse(msg, &response);
TurnMessage response(GetStunSuccessResponseTypeOrZero(*msg),
msg->transaction_id());
SendResponse(&response);
}
@ -850,9 +844,7 @@ void TurnServerAllocation::OnExternalPacket(
} else if (!server_->enable_permission_checks_ ||
HasPermission(addr.ipaddr())) {
// No channel, but a permission exists. Send as a data indication.
TurnMessage msg;
msg.SetType(TURN_DATA_INDICATION);
msg.SetTransactionID(rtc::CreateRandomString(kStunTransactionIdLength));
TurnMessage msg(TURN_DATA_INDICATION);
msg.AddAttribute(std::make_unique<StunXorAddressAttribute>(
STUN_ATTR_XOR_PEER_ADDRESS, addr));
msg.AddAttribute(
@ -865,10 +857,10 @@ void TurnServerAllocation::OnExternalPacket(
}
}
int TurnServerAllocation::ComputeLifetime(const TurnMessage* msg) {
int TurnServerAllocation::ComputeLifetime(const TurnMessage& msg) {
// Return the smaller of our default lifetime and the requested lifetime.
int lifetime = kDefaultAllocationTimeout / 1000; // convert to seconds
const StunUInt32Attribute* lifetime_attr = msg->GetUInt32(STUN_ATTR_LIFETIME);
const StunUInt32Attribute* lifetime_attr = msg.GetUInt32(STUN_ATTR_LIFETIME);
if (lifetime_attr && static_cast<int>(lifetime_attr->value()) < lifetime) {
lifetime = static_cast<int>(lifetime_attr->value());
}

View File

@ -108,7 +108,7 @@ class TurnServerAllocation : public rtc::MessageHandlerAutoCleanup,
const rtc::SocketAddress& addr,
const int64_t& packet_time_us);
static int ComputeLifetime(const TurnMessage* msg);
static int ComputeLifetime(const TurnMessage& msg);
bool HasPermission(const rtc::IPAddress& addr);
void AddPermission(const rtc::IPAddress& addr);
Permission* FindPermission(const rtc::IPAddress& addr) const;

View File

@ -137,12 +137,8 @@ void StunProber::Requester::SendStunRequest() {
RTC_DCHECK(thread_checker_.IsCurrent());
requests_.push_back(new Request());
Request& request = *(requests_.back());
cricket::StunMessage message;
// Random transaction ID, STUN_BINDING_REQUEST
message.SetTransactionID(
rtc::CreateRandomString(cricket::kStunTransactionIdLength));
message.SetType(cricket::STUN_BINDING_REQUEST);
cricket::StunMessage message(cricket::STUN_BINDING_REQUEST);
std::unique_ptr<rtc::ByteBufferWriter> request_packet(
new rtc::ByteBufferWriter(nullptr, kMaxUdpBufferSize));