Fix invalid input getting marked as internal error

This commit is contained in:
vishalnayak 2016-07-28 15:19:27 -04:00
parent 4fd83816bf
commit ddb6ae18a0
18 changed files with 208 additions and 194 deletions

View file

@ -16,6 +16,7 @@ import (
"time"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -117,31 +118,31 @@ func validateKeyTypeLength(keyType string, keyBits int) *logical.Response {
func fetchCAInfo(req *logical.Request) (*caInfoBundle, error) {
bundleEntry, err := req.Storage.Get("config/ca_bundle")
if err != nil {
return nil, certutil.InternalError{Err: fmt.Sprintf("unable to fetch local CA certificate/key: %v", err)}
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to fetch local CA certificate/key: %v", err)}
}
if bundleEntry == nil {
return nil, certutil.UserError{Err: "backend must be configured with a CA certificate/key"}
return nil, errutil.UserError{Err: "backend must be configured with a CA certificate/key"}
}
var bundle certutil.CertBundle
if err := bundleEntry.DecodeJSON(&bundle); err != nil {
return nil, certutil.InternalError{Err: fmt.Sprintf("unable to decode local CA certificate/key: %v", err)}
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to decode local CA certificate/key: %v", err)}
}
parsedBundle, err := bundle.ToParsedCertBundle()
if err != nil {
return nil, certutil.InternalError{Err: err.Error()}
return nil, errutil.InternalError{Err: err.Error()}
}
if parsedBundle.Certificate == nil {
return nil, certutil.InternalError{Err: "stored CA information not able to be parsed"}
return nil, errutil.InternalError{Err: "stored CA information not able to be parsed"}
}
caInfo := &caInfoBundle{*parsedBundle, nil}
entries, err := getURLs(req)
if err != nil {
return nil, certutil.InternalError{Err: fmt.Sprintf("unable to fetch URL information: %v", err)}
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to fetch URL information: %v", err)}
}
if entries == nil {
entries = &urlEntries{
@ -175,14 +176,14 @@ func fetchCertBySerial(req *logical.Request, prefix, serial string) (*logical.St
certEntry, err := req.Storage.Get(path)
if err != nil {
return nil, certutil.InternalError{Err: fmt.Sprintf("error fetching certificate %s: %s", serial, err)}
return nil, errutil.InternalError{Err: fmt.Sprintf("error fetching certificate %s: %s", serial, err)}
}
if certEntry == nil {
return nil, nil
}
if certEntry.Value == nil || len(certEntry.Value) == 0 {
return nil, certutil.InternalError{Err: fmt.Sprintf("returned certificate bytes for serial %s were empty", serial)}
return nil, errutil.InternalError{Err: fmt.Sprintf("returned certificate bytes for serial %s were empty", serial)}
}
return certEntry, nil
@ -356,7 +357,7 @@ func generateCert(b *backend,
data *framework.FieldData) (*certutil.ParsedCertBundle, error) {
if role.KeyType == "rsa" && role.KeyBits < 2048 {
return nil, certutil.UserError{Err: "RSA keys < 2048 bits are unsafe and not supported"}
return nil, errutil.UserError{Err: "RSA keys < 2048 bits are unsafe and not supported"}
}
creationBundle, err := generateCreationBundle(b, role, signingBundle, nil, req, data)
@ -371,7 +372,7 @@ func generateCert(b *backend,
// Generating a self-signed root certificate
entries, err := getURLs(req)
if err != nil {
return nil, certutil.InternalError{Err: fmt.Sprintf("unable to fetch URL information: %v", err)}
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to fetch URL information: %v", err)}
}
if entries == nil {
entries = &urlEntries{
@ -429,40 +430,40 @@ func signCert(b *backend,
csrString := data.Get("csr").(string)
if csrString == "" {
return nil, certutil.UserError{Err: fmt.Sprintf("\"csr\" is empty")}
return nil, errutil.UserError{Err: fmt.Sprintf("\"csr\" is empty")}
}
pemBytes := []byte(csrString)
pemBlock, pemBytes := pem.Decode(pemBytes)
if pemBlock == nil {
return nil, certutil.UserError{Err: "csr contains no data"}
return nil, errutil.UserError{Err: "csr contains no data"}
}
csr, err := x509.ParseCertificateRequest(pemBlock.Bytes)
if err != nil {
return nil, certutil.UserError{Err: "certificate request could not be parsed"}
return nil, errutil.UserError{Err: "certificate request could not be parsed"}
}
switch role.KeyType {
case "rsa":
// Verify that the key matches the role type
if csr.PublicKeyAlgorithm != x509.RSA {
return nil, certutil.UserError{Err: fmt.Sprintf(
return nil, errutil.UserError{Err: fmt.Sprintf(
"role requires keys of type %s",
role.KeyType)}
}
pubKey, ok := csr.PublicKey.(*rsa.PublicKey)
if !ok {
return nil, certutil.UserError{Err: "could not parse CSR's public key"}
return nil, errutil.UserError{Err: "could not parse CSR's public key"}
}
// Verify that the key is at least 2048 bits
if pubKey.N.BitLen() < 2048 {
return nil, certutil.UserError{Err: "RSA keys < 2048 bits are unsafe and not supported"}
return nil, errutil.UserError{Err: "RSA keys < 2048 bits are unsafe and not supported"}
}
// Verify that the bit size is at least the size specified in the role
if pubKey.N.BitLen() < role.KeyBits {
return nil, certutil.UserError{Err: fmt.Sprintf(
return nil, errutil.UserError{Err: fmt.Sprintf(
"role requires a minimum of a %d-bit key, but CSR's key is %d bits",
role.KeyBits,
pubKey.N.BitLen())}
@ -471,18 +472,18 @@ func signCert(b *backend,
case "ec":
// Verify that the key matches the role type
if csr.PublicKeyAlgorithm != x509.ECDSA {
return nil, certutil.UserError{Err: fmt.Sprintf(
return nil, errutil.UserError{Err: fmt.Sprintf(
"role requires keys of type %s",
role.KeyType)}
}
pubKey, ok := csr.PublicKey.(*ecdsa.PublicKey)
if !ok {
return nil, certutil.UserError{Err: "could not parse CSR's public key"}
return nil, errutil.UserError{Err: "could not parse CSR's public key"}
}
// Verify that the bit size is at least the size specified in the role
if pubKey.Params().BitSize < role.KeyBits {
return nil, certutil.UserError{Err: fmt.Sprintf(
return nil, errutil.UserError{Err: fmt.Sprintf(
"role requires a minimum of a %d-bit key, but CSR's key is %d bits",
role.KeyBits,
pubKey.Params().BitSize)}
@ -498,10 +499,10 @@ func signCert(b *backend,
// Run RSA < 2048 bit checks
pubKey, ok := csr.PublicKey.(*rsa.PublicKey)
if !ok {
return nil, certutil.UserError{Err: "could not parse CSR's public key"}
return nil, errutil.UserError{Err: "could not parse CSR's public key"}
}
if pubKey.N.BitLen() < 2048 {
return nil, certutil.UserError{Err: "RSA keys < 2048 bits are unsafe and not supported"}
return nil, errutil.UserError{Err: "RSA keys < 2048 bits are unsafe and not supported"}
}
}
@ -545,7 +546,7 @@ func generateCreationBundle(b *backend,
if cn == "" {
cn = data.Get("common_name").(string)
if cn == "" {
return nil, certutil.UserError{Err: `the common_name field is required, or must be provided in a CSR with "use_csr_common_name" set to true`}
return nil, errutil.UserError{Err: `the common_name field is required, or must be provided in a CSR with "use_csr_common_name" set to true`}
}
}
}
@ -583,19 +584,19 @@ func generateCreationBundle(b *backend,
// Check for bad email and/or DNS names
badName, err := validateNames(req, dnsNames, role)
if len(badName) != 0 {
return nil, certutil.UserError{Err: fmt.Sprintf(
return nil, errutil.UserError{Err: fmt.Sprintf(
"name %s not allowed by this role", badName)}
} else if err != nil {
return nil, certutil.InternalError{Err: fmt.Sprintf(
return nil, errutil.InternalError{Err: fmt.Sprintf(
"error validating name %s: %s", badName, err)}
}
badName, err = validateNames(req, emailAddresses, role)
if len(badName) != 0 {
return nil, certutil.UserError{Err: fmt.Sprintf(
return nil, errutil.UserError{Err: fmt.Sprintf(
"email %s not allowed by this role", badName)}
} else if err != nil {
return nil, certutil.InternalError{Err: fmt.Sprintf(
return nil, errutil.InternalError{Err: fmt.Sprintf(
"error validating name %s: %s", badName, err)}
}
}
@ -609,13 +610,13 @@ func generateCreationBundle(b *backend,
ipAlt := ipAltInt.(string)
if len(ipAlt) != 0 {
if !role.AllowIPSANs {
return nil, certutil.UserError{Err: fmt.Sprintf(
return nil, errutil.UserError{Err: fmt.Sprintf(
"IP Subject Alternative Names are not allowed in this role, but was provided %s", ipAlt)}
}
for _, v := range strings.Split(ipAlt, ",") {
parsedIP := net.ParseIP(v)
if parsedIP == nil {
return nil, certutil.UserError{Err: fmt.Sprintf(
return nil, errutil.UserError{Err: fmt.Sprintf(
"the value '%s' is not a valid IP address", v)}
}
ipAddresses = append(ipAddresses, parsedIP)
@ -642,7 +643,7 @@ func generateCreationBundle(b *backend,
} else {
ttl, err = time.ParseDuration(ttlField)
if err != nil {
return nil, certutil.UserError{Err: fmt.Sprintf(
return nil, errutil.UserError{Err: fmt.Sprintf(
"invalid requested ttl: %s", err)}
}
}
@ -652,7 +653,7 @@ func generateCreationBundle(b *backend,
} else {
maxTTL, err = time.ParseDuration(role.MaxTTL)
if err != nil {
return nil, certutil.UserError{Err: fmt.Sprintf(
return nil, errutil.UserError{Err: fmt.Sprintf(
"invalid ttl: %s", err)}
}
}
@ -663,7 +664,7 @@ func generateCreationBundle(b *backend,
if len(ttlField) == 0 {
ttl = maxTTL
} else {
return nil, certutil.UserError{Err: fmt.Sprintf(
return nil, errutil.UserError{Err: fmt.Sprintf(
"ttl is larger than maximum allowed (%d)", maxTTL/time.Second)}
}
}
@ -672,7 +673,7 @@ func generateCreationBundle(b *backend,
// valid past the lifetime of the CA certificate
if signingBundle != nil &&
time.Now().Add(ttl).After(signingBundle.Certificate.NotAfter) {
return nil, certutil.UserError{Err: fmt.Sprintf(
return nil, errutil.UserError{Err: fmt.Sprintf(
"cannot satisfy request, as TTL is beyond the expiration of the CA certificate")}
}
}
@ -782,7 +783,7 @@ func createCertificate(creationInfo *creationBundle) (*certutil.ParsedCertBundle
subjKeyID, err := certutil.GetSubjKeyID(result.PrivateKey)
if err != nil {
return nil, certutil.InternalError{Err: fmt.Sprintf("error getting subject key ID: %s", err)}
return nil, errutil.InternalError{Err: fmt.Sprintf("error getting subject key ID: %s", err)}
}
subject := pkix.Name{
@ -845,13 +846,13 @@ func createCertificate(creationInfo *creationBundle) (*certutil.ParsedCertBundle
}
if err != nil {
return nil, certutil.InternalError{Err: fmt.Sprintf("unable to create certificate: %s", err)}
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to create certificate: %s", err)}
}
result.CertificateBytes = certBytes
result.Certificate, err = x509.ParseCertificate(certBytes)
if err != nil {
return nil, certutil.InternalError{Err: fmt.Sprintf("unable to parse created certificate: %s", err)}
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to parse created certificate: %s", err)}
}
if creationInfo.SigningBundle != nil {
@ -898,13 +899,13 @@ func createCSR(creationInfo *creationBundle) (*certutil.ParsedCSRBundle, error)
csr, err := x509.CreateCertificateRequest(rand.Reader, csrTemplate, result.PrivateKey)
if err != nil {
return nil, certutil.InternalError{Err: fmt.Sprintf("unable to create certificate: %s", err)}
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to create certificate: %s", err)}
}
result.CSRBytes = csr
result.CSR, err = x509.ParseCertificateRequest(csr)
if err != nil {
return nil, certutil.InternalError{Err: fmt.Sprintf("unable to parse created certificate: %s", err)}
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to parse created certificate: %s", err)}
}
return result, nil
@ -916,16 +917,16 @@ func signCertificate(creationInfo *creationBundle,
csr *x509.CertificateRequest) (*certutil.ParsedCertBundle, error) {
switch {
case creationInfo == nil:
return nil, certutil.UserError{Err: "nil creation info given to signCertificate"}
return nil, errutil.UserError{Err: "nil creation info given to signCertificate"}
case creationInfo.SigningBundle == nil:
return nil, certutil.UserError{Err: "nil signing bundle given to signCertificate"}
return nil, errutil.UserError{Err: "nil signing bundle given to signCertificate"}
case csr == nil:
return nil, certutil.UserError{Err: "nil csr given to signCertificate"}
return nil, errutil.UserError{Err: "nil csr given to signCertificate"}
}
err := csr.CheckSignature()
if err != nil {
return nil, certutil.UserError{Err: "request signature invalid"}
return nil, errutil.UserError{Err: "request signature invalid"}
}
result := &certutil.ParsedCertBundle{}
@ -937,7 +938,7 @@ func signCertificate(creationInfo *creationBundle,
marshaledKey, err := x509.MarshalPKIXPublicKey(csr.PublicKey)
if err != nil {
return nil, certutil.InternalError{Err: fmt.Sprintf("error marshalling public key: %s", err)}
return nil, errutil.InternalError{Err: fmt.Sprintf("error marshalling public key: %s", err)}
}
subjKeyID := sha1.Sum(marshaledKey)
@ -989,7 +990,7 @@ func signCertificate(creationInfo *creationBundle,
if creationInfo.SigningBundle.Certificate.MaxPathLen == 0 &&
creationInfo.SigningBundle.Certificate.MaxPathLenZero {
return nil, certutil.UserError{Err: "signing certificate has a max path length of zero, and cannot issue further CA certificates"}
return nil, errutil.UserError{Err: "signing certificate has a max path length of zero, and cannot issue further CA certificates"}
}
certTemplate.MaxPathLen = creationInfo.MaxPathLength
@ -1001,13 +1002,13 @@ func signCertificate(creationInfo *creationBundle,
certBytes, err = x509.CreateCertificate(rand.Reader, certTemplate, caCert, csr.PublicKey, creationInfo.SigningBundle.PrivateKey)
if err != nil {
return nil, certutil.InternalError{Err: fmt.Sprintf("unable to create certificate: %s", err)}
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to create certificate: %s", err)}
}
result.CertificateBytes = certBytes
result.Certificate, err = x509.ParseCertificate(certBytes)
if err != nil {
return nil, certutil.InternalError{Err: fmt.Sprintf("unable to parse created certificate: %s", err)}
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to parse created certificate: %s", err)}
}
result.IssuingCABytes = creationInfo.SigningBundle.CertificateBytes

View file

@ -7,7 +7,7 @@ import (
"fmt"
"time"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/logical"
)
@ -33,9 +33,9 @@ func revokeCert(b *backend, req *logical.Request, serial string, fromLease bool)
certEntry, err := fetchCertBySerial(req, "revoked/", serial)
if err != nil {
switch err.(type) {
case certutil.UserError:
case errutil.UserError:
return logical.ErrorResponse(err.Error()), nil
case certutil.InternalError:
case errutil.InternalError:
return nil, err
}
}
@ -58,9 +58,9 @@ func revokeCert(b *backend, req *logical.Request, serial string, fromLease bool)
certEntry, err = fetchCertBySerial(req, "certs/", serial)
if err != nil {
switch err.(type) {
case certutil.UserError:
case errutil.UserError:
return logical.ErrorResponse(err.Error()), nil
case certutil.InternalError:
case errutil.InternalError:
return nil, err
}
}
@ -103,9 +103,9 @@ func revokeCert(b *backend, req *logical.Request, serial string, fromLease bool)
crlErr := buildCRL(b, req)
switch crlErr.(type) {
case certutil.UserError:
case errutil.UserError:
return logical.ErrorResponse(fmt.Sprintf("Error during CRL building: %s", crlErr)), nil
case certutil.InternalError:
case errutil.InternalError:
return nil, fmt.Errorf("Error encountered during CRL building: %s", crlErr)
}
@ -121,7 +121,7 @@ func revokeCert(b *backend, req *logical.Request, serial string, fromLease bool)
func buildCRL(b *backend, req *logical.Request) error {
revokedSerials, err := req.Storage.List("revoked/")
if err != nil {
return certutil.InternalError{Err: fmt.Sprintf("Error fetching list of revoked certs: %s", err)}
return errutil.InternalError{Err: fmt.Sprintf("Error fetching list of revoked certs: %s", err)}
}
revokedCerts := []pkix.RevokedCertificate{}
@ -129,26 +129,26 @@ func buildCRL(b *backend, req *logical.Request) error {
for _, serial := range revokedSerials {
revokedEntry, err := req.Storage.Get("revoked/" + serial)
if err != nil {
return certutil.InternalError{Err: fmt.Sprintf("Unable to fetch revoked cert with serial %s: %s", serial, err)}
return errutil.InternalError{Err: fmt.Sprintf("Unable to fetch revoked cert with serial %s: %s", serial, err)}
}
if revokedEntry == nil {
return certutil.InternalError{Err: fmt.Sprintf("Revoked certificate entry for serial %s is nil", serial)}
return errutil.InternalError{Err: fmt.Sprintf("Revoked certificate entry for serial %s is nil", serial)}
}
if revokedEntry.Value == nil || len(revokedEntry.Value) == 0 {
// TODO: In this case, remove it and continue? How likely is this to
// happen? Alternately, could skip it entirely, or could implement a
// delete function so that there is a way to remove these
return certutil.InternalError{Err: fmt.Sprintf("Found revoked serial but actual certificate is empty")}
return errutil.InternalError{Err: fmt.Sprintf("Found revoked serial but actual certificate is empty")}
}
err = revokedEntry.DecodeJSON(&revInfo)
if err != nil {
return certutil.InternalError{Err: fmt.Sprintf("Error decoding revocation entry for serial %s: %s", serial, err)}
return errutil.InternalError{Err: fmt.Sprintf("Error decoding revocation entry for serial %s: %s", serial, err)}
}
revokedCert, err := x509.ParseCertificate(revInfo.CertificateBytes)
if err != nil {
return certutil.InternalError{Err: fmt.Sprintf("Unable to parse stored revoked certificate with serial %s: %s", serial, err)}
return errutil.InternalError{Err: fmt.Sprintf("Unable to parse stored revoked certificate with serial %s: %s", serial, err)}
}
revokedCerts = append(revokedCerts, pkix.RevokedCertificate{
@ -159,28 +159,28 @@ func buildCRL(b *backend, req *logical.Request) error {
signingBundle, caErr := fetchCAInfo(req)
switch caErr.(type) {
case certutil.UserError:
return certutil.UserError{Err: fmt.Sprintf("Could not fetch the CA certificate: %s", caErr)}
case certutil.InternalError:
return certutil.InternalError{Err: fmt.Sprintf("Error fetching CA certificate: %s", caErr)}
case errutil.UserError:
return errutil.UserError{Err: fmt.Sprintf("Could not fetch the CA certificate: %s", caErr)}
case errutil.InternalError:
return errutil.InternalError{Err: fmt.Sprintf("Error fetching CA certificate: %s", caErr)}
}
crlLifetime := b.crlLifetime
crlInfo, err := b.CRL(req.Storage)
if err != nil {
return certutil.InternalError{Err: fmt.Sprintf("Error fetching CRL config information: %s", err)}
return errutil.InternalError{Err: fmt.Sprintf("Error fetching CRL config information: %s", err)}
}
if crlInfo != nil {
crlDur, err := time.ParseDuration(crlInfo.Expiry)
if err != nil {
return certutil.InternalError{Err: fmt.Sprintf("Error parsing CRL duration of %s", crlInfo.Expiry)}
return errutil.InternalError{Err: fmt.Sprintf("Error parsing CRL duration of %s", crlInfo.Expiry)}
}
crlLifetime = crlDur
}
crlBytes, err := signingBundle.Certificate.CreateCRL(rand.Reader, signingBundle.PrivateKey, revokedCerts, time.Now(), time.Now().Add(crlLifetime))
if err != nil {
return certutil.InternalError{Err: fmt.Sprintf("Error creating new CRL: %s", err)}
return errutil.InternalError{Err: fmt.Sprintf("Error creating new CRL: %s", err)}
}
err = req.Storage.Put(&logical.StorageEntry{
@ -188,7 +188,7 @@ func buildCRL(b *backend, req *logical.Request) error {
Value: crlBytes,
})
if err != nil {
return certutil.InternalError{Err: fmt.Sprintf("Error storing CRL: %s", err)}
return errutil.InternalError{Err: fmt.Sprintf("Error storing CRL: %s", err)}
}
return nil

View file

@ -4,6 +4,7 @@ import (
"fmt"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -37,7 +38,7 @@ func (b *backend) pathCAWrite(
parsedBundle, err := certutil.ParsePEMBundle(pemBundle)
if err != nil {
switch err.(type) {
case certutil.InternalError:
case errutil.InternalError:
return nil, err
default:
return logical.ErrorResponse(err.Error()), nil

View file

@ -4,7 +4,7 @@ import (
"encoding/pem"
"fmt"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -139,10 +139,10 @@ func (b *backend) pathFetchRead(req *logical.Request, data *framework.FieldData)
certEntry, funcErr = fetchCertBySerial(req, req.Path, serial)
if funcErr != nil {
switch funcErr.(type) {
case certutil.UserError:
case errutil.UserError:
response = logical.ErrorResponse(funcErr.Error())
goto reply
case certutil.InternalError:
case errutil.InternalError:
retErr = funcErr
goto reply
}
@ -165,10 +165,10 @@ func (b *backend) pathFetchRead(req *logical.Request, data *framework.FieldData)
revokedEntry, funcErr = fetchCertBySerial(req, "revoked/", serial)
if funcErr != nil {
switch funcErr.(type) {
case certutil.UserError:
case errutil.UserError:
response = logical.ErrorResponse(funcErr.Error())
goto reply
case certutil.InternalError:
case errutil.InternalError:
retErr = funcErr
goto reply
}

View file

@ -5,6 +5,7 @@ import (
"fmt"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -65,9 +66,9 @@ func (b *backend) pathGenerateIntermediate(
parsedBundle, err := generateIntermediateCSR(b, role, nil, req, data)
if err != nil {
switch err.(type) {
case certutil.UserError:
case errutil.UserError:
return logical.ErrorResponse(err.Error()), nil
case certutil.InternalError:
case errutil.InternalError:
return nil, err
}
}
@ -132,7 +133,7 @@ func (b *backend) pathSetSignedIntermediate(
inputBundle, err := certutil.ParsePEMBundle(cert)
if err != nil {
switch err.(type) {
case certutil.InternalError:
case errutil.InternalError:
return nil, err
default:
return logical.ErrorResponse(err.Error()), nil

View file

@ -6,6 +6,7 @@ import (
"time"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -140,11 +141,11 @@ func (b *backend) pathIssueSignCert(
var caErr error
signingBundle, caErr := fetchCAInfo(req)
switch caErr.(type) {
case certutil.UserError:
return nil, certutil.UserError{Err: fmt.Sprintf(
case errutil.UserError:
return nil, errutil.UserError{Err: fmt.Sprintf(
"Could not fetch the CA certificate (was one set?): %s", caErr)}
case certutil.InternalError:
return nil, certutil.InternalError{Err: fmt.Sprintf(
case errutil.InternalError:
return nil, errutil.InternalError{Err: fmt.Sprintf(
"Error fetching CA certificate: %s", caErr)}
}
@ -157,9 +158,9 @@ func (b *backend) pathIssueSignCert(
}
if err != nil {
switch err.(type) {
case certutil.UserError:
case errutil.UserError:
return logical.ErrorResponse(err.Error()), nil
case certutil.InternalError:
case errutil.InternalError:
return nil, err
}
}

View file

@ -4,7 +4,7 @@ import (
"fmt"
"strings"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -64,9 +64,9 @@ func (b *backend) pathRotateCRLRead(req *logical.Request, data *framework.FieldD
crlErr := buildCRL(b, req)
switch crlErr.(type) {
case certutil.UserError:
case errutil.UserError:
return logical.ErrorResponse(fmt.Sprintf("Error during CRL building: %s", crlErr)), nil
case certutil.InternalError:
case errutil.InternalError:
return nil, fmt.Errorf("Error encountered during CRL building: %s", crlErr)
default:
return &logical.Response{

View file

@ -4,7 +4,7 @@ import (
"encoding/base64"
"fmt"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -84,9 +84,9 @@ func (b *backend) pathCAGenerateRoot(
parsedBundle, err := generateCert(b, role, nil, true, req, data)
if err != nil {
switch err.(type) {
case certutil.UserError:
case errutil.UserError:
return logical.ErrorResponse(err.Error()), nil
case certutil.InternalError:
case errutil.InternalError:
return nil, err
}
}
@ -201,11 +201,11 @@ func (b *backend) pathCASignIntermediate(
var caErr error
signingBundle, caErr := fetchCAInfo(req)
switch caErr.(type) {
case certutil.UserError:
return nil, certutil.UserError{Err: fmt.Sprintf(
case errutil.UserError:
return nil, errutil.UserError{Err: fmt.Sprintf(
"could not fetch the CA certificate (was one set?): %s", caErr)}
case certutil.InternalError:
return nil, certutil.InternalError{Err: fmt.Sprintf(
case errutil.InternalError:
return nil, errutil.InternalError{Err: fmt.Sprintf(
"error fetching CA certificate: %s", caErr)}
}
@ -220,9 +220,9 @@ func (b *backend) pathCASignIntermediate(
parsedBundle, err := signCert(b, role, signingBundle, true, useCSRValues, req, data)
if err != nil {
switch err.(type) {
case certutil.UserError:
case errutil.UserError:
return logical.ErrorResponse(err.Error()), nil
case certutil.InternalError:
case errutil.InternalError:
return nil, err
}
}

View file

@ -5,7 +5,7 @@ import (
"encoding/base64"
"fmt"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -103,9 +103,9 @@ func (b *backend) pathDatakeyWrite(
ciphertext, err := p.Encrypt(context, base64.StdEncoding.EncodeToString(newKey))
if err != nil {
switch err.(type) {
case certutil.UserError:
case errutil.UserError:
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
case certutil.InternalError:
case errutil.InternalError:
return nil, err
default:
return nil, err

View file

@ -4,7 +4,7 @@ import (
"encoding/base64"
"fmt"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -72,9 +72,9 @@ func (b *backend) pathDecryptWrite(
plaintext, err := p.Decrypt(context, ciphertext)
if err != nil {
switch err.(type) {
case certutil.UserError:
case errutil.UserError:
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
case certutil.InternalError:
case errutil.InternalError:
return nil, err
default:
return nil, err

View file

@ -5,7 +5,7 @@ import (
"fmt"
"sync"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -96,9 +96,9 @@ func (b *backend) pathEncryptWrite(
ciphertext, err := p.Encrypt(context, value)
if err != nil {
switch err.(type) {
case certutil.UserError:
case errutil.UserError:
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
case certutil.InternalError:
case errutil.InternalError:
return nil, err
default:
return nil, err

View file

@ -4,7 +4,7 @@ import (
"encoding/base64"
"fmt"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -74,9 +74,9 @@ func (b *backend) pathRewrapWrite(
plaintext, err := p.Decrypt(context, value)
if err != nil {
switch err.(type) {
case certutil.UserError:
case errutil.UserError:
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
case certutil.InternalError:
case errutil.InternalError:
return nil, err
default:
return nil, err
@ -90,9 +90,9 @@ func (b *backend) pathRewrapWrite(
ciphertext, err := p.Encrypt(context, plaintext)
if err != nil {
switch err.(type) {
case certutil.UserError:
case errutil.UserError:
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
case certutil.InternalError:
case errutil.InternalError:
return nil, err
default:
return nil, err

View file

@ -11,7 +11,7 @@ import (
"strings"
"time"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/helper/kdf"
"github.com/hashicorp/vault/logical"
@ -304,15 +304,15 @@ func (p *Policy) upgrade(storage logical.Storage) error {
// mode is used with the context to derive the proper key.
func (p *Policy) DeriveKey(context []byte, ver int) ([]byte, error) {
if p.Keys == nil || p.LatestVersion == 0 {
return nil, certutil.InternalError{Err: "unable to access the key; no key versions found"}
return nil, errutil.InternalError{Err: "unable to access the key; no key versions found"}
}
if p.LatestVersion == 0 {
return nil, certutil.InternalError{Err: "unable to access the key; no key versions found"}
return nil, errutil.InternalError{Err: "unable to access the key; no key versions found"}
}
if ver <= 0 || ver > p.LatestVersion {
return nil, certutil.UserError{Err: "invalid key version"}
return nil, errutil.UserError{Err: "invalid key version"}
}
// Fast-path non-derived keys
@ -322,7 +322,7 @@ func (p *Policy) DeriveKey(context []byte, ver int) ([]byte, error) {
// Ensure a context is provided
if len(context) == 0 {
return nil, certutil.UserError{Err: "missing 'context' for key deriviation. The key was created using a derived key, which means additional, per-request information must be included in order to encrypt or decrypt information"}
return nil, errutil.UserError{Err: "missing 'context' for key deriviation. The key was created using a derived key, which means additional, per-request information must be included in order to encrypt or decrypt information"}
}
switch p.KDFMode {
@ -331,7 +331,7 @@ func (p *Policy) DeriveKey(context []byte, ver int) ([]byte, error) {
prfLen := kdf.HMACSHA256PRFLen
return kdf.CounterMode(prf, prfLen, p.Keys[ver].Key, context, 256)
default:
return nil, certutil.InternalError{Err: "unsupported key derivation mode"}
return nil, errutil.InternalError{Err: "unsupported key derivation mode"}
}
}
@ -339,7 +339,7 @@ func (p *Policy) Encrypt(context []byte, value string) (string, error) {
// Decode the plaintext value
plaintext, err := base64.StdEncoding.DecodeString(value)
if err != nil {
return "", certutil.UserError{Err: "failed to decode plaintext as base64"}
return "", errutil.UserError{Err: "failed to decode plaintext as base64"}
}
// Derive the key that should be used
@ -352,23 +352,23 @@ func (p *Policy) Encrypt(context []byte, value string) (string, error) {
switch p.CipherMode {
case "aes-gcm":
default:
return "", certutil.InternalError{Err: "unsupported cipher mode"}
return "", errutil.InternalError{Err: "unsupported cipher mode"}
}
// Setup the cipher
aesCipher, err := aes.NewCipher(key)
if err != nil {
return "", certutil.InternalError{Err: err.Error()}
return "", errutil.InternalError{Err: err.Error()}
}
// Setup the GCM AEAD
gcm, err := cipher.NewGCM(aesCipher)
if err != nil {
return "", certutil.InternalError{Err: err.Error()}
return "", errutil.InternalError{Err: err.Error()}
}
if p.ConvergentEncryption && len(context) != gcm.NonceSize() {
return "", certutil.UserError{Err: fmt.Sprintf("base64-decoded context must be %d bytes long when using convergent encryption with this key", gcm.NonceSize())}
return "", errutil.UserError{Err: fmt.Sprintf("base64-decoded context must be %d bytes long when using convergent encryption with this key", gcm.NonceSize())}
}
// Compute random nonce
@ -379,7 +379,7 @@ func (p *Policy) Encrypt(context []byte, value string) (string, error) {
nonce = make([]byte, gcm.NonceSize())
_, err = rand.Read(nonce)
if err != nil {
return "", certutil.InternalError{Err: err.Error()}
return "", errutil.InternalError{Err: err.Error()}
}
}
@ -401,17 +401,17 @@ func (p *Policy) Encrypt(context []byte, value string) (string, error) {
func (p *Policy) Decrypt(context []byte, value string) (string, error) {
// Verify the prefix
if !strings.HasPrefix(value, "vault:v") {
return "", certutil.UserError{Err: "invalid ciphertext: no prefix"}
return "", errutil.UserError{Err: "invalid ciphertext: no prefix"}
}
splitVerCiphertext := strings.SplitN(strings.TrimPrefix(value, "vault:v"), ":", 2)
if len(splitVerCiphertext) != 2 {
return "", certutil.UserError{Err: "invalid ciphertext: wrong number of fields"}
return "", errutil.UserError{Err: "invalid ciphertext: wrong number of fields"}
}
ver, err := strconv.Atoi(splitVerCiphertext[0])
if err != nil {
return "", certutil.UserError{Err: "invalid ciphertext: version number could not be decoded"}
return "", errutil.UserError{Err: "invalid ciphertext: version number could not be decoded"}
}
if ver == 0 {
@ -421,11 +421,11 @@ func (p *Policy) Decrypt(context []byte, value string) (string, error) {
}
if ver > p.LatestVersion {
return "", certutil.UserError{Err: "invalid ciphertext: version is too new"}
return "", errutil.UserError{Err: "invalid ciphertext: version is too new"}
}
if p.MinDecryptionVersion > 0 && ver < p.MinDecryptionVersion {
return "", certutil.UserError{Err: ErrTooOld}
return "", errutil.UserError{Err: ErrTooOld}
}
// Derive the key that should be used
@ -438,25 +438,25 @@ func (p *Policy) Decrypt(context []byte, value string) (string, error) {
switch p.CipherMode {
case "aes-gcm":
default:
return "", certutil.InternalError{Err: "unsupported cipher mode"}
return "", errutil.InternalError{Err: "unsupported cipher mode"}
}
// Decode the base64
decoded, err := base64.StdEncoding.DecodeString(splitVerCiphertext[1])
if err != nil {
return "", certutil.UserError{Err: "invalid ciphertext: could not decode base64"}
return "", errutil.UserError{Err: "invalid ciphertext: could not decode base64"}
}
// Setup the cipher
aesCipher, err := aes.NewCipher(key)
if err != nil {
return "", certutil.InternalError{Err: err.Error()}
return "", errutil.InternalError{Err: err.Error()}
}
// Setup the GCM AEAD
gcm, err := cipher.NewGCM(aesCipher)
if err != nil {
return "", certutil.InternalError{Err: err.Error()}
return "", errutil.InternalError{Err: err.Error()}
}
// Extract the nonce and ciphertext
@ -466,7 +466,7 @@ func (p *Policy) Decrypt(context []byte, value string) (string, error) {
// Verify and Decrypt
plain, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return "", certutil.UserError{Err: "invalid ciphertext: unable to decrypt"}
return "", errutil.UserError{Err: "invalid ciphertext: unable to decrypt"}
}
return base64.StdEncoding.EncodeToString(plain), nil

View file

@ -15,6 +15,7 @@ import (
"strconv"
"strings"
"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/mitchellh/mapstructure"
)
@ -52,12 +53,12 @@ func ParseHexFormatted(in, sep string) []byte {
// of the marshaled public key
func GetSubjKeyID(privateKey crypto.Signer) ([]byte, error) {
if privateKey == nil {
return nil, InternalError{"passed-in private key is nil"}
return nil, errutil.InternalError{"passed-in private key is nil"}
}
marshaledKey, err := x509.MarshalPKIXPublicKey(privateKey.Public())
if err != nil {
return nil, InternalError{fmt.Sprintf("error marshalling public key: %s", err)}
return nil, errutil.InternalError{fmt.Sprintf("error marshalling public key: %s", err)}
}
subjKeyID := sha1.Sum(marshaledKey)
@ -71,7 +72,7 @@ func ParsePKIMap(data map[string]interface{}) (*ParsedCertBundle, error) {
result := &CertBundle{}
err := mapstructure.Decode(data, result)
if err != nil {
return nil, UserError{err.Error()}
return nil, errutil.UserError{err.Error()}
}
return result.ToParsedCertBundle()
@ -97,7 +98,7 @@ func ParsePKIJSON(input []byte) (*ParsedCertBundle, error) {
return ParsePKIMap(secret.Data)
}
return nil, UserError{"unable to parse out of either secret data or a secret object"}
return nil, errutil.UserError{"unable to parse out of either secret data or a secret object"}
}
// ParsePEMBundle takes a string of concatenated PEM-format certificate
@ -106,7 +107,7 @@ func ParsePKIJSON(input []byte) (*ParsedCertBundle, error) {
// issuing certificate) and one private key.
func ParsePEMBundle(pemBundle string) (*ParsedCertBundle, error) {
if len(pemBundle) == 0 {
return nil, UserError{"empty pem bundle"}
return nil, errutil.UserError{"empty pem bundle"}
}
pemBundle = strings.TrimSpace(pemBundle)
@ -118,12 +119,12 @@ func ParsePEMBundle(pemBundle string) (*ParsedCertBundle, error) {
for {
pemBlock, pemBytes = pem.Decode(pemBytes)
if pemBlock == nil {
return nil, UserError{"no data found"}
return nil, errutil.UserError{"no data found"}
}
if signer, err := x509.ParseECPrivateKey(pemBlock.Bytes); err == nil {
if parsedBundle.PrivateKeyType != UnknownPrivateKey {
return nil, UserError{"more than one private key given; provide only one private key in the bundle"}
return nil, errutil.UserError{"more than one private key given; provide only one private key in the bundle"}
}
parsedBundle.PrivateKeyFormat = ECBlock
parsedBundle.PrivateKeyType = ECPrivateKey
@ -132,7 +133,7 @@ func ParsePEMBundle(pemBundle string) (*ParsedCertBundle, error) {
} else if signer, err := x509.ParsePKCS1PrivateKey(pemBlock.Bytes); err == nil {
if parsedBundle.PrivateKeyType != UnknownPrivateKey {
return nil, UserError{"more than one private key given; provide only one private key in the bundle"}
return nil, errutil.UserError{"more than one private key given; provide only one private key in the bundle"}
}
parsedBundle.PrivateKeyType = RSAPrivateKey
parsedBundle.PrivateKeyFormat = PKCS1Block
@ -142,7 +143,7 @@ func ParsePEMBundle(pemBundle string) (*ParsedCertBundle, error) {
parsedBundle.PrivateKeyFormat = PKCS8Block
if parsedBundle.PrivateKeyType != UnknownPrivateKey {
return nil, UserError{"More than one private key given; provide only one private key in the bundle"}
return nil, errutil.UserError{"More than one private key given; provide only one private key in the bundle"}
}
switch signer := signer.(type) {
case *rsa.PrivateKey:
@ -157,7 +158,7 @@ func ParsePEMBundle(pemBundle string) (*ParsedCertBundle, error) {
} else if certificates, err := x509.ParseCertificates(pemBlock.Bytes); err == nil {
switch len(certificates) {
case 0:
return nil, UserError{"pem block cannot be decoded to a private key or certificate"}
return nil, errutil.UserError{"pem block cannot be decoded to a private key or certificate"}
case 1:
if parsedBundle.Certificate != nil {
@ -190,7 +191,7 @@ func ParsePEMBundle(pemBundle string) (*ParsedCertBundle, error) {
}
default:
return nil, UserError{"too many certificates given; provide a maximum of two certificates in the bundle"}
return nil, errutil.UserError{"too many certificates given; provide a maximum of two certificates in the bundle"}
}
}
@ -214,7 +215,7 @@ func GeneratePrivateKey(keyType string, keyBits int, container ParsedPrivateKeyC
privateKeyType = RSAPrivateKey
privateKey, err = rsa.GenerateKey(rand.Reader, keyBits)
if err != nil {
return InternalError{Err: fmt.Sprintf("error generating RSA private key: %v", err)}
return errutil.InternalError{Err: fmt.Sprintf("error generating RSA private key: %v", err)}
}
privateKeyBytes = x509.MarshalPKCS1PrivateKey(privateKey.(*rsa.PrivateKey))
case "ec":
@ -230,18 +231,18 @@ func GeneratePrivateKey(keyType string, keyBits int, container ParsedPrivateKeyC
case 521:
curve = elliptic.P521()
default:
return UserError{Err: fmt.Sprintf("unsupported bit length for EC key: %d", keyBits)}
return errutil.UserError{Err: fmt.Sprintf("unsupported bit length for EC key: %d", keyBits)}
}
privateKey, err = ecdsa.GenerateKey(curve, rand.Reader)
if err != nil {
return InternalError{Err: fmt.Sprintf("error generating EC private key: %v", err)}
return errutil.InternalError{Err: fmt.Sprintf("error generating EC private key: %v", err)}
}
privateKeyBytes, err = x509.MarshalECPrivateKey(privateKey.(*ecdsa.PrivateKey))
if err != nil {
return InternalError{Err: fmt.Sprintf("error marshalling EC private key: %v", err)}
return errutil.InternalError{Err: fmt.Sprintf("error marshalling EC private key: %v", err)}
}
default:
return UserError{Err: fmt.Sprintf("unknown key type: %s", keyType)}
return errutil.UserError{Err: fmt.Sprintf("unknown key type: %s", keyType)}
}
container.SetParsedPrivateKey(privateKey, privateKeyType, privateKeyBytes)
@ -252,7 +253,7 @@ func GeneratePrivateKey(keyType string, keyBits int, container ParsedPrivateKeyC
func GenerateSerialNumber() (*big.Int, error) {
serial, err := rand.Int(rand.Reader, (&big.Int{}).Exp(big.NewInt(2), big.NewInt(159), nil))
if err != nil {
return nil, InternalError{Err: fmt.Sprintf("error generating serial number: %v", err)}
return nil, errutil.InternalError{Err: fmt.Sprintf("error generating serial number: %v", err)}
}
return serial, nil
}

View file

@ -15,6 +15,8 @@ import (
"encoding/pem"
"fmt"
"strings"
"github.com/hashicorp/vault/helper/errutil"
)
// Secret is used to attempt to unmarshal a Vault secret
@ -57,25 +59,6 @@ const (
ECBlock BlockType = "EC PRIVATE KEY"
)
// UserError represents an error generated due to invalid user input
type UserError struct {
Err string
}
func (e UserError) Error() string {
return e.Err
}
// InternalError represents an error generated internally,
// presumably not due to invalid user input
type InternalError struct {
Err string
}
func (e InternalError) Error() string {
return e.Err
}
//ParsedPrivateKeyContainer allows common key setting for certs and CSRs
type ParsedPrivateKeyContainer interface {
SetParsedPrivateKey(crypto.Signer, PrivateKeyType, []byte)
@ -133,7 +116,7 @@ func (c *CertBundle) ToParsedCertBundle() (*ParsedCertBundle, error) {
if len(c.PrivateKey) > 0 {
pemBlock, _ = pem.Decode([]byte(c.PrivateKey))
if pemBlock == nil {
return nil, UserError{"Error decoding private key from cert bundle"}
return nil, errutil.UserError{"Error decoding private key from cert bundle"}
}
result.PrivateKeyBytes = pemBlock.Bytes
@ -147,7 +130,7 @@ func (c *CertBundle) ToParsedCertBundle() (*ParsedCertBundle, error) {
case PKCS8Block:
t, err := getPKCS8Type(pemBlock.Bytes)
if err != nil {
return nil, UserError{fmt.Sprintf("Error getting key type from pkcs#8: %v", err)}
return nil, errutil.UserError{fmt.Sprintf("Error getting key type from pkcs#8: %v", err)}
}
result.PrivateKeyType = t
switch t {
@ -157,36 +140,36 @@ func (c *CertBundle) ToParsedCertBundle() (*ParsedCertBundle, error) {
c.PrivateKeyType = RSAPrivateKey
}
default:
return nil, UserError{fmt.Sprintf("Unsupported key block type: %s", pemBlock.Type)}
return nil, errutil.UserError{fmt.Sprintf("Unsupported key block type: %s", pemBlock.Type)}
}
result.PrivateKey, err = result.getSigner()
if err != nil {
return nil, UserError{fmt.Sprintf("Error getting signer: %s", err)}
return nil, errutil.UserError{fmt.Sprintf("Error getting signer: %s", err)}
}
}
if len(c.Certificate) > 0 {
pemBlock, _ = pem.Decode([]byte(c.Certificate))
if pemBlock == nil {
return nil, UserError{"Error decoding certificate from cert bundle"}
return nil, errutil.UserError{"Error decoding certificate from cert bundle"}
}
result.CertificateBytes = pemBlock.Bytes
result.Certificate, err = x509.ParseCertificate(result.CertificateBytes)
if err != nil {
return nil, UserError{"Error encountered parsing certificate bytes from raw bundle"}
return nil, errutil.UserError{"Error encountered parsing certificate bytes from raw bundle"}
}
}
if len(c.IssuingCA) > 0 {
pemBlock, _ = pem.Decode([]byte(c.IssuingCA))
if pemBlock == nil {
return nil, UserError{"Error decoding issuing CA from cert bundle"}
return nil, errutil.UserError{"Error decoding issuing CA from cert bundle"}
}
result.IssuingCABytes = pemBlock.Bytes
result.IssuingCA, err = x509.ParseCertificate(result.IssuingCABytes)
if err != nil {
return nil, UserError{fmt.Sprintf("Error parsing CA certificate: %s", err)}
return nil, errutil.UserError{fmt.Sprintf("Error parsing CA certificate: %s", err)}
}
}
@ -249,20 +232,20 @@ func (p *ParsedCertBundle) getSigner() (crypto.Signer, error) {
var err error
if p.PrivateKeyBytes == nil || len(p.PrivateKeyBytes) == 0 {
return nil, UserError{"Given parsed cert bundle does not have private key information"}
return nil, errutil.UserError{"Given parsed cert bundle does not have private key information"}
}
switch p.PrivateKeyFormat {
case ECBlock:
signer, err = x509.ParseECPrivateKey(p.PrivateKeyBytes)
if err != nil {
return nil, UserError{fmt.Sprintf("Unable to parse CA's private EC key: %s", err)}
return nil, errutil.UserError{fmt.Sprintf("Unable to parse CA's private EC key: %s", err)}
}
case PKCS1Block:
signer, err = x509.ParsePKCS1PrivateKey(p.PrivateKeyBytes)
if err != nil {
return nil, UserError{fmt.Sprintf("Unable to parse CA's private RSA key: %s", err)}
return nil, errutil.UserError{fmt.Sprintf("Unable to parse CA's private RSA key: %s", err)}
}
case PKCS8Block:
@ -271,12 +254,12 @@ func (p *ParsedCertBundle) getSigner() (crypto.Signer, error) {
case *rsa.PrivateKey, *ecdsa.PrivateKey:
return k.(crypto.Signer), nil
default:
return nil, UserError{"Found unknown private key type in pkcs#8 wrapping"}
return nil, errutil.UserError{"Found unknown private key type in pkcs#8 wrapping"}
}
}
return nil, UserError{fmt.Sprintf("Failed to parse pkcs#8 key: %v", err)}
return nil, errutil.UserError{fmt.Sprintf("Failed to parse pkcs#8 key: %v", err)}
default:
return nil, UserError{"Unable to determine type of private key; only RSA and EC are supported"}
return nil, errutil.UserError{"Unable to determine type of private key; only RSA and EC are supported"}
}
return signer, nil
}
@ -291,7 +274,7 @@ func (p *ParsedCertBundle) SetParsedPrivateKey(privateKey crypto.Signer, private
func getPKCS8Type(bs []byte) (PrivateKeyType, error) {
k, err := x509.ParsePKCS8PrivateKey(bs)
if err != nil {
return UnknownPrivateKey, UserError{fmt.Sprintf("Failed to parse pkcs#8 key: %v", err)}
return UnknownPrivateKey, errutil.UserError{fmt.Sprintf("Failed to parse pkcs#8 key: %v", err)}
}
switch k.(type) {
@ -300,7 +283,7 @@ func getPKCS8Type(bs []byte) (PrivateKeyType, error) {
case *rsa.PrivateKey:
return RSAPrivateKey, nil
default:
return UnknownPrivateKey, UserError{"Found unknown private key type in pkcs#8 wrapping"}
return UnknownPrivateKey, errutil.UserError{"Found unknown private key type in pkcs#8 wrapping"}
}
}
@ -314,7 +297,7 @@ func (c *CSRBundle) ToParsedCSRBundle() (*ParsedCSRBundle, error) {
if len(c.PrivateKey) > 0 {
pemBlock, _ = pem.Decode([]byte(c.PrivateKey))
if pemBlock == nil {
return nil, UserError{"Error decoding private key from cert bundle"}
return nil, errutil.UserError{"Error decoding private key from cert bundle"}
}
result.PrivateKeyBytes = pemBlock.Bytes
@ -332,25 +315,25 @@ func (c *CSRBundle) ToParsedCSRBundle() (*ParsedCSRBundle, error) {
result.PrivateKeyType = RSAPrivateKey
c.PrivateKeyType = "rsa"
} else {
return nil, UserError{fmt.Sprintf("Unknown private key type in bundle: %s", c.PrivateKeyType)}
return nil, errutil.UserError{fmt.Sprintf("Unknown private key type in bundle: %s", c.PrivateKeyType)}
}
}
result.PrivateKey, err = result.getSigner()
if err != nil {
return nil, UserError{fmt.Sprintf("Error getting signer: %s", err)}
return nil, errutil.UserError{fmt.Sprintf("Error getting signer: %s", err)}
}
}
if len(c.CSR) > 0 {
pemBlock, _ = pem.Decode([]byte(c.CSR))
if pemBlock == nil {
return nil, UserError{"Error decoding certificate from cert bundle"}
return nil, errutil.UserError{"Error decoding certificate from cert bundle"}
}
result.CSRBytes = pemBlock.Bytes
result.CSR, err = x509.ParseCertificateRequest(result.CSRBytes)
if err != nil {
return nil, UserError{"Error encountered parsing certificate bytes from raw bundle"}
return nil, errutil.UserError{"Error encountered parsing certificate bytes from raw bundle"}
}
}
@ -380,7 +363,7 @@ func (p *ParsedCSRBundle) ToCSRBundle() (*CSRBundle, error) {
result.PrivateKeyType = "ec"
block.Type = "EC PRIVATE KEY"
default:
return nil, InternalError{"Could not determine private key type when creating block"}
return nil, errutil.InternalError{"Could not determine private key type when creating block"}
}
result.PrivateKey = strings.TrimSpace(string(pem.EncodeToMemory(&block)))
}
@ -397,24 +380,24 @@ func (p *ParsedCSRBundle) getSigner() (crypto.Signer, error) {
var err error
if p.PrivateKeyBytes == nil || len(p.PrivateKeyBytes) == 0 {
return nil, UserError{"Given parsed cert bundle does not have private key information"}
return nil, errutil.UserError{"Given parsed cert bundle does not have private key information"}
}
switch p.PrivateKeyType {
case ECPrivateKey:
signer, err = x509.ParseECPrivateKey(p.PrivateKeyBytes)
if err != nil {
return nil, UserError{fmt.Sprintf("Unable to parse CA's private EC key: %s", err)}
return nil, errutil.UserError{fmt.Sprintf("Unable to parse CA's private EC key: %s", err)}
}
case RSAPrivateKey:
signer, err = x509.ParsePKCS1PrivateKey(p.PrivateKeyBytes)
if err != nil {
return nil, UserError{fmt.Sprintf("Unable to parse CA's private RSA key: %s", err)}
return nil, errutil.UserError{fmt.Sprintf("Unable to parse CA's private RSA key: %s", err)}
}
default:
return nil, UserError{"Unable to determine type of private key; only RSA and EC are supported"}
return nil, errutil.UserError{"Unable to determine type of private key; only RSA and EC are supported"}
}
return signer, nil
}

20
helper/errutil/error.go Normal file
View file

@ -0,0 +1,20 @@
package errutil
// UserError represents an error generated due to invalid user input
type UserError struct {
Err string
}
func (e UserError) Error() string {
return e.Err
}
// InternalError represents an error generated internally,
// presumably not due to invalid user input
type InternalError struct {
Err string
}
func (e InternalError) Error() string {
return e.Err
}

View file

@ -11,6 +11,7 @@ import (
"time"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/logical"
)
@ -128,7 +129,7 @@ func (b *Backend) HandleExistenceCheck(req *logical.Request) (checkFound bool, e
err = fd.Validate()
if err != nil {
return false, false, err
return false, false, errutil.UserError{Err: err.Error()}
}
// Call the callback with the request and the data

View file

@ -15,6 +15,7 @@ import (
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/helper/mlock"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/physical"
@ -461,7 +462,11 @@ func (c *Core) checkToken(req *logical.Request) (*logical.Auth, *TokenEntry, err
// Continue on
default:
c.logger.Printf("[ERR] core: failed to run existence check: %v", err)
return nil, nil, ErrInternalError
if _, ok := err.(errutil.UserError); ok {
return nil, nil, err
} else {
return nil, nil, ErrInternalError
}
}
switch {