mirror of
https://github.com/hashicorp/vault.git
synced 2026-06-08 16:24:51 -04:00
X-Forwarded-For (#4380)
This commit is contained in:
parent
f7e886f29d
commit
80b17705a9
8 changed files with 493 additions and 42 deletions
|
|
@ -32,6 +32,7 @@ import (
|
|||
"github.com/hashicorp/errwrap"
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
sockaddr "github.com/hashicorp/go-sockaddr"
|
||||
"github.com/hashicorp/vault/audit"
|
||||
"github.com/hashicorp/vault/command/server"
|
||||
"github.com/hashicorp/vault/helper/gated-writer"
|
||||
|
|
@ -92,6 +93,11 @@ type ServerCommand struct {
|
|||
flagTestVerifyOnly bool
|
||||
}
|
||||
|
||||
type ServerListener struct {
|
||||
net.Listener
|
||||
config map[string]interface{}
|
||||
}
|
||||
|
||||
func (c *ServerCommand) Synopsis() string {
|
||||
return "Start a Vault server"
|
||||
}
|
||||
|
|
@ -670,8 +676,8 @@ CLUSTER_SYNTHESIS_COMPLETE:
|
|||
clusterAddrs := []*net.TCPAddr{}
|
||||
|
||||
// Initialize the listeners
|
||||
lns := make([]ServerListener, 0, len(config.Listeners))
|
||||
c.reloadFuncsLock.Lock()
|
||||
lns := make([]net.Listener, 0, len(config.Listeners))
|
||||
for i, lnConfig := range config.Listeners {
|
||||
ln, props, reloadFunc, err := server.NewListener(lnConfig.Type, lnConfig.Config, c.logGate, c.UI)
|
||||
if err != nil {
|
||||
|
|
@ -679,7 +685,10 @@ CLUSTER_SYNTHESIS_COMPLETE:
|
|||
return 1
|
||||
}
|
||||
|
||||
lns = append(lns, ln)
|
||||
lns = append(lns, ServerListener{
|
||||
Listener: ln,
|
||||
config: lnConfig.Config,
|
||||
})
|
||||
|
||||
if reloadFunc != nil {
|
||||
relSlice := (*c.reloadFuncs)["listener|"+lnConfig.Type]
|
||||
|
|
@ -738,7 +747,7 @@ CLUSTER_SYNTHESIS_COMPLETE:
|
|||
// Make sure we close all listeners from this point on
|
||||
listenerCloseFunc := func() {
|
||||
for _, ln := range lns {
|
||||
ln.Close()
|
||||
ln.Listener.Close()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -776,12 +785,10 @@ CLUSTER_SYNTHESIS_COMPLETE:
|
|||
return 0
|
||||
}
|
||||
|
||||
handler := vaulthttp.Handler(core)
|
||||
|
||||
// This needs to happen before we first unseal, so before we trigger dev
|
||||
// mode if it's set
|
||||
core.SetClusterListenerAddrs(clusterAddrs)
|
||||
core.SetClusterHandler(handler)
|
||||
core.SetClusterHandler(vaulthttp.Handler(core))
|
||||
|
||||
err = core.UnsealWithStoredKeys(context.Background())
|
||||
if err != nil {
|
||||
|
|
@ -914,10 +921,23 @@ CLUSTER_SYNTHESIS_COMPLETE:
|
|||
|
||||
// Initialize the HTTP servers
|
||||
for _, ln := range lns {
|
||||
handler := vaulthttp.Handler(core)
|
||||
|
||||
// We perform validation on the config earlier, we can just cast here
|
||||
if _, ok := ln.config["x_forwarded_for_authorized_addrs"]; ok {
|
||||
hopSkips := ln.config["x_forwarded_for_hop_skips"].(int)
|
||||
authzdAddrs := ln.config["x_forwarded_for_authorized_addrs"].([]*sockaddr.SockAddrMarshaler)
|
||||
rejectNotPresent := ln.config["x_forwarded_for_reject_not_present"].(bool)
|
||||
rejectNonAuthz := ln.config["x_forwarded_for_reject_not_authorized"].(bool)
|
||||
if len(authzdAddrs) > 0 {
|
||||
handler = vaulthttp.WrapForwardedForHandler(handler, authzdAddrs, rejectNotPresent, rejectNonAuthz, hopSkips)
|
||||
}
|
||||
}
|
||||
|
||||
server := &http.Server{
|
||||
Handler: handler,
|
||||
}
|
||||
go server.Serve(ln)
|
||||
go server.Serve(ln.Listener)
|
||||
}
|
||||
|
||||
if newCoreError != nil {
|
||||
|
|
|
|||
|
|
@ -761,6 +761,10 @@ func parseListeners(result *Config, list *ast.ObjectList) error {
|
|||
"address",
|
||||
"cluster_address",
|
||||
"endpoint",
|
||||
"x_forwarded_for_authorized_addrs",
|
||||
"x_forwarded_for_hop_skips",
|
||||
"x_forwarded_for_reject_not_authorized",
|
||||
"x_forwarded_for_reject_not_present",
|
||||
"infrastructure",
|
||||
"node_id",
|
||||
"proxy_protocol_behavior",
|
||||
|
|
|
|||
|
|
@ -1,11 +1,15 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/vault/helper/parseutil"
|
||||
"github.com/hashicorp/vault/helper/reload"
|
||||
"github.com/mitchellh/cli"
|
||||
)
|
||||
|
|
@ -39,6 +43,57 @@ func tcpListenerFactory(config map[string]interface{}, _ io.Writer, ui cli.Ui) (
|
|||
}
|
||||
|
||||
props := map[string]string{"addr": addr}
|
||||
|
||||
ffAllowedRaw, ffAllowedOK := config["x_forwarded_for_authorized_addrs"]
|
||||
if ffAllowedOK {
|
||||
ffAllowed, err := parseutil.ParseAddrs(ffAllowedRaw)
|
||||
if err != nil {
|
||||
return nil, nil, nil, errwrap.Wrapf("error parsing \"x_forwarded_for_authorized_addrs\": {{err}}", err)
|
||||
}
|
||||
props["x_forwarded_for_authorized_addrs"] = fmt.Sprintf("%v", ffAllowed)
|
||||
config["x_forwarded_for_authorized_addrs"] = ffAllowed
|
||||
}
|
||||
|
||||
if ffHopsRaw, ok := config["x_forwarded_for_hop_skips"]; ok {
|
||||
ffHops64, err := parseutil.ParseInt(ffHopsRaw)
|
||||
if err != nil {
|
||||
return nil, nil, nil, errwrap.Wrapf("error parsing \"x_forwarded_for_hop_skips\": {{err}}", err)
|
||||
}
|
||||
if ffHops64 < 0 {
|
||||
return nil, nil, nil, fmt.Errorf("\"x_forwarded_for_hop_skips\" cannot be negative")
|
||||
}
|
||||
ffHops := int(ffHops64)
|
||||
props["x_forwarded_for_hop_skips"] = strconv.Itoa(ffHops)
|
||||
config["x_forwarded_for_hop_skips"] = ffHops
|
||||
} else if ffAllowedOK {
|
||||
props["x_forwarded_for_hop_skips"] = "0"
|
||||
config["x_forwarded_for_hop_skips"] = int(0)
|
||||
}
|
||||
|
||||
if ffRejectNotPresentRaw, ok := config["x_forwarded_for_reject_not_present"]; ok {
|
||||
ffRejectNotPresent, err := parseutil.ParseBool(ffRejectNotPresentRaw)
|
||||
if err != nil {
|
||||
return nil, nil, nil, errwrap.Wrapf("error parsing \"x_forwarded_for_reject_not_present\": {{err}}", err)
|
||||
}
|
||||
props["x_forwarded_for_reject_not_present"] = strconv.FormatBool(ffRejectNotPresent)
|
||||
config["x_forwarded_for_reject_not_present"] = ffRejectNotPresent
|
||||
} else if ffAllowedOK {
|
||||
props["x_forwarded_for_reject_not_present"] = "true"
|
||||
config["x_forwarded_for_reject_not_present"] = true
|
||||
}
|
||||
|
||||
if ffRejectNonAuthorizedRaw, ok := config["x_forwarded_for_reject_not_authorized"]; ok {
|
||||
ffRejectNonAuthorized, err := parseutil.ParseBool(ffRejectNonAuthorizedRaw)
|
||||
if err != nil {
|
||||
return nil, nil, nil, errwrap.Wrapf("error parsing \"x_forwarded_for_reject_not_authorized\": {{err}}", err)
|
||||
}
|
||||
props["x_forwarded_for_reject_not_authorized"] = strconv.FormatBool(ffRejectNonAuthorized)
|
||||
config["x_forwarded_for_reject_not_authorized"] = ffRejectNonAuthorized
|
||||
} else if ffAllowedOK {
|
||||
props["x_forwarded_for_reject_not_authorized"] = "true"
|
||||
config["x_forwarded_for_reject_not_authorized"] = true
|
||||
}
|
||||
|
||||
return listenerWrapTLS(ln, props, config, ui)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,10 +3,13 @@ package parseutil
|
|||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
sockaddr "github.com/hashicorp/go-sockaddr"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
|
@ -118,3 +121,43 @@ func ParseCommaStringSlice(in interface{}) ([]string, error) {
|
|||
}
|
||||
return strutil.TrimStrings(result), nil
|
||||
}
|
||||
|
||||
func ParseAddrs(addrs interface{}) ([]*sockaddr.SockAddrMarshaler, error) {
|
||||
out := make([]*sockaddr.SockAddrMarshaler, 0)
|
||||
stringAddrs := make([]string, 0)
|
||||
|
||||
switch addrs.(type) {
|
||||
case string:
|
||||
stringAddrs = strutil.ParseArbitraryStringSlice(addrs.(string), ",")
|
||||
if len(stringAddrs) == 0 {
|
||||
return nil, fmt.Errorf("unable to parse addresses from %v", addrs)
|
||||
}
|
||||
|
||||
case []string:
|
||||
stringAddrs = addrs.([]string)
|
||||
|
||||
case []interface{}:
|
||||
for _, v := range addrs.([]interface{}) {
|
||||
stringAddr, ok := v.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("error parsing %v as string", v)
|
||||
}
|
||||
stringAddrs = append(stringAddrs, stringAddr)
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown address input type %T", addrs)
|
||||
}
|
||||
|
||||
for _, addr := range stringAddrs {
|
||||
sa, err := sockaddr.NewSockAddr(addr)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf(fmt.Sprintf("error parsing address %q: {{err}}", addr), err)
|
||||
}
|
||||
out = append(out, &sockaddr.SockAddrMarshaler{
|
||||
SockAddr: sa,
|
||||
})
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import (
|
|||
proxyproto "github.com/armon/go-proxyproto"
|
||||
"github.com/hashicorp/errwrap"
|
||||
sockaddr "github.com/hashicorp/go-sockaddr"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/helper/parseutil"
|
||||
)
|
||||
|
||||
// ProxyProtoConfig contains configuration for the PROXY protocol
|
||||
|
|
@ -19,42 +19,12 @@ type ProxyProtoConfig struct {
|
|||
}
|
||||
|
||||
func (p *ProxyProtoConfig) SetAuthorizedAddrs(addrs interface{}) error {
|
||||
p.AuthorizedAddrs = make([]*sockaddr.SockAddrMarshaler, 0)
|
||||
stringAddrs := make([]string, 0)
|
||||
|
||||
switch addrs.(type) {
|
||||
case string:
|
||||
stringAddrs = strutil.ParseArbitraryStringSlice(addrs.(string), ",")
|
||||
if len(stringAddrs) == 0 {
|
||||
return fmt.Errorf("unable to parse addresses from %v", addrs)
|
||||
}
|
||||
|
||||
case []string:
|
||||
stringAddrs = addrs.([]string)
|
||||
|
||||
case []interface{}:
|
||||
for _, v := range addrs.([]interface{}) {
|
||||
stringAddr, ok := v.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("error parsing %v as string", v)
|
||||
}
|
||||
stringAddrs = append(stringAddrs, stringAddr)
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unknown address input type %T", addrs)
|
||||
}
|
||||
|
||||
for _, addr := range stringAddrs {
|
||||
sa, err := sockaddr.NewSockAddr(addr)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf("error parsing authorized address: {{err}}", err)
|
||||
}
|
||||
p.AuthorizedAddrs = append(p.AuthorizedAddrs, &sockaddr.SockAddrMarshaler{
|
||||
SockAddr: sa,
|
||||
})
|
||||
aa, err := parseutil.ParseAddrs(addrs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.AuthorizedAddrs = aa
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
249
http/forwarded_for_test.go
Normal file
249
http/forwarded_for_test.go
Normal file
|
|
@ -0,0 +1,249 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
sockaddr "github.com/hashicorp/go-sockaddr"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
|
||||
func TestHandler_XForwardedFor(t *testing.T) {
|
||||
goodAddr, err := sockaddr.NewIPAddr("127.0.0.1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
badAddr, err := sockaddr.NewIPAddr("1.2.3.4")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// First: test reject not present
|
||||
t.Run("reject_not_present", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testHandler := func(c *vault.Core) http.Handler {
|
||||
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(r.RemoteAddr))
|
||||
})
|
||||
return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{
|
||||
&sockaddr.SockAddrMarshaler{
|
||||
SockAddr: goodAddr,
|
||||
},
|
||||
}, true, false, 0)
|
||||
}
|
||||
|
||||
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
|
||||
HandlerFunc: testHandler,
|
||||
})
|
||||
cluster.Start()
|
||||
defer cluster.Cleanup()
|
||||
client := cluster.Cores[0].Client
|
||||
|
||||
req := client.NewRequest("GET", "/")
|
||||
_, err = client.RawRequest(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "missing x-forwarded-for") {
|
||||
t.Fatalf("bad error message: %v", err)
|
||||
}
|
||||
req = client.NewRequest("GET", "/")
|
||||
req.Headers = make(http.Header)
|
||||
req.Headers.Set("x-forwarded-for", "1.2.3.4")
|
||||
resp, err := client.RawRequest(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
buf := bytes.NewBuffer(nil)
|
||||
buf.ReadFrom(resp.Body)
|
||||
if !strings.HasPrefix(buf.String(), "1.2.3.4:") {
|
||||
t.Fatalf("bad body: %s", buf.String())
|
||||
}
|
||||
})
|
||||
|
||||
// Next: test allow unauth
|
||||
t.Run("allow_unauth", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testHandler := func(c *vault.Core) http.Handler {
|
||||
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(r.RemoteAddr))
|
||||
})
|
||||
return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{
|
||||
&sockaddr.SockAddrMarshaler{
|
||||
SockAddr: badAddr,
|
||||
},
|
||||
}, true, false, 0)
|
||||
}
|
||||
|
||||
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
|
||||
HandlerFunc: testHandler,
|
||||
})
|
||||
cluster.Start()
|
||||
defer cluster.Cleanup()
|
||||
client := cluster.Cores[0].Client
|
||||
|
||||
req := client.NewRequest("GET", "/")
|
||||
req.Headers = make(http.Header)
|
||||
req.Headers.Set("x-forwarded-for", "5.6.7.8")
|
||||
resp, err := client.RawRequest(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
buf := bytes.NewBuffer(nil)
|
||||
buf.ReadFrom(resp.Body)
|
||||
if !strings.HasPrefix(buf.String(), "127.0.0.1:") {
|
||||
t.Fatalf("bad body: %s", buf.String())
|
||||
}
|
||||
})
|
||||
|
||||
// Next: test fail unauth
|
||||
t.Run("fail_unauth", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testHandler := func(c *vault.Core) http.Handler {
|
||||
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(r.RemoteAddr))
|
||||
})
|
||||
return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{
|
||||
&sockaddr.SockAddrMarshaler{
|
||||
SockAddr: badAddr,
|
||||
},
|
||||
}, true, true, 0)
|
||||
}
|
||||
|
||||
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
|
||||
HandlerFunc: testHandler,
|
||||
})
|
||||
cluster.Start()
|
||||
defer cluster.Cleanup()
|
||||
client := cluster.Cores[0].Client
|
||||
|
||||
req := client.NewRequest("GET", "/")
|
||||
req.Headers = make(http.Header)
|
||||
req.Headers.Set("x-forwarded-for", "5.6.7.8")
|
||||
_, err = client.RawRequest(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not authorized for x-forwarded-for") {
|
||||
t.Fatalf("bad error message: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Next: test bad hops (too many)
|
||||
t.Run("too_many_hops", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testHandler := func(c *vault.Core) http.Handler {
|
||||
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(r.RemoteAddr))
|
||||
})
|
||||
return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{
|
||||
&sockaddr.SockAddrMarshaler{
|
||||
SockAddr: goodAddr,
|
||||
},
|
||||
}, true, true, 4)
|
||||
}
|
||||
|
||||
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
|
||||
HandlerFunc: testHandler,
|
||||
})
|
||||
cluster.Start()
|
||||
defer cluster.Cleanup()
|
||||
client := cluster.Cores[0].Client
|
||||
|
||||
req := client.NewRequest("GET", "/")
|
||||
req.Headers = make(http.Header)
|
||||
req.Headers.Set("x-forwarded-for", "2.3.4.5,3.4.5.6")
|
||||
_, err = client.RawRequest(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "would skip before earliest") {
|
||||
t.Fatalf("bad error message: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Next: test picking correct value
|
||||
t.Run("correct_hop_skipping", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testHandler := func(c *vault.Core) http.Handler {
|
||||
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(r.RemoteAddr))
|
||||
})
|
||||
return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{
|
||||
&sockaddr.SockAddrMarshaler{
|
||||
SockAddr: goodAddr,
|
||||
},
|
||||
}, true, true, 1)
|
||||
}
|
||||
|
||||
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
|
||||
HandlerFunc: testHandler,
|
||||
})
|
||||
cluster.Start()
|
||||
defer cluster.Cleanup()
|
||||
client := cluster.Cores[0].Client
|
||||
|
||||
req := client.NewRequest("GET", "/")
|
||||
req.Headers = make(http.Header)
|
||||
req.Headers.Set("x-forwarded-for", "2.3.4.5,3.4.5.6,4.5.6.7,5.6.7.8")
|
||||
resp, err := client.RawRequest(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
buf := bytes.NewBuffer(nil)
|
||||
buf.ReadFrom(resp.Body)
|
||||
if !strings.HasPrefix(buf.String(), "4.5.6.7:") {
|
||||
t.Fatalf("bad body: %s", buf.String())
|
||||
}
|
||||
})
|
||||
|
||||
// Next: multi-header approach
|
||||
t.Run("correct_hop_skipping_multi_header", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testHandler := func(c *vault.Core) http.Handler {
|
||||
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(r.RemoteAddr))
|
||||
})
|
||||
return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{
|
||||
&sockaddr.SockAddrMarshaler{
|
||||
SockAddr: goodAddr,
|
||||
},
|
||||
}, true, true, 1)
|
||||
}
|
||||
|
||||
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
|
||||
HandlerFunc: testHandler,
|
||||
})
|
||||
cluster.Start()
|
||||
defer cluster.Cleanup()
|
||||
client := cluster.Cores[0].Client
|
||||
|
||||
req := client.NewRequest("GET", "/")
|
||||
req.Headers = make(http.Header)
|
||||
req.Headers.Add("x-forwarded-for", "2.3.4.5")
|
||||
req.Headers.Add("x-forwarded-for", "3.4.5.6,4.5.6.7")
|
||||
req.Headers.Add("x-forwarded-for", "5.6.7.8")
|
||||
resp, err := client.RawRequest(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
buf := bytes.NewBuffer(nil)
|
||||
buf.ReadFrom(resp.Body)
|
||||
if !strings.HasPrefix(buf.String(), "4.5.6.7:") {
|
||||
t.Fatalf("bad body: %s", buf.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
@ -4,7 +4,9 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
|
|
@ -13,6 +15,7 @@ import (
|
|||
"github.com/elazarl/go-bindata-assetfs"
|
||||
"github.com/hashicorp/errwrap"
|
||||
cleanhttp "github.com/hashicorp/go-cleanhttp"
|
||||
sockaddr "github.com/hashicorp/go-sockaddr"
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/helper/parseutil"
|
||||
|
|
@ -124,6 +127,94 @@ func wrapGenericHandler(h http.Handler) http.Handler {
|
|||
})
|
||||
}
|
||||
|
||||
func WrapForwardedForHandler(h http.Handler, authorizedAddrs []*sockaddr.SockAddrMarshaler, rejectNotPresent, rejectNonAuthz bool, hopSkips int) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
headers, headersOK := r.Header[textproto.CanonicalMIMEHeaderKey("X-Forwarded-For")]
|
||||
if !headersOK || len(headers) == 0 {
|
||||
if !rejectNotPresent {
|
||||
h.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
respondError(w, http.StatusBadRequest, fmt.Errorf("missing x-forwarded-for header and configured to reject when not present"))
|
||||
return
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
// If not rejecting treat it like we just don't have a valid
|
||||
// header because we can't do a comparison against an address we
|
||||
// can't understand
|
||||
if !rejectNotPresent {
|
||||
h.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
respondError(w, http.StatusBadRequest, errwrap.Wrapf("error parsing client hostport: {{err}}", err))
|
||||
return
|
||||
}
|
||||
|
||||
addr, err := sockaddr.NewIPAddr(host)
|
||||
if err != nil {
|
||||
// We treat this the same as the case above
|
||||
if !rejectNotPresent {
|
||||
h.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
respondError(w, http.StatusBadRequest, errwrap.Wrapf("error parsing client address: {{err}}", err))
|
||||
return
|
||||
}
|
||||
|
||||
var found bool
|
||||
for _, authz := range authorizedAddrs {
|
||||
if authz.Contains(addr) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
// If we didn't find it and aren't configured to reject, simply
|
||||
// don't trust it
|
||||
if !rejectNonAuthz {
|
||||
h.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
respondError(w, http.StatusBadRequest, fmt.Errorf("client address not authorized for x-forwarded-for and configured to reject connection"))
|
||||
return
|
||||
}
|
||||
|
||||
// At this point we have at least one value and it's authorized
|
||||
|
||||
// Split comma separated ones, which are common. This brings it in line
|
||||
// to the multiple-header case.
|
||||
var acc []string
|
||||
for _, header := range headers {
|
||||
vals := strings.Split(header, ",")
|
||||
for _, v := range vals {
|
||||
acc = append(acc, strings.TrimSpace(v))
|
||||
}
|
||||
}
|
||||
|
||||
indexToUse := len(acc) - 1 - hopSkips
|
||||
if indexToUse < 0 {
|
||||
// This is likely an error in either configuration or other
|
||||
// infrastructure. We could either deny the request, or we
|
||||
// could simply not trust the value. Denying the request is
|
||||
// "safer" since if this logic is configured at all there may
|
||||
// be an assumption it can always be trusted. Given that we can
|
||||
// deny accepting the request at all if it's not from an
|
||||
// authorized address, if we're at this point the address is
|
||||
// authorized (or we've turned off explicit rejection) and we
|
||||
// should assume that what comes in should be properly
|
||||
// formatted.
|
||||
respondError(w, http.StatusBadRequest, fmt.Errorf("malformed x-forwarded-for configuration or request, hops to skip (%d) would skip before earliest chain link (chain length %d)", hopSkips, len(headers)))
|
||||
return
|
||||
}
|
||||
|
||||
r.RemoteAddr = net.JoinHostPort(acc[indexToUse], port)
|
||||
h.ServeHTTP(w, r)
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
// A lookup on a token that is about to expire returns nil, which means by the
|
||||
// time we can validate a wrapping token lookup will return nil since it will
|
||||
// be revoked after the call. So we have to do the validation here.
|
||||
|
|
|
|||
|
|
@ -85,6 +85,25 @@ listener "tcp" {
|
|||
authentication for this listener. The default behavior (when this is false)
|
||||
is for Vault to request client certificates when available.
|
||||
|
||||
- `x_forwarded_for_authorized_addrs` `(string: <required-to-enable>)` –
|
||||
Specifies the list of source IP addresses for which an X-Forwarded-For header
|
||||
will be trusted. Comma-separated list or JSON array. This turns on
|
||||
X-Forwarded-For support.
|
||||
|
||||
- `x_forwarded_for_hop_skips` `(string: "0")` – The number of addresses that will be
|
||||
skipped from the *rear* of the set of hops. For instance, for a header value
|
||||
of `1.2.3.4, 2.3.4.5, 3.4.5.6`, if this value is set to `"1"`, the address that
|
||||
will be used as the originating client IP is `2.3.4.5`.
|
||||
|
||||
- `x_forwarded_for_reject_not_authorized` `(string: "true")` – If set false,
|
||||
if there is an X-Forwarded-For header in a connection from an unauthorized
|
||||
address, the header will be ignored and the client connection used as-is,
|
||||
rather than the client connection rejected.
|
||||
|
||||
- `x_forwarded_for_reject_not_present` `(string: "true")` – If set false, if
|
||||
there is no X-Forwarded-For header or it is empty, the client address will be
|
||||
used as-is, rather than the client connection rejected.
|
||||
|
||||
## `tcp` Listener Examples
|
||||
|
||||
### Configuring TLS
|
||||
|
|
|
|||
Loading…
Reference in a new issue