diff --git a/builtin/logical/pki/acme_jws.go b/builtin/logical/pki/acme_jws.go index 034c379c8b..c456008090 100644 --- a/builtin/logical/pki/acme_jws.go +++ b/builtin/logical/pki/acme_jws.go @@ -26,20 +26,20 @@ var AllowedOuterJWSTypes = map[string]interface{}{ type jwsCtx struct { Algo string `json:"alg"` Kid string `json:"kid"` - jwk json.RawMessage `json:"jwk"` + Jwk json.RawMessage `json:"jwk"` Nonce string `json:"nonce"` Url string `json:"url"` - key jose.JSONWebKey `json:"-"` + Key jose.JSONWebKey `json:"-"` Existing bool `json:"-"` } -func (c *jwsCtx) UnmarshalJSON(a *acmeState, jws []byte) error { +func (c *jwsCtx) UnmarshalJSON(a *acmeState, ac *acmeContext, jws []byte) error { var err error if err = json.Unmarshal(jws, c); err != nil { return err } - if c.Kid != "" && len(c.jwk) > 0 { + if c.Kid != "" && len(c.Jwk) > 0 { // See RFC 8555 Section 6.2. Request Authentication: // // > The "jwk" and "kid" fields are mutually exclusive. Servers MUST @@ -47,7 +47,7 @@ func (c *jwsCtx) UnmarshalJSON(a *acmeState, jws []byte) error { return fmt.Errorf("invalid header: got both account 'kid' and 'jwk' in the same message; expected only one: %w", ErrMalformed) } - if c.Kid == "" && len(c.jwk) == 0 { + if c.Kid == "" && len(c.Jwk) == 0 { // See RFC 8555 Section 6.2. Request Authentication: // // > Either "jwk" (JSON Web Key) or "kid" (Key ID) as specified @@ -70,24 +70,24 @@ func (c *jwsCtx) UnmarshalJSON(a *acmeState, jws []byte) error { if c.Kid != "" { // Load KID from storage first. - c.jwk, err = a.LoadJWK(c.Kid) + c.Jwk, err = a.LoadJWK(ac, c.Kid) if err != nil { return err } c.Existing = true } - if err = c.key.UnmarshalJSON(c.jwk); err != nil { + if err = c.Key.UnmarshalJSON(c.Jwk); err != nil { return err } - if !c.key.Valid() { + if !c.Key.Valid() { return fmt.Errorf("received invalid jwk: %w", ErrMalformed) } - if c.Kid != "" { + if c.Kid == "" { // Create a key ID - kid, err := c.key.Thumbprint(crypto.SHA256) + kid, err := c.Key.Thumbprint(crypto.SHA256) if err != nil { return fmt.Errorf("failed creating thumbprint: %w", err) } @@ -128,7 +128,7 @@ func (c *jwsCtx) VerifyJWS(signature string) (map[string]interface{}, error) { return nil, fmt.Errorf("request had unprotected headers: %w", ErrMalformed) } - payload, err := sig.Verify(c.key) + payload, err := sig.Verify(c.Key) if err != nil { return nil, err } diff --git a/builtin/logical/pki/acme_state.go b/builtin/logical/pki/acme_state.go index 36715e1689..5fef9b3297 100644 --- a/builtin/logical/pki/acme_state.go +++ b/builtin/logical/pki/acme_state.go @@ -5,15 +5,23 @@ import ( "encoding/base64" "fmt" "io" + "strings" "sync" "sync/atomic" "time" "github.com/hashicorp/vault/sdk/framework" + "github.com/hashicorp/vault/sdk/logical" ) -// How long nonces are considered valid. -const nonceExpiry = 15 * time.Minute +const ( + // How long nonces are considered valid. + nonceExpiry = 15 * time.Minute + + // Path Prefixes + acmePathPrefix = "acme/" + acmeAccountPrefix = acmePathPrefix + "accounts/" +) type acmeState struct { nextExpiry *atomic.Int64 @@ -99,36 +107,86 @@ func (a *acmeState) TidyNonces() { a.nextExpiry.Store(nextRun.Unix()) } -func (a *acmeState) CreateAccount(c *jwsCtx, contact []string, termsOfServiceAgreed bool) (map[string]interface{}, error) { - // TODO - return nil, nil +type ACMEStates string + +const ( + StatusValid = "valid" + StatusDeactivated = "deactivated" + StatusRevoked = "revoked" +) + +type acmeAccount struct { + KeyId string `json:"-"` + Status ACMEStates `json:"state"` + Contact []string `json:"contact"` + TermsOfServiceAgreed bool `json:"termsOfServiceAgreed"` + Jwk []byte `json:"jwk"` } -func (a *acmeState) LoadAccount(keyID string) (map[string]interface{}, error) { - // TODO - return nil, nil +func (a *acmeState) CreateAccount(ac *acmeContext, c *jwsCtx, contact []string, termsOfServiceAgreed bool) (*acmeAccount, error) { + acct := &acmeAccount{ + KeyId: c.Kid, + Contact: contact, + TermsOfServiceAgreed: termsOfServiceAgreed, + Jwk: c.Jwk, + } + + json, err := logical.StorageEntryJSON(acmeAccountPrefix+c.Kid, acct) + if err != nil { + return nil, fmt.Errorf("error creating account entry: %w", err) + } + + if err := ac.sc.Storage.Put(ac.sc.Context, json); err != nil { + return nil, fmt.Errorf("error writing account entry: %w", err) + } + + return acct, nil } -func (a *acmeState) DoesAccountExist(keyId string) bool { - account, err := a.LoadAccount(keyId) - return err == nil && len(account) > 0 +func cleanKid(keyID string) string { + pieces := strings.Split(keyID, "/") + return pieces[len(pieces)-1] } -func (a *acmeState) LoadJWK(keyID string) ([]byte, error) { - key, err := a.LoadAccount(keyID) +func (a *acmeState) LoadAccount(ac *acmeContext, keyID string) (*acmeAccount, error) { + kid := cleanKid(keyID) + + entry, err := ac.sc.Storage.Get(ac.sc.Context, acmeAccountPrefix+kid) + if err != nil { + return nil, fmt.Errorf("error loading account: %w", err) + } + if entry == nil { + return nil, fmt.Errorf("account not found: %w", ErrMalformed) + } + + var acct acmeAccount + err = entry.DecodeJSON(&acct) + if err != nil { + return nil, fmt.Errorf("error loading account: %w", err) + } + + return &acct, nil +} + +func (a *acmeState) DoesAccountExist(ac *acmeContext, keyId string) bool { + account, err := a.LoadAccount(ac, keyId) + return err == nil && account != nil +} + +func (a *acmeState) LoadJWK(ac *acmeContext, keyID string) ([]byte, error) { + key, err := a.LoadAccount(ac, keyID) if err != nil { return nil, err } - jwk, present := key["jwk"] - if !present { + if len(key.Jwk) == 0 { return nil, fmt.Errorf("malformed key entry lacks JWK") } - return jwk.([]byte), nil + return key.Jwk, nil } -func (a *acmeState) ParseRequestParams(data *framework.FieldData) (*jwsCtx, map[string]interface{}, error) { +func (a *acmeState) ParseRequestParams(ac *acmeContext, data *framework.FieldData) (*jwsCtx, map[string]interface{}, error) { var c jwsCtx var m map[string]interface{} @@ -143,7 +201,7 @@ func (a *acmeState) ParseRequestParams(data *framework.FieldData) (*jwsCtx, map[ if err != nil { return nil, nil, fmt.Errorf("failed to base64 parse 'protected': %s: %w", err, ErrMalformed) } - if err = c.UnmarshalJSON(a, jwkBytes); err != nil { + if err = c.UnmarshalJSON(a, ac, jwkBytes); err != nil { return nil, nil, fmt.Errorf("failed to json unmarshal 'protected': %w", err) } diff --git a/builtin/logical/pki/path_acme_directory.go b/builtin/logical/pki/path_acme_directory.go index a8329a0258..bd16c51e16 100644 --- a/builtin/logical/pki/path_acme_directory.go +++ b/builtin/logical/pki/path_acme_directory.go @@ -55,7 +55,7 @@ func patternAcmeDirectory(b *backend, pattern string) *framework.Path { } } -type acmeOperation func(acmeCtx acmeContext, r *logical.Request, _ *framework.FieldData) (*logical.Response, error) +type acmeOperation func(acmeCtx *acmeContext, r *logical.Request, _ *framework.FieldData) (*logical.Response, error) type acmeContext struct { baseUrl *url.URL @@ -76,7 +76,7 @@ func (b *backend) acmeWrapper(op acmeOperation) framework.OperationFunc { return nil, err } - acmeCtx := acmeContext{ + acmeCtx := &acmeContext{ baseUrl: baseUrl, sc: sc, } @@ -120,7 +120,7 @@ func acmeErrorWrapper(op framework.OperationFunc) framework.OperationFunc { } } -func (b *backend) acmeDirectoryHandler(acmeCtx acmeContext, r *logical.Request, _ *framework.FieldData) (*logical.Response, error) { +func (b *backend) acmeDirectoryHandler(acmeCtx *acmeContext, r *logical.Request, _ *framework.FieldData) (*logical.Response, error) { rawBody, err := json.Marshal(map[string]interface{}{ "newNonce": acmeCtx.baseUrl.JoinPath("new-nonce").String(), "newAccount": acmeCtx.baseUrl.JoinPath("new-account").String(), diff --git a/builtin/logical/pki/path_acme_new_account.go b/builtin/logical/pki/path_acme_new_account.go index cc922f3670..67cfe094df 100644 --- a/builtin/logical/pki/path_acme_new_account.go +++ b/builtin/logical/pki/path_acme_new_account.go @@ -1,7 +1,9 @@ package pki import ( + "encoding/json" "fmt" + "net/http" "strings" "github.com/hashicorp/vault/sdk/framework" @@ -88,31 +90,102 @@ func patternAcmeNewAccount(b *backend, pattern string) *framework.Path { } } -type acmeParsedOperation func(acmeCtx acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}) (*logical.Response, error) +type acmeParsedOperation func(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}) (*logical.Response, error) func (b *backend) acmeParsedWrapper(op acmeParsedOperation) framework.OperationFunc { - return b.acmeWrapper(func(acmeCtx acmeContext, r *logical.Request, fields *framework.FieldData) (*logical.Response, error) { - user, data, err := b.acmeState.ParseRequestParams(fields) + return b.acmeWrapper(func(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData) (*logical.Response, error) { + user, data, err := b.acmeState.ParseRequestParams(acmeCtx, fields) if err != nil { return nil, err } - return op(acmeCtx, r, fields, user, data) + resp, err := op(acmeCtx, r, fields, user, data) + + // Our response handlers might not add the necessary headers. + if resp != nil { + if resp.Headers == nil { + resp.Headers = map[string][]string{} + } + + if _, ok := resp.Headers["Replay-Nonce"]; !ok { + nonce, _, err := b.acmeState.GetNonce() + if err != nil { + return nil, err + } + + resp.Headers["Replay-Nonce"] = []string{nonce} + } + + if _, ok := resp.Headers["Link"]; !ok { + resp.Headers["Link"] = genAcmeLinkHeader(acmeCtx) + } else { + directory := genAcmeLinkHeader(acmeCtx)[0] + addDirectory := true + for _, item := range resp.Headers["Link"] { + if item == directory { + addDirectory = false + break + } + } + if addDirectory { + resp.Headers["Link"] = append(resp.Headers["Link"], directory) + } + } + + // ACME responses don't understand Vault's default encoding + // format. Rather than expecting everything to handle creating + // ACME-formatted responses, do the marshaling in one place. + if _, ok := resp.Data[logical.HTTPRawBody]; !ok { + ignored_values := map[string]bool{logical.HTTPContentType: true, logical.HTTPStatusCode: true} + fields := map[string]interface{}{} + body := map[string]interface{}{ + logical.HTTPContentType: "application/json", + logical.HTTPStatusCode: http.StatusOK, + } + + for key, value := range resp.Data { + if _, present := ignored_values[key]; !present { + fields[key] = value + } else { + body[key] = value + } + } + + rawBody, err := json.Marshal(fields) + if err != nil { + return nil, fmt.Errorf("Error marshaling JSON body: %w", err) + } + + body[logical.HTTPRawBody] = rawBody + resp.Data = body + } + } + + return resp, err }) } -func (b *backend) acmeNewAccountHandler(acmeCtx acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}) (*logical.Response, error) { +func (b *backend) acmeNewAccountHandler(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}) (*logical.Response, error) { // Parameters var ok bool var onlyReturnExisting bool - var contact []string + var contacts []string var termsOfServiceAgreed bool rawContact, present := data["contact"] if present { - contact, ok = rawContact.([]string) + listContact, ok := rawContact.([]interface{}) if !ok { - return nil, fmt.Errorf("invalid type for field 'contact': %w", ErrMalformed) + return nil, fmt.Errorf("invalid type (%T) for field 'contact': %w", rawContact, ErrMalformed) + } + + for index, singleContact := range listContact { + contact, ok := singleContact.(string) + if !ok { + return nil, fmt.Errorf("invalid type (%T) for field 'contact' item %d: %w", singleContact, index, ErrMalformed) + } + + contacts = append(contacts, contact) } } @@ -120,7 +193,7 @@ func (b *backend) acmeNewAccountHandler(acmeCtx acmeContext, r *logical.Request, if present { termsOfServiceAgreed, ok = rawTermsOfServiceAgreed.(bool) if !ok { - return nil, fmt.Errorf("invalid type for field 'termsOfServiceAgreed': %w", ErrMalformed) + return nil, fmt.Errorf("invalid type (%T) for field 'termsOfServiceAgreed': %w", rawTermsOfServiceAgreed, ErrMalformed) } } @@ -128,7 +201,7 @@ func (b *backend) acmeNewAccountHandler(acmeCtx acmeContext, r *logical.Request, if present { onlyReturnExisting, ok = rawOnlyReturnExisting.(bool) if !ok { - return nil, fmt.Errorf("invalid type for field 'onlyReturnExisting': %w", ErrMalformed) + return nil, fmt.Errorf("invalid type (%T) for field 'onlyReturnExisting': %w", rawOnlyReturnExisting, ErrMalformed) } } @@ -139,38 +212,39 @@ func (b *backend) acmeNewAccountHandler(acmeCtx acmeContext, r *logical.Request, return b.acmeNewAccountSearchHandler(acmeCtx, r, fields, userCtx, data) } - return b.acmeNewAccountCreateHandler(acmeCtx, r, fields, userCtx, data, contact, termsOfServiceAgreed) + return b.acmeNewAccountCreateHandler(acmeCtx, r, fields, userCtx, data, contacts, termsOfServiceAgreed) } -func formatAccountResponse(location string, status string, contact []string) *logical.Response { +func formatAccountResponse(location string, acct *acmeAccount) *logical.Response { resp := &logical.Response{ Data: map[string]interface{}{ - "status": status, + "status": acct.Status, "orders": location + "/orders", }, + Headers: map[string][]string{ + "Location": {location}, + }, } - if len(contact) > 0 { - resp.Data["contact"] = contact + if len(acct.Contact) > 0 { + resp.Data["contact"] = acct.Contact } - resp.Headers["Location"] = []string{location} - return resp } -func (b *backend) acmeNewAccountSearchHandler(acmeCtx acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}) (*logical.Response, error) { - if userCtx.Existing || b.acmeState.DoesAccountExist(userCtx.Kid) { +func (b *backend) acmeNewAccountSearchHandler(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}) (*logical.Response, error) { + if userCtx.Existing || b.acmeState.DoesAccountExist(acmeCtx, userCtx.Kid) { // This account exists; return its details. It would be slightly // weird to specify a kid in the request (and not use an explicit // jwk here), but we might as well support it too. - account, err := b.acmeState.LoadAccount(userCtx.Kid) + account, err := b.acmeState.LoadAccount(acmeCtx, userCtx.Kid) if err != nil { return nil, fmt.Errorf("error loading account: %w", err) } location := acmeCtx.baseUrl.String() + "account/" + userCtx.Kid - return formatAccountResponse(location, account["status"].(string), account["contact"].([]string)), nil + return formatAccountResponse(location, account), nil } // Per RFC 8555 Section 7.3.1. Finding an Account URL Given a Key: @@ -181,13 +255,13 @@ func (b *backend) acmeNewAccountSearchHandler(acmeCtx acmeContext, r *logical.Re return nil, fmt.Errorf("An account with this key does not exist: %w", ErrAccountDoesNotExist) } -func (b *backend) acmeNewAccountCreateHandler(acmeCtx acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}, contact []string, termsOfServiceAgreed bool) (*logical.Response, error) { +func (b *backend) acmeNewAccountCreateHandler(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}, contact []string, termsOfServiceAgreed bool) (*logical.Response, error) { if userCtx.Existing { return nil, fmt.Errorf("cannot submit to newAccount with 'kid': %w", ErrMalformed) } // If the account already exists, return the existing one. - if b.acmeState.DoesAccountExist(userCtx.Kid) { + if b.acmeState.DoesAccountExist(acmeCtx, userCtx.Kid) { return b.acmeNewAccountSearchHandler(acmeCtx, r, fields, userCtx, data) } @@ -196,11 +270,18 @@ func (b *backend) acmeNewAccountCreateHandler(acmeCtx acmeContext, r *logical.Re return nil, fmt.Errorf("terms of service not agreed to: %w", ErrUserActionRequired) } - account, err := b.acmeState.CreateAccount(userCtx, contact, termsOfServiceAgreed) + account, err := b.acmeState.CreateAccount(acmeCtx, userCtx, contact, termsOfServiceAgreed) if err != nil { return nil, fmt.Errorf("failed to create account: %w", err) } location := acmeCtx.baseUrl.String() + "account/" + userCtx.Kid - return formatAccountResponse(location, account["status"].(string), account["contact"].([]string)), nil + resp := formatAccountResponse(location, account) + + // Per RFC 8555 Section 7.3. Account Management: + // + // > The server returns this account object in a 201 (Created) response, + // > with the account URL in a Location header field. + resp.Data[logical.HTTPStatusCode] = http.StatusCreated + return resp, nil } diff --git a/builtin/logical/pki/path_acme_nonce.go b/builtin/logical/pki/path_acme_nonce.go index 4a6d227ea7..bfe9ad80ea 100644 --- a/builtin/logical/pki/path_acme_nonce.go +++ b/builtin/logical/pki/path_acme_nonce.go @@ -51,7 +51,7 @@ func patternAcmeNonce(b *backend, pattern string) *framework.Path { } } -func (b *backend) acmeNonceHandler(ctx acmeContext, r *logical.Request, _ *framework.FieldData) (*logical.Response, error) { +func (b *backend) acmeNonceHandler(ctx *acmeContext, r *logical.Request, _ *framework.FieldData) (*logical.Response, error) { nonce, _, err := b.acmeState.GetNonce() if err != nil { return nil, err @@ -78,7 +78,7 @@ func (b *backend) acmeNonceHandler(ctx acmeContext, r *logical.Request, _ *frame }, nil } -func genAcmeLinkHeader(ctx acmeContext) []string { +func genAcmeLinkHeader(ctx *acmeContext) []string { path := fmt.Sprintf("<%s>;rel=\"index\"", ctx.baseUrl.JoinPath("directory").String()) return []string{path} }