From 56edb780e8b7ff6b84e221ea1b666bd27ddb9469 Mon Sep 17 00:00:00 2001 From: Becca Petrin Date: Thu, 9 Jan 2020 14:56:34 -0800 Subject: [PATCH] Add Kerberos auth agent (#7999) * add kerberos auth agent * strip old comment * changes from feedback * strip appengine indirect dependency --- api/client.go | 20 ++- api/client_test.go | 49 +++++ builtin/credential/aws/path_role_test.go | 2 +- builtin/logical/pki/backend_test.go | 2 +- command/agent.go | 3 + command/agent/auth/alicloud/alicloud.go | 7 +- command/agent/auth/approle/approle.go | 43 ++--- command/agent/auth/auth.go | 12 +- command/agent/auth/auth_test.go | 7 +- command/agent/auth/aws/aws.go | 4 +- command/agent/auth/azure/azure.go | 4 +- command/agent/auth/cert/cert.go | 7 +- command/agent/auth/cf/cf.go | 13 +- command/agent/auth/gcp/gcp.go | 4 +- command/agent/auth/jwt/jwt.go | 7 +- .../kerberos/integtest/integrationtest.sh | 170 ++++++++++++++++++ command/agent/auth/kerberos/kerberos.go | 91 ++++++++++ command/agent/auth/kerberos/kerberos_test.go | 67 +++++++ command/agent/auth/kubernetes/kubernetes.go | 7 +- .../agent/auth/kubernetes/kubernetes_test.go | 2 +- command/agent/aws_end_to_end_test.go | 2 - helper/metricsutil/metricsutil_test.go | 3 +- .../github.com/hashicorp/vault/api/client.go | 14 +- 23 files changed, 475 insertions(+), 65 deletions(-) create mode 100755 command/agent/auth/kerberos/integtest/integrationtest.sh create mode 100644 command/agent/auth/kerberos/kerberos.go create mode 100644 command/agent/auth/kerberos/kerberos_test.go diff --git a/api/client.go b/api/client.go index a6a8bb6091..2f4f90ee35 100644 --- a/api/client.go +++ b/api/client.go @@ -603,7 +603,7 @@ func (c *Client) ClearToken() { } // Headers gets the current set of headers used for requests. This returns a -// copy; to modify it make modifications locally and use SetHeaders. +// copy; to modify it call AddHeader or SetHeaders. func (c *Client) Headers() http.Header { c.modifyLock.RLock() defer c.modifyLock.RUnlock() @@ -622,11 +622,19 @@ func (c *Client) Headers() http.Header { return ret } -// SetHeaders sets the headers to be used for future requests. +// AddHeader allows a single header key/value pair to be added +// in a race-safe fashion. +func (c *Client) AddHeader(key, value string) { + c.modifyLock.Lock() + defer c.modifyLock.Unlock() + c.headers.Add(key, value) +} + +// SetHeaders clears all previous headers and uses only the given +// ones going forward. func (c *Client) SetHeaders(headers http.Header) { c.modifyLock.Lock() defer c.modifyLock.Unlock() - c.headers = headers } @@ -680,8 +688,8 @@ func (c *Client) SetPolicyOverride(override bool) { // portMap defines the standard port map var portMap = map[string]string{ - "http": "80", - "https": "443", + "http": "80", + "https": "443", } // NewRequest creates a new raw request object to query the Vault server @@ -703,7 +711,7 @@ func (c *Client) NewRequest(method, requestPath string) *Request { // Avoid lookup of SRV record if scheme is known port, ok := portMap[addr.Scheme] if ok { - host = net.JoinHostPort(host, port) + host = net.JoinHostPort(host, port) } else { // Internet Draft specifies that the SRV record is ignored if a port is given _, addrs, err := net.LookupSRV("http", "tcp", addr.Hostname()) diff --git a/api/client_test.go b/api/client_test.go index 2bd4ca9146..63c822fd21 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -325,3 +325,52 @@ func TestClone(t *testing.T) { _ = client2 } + +func TestSetHeadersRaceSafe(t *testing.T) { + client, err1 := NewClient(nil) + if err1 != nil { + t.Fatalf("NewClient failed: %v", err1) + } + + start := make(chan interface{}) + done := make(chan interface{}) + + testPairs := map[string]string{ + "soda": "rootbeer", + "veggie": "carrots", + "fruit": "apples", + "color": "red", + "protein": "egg", + } + + for key, value := range testPairs { + tmpKey := key + tmpValue := value + go func() { + <-start + // This test fails if here, you replace client.AddHeader(tmpKey, tmpValue) with: + // headerCopy := client.Header() + // headerCopy.AddHeader(tmpKey, tmpValue) + // client.SetHeader(headerCopy) + client.AddHeader(tmpKey, tmpValue) + done <- true + }() + } + + // Start everyone at once. + close(start) + + // Wait until everyone is done. + for i := 0; i < len(testPairs); i++ { + <-done + } + + // Check that all the test pairs are in the resulting + // headers. + resultingHeaders := client.Headers() + for key, value := range testPairs { + if resultingHeaders.Get(key) != value { + t.Fatal("expected " + value + " for " + key) + } + } +} diff --git a/builtin/credential/aws/path_role_test.go b/builtin/credential/aws/path_role_test.go index f3daec40b5..36118fa838 100644 --- a/builtin/credential/aws/path_role_test.go +++ b/builtin/credential/aws/path_role_test.go @@ -1002,7 +1002,7 @@ func TestRoleResolutionWithSTSEndpointConfigured(t *testing.T) { /* ARN of an AWS role that Vault can query during testing. This role should exist in your current AWS account and your credentials should have iam:GetRole permissions to query it. - */ + */ assumableRoleArn := os.Getenv("AWS_ASSUMABLE_ROLE_ARN") if assumableRoleArn == "" { t.Skip("skipping because AWS_ASSUMABLE_ROLE_ARN is unset") diff --git a/builtin/logical/pki/backend_test.go b/builtin/logical/pki/backend_test.go index 330f4abee8..4cc576b75c 100644 --- a/builtin/logical/pki/backend_test.go +++ b/builtin/logical/pki/backend_test.go @@ -13,7 +13,6 @@ import ( "encoding/base64" "encoding/pem" "fmt" - "github.com/go-test/deep" "math" "math/big" mathrand "math/rand" @@ -29,6 +28,7 @@ import ( "time" "github.com/fatih/structs" + "github.com/go-test/deep" "github.com/hashicorp/vault/api" logicaltest "github.com/hashicorp/vault/helper/testhelpers/logical" vaulthttp "github.com/hashicorp/vault/http" diff --git a/command/agent.go b/command/agent.go index 111345b4d8..3f9110512d 100644 --- a/command/agent.go +++ b/command/agent.go @@ -27,6 +27,7 @@ import ( "github.com/hashicorp/vault/command/agent/auth/cf" "github.com/hashicorp/vault/command/agent/auth/gcp" "github.com/hashicorp/vault/command/agent/auth/jwt" + "github.com/hashicorp/vault/command/agent/auth/kerberos" "github.com/hashicorp/vault/command/agent/auth/kubernetes" "github.com/hashicorp/vault/command/agent/cache" agentConfig "github.com/hashicorp/vault/command/agent/config" @@ -385,6 +386,8 @@ func (c *AgentCommand) Run(args []string) int { method, err = gcp.NewGCPAuthMethod(authConfig) case "jwt": method, err = jwt.NewJWTAuthMethod(authConfig) + case "kerberos": + method, err = kerberos.NewKerberosAuthMethod(authConfig) case "kubernetes": method, err = kubernetes.NewKubernetesAuthMethod(authConfig) case "approle": diff --git a/command/agent/auth/alicloud/alicloud.go b/command/agent/auth/alicloud/alicloud.go index dbccdd57c3..ff9a4341f2 100644 --- a/command/agent/auth/alicloud/alicloud.go +++ b/command/agent/auth/alicloud/alicloud.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "net/http" "reflect" "sync" "time" @@ -174,16 +175,16 @@ type alicloudMethod struct { stopCh chan struct{} } -func (a *alicloudMethod) Authenticate(context.Context, *api.Client) (string, map[string]interface{}, error) { +func (a *alicloudMethod) Authenticate(context.Context, *api.Client) (string, http.Header, map[string]interface{}, error) { a.credLock.Lock() defer a.credLock.Unlock() a.logger.Trace("beginning authentication") data, err := tools.GenerateLoginData(a.role, a.lastCreds, a.region) if err != nil { - return "", nil, err + return "", nil, nil, err } - return fmt.Sprintf("%s/login", a.mountPath), data, nil + return fmt.Sprintf("%s/login", a.mountPath), nil, data, nil } func (a *alicloudMethod) NewCreds() chan struct{} { diff --git a/command/agent/auth/approle/approle.go b/command/agent/auth/approle/approle.go index 5c319fafaa..f9970348b7 100644 --- a/command/agent/auth/approle/approle.go +++ b/command/agent/auth/approle/approle.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io/ioutil" + "net/http" "os" "strings" @@ -87,18 +88,18 @@ func NewApproleAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) { return a, nil } -func (a *approleMethod) Authenticate(ctx context.Context, client *api.Client) (string, map[string]interface{}, error) { +func (a *approleMethod) Authenticate(ctx context.Context, client *api.Client) (string, http.Header, map[string]interface{}, error) { if _, err := os.Stat(a.roleIDFilePath); err == nil { roleID, err := ioutil.ReadFile(a.roleIDFilePath) if err != nil { if a.cachedRoleID == "" { - return "", nil, errwrap.Wrapf("error reading role ID file and no cached role ID known: {{err}}", err) + return "", nil, nil, errwrap.Wrapf("error reading role ID file and no cached role ID known: {{err}}", err) } a.logger.Warn("error reading role ID file", "error", err) } if len(roleID) == 0 { if a.cachedRoleID == "" { - return "", nil, errors.New("role ID file empty and no cached role ID known") + return "", nil, nil, errors.New("role ID file empty and no cached role ID known") } a.logger.Warn("role ID file exists but read empty value, re-using cached value") } else { @@ -107,11 +108,11 @@ func (a *approleMethod) Authenticate(ctx context.Context, client *api.Client) (s } if a.cachedRoleID == "" { - return "", nil, errors.New("no known role ID") + return "", nil, nil, errors.New("no known role ID") } if a.secretIDFilePath == "" { - return fmt.Sprintf("%s/login", a.mountPath), map[string]interface{}{ + return fmt.Sprintf("%s/login", a.mountPath), nil, map[string]interface{}{ "role_id": a.cachedRoleID, }, nil } @@ -120,13 +121,13 @@ func (a *approleMethod) Authenticate(ctx context.Context, client *api.Client) (s secretID, err := ioutil.ReadFile(a.secretIDFilePath) if err != nil { if a.cachedSecretID == "" { - return "", nil, errwrap.Wrapf("error reading secret ID file and no cached secret ID known: {{err}}", err) + return "", nil, nil, errwrap.Wrapf("error reading secret ID file and no cached secret ID known: {{err}}", err) } a.logger.Warn("error reading secret ID file", "error", err) } if len(secretID) == 0 { if a.cachedSecretID == "" { - return "", nil, errors.New("secret ID file empty and no cached secret ID known") + return "", nil, nil, errors.New("secret ID file empty and no cached secret ID known") } a.logger.Warn("secret ID file exists but read empty value, re-using cached value") } else { @@ -134,50 +135,50 @@ func (a *approleMethod) Authenticate(ctx context.Context, client *api.Client) (s if a.secretIDResponseWrappingPath != "" { clonedClient, err := client.Clone() if err != nil { - return "", nil, errwrap.Wrapf("error cloning client to unwrap secret ID: {{err}}", err) + return "", nil, nil, errwrap.Wrapf("error cloning client to unwrap secret ID: {{err}}", err) } clonedClient.SetToken(stringSecretID) // Validate the creation path resp, err := clonedClient.Logical().Read("sys/wrapping/lookup") if err != nil { - return "", nil, errwrap.Wrapf("error looking up wrapped secret ID: {{err}}", err) + return "", nil, nil, errwrap.Wrapf("error looking up wrapped secret ID: {{err}}", err) } if resp == nil { - return "", nil, errors.New("response nil when looking up wrapped secret ID") + return "", nil, nil, errors.New("response nil when looking up wrapped secret ID") } if resp.Data == nil { - return "", nil, errors.New("data in response nil when looking up wrapped secret ID") + return "", nil, nil, errors.New("data in response nil when looking up wrapped secret ID") } creationPathRaw, ok := resp.Data["creation_path"] if !ok { - return "", nil, errors.New("creation_path in response nil when looking up wrapped secret ID") + return "", nil, nil, errors.New("creation_path in response nil when looking up wrapped secret ID") } creationPath, ok := creationPathRaw.(string) if !ok { - return "", nil, errors.New("creation_path in response could not be parsed as string when looking up wrapped secret ID") + return "", nil, nil, errors.New("creation_path in response could not be parsed as string when looking up wrapped secret ID") } if creationPath != a.secretIDResponseWrappingPath { a.logger.Error("SECURITY: unable to validate wrapping token creation path", "expected", a.secretIDResponseWrappingPath, "found", creationPath) - return "", nil, errors.New("unable to validate wrapping token creation path") + return "", nil, nil, errors.New("unable to validate wrapping token creation path") } // Now get the secret ID resp, err = clonedClient.Logical().Unwrap("") if err != nil { - return "", nil, errwrap.Wrapf("error unwrapping secret ID: {{err}}", err) + return "", nil, nil, errwrap.Wrapf("error unwrapping secret ID: {{err}}", err) } if resp == nil { - return "", nil, errors.New("response nil when unwrapping secret ID") + return "", nil, nil, errors.New("response nil when unwrapping secret ID") } if resp.Data == nil { - return "", nil, errors.New("data in response nil when unwrapping secret ID") + return "", nil, nil, errors.New("data in response nil when unwrapping secret ID") } secretIDRaw, ok := resp.Data["secret_id"] if !ok { - return "", nil, errors.New("secret_id in response nil when unwrapping secret ID") + return "", nil, nil, errors.New("secret_id in response nil when unwrapping secret ID") } secretID, ok := secretIDRaw.(string) if !ok { - return "", nil, errors.New("secret_id in response could not be parsed as string when unwrapping secret ID") + return "", nil, nil, errors.New("secret_id in response could not be parsed as string when unwrapping secret ID") } stringSecretID = secretID } @@ -191,10 +192,10 @@ func (a *approleMethod) Authenticate(ctx context.Context, client *api.Client) (s } if a.cachedSecretID == "" { - return "", nil, errors.New("no known secret ID") + return "", nil, nil, errors.New("no known secret ID") } - return fmt.Sprintf("%s/login", a.mountPath), map[string]interface{}{ + return fmt.Sprintf("%s/login", a.mountPath), nil, map[string]interface{}{ "role_id": a.cachedRoleID, "secret_id": a.cachedSecretID, }, nil diff --git a/command/agent/auth/auth.go b/command/agent/auth/auth.go index 3b5460df35..b1e33bb961 100644 --- a/command/agent/auth/auth.go +++ b/command/agent/auth/auth.go @@ -3,6 +3,7 @@ package auth import ( "context" "math/rand" + "net/http" "time" hclog "github.com/hashicorp/go-hclog" @@ -11,7 +12,9 @@ import ( ) type AuthMethod interface { - Authenticate(context.Context, *api.Client) (string, map[string]interface{}, error) + // Authenticate returns a mount path, header, request body, and error. + // The header may be nil if no special header is needed. + Authenticate(context.Context, *api.Client) (string, http.Header, map[string]interface{}, error) NewCreds() chan struct{} CredSuccess() Shutdown() @@ -119,7 +122,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) { backoff := 2*time.Second + time.Duration(ah.random.Int63()%int64(time.Second*2)-int64(time.Second)) ah.logger.Info("authenticating") - path, data, err := am.Authenticate(ctx, ah.client) + path, header, data, err := am.Authenticate(ctx, ah.client) if err != nil { ah.logger.Error("error getting path or data from method", "error", err, "backoff", backoff.Seconds()) backoffOrQuit(ctx, backoff) @@ -139,6 +142,11 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) { }) clientToUse = wrapClient } + for key, values := range header { + for _, value := range values { + clientToUse.AddHeader(key, value) + } + } secret, err := clientToUse.Logical().Write(path, data) // Check errors/sanity diff --git a/command/agent/auth/auth_test.go b/command/agent/auth/auth_test.go index 297b518257..c08880141b 100644 --- a/command/agent/auth/auth_test.go +++ b/command/agent/auth/auth_test.go @@ -2,6 +2,7 @@ package auth import ( "context" + "net/http" "testing" "time" @@ -31,14 +32,14 @@ func newUserpassTestMethod(t *testing.T, client *api.Client) AuthMethod { return &userpassTestMethod{} } -func (u *userpassTestMethod) Authenticate(_ context.Context, client *api.Client) (string, map[string]interface{}, error) { +func (u *userpassTestMethod) Authenticate(_ context.Context, client *api.Client) (string, http.Header, map[string]interface{}, error) { _, err := client.Logical().Write("auth/userpass/users/foo", map[string]interface{}{ "password": "bar", }) if err != nil { - return "", nil, err + return "", nil, nil, err } - return "auth/userpass/login/foo", map[string]interface{}{ + return "auth/userpass/login/foo", nil, map[string]interface{}{ "password": "bar", }, nil } diff --git a/command/agent/auth/aws/aws.go b/command/agent/auth/aws/aws.go index 306d45f527..4330148dfd 100644 --- a/command/agent/auth/aws/aws.go +++ b/command/agent/auth/aws/aws.go @@ -179,7 +179,7 @@ func NewAWSAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) { return a, nil } -func (a *awsMethod) Authenticate(ctx context.Context, client *api.Client) (retToken string, retData map[string]interface{}, retErr error) { +func (a *awsMethod) Authenticate(ctx context.Context, client *api.Client) (retToken string, header http.Header, retData map[string]interface{}, retErr error) { a.logger.Trace("beginning authentication") data := make(map[string]interface{}) @@ -266,7 +266,7 @@ func (a *awsMethod) Authenticate(ctx context.Context, client *api.Client) (retTo data["role"] = a.role - return fmt.Sprintf("%s/login", a.mountPath), data, nil + return fmt.Sprintf("%s/login", a.mountPath), nil, data, nil } func (a *awsMethod) NewCreds() chan struct{} { diff --git a/command/agent/auth/azure/azure.go b/command/agent/auth/azure/azure.go index 01a0ada341..26a4e4af5e 100644 --- a/command/agent/auth/azure/azure.go +++ b/command/agent/auth/azure/azure.go @@ -74,7 +74,7 @@ func NewAzureAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) { return a, nil } -func (a *azureMethod) Authenticate(ctx context.Context, client *api.Client) (retPath string, retData map[string]interface{}, retErr error) { +func (a *azureMethod) Authenticate(ctx context.Context, client *api.Client) (retPath string, header http.Header, retData map[string]interface{}, retErr error) { a.logger.Trace("beginning authentication") // Fetch instance data @@ -126,7 +126,7 @@ func (a *azureMethod) Authenticate(ctx context.Context, client *api.Client) (ret "jwt": identity.AccessToken, } - return fmt.Sprintf("%s/login", a.mountPath), data, nil + return fmt.Sprintf("%s/login", a.mountPath), nil, data, nil } func (a *azureMethod) NewCreds() chan struct{} { diff --git a/command/agent/auth/cert/cert.go b/command/agent/auth/cert/cert.go index fc1f42606d..2265387948 100644 --- a/command/agent/auth/cert/cert.go +++ b/command/agent/auth/cert/cert.go @@ -4,8 +4,9 @@ import ( "context" "errors" "fmt" + "net/http" - hclog "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/command/agent/auth" ) @@ -44,7 +45,7 @@ func NewCertAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) { return c, nil } -func (c *certMethod) Authenticate(_ context.Context, client *api.Client) (string, map[string]interface{}, error) { +func (c *certMethod) Authenticate(_ context.Context, client *api.Client) (string, http.Header, map[string]interface{}, error) { c.logger.Trace("beginning authentication") authMap := map[string]interface{}{} @@ -53,7 +54,7 @@ func (c *certMethod) Authenticate(_ context.Context, client *api.Client) (string authMap["name"] = c.name } - return fmt.Sprintf("%s/login", c.mountPath), authMap, nil + return fmt.Sprintf("%s/login", c.mountPath), nil, authMap, nil } func (c *certMethod) NewCreds() chan struct{} { diff --git a/command/agent/auth/cf/cf.go b/command/agent/auth/cf/cf.go index dcd17dbde6..9508b7164f 100644 --- a/command/agent/auth/cf/cf.go +++ b/command/agent/auth/cf/cf.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io/ioutil" + "net/http" "os" "time" @@ -41,18 +42,18 @@ func NewCFAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) { return a, nil } -func (p *cfMethod) Authenticate(ctx context.Context, client *api.Client) (string, map[string]interface{}, error) { +func (p *cfMethod) Authenticate(ctx context.Context, client *api.Client) (string, http.Header, map[string]interface{}, error) { pathToClientCert := os.Getenv(cf.EnvVarInstanceCertificate) if pathToClientCert == "" { - return "", nil, fmt.Errorf("missing %q value", cf.EnvVarInstanceCertificate) + return "", nil, nil, fmt.Errorf("missing %q value", cf.EnvVarInstanceCertificate) } certBytes, err := ioutil.ReadFile(pathToClientCert) if err != nil { - return "", nil, err + return "", nil, nil, err } pathToClientKey := os.Getenv(cf.EnvVarInstanceKey) if pathToClientKey == "" { - return "", nil, fmt.Errorf("missing %q value", cf.EnvVarInstanceKey) + return "", nil, nil, fmt.Errorf("missing %q value", cf.EnvVarInstanceKey) } signingTime := time.Now().UTC() signatureData := &signatures.SignatureData{ @@ -62,7 +63,7 @@ func (p *cfMethod) Authenticate(ctx context.Context, client *api.Client) (string } signature, err := signatures.Sign(pathToClientKey, signatureData) if err != nil { - return "", nil, err + return "", nil, nil, err } data := map[string]interface{}{ "role": p.roleName, @@ -70,7 +71,7 @@ func (p *cfMethod) Authenticate(ctx context.Context, client *api.Client) (string "signing_time": signingTime.Format(signatures.TimeFormat), "signature": signature, } - return fmt.Sprintf("%s/login", p.mountPath), data, nil + return fmt.Sprintf("%s/login", p.mountPath), nil, data, nil } func (p *cfMethod) NewCreds() chan struct{} { diff --git a/command/agent/auth/gcp/gcp.go b/command/agent/auth/gcp/gcp.go index f4a7d92dce..666593c99b 100644 --- a/command/agent/auth/gcp/gcp.go +++ b/command/agent/auth/gcp/gcp.go @@ -116,7 +116,7 @@ func NewGCPAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) { return g, nil } -func (g *gcpMethod) Authenticate(ctx context.Context, client *api.Client) (retPath string, retData map[string]interface{}, retErr error) { +func (g *gcpMethod) Authenticate(ctx context.Context, client *api.Client) (retPath string, header http.Header, retData map[string]interface{}, retErr error) { g.logger.Trace("beginning authentication") data := make(map[string]interface{}) @@ -227,7 +227,7 @@ func (g *gcpMethod) Authenticate(ctx context.Context, client *api.Client) (retPa data["role"] = g.role data["jwt"] = jwt - return fmt.Sprintf("%s/login", g.mountPath), data, nil + return fmt.Sprintf("%s/login", g.mountPath), nil, data, nil } func (g *gcpMethod) NewCreds() chan struct{} { diff --git a/command/agent/auth/jwt/jwt.go b/command/agent/auth/jwt/jwt.go index e77265536a..403a57fae1 100644 --- a/command/agent/auth/jwt/jwt.go +++ b/command/agent/auth/jwt/jwt.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io/ioutil" + "net/http" "os" "sync" "sync/atomic" @@ -85,17 +86,17 @@ func NewJWTAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) { return j, nil } -func (j *jwtMethod) Authenticate(_ context.Context, client *api.Client) (string, map[string]interface{}, error) { +func (j *jwtMethod) Authenticate(_ context.Context, client *api.Client) (string, http.Header, map[string]interface{}, error) { j.logger.Trace("beginning authentication") j.ingressToken() latestToken := j.latestToken.Load().(string) if latestToken == "" { - return "", nil, errors.New("latest known jwt is empty, cannot authenticate") + return "", nil, nil, errors.New("latest known jwt is empty, cannot authenticate") } - return fmt.Sprintf("%s/login", j.mountPath), map[string]interface{}{ + return fmt.Sprintf("%s/login", j.mountPath), nil, map[string]interface{}{ "role": j.role, "jwt": latestToken, }, nil diff --git a/command/agent/auth/kerberos/integtest/integrationtest.sh b/command/agent/auth/kerberos/integtest/integrationtest.sh new file mode 100755 index 0000000000..28da55f599 --- /dev/null +++ b/command/agent/auth/kerberos/integtest/integrationtest.sh @@ -0,0 +1,170 @@ +#!/bin/bash +# Instructions +# This integration test is for the Vault Kerberos agent. +# Before running, execute: +# pip install --quiet requests-kerberos +# Then run this test from Vault's home directory. +# ./command/agent/auth/kerberos/integtest/integrationtest.sh + +if [[ "$OSTYPE" == "darwin"* ]]; then + base64cmd="base64 -D" +else + base64cmd="base64 -d" +fi + +VAULT_PORT=8200 +SAMBA_VER=4.8.12 + +export VAULT_TOKEN=${VAULT_TOKEN:-myroot} +DOMAIN_ADMIN_PASS=Pa55word! +DOMAIN_VAULT_ACCOUNT=vault_svc +DOMAIN_VAULT_PASS=vaultPa55word! +DOMAIN_USER_ACCOUNT=grace +DOMAIN_USER_PASS=gracePa55word! + +SAMBA_CONF_FILE=/srv/etc/smb.conf +DOMAIN_NAME=matrix +DNS_NAME=host +REALM_NAME=MATRIX.LAN +DOMAIN_DN=DC=MATRIX,DC=LAN +TESTS_DIR=/tmp/vault_plugin_tests + +function add_user() { + + username="${1}" + password="${2}" + + if [[ $(check_user ${username}) -eq 0 ]] + then + echo "add user '${username}'" + + docker exec $SAMBA_CONTAINER \ + /usr/bin/samba-tool user create \ + ${username} \ + ${password}\ + --configfile=${SAMBA_CONF_FILE} + fi +} + +function check_user() { + + username="${1}" + + docker exec $SAMBA_CONTAINER \ + /usr/bin/samba-tool user list \ + --configfile=${SAMBA_CONF_FILE} \ + | grep -c ${username} +} + +function create_keytab() { + + username="${1}" + password="${2}" + + user_kvno=$(docker exec $SAMBA_CONTAINER \ + bash -c "ldapsearch -H ldaps://localhost -D \"Administrator@${REALM_NAME}\" -w \"${DOMAIN_ADMIN_PASS}\" -b \"CN=Users,${DOMAIN_DN}\" -LLL \"(&(objectClass=user)(sAMAccountName=${username}))\" msDS-KeyVersionNumber | sed -n 's/^[ \t]*msDS-KeyVersionNumber:[ \t]*\(.*\)/\1/p'") + + docker exec $SAMBA_CONTAINER \ + bash -c "printf \"%b\" \"addent -password -p \"${username}@${REALM_NAME}\" -k ${user_kvno} -e rc4-hmac\n${password}\nwrite_kt ${username}.keytab\" | ktutil" + + docker exec $SAMBA_CONTAINER \ + bash -c "printf \"%b\" \"read_kt ${username}.keytab\nlist\" | ktutil" + + docker exec $SAMBA_CONTAINER \ + base64 ${username}.keytab > ${TESTS_DIR}/integration/${username}.keytab.base64 + + docker cp $SAMBA_CONTAINER:/${username}.keytab ${TESTS_DIR}/integration/ +} + +function main() { + # make and start vault + make dev + vault server -dev -dev-root-token-id=root & + + # start our domain controller + SAMBA_CONTAINER=$(docker run --net=${DNS_NAME} -d -ti --privileged -e "SAMBA_DC_ADMIN_PASSWD=${DOMAIN_ADMIN_PASS}" -e "KERBEROS_PASSWORD=${DOMAIN_ADMIN_PASS}" -e SAMBA_DC_DOMAIN=${DOMAIN_NAME} -e SAMBA_DC_REALM=${REALM_NAME} "bodsch/docker-samba4:${SAMBA_VER}") + sleep 15 + + # set up users + add_user $DOMAIN_VAULT_ACCOUNT $DOMAIN_VAULT_PASS + create_keytab $DOMAIN_VAULT_ACCOUNT $DOMAIN_VAULT_PASS + + add_user $DOMAIN_USER_ACCOUNT $DOMAIN_USER_PASS + create_keytab $DOMAIN_USER_ACCOUNT $DOMAIN_USER_PASS + + # add the service principals we'll need + docker exec $SAMBA_CONTAINER \ + samba-tool spn add HTTP/localhost ${DOMAIN_VAULT_ACCOUNT} --configfile=${SAMBA_CONF_FILE} + docker exec $SAMBA_CONTAINER \ + samba-tool spn add HTTP/localhost:${VAULT_PORT} ${DOMAIN_VAULT_ACCOUNT} --configfile=${SAMBA_CONF_FILE} + docker exec $SAMBA_CONTAINER \ + samba-tool spn add HTTP/localhost.${DNS_NAME} ${DOMAIN_VAULT_ACCOUNT} --configfile=${SAMBA_CONF_FILE} + docker exec $SAMBA_CONTAINER \ + samba-tool spn add HTTP/localhost.${DNS_NAME}:${VAULT_PORT} ${DOMAIN_VAULT_ACCOUNT} --configfile=${SAMBA_CONF_FILE} + + # enable and configure the kerberos plugin in Vault + vault auth enable -passthrough-request-headers=Authorization -allowed-response-headers=www-authenticate kerberos + vault write auth/kerberos/config keytab=@${TESTS_DIR}/integration/vault_svc.keytab.base64 service_account="vault_svc" + vault write auth/kerberos/config/ldap binddn=${DOMAIN_VAULT_ACCOUNT}@${REALM_NAME} bindpass=${DOMAIN_VAULT_PASS} groupattr=sAMAccountName groupdn="${DOMAIN_DN}" groupfilter="(&(objectClass=group)(member:1.2.840.113556.1.4.1941:={{.UserDN}}))" insecure_tls=true starttls=true userdn="CN=Users,${DOMAIN_DN}" userattr=sAMAccountName upndomain=${REALM_NAME} url=ldaps://localhost:636 + + mkdir -p ${TESTS_DIR}/integration + + echo " +[libdefaults] + default_realm = ${REALM_NAME} + dns_lookup_realm = false + dns_lookup_kdc = true + ticket_lifetime = 24h + renew_lifetime = 7d + forwardable = true + rdns = false + preferred_preauth_types = 23 +[realms] + ${REALM_NAME} = { + kdc = localhost + admin_server = localhost + master_kdc = localhost + default_domain = localhost + } +" > ${TESTS_DIR}/integration/krb5.conf + + echo " +auto_auth { + method \"kerberos\" { + mount_path = \"auth/kerberos\" + config = { + username = \"$DOMAIN_USER_ACCOUNT\" + service = \"HTTP/localhost:8200\" + realm = \"$REALM_NAME\" + keytab_path = \"$TESTS_DIR/integration/grace.keytab\" + krb5conf_path = \"$TESTS_DIR/integration/krb5.conf\" + } + } + sink \"file\" { + config = { + path = \"$TESTS_DIR/integration/agent-token.txt\" + } + } +} +" > ${TESTS_DIR}/integration/agent.conf + + vault agent -config=${TESTS_DIR}/integration/agent.conf & + sleep 10 + token=$(cat $TESTS_DIR/integration/agent-token.txt) + + # clean up: kill vault and stop the docker container we started + kill -9 $(ps aux | grep vault | awk '{print $2}' | head -1) # kill vault server + kill -9 $(ps aux | grep vault | awk '{print $2}' | head -1) # kill vault agent + docker rm -f ${SAMBA_CONTAINER} + + # a valid Vault token starts with "s.", check for that + if [[ $token != s.* ]]; then + echo "received invalid token: $token" + return 1 + fi + + echo "vault kerberos agent obtained auth token: $token" + echo "exiting successfully!" + return 0 +} +main diff --git a/command/agent/auth/kerberos/kerberos.go b/command/agent/auth/kerberos/kerberos.go new file mode 100644 index 0000000000..5f5ff034b9 --- /dev/null +++ b/command/agent/auth/kerberos/kerberos.go @@ -0,0 +1,91 @@ +package kerberos + +import ( + "context" + "errors" + "fmt" + "net/http" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/gokrb5/spnego" + kerberos "github.com/hashicorp/vault-plugin-auth-kerberos" + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/command/agent/auth" +) + +type kerberosMethod struct { + logger hclog.Logger + mountPath string + loginCfg *kerberos.LoginCfg +} + +func NewKerberosAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) { + if conf == nil { + return nil, errors.New("empty config") + } + if conf.Config == nil { + return nil, errors.New("empty config data") + } + username, err := read("username", conf.Config) + if err != nil { + return nil, err + } + service, err := read("service", conf.Config) + if err != nil { + return nil, err + } + realm, err := read("realm", conf.Config) + if err != nil { + return nil, err + } + keytabPath, err := read("keytab_path", conf.Config) + if err != nil { + return nil, err + } + krb5ConfPath, err := read("krb5conf_path", conf.Config) + if err != nil { + return nil, err + } + return &kerberosMethod{ + logger: conf.Logger, + mountPath: conf.MountPath, + loginCfg: &kerberos.LoginCfg{ + Username: username, + Service: service, + Realm: realm, + KeytabPath: keytabPath, + Krb5ConfPath: krb5ConfPath, + }, + }, nil +} + +func (k *kerberosMethod) Authenticate(context.Context, *api.Client) (string, http.Header, map[string]interface{}, error) { + k.logger.Trace("beginning authentication") + authHeaderVal, err := kerberos.GetAuthHeaderVal(k.loginCfg) + if err != nil { + return "", nil, nil, err + } + var header http.Header + header = make(map[string][]string) + header.Set(spnego.HTTPHeaderAuthRequest, authHeaderVal) + return k.mountPath + "/login", header, make(map[string]interface{}), nil +} + +// These functions are implemented to meet the AuthHandler interface, +// but we don't need to take advantage of them. +func (k *kerberosMethod) NewCreds() chan struct{} { return nil } +func (k *kerberosMethod) CredSuccess() {} +func (k *kerberosMethod) Shutdown() {} + +// read reads a key from a map and convert its value to a string. +func read(key string, m map[string]interface{}) (string, error) { + raw, ok := m[key] + if !ok { + return "", fmt.Errorf("%q is required", key) + } + v, ok := raw.(string) + if !ok { + return "", fmt.Errorf("%q must be a string", key) + } + return v, nil +} diff --git a/command/agent/auth/kerberos/kerberos_test.go b/command/agent/auth/kerberos/kerberos_test.go new file mode 100644 index 0000000000..01d4a82882 --- /dev/null +++ b/command/agent/auth/kerberos/kerberos_test.go @@ -0,0 +1,67 @@ +package kerberos + +import ( + "testing" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/command/agent/auth" +) + +func TestNewKerberosAuthMethod(t *testing.T) { + if _, err := NewKerberosAuthMethod(nil); err == nil { + t.Fatal("err should be returned for nil input") + } + if _, err := NewKerberosAuthMethod(&auth.AuthConfig{}); err == nil { + t.Fatal("err should be returned for nil config map") + } + + authConfig := simpleAuthConfig() + delete(authConfig.Config, "username") + if _, err := NewKerberosAuthMethod(authConfig); err == nil { + t.Fatal("err should be returned for missing username") + } + + authConfig = simpleAuthConfig() + delete(authConfig.Config, "service") + if _, err := NewKerberosAuthMethod(authConfig); err == nil { + t.Fatal("err should be returned for missing service") + } + + authConfig = simpleAuthConfig() + delete(authConfig.Config, "realm") + if _, err := NewKerberosAuthMethod(authConfig); err == nil { + t.Fatal("err should be returned for missing realm") + } + + authConfig = simpleAuthConfig() + delete(authConfig.Config, "keytab_path") + if _, err := NewKerberosAuthMethod(authConfig); err == nil { + t.Fatal("err should be returned for missing keytab_path") + } + + authConfig = simpleAuthConfig() + delete(authConfig.Config, "krb5conf_path") + if _, err := NewKerberosAuthMethod(authConfig); err == nil { + t.Fatal("err should be returned for missing krb5conf_path") + } + + authConfig = simpleAuthConfig() + if _, err := NewKerberosAuthMethod(authConfig); err != nil { + t.Fatal(err) + } +} + +func simpleAuthConfig() *auth.AuthConfig { + return &auth.AuthConfig{ + Logger: hclog.NewNullLogger(), + MountPath: "kerberos", + WrapTTL: 20, + Config: map[string]interface{}{ + "username": "grace", + "service": "HTTP/05a65fad28ef.matrix.lan:8200", + "realm": "MATRIX.LAN", + "keytab_path": "grace.keytab", + "krb5conf_path": "krb5.conf", + }, + } +} diff --git a/command/agent/auth/kubernetes/kubernetes.go b/command/agent/auth/kubernetes/kubernetes.go index 21542bfc66..8b35b30ae6 100644 --- a/command/agent/auth/kubernetes/kubernetes.go +++ b/command/agent/auth/kubernetes/kubernetes.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "io/ioutil" + "net/http" "os" "strings" @@ -72,15 +73,15 @@ func NewKubernetesAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) { return k, nil } -func (k *kubernetesMethod) Authenticate(ctx context.Context, client *api.Client) (string, map[string]interface{}, error) { +func (k *kubernetesMethod) Authenticate(ctx context.Context, client *api.Client) (string, http.Header, map[string]interface{}, error) { k.logger.Trace("beginning authentication") jwtString, err := k.readJWT() if err != nil { - return "", nil, errwrap.Wrapf("error reading JWT with Kubernetes Auth: {{err}}", err) + return "", nil, nil, errwrap.Wrapf("error reading JWT with Kubernetes Auth: {{err}}", err) } - return fmt.Sprintf("%s/login", k.mountPath), map[string]interface{}{ + return fmt.Sprintf("%s/login", k.mountPath), nil, map[string]interface{}{ "role": k.role, "jwt": jwtString, }, nil diff --git a/command/agent/auth/kubernetes/kubernetes_test.go b/command/agent/auth/kubernetes/kubernetes_test.go index 6ea560075e..8032d11e09 100644 --- a/command/agent/auth/kubernetes/kubernetes_test.go +++ b/command/agent/auth/kubernetes/kubernetes_test.go @@ -61,7 +61,7 @@ func TestKubernetesAuth_basic(t *testing.T) { k.jwtData = tc.data } - _, data, err := k.Authenticate(context.Background(), nil) + _, _, data, err := k.Authenticate(context.Background(), nil) if err != nil && tc.e == nil { t.Fatal(err) } diff --git a/command/agent/aws_end_to_end_test.go b/command/agent/aws_end_to_end_test.go index 95b465b0e8..abb2921e7b 100644 --- a/command/agent/aws_end_to_end_test.go +++ b/command/agent/aws_end_to_end_test.go @@ -2,7 +2,6 @@ package agent import ( "context" - "fmt" "io/ioutil" "os" "testing" @@ -77,7 +76,6 @@ func TestAWSEndToEnd(t *testing.T) { // Retain thru the account number of the given arn and wildcard the rest. "bound_iam_principal_arn": os.Getenv(envVarAwsTestRoleArn)[:25] + "*", }); err != nil { - fmt.Println(err) t.Fatal(err) } diff --git a/helper/metricsutil/metricsutil_test.go b/helper/metricsutil/metricsutil_test.go index bf33b76045..1b817ddad1 100644 --- a/helper/metricsutil/metricsutil_test.go +++ b/helper/metricsutil/metricsutil_test.go @@ -1,8 +1,9 @@ package metricsutil import ( - "github.com/hashicorp/vault/sdk/logical" "testing" + + "github.com/hashicorp/vault/sdk/logical" ) func TestFormatFromRequest(t *testing.T) { diff --git a/vendor/github.com/hashicorp/vault/api/client.go b/vendor/github.com/hashicorp/vault/api/client.go index a6a8bb6091..434663b12a 100644 --- a/vendor/github.com/hashicorp/vault/api/client.go +++ b/vendor/github.com/hashicorp/vault/api/client.go @@ -603,7 +603,7 @@ func (c *Client) ClearToken() { } // Headers gets the current set of headers used for requests. This returns a -// copy; to modify it make modifications locally and use SetHeaders. +// copy; to modify it call AddHeader or SetHeaders. func (c *Client) Headers() http.Header { c.modifyLock.RLock() defer c.modifyLock.RUnlock() @@ -622,11 +622,19 @@ func (c *Client) Headers() http.Header { return ret } -// SetHeaders sets the headers to be used for future requests. +// AddHeader allows a single header key/value pair to be added +// in a race-safe fashion. +func (c *Client) AddHeader(key, value string) { + c.modifyLock.Lock() + defer c.modifyLock.Unlock() + c.headers.Add(key, value) +} + +// SetHeaders clears all previous headers and uses only the given +// ones going forward. func (c *Client) SetHeaders(headers http.Header) { c.modifyLock.Lock() defer c.modifyLock.Unlock() - c.headers = headers }