From 5d3bda58fd4bb9fc515491f848abbfeca8f27cba Mon Sep 17 00:00:00 2001 From: Victor Boivie Date: Mon, 12 Apr 2021 21:59:19 +0200 Subject: [PATCH] dcsctp: Add timer safeguards and sanity checks Ensuring that timer durations never go beyond a safe maximum duration and that timer IDs are not re-used. Bug: webrtc:12614 Change-Id: I227a2e9933da16669dc6ea0a39c570892010ba2c Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/215063 Commit-Queue: Victor Boivie Reviewed-by: Tommi Cr-Commit-Position: refs/heads/master@{#33860} --- net/dcsctp/timer/BUILD.gn | 1 + net/dcsctp/timer/timer.cc | 45 +++++++++++++++++++++++++--------- net/dcsctp/timer/timer.h | 34 +++++++++++++++++++------ net/dcsctp/timer/timer_test.cc | 36 +++++++++++++++++++++++++++ 4 files changed, 96 insertions(+), 20 deletions(-) diff --git a/net/dcsctp/timer/BUILD.gn b/net/dcsctp/timer/BUILD.gn index d92aca8f5a..8eec923a2b 100644 --- a/net/dcsctp/timer/BUILD.gn +++ b/net/dcsctp/timer/BUILD.gn @@ -14,6 +14,7 @@ rtc_library("timer") { "../../../rtc_base", "../../../rtc_base:checks", "../../../rtc_base:rtc_base_approved", + "../public:strong_alias", "../public:types", ] sources = [ diff --git a/net/dcsctp/timer/timer.cc b/net/dcsctp/timer/timer.cc index 2376e7aecb..f3c33ea971 100644 --- a/net/dcsctp/timer/timer.cc +++ b/net/dcsctp/timer/timer.cc @@ -9,7 +9,9 @@ */ #include "net/dcsctp/timer/timer.h" +#include #include +#include #include #include #include @@ -17,11 +19,12 @@ #include "absl/memory/memory.h" #include "absl/strings/string_view.h" #include "net/dcsctp/public/timeout.h" +#include "rtc_base/checks.h" namespace dcsctp { namespace { -TimeoutID MakeTimeoutId(uint32_t timer_id, uint32_t generation) { - return TimeoutID(static_cast(timer_id) << 32 | generation); +TimeoutID MakeTimeoutId(TimerID timer_id, TimerGeneration generation) { + return TimeoutID(static_cast(*timer_id) << 32 | *generation); } DurationMs GetBackoffDuration(TimerBackoffAlgorithm algorithm, @@ -30,13 +33,23 @@ DurationMs GetBackoffDuration(TimerBackoffAlgorithm algorithm, switch (algorithm) { case TimerBackoffAlgorithm::kFixed: return base_duration; - case TimerBackoffAlgorithm::kExponential: - return DurationMs(*base_duration * (1 << expiration_count)); + case TimerBackoffAlgorithm::kExponential: { + int32_t duration_ms = *base_duration; + + while (expiration_count > 0 && duration_ms < *Timer::kMaxTimerDuration) { + duration_ms *= 2; + --expiration_count; + } + + return DurationMs(std::min(duration_ms, *Timer::kMaxTimerDuration)); + } } } } // namespace -Timer::Timer(uint32_t id, +constexpr DurationMs Timer::kMaxTimerDuration; + +Timer::Timer(TimerID id, absl::string_view name, OnExpired on_expired, UnregisterHandler unregister_handler, @@ -59,11 +72,13 @@ void Timer::Start() { expiration_count_ = 0; if (!is_running()) { is_running_ = true; - timeout_->Start(duration_, MakeTimeoutId(id_, ++generation_)); + generation_ = TimerGeneration(*generation_ + 1); + timeout_->Start(duration_, MakeTimeoutId(id_, generation_)); } else { // Timer was running - stop and restart it, to make it expire in `duration_` // from now. - timeout_->Restart(duration_, MakeTimeoutId(id_, ++generation_)); + generation_ = TimerGeneration(*generation_ + 1); + timeout_->Restart(duration_, MakeTimeoutId(id_, generation_)); } } @@ -75,7 +90,7 @@ void Timer::Stop() { } } -void Timer::Trigger(uint32_t generation) { +void Timer::Trigger(TimerGeneration generation) { if (is_running_ && generation == generation_) { ++expiration_count_; if (options_.max_restarts >= 0 && @@ -92,14 +107,15 @@ void Timer::Trigger(uint32_t generation) { // Restart it with new duration. DurationMs duration = GetBackoffDuration(options_.backoff_algorithm, duration_, expiration_count_); - timeout_->Start(duration, MakeTimeoutId(id_, ++generation_)); + generation_ = TimerGeneration(*generation_ + 1); + timeout_->Start(duration, MakeTimeoutId(id_, generation_)); } } } void TimerManager::HandleTimeout(TimeoutID timeout_id) { - uint32_t timer_id = *timeout_id >> 32; - uint32_t generation = *timeout_id; + TimerID timer_id(*timeout_id >> 32); + TimerGeneration generation(*timeout_id); auto it = timers_.find(timer_id); if (it != timers_.end()) { it->second->Trigger(generation); @@ -109,7 +125,12 @@ void TimerManager::HandleTimeout(TimeoutID timeout_id) { std::unique_ptr TimerManager::CreateTimer(absl::string_view name, Timer::OnExpired on_expired, const TimerOptions& options) { - uint32_t id = ++next_id_; + next_id_ = TimerID(*next_id_ + 1); + TimerID id = next_id_; + // This would overflow after 4 billion timers created, which in SCTP would be + // after 800 million reconnections on a single socket. Ensure this will never + // happen. + RTC_CHECK_NE(*id, std::numeric_limits::max()); auto timer = absl::WrapUnique(new Timer( id, name, std::move(on_expired), [this, id]() { timers_.erase(id); }, create_timeout_(), options)); diff --git a/net/dcsctp/timer/timer.h b/net/dcsctp/timer/timer.h index 6b68c98374..bf923ea4ca 100644 --- a/net/dcsctp/timer/timer.h +++ b/net/dcsctp/timer/timer.h @@ -12,6 +12,7 @@ #include +#include #include #include #include @@ -20,10 +21,14 @@ #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "net/dcsctp/public/strong_alias.h" #include "net/dcsctp/public/timeout.h" namespace dcsctp { +using TimerID = StrongAlias; +using TimerGeneration = StrongAlias; + enum class TimerBackoffAlgorithm { // The base duration will be used for any restart. kFixed, @@ -68,6 +73,9 @@ struct TimerOptions { // backoff algorithm). class Timer { public: + // The maximum timer duration - one day. + static constexpr DurationMs kMaxTimerDuration = DurationMs(24 * 3600 * 1000); + // When expired, the timer handler can optionally return a new duration which // will be set as `duration` and used as base duration when the timer is // restarted and as input to the backoff algorithm. @@ -89,7 +97,9 @@ class Timer { // Sets the base duration. The actual timer duration may be larger depending // on the backoff algorithm. - void set_duration(DurationMs duration) { duration_ = duration; } + void set_duration(DurationMs duration) { + duration_ = std::min(duration, kMaxTimerDuration); + } // Retrieves the base duration. The actual timer duration may be larger // depending on the backoff algorithm. @@ -110,7 +120,7 @@ class Timer { private: friend class TimerManager; using UnregisterHandler = std::function; - Timer(uint32_t id, + Timer(TimerID id, absl::string_view name, OnExpired on_expired, UnregisterHandler unregister, @@ -122,9 +132,9 @@ class Timer { // duration as decided by the backoff algorithm, unless the // `TimerOptions::max_restarts` has been reached and then it will be stopped // and `is_running()` will return false. - void Trigger(uint32_t generation); + void Trigger(TimerGeneration generation); - const uint32_t id_; + const TimerID id_; const std::string name_; const TimerOptions options_; const OnExpired on_expired_; @@ -133,8 +143,16 @@ class Timer { DurationMs duration_; - // Increased on each start, and is matched on Trigger, to avoid races. - uint32_t generation_ = 0; + // Increased on each start, and is matched on Trigger, to avoid races. And by + // race, meaning that a timeout - which may be evaluated/expired on a + // different thread while this thread has stopped that timer already. Note + // that the entire socket is not thread-safe, so `TimerManager::HandleTimeout` + // is never executed concurrently with any timer starting/stopping. + // + // This will wrap around after 4 billion timer restarts, and if it wraps + // around, it would just trigger _this_ timer in advance (but it's hard to + // restart it 4 billion times within its duration). + TimerGeneration generation_ = TimerGeneration(0); bool is_running_ = false; // Incremented each time time has expired and reset when stopped or restarted. int expiration_count_ = 0; @@ -158,8 +176,8 @@ class TimerManager { private: const std::function()> create_timeout_; - std::unordered_map timers_; - uint32_t next_id_ = 0; + std::unordered_map timers_; + TimerID next_id_ = TimerID(0); }; } // namespace dcsctp diff --git a/net/dcsctp/timer/timer_test.cc b/net/dcsctp/timer/timer_test.cc index 9533234895..719d73e891 100644 --- a/net/dcsctp/timer/timer_test.cc +++ b/net/dcsctp/timer/timer_test.cc @@ -310,5 +310,41 @@ TEST_F(TimerTest, ReturningNewDurationWhenExpired) { AdvanceTimeAndRunTimers(DurationMs(1000)); } +TEST_F(TimerTest, TimersHaveMaximumDuration) { + std::unique_ptr t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(1000), TimerBackoffAlgorithm::kExponential)); + + t1->set_duration(DurationMs(2 * *Timer::kMaxTimerDuration)); + EXPECT_EQ(t1->duration(), Timer::kMaxTimerDuration); +} + +TEST_F(TimerTest, TimersHaveMaximumBackoffDuration) { + std::unique_ptr t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(1000), TimerBackoffAlgorithm::kExponential)); + + t1->Start(); + + int max_exponent = static_cast(log2(*Timer::kMaxTimerDuration / 1000)); + for (int i = 0; i < max_exponent; ++i) { + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000 * (1 << i))); + } + + // Reached the maximum duration. + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(Timer::kMaxTimerDuration); + + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(Timer::kMaxTimerDuration); + + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(Timer::kMaxTimerDuration); + + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(Timer::kMaxTimerDuration); +} + } // namespace } // namespace dcsctp