diff --git a/command/base_predict.go b/command/base_predict.go index 5fca56c090..2347193523 100644 --- a/command/base_predict.go +++ b/command/base_predict.go @@ -15,28 +15,45 @@ import ( // that returning nothing. var defaultPredictVaultMounts = []string{"cubbyhole/"} -// PredictVaultPaths returns a predictor for Vault mounts and paths based on the +// PredictVaultFiles returns a predictor for Vault mounts and paths based on the // configured client for the base command. Unfortunately this happens pre-flag // parsing, so users must rely on environment variables for autocomplete if they // are not using Vault at the default endpoints. -func (b *BaseCommand) PredictVaultPaths() complete.Predictor { +func (b *BaseCommand) PredictVaultFiles() complete.Predictor { client, err := b.Client() if err != nil { return nil } - return PredictVaultPaths(client) + return PredictVaultFiles(client) } -// PredictVaultPaths returns a predictor for Vault paths. This is a public API -// for consumers, but you probably want BaseCommand.PredictVaultPaths instead. -func PredictVaultPaths(client *api.Client) complete.Predictor { - return predictVaultPaths(client) +// PredictVaultFolders returns a predictor for "folders". See PredictVaultFiles +// for more information and restrictions. +func (b *BaseCommand) PredictVaultFolders() complete.Predictor { + client, err := b.Client() + if err != nil { + return nil + } + return PredictVaultFolders(client) +} + +// PredictVaultFiles returns a predictor for Vault "files". This is a public API +// for consumers, but you probably want BaseCommand.PredictVaultFiles instead. +func PredictVaultFiles(client *api.Client) complete.Predictor { + return predictVaultPaths(client, true) +} + +// PredictVaultFolders returns a predictor for Vault "folders". This is a public +// API for consumers, but you probably want BaseCommand.PredictVaultFolders +// instead. +func PredictVaultFolders(client *api.Client) complete.Predictor { + return predictVaultPaths(client, false) } // predictVaultPaths parses the CLI options and returns the "best" list of // possible paths. If there are any errors, this function returns an empty // result. All errors are suppressed since this is a prediction function. -func predictVaultPaths(client *api.Client) complete.PredictFunc { +func predictVaultPaths(client *api.Client, includeFiles bool) complete.PredictFunc { return func(args complete.Args) []string { // Do not predict more than one paths if predictHasPathArg(args.All) { @@ -47,7 +64,7 @@ func predictVaultPaths(client *api.Client) complete.PredictFunc { var predictions []string if strings.Contains(path, "/") { - predictions = predictPaths(client, path) + predictions = predictPaths(client, path, includeFiles) } else { predictions = predictMounts(client, path) } @@ -70,7 +87,7 @@ func predictVaultPaths(client *api.Client) complete.PredictFunc { // Re-predict with the remaining path args.Last = predictions[0] - return predictVaultPaths(client).Predict(args) + return predictVaultPaths(client, includeFiles).Predict(args) } } @@ -90,7 +107,7 @@ func predictMounts(client *api.Client, path string) []string { } // predictPaths predicts all paths which start with the given path. -func predictPaths(client *api.Client, path string) []string { +func predictPaths(client *api.Client, path string, includeFiles bool) []string { // Vault does not support listing based on a sub-key, so we have to back-pedal // to the last "/" and return all paths on that "folder". Then we perform // client-side filtering. @@ -108,7 +125,10 @@ func predictPaths(client *api.Client, path string) []string { p = root + p if strings.HasPrefix(p, path) { - predictions = append(predictions, p) + // Ensure this is a directory or we've asked to include files. + if includeFiles || strings.HasSuffix(p, "/") { + predictions = append(predictions, p) + } } } diff --git a/command/base_predict_test.go b/command/base_predict_test.go index e579b9ec18..05af55da7c 100644 --- a/command/base_predict_test.go +++ b/command/base_predict_test.go @@ -31,12 +31,11 @@ func TestPredictVaultPaths(t *testing.T) { t.Fatal(err) } - f := predictVaultPaths(client) - cases := []struct { - name string - args complete.Args - exp []string + name string + args complete.Args + includeFiles bool + exp []string }{ { "has_args", @@ -44,6 +43,16 @@ func TestPredictVaultPaths(t *testing.T) { All: []string{"read", "secret/foo", "a=b"}, Last: "a=b", }, + true, + nil, + }, + { + "has_args_no_files", + complete.Args{ + All: []string{"read", "secret/foo", "a=b"}, + Last: "a=b", + }, + false, nil, }, { @@ -52,6 +61,16 @@ func TestPredictVaultPaths(t *testing.T) { All: []string{"read", "s"}, Last: "s", }, + true, + []string{"secret/", "sys/"}, + }, + { + "part_mount_no_files", + complete.Args{ + All: []string{"read", "s"}, + Last: "s", + }, + false, []string{"secret/", "sys/"}, }, { @@ -60,48 +79,108 @@ func TestPredictVaultPaths(t *testing.T) { All: []string{"read", "sec"}, Last: "sec", }, + true, []string{"secret/bar", "secret/foo", "secret/zip/"}, }, + { + "only_mount_no_files", + complete.Args{ + All: []string{"read", "sec"}, + Last: "sec", + }, + false, + []string{"secret/zip/"}, + }, { "full_mount", complete.Args{ All: []string{"read", "secret"}, Last: "secret", }, + true, []string{"secret/bar", "secret/foo", "secret/zip/"}, }, + { + "full_mount_no_files", + complete.Args{ + All: []string{"read", "secret"}, + Last: "secret", + }, + false, + []string{"secret/zip/"}, + }, { "full_mount_slash", complete.Args{ All: []string{"read", "secret/"}, Last: "secret/", }, + true, []string{"secret/bar", "secret/foo", "secret/zip/"}, }, + { + "full_mount_slash_no_files", + complete.Args{ + All: []string{"read", "secret/"}, + Last: "secret/", + }, + false, + []string{"secret/zip/"}, + }, { "path_partial", complete.Args{ All: []string{"read", "secret/z"}, Last: "secret/z", }, + true, []string{"secret/zip/twoot", "secret/zip/zap", "secret/zip/zonk"}, }, + { + "path_partial_no_files", + complete.Args{ + All: []string{"read", "secret/z"}, + Last: "secret/z", + }, + false, + []string{"secret/zip/"}, + }, { "subpath_partial_z", complete.Args{ All: []string{"read", "secret/zip/z"}, Last: "secret/zip/z", }, + true, []string{"secret/zip/zap", "secret/zip/zonk"}, }, + { + "subpath_partial_z_no_files", + complete.Args{ + All: []string{"read", "secret/zip/z"}, + Last: "secret/zip/z", + }, + false, + []string{"secret/zip/z"}, + }, { "subpath_partial_t", complete.Args{ All: []string{"read", "secret/zip/t"}, Last: "secret/zip/t", }, + true, []string{"secret/zip/twoot"}, }, + { + "subpath_partial_t_no_files", + complete.Args{ + All: []string{"read", "secret/zip/t"}, + Last: "secret/zip/t", + }, + false, + []string{"secret/zip/t"}, + }, } t.Run("group", func(t *testing.T) { @@ -110,6 +189,7 @@ func TestPredictVaultPaths(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() + f := predictVaultPaths(client, tc.includeFiles) act := f(tc.args) if !reflect.DeepEqual(act, tc.exp) { t.Errorf("expected %q to be %q", act, tc.exp) @@ -126,26 +206,22 @@ func TestPredictMounts(t *testing.T) { defer closer() cases := []struct { - name string - client *api.Client - path string - exp []string + name string + path string + exp []string }{ { "no_match", - client, "not-a-real-mount-seriously", nil, }, { "s", - client, "s", []string{"secret/", "sys/"}, }, { "se", - client, "se", []string{"secret/"}, }, @@ -157,7 +233,7 @@ func TestPredictMounts(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - act := predictMounts(tc.client, tc.path) + act := predictMounts(client, tc.path) if !reflect.DeepEqual(act, tc.exp) { t.Errorf("expected %q to be %q", act, tc.exp) } @@ -184,27 +260,39 @@ func TestPredictPaths(t *testing.T) { } cases := []struct { - name string - client *api.Client - path string - exp []string + name string + path string + includeFiles bool + exp []string }{ { "bad_path", - client, "nope/not/a/real/path/ever", + true, []string{"nope/not/a/real/path/ever"}, }, { "good_path", - client, "secret/", + true, []string{"secret/bar", "secret/foo", "secret/zip/"}, }, + { + "good_path_no_files", + "secret/", + false, + []string{"secret/zip/"}, + }, { "partial_match", - client, "secret/z", + true, + []string{"secret/zip/"}, + }, + { + "partial_match_no_files", + "secret/z", + false, []string{"secret/zip/"}, }, } @@ -215,7 +303,7 @@ func TestPredictPaths(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - act := predictPaths(tc.client, tc.path) + act := predictPaths(client, tc.path, tc.includeFiles) if !reflect.DeepEqual(act, tc.exp) { t.Errorf("expected %q to be %q", act, tc.exp) }