From a5d1808efe3a2f58acfaed2b87c3fbef22c470bf Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Thu, 16 Mar 2017 11:14:17 -0400 Subject: [PATCH] Always include a hash of the public key and "vault" (to know where it (#2498) came from) when generating a cert for SSH. Follow on from #2494 --- builtin/logical/ssh/backend_test.go | 6 +++--- builtin/logical/ssh/path_roles.go | 1 - builtin/logical/ssh/path_sign.go | 24 +++++++++++++++++------- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/builtin/logical/ssh/backend_test.go b/builtin/logical/ssh/backend_test.go index 37a348489e..538455c728 100644 --- a/builtin/logical/ssh/backend_test.go +++ b/builtin/logical/ssh/backend_test.go @@ -587,7 +587,7 @@ func TestBackend_ValidPrincipalsValidatedForHostCertificates(t *testing.T) { }, }), - signCertificateStep("testing", "root", ssh.HostCert, []string{"dummy.example.org", "second.example.com"}, map[string]string{ + signCertificateStep("testing", "vault-root-22608f5ef173aabf700797cb95c5641e792698ec6380e8e1eb55523e39aa5e51", ssh.HostCert, []string{"dummy.example.org", "second.example.com"}, map[string]string{ "option": "value", }, map[string]string{ "extension": "extended", @@ -632,7 +632,7 @@ func TestBackend_OptionsOverrideDefaults(t *testing.T) { }, }), - signCertificateStep("testing", "root", ssh.UserCert, []string{"tuber"}, map[string]string{ + signCertificateStep("testing", "vault-root-22608f5ef173aabf700797cb95c5641e792698ec6380e8e1eb55523e39aa5e51", ssh.UserCert, []string{"tuber"}, map[string]string{ "secondary": "value", }, map[string]string{ "additional": "value", @@ -709,7 +709,7 @@ func validateSSHCertificate(cert *ssh.Certificate, keyId string, certType int, v ttl time.Duration) error { if cert.KeyId != keyId { - return fmt.Errorf("Incorrect KeyId: %v", cert.KeyId) + return fmt.Errorf("Incorrect KeyId: %v, wanted %v", cert.KeyId, keyId) } if cert.CertType != uint32(certType) { diff --git a/builtin/logical/ssh/path_roles.go b/builtin/logical/ssh/path_roles.go index 679d38cc5e..907689f282 100644 --- a/builtin/logical/ssh/path_roles.go +++ b/builtin/logical/ssh/path_roles.go @@ -257,7 +257,6 @@ func pathRoles(b *backend) *framework.Path { }, "allow_user_key_ids": &framework.FieldSchema{ Type: framework.TypeBool, - Default: true, Description: ` [Not applicable for Dynamic type] [Not applicable for OTP type] [Optional for CA type] If true, users can override the key ID for a signed certificate with the "key_id" field. diff --git a/builtin/logical/ssh/path_sign.go b/builtin/logical/ssh/path_sign.go index 7579601797..22e2513162 100644 --- a/builtin/logical/ssh/path_sign.go +++ b/builtin/logical/ssh/path_sign.go @@ -2,6 +2,8 @@ package ssh import ( "crypto/rand" + "crypto/sha256" + "encoding/hex" "errors" "fmt" "strconv" @@ -109,7 +111,7 @@ func (b *backend) pathSignCertificate(req *logical.Request, data *framework.Fiel // Note that these various functions always return "user errors" so we pass // them as 4xx values - keyId, err := b.calculateKeyId(data, req, role) + keyId, err := b.calculateKeyId(data, req, role, userPublicKey) if err != nil { return logical.ErrorResponse(err.Error()), nil } @@ -263,17 +265,25 @@ func (b *backend) calculateCertificateType(data *framework.FieldData, role *sshR return certificateType, nil } -func (b *backend) calculateKeyId(data *framework.FieldData, req *logical.Request, role *sshRole) (string, error) { - keyId := data.Get("key_id").(string) +func (b *backend) calculateKeyId(data *framework.FieldData, req *logical.Request, role *sshRole, pubKey ssh.PublicKey) (string, error) { + reqId := data.Get("key_id").(string) - if keyId != "" && !role.AllowUserKeyIDs { - return "", fmt.Errorf("Setting key_id is not allowed by role") + if reqId != "" { + if !role.AllowUserKeyIDs { + return "", fmt.Errorf("setting key_id is not allowed by role") + } + return reqId, nil } - if keyId == "" { - keyId = req.DisplayName + keyHash := sha256.Sum256(pubKey.Marshal()) + keyId := hex.EncodeToString(keyHash[:]) + + if req.DisplayName != "" { + keyId = fmt.Sprintf("%s-%s", req.DisplayName, keyId) } + keyId = fmt.Sprintf("vault-%s", keyId) + return keyId, nil }