feat: add health track to ratelimit middleware

This commit is contained in:
deivid.garcia.garcia 2025-09-05 09:13:15 +02:00
parent 9b42b5b930
commit c76deea8a9
5 changed files with 485 additions and 1 deletions

View file

@ -571,6 +571,18 @@ type RateLimit struct {
// Redis stores the configuration for using Redis as a bucket in the rate-limiting algorithm.
// If not specified, Traefik will default to an in-memory bucket for the algorithm.
Redis *Redis `json:"redis,omitempty" toml:"redis,omitempty" yaml:"redis,omitempty" export:"true"`
// UnhealthyLimiterBackoffTimeout is the duration for which the rate limiter will be disabled
// after detecting unhealthy conditions. Defaults to 30 seconds.
UnhealthyLimiterBackoffTimeout *ptypes.Duration `json:"unhealthyLimiterBackoffTimeout,omitempty" toml:"unhealthyLimiterBackoffTimeout,omitempty" yaml:"unhealthyLimiterBackoffTimeout,omitempty" export:"true"`
// UnhealthyLimiterBackoffDuration is the time window during which failures are counted
// to determine if the limiter should be shut down. Defaults to 10 seconds.
UnhealthyLimiterBackoffDuration *ptypes.Duration `json:"unhealthyLimiterBackoffDuration,omitempty" toml:"unhealthyLimiterBackoffDuration,omitempty" yaml:"unhealthyLimiterBackoffDuration,omitempty" export:"true"`
// UnhealthyLimiterBackoffThreshold is the number of failures within the backoff duration
// that will trigger the limiter to be shut down. Defaults to 5.
UnhealthyLimiterBackoffThreshold *int `json:"unhealthyLimiterBackoffThreshold,omitempty" toml:"unhealthyLimiterBackoffThreshold,omitempty" yaml:"unhealthyLimiterBackoffThreshold,omitempty" export:"true"`
}
// SetDefaults sets the default values on a RateLimit.

View file

@ -0,0 +1,101 @@
package ratelimiter
import (
"sync"
"time"
"github.com/rs/zerolog"
)
// healthTracker tracks the health status of the rate limiter
type healthTracker struct {
mu sync.RWMutex
isShutdown bool
shutdownUntil time.Time
failureCount int
lastFailureReset time.Time
backoffTimeout time.Duration
backoffDuration time.Duration
backoffThreshold int
logger *zerolog.Logger
}
// newHealthTracker creates a new health tracker with the given configuration
func newHealthTracker(backoffTimeout, backoffDuration time.Duration, backoffThreshold int, logger *zerolog.Logger) *healthTracker {
return &healthTracker{
backoffTimeout: backoffTimeout,
backoffDuration: backoffDuration,
backoffThreshold: backoffThreshold,
logger: logger,
}
}
// recordFailure records a failure and checks if the limiter should be shut down
func (ht *healthTracker) recordFailure() bool {
ht.mu.Lock()
defer ht.mu.Unlock()
now := time.Now()
// Reset failure count if the backoff duration has passed
if now.Sub(ht.lastFailureReset) > ht.backoffDuration {
ht.failureCount = 0
ht.lastFailureReset = now
}
ht.failureCount++
// Check if we should shut down the limiter
// Only shutdown if threshold is non-negative and we've reached it
if ht.backoffThreshold >= 0 && ht.failureCount >= ht.backoffThreshold {
ht.isShutdown = true
ht.shutdownUntil = now.Add(ht.backoffTimeout)
ht.logger.Warn().
Int("failureCount", ht.failureCount).
Dur("shutdownUntil", ht.backoffTimeout).
Msg("Rate limiter shut down due to repeated failures")
return true
}
return false
}
// isShutdownNow checks if the limiter is currently shut down
func (ht *healthTracker) isShutdownNow() bool {
// Fast path: lockless read for performance in the hot path
// This may occasionally read a stale value during state transitions,
// but this is acceptable for rate limiting where perfect precision isn't critical
if !ht.isShutdown {
return false
}
// Check if shutdown period has expired
if ht.isShutdown && time.Now().After(ht.shutdownUntil) {
ht.mu.Lock()
defer ht.mu.Unlock()
// Double-check after acquiring write lock
if ht.isShutdown && time.Now().After(ht.shutdownUntil) {
ht.isShutdown = false
ht.failureCount = 0
ht.lastFailureReset = time.Now()
ht.logger.Info().Msg("Rate limiter recovered from shutdown")
}
return false
}
return ht.isShutdown
}
// getStatus returns the current status of the health tracker for testing purposes
func (ht *healthTracker) getStatus() (isShutdown bool, failureCount int, shutdownUntil time.Time) {
ht.mu.RLock()
defer ht.mu.RUnlock()
return ht.isShutdown, ht.failureCount, ht.shutdownUntil
}
// getThreshold returns the backoff threshold for testing purposes
func (ht *healthTracker) getThreshold() int {
ht.mu.RLock()
defer ht.mu.RUnlock()
return ht.backoffThreshold
}

View file

@ -0,0 +1,210 @@
package ratelimiter
import (
"testing"
"time"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewHealthTracker(t *testing.T) {
logger := zerolog.Nop()
backoffTimeout := 30 * time.Second
backoffDuration := 10 * time.Second
backoffThreshold := 5
tracker := newHealthTracker(backoffTimeout, backoffDuration, backoffThreshold, &logger)
assert.Equal(t, backoffTimeout, tracker.backoffTimeout)
assert.Equal(t, backoffDuration, tracker.backoffDuration)
assert.Equal(t, backoffThreshold, tracker.backoffThreshold)
assert.Equal(t, &logger, tracker.logger)
assert.False(t, tracker.isShutdown)
assert.Equal(t, 0, tracker.failureCount)
}
func TestRecordFailure_UnderThreshold(t *testing.T) {
logger := zerolog.Nop()
tracker := newHealthTracker(30*time.Second, 10*time.Second, 3, &logger)
// Record 2 failures (under threshold)
shouldShutdown := tracker.recordFailure()
assert.False(t, shouldShutdown)
assert.False(t, tracker.isShutdown)
assert.Equal(t, 1, tracker.failureCount)
shouldShutdown = tracker.recordFailure()
assert.False(t, shouldShutdown)
assert.False(t, tracker.isShutdown)
assert.Equal(t, 2, tracker.failureCount)
}
func TestRecordFailure_AtThreshold(t *testing.T) {
logger := zerolog.Nop()
tracker := newHealthTracker(30*time.Second, 10*time.Second, 3, &logger)
// Record 3 failures (at threshold)
shouldShutdown := tracker.recordFailure()
assert.False(t, shouldShutdown)
assert.Equal(t, 1, tracker.failureCount)
shouldShutdown = tracker.recordFailure()
assert.False(t, shouldShutdown)
assert.Equal(t, 2, tracker.failureCount)
shouldShutdown = tracker.recordFailure()
assert.True(t, shouldShutdown)
assert.True(t, tracker.isShutdown)
assert.Equal(t, 3, tracker.failureCount)
}
func TestRecordFailure_OverThreshold(t *testing.T) {
logger := zerolog.Nop()
tracker := newHealthTracker(30*time.Second, 10*time.Second, 2, &logger)
// Record 3 failures (over threshold)
shouldShutdown := tracker.recordFailure()
assert.False(t, shouldShutdown)
shouldShutdown = tracker.recordFailure()
assert.True(t, shouldShutdown)
assert.True(t, tracker.isShutdown)
// Additional failures after shutdown should still return true
shouldShutdown = tracker.recordFailure()
assert.True(t, shouldShutdown)
assert.True(t, tracker.isShutdown)
}
func TestRecordFailure_ResetAfterPeriod(t *testing.T) {
logger := zerolog.Nop()
backoffDuration := 100 * time.Millisecond
tracker := newHealthTracker(30*time.Second, backoffDuration, 2, &logger)
// Record 1 failure
shouldShutdown := tracker.recordFailure()
assert.False(t, shouldShutdown)
assert.Equal(t, 1, tracker.failureCount)
// Wait for the backoff duration to expire
time.Sleep(backoffDuration + 10*time.Millisecond)
// Record another failure - should reset counter
shouldShutdown = tracker.recordFailure()
assert.False(t, shouldShutdown)
assert.Equal(t, 1, tracker.failureCount) // Reset to 1, not 2
}
func TestIsShutdownNow_NotShutdown(t *testing.T) {
logger := zerolog.Nop()
tracker := newHealthTracker(30*time.Second, 10*time.Second, 2, &logger)
assert.False(t, tracker.isShutdownNow())
}
func TestIsShutdownNow_CurrentlyShutdown(t *testing.T) {
logger := zerolog.Nop()
backoffTimeout := 100 * time.Millisecond
tracker := newHealthTracker(backoffTimeout, 10*time.Second, 1, &logger)
// Trigger shutdown
shouldShutdown := tracker.recordFailure()
require.True(t, shouldShutdown)
require.True(t, tracker.isShutdown)
// Should still be shutdown
assert.True(t, tracker.isShutdownNow())
}
func TestIsShutdownNow_RecoveryAfterTimeout(t *testing.T) {
logger := zerolog.Nop()
backoffTimeout := 50 * time.Millisecond
tracker := newHealthTracker(backoffTimeout, 10*time.Second, 1, &logger)
// Trigger shutdown
shouldShutdown := tracker.recordFailure()
require.True(t, shouldShutdown)
require.True(t, tracker.isShutdown)
// Wait for backoff timeout to expire
time.Sleep(backoffTimeout + 10*time.Millisecond)
// Should have recovered
assert.False(t, tracker.isShutdownNow())
// Check internal state
isShutdown, failureCount, _ := tracker.getStatus()
assert.False(t, isShutdown)
assert.Equal(t, 0, failureCount)
}
func TestConcurrentAccess(t *testing.T) {
logger := zerolog.Nop()
tracker := newHealthTracker(30*time.Second, 10*time.Second, 10, &logger)
// Test concurrent recordFailure calls
done := make(chan bool, 10)
for i := 0; i < 10; i++ {
go func() {
tracker.recordFailure()
done <- true
}()
}
// Wait for all goroutines to complete
for i := 0; i < 10; i++ {
<-done
}
// Should be shutdown after 10 failures
assert.True(t, tracker.isShutdownNow())
}
func TestGetStatus(t *testing.T) {
logger := zerolog.Nop()
backoffTimeout := 30 * time.Second
tracker := newHealthTracker(backoffTimeout, 10*time.Second, 2, &logger)
// Initial state
isShutdown, failureCount, shutdownUntil := tracker.getStatus()
assert.False(t, isShutdown)
assert.Equal(t, 0, failureCount)
assert.True(t, shutdownUntil.IsZero())
// Record one failure
tracker.recordFailure()
isShutdown, failureCount, shutdownUntil = tracker.getStatus()
assert.False(t, isShutdown)
assert.Equal(t, 1, failureCount)
assert.True(t, shutdownUntil.IsZero())
// Record second failure to trigger shutdown
tracker.recordFailure()
isShutdown, failureCount, shutdownUntil = tracker.getStatus()
assert.True(t, isShutdown)
assert.Equal(t, 2, failureCount)
assert.False(t, shutdownUntil.IsZero())
assert.True(t, shutdownUntil.After(time.Now()))
}
func TestEdgeCase_ZeroThreshold(t *testing.T) {
logger := zerolog.Nop()
tracker := newHealthTracker(30*time.Second, 10*time.Second, 0, &logger)
// With threshold 0, first failure should trigger shutdown
shouldShutdown := tracker.recordFailure()
assert.True(t, shouldShutdown)
assert.True(t, tracker.isShutdown)
}
func TestEdgeCase_NegativeThreshold(t *testing.T) {
logger := zerolog.Nop()
tracker := newHealthTracker(30*time.Second, 10*time.Second, -1, &logger)
// With negative threshold, should never shutdown
shouldShutdown := tracker.recordFailure()
assert.False(t, shouldShutdown)
assert.False(t, tracker.isShutdown)
}

View file

@ -38,7 +38,8 @@ type rateLimiter struct {
next http.Handler
logger *zerolog.Logger
limiter limiter
limiter limiter
healthTracker *healthTracker
}
// New returns a rate limiter middleware.
@ -115,6 +116,29 @@ func New(ctx context.Context, next http.Handler, config dynamic.RateLimit, name
}
}
// Initialize health tracker with configuration
// Only create health tracker if ALL three resilience parameters are configured
var healthTracker *healthTracker
hasBackoffTimeout := config.UnhealthyLimiterBackoffTimeout != nil
hasBackoffDuration := config.UnhealthyLimiterBackoffDuration != nil
hasBackoffThreshold := config.UnhealthyLimiterBackoffThreshold != nil
if hasBackoffTimeout && hasBackoffDuration && hasBackoffThreshold {
// All three parameters provided, create health tracker
backoffTimeout := time.Duration(*config.UnhealthyLimiterBackoffTimeout)
backoffDuration := time.Duration(*config.UnhealthyLimiterBackoffDuration)
backoffThreshold := *config.UnhealthyLimiterBackoffThreshold
healthTracker = newHealthTracker(backoffTimeout, backoffDuration, backoffThreshold, logger)
} else if hasBackoffTimeout || hasBackoffDuration || hasBackoffThreshold {
// Only some parameters provided, warn and don't create health tracker
logger.Warn().
Bool("hasBackoffTimeout", hasBackoffTimeout).
Bool("hasBackoffDuration", hasBackoffDuration).
Bool("hasBackoffThreshold", hasBackoffThreshold).
Msg("Resilience parameters must all be provided together. Health tracker not created.")
}
return &rateLimiter{
logger: logger,
name: name,
@ -123,6 +147,7 @@ func New(ctx context.Context, next http.Handler, config dynamic.RateLimit, name
next: next,
sourceMatcher: sourceMatcher,
limiter: limiter,
healthTracker: healthTracker,
}, nil
}
@ -145,6 +170,13 @@ func (rl *rateLimiter) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger.Info().Msgf("ignoring token bucket amount > 1: %d", amount)
}
// Check if the limiter is currently shut down due to health issues
if rl.healthTracker != nil && rl.healthTracker.isShutdownNow() {
// If shut down, bypass rate limiting and let the request through
rl.next.ServeHTTP(rw, req)
return
}
// Each rate limiter has its own source space,
// ensuring independence between rate limiters,
// i.e., rate limit rules are only applied based on traffic
@ -154,6 +186,19 @@ func (rl *rateLimiter) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if err != nil {
rl.logger.Error().Err(err).Msg("Could not insert/update bucket")
observability.SetStatusErrorf(ctx, "Could not insert/update bucket")
// If health tracker is configured, record the failure and check if this should trigger a shutdown
if rl.healthTracker != nil {
shouldShutdown := rl.healthTracker.recordFailure()
// If this failure triggers a shutdown, let the current request through
if shouldShutdown {
rl.next.ServeHTTP(rw, req)
return
}
}
// Otherwise, return 500 as before (default behavior)
http.Error(rw, "Could not insert/update bucket", http.StatusInternalServerError)
return
}

View file

@ -102,6 +102,13 @@ func TestNewRateLimiter(t *testing.T) {
},
},
},
{
desc: "Default behavior - no health tracker when no resilience params",
config: dynamic.RateLimit{
Average: 200,
Burst: 10,
},
},
}
for _, test := range testCases {
@ -150,6 +157,115 @@ func TestNewRateLimiter(t *testing.T) {
if test.expectedRTL != 0 {
assert.InDelta(t, float64(test.expectedRTL), float64(rtl.rate), delta)
}
// Test default behavior - no health tracker when no resilience params
if test.desc == "Default behavior - no health tracker when no resilience params" {
assert.Nil(t, rtl.healthTracker, "Health tracker should be nil when no resilience parameters are provided")
}
})
}
}
func TestRateLimiterWithResilience(t *testing.T) {
testCases := []struct {
desc string
config dynamic.RateLimit
expectedHealthTracker bool
expectedShutdownThreshold int
expectWarning bool
}{
{
desc: "No health tracker when no resilience parameters provided",
config: dynamic.RateLimit{
Average: 100,
Burst: 10,
},
expectedHealthTracker: false,
expectedShutdownThreshold: 0,
expectWarning: false,
},
{
desc: "No health tracker when only backoff timeout is provided",
config: dynamic.RateLimit{
Average: 100,
Burst: 10,
UnhealthyLimiterBackoffTimeout: func() *ptypes.Duration { d := ptypes.Duration(30 * time.Second); return &d }(),
},
expectedHealthTracker: false,
expectedShutdownThreshold: 0,
expectWarning: true,
},
{
desc: "No health tracker when only shutdown period is provided",
config: dynamic.RateLimit{
Average: 100,
Burst: 10,
UnhealthyLimiterBackoffDuration: func() *ptypes.Duration { d := ptypes.Duration(10 * time.Second); return &d }(),
},
expectedHealthTracker: false,
expectedShutdownThreshold: 0,
expectWarning: true,
},
{
desc: "No health tracker when only shutdown threshold is provided",
config: dynamic.RateLimit{
Average: 100,
Burst: 10,
UnhealthyLimiterBackoffThreshold: func() *int { v := 3; return &v }(),
},
expectedHealthTracker: false,
expectedShutdownThreshold: 0,
expectWarning: true,
},
{
desc: "No health tracker when only two parameters are provided",
config: dynamic.RateLimit{
Average: 100,
Burst: 10,
UnhealthyLimiterBackoffTimeout: func() *ptypes.Duration { d := ptypes.Duration(30 * time.Second); return &d }(),
UnhealthyLimiterBackoffDuration: func() *ptypes.Duration { d := ptypes.Duration(10 * time.Second); return &d }(),
},
expectedHealthTracker: false,
expectedShutdownThreshold: 0,
expectWarning: true,
},
{
desc: "Health tracker created when all three parameters are provided",
config: dynamic.RateLimit{
Average: 100,
Burst: 10,
UnhealthyLimiterBackoffTimeout: func() *ptypes.Duration { d := ptypes.Duration(60 * time.Second); return &d }(),
UnhealthyLimiterBackoffDuration: func() *ptypes.Duration { d := ptypes.Duration(5 * time.Second); return &d }(),
UnhealthyLimiterBackoffThreshold: func() *int { v := 2; return &v }(),
},
expectedHealthTracker: true,
expectedShutdownThreshold: 2,
expectWarning: false,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
h, err := New(t.Context(), next, test.config, "rate-limiter")
require.NoError(t, err)
rtl, _ := h.(*rateLimiter)
if test.expectedHealthTracker {
assert.NotNil(t, rtl.healthTracker, "Health tracker should be created when all resilience parameters are provided")
if rtl.healthTracker != nil {
// Test that the health tracker has the correct threshold
threshold := rtl.healthTracker.getThreshold()
assert.Equal(t, test.expectedShutdownThreshold, threshold)
}
} else {
assert.Nil(t, rtl.healthTracker, "Health tracker should not be created when resilience parameters are not properly configured")
}
// Note: Warning testing would require capturing log output, which is complex
// For now, we just verify the behavior (health tracker creation/not creation)
})
}
}