diff --git a/net/dcsctp/common/math.h b/net/dcsctp/common/math.h index ee161d2c8a..12f690ed57 100644 --- a/net/dcsctp/common/math.h +++ b/net/dcsctp/common/math.h @@ -16,7 +16,19 @@ namespace dcsctp { // used to e.g. pad chunks or parameters to an even 32-bit offset. template IntType RoundUpTo4(IntType val) { - return (val + 3) & -4; + return (val + 3) & ~3; +} + +// Similarly, rounds down `val` to the nearest value that is divisible by four. +template +IntType RoundDownTo4(IntType val) { + return val & ~3; +} + +// Returns true if `val` is divisible by four. +template +bool IsDivisibleBy4(IntType val) { + return (val & 3) == 0; } } // namespace dcsctp diff --git a/net/dcsctp/common/math_test.cc b/net/dcsctp/common/math_test.cc index 902aefa906..f95dfbdb55 100644 --- a/net/dcsctp/common/math_test.cc +++ b/net/dcsctp/common/math_test.cc @@ -15,17 +15,101 @@ namespace dcsctp { namespace { TEST(MathUtilTest, CanRoundUpTo4) { - EXPECT_EQ(RoundUpTo4(0), 0); - EXPECT_EQ(RoundUpTo4(1), 4); - EXPECT_EQ(RoundUpTo4(2), 4); - EXPECT_EQ(RoundUpTo4(3), 4); - EXPECT_EQ(RoundUpTo4(4), 4); - EXPECT_EQ(RoundUpTo4(5), 8); - EXPECT_EQ(RoundUpTo4(6), 8); - EXPECT_EQ(RoundUpTo4(7), 8); - EXPECT_EQ(RoundUpTo4(8), 8); - EXPECT_EQ(RoundUpTo4(10000000000), 10000000000); - EXPECT_EQ(RoundUpTo4(10000000001), 10000000004); + // Signed numbers + EXPECT_EQ(RoundUpTo4(static_cast(-5)), -4); + EXPECT_EQ(RoundUpTo4(static_cast(-4)), -4); + EXPECT_EQ(RoundUpTo4(static_cast(-3)), 0); + EXPECT_EQ(RoundUpTo4(static_cast(-2)), 0); + EXPECT_EQ(RoundUpTo4(static_cast(-1)), 0); + EXPECT_EQ(RoundUpTo4(static_cast(0)), 0); + EXPECT_EQ(RoundUpTo4(static_cast(1)), 4); + EXPECT_EQ(RoundUpTo4(static_cast(2)), 4); + EXPECT_EQ(RoundUpTo4(static_cast(3)), 4); + EXPECT_EQ(RoundUpTo4(static_cast(4)), 4); + EXPECT_EQ(RoundUpTo4(static_cast(5)), 8); + EXPECT_EQ(RoundUpTo4(static_cast(6)), 8); + EXPECT_EQ(RoundUpTo4(static_cast(7)), 8); + EXPECT_EQ(RoundUpTo4(static_cast(8)), 8); + EXPECT_EQ(RoundUpTo4(static_cast(10000000000)), 10000000000); + EXPECT_EQ(RoundUpTo4(static_cast(10000000001)), 10000000004); + + // Unsigned numbers + EXPECT_EQ(RoundUpTo4(static_cast(0)), 0u); + EXPECT_EQ(RoundUpTo4(static_cast(1)), 4u); + EXPECT_EQ(RoundUpTo4(static_cast(2)), 4u); + EXPECT_EQ(RoundUpTo4(static_cast(3)), 4u); + EXPECT_EQ(RoundUpTo4(static_cast(4)), 4u); + EXPECT_EQ(RoundUpTo4(static_cast(5)), 8u); + EXPECT_EQ(RoundUpTo4(static_cast(6)), 8u); + EXPECT_EQ(RoundUpTo4(static_cast(7)), 8u); + EXPECT_EQ(RoundUpTo4(static_cast(8)), 8u); + EXPECT_EQ(RoundUpTo4(static_cast(10000000000)), 10000000000u); + EXPECT_EQ(RoundUpTo4(static_cast(10000000001)), 10000000004u); +} + +TEST(MathUtilTest, CanRoundDownTo4) { + // Signed numbers + EXPECT_EQ(RoundDownTo4(static_cast(-5)), -8); + EXPECT_EQ(RoundDownTo4(static_cast(-4)), -4); + EXPECT_EQ(RoundDownTo4(static_cast(-3)), -4); + EXPECT_EQ(RoundDownTo4(static_cast(-2)), -4); + EXPECT_EQ(RoundDownTo4(static_cast(-1)), -4); + EXPECT_EQ(RoundDownTo4(static_cast(0)), 0); + EXPECT_EQ(RoundDownTo4(static_cast(1)), 0); + EXPECT_EQ(RoundDownTo4(static_cast(2)), 0); + EXPECT_EQ(RoundDownTo4(static_cast(3)), 0); + EXPECT_EQ(RoundDownTo4(static_cast(4)), 4); + EXPECT_EQ(RoundDownTo4(static_cast(5)), 4); + EXPECT_EQ(RoundDownTo4(static_cast(6)), 4); + EXPECT_EQ(RoundDownTo4(static_cast(7)), 4); + EXPECT_EQ(RoundDownTo4(static_cast(8)), 8); + EXPECT_EQ(RoundDownTo4(static_cast(10000000000)), 10000000000); + EXPECT_EQ(RoundDownTo4(static_cast(10000000001)), 10000000000); + + // Unsigned numbers + EXPECT_EQ(RoundDownTo4(static_cast(0)), 0u); + EXPECT_EQ(RoundDownTo4(static_cast(1)), 0u); + EXPECT_EQ(RoundDownTo4(static_cast(2)), 0u); + EXPECT_EQ(RoundDownTo4(static_cast(3)), 0u); + EXPECT_EQ(RoundDownTo4(static_cast(4)), 4u); + EXPECT_EQ(RoundDownTo4(static_cast(5)), 4u); + EXPECT_EQ(RoundDownTo4(static_cast(6)), 4u); + EXPECT_EQ(RoundDownTo4(static_cast(7)), 4u); + EXPECT_EQ(RoundDownTo4(static_cast(8)), 8u); + EXPECT_EQ(RoundDownTo4(static_cast(10000000000)), 10000000000u); + EXPECT_EQ(RoundDownTo4(static_cast(10000000001)), 10000000000u); +} + +TEST(MathUtilTest, IsDivisibleBy4) { + // Signed numbers + EXPECT_EQ(IsDivisibleBy4(static_cast(-4)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast(-3)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(-2)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(-1)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(0)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast(1)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(2)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(3)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(4)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast(5)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(6)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(7)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(8)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast(10000000000)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast(10000000001)), false); + + // Unsigned numbers + EXPECT_EQ(IsDivisibleBy4(static_cast(0)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast(1)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(2)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(3)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(4)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast(5)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(6)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(7)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(8)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast(10000000000)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast(10000000001)), false); } } // namespace diff --git a/net/dcsctp/packet/sctp_packet.cc b/net/dcsctp/packet/sctp_packet.cc index 1e12367263..da06ccf867 100644 --- a/net/dcsctp/packet/sctp_packet.cc +++ b/net/dcsctp/packet/sctp_packet.cc @@ -52,11 +52,11 @@ SctpPacket::Builder::Builder(VerificationTag verification_tag, : verification_tag_(verification_tag), source_port_(options.local_port), dest_port_(options.remote_port), - max_mtu_(options.mtu) {} + max_packet_size_(RoundDownTo4(options.mtu)) {} SctpPacket::Builder& SctpPacket::Builder::Add(const Chunk& chunk) { if (out_.empty()) { - out_.reserve(max_mtu_); + out_.reserve(max_packet_size_); out_.resize(SctpPacket::kHeaderSize); BoundedByteWriter buffer(out_); buffer.Store16<0>(source_port_); @@ -64,14 +64,31 @@ SctpPacket::Builder& SctpPacket::Builder::Add(const Chunk& chunk) { buffer.Store32<4>(*verification_tag_); // Checksum is at offset 8 - written when calling Build(); } + RTC_DCHECK(IsDivisibleBy4(out_.size())); + chunk.SerializeTo(out_); if (out_.size() % 4 != 0) { out_.resize(RoundUpTo4(out_.size())); } + RTC_DCHECK(out_.size() <= max_packet_size_) + << "Exceeded max size, data=" << out_.size() + << ", max_size=" << max_packet_size_; return *this; } +size_t SctpPacket::Builder::bytes_remaining() const { + if (out_.empty()) { + // The packet header (CommonHeader) hasn't been written yet: + return max_packet_size_ - kHeaderSize; + } else if (out_.size() > max_packet_size_) { + RTC_DCHECK(false) << "Exceeded max size, data=" << out_.size() + << ", max_size=" << max_packet_size_; + return 0; + } + return max_packet_size_ - out_.size(); +} + std::vector SctpPacket::Builder::Build() { std::vector out; out_.swap(out); @@ -80,6 +97,11 @@ std::vector SctpPacket::Builder::Build() { uint32_t crc = GenerateCrc32C(out); BoundedByteWriter(out).Store32<8>(crc); } + + RTC_DCHECK(out.size() <= max_packet_size_) + << "Exceeded max size, data=" << out.size() + << ", max_size=" << max_packet_size_; + return out; } diff --git a/net/dcsctp/packet/sctp_packet.h b/net/dcsctp/packet/sctp_packet.h index 927b8dbd41..2600caf7a9 100644 --- a/net/dcsctp/packet/sctp_packet.h +++ b/net/dcsctp/packet/sctp_packet.h @@ -65,10 +65,9 @@ class SctpPacket { // Adds a chunk to the to-be-built SCTP packet. Builder& Add(const Chunk& chunk); - // The number of bytes remaining in the packet, until the MTU is reached. - size_t bytes_remaining() const { - return out_.size() >= max_mtu_ ? 0 : max_mtu_ - out_.size(); - } + // The number of bytes remaining in the packet for chunk storage until the + // packet reaches its maximum size. + size_t bytes_remaining() const; // Indicates if any packets have been added to the builder. bool empty() const { return out_.empty(); } @@ -82,7 +81,9 @@ class SctpPacket { VerificationTag verification_tag_; uint16_t source_port_; uint16_t dest_port_; - size_t max_mtu_; + // The maximum packet size is always even divisible by four, as chunks are + // always padded to a size even divisible by four. + size_t max_packet_size_; std::vector out_; }; diff --git a/net/dcsctp/packet/sctp_packet_test.cc b/net/dcsctp/packet/sctp_packet_test.cc index ece1b7bbd7..7438315eec 100644 --- a/net/dcsctp/packet/sctp_packet_test.cc +++ b/net/dcsctp/packet/sctp_packet_test.cc @@ -15,6 +15,7 @@ #include "api/array_view.h" #include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/common/math.h" #include "net/dcsctp/packet/chunk/abort_chunk.h" #include "net/dcsctp/packet/chunk/cookie_ack_chunk.h" #include "net/dcsctp/packet/chunk/data_chunk.h" @@ -24,6 +25,7 @@ #include "net/dcsctp/packet/error_cause/user_initiated_abort_cause.h" #include "net/dcsctp/packet/parameter/parameter.h" #include "net/dcsctp/packet/tlv_trait.h" +#include "net/dcsctp/public/dcsctp_options.h" #include "net/dcsctp/testing/testing_macros.h" #include "rtc_base/gunit.h" #include "test/gmock.h" @@ -298,5 +300,43 @@ TEST(SctpPacketTest, DetectPacketWithZeroSizeChunk) { EXPECT_FALSE(SctpPacket::Parse(data, true).has_value()); } + +TEST(SctpPacketTest, ReturnsCorrectSpaceAvailableToStayWithinMTU) { + DcSctpOptions options; + options.mtu = 1191; + + SctpPacket::Builder builder(VerificationTag(123), options); + + // Chunks will be padded to an even 4 bytes, so the maximum packet size should + // be rounded down. + const size_t kMaxPacketSize = RoundDownTo4(options.mtu); + EXPECT_EQ(kMaxPacketSize, 1188u); + + const size_t kSctpHeaderSize = 12; + EXPECT_EQ(builder.bytes_remaining(), kMaxPacketSize - kSctpHeaderSize); + EXPECT_EQ(builder.bytes_remaining(), 1176u); + + // Add a smaller packet first. + DataChunk::Options data_options; + + std::vector payload1(183); + builder.Add( + DataChunk(TSN(1), StreamID(1), SSN(0), PPID(53), payload1, data_options)); + + size_t chunk1_size = RoundUpTo4(DataChunk::kHeaderSize + payload1.size()); + EXPECT_EQ(builder.bytes_remaining(), + kMaxPacketSize - kSctpHeaderSize - chunk1_size); + EXPECT_EQ(builder.bytes_remaining(), 976u); // Hand-calculated. + + std::vector payload2(957); + builder.Add( + DataChunk(TSN(1), StreamID(1), SSN(0), PPID(53), payload2, data_options)); + + size_t chunk2_size = RoundUpTo4(DataChunk::kHeaderSize + payload2.size()); + EXPECT_EQ(builder.bytes_remaining(), + kMaxPacketSize - kSctpHeaderSize - chunk1_size - chunk2_size); + EXPECT_EQ(builder.bytes_remaining(), 0u); // Hand-calculated. +} + } // namespace } // namespace dcsctp