diff --git a/net/dcsctp/socket/dcsctp_socket_test.cc b/net/dcsctp/socket/dcsctp_socket_test.cc index c542a9cd1a..33af2a15de 100644 --- a/net/dcsctp/socket/dcsctp_socket_test.cc +++ b/net/dcsctp/socket/dcsctp_socket_test.cc @@ -208,13 +208,19 @@ class DcSctpSocketTest : public testing::Test { } while (delivered_packet); } + void RunTimers(MockDcSctpSocketCallbacks& cb, DcSctpSocket& socket) { + for (;;) { + absl::optional timeout_id = cb.GetNextExpiredTimeout(); + if (!timeout_id.has_value()) { + break; + } + socket.HandleTimeout(*timeout_id); + } + } + void RunTimers() { - for (const auto timeout_id : cb_a_.RunTimers()) { - sock_a_.HandleTimeout(timeout_id); - } - for (const auto timeout_id : cb_z_.RunTimers()) { - sock_z_.HandleTimeout(timeout_id); - } + RunTimers(cb_a_, sock_a_); + RunTimers(cb_z_, sock_z_); } const DcSctpOptions options_; @@ -1025,9 +1031,7 @@ TEST_F(DcSctpSocketTest, PassingHighWatermarkWillOnlyAcceptCumAckTsn) { // The receiver might have moved into delayed ack mode. cb_z2.AdvanceTime(options.rto_initial); - for (const auto timeout_id : cb_z2.RunTimers()) { - sock_z2.HandleTimeout(timeout_id); - } + RunTimers(cb_z2, sock_z2); EXPECT_THAT( cb_z2.ConsumeSentPacket(), @@ -1066,9 +1070,7 @@ TEST_F(DcSctpSocketTest, PassingHighWatermarkWillOnlyAcceptCumAckTsn) { // The receiver might have moved into delayed ack mode. cb_z2.AdvanceTime(options.rto_initial); - for (const auto timeout_id : cb_z2.RunTimers()) { - sock_z2.HandleTimeout(timeout_id); - } + RunTimers(cb_z2, sock_z2); EXPECT_THAT( cb_z2.ConsumeSentPacket(), diff --git a/net/dcsctp/socket/heartbeat_handler_test.cc b/net/dcsctp/socket/heartbeat_handler_test.cc index 58dbcff4b2..20c1d465db 100644 --- a/net/dcsctp/socket/heartbeat_handler_test.cc +++ b/net/dcsctp/socket/heartbeat_handler_test.cc @@ -45,6 +45,17 @@ class HeartbeatHandlerTest : public testing::Test { timer_manager_([this]() { return callbacks_.CreateTimeout(); }), handler_("log: ", options_, &context_, &timer_manager_) {} + void AdvanceTime(DurationMs duration) { + callbacks_.AdvanceTime(duration); + for (;;) { + absl::optional timeout_id = callbacks_.GetNextExpiredTimeout(); + if (!timeout_id.has_value()) { + break; + } + timer_manager_.HandleTimeout(*timeout_id); + } + } + const DcSctpOptions options_; NiceMock callbacks_; NiceMock context_; @@ -75,10 +86,7 @@ TEST_F(HeartbeatHandlerTest, RepliesToHeartbeatRequests) { } TEST_F(HeartbeatHandlerTest, SendsHeartbeatRequestsOnIdleChannel) { - callbacks_.AdvanceTime(options_.heartbeat_interval); - for (TimeoutID id : callbacks_.RunTimers()) { - timer_manager_.HandleTimeout(id); - } + AdvanceTime(options_.heartbeat_interval); // Grab the request, and make a response. std::vector payload = callbacks_.ConsumeSentPacket(); @@ -101,22 +109,15 @@ TEST_F(HeartbeatHandlerTest, SendsHeartbeatRequestsOnIdleChannel) { } TEST_F(HeartbeatHandlerTest, IncreasesErrorIfNotAckedInTime) { - callbacks_.AdvanceTime(options_.heartbeat_interval); - DurationMs rto(105); EXPECT_CALL(context_, current_rto).WillOnce(Return(rto)); - for (TimeoutID id : callbacks_.RunTimers()) { - timer_manager_.HandleTimeout(id); - } + AdvanceTime(options_.heartbeat_interval); // Validate that a request was sent. EXPECT_THAT(callbacks_.ConsumeSentPacket(), Not(IsEmpty())); EXPECT_CALL(context_, IncrementTxErrorCounter).Times(1); - callbacks_.AdvanceTime(rto); - for (TimeoutID id : callbacks_.RunTimers()) { - timer_manager_.HandleTimeout(id); - } + AdvanceTime(rto); } } // namespace diff --git a/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h b/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h index bad1aa697d..289da7a4d1 100644 --- a/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h +++ b/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h @@ -134,7 +134,9 @@ class MockDcSctpSocketCallbacks : public DcSctpSocketCallbacks { void AdvanceTime(DurationMs duration_ms) { now_ = now_ + duration_ms; } void SetTime(TimeMs now) { now_ = now; } - std::vector RunTimers() { return timeout_manager_.RunTimers(); } + absl::optional GetNextExpiredTimeout() { + return timeout_manager_.GetNextExpiredTimeout(); + } private: TimeMs now_ = TimeMs(0); diff --git a/net/dcsctp/socket/stream_reset_handler_test.cc b/net/dcsctp/socket/stream_reset_handler_test.cc index 4f9b7434d6..6168f16312 100644 --- a/net/dcsctp/socket/stream_reset_handler_test.cc +++ b/net/dcsctp/socket/stream_reset_handler_test.cc @@ -119,8 +119,12 @@ class StreamResetHandlerTest : public testing::Test { void AdvanceTime(DurationMs duration) { callbacks_.AdvanceTime(kRto); - for (TimeoutID timeout_id : callbacks_.RunTimers()) { - timer_manager_.HandleTimeout(timeout_id); + for (;;) { + absl::optional timeout_id = callbacks_.GetNextExpiredTimeout(); + if (!timeout_id.has_value()) { + break; + } + timer_manager_.HandleTimeout(*timeout_id); } } diff --git a/net/dcsctp/timer/fake_timeout.h b/net/dcsctp/timer/fake_timeout.h index 265b34edfa..ada4754a27 100644 --- a/net/dcsctp/timer/fake_timeout.h +++ b/net/dcsctp/timer/fake_timeout.h @@ -18,6 +18,7 @@ #include #include +#include "absl/types/optional.h" #include "net/dcsctp/public/timeout.h" namespace dcsctp { @@ -73,15 +74,20 @@ class FakeTimeoutManager { return timer; } - std::vector RunTimers() { + // NOTE: This can't return a vector, as calling EvaluateHasExpired requires + // calling socket->HandleTimeout directly afterwards, as the owning Timer + // still believes it's running, and it needs to be updated to set + // Timer::is_running_ to false before you operate on the Timer or Timeout + // again. + absl::optional GetNextExpiredTimeout() { TimeMs now = get_time_(); std::vector expired_timers; for (auto& timer : timers_) { if (timer->EvaluateHasExpired(now)) { - expired_timers.push_back(timer->timeout_id()); + return timer->timeout_id(); } } - return expired_timers; + return absl::nullopt; } private: diff --git a/net/dcsctp/timer/timer_test.cc b/net/dcsctp/timer/timer_test.cc index 719d73e891..82b92ef395 100644 --- a/net/dcsctp/timer/timer_test.cc +++ b/net/dcsctp/timer/timer_test.cc @@ -32,8 +32,13 @@ class TimerTest : public testing::Test { void AdvanceTimeAndRunTimers(DurationMs duration) { now_ = now_ + duration; - for (TimeoutID timeout_id : timeout_manager_.RunTimers()) { - manager_.HandleTimeout(timeout_id); + for (;;) { + absl::optional timeout_id = + timeout_manager_.GetNextExpiredTimeout(); + if (!timeout_id.has_value()) { + break; + } + manager_.HandleTimeout(*timeout_id); } }