diff --git a/audit/entry_formatter.go b/audit/entry_formatter.go index c4a1987a7e..ebf53088ef 100644 --- a/audit/entry_formatter.go +++ b/audit/entry_formatter.go @@ -242,12 +242,12 @@ func (f *entryFormatter) formatRequest(ctx context.Context, in *logical.LogInput if !f.config.raw { var err error - auth, err = hashAuth(ctx, f.salter, auth, f.config.hmacAccessor) + err = hashAuth(ctx, f.salter, auth, f.config.hmacAccessor) if err != nil { return nil, err } - req, err = hashRequest(ctx, f.salter, req, f.config.hmacAccessor, in.NonHMACReqDataKeys) + err = hashRequest(ctx, f.salter, req, f.config.hmacAccessor, in.NonHMACReqDataKeys) if err != nil { return nil, err } @@ -394,17 +394,17 @@ func (f *entryFormatter) formatResponse(ctx context.Context, in *logical.LogInpu } } else { var err error - auth, err = hashAuth(ctx, f.salter, auth, f.config.hmacAccessor) + err = hashAuth(ctx, f.salter, auth, f.config.hmacAccessor) if err != nil { return nil, err } - req, err = hashRequest(ctx, f.salter, req, f.config.hmacAccessor, in.NonHMACReqDataKeys) + err = hashRequest(ctx, f.salter, req, f.config.hmacAccessor, in.NonHMACReqDataKeys) if err != nil { return nil, err } - resp, err = hashResponse(ctx, f.salter, resp, f.config.hmacAccessor, in.NonHMACRespDataKeys, elideListResponseData) + err = hashResponse(ctx, f.salter, resp, f.config.hmacAccessor, in.NonHMACRespDataKeys, elideListResponseData) if err != nil { return nil, err } diff --git a/audit/hashstructure.go b/audit/hashstructure.go index 95b970e1ae..6a11908010 100644 --- a/audit/hashstructure.go +++ b/audit/hashstructure.go @@ -13,11 +13,10 @@ import ( "github.com/hashicorp/go-secure-stdlib/strutil" "github.com/hashicorp/vault/sdk/helper/wrapping" "github.com/hashicorp/vault/sdk/logical" - "github.com/mitchellh/copystructure" "github.com/mitchellh/reflectwalk" ) -// hashString hashes the given opaque string and returns it +// hashString uses the Salter to hash the supplied opaque string and returns it. func hashString(ctx context.Context, salter Salter, data string) (string, error) { salt, err := salter.Salt(ctx) if err != nil { @@ -27,76 +26,68 @@ func hashString(ctx context.Context, salter Salter, data string) (string, error) return salt.GetIdentifiedHMAC(data), nil } -// hashAuth returns a hashed copy of the logical.Auth input. -func hashAuth(ctx context.Context, salter Salter, in *logical.Auth, HMACAccessor bool) (*logical.Auth, error) { - if in == nil { - return nil, nil +// hashAuth uses the Salter to hash the supplied Auth (modifying it). +// hmacAccessor is used to indicate whether the accessor should also be HMAC'd +// when present. +func hashAuth(ctx context.Context, salter Salter, auth *logical.Auth, hmacAccessor bool) error { + if auth == nil { + return nil } salt, err := salter.Salt(ctx) if err != nil { - return nil, err + return err } fn := salt.GetIdentifiedHMAC - auth := *in if auth.ClientToken != "" { auth.ClientToken = fn(auth.ClientToken) } - if HMACAccessor && auth.Accessor != "" { + if hmacAccessor && auth.Accessor != "" { auth.Accessor = fn(auth.Accessor) } - return &auth, nil + + return nil } -// hashRequest returns a hashed copy of the logical.Request input. -func hashRequest(ctx context.Context, salter Salter, in *logical.Request, HMACAccessor bool, nonHMACDataKeys []string) (*logical.Request, error) { - if in == nil { - return nil, nil +// hashRequest uses the Salter to hash the supplied Request (modifying it). +// nonHMACDataKeys is used when hashing any 'Data' field within the Request which +// prevents those specific keys from HMAC'd. +// hmacAccessor is used to indicate whether some accessors should also be HMAC'd +// when present. +func hashRequest(ctx context.Context, salter Salter, req *logical.Request, hmacAccessor bool, nonHMACDataKeys []string) error { + if req == nil { + return nil } salt, err := salter.Salt(ctx) if err != nil { - return nil, err + return err } fn := salt.GetIdentifiedHMAC - req := *in - if req.Auth != nil { - cp, err := copystructure.Copy(req.Auth) - if err != nil { - return nil, err - } - - req.Auth, err = hashAuth(ctx, salter, cp.(*logical.Auth), HMACAccessor) - if err != nil { - return nil, err - } + err = hashAuth(ctx, salter, req.Auth, hmacAccessor) + if err != nil { + return err } if req.ClientToken != "" { req.ClientToken = fn(req.ClientToken) } - if HMACAccessor && req.ClientTokenAccessor != "" { + if hmacAccessor && req.ClientTokenAccessor != "" { req.ClientTokenAccessor = fn(req.ClientTokenAccessor) } if req.Data != nil { - copy, err := copystructure.Copy(req.Data) + err = hashMap(fn, req.Data, nonHMACDataKeys) if err != nil { - return nil, err + return err } - - err = hashMap(fn, copy.(map[string]interface{}), nonHMACDataKeys) - if err != nil { - return nil, err - } - req.Data = copy.(map[string]interface{}) } - return &req, nil + return nil } func hashMap(hashFunc hashCallback, data map[string]interface{}, nonHMACDataKeys []string) error { @@ -112,102 +103,96 @@ func hashMap(hashFunc hashCallback, data map[string]interface{}, nonHMACDataKeys } } - return HashStructure(data, hashFunc, nonHMACDataKeys) + return hashStructure(data, hashFunc, nonHMACDataKeys) } -// hashResponse returns a hashed copy of the logical.Request input. -func hashResponse(ctx context.Context, salter Salter, in *logical.Response, HMACAccessor bool, nonHMACDataKeys []string, elideListResponseData bool) (*logical.Response, error) { - if in == nil { - return nil, nil +// hashResponse uses the Salter to hash the supplied Response (modifying it). +// nonHMACDataKeys is used when hashing any 'Data' field within the Request which +// prevents those specific keys from HMAC'd. +// hmacAccessor is used to indicate whether some accessors should also be HMAC'd +// when present. +// elideListResponseData indicates whether any 'keys' or 'key_info' data present in +// the Response should be elided (when the request was a LIST operation). +// See: /vault/docs/audit#eliding-list-response-bodies +func hashResponse(ctx context.Context, salter Salter, resp *logical.Response, hmacAccessor bool, nonHMACDataKeys []string, elideListResponseData bool) error { + if resp == nil { + return nil } salt, err := salter.Salt(ctx) if err != nil { - return nil, err + return err } fn := salt.GetIdentifiedHMAC - resp := *in - if resp.Auth != nil { - cp, err := copystructure.Copy(resp.Auth) - if err != nil { - return nil, err - } - - resp.Auth, err = hashAuth(ctx, salter, cp.(*logical.Auth), HMACAccessor) - if err != nil { - return nil, err - } + err = hashAuth(ctx, salter, resp.Auth, hmacAccessor) + if err != nil { + return err } if resp.Data != nil { - copy, err := copystructure.Copy(resp.Data) - if err != nil { - return nil, err - } - - mapCopy := copy.(map[string]interface{}) - if b, ok := mapCopy[logical.HTTPRawBody].([]byte); ok { - mapCopy[logical.HTTPRawBody] = string(b) + if b, ok := resp.Data[logical.HTTPRawBody].([]byte); ok { + resp.Data[logical.HTTPRawBody] = string(b) } // Processing list response data elision takes place at this point in the code for performance reasons: // - take advantage of the deep copy of resp.Data that was going to be done anyway for hashing // - but elide data before potentially spending time hashing it if elideListResponseData { - doElideListResponseData(mapCopy) + doElideListResponseData(resp.Data) } - err = hashMap(fn, mapCopy, nonHMACDataKeys) + err = hashMap(fn, resp.Data, nonHMACDataKeys) if err != nil { - return nil, err + return err } - resp.Data = mapCopy } if resp.WrapInfo != nil { var err error - resp.WrapInfo, err = hashWrapInfo(fn, resp.WrapInfo, HMACAccessor) + err = hashWrapInfo(fn, resp.WrapInfo, hmacAccessor) if err != nil { - return nil, err + return err } } - return &resp, nil + return nil } -// hashWrapInfo returns a hashed copy of the wrapping.ResponseWrapInfo input. -func hashWrapInfo(hashFunc hashCallback, in *wrapping.ResponseWrapInfo, HMACAccessor bool) (*wrapping.ResponseWrapInfo, error) { - if in == nil { - return nil, nil +// hashWrapInfo returns a hashed copy of the ResponseWrapInfo input. + +// hashWrapInfo uses the supplied hashing function to hash ResponseWrapInfo (modifying it). +// hmacAccessor is used to indicate whether some accessors should also be HMAC'd +// when present. +func hashWrapInfo(hashFunc hashCallback, wrapInfo *wrapping.ResponseWrapInfo, hmacAccessor bool) error { + if wrapInfo == nil { + return nil } - wrapinfo := *in + wrapInfo.Token = hashFunc(wrapInfo.Token) - wrapinfo.Token = hashFunc(wrapinfo.Token) + if hmacAccessor { + wrapInfo.Accessor = hashFunc(wrapInfo.Accessor) - if HMACAccessor { - wrapinfo.Accessor = hashFunc(wrapinfo.Accessor) - - if wrapinfo.WrappedAccessor != "" { - wrapinfo.WrappedAccessor = hashFunc(wrapinfo.WrappedAccessor) + if wrapInfo.WrappedAccessor != "" { + wrapInfo.WrappedAccessor = hashFunc(wrapInfo.WrappedAccessor) } } - return &wrapinfo, nil + return nil } -// HashStructure takes an interface and hashes all the values within +// hashStructure takes an interface and hashes all the values within // the structure. Only _values_ are hashed: keys of objects are not. // // For the hashCallback, see the built-in HashCallbacks below. -func HashStructure(s interface{}, cb hashCallback, ignoredKeys []string) error { +func hashStructure(s interface{}, cb hashCallback, ignoredKeys []string) error { walker := &hashWalker{Callback: cb, IgnoredKeys: ignoredKeys} return reflectwalk.Walk(s, walker) } -// hashCallback is the callback called for HashStructure to hash +// hashCallback is the callback called for hashStructure to hash // a value. type hashCallback func(string) string @@ -219,20 +204,26 @@ type hashWalker struct { // to be hashed. If there is an error, walking will be halted // immediately and the error returned. Callback hashCallback - // IgnoreKeys are the keys that wont have the hashCallback applied + + // IgnoreKeys are the keys that won't have the hashCallback applied IgnoredKeys []string + // MapElem appends the key itself (not the reflect.Value) to key. // The last element in key is the most recently entered map key. // Since Exit pops the last element of key, only nesting to another // structure increases the size of this slice. - key []string + key []string + lastValue reflect.Value + // Enter appends to loc and exit pops loc. The last element of loc is thus // the current location. loc []reflectwalk.Location + // Map and Slice append to cs, Exit pops the last element off cs. // The last element in cs is the most recently entered map or slice. cs []reflect.Value + // MapElem and SliceElem append to csKey. The last element in csKey is the // most recently entered map key or slice index. Since Exit pops the last // element of csKey, only nesting to another structure increases the size of diff --git a/audit/hashstructure_test.go b/audit/hashstructure_test.go index 5468cbfa7b..fe2eda1555 100644 --- a/audit/hashstructure_test.go +++ b/audit/hashstructure_test.go @@ -18,6 +18,7 @@ import ( "github.com/hashicorp/vault/sdk/helper/wrapping" "github.com/hashicorp/vault/sdk/logical" "github.com/mitchellh/copystructure" + "github.com/stretchr/testify/require" ) func TestCopy_auth(t *testing.T) { @@ -105,10 +106,13 @@ type TestSalter struct{} // storage instance. func (*TestSalter) Salt(ctx context.Context) (*salt.Salt, error) { inmemStorage := &logical.InmemStorage{} - inmemStorage.Put(context.Background(), &logical.StorageEntry{ + err := inmemStorage.Put(context.Background(), &logical.StorageEntry{ Key: "salt", Value: []byte("foo"), }) + if err != nil { + return nil, err + } return salt.NewSalt(context.Background(), inmemStorage, &salt.Config{ HMAC: sha256.New, @@ -159,19 +163,20 @@ func TestHashAuth(t *testing.T) { } inmemStorage := &logical.InmemStorage{} - inmemStorage.Put(context.Background(), &logical.StorageEntry{ + err := inmemStorage.Put(context.Background(), &logical.StorageEntry{ Key: "salt", Value: []byte("foo"), }) + require.NoError(t, err) salter := &TestSalter{} for _, tc := range cases { input := fmt.Sprintf("%#v", tc.Input) - out, err := hashAuth(context.Background(), salter, tc.Input, tc.HMACAccessor) + err := hashAuth(context.Background(), salter, tc.Input, tc.HMACAccessor) if err != nil { t.Fatalf("err: %s\n\n%s", err, input) } - if !reflect.DeepEqual(out, tc.Output) { - t.Fatalf("bad:\nInput:\n%s\nOutput:\n%#v\nExpected output:\n%#v", input, out, tc.Output) + if !reflect.DeepEqual(tc.Input, tc.Output) { + t.Fatalf("bad:\nInput:\n%s\nOutput:\n%#v\nExpected output:\n%#v", input, tc.Input, tc.Output) } } } @@ -217,18 +222,19 @@ func TestHashRequest(t *testing.T) { } inmemStorage := &logical.InmemStorage{} - inmemStorage.Put(context.Background(), &logical.StorageEntry{ + err := inmemStorage.Put(context.Background(), &logical.StorageEntry{ Key: "salt", Value: []byte("foo"), }) + require.NoError(t, err) salter := &TestSalter{} for _, tc := range cases { input := fmt.Sprintf("%#v", tc.Input) - out, err := hashRequest(context.Background(), salter, tc.Input, tc.HMACAccessor, tc.NonHMACDataKeys) + err := hashRequest(context.Background(), salter, tc.Input, tc.HMACAccessor, tc.NonHMACDataKeys) if err != nil { t.Fatalf("err: %s\n\n%s", err, input) } - if diff := deep.Equal(out, tc.Output); len(diff) > 0 { + if diff := deep.Equal(tc.Input, tc.Output); len(diff) > 0 { t.Fatalf("bad:\nInput:\n%s\nDiff:\n%#v", input, diff) } } @@ -282,20 +288,19 @@ func TestHashResponse(t *testing.T) { } inmemStorage := &logical.InmemStorage{} - inmemStorage.Put(context.Background(), &logical.StorageEntry{ + err := inmemStorage.Put(context.Background(), &logical.StorageEntry{ Key: "salt", Value: []byte("foo"), }) + require.NoError(t, err) salter := &TestSalter{} for _, tc := range cases { input := fmt.Sprintf("%#v", tc.Input) - out, err := hashResponse(context.Background(), salter, tc.Input, tc.HMACAccessor, tc.NonHMACDataKeys, false) + err := hashResponse(context.Background(), salter, tc.Input, tc.HMACAccessor, tc.NonHMACDataKeys, false) if err != nil { t.Fatalf("err: %s\n\n%s", err, input) } - if diff := deep.Equal(out, tc.Output); len(diff) > 0 { - t.Fatalf("bad:\nInput:\n%s\nDiff:\n%#v", input, diff) - } + require.Equal(t, tc.Output, tc.Input) } } @@ -326,7 +331,7 @@ func TestHashWalker(t *testing.T) { } for _, tc := range cases { - err := HashStructure(tc.Input, func(string) string { + err := hashStructure(tc.Input, func(string) string { return replaceText }, nil) if err != nil { @@ -380,7 +385,7 @@ func TestHashWalker_TimeStructs(t *testing.T) { } for _, tc := range cases { - err := HashStructure(tc.Input, func(s string) string { + err := hashStructure(tc.Input, func(s string) string { return s + replaceText }, nil) if err != nil {