diff --git a/builtin/logical/pki/backend.go b/builtin/logical/pki/backend.go index baca834258..210c710c2b 100644 --- a/builtin/logical/pki/backend.go +++ b/builtin/logical/pki/backend.go @@ -49,6 +49,7 @@ func Backend() *framework.Backend { pathFetchCRLViaCertPath(&b), pathFetchValid(&b), pathRevoke(&b), + pathTidy(&b), }, Secrets: []*framework.Secret{ @@ -57,7 +58,6 @@ func Backend() *framework.Backend { } b.crlLifetime = time.Hour * 72 - b.revokeStorageLock = &sync.Mutex{} return b.Backend } @@ -66,7 +66,7 @@ type backend struct { *framework.Backend crlLifetime time.Duration - revokeStorageLock *sync.Mutex + revokeStorageLock sync.RWMutex } const backendHelp = ` diff --git a/builtin/logical/pki/backend_test.go b/builtin/logical/pki/backend_test.go index b6c4dc1f9e..78091ad3be 100644 --- a/builtin/logical/pki/backend_test.go +++ b/builtin/logical/pki/backend_test.go @@ -29,7 +29,8 @@ import ( ) var ( - stepCount = 0 + stepCount = 0 + serialUnderTest string ) // Performs basic tests on CA functionality @@ -650,6 +651,11 @@ func generateCSRSteps(t *testing.T, caCert, caKey string, intdata, reqdata map[s // Generates steps to test out CA configuration -- certificates + CRL expiry, // and ensure that the certificates are readable after storing them func generateCATestingSteps(t *testing.T, caCert, caKey, otherCaCert string, intdata, reqdata map[string]interface{}) []logicaltest.TestStep { + setSerialUnderTest := func(req *logical.Request) error { + req.Path = serialUnderTest + return nil + } + ret := []logicaltest.TestStep{ logicaltest.TestStep{ Operation: logical.UpdateOperation, @@ -836,7 +842,7 @@ func generateCATestingSteps(t *testing.T, caCert, caKey, otherCaCert string, int delete(reqdata, "ttl") reqdata["csr"] = intdata["intermediatecsr"].(string) reqdata["common_name"] = "Intermediate Cert" - reqdata["ttl"] = "90h" + reqdata["ttl"] = "10s" return nil }, }, @@ -851,6 +857,7 @@ func generateCATestingSteps(t *testing.T, caCert, caKey, otherCaCert string, int delete(reqdata, "ttl") intdata["intermediatecert"] = resp.Data["certificate"].(string) reqdata["serial_number"] = resp.Data["serial_number"].(string) + reqdata["rsa_int_serial_number"] = resp.Data["serial_number"].(string) reqdata["certificate"] = resp.Data["certificate"].(string) reqdata["pem_bundle"] = intdata["intermediatekey"].(string) + "\n" + resp.Data["certificate"].(string) return nil @@ -972,7 +979,7 @@ func generateCATestingSteps(t *testing.T, caCert, caKey, otherCaCert string, int delete(reqdata, "ttl") reqdata["csr"] = intdata["intermediatecsr"].(string) reqdata["common_name"] = "Intermediate Cert" - reqdata["ttl"] = "90h" + reqdata["ttl"] = "10s" return nil }, }, @@ -987,6 +994,7 @@ func generateCATestingSteps(t *testing.T, caCert, caKey, otherCaCert string, int delete(reqdata, "ttl") intdata["intermediatecert"] = resp.Data["certificate"].(string) reqdata["serial_number"] = resp.Data["serial_number"].(string) + reqdata["ec_int_serial_number"] = resp.Data["serial_number"].(string) reqdata["certificate"] = resp.Data["certificate"].(string) reqdata["pem_bundle"] = intdata["intermediatekey"].(string) + "\n" + resp.Data["certificate"].(string) return nil @@ -1040,13 +1048,220 @@ func generateCATestingSteps(t *testing.T, caCert, caKey, otherCaCert string, int revokedString := certutil.GetOctalFormatted(revEntry.SerialNumber.Bytes(), ":") if revokedString == reqdata["serial_number"].(string) { found = true - } } if !found { t.Fatalf("did not find %s in CRL", reqdata["serial_number"].(string)) } delete(reqdata, "serial_number") + + serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string) + + return nil + }, + }, + + // Make sure both serial numbers we expect to find are found + logicaltest.TestStep{ + Operation: logical.ReadOperation, + PreFlight: setSerialUnderTest, + Check: func(resp *logical.Response) error { + if resp.Data["error"] != nil && resp.Data["error"].(string) != "" { + return fmt.Errorf("got an error: %s", resp.Data["error"].(string)) + } + + serialUnderTest = "cert/" + reqdata["ec_int_serial_number"].(string) + + return nil + }, + }, + + logicaltest.TestStep{ + Operation: logical.ReadOperation, + PreFlight: setSerialUnderTest, + Check: func(resp *logical.Response) error { + if resp.Data["error"] != nil && resp.Data["error"].(string) != "" { + return fmt.Errorf("got an error: %s", resp.Data["error"].(string)) + } + + // Give time for the certificates to pass the safety buffer + t.Logf("Sleeping for 15 seconds to allow safety buffer time to pass before testing tidying") + time.Sleep(15 * time.Second) + + serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string) + + return nil + }, + }, + + // This shouldn't do anything since the safety buffer is too long + logicaltest.TestStep{ + Operation: logical.UpdateOperation, + Path: "tidy", + Data: map[string]interface{}{ + "safety_buffer": "3h", + "tidy_cert_store": true, + "tidy_revocation_list": true, + }, + }, + + // We still expect to find these + logicaltest.TestStep{ + Operation: logical.ReadOperation, + PreFlight: setSerialUnderTest, + Check: func(resp *logical.Response) error { + if resp.Data["error"] != nil && resp.Data["error"].(string) != "" { + return fmt.Errorf("got an error: %s", resp.Data["error"].(string)) + } + + serialUnderTest = "cert/" + reqdata["ec_int_serial_number"].(string) + + return nil + }, + }, + + logicaltest.TestStep{ + Operation: logical.ReadOperation, + PreFlight: setSerialUnderTest, + Check: func(resp *logical.Response) error { + if resp.Data["error"] != nil && resp.Data["error"].(string) != "" { + return fmt.Errorf("got an error: %s", resp.Data["error"].(string)) + } + + serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string) + + return nil + }, + }, + + // Both should appear in the CRL + logicaltest.TestStep{ + Operation: logical.ReadOperation, + Path: "crl", + Data: reqdata, + Check: func(resp *logical.Response) error { + crlBytes := resp.Data["http_raw_body"].([]byte) + certList, err := x509.ParseCRL(crlBytes) + if err != nil { + t.Fatalf("err: %s", err) + } + revokedList := certList.TBSCertList.RevokedCertificates + if len(revokedList) != 2 { + t.Fatalf("length of revoked list not 2; %d", len(revokedList)) + } + foundRsa := false + foundEc := false + for _, revEntry := range revokedList { + revokedString := certutil.GetOctalFormatted(revEntry.SerialNumber.Bytes(), ":") + if revokedString == reqdata["rsa_int_serial_number"].(string) { + foundRsa = true + } + if revokedString == reqdata["ec_int_serial_number"].(string) { + foundEc = true + } + } + if !foundRsa || !foundEc { + t.Fatalf("did not find an expected entry in CRL") + } + + return nil + }, + }, + + // This shouldn't do anything since the boolean values default to false + logicaltest.TestStep{ + Operation: logical.UpdateOperation, + Path: "tidy", + Data: map[string]interface{}{ + "safety_buffer": "1s", + }, + }, + + // We still expect to find these + logicaltest.TestStep{ + Operation: logical.ReadOperation, + PreFlight: setSerialUnderTest, + Check: func(resp *logical.Response) error { + if resp.Data["error"] != nil && resp.Data["error"].(string) != "" { + return fmt.Errorf("got an error: %s", resp.Data["error"].(string)) + } + + serialUnderTest = "cert/" + reqdata["ec_int_serial_number"].(string) + + return nil + }, + }, + + logicaltest.TestStep{ + Operation: logical.ReadOperation, + PreFlight: setSerialUnderTest, + Check: func(resp *logical.Response) error { + if resp.Data["error"] != nil && resp.Data["error"].(string) != "" { + return fmt.Errorf("got an error: %s", resp.Data["error"].(string)) + } + + serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string) + + return nil + }, + }, + + // This should remove the values since the safety buffer is short + logicaltest.TestStep{ + Operation: logical.UpdateOperation, + Path: "tidy", + Data: map[string]interface{}{ + "safety_buffer": "1s", + "tidy_cert_store": true, + "tidy_revocation_list": true, + }, + }, + + // We do *not* expect to find these + logicaltest.TestStep{ + Operation: logical.ReadOperation, + PreFlight: setSerialUnderTest, + Check: func(resp *logical.Response) error { + if resp.Data["error"] == nil || resp.Data["error"].(string) == "" { + return fmt.Errorf("didn't get an expected error") + } + + serialUnderTest = "cert/" + reqdata["ec_int_serial_number"].(string) + + return nil + }, + }, + + logicaltest.TestStep{ + Operation: logical.ReadOperation, + PreFlight: setSerialUnderTest, + Check: func(resp *logical.Response) error { + if resp.Data["error"] == nil || resp.Data["error"].(string) == "" { + return fmt.Errorf("didn't get an expected error") + } + + serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string) + + return nil + }, + }, + + // Both should be gone from the CRL + logicaltest.TestStep{ + Operation: logical.ReadOperation, + Path: "crl", + Data: reqdata, + Check: func(resp *logical.Response) error { + crlBytes := resp.Data["http_raw_body"].([]byte) + certList, err := x509.ParseCRL(crlBytes) + if err != nil { + t.Fatalf("err: %s", err) + } + revokedList := certList.TBSCertList.RevokedCertificates + if len(revokedList) != 0 { + t.Fatalf("length of revoked list not 0; %d", len(revokedList)) + } + return nil }, }, diff --git a/builtin/logical/pki/crl_util.go b/builtin/logical/pki/crl_util.go index bdd06362fc..53459d48d0 100644 --- a/builtin/logical/pki/crl_util.go +++ b/builtin/logical/pki/crl_util.go @@ -112,9 +112,6 @@ func revokeCert(b *backend, req *logical.Request, serial string) (*logical.Respo // Builds a CRL by going through the list of revoked certificates and building // a new CRL with the stored revocation times and serial numbers. -// -// If a certificate has already expired, it will be removed entirely rather than -// become part of the new CRL. func buildCRL(b *backend, req *logical.Request) error { revokedSerials, err := req.Storage.List("revoked/") if err != nil { diff --git a/builtin/logical/pki/path_fetch.go b/builtin/logical/pki/path_fetch.go index 5d242e3771..ebde73da8e 100644 --- a/builtin/logical/pki/path_fetch.go +++ b/builtin/logical/pki/path_fetch.go @@ -114,16 +114,6 @@ func (b *backend) pathFetchRead(req *logical.Request, data *framework.FieldData) goto reply } - _, funcErr = fetchCAInfo(req) - switch funcErr.(type) { - case certutil.UserError: - response = logical.ErrorResponse(fmt.Sprintf("%s", funcErr)) - goto reply - case certutil.InternalError: - retErr = funcErr - goto reply - } - certEntry, funcErr = fetchCertBySerial(req, req.Path, serial) if funcErr != nil { switch funcErr.(type) { diff --git a/builtin/logical/pki/path_intermediate.go b/builtin/logical/pki/path_intermediate.go index c17347cb2b..667caf5cac 100644 --- a/builtin/logical/pki/path_intermediate.go +++ b/builtin/logical/pki/path_intermediate.go @@ -209,6 +209,13 @@ func (b *backend) pathSetSignedIntermediate( return nil, err } + entry.Key = "certs/" + cb.SerialNumber + entry.Value = inputBundle.CertificateBytes + err = req.Storage.Put(entry) + if err != nil { + return nil, err + } + // For ease of later use, also store just the certificate at a known // location entry.Key = "ca" diff --git a/builtin/logical/pki/path_revoke.go b/builtin/logical/pki/path_revoke.go index 280fa7cefc..b6a7e2e72f 100644 --- a/builtin/logical/pki/path_revoke.go +++ b/builtin/logical/pki/path_revoke.go @@ -54,8 +54,8 @@ func (b *backend) pathRevokeWrite(req *logical.Request, data *framework.FieldDat } func (b *backend) pathRotateCRLRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - b.revokeStorageLock.Lock() - defer b.revokeStorageLock.Unlock() + b.revokeStorageLock.RLock() + defer b.revokeStorageLock.RUnlock() crlErr := buildCRL(b, req) switch crlErr.(type) { diff --git a/builtin/logical/pki/path_tidy.go b/builtin/logical/pki/path_tidy.go new file mode 100644 index 0000000000..e145fa0186 --- /dev/null +++ b/builtin/logical/pki/path_tidy.go @@ -0,0 +1,169 @@ +package pki + +import ( + "crypto/x509" + "fmt" + "time" + + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func pathTidy(b *backend) *framework.Path { + return &framework.Path{ + Pattern: "tidy", + Fields: map[string]*framework.FieldSchema{ + "tidy_cert_store": &framework.FieldSchema{ + Type: framework.TypeBool, + Description: `Set to true to enable tidying up +the certificate store`, + Default: false, + }, + + "tidy_revocation_list": &framework.FieldSchema{ + Type: framework.TypeBool, + Description: `Set to true to enable tidying up +the revocation list`, + Default: false, + }, + + "safety_buffer": &framework.FieldSchema{ + Type: framework.TypeDurationSecond, + Description: `The amount of extra time that must have passed +beyond certificate expiration before it is removed +from the backend storage and/or revocation list. +Defaults to 72 hours.`, + Default: 259200, //72h, but TypeDurationSecond currently requires defaults to be int + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.UpdateOperation: b.pathTidyWrite, + }, + + HelpSynopsis: pathTidyHelpSyn, + HelpDescription: pathTidyHelpDesc, + } +} + +func (b *backend) pathTidyWrite( + req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + safety_buffer := d.Get("safety_buffer").(int) + tidyCertStore := d.Get("tidy_cert_store").(bool) + tidyRevocationList := d.Get("tidy_revocation_list").(bool) + + bufferDuration := time.Duration(safety_buffer) * time.Second + + if tidyCertStore { + serials, err := req.Storage.List("certs/") + if err != nil { + return nil, fmt.Errorf("error fetching list of revoked certs: %s", err) + } + + for _, serial := range serials { + certEntry, err := req.Storage.Get("certs/" + serial) + if err != nil { + return nil, fmt.Errorf("error fetching certificate %s: %s", serial, err) + } + + if certEntry == nil { + return nil, fmt.Errorf("revoked certificate entry for serial %s is nil", serial) + } + + if certEntry.Value == nil || len(certEntry.Value) == 0 { + return nil, fmt.Errorf("found entry for serial %s but actual certificate is empty", serial) + } + + cert, err := x509.ParseCertificate(certEntry.Value) + if err != nil { + return nil, fmt.Errorf("unable to parse stored certificate with serial %s: %s", serial, err) + } + + if time.Now().After(cert.NotAfter.Add(bufferDuration)) { + if err := req.Storage.Delete("certs/" + serial); err != nil { + return nil, fmt.Errorf("error deleting serial %s from storage: %s", serial, err) + } + } + } + } + + if tidyRevocationList { + b.revokeStorageLock.Lock() + defer b.revokeStorageLock.Unlock() + + tidiedRevoked := false + + revokedSerials, err := req.Storage.List("revoked/") + if err != nil { + return nil, fmt.Errorf("error fetching list of revoked certs: %s", err) + } + + var revInfo revocationInfo + for _, serial := range revokedSerials { + revokedEntry, err := req.Storage.Get("revoked/" + serial) + if err != nil { + return nil, fmt.Errorf("unable to fetch revoked cert with serial %s: %s", serial, err) + } + if revokedEntry == nil { + return nil, fmt.Errorf("revoked certificate entry for serial %s is nil", serial) + } + if revokedEntry.Value == nil || len(revokedEntry.Value) == 0 { + // TODO: In this case, remove it and continue? How likely is this to + // happen? Alternately, could skip it entirely, or could implement a + // delete function so that there is a way to remove these + return nil, fmt.Errorf("found revoked serial but actual certificate is empty") + } + + err = revokedEntry.DecodeJSON(&revInfo) + if err != nil { + return nil, fmt.Errorf("error decoding revocation entry for serial %s: %s", serial, err) + } + + revokedCert, err := x509.ParseCertificate(revInfo.CertificateBytes) + if err != nil { + return nil, fmt.Errorf("unable to parse stored revoked certificate with serial %s: %s", serial, err) + } + + if time.Now().After(revokedCert.NotAfter.Add(bufferDuration)) { + if err := req.Storage.Delete("revoked/" + serial); err != nil { + return nil, fmt.Errorf("error deleting serial %s from revoked list: %s", serial, err) + } + tidiedRevoked = true + } + } + + if tidiedRevoked { + if err := buildCRL(b, req); err != nil { + return nil, err + } + } + } + + return nil, nil +} + +const pathTidyHelpSyn = ` +Tidy up the backend by removing expired certificates, revocation information, +or both. +` + +const pathTidyHelpDesc = ` +This endpoint allows expired certificates and/or revocation information to be +removed from the backend, freeing up storage and shortening CRLs. + +For safety, this function is a noop if called without parameters; cleanup from +normal certificate storage must be enabled with 'tidy_cert_store' and cleanup +from revocation information must be enabled with 'tidy_revocation_list'. + +The 'safety_buffer' parameter is useful to ensure that clock skew amongst your +hosts cannot lead to a certificate being removed from the CRL while it is still +considered valid by other hosts (for instance, if their clocks are a few +minutes behind). + +All certificates and/or revocation information currently stored in the backend +will be checked when this endpoint is hit. The expiration of the +certificate/revocation information of each certificate being held in +certificate storage or in revocation infomation will then be checked. If the +current time, minus the value of 'safety_buffer', is greater than the +expiration, it will be removed. +` diff --git a/logical/testing/testing.go b/logical/testing/testing.go index e720151cd2..b327c2101a 100644 --- a/logical/testing/testing.go +++ b/logical/testing/testing.go @@ -61,6 +61,10 @@ type TestStep struct { // step will be called Check TestCheckFunc + // PreFlight is called directly before execution of the request, allowing + // modification of the request paramters (e.g. Path) with dynamic values. + PreFlight PreFlightFunc + // ErrorOk, if true, will let erroneous responses through to the check ErrorOk bool @@ -77,6 +81,10 @@ type TestStep struct { // TestCheckFunc is the callback used for Check in TestStep. type TestCheckFunc func(*logical.Response) error +// PreFlightFunc is used to modify request parameters directly before execution +// in each TestStep. +type PreFlightFunc func(*logical.Request) error + // TestTeardownFunc is the callback used for Teardown in TestCase. type TestTeardownFunc func() error @@ -182,13 +190,10 @@ func Test(t TestT, c TestCase) { for i, s := range c.Steps { log.Printf("[WARN] Executing test step %d", i+1) - // Make sure to prefix the path with where we mounted the thing - path := fmt.Sprintf("%s/%s", prefix, s.Path) - // Create the request req := &logical.Request{ Operation: s.Operation, - Path: path, + Path: s.Path, Data: s.Data, } if !s.Unauthenticated { @@ -201,6 +206,19 @@ func Test(t TestT, c TestCase) { req.Connection = &logical.Connection{ConnState: s.ConnState} } + if s.PreFlight != nil { + ct := req.ClientToken + req.ClientToken = "" + if err := s.PreFlight(req); err != nil { + t.Error(fmt.Sprintf("Failed preflight for step %d: %s", i+1, err)) + break + } + req.ClientToken = ct + } + + // Make sure to prefix the path with where we mounted the thing + req.Path = fmt.Sprintf("%s/%s", prefix, req.Path) + // Make the request resp, err := core.HandleRequest(req) if resp != nil && resp.Secret != nil {