diff --git a/builtin/logical/ssh/path_role_create.go b/builtin/logical/ssh/path_role_create.go index 9e082393e5..016715c925 100644 --- a/builtin/logical/ssh/path_role_create.go +++ b/builtin/logical/ssh/path_role_create.go @@ -103,14 +103,16 @@ func (b *backend) pathRoleCreateWrite( otkPublicKeyFileName := otkPrivateKeyFileName + ".pub" //commands to be run on vault server - rmCmd := "rm -f " + otkPrivateKeyFileName + " " + otkPublicKeyFileName + ";" - sshKeygenCmd := "ssh-keygen -f " + otkPrivateKeyFileName + " -t rsa -N ''" + ";" - chmodCmd := "chmod 400 " + otkPrivateKeyFileName + ";" + removeFile(otkPrivateKeyFileName) + removeFile(otkPublicKeyFileName) + dynamicPublicKey, dynamicPrivateKey, _ := generateRSAKeys() + ioutil.WriteFile(otkPrivateKeyFileName, []byte(dynamicPrivateKey), 0600) + ioutil.WriteFile(otkPublicKeyFileName, []byte(dynamicPublicKey), 0600) + //ioutil.WriteFile("testkey.pub", []byte(publicKeyRsa), 0600) + //sshKeygenCmd := "ssh-keygen -f " + otkPrivateKeyFileName + " -t rsa -N ''" + ";" + //chmodCmd := "chmod 600 " + otkPrivateKeyFileName + ";" scpCmd := "scp -i " + hostKeyFileName + " " + otkPublicKeyFileName + " " + username + "@" + ip + ":~;" localCmdString := strings.Join([]string{ - rmCmd, - sshKeygenCmd, - chmodCmd, scpCmd, }, "") //run the commands on vault server @@ -151,13 +153,13 @@ func (b *backend) pathRoleCreateWrite( if err != nil { fmt.Errorf("Failed to open '%s':%s", otkPrivateKeyFileName, err) } - dynamicPrivateKey := string(dynamicPrivateKeyBytes) + dynamicPrivateKey = string(dynamicPrivateKeyBytes) dynamicPublicKeyBytes, err := ioutil.ReadFile(otkPublicKeyFileName) if err != nil { fmt.Errorf("Failed to open '%s':%s", otkPublicKeyFileName, err) } - dynamicPublicKey := string(dynamicPublicKeyBytes) + dynamicPublicKey = string(dynamicPublicKeyBytes) return b.Secret(SecretOneTimeKeyType).Response(map[string]interface{}{ "key": dynamicPrivateKey, }, map[string]interface{}{ diff --git a/builtin/logical/ssh/ssh_util.go b/builtin/logical/ssh/ssh_util.go deleted file mode 100644 index bfa9aa489e..0000000000 --- a/builtin/logical/ssh/ssh_util.go +++ /dev/null @@ -1,44 +0,0 @@ -package ssh - -import ( - "fmt" - "os/exec" - - "golang.org/x/crypto/ssh" -) - -func exec_command(cmdString string) error { - cmd := exec.Command("/bin/bash", "-c", cmdString) - if _, err := cmd.Output(); err != nil { - return err - } - return nil -} - -func createSSHPublicKeysSession(username string, ipAddr string, hostKey string) *ssh.Session { - signer, err := ssh.ParsePrivateKey([]byte(hostKey)) - if err != nil { - fmt.Errorf("Parsing Private Key failed: " + err.Error()) - } - - config := &ssh.ClientConfig{ - User: username, - Auth: []ssh.AuthMethod{ - ssh.PublicKeys(signer), - }, - } - - client, err := ssh.Dial("tcp", ipAddr+":22", config) - if err != nil { - fmt.Errorf("Dial Failed: " + err.Error()) - } - if client == nil { - fmt.Errorf("SSH Dial to target failed: ", err.Error()) - } - - session, err := client.NewSession() - if err != nil { - fmt.Errorf("NewSession failed: " + err.Error()) - } - return session -} diff --git a/builtin/logical/ssh/util.go b/builtin/logical/ssh/util.go new file mode 100644 index 0000000000..af1856d09d --- /dev/null +++ b/builtin/logical/ssh/util.go @@ -0,0 +1,93 @@ +package ssh + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/base64" + "encoding/pem" + "fmt" + "log" + "os" + "os/exec" + + "golang.org/x/crypto/ssh" +) + +func exec_command(cmdString string) error { + cmd := exec.Command("/bin/bash", "-c", cmdString) + if _, err := cmd.Output(); err != nil { + return err + } + return nil +} + +func createSSHPublicKeysSession(username string, ipAddr string, hostKey string) *ssh.Session { + signer, err := ssh.ParsePrivateKey([]byte(hostKey)) + if err != nil { + fmt.Errorf("Parsing Private Key failed: " + err.Error()) + } + + config := &ssh.ClientConfig{ + User: username, + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(signer), + }, + } + + client, err := ssh.Dial("tcp", ipAddr+":22", config) + if err != nil { + fmt.Errorf("Dial Failed: " + err.Error()) + } + if client == nil { + fmt.Errorf("SSH Dial to target failed: ", err.Error()) + } + + session, err := client.NewSession() + if err != nil { + fmt.Errorf("NewSession failed: " + err.Error()) + } + return session +} + +func removeFile(fileName string) { + wd, err := os.Getwd() + if err != nil { + log.Printf("Error fetching working directory:%s", err) + return + } + absFileName := wd + "/" + fileName + + if _, err := os.Stat(absFileName); err == nil { + err := os.Remove(absFileName) + if err != nil { + log.Printf(fmt.Sprintf("Failed: %s", err)) + return + } else { + log.Printf("Successful\n") + } + } +} + +func generateRSAKeys() (string, string, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return "", "", fmt.Errorf("error generating RSA key-pair: %s", err) + } + + privateKeyRsa := string(pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + })) + + sshPublicKey, err := ssh.NewPublicKey(privateKey.Public()) + if err != nil { + return "", "", fmt.Errorf("error generating RSA key-pair: %s", err) + } + publicKeyRsa := "ssh-rsa " + base64.StdEncoding.EncodeToString(sshPublicKey.Marshal()) + + //ioutil.WriteFile("testkey.pem", []byte(privateKeyRsa), 0600) + //ioutil.WriteFile("testkey.pub", []byte(publicKeyRsa), 0600) + + return publicKeyRsa, privateKeyRsa, nil +} diff --git a/command/ssh.go b/command/ssh.go index fd22f1b4dc..4c89ce5b0f 100644 --- a/command/ssh.go +++ b/command/ssh.go @@ -19,20 +19,23 @@ func (c *SshCommand) Run(args []string) int { log.SetFlags(log.LstdFlags | log.Lshortfile) log.Printf("Vishal: SshCommand.Run: args:%#v len(args):%d\n", args, len(args)) flags := c.Meta.FlagSet("ssh", FlagSetDefault) + var role string + flags.StringVar(&role, "role", "", "") flags.Usage = func() { c.Ui.Error(c.Help()) } if err := flags.Parse(args); err != nil { return 1 } - + log.Printf("Vishal: Role:%s\n", role) + args = flags.Args() + if len(args) < 1 { + c.Ui.Error("ssh expects at least one argument") + return 2 + } client, err := c.Client() if err != nil { c.Ui.Error(fmt.Sprintf("Error initializing client: %s", err)) return 2 } - if len(args) < 1 { - c.Ui.Error(fmt.Sprintf("Insufficient arguments")) - return 2 - } log.Printf("Vishal: sshCommand.Run: args[0]: %#v\n", args[0]) input := strings.Split(args[0], "@") username := input[0]