diff --git a/CHANGELOG.md b/CHANGELOG.md index 838ac76fe4..98666ff7d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,12 +1,20 @@ ## 0.10.1 (Unreleased) +FEATURES: + + * X-Forwarded-For support: `X-Forwarded-For` headers can now be used to set + the client IP seen by Vault. See the [TCP listener configuration + page](https://www.vaultproject.io/docs/configuration/listener/tcp.html) for + details. + IMPROVEMENTS: + * auth/token: Add to the token lookup response, the policies inherited due to + identity associations [GH-4366] + * core: Add X-Forwarded-For support [GH-4380] * identity: Add the ability to disable an entity. Disabling an entity does not revoke associated tokens, but while the entity is disabled they cannot be used. [GH-4353] - * auth/token: Add to the token lookup response, the policies inherited due to - identity associations [GH-4366] BUG FIXES: diff --git a/builtin/logical/mssql/path_creds_create.go b/builtin/logical/mssql/path_creds_create.go index 7e26937016..1a954c87b4 100644 --- a/builtin/logical/mssql/path_creds_create.go +++ b/builtin/logical/mssql/path_creds_create.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/helper/dbtxn" "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" @@ -90,15 +91,11 @@ func (b *backend) pathCredsCreateRead(ctx context.Context, req *logical.Request, continue } - stmt, err := tx.Prepare(Query(query, map[string]string{ + m := map[string]string{ "name": username, "password": password, - })) - if err != nil { - return nil, err } - defer stmt.Close() - if _, err := stmt.Exec(); err != nil { + if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { return nil, err } } diff --git a/builtin/logical/mssql/secret_creds.go b/builtin/logical/mssql/secret_creds.go index ce93643eec..5edac67be5 100644 --- a/builtin/logical/mssql/secret_creds.go +++ b/builtin/logical/mssql/secret_creds.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/hashicorp/errwrap" + "github.com/hashicorp/vault/helper/dbtxn" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -130,16 +131,11 @@ func (b *backend) secretCredsRevoke(ctx context.Context, req *logical.Request, d // many permissions as possible right now var lastStmtError error for _, query := range revokeStmts { - stmt, err := db.Prepare(query) - if err != nil { + + if err := dbtxn.ExecuteDBQuery(ctx, db, nil, query); err != nil { lastStmtError = err continue } - defer stmt.Close() - _, err = stmt.Exec() - if err != nil { - lastStmtError = err - } } // can't drop if not all database users are dropped diff --git a/builtin/logical/mysql/path_role_create.go b/builtin/logical/mysql/path_role_create.go index ae184bdb34..135587575d 100644 --- a/builtin/logical/mysql/path_role_create.go +++ b/builtin/logical/mysql/path_role_create.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/helper/dbtxn" "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" @@ -103,15 +104,11 @@ func (b *backend) pathRoleCreateRead(ctx context.Context, req *logical.Request, continue } - stmt, err := tx.Prepare(Query(query, map[string]string{ + m := map[string]string{ "name": username, "password": password, - })) - if err != nil { - return nil, err } - defer stmt.Close() - if _, err := stmt.Exec(); err != nil { + if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { return nil, err } } diff --git a/builtin/logical/postgresql/path_role_create.go b/builtin/logical/postgresql/path_role_create.go index 7fed904667..113e3ab896 100644 --- a/builtin/logical/postgresql/path_role_create.go +++ b/builtin/logical/postgresql/path_role_create.go @@ -7,6 +7,7 @@ import ( "time" "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/helper/dbtxn" "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" @@ -106,16 +107,13 @@ func (b *backend) pathRoleCreateRead(ctx context.Context, req *logical.Request, continue } - stmt, err := tx.Prepare(Query(query, map[string]string{ + m := map[string]string{ "name": username, "password": password, "expiration": expiration, - })) - if err != nil { - return nil, err } - defer stmt.Close() - if _, err := stmt.Exec(); err != nil { + + if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { return nil, err } } diff --git a/builtin/logical/postgresql/secret_creds.go b/builtin/logical/postgresql/secret_creds.go index 87748d1ff9..d00beacb65 100644 --- a/builtin/logical/postgresql/secret_creds.go +++ b/builtin/logical/postgresql/secret_creds.go @@ -8,6 +8,7 @@ import ( "time" "github.com/hashicorp/errwrap" + "github.com/hashicorp/vault/helper/dbtxn" "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" @@ -211,14 +212,7 @@ func (b *backend) secretCredsRevoke(ctx context.Context, req *logical.Request, d // many permissions as possible right now var lastStmtError error for _, query := range revocationStmts { - stmt, err := db.Prepare(query) - if err != nil { - lastStmtError = err - continue - } - defer stmt.Close() - _, err = stmt.Exec() - if err != nil { + if err := dbtxn.ExecuteDBQuery(ctx, db, nil, query); err != nil { lastStmtError = err } } @@ -258,15 +252,10 @@ func (b *backend) secretCredsRevoke(ctx context.Context, req *logical.Request, d continue } - stmt, err := tx.Prepare(Query(query, map[string]string{ + m := map[string]string{ "name": username, - })) - if err != nil { - return nil, err } - defer stmt.Close() - - if _, err := stmt.Exec(); err != nil { + if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { return nil, err } } diff --git a/command/server.go b/command/server.go index cb5bd9b7b5..630f9aada7 100644 --- a/command/server.go +++ b/command/server.go @@ -32,6 +32,7 @@ import ( "github.com/hashicorp/errwrap" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-multierror" + sockaddr "github.com/hashicorp/go-sockaddr" "github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/command/server" "github.com/hashicorp/vault/helper/gated-writer" @@ -92,6 +93,11 @@ type ServerCommand struct { flagTestVerifyOnly bool } +type ServerListener struct { + net.Listener + config map[string]interface{} +} + func (c *ServerCommand) Synopsis() string { return "Start a Vault server" } @@ -670,8 +676,8 @@ CLUSTER_SYNTHESIS_COMPLETE: clusterAddrs := []*net.TCPAddr{} // Initialize the listeners + lns := make([]ServerListener, 0, len(config.Listeners)) c.reloadFuncsLock.Lock() - lns := make([]net.Listener, 0, len(config.Listeners)) for i, lnConfig := range config.Listeners { ln, props, reloadFunc, err := server.NewListener(lnConfig.Type, lnConfig.Config, c.logGate, c.UI) if err != nil { @@ -679,7 +685,10 @@ CLUSTER_SYNTHESIS_COMPLETE: return 1 } - lns = append(lns, ln) + lns = append(lns, ServerListener{ + Listener: ln, + config: lnConfig.Config, + }) if reloadFunc != nil { relSlice := (*c.reloadFuncs)["listener|"+lnConfig.Type] @@ -738,7 +747,7 @@ CLUSTER_SYNTHESIS_COMPLETE: // Make sure we close all listeners from this point on listenerCloseFunc := func() { for _, ln := range lns { - ln.Close() + ln.Listener.Close() } } @@ -776,12 +785,10 @@ CLUSTER_SYNTHESIS_COMPLETE: return 0 } - handler := vaulthttp.Handler(core) - // This needs to happen before we first unseal, so before we trigger dev // mode if it's set core.SetClusterListenerAddrs(clusterAddrs) - core.SetClusterHandler(handler) + core.SetClusterHandler(vaulthttp.Handler(core)) err = core.UnsealWithStoredKeys(context.Background()) if err != nil { @@ -914,10 +921,23 @@ CLUSTER_SYNTHESIS_COMPLETE: // Initialize the HTTP servers for _, ln := range lns { + handler := vaulthttp.Handler(core) + + // We perform validation on the config earlier, we can just cast here + if _, ok := ln.config["x_forwarded_for_authorized_addrs"]; ok { + hopSkips := ln.config["x_forwarded_for_hop_skips"].(int) + authzdAddrs := ln.config["x_forwarded_for_authorized_addrs"].([]*sockaddr.SockAddrMarshaler) + rejectNotPresent := ln.config["x_forwarded_for_reject_not_present"].(bool) + rejectNonAuthz := ln.config["x_forwarded_for_reject_not_authorized"].(bool) + if len(authzdAddrs) > 0 { + handler = vaulthttp.WrapForwardedForHandler(handler, authzdAddrs, rejectNotPresent, rejectNonAuthz, hopSkips) + } + } + server := &http.Server{ Handler: handler, } - go server.Serve(ln) + go server.Serve(ln.Listener) } if newCoreError != nil { diff --git a/command/server/config.go b/command/server/config.go index b150b83bd8..21520ce094 100644 --- a/command/server/config.go +++ b/command/server/config.go @@ -761,6 +761,10 @@ func parseListeners(result *Config, list *ast.ObjectList) error { "address", "cluster_address", "endpoint", + "x_forwarded_for_authorized_addrs", + "x_forwarded_for_hop_skips", + "x_forwarded_for_reject_not_authorized", + "x_forwarded_for_reject_not_present", "infrastructure", "node_id", "proxy_protocol_behavior", diff --git a/command/server/listener_tcp.go b/command/server/listener_tcp.go index bf39615a69..201e124f3a 100644 --- a/command/server/listener_tcp.go +++ b/command/server/listener_tcp.go @@ -1,11 +1,15 @@ package server import ( + "fmt" "io" "net" + "strconv" "strings" "time" + "github.com/hashicorp/errwrap" + "github.com/hashicorp/vault/helper/parseutil" "github.com/hashicorp/vault/helper/reload" "github.com/mitchellh/cli" ) @@ -39,6 +43,57 @@ func tcpListenerFactory(config map[string]interface{}, _ io.Writer, ui cli.Ui) ( } props := map[string]string{"addr": addr} + + ffAllowedRaw, ffAllowedOK := config["x_forwarded_for_authorized_addrs"] + if ffAllowedOK { + ffAllowed, err := parseutil.ParseAddrs(ffAllowedRaw) + if err != nil { + return nil, nil, nil, errwrap.Wrapf("error parsing \"x_forwarded_for_authorized_addrs\": {{err}}", err) + } + props["x_forwarded_for_authorized_addrs"] = fmt.Sprintf("%v", ffAllowed) + config["x_forwarded_for_authorized_addrs"] = ffAllowed + } + + if ffHopsRaw, ok := config["x_forwarded_for_hop_skips"]; ok { + ffHops64, err := parseutil.ParseInt(ffHopsRaw) + if err != nil { + return nil, nil, nil, errwrap.Wrapf("error parsing \"x_forwarded_for_hop_skips\": {{err}}", err) + } + if ffHops64 < 0 { + return nil, nil, nil, fmt.Errorf("\"x_forwarded_for_hop_skips\" cannot be negative") + } + ffHops := int(ffHops64) + props["x_forwarded_for_hop_skips"] = strconv.Itoa(ffHops) + config["x_forwarded_for_hop_skips"] = ffHops + } else if ffAllowedOK { + props["x_forwarded_for_hop_skips"] = "0" + config["x_forwarded_for_hop_skips"] = int(0) + } + + if ffRejectNotPresentRaw, ok := config["x_forwarded_for_reject_not_present"]; ok { + ffRejectNotPresent, err := parseutil.ParseBool(ffRejectNotPresentRaw) + if err != nil { + return nil, nil, nil, errwrap.Wrapf("error parsing \"x_forwarded_for_reject_not_present\": {{err}}", err) + } + props["x_forwarded_for_reject_not_present"] = strconv.FormatBool(ffRejectNotPresent) + config["x_forwarded_for_reject_not_present"] = ffRejectNotPresent + } else if ffAllowedOK { + props["x_forwarded_for_reject_not_present"] = "true" + config["x_forwarded_for_reject_not_present"] = true + } + + if ffRejectNonAuthorizedRaw, ok := config["x_forwarded_for_reject_not_authorized"]; ok { + ffRejectNonAuthorized, err := parseutil.ParseBool(ffRejectNonAuthorizedRaw) + if err != nil { + return nil, nil, nil, errwrap.Wrapf("error parsing \"x_forwarded_for_reject_not_authorized\": {{err}}", err) + } + props["x_forwarded_for_reject_not_authorized"] = strconv.FormatBool(ffRejectNonAuthorized) + config["x_forwarded_for_reject_not_authorized"] = ffRejectNonAuthorized + } else if ffAllowedOK { + props["x_forwarded_for_reject_not_authorized"] = "true" + config["x_forwarded_for_reject_not_authorized"] = true + } + return listenerWrapTLS(ln, props, config, ui) } diff --git a/helper/dbtxn/dbtxn.go b/helper/dbtxn/dbtxn.go new file mode 100644 index 0000000000..3337bd97b2 --- /dev/null +++ b/helper/dbtxn/dbtxn.go @@ -0,0 +1,63 @@ +package dbtxn + +import ( + "context" + "database/sql" + "fmt" + "strings" +) + +// ExecuteDBQuery handles executing one single statement, while properly releasing its resources. +// - ctx: Required +// - db: Required +// - config: Optional, may be nil +// - query: Required +func ExecuteDBQuery(ctx context.Context, db *sql.DB, params map[string]string, query string) error { + + parsedQuery := parseQuery(params, query) + + stmt, err := db.PrepareContext(ctx, parsedQuery) + if err != nil { + return err + } + defer stmt.Close() + + return execute(ctx, stmt) +} + +// ExecuteTxQuery handles executing one single statement, while properly releasing its resources. +// - ctx: Required +// - tx: Required +// - config: Optional, may be nil +// - query: Required +func ExecuteTxQuery(ctx context.Context, tx *sql.Tx, params map[string]string, query string) error { + + parsedQuery := parseQuery(params, query) + + stmt, err := tx.PrepareContext(ctx, parsedQuery) + if err != nil { + return err + } + defer stmt.Close() + + return execute(ctx, stmt) +} + +func execute(ctx context.Context, stmt *sql.Stmt) error { + if _, err := stmt.ExecContext(ctx); err != nil { + return err + } + return nil +} + +func parseQuery(m map[string]string, tpl string) string { + + if m == nil || len(m) <= 0 { + return tpl + } + + for k, v := range m { + tpl = strings.Replace(tpl, fmt.Sprintf("{{%s}}", k), v, -1) + } + return tpl +} diff --git a/helper/parseutil/parseutil.go b/helper/parseutil/parseutil.go index 464b50899c..ae8c58ba78 100644 --- a/helper/parseutil/parseutil.go +++ b/helper/parseutil/parseutil.go @@ -3,10 +3,13 @@ package parseutil import ( "encoding/json" "errors" + "fmt" "strconv" "strings" "time" + "github.com/hashicorp/errwrap" + sockaddr "github.com/hashicorp/go-sockaddr" "github.com/hashicorp/vault/helper/strutil" "github.com/mitchellh/mapstructure" ) @@ -118,3 +121,43 @@ func ParseCommaStringSlice(in interface{}) ([]string, error) { } return strutil.TrimStrings(result), nil } + +func ParseAddrs(addrs interface{}) ([]*sockaddr.SockAddrMarshaler, error) { + out := make([]*sockaddr.SockAddrMarshaler, 0) + stringAddrs := make([]string, 0) + + switch addrs.(type) { + case string: + stringAddrs = strutil.ParseArbitraryStringSlice(addrs.(string), ",") + if len(stringAddrs) == 0 { + return nil, fmt.Errorf("unable to parse addresses from %v", addrs) + } + + case []string: + stringAddrs = addrs.([]string) + + case []interface{}: + for _, v := range addrs.([]interface{}) { + stringAddr, ok := v.(string) + if !ok { + return nil, fmt.Errorf("error parsing %v as string", v) + } + stringAddrs = append(stringAddrs, stringAddr) + } + + default: + return nil, fmt.Errorf("unknown address input type %T", addrs) + } + + for _, addr := range stringAddrs { + sa, err := sockaddr.NewSockAddr(addr) + if err != nil { + return nil, errwrap.Wrapf(fmt.Sprintf("error parsing address %q: {{err}}", addr), err) + } + out = append(out, &sockaddr.SockAddrMarshaler{ + SockAddr: sa, + }) + } + + return out, nil +} diff --git a/helper/proxyutil/proxyutil.go b/helper/proxyutil/proxyutil.go index 06371b29e5..875e74831c 100644 --- a/helper/proxyutil/proxyutil.go +++ b/helper/proxyutil/proxyutil.go @@ -8,7 +8,7 @@ import ( proxyproto "github.com/armon/go-proxyproto" "github.com/hashicorp/errwrap" sockaddr "github.com/hashicorp/go-sockaddr" - "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/helper/parseutil" ) // ProxyProtoConfig contains configuration for the PROXY protocol @@ -19,42 +19,12 @@ type ProxyProtoConfig struct { } func (p *ProxyProtoConfig) SetAuthorizedAddrs(addrs interface{}) error { - p.AuthorizedAddrs = make([]*sockaddr.SockAddrMarshaler, 0) - stringAddrs := make([]string, 0) - - switch addrs.(type) { - case string: - stringAddrs = strutil.ParseArbitraryStringSlice(addrs.(string), ",") - if len(stringAddrs) == 0 { - return fmt.Errorf("unable to parse addresses from %v", addrs) - } - - case []string: - stringAddrs = addrs.([]string) - - case []interface{}: - for _, v := range addrs.([]interface{}) { - stringAddr, ok := v.(string) - if !ok { - return fmt.Errorf("error parsing %v as string", v) - } - stringAddrs = append(stringAddrs, stringAddr) - } - - default: - return fmt.Errorf("unknown address input type %T", addrs) - } - - for _, addr := range stringAddrs { - sa, err := sockaddr.NewSockAddr(addr) - if err != nil { - return errwrap.Wrapf("error parsing authorized address: {{err}}", err) - } - p.AuthorizedAddrs = append(p.AuthorizedAddrs, &sockaddr.SockAddrMarshaler{ - SockAddr: sa, - }) + aa, err := parseutil.ParseAddrs(addrs) + if err != nil { + return err } + p.AuthorizedAddrs = aa return nil } diff --git a/http/forwarded_for_test.go b/http/forwarded_for_test.go new file mode 100644 index 0000000000..5d60391353 --- /dev/null +++ b/http/forwarded_for_test.go @@ -0,0 +1,249 @@ +package http + +import ( + "bytes" + "net/http" + "strings" + "testing" + + sockaddr "github.com/hashicorp/go-sockaddr" + "github.com/hashicorp/vault/vault" +) + +func TestHandler_XForwardedFor(t *testing.T) { + goodAddr, err := sockaddr.NewIPAddr("127.0.0.1") + if err != nil { + t.Fatal(err) + } + + badAddr, err := sockaddr.NewIPAddr("1.2.3.4") + if err != nil { + t.Fatal(err) + } + + // First: test reject not present + t.Run("reject_not_present", func(t *testing.T) { + t.Parallel() + testHandler := func(c *vault.Core) http.Handler { + origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(r.RemoteAddr)) + }) + return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{ + &sockaddr.SockAddrMarshaler{ + SockAddr: goodAddr, + }, + }, true, false, 0) + } + + cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ + HandlerFunc: testHandler, + }) + cluster.Start() + defer cluster.Cleanup() + client := cluster.Cores[0].Client + + req := client.NewRequest("GET", "/") + _, err = client.RawRequest(req) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "missing x-forwarded-for") { + t.Fatalf("bad error message: %v", err) + } + req = client.NewRequest("GET", "/") + req.Headers = make(http.Header) + req.Headers.Set("x-forwarded-for", "1.2.3.4") + resp, err := client.RawRequest(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + buf := bytes.NewBuffer(nil) + buf.ReadFrom(resp.Body) + if !strings.HasPrefix(buf.String(), "1.2.3.4:") { + t.Fatalf("bad body: %s", buf.String()) + } + }) + + // Next: test allow unauth + t.Run("allow_unauth", func(t *testing.T) { + t.Parallel() + testHandler := func(c *vault.Core) http.Handler { + origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(r.RemoteAddr)) + }) + return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{ + &sockaddr.SockAddrMarshaler{ + SockAddr: badAddr, + }, + }, true, false, 0) + } + + cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ + HandlerFunc: testHandler, + }) + cluster.Start() + defer cluster.Cleanup() + client := cluster.Cores[0].Client + + req := client.NewRequest("GET", "/") + req.Headers = make(http.Header) + req.Headers.Set("x-forwarded-for", "5.6.7.8") + resp, err := client.RawRequest(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + buf := bytes.NewBuffer(nil) + buf.ReadFrom(resp.Body) + if !strings.HasPrefix(buf.String(), "127.0.0.1:") { + t.Fatalf("bad body: %s", buf.String()) + } + }) + + // Next: test fail unauth + t.Run("fail_unauth", func(t *testing.T) { + t.Parallel() + testHandler := func(c *vault.Core) http.Handler { + origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(r.RemoteAddr)) + }) + return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{ + &sockaddr.SockAddrMarshaler{ + SockAddr: badAddr, + }, + }, true, true, 0) + } + + cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ + HandlerFunc: testHandler, + }) + cluster.Start() + defer cluster.Cleanup() + client := cluster.Cores[0].Client + + req := client.NewRequest("GET", "/") + req.Headers = make(http.Header) + req.Headers.Set("x-forwarded-for", "5.6.7.8") + _, err = client.RawRequest(req) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "not authorized for x-forwarded-for") { + t.Fatalf("bad error message: %v", err) + } + }) + + // Next: test bad hops (too many) + t.Run("too_many_hops", func(t *testing.T) { + t.Parallel() + testHandler := func(c *vault.Core) http.Handler { + origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(r.RemoteAddr)) + }) + return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{ + &sockaddr.SockAddrMarshaler{ + SockAddr: goodAddr, + }, + }, true, true, 4) + } + + cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ + HandlerFunc: testHandler, + }) + cluster.Start() + defer cluster.Cleanup() + client := cluster.Cores[0].Client + + req := client.NewRequest("GET", "/") + req.Headers = make(http.Header) + req.Headers.Set("x-forwarded-for", "2.3.4.5,3.4.5.6") + _, err = client.RawRequest(req) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "would skip before earliest") { + t.Fatalf("bad error message: %v", err) + } + }) + + // Next: test picking correct value + t.Run("correct_hop_skipping", func(t *testing.T) { + t.Parallel() + testHandler := func(c *vault.Core) http.Handler { + origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(r.RemoteAddr)) + }) + return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{ + &sockaddr.SockAddrMarshaler{ + SockAddr: goodAddr, + }, + }, true, true, 1) + } + + cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ + HandlerFunc: testHandler, + }) + cluster.Start() + defer cluster.Cleanup() + client := cluster.Cores[0].Client + + req := client.NewRequest("GET", "/") + req.Headers = make(http.Header) + req.Headers.Set("x-forwarded-for", "2.3.4.5,3.4.5.6,4.5.6.7,5.6.7.8") + resp, err := client.RawRequest(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + buf := bytes.NewBuffer(nil) + buf.ReadFrom(resp.Body) + if !strings.HasPrefix(buf.String(), "4.5.6.7:") { + t.Fatalf("bad body: %s", buf.String()) + } + }) + + // Next: multi-header approach + t.Run("correct_hop_skipping_multi_header", func(t *testing.T) { + t.Parallel() + testHandler := func(c *vault.Core) http.Handler { + origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(r.RemoteAddr)) + }) + return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{ + &sockaddr.SockAddrMarshaler{ + SockAddr: goodAddr, + }, + }, true, true, 1) + } + + cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ + HandlerFunc: testHandler, + }) + cluster.Start() + defer cluster.Cleanup() + client := cluster.Cores[0].Client + + req := client.NewRequest("GET", "/") + req.Headers = make(http.Header) + req.Headers.Add("x-forwarded-for", "2.3.4.5") + req.Headers.Add("x-forwarded-for", "3.4.5.6,4.5.6.7") + req.Headers.Add("x-forwarded-for", "5.6.7.8") + resp, err := client.RawRequest(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + buf := bytes.NewBuffer(nil) + buf.ReadFrom(resp.Body) + if !strings.HasPrefix(buf.String(), "4.5.6.7:") { + t.Fatalf("bad body: %s", buf.String()) + } + }) +} diff --git a/http/handler.go b/http/handler.go index a4e284dc36..72294b2bda 100644 --- a/http/handler.go +++ b/http/handler.go @@ -4,7 +4,9 @@ import ( "encoding/json" "fmt" "io" + "net" "net/http" + "net/textproto" "net/url" "os" "strings" @@ -13,6 +15,7 @@ import ( "github.com/elazarl/go-bindata-assetfs" "github.com/hashicorp/errwrap" cleanhttp "github.com/hashicorp/go-cleanhttp" + sockaddr "github.com/hashicorp/go-sockaddr" "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/parseutil" @@ -124,6 +127,94 @@ func wrapGenericHandler(h http.Handler) http.Handler { }) } +func WrapForwardedForHandler(h http.Handler, authorizedAddrs []*sockaddr.SockAddrMarshaler, rejectNotPresent, rejectNonAuthz bool, hopSkips int) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + headers, headersOK := r.Header[textproto.CanonicalMIMEHeaderKey("X-Forwarded-For")] + if !headersOK || len(headers) == 0 { + if !rejectNotPresent { + h.ServeHTTP(w, r) + return + } + respondError(w, http.StatusBadRequest, fmt.Errorf("missing x-forwarded-for header and configured to reject when not present")) + return + } + + host, port, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + // If not rejecting treat it like we just don't have a valid + // header because we can't do a comparison against an address we + // can't understand + if !rejectNotPresent { + h.ServeHTTP(w, r) + return + } + respondError(w, http.StatusBadRequest, errwrap.Wrapf("error parsing client hostport: {{err}}", err)) + return + } + + addr, err := sockaddr.NewIPAddr(host) + if err != nil { + // We treat this the same as the case above + if !rejectNotPresent { + h.ServeHTTP(w, r) + return + } + respondError(w, http.StatusBadRequest, errwrap.Wrapf("error parsing client address: {{err}}", err)) + return + } + + var found bool + for _, authz := range authorizedAddrs { + if authz.Contains(addr) { + found = true + break + } + } + if !found { + // If we didn't find it and aren't configured to reject, simply + // don't trust it + if !rejectNonAuthz { + h.ServeHTTP(w, r) + return + } + respondError(w, http.StatusBadRequest, fmt.Errorf("client address not authorized for x-forwarded-for and configured to reject connection")) + return + } + + // At this point we have at least one value and it's authorized + + // Split comma separated ones, which are common. This brings it in line + // to the multiple-header case. + var acc []string + for _, header := range headers { + vals := strings.Split(header, ",") + for _, v := range vals { + acc = append(acc, strings.TrimSpace(v)) + } + } + + indexToUse := len(acc) - 1 - hopSkips + if indexToUse < 0 { + // This is likely an error in either configuration or other + // infrastructure. We could either deny the request, or we + // could simply not trust the value. Denying the request is + // "safer" since if this logic is configured at all there may + // be an assumption it can always be trusted. Given that we can + // deny accepting the request at all if it's not from an + // authorized address, if we're at this point the address is + // authorized (or we've turned off explicit rejection) and we + // should assume that what comes in should be properly + // formatted. + respondError(w, http.StatusBadRequest, fmt.Errorf("malformed x-forwarded-for configuration or request, hops to skip (%d) would skip before earliest chain link (chain length %d)", hopSkips, len(headers))) + return + } + + r.RemoteAddr = net.JoinHostPort(acc[indexToUse], port) + h.ServeHTTP(w, r) + return + }) +} + // A lookup on a token that is about to expire returns nil, which means by the // time we can validate a wrapping token lookup will return nil since it will // be revoked after the call. So we have to do the validation here. diff --git a/plugins/database/hana/hana.go b/plugins/database/hana/hana.go index 1fdafe77ad..62e739a669 100644 --- a/plugins/database/hana/hana.go +++ b/plugins/database/hana/hana.go @@ -11,6 +11,7 @@ import ( _ "github.com/SAP/go-hdb/driver" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/helper/dbtxn" "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/plugins" "github.com/hashicorp/vault/plugins/helper/database/connutil" @@ -143,16 +144,12 @@ func (h *HANA) CreateUser(ctx context.Context, statements dbplugin.Statements, u continue } - stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ + m := map[string]string{ "name": username, "password": password, "expiration": expirationStr, - })) - if err != nil { - return "", "", err } - defer stmt.Close() - if _, err := stmt.ExecContext(ctx); err != nil { + if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { return "", "", err } } @@ -238,14 +235,10 @@ func (h *HANA) RevokeUser(ctx context.Context, statements dbplugin.Statements, u continue } - stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ + m := map[string]string{ "name": username, - })) - if err != nil { - return err } - defer stmt.Close() - if _, err := stmt.ExecContext(ctx); err != nil { + if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { return err } } diff --git a/plugins/database/mssql/mssql.go b/plugins/database/mssql/mssql.go index 84f7e1462a..9b0a78c0ae 100644 --- a/plugins/database/mssql/mssql.go +++ b/plugins/database/mssql/mssql.go @@ -12,6 +12,7 @@ import ( "github.com/hashicorp/errwrap" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/helper/dbtxn" "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/plugins" "github.com/hashicorp/vault/plugins/helper/database/connutil" @@ -129,16 +130,13 @@ func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, continue } - stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ + m := map[string]string{ "name": username, "password": password, "expiration": expirationStr, - })) - if err != nil { - return "", "", err } - defer stmt.Close() - if _, err := stmt.ExecContext(ctx); err != nil { + + if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { return "", "", err } } @@ -189,14 +187,10 @@ func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, continue } - stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ + m := map[string]string{ "name": username, - })) - if err != nil { - return err } - defer stmt.Close() - if _, err := stmt.ExecContext(ctx); err != nil { + if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { return err } } @@ -285,14 +279,7 @@ func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error { // many permissions as possible right now var lastStmtError error for _, query := range revokeStmts { - stmt, err := db.PrepareContext(ctx, query) - if err != nil { - lastStmtError = err - continue - } - defer stmt.Close() - _, err = stmt.ExecContext(ctx) - if err != nil { + if err := dbtxn.ExecuteDBQuery(ctx, db, nil, query); err != nil { lastStmtError = err } } @@ -355,16 +342,12 @@ func (m *MSSQL) RotateRootCredentials(ctx context.Context, statements []string) if len(query) == 0 { continue } - stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ + + m := map[string]string{ "username": m.Username, "password": password, - })) - if err != nil { - return nil, err } - - defer stmt.Close() - if _, err := stmt.ExecContext(ctx); err != nil { + if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { return nil, err } } diff --git a/plugins/database/mysql/mysql.go b/plugins/database/mysql/mysql.go index 00fe475045..a36f1a8686 100644 --- a/plugins/database/mysql/mysql.go +++ b/plugins/database/mysql/mysql.go @@ -10,6 +10,7 @@ import ( stdmysql "github.com/go-sql-driver/mysql" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/helper/dbtxn" "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/plugins" "github.com/hashicorp/vault/plugins/helper/database/connutil" @@ -182,10 +183,11 @@ func (m *MySQL) CreateUser(ctx context.Context, statements dbplugin.Statements, return "", "", err } - defer stmt.Close() if _, err := stmt.ExecContext(ctx); err != nil { + stmt.Close() return "", "", err } + stmt.Close() } } @@ -291,16 +293,12 @@ func (m *MySQL) RotateRootCredentials(ctx context.Context, statements []string) if len(query) == 0 { continue } - stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ + + m := map[string]string{ "username": m.Username, "password": password, - })) - if err != nil { - return nil, err } - - defer stmt.Close() - if _, err := stmt.ExecContext(ctx); err != nil { + if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { return nil, err } } diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go index c56f9ed02d..36dd0036a9 100644 --- a/plugins/database/postgresql/postgresql.go +++ b/plugins/database/postgresql/postgresql.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/errwrap" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/helper/dbtxn" "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/plugins" "github.com/hashicorp/vault/plugins/helper/database/connutil" @@ -139,16 +140,12 @@ func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Stateme continue } - stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ + m := map[string]string{ "name": username, "password": password, "expiration": expirationStr, - })) - if err != nil { - return "", "", err } - defer stmt.Close() - if _, err := stmt.ExecContext(ctx); err != nil { + if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { return "", "", err } } @@ -157,7 +154,6 @@ func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Stateme // Commit the transaction if err := tx.Commit(); err != nil { return "", "", err - } return username, password, nil @@ -198,16 +194,12 @@ func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statemen if len(query) == 0 { continue } - stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ + + m := map[string]string{ "name": username, "expiration": expirationStr, - })) - if err != nil { - return err } - - defer stmt.Close() - if _, err := stmt.ExecContext(ctx); err != nil { + if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { return err } } @@ -251,15 +243,10 @@ func (p *PostgreSQL) customRevokeUser(ctx context.Context, username string, revo continue } - stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ + m := map[string]string{ "name": username, - })) - if err != nil { - return err } - defer stmt.Close() - - if _, err := stmt.ExecContext(ctx); err != nil { + if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { return err } } @@ -352,14 +339,7 @@ func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) err // many permissions as possible right now var lastStmtError error for _, query := range revocationStmts { - stmt, err := db.PrepareContext(ctx, query) - if err != nil { - lastStmtError = err - continue - } - defer stmt.Close() - _, err = stmt.ExecContext(ctx) - if err != nil { + if err := dbtxn.ExecuteDBQuery(ctx, db, nil, query); err != nil { lastStmtError = err } } @@ -423,16 +403,11 @@ func (p *PostgreSQL) RotateRootCredentials(ctx context.Context, statements []str if len(query) == 0 { continue } - stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ + m := map[string]string{ "username": p.Username, "password": password, - })) - if err != nil { - return nil, err } - - defer stmt.Close() - if _, err := stmt.ExecContext(ctx); err != nil { + if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { return nil, err } } diff --git a/ui/app/components/auth-form.js b/ui/app/components/auth-form.js index a2ebe58449..2790508a4e 100644 --- a/ui/app/components/auth-form.js +++ b/ui/app/components/auth-form.js @@ -3,7 +3,21 @@ import { supportedAuthBackends } from 'vault/helpers/supported-auth-backends'; const BACKENDS = supportedAuthBackends(); const { computed, inject } = Ember; -export default Ember.Component.extend({ +const attributesForSelectedAuthBackend = { + token: ['token'], + userpass: ['username', 'password'], + ldap: ['username', 'password'], + github: ['username', 'password'], + okta: ['username', 'password'], +}; + +const DEFAULTS = { + token: null, + username: null, + password: null, +}; + +export default Ember.Component.extend(DEFAULTS, { classNames: ['auth-form'], routing: inject.service('-routing'), auth: inject.service(), @@ -14,6 +28,21 @@ export default Ember.Component.extend({ this.$('li.is-active').get(0).scrollIntoView(); }, + didReceiveAttrs() { + this._super(...arguments); + let newMethod = this.get('selectedAuthType'); + let oldMethod = this.get('oldSelectedAuthType'); + + if (oldMethod && oldMethod !== newMethod) { + this.resetDefaults(); + } + this.set('oldSelectedAuthType', newMethod); + }, + + resetDefaults() { + this.setProperties(DEFAULTS); + }, + cluster: null, redirectTo: null, @@ -22,9 +51,9 @@ export default Ember.Component.extend({ return BACKENDS.findBy('type', this.get('selectedAuthType')); }), - providerComponentName: Ember.computed('selectedAuthBackend.type', function() { - const type = Ember.String.dasherize(this.get('selectedAuthBackend.type')); - return `auth-form/${type}`; + providerPartialName: Ember.computed('selectedAuthType', function() { + const type = Ember.String.dasherize(this.get('selectedAuthType')); + return `partials/auth-form/${type}`; }), hasCSPError: computed.alias('csp.connectionViolations.firstObject'), @@ -45,15 +74,18 @@ export default Ember.Component.extend({ }, actions: { - doSubmit(data) { + doSubmit() { + let data = {}; this.setProperties({ loading: true, error: null, }); - const targetRoute = this.get('redirectTo') || 'vault.cluster'; - //const {password, token, username} = data; - const backend = this.get('selectedAuthBackend.type'); - const path = this.get('customPath'); + let targetRoute = this.get('redirectTo') || 'vault.cluster'; + let backend = this.get('selectedAuthBackend.type'); + let path = this.get('customPath'); + let attributes = attributesForSelectedAuthBackend[backend]; + + data = Ember.assign(data, this.getProperties(...attributes)); if (this.get('useCustomPath') && path) { data.path = path; } diff --git a/ui/app/components/download-button.js b/ui/app/components/download-button.js index 72fee1106e..2e2847c887 100644 --- a/ui/app/components/download-button.js +++ b/ui/app/components/download-button.js @@ -12,7 +12,7 @@ export default Ember.Component.extend({ return `${this.get('filename')}-${new Date().toISOString()}.${this.get('extension')}`; }), - fileLike: computed('data', 'mime', 'strigify', 'download', function() { + fileLike: computed('data', 'mime', 'stringify', 'download', function() { let file; let data = this.get('data'); let filename = this.get('download'); diff --git a/ui/app/styles/components/shamir-progress.scss b/ui/app/styles/components/shamir-progress.scss index 4bb418328a..1169b45a4d 100644 --- a/ui/app/styles/components/shamir-progress.scss +++ b/ui/app/styles/components/shamir-progress.scss @@ -1,11 +1,12 @@ .shamir-progress { .shamir-progress-progress { display: inline-block; + margin-top: $size-10; margin-right: $size-8; } .progress { box-shadow: 0 0 0 4px $progress-bar-background-color; - display: inline; - width: 150px; + margin-top: $size-10; + min-width: 90px; } } diff --git a/ui/app/templates/components/auth-form.hbs b/ui/app/templates/components/auth-form.hbs index 73b35345a4..2d94e536ad 100644 --- a/ui/app/templates/components/auth-form.hbs +++ b/ui/app/templates/components/auth-form.hbs @@ -1,7 +1,7 @@ -
- If this backend was mounted using a non-default path, enter it here. -
- {{/if}} -