diff --git a/command/auth_enable.go b/command/auth_enable.go index e6b7f20f24..a5ba67f27d 100644 --- a/command/auth_enable.go +++ b/command/auth_enable.go @@ -5,137 +5,162 @@ import ( "strings" "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/meta" + "github.com/mitchellh/cli" "github.com/posener/complete" ) +// Ensure we are implementing the right interfaces. +var _ cli.Command = (*AuthEnableCommand)(nil) +var _ cli.CommandAutocomplete = (*AuthEnableCommand)(nil) + // AuthEnableCommand is a Command that enables a new endpoint. type AuthEnableCommand struct { - meta.Meta -} + *BaseCommand -func (c *AuthEnableCommand) Run(args []string) int { - var description, path, pluginName string - var local bool - flags := c.Meta.FlagSet("auth-enable", meta.FlagSetDefault) - flags.StringVar(&description, "description", "", "") - flags.StringVar(&path, "path", "", "") - flags.StringVar(&pluginName, "plugin-name", "", "") - flags.BoolVar(&local, "local", false, "") - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - args = flags.Args() - if len(args) != 1 { - flags.Usage() - c.Ui.Error(fmt.Sprintf( - "\nauth-enable expects one argument: the type to enable.")) - return 1 - } - - authType := args[0] - - // If no path is specified, we default the path to the backend type - // or use the plugin name if it's a plugin backend - if path == "" { - if authType == "plugin" { - path = pluginName - } else { - path = authType - } - } - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 2 - } - - if err := client.Sys().EnableAuthWithOptions(path, &api.EnableAuthOptions{ - Type: authType, - Description: description, - Config: api.AuthConfigInput{ - PluginName: pluginName, - }, - Local: local, - }); err != nil { - c.Ui.Error(fmt.Sprintf( - "Error: %s", err)) - return 2 - } - - authTypeOutput := fmt.Sprintf("'%s'", authType) - if authType == "plugin" { - authTypeOutput = fmt.Sprintf("plugin '%s'", pluginName) - } - - c.Ui.Output(fmt.Sprintf( - "Successfully enabled %s at '%s'!", - authTypeOutput, path)) - - return 0 + flagDescription string + flagPath string + flagPluginName string + flagLocal bool } func (c *AuthEnableCommand) Synopsis() string { - return "Enable a new auth provider" + return "Enables a new auth provider" } func (c *AuthEnableCommand) Help() string { helpText := ` -Usage: vault auth-enable [options] type +Usage: vault auth-enable [options] TYPE - Enable a new auth provider. + Enables a new authentication provider. An authentication provider is + responsible for authenticating users or machiens and assigning them + policies with which they can access Vault. - This command enables a new auth provider. An auth provider is responsible - for authenticating a user and assigning them policies with which they can - access Vault. + Enable the userpass auth provider at userpass/: -General Options: -` + meta.GeneralOptionsUsage() + ` -Auth Enable Options: + $ vault auth-enable userpass - -description= Human-friendly description of the purpose of the - auth provider. This shows up in the auth -methods command. + Enable the LDAP auth provider at auth-prod/: - -path= Mount point for the auth provider. This defaults - to the type of the mount. This will make the auth - provider available at "/auth/" + $ vault auth-enable -path=auth-prod ldap - -plugin-name Name of the auth plugin to use based from the name - in the plugin catalog. + Enable a custom auth plugin (after it is registered in the plugin registry): + + $ vault auth-enable -path=my-auth -plugin-name=my-auth-plugin plugin + + For a full list of examples, please see the documentation. + +` + c.Flags().Help() - -local Mark the mount as a local mount. Local mounts - are not replicated nor (if a secondary) - removed by replication. -` return strings.TrimSpace(helpText) } -func (c *AuthEnableCommand) AutocompleteArgs() complete.Predictor { - return complete.PredictSet( - "approle", - "cert", - "aws", - "app-id", - "gcp", - "github", - "userpass", - "ldap", - "okta", - "radius", - "plugin", - ) +func (c *AuthEnableCommand) Flags() *FlagSets { + set := c.flagSet(FlagSetHTTP) + f := set.NewFlagSet("Command Options") + + f.StringVar(&StringVar{ + Name: "description", + Target: &c.flagDescription, + Completion: complete.PredictAnything, + Usage: "Human-friendly description for the purpose of this " + + "authentication provider.", + }) + + f.StringVar(&StringVar{ + Name: "path", + Target: &c.flagPath, + Default: "", // The default is complex, so we have to manually document + Completion: complete.PredictAnything, + Usage: "Place where the auth provider will be accessible. This must be " + + "unique across all auth providers. This defaults to the \"type\" of " + + "the mount. The auth provider will be accessible at \"/auth/\".", + }) + + f.StringVar(&StringVar{ + Name: "plugin-name", + Target: &c.flagPluginName, + Completion: complete.PredictAnything, + Usage: "Name of the auth provider plugin. This plugin name must already " + + "exist in the Vault server's plugin catalog.", + }) + + f.BoolVar(&BoolVar{ + Name: "local", + Target: &c.flagLocal, + Default: false, + Usage: "Mark the auth provider as local-only. Local auth providers are " + + "not replicated nor removed by replication.", + }) + + return set +} + +func (c *AuthEnableCommand) AutocompleteArgs() complete.Predictor { + return c.PredictVaultAvailableAuths() } func (c *AuthEnableCommand) AutocompleteFlags() complete.Flags { - return complete.Flags{ - "-description": complete.PredictNothing, - "-path": complete.PredictNothing, - "-plugin-name": complete.PredictNothing, - "-local": complete.PredictNothing, - } + return c.Flags().Completions() +} + +func (c *AuthEnableCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + switch { + case len(args) < 1: + c.UI.Error(fmt.Sprintf("Not enough arguments (expected 1, got %d)", len(args))) + return 1 + case len(args) > 1: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args))) + return 1 + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + authType := strings.TrimSpace(args[0]) + + // If no path is specified, we default the path to the backend type + // or use the plugin name if it's a plugin backend + authPath := c.flagPath + if authPath == "" { + if authType == "plugin" { + authPath = c.flagPluginName + } else { + authPath = authType + } + } + + // Append a trailing slash to indicate it's a path in output + authPath = ensureTrailingSlash(authPath) + + if err := client.Sys().EnableAuthWithOptions(authPath, &api.EnableAuthOptions{ + Type: authType, + Description: c.flagDescription, + Config: api.AuthConfigInput{ + PluginName: c.flagPluginName, + }, + Local: c.flagLocal, + }); err != nil { + c.UI.Error(fmt.Sprintf("Error enabling %s auth: %s", authType, err)) + return 2 + } + + authThing := authType + " auth provider" + if authType == "plugin" { + authThing = c.flagPluginName + " plugin" + } + + c.UI.Output(fmt.Sprintf("Success! Enabled %s at: %s", authThing, authPath)) + return 0 } diff --git a/command/auth_enable_test.go b/command/auth_enable_test.go index 0f8348700f..00016d0138 100644 --- a/command/auth_enable_test.go +++ b/command/auth_enable_test.go @@ -1,50 +1,144 @@ package command import ( + "strings" "testing" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" ) -func TestAuthEnable(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func testAuthEnableCommand(tb testing.TB) (*cli.MockUi, *AuthEnableCommand) { + tb.Helper() - ui := new(cli.MockUi) - c := &AuthEnableCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + ui := cli.NewMockUi() + return ui, &AuthEnableCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestAuthEnableCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "not_enough_args", + nil, + "Not enough arguments", + 1, + }, + { + "too_many_args", + []string{"foo", "bar"}, + "Too many arguments", + 1, + }, + { + "not_a_valid_auth", + []string{"nope_definitely_not_a_valid_mount_like_ever"}, + "", + 2, }, } - args := []string{ - "-address", addr, - "noop", - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testAuthEnableCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) } - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } + t.Run("integration", func(t *testing.T) { + t.Parallel() - mounts, err := client.Sys().ListAuth() - if err != nil { - t.Fatalf("err: %s", err) - } + client, closer := testVaultServer(t) + defer closer() - mount, ok := mounts["noop/"] - if !ok { - t.Fatal("should have noop mount") - } - if mount.Type != "noop" { - t.Fatal("should be noop type") - } + ui, cmd := testAuthEnableCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "-path", "auth_integration/", + "-description", "The best kind of test", + "userpass", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Success! Enabled userpass auth provider at:" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + auths, err := client.Sys().ListAuth() + if err != nil { + t.Fatal(err) + } + + authInfo, ok := auths["auth_integration/"] + if !ok { + t.Fatalf("expected mount to exist") + } + if exp := "userpass"; authInfo.Type != exp { + t.Errorf("expected %q to be %q", authInfo.Type, exp) + } + if exp := "The best kind of test"; authInfo.Description != exp { + t.Errorf("expected %q to be %q", authInfo.Description, exp) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testAuthEnableCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "userpass", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error enabling userpass auth: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testAuthEnableCommand(t) + assertNoTabs(t, cmd) + }) }