From 92de42170c20d958bffe6fe2252e7fe4d6147d40 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 11 Apr 2018 11:32:55 -0700 Subject: [PATCH] Port some ent mount changes (#4330) --- vault/auth.go | 60 ++++++++++++++++++--------------- vault/auth_test.go | 2 +- vault/logical_system.go | 24 ++++++------- vault/logical_system_helpers.go | 4 +-- vault/mount.go | 59 ++++++++++++++++++-------------- vault/mount_test.go | 18 ++++++---- 6 files changed, 92 insertions(+), 75 deletions(-) diff --git a/vault/auth.go b/vault/auth.go index e94b85d1e1..d75475a050 100644 --- a/vault/auth.go +++ b/vault/auth.go @@ -134,7 +134,7 @@ func (c *Core) enableCredential(ctx context.Context, entry *MountEntry) error { // Update the auth table newTable := c.auth.shallowClone() newTable.Entries = append(newTable.Entries, entry) - if err := c.persistAuth(ctx, newTable, entry.Local); err != nil { + if err := c.persistAuth(ctx, newTable, &entry.Local); err != nil { return errors.New("failed to update auth table") } @@ -235,7 +235,7 @@ func (c *Core) removeCredEntry(ctx context.Context, path string) error { } // Update the auth table - if err := c.persistAuth(ctx, newTable, entry.Local); err != nil { + if err := c.persistAuth(ctx, newTable, &entry.Local); err != nil { return errors.New("failed to update auth table") } @@ -281,7 +281,7 @@ func (c *Core) taintCredEntry(ctx context.Context, path string) error { } // Update the auth table - if err := c.persistAuth(ctx, c.auth, entry.Local); err != nil { + if err := c.persistAuth(ctx, c.auth, &entry.Local); err != nil { return errors.New("failed to update auth table") } @@ -369,7 +369,7 @@ func (c *Core) loadCredentials(ctx context.Context) error { return nil } - if err := c.persistAuth(ctx, c.auth, false); err != nil { + if err := c.persistAuth(ctx, c.auth, nil); err != nil { c.logger.Error("failed to persist auth table", "error", err) return errLoadAuthFailed } @@ -377,7 +377,7 @@ func (c *Core) loadCredentials(ctx context.Context) error { } // persistAuth is used to persist the auth table after modification -func (c *Core) persistAuth(ctx context.Context, table *MountTable, localOnly bool) error { +func (c *Core) persistAuth(ctx context.Context, table *MountTable, local *bool) error { if table.Type != credentialTableType { c.logger.Error("given table to persist has wrong type", "actual_type", table.Type, "expected_type", credentialTableType) return fmt.Errorf("invalid table type given, not persisting") @@ -406,45 +406,49 @@ func (c *Core) persistAuth(ctx context.Context, table *MountTable, localOnly boo } } - if !localOnly { - // Marshal the table - compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalAuth, nil) + writeTable := func(mt *MountTable, path string) error { + // Encode the mount table into JSON and compress it (lzw). + compressedBytes, err := jsonutil.EncodeJSONAndCompress(mt, nil) if err != nil { - c.logger.Error("failed to encode and/or compress auth table", "error", err) + c.logger.Error("failed to encode or compress auth mount table", "error", err) return err } // Create an entry entry := &Entry{ - Key: coreAuthConfigPath, + Key: path, Value: compressedBytes, } // Write to the physical backend if err := c.barrier.Put(ctx, entry); err != nil { - c.logger.Error("failed to persist auth table", "error", err) + c.logger.Error("failed to persist auth mount table", "error", err) return err } + return nil } - // Repeat with local auth - compressedBytes, err := jsonutil.EncodeJSONAndCompress(localAuth, nil) - if err != nil { - c.logger.Error("failed to encode and/or compress local auth table", "error", err) - return err + var err error + switch { + case local == nil: + // Write non-local mounts + err := writeTable(nonLocalAuth, coreAuthConfigPath) + if err != nil { + return err + } + + // Write local mounts + err = writeTable(localAuth, coreLocalAuthConfigPath) + if err != nil { + return err + } + case *local: + err = writeTable(localAuth, coreLocalAuthConfigPath) + default: + err = writeTable(nonLocalAuth, coreAuthConfigPath) } - entry := &Entry{ - Key: coreLocalAuthConfigPath, - Value: compressedBytes, - } - - if err := c.barrier.Put(ctx, entry); err != nil { - c.logger.Error("failed to persist local auth table", "error", err) - return err - } - - return nil + return err } // setupCredentials is invoked after we've loaded the auth table to @@ -520,7 +524,7 @@ func (c *Core) setupCredentials(ctx context.Context) error { } if persistNeeded { - return c.persistAuth(ctx, c.auth, false) + return c.persistAuth(ctx, c.auth, nil) } return nil diff --git a/vault/auth_test.go b/vault/auth_test.go index 6f66b4c584..8b32275997 100644 --- a/vault/auth_test.go +++ b/vault/auth_test.go @@ -164,7 +164,7 @@ func TestCore_EnableCredential_Local(t *testing.T) { } c.auth.Entries[1].Local = true - if err := c.persistAuth(context.Background(), c.auth, false); err != nil { + if err := c.persistAuth(context.Background(), c.auth, nil); err != nil { t.Fatal(err) } diff --git a/vault/logical_system.go b/vault/logical_system.go index 298c69c070..0970075203 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -1988,9 +1988,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string, var err error switch { case strings.HasPrefix(path, "auth/"): - err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local) + err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local) default: - err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local) + err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local) } if err != nil { mountEntry.Description = oldDesc @@ -2011,9 +2011,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string, var err error switch { case strings.HasPrefix(path, "auth/"): - err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local) + err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local) default: - err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local) + err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local) } if err != nil { mountEntry.Config.AuditNonHMACRequestKeys = oldVal @@ -2037,9 +2037,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string, var err error switch { case strings.HasPrefix(path, "auth/"): - err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local) + err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local) default: - err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local) + err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local) } if err != nil { mountEntry.Config.AuditNonHMACResponseKeys = oldVal @@ -2068,9 +2068,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string, var err error switch { case strings.HasPrefix(path, "auth/"): - err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local) + err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local) default: - err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local) + err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local) } if err != nil { mountEntry.Config.ListingVisibility = oldVal @@ -2092,9 +2092,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string, var err error switch { case strings.HasPrefix(path, "auth/"): - err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local) + err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local) default: - err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local) + err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local) } if err != nil { mountEntry.Config.PassthroughRequestHeaders = oldVal @@ -2154,9 +2154,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string, // Update the mount table switch { case strings.HasPrefix(path, "auth/"): - err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local) + err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local) default: - err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local) + err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local) } if err != nil { mountEntry.Options = oldVal diff --git a/vault/logical_system_helpers.go b/vault/logical_system_helpers.go index 48cbb173c7..d9fdb046b7 100644 --- a/vault/logical_system_helpers.go +++ b/vault/logical_system_helpers.go @@ -37,9 +37,9 @@ func (b *SystemBackend) tuneMountTTLs(ctx context.Context, path string, me *Moun var err error switch { case strings.HasPrefix(path, credentialRoutePrefix): - err = b.Core.persistAuth(ctx, b.Core.auth, me.Local) + err = b.Core.persistAuth(ctx, b.Core.auth, &me.Local) default: - err = b.Core.persistMounts(ctx, b.Core.mounts, me.Local) + err = b.Core.persistMounts(ctx, b.Core.mounts, &me.Local) } if err != nil { me.Config.MaxLeaseTTL = origMax diff --git a/vault/mount.go b/vault/mount.go index 5ef79bf701..7aaf5d6a99 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -336,7 +336,7 @@ func (c *Core) mountInternal(ctx context.Context, entry *MountEntry) error { newTable := c.mounts.shallowClone() newTable.Entries = append(newTable.Entries, entry) - if err := c.persistMounts(ctx, newTable, entry.Local); err != nil { + if err := c.persistMounts(ctx, newTable, &entry.Local); err != nil { c.logger.Error("failed to update mount table", "error", err) return logical.CodedError(500, "failed to update mount table") } @@ -457,7 +457,7 @@ func (c *Core) removeMountEntry(ctx context.Context, path string) error { } // Update the mount table - if err := c.persistMounts(ctx, newTable, entry.Local); err != nil { + if err := c.persistMounts(ctx, newTable, &entry.Local); err != nil { c.logger.Error("failed to remove entry from mounts table", "error", err) return logical.CodedError(500, "failed to remove entry from mounts table") } @@ -480,7 +480,7 @@ func (c *Core) taintMountEntry(ctx context.Context, path string) error { } // Update the mount table - if err := c.persistMounts(ctx, c.mounts, entry.Local); err != nil { + if err := c.persistMounts(ctx, c.mounts, &entry.Local); err != nil { c.logger.Error("failed to taint entry in mounts table", "error", err) return logical.CodedError(500, "failed to taint entry in mounts table") } @@ -571,7 +571,7 @@ func (c *Core) remount(ctx context.Context, src, dst string) error { } // Update the mount table - if err := c.persistMounts(ctx, c.mounts, entry.Local); err != nil { + if err := c.persistMounts(ctx, c.mounts, &entry.Local); err != nil { entry.Path = src entry.Tainted = true c.mountsLock.Unlock() @@ -710,7 +710,7 @@ func (c *Core) loadMounts(ctx context.Context) error { return nil } - if err := c.persistMounts(ctx, c.mounts, false); err != nil { + if err := c.persistMounts(ctx, c.mounts, nil); err != nil { c.logger.Error("failed to persist mount table", "error", err) return errLoadMountsFailed } @@ -718,7 +718,7 @@ func (c *Core) loadMounts(ctx context.Context) error { } // persistMounts is used to persist the mount table after modification -func (c *Core) persistMounts(ctx context.Context, table *MountTable, localOnly bool) error { +func (c *Core) persistMounts(ctx context.Context, table *MountTable, local *bool) error { if table.Type != mountTableType { c.logger.Error("given table to persist has wrong type", "actual_type", table.Type, "expected_type", mountTableType) return fmt.Errorf("invalid table type given, not persisting") @@ -747,17 +747,17 @@ func (c *Core) persistMounts(ctx context.Context, table *MountTable, localOnly b } } - if !localOnly { + writeTable := func(mt *MountTable, path string) error { // Encode the mount table into JSON and compress it (lzw). - compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalMounts, nil) + compressedBytes, err := jsonutil.EncodeJSONAndCompress(mt, nil) if err != nil { - c.logger.Error("failed to encode and/or compress the mount table", "error", err) + c.logger.Error("failed to encode or compress mount table", "error", err) return err } // Create an entry entry := &Entry{ - Key: coreMountConfigPath, + Key: path, Value: compressedBytes, } @@ -766,26 +766,33 @@ func (c *Core) persistMounts(ctx context.Context, table *MountTable, localOnly b c.logger.Error("failed to persist mount table", "error", err) return err } + + return nil } - // Repeat with local mounts - compressedBytes, err := jsonutil.EncodeJSONAndCompress(localMounts, nil) - if err != nil { - c.logger.Error("failed to encode and/or compress the local mount table", "error", err) - return err + var err error + switch { + case local == nil: + // Write non-local mounts + err := writeTable(nonLocalMounts, coreMountConfigPath) + if err != nil { + return err + } + + // Write local mounts + err = writeTable(localMounts, coreLocalMountConfigPath) + if err != nil { + return err + } + case *local: + // Write local mounts + err = writeTable(localMounts, coreLocalMountConfigPath) + default: + // Write non-local mounts + err = writeTable(nonLocalMounts, coreMountConfigPath) } - entry := &Entry{ - Key: coreLocalMountConfigPath, - Value: compressedBytes, - } - - if err := c.barrier.Put(ctx, entry); err != nil { - c.logger.Error("failed to persist local mount table", "error", err) - return err - } - - return nil + return err } // setupMounts is invoked after we've loaded the mount table to diff --git a/vault/mount_test.go b/vault/mount_test.go index 87a0f9ed29..e773003571 100644 --- a/vault/mount_test.go +++ b/vault/mount_test.go @@ -161,7 +161,7 @@ func TestCore_Mount_Local(t *testing.T) { } c.mounts.Entries[1].Local = true - if err := c.persistMounts(context.Background(), c.mounts, false); err != nil { + if err := c.persistMounts(context.Background(), c.mounts, nil); err != nil { t.Fatal(err) } @@ -557,7 +557,7 @@ func testCore_MountTable_UpgradeToTyped_Common( t.Fatal(err) } - var persistFunc func(context.Context, *MountTable, bool) error + var persistFunc func(context.Context, *MountTable, *bool) error // It should load successfully and be upgraded and persisted switch testType { @@ -571,7 +571,13 @@ func testCore_MountTable_UpgradeToTyped_Common( mt = c.auth case "audits": err = c.loadAudits(context.Background()) - persistFunc = c.persistAudit + persistFunc = func(ctx context.Context, mt *MountTable, b *bool) error { + if b == nil { + b = new(bool) + *b = false + } + return c.persistAudit(ctx, mt, *b) + } mt = c.audit } if err != nil { @@ -600,19 +606,19 @@ func testCore_MountTable_UpgradeToTyped_Common( // Now try saving invalid versions origTableType := mt.Type mt.Type = "foo" - if err := persistFunc(context.Background(), mt, false); err == nil { + if err := persistFunc(context.Background(), mt, nil); err == nil { t.Fatal("expected error") } if len(mt.Entries) > 0 { mt.Type = origTableType mt.Entries[0].Table = "bar" - if err := persistFunc(context.Background(), mt, false); err == nil { + if err := persistFunc(context.Background(), mt, nil); err == nil { t.Fatal("expected error") } mt.Entries[0].Table = mt.Type - if err := persistFunc(context.Background(), mt, false); err != nil { + if err := persistFunc(context.Background(), mt, nil); err != nil { t.Fatal(err) } }