Fix client-controlled-consistency for external plugins (#12117) (#12134)

* Allow requests to external plugins that modify storage to populate the X-Vault-Index response header.
This commit is contained in:
Vault Automation 2026-02-03 13:01:09 -05:00 committed by GitHub
parent 5d869440c3
commit 375a59c4cd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 848 additions and 643 deletions

3
changelog/_12117.txt Normal file
View file

@ -0,0 +1,3 @@
```release-note:bug
plugins (enterprise): Fix bug where requests to external plugins that modify storage weren't populating the X-Vault-Index response header.
```

View file

@ -5,6 +5,8 @@ package plugin
import (
"context"
"fmt"
"strconv"
"github.com/hashicorp/vault/sdk/logical"
"google.golang.org/grpc/metadata"
@ -12,23 +14,57 @@ import (
// pbMetadataCtxToLogicalCtx extracts the snapshot ID key from an incoming GRPC
// context and adds the logical context key to the returned context
func pbMetadataCtxToLogicalCtx(ctx context.Context) context.Context {
func pbMetadataCtxToLogicalCtx(ctx context.Context) (context.Context, error) {
var snapshotID string
snapshotIDs := metadata.ValueFromIncomingContext(ctx, snapshotIDCtxKey)
if len(snapshotIDs) > 0 {
snapshotID = snapshotIDs[0]
ctx = logical.CreateContextWithSnapshotID(ctx, snapshotID)
}
return logical.CreateContextWithSnapshotID(ctx, snapshotID)
clusterID := metadata.ValueFromIncomingContext(ctx, indexStateCtxKeyClusterID)
localRaw := metadata.ValueFromIncomingContext(ctx, indexStateCtxKeyLocal)
replicatedRaw := metadata.ValueFromIncomingContext(ctx, indexStateCtxKeyReplicated)
if len(clusterID) > 0 {
local, err := strconv.ParseUint(localRaw[0], 10, 64)
if err != nil {
return nil, fmt.Errorf("error parsing local index: %w", err)
}
replicated, err := strconv.ParseUint(replicatedRaw[0], 10, 64)
if err != nil {
return nil, fmt.Errorf("error parsing replicated index: %w", err)
}
w := &logical.WALState{
ClusterID: clusterID[0],
LocalIndex: local,
ReplicatedIndex: replicated,
}
ctx = logical.IndexStateContext(ctx, w)
} else {
ctx = logical.IndexStateContext(ctx, &logical.WALState{})
}
return ctx, nil
}
// logicalCtxToPBMetadataCtx extracts the logical context snapshot ID key from
// the context and appends it to an outgoing GRPC context
func logicalCtxToPBMetadataCtx(ctx context.Context) context.Context {
snapshotID, ok := logical.ContextSnapshotIDValue(ctx)
if !ok {
return ctx
var args []string
if snapshotID, ok := logical.ContextSnapshotIDValue(ctx); ok {
args = append(args, snapshotIDCtxKey, snapshotID)
}
return metadata.AppendToOutgoingContext(ctx, snapshotIDCtxKey, snapshotID)
if index := logical.IndexStateFromContext(ctx); index != nil {
args = append(args, indexStateCtxKeyClusterID, index.ClusterID,
indexStateCtxKeyLocal, fmt.Sprintf("%d", index.LocalIndex),
indexStateCtxKeyReplicated, fmt.Sprintf("%d", index.ReplicatedIndex))
}
return metadata.AppendToOutgoingContext(ctx, args...)
}
const snapshotIDCtxKey string = "snapshot_id"
const (
snapshotIDCtxKey string = "snapshot_id"
indexStateCtxKey = "index_state"
indexStateCtxKeyClusterID = indexStateCtxKey + "_cluster_id"
indexStateCtxKeyLocal = indexStateCtxKey + "_local"
indexStateCtxKeyReplicated = indexStateCtxKey + "_replicated"
)

View file

@ -114,6 +114,13 @@ func (b *backendGRPCPluginClient) HandleRequest(ctx context.Context, req *logica
return resp, pb.ProtoErrToErr(reply.Err)
}
if reply.WalIndex != nil {
req.SetResponseState(&logical.WALState{
LocalIndex: reply.WalIndex.LocalIndex,
ReplicatedIndex: reply.WalIndex.ReplicatedIndex,
})
}
return resp, nil
}

View file

@ -143,8 +143,16 @@ func (b *backendGRPCPluginServer) HandleRequest(ctx context.Context, args *pb.Ha
logicalReq.Storage = newGRPCStorageClient(brokeredClient)
reqCtx := pbMetadataCtxToLogicalCtx(ctx)
reqCtx, err := pbMetadataCtxToLogicalCtx(ctx)
if err != nil {
return &pb.HandleRequestReply{}, err
}
resp, respErr := backend.HandleRequest(reqCtx, logicalReq)
ws := logical.IndexStateFromContext(reqCtx)
if ws == nil {
ws = &logical.WALState{}
}
pbResp, err := pb.LogicalResponseToProtoResponse(resp)
if err != nil {
@ -154,6 +162,10 @@ func (b *backendGRPCPluginServer) HandleRequest(ctx context.Context, args *pb.Ha
return &pb.HandleRequestReply{
Response: pbResp,
Err: pb.ErrToProtoErr(respErr),
WalIndex: &pb.WALIndex{
LocalIndex: ws.LocalIndex,
ReplicatedIndex: ws.ReplicatedIndex,
},
}, nil
}

View file

@ -55,8 +55,8 @@ func (s *GRPCStorageClient) Get(ctx context.Context, key string) (*logical.Stora
}
func (s *GRPCStorageClient) Put(ctx context.Context, entry *logical.StorageEntry) error {
ctx = logicalCtxToPBMetadataCtx(ctx)
reply, err := s.client.Put(ctx, &pb.StoragePutArgs{
pbmctx := logicalCtxToPBMetadataCtx(ctx)
reply, err := s.client.Put(pbmctx, &pb.StoragePutArgs{
Entry: pb.LogicalStorageEntryToProtoStorageEntry(entry),
}, largeMsgGRPCCallOpts...)
if err != nil {
@ -65,6 +65,11 @@ func (s *GRPCStorageClient) Put(ctx context.Context, entry *logical.StorageEntry
if reply.Err != "" {
return errors.New(reply.Err)
}
ws := logical.IndexStateFromContext(ctx)
if ws != nil {
ws.ReplicatedIndex = reply.WalIndex.ReplicatedIndex
ws.LocalIndex = reply.WalIndex.LocalIndex
}
return nil
}
@ -79,6 +84,11 @@ func (s *GRPCStorageClient) Delete(ctx context.Context, key string) error {
if reply.Err != "" {
return errors.New(reply.Err)
}
ws := logical.IndexStateFromContext(ctx)
if ws != nil {
ws.ReplicatedIndex = reply.WalIndex.ReplicatedIndex
ws.LocalIndex = reply.WalIndex.LocalIndex
}
return nil
}
@ -92,7 +102,11 @@ func (s *GRPCStorageServer) List(ctx context.Context, args *pb.StorageListArgs)
if s.impl == nil {
return nil, errMissingStorage
}
ctx = pbMetadataCtxToLogicalCtx(ctx)
ctx, err := pbMetadataCtxToLogicalCtx(ctx)
if err != nil {
return nil, err
}
keys, err := s.impl.List(ctx, args.Prefix)
return &pb.StorageListReply{
Keys: keys,
@ -104,7 +118,10 @@ func (s *GRPCStorageServer) Get(ctx context.Context, args *pb.StorageGetArgs) (*
if s.impl == nil {
return nil, errMissingStorage
}
ctx = pbMetadataCtxToLogicalCtx(ctx)
ctx, err := pbMetadataCtxToLogicalCtx(ctx)
if err != nil {
return nil, err
}
storageEntry, err := s.impl.Get(ctx, args.Key)
if storageEntry == nil {
return &pb.StorageGetReply{
@ -122,10 +139,18 @@ func (s *GRPCStorageServer) Put(ctx context.Context, args *pb.StoragePutArgs) (*
if s.impl == nil {
return nil, errMissingStorage
}
ctx = pbMetadataCtxToLogicalCtx(ctx)
err := s.impl.Put(ctx, pb.ProtoStorageEntryToLogicalStorageEntry(args.Entry))
ctx, err := pbMetadataCtxToLogicalCtx(ctx)
if err != nil {
return nil, err
}
err = s.impl.Put(ctx, pb.ProtoStorageEntryToLogicalStorageEntry(args.Entry))
ws := logical.IndexStateFromContext(ctx)
return &pb.StoragePutReply{
Err: pb.ErrToString(err),
WalIndex: &pb.WALIndex{
LocalIndex: ws.LocalIndex,
ReplicatedIndex: ws.ReplicatedIndex,
},
}, nil
}
@ -133,10 +158,18 @@ func (s *GRPCStorageServer) Delete(ctx context.Context, args *pb.StorageDeleteAr
if s.impl == nil {
return nil, errMissingStorage
}
ctx = pbMetadataCtxToLogicalCtx(ctx)
err := s.impl.Delete(ctx, args.Key)
ctx, err := pbMetadataCtxToLogicalCtx(ctx)
if err != nil {
return nil, err
}
err = s.impl.Delete(ctx, args.Key)
ws := logical.IndexStateFromContext(ctx)
return &pb.StorageDeleteReply{
Err: pb.ErrToString(err),
WalIndex: &pb.WALIndex{
LocalIndex: ws.LocalIndex,
ReplicatedIndex: ws.ReplicatedIndex,
},
}, nil
}

File diff suppressed because it is too large Load diff

View file

@ -399,6 +399,7 @@ message HandleRequestArgs {
message HandleRequestReply {
Response response = 1;
ProtoError err = 2;
WALIndex walIndex = 3;
}
// InitializeArgs is the args for Initialize method.
@ -524,8 +525,14 @@ message StoragePutArgs {
StorageEntry entry = 1;
}
message WALIndex {
uint64 local_index = 1;
uint64 replicated_index = 2;
}
message StoragePutReply {
string err = 1;
WALIndex walIndex = 2;
}
message StorageDeleteArgs {
@ -534,6 +541,7 @@ message StorageDeleteArgs {
message StorageDeleteReply {
string err = 1;
WALIndex walIndex = 2;
}
// Storage is the way that plugins are able read/write data. Plugins should

View file

@ -12,8 +12,10 @@ import (
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/plugin"
"github.com/hashicorp/vault/helper/constants"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/helper/testhelpers/corehelpers"
"github.com/hashicorp/vault/helper/testhelpers/teststorage"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
@ -21,6 +23,7 @@ import (
lplugin "github.com/hashicorp/vault/sdk/plugin"
"github.com/hashicorp/vault/sdk/plugin/mock"
"github.com/hashicorp/vault/vault"
"github.com/stretchr/testify/require"
)
// logicalVersionMap is a map of version to test plugin
@ -74,6 +77,18 @@ func TestSystemBackend_Plugin_secret(t *testing.T) {
t.Fatalf("bad: response should not be nil")
}
req = logical.TestRequest(t, logical.CreateOperation, "mock-0/config")
req.ClientToken = core.Client.Token()
resp, err = core.HandleRequest(namespace.RootContext(testCtx), req)
if err != nil {
t.Fatalf("err: %v", err)
}
wspost := req.ResponseState()
if constants.IsEnterprise {
require.NotZero(t, wspost.ReplicatedIndex)
require.NotZero(t, wspost.LocalIndex)
}
// Seal the cluster
cluster.EnsureCoresSealed(t)
@ -630,7 +645,7 @@ func TestSystemBackend_PluginReload_WarningIfNoneReloaded(t *testing.T) {
func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType logical.BackendType, pluginVersion string) *vault.TestCluster {
t.Helper()
pluginDir := corehelpers.MakeTestPluginDir(t)
coreConfig := &vault.CoreConfig{
conf, opts := teststorage.ClusterSetup(&vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"plugin": plugin.Factory,
},
@ -638,15 +653,13 @@ func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType lo
"plugin": plugin.Factory,
},
PluginDirectory: pluginDir,
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
}, &vault.TestClusterOptions{
KeepStandbysSealed: true,
NumCores: numCores,
TempDir: pluginDir,
})
cluster.Start()
}, teststorage.InmemBackendSetup)
cluster := vault.NewTestCluster(t, conf, opts)
core := cluster.Cores[0]
vault.TestWaitActive(t, core.Core)