diff --git a/vault/token_store.go b/vault/token_store.go index 14d0c601e2..26dfbd4428 100644 --- a/vault/token_store.go +++ b/vault/token_store.go @@ -1404,8 +1404,8 @@ func (ts *TokenStore) revokeTree(ctx context.Context, le *leaseEntry) error { // Updated to be non-recursive and revoke child tokens // before parent tokens(DFS). func (ts *TokenStore) revokeTreeInternal(ctx context.Context, id string) error { - var dfs []string - dfs = append(dfs, id) + dfs := []string{id} + seenIDs := make(map[string]struct{}) var ns *namespace.Namespace @@ -1429,7 +1429,8 @@ func (ts *TokenStore) revokeTreeInternal(ctx context.Context, id string) error { } for l := len(dfs); l > 0; l = len(dfs) { - id := dfs[0] + id := dfs[len(dfs)-1] + seenIDs[id] = struct{}{} saltedCtx := ctx saltedNS := ns @@ -1444,11 +1445,26 @@ func (ts *TokenStore) revokeTreeInternal(ctx context.Context, id string) error { } path := saltedID + "/" - children, err := ts.parentView(saltedNS).List(saltedCtx, path) + childrenRaw, err := ts.parentView(saltedNS).List(saltedCtx, path) if err != nil { return errwrap.Wrapf("failed to scan for children: {{err}}", err) } + // Filter the child list to remove any items that have ever been in the dfs stack. + // This is a robustness check, as a parent/child cycle can lead to an OOM crash. + children := make([]string, 0, len(childrenRaw)) + for _, child := range childrenRaw { + if _, seen := seenIDs[child]; !seen { + children = append(children, child) + } else { + if err = ts.parentView(saltedNS).Delete(saltedCtx, path+child); err != nil { + return errwrap.Wrapf("failed to delete entry: {{err}}", err) + } + + ts.Logger().Warn("token cycle found", "token", child) + } + } + // If the length of the children array is zero, // then we are at a leaf node. if len(children) == 0 { @@ -1464,11 +1480,10 @@ func (ts *TokenStore) revokeTreeInternal(ctx context.Context, id string) error { if l == 1 { return nil } - dfs = dfs[1:] + dfs = dfs[:len(dfs)-1] } else { - // If we make it here, there are children and they must - // be prepended. - dfs = append(children, dfs...) + // If we make it here, there are children and they must be appended. + dfs = append(dfs, children...) } } diff --git a/vault/token_store_test.go b/vault/token_store_test.go index 49a7c7150b..26de3a0051 100644 --- a/vault/token_store_test.go +++ b/vault/token_store_test.go @@ -1004,17 +1004,46 @@ func TestTokenStore_Revoke_Orphan(t *testing.T) { // This was the original function name, and now it just calls // the non recursive version for a variety of depths. func TestTokenStore_RevokeTree(t *testing.T) { - testTokenStore_RevokeTree_NonRecursive(t, 1) - testTokenStore_RevokeTree_NonRecursive(t, 2) - testTokenStore_RevokeTree_NonRecursive(t, 10) + testTokenStore_RevokeTree_NonRecursive(t, 1, false) + testTokenStore_RevokeTree_NonRecursive(t, 2, false) + testTokenStore_RevokeTree_NonRecursive(t, 10, false) + + // corrupted trees with cycles + testTokenStore_RevokeTree_NonRecursive(t, 1, true) + testTokenStore_RevokeTree_NonRecursive(t, 10, true) } // Revokes a given Token Store tree non recursively. // The second parameter refers to the depth of the tree. -func testTokenStore_RevokeTree_NonRecursive(t testing.TB, depth uint64) { +func testTokenStore_RevokeTree_NonRecursive(t testing.TB, depth uint64, injectCycles bool) { c, _, _ := TestCoreUnsealed(t) ts := c.tokenStore root, children := buildTokenTree(t, ts, depth) + + var cyclePaths []string + if injectCycles { + // Make the root the parent of itself + saltedRoot, _ := ts.SaltID(namespace.TestContext(), root.ID) + key := fmt.Sprintf("%s/%s", saltedRoot, saltedRoot) + cyclePaths = append(cyclePaths, key) + le := &logical.StorageEntry{Key: key} + + if err := ts.parentView(namespace.TestNamespace()).Put(namespace.TestContext(), le); err != nil { + t.Fatalf("err: %v", err) + } + + // Make a deep child the parent of a shallow child + shallow, _ := ts.SaltID(namespace.TestContext(), children[0].ID) + deep, _ := ts.SaltID(namespace.TestContext(), children[len(children)-1].ID) + key = fmt.Sprintf("%s/%s", deep, shallow) + cyclePaths = append(cyclePaths, key) + le = &logical.StorageEntry{Key: key} + + if err := ts.parentView(namespace.TestNamespace()).Put(namespace.TestContext(), le); err != nil { + t.Fatalf("err: %v", err) + } + } + err := ts.revokeTree(namespace.TestContext(), &leaseEntry{}) if err.Error() != "cannot tree-revoke blank token" { t.Fatal(err) @@ -1049,6 +1078,16 @@ func testTokenStore_RevokeTree_NonRecursive(t testing.TB, depth uint64) { t.Fatalf("bad: %#v", out) } } + + for _, path := range cyclePaths { + entry, err := ts.parentView(namespace.TestNamespace()).Get(namespace.TestContext(), path) + if err != nil { + t.Fatalf("err: %v", err) + } + if entry != nil { + t.Fatalf("expected reference to be deleted: %v", entry) + } + } } // A benchmark function that tests testTokenStore_RevokeTree_NonRecursive @@ -1058,7 +1097,7 @@ func BenchmarkTokenStore_RevokeTree(b *testing.B) { for _, depth := range benchmarks { b.Run(fmt.Sprintf("Tree of Depth %d", depth), func(b *testing.B) { for i := 0; i < b.N; i++ { - testTokenStore_RevokeTree_NonRecursive(b, depth) + testTokenStore_RevokeTree_NonRecursive(b, depth, false) } }) }