diff --git a/builtin/logical/mysql/backend.go b/builtin/logical/mysql/backend.go new file mode 100644 index 0000000000..1d02222084 --- /dev/null +++ b/builtin/logical/mysql/backend.go @@ -0,0 +1,123 @@ +package mysql + +import ( + "database/sql" + "fmt" + "strings" + "sync" + + _ "github.com/go-sql-driver/mysql" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func Factory(map[string]string) (logical.Backend, error) { + return Backend(), nil +} + +func Backend() *framework.Backend { + var b backend + b.Backend = &framework.Backend{ + Help: strings.TrimSpace(backendHelp), + + PathsSpecial: &logical.Paths{ + Root: []string{ + "config/*", + }, + }, + + Paths: []*framework.Path{ + pathConfigConnection(&b), + pathConfigLease(&b), + pathRoles(&b), + pathRoleCreate(&b), + }, + + Secrets: []*framework.Secret{ + secretCreds(&b), + }, + } + + return b.Backend +} + +type backend struct { + *framework.Backend + + db *sql.DB + lock sync.Mutex +} + +// DB returns the database connection. +func (b *backend) DB(s logical.Storage) (*sql.DB, error) { + b.lock.Lock() + defer b.lock.Unlock() + + // If we already have a DB, we got it! + if b.db != nil { + return b.db, nil + } + + // Otherwise, attempt to make connection + entry, err := s.Get("config/connection") + if err != nil { + return nil, err + } + if entry == nil { + return nil, + fmt.Errorf("configure the DB connection with config/connection first") + } + + var conn string + if err := entry.DecodeJSON(&conn); err != nil { + return nil, err + } + + b.db, err = sql.Open("mysql", conn) + if err != nil { + return nil, err + } + + // Set some connection pool settings. We don't need much of this, + // since the request rate shouldn't be high. + b.db.SetMaxOpenConns(2) + + return b.db, nil +} + +// ResetDB forces a connection next time DB() is called. +func (b *backend) ResetDB() { + b.lock.Lock() + defer b.lock.Unlock() + + if b.db != nil { + b.db.Close() + } + + b.db = nil +} + +// Lease returns the lease information +func (b *backend) Lease(s logical.Storage) (*configLease, error) { + entry, err := s.Get("config/lease") + if err != nil { + return nil, err + } + if entry == nil { + return nil, nil + } + + var result configLease + if err := entry.DecodeJSON(&result); err != nil { + return nil, err + } + + return &result, nil +} + +const backendHelp = ` +The MySQL backend dynamically generates database users. + +After mounting this backend, configure it using the endpoints within +the "config/" path. +` diff --git a/builtin/logical/mysql/path_config_connection.go b/builtin/logical/mysql/path_config_connection.go new file mode 100644 index 0000000000..786ee8c8bb --- /dev/null +++ b/builtin/logical/mysql/path_config_connection.go @@ -0,0 +1,73 @@ +package mysql + +import ( + "database/sql" + "fmt" + + _ "github.com/go-sql-driver/mysql" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func pathConfigConnection(b *backend) *framework.Path { + return &framework.Path{ + Pattern: "config/connection", + Fields: map[string]*framework.FieldSchema{ + "value": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "DB connection string", + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.WriteOperation: b.pathConnectionWrite, + }, + + HelpSynopsis: pathConfigConnectionHelpSyn, + HelpDescription: pathConfigConnectionHelpDesc, + } +} + +func (b *backend) pathConnectionWrite( + req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + connString := data.Get("value").(string) + + // Verify the string + db, err := sql.Open("mysql", connString) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Error validating connection info: %s", err)), nil + } + defer db.Close() + if err := db.Ping(); err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Error validating connection info: %s", err)), nil + } + + // Store it + entry, err := logical.StorageEntryJSON("config/connection", connString) + if err != nil { + return nil, err + } + if err := req.Storage.Put(entry); err != nil { + return nil, err + } + + // Reset the DB connection + b.ResetDB() + return nil, nil +} + +const pathConfigConnectionHelpSyn = ` +Configure the connection string to talk to MySQL. +` + +const pathConfigConnectionHelpDesc = ` +This path configures the connection string used to connect to MySQL. +The value of the string is a Data Source Name (DSN). An example is +using "username:password@protocol(address)/dbname?param=value" + +For example, RDS may look like: "id:password@tcp(your-amazonaws-uri.com:3306)/dbname" + +When configuring the connection string, the backend will verify its validity. +` diff --git a/builtin/logical/mysql/path_config_lease.go b/builtin/logical/mysql/path_config_lease.go new file mode 100644 index 0000000000..d2f0950d71 --- /dev/null +++ b/builtin/logical/mysql/path_config_lease.go @@ -0,0 +1,83 @@ +package mysql + +import ( + "fmt" + "time" + + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func pathConfigLease(b *backend) *framework.Path { + return &framework.Path{ + Pattern: "config/lease", + Fields: map[string]*framework.FieldSchema{ + "lease": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Default lease for roles.", + }, + + "lease_max": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Maximum time a credential is valid for.", + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.WriteOperation: b.pathLeaseWrite, + }, + + HelpSynopsis: pathConfigLeaseHelpSyn, + HelpDescription: pathConfigLeaseHelpDesc, + } +} + +func (b *backend) pathLeaseWrite( + req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + leaseRaw := d.Get("lease").(string) + leaseMaxRaw := d.Get("lease").(string) + + lease, err := time.ParseDuration(leaseRaw) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Invalid lease: %s", err)), nil + } + leaseMax, err := time.ParseDuration(leaseMaxRaw) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Invalid lease: %s", err)), nil + } + + // Store it + entry, err := logical.StorageEntryJSON("config/lease", &configLease{ + Lease: lease, + LeaseMax: leaseMax, + }) + if err != nil { + return nil, err + } + if err := req.Storage.Put(entry); err != nil { + return nil, err + } + + return nil, nil +} + +type configLease struct { + Lease time.Duration + LeaseMax time.Duration +} + +const pathConfigLeaseHelpSyn = ` +Configure the default lease information for generated credentials. +` + +const pathConfigLeaseHelpDesc = ` +This configures the default lease information used for credentials +generated by this backend. The lease specifies the duration that a +credential will be valid for, as well as the maximum session for +a set of credentials. + +The format for the lease is "1h" or integer and then unit. The longest +unit is hour. +` diff --git a/builtin/logical/mysql/path_role_create.go b/builtin/logical/mysql/path_role_create.go new file mode 100644 index 0000000000..ece4dc5846 --- /dev/null +++ b/builtin/logical/mysql/path_role_create.go @@ -0,0 +1,112 @@ +package mysql + +import ( + "fmt" + "math/rand" + "time" + + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" + _ "github.com/lib/pq" +) + +func pathRoleCreate(b *backend) *framework.Path { + return &framework.Path{ + Pattern: `creds/(?P\w+)`, + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Name of the role.", + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ReadOperation: b.pathRoleCreateRead, + }, + + HelpSynopsis: pathRoleCreateReadHelpSyn, + HelpDescription: pathRoleCreateReadHelpDesc, + } +} + +func (b *backend) pathRoleCreateRead( + req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) + + // Get the role + role, err := b.Role(req.Storage, name) + if err != nil { + return nil, err + } + if role == nil { + return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil + } + + // Determine if we have a lease + lease, err := b.Lease(req.Storage) + if err != nil { + return nil, err + } + if lease == nil { + lease = &configLease{Lease: 1 * time.Hour} + } + + // Generate our username and password + username := fmt.Sprintf( + "vault-%s-%d-%d", + req.DisplayName, time.Now().Unix(), rand.Int31n(10000)) + password := generateUUID() + + // Get our connection + db, err := b.DB(req.Storage) + if err != nil { + return nil, err + } + + // Start a transaction + tx, err := db.Begin() + if err != nil { + return nil, err + } + defer tx.Rollback() + + // Execute each query + // Test the query by trying to prepare it + for _, query := range SplitSQL(role.SQL) { + stmt, err := db.Prepare(Query(query, map[string]string{ + "name": username, + "password": password, + })) + if err != nil { + return nil, err + } + if _, err := stmt.Exec(); err != nil { + return nil, err + } + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return nil, err + } + + // Return the secret + resp := b.Secret(SecretCredsType).Response(map[string]interface{}{ + "username": username, + "password": password, + }, map[string]interface{}{ + "username": username, + }) + resp.Secret.Lease = lease.Lease + return resp, nil +} + +const pathRoleCreateReadHelpSyn = ` +Request database credentials for a certain role. +` + +const pathRoleCreateReadHelpDesc = ` +This path reads database credentials for a certain role. The +database credentials will be generated on demand and will be automatically +revoked when the lease is up. +` diff --git a/builtin/logical/mysql/path_roles.go b/builtin/logical/mysql/path_roles.go new file mode 100644 index 0000000000..61370d2d24 --- /dev/null +++ b/builtin/logical/mysql/path_roles.go @@ -0,0 +1,144 @@ +package mysql + +import ( + "fmt" + + _ "github.com/go-sql-driver/mysql" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func pathRoles(b *backend) *framework.Path { + return &framework.Path{ + Pattern: "roles/(?P\\w+)", + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Name of the role.", + }, + + "sql": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "SQL string to create a user. See help for more info.", + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ReadOperation: b.pathRoleRead, + logical.WriteOperation: b.pathRoleCreate, + logical.DeleteOperation: b.pathRoleDelete, + }, + + HelpSynopsis: pathRoleHelpSyn, + HelpDescription: pathRoleHelpDesc, + } +} + +func (b *backend) Role(s logical.Storage, n string) (*roleEntry, error) { + entry, err := s.Get("role/" + n) + if err != nil { + return nil, err + } + if entry == nil { + return nil, nil + } + + var result roleEntry + if err := entry.DecodeJSON(&result); err != nil { + return nil, err + } + + return &result, nil +} + +func (b *backend) pathRoleDelete( + req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + err := req.Storage.Delete("role/" + data.Get("name").(string)) + if err != nil { + return nil, err + } + + return nil, nil +} + +func (b *backend) pathRoleRead( + req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + role, err := b.Role(req.Storage, data.Get("name").(string)) + if err != nil { + return nil, err + } + if role == nil { + return nil, nil + } + + return &logical.Response{ + Data: map[string]interface{}{ + "sql": role.SQL, + }, + }, nil +} + +func (b *backend) pathRoleCreate( + req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) + sql := data.Get("sql").(string) + + // Get our connection + db, err := b.DB(req.Storage) + if err != nil { + return nil, err + } + + // Test the query by trying to prepare it + for _, query := range SplitSQL(sql) { + stmt, err := db.Prepare(Query(query, map[string]string{ + "name": "foo", + "password": "bar", + })) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Error testing query: %s", err)), nil + } + stmt.Close() + } + + // Store it + entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{ + SQL: sql, + }) + if err != nil { + return nil, err + } + if err := req.Storage.Put(entry); err != nil { + return nil, err + } + return nil, nil +} + +type roleEntry struct { + SQL string `json:"sql"` +} + +const pathRoleHelpSyn = ` +Manage the roles that can be created with this backend. +` + +const pathRoleHelpDesc = ` +This path lets you manage the roles that can be created with this backend. + +The "sql" parameter customizes the SQL string used to create the role. +This can be a sequence of SQL queries, each semi-colon seperated. Some +substitution will be done to the SQL string for certain keys. +The names of the variables must be surrounded by "{{" and "}}" to be replaced. + + * "name" - The random username generated for the DB user. + + * "password" - The random password generated for the DB user. + +Example of a decent SQL query to use: + + CREATE USER "{{name}}" IDENTIFIED BY "{{password}}"; GRANT ALL ON db1.* TO "{{name}}"; + +Note the above user would be able to access anything in db1. Please see the MySQL +manual on the GRANT command to learn how to do more fine grained access. +` diff --git a/builtin/logical/mysql/secret_creds.go b/builtin/logical/mysql/secret_creds.go new file mode 100644 index 0000000000..fc1bae7854 --- /dev/null +++ b/builtin/logical/mysql/secret_creds.go @@ -0,0 +1,100 @@ +package mysql + +import ( + "fmt" + "time" + + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +const SecretCredsType = "creds" + +func secretCreds(b *backend) *framework.Secret { + return &framework.Secret{ + Type: SecretCredsType, + Fields: map[string]*framework.FieldSchema{ + "username": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Username", + }, + + "password": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Password", + }, + }, + + DefaultDuration: 1 * time.Hour, + DefaultGracePeriod: 10 * time.Minute, + + Renew: b.secretCredsRenew, + Revoke: b.secretCredsRevoke, + } +} + +func (b *backend) secretCredsRenew( + req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + // Get the lease information + lease, err := b.Lease(req.Storage) + if err != nil { + return nil, err + } + if lease == nil { + lease = &configLease{Lease: 1 * time.Hour} + } + + f := framework.LeaseExtend(lease.Lease, lease.LeaseMax) + return f(req, d) +} + +func (b *backend) secretCredsRevoke( + req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + // Get the username from the internal data + usernameRaw, ok := req.Secret.InternalData["username"] + if !ok { + return nil, fmt.Errorf("secret is missing username internal data") + } + username, ok := usernameRaw.(string) + + // Get our connection + db, err := b.DB(req.Storage) + if err != nil { + return nil, err + } + + // Start a transaction + tx, err := db.Begin() + if err != nil { + return nil, err + } + defer tx.Rollback() + + // Revoke all permissions for the user. This is done before the + // drop, because MySQL explicitly documents that open user connections + // will not be closed. By revoking all grants, at least we ensure + // that the open connection is useless. + stmt, err := tx.Prepare("REVOKE ALL PRIVILEGES, GRANT OPTION FROM ?") + if err != nil { + return nil, err + } + if _, err := stmt.Exec(username); err != nil { + return nil, err + } + + // Drop this user. This only affects the next connection, which is + // why we do the revoke initially. + stmt, err = db.Prepare("DROP USER ?") + if err != nil { + return nil, err + } + if _, err := stmt.Exec(username); err != nil { + return nil, err + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return nil, err + } + return nil, nil +} diff --git a/builtin/logical/mysql/util.go b/builtin/logical/mysql/util.go new file mode 100644 index 0000000000..57529d8f25 --- /dev/null +++ b/builtin/logical/mysql/util.go @@ -0,0 +1,44 @@ +package mysql + +import ( + "crypto/rand" + "fmt" + "strings" +) + +// SplitSQL is used to split a series of SQL statements +func SplitSQL(sql string) []string { + parts := strings.Split(sql, ";") + out := make([]string, 0, len(parts)) + for _, p := range parts { + clean := strings.TrimSpace(p) + if len(clean) > 0 { + out = append(out, clean) + } + } + return out +} + +// Query templates a query for us. +func Query(tpl string, data map[string]string) string { + for k, v := range data { + tpl = strings.Replace(tpl, fmt.Sprintf("{{%s}}", k), v, -1) + } + + return tpl +} + +// generateUUID is used to generate a random UUID +func generateUUID() string { + buf := make([]byte, 16) + if _, err := rand.Read(buf); err != nil { + panic(fmt.Errorf("failed to read random bytes: %v", err)) + } + + return fmt.Sprintf("%08x-%04x-%04x-%04x-%12x", + buf[0:4], + buf[4:6], + buf[6:8], + buf[8:10], + buf[10:16]) +}