mirror of
https://github.com/traefik/traefik.git
synced 2026-02-18 18:20:28 -05:00
Failover according to response status code
Some checks are pending
CodeQL / Analyze (push) Waiting to run
Build and Publish Documentation / Doc Process (push) Waiting to run
Build experimental image on branch / build-webui (push) Waiting to run
Build experimental image on branch / Build experimental image on branch (push) Waiting to run
Some checks are pending
CodeQL / Analyze (push) Waiting to run
Build and Publish Documentation / Doc Process (push) Waiting to run
Build experimental image on branch / build-webui (push) Waiting to run
Build experimental image on branch / Build experimental image on branch (push) Waiting to run
Co-authored-by: juliens <julien.salleyron@gmail.com>
This commit is contained in:
parent
a4a91344ed
commit
34ae66b9ab
14 changed files with 998 additions and 67 deletions
|
|
@ -760,7 +760,7 @@ The `mirroring` service type mirrors requests sent to a service to other service
|
|||
!!! info "Supported Providers"
|
||||
|
||||
This service type can be defined currently with the [File](../../../install-configuration/providers/others/file.md) provider or [IngressRoute](../../../routing-configuration/kubernetes/crd/http/ingressroute.md).
|
||||
|
||||
|
||||
```yaml tab="Structured (YAML)"
|
||||
## Routing configuration
|
||||
http:
|
||||
|
|
@ -887,9 +887,13 @@ http:
|
|||
url = "http://private-ip-server-2/"
|
||||
```
|
||||
|
||||
### Failover
|
||||
### Failover
|
||||
|
||||
The `failover` service type forwards all requests to a fallback service when the main service becomes unreachable.
|
||||
The `failover` service type forwards requests to a fallback service when the main service is unavailable.
|
||||
Failover can be triggered in two ways:
|
||||
|
||||
- **Health check-based**: When the main service becomes unreachable based on [health checks](#health-check).
|
||||
- **Status code-based**: When the main service responds with specific HTTP status codes defined in the [errors](#errors) configuration.
|
||||
|
||||
!!! info "Relation to HealthCheck"
|
||||
The failover service relies on the HealthCheck system to get notified when its main service becomes unreachable, which means HealthCheck needs to be enabled and functional on the main service. However, HealthCheck does not need to be enabled on the failover service itself for it to be functional. It is only required in order to propagate upwards the information when the failover itself becomes down (i.e. both its main and its fallback are down too).
|
||||
|
|
@ -940,15 +944,15 @@ http:
|
|||
## Routing configuration
|
||||
[http.services]
|
||||
[http.services.app]
|
||||
[http.services.app.failover.healthCheck]
|
||||
[http.services.app.failover]
|
||||
service = "main"
|
||||
fallback = "backup"
|
||||
[http.services.app.failover.healthCheck]
|
||||
|
||||
[http.services.main]
|
||||
[http.services.main.loadBalancer]
|
||||
[http.services.main.loadBalancer.healthCheck]
|
||||
path = "/health"
|
||||
path = "/status"
|
||||
interval = "10s"
|
||||
timeout = "3s"
|
||||
[[http.services.main.loadBalancer.servers]]
|
||||
|
|
@ -957,9 +961,163 @@ http:
|
|||
[http.services.backup]
|
||||
[http.services.backup.loadBalancer]
|
||||
[http.services.backup.loadBalancer.healthCheck]
|
||||
path = "/health"
|
||||
path = "/status"
|
||||
interval = "10s"
|
||||
timeout = "3s"
|
||||
[[http.services.backup.loadBalancer.servers]]
|
||||
url = "http://private-ip-server-2/"
|
||||
```
|
||||
|
||||
#### Errors
|
||||
|
||||
The `errors` option enables status code-based failover.
|
||||
When the main service responds with an HTTP status code matching one of the configured ranges, Traefik automatically retries the request on the fallback service.
|
||||
|
||||
To support request replay, the request body is buffered up to `maxRequestBodyBytes`.
|
||||
Requests with bodies larger than this limit receive a `413 Request Entity Too Large` response.
|
||||
|
||||
Below is a list of options available for the `errors` option and an example of how to configure it for a failover service:
|
||||
|
||||
| Field | Description | Default |
|
||||
|-----------------------|-------------------------------------------------------------------------------------------------------------------|---------|
|
||||
| <a id="opt-status-2" href="#opt-status-2" title="#opt-status-2">`status`</a> | List of HTTP status code ranges that trigger failover. Supports single codes (`"500"`) and ranges (`"500-504"`). | None |
|
||||
| <a id="opt-maxRequestBodyBytes" href="#opt-maxRequestBodyBytes" title="#opt-maxRequestBodyBytes">`maxRequestBodyBytes`</a> | Maximum request body size (in bytes) to buffer for replay to the fallback service. Set to `-1` for no limit. | `-1` |
|
||||
|
||||
```yaml tab="Structured (YAML)"
|
||||
## Routing configuration
|
||||
http:
|
||||
services:
|
||||
app:
|
||||
failover:
|
||||
service: main
|
||||
fallback: backup
|
||||
errors:
|
||||
status:
|
||||
- "500-504"
|
||||
maxRequestBodyBytes: 1048576
|
||||
|
||||
main:
|
||||
loadBalancer:
|
||||
servers:
|
||||
- url: "http://private-ip-server-1/"
|
||||
|
||||
backup:
|
||||
loadBalancer:
|
||||
servers:
|
||||
- url: "http://private-ip-server-2/"
|
||||
```
|
||||
|
||||
```toml tab="Structured (TOML)"
|
||||
## Routing configuration
|
||||
[http.services]
|
||||
[http.services.app]
|
||||
[http.services.app.failover]
|
||||
service = "main"
|
||||
fallback = "backup"
|
||||
[http.services.app.failover.errors]
|
||||
status = ["500-504"]
|
||||
maxRequestBodyBytes = 1048576
|
||||
|
||||
[http.services.main]
|
||||
[http.services.main.loadBalancer]
|
||||
[[http.services.main.loadBalancer.servers]]
|
||||
url = "http://private-ip-server-1/"
|
||||
|
||||
[http.services.backup]
|
||||
[http.services.backup.loadBalancer]
|
||||
[[http.services.backup.loadBalancer.servers]]
|
||||
url = "http://private-ip-server-2/"
|
||||
```
|
||||
|
||||
#### Chaining Failover Services
|
||||
|
||||
Failover services can be chained together for multi-level redundancy.
|
||||
In the following example, if the primary service fails, traffic goes to the secondary service.
|
||||
If both primary and secondary fail, traffic goes to the tertiary service.
|
||||
|
||||
```yaml tab="Structured (YAML)"
|
||||
## Routing configuration
|
||||
http:
|
||||
services:
|
||||
app:
|
||||
failover:
|
||||
healthCheck: {}
|
||||
service: primary-failover
|
||||
fallback: tertiary
|
||||
|
||||
primary-failover:
|
||||
failover:
|
||||
healthCheck: {}
|
||||
service: primary
|
||||
fallback: secondary
|
||||
|
||||
primary:
|
||||
loadBalancer:
|
||||
healthCheck:
|
||||
path: /health
|
||||
interval: 10s
|
||||
timeout: 3s
|
||||
servers:
|
||||
- url: "http://primary-server/"
|
||||
|
||||
secondary:
|
||||
loadBalancer:
|
||||
healthCheck:
|
||||
path: /health
|
||||
interval: 10s
|
||||
timeout: 3s
|
||||
servers:
|
||||
- url: "http://secondary-server/"
|
||||
|
||||
tertiary:
|
||||
loadBalancer:
|
||||
healthCheck:
|
||||
path: /health
|
||||
interval: 10s
|
||||
timeout: 3s
|
||||
servers:
|
||||
- url: "http://tertiary-server/"
|
||||
```
|
||||
|
||||
```toml tab="Structured (TOML)"
|
||||
## Routing configuration
|
||||
[http.services]
|
||||
[http.services.app]
|
||||
[http.services.app.failover]
|
||||
service = "primary-failover"
|
||||
fallback = "tertiary"
|
||||
[http.services.app.failover.healthCheck]
|
||||
|
||||
[http.services.primary-failover]
|
||||
[http.services.primary-failover.failover]
|
||||
service = "primary"
|
||||
fallback = "secondary"
|
||||
[http.services.primary-failover.failover.healthCheck]
|
||||
|
||||
[http.services.primary]
|
||||
[http.services.primary.loadBalancer]
|
||||
[http.services.primary.loadBalancer.healthCheck]
|
||||
path = "/health"
|
||||
interval = "10s"
|
||||
timeout = "3s"
|
||||
[[http.services.primary.loadBalancer.servers]]
|
||||
url = "http://primary-server/"
|
||||
|
||||
[http.services.secondary]
|
||||
[http.services.secondary.loadBalancer]
|
||||
[http.services.secondary.loadBalancer.healthCheck]
|
||||
path = "/health"
|
||||
interval = "10s"
|
||||
timeout = "3s"
|
||||
[[http.services.secondary.loadBalancer.servers]]
|
||||
url = "http://secondary-server/"
|
||||
|
||||
[http.services.tertiary]
|
||||
[http.services.tertiary.loadBalancer]
|
||||
[http.services.tertiary.loadBalancer.healthCheck]
|
||||
path = "/health"
|
||||
interval = "10s"
|
||||
timeout = "3s"
|
||||
[[http.services.tertiary.loadBalancer.servers]]
|
||||
url = "http://tertiary-server/"
|
||||
```
|
||||
|
|
|
|||
|
|
@ -56,6 +56,9 @@
|
|||
service = "foobar"
|
||||
fallback = "foobar"
|
||||
[http.services.Service01.failover.healthCheck]
|
||||
[http.services.Service01.failover.errors]
|
||||
maxRequestBodyBytes = 42
|
||||
status = ["foobar", "foobar"]
|
||||
[http.services.Service02]
|
||||
[http.services.Service02.highestRandomWeight]
|
||||
|
||||
|
|
|
|||
|
|
@ -70,6 +70,11 @@ http:
|
|||
service: foobar
|
||||
fallback: foobar
|
||||
healthCheck: {}
|
||||
errors:
|
||||
maxRequestBodyBytes: 42
|
||||
status:
|
||||
- foobar
|
||||
- foobar
|
||||
Service02:
|
||||
highestRandomWeight:
|
||||
services:
|
||||
|
|
|
|||
53
integration/fixtures/failover.toml
Normal file
53
integration/fixtures/failover.toml
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
[global]
|
||||
checkNewVersion = false
|
||||
sendAnonymousUsage = false
|
||||
|
||||
[api]
|
||||
insecure = true
|
||||
|
||||
[log]
|
||||
level = "DEBUG"
|
||||
noColor = true
|
||||
|
||||
[entryPoints]
|
||||
[entryPoints.web]
|
||||
address = ":8000"
|
||||
|
||||
[providers.file]
|
||||
filename = "{{ .SelfFilename }}"
|
||||
|
||||
## dynamic configuration ##
|
||||
|
||||
[http.routers]
|
||||
[http.routers.router1]
|
||||
entrypoints = ["web"]
|
||||
service = "failover-service"
|
||||
rule = "Path(`/whoami`)"
|
||||
|
||||
[http.services]
|
||||
# Failover service with health check
|
||||
[http.services.failover-service]
|
||||
[http.services.failover-service.failover]
|
||||
service = "main-service"
|
||||
fallback = "fallback-service"
|
||||
[http.services.failover-service.failover.healthCheck]
|
||||
|
||||
# Main service with health check enabled
|
||||
[http.services.main-service]
|
||||
[http.services.main-service.loadBalancer]
|
||||
[http.services.main-service.loadBalancer.healthCheck]
|
||||
path = "/health"
|
||||
interval = "1s"
|
||||
timeout = "0.9s"
|
||||
[[http.services.main-service.loadBalancer.servers]]
|
||||
url = "http://{{ .MainServer }}:80"
|
||||
|
||||
# Fallback service with health check enabled
|
||||
[http.services.fallback-service]
|
||||
[http.services.fallback-service.loadBalancer]
|
||||
[http.services.fallback-service.loadBalancer.healthCheck]
|
||||
path = "/health"
|
||||
interval = "1s"
|
||||
timeout = "0.9s"
|
||||
[[http.services.fallback-service.loadBalancer.servers]]
|
||||
url = "http://{{ .FallbackServer }}:80"
|
||||
47
integration/fixtures/failover_statuscode.toml
Normal file
47
integration/fixtures/failover_statuscode.toml
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
[global]
|
||||
checkNewVersion = false
|
||||
sendAnonymousUsage = false
|
||||
|
||||
[api]
|
||||
insecure = true
|
||||
|
||||
[log]
|
||||
level = "DEBUG"
|
||||
noColor = true
|
||||
|
||||
[entryPoints]
|
||||
[entryPoints.web]
|
||||
address = ":8000"
|
||||
|
||||
[providers.file]
|
||||
filename = "{{ .SelfFilename }}"
|
||||
|
||||
## dynamic configuration ##
|
||||
|
||||
[http.routers]
|
||||
[http.routers.router1]
|
||||
entrypoints = ["web"]
|
||||
service = "failover-service"
|
||||
rule = "PathPrefix(`/`)"
|
||||
|
||||
[http.services]
|
||||
# Failover service with error-based failover (status codes)
|
||||
[http.services.failover-service]
|
||||
[http.services.failover-service.failover]
|
||||
service = "main-service"
|
||||
fallback = "fallback-service"
|
||||
[http.services.failover-service.failover.errors]
|
||||
status = ["500-504"]
|
||||
maxRequestBodyBytes = 1048576 # 1MB
|
||||
|
||||
# Main service (no health check - failover based on status codes only)
|
||||
[http.services.main-service]
|
||||
[http.services.main-service.loadBalancer]
|
||||
[[http.services.main-service.loadBalancer.servers]]
|
||||
url = "{{ .MainServer }}"
|
||||
|
||||
# Fallback service
|
||||
[http.services.fallback-service]
|
||||
[http.services.fallback-service.loadBalancer]
|
||||
[[http.services.fallback-service.loadBalancer.servers]]
|
||||
url = "{{ .FallbackServer }}"
|
||||
|
|
@ -2365,6 +2365,173 @@ func (s *SimpleSuite) TestEncodedCharactersDifferentEntryPoints() {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *SimpleSuite) TestFailoverService() {
|
||||
s.createComposeProject("base")
|
||||
|
||||
s.composeUp()
|
||||
defer s.composeDown()
|
||||
|
||||
whoami1IP := s.getComposeServiceIP("whoami1")
|
||||
whoami2IP := s.getComposeServiceIP("whoami2")
|
||||
|
||||
file := s.adaptFile("fixtures/failover.toml", struct {
|
||||
MainServer string
|
||||
FallbackServer string
|
||||
}{
|
||||
MainServer: whoami1IP,
|
||||
FallbackServer: whoami2IP,
|
||||
})
|
||||
|
||||
s.traefikCmd(withConfigFile(file))
|
||||
|
||||
// Wait for Traefik to be ready
|
||||
err := try.GetRequest("http://127.0.0.1:8080/api/http/services", 2*time.Second, try.BodyContains("failover-service"))
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// Test 1: When main service is healthy, traffic should go to main
|
||||
var primaryCount, fallbackCount int
|
||||
for range 5 {
|
||||
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8000/whoami", nil)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
response, err := http.DefaultClient.Do(req)
|
||||
require.NoError(s.T(), err)
|
||||
assert.Equal(s.T(), http.StatusOK, response.StatusCode)
|
||||
|
||||
body, err := io.ReadAll(response.Body)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
if strings.Contains(string(body), whoami1IP) {
|
||||
primaryCount++
|
||||
}
|
||||
if strings.Contains(string(body), whoami2IP) {
|
||||
fallbackCount++
|
||||
}
|
||||
}
|
||||
|
||||
// All requests should go to the main service (whoami1)
|
||||
assert.Equal(s.T(), 5, primaryCount, "Expected all requests to go to main service")
|
||||
assert.Equal(s.T(), 0, fallbackCount, "Expected no requests to go to fallback service")
|
||||
|
||||
// Test 2: Stop the main service to trigger failover via health check
|
||||
s.composeStop("whoami1")
|
||||
|
||||
// Wait for health check to detect the main service is down
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
// Now all traffic should go to the fallback service (whoami2)
|
||||
primaryCount, fallbackCount = 0, 0
|
||||
for range 5 {
|
||||
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8000/whoami", nil)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
response, err := http.DefaultClient.Do(req)
|
||||
require.NoError(s.T(), err)
|
||||
assert.Equal(s.T(), http.StatusOK, response.StatusCode)
|
||||
|
||||
body, err := io.ReadAll(response.Body)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
if strings.Contains(string(body), whoami1IP) {
|
||||
primaryCount++
|
||||
}
|
||||
if strings.Contains(string(body), whoami2IP) {
|
||||
fallbackCount++
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(s.T(), 0, primaryCount, "Expected no requests to go to main service when down")
|
||||
assert.Equal(s.T(), 5, fallbackCount, "Expected all requests to go to fallback service")
|
||||
|
||||
// Test 3: Restart main service and verify traffic returns to main
|
||||
s.composeUp("whoami1")
|
||||
|
||||
// Wait for health check to detect the main service is back up
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
primaryCount, fallbackCount = 0, 0
|
||||
for range 5 {
|
||||
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8000/whoami", nil)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
response, err := http.DefaultClient.Do(req)
|
||||
require.NoError(s.T(), err)
|
||||
assert.Equal(s.T(), http.StatusOK, response.StatusCode)
|
||||
|
||||
body, err := io.ReadAll(response.Body)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
if strings.Contains(string(body), whoami1IP) {
|
||||
primaryCount++
|
||||
}
|
||||
if strings.Contains(string(body), whoami2IP) {
|
||||
fallbackCount++
|
||||
}
|
||||
}
|
||||
|
||||
// Traffic should return to the main service
|
||||
assert.Equal(s.T(), 5, primaryCount, "Expected all requests to return to main service when back up")
|
||||
assert.Equal(s.T(), 0, fallbackCount, "Expected no requests to go to fallback service")
|
||||
|
||||
// Test 4: Stop both services and verify we get 503
|
||||
s.composeStop("whoami1")
|
||||
s.composeStop("whoami2")
|
||||
|
||||
// Wait for health checks to detect both services are down
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
// Request should return 503 Service Unavailable when both services are down
|
||||
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8000/whoami", nil)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
response, err := http.DefaultClient.Do(req)
|
||||
require.NoError(s.T(), err)
|
||||
assert.Equal(s.T(), http.StatusServiceUnavailable, response.StatusCode)
|
||||
}
|
||||
|
||||
func (s *SimpleSuite) TestFailoverServiceWithStatusCode() {
|
||||
var mainCallCount, fallbackCallCount atomic.Int32
|
||||
|
||||
// Create a test server that returns 503 to trigger error-based failover
|
||||
mainServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
mainCallCount.Add(1)
|
||||
rw.WriteHeader(http.StatusServiceUnavailable)
|
||||
_, _ = rw.Write([]byte("main service unavailable"))
|
||||
}))
|
||||
defer mainServer.Close()
|
||||
|
||||
// Create a fallback server that returns 200
|
||||
fallbackServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
fallbackCallCount.Add(1)
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
|
||||
_, _ = rw.Write([]byte("fallback service"))
|
||||
}))
|
||||
defer fallbackServer.Close()
|
||||
|
||||
file := s.adaptFile("fixtures/failover_statuscode.toml", struct {
|
||||
MainServer string
|
||||
FallbackServer string
|
||||
}{
|
||||
MainServer: mainServer.URL,
|
||||
FallbackServer: fallbackServer.URL,
|
||||
})
|
||||
|
||||
s.traefikCmd(withConfigFile(file))
|
||||
|
||||
// Wait for Traefik to be ready and verify the configuration is loaded
|
||||
err := try.GetRequest("http://127.0.0.1:8080/api/rawdata", 10*time.Second, try.BodyContains("PathPrefix"))
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// Make a request - should failover to fallback because main returns 503
|
||||
err = try.GetRequest("http://127.0.0.1:8000/", 5*time.Second, try.BodyContains("fallback service"))
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// Main was called but returned 503, triggering failover to fallback
|
||||
assert.GreaterOrEqual(s.T(), mainCallCount.Load(), int32(1), "Main service should have been called at least once")
|
||||
assert.GreaterOrEqual(s.T(), fallbackCallCount.Load(), int32(1), "Fallback service should have been called at least once")
|
||||
}
|
||||
|
||||
func (s *SimpleSuite) TestServiceMiddleware() {
|
||||
s.createComposeProject("base")
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import (
|
|||
traefiktls "github.com/traefik/traefik/v3/pkg/tls"
|
||||
"github.com/traefik/traefik/v3/pkg/types"
|
||||
"google.golang.org/grpc/codes"
|
||||
"k8s.io/utils/ptr"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -28,6 +29,8 @@ const (
|
|||
MirroringDefaultMirrorBody = true
|
||||
// MirroringDefaultMaxBodySize is the Mirroring.MaxBodySize option default value.
|
||||
MirroringDefaultMaxBodySize int64 = -1
|
||||
// FailoverErrorsDefaultMaxRequestBodyBytes is the Failover.Errors.MaxBodySize option default value.
|
||||
FailoverErrorsDefaultMaxRequestBodyBytes int64 = -1
|
||||
)
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
|
@ -195,9 +198,23 @@ func (m *Mirroring) SetDefaults() {
|
|||
|
||||
// Failover holds the Failover configuration.
|
||||
type Failover struct {
|
||||
Service string `json:"service,omitempty" toml:"service,omitempty" yaml:"service,omitempty" export:"true"`
|
||||
Fallback string `json:"fallback,omitempty" toml:"fallback,omitempty" yaml:"fallback,omitempty" export:"true"`
|
||||
HealthCheck *HealthCheck `json:"healthCheck,omitempty" toml:"healthCheck,omitempty" yaml:"healthCheck,omitempty" label:"allowEmpty" file:"allowEmpty" export:"true"`
|
||||
Service string `json:"service,omitempty" toml:"service,omitempty" yaml:"service,omitempty" export:"true"`
|
||||
Fallback string `json:"fallback,omitempty" toml:"fallback,omitempty" yaml:"fallback,omitempty" export:"true"`
|
||||
HealthCheck *HealthCheck `json:"healthCheck,omitempty" toml:"healthCheck,omitempty" yaml:"healthCheck,omitempty" label:"allowEmpty" file:"allowEmpty" export:"true"`
|
||||
Errors *FailoverError `json:"errors,omitempty" toml:"errors,omitempty" yaml:"errors,omitempty" export:"true"`
|
||||
}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// FailoverError holds errors configuration.
|
||||
type FailoverError struct {
|
||||
MaxRequestBodyBytes *int64 `json:"maxRequestBodyBytes,omitempty" toml:"maxRequestBodyBytes,omitempty" yaml:"maxRequestBodyBytes,omitempty" export:"true"`
|
||||
Status []string `json:"status,omitempty" toml:"status,omitempty" yaml:"status,omitempty" export:"true"`
|
||||
}
|
||||
|
||||
// SetDefaults Default values for a WRRService.
|
||||
func (m *FailoverError) SetDefaults() {
|
||||
m.MaxRequestBodyBytes = ptr.To(FailoverErrorsDefaultMaxRequestBodyBytes)
|
||||
}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
|
|
|||
|
|
@ -358,6 +358,11 @@ func (in *Failover) DeepCopyInto(out *Failover) {
|
|||
*out = new(HealthCheck)
|
||||
**out = **in
|
||||
}
|
||||
if in.Errors != nil {
|
||||
in, out := &in.Errors, &out.Errors
|
||||
*out = new(FailoverError)
|
||||
(*in).DeepCopyInto(*out)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -371,6 +376,32 @@ func (in *Failover) DeepCopy() *Failover {
|
|||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *FailoverError) DeepCopyInto(out *FailoverError) {
|
||||
*out = *in
|
||||
if in.MaxRequestBodyBytes != nil {
|
||||
in, out := &in.MaxRequestBodyBytes, &out.MaxRequestBodyBytes
|
||||
*out = new(int64)
|
||||
**out = **in
|
||||
}
|
||||
if in.Status != nil {
|
||||
in, out := &in.Status, &out.Status
|
||||
*out = make([]string, len(*in))
|
||||
copy(*out, *in)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FailoverError.
|
||||
func (in *FailoverError) DeepCopy() *FailoverError {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(FailoverError)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *ForwardAuth) DeepCopyInto(out *ForwardAuth) {
|
||||
*out = *in
|
||||
|
|
|
|||
|
|
@ -3,11 +3,14 @@ package failover
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/traefik/traefik/v3/pkg/config/dynamic"
|
||||
"github.com/traefik/traefik/v3/pkg/server/service/loadbalancer/mirror"
|
||||
"github.com/traefik/traefik/v3/pkg/types"
|
||||
)
|
||||
|
||||
// Failover is an http.Handler that can forward requests to the fallback handler
|
||||
|
|
@ -25,13 +28,33 @@ type Failover struct {
|
|||
|
||||
fallbackStatusMu sync.RWMutex
|
||||
fallbackStatus bool
|
||||
|
||||
statusCode types.HTTPCodeRanges
|
||||
maxRequestBodyBytes int64
|
||||
}
|
||||
|
||||
// New creates a new Failover handler.
|
||||
func New(hc *dynamic.HealthCheck) *Failover {
|
||||
return &Failover{
|
||||
wantsHealthCheck: hc != nil,
|
||||
func New(config *dynamic.Failover) (*Failover, error) {
|
||||
f := &Failover{wantsHealthCheck: config.HealthCheck != nil}
|
||||
|
||||
if config.Errors != nil {
|
||||
if len(config.Errors.Status) > 0 {
|
||||
httpCodeRanges, err := types.NewHTTPCodeRanges(config.Errors.Status)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating HTTP code ranges: %w", err)
|
||||
}
|
||||
f.statusCode = httpCodeRanges
|
||||
}
|
||||
|
||||
maxRequestBodyBytes := dynamic.FailoverErrorsDefaultMaxRequestBodyBytes
|
||||
if config.Errors.MaxRequestBodyBytes != nil {
|
||||
maxRequestBodyBytes = *config.Errors.MaxRequestBodyBytes
|
||||
}
|
||||
|
||||
f.maxRequestBodyBytes = maxRequestBodyBytes
|
||||
}
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// RegisterStatusUpdater adds fn to the list of hooks that are run when the
|
||||
|
|
@ -53,8 +76,33 @@ func (f *Failover) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
|||
f.handlerStatusMu.RUnlock()
|
||||
|
||||
if handlerStatus {
|
||||
f.handler.ServeHTTP(w, req)
|
||||
return
|
||||
if len(f.statusCode) == 0 {
|
||||
f.handler.ServeHTTP(w, req)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: move reusable request to a common package at some point.
|
||||
rr, _, err := mirror.NewReusableRequest(req, f.maxRequestBodyBytes)
|
||||
if err != nil && !errors.Is(err, mirror.ErrBodyTooLarge) {
|
||||
log.Ctx(req.Context()).Debug().Err(err).Msg("Error while creating reusable request for failover")
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if errors.Is(err, mirror.ErrBodyTooLarge) {
|
||||
http.Error(w, "Request body too large", http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
}
|
||||
|
||||
rw := &responseWriter{ResponseWriter: w, statusCodeRange: f.statusCode}
|
||||
f.handler.ServeHTTP(rw, rr.Clone(req.Context()))
|
||||
|
||||
if !rw.needFallback {
|
||||
return
|
||||
}
|
||||
|
||||
req = rr.Clone(req.Context())
|
||||
}
|
||||
|
||||
f.fallbackStatusMu.RLock()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
package failover
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
|
@ -26,7 +28,10 @@ func (r *responseRecorder) WriteHeader(statusCode int) {
|
|||
}
|
||||
|
||||
func TestFailover(t *testing.T) {
|
||||
failover := New(&dynamic.HealthCheck{})
|
||||
failover, err := New(&dynamic.Failover{
|
||||
HealthCheck: &dynamic.HealthCheck{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
status := true
|
||||
require.NoError(t, failover.RegisterStatusUpdater(func(up bool) {
|
||||
|
|
@ -73,7 +78,8 @@ func TestFailover(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestFailoverDownThenUp(t *testing.T) {
|
||||
failover := New(nil)
|
||||
failover, err := New(&dynamic.Failover{})
|
||||
require.NoError(t, err)
|
||||
|
||||
failover.SetHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set("server", "handler")
|
||||
|
|
@ -112,7 +118,10 @@ func TestFailoverDownThenUp(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestFailoverPropagate(t *testing.T) {
|
||||
failover := New(&dynamic.HealthCheck{})
|
||||
failover, err := New(&dynamic.Failover{
|
||||
HealthCheck: &dynamic.HealthCheck{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
failover.SetHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set("server", "handler")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
|
|
@ -122,13 +131,14 @@ func TestFailoverPropagate(t *testing.T) {
|
|||
rw.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
topFailover := New(nil)
|
||||
topFailover, err := New(&dynamic.Failover{})
|
||||
require.NoError(t, err)
|
||||
topFailover.SetHandler(failover)
|
||||
topFailover.SetFallbackHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set("server", "topFailover")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
err := failover.RegisterStatusUpdater(func(up bool) {
|
||||
err = failover.RegisterStatusUpdater(func(up bool) {
|
||||
topFailover.SetHandlerStatus(t.Context(), up)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
|
@ -161,3 +171,301 @@ func TestFailoverPropagate(t *testing.T) {
|
|||
assert.Equal(t, 1, recorder.save["topFailover"])
|
||||
assert.Equal(t, []int{200}, recorder.status)
|
||||
}
|
||||
|
||||
func TestFailoverStatusCode(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
statusCode []string
|
||||
handlerStatusCode int
|
||||
expectedHandler string
|
||||
expectedStatusCode int
|
||||
expectedResponseBody string
|
||||
}{
|
||||
{
|
||||
desc: "main handler returns 503, failover triggered",
|
||||
statusCode: []string{"503"},
|
||||
handlerStatusCode: 503,
|
||||
expectedHandler: "fallback",
|
||||
expectedStatusCode: 200,
|
||||
expectedResponseBody: "fallback response",
|
||||
},
|
||||
{
|
||||
desc: "main handler returns 200, failover not triggered",
|
||||
statusCode: []string{"503"},
|
||||
handlerStatusCode: 200,
|
||||
expectedHandler: "handler",
|
||||
expectedStatusCode: 200,
|
||||
expectedResponseBody: "main response",
|
||||
},
|
||||
{
|
||||
desc: "main handler returns 500, failover not triggered",
|
||||
statusCode: []string{"503"},
|
||||
handlerStatusCode: 500,
|
||||
expectedHandler: "handler",
|
||||
expectedStatusCode: 500,
|
||||
expectedResponseBody: "main response",
|
||||
},
|
||||
{
|
||||
desc: "multiple status codes, 503 triggers failover",
|
||||
statusCode: []string{"500", "502", "503", "504"},
|
||||
handlerStatusCode: 503,
|
||||
expectedHandler: "fallback",
|
||||
expectedStatusCode: 200,
|
||||
expectedResponseBody: "fallback response",
|
||||
},
|
||||
{
|
||||
desc: "multiple status codes, 502 triggers failover",
|
||||
statusCode: []string{"500", "502", "503", "504"},
|
||||
handlerStatusCode: 502,
|
||||
expectedHandler: "fallback",
|
||||
expectedStatusCode: 200,
|
||||
expectedResponseBody: "fallback response",
|
||||
},
|
||||
{
|
||||
desc: "multiple status codes, 404 does not trigger failover",
|
||||
statusCode: []string{"500", "502", "503", "504"},
|
||||
handlerStatusCode: 404,
|
||||
expectedHandler: "handler",
|
||||
expectedStatusCode: 404,
|
||||
expectedResponseBody: "main response",
|
||||
},
|
||||
{
|
||||
desc: "status code range 500-504, 503 triggers failover",
|
||||
statusCode: []string{"500-504"},
|
||||
handlerStatusCode: 503,
|
||||
expectedHandler: "fallback",
|
||||
expectedStatusCode: 200,
|
||||
expectedResponseBody: "fallback response",
|
||||
},
|
||||
{
|
||||
desc: "status code range 500-504, 404 does not trigger failover",
|
||||
statusCode: []string{"500-504"},
|
||||
handlerStatusCode: 404,
|
||||
expectedHandler: "handler",
|
||||
expectedStatusCode: 404,
|
||||
expectedResponseBody: "main response",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
failover, err := New(&dynamic.Failover{
|
||||
Errors: &dynamic.FailoverError{
|
||||
Status: test.statusCode,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
failover.SetHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set("server", "handler")
|
||||
rw.WriteHeader(test.handlerStatusCode)
|
||||
_, err := rw.Write([]byte("main response"))
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
|
||||
failover.SetFallbackHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set("server", "fallback")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
_, err := rw.Write([]byte("fallback response"))
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
failover.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
|
||||
assert.Equal(t, test.expectedHandler, recorder.Header().Get("server"))
|
||||
assert.Equal(t, test.expectedStatusCode, recorder.Code)
|
||||
assert.Equal(t, test.expectedResponseBody, recorder.Body.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailoverStatusCodeWithRequestBody(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
statusCode []string
|
||||
handlerStatusCode int
|
||||
requestBody string
|
||||
expectedHandler string
|
||||
}{
|
||||
{
|
||||
desc: "request body replayed to fallback handler",
|
||||
statusCode: []string{"503"},
|
||||
handlerStatusCode: 503,
|
||||
requestBody: "test request body",
|
||||
expectedHandler: "fallback",
|
||||
},
|
||||
{
|
||||
desc: "request body used by main handler",
|
||||
statusCode: []string{"503"},
|
||||
handlerStatusCode: 200,
|
||||
requestBody: "test request body",
|
||||
expectedHandler: "handler",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
var receivedBody string
|
||||
|
||||
maxBody := int64(-1)
|
||||
failover, err := New(&dynamic.Failover{
|
||||
Errors: &dynamic.FailoverError{
|
||||
Status: test.statusCode,
|
||||
MaxRequestBodyBytes: &maxBody,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
failover.SetHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
body, _ := io.ReadAll(req.Body)
|
||||
receivedBody = string(body)
|
||||
rw.Header().Set("server", "handler")
|
||||
rw.WriteHeader(test.handlerStatusCode)
|
||||
}))
|
||||
|
||||
failover.SetFallbackHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
body, _ := io.ReadAll(req.Body)
|
||||
receivedBody = string(body)
|
||||
rw.Header().Set("server", "fallback")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString(test.requestBody))
|
||||
failover.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, test.expectedHandler, recorder.Header().Get("server"))
|
||||
assert.Equal(t, test.requestBody, receivedBody)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailoverStatusCodeMaxBodySize(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
maxBodySize int64
|
||||
requestBody string
|
||||
expectedStatusCode int
|
||||
expectedMessage string
|
||||
}{
|
||||
{
|
||||
desc: "request body within limit",
|
||||
maxBodySize: 100,
|
||||
requestBody: "small body",
|
||||
expectedStatusCode: 200,
|
||||
expectedMessage: "",
|
||||
},
|
||||
{
|
||||
desc: "request body exceeds limit",
|
||||
maxBodySize: 5,
|
||||
requestBody: "this body is too large",
|
||||
expectedStatusCode: http.StatusRequestEntityTooLarge,
|
||||
expectedMessage: "Request body too large\n",
|
||||
},
|
||||
{
|
||||
desc: "zero body size limit",
|
||||
maxBodySize: 0,
|
||||
requestBody: "any size body",
|
||||
expectedStatusCode: http.StatusRequestEntityTooLarge,
|
||||
expectedMessage: "Request body too large\n",
|
||||
},
|
||||
{
|
||||
desc: "no body size limit",
|
||||
maxBodySize: -1,
|
||||
requestBody: "any size body should work",
|
||||
expectedStatusCode: 200,
|
||||
expectedMessage: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
maxBody := test.maxBodySize
|
||||
failover, err := New(&dynamic.Failover{
|
||||
Errors: &dynamic.FailoverError{
|
||||
Status: []string{"503"},
|
||||
MaxRequestBodyBytes: &maxBody,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
failover.SetHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(http.StatusServiceUnavailable)
|
||||
}))
|
||||
|
||||
failover.SetFallbackHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString(test.requestBody))
|
||||
req.ContentLength = int64(len(test.requestBody))
|
||||
failover.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, test.expectedStatusCode, recorder.Code)
|
||||
if test.expectedMessage != "" {
|
||||
assert.Equal(t, test.expectedMessage, recorder.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailoverStatusCodeWithHealthCheck(t *testing.T) {
|
||||
failover, err := New(&dynamic.Failover{
|
||||
HealthCheck: &dynamic.HealthCheck{},
|
||||
Errors: &dynamic.FailoverError{
|
||||
Status: []string{"503"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
failover.SetHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set("server", "handler")
|
||||
rw.WriteHeader(http.StatusServiceUnavailable)
|
||||
}))
|
||||
|
||||
failover.SetFallbackHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set("server", "fallback")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Test 1: Handler is up, returns 503, should failover based on status code
|
||||
recorder := httptest.NewRecorder()
|
||||
failover.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
|
||||
assert.Equal(t, "fallback", recorder.Header().Get("server"))
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
// Test 2: Handler is marked down via health check, should failover
|
||||
failover.SetHandlerStatus(t.Context(), false)
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
failover.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
|
||||
assert.Equal(t, "fallback", recorder.Header().Get("server"))
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
// Test 3: Handler is back up but returns non-503 status, should not failover
|
||||
failover.SetHandlerStatus(t.Context(), true)
|
||||
failover.SetHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set("server", "handler")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
failover.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
|
||||
assert.Equal(t, "handler", recorder.Header().Get("server"))
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
}
|
||||
|
||||
func TestFailoverInvalidStatusCodeRange(t *testing.T) {
|
||||
_, err := New(&dynamic.Failover{
|
||||
Errors: &dynamic.FailoverError{
|
||||
Status: []string{"invalid"},
|
||||
},
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "strconv")
|
||||
}
|
||||
|
|
|
|||
77
pkg/server/service/loadbalancer/failover/response_writer.go
Normal file
77
pkg/server/service/loadbalancer/failover/response_writer.go
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
package failover
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/traefik/traefik/v3/pkg/types"
|
||||
)
|
||||
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
|
||||
needFallback bool
|
||||
written bool
|
||||
header http.Header
|
||||
statusCodeRange types.HTTPCodeRanges
|
||||
}
|
||||
|
||||
func (r *responseWriter) Write(b []byte) (int, error) {
|
||||
if !r.written {
|
||||
r.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
if r.needFallback {
|
||||
// As we will fallback, we can discard the response body.
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
return r.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
func (r *responseWriter) Header() http.Header {
|
||||
if r.header == nil {
|
||||
r.header = make(http.Header)
|
||||
}
|
||||
|
||||
return r.header
|
||||
}
|
||||
|
||||
func (r *responseWriter) WriteHeader(statusCode int) {
|
||||
if statusCode >= 100 && statusCode <= 199 && statusCode != http.StatusSwitchingProtocols {
|
||||
clear(r.header)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !r.written {
|
||||
r.written = true
|
||||
r.needFallback = r.statusCodeRange.Contains(statusCode)
|
||||
|
||||
if !r.needFallback {
|
||||
for k, v := range r.header {
|
||||
for _, vv := range v {
|
||||
r.ResponseWriter.Header().Add(k, vv)
|
||||
}
|
||||
}
|
||||
|
||||
r.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *responseWriter) Flush() {
|
||||
if flusher, ok := r.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if h, ok := r.ResponseWriter.(http.Hijacker); ok {
|
||||
return h.Hijack()
|
||||
}
|
||||
|
||||
return nil, nil, fmt.Errorf("not a hijacker: %T", r.ResponseWriter)
|
||||
}
|
||||
|
|
@ -54,6 +54,10 @@ type mirrorHandler struct {
|
|||
count uint64
|
||||
}
|
||||
|
||||
type clonableRequest interface {
|
||||
Clone(ctx context.Context) *http.Request
|
||||
}
|
||||
|
||||
func (m *Mirroring) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
mirrors := m.getActiveMirrors()
|
||||
if len(mirrors) == 0 {
|
||||
|
|
@ -62,21 +66,28 @@ func (m *Mirroring) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||
}
|
||||
|
||||
logger := log.Ctx(req.Context())
|
||||
rr, bytesRead, err := newReusableRequest(req, m.mirrorBody, m.maxBodySize)
|
||||
if err != nil && !errors.Is(err, errBodyTooLarge) {
|
||||
http.Error(rw, fmt.Sprintf("%s: creating reusable request: %v",
|
||||
http.StatusText(http.StatusInternalServerError), err), http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
var rr clonableRequest = req
|
||||
|
||||
if m.mirrorBody {
|
||||
var err error
|
||||
var bytesRead []byte
|
||||
rr, bytesRead, err = NewReusableRequest(req, m.maxBodySize)
|
||||
if err != nil && !errors.Is(err, ErrBodyTooLarge) {
|
||||
logger.Debug().Err(err).Msg("Error while creating reusable request for mirroring")
|
||||
http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if errors.Is(err, ErrBodyTooLarge) {
|
||||
req.Body = io.NopCloser(io.MultiReader(bytes.NewReader(bytesRead), req.Body))
|
||||
m.handler.ServeHTTP(rw, req)
|
||||
logger.Debug().Msg("No mirroring, request body larger than allowed size")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if errors.Is(err, errBodyTooLarge) {
|
||||
req.Body = io.NopCloser(io.MultiReader(bytes.NewReader(bytesRead), req.Body))
|
||||
m.handler.ServeHTTP(rw, req)
|
||||
logger.Debug().Msg("No mirroring, request body larger than allowed size")
|
||||
return
|
||||
}
|
||||
|
||||
m.handler.ServeHTTP(rw, rr.clone(req.Context()))
|
||||
m.handler.ServeHTTP(rw, rr.Clone(req.Context()))
|
||||
|
||||
select {
|
||||
case <-req.Context().Done():
|
||||
|
|
@ -89,7 +100,7 @@ func (m *Mirroring) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||
m.routinePool.GoCtx(func(_ context.Context) {
|
||||
for _, handler := range mirrors {
|
||||
// prepare request, update body from buffer
|
||||
r := rr.clone(req.Context())
|
||||
r := rr.Clone(req.Context())
|
||||
|
||||
// In ServeHTTP, we rely on the presence of the accessLog datatable found in the request's context
|
||||
// to know whether we should mutate said datatable (and contribute some fields to the log).
|
||||
|
|
@ -192,23 +203,23 @@ func (c contextStopPropagation) Done() <-chan struct{} {
|
|||
return make(chan struct{})
|
||||
}
|
||||
|
||||
// reusableRequest keeps in memory the body of the given request,
|
||||
// ReusableRequest keeps in memory the body of the given request,
|
||||
// so that the request can be fully cloned by each mirror.
|
||||
type reusableRequest struct {
|
||||
type ReusableRequest struct {
|
||||
req *http.Request
|
||||
body []byte
|
||||
}
|
||||
|
||||
var errBodyTooLarge = errors.New("request body too large")
|
||||
var ErrBodyTooLarge = errors.New("request body too large")
|
||||
|
||||
// if the returned error is errBodyTooLarge, newReusableRequest also returns the
|
||||
// bytes that were already consumed from the request's body.
|
||||
func newReusableRequest(req *http.Request, mirrorBody bool, maxBodySize int64) (*reusableRequest, []byte, error) {
|
||||
// NewReusableRequest returns a new reusable request. If the returned error is ErrBodyTooLarge, NewReusableRequest also returns the
|
||||
// bytes already consumed from the request's body.
|
||||
func NewReusableRequest(req *http.Request, maxBodySize int64) (*ReusableRequest, []byte, error) {
|
||||
if req == nil {
|
||||
return nil, nil, errors.New("nil input request")
|
||||
}
|
||||
if req.Body == nil || req.ContentLength == 0 || !mirrorBody {
|
||||
return &reusableRequest{req: req}, nil, nil
|
||||
if req.Body == nil || req.ContentLength == 0 {
|
||||
return &ReusableRequest{req: req}, nil, nil
|
||||
}
|
||||
|
||||
// unbounded body size
|
||||
|
|
@ -217,7 +228,7 @@ func newReusableRequest(req *http.Request, mirrorBody bool, maxBodySize int64) (
|
|||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return &reusableRequest{
|
||||
return &ReusableRequest{
|
||||
req: req,
|
||||
body: body,
|
||||
}, nil, nil
|
||||
|
|
@ -234,17 +245,18 @@ func newReusableRequest(req *http.Request, mirrorBody bool, maxBodySize int64) (
|
|||
// we got an ErrUnexpectedEOF, which means there was less than maxBodySize data to read,
|
||||
// which permits us sending also to all the mirrors later.
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return &reusableRequest{
|
||||
return &ReusableRequest{
|
||||
req: req,
|
||||
body: body[:n],
|
||||
}, nil, nil
|
||||
}
|
||||
|
||||
// err == nil , which means data size > maxBodySize
|
||||
return nil, body[:n], errBodyTooLarge
|
||||
return nil, body[:n], ErrBodyTooLarge
|
||||
}
|
||||
|
||||
func (rr reusableRequest) clone(ctx context.Context) *http.Request {
|
||||
// Clone clones the request.
|
||||
func (rr ReusableRequest) Clone(ctx context.Context) *http.Request {
|
||||
req := rr.req.Clone(ctx)
|
||||
|
||||
if rr.body != nil {
|
||||
|
|
|
|||
|
|
@ -224,16 +224,16 @@ func TestCloneRequest(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
ctx := req.Context()
|
||||
rr, _, err := newReusableRequest(req, true, defaultMaxBodySize)
|
||||
rr, _, err := NewReusableRequest(req, defaultMaxBodySize)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// first call
|
||||
cloned := rr.clone(ctx)
|
||||
cloned := rr.Clone(ctx)
|
||||
assert.Equal(t, cloned, req)
|
||||
assert.Nil(t, cloned.Body)
|
||||
|
||||
// second call
|
||||
cloned = rr.clone(ctx)
|
||||
cloned = rr.Clone(ctx)
|
||||
assert.Equal(t, cloned, req)
|
||||
assert.Nil(t, cloned.Body)
|
||||
})
|
||||
|
|
@ -249,17 +249,17 @@ func TestCloneRequest(t *testing.T) {
|
|||
ctx := req.Context()
|
||||
req.ContentLength = int64(contentLength)
|
||||
|
||||
rr, _, err := newReusableRequest(req, true, defaultMaxBodySize)
|
||||
rr, _, err := NewReusableRequest(req, defaultMaxBodySize)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// first call
|
||||
cloned := rr.clone(ctx)
|
||||
cloned := rr.Clone(ctx)
|
||||
body, err := io.ReadAll(cloned.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, bb, body)
|
||||
|
||||
// second call
|
||||
cloned = rr.clone(ctx)
|
||||
cloned = rr.Clone(ctx)
|
||||
body, err = io.ReadAll(cloned.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, bb, body)
|
||||
|
|
@ -272,7 +272,7 @@ func TestCloneRequest(t *testing.T) {
|
|||
req, err := http.NewRequest(http.MethodPost, "/", buf)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, expectedBytes, err := newReusableRequest(req, true, 2)
|
||||
_, expectedBytes, err := NewReusableRequest(req, 2)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, expectedBytes, bb[:3])
|
||||
})
|
||||
|
|
@ -284,7 +284,7 @@ func TestCloneRequest(t *testing.T) {
|
|||
req, err := http.NewRequest(http.MethodPost, "/", buf)
|
||||
assert.NoError(t, err)
|
||||
|
||||
rr, expectedBytes, err := newReusableRequest(req, true, 20)
|
||||
rr, expectedBytes, err := NewReusableRequest(req, 20)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, expectedBytes)
|
||||
assert.Len(t, rr.body, 10)
|
||||
|
|
@ -296,14 +296,14 @@ func TestCloneRequest(t *testing.T) {
|
|||
req, err := http.NewRequest(http.MethodGet, "/", buf)
|
||||
assert.NoError(t, err)
|
||||
|
||||
rr, expectedBytes, err := newReusableRequest(req, true, 20)
|
||||
rr, expectedBytes, err := NewReusableRequest(req, 20)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, expectedBytes)
|
||||
assert.Empty(t, rr.body)
|
||||
})
|
||||
|
||||
t.Run("no request given", func(t *testing.T) {
|
||||
_, _, err := newReusableRequest(nil, true, defaultMaxBodySize)
|
||||
_, _, err := NewReusableRequest(nil, defaultMaxBodySize)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -210,24 +210,29 @@ func (m *Manager) LaunchHealthCheck(ctx context.Context) {
|
|||
}
|
||||
|
||||
func (m *Manager) getFailoverServiceHandler(ctx context.Context, serviceName string, config *dynamic.Failover) (http.Handler, error) {
|
||||
f := failover.New(config.HealthCheck)
|
||||
|
||||
serviceHandler, err := m.BuildHTTP(ctx, config.Service)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f.SetHandler(serviceHandler)
|
||||
|
||||
updater, ok := serviceHandler.(healthcheck.StatusUpdater)
|
||||
if !ok {
|
||||
updater, implementUpdater := serviceHandler.(healthcheck.StatusUpdater)
|
||||
isErrorDefined := config.Errors != nil && len(config.Errors.Status) > 0
|
||||
if !implementUpdater && !isErrorDefined {
|
||||
return nil, fmt.Errorf("child service %v of %v not a healthcheck.StatusUpdater (%T)", config.Service, serviceName, serviceHandler)
|
||||
}
|
||||
|
||||
if err := updater.RegisterStatusUpdater(func(up bool) {
|
||||
f.SetHandlerStatus(ctx, up)
|
||||
}); err != nil {
|
||||
return nil, fmt.Errorf("cannot register %v as updater for %v: %w", config.Service, serviceName, err)
|
||||
f, err := failover.New(config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating failover service %v: %w", serviceName, err)
|
||||
}
|
||||
f.SetHandler(serviceHandler)
|
||||
|
||||
if implementUpdater {
|
||||
if err := updater.RegisterStatusUpdater(func(up bool) {
|
||||
f.SetHandlerStatus(ctx, up)
|
||||
}); err != nil && !isErrorDefined {
|
||||
return nil, fmt.Errorf("cannot register %v as updater for %v: %w", config.Service, serviceName, err)
|
||||
}
|
||||
}
|
||||
|
||||
fallbackHandler, err := m.BuildHTTP(ctx, config.Fallback)
|
||||
|
|
@ -242,8 +247,8 @@ func (m *Manager) getFailoverServiceHandler(ctx context.Context, serviceName str
|
|||
return f, nil
|
||||
}
|
||||
|
||||
fallbackUpdater, ok := fallbackHandler.(healthcheck.StatusUpdater)
|
||||
if !ok {
|
||||
fallbackUpdater, implementUpdater := fallbackHandler.(healthcheck.StatusUpdater)
|
||||
if !implementUpdater {
|
||||
return nil, fmt.Errorf("child service %v of %v not a healthcheck.StatusUpdater (%T)", config.Fallback, serviceName, fallbackHandler)
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue