diff --git a/net/dcsctp/timer/fake_timeout.h b/net/dcsctp/timer/fake_timeout.h index ada4754a27..927e6b2808 100644 --- a/net/dcsctp/timer/fake_timeout.h +++ b/net/dcsctp/timer/fake_timeout.h @@ -20,6 +20,7 @@ #include "absl/types/optional.h" #include "net/dcsctp/public/timeout.h" +#include "rtc_base/checks.h" namespace dcsctp { @@ -33,14 +34,18 @@ class FakeTimeout : public Timeout { ~FakeTimeout() override { on_delete_(this); } void Start(DurationMs duration_ms, TimeoutID timeout_id) override { + RTC_DCHECK(expiry_ == TimeMs::InfiniteFuture()); timeout_id_ = timeout_id; expiry_ = get_time_() + duration_ms; } - void Stop() override { expiry_ = InfiniteFuture(); } + void Stop() override { + RTC_DCHECK(expiry_ != TimeMs::InfiniteFuture()); + expiry_ = TimeMs::InfiniteFuture(); + } bool EvaluateHasExpired(TimeMs now) { if (now >= expiry_) { - expiry_ = InfiniteFuture(); + expiry_ = TimeMs::InfiniteFuture(); return true; } return false; @@ -49,15 +54,11 @@ class FakeTimeout : public Timeout { TimeoutID timeout_id() const { return timeout_id_; } private: - static constexpr TimeMs InfiniteFuture() { - return TimeMs(std::numeric_limits::max()); - } - const std::function get_time_; const std::function on_delete_; TimeoutID timeout_id_ = TimeoutID(0); - TimeMs expiry_ = InfiniteFuture(); + TimeMs expiry_ = TimeMs::InfiniteFuture(); }; class FakeTimeoutManager { diff --git a/net/dcsctp/timer/timer.cc b/net/dcsctp/timer/timer.cc index f3c33ea971..593d639fa7 100644 --- a/net/dcsctp/timer/timer.cc +++ b/net/dcsctp/timer/timer.cc @@ -93,23 +93,32 @@ void Timer::Stop() { void Timer::Trigger(TimerGeneration generation) { if (is_running_ && generation == generation_) { ++expiration_count_; - if (options_.max_restarts >= 0 && - expiration_count_ > options_.max_restarts) { - is_running_ = false; - } - - absl::optional new_duration = on_expired_(); - if (new_duration.has_value()) { - duration_ = new_duration.value(); - } - - if (is_running_) { - // Restart it with new duration. + is_running_ = false; + if (options_.max_restarts < 0 || + expiration_count_ <= options_.max_restarts) { + // The timer should still be running after this triggers. Start a new + // timer. Note that it might be very quickly restarted again, if the + // `on_expired_` callback returns a new duration. + is_running_ = true; DurationMs duration = GetBackoffDuration(options_.backoff_algorithm, duration_, expiration_count_); generation_ = TimerGeneration(*generation_ + 1); timeout_->Start(duration, MakeTimeoutId(id_, generation_)); } + + absl::optional new_duration = on_expired_(); + if (new_duration.has_value() && new_duration != duration_) { + duration_ = new_duration.value(); + if (is_running_) { + // Restart it with new duration. + timeout_->Stop(); + + DurationMs duration = GetBackoffDuration(options_.backoff_algorithm, + duration_, expiration_count_); + generation_ = TimerGeneration(*generation_ + 1); + timeout_->Start(duration, MakeTimeoutId(id_, generation_)); + } + } } } diff --git a/net/dcsctp/timer/timer_test.cc b/net/dcsctp/timer/timer_test.cc index 82b92ef395..a403bb6b4b 100644 --- a/net/dcsctp/timer/timer_test.cc +++ b/net/dcsctp/timer/timer_test.cc @@ -351,5 +351,40 @@ TEST_F(TimerTest, TimersHaveMaximumBackoffDuration) { AdvanceTimeAndRunTimers(Timer::kMaxTimerDuration); } +TEST_F(TimerTest, TimerCanBeStartedFromWithinExpirationHandler) { + std::unique_ptr t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(1000), TimerBackoffAlgorithm::kFixed)); + + t1->Start(); + + // Start a timer, but don't return any new duration in callback. + EXPECT_CALL(on_expired_, Call).WillOnce([&]() { + EXPECT_TRUE(t1->is_running()); + t1->set_duration(DurationMs(5000)); + t1->Start(); + return absl::nullopt; + }); + AdvanceTimeAndRunTimers(DurationMs(1000)); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(4999)); + + // Start a timer, and return any new duration in callback. + EXPECT_CALL(on_expired_, Call).WillOnce([&]() { + EXPECT_TRUE(t1->is_running()); + t1->set_duration(DurationMs(5000)); + t1->Start(); + return absl::make_optional(DurationMs(8000)); + }); + AdvanceTimeAndRunTimers(DurationMs(1)); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(7999)); + + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1)); +} + } // namespace } // namespace dcsctp