From 78aa5cd35958e7db8dbe0fa1fa8d41223e93dbab Mon Sep 17 00:00:00 2001 From: Victor Boivie Date: Tue, 13 Apr 2021 23:42:39 +0200 Subject: [PATCH] dcsctp: Ensure packet size doesn't exceed MTU Due to a previous refactoring, the SCTP packet header is only added when the first chunk is written. This wasn't reflected in the `bytes_remaining`, which made it add more than could fit within the MTU. Additionally, the maximum packet size must be even divisible by four as padding will be added to chunks that are not even divisble by four (up to three bytes of padding). So compensate for that. Bug: webrtc:12614 Change-Id: I6b57dfbf88d1fcfcbf443038915dd180e796191a Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/215145 Reviewed-by: Tommi Reviewed-by: Florent Castelli Commit-Queue: Victor Boivie Cr-Commit-Position: refs/heads/master@{#33760} --- net/dcsctp/common/math.h | 14 +++- net/dcsctp/common/math_test.cc | 106 +++++++++++++++++++++++--- net/dcsctp/packet/sctp_packet.cc | 26 ++++++- net/dcsctp/packet/sctp_packet.h | 11 +-- net/dcsctp/packet/sctp_packet_test.cc | 40 ++++++++++ 5 files changed, 178 insertions(+), 19 deletions(-) diff --git a/net/dcsctp/common/math.h b/net/dcsctp/common/math.h index ee161d2c8a..12f690ed57 100644 --- a/net/dcsctp/common/math.h +++ b/net/dcsctp/common/math.h @@ -16,7 +16,19 @@ namespace dcsctp { // used to e.g. pad chunks or parameters to an even 32-bit offset. template IntType RoundUpTo4(IntType val) { - return (val + 3) & -4; + return (val + 3) & ~3; +} + +// Similarly, rounds down `val` to the nearest value that is divisible by four. +template +IntType RoundDownTo4(IntType val) { + return val & ~3; +} + +// Returns true if `val` is divisible by four. +template +bool IsDivisibleBy4(IntType val) { + return (val & 3) == 0; } } // namespace dcsctp diff --git a/net/dcsctp/common/math_test.cc b/net/dcsctp/common/math_test.cc index 902aefa906..f95dfbdb55 100644 --- a/net/dcsctp/common/math_test.cc +++ b/net/dcsctp/common/math_test.cc @@ -15,17 +15,101 @@ namespace dcsctp { namespace { TEST(MathUtilTest, CanRoundUpTo4) { - EXPECT_EQ(RoundUpTo4(0), 0); - EXPECT_EQ(RoundUpTo4(1), 4); - EXPECT_EQ(RoundUpTo4(2), 4); - EXPECT_EQ(RoundUpTo4(3), 4); - EXPECT_EQ(RoundUpTo4(4), 4); - EXPECT_EQ(RoundUpTo4(5), 8); - EXPECT_EQ(RoundUpTo4(6), 8); - EXPECT_EQ(RoundUpTo4(7), 8); - EXPECT_EQ(RoundUpTo4(8), 8); - EXPECT_EQ(RoundUpTo4(10000000000), 10000000000); - EXPECT_EQ(RoundUpTo4(10000000001), 10000000004); + // Signed numbers + EXPECT_EQ(RoundUpTo4(static_cast(-5)), -4); + EXPECT_EQ(RoundUpTo4(static_cast(-4)), -4); + EXPECT_EQ(RoundUpTo4(static_cast(-3)), 0); + EXPECT_EQ(RoundUpTo4(static_cast(-2)), 0); + EXPECT_EQ(RoundUpTo4(static_cast(-1)), 0); + EXPECT_EQ(RoundUpTo4(static_cast(0)), 0); + EXPECT_EQ(RoundUpTo4(static_cast(1)), 4); + EXPECT_EQ(RoundUpTo4(static_cast(2)), 4); + EXPECT_EQ(RoundUpTo4(static_cast(3)), 4); + EXPECT_EQ(RoundUpTo4(static_cast(4)), 4); + EXPECT_EQ(RoundUpTo4(static_cast(5)), 8); + EXPECT_EQ(RoundUpTo4(static_cast(6)), 8); + EXPECT_EQ(RoundUpTo4(static_cast(7)), 8); + EXPECT_EQ(RoundUpTo4(static_cast(8)), 8); + EXPECT_EQ(RoundUpTo4(static_cast(10000000000)), 10000000000); + EXPECT_EQ(RoundUpTo4(static_cast(10000000001)), 10000000004); + + // Unsigned numbers + EXPECT_EQ(RoundUpTo4(static_cast(0)), 0u); + EXPECT_EQ(RoundUpTo4(static_cast(1)), 4u); + EXPECT_EQ(RoundUpTo4(static_cast(2)), 4u); + EXPECT_EQ(RoundUpTo4(static_cast(3)), 4u); + EXPECT_EQ(RoundUpTo4(static_cast(4)), 4u); + EXPECT_EQ(RoundUpTo4(static_cast(5)), 8u); + EXPECT_EQ(RoundUpTo4(static_cast(6)), 8u); + EXPECT_EQ(RoundUpTo4(static_cast(7)), 8u); + EXPECT_EQ(RoundUpTo4(static_cast(8)), 8u); + EXPECT_EQ(RoundUpTo4(static_cast(10000000000)), 10000000000u); + EXPECT_EQ(RoundUpTo4(static_cast(10000000001)), 10000000004u); +} + +TEST(MathUtilTest, CanRoundDownTo4) { + // Signed numbers + EXPECT_EQ(RoundDownTo4(static_cast(-5)), -8); + EXPECT_EQ(RoundDownTo4(static_cast(-4)), -4); + EXPECT_EQ(RoundDownTo4(static_cast(-3)), -4); + EXPECT_EQ(RoundDownTo4(static_cast(-2)), -4); + EXPECT_EQ(RoundDownTo4(static_cast(-1)), -4); + EXPECT_EQ(RoundDownTo4(static_cast(0)), 0); + EXPECT_EQ(RoundDownTo4(static_cast(1)), 0); + EXPECT_EQ(RoundDownTo4(static_cast(2)), 0); + EXPECT_EQ(RoundDownTo4(static_cast(3)), 0); + EXPECT_EQ(RoundDownTo4(static_cast(4)), 4); + EXPECT_EQ(RoundDownTo4(static_cast(5)), 4); + EXPECT_EQ(RoundDownTo4(static_cast(6)), 4); + EXPECT_EQ(RoundDownTo4(static_cast(7)), 4); + EXPECT_EQ(RoundDownTo4(static_cast(8)), 8); + EXPECT_EQ(RoundDownTo4(static_cast(10000000000)), 10000000000); + EXPECT_EQ(RoundDownTo4(static_cast(10000000001)), 10000000000); + + // Unsigned numbers + EXPECT_EQ(RoundDownTo4(static_cast(0)), 0u); + EXPECT_EQ(RoundDownTo4(static_cast(1)), 0u); + EXPECT_EQ(RoundDownTo4(static_cast(2)), 0u); + EXPECT_EQ(RoundDownTo4(static_cast(3)), 0u); + EXPECT_EQ(RoundDownTo4(static_cast(4)), 4u); + EXPECT_EQ(RoundDownTo4(static_cast(5)), 4u); + EXPECT_EQ(RoundDownTo4(static_cast(6)), 4u); + EXPECT_EQ(RoundDownTo4(static_cast(7)), 4u); + EXPECT_EQ(RoundDownTo4(static_cast(8)), 8u); + EXPECT_EQ(RoundDownTo4(static_cast(10000000000)), 10000000000u); + EXPECT_EQ(RoundDownTo4(static_cast(10000000001)), 10000000000u); +} + +TEST(MathUtilTest, IsDivisibleBy4) { + // Signed numbers + EXPECT_EQ(IsDivisibleBy4(static_cast(-4)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast(-3)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(-2)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(-1)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(0)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast(1)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(2)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(3)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(4)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast(5)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(6)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(7)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(8)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast(10000000000)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast(10000000001)), false); + + // Unsigned numbers + EXPECT_EQ(IsDivisibleBy4(static_cast(0)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast(1)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(2)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(3)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(4)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast(5)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(6)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(7)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(8)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast(10000000000)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast(10000000001)), false); } } // namespace diff --git a/net/dcsctp/packet/sctp_packet.cc b/net/dcsctp/packet/sctp_packet.cc index 1e12367263..da06ccf867 100644 --- a/net/dcsctp/packet/sctp_packet.cc +++ b/net/dcsctp/packet/sctp_packet.cc @@ -52,11 +52,11 @@ SctpPacket::Builder::Builder(VerificationTag verification_tag, : verification_tag_(verification_tag), source_port_(options.local_port), dest_port_(options.remote_port), - max_mtu_(options.mtu) {} + max_packet_size_(RoundDownTo4(options.mtu)) {} SctpPacket::Builder& SctpPacket::Builder::Add(const Chunk& chunk) { if (out_.empty()) { - out_.reserve(max_mtu_); + out_.reserve(max_packet_size_); out_.resize(SctpPacket::kHeaderSize); BoundedByteWriter buffer(out_); buffer.Store16<0>(source_port_); @@ -64,14 +64,31 @@ SctpPacket::Builder& SctpPacket::Builder::Add(const Chunk& chunk) { buffer.Store32<4>(*verification_tag_); // Checksum is at offset 8 - written when calling Build(); } + RTC_DCHECK(IsDivisibleBy4(out_.size())); + chunk.SerializeTo(out_); if (out_.size() % 4 != 0) { out_.resize(RoundUpTo4(out_.size())); } + RTC_DCHECK(out_.size() <= max_packet_size_) + << "Exceeded max size, data=" << out_.size() + << ", max_size=" << max_packet_size_; return *this; } +size_t SctpPacket::Builder::bytes_remaining() const { + if (out_.empty()) { + // The packet header (CommonHeader) hasn't been written yet: + return max_packet_size_ - kHeaderSize; + } else if (out_.size() > max_packet_size_) { + RTC_DCHECK(false) << "Exceeded max size, data=" << out_.size() + << ", max_size=" << max_packet_size_; + return 0; + } + return max_packet_size_ - out_.size(); +} + std::vector SctpPacket::Builder::Build() { std::vector out; out_.swap(out); @@ -80,6 +97,11 @@ std::vector SctpPacket::Builder::Build() { uint32_t crc = GenerateCrc32C(out); BoundedByteWriter(out).Store32<8>(crc); } + + RTC_DCHECK(out.size() <= max_packet_size_) + << "Exceeded max size, data=" << out.size() + << ", max_size=" << max_packet_size_; + return out; } diff --git a/net/dcsctp/packet/sctp_packet.h b/net/dcsctp/packet/sctp_packet.h index 927b8dbd41..2600caf7a9 100644 --- a/net/dcsctp/packet/sctp_packet.h +++ b/net/dcsctp/packet/sctp_packet.h @@ -65,10 +65,9 @@ class SctpPacket { // Adds a chunk to the to-be-built SCTP packet. Builder& Add(const Chunk& chunk); - // The number of bytes remaining in the packet, until the MTU is reached. - size_t bytes_remaining() const { - return out_.size() >= max_mtu_ ? 0 : max_mtu_ - out_.size(); - } + // The number of bytes remaining in the packet for chunk storage until the + // packet reaches its maximum size. + size_t bytes_remaining() const; // Indicates if any packets have been added to the builder. bool empty() const { return out_.empty(); } @@ -82,7 +81,9 @@ class SctpPacket { VerificationTag verification_tag_; uint16_t source_port_; uint16_t dest_port_; - size_t max_mtu_; + // The maximum packet size is always even divisible by four, as chunks are + // always padded to a size even divisible by four. + size_t max_packet_size_; std::vector out_; }; diff --git a/net/dcsctp/packet/sctp_packet_test.cc b/net/dcsctp/packet/sctp_packet_test.cc index ece1b7bbd7..7438315eec 100644 --- a/net/dcsctp/packet/sctp_packet_test.cc +++ b/net/dcsctp/packet/sctp_packet_test.cc @@ -15,6 +15,7 @@ #include "api/array_view.h" #include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/common/math.h" #include "net/dcsctp/packet/chunk/abort_chunk.h" #include "net/dcsctp/packet/chunk/cookie_ack_chunk.h" #include "net/dcsctp/packet/chunk/data_chunk.h" @@ -24,6 +25,7 @@ #include "net/dcsctp/packet/error_cause/user_initiated_abort_cause.h" #include "net/dcsctp/packet/parameter/parameter.h" #include "net/dcsctp/packet/tlv_trait.h" +#include "net/dcsctp/public/dcsctp_options.h" #include "net/dcsctp/testing/testing_macros.h" #include "rtc_base/gunit.h" #include "test/gmock.h" @@ -298,5 +300,43 @@ TEST(SctpPacketTest, DetectPacketWithZeroSizeChunk) { EXPECT_FALSE(SctpPacket::Parse(data, true).has_value()); } + +TEST(SctpPacketTest, ReturnsCorrectSpaceAvailableToStayWithinMTU) { + DcSctpOptions options; + options.mtu = 1191; + + SctpPacket::Builder builder(VerificationTag(123), options); + + // Chunks will be padded to an even 4 bytes, so the maximum packet size should + // be rounded down. + const size_t kMaxPacketSize = RoundDownTo4(options.mtu); + EXPECT_EQ(kMaxPacketSize, 1188u); + + const size_t kSctpHeaderSize = 12; + EXPECT_EQ(builder.bytes_remaining(), kMaxPacketSize - kSctpHeaderSize); + EXPECT_EQ(builder.bytes_remaining(), 1176u); + + // Add a smaller packet first. + DataChunk::Options data_options; + + std::vector payload1(183); + builder.Add( + DataChunk(TSN(1), StreamID(1), SSN(0), PPID(53), payload1, data_options)); + + size_t chunk1_size = RoundUpTo4(DataChunk::kHeaderSize + payload1.size()); + EXPECT_EQ(builder.bytes_remaining(), + kMaxPacketSize - kSctpHeaderSize - chunk1_size); + EXPECT_EQ(builder.bytes_remaining(), 976u); // Hand-calculated. + + std::vector payload2(957); + builder.Add( + DataChunk(TSN(1), StreamID(1), SSN(0), PPID(53), payload2, data_options)); + + size_t chunk2_size = RoundUpTo4(DataChunk::kHeaderSize + payload2.size()); + EXPECT_EQ(builder.bytes_remaining(), + kMaxPacketSize - kSctpHeaderSize - chunk1_size - chunk2_size); + EXPECT_EQ(builder.bytes_remaining(), 0u); // Hand-calculated. +} + } // namespace } // namespace dcsctp