diff --git a/CHANGELOG.md b/CHANGELOG.md index 733764c655..2d3c4dc83a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -293,14 +293,14 @@ IMPROVEMENTS: (PID) in a file [GH-3321] * mfa (Enterprise): Add the ability to use identity metadata in username format * mfa/okta (Enterprise): Add support for configuring base_url for API calls - * secret/pki: `sign-intermediate` will now allow specifying a `ttl` value + * secret/pki: `sign-intermediate` will now allow specifying a `ttl` value longer than the signing CA certificate's NotAfter value. [GH-3325] * sys/raw: Raw storage access is now disabled by default [GH-3329] BUG FIXES: * auth/okta: Fix regression that removed the ability to set base_url [GH-3313] - * core: Fix panic while loading leases at startup on ARM processors + * core: Fix panic while loading leases at startup on ARM processors [GH-3314] * secret/pki: Fix `sign-self-issued` encoding the wrong subject public key [GH-3325] @@ -350,7 +350,7 @@ IMPROVEMENTS: * auth/okta: Compare groups case-insensitively since Okta is only case-preserving [GH-3240] * auth/okta: Standardize Okta configuration APIs across backends [GH-3245] - * cli: Add subcommand autocompletion that can be enabled with + * cli: Add subcommand autocompletion that can be enabled with `vault -autocomplete-install` [GH-3223] * cli: Add ability to handle wrapped responses when using `vault auth`. What is output depends on the other given flags; see the help output for that diff --git a/Makefile b/Makefile index 60fcc39c4d..f6f07c5001 100644 --- a/Makefile +++ b/Makefile @@ -31,7 +31,12 @@ dev-dynamic: prep # test runs the unit tests and vets the code test: prep - CGO_ENABLED=0 VAULT_TOKEN= VAULT_ACC= go test -tags='$(BUILD_TAGS)' $(TEST) $(TESTARGS) -timeout=20m -parallel=4 + @CGO_ENABLED=0 \ + VAULT_ADDR= \ + VAULT_TOKEN= \ + VAULT_DEV_ROOT_TOKEN_ID= \ + VAULT_ACC= \ + go test -tags='$(BUILD_TAGS)' $(TEST) $(TESTARGS) -timeout=20m -parallel=20 testcompile: prep @for pkg in $(TEST) ; do \ @@ -48,7 +53,12 @@ testacc: prep # testrace runs the race checker testrace: prep - CGO_ENABLED=1 VAULT_TOKEN= VAULT_ACC= go test -tags='$(BUILD_TAGS)' -race $(TEST) $(TESTARGS) -timeout=45m -parallel=4 + @CGO_ENABLED=1 \ + VAULT_ADDR= \ + VAULT_TOKEN= \ + VAULT_DEV_ROOT_TOKEN_ID= \ + VAULT_ACC= \ + go test -tags='$(BUILD_TAGS)' -race $(TEST) $(TESTARGS) -timeout=45m -parallel=20 cover: ./scripts/coverage.sh --html diff --git a/README.md b/README.md index 94e0e0cb42..d3fa4a931d 100644 --- a/README.md +++ b/README.md @@ -102,9 +102,9 @@ $ make test TEST=./vault ### Acceptance Tests Vault has comprehensive [acceptance tests](https://en.wikipedia.org/wiki/Acceptance_testing) -covering most of the features of the secret and auth backends. +covering most of the features of the secret and auth methods. -If you're working on a feature of a secret or auth backend and want to +If you're working on a feature of a secret or auth method and want to verify it is functioning (and also hasn't broken anything else), we recommend running the acceptance tests. diff --git a/api/SPEC.md b/api/SPEC.md deleted file mode 100644 index 15345f3905..0000000000 --- a/api/SPEC.md +++ /dev/null @@ -1,611 +0,0 @@ -FORMAT: 1A - -# vault - -The Vault API gives you full access to the Vault project. - -If you're browsing this API specifiction in GitHub or in raw -format, please excuse some of the odd formatting. This document -is in api-blueprint format that is read by viewers such as -Apiary. - -## Sealed vs. Unsealed - -Whenever an individual Vault server is started, it is started -in the _sealed_ state. In this state, it knows where its data -is located, but the data is encrypted and Vault doesn't have the -encryption keys to access it. Before Vault can operate, it must -be _unsealed_. - -**Note:** Sealing/unsealing has no relationship to _authentication_ -which is separate and still required once the Vault is unsealed. - -Instead of being sealed with a single key, we utilize -[Shamir's Secret Sharing](http://en.wikipedia.org/wiki/Shamir%27s_Secret_Sharing) -to shard a key into _n_ parts such that _t_ parts are required -to reconstruct the original key, where `t <= n`. This means that -Vault itself doesn't know the original key, and no single person -has the original key (unless `n = 1`, or `t` parts are given to -a single person). - -Unsealing is done via an unauthenticated -[unseal API](#reference/seal/unseal/unseal). This API takes a single -master shard and progresses the unsealing process. Once all shards -are given, the Vault is either unsealed or resets the unsealing -process if the key was invalid. - -The entire seal/unseal state is server-wide. This allows multiple -distinct operators to use the unseal API (or more likely the -`vault unseal` command) from separate computers/networks and never -have to transmit their key in order to unseal the vault in a -distributed fashion. - -## Transport - -The API is expected to be accessed over a TLS connection at -all times, with a valid certificate that is verified by a well -behaved client. - -## Authentication - -Once the Vault is unsealed, every other operation requires -authentication. There are multiple methods for authentication -that can be enabled (see -[authentication](#reference/authentication)). - -Authentication is done with the login endpoint. The login endpoint -returns an access token that is set as the `X-Vault-Token` header. - -## Help - -To retrieve the help for any API within Vault, including mounted -backends, credential providers, etc. then append `?help=1` to any -URL. If you have valid permission to access the path, then the help text -will be returned with the following structure: - - { - "help": "help text" - } - -## Error Response - -A common JSON structure is always returned to return errors: - - { - "errors": [ - "message", - "another message" - ] - } - -This structure will be sent down for any non-20x HTTP status. - -## HTTP Status Codes - -The following HTTP status codes are used throughout the API. - -- `200` - Success with data. -- `204` - Success, no data returned. -- `400` - Invalid request, missing or invalid data. -- `403` - Forbidden, your authentication details are either - incorrect or you don't have access to this feature. -- `404` - Invalid path. This can both mean that the path truly - doesn't exist or that you don't have permission to view a - specific path. We use 404 in some cases to avoid state leakage. -- `429` - Rate limit exceeded. Try again after waiting some period - of time. -- `500` - Internal server error. An internal error has occurred, - try again later. If the error persists, report a bug. -- `503` - Vault is down for maintenance or is currently sealed. - Try again later. - -# Group Initialization - -## Initialization [/sys/init] -### Initialization Status [GET] -Returns the status of whether the vault is initialized or not. The -vault doesn't have to be unsealed for this operation. - -+ Response 200 (application/json) - - { - "initialized": true - } - -### Initialize [POST] -Initialize the vault. This is an unauthenticated request to initially -setup a new vault. Although this is unauthenticated, it is still safe: -data cannot be in vault prior to initialization, and any future -authentication will fail if you didn't initialize it yourself. -Additionally, once initialized, a vault cannot be reinitialized. - -This API is the only time Vault will ever be aware of your keys, and -the only time the keys will ever be returned in one unit. Care should -be taken to ensure that the output of this request is never logged, -and that the keys are properly distributed. - -The response also contains the initial root token that can be used -as authentication in order to initially configure Vault once it is -unsealed. Just as with the unseal keys, this is the only time Vault is -ever aware of this token. - -+ Request (application/json) - - { - "secret_shares": 5, - "secret_threshold": 3, - } - -+ Response 200 (application/json) - - { - "keys": ["one", "two", "three"], - "root_token": "foo" - } - -# Group Seal/Unseal - -## Seal Status [/sys/seal-status] -### Seal Status [GET] -Returns the status of whether the vault is currently -sealed or not, as well as the progress of unsealing. - -The response has the following attributes: - -- sealed (boolean) - If true, the vault is sealed. Otherwise, - it is unsealed. -- t (int) - The "t" value for the master key, or the number - of shards needed total to unseal the vault. -- n (int) - The "n" value for the master key, or the total - number of shards of the key distributed. -- progress (int) - The number of master key shards that have - been entered so far towards unsealing the vault. - -+ Response 200 (application/json) - - { - "sealed": true, - "t": 3, - "n": 5, - "progress": 1 - } - -## Seal [/sys/seal] -### Seal [PUT] -Seal the vault. - -Sealing the vault locks Vault from any future operations on any -secrets or system configuration until the vault is once again -unsealed. Internally, sealing throws away the keys to access the -encrypted vault data, so Vault is unable to access the data without -unsealing to get the encryption keys. - -+ Response 204 - -## Unseal [/sys/unseal] -### Unseal [PUT] -Unseal the vault. - -Unseal the vault by entering a portion of the master key. The -response object will tell you if the unseal is complete or -only partial. - -If the vault is already unsealed, this does nothing. It is -not an error, the return value just says the vault is unsealed. -Due to the architecture of Vault, we cannot validate whether -any portion of the unseal key given is valid until all keys -are inputted, therefore unsealing an already unsealed vault -is still a success even if the input key is invalid. - -+ Request (application/json) - - { - "key": "value" - } - -+ Response 200 (application/json) - - { - "sealed": true, - "t": 3, - "n": 5, - "progress": 1 - } - -# Group Authentication - -## List Auth Methods [/sys/auth] -### List all auth methods [GET] -Lists all available authentication methods. - -This returns the name of the authentication method as well as -a human-friendly long-form help text for the method that can be -shown to the user as documentation. - -+ Response 200 (application/json) - - { - "token": { - "type": "token", - "description": "Token authentication" - }, - "oauth": { - "type": "oauth", - "description": "OAuth authentication" - } - } - -## Single Auth Method [/sys/auth/{id}] - -+ Parameters - + id (required, string) ... The ID of the auth method. - -### Enable an auth method [PUT] -Enables an authentication method. - -The body of the request depends on the authentication method -being used. Please reference the documentation for the specific -authentication method you're enabling in order to determine what -parameters you must give it. - -If an authentication method is already enabled, then this can be -used to change the configuration, including even the type of -the configuration. - -+ Request (application/json) - - { - "type": "type", - "key": "value", - "key2": "value2" - } - -+ Response 204 - -### Disable an auth method [DELETE] -Disables an authentication method. Previously authenticated sessions -are immediately invalidated. - -+ Response 204 - -# Group Policies - -Policies are named permission sets that identities returned by -credential stores are bound to. This separates _authentication_ -from _authorization_. - -## Policies [/sys/policy] -### List all Policies [GET] - -List all the policies. - -+ Response 200 (application/json) - - { - "policies": ["root"] - } - -## Single Policy [/sys/policy/{id}] - -+ Parameters - + id (required, string) ... The name of the policy - -### Upsert [PUT] - -Create or update a policy with the given ID. - -+ Request (application/json) - - { - "rules": "HCL" - } - -+ Response 204 - -### Delete [DELETE] - -Delete a policy with the given ID. Any identities bound to this -policy will immediately become "deny all" despite already being -authenticated. - -+ Response 204 - -# Group Mounts - -Logical backends are mounted at _mount points_, similar to -filesystems. This allows you to mount the "aws" logical backend -at the "aws-us-east" path, so all access is at `/aws-us-east/keys/foo` -for example. This enables multiple logical backends to be enabled. - -## Mounts [/sys/mounts] -### List all mounts [GET] - -Lists all the active mount points. - -+ Response 200 (application/json) - - { - "aws": { - "type": "aws", - "description": "AWS" - }, - "pg": { - "type": "postgresql", - "description": "PostgreSQL dynamic users" - } - } - -## Single Mount [/sys/mounts/{path}] -### New Mount [POST] - -Mount a logical backend to a new path. - -Configuration for this new backend is done via the normal -read/write mechanism once it is mounted. - -+ Request (application/json) - - { - "type": "aws", - "description": "EU AWS tokens" - } - -+ Response 204 - -### Unmount [DELETE] - -Unmount a mount point. - -+ Response 204 - -## Remount [/sys/remount] -### Remount [POST] - -Move an already-mounted backend to a new path. - -+ Request (application/json) - - { - "from": "aws", - "to": "aws-east" - } - -+ Response 204 - -# Group Audit Backends - -Audit backends are responsible for shuttling the audit logs that -Vault generates to a durable system for future querying. By default, -audit logs are not stored anywhere. - -## Audit Backends [/sys/audit] -### List Enabled Audit Backends [GET] - -List all the enabled audit backends - -+ Response 200 (application/json) - - { - "file": { - "type": "file", - "description": "Send audit logs to a file", - "options": {} - } - } - -## Single Audit Backend [/sys/audit/{path}] - -+ Parameters - + path (required, string) ... The path where the audit backend is mounted - -### Enable [PUT] - -Enable an audit backend. - -+ Request (application/json) - - { - "type": "file", - "description": "send to a file", - "options": { - "path": "/var/log/vault.audit.log" - } - } - -+ Response 204 - -### Disable [DELETE] - -Disable an audit backend. - -+ Request (application/json) - -+ Response 204 - -# Group Secrets - -## Generic [/{mount}/{path}] - -This group documents the general format of reading and writing -to Vault. The exact structure of the keyspace is defined by the -logical backends in use, so documentation related to -a specific backend should be referenced for details on what keys -and routes are expected. - -The path for examples are `/prefix/path`, but in practice -these will be defined by the backends that are mounted. For -example, reading an AWS key might be at the `/aws/root` path. -These paths are defined by the logical backends. - -+ Parameters - + mount (required, string) ... The mount point for the - logical backend. Example: `aws`. - + path (optional, string) ... The path within the backend - to read or write data. - -### Read [GET] - -Read data from vault. - -The data read from the vault can either be a secret or -arbitrary configuration data. The type of data returned -depends on the path, and is defined by the logical backend. - -If the return value is a secret, then the return structure -is a mixture of arbitrary key/value along with the following -fields which are guaranteed to exist: - -- `lease_id` (string) - A unique ID used for renewal and - revocation. - -- `renewable` (bool) - If true, then this key can be renewed. - If a key can't be renewed, then a new key must be requested - after the lease duration period. - -- `lease_duration` (int) - The time in seconds that a secret is - valid for before it must be renewed. - -- `lease_duration_max` (int) - The maximum amount of time in - seconds that a secret is valid for. This will always be - greater than or equal to `lease_duration`. The difference - between this and `lease_duration` is an overlap window - where multiple keys may be valid. - -If the return value is not a secret, then the return structure -is an arbitrary JSON object. - -+ Response 200 (application/json) - - { - "lease_id": "UUID", - "lease_duration": 3600, - "key": "value" - } - -### Write [PUT] - -Write data to vault. - -The behavior and arguments to the write are defined by -the logical backend. - -+ Request (application/json) - - { - "key": "value" - } - -+ Response 204 - -# Group Lease Management - -## Renew Key [/sys/renew/{id}] - -+ Parameters - + id (required, string) ... The `lease_id` of the secret - to renew. - -### Renew [PUT] - -+ Response 200 (application/json) - - { - "lease_id": "...", - "lease_duration": 3600, - "access_key": "foo", - "secret_key": "bar" - } - -## Revoke Key [/sys/revoke/{id}] - -+ Parameters - + id (required, string) ... The `lease_id` of the secret - to revoke. - -### Revoke [PUT] - -+ Response 204 - -# Group Backend: AWS - -## Root Key [/aws/root] -### Set the Key [PUT] - -Set the root key that the logical backend will use to create -new secrets, IAM policies, etc. - -+ Request (application/json) - - { - "access_key": "key", - "secret_key": "key", - "region": "us-east-1" - } - -+ Response 204 - -## Policies [/aws/policies] -### List Policies [GET] - -List all the policies that can be used to create keys. - -+ Response 200 (application/json) - - [{ - "name": "root", - "description": "Root access" - }, { - "name": "web-deploy", - "description": "Enough permissions to deploy the web app." - }] - -## Single Policy [/aws/policies/{name}] - -+ Parameters - + name (required, string) ... Name of the policy. - -### Read [GET] - -Read a policy. - -+ Response 200 (application/json) - - { - "policy": "base64-encoded policy" - } - -### Upsert [PUT] - -Create or update a policy. - -+ Request (application/json) - - { - "policy": "base64-encoded policy" - } - -+ Response 204 - -### Delete [DELETE] - -Delete the policy with the given name. - -+ Response 204 - -## Generate Access Keys [/aws/keys/{policy}] -### Create [GET] - -This generates a new keypair for the given policy. - -+ Parameters - + policy (required, string) ... The policy under which to create - the key pair. - -+ Response 200 (application/json) - - { - "lease_id": "...", - "lease_duration": 3600, - "access_key": "foo", - "secret_key": "bar" - } diff --git a/api/api_integration_test.go b/api/api_integration_test.go index c4e1a1d807..a9e4409ae1 100644 --- a/api/api_integration_test.go +++ b/api/api_integration_test.go @@ -1,59 +1,131 @@ package api_test import ( + "context" "database/sql" + "encoding/base64" "fmt" + "net" + "net/http" "testing" + "time" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/audit" + "github.com/hashicorp/vault/builtin/logical/database" "github.com/hashicorp/vault/builtin/logical/pki" "github.com/hashicorp/vault/builtin/logical/transit" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/vault" + auditFile "github.com/hashicorp/vault/builtin/audit/file" + credUserpass "github.com/hashicorp/vault/builtin/credential/userpass" vaulthttp "github.com/hashicorp/vault/http" logxi "github.com/mgutz/logxi/v1" dockertest "gopkg.in/ory-am/dockertest.v3" ) -var testVaultServerDefaultBackends = map[string]logical.Factory{ - "transit": transit.Factory, - "pki": pki.Factory, -} - +// testVaultServer creates a test vault cluster and returns a configured API +// client and closer function. func testVaultServer(t testing.TB) (*api.Client, func()) { - return testVaultServerBackends(t, testVaultServerDefaultBackends) + t.Helper() + + client, _, closer := testVaultServerUnseal(t) + return client, closer } -func testVaultServerBackends(t testing.TB, backends map[string]logical.Factory) (*api.Client, func()) { - coreConfig := &vault.CoreConfig{ - DisableMlock: true, - DisableCache: true, - Logger: logxi.NullLog, - LogicalBackends: backends, - } +// testVaultServerUnseal creates a test vault cluster and returns a configured +// API client, list of unseal keys (as strings), and a closer function. +func testVaultServerUnseal(t testing.TB) (*api.Client, []string, func()) { + t.Helper() + + return testVaultServerCoreConfig(t, &vault.CoreConfig{ + DisableMlock: true, + DisableCache: true, + Logger: logxi.NullLog, + CredentialBackends: map[string]logical.Factory{ + "userpass": credUserpass.Factory, + }, + AuditBackends: map[string]audit.Factory{ + "file": auditFile.Factory, + }, + LogicalBackends: map[string]logical.Factory{ + "database": database.Factory, + "generic-leased": vault.LeasedPassthroughBackendFactory, + "pki": pki.Factory, + "transit": transit.Factory, + }, + }) +} + +// testVaultServerCoreConfig creates a new vault cluster with the given core +// configuration. This is a lower-level test helper. +func testVaultServerCoreConfig(t testing.TB, coreConfig *vault.CoreConfig) (*api.Client, []string, func()) { + t.Helper() cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ HandlerFunc: vaulthttp.Handler, }) cluster.Start() - // make it easy to get access to the active + // Make it easy to get access to the active core := cluster.Cores[0].Core vault.TestWaitActive(t, core) + // Get the client already setup for us! client := cluster.Cores[0].Client client.SetToken(cluster.RootToken) - // Sanity check - secret, err := client.Auth().Token().LookupSelf() + // Convert the unseal keys to base64 encoded, since these are how the user + // will get them. + unsealKeys := make([]string, len(cluster.BarrierKeys)) + for i := range unsealKeys { + unsealKeys[i] = base64.StdEncoding.EncodeToString(cluster.BarrierKeys[i]) + } + + return client, unsealKeys, func() { defer cluster.Cleanup() } +} + +// testVaultServerBad creates an http server that returns a 500 on each request +// to simulate failures. +func testVaultServerBad(t testing.TB) (*api.Client, func()) { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatal(err) } - if secret == nil || secret.Data["id"].(string) != cluster.RootToken { - t.Fatalf("token mismatch: %#v vs %q", secret, cluster.RootToken) + + server := &http.Server{ + Addr: "127.0.0.1:0", + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "500 internal server error", http.StatusInternalServerError) + }), + ReadTimeout: 1 * time.Second, + ReadHeaderTimeout: 1 * time.Second, + WriteTimeout: 1 * time.Second, + IdleTimeout: 1 * time.Second, + } + + go func() { + if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { + t.Fatal(err) + } + }() + + client, err := api.NewClient(&api.Config{ + Address: "http://" + listener.Addr().String(), + }) + if err != nil { + t.Fatal(err) + } + + return client, func() { + ctx, done := context.WithTimeout(context.Background(), 5*time.Second) + defer done() + + server.Shutdown(ctx) } - return client, func() { defer cluster.Cleanup() } } // testPostgresDB creates a testing postgres database in a Docker container, diff --git a/api/renewer_integration_test.go b/api/renewer_integration_test.go index 7011c7d10a..50a775e126 100644 --- a/api/renewer_integration_test.go +++ b/api/renewer_integration_test.go @@ -5,20 +5,12 @@ import ( "time" "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/builtin/logical/database" - "github.com/hashicorp/vault/builtin/logical/pki" - "github.com/hashicorp/vault/builtin/logical/transit" - "github.com/hashicorp/vault/logical" ) func TestRenewer_Renew(t *testing.T) { t.Parallel() - client, vaultDone := testVaultServerBackends(t, map[string]logical.Factory{ - "database": database.Factory, - "pki": pki.Factory, - "transit": transit.Factory, - }) + client, vaultDone := testVaultServer(t) defer vaultDone() pgURL, pgDone := testPostgresDB(t) diff --git a/api/secret.go b/api/secret.go index b111501490..4891651622 100644 --- a/api/secret.go +++ b/api/secret.go @@ -1,10 +1,12 @@ package api import ( + "fmt" "io" "time" "github.com/hashicorp/vault/helper/jsonutil" + "github.com/hashicorp/vault/helper/parseutil" ) // Secret is the structure returned for every secret within Vault. @@ -35,6 +37,188 @@ type Secret struct { WrapInfo *SecretWrapInfo `json:"wrap_info,omitempty"` } +// TokenID returns the standardized token ID (token) for the given secret. +func (s *Secret) TokenID() (string, error) { + if s == nil { + return "", nil + } + + if s.Auth != nil && len(s.Auth.ClientToken) > 0 { + return s.Auth.ClientToken, nil + } + + if s.Data == nil || s.Data["id"] == nil { + return "", nil + } + + id, ok := s.Data["id"].(string) + if !ok { + return "", fmt.Errorf("token found but in the wrong format") + } + + return id, nil +} + +// TokenAccessor returns the standardized token accessor for the given secret. +// If the secret is nil or does not contain an accessor, this returns the empty +// string. +func (s *Secret) TokenAccessor() (string, error) { + if s == nil { + return "", nil + } + + if s.Auth != nil && len(s.Auth.Accessor) > 0 { + return s.Auth.Accessor, nil + } + + if s.Data == nil || s.Data["accessor"] == nil { + return "", nil + } + + accessor, ok := s.Data["accessor"].(string) + if !ok { + return "", fmt.Errorf("token found but in the wrong format") + } + + return accessor, nil +} + +// TokenRemainingUses returns the standardized remaining uses for the given +// secret. If the secret is nil or does not contain the "num_uses", this +// returns -1. On error, this will return -1 and a non-nil error. +func (s *Secret) TokenRemainingUses() (int, error) { + if s == nil || s.Data == nil || s.Data["num_uses"] == nil { + return -1, nil + } + + uses, err := parseutil.ParseInt(s.Data["num_uses"]) + if err != nil { + return 0, err + } + + return int(uses), nil +} + +// TokenPolicies returns the standardized list of policies for the given secret. +// If the secret is nil or does not contain any policies, this returns nil. +func (s *Secret) TokenPolicies() ([]string, error) { + if s == nil { + return nil, nil + } + + if s.Auth != nil && len(s.Auth.Policies) > 0 { + return s.Auth.Policies, nil + } + + if s.Data == nil || s.Data["policies"] == nil { + return nil, nil + } + + sList, ok := s.Data["policies"].([]string) + if ok { + return sList, nil + } + + list, ok := s.Data["policies"].([]interface{}) + if !ok { + return nil, fmt.Errorf("unable to convert token policies to expected format") + } + + policies := make([]string, len(list)) + for i := range list { + p, ok := list[i].(string) + if !ok { + return nil, fmt.Errorf("unable to convert policy %v to string", list[i]) + } + policies[i] = p + } + + return policies, nil +} + +// TokenMetadata returns the map of metadata associated with this token, if any +// exists. If the secret is nil or does not contain the "metadata" key, this +// returns nil. +func (s *Secret) TokenMetadata() (map[string]string, error) { + if s == nil { + return nil, nil + } + + if s.Auth != nil && len(s.Auth.Metadata) > 0 { + return s.Auth.Metadata, nil + } + + if s.Data == nil || (s.Data["metadata"] == nil && s.Data["meta"] == nil) { + return nil, nil + } + + data, ok := s.Data["metadata"].(map[string]interface{}) + if !ok { + data, ok = s.Data["meta"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("unable to convert metadata field to expected format") + } + } + + metadata := make(map[string]string, len(data)) + for k, v := range data { + typed, ok := v.(string) + if !ok { + return nil, fmt.Errorf("unable to convert metadata value %v to string", v) + } + metadata[k] = typed + } + + return metadata, nil +} + +// TokenIsRenewable returns the standardized token renewability for the given +// secret. If the secret is nil or does not contain the "renewable" key, this +// returns false. +func (s *Secret) TokenIsRenewable() (bool, error) { + if s == nil { + return false, nil + } + + if s.Auth != nil && s.Auth.Renewable { + return s.Auth.Renewable, nil + } + + if s.Data == nil || s.Data["renewable"] == nil { + return false, nil + } + + renewable, err := parseutil.ParseBool(s.Data["renewable"]) + if err != nil { + return false, fmt.Errorf("could not convert renewable value to a boolean: %v", err) + } + + return renewable, nil +} + +// TokenTTL returns the standardized remaining token TTL for the given secret. +// If the secret is nil or does not contain a TTL, this returns 0. +func (s *Secret) TokenTTL() (time.Duration, error) { + if s == nil { + return 0, nil + } + + if s.Auth != nil && s.Auth.LeaseDuration > 0 { + return time.Duration(s.Auth.LeaseDuration) * time.Second, nil + } + + if s.Data == nil || s.Data["ttl"] == nil { + return 0, nil + } + + ttl, err := parseutil.ParseDurationSecond(s.Data["ttl"]) + if err != nil { + return 0, err + } + + return ttl, nil +} + // SecretWrapInfo contains wrapping information if we have it. If what is // contained is an authentication token, the accessor for the token will be // available in WrappedAccessor. diff --git a/api/secret_test.go b/api/secret_test.go index 327de46c11..b8e690de37 100644 --- a/api/secret_test.go +++ b/api/secret_test.go @@ -1,13 +1,18 @@ -package api +package api_test import ( + "encoding/json" "reflect" "strings" "testing" "time" + + "github.com/hashicorp/vault/api" ) func TestParseSecret(t *testing.T) { + t.Parallel() + raw := strings.TrimSpace(` { "lease_id": "foo", @@ -30,12 +35,12 @@ func TestParseSecret(t *testing.T) { rawTime, _ := time.Parse(time.RFC3339, "2016-06-07T15:52:10-04:00") - secret, err := ParseSecret(strings.NewReader(raw)) + secret, err := api.ParseSecret(strings.NewReader(raw)) if err != nil { t.Fatalf("err: %s", err) } - expected := &Secret{ + expected := &api.Secret{ LeaseID: "foo", Renewable: true, LeaseDuration: 10, @@ -45,7 +50,7 @@ func TestParseSecret(t *testing.T) { Warnings: []string{ "a warning!", }, - WrapInfo: &SecretWrapInfo{ + WrapInfo: &api.SecretWrapInfo{ Token: "token", Accessor: "accessor", TTL: 60, @@ -57,3 +62,1962 @@ func TestParseSecret(t *testing.T) { t.Fatalf("bad:\ngot\n%#v\nexpected\n%#v\n", secret, expected) } } + +func TestSecret_TokenID(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + secret *api.Secret + exp string + err bool + }{ + { + "nil", + nil, + "", + false, + }, + { + "nil_auth", + &api.Secret{ + Auth: nil, + }, + "", + false, + }, + { + "empty_auth_client_token", + &api.Secret{ + Auth: &api.SecretAuth{ + ClientToken: "", + }, + }, + "", + false, + }, + { + "real_auth_client_token", + &api.Secret{ + Auth: &api.SecretAuth{ + ClientToken: "my-token", + }, + }, + "my-token", + false, + }, + { + "nil_data", + &api.Secret{ + Data: nil, + }, + "", + false, + }, + { + "empty_data", + &api.Secret{ + Data: map[string]interface{}{}, + }, + "", + false, + }, + { + "data_not_string", + &api.Secret{ + Data: map[string]interface{}{ + "id": 123, + }, + }, + "", + true, + }, + { + "data_string", + &api.Secret{ + Data: map[string]interface{}{ + "id": "my-token", + }, + }, + "my-token", + false, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + act, err := tc.secret.TokenID() + if err != nil && !tc.err { + t.Fatal(err) + } + if act != tc.exp { + t.Errorf("expected %q to be %q", act, tc.exp) + } + }) + } + + t.Run("auth", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().EnableAuth("userpass", "userpass", ""); err != nil { + t.Fatal(err) + } + if _, err := client.Logical().Write("auth/userpass/users/test", map[string]interface{}{ + "password": "test", + "policies": "default", + }); err != nil { + t.Fatal(err) + } + + secret, err := client.Logical().Write("auth/userpass/login/test", map[string]interface{}{ + "password": "test", + }) + if err != nil || secret == nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + tokenID, err := secret.TokenID() + if err != nil { + t.Fatal(err) + } + if tokenID != token { + t.Errorf("expected %q to be %q", tokenID, token) + } + }) + + t.Run("token-create", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + }) + if err != nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + tokenID, err := secret.TokenID() + if err != nil { + t.Fatal(err) + } + if tokenID != token { + t.Errorf("expected %q to be %q", tokenID, token) + } + }) + + t.Run("token-lookup", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + }) + if err != nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + secret, err = client.Auth().Token().Lookup(token) + if err != nil { + t.Fatal(err) + } + + tokenID, err := secret.TokenID() + if err != nil { + t.Fatal(err) + } + if tokenID != token { + t.Errorf("expected %q to be %q", tokenID, token) + } + }) + + t.Run("token-lookup-self", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + }) + if err != nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + client.SetToken(token) + secret, err = client.Auth().Token().LookupSelf() + if err != nil { + t.Fatal(err) + } + + tokenID, err := secret.TokenID() + if err != nil { + t.Fatal(err) + } + if tokenID != token { + t.Errorf("expected %q to be %q", tokenID, token) + } + }) + + t.Run("token-renew", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + }) + if err != nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + secret, err = client.Auth().Token().Renew(token, 0) + if err != nil { + t.Fatal(err) + } + + tokenID, err := secret.TokenID() + if err != nil { + t.Fatal(err) + } + if tokenID != token { + t.Errorf("expected %q to be %q", tokenID, token) + } + }) + + t.Run("token-renew-self", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + }) + if err != nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + client.SetToken(token) + secret, err = client.Auth().Token().RenewSelf(0) + if err != nil { + t.Fatal(err) + } + + tokenID, err := secret.TokenID() + if err != nil { + t.Fatal(err) + } + if tokenID != token { + t.Errorf("expected %q to be %q", tokenID, token) + } + }) +} + +func TestSecret_TokenAccessor(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + secret *api.Secret + exp string + err bool + }{ + { + "nil", + nil, + "", + false, + }, + { + "nil_auth", + &api.Secret{ + Auth: nil, + }, + "", + false, + }, + { + "empty_auth_accessor", + &api.Secret{ + Auth: &api.SecretAuth{ + Accessor: "", + }, + }, + "", + false, + }, + { + "real_auth_accessor", + &api.Secret{ + Auth: &api.SecretAuth{ + Accessor: "my-accessor", + }, + }, + "my-accessor", + false, + }, + { + "nil_data", + &api.Secret{ + Data: nil, + }, + "", + false, + }, + { + "empty_data", + &api.Secret{ + Data: map[string]interface{}{}, + }, + "", + false, + }, + { + "data_not_string", + &api.Secret{ + Data: map[string]interface{}{ + "accessor": 123, + }, + }, + "", + true, + }, + { + "data_string", + &api.Secret{ + Data: map[string]interface{}{ + "accessor": "my-accessor", + }, + }, + "my-accessor", + false, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + act, err := tc.secret.TokenAccessor() + if err != nil && !tc.err { + t.Fatal(err) + } + if act != tc.exp { + t.Errorf("expected %q to be %q", act, tc.exp) + } + }) + } + + t.Run("auth", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().EnableAuth("userpass", "userpass", ""); err != nil { + t.Fatal(err) + } + if _, err := client.Logical().Write("auth/userpass/users/test", map[string]interface{}{ + "password": "test", + "policies": "default", + }); err != nil { + t.Fatal(err) + } + + secret, err := client.Logical().Write("auth/userpass/login/test", map[string]interface{}{ + "password": "test", + }) + if err != nil || secret == nil { + t.Fatal(err) + } + _, accessor := secret.Auth.ClientToken, secret.Auth.Accessor + + newAccessor, err := secret.TokenAccessor() + if err != nil { + t.Fatal(err) + } + if newAccessor != accessor { + t.Errorf("expected %q to be %q", newAccessor, accessor) + } + }) + + t.Run("token-create", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + }) + if err != nil { + t.Fatal(err) + } + _, accessor := secret.Auth.ClientToken, secret.Auth.Accessor + + newAccessor, err := secret.TokenAccessor() + if err != nil { + t.Fatal(err) + } + if newAccessor != accessor { + t.Errorf("expected %q to be %q", newAccessor, accessor) + } + }) + + t.Run("token-lookup", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + }) + if err != nil { + t.Fatal(err) + } + token, accessor := secret.Auth.ClientToken, secret.Auth.Accessor + + secret, err = client.Auth().Token().Lookup(token) + if err != nil { + t.Fatal(err) + } + + newAccessor, err := secret.TokenAccessor() + if err != nil { + t.Fatal(err) + } + if newAccessor != accessor { + t.Errorf("expected %q to be %q", newAccessor, accessor) + } + }) + + t.Run("token-lookup-self", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + }) + if err != nil { + t.Fatal(err) + } + token, accessor := secret.Auth.ClientToken, secret.Auth.Accessor + + client.SetToken(token) + secret, err = client.Auth().Token().LookupSelf() + if err != nil { + t.Fatal(err) + } + + newAccessor, err := secret.TokenAccessor() + if err != nil { + t.Fatal(err) + } + if newAccessor != accessor { + t.Errorf("expected %q to be %q", newAccessor, accessor) + } + }) + + t.Run("token-renew", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + }) + if err != nil { + t.Fatal(err) + } + token, accessor := secret.Auth.ClientToken, secret.Auth.Accessor + + secret, err = client.Auth().Token().Renew(token, 0) + if err != nil { + t.Fatal(err) + } + + newAccessor, err := secret.TokenAccessor() + if err != nil { + t.Fatal(err) + } + if newAccessor != accessor { + t.Errorf("expected %q to be %q", newAccessor, accessor) + } + }) + + t.Run("token-renew-self", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + }) + if err != nil { + t.Fatal(err) + } + token, accessor := secret.Auth.ClientToken, secret.Auth.Accessor + + client.SetToken(token) + secret, err = client.Auth().Token().RenewSelf(0) + if err != nil { + t.Fatal(err) + } + + newAccessor, err := secret.TokenAccessor() + if err != nil { + t.Fatal(err) + } + if newAccessor != accessor { + t.Errorf("expected %q to be %q", newAccessor, accessor) + } + }) +} + +func TestSecret_TokenRemainingUses(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + secret *api.Secret + exp int + }{ + { + "nil", + nil, + -1, + }, + { + "nil_data", + &api.Secret{ + Data: nil, + }, + -1, + }, + { + "empty_data", + &api.Secret{ + Data: map[string]interface{}{}, + }, + -1, + }, + { + "data_not_json_number", + &api.Secret{ + Data: map[string]interface{}{ + "num_uses": 123, + }, + }, + 123, + }, + { + "data_json_number", + &api.Secret{ + Data: map[string]interface{}{ + "num_uses": json.Number("123"), + }, + }, + 123, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + act, err := tc.secret.TokenRemainingUses() + if tc.exp != -1 && err != nil { + t.Fatal(err) + } + if act != tc.exp { + t.Errorf("expected %d to be %d", act, tc.exp) + } + }) + } + + t.Run("auth", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + uses := 5 + + if err := client.Sys().EnableAuth("userpass", "userpass", ""); err != nil { + t.Fatal(err) + } + if _, err := client.Logical().Write("auth/userpass/users/test", map[string]interface{}{ + "password": "test", + "policies": "default", + "num_uses": uses, + }); err != nil { + t.Fatal(err) + } + + secret, err := client.Logical().Write("auth/userpass/login/test", map[string]interface{}{ + "password": "test", + }) + if err != nil || secret == nil { + t.Fatal(err) + } + + // Remaining uses is not returned from this API + uses = -1 + remaining, err := secret.TokenRemainingUses() + if err != nil { + t.Fatal(err) + } + if remaining != uses { + t.Errorf("expected %d to be %d", remaining, uses) + } + }) + + t.Run("token-create", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + uses := 5 + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + NumUses: uses, + }) + if err != nil { + t.Fatal(err) + } + + // /auth/token/create does not return the number of uses + uses = -1 + remaining, err := secret.TokenRemainingUses() + if err != nil { + t.Fatal(err) + } + if remaining != uses { + t.Errorf("expected %d to be %d", remaining, uses) + } + }) + + t.Run("token-lookup", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + uses := 5 + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + NumUses: uses, + }) + if err != nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + secret, err = client.Auth().Token().Lookup(token) + if err != nil { + t.Fatal(err) + } + + remaining, err := secret.TokenRemainingUses() + if err != nil { + t.Fatal(err) + } + if remaining != uses { + t.Errorf("expected %d to be %d", remaining, uses) + } + }) + + t.Run("token-lookup-self", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + uses := 5 + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + NumUses: uses, + }) + if err != nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + client.SetToken(token) + secret, err = client.Auth().Token().LookupSelf() + if err != nil { + t.Fatal(err) + } + + uses = uses - 1 // we just used it + remaining, err := secret.TokenRemainingUses() + if err != nil { + t.Fatal(err) + } + if remaining != uses { + t.Errorf("expected %d to be %d", remaining, uses) + } + }) + + t.Run("token-renew", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + uses := 5 + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + NumUses: uses, + }) + if err != nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + secret, err = client.Auth().Token().Renew(token, 0) + if err != nil { + t.Fatal(err) + } + + // /auth/token/renew does not return the number of uses + uses = -1 + remaining, err := secret.TokenRemainingUses() + if err != nil { + t.Fatal(err) + } + if remaining != uses { + t.Errorf("expected %d to be %d", remaining, uses) + } + }) + + t.Run("token-renew-self", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + uses := 5 + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + NumUses: uses, + }) + if err != nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + client.SetToken(token) + secret, err = client.Auth().Token().RenewSelf(0) + if err != nil { + t.Fatal(err) + } + + // /auth/token/renew-self does not return the number of uses + uses = -1 + remaining, err := secret.TokenRemainingUses() + if err != nil { + t.Fatal(err) + } + if remaining != uses { + t.Errorf("expected %d to be %d", remaining, uses) + } + }) +} + +func TestSecret_TokenPolicies(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + secret *api.Secret + exp []string + err bool + }{ + { + "nil", + nil, + nil, + false, + }, + { + "nil_auth", + &api.Secret{ + Auth: nil, + }, + nil, + false, + }, + { + "nil_auth_policies", + &api.Secret{ + Auth: &api.SecretAuth{ + Policies: nil, + }, + }, + nil, + false, + }, + { + "empty_auth_policies", + &api.Secret{ + Auth: &api.SecretAuth{ + Policies: []string{}, + }, + }, + nil, + false, + }, + { + "real_auth_policies", + &api.Secret{ + Auth: &api.SecretAuth{ + Policies: []string{"foo"}, + }, + }, + []string{"foo"}, + false, + }, + { + "nil_data", + &api.Secret{ + Data: nil, + }, + nil, + false, + }, + { + "empty_data", + &api.Secret{ + Data: map[string]interface{}{}, + }, + nil, + false, + }, + { + "data_not_slice", + &api.Secret{ + Data: map[string]interface{}{ + "policies": 123, + }, + }, + nil, + true, + }, + { + "data_slice", + &api.Secret{ + Data: map[string]interface{}{ + "policies": []interface{}{"foo"}, + }, + }, + []string{"foo"}, + false, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + act, err := tc.secret.TokenPolicies() + if err != nil && !tc.err { + t.Fatal(err) + } + if !reflect.DeepEqual(act, tc.exp) { + t.Errorf("expected %#v to be %#v", act, tc.exp) + } + }) + } + + t.Run("auth", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + policies := []string{"bar", "default", "foo"} + + if err := client.Sys().EnableAuth("userpass", "userpass", ""); err != nil { + t.Fatal(err) + } + if _, err := client.Logical().Write("auth/userpass/users/test", map[string]interface{}{ + "password": "test", + "policies": strings.Join(policies, ","), + }); err != nil { + t.Fatal(err) + } + + secret, err := client.Logical().Write("auth/userpass/login/test", map[string]interface{}{ + "password": "test", + }) + if err != nil || secret == nil { + t.Fatal(err) + } + + tPol, err := secret.TokenPolicies() + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(tPol, policies) { + t.Errorf("expected %#v to be %#v", tPol, policies) + } + }) + + t.Run("token-create", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + policies := []string{"bar", "default", "foo"} + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: policies, + }) + if err != nil { + t.Fatal(err) + } + + tPol, err := secret.TokenPolicies() + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(tPol, policies) { + t.Errorf("expected %#v to be %#v", tPol, policies) + } + }) + + t.Run("token-lookup", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + policies := []string{"bar", "default", "foo"} + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: policies, + }) + if err != nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + secret, err = client.Auth().Token().Lookup(token) + if err != nil { + t.Fatal(err) + } + + tPol, err := secret.TokenPolicies() + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(tPol, policies) { + t.Errorf("expected %#v to be %#v", tPol, policies) + } + }) + + t.Run("token-lookup-self", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + policies := []string{"bar", "default", "foo"} + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: policies, + }) + if err != nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + client.SetToken(token) + secret, err = client.Auth().Token().LookupSelf() + if err != nil { + t.Fatal(err) + } + + tPol, err := secret.TokenPolicies() + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(tPol, policies) { + t.Errorf("expected %#v to be %#v", tPol, policies) + } + }) + + t.Run("token-renew", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + policies := []string{"bar", "default", "foo"} + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: policies, + }) + if err != nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + secret, err = client.Auth().Token().Renew(token, 0) + if err != nil { + t.Fatal(err) + } + + tPol, err := secret.TokenPolicies() + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(tPol, policies) { + t.Errorf("expected %#v to be %#v", tPol, policies) + } + }) + + t.Run("token-renew-self", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + policies := []string{"bar", "default", "foo"} + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: policies, + }) + if err != nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + client.SetToken(token) + secret, err = client.Auth().Token().RenewSelf(0) + if err != nil { + t.Fatal(err) + } + + tPol, err := secret.TokenPolicies() + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(tPol, policies) { + t.Errorf("expected %#v to be %#v", tPol, policies) + } + }) +} + +func TestSecret_TokenMetadata(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + secret *api.Secret + exp map[string]string + err bool + }{ + { + "nil", + nil, + nil, + false, + }, + { + "nil_auth", + &api.Secret{ + Auth: nil, + }, + nil, + false, + }, + { + "nil_auth_metadata", + &api.Secret{ + Auth: &api.SecretAuth{ + Metadata: nil, + }, + }, + nil, + false, + }, + { + "empty_auth_metadata", + &api.Secret{ + Auth: &api.SecretAuth{ + Metadata: map[string]string{}, + }, + }, + nil, + false, + }, + { + "real_auth_metdata", + &api.Secret{ + Auth: &api.SecretAuth{ + Metadata: map[string]string{"foo": "bar"}, + }, + }, + map[string]string{"foo": "bar"}, + false, + }, + { + "nil_data", + &api.Secret{ + Data: nil, + }, + nil, + false, + }, + { + "empty_data", + &api.Secret{ + Data: map[string]interface{}{}, + }, + nil, + false, + }, + { + "data_not_map", + &api.Secret{ + Data: map[string]interface{}{ + "metadata": 123, + }, + }, + nil, + true, + }, + { + "data_map", + &api.Secret{ + Data: map[string]interface{}{ + "metadata": map[string]interface{}{"foo": "bar"}, + }, + }, + map[string]string{"foo": "bar"}, + false, + }, + { + "data_map_bad_type", + &api.Secret{ + Data: map[string]interface{}{ + "metadata": map[string]interface{}{"foo": 123}, + }, + }, + nil, + true, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + act, err := tc.secret.TokenMetadata() + if err != nil && !tc.err { + t.Fatal(err) + } + if !reflect.DeepEqual(act, tc.exp) { + t.Errorf("expected %#v to be %#v", act, tc.exp) + } + }) + } + + t.Run("auth", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + metadata := map[string]string{"username": "test"} + + if err := client.Sys().EnableAuth("userpass", "userpass", ""); err != nil { + t.Fatal(err) + } + if _, err := client.Logical().Write("auth/userpass/users/test", map[string]interface{}{ + "password": "test", + "policies": "default", + }); err != nil { + t.Fatal(err) + } + + secret, err := client.Logical().Write("auth/userpass/login/test", map[string]interface{}{ + "password": "test", + }) + if err != nil || secret == nil { + t.Fatal(err) + } + + tMeta, err := secret.TokenMetadata() + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(tMeta, metadata) { + t.Errorf("expected %#v to be %#v", tMeta, metadata) + } + }) + + t.Run("token-create", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + metadata := map[string]string{"username": "test"} + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Metadata: metadata, + Policies: []string{"default"}, + }) + if err != nil { + t.Fatal(err) + } + + tMeta, err := secret.TokenMetadata() + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(tMeta, metadata) { + t.Errorf("expected %#v to be %#v", tMeta, metadata) + } + }) + + t.Run("token-lookup", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + metadata := map[string]string{"username": "test"} + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Metadata: metadata, + Policies: []string{"default"}, + }) + if err != nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + secret, err = client.Auth().Token().Lookup(token) + if err != nil { + t.Fatal(err) + } + + tMeta, err := secret.TokenMetadata() + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(tMeta, metadata) { + t.Errorf("expected %#v to be %#v", tMeta, metadata) + } + }) + + t.Run("token-lookup-self", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + metadata := map[string]string{"username": "test"} + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Metadata: metadata, + Policies: []string{"default"}, + }) + if err != nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + client.SetToken(token) + secret, err = client.Auth().Token().LookupSelf() + if err != nil { + t.Fatal(err) + } + + tMeta, err := secret.TokenMetadata() + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(tMeta, metadata) { + t.Errorf("expected %#v to be %#v", tMeta, metadata) + } + }) + + t.Run("token-renew", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + metadata := map[string]string{"username": "test"} + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Metadata: metadata, + Policies: []string{"default"}, + }) + if err != nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + secret, err = client.Auth().Token().Renew(token, 0) + if err != nil { + t.Fatal(err) + } + + tMeta, err := secret.TokenMetadata() + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(tMeta, metadata) { + t.Errorf("expected %#v to be %#v", tMeta, metadata) + } + }) + + t.Run("token-renew-self", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + metadata := map[string]string{"username": "test"} + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Metadata: metadata, + Policies: []string{"default"}, + }) + if err != nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + client.SetToken(token) + secret, err = client.Auth().Token().RenewSelf(0) + if err != nil { + t.Fatal(err) + } + + tMeta, err := secret.TokenMetadata() + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(tMeta, metadata) { + t.Errorf("expected %#v to be %#v", tMeta, metadata) + } + }) +} + +func TestSecret_TokenIsRenewable(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + secret *api.Secret + exp bool + }{ + { + "nil", + nil, + false, + }, + { + "nil_auth", + &api.Secret{ + Auth: nil, + }, + false, + }, + { + "auth_renewable_false", + &api.Secret{ + Auth: &api.SecretAuth{ + Renewable: false, + }, + }, + false, + }, + { + "auth_renewable_true", + &api.Secret{ + Auth: &api.SecretAuth{ + Renewable: true, + }, + }, + true, + }, + { + "nil_data", + &api.Secret{ + Data: nil, + }, + false, + }, + { + "empty_data", + &api.Secret{ + Data: map[string]interface{}{}, + }, + false, + }, + { + "data_not_bool", + &api.Secret{ + Data: map[string]interface{}{ + "renewable": 123, + }, + }, + true, + }, + { + "data_bool_string", + &api.Secret{ + Data: map[string]interface{}{ + "renewable": "true", + }, + }, + true, + }, + { + "data_bool_true", + &api.Secret{ + Data: map[string]interface{}{ + "renewable": true, + }, + }, + true, + }, + { + "data_bool_false", + &api.Secret{ + Data: map[string]interface{}{ + "renewable": true, + }, + }, + true, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + act, err := tc.secret.TokenIsRenewable() + if err != nil { + t.Fatal(err) + } + if act != tc.exp { + t.Errorf("expected %t to be %t", act, tc.exp) + } + }) + } + + t.Run("auth", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + renewable := true + + if err := client.Sys().EnableAuth("userpass", "userpass", ""); err != nil { + t.Fatal(err) + } + if _, err := client.Logical().Write("auth/userpass/users/test", map[string]interface{}{ + "password": "test", + "policies": "default", + }); err != nil { + t.Fatal(err) + } + + secret, err := client.Logical().Write("auth/userpass/login/test", map[string]interface{}{ + "password": "test", + }) + if err != nil || secret == nil { + t.Fatal(err) + } + + tRenew, err := secret.TokenIsRenewable() + if err != nil { + t.Fatal(err) + } + if tRenew != renewable { + t.Errorf("expected %t to be %t", tRenew, renewable) + } + }) + + t.Run("token-create", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + renewable := true + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + Renewable: &renewable, + }) + if err != nil { + t.Fatal(err) + } + + tRenew, err := secret.TokenIsRenewable() + if err != nil { + t.Fatal(err) + } + if tRenew != renewable { + t.Errorf("expected %t to be %t", tRenew, renewable) + } + }) + + t.Run("token-lookup", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + renewable := true + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + Renewable: &renewable, + }) + if err != nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + secret, err = client.Auth().Token().Lookup(token) + if err != nil { + t.Fatal(err) + } + + tRenew, err := secret.TokenIsRenewable() + if err != nil { + t.Fatal(err) + } + if tRenew != renewable { + t.Errorf("expected %t to be %t", tRenew, renewable) + } + }) + + t.Run("token-lookup-self", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + renewable := true + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + Renewable: &renewable, + }) + if err != nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + client.SetToken(token) + secret, err = client.Auth().Token().LookupSelf() + if err != nil { + t.Fatal(err) + } + + tRenew, err := secret.TokenIsRenewable() + if err != nil { + t.Fatal(err) + } + if tRenew != renewable { + t.Errorf("expected %t to be %t", tRenew, renewable) + } + }) + + t.Run("token-renew", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + renewable := true + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + Renewable: &renewable, + }) + if err != nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + secret, err = client.Auth().Token().Renew(token, 0) + if err != nil { + t.Fatal(err) + } + + tRenew, err := secret.TokenIsRenewable() + if err != nil { + t.Fatal(err) + } + if tRenew != renewable { + t.Errorf("expected %t to be %t", tRenew, renewable) + } + }) + + t.Run("token-renew-self", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + renewable := true + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + Renewable: &renewable, + }) + if err != nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + client.SetToken(token) + secret, err = client.Auth().Token().RenewSelf(0) + if err != nil { + t.Fatal(err) + } + + tRenew, err := secret.TokenIsRenewable() + if err != nil { + t.Fatal(err) + } + if tRenew != renewable { + t.Errorf("expected %t to be %t", tRenew, renewable) + } + }) +} + +func TestSecret_TokenTTL(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + secret *api.Secret + exp time.Duration + }{ + { + "nil", + nil, + 0, + }, + { + "nil_auth", + &api.Secret{ + Auth: nil, + }, + 0, + }, + { + "nil_auth_lease_duration", + &api.Secret{ + Auth: &api.SecretAuth{ + LeaseDuration: 0, + }, + }, + 0, + }, + { + "real_auth_lease_duration", + &api.Secret{ + Auth: &api.SecretAuth{ + LeaseDuration: 3600, + }, + }, + 1 * time.Hour, + }, + { + "nil_data", + &api.Secret{ + Data: nil, + }, + 0, + }, + { + "empty_data", + &api.Secret{ + Data: map[string]interface{}{}, + }, + 0, + }, + { + "data_not_json_number", + &api.Secret{ + Data: map[string]interface{}{ + "ttl": 123, + }, + }, + 123 * time.Second, + }, + { + "data_json_number", + &api.Secret{ + Data: map[string]interface{}{ + "ttl": json.Number("3600"), + }, + }, + 1 * time.Hour, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + act, err := tc.secret.TokenTTL() + if err != nil { + t.Fatal(err) + } + if act != tc.exp { + t.Errorf("expected %q to be %q", act, tc.exp) + } + }) + } + + t.Run("auth", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ttl := 30 * time.Minute + + if err := client.Sys().EnableAuth("userpass", "userpass", ""); err != nil { + t.Fatal(err) + } + if _, err := client.Logical().Write("auth/userpass/users/test", map[string]interface{}{ + "password": "test", + "policies": "default", + "ttl": ttl.String(), + "explicit_max_ttl": ttl.String(), + }); err != nil { + t.Fatal(err) + } + + secret, err := client.Logical().Write("auth/userpass/login/test", map[string]interface{}{ + "password": "test", + }) + if err != nil || secret == nil { + t.Fatal(err) + } + + tokenTTL, err := secret.TokenTTL() + if err != nil { + t.Fatal(err) + } + if tokenTTL == 0 || tokenTTL > ttl { + t.Errorf("expected %q to non-zero and less than %q", tokenTTL, ttl) + } + }) + + t.Run("token-create", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ttl := 30 * time.Minute + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + TTL: ttl.String(), + ExplicitMaxTTL: ttl.String(), + }) + if err != nil { + t.Fatal(err) + } + + tokenTTL, err := secret.TokenTTL() + if err != nil { + t.Fatal(err) + } + if tokenTTL == 0 || tokenTTL > ttl { + t.Errorf("expected %q to non-zero and less than %q", tokenTTL, ttl) + } + }) + + t.Run("token-lookup", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ttl := 30 * time.Minute + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + TTL: ttl.String(), + ExplicitMaxTTL: ttl.String(), + }) + if err != nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + secret, err = client.Auth().Token().Lookup(token) + if err != nil { + t.Fatal(err) + } + + tokenTTL, err := secret.TokenTTL() + if err != nil { + t.Fatal(err) + } + if tokenTTL == 0 || tokenTTL > ttl { + t.Errorf("expected %q to non-zero and less than %q", tokenTTL, ttl) + } + }) + + t.Run("token-lookup-self", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ttl := 30 * time.Minute + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + TTL: ttl.String(), + ExplicitMaxTTL: ttl.String(), + }) + if err != nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + client.SetToken(token) + secret, err = client.Auth().Token().LookupSelf() + if err != nil { + t.Fatal(err) + } + + tokenTTL, err := secret.TokenTTL() + if err != nil { + t.Fatal(err) + } + if tokenTTL == 0 || tokenTTL > ttl { + t.Errorf("expected %q to non-zero and less than %q", tokenTTL, ttl) + } + }) + + t.Run("token-renew", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ttl := 30 * time.Minute + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + TTL: ttl.String(), + ExplicitMaxTTL: ttl.String(), + }) + if err != nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + secret, err = client.Auth().Token().Renew(token, 0) + if err != nil { + t.Fatal(err) + } + + tokenTTL, err := secret.TokenTTL() + if err != nil { + t.Fatal(err) + } + if tokenTTL == 0 || tokenTTL > ttl { + t.Errorf("expected %q to non-zero and less than %q", tokenTTL, ttl) + } + }) + + t.Run("token-renew-self", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ttl := 30 * time.Minute + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + TTL: ttl.String(), + ExplicitMaxTTL: ttl.String(), + }) + if err != nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + client.SetToken(token) + secret, err = client.Auth().Token().RenewSelf(0) + if err != nil { + t.Fatal(err) + } + + tokenTTL, err := secret.TokenTTL() + if err != nil { + t.Fatal(err) + } + if tokenTTL == 0 || tokenTTL > ttl { + t.Errorf("expected %q to non-zero and less than %q", tokenTTL, ttl) + } + }) +} diff --git a/builtin/credential/aws/backend.go b/builtin/credential/aws/backend.go index cb1c97017e..5019c33f5e 100644 --- a/builtin/credential/aws/backend.go +++ b/builtin/credential/aws/backend.go @@ -278,7 +278,7 @@ func getAnyRegionForAwsPartition(partitionId string) *endpoints.Region { } const backendHelp = ` -aws-ec2 auth backend takes in PKCS#7 signature of an AWS EC2 instance and a client +aws-ec2 auth method takes in PKCS#7 signature of an AWS EC2 instance and a client created nonce to authenticates the EC2 instance with Vault. Authentication is backed by a preconfigured role in the backend. The role diff --git a/builtin/credential/aws/cli.go b/builtin/credential/aws/cli.go index e139230be9..e6330ca128 100644 --- a/builtin/credential/aws/cli.go +++ b/builtin/credential/aws/cli.go @@ -113,29 +113,51 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro func (h *CLIHandler) Help() string { help := ` -The AWS credential provider allows you to authenticate with -AWS IAM credentials. To use it, you specify valid AWS IAM credentials -in one of a number of ways. They can be specified explicitly on the -command line (which in general you should not do), via the standard AWS -environment variables (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and -AWS_SECURITY_TOKEN), via the ~/.aws/credentials file, or via an EC2 -instance profile (in that order). +Usage: vault login -method=aws [CONFIG K=V...] - Example: vault auth -method=aws + The AWS auth method allows users to authenticate with AWS IAM + credentials. The AWS IAM credentials may be specified in a number of ways, + listed in order of precedence below: -If you need to explicitly pass in credentials, you would do it like this: - Example: vault auth -method=aws aws_access_key_id= aws_secret_access_key= aws_security_token= + 1. Explicitly via the command line (not recommended) -Key/Value Pairs: + 2. Via the standard AWS environment variables (AWS_ACCESS_KEY, etc.) - mount=aws The mountpoint for the AWS credential provider. - Defaults to "aws" - aws_access_key_id= Explicitly specified AWS access key - aws_secret_access_key= Explicitly specified AWS secret key - aws_security_token= Security token for temporary credentials - header_value The Value of the X-Vault-AWS-IAM-Server-ID header. - role The name of the role you're requesting a token for - ` + 3. Via the ~/.aws/credentials file + + 4. Via EC2 instance profile + + Authenticate using locally stored credentials: + + $ vault login -method=aws + + Authenticate by passing keys: + + $ vault login -method=aws aws_access_key_id=... aws_secret_access_key=... + +Configuration: + + aws_access_key_id= + Explicit AWS access key ID + + aws_secret_access_key= + Explicit AWS secret access key + + aws_security_token= + Explicit AWS security token for temporary credentials + + header_value= + Value for the x-vault-aws-iam-server-id header in requests + + mount= + Path where the AWS credential method is mounted. This is usually provided + via the -path flag in the "vault login" command, but it can be specified + here as well. If specified here, it takes precedence over the value for + -path. The default value is "aws". + + role= + Name of the role to request a token against +` return strings.TrimSpace(help) } diff --git a/builtin/credential/aws/path_config_client.go b/builtin/credential/aws/path_config_client.go index 558f6f31fc..64b3a5b8ef 100644 --- a/builtin/credential/aws/path_config_client.go +++ b/builtin/credential/aws/path_config_client.go @@ -261,7 +261,7 @@ Configure AWS IAM credentials that are used to query instance and role details f ` const pathConfigClientHelpDesc = ` -The aws-ec2 auth backend makes AWS API queries to retrieve information +The aws-ec2 auth method makes AWS API queries to retrieve information regarding EC2 instances that perform login operations. The 'aws_secret_key' and 'aws_access_key' parameters configured here should map to an AWS IAM user that has permission to make the following API queries: diff --git a/builtin/credential/cert/cli.go b/builtin/credential/cert/cli.go index a1071fcd38..3c45b855d1 100644 --- a/builtin/credential/cert/cli.go +++ b/builtin/credential/cert/cli.go @@ -40,17 +40,22 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro func (h *CLIHandler) Help() string { help := ` -The "cert" credential provider allows you to authenticate with a -client certificate. No other authentication materials are needed. -Optionally, you may specify the specific certificate role to -authenticate against with the "name" parameter. +Usage: vault login -method=cert [CONFIG K=V...] - Example: vault auth -method=cert \ - -client-cert=/path/to/cert.pem \ - -client-key=/path/to/key.pem - name=cert1 + The certificate auth method allows uers to authenticate with a + client certificate passed with the request. The -client-cert and -client-key + flags are included with the "vault login" command, NOT as configuration to the + auth method. - ` + Authenticate using a local client certificate: + + $ vault login -method=cert -client-cert=cert.pem -client-key=key.pem + +Configuration: + + name= + Certificate role to authenticate against. +` return strings.TrimSpace(help) } diff --git a/builtin/credential/github/cli.go b/builtin/credential/github/cli.go index 557939b209..595f20fcbd 100644 --- a/builtin/credential/github/cli.go +++ b/builtin/credential/github/cli.go @@ -2,13 +2,18 @@ package github import ( "fmt" + "io" "os" "strings" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/password" ) -type CLIHandler struct{} +type CLIHandler struct { + // for tests + testStdout io.Writer +} func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, error) { mount, ok := m["mount"] @@ -16,16 +21,39 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro mount = "github" } - token, ok := m["token"] - if !ok { - if token = os.Getenv("VAULT_AUTH_GITHUB_TOKEN"); token == "" { - return nil, fmt.Errorf("GitHub token should be provided either as 'value' for 'token' key,\nor via an env var VAULT_AUTH_GITHUB_TOKEN") + // Extract or prompt for token + token := m["token"] + if token == "" { + token = os.Getenv("VAULT_AUTH_GITHUB_TOKEN") + } + if token == "" { + // Override the output + stdout := h.testStdout + if stdout == nil { + stdout = os.Stdout + } + + var err error + fmt.Fprintf(stdout, "GitHub Personal Access Token (will be hidden): ") + token, err = password.Read(os.Stdin) + fmt.Fprintf(stdout, "\n") + if err != nil { + if err == password.ErrInterrupted { + return nil, fmt.Errorf("user interrupted") + } + + return nil, fmt.Errorf("An error occurred attempting to "+ + "ask for a token. The raw error message is shown below, but usually "+ + "this is because you attempted to pipe a value into the command or "+ + "you are executing outside of a terminal (tty). If you want to pipe "+ + "the value, pass \"-\" as the argument to read from stdin. The raw "+ + "error was: %s", err) } } path := fmt.Sprintf("auth/%s/login", mount) secret, err := c.Logical().Write(path, map[string]interface{}{ - "token": token, + "token": strings.TrimSpace(token), }) if err != nil { return nil, err @@ -39,20 +67,28 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro func (h *CLIHandler) Help() string { help := ` -The GitHub credential provider allows you to authenticate with GitHub. -To use it, specify the "token" parameter. The value should be a personal access -token for your GitHub account. You can generate a personal access token on your -account settings page on GitHub. +Usage: vault login -method=github [CONFIG K=V...] - Example: vault auth -method=github token= + The GitHub auth method allows users to authenticate using a GitHub + personal access token. Users can generate a personal access token from the + settings page on their GitHub account. -Key/Value Pairs: + Authenticate using a GitHub token: - mount=github The mountpoint for the GitHub credential provider. - Defaults to "github" + $ vault login -method=github token=abcd1234 - token= The GitHub personal access token for authentication. - ` +Configuration: + + mount= + Path where the GitHub credential method is mounted. This is usually + provided via the -path flag in the "vault login" command, but it can be + specified here as well. If specified here, it takes precedence over the + value for -path. The default value is "github". + + token= + GitHub personal access token to use for authentication. If not provided, + Vault will prompt for the value. +` return strings.TrimSpace(help) } diff --git a/builtin/credential/ldap/backend_test.go b/builtin/credential/ldap/backend_test.go index 83e241faab..0ba292047f 100644 --- a/builtin/credential/ldap/backend_test.go +++ b/builtin/credential/ldap/backend_test.go @@ -103,7 +103,7 @@ func TestLdapAuthBackend_UserPolicies(t *testing.T) { } /* - * Acceptance test for LDAP Auth Backend + * Acceptance test for LDAP Auth Method * * The tests here rely on a public LDAP server: * [http://www.forumsys.com/tutorials/integration-how-to/ldap/online-ldap-test-server/] diff --git a/builtin/credential/ldap/cli.go b/builtin/credential/ldap/cli.go index 262bc998e1..5ac19a727d 100644 --- a/builtin/credential/ldap/cli.go +++ b/builtin/credential/ldap/cli.go @@ -62,18 +62,40 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro func (h *CLIHandler) Help() string { help := ` -The LDAP credential provider allows you to authenticate with LDAP. -To use it, first configure it through the "config" endpoint, and then -login by specifying username and password. If password is not provided -on the command line, it will be read from stdin. +Usage: vault login -method=ldap [CONFIG K=V...] -If multi-factor authentication (MFA) is enabled, a "method" and/or "passcode" -may be provided depending on the MFA backend enabled. To check -which MFA backend is in use, read "auth/[mount]/mfa_config". + The LDAP auth method allows users to authenticate using LDAP or + Active Directory. - Example: vault auth -method=ldap username=john + If MFA is enabled, a "method" and/or "passcode" may be required depending on + the MFA method. To check which MFA is in use, run: - ` + $ vault read auth//mfa_config + + Authenticate as "sally": + + $ vault login -method=ldap username=sally + Password (will be hidden): + + Authenticate as "bob": + + $ vault login -method=ldap username=bob password=password + +Configuration: + + method= + MFA method. + + passcode= + MFA OTP/passcode. + + password= + LDAP password to use for authentication. If not provided, the CLI will + prompt for this on stdin. + + username= + LDAP username to use for authentication. +` return strings.TrimSpace(help) } diff --git a/builtin/credential/okta/backend.go b/builtin/credential/okta/backend.go index ad96013a77..45eb341188 100644 --- a/builtin/credential/okta/backend.go +++ b/builtin/credential/okta/backend.go @@ -60,7 +60,7 @@ func (b *backend) Login(req *logical.Request, username string, password string) return nil, nil, nil, err } if cfg == nil { - return nil, logical.ErrorResponse("Okta backend not configured"), nil, nil + return nil, logical.ErrorResponse("Okta auth method not configured"), nil, nil } client := cfg.OktaClient() @@ -87,7 +87,7 @@ func (b *backend) Login(req *logical.Request, username string, password string) return nil, logical.ErrorResponse(fmt.Sprintf("Okta auth failed: %v", err)), nil, nil } if rsp == nil { - return nil, logical.ErrorResponse("okta auth backend unexpected failure"), nil, nil + return nil, logical.ErrorResponse("okta auth method unexpected failure"), nil, nil } oktaResponse := &logical.Response{ @@ -161,7 +161,7 @@ func (b *backend) getOktaGroups(client *okta.Client, user *okta.User) ([]string, return nil, err } if rsp == nil { - return nil, fmt.Errorf("okta auth backend unexpected failure") + return nil, fmt.Errorf("okta auth method unexpected failure") } oktaGroups := make([]string, 0, len(user.Groups)) for _, group := range user.Groups { diff --git a/builtin/credential/okta/cli.go b/builtin/credential/okta/cli.go index 8939e2a19f..110b184e92 100644 --- a/builtin/credential/okta/cli.go +++ b/builtin/credential/okta/cli.go @@ -62,14 +62,28 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro // Help method for okta cli func (h *CLIHandler) Help() string { help := ` -The Okta credential provider allows you to authenticate with Okta. -To use it, first configure it through the "config" endpoint, and then -login by specifying username and password. If password is not provided -on the command line, it will be read from stdin. +Usage: vault login -method=okta [CONFIG K=V...] - Example: vault auth -method=okta username=john + The Okta auth method allows users to authenticate using Okta. - ` + Authenticate as "sally": + + $ vault login -method=okta username=sally + Password (will be hidden): + + Authenticate as "bob": + + $ vault login -method=okta username=bob password=password + +Configuration: + + password= + Okta password to use for authentication. If not provided, the CLI will + prompt for this on stdin. + + username= + Okta username to use for authentication. +` return strings.TrimSpace(help) } diff --git a/builtin/credential/radius/path_config.go b/builtin/credential/radius/path_config.go index 1a0870366c..573788b498 100644 --- a/builtin/credential/radius/path_config.go +++ b/builtin/credential/radius/path_config.go @@ -156,7 +156,7 @@ func (b *backend) pathConfigCreateUpdate(ctx context.Context, req *logical.Reque policies = strings.Split(unregisteredUserPoliciesStr, ",") for _, policy := range policies { if policy == "root" { - return logical.ErrorResponse("root policy cannot be granted by an authentication backend"), nil + return logical.ErrorResponse("root policy cannot be granted by an auth method"), nil } } } diff --git a/builtin/credential/radius/path_users.go b/builtin/credential/radius/path_users.go index 4a7ec0c734..88c2101de0 100644 --- a/builtin/credential/radius/path_users.go +++ b/builtin/credential/radius/path_users.go @@ -112,7 +112,7 @@ func (b *backend) pathUserWrite(ctx context.Context, req *logical.Request, d *fr var policies = policyutil.ParsePolicies(d.Get("policies")) for _, policy := range policies { if policy == "root" { - return logical.ErrorResponse("root policy cannot be granted by an authentication backend"), nil + return logical.ErrorResponse("root policy cannot be granted by an auth method"), nil } } diff --git a/builtin/credential/token/cli.go b/builtin/credential/token/cli.go new file mode 100644 index 0000000000..b300128c28 --- /dev/null +++ b/builtin/credential/token/cli.go @@ -0,0 +1,166 @@ +package token + +import ( + "fmt" + "io" + "os" + "strconv" + "strings" + + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/password" +) + +type CLIHandler struct { + // for tests + testStdin io.Reader + testStdout io.Writer +} + +func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, error) { + // Parse "lookup" first - we want to return an early error if the user + // supplied an invalid value here before we prompt them for a token. It would + // be annoying to type your token and then be told you supplied an invalid + // value that we could have known in advance. + lookup := true + if x, ok := m["lookup"]; ok { + parsed, err := strconv.ParseBool(x) + if err != nil { + return nil, fmt.Errorf("Failed to parse \"lookup\" as boolean: %s", err) + } + lookup = parsed + } + + // Parse the token. + token, ok := m["token"] + if !ok { + // Override the output + stdout := h.testStdout + if stdout == nil { + stdout = os.Stdout + } + + // No arguments given, read the token from user input + fmt.Fprintf(stdout, "Token (will be hidden): ") + var err error + token, err = password.Read(os.Stdin) + fmt.Fprintf(stdout, "\n") + + if err != nil { + if err == password.ErrInterrupted { + return nil, fmt.Errorf("user interrupted") + } + + return nil, fmt.Errorf("An error occurred attempting to "+ + "ask for a token. The raw error message is shown below, but usually "+ + "this is because you attempted to pipe a value into the command or "+ + "you are executing outside of a terminal (tty). If you want to pipe "+ + "the value, pass \"-\" as the argument to read from stdin. The raw "+ + "error was: %s", err) + } + } + + // Remove any whitespace, etc. + token = strings.TrimSpace(token) + + if token == "" { + return nil, fmt.Errorf( + "A token must be passed to auth. Please view the help for more " + + "information.") + } + + // If the user declined verification, return now. Note that we will not have + // a lot of information about the token. + if !lookup { + return &api.Secret{ + Auth: &api.SecretAuth{ + ClientToken: token, + }, + }, nil + } + + // If we got this far, we want to lookup and lookup the token and pull it's + // list of policies an metadata. + c.SetToken(token) + c.SetWrappingLookupFunc(func(string, string) string { return "" }) + + secret, err := c.Auth().Token().LookupSelf() + if err != nil { + return nil, fmt.Errorf("Error looking up token: %s", err) + } + if secret == nil { + return nil, fmt.Errorf("Empty response from lookup-self") + } + + // Return an auth struct that "looks" like the response from an auth method. + // lookup and lookup-self return their data in data, not auth. We try to + // mirror that data here. + id, err := secret.TokenID() + if err != nil { + return nil, fmt.Errorf("Error accessing token ID: %s", err) + } + accessor, err := secret.TokenAccessor() + if err != nil { + return nil, fmt.Errorf("Error accessing token accessor: %s", err) + } + policies, err := secret.TokenPolicies() + if err != nil { + return nil, fmt.Errorf("Error accessing token policies: %s", err) + } + metadata, err := secret.TokenMetadata() + if err != nil { + return nil, fmt.Errorf("Error accessing token metadata: %s", err) + } + dur, err := secret.TokenTTL() + if err != nil { + return nil, fmt.Errorf("Error converting token TTL: %s", err) + } + renewable, err := secret.TokenIsRenewable() + if err != nil { + return nil, fmt.Errorf("Error checking if token is renewable: %s", err) + } + return &api.Secret{ + Auth: &api.SecretAuth{ + ClientToken: id, + Accessor: accessor, + Policies: policies, + Metadata: metadata, + + LeaseDuration: int(dur.Seconds()), + Renewable: renewable, + }, + }, nil + +} + +func (h *CLIHandler) Help() string { + help := ` +Usage: vault login TOKEN [CONFIG K=V...] + + The token auth method allows logging in directly with a token. This + can be a token from the "token-create" command or API. There are no + configuration options for this auth method. + + Authenticate using a token: + + $ vault login 96ddf4bc-d217-f3ba-f9bd-017055595017 + + Authenticate but do not lookup information about the token: + + $ vault login token=96ddf4bc-d217-f3ba-f9bd-017055595017 lookup=false + + This token usually comes from a different source such as the API or via the + built-in "vault token-create" command. + +Configuration: + + token= + The token to use for authentication. This is usually provided directly + via the "vault login" command. + + lookup= + Perform a lookup of the token's metadata and policies. +` + + return strings.TrimSpace(help) +} diff --git a/builtin/credential/userpass/cli.go b/builtin/credential/userpass/cli.go index 4433c0e706..4984cb4602 100644 --- a/builtin/credential/userpass/cli.go +++ b/builtin/credential/userpass/cli.go @@ -66,20 +66,40 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro func (h *CLIHandler) Help() string { help := ` -The "userpass"/"radius" credential provider allows you to authenticate with -a username and password. To use it, specify the "username" and "password" -parameters. If password is not provided on the command line, it will be -read from stdin. +Usage: vault login -method=userpass [CONFIG K=V...] -If multi-factor authentication (MFA) is enabled, a "method" and/or "passcode" -may be provided depending on the MFA backend enabled. To check -which MFA backend is in use, read "auth/[mount]/mfa_config". + The userpass auth method allows users to authenticate using Vault's + internal user database. - Example: vault auth -method=userpass \ - username= \ - password= + If MFA is enabled, a "method" and/or "passcode" may be required depending on + the MFA method. To check which MFA is in use, run: - ` + $ vault read auth//mfa_config + + Authenticate as "sally": + + $ vault login -method=userpass username=sally + Password (will be hidden): + + Authenticate as "bob": + + $ vault login -method=userpass username=bob password=password + +Configuration: + + method= + MFA method. + + passcode= + MFA OTP/passcode. + + password= + Password to use for authentication. If not provided, the CLI will prompt + for this on stdin. + + username= + Username to use for authentication. +` return strings.TrimSpace(help) } diff --git a/cli/commands.go b/cli/commands.go deleted file mode 100644 index 83e50b6ed2..0000000000 --- a/cli/commands.go +++ /dev/null @@ -1,391 +0,0 @@ -package cli - -import ( - "os" - - auditFile "github.com/hashicorp/vault/builtin/audit/file" - auditSocket "github.com/hashicorp/vault/builtin/audit/socket" - auditSyslog "github.com/hashicorp/vault/builtin/audit/syslog" - "github.com/hashicorp/vault/physical" - "github.com/hashicorp/vault/version" - - credGcp "github.com/hashicorp/vault-plugin-auth-gcp/plugin" - credKube "github.com/hashicorp/vault-plugin-auth-kubernetes" - credAppId "github.com/hashicorp/vault/builtin/credential/app-id" - credAppRole "github.com/hashicorp/vault/builtin/credential/approle" - credAws "github.com/hashicorp/vault/builtin/credential/aws" - credCert "github.com/hashicorp/vault/builtin/credential/cert" - credGitHub "github.com/hashicorp/vault/builtin/credential/github" - credLdap "github.com/hashicorp/vault/builtin/credential/ldap" - credOkta "github.com/hashicorp/vault/builtin/credential/okta" - credRadius "github.com/hashicorp/vault/builtin/credential/radius" - credUserpass "github.com/hashicorp/vault/builtin/credential/userpass" - - physAzure "github.com/hashicorp/vault/physical/azure" - physCassandra "github.com/hashicorp/vault/physical/cassandra" - physCockroachDB "github.com/hashicorp/vault/physical/cockroachdb" - physConsul "github.com/hashicorp/vault/physical/consul" - physCouchDB "github.com/hashicorp/vault/physical/couchdb" - physDynamoDB "github.com/hashicorp/vault/physical/dynamodb" - physEtcd "github.com/hashicorp/vault/physical/etcd" - physFile "github.com/hashicorp/vault/physical/file" - physGCS "github.com/hashicorp/vault/physical/gcs" - physInmem "github.com/hashicorp/vault/physical/inmem" - physMSSQL "github.com/hashicorp/vault/physical/mssql" - physMySQL "github.com/hashicorp/vault/physical/mysql" - physPostgreSQL "github.com/hashicorp/vault/physical/postgresql" - physS3 "github.com/hashicorp/vault/physical/s3" - physSwift "github.com/hashicorp/vault/physical/swift" - physZooKeeper "github.com/hashicorp/vault/physical/zookeeper" - - "github.com/hashicorp/vault/builtin/logical/aws" - "github.com/hashicorp/vault/builtin/logical/cassandra" - "github.com/hashicorp/vault/builtin/logical/consul" - "github.com/hashicorp/vault/builtin/logical/database" - "github.com/hashicorp/vault/builtin/logical/mongodb" - "github.com/hashicorp/vault/builtin/logical/mssql" - "github.com/hashicorp/vault/builtin/logical/mysql" - "github.com/hashicorp/vault/builtin/logical/nomad" - "github.com/hashicorp/vault/builtin/logical/pki" - "github.com/hashicorp/vault/builtin/logical/postgresql" - "github.com/hashicorp/vault/builtin/logical/rabbitmq" - "github.com/hashicorp/vault/builtin/logical/ssh" - "github.com/hashicorp/vault/builtin/logical/totp" - "github.com/hashicorp/vault/builtin/logical/transit" - "github.com/hashicorp/vault/builtin/plugin" - - "github.com/hashicorp/vault/audit" - "github.com/hashicorp/vault/command" - "github.com/hashicorp/vault/logical" - "github.com/hashicorp/vault/meta" - "github.com/mitchellh/cli" -) - -// Commands returns the mapping of CLI commands for Vault. The meta -// parameter lets you set meta options for all commands. -func Commands(metaPtr *meta.Meta) map[string]cli.CommandFactory { - if metaPtr == nil { - metaPtr = &meta.Meta{ - TokenHelper: command.DefaultTokenHelper, - } - } - - if metaPtr.Ui == nil { - metaPtr.Ui = &cli.BasicUi{ - Writer: os.Stdout, - ErrorWriter: os.Stderr, - } - } - - return map[string]cli.CommandFactory{ - "init": func() (cli.Command, error) { - return &command.InitCommand{ - Meta: *metaPtr, - }, nil - }, - "server": func() (cli.Command, error) { - c := &command.ServerCommand{ - Meta: *metaPtr, - AuditBackends: map[string]audit.Factory{ - "file": auditFile.Factory, - "syslog": auditSyslog.Factory, - "socket": auditSocket.Factory, - }, - CredentialBackends: map[string]logical.Factory{ - "approle": credAppRole.Factory, - "cert": credCert.Factory, - "aws": credAws.Factory, - "app-id": credAppId.Factory, - "gcp": credGcp.Factory, - "github": credGitHub.Factory, - "userpass": credUserpass.Factory, - "ldap": credLdap.Factory, - "okta": credOkta.Factory, - "radius": credRadius.Factory, - "kubernetes": credKube.Factory, - "plugin": plugin.Factory, - }, - LogicalBackends: map[string]logical.Factory{ - "aws": aws.Factory, - "consul": consul.Factory, - "nomad": nomad.Factory, - "postgresql": postgresql.Factory, - "cassandra": cassandra.Factory, - "pki": pki.Factory, - "transit": transit.Factory, - "mongodb": mongodb.Factory, - "mssql": mssql.Factory, - "mysql": mysql.Factory, - "ssh": ssh.Factory, - "rabbitmq": rabbitmq.Factory, - "database": database.Factory, - "totp": totp.Factory, - "plugin": plugin.Factory, - }, - - ShutdownCh: command.MakeShutdownCh(), - SighupCh: command.MakeSighupCh(), - } - - c.PhysicalBackends = map[string]physical.Factory{ - "azure": physAzure.NewAzureBackend, - "cassandra": physCassandra.NewCassandraBackend, - "cockroachdb": physCockroachDB.NewCockroachDBBackend, - "consul": physConsul.NewConsulBackend, - "couchdb": physCouchDB.NewCouchDBBackend, - "couchdb_transactional": physCouchDB.NewTransactionalCouchDBBackend, - "dynamodb": physDynamoDB.NewDynamoDBBackend, - "etcd": physEtcd.NewEtcdBackend, - "file": physFile.NewFileBackend, - "file_transactional": physFile.NewTransactionalFileBackend, - "gcs": physGCS.NewGCSBackend, - "inmem": physInmem.NewInmem, - "inmem_ha": physInmem.NewInmemHA, - "inmem_transactional": physInmem.NewTransactionalInmem, - "inmem_transactional_ha": physInmem.NewTransactionalInmemHA, - "mssql": physMSSQL.NewMSSQLBackend, - "mysql": physMySQL.NewMySQLBackend, - "postgresql": physPostgreSQL.NewPostgreSQLBackend, - "s3": physS3.NewS3Backend, - "swift": physSwift.NewSwiftBackend, - "zookeeper": physZooKeeper.NewZooKeeperBackend, - } - - return c, nil - }, - - "ssh": func() (cli.Command, error) { - return &command.SSHCommand{ - Meta: *metaPtr, - }, nil - }, - - "path-help": func() (cli.Command, error) { - return &command.PathHelpCommand{ - Meta: *metaPtr, - }, nil - }, - - "auth": func() (cli.Command, error) { - return &command.AuthCommand{ - Meta: *metaPtr, - Handlers: map[string]command.AuthHandler{ - "github": &credGitHub.CLIHandler{}, - "userpass": &credUserpass.CLIHandler{DefaultMount: "userpass"}, - "ldap": &credLdap.CLIHandler{}, - "okta": &credOkta.CLIHandler{}, - "cert": &credCert.CLIHandler{}, - "aws": &credAws.CLIHandler{}, - "radius": &credUserpass.CLIHandler{DefaultMount: "radius"}, - }, - }, nil - }, - - "auth-enable": func() (cli.Command, error) { - return &command.AuthEnableCommand{ - Meta: *metaPtr, - }, nil - }, - - "auth-disable": func() (cli.Command, error) { - return &command.AuthDisableCommand{ - Meta: *metaPtr, - }, nil - }, - - "audit-list": func() (cli.Command, error) { - return &command.AuditListCommand{ - Meta: *metaPtr, - }, nil - }, - - "audit-disable": func() (cli.Command, error) { - return &command.AuditDisableCommand{ - Meta: *metaPtr, - }, nil - }, - - "audit-enable": func() (cli.Command, error) { - return &command.AuditEnableCommand{ - Meta: *metaPtr, - }, nil - }, - - "key-status": func() (cli.Command, error) { - return &command.KeyStatusCommand{ - Meta: *metaPtr, - }, nil - }, - - "policies": func() (cli.Command, error) { - return &command.PolicyListCommand{ - Meta: *metaPtr, - }, nil - }, - - "policy-delete": func() (cli.Command, error) { - return &command.PolicyDeleteCommand{ - Meta: *metaPtr, - }, nil - }, - - "policy-write": func() (cli.Command, error) { - return &command.PolicyWriteCommand{ - Meta: *metaPtr, - }, nil - }, - - "read": func() (cli.Command, error) { - return &command.ReadCommand{ - Meta: *metaPtr, - }, nil - }, - - "unwrap": func() (cli.Command, error) { - return &command.UnwrapCommand{ - Meta: *metaPtr, - }, nil - }, - - "list": func() (cli.Command, error) { - return &command.ListCommand{ - Meta: *metaPtr, - }, nil - }, - - "write": func() (cli.Command, error) { - return &command.WriteCommand{ - Meta: *metaPtr, - }, nil - }, - - "delete": func() (cli.Command, error) { - return &command.DeleteCommand{ - Meta: *metaPtr, - }, nil - }, - - "rekey": func() (cli.Command, error) { - return &command.RekeyCommand{ - Meta: *metaPtr, - }, nil - }, - - "generate-root": func() (cli.Command, error) { - return &command.GenerateRootCommand{ - Meta: *metaPtr, - }, nil - }, - - "renew": func() (cli.Command, error) { - return &command.RenewCommand{ - Meta: *metaPtr, - }, nil - }, - - "revoke": func() (cli.Command, error) { - return &command.RevokeCommand{ - Meta: *metaPtr, - }, nil - }, - - "seal": func() (cli.Command, error) { - return &command.SealCommand{ - Meta: *metaPtr, - }, nil - }, - - "status": func() (cli.Command, error) { - return &command.StatusCommand{ - Meta: *metaPtr, - }, nil - }, - - "unseal": func() (cli.Command, error) { - return &command.UnsealCommand{ - Meta: *metaPtr, - }, nil - }, - - "step-down": func() (cli.Command, error) { - return &command.StepDownCommand{ - Meta: *metaPtr, - }, nil - }, - - "mount": func() (cli.Command, error) { - return &command.MountCommand{ - Meta: *metaPtr, - }, nil - }, - - "mounts": func() (cli.Command, error) { - return &command.MountsCommand{ - Meta: *metaPtr, - }, nil - }, - - "mount-tune": func() (cli.Command, error) { - return &command.MountTuneCommand{ - Meta: *metaPtr, - }, nil - }, - - "remount": func() (cli.Command, error) { - return &command.RemountCommand{ - Meta: *metaPtr, - }, nil - }, - - "rotate": func() (cli.Command, error) { - return &command.RotateCommand{ - Meta: *metaPtr, - }, nil - }, - - "unmount": func() (cli.Command, error) { - return &command.UnmountCommand{ - Meta: *metaPtr, - }, nil - }, - - "token-create": func() (cli.Command, error) { - return &command.TokenCreateCommand{ - Meta: *metaPtr, - }, nil - }, - - "token-lookup": func() (cli.Command, error) { - return &command.TokenLookupCommand{ - Meta: *metaPtr, - }, nil - }, - - "token-renew": func() (cli.Command, error) { - return &command.TokenRenewCommand{ - Meta: *metaPtr, - }, nil - }, - - "token-revoke": func() (cli.Command, error) { - return &command.TokenRevokeCommand{ - Meta: *metaPtr, - }, nil - }, - - "capabilities": func() (cli.Command, error) { - return &command.CapabilitiesCommand{ - Meta: *metaPtr, - }, nil - }, - - "version": func() (cli.Command, error) { - versionInfo := version.GetVersion() - - return &command.VersionCommand{ - VersionInfo: versionInfo, - Ui: metaPtr.Ui, - }, nil - }, - } -} diff --git a/cli/help.go b/cli/help.go deleted file mode 100644 index bd66e335a3..0000000000 --- a/cli/help.go +++ /dev/null @@ -1,82 +0,0 @@ -package cli - -import ( - "bytes" - "fmt" - "sort" - "strings" - - "github.com/mitchellh/cli" -) - -// HelpFunc is a cli.HelpFunc that can is used to output the help for Vault. -func HelpFunc(commands map[string]cli.CommandFactory) string { - commonNames := map[string]struct{}{ - "delete": struct{}{}, - "path-help": struct{}{}, - "read": struct{}{}, - "renew": struct{}{}, - "revoke": struct{}{}, - "write": struct{}{}, - "server": struct{}{}, - "status": struct{}{}, - "unwrap": struct{}{}, - } - - // Determine the maximum key length, and classify based on type - commonCommands := make(map[string]cli.CommandFactory) - otherCommands := make(map[string]cli.CommandFactory) - maxKeyLen := 0 - for key, f := range commands { - if len(key) > maxKeyLen { - maxKeyLen = len(key) - } - - if _, ok := commonNames[key]; ok { - commonCommands[key] = f - } else { - otherCommands[key] = f - } - } - - var buf bytes.Buffer - buf.WriteString("usage: vault [-version] [-help] [args]\n\n") - buf.WriteString("Common commands:\n") - buf.WriteString(listCommands(commonCommands, maxKeyLen)) - buf.WriteString("\nAll other commands:\n") - buf.WriteString(listCommands(otherCommands, maxKeyLen)) - return buf.String() -} - -// listCommands just lists the commands in the map with the -// given maximum key length. -func listCommands(commands map[string]cli.CommandFactory, maxKeyLen int) string { - var buf bytes.Buffer - - // Get the list of keys so we can sort them, and also get the maximum - // key length so they can be aligned properly. - keys := make([]string, 0, len(commands)) - for key, _ := range commands { - keys = append(keys, key) - } - sort.Strings(keys) - - for _, key := range keys { - commandFunc, ok := commands[key] - if !ok { - // This should never happen since we JUST built the list of - // keys. - panic("command not found: " + key) - } - - command, err := commandFunc() - if err != nil { - panic(fmt.Sprintf("command '%s' failed to load: %s", key, err)) - } - - key = fmt.Sprintf("%s%s", key, strings.Repeat(" ", maxKeyLen-len(key))) - buf.WriteString(fmt.Sprintf(" %s %s\n", key, command.Synopsis())) - } - - return buf.String() -} diff --git a/cli/main.go b/cli/main.go deleted file mode 100644 index 000e1e9a4e..0000000000 --- a/cli/main.go +++ /dev/null @@ -1,53 +0,0 @@ -package cli - -import ( - "fmt" - "os" - - "github.com/mitchellh/cli" -) - -func Run(args []string) int { - return RunCustom(args, Commands(nil)) -} - -func RunCustom(args []string, commands map[string]cli.CommandFactory) int { - // Get the command line args. We shortcut "--version" and "-v" to - // just show the version. - for _, arg := range args { - if arg == "-v" || arg == "-version" || arg == "--version" { - newArgs := make([]string, len(args)+1) - newArgs[0] = "version" - copy(newArgs[1:], args) - args = newArgs - break - } - } - - // Build the commands to include in the help now. This is pretty... - // tedious, but we don't have a better way at the moment. - commandsInclude := make([]string, 0, len(commands)) - for k, _ := range commands { - switch k { - case "token-disk": - default: - commandsInclude = append(commandsInclude, k) - } - } - - cli := &cli.CLI{ - Args: args, - Commands: commands, - Name: "vault", - Autocomplete: true, - HelpFunc: cli.FilteredHelpFunc(commandsInclude, HelpFunc), - } - - exitCode, err := cli.Run() - if err != nil { - fmt.Fprintf(os.Stderr, "Error executing CLI: %s\n", err.Error()) - return 1 - } - - return exitCode -} diff --git a/command/audit.go b/command/audit.go new file mode 100644 index 0000000000..0e59357794 --- /dev/null +++ b/command/audit.go @@ -0,0 +1,42 @@ +package command + +import ( + "strings" + + "github.com/mitchellh/cli" +) + +var _ cli.Command = (*AuditCommand)(nil) + +type AuditCommand struct { + *BaseCommand +} + +func (c *AuditCommand) Synopsis() string { + return "Interact with audit devices" +} + +func (c *AuditCommand) Help() string { + helpText := ` +Usage: vault audit [options] [args] + + This command groups subcommands for interacting with Vault's audit devices. + Users can list, enable, and disable audit devices. + + List all enabled audit devices: + + $ vault audit list + + Enable a new audit device "userpass"; + + $ vault audit enable file file_path=/var/log/audit.log + + Please see the individual subcommand help for detailed usage information. +` + + return strings.TrimSpace(helpText) +} + +func (c *AuditCommand) Run(args []string) int { + return cli.RunResultHelp +} diff --git a/command/audit_disable.go b/command/audit_disable.go index 31c4457287..1025a0ba27 100644 --- a/command/audit_disable.go +++ b/command/audit_disable.go @@ -4,68 +4,84 @@ import ( "fmt" "strings" - "github.com/hashicorp/vault/meta" + "github.com/mitchellh/cli" + "github.com/posener/complete" ) -// AuditDisableCommand is a Command that mounts a new mount. +var _ cli.Command = (*AuditDisableCommand)(nil) +var _ cli.CommandAutocomplete = (*AuditDisableCommand)(nil) + type AuditDisableCommand struct { - meta.Meta -} - -func (c *AuditDisableCommand) Run(args []string) int { - flags := c.Meta.FlagSet("mount", meta.FlagSetDefault) - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - args = flags.Args() - if len(args) != 1 { - flags.Usage() - c.Ui.Error(fmt.Sprintf( - "\naudit-disable expects one argument: the id to disable")) - return 1 - } - - id := args[0] - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 2 - } - - if err := client.Sys().DisableAudit(id); err != nil { - c.Ui.Error(fmt.Sprintf( - "Error disabling audit backend: %s", err)) - return 2 - } - - c.Ui.Output(fmt.Sprintf( - "Successfully disabled audit backend '%s' if it was enabled", id)) - - return 0 + *BaseCommand } func (c *AuditDisableCommand) Synopsis() string { - return "Disable an audit backend" + return "Disables an audit device" } func (c *AuditDisableCommand) Help() string { helpText := ` -Usage: vault audit-disable [options] id +Usage: vault audit disable [options] PATH - Disable an audit backend. + Disables an audit device. Once an audit device is disabled, no future audit + logs are dispatched to it. The data associated with the audit device is not + affected. - Once the audit backend is disabled no more audit logs will be sent to - it. The data associated with the audit backend isn't affected. + The argument corresponds to the PATH of audit device, not the TYPE! - The "id" parameter should map to the "path" used in "audit-enable". If - no path was provided to "audit-enable" you should use the backend - type (e.g. "file"). + Disable the audit device enabled at "file/": + + $ vault audit disable file/ + +` + c.Flags().Help() -General Options: -` + meta.GeneralOptionsUsage() return strings.TrimSpace(helpText) } + +func (c *AuditDisableCommand) Flags() *FlagSets { + return c.flagSet(FlagSetHTTP) +} + +func (c *AuditDisableCommand) AutocompleteArgs() complete.Predictor { + return c.PredictVaultAudits() +} + +func (c *AuditDisableCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *AuditDisableCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + switch { + case len(args) < 1: + c.UI.Error(fmt.Sprintf("Not enough arguments (expected 1, got %d)", len(args))) + return 1 + case len(args) > 1: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args))) + return 1 + } + + path := ensureTrailingSlash(sanitizePath(args[0])) + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + if err := client.Sys().DisableAudit(path); err != nil { + c.UI.Error(fmt.Sprintf("Error disabling audit device: %s", err)) + return 2 + } + + c.UI.Output(fmt.Sprintf("Success! Disabled audit device (if it was enabled) at: %s", path)) + + return 0 +} diff --git a/command/audit_disable_test.go b/command/audit_disable_test.go index 500ee9ccb1..0a7e8e4dcd 100644 --- a/command/audit_disable_test.go +++ b/command/audit_disable_test.go @@ -1,86 +1,160 @@ package command import ( + "strings" "testing" "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" ) -func TestAuditDisable(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func testAuditDisableCommand(tb testing.TB) (*cli.MockUi, *AuditDisableCommand) { + tb.Helper() - ui := new(cli.MockUi) - c := &AuditDisableCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + ui := cli.NewMockUi() + return ui, &AuditDisableCommand{ + BaseCommand: &BaseCommand{ + UI: ui, }, } - - args := []string{ - "-address", addr, - "noop", - } - - // Run once to get the client - c.Run(args) - - // Get the client - client, err := c.Client() - if err != nil { - t.Fatalf("err: %#v", err) - } - if err := client.Sys().EnableAudit("noop", "noop", "", nil); err != nil { - t.Fatalf("err: %#v", err) - } - - // Run again - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } } -func TestAuditDisableWithOptions(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func TestAuditDisableCommand_Run(t *testing.T) { + t.Parallel() - ui := new(cli.MockUi) - c := &AuditDisableCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + cases := []struct { + name string + args []string + out string + code int + }{ + { + "not_enough_args", + nil, + "Not enough arguments", + 1, + }, + { + "too_many_args", + []string{"foo", "bar", "baz"}, + "Too many arguments", + 1, + }, + { + "not_real", + []string{"not_real"}, + "Success! Disabled audit device (if it was enabled) at: not_real/", + 0, + }, + { + "default", + []string{"file"}, + "Success! Disabled audit device (if it was enabled) at: file/", + 0, }, } - args := []string{ - "-address", addr, - "noop", + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().EnableAuditWithOptions("file", &api.EnableAuditOptions{ + Type: "file", + Options: map[string]string{ + "file_path": "discard", + }, + }); err != nil { + t.Fatal(err) + } + + ui, cmd := testAuditDisableCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) } - // Run once to get the client - c.Run(args) + t.Run("integration", func(t *testing.T) { + t.Parallel() - // Get the client - client, err := c.Client() - if err != nil { - t.Fatalf("err: %#v", err) - } - if err := client.Sys().EnableAuditWithOptions("noop", &api.EnableAuditOptions{ - Type: "noop", - Description: "noop", - }); err != nil { - t.Fatalf("err: %#v", err) - } + client, closer := testVaultServer(t) + defer closer() - // Run again - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } + if err := client.Sys().EnableAuditWithOptions("integration_audit_disable", &api.EnableAuditOptions{ + Type: "file", + Options: map[string]string{ + "file_path": "discard", + }, + }); err != nil { + t.Fatal(err) + } + + ui, cmd := testAuditDisableCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "integration_audit_disable/", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Success! Disabled audit device (if it was enabled) at: integration_audit_disable/" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + mounts, err := client.Sys().ListMounts() + if err != nil { + t.Fatal(err) + } + + if _, ok := mounts["integration_audit_disable"]; ok { + t.Errorf("expected mount to not exist: %#v", mounts) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testAuditDisableCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "file", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error disabling audit device: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testAuditDisableCommand(t) + assertNoTabs(t, cmd) + }) } diff --git a/command/audit_enable.go b/command/audit_enable.go index 680a94ed19..85b3bac9aa 100644 --- a/command/audit_enable.go +++ b/command/audit_enable.go @@ -7,128 +7,85 @@ import ( "strings" "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/helper/kv-builder" - "github.com/hashicorp/vault/meta" - "github.com/mitchellh/mapstructure" + "github.com/mitchellh/cli" "github.com/posener/complete" ) -// AuditEnableCommand is a Command that mounts a new mount. +var _ cli.Command = (*AuditEnableCommand)(nil) +var _ cli.CommandAutocomplete = (*AuditEnableCommand)(nil) + type AuditEnableCommand struct { - meta.Meta + *BaseCommand - // A test stdin that can be used for tests - testStdin io.Reader -} + flagDescription string + flagPath string + flagLocal bool -func (c *AuditEnableCommand) Run(args []string) int { - var desc, path string - var local bool - flags := c.Meta.FlagSet("audit-enable", meta.FlagSetDefault) - flags.StringVar(&desc, "description", "", "") - flags.StringVar(&path, "path", "", "") - flags.BoolVar(&local, "local", false, "") - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - args = flags.Args() - if len(args) < 1 { - flags.Usage() - c.Ui.Error(fmt.Sprintf( - "\naudit-enable expects at least one argument: the type to enable")) - return 1 - } - - auditType := args[0] - if path == "" { - path = auditType - } - - // Build the options - var stdin io.Reader = os.Stdin - if c.testStdin != nil { - stdin = c.testStdin - } - builder := &kvbuilder.Builder{Stdin: stdin} - if err := builder.Add(args[1:]...); err != nil { - c.Ui.Error(fmt.Sprintf( - "Error parsing options: %s", err)) - return 1 - } - - var opts map[string]string - if err := mapstructure.WeakDecode(builder.Map(), &opts); err != nil { - c.Ui.Error(fmt.Sprintf( - "Error parsing options: %s", err)) - return 1 - } - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 1 - } - - err = client.Sys().EnableAuditWithOptions(path, &api.EnableAuditOptions{ - Type: auditType, - Description: desc, - Options: opts, - Local: local, - }) - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error enabling audit backend: %s", err)) - return 1 - } - - c.Ui.Output(fmt.Sprintf( - "Successfully enabled audit backend '%s' with path '%s'!", auditType, path)) - return 0 + testStdin io.Reader // For tests } func (c *AuditEnableCommand) Synopsis() string { - return "Enable an audit backend" + return "Enables an audit device" } func (c *AuditEnableCommand) Help() string { helpText := ` -Usage: vault audit-enable [options] type [config...] +Usage: vault audit enable [options] TYPE [CONFIG K=V...] - Enable an audit backend. + Enables an audit device at a given path. - This command enables an audit backend of type "type". Additional - options for configuring the audit backend can be specified after the - type in the same format as the "vault write" command in key/value pairs. + This command enables an audit device of TYPE. Additional options for + configuring the audit device can be specified after the type in the same + format as the "vault write" command in key/value pairs. - For example, to configure the file audit backend to write audit logs at - the path /var/log/audit.log: + For example, to configure the file audit device to write audit logs at the + path "/var/log/audit.log": - $ vault audit-enable file file_path=/var/log/audit.log + $ vault audit enable file file_path=/var/log/audit.log - For information on available configuration options, please see the - documentation. +` + c.Flags().Help() -General Options: -` + meta.GeneralOptionsUsage() + ` -Audit Enable Options: - - -description= A human-friendly description for the backend. This - shows up only when querying the enabled backends. - - -path= Specify a unique path for this audit backend. This - is purely for referencing this audit backend. By - default this will be the backend type. - - -local Mark the mount as a local mount. Local mounts - are not replicated nor (if a secondary) - removed by replication. -` return strings.TrimSpace(helpText) } +func (c *AuditEnableCommand) Flags() *FlagSets { + set := c.flagSet(FlagSetHTTP) + + f := set.NewFlagSet("Command Options") + + f.StringVar(&StringVar{ + Name: "description", + Target: &c.flagDescription, + Default: "", + EnvVar: "", + Completion: complete.PredictAnything, + Usage: "Human-friendly description for the purpose of this audit " + + "device.", + }) + + f.StringVar(&StringVar{ + Name: "path", + Target: &c.flagPath, + Default: "", // The default is complex, so we have to manually document + EnvVar: "", + Completion: complete.PredictAnything, + Usage: "Place where the audit device will be accessible. This must be " + + "unique across all audit devices. This defaults to the \"type\" of the " + + "audit device.", + }) + + f.BoolVar(&BoolVar{ + Name: "local", + Target: &c.flagLocal, + Default: false, + EnvVar: "", + Usage: "Mark the audit device as a local-only device. Local devices " + + "are not replicated or removed by replication.", + }) + + return set +} + func (c *AuditEnableCommand) AutocompleteArgs() complete.Predictor { return complete.PredictSet( "file", @@ -138,9 +95,60 @@ func (c *AuditEnableCommand) AutocompleteArgs() complete.Predictor { } func (c *AuditEnableCommand) AutocompleteFlags() complete.Flags { - return complete.Flags{ - "-description": complete.PredictNothing, - "-path": complete.PredictNothing, - "-local": complete.PredictNothing, - } + return c.Flags().Completions() +} + +func (c *AuditEnableCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + if len(args) < 1 { + c.UI.Error("Missing TYPE!") + return 1 + } + + // Grab the type + auditType := strings.TrimSpace(args[0]) + + auditPath := c.flagPath + if auditPath == "" { + auditPath = auditType + } + auditPath = ensureTrailingSlash(auditPath) + + // Pull our fake stdin if needed + stdin := (io.Reader)(os.Stdin) + if c.testStdin != nil { + stdin = c.testStdin + } + + options, err := parseArgsDataString(stdin, args[1:]) + if err != nil { + c.UI.Error(fmt.Sprintf("Failed to parse K=V data: %s", err)) + return 1 + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + if err := client.Sys().EnableAuditWithOptions(auditPath, &api.EnableAuditOptions{ + Type: auditType, + Description: c.flagDescription, + Options: options, + Local: c.flagLocal, + }); err != nil { + c.UI.Error(fmt.Sprintf("Error enabling audit device: %s", err)) + return 2 + } + + c.UI.Output(fmt.Sprintf("Success! Enabled the %s audit device at: %s", auditType, auditPath)) + return 0 } diff --git a/command/audit_enable_test.go b/command/audit_enable_test.go index 118f103d3e..c2fe43e84f 100644 --- a/command/audit_enable_test.go +++ b/command/audit_enable_test.go @@ -1,56 +1,160 @@ package command import ( - "reflect" + "strings" "testing" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" ) -func TestAuditEnable(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func testAuditEnableCommand(tb testing.TB) (*cli.MockUi, *AuditEnableCommand) { + tb.Helper() - ui := new(cli.MockUi) - c := &AuditEnableCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + ui := cli.NewMockUi() + return ui, &AuditEnableCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestAuditEnableCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "empty", + nil, + "Missing TYPE!", + 1, + }, + { + "not_a_valid_type", + []string{"nope_definitely_not_a_valid_type_like_ever"}, + "", + 2, + }, + { + "enable", + []string{"file", "file_path=discard"}, + "Success! Enabled the file audit device at: file/", + 0, + }, + { + "enable_path", + []string{ + "-path", "audit_path", + "file", + "file_path=discard", + }, + "Success! Enabled the file audit device at: audit_path/", + 0, }, } - args := []string{ - "-address", addr, - "noop", - "foo=bar", + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testAuditEnableCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } + t.Run("integration", func(t *testing.T) { + t.Parallel() - // Get the client - client, err := c.Client() - if err != nil { - t.Fatalf("err: %#v", err) - } + client, closer := testVaultServer(t) + defer closer() - audits, err := client.Sys().ListAudit() - if err != nil { - t.Fatalf("err: %#v", err) - } + ui, cmd := testAuditEnableCommand(t) + cmd.client = client - audit, ok := audits["noop/"] - if !ok { - t.Fatalf("err: %#v", audits) - } + code := cmd.Run([]string{ + "-path", "audit_enable_integration/", + "-description", "The best kind of test", + "file", + "file_path=discard", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } - expected := map[string]string{"foo": "bar"} - if !reflect.DeepEqual(audit.Options, expected) { - t.Fatalf("err: %#v", audit) - } + expected := "Success! Enabled the file audit device at: audit_enable_integration/" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + audits, err := client.Sys().ListAudit() + if err != nil { + t.Fatal(err) + } + + auditInfo, ok := audits["audit_enable_integration/"] + if !ok { + t.Fatalf("expected audit to exist") + } + if exp := "file"; auditInfo.Type != exp { + t.Errorf("expected %q to be %q", auditInfo.Type, exp) + } + if exp := "The best kind of test"; auditInfo.Description != exp { + t.Errorf("expected %q to be %q", auditInfo.Description, exp) + } + + filePath, ok := auditInfo.Options["file_path"] + if !ok || filePath != "discard" { + t.Errorf("missing some options: %#v", auditInfo) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testAuditEnableCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "pki", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error enabling audit device: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testAuditEnableCommand(t) + assertNoTabs(t, cmd) + }) } diff --git a/command/audit_list.go b/command/audit_list.go index b9914eb929..0012426e41 100644 --- a/command/audit_list.go +++ b/command/audit_list.go @@ -5,83 +5,158 @@ import ( "sort" "strings" - "github.com/hashicorp/vault/meta" - "github.com/ryanuber/columnize" + "github.com/hashicorp/vault/api" + "github.com/mitchellh/cli" + "github.com/posener/complete" ) -// AuditListCommand is a Command that lists the enabled audits. +var _ cli.Command = (*AuditListCommand)(nil) +var _ cli.CommandAutocomplete = (*AuditListCommand)(nil) + type AuditListCommand struct { - meta.Meta + *BaseCommand + + flagDetailed bool +} + +func (c *AuditListCommand) Synopsis() string { + return "Lists enabled audit devices" +} + +func (c *AuditListCommand) Help() string { + helpText := ` +Usage: vault audit list [options] + + Lists the enabled audit devices in the Vault server. The output lists the + enabled audit devices and the options for those devices. + + List all audit devices: + + $ vault audit list + + List detailed output about the audit devices: + + $ vault audit list -detailed + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *AuditListCommand) Flags() *FlagSets { + set := c.flagSet(FlagSetHTTP) + + f := set.NewFlagSet("Command Options") + + f.BoolVar(&BoolVar{ + Name: "detailed", + Target: &c.flagDetailed, + Default: false, + EnvVar: "", + Usage: "Print detailed information such as options and replication " + + "status about each auth device.", + }) + + return set +} + +func (c *AuditListCommand) AutocompleteArgs() complete.Predictor { + return nil +} + +func (c *AuditListCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() } func (c *AuditListCommand) Run(args []string) int { - flags := c.Meta.FlagSet("audit-list", meta.FlagSetDefault) - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + if len(args) > 0 { + c.UI.Error(fmt.Sprintf("Too many arguments (expected 0, got %d)", len(args))) return 1 } client, err := c.Client() if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) + c.UI.Error(err.Error()) return 2 } audits, err := client.Sys().ListAudit() if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error reading audits: %s", err)) + c.UI.Error(fmt.Sprintf("Error listing audits: %s", err)) return 2 } if len(audits) == 0 { - c.Ui.Error(fmt.Sprintf( - "No audit backends are enabled. Use `vault audit-enable` to\n" + - "enable an audit backend.")) - return 1 + c.UI.Output(fmt.Sprintf("No audit devices are enabled.")) + return 0 } + if c.flagDetailed { + c.UI.Output(tableOutput(c.detailedAudits(audits), nil)) + return 0 + } + + c.UI.Output(tableOutput(c.simpleAudits(audits), nil)) + return 0 +} + +func (c *AuditListCommand) simpleAudits(audits map[string]*api.Audit) []string { paths := make([]string, 0, len(audits)) for path, _ := range audits { paths = append(paths, path) } sort.Strings(paths) - columns := []string{"Path | Type | Description | Replication Behavior | Options"} + columns := []string{"Path | Type | Description"} for _, path := range paths { audit := audits[path] + columns = append(columns, fmt.Sprintf("%s | %s | %s", + audit.Path, + audit.Type, + audit.Description, + )) + } + + return columns +} + +func (c *AuditListCommand) detailedAudits(audits map[string]*api.Audit) []string { + paths := make([]string, 0, len(audits)) + for path, _ := range audits { + paths = append(paths, path) + } + sort.Strings(paths) + + columns := []string{"Path | Type | Description | Replication | Options"} + for _, path := range paths { + audit := audits[path] + opts := make([]string, 0, len(audit.Options)) for k, v := range audit.Options { opts = append(opts, k+"="+v) } - replicatedBehavior := "replicated" + + replication := "replicated" if audit.Local { - replicatedBehavior = "local" + replication = "local" } - columns = append(columns, fmt.Sprintf( - "%s | %s | %s | %s | %s", audit.Path, audit.Type, audit.Description, replicatedBehavior, strings.Join(opts, " "))) + + columns = append(columns, fmt.Sprintf("%s | %s | %s | %s | %s", + audit.Path, + audit.Type, + audit.Description, + replication, + strings.Join(opts, " "), + )) } - c.Ui.Output(columnize.SimpleFormat(columns)) - return 0 -} - -func (c *AuditListCommand) Synopsis() string { - return "Lists enabled audit backends in Vault" -} - -func (c *AuditListCommand) Help() string { - helpText := ` -Usage: vault audit-list [options] - - List the enabled audit backends. - - The output lists the enabled audit backends and the options for those - backends. The options may contain sensitive information, and therefore - only a root Vault user can view this. - -General Options: -` + meta.GeneralOptionsUsage() - return strings.TrimSpace(helpText) + return columns } diff --git a/command/audit_list_test.go b/command/audit_list_test.go index 01d4f83ed0..9cbb0af5ee 100644 --- a/command/audit_list_test.go +++ b/command/audit_list_test.go @@ -1,50 +1,111 @@ package command import ( + "strings" "testing" "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" ) -func TestAuditList(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func testAuditListCommand(tb testing.TB) (*cli.MockUi, *AuditListCommand) { + tb.Helper() - ui := new(cli.MockUi) - c := &AuditListCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + ui := cli.NewMockUi() + return ui, &AuditListCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestAuditListCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "too_many_args", + []string{"foo"}, + "Too many arguments", + 1, + }, + { + "lists", + nil, + "Path", + 0, + }, + { + "detailed", + []string{"-detailed"}, + "Options", + 0, }, } - args := []string{ - "-address", addr, + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().EnableAuditWithOptions("file", &api.EnableAuditOptions{ + Type: "file", + Options: map[string]string{ + "file_path": "discard", + }, + }); err != nil { + t.Fatal(err) + } + + ui, cmd := testAuditListCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) } - // Run once to get the client - c.Run(args) + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() - // Get the client - client, err := c.Client() - if err != nil { - t.Fatalf("err: %#v", err) - } - if err := client.Sys().EnableAuditWithOptions("foo", &api.EnableAuditOptions{ - Type: "noop", - Description: "noop", - Options: nil, - }); err != nil { - t.Fatalf("err: %#v", err) - } + client, closer := testVaultServerBad(t) + defer closer() - // Run again - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } + ui, cmd := testAuditListCommand(t) + cmd.client = client + + code := cmd.Run([]string{}) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error listing audits: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testAuditListCommand(t) + assertNoTabs(t, cmd) + }) } diff --git a/command/auth.go b/command/auth.go index 5beabff5cb..fa8cc66268 100644 --- a/command/auth.go +++ b/command/auth.go @@ -1,557 +1,117 @@ package command import ( - "bufio" - "encoding/json" - "fmt" + "flag" "io" - "os" - "sort" - "strconv" + "io/ioutil" "strings" - "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/helper/kv-builder" - "github.com/hashicorp/vault/helper/password" - "github.com/hashicorp/vault/meta" - "github.com/mitchellh/mapstructure" - "github.com/posener/complete" - "github.com/ryanuber/columnize" + "github.com/mitchellh/cli" ) -// AuthHandler is the interface that any auth handlers must implement -// to enable auth via the CLI. -type AuthHandler interface { - Auth(*api.Client, map[string]string) (*api.Secret, error) - Help() string -} +var _ cli.Command = (*AuthCommand)(nil) -// AuthCommand is a Command that handles authentication. type AuthCommand struct { - meta.Meta + *BaseCommand - Handlers map[string]AuthHandler + Handlers map[string]LoginHandler - // The fields below can be overwritten for tests - testStdin io.Reader -} - -func (c *AuthCommand) Run(args []string) int { - var method, authPath string - var methods, methodHelp, noVerify, noStore, tokenOnly bool - flags := c.Meta.FlagSet("auth", meta.FlagSetDefault) - flags.BoolVar(&methods, "methods", false, "") - flags.BoolVar(&methodHelp, "method-help", false, "") - flags.BoolVar(&noVerify, "no-verify", false, "") - flags.BoolVar(&noStore, "no-store", false, "") - flags.BoolVar(&tokenOnly, "token-only", false, "") - flags.StringVar(&method, "method", "", "method") - flags.StringVar(&authPath, "path", "", "") - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - if methods { - return c.listMethods() - } - - args = flags.Args() - - tokenHelper, err := c.TokenHelper() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing token helper: %s\n\n"+ - "Please verify that the token helper is available and properly\n"+ - "configured for your system. Please refer to the documentation\n"+ - "on token helpers for more information.", - err)) - return 1 - } - - // token is where the final token will go - handler := c.Handlers[method] - - // Read token from stdin if first arg is exactly "-" - var stdin io.Reader = os.Stdin - if c.testStdin != nil { - stdin = c.testStdin - } - - if len(args) > 0 && args[0] == "-" { - stdinR := bufio.NewReader(stdin) - args[0], err = stdinR.ReadString('\n') - if err != nil && err != io.EOF { - c.Ui.Error(fmt.Sprintf("Error reading from stdin: %s", err)) - return 1 - } - args[0] = strings.TrimSpace(args[0]) - } - - if method == "" { - token := "" - if len(args) > 0 { - token = args[0] - } - - handler = &tokenAuthHandler{Token: token} - args = nil - - switch authPath { - case "", "auth/token": - default: - c.Ui.Error("Token authentication does not support custom paths") - return 1 - } - } - - if handler == nil { - methods := make([]string, 0, len(c.Handlers)) - for k := range c.Handlers { - methods = append(methods, k) - } - sort.Strings(methods) - - c.Ui.Error(fmt.Sprintf( - "Unknown authentication method: %s\n\n"+ - "Please use a supported authentication method. The list of supported\n"+ - "authentication methods is shown below. Note that this list may not\n"+ - "be exhaustive: Vault may support other auth methods. For auth methods\n"+ - "unsupported by the CLI, please use the HTTP API.\n\n"+ - "%s", - method, - strings.Join(methods, ", "))) - return 1 - } - - if methodHelp { - c.Ui.Output(handler.Help()) - return 0 - } - - // Warn if the VAULT_TOKEN environment variable is set, as that will take - // precedence. Don't output on token-only since we're likely piping output. - if os.Getenv("VAULT_TOKEN") != "" && !tokenOnly { - c.Ui.Output("==> WARNING: VAULT_TOKEN environment variable set!\n") - c.Ui.Output(" The environment variable takes precedence over the value") - c.Ui.Output(" set by the auth command. Either update the value of the") - c.Ui.Output(" environment variable or unset it to use the new token.\n") - } - - var vars map[string]string - if len(args) > 0 { - builder := kvbuilder.Builder{Stdin: os.Stdin} - if err := builder.Add(args...); err != nil { - c.Ui.Error(err.Error()) - return 1 - } - - if err := mapstructure.Decode(builder.Map(), &vars); err != nil { - c.Ui.Error(fmt.Sprintf("Error parsing options: %s", err)) - return 1 - } - } else { - vars = make(map[string]string) - } - - // Build the client so we can auth - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client to auth: %s", err)) - return 1 - } - - if authPath != "" { - vars["mount"] = authPath - } - - // Authenticate - secret, err := handler.Auth(client, vars) - if err != nil { - c.Ui.Error(err.Error()) - return 1 - } - if secret == nil { - c.Ui.Error("Empty response from auth helper") - return 1 - } - - // If we had requested a wrapped token, we want to unset that request - // before performing further functions - client.SetWrappingLookupFunc(func(string, string) string { - return "" - }) - -CHECK_TOKEN: - var token string - switch { - case secret == nil: - c.Ui.Error("Empty response from auth helper") - return 1 - - case secret.Auth != nil: - token = secret.Auth.ClientToken - - case secret.WrapInfo != nil: - if secret.WrapInfo.WrappedAccessor == "" { - c.Ui.Error("Got a wrapped response from Vault but wrapped reply does not seem to contain a token") - return 1 - } - if tokenOnly { - c.Ui.Output(secret.WrapInfo.Token) - return 0 - } - if noStore { - return OutputSecret(c.Ui, "table", secret) - } - client.SetToken(secret.WrapInfo.Token) - secret, err = client.Logical().Unwrap("") - goto CHECK_TOKEN - - default: - c.Ui.Error("No auth or wrapping info in auth helper response") - return 1 - } - - // Cache the previous token so that it can be restored if authentication fails - var previousToken string - if previousToken, err = tokenHelper.Get(); err != nil { - c.Ui.Error(fmt.Sprintf("Error caching the previous token: %s\n\n", err)) - return 1 - } - - if tokenOnly { - c.Ui.Output(token) - return 0 - } - - // Store the token! - if !noStore { - if err := tokenHelper.Store(token); err != nil { - c.Ui.Error(fmt.Sprintf( - "Error storing token: %s\n\n"+ - "Authentication was not successful and did not persist.\n"+ - "Please reauthenticate, or fix the issue above if possible.", - err)) - return 1 - } - } - - if noVerify { - c.Ui.Output(fmt.Sprintf( - "Authenticated - no token verification has been performed.", - )) - - if noStore { - if err := tokenHelper.Erase(); err != nil { - c.Ui.Error(fmt.Sprintf( - "Error removing prior token: %s\n\n"+ - "Authentication was successful, but unable to remove the\n"+ - "previous token.", - err)) - return 1 - } - } - return 0 - } - - // Build the client again so it can read the token we just wrote - client, err = c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client to verify the token: %s", err)) - if !noStore { - if err := tokenHelper.Store(previousToken); err != nil { - c.Ui.Error(fmt.Sprintf( - "Error restoring the previous token: %s\n\n"+ - "Please reauthenticate with a valid token.", - err)) - } - } - return 1 - } - client.SetWrappingLookupFunc(func(string, string) string { - return "" - }) - - // If in no-store mode it won't have read the token from a token-helper (or - // will read an old one) so set it explicitly - if noStore { - client.SetToken(token) - } - - // Verify the token - secret, err = client.Auth().Token().LookupSelf() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error validating token: %s", err)) - if err := tokenHelper.Store(previousToken); err != nil { - c.Ui.Error(fmt.Sprintf( - "Error restoring the previous token: %s\n\n"+ - "Please reauthenticate with a valid token.", - err)) - } - return 1 - } - if secret == nil && !noStore { - c.Ui.Error(fmt.Sprintf("Error: Invalid token")) - if err := tokenHelper.Store(previousToken); err != nil { - c.Ui.Error(fmt.Sprintf( - "Error restoring the previous token: %s\n\n"+ - "Please reauthenticate with a valid token.", - err)) - } - return 1 - } - - if noStore { - if err := tokenHelper.Erase(); err != nil { - c.Ui.Error(fmt.Sprintf( - "Error removing prior token: %s\n\n"+ - "Authentication was successful, but unable to remove the\n"+ - "previous token.", - err)) - return 1 - } - } - - // Get the policies we have - policiesRaw, ok := secret.Data["policies"] - if !ok || policiesRaw == nil { - policiesRaw = []interface{}{"unknown"} - } - var policies []string - for _, v := range policiesRaw.([]interface{}) { - policies = append(policies, v.(string)) - } - - output := "Successfully authenticated! You are now logged in." - if noStore { - output += "\nThe token has not been stored to the configured token helper." - } - if method != "" { - output += "\nThe token below is already saved in the session. You do not" - output += "\nneed to \"vault auth\" again with the token." - } - output += fmt.Sprintf("\ntoken: %s", secret.Data["id"]) - output += fmt.Sprintf("\ntoken_duration: %s", secret.Data["ttl"].(json.Number).String()) - if len(policies) > 0 { - output += fmt.Sprintf("\ntoken_policies: %v", policies) - } - - c.Ui.Output(output) - - return 0 - -} - -func (c *AuthCommand) getMethods() (map[string]*api.AuthMount, error) { - client, err := c.Client() - if err != nil { - return nil, err - } - client.SetWrappingLookupFunc(func(string, string) string { - return "" - }) - - auth, err := client.Sys().ListAuth() - if err != nil { - return nil, err - } - - return auth, nil -} - -func (c *AuthCommand) listMethods() int { - auth, err := c.getMethods() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error reading auth table: %s", err)) - return 1 - } - - paths := make([]string, 0, len(auth)) - for path := range auth { - paths = append(paths, path) - } - sort.Strings(paths) - - columns := []string{"Path | Type | Accessor | Default TTL | Max TTL | Replication Behavior | Seal Wrap | Description"} - for _, path := range paths { - auth := auth[path] - defTTL := "system" - if auth.Config.DefaultLeaseTTL != 0 { - defTTL = strconv.Itoa(auth.Config.DefaultLeaseTTL) - } - maxTTL := "system" - if auth.Config.MaxLeaseTTL != 0 { - maxTTL = strconv.Itoa(auth.Config.MaxLeaseTTL) - } - replicatedBehavior := "replicated" - if auth.Local { - replicatedBehavior = "local" - } - columns = append(columns, fmt.Sprintf( - "%s | %s | %s | %s | %s | %s | %t | %s", path, auth.Type, auth.Accessor, defTTL, maxTTL, replicatedBehavior, auth.SealWrap, auth.Description)) - } - - c.Ui.Output(columnize.SimpleFormat(columns)) - return 0 + testStdin io.Reader // for tests } func (c *AuthCommand) Synopsis() string { - return "Prints information about how to authenticate with Vault" + return "Interact with auth methods" } func (c *AuthCommand) Help() string { - helpText := ` -Usage: vault auth [options] [auth-information] + return strings.TrimSpace(` +Usage: vault auth [options] [args] - Authenticate with Vault using the given token or via any supported - authentication backend. + This command groups subcommands for interacting with Vault's auth methods. + Users can list, enable, disable, and get help for different auth methods. - By default, the -method is assumed to be token. If not supplied via the - command-line, a prompt for input will be shown. If the authentication - information is "-", it will be read from stdin. + To authenticate to Vault as a user or machine, use the "vault login" command + instead. This command is for interacting with the auth methods themselves, not + authenticating to Vault. - The -method option allows alternative authentication methods to be used, - such as userpass, GitHub, or TLS certificates. For these, additional - values as "key=value" pairs may be required. For example, to authenticate - to the userpass auth backend: + List all enabled auth methods: - $ vault auth -method=userpass username=my-username + $ vault auth list - Use "-method-help" to get help for a specific method. + Enable a new auth method "userpass"; - If an auth backend is enabled at a different path, the "-method" flag - should still point to the canonical name, and the "-path" flag should be - used. If a GitHub auth backend was mounted as "github-private", one would - authenticate to this backend via: + $ vault auth enable userpass - $ vault auth -method=github -path=github-private + Get detailed help information about how to authenticate to a particular auth + method: - The value of the "-path" flag is supplied to auth providers as the "mount" - option in the payload to specify the mount point. + $ vault auth help github - If response wrapping is used (via -wrap-ttl), the returned token will be - automatically unwrapped unless: - * -token-only is used, in which case the wrapping token will be output - * -no-store is used, in which case the details of the wrapping token - will be printed - -General Options: - - ` + meta.GeneralOptionsUsage() + ` - -Auth Options: - - -method=name Use the method given here, which is a type of backend, not - the path. If this authentication method is not available, - exit with code 1. - - -method-help If set, the help for the selected method will be shown. - - -methods List the available auth methods. - - -no-verify Do not verify the token after creation; avoids a use count - decrement. - - -no-store Do not store the token after creation; it will only be - displayed in the command output. - - -token-only Output only the token to stdout. This implies -no-verify - and -no-store. - - -path The path at which the auth backend is enabled. If an auth - backend is mounted at multiple paths, this option can be - used to authenticate against specific paths. -` - return strings.TrimSpace(helpText) + Please see the individual subcommand help for detailed usage information. +`) } -// tokenAuthHandler handles retrieving the token from the command-line. -type tokenAuthHandler struct { - Token string -} +func (c *AuthCommand) Run(args []string) int { + // If we entered the run method, none of the subcommands picked up. This + // means the user is still trying to use auth as "vault auth TOKEN" or + // similar, so direct them to vault login instead. + // + // This run command is a bit messy to maintain BC for a bit. In the future, + // it will just be a tiny function, but for now we have to maintain bc. + // + // Deprecation + // TODO: remove in 0.9.0 -func (h *tokenAuthHandler) Auth(*api.Client, map[string]string) (*api.Secret, error) { - token := h.Token - if token == "" { - var err error + // Parse the args for our deprecations and defer to the proper areas. + for _, arg := range args { + switch { + case strings.HasPrefix(arg, "-methods"): + c.UI.Warn(wrapAtLength( + "WARNING! The -methods flag is deprecated. Please use "+ + "\"vault auth list\" instead. This flag will be removed in the "+ + "next major release of Vault.") + "\n") + return (&AuthListCommand{ + BaseCommand: &BaseCommand{ + UI: c.UI, + client: c.client, + }, + }).Run(nil) + case strings.HasPrefix(arg, "-method-help"): + c.UI.Warn(wrapAtLength( + "WARNING! The -method-help flag is deprecated. Please use "+ + "\"vault auth help\" instead. This flag will be removed in the "+ + "next major release of Vault.") + "\n") + // Parse the args to pull out the method, surpressing any errors because + // there could be other flags that we don't care about. + f := flag.NewFlagSet("", flag.ContinueOnError) + f.Usage = func() {} + f.SetOutput(ioutil.Discard) + flagMethod := f.String("method", "", "") + f.Parse(args) - // No arguments given, read the token from user input - fmt.Printf("Token (will be hidden): ") - token, err = password.Read(os.Stdin) - fmt.Printf("\n") - if err != nil { - return nil, fmt.Errorf( - "Error attempting to ask for token. The raw error message\n"+ - "is shown below, but the most common reason for this error is\n"+ - "that you attempted to pipe a value into auth. If you want to\n"+ - "pipe the token, please pass '-' as the token argument.\n\n"+ - "Raw error: %s", err) + return (&AuthHelpCommand{ + BaseCommand: &BaseCommand{ + UI: c.UI, + client: c.client, + }, + Handlers: c.Handlers, + }).Run([]string{*flagMethod}) } } - if token == "" { - return nil, fmt.Errorf( - "A token must be passed to auth. Please view the help\n" + - "for more information.") - } - - return &api.Secret{ - Auth: &api.SecretAuth{ - ClientToken: token, + // If we got this far, we have an arg or a series of args that should be + // passed directly to the new "vault login" command. + c.UI.Warn(wrapAtLength( + "WARNING! The \"vault auth ARG\" command is deprecated and is now a "+ + "subcommand for interacting with auth methods. To "+ + "authenticate locally to Vault, use \"vault login\" instead. This "+ + "backwards compatability will be removed in the next major release of "+ + "Vault.") + "\n") + return (&LoginCommand{ + BaseCommand: &BaseCommand{ + UI: c.UI, + client: c.client, }, - }, nil -} - -func (h *tokenAuthHandler) Help() string { - help := ` -No method selected with the "-method" flag, so the "auth" command assumes -you'll be using raw token authentication. For this, specify the token to -authenticate as the parameter to "vault auth". Example: - - vault auth 123456 - -The token used to authenticate must come from some other source. A root -token is created when Vault is first initialized. After that, subsequent -tokens are created via the API or command line interface (with the -"token"-prefixed commands). -` - - return strings.TrimSpace(help) -} - -func (c *AuthCommand) AutocompleteArgs() complete.Predictor { - return complete.PredictNothing -} - -func (c *AuthCommand) AutocompleteFlags() complete.Flags { - var predictFunc complete.PredictFunc = func(a complete.Args) []string { - auths, err := c.getMethods() - if err != nil { - return []string{} - } - - methods := make([]string, 0, len(auths)) - for _, auth := range auths { - if strings.HasPrefix(auth.Type, a.Last) { - methods = append(methods, auth.Type) - } - } - - return methods - } - - return complete.Flags{ - "-method": predictFunc, - "-methods": complete.PredictNothing, - "-method-help": complete.PredictNothing, - "-no-verify": complete.PredictNothing, - "-no-store": complete.PredictNothing, - "-token-only": complete.PredictNothing, - "-path": complete.PredictNothing, - } + Handlers: c.Handlers, + }).Run(args) } diff --git a/command/auth_disable.go b/command/auth_disable.go index 621ce5907c..afcfe747df 100644 --- a/command/auth_disable.go +++ b/command/auth_disable.go @@ -4,66 +4,84 @@ import ( "fmt" "strings" - "github.com/hashicorp/vault/meta" + "github.com/mitchellh/cli" + "github.com/posener/complete" ) -// AuthDisableCommand is a Command that enables a new endpoint. +var _ cli.Command = (*AuthDisableCommand)(nil) +var _ cli.CommandAutocomplete = (*AuthDisableCommand)(nil) + type AuthDisableCommand struct { - meta.Meta -} - -func (c *AuthDisableCommand) Run(args []string) int { - flags := c.Meta.FlagSet("auth-disable", meta.FlagSetDefault) - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - args = flags.Args() - if len(args) != 1 { - flags.Usage() - c.Ui.Error(fmt.Sprintf( - "\nauth-disable expects one argument: the path to disable.")) - return 1 - } - - path := args[0] - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 2 - } - - if err := client.Sys().DisableAuth(path); err != nil { - c.Ui.Error(fmt.Sprintf( - "Error: %s", err)) - return 2 - } - - c.Ui.Output(fmt.Sprintf( - "Disabled auth provider at path '%s' if it was enabled", path)) - - return 0 + *BaseCommand } func (c *AuthDisableCommand) Synopsis() string { - return "Disable an auth provider" + return "Disables an auth method" } func (c *AuthDisableCommand) Help() string { helpText := ` -Usage: vault auth-disable [options] path +Usage: vault auth disable [options] PATH - Disable an already-enabled auth provider. + Disables an existing auth method at the given PATH. The argument corresponds + to the PATH of the mount, not the TYPE!. Once the auth method is disabled its + path can no longer be used to authenticate. - Once the auth provider is disabled its path can no longer be used - to authenticate. All access tokens generated via the disabled auth provider - will be revoked. This command will block until all tokens are revoked. - If the command is exited early the tokens will still be revoked. + All access tokens generated via the disabled auth method are immediately + revoked. This command will block until all tokens are revoked. + + Disable the auth method at userpass/: + + $ vault auth disable userpass/ + +` + c.Flags().Help() -General Options: -` + meta.GeneralOptionsUsage() return strings.TrimSpace(helpText) } + +func (c *AuthDisableCommand) Flags() *FlagSets { + return c.flagSet(FlagSetHTTP) +} + +func (c *AuthDisableCommand) AutocompleteArgs() complete.Predictor { + return c.PredictVaultAuths() +} + +func (c *AuthDisableCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *AuthDisableCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + switch { + case len(args) < 1: + c.UI.Error(fmt.Sprintf("Not enough arguments (expected 1, got %d)", len(args))) + return 1 + case len(args) > 1: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args))) + return 1 + } + + path := ensureTrailingSlash(sanitizePath(args[0])) + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + if err := client.Sys().DisableAuth(path); err != nil { + c.UI.Error(fmt.Sprintf("Error disabling auth method at %s: %s", path, err)) + return 2 + } + + c.UI.Output(fmt.Sprintf("Success! Disabled the auth method (if it existed) at: %s", path)) + return 0 +} diff --git a/command/auth_disable_test.go b/command/auth_disable_test.go index fb2b91fb23..dbe2e776f9 100644 --- a/command/auth_disable_test.go +++ b/command/auth_disable_test.go @@ -1,102 +1,133 @@ package command import ( + "strings" "testing" - "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" ) -func TestAuthDisable(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func testAuthDisableCommand(tb testing.TB) (*cli.MockUi, *AuthDisableCommand) { + tb.Helper() - ui := new(cli.MockUi) - c := &AuthDisableCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + ui := cli.NewMockUi() + return ui, &AuthDisableCommand{ + BaseCommand: &BaseCommand{ + UI: ui, }, } - - args := []string{ - "-address", addr, - "noop", - } - - // Run the command once to setup the client, it will fail - c.Run(args) - - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } - - if err := client.Sys().EnableAuth("noop", "noop", ""); err != nil { - t.Fatalf("err: %s", err) - } - - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - mounts, err := client.Sys().ListAuth() - if err != nil { - t.Fatalf("err: %s", err) - } - - if _, ok := mounts["noop"]; ok { - t.Fatal("should not have noop mount") - } } -func TestAuthDisableWithOptions(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func TestAuthDisableCommand_Run(t *testing.T) { + t.Parallel() - ui := new(cli.MockUi) - c := &AuthDisableCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + cases := []struct { + name string + args []string + out string + code int + }{ + { + "not_enough_args", + nil, + "Not enough arguments", + 1, + }, + { + "too_many_args", + []string{"foo", "bar"}, + "Too many arguments", + 1, }, } - args := []string{ - "-address", addr, - "noop", - } + t.Run("validations", func(t *testing.T) { + t.Parallel() - // Run the command once to setup the client, it will fail - c.Run(args) + for _, tc := range cases { + tc := tc - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - if err := client.Sys().EnableAuthWithOptions("noop", &api.EnableAuthOptions{ - Type: "noop", - Description: "", - }); err != nil { - t.Fatalf("err: %#v", err) - } + ui, cmd := testAuthDisableCommand(t) - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } - mounts, err := client.Sys().ListAuth() - if err != nil { - t.Fatalf("err: %s", err) - } + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) - if _, ok := mounts["noop"]; ok { - t.Fatal("should not have noop mount") - } + t.Run("integration", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().EnableAuth("my-auth", "userpass", ""); err != nil { + t.Fatal(err) + } + + ui, cmd := testAuthDisableCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "my-auth", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Success! Disabled the auth method" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + auths, err := client.Sys().ListAuth() + if err != nil { + t.Fatal(err) + } + + if auth, ok := auths["my-auth/"]; ok { + t.Errorf("expected auth to be disabled: %#v", auth) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testAuthDisableCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "my-auth", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error disabling auth method at my-auth/: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testAuthDisableCommand(t) + assertNoTabs(t, cmd) + }) } diff --git a/command/auth_enable.go b/command/auth_enable.go index 09227a749d..1cc8dbc1a7 100644 --- a/command/auth_enable.go +++ b/command/auth_enable.go @@ -5,142 +5,168 @@ import ( "strings" "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/meta" + "github.com/mitchellh/cli" "github.com/posener/complete" ) -// AuthEnableCommand is a Command that enables a new endpoint. +var _ cli.Command = (*AuthEnableCommand)(nil) +var _ cli.CommandAutocomplete = (*AuthEnableCommand)(nil) + type AuthEnableCommand struct { - meta.Meta -} + *BaseCommand -func (c *AuthEnableCommand) Run(args []string) int { - var description, path, pluginName string - var local, sealWrap bool - flags := c.Meta.FlagSet("auth-enable", meta.FlagSetDefault) - flags.StringVar(&description, "description", "", "") - flags.StringVar(&path, "path", "", "") - flags.StringVar(&pluginName, "plugin-name", "", "") - flags.BoolVar(&local, "local", false, "") - flags.BoolVar(&sealWrap, "seal-wrap", false, "") - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - args = flags.Args() - if len(args) != 1 { - flags.Usage() - c.Ui.Error(fmt.Sprintf( - "\nauth-enable expects one argument: the type to enable.")) - return 1 - } - - authType := args[0] - - // If no path is specified, we default the path to the backend type - // or use the plugin name if it's a plugin backend - if path == "" { - if authType == "plugin" { - path = pluginName - } else { - path = authType - } - } - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 2 - } - - if err := client.Sys().EnableAuthWithOptions(path, &api.EnableAuthOptions{ - Type: authType, - Description: description, - Config: api.AuthConfigInput{ - PluginName: pluginName, - }, - Local: local, - SealWrap: sealWrap, - }); err != nil { - c.Ui.Error(fmt.Sprintf( - "Error: %s", err)) - return 2 - } - - authTypeOutput := fmt.Sprintf("'%s'", authType) - if authType == "plugin" { - authTypeOutput = fmt.Sprintf("plugin '%s'", pluginName) - } - - c.Ui.Output(fmt.Sprintf( - "Successfully enabled %s at '%s'!", - authTypeOutput, path)) - - return 0 + flagDescription string + flagPath string + flagPluginName string + flagLocal bool + flagSealWrap bool } func (c *AuthEnableCommand) Synopsis() string { - return "Enable a new auth provider" + return "Enables a new auth method" } func (c *AuthEnableCommand) Help() string { helpText := ` -Usage: vault auth-enable [options] type +Usage: vault auth enable [options] TYPE - Enable a new auth provider. + Enables a new auth method. An auth method is responsible for authenticating + users or machines and assigning them policies with which they can access + Vault. - This command enables a new auth provider. An auth provider is responsible - for authenticating a user and assigning them policies with which they can - access Vault. + Enable the userpass auth method at userpass/: -General Options: -` + meta.GeneralOptionsUsage() + ` -Auth Enable Options: + $ vault auth enable userpass - -description= Human-friendly description of the purpose of the - auth provider. This shows up in the auth -methods command. + Enable the LDAP auth method at auth-prod/: - -path= Mount point for the auth provider. This defaults - to the type of the mount. This will make the auth - provider available at "/auth/" + $ vault auth enable -path=auth-prod ldap - -plugin-name Name of the auth plugin to use based from the name - in the plugin catalog. + Enable a custom auth plugin (after it's registered in the plugin registry): - -local Mark the mount as a local mount. Local mounts - are not replicated nor (if a secondary) - removed by replication. + $ vault auth enable -path=my-auth -plugin-name=my-auth-plugin plugin + +` + c.Flags().Help() - -seal-wrap Turn on seal wrapping for the mount. -` return strings.TrimSpace(helpText) } -func (c *AuthEnableCommand) AutocompleteArgs() complete.Predictor { - return complete.PredictSet( - "approle", - "cert", - "aws", - "app-id", - "gcp", - "github", - "userpass", - "ldap", - "okta", - "radius", - "plugin", - ) +func (c *AuthEnableCommand) Flags() *FlagSets { + set := c.flagSet(FlagSetHTTP) + f := set.NewFlagSet("Command Options") + + f.StringVar(&StringVar{ + Name: "description", + Target: &c.flagDescription, + Completion: complete.PredictAnything, + Usage: "Human-friendly description for the purpose of this " + + "auth method.", + }) + + f.StringVar(&StringVar{ + Name: "path", + Target: &c.flagPath, + Default: "", // The default is complex, so we have to manually document + Completion: complete.PredictAnything, + Usage: "Place where the auth method will be accessible. This must be " + + "unique across all auth methods. This defaults to the \"type\" of " + + "the auth method. The auth method will be accessible at " + + "\"/auth/\".", + }) + + f.StringVar(&StringVar{ + Name: "plugin-name", + Target: &c.flagPluginName, + Completion: complete.PredictAnything, + Usage: "Name of the auth method plugin. This plugin name must already " + + "exist in the Vault server's plugin catalog.", + }) + + f.BoolVar(&BoolVar{ + Name: "local", + Target: &c.flagLocal, + Default: false, + Usage: "Mark the auth method as local-only. Local auth methods are " + + "not replicated nor removed by replication.", + }) + + f.BoolVar(&BoolVar{ + Name: "seal-wrap", + Target: &c.flagSealWrap, + Default: false, + Usage: "Enable seal wrapping of critical values in the secrets engine.", + }) + + return set +} + +func (c *AuthEnableCommand) AutocompleteArgs() complete.Predictor { + return c.PredictVaultAvailableAuths() } func (c *AuthEnableCommand) AutocompleteFlags() complete.Flags { - return complete.Flags{ - "-description": complete.PredictNothing, - "-path": complete.PredictNothing, - "-plugin-name": complete.PredictNothing, - "-local": complete.PredictNothing, - "-seal-wrap": complete.PredictNothing, - } + return c.Flags().Completions() +} + +func (c *AuthEnableCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + switch { + case len(args) < 1: + c.UI.Error(fmt.Sprintf("Not enough arguments (expected 1, got %d)", len(args))) + return 1 + case len(args) > 1: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args))) + return 1 + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + authType := strings.TrimSpace(args[0]) + + // If no path is specified, we default the path to the backend type + // or use the plugin name if it's a plugin backend + authPath := c.flagPath + if authPath == "" { + if authType == "plugin" { + authPath = c.flagPluginName + } else { + authPath = authType + } + } + + // Append a trailing slash to indicate it's a path in output + authPath = ensureTrailingSlash(authPath) + + if err := client.Sys().EnableAuthWithOptions(authPath, &api.EnableAuthOptions{ + Type: authType, + Description: c.flagDescription, + Local: c.flagLocal, + SealWrap: c.flagSealWrap, + Config: api.AuthConfigInput{ + PluginName: c.flagPluginName, + }, + }); err != nil { + c.UI.Error(fmt.Sprintf("Error enabling %s auth: %s", authType, err)) + return 2 + } + + authThing := authType + " auth method" + if authType == "plugin" { + authThing = c.flagPluginName + " plugin" + } + + c.UI.Output(fmt.Sprintf("Success! Enabled %s at: %s", authThing, authPath)) + return 0 } diff --git a/command/auth_enable_test.go b/command/auth_enable_test.go index 0f8348700f..e4308f9934 100644 --- a/command/auth_enable_test.go +++ b/command/auth_enable_test.go @@ -1,50 +1,144 @@ package command import ( + "strings" "testing" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" ) -func TestAuthEnable(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func testAuthEnableCommand(tb testing.TB) (*cli.MockUi, *AuthEnableCommand) { + tb.Helper() - ui := new(cli.MockUi) - c := &AuthEnableCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + ui := cli.NewMockUi() + return ui, &AuthEnableCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestAuthEnableCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "not_enough_args", + nil, + "Not enough arguments", + 1, + }, + { + "too_many_args", + []string{"foo", "bar"}, + "Too many arguments", + 1, + }, + { + "not_a_valid_auth", + []string{"nope_definitely_not_a_valid_mount_like_ever"}, + "", + 2, }, } - args := []string{ - "-address", addr, - "noop", - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testAuthEnableCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) } - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } + t.Run("integration", func(t *testing.T) { + t.Parallel() - mounts, err := client.Sys().ListAuth() - if err != nil { - t.Fatalf("err: %s", err) - } + client, closer := testVaultServer(t) + defer closer() - mount, ok := mounts["noop/"] - if !ok { - t.Fatal("should have noop mount") - } - if mount.Type != "noop" { - t.Fatal("should be noop type") - } + ui, cmd := testAuthEnableCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "-path", "auth_integration/", + "-description", "The best kind of test", + "userpass", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Success! Enabled userpass auth method at:" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + auths, err := client.Sys().ListAuth() + if err != nil { + t.Fatal(err) + } + + authInfo, ok := auths["auth_integration/"] + if !ok { + t.Fatalf("expected mount to exist") + } + if exp := "userpass"; authInfo.Type != exp { + t.Errorf("expected %q to be %q", authInfo.Type, exp) + } + if exp := "The best kind of test"; authInfo.Description != exp { + t.Errorf("expected %q to be %q", authInfo.Description, exp) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testAuthEnableCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "userpass", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error enabling userpass auth: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testAuthEnableCommand(t) + assertNoTabs(t, cmd) + }) } diff --git a/command/auth_help.go b/command/auth_help.go new file mode 100644 index 0000000000..d18d1bf095 --- /dev/null +++ b/command/auth_help.go @@ -0,0 +1,125 @@ +package command + +import ( + "fmt" + "strings" + + "github.com/mitchellh/cli" + "github.com/posener/complete" +) + +var _ cli.Command = (*AuthHelpCommand)(nil) +var _ cli.CommandAutocomplete = (*AuthHelpCommand)(nil) + +type AuthHelpCommand struct { + *BaseCommand + + Handlers map[string]LoginHandler +} + +func (c *AuthHelpCommand) Synopsis() string { + return "Prints usage for an auth method" +} + +func (c *AuthHelpCommand) Help() string { + helpText := ` +Usage: vault auth help [options] TYPE | PATH + + Prints usage and help for an auth method. + + - If given a TYPE, this command prints the default help for the + auth method of that type. + + - If given a PATH, this command prints the help output for the + auth method enabled at that path. This path must already + exist. + + Get usage instructions for the userpass auth method: + + $ vault auth help userpass + + Print usage for the auth method enabled at my-method/: + + $ vault auth help my-method/ + + Each auth method produces its own help output. + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *AuthHelpCommand) Flags() *FlagSets { + return c.flagSet(FlagSetHTTP) +} + +func (c *AuthHelpCommand) AutocompleteArgs() complete.Predictor { + handlers := make([]string, 0, len(c.Handlers)) + for k := range c.Handlers { + handlers = append(handlers, k) + } + return complete.PredictSet(handlers...) +} + +func (c *AuthHelpCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *AuthHelpCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + switch { + case len(args) < 1: + c.UI.Error(fmt.Sprintf("Not enough arguments (expected 1, got %d)", len(args))) + return 1 + case len(args) > 1: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args))) + return 1 + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + // Start with the assumption that we have an auth type, not a path. + authType := strings.TrimSpace(args[0]) + + authHandler, ok := c.Handlers[authType] + if !ok { + // There was no auth type by that name, see if it's a mount + auths, err := client.Sys().ListAuth() + if err != nil { + c.UI.Error(fmt.Sprintf("Error listing auth methods: %s", err)) + return 2 + } + + authPath := ensureTrailingSlash(sanitizePath(args[0])) + auth, ok := auths[authPath] + if !ok { + c.UI.Error(fmt.Sprintf( + "Error retrieving help: unknown auth method: %s", authType)) + return 1 + } + + authHandler, ok = c.Handlers[auth.Type] + if !ok { + c.UI.Error(wrapAtLength(fmt.Sprintf( + "INTERNAL ERROR! Found an auth method enabled at %s, but "+ + "its type %q is not registered in Vault. This is a bug and should "+ + "be reported. Please open an issue at github.com/hashicorp/vault.", + authPath, authType))) + return 2 + } + } + + c.UI.Output(authHandler.Help()) + return 0 +} diff --git a/command/auth_help_test.go b/command/auth_help_test.go new file mode 100644 index 0000000000..9457bea0ec --- /dev/null +++ b/command/auth_help_test.go @@ -0,0 +1,152 @@ +package command + +import ( + "strings" + "testing" + + "github.com/mitchellh/cli" + + credUserpass "github.com/hashicorp/vault/builtin/credential/userpass" +) + +func testAuthHelpCommand(tb testing.TB) (*cli.MockUi, *AuthHelpCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &AuthHelpCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + Handlers: map[string]LoginHandler{ + "userpass": &credUserpass.CLIHandler{ + DefaultMount: "userpass", + }, + }, + } +} + +func TestAuthHelpCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "too_many_args", + []string{"foo", "bar"}, + "Too many arguments", + 1, + }, + { + "not_enough_args", + nil, + "Not enough arguments", + 1, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ui, cmd := testAuthHelpCommand(t) + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + + t.Run("path", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().EnableAuth("foo", "userpass", ""); err != nil { + t.Fatal(err) + } + + ui, cmd := testAuthHelpCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "foo/", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Usage: vault login -method=userpass" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("type", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + // No mounted auth methods + + ui, cmd := testAuthHelpCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "userpass", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Usage: vault login -method=userpass" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testAuthHelpCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "sys/mounts", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error listing auth methods: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testAuthHelpCommand(t) + assertNoTabs(t, cmd) + }) +} diff --git a/command/auth_list.go b/command/auth_list.go new file mode 100644 index 0000000000..ff56b8e022 --- /dev/null +++ b/command/auth_list.go @@ -0,0 +1,167 @@ +package command + +import ( + "fmt" + "sort" + "strconv" + "strings" + + "github.com/hashicorp/vault/api" + "github.com/mitchellh/cli" + "github.com/posener/complete" +) + +var _ cli.Command = (*AuthListCommand)(nil) +var _ cli.CommandAutocomplete = (*AuthListCommand)(nil) + +type AuthListCommand struct { + *BaseCommand + + flagDetailed bool +} + +func (c *AuthListCommand) Synopsis() string { + return "Lists enabled auth methods" +} + +func (c *AuthListCommand) Help() string { + helpText := ` +Usage: vault auth list [options] + + Lists the enabled auth methods on the Vault server. This command also outputs + information about the method including configuration and human-friendly + descriptions. A TTL of "system" indicates that the system default is in use. + + List all enabled auth methods: + + $ vault auth list + + List all enabled auth methods with detailed output: + + $ vault auth list -detailed + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *AuthListCommand) Flags() *FlagSets { + set := c.flagSet(FlagSetHTTP) + + f := set.NewFlagSet("Command Options") + + f.BoolVar(&BoolVar{ + Name: "detailed", + Target: &c.flagDetailed, + Default: false, + Usage: "Print detailed information such as configuration and replication " + + "status about each auth method.", + }) + + return set +} + +func (c *AuthListCommand) AutocompleteArgs() complete.Predictor { + return nil +} + +func (c *AuthListCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *AuthListCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + if len(args) > 0 { + c.UI.Error(fmt.Sprintf("Too many arguments (expected 0, got %d)", len(args))) + return 1 + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + auths, err := client.Sys().ListAuth() + if err != nil { + c.UI.Error(fmt.Sprintf("Error listing enabled authentications: %s", err)) + return 2 + } + + if c.flagDetailed { + c.UI.Output(tableOutput(c.detailedMounts(auths), nil)) + return 0 + } + + c.UI.Output(tableOutput(c.simpleMounts(auths), nil)) + return 0 +} + +func (c *AuthListCommand) simpleMounts(auths map[string]*api.AuthMount) []string { + paths := make([]string, 0, len(auths)) + for path := range auths { + paths = append(paths, path) + } + sort.Strings(paths) + + out := []string{"Path | Type | Description"} + for _, path := range paths { + mount := auths[path] + out = append(out, fmt.Sprintf("%s | %s | %s", path, mount.Type, mount.Description)) + } + + return out +} + +func (c *AuthListCommand) detailedMounts(auths map[string]*api.AuthMount) []string { + paths := make([]string, 0, len(auths)) + for path := range auths { + paths = append(paths, path) + } + sort.Strings(paths) + + calcTTL := func(typ string, ttl int) string { + switch { + case typ == "system", typ == "cubbyhole": + return "" + case ttl != 0: + return strconv.Itoa(ttl) + default: + return "system" + } + } + + out := []string{"Path | Type | Accessor | Plugin | Default TTL | Max TTL | Replication | Seal Wrap | Description"} + for _, path := range paths { + mount := auths[path] + + defaultTTL := calcTTL(mount.Type, mount.Config.DefaultLeaseTTL) + maxTTL := calcTTL(mount.Type, mount.Config.MaxLeaseTTL) + + replication := "replicated" + if mount.Local { + replication = "local" + } + + out = append(out, fmt.Sprintf("%s | %s | %s | %s | %s | %s | %s | %t | %s", + path, + mount.Type, + mount.Accessor, + mount.Config.PluginName, + defaultTTL, + maxTTL, + replication, + mount.SealWrap, + mount.Description, + )) + } + + return out +} diff --git a/command/auth_list_test.go b/command/auth_list_test.go new file mode 100644 index 0000000000..decf6e9b06 --- /dev/null +++ b/command/auth_list_test.go @@ -0,0 +1,105 @@ +package command + +import ( + "strings" + "testing" + + "github.com/mitchellh/cli" +) + +func testAuthListCommand(tb testing.TB) (*cli.MockUi, *AuthListCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &AuthListCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestAuthListCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "too_many_args", + []string{"foo"}, + "Too many arguments", + 1, + }, + { + "lists", + nil, + "Path", + 0, + }, + { + "detailed", + []string{"-detailed"}, + "Default TTL", + 0, + }, + } + + t.Run("validations", func(t *testing.T) { + t.Parallel() + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testAuthListCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testAuthListCommand(t) + cmd.client = client + + code := cmd.Run([]string{}) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error listing enabled authentications: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testAuthListCommand(t) + assertNoTabs(t, cmd) + }) +} diff --git a/command/auth_test.go b/command/auth_test.go index 8243129083..5ec0cf60d3 100644 --- a/command/auth_test.go +++ b/command/auth_test.go @@ -1,400 +1,135 @@ package command import ( - "fmt" - "io" - "io/ioutil" - "os" - "path/filepath" "strings" "testing" - credUserpass "github.com/hashicorp/vault/builtin/credential/userpass" - "github.com/hashicorp/vault/logical" - - "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" + + credToken "github.com/hashicorp/vault/builtin/credential/token" + credUserpass "github.com/hashicorp/vault/builtin/credential/userpass" + "github.com/hashicorp/vault/command/token" ) -func TestAuth_methods(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func testAuthCommand(tb testing.TB) (*cli.MockUi, *AuthCommand) { + tb.Helper() - testAuthInit(t) + ui := cli.NewMockUi() + return ui, &AuthCommand{ + BaseCommand: &BaseCommand{ + UI: ui, - ui := new(cli.MockUi) - c := &AuthCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, - TokenHelper: DefaultTokenHelper, + // Override to our own token helper + tokenHelper: token.NewTestingTokenHelper(), + }, + Handlers: map[string]LoginHandler{ + "token": &credToken.CLIHandler{}, + "userpass": &credUserpass.CLIHandler{}, }, - } - - args := []string{ - "-address", addr, - "-methods", - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - output := ui.OutputWriter.String() - if !strings.Contains(output, "token") { - t.Fatalf("bad: %#v", output) } } -func TestAuth_token(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func TestAuthCommand_Run(t *testing.T) { + t.Parallel() - testAuthInit(t) + // TODO: remove in 0.9.0 + t.Run("deprecated_methods", func(t *testing.T) { + t.Parallel() - ui := new(cli.MockUi) - c := &AuthCommand{ - Meta: meta.Meta{ - Ui: ui, - TokenHelper: DefaultTokenHelper, - }, - } + client, closer := testVaultServer(t) + defer closer() - args := []string{ - "-address", addr, - token, - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } + ui, cmd := testAuthCommand(t) + cmd.client = client - helper, err := c.TokenHelper() - if err != nil { - t.Fatalf("err: %s", err) - } + // vault auth -methods -> vault auth list + code := cmd.Run([]string{"-methods"}) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String()) + } + stdout, stderr := ui.OutputWriter.String(), ui.ErrorWriter.String() - actual, err := helper.Get() - if err != nil { - t.Fatalf("err: %s", err) - } + if expected := "WARNING!"; !strings.Contains(stderr, expected) { + t.Errorf("expected %q to contain %q", stderr, expected) + } - if actual != token { - t.Fatalf("bad: %s", actual) - } -} - -func TestAuth_wrapping(t *testing.T) { - baseConfig := &vault.CoreConfig{ - CredentialBackends: map[string]logical.Factory{ - "userpass": credUserpass.Factory, - }, - } - cluster := vault.NewTestCluster(t, baseConfig, &vault.TestClusterOptions{ - HandlerFunc: http.Handler, - BaseListenAddress: "127.0.0.1:8200", + if expected := "token/"; !strings.Contains(stdout, expected) { + t.Errorf("expected %q to contain %q", stdout, expected) + } }) - cluster.Start() - defer cluster.Cleanup() - testAuthInit(t) + t.Run("deprecated_method_help", func(t *testing.T) { + t.Parallel() - client := cluster.Cores[0].Client - err := client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{ - Type: "userpass", + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testAuthCommand(t) + cmd.client = client + + // vault auth -method=foo -method-help -> vault auth help foo + code := cmd.Run([]string{ + "-method=userpass", + "-method-help", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String()) + } + stdout, stderr := ui.OutputWriter.String(), ui.ErrorWriter.String() + + if expected := "WARNING!"; !strings.Contains(stderr, expected) { + t.Errorf("expected %q to contain %q", stderr, expected) + } + + if expected := "vault login"; !strings.Contains(stdout, expected) { + t.Errorf("expected %q to contain %q", stdout, expected) + } }) - if err != nil { - t.Fatal(err) - } - _, err = client.Logical().Write("auth/userpass/users/foo", map[string]interface{}{ - "password": "bar", - "policies": "zip,zap", + + t.Run("deprecated_login", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().EnableAuth("my-auth", "userpass", ""); err != nil { + t.Fatal(err) + } + if _, err := client.Logical().Write("auth/my-auth/users/test", map[string]interface{}{ + "password": "test", + "policies": "default", + }); err != nil { + t.Fatal(err) + } + + ui, cmd := testAuthCommand(t) + cmd.client = client + + // vault auth ARGS -> vault login ARGS + code := cmd.Run([]string{ + "-method", "userpass", + "-path", "my-auth", + "username=test", + "password=test", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String()) + } + stdout, stderr := ui.OutputWriter.String(), ui.ErrorWriter.String() + + if expected := "WARNING!"; !strings.Contains(stderr, expected) { + t.Errorf("expected %q to contain %q", stderr, expected) + } + + if expected := "Success! You are now authenticated."; !strings.Contains(stdout, expected) { + t.Errorf("expected %q to contain %q", stdout, expected) + } }) - if err != nil { - t.Fatal(err) - } - ui := new(cli.MockUi) - c := &AuthCommand{ - Meta: meta.Meta{ - Ui: ui, - TokenHelper: DefaultTokenHelper, - }, - Handlers: map[string]AuthHandler{ - "userpass": &credUserpass.CLIHandler{DefaultMount: "userpass"}, - }, - } + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() - args := []string{ - "-address", - "https://127.0.0.1:8200", - "-tls-skip-verify", - "-method", - "userpass", - "username=foo", - "password=bar", - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - // Test again with wrapping - ui = new(cli.MockUi) - c = &AuthCommand{ - Meta: meta.Meta{ - Ui: ui, - TokenHelper: DefaultTokenHelper, - }, - Handlers: map[string]AuthHandler{ - "userpass": &credUserpass.CLIHandler{DefaultMount: "userpass"}, - }, - } - - args = []string{ - "-address", - "https://127.0.0.1:8200", - "-tls-skip-verify", - "-wrap-ttl", - "5m", - "-method", - "userpass", - "username=foo", - "password=bar", - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - // Test again with no-store - ui = new(cli.MockUi) - c = &AuthCommand{ - Meta: meta.Meta{ - Ui: ui, - TokenHelper: DefaultTokenHelper, - }, - Handlers: map[string]AuthHandler{ - "userpass": &credUserpass.CLIHandler{DefaultMount: "userpass"}, - }, - } - - args = []string{ - "-address", - "https://127.0.0.1:8200", - "-tls-skip-verify", - "-wrap-ttl", - "5m", - "-no-store", - "-method", - "userpass", - "username=foo", - "password=bar", - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - // Test again with wrapping and token-only - ui = new(cli.MockUi) - c = &AuthCommand{ - Meta: meta.Meta{ - Ui: ui, - TokenHelper: DefaultTokenHelper, - }, - Handlers: map[string]AuthHandler{ - "userpass": &credUserpass.CLIHandler{DefaultMount: "userpass"}, - }, - } - - args = []string{ - "-address", - "https://127.0.0.1:8200", - "-tls-skip-verify", - "-wrap-ttl", - "5m", - "-token-only", - "-method", - "userpass", - "username=foo", - "password=bar", - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - token := strings.TrimSpace(ui.OutputWriter.String()) - if token == "" { - t.Fatal("expected to find token in output") - } - secret, err := client.Logical().Unwrap(token) - if err != nil { - t.Fatal(err) - } - if secret.Auth.ClientToken == "" { - t.Fatal("no client token found") - } + _, cmd := testAuthCommand(t) + assertNoTabs(t, cmd) + }) } - -func TestAuth_token_nostore(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - testAuthInit(t) - - ui := new(cli.MockUi) - c := &AuthCommand{ - Meta: meta.Meta{ - Ui: ui, - TokenHelper: DefaultTokenHelper, - }, - } - - args := []string{ - "-address", addr, - "-no-store", - token, - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - helper, err := c.TokenHelper() - if err != nil { - t.Fatalf("err: %s", err) - } - - actual, err := helper.Get() - if err != nil { - t.Fatalf("err: %s", err) - } - - if actual != "" { - t.Fatalf("bad: %s", actual) - } -} - -func TestAuth_stdin(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - testAuthInit(t) - - stdinR, stdinW := io.Pipe() - ui := new(cli.MockUi) - c := &AuthCommand{ - Meta: meta.Meta{ - Ui: ui, - TokenHelper: DefaultTokenHelper, - }, - testStdin: stdinR, - } - - go func() { - stdinW.Write([]byte(token)) - stdinW.Close() - }() - - args := []string{ - "-address", addr, - "-", - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } -} - -func TestAuth_badToken(t *testing.T) { - core, _, _ := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - testAuthInit(t) - - ui := new(cli.MockUi) - c := &AuthCommand{ - Meta: meta.Meta{ - Ui: ui, - TokenHelper: DefaultTokenHelper, - }, - } - - args := []string{ - "-address", addr, - "not-a-valid-token", - } - if code := c.Run(args); code != 1 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } -} - -func TestAuth_method(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - testAuthInit(t) - - ui := new(cli.MockUi) - c := &AuthCommand{ - Handlers: map[string]AuthHandler{ - "test": &testAuthHandler{}, - }, - Meta: meta.Meta{ - Ui: ui, - TokenHelper: DefaultTokenHelper, - }, - } - - args := []string{ - "-address", addr, - "-method=test", - "foo=" + token, - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - helper, err := c.TokenHelper() - if err != nil { - t.Fatalf("err: %s", err) - } - - actual, err := helper.Get() - if err != nil { - t.Fatalf("err: %s", err) - } - - if actual != token { - t.Fatalf("bad: %s", actual) - } -} - -func testAuthInit(t *testing.T) { - td, err := ioutil.TempDir("", "vault") - if err != nil { - t.Fatalf("err: %s", err) - } - - // Set the HOME env var so we get that right - os.Setenv("HOME", td) - - // Write a .vault config to use our custom token helper - config := fmt.Sprintf( - "token_helper = \"\"\n") - ioutil.WriteFile(filepath.Join(td, ".vault"), []byte(config), 0644) -} - -type testAuthHandler struct{} - -func (h *testAuthHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, error) { - return &api.Secret{ - Auth: &api.SecretAuth{ - ClientToken: m["foo"], - }, - }, nil -} - -func (h *testAuthHandler) Help() string { return "" } diff --git a/command/auth_tune.go b/command/auth_tune.go new file mode 100644 index 0000000000..958d11bd1f --- /dev/null +++ b/command/auth_tune.go @@ -0,0 +1,120 @@ +package command + +import ( + "fmt" + "strings" + "time" + + "github.com/hashicorp/vault/api" + "github.com/mitchellh/cli" + "github.com/posener/complete" +) + +var _ cli.Command = (*AuthTuneCommand)(nil) +var _ cli.CommandAutocomplete = (*AuthTuneCommand)(nil) + +type AuthTuneCommand struct { + *BaseCommand + + flagDefaultLeaseTTL time.Duration + flagMaxLeaseTTL time.Duration +} + +func (c *AuthTuneCommand) Synopsis() string { + return "Tunes an auth method configuration" +} + +func (c *AuthTuneCommand) Help() string { + helpText := ` +Usage: vault auth tune [options] PATH + + Tunes the configuration options for the auth method at the given PATH. The + argument corresponds to the PATH where the auth method is enabled, not the + TYPE! + + Tune the default lease for the github auth method: + + $ vault auth tune -default-lease-ttl=72h github/ + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *AuthTuneCommand) Flags() *FlagSets { + set := c.flagSet(FlagSetHTTP) + + f := set.NewFlagSet("Command Options") + + f.DurationVar(&DurationVar{ + Name: "default-lease-ttl", + Target: &c.flagDefaultLeaseTTL, + Default: 0, + EnvVar: "", + Completion: complete.PredictAnything, + Usage: "The default lease TTL for this auth method. If unspecified, this " + + "defaults to the Vault server's globally configured default lease TTL, " + + "or a previously configured value for the auth method.", + }) + + f.DurationVar(&DurationVar{ + Name: "max-lease-ttl", + Target: &c.flagMaxLeaseTTL, + Default: 0, + EnvVar: "", + Completion: complete.PredictAnything, + Usage: "The maximum lease TTL for this auth method. If unspecified, this " + + "defaults to the Vault server's globally configured maximum lease TTL, " + + "or a previously configured value for the auth method.", + }) + + return set +} + +func (c *AuthTuneCommand) AutocompleteArgs() complete.Predictor { + return c.PredictVaultAuths() +} + +func (c *AuthTuneCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *AuthTuneCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + switch { + case len(args) < 1: + c.UI.Error(fmt.Sprintf("Not enough arguments (expected 1, got %d)", len(args))) + return 1 + case len(args) > 1: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args))) + return 1 + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + // Append /auth (since that's where auths live) and a trailing slash to + // indicate it's a path in output + mountPath := ensureTrailingSlash(sanitizePath(args[0])) + + if err := client.Sys().TuneMount("/auth/"+mountPath, api.MountConfigInput{ + DefaultLeaseTTL: ttlToAPI(c.flagDefaultLeaseTTL), + MaxLeaseTTL: ttlToAPI(c.flagMaxLeaseTTL), + }); err != nil { + c.UI.Error(fmt.Sprintf("Error tuning auth method %s: %s", mountPath, err)) + return 2 + } + + c.UI.Output(fmt.Sprintf("Success! Tuned the auth method at: %s", mountPath)) + return 0 +} diff --git a/command/auth_tune_test.go b/command/auth_tune_test.go new file mode 100644 index 0000000000..61a36441d7 --- /dev/null +++ b/command/auth_tune_test.go @@ -0,0 +1,149 @@ +package command + +import ( + "strings" + "testing" + + "github.com/hashicorp/vault/api" + "github.com/mitchellh/cli" +) + +func testAuthTuneCommand(tb testing.TB) (*cli.MockUi, *AuthTuneCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &AuthTuneCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestAuthTuneCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "not_enough_args", + []string{}, + "Not enough arguments", + 1, + }, + { + "too_many_args", + []string{"foo", "bar"}, + "Too many arguments", + 1, + }, + } + + t.Run("validations", func(t *testing.T) { + t.Parallel() + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ui, cmd := testAuthTuneCommand(t) + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) + + t.Run("integration", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testAuthTuneCommand(t) + cmd.client = client + + // Mount + if err := client.Sys().EnableAuthWithOptions("my-auth", &api.EnableAuthOptions{ + Type: "userpass", + }); err != nil { + t.Fatal(err) + } + + code := cmd.Run([]string{ + "-default-lease-ttl", "30m", + "-max-lease-ttl", "1h", + "my-auth/", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Success! Tuned the auth method at: my-auth/" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + auths, err := client.Sys().ListAuth() + if err != nil { + t.Fatal(err) + } + + mountInfo, ok := auths["my-auth/"] + if !ok { + t.Fatalf("expected auth to exist") + } + if exp := "userpass"; mountInfo.Type != exp { + t.Errorf("expected %q to be %q", mountInfo.Type, exp) + } + if exp := 1800; mountInfo.Config.DefaultLeaseTTL != exp { + t.Errorf("expected %d to be %d", mountInfo.Config.DefaultLeaseTTL, exp) + } + if exp := 3600; mountInfo.Config.MaxLeaseTTL != exp { + t.Errorf("expected %d to be %d", mountInfo.Config.MaxLeaseTTL, exp) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testAuthTuneCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "userpass/", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error tuning auth method userpass/: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testAuthTuneCommand(t) + assertNoTabs(t, cmd) + }) +} diff --git a/command/base.go b/command/base.go new file mode 100644 index 0000000000..7dcca7671b --- /dev/null +++ b/command/base.go @@ -0,0 +1,399 @@ +package command + +import ( + "bytes" + "flag" + "fmt" + "io" + "io/ioutil" + "regexp" + "strings" + "sync" + "time" + + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/command/token" + "github.com/mitchellh/cli" + "github.com/pkg/errors" + "github.com/posener/complete" +) + +// maxLineLength is the maximum width of any line. +const maxLineLength int = 78 + +// reRemoveWhitespace is a regular expression for stripping whitespace from +// a string. +var reRemoveWhitespace = regexp.MustCompile(`[\s]+`) + +type BaseCommand struct { + UI cli.Ui + + flags *FlagSets + flagsOnce sync.Once + + flagAddress string + flagCACert string + flagCAPath string + flagClientCert string + flagClientKey string + flagTLSServerName string + flagTLSSkipVerify bool + flagWrapTTL time.Duration + + flagFormat string + flagField string + + tokenHelper token.TokenHelper + + // For testing + client *api.Client +} + +// Client returns the HTTP API client. The client is cached on the command to +// save performance on future calls. +func (c *BaseCommand) Client() (*api.Client, error) { + // Read the test client if present + if c.client != nil { + return c.client, nil + } + + config := api.DefaultConfig() + + if err := config.ReadEnvironment(); err != nil { + return nil, errors.Wrap(err, "failed to read environment") + } + + if c.flagAddress != "" { + config.Address = c.flagAddress + } + + // If we need custom TLS configuration, then set it + if c.flagCACert != "" || c.flagCAPath != "" || c.flagClientCert != "" || + c.flagClientKey != "" || c.flagTLSServerName != "" || c.flagTLSSkipVerify { + t := &api.TLSConfig{ + CACert: c.flagCACert, + CAPath: c.flagCAPath, + ClientCert: c.flagClientCert, + ClientKey: c.flagClientKey, + TLSServerName: c.flagTLSServerName, + Insecure: c.flagTLSSkipVerify, + } + config.ConfigureTLS(t) + } + + // Build the client + client, err := api.NewClient(config) + if err != nil { + return nil, errors.Wrap(err, "failed to create client") + } + + // Set the wrapping function + client.SetWrappingLookupFunc(c.DefaultWrappingLookupFunc) + + // Get the token if it came in from the environment + token := client.Token() + + // If we don't have a token, check the token helper + if token == "" { + helper, err := c.TokenHelper() + if err != nil { + return nil, errors.Wrap(err, "failed to get token helper") + } + token, err = helper.Get() + if err != nil { + return nil, errors.Wrap(err, "failed to get token from token helper") + } + } + + // Set the token + if token != "" { + client.SetToken(token) + } + + return client, nil +} + +// TokenHelper returns the token helper attached to the command. +func (c *BaseCommand) TokenHelper() (token.TokenHelper, error) { + if c.tokenHelper != nil { + return c.tokenHelper, nil + } + + helper, err := DefaultTokenHelper() + if err != nil { + return nil, err + } + return helper, nil +} + +// DefaultWrappingLookupFunc is the default wrapping function based on the +// CLI flag. +func (c *BaseCommand) DefaultWrappingLookupFunc(operation, path string) string { + if c.flagWrapTTL != 0 { + return c.flagWrapTTL.String() + } + + return api.DefaultWrappingLookupFunc(operation, path) +} + +type FlagSetBit uint + +const ( + FlagSetNone FlagSetBit = 1 << iota + FlagSetHTTP + FlagSetOutputField + FlagSetOutputFormat +) + +// flagSet creates the flags for this command. The result is cached on the +// command to save performance on future calls. +func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets { + c.flagsOnce.Do(func() { + set := NewFlagSets(c.UI) + + if bit&FlagSetHTTP != 0 { + f := set.NewFlagSet("HTTP Options") + + f.StringVar(&StringVar{ + Name: "address", + Target: &c.flagAddress, + Default: "https://127.0.0.1:8200", + EnvVar: "VAULT_ADDR", + Completion: complete.PredictAnything, + Usage: "Address of the Vault server.", + }) + + f.StringVar(&StringVar{ + Name: "ca-cert", + Target: &c.flagCACert, + Default: "", + EnvVar: "VAULT_CACERT", + Completion: complete.PredictFiles("*"), + Usage: "Path on the local disk to a single PEM-encoded CA " + + "certificate to verify the Vault server's SSL certificate. This " + + "takes precendence over -ca-path.", + }) + + f.StringVar(&StringVar{ + Name: "ca-path", + Target: &c.flagCAPath, + Default: "", + EnvVar: "VAULT_CAPATH", + Completion: complete.PredictDirs("*"), + Usage: "Path on the local disk to a directory of PEM-encoded CA " + + "certificates to verify the Vault server's SSL certificate.", + }) + + f.StringVar(&StringVar{ + Name: "client-cert", + Target: &c.flagClientCert, + Default: "", + EnvVar: "VAULT_CLIENT_CERT", + Completion: complete.PredictFiles("*"), + Usage: "Path on the local disk to a single PEM-encoded CA " + + "certificate to use for TLS authentication to the Vault server. If " + + "this flag is specified, -client-key is also required.", + }) + + f.StringVar(&StringVar{ + Name: "client-key", + Target: &c.flagClientKey, + Default: "", + EnvVar: "VAULT_CLIENT_KEY", + Completion: complete.PredictFiles("*"), + Usage: "Path on the local disk to a single PEM-encoded private key " + + "matching the client certificate from -client-cert.", + }) + + f.StringVar(&StringVar{ + Name: "tls-server-name", + Target: &c.flagTLSServerName, + Default: "", + EnvVar: "VAULT_TLS_SERVER_NAME", + Completion: complete.PredictAnything, + Usage: "Name to use as the SNI host when connecting to the Vault " + + "server via TLS.", + }) + + f.BoolVar(&BoolVar{ + Name: "tls-skip-verify", + Target: &c.flagTLSSkipVerify, + Default: false, + EnvVar: "VAULT_SKIP_VERIFY", + Usage: "Disable verification of TLS certificates. Using this option " + + "is highly discouraged and decreases the security of data " + + "transmissions to and from the Vault server.", + }) + + f.DurationVar(&DurationVar{ + Name: "wrap-ttl", + Target: &c.flagWrapTTL, + Default: 0, + EnvVar: "VAULT_WRAP_TTL", + Completion: complete.PredictAnything, + Usage: "Wraps the response in a cubbyhole token with the requested " + + "TTL. The response is available via the \"vault unwrap\" command. " + + "The TTL is specified as a numeric string with suffix like \"30s\" " + + "or \"5m\".", + }) + } + + if bit&(FlagSetOutputField|FlagSetOutputFormat) != 0 { + f := set.NewFlagSet("Output Options") + + if bit&FlagSetOutputField != 0 { + f.StringVar(&StringVar{ + Name: "field", + Target: &c.flagField, + Default: "", + Completion: complete.PredictAnything, + Usage: "Print only the field with the given name. Specifying " + + "this option will take precedence over other formatting " + + "directives. The result will not have a trailing newline " + + "making it idea for piping to other processes.", + }) + } + + if bit&FlagSetOutputFormat != 0 { + f.StringVar(&StringVar{ + Name: "format", + Target: &c.flagFormat, + Default: "table", + EnvVar: "VAULT_FORMAT", + Completion: complete.PredictSet("table", "json", "yaml"), + Usage: "Print the output in the given format. Valid formats " + + "are \"table\", \"json\", or \"yaml\".", + }) + } + } + + c.flags = set + }) + + return c.flags +} + +// FlagSets is a group of flag sets. +type FlagSets struct { + flagSets []*FlagSet + mainSet *flag.FlagSet + hiddens map[string]struct{} + completions complete.Flags +} + +// NewFlagSets creates a new flag sets. +func NewFlagSets(ui cli.Ui) *FlagSets { + mainSet := flag.NewFlagSet("", flag.ContinueOnError) + + // Errors and usage are controlled by the CLI. + mainSet.Usage = func() {} + mainSet.SetOutput(ioutil.Discard) + + return &FlagSets{ + flagSets: make([]*FlagSet, 0, 6), + mainSet: mainSet, + hiddens: make(map[string]struct{}), + completions: complete.Flags{}, + } +} + +// NewFlagSet creates a new flag set from the given flag sets. +func (f *FlagSets) NewFlagSet(name string) *FlagSet { + flagSet := NewFlagSet(name) + flagSet.mainSet = f.mainSet + flagSet.completions = f.completions + f.flagSets = append(f.flagSets, flagSet) + return flagSet +} + +// Completions returns the completions for this flag set. +func (f *FlagSets) Completions() complete.Flags { + return f.completions +} + +// Parse parses the given flags, returning any errors. +func (f *FlagSets) Parse(args []string) error { + return f.mainSet.Parse(args) +} + +// Args returns the remaining args after parsing. +func (f *FlagSets) Args() []string { + return f.mainSet.Args() +} + +// Help builds custom help for this command, grouping by flag set. +func (fs *FlagSets) Help() string { + var out bytes.Buffer + + for _, set := range fs.flagSets { + printFlagTitle(&out, set.name+":") + set.VisitAll(func(f *flag.Flag) { + // Skip any hidden flags + if v, ok := f.Value.(FlagVisibility); ok && v.Hidden() { + return + } + printFlagDetail(&out, f) + }) + } + + return strings.TrimRight(out.String(), "\n") +} + +// FlagSet is a grouped wrapper around a real flag set and a grouped flag set. +type FlagSet struct { + name string + flagSet *flag.FlagSet + mainSet *flag.FlagSet + completions complete.Flags +} + +// NewFlagSet creates a new flag set. +func NewFlagSet(name string) *FlagSet { + return &FlagSet{ + name: name, + flagSet: flag.NewFlagSet(name, flag.ContinueOnError), + } +} + +// Name returns the name of this flag set. +func (f *FlagSet) Name() string { + return f.name +} + +func (f *FlagSet) Visit(fn func(*flag.Flag)) { + f.flagSet.Visit(fn) +} + +func (f *FlagSet) VisitAll(fn func(*flag.Flag)) { + f.flagSet.VisitAll(fn) +} + +// printFlagTitle prints a consistently-formatted title to the given writer. +func printFlagTitle(w io.Writer, s string) { + fmt.Fprintf(w, "%s\n\n", s) +} + +// printFlagDetail prints a single flag to the given writer. +func printFlagDetail(w io.Writer, f *flag.Flag) { + // Check if the flag is hidden - do not print any flag detail or help output + // if it is hidden. + if h, ok := f.Value.(FlagVisibility); ok && h.Hidden() { + return + } + + // Check for a detailed example + example := "" + if t, ok := f.Value.(FlagExample); ok { + example = t.Example() + } + + if example != "" { + fmt.Fprintf(w, " -%s=<%s>\n", f.Name, example) + } else { + fmt.Fprintf(w, " -%s\n", f.Name) + } + + usage := reRemoveWhitespace.ReplaceAllString(f.Usage, " ") + indented := wrapAtLengthWithPadding(usage, 6) + fmt.Fprintf(w, "%s\n\n", indented) +} diff --git a/command/base_flags.go b/command/base_flags.go new file mode 100644 index 0000000000..57c251d200 --- /dev/null +++ b/command/base_flags.go @@ -0,0 +1,780 @@ +package command + +import ( + "flag" + "fmt" + "os" + "sort" + "strconv" + "strings" + "time" + + "github.com/posener/complete" +) + +// FlagExample is an interface which declares an example value. +type FlagExample interface { + Example() string +} + +// FlagVisibility is an interface which declares whether a flag should be +// hidden from help and completions. This is usually used for deprecations +// on "internal-only" flags. +type FlagVisibility interface { + Hidden() bool +} + +// FlagBool is an interface which boolean flags implement. +type FlagBool interface { + IsBoolFlag() bool +} + +// -- BoolVar and boolValue +type BoolVar struct { + Name string + Aliases []string + Usage string + Default bool + Hidden bool + EnvVar string + Target *bool + Completion complete.Predictor +} + +func (f *FlagSet) BoolVar(i *BoolVar) { + def := i.Default + if v := os.Getenv(i.EnvVar); v != "" { + if b, err := strconv.ParseBool(v); err != nil { + def = b + } + } + + f.VarFlag(&VarFlag{ + Name: i.Name, + Aliases: i.Aliases, + Usage: i.Usage, + Default: strconv.FormatBool(i.Default), + EnvVar: i.EnvVar, + Value: newBoolValue(def, i.Target, i.Hidden), + Completion: i.Completion, + }) +} + +type boolValue struct { + hidden bool + target *bool +} + +func newBoolValue(def bool, target *bool, hidden bool) *boolValue { + *target = def + + return &boolValue{ + hidden: hidden, + target: target, + } +} + +func (b *boolValue) Set(s string) error { + v, err := strconv.ParseBool(s) + if err != nil { + return err + } + + *b.target = v + return nil +} + +func (b *boolValue) Get() interface{} { return *b.target } +func (b *boolValue) String() string { return strconv.FormatBool(*b.target) } +func (b *boolValue) Example() string { return "" } +func (b *boolValue) Hidden() bool { return b.hidden } +func (b *boolValue) IsBoolFlag() bool { return true } + +// -- IntVar and intValue +type IntVar struct { + Name string + Aliases []string + Usage string + Default int + Hidden bool + EnvVar string + Target *int + Completion complete.Predictor +} + +func (f *FlagSet) IntVar(i *IntVar) { + initial := i.Default + if v := os.Getenv(i.EnvVar); v != "" { + if i, err := strconv.ParseInt(v, 0, 64); err != nil { + initial = int(i) + } + } + + def := "" + if i.Default != 0 { + def = strconv.FormatInt(int64(i.Default), 10) + } + + f.VarFlag(&VarFlag{ + Name: i.Name, + Aliases: i.Aliases, + Usage: i.Usage, + Default: def, + EnvVar: i.EnvVar, + Value: newIntValue(initial, i.Target, i.Hidden), + Completion: i.Completion, + }) +} + +type intValue struct { + hidden bool + target *int +} + +func newIntValue(def int, target *int, hidden bool) *intValue { + *target = def + return &intValue{ + hidden: hidden, + target: target, + } +} + +func (i *intValue) Set(s string) error { + v, err := strconv.ParseInt(s, 0, 64) + if err != nil { + return err + } + + *i.target = int(v) + return nil +} + +func (i *intValue) Get() interface{} { return int(*i.target) } +func (i *intValue) String() string { return strconv.Itoa(int(*i.target)) } +func (i *intValue) Example() string { return "int" } +func (i *intValue) Hidden() bool { return i.hidden } + +// -- Int64Var and int64Value +type Int64Var struct { + Name string + Aliases []string + Usage string + Default int64 + Hidden bool + EnvVar string + Target *int64 + Completion complete.Predictor +} + +func (f *FlagSet) Int64Var(i *Int64Var) { + initial := i.Default + if v := os.Getenv(i.EnvVar); v != "" { + if i, err := strconv.ParseInt(v, 0, 64); err != nil { + initial = i + } + } + + def := "" + if i.Default != 0 { + def = strconv.FormatInt(int64(i.Default), 10) + } + + f.VarFlag(&VarFlag{ + Name: i.Name, + Aliases: i.Aliases, + Usage: i.Usage, + Default: def, + EnvVar: i.EnvVar, + Value: newInt64Value(initial, i.Target, i.Hidden), + Completion: i.Completion, + }) +} + +type int64Value struct { + hidden bool + target *int64 +} + +func newInt64Value(def int64, target *int64, hidden bool) *int64Value { + *target = def + return &int64Value{ + hidden: hidden, + target: target, + } +} + +func (i *int64Value) Set(s string) error { + v, err := strconv.ParseInt(s, 0, 64) + if err != nil { + return err + } + + *i.target = v + return nil +} + +func (i *int64Value) Get() interface{} { return int64(*i.target) } +func (i *int64Value) String() string { return strconv.FormatInt(int64(*i.target), 10) } +func (i *int64Value) Example() string { return "int" } +func (i *int64Value) Hidden() bool { return i.hidden } + +// -- UintVar && uintValue +type UintVar struct { + Name string + Aliases []string + Usage string + Default uint + Hidden bool + EnvVar string + Target *uint + Completion complete.Predictor +} + +func (f *FlagSet) UintVar(i *UintVar) { + initial := i.Default + if v := os.Getenv(i.EnvVar); v != "" { + if i, err := strconv.ParseUint(v, 0, 64); err != nil { + initial = uint(i) + } + } + + def := "" + if i.Default != 0 { + def = strconv.FormatUint(uint64(i.Default), 10) + } + + f.VarFlag(&VarFlag{ + Name: i.Name, + Aliases: i.Aliases, + Usage: i.Usage, + Default: def, + EnvVar: i.EnvVar, + Value: newUintValue(initial, i.Target, i.Hidden), + Completion: i.Completion, + }) +} + +type uintValue struct { + hidden bool + target *uint +} + +func newUintValue(def uint, target *uint, hidden bool) *uintValue { + *target = def + return &uintValue{ + hidden: hidden, + target: target, + } +} + +func (i *uintValue) Set(s string) error { + v, err := strconv.ParseUint(s, 0, 64) + if err != nil { + return err + } + + *i.target = uint(v) + return nil +} + +func (i *uintValue) Get() interface{} { return uint(*i.target) } +func (i *uintValue) String() string { return strconv.FormatUint(uint64(*i.target), 10) } +func (i *uintValue) Example() string { return "uint" } +func (i *uintValue) Hidden() bool { return i.hidden } + +// -- Uint64Var and uint64Value +type Uint64Var struct { + Name string + Aliases []string + Usage string + Default uint64 + Hidden bool + EnvVar string + Target *uint64 + Completion complete.Predictor +} + +func (f *FlagSet) Uint64Var(i *Uint64Var) { + initial := i.Default + if v := os.Getenv(i.EnvVar); v != "" { + if i, err := strconv.ParseUint(v, 0, 64); err != nil { + initial = i + } + } + + def := "" + if i.Default != 0 { + strconv.FormatUint(i.Default, 10) + } + + f.VarFlag(&VarFlag{ + Name: i.Name, + Aliases: i.Aliases, + Usage: i.Usage, + Default: def, + EnvVar: i.EnvVar, + Value: newUint64Value(initial, i.Target, i.Hidden), + Completion: i.Completion, + }) +} + +type uint64Value struct { + hidden bool + target *uint64 +} + +func newUint64Value(def uint64, target *uint64, hidden bool) *uint64Value { + *target = def + return &uint64Value{ + hidden: hidden, + target: target, + } +} + +func (i *uint64Value) Set(s string) error { + v, err := strconv.ParseUint(s, 0, 64) + if err != nil { + return err + } + + *i.target = v + return nil +} + +func (i *uint64Value) Get() interface{} { return uint64(*i.target) } +func (i *uint64Value) String() string { return strconv.FormatUint(uint64(*i.target), 10) } +func (i *uint64Value) Example() string { return "uint" } +func (i *uint64Value) Hidden() bool { return i.hidden } + +// -- StringVar and stringValue +type StringVar struct { + Name string + Aliases []string + Usage string + Default string + Hidden bool + EnvVar string + Target *string + Completion complete.Predictor +} + +func (f *FlagSet) StringVar(i *StringVar) { + initial := i.Default + if v := os.Getenv(i.EnvVar); v != "" { + initial = v + } + + def := "" + if i.Default != "" { + def = i.Default + } + + f.VarFlag(&VarFlag{ + Name: i.Name, + Aliases: i.Aliases, + Usage: i.Usage, + Default: def, + EnvVar: i.EnvVar, + Value: newStringValue(initial, i.Target, i.Hidden), + Completion: i.Completion, + }) +} + +type stringValue struct { + hidden bool + target *string +} + +func newStringValue(def string, target *string, hidden bool) *stringValue { + *target = def + return &stringValue{ + hidden: hidden, + target: target, + } +} + +func (s *stringValue) Set(val string) error { + *s.target = val + return nil +} + +func (s *stringValue) Get() interface{} { return *s.target } +func (s *stringValue) String() string { return *s.target } +func (s *stringValue) Example() string { return "string" } +func (s *stringValue) Hidden() bool { return s.hidden } + +// -- Float64Var and float64Value +type Float64Var struct { + Name string + Aliases []string + Usage string + Default float64 + Hidden bool + EnvVar string + Target *float64 + Completion complete.Predictor +} + +func (f *FlagSet) Float64Var(i *Float64Var) { + initial := i.Default + if v := os.Getenv(i.EnvVar); v != "" { + if i, err := strconv.ParseFloat(v, 64); err != nil { + initial = i + } + } + + def := "" + if i.Default != 0 { + def = strconv.FormatFloat(i.Default, 'e', -1, 64) + } + + f.VarFlag(&VarFlag{ + Name: i.Name, + Aliases: i.Aliases, + Usage: i.Usage, + Default: def, + EnvVar: i.EnvVar, + Value: newFloat64Value(initial, i.Target, i.Hidden), + Completion: i.Completion, + }) +} + +type float64Value struct { + hidden bool + target *float64 +} + +func newFloat64Value(def float64, target *float64, hidden bool) *float64Value { + *target = def + return &float64Value{ + hidden: hidden, + target: target, + } +} + +func (f *float64Value) Set(s string) error { + v, err := strconv.ParseFloat(s, 64) + if err != nil { + return err + } + + *f.target = v + return nil +} + +func (f *float64Value) Get() interface{} { return float64(*f.target) } +func (f *float64Value) String() string { return strconv.FormatFloat(float64(*f.target), 'g', -1, 64) } +func (f *float64Value) Example() string { return "float" } +func (f *float64Value) Hidden() bool { return f.hidden } + +// -- DurationVar and durationValue +type DurationVar struct { + Name string + Aliases []string + Usage string + Default time.Duration + Hidden bool + EnvVar string + Target *time.Duration + Completion complete.Predictor +} + +func (f *FlagSet) DurationVar(i *DurationVar) { + initial := i.Default + if v := os.Getenv(i.EnvVar); v != "" { + if d, err := time.ParseDuration(appendDurationSuffix(v)); err != nil { + initial = d + } + } + + def := "" + if i.Default != 0 { + def = i.Default.String() + } + + f.VarFlag(&VarFlag{ + Name: i.Name, + Aliases: i.Aliases, + Usage: i.Usage, + Default: def, + EnvVar: i.EnvVar, + Value: newDurationValue(initial, i.Target, i.Hidden), + Completion: i.Completion, + }) +} + +type durationValue struct { + hidden bool + target *time.Duration +} + +func newDurationValue(def time.Duration, target *time.Duration, hidden bool) *durationValue { + *target = def + return &durationValue{ + hidden: hidden, + target: target, + } +} + +func (d *durationValue) Set(s string) error { + // Maintain bc for people specifying "system" as the value. + if s == "system" { + s = "-1" + } + + v, err := time.ParseDuration(appendDurationSuffix(s)) + if err != nil { + return err + } + *d.target = v + return nil +} + +func (d *durationValue) Get() interface{} { return time.Duration(*d.target) } +func (d *durationValue) String() string { return (*d.target).String() } +func (d *durationValue) Example() string { return "duration" } +func (d *durationValue) Hidden() bool { return d.hidden } + +// appendDurationSuffix is used as a backwards-compat tool for assuming users +// meant "seconds" when they do not provide a suffixed duration value. +func appendDurationSuffix(s string) string { + if strings.HasSuffix(s, "s") || strings.HasSuffix(s, "m") || strings.HasSuffix(s, "h") { + return s + } + return s + "s" +} + +// -- StringSliceVar and stringSliceValue +type StringSliceVar struct { + Name string + Aliases []string + Usage string + Default []string + Hidden bool + EnvVar string + Target *[]string + Completion complete.Predictor +} + +func (f *FlagSet) StringSliceVar(i *StringSliceVar) { + initial := i.Default + if v := os.Getenv(i.EnvVar); v != "" { + parts := strings.Split(v, ",") + for i := range parts { + parts[i] = strings.TrimSpace(parts[i]) + } + initial = parts + } + + def := "" + if i.Default != nil { + def = strings.Join(i.Default, ",") + } + + f.VarFlag(&VarFlag{ + Name: i.Name, + Aliases: i.Aliases, + Usage: i.Usage, + Default: def, + EnvVar: i.EnvVar, + Value: newStringSliceValue(initial, i.Target, i.Hidden), + Completion: i.Completion, + }) +} + +type stringSliceValue struct { + hidden bool + target *[]string +} + +func newStringSliceValue(def []string, target *[]string, hidden bool) *stringSliceValue { + *target = def + return &stringSliceValue{ + hidden: hidden, + target: target, + } +} + +func (s *stringSliceValue) Set(val string) error { + *s.target = append(*s.target, strings.TrimSpace(val)) + return nil +} + +func (s *stringSliceValue) Get() interface{} { return *s.target } +func (s *stringSliceValue) String() string { return strings.Join(*s.target, ",") } +func (s *stringSliceValue) Example() string { return "string" } +func (s *stringSliceValue) Hidden() bool { return s.hidden } + +// -- StringMapVar and stringMapValue +type StringMapVar struct { + Name string + Aliases []string + Usage string + Default map[string]string + Hidden bool + Target *map[string]string + Completion complete.Predictor +} + +func (f *FlagSet) StringMapVar(i *StringMapVar) { + def := "" + if i.Default != nil { + def = mapToKV(i.Default) + } + + f.VarFlag(&VarFlag{ + Name: i.Name, + Aliases: i.Aliases, + Usage: i.Usage, + Default: def, + Value: newStringMapValue(i.Default, i.Target, i.Hidden), + Completion: i.Completion, + }) +} + +type stringMapValue struct { + hidden bool + target *map[string]string +} + +func newStringMapValue(def map[string]string, target *map[string]string, hidden bool) *stringMapValue { + *target = def + return &stringMapValue{ + hidden: hidden, + target: target, + } +} + +func (s *stringMapValue) Set(val string) error { + idx := strings.Index(val, "=") + if idx == -1 { + return fmt.Errorf("Missing = in KV pair: %s", val) + } + + if *s.target == nil { + *s.target = make(map[string]string) + } + + k, v := val[0:idx], val[idx+1:] + (*s.target)[k] = v + return nil +} + +func (s *stringMapValue) Get() interface{} { return *s.target } +func (s *stringMapValue) String() string { return mapToKV(*s.target) } +func (s *stringMapValue) Example() string { return "key=value" } +func (s *stringMapValue) Hidden() bool { return s.hidden } + +func mapToKV(m map[string]string) string { + list := make([]string, 0, len(m)) + for k, _ := range m { + list = append(list, k) + } + sort.Strings(list) + + for i, k := range list { + list[i] = k + "=" + m[k] + } + + return strings.Join(list, ",") +} + +// -- VarFlag +type VarFlag struct { + Name string + Aliases []string + Usage string + Default string + EnvVar string + Value flag.Value + Completion complete.Predictor +} + +func (f *FlagSet) VarFlag(i *VarFlag) { + // If the flag is marked as hidden, just add it to the set and return to + // avoid unnecessary computations here. We do not want to add completions or + // generate help output for hidden flags. + if v, ok := i.Value.(FlagVisibility); ok && v.Hidden() { + f.Var(i.Value, i.Name, "") + return + } + + // Calculate the full usage + usage := i.Usage + + if len(i.Aliases) > 0 { + sentence := make([]string, len(i.Aliases)) + for i, a := range i.Aliases { + sentence[i] = fmt.Sprintf(`"-%s"`, a) + } + + aliases := "" + switch len(sentence) { + case 0: + // impossible... + case 1: + aliases = sentence[0] + case 2: + aliases = sentence[0] + " and " + sentence[1] + default: + sentence[len(sentence)-1] = "and " + sentence[len(sentence)-1] + aliases = strings.Join(sentence, ", ") + } + + usage += fmt.Sprintf(" This is aliased as %s.", aliases) + } + + if i.Default != "" { + usage += fmt.Sprintf(" The default is %s.", i.Default) + } + + if i.EnvVar != "" { + usage += fmt.Sprintf(" This can also be specified via the %s "+ + "environment variable.", i.EnvVar) + } + + // Add aliases to the main set + for _, a := range i.Aliases { + f.mainSet.Var(i.Value, a, "") + } + + f.Var(i.Value, i.Name, usage) + f.completions["-"+i.Name] = i.Completion +} + +// Var is a lower-level API for adding something to the flags. It should be used +// wtih caution, since it bypasses all validation. Consider VarFlag instead. +func (f *FlagSet) Var(value flag.Value, name, usage string) { + f.mainSet.Var(value, name, usage) + f.flagSet.Var(value, name, usage) +} + +// -- helpers +func envDefault(key, def string) string { + if v := os.Getenv(key); v != "" { + return v + } + return def +} + +func envBoolDefault(key string, def bool) bool { + if v := os.Getenv(key); v != "" { + b, err := strconv.ParseBool(v) + if err != nil { + panic(err) + } + return b + } + return def +} + +func envDurationDefault(key string, def time.Duration) time.Duration { + if v := os.Getenv(key); v != "" { + d, err := time.ParseDuration(v) + if err != nil { + panic(err) + } + return d + } + return def +} diff --git a/command/base_helpers.go b/command/base_helpers.go new file mode 100644 index 0000000000..75c0e80998 --- /dev/null +++ b/command/base_helpers.go @@ -0,0 +1,243 @@ +package command + +import ( + "fmt" + "io" + "strings" + "time" + + "github.com/hashicorp/vault/api" + kvbuilder "github.com/hashicorp/vault/helper/kv-builder" + "github.com/kr/text" + homedir "github.com/mitchellh/go-homedir" + "github.com/mitchellh/mapstructure" + "github.com/pkg/errors" + "github.com/ryanuber/columnize" +) + +// extractListData reads the secret and returns a typed list of data and a +// boolean indicating whether the extraction was successful. +func extractListData(secret *api.Secret) ([]interface{}, bool) { + if secret == nil || secret.Data == nil { + return nil, false + } + + k, ok := secret.Data["keys"] + if !ok || k == nil { + return nil, false + } + + i, ok := k.([]interface{}) + return i, ok +} + +// sanitizePath removes any leading or trailing things from a "path". +func sanitizePath(s string) string { + return ensureNoTrailingSlash(ensureNoLeadingSlash(strings.TrimSpace(s))) +} + +// ensureTrailingSlash ensures the given string has a trailing slash. +func ensureTrailingSlash(s string) string { + s = strings.TrimSpace(s) + if s == "" { + return "" + } + + for len(s) > 0 && s[len(s)-1] != '/' { + s = s + "/" + } + return s +} + +// ensureNoTrailingSlash ensures the given string has a trailing slash. +func ensureNoTrailingSlash(s string) string { + s = strings.TrimSpace(s) + if s == "" { + return "" + } + + for len(s) > 0 && s[len(s)-1] == '/' { + s = s[:len(s)-1] + } + return s +} + +// ensureNoLeadingSlash ensures the given string has a trailing slash. +func ensureNoLeadingSlash(s string) string { + s = strings.TrimSpace(s) + if s == "" { + return "" + } + + for len(s) > 0 && s[0] == '/' { + s = s[1:] + } + return s +} + +// columnOuput prints the list of items as a table with no headers. +func columnOutput(list []string, c *columnize.Config) string { + if len(list) == 0 { + return "" + } + + if c == nil { + c = &columnize.Config{} + } + if c.Glue == "" { + c.Glue = " " + } + if c.Empty == "" { + c.Empty = "n/a" + } + + return columnize.Format(list, c) +} + +// tableOutput prints the list of items as columns, where the first row is +// the list of headers. +func tableOutput(list []string, c *columnize.Config) string { + if len(list) == 0 { + return "" + } + + delim := "|" + if c != nil && c.Delim != "" { + delim = c.Delim + } + + underline := "" + headers := strings.Split(list[0], delim) + for i, h := range headers { + h = strings.TrimSpace(h) + u := strings.Repeat("-", len(h)) + + underline = underline + u + if i != len(headers)-1 { + underline = underline + delim + } + } + + list = append(list, "") + copy(list[2:], list[1:]) + list[1] = underline + + return columnOutput(list, c) +} + +// parseArgsData parses the given args in the format key=value into a map of +// the provided arguments. The given reader can also supply key=value pairs. +func parseArgsData(stdin io.Reader, args []string) (map[string]interface{}, error) { + builder := &kvbuilder.Builder{Stdin: stdin} + if err := builder.Add(args...); err != nil { + return nil, err + } + + return builder.Map(), nil +} + +// parseArgsDataString parses the args data and returns the values as strings. +// If the values cannot be represented as strings, an error is returned. +func parseArgsDataString(stdin io.Reader, args []string) (map[string]string, error) { + raw, err := parseArgsData(stdin, args) + if err != nil { + return nil, err + } + + var result map[string]string + if err := mapstructure.WeakDecode(raw, &result); err != nil { + return nil, errors.Wrap(err, "failed to convert values to strings") + } + return result, nil +} + +// truncateToSeconds truncates the given duaration to the number of seconds. If +// the duration is less than 1s, it is returned as 0. The integer represents +// the whole number unit of seconds for the duration. +func truncateToSeconds(d time.Duration) int { + d = d.Truncate(1 * time.Second) + + // Handle the case where someone requested a ridiculously short increment - + // incremenents must be larger than a second. + if d < 1*time.Second { + return 0 + } + + return int(d.Seconds()) +} + +// printKeyStatus prints the KeyStatus response from the API. +func printKeyStatus(ks *api.KeyStatus) string { + return columnOutput([]string{ + fmt.Sprintf("Key Term | %d", ks.Term), + fmt.Sprintf("Install Time | %s", ks.InstallTime.UTC().Format(time.RFC822)), + }, nil) +} + +// expandPath takes a filepath and returns the full expanded path, accounting +// for user-relative things like ~/. +func expandPath(s string) string { + if s == "" { + return "" + } + + e, err := homedir.Expand(s) + if err != nil { + return s + } + return e +} + +// wrapAtLengthWithPadding wraps the given text at the maxLineLength, taking +// into account any provided left padding. +func wrapAtLengthWithPadding(s string, pad int) string { + wrapped := text.Wrap(s, maxLineLength-pad) + lines := strings.Split(wrapped, "\n") + for i, line := range lines { + lines[i] = strings.Repeat(" ", pad) + line + } + return strings.Join(lines, "\n") +} + +// wrapAtLength wraps the given text to maxLineLength. +func wrapAtLength(s string) string { + return wrapAtLengthWithPadding(s, 0) +} + +// ttlToAPI converts a user-supplied ttl into an API-compatible string. If +// the TTL is 0, this returns the empty string. If the TTL is negative, this +// returns "system" to indicate to use the system values. Otherwise, the +// time.Duration ttl is used. +func ttlToAPI(d time.Duration) string { + if d == 0 { + return "" + } + + if d < 0 { + return "system" + } + + return d.String() +} + +// humanDuration prints the time duration without those pesky zeros. +func humanDuration(d time.Duration) string { + if d == 0 { + return "0s" + } + + s := d.String() + if strings.HasSuffix(s, "m0s") { + s = s[:len(s)-2] + } + if idx := strings.Index(s, "h0m"); idx > 0 { + s = s[:idx+1] + s[idx+3:] + } + return s +} + +// humanDurationInt prints the given int as if it were a time.Duration number +// of seconds. +func humanDurationInt(i int) string { + return humanDuration(time.Duration(i) * time.Second) +} diff --git a/command/base_helpers_test.go b/command/base_helpers_test.go new file mode 100644 index 0000000000..87c0bff695 --- /dev/null +++ b/command/base_helpers_test.go @@ -0,0 +1,162 @@ +package command + +import ( + "fmt" + "io" + "io/ioutil" + "os" + "testing" + "time" +) + +func TestParseArgsData(t *testing.T) { + t.Parallel() + + t.Run("stdin_full", func(t *testing.T) { + t.Parallel() + + stdinR, stdinW := io.Pipe() + go func() { + stdinW.Write([]byte(`{"foo":"bar"}`)) + stdinW.Close() + }() + + m, err := parseArgsData(stdinR, []string{"-"}) + if err != nil { + t.Fatal(err) + } + + if v, ok := m["foo"]; !ok || v != "bar" { + t.Errorf("expected %q to be %q", v, "bar") + } + }) + + t.Run("stdin_value", func(t *testing.T) { + t.Parallel() + + stdinR, stdinW := io.Pipe() + go func() { + stdinW.Write([]byte(`bar`)) + stdinW.Close() + }() + + m, err := parseArgsData(stdinR, []string{"foo=-"}) + if err != nil { + t.Fatal(err) + } + + if v, ok := m["foo"]; !ok || v != "bar" { + t.Errorf("expected %q to be %q", v, "bar") + } + }) + + t.Run("file_full", func(t *testing.T) { + t.Parallel() + + f, err := ioutil.TempFile("", "vault") + if err != nil { + t.Fatal(err) + } + f.Write([]byte(`{"foo":"bar"}`)) + f.Close() + defer os.Remove(f.Name()) + + m, err := parseArgsData(os.Stdin, []string{"@" + f.Name()}) + if err != nil { + t.Fatal(err) + } + + if v, ok := m["foo"]; !ok || v != "bar" { + t.Errorf("expected %q to be %q", v, "bar") + } + }) + + t.Run("file_value", func(t *testing.T) { + t.Parallel() + + f, err := ioutil.TempFile("", "vault") + if err != nil { + t.Fatal(err) + } + f.Write([]byte(`bar`)) + f.Close() + defer os.Remove(f.Name()) + + m, err := parseArgsData(os.Stdin, []string{"foo=@" + f.Name()}) + if err != nil { + t.Fatal(err) + } + + if v, ok := m["foo"]; !ok || v != "bar" { + t.Errorf("expected %q to be %q", v, "bar") + } + }) + + t.Run("file_value_escaped", func(t *testing.T) { + t.Parallel() + + m, err := parseArgsData(os.Stdin, []string{`foo=\@`}) + if err != nil { + t.Fatal(err) + } + + if v, ok := m["foo"]; !ok || v != "@" { + t.Errorf("expected %q to be %q", v, "@") + } + }) +} + +func TestTruncateToSeconds(t *testing.T) { + t.Parallel() + + cases := []struct { + d time.Duration + exp int + }{ + { + 10 * time.Nanosecond, + 0, + }, + { + 10 * time.Microsecond, + 0, + }, + { + 10 * time.Millisecond, + 0, + }, + { + 1 * time.Second, + 1, + }, + { + 10 * time.Second, + 10, + }, + { + 100 * time.Second, + 100, + }, + { + 3 * time.Minute, + 180, + }, + { + 3 * time.Hour, + 10800, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(fmt.Sprintf("%s", tc.d), func(t *testing.T) { + t.Parallel() + + act := truncateToSeconds(tc.d) + if act != tc.exp { + t.Errorf("expected %d to be %d", act, tc.exp) + } + }) + } +} diff --git a/command/base_predict.go b/command/base_predict.go new file mode 100644 index 0000000000..8f53c0444a --- /dev/null +++ b/command/base_predict.go @@ -0,0 +1,417 @@ +package command + +import ( + "sort" + "strings" + "sync" + + "github.com/hashicorp/vault/api" + "github.com/posener/complete" +) + +type Predict struct { + client *api.Client + clientOnce sync.Once +} + +func NewPredict() *Predict { + return &Predict{} +} + +func (p *Predict) Client() *api.Client { + p.clientOnce.Do(func() { + if p.client == nil { // For tests + client, _ := api.NewClient(nil) + + if client.Token() == "" { + helper, err := DefaultTokenHelper() + if err != nil { + return + } + token, err := helper.Get() + if err != nil { + return + } + client.SetToken(token) + } + + p.client = client + } + }) + return p.client +} + +// defaultPredictVaultMounts is the default list of mounts to return to the +// user. This is a best-guess, given we haven't communicated with the Vault +// server. If the user has no token or if the token does not have the default +// policy attached, it won't be able to read cubbyhole/, but it's a better UX +// that returning nothing. +var defaultPredictVaultMounts = []string{"cubbyhole/"} + +// predictClient is the API client to use for prediction. We create this at the +// beginning once, because completions are generated for each command (and this +// doesn't change), and the only way to configure the predict/autocomplete +// client is via environment variables. Even if the user specifies a flag, we +// can't parse that flag until after the command is submitted. +var predictClient *api.Client +var predictClientOnce sync.Once + +// PredictClient returns the cached API client for the predictor. +func PredictClient() *api.Client { + predictClientOnce.Do(func() { + if predictClient == nil { // For tests + predictClient, _ = api.NewClient(nil) + } + }) + return predictClient +} + +// PredictVaultAvailableMounts returns a predictor for the available mounts in +// Vault. For now, there is no way to programatically get this list. If, in the +// future, such a list exists, we can adapt it here. Until then, it's +// hard-coded. +func (b *BaseCommand) PredictVaultAvailableMounts() complete.Predictor { + // This list does not contain deprecated backends. At present, there is no + // API that lists all available secret backends, so this is hard-coded :(. + return complete.PredictSet( + "aws", + "consul", + "database", + "generic", + "pki", + "plugin", + "rabbitmq", + "ssh", + "totp", + "transit", + ) +} + +// PredictVaultAvailableAuths returns a predictor for the available auths in +// Vault. For now, there is no way to programatically get this list. If, in the +// future, such a list exists, we can adapt it here. Until then, it's +// hard-coded. +func (b *BaseCommand) PredictVaultAvailableAuths() complete.Predictor { + return complete.PredictSet( + "app-id", + "approle", + "aws", + "cert", + "gcp", + "github", + "ldap", + "okta", + "plugin", + "radius", + "userpass", + ) +} + +// PredictVaultFiles returns a predictor for Vault mounts and paths based on the +// configured client for the base command. Unfortunately this happens pre-flag +// parsing, so users must rely on environment variables for autocomplete if they +// are not using Vault at the default endpoints. +func (b *BaseCommand) PredictVaultFiles() complete.Predictor { + return NewPredict().VaultFiles() +} + +// PredictVaultFolders returns a predictor for "folders". See PredictVaultFiles +// for more information and restrictions. +func (b *BaseCommand) PredictVaultFolders() complete.Predictor { + return NewPredict().VaultFolders() +} + +// PredictVaultMounts returns a predictor for "folders". See PredictVaultFiles +// for more information and restrictions. +func (b *BaseCommand) PredictVaultMounts() complete.Predictor { + return NewPredict().VaultMounts() +} + +// PredictVaultAudits returns a predictor for "folders". See PredictVaultFiles +// for more information and restrictions. +func (b *BaseCommand) PredictVaultAudits() complete.Predictor { + return NewPredict().VaultAudits() +} + +// PredictVaultAuths returns a predictor for "folders". See PredictVaultFiles +// for more information and restrictions. +func (b *BaseCommand) PredictVaultAuths() complete.Predictor { + return NewPredict().VaultAuths() +} + +// PredictVaultPolicies returns a predictor for "folders". See PredictVaultFiles +// for more information and restrictions. +func (b *BaseCommand) PredictVaultPolicies() complete.Predictor { + return NewPredict().VaultPolicies() +} + +// VaultFiles returns a predictor for Vault "files". This is a public API for +// consumers, but you probably want BaseCommand.PredictVaultFiles instead. +func (p *Predict) VaultFiles() complete.Predictor { + return p.vaultPaths(true) +} + +// VaultFolders returns a predictor for Vault "folders". This is a public +// API for consumers, but you probably want BaseCommand.PredictVaultFolders +// instead. +func (p *Predict) VaultFolders() complete.Predictor { + return p.vaultPaths(false) +} + +// VaultMounts returns a predictor for Vault "folders". This is a public +// API for consumers, but you probably want BaseCommand.PredictVaultMounts +// instead. +func (p *Predict) VaultMounts() complete.Predictor { + return p.filterFunc(p.mounts) +} + +// VaultAudits returns a predictor for Vault "folders". This is a public API for +// consumers, but you probably want BaseCommand.PredictVaultAudits instead. +func (p *Predict) VaultAudits() complete.Predictor { + return p.filterFunc(p.audits) +} + +// VaultAuths returns a predictor for Vault "folders". This is a public API for +// consumers, but you probably want BaseCommand.PredictVaultAuths instead. +func (p *Predict) VaultAuths() complete.Predictor { + return p.filterFunc(p.auths) +} + +// VaultPolicies returns a predictor for Vault "folders". This is a public API for +// consumers, but you probably want BaseCommand.PredictVaultPolicies instead. +func (p *Predict) VaultPolicies() complete.Predictor { + return p.filterFunc(p.policies) +} + +// vaultPaths parses the CLI options and returns the "best" list of possible +// paths. If there are any errors, this function returns an empty result. All +// errors are suppressed since this is a prediction function. +func (p *Predict) vaultPaths(includeFiles bool) complete.PredictFunc { + return func(args complete.Args) []string { + // Do not predict more than one paths + if p.hasPathArg(args.All) { + return nil + } + + client := p.Client() + if client == nil { + return nil + } + + path := args.Last + + var predictions []string + if strings.Contains(path, "/") { + predictions = p.paths(path, includeFiles) + } else { + predictions = p.filter(p.mounts(), path) + } + + // Either no results or many results, so return. + if len(predictions) != 1 { + return predictions + } + + // If this is not a "folder", do not try to recurse. + if !strings.HasSuffix(predictions[0], "/") { + return predictions + } + + // If the prediction is the same as the last guess, return it (we have no + // new information and we won't get anymore). + if predictions[0] == args.Last { + return predictions + } + + // Re-predict with the remaining path + args.Last = predictions[0] + return p.vaultPaths(includeFiles).Predict(args) + } +} + +// paths predicts all paths which start with the given path. +func (p *Predict) paths(path string, includeFiles bool) []string { + client := p.Client() + if client == nil { + return nil + } + + // Vault does not support listing based on a sub-key, so we have to back-pedal + // to the last "/" and return all paths on that "folder". Then we perform + // client-side filtering. + root := path + idx := strings.LastIndex(root, "/") + if idx > 0 && idx < len(root) { + root = root[:idx+1] + } + + paths := p.listPaths(root) + + var predictions []string + for _, p := range paths { + // Calculate the absolute "path" for matching. + p = root + p + + if strings.HasPrefix(p, path) { + // Ensure this is a directory or we've asked to include files. + if includeFiles || strings.HasSuffix(p, "/") { + predictions = append(predictions, p) + } + } + } + + // Add root to the path + if len(predictions) == 0 { + predictions = append(predictions, path) + } + + return predictions +} + +// audits returns a sorted list of the audit backends for Vault server for +// which the client is configured to communicate with. +func (p *Predict) audits() []string { + client := p.Client() + if client == nil { + return nil + } + + audits, err := client.Sys().ListAudit() + if err != nil { + return nil + } + + list := make([]string, 0, len(audits)) + for m := range audits { + list = append(list, m) + } + sort.Strings(list) + return list +} + +// auths returns a sorted list of the enabled auth provides for Vault server for +// which the client is configured to communicate with. +func (p *Predict) auths() []string { + client := p.Client() + if client == nil { + return nil + } + + auths, err := client.Sys().ListAuth() + if err != nil { + return nil + } + + list := make([]string, 0, len(auths)) + for m := range auths { + list = append(list, m) + } + sort.Strings(list) + return list +} + +// policies returns a sorted list of the policies stored in this Vault +// server. +func (p *Predict) policies() []string { + client := p.Client() + if client == nil { + return nil + } + + policies, err := client.Sys().ListPolicies() + if err != nil { + return nil + } + sort.Strings(policies) + return policies +} + +// mounts returns a sorted list of the mount paths for Vault server for +// which the client is configured to communicate with. This function returns the +// default list of mounts if an error occurs. +func (p *Predict) mounts() []string { + client := p.Client() + if client == nil { + return nil + } + + mounts, err := client.Sys().ListMounts() + if err != nil { + return defaultPredictVaultMounts + } + + list := make([]string, 0, len(mounts)) + for m := range mounts { + list = append(list, m) + } + sort.Strings(list) + return list +} + +// listPaths returns a list of paths (HTTP LIST) for the given path. This +// function returns an empty list of any errors occur. +func (p *Predict) listPaths(path string) []string { + client := p.Client() + if client == nil { + return nil + } + + secret, err := client.Logical().List(path) + if err != nil || secret == nil || secret.Data == nil { + return nil + } + + paths, ok := secret.Data["keys"].([]interface{}) + if !ok { + return nil + } + + list := make([]string, 0, len(paths)) + for _, p := range paths { + if str, ok := p.(string); ok { + list = append(list, str) + } + } + sort.Strings(list) + return list +} + +// hasPathArg determines if the args have already accepted a path. +func (p *Predict) hasPathArg(args []string) bool { + var nonFlags []string + for _, a := range args { + if !strings.HasPrefix(a, "-") { + nonFlags = append(nonFlags, a) + } + } + + return len(nonFlags) > 2 +} + +// filterFunc is used to compose a complete predictor that filters an array +// of strings as per the filter function. +func (p *Predict) filterFunc(f func() []string) complete.Predictor { + return complete.PredictFunc(func(args complete.Args) []string { + if p.hasPathArg(args.All) { + return nil + } + + client := p.Client() + if client == nil { + return nil + } + + return p.filter(f(), args.Last) + }) +} + +// filter filters the given list for items that start with the prefix. +func (p *Predict) filter(list []string, prefix string) []string { + var predictions []string + for _, item := range list { + if strings.HasPrefix(item, prefix) { + predictions = append(predictions, item) + } + } + return predictions +} diff --git a/command/base_predict_test.go b/command/base_predict_test.go new file mode 100644 index 0000000000..93e291e502 --- /dev/null +++ b/command/base_predict_test.go @@ -0,0 +1,526 @@ +package command + +import ( + "reflect" + "testing" + + "github.com/hashicorp/vault/api" + "github.com/posener/complete" +) + +func TestPredictVaultPaths(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + data := map[string]interface{}{"a": "b"} + if _, err := client.Logical().Write("secret/bar", data); err != nil { + t.Fatal(err) + } + if _, err := client.Logical().Write("secret/foo", data); err != nil { + t.Fatal(err) + } + if _, err := client.Logical().Write("secret/zip/zap", data); err != nil { + t.Fatal(err) + } + if _, err := client.Logical().Write("secret/zip/zonk", data); err != nil { + t.Fatal(err) + } + if _, err := client.Logical().Write("secret/zip/twoot", data); err != nil { + t.Fatal(err) + } + + cases := []struct { + name string + args complete.Args + includeFiles bool + exp []string + }{ + { + "has_args", + complete.Args{ + All: []string{"read", "secret/foo", "a=b"}, + Last: "a=b", + }, + true, + nil, + }, + { + "has_args_no_files", + complete.Args{ + All: []string{"read", "secret/foo", "a=b"}, + Last: "a=b", + }, + false, + nil, + }, + { + "part_mount", + complete.Args{ + All: []string{"read", "s"}, + Last: "s", + }, + true, + []string{"secret/", "sys/"}, + }, + { + "part_mount_no_files", + complete.Args{ + All: []string{"read", "s"}, + Last: "s", + }, + false, + []string{"secret/", "sys/"}, + }, + { + "only_mount", + complete.Args{ + All: []string{"read", "sec"}, + Last: "sec", + }, + true, + []string{"secret/bar", "secret/foo", "secret/zip/"}, + }, + { + "only_mount_no_files", + complete.Args{ + All: []string{"read", "sec"}, + Last: "sec", + }, + false, + []string{"secret/zip/"}, + }, + { + "full_mount", + complete.Args{ + All: []string{"read", "secret"}, + Last: "secret", + }, + true, + []string{"secret/bar", "secret/foo", "secret/zip/"}, + }, + { + "full_mount_no_files", + complete.Args{ + All: []string{"read", "secret"}, + Last: "secret", + }, + false, + []string{"secret/zip/"}, + }, + { + "full_mount_slash", + complete.Args{ + All: []string{"read", "secret/"}, + Last: "secret/", + }, + true, + []string{"secret/bar", "secret/foo", "secret/zip/"}, + }, + { + "full_mount_slash_no_files", + complete.Args{ + All: []string{"read", "secret/"}, + Last: "secret/", + }, + false, + []string{"secret/zip/"}, + }, + { + "path_partial", + complete.Args{ + All: []string{"read", "secret/z"}, + Last: "secret/z", + }, + true, + []string{"secret/zip/twoot", "secret/zip/zap", "secret/zip/zonk"}, + }, + { + "path_partial_no_files", + complete.Args{ + All: []string{"read", "secret/z"}, + Last: "secret/z", + }, + false, + []string{"secret/zip/"}, + }, + { + "subpath_partial_z", + complete.Args{ + All: []string{"read", "secret/zip/z"}, + Last: "secret/zip/z", + }, + true, + []string{"secret/zip/zap", "secret/zip/zonk"}, + }, + { + "subpath_partial_z_no_files", + complete.Args{ + All: []string{"read", "secret/zip/z"}, + Last: "secret/zip/z", + }, + false, + []string{"secret/zip/z"}, + }, + { + "subpath_partial_t", + complete.Args{ + All: []string{"read", "secret/zip/t"}, + Last: "secret/zip/t", + }, + true, + []string{"secret/zip/twoot"}, + }, + { + "subpath_partial_t_no_files", + complete.Args{ + All: []string{"read", "secret/zip/t"}, + Last: "secret/zip/t", + }, + false, + []string{"secret/zip/t"}, + }, + } + + t.Run("group", func(t *testing.T) { + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + p := NewPredict() + p.client = client + + f := p.vaultPaths(tc.includeFiles) + act := f(tc.args) + if !reflect.DeepEqual(act, tc.exp) { + t.Errorf("expected %q to be %q", act, tc.exp) + } + }) + } + }) +} + +func TestPredict_Audits(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + badClient, badCloser := testVaultServerBad(t) + defer badCloser() + + if err := client.Sys().EnableAuditWithOptions("file", &api.EnableAuditOptions{ + Type: "file", + Options: map[string]string{ + "file_path": "discard", + }, + }); err != nil { + t.Fatal(err) + } + + cases := []struct { + name string + client *api.Client + exp []string + }{ + { + "not_connected_client", + badClient, + nil, + }, + { + "good_path", + client, + []string{"file/"}, + }, + } + + t.Run("group", func(t *testing.T) { + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + p := NewPredict() + p.client = tc.client + + act := p.audits() + if !reflect.DeepEqual(act, tc.exp) { + t.Errorf("expected %q to be %q", act, tc.exp) + } + }) + } + }) +} + +func TestPredict_Mounts(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + badClient, badCloser := testVaultServerBad(t) + defer badCloser() + + cases := []struct { + name string + client *api.Client + exp []string + }{ + { + "not_connected_client", + badClient, + defaultPredictVaultMounts, + }, + { + "good_path", + client, + []string{"cubbyhole/", "identity/", "secret/", "sys/"}, + }, + } + + t.Run("group", func(t *testing.T) { + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + p := NewPredict() + p.client = tc.client + + act := p.mounts() + if !reflect.DeepEqual(act, tc.exp) { + t.Errorf("expected %q to be %q", act, tc.exp) + } + }) + } + }) +} + +func TestPredict_Policies(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + badClient, badCloser := testVaultServerBad(t) + defer badCloser() + + cases := []struct { + name string + client *api.Client + exp []string + }{ + { + "not_connected_client", + badClient, + nil, + }, + { + "good_path", + client, + []string{"default", "root"}, + }, + } + + t.Run("group", func(t *testing.T) { + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + p := NewPredict() + p.client = tc.client + + act := p.policies() + if !reflect.DeepEqual(act, tc.exp) { + t.Errorf("expected %q to be %q", act, tc.exp) + } + }) + } + }) +} + +func TestPredict_Paths(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + data := map[string]interface{}{"a": "b"} + if _, err := client.Logical().Write("secret/bar", data); err != nil { + t.Fatal(err) + } + if _, err := client.Logical().Write("secret/foo", data); err != nil { + t.Fatal(err) + } + if _, err := client.Logical().Write("secret/zip/zap", data); err != nil { + t.Fatal(err) + } + + cases := []struct { + name string + path string + includeFiles bool + exp []string + }{ + { + "bad_path", + "nope/not/a/real/path/ever", + true, + []string{"nope/not/a/real/path/ever"}, + }, + { + "good_path", + "secret/", + true, + []string{"secret/bar", "secret/foo", "secret/zip/"}, + }, + { + "good_path_no_files", + "secret/", + false, + []string{"secret/zip/"}, + }, + { + "partial_match", + "secret/z", + true, + []string{"secret/zip/"}, + }, + { + "partial_match_no_files", + "secret/z", + false, + []string{"secret/zip/"}, + }, + } + + t.Run("group", func(t *testing.T) { + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + p := NewPredict() + p.client = client + + act := p.paths(tc.path, tc.includeFiles) + if !reflect.DeepEqual(act, tc.exp) { + t.Errorf("expected %q to be %q", act, tc.exp) + } + }) + } + }) +} + +func TestPredict_ListPaths(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + badClient, badCloser := testVaultServerBad(t) + defer badCloser() + + data := map[string]interface{}{"a": "b"} + if _, err := client.Logical().Write("secret/bar", data); err != nil { + t.Fatal(err) + } + if _, err := client.Logical().Write("secret/foo", data); err != nil { + t.Fatal(err) + } + + cases := []struct { + name string + client *api.Client + path string + exp []string + }{ + { + "bad_path", + client, + "nope/not/a/real/path/ever", + nil, + }, + { + "good_path", + client, + "secret/", + []string{"bar", "foo"}, + }, + { + "not_connected_client", + badClient, + "secret/", + nil, + }, + } + + t.Run("group", func(t *testing.T) { + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + p := NewPredict() + p.client = tc.client + + act := p.listPaths(tc.path) + if !reflect.DeepEqual(act, tc.exp) { + t.Errorf("expected %q to be %q", act, tc.exp) + } + }) + } + }) +} + +func TestPredict_HasPathArg(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + exp bool + }{ + { + "nil", + nil, + false, + }, + { + "empty", + []string{}, + false, + }, + { + "empty_string", + []string{""}, + false, + }, + { + "single", + []string{"foo"}, + false, + }, + { + "multiple", + []string{"foo", "bar", "baz"}, + true, + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + p := NewPredict() + if act := p.hasPathArg(tc.args); act != tc.exp { + t.Errorf("expected %t to be %t", act, tc.exp) + } + }) + } +} diff --git a/command/capabilities.go b/command/capabilities.go deleted file mode 100644 index bb60bd4ea8..0000000000 --- a/command/capabilities.go +++ /dev/null @@ -1,87 +0,0 @@ -package command - -import ( - "fmt" - "strings" - - "github.com/hashicorp/vault/meta" -) - -// CapabilitiesCommand is a Command that enables a new endpoint. -type CapabilitiesCommand struct { - meta.Meta -} - -func (c *CapabilitiesCommand) Run(args []string) int { - flags := c.Meta.FlagSet("capabilities", meta.FlagSetDefault) - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - args = flags.Args() - if len(args) > 2 { - flags.Usage() - c.Ui.Error(fmt.Sprintf( - "\ncapabilities expects at most two arguments")) - return 1 - } - - var token string - var path string - switch { - case len(args) == 1: - path = args[0] - case len(args) == 2: - token = args[0] - path = args[1] - default: - flags.Usage() - c.Ui.Error(fmt.Sprintf("\ncapabilities expects at least one argument")) - return 1 - } - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 2 - } - - var capabilities []string - if token == "" { - capabilities, err = client.Sys().CapabilitiesSelf(path) - } else { - capabilities, err = client.Sys().Capabilities(token, path) - } - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error retrieving capabilities: %s", err)) - return 1 - } - - c.Ui.Output(fmt.Sprintf("Capabilities: %s", capabilities)) - return 0 -} - -func (c *CapabilitiesCommand) Synopsis() string { - return "Fetch the capabilities of a token on a given path" -} - -func (c *CapabilitiesCommand) Help() string { - helpText := ` -Usage: vault capabilities [options] [token] path - - Fetch the capabilities of a token on a given path. - If a token is provided as an argument, the '/sys/capabilities' endpoint will be invoked - with the given token; otherwise the '/sys/capabilities-self' endpoint will be invoked - with the client token. - - If a token does not have any capability on a given path, or if any of the policies - belonging to the token explicitly have ["deny"] capability, or if the argument path - is invalid, this command will respond with a ["deny"]. - -General Options: -` + meta.GeneralOptionsUsage() - return strings.TrimSpace(helpText) -} diff --git a/command/capabilities_test.go b/command/capabilities_test.go deleted file mode 100644 index 5d106a14e9..0000000000 --- a/command/capabilities_test.go +++ /dev/null @@ -1,45 +0,0 @@ -package command - -import ( - "testing" - - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" - "github.com/mitchellh/cli" -) - -func TestCapabilities_Basic(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - ui := new(cli.MockUi) - c := &CapabilitiesCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, - }, - } - - var args []string - - args = []string{"-address", addr} - if code := c.Run(args); code == 0 { - t.Fatalf("expected failure due to no args") - } - - args = []string{"-address", addr, "testpath"} - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - args = []string{"-address", addr, token, "test"} - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - args = []string{"-address", addr, "invalidtoken", "test"} - if code := c.Run(args); code == 0 { - t.Fatalf("expected failure due to invalid token") - } -} diff --git a/command/command_test.go b/command/command_test.go index 763587a05f..0ff084f45b 100644 --- a/command/command_test.go +++ b/command/command_test.go @@ -1,17 +1,212 @@ package command import ( + "context" + "encoding/base64" + "net" + "net/http" + "strings" "testing" + "time" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/audit" + "github.com/hashicorp/vault/builtin/logical/pki" + "github.com/hashicorp/vault/builtin/logical/ssh" + "github.com/hashicorp/vault/builtin/logical/transit" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/physical/inmem" + "github.com/hashicorp/vault/vault" + "github.com/mitchellh/cli" + + auditFile "github.com/hashicorp/vault/builtin/audit/file" + credUserpass "github.com/hashicorp/vault/builtin/credential/userpass" + vaulthttp "github.com/hashicorp/vault/http" + logxi "github.com/mgutz/logxi/v1" ) -func testClient(t *testing.T, addr string, token string) *api.Client { +var ( + defaultVaultLogger = logxi.NullLog + + defaultVaultCredentialBackends = map[string]logical.Factory{ + "userpass": credUserpass.Factory, + } + + defaultVaultAuditBackends = map[string]audit.Factory{ + "file": auditFile.Factory, + } + + defaultVaultLogicalBackends = map[string]logical.Factory{ + "generic-leased": vault.LeasedPassthroughBackendFactory, + "pki": pki.Factory, + "ssh": ssh.Factory, + "transit": transit.Factory, + } +) + +// assertNoTabs asserts the CLI help has no tab characters. +func assertNoTabs(tb testing.TB, c cli.Command) { + tb.Helper() + + if strings.ContainsRune(c.Help(), '\t') { + tb.Errorf("%#v help output contains tabs", c) + } +} + +// testVaultServer creates a test vault cluster and returns a configured API +// client and closer function. +func testVaultServer(tb testing.TB) (*api.Client, func()) { + tb.Helper() + + client, _, closer := testVaultServerUnseal(tb) + return client, closer +} + +// testVaultServerUnseal creates a test vault cluster and returns a configured +// API client, list of unseal keys (as strings), and a closer function. +func testVaultServerUnseal(tb testing.TB) (*api.Client, []string, func()) { + tb.Helper() + + return testVaultServerCoreConfig(tb, &vault.CoreConfig{ + DisableMlock: true, + DisableCache: true, + Logger: defaultVaultLogger, + CredentialBackends: defaultVaultCredentialBackends, + AuditBackends: defaultVaultAuditBackends, + LogicalBackends: defaultVaultLogicalBackends, + }) +} + +// testVaultServerCoreConfig creates a new vault cluster with the given core +// configuration. This is a lower-level test helper. +func testVaultServerCoreConfig(tb testing.TB, coreConfig *vault.CoreConfig) (*api.Client, []string, func()) { + tb.Helper() + + cluster := vault.NewTestCluster(tb, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + NumCores: 1, // Default is 3, but we don't need that many + }) + cluster.Start() + + // Make it easy to get access to the active + core := cluster.Cores[0].Core + vault.TestWaitActive(tb, core) + + // Get the client already setup for us! + client := cluster.Cores[0].Client + client.SetToken(cluster.RootToken) + + // Convert the unseal keys to base64 encoded, since these are how the user + // will get them. + unsealKeys := make([]string, len(cluster.BarrierKeys)) + for i := range unsealKeys { + unsealKeys[i] = base64.StdEncoding.EncodeToString(cluster.BarrierKeys[i]) + } + + return client, unsealKeys, func() { defer cluster.Cleanup() } +} + +// testVaultServerUninit creates an uninitialized server. +func testVaultServerUninit(tb testing.TB) (*api.Client, func()) { + tb.Helper() + + inm, err := inmem.NewInmem(nil, defaultVaultLogger) + if err != nil { + tb.Fatal(err) + } + + core, err := vault.NewCore(&vault.CoreConfig{ + DisableMlock: true, + DisableCache: true, + Logger: defaultVaultLogger, + Physical: inm, + CredentialBackends: defaultVaultCredentialBackends, + AuditBackends: defaultVaultAuditBackends, + LogicalBackends: defaultVaultLogicalBackends, + }) + if err != nil { + tb.Fatal(err) + } + + ln, addr := vaulthttp.TestServer(tb, core) + + client, err := api.NewClient(&api.Config{ + Address: addr, + }) + if err != nil { + tb.Fatal(err) + } + + return client, func() { ln.Close() } +} + +// testVaultServerBad creates an http server that returns a 500 on each request +// to simulate failures. +func testVaultServerBad(tb testing.TB) (*api.Client, func()) { + tb.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + tb.Fatal(err) + } + + server := &http.Server{ + Addr: "127.0.0.1:0", + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "500 internal server error", http.StatusInternalServerError) + }), + ReadTimeout: 1 * time.Second, + ReadHeaderTimeout: 1 * time.Second, + WriteTimeout: 1 * time.Second, + IdleTimeout: 1 * time.Second, + } + + go func() { + if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { + tb.Fatal(err) + } + }() + + client, err := api.NewClient(&api.Config{ + Address: "http://" + listener.Addr().String(), + }) + if err != nil { + tb.Fatal(err) + } + + return client, func() { + ctx, done := context.WithTimeout(context.Background(), 5*time.Second) + defer done() + + server.Shutdown(ctx) + } +} + +// testTokenAndAccessor creates a new authentication token capable of being renewed with +// the default policy attached. It returns the token and it's accessor. +func testTokenAndAccessor(tb testing.TB, client *api.Client) (string, string) { + tb.Helper() + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + TTL: "30m", + }) + if err != nil { + tb.Fatal(err) + } + if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" { + tb.Fatalf("missing auth data: %#v", secret) + } + return secret.Auth.ClientToken, secret.Auth.Accessor +} + +func testClient(tb testing.TB, addr string, token string) *api.Client { + tb.Helper() config := api.DefaultConfig() config.Address = addr client, err := api.NewClient(config) if err != nil { - t.Fatalf("err: %s", err) + tb.Fatal(err) } client.SetToken(token) diff --git a/command/commands.go b/command/commands.go new file mode 100644 index 0000000000..465dae748e --- /dev/null +++ b/command/commands.go @@ -0,0 +1,960 @@ +package command + +import ( + "fmt" + "os" + "os/signal" + "syscall" + + "github.com/hashicorp/vault/audit" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/physical" + "github.com/hashicorp/vault/version" + "github.com/mitchellh/cli" + + "github.com/hashicorp/vault/builtin/logical/aws" + "github.com/hashicorp/vault/builtin/logical/cassandra" + "github.com/hashicorp/vault/builtin/logical/consul" + "github.com/hashicorp/vault/builtin/logical/database" + "github.com/hashicorp/vault/builtin/logical/mongodb" + "github.com/hashicorp/vault/builtin/logical/mssql" + "github.com/hashicorp/vault/builtin/logical/mysql" + "github.com/hashicorp/vault/builtin/logical/pki" + "github.com/hashicorp/vault/builtin/logical/postgresql" + "github.com/hashicorp/vault/builtin/logical/rabbitmq" + "github.com/hashicorp/vault/builtin/logical/ssh" + "github.com/hashicorp/vault/builtin/logical/totp" + "github.com/hashicorp/vault/builtin/logical/transit" + "github.com/hashicorp/vault/builtin/plugin" + + auditFile "github.com/hashicorp/vault/builtin/audit/file" + auditSocket "github.com/hashicorp/vault/builtin/audit/socket" + auditSyslog "github.com/hashicorp/vault/builtin/audit/syslog" + + credGcp "github.com/hashicorp/vault-plugin-auth-gcp/plugin" + credKube "github.com/hashicorp/vault-plugin-auth-kubernetes" + credAppId "github.com/hashicorp/vault/builtin/credential/app-id" + credAppRole "github.com/hashicorp/vault/builtin/credential/approle" + credAws "github.com/hashicorp/vault/builtin/credential/aws" + credCert "github.com/hashicorp/vault/builtin/credential/cert" + credGitHub "github.com/hashicorp/vault/builtin/credential/github" + credLdap "github.com/hashicorp/vault/builtin/credential/ldap" + credOkta "github.com/hashicorp/vault/builtin/credential/okta" + credRadius "github.com/hashicorp/vault/builtin/credential/radius" + credToken "github.com/hashicorp/vault/builtin/credential/token" + credUserpass "github.com/hashicorp/vault/builtin/credential/userpass" + + physAzure "github.com/hashicorp/vault/physical/azure" + physCassandra "github.com/hashicorp/vault/physical/cassandra" + physCockroachDB "github.com/hashicorp/vault/physical/cockroachdb" + physConsul "github.com/hashicorp/vault/physical/consul" + physCouchDB "github.com/hashicorp/vault/physical/couchdb" + physDynamoDB "github.com/hashicorp/vault/physical/dynamodb" + physEtcd "github.com/hashicorp/vault/physical/etcd" + physFile "github.com/hashicorp/vault/physical/file" + physGCS "github.com/hashicorp/vault/physical/gcs" + physInmem "github.com/hashicorp/vault/physical/inmem" + physMSSQL "github.com/hashicorp/vault/physical/mssql" + physMySQL "github.com/hashicorp/vault/physical/mysql" + physPostgreSQL "github.com/hashicorp/vault/physical/postgresql" + physS3 "github.com/hashicorp/vault/physical/s3" + physSwift "github.com/hashicorp/vault/physical/swift" + physZooKeeper "github.com/hashicorp/vault/physical/zookeeper" +) + +// DeprecatedCommand is a command that wraps an existing command and prints a +// deprecation notice and points the user to the new command. Deprecated +// commands are always hidden from help output. +type DeprecatedCommand struct { + cli.Command + UI cli.Ui + + // Old is the old command name, New is the new command name. + Old, New string +} + +// Help wraps the embedded Help command and prints a warning about deprecations. +func (c *DeprecatedCommand) Help() string { + c.warn() + return c.Command.Help() +} + +// Run wraps the embedded Run command and prints a warning about deprecation. +func (c *DeprecatedCommand) Run(args []string) int { + c.warn() + return c.Command.Run(args) +} + +func (c *DeprecatedCommand) warn() { + c.UI.Warn(wrapAtLength(fmt.Sprintf( + "WARNING! The \"vault %s\" command is deprecated. Please use \"vault %s\" "+ + "instead. This command will be removed in the next major release of "+ + "Vault.", + c.Old, + c.New))) + c.UI.Warn("") +} + +// Commands is the mapping of all the available commands. +var Commands map[string]cli.CommandFactory +var DeprecatedCommands map[string]cli.CommandFactory + +func init() { + ui := &cli.ColoredUi{ + ErrorColor: cli.UiColorRed, + WarnColor: cli.UiColorYellow, + Ui: &cli.BasicUi{ + Writer: os.Stdout, + ErrorWriter: os.Stderr, + }, + } + + loginHandlers := map[string]LoginHandler{ + "aws": &credAws.CLIHandler{}, + "cert": &credCert.CLIHandler{}, + "github": &credGitHub.CLIHandler{}, + "ldap": &credLdap.CLIHandler{}, + "okta": &credOkta.CLIHandler{}, + "radius": &credUserpass.CLIHandler{ + DefaultMount: "radius", + }, + "token": &credToken.CLIHandler{}, + "userpass": &credUserpass.CLIHandler{ + DefaultMount: "userpass", + }, + } + + Commands = map[string]cli.CommandFactory{ + "audit": func() (cli.Command, error) { + return &AuditCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "audit disable": func() (cli.Command, error) { + return &AuditDisableCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "audit enable": func() (cli.Command, error) { + return &AuditEnableCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "audit list": func() (cli.Command, error) { + return &AuditListCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "auth tune": func() (cli.Command, error) { + return &AuthTuneCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "auth": func() (cli.Command, error) { + return &AuthCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + Handlers: loginHandlers, + }, nil + }, + "auth disable": func() (cli.Command, error) { + return &AuthDisableCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "auth enable": func() (cli.Command, error) { + return &AuthEnableCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "auth help": func() (cli.Command, error) { + return &AuthHelpCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + Handlers: loginHandlers, + }, nil + }, + "auth list": func() (cli.Command, error) { + return &AuthListCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "delete": func() (cli.Command, error) { + return &DeleteCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "lease": func() (cli.Command, error) { + return &LeaseCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "lease renew": func() (cli.Command, error) { + return &LeaseRenewCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "lease revoke": func() (cli.Command, error) { + return &LeaseRevokeCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "list": func() (cli.Command, error) { + return &ListCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "login": func() (cli.Command, error) { + return &LoginCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + Handlers: loginHandlers, + }, nil + }, + "operator": func() (cli.Command, error) { + return &OperatorCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "operator generate-root": func() (cli.Command, error) { + return &OperatorGenerateRootCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "operator init": func() (cli.Command, error) { + return &OperatorInitCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "operator key-status": func() (cli.Command, error) { + return &OperatorKeyStatusCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "operator rekey": func() (cli.Command, error) { + return &OperatorRekeyCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "operator rotate": func() (cli.Command, error) { + return &OperatorRotateCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "operator seal": func() (cli.Command, error) { + return &OperatorSealCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "operator step-down": func() (cli.Command, error) { + return &OperatorStepDownCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "operator unseal": func() (cli.Command, error) { + return &OperatorUnsealCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "path-help": func() (cli.Command, error) { + return &PathHelpCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "policy": func() (cli.Command, error) { + return &PolicyCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "policy delete": func() (cli.Command, error) { + return &PolicyDeleteCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "policy fmt": func() (cli.Command, error) { + return &PolicyFmtCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "policy list": func() (cli.Command, error) { + return &PolicyListCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "policy read": func() (cli.Command, error) { + return &PolicyReadCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "policy write": func() (cli.Command, error) { + return &PolicyWriteCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "read": func() (cli.Command, error) { + return &ReadCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "secrets": func() (cli.Command, error) { + return &SecretsCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "secrets disable": func() (cli.Command, error) { + return &SecretsDisableCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "secrets enable": func() (cli.Command, error) { + return &SecretsEnableCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "secrets list": func() (cli.Command, error) { + return &SecretsListCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "secrets move": func() (cli.Command, error) { + return &SecretsMoveCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "secrets tune": func() (cli.Command, error) { + return &SecretsTuneCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "server": func() (cli.Command, error) { + return &ServerCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + AuditBackends: map[string]audit.Factory{ + "file": auditFile.Factory, + "socket": auditSocket.Factory, + "syslog": auditSyslog.Factory, + }, + CredentialBackends: map[string]logical.Factory{ + "app-id": credAppId.Factory, + "approle": credAppRole.Factory, + "aws": credAws.Factory, + "cert": credCert.Factory, + "gcp": credGcp.Factory, + "github": credGitHub.Factory, + "kubernetes": credKube.Factory, + "ldap": credLdap.Factory, + "okta": credOkta.Factory, + "plugin": plugin.Factory, + "radius": credRadius.Factory, + "userpass": credUserpass.Factory, + }, + LogicalBackends: map[string]logical.Factory{ + "aws": aws.Factory, + "cassandra": cassandra.Factory, + "consul": consul.Factory, + "database": database.Factory, + "mongodb": mongodb.Factory, + "mssql": mssql.Factory, + "mysql": mysql.Factory, + "pki": pki.Factory, + "plugin": plugin.Factory, + "postgresql": postgresql.Factory, + "rabbitmq": rabbitmq.Factory, + "ssh": ssh.Factory, + "totp": totp.Factory, + "transit": transit.Factory, + }, + PhysicalBackends: map[string]physical.Factory{ + "azure": physAzure.NewAzureBackend, + "cassandra": physCassandra.NewCassandraBackend, + "cockroachdb": physCockroachDB.NewCockroachDBBackend, + "consul": physConsul.NewConsulBackend, + "couchdb_transactional": physCouchDB.NewTransactionalCouchDBBackend, + "couchdb": physCouchDB.NewCouchDBBackend, + "dynamodb": physDynamoDB.NewDynamoDBBackend, + "etcd": physEtcd.NewEtcdBackend, + "file_transactional": physFile.NewTransactionalFileBackend, + "file": physFile.NewFileBackend, + "gcs": physGCS.NewGCSBackend, + "inmem_ha": physInmem.NewInmemHA, + "inmem_transactional_ha": physInmem.NewTransactionalInmemHA, + "inmem_transactional": physInmem.NewTransactionalInmem, + "inmem": physInmem.NewInmem, + "mssql": physMSSQL.NewMSSQLBackend, + "mysql": physMySQL.NewMySQLBackend, + "postgresql": physPostgreSQL.NewPostgreSQLBackend, + "s3": physS3.NewS3Backend, + "swift": physSwift.NewSwiftBackend, + "zookeeper": physZooKeeper.NewZooKeeperBackend, + }, + ShutdownCh: MakeShutdownCh(), + SighupCh: MakeSighupCh(), + }, nil + }, + "ssh": func() (cli.Command, error) { + return &SSHCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "status": func() (cli.Command, error) { + return &StatusCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "token": func() (cli.Command, error) { + return &TokenCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "token create": func() (cli.Command, error) { + return &TokenCreateCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "token capabilities": func() (cli.Command, error) { + return &TokenCapabilitiesCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "token lookup": func() (cli.Command, error) { + return &TokenLookupCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "token renew": func() (cli.Command, error) { + return &TokenRenewCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "token revoke": func() (cli.Command, error) { + return &TokenRevokeCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "unwrap": func() (cli.Command, error) { + return &UnwrapCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "version": func() (cli.Command, error) { + return &VersionCommand{ + VersionInfo: version.GetVersion(), + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + "write": func() (cli.Command, error) { + return &WriteCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, nil + }, + } + + // Deprecated commands + // + // TODO: Remove in 0.9.0 + DeprecatedCommands = map[string]cli.CommandFactory{ + "audit-disable": func() (cli.Command, error) { + return &DeprecatedCommand{ + Old: "audit-disable", + New: "audit disable", + UI: ui, + Command: &AuditDisableCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, + }, nil + }, + + "audit-enable": func() (cli.Command, error) { + return &DeprecatedCommand{ + Old: "audit-enable", + New: "audit enable", + UI: ui, + Command: &AuditEnableCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, + }, nil + }, + + "audit-list": func() (cli.Command, error) { + return &DeprecatedCommand{ + Old: "audit-list", + New: "audit list", + UI: ui, + Command: &AuditListCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, + }, nil + }, + + "auth-disable": func() (cli.Command, error) { + return &DeprecatedCommand{ + Old: "auth-disable", + New: "auth disable", + UI: ui, + Command: &AuthDisableCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, + }, nil + }, + + "auth-enable": func() (cli.Command, error) { + return &DeprecatedCommand{ + Old: "auth-enable", + New: "auth enable", + UI: ui, + Command: &AuthEnableCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, + }, nil + }, + + "capabilities": func() (cli.Command, error) { + return &DeprecatedCommand{ + Old: "capabilities", + New: "token capabilities", + UI: ui, + Command: &TokenCapabilitiesCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, + }, nil + }, + + "generate-root": func() (cli.Command, error) { + return &DeprecatedCommand{ + Old: "generate-root", + New: "operator generate-root", + UI: ui, + Command: &OperatorGenerateRootCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, + }, nil + }, + + "init": func() (cli.Command, error) { + return &DeprecatedCommand{ + Old: "init", + New: "operator init", + UI: ui, + Command: &OperatorInitCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, + }, nil + }, + + "key-status": func() (cli.Command, error) { + return &DeprecatedCommand{ + Old: "key-status", + New: "operator key-status", + UI: ui, + Command: &OperatorKeyStatusCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, + }, nil + }, + + "renew": func() (cli.Command, error) { + return &DeprecatedCommand{ + Old: "renew", + New: "lease renew", + UI: ui, + Command: &LeaseRenewCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, + }, nil + }, + + "revoke": func() (cli.Command, error) { + return &DeprecatedCommand{ + Old: "revoke", + New: "lease revoke", + UI: ui, + Command: &LeaseRevokeCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, + }, nil + }, + + "mount": func() (cli.Command, error) { + return &DeprecatedCommand{ + Old: "mount", + New: "secrets enable", + UI: ui, + Command: &SecretsEnableCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, + }, nil + }, + + "mount-tune": func() (cli.Command, error) { + return &DeprecatedCommand{ + Old: "mount-tune", + New: "secrets tune", + UI: ui, + Command: &SecretsTuneCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, + }, nil + }, + + "mounts": func() (cli.Command, error) { + return &DeprecatedCommand{ + Old: "mounts", + New: "secrets list", + UI: ui, + Command: &SecretsListCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, + }, nil + }, + + "policies": func() (cli.Command, error) { + return &DeprecatedCommand{ + Old: "policies", + New: "policy read\" or \"vault policy list", // lol + UI: ui, + Command: &PoliciesDeprecatedCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, + }, nil + }, + + "policy-delete": func() (cli.Command, error) { + return &DeprecatedCommand{ + Old: "policy-delete", + New: "policy delete", + UI: ui, + Command: &PolicyDeleteCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, + }, nil + }, + + "policy-write": func() (cli.Command, error) { + return &DeprecatedCommand{ + Old: "policy-write", + New: "policy write", + UI: ui, + Command: &PolicyWriteCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, + }, nil + }, + + "rekey": func() (cli.Command, error) { + return &DeprecatedCommand{ + Old: "rekey", + New: "operator rekey", + UI: ui, + Command: &OperatorRekeyCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, + }, nil + }, + + "remount": func() (cli.Command, error) { + return &DeprecatedCommand{ + Old: "remount", + New: "secrets move", + UI: ui, + Command: &SecretsMoveCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, + }, nil + }, + + "rotate": func() (cli.Command, error) { + return &DeprecatedCommand{ + Old: "rotate", + New: "operator rotate", + UI: ui, + Command: &OperatorRotateCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, + }, nil + }, + + "seal": func() (cli.Command, error) { + return &DeprecatedCommand{ + Old: "seal", + New: "operator seal", + UI: ui, + Command: &OperatorSealCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, + }, nil + }, + + "step-down": func() (cli.Command, error) { + return &DeprecatedCommand{ + Old: "step-down", + New: "operator step-down", + UI: ui, + Command: &OperatorStepDownCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, + }, nil + }, + + "token-create": func() (cli.Command, error) { + return &DeprecatedCommand{ + Old: "token-create", + New: "token create", + UI: ui, + Command: &TokenCreateCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, + }, nil + }, + + "token-lookup": func() (cli.Command, error) { + return &DeprecatedCommand{ + Old: "token-lookup", + New: "token lookup", + UI: ui, + Command: &TokenLookupCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, + }, nil + }, + + "token-renew": func() (cli.Command, error) { + return &DeprecatedCommand{ + Old: "token-renew", + New: "token renew", + UI: ui, + Command: &TokenRenewCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, + }, nil + }, + + "token-revoke": func() (cli.Command, error) { + return &DeprecatedCommand{ + Old: "token-revoke", + New: "token revoke", + UI: ui, + Command: &TokenRevokeCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, + }, nil + }, + + "unmount": func() (cli.Command, error) { + return &DeprecatedCommand{ + Old: "unmount", + New: "secrets disable", + UI: ui, + Command: &SecretsDisableCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, + }, nil + }, + + "unseal": func() (cli.Command, error) { + return &DeprecatedCommand{ + Old: "unseal", + New: "operator unseal", + UI: ui, + Command: &OperatorUnsealCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + }, + }, nil + }, + } + + // Add deprecated commands back to the main commands so they parse. + for k, v := range DeprecatedCommands { + if _, ok := Commands[k]; ok { + // Can't deprecate an existing command... + panic(fmt.Sprintf("command %q defined as deprecated and not at the same time!", k)) + } + Commands[k] = v + } +} + +// MakeShutdownCh returns a channel that can be used for shutdown +// notifications for commands. This channel will send a message for every +// SIGINT or SIGTERM received. +func MakeShutdownCh() chan struct{} { + resultCh := make(chan struct{}) + + shutdownCh := make(chan os.Signal, 4) + signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM) + go func() { + <-shutdownCh + close(resultCh) + }() + return resultCh +} + +// MakeSighupCh returns a channel that can be used for SIGHUP +// reloading. This channel will send a message for every +// SIGHUP received. +func MakeSighupCh() chan struct{} { + resultCh := make(chan struct{}) + + signalCh := make(chan os.Signal, 4) + signal.Notify(signalCh, syscall.SIGHUP) + go func() { + for { + <-signalCh + resultCh <- struct{}{} + } + }() + return resultCh +} diff --git a/command/delete.go b/command/delete.go index d9a8ee8a66..12c7f96117 100644 --- a/command/delete.go +++ b/command/delete.go @@ -4,64 +4,91 @@ import ( "fmt" "strings" - "github.com/hashicorp/vault/meta" + "github.com/mitchellh/cli" + "github.com/posener/complete" ) -// DeleteCommand is a Command that puts data into the Vault. +var _ cli.Command = (*DeleteCommand)(nil) +var _ cli.CommandAutocomplete = (*DeleteCommand)(nil) + type DeleteCommand struct { - meta.Meta -} - -func (c *DeleteCommand) Run(args []string) int { - flags := c.Meta.FlagSet("delete", meta.FlagSetDefault) - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - args = flags.Args() - if len(args) != 1 { - c.Ui.Error("delete expects one argument") - flags.Usage() - return 1 - } - - path := args[0] - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 2 - } - - if _, err := client.Logical().Delete(path); err != nil { - c.Ui.Error(fmt.Sprintf( - "Error deleting '%s': %s", path, err)) - return 1 - } - - c.Ui.Output(fmt.Sprintf("Success! Deleted '%s' if it existed.", path)) - return 0 + *BaseCommand } func (c *DeleteCommand) Synopsis() string { - return "Delete operation on secrets in Vault" + return "Delete secrets and configuration" } func (c *DeleteCommand) Help() string { helpText := ` -Usage: vault delete [options] path +Usage: vault delete [options] PATH - Delete data (secrets or configuration) from Vault. + Deletes secrets and configuration from Vault at the given path. The behavior + of "delete" is delegated to the backend corresponding to the given path. - Delete sends a delete operation request to the given path. The - behavior of the delete is determined by the backend at the given - path. For example, deleting "aws/policy/ops" will delete the "ops" - policy for the AWS backend. Use "vault help" for more details on - whether delete is supported for a path and what the behavior is. + Remove data in the status secret backend: + + $ vault delete secret/my-secret + + Uninstall an encryption key in the transit backend: + + $ vault delete transit/keys/my-key + + Delete an IAM role: + + $ vault delete aws/roles/ops + + For a full list of examples and paths, please see the documentation that + corresponds to the secret backend in use. + +` + c.Flags().Help() -General Options: -` + meta.GeneralOptionsUsage() return strings.TrimSpace(helpText) } + +func (c *DeleteCommand) Flags() *FlagSets { + return c.flagSet(FlagSetHTTP) +} + +func (c *DeleteCommand) AutocompleteArgs() complete.Predictor { + return c.PredictVaultFiles() +} + +func (c *DeleteCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *DeleteCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + switch { + case len(args) < 1: + c.UI.Error(fmt.Sprintf("Not enough arguments (expected 1, got %d)", len(args))) + return 1 + case len(args) > 1: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args))) + return 1 + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + path := sanitizePath(args[0]) + + if _, err := client.Logical().Delete(path); err != nil { + c.UI.Error(fmt.Sprintf("Error deleting %s: %s", path, err)) + return 2 + } + + c.UI.Info(fmt.Sprintf("Success! Data deleted (if it existed) at: %s", path)) + return 0 +} diff --git a/command/delete_test.go b/command/delete_test.go index c5efc415b8..44970f57af 100644 --- a/command/delete_test.go +++ b/command/delete_test.go @@ -1,56 +1,131 @@ package command import ( + "strings" "testing" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" ) -func TestDelete(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func testDeleteCommand(tb testing.TB) (*cli.MockUi, *DeleteCommand) { + tb.Helper() - ui := new(cli.MockUi) - c := &DeleteCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + ui := cli.NewMockUi() + return ui, &DeleteCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestDeleteCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "not_enough_args", + []string{}, + "Not enough arguments", + 1, + }, + { + "too_many_args", + []string{"foo", "bar"}, + "Too many arguments", + 1, }, } - args := []string{ - "-address", addr, - "secret/foo", - } + t.Run("validations", func(t *testing.T) { + t.Parallel() - // Run once so the client is setup, ignore errors - c.Run(args) + for _, tc := range cases { + tc := tc - // Get the client so we can write data - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - data := map[string]interface{}{"value": "bar"} - if _, err := client.Logical().Write("secret/foo", data); err != nil { - t.Fatalf("err: %s", err) - } + ui, cmd := testDeleteCommand(t) - // Run the delete - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } - resp, err := client.Logical().Read("secret/foo") - if err != nil { - t.Fatalf("err: %s", err) - } - if resp != nil { - t.Fatalf("bad: %#v", resp) - } + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) + + t.Run("integration", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + if _, err := client.Logical().Write("secret/delete/foo", map[string]interface{}{ + "foo": "bar", + }); err != nil { + t.Fatal(err) + } + + ui, cmd := testDeleteCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "secret/delete/foo", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Success! Data deleted (if it existed) at: secret/delete/foo" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + secret, _ := client.Logical().Read("secret/delete/foo") + if secret != nil { + t.Errorf("expected deletion: %#v", secret) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testDeleteCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "secret/delete/foo", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error deleting secret/delete/foo: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testDeleteCommand(t) + assertNoTabs(t, cmd) + }) } diff --git a/command/format.go b/command/format.go index aab5a856f2..c1397b3383 100644 --- a/command/format.go +++ b/command/format.go @@ -1,24 +1,23 @@ package command import ( - "bytes" "encoding/json" "errors" "fmt" "sort" - "strconv" "strings" - "sync" - "time" "github.com/ghodss/yaml" "github.com/hashicorp/vault/api" "github.com/mitchellh/cli" - "github.com/posener/complete" "github.com/ryanuber/columnize" ) -var predictFormat complete.Predictor = complete.PredictSet("json", "yaml") +const ( + // hopeDelim is the delimiter to use when splitting columns. We call it a + // hopeDelim because we hope that it's never contained in a secret. + hopeDelim = "♨" +) func OutputSecret(ui cli.Ui, format string, secret *api.Secret) int { return outputWithFormat(ui, format, secret, secret) @@ -29,6 +28,13 @@ func OutputList(ui cli.Ui, format string, secret *api.Secret) int { } func outputWithFormat(ui cli.Ui, format string, secret *api.Secret, data interface{}) int { + // If we had a colored UI, pull out the nested ui so we don't add escape + // sequences for outputting json, etc. + colorUI, ok := ui.(*cli.ColoredUi) + if ok { + ui = colorUI.Ui + } + formatter, ok := Formatters[strings.ToLower(format)] if !ok { ui.Error(fmt.Sprintf("Invalid output format: %s", format)) @@ -53,17 +59,15 @@ var Formatters = map[string]Formatter{ } // An output formatter for json output of an object -type JsonFormatter struct { -} +type JsonFormatter struct{} func (j JsonFormatter) Output(ui cli.Ui, secret *api.Secret, data interface{}) error { - b, err := json.Marshal(data) - if err == nil { - var out bytes.Buffer - json.Indent(&out, b, "", "\t") - ui.Output(out.String()) + b, err := json.MarshalIndent(data, "", " ") + if err != nil { + return err } - return err + ui.Output(string(b)) + return nil } // An output formatter for yaml output format of an object @@ -85,7 +89,7 @@ type TableFormatter struct { func (t TableFormatter) Output(ui cli.Ui, secret *api.Secret, data interface{}) error { // TODO: this should really use reflection like the other formatters do if s, ok := data.(*api.Secret); ok { - return t.OutputSecret(ui, secret, s) + return t.OutputSecret(ui, s) } if s, ok := data.([]interface{}); ok { return t.OutputList(ui, secret, s) @@ -94,133 +98,166 @@ func (t TableFormatter) Output(ui cli.Ui, secret *api.Secret, data interface{}) } func (t TableFormatter) OutputList(ui cli.Ui, secret *api.Secret, list []interface{}) error { - config := columnize.DefaultConfig() - config.Delim = "♨" - config.Glue = "\t" - config.Prefix = "" - - input := make([]string, 0, 5) + t.printWarnings(ui, secret) if len(list) > 0 { - input = append(input, "Keys") - input = append(input, "----") - - keys := make([]string, 0, len(list)) - for _, k := range list { - keys = append(keys, k.(string)) + keys := make([]string, len(list)) + for i, v := range list { + typed, ok := v.(string) + if !ok { + return fmt.Errorf("Error: %v is not a string", v) + } + keys[i] = typed } sort.Strings(keys) - for _, k := range keys { - input = append(input, fmt.Sprintf("%s", k)) - } + // Prepend the header + keys = append([]string{"Keys"}, keys...) + + ui.Output(tableOutput(keys, &columnize.Config{ + Delim: hopeDelim, + })) } - tableOutputStr := columnize.Format(input, config) - - // Print the warning separately because the length of first - // column in the output will be increased by the length of - // the longest warning string making the output look bad. - warningsInput := make([]string, 0, 5) - if len(secret.Warnings) != 0 { - warningsInput = append(warningsInput, "") - warningsInput = append(warningsInput, "The following warnings were returned from the Vault server:") - for _, warning := range secret.Warnings { - warningsInput = append(warningsInput, fmt.Sprintf("* %s", warning)) - } - } - - warningsOutputStr := columnize.Format(warningsInput, config) - - ui.Output(fmt.Sprintf("%s\n%s", tableOutputStr, warningsOutputStr)) - return nil } -func (t TableFormatter) OutputSecret(ui cli.Ui, secret, s *api.Secret) error { - config := columnize.DefaultConfig() - config.Delim = "♨" - config.Glue = "\t" - config.Prefix = "" +// printWarnings prints any warnings in the secret. +func (t TableFormatter) printWarnings(ui cli.Ui, secret *api.Secret) { + if secret != nil && len(secret.Warnings) > 0 { + ui.Warn("WARNING! The following warnings were returned from Vault:\n") + for _, warning := range secret.Warnings { + ui.Warn(wrapAtLengthWithPadding(fmt.Sprintf("* %s", warning), 2)) + } + ui.Warn("") + } +} - input := make([]string, 0, 5) - - onceHeader := &sync.Once{} - headerFunc := func() { - input = append(input, fmt.Sprintf("Key %s Value", config.Delim)) - input = append(input, fmt.Sprintf("--- %s -----", config.Delim)) +func (t TableFormatter) OutputSecret(ui cli.Ui, secret *api.Secret) error { + if secret == nil { + return nil } - if s.LeaseDuration > 0 { - onceHeader.Do(headerFunc) - if s.LeaseID != "" { - input = append(input, fmt.Sprintf("lease_id %s %s", config.Delim, s.LeaseID)) - input = append(input, fmt.Sprintf( - "lease_duration %s %s", config.Delim, (time.Second*time.Duration(s.LeaseDuration)).String())) + t.printWarnings(ui, secret) + + out := make([]string, 0, 8) + if secret.LeaseDuration > 0 { + if secret.LeaseID != "" { + out = append(out, fmt.Sprintf("lease_id %s %s", hopeDelim, secret.LeaseID)) + out = append(out, fmt.Sprintf("lease_duration %s %s", hopeDelim, humanDurationInt(secret.LeaseDuration))) + out = append(out, fmt.Sprintf("lease_renewable %s %t", hopeDelim, secret.Renewable)) } else { - input = append(input, fmt.Sprintf( - "refresh_interval %s %s", config.Delim, (time.Second*time.Duration(s.LeaseDuration)).String())) - } - if s.LeaseID != "" { - input = append(input, fmt.Sprintf( - "lease_renewable %s %s", config.Delim, strconv.FormatBool(s.Renewable))) + // This is probably the generic secret backend which has leases, but we + // print them as refresh_interval to reduce confusion. + out = append(out, fmt.Sprintf("refresh_interval %s %s", hopeDelim, humanDurationInt(secret.LeaseDuration))) } } - if s.Auth != nil { - onceHeader.Do(headerFunc) - input = append(input, fmt.Sprintf("token %s %s", config.Delim, s.Auth.ClientToken)) - input = append(input, fmt.Sprintf("token_accessor %s %s", config.Delim, s.Auth.Accessor)) - input = append(input, fmt.Sprintf("token_duration %s %s", config.Delim, (time.Second*time.Duration(s.Auth.LeaseDuration)).String())) - input = append(input, fmt.Sprintf("token_renewable %s %v", config.Delim, s.Auth.Renewable)) - input = append(input, fmt.Sprintf("token_policies %s %v", config.Delim, s.Auth.Policies)) - for k, v := range s.Auth.Metadata { - input = append(input, fmt.Sprintf("token_meta_%s %s %#v", k, config.Delim, v)) + if secret.Auth != nil { + out = append(out, fmt.Sprintf("token %s %s", hopeDelim, secret.Auth.ClientToken)) + out = append(out, fmt.Sprintf("token_accessor %s %s", hopeDelim, secret.Auth.Accessor)) + // If the lease duration is 0, it's likely a root token, so output the + // duration as "infinity" to clear things up. + if secret.Auth.LeaseDuration == 0 { + out = append(out, fmt.Sprintf("token_duration %s %s", hopeDelim, "∞")) + } else { + out = append(out, fmt.Sprintf("token_duration %s %s", hopeDelim, humanDurationInt(secret.Auth.LeaseDuration))) + } + out = append(out, fmt.Sprintf("token_renewable %s %t", hopeDelim, secret.Auth.Renewable)) + out = append(out, fmt.Sprintf("token_policies %s %v", hopeDelim, secret.Auth.Policies)) + for k, v := range secret.Auth.Metadata { + out = append(out, fmt.Sprintf("token_meta_%s %s %v", k, hopeDelim, v)) } } - if s.WrapInfo != nil { - onceHeader.Do(headerFunc) - input = append(input, fmt.Sprintf("wrapping_token: %s %s", config.Delim, s.WrapInfo.Token)) - input = append(input, fmt.Sprintf("wrapping_accessor: %s %s", config.Delim, s.WrapInfo.Accessor)) - input = append(input, fmt.Sprintf("wrapping_token_ttl: %s %s", config.Delim, (time.Second*time.Duration(s.WrapInfo.TTL)).String())) - input = append(input, fmt.Sprintf("wrapping_token_creation_time: %s %s", config.Delim, s.WrapInfo.CreationTime.String())) - input = append(input, fmt.Sprintf("wrapping_token_creation_path: %s %s", config.Delim, s.WrapInfo.CreationPath)) - if s.WrapInfo.WrappedAccessor != "" { - input = append(input, fmt.Sprintf("wrapped_accessor: %s %s", config.Delim, s.WrapInfo.WrappedAccessor)) + if secret.WrapInfo != nil { + out = append(out, fmt.Sprintf("wrapping_token: %s %s", hopeDelim, secret.WrapInfo.Token)) + out = append(out, fmt.Sprintf("wrapping_accessor: %s %s", hopeDelim, secret.WrapInfo.Accessor)) + out = append(out, fmt.Sprintf("wrapping_token_ttl: %s %s", hopeDelim, humanDurationInt(secret.WrapInfo.TTL))) + out = append(out, fmt.Sprintf("wrapping_token_creation_time: %s %s", hopeDelim, secret.WrapInfo.CreationTime.String())) + out = append(out, fmt.Sprintf("wrapping_token_creation_path: %s %s", hopeDelim, secret.WrapInfo.CreationPath)) + if secret.WrapInfo.WrappedAccessor != "" { + out = append(out, fmt.Sprintf("wrapped_accessor: %s %s", hopeDelim, secret.WrapInfo.WrappedAccessor)) } } - if s.Data != nil && len(s.Data) > 0 { - onceHeader.Do(headerFunc) - keys := make([]string, 0, len(s.Data)) - for k := range s.Data { + if len(secret.Data) > 0 { + keys := make([]string, 0, len(secret.Data)) + for k := range secret.Data { keys = append(keys, k) } sort.Strings(keys) for _, k := range keys { - input = append(input, fmt.Sprintf("%s %s %v", k, config.Delim, s.Data[k])) + out = append(out, fmt.Sprintf("%s %s %v", k, hopeDelim, secret.Data[k])) } } - tableOutputStr := columnize.Format(input, config) - - // Print the warning separately because the length of first - // column in the output will be increased by the length of - // the longest warning string making the output look bad. - warningsInput := make([]string, 0, 5) - if len(s.Warnings) != 0 { - warningsInput = append(warningsInput, "") - warningsInput = append(warningsInput, "The following warnings were returned from the Vault server:") - for _, warning := range s.Warnings { - warningsInput = append(warningsInput, fmt.Sprintf("* %s", warning)) - } + // If we got this far and still don't have any data, there's nothing to print, + // sorry. + if len(out) == 0 { + return nil } - warningsOutputStr := columnize.Format(warningsInput, config) - - ui.Output(fmt.Sprintf("%s\n%s", tableOutputStr, warningsOutputStr)) + // Prepend the header + out = append([]string{"Key" + hopeDelim + "Value"}, out...) + ui.Output(tableOutput(out, &columnize.Config{ + Delim: hopeDelim, + })) return nil } + +func OutputSealStatus(ui cli.Ui, client *api.Client, status *api.SealStatusResponse) int { + var sealPrefix string + if status.RecoverySeal { + sealPrefix = "Recovery " + } + + out := []string{} + out = append(out, "Key | Value") + out = append(out, fmt.Sprintf("%sSeal Type | %s", sealPrefix, status.Type)) + out = append(out, fmt.Sprintf("Sealed | %t", status.Sealed)) + out = append(out, fmt.Sprintf("Total %sShares | %d", sealPrefix, status.N)) + out = append(out, fmt.Sprintf("Threshold | %d", status.T)) + + if status.Sealed { + out = append(out, fmt.Sprintf("Unseal Progress | %d/%d", status.Progress, status.T)) + out = append(out, fmt.Sprintf("Unseal Nonce | %s", status.Nonce)) + } + + out = append(out, fmt.Sprintf("Version | %s", status.Version)) + + if status.ClusterName != "" && status.ClusterID != "" { + out = append(out, fmt.Sprintf("Cluster Name | %s", status.ClusterName)) + out = append(out, fmt.Sprintf("Cluster ID | %s", status.ClusterID)) + } + + // Mask the 'Vault is sealed' error, since this means HA is enabled, but that + // we cannot query for the leader since we are sealed. + leaderStatus, err := client.Sys().Leader() + if err != nil && strings.Contains(err.Error(), "Vault is sealed") { + leaderStatus = &api.LeaderResponse{HAEnabled: true} + } + + // Output if HA is enabled + out = append(out, fmt.Sprintf("HA Enabled | %t", leaderStatus.HAEnabled)) + if leaderStatus.HAEnabled { + mode := "sealed" + if !status.Sealed { + mode = "standby" + if leaderStatus.IsSelf { + mode = "active" + } + } + + out = append(out, fmt.Sprintf("HA Mode | %s", mode)) + + if !status.Sealed { + out = append(out, fmt.Sprintf("HA Cluster | %s", leaderStatus.LeaderClusterAddress)) + } + } + + ui.Output(tableOutput(out, nil)) + return 0 +} diff --git a/command/format_test.go b/command/format_test.go index 8e32d2419c..44020a1aa6 100644 --- a/command/format_test.go +++ b/command/format_test.go @@ -24,18 +24,10 @@ func (m mockUi) AskSecret(_ string) (string, error) { m.t.FailNow() return "", nil } -func (m mockUi) Output(s string) { - output = s -} -func (m mockUi) Info(s string) { - m.t.Log(s) -} -func (m mockUi) Error(s string) { - m.t.Log(s) -} -func (m mockUi) Warn(s string) { - m.t.Log(s) -} +func (m mockUi) Output(s string) { output = s } +func (m mockUi) Info(s string) { m.t.Log(s) } +func (m mockUi) Error(s string) { m.t.Log(s) } +func (m mockUi) Warn(s string) { m.t.Log(s) } func TestJsonFormatter(t *testing.T) { ui := mockUi{t: t, SampleData: "something"} diff --git a/command/generate-root.go b/command/generate-root.go deleted file mode 100644 index 955e3d8bdf..0000000000 --- a/command/generate-root.go +++ /dev/null @@ -1,405 +0,0 @@ -package command - -import ( - "crypto/rand" - "encoding/base64" - "fmt" - "os" - "strings" - - "github.com/hashicorp/go-uuid" - "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/helper/password" - "github.com/hashicorp/vault/helper/pgpkeys" - "github.com/hashicorp/vault/helper/xor" - "github.com/hashicorp/vault/meta" - "github.com/posener/complete" -) - -// GenerateRootCommand is a Command that generates a new root token. -type GenerateRootCommand struct { - meta.Meta - - // Key can be used to pre-seed the key. If it is set, it will not - // be asked with the `password` helper. - Key string - - // The nonce for the rekey request to send along - Nonce string -} - -func (c *GenerateRootCommand) Run(args []string) int { - var init, cancel, status, genotp, drToken bool - var nonce, decode, otp, pgpKey string - var pgpKeyArr pgpkeys.PubKeyFilesFlag - flags := c.Meta.FlagSet("generate-root", meta.FlagSetDefault) - flags.BoolVar(&init, "init", false, "") - flags.BoolVar(&drToken, "dr-token", false, "") - flags.BoolVar(&cancel, "cancel", false, "") - flags.BoolVar(&status, "status", false, "") - flags.BoolVar(&genotp, "genotp", false, "") - flags.StringVar(&decode, "decode", "", "") - flags.StringVar(&otp, "otp", "", "") - flags.StringVar(&nonce, "nonce", "", "") - flags.Var(&pgpKeyArr, "pgp-key", "") - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - if genotp { - buf := make([]byte, 16) - readLen, err := rand.Read(buf) - if err != nil { - c.Ui.Error(fmt.Sprintf("Error reading random bytes: %s", err)) - return 1 - } - if readLen != 16 { - c.Ui.Error(fmt.Sprintf("Read %d bytes when we should have read 16", readLen)) - return 1 - } - c.Ui.Output(fmt.Sprintf("OTP: %s", base64.StdEncoding.EncodeToString(buf))) - return 0 - } - - if len(decode) > 0 { - if len(otp) == 0 { - c.Ui.Error("Both the value to decode and the OTP must be passed in") - return 1 - } - return c.decode(decode, otp) - } - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 2 - } - - // Check if the root generation is started - f := client.Sys().GenerateRootStatus - if drToken { - f = client.Sys().GenerateDROperationTokenStatus - } - rootGenerationStatus, err := f() - if err != nil { - c.Ui.Error(fmt.Sprintf("Error reading root generation status: %s", err)) - return 1 - } - - // If we are initing, or if we are not started but are not running a - // special function, check otp and pgpkey - checkOtpPgp := false - switch { - case init: - checkOtpPgp = true - case cancel: - case status: - case genotp: - case len(decode) != 0: - case rootGenerationStatus.Started: - default: - checkOtpPgp = true - } - if checkOtpPgp { - switch { - case len(otp) == 0 && (pgpKeyArr == nil || len(pgpKeyArr) == 0): - c.Ui.Error(c.Help()) - return 1 - case len(otp) != 0 && pgpKeyArr != nil && len(pgpKeyArr) != 0: - c.Ui.Error(c.Help()) - return 1 - case len(otp) != 0: - err := c.verifyOTP(otp) - if err != nil { - c.Ui.Error(fmt.Sprintf("Error verifying the provided OTP: %s", err)) - return 1 - } - case pgpKeyArr != nil: - if len(pgpKeyArr) != 1 { - c.Ui.Error("Could not parse PGP key") - return 1 - } - if len(pgpKeyArr[0]) == 0 { - c.Ui.Error("Got an empty PGP key") - return 1 - } - pgpKey = pgpKeyArr[0] - default: - panic("unreachable case") - } - } - - if nonce != "" { - c.Nonce = nonce - } - - // Check if we are running doing any restricted variants - switch { - case init: - return c.initGenerateRoot(client, otp, pgpKey, drToken) - case cancel: - return c.cancelGenerateRoot(client, drToken) - case status: - return c.rootGenerationStatus(client, drToken) - } - - // Start the root generation process if not started - if !rootGenerationStatus.Started { - f := client.Sys().GenerateRootInit - if drToken { - f = client.Sys().GenerateDROperationTokenInit - } - rootGenerationStatus, err = f(otp, pgpKey) - if err != nil { - c.Ui.Error(fmt.Sprintf("Error initializing root generation: %s", err)) - return 1 - } - c.Nonce = rootGenerationStatus.Nonce - } - - serverNonce := rootGenerationStatus.Nonce - - // Get the unseal key - args = flags.Args() - key := c.Key - if len(args) > 0 { - key = args[0] - } - if key == "" { - c.Nonce = serverNonce - fmt.Printf("Root generation operation nonce: %s\n", serverNonce) - fmt.Printf("Key (will be hidden): ") - key, err = password.Read(os.Stdin) - fmt.Printf("\n") - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error attempting to ask for password. The raw error message\n"+ - "is shown below, but the most common reason for this error is\n"+ - "that you attempted to pipe a value into unseal or you're\n"+ - "executing `vault generate-root` from outside of a terminal.\n\n"+ - "You should use `vault generate-root` from a terminal for maximum\n"+ - "security. If this isn't an option, the unseal key can be passed\n"+ - "in using the first parameter.\n\n"+ - "Raw error: %s", err)) - return 1 - } - } - - // Provide the key, this may potentially complete the update - { - f := client.Sys().GenerateRootUpdate - if drToken { - f = client.Sys().GenerateDROperationTokenUpdate - } - statusResp, err := f(strings.TrimSpace(key), c.Nonce) - if err != nil { - c.Ui.Error(fmt.Sprintf("Error attempting generate-root update: %s", err)) - return 1 - } - - c.dumpStatus(statusResp) - } - return 0 -} - -func (c *GenerateRootCommand) verifyOTP(otp string) error { - if len(otp) == 0 { - return fmt.Errorf("No OTP passed in") - } - otpBytes, err := base64.StdEncoding.DecodeString(otp) - if err != nil { - return fmt.Errorf("Error decoding base64 OTP value: %s", err) - } - if otpBytes == nil || len(otpBytes) != 16 { - return fmt.Errorf("Decoded OTP value is invalid or wrong length") - } - - return nil -} - -func (c *GenerateRootCommand) decode(encodedVal, otp string) int { - tokenBytes, err := xor.XORBase64(encodedVal, otp) - if err != nil { - c.Ui.Error(err.Error()) - return 1 - } - - token, err := uuid.FormatUUID(tokenBytes) - if err != nil { - c.Ui.Error(fmt.Sprintf("Error formatting base64 token value: %v", err)) - return 1 - } - - c.Ui.Output(fmt.Sprintf("Root token: %s", token)) - - return 0 -} - -// initGenerateRoot is used to start the generation process -func (c *GenerateRootCommand) initGenerateRoot(client *api.Client, otp string, pgpKey string, drToken bool) int { - // Start the rekey - f := client.Sys().GenerateRootInit - if drToken { - f = client.Sys().GenerateDROperationTokenInit - } - - status, err := f(otp, pgpKey) - if err != nil { - c.Ui.Error(fmt.Sprintf("Error initializing root generation: %s", err)) - return 1 - } - - c.dumpStatus(status) - - return 0 -} - -// cancelGenerateRoot is used to abort the generation process -func (c *GenerateRootCommand) cancelGenerateRoot(client *api.Client, drToken bool) int { - f := client.Sys().GenerateRootCancel - if drToken { - f = client.Sys().GenerateDROperationTokenCancel - } - err := f() - if err != nil { - c.Ui.Error(fmt.Sprintf("Failed to cancel root generation: %s", err)) - return 1 - } - c.Ui.Output("Root generation canceled.") - return 0 -} - -// rootGenerationStatus is used just to fetch and dump the status -func (c *GenerateRootCommand) rootGenerationStatus(client *api.Client, drToken bool) int { - // Check the status - f := client.Sys().GenerateRootStatus - if drToken { - f = client.Sys().GenerateDROperationTokenStatus - } - status, err := f() - if err != nil { - c.Ui.Error(fmt.Sprintf("Error reading root generation status: %s", err)) - return 1 - } - - c.dumpStatus(status) - - return 0 -} - -// dumpStatus dumps the status to output -func (c *GenerateRootCommand) dumpStatus(status *api.GenerateRootStatusResponse) { - // Dump the status - statString := fmt.Sprintf( - "Nonce: %s\n"+ - "Started: %v\n"+ - "Generate Root Progress: %d\n"+ - "Required Keys: %d\n"+ - "Complete: %t", - status.Nonce, - status.Started, - status.Progress, - status.Required, - status.Complete, - ) - if len(status.PGPFingerprint) > 0 { - statString = fmt.Sprintf("%s\nPGP Fingerprint: %s", statString, status.PGPFingerprint) - } - if len(status.EncodedRootToken) > 0 { - statString = fmt.Sprintf("%s\n\nEncoded root token: %s", statString, status.EncodedRootToken) - } else if len(status.EncodedToken) > 0 { - statString = fmt.Sprintf("%s\n\nEncoded token: %s", statString, status.EncodedToken) - } - c.Ui.Output(statString) -} - -func (c *GenerateRootCommand) Synopsis() string { - return "Generates a new root token" -} - -func (c *GenerateRootCommand) Help() string { - helpText := ` -Usage: vault generate-root [options] [key] - - 'generate-root' is used to create a new root token. - - Root generation can only be done when the vault is already unsealed. The - operation is done online, but requires that a threshold of the current unseal - keys be provided. - - One (and only one) of the following must be provided when initializing the - root generation attempt: - - 1) A 16-byte, base64-encoded One Time Password (OTP) provided in the '-otp' - flag; the token is XOR'd with this value before it is returned once the final - unseal key has been provided. The '-decode' operation can be used with this - value and the OTP to output the final token value. The '-genotp' flag can be - used to generate a suitable value. - - or - - 2) A file containing a PGP key (binary or base64-encoded) or a Keybase.io - username in the format of "keybase:" in the '-pgp-key' flag. The - final token value will be encrypted with this public key and base64-encoded. - -General Options: -` + meta.GeneralOptionsUsage() + ` -Generate Root Options: - - -init Initialize the root generation attempt. This can only - be done if no generation is already initiated. - - -cancel Reset the root generation process by throwing away - prior unseal keys and the configuration. - - -status Prints the status of the current attempt. This can be - used to see the status without attempting to provide - an unseal key. - - -decode=abcd Decodes and outputs the generated root token. The OTP - used at '-init' time must be provided in the '-otp' - parameter. - - -genotp Returns a high-quality OTP suitable for passing into - the '-init' method. - - -otp=abcd The base64-encoded 16-byte OTP for use with the - '-init' or '-decode' methods. - - -pgp-key A file on disk containing a binary- or base64-format - public PGP key, or a Keybase username specified as - "keybase:". The output root token will be - encrypted and base64-encoded, in order, with the given - public key. - - -nonce=abcd The nonce provided at initialization time. This same - nonce value must be provided with each unseal key. If - the unseal key is not being passed in via the command - line the nonce parameter is not required, and will - instead be displayed with the key prompt. - - -dr-token Generate a Disaster Recovery operation token. This flag - should be set on '-init', '-cancel', and every time a - key is provided to specify the type of token to generate. -` - return strings.TrimSpace(helpText) -} - -func (c *GenerateRootCommand) AutocompleteArgs() complete.Predictor { - return complete.PredictNothing -} - -func (c *GenerateRootCommand) AutocompleteFlags() complete.Flags { - return complete.Flags{ - "-init": complete.PredictNothing, - "-cancel": complete.PredictNothing, - "-status": complete.PredictNothing, - "-decode": complete.PredictNothing, - "-genotp": complete.PredictNothing, - "-otp": complete.PredictNothing, - "-pgp-key": complete.PredictNothing, - "-nonce": complete.PredictNothing, - } -} diff --git a/command/generate-root_test.go b/command/generate-root_test.go deleted file mode 100644 index 3aa2bd968e..0000000000 --- a/command/generate-root_test.go +++ /dev/null @@ -1,295 +0,0 @@ -package command - -import ( - "context" - "encoding/base64" - "encoding/hex" - "os" - "strings" - "testing" - - "github.com/hashicorp/go-uuid" - "github.com/hashicorp/vault/helper/pgpkeys" - "github.com/hashicorp/vault/helper/xor" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/logical" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" - "github.com/mitchellh/cli" -) - -func TestGenerateRoot_Cancel(t *testing.T) { - core, _, _ := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - c := &GenerateRootCommand{ - Meta: meta.Meta{ - Ui: ui, - }, - } - - otpBytes, err := vault.GenerateRandBytes(16) - if err != nil { - t.Fatal(err) - } - otp := base64.StdEncoding.EncodeToString(otpBytes) - - args := []string{"-address", addr, "-init", "-otp", otp} - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - args = []string{"-address", addr, "-cancel"} - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - config, err := core.GenerateRootConfiguration() - if err != nil { - t.Fatalf("err: %s", err) - } - if config != nil { - t.Fatal("should not have a config for root generation") - } -} - -func TestGenerateRoot_status(t *testing.T) { - core, _, _ := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - c := &GenerateRootCommand{ - Meta: meta.Meta{ - Ui: ui, - }, - } - - otpBytes, err := vault.GenerateRandBytes(16) - if err != nil { - t.Fatal(err) - } - otp := base64.StdEncoding.EncodeToString(otpBytes) - - args := []string{"-address", addr, "-init", "-otp", otp} - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - args = []string{"-address", addr, "-status"} - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - if !strings.Contains(ui.OutputWriter.String(), "Started: true") { - t.Fatalf("bad: %s", ui.OutputWriter.String()) - } -} - -func TestGenerateRoot_OTP(t *testing.T) { - core, ts, keys, _ := vault.TestCoreWithTokenStore(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - c := &GenerateRootCommand{ - Meta: meta.Meta{ - Ui: ui, - }, - } - - // Generate an OTP - otpBytes, err := vault.GenerateRandBytes(16) - if err != nil { - t.Fatal(err) - } - otp := base64.StdEncoding.EncodeToString(otpBytes) - - // Init the attempt - args := []string{ - "-address", addr, - "-init", - "-otp", otp, - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - config, err := core.GenerateRootConfiguration() - if err != nil { - t.Fatalf("err: %v", err) - } - - for _, key := range keys { - ui = new(cli.MockUi) - c = &GenerateRootCommand{ - Key: hex.EncodeToString(key), - Meta: meta.Meta{ - Ui: ui, - }, - } - - c.Nonce = config.Nonce - - // Provide the key - args = []string{ - "-address", addr, - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - } - - beforeNAfter := strings.Split(ui.OutputWriter.String(), "Encoded root token: ") - if len(beforeNAfter) != 2 { - t.Fatalf("did not find encoded root token in %s", ui.OutputWriter.String()) - } - encodedToken := strings.TrimSpace(beforeNAfter[1]) - - decodedToken, err := xor.XORBase64(encodedToken, otp) - if err != nil { - t.Fatal(err) - } - - token, err := uuid.FormatUUID(decodedToken) - if err != nil { - t.Fatal(err) - } - - req := logical.TestRequest(t, logical.ReadOperation, "lookup-self") - req.ClientToken = token - - resp, err := ts.HandleRequest(context.Background(), req) - if err != nil { - t.Fatalf("error running token lookup-self: %v", err) - } - if resp == nil { - t.Fatalf("got nil resp with token lookup-self") - } - if resp.Data == nil { - t.Fatalf("got nil resp.Data with token lookup-self") - } - - if resp.Data["orphan"].(bool) != true || - resp.Data["ttl"].(int64) != 0 || - resp.Data["num_uses"].(int) != 0 || - resp.Data["meta"].(map[string]string) != nil || - len(resp.Data["policies"].([]string)) != 1 || - resp.Data["policies"].([]string)[0] != "root" { - t.Fatalf("bad: %#v", resp.Data) - } - - // Clear the output and run a decode to verify we get the same result - ui.OutputWriter.Reset() - args = []string{ - "-address", addr, - "-decode", encodedToken, - "-otp", otp, - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - beforeNAfter = strings.Split(ui.OutputWriter.String(), "Root token: ") - if len(beforeNAfter) != 2 { - t.Fatalf("did not find decoded root token in %s", ui.OutputWriter.String()) - } - - outToken := strings.TrimSpace(beforeNAfter[1]) - if outToken != token { - t.Fatalf("tokens do not match:\n%s\n%s", token, outToken) - } -} - -func TestGenerateRoot_PGP(t *testing.T) { - core, ts, keys, _ := vault.TestCoreWithTokenStore(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - c := &GenerateRootCommand{ - Meta: meta.Meta{ - Ui: ui, - }, - } - - tempDir, pubFiles, err := getPubKeyFiles(t) - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(tempDir) - - // Init the attempt - args := []string{ - "-address", addr, - "-init", - "-pgp-key", pubFiles[0], - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - config, err := core.GenerateRootConfiguration() - if err != nil { - t.Fatalf("err: %v", err) - } - - for _, key := range keys { - c = &GenerateRootCommand{ - Key: hex.EncodeToString(key), - Meta: meta.Meta{ - Ui: ui, - }, - } - - c.Nonce = config.Nonce - - // Provide the key - args = []string{ - "-address", addr, - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - } - - beforeNAfter := strings.Split(ui.OutputWriter.String(), "Encoded root token: ") - if len(beforeNAfter) != 2 { - t.Fatalf("did not find encoded root token in %s", ui.OutputWriter.String()) - } - encodedToken := strings.TrimSpace(beforeNAfter[1]) - - ptBuf, err := pgpkeys.DecryptBytes(encodedToken, pgpkeys.TestPrivKey1) - if err != nil { - t.Fatal(err) - } - if ptBuf == nil { - t.Fatal("returned plain text buffer is nil") - } - - token := ptBuf.String() - - req := logical.TestRequest(t, logical.ReadOperation, "lookup-self") - req.ClientToken = token - - resp, err := ts.HandleRequest(context.Background(), req) - if err != nil { - t.Fatalf("error running token lookup-self: %v", err) - } - if resp == nil { - t.Fatalf("got nil resp with token lookup-self") - } - if resp.Data == nil { - t.Fatalf("got nil resp.Data with token lookup-self") - } - - if resp.Data["orphan"].(bool) != true || - resp.Data["ttl"].(int64) != 0 || - resp.Data["num_uses"].(int) != 0 || - resp.Data["meta"].(map[string]string) != nil || - len(resp.Data["policies"].([]string)) != 1 || - resp.Data["policies"].([]string)[0] != "root" { - t.Fatalf("bad: %#v", resp.Data) - } -} diff --git a/command/init.go b/command/init.go deleted file mode 100644 index 470c325107..0000000000 --- a/command/init.go +++ /dev/null @@ -1,406 +0,0 @@ -package command - -import ( - "fmt" - "net/url" - "os" - "runtime" - "strings" - - consulapi "github.com/hashicorp/consul/api" - "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/helper/pgpkeys" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/physical/consul" - "github.com/posener/complete" -) - -// InitCommand is a Command that initializes a new Vault server. -type InitCommand struct { - meta.Meta -} - -func (c *InitCommand) Run(args []string) int { - var threshold, shares, storedShares, recoveryThreshold, recoveryShares int - var pgpKeys, recoveryPgpKeys, rootTokenPgpKey pgpkeys.PubKeyFilesFlag - var auto, check bool - var consulServiceName string - flags := c.Meta.FlagSet("init", meta.FlagSetDefault) - flags.Usage = func() { c.Ui.Error(c.Help()) } - flags.IntVar(&shares, "key-shares", 5, "") - flags.IntVar(&threshold, "key-threshold", 3, "") - flags.IntVar(&storedShares, "stored-shares", 0, "") - flags.Var(&pgpKeys, "pgp-keys", "") - flags.Var(&rootTokenPgpKey, "root-token-pgp-key", "") - flags.IntVar(&recoveryShares, "recovery-shares", 5, "") - flags.IntVar(&recoveryThreshold, "recovery-threshold", 3, "") - flags.Var(&recoveryPgpKeys, "recovery-pgp-keys", "") - flags.BoolVar(&check, "check", false, "") - flags.BoolVar(&auto, "auto", false, "") - flags.StringVar(&consulServiceName, "consul-service", consul.DefaultServiceName, "") - if err := flags.Parse(args); err != nil { - return 1 - } - - initRequest := &api.InitRequest{ - SecretShares: shares, - SecretThreshold: threshold, - StoredShares: storedShares, - PGPKeys: pgpKeys, - RecoveryShares: recoveryShares, - RecoveryThreshold: recoveryThreshold, - RecoveryPGPKeys: recoveryPgpKeys, - } - - switch len(rootTokenPgpKey) { - case 0: - case 1: - initRequest.RootTokenPGPKey = rootTokenPgpKey[0] - default: - c.Ui.Error("Only one PGP key can be specified for encrypting the root token") - return 1 - } - - // If running in 'auto' mode, run service discovery based on environment - // variables of Consul. - if auto { - - // Create configuration for Consul - consulConfig := consulapi.DefaultConfig() - - // Create a client to communicate with Consul - consulClient, err := consulapi.NewClient(consulConfig) - if err != nil { - c.Ui.Error(fmt.Sprintf("Failed to create Consul client:%v", err)) - return 1 - } - - // Fetch Vault's protocol scheme from the client - vaultclient, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf("Failed to fetch Vault client: %v", err)) - return 1 - } - - if vaultclient.Address() == "" { - c.Ui.Error("Failed to fetch Vault client address") - return 1 - } - - clientURL, err := url.Parse(vaultclient.Address()) - if err != nil { - c.Ui.Error(fmt.Sprintf("Failed to parse Vault address: %v", err)) - return 1 - } - - if clientURL == nil { - c.Ui.Error("Failed to parse Vault client address") - return 1 - } - - var uninitializedVaults []string - var initializedVault string - - // Query the nodes belonging to the cluster - if services, _, err := consulClient.Catalog().Service(consulServiceName, "", &consulapi.QueryOptions{AllowStale: true}); err == nil { - Loop: - for _, service := range services { - vaultAddress := &url.URL{ - Scheme: clientURL.Scheme, - Host: fmt.Sprintf("%s:%d", service.ServiceAddress, service.ServicePort), - } - - // Set VAULT_ADDR to the discovered node - os.Setenv(api.EnvVaultAddress, vaultAddress.String()) - - // Create a client to communicate with the discovered node - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf("Error initializing client: %v", err)) - return 1 - } - - // Check the initialization status of the discovered node - inited, err := client.Sys().InitStatus() - switch { - case err != nil: - c.Ui.Error(fmt.Sprintf("Error checking initialization status of discovered node: %+q. Err: %v", vaultAddress.String(), err)) - return 1 - case inited: - // One of the nodes in the cluster is initialized. Break out. - initializedVault = vaultAddress.String() - break Loop - default: - // Vault is uninitialized. - uninitializedVaults = append(uninitializedVaults, vaultAddress.String()) - } - } - } - - export := "export" - quote := "'" - if runtime.GOOS == "windows" { - export = "set" - quote = "" - } - - if initializedVault != "" { - vaultURL, err := url.Parse(initializedVault) - if err != nil { - c.Ui.Error(fmt.Sprintf("Failed to parse Vault address: %+q. Err: %v", initializedVault, err)) - } - c.Ui.Output(fmt.Sprintf("Discovered an initialized Vault node at %+q, using Consul service name %+q", vaultURL.String(), consulServiceName)) - c.Ui.Output("\nSet the following environment variable to operate on the discovered Vault:\n") - c.Ui.Output(fmt.Sprintf("\t%s VAULT_ADDR=%s%s%s", export, quote, vaultURL.String(), quote)) - return 0 - } - - switch len(uninitializedVaults) { - case 0: - c.Ui.Error(fmt.Sprintf("Failed to discover Vault nodes using Consul service name %+q", consulServiceName)) - return 1 - case 1: - // There was only one node found in the Vault cluster and it - // was uninitialized. - - vaultURL, err := url.Parse(uninitializedVaults[0]) - if err != nil { - c.Ui.Error(fmt.Sprintf("Failed to parse Vault address: %+q. Err: %v", uninitializedVaults[0], err)) - } - - // Set the VAULT_ADDR to the discovered node. This will ensure - // that the client created will operate on the discovered node. - os.Setenv(api.EnvVaultAddress, vaultURL.String()) - - // Let the client know that initialization is perfomed on the - // discovered node. - c.Ui.Output(fmt.Sprintf("Discovered Vault at %+q using Consul service name %+q\n", vaultURL.String(), consulServiceName)) - - // Attempt initializing it - ret := c.runInit(check, initRequest) - - // Regardless of success or failure, instruct client to update VAULT_ADDR - c.Ui.Output("\nSet the following environment variable to operate on the discovered Vault:\n") - c.Ui.Output(fmt.Sprintf("\t%s VAULT_ADDR=%s%s%s", export, quote, vaultURL.String(), quote)) - - return ret - default: - // If more than one Vault node were discovered, print out all of them, - // requiring the client to update VAULT_ADDR and to run init again. - c.Ui.Output(fmt.Sprintf("Discovered more than one uninitialized Vaults using Consul service name %+q\n", consulServiceName)) - c.Ui.Output("To initialize these Vaults, set any *one* of the following environment variables and run 'vault init':") - - // Print valid commands to make setting the variables easier - for _, vaultNode := range uninitializedVaults { - vaultURL, err := url.Parse(vaultNode) - if err != nil { - c.Ui.Error(fmt.Sprintf("Failed to parse Vault address: %+q. Err: %v", vaultNode, err)) - } - c.Ui.Output(fmt.Sprintf("\t%s VAULT_ADDR=%s%s%s", export, quote, vaultURL.String(), quote)) - - } - return 0 - } - } - - return c.runInit(check, initRequest) -} - -func (c *InitCommand) runInit(check bool, initRequest *api.InitRequest) int { - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 1 - } - - if check { - return c.checkStatus(client) - } - - resp, err := client.Sys().Init(initRequest) - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing Vault: %s", err)) - return 1 - } - - for i, key := range resp.Keys { - if resp.KeysB64 != nil && len(resp.KeysB64) == len(resp.Keys) { - c.Ui.Output(fmt.Sprintf("Unseal Key %d: %s", i+1, resp.KeysB64[i])) - } else { - c.Ui.Output(fmt.Sprintf("Unseal Key %d: %s", i+1, key)) - } - } - for i, key := range resp.RecoveryKeys { - if resp.RecoveryKeysB64 != nil && len(resp.RecoveryKeysB64) == len(resp.RecoveryKeys) { - c.Ui.Output(fmt.Sprintf("Recovery Key %d: %s", i+1, resp.RecoveryKeysB64[i])) - } else { - c.Ui.Output(fmt.Sprintf("Recovery Key %d: %s", i+1, key)) - } - } - - c.Ui.Output(fmt.Sprintf("Initial Root Token: %s", resp.RootToken)) - - if initRequest.StoredShares < 1 { - c.Ui.Output(fmt.Sprintf( - "\n"+ - "Vault initialized with %d keys and a key threshold of %d. Please\n"+ - "securely distribute the above keys. When the vault is re-sealed,\n"+ - "restarted, or stopped, you must provide at least %d of these keys\n"+ - "to unseal it again.\n\n"+ - "Vault does not store the master key. Without at least %d keys,\n"+ - "your vault will remain permanently sealed.", - initRequest.SecretShares, - initRequest.SecretThreshold, - initRequest.SecretThreshold, - initRequest.SecretThreshold, - )) - } else { - c.Ui.Output( - "\n" + - "Vault initialized successfully.", - ) - } - if len(resp.RecoveryKeys) > 0 { - c.Ui.Output(fmt.Sprintf( - "\n"+ - "Recovery key initialized with %d keys and a key threshold of %d. Please\n"+ - "securely distribute the above keys.", - initRequest.RecoveryShares, - initRequest.RecoveryThreshold, - )) - } - - return 0 -} - -func (c *InitCommand) checkStatus(client *api.Client) int { - inited, err := client.Sys().InitStatus() - switch { - case err != nil: - c.Ui.Error(fmt.Sprintf( - "Error checking initialization status: %s", err)) - return 1 - case inited: - c.Ui.Output("Vault has been initialized") - return 0 - default: - c.Ui.Output("Vault is not initialized") - return 2 - } -} - -func (c *InitCommand) Synopsis() string { - return "Initialize a new Vault server" -} - -func (c *InitCommand) Help() string { - helpText := ` -Usage: vault init [options] - - Initialize a new Vault server. - - This command connects to a Vault server and initializes it for the - first time. This sets up the initial set of master keys and the - backend data store structure. - - This command can't be called on an already-initialized Vault server. - -General Options: -` + meta.GeneralOptionsUsage() + ` -Init Options: - - -check Don't actually initialize, just check if Vault is - already initialized. A return code of 0 means Vault - is initialized; a return code of 2 means Vault is not - initialized; a return code of 1 means an error was - encountered. - - -key-shares=5 The number of key shares to split the master key - into. - - -key-threshold=3 The number of key shares required to reconstruct - the master key. - - -stored-shares=0 The number of unseal keys to store. Only used with - Vault HSM. Must currently be equivalent to the - number of shares. - - -pgp-keys If provided, must be a comma-separated list of - files on disk containing binary- or base64-format - public PGP keys, or Keybase usernames specified as - "keybase:". The output unseal keys will - be encrypted and base64-encoded, in order, with the - given public keys. If you want to use them with the - 'vault unseal' command, you will need to base64- - decode and decrypt; this will be the plaintext - unseal key. When 'stored-shares' are not used, the - number of entries in this field must match 'key-shares'. - When 'stored-shares' are used, the number of entries - should match the difference between 'key-shares' - and 'stored-shares'. - - -root-token-pgp-key If provided, a file on disk with a binary- or - base64-format public PGP key, or a Keybase username - specified as "keybase:". The output root - token will be encrypted and base64-encoded, in - order, with the given public key. You will need - to base64-decode and decrypt the result. - - -recovery-shares=5 The number of key shares to split the recovery key - into. Only used with Vault HSM. - - -recovery-threshold=3 The number of key shares required to reconstruct - the recovery key. Only used with Vault HSM. - - -recovery-pgp-keys If provided, behaves like "pgp-keys" but for the - recovery key shares. Only used with Vault HSM. - - -auto If set, performs service discovery using Consul. - When all the nodes of a Vault cluster are - registered with Consul, setting this flag will - trigger service discovery using the service name - with which Vault nodes are registered. This option - works well when each Vault cluster is registered - under a unique service name. Note that, when Consul - is serving as Vault's HA backend, Vault nodes are - registered with Consul by default. The service name - can be changed using 'consul-service' flag. Ensure - that environment variables required to communicate - with Consul, like (CONSUL_HTTP_ADDR, - CONSUL_HTTP_TOKEN, CONSUL_HTTP_SSL, et al) are - properly set. When only one Vault node is - discovered, it will be initialized and when more - than one Vault node is discovered, they will be - output for easy selection. - - -consul-service Service name under which all the nodes of a Vault - cluster are registered with Consul. Note that, when - Vault uses Consul as its HA backend, by default, - Vault will register itself as a service with Consul - with the service name "vault". This name can be - modified in Vault's configuration file, using the - "service" option for the Consul backend. -` - return strings.TrimSpace(helpText) -} - -func (c *InitCommand) AutocompleteArgs() complete.Predictor { - return complete.PredictNothing -} - -func (c *InitCommand) AutocompleteFlags() complete.Flags { - return complete.Flags{ - "-check": complete.PredictNothing, - "-key-shares": complete.PredictNothing, - "-key-threshold": complete.PredictNothing, - "-pgp-keys": complete.PredictNothing, - "-root-token-pgp-key": complete.PredictNothing, - "-recovery-shares": complete.PredictNothing, - "-recovery-threshold": complete.PredictNothing, - "-recovery-pgp-keys": complete.PredictNothing, - "-auto": complete.PredictNothing, - "-consul-service": complete.PredictNothing, - } -} diff --git a/command/init_test.go b/command/init_test.go deleted file mode 100644 index e09ba80dc0..0000000000 --- a/command/init_test.go +++ /dev/null @@ -1,343 +0,0 @@ -package command - -import ( - "bytes" - "encoding/base64" - "os" - "reflect" - "regexp" - "strings" - "testing" - - "github.com/hashicorp/vault/helper/pgpkeys" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" - "github.com/keybase/go-crypto/openpgp" - "github.com/keybase/go-crypto/openpgp/packet" - "github.com/mitchellh/cli" -) - -func TestInit(t *testing.T) { - ui := new(cli.MockUi) - c := &InitCommand{ - Meta: meta.Meta{ - Ui: ui, - }, - } - - core := vault.TestCore(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - init, err := core.Initialized() - if err != nil { - t.Fatalf("err: %s", err) - } - if init { - t.Fatal("should not be initialized") - } - - args := []string{"-address", addr} - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - init, err = core.Initialized() - if err != nil { - t.Fatalf("err: %s", err) - } - if !init { - t.Fatal("should be initialized") - } - - sealConf, err := core.SealAccess().BarrierConfig() - if err != nil { - t.Fatalf("err: %s", err) - } - expected := &vault.SealConfig{ - Type: "shamir", - SecretShares: 5, - SecretThreshold: 3, - } - if !reflect.DeepEqual(expected, sealConf) { - t.Fatalf("expected:\n%#v\ngot:\n%#v\n", expected, sealConf) - } -} - -func TestInit_Check(t *testing.T) { - ui := new(cli.MockUi) - c := &InitCommand{ - Meta: meta.Meta{ - Ui: ui, - }, - } - - core := vault.TestCore(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - // Should return 2, not initialized - args := []string{"-address", addr, "-check"} - if code := c.Run(args); code != 2 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - // Now initialize it - args = []string{"-address", addr} - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - // Should return 0, initialized - args = []string{"-address", addr, "-check"} - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - init, err := core.Initialized() - if err != nil { - t.Fatalf("err: %s", err) - } - if !init { - t.Fatal("should be initialized") - } -} - -func TestInit_custom(t *testing.T) { - ui := new(cli.MockUi) - c := &InitCommand{ - Meta: meta.Meta{ - Ui: ui, - }, - } - - core := vault.TestCore(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - init, err := core.Initialized() - if err != nil { - t.Fatalf("err: %s", err) - } - if init { - t.Fatal("should not be initialized") - } - - args := []string{ - "-address", addr, - "-key-shares", "7", - "-key-threshold", "3", - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - init, err = core.Initialized() - if err != nil { - t.Fatalf("err: %s", err) - } - if !init { - t.Fatal("should be initialized") - } - - sealConf, err := core.SealAccess().BarrierConfig() - if err != nil { - t.Fatalf("err: %s", err) - } - expected := &vault.SealConfig{ - Type: "shamir", - SecretShares: 7, - SecretThreshold: 3, - } - if !reflect.DeepEqual(expected, sealConf) { - t.Fatalf("expected:\n%#v\ngot:\n%#v\n", expected, sealConf) - } - - re, err := regexp.Compile("\\s+Initial Root Token:\\s+(.*)") - if err != nil { - t.Fatalf("Error compiling regex: %s", err) - } - matches := re.FindAllStringSubmatch(ui.OutputWriter.String(), -1) - if len(matches) != 1 { - t.Fatalf("Unexpected number of tokens found, got %d", len(matches)) - } - - rootToken := matches[0][1] - - client, err := c.Client() - if err != nil { - t.Fatalf("Error fetching client: %v", err) - } - - client.SetToken(rootToken) - - re, err = regexp.Compile("\\s*Unseal Key \\d+: (.*)") - if err != nil { - t.Fatalf("Error compiling regex: %s", err) - } - matches = re.FindAllStringSubmatch(ui.OutputWriter.String(), -1) - if len(matches) != 7 { - t.Fatalf("Unexpected number of keys returned, got %d, matches was \n\n%#v\n\n, input was \n\n%s\n\n", len(matches), matches, ui.OutputWriter.String()) - } - - var unsealed bool - for i := 0; i < 3; i++ { - decodedKey, err := base64.StdEncoding.DecodeString(strings.TrimSpace(matches[i][1])) - if err != nil { - t.Fatalf("err decoding key %v: %v", matches[i][1], err) - } - unsealed, err = core.Unseal(decodedKey) - if err != nil { - t.Fatalf("err during unseal: %v; key was %v", err, matches[i][1]) - } - } - if !unsealed { - t.Fatal("expected to be unsealed") - } - - tokenInfo, err := client.Auth().Token().LookupSelf() - if err != nil { - t.Fatalf("Error looking up root token info: %v", err) - } - - if tokenInfo.Data["policies"].([]interface{})[0].(string) != "root" { - t.Fatalf("expected root policy") - } -} - -func TestInit_PGP(t *testing.T) { - ui := new(cli.MockUi) - c := &InitCommand{ - Meta: meta.Meta{ - Ui: ui, - }, - } - - core := vault.TestCore(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - init, err := core.Initialized() - if err != nil { - t.Fatalf("err: %s", err) - } - if init { - t.Fatal("should not be initialized") - } - - tempDir, pubFiles, err := getPubKeyFiles(t) - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(tempDir) - - args := []string{ - "-address", addr, - "-key-shares", "2", - "-pgp-keys", pubFiles[0] + ",@" + pubFiles[1] + "," + pubFiles[2], - "-key-threshold", "2", - "-root-token-pgp-key", pubFiles[0], - } - - // This should fail, as key-shares does not match pgp-keys size - if code := c.Run(args); code == 0 { - t.Fatalf("bad (command should have failed): %d\n\n%s", code, ui.ErrorWriter.String()) - } - - args = []string{ - "-address", addr, - "-key-shares", "4", - "-pgp-keys", pubFiles[0] + ",@" + pubFiles[1] + "," + pubFiles[2] + "," + pubFiles[3], - "-key-threshold", "2", - "-root-token-pgp-key", pubFiles[0], - } - - ui.OutputWriter.Reset() - - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - init, err = core.Initialized() - if err != nil { - t.Fatalf("err: %s", err) - } - if !init { - t.Fatal("should be initialized") - } - - sealConf, err := core.SealAccess().BarrierConfig() - if err != nil { - t.Fatalf("err: %s", err) - } - - pgpKeys := []string{} - for _, pubFile := range pubFiles { - pub, err := pgpkeys.ReadPGPFile(pubFile) - if err != nil { - t.Fatalf("bad: %v", err) - } - pgpKeys = append(pgpKeys, pub) - } - - expected := &vault.SealConfig{ - Type: "shamir", - SecretShares: 4, - SecretThreshold: 2, - PGPKeys: pgpKeys, - } - if !reflect.DeepEqual(expected, sealConf) { - t.Fatalf("expected:\n%#v\ngot:\n%#v\n", expected, sealConf) - } - - re, err := regexp.Compile("\\s+Initial Root Token:\\s+(.*)") - if err != nil { - t.Fatalf("Error compiling regex: %s", err) - } - matches := re.FindAllStringSubmatch(ui.OutputWriter.String(), -1) - if len(matches) != 1 { - t.Fatalf("Unexpected number of tokens found, got %d", len(matches)) - } - - encRootToken := matches[0][1] - privKeyBytes, err := base64.StdEncoding.DecodeString(pgpkeys.TestPrivKey1) - if err != nil { - t.Fatalf("error decoding private key: %v", err) - } - ptBuf := bytes.NewBuffer(nil) - entity, err := openpgp.ReadEntity(packet.NewReader(bytes.NewBuffer(privKeyBytes))) - if err != nil { - t.Fatalf("Error parsing private key: %s", err) - } - var rootBytes []byte - rootBytes, err = base64.StdEncoding.DecodeString(encRootToken) - if err != nil { - t.Fatalf("Error decoding root token: %s", err) - } - entityList := &openpgp.EntityList{entity} - md, err := openpgp.ReadMessage(bytes.NewBuffer(rootBytes), entityList, nil, nil) - if err != nil { - t.Fatalf("Error decrypting root token: %s", err) - } - ptBuf.ReadFrom(md.UnverifiedBody) - rootToken := ptBuf.String() - - parseDecryptAndTestUnsealKeys(t, ui.OutputWriter.String(), rootToken, false, nil, nil, core) - - client, err := c.Client() - if err != nil { - t.Fatalf("Error fetching client: %v", err) - } - - client.SetToken(rootToken) - - tokenInfo, err := client.Auth().Token().LookupSelf() - if err != nil { - t.Fatalf("Error looking up root token info: %v", err) - } - - if tokenInfo.Data["policies"].([]interface{})[0].(string) != "root" { - t.Fatalf("expected root policy") - } -} diff --git a/command/key_status.go b/command/key_status.go deleted file mode 100644 index ff1b0860c6..0000000000 --- a/command/key_status.go +++ /dev/null @@ -1,55 +0,0 @@ -package command - -import ( - "fmt" - "strings" - - "github.com/hashicorp/vault/meta" -) - -// KeyStatusCommand is a Command that provides information about the key status -type KeyStatusCommand struct { - meta.Meta -} - -func (c *KeyStatusCommand) Run(args []string) int { - flags := c.Meta.FlagSet("key-status", meta.FlagSetDefault) - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 2 - } - - status, err := client.Sys().KeyStatus() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error reading audits: %s", err)) - return 2 - } - - c.Ui.Output(fmt.Sprintf("Key Term: %d", status.Term)) - c.Ui.Output(fmt.Sprintf("Installation Time: %v", status.InstallTime)) - return 0 -} - -func (c *KeyStatusCommand) Synopsis() string { - return "Provides information about the active encryption key" -} - -func (c *KeyStatusCommand) Help() string { - helpText := ` -Usage: vault key-status [options] - - Provides information about the active encryption key. Specifically, - the current key term and the key installation time. - -General Options: -` + meta.GeneralOptionsUsage() - return strings.TrimSpace(helpText) -} diff --git a/command/key_status_test.go b/command/key_status_test.go deleted file mode 100644 index 0adcefa7f4..0000000000 --- a/command/key_status_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package command - -import ( - "testing" - - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" - "github.com/mitchellh/cli" -) - -func TestKeyStatus(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - c := &KeyStatusCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, - }, - } - - args := []string{ - "-address", addr, - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } -} diff --git a/command/lease.go b/command/lease.go new file mode 100644 index 0000000000..76f6cc174c --- /dev/null +++ b/command/lease.go @@ -0,0 +1,40 @@ +package command + +import ( + "strings" + + "github.com/mitchellh/cli" +) + +var _ cli.Command = (*LeaseCommand)(nil) + +type LeaseCommand struct { + *BaseCommand +} + +func (c *LeaseCommand) Synopsis() string { + return "Interact with leases" +} + +func (c *LeaseCommand) Help() string { + helpText := ` +Usage: vault lease [options] [args] + + This command groups subcommands for interacting with leases. Users can revoke + or renew leases. + + Renew a lease: + + $ vault lease renew database/creds/readonly/2f6a614c... + + Revoke a lease: + + $ vault lease revoke database/creds/readonly/2f6a614c... +` + + return strings.TrimSpace(helpText) +} + +func (c *LeaseCommand) Run(args []string) int { + return cli.RunResultHelp +} diff --git a/command/lease_renew.go b/command/lease_renew.go new file mode 100644 index 0000000000..4dd2e1c573 --- /dev/null +++ b/command/lease_renew.go @@ -0,0 +1,127 @@ +package command + +import ( + "fmt" + "strings" + "time" + + "github.com/mitchellh/cli" + "github.com/posener/complete" +) + +var _ cli.Command = (*LeaseRenewCommand)(nil) +var _ cli.CommandAutocomplete = (*LeaseRenewCommand)(nil) + +type LeaseRenewCommand struct { + *BaseCommand + + flagIncrement time.Duration +} + +func (c *LeaseRenewCommand) Synopsis() string { + return "Renews the lease of a secret" +} + +func (c *LeaseRenewCommand) Help() string { + helpText := ` +Usage: vault lease renew [options] ID + + Renews the lease on a secret, extending the time that it can be used before + it is revoked by Vault. + + Every secret in Vault has a lease associated with it. If the owner of the + secret wants to use it longer than the lease, then it must be renewed. + Renewing the lease does not change the contents of the secret. The ID is the + full path lease ID. + + Renew a secret: + + $ vault lease renew database/creds/readonly/2f6a614c... + + Lease renewal will fail if the secret is not renewable, the secret has already + been revoked, or if the secret has already reached its maximum TTL. + + For a full list of examples, please see the documentation. + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *LeaseRenewCommand) Flags() *FlagSets { + set := c.flagSet(FlagSetHTTP | FlagSetOutputFormat) + f := set.NewFlagSet("Command Options") + + f.DurationVar(&DurationVar{ + Name: "increment", + Target: &c.flagIncrement, + Default: 0, + EnvVar: "", + Completion: complete.PredictAnything, + Usage: "Request a specific increment in seconds. Vault is not required " + + "to honor this request.", + }) + + return set +} + +func (c *LeaseRenewCommand) AutocompleteArgs() complete.Predictor { + return complete.PredictAnything +} + +func (c *LeaseRenewCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *LeaseRenewCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + leaseID := "" + increment := c.flagIncrement + + args = f.Args() + switch len(args) { + case 0: + c.UI.Error("Missing ID!") + return 1 + case 1: + leaseID = strings.TrimSpace(args[0]) + case 2: + // Deprecation + // TODO: remove in 0.9.0 + c.UI.Warn(wrapAtLength( + "WARNING! Specifying INCREMENT as a second argument is deprecated. " + + "Please use -increment instead. This will be removed in the next " + + "major release of Vault.")) + + leaseID = strings.TrimSpace(args[0]) + parsed, err := time.ParseDuration(appendDurationSuffix(args[1])) + if err != nil { + c.UI.Error(fmt.Sprintf("Invalid increment: %s", err)) + return 1 + } + increment = parsed + default: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 1-2, got %d)", len(args))) + return 1 + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + secret, err := client.Sys().Renew(leaseID, truncateToSeconds(increment)) + if err != nil { + c.UI.Error(fmt.Sprintf("Error renewing %s: %s", leaseID, err)) + return 2 + } + + return OutputSecret(c.UI, c.flagFormat, secret) +} diff --git a/command/lease_renew_test.go b/command/lease_renew_test.go new file mode 100644 index 0000000000..166e5156be --- /dev/null +++ b/command/lease_renew_test.go @@ -0,0 +1,170 @@ +package command + +import ( + "strings" + "testing" + + "github.com/hashicorp/vault/api" + "github.com/mitchellh/cli" +) + +func testLeaseRenewCommand(tb testing.TB) (*cli.MockUi, *LeaseRenewCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &LeaseRenewCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +// testLeaseRenewCommandMountAndLease mounts a leased secret backend and returns +// the leaseID of an item. +func testLeaseRenewCommandMountAndLease(tb testing.TB, client *api.Client) string { + if err := client.Sys().Mount("testing", &api.MountInput{ + Type: "generic-leased", + }); err != nil { + tb.Fatal(err) + } + + if _, err := client.Logical().Write("testing/foo", map[string]interface{}{ + "key": "value", + "lease": "5m", + }); err != nil { + tb.Fatal(err) + } + + // Read the secret back to get the leaseID + secret, err := client.Logical().Read("testing/foo") + if err != nil { + tb.Fatal(err) + } + if secret == nil || secret.LeaseID == "" { + tb.Fatalf("missing secret or lease: %#v", secret) + } + + return secret.LeaseID +} + +func TestLeaseRenewCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "empty", + nil, + "Missing ID!", + 1, + }, + { + "increment", + []string{"-increment", "60s"}, + "foo", + 0, + }, + { + "increment_no_suffix", + []string{"-increment", "60"}, + "foo", + 0, + }, + { + "format", + []string{"-format", "json"}, + "{", + 0, + }, + { + "format_bad", + []string{"-format", "nope-not-real"}, + "Invalid output format", + 1, + }, + } + + t.Run("group", func(t *testing.T) { + t.Parallel() + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + leaseID := testLeaseRenewCommandMountAndLease(t, client) + + ui, cmd := testLeaseRenewCommand(t) + cmd.client = client + + if tc.args != nil { + tc.args = append(tc.args, leaseID) + } + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) + + t.Run("integration", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + leaseID := testLeaseRenewCommandMountAndLease(t, client) + + _, cmd := testLeaseRenewCommand(t) + cmd.client = client + + code := cmd.Run([]string{leaseID}) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testLeaseRenewCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "foo/bar", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error renewing foo/bar: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testLeaseRenewCommand(t) + assertNoTabs(t, cmd) + }) +} diff --git a/command/lease_revoke.go b/command/lease_revoke.go new file mode 100644 index 0000000000..45f5dc3a76 --- /dev/null +++ b/command/lease_revoke.go @@ -0,0 +1,142 @@ +package command + +import ( + "fmt" + "strings" + + "github.com/mitchellh/cli" + "github.com/posener/complete" +) + +var _ cli.Command = (*LeaseRevokeCommand)(nil) +var _ cli.CommandAutocomplete = (*LeaseRevokeCommand)(nil) + +type LeaseRevokeCommand struct { + *BaseCommand + + flagForce bool + flagPrefix bool +} + +func (c *LeaseRevokeCommand) Synopsis() string { + return "Revokes leases and secrets" +} + +func (c *LeaseRevokeCommand) Help() string { + helpText := ` +Usage: vault lease revoke [options] ID + + Revokes secrets by their lease ID. This command can revoke a single secret + or multiple secrets based on a path-matched prefix. + + Revoke a single lease: + + $ vault lease revoke database/creds/readonly/2f6a614c... + + Revoke all leases for a role: + + $ vault lease revoke -prefix aws/creds/deploy + + Force delete leases from Vault even if secret engine revocation fails: + + $ vault lease revoke -force -prefix consul/creds + + For a full list of examples and paths, please see the documentation that + corresponds to the secret engine in use. + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *LeaseRevokeCommand) Flags() *FlagSets { + set := c.flagSet(FlagSetHTTP) + f := set.NewFlagSet("Command Options") + + f.BoolVar(&BoolVar{ + Name: "force", + Aliases: []string{"f"}, + Target: &c.flagForce, + Default: false, + Usage: "Delete the lease from Vault even if the secret engine revocation " + + "fails. This is meant for recovery situations where the secret " + + "in the target secret engine was manually removed. If this flag is " + + "specified, -prefix is also required.", + }) + + f.BoolVar(&BoolVar{ + Name: "prefix", + Target: &c.flagPrefix, + Default: false, + Usage: "Treat the ID as a prefix instead of an exact lease ID. This can " + + "revoke multiple leases simultaneously.", + }) + + return set +} + +func (c *LeaseRevokeCommand) AutocompleteArgs() complete.Predictor { + return c.PredictVaultFiles() +} + +func (c *LeaseRevokeCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *LeaseRevokeCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + switch { + case len(args) < 1: + c.UI.Error(fmt.Sprintf("Not enough arguments (expected 1, got %d)", len(args))) + return 1 + case len(args) > 1: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args))) + return 1 + } + + if c.flagForce && !c.flagPrefix { + c.UI.Error("Specifying -force requires also specifying -prefix") + return 1 + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + leaseID := strings.TrimSpace(args[0]) + + switch { + case c.flagForce && c.flagPrefix: + c.UI.Warn(wrapAtLength("Warning! Force-removing leases can cause Vault " + + "to become out of sync with secret engines!")) + if err := client.Sys().RevokeForce(leaseID); err != nil { + c.UI.Error(fmt.Sprintf("Error force revoking leases with prefix %s: %s", leaseID, err)) + return 2 + } + c.UI.Output(fmt.Sprintf("Success! Force revoked any leases with prefix: %s", leaseID)) + return 0 + case c.flagPrefix: + if err := client.Sys().RevokePrefix(leaseID); err != nil { + c.UI.Error(fmt.Sprintf("Error revoking leases with prefix %s: %s", leaseID, err)) + return 2 + } + c.UI.Output(fmt.Sprintf("Success! Revoked any leases with prefix: %s", leaseID)) + return 0 + default: + if err := client.Sys().Revoke(leaseID); err != nil { + c.UI.Error(fmt.Sprintf("Error revoking lease %s: %s", leaseID, err)) + return 2 + } + c.UI.Output(fmt.Sprintf("Success! Revoked lease: %s", leaseID)) + return 0 + } +} diff --git a/command/lease_revoke_test.go b/command/lease_revoke_test.go new file mode 100644 index 0000000000..97904a4d7f --- /dev/null +++ b/command/lease_revoke_test.go @@ -0,0 +1,134 @@ +package command + +import ( + "strings" + "testing" + + "github.com/hashicorp/vault/api" + "github.com/mitchellh/cli" +) + +func testLeaseRevokeCommand(tb testing.TB) (*cli.MockUi, *LeaseRevokeCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &LeaseRevokeCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestLeaseRevokeCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "force_without_prefix", + []string{"-force"}, + "requires also specifying -prefix", + 1, + }, + { + "single", + nil, + "Success", + 0, + }, + { + "force_prefix", + []string{"-force", "-prefix"}, + "Success", + 0, + }, + { + "prefix", + []string{"-prefix"}, + "Success", + 0, + }, + } + + t.Run("validations", func(t *testing.T) { + t.Parallel() + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().Mount("secret-leased", &api.MountInput{ + Type: "generic-leased", + }); err != nil { + t.Fatal(err) + } + + path := "secret-leased/revoke/" + tc.name + data := map[string]interface{}{ + "key": "value", + "lease": "1m", + } + if _, err := client.Logical().Write(path, data); err != nil { + t.Fatal(err) + } + secret, err := client.Logical().Read(path) + if err != nil { + t.Fatal(err) + } + + ui, cmd := testLeaseRevokeCommand(t) + cmd.client = client + + tc.args = append(tc.args, secret.LeaseID) + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testLeaseRevokeCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "foo/bar", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error revoking lease foo/bar: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testLeaseRevokeCommand(t) + assertNoTabs(t, cmd) + }) +} diff --git a/command/list.go b/command/list.go index 71bf388c90..ecfa0acb56 100644 --- a/command/list.go +++ b/command/list.go @@ -1,97 +1,101 @@ package command import ( - "flag" "fmt" "strings" - "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/meta" + "github.com/mitchellh/cli" + "github.com/posener/complete" ) -// ListCommand is a Command that lists data from the Vault. +var _ cli.Command = (*ListCommand)(nil) +var _ cli.CommandAutocomplete = (*ListCommand)(nil) + type ListCommand struct { - meta.Meta + *BaseCommand +} + +func (c *ListCommand) Synopsis() string { + return "List data or secrets" +} + +func (c *ListCommand) Help() string { + helpText := ` + +Usage: vault list [options] PATH + + Lists data from Vault at the given path. This can be used to list keys in a, + given secret engine. + + List values under the "my-app" folder of the generic secret engine: + + $ vault list secret/my-app/ + + For a full list of examples and paths, please see the documentation that + corresponds to the secret engine in use. Not all engines support listing. + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *ListCommand) Flags() *FlagSets { + return c.flagSet(FlagSetHTTP | FlagSetOutputFormat) +} + +func (c *ListCommand) AutocompleteArgs() complete.Predictor { + return c.PredictVaultFolders() +} + +func (c *ListCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() } func (c *ListCommand) Run(args []string) int { - var format string - var err error - var secret *api.Secret - var flags *flag.FlagSet - flags = c.Meta.FlagSet("list", meta.FlagSetDefault) - flags.StringVar(&format, "format", "table", "") - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) return 1 } - args = flags.Args() - if len(args) != 1 || len(args[0]) == 0 { - c.Ui.Error("list expects one argument") - flags.Usage() + args = f.Args() + switch { + case len(args) < 1: + c.UI.Error(fmt.Sprintf("Not enough arguments (expected 1, got %d)", len(args))) + return 1 + case len(args) > 1: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args))) return 1 - } - - path := args[0] - if path[0] == '/' { - path = path[1:] - } - - if !strings.HasSuffix(path, "/") { - path = path + "/" } client, err := c.Client() if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) + c.UI.Error(err.Error()) return 2 } - secret, err = client.Logical().List(path) + path := ensureTrailingSlash(sanitizePath(args[0])) + + secret, err := client.Logical().List(path) if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error reading %s: %s", path, err)) - return 1 + c.UI.Error(fmt.Sprintf("Error listing %s: %s", path, err)) + return 2 } - if secret == nil { - c.Ui.Error(fmt.Sprintf( - "No value found at %s", path)) - return 1 + if secret == nil || secret.Data == nil { + c.UI.Error(fmt.Sprintf("No value found at %s", path)) + return 2 } + + // If the secret is wrapped, return the wrapped response. if secret.WrapInfo != nil && secret.WrapInfo.TTL != 0 { - return OutputSecret(c.Ui, format, secret) + return OutputSecret(c.UI, c.flagFormat, secret) } - if secret.Data["keys"] == nil { - c.Ui.Error("No entries found") - return 0 + if _, ok := extractListData(secret); !ok { + c.UI.Error(fmt.Sprintf("No entries found at %s", path)) + return 2 } - return OutputList(c.Ui, format, secret) -} - -func (c *ListCommand) Synopsis() string { - return "List data or secrets in Vault" -} - -func (c *ListCommand) Help() string { - helpText := - ` -Usage: vault list [options] path - - List data from Vault. - - Retrieve a listing of available data. The data returned, if any, is backend- - and endpoint-specific. - -General Options: -` + meta.GeneralOptionsUsage() + ` -Read Options: - - -format=table The format for output. By default it is a whitespace- - delimited table. This can also be json or yaml. -` - return strings.TrimSpace(helpText) + return OutputList(c.UI, c.flagFormat, secret) } diff --git a/command/list_test.go b/command/list_test.go index 1f75c0b25b..e5a77c532b 100644 --- a/command/list_test.go +++ b/command/list_test.go @@ -1,71 +1,150 @@ package command import ( - "reflect" + "strings" "testing" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" ) -func TestList(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func testListCommand(tb testing.TB) (*cli.MockUi, *ListCommand) { + tb.Helper() - ui := new(cli.MockUi) - c := &ReadCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + ui := cli.NewMockUi() + return ui, &ListCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestListCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "not_enough_args", + []string{}, + "Not enough arguments", + 1, + }, + { + "too_many_args", + []string{"foo", "bar"}, + "Too many arguments", + 1, + }, + { + "not_found", + []string{"nope/not/once/never"}, + "", + 2, + }, + { + "default", + []string{"secret/list"}, + "bar\nbaz\nfoo", + 0, + }, + { + "default_slash", + []string{"secret/list/"}, + "bar\nbaz\nfoo", + 0, + }, + { + "format", + []string{ + "-format", "json", + "secret/list/", + }, + "[", + 0, + }, + { + "format_bad", + []string{ + "-format", "nope-not-real", + "secret/list/", + }, + "Invalid output format", + 1, }, } - args := []string{ - "-address", addr, - "-format", "json", - "secret", - } + t.Run("validations", func(t *testing.T) { + t.Parallel() - // Run once so the client is setup, ignore errors - c.Run(args) + for _, tc := range cases { + tc := tc - // Get the client so we can write data - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - data := map[string]interface{}{"value": "bar"} - if _, err := client.Logical().Write("secret/foo", data); err != nil { - t.Fatalf("err: %s", err) - } + client, closer := testVaultServer(t) + defer closer() - data = map[string]interface{}{"value": "bar"} - if _, err := client.Logical().Write("secret/foo/bar", data); err != nil { - t.Fatalf("err: %s", err) - } + keys := []string{ + "secret/list/foo", + "secret/list/bar", + "secret/list/baz", + } + for _, k := range keys { + if _, err := client.Logical().Write(k, map[string]interface{}{ + "foo": "bar", + }); err != nil { + t.Fatal(err) + } + } - secret, err := client.Logical().List("secret/") - if err != nil { - t.Fatalf("err: %s", err) - } + ui, cmd := testListCommand(t) + cmd.client = client - if secret == nil { - t.Fatalf("err: No value found at secret/") - } + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } - if secret.Data == nil { - t.Fatalf("err: Data not found") - } + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) - exp := map[string]interface{}{ - "keys": []interface{}{"foo", "foo/"}, - } + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() - if !reflect.DeepEqual(secret.Data, exp) { - t.Fatalf("err: expected %#v, got %#v", exp, secret.Data) - } + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testListCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "secret/list", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error listing secret/list/: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testListCommand(t) + assertNoTabs(t, cmd) + }) } diff --git a/command/login.go b/command/login.go new file mode 100644 index 0000000000..564b893bb8 --- /dev/null +++ b/command/login.go @@ -0,0 +1,379 @@ +package command + +import ( + "fmt" + "io" + "os" + "strings" + + "github.com/hashicorp/vault/api" + "github.com/posener/complete" +) + +// LoginHandler is the interface that any auth handlers must implement to enable +// auth via the CLI. +type LoginHandler interface { + Auth(*api.Client, map[string]string) (*api.Secret, error) + Help() string +} + +type LoginCommand struct { + *BaseCommand + + Handlers map[string]LoginHandler + + flagMethod string + flagPath string + flagNoStore bool + flagTokenOnly bool + + // Deprecations + // TODO: remove in 0.9.0 + flagNoVerify bool + + testStdin io.Reader // for tests +} + +func (c *LoginCommand) Synopsis() string { + return "Authenticate locally" +} + +func (c *LoginCommand) Help() string { + helpText := ` +Usage: vault login [options] [AUTH K=V...] + + Authenticates users or machines to Vault using the provided arguments. A + successful authentication results in a Vault token - conceptually similar to + a session token on a website. By default, this token is cached on the local + machine for future requests. + + The default auth method is "token". If not supplied via the CLI, + Vault will prompt for input. If the argument is "-", the values are read + from stdin. + + The -method flag allows using other auth methods, such as userpass, github, or + cert. For these, additional "K=V" pairs may be required. For example, to + authenticate to the userpass auth method: + + $ vault login -method=userpass username=my-username + + For more information about the list of configuration parameters available for + a given auth method, use the "vault auth help TYPE". You can also use "vault + auth list" to see the list of enabled auth methods. + + If an auth method is enabled at a non-standard path, the -method flag still + refers to the canonical type, but the -path flag refers to the enabled path. + If a github auth method was enabled at "github-ent", authenticate like this: + + $ vault login -method=github -path=github-prod + + If the authentication is requested with response wrapping (via -wrap-ttl), + the returned token is automatically unwrapped unless: + + - The -token-only flag is used, in which case this command will output + the wrapping token. + + - The -no-store flag is used, in which case this command will output the + details of the wrapping token. + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *LoginCommand) Flags() *FlagSets { + set := c.flagSet(FlagSetHTTP | FlagSetOutputField | FlagSetOutputFormat) + + f := set.NewFlagSet("Command Options") + + f.StringVar(&StringVar{ + Name: "method", + Target: &c.flagMethod, + Default: "token", + Completion: c.PredictVaultAvailableAuths(), + Usage: "Type of authentication to use such as \"userpass\" or " + + "\"ldap\". Note this corresponds to the TYPE, not the enabled path. " + + "Use -path to specify the path where the authentication is enabled.", + }) + + f.StringVar(&StringVar{ + Name: "path", + Target: &c.flagPath, + Default: "", + Completion: c.PredictVaultAuths(), + Usage: "Remote path in Vault where the auth method is enabled. " + + "This defaults to the TYPE of method (e.g. userpass -> userpass/).", + }) + + f.BoolVar(&BoolVar{ + Name: "no-store", + Target: &c.flagNoStore, + Default: false, + Usage: "Do not persist the token to the token helper (usually the " + + "local filesystem) after authentication for use in future requests. " + + "The token will only be displayed in the command output.", + }) + + f.BoolVar(&BoolVar{ + Name: "token-only", + Target: &c.flagTokenOnly, + Default: false, + Usage: "Output only the token with no verification. This flag is a " + + "shortcut for \"-field=token -no-store\". Setting those flags to other " + + "values will have no affect.", + }) + + // Deprecations + // TODO: remove in 0.9.0 + f.BoolVar(&BoolVar{ + Name: "no-verify", + Target: &c.flagNoVerify, + Hidden: true, + Default: false, + Usage: "", + }) + + return set +} + +func (c *LoginCommand) AutocompleteArgs() complete.Predictor { + return nil +} + +func (c *LoginCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *LoginCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + + // Deprecations + // TODO: remove in 0.10.0 + switch { + case c.flagNoVerify: + c.UI.Warn(wrapAtLength( + "WARNING! The -no-verify flag is deprecated. In the past, Vault " + + "performed a lookup on a token after authentication. This is no " + + "longer the case for all auth methods except \"token\". Vault will " + + "still attempt to perform a lookup when given a token directly " + + "because that is how it gets the list of policies, ttl, and other " + + "metadata. To disable this lookup, specify \"lookup=false\" as a " + + "configuration option to the token auth method, like this:")) + c.UI.Warn("") + c.UI.Warn(" $ vault auth token=ABCD lookup=false") + c.UI.Warn("") + c.UI.Warn("Or omit the token and Vault will prompt for it:") + c.UI.Warn("") + c.UI.Warn(" $ vault auth lookup=false") + c.UI.Warn(" Token (will be hidden): ...") + c.UI.Warn("") + c.UI.Warn(wrapAtLength( + "If you are not using token authentication, you can safely omit this " + + "flag. Vault will not perform a lookup after authentication.")) + c.UI.Warn("") + + // There's no point in passing this to other auth handlers... + if c.flagMethod == "token" { + args = append(args, "lookup=false") + } + } + + // Set the right flags if the user requested token-only - this overrides + // any previously configured values, as documented. + if c.flagTokenOnly { + c.flagNoStore = true + c.flagField = "token" + } + + // Get the auth method + authMethod := sanitizePath(c.flagMethod) + if authMethod == "" { + authMethod = "token" + } + + // If no path is specified, we default the path to the method type + // or use the plugin name if it's a plugin + authPath := c.flagPath + if authPath == "" { + authPath = ensureTrailingSlash(authMethod) + } + + // Get the handler function + authHandler, ok := c.Handlers[authMethod] + if !ok { + c.UI.Error(wrapAtLength(fmt.Sprintf( + "Unknown auth method: %s. Use \"vault auth list\" to see the "+ + "complete list of auth methods. Additionally, some "+ + "auth methods are only available via the HTTP API.", + authMethod))) + return 1 + } + + // Pull our fake stdin if needed + stdin := (io.Reader)(os.Stdin) + if c.testStdin != nil { + stdin = c.testStdin + } + + // If the user provided a token, pass it along to the auth provier. + if authMethod == "token" && len(args) > 0 && !strings.Contains(args[0], "=") { + args = append([]string{"token=" + args[0]}, args[1:]...) + } + + config, err := parseArgsDataString(stdin, args) + if err != nil { + c.UI.Error(fmt.Sprintf("Error parsing configuration: %s", err)) + return 1 + } + + // If the user did not specify a mount path, use the provided mount path. + if config["mount"] == "" && authPath != "" { + config["mount"] = authPath + } + + // Create the client + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + // Authenticate delegation to the auth handler + secret, err := authHandler.Auth(client, config) + if err != nil { + c.UI.Error(fmt.Sprintf("Error authenticating: %s", err)) + return 2 + } + + // Unset any previous token wrapping functionality. If the original request + // was for a wrapped token, we don't want future requests to be wrapped. + client.SetWrappingLookupFunc(func(string, string) string { return "" }) + + // Recursively extract the token, handling wrapping + unwrap := !c.flagTokenOnly && !c.flagNoStore + secret, isWrapped, err := c.extractToken(client, secret, unwrap) + if err != nil { + c.UI.Error(fmt.Sprintf("Error extracting token: %s", err)) + return 2 + } + if secret == nil { + c.UI.Error("Vault returned an empty secret") + return 2 + } + + // Handle special cases if the token was wrapped + if isWrapped { + if c.flagTokenOnly { + return PrintRawField(c.UI, secret, "wrapping_token") + } + if c.flagNoStore { + return OutputSecret(c.UI, c.flagFormat, secret) + } + } + + // If we got this far, verify we have authentication data before continuing + if secret.Auth == nil { + c.UI.Error(wrapAtLength( + "Vault returned a secret, but the secret has no authentication " + + "information attached. This should never happen and is likely a " + + "bug.")) + return 2 + } + + // Pull the token itself out, since we don't need the rest of the auth + // information anymore/. + token := secret.Auth.ClientToken + + if !c.flagNoStore { + // Grab the token helper so we can store + tokenHelper, err := c.TokenHelper() + if err != nil { + c.UI.Error(wrapAtLength(fmt.Sprintf( + "Error initializing token helper. Please verify that the token "+ + "helper is available and properly configured for your system. The "+ + "error was: %s", err))) + return 1 + } + + // Store the token in the local client + if err := tokenHelper.Store(token); err != nil { + c.UI.Error(fmt.Sprintf("Error storing token: %s", err)) + c.UI.Error(wrapAtLength( + "Authentication was successful, but the token was not persisted. The "+ + "resulting token is shown below for your records.") + "\n") + OutputSecret(c.UI, c.flagFormat, secret) + return 2 + } + + // Warn if the VAULT_TOKEN environment variable is set, as that will take + // precedence. We output as a warning, so piping should still work since it + // will be on a different stream. + if os.Getenv("VAULT_TOKEN") != "" { + c.UI.Warn(wrapAtLength("WARNING! The VAULT_TOKEN environment variable "+ + "is set! This takes precedence over the value set by this command. To "+ + "use the value set by this command, unset the VAULT_TOKEN environment "+ + "variable or set it to the token displayed below.") + "\n") + } + } else { + c.UI.Warn(wrapAtLength( + "The token was not stored in token helper. Set the VAULT_TOKEN "+ + "environment variable or pass the token below with each request to "+ + "Vault.") + "\n") + } + + // If the user requested a particular field, print that out now since we + // are likely piping to another process. + if c.flagField != "" { + return PrintRawField(c.UI, secret, c.flagField) + } + + // Print some yay! text, but only in table mode. + if c.flagFormat == "table" { + c.UI.Output(wrapAtLength( + "Success! You are now authenticated. The token information displayed "+ + "below is already stored in the token helper. You do NOT need to run "+ + "\"vault login\" again. Future Vault requests will automatically use "+ + "this token.") + "\n") + } + + return OutputSecret(c.UI, c.flagFormat, secret) +} + +// extractToken extracts the token from the given secret, automatically +// unwrapping responses and handling error conditions if unwrap is true. The +// result also returns whether it was a wrapped resonse that was not unwrapped. +func (c *LoginCommand) extractToken(client *api.Client, secret *api.Secret, unwrap bool) (*api.Secret, bool, error) { + switch { + case secret == nil: + return nil, false, fmt.Errorf("empty response from auth helper") + + case secret.Auth != nil: + return secret, false, nil + + case secret.WrapInfo != nil: + if secret.WrapInfo.WrappedAccessor == "" { + return nil, false, fmt.Errorf("wrapped response does not contain a token") + } + + if !unwrap { + return secret, true, nil + } + + client.SetToken(secret.WrapInfo.Token) + secret, err := client.Logical().Unwrap("") + if err != nil { + return nil, false, err + } + return c.extractToken(client, secret, unwrap) + + default: + return nil, false, fmt.Errorf("no auth or wrapping info in response") + } +} diff --git a/command/login_test.go b/command/login_test.go new file mode 100644 index 0000000000..3776e2873f --- /dev/null +++ b/command/login_test.go @@ -0,0 +1,496 @@ +package command + +import ( + "strings" + "testing" + + "github.com/mitchellh/cli" + + "github.com/hashicorp/vault/api" + credToken "github.com/hashicorp/vault/builtin/credential/token" + credUserpass "github.com/hashicorp/vault/builtin/credential/userpass" + "github.com/hashicorp/vault/command/token" +) + +func testLoginCommand(tb testing.TB) (*cli.MockUi, *LoginCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &LoginCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + + // Override to our own token helper + tokenHelper: token.NewTestingTokenHelper(), + }, + Handlers: map[string]LoginHandler{ + "token": &credToken.CLIHandler{}, + "userpass": &credUserpass.CLIHandler{}, + }, + } +} + +func TestLoginCommand_Run(t *testing.T) { + t.Parallel() + + t.Run("custom_path", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().EnableAuth("my-auth", "userpass", ""); err != nil { + t.Fatal(err) + } + if _, err := client.Logical().Write("auth/my-auth/users/test", map[string]interface{}{ + "password": "test", + "policies": "default", + }); err != nil { + t.Fatal(err) + } + + ui, cmd := testLoginCommand(t) + cmd.client = client + + tokenHelper, err := cmd.TokenHelper() + if err != nil { + t.Fatal(err) + } + + code := cmd.Run([]string{ + "-method", "userpass", + "-path", "my-auth", + "username=test", + "password=test", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Success! You are now authenticated." + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to be %q", combined, expected) + } + + storedToken, err := tokenHelper.Get() + if err != nil { + t.Fatal(err) + } + + if l, exp := len(storedToken), 36; l != exp { + t.Errorf("expected token to be %d characters, was %d: %q", exp, l, storedToken) + } + }) + + t.Run("no_store", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + TTL: "30m", + }) + if err != nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + _, cmd := testLoginCommand(t) + cmd.client = client + + tokenHelper, err := cmd.TokenHelper() + if err != nil { + t.Fatal(err) + } + + // Ensure we have no token to start + if storedToken, err := tokenHelper.Get(); err != nil || storedToken != "" { + t.Errorf("expected token helper to be empty: %s: %q", err, storedToken) + } + + code := cmd.Run([]string{ + "-no-store", + token, + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + storedToken, err := tokenHelper.Get() + if err != nil { + t.Fatal(err) + } + + if exp := ""; storedToken != exp { + t.Errorf("expected %q to be %q", storedToken, exp) + } + }) + + t.Run("stores", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + TTL: "30m", + }) + if err != nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + _, cmd := testLoginCommand(t) + cmd.client = client + + tokenHelper, err := cmd.TokenHelper() + if err != nil { + t.Fatal(err) + } + + code := cmd.Run([]string{ + token, + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + storedToken, err := tokenHelper.Get() + if err != nil { + t.Fatal(err) + } + + if storedToken != token { + t.Errorf("expected %q to be %q", storedToken, token) + } + }) + + t.Run("token_only", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().EnableAuth("userpass", "userpass", ""); err != nil { + t.Fatal(err) + } + if _, err := client.Logical().Write("auth/userpass/users/test", map[string]interface{}{ + "password": "test", + "policies": "default", + }); err != nil { + t.Fatal(err) + } + + ui, cmd := testLoginCommand(t) + cmd.client = client + + tokenHelper, err := cmd.TokenHelper() + if err != nil { + t.Fatal(err) + } + + code := cmd.Run([]string{ + "-token-only", + "-method", "userpass", + "username=test", + "password=test", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + // Verify only the token was printed + token := ui.OutputWriter.String() + if l, exp := len(token), 36; l != exp { + t.Errorf("expected token to be %d characters, was %d: %q", exp, l, token) + } + + // Verify the token was not stored + if storedToken, err := tokenHelper.Get(); err != nil || storedToken != "" { + t.Fatalf("expted token to not be stored: %s: %q", err, storedToken) + } + }) + + t.Run("failure_no_store", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testLoginCommand(t) + cmd.client = client + + tokenHelper, err := cmd.TokenHelper() + if err != nil { + t.Fatal(err) + } + + code := cmd.Run([]string{ + "not-a-real-token", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error authenticating: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + if storedToken, err := tokenHelper.Get(); err != nil || storedToken != "" { + t.Fatalf("expected token to not be stored: %s: %q", err, storedToken) + } + }) + + t.Run("wrap_auto_unwrap", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().EnableAuth("userpass", "userpass", ""); err != nil { + t.Fatal(err) + } + if _, err := client.Logical().Write("auth/userpass/users/test", map[string]interface{}{ + "password": "test", + "policies": "default", + }); err != nil { + t.Fatal(err) + } + + _, cmd := testLoginCommand(t) + cmd.client = client + + // Set the wrapping ttl to 5s. We can't set this via the flag because we + // override the client object before that particular flag is parsed. + client.SetWrappingLookupFunc(func(string, string) string { return "5m" }) + + code := cmd.Run([]string{ + "-method", "userpass", + "username=test", + "password=test", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + // Unset the wrapping + client.SetWrappingLookupFunc(func(string, string) string { return "" }) + + tokenHelper, err := cmd.TokenHelper() + if err != nil { + t.Fatal(err) + } + token, err := tokenHelper.Get() + if err != nil || token == "" { + t.Fatalf("expected token from helper: %s: %q", err, token) + } + client.SetToken(token) + + // Ensure the resulting token is unwrapped + secret, err := client.Auth().Token().LookupSelf() + if err != nil { + t.Error(err) + } + if secret == nil { + t.Fatal("secret was nil") + } + + if secret.WrapInfo != nil { + t.Errorf("expected to be unwrapped: %#v", secret) + } + }) + + t.Run("wrap_token_only", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().EnableAuth("userpass", "userpass", ""); err != nil { + t.Fatal(err) + } + if _, err := client.Logical().Write("auth/userpass/users/test", map[string]interface{}{ + "password": "test", + "policies": "default", + }); err != nil { + t.Fatal(err) + } + + ui, cmd := testLoginCommand(t) + cmd.client = client + + // Set the wrapping ttl to 5s. We can't set this via the flag because we + // override the client object before that particular flag is parsed. + client.SetWrappingLookupFunc(func(string, string) string { return "5m" }) + + code := cmd.Run([]string{ + "-token-only", + "-method", "userpass", + "username=test", + "password=test", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + // Unset the wrapping + client.SetWrappingLookupFunc(func(string, string) string { return "" }) + + tokenHelper, err := cmd.TokenHelper() + if err != nil { + t.Fatal(err) + } + storedToken, err := tokenHelper.Get() + if err != nil || storedToken != "" { + t.Fatalf("expected token to not be stored: %s: %q", err, storedToken) + } + + token := strings.TrimSpace(ui.OutputWriter.String()) + if token == "" { + t.Errorf("expected %q to not be %q", token, "") + } + + // Ensure the resulting token is, in fact, still wrapped. + client.SetToken(token) + secret, err := client.Logical().Unwrap("") + if err != nil { + t.Error(err) + } + if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" { + t.Fatalf("expected secret to have auth: %#v", secret) + } + }) + + t.Run("wrap_no_store", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().EnableAuth("userpass", "userpass", ""); err != nil { + t.Fatal(err) + } + if _, err := client.Logical().Write("auth/userpass/users/test", map[string]interface{}{ + "password": "test", + "policies": "default", + }); err != nil { + t.Fatal(err) + } + + ui, cmd := testLoginCommand(t) + cmd.client = client + + // Set the wrapping ttl to 5s. We can't set this via the flag because we + // override the client object before that particular flag is parsed. + client.SetWrappingLookupFunc(func(string, string) string { return "5m" }) + + code := cmd.Run([]string{ + "-no-store", + "-method", "userpass", + "username=test", + "password=test", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + // Unset the wrapping + client.SetWrappingLookupFunc(func(string, string) string { return "" }) + + tokenHelper, err := cmd.TokenHelper() + if err != nil { + t.Fatal(err) + } + storedToken, err := tokenHelper.Get() + if err != nil || storedToken != "" { + t.Fatalf("expected token to not be stored: %s: %q", err, storedToken) + } + + expected := "wrapping_token" + output := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(output, expected) { + t.Errorf("expected %q to contain %q", output, expected) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testLoginCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "token", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error authenticating: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + // Deprecations + // TODO: remove in 0.9.0 + t.Run("deprecated_no_verify", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + TTL: "30m", + NumUses: 1, + }) + if err != nil { + t.Fatal(err) + } + token := secret.Auth.ClientToken + + _, cmd := testLoginCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "-no-verify", + token, + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + lookup, err := client.Auth().Token().Lookup(token) + if err != nil { + t.Fatal(err) + } + + // There was 1 use to start, make sure we didn't use it (verifying would + // use it). + uses, err := lookup.TokenRemainingUses() + if err != nil { + t.Fatal(err) + } + if uses != 1 { + t.Errorf("expected %d to be %d", uses, 1) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testLoginCommand(t) + assertNoTabs(t, cmd) + }) +} diff --git a/command/main.go b/command/main.go new file mode 100644 index 0000000000..6cc9d43da4 --- /dev/null +++ b/command/main.go @@ -0,0 +1,112 @@ +package command + +import ( + "bytes" + "fmt" + "io" + "os" + "sort" + "strings" + "text/tabwriter" + + "github.com/mitchellh/cli" +) + +func Run(args []string) int { + // Handle -v shorthand + for _, arg := range args { + if arg == "--" { + break + } + + if arg == "-v" || arg == "-version" || arg == "--version" { + args = []string{"version"} + break + } + } + + // Calculate hidden commands from the deprecated ones + hiddenCommands := make([]string, 0, len(DeprecatedCommands)+1) + for k := range DeprecatedCommands { + hiddenCommands = append(hiddenCommands, k) + } + hiddenCommands = append(hiddenCommands, "version") + + cli := &cli.CLI{ + Name: "vault", + Args: args, + Commands: Commands, + HelpFunc: groupedHelpFunc( + cli.BasicHelpFunc("vault"), + ), + HiddenCommands: hiddenCommands, + Autocomplete: true, + AutocompleteNoDefaultFlags: true, + } + + exitCode, err := cli.Run() + if err != nil { + fmt.Fprintf(os.Stderr, "Error executing CLI: %s\n", err.Error()) + return 1 + } + + return exitCode +} + +var commonCommands = []string{ + "read", + "write", + "delete", + "list", + "login", + "server", + "status", + "unwrap", +} + +func groupedHelpFunc(f cli.HelpFunc) cli.HelpFunc { + return func(commands map[string]cli.CommandFactory) string { + var b bytes.Buffer + tw := tabwriter.NewWriter(&b, 0, 2, 6, ' ', 0) + + fmt.Fprintf(tw, "Usage: vault [args]\n\n") + fmt.Fprintf(tw, "Common commands:\n") + for _, v := range commonCommands { + printCommand(tw, v, commands[v]) + } + + otherCommands := make([]string, 0, len(commands)) + for k := range commands { + found := false + for _, v := range commonCommands { + if k == v { + found = true + break + } + } + + if !found { + otherCommands = append(otherCommands, k) + } + } + sort.Strings(otherCommands) + + fmt.Fprintf(tw, "\n") + fmt.Fprintf(tw, "Other commands:\n") + for _, v := range otherCommands { + printCommand(tw, v, commands[v]) + } + + tw.Flush() + + return strings.TrimSpace(b.String()) + } +} + +func printCommand(w io.Writer, name string, cmdFn cli.CommandFactory) { + cmd, err := cmdFn() + if err != nil { + panic(fmt.Sprintf("failed to load %q command: %s", name, err)) + } + fmt.Fprintf(w, " %s\t%s\n", name, cmd.Synopsis()) +} diff --git a/command/mount.go b/command/mount.go deleted file mode 100644 index b97392fa33..0000000000 --- a/command/mount.go +++ /dev/null @@ -1,169 +0,0 @@ -package command - -import ( - "fmt" - "strings" - - "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/meta" - "github.com/posener/complete" -) - -// MountCommand is a Command that mounts a new mount. -type MountCommand struct { - meta.Meta -} - -func (c *MountCommand) Run(args []string) int { - var description, path, defaultLeaseTTL, maxLeaseTTL, pluginName string - var local, forceNoCache, sealWrap bool - flags := c.Meta.FlagSet("mount", meta.FlagSetDefault) - flags.StringVar(&description, "description", "", "") - flags.StringVar(&path, "path", "", "") - flags.StringVar(&defaultLeaseTTL, "default-lease-ttl", "", "") - flags.StringVar(&maxLeaseTTL, "max-lease-ttl", "", "") - flags.StringVar(&pluginName, "plugin-name", "", "") - flags.BoolVar(&forceNoCache, "force-no-cache", false, "") - flags.BoolVar(&local, "local", false, "") - flags.BoolVar(&sealWrap, "seal-wrap", false, "") - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - args = flags.Args() - if len(args) != 1 { - flags.Usage() - c.Ui.Error(fmt.Sprintf( - "\nmount expects one argument: the type to mount.")) - return 1 - } - - mountType := args[0] - - // If no path is specified, we default the path to the backend type - // or use the plugin name if it's a plugin backend - if path == "" { - if mountType == "plugin" { - path = pluginName - } else { - path = mountType - } - } - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 2 - } - - mountInfo := &api.MountInput{ - Type: mountType, - Description: description, - Config: api.MountConfigInput{ - DefaultLeaseTTL: defaultLeaseTTL, - MaxLeaseTTL: maxLeaseTTL, - ForceNoCache: forceNoCache, - PluginName: pluginName, - }, - Local: local, - SealWrap: sealWrap, - } - - if err := client.Sys().Mount(path, mountInfo); err != nil { - c.Ui.Error(fmt.Sprintf( - "Mount error: %s", err)) - return 2 - } - - mountTypeOutput := fmt.Sprintf("'%s'", mountType) - if mountType == "plugin" { - mountTypeOutput = fmt.Sprintf("plugin '%s'", pluginName) - } - - c.Ui.Output(fmt.Sprintf( - "Successfully mounted %s at '%s'!", - mountTypeOutput, path)) - - return 0 -} - -func (c *MountCommand) Synopsis() string { - return "Mount a logical backend" -} - -func (c *MountCommand) Help() string { - helpText := ` -Usage: vault mount [options] type - - Mount a logical backend. - - This command mounts a logical backend for storing and/or generating - secrets. - -General Options: -` + meta.GeneralOptionsUsage() + ` -Mount Options: - - -description= Human-friendly description of the purpose for - the mount. This shows up in the mounts command. - - -path= Mount point for the logical backend. This - defaults to the type of the mount. - - -default-lease-ttl= Default lease time-to-live for this backend. - If not specified, uses the global default, or - the previously set value. Set to '0' to - explicitly set it to use the global default. - - -max-lease-ttl= Max lease time-to-live for this backend. - If not specified, uses the global default, or - the previously set value. Set to '0' to - explicitly set it to use the global default. - - -force-no-cache Forces the backend to disable caching. If not - specified, uses the global default. This does - not affect caching of the underlying encrypted - data storage. - - -plugin-name Name of the plugin to mount based from the name - in the plugin catalog. - - -local Mark the mount as a local mount. Local mounts - are not replicated nor (if a secondary) - removed by replication. - - -seal-wrap Turn on seal wrapping for the mount. -` - return strings.TrimSpace(helpText) -} - -func (c *MountCommand) AutocompleteArgs() complete.Predictor { - // This list does not contain deprecated backends - return complete.PredictSet( - "aws", - "consul", - "pki", - "transit", - "ssh", - "rabbitmq", - "database", - "totp", - "plugin", - ) - -} - -func (c *MountCommand) AutocompleteFlags() complete.Flags { - return complete.Flags{ - "-description": complete.PredictNothing, - "-path": complete.PredictNothing, - "-default-lease-ttl": complete.PredictNothing, - "-max-lease-ttl": complete.PredictNothing, - "-force-no-cache": complete.PredictNothing, - "-plugin-name": complete.PredictNothing, - "-local": complete.PredictNothing, - "-seal-wrap": complete.PredictNothing, - } -} diff --git a/command/mount_test.go b/command/mount_test.go deleted file mode 100644 index ea9108cb71..0000000000 --- a/command/mount_test.go +++ /dev/null @@ -1,90 +0,0 @@ -package command - -import ( - "testing" - - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" - "github.com/mitchellh/cli" -) - -func TestMount(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - c := &MountCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, - }, - } - - args := []string{ - "-address", addr, - "kv", - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } - - mounts, err := client.Sys().ListMounts() - if err != nil { - t.Fatalf("err: %s", err) - } - - mount, ok := mounts["kv/"] - if !ok { - t.Fatal("should have kv mount") - } - if mount.Type != "kv" { - t.Fatal("should be kv type") - } -} - -func TestMount_Generic(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - c := &MountCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, - }, - } - - args := []string{ - "-address", addr, - "generic", - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } - - mounts, err := client.Sys().ListMounts() - if err != nil { - t.Fatalf("err: %s", err) - } - - mount, ok := mounts["generic/"] - if !ok { - t.Fatal("should have generic mount path") - } - if mount.Type != "generic" { - t.Fatal("should be generic type") - } -} diff --git a/command/mount_tune.go b/command/mount_tune.go deleted file mode 100644 index e1efdd241d..0000000000 --- a/command/mount_tune.go +++ /dev/null @@ -1,89 +0,0 @@ -package command - -import ( - "fmt" - "strings" - - "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/meta" -) - -// MountTuneCommand is a Command that remounts a mounted secret backend -// to a new endpoint. -type MountTuneCommand struct { - meta.Meta -} - -func (c *MountTuneCommand) Run(args []string) int { - var defaultLeaseTTL, maxLeaseTTL string - flags := c.Meta.FlagSet("mount-tune", meta.FlagSetDefault) - flags.StringVar(&defaultLeaseTTL, "default-lease-ttl", "", "") - flags.StringVar(&maxLeaseTTL, "max-lease-ttl", "", "") - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - args = flags.Args() - if len(args) != 1 { - flags.Usage() - c.Ui.Error(fmt.Sprintf( - "\nmount-tune expects one arguments: the mount path")) - return 1 - } - - path := args[0] - - mountConfig := api.MountConfigInput{ - DefaultLeaseTTL: defaultLeaseTTL, - MaxLeaseTTL: maxLeaseTTL, - } - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 2 - } - - if err := client.Sys().TuneMount(path, mountConfig); err != nil { - c.Ui.Error(fmt.Sprintf( - "Mount tune error: %s", err)) - return 2 - } - - c.Ui.Output(fmt.Sprintf( - "Successfully tuned mount '%s'!", path)) - - return 0 -} - -func (c *MountTuneCommand) Synopsis() string { - return "Tune mount configuration parameters" -} - -func (c *MountTuneCommand) Help() string { - helpText := ` - Usage: vault mount-tune [options] path - - Tune configuration options for a mounted secret backend. - - Example: vault mount-tune -default-lease-ttl="24h" secret - -General Options: -` + meta.GeneralOptionsUsage() + ` -Mount Options: - - -default-lease-ttl= Default lease time-to-live for this backend. - If not specified, uses the system default, or - the previously set value. Set to 'system' to - explicitly set it to use the system default. - - -max-lease-ttl= Max lease time-to-live for this backend. - If not specified, uses the system default, or - the previously set value. Set to 'system' to - explicitly set it to use the system default. - -` - return strings.TrimSpace(helpText) -} diff --git a/command/mounts.go b/command/mounts.go deleted file mode 100644 index b14bbd8320..0000000000 --- a/command/mounts.go +++ /dev/null @@ -1,98 +0,0 @@ -package command - -import ( - "fmt" - "sort" - "strconv" - "strings" - - "github.com/hashicorp/vault/meta" - "github.com/ryanuber/columnize" -) - -// MountsCommand is a Command that lists the mounts. -type MountsCommand struct { - meta.Meta -} - -func (c *MountsCommand) Run(args []string) int { - flags := c.Meta.FlagSet("mounts", meta.FlagSetDefault) - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 2 - } - - mounts, err := client.Sys().ListMounts() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error reading mounts: %s", err)) - return 2 - } - - paths := make([]string, 0, len(mounts)) - for path := range mounts { - paths = append(paths, path) - } - sort.Strings(paths) - - columns := []string{"Path | Type | Accessor | Plugin | Default TTL | Max TTL | Force No Cache | Replication Behavior | Seal Wrap | Description"} - for _, path := range paths { - mount := mounts[path] - pluginName := "n/a" - if mount.Config.PluginName != "" { - pluginName = mount.Config.PluginName - } - defTTL := "system" - switch { - case mount.Type == "system", mount.Type == "cubbyhole", mount.Type == "identity": - defTTL = "n/a" - case mount.Config.DefaultLeaseTTL != 0: - defTTL = strconv.Itoa(mount.Config.DefaultLeaseTTL) - } - - maxTTL := "system" - switch { - case mount.Type == "system", mount.Type == "cubbyhole", mount.Type == "identity": - maxTTL = "n/a" - case mount.Config.MaxLeaseTTL != 0: - maxTTL = strconv.Itoa(mount.Config.MaxLeaseTTL) - } - - replicatedBehavior := "replicated" - if mount.Local { - replicatedBehavior = "local" - } - columns = append(columns, fmt.Sprintf( - "%s | %s | %s | %s | %s | %s | %v | %s | %t | %s", path, mount.Type, mount.Accessor, pluginName, defTTL, maxTTL, - mount.Config.ForceNoCache, replicatedBehavior, mount.SealWrap, mount.Description)) - } - - c.Ui.Output(columnize.SimpleFormat(columns)) - return 0 -} - -func (c *MountsCommand) Synopsis() string { - return "Lists mounted backends in Vault" -} - -func (c *MountsCommand) Help() string { - helpText := ` -Usage: vault mounts [options] - - Outputs information about the mounted backends. - - This command lists the mounted backends, their mount points, the - configured TTLs, and a human-friendly description of the mount point. - A TTL of 'system' indicates that the system default is being used. - -General Options: -` + meta.GeneralOptionsUsage() - return strings.TrimSpace(helpText) -} diff --git a/command/mounts_test.go b/command/mounts_test.go deleted file mode 100644 index 55e5f679f6..0000000000 --- a/command/mounts_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package command - -import ( - "testing" - - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" - "github.com/mitchellh/cli" -) - -func TestMounts(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - c := &MountsCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, - }, - } - - args := []string{ - "-address", addr, - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } -} diff --git a/command/operator.go b/command/operator.go new file mode 100644 index 0000000000..ad1bb439fc --- /dev/null +++ b/command/operator.go @@ -0,0 +1,47 @@ +package command + +import ( + "strings" + + "github.com/mitchellh/cli" +) + +var _ cli.Command = (*OperatorCommand)(nil) + +type OperatorCommand struct { + *BaseCommand +} + +func (c *OperatorCommand) Synopsis() string { + return "Perform operator-specific tasks" +} + +func (c *OperatorCommand) Help() string { + helpText := ` +Usage: vault operator [options] [args] + + This command groups subcommands for operators interacting with Vault. Most + users will not need to interact with these commands. Here are a few examples + of the operator commands: + + Initialize a new Vault cluster: + + $ vault operator init + + Force a Vault to resign leadership in a cluster: + + $ vault operator step-down + + Rotate Vault's underlying encryption key: + + $ vault operator rotate + + Please see the individual subcommand help for detailed usage information. +` + + return strings.TrimSpace(helpText) +} + +func (c *OperatorCommand) Run(args []string) int { + return cli.RunResultHelp +} diff --git a/command/operator_generate_root.go b/command/operator_generate_root.go new file mode 100644 index 0000000000..6eb17042a6 --- /dev/null +++ b/command/operator_generate_root.go @@ -0,0 +1,448 @@ +package command + +import ( + "bytes" + "crypto/rand" + "encoding/base64" + "fmt" + "io" + "os" + "strings" + + "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/password" + "github.com/hashicorp/vault/helper/pgpkeys" + "github.com/hashicorp/vault/helper/xor" + "github.com/mitchellh/cli" + "github.com/posener/complete" +) + +var _ cli.Command = (*OperatorGenerateRootCommand)(nil) +var _ cli.CommandAutocomplete = (*OperatorGenerateRootCommand)(nil) + +type OperatorGenerateRootCommand struct { + *BaseCommand + + flagInit bool + flagCancel bool + flagStatus bool + flagDecode string + flagOTP string + flagPGPKey string + flagNonce string + flagGenerateOTP bool + + // Deprecation + // TODO: remove in 0.9.0 + flagGenOTP bool + + testStdin io.Reader // for tests +} + +func (c *OperatorGenerateRootCommand) Synopsis() string { + return "Generates a new root token" +} + +func (c *OperatorGenerateRootCommand) Help() string { + helpText := ` +Usage: vault operator generate-root [options] [KEY] + + Generates a new root token by combining a quorum of share holders. One of + the following must be provided to start the root token generation: + + - A base64-encoded one-time-password (OTP) provided via the "-otp" flag. + Use the "-generate-otp" flag to generate a usable value. The resulting + token is XORed with this value when it is returned. Use the "-decode" + flag to output the final value. + + - A file containing a PGP key or a keybase username in the "-pgp-key" + flag. The resulting token is encrypted with this public key. + + An unseal key may be provided directly on the command line as an argument to + the command. If key is specified as "-", the command will read from stdin. If + a TTY is available, the command will prompt for text. + + Generate an OTP code for the final token: + + $ vault operator generate-root -generate-otp + + Start a root token generation: + + $ vault operator generate-root -init -otp="..." + $ vault operator generate-root -init -pgp-key="..." + + Enter an unseal key to progress root token generation: + + $ vault operator generate-root -otp="..." + +` + c.Flags().Help() + return strings.TrimSpace(helpText) +} + +func (c *OperatorGenerateRootCommand) Flags() *FlagSets { + set := c.flagSet(FlagSetHTTP) + + f := set.NewFlagSet("Command Options") + + f.BoolVar(&BoolVar{ + Name: "init", + Target: &c.flagInit, + Default: false, + EnvVar: "", + Completion: complete.PredictNothing, + Usage: "Start a root token generation. This can only be done if " + + "there is not currently one in progress.", + }) + + f.BoolVar(&BoolVar{ + Name: "cancel", + Target: &c.flagCancel, + Default: false, + EnvVar: "", + Completion: complete.PredictNothing, + Usage: "Reset the root token generation progress. This will discard any " + + "submitted unseal keys or configuration.", + }) + + f.BoolVar(&BoolVar{ + Name: "status", + Target: &c.flagStatus, + Default: false, + EnvVar: "", + Completion: complete.PredictNothing, + Usage: "Print the status of the current attempt without providing an " + + "unseal key.", + }) + + f.StringVar(&StringVar{ + Name: "decode", + Target: &c.flagDecode, + Default: "", + EnvVar: "", + Completion: complete.PredictAnything, + Usage: "Decode and output the generated root token. This option requires " + + "the \"-otp\" flag be set to the OTP used during initialization.", + }) + + f.BoolVar(&BoolVar{ + Name: "generate-otp", + Target: &c.flagGenerateOTP, + Default: false, + EnvVar: "", + Completion: complete.PredictNothing, + Usage: "Generate and print a high-entropy one-time-password (OTP) " + + "suitable for use with the \"-init\" flag.", + }) + + f.StringVar(&StringVar{ + Name: "otp", + Target: &c.flagOTP, + Default: "", + EnvVar: "", + Completion: complete.PredictAnything, + Usage: "OTP code to use with \"-decode\" or \"-init\".", + }) + + f.VarFlag(&VarFlag{ + Name: "pgp-key", + Value: (*pgpkeys.PubKeyFileFlag)(&c.flagPGPKey), + Default: "", + EnvVar: "", + Completion: complete.PredictAnything, + Usage: "Path to a file on disk containing a binary or base64-encoded " + + "public GPG key. This can also be specified as a Keybase username " + + "using the format \"keybase:\". When supplied, the generated " + + "root token will be encrypted and base64-encoded with the given public " + + "key.", + }) + + f.StringVar(&StringVar{ + Name: "nonce", + Target: &c.flagNonce, + Default: "", + EnvVar: "", + Completion: complete.PredictAnything, + Usage: "Nonce value provided at initialization. The same nonce value " + + "must be provided with each unseal key.", + }) + + // Deprecations: prefer longer-form, descriptive flags + // TODO: remove in 0.9.0 + f.BoolVar(&BoolVar{ + Name: "genotp", // -generate-otp + Target: &c.flagGenOTP, + Default: false, + Hidden: true, + }) + + return set +} + +func (c *OperatorGenerateRootCommand) AutocompleteArgs() complete.Predictor { + return nil +} + +func (c *OperatorGenerateRootCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *OperatorGenerateRootCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + if len(args) > 1 { + c.UI.Error(fmt.Sprintf("Too many arguments (expected 0-1, got %d)", len(args))) + return 1 + } + + // Deprecations + // TODO: remove in 0.9.0 + switch { + case c.flagGenOTP: + c.UI.Warn(wrapAtLength( + "The -gen-otp flag is deprecated. Please use the -generate-otp flag " + + "instead.")) + c.flagGenerateOTP = c.flagGenOTP + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + switch { + case c.flagGenerateOTP: + return c.generateOTP() + case c.flagDecode != "": + return c.decode(c.flagDecode, c.flagOTP) + case c.flagCancel: + return c.cancel(client) + case c.flagInit: + return c.init(client, c.flagOTP, c.flagPGPKey) + case c.flagStatus: + return c.status(client) + default: + // If there are no other flags, prompt for an unseal key. + key := "" + if len(args) > 0 { + key = strings.TrimSpace(args[0]) + } + return c.provide(client, key) + } +} + +// verifyOTP verifies the given OTP code is exactly 16 bytes. +func (c *OperatorGenerateRootCommand) verifyOTP(otp string) error { + if len(otp) == 0 { + return fmt.Errorf("No OTP passed in") + } + otpBytes, err := base64.StdEncoding.DecodeString(otp) + if err != nil { + return fmt.Errorf("Error decoding base64 OTP value: %s", err) + } + if otpBytes == nil || len(otpBytes) != 16 { + return fmt.Errorf("Decoded OTP value is invalid or wrong length") + } + + return nil +} + +// generateOTP generates a suitable OTP code for generating a root token. +func (c *OperatorGenerateRootCommand) generateOTP() int { + buf := make([]byte, 16) + readLen, err := rand.Read(buf) + if err != nil { + c.UI.Error(fmt.Sprintf("Error reading random bytes: %s", err)) + return 2 + } + + if readLen != 16 { + c.UI.Error(fmt.Sprintf("Read %d bytes when we should have read 16", readLen)) + return 2 + } + + return PrintRaw(c.UI, base64.StdEncoding.EncodeToString(buf)) +} + +// decode decodes the given value using the otp. +func (c *OperatorGenerateRootCommand) decode(encoded, otp string) int { + if encoded == "" { + c.UI.Error("Missing encoded value: use -decode= to supply it") + return 1 + } + if otp == "" { + c.UI.Error("Missing otp: use -otp to supply it") + return 1 + } + + tokenBytes, err := xor.XORBase64(encoded, otp) + if err != nil { + c.UI.Error(fmt.Sprintf("Error xoring token: %s", err)) + return 1 + } + + token, err := uuid.FormatUUID(tokenBytes) + if err != nil { + c.UI.Error(fmt.Sprintf("Error formatting base64 token value: %s", err)) + return 1 + } + + return PrintRaw(c.UI, strings.TrimSpace(token)) +} + +// init is used to start the generation process +func (c *OperatorGenerateRootCommand) init(client *api.Client, otp string, pgpKey string) int { + // Validate incoming fields. Either OTP OR PGP keys must be supplied. + switch { + case otp == "" && pgpKey == "": + c.UI.Error("Error initializing: must specify either -otp or -pgp-key") + return 1 + case otp != "" && pgpKey != "": + c.UI.Error("Error initializing: cannot specify both -otp and -pgp-key") + return 1 + case otp != "": + if err := c.verifyOTP(otp); err != nil { + c.UI.Error(fmt.Sprintf("Error initializing: invalid OTP: %s", err)) + return 1 + } + case pgpKey != "": + // OK + } + + // Start the root generation + status, err := client.Sys().GenerateRootInit(otp, pgpKey) + if err != nil { + c.UI.Error(fmt.Sprintf("Error initializing root generation: %s", err)) + return 2 + } + return c.printStatus(status) +} + +// provide prompts the user for the seal key and posts it to the update root +// endpoint. If this is the last unseal, this function outputs it. +func (c *OperatorGenerateRootCommand) provide(client *api.Client, key string) int { + status, err := client.Sys().GenerateRootStatus() + if err != nil { + c.UI.Error(fmt.Sprintf("Error getting root generation status: %s", err)) + return 2 + } + + // Verify a root token generation is in progress. If there is not one in + // progress, return an error instructing the user to start one. + if !status.Started { + c.UI.Error(wrapAtLength( + "No root generation is in progress. Start a root generation by " + + "running \"vault generate-root -init\".")) + return 1 + } + + var nonce string + + switch key { + case "-": // Read from stdin + nonce = c.flagNonce + + // Pull our fake stdin if needed + stdin := (io.Reader)(os.Stdin) + if c.testStdin != nil { + stdin = c.testStdin + } + + var buf bytes.Buffer + if _, err := io.Copy(&buf, stdin); err != nil { + c.UI.Error(fmt.Sprintf("Failed to read from stdin: %s", err)) + return 1 + } + + key = buf.String() + case "": // Prompt using the tty + // Nonce value is not required if we are prompting via the terminal + nonce = status.Nonce + + w := getWriterFromUI(c.UI) + fmt.Fprintf(w, "Root generation operation nonce: %s\n", nonce) + fmt.Fprintf(w, "Unseal Key (will be hidden): ") + key, err = password.Read(os.Stdin) + fmt.Fprintf(w, "\n") + if err != nil { + if err == password.ErrInterrupted { + c.UI.Error("user canceled") + return 1 + } + + c.UI.Error(wrapAtLength(fmt.Sprintf("An error occurred attempting to "+ + "ask for the unseal key. The raw error message is shown below, but "+ + "usually this is because you attempted to pipe a value into the "+ + "command or you are executing outside of a terminal (tty). If you "+ + "want to pipe the value, pass \"-\" as the argument to read from "+ + "stdin. The raw error was: %s", err))) + return 1 + } + default: // Supplied directly as an arg + nonce = c.flagNonce + } + + // Trim any whitespace from they key, especially since we might have prompted + // the user for it. + key = strings.TrimSpace(key) + + // Verify we have a nonce value + if nonce == "" { + c.UI.Error("Missing nonce value: specify it via the -nonce flag") + return 1 + } + + // Provide the key, this may potentially complete the update + status, err = client.Sys().GenerateRootUpdate(key, nonce) + if err != nil { + c.UI.Error(fmt.Sprintf("Error posting unseal key: %s", err)) + return 2 + } + return c.printStatus(status) +} + +// cancel cancels the root token generation +func (c *OperatorGenerateRootCommand) cancel(client *api.Client) int { + if err := client.Sys().GenerateRootCancel(); err != nil { + c.UI.Error(fmt.Sprintf("Error canceling root token generation: %s", err)) + return 2 + } + c.UI.Output("Success! Root token generation canceled (if it was started)") + return 0 +} + +// status is used just to fetch and dump the status +func (c *OperatorGenerateRootCommand) status(client *api.Client) int { + status, err := client.Sys().GenerateRootStatus() + if err != nil { + c.UI.Error(fmt.Sprintf("Error getting root generation status: %s", err)) + return 2 + } + return c.printStatus(status) +} + +// printStatus dumps the status to output +func (c *OperatorGenerateRootCommand) printStatus(status *api.GenerateRootStatusResponse) int { + out := []string{} + out = append(out, fmt.Sprintf("Nonce | %s", status.Nonce)) + out = append(out, fmt.Sprintf("Started | %t", status.Started)) + out = append(out, fmt.Sprintf("Progress | %d/%d", status.Progress, status.Required)) + out = append(out, fmt.Sprintf("Complete | %t", status.Complete)) + if status.PGPFingerprint != "" { + out = append(out, fmt.Sprintf("PGP Fingerprint | %s", status.PGPFingerprint)) + } + if status.EncodedRootToken != "" { + out = append(out, fmt.Sprintf("Root Token | %s", status.EncodedRootToken)) + } + + output := columnOutput(out, nil) + c.UI.Output(output) + return 0 +} diff --git a/command/operator_generate_root_test.go b/command/operator_generate_root_test.go new file mode 100644 index 0000000000..ad5e67e239 --- /dev/null +++ b/command/operator_generate_root_test.go @@ -0,0 +1,448 @@ +package command + +import ( + "io" + "regexp" + "strings" + "testing" + + uuid "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/helper/xor" + "github.com/mitchellh/cli" +) + +func testOperatorGenerateRootCommand(tb testing.TB) (*cli.MockUi, *OperatorGenerateRootCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &OperatorGenerateRootCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestOperatorGenerateRootCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "init_no_args", + []string{ + "-init", + }, + "must specify either -otp or -pgp-key", + 1, + }, + { + "init_invalid_otp", + []string{ + "-init", + "-otp", "not-a-valid-otp", + }, + "Error initializing: invalid OTP:", + 1, + }, + { + "init_pgp_multi", + []string{ + "-init", + "-pgp-key", "keybase:hashicorp", + "-pgp-key", "keybase:jefferai", + }, + "can only be specified once", + 1, + }, + { + "init_pgp_multi_inline", + []string{ + "-init", + "-pgp-key", "keybase:hashicorp,keybase:jefferai", + }, + "can only specify one pgp key", + 1, + }, + { + "init_pgp_otp", + []string{ + "-init", + "-pgp-key", "keybase:hashicorp", + "-otp", "abcd1234", + }, + "cannot specify both -otp and -pgp-key", + 1, + }, + } + + t.Run("validations", func(t *testing.T) { + t.Parallel() + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ui, cmd := testOperatorGenerateRootCommand(t) + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) + + t.Run("generate_otp", func(t *testing.T) { + t.Parallel() + + ui, cmd := testOperatorGenerateRootCommand(t) + + code := cmd.Run([]string{ + "-generate-otp", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + output := ui.OutputWriter.String() + ui.ErrorWriter.String() + if err := cmd.verifyOTP(output); err != nil { + t.Fatal(err) + } + }) + + t.Run("decode", func(t *testing.T) { + t.Parallel() + + encoded := "L9MaZ/4mQanpOV6QeWd84g==" + otp := "dIeeezkjpDUv3fy7MYPOLQ==" + + ui, cmd := testOperatorGenerateRootCommand(t) + + code := cmd.Run([]string{ + "-decode", encoded, + "-otp", otp, + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "5b54841c-c705-e59c-c6e4-a22b48e4b2cf" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if combined != expected { + t.Errorf("expected %q to be %q", combined, expected) + } + }) + + t.Run("cancel", func(t *testing.T) { + t.Parallel() + + otp := "dIeeezkjpDUv3fy7MYPOLQ==" + + client, closer := testVaultServer(t) + defer closer() + + // Initialize a generation + if _, err := client.Sys().GenerateRootInit(otp, ""); err != nil { + t.Fatal(err) + } + + ui, cmd := testOperatorGenerateRootCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "-cancel", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Success! Root token generation canceled" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + status, err := client.Sys().GenerateRootStatus() + if err != nil { + t.Fatal(err) + } + + if status.Started { + t.Errorf("expected status to be canceled: %#v", status) + } + }) + + t.Run("init_otp", func(t *testing.T) { + t.Parallel() + + otp := "dIeeezkjpDUv3fy7MYPOLQ==" + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testOperatorGenerateRootCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "-init", + "-otp", otp, + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Nonce" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + status, err := client.Sys().GenerateRootStatus() + if err != nil { + t.Fatal(err) + } + + if !status.Started { + t.Errorf("expected status to be started: %#v", status) + } + }) + + t.Run("init_pgp", func(t *testing.T) { + t.Parallel() + + pgpKey := "keybase:hashicorp" + pgpFingerprint := "91a6e7f85d05c65630bef18951852d87348ffc4c" + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testOperatorGenerateRootCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "-init", + "-pgp-key", pgpKey, + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Nonce" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + status, err := client.Sys().GenerateRootStatus() + if err != nil { + t.Fatal(err) + } + + if !status.Started { + t.Errorf("expected status to be started: %#v", status) + } + if status.PGPFingerprint != pgpFingerprint { + t.Errorf("expected %q to be %q", status.PGPFingerprint, pgpFingerprint) + } + }) + + t.Run("status", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testOperatorGenerateRootCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "-status", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Nonce" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("provide_arg", func(t *testing.T) { + t.Parallel() + + otp := "dIeeezkjpDUv3fy7MYPOLQ==" + + client, keys, closer := testVaultServerUnseal(t) + defer closer() + + // Initialize a generation + status, err := client.Sys().GenerateRootInit(otp, "") + if err != nil { + t.Fatal(err) + } + nonce := status.Nonce + + // Supply the first n-1 unseal keys + for _, key := range keys[:len(keys)-1] { + _, cmd := testOperatorGenerateRootCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "-nonce", nonce, + key, + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + } + + ui, cmd := testOperatorGenerateRootCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "-nonce", nonce, + keys[len(keys)-1], // the last unseal key + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + reToken := regexp.MustCompile(`Root Token\s+(.+)`) + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + match := reToken.FindAllStringSubmatch(combined, -1) + if len(match) < 1 || len(match[0]) < 2 { + t.Fatalf("no match: %#v", match) + } + + tokenBytes, err := xor.XORBase64(match[0][1], otp) + if err != nil { + t.Fatal(err) + } + token, err := uuid.FormatUUID(tokenBytes) + if err != nil { + t.Fatal(err) + } + + if l, exp := len(token), 36; l != exp { + t.Errorf("expected %d to be %d: %s", l, exp, token) + } + }) + + t.Run("provide_stdin", func(t *testing.T) { + t.Parallel() + + otp := "dIeeezkjpDUv3fy7MYPOLQ==" + + client, keys, closer := testVaultServerUnseal(t) + defer closer() + + // Initialize a generation + status, err := client.Sys().GenerateRootInit(otp, "") + if err != nil { + t.Fatal(err) + } + nonce := status.Nonce + + // Supply the first n-1 unseal keys + for _, key := range keys[:len(keys)-1] { + stdinR, stdinW := io.Pipe() + go func() { + stdinW.Write([]byte(key)) + stdinW.Close() + }() + + _, cmd := testOperatorGenerateRootCommand(t) + cmd.client = client + cmd.testStdin = stdinR + + code := cmd.Run([]string{ + "-nonce", nonce, + "-", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + } + + stdinR, stdinW := io.Pipe() + go func() { + stdinW.Write([]byte(keys[len(keys)-1])) // the last unseal key + stdinW.Close() + }() + + ui, cmd := testOperatorGenerateRootCommand(t) + cmd.client = client + cmd.testStdin = stdinR + + code := cmd.Run([]string{ + "-nonce", nonce, + "-", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + reToken := regexp.MustCompile(`Root Token\s+(.+)`) + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + match := reToken.FindAllStringSubmatch(combined, -1) + if len(match) < 1 || len(match[0]) < 2 { + t.Fatalf("no match: %#v", match) + } + + tokenBytes, err := xor.XORBase64(match[0][1], otp) + if err != nil { + t.Fatal(err) + } + token, err := uuid.FormatUUID(tokenBytes) + if err != nil { + t.Fatal(err) + } + + if l, exp := len(token), 36; l != exp { + t.Errorf("expected %d to be %d: %s", l, exp, token) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testOperatorGenerateRootCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "secret/foo", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error getting root generation status: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testOperatorGenerateRootCommand(t) + assertNoTabs(t, cmd) + }) +} diff --git a/command/operator_init.go b/command/operator_init.go new file mode 100644 index 0000000000..cf3aaf31db --- /dev/null +++ b/command/operator_init.go @@ -0,0 +1,590 @@ +package command + +import ( + "encoding/json" + "fmt" + "net/url" + "runtime" + "strings" + + "github.com/ghodss/yaml" + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/pgpkeys" + "github.com/mitchellh/cli" + "github.com/posener/complete" + + consulapi "github.com/hashicorp/consul/api" +) + +var _ cli.Command = (*OperatorInitCommand)(nil) +var _ cli.CommandAutocomplete = (*OperatorInitCommand)(nil) + +type OperatorInitCommand struct { + *BaseCommand + + flagStatus bool + flagKeyShares int + flagKeyThreshold int + flagPGPKeys []string + flagRootTokenPGPKey string + + // HSM + flagStoredShares int + flagRecoveryShares int + flagRecoveryThreshold int + flagRecoveryPGPKeys []string + + // Consul + flagConsulAuto bool + flagConsulService string + + // Deprecations + // TODO: remove in 0.9.0 + flagAuto bool + flagCheck bool +} + +func (c *OperatorInitCommand) Synopsis() string { + return "Initializes a server" +} + +func (c *OperatorInitCommand) Help() string { + helpText := ` +Usage: vault operator init [options] + + Initializes a Vault server. Initialization is the process by which Vault's + storage backend is prepared to receive data. Since Vault server's share the + same storage backend in HA mode, you only need to initialize one Vault to + initialize the storage backend. + + During initialization, Vault generates an in-memory master key and applies + Shamir's secret sharing algorithm to disassemble that master key into a + configuration number of key shares such that a configurable subset of those + key shares must come together to regenerate the master key. These keys are + often called "unseal keys" in Vault's documentation. + + This command cannot be run against already-initialized Vault cluster. + + Start initialization with the default options: + + $ vault operator init + + Initialize, but encrypt the unseal keys with pgp keys: + + $ vault operator init \ + -key-shares=3 \ + -key-threshold=2 \ + -pgp-keys="keybase:hashicorp,keybase:jefferai,keybase:sethvargo" + + Encrypt the initial root token using a pgp key: + + $ vault operator init -root-token-pgp-key="keybase:hashicorp" + +` + c.Flags().Help() + return strings.TrimSpace(helpText) +} + +func (c *OperatorInitCommand) Flags() *FlagSets { + set := c.flagSet(FlagSetHTTP | FlagSetOutputFormat) + + // Common Options + f := set.NewFlagSet("Common Options") + + f.BoolVar(&BoolVar{ + Name: "status", + Target: &c.flagStatus, + Default: false, + Usage: "Print the current initialization status. An exit code of 0 means " + + "the Vault is already initialized. An exit code of 1 means an error " + + "occurred. An exit code of 2 means the mean is not initialized.", + }) + + f.IntVar(&IntVar{ + Name: "key-shares", + Aliases: []string{"n"}, + Target: &c.flagKeyShares, + Default: 5, + Completion: complete.PredictAnything, + Usage: "Number of key shares to split the generated master key into. " + + "This is the number of \"unseal keys\" to generate.", + }) + + f.IntVar(&IntVar{ + Name: "key-threshold", + Aliases: []string{"t"}, + Target: &c.flagKeyThreshold, + Default: 3, + Completion: complete.PredictAnything, + Usage: "Number of key shares required to reconstruct the master key. " + + "This must be less than or equal to -key-shares.", + }) + + f.VarFlag(&VarFlag{ + Name: "pgp-keys", + Value: (*pgpkeys.PubKeyFilesFlag)(&c.flagPGPKeys), + Completion: complete.PredictAnything, + Usage: "Comma-separated list of paths to files on disk containing " + + "public GPG keys OR a comma-separated list of Keybase usernames using " + + "the format \"keybase:\". When supplied, the generated " + + "unseal keys will be encrypted and base64-encoded in the order " + + "specified in this list. The number of entires must match -key-shares, " + + "unless -store-shares are used.", + }) + + f.VarFlag(&VarFlag{ + Name: "root-token-pgp-key", + Value: (*pgpkeys.PubKeyFileFlag)(&c.flagRootTokenPGPKey), + Completion: complete.PredictAnything, + Usage: "Path to a file on disk containing a binary or base64-encoded " + + "public GPG key. This can also be specified as a Keybase username " + + "using the format \"keybase:\". When supplied, the generated " + + "root token will be encrypted and base64-encoded with the given public " + + "key.", + }) + + // Consul Options + f = set.NewFlagSet("Consul Options") + + f.BoolVar(&BoolVar{ + Name: "consul-auto", + Target: &c.flagConsulAuto, + Default: false, + Usage: "Perform automatic service discovery using Consul in HA mode. " + + "When all nodes in a Vault HA cluster are registered with Consul, " + + "enabling this option will trigger automatic service discovery based " + + "on the provided -consul-service value. When Consul is Vault's HA " + + "backend, this functionality is automatically enabled. Ensure the " + + "proper Consul environment variables are set (CONSUL_HTTP_ADDR, etc). " + + "When only one Vault server is discovered, it will be initialized " + + "automatically. When more than one Vault server is discovered, they " + + "will each be output for selection.", + }) + + f.StringVar(&StringVar{ + Name: "consul-service", + Target: &c.flagConsulService, + Default: "vault", + Completion: complete.PredictAnything, + Usage: "Name of the service in Consul under which the Vault servers are " + + "registered.", + }) + + // HSM Options + f = set.NewFlagSet("HSM Options") + + f.IntVar(&IntVar{ + Name: "recovery-shares", + Target: &c.flagRecoveryShares, + Default: 5, + Completion: complete.PredictAnything, + Usage: "Number of key shares to split the recovery key into. " + + "This is only used in HSM mode.", + }) + + f.IntVar(&IntVar{ + Name: "recovery-threshold", + Target: &c.flagRecoveryThreshold, + Default: 3, + Completion: complete.PredictAnything, + Usage: "Number of key shares required to reconstruct the recovery key. " + + "This is only used in HSM mode.", + }) + + f.VarFlag(&VarFlag{ + Name: "recovery-pgp-keys", + Value: (*pgpkeys.PubKeyFilesFlag)(&c.flagRecoveryPGPKeys), + Completion: complete.PredictAnything, + Usage: "Behaves like -pgp-keys, but for the recovery key shares. This " + + "is only used in HSM mode.", + }) + + f.IntVar(&IntVar{ + Name: "stored-shares", + Target: &c.flagStoredShares, + Default: 0, // No default, because we need to check if was supplied + Completion: complete.PredictAnything, + Usage: "Number of unseal keys to store on an HSM. This must be equal to " + + "-key-shares. This is only used in HSM mode.", + }) + + // Deprecations + // TODO: remove in 0.9.0 + f.BoolVar(&BoolVar{ + Name: "check", // prefer -status + Target: &c.flagCheck, + Default: false, + Hidden: true, + Usage: "", + }) + f.BoolVar(&BoolVar{ + Name: "auto", // prefer -consul-auto + Target: &c.flagAuto, + Default: false, + Hidden: true, + Usage: "", + }) + + return set +} + +func (c *OperatorInitCommand) AutocompleteArgs() complete.Predictor { + return nil +} + +func (c *OperatorInitCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *OperatorInitCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + // Deprecations + // TODO: remove in 0.9.0 + if c.flagAuto { + c.UI.Warn(wrapAtLength("WARNING! -auto is deprecated. Please use " + + "-consul-auto instead. This will be removed the next major release " + + "of Vault.")) + c.flagConsulAuto = true + } + if c.flagCheck { + c.UI.Warn(wrapAtLength("WARNING! -check is deprecated. Please use " + + "-status instead. This will be removed in the next major release " + + "of Vault.")) + c.flagStatus = true + } + + // Build the initial init request + initReq := &api.InitRequest{ + SecretShares: c.flagKeyShares, + SecretThreshold: c.flagKeyThreshold, + PGPKeys: c.flagPGPKeys, + RootTokenPGPKey: c.flagRootTokenPGPKey, + + StoredShares: c.flagStoredShares, + RecoveryShares: c.flagRecoveryShares, + RecoveryThreshold: c.flagRecoveryThreshold, + RecoveryPGPKeys: c.flagRecoveryPGPKeys, + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + // Check auto mode + switch { + case c.flagStatus: + return c.status(client) + case c.flagConsulAuto: + return c.consulAuto(client, initReq) + default: + return c.init(client, initReq) + } +} + +// consulAuto enables auto-joining via Consul. +func (c *OperatorInitCommand) consulAuto(client *api.Client, req *api.InitRequest) int { + // Capture the client original address and reset it + originalAddr := client.Address() + defer client.SetAddress(originalAddr) + + // Create a client to communicate with Consul + consulClient, err := consulapi.NewClient(consulapi.DefaultConfig()) + if err != nil { + c.UI.Error(fmt.Sprintf("Failed to create Consul client:%v", err)) + return 1 + } + + // Pull the scheme from the Vault client to determine if the Consul agent + // should talk via HTTP or HTTPS. + addr := client.Address() + clientURL, err := url.Parse(addr) + if err != nil || clientURL == nil { + c.UI.Error(fmt.Sprintf("Failed to parse Vault address %s: %s", addr, err)) + return 1 + } + + var uninitedVaults []string + var initedVault string + + // Query the nodes belonging to the cluster + services, _, err := consulClient.Catalog().Service(c.flagConsulService, "", &consulapi.QueryOptions{ + AllowStale: true, + }) + if err == nil { + for _, service := range services { + // Set the address on the client temporarily + vaultAddr := (&url.URL{ + Scheme: clientURL.Scheme, + Host: fmt.Sprintf("%s:%d", service.ServiceAddress, service.ServicePort), + }).String() + client.SetAddress(vaultAddr) + + // Check the initialization status of the discovered node + inited, err := client.Sys().InitStatus() + if err != nil { + c.UI.Error(fmt.Sprintf("Error checking init status of %q: %s", vaultAddr, err)) + } + if inited { + initedVault = vaultAddr + break + } + + // If we got this far, we communicated successfully with Vault, but it + // was not initialized. + uninitedVaults = append(uninitedVaults, vaultAddr) + } + } + + // Get the correct export keywords and quotes for *nix vs Windows + export := "export" + quote := "\"" + if runtime.GOOS == "windows" { + export = "set" + quote = "" + } + + if initedVault != "" { + vaultURL, err := url.Parse(initedVault) + if err != nil { + c.UI.Error(fmt.Sprintf("Failed to parse Vault address %q: %s", initedVault, err)) + return 2 + } + vaultAddr := vaultURL.String() + + c.UI.Output(wrapAtLength(fmt.Sprintf( + "Discovered an initialized Vault node at %q with Consul service name "+ + "%q. Set the following environment variable to target the discovered "+ + "Vault server:", + vaultURL.String(), c.flagConsulService))) + c.UI.Output("") + c.UI.Output(fmt.Sprintf(" $ %s VAULT_ADDR=%s%s%s", export, quote, vaultAddr, quote)) + c.UI.Output("") + return 0 + } + + switch len(uninitedVaults) { + case 0: + c.UI.Error(fmt.Sprintf("No Vault nodes registered as %q in Consul", c.flagConsulService)) + return 2 + case 1: + // There was only one node found in the Vault cluster and it was + // uninitialized. + vaultURL, err := url.Parse(uninitedVaults[0]) + if err != nil { + c.UI.Error(fmt.Sprintf("Failed to parse Vault address %q: %s", initedVault, err)) + return 2 + } + vaultAddr := vaultURL.String() + + // Update the client to connect to this Vault server + client.SetAddress(vaultAddr) + + // Let the client know that initialization is perfomed on the + // discovered node. + c.UI.Output(wrapAtLength(fmt.Sprintf( + "Discovered an initialized Vault node at %q with Consul service name "+ + "%q. Set the following environment variable to target the discovered "+ + "Vault server:", + vaultURL.String(), c.flagConsulService))) + c.UI.Output("") + c.UI.Output(fmt.Sprintf(" $ %s VAULT_ADDR=%s%s%s", export, quote, vaultAddr, quote)) + c.UI.Output("") + c.UI.Output("Attempting to initialize it...") + c.UI.Output("") + + // Attempt to initialize it + return c.init(client, req) + default: + // If more than one Vault node were discovered, print out all of them, + // requiring the client to update VAULT_ADDR and to run init again. + c.UI.Output(wrapAtLength(fmt.Sprintf( + "Discovered %d uninitialized Vault servers with Consul service name "+ + "%q. To initialize these Vatuls, set any one of the following "+ + "environment variables and run \"vault init\":", + len(uninitedVaults), c.flagConsulService))) + c.UI.Output("") + + // Print valid commands to make setting the variables easier + for _, node := range uninitedVaults { + vaultURL, err := url.Parse(node) + if err != nil { + c.UI.Error(fmt.Sprintf("Failed to parse Vault address %q: %s", initedVault, err)) + return 2 + } + vaultAddr := vaultURL.String() + + c.UI.Output(fmt.Sprintf(" $ %s VAULT_ADDR=%s%s%s", export, quote, vaultAddr, quote)) + } + + c.UI.Output("") + return 0 + } +} + +func (c *OperatorInitCommand) init(client *api.Client, req *api.InitRequest) int { + resp, err := client.Sys().Init(req) + if err != nil { + c.UI.Error(fmt.Sprintf("Error initializing: %s", err)) + return 2 + } + + switch c.flagFormat { + case "yaml", "yml": + return c.initOutputYAML(req, resp) + case "json": + return c.initOutputJSON(req, resp) + case "table": + default: + c.UI.Error(fmt.Sprintf("Unknown format: %s", c.flagFormat)) + return 1 + } + + for i, key := range resp.Keys { + if resp.KeysB64 != nil && len(resp.KeysB64) == len(resp.Keys) { + c.UI.Output(fmt.Sprintf("Unseal Key %d: %s", i+1, resp.KeysB64[i])) + } else { + c.UI.Output(fmt.Sprintf("Unseal Key %d: %s", i+1, key)) + } + } + for i, key := range resp.RecoveryKeys { + if resp.RecoveryKeysB64 != nil && len(resp.RecoveryKeysB64) == len(resp.RecoveryKeys) { + c.UI.Output(fmt.Sprintf("Recovery Key %d: %s", i+1, resp.RecoveryKeysB64[i])) + } else { + c.UI.Output(fmt.Sprintf("Recovery Key %d: %s", i+1, key)) + } + } + + c.UI.Output("") + c.UI.Output(fmt.Sprintf("Initial Root Token: %s", resp.RootToken)) + + if req.StoredShares < 1 { + c.UI.Output("") + c.UI.Output(wrapAtLength(fmt.Sprintf( + "Vault initialized with %d key shares an a key threshold of %d. Please "+ + "securely distributed the key shares printed above. When the Vault is "+ + "re-sealed, restarted, or stopped, you must supply at least %d of "+ + "these keys to unseal it before it can start servicing requests.", + req.SecretShares, + req.SecretThreshold, + req.SecretThreshold))) + + c.UI.Output("") + c.UI.Output(wrapAtLength(fmt.Sprintf( + "Vault does not store the generated master key. Without at least %d "+ + "key to reconstruct the master key, Vault will remain permanently "+ + "sealed!", + req.SecretThreshold))) + + c.UI.Output("") + c.UI.Output(wrapAtLength( + "It is possible to generate new unseal keys, provided you have a quorum " + + "of existing unseal keys shares. See \"vault rekey\" for more " + + "information.")) + } else { + c.UI.Output("") + c.UI.Output("Success! Vault is initialized") + } + + if len(resp.RecoveryKeys) > 0 { + c.UI.Output("") + c.UI.Output(wrapAtLength(fmt.Sprintf( + "Recovery key initialized with %d key shares and a key threshold of %d. "+ + "Please securely distribute the key shares printed above.", + req.RecoveryShares, + req.RecoveryThreshold))) + } + + return 0 +} + +// initOutputYAML outputs the init output as YAML. +func (c *OperatorInitCommand) initOutputYAML(req *api.InitRequest, resp *api.InitResponse) int { + b, err := yaml.Marshal(newMachineInit(req, resp)) + if err != nil { + c.UI.Error(fmt.Sprintf("Error marshaling YAML: %s", err)) + return 2 + } + return PrintRaw(c.UI, strings.TrimSpace(string(b))) +} + +// initOutputJSON outputs the init output as JSON. +func (c *OperatorInitCommand) initOutputJSON(req *api.InitRequest, resp *api.InitResponse) int { + b, err := json.Marshal(newMachineInit(req, resp)) + if err != nil { + c.UI.Error(fmt.Sprintf("Error marshaling JSON: %s", err)) + return 2 + } + return PrintRaw(c.UI, strings.TrimSpace(string(b))) +} + +// status inspects the init status of vault and returns an appropriate error +// code and message. +func (c *OperatorInitCommand) status(client *api.Client) int { + inited, err := client.Sys().InitStatus() + if err != nil { + c.UI.Error(fmt.Sprintf("Error checking init status: %s", err)) + return 1 // Normally we'd return 2, but 2 means something special here + } + + if inited { + c.UI.Output("Vault is initialized") + return 0 + } + + c.UI.Output("Vault is not initialized") + return 2 +} + +// machineInit is used to output information about the init command. +type machineInit struct { + UnsealKeysB64 []string `json:"unseal_keys_b64"` + UnsealKeysHex []string `json:"unseal_keys_hex"` + UnsealShares int `json:"unseal_shares"` + UnsealThreshold int `json:"unseal_threshold"` + RecoveryKeysB64 []string `json:"recovery_keys_b64"` + RecoveryKeysHex []string `json:"recovery_keys_hex"` + RecoveryShares int `json:"recovery_keys_shares"` + RecoveryThreshold int `json:"recovery_keys_threshold"` + RootToken string `json:"root_token"` +} + +func newMachineInit(req *api.InitRequest, resp *api.InitResponse) *machineInit { + init := &machineInit{} + + init.UnsealKeysHex = make([]string, len(resp.Keys)) + for i, v := range resp.Keys { + init.UnsealKeysHex[i] = v + } + + init.UnsealKeysB64 = make([]string, len(resp.KeysB64)) + for i, v := range resp.KeysB64 { + init.UnsealKeysB64[i] = v + } + + init.UnsealShares = req.SecretShares + init.UnsealThreshold = req.SecretThreshold + + init.RecoveryKeysHex = make([]string, len(resp.RecoveryKeys)) + for i, v := range resp.RecoveryKeys { + init.RecoveryKeysHex[i] = v + } + + init.RecoveryKeysB64 = make([]string, len(resp.RecoveryKeysB64)) + for i, v := range resp.RecoveryKeysB64 { + init.RecoveryKeysB64[i] = v + } + + init.RecoveryShares = req.RecoveryShares + init.RecoveryThreshold = req.RecoveryThreshold + + init.RootToken = resp.RootToken + + return init +} diff --git a/command/operator_init_test.go b/command/operator_init_test.go new file mode 100644 index 0000000000..f398dd3756 --- /dev/null +++ b/command/operator_init_test.go @@ -0,0 +1,361 @@ +package command + +import ( + "fmt" + "os" + "regexp" + "strconv" + "strings" + "testing" + + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/pgpkeys" + "github.com/mitchellh/cli" +) + +func testOperatorInitCommand(tb testing.TB) (*cli.MockUi, *OperatorInitCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &OperatorInitCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestOperatorInitCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "pgp_keys_multi", + []string{ + "-pgp-keys", "keybase:hashicorp", + "-pgp-keys", "keybase:jefferai", + }, + "can only be specified once", + 1, + }, + { + "root_token_pgp_key_multi", + []string{ + "-root-token-pgp-key", "keybase:hashicorp", + "-root-token-pgp-key", "keybase:jefferai", + }, + "can only be specified once", + 1, + }, + { + "root_token_pgp_key_multi_inline", + []string{ + "-root-token-pgp-key", "keybase:hashicorp,keybase:jefferai", + }, + "can only specify one pgp key", + 1, + }, + { + "recovery_pgp_keys_multi", + []string{ + "-recovery-pgp-keys", "keybase:hashicorp", + "-recovery-pgp-keys", "keybase:jefferai", + }, + "can only be specified once", + 1, + }, + { + "key_shares_pgp_less", + []string{ + "-key-shares", "10", + "-pgp-keys", "keybase:jefferai,keybase:sethvargo", + }, + "incorrect number", + 2, + }, + { + "key_shares_pgp_more", + []string{ + "-key-shares", "1", + "-pgp-keys", "keybase:jefferai,keybase:sethvargo", + }, + "incorrect number", + 2, + }, + } + + t.Run("validations", func(t *testing.T) { + t.Parallel() + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testOperatorInitCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) + + t.Run("status", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerUninit(t) + defer closer() + + ui, cmd := testOperatorInitCommand(t) + cmd.client = client + + // Verify the non-init response code + code := cmd.Run([]string{ + "-status", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String()) + } + + // Now init to verify the init response code + if _, err := client.Sys().Init(&api.InitRequest{ + SecretShares: 1, + SecretThreshold: 1, + }); err != nil { + t.Fatal(err) + } + + // Verify the init response code + ui, cmd = testOperatorInitCommand(t) + cmd.client = client + code = cmd.Run([]string{ + "-status", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String()) + } + }) + + t.Run("default", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerUninit(t) + defer closer() + + ui, cmd := testOperatorInitCommand(t) + cmd.client = client + + code := cmd.Run([]string{}) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String()) + } + + init, err := client.Sys().InitStatus() + if err != nil { + t.Fatal(err) + } + if !init { + t.Error("expected initialized") + } + + re := regexp.MustCompile(`Unseal Key \d+: (.+)`) + output := ui.OutputWriter.String() + match := re.FindAllStringSubmatch(output, -1) + if len(match) < 5 || len(match[0]) < 2 { + t.Fatalf("no match: %#v", match) + } + + keys := make([]string, len(match)) + for i := range match { + keys[i] = match[i][1] + } + + // Try unsealing with those keys - only use 3, which is the default + // threshold. + for i, key := range keys[:3] { + resp, err := client.Sys().Unseal(key) + if err != nil { + t.Fatal(err) + } + + exp := (i + 1) % 3 // 1, 2, 0 + if resp.Progress != exp { + t.Errorf("expected %d to be %d", resp.Progress, exp) + } + } + + status, err := client.Sys().SealStatus() + if err != nil { + t.Fatal(err) + } + if status.Sealed { + t.Errorf("expected vault to be unsealed: %#v", status) + } + }) + + t.Run("custom_shares_threshold", func(t *testing.T) { + t.Parallel() + + keyShares, keyThreshold := 20, 15 + + client, closer := testVaultServerUninit(t) + defer closer() + + ui, cmd := testOperatorInitCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "-key-shares", strconv.Itoa(keyShares), + "-key-threshold", strconv.Itoa(keyThreshold), + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String()) + } + + init, err := client.Sys().InitStatus() + if err != nil { + t.Fatal(err) + } + if !init { + t.Error("expected initialized") + } + + re := regexp.MustCompile(`Unseal Key \d+: (.+)`) + output := ui.OutputWriter.String() + match := re.FindAllStringSubmatch(output, -1) + if len(match) < keyShares || len(match[0]) < 2 { + t.Fatalf("no match: %#v", match) + } + + keys := make([]string, len(match)) + for i := range match { + keys[i] = match[i][1] + } + + // Try unsealing with those keys - only use 3, which is the default + // threshold. + for i, key := range keys[:keyThreshold] { + resp, err := client.Sys().Unseal(key) + if err != nil { + t.Fatal(err) + } + + exp := (i + 1) % keyThreshold + if resp.Progress != exp { + t.Errorf("expected %d to be %d", resp.Progress, exp) + } + } + + status, err := client.Sys().SealStatus() + if err != nil { + t.Fatal(err) + } + if status.Sealed { + t.Errorf("expected vault to be unsealed: %#v", status) + } + }) + + t.Run("pgp", func(t *testing.T) { + t.Parallel() + + tempDir, pubFiles, err := getPubKeyFiles(t) + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + client, closer := testVaultServerUninit(t) + defer closer() + + ui, cmd := testOperatorInitCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "-key-shares", "4", + "-key-threshold", "2", + "-pgp-keys", fmt.Sprintf("%s,@%s, %s, %s ", + pubFiles[0], pubFiles[1], pubFiles[2], pubFiles[3]), + "-root-token-pgp-key", pubFiles[0], + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String()) + } + + re := regexp.MustCompile(`Unseal Key \d+: (.+)`) + output := ui.OutputWriter.String() + match := re.FindAllStringSubmatch(output, -1) + if len(match) < 4 || len(match[0]) < 2 { + t.Fatalf("no match: %#v", match) + } + + keys := make([]string, len(match)) + for i := range match { + keys[i] = match[i][1] + } + + // Try unsealing with one key + decryptedKey := testPGPDecrypt(t, pgpkeys.TestPrivKey1, keys[0]) + if _, err := client.Sys().Unseal(decryptedKey); err != nil { + t.Fatal(err) + } + + // Decrypt the root token + reToken := regexp.MustCompile(`Root Token: (.+)`) + match = reToken.FindAllStringSubmatch(output, -1) + if len(match) < 1 || len(match[0]) < 2 { + t.Fatalf("no match") + } + root := match[0][1] + decryptedRoot := testPGPDecrypt(t, pgpkeys.TestPrivKey1, root) + + if l, exp := len(decryptedRoot), 36; l != exp { + t.Errorf("expected %d to be %d", l, exp) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testOperatorInitCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "secret/foo", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error initializing: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testOperatorInitCommand(t) + assertNoTabs(t, cmd) + }) +} diff --git a/command/operator_key_status.go b/command/operator_key_status.go new file mode 100644 index 0000000000..6558290ca7 --- /dev/null +++ b/command/operator_key_status.go @@ -0,0 +1,74 @@ +package command + +import ( + "fmt" + "strings" + + "github.com/mitchellh/cli" + "github.com/posener/complete" +) + +var _ cli.Command = (*OperatorKeyStatusCommand)(nil) +var _ cli.CommandAutocomplete = (*OperatorKeyStatusCommand)(nil) + +type OperatorKeyStatusCommand struct { + *BaseCommand +} + +func (c *OperatorKeyStatusCommand) Synopsis() string { + return "Provides information about the active encryption key" +} + +func (c *OperatorKeyStatusCommand) Help() string { + helpText := ` +Usage: vault operator key-status [options] + + Provides information about the active encryption key. Specifically, + the current key term and the key installation time. + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *OperatorKeyStatusCommand) Flags() *FlagSets { + return c.flagSet(FlagSetHTTP) +} + +func (c *OperatorKeyStatusCommand) AutocompleteArgs() complete.Predictor { + return nil +} + +func (c *OperatorKeyStatusCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *OperatorKeyStatusCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + if len(args) > 0 { + c.UI.Error(fmt.Sprintf("Too many arguments (expected 0, got %d)", len(args))) + return 1 + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + status, err := client.Sys().KeyStatus() + if err != nil { + c.UI.Error(fmt.Sprintf("Error reading key status: %s", err)) + return 2 + } + + c.UI.Output(printKeyStatus(status)) + return 0 +} diff --git a/command/operator_key_status_test.go b/command/operator_key_status_test.go new file mode 100644 index 0000000000..5c1aada3e6 --- /dev/null +++ b/command/operator_key_status_test.go @@ -0,0 +1,110 @@ +package command + +import ( + "strings" + "testing" + + "github.com/mitchellh/cli" +) + +func testOperatorKeyStatusCommand(tb testing.TB) (*cli.MockUi, *OperatorKeyStatusCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &OperatorKeyStatusCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestOperatorKeyStatusCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "too_many_args", + []string{"foo", "bar"}, + "Too many arguments", + 1, + }, + } + + t.Run("validations", func(t *testing.T) { + t.Parallel() + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ui, cmd := testOperatorKeyStatusCommand(t) + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) + + t.Run("integration", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testOperatorKeyStatusCommand(t) + cmd.client = client + + code := cmd.Run([]string{}) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Key Term" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testOperatorKeyStatusCommand(t) + cmd.client = client + + code := cmd.Run([]string{}) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error reading key status: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testOperatorKeyStatusCommand(t) + assertNoTabs(t, cmd) + }) +} diff --git a/command/operator_rekey.go b/command/operator_rekey.go new file mode 100644 index 0000000000..595f1f3af9 --- /dev/null +++ b/command/operator_rekey.go @@ -0,0 +1,638 @@ +package command + +import ( + "bytes" + "fmt" + "io" + "os" + "strings" + + "github.com/fatih/structs" + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/password" + "github.com/hashicorp/vault/helper/pgpkeys" + "github.com/mitchellh/cli" + "github.com/posener/complete" +) + +var _ cli.Command = (*OperatorRekeyCommand)(nil) +var _ cli.CommandAutocomplete = (*OperatorRekeyCommand)(nil) + +type OperatorRekeyCommand struct { + *BaseCommand + + flagCancel bool + flagInit bool + flagKeyShares int + flagKeyThreshold int + flagNonce string + flagPGPKeys []string + flagStatus bool + flagTarget string + + // Backup options + flagBackup bool + flagBackupDelete bool + flagBackupRetrieve bool + + // Deprecations + // TODO: remove in 0.9.0 + flagDelete bool + flagRecoveryKey bool + flagRetrieve bool + + testStdin io.Reader // for tests +} + +func (c *OperatorRekeyCommand) Synopsis() string { + return "Generates new unseal keys" +} + +func (c *OperatorRekeyCommand) Help() string { + helpText := ` +Usage: vault rekey [options] [KEY] + + Generates a new set of unseal keys. This can optionally change the total + number of key shares or the required threshold of those key shares to + reconstruct the master key. This operation is zero downtime, but it requires + the Vault is unsealed and a quorum of existing unseal keys are provided. + + An unseal key may be provided directly on the command line as an argument to + the command. If key is specified as "-", the command will read from stdin. If + a TTY is available, the command will prompt for text. + + Initialize a rekey: + + $ vault operator rekey \ + -init \ + -key-shares=15 \ + -key-threshold=9 + + Rekey and encrypt the resulting unseal keys with PGP: + + $ vault operator rekey \ + -init \ + -key-shares=3 \ + -key-threshold=2 \ + -pgp-keys="keybase:hashicorp,keybase:jefferai,keybase:sethvargo" + + Store encrypted PGP keys in Vault's core: + + $ vault operator rekey \ + -init \ + -pgp-keys="..." \ + -backup + + Retrieve backed-up unseal keys: + + $ vault operator rekey -backup-retrieve + + Delete backed-up unseal keys: + + $ vault operator rekey -backup-delete + +` + c.Flags().Help() + return strings.TrimSpace(helpText) +} + +func (c *OperatorRekeyCommand) Flags() *FlagSets { + set := c.flagSet(FlagSetHTTP) + + f := set.NewFlagSet("Common Options") + + f.BoolVar(&BoolVar{ + Name: "init", + Target: &c.flagInit, + Default: false, + Usage: "Initialize the rekeying operation. This can only be done if no " + + "rekeying operation is in progress. Customize the new number of key " + + "shares and key threshold using the -key-shares and -key-threshold " + + "flags.", + }) + + f.BoolVar(&BoolVar{ + Name: "cancel", + Target: &c.flagCancel, + Default: false, + Usage: "Reset the rekeying progress. This will discard any submitted " + + "unseal keys or configuration.", + }) + + f.BoolVar(&BoolVar{ + Name: "status", + Target: &c.flagStatus, + Default: false, + Usage: "Print the status of the current attempt without providing an " + + "unseal key.", + }) + + f.IntVar(&IntVar{ + Name: "key-shares", + Aliases: []string{"n"}, + Target: &c.flagKeyShares, + Default: 5, + Completion: complete.PredictAnything, + Usage: "Number of key shares to split the generated master key into. " + + "This is the number of \"unseal keys\" to generate.", + }) + + f.IntVar(&IntVar{ + Name: "key-threshold", + Aliases: []string{"t"}, + Target: &c.flagKeyThreshold, + Default: 3, + Completion: complete.PredictAnything, + Usage: "Number of key shares required to reconstruct the master key. " + + "This must be less than or equal to -key-shares.", + }) + + f.StringVar(&StringVar{ + Name: "nonce", + Target: &c.flagNonce, + Default: "", + EnvVar: "", + Completion: complete.PredictAnything, + Usage: "Nonce value provided at initialization. The same nonce value " + + "must be provided with each unseal key.", + }) + + f.StringVar(&StringVar{ + Name: "target", + Target: &c.flagTarget, + Default: "barrier", + EnvVar: "", + Completion: complete.PredictSet("barrier", "recovery"), + Usage: "Target for rekeying. \"recovery\" only applies when HSM support " + + "is enabled.", + }) + + f.VarFlag(&VarFlag{ + Name: "pgp-keys", + Value: (*pgpkeys.PubKeyFilesFlag)(&c.flagPGPKeys), + Completion: complete.PredictAnything, + Usage: "Comma-separated list of paths to files on disk containing " + + "public GPG keys OR a comma-separated list of Keybase usernames using " + + "the format \"keybase:\". When supplied, the generated " + + "unseal keys will be encrypted and base64-encoded in the order " + + "specified in this list.", + }) + + f = set.NewFlagSet("Backup Options") + + f.BoolVar(&BoolVar{ + Name: "backup", + Target: &c.flagBackup, + Default: false, + Usage: "Store a backup of the current PGP encrypted unseal keys in " + + "Vault's core. The encrypted values can be recovered in the event of " + + "failure or discarded after success. See the -backup-delete and " + + "-backup-retrieve options for more information. This option only " + + "applies when the existing unseal keys were PGP encrypted.", + }) + + f.BoolVar(&BoolVar{ + Name: "backup-delete", + Target: &c.flagBackupDelete, + Default: false, + Usage: "Delete any stored backup unseal keys.", + }) + + f.BoolVar(&BoolVar{ + Name: "backup-retrieve", + Target: &c.flagBackupRetrieve, + Default: false, + Usage: "Retrieve the backed-up unseal keys. This option is only available " + + "if the PGP keys were provided and the backup has not been deleted.", + }) + + // Deprecations + // TODO: remove in 0.9.0 + f.BoolVar(&BoolVar{ + Name: "delete", // prefer -backup-delete + Target: &c.flagDelete, + Default: false, + Hidden: true, + Usage: "", + }) + + f.BoolVar(&BoolVar{ + Name: "retrieve", // prefer -backup-retrieve + Target: &c.flagRetrieve, + Default: false, + Hidden: true, + Usage: "", + }) + + f.BoolVar(&BoolVar{ + Name: "recovery-key", // prefer -target=recovery + Target: &c.flagRecoveryKey, + Default: false, + Hidden: true, + Usage: "", + }) + + return set +} + +func (c *OperatorRekeyCommand) AutocompleteArgs() complete.Predictor { + return complete.PredictAnything +} + +func (c *OperatorRekeyCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *OperatorRekeyCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + if len(args) > 1 { + c.UI.Error(fmt.Sprintf("Too many arguments (expected 0-1, got %d)", len(args))) + return 1 + } + + // Deprecations + // TODO: remove in 0.9.0 + if c.flagDelete { + c.UI.Warn(wrapAtLength( + "WARNING! The -delete flag is deprecated. Please use -backup-delete " + + "instead. This flag will be removed in the next major release of " + + "Vault.")) + c.flagBackupDelete = true + } + if c.flagRetrieve { + c.UI.Warn(wrapAtLength( + "WARNING! The -retrieve flag is deprecated. Please use -backup-retrieve " + + "instead. This flag will be removed in the next major release of " + + "Vault.")) + c.flagBackupRetrieve = true + } + if c.flagRecoveryKey { + c.UI.Warn(wrapAtLength( + "WARNING! The -recovery-key flag is deprecated. Please use -target=recovery " + + "instead. This flag will be removed in the next major release of " + + "Vault.")) + c.flagTarget = "recovery" + } + + // Create the client + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + switch { + case c.flagBackupDelete: + return c.backupDelete(client) + case c.flagBackupRetrieve: + return c.backupRetrieve(client) + case c.flagCancel: + return c.cancel(client) + case c.flagInit: + return c.init(client) + case c.flagStatus: + return c.status(client) + default: + // If there are no other flags, prompt for an unseal key. + key := "" + if len(args) > 0 { + key = strings.TrimSpace(args[0]) + } + return c.provide(client, key) + } +} + +// init starts the rekey process. +func (c *OperatorRekeyCommand) init(client *api.Client) int { + // Handle the different API requests + var fn func(*api.RekeyInitRequest) (*api.RekeyStatusResponse, error) + switch strings.ToLower(strings.TrimSpace(c.flagTarget)) { + case "barrier": + fn = client.Sys().RekeyInit + case "recovery", "hsm": + fn = client.Sys().RekeyRecoveryKeyInit + default: + c.UI.Error(fmt.Sprintf("Unknown target: %s", c.flagTarget)) + return 1 + } + + // Make the request + status, err := fn(&api.RekeyInitRequest{ + SecretShares: c.flagKeyShares, + SecretThreshold: c.flagKeyThreshold, + PGPKeys: c.flagPGPKeys, + Backup: c.flagBackup, + }) + if err != nil { + c.UI.Error(fmt.Sprintf("Error initializing rekey: %s", err)) + return 2 + } + + // Print warnings about recovery, etc. + if len(c.flagPGPKeys) == 0 { + c.UI.Warn(wrapAtLength( + "WARNING! If you lose the keys after they are returned, there is no " + + "recovery. Consider canceling this operation and re-initializing " + + "with the -pgp-keys flag to protect the returned unseal keys along " + + "with -backup to allow recovery of the encrypted keys in case of " + + "emergency. You can delete the stored keys later using the -delete " + + "flag.")) + c.UI.Output("") + } + if len(c.flagPGPKeys) > 0 && !c.flagBackup { + c.UI.Warn(wrapAtLength( + "WARNING! You are using PGP keys for encrypted the resulting unseal " + + "keys, but you did not enable the option to backup the keys to " + + "Vault's core. If you lose the encrypted keys after they are " + + "returned, you will not be able to recover them. Consider canceling " + + "this operation and re-running with -backup to allow recovery of the " + + "encrypted unseal keys in case of emergency. You can delete the " + + "stored keys later using the -delete flag.")) + c.UI.Output("") + } + + // Provide the current status + return c.printStatus(status) +} + +// cancel is used to abort the rekey process. +func (c *OperatorRekeyCommand) cancel(client *api.Client) int { + // Handle the different API requests + var fn func() error + switch strings.ToLower(strings.TrimSpace(c.flagTarget)) { + case "barrier": + fn = client.Sys().RekeyCancel + case "recovery", "hsm": + fn = client.Sys().RekeyRecoveryKeyCancel + default: + c.UI.Error(fmt.Sprintf("Unknown target: %s", c.flagTarget)) + return 1 + } + + // Make the request + if err := fn(); err != nil { + c.UI.Error(fmt.Sprintf("Error canceling rekey: %s", err)) + return 2 + } + + c.UI.Output("Success! Canceled rekeying (if it was started)") + return 0 +} + +// provide prompts the user for the seal key and posts it to the update root +// endpoint. If this is the last unseal, this function outputs it. +func (c *OperatorRekeyCommand) provide(client *api.Client, key string) int { + var statusFn func() (*api.RekeyStatusResponse, error) + var updateFn func(string, string) (*api.RekeyUpdateResponse, error) + + switch strings.ToLower(strings.TrimSpace(c.flagTarget)) { + case "barrier": + statusFn = client.Sys().RekeyStatus + updateFn = client.Sys().RekeyUpdate + case "recovery", "hsm": + statusFn = client.Sys().RekeyRecoveryKeyStatus + updateFn = client.Sys().RekeyRecoveryKeyUpdate + default: + c.UI.Error(fmt.Sprintf("Unknown target: %s", c.flagTarget)) + return 1 + } + + status, err := statusFn() + if err != nil { + c.UI.Error(fmt.Sprintf("Error getting rekey status: %s", err)) + return 2 + } + + // Verify a root token generation is in progress. If there is not one in + // progress, return an error instructing the user to start one. + if !status.Started { + c.UI.Error(wrapAtLength( + "No rekey is in progress. Start a rekey process by running " + + "\"vault rekey -init\".")) + return 1 + } + + var nonce string + + switch key { + case "-": // Read from stdin + nonce = c.flagNonce + + // Pull our fake stdin if needed + stdin := (io.Reader)(os.Stdin) + if c.testStdin != nil { + stdin = c.testStdin + } + + var buf bytes.Buffer + if _, err := io.Copy(&buf, stdin); err != nil { + c.UI.Error(fmt.Sprintf("Failed to read from stdin: %s", err)) + return 1 + } + + key = buf.String() + case "": // Prompt using the tty + // Nonce value is not required if we are prompting via the terminal + nonce = status.Nonce + + w := getWriterFromUI(c.UI) + fmt.Fprintf(w, "Rekey operation nonce: %s\n", nonce) + fmt.Fprintf(w, "Unseal Key (will be hidden): ") + key, err = password.Read(os.Stdin) + fmt.Fprintf(w, "\n") + if err != nil { + if err == password.ErrInterrupted { + c.UI.Error("user canceled") + return 1 + } + + c.UI.Error(wrapAtLength(fmt.Sprintf("An error occurred attempting to "+ + "ask for the unseal key. The raw error message is shown below, but "+ + "usually this is because you attempted to pipe a value into the "+ + "command or you are executing outside of a terminal (tty). If you "+ + "want to pipe the value, pass \"-\" as the argument to read from "+ + "stdin. The raw error was: %s", err))) + return 1 + } + default: // Supplied directly as an arg + nonce = c.flagNonce + } + + // Trim any whitespace from they key, especially since we might have + // prompted the user for it. + key = strings.TrimSpace(key) + + // Verify we have a nonce value + if nonce == "" { + c.UI.Error("Missing nonce value: specify it via the -nonce flag") + return 1 + } + + // Provide the key, this may potentially complete the update + resp, err := updateFn(key, nonce) + if err != nil { + c.UI.Error(fmt.Sprintf("Error posting unseal key: %s", err)) + return 2 + } + + if !resp.Complete { + return c.status(client) + } + + return c.printUnsealKeys(status, resp) +} + +// status is used just to fetch and dump the status. +func (c *OperatorRekeyCommand) status(client *api.Client) int { + // Handle the different API requests + var fn func() (*api.RekeyStatusResponse, error) + switch strings.ToLower(strings.TrimSpace(c.flagTarget)) { + case "barrier": + fn = client.Sys().RekeyStatus + case "recovery", "hsm": + fn = client.Sys().RekeyRecoveryKeyStatus + default: + c.UI.Error(fmt.Sprintf("Unknown target: %s", c.flagTarget)) + return 1 + } + + // Make the request + status, err := fn() + if err != nil { + c.UI.Error(fmt.Sprintf("Error reading rekey status: %s", err)) + return 2 + } + + return c.printStatus(status) +} + +// backupRetrieve retrieves the stored backup keys. +func (c *OperatorRekeyCommand) backupRetrieve(client *api.Client) int { + // Handle the different API requests + var fn func() (*api.RekeyRetrieveResponse, error) + switch strings.ToLower(strings.TrimSpace(c.flagTarget)) { + case "barrier": + fn = client.Sys().RekeyRetrieveBackup + case "recovery", "hsm": + fn = client.Sys().RekeyRetrieveRecoveryBackup + default: + c.UI.Error(fmt.Sprintf("Unknown target: %s", c.flagTarget)) + return 1 + } + + // Make the request + storedKeys, err := fn() + if err != nil { + c.UI.Error(fmt.Sprintf("Error retrieving rekey stored keys: %s", err)) + return 2 + } + + secret := &api.Secret{ + Data: structs.New(storedKeys).Map(), + } + + return OutputSecret(c.UI, "table", secret) +} + +// backupDelete deletes the stored backup keys. +func (c *OperatorRekeyCommand) backupDelete(client *api.Client) int { + // Handle the different API requests + var fn func() error + switch strings.ToLower(strings.TrimSpace(c.flagTarget)) { + case "barrier": + fn = client.Sys().RekeyDeleteBackup + case "recovery", "hsm": + fn = client.Sys().RekeyDeleteRecoveryBackup + default: + c.UI.Error(fmt.Sprintf("Unknown target: %s", c.flagTarget)) + return 1 + } + + // Make the request + if err := fn(); err != nil { + c.UI.Error(fmt.Sprintf("Error deleting rekey stored keys: %s", err)) + return 2 + } + + c.UI.Output("Success! Delete stored keys (if they existed)") + return 0 +} + +// printStatus dumps the status to output +func (c *OperatorRekeyCommand) printStatus(status *api.RekeyStatusResponse) int { + out := []string{} + out = append(out, "Key | Value") + out = append(out, fmt.Sprintf("Nonce | %s", status.Nonce)) + out = append(out, fmt.Sprintf("Started | %t", status.Started)) + + if status.Started { + out = append(out, fmt.Sprintf("Rekey Progress | %d/%d", status.Progress, status.Required)) + out = append(out, fmt.Sprintf("New Shares | %d", status.N)) + out = append(out, fmt.Sprintf("New Threshold | %d", status.T)) + } + + if len(status.PGPFingerprints) > 0 { + out = append(out, fmt.Sprintf("PGP Fingerprints | %s", status.PGPFingerprints)) + out = append(out, fmt.Sprintf("Backup | %t", status.Backup)) + } + + c.UI.Output(tableOutput(out, nil)) + return 0 +} + +func (c *OperatorRekeyCommand) printUnsealKeys(status *api.RekeyStatusResponse, resp *api.RekeyUpdateResponse) int { + // Space between the key prompt, if any, and the output + c.UI.Output("") + + // Provide the keys + var haveB64 bool + if resp.KeysB64 != nil && len(resp.KeysB64) == len(resp.Keys) { + haveB64 = true + } + for i, key := range resp.Keys { + if len(resp.PGPFingerprints) > 0 { + if haveB64 { + c.UI.Output(fmt.Sprintf("Key %d fingerprint: %s; value: %s", i+1, resp.PGPFingerprints[i], resp.KeysB64[i])) + } else { + c.UI.Output(fmt.Sprintf("Key %d fingerprint: %s; value: %s", i+1, resp.PGPFingerprints[i], key)) + } + } else { + if haveB64 { + c.UI.Output(fmt.Sprintf("Key %d: %s", i+1, resp.KeysB64[i])) + } else { + c.UI.Output(fmt.Sprintf("Key %d: %s", i+1, key)) + } + } + } + + c.UI.Output("") + c.UI.Output(fmt.Sprintf("Operation nonce: %s", resp.Nonce)) + + if len(resp.PGPFingerprints) > 0 && resp.Backup { + c.UI.Output("") + c.UI.Output(wrapAtLength(fmt.Sprintf( + "The encrypted unseal keys are backed up to \"core/unseal-keys-backup\"" + + "in the storage backend. Remove these keys at any time using " + + "\"vault rekey -delete-backup\". Vault does not automatically remove " + + "these keys.", + ))) + } + + c.UI.Output("") + c.UI.Output(wrapAtLength(fmt.Sprintf( + "Vault rekeyed with %d key shares an a key threshold of %d. Please "+ + "securely distributed the key shares printed above. When the Vault is "+ + "re-sealed, restarted, or stopped, you must supply at least %d of "+ + "these keys to unseal it before it can start servicing requests.", + status.N, + status.T, + status.T))) + + return 0 +} diff --git a/command/operator_rekey_test.go b/command/operator_rekey_test.go new file mode 100644 index 0000000000..47154c73ec --- /dev/null +++ b/command/operator_rekey_test.go @@ -0,0 +1,515 @@ +package command + +import ( + "io" + "reflect" + "regexp" + "strings" + "testing" + + "github.com/hashicorp/vault/api" + "github.com/mitchellh/cli" +) + +func testOperatorRekeyCommand(tb testing.TB) (*cli.MockUi, *OperatorRekeyCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &OperatorRekeyCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestOperatorRekeyCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "pgp_keys_multi", + []string{ + "-init", + "-pgp-keys", "keybase:hashicorp", + "-pgp-keys", "keybase:jefferai", + }, + "can only be specified once", + 1, + }, + { + "key_shares_pgp_less", + []string{ + "-init", + "-key-shares", "10", + "-pgp-keys", "keybase:jefferai,keybase:sethvargo", + }, + "incorrect number", + 2, + }, + { + "key_shares_pgp_more", + []string{ + "-init", + "-key-shares", "1", + "-pgp-keys", "keybase:jefferai,keybase:sethvargo", + }, + "incorrect number", + 2, + }, + } + + t.Run("validations", func(t *testing.T) { + t.Parallel() + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testOperatorRekeyCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) + + t.Run("status", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testOperatorRekeyCommand(t) + cmd.client = client + + // Verify the non-init response + code := cmd.Run([]string{ + "-status", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String()) + } + + expected := "Nonce" + combined := ui.OutputWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + // Now init to verify the init response + if _, err := client.Sys().RekeyInit(&api.RekeyInitRequest{ + SecretShares: 1, + SecretThreshold: 1, + }); err != nil { + t.Fatal(err) + } + + // Verify the init response + ui, cmd = testOperatorRekeyCommand(t) + cmd.client = client + code = cmd.Run([]string{ + "-status", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String()) + } + + expected = "Progress" + combined = ui.OutputWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("cancel", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + // Initialize a rekey + if _, err := client.Sys().RekeyInit(&api.RekeyInitRequest{ + SecretShares: 1, + SecretThreshold: 1, + }); err != nil { + t.Fatal(err) + } + + ui, cmd := testOperatorRekeyCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "-cancel", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Success! Canceled rekeying" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + status, err := client.Sys().GenerateRootStatus() + if err != nil { + t.Fatal(err) + } + + if status.Started { + t.Errorf("expected status to be canceled: %#v", status) + } + }) + + t.Run("init", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testOperatorRekeyCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "-init", + "-key-shares", "1", + "-key-threshold", "1", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String()) + } + + expected := "Nonce" + combined := ui.OutputWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + status, err := client.Sys().RekeyStatus() + if err != nil { + t.Fatal(err) + } + if !status.Started { + t.Errorf("expected status to be started: %#v", status) + } + }) + + t.Run("init_pgp", func(t *testing.T) { + t.Parallel() + + pgpKey := "keybase:hashicorp" + pgpFingerprints := []string{"91a6e7f85d05c65630bef18951852d87348ffc4c"} + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testOperatorRekeyCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "-init", + "-key-shares", "1", + "-key-threshold", "1", + "-pgp-keys", pgpKey, + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String()) + } + + expected := "Nonce" + combined := ui.OutputWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + status, err := client.Sys().RekeyStatus() + if err != nil { + t.Fatal(err) + } + if !status.Started { + t.Errorf("expected status to be started: %#v", status) + } + if !reflect.DeepEqual(status.PGPFingerprints, pgpFingerprints) { + t.Errorf("expected %#v to be %#v", status.PGPFingerprints, pgpFingerprints) + } + }) + + t.Run("provide_arg", func(t *testing.T) { + t.Parallel() + + client, keys, closer := testVaultServerUnseal(t) + defer closer() + + // Initialize a rekey + status, err := client.Sys().RekeyInit(&api.RekeyInitRequest{ + SecretShares: 1, + SecretThreshold: 1, + }) + if err != nil { + t.Fatal(err) + } + nonce := status.Nonce + + // Supply the first n-1 unseal keys + for _, key := range keys[:len(keys)-1] { + ui, cmd := testOperatorRekeyCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "-nonce", nonce, + key, + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String()) + } + } + + ui, cmd := testOperatorRekeyCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "-nonce", nonce, + keys[len(keys)-1], // the last unseal key + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String()) + } + + re := regexp.MustCompile(`Key 1: (.+)`) + output := ui.OutputWriter.String() + match := re.FindAllStringSubmatch(output, -1) + if len(match) < 1 || len(match[0]) < 2 { + t.Fatalf("bad match: %#v", match) + } + + // Grab the unseal key and try to unseal + unsealKey := match[0][1] + if err := client.Sys().Seal(); err != nil { + t.Fatal(err) + } + sealStatus, err := client.Sys().Unseal(unsealKey) + if err != nil { + t.Fatal(err) + } + if sealStatus.Sealed { + t.Errorf("expected vault to be unsealed: %#v", sealStatus) + } + }) + + t.Run("provide_stdin", func(t *testing.T) { + t.Parallel() + + client, keys, closer := testVaultServerUnseal(t) + defer closer() + + // Initialize a rekey + status, err := client.Sys().RekeyInit(&api.RekeyInitRequest{ + SecretShares: 1, + SecretThreshold: 1, + }) + if err != nil { + t.Fatal(err) + } + nonce := status.Nonce + + // Supply the first n-1 unseal keys + for _, key := range keys[:len(keys)-1] { + stdinR, stdinW := io.Pipe() + go func() { + stdinW.Write([]byte(key)) + stdinW.Close() + }() + + ui, cmd := testOperatorRekeyCommand(t) + cmd.client = client + cmd.testStdin = stdinR + + code := cmd.Run([]string{ + "-nonce", nonce, + "-", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String()) + } + } + + stdinR, stdinW := io.Pipe() + go func() { + stdinW.Write([]byte(keys[len(keys)-1])) // the last unseal key + stdinW.Close() + }() + + ui, cmd := testOperatorRekeyCommand(t) + cmd.client = client + cmd.testStdin = stdinR + + code := cmd.Run([]string{ + "-nonce", nonce, + "-", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + re := regexp.MustCompile(`Key 1: (.+)`) + output := ui.OutputWriter.String() + match := re.FindAllStringSubmatch(output, -1) + if len(match) < 1 || len(match[0]) < 2 { + t.Fatalf("bad match: %#v", match) + } + + // Grab the unseal key and try to unseal + unsealKey := match[0][1] + if err := client.Sys().Seal(); err != nil { + t.Fatal(err) + } + sealStatus, err := client.Sys().Unseal(unsealKey) + if err != nil { + t.Fatal(err) + } + if sealStatus.Sealed { + t.Errorf("expected vault to be unsealed: %#v", sealStatus) + } + }) + + t.Run("backup", func(t *testing.T) { + t.Parallel() + + pgpKey := "keybase:hashicorp" + // pgpFingerprints := []string{"91a6e7f85d05c65630bef18951852d87348ffc4c"} + + client, keys, closer := testVaultServerUnseal(t) + defer closer() + + ui, cmd := testOperatorRekeyCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "-init", + "-key-shares", "1", + "-key-threshold", "1", + "-pgp-keys", pgpKey, + "-backup", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String()) + } + + // Get the status for the nonce + status, err := client.Sys().RekeyStatus() + if err != nil { + t.Fatal(err) + } + nonce := status.Nonce + + var combined string + // Supply the unseal keys + for _, key := range keys { + ui, cmd := testOperatorRekeyCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "-nonce", nonce, + key, + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String()) + } + + // Append to our output string + combined += ui.OutputWriter.String() + } + + re := regexp.MustCompile(`Key 1 fingerprint: (.+); value: (.+)`) + match := re.FindAllStringSubmatch(combined, -1) + if len(match) < 1 || len(match[0]) < 3 { + t.Fatalf("bad match: %#v", match) + } + + // Grab the output fingerprint and encrypted key + fingerprint, encryptedKey := match[0][1], match[0][2] + + // Get the backup + ui, cmd = testOperatorRekeyCommand(t) + cmd.client = client + + code = cmd.Run([]string{ + "-backup-retrieve", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String()) + } + + output := ui.OutputWriter.String() + if !strings.Contains(output, fingerprint) { + t.Errorf("expected %q to contain %q", output, fingerprint) + } + if !strings.Contains(output, encryptedKey) { + t.Errorf("expected %q to contain %q", output, encryptedKey) + } + + // Delete the backup + ui, cmd = testOperatorRekeyCommand(t) + cmd.client = client + + code = cmd.Run([]string{ + "-backup-delete", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String()) + } + + secret, err := client.Sys().RekeyRetrieveBackup() + if err == nil { + t.Errorf("expected error: %#v", secret) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testOperatorRekeyCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "secret/foo", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error getting rekey status: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testOperatorRekeyCommand(t) + assertNoTabs(t, cmd) + }) +} diff --git a/command/operator_seal.go b/command/operator_seal.go new file mode 100644 index 0000000000..b5fe3b8eb7 --- /dev/null +++ b/command/operator_seal.go @@ -0,0 +1,84 @@ +package command + +import ( + "fmt" + "strings" + + "github.com/mitchellh/cli" + "github.com/posener/complete" +) + +var _ cli.Command = (*OperatorSealCommand)(nil) +var _ cli.CommandAutocomplete = (*OperatorSealCommand)(nil) + +type OperatorSealCommand struct { + *BaseCommand +} + +func (c *OperatorSealCommand) Synopsis() string { + return "Seals the Vault server" +} + +func (c *OperatorSealCommand) Help() string { + helpText := ` +Usage: vault seal [options] + + Seals the Vault server. Sealing tells the Vault server to stop responding + to any operations until it is unsealed. When sealed, the Vault server + discards its in-memory master key to unlock the data, so it is physically + blocked from responding to operations unsealed. + + If an unseal is in progress, sealing the Vault will reset the unsealing + process. Users will have to re-enter their portions of the master key again. + + This command does nothing if the Vault server is already sealed. + + Seal the Vault server: + + $ vault operator seal + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *OperatorSealCommand) Flags() *FlagSets { + return c.flagSet(FlagSetHTTP) +} + +func (c *OperatorSealCommand) AutocompleteArgs() complete.Predictor { + return nil +} + +func (c *OperatorSealCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *OperatorSealCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + if len(args) > 0 { + c.UI.Error(fmt.Sprintf("Too many arguments (expected 0, got %d)", len(args))) + return 1 + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + if err := client.Sys().Seal(); err != nil { + c.UI.Error(fmt.Sprintf("Error sealing: %s", err)) + return 2 + } + + c.UI.Output("Success! Vault is sealed.") + return 0 +} diff --git a/command/operator_seal_test.go b/command/operator_seal_test.go new file mode 100644 index 0000000000..86722d2e84 --- /dev/null +++ b/command/operator_seal_test.go @@ -0,0 +1,122 @@ +package command + +import ( + "strings" + "testing" + + "github.com/mitchellh/cli" +) + +func testOperatorSealCommand(tb testing.TB) (*cli.MockUi, *OperatorSealCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &OperatorSealCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestOperatorSealCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "args", + []string{"foo"}, + "Too many arguments", + 1, + }, + } + + t.Run("validations", func(t *testing.T) { + t.Parallel() + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testOperatorSealCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) + + t.Run("integration", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testOperatorSealCommand(t) + cmd.client = client + + code := cmd.Run([]string{}) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Success! Vault is sealed." + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + sealStatus, err := client.Sys().SealStatus() + if err != nil { + t.Fatal(err) + } + if !sealStatus.Sealed { + t.Errorf("expected to be sealed") + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testOperatorSealCommand(t) + cmd.client = client + + code := cmd.Run([]string{}) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error sealing: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testOperatorSealCommand(t) + assertNoTabs(t, cmd) + }) +} diff --git a/command/operator_step_down.go b/command/operator_step_down.go new file mode 100644 index 0000000000..63208faf04 --- /dev/null +++ b/command/operator_step_down.go @@ -0,0 +1,80 @@ +package command + +import ( + "fmt" + "strings" + + "github.com/mitchellh/cli" + "github.com/posener/complete" +) + +var _ cli.Command = (*OperatorStepDownCommand)(nil) +var _ cli.CommandAutocomplete = (*OperatorStepDownCommand)(nil) + +type OperatorStepDownCommand struct { + *BaseCommand +} + +func (c *OperatorStepDownCommand) Synopsis() string { + return "Forces Vault to resign active duty" +} + +func (c *OperatorStepDownCommand) Help() string { + helpText := ` +Usage: vault operator step-down [options] + + Forces the Vault server at the given address to step down from active duty. + While the affected node will have a delay before attempting to acquire the + leader lock again, if no other Vault nodes acquire the lock beforehand, it + is possible for the same node to re-acquire the lock and become active + again. + + Force Vault to step down as the leader: + + $ vault operator step-down + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *OperatorStepDownCommand) Flags() *FlagSets { + return c.flagSet(FlagSetHTTP) +} + +func (c *OperatorStepDownCommand) AutocompleteArgs() complete.Predictor { + return nil +} + +func (c *OperatorStepDownCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *OperatorStepDownCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + if len(args) > 0 { + c.UI.Error(fmt.Sprintf("Too many arguments (expected 0, got %d)", len(args))) + return 1 + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + if err := client.Sys().StepDown(); err != nil { + c.UI.Error(fmt.Sprintf("Error stepping down: %s", err)) + return 2 + } + + c.UI.Output(fmt.Sprintf("Success! Stepped down: %s", client.Address())) + return 0 +} diff --git a/command/operator_step_down_test.go b/command/operator_step_down_test.go new file mode 100644 index 0000000000..93117a856b --- /dev/null +++ b/command/operator_step_down_test.go @@ -0,0 +1,99 @@ +package command + +import ( + "strings" + "testing" + + "github.com/mitchellh/cli" +) + +func testOperatorStepDownCommand(tb testing.TB) (*cli.MockUi, *OperatorStepDownCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &OperatorStepDownCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestOperatorStepDownCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "too_many_args", + []string{"foo"}, + "Too many arguments", + 1, + }, + { + "default", + nil, + "Success! Stepped down: ", + 0, + }, + } + + t.Run("validations", func(t *testing.T) { + t.Parallel() + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testOperatorStepDownCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testOperatorStepDownCommand(t) + cmd.client = client + + code := cmd.Run([]string{}) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error stepping down: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testOperatorStepDownCommand(t) + assertNoTabs(t, cmd) + }) +} diff --git a/command/operator_unseal.go b/command/operator_unseal.go new file mode 100644 index 0000000000..e2957647d1 --- /dev/null +++ b/command/operator_unseal.go @@ -0,0 +1,145 @@ +package command + +import ( + "fmt" + "io" + "os" + "strings" + + "github.com/hashicorp/vault/helper/password" + "github.com/mitchellh/cli" + "github.com/posener/complete" +) + +var _ cli.Command = (*OperatorUnsealCommand)(nil) +var _ cli.CommandAutocomplete = (*OperatorUnsealCommand)(nil) + +type OperatorUnsealCommand struct { + *BaseCommand + + flagReset bool + + testOutput io.Writer // for tests +} + +func (c *OperatorUnsealCommand) Synopsis() string { + return "Unseals the Vault server" +} + +func (c *OperatorUnsealCommand) Help() string { + helpText := ` +Usage: vault operator unseal [options] [KEY] + + Provide a portion of the master key to unseal a Vault server. Vault starts + in a sealed state. It cannot perform operations until it is unsealed. This + command accepts a portion of the master key (an "unseal key"). + + The unseal key can be supplied as an argument to the command, but this is + not recommended as the unseal key will be available in your history: + + $ vault operator unseal IXyR0OJnSFobekZMMCKCoVEpT7wI6l+USMzE3IcyDyo= + + Instead, run the command with no arguments and it will prompt for the key: + + $ vault operator unseal + Key (will be hidden): IXyR0OJnSFobekZMMCKCoVEpT7wI6l+USMzE3IcyDyo= + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *OperatorUnsealCommand) Flags() *FlagSets { + set := c.flagSet(FlagSetHTTP) + + f := set.NewFlagSet("Command Options") + + f.BoolVar(&BoolVar{ + Name: "reset", + Aliases: []string{}, + Target: &c.flagReset, + Default: false, + EnvVar: "", + Completion: complete.PredictNothing, + Usage: "Discard any previously entered keys to the unseal process.", + }) + + return set +} + +func (c *OperatorUnsealCommand) AutocompleteArgs() complete.Predictor { + return complete.PredictAnything +} + +func (c *OperatorUnsealCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *OperatorUnsealCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + unsealKey := "" + + args = f.Args() + switch len(args) { + case 0: + // We will prompt for the unsealKey later + case 1: + unsealKey = strings.TrimSpace(args[0]) + default: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args))) + return 1 + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + if c.flagReset { + status, err := client.Sys().ResetUnsealProcess() + if err != nil { + c.UI.Error(fmt.Sprintf("Error resetting unseal process: %s", err)) + return 2 + } + return OutputSealStatus(c.UI, client, status) + } + + if unsealKey == "" { + // Override the output + writer := (io.Writer)(os.Stdout) + if c.testOutput != nil { + writer = c.testOutput + } + + fmt.Fprintf(writer, "Unseal Key (will be hidden): ") + value, err := password.Read(os.Stdin) + fmt.Fprintf(writer, "\n") + if err != nil { + c.UI.Error(wrapAtLength(fmt.Sprintf("An error occurred attempting to "+ + "ask for an unseal key. The raw error message is shown below, but "+ + "usually this is because you attempted to pipe a value into the "+ + "unseal command or you are executing outside of a terminal (tty). "+ + "You should run the unseal command from a terminal for maximum "+ + "security. If this is not an option, the unseal can be provided as "+ + "the first argument to the unseal command. The raw error "+ + "was:\n\n%s", err))) + return 1 + } + unsealKey = strings.TrimSpace(value) + } + + status, err := client.Sys().Unseal(unsealKey) + if err != nil { + c.UI.Error(fmt.Sprintf("Error unsealing: %s", err)) + return 2 + } + + return OutputSealStatus(c.UI, client, status) +} diff --git a/command/operator_unseal_test.go b/command/operator_unseal_test.go new file mode 100644 index 0000000000..e2222fc73b --- /dev/null +++ b/command/operator_unseal_test.go @@ -0,0 +1,140 @@ +package command + +import ( + "io/ioutil" + "strings" + "testing" + + "github.com/mitchellh/cli" +) + +func testOperatorUnsealCommand(tb testing.TB) (*cli.MockUi, *OperatorUnsealCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &OperatorUnsealCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestOperatorUnsealCommand_Run(t *testing.T) { + t.Parallel() + + t.Run("error_non_terminal", func(t *testing.T) { + t.Parallel() + + ui, cmd := testOperatorUnsealCommand(t) + cmd.testOutput = ioutil.Discard + + code := cmd.Run(nil) + if exp := 1; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "is not a terminal" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("reset", func(t *testing.T) { + t.Parallel() + + client, keys, closer := testVaultServerUnseal(t) + defer closer() + + // Seal so we can unseal + if err := client.Sys().Seal(); err != nil { + t.Fatal(err) + } + + // Enter an unseal key + if _, err := client.Sys().Unseal(keys[0]); err != nil { + t.Fatal(err) + } + + ui, cmd := testOperatorUnsealCommand(t) + cmd.client = client + cmd.testOutput = ioutil.Discard + + // Reset and check output + code := cmd.Run([]string{ + "-reset", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + expected := "0/3" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("full", func(t *testing.T) { + t.Parallel() + + client, keys, closer := testVaultServerUnseal(t) + defer closer() + + // Seal so we can unseal + if err := client.Sys().Seal(); err != nil { + t.Fatal(err) + } + + for _, key := range keys { + ui, cmd := testOperatorUnsealCommand(t) + cmd.client = client + cmd.testOutput = ioutil.Discard + + // Reset and check output + code := cmd.Run([]string{ + key, + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String()) + } + } + + status, err := client.Sys().SealStatus() + if err != nil { + t.Fatal(err) + } + if status.Sealed { + t.Error("expected unsealed") + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testOperatorUnsealCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "abcd", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error unsealing: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testOperatorUnsealCommand(t) + assertNoTabs(t, cmd) + }) +} diff --git a/command/path_help.go b/command/path_help.go index 6eed9607d8..2ce4a38bfd 100644 --- a/command/path_help.go +++ b/command/path_help.go @@ -4,73 +4,99 @@ import ( "fmt" "strings" - "github.com/hashicorp/vault/meta" + "github.com/mitchellh/cli" + "github.com/posener/complete" ) -// PathHelpCommand is a Command that lists the mounts. +var _ cli.Command = (*PathHelpCommand)(nil) +var _ cli.CommandAutocomplete = (*PathHelpCommand)(nil) + +var pathHelpVaultSealedMessage = strings.TrimSpace(` +Error: Vault is sealed. + +The "path-help" command requires the Vault to be unsealed so that the mount +points of the secret engines are known. +`) + type PathHelpCommand struct { - meta.Meta -} - -func (c *PathHelpCommand) Run(args []string) int { - flags := c.Meta.FlagSet("help", meta.FlagSetDefault) - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - args = flags.Args() - if len(args) != 1 { - flags.Usage() - c.Ui.Error("\nhelp expects a single argument") - return 1 - } - - path := args[0] - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 2 - } - - help, err := client.Help(path) - if err != nil { - if strings.Contains(err.Error(), "Vault is sealed") { - c.Ui.Error(`Error: Vault is sealed. - -The path-help command requires the vault to be unsealed so that -mount points of secret backends are known.`) - } else { - c.Ui.Error(fmt.Sprintf( - "Error reading help: %s", err)) - } - return 1 - } - - c.Ui.Output(help.Help) - return 0 + *BaseCommand } func (c *PathHelpCommand) Synopsis() string { - return "Look up the help for a path" + return "Retrieve API help for paths" } func (c *PathHelpCommand) Help() string { helpText := ` -Usage: vault path-help [options] path +Usage: vault path-help [options] PATH - Look up the help for a path. + Retrieves API help for paths. All endpoints in Vault provide built-in help + in markdown format. This includes system paths, secret engines, and auth + methods. - All endpoints in Vault from system paths, secret paths, and credential - providers provide built-in help. This command looks up and outputs that - help. + Get help for the thing mounted at database/: - The command requires that the vault be unsealed, because otherwise - the mount points of the backends are unknown. + $ vault path-help database/ + + The response object will return additional paths to retrieve help: + + $ vault path-help database/roles/ + + Each secret engine produces different help output. + +` + c.Flags().Help() -General Options: -` + meta.GeneralOptionsUsage() return strings.TrimSpace(helpText) } + +func (c *PathHelpCommand) Flags() *FlagSets { + return c.flagSet(FlagSetHTTP) +} + +func (c *PathHelpCommand) AutocompleteArgs() complete.Predictor { + return complete.PredictAnything // TODO: programatic way to invoke help +} + +func (c *PathHelpCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *PathHelpCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + switch { + case len(args) < 1: + c.UI.Error(fmt.Sprintf("Not enough arguments (expected 1, got %d)", len(args))) + return 1 + case len(args) > 1: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args))) + return 1 + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + path := sanitizePath(args[0]) + + help, err := client.Help(path) + if err != nil { + if strings.Contains(err.Error(), "Vault is sealed") { + c.UI.Error(pathHelpVaultSealedMessage) + } else { + c.UI.Error(fmt.Sprintf("Error retrieving help: %s", err)) + } + return 2 + } + + c.UI.Output(help.Help) + return 0 +} diff --git a/command/path_help_test.go b/command/path_help_test.go index 46219ba5dc..688bcf09ce 100644 --- a/command/path_help_test.go +++ b/command/path_help_test.go @@ -1,32 +1,115 @@ package command import ( + "strings" "testing" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" ) -func TestHelp(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func testPathHelpCommand(tb testing.TB) (*cli.MockUi, *PathHelpCommand) { + tb.Helper() - ui := new(cli.MockUi) - c := &PathHelpCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + ui := cli.NewMockUi() + return ui, &PathHelpCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestPathHelpCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "not_enough_args", + []string{}, + "Not enough arguments", + 1, + }, + { + "too_many_args", + []string{"foo", "bar"}, + "Too many arguments", + 1, + }, + { + "not_found", + []string{"nope/not/once/never"}, + "", + 2, + }, + { + "kv", + []string{"secret/"}, + "The kv backend", + 0, + }, + { + "sys", + []string{"sys/mounts"}, + "currently mounted backends", + 0, }, } - args := []string{ - "-address", addr, - "sys/mounts", - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testPathHelpCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) } + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testPathHelpCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "sys/mounts", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error retrieving help: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testPathHelpCommand(t) + assertNoTabs(t, cmd) + }) } diff --git a/command/pgp_test.go b/command/pgp_test.go index c368e31133..4cfda985b7 100644 --- a/command/pgp_test.go +++ b/command/pgp_test.go @@ -62,6 +62,35 @@ func getPubKeyFiles(t *testing.T) (string, []string, error) { return tempDir, pubFiles, nil } +func testPGPDecrypt(tb testing.TB, privKey, enc string) string { + tb.Helper() + + privKeyBytes, err := base64.StdEncoding.DecodeString(privKey) + if err != nil { + tb.Fatal(err) + } + + ptBuf := bytes.NewBuffer(nil) + entity, err := openpgp.ReadEntity(packet.NewReader(bytes.NewBuffer(privKeyBytes))) + if err != nil { + tb.Fatal(err) + } + + var rootBytes []byte + rootBytes, err = base64.StdEncoding.DecodeString(enc) + if err != nil { + tb.Fatal(err) + } + + entityList := &openpgp.EntityList{entity} + md, err := openpgp.ReadMessage(bytes.NewBuffer(rootBytes), entityList, nil, nil) + if err != nil { + tb.Fatal(err) + } + ptBuf.ReadFrom(md.UnverifiedBody) + return ptBuf.String() +} + func parseDecryptAndTestUnsealKeys(t *testing.T, input, rootToken string, fingerprints bool, diff --git a/command/policies_deprecated.go b/command/policies_deprecated.go new file mode 100644 index 0000000000..b7b5f5cf17 --- /dev/null +++ b/command/policies_deprecated.go @@ -0,0 +1,52 @@ +package command + +import ( + "github.com/mitchellh/cli" +) + +// Deprecation +// TODO: remove in 0.9.0 + +var _ cli.Command = (*PoliciesDeprecatedCommand)(nil) + +type PoliciesDeprecatedCommand struct { + *BaseCommand +} + +func (c *PoliciesDeprecatedCommand) Synopsis() string { return "" } + +func (c *PoliciesDeprecatedCommand) Help() string { + return (&PolicyListCommand{ + BaseCommand: c.BaseCommand, + }).Help() +} + +func (c *PoliciesDeprecatedCommand) Run(args []string) int { + oargs := args + + f := c.flagSet(FlagSetHTTP) + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + args = f.Args() + + // Got an arg, this is trying to read a policy + if len(args) > 0 { + return (&PolicyReadCommand{ + BaseCommand: &BaseCommand{ + UI: c.UI, + client: c.client, + }, + }).Run(oargs) + } + + // No args, probably ran "vault policies" and we want to translate that to + // "vault policy list" + return (&PolicyListCommand{ + BaseCommand: &BaseCommand{ + UI: c.UI, + client: c.client, + }, + }).Run(oargs) +} diff --git a/command/policies_deprecated_test.go b/command/policies_deprecated_test.go new file mode 100644 index 0000000000..de8ff7a329 --- /dev/null +++ b/command/policies_deprecated_test.go @@ -0,0 +1,96 @@ +package command + +import ( + "strings" + "testing" + + "github.com/mitchellh/cli" +) + +func testPoliciesDeprecatedCommand(tb testing.TB) (*cli.MockUi, *PoliciesDeprecatedCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &PoliciesDeprecatedCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestPoliciesDeprecatedCommand_Run(t *testing.T) { + t.Parallel() + + // TODO: remove in 0.9.0 + t.Run("deprecated_arg", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testPoliciesDeprecatedCommand(t) + cmd.client = client + + // vault policies ARG -> vault policy read ARG + code := cmd.Run([]string{"default"}) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String()) + } + stdout := ui.OutputWriter.String() + + if expected := "token/"; !strings.Contains(stdout, expected) { + t.Errorf("expected %q to contain %q", stdout, expected) + } + }) + + t.Run("deprecated_no_args", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testPoliciesDeprecatedCommand(t) + cmd.client = client + + // vault policies -> vault policy list + code := cmd.Run([]string{}) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String()) + } + stdout := ui.OutputWriter.String() + + if expected := "root"; !strings.Contains(stdout, expected) { + t.Errorf("expected %q to contain %q", stdout, expected) + } + }) + + t.Run("deprecated_with_flags", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testPoliciesDeprecatedCommand(t) + cmd.client = client + + // vault policies -flag -> vault policy list + code := cmd.Run([]string{ + "-address", client.Address(), + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String()) + } + stdout := ui.OutputWriter.String() + + if expected := "root"; !strings.Contains(stdout, expected) { + t.Errorf("expected %q to contain %q", stdout, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testPoliciesDeprecatedCommand(t) + assertNoTabs(t, cmd) + }) +} diff --git a/command/policy.go b/command/policy.go new file mode 100644 index 0000000000..5d812cadeb --- /dev/null +++ b/command/policy.go @@ -0,0 +1,47 @@ +package command + +import ( + "strings" + + "github.com/mitchellh/cli" +) + +var _ cli.Command = (*PolicyCommand)(nil) + +// PolicyCommand is a Command that holds the audit commands +type PolicyCommand struct { + *BaseCommand +} + +func (c *PolicyCommand) Synopsis() string { + return "Interact with policies" +} + +func (c *PolicyCommand) Help() string { + helpText := ` +Usage: vault policy [options] [args] + + This command groups subcommands for interacting with policies. Users can + Users can write, read, and list policies in Vault. + + List all enabled policies: + + $ vault policy list + + Create a policy named "my-policy" from contents on local disk: + + $ vault policy write my-policy ./my-policy.hcl + + Delete the policy named my-policy: + + $ vault policy delete my-policy + + Please see the individual subcommand help for detailed usage information. +` + + return strings.TrimSpace(helpText) +} + +func (c *PolicyCommand) Run(args []string) int { + return cli.RunResultHelp +} diff --git a/command/policy_delete.go b/command/policy_delete.go index ff8342a625..e74030640f 100644 --- a/command/policy_delete.go +++ b/command/policy_delete.go @@ -4,62 +4,82 @@ import ( "fmt" "strings" - "github.com/hashicorp/vault/meta" + "github.com/mitchellh/cli" + "github.com/posener/complete" ) -// PolicyDeleteCommand is a Command that enables a new endpoint. +var _ cli.Command = (*PolicyDeleteCommand)(nil) +var _ cli.CommandAutocomplete = (*PolicyDeleteCommand)(nil) + type PolicyDeleteCommand struct { - meta.Meta + *BaseCommand +} + +func (c *PolicyDeleteCommand) Synopsis() string { + return "Deletes a policy by name" +} + +func (c *PolicyDeleteCommand) Help() string { + helpText := ` +Usage: vault policy delete [options] NAME + + Deletes the policy named NAME in the Vault server. Once the policy is deleted, + all tokens associated with the policy are affected immediately. + + Delete the policy named "my-policy": + + $ vault policy delete my-policy + + Note that it is not possible to delete the "default" or "root" policies. + These are built-in policies. + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *PolicyDeleteCommand) Flags() *FlagSets { + return c.flagSet(FlagSetHTTP) +} + +func (c *PolicyDeleteCommand) AutocompleteArgs() complete.Predictor { + return c.PredictVaultPolicies() +} + +func (c *PolicyDeleteCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() } func (c *PolicyDeleteCommand) Run(args []string) int { - flags := c.Meta.FlagSet("policy-delete", meta.FlagSetDefault) - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) return 1 } - args = flags.Args() - if len(args) != 1 { - flags.Usage() - c.Ui.Error(fmt.Sprintf( - "\npolicy-delete expects exactly one argument")) + args = f.Args() + switch { + case len(args) < 1: + c.UI.Error(fmt.Sprintf("Not enough arguments (expected 1, got %d)", len(args))) + return 1 + case len(args) > 1: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args))) return 1 } client, err := c.Client() if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) + c.UI.Error(err.Error()) return 2 } - name := args[0] + name := strings.TrimSpace(strings.ToLower(args[0])) if err := client.Sys().DeletePolicy(name); err != nil { - c.Ui.Error(fmt.Sprintf( - "Error: %s", err)) - return 1 + c.UI.Error(fmt.Sprintf("Error deleting %s: %s", name, err)) + return 2 } - c.Ui.Output(fmt.Sprintf("Policy '%s' deleted.", name)) + c.UI.Output(fmt.Sprintf("Success! Deleted policy: %s", name)) return 0 } - -func (c *PolicyDeleteCommand) Synopsis() string { - return "Delete a policy from the server" -} - -func (c *PolicyDeleteCommand) Help() string { - helpText := ` -Usage: vault policy-delete [options] name - - Delete a policy with the given name. - - Once the policy is deleted, all users associated with the policy will - be affected immediately. When a user is associated with a policy that - doesn't exist, it is identical to not being associated with that policy. - -General Options: -` + meta.GeneralOptionsUsage() - return strings.TrimSpace(helpText) -} diff --git a/command/policy_delete_test.go b/command/policy_delete_test.go index 4f62a1028d..1c30c739d7 100644 --- a/command/policy_delete_test.go +++ b/command/policy_delete_test.go @@ -1,61 +1,136 @@ package command import ( + "reflect" + "strings" "testing" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" ) -func TestPolicyDelete(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func testPolicyDeleteCommand(tb testing.TB) (*cli.MockUi, *PolicyDeleteCommand) { + tb.Helper() - ui := new(cli.MockUi) - c := &PolicyDeleteCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + ui := cli.NewMockUi() + return ui, &PolicyDeleteCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestPolicyDeleteCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "not_enough_args", + nil, + "Not enough arguments", + 1, + }, + { + "too_many_args", + []string{"foo", "bar"}, + "Too many arguments", + 1, }, } - args := []string{ - "-address", addr, - "foo", - } + t.Run("validations", func(t *testing.T) { + t.Parallel() - // Run once so the client is setup, ignore errors - c.Run(args) + for _, tc := range cases { + tc := tc - // Get the client so we can write data - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } - if err := client.Sys().PutPolicy("foo", testPolicyDeleteRules); err != nil { - t.Fatalf("err: %s", err) - } + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - // Test that the delete works - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } + ui, cmd := testPolicyDeleteCommand(t) - // Test the policy is gone - rules, err := client.Sys().GetPolicy("foo") - if err != nil { - t.Fatalf("err: %s", err) - } - if rules != "" { - t.Fatalf("bad: %#v", rules) - } + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) + + t.Run("integration", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + policy := `path "secret/" {}` + if err := client.Sys().PutPolicy("my-policy", policy); err != nil { + t.Fatal(err) + } + + ui, cmd := testPolicyDeleteCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "my-policy", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Success! Deleted policy: my-policy" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + policies, err := client.Sys().ListPolicies() + if err != nil { + t.Fatal(err) + } + + list := []string{"default", "root"} + if !reflect.DeepEqual(policies, list) { + t.Errorf("expected %q to be %q", policies, list) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testPolicyDeleteCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "my-policy", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error deleting my-policy: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testPolicyDeleteCommand(t) + assertNoTabs(t, cmd) + }) } - -const testPolicyDeleteRules = ` -path "sys" { - policy = "deny" -} -` diff --git a/command/policy_fmt.go b/command/policy_fmt.go new file mode 100644 index 0000000000..01bde16e60 --- /dev/null +++ b/command/policy_fmt.go @@ -0,0 +1,109 @@ +package command + +import ( + "fmt" + "io/ioutil" + "strings" + + "github.com/hashicorp/hcl/hcl/printer" + "github.com/hashicorp/vault/vault" + "github.com/mitchellh/cli" + homedir "github.com/mitchellh/go-homedir" + "github.com/posener/complete" +) + +var _ cli.Command = (*PolicyFmtCommand)(nil) +var _ cli.CommandAutocomplete = (*PolicyFmtCommand)(nil) + +type PolicyFmtCommand struct { + *BaseCommand +} + +func (c *PolicyFmtCommand) Synopsis() string { + return "Formats a policy on disk" +} + +func (c *PolicyFmtCommand) Help() string { + helpText := ` +Usage: vault policy fmt [options] PATH + + Formats a local policy file to the policy specification. This command will + overwrite the file at the given PATH with the properly-formatted policy + file contents. + + Format the local file "my-policy.hcl" as a policy file: + + $ vault policy fmt my-policy.hcl + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *PolicyFmtCommand) Flags() *FlagSets { + return c.flagSet(FlagSetNone) +} + +func (c *PolicyFmtCommand) AutocompleteArgs() complete.Predictor { + return complete.PredictFiles("*.hcl") +} + +func (c *PolicyFmtCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *PolicyFmtCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + switch { + case len(args) < 1: + c.UI.Error(fmt.Sprintf("Not enough arguments (expected 1, got %d)", len(args))) + return 1 + case len(args) > 1: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args))) + return 1 + } + + // Get the filepath, accounting for ~ and stuff + path, err := homedir.Expand(strings.TrimSpace(args[0])) + if err != nil { + c.UI.Error(fmt.Sprintf("Failed to expand path: %s", err)) + return 1 + } + + // Read the entire contents into memory - it would be nice if we could use + // a buffer, but hcl wants the full contents. + b, err := ioutil.ReadFile(path) + if err != nil { + c.UI.Error(fmt.Sprintf("Error reading source file: %s", err)) + return 1 + } + + // Actually parse the policy + if _, err := vault.ParseACLPolicy(string(b)); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + // Generate final contents + result, err := printer.Format(b) + if err != nil { + c.UI.Error(fmt.Sprintf("Error printing result: %s", err)) + return 1 + } + + // Write them back out + if err := ioutil.WriteFile(path, result, 0644); err != nil { + c.UI.Error(fmt.Sprintf("Error writing result: %s", err)) + return 1 + } + + c.UI.Output(fmt.Sprintf("Success! Formatted policy: %s", path)) + return 0 +} diff --git a/command/policy_fmt_test.go b/command/policy_fmt_test.go new file mode 100644 index 0000000000..93e8daa1bf --- /dev/null +++ b/command/policy_fmt_test.go @@ -0,0 +1,213 @@ +package command + +import ( + "io/ioutil" + "os" + "strings" + "testing" + + "github.com/mitchellh/cli" +) + +func testPolicyFmtCommand(tb testing.TB) (*cli.MockUi, *PolicyFmtCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &PolicyFmtCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestPolicyFmtCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "not_enough_args", + []string{}, + "Not enough arguments", + 1, + }, + { + "too_many_args", + []string{"foo", "bar"}, + "Too many arguments", + 1, + }, + } + + t.Run("validations", func(t *testing.T) { + t.Parallel() + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ui, cmd := testPolicyFmtCommand(t) + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) + + t.Run("default", func(t *testing.T) { + t.Parallel() + + policy := strings.TrimSpace(` +path "secret" { + capabilities = ["create", "update","delete"] + +} +`) + + f, err := ioutil.TempFile("", "") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + if _, err := f.Write([]byte(policy)); err != nil { + t.Fatal(err) + } + f.Close() + + _, cmd := testPolicyFmtCommand(t) + + code := cmd.Run([]string{ + f.Name(), + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := strings.TrimSpace(` +path "secret" { + capabilities = ["create", "update", "delete"] +} +`) + "\n" + + contents, err := ioutil.ReadFile(f.Name()) + if err != nil { + t.Fatal(err) + } + if string(contents) != expected { + t.Errorf("expected %q to be %q", string(contents), expected) + } + }) + + t.Run("bad_hcl", func(t *testing.T) { + t.Parallel() + + policy := `dafdaf` + + f, err := ioutil.TempFile("", "") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + if _, err := f.Write([]byte(policy)); err != nil { + t.Fatal(err) + } + f.Close() + + ui, cmd := testPolicyFmtCommand(t) + + code := cmd.Run([]string{ + f.Name(), + }) + if exp := 1; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + stderr := ui.ErrorWriter.String() + expected := "Failed to parse policy" + if !strings.Contains(stderr, expected) { + t.Errorf("expected %q to include %q", stderr, expected) + } + }) + + t.Run("bad_policy", func(t *testing.T) { + t.Parallel() + + policy := `banana "foo" {}` + + f, err := ioutil.TempFile("", "") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + if _, err := f.Write([]byte(policy)); err != nil { + t.Fatal(err) + } + f.Close() + + ui, cmd := testPolicyFmtCommand(t) + + code := cmd.Run([]string{ + f.Name(), + }) + if exp := 1; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + stderr := ui.ErrorWriter.String() + expected := "Failed to parse policy" + if !strings.Contains(stderr, expected) { + t.Errorf("expected %q to include %q", stderr, expected) + } + }) + + t.Run("bad_policy", func(t *testing.T) { + t.Parallel() + + policy := `path "secret/" { capabilities = ["bogus"] }` + + f, err := ioutil.TempFile("", "") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + if _, err := f.Write([]byte(policy)); err != nil { + t.Fatal(err) + } + f.Close() + + ui, cmd := testPolicyFmtCommand(t) + + code := cmd.Run([]string{ + f.Name(), + }) + if exp := 1; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + stderr := ui.ErrorWriter.String() + expected := "Failed to parse policy" + if !strings.Contains(stderr, expected) { + t.Errorf("expected %q to include %q", stderr, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testPolicyFmtCommand(t) + assertNoTabs(t, cmd) + }) +} diff --git a/command/policy_list.go b/command/policy_list.go index 73cb9c5b4e..1a61136ff1 100644 --- a/command/policy_list.go +++ b/command/policy_list.go @@ -4,89 +4,73 @@ import ( "fmt" "strings" - "github.com/hashicorp/vault/meta" + "github.com/mitchellh/cli" + "github.com/posener/complete" ) -// PolicyListCommand is a Command that enables a new endpoint. +var _ cli.Command = (*PolicyListCommand)(nil) +var _ cli.CommandAutocomplete = (*PolicyListCommand)(nil) + type PolicyListCommand struct { - meta.Meta + *BaseCommand +} + +func (c *PolicyListCommand) Synopsis() string { + return "Lists the installed policies" +} + +func (c *PolicyListCommand) Help() string { + helpText := ` +Usage: vault policy list [options] + + Lists the names of the policies that are installed on the Vault server. + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *PolicyListCommand) Flags() *FlagSets { + return c.flagSet(FlagSetHTTP) +} + +func (c *PolicyListCommand) AutocompleteArgs() complete.Predictor { + return nil +} + +func (c *PolicyListCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() } func (c *PolicyListCommand) Run(args []string) int { - flags := c.Meta.FlagSet("policy-list", meta.FlagSetDefault) - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) return 1 } - args = flags.Args() - if len(args) == 1 { - return c.read(args[0]) - } else if len(args) == 0 { - return c.list() - } else { - flags.Usage() - c.Ui.Error(fmt.Sprintf( - "\npolicies expects zero or one arguments")) + args = f.Args() + switch { + case len(args) > 0: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 0, got %d)", len(args))) return 1 } -} -func (c *PolicyListCommand) list() int { client, err := c.Client() if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) + c.UI.Error(err.Error()) return 2 } policies, err := client.Sys().ListPolicies() if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error: %s", err)) - return 1 - } - - for _, p := range policies { - c.Ui.Output(p) - } - - return 0 -} - -func (c *PolicyListCommand) read(n string) int { - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) + c.UI.Error(fmt.Sprintf("Error listing policies: %s", err)) return 2 } - - rules, err := client.Sys().GetPolicy(n) - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error: %s", err)) - return 1 + for _, p := range policies { + c.UI.Output(p) } - c.Ui.Output(rules) return 0 } - -func (c *PolicyListCommand) Synopsis() string { - return "List the policies on the server" -} - -func (c *PolicyListCommand) Help() string { - helpText := ` -Usage: vault policies [options] [name] - - List the policies that are available or read a single policy. - - This command lists the policies that are written to the Vault server. - If a name of a policy is specified, that policy is outputted. - -General Options: -` + meta.GeneralOptionsUsage() - return strings.TrimSpace(helpText) -} diff --git a/command/policy_list_test.go b/command/policy_list_test.go index b2afe293c1..70defe54ea 100644 --- a/command/policy_list_test.go +++ b/command/policy_list_test.go @@ -1,53 +1,114 @@ package command import ( + "strings" "testing" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" ) -func TestPolicyList(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func testPolicyListCommand(tb testing.TB) (*cli.MockUi, *PolicyListCommand) { + tb.Helper() - ui := new(cli.MockUi) - c := &PolicyListCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + ui := cli.NewMockUi() + return ui, &PolicyListCommand{ + BaseCommand: &BaseCommand{ + UI: ui, }, } - - args := []string{ - "-address", addr, - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } } -func TestPolicyRead(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func TestPolicyListCommand_Run(t *testing.T) { + t.Parallel() - ui := new(cli.MockUi) - c := &PolicyListCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + cases := []struct { + name string + args []string + out string + code int + }{ + { + "too_many_args", + []string{"foo"}, + "Too many arguments", + 1, }, } - args := []string{ - "-address", addr, - "root", - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } + t.Run("validations", func(t *testing.T) { + t.Parallel() + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testPolicyListCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) + + t.Run("default", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testPolicyListCommand(t) + cmd.client = client + + code := cmd.Run([]string{}) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "default\nroot" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testPolicyListCommand(t) + cmd.client = client + + code := cmd.Run([]string{}) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error listing policies: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testPolicyListCommand(t) + assertNoTabs(t, cmd) + }) } diff --git a/command/policy_read.go b/command/policy_read.go new file mode 100644 index 0000000000..324615935a --- /dev/null +++ b/command/policy_read.go @@ -0,0 +1,87 @@ +package command + +import ( + "fmt" + "strings" + + "github.com/mitchellh/cli" + "github.com/posener/complete" +) + +var _ cli.Command = (*PolicyReadCommand)(nil) +var _ cli.CommandAutocomplete = (*PolicyReadCommand)(nil) + +type PolicyReadCommand struct { + *BaseCommand +} + +func (c *PolicyReadCommand) Synopsis() string { + return "Prints the contents of a policy" +} + +func (c *PolicyReadCommand) Help() string { + helpText := ` +Usage: vault policy read [options] [NAME] + + Prints the contents and metadata of the Vault policy named NAME. If the policy + does not exist, an error is returned. + + Read the policy named "my-policy": + + $ vault policy read my-policy + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *PolicyReadCommand) Flags() *FlagSets { + return c.flagSet(FlagSetHTTP) +} + +func (c *PolicyReadCommand) AutocompleteArgs() complete.Predictor { + return c.PredictVaultPolicies() +} + +func (c *PolicyReadCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *PolicyReadCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + switch { + case len(args) < 1: + c.UI.Error(fmt.Sprintf("Not enough arguments (expected 1, got %d)", len(args))) + return 1 + case len(args) > 1: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args))) + return 1 + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + name := strings.ToLower(strings.TrimSpace(args[0])) + rules, err := client.Sys().GetPolicy(name) + if err != nil { + c.UI.Error(fmt.Sprintf("Error reading policy named %s: %s", name, err)) + return 2 + } + if rules == "" { + c.UI.Error(fmt.Sprintf("No policy named: %s", name)) + return 2 + } + c.UI.Output(strings.TrimSpace(rules)) + + return 0 +} diff --git a/command/policy_read_test.go b/command/policy_read_test.go new file mode 100644 index 0000000000..8cd7c066b8 --- /dev/null +++ b/command/policy_read_test.go @@ -0,0 +1,128 @@ +package command + +import ( + "strings" + "testing" + + "github.com/mitchellh/cli" +) + +func testPolicyReadCommand(tb testing.TB) (*cli.MockUi, *PolicyReadCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &PolicyReadCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestPolicyReadCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "too_many_args", + []string{"foo", "bar"}, + "Too many arguments", + 1, + }, + { + "no_policy_exists", + []string{"not-a-real-policy"}, + "No policy named", + 2, + }, + } + + t.Run("validations", func(t *testing.T) { + t.Parallel() + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testPolicyReadCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) + + t.Run("default", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + policy := `path "secret/" {}` + if err := client.Sys().PutPolicy("my-policy", policy); err != nil { + t.Fatal(err) + } + + ui, cmd := testPolicyReadCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "my-policy", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, policy) { + t.Errorf("expected %q to contain %q", combined, policy) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testPolicyReadCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "my-policy", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error reading policy named my-policy: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testPolicyReadCommand(t) + assertNoTabs(t, cmd) + }) +} diff --git a/command/policy_write.go b/command/policy_write.go index 59b26fb472..9f6cb2222d 100644 --- a/command/policy_write.go +++ b/command/policy_write.go @@ -7,84 +7,121 @@ import ( "os" "strings" - "github.com/hashicorp/vault/meta" + "github.com/mitchellh/cli" + "github.com/posener/complete" ) -// PolicyWriteCommand is a Command that enables a new endpoint. +var _ cli.Command = (*PolicyWriteCommand)(nil) +var _ cli.CommandAutocomplete = (*PolicyWriteCommand)(nil) + type PolicyWriteCommand struct { - meta.Meta + *BaseCommand + + testStdin io.Reader // for tests +} + +func (c *PolicyWriteCommand) Synopsis() string { + return "Uploads a named policy from a file" +} + +func (c *PolicyWriteCommand) Help() string { + helpText := ` +Usage: vault policy write [options] NAME PATH + + Uploads a policy with name NAME from the contents of a local file PATH or + stdin. If PATH is "-", the policy is read from stdin. Otherwise, it is + loaded from the file at the given path on the local disk. + + Upload a policy named "my-policy" from "/tmp/policy.hcl" on the local disk: + + $ vault policy write my-policy /tmp/policy.hcl + + Upload a policy from stdin: + + $ cat my-policy.hcl | vault policy write my-policy - + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *PolicyWriteCommand) Flags() *FlagSets { + return c.flagSet(FlagSetHTTP) +} + +func (c *PolicyWriteCommand) AutocompleteArgs() complete.Predictor { + return complete.PredictFunc(func(args complete.Args) []string { + // Predict the LAST argument hcl files - we don't want to predict the + // name argument as a filepath. + if len(args.All) == 3 { + return complete.PredictFiles("*.hcl").Predict(args) + } + return nil + }) +} + +func (c *PolicyWriteCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() } func (c *PolicyWriteCommand) Run(args []string) int { - flags := c.Meta.FlagSet("policy-write", meta.FlagSetDefault) - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) return 1 } - args = flags.Args() - if len(args) != 2 { - flags.Usage() - c.Ui.Error(fmt.Sprintf( - "\npolicy-write expects exactly two arguments")) + args = f.Args() + switch { + case len(args) < 2: + c.UI.Error(fmt.Sprintf("Not enough arguments (expected 2, got %d)", len(args))) + return 1 + case len(args) > 2: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 2, got %d)", len(args))) return 1 } client, err := c.Client() if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) + c.UI.Error(err.Error()) return 2 } // Policies are normalized to lowercase - name := strings.ToLower(args[0]) - path := args[1] + name := strings.TrimSpace(strings.ToLower(args[0])) + path := strings.TrimSpace(args[1]) - // Read the policy - var f io.Reader = os.Stdin - if path != "-" { + // Get the policy contents, either from stdin of a file + var reader io.Reader + if path == "-" { + reader = os.Stdin + if c.testStdin != nil { + reader = c.testStdin + } + } else { file, err := os.Open(path) if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error opening file: %s", err)) - return 1 + c.UI.Error(fmt.Sprintf("Error opening policy file: %s", err)) + return 2 } defer file.Close() - f = file + reader = file } + + // Read the policy var buf bytes.Buffer - if _, err := io.Copy(&buf, f); err != nil { - c.Ui.Error(fmt.Sprintf( - "Error reading file: %s", err)) - return 1 + if _, err := io.Copy(&buf, reader); err != nil { + c.UI.Error(fmt.Sprintf("Error reading policy: %s", err)) + return 2 } rules := buf.String() if err := client.Sys().PutPolicy(name, rules); err != nil { - c.Ui.Error(fmt.Sprintf( - "Error: %s", err)) - return 1 + c.UI.Error(fmt.Sprintf("Error uploading policy: %s", err)) + return 2 } - c.Ui.Output(fmt.Sprintf("Policy '%s' written.", name)) + c.UI.Output(fmt.Sprintf("Success! Uploaded policy: %s", name)) return 0 } - -func (c *PolicyWriteCommand) Synopsis() string { - return "Write a policy to the server" -} - -func (c *PolicyWriteCommand) Help() string { - helpText := ` -Usage: vault policy-write [options] name path - - Write a policy with the given name from the contents of a file or stdin. - - If the path is "-", the policy is read from stdin. Otherwise, it is - loaded from the file at the given path. - -General Options: -` + meta.GeneralOptionsUsage() - return strings.TrimSpace(helpText) -} diff --git a/command/policy_write_test.go b/command/policy_write_test.go index d0deeaac69..c8db7dc9dd 100644 --- a/command/policy_write_test.go +++ b/command/policy_write_test.go @@ -1,33 +1,207 @@ package command import ( + "bytes" + "io" + "io/ioutil" + "os" + "reflect" + "strings" "testing" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" ) -func TestPolicyWrite(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func testPolicyWriteCommand(tb testing.TB) (*cli.MockUi, *PolicyWriteCommand) { + tb.Helper() - ui := new(cli.MockUi) - c := &PolicyWriteCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + ui := cli.NewMockUi() + return ui, &PolicyWriteCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func testPolicyWritePolicyContents(tb testing.TB) []byte { + return bytes.TrimSpace([]byte(` +path "secret/" { + capabilities = ["read"] +} + `)) +} + +func TestPolicyWriteCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "too_many_args", + []string{"foo", "bar", "baz"}, + "Too many arguments", + 1, + }, + { + "not_enough_args", + []string{"foo"}, + "Not enough arguments", + 1, + }, + { + "bad_file", + []string{"my-policy", "/not/a/real/path.hcl"}, + "Error opening policy file", + 2, }, } - args := []string{ - "-address", addr, - "foo", - "./test-fixtures/policy.hcl", - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } + t.Run("validations", func(t *testing.T) { + t.Parallel() + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testPolicyWriteCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) + + t.Run("file", func(t *testing.T) { + t.Parallel() + + policy := testPolicyWritePolicyContents(t) + f, err := ioutil.TempFile("", "vault-policy-write") + if err != nil { + t.Fatal(err) + } + if _, err := f.Write(policy); err != nil { + t.Fatal(err) + } + if err := f.Close(); err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testPolicyWriteCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "my-policy", f.Name(), + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Success! Uploaded policy: my-policy" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + policies, err := client.Sys().ListPolicies() + if err != nil { + t.Fatal(err) + } + + list := []string{"default", "my-policy", "root"} + if !reflect.DeepEqual(policies, list) { + t.Errorf("expected %q to be %q", policies, list) + } + }) + + t.Run("stdin", func(t *testing.T) { + t.Parallel() + + stdinR, stdinW := io.Pipe() + go func() { + policy := testPolicyWritePolicyContents(t) + stdinW.Write(policy) + stdinW.Close() + }() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testPolicyWriteCommand(t) + cmd.client = client + cmd.testStdin = stdinR + + code := cmd.Run([]string{ + "my-policy", "-", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Success! Uploaded policy: my-policy" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + policies, err := client.Sys().ListPolicies() + if err != nil { + t.Fatal(err) + } + + list := []string{"default", "my-policy", "root"} + if !reflect.DeepEqual(policies, list) { + t.Errorf("expected %q to be %q", policies, list) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testPolicyWriteCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "my-policy", "-", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error uploading policy: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testPolicyWriteCommand(t) + assertNoTabs(t, cmd) + }) } diff --git a/command/read.go b/command/read.go index d989178229..6dc2b5337b 100644 --- a/command/read.go +++ b/command/read.go @@ -1,109 +1,94 @@ package command import ( - "flag" "fmt" "strings" - "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/meta" + "github.com/mitchellh/cli" "github.com/posener/complete" ) -// ReadCommand is a Command that reads data from the Vault. +var _ cli.Command = (*ReadCommand)(nil) +var _ cli.CommandAutocomplete = (*ReadCommand)(nil) + type ReadCommand struct { - meta.Meta -} - -func (c *ReadCommand) Run(args []string) int { - var format string - var field string - var err error - var secret *api.Secret - var flags *flag.FlagSet - flags = c.Meta.FlagSet("read", meta.FlagSetDefault) - flags.StringVar(&format, "format", "table", "") - flags.StringVar(&field, "field", "", "") - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - args = flags.Args() - if len(args) != 1 || len(args[0]) == 0 { - c.Ui.Error("read expects one argument") - flags.Usage() - return 1 - } - - path := args[0] - if path[0] == '/' { - path = path[1:] - } - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 2 - } - - secret, err = client.Logical().Read(path) - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error reading %s: %s", path, err)) - return 1 - } - if secret == nil { - c.Ui.Error(fmt.Sprintf( - "No value found at %s", path)) - return 1 - } - - // Handle single field output - if field != "" { - return PrintRawField(c.Ui, secret, field) - } - - return OutputSecret(c.Ui, format, secret) + *BaseCommand } func (c *ReadCommand) Synopsis() string { - return "Read data or secrets from Vault" + return "Read data and retrieves secrets" } func (c *ReadCommand) Help() string { helpText := ` -Usage: vault read [options] path +Usage: vault read [options] PATH - Read data from Vault. + Reads data from Vault at the given path. This can be used to read secrets, + generate dynamic credentials, get configuration details, and more. - Reads data at the given path from Vault. This can be used to read - secrets and configuration as well as generate dynamic values from - materialized backends. Please reference the documentation for the - backends in use to determine key structure. + Read a secret from the static secrets engine: -General Options: -` + meta.GeneralOptionsUsage() + ` -Read Options: + $ vault read secret/my-secret - -format=table The format for output. By default it is a whitespace- - delimited table. This can also be json or yaml. + For a full list of examples and paths, please see the documentation that + corresponds to the secrets engine in use. - -field=field If included, the raw value of the specified field - will be output raw to stdout. +` + c.Flags().Help() -` return strings.TrimSpace(helpText) } +func (c *ReadCommand) Flags() *FlagSets { + return c.flagSet(FlagSetHTTP | FlagSetOutputField | FlagSetOutputFormat) +} + func (c *ReadCommand) AutocompleteArgs() complete.Predictor { - return complete.PredictNothing + return c.PredictVaultFiles() } func (c *ReadCommand) AutocompleteFlags() complete.Flags { - return complete.Flags{ - "-format": predictFormat, - "-field": complete.PredictNothing, - } + return c.Flags().Completions() +} + +func (c *ReadCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + switch { + case len(args) < 1: + c.UI.Error(fmt.Sprintf("Not enough arguments (expected 1, got %d)", len(args))) + return 1 + case len(args) > 1: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args))) + return 1 + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + path := sanitizePath(args[0]) + + secret, err := client.Logical().Read(path) + if err != nil { + c.UI.Error(fmt.Sprintf("Error reading %s: %s", path, err)) + return 2 + } + if secret == nil { + c.UI.Error(fmt.Sprintf("No value found at %s", path)) + return 2 + } + + if c.flagField != "" { + return PrintRawField(c.UI, secret, c.flagField) + } + + return OutputSecret(c.UI, c.flagFormat, secret) } diff --git a/command/read_test.go b/command/read_test.go index 5cf0f08c3e..1a45ed28a6 100644 --- a/command/read_test.go +++ b/command/read_test.go @@ -1,137 +1,155 @@ package command import ( + "strings" "testing" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" ) -func TestRead(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func testReadCommand(tb testing.TB) (*cli.MockUi, *ReadCommand) { + tb.Helper() - ui := new(cli.MockUi) - c := &ReadCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + ui := cli.NewMockUi() + return ui, &ReadCommand{ + BaseCommand: &BaseCommand{ + UI: ui, }, } - - args := []string{ - "-address", addr, - "sys/mounts", - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } } -func TestRead_notFound(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func TestReadCommand_Run(t *testing.T) { + t.Parallel() - ui := new(cli.MockUi) - c := &ReadCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + cases := []struct { + name string + args []string + out string + code int + }{ + { + "not_enough_args", + []string{}, + "Not enough arguments", + 1, + }, + { + "too_many_args", + []string{"foo", "bar"}, + "Too many arguments", + 1, + }, + { + "not_found", + []string{"nope/not/once/never"}, + "", + 2, + }, + { + "default", + []string{"secret/read/foo"}, + "foo", + 0, + }, + { + "field", + []string{ + "-field", "foo", + "secret/read/foo", + }, + "bar", + 0, + }, + { + "field_not_found", + []string{ + "-field", "not-a-real-field", + "secret/read/foo", + }, + "not present in secret", + 1, + }, + { + "format", + []string{ + "-format", "json", + "secret/read/foo", + }, + "{", + 0, + }, + { + "format_bad", + []string{ + "-format", "nope-not-real", + "secret/read/foo", + }, + "Invalid output format", + 1, }, } - args := []string{ - "-address", addr, - "secret/nope", - } - if code := c.Run(args); code != 1 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } -} - -func TestRead_field(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - c := &ReadCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, - }, - } - - args := []string{ - "-address", addr, - "-field", "value", - "secret/foo", - } - - // Run once so the client is setup, ignore errors - c.Run(args) - - // Get the client so we can write data - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } - - data := map[string]interface{}{"value": "bar"} - if _, err := client.Logical().Write("secret/foo", data); err != nil { - t.Fatalf("err: %s", err) - } - - // Run the read - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - output := ui.OutputWriter.String() - if output != "bar\n" { - t.Fatalf("unexpectd output:\n%s", output) - } -} - -func TestRead_field_notFound(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - c := &ReadCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, - }, - } - - args := []string{ - "-address", addr, - "-field", "nope", - "secret/foo", - } - - // Run once so the client is setup, ignore errors - c.Run(args) - - // Get the client so we can write data - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } - - data := map[string]interface{}{"value": "bar"} - if _, err := client.Logical().Write("secret/foo", data); err != nil { - t.Fatalf("err: %s", err) - } - - // Run the read - if code := c.Run(args); code != 1 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } + t.Run("validations", func(t *testing.T) { + t.Parallel() + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + if _, err := client.Logical().Write("secret/read/foo", map[string]interface{}{ + "foo": "bar", + }); err != nil { + t.Fatal(err) + } + + ui, cmd := testReadCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testReadCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "secret/foo", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error reading secret/foo: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testReadCommand(t) + assertNoTabs(t, cmd) + }) } diff --git a/command/rekey.go b/command/rekey.go deleted file mode 100644 index 1c6b85412f..0000000000 --- a/command/rekey.go +++ /dev/null @@ -1,449 +0,0 @@ -package command - -import ( - "fmt" - "os" - "strings" - - "github.com/fatih/structs" - "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/helper/password" - "github.com/hashicorp/vault/helper/pgpkeys" - "github.com/hashicorp/vault/meta" - "github.com/posener/complete" -) - -// RekeyCommand is a Command that rekeys the vault. -type RekeyCommand struct { - meta.Meta - - // Key can be used to pre-seed the key. If it is set, it will not - // be asked with the `password` helper. - Key string - - // The nonce for the rekey request to send along - Nonce string - - // Whether to use the recovery key instead of barrier key, if available - RecoveryKey bool -} - -func (c *RekeyCommand) Run(args []string) int { - var init, cancel, status, delete, retrieve, backup, recoveryKey bool - var shares, threshold, storedShares int - var nonce string - var pgpKeys pgpkeys.PubKeyFilesFlag - flags := c.Meta.FlagSet("rekey", meta.FlagSetDefault) - flags.BoolVar(&init, "init", false, "") - flags.BoolVar(&cancel, "cancel", false, "") - flags.BoolVar(&status, "status", false, "") - flags.BoolVar(&delete, "delete", false, "") - flags.BoolVar(&retrieve, "retrieve", false, "") - flags.BoolVar(&backup, "backup", false, "") - flags.BoolVar(&recoveryKey, "recovery-key", c.RecoveryKey, "") - flags.IntVar(&shares, "key-shares", 5, "") - flags.IntVar(&threshold, "key-threshold", 3, "") - flags.IntVar(&storedShares, "stored-shares", 0, "") - flags.StringVar(&nonce, "nonce", "", "") - flags.Var(&pgpKeys, "pgp-keys", "") - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - if nonce != "" { - c.Nonce = nonce - } - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 2 - } - - // Check if we are running doing any restricted variants - switch { - case init: - return c.initRekey(client, shares, threshold, storedShares, pgpKeys, backup, recoveryKey) - case cancel: - return c.cancelRekey(client, recoveryKey) - case status: - return c.rekeyStatus(client, recoveryKey) - case retrieve: - return c.rekeyRetrieveStored(client, recoveryKey) - case delete: - return c.rekeyDeleteStored(client, recoveryKey) - } - - // Check if the rekey is started - var rekeyStatus *api.RekeyStatusResponse - if recoveryKey { - rekeyStatus, err = client.Sys().RekeyRecoveryKeyStatus() - } else { - rekeyStatus, err = client.Sys().RekeyStatus() - } - if err != nil { - c.Ui.Error(fmt.Sprintf("Error reading rekey status: %s", err)) - return 1 - } - - // Start the rekey process if not started - if !rekeyStatus.Started { - if recoveryKey { - rekeyStatus, err = client.Sys().RekeyRecoveryKeyInit(&api.RekeyInitRequest{ - SecretShares: shares, - SecretThreshold: threshold, - PGPKeys: pgpKeys, - Backup: backup, - }) - } else { - rekeyStatus, err = client.Sys().RekeyInit(&api.RekeyInitRequest{ - SecretShares: shares, - SecretThreshold: threshold, - PGPKeys: pgpKeys, - Backup: backup, - }) - } - if err != nil { - c.Ui.Error(fmt.Sprintf("Error initializing rekey: %s", err)) - return 1 - } - c.Nonce = rekeyStatus.Nonce - } - - shares = rekeyStatus.N - threshold = rekeyStatus.T - serverNonce := rekeyStatus.Nonce - - // Get the unseal key - args = flags.Args() - key := c.Key - if len(args) > 0 { - key = args[0] - } - if key == "" { - c.Nonce = serverNonce - fmt.Printf("Rekey operation nonce: %s\n", serverNonce) - fmt.Printf("Key (will be hidden): ") - key, err = password.Read(os.Stdin) - fmt.Printf("\n") - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error attempting to ask for password. The raw error message\n"+ - "is shown below, but the most common reason for this error is\n"+ - "that you attempted to pipe a value into unseal or you're\n"+ - "executing `vault rekey` from outside of a terminal.\n\n"+ - "You should use `vault rekey` from a terminal for maximum\n"+ - "security. If this isn't an option, the unseal key can be passed\n"+ - "in using the first parameter.\n\n"+ - "Raw error: %s", err)) - return 1 - } - } - - // Provide the key, this may potentially complete the update - var result *api.RekeyUpdateResponse - if recoveryKey { - result, err = client.Sys().RekeyRecoveryKeyUpdate(strings.TrimSpace(key), c.Nonce) - } else { - result, err = client.Sys().RekeyUpdate(strings.TrimSpace(key), c.Nonce) - } - if err != nil { - c.Ui.Error(fmt.Sprintf("Error attempting rekey update: %s", err)) - return 1 - } - - // If we are not complete, then dump the status - if !result.Complete { - return c.rekeyStatus(client, recoveryKey) - } - - // Space between the key prompt, if any, and the output - c.Ui.Output("\n") - // Provide the keys - var haveB64 bool - if result.KeysB64 != nil && len(result.KeysB64) == len(result.Keys) { - haveB64 = true - } - for i, key := range result.Keys { - if len(result.PGPFingerprints) > 0 { - if haveB64 { - c.Ui.Output(fmt.Sprintf("Key %d fingerprint: %s; value: %s", i+1, result.PGPFingerprints[i], result.KeysB64[i])) - } else { - c.Ui.Output(fmt.Sprintf("Key %d fingerprint: %s; value: %s", i+1, result.PGPFingerprints[i], key)) - } - } else { - if haveB64 { - c.Ui.Output(fmt.Sprintf("Key %d: %s", i+1, result.KeysB64[i])) - } else { - c.Ui.Output(fmt.Sprintf("Key %d: %s", i+1, key)) - } - } - } - - c.Ui.Output(fmt.Sprintf("\nOperation nonce: %s", result.Nonce)) - - if len(result.PGPFingerprints) > 0 && result.Backup { - c.Ui.Output(fmt.Sprintf( - "\n" + - "The encrypted unseal keys have been backed up to \"core/unseal-keys-backup\"\n" + - "in your physical backend. It is your responsibility to remove these if and\n" + - "when desired.", - )) - } - - c.Ui.Output(fmt.Sprintf( - "\n"+ - "Vault rekeyed with %d keys and a key threshold of %d.\n", - shares, - threshold, - )) - - // Print this message if keys are returned - if len(result.Keys) > 0 { - c.Ui.Output(fmt.Sprintf( - "\n"+ - "Please securely distribute the above keys. When the vault is re-sealed,\n"+ - "restarted, or stopped, you must provide at least %d of these keys\n"+ - "to unseal it again.\n\n"+ - "Vault does not store the master key. Without at least %[1]d keys,\n"+ - "your vault will remain permanently sealed.", - threshold, - )) - } - - return 0 -} - -// initRekey is used to start the rekey process -func (c *RekeyCommand) initRekey(client *api.Client, - shares, threshold, storedShares int, - pgpKeys pgpkeys.PubKeyFilesFlag, - backup, recoveryKey bool) int { - // Start the rekey - request := &api.RekeyInitRequest{ - SecretShares: shares, - SecretThreshold: threshold, - StoredShares: storedShares, - PGPKeys: pgpKeys, - Backup: backup, - } - var status *api.RekeyStatusResponse - var err error - if recoveryKey { - status, err = client.Sys().RekeyRecoveryKeyInit(request) - } else { - status, err = client.Sys().RekeyInit(request) - } - if err != nil { - c.Ui.Error(fmt.Sprintf("Error initializing rekey: %s", err)) - return 1 - } - - if pgpKeys == nil || len(pgpKeys) == 0 { - c.Ui.Output(` -WARNING: If you lose the keys after they are returned to you, there is no -recovery. Consider using the '-pgp-keys' option to protect the returned unseal -keys along with '-backup=true' to allow recovery of the encrypted keys in case -of emergency. They can easily be deleted at a later time with -'vault rekey -delete'. -`) - } - - if pgpKeys != nil && len(pgpKeys) > 0 && !backup { - c.Ui.Output(` -WARNING: You are using PGP keys for encryption, but have not set the option to -back up the new unseal keys to physical storage. If you lose the keys after -they are returned to you, there is no recovery. Consider setting '-backup=true' -to allow recovery of the encrypted keys in case of emergency. They can easily -be deleted at a later time with 'vault rekey -delete'. -`) - } - - // Provide the current status - return c.dumpRekeyStatus(status) -} - -// cancelRekey is used to abort the rekey process -func (c *RekeyCommand) cancelRekey(client *api.Client, recovery bool) int { - var err error - if recovery { - err = client.Sys().RekeyRecoveryKeyCancel() - } else { - err = client.Sys().RekeyCancel() - } - if err != nil { - c.Ui.Error(fmt.Sprintf("Failed to cancel rekey: %s", err)) - return 1 - } - c.Ui.Output("Rekey canceled.") - return 0 -} - -// rekeyStatus is used just to fetch and dump the status -func (c *RekeyCommand) rekeyStatus(client *api.Client, recovery bool) int { - // Check the status - var status *api.RekeyStatusResponse - var err error - if recovery { - status, err = client.Sys().RekeyRecoveryKeyStatus() - } else { - status, err = client.Sys().RekeyStatus() - } - if err != nil { - c.Ui.Error(fmt.Sprintf("Error reading rekey status: %s", err)) - return 1 - } - - return c.dumpRekeyStatus(status) -} - -func (c *RekeyCommand) dumpRekeyStatus(status *api.RekeyStatusResponse) int { - // Dump the status - statString := fmt.Sprintf( - "Nonce: %s\n"+ - "Started: %t\n"+ - "Key Shares: %d\n"+ - "Key Threshold: %d\n"+ - "Rekey Progress: %d\n"+ - "Required Keys: %d", - status.Nonce, - status.Started, - status.N, - status.T, - status.Progress, - status.Required, - ) - if len(status.PGPFingerprints) != 0 { - statString = fmt.Sprintf("%s\nPGP Key Fingerprints: %s", statString, status.PGPFingerprints) - statString = fmt.Sprintf("%s\nBackup Storage: %t", statString, status.Backup) - } - c.Ui.Output(statString) - return 0 -} - -func (c *RekeyCommand) rekeyRetrieveStored(client *api.Client, recovery bool) int { - var storedKeys *api.RekeyRetrieveResponse - var err error - if recovery { - storedKeys, err = client.Sys().RekeyRetrieveRecoveryBackup() - } else { - storedKeys, err = client.Sys().RekeyRetrieveBackup() - } - if err != nil { - c.Ui.Error(fmt.Sprintf("Error retrieving stored keys: %s", err)) - return 1 - } - - secret := &api.Secret{ - Data: structs.New(storedKeys).Map(), - } - - return OutputSecret(c.Ui, "table", secret) -} - -func (c *RekeyCommand) rekeyDeleteStored(client *api.Client, recovery bool) int { - var err error - if recovery { - err = client.Sys().RekeyDeleteRecoveryBackup() - } else { - err = client.Sys().RekeyDeleteBackup() - } - if err != nil { - c.Ui.Error(fmt.Sprintf("Failed to delete stored keys: %s", err)) - return 1 - } - c.Ui.Output("Stored keys deleted.") - return 0 -} - -func (c *RekeyCommand) Synopsis() string { - return "Rekeys Vault to generate new unseal keys" -} - -func (c *RekeyCommand) Help() string { - helpText := ` -Usage: vault rekey [options] [key] - - Rekey is used to change the unseal keys. This can be done to generate - a new set of unseal keys or to change the number of shares and the - required threshold. - - Rekey can only be done when the vault is already unsealed. The operation - is done online, but requires that a threshold of the current unseal - keys be provided. - -General Options: -` + meta.GeneralOptionsUsage() + ` -Rekey Options: - - -init Initialize the rekey operation by setting the desired - number of shares and the key threshold. This can only be - done if no rekey is already initiated. - - -cancel Reset the rekey process by throwing away - prior keys and the rekey configuration. - - -status Prints the status of the current rekey operation. - This can be used to see the status without attempting - to provide an unseal key. - - -retrieve Retrieve backed-up keys. Only available if the PGP keys - were provided and the backup has not been deleted. - - -delete Delete any backed-up keys. - - -key-shares=5 The number of key shares to split the master key - into. - - -key-threshold=3 The number of key shares required to reconstruct - the master key. - - -nonce=abcd The nonce provided at rekey initialization time. This - same nonce value must be provided with each unseal - key. If the unseal key is not being passed in via the - the command line the nonce parameter is not required, - and will instead be displayed with the key prompt. - - -pgp-keys If provided, must be a comma-separated list of - files on disk containing binary- or base64-format - public PGP keys, or Keybase usernames specified as - "keybase:". The number of given entries - must match 'key-shares'. The output unseal keys will - be encrypted and base64-encoded, in order, with the - given public keys. If you want to use them with the - 'vault unseal' command, you will need to base64-decode - and decrypt; this will be the plaintext unseal key. - - -backup=false If true, and if the key shares are PGP-encrypted, a - plaintext backup of the PGP-encrypted keys will be - stored at "core/unseal-keys-backup" in your physical - storage. You can retrieve or delete them via the - 'sys/rekey/backup' endpoint. - - -recovery-key=false Whether to rekey the recovery key instead of the - barrier key. Only used with Vault HSM. -` - return strings.TrimSpace(helpText) -} - -func (c *RekeyCommand) AutocompleteArgs() complete.Predictor { - return complete.PredictNothing -} - -func (c *RekeyCommand) AutocompleteFlags() complete.Flags { - return complete.Flags{ - "-init": complete.PredictNothing, - "-cancel": complete.PredictNothing, - "-status": complete.PredictNothing, - "-retrieve": complete.PredictNothing, - "-delete": complete.PredictNothing, - "-key-shares": complete.PredictNothing, - "-key-threshold": complete.PredictNothing, - "-nonce": complete.PredictNothing, - "-pgp-keys": complete.PredictNothing, - "-backup": complete.PredictNothing, - "-recovery-key": complete.PredictNothing, - } -} diff --git a/command/rekey_test.go b/command/rekey_test.go deleted file mode 100644 index c22f26244f..0000000000 --- a/command/rekey_test.go +++ /dev/null @@ -1,313 +0,0 @@ -package command - -import ( - "context" - "encoding/hex" - "os" - "sort" - "strings" - "testing" - "time" - - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/logical" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" - "github.com/mitchellh/cli" -) - -func TestRekey(t *testing.T) { - core, keys, _ := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - - for i, key := range keys { - c := &RekeyCommand{ - Key: hex.EncodeToString(key), - RecoveryKey: false, - Meta: meta.Meta{ - Ui: ui, - }, - } - - if i > 0 { - conf, err := core.RekeyConfig(false) - if err != nil { - t.Fatal(err) - } - c.Nonce = conf.Nonce - } - - args := []string{"-address", addr} - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - } - - config, err := core.SealAccess().BarrierConfig() - if err != nil { - t.Fatalf("err: %s", err) - } - if config.SecretShares != 5 { - t.Fatal("should rekey") - } -} - -func TestRekey_arg(t *testing.T) { - core, keys, _ := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - - for i, key := range keys { - c := &RekeyCommand{ - RecoveryKey: false, - Meta: meta.Meta{ - Ui: ui, - }, - } - - if i > 0 { - conf, err := core.RekeyConfig(false) - if err != nil { - t.Fatal(err) - } - c.Nonce = conf.Nonce - } - - args := []string{"-address", addr, hex.EncodeToString(key)} - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - } - - config, err := core.SealAccess().BarrierConfig() - if err != nil { - t.Fatalf("err: %s", err) - } - if config.SecretShares != 5 { - t.Fatal("should rekey") - } -} - -func TestRekey_init(t *testing.T) { - core, _, _ := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - - c := &RekeyCommand{ - Meta: meta.Meta{ - Ui: ui, - }, - } - - args := []string{ - "-address", addr, - "-init", - "-key-threshold", "10", - "-key-shares", "10", - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - config, err := core.RekeyConfig(false) - if err != nil { - t.Fatalf("err: %s", err) - } - if config.SecretShares != 10 { - t.Fatal("should rekey") - } - if config.SecretThreshold != 10 { - t.Fatal("should rekey") - } -} - -func TestRekey_cancel(t *testing.T) { - core, keys, _ := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - c := &RekeyCommand{ - Key: hex.EncodeToString(keys[0]), - Meta: meta.Meta{ - Ui: ui, - }, - } - - args := []string{"-address", addr, "-init"} - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - args = []string{"-address", addr, "-cancel"} - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - config, err := core.RekeyConfig(false) - if err != nil { - t.Fatalf("err: %s", err) - } - if config != nil { - t.Fatal("should not rekey") - } -} - -func TestRekey_status(t *testing.T) { - core, keys, _ := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - c := &RekeyCommand{ - Key: hex.EncodeToString(keys[0]), - Meta: meta.Meta{ - Ui: ui, - }, - } - - args := []string{"-address", addr, "-init"} - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - args = []string{"-address", addr, "-status"} - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - if !strings.Contains(ui.OutputWriter.String(), "Started: true") { - t.Fatalf("bad: %s", ui.OutputWriter.String()) - } -} - -func TestRekey_init_pgp(t *testing.T) { - core, keys, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - bc := &logical.BackendConfig{ - Logger: nil, - System: logical.StaticSystemView{ - DefaultLeaseTTLVal: time.Hour * 24, - MaxLeaseTTLVal: time.Hour * 24 * 32, - }, - } - sysBackend := vault.NewSystemBackend(core) - err := sysBackend.Backend.Setup(bc) - if err != nil { - t.Fatal(err) - } - - ui := new(cli.MockUi) - c := &RekeyCommand{ - Meta: meta.Meta{ - Ui: ui, - }, - } - - tempDir, pubFiles, err := getPubKeyFiles(t) - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(tempDir) - - args := []string{ - "-address", addr, - "-init", - "-key-shares", "4", - "-pgp-keys", pubFiles[0] + ",@" + pubFiles[1] + "," + pubFiles[2] + "," + pubFiles[3], - "-key-threshold", "2", - "-backup", "true", - } - - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - config, err := core.RekeyConfig(false) - if err != nil { - t.Fatalf("err: %s", err) - } - if config.SecretShares != 4 { - t.Fatal("should rekey") - } - if config.SecretThreshold != 2 { - t.Fatal("should rekey") - } - - for _, key := range keys { - c = &RekeyCommand{ - Key: hex.EncodeToString(key), - Meta: meta.Meta{ - Ui: ui, - }, - } - - c.Nonce = config.Nonce - - args = []string{ - "-address", addr, - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - } - - type backupStruct struct { - Keys map[string][]string - KeysB64 map[string][]string - } - backupVals := &backupStruct{} - - req := logical.TestRequest(t, logical.ReadOperation, "rekey/backup") - resp, err := sysBackend.HandleRequest(context.Background(), req) - if err != nil { - t.Fatalf("error running backed-up unseal key fetch: %v", err) - } - if resp == nil { - t.Fatalf("got nil resp with unseal key fetch") - } - if resp.Data["keys"] == nil { - t.Fatalf("could not retrieve unseal keys from token") - } - if resp.Data["nonce"] != config.Nonce { - t.Fatalf("nonce mismatch between rekey and backed-up keys") - } - - backupVals.Keys = resp.Data["keys"].(map[string][]string) - backupVals.KeysB64 = resp.Data["keys_base64"].(map[string][]string) - - // Now delete and try again; the values should be inaccessible - req = logical.TestRequest(t, logical.DeleteOperation, "rekey/backup") - resp, err = sysBackend.HandleRequest(context.Background(), req) - if err != nil { - t.Fatalf("error running backed-up unseal key delete: %v", err) - } - req = logical.TestRequest(t, logical.ReadOperation, "rekey/backup") - resp, err = sysBackend.HandleRequest(context.Background(), req) - if err != nil { - t.Fatalf("error running backed-up unseal key fetch: %v", err) - } - if resp == nil { - t.Fatalf("got nil resp with unseal key fetch") - } - if resp.Data["keys"] != nil { - t.Fatalf("keys found when they should have been deleted") - } - - // Sort, because it'll be tested with DeepEqual later - for k, _ := range backupVals.Keys { - sort.Strings(backupVals.Keys[k]) - sort.Strings(backupVals.KeysB64[k]) - } - - parseDecryptAndTestUnsealKeys(t, ui.OutputWriter.String(), token, true, backupVals.Keys, backupVals.KeysB64, core) -} diff --git a/command/remount.go b/command/remount.go deleted file mode 100644 index a36f1410ad..0000000000 --- a/command/remount.go +++ /dev/null @@ -1,74 +0,0 @@ -package command - -import ( - "fmt" - "strings" - - "github.com/hashicorp/vault/meta" -) - -// RemountCommand is a Command that remounts a mounted secret backend -// to a new endpoint. -type RemountCommand struct { - meta.Meta -} - -func (c *RemountCommand) Run(args []string) int { - flags := c.Meta.FlagSet("remount", meta.FlagSetDefault) - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - args = flags.Args() - if len(args) != 2 { - flags.Usage() - c.Ui.Error(fmt.Sprintf( - "\nremount expects two arguments: the from and to path")) - return 1 - } - - from := args[0] - to := args[1] - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 2 - } - - if err := client.Sys().Remount(from, to); err != nil { - c.Ui.Error(fmt.Sprintf( - "Unmount error: %s", err)) - return 2 - } - - c.Ui.Output(fmt.Sprintf( - "Successfully remounted from '%s' to '%s'!", from, to)) - - return 0 -} - -func (c *RemountCommand) Synopsis() string { - return "Remount a secret backend to a new path" -} - -func (c *RemountCommand) Help() string { - helpText := ` -Usage: vault remount [options] from to - - Remount a mounted secret backend to a new path. - - This command remounts a secret backend that is already mounted to - a new path. All the secrets from the old path will be revoked, but - the data associated with the backend (such as configuration), will - be preserved. - - Example: vault remount secret/ kv/ - -General Options: -` + meta.GeneralOptionsUsage() - - return strings.TrimSpace(helpText) -} diff --git a/command/remount_test.go b/command/remount_test.go deleted file mode 100644 index 7ec1321432..0000000000 --- a/command/remount_test.go +++ /dev/null @@ -1,52 +0,0 @@ -package command - -import ( - "testing" - - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" - "github.com/mitchellh/cli" -) - -func TestRemount(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - c := &RemountCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, - }, - } - - args := []string{ - "-address", addr, - "secret/", "kv", - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } - - mounts, err := client.Sys().ListMounts() - if err != nil { - t.Fatalf("err: %s", err) - } - - _, ok := mounts["secret/"] - if ok { - t.Fatal("should not have mount") - } - - _, ok = mounts["kv/"] - if !ok { - t.Fatal("should have kv") - } -} diff --git a/command/renew.go b/command/renew.go deleted file mode 100644 index 6a3eafe52a..0000000000 --- a/command/renew.go +++ /dev/null @@ -1,90 +0,0 @@ -package command - -import ( - "fmt" - "strconv" - "strings" - - "github.com/hashicorp/vault/meta" -) - -// RenewCommand is a Command that mounts a new mount. -type RenewCommand struct { - meta.Meta -} - -func (c *RenewCommand) Run(args []string) int { - var format string - flags := c.Meta.FlagSet("renew", meta.FlagSetDefault) - flags.StringVar(&format, "format", "table", "") - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - args = flags.Args() - if len(args) < 1 || len(args) >= 3 { - flags.Usage() - c.Ui.Error(fmt.Sprintf( - "\nrenew expects at least one argument: the lease ID to renew")) - return 1 - } - - var increment int - leaseId := args[0] - if len(args) > 1 { - parsed, err := strconv.ParseInt(args[1], 10, 0) - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Invalid increment, must be an int: %s", err)) - return 1 - } - - increment = int(parsed) - } - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 2 - } - - secret, err := client.Sys().Renew(leaseId, increment) - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Renew error: %s", err)) - return 1 - } - - return OutputSecret(c.Ui, format, secret) -} - -func (c *RenewCommand) Synopsis() string { - return "Renew the lease of a secret" -} - -func (c *RenewCommand) Help() string { - helpText := ` -Usage: vault renew [options] id [increment] - - Renew the lease on a secret, extending the time that it can be used - before it is revoked by Vault. - - Every secret in Vault has a lease associated with it. If the user of - the secret wants to use it longer than the lease, then it must be - renewed. Renewing the lease will not change the contents of the secret. - - To renew a secret, run this command with the lease ID returned when it - was read. Optionally, request a specific increment in seconds. Vault - is not required to honor this request. - -General Options: -` + meta.GeneralOptionsUsage() + ` -Renew Options: - - -format=table The format for output. By default it is a whitespace- - delimited table. This can also be json or yaml. -` - return strings.TrimSpace(helpText) -} diff --git a/command/renew_test.go b/command/renew_test.go deleted file mode 100644 index 2191662220..0000000000 --- a/command/renew_test.go +++ /dev/null @@ -1,143 +0,0 @@ -package command - -import ( - "testing" - - "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" - "github.com/mitchellh/cli" -) - -func TestRenew(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - c := &RenewCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, - }, - } - - // write a secret with a lease - client := testClient(t, addr, token) - _, err := client.Logical().Write("secret/foo", map[string]interface{}{ - "key": "value", - "lease": "1m", - }) - if err != nil { - t.Fatalf("err: %s", err) - } - - // read the secret to get its lease ID - secret, err := client.Logical().Read("secret/foo") - if err != nil { - t.Fatalf("err: %s", err) - } - - args := []string{ - "-address", addr, - secret.LeaseID, - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } -} - -func TestRenewBothWays(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - // write a secret with a lease - client := testClient(t, addr, token) - _, err := client.Logical().Write("secret/foo", map[string]interface{}{ - "key": "value", - "ttl": "1m", - }) - if err != nil { - t.Fatalf("err: %s", err) - } - - // read the secret to get its lease ID - secret, err := client.Logical().Read("secret/foo") - if err != nil { - t.Fatalf("err: %s", err) - } - - // Test one renew path - r := client.NewRequest("PUT", "/v1/sys/renew") - body := map[string]interface{}{ - "lease_id": secret.LeaseID, - } - if err := r.SetJSONBody(body); err != nil { - t.Fatal(err) - } - resp, err := client.RawRequest(r) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - secret, err = api.ParseSecret(resp.Body) - if err != nil { - t.Fatal(err) - } - if secret.LeaseDuration != 60 { - t.Fatal("bad lease duration") - } - - // Test another - r = client.NewRequest("PUT", "/v1/sys/leases/renew") - body = map[string]interface{}{ - "lease_id": secret.LeaseID, - } - if err := r.SetJSONBody(body); err != nil { - t.Fatal(err) - } - resp, err = client.RawRequest(r) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - secret, err = api.ParseSecret(resp.Body) - if err != nil { - t.Fatal(err) - } - if secret.LeaseDuration != 60 { - t.Fatal("bad lease duration") - } - - // Test the other - r = client.NewRequest("PUT", "/v1/sys/renew/"+secret.LeaseID) - resp, err = client.RawRequest(r) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - secret, err = api.ParseSecret(resp.Body) - if err != nil { - t.Fatal(err) - } - if secret.LeaseDuration != 60 { - t.Fatalf("bad lease duration; secret is %#v\n", *secret) - } - - // Test another - r = client.NewRequest("PUT", "/v1/sys/leases/renew/"+secret.LeaseID) - resp, err = client.RawRequest(r) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - secret, err = api.ParseSecret(resp.Body) - if err != nil { - t.Fatal(err) - } - if secret.LeaseDuration != 60 { - t.Fatalf("bad lease duration; secret is %#v\n", *secret) - } -} diff --git a/command/revoke.go b/command/revoke.go deleted file mode 100644 index 50933ada42..0000000000 --- a/command/revoke.go +++ /dev/null @@ -1,95 +0,0 @@ -package command - -import ( - "fmt" - "strings" - - "github.com/hashicorp/vault/meta" -) - -// RevokeCommand is a Command that mounts a new mount. -type RevokeCommand struct { - meta.Meta -} - -func (c *RevokeCommand) Run(args []string) int { - var prefix, force bool - flags := c.Meta.FlagSet("revoke", meta.FlagSetDefault) - flags.BoolVar(&prefix, "prefix", false, "") - flags.BoolVar(&force, "force", false, "") - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - args = flags.Args() - if len(args) != 1 { - flags.Usage() - c.Ui.Error(fmt.Sprintf( - "\nrevoke expects one argument: the ID to revoke")) - return 1 - } - leaseId := args[0] - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 2 - } - - switch { - case force && !prefix: - c.Ui.Error(fmt.Sprintf( - "-force requires -prefix")) - return 1 - case force && prefix: - err = client.Sys().RevokeForce(leaseId) - case prefix: - err = client.Sys().RevokePrefix(leaseId) - default: - err = client.Sys().Revoke(leaseId) - } - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Revoke error: %s", err)) - return 1 - } - - c.Ui.Output(fmt.Sprintf("Success! Revoked the secret with ID '%s', if it existed.", leaseId)) - return 0 -} - -func (c *RevokeCommand) Synopsis() string { - return "Revoke a secret." -} - -func (c *RevokeCommand) Help() string { - helpText := ` -Usage: vault revoke [options] id - - Revoke a secret by its lease ID. - - This command revokes a secret by its lease ID that was returned with it. Once - the key is revoked, it is no longer valid. - - With the -prefix flag, the revoke is done by prefix: any secret prefixed with - the given partial ID is revoked. Lease IDs are structured in such a way to - make revocation of prefixes useful. - - With the -force flag, the lease is removed from Vault even if the revocation - fails. This is meant for certain recovery scenarios and should not be used - lightly. This option requires -prefix. - -General Options: -` + meta.GeneralOptionsUsage() + ` -Revoke Options: - - -prefix=true Revoke all secrets with the matching prefix. This - defaults to false: an exact revocation. - - -force=true Delete the lease even if the actual revocation - operation fails. -` - return strings.TrimSpace(helpText) -} diff --git a/command/revoke_test.go b/command/revoke_test.go deleted file mode 100644 index cb9febf6d0..0000000000 --- a/command/revoke_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package command - -import ( - "testing" - - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" - "github.com/mitchellh/cli" -) - -func TestRevoke(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - c := &RevokeCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, - }, - } - - client := testClient(t, addr, token) - _, err := client.Logical().Write("secret/foo", map[string]interface{}{ - "key": "value", - "lease": "1m", - }) - if err != nil { - t.Fatalf("err: %s", err) - } - - secret, err := client.Logical().Read("secret/foo") - if err != nil { - t.Fatalf("err: %s", err) - } - - args := []string{ - "-address", addr, - secret.LeaseID, - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } -} diff --git a/command/rotate.go b/command/rotate.go index 9da387370b..77bc0602b7 100644 --- a/command/rotate.go +++ b/command/rotate.go @@ -4,64 +4,93 @@ import ( "fmt" "strings" - "github.com/hashicorp/vault/meta" + "github.com/mitchellh/cli" + "github.com/posener/complete" ) -// RotateCommand is a Command that rotates the encryption key being used -type RotateCommand struct { - meta.Meta +var _ cli.Command = (*OperatorRotateCommand)(nil) +var _ cli.CommandAutocomplete = (*OperatorRotateCommand)(nil) + +type OperatorRotateCommand struct { + *BaseCommand } -func (c *RotateCommand) Run(args []string) int { - flags := c.Meta.FlagSet("rotate", meta.FlagSetDefault) - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { +func (c *OperatorRotateCommand) Synopsis() string { + return "Rotates the underlying encryption key" +} + +func (c *OperatorRotateCommand) Help() string { + helpText := ` +Usage: vault rotate [options] + + Rotates the underlying encryption key which is used to secure data written + to the storage backend. This installs a new key in the key ring. This new + key is used to encrypted new data, while older keys in the ring are used to + decrypt older data. + + This is an online operation and does not cause downtime. This command is run + per-cluster (not per-server), since Vault servers in HA mode share the same + storage backend. + + Rotate Vault's encryption key: + + $ vault rotate + + For a full list of examples, please see the documentation. + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *OperatorRotateCommand) Flags() *FlagSets { + return c.flagSet(FlagSetHTTP) +} + +func (c *OperatorRotateCommand) AutocompleteArgs() complete.Predictor { + return nil +} + +func (c *OperatorRotateCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *OperatorRotateCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + if len(args) > 0 { + c.UI.Error(fmt.Sprintf("Too many arguments (expected 0, got %d)", len(args))) return 1 } client, err := c.Client() if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) + c.UI.Error(err.Error()) return 2 } // Rotate the key err = client.Sys().Rotate() if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error with key rotation: %s", err)) + c.UI.Error(fmt.Sprintf("Error rotating key: %s", err)) return 2 } // Print the key status status, err := client.Sys().KeyStatus() if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error reading audits: %s", err)) + c.UI.Error(fmt.Sprintf("Error reading key status: %s", err)) return 2 } - c.Ui.Output(fmt.Sprintf("Key Term: %d", status.Term)) - c.Ui.Output(fmt.Sprintf("Installation Time: %v", status.InstallTime)) + c.UI.Output("Success! Rotated key") + c.UI.Output("") + c.UI.Output(printKeyStatus(status)) return 0 } - -func (c *RotateCommand) Synopsis() string { - return "Rotates the backend encryption key used to persist data" -} - -func (c *RotateCommand) Help() string { - helpText := ` -Usage: vault rotate [options] - - Rotates the backend encryption key which is used to secure data - written to the storage backend. This is done by installing a new key - which encrypts new data, while old keys are still used to decrypt - secrets written previously. This is an online operation and is not - disruptive. - -General Options: -` + meta.GeneralOptionsUsage() - return strings.TrimSpace(helpText) -} diff --git a/command/rotate_test.go b/command/rotate_test.go index 257f280071..8f9756de06 100644 --- a/command/rotate_test.go +++ b/command/rotate_test.go @@ -1,31 +1,118 @@ package command import ( + "strings" "testing" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" ) -func TestRotate(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func testOperatorRotateCommand(tb testing.TB) (*cli.MockUi, *OperatorRotateCommand) { + tb.Helper() - ui := new(cli.MockUi) - c := &RotateCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + ui := cli.NewMockUi() + return ui, &OperatorRotateCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestOperatorRotateCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "too_many_args", + []string{"abcd1234"}, + "Too many arguments", + 1, }, } - args := []string{ - "-address", addr, - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } + t.Run("validations", func(t *testing.T) { + t.Parallel() + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ui, cmd := testOperatorRotateCommand(t) + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) + + t.Run("default", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testOperatorRotateCommand(t) + cmd.client = client + + code := cmd.Run([]string{}) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Success! Rotated key" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + status, err := client.Sys().KeyStatus() + if err != nil { + t.Fatal(err) + } + if exp := 1; status.Term < exp { + t.Errorf("expected %d to be less than %d", status.Term, exp) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testOperatorRotateCommand(t) + cmd.client = client + + code := cmd.Run([]string{}) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error rotating key: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testOperatorRotateCommand(t) + assertNoTabs(t, cmd) + }) } diff --git a/command/seal.go b/command/seal.go deleted file mode 100644 index 033c164587..0000000000 --- a/command/seal.go +++ /dev/null @@ -1,63 +0,0 @@ -package command - -import ( - "fmt" - "strings" - - "github.com/hashicorp/vault/meta" -) - -// SealCommand is a Command that seals the vault. -type SealCommand struct { - meta.Meta -} - -func (c *SealCommand) Run(args []string) int { - flags := c.Meta.FlagSet("seal", meta.FlagSetDefault) - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 2 - } - - if err := client.Sys().Seal(); err != nil { - c.Ui.Error(fmt.Sprintf("Error sealing: %s", err)) - return 1 - } - - c.Ui.Output("Vault is now sealed.") - return 0 -} - -func (c *SealCommand) Synopsis() string { - return "Seals the Vault server" -} - -func (c *SealCommand) Help() string { - helpText := ` -Usage: vault seal [options] - - Seal the vault. - - Sealing a vault tells the Vault server to stop responding to any - access operations until it is unsealed again. A sealed vault throws away - its master key to unlock the data, so it is physically blocked from - responding to operations again until the vault is unsealed with - the "unseal" command or via the API. - - This command is idempotent, if the vault is already sealed it does nothing. - - If an unseal has started, sealing the vault will reset the unsealing - process. You'll have to re-enter every portion of the master key again. - This is the same as running "vault unseal -reset". - -General Options: -` + meta.GeneralOptionsUsage() - return strings.TrimSpace(helpText) -} diff --git a/command/seal_test.go b/command/seal_test.go deleted file mode 100644 index c224aee31e..0000000000 --- a/command/seal_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package command - -import ( - "testing" - - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" - "github.com/mitchellh/cli" -) - -func Test_Seal(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - c := &SealCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, - }, - } - - args := []string{"-address", addr} - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - sealed, err := core.Sealed() - if err != nil { - t.Fatalf("err: %s", err) - } - if !sealed { - t.Fatal("should be sealed") - } -} diff --git a/command/secrets.go b/command/secrets.go new file mode 100644 index 0000000000..06e63bec28 --- /dev/null +++ b/command/secrets.go @@ -0,0 +1,43 @@ +package command + +import ( + "strings" + + "github.com/mitchellh/cli" +) + +var _ cli.Command = (*SecretsCommand)(nil) + +type SecretsCommand struct { + *BaseCommand +} + +func (c *SecretsCommand) Synopsis() string { + return "Interact with secrets engines" +} + +func (c *SecretsCommand) Help() string { + helpText := ` +Usage: vault secrets [options] [args] + + This command groups subcommands for interacting with Vault's secrets engines. + Each secret engine behaves differently. Please see the documentation for + more information. + + List all enabled secrets engines: + + $ vault secrets list + + Enable a new secrets engine: + + $ vault secrets enable database + + Please see the individual subcommand help for detailed usage information. +` + + return strings.TrimSpace(helpText) +} + +func (c *SecretsCommand) Run(args []string) int { + return cli.RunResultHelp +} diff --git a/command/secrets_disable.go b/command/secrets_disable.go new file mode 100644 index 0000000000..7002874409 --- /dev/null +++ b/command/secrets_disable.go @@ -0,0 +1,84 @@ +package command + +import ( + "fmt" + "strings" + + "github.com/mitchellh/cli" + "github.com/posener/complete" +) + +var _ cli.Command = (*SecretsDisableCommand)(nil) +var _ cli.CommandAutocomplete = (*SecretsDisableCommand)(nil) + +type SecretsDisableCommand struct { + *BaseCommand +} + +func (c *SecretsDisableCommand) Synopsis() string { + return "Disable a secret engine" +} + +func (c *SecretsDisableCommand) Help() string { + helpText := ` +Usage: vault secrets disable [options] PATH + + Disables a secrets engine at the given PATH. The argument corresponds to + the enabled PATH of the engine, not the TYPE! All secrets created by this + engine are revoked and its Vault data is removed. + + Disable the secrets engine enabled at aws/: + + $ vault secrets disable aws/ + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *SecretsDisableCommand) Flags() *FlagSets { + return c.flagSet(FlagSetHTTP) +} + +func (c *SecretsDisableCommand) AutocompleteArgs() complete.Predictor { + return c.PredictVaultMounts() +} + +func (c *SecretsDisableCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *SecretsDisableCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + switch { + case len(args) < 1: + c.UI.Error(fmt.Sprintf("Not enough arguments (expected 1, got %d)", len(args))) + return 1 + case len(args) > 1: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args))) + return 1 + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + path := ensureTrailingSlash(sanitizePath(args[0])) + + if err := client.Sys().Unmount(path); err != nil { + c.UI.Error(fmt.Sprintf("Error disabling secrets engine at %s: %s", path, err)) + return 2 + } + + c.UI.Output(fmt.Sprintf("Success! Disabled the secrets engine (if it existed) at: %s", path)) + return 0 +} diff --git a/command/secrets_disable_test.go b/command/secrets_disable_test.go new file mode 100644 index 0000000000..567c8956d6 --- /dev/null +++ b/command/secrets_disable_test.go @@ -0,0 +1,152 @@ +package command + +import ( + "strings" + "testing" + + "github.com/hashicorp/vault/api" + "github.com/mitchellh/cli" +) + +func testSecretsDisableCommand(tb testing.TB) (*cli.MockUi, *SecretsDisableCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &SecretsDisableCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestSecretsDisableCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "not_enough_args", + []string{}, + "Not enough arguments", + 1, + }, + { + "too_many_args", + []string{"foo", "bar"}, + "Too many arguments", + 1, + }, + { + "not_real", + []string{"not_real"}, + "Success! Disabled the secrets engine (if it existed) at: not_real/", + 0, + }, + { + "default", + []string{"secret"}, + "Success! Disabled the secrets engine (if it existed) at: secret/", + 0, + }, + } + + t.Run("validations", func(t *testing.T) { + t.Parallel() + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testSecretsDisableCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) + + t.Run("integration", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().Mount("my-secret/", &api.MountInput{ + Type: "generic", + }); err != nil { + t.Fatal(err) + } + + ui, cmd := testSecretsDisableCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "my-secret/", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Success! Disabled the secrets engine (if it existed) at: my-secret/" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + mounts, err := client.Sys().ListMounts() + if err != nil { + t.Fatal(err) + } + + if _, ok := mounts["integration_unmount"]; ok { + t.Errorf("expected mount to not exist: %#v", mounts) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testSecretsDisableCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "pki/", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error disabling secrets engine at pki/: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testSecretsDisableCommand(t) + assertNoTabs(t, cmd) + }) +} diff --git a/command/secrets_enable.go b/command/secrets_enable.go new file mode 100644 index 0000000000..e31d77ec24 --- /dev/null +++ b/command/secrets_enable.go @@ -0,0 +1,217 @@ +package command + +import ( + "fmt" + "strings" + "time" + + "github.com/hashicorp/vault/api" + "github.com/mitchellh/cli" + "github.com/posener/complete" +) + +var _ cli.Command = (*SecretsEnableCommand)(nil) +var _ cli.CommandAutocomplete = (*SecretsEnableCommand)(nil) + +type SecretsEnableCommand struct { + *BaseCommand + + flagDescription string + flagPath string + flagDefaultLeaseTTL time.Duration + flagMaxLeaseTTL time.Duration + flagForceNoCache bool + flagPluginName string + flagLocal bool + flagSealWrap bool +} + +func (c *SecretsEnableCommand) Synopsis() string { + return "Enable a secrets engine" +} + +func (c *SecretsEnableCommand) Help() string { + helpText := ` +Usage: vault secrets enable [options] TYPE + + Enables a secrets engine. By default, secrets engines are enabled at the path + corresponding to their TYPE, but users can customize the path using the + -path option. + + Once enabled, Vault will route all requests which begin with the path to the + secrets engine. + + Enable the AWS secrets engine at aws/: + + $ vault secrets enable aws + + Enable the SSH secrets engine at ssh-prod/: + + $ vault secrets enable -path=ssh-prod ssh + + Enable the database secrets engine with an explicit maximum TTL of 30m: + + $ vault secrets enable -max-lease-ttl=30m database + + Enable a custom plugin (after it is registered in the plugin registry): + + $ vault secrets enable -path=my-secrets -plugin-name=my-plugin plugin + + For a full list of secrets engines and examples, please see the documentation. + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *SecretsEnableCommand) Flags() *FlagSets { + set := c.flagSet(FlagSetHTTP) + + f := set.NewFlagSet("Command Options") + + f.StringVar(&StringVar{ + Name: "description", + Target: &c.flagDescription, + Completion: complete.PredictAnything, + Usage: "Human-friendly description for the purpose of this engine.", + }) + + f.StringVar(&StringVar{ + Name: "path", + Target: &c.flagPath, + Default: "", // The default is complex, so we have to manually document + Completion: complete.PredictAnything, + Usage: "Place where the secrets engine will be accessible. This must be " + + "unique cross all secrets engines. This defaults to the \"type\" of the " + + "secrets engine.", + }) + + f.DurationVar(&DurationVar{ + Name: "default-lease-ttl", + Target: &c.flagDefaultLeaseTTL, + Completion: complete.PredictAnything, + Usage: "The default lease TTL for this secrets engine. If unspecified, " + + "this defaults to the Vault server's globally configured default lease " + + "TTL.", + }) + + f.DurationVar(&DurationVar{ + Name: "max-lease-ttl", + Target: &c.flagMaxLeaseTTL, + Completion: complete.PredictAnything, + Usage: "The maximum lease TTL for this secrets engine. If unspecified, " + + "this defaults to the Vault server's globally configured maximum lease " + + "TTL.", + }) + + f.BoolVar(&BoolVar{ + Name: "force-no-cache", + Target: &c.flagForceNoCache, + Default: false, + Usage: "Force the secrets engine to disable caching. If unspecified, this " + + "defaults to the Vault server's globally configured cache settings. " + + "This does not affect caching of the underlying encrypted data storage.", + }) + + f.StringVar(&StringVar{ + Name: "plugin-name", + Target: &c.flagPluginName, + Completion: complete.PredictAnything, + Usage: "Name of the secrets engine plugin. This plugin name must already " + + "exist in Vault's plugin catalog.", + }) + + f.BoolVar(&BoolVar{ + Name: "local", + Target: &c.flagLocal, + Default: false, + Usage: "Mark the secrets engine as local-only. Local engines are not " + + "replicated or removed by replication.", + }) + + f.BoolVar(&BoolVar{ + Name: "seal-wrap", + Target: &c.flagSealWrap, + Default: false, + Usage: "Enable seal wrapping of critical values in the secrets engine.", + }) + + return set +} + +func (c *SecretsEnableCommand) AutocompleteArgs() complete.Predictor { + return c.PredictVaultAvailableMounts() +} + +func (c *SecretsEnableCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *SecretsEnableCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + switch { + case len(args) < 1: + c.UI.Error(fmt.Sprintf("Not enough arguments (expected 1, got %d)", len(args))) + return 1 + case len(args) > 1: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args))) + return 1 + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + // Get the engine type type (first arg) + engineType := strings.TrimSpace(args[0]) + + // If no path is specified, we default the path to the backend type + // or use the plugin name if it's a plugin backend + mountPath := c.flagPath + if mountPath == "" { + if engineType == "plugin" { + mountPath = c.flagPluginName + } else { + mountPath = engineType + } + } + + // Append a trailing slash to indicate it's a path in output + mountPath = ensureTrailingSlash(mountPath) + + // Build mount input + mountInput := &api.MountInput{ + Type: engineType, + Description: c.flagDescription, + Local: c.flagLocal, + SealWrap: c.flagSealWrap, + Config: api.MountConfigInput{ + DefaultLeaseTTL: c.flagDefaultLeaseTTL.String(), + MaxLeaseTTL: c.flagMaxLeaseTTL.String(), + ForceNoCache: c.flagForceNoCache, + PluginName: c.flagPluginName, + }, + } + + if err := client.Sys().Mount(mountPath, mountInput); err != nil { + c.UI.Error(fmt.Sprintf("Error enabling: %s", err)) + return 2 + } + + thing := engineType + " secrets engine" + if engineType == "plugin" { + thing = c.flagPluginName + " plugin" + } + + c.UI.Output(fmt.Sprintf("Success! Enabled the %s at: %s", thing, mountPath)) + return 0 +} diff --git a/command/secrets_enable_test.go b/command/secrets_enable_test.go new file mode 100644 index 0000000000..e241edfa65 --- /dev/null +++ b/command/secrets_enable_test.go @@ -0,0 +1,171 @@ +package command + +import ( + "strings" + "testing" + + "github.com/mitchellh/cli" +) + +func testSecretsEnableCommand(tb testing.TB) (*cli.MockUi, *SecretsEnableCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &SecretsEnableCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestSecretsEnableCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "not_enough_args", + []string{}, + "Not enough arguments", + 1, + }, + { + "too_many_args", + []string{"foo", "bar"}, + "Too many arguments", + 1, + }, + { + "not_a_valid_mount", + []string{"nope_definitely_not_a_valid_mount_like_ever"}, + "", + 2, + }, + { + "mount", + []string{"transit"}, + "Success! Enabled the transit secrets engine at: transit/", + 0, + }, + { + "mount_path", + []string{ + "-path", "transit_mount_point", + "transit", + }, + "Success! Enabled the transit secrets engine at: transit_mount_point/", + 0, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testSecretsEnableCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + + t.Run("integration", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testSecretsEnableCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "-path", "mount_integration/", + "-description", "The best kind of test", + "-default-lease-ttl", "30m", + "-max-lease-ttl", "1h", + "-force-no-cache", + "pki", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Success! Enabled the pki secrets engine at: mount_integration/" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + mounts, err := client.Sys().ListMounts() + if err != nil { + t.Fatal(err) + } + + mountInfo, ok := mounts["mount_integration/"] + if !ok { + t.Fatalf("expected mount to exist") + } + if exp := "pki"; mountInfo.Type != exp { + t.Errorf("expected %q to be %q", mountInfo.Type, exp) + } + if exp := "The best kind of test"; mountInfo.Description != exp { + t.Errorf("expected %q to be %q", mountInfo.Description, exp) + } + if exp := 1800; mountInfo.Config.DefaultLeaseTTL != exp { + t.Errorf("expected %d to be %d", mountInfo.Config.DefaultLeaseTTL, exp) + } + if exp := 3600; mountInfo.Config.MaxLeaseTTL != exp { + t.Errorf("expected %d to be %d", mountInfo.Config.MaxLeaseTTL, exp) + } + if exp := true; mountInfo.Config.ForceNoCache != exp { + t.Errorf("expected %t to be %t", mountInfo.Config.ForceNoCache, exp) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testSecretsEnableCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "pki", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error enabling: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testSecretsEnableCommand(t) + assertNoTabs(t, cmd) + }) +} diff --git a/command/secrets_list.go b/command/secrets_list.go new file mode 100644 index 0000000000..f50f618085 --- /dev/null +++ b/command/secrets_list.go @@ -0,0 +1,169 @@ +package command + +import ( + "fmt" + "sort" + "strconv" + "strings" + + "github.com/hashicorp/vault/api" + "github.com/mitchellh/cli" + "github.com/posener/complete" +) + +var _ cli.Command = (*SecretsListCommand)(nil) +var _ cli.CommandAutocomplete = (*SecretsListCommand)(nil) + +type SecretsListCommand struct { + *BaseCommand + + flagDetailed bool +} + +func (c *SecretsListCommand) Synopsis() string { + return "List enabled secrets engines" +} + +func (c *SecretsListCommand) Help() string { + helpText := ` +Usage: vault secrets list [options] + + Lists the enabled secret engines on the Vault server. This command also + outputs information about the enabled path including configured TTLs and + human-friendly descriptions. A TTL of "system" indicates that the system + default is in use. + + List all enabled secrets engines: + + $ vault secrets list + + List all enabled secrets engines with detailed output: + + $ vault secrets list -detailed + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *SecretsListCommand) Flags() *FlagSets { + set := c.flagSet(FlagSetHTTP) + + f := set.NewFlagSet("Command Options") + + f.BoolVar(&BoolVar{ + Name: "detailed", + Target: &c.flagDetailed, + Default: false, + Usage: "Print detailed information such as TTLs and replication status " + + "about each secrets engine.", + }) + + return set +} + +func (c *SecretsListCommand) AutocompleteArgs() complete.Predictor { + return c.PredictVaultFiles() +} + +func (c *SecretsListCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *SecretsListCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + if len(args) > 0 { + c.UI.Error(fmt.Sprintf("Too many arguments (expected 0, got %d)", len(args))) + return 1 + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + mounts, err := client.Sys().ListMounts() + if err != nil { + c.UI.Error(fmt.Sprintf("Error listing secrets engines: %s", err)) + return 2 + } + + if c.flagDetailed { + c.UI.Output(tableOutput(c.detailedMounts(mounts), nil)) + return 0 + } + + c.UI.Output(tableOutput(c.simpleMounts(mounts), nil)) + return 0 +} + +func (c *SecretsListCommand) simpleMounts(mounts map[string]*api.MountOutput) []string { + paths := make([]string, 0, len(mounts)) + for path := range mounts { + paths = append(paths, path) + } + sort.Strings(paths) + + out := []string{"Path | Type | Description"} + for _, path := range paths { + mount := mounts[path] + out = append(out, fmt.Sprintf("%s | %s | %s", path, mount.Type, mount.Description)) + } + + return out +} + +func (c *SecretsListCommand) detailedMounts(mounts map[string]*api.MountOutput) []string { + paths := make([]string, 0, len(mounts)) + for path := range mounts { + paths = append(paths, path) + } + sort.Strings(paths) + + calcTTL := func(typ string, ttl int) string { + switch { + case typ == "system", typ == "cubbyhole": + return "" + case ttl != 0: + return strconv.Itoa(ttl) + default: + return "system" + } + } + + out := []string{"Path | Type | Accessor | Plugin | Default TTL | Max TTL | Force No Cache | Replication | Seal Wrap | Description"} + for _, path := range paths { + mount := mounts[path] + + defaultTTL := calcTTL(mount.Type, mount.Config.DefaultLeaseTTL) + maxTTL := calcTTL(mount.Type, mount.Config.MaxLeaseTTL) + + replication := "replicated" + if mount.Local { + replication = "local" + } + + out = append(out, fmt.Sprintf("%s | %s | %s | %s | %s | %s | %t | %s | %t | %s", + path, + mount.Type, + mount.Accessor, + mount.Config.PluginName, + defaultTTL, + maxTTL, + mount.Config.ForceNoCache, + replication, + mount.SealWrap, + mount.Description, + )) + } + + return out +} diff --git a/command/secrets_list_test.go b/command/secrets_list_test.go new file mode 100644 index 0000000000..9edb628202 --- /dev/null +++ b/command/secrets_list_test.go @@ -0,0 +1,105 @@ +package command + +import ( + "strings" + "testing" + + "github.com/mitchellh/cli" +) + +func testSecretsListCommand(tb testing.TB) (*cli.MockUi, *SecretsListCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &SecretsListCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestSecretsListCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "too_many_args", + []string{"foo"}, + "Too many arguments", + 1, + }, + { + "lists", + nil, + "Path", + 0, + }, + { + "detailed", + []string{"-detailed"}, + "Default TTL", + 0, + }, + } + + t.Run("validations", func(t *testing.T) { + t.Parallel() + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testSecretsListCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testSecretsListCommand(t) + cmd.client = client + + code := cmd.Run([]string{}) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error listing secrets engines: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testSecretsListCommand(t) + assertNoTabs(t, cmd) + }) +} diff --git a/command/secrets_move.go b/command/secrets_move.go new file mode 100644 index 0000000000..542b14d3d2 --- /dev/null +++ b/command/secrets_move.go @@ -0,0 +1,89 @@ +package command + +import ( + "fmt" + "strings" + + "github.com/mitchellh/cli" + "github.com/posener/complete" +) + +var _ cli.Command = (*SecretsMoveCommand)(nil) +var _ cli.CommandAutocomplete = (*SecretsMoveCommand)(nil) + +type SecretsMoveCommand struct { + *BaseCommand +} + +func (c *SecretsMoveCommand) Synopsis() string { + return "Move a secrets engine to a new path" +} + +func (c *SecretsMoveCommand) Help() string { + helpText := ` +Usage: vault secrets move [options] SOURCE DESTINATION + + Moves an existing secrets engine to a new path. Any leases from the old + secrets engine are revoked, but all configuration associated with the engine + is preserved. + + WARNING! Moving an existing secrets engine will revoke any leases from the + old engine. + + Move the existing secrets engine at secret/ to generic/: + + $ vault secrets move secret/ generic/ + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *SecretsMoveCommand) Flags() *FlagSets { + return c.flagSet(FlagSetHTTP) +} + +func (c *SecretsMoveCommand) AutocompleteArgs() complete.Predictor { + return c.PredictVaultMounts() +} + +func (c *SecretsMoveCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *SecretsMoveCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + switch { + case len(args) < 2: + c.UI.Error(fmt.Sprintf("Not enough arguments (expected 2, got %d)", len(args))) + return 1 + case len(args) > 2: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 2, got %d)", len(args))) + return 1 + } + + // Grab the source and destination + source := ensureTrailingSlash(args[0]) + destination := ensureTrailingSlash(args[1]) + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + if err := client.Sys().Remount(source, destination); err != nil { + c.UI.Error(fmt.Sprintf("Error moving secrets engine %s to %s: %s", source, destination, err)) + return 2 + } + + c.UI.Output(fmt.Sprintf("Success! Moved secrets engine %s to: %s", source, destination)) + return 0 +} diff --git a/command/secrets_move_test.go b/command/secrets_move_test.go new file mode 100644 index 0000000000..0936a7dd30 --- /dev/null +++ b/command/secrets_move_test.go @@ -0,0 +1,135 @@ +package command + +import ( + "strings" + "testing" + + "github.com/mitchellh/cli" +) + +func testSecretsMoveCommand(tb testing.TB) (*cli.MockUi, *SecretsMoveCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &SecretsMoveCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestSecretsMoveCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "not_enough_args", + []string{}, + "Not enough arguments", + 1, + }, + { + "too_many_args", + []string{"foo", "bar", "baz"}, + "Too many arguments", + 1, + }, + { + "non_existent", + []string{"not_real", "over_here"}, + "Error moving secrets engine not_real/ to over_here/", + 2, + }, + } + + t.Run("validations", func(t *testing.T) { + t.Parallel() + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ui, cmd := testSecretsMoveCommand(t) + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) + + t.Run("integration", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testSecretsMoveCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "secret/", "generic/", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Success! Moved secrets engine secret/ to: generic/" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + mounts, err := client.Sys().ListMounts() + if err != nil { + t.Fatal(err) + } + + if _, ok := mounts["generic/"]; !ok { + t.Errorf("expected mount at generic/: %#v", mounts) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testSecretsMoveCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "secret/", "generic/", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error moving secrets engine secret/ to generic/:" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testSecretsMoveCommand(t) + assertNoTabs(t, cmd) + }) +} diff --git a/command/secrets_tune.go b/command/secrets_tune.go new file mode 100644 index 0000000000..b2029b7507 --- /dev/null +++ b/command/secrets_tune.go @@ -0,0 +1,119 @@ +package command + +import ( + "fmt" + "strings" + "time" + + "github.com/hashicorp/vault/api" + "github.com/mitchellh/cli" + "github.com/posener/complete" +) + +var _ cli.Command = (*SecretsTuneCommand)(nil) +var _ cli.CommandAutocomplete = (*SecretsTuneCommand)(nil) + +type SecretsTuneCommand struct { + *BaseCommand + + flagDefaultLeaseTTL time.Duration + flagMaxLeaseTTL time.Duration +} + +func (c *SecretsTuneCommand) Synopsis() string { + return "Tune a secrets engine configuration" +} + +func (c *SecretsTuneCommand) Help() string { + helpText := ` +Usage: vault secrets tune [options] PATH + + Tunes the configuration options for the secrets engine at the given PATH. + The argument corresponds to the PATH where the secrets engine is enabled, + not the TYPE! + + Tune the default lease for the PKI secrets engine: + + $ vault secrets tune -default-lease-ttl=72h pki/ + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *SecretsTuneCommand) Flags() *FlagSets { + set := c.flagSet(FlagSetHTTP) + + f := set.NewFlagSet("Command Options") + + f.DurationVar(&DurationVar{ + Name: "default-lease-ttl", + Target: &c.flagDefaultLeaseTTL, + Default: 0, + EnvVar: "", + Completion: complete.PredictAnything, + Usage: "The default lease TTL for this secrets engine. If unspecified, " + + "this defaults to the Vault server's globally configured default lease " + + "TTL, or a previously configured value for the secrets engine.", + }) + + f.DurationVar(&DurationVar{ + Name: "max-lease-ttl", + Target: &c.flagMaxLeaseTTL, + Default: 0, + EnvVar: "", + Completion: complete.PredictAnything, + Usage: "The maximum lease TTL for this secrets engine. If unspecified, " + + "this defaults to the Vault server's globally configured maximum lease " + + "TTL, or a previously configured value for the secrets engine.", + }) + + return set +} + +func (c *SecretsTuneCommand) AutocompleteArgs() complete.Predictor { + return c.PredictVaultMounts() +} + +func (c *SecretsTuneCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *SecretsTuneCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + switch { + case len(args) < 1: + c.UI.Error(fmt.Sprintf("Not enough arguments (expected 1, got %d)", len(args))) + return 1 + case len(args) > 1: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args))) + return 1 + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + // Append a trailing slash to indicate it's a path in output + mountPath := ensureTrailingSlash(sanitizePath(args[0])) + + if err := client.Sys().TuneMount(mountPath, api.MountConfigInput{ + DefaultLeaseTTL: ttlToAPI(c.flagDefaultLeaseTTL), + MaxLeaseTTL: ttlToAPI(c.flagMaxLeaseTTL), + }); err != nil { + c.UI.Error(fmt.Sprintf("Error tuning secrets engine %s: %s", mountPath, err)) + return 2 + } + + c.UI.Output(fmt.Sprintf("Success! Tuned the secrets engine at: %s", mountPath)) + return 0 +} diff --git a/command/secrets_tune_test.go b/command/secrets_tune_test.go new file mode 100644 index 0000000000..11d90263d3 --- /dev/null +++ b/command/secrets_tune_test.go @@ -0,0 +1,149 @@ +package command + +import ( + "strings" + "testing" + + "github.com/hashicorp/vault/api" + "github.com/mitchellh/cli" +) + +func testSecretsTuneCommand(tb testing.TB) (*cli.MockUi, *SecretsTuneCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &SecretsTuneCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestSecretsTuneCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "not_enough_args", + []string{}, + "Not enough arguments", + 1, + }, + { + "too_many_args", + []string{"foo", "bar"}, + "Too many arguments", + 1, + }, + } + + t.Run("validations", func(t *testing.T) { + t.Parallel() + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ui, cmd := testSecretsTuneCommand(t) + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) + + t.Run("integration", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testSecretsTuneCommand(t) + cmd.client = client + + // Mount + if err := client.Sys().Mount("mount_tune_integration", &api.MountInput{ + Type: "pki", + }); err != nil { + t.Fatal(err) + } + + code := cmd.Run([]string{ + "-default-lease-ttl", "30m", + "-max-lease-ttl", "1h", + "mount_tune_integration/", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Success! Tuned the secrets engine at: mount_tune_integration/" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + mounts, err := client.Sys().ListMounts() + if err != nil { + t.Fatal(err) + } + + mountInfo, ok := mounts["mount_tune_integration/"] + if !ok { + t.Fatalf("expected mount to exist") + } + if exp := "pki"; mountInfo.Type != exp { + t.Errorf("expected %q to be %q", mountInfo.Type, exp) + } + if exp := 1800; mountInfo.Config.DefaultLeaseTTL != exp { + t.Errorf("expected %d to be %d", mountInfo.Config.DefaultLeaseTTL, exp) + } + if exp := 3600; mountInfo.Config.MaxLeaseTTL != exp { + t.Errorf("expected %d to be %d", mountInfo.Config.MaxLeaseTTL, exp) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testSecretsTuneCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "pki/", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error tuning secrets engine pki/: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testSecretsTuneCommand(t) + assertNoTabs(t, cmd) + }) +} diff --git a/command/server.go b/command/server.go index 17bee97eeb..4943cb7092 100644 --- a/command/server.go +++ b/command/server.go @@ -8,18 +8,17 @@ import ( "net/http" "net/url" "os" - "os/signal" "path/filepath" "runtime" "sort" "strconv" "strings" "sync" - "syscall" "time" colorable "github.com/mattn/go-colorable" log "github.com/mgutz/logxi/v1" + "github.com/mitchellh/cli" testing "github.com/mitchellh/go-testing-interface" "github.com/posener/complete" @@ -33,7 +32,6 @@ import ( "github.com/hashicorp/go-multierror" "github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/command/server" - "github.com/hashicorp/vault/helper/flag-slice" "github.com/hashicorp/vault/helper/gated-writer" "github.com/hashicorp/vault/helper/logbridge" "github.com/hashicorp/vault/helper/logformat" @@ -42,14 +40,17 @@ import ( "github.com/hashicorp/vault/helper/reload" vaulthttp "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/logical" - "github.com/hashicorp/vault/meta" "github.com/hashicorp/vault/physical" "github.com/hashicorp/vault/vault" "github.com/hashicorp/vault/version" ) -// ServerCommand is a Command that starts the Vault server. +var _ cli.Command = (*ServerCommand)(nil) +var _ cli.CommandAutocomplete = (*ServerCommand)(nil) + type ServerCommand struct { + *BaseCommand + AuditBackends map[string]audit.Factory CredentialBackends map[string]logical.Factory LogicalBackends map[string]logical.Factory @@ -60,8 +61,6 @@ type ServerCommand struct { WaitGroup *sync.WaitGroup - meta.Meta - logGate *gatedwriter.Writer logger log.Logger @@ -69,30 +68,200 @@ type ServerCommand struct { reloadFuncsLock *sync.RWMutex reloadFuncs *map[string][]reload.ReloadFunc + startedCh chan (struct{}) // for tests + reloadedCh chan (struct{}) // for tests + + // new stuff + flagConfigs []string + flagLogLevel string + flagDev bool + flagDevRootTokenID string + flagDevListenAddr string + + flagDevPluginDir string + flagDevHA bool + flagDevLatency int + flagDevLatencyJitter int + flagDevLeasedKV bool + flagDevSkipInit bool + flagDevThreeNode bool + flagDevTransactional bool + flagTestVerifyOnly bool +} + +func (c *ServerCommand) Synopsis() string { + return "Start a Vault server" +} + +func (c *ServerCommand) Help() string { + helpText := ` +Usage: vault server [options] + + This command starts a Vault server that responds to API requests. By default, + Vault will start in a "sealed" state. The Vault cluster must be initialized + before use, usually by the "vault init" command. Each Vault server must also + be unsealed using the "vault unseal" command or the API before the server can + respond to requests. + + Start a server with a configuration file: + + $ vault server -config=/etc/vault/config.hcl + + Run in "dev" mode: + + $ vault server -dev -dev-root-token-id="root" + + For a full list of examples, please see the documentation. + +` + c.Flags().Help() + return strings.TrimSpace(helpText) +} + +func (c *ServerCommand) Flags() *FlagSets { + set := c.flagSet(FlagSetHTTP) + + f := set.NewFlagSet("Command Options") + + f.StringSliceVar(&StringSliceVar{ + Name: "config", + Target: &c.flagConfigs, + Completion: complete.PredictOr( + complete.PredictFiles("*.hcl"), + complete.PredictFiles("*.json"), + complete.PredictDirs("*"), + ), + Usage: "Path to a configuration file or directory of configuration " + + "files. This flag can be specified multiple times to load multiple " + + "configurations. If the path is a directory, all files which end in " + + ".hcl or .json are loaded.", + }) + + f.StringVar(&StringVar{ + Name: "log-level", + Target: &c.flagLogLevel, + Default: "info", + EnvVar: "VAULT_LOG_LEVEL", + Completion: complete.PredictSet("trace", "debug", "info", "warn", "err"), + Usage: "Log verbosity level. Supported values (in order of detail) are " + + "\"trace\", \"debug\", \"info\", \"warn\", and \"err\".", + }) + + f = set.NewFlagSet("Dev Options") + + f.BoolVar(&BoolVar{ + Name: "dev", + Target: &c.flagDev, + Usage: "Enable development mode. In this mode, Vault runs in-memory and " + + "starts unsealed. As the name implies, do not run \"dev\" mode in " + + "production.", + }) + + f.StringVar(&StringVar{ + Name: "dev-root-token-id", + Target: &c.flagDevRootTokenID, + Default: "", + EnvVar: "VAULT_DEV_ROOT_TOKEN_ID", + Usage: "Initial root token. This only applies when running in \"dev\" " + + "mode.", + }) + + f.StringVar(&StringVar{ + Name: "dev-listen-address", + Target: &c.flagDevListenAddr, + Default: "127.0.0.1:8200", + EnvVar: "VAULT_DEV_LISTEN_ADDRESS", + Usage: "Address to bind to in \"dev\" mode.", + }) + + // Internal-only flags to follow. + // + // Why hello there little source code reader! Welcome to the Vault source + // code. The remaining options are intentionally undocumented and come with + // no warranty or backwards-compatability promise. Do not use these flags + // in production. Do not build automation using these flags. Unless you are + // developing against Vault, you should not need any of these flags. + + f.StringVar(&StringVar{ + Name: "dev-plugin-dir", + Target: &c.flagDevPluginDir, + Default: "", + Completion: complete.PredictDirs("*"), + Hidden: true, + }) + + f.BoolVar(&BoolVar{ + Name: "dev-ha", + Target: &c.flagDevHA, + Default: false, + Hidden: true, + }) + + f.BoolVar(&BoolVar{ + Name: "dev-transactional", + Target: &c.flagDevTransactional, + Default: false, + Hidden: true, + }) + + f.IntVar(&IntVar{ + Name: "dev-latency", + Target: &c.flagDevLatency, + Hidden: true, + }) + + f.IntVar(&IntVar{ + Name: "dev-latency-jitter", + Target: &c.flagDevLatencyJitter, + Hidden: true, + }) + + f.BoolVar(&BoolVar{ + Name: "dev-leased-kv", + Target: &c.flagDevLeasedKV, + Default: false, + Hidden: true, + }) + + f.BoolVar(&BoolVar{ + Name: "dev-skip-init", + Target: &c.flagDevSkipInit, + Default: false, + Hidden: true, + }) + + f.BoolVar(&BoolVar{ + Name: "dev-three-node", + Target: &c.flagDevThreeNode, + Default: false, + Hidden: true, + }) + + // TODO: should this be a public flag? + f.BoolVar(&BoolVar{ + Name: "test-verify-only", + Target: &c.flagTestVerifyOnly, + Default: false, + Hidden: true, + }) + + // End internal-only flags. + + return set +} + +func (c *ServerCommand) AutocompleteArgs() complete.Predictor { + return complete.PredictNothing +} + +func (c *ServerCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() } func (c *ServerCommand) Run(args []string) int { - var dev, verifyOnly, devHA, devTransactional, devLeasedKV, devThreeNode, devSkipInit bool - var configPath []string - var logLevelFlag, devRootTokenID, devListenAddress, devPluginDir string - var devLatency, devLatencyJitter int - flags := c.Meta.FlagSet("server", meta.FlagSetDefault) - flags.BoolVar(&dev, "dev", false, "") - flags.StringVar(&devRootTokenID, "dev-root-token-id", "", "") - flags.StringVar(&devListenAddress, "dev-listen-address", "", "") - flags.StringVar(&devPluginDir, "dev-plugin-dir", "", "") - flags.StringVar(&logLevelFlag, "log-level", "", "") - flags.IntVar(&devLatency, "dev-latency", 0, "") - flags.IntVar(&devLatencyJitter, "dev-latency-jitter", 20, "") - flags.BoolVar(&verifyOnly, "verify-only", false, "") - flags.BoolVar(&devHA, "dev-ha", false, "") - flags.BoolVar(&devTransactional, "dev-transactional", false, "") - flags.BoolVar(&devLeasedKV, "dev-leased-kv", false, "") - flags.BoolVar(&devThreeNode, "dev-three-node", false, "") - flags.BoolVar(&devSkipInit, "dev-skip-init", false, "") - flags.Usage = func() { c.Ui.Output(c.Help()) } - flags.Var((*sliceflag.StringFlag)(&configPath), "config", "config") - if err := flags.Parse(args); err != nil { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) return 1 } @@ -100,14 +269,8 @@ func (c *ServerCommand) Run(args []string) int { // start logging too early. c.logGate = &gatedwriter.Writer{Writer: colorable.NewColorable(os.Stderr)} var level int - var logLevel string - if os.Getenv("VAULT_LOG_LEVEL") != "" { - logLevel = os.Getenv("VAULT_LOG_LEVEL") - } - if logLevelFlag != "" { - logLevel = strings.ToLower(strings.TrimSpace(logLevelFlag)) - } - switch logLevel { + c.flagLogLevel = strings.ToLower(strings.TrimSpace(c.flagLogLevel)) + switch c.flagLogLevel { case "trace": level = log.LevelTrace case "debug": @@ -121,7 +284,7 @@ func (c *ServerCommand) Run(args []string) int { case "err", "error": level = log.LevelError default: - c.Ui.Output(fmt.Sprintf("Unknown log level %s", logLevel)) + c.UI.Error(fmt.Sprintf("Unknown log level: %s", c.flagLogLevel)) return 1 } @@ -131,7 +294,7 @@ func (c *ServerCommand) Run(args []string) int { } switch strings.ToLower(logFormat) { case "vault", "vault_json", "vault-json", "vaultjson", "json", "": - if devThreeNode { + if c.flagDevThreeNode { c.logger = logbridge.NewLogger(hclog.New(&hclog.LoggerOptions{ Mutex: &sync.Mutex{}, Output: c.logGate, @@ -148,45 +311,37 @@ func (c *ServerCommand) Run(args []string) int { log: os.Getenv("VAULT_GRPC_LOGGING") != "", }) - if os.Getenv("VAULT_DEV_ROOT_TOKEN_ID") != "" && devRootTokenID == "" { - devRootTokenID = os.Getenv("VAULT_DEV_ROOT_TOKEN_ID") - } - - if os.Getenv("VAULT_DEV_LISTEN_ADDRESS") != "" && devListenAddress == "" { - devListenAddress = os.Getenv("VAULT_DEV_LISTEN_ADDRESS") - } - - if devHA || devTransactional || devLeasedKV || devThreeNode { - dev = true + // Automatically enable dev mode if other dev flags are provided. + if c.flagDevHA || c.flagDevTransactional || c.flagDevLeasedKV || c.flagDevThreeNode { + c.flagDev = true } // Validation - if !dev { + if !c.flagDev { switch { - case len(configPath) == 0: - c.Ui.Output("At least one config path must be specified with -config") - flags.Usage() - return 1 - case devRootTokenID != "": - c.Ui.Output("Root token ID can only be specified with -dev") - flags.Usage() + case len(c.flagConfigs) == 0: + c.UI.Error("Must specify at least one config path using -config") return 1 + case c.flagDevRootTokenID != "": + c.UI.Warn(wrapAtLength( + "You cannot specify a custom root token ID outside of \"dev\" mode. " + + "Your request has been ignored.")) + c.flagDevRootTokenID = "" } } // Load the configuration var config *server.Config - if dev { - config = server.DevConfig(devHA, devTransactional) - if devListenAddress != "" { - config.Listeners[0].Config["address"] = devListenAddress + if c.flagDev { + config = server.DevConfig(c.flagDevHA, c.flagDevTransactional) + if c.flagDevListenAddr != "" { + config.Listeners[0].Config["address"] = c.flagDevListenAddr } } - for _, path := range configPath { + for _, path := range c.flagConfigs { current, err := server.LoadConfig(path, c.logger) if err != nil { - c.Ui.Output(fmt.Sprintf( - "Error loading configuration from %s: %s", path, err)) + c.UI.Error(fmt.Sprintf("Error loading configuration from %s: %s", path, err)) return 1 } @@ -199,49 +354,51 @@ func (c *ServerCommand) Run(args []string) int { // Ensure at least one config was found. if config == nil { - c.Ui.Output("No configuration files found.") + c.UI.Output(wrapAtLength( + "No configuration files found. Please provide configurations with the " + + "-config flag. If you are supply the path to a directory, please " + + "ensure the directory contains files with the .hcl or .json " + + "extension.")) return 1 } // Ensure that a backend is provided if config.Storage == nil { - c.Ui.Output("A storage backend must be specified") + c.UI.Output("A storage backend must be specified") return 1 } // If mlockall(2) isn't supported, show a warning. We disable this // in dev because it is quite scary to see when first using Vault. - if !dev && !mlock.Supported() { - c.Ui.Output("==> WARNING: mlock not supported on this system!\n") - c.Ui.Output(" An `mlockall(2)`-like syscall to prevent memory from being") - c.Ui.Output(" swapped to disk is not supported on this system. Running") - c.Ui.Output(" Vault on an mlockall(2) enabled system is much more secure.\n") + if !c.flagDev && !mlock.Supported() { + c.UI.Warn(wrapAtLength( + "WARNING! mlock is not supported on this system! An mlockall(2)-like " + + "syscall to prevent memory from being swapped to disk is not " + + "supported on this system. For better security, only run Vault on " + + "systems where this call is supported. If you are running Vault " + + "in a Docker container, provide the IPC_LOCK cap to the container.")) } if err := c.setupTelemetry(config); err != nil { - c.Ui.Output(fmt.Sprintf("Error initializing telemetry: %s", err)) + c.UI.Error(fmt.Sprintf("Error initializing telemetry: %s", err)) return 1 } // Initialize the backend factory, exists := c.PhysicalBackends[config.Storage.Type] if !exists { - c.Ui.Output(fmt.Sprintf( - "Unknown storage type %s", - config.Storage.Type)) + c.UI.Error(fmt.Sprintf("Unknown storage type %s", config.Storage.Type)) return 1 } backend, err := factory(config.Storage.Config, c.logger) if err != nil { - c.Ui.Output(fmt.Sprintf( - "Error initializing storage of type %s: %s", - config.Storage.Type, err)) + c.UI.Error(fmt.Sprintf("Error initializing storage of type %s: %s", config.Storage.Type, err)) return 1 } infoKeys := make([]string, 0, 10) info := make(map[string]string) - info["log level"] = logLevel + info["log level"] = c.flagLogLevel infoKeys = append(infoKeys, "log level") var seal vault.Seal = &vault.DefaultSeal{} @@ -251,13 +408,13 @@ func (c *ServerCommand) Run(args []string) int { if seal != nil { err = seal.Finalize() if err != nil { - c.Ui.Error(fmt.Sprintf("Error finalizing seals: %v", err)) + c.UI.Error(fmt.Sprintf("Error finalizing seals: %v", err)) } } }() if seal == nil { - c.Ui.Error(fmt.Sprintf("Could not create seal; most likely proper Seal configuration information was not set, but no error was generated.")) + c.UI.Error(fmt.Sprintf("Could not create seal! Most likely proper Seal configuration information was not set, but no error was generated.")) return 1 } @@ -279,27 +436,26 @@ func (c *ServerCommand) Run(args []string) int { PluginDirectory: config.PluginDirectory, EnableRaw: config.EnableRawEndpoint, } - - if dev { - coreConfig.DevToken = devRootTokenID - if devLeasedKV { + if c.flagDev { + coreConfig.DevToken = c.flagDevRootTokenID + if c.flagDevLeasedKV { coreConfig.LogicalBackends["kv"] = vault.LeasedPassthroughBackendFactory } - if devPluginDir != "" { - coreConfig.PluginDirectory = devPluginDir + if c.flagDevPluginDir != "" { + coreConfig.PluginDirectory = c.flagDevPluginDir } - if devLatency > 0 { - injectLatency := time.Duration(devLatency) * time.Millisecond + if c.flagDevLatency > 0 { + injectLatency := time.Duration(c.flagDevLatency) * time.Millisecond if _, txnOK := backend.(physical.Transactional); txnOK { - coreConfig.Physical = physical.NewTransactionalLatencyInjector(backend, injectLatency, devLatencyJitter, c.logger) + coreConfig.Physical = physical.NewTransactionalLatencyInjector(backend, injectLatency, c.flagDevLatencyJitter, c.logger) } else { - coreConfig.Physical = physical.NewLatencyInjector(backend, injectLatency, devLatencyJitter, c.logger) + coreConfig.Physical = physical.NewLatencyInjector(backend, injectLatency, c.flagDevLatencyJitter, c.logger) } } } - if devThreeNode { - return c.enableThreeNodeDevCluster(coreConfig, info, infoKeys, devListenAddress, os.Getenv("VAULT_DEV_TEMP_DIR")) + if c.flagDevThreeNode { + return c.enableThreeNodeDevCluster(coreConfig, info, infoKeys, c.flagDevListenAddr, os.Getenv("VAULT_DEV_TEMP_DIR")) } var disableClustering bool @@ -309,26 +465,25 @@ func (c *ServerCommand) Run(args []string) int { if config.HAStorage != nil { factory, exists := c.PhysicalBackends[config.HAStorage.Type] if !exists { - c.Ui.Output(fmt.Sprintf( - "Unknown HA storage type %s", - config.HAStorage.Type)) + c.UI.Error(fmt.Sprintf("Unknown HA storage type %s", config.HAStorage.Type)) return 1 + } habackend, err := factory(config.HAStorage.Config, c.logger) if err != nil { - c.Ui.Output(fmt.Sprintf( - "Error initializing HA storage of type %s: %s", - config.HAStorage.Type, err)) + c.UI.Error(fmt.Sprintf( + "Error initializing HA storage of type %s: %s", config.HAStorage.Type, err)) return 1 + } if coreConfig.HAPhysical, ok = habackend.(physical.HABackend); !ok { - c.Ui.Output("Specified HA storage does not support HA") + c.UI.Error("Specified HA storage does not support HA") return 1 } if !coreConfig.HAPhysical.HAEnabled() { - c.Ui.Output("Specified HA storage has HA support disabled; please consult documentation") + c.UI.Error("Specified HA storage has HA support disabled; please consult documentation") return 1 } @@ -365,14 +520,14 @@ func (c *ServerCommand) Run(args []string) int { if ok && coreConfig.RedirectAddr == "" { redirect, err := c.detectRedirect(detect, config) if err != nil { - c.Ui.Output(fmt.Sprintf("Error detecting redirect address: %s", err)) + c.UI.Error(fmt.Sprintf("Error detecting redirect address: %s", err)) } else if redirect == "" { - c.Ui.Output("Failed to detect redirect address.") + c.UI.Error("Failed to detect redirect address.") } else { coreConfig.RedirectAddr = redirect } } - if coreConfig.RedirectAddr == "" && dev { + if coreConfig.RedirectAddr == "" && c.flagDev { coreConfig.RedirectAddr = fmt.Sprintf("http://%s", config.Listeners[0].Config["address"]) } @@ -387,14 +542,15 @@ func (c *ServerCommand) Run(args []string) int { switch { case coreConfig.ClusterAddr == "" && coreConfig.RedirectAddr != "": addrToUse = coreConfig.RedirectAddr - case dev: + case c.flagDev: addrToUse = fmt.Sprintf("http://%s", config.Listeners[0].Config["address"]) default: goto CLUSTER_SYNTHESIS_COMPLETE } u, err := url.ParseRequestURI(addrToUse) if err != nil { - c.Ui.Output(fmt.Sprintf("Error parsing synthesized cluster address %s: %v", addrToUse, err)) + c.UI.Error(fmt.Sprintf( + "Error parsing synthesized cluster address %s: %v", addrToUse, err)) return 1 } host, port, err := net.SplitHostPort(u.Host) @@ -404,13 +560,14 @@ func (c *ServerCommand) Run(args []string) int { host = u.Host port = "443" } else { - c.Ui.Output(fmt.Sprintf("Error parsing redirect address: %v", err)) + c.UI.Error(fmt.Sprintf("Error parsing redirect address: %v", err)) return 1 } } nPort, err := strconv.Atoi(port) if err != nil { - c.Ui.Output(fmt.Sprintf("Error parsing synthesized address; failed to convert %q to a numeric: %v", port, err)) + c.UI.Error(fmt.Sprintf( + "Error parsing synthesized address; failed to convert %q to a numeric: %v", port, err)) return 1 } u.Host = net.JoinHostPort(host, strconv.Itoa(nPort+1)) @@ -425,8 +582,8 @@ CLUSTER_SYNTHESIS_COMPLETE: // Force https as we'll always be TLS-secured u, err := url.ParseRequestURI(coreConfig.ClusterAddr) if err != nil { - c.Ui.Output(fmt.Sprintf("Error parsing cluster address %s: %v", coreConfig.RedirectAddr, err)) - return 1 + c.UI.Error(fmt.Sprintf("Error parsing cluster address %s: %v", coreConfig.RedirectAddr, err)) + return 11 } u.Scheme = "https" coreConfig.ClusterAddr = u.String() @@ -436,7 +593,7 @@ CLUSTER_SYNTHESIS_COMPLETE: core, newCoreError := vault.NewCore(coreConfig) if newCoreError != nil { if !errwrap.ContainsType(newCoreError, new(vault.NonFatalError)) { - c.Ui.Output(fmt.Sprintf("Error initializing core: %s", newCoreError)) + c.UI.Error(fmt.Sprintf("Error initializing core: %s", newCoreError)) return 1 } } @@ -447,6 +604,7 @@ CLUSTER_SYNTHESIS_COMPLETE: // Compile server information for output later info["storage"] = config.Storage.Type + info["log level"] = c.flagLogLevel info["mlock"] = fmt.Sprintf( "supported: %v, enabled: %v", mlock.Supported(), !config.DisableMlock && mlock.Supported()) @@ -481,11 +639,9 @@ CLUSTER_SYNTHESIS_COMPLETE: c.reloadFuncsLock.Lock() lns := make([]net.Listener, 0, len(config.Listeners)) for i, lnConfig := range config.Listeners { - ln, props, reloadFunc, err := server.NewListener(lnConfig.Type, lnConfig.Config, c.logGate, c.Ui) + ln, props, reloadFunc, err := server.NewListener(lnConfig.Type, lnConfig.Config, c.logGate, c.UI) if err != nil { - c.Ui.Output(fmt.Sprintf( - "Error initializing listener of type %s: %s", - lnConfig.Type, err)) + c.UI.Error(fmt.Sprintf("Error initializing listener of type %s: %s", lnConfig.Type, err)) return 1 } @@ -505,16 +661,14 @@ CLUSTER_SYNTHESIS_COMPLETE: addr = addrRaw.(string) tcpAddr, err := net.ResolveTCPAddr("tcp", addr) if err != nil { - c.Ui.Output(fmt.Sprintf( - "Error resolving cluster_address: %s", - err)) + c.UI.Error(fmt.Sprintf("Error resolving cluster_address: %s", err)) return 1 } clusterAddrs = append(clusterAddrs, tcpAddr) } else { tcpAddr, ok := ln.Addr().(*net.TCPAddr) if !ok { - c.Ui.Output("Failed to parse tcp listener") + c.UI.Error("Failed to parse tcp listener") return 1 } clusterAddr := &net.TCPAddr{ @@ -572,17 +726,19 @@ CLUSTER_SYNTHESIS_COMPLETE: // Server configuration output padding := 24 sort.Strings(infoKeys) - c.Ui.Output("==> Vault server configuration:\n") + c.UI.Output("==> Vault server configuration:\n") for _, k := range infoKeys { - c.Ui.Output(fmt.Sprintf( + c.UI.Output(fmt.Sprintf( "%s%s: %s", strings.Repeat(" ", padding-len(k)), strings.Title(k), info[k])) } - c.Ui.Output("") + c.UI.Output("") - if verifyOnly { + // Tests might not want to start a vault server and just want to verify + // the configuration. + if c.flagTestVerifyOnly { return 0 } @@ -596,7 +752,7 @@ CLUSTER_SYNTHESIS_COMPLETE: err = core.UnsealWithStoredKeys() if err != nil { if !errwrap.ContainsType(err, new(vault.NonFatalError)) { - c.Ui.Output(fmt.Sprintf("Error initializing core: %s", err)) + c.UI.Error(fmt.Sprintf("Error initializing core: %s", err)) return 1 } } @@ -626,18 +782,17 @@ CLUSTER_SYNTHESIS_COMPLETE: } if err := sd.RunServiceDiscovery(c.WaitGroup, c.ShutdownCh, coreConfig.RedirectAddr, activeFunc, sealedFunc); err != nil { - c.Ui.Output(fmt.Sprintf("Error initializing service discovery: %v", err)) + c.UI.Error(fmt.Sprintf("Error initializing service discovery: %v", err)) return 1 } } } // If we're in Dev mode, then initialize the core - if dev && !devSkipInit { + if c.flagDev && !c.flagDevSkipInit { init, err := c.enableDev(core, coreConfig) if err != nil { - c.Ui.Output(fmt.Sprintf( - "Error initializing Dev mode: %s", err)) + c.UI.Error(fmt.Sprintf("Error initializing Dev mode: %s", err)) return 1 } @@ -648,38 +803,43 @@ CLUSTER_SYNTHESIS_COMPLETE: quote = "" } - c.Ui.Output(fmt.Sprint( - "==> WARNING: Dev mode is enabled!\n\n" + - "In this mode, Vault is completely in-memory and unsealed.\n" + - "Vault is configured to only have a single unseal key. The root\n" + - "token has already been authenticated with the CLI, so you can\n" + - "immediately begin using the Vault CLI.\n\n" + - "The only step you need to take is to set the following\n" + - "environment variables:\n\n" + - " " + export + " VAULT_ADDR=" + quote + "http://" + config.Listeners[0].Config["address"].(string) + quote + "\n\n" + - "The unseal key and root token are reproduced below in case you\n" + - "want to seal/unseal the Vault or play with authentication.\n", - )) + // Print the big dev mode warning! + c.UI.Warn(wrapAtLength( + "WARNING! dev mode is enabled! In this mode, Vault runs entirely " + + "in-memory and starts unsealed with a single unseal key. The root " + + "token is already authenticated to the CLI, so you can immediately " + + "begin using Vault.")) + c.UI.Warn("") + c.UI.Warn("You may need to set the following environment variable:") + c.UI.Warn("") + c.UI.Warn(fmt.Sprintf(" $ %s VAULT_ADDR=%s%s%s", + export, quote, "http://"+config.Listeners[0].Config["address"].(string), quote)) // Unseal key is not returned if stored shares is supported if len(init.SecretShares) > 0 { - c.Ui.Output(fmt.Sprintf( - "Unseal Key: %s", - base64.StdEncoding.EncodeToString(init.SecretShares[0]), - )) + c.UI.Warn("") + c.UI.Warn(wrapAtLength( + "The unseal key and root token are displayed below in case you want " + + "to seal/unseal the Vault or re-authenticate.")) + c.UI.Warn("") + c.UI.Warn(fmt.Sprintf("Unseal Key: %s", base64.StdEncoding.EncodeToString(init.SecretShares[0]))) } if len(init.RecoveryShares) > 0 { - c.Ui.Output(fmt.Sprintf( - "Recovery Key: %s", - base64.StdEncoding.EncodeToString(init.RecoveryShares[0]), - )) + c.UI.Warn("") + c.UI.Warn(wrapAtLength( + "The recovery key and root token are displayed below in case you want " + + "to seal/unseal the Vault or re-authenticate.")) + c.UI.Warn("") + c.UI.Warn(fmt.Sprintf("Unseal Key: %s", base64.StdEncoding.EncodeToString(init.RecoveryShares[0]))) } - c.Ui.Output(fmt.Sprintf( - "Root Token: %s\n", - init.RootToken, - )) + c.UI.Warn(fmt.Sprintf("Root Token: %s", init.RootToken)) + + c.UI.Warn("") + c.UI.Warn(wrapAtLength( + "Development mode should NOT be used in production installations!")) + c.UI.Warn("") } // Initialize the HTTP servers @@ -691,25 +851,33 @@ CLUSTER_SYNTHESIS_COMPLETE: } if newCoreError != nil { - c.Ui.Output("==> Warning:\n\nNon-fatal error during initialization; check the logs for more information.") - c.Ui.Output("") + c.UI.Warn(wrapAtLength( + "WARNING! A non-fatal error occurred during initialization. Please " + + "check the logs for more information.")) + c.UI.Warn("") } // Output the header that the server has started - c.Ui.Output("==> Vault server started! Log data will stream in below:\n") + c.UI.Output("==> Vault server started! Log data will stream in below:\n") + + // Inform any tests that the server is ready + select { + case c.startedCh <- struct{}{}: + default: + } // Release the log gate. c.logGate.Flush() // Write out the PID to the file now that server has successfully started if err := c.storePidFile(config.PidFile); err != nil { - c.Ui.Output(fmt.Sprintf("Error storing PID: %v", err)) + c.UI.Error(fmt.Sprintf("Error storing PID: %s", err)) return 1 } defer func() { if err := c.removePidFile(config.PidFile); err != nil { - c.Ui.Output(fmt.Sprintf("Error deleting the PID file: %v", err)) + c.UI.Error(fmt.Sprintf("Error deleting the PID file: %s", err)) } }() @@ -719,7 +887,7 @@ CLUSTER_SYNTHESIS_COMPLETE: for !shutdownTriggered { select { case <-c.ShutdownCh: - c.Ui.Output("==> Vault shutdown triggered") + c.UI.Output("==> Vault shutdown triggered") // Stop the listners so that we don't process further client requests. c.cleanupGuard.Do(listenerCloseFunc) @@ -728,15 +896,15 @@ CLUSTER_SYNTHESIS_COMPLETE: // request forwarding listeners will also be closed (and also // waited for). if err := core.Shutdown(); err != nil { - c.Ui.Output(fmt.Sprintf("Error with core shutdown: %s", err)) + c.UI.Error(fmt.Sprintf("Error with core shutdown: %s", err)) } shutdownTriggered = true case <-c.SighupCh: - c.Ui.Output("==> Vault reload triggered") - if err := c.Reload(c.reloadFuncsLock, c.reloadFuncs, configPath); err != nil { - c.Ui.Output(fmt.Sprintf("Error(s) were encountered during reload: %s", err)) + c.UI.Output("==> Vault reload triggered") + if err := c.Reload(c.reloadFuncsLock, c.reloadFuncs, c.flagConfigs); err != nil { + c.UI.Error(fmt.Sprintf("Error(s) were encountered during reload: %s", err)) } } } @@ -866,7 +1034,7 @@ func (c *ServerCommand) enableDev(core *vault.Core, coreConfig *vault.CoreConfig func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info map[string]string, infoKeys []string, devListenAddress, tempDir string) int { testCluster := vault.NewTestCluster(&testing.RuntimeT{}, base, &vault.TestClusterOptions{ HandlerFunc: vaulthttp.Handler, - BaseListenAddress: devListenAddress, + BaseListenAddress: c.flagDevListenAddr, RawLogger: c.logger, TempDir: tempDir, }) @@ -896,15 +1064,15 @@ func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info m // Server configuration output padding := 24 sort.Strings(infoKeys) - c.Ui.Output("==> Vault server configuration:\n") + c.UI.Output("==> Vault server configuration:\n") for _, k := range infoKeys { - c.Ui.Output(fmt.Sprintf( + c.UI.Output(fmt.Sprintf( "%s%s: %s", strings.Repeat(" ", padding-len(k)), strings.Title(k), info[k])) } - c.Ui.Output("") + c.UI.Output("") for _, core := range testCluster.Cores { core.Server.Handler = vaulthttp.Handler(core.Core) @@ -928,15 +1096,15 @@ func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info m } resp, err := testCluster.Cores[0].HandleRequest(req) if err != nil { - c.Ui.Output(fmt.Sprintf("failed to create root token with ID %s: %s", base.DevToken, err)) + c.UI.Error(fmt.Sprintf("failed to create root token with ID %s: %s", base.DevToken, err)) return 1 } if resp == nil { - c.Ui.Output(fmt.Sprintf("nil response when creating root token with ID %s", base.DevToken)) + c.UI.Error(fmt.Sprintf("nil response when creating root token with ID %s", base.DevToken)) return 1 } if resp.Auth == nil { - c.Ui.Output(fmt.Sprintf("nil auth when creating root token with ID %s", base.DevToken)) + c.UI.Error(fmt.Sprintf("nil auth when creating root token with ID %s", base.DevToken)) return 1 } @@ -947,7 +1115,7 @@ func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info m req.Data = nil resp, err = testCluster.Cores[0].HandleRequest(req) if err != nil { - c.Ui.Output(fmt.Sprintf("failed to revoke initial root token: %s", err)) + c.UI.Output(fmt.Sprintf("failed to revoke initial root token: %s", err)) return 1 } } @@ -955,37 +1123,37 @@ func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info m // Set the token tokenHelper, err := c.TokenHelper() if err != nil { - c.Ui.Output(fmt.Sprintf("%v", err)) + c.UI.Error(fmt.Sprintf("Error getting token helper: %s", err)) return 1 } if err := tokenHelper.Store(testCluster.RootToken); err != nil { - c.Ui.Output(fmt.Sprintf("%v", err)) + c.UI.Error(fmt.Sprintf("Error storing in token helper: %s", err)) return 1 } if err := ioutil.WriteFile(filepath.Join(testCluster.TempDir, "root_token"), []byte(testCluster.RootToken), 0755); err != nil { - c.Ui.Output(fmt.Sprintf("%v", err)) + c.UI.Error(fmt.Sprintf("Error writing token to tempfile: %s", err)) return 1 } - c.Ui.Output(fmt.Sprintf( + c.UI.Output(fmt.Sprintf( "==> Three node dev mode is enabled\n\n" + "The unseal key and root token are reproduced below in case you\n" + "want to seal/unseal the Vault or play with authentication.\n", )) for i, key := range testCluster.BarrierKeys { - c.Ui.Output(fmt.Sprintf( + c.UI.Output(fmt.Sprintf( "Unseal Key %d: %s", i+1, base64.StdEncoding.EncodeToString(key), )) } - c.Ui.Output(fmt.Sprintf( + c.UI.Output(fmt.Sprintf( "\nRoot Token: %s\n", testCluster.RootToken, )) - c.Ui.Output(fmt.Sprintf( + c.UI.Output(fmt.Sprintf( "\nUseful env vars:\n"+ "VAULT_TOKEN=%s\n"+ "VAULT_ADDR=%s\n"+ @@ -996,7 +1164,13 @@ func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info m )) // Output the header that the server has started - c.Ui.Output("==> Vault server started! Log data will stream in below:\n") + c.UI.Output("==> Vault server started! Log data will stream in below:\n") + + // Inform any tests that the server is ready + select { + case c.startedCh <- struct{}{}: + default: + } // Release the log gate. c.logGate.Flush() @@ -1007,7 +1181,7 @@ func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info m for !shutdownTriggered { select { case <-c.ShutdownCh: - c.Ui.Output("==> Vault shutdown triggered") + c.UI.Output("==> Vault shutdown triggered") // Stop the listners so that we don't process further client requests. c.cleanupGuard.Do(testCluster.Cleanup) @@ -1017,17 +1191,17 @@ func (c *ServerCommand) enableThreeNodeDevCluster(base *vault.CoreConfig, info m // waited for). for _, core := range testCluster.Cores { if err := core.Shutdown(); err != nil { - c.Ui.Output(fmt.Sprintf("Error with core shutdown: %s", err)) + c.UI.Error(fmt.Sprintf("Error with core shutdown: %s", err)) } } shutdownTriggered = true case <-c.SighupCh: - c.Ui.Output("==> Vault reload triggered") + c.UI.Output("==> Vault reload triggered") for _, core := range testCluster.Cores { if err := c.Reload(core.ReloadFuncsLock, core.ReloadFuncs, nil); err != nil { - c.Ui.Output(fmt.Sprintf("Error(s) were encountered during reload: %s", err)) + c.UI.Error(fmt.Sprintf("Error(s) were encountered during reload: %s", err)) } } } @@ -1231,77 +1405,21 @@ func (c *ServerCommand) Reload(lock *sync.RWMutex, reloadFuncs *map[string][]rel for _, relFunc := range relFuncs { if relFunc != nil { if err := relFunc(nil); err != nil { - reloadErrors = multierror.Append(reloadErrors, fmt.Errorf("Error encountered reloading file audit backend at path %s: %v", strings.TrimPrefix(k, "audit_file|"), err)) + reloadErrors = multierror.Append(reloadErrors, fmt.Errorf("Error encountered reloading file audit device at path %s: %v", strings.TrimPrefix(k, "audit_file|"), err)) } } } } } - return reloadErrors.ErrorOrNil() -} - -func (c *ServerCommand) Synopsis() string { - return "Start a Vault server" -} - -func (c *ServerCommand) Help() string { - helpText := ` -Usage: vault server [options] - - Start a Vault server. - - This command starts a Vault server that responds to API requests. - Vault will start in a "sealed" state. The Vault must be unsealed - with "vault unseal" or the API before this server can respond to requests. - This must be done for every server. - - If the server is being started against a storage backend that is - brand new (no existing Vault data in it), it must be initialized with - "vault init" or the API first. - - -General Options: - - -config= Path to the configuration file or directory. This can - be specified multiple times. If it is a directory, - all files with a ".hcl" or ".json" suffix will be - loaded. - - -dev Enables Dev mode. In this mode, Vault is completely - in-memory and unsealed. Do not run the Dev server in - production! - - -dev-root-token-id="" If set, the root token returned in Dev mode will have - the given ID. This *only* has an effect when running - in Dev mode. Can also be specified with the - VAULT_DEV_ROOT_TOKEN_ID environment variable. - - -dev-listen-address="" If set, this overrides the normal Dev mode listen - address of "127.0.0.1:8200". Can also be specified - with the VAULT_DEV_LISTEN_ADDRESS environment - variable. - - -log-level=info Log verbosity. Defaults to "info", will be output to - stderr. Supported values: "trace", "debug", "info", - "warn", "err". Can also be specified with the - VAULT_LOG_LEVEL environment variable. -` - return strings.TrimSpace(helpText) -} - -func (c *ServerCommand) AutocompleteArgs() complete.Predictor { - return complete.PredictNothing -} - -func (c *ServerCommand) AutocompleteFlags() complete.Flags { - return complete.Flags{ - "-config": complete.PredictOr(complete.PredictFiles("*.hcl"), complete.PredictFiles("*.json")), - "-dev": complete.PredictNothing, - "-dev-root-token-id": complete.PredictNothing, - "-dev-listen-address": complete.PredictNothing, - "-log-level": complete.PredictSet("trace", "debug", "info", "warn", "err"), + // Send a message that we reloaded. This prevents "guessing" sleep times + // in tests. + select { + case c.reloadedCh <- struct{}{}: + default: } + + return reloadErrors.ErrorOrNil() } // storePidFile is used to write out our PID to a file if necessary @@ -1335,38 +1453,6 @@ func (c *ServerCommand) removePidFile(pidPath string) error { return os.Remove(pidPath) } -// MakeShutdownCh returns a channel that can be used for shutdown -// notifications for commands. This channel will send a message for every -// SIGINT or SIGTERM received. -func MakeShutdownCh() chan struct{} { - resultCh := make(chan struct{}) - - shutdownCh := make(chan os.Signal, 4) - signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM) - go func() { - <-shutdownCh - close(resultCh) - }() - return resultCh -} - -// MakeSighupCh returns a channel that can be used for SIGHUP -// reloading. This channel will send a message for every -// SIGHUP received. -func MakeSighupCh() chan struct{} { - resultCh := make(chan struct{}) - - signalCh := make(chan os.Signal, 4) - signal.Notify(signalCh, syscall.SIGHUP) - go func() { - for { - <-signalCh - resultCh <- struct{}{} - } - }() - return resultCh -} - type grpclogFaker struct { logger log.Logger log bool diff --git a/command/server_ha_test.go b/command/server_ha_test.go deleted file mode 100644 index a9b1188126..0000000000 --- a/command/server_ha_test.go +++ /dev/null @@ -1,106 +0,0 @@ -// +build !race - -package command - -import ( - "io/ioutil" - "os" - "strings" - "testing" - - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/physical" - "github.com/mitchellh/cli" - - physConsul "github.com/hashicorp/vault/physical/consul" -) - -// The following tests have a go-metrics/exp manager race condition -func TestServer_CommonHA(t *testing.T) { - ui := new(cli.MockUi) - c := &ServerCommand{ - Meta: meta.Meta{ - Ui: ui, - }, - PhysicalBackends: map[string]physical.Factory{ - "consul": physConsul.NewConsulBackend, - }, - } - - tmpfile, err := ioutil.TempFile("", "") - if err != nil { - t.Fatalf("error creating temp dir: %v", err) - } - - tmpfile.WriteString(basehcl + consulhcl) - tmpfile.Close() - defer os.Remove(tmpfile.Name()) - - args := []string{"-config", tmpfile.Name(), "-verify-only", "true"} - - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s\n\n%s", code, ui.ErrorWriter.String(), ui.OutputWriter.String()) - } - - if !strings.Contains(ui.OutputWriter.String(), "(HA available)") { - t.Fatalf("did not find HA available: %s", ui.OutputWriter.String()) - } -} - -func TestServer_GoodSeparateHA(t *testing.T) { - ui := new(cli.MockUi) - c := &ServerCommand{ - Meta: meta.Meta{ - Ui: ui, - }, - PhysicalBackends: map[string]physical.Factory{ - "consul": physConsul.NewConsulBackend, - }, - } - - tmpfile, err := ioutil.TempFile("", "") - if err != nil { - t.Fatalf("error creating temp dir: %v", err) - } - - tmpfile.WriteString(basehcl + consulhcl + haconsulhcl) - tmpfile.Close() - defer os.Remove(tmpfile.Name()) - - args := []string{"-config", tmpfile.Name(), "-verify-only", "true"} - - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s\n\n%s", code, ui.ErrorWriter.String(), ui.OutputWriter.String()) - } - - if !strings.Contains(ui.OutputWriter.String(), "HA Storage:") { - t.Fatalf("did not find HA Storage: %s", ui.OutputWriter.String()) - } -} - -func TestServer_BadSeparateHA(t *testing.T) { - ui := new(cli.MockUi) - c := &ServerCommand{ - Meta: meta.Meta{ - Ui: ui, - }, - PhysicalBackends: map[string]physical.Factory{ - "consul": physConsul.NewConsulBackend, - }, - } - - tmpfile, err := ioutil.TempFile("", "") - if err != nil { - t.Fatalf("error creating temp dir: %v", err) - } - - tmpfile.WriteString(basehcl + consulhcl + badhaconsulhcl) - tmpfile.Close() - defer os.Remove(tmpfile.Name()) - - args := []string{"-config", tmpfile.Name()} - - if code := c.Run(args); code == 0 { - t.Fatalf("bad: should have gotten an error on a bad HA config") - } -} diff --git a/command/server_test.go b/command/server_test.go index 9a90239011..c15cf0596f 100644 --- a/command/server_test.go +++ b/command/server_test.go @@ -1,4 +1,5 @@ // +build !race +// The server tests have a go-metrics/exp manager race condition :(. package command @@ -7,72 +8,112 @@ import ( "crypto/x509" "fmt" "io/ioutil" - "math/rand" + "net" "os" "strings" "sync" "testing" "time" - "github.com/hashicorp/vault/meta" "github.com/hashicorp/vault/physical" "github.com/mitchellh/cli" + physConsul "github.com/hashicorp/vault/physical/consul" physFile "github.com/hashicorp/vault/physical/file" ) -var ( - basehcl = ` -disable_mlock = true +func testRandomPort(tb testing.TB) int { + tb.Helper() -listener "tcp" { - address = "127.0.0.1:8200" - tls_disable = "true" + addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0") + if err != nil { + tb.Fatal(err) + } + + l, err := net.ListenTCP("tcp", addr) + if err != nil { + tb.Fatal(err) + } + defer l.Close() + + return l.Addr().(*net.TCPAddr).Port } -` - consulhcl = ` +func testBaseHCL(tb testing.TB) string { + tb.Helper() + + return strings.TrimSpace(fmt.Sprintf(` + disable_mlock = true + listener "tcp" { + address = "127.0.0.1:%d" + tls_disable = "true" + } + `, testRandomPort(tb))) +} + +const ( + consulHCL = ` backend "consul" { - prefix = "foo/" - advertise_addr = "http://127.0.0.1:8200" - disable_registration = "true" + prefix = "foo/" + advertise_addr = "http://127.0.0.1:8200" + disable_registration = "true" } ` - haconsulhcl = ` + haConsulHCL = ` ha_backend "consul" { - prefix = "bar/" - redirect_addr = "http://127.0.0.1:8200" - disable_registration = "true" + prefix = "bar/" + redirect_addr = "http://127.0.0.1:8200" + disable_registration = "true" } ` - badhaconsulhcl = ` + badHAConsulHCL = ` ha_backend "file" { - path = "/dev/null" + path = "/dev/null" } ` - reloadhcl = ` + reloadHCL = ` backend "file" { - path = "/dev/null" + path = "/dev/null" } - disable_mlock = true - listener "tcp" { - address = "127.0.0.1:8203" - tls_cert_file = "TMPDIR/reload_cert.pem" - tls_key_file = "TMPDIR/reload_key.pem" + address = "127.0.0.1:8203" + tls_cert_file = "TMPDIR/reload_cert.pem" + tls_key_file = "TMPDIR/reload_key.pem" } ` ) -// The following tests have a go-metrics/exp manager race condition +func testServerCommand(tb testing.TB) (*cli.MockUi, *ServerCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &ServerCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + ShutdownCh: MakeShutdownCh(), + SighupCh: MakeSighupCh(), + PhysicalBackends: map[string]physical.Factory{ + "file": physFile.NewFileBackend, + "consul": physConsul.NewConsulBackend, + }, + + // These prevent us from random sleep guessing... + startedCh: make(chan struct{}, 5), + reloadedCh: make(chan struct{}, 5), + } +} + func TestServer_ReloadListener(t *testing.T) { + t.Parallel() + wd, _ := os.Getwd() wd += "/server/test-fixtures/reload/" - td, err := ioutil.TempDir("", fmt.Sprintf("vault-test-%d", rand.New(rand.NewSource(time.Now().Unix())).Int63)) + td, err := ioutil.TempDir("", "vault-test-") if err != nil { t.Fatal(err) } @@ -86,7 +127,7 @@ func TestServer_ReloadListener(t *testing.T) { inBytes, _ = ioutil.ReadFile(wd + "reload_foo.key") ioutil.WriteFile(td+"/reload_key.pem", inBytes, 0777) - relhcl := strings.Replace(reloadhcl, "TMPDIR", td, -1) + relhcl := strings.Replace(reloadHCL, "TMPDIR", td, -1) ioutil.WriteFile(td+"/reload.hcl", []byte(relhcl), 0777) inBytes, _ = ioutil.ReadFile(wd + "reload_ca.pem") @@ -96,17 +137,8 @@ func TestServer_ReloadListener(t *testing.T) { t.Fatal("not ok when appending CA cert") } - ui := new(cli.MockUi) - c := &ServerCommand{ - Meta: meta.Meta{ - Ui: ui, - }, - ShutdownCh: MakeShutdownCh(), - SighupCh: MakeSighupCh(), - PhysicalBackends: map[string]physical.Factory{ - "file": physFile.NewFileBackend, - }, - } + ui, cmd := testServerCommand(t) + _ = ui finished := false finishedMutex := sync.Mutex{} @@ -114,7 +146,7 @@ func TestServer_ReloadListener(t *testing.T) { wg.Add(1) args := []string{"-config", td + "/reload.hcl"} go func() { - if code := c.Run(args); code != 0 { + if code := cmd.Run(args); code != 0 { t.Error("got a non-zero exit status") } finishedMutex.Lock() @@ -123,14 +155,6 @@ func TestServer_ReloadListener(t *testing.T) { wg.Done() }() - checkFinished := func() { - finishedMutex.Lock() - if finished { - t.Fatalf(fmt.Sprintf("finished early; relhcl was\n%s\nstdout was\n%s\nstderr was\n%s\n", relhcl, ui.OutputWriter.String(), ui.ErrorWriter.String())) - } - finishedMutex.Unlock() - } - testCertificateName := func(cn string) error { conn, err := tls.Dial("tcp", "127.0.0.1:8203", &tls.Config{ RootCAs: certPool, @@ -149,31 +173,95 @@ func TestServer_ReloadListener(t *testing.T) { return nil } - checkFinished() - time.Sleep(5 * time.Second) - checkFinished() + select { + case <-cmd.startedCh: + case <-time.After(5 * time.Second): + t.Fatalf("timeout") + } if err := testCertificateName("foo.example.com"); err != nil { t.Fatalf("certificate name didn't check out: %s", err) } - relhcl = strings.Replace(reloadhcl, "TMPDIR", td, -1) + relhcl = strings.Replace(reloadHCL, "TMPDIR", td, -1) inBytes, _ = ioutil.ReadFile(wd + "reload_bar.pem") ioutil.WriteFile(td+"/reload_cert.pem", inBytes, 0777) inBytes, _ = ioutil.ReadFile(wd + "reload_bar.key") ioutil.WriteFile(td+"/reload_key.pem", inBytes, 0777) ioutil.WriteFile(td+"/reload.hcl", []byte(relhcl), 0777) - c.SighupCh <- struct{}{} - checkFinished() - time.Sleep(2 * time.Second) - checkFinished() + cmd.SighupCh <- struct{}{} + select { + case <-cmd.reloadedCh: + case <-time.After(5 * time.Second): + t.Fatalf("timeout") + } if err := testCertificateName("bar.example.com"); err != nil { t.Fatalf("certificate name didn't check out: %s", err) } - c.ShutdownCh <- struct{}{} + cmd.ShutdownCh <- struct{}{} wg.Wait() } + +func TestServer(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + contents string + exp string + code int + }{ + { + "common_ha", + testBaseHCL(t) + consulHCL, + "(HA available)", + 0, + }, + { + "separate_ha", + testBaseHCL(t) + consulHCL + haConsulHCL, + "HA Storage:", + 0, + }, + { + "bad_separate_ha", + testBaseHCL(t) + consulHCL + badHAConsulHCL, + "Specified HA storage does not support HA", + 1, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ui, cmd := testServerCommand(t) + f, err := ioutil.TempFile("", "") + if err != nil { + t.Fatalf("error creating temp dir: %v", err) + } + f.WriteString(tc.contents) + f.Close() + defer os.Remove(f.Name()) + + code := cmd.Run([]string{ + "-config", f.Name(), + "-test-verify-only", + }) + output := ui.ErrorWriter.String() + ui.OutputWriter.String() + if code != tc.code { + t.Errorf("expected %d to be %d: %s", code, tc.code, output) + } + + if !strings.Contains(output, tc.exp) { + t.Fatalf("expected %q to contain %q", output, tc.exp) + } + }) + } +} diff --git a/command/ssh.go b/command/ssh.go index 03e1933da6..99939c0e75 100644 --- a/command/ssh.go +++ b/command/ssh.go @@ -9,46 +9,201 @@ import ( "os/exec" "os/user" "strings" + "syscall" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/ssh" - "github.com/hashicorp/vault/meta" - homedir "github.com/mitchellh/go-homedir" + "github.com/mitchellh/cli" "github.com/mitchellh/mapstructure" "github.com/pkg/errors" + "github.com/posener/complete" ) -// SSHCommand is a Command that establishes a SSH connection with target by -// generating a dynamic key +var _ cli.Command = (*SSHCommand)(nil) +var _ cli.CommandAutocomplete = (*SSHCommand)(nil) + type SSHCommand struct { - meta.Meta + *BaseCommand - // API - client *api.Client - sshClient *api.SSH + // Common SSH options + flagMode string + flagRole string + flagNoExec bool + flagMountPoint string + flagStrictHostKeyChecking string + flagUserKnownHostsFile string - // Common options - mode string - noExec bool - format string - mountPoint string - role string - username string - ip string - sshArgs []string - - // Key options - strictHostKeyChecking string - userKnownHostsFile string - - // SSH CA backend specific options - publicKeyPath string - privateKeyPath string - hostKeyMountPoint string - hostKeyHostnames string + // SSH CA Mode options + flagPublicKeyPath string + flagPrivateKeyPath string + flagHostKeyMountPoint string + flagHostKeyHostnames string } -// Structure to hold the fields returned when asked for a credential from SSHh backend. +func (c *SSHCommand) Synopsis() string { + return "Initiate an SSH session" +} + +func (c *SSHCommand) Help() string { + helpText := ` +Usage: vault ssh [options] username@ip [ssh options] + + Establishes an SSH connection with the target machine. + + This command uses one of the SSH secrets engines to authenticate and + automatically establish an SSH connection to a host. This operation requires + that the SSH secrets engine is mounted and configured. + + SSH using the OTP mode (requires sshpass for full automation): + + $ vault ssh -mode=otp -role=my-role user@1.2.3.4 + + SSH using the CA mode: + + $ vault ssh -mode=ca -role=my-role user@1.2.3.4 + + SSH using CA mode with host key verification: + + $ vault ssh \ + -mode=ca \ + -role=my-role \ + -host-key-mount-point=host-signer \ + -host-key-hostnames=example.com \ + user@example.com + + For the full list of options and arguments, please see the documentation. + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *SSHCommand) Flags() *FlagSets { + set := c.flagSet(FlagSetHTTP | FlagSetOutputField | FlagSetOutputFormat) + + f := set.NewFlagSet("SSH Options") + + // TODO: doc field? + + // General + f.StringVar(&StringVar{ + Name: "mode", + Target: &c.flagMode, + Default: "", + EnvVar: "", + Completion: complete.PredictSet("ca", "dynamic", "otp"), + Usage: "Name of the role to use to generate the key.", + }) + + f.StringVar(&StringVar{ + Name: "role", + Target: &c.flagRole, + Default: "", + EnvVar: "", + Completion: complete.PredictAnything, + Usage: "Name of the role to use to generate the key.", + }) + + f.BoolVar(&BoolVar{ + Name: "no-exec", + Target: &c.flagNoExec, + Default: false, + EnvVar: "", + Completion: complete.PredictNothing, + Usage: "Print the generated credentials, but do not establish a " + + "connection.", + }) + + f.StringVar(&StringVar{ + Name: "mount-point", + Target: &c.flagMountPoint, + Default: "ssh/", + EnvVar: "", + Completion: complete.PredictAnything, + Usage: "Mount point to the SSH secrets engine.", + }) + + f.StringVar(&StringVar{ + Name: "strict-host-key-checking", + Target: &c.flagStrictHostKeyChecking, + Default: "ask", + EnvVar: "VAULT_SSH_STRICT_HOST_KEY_CHECKING", + Completion: complete.PredictSet("ask", "no", "yes"), + Usage: "Value to use for the SSH configuration option " + + "\"StrictHostKeyChecking\".", + }) + + f.StringVar(&StringVar{ + Name: "user-known-hosts-file", + Target: &c.flagUserKnownHostsFile, + Default: "~/.ssh/known_hosts", + EnvVar: "VAULT_SSH_USER_KNOWN_HOSTS_FILE", + Completion: complete.PredictFiles("*"), + Usage: "Value to use for the SSH configuration option " + + "\"UserKnownHostsFile\".", + }) + + // SSH CA + f = set.NewFlagSet("CA Mode Options") + + f.StringVar(&StringVar{ + Name: "public-key-path", + Target: &c.flagPublicKeyPath, + Default: "~/.ssh/id_rsa.pub", + EnvVar: "", + Completion: complete.PredictFiles("*"), + Usage: "Path to the SSH public key to send to Vault for signing.", + }) + + f.StringVar(&StringVar{ + Name: "private-key-path", + Target: &c.flagPrivateKeyPath, + Default: "~/.ssh/id_rsa", + EnvVar: "", + Completion: complete.PredictFiles("*"), + Usage: "Path to the SSH private key to use for authentication. This must " + + "be the corresponding private key to -public-key-path.", + }) + + f.StringVar(&StringVar{ + Name: "host-key-mount-point", + Target: &c.flagHostKeyMountPoint, + Default: "", + EnvVar: "VAULT_SSH_HOST_KEY_MOUNT_POINT", + Completion: complete.PredictAnything, + Usage: "Mount point to the SSH secrets engine where host keys are signed. " + + "When given a value, Vault will generate a custom \"known_hosts\" file " + + "with delegation to the CA at the provided mount point to verify the " + + "SSH connection's host keys against the provided CA. By default, host " + + "keys are validated against the user's local \"known_hosts\" file. " + + "This flag forces strict key host checking and ignores a custom user " + + "known hosts file.", + }) + + f.StringVar(&StringVar{ + Name: "host-key-hostnames", + Target: &c.flagHostKeyHostnames, + Default: "*", + EnvVar: "VAULT_SSH_HOST_KEY_HOSTNAMES", + Completion: complete.PredictAnything, + Usage: "List of hostnames to delegate for the CA. The default value " + + "allows all domains and IPs. This is specified as a comma-separated " + + "list of values.", + }) + + return set +} + +func (c *SSHCommand) AutocompleteArgs() complete.Predictor { + return nil +} + +func (c *SSHCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +// Structure to hold the fields returned when asked for a credential from SSH +// secrets engine. type SSHCredentialResp struct { KeyType string `mapstructure:"key_type"` Key string `mapstructure:"key"` @@ -58,74 +213,35 @@ type SSHCredentialResp struct { } func (c *SSHCommand) Run(args []string) int { + f := c.Flags() - flags := c.Meta.FlagSet("ssh", meta.FlagSetDefault) - - envOrDefault := func(key string, def string) string { - if k := os.Getenv(key); k != "" { - return k - } - return def - } - - expandPath := func(p string) string { - e, err := homedir.Expand(p) - if err != nil { - return p - } - return e - } - - // Common options - flags.StringVar(&c.mode, "mode", "", "") - flags.BoolVar(&c.noExec, "no-exec", false, "") - flags.StringVar(&c.format, "format", "table", "") - flags.StringVar(&c.mountPoint, "mount-point", "ssh", "") - flags.StringVar(&c.role, "role", "", "") - - // Key options - flags.StringVar(&c.strictHostKeyChecking, "strict-host-key-checking", - envOrDefault("VAULT_SSH_STRICT_HOST_KEY_CHECKING", "ask"), "") - flags.StringVar(&c.userKnownHostsFile, "user-known-hosts-file", - envOrDefault("VAULT_SSH_USER_KNOWN_HOSTS_FILE", expandPath("~/.ssh/known_hosts")), "") - - // CA-specific options - flags.StringVar(&c.publicKeyPath, "public-key-path", - expandPath("~/.ssh/id_rsa.pub"), "") - flags.StringVar(&c.privateKeyPath, "private-key-path", - expandPath("~/.ssh/id_rsa"), "") - flags.StringVar(&c.hostKeyMountPoint, "host-key-mount-point", "", "") - flags.StringVar(&c.hostKeyHostnames, "host-key-hostnames", "*", "") - - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) return 1 } - args = flags.Args() + // Use homedir to expand any relative paths such as ~/.ssh + c.flagUserKnownHostsFile = expandPath(c.flagUserKnownHostsFile) + c.flagPublicKeyPath = expandPath(c.flagPublicKeyPath) + c.flagPrivateKeyPath = expandPath(c.flagPrivateKeyPath) + + args = f.Args() if len(args) < 1 { - c.Ui.Error("ssh expects at least one argument") + c.UI.Error(fmt.Sprintf("Not enough arguments, (expected 1-n, got %d)", len(args))) return 1 } - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf("Error initializing client: %v", err)) - return 1 - } - c.client = client - c.sshClient = client.SSHWithMountPoint(c.mountPoint) - // Extract the username and IP. - c.username, c.ip, err = c.userAndIP(args[0]) + username, ip, err := c.userAndIP(args[0]) if err != nil { - c.Ui.Error(fmt.Sprintf("Error parsing user and IP: %s", err)) + c.UI.Error(fmt.Sprintf("Error parsing user and IP: %s", err)) return 1 } // The rest of the args are ssh args + sshArgs := []string{} if len(args) > 1 { - c.sshArgs = args[1:] + sshArgs = args[1:] } // Credentials are generated only against a registered role. If user @@ -134,100 +250,101 @@ func (c *SSHCommand) Run(args []string) int { // only one role associated with it, use it to establish the connection. // // TODO: remove in 0.9.0, convert to validation error - if c.role == "" { - c.Ui.Warn("" + - "WARNING: No -role specified. Use -role to tell Vault which ssh role\n" + - "to use for authentication. In the future, you will need to tell Vault\n" + - "which role to use. For now, Vault will attempt to guess based on a\n" + - "the API response.") + if c.flagRole == "" { + c.UI.Warn(wrapAtLength( + "WARNING: No -role specified. Use -role to tell Vault which ssh role " + + "to use for authentication. In the future, you will need to tell " + + "Vault which role to use. For now, Vault will attempt to guess based " + + "on a the API response. This will be removed in the next major " + + "version of Vault.")) - role, err := c.defaultRole(c.mountPoint, c.ip) + role, err := c.defaultRole(c.flagMountPoint, ip) if err != nil { - c.Ui.Error(fmt.Sprintf("Error choosing role: %v", err)) + c.UI.Error(fmt.Sprintf("Error choosing role: %v", err)) return 1 } // Print the default role chosen so that user knows the role name // if something doesn't work. If the role chosen is not allowed to // be used by the user (ACL enforcement), then user should see an // error message accordingly. - c.Ui.Output(fmt.Sprintf("Vault SSH: Role: %q", role)) - c.role = role + c.UI.Output(fmt.Sprintf("Vault SSH: Role: %q", role)) + c.flagRole = role } // If no mode was given, perform the old-school lookup. Keep this now for // backwards-compatability, but print a warning. // // TODO: remove in 0.9.0, convert to validation error - if c.mode == "" { - c.Ui.Warn("" + - "WARNING: No -mode specified. Use -mode to tell Vault which ssh\n" + - "authentication mode to use. In the future, you will need to tell\n" + - "Vault which mode to use. For now, Vault will attempt to guess based\n" + - "on the API response. This guess involves creating a temporary\n" + - "credential, reading its type, and then revoking it. To reduce the\n" + - "number of API calls and surface area, specify -mode directly.") - secret, cred, err := c.generateCredential() + if c.flagMode == "" { + c.UI.Warn(wrapAtLength( + "WARNING: No -mode specified. Use -mode to tell Vault which ssh " + + "authentication mode to use. In the future, you will need to tell " + + "Vault which mode to use. For now, Vault will attempt to guess based " + + "on the API response. This guess involves creating a temporary " + + "credential, reading its type, and then revoking it. To reduce the " + + "number of API calls and surface area, specify -mode directly. This " + + "will be removed in the next major version of Vault.")) + secret, cred, err := c.generateCredential(username, ip) if err != nil { // This is _very_ hacky, but is the only sane backwards-compatible way // to do this. If the error is "key type unknown", we just assume the // type is "ca". In the future, mode will be required as an option. if strings.Contains(err.Error(), "key type unknown") { - c.mode = ssh.KeyTypeCA + c.flagMode = ssh.KeyTypeCA } else { - c.Ui.Error(fmt.Sprintf("Error getting credential: %s", err)) + c.UI.Error(fmt.Sprintf("Error getting credential: %s", err)) return 1 } } else { - c.mode = cred.KeyType + c.flagMode = cred.KeyType } // Revoke the secret, since the child functions will generate their own // credential. Users wishing to avoid this should specify -mode. if secret != nil { if err := c.client.Sys().Revoke(secret.LeaseID); err != nil { - c.Ui.Warn(fmt.Sprintf("Failed to revoke temporary key: %s", err)) + c.UI.Warn(fmt.Sprintf("Failed to revoke temporary key: %s", err)) } } } - switch strings.ToLower(c.mode) { + switch strings.ToLower(c.flagMode) { case ssh.KeyTypeCA: - if err := c.handleTypeCA(); err != nil { - c.Ui.Error(err.Error()) - return 1 - } + return c.handleTypeCA(username, ip, sshArgs) case ssh.KeyTypeOTP: - if err := c.handleTypeOTP(); err != nil { - c.Ui.Error(err.Error()) - return 1 - } + return c.handleTypeOTP(username, ip, sshArgs) case ssh.KeyTypeDynamic: - if err := c.handleTypeDynamic(); err != nil { - c.Ui.Error(err.Error()) - return 1 - } + return c.handleTypeDynamic(username, ip, sshArgs) default: - c.Ui.Error(fmt.Sprintf("Unknown SSH mode: %s", c.mode)) + c.UI.Error(fmt.Sprintf("Unknown SSH mode: %s", c.flagMode)) return 1 } - - return 0 } // handleTypeCA is used to handle SSH logins using the "CA" key type. -func (c *SSHCommand) handleTypeCA() error { +func (c *SSHCommand) handleTypeCA(username, ip string, sshArgs []string) int { // Read the key from disk - publicKey, err := ioutil.ReadFile(c.publicKeyPath) + publicKey, err := ioutil.ReadFile(c.flagPublicKeyPath) if err != nil { - return errors.Wrap(err, "failed to read public key") + c.UI.Error(fmt.Sprintf("failed to read public key %s: %s", + c.flagPublicKeyPath, err)) + return 1 } + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 1 + } + + sshClient := client.SSHWithMountPoint(c.flagMountPoint) + // Attempt to sign the public key - secret, err := c.sshClient.SignKey(c.role, map[string]interface{}{ + secret, err := sshClient.SignKey(c.flagRole, map[string]interface{}{ // WARNING: publicKey is []byte, which is b64 encoded on JSON upload. We // have to convert it to a string. SV lost many hours to this... "public_key": string(publicKey), - "valid_principals": c.username, + "valid_principals": username, "cert_type": "user", // TODO: let the user configure these. In the interim, if users want to @@ -241,55 +358,62 @@ func (c *SSHCommand) handleTypeCA() error { }, }) if err != nil { - return errors.Wrap(err, "failed to sign public key") + c.UI.Error(fmt.Sprintf("failed to sign public key %s: %s", + c.flagPublicKeyPath, err)) + return 2 } if secret == nil || secret.Data == nil { - return fmt.Errorf("client signing returned empty credentials") + c.UI.Error("missing signed key") + return 2 } // Handle no-exec - if c.noExec { - // This is hacky, but OutputSecret returns an int, not an error :( - if i := OutputSecret(c.Ui, c.format, secret); i != 0 { - return fmt.Errorf("an error occurred outputting the secret") + if c.flagNoExec { + if c.flagFormat != "" { + return PrintRawField(c.UI, secret, c.flagField) } - return nil + return OutputSecret(c.UI, c.flagFormat, secret) } // Extract public key key, ok := secret.Data["signed_key"].(string) - if !ok { - return fmt.Errorf("missing signed key") + if !ok || key == "" { + c.UI.Error("signed key is empty") + return 2 } // Capture the current value - this could be overwritten later if the user // enabled host key signing verification. - userKnownHostsFile := c.userKnownHostsFile - strictHostKeyChecking := c.strictHostKeyChecking + userKnownHostsFile := c.flagUserKnownHostsFile + strictHostKeyChecking := c.flagStrictHostKeyChecking // Handle host key signing verification. If the user specified a mount point, // download the public key, trust it with the given domains, and use that // instead of the user's regular known_hosts file. - if c.hostKeyMountPoint != "" { - secret, err := c.client.Logical().Read(c.hostKeyMountPoint + "/config/ca") + if c.flagHostKeyMountPoint != "" { + secret, err := c.client.Logical().Read(c.flagHostKeyMountPoint + "/config/ca") if err != nil { - return errors.Wrap(err, "failed to get host signing key") + c.UI.Error(fmt.Sprintf("failed to get host signing key: %s", err)) + return 2 } if secret == nil || secret.Data == nil { - return fmt.Errorf("missing host signing key") + c.UI.Error("missing host signing key") + return 2 } publicKey, ok := secret.Data["public_key"].(string) - if !ok { - return fmt.Errorf("host signing key is empty") + if !ok || publicKey == "" { + c.UI.Error("host signing key is empty") + return 2 } // Write the known_hosts file - name := fmt.Sprintf("vault_ssh_ca_known_hosts_%s_%s", c.username, c.ip) - data := fmt.Sprintf("@cert-authority %s %s", c.hostKeyHostnames, publicKey) + name := fmt.Sprintf("vault_ssh_ca_known_hosts_%s_%s", username, ip) + data := fmt.Sprintf("@cert-authority %s %s", c.flagHostKeyHostnames, publicKey) knownHosts, err, closer := c.writeTemporaryFile(name, []byte(data), 0644) defer closer() if err != nil { - return errors.Wrap(err, "failed to write host public key") + c.UI.Error(fmt.Sprintf("failed to write host public key: %s", err)) + return 1 } // Update the variables @@ -298,20 +422,21 @@ func (c *SSHCommand) handleTypeCA() error { } // Write the signed public key to disk - name := fmt.Sprintf("vault_ssh_ca_%s_%s", c.username, c.ip) + name := fmt.Sprintf("vault_ssh_ca_%s_%s", username, ip) signedPublicKeyPath, err, closer := c.writeTemporaryKey(name, []byte(key)) defer closer() if err != nil { - return errors.Wrap(err, "failed to write signed public key") + c.UI.Error(fmt.Sprintf("failed to write signed public key: %s", err)) + return 2 } args := append([]string{ - "-i", c.privateKeyPath, + "-i", c.flagPrivateKeyPath, "-i", signedPublicKeyPath, "-o UserKnownHostsFile=" + userKnownHostsFile, "-o StrictHostKeyChecking=" + strictHostKeyChecking, - c.username + "@" + c.ip, - }, c.sshArgs...) + username + "@" + ip, + }, sshArgs...) cmd := exec.Command("ssh", args...) cmd.Stdin = os.Stdin @@ -319,61 +444,71 @@ func (c *SSHCommand) handleTypeCA() error { cmd.Stderr = os.Stderr err = cmd.Run() if err != nil { - return errors.Wrap(err, "failed to run ssh command") + exitCode := 2 + + if exitError, ok := err.(*exec.ExitError); ok { + if exitError.Success() { + return 0 + } + if ws, ok := exitError.Sys().(syscall.WaitStatus); ok { + exitCode = ws.ExitStatus() + } + } + + c.UI.Error(fmt.Sprintf("failed to run ssh command: %s", err)) + return exitCode } // There is no secret to revoke, since it's a certificate signing - - return nil + return 0 } // handleTypeOTP is used to handle SSH logins using the "otp" key type. -func (c *SSHCommand) handleTypeOTP() error { - secret, cred, err := c.generateCredential() +func (c *SSHCommand) handleTypeOTP(username, ip string, sshArgs []string) int { + secret, cred, err := c.generateCredential(username, ip) if err != nil { - return errors.Wrap(err, "failed to generate credential") + c.UI.Error(fmt.Sprintf("failed to generate credential: %s", err)) + return 2 } // Handle no-exec - if c.noExec { - // This is hacky, but OutputSecret returns an int, not an error :( - if i := OutputSecret(c.Ui, c.format, secret); i != 0 { - return fmt.Errorf("an error occurred outputting the secret") + if c.flagNoExec { + if c.flagFormat != "" { + return PrintRawField(c.UI, secret, c.flagField) } - return nil + return OutputSecret(c.UI, c.flagFormat, secret) } var cmd *exec.Cmd - // Check if the application 'sshpass' is installed in the client machine. - // If it is then, use it to automate typing in OTP to the prompt. Unfortunately, + // Check if the application 'sshpass' is installed in the client machine. If + // it is then, use it to automate typing in OTP to the prompt. Unfortunately, // it was not possible to automate it without a third-party application, with - // only the Go libraries. - // Feel free to try and remove this dependency. + // only the Go libraries. Feel free to try and remove this dependency. sshpassPath, err := exec.LookPath("sshpass") if err != nil { - c.Ui.Warn("" + - "Vault could not locate sshpass. The OTP code for the session will be\n" + - "displayed below. Enter this code in the SSH password prompt. If you\n" + - "install sshpass, Vault can automatically perform this step for you.") - c.Ui.Output("OTP for the session is " + cred.Key) + c.UI.Warn(wrapAtLength( + "Vault could not locate \"sshpass\". The OTP code for the session is " + + "displayed below. Enter this code in the SSH password prompt. If you " + + "install sshpass, Vault can automatically perform this step for you.")) + c.UI.Output("OTP for the session is: " + cred.Key) args := append([]string{ - "-o UserKnownHostsFile=" + c.userKnownHostsFile, - "-o StrictHostKeyChecking=" + c.strictHostKeyChecking, + "-o UserKnownHostsFile=" + c.flagUserKnownHostsFile, + "-o StrictHostKeyChecking=" + c.flagStrictHostKeyChecking, "-p", cred.Port, - c.username + "@" + c.ip, - }, c.sshArgs...) + username + "@" + ip, + }, sshArgs...) cmd = exec.Command("ssh", args...) } else { args := append([]string{ "-e", // Read password for SSHPASS environment variable "ssh", - "-o UserKnownHostsFile=" + c.userKnownHostsFile, - "-o StrictHostKeyChecking=" + c.strictHostKeyChecking, + "-o UserKnownHostsFile=" + c.flagUserKnownHostsFile, + "-o StrictHostKeyChecking=" + c.flagStrictHostKeyChecking, "-p", cred.Port, - c.username + "@" + c.ip, - }, c.sshArgs...) + username + "@" + ip, + }, sshArgs...) cmd = exec.Command(sshpassPath, args...) env := os.Environ() env = append(env, fmt.Sprintf("SSHPASS=%s", string(cred.Key))) @@ -385,49 +520,63 @@ func (c *SSHCommand) handleTypeOTP() error { cmd.Stderr = os.Stderr err = cmd.Run() if err != nil { - return errors.Wrap(err, "failed to run ssh command") + exitCode := 2 + + if exitError, ok := err.(*exec.ExitError); ok { + if exitError.Success() { + return 0 + } + if ws, ok := exitError.Sys().(syscall.WaitStatus); ok { + exitCode = ws.ExitStatus() + } + } + + c.UI.Error(fmt.Sprintf("failed to run ssh command: %s", err)) + return exitCode } // Revoke the key if it's longer than expected if err := c.client.Sys().Revoke(secret.LeaseID); err != nil { - return errors.Wrap(err, "failed to revoke key") + c.UI.Error(fmt.Sprintf("failed to revoke key: %s", err)) + return 2 } - return nil + return 0 } // handleTypeDynamic is used to handle SSH logins using the "dyanmic" key type. -func (c *SSHCommand) handleTypeDynamic() error { +func (c *SSHCommand) handleTypeDynamic(username, ip string, sshArgs []string) int { // Generate the credential - secret, cred, err := c.generateCredential() + secret, cred, err := c.generateCredential(username, ip) if err != nil { - return errors.Wrap(err, "failed to generate credential") + c.UI.Error(fmt.Sprintf("failed to generate credential: %s", err)) + return 2 } // Handle no-exec - if c.noExec { - // This is hacky, but OutputSecret returns an int, not an error :( - if i := OutputSecret(c.Ui, c.format, secret); i != 0 { - return fmt.Errorf("an error occurred outputting the secret") + if c.flagNoExec { + if c.flagFormat != "" { + return PrintRawField(c.UI, secret, c.flagField) } - return nil + return OutputSecret(c.UI, c.flagFormat, secret) } // Write the dynamic key to disk - name := fmt.Sprintf("vault_ssh_dynamic_%s_%s", c.username, c.ip) + name := fmt.Sprintf("vault_ssh_dynamic_%s_%s", username, ip) keyPath, err, closer := c.writeTemporaryKey(name, []byte(cred.Key)) defer closer() if err != nil { - return errors.Wrap(err, "failed to save dyanmic key") + c.UI.Error(fmt.Sprintf("failed to write dynamic key: %s", err)) + return 1 } args := append([]string{ "-i", keyPath, - "-o UserKnownHostsFile=" + c.userKnownHostsFile, - "-o StrictHostKeyChecking=" + c.strictHostKeyChecking, + "-o UserKnownHostsFile=" + c.flagUserKnownHostsFile, + "-o StrictHostKeyChecking=" + c.flagStrictHostKeyChecking, "-p", cred.Port, - c.username + "@" + c.ip, - }, c.sshArgs...) + username + "@" + ip, + }, sshArgs...) cmd := exec.Command("ssh", args...) cmd.Stdin = os.Stdin @@ -435,24 +584,44 @@ func (c *SSHCommand) handleTypeDynamic() error { cmd.Stderr = os.Stderr err = cmd.Run() if err != nil { - return errors.Wrap(err, "failed to run ssh command") + exitCode := 2 + + if exitError, ok := err.(*exec.ExitError); ok { + if exitError.Success() { + return 0 + } + if ws, ok := exitError.Sys().(syscall.WaitStatus); ok { + exitCode = ws.ExitStatus() + } + } + + c.UI.Error(fmt.Sprintf("failed to run ssh command: %s", err)) + return exitCode } // Revoke the key if it's longer than expected if err := c.client.Sys().Revoke(secret.LeaseID); err != nil { - return errors.Wrap(err, "failed to revoke key") + c.UI.Error(fmt.Sprintf("failed to revoke key: %s", err)) + return 2 } - return nil + return 0 } // generateCredential generates a credential for the given role and returns the // decoded secret data. -func (c *SSHCommand) generateCredential() (*api.Secret, *SSHCredentialResp, error) { +func (c *SSHCommand) generateCredential(username, ip string) (*api.Secret, *SSHCredentialResp, error) { + client, err := c.Client() + if err != nil { + return nil, nil, err + } + + sshClient := client.SSHWithMountPoint(c.flagMountPoint) + // Attempt to generate the credential. - secret, err := c.sshClient.Credential(c.role, map[string]interface{}{ - "username": c.username, - "ip": c.ip, + secret, err := sshClient.Credential(c.flagRole, map[string]interface{}{ + "username": username, + "ip": ip, }) if err != nil { return nil, nil, errors.Wrap(err, "failed to get credentials") @@ -540,9 +709,9 @@ func (c *SSHCommand) defaultRole(mountPoint, ip string) (string, error) { } roleNames = strings.TrimRight(roleNames, ", ") return "", fmt.Errorf("Roles:%q. "+` - Multiple roles are registered for this IP. - Select a role using '-role' option. - Note that all roles may not be permitted, based on ACLs.`, roleNames) + Multiple roles are registered for this IP. + Select a role using '-role' option. + Note that all roles may not be permitted, based on ACLs.`, roleNames) } } @@ -580,102 +749,3 @@ func (c *SSHCommand) userAndIP(s string) (string, string, error) { return username, ip, nil } - -func (c *SSHCommand) Synopsis() string { - return "Initiate an SSH session" -} - -func (c *SSHCommand) Help() string { - helpText := ` -Usage: vault ssh [options] username@ip [ssh options] - - Establishes an SSH connection with the target machine. - - This command uses one of the SSH authentication backends to authenticate and - automatically establish an SSH connection to a host. This operation requires - that the SSH backend is mounted and configured. - - SSH using the OTP mode (requires sshpass for full automation): - - $ vault ssh -mode=otp -role=my-role user@1.2.3.4 - - SSH using the CA mode: - - $ vault ssh -mode=ca -role=my-role user@1.2.3.4 - - SSH using CA mode with host key verification: - - $ vault ssh \ - -mode=ca \ - -role=my-role \ - -host-key-mount-point=host-signer \ - -host-key-hostnames=example.com \ - user@example.com - - For the full list of options and arguments, please see the documentation. - -General Options: -` + meta.GeneralOptionsUsage() + ` -SSH Options: - - -role Role to be used to create the key. Each IP is associated with - a role. To see the associated roles with IP, use "lookup" - endpoint. If you are certain that there is only one role - associated with the IP, you can skip mentioning the role. It - will be chosen by default. If there are no roles associated - with the IP, register the CIDR block of that IP using the - "roles/" endpoint. - - -no-exec Shows the credentials but does not establish connection. - - -mount-point Mount point of SSH backend. If the backend is mounted at - "ssh" (default), this parameter can be skipped. - - -format If the "no-exec" option is enabled, the credentials will be - printed out and SSH connection will not be established. The - format of the output can be "json" or "table" (default). - - -strict-host-key-checking This option corresponds to "StrictHostKeyChecking" - of SSH configuration. If "sshpass" is employed to enable - automated login, then if host key is not "known" to the - client, "vault ssh" command will fail. Set this option to - "no" to bypass the host key checking. Defaults to "ask". - Can also be specified with the - "VAULT_SSH_STRICT_HOST_KEY_CHECKING" environment variable. - - -user-known-hosts-file This option corresponds to "UserKnownHostsFile" of - SSH configuration. Assigns the file to use for storing the - host keys. If this option is set to "/dev/null" along with - "-strict-host-key-checking=no", both warnings and host key - checking can be avoided while establishing the connection. - Defaults to "~/.ssh/known_hosts". Can also be specified with - "VAULT_SSH_USER_KNOWN_HOSTS_FILE" environment variable. - -CA Mode Options: - - - public-key-path= - The path to the public key to send to Vault for signing. The default value - is ~/.ssh/id_rsa.pub. - - - private-key-path= - The path to the private key to use for authentication. This must be the - corresponding private key to -public-key-path. The default value is - ~/.ssh/id_rsa. - - - host-key-mount-point= - The mount point to the SSH backend where host keys are signed. When given - a value, Vault will generate a custom known_hosts file with delegation to - the CA at the provided mount point and verify the SSH connection's host - keys against the provided CA. By default, this command uses the users's - existing known_hosts file. When this flag is set, this command will force - strict host key checking and will override any values provided for a - custom -user-known-hosts-file. - - - host-key-hostnames= - The list of hostnames to delegate for this certificate authority. By - default, this is "*", which allows all domains and IPs. To restrict - validation to a series of hostnames, specify them as comma-separated - values here. -` - return strings.TrimSpace(helpText) -} diff --git a/command/ssh_test.go b/command/ssh_test.go index 70a58f5431..189ea2887f 100644 --- a/command/ssh_test.go +++ b/command/ssh_test.go @@ -1,199 +1,23 @@ package command import ( - "bytes" - "fmt" - "io" - "os" - "strings" "testing" - logicalssh "github.com/hashicorp/vault/builtin/logical/ssh" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" ) -const ( - testCidr = "127.0.0.1/32" - testRoleName = "testRoleName" - testKey = "testKey" - testSharedPrivateKey = ` ------BEGIN RSA PRIVATE KEY----- -MIIEogIBAAKCAQEAvYvoRcWRxqOim5VZnuM6wHCbLUeiND0yaM1tvOl+Fsrz55DG -A0OZp4RGAu1Fgr46E1mzxFz1+zY4UbcEExg+u21fpa8YH8sytSWW1FyuD8ICib0A -/l8slmDMw4BkkGOtSlEqgscpkpv/TWZD1NxJWkPcULk8z6c7TOETn2/H9mL+v2RE -mbE6NDEwJKfD3MvlpIqCP7idR+86rNBAODjGOGgyUbtFLT+K01XmDRALkV3V/nh+ -GltyjL4c6RU4zG2iRyV5RHlJtkml+UzUMkzr4IQnkCC32CC/wmtoo/IsAprpcHVe -nkBn3eFQ7uND70p5n6GhN/KOh2j519JFHJyokwIDAQABAoIBAHX7VOvBC3kCN9/x -+aPdup84OE7Z7MvpX6w+WlUhXVugnmsAAVDczhKoUc/WktLLx2huCGhsmKvyVuH+ -MioUiE+vx75gm3qGx5xbtmOfALVMRLopjCnJYf6EaFA0ZeQ+NwowNW7Lu0PHmAU8 -Z3JiX8IwxTz14DU82buDyewO7v+cEr97AnERe3PUcSTDoUXNaoNxjNpEJkKREY6h -4hAY676RT/GsRcQ8tqe/rnCqPHNd7JGqL+207FK4tJw7daoBjQyijWuB7K5chSal -oPInylM6b13ASXuOAOT/2uSUBWmFVCZPDCmnZxy2SdnJGbsJAMl7Ma3MUlaGvVI+ -Tfh1aQkCgYEA4JlNOabTb3z42wz6mz+Nz3JRwbawD+PJXOk5JsSnV7DtPtfgkK9y -6FTQdhnozGWShAvJvc+C4QAihs9AlHXoaBY5bEU7R/8UK/pSqwzam+MmxmhVDV7G -IMQPV0FteoXTaJSikhZ88mETTegI2mik+zleBpVxvfdhE5TR+lq8Br0CgYEA2AwJ -CUD5CYUSj09PluR0HHqamWOrJkKPFPwa+5eiTTCzfBBxImYZh7nXnWuoviXC0sg2 -AuvCW+uZ48ygv/D8gcz3j1JfbErKZJuV+TotK9rRtNIF5Ub7qysP7UjyI7zCssVM -kuDd9LfRXaB/qGAHNkcDA8NxmHW3gpln4CFdSY8CgYANs4xwfercHEWaJ1qKagAe -rZyrMpffAEhicJ/Z65lB0jtG4CiE6w8ZeUMWUVJQVcnwYD+4YpZbX4S7sJ0B8Ydy -AhkSr86D/92dKTIt2STk6aCN7gNyQ1vW198PtaAWH1/cO2UHgHOy3ZUt5X/Uwxl9 -cex4flln+1Viumts2GgsCQKBgCJH7psgSyPekK5auFdKEr5+Gc/jB8I/Z3K9+g4X -5nH3G1PBTCJYLw7hRzw8W/8oALzvddqKzEFHphiGXK94Lqjt/A4q1OdbCrhiE68D -My21P/dAKB1UYRSs9Y8CNyHCjuZM9jSMJ8vv6vG/SOJPsnVDWVAckAbQDvlTHC9t -O98zAoGAcbW6uFDkrv0XMCpB9Su3KaNXOR0wzag+WIFQRXCcoTvxVi9iYfUReQPi -oOyBJU/HMVvBfv4g+OVFLVgSwwm6owwsouZ0+D/LasbuHqYyqYqdyPJQYzWA2Y+F -+B6f4RoPdSXj24JHPg/ioRxjaj094UXJxua2yfkcecGNEuBQHSs= ------END RSA PRIVATE KEY----- -` -) +func testSSHCommand(tb testing.TB) (*cli.MockUi, *SSHCommand) { + tb.Helper() -var testIP string -var testPort string -var testUserName string -var testAdminUser string - -// Starts the server and initializes the servers IP address, -// port and usernames to be used by the test cases. -func initTest() { - addr, err := vault.StartSSHHostTestServer() - if err != nil { - panic(fmt.Sprintf("Error starting mock server:%s", err)) + ui := cli.NewMockUi() + return ui, &SSHCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, } - input := strings.Split(addr, ":") - testIP = input[0] - testPort = input[1] - - testUserName := os.Getenv("VAULT_SSHTEST_USER") - if len(testUserName) == 0 { - panic("VAULT_SSHTEST_USER must be set to the desired user") - } - testAdminUser = testUserName } -// This test is broken. Hence temporarily disabling it. -func testSSH(t *testing.T) { - initTest() - // Add the SSH backend to the unsealed test core. - // This should be done before the unsealed core is created. - err := vault.AddTestLogicalBackend("ssh", logicalssh.Factory) - if err != nil { - t.Fatalf("err: %s", err) - } - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - mountCmd := &MountCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, - }, - } - - args := []string{"-address", addr, "ssh"} - - // Mount the SSH backend - if code := mountCmd.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - client, err := mountCmd.Client() - if err != nil { - t.Fatalf("err: %s", err) - } - - mounts, err := client.Sys().ListMounts() - if err != nil { - t.Fatalf("err: %s", err) - } - - // Check if SSH backend is mounted or not - mount, ok := mounts["ssh/"] - if !ok { - t.Fatal("should have ssh mount") - } - if mount.Type != "ssh" { - t.Fatal("should have ssh type") - } - - writeCmd := &WriteCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, - }, - } - - // Create a 'named' key in vault - args = []string{ - "-address", addr, - "ssh/keys/" + testKey, - "key=" + testSharedPrivateKey, - } - if code := writeCmd.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - // Create a role using the named key along with cidr, username and port - args = []string{ - "-address", addr, - "ssh/roles/" + testRoleName, - "key=" + testKey, - "admin_user=" + testUserName, - "cidr=" + testCidr, - "port=" + testPort, - } - if code := writeCmd.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - sshCmd := &SSHCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, - }, - } - - // Get the dynamic key and establish an SSH connection with target. - // Inline command when supplied, runs on target and terminates the - // connection. Use whoami as the inline command in target and get - // the result. Compare the result with the username used to connect - // to target. Test succeeds if they match. - args = []string{ - "-address", addr, - "-role=" + testRoleName, - testUserName + "@" + testIP, - "/usr/bin/whoami", - } - - // Creating pipe to get the result of the inline command run in target machine. - stdout := os.Stdout - r, w, err := os.Pipe() - if err != nil { - t.Fatalf("err: %s", err) - } - os.Stdout = w - if code := sshCmd.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - bufChan := make(chan string) - go func() { - var buf bytes.Buffer - io.Copy(&buf, r) - bufChan <- buf.String() - }() - w.Close() - os.Stdout = stdout - userName := <-bufChan - userName = strings.TrimSpace(userName) - - // Comparing the username used to connect to target and - // the username on the target, thereby verifying successful - // execution - if userName != testUserName { - t.Fatalf("err: username mismatch") - } +func TestSSHCommand_Run(t *testing.T) { + t.Parallel() + t.Skip("Need a way to setup target infrastructure") } diff --git a/command/status.go b/command/status.go index 76b7a0fc30..b31b093c17 100644 --- a/command/status.go +++ b/command/status.go @@ -4,123 +4,86 @@ import ( "fmt" "strings" - "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/meta" + "github.com/mitchellh/cli" + "github.com/posener/complete" ) -// StatusCommand is a Command that outputs the status of whether -// Vault is sealed or not as well as HA information. +var _ cli.Command = (*StatusCommand)(nil) +var _ cli.CommandAutocomplete = (*StatusCommand)(nil) + type StatusCommand struct { - meta.Meta -} - -func (c *StatusCommand) Run(args []string) int { - flags := c.Meta.FlagSet("status", meta.FlagSetDefault) - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 1 - } - - sealStatus, err := client.Sys().SealStatus() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error checking seal status: %s", err)) - return 1 - } - - var sealPrefix string - if sealStatus.RecoverySeal { - sealPrefix = "Recovery " - } - outStr := fmt.Sprintf( - "%sSeal Type: %s\n"+ - "Sealed: %v\n"+ - "%sKey Shares: %d\n"+ - "%sKey Threshold: %d\n"+ - "Unseal Progress: %d\n"+ - "Unseal Nonce: %v\n"+ - "Version: %s", - sealPrefix, - sealStatus.Type, - sealStatus.Sealed, - sealPrefix, - sealStatus.N, - sealPrefix, - sealStatus.T, - sealStatus.Progress, - sealStatus.Nonce, - sealStatus.Version) - - if sealStatus.ClusterName != "" && sealStatus.ClusterID != "" { - outStr = fmt.Sprintf("%s\nCluster Name: %s\nCluster ID: %s", outStr, sealStatus.ClusterName, sealStatus.ClusterID) - } - - c.Ui.Output(outStr) - - // Mask the 'Vault is sealed' error, since this means HA is enabled, - // but that we cannot query for the leader since we are sealed. - leaderStatus, err := client.Sys().Leader() - if err != nil && strings.Contains(err.Error(), "Vault is sealed") { - leaderStatus = &api.LeaderResponse{HAEnabled: true} - err = nil - } - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error checking leader status: %s", err)) - return 1 - } - - // Output if HA is enabled - c.Ui.Output("") - c.Ui.Output(fmt.Sprintf("High-Availability Enabled: %v", leaderStatus.HAEnabled)) - if leaderStatus.HAEnabled { - if sealStatus.Sealed { - c.Ui.Output("\tMode: sealed") - } else { - mode := "standby" - if leaderStatus.IsSelf { - mode = "active" - } - c.Ui.Output(fmt.Sprintf("\tMode: %s", mode)) - - if leaderStatus.LeaderAddress == "" { - leaderStatus.LeaderAddress = "" - } - if leaderStatus.LeaderClusterAddress == "" { - leaderStatus.LeaderClusterAddress = "" - } - c.Ui.Output(fmt.Sprintf("\tLeader Cluster Address: %s", leaderStatus.LeaderClusterAddress)) - } - } - - if sealStatus.Sealed { - return 2 - } else { - return 0 - } + *BaseCommand } func (c *StatusCommand) Synopsis() string { - return "Outputs status of whether Vault is sealed and if HA mode is enabled" + return "Print seal and HA status" } func (c *StatusCommand) Help() string { helpText := ` Usage: vault status [options] - Outputs the state of the Vault, sealed or unsealed and if HA is enabled. + Prints the current state of Vault including whether it is sealed and if HA + mode is enabled. This command prints regardless of whether the Vault is + sealed. - This command outputs whether or not the Vault is sealed. The exit - code also reflects the seal status (0 unsealed, 2 sealed, 1 error). + The exit code reflects the seal status: + + - 0 - unsealed + - 1 - error + - 2 - sealed + +` + c.Flags().Help() -General Options: -` + meta.GeneralOptionsUsage() return strings.TrimSpace(helpText) } + +func (c *StatusCommand) Flags() *FlagSets { + return c.flagSet(FlagSetHTTP) +} + +func (c *StatusCommand) AutocompleteArgs() complete.Predictor { + return complete.PredictNothing +} + +func (c *StatusCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *StatusCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + if len(args) > 0 { + c.UI.Error(fmt.Sprintf("Too many arguments (expected 0, got %d)", len(args))) + return 1 + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + // We return 2 everywhere else, but 2 is reserved for "sealed" here + return 1 + } + + status, err := client.Sys().SealStatus() + if err != nil { + c.UI.Error(fmt.Sprintf("Error checking seal status: %s", err)) + return 1 + } + + // Do not return the int here, since we want to return a custom error code + // depending on the seal status. + OutputSealStatus(c.UI, client, status) + + if status.Sealed { + return 2 + } + + return 0 +} diff --git a/command/status_test.go b/command/status_test.go index 92e7f74719..e34a72c578 100644 --- a/command/status_test.go +++ b/command/status_test.go @@ -1,39 +1,115 @@ package command import ( + "strings" "testing" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" ) -func TestStatus(t *testing.T) { - ui := new(cli.MockUi) - c := &StatusCommand{ - Meta: meta.Meta{ - Ui: ui, +func testStatusCommand(tb testing.TB) (*cli.MockUi, *StatusCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &StatusCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestStatusCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + sealed bool + out string + code int + }{ + { + "unsealed", + nil, + false, + "Sealed false", + 0, + }, + { + "sealed", + nil, + true, + "Sealed true", + 2, + }, + { + "args", + []string{"foo"}, + false, + "Too many arguments", + 1, }, } - core := vault.TestCore(t) - keys, _ := vault.TestCoreInit(t, core) - ln, addr := http.TestServer(t, core) - defer ln.Close() + t.Run("validations", func(t *testing.T) { + t.Parallel() - args := []string{"-address", addr} - if code := c.Run(args); code != 2 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } + for _, tc := range cases { + tc := tc - for _, key := range keys { - if _, err := core.Unseal(key); err != nil { - t.Fatalf("err: %s", err) + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + if tc.sealed { + if err := client.Sys().Seal(); err != nil { + t.Fatal(err) + } + } + + ui, cmd := testStatusCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) } - } + }) - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testStatusCommand(t) + cmd.client = client + + code := cmd.Run([]string{}) + if exp := 1; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error checking seal status: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testStatusCommand(t) + assertNoTabs(t, cmd) + }) } diff --git a/command/step-down.go b/command/step-down.go deleted file mode 100644 index be445a8389..0000000000 --- a/command/step-down.go +++ /dev/null @@ -1,55 +0,0 @@ -package command - -import ( - "fmt" - "strings" - - "github.com/hashicorp/vault/meta" -) - -// StepDownCommand is a Command that seals the vault. -type StepDownCommand struct { - meta.Meta -} - -func (c *StepDownCommand) Run(args []string) int { - flags := c.Meta.FlagSet("step-down", meta.FlagSetDefault) - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 2 - } - - if err := client.Sys().StepDown(); err != nil { - c.Ui.Error(fmt.Sprintf("Error stepping down: %s", err)) - return 1 - } - - return 0 -} - -func (c *StepDownCommand) Synopsis() string { - return "Force the Vault node to give up active duty" -} - -func (c *StepDownCommand) Help() string { - helpText := ` -Usage: vault step-down [options] - - Force the Vault node to step down from active duty. - - This causes the indicated node to give up active status. Note that while the - affected node will have a short delay before attempting to grab the lock - again, if no other node grabs the lock beforehand, it is possible for the - same node to re-grab the lock and become active again. - -General Options: -` + meta.GeneralOptionsUsage() - return strings.TrimSpace(helpText) -} diff --git a/command/token.go b/command/token.go new file mode 100644 index 0000000000..20af230a5b --- /dev/null +++ b/command/token.go @@ -0,0 +1,46 @@ +package command + +import ( + "strings" + + "github.com/mitchellh/cli" +) + +var _ cli.Command = (*TokenCommand)(nil) + +type TokenCommand struct { + *BaseCommand +} + +func (c *TokenCommand) Synopsis() string { + return "Interact with tokens" +} + +func (c *TokenCommand) Help() string { + helpText := ` +Usage: vault token [options] [args] + + This command groups subcommands for interacting with tokens. Users can + create, lookup, renew, and revoke tokens. + + Create a new token: + + $ vault token create + + Revoke a token: + + $ vault token revoke 96ddf4bc-d217-f3ba-f9bd-017055595017 + + Renew a token: + + $ vault token renew 96ddf4bc-d217-f3ba-f9bd-017055595017 + + Please see the individual subcommand help for detailed usage information. +` + + return strings.TrimSpace(helpText) +} + +func (c *TokenCommand) Run(args []string) int { + return cli.RunResultHelp +} diff --git a/command/token/helper.go b/command/token/helper.go index db068beb29..cac79487fa 100644 --- a/command/token/helper.go +++ b/command/token/helper.go @@ -3,7 +3,7 @@ package token // TokenHelper is an interface that contains basic operations that must be // implemented by a token helper type TokenHelper interface { - // Path displays a backend-specific path; for the internal helper this + // Path displays a method-specific path; for the internal helper this // is the location of the token stored on disk; for the external helper // this is the location of the binary being invoked Path() string diff --git a/command/token/helper_external.go b/command/token/helper_external.go index 40de9bfde1..4483074af7 100644 --- a/command/token/helper_external.go +++ b/command/token/helper_external.go @@ -36,6 +36,8 @@ func ExternalTokenHelperPath(path string) (string, error) { return path, nil } +var _ TokenHelper = (*ExternalTokenHelper)(nil) + // ExternalTokenHelper is the struct that has all the logic for storing and retrieving // tokens from the token helper. The API for the helpers is simple: the // BinaryPath is executed within a shell with environment Env. The last argument diff --git a/command/token/helper_internal.go b/command/token/helper_internal.go index 89793cb63d..58dceaebcc 100644 --- a/command/token/helper_internal.go +++ b/command/token/helper_internal.go @@ -10,6 +10,8 @@ import ( "github.com/mitchellh/go-homedir" ) +var _ TokenHelper = (*InternalTokenHelper)(nil) + // InternalTokenHelper fulfills the TokenHelper interface when no external // token-helper is configured, and avoids shelling out type InternalTokenHelper struct { diff --git a/command/token/helper_testing.go b/command/token/helper_testing.go new file mode 100644 index 0000000000..93465931b7 --- /dev/null +++ b/command/token/helper_testing.go @@ -0,0 +1,42 @@ +package token + +import ( + "sync" +) + +var _ TokenHelper = (*TestingTokenHelper)(nil) + +// TestingTokenHelper implements token.TokenHelper which runs entirely +// in-memory. This should not be used outside of testing. +type TestingTokenHelper struct { + lock sync.RWMutex + token string +} + +func NewTestingTokenHelper() *TestingTokenHelper { + return &TestingTokenHelper{} +} + +func (t *TestingTokenHelper) Erase() error { + t.lock.Lock() + defer t.lock.Unlock() + t.token = "" + return nil +} + +func (t *TestingTokenHelper) Get() (string, error) { + t.lock.RLock() + defer t.lock.RUnlock() + return t.token, nil +} + +func (t *TestingTokenHelper) Path() string { + return "" +} + +func (t *TestingTokenHelper) Store(token string) error { + t.lock.Lock() + defer t.lock.Unlock() + t.token = token + return nil +} diff --git a/command/token_capabilities.go b/command/token_capabilities.go new file mode 100644 index 0000000000..3212546b43 --- /dev/null +++ b/command/token_capabilities.go @@ -0,0 +1,100 @@ +package command + +import ( + "fmt" + "sort" + "strings" + + "github.com/mitchellh/cli" + "github.com/posener/complete" +) + +var _ cli.Command = (*TokenCapabilitiesCommand)(nil) +var _ cli.CommandAutocomplete = (*TokenCapabilitiesCommand)(nil) + +type TokenCapabilitiesCommand struct { + *BaseCommand +} + +func (c *TokenCapabilitiesCommand) Synopsis() string { + return "Print capabilities of a token on a path" +} + +func (c *TokenCapabilitiesCommand) Help() string { + helpText := ` +Usage: vault token capabilities [options] [TOKEN] PATH + + Fetches the capabilities of a token for a given path. If a TOKEN is provided + as an argument, the "/sys/capabilities" endpoint and permission is used. If + no TOKEN is provided, the "/sys/capabilities-self" endpoint and permission + is used with the locally authenticated token. + + List capabilities for the local token on the "secret/foo" path: + + $ vault token capabilities secret/foo + + List capabilities for a token on the "cubbyhole/foo" path: + + $ vault token capabilities 96ddf4bc-d217-f3ba-f9bd-017055595017 cubbyhole/foo + + For a full list of examples, please see the documentation. + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *TokenCapabilitiesCommand) Flags() *FlagSets { + return c.flagSet(FlagSetHTTP) +} + +func (c *TokenCapabilitiesCommand) AutocompleteArgs() complete.Predictor { + return nil +} + +func (c *TokenCapabilitiesCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *TokenCapabilitiesCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + token := "" + path := "" + args = f.Args() + switch { + case len(args) == 1: + path = args[0] + case len(args) == 2: + token, path = args[0], args[1] + default: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 1-2, got %d)", len(args))) + return 1 + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + var capabilities []string + if token == "" { + capabilities, err = client.Sys().CapabilitiesSelf(path) + } else { + capabilities, err = client.Sys().Capabilities(token, path) + } + if err != nil { + c.UI.Error(fmt.Sprintf("Error listing capabilities: %s", err)) + return 2 + } + + sort.Strings(capabilities) + c.UI.Output(strings.Join(capabilities, ", ")) + return 0 +} diff --git a/command/token_capabilities_test.go b/command/token_capabilities_test.go new file mode 100644 index 0000000000..74efd0cab5 --- /dev/null +++ b/command/token_capabilities_test.go @@ -0,0 +1,170 @@ +package command + +import ( + "strings" + "testing" + + "github.com/hashicorp/vault/api" + "github.com/mitchellh/cli" +) + +func testTokenCapabilitiesCommand(tb testing.TB) (*cli.MockUi, *TokenCapabilitiesCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &TokenCapabilitiesCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestTokenCapabilitiesCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "too_many_args", + []string{"foo", "bar", "zip"}, + "Too many arguments", + 1, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ui, cmd := testTokenCapabilitiesCommand(t) + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + + t.Run("token", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + policy := `path "secret/foo" { capabilities = ["read"] }` + if err := client.Sys().PutPolicy("policy", policy); err != nil { + t.Error(err) + } + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"policy"}, + TTL: "30m", + }) + if err != nil { + t.Fatal(err) + } + if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" { + t.Fatalf("missing auth data: %#v", secret) + } + token := secret.Auth.ClientToken + + ui, cmd := testTokenCapabilitiesCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + token, "secret/foo", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "read" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("local", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + policy := `path "secret/foo" { capabilities = ["read"] }` + if err := client.Sys().PutPolicy("policy", policy); err != nil { + t.Error(err) + } + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"policy"}, + TTL: "30m", + }) + if err != nil { + t.Fatal(err) + } + if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" { + t.Fatalf("missing auth data: %#v", secret) + } + token := secret.Auth.ClientToken + + client.SetToken(token) + + ui, cmd := testTokenCapabilitiesCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "secret/foo", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "read" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testTokenCapabilitiesCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "foo", "bar", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error listing capabilities: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testTokenCapabilitiesCommand(t) + assertNoTabs(t, cmd) + }) +} diff --git a/command/token_create.go b/command/token_create.go index f8d8c59265..a75a6066ff 100644 --- a/command/token_create.go +++ b/command/token_create.go @@ -3,174 +3,248 @@ package command import ( "fmt" "strings" + "time" "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/helper/flag-kv" - "github.com/hashicorp/vault/helper/flag-slice" - "github.com/hashicorp/vault/meta" + "github.com/mitchellh/cli" + "github.com/posener/complete" ) -// TokenCreateCommand is a Command that mounts a new mount. +var _ cli.Command = (*TokenCreateCommand)(nil) +var _ cli.CommandAutocomplete = (*TokenCreateCommand)(nil) + type TokenCreateCommand struct { - meta.Meta -} + *BaseCommand -func (c *TokenCreateCommand) Run(args []string) int { - var format string - var id, displayName, lease, ttl, explicitMaxTTL, period, role string - var orphan, noDefaultPolicy, renewable bool - var metadata map[string]string - var numUses int - var policies []string - flags := c.Meta.FlagSet("mount", meta.FlagSetDefault) - flags.StringVar(&format, "format", "table", "") - flags.StringVar(&displayName, "display-name", "", "") - flags.StringVar(&id, "id", "", "") - flags.StringVar(&lease, "lease", "", "") - flags.StringVar(&ttl, "ttl", "", "") - flags.StringVar(&explicitMaxTTL, "explicit-max-ttl", "", "") - flags.StringVar(&period, "period", "", "") - flags.StringVar(&role, "role", "", "") - flags.BoolVar(&orphan, "orphan", false, "") - flags.BoolVar(&renewable, "renewable", true, "") - flags.BoolVar(&noDefaultPolicy, "no-default-policy", false, "") - flags.IntVar(&numUses, "use-limit", 0, "") - flags.Var((*kvFlag.Flag)(&metadata), "metadata", "") - flags.Var((*sliceflag.StringFlag)(&policies), "policy", "") - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } + flagID string + flagDisplayName string + flagTTL time.Duration + flagExplicitMaxTTL time.Duration + flagPeriod time.Duration + flagRenewable bool + flagOrphan bool + flagNoDefaultPolicy bool + flagUseLimit int + flagRole string + flagMetadata map[string]string + flagPolicies []string - args = flags.Args() - if len(args) != 0 { - flags.Usage() - c.Ui.Error(fmt.Sprintf( - "\ntoken-create expects no arguments")) - return 1 - } - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 2 - } - - if ttl == "" { - ttl = lease - } - - tcr := &api.TokenCreateRequest{ - ID: id, - Policies: policies, - Metadata: metadata, - TTL: ttl, - NoParent: orphan, - NoDefaultPolicy: noDefaultPolicy, - DisplayName: displayName, - NumUses: numUses, - Renewable: new(bool), - ExplicitMaxTTL: explicitMaxTTL, - Period: period, - } - *tcr.Renewable = renewable - - var secret *api.Secret - if role != "" { - secret, err = client.Auth().Token().CreateWithRole(tcr, role) - } else { - secret, err = client.Auth().Token().Create(tcr) - } - - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error creating token: %s", err)) - return 2 - } - - return OutputSecret(c.Ui, format, secret) + // Deprecated flags + flagLease time.Duration } func (c *TokenCreateCommand) Synopsis() string { - return "Create a new auth token" + return "Create a new token" } func (c *TokenCreateCommand) Help() string { helpText := ` -Usage: vault token-create [options] +Usage: vault token create [options] - Create a new auth token. + Creates a new token that can be used for authentication. This token will be + created as a child of the currently authenticated token. The generated token + will inherit all policies and permissions of the currently authenticated + token unless you explicitly define a subset list policies to assign to the + token. - This command creates a new token that can be used for authentication. - This token will be created as a child of your token. The created token - will inherit your policies, or can be assigned a subset of your policies. - - A lease can also be associated with the token. If a lease is not associated - with the token, then it cannot be renewed. If a lease is associated with + A ttl can also be associated with the token. If a ttl is not associated + with the token, then it cannot be renewed. If a ttl is associated with the token, it will expire after that amount of time unless it is renewed. - Metadata associated with the token (specified with "-metadata") is - written to the audit log when the token is used. + Metadata associated with the token (specified with "-metadata") is written + to the audit log when the token is used. If a role is specified, the role may override parameters specified here. -General Options: -` + meta.GeneralOptionsUsage() + ` -Token Options: +` + c.Flags().Help() - -id="7699125c-d8...." The token value that clients will use to authenticate - with Vault. If not provided this defaults to a 36 - character UUID. A root token is required to specify - the ID of a token. - - -display-name="name" A display name to associate with this token. This - is a non-security sensitive value used to help - identify created secrets, i.e. prefixes. - - -ttl="1h" Initial TTL to associate with the token; renewals can - extend this value. - - -explicit-max-ttl="1h" An explicit maximum lifetime for the token. Unlike - normal token TTLs, which can be renewed up until the - maximum TTL set on the auth/token mount or the system - configuration file, this lifetime is a hard limit set - on the token itself and cannot be exceeded. - - -period="1h" If specified, the token will be periodic; it will - have no maximum TTL (unless an "explicit-max-ttl" is - also set) but every renewal will use the given - period. Requires a root/sudo token to use. - - -renewable=true Whether or not the token is renewable to extend its - TTL up to Vault's configured maximum TTL for tokens. - This defaults to true; set to false to disable - renewal of this token. - - -metadata="key=value" Metadata to associate with the token. This shows - up in the audit log. This can be specified multiple - times. - - -orphan If specified, the token will have no parent. This - prevents the new token from being revoked with - your token. Requires a root/sudo token to use. - - -no-default-policy If specified, the token will not have the "default" - policy included in its policy set. - - -policy="name" Policy to associate with this token. This can be - specified multiple times. - - -use-limit=5 The number of times this token can be used until - it is automatically revoked. - - -format=table The format for output. By default it is a whitespace- - delimited table. This can also be json or yaml. - - -role=name If set, the token will be created against the named - role. The role may override other parameters. This - requires the client to have permissions on the - appropriate endpoint (auth/token/create/). -` return strings.TrimSpace(helpText) } + +func (c *TokenCreateCommand) Flags() *FlagSets { + set := c.flagSet(FlagSetHTTP | FlagSetOutputField | FlagSetOutputFormat) + + f := set.NewFlagSet("Command Options") + + f.StringVar(&StringVar{ + Name: "id", + Target: &c.flagID, + Completion: complete.PredictAnything, + Usage: "Value for the token. By default, this is an auto-generated 36 " + + "character UUID. Specifying this value requires sudo permissions.", + }) + + f.StringVar(&StringVar{ + Name: "display-name", + Target: &c.flagDisplayName, + Completion: complete.PredictAnything, + Usage: "Name to associate with this token. This is a non-sensitive value " + + "that can be used to help identify created secrets (e.g. prefixes).", + }) + + f.DurationVar(&DurationVar{ + Name: "ttl", + Target: &c.flagTTL, + Completion: complete.PredictAnything, + Usage: "Initial TTL to associate with the token. Token renewals may be " + + "able to extend beyond this value, depending on the configured maximum" + + "TTLs. This is specified as a numeric string with suffix like \"30s\" " + + "or \"5m\".", + }) + + f.DurationVar(&DurationVar{ + Name: "explicit-max-ttl", + Target: &c.flagExplicitMaxTTL, + Completion: complete.PredictAnything, + Usage: "Explicit maximum lifetime for the token. Unlike normal TTLs, the " + + "maximum TTL is a hard limit and cannot be exceeded. This is specified " + + "as a numeric string with suffix like \"30s\" or \"5m\".", + }) + + f.DurationVar(&DurationVar{ + Name: "period", + Target: &c.flagPeriod, + Completion: complete.PredictAnything, + Usage: "If specified, every renewal will use the given period. Periodic " + + "tokens do not expire (unless -explicit-max-ttl is also provided). " + + "Setting this value requires sudo permissions. This is specified as a " + + "numeric string with suffix like \"30s\" or \"5m\".", + }) + + f.BoolVar(&BoolVar{ + Name: "renewable", + Target: &c.flagRenewable, + Default: true, + Usage: "Allow the token to be renewed up to it's maximum TTL.", + }) + + f.BoolVar(&BoolVar{ + Name: "orphan", + Target: &c.flagOrphan, + Default: false, + Usage: "Create the token with no parent. This prevents the token from " + + "being revoked when the token which created it expires. Setting this " + + "value requires sudo permissions.", + }) + + f.BoolVar(&BoolVar{ + Name: "no-default-policy", + Target: &c.flagNoDefaultPolicy, + Default: false, + Usage: "Detach the \"default\" policy from the policy set for this " + + "token.", + }) + + f.IntVar(&IntVar{ + Name: "use-limit", + Target: &c.flagUseLimit, + Default: 0, + Usage: "Number of times this token can be used. After the last use, the " + + "token is automatically revoked. By default, tokens can be used an " + + "unlimited number of times until their expiration.", + }) + + f.StringVar(&StringVar{ + Name: "role", + Target: &c.flagRole, + Default: "", + Usage: "Name of the role to create the token against. Specifying -role " + + "may override other arguments. The locally authenticated Vault token " + + "must have permission for \"auth/token/create/\".", + }) + + f.StringMapVar(&StringMapVar{ + Name: "metadata", + Target: &c.flagMetadata, + Completion: complete.PredictAnything, + Usage: "Arbitrary key=value metadata to associate with the token. " + + "This metadata will show in the audit log when the token is used. " + + "This can be specified multiple times to add multiple pieces of " + + "metadata.", + }) + + f.StringSliceVar(&StringSliceVar{ + Name: "policy", + Target: &c.flagPolicies, + Completion: c.PredictVaultPolicies(), + Usage: "Name of a policy to associate with this token. This can be " + + "specified multiple times to attach multiple policies.", + }) + + // Deprecated flags + // TODO: remove in 0.9.0 + f.DurationVar(&DurationVar{ + Name: "lease", // prefer -ttl + Target: &c.flagLease, + Default: 0, + Hidden: true, + }) + + return set +} + +func (c *TokenCreateCommand) AutocompleteArgs() complete.Predictor { + return nil +} + +func (c *TokenCreateCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *TokenCreateCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + if len(args) > 0 { + c.UI.Error(fmt.Sprintf("Too many arguments (expected 0, got %d)", len(args))) + return 1 + } + + // TODO: remove in 0.9.0 + if c.flagLease != 0 { + c.UI.Warn("The -lease flag is deprecated. Please use -ttl instead.") + c.flagTTL = c.flagLease + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + tcr := &api.TokenCreateRequest{ + ID: c.flagID, + Policies: c.flagPolicies, + Metadata: c.flagMetadata, + TTL: c.flagTTL.String(), + NoParent: c.flagOrphan, + NoDefaultPolicy: c.flagNoDefaultPolicy, + DisplayName: c.flagDisplayName, + NumUses: c.flagUseLimit, + Renewable: &c.flagRenewable, + ExplicitMaxTTL: c.flagExplicitMaxTTL.String(), + Period: c.flagPeriod.String(), + } + + var secret *api.Secret + if c.flagRole != "" { + secret, err = client.Auth().Token().CreateWithRole(tcr, c.flagRole) + } else { + secret, err = client.Auth().Token().Create(tcr) + } + if err != nil { + c.UI.Error(fmt.Sprintf("Error creating token: %s", err)) + return 2 + } + + if c.flagField != "" { + return PrintRawField(c.UI, secret, c.flagField) + } + + return OutputSecret(c.UI, c.flagFormat, secret) +} diff --git a/command/token_create_test.go b/command/token_create_test.go index 9db2a26a29..ec0cc79fb5 100644 --- a/command/token_create_test.go +++ b/command/token_create_test.go @@ -1,38 +1,243 @@ package command import ( + "reflect" "strings" "testing" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" ) -func TestTokenCreate(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func testTokenCreateCommand(tb testing.TB) (*cli.MockUi, *TokenCreateCommand) { + tb.Helper() - ui := new(cli.MockUi) - c := &TokenCreateCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + ui := cli.NewMockUi() + return ui, &TokenCreateCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestTokenCreateCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "too_many_args", + []string{"abcd1234"}, + "Too many arguments", + 1, + }, + { + "default", + nil, + "token", + 0, + }, + { + "metadata", + []string{"-metadata", "foo=bar", "-metadata", "zip=zap"}, + "token", + 0, + }, + { + "policies", + []string{"-policy", "foo", "-policy", "bar"}, + "token", + 0, + }, + { + "field", + []string{ + "-field", "token_renewable", + }, + "false", + 0, + }, + { + "field_not_found", + []string{ + "-field", "not-a-real-field", + }, + "not present in secret", + 1, + }, + { + "format", + []string{ + "-format", "json", + }, + "{", + 0, + }, + { + "format_bad", + []string{ + "-format", "nope-not-real", + }, + "Invalid output format", + 1, }, } - args := []string{ - "-address", addr, - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } + t.Run("validations", func(t *testing.T) { + t.Parallel() - // Ensure we get lease info - output := ui.OutputWriter.String() - if !strings.Contains(output, "token_duration") { - t.Fatalf("bad: %#v", output) - } + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testTokenCreateCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) + + t.Run("default", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testTokenCreateCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "-field", "token", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + token := strings.TrimSpace(ui.OutputWriter.String()) + secret, err := client.Auth().Token().Lookup(token) + if secret == nil || err != nil { + t.Fatal(err) + } + }) + + t.Run("metadata", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testTokenCreateCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "-metadata", "foo=bar", + "-metadata", "zip=zap", + "-field", "token", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + token := strings.TrimSpace(ui.OutputWriter.String()) + secret, err := client.Auth().Token().Lookup(token) + if secret == nil || err != nil { + t.Fatal(err) + } + + meta, ok := secret.Data["meta"].(map[string]interface{}) + if !ok { + t.Fatalf("missing meta: %#v", secret) + } + if _, ok := meta["foo"]; !ok { + t.Errorf("missing meta.foo: %#v", meta) + } + if _, ok := meta["zip"]; !ok { + t.Errorf("missing meta.bar: %#v", meta) + } + }) + + t.Run("policies", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testTokenCreateCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "-policy", "foo", + "-policy", "bar", + "-field", "token", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + token := strings.TrimSpace(ui.OutputWriter.String()) + secret, err := client.Auth().Token().Lookup(token) + if secret == nil || err != nil { + t.Fatal(err) + } + + raw, ok := secret.Data["policies"].([]interface{}) + if !ok { + t.Fatalf("missing policies: %#v", secret) + } + + policies := make([]string, len(raw)) + for i := range raw { + policies[i] = raw[i].(string) + } + + expected := []string{"bar", "default", "foo"} + if !reflect.DeepEqual(policies, expected) { + t.Errorf("expected %q to be %q", policies, expected) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testTokenCreateCommand(t) + cmd.client = client + + code := cmd.Run([]string{}) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error creating token: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testTokenCreateCommand(t) + assertNoTabs(t, cmd) + }) } diff --git a/command/token_lookup.go b/command/token_lookup.go index c1c62ef716..2c885eaee9 100644 --- a/command/token_lookup.go +++ b/command/token_lookup.go @@ -5,96 +5,124 @@ import ( "strings" "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/meta" + "github.com/mitchellh/cli" + "github.com/posener/complete" ) -// TokenLookupCommand is a Command that outputs details about the -// provided. +var _ cli.Command = (*TokenLookupCommand)(nil) +var _ cli.CommandAutocomplete = (*TokenLookupCommand)(nil) + type TokenLookupCommand struct { - meta.Meta + *BaseCommand + + flagAccessor bool +} + +func (c *TokenLookupCommand) Synopsis() string { + return "Display information about a token" +} + +func (c *TokenLookupCommand) Help() string { + helpText := ` +Usage: vault token lookup [options] [TOKEN | ACCESSOR] + + Displays information about a token or accessor. If a TOKEN is not provided, + the locally authenticated token is used. + + Get information about the locally authenticated token (this uses the + /auth/token/lookup-self endpoint and permission): + + $ vault token lookup + + Get information about a particular token (this uses the /auth/token/lookup + endpoint and permission): + + $ vault token lookup 96ddf4bc-d217-f3ba-f9bd-017055595017 + + Get information about a token via its accessor: + + $ vault token lookup -accessor 9793c9b3-e04a-46f3-e7b8-748d7da248da + + For a full list of examples, please see the documentation. + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *TokenLookupCommand) Flags() *FlagSets { + set := c.flagSet(FlagSetHTTP | FlagSetOutputFormat) + + f := set.NewFlagSet("Command Options") + + f.BoolVar(&BoolVar{ + Name: "accessor", + Target: &c.flagAccessor, + Default: false, + EnvVar: "", + Completion: complete.PredictNothing, + Usage: "Treat the argument as an accessor instead of a token. When " + + "this option is selected, the output will NOT include the token.", + }) + + return set +} + +func (c *TokenLookupCommand) AutocompleteArgs() complete.Predictor { + return c.PredictVaultFiles() +} + +func (c *TokenLookupCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() } func (c *TokenLookupCommand) Run(args []string) int { - var format string - var accessor bool - flags := c.Meta.FlagSet("token-lookup", meta.FlagSetDefault) - flags.BoolVar(&accessor, "accessor", false, "") - flags.StringVar(&format, "format", "table", "") - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) return 1 } - args = flags.Args() - if len(args) > 1 { - flags.Usage() - c.Ui.Error(fmt.Sprintf( - "\ntoken-lookup expects at most one argument")) + token := "" + + args = f.Args() + switch { + case c.flagAccessor && len(args) < 1: + c.UI.Error(fmt.Sprintf("Not enough arguments with -accessor (expected 1, got %d)", len(args))) + return 1 + case c.flagAccessor && len(args) > 1: + c.UI.Error(fmt.Sprintf("Too many arguments with -accessor (expected 1, got %d)", len(args))) + return 1 + case len(args) == 0: + // Use the local token + case len(args) == 1: + token = strings.TrimSpace(args[0]) + case len(args) > 1: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 0-1, got %d)", len(args))) return 1 } client, err := c.Client() if err != nil { - c.Ui.Error(fmt.Sprintf( - "error initializing client: %s", err)) + c.UI.Error(err.Error()) return 2 } var secret *api.Secret switch { - case !accessor && len(args) == 0: + case token == "": secret, err = client.Auth().Token().LookupSelf() - case !accessor && len(args) == 1: - secret, err = client.Auth().Token().Lookup(args[0]) - case accessor && len(args) == 1: - secret, err = client.Auth().Token().LookupAccessor(args[0]) + case c.flagAccessor: + secret, err = client.Auth().Token().LookupAccessor(token) default: - // This happens only when accessor is set and no argument is passed - c.Ui.Error(fmt.Sprintf("token-lookup expects an argument when accessor flag is set")) - return 1 + secret, err = client.Auth().Token().Lookup(token) } if err != nil { - c.Ui.Error(fmt.Sprintf( - "error looking up token: %s", err)) - return 1 - } - return OutputSecret(c.Ui, format, secret) -} - -func doTokenLookup(args []string, client *api.Client) (*api.Secret, error) { - if len(args) == 0 { - return client.Auth().Token().LookupSelf() + c.UI.Error(fmt.Sprintf("Error looking up token: %s", err)) + return 2 } - token := args[0] - return client.Auth().Token().Lookup(token) -} - -func (c *TokenLookupCommand) Synopsis() string { - return "Display information about the specified token" -} - -func (c *TokenLookupCommand) Help() string { - helpText := ` -Usage: vault token-lookup [options] [token|accessor] - - Displays information about the specified token. If no token is specified, the - operation is performed on the currently authenticated token i.e. lookup-self. - Information about the token can be retrieved using the token accessor via the - '-accessor' flag. - -General Options: -` + meta.GeneralOptionsUsage() + ` -Token Lookup Options: - -accessor A boolean flag, if set, treats the argument as an accessor of the token. - Note that the response of the command when this is set, will not contain - the token ID. Accessor is only meant for looking up the token properties - (and for revocation via '/auth/token/revoke-accessor/' endpoint). - - -format=table The format for output. By default it is a whitespace- - delimited table. This can also be json or yaml. - -` - return strings.TrimSpace(helpText) + return OutputSecret(c.UI, c.flagFormat, secret) } diff --git a/command/token_lookup_test.go b/command/token_lookup_test.go index 143b9447d6..eeeaefd2b3 100644 --- a/command/token_lookup_test.go +++ b/command/token_lookup_test.go @@ -1,124 +1,189 @@ package command import ( + "strings" "testing" - "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" ) -func TestTokenLookupAccessor(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func testTokenLookupCommand(tb testing.TB) (*cli.MockUi, *TokenLookupCommand) { + tb.Helper() - ui := new(cli.MockUi) - c := &TokenLookupCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + ui := cli.NewMockUi() + return ui, &TokenLookupCommand{ + BaseCommand: &BaseCommand{ + UI: ui, }, } - args := []string{ - "-address", addr, - } - c.Run(args) +} - // Create a new token for us to use - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) +func TestTokenLookupCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "accessor_no_args", + []string{"-accessor"}, + "Not enough arguments", + 1, + }, + { + "accessor_too_many_args", + []string{"-accessor", "abcd1234", "efgh5678"}, + "Too many arguments", + 1, + }, + { + "too_many_args", + []string{"abcd1234", "efgh5678"}, + "Too many arguments", + 1, + }, + { + "format", + []string{"-format", "json"}, + "{", + 0, + }, + { + "format_bad", + []string{"-format", "nope-not-real"}, + "Invalid output format", + 1, + }, } - resp, err := client.Auth().Token().Create(&api.TokenCreateRequest{ - Lease: "1h", + + t.Run("validations", func(t *testing.T) { + t.Parallel() + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testTokenLookupCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } }) - if err != nil { - t.Fatalf("err: %s", err) - } - // Enable the accessor flag - args = append(args, "-accessor") + t.Run("token", func(t *testing.T) { + t.Parallel() - // Expect failure if no argument is passed when accessor flag is set - code := c.Run(args) - if code == 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } + client, closer := testVaultServer(t) + defer closer() - // Add token accessor as arg - args = append(args, resp.Auth.Accessor) - code = c.Run(args) - if code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } -} + token, _ := testTokenAndAccessor(t, client) -func TestTokenLookupSelf(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() + ui, cmd := testTokenLookupCommand(t) + cmd.client = client - ui := new(cli.MockUi) - c := &TokenLookupCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, - }, - } + code := cmd.Run([]string{ + token, + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } - args := []string{ - "-address", addr, - } - - // Run it against itself - code := c.Run(args) - - // Verify it worked - if code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } -} - -func TestTokenLookup(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - c := &TokenLookupCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, - }, - } - - args := []string{ - "-address", addr, - } - // Run it once for client - c.Run(args) - - // Create a new token for us to use - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } - resp, err := client.Auth().Token().Create(&api.TokenCreateRequest{ - Lease: "1h", + expected := token + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } }) - if err != nil { - t.Fatalf("err: %s", err) - } - // Add token as arg for real test and run it - args = append(args, resp.Auth.ClientToken) - code := c.Run(args) + t.Run("self", func(t *testing.T) { + t.Parallel() - // Verify it worked - if code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testTokenLookupCommand(t) + cmd.client = client + + code := cmd.Run([]string{}) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "display_name" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("accessor", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + _, accessor := testTokenAndAccessor(t, client) + + ui, cmd := testTokenLookupCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "-accessor", + accessor, + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := accessor + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testTokenLookupCommand(t) + cmd.client = client + + code := cmd.Run([]string{}) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error looking up token: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testTokenLookupCommand(t) + assertNoTabs(t, cmd) + }) } diff --git a/command/token_renew.go b/command/token_renew.go index 8ec1a550bb..6505ce328d 100644 --- a/command/token_renew.go +++ b/command/token_renew.go @@ -6,110 +6,131 @@ import ( "time" "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/helper/parseutil" - "github.com/hashicorp/vault/meta" + "github.com/mitchellh/cli" + "github.com/posener/complete" ) -// TokenRenewCommand is a Command that mounts a new mount. +var _ cli.Command = (*TokenRenewCommand)(nil) +var _ cli.CommandAutocomplete = (*TokenRenewCommand)(nil) + type TokenRenewCommand struct { - meta.Meta + *BaseCommand + + flagIncrement time.Duration +} + +func (c *TokenRenewCommand) Synopsis() string { + return "Renew a token lease" +} + +func (c *TokenRenewCommand) Help() string { + helpText := ` +Usage: vault token renew [options] [TOKEN] + + Renews a token's lease, extending the amount of time it can be used. If a + TOKEN is not provided, the locally authenticated token is used. Lease renewal + will fail if the token is not renewable, the token has already been revoked, + or if the token has already reached its maximum TTL. + + Renew a token (this uses the /auth/token/renew endpoint and permission): + + $ vault token renew 96ddf4bc-d217-f3ba-f9bd-017055595017 + + Renew the currently authenticated token (this uses the /auth/token/renew-self + endpoint and permission): + + $ vault token renew + + Renew a token requesting a specific increment value: + + $ vault token renew -increment=30m 96ddf4bc-d217-f3ba-f9bd-017055595017 + + For a full list of examples, please see the documentation. + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *TokenRenewCommand) Flags() *FlagSets { + set := c.flagSet(FlagSetHTTP | FlagSetOutputFormat) + f := set.NewFlagSet("Command Options") + + f.DurationVar(&DurationVar{ + Name: "increment", + Aliases: []string{"i"}, + Target: &c.flagIncrement, + Default: 0, + EnvVar: "", + Completion: complete.PredictAnything, + Usage: "Request a specific increment for renewal. Vault is not required " + + "to honor this request. If not supplied, Vault will use the default " + + "TTL. This is specified as a numeric string with suffix like \"30s\" " + + "or \"5m\".", + }) + + return set +} + +func (c *TokenRenewCommand) AutocompleteArgs() complete.Predictor { + return c.PredictVaultFiles() +} + +func (c *TokenRenewCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() } func (c *TokenRenewCommand) Run(args []string) int { - var format, increment string - flags := c.Meta.FlagSet("token-renew", meta.FlagSetDefault) - flags.StringVar(&format, "format", "table", "") - flags.StringVar(&increment, "increment", "", "") - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) return 1 } - args = flags.Args() - if len(args) > 2 { - flags.Usage() - c.Ui.Error(fmt.Sprintf( - "\ntoken-renew expects at most two arguments")) - return 1 - } + token := "" + increment := c.flagIncrement - var token string - if len(args) > 0 { - token = args[0] - } + args = f.Args() + switch len(args) { + case 0: + // Use the local token + case 1: + token = strings.TrimSpace(args[0]) + case 2: + // TODO: remove in 0.9.0 - backwards compat + c.UI.Warn("Specifying increment as a second argument is deprecated. " + + "Please use -increment instead.") - var inc int - // If both are specified prefer the argument - if len(args) == 2 { - increment = args[1] - } - if increment != "" { - dur, err := parseutil.ParseDurationSecond(increment) + token = strings.TrimSpace(args[0]) + parsed, err := time.ParseDuration(appendDurationSuffix(args[1])) if err != nil { - c.Ui.Error(fmt.Sprintf("Invalid increment: %s", err)) + c.UI.Error(fmt.Sprintf("Invalid increment: %s", err)) return 1 } - - inc = int(dur / time.Second) + increment = parsed + default: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args))) + return 1 } client, err := c.Client() if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) + c.UI.Error(err.Error()) return 2 } - // If the given token is the same as the client's, use renew-self instead - // as this is far more likely to be allowed via policy var secret *api.Secret + inc := truncateToSeconds(increment) if token == "" { secret, err = client.Auth().Token().RenewSelf(inc) } else { secret, err = client.Auth().Token().Renew(token, inc) } if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error renewing token: %s", err)) - return 1 + c.UI.Error(fmt.Sprintf("Error renewing token: %s", err)) + return 2 } - return OutputSecret(c.Ui, format, secret) -} - -func (c *TokenRenewCommand) Synopsis() string { - return "Renew an auth token if there is an associated lease" -} - -func (c *TokenRenewCommand) Help() string { - helpText := ` -Usage: vault token-renew [options] [token] [increment] - - Renew an auth token, extending the amount of time it can be used. If a token - is given to the command, '/auth/token/renew' will be called with the given - token; otherwise, '/auth/token/renew-self' will be called with the client - token. - - This command is similar to "renew", but "renew" is only for leases; this - command is only for tokens. - - An optional increment can be given to request a certain number of seconds to - increment the lease. This request is advisory; Vault may not adhere to it at - all. If a token is being passed in on the command line, the increment can as - well; otherwise it must be passed in via the '-increment' flag. - -General Options: -` + meta.GeneralOptionsUsage() + ` -Token Renew Options: - - -increment=3600 The desired increment. If not supplied, Vault will - use the default TTL. If supplied, it may still be - ignored. This can be submitted as an integer number - of seconds or a string duration (e.g. "72h"). - - -format=table The format for output. By default it is a whitespace- - delimited table. This can also be json or yaml. - -` - return strings.TrimSpace(helpText) + return OutputSecret(c.UI, c.flagFormat, secret) } diff --git a/command/token_renew_test.go b/command/token_renew_test.go index 270ee7ec71..2c36412932 100644 --- a/command/token_renew_test.go +++ b/command/token_renew_test.go @@ -1,177 +1,204 @@ package command import ( + "encoding/json" + "strconv" + "strings" "testing" - "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" ) -func TestTokenRenew(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func testTokenRenewCommand(tb testing.TB) (*cli.MockUi, *TokenRenewCommand) { + tb.Helper() - ui := new(cli.MockUi) - c := &TokenRenewCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + ui := cli.NewMockUi() + return ui, &TokenRenewCommand{ + BaseCommand: &BaseCommand{ + UI: ui, }, } - - args := []string{ - "-address", addr, - } - - // Run it once for client - c.Run(args) - - // Create a token - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } - resp, err := client.Auth().Token().Create(&api.TokenCreateRequest{ - Lease: "1h", - }) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Renew, passing in the token - args = append(args, resp.Auth.ClientToken) - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } } -func TestTokenRenewWithIncrement(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func TestTokenRenewCommand_Run(t *testing.T) { + t.Parallel() - ui := new(cli.MockUi) - c := &TokenRenewCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + cases := []struct { + name string + args []string + out string + code int + }{ + { + "too_many_args", + []string{"foo", "bar", "baz"}, + "Too many arguments", + 1, + }, + { + "default", + nil, + "", + 0, + }, + { + "increment", + []string{"-increment", "60s"}, + "", + 0, + }, + { + "increment_no_suffix", + []string{"-increment", "60"}, + "", + 0, + }, + { + "format", + []string{"-format", "json"}, + "{", + 0, + }, + { + "format_bad", + []string{"-format", "nope-not-real"}, + "Invalid output format", + 1, }, } - args := []string{ - "-address", addr, - } + t.Run("validations", func(t *testing.T) { + t.Parallel() - // Run it once for client - c.Run(args) + for _, tc := range cases { + tc := tc - // Create a token - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } - resp, err := client.Auth().Token().Create(&api.TokenCreateRequest{ - Lease: "1h", + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + // Login with the token so we can renew-self. + token, _ := testTokenAndAccessor(t, client) + client.SetToken(token) + + ui, cmd := testTokenRenewCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } }) - if err != nil { - t.Fatalf("err: %s", err) - } - // Renew, passing in the token - args = append(args, resp.Auth.ClientToken) - args = append(args, "72h") - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } -} + t.Run("token", func(t *testing.T) { + t.Parallel() -func TestTokenRenewSelf(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() + client, closer := testVaultServer(t) + defer closer() - ui := new(cli.MockUi) - c := &TokenRenewCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, - }, - } + token, _ := testTokenAndAccessor(t, client) - args := []string{ - "-address", addr, - } + _, cmd := testTokenRenewCommand(t) + cmd.client = client - // Run it once for client - c.Run(args) + code := cmd.Run([]string{ + "-increment", "30m", + token, + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } - // Create a token - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } - resp, err := client.Auth().Token().Create(&api.TokenCreateRequest{ - Lease: "1h", + secret, err := client.Auth().Token().Lookup(token) + if err != nil { + t.Fatal(err) + } + + str := string(secret.Data["ttl"].(json.Number)) + ttl, err := strconv.ParseInt(str, 10, 64) + if err != nil { + t.Fatalf("bad ttl: %#v", secret.Data["ttl"]) + } + if exp := int64(1800); ttl > exp { + t.Errorf("expected %d to be <= to %d", ttl, exp) + } }) - if err != nil { - t.Fatalf("err: %s", err) - } - if resp.Auth.ClientToken == "" { - t.Fatal("returned client token is empty") - } - c.Meta.ClientToken = resp.Auth.ClientToken + t.Run("self", func(t *testing.T) { + t.Parallel() - // Renew using the self endpoint - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } -} + client, closer := testVaultServer(t) + defer closer() -func TestTokenRenewSelfWithIncrement(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() + token, _ := testTokenAndAccessor(t, client) - ui := new(cli.MockUi) - c := &TokenRenewCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, - }, - } + // Get the old token and login as the new token. We need the old token + // to query after the lookup, but we need the new token on the client. + oldToken := client.Token() + client.SetToken(token) - args := []string{ - "-address", addr, - } + _, cmd := testTokenRenewCommand(t) + cmd.client = client - // Run it once for client - c.Run(args) + code := cmd.Run([]string{ + "-increment", "30m", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } - // Create a token - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } - resp, err := client.Auth().Token().Create(&api.TokenCreateRequest{ - Lease: "1h", + client.SetToken(oldToken) + secret, err := client.Auth().Token().Lookup(token) + if err != nil { + t.Fatal(err) + } + + str := string(secret.Data["ttl"].(json.Number)) + ttl, err := strconv.ParseInt(str, 10, 64) + if err != nil { + t.Fatalf("bad ttl: %#v", secret.Data["ttl"]) + } + if exp := int64(1800); ttl > exp { + t.Errorf("expected %d to be <= to %d", ttl, exp) + } }) - if err != nil { - t.Fatalf("err: %s", err) - } - if resp.Auth.ClientToken == "" { - t.Fatal("returned client token is empty") - } - c.Meta.ClientToken = resp.Auth.ClientToken + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() - args = append(args, "-increment=72h") - // Renew using the self endpoint - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testTokenRenewCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "foo/bar", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error renewing token: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testTokenRenewCommand(t) + assertNoTabs(t, cmd) + }) } diff --git a/command/token_revoke.go b/command/token_revoke.go index 6e4105d0f2..351ce639c3 100644 --- a/command/token_revoke.go +++ b/command/token_revoke.go @@ -4,135 +4,171 @@ import ( "fmt" "strings" - "github.com/hashicorp/vault/meta" + "github.com/mitchellh/cli" + "github.com/posener/complete" ) -// TokenRevokeCommand is a Command that mounts a new mount. +var _ cli.Command = (*TokenRevokeCommand)(nil) +var _ cli.CommandAutocomplete = (*TokenRevokeCommand)(nil) + type TokenRevokeCommand struct { - meta.Meta + *BaseCommand + + flagAccessor bool + flagSelf bool + flagMode string +} + +func (c *TokenRevokeCommand) Synopsis() string { + return "Revoke a token and its children" +} + +func (c *TokenRevokeCommand) Help() string { + helpText := ` +Usage: vault token revoke [options] [TOKEN | ACCESSOR] + + Revokes authentication tokens and their children. If a TOKEN is not provided, + the locally authenticated token is used. The "-mode" flag can be used to + control the behavior of the revocation. See the "-mode" flag documentation + for more information. + + Revoke a token and all the token's children: + + $ vault token revoke 96ddf4bc-d217-f3ba-f9bd-017055595017 + + Revoke a token leaving the token's children: + + $ vault token revoke -mode=orphan 96ddf4bc-d217-f3ba-f9bd-017055595017 + + Revoke a token by accessor: + + $ vault token revoke -accessor 9793c9b3-e04a-46f3-e7b8-748d7da248da + + For a full list of examples, please see the documentation. + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *TokenRevokeCommand) Flags() *FlagSets { + set := c.flagSet(FlagSetHTTP) + + f := set.NewFlagSet("Command Options") + + f.BoolVar(&BoolVar{ + Name: "accessor", + Target: &c.flagAccessor, + Default: false, + EnvVar: "", + Completion: complete.PredictNothing, + Usage: "Treat the argument as an accessor instead of a token.", + }) + + f.BoolVar(&BoolVar{ + Name: "self", + Target: &c.flagSelf, + Default: false, + EnvVar: "", + Completion: complete.PredictNothing, + Usage: "Perform the revocation on the currently authenticated token.", + }) + + f.StringVar(&StringVar{ + Name: "mode", + Target: &c.flagMode, + Default: "", + EnvVar: "", + Completion: complete.PredictSet("orphan", "path"), + Usage: "Type of revocation to perform. If unspecified, Vault will revoke " + + "the token and all of the token's children. If \"orphan\", Vault will " + + "revoke only the token, leaving the children as orphans. If \"path\", " + + "tokens created from the given authentication path prefix are deleted " + + "along with their children.", + }) + + return set +} + +func (c *TokenRevokeCommand) AutocompleteArgs() complete.Predictor { + return nil +} + +func (c *TokenRevokeCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() } func (c *TokenRevokeCommand) Run(args []string) int { - var mode string - var accessor bool - var self bool - var token string - flags := c.Meta.FlagSet("token-revoke", meta.FlagSetDefault) - flags.BoolVar(&accessor, "accessor", false, "") - flags.BoolVar(&self, "self", false, "") - flags.StringVar(&mode, "mode", "", "") - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) return 1 } - args = flags.Args() - switch { - case len(args) == 1 && !self: - token = args[0] - case len(args) != 0 && self: - flags.Usage() - c.Ui.Error(fmt.Sprintf( - "\ntoken-revoke expects no arguments when revoking self")) + args = f.Args() + token := "" + if len(args) > 0 { + token = strings.TrimSpace(args[0]) + } + + switch c.flagMode { + case "", "orphan", "path": + default: + c.UI.Error(fmt.Sprintf("Invalid mode: %s", c.flagMode)) return 1 - case len(args) != 1 && !self: - flags.Usage() - c.Ui.Error(fmt.Sprintf( - "\ntoken-revoke expects one argument or the 'self' flag")) + } + + switch { + case c.flagSelf && len(args) > 0: + c.UI.Error(fmt.Sprintf("Too many arguments with -self (expected 0, got %d)", len(args))) + return 1 + case !c.flagSelf && len(args) > 1: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 1 or -self, got %d)", len(args))) + return 1 + case !c.flagSelf && len(args) < 1: + c.UI.Error(fmt.Sprintf("Not enough arguments (expected 1 or -self, got %d)", len(args))) + return 1 + case c.flagSelf && c.flagAccessor: + c.UI.Error("Cannot use -self with -accessor!") + return 1 + case c.flagSelf && c.flagMode != "": + c.UI.Error("Cannot use -self with -mode!") + return 1 + case c.flagAccessor && c.flagMode == "orphan": + c.UI.Error("Cannot use -accessor with -mode=orphan!") + return 1 + case c.flagAccessor && c.flagMode == "path": + c.UI.Error("Cannot use -accessor with -mode=path!") return 1 } client, err := c.Client() if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) + c.UI.Error(err.Error()) return 2 } - var fn func(string) error + var revokeFn func(string) error // Handle all 6 possible combinations switch { - case !accessor && self && mode == "": - fn = client.Auth().Token().RevokeSelf - case !accessor && !self && mode == "": - fn = client.Auth().Token().RevokeTree - case !accessor && !self && mode == "orphan": - fn = client.Auth().Token().RevokeOrphan - case !accessor && !self && mode == "path": - fn = client.Sys().RevokePrefix - case accessor && !self && mode == "": - fn = client.Auth().Token().RevokeAccessor - case accessor && self: - c.Ui.Error("token-revoke cannot be run on self when 'accessor' flag is set") - return 1 - case self && mode != "": - c.Ui.Error("token-revoke cannot be run on self when 'mode' flag is set") - return 1 - case accessor && mode == "orphan": - c.Ui.Error("token-revoke cannot be run for 'orphan' mode when 'accessor' flag is set") - return 1 - case accessor && mode == "path": - c.Ui.Error("token-revoke cannot be run for 'path' mode when 'accessor' flag is set") - return 1 + case !c.flagAccessor && c.flagSelf && c.flagMode == "": + revokeFn = client.Auth().Token().RevokeSelf + case !c.flagAccessor && !c.flagSelf && c.flagMode == "": + revokeFn = client.Auth().Token().RevokeTree + case !c.flagAccessor && !c.flagSelf && c.flagMode == "orphan": + revokeFn = client.Auth().Token().RevokeOrphan + case !c.flagAccessor && !c.flagSelf && c.flagMode == "path": + revokeFn = client.Sys().RevokePrefix + case c.flagAccessor && !c.flagSelf && c.flagMode == "": + revokeFn = client.Auth().Token().RevokeAccessor } - if err := fn(token); err != nil { - c.Ui.Error(fmt.Sprintf( - "Error revoking token: %s", err)) + if err := revokeFn(token); err != nil { + c.UI.Error(fmt.Sprintf("Error revoking token: %s", err)) return 2 } - c.Ui.Output("Success! Token revoked if it existed.") + c.UI.Output("Success! Revoked token (if it existed)") return 0 } - -func (c *TokenRevokeCommand) Synopsis() string { - return "Revoke one or more auth tokens" -} - -func (c *TokenRevokeCommand) Help() string { - helpText := ` -Usage: vault token-revoke [options] [token|accessor] - - Revoke one or more auth tokens. - - This command revokes auth tokens. Use the "revoke" command for - revoking secrets. - - Depending on the flags used, auth tokens can be revoked in multiple ways - depending on the "-mode" flag: - - * Without any value, the token specified and all of its children - will be revoked. - - * With the "orphan" value, only the specific token will be revoked. - All of its children will be orphaned. - - * With the "path" value, tokens created from the given auth path - prefix will be deleted, along with all their children. In this case - the "token" arg above is actually a "path". This mode does *not* - work with token values or parts of token values. - - Token can be revoked using the token accessor. This can be done by - setting the '-accessor' flag. Note that when '-accessor' flag is set, - '-mode' should not be set for 'orphan' or 'path'. This is because, - a token accessor always revokes the token along with its child tokens. - -General Options: -` + meta.GeneralOptionsUsage() + ` -Token Options: - - -accessor A boolean flag, if set, treats the argument as an accessor of the token. - Note that accessor can also be used for looking up the token properties - via '/auth/token/lookup-accessor/' endpoint. - Accessor is used when there is no access to token ID. - - -self A boolean flag, if set, the operation is performed on the currently - authenticated token i.e. lookup-self. - - -mode=value The type of revocation to do. See the documentation - above for more information. - -` - return strings.TrimSpace(helpText) -} diff --git a/command/token_revoke_test.go b/command/token_revoke_test.go index 7265a106ee..60a1345c4b 100644 --- a/command/token_revoke_test.go +++ b/command/token_revoke_test.go @@ -1,102 +1,222 @@ package command import ( + "strings" "testing" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" ) -func TestTokenRevokeAccessor(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func testTokenRevokeCommand(tb testing.TB) (*cli.MockUi, *TokenRevokeCommand) { + tb.Helper() - ui := new(cli.MockUi) - c := &TokenRevokeCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + ui := cli.NewMockUi() + return ui, &TokenRevokeCommand{ + BaseCommand: &BaseCommand{ + UI: ui, }, } - - args := []string{ - "-address", addr, - } - - // Run it once for client - c.Run(args) - - // Create a token - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } - resp, err := client.Auth().Token().Create(nil) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Treat the argument as accessor - args = append(args, "-accessor") - if code := c.Run(args); code == 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - // Verify it worked with proper accessor - args1 := append(args, resp.Auth.Accessor) - if code := c.Run(args1); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - // Fail if mode is set to 'orphan' when accessor is set - args2 := append(args, "-mode=\"orphan\"") - if code := c.Run(args2); code == 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - // Fail if mode is set to 'path' when accessor is set - args3 := append(args, "-mode=\"path\"") - if code := c.Run(args3); code == 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } } -func TestTokenRevoke(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func TestTokenRevokeCommand_Run(t *testing.T) { + t.Parallel() - ui := new(cli.MockUi) - c := &TokenRevokeCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + validations := []struct { + name string + args []string + out string + code int + }{ + { + "bad_mode", + []string{"-mode=banana"}, + "Invalid mode", + 1, + }, + { + "empty", + nil, + "Not enough arguments", + 1, + }, + { + "args_with_self", + []string{"-self", "abcd1234"}, + "Too many arguments", + 1, + }, + { + "too_many_args", + []string{"abcd1234", "efgh5678"}, + "Too many arguments", + 1, + }, + { + "self_and_accessor", + []string{"-self", "-accessor"}, + "Cannot use -self with -accessor", + 1, + }, + { + "self_and_mode", + []string{"-self", "-mode=orphan"}, + "Cannot use -self with -mode", + 1, + }, + { + "accessor_and_mode_orphan", + []string{"-accessor", "-mode=orphan", "abcd1234"}, + "Cannot use -accessor with -mode=orphan", + 1, + }, + { + "accessor_and_mode_path", + []string{"-accessor", "-mode=path", "abcd1234"}, + "Cannot use -accessor with -mode=path", + 1, }, } - args := []string{ - "-address", addr, - } + t.Run("validations", func(t *testing.T) { + t.Parallel() - // Run it once for client - c.Run(args) + for _, tc := range validations { + tc := tc - // Create a token - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } - resp, err := client.Auth().Token().Create(nil) - if err != nil { - t.Fatalf("err: %s", err) - } + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - // Verify it worked - args = append(args, resp.Auth.ClientToken) - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } + ui, cmd := testTokenRevokeCommand(t) + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) + + t.Run("token", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + token, _ := testTokenAndAccessor(t, client) + + ui, cmd := testTokenRevokeCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + token, + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Success! Revoked token" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + secret, err := client.Auth().Token().Lookup(token) + if secret != nil || err == nil { + t.Errorf("expected token to be revoked: %#v", secret) + } + }) + + t.Run("self", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testTokenRevokeCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "-self", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Success! Revoked token" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + secret, err := client.Auth().Token().LookupSelf() + if secret != nil || err == nil { + t.Errorf("expected token to be revoked: %#v", secret) + } + }) + + t.Run("accessor", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + token, accessor := testTokenAndAccessor(t, client) + + ui, cmd := testTokenRevokeCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "-accessor", + accessor, + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Success! Revoked token" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + secret, err := client.Auth().Token().Lookup(token) + if secret != nil || err == nil { + t.Errorf("expected token to be revoked: %#v", secret) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testTokenRevokeCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "abcd1234", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error revoking token: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testTokenRevokeCommand(t) + assertNoTabs(t, cmd) + }) } diff --git a/command/unmount.go b/command/unmount.go deleted file mode 100644 index b04e532a39..0000000000 --- a/command/unmount.go +++ /dev/null @@ -1,67 +0,0 @@ -package command - -import ( - "fmt" - "strings" - - "github.com/hashicorp/vault/meta" -) - -// UnmountCommand is a Command that mounts a new mount. -type UnmountCommand struct { - meta.Meta -} - -func (c *UnmountCommand) Run(args []string) int { - flags := c.Meta.FlagSet("mount", meta.FlagSetDefault) - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - args = flags.Args() - if len(args) != 1 { - flags.Usage() - c.Ui.Error(fmt.Sprintf( - "\nunmount expects one argument: the path to unmount")) - return 1 - } - - path := args[0] - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 2 - } - - if err := client.Sys().Unmount(path); err != nil { - c.Ui.Error(fmt.Sprintf( - "Unmount error: %s", err)) - return 2 - } - - c.Ui.Output(fmt.Sprintf( - "Successfully unmounted '%s' if it was mounted", path)) - - return 0 -} - -func (c *UnmountCommand) Synopsis() string { - return "Unmount a secret backend" -} - -func (c *UnmountCommand) Help() string { - helpText := ` -Usage: vault unmount [options] path - - Unmount a secret backend. - - This command unmounts a secret backend. All the secrets created - by this backend will be revoked and its Vault data will be deleted. - -General Options: -` + meta.GeneralOptionsUsage() - return strings.TrimSpace(helpText) -} diff --git a/command/unmount_test.go b/command/unmount_test.go deleted file mode 100644 index 1af5ef8b0f..0000000000 --- a/command/unmount_test.go +++ /dev/null @@ -1,47 +0,0 @@ -package command - -import ( - "testing" - - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" - "github.com/mitchellh/cli" -) - -func TestUnmount(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - c := &UnmountCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, - }, - } - - args := []string{ - "-address", addr, - "secret", - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } - - mounts, err := client.Sys().ListMounts() - if err != nil { - t.Fatalf("err: %s", err) - } - - _, ok := mounts["secret/"] - if ok { - t.Fatal("should not have mount") - } -} diff --git a/command/unseal.go b/command/unseal.go deleted file mode 100644 index 2dfb9476de..0000000000 --- a/command/unseal.go +++ /dev/null @@ -1,128 +0,0 @@ -package command - -import ( - "fmt" - "os" - "strings" - - "github.com/hashicorp/vault/helper/password" - "github.com/hashicorp/vault/meta" -) - -// UnsealCommand is a Command that unseals the vault. -type UnsealCommand struct { - meta.Meta - - // Key can be used to pre-seed the key. If it is set, it will not - // be asked with the `password` helper. - Key string -} - -func (c *UnsealCommand) Run(args []string) int { - var reset bool - flags := c.Meta.FlagSet("unseal", meta.FlagSetDefault) - flags.BoolVar(&reset, "reset", false, "") - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 2 - } - - sealStatus, err := client.Sys().SealStatus() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error checking seal status: %s", err)) - return 2 - } - - if !sealStatus.Sealed { - c.Ui.Output("Vault is already unsealed.") - return 0 - } - - args = flags.Args() - if reset { - sealStatus, err = client.Sys().ResetUnsealProcess() - } else { - value := c.Key - if len(args) > 0 { - value = args[0] - } - if value == "" { - fmt.Printf("Key (will be hidden): ") - value, err = password.Read(os.Stdin) - fmt.Printf("\n") - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error attempting to ask for password. The raw error message\n"+ - "is shown below, but the most common reason for this error is\n"+ - "that you attempted to pipe a value into unseal or you're\n"+ - "executing `vault unseal` from outside of a terminal.\n\n"+ - "You should use `vault unseal` from a terminal for maximum\n"+ - "security. If this isn't an option, the unseal key can be passed\n"+ - "in using the first parameter.\n\n"+ - "Raw error: %s", err)) - return 1 - } - } - sealStatus, err = client.Sys().Unseal(strings.TrimSpace(value)) - } - - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error: %s", err)) - return 1 - } - - c.Ui.Output(fmt.Sprintf( - "Sealed: %v\n"+ - "Key Shares: %d\n"+ - "Key Threshold: %d\n"+ - "Unseal Progress: %d\n"+ - "Unseal Nonce: %v", - sealStatus.Sealed, - sealStatus.N, - sealStatus.T, - sealStatus.Progress, - sealStatus.Nonce, - )) - - return 0 -} - -func (c *UnsealCommand) Synopsis() string { - return "Unseals the Vault server" -} - -func (c *UnsealCommand) Help() string { - helpText := ` -Usage: vault unseal [options] [key] - - Unseal the vault by entering a portion of the master key. Once all - portions are entered, the vault will be unsealed. - - Every Vault server initially starts as sealed. It cannot perform any - operation except unsealing until it is sealed. Secrets cannot be accessed - in any way until the vault is unsealed. This command allows you to enter - a portion of the master key to unseal the vault. - - The unseal key can be specified via the command line, but this is - not recommended. The key may then live in your terminal history. This - only exists to assist in scripting. - -General Options: -` + meta.GeneralOptionsUsage() + ` -Unseal Options: - - -reset Reset the unsealing process by throwing away - prior keys in process to unseal the vault. - -` - return strings.TrimSpace(helpText) -} diff --git a/command/unseal_test.go b/command/unseal_test.go deleted file mode 100644 index 699fdd8fb7..0000000000 --- a/command/unseal_test.go +++ /dev/null @@ -1,72 +0,0 @@ -package command - -import ( - "encoding/hex" - "testing" - - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" - "github.com/mitchellh/cli" -) - -func TestUnseal(t *testing.T) { - core := vault.TestCore(t) - keys, _ := vault.TestCoreInit(t, core) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - - for _, key := range keys { - c := &UnsealCommand{ - Key: hex.EncodeToString(key), - Meta: meta.Meta{ - Ui: ui, - }, - } - - args := []string{"-address", addr} - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - } - - sealed, err := core.Sealed() - if err != nil { - t.Fatalf("err: %s", err) - } - if sealed { - t.Fatal("should not be sealed") - } -} - -func TestUnseal_arg(t *testing.T) { - core := vault.TestCore(t) - keys, _ := vault.TestCoreInit(t, core) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - - for _, key := range keys { - c := &UnsealCommand{ - Meta: meta.Meta{ - Ui: ui, - }, - } - - args := []string{"-address", addr, hex.EncodeToString(key)} - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - } - - sealed, err := core.Sealed() - if err != nil { - t.Fatalf("err: %s", err) - } - if sealed { - t.Fatal("should not be sealed") - } -} diff --git a/command/unwrap.go b/command/unwrap.go index 5a21920eb5..8246768383 100644 --- a/command/unwrap.go +++ b/command/unwrap.go @@ -1,79 +1,20 @@ package command import ( - "flag" "fmt" "strings" - "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/meta" + "github.com/mitchellh/cli" + "github.com/posener/complete" ) -// UnwrapCommand is a Command that behaves like ReadCommand but specifically -// for unwrapping cubbyhole-wrapped secrets +var _ cli.Command = (*UnwrapCommand)(nil) +var _ cli.CommandAutocomplete = (*UnwrapCommand)(nil) + +// UnwrapCommand is a Command that behaves like ReadCommand but specifically for +// unwrapping cubbyhole-wrapped secrets type UnwrapCommand struct { - meta.Meta -} - -func (c *UnwrapCommand) Run(args []string) int { - var format string - var field string - var err error - var secret *api.Secret - var flags *flag.FlagSet - flags = c.Meta.FlagSet("unwrap", meta.FlagSetDefault) - flags.StringVar(&format, "format", "table", "") - flags.StringVar(&field, "field", "", "") - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - var tokenID string - - args = flags.Args() - switch len(args) { - case 0: - case 1: - tokenID = args[0] - default: - c.Ui.Error("unwrap expects zero or one argument (the ID of the wrapping token)") - flags.Usage() - return 1 - } - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 2 - } - - secret, err = client.Logical().Unwrap(tokenID) - if err != nil { - c.Ui.Error(err.Error()) - return 1 - } - if secret == nil { - c.Ui.Error("Server gave empty response or secret returned was empty") - return 1 - } - - // Handle single field output - if field != "" { - return PrintRawField(c.Ui, secret, field) - } - - // Check if the original was a list response and format as a list if so - if secret.Data != nil && - len(secret.Data) == 1 && - secret.Data["keys"] != nil { - _, ok := secret.Data["keys"].([]interface{}) - if ok { - return OutputList(c.Ui, format, secret) - } - } - return OutputSecret(c.Ui, format, secret) + *BaseCommand } func (c *UnwrapCommand) Synopsis() string { @@ -82,23 +23,84 @@ func (c *UnwrapCommand) Synopsis() string { func (c *UnwrapCommand) Help() string { helpText := ` -Usage: vault unwrap [options] +Usage: vault unwrap [options] [TOKEN] - Unwrap a wrapped secret. + Unwraps a wrapped secret from Vault by the given token. The result is the + same as the "vault read" operation on the non-wrapped secret. If no token + is given, the data in the currently authenticated token is unwrapped. - Unwraps the data wrapped by the given token ID. The returned result is the - same as a 'read' operation on a non-wrapped secret. + Unwrap the data in the cubbyhole secrets engine for a token: -General Options: -` + meta.GeneralOptionsUsage() + ` -Read Options: + $ vault unwrap 3de9ece1-b347-e143-29b0-dc2dc31caafd - -format=table The format for output. By default it is a whitespace- - delimited table. This can also be json or yaml. + Unwrap the data in the active token: - -field=field If included, the raw value of the specified field - will be output raw to stdout. + $ vault login 848f9ccf-7176-098c-5e2b-75a0689d41cd + $ vault unwrap # unwraps 848f9ccf... + + For a full list of examples and paths, please see the online documentation. + +` + c.Flags().Help() -` return strings.TrimSpace(helpText) } + +func (c *UnwrapCommand) Flags() *FlagSets { + return c.flagSet(FlagSetHTTP | FlagSetOutputField | FlagSetOutputFormat) +} + +func (c *UnwrapCommand) AutocompleteArgs() complete.Predictor { + return c.PredictVaultFiles() +} + +func (c *UnwrapCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *UnwrapCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + token := "" + switch len(args) { + case 0: + // Leave token as "", that will use the local token + case 1: + token = strings.TrimSpace(args[0]) + default: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 0-1, got %d)", len(args))) + return 1 + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + secret, err := client.Logical().Unwrap(token) + if err != nil { + c.UI.Error(fmt.Sprintf("Error unwrapping: %s", err)) + return 2 + } + if secret == nil { + c.UI.Error("Could not find wrapped response") + return 2 + } + + // Handle single field output + if c.flagField != "" { + return PrintRawField(c.UI, secret, c.flagField) + } + + // Check if the original was a list response and format as a list + if _, ok := extractListData(secret); ok { + return OutputList(c.UI, c.flagFormat, secret) + } + return OutputSecret(c.UI, c.flagFormat, secret) +} diff --git a/command/unwrap_test.go b/command/unwrap_test.go index e5dc0bfd33..2d90d47d42 100644 --- a/command/unwrap_test.go +++ b/command/unwrap_test.go @@ -4,104 +4,169 @@ import ( "strings" "testing" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" + "github.com/hashicorp/vault/api" "github.com/mitchellh/cli" ) -func TestUnwrap(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func testUnwrapCommand(tb testing.TB) (*cli.MockUi, *UnwrapCommand) { + tb.Helper() - ui := new(cli.MockUi) - c := &UnwrapCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + ui := cli.NewMockUi() + return ui, &UnwrapCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func testUnwrapWrappedToken(tb testing.TB, client *api.Client, data map[string]interface{}) string { + tb.Helper() + + wrapped, err := client.Logical().Write("sys/wrapping/wrap", data) + if err != nil { + tb.Fatal(err) + } + if wrapped == nil || wrapped.WrapInfo == nil || wrapped.WrapInfo.Token == "" { + tb.Fatalf("missing wrap info: %v", wrapped) + } + return wrapped.WrapInfo.Token +} + +func TestUnwrapCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "too_many_args", + []string{"foo", "bar"}, + "Too many arguments", + 1, + }, + { + "default", + nil, // Token comes in the test func + "bar", + 0, + }, + { + "field", + []string{"-field", "foo"}, + "bar", + 0, + }, + { + "field_not_found", + []string{"-field", "not-a-real-field"}, + "not present in secret", + 1, + }, + { + "format", + []string{"-format", "json"}, + "{", + 0, + }, + { + "format_bad", + []string{"-format", "nope-not-real"}, + "Invalid output format", + 1, }, } - args := []string{ - "-address", addr, - "-field", "zip", - } + t.Run("validations", func(t *testing.T) { + t.Parallel() - // Run once so the client is setup, ignore errors - c.Run(args) + for _, tc := range cases { + tc := tc - // Get the client so we can write data - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - wrapLookupFunc := func(method, path string) string { - if method == "GET" && path == "secret/foo" { - return "60s" + client, closer := testVaultServer(t) + defer closer() + + wrappedToken := testUnwrapWrappedToken(t, client, map[string]interface{}{ + "foo": "bar", + }) + + ui, cmd := testUnwrapCommand(t) + cmd.client = client + + tc.args = append(tc.args, wrappedToken) + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) } - if method == "LIST" && path == "secret" { - return "60s" + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testUnwrapCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "foo", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) } - return "" - } - client.SetWrappingLookupFunc(wrapLookupFunc) - data := map[string]interface{}{"zip": "zap"} - if _, err := client.Logical().Write("secret/foo", data); err != nil { - t.Fatalf("err: %s", err) - } + expected := "Error unwrapping: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) - outer, err := client.Logical().Read("secret/foo") - if err != nil { - t.Fatalf("err: %s", err) - } - if outer == nil { - t.Fatal("outer response was nil") - } - if outer.WrapInfo == nil { - t.Fatalf("outer wrapinfo was nil, response was %#v", *outer) - } + // This test needs its own client and server because it modifies the client + // to the wrapping token + t.Run("local_token", func(t *testing.T) { + t.Parallel() - args = append(args, outer.WrapInfo.Token) + client, closer := testVaultServer(t) + defer closer() - // Run the read - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } + wrappedToken := testUnwrapWrappedToken(t, client, map[string]interface{}{ + "foo": "bar", + }) - output := ui.OutputWriter.String() - if output != "zap\n" { - t.Fatalf("unexpectd output:\n%s", output) - } + ui, cmd := testUnwrapCommand(t) + cmd.client = client + cmd.client.SetToken(wrappedToken) - // Now test with list handling, specifically that it will be called with - // the list output formatter - ui.OutputWriter.Reset() + // Intentionally don't pass the token here - it shoudl use the local token + code := cmd.Run([]string{}) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } - outer, err = client.Logical().List("secret") - if err != nil { - t.Fatalf("err: %s", err) - } - if outer == nil { - t.Fatal("outer response was nil") - } - if outer.WrapInfo == nil { - t.Fatalf("outer wrapinfo was nil, response was %#v", *outer) - } + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, "bar") { + t.Errorf("expected %q to contain %q", combined, "bar") + } + }) - args = []string{ - "-address", addr, - outer.WrapInfo.Token, - } - // Run the read - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() - output = ui.OutputWriter.String() - if strings.TrimSpace(output) != "Keys\n----\nfoo" { - t.Fatalf("unexpected output:\n%s", output) - } + _, cmd := testUnwrapCommand(t) + assertNoTabs(t, cmd) + }) } diff --git a/command/util.go b/command/util.go index e8d46152c9..a51999205f 100644 --- a/command/util.go +++ b/command/util.go @@ -2,10 +2,12 @@ package command import ( "fmt" + "io" "os" - "reflect" "time" + "golang.org/x/crypto/ssh/terminal" + "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/command/token" "github.com/mitchellh/cli" @@ -30,7 +32,9 @@ func DefaultTokenHelper() (token.TokenHelper, error) { return &token.ExternalTokenHelper{BinaryPath: path}, nil } -func PrintRawField(ui cli.Ui, secret *api.Secret, field string) int { +// RawField extracts the raw field from the given data and returns it as a +// string for printing purposes. +func RawField(secret *api.Secret, field string) (string, bool) { var val interface{} switch { case secret.Auth != nil: @@ -76,21 +80,50 @@ func PrintRawField(ui cli.Ui, secret *api.Secret, field string) int { } } - if val != nil { - // c.Ui.Output() prints a CR character which in this case is - // not desired. Since Vault CLI currently only uses BasicUi, - // which writes to standard output, os.Stdout is used here to - // directly print the message. If mitchellh/cli exposes method - // to print without CR, this check needs to be removed. - if reflect.TypeOf(ui).String() == "*cli.BasicUi" { - fmt.Fprintf(os.Stdout, "%v", val) - } else { - ui.Output(fmt.Sprintf("%v", val)) - } - return 0 - } else { - ui.Error(fmt.Sprintf( - "Field %s not present in secret", field)) + str := fmt.Sprintf("%v", val) + return str, val != nil +} + +// PrintRawField prints raw field from the secret. +func PrintRawField(ui cli.Ui, secret *api.Secret, field string) int { + str, ok := RawField(secret, field) + if !ok { + ui.Error(fmt.Sprintf("Field %q not present in secret", field)) return 1 } + + return PrintRaw(ui, str) +} + +// PrintRaw prints a raw value to the terminal. If the process is being "piped" +// to something else, the "raw" value is printed without a newline character. +// Otherwise the value is printed as normal. +func PrintRaw(ui cli.Ui, str string) int { + if terminal.IsTerminal(int(os.Stdout.Fd())) { + ui.Output(str) + } else { + // The cli.Ui prints a CR, which is not wanted since the user probably wants + // just the raw value. + w := getWriterFromUI(ui) + fmt.Fprintf(w, str) + } + return 0 +} + +// getWriterFromUI accepts a cli.Ui and returns the underlying io.Writer by +// unwrapping as many wrapped Uis as necessary. If there is an unknown UI +// type, this falls back to os.Stdout. +func getWriterFromUI(ui cli.Ui) io.Writer { + switch t := ui.(type) { + case *cli.BasicUi: + return t.Writer + case *cli.ColoredUi: + return getWriterFromUI(t.Ui) + case *cli.ConcurrentUi: + return getWriterFromUI(t.Ui) + case *cli.MockUi: + return t.OutputWriter + default: + return os.Stdout + } } diff --git a/command/version.go b/command/version.go index 4665436e2e..398036d3a1 100644 --- a/command/version.go +++ b/command/version.go @@ -1,18 +1,54 @@ package command import ( + "strings" + "github.com/hashicorp/vault/version" "github.com/mitchellh/cli" + "github.com/posener/complete" ) +var _ cli.Command = (*VersionCommand)(nil) +var _ cli.CommandAutocomplete = (*VersionCommand)(nil) + // VersionCommand is a Command implementation prints the version. type VersionCommand struct { + *BaseCommand + VersionInfo *version.VersionInfo - Ui cli.Ui +} + +func (c *VersionCommand) Synopsis() string { + return "Prints the Vault CLI version" } func (c *VersionCommand) Help() string { - return "" + helpText := ` +Usage: vault version + + Prints the version of this Vault CLI. This does not print the target Vault + server version. + + Print the version: + + $ vault version + + There are no arguments or flags to this command. Any additional arguments or + flags are ignored. +` + return strings.TrimSpace(helpText) +} + +func (c *VersionCommand) Flags() *FlagSets { + return nil +} + +func (c *VersionCommand) AutocompleteArgs() complete.Predictor { + return nil +} + +func (c *VersionCommand) AutocompleteFlags() complete.Flags { + return nil } func (c *VersionCommand) Run(_ []string) int { @@ -20,10 +56,6 @@ func (c *VersionCommand) Run(_ []string) int { if version.CgoEnabled { out += " (cgo)" } - c.Ui.Output(out) + c.UI.Output(out) return 0 } - -func (c *VersionCommand) Synopsis() string { - return "Prints the Vault version" -} diff --git a/command/version_test.go b/command/version_test.go index 2a645690fa..2f132fed79 100644 --- a/command/version_test.go +++ b/command/version_test.go @@ -1,11 +1,48 @@ package command import ( + "strings" "testing" + "github.com/hashicorp/vault/version" "github.com/mitchellh/cli" ) -func TestVersionCommand_implements(t *testing.T) { - var _ cli.Command = &VersionCommand{} +func testVersionCommand(tb testing.TB) (*cli.MockUi, *VersionCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &VersionCommand{ + VersionInfo: &version.VersionInfo{}, + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestVersionCommand_Run(t *testing.T) { + t.Parallel() + + t.Run("output", func(t *testing.T) { + t.Parallel() + + ui, cmd := testVersionCommand(t) + code := cmd.Run(nil) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Vault" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to equal %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testVersionCommand(t) + assertNoTabs(t, cmd) + }) } diff --git a/command/wrapping_test.go b/command/wrapping_test.go deleted file mode 100644 index a380cfc7d9..0000000000 --- a/command/wrapping_test.go +++ /dev/null @@ -1,109 +0,0 @@ -package command - -import ( - "os" - "testing" - - "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" - "github.com/mitchellh/cli" -) - -func TestWrapping_Env(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - c := &TokenLookupCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, - }, - } - - args := []string{ - "-address", addr, - } - // Run it once for client - c.Run(args) - - // Create a new token for us to use - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } - resp, err := client.Auth().Token().Create(&api.TokenCreateRequest{ - Lease: "1h", - }) - if err != nil { - t.Fatalf("err: %s", err) - } - - prevWrapTTLEnv := os.Getenv(api.EnvVaultWrapTTL) - os.Setenv(api.EnvVaultWrapTTL, "5s") - defer func() { - os.Setenv(api.EnvVaultWrapTTL, prevWrapTTLEnv) - }() - - // Now when we do a lookup-self the response should be wrapped - args = append(args, resp.Auth.ClientToken) - - resp, err = client.Auth().Token().LookupSelf() - if err != nil { - t.Fatalf("err: %s", err) - } - if resp == nil { - t.Fatal("nil response") - } - if resp.WrapInfo == nil { - t.Fatal("nil wrap info") - } - if resp.WrapInfo.Token == "" || resp.WrapInfo.TTL != 5 { - t.Fatal("did not get token or ttl wrong") - } -} - -func TestWrapping_Flag(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - c := &TokenLookupCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, - }, - } - - args := []string{ - "-address", addr, - "-wrap-ttl", "5s", - } - // Run it once for client - c.Run(args) - - // Create a new token for us to use - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } - resp, err := client.Auth().Token().Create(&api.TokenCreateRequest{ - Lease: "1h", - }) - if err != nil { - t.Fatalf("err: %s", err) - } - if resp == nil { - t.Fatal("nil response") - } - if resp.WrapInfo == nil { - t.Fatal("nil wrap info") - } - if resp.WrapInfo.Token == "" || resp.WrapInfo.TTL != 5 { - t.Fatal("did not get token or ttl wrong") - } -} diff --git a/command/write.go b/command/write.go index 6f7b495b40..ed25ba6ce7 100644 --- a/command/write.go +++ b/command/write.go @@ -6,149 +6,145 @@ import ( "os" "strings" - "github.com/hashicorp/vault/helper/kv-builder" - "github.com/hashicorp/vault/meta" + "github.com/mitchellh/cli" "github.com/posener/complete" ) +var _ cli.Command = (*WriteCommand)(nil) +var _ cli.CommandAutocomplete = (*WriteCommand)(nil) + // WriteCommand is a Command that puts data into the Vault. type WriteCommand struct { - meta.Meta + *BaseCommand - // The fields below can be overwritten for tests - testStdin io.Reader + flagForce bool + + testStdin io.Reader // for tests +} + +func (c *WriteCommand) Synopsis() string { + return "Write data, configuration, and secrets" +} + +func (c *WriteCommand) Help() string { + helpText := ` +Usage: vault write [options] PATH [DATA K=V...] + + Writes data to Vault at the given path. The data can be credentials, secrets, + configuration, or arbitrary data. The specific behavior of this command is + determined at the thing mounted at the path. + + Data is specified as "key=value" pairs. If the value begins with an "@", then + it is loaded from a file. If the value is "-", Vault will read the value from + stdin. + + Persist data in the generic secrets engine: + + $ vault write secret/my-secret foo=bar + + Create a new encryption key in the transit secrets engine: + + $ vault write -f transit/keys/my-key + + Upload an AWS IAM policy from a file on disk: + + $ vault write aws/roles/ops policy=@policy.json + + Configure access to Consul by providing an access token: + + $ echo $MY_TOKEN | vault write consul/config/access token=- + + For a full list of examples and paths, please see the documentation that + corresponds to the secret engines in use. + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *WriteCommand) Flags() *FlagSets { + set := c.flagSet(FlagSetHTTP | FlagSetOutputField | FlagSetOutputFormat) + f := set.NewFlagSet("Command Options") + + f.BoolVar(&BoolVar{ + Name: "force", + Aliases: []string{"f"}, + Target: &c.flagForce, + Default: false, + EnvVar: "", + Completion: complete.PredictNothing, + Usage: "Allow the operation to continue with no key=value pairs. This " + + "allows writing to keys that do not need or expect data.", + }) + + return set +} + +func (c *WriteCommand) AutocompleteArgs() complete.Predictor { + // Return an anything predictor here. Without a way to access help + // information, we don't know what paths we could write to. + return complete.PredictAnything +} + +func (c *WriteCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() } func (c *WriteCommand) Run(args []string) int { - var field, format string - var force bool - flags := c.Meta.FlagSet("write", meta.FlagSetDefault) - flags.StringVar(&format, "format", "table", "") - flags.StringVar(&field, "field", "", "") - flags.BoolVar(&force, "force", false, "") - flags.BoolVar(&force, "f", false, "") - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) return 1 } - args = flags.Args() - if len(args) < 1 { - c.Ui.Error("write requires a path") - flags.Usage() + args = f.Args() + switch { + case len(args) < 1: + c.UI.Error(fmt.Sprintf("Not enough arguments (expected 1, got %d)", len(args))) + return 1 + case len(args) == 1 && !c.flagForce: + c.UI.Error("Must supply data or use -force") return 1 } - if len(args) < 2 && !force { - c.Ui.Error("write expects at least two arguments; use -f to perform the write anyways") - flags.Usage() - return 1 + // Pull our fake stdin if needed + stdin := (io.Reader)(os.Stdin) + if c.testStdin != nil { + stdin = c.testStdin } - path := args[0] - if path[0] == '/' { - path = path[1:] - } + path := sanitizePath(args[0]) - data, err := c.parseData(args[1:]) + data, err := parseArgsData(stdin, args[1:]) if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error loading data: %s", err)) + c.UI.Error(fmt.Sprintf("Failed to parse K=V data: %s", err)) return 1 } client, err := c.Client() if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) + c.UI.Error(err.Error()) return 2 } secret, err := client.Logical().Write(path, data) if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error writing data to %s: %s", path, err)) - return 1 + c.UI.Error(fmt.Sprintf("Error writing data to %s: %s", path, err)) + return 2 } - if secret == nil { - // Don't output anything if people aren't using the "human" output - if format == "table" { - c.Ui.Output(fmt.Sprintf("Success! Data written to: %s", path)) + // Don't output anything unless using the "table" format + if c.flagFormat == "table" { + c.UI.Info(fmt.Sprintf("Success! Data written to: %s", path)) } return 0 } // Handle single field output - if field != "" { - return PrintRawField(c.Ui, secret, field) + if c.flagField != "" { + return PrintRawField(c.UI, secret, c.flagField) } - return OutputSecret(c.Ui, format, secret) -} - -func (c *WriteCommand) parseData(args []string) (map[string]interface{}, error) { - var stdin io.Reader = os.Stdin - if c.testStdin != nil { - stdin = c.testStdin - } - - builder := &kvbuilder.Builder{Stdin: stdin} - if err := builder.Add(args...); err != nil { - return nil, err - } - - return builder.Map(), nil -} - -func (c *WriteCommand) Synopsis() string { - return "Write secrets or configuration into Vault" -} - -func (c *WriteCommand) Help() string { - helpText := ` -Usage: vault write [options] path [data] - - Write data (secrets or configuration) into Vault. - - Write sends data into Vault at the given path. The behavior of the write is - determined by the backend at the given path. For example, writing to - "aws/policy/ops" will create an "ops" IAM policy for the AWS backend - (configuration), but writing to "consul/foo" will write a value directly into - Consul at that key. Check the documentation of the logical backend you're - using for more information on key structure. - - Data is sent via additional arguments in "key=value" pairs. If value begins - with an "@", then it is loaded from a file. Write expects data in the file to - be in JSON format. If you want to start the value with a literal "@", then - prefix the "@" with a slash: "\@". - -General Options: -` + meta.GeneralOptionsUsage() + ` -Write Options: - - -f | -force Force the write to continue without any data values - specified. This allows writing to keys that do not - need or expect any fields to be specified. - - -format=table The format for output. By default it is a whitespace- - delimited table. This can also be json or yaml. - - -field=field If included, the raw value of the specified field - will be output raw to stdout. - -` - return strings.TrimSpace(helpText) -} - -func (c *WriteCommand) AutocompleteArgs() complete.Predictor { - return complete.PredictNothing -} - -func (c *WriteCommand) AutocompleteFlags() complete.Flags { - return complete.Flags{ - "-force": complete.PredictNothing, - "-format": predictFormat, - "-field": complete.PredictNothing, - } + return OutputSecret(c.UI, c.flagFormat, secret) } diff --git a/command/write_test.go b/command/write_test.go index 5aa3c1e559..03aab4c79a 100644 --- a/command/write_test.go +++ b/command/write_test.go @@ -2,271 +2,279 @@ package command import ( "io" - "io/ioutil" - "os" "strings" "testing" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" + "github.com/hashicorp/vault/api" "github.com/mitchellh/cli" ) -func TestWrite(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func testWriteCommand(tb testing.TB) (*cli.MockUi, *WriteCommand) { + tb.Helper() - ui := new(cli.MockUi) - c := &WriteCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + ui := cli.NewMockUi() + return ui, &WriteCommand{ + BaseCommand: &BaseCommand{ + UI: ui, }, } - - args := []string{ - "-address", addr, - "secret/foo", - "value=bar", - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } - - resp, err := client.Logical().Read("secret/foo") - if err != nil { - t.Fatalf("err: %s", err) - } - - if resp.Data["value"] != "bar" { - t.Fatalf("bad: %#v", resp) - } } -func TestWrite_arbitrary(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func TestWriteCommand_Run(t *testing.T) { + t.Parallel() - stdinR, stdinW := io.Pipe() - ui := new(cli.MockUi) - c := &WriteCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + cases := []struct { + name string + args []string + out string + code int + }{ + { + "not_enough_args", + []string{}, + "Not enough arguments", + 1, }, - - testStdin: stdinR, - } - - go func() { - stdinW.Write([]byte(`{"foo":"bar"}`)) - stdinW.Close() - }() - - args := []string{ - "-address", addr, - "secret/foo", - "-", - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } - - resp, err := client.Logical().Read("secret/foo") - if err != nil { - t.Fatalf("err: %s", err) - } - - if resp.Data["foo"] != "bar" { - t.Fatalf("bad: %#v", resp) - } -} - -func TestWrite_escaped(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - c := &WriteCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + { + "empty_kvs", + []string{"secret/write/foo"}, + "Must supply data or use -force", + 1, + }, + { + "force_kvs", + []string{"-force", "auth/token/create"}, + "token", + 0, + }, + { + "force_f_kvs", + []string{"-f", "auth/token/create"}, + "token", + 0, + }, + { + "kvs_no_value", + []string{"secret/write/foo", "foo"}, + "Failed to parse K=V data", + 1, + }, + { + "single_value", + []string{"secret/write/foo", "foo=bar"}, + "Success!", + 0, + }, + { + "multi_value", + []string{"secret/write/foo", "foo=bar", "zip=zap"}, + "Success!", + 0, + }, + { + "field", + []string{ + "-field", "token_renewable", + "auth/token/create", "display_name=foo", + }, + "false", + 0, + }, + { + "field_not_found", + []string{ + "-field", "not-a-real-field", + "auth/token/create", "display_name=foo", + }, + "not present in secret", + 1, }, } - args := []string{ - "-address", addr, - "secret/foo", - "value=\\@bar", - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testWriteCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) } - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } + t.Run("force", func(t *testing.T) { + t.Parallel() - resp, err := client.Logical().Read("secret/foo") - if err != nil { - t.Fatalf("err: %s", err) - } + client, closer := testVaultServer(t) + defer closer() - if resp.Data["value"] != "@bar" { - t.Fatalf("bad: %#v", resp) - } -} - -func TestWrite_file(t *testing.T) { - tf, err := ioutil.TempFile("", "vault") - if err != nil { - t.Fatalf("err: %s", err) - } - tf.Write([]byte(`{"foo":"bar"}`)) - tf.Close() - defer os.Remove(tf.Name()) - - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - c := &WriteCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, - }, - } - - args := []string{ - "-address", addr, - "secret/foo", - "@" + tf.Name(), - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } - - resp, err := client.Logical().Read("secret/foo") - if err != nil { - t.Fatalf("err: %s", err) - } - - if resp.Data["foo"] != "bar" { - t.Fatalf("bad: %#v", resp) - } -} - -func TestWrite_fileValue(t *testing.T) { - tf, err := ioutil.TempFile("", "vault") - if err != nil { - t.Fatalf("err: %s", err) - } - tf.Write([]byte("foo")) - tf.Close() - defer os.Remove(tf.Name()) - - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - c := &WriteCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, - }, - } - - args := []string{ - "-address", addr, - "secret/foo", - "value=@" + tf.Name(), - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - - client, err := c.Client() - if err != nil { - t.Fatalf("err: %s", err) - } - - resp, err := client.Logical().Read("secret/foo") - if err != nil { - t.Fatalf("err: %s", err) - } - - if resp.Data["value"] != "foo" { - t.Fatalf("bad: %#v", resp) - } -} - -func TestWrite_Output(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - c := &WriteCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, - }, - } - - args := []string{ - "-address", addr, - "auth/token/create", - "display_name=foo", - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } - if !strings.Contains(ui.OutputWriter.String(), "Key") { - t.Fatalf("bad: %s", ui.OutputWriter.String()) - } -} - -func TestWrite_force(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() - - ui := new(cli.MockUi) - c := &WriteCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, - }, - } - - args := []string{ - "-address", addr, - "-force", - "sys/rotate", - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } + if err := client.Sys().Mount("transit/", &api.MountInput{ + Type: "transit", + }); err != nil { + t.Fatal(err) + } + + ui, cmd := testWriteCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "-force", + "transit/keys/my-key", + }) + if exp := 0; code != exp { + t.Fatalf("expected %d to be %d: %q", code, exp, ui.ErrorWriter.String()) + } + + secret, err := client.Logical().Read("transit/keys/my-key") + if err != nil { + t.Fatal(err) + } + if secret == nil || secret.Data == nil { + t.Fatal("expected secret to have data") + } + }) + + t.Run("stdin_full", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + stdinR, stdinW := io.Pipe() + go func() { + stdinW.Write([]byte(`{"foo":"bar"}`)) + stdinW.Close() + }() + + _, cmd := testWriteCommand(t) + cmd.client = client + cmd.testStdin = stdinR + + code := cmd.Run([]string{ + "secret/write/stdin_full", "-", + }) + if code != 0 { + t.Fatalf("expected 0 to be %d", code) + } + + secret, err := client.Logical().Read("secret/write/stdin_full") + if err != nil { + t.Fatal(err) + } + if secret == nil || secret.Data == nil { + t.Fatal("expected secret to have data") + } + if exp, act := "bar", secret.Data["foo"].(string); exp != act { + t.Errorf("expected %q to be %q", act, exp) + } + }) + + t.Run("stdin_value", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + stdinR, stdinW := io.Pipe() + go func() { + stdinW.Write([]byte("bar")) + stdinW.Close() + }() + + _, cmd := testWriteCommand(t) + cmd.client = client + cmd.testStdin = stdinR + + code := cmd.Run([]string{ + "secret/write/stdin_value", "foo=-", + }) + if code != 0 { + t.Fatalf("expected 0 to be %d", code) + } + + secret, err := client.Logical().Read("secret/write/stdin_value") + if err != nil { + t.Fatal(err) + } + if secret == nil || secret.Data == nil { + t.Fatal("expected secret to have data") + } + if exp, act := "bar", secret.Data["foo"].(string); exp != act { + t.Errorf("expected %q to be %q", act, exp) + } + }) + + t.Run("integration", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + _, cmd := testWriteCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "secret/write/integration", "foo=bar", "zip=zap", + }) + if code != 0 { + t.Fatalf("expected 0 to be %d", code) + } + + secret, err := client.Logical().Read("secret/write/integration") + if err != nil { + t.Fatal(err) + } + if secret == nil || secret.Data == nil { + t.Fatal("expected secret to have data") + } + if exp, act := "bar", secret.Data["foo"].(string); exp != act { + t.Errorf("expected %q to be %q", act, exp) + } + if exp, act := "zap", secret.Data["zip"].(string); exp != act { + t.Errorf("expected %q to be %q", act, exp) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testWriteCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "foo/bar", "a=b", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error writing data to foo/bar: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testWriteCommand(t) + assertNoTabs(t, cmd) + }) } diff --git a/helper/mfa/duo/path_duo_config.go b/helper/mfa/duo/path_duo_config.go index 166df78c51..c8c317e9d3 100644 --- a/helper/mfa/duo/path_duo_config.go +++ b/helper/mfa/duo/path_duo_config.go @@ -19,7 +19,7 @@ func pathDuoConfig() *framework.Path { }, "username_format": &framework.FieldSchema{ Type: framework.TypeString, - Description: "Format string given auth backend username as argument to create Duo username (default '%s')", + Description: "Format string given auth method username as argument to create Duo username (default '%s')", }, "push_info": &framework.FieldSchema{ Type: framework.TypeString, @@ -101,10 +101,10 @@ type DuoConfig struct { } const pathDuoConfigHelpSyn = ` -Configure Duo second factor behavior. +Configure Duo second factor behavior. ` const pathDuoConfigHelpDesc = ` -This endpoint allows you to configure how the original auth backend username maps to +This endpoint allows you to configure how the original auth method username maps to the Duo username by providing a template format string. ` diff --git a/helper/mfa/mfa.go b/helper/mfa/mfa.go index 8d4628ba0e..bce388da3f 100644 --- a/helper/mfa/mfa.go +++ b/helper/mfa/mfa.go @@ -1,5 +1,5 @@ // Package mfa provides wrappers to add multi-factor authentication -// to any auth backend. +// to any auth method. // // To add MFA to a backend, replace its login path with the // paths returned by MFAPaths and add the additional root diff --git a/helper/parseutil/parseutil.go b/helper/parseutil/parseutil.go index 957d5332e1..ad5a96f641 100644 --- a/helper/parseutil/parseutil.go +++ b/helper/parseutil/parseutil.go @@ -56,6 +56,43 @@ func ParseDurationSecond(in interface{}) (time.Duration, error) { return dur, nil } +func ParseInt(in interface{}) (int64, error) { + var ret int64 + jsonIn, ok := in.(json.Number) + if ok { + in = jsonIn.String() + } + switch in.(type) { + case string: + inp := in.(string) + if inp == "" { + return 0, nil + } + var err error + left, err := strconv.ParseInt(inp, 10, 64) + if err != nil { + return ret, err + } + ret = left + case int: + ret = int64(in.(int)) + case int32: + ret = int64(in.(int32)) + case int64: + ret = in.(int64) + case uint: + ret = int64(in.(uint)) + case uint32: + ret = int64(in.(uint32)) + case uint64: + ret = int64(in.(uint64)) + default: + return 0, errors.New("could not parse value from input") + } + + return ret, nil +} + func ParseBool(in interface{}) (bool, error) { var result bool if err := mapstructure.WeakDecode(in, &result); err != nil { diff --git a/helper/pgpkeys/flag.go b/helper/pgpkeys/flag.go index ccfc64b804..a7371e50a9 100644 --- a/helper/pgpkeys/flag.go +++ b/helper/pgpkeys/flag.go @@ -11,48 +11,90 @@ import ( "github.com/keybase/go-crypto/openpgp" ) -// PGPPubKeyFiles implements the flag.Value interface and allows -// parsing and reading a list of pgp public key files +// PubKeyFileFlag implements flag.Value and command.Example to receive exactly +// one PGP or keybase key via a flag. +type PubKeyFileFlag string + +func (p *PubKeyFileFlag) String() string { return string(*p) } + +func (p *PubKeyFileFlag) Set(val string) error { + if p != nil && *p != "" { + return errors.New("can only be specified once") + } + + keys, err := ParsePGPKeys(strings.Split(val, ",")) + if err != nil { + return err + } + + if len(keys) > 1 { + return errors.New("can only specify one pgp key") + } + + *p = PubKeyFileFlag(keys[0]) + return nil +} + +func (p *PubKeyFileFlag) Example() string { return "keybase:user" } + +// PGPPubKeyFiles implements the flag.Value interface and allows parsing and +// reading a list of PGP public key files. type PubKeyFilesFlag []string func (p *PubKeyFilesFlag) String() string { return fmt.Sprint(*p) } -func (p *PubKeyFilesFlag) Set(value string) error { +func (p *PubKeyFilesFlag) Set(val string) error { if len(*p) > 0 { - return errors.New("pgp-keys can only be specified once") + return errors.New("can only be specified once") } - splitValues := strings.Split(value, ",") - - keybaseMap, err := FetchKeybasePubkeys(splitValues) + keys, err := ParsePGPKeys(strings.Split(val, ",")) if err != nil { return err } - // Now go through the actual flag, and substitute in resolved keybase - // entries where appropriate - for _, keyfile := range splitValues { + *p = PubKeyFilesFlag(keys) + return nil +} + +func (p *PubKeyFilesFlag) Example() string { return "keybase:user1, keybase:user2, ..." } + +// ParsePGPKeys takes a list of PGP keys and parses them either using keybase +// or reading them from disk and returns the "expanded" list of pgp keys in +// the same order. +func ParsePGPKeys(keyfiles []string) ([]string, error) { + keys := make([]string, len(keyfiles)) + + keybaseMap, err := FetchKeybasePubkeys(keyfiles) + if err != nil { + return nil, err + } + + for i, keyfile := range keyfiles { + keyfile = strings.TrimSpace(keyfile) + if strings.HasPrefix(keyfile, kbPrefix) { - key := keybaseMap[keyfile] - if key == "" { - return fmt.Errorf("key for keybase user %s was not found in the map", strings.TrimPrefix(keyfile, kbPrefix)) + key, ok := keybaseMap[keyfile] + if !ok || key == "" { + return nil, fmt.Errorf("keybase user %q not found", strings.TrimPrefix(keyfile, kbPrefix)) } - *p = append(*p, key) + keys[i] = key continue } pgpStr, err := ReadPGPFile(keyfile) if err != nil { - return err + return nil, err } - - *p = append(*p, pgpStr) + keys[i] = pgpStr } - return nil + + return keys, nil } +// ReadPGPFile reads the given PGP file from disk. func ReadPGPFile(path string) (string, error) { if path[0] == '@' { path = path[1:] diff --git a/http/testing.go b/http/testing.go index bda4819d4b..2299006c98 100644 --- a/http/testing.go +++ b/http/testing.go @@ -9,12 +9,12 @@ import ( "github.com/hashicorp/vault/vault" ) -func TestListener(t *testing.T) (net.Listener, string) { +func TestListener(tb testing.TB) (net.Listener, string) { fail := func(format string, args ...interface{}) { panic(fmt.Sprintf(format, args...)) } - if t != nil { - fail = t.Fatalf + if tb != nil { + fail = tb.Fatalf } ln, err := net.Listen("tcp", "127.0.0.1:0") @@ -25,7 +25,7 @@ func TestListener(t *testing.T) (net.Listener, string) { return ln, addr } -func TestServerWithListener(t *testing.T, ln net.Listener, addr string, core *vault.Core) { +func TestServerWithListener(tb testing.TB, ln net.Listener, addr string, core *vault.Core) { // Create a muxer to handle our requests so that we can authenticate // for tests. mux := http.NewServeMux() @@ -39,15 +39,15 @@ func TestServerWithListener(t *testing.T, ln net.Listener, addr string, core *va go server.Serve(ln) } -func TestServer(t *testing.T, core *vault.Core) (net.Listener, string) { - ln, addr := TestListener(t) - TestServerWithListener(t, ln, addr, core) +func TestServer(tb testing.TB, core *vault.Core) (net.Listener, string) { + ln, addr := TestListener(tb) + TestServerWithListener(tb, ln, addr, core) return ln, addr } -func TestServerAuth(t *testing.T, addr string, token string) { +func TestServerAuth(tb testing.TB, addr string, token string) { if _, err := http.Get(addr + "/_test/auth?token=" + token); err != nil { - t.Fatalf("error authenticating: %s", err) + tb.Fatalf("error authenticating: %s", err) } } diff --git a/main.go b/main.go index 6cd34fe36c..7e4b1c9d28 100644 --- a/main.go +++ b/main.go @@ -3,9 +3,9 @@ package main // import "github.com/hashicorp/vault" import ( "os" - "github.com/hashicorp/vault/cli" + "github.com/hashicorp/vault/command" ) func main() { - os.Exit(cli.Run(os.Args[1:])) + os.Exit(command.Run(os.Args[1:])) } diff --git a/meta/meta.go b/meta/meta.go deleted file mode 100644 index b25bfaf462..0000000000 --- a/meta/meta.go +++ /dev/null @@ -1,229 +0,0 @@ -package meta - -import ( - "bufio" - "flag" - "io" - "os" - - "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/command/token" - "github.com/hashicorp/vault/helper/flag-slice" - "github.com/mitchellh/cli" -) - -// FlagSetFlags is an enum to define what flags are present in the -// default FlagSet returned by Meta.FlagSet. -type FlagSetFlags uint - -type TokenHelperFunc func() (token.TokenHelper, error) - -const ( - FlagSetNone FlagSetFlags = 0 - FlagSetServer FlagSetFlags = 1 << iota - FlagSetDefault = FlagSetServer -) - -var ( - additionalOptionsUsage = func() string { - return ` - -wrap-ttl="" Indicates that the response should be wrapped in a - cubbyhole token with the requested TTL. The response - can be fetched by calling the "sys/wrapping/unwrap" - endpoint, passing in the wrapping token's ID. This - is a numeric string with an optional suffix - "s", "m", or "h"; if no suffix is specified it will - be parsed as seconds. May also be specified via - VAULT_WRAP_TTL. - - -policy-override Indicates that any soft-mandatory Sentinel policies - be overridden. -` - } -) - -// Meta contains the meta-options and functionality that nearly every -// Vault command inherits. -type Meta struct { - ClientToken string - Ui cli.Ui - - // The things below can be set, but aren't common - ForceAddress string // Address to force for API clients - - // These are set by the command line flags. - flagAddress string - flagCACert string - flagCAPath string - flagClientCert string - flagClientKey string - flagWrapTTL string - flagInsecure bool - flagMFA []string - flagPolicyOverride bool - - // Queried if no token can be found - TokenHelper TokenHelperFunc -} - -func (m *Meta) DefaultWrappingLookupFunc(operation, path string) string { - if m.flagWrapTTL != "" { - return m.flagWrapTTL - } - - return api.DefaultWrappingLookupFunc(operation, path) -} - -// Client returns the API client to a Vault server given the configured -// flag settings for this command. -func (m *Meta) Client() (*api.Client, error) { - config := api.DefaultConfig() - - if m.flagAddress != "" { - config.Address = m.flagAddress - } - if m.ForceAddress != "" { - config.Address = m.ForceAddress - } - // If we need custom TLS configuration, then set it - if m.flagCACert != "" || m.flagCAPath != "" || m.flagClientCert != "" || m.flagClientKey != "" || m.flagInsecure { - t := &api.TLSConfig{ - CACert: m.flagCACert, - CAPath: m.flagCAPath, - ClientCert: m.flagClientCert, - ClientKey: m.flagClientKey, - TLSServerName: "", - Insecure: m.flagInsecure, - } - if err := config.ConfigureTLS(t); err != nil { - return nil, err - } - } - - // Build the client - client, err := api.NewClient(config) - if err != nil { - return nil, err - } - - client.SetWrappingLookupFunc(m.DefaultWrappingLookupFunc) - - var mfaCreds []string - - // Extract the MFA credentials from environment variable first - if os.Getenv(api.EnvVaultMFA) != "" { - mfaCreds = []string{os.Getenv(api.EnvVaultMFA)} - } - - // If CLI MFA flags were supplied, prefer that over environment variable - if len(m.flagMFA) != 0 { - mfaCreds = m.flagMFA - } - - client.SetMFACreds(mfaCreds) - - client.SetPolicyOverride(m.flagPolicyOverride) - - // If we have a token directly, then set that - token := m.ClientToken - - // Try to set the token to what is already stored - if token == "" { - token = client.Token() - } - - // If we don't have a token, check the token helper - if token == "" { - if m.TokenHelper != nil { - // If we have a token, then set that - tokenHelper, err := m.TokenHelper() - if err != nil { - return nil, err - } - token, err = tokenHelper.Get() - if err != nil { - return nil, err - } - } - } - - // Set the token - if token != "" { - client.SetToken(token) - } - - return client, nil -} - -// FlagSet returns a FlagSet with the common flags that every -// command implements. The exact behavior of FlagSet can be configured -// using the flags as the second parameter, for example to disable -// server settings on the commands that don't talk to a server. -func (m *Meta) FlagSet(n string, fs FlagSetFlags) *flag.FlagSet { - f := flag.NewFlagSet(n, flag.ContinueOnError) - - // FlagSetServer tells us to enable the settings for selecting - // the server information. - if fs&FlagSetServer != 0 { - f.StringVar(&m.flagAddress, "address", "", "") - f.StringVar(&m.flagCACert, "ca-cert", "", "") - f.StringVar(&m.flagCAPath, "ca-path", "", "") - f.StringVar(&m.flagClientCert, "client-cert", "", "") - f.StringVar(&m.flagClientKey, "client-key", "", "") - f.StringVar(&m.flagWrapTTL, "wrap-ttl", "", "") - f.BoolVar(&m.flagInsecure, "insecure", false, "") - f.BoolVar(&m.flagInsecure, "tls-skip-verify", false, "") - f.BoolVar(&m.flagPolicyOverride, "policy-override", false, "") - f.Var((*sliceflag.StringFlag)(&m.flagMFA), "mfa", "") - } - - // Create an io.Writer that writes to our Ui properly for errors. - // This is kind of a hack, but it does the job. Basically: create - // a pipe, use a scanner to break it into lines, and output each line - // to the UI. Do this forever. - errR, errW := io.Pipe() - errScanner := bufio.NewScanner(errR) - go func() { - for errScanner.Scan() { - m.Ui.Error(errScanner.Text()) - } - }() - f.SetOutput(errW) - - return f -} - -// GeneralOptionsUsage returns the usage documentation for commonly -// available options -func GeneralOptionsUsage() string { - general := ` - -address=addr The address of the Vault server. - Overrides the VAULT_ADDR environment variable if set. - - -ca-cert=path Path to a PEM encoded CA cert file to use to - verify the Vault server SSL certificate. - Overrides the VAULT_CACERT environment variable if set. - - -ca-path=path Path to a directory of PEM encoded CA cert files - to verify the Vault server SSL certificate. If both - -ca-cert and -ca-path are specified, -ca-cert is used. - Overrides the VAULT_CAPATH environment variable if set. - - -client-cert=path Path to a PEM encoded client certificate for TLS - authentication to the Vault server. Must also specify - -client-key. Overrides the VAULT_CLIENT_CERT - environment variable if set. - - -client-key=path Path to an unencrypted PEM encoded private key - matching the client certificate from -client-cert. - Overrides the VAULT_CLIENT_KEY environment variable - if set. - - -tls-skip-verify Do not verify TLS certificate. This is highly - not recommended. Verification will also be skipped - if VAULT_SKIP_VERIFY is set. -` - - general += additionalOptionsUsage() - return general -} diff --git a/meta/meta_test.go b/meta/meta_test.go deleted file mode 100644 index 99a294d249..0000000000 --- a/meta/meta_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package meta - -import ( - "flag" - "reflect" - "sort" - "testing" -) - -func TestFlagSet(t *testing.T) { - cases := []struct { - Flags FlagSetFlags - Expected []string - }{ - { - FlagSetNone, - []string{}, - }, - { - FlagSetServer, - []string{"address", "ca-cert", "ca-path", "client-cert", "client-key", "insecure", "mfa", "policy-override", "tls-skip-verify", "wrap-ttl"}, - }, - } - - for i, tc := range cases { - var m Meta - fs := m.FlagSet("foo", tc.Flags) - - actual := make([]string, 0, 0) - fs.VisitAll(func(f *flag.Flag) { - actual = append(actual, f.Name) - }) - sort.Strings(actual) - sort.Strings(tc.Expected) - - if !reflect.DeepEqual(actual, tc.Expected) { - t.Fatalf("%d: flags: %#v\n\nExpected: %#v\nGot: %#v", - i, tc.Flags, tc.Expected, actual) - } - } -} diff --git a/vault/auth.go b/vault/auth.go index 8a1f5d6c95..e2e8046705 100644 --- a/vault/auth.go +++ b/vault/auth.go @@ -115,7 +115,7 @@ func (c *Core) enableCredential(entry *MountEntry) error { // Check for the correct backend type backendType := backend.Type() if entry.Type == "plugin" && backendType != logical.TypeCredential { - return fmt.Errorf("cannot mount '%s' of type '%s' as an auth backend", entry.Config.PluginName, backendType) + return fmt.Errorf("cannot mount '%s' of type '%s' as an auth method", entry.Config.PluginName, backendType) } if err := backend.Initialize(); err != nil { diff --git a/vault/logical_system.go b/vault/logical_system.go index 6dd2150e56..6611d227b1 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -1583,7 +1583,7 @@ func (b *SystemBackend) handleMountTuneRead(ctx context.Context, req *logical.Re logical.ErrInvalidRequest } - // This call will read both logical backend's configuration as well as auth backends'. + // This call will read both logical backend's configuration as well as auth methods'. // Retaining this behavior for backward compatibility. If this behavior is not desired, // an error can be returned if path has a prefix of "auth/". return b.handleTuneReadCommon(path) @@ -1633,7 +1633,7 @@ func (b *SystemBackend) handleMountTuneWrite(ctx context.Context, req *logical.R return logical.ErrorResponse("path must be specified as a string"), logical.ErrInvalidRequest } - // This call will write both logical backend's configuration as well as auth backends'. + // This call will write both logical backend's configuration as well as auth methods'. // Retaining this behavior for backward compatibility. If this behavior is not desired, // an error can be returned if path has a prefix of "auth/". return b.handleTuneWriteCommon(path, data) @@ -3081,10 +3081,10 @@ This path responds to the following HTTP methods. credential backend. POST / - Enable a new auth backend. + Enable a new auth method. DELETE / - Disable the auth backend at the given mount point. + Disable the auth method at the given mount point. `, }, diff --git a/vault/request_handling.go b/vault/request_handling.go index bef50741ba..18ba2d3816 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -464,7 +464,7 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re } if strutil.StrListSubset(auth.Policies, []string{"root"}) { - return logical.ErrorResponse("authentication backends cannot create root tokens"), nil, logical.ErrInvalidRequest + return logical.ErrorResponse("auth methods cannot create root tokens"), nil, logical.ErrInvalidRequest } // Determine the source of the login diff --git a/vendor/github.com/hashicorp/hcl/hcl/printer/nodes.go b/vendor/github.com/hashicorp/hcl/hcl/printer/nodes.go new file mode 100644 index 0000000000..c896d5844a --- /dev/null +++ b/vendor/github.com/hashicorp/hcl/hcl/printer/nodes.go @@ -0,0 +1,779 @@ +package printer + +import ( + "bytes" + "fmt" + "sort" + + "github.com/hashicorp/hcl/hcl/ast" + "github.com/hashicorp/hcl/hcl/token" +) + +const ( + blank = byte(' ') + newline = byte('\n') + tab = byte('\t') + infinity = 1 << 30 // offset or line +) + +var ( + unindent = []byte("\uE123") // in the private use space +) + +type printer struct { + cfg Config + prev token.Pos + + comments []*ast.CommentGroup // may be nil, contains all comments + standaloneComments []*ast.CommentGroup // contains all standalone comments (not assigned to any node) + + enableTrace bool + indentTrace int +} + +type ByPosition []*ast.CommentGroup + +func (b ByPosition) Len() int { return len(b) } +func (b ByPosition) Swap(i, j int) { b[i], b[j] = b[j], b[i] } +func (b ByPosition) Less(i, j int) bool { return b[i].Pos().Before(b[j].Pos()) } + +// collectComments comments all standalone comments which are not lead or line +// comment +func (p *printer) collectComments(node ast.Node) { + // first collect all comments. This is already stored in + // ast.File.(comments) + ast.Walk(node, func(nn ast.Node) (ast.Node, bool) { + switch t := nn.(type) { + case *ast.File: + p.comments = t.Comments + return nn, false + } + return nn, true + }) + + standaloneComments := make(map[token.Pos]*ast.CommentGroup, 0) + for _, c := range p.comments { + standaloneComments[c.Pos()] = c + } + + // next remove all lead and line comments from the overall comment map. + // This will give us comments which are standalone, comments which are not + // assigned to any kind of node. + ast.Walk(node, func(nn ast.Node) (ast.Node, bool) { + switch t := nn.(type) { + case *ast.LiteralType: + if t.LeadComment != nil { + for _, comment := range t.LeadComment.List { + if _, ok := standaloneComments[comment.Pos()]; ok { + delete(standaloneComments, comment.Pos()) + } + } + } + + if t.LineComment != nil { + for _, comment := range t.LineComment.List { + if _, ok := standaloneComments[comment.Pos()]; ok { + delete(standaloneComments, comment.Pos()) + } + } + } + case *ast.ObjectItem: + if t.LeadComment != nil { + for _, comment := range t.LeadComment.List { + if _, ok := standaloneComments[comment.Pos()]; ok { + delete(standaloneComments, comment.Pos()) + } + } + } + + if t.LineComment != nil { + for _, comment := range t.LineComment.List { + if _, ok := standaloneComments[comment.Pos()]; ok { + delete(standaloneComments, comment.Pos()) + } + } + } + } + + return nn, true + }) + + for _, c := range standaloneComments { + p.standaloneComments = append(p.standaloneComments, c) + } + + sort.Sort(ByPosition(p.standaloneComments)) +} + +// output prints creates b printable HCL output and returns it. +func (p *printer) output(n interface{}) []byte { + var buf bytes.Buffer + + switch t := n.(type) { + case *ast.File: + // File doesn't trace so we add the tracing here + defer un(trace(p, "File")) + return p.output(t.Node) + case *ast.ObjectList: + defer un(trace(p, "ObjectList")) + + var index int + for { + // Determine the location of the next actual non-comment + // item. If we're at the end, the next item is at "infinity" + var nextItem token.Pos + if index != len(t.Items) { + nextItem = t.Items[index].Pos() + } else { + nextItem = token.Pos{Offset: infinity, Line: infinity} + } + + // Go through the standalone comments in the file and print out + // the comments that we should be for this object item. + for _, c := range p.standaloneComments { + // Go through all the comments in the group. The group + // should be printed together, not separated by double newlines. + printed := false + newlinePrinted := false + for _, comment := range c.List { + // We only care about comments after the previous item + // we've printed so that comments are printed in the + // correct locations (between two objects for example). + // And before the next item. + if comment.Pos().After(p.prev) && comment.Pos().Before(nextItem) { + // if we hit the end add newlines so we can print the comment + // we don't do this if prev is invalid which means the + // beginning of the file since the first comment should + // be at the first line. + if !newlinePrinted && p.prev.IsValid() && index == len(t.Items) { + buf.Write([]byte{newline, newline}) + newlinePrinted = true + } + + // Write the actual comment. + buf.WriteString(comment.Text) + buf.WriteByte(newline) + + // Set printed to true to note that we printed something + printed = true + } + } + + // If we're not at the last item, write a new line so + // that there is a newline separating this comment from + // the next object. + if printed && index != len(t.Items) { + buf.WriteByte(newline) + } + } + + if index == len(t.Items) { + break + } + + buf.Write(p.output(t.Items[index])) + if index != len(t.Items)-1 { + // Always write a newline to separate us from the next item + buf.WriteByte(newline) + + // Need to determine if we're going to separate the next item + // with a blank line. The logic here is simple, though there + // are a few conditions: + // + // 1. The next object is more than one line away anyways, + // so we need an empty line. + // + // 2. The next object is not a "single line" object, so + // we need an empty line. + // + // 3. This current object is not a single line object, + // so we need an empty line. + current := t.Items[index] + next := t.Items[index+1] + if next.Pos().Line != t.Items[index].Pos().Line+1 || + !p.isSingleLineObject(next) || + !p.isSingleLineObject(current) { + buf.WriteByte(newline) + } + } + index++ + } + case *ast.ObjectKey: + buf.WriteString(t.Token.Text) + case *ast.ObjectItem: + p.prev = t.Pos() + buf.Write(p.objectItem(t)) + case *ast.LiteralType: + buf.Write(p.literalType(t)) + case *ast.ListType: + buf.Write(p.list(t)) + case *ast.ObjectType: + buf.Write(p.objectType(t)) + default: + fmt.Printf(" unknown type: %T\n", n) + } + + return buf.Bytes() +} + +func (p *printer) literalType(lit *ast.LiteralType) []byte { + result := []byte(lit.Token.Text) + switch lit.Token.Type { + case token.HEREDOC: + // Clear the trailing newline from heredocs + if result[len(result)-1] == '\n' { + result = result[:len(result)-1] + } + + // Poison lines 2+ so that we don't indent them + result = p.heredocIndent(result) + case token.STRING: + // If this is a multiline string, poison lines 2+ so we don't + // indent them. + if bytes.IndexRune(result, '\n') >= 0 { + result = p.heredocIndent(result) + } + } + + return result +} + +// objectItem returns the printable HCL form of an object item. An object type +// starts with one/multiple keys and has a value. The value might be of any +// type. +func (p *printer) objectItem(o *ast.ObjectItem) []byte { + defer un(trace(p, fmt.Sprintf("ObjectItem: %s", o.Keys[0].Token.Text))) + var buf bytes.Buffer + + if o.LeadComment != nil { + for _, comment := range o.LeadComment.List { + buf.WriteString(comment.Text) + buf.WriteByte(newline) + } + } + + for i, k := range o.Keys { + buf.WriteString(k.Token.Text) + buf.WriteByte(blank) + + // reach end of key + if o.Assign.IsValid() && i == len(o.Keys)-1 && len(o.Keys) == 1 { + buf.WriteString("=") + buf.WriteByte(blank) + } + } + + buf.Write(p.output(o.Val)) + + if o.Val.Pos().Line == o.Keys[0].Pos().Line && o.LineComment != nil { + buf.WriteByte(blank) + for _, comment := range o.LineComment.List { + buf.WriteString(comment.Text) + } + } + + return buf.Bytes() +} + +// objectType returns the printable HCL form of an object type. An object type +// begins with a brace and ends with a brace. +func (p *printer) objectType(o *ast.ObjectType) []byte { + defer un(trace(p, "ObjectType")) + var buf bytes.Buffer + buf.WriteString("{") + + var index int + var nextItem token.Pos + var commented, newlinePrinted bool + for { + // Determine the location of the next actual non-comment + // item. If we're at the end, the next item is the closing brace + if index != len(o.List.Items) { + nextItem = o.List.Items[index].Pos() + } else { + nextItem = o.Rbrace + } + + // Go through the standalone comments in the file and print out + // the comments that we should be for this object item. + for _, c := range p.standaloneComments { + printed := false + var lastCommentPos token.Pos + for _, comment := range c.List { + // We only care about comments after the previous item + // we've printed so that comments are printed in the + // correct locations (between two objects for example). + // And before the next item. + if comment.Pos().After(p.prev) && comment.Pos().Before(nextItem) { + // If there are standalone comments and the initial newline has not + // been printed yet, do it now. + if !newlinePrinted { + newlinePrinted = true + buf.WriteByte(newline) + } + + // add newline if it's between other printed nodes + if index > 0 { + commented = true + buf.WriteByte(newline) + } + + // Store this position + lastCommentPos = comment.Pos() + + // output the comment itself + buf.Write(p.indent(p.heredocIndent([]byte(comment.Text)))) + + // Set printed to true to note that we printed something + printed = true + + /* + if index != len(o.List.Items) { + buf.WriteByte(newline) // do not print on the end + } + */ + } + } + + // Stuff to do if we had comments + if printed { + // Always write a newline + buf.WriteByte(newline) + + // If there is another item in the object and our comment + // didn't hug it directly, then make sure there is a blank + // line separating them. + if nextItem != o.Rbrace && nextItem.Line != lastCommentPos.Line+1 { + buf.WriteByte(newline) + } + } + } + + if index == len(o.List.Items) { + p.prev = o.Rbrace + break + } + + // At this point we are sure that it's not a totally empty block: print + // the initial newline if it hasn't been printed yet by the previous + // block about standalone comments. + if !newlinePrinted { + buf.WriteByte(newline) + newlinePrinted = true + } + + // check if we have adjacent one liner items. If yes we'll going to align + // the comments. + var aligned []*ast.ObjectItem + for _, item := range o.List.Items[index:] { + // we don't group one line lists + if len(o.List.Items) == 1 { + break + } + + // one means a oneliner with out any lead comment + // two means a oneliner with lead comment + // anything else might be something else + cur := lines(string(p.objectItem(item))) + if cur > 2 { + break + } + + curPos := item.Pos() + + nextPos := token.Pos{} + if index != len(o.List.Items)-1 { + nextPos = o.List.Items[index+1].Pos() + } + + prevPos := token.Pos{} + if index != 0 { + prevPos = o.List.Items[index-1].Pos() + } + + // fmt.Println("DEBUG ----------------") + // fmt.Printf("prev = %+v prevPos: %s\n", prev, prevPos) + // fmt.Printf("cur = %+v curPos: %s\n", cur, curPos) + // fmt.Printf("next = %+v nextPos: %s\n", next, nextPos) + + if curPos.Line+1 == nextPos.Line { + aligned = append(aligned, item) + index++ + continue + } + + if curPos.Line-1 == prevPos.Line { + aligned = append(aligned, item) + index++ + + // finish if we have a new line or comment next. This happens + // if the next item is not adjacent + if curPos.Line+1 != nextPos.Line { + break + } + continue + } + + break + } + + // put newlines if the items are between other non aligned items. + // newlines are also added if there is a standalone comment already, so + // check it too + if !commented && index != len(aligned) { + buf.WriteByte(newline) + } + + if len(aligned) >= 1 { + p.prev = aligned[len(aligned)-1].Pos() + + items := p.alignedItems(aligned) + buf.Write(p.indent(items)) + } else { + p.prev = o.List.Items[index].Pos() + + buf.Write(p.indent(p.objectItem(o.List.Items[index]))) + index++ + } + + buf.WriteByte(newline) + } + + buf.WriteString("}") + return buf.Bytes() +} + +func (p *printer) alignedItems(items []*ast.ObjectItem) []byte { + var buf bytes.Buffer + + // find the longest key and value length, needed for alignment + var longestKeyLen int // longest key length + var longestValLen int // longest value length + for _, item := range items { + key := len(item.Keys[0].Token.Text) + val := len(p.output(item.Val)) + + if key > longestKeyLen { + longestKeyLen = key + } + + if val > longestValLen { + longestValLen = val + } + } + + for i, item := range items { + if item.LeadComment != nil { + for _, comment := range item.LeadComment.List { + buf.WriteString(comment.Text) + buf.WriteByte(newline) + } + } + + for i, k := range item.Keys { + keyLen := len(k.Token.Text) + buf.WriteString(k.Token.Text) + for i := 0; i < longestKeyLen-keyLen+1; i++ { + buf.WriteByte(blank) + } + + // reach end of key + if i == len(item.Keys)-1 && len(item.Keys) == 1 { + buf.WriteString("=") + buf.WriteByte(blank) + } + } + + val := p.output(item.Val) + valLen := len(val) + buf.Write(val) + + if item.Val.Pos().Line == item.Keys[0].Pos().Line && item.LineComment != nil { + for i := 0; i < longestValLen-valLen+1; i++ { + buf.WriteByte(blank) + } + + for _, comment := range item.LineComment.List { + buf.WriteString(comment.Text) + } + } + + // do not print for the last item + if i != len(items)-1 { + buf.WriteByte(newline) + } + } + + return buf.Bytes() +} + +// list returns the printable HCL form of an list type. +func (p *printer) list(l *ast.ListType) []byte { + var buf bytes.Buffer + buf.WriteString("[") + + var longestLine int + for _, item := range l.List { + // for now we assume that the list only contains literal types + if lit, ok := item.(*ast.LiteralType); ok { + lineLen := len(lit.Token.Text) + if lineLen > longestLine { + longestLine = lineLen + } + } + } + + insertSpaceBeforeItem := false + lastHadLeadComment := false + for i, item := range l.List { + // Keep track of whether this item is a heredoc since that has + // unique behavior. + heredoc := false + if lit, ok := item.(*ast.LiteralType); ok && lit.Token.Type == token.HEREDOC { + heredoc = true + } + + if item.Pos().Line != l.Lbrack.Line { + // multiline list, add newline before we add each item + buf.WriteByte(newline) + insertSpaceBeforeItem = false + + // If we have a lead comment, then we want to write that first + leadComment := false + if lit, ok := item.(*ast.LiteralType); ok && lit.LeadComment != nil { + leadComment = true + + // If this isn't the first item and the previous element + // didn't have a lead comment, then we need to add an extra + // newline to properly space things out. If it did have a + // lead comment previously then this would be done + // automatically. + if i > 0 && !lastHadLeadComment { + buf.WriteByte(newline) + } + + for _, comment := range lit.LeadComment.List { + buf.Write(p.indent([]byte(comment.Text))) + buf.WriteByte(newline) + } + } + + // also indent each line + val := p.output(item) + curLen := len(val) + buf.Write(p.indent(val)) + + // if this item is a heredoc, then we output the comma on + // the next line. This is the only case this happens. + comma := []byte{','} + if heredoc { + buf.WriteByte(newline) + comma = p.indent(comma) + } + + buf.Write(comma) + + if lit, ok := item.(*ast.LiteralType); ok && lit.LineComment != nil { + // if the next item doesn't have any comments, do not align + buf.WriteByte(blank) // align one space + for i := 0; i < longestLine-curLen; i++ { + buf.WriteByte(blank) + } + + for _, comment := range lit.LineComment.List { + buf.WriteString(comment.Text) + } + } + + lastItem := i == len(l.List)-1 + if lastItem { + buf.WriteByte(newline) + } + + if leadComment && !lastItem { + buf.WriteByte(newline) + } + + lastHadLeadComment = leadComment + } else { + if insertSpaceBeforeItem { + buf.WriteByte(blank) + insertSpaceBeforeItem = false + } + + // Output the item itself + // also indent each line + val := p.output(item) + curLen := len(val) + buf.Write(val) + + // If this is a heredoc item we always have to output a newline + // so that it parses properly. + if heredoc { + buf.WriteByte(newline) + } + + // If this isn't the last element, write a comma. + if i != len(l.List)-1 { + buf.WriteString(",") + insertSpaceBeforeItem = true + } + + if lit, ok := item.(*ast.LiteralType); ok && lit.LineComment != nil { + // if the next item doesn't have any comments, do not align + buf.WriteByte(blank) // align one space + for i := 0; i < longestLine-curLen; i++ { + buf.WriteByte(blank) + } + + for _, comment := range lit.LineComment.List { + buf.WriteString(comment.Text) + } + } + } + + } + + buf.WriteString("]") + return buf.Bytes() +} + +// indent indents the lines of the given buffer for each non-empty line +func (p *printer) indent(buf []byte) []byte { + var prefix []byte + if p.cfg.SpacesWidth != 0 { + for i := 0; i < p.cfg.SpacesWidth; i++ { + prefix = append(prefix, blank) + } + } else { + prefix = []byte{tab} + } + + var res []byte + bol := true + for _, c := range buf { + if bol && c != '\n' { + res = append(res, prefix...) + } + + res = append(res, c) + bol = c == '\n' + } + return res +} + +// unindent removes all the indentation from the tombstoned lines +func (p *printer) unindent(buf []byte) []byte { + var res []byte + for i := 0; i < len(buf); i++ { + skip := len(buf)-i <= len(unindent) + if !skip { + skip = !bytes.Equal(unindent, buf[i:i+len(unindent)]) + } + if skip { + res = append(res, buf[i]) + continue + } + + // We have a marker. we have to backtrace here and clean out + // any whitespace ahead of our tombstone up to a \n + for j := len(res) - 1; j >= 0; j-- { + if res[j] == '\n' { + break + } + + res = res[:j] + } + + // Skip the entire unindent marker + i += len(unindent) - 1 + } + + return res +} + +// heredocIndent marks all the 2nd and further lines as unindentable +func (p *printer) heredocIndent(buf []byte) []byte { + var res []byte + bol := false + for _, c := range buf { + if bol && c != '\n' { + res = append(res, unindent...) + } + res = append(res, c) + bol = c == '\n' + } + return res +} + +// isSingleLineObject tells whether the given object item is a single +// line object such as "obj {}". +// +// A single line object: +// +// * has no lead comments (hence multi-line) +// * has no assignment +// * has no values in the stanza (within {}) +// +func (p *printer) isSingleLineObject(val *ast.ObjectItem) bool { + // If there is a lead comment, can't be one line + if val.LeadComment != nil { + return false + } + + // If there is assignment, we always break by line + if val.Assign.IsValid() { + return false + } + + // If it isn't an object type, then its not a single line object + ot, ok := val.Val.(*ast.ObjectType) + if !ok { + return false + } + + // If the object has no items, it is single line! + return len(ot.List.Items) == 0 +} + +func lines(txt string) int { + endline := 1 + for i := 0; i < len(txt); i++ { + if txt[i] == '\n' { + endline++ + } + } + return endline +} + +// ---------------------------------------------------------------------------- +// Tracing support + +func (p *printer) printTrace(a ...interface{}) { + if !p.enableTrace { + return + } + + const dots = ". . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . " + const n = len(dots) + i := 2 * p.indentTrace + for i > n { + fmt.Print(dots) + i -= n + } + // i <= n + fmt.Print(dots[0:i]) + fmt.Println(a...) +} + +func trace(p *printer, msg string) *printer { + p.printTrace(msg, "(") + p.indentTrace++ + return p +} + +// Usage pattern: defer un(trace(p, "...")) +func un(p *printer) { + p.indentTrace-- + p.printTrace(")") +} diff --git a/vendor/github.com/hashicorp/hcl/hcl/printer/printer.go b/vendor/github.com/hashicorp/hcl/hcl/printer/printer.go new file mode 100644 index 0000000000..6617ab8e7a --- /dev/null +++ b/vendor/github.com/hashicorp/hcl/hcl/printer/printer.go @@ -0,0 +1,66 @@ +// Package printer implements printing of AST nodes to HCL format. +package printer + +import ( + "bytes" + "io" + "text/tabwriter" + + "github.com/hashicorp/hcl/hcl/ast" + "github.com/hashicorp/hcl/hcl/parser" +) + +var DefaultConfig = Config{ + SpacesWidth: 2, +} + +// A Config node controls the output of Fprint. +type Config struct { + SpacesWidth int // if set, it will use spaces instead of tabs for alignment +} + +func (c *Config) Fprint(output io.Writer, node ast.Node) error { + p := &printer{ + cfg: *c, + comments: make([]*ast.CommentGroup, 0), + standaloneComments: make([]*ast.CommentGroup, 0), + // enableTrace: true, + } + + p.collectComments(node) + + if _, err := output.Write(p.unindent(p.output(node))); err != nil { + return err + } + + // flush tabwriter, if any + var err error + if tw, _ := output.(*tabwriter.Writer); tw != nil { + err = tw.Flush() + } + + return err +} + +// Fprint "pretty-prints" an HCL node to output +// It calls Config.Fprint with default settings. +func Fprint(output io.Writer, node ast.Node) error { + return DefaultConfig.Fprint(output, node) +} + +// Format formats src HCL and returns the result. +func Format(src []byte) ([]byte, error) { + node, err := parser.Parse(src) + if err != nil { + return nil, err + } + + var buf bytes.Buffer + if err := DefaultConfig.Fprint(&buf, node); err != nil { + return nil, err + } + + // Add trailing newline to result + buf.WriteString("\n") + return buf.Bytes(), nil +} diff --git a/vendor/github.com/kr/text/License b/vendor/github.com/kr/text/License new file mode 100644 index 0000000000..480a328059 --- /dev/null +++ b/vendor/github.com/kr/text/License @@ -0,0 +1,19 @@ +Copyright 2012 Keith Rarick + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/vendor/github.com/kr/text/Readme b/vendor/github.com/kr/text/Readme new file mode 100644 index 0000000000..7e6e7c0687 --- /dev/null +++ b/vendor/github.com/kr/text/Readme @@ -0,0 +1,3 @@ +This is a Go package for manipulating paragraphs of text. + +See http://go.pkgdoc.org/github.com/kr/text for full documentation. diff --git a/vendor/github.com/kr/text/doc.go b/vendor/github.com/kr/text/doc.go new file mode 100644 index 0000000000..cf4c198f95 --- /dev/null +++ b/vendor/github.com/kr/text/doc.go @@ -0,0 +1,3 @@ +// Package text provides rudimentary functions for manipulating text in +// paragraphs. +package text diff --git a/vendor/github.com/kr/text/indent.go b/vendor/github.com/kr/text/indent.go new file mode 100644 index 0000000000..4ebac45c09 --- /dev/null +++ b/vendor/github.com/kr/text/indent.go @@ -0,0 +1,74 @@ +package text + +import ( + "io" +) + +// Indent inserts prefix at the beginning of each non-empty line of s. The +// end-of-line marker is NL. +func Indent(s, prefix string) string { + return string(IndentBytes([]byte(s), []byte(prefix))) +} + +// IndentBytes inserts prefix at the beginning of each non-empty line of b. +// The end-of-line marker is NL. +func IndentBytes(b, prefix []byte) []byte { + var res []byte + bol := true + for _, c := range b { + if bol && c != '\n' { + res = append(res, prefix...) + } + res = append(res, c) + bol = c == '\n' + } + return res +} + +// Writer indents each line of its input. +type indentWriter struct { + w io.Writer + bol bool + pre [][]byte + sel int + off int +} + +// NewIndentWriter makes a new write filter that indents the input +// lines. Each line is prefixed in order with the corresponding +// element of pre. If there are more lines than elements, the last +// element of pre is repeated for each subsequent line. +func NewIndentWriter(w io.Writer, pre ...[]byte) io.Writer { + return &indentWriter{ + w: w, + pre: pre, + bol: true, + } +} + +// The only errors returned are from the underlying indentWriter. +func (w *indentWriter) Write(p []byte) (n int, err error) { + for _, c := range p { + if w.bol { + var i int + i, err = w.w.Write(w.pre[w.sel][w.off:]) + w.off += i + if err != nil { + return n, err + } + } + _, err = w.w.Write([]byte{c}) + if err != nil { + return n, err + } + n++ + w.bol = c == '\n' + if w.bol { + w.off = 0 + if w.sel < len(w.pre)-1 { + w.sel++ + } + } + } + return n, nil +} diff --git a/vendor/github.com/kr/text/wrap.go b/vendor/github.com/kr/text/wrap.go new file mode 100644 index 0000000000..b09bb03736 --- /dev/null +++ b/vendor/github.com/kr/text/wrap.go @@ -0,0 +1,86 @@ +package text + +import ( + "bytes" + "math" +) + +var ( + nl = []byte{'\n'} + sp = []byte{' '} +) + +const defaultPenalty = 1e5 + +// Wrap wraps s into a paragraph of lines of length lim, with minimal +// raggedness. +func Wrap(s string, lim int) string { + return string(WrapBytes([]byte(s), lim)) +} + +// WrapBytes wraps b into a paragraph of lines of length lim, with minimal +// raggedness. +func WrapBytes(b []byte, lim int) []byte { + words := bytes.Split(bytes.Replace(bytes.TrimSpace(b), nl, sp, -1), sp) + var lines [][]byte + for _, line := range WrapWords(words, 1, lim, defaultPenalty) { + lines = append(lines, bytes.Join(line, sp)) + } + return bytes.Join(lines, nl) +} + +// WrapWords is the low-level line-breaking algorithm, useful if you need more +// control over the details of the text wrapping process. For most uses, either +// Wrap or WrapBytes will be sufficient and more convenient. +// +// WrapWords splits a list of words into lines with minimal "raggedness", +// treating each byte as one unit, accounting for spc units between adjacent +// words on each line, and attempting to limit lines to lim units. Raggedness +// is the total error over all lines, where error is the square of the +// difference of the length of the line and lim. Too-long lines (which only +// happen when a single word is longer than lim units) have pen penalty units +// added to the error. +func WrapWords(words [][]byte, spc, lim, pen int) [][][]byte { + n := len(words) + + length := make([][]int, n) + for i := 0; i < n; i++ { + length[i] = make([]int, n) + length[i][i] = len(words[i]) + for j := i + 1; j < n; j++ { + length[i][j] = length[i][j-1] + spc + len(words[j]) + } + } + + nbrk := make([]int, n) + cost := make([]int, n) + for i := range cost { + cost[i] = math.MaxInt32 + } + for i := n - 1; i >= 0; i-- { + if length[i][n-1] <= lim || i == n-1 { + cost[i] = 0 + nbrk[i] = n + } else { + for j := i + 1; j < n; j++ { + d := lim - length[i][j-1] + c := d*d + cost[j] + if length[i][j-1] > lim { + c += pen // too-long lines get a worse penalty + } + if c < cost[i] { + cost[i] = c + nbrk[i] = j + } + } + } + } + + var lines [][][]byte + i := 0 + for i < n { + lines = append(lines, words[i:nbrk[i]]) + i = nbrk[i] + } + return lines +} diff --git a/website/config.rb b/website/config.rb index 162afd2c80..4b21fbe4fa 100644 --- a/website/config.rb +++ b/website/config.rb @@ -37,7 +37,6 @@ helpers do # @return [String] def description_for(page) description = (page.data.description || "") - .gsub('"', '') .gsub(/\n+/, ' ') .squeeze(' ') diff --git a/website/redirects.txt b/website/redirects.txt index 528394bbb1..0a317b3f1f 100644 --- a/website/redirects.txt +++ b/website/redirects.txt @@ -84,6 +84,7 @@ /docs/secrets/custom.html /docs/plugin/index.html /docs/secrets/generic/index.html /docs/secrets/kv/index.html /intro/getting-started/acl.html /intro/getting-started/policies.html +/intro/getting-started/secret-backends.html /intro/getting-started/secrets-engines.html /docs/vault-enterprise/index.html /docs/enterprise/index.html /docs/vault-enterprise/replication/index.html /docs/enterprise/replication/index.html @@ -99,4 +100,8 @@ /docs/vault-enterprise/mfa/mfa-pingid.html /docs/enterprise/mfa/mfa-pingid.html /docs/vault-enterprise/mfa/mfa-totp.html /docs/enterprise/mfa/mfa-totp.html -/docs/enterprise/hsm/configuration.html /docs/configuration/seal/pkcs11.html \ No newline at end of file +/docs/commands/environment.html /docs/commands/index.html#environment-variables +/docs/commands/read-write.html /docs/commands/index.html#reading-and-writing-data +/docs/commands/help.html /docs/commands/path-help.html + +/docs/enterprise/hsm/configuration.html /docs/configuration/seal/pkcs11.html diff --git a/website/source/_ember_steps.html.erb b/website/source/_ember_steps.html.erb index c28a008e8b..7980f25c15 100644 --- a/website/source/_ember_steps.html.erb +++ b/website/source/_ember_steps.html.erb @@ -95,7 +95,7 @@