diff --git a/builtin/logical/pki/cert_util.go b/builtin/logical/pki/cert_util.go index 1961ecbba8..a228ad1156 100644 --- a/builtin/logical/pki/cert_util.go +++ b/builtin/logical/pki/cert_util.go @@ -16,6 +16,7 @@ import ( "time" "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/helper/errutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -117,31 +118,31 @@ func validateKeyTypeLength(keyType string, keyBits int) *logical.Response { func fetchCAInfo(req *logical.Request) (*caInfoBundle, error) { bundleEntry, err := req.Storage.Get("config/ca_bundle") if err != nil { - return nil, certutil.InternalError{Err: fmt.Sprintf("unable to fetch local CA certificate/key: %v", err)} + return nil, errutil.InternalError{Err: fmt.Sprintf("unable to fetch local CA certificate/key: %v", err)} } if bundleEntry == nil { - return nil, certutil.UserError{Err: "backend must be configured with a CA certificate/key"} + return nil, errutil.UserError{Err: "backend must be configured with a CA certificate/key"} } var bundle certutil.CertBundle if err := bundleEntry.DecodeJSON(&bundle); err != nil { - return nil, certutil.InternalError{Err: fmt.Sprintf("unable to decode local CA certificate/key: %v", err)} + return nil, errutil.InternalError{Err: fmt.Sprintf("unable to decode local CA certificate/key: %v", err)} } parsedBundle, err := bundle.ToParsedCertBundle() if err != nil { - return nil, certutil.InternalError{Err: err.Error()} + return nil, errutil.InternalError{Err: err.Error()} } if parsedBundle.Certificate == nil { - return nil, certutil.InternalError{Err: "stored CA information not able to be parsed"} + return nil, errutil.InternalError{Err: "stored CA information not able to be parsed"} } caInfo := &caInfoBundle{*parsedBundle, nil} entries, err := getURLs(req) if err != nil { - return nil, certutil.InternalError{Err: fmt.Sprintf("unable to fetch URL information: %v", err)} + return nil, errutil.InternalError{Err: fmt.Sprintf("unable to fetch URL information: %v", err)} } if entries == nil { entries = &urlEntries{ @@ -175,14 +176,14 @@ func fetchCertBySerial(req *logical.Request, prefix, serial string) (*logical.St certEntry, err := req.Storage.Get(path) if err != nil { - return nil, certutil.InternalError{Err: fmt.Sprintf("error fetching certificate %s: %s", serial, err)} + return nil, errutil.InternalError{Err: fmt.Sprintf("error fetching certificate %s: %s", serial, err)} } if certEntry == nil { return nil, nil } if certEntry.Value == nil || len(certEntry.Value) == 0 { - return nil, certutil.InternalError{Err: fmt.Sprintf("returned certificate bytes for serial %s were empty", serial)} + return nil, errutil.InternalError{Err: fmt.Sprintf("returned certificate bytes for serial %s were empty", serial)} } return certEntry, nil @@ -356,7 +357,7 @@ func generateCert(b *backend, data *framework.FieldData) (*certutil.ParsedCertBundle, error) { if role.KeyType == "rsa" && role.KeyBits < 2048 { - return nil, certutil.UserError{Err: "RSA keys < 2048 bits are unsafe and not supported"} + return nil, errutil.UserError{Err: "RSA keys < 2048 bits are unsafe and not supported"} } creationBundle, err := generateCreationBundle(b, role, signingBundle, nil, req, data) @@ -371,7 +372,7 @@ func generateCert(b *backend, // Generating a self-signed root certificate entries, err := getURLs(req) if err != nil { - return nil, certutil.InternalError{Err: fmt.Sprintf("unable to fetch URL information: %v", err)} + return nil, errutil.InternalError{Err: fmt.Sprintf("unable to fetch URL information: %v", err)} } if entries == nil { entries = &urlEntries{ @@ -429,40 +430,40 @@ func signCert(b *backend, csrString := data.Get("csr").(string) if csrString == "" { - return nil, certutil.UserError{Err: fmt.Sprintf("\"csr\" is empty")} + return nil, errutil.UserError{Err: fmt.Sprintf("\"csr\" is empty")} } pemBytes := []byte(csrString) pemBlock, pemBytes := pem.Decode(pemBytes) if pemBlock == nil { - return nil, certutil.UserError{Err: "csr contains no data"} + return nil, errutil.UserError{Err: "csr contains no data"} } csr, err := x509.ParseCertificateRequest(pemBlock.Bytes) if err != nil { - return nil, certutil.UserError{Err: "certificate request could not be parsed"} + return nil, errutil.UserError{Err: "certificate request could not be parsed"} } switch role.KeyType { case "rsa": // Verify that the key matches the role type if csr.PublicKeyAlgorithm != x509.RSA { - return nil, certutil.UserError{Err: fmt.Sprintf( + return nil, errutil.UserError{Err: fmt.Sprintf( "role requires keys of type %s", role.KeyType)} } pubKey, ok := csr.PublicKey.(*rsa.PublicKey) if !ok { - return nil, certutil.UserError{Err: "could not parse CSR's public key"} + return nil, errutil.UserError{Err: "could not parse CSR's public key"} } // Verify that the key is at least 2048 bits if pubKey.N.BitLen() < 2048 { - return nil, certutil.UserError{Err: "RSA keys < 2048 bits are unsafe and not supported"} + return nil, errutil.UserError{Err: "RSA keys < 2048 bits are unsafe and not supported"} } // Verify that the bit size is at least the size specified in the role if pubKey.N.BitLen() < role.KeyBits { - return nil, certutil.UserError{Err: fmt.Sprintf( + return nil, errutil.UserError{Err: fmt.Sprintf( "role requires a minimum of a %d-bit key, but CSR's key is %d bits", role.KeyBits, pubKey.N.BitLen())} @@ -471,18 +472,18 @@ func signCert(b *backend, case "ec": // Verify that the key matches the role type if csr.PublicKeyAlgorithm != x509.ECDSA { - return nil, certutil.UserError{Err: fmt.Sprintf( + return nil, errutil.UserError{Err: fmt.Sprintf( "role requires keys of type %s", role.KeyType)} } pubKey, ok := csr.PublicKey.(*ecdsa.PublicKey) if !ok { - return nil, certutil.UserError{Err: "could not parse CSR's public key"} + return nil, errutil.UserError{Err: "could not parse CSR's public key"} } // Verify that the bit size is at least the size specified in the role if pubKey.Params().BitSize < role.KeyBits { - return nil, certutil.UserError{Err: fmt.Sprintf( + return nil, errutil.UserError{Err: fmt.Sprintf( "role requires a minimum of a %d-bit key, but CSR's key is %d bits", role.KeyBits, pubKey.Params().BitSize)} @@ -498,10 +499,10 @@ func signCert(b *backend, // Run RSA < 2048 bit checks pubKey, ok := csr.PublicKey.(*rsa.PublicKey) if !ok { - return nil, certutil.UserError{Err: "could not parse CSR's public key"} + return nil, errutil.UserError{Err: "could not parse CSR's public key"} } if pubKey.N.BitLen() < 2048 { - return nil, certutil.UserError{Err: "RSA keys < 2048 bits are unsafe and not supported"} + return nil, errutil.UserError{Err: "RSA keys < 2048 bits are unsafe and not supported"} } } @@ -545,7 +546,7 @@ func generateCreationBundle(b *backend, if cn == "" { cn = data.Get("common_name").(string) if cn == "" { - return nil, certutil.UserError{Err: `the common_name field is required, or must be provided in a CSR with "use_csr_common_name" set to true`} + return nil, errutil.UserError{Err: `the common_name field is required, or must be provided in a CSR with "use_csr_common_name" set to true`} } } } @@ -583,19 +584,19 @@ func generateCreationBundle(b *backend, // Check for bad email and/or DNS names badName, err := validateNames(req, dnsNames, role) if len(badName) != 0 { - return nil, certutil.UserError{Err: fmt.Sprintf( + return nil, errutil.UserError{Err: fmt.Sprintf( "name %s not allowed by this role", badName)} } else if err != nil { - return nil, certutil.InternalError{Err: fmt.Sprintf( + return nil, errutil.InternalError{Err: fmt.Sprintf( "error validating name %s: %s", badName, err)} } badName, err = validateNames(req, emailAddresses, role) if len(badName) != 0 { - return nil, certutil.UserError{Err: fmt.Sprintf( + return nil, errutil.UserError{Err: fmt.Sprintf( "email %s not allowed by this role", badName)} } else if err != nil { - return nil, certutil.InternalError{Err: fmt.Sprintf( + return nil, errutil.InternalError{Err: fmt.Sprintf( "error validating name %s: %s", badName, err)} } } @@ -609,13 +610,13 @@ func generateCreationBundle(b *backend, ipAlt := ipAltInt.(string) if len(ipAlt) != 0 { if !role.AllowIPSANs { - return nil, certutil.UserError{Err: fmt.Sprintf( + return nil, errutil.UserError{Err: fmt.Sprintf( "IP Subject Alternative Names are not allowed in this role, but was provided %s", ipAlt)} } for _, v := range strings.Split(ipAlt, ",") { parsedIP := net.ParseIP(v) if parsedIP == nil { - return nil, certutil.UserError{Err: fmt.Sprintf( + return nil, errutil.UserError{Err: fmt.Sprintf( "the value '%s' is not a valid IP address", v)} } ipAddresses = append(ipAddresses, parsedIP) @@ -642,7 +643,7 @@ func generateCreationBundle(b *backend, } else { ttl, err = time.ParseDuration(ttlField) if err != nil { - return nil, certutil.UserError{Err: fmt.Sprintf( + return nil, errutil.UserError{Err: fmt.Sprintf( "invalid requested ttl: %s", err)} } } @@ -652,7 +653,7 @@ func generateCreationBundle(b *backend, } else { maxTTL, err = time.ParseDuration(role.MaxTTL) if err != nil { - return nil, certutil.UserError{Err: fmt.Sprintf( + return nil, errutil.UserError{Err: fmt.Sprintf( "invalid ttl: %s", err)} } } @@ -663,7 +664,7 @@ func generateCreationBundle(b *backend, if len(ttlField) == 0 { ttl = maxTTL } else { - return nil, certutil.UserError{Err: fmt.Sprintf( + return nil, errutil.UserError{Err: fmt.Sprintf( "ttl is larger than maximum allowed (%d)", maxTTL/time.Second)} } } @@ -672,7 +673,7 @@ func generateCreationBundle(b *backend, // valid past the lifetime of the CA certificate if signingBundle != nil && time.Now().Add(ttl).After(signingBundle.Certificate.NotAfter) { - return nil, certutil.UserError{Err: fmt.Sprintf( + return nil, errutil.UserError{Err: fmt.Sprintf( "cannot satisfy request, as TTL is beyond the expiration of the CA certificate")} } } @@ -782,7 +783,7 @@ func createCertificate(creationInfo *creationBundle) (*certutil.ParsedCertBundle subjKeyID, err := certutil.GetSubjKeyID(result.PrivateKey) if err != nil { - return nil, certutil.InternalError{Err: fmt.Sprintf("error getting subject key ID: %s", err)} + return nil, errutil.InternalError{Err: fmt.Sprintf("error getting subject key ID: %s", err)} } subject := pkix.Name{ @@ -845,13 +846,13 @@ func createCertificate(creationInfo *creationBundle) (*certutil.ParsedCertBundle } if err != nil { - return nil, certutil.InternalError{Err: fmt.Sprintf("unable to create certificate: %s", err)} + return nil, errutil.InternalError{Err: fmt.Sprintf("unable to create certificate: %s", err)} } result.CertificateBytes = certBytes result.Certificate, err = x509.ParseCertificate(certBytes) if err != nil { - return nil, certutil.InternalError{Err: fmt.Sprintf("unable to parse created certificate: %s", err)} + return nil, errutil.InternalError{Err: fmt.Sprintf("unable to parse created certificate: %s", err)} } if creationInfo.SigningBundle != nil { @@ -898,13 +899,13 @@ func createCSR(creationInfo *creationBundle) (*certutil.ParsedCSRBundle, error) csr, err := x509.CreateCertificateRequest(rand.Reader, csrTemplate, result.PrivateKey) if err != nil { - return nil, certutil.InternalError{Err: fmt.Sprintf("unable to create certificate: %s", err)} + return nil, errutil.InternalError{Err: fmt.Sprintf("unable to create certificate: %s", err)} } result.CSRBytes = csr result.CSR, err = x509.ParseCertificateRequest(csr) if err != nil { - return nil, certutil.InternalError{Err: fmt.Sprintf("unable to parse created certificate: %s", err)} + return nil, errutil.InternalError{Err: fmt.Sprintf("unable to parse created certificate: %s", err)} } return result, nil @@ -916,16 +917,16 @@ func signCertificate(creationInfo *creationBundle, csr *x509.CertificateRequest) (*certutil.ParsedCertBundle, error) { switch { case creationInfo == nil: - return nil, certutil.UserError{Err: "nil creation info given to signCertificate"} + return nil, errutil.UserError{Err: "nil creation info given to signCertificate"} case creationInfo.SigningBundle == nil: - return nil, certutil.UserError{Err: "nil signing bundle given to signCertificate"} + return nil, errutil.UserError{Err: "nil signing bundle given to signCertificate"} case csr == nil: - return nil, certutil.UserError{Err: "nil csr given to signCertificate"} + return nil, errutil.UserError{Err: "nil csr given to signCertificate"} } err := csr.CheckSignature() if err != nil { - return nil, certutil.UserError{Err: "request signature invalid"} + return nil, errutil.UserError{Err: "request signature invalid"} } result := &certutil.ParsedCertBundle{} @@ -937,7 +938,7 @@ func signCertificate(creationInfo *creationBundle, marshaledKey, err := x509.MarshalPKIXPublicKey(csr.PublicKey) if err != nil { - return nil, certutil.InternalError{Err: fmt.Sprintf("error marshalling public key: %s", err)} + return nil, errutil.InternalError{Err: fmt.Sprintf("error marshalling public key: %s", err)} } subjKeyID := sha1.Sum(marshaledKey) @@ -989,7 +990,7 @@ func signCertificate(creationInfo *creationBundle, if creationInfo.SigningBundle.Certificate.MaxPathLen == 0 && creationInfo.SigningBundle.Certificate.MaxPathLenZero { - return nil, certutil.UserError{Err: "signing certificate has a max path length of zero, and cannot issue further CA certificates"} + return nil, errutil.UserError{Err: "signing certificate has a max path length of zero, and cannot issue further CA certificates"} } certTemplate.MaxPathLen = creationInfo.MaxPathLength @@ -1001,13 +1002,13 @@ func signCertificate(creationInfo *creationBundle, certBytes, err = x509.CreateCertificate(rand.Reader, certTemplate, caCert, csr.PublicKey, creationInfo.SigningBundle.PrivateKey) if err != nil { - return nil, certutil.InternalError{Err: fmt.Sprintf("unable to create certificate: %s", err)} + return nil, errutil.InternalError{Err: fmt.Sprintf("unable to create certificate: %s", err)} } result.CertificateBytes = certBytes result.Certificate, err = x509.ParseCertificate(certBytes) if err != nil { - return nil, certutil.InternalError{Err: fmt.Sprintf("unable to parse created certificate: %s", err)} + return nil, errutil.InternalError{Err: fmt.Sprintf("unable to parse created certificate: %s", err)} } result.IssuingCABytes = creationInfo.SigningBundle.CertificateBytes diff --git a/builtin/logical/pki/crl_util.go b/builtin/logical/pki/crl_util.go index 2dd32b5f3a..13d0f0ff8c 100644 --- a/builtin/logical/pki/crl_util.go +++ b/builtin/logical/pki/crl_util.go @@ -7,7 +7,7 @@ import ( "fmt" "time" - "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/helper/errutil" "github.com/hashicorp/vault/logical" ) @@ -33,9 +33,9 @@ func revokeCert(b *backend, req *logical.Request, serial string, fromLease bool) certEntry, err := fetchCertBySerial(req, "revoked/", serial) if err != nil { switch err.(type) { - case certutil.UserError: + case errutil.UserError: return logical.ErrorResponse(err.Error()), nil - case certutil.InternalError: + case errutil.InternalError: return nil, err } } @@ -58,9 +58,9 @@ func revokeCert(b *backend, req *logical.Request, serial string, fromLease bool) certEntry, err = fetchCertBySerial(req, "certs/", serial) if err != nil { switch err.(type) { - case certutil.UserError: + case errutil.UserError: return logical.ErrorResponse(err.Error()), nil - case certutil.InternalError: + case errutil.InternalError: return nil, err } } @@ -103,9 +103,9 @@ func revokeCert(b *backend, req *logical.Request, serial string, fromLease bool) crlErr := buildCRL(b, req) switch crlErr.(type) { - case certutil.UserError: + case errutil.UserError: return logical.ErrorResponse(fmt.Sprintf("Error during CRL building: %s", crlErr)), nil - case certutil.InternalError: + case errutil.InternalError: return nil, fmt.Errorf("Error encountered during CRL building: %s", crlErr) } @@ -121,7 +121,7 @@ func revokeCert(b *backend, req *logical.Request, serial string, fromLease bool) func buildCRL(b *backend, req *logical.Request) error { revokedSerials, err := req.Storage.List("revoked/") if err != nil { - return certutil.InternalError{Err: fmt.Sprintf("Error fetching list of revoked certs: %s", err)} + return errutil.InternalError{Err: fmt.Sprintf("Error fetching list of revoked certs: %s", err)} } revokedCerts := []pkix.RevokedCertificate{} @@ -129,26 +129,26 @@ func buildCRL(b *backend, req *logical.Request) error { for _, serial := range revokedSerials { revokedEntry, err := req.Storage.Get("revoked/" + serial) if err != nil { - return certutil.InternalError{Err: fmt.Sprintf("Unable to fetch revoked cert with serial %s: %s", serial, err)} + return errutil.InternalError{Err: fmt.Sprintf("Unable to fetch revoked cert with serial %s: %s", serial, err)} } if revokedEntry == nil { - return certutil.InternalError{Err: fmt.Sprintf("Revoked certificate entry for serial %s is nil", serial)} + return errutil.InternalError{Err: fmt.Sprintf("Revoked certificate entry for serial %s is nil", serial)} } if revokedEntry.Value == nil || len(revokedEntry.Value) == 0 { // TODO: In this case, remove it and continue? How likely is this to // happen? Alternately, could skip it entirely, or could implement a // delete function so that there is a way to remove these - return certutil.InternalError{Err: fmt.Sprintf("Found revoked serial but actual certificate is empty")} + return errutil.InternalError{Err: fmt.Sprintf("Found revoked serial but actual certificate is empty")} } err = revokedEntry.DecodeJSON(&revInfo) if err != nil { - return certutil.InternalError{Err: fmt.Sprintf("Error decoding revocation entry for serial %s: %s", serial, err)} + return errutil.InternalError{Err: fmt.Sprintf("Error decoding revocation entry for serial %s: %s", serial, err)} } revokedCert, err := x509.ParseCertificate(revInfo.CertificateBytes) if err != nil { - return certutil.InternalError{Err: fmt.Sprintf("Unable to parse stored revoked certificate with serial %s: %s", serial, err)} + return errutil.InternalError{Err: fmt.Sprintf("Unable to parse stored revoked certificate with serial %s: %s", serial, err)} } revokedCerts = append(revokedCerts, pkix.RevokedCertificate{ @@ -159,28 +159,28 @@ func buildCRL(b *backend, req *logical.Request) error { signingBundle, caErr := fetchCAInfo(req) switch caErr.(type) { - case certutil.UserError: - return certutil.UserError{Err: fmt.Sprintf("Could not fetch the CA certificate: %s", caErr)} - case certutil.InternalError: - return certutil.InternalError{Err: fmt.Sprintf("Error fetching CA certificate: %s", caErr)} + case errutil.UserError: + return errutil.UserError{Err: fmt.Sprintf("Could not fetch the CA certificate: %s", caErr)} + case errutil.InternalError: + return errutil.InternalError{Err: fmt.Sprintf("Error fetching CA certificate: %s", caErr)} } crlLifetime := b.crlLifetime crlInfo, err := b.CRL(req.Storage) if err != nil { - return certutil.InternalError{Err: fmt.Sprintf("Error fetching CRL config information: %s", err)} + return errutil.InternalError{Err: fmt.Sprintf("Error fetching CRL config information: %s", err)} } if crlInfo != nil { crlDur, err := time.ParseDuration(crlInfo.Expiry) if err != nil { - return certutil.InternalError{Err: fmt.Sprintf("Error parsing CRL duration of %s", crlInfo.Expiry)} + return errutil.InternalError{Err: fmt.Sprintf("Error parsing CRL duration of %s", crlInfo.Expiry)} } crlLifetime = crlDur } crlBytes, err := signingBundle.Certificate.CreateCRL(rand.Reader, signingBundle.PrivateKey, revokedCerts, time.Now(), time.Now().Add(crlLifetime)) if err != nil { - return certutil.InternalError{Err: fmt.Sprintf("Error creating new CRL: %s", err)} + return errutil.InternalError{Err: fmt.Sprintf("Error creating new CRL: %s", err)} } err = req.Storage.Put(&logical.StorageEntry{ @@ -188,7 +188,7 @@ func buildCRL(b *backend, req *logical.Request) error { Value: crlBytes, }) if err != nil { - return certutil.InternalError{Err: fmt.Sprintf("Error storing CRL: %s", err)} + return errutil.InternalError{Err: fmt.Sprintf("Error storing CRL: %s", err)} } return nil diff --git a/builtin/logical/pki/path_config_ca.go b/builtin/logical/pki/path_config_ca.go index 8e4799d708..18324d7cb3 100644 --- a/builtin/logical/pki/path_config_ca.go +++ b/builtin/logical/pki/path_config_ca.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/helper/errutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -37,7 +38,7 @@ func (b *backend) pathCAWrite( parsedBundle, err := certutil.ParsePEMBundle(pemBundle) if err != nil { switch err.(type) { - case certutil.InternalError: + case errutil.InternalError: return nil, err default: return logical.ErrorResponse(err.Error()), nil diff --git a/builtin/logical/pki/path_fetch.go b/builtin/logical/pki/path_fetch.go index b37e9ff0fe..acc6d1c902 100644 --- a/builtin/logical/pki/path_fetch.go +++ b/builtin/logical/pki/path_fetch.go @@ -4,7 +4,7 @@ import ( "encoding/pem" "fmt" - "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/helper/errutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -139,10 +139,10 @@ func (b *backend) pathFetchRead(req *logical.Request, data *framework.FieldData) certEntry, funcErr = fetchCertBySerial(req, req.Path, serial) if funcErr != nil { switch funcErr.(type) { - case certutil.UserError: + case errutil.UserError: response = logical.ErrorResponse(funcErr.Error()) goto reply - case certutil.InternalError: + case errutil.InternalError: retErr = funcErr goto reply } @@ -165,10 +165,10 @@ func (b *backend) pathFetchRead(req *logical.Request, data *framework.FieldData) revokedEntry, funcErr = fetchCertBySerial(req, "revoked/", serial) if funcErr != nil { switch funcErr.(type) { - case certutil.UserError: + case errutil.UserError: response = logical.ErrorResponse(funcErr.Error()) goto reply - case certutil.InternalError: + case errutil.InternalError: retErr = funcErr goto reply } diff --git a/builtin/logical/pki/path_intermediate.go b/builtin/logical/pki/path_intermediate.go index 667caf5cac..ede2b03df4 100644 --- a/builtin/logical/pki/path_intermediate.go +++ b/builtin/logical/pki/path_intermediate.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/helper/errutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -65,9 +66,9 @@ func (b *backend) pathGenerateIntermediate( parsedBundle, err := generateIntermediateCSR(b, role, nil, req, data) if err != nil { switch err.(type) { - case certutil.UserError: + case errutil.UserError: return logical.ErrorResponse(err.Error()), nil - case certutil.InternalError: + case errutil.InternalError: return nil, err } } @@ -132,7 +133,7 @@ func (b *backend) pathSetSignedIntermediate( inputBundle, err := certutil.ParsePEMBundle(cert) if err != nil { switch err.(type) { - case certutil.InternalError: + case errutil.InternalError: return nil, err default: return logical.ErrorResponse(err.Error()), nil diff --git a/builtin/logical/pki/path_issue_sign.go b/builtin/logical/pki/path_issue_sign.go index 9d7a81984f..5cd6fdcdac 100644 --- a/builtin/logical/pki/path_issue_sign.go +++ b/builtin/logical/pki/path_issue_sign.go @@ -6,6 +6,7 @@ import ( "time" "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/helper/errutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -140,11 +141,11 @@ func (b *backend) pathIssueSignCert( var caErr error signingBundle, caErr := fetchCAInfo(req) switch caErr.(type) { - case certutil.UserError: - return nil, certutil.UserError{Err: fmt.Sprintf( + case errutil.UserError: + return nil, errutil.UserError{Err: fmt.Sprintf( "Could not fetch the CA certificate (was one set?): %s", caErr)} - case certutil.InternalError: - return nil, certutil.InternalError{Err: fmt.Sprintf( + case errutil.InternalError: + return nil, errutil.InternalError{Err: fmt.Sprintf( "Error fetching CA certificate: %s", caErr)} } @@ -157,9 +158,9 @@ func (b *backend) pathIssueSignCert( } if err != nil { switch err.(type) { - case certutil.UserError: + case errutil.UserError: return logical.ErrorResponse(err.Error()), nil - case certutil.InternalError: + case errutil.InternalError: return nil, err } } diff --git a/builtin/logical/pki/path_revoke.go b/builtin/logical/pki/path_revoke.go index ce0ae1afe9..291199564a 100644 --- a/builtin/logical/pki/path_revoke.go +++ b/builtin/logical/pki/path_revoke.go @@ -4,7 +4,7 @@ import ( "fmt" "strings" - "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/helper/errutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -64,9 +64,9 @@ func (b *backend) pathRotateCRLRead(req *logical.Request, data *framework.FieldD crlErr := buildCRL(b, req) switch crlErr.(type) { - case certutil.UserError: + case errutil.UserError: return logical.ErrorResponse(fmt.Sprintf("Error during CRL building: %s", crlErr)), nil - case certutil.InternalError: + case errutil.InternalError: return nil, fmt.Errorf("Error encountered during CRL building: %s", crlErr) default: return &logical.Response{ diff --git a/builtin/logical/pki/path_root.go b/builtin/logical/pki/path_root.go index b127533dc8..93ed3dd86e 100644 --- a/builtin/logical/pki/path_root.go +++ b/builtin/logical/pki/path_root.go @@ -4,7 +4,7 @@ import ( "encoding/base64" "fmt" - "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/helper/errutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -84,9 +84,9 @@ func (b *backend) pathCAGenerateRoot( parsedBundle, err := generateCert(b, role, nil, true, req, data) if err != nil { switch err.(type) { - case certutil.UserError: + case errutil.UserError: return logical.ErrorResponse(err.Error()), nil - case certutil.InternalError: + case errutil.InternalError: return nil, err } } @@ -201,11 +201,11 @@ func (b *backend) pathCASignIntermediate( var caErr error signingBundle, caErr := fetchCAInfo(req) switch caErr.(type) { - case certutil.UserError: - return nil, certutil.UserError{Err: fmt.Sprintf( + case errutil.UserError: + return nil, errutil.UserError{Err: fmt.Sprintf( "could not fetch the CA certificate (was one set?): %s", caErr)} - case certutil.InternalError: - return nil, certutil.InternalError{Err: fmt.Sprintf( + case errutil.InternalError: + return nil, errutil.InternalError{Err: fmt.Sprintf( "error fetching CA certificate: %s", caErr)} } @@ -220,9 +220,9 @@ func (b *backend) pathCASignIntermediate( parsedBundle, err := signCert(b, role, signingBundle, true, useCSRValues, req, data) if err != nil { switch err.(type) { - case certutil.UserError: + case errutil.UserError: return logical.ErrorResponse(err.Error()), nil - case certutil.InternalError: + case errutil.InternalError: return nil, err } } diff --git a/builtin/logical/transit/path_datakey.go b/builtin/logical/transit/path_datakey.go index 817529e9dd..f96ae4adad 100644 --- a/builtin/logical/transit/path_datakey.go +++ b/builtin/logical/transit/path_datakey.go @@ -5,7 +5,7 @@ import ( "encoding/base64" "fmt" - "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/helper/errutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -103,9 +103,9 @@ func (b *backend) pathDatakeyWrite( ciphertext, err := p.Encrypt(context, base64.StdEncoding.EncodeToString(newKey)) if err != nil { switch err.(type) { - case certutil.UserError: + case errutil.UserError: return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest - case certutil.InternalError: + case errutil.InternalError: return nil, err default: return nil, err diff --git a/builtin/logical/transit/path_decrypt.go b/builtin/logical/transit/path_decrypt.go index 254b2b5093..3e0d0e1314 100644 --- a/builtin/logical/transit/path_decrypt.go +++ b/builtin/logical/transit/path_decrypt.go @@ -4,7 +4,7 @@ import ( "encoding/base64" "fmt" - "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/helper/errutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -72,9 +72,9 @@ func (b *backend) pathDecryptWrite( plaintext, err := p.Decrypt(context, ciphertext) if err != nil { switch err.(type) { - case certutil.UserError: + case errutil.UserError: return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest - case certutil.InternalError: + case errutil.InternalError: return nil, err default: return nil, err diff --git a/builtin/logical/transit/path_encrypt.go b/builtin/logical/transit/path_encrypt.go index 5ae3032d19..153757d833 100644 --- a/builtin/logical/transit/path_encrypt.go +++ b/builtin/logical/transit/path_encrypt.go @@ -5,7 +5,7 @@ import ( "fmt" "sync" - "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/helper/errutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -96,9 +96,9 @@ func (b *backend) pathEncryptWrite( ciphertext, err := p.Encrypt(context, value) if err != nil { switch err.(type) { - case certutil.UserError: + case errutil.UserError: return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest - case certutil.InternalError: + case errutil.InternalError: return nil, err default: return nil, err diff --git a/builtin/logical/transit/path_rewrap.go b/builtin/logical/transit/path_rewrap.go index a5854feeea..2d40152506 100644 --- a/builtin/logical/transit/path_rewrap.go +++ b/builtin/logical/transit/path_rewrap.go @@ -4,7 +4,7 @@ import ( "encoding/base64" "fmt" - "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/helper/errutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -74,9 +74,9 @@ func (b *backend) pathRewrapWrite( plaintext, err := p.Decrypt(context, value) if err != nil { switch err.(type) { - case certutil.UserError: + case errutil.UserError: return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest - case certutil.InternalError: + case errutil.InternalError: return nil, err default: return nil, err @@ -90,9 +90,9 @@ func (b *backend) pathRewrapWrite( ciphertext, err := p.Encrypt(context, plaintext) if err != nil { switch err.(type) { - case certutil.UserError: + case errutil.UserError: return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest - case certutil.InternalError: + case errutil.InternalError: return nil, err default: return nil, err diff --git a/builtin/logical/transit/policy.go b/builtin/logical/transit/policy.go index e8748712bc..611aa5c3af 100644 --- a/builtin/logical/transit/policy.go +++ b/builtin/logical/transit/policy.go @@ -11,7 +11,7 @@ import ( "strings" "time" - "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/helper/errutil" "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/kdf" "github.com/hashicorp/vault/logical" @@ -304,15 +304,15 @@ func (p *Policy) upgrade(storage logical.Storage) error { // mode is used with the context to derive the proper key. func (p *Policy) DeriveKey(context []byte, ver int) ([]byte, error) { if p.Keys == nil || p.LatestVersion == 0 { - return nil, certutil.InternalError{Err: "unable to access the key; no key versions found"} + return nil, errutil.InternalError{Err: "unable to access the key; no key versions found"} } if p.LatestVersion == 0 { - return nil, certutil.InternalError{Err: "unable to access the key; no key versions found"} + return nil, errutil.InternalError{Err: "unable to access the key; no key versions found"} } if ver <= 0 || ver > p.LatestVersion { - return nil, certutil.UserError{Err: "invalid key version"} + return nil, errutil.UserError{Err: "invalid key version"} } // Fast-path non-derived keys @@ -322,7 +322,7 @@ func (p *Policy) DeriveKey(context []byte, ver int) ([]byte, error) { // Ensure a context is provided if len(context) == 0 { - return nil, certutil.UserError{Err: "missing 'context' for key deriviation. The key was created using a derived key, which means additional, per-request information must be included in order to encrypt or decrypt information"} + return nil, errutil.UserError{Err: "missing 'context' for key deriviation. The key was created using a derived key, which means additional, per-request information must be included in order to encrypt or decrypt information"} } switch p.KDFMode { @@ -331,7 +331,7 @@ func (p *Policy) DeriveKey(context []byte, ver int) ([]byte, error) { prfLen := kdf.HMACSHA256PRFLen return kdf.CounterMode(prf, prfLen, p.Keys[ver].Key, context, 256) default: - return nil, certutil.InternalError{Err: "unsupported key derivation mode"} + return nil, errutil.InternalError{Err: "unsupported key derivation mode"} } } @@ -339,7 +339,7 @@ func (p *Policy) Encrypt(context []byte, value string) (string, error) { // Decode the plaintext value plaintext, err := base64.StdEncoding.DecodeString(value) if err != nil { - return "", certutil.UserError{Err: "failed to decode plaintext as base64"} + return "", errutil.UserError{Err: "failed to decode plaintext as base64"} } // Derive the key that should be used @@ -352,23 +352,23 @@ func (p *Policy) Encrypt(context []byte, value string) (string, error) { switch p.CipherMode { case "aes-gcm": default: - return "", certutil.InternalError{Err: "unsupported cipher mode"} + return "", errutil.InternalError{Err: "unsupported cipher mode"} } // Setup the cipher aesCipher, err := aes.NewCipher(key) if err != nil { - return "", certutil.InternalError{Err: err.Error()} + return "", errutil.InternalError{Err: err.Error()} } // Setup the GCM AEAD gcm, err := cipher.NewGCM(aesCipher) if err != nil { - return "", certutil.InternalError{Err: err.Error()} + return "", errutil.InternalError{Err: err.Error()} } if p.ConvergentEncryption && len(context) != gcm.NonceSize() { - return "", certutil.UserError{Err: fmt.Sprintf("base64-decoded context must be %d bytes long when using convergent encryption with this key", gcm.NonceSize())} + return "", errutil.UserError{Err: fmt.Sprintf("base64-decoded context must be %d bytes long when using convergent encryption with this key", gcm.NonceSize())} } // Compute random nonce @@ -379,7 +379,7 @@ func (p *Policy) Encrypt(context []byte, value string) (string, error) { nonce = make([]byte, gcm.NonceSize()) _, err = rand.Read(nonce) if err != nil { - return "", certutil.InternalError{Err: err.Error()} + return "", errutil.InternalError{Err: err.Error()} } } @@ -401,17 +401,17 @@ func (p *Policy) Encrypt(context []byte, value string) (string, error) { func (p *Policy) Decrypt(context []byte, value string) (string, error) { // Verify the prefix if !strings.HasPrefix(value, "vault:v") { - return "", certutil.UserError{Err: "invalid ciphertext: no prefix"} + return "", errutil.UserError{Err: "invalid ciphertext: no prefix"} } splitVerCiphertext := strings.SplitN(strings.TrimPrefix(value, "vault:v"), ":", 2) if len(splitVerCiphertext) != 2 { - return "", certutil.UserError{Err: "invalid ciphertext: wrong number of fields"} + return "", errutil.UserError{Err: "invalid ciphertext: wrong number of fields"} } ver, err := strconv.Atoi(splitVerCiphertext[0]) if err != nil { - return "", certutil.UserError{Err: "invalid ciphertext: version number could not be decoded"} + return "", errutil.UserError{Err: "invalid ciphertext: version number could not be decoded"} } if ver == 0 { @@ -421,11 +421,11 @@ func (p *Policy) Decrypt(context []byte, value string) (string, error) { } if ver > p.LatestVersion { - return "", certutil.UserError{Err: "invalid ciphertext: version is too new"} + return "", errutil.UserError{Err: "invalid ciphertext: version is too new"} } if p.MinDecryptionVersion > 0 && ver < p.MinDecryptionVersion { - return "", certutil.UserError{Err: ErrTooOld} + return "", errutil.UserError{Err: ErrTooOld} } // Derive the key that should be used @@ -438,25 +438,25 @@ func (p *Policy) Decrypt(context []byte, value string) (string, error) { switch p.CipherMode { case "aes-gcm": default: - return "", certutil.InternalError{Err: "unsupported cipher mode"} + return "", errutil.InternalError{Err: "unsupported cipher mode"} } // Decode the base64 decoded, err := base64.StdEncoding.DecodeString(splitVerCiphertext[1]) if err != nil { - return "", certutil.UserError{Err: "invalid ciphertext: could not decode base64"} + return "", errutil.UserError{Err: "invalid ciphertext: could not decode base64"} } // Setup the cipher aesCipher, err := aes.NewCipher(key) if err != nil { - return "", certutil.InternalError{Err: err.Error()} + return "", errutil.InternalError{Err: err.Error()} } // Setup the GCM AEAD gcm, err := cipher.NewGCM(aesCipher) if err != nil { - return "", certutil.InternalError{Err: err.Error()} + return "", errutil.InternalError{Err: err.Error()} } // Extract the nonce and ciphertext @@ -466,7 +466,7 @@ func (p *Policy) Decrypt(context []byte, value string) (string, error) { // Verify and Decrypt plain, err := gcm.Open(nil, nonce, ciphertext, nil) if err != nil { - return "", certutil.UserError{Err: "invalid ciphertext: unable to decrypt"} + return "", errutil.UserError{Err: "invalid ciphertext: unable to decrypt"} } return base64.StdEncoding.EncodeToString(plain), nil diff --git a/helper/certutil/helpers.go b/helper/certutil/helpers.go index d7408e449a..4574842cf8 100644 --- a/helper/certutil/helpers.go +++ b/helper/certutil/helpers.go @@ -15,6 +15,7 @@ import ( "strconv" "strings" + "github.com/hashicorp/vault/helper/errutil" "github.com/hashicorp/vault/helper/jsonutil" "github.com/mitchellh/mapstructure" ) @@ -52,12 +53,12 @@ func ParseHexFormatted(in, sep string) []byte { // of the marshaled public key func GetSubjKeyID(privateKey crypto.Signer) ([]byte, error) { if privateKey == nil { - return nil, InternalError{"passed-in private key is nil"} + return nil, errutil.InternalError{"passed-in private key is nil"} } marshaledKey, err := x509.MarshalPKIXPublicKey(privateKey.Public()) if err != nil { - return nil, InternalError{fmt.Sprintf("error marshalling public key: %s", err)} + return nil, errutil.InternalError{fmt.Sprintf("error marshalling public key: %s", err)} } subjKeyID := sha1.Sum(marshaledKey) @@ -71,7 +72,7 @@ func ParsePKIMap(data map[string]interface{}) (*ParsedCertBundle, error) { result := &CertBundle{} err := mapstructure.Decode(data, result) if err != nil { - return nil, UserError{err.Error()} + return nil, errutil.UserError{err.Error()} } return result.ToParsedCertBundle() @@ -97,7 +98,7 @@ func ParsePKIJSON(input []byte) (*ParsedCertBundle, error) { return ParsePKIMap(secret.Data) } - return nil, UserError{"unable to parse out of either secret data or a secret object"} + return nil, errutil.UserError{"unable to parse out of either secret data or a secret object"} } // ParsePEMBundle takes a string of concatenated PEM-format certificate @@ -106,7 +107,7 @@ func ParsePKIJSON(input []byte) (*ParsedCertBundle, error) { // issuing certificate) and one private key. func ParsePEMBundle(pemBundle string) (*ParsedCertBundle, error) { if len(pemBundle) == 0 { - return nil, UserError{"empty pem bundle"} + return nil, errutil.UserError{"empty pem bundle"} } pemBundle = strings.TrimSpace(pemBundle) @@ -118,12 +119,12 @@ func ParsePEMBundle(pemBundle string) (*ParsedCertBundle, error) { for { pemBlock, pemBytes = pem.Decode(pemBytes) if pemBlock == nil { - return nil, UserError{"no data found"} + return nil, errutil.UserError{"no data found"} } if signer, err := x509.ParseECPrivateKey(pemBlock.Bytes); err == nil { if parsedBundle.PrivateKeyType != UnknownPrivateKey { - return nil, UserError{"more than one private key given; provide only one private key in the bundle"} + return nil, errutil.UserError{"more than one private key given; provide only one private key in the bundle"} } parsedBundle.PrivateKeyFormat = ECBlock parsedBundle.PrivateKeyType = ECPrivateKey @@ -132,7 +133,7 @@ func ParsePEMBundle(pemBundle string) (*ParsedCertBundle, error) { } else if signer, err := x509.ParsePKCS1PrivateKey(pemBlock.Bytes); err == nil { if parsedBundle.PrivateKeyType != UnknownPrivateKey { - return nil, UserError{"more than one private key given; provide only one private key in the bundle"} + return nil, errutil.UserError{"more than one private key given; provide only one private key in the bundle"} } parsedBundle.PrivateKeyType = RSAPrivateKey parsedBundle.PrivateKeyFormat = PKCS1Block @@ -142,7 +143,7 @@ func ParsePEMBundle(pemBundle string) (*ParsedCertBundle, error) { parsedBundle.PrivateKeyFormat = PKCS8Block if parsedBundle.PrivateKeyType != UnknownPrivateKey { - return nil, UserError{"More than one private key given; provide only one private key in the bundle"} + return nil, errutil.UserError{"More than one private key given; provide only one private key in the bundle"} } switch signer := signer.(type) { case *rsa.PrivateKey: @@ -157,7 +158,7 @@ func ParsePEMBundle(pemBundle string) (*ParsedCertBundle, error) { } else if certificates, err := x509.ParseCertificates(pemBlock.Bytes); err == nil { switch len(certificates) { case 0: - return nil, UserError{"pem block cannot be decoded to a private key or certificate"} + return nil, errutil.UserError{"pem block cannot be decoded to a private key or certificate"} case 1: if parsedBundle.Certificate != nil { @@ -190,7 +191,7 @@ func ParsePEMBundle(pemBundle string) (*ParsedCertBundle, error) { } default: - return nil, UserError{"too many certificates given; provide a maximum of two certificates in the bundle"} + return nil, errutil.UserError{"too many certificates given; provide a maximum of two certificates in the bundle"} } } @@ -214,7 +215,7 @@ func GeneratePrivateKey(keyType string, keyBits int, container ParsedPrivateKeyC privateKeyType = RSAPrivateKey privateKey, err = rsa.GenerateKey(rand.Reader, keyBits) if err != nil { - return InternalError{Err: fmt.Sprintf("error generating RSA private key: %v", err)} + return errutil.InternalError{Err: fmt.Sprintf("error generating RSA private key: %v", err)} } privateKeyBytes = x509.MarshalPKCS1PrivateKey(privateKey.(*rsa.PrivateKey)) case "ec": @@ -230,18 +231,18 @@ func GeneratePrivateKey(keyType string, keyBits int, container ParsedPrivateKeyC case 521: curve = elliptic.P521() default: - return UserError{Err: fmt.Sprintf("unsupported bit length for EC key: %d", keyBits)} + return errutil.UserError{Err: fmt.Sprintf("unsupported bit length for EC key: %d", keyBits)} } privateKey, err = ecdsa.GenerateKey(curve, rand.Reader) if err != nil { - return InternalError{Err: fmt.Sprintf("error generating EC private key: %v", err)} + return errutil.InternalError{Err: fmt.Sprintf("error generating EC private key: %v", err)} } privateKeyBytes, err = x509.MarshalECPrivateKey(privateKey.(*ecdsa.PrivateKey)) if err != nil { - return InternalError{Err: fmt.Sprintf("error marshalling EC private key: %v", err)} + return errutil.InternalError{Err: fmt.Sprintf("error marshalling EC private key: %v", err)} } default: - return UserError{Err: fmt.Sprintf("unknown key type: %s", keyType)} + return errutil.UserError{Err: fmt.Sprintf("unknown key type: %s", keyType)} } container.SetParsedPrivateKey(privateKey, privateKeyType, privateKeyBytes) @@ -252,7 +253,7 @@ func GeneratePrivateKey(keyType string, keyBits int, container ParsedPrivateKeyC func GenerateSerialNumber() (*big.Int, error) { serial, err := rand.Int(rand.Reader, (&big.Int{}).Exp(big.NewInt(2), big.NewInt(159), nil)) if err != nil { - return nil, InternalError{Err: fmt.Sprintf("error generating serial number: %v", err)} + return nil, errutil.InternalError{Err: fmt.Sprintf("error generating serial number: %v", err)} } return serial, nil } diff --git a/helper/certutil/types.go b/helper/certutil/types.go index 6ebae6fdb8..6c40fc3c91 100644 --- a/helper/certutil/types.go +++ b/helper/certutil/types.go @@ -15,6 +15,8 @@ import ( "encoding/pem" "fmt" "strings" + + "github.com/hashicorp/vault/helper/errutil" ) // Secret is used to attempt to unmarshal a Vault secret @@ -57,25 +59,6 @@ const ( ECBlock BlockType = "EC PRIVATE KEY" ) -// UserError represents an error generated due to invalid user input -type UserError struct { - Err string -} - -func (e UserError) Error() string { - return e.Err -} - -// InternalError represents an error generated internally, -// presumably not due to invalid user input -type InternalError struct { - Err string -} - -func (e InternalError) Error() string { - return e.Err -} - //ParsedPrivateKeyContainer allows common key setting for certs and CSRs type ParsedPrivateKeyContainer interface { SetParsedPrivateKey(crypto.Signer, PrivateKeyType, []byte) @@ -133,7 +116,7 @@ func (c *CertBundle) ToParsedCertBundle() (*ParsedCertBundle, error) { if len(c.PrivateKey) > 0 { pemBlock, _ = pem.Decode([]byte(c.PrivateKey)) if pemBlock == nil { - return nil, UserError{"Error decoding private key from cert bundle"} + return nil, errutil.UserError{"Error decoding private key from cert bundle"} } result.PrivateKeyBytes = pemBlock.Bytes @@ -147,7 +130,7 @@ func (c *CertBundle) ToParsedCertBundle() (*ParsedCertBundle, error) { case PKCS8Block: t, err := getPKCS8Type(pemBlock.Bytes) if err != nil { - return nil, UserError{fmt.Sprintf("Error getting key type from pkcs#8: %v", err)} + return nil, errutil.UserError{fmt.Sprintf("Error getting key type from pkcs#8: %v", err)} } result.PrivateKeyType = t switch t { @@ -157,36 +140,36 @@ func (c *CertBundle) ToParsedCertBundle() (*ParsedCertBundle, error) { c.PrivateKeyType = RSAPrivateKey } default: - return nil, UserError{fmt.Sprintf("Unsupported key block type: %s", pemBlock.Type)} + return nil, errutil.UserError{fmt.Sprintf("Unsupported key block type: %s", pemBlock.Type)} } result.PrivateKey, err = result.getSigner() if err != nil { - return nil, UserError{fmt.Sprintf("Error getting signer: %s", err)} + return nil, errutil.UserError{fmt.Sprintf("Error getting signer: %s", err)} } } if len(c.Certificate) > 0 { pemBlock, _ = pem.Decode([]byte(c.Certificate)) if pemBlock == nil { - return nil, UserError{"Error decoding certificate from cert bundle"} + return nil, errutil.UserError{"Error decoding certificate from cert bundle"} } result.CertificateBytes = pemBlock.Bytes result.Certificate, err = x509.ParseCertificate(result.CertificateBytes) if err != nil { - return nil, UserError{"Error encountered parsing certificate bytes from raw bundle"} + return nil, errutil.UserError{"Error encountered parsing certificate bytes from raw bundle"} } } if len(c.IssuingCA) > 0 { pemBlock, _ = pem.Decode([]byte(c.IssuingCA)) if pemBlock == nil { - return nil, UserError{"Error decoding issuing CA from cert bundle"} + return nil, errutil.UserError{"Error decoding issuing CA from cert bundle"} } result.IssuingCABytes = pemBlock.Bytes result.IssuingCA, err = x509.ParseCertificate(result.IssuingCABytes) if err != nil { - return nil, UserError{fmt.Sprintf("Error parsing CA certificate: %s", err)} + return nil, errutil.UserError{fmt.Sprintf("Error parsing CA certificate: %s", err)} } } @@ -249,20 +232,20 @@ func (p *ParsedCertBundle) getSigner() (crypto.Signer, error) { var err error if p.PrivateKeyBytes == nil || len(p.PrivateKeyBytes) == 0 { - return nil, UserError{"Given parsed cert bundle does not have private key information"} + return nil, errutil.UserError{"Given parsed cert bundle does not have private key information"} } switch p.PrivateKeyFormat { case ECBlock: signer, err = x509.ParseECPrivateKey(p.PrivateKeyBytes) if err != nil { - return nil, UserError{fmt.Sprintf("Unable to parse CA's private EC key: %s", err)} + return nil, errutil.UserError{fmt.Sprintf("Unable to parse CA's private EC key: %s", err)} } case PKCS1Block: signer, err = x509.ParsePKCS1PrivateKey(p.PrivateKeyBytes) if err != nil { - return nil, UserError{fmt.Sprintf("Unable to parse CA's private RSA key: %s", err)} + return nil, errutil.UserError{fmt.Sprintf("Unable to parse CA's private RSA key: %s", err)} } case PKCS8Block: @@ -271,12 +254,12 @@ func (p *ParsedCertBundle) getSigner() (crypto.Signer, error) { case *rsa.PrivateKey, *ecdsa.PrivateKey: return k.(crypto.Signer), nil default: - return nil, UserError{"Found unknown private key type in pkcs#8 wrapping"} + return nil, errutil.UserError{"Found unknown private key type in pkcs#8 wrapping"} } } - return nil, UserError{fmt.Sprintf("Failed to parse pkcs#8 key: %v", err)} + return nil, errutil.UserError{fmt.Sprintf("Failed to parse pkcs#8 key: %v", err)} default: - return nil, UserError{"Unable to determine type of private key; only RSA and EC are supported"} + return nil, errutil.UserError{"Unable to determine type of private key; only RSA and EC are supported"} } return signer, nil } @@ -291,7 +274,7 @@ func (p *ParsedCertBundle) SetParsedPrivateKey(privateKey crypto.Signer, private func getPKCS8Type(bs []byte) (PrivateKeyType, error) { k, err := x509.ParsePKCS8PrivateKey(bs) if err != nil { - return UnknownPrivateKey, UserError{fmt.Sprintf("Failed to parse pkcs#8 key: %v", err)} + return UnknownPrivateKey, errutil.UserError{fmt.Sprintf("Failed to parse pkcs#8 key: %v", err)} } switch k.(type) { @@ -300,7 +283,7 @@ func getPKCS8Type(bs []byte) (PrivateKeyType, error) { case *rsa.PrivateKey: return RSAPrivateKey, nil default: - return UnknownPrivateKey, UserError{"Found unknown private key type in pkcs#8 wrapping"} + return UnknownPrivateKey, errutil.UserError{"Found unknown private key type in pkcs#8 wrapping"} } } @@ -314,7 +297,7 @@ func (c *CSRBundle) ToParsedCSRBundle() (*ParsedCSRBundle, error) { if len(c.PrivateKey) > 0 { pemBlock, _ = pem.Decode([]byte(c.PrivateKey)) if pemBlock == nil { - return nil, UserError{"Error decoding private key from cert bundle"} + return nil, errutil.UserError{"Error decoding private key from cert bundle"} } result.PrivateKeyBytes = pemBlock.Bytes @@ -332,25 +315,25 @@ func (c *CSRBundle) ToParsedCSRBundle() (*ParsedCSRBundle, error) { result.PrivateKeyType = RSAPrivateKey c.PrivateKeyType = "rsa" } else { - return nil, UserError{fmt.Sprintf("Unknown private key type in bundle: %s", c.PrivateKeyType)} + return nil, errutil.UserError{fmt.Sprintf("Unknown private key type in bundle: %s", c.PrivateKeyType)} } } result.PrivateKey, err = result.getSigner() if err != nil { - return nil, UserError{fmt.Sprintf("Error getting signer: %s", err)} + return nil, errutil.UserError{fmt.Sprintf("Error getting signer: %s", err)} } } if len(c.CSR) > 0 { pemBlock, _ = pem.Decode([]byte(c.CSR)) if pemBlock == nil { - return nil, UserError{"Error decoding certificate from cert bundle"} + return nil, errutil.UserError{"Error decoding certificate from cert bundle"} } result.CSRBytes = pemBlock.Bytes result.CSR, err = x509.ParseCertificateRequest(result.CSRBytes) if err != nil { - return nil, UserError{"Error encountered parsing certificate bytes from raw bundle"} + return nil, errutil.UserError{"Error encountered parsing certificate bytes from raw bundle"} } } @@ -380,7 +363,7 @@ func (p *ParsedCSRBundle) ToCSRBundle() (*CSRBundle, error) { result.PrivateKeyType = "ec" block.Type = "EC PRIVATE KEY" default: - return nil, InternalError{"Could not determine private key type when creating block"} + return nil, errutil.InternalError{"Could not determine private key type when creating block"} } result.PrivateKey = strings.TrimSpace(string(pem.EncodeToMemory(&block))) } @@ -397,24 +380,24 @@ func (p *ParsedCSRBundle) getSigner() (crypto.Signer, error) { var err error if p.PrivateKeyBytes == nil || len(p.PrivateKeyBytes) == 0 { - return nil, UserError{"Given parsed cert bundle does not have private key information"} + return nil, errutil.UserError{"Given parsed cert bundle does not have private key information"} } switch p.PrivateKeyType { case ECPrivateKey: signer, err = x509.ParseECPrivateKey(p.PrivateKeyBytes) if err != nil { - return nil, UserError{fmt.Sprintf("Unable to parse CA's private EC key: %s", err)} + return nil, errutil.UserError{fmt.Sprintf("Unable to parse CA's private EC key: %s", err)} } case RSAPrivateKey: signer, err = x509.ParsePKCS1PrivateKey(p.PrivateKeyBytes) if err != nil { - return nil, UserError{fmt.Sprintf("Unable to parse CA's private RSA key: %s", err)} + return nil, errutil.UserError{fmt.Sprintf("Unable to parse CA's private RSA key: %s", err)} } default: - return nil, UserError{"Unable to determine type of private key; only RSA and EC are supported"} + return nil, errutil.UserError{"Unable to determine type of private key; only RSA and EC are supported"} } return signer, nil } diff --git a/helper/errutil/error.go b/helper/errutil/error.go new file mode 100644 index 0000000000..0b95efb40e --- /dev/null +++ b/helper/errutil/error.go @@ -0,0 +1,20 @@ +package errutil + +// UserError represents an error generated due to invalid user input +type UserError struct { + Err string +} + +func (e UserError) Error() string { + return e.Err +} + +// InternalError represents an error generated internally, +// presumably not due to invalid user input +type InternalError struct { + Err string +} + +func (e InternalError) Error() string { + return e.Err +} diff --git a/logical/framework/backend.go b/logical/framework/backend.go index a4f6dba2ee..0d14b3de82 100644 --- a/logical/framework/backend.go +++ b/logical/framework/backend.go @@ -11,6 +11,7 @@ import ( "time" "github.com/hashicorp/go-multierror" + "github.com/hashicorp/vault/helper/errutil" "github.com/hashicorp/vault/logical" ) @@ -128,7 +129,7 @@ func (b *Backend) HandleExistenceCheck(req *logical.Request) (checkFound bool, e err = fd.Validate() if err != nil { - return false, false, err + return false, false, errutil.UserError{Err: err.Error()} } // Call the callback with the request and the data diff --git a/vault/core.go b/vault/core.go index 30a63095b8..8e3a87c362 100644 --- a/vault/core.go +++ b/vault/core.go @@ -15,6 +15,7 @@ import ( "github.com/hashicorp/go-multierror" "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/audit" + "github.com/hashicorp/vault/helper/errutil" "github.com/hashicorp/vault/helper/mlock" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/physical" @@ -461,7 +462,11 @@ func (c *Core) checkToken(req *logical.Request) (*logical.Auth, *TokenEntry, err // Continue on default: c.logger.Printf("[ERR] core: failed to run existence check: %v", err) - return nil, nil, ErrInternalError + if _, ok := err.(errutil.UserError); ok { + return nil, nil, err + } else { + return nil, nil, ErrInternalError + } } switch {