From ff5a0cc8515638806f84bb726520d643737c7ce9 Mon Sep 17 00:00:00 2001 From: Michael Eischer Date: Sun, 14 Sep 2025 19:21:51 +0200 Subject: [PATCH] termstatus: fully wrap reading password from terminal --- cmd/restic/global.go | 25 ++----------------------- cmd/restic/global_test.go | 11 ----------- internal/terminal/password.go | 15 +++++++-------- internal/ui/mock.go | 9 ++++++++- internal/ui/terminal.go | 6 +++++- internal/ui/termstatus/status.go | 23 +++++++++++++++++++++++ internal/ui/termstatus/status_test.go | 11 +++++++++++ 7 files changed, 56 insertions(+), 44 deletions(-) diff --git a/cmd/restic/global.go b/cmd/restic/global.go index b31c37515..0170830f0 100644 --- a/cmd/restic/global.go +++ b/cmd/restic/global.go @@ -1,7 +1,6 @@ package main import ( - "bufio" "context" "fmt" "io" @@ -32,7 +31,6 @@ import ( "github.com/restic/restic/internal/options" "github.com/restic/restic/internal/repository" "github.com/restic/restic/internal/restic" - "github.com/restic/restic/internal/terminal" "github.com/restic/restic/internal/textfile" "github.com/restic/restic/internal/ui" "github.com/restic/restic/internal/ui/progress" @@ -232,14 +230,6 @@ func loadPasswordFromFile(pwdFile string) (string, error) { return strings.TrimSpace(string(s)), errors.Wrap(err, "Readfile") } -// readPassword reads the password from the given reader directly. -func readPassword(in io.Reader) (password string, err error) { - sc := bufio.NewScanner(in) - sc.Scan() - - return sc.Text(), errors.WithStack(sc.Err()) -} - // ReadPassword reads the password from a password file, the environment // variable RESTIC_PASSWORD or prompts the user. If the context is canceled, // the function leaks the password reading goroutine. @@ -255,20 +245,9 @@ func ReadPassword(ctx context.Context, gopts GlobalOptions, prompt string, print return gopts.password, nil } - var ( - password string - err error - ) - - if gopts.term.InputIsTerminal() { - password, err = terminal.ReadPassword(ctx, os.Stdin, os.Stderr, prompt) - } else { - printer.PT("reading repository password from stdin") - password, err = readPassword(os.Stdin) - } - + password, err := gopts.term.ReadPassword(ctx, prompt) if err != nil { - return "", errors.Wrap(err, "unable to read password") + return "", fmt.Errorf("unable to read password: %w", err) } if len(password) == 0 { diff --git a/cmd/restic/global_test.go b/cmd/restic/global_test.go index 173a7a2a8..de8275876 100644 --- a/cmd/restic/global_test.go +++ b/cmd/restic/global_test.go @@ -7,21 +7,10 @@ import ( "strings" "testing" - "github.com/restic/restic/internal/errors" rtest "github.com/restic/restic/internal/test" "github.com/restic/restic/internal/ui/progress" ) -type errorReader struct{ err error } - -func (r *errorReader) Read([]byte) (int, error) { return 0, r.err } - -func TestReadPassword(t *testing.T) { - want := errors.New("foo") - _, err := readPassword(&errorReader{want}) - rtest.Assert(t, errors.Is(err, want), "wrong error %v", err) -} - func TestReadRepo(t *testing.T) { tempDir := rtest.TempDir(t) diff --git a/internal/terminal/password.go b/internal/terminal/password.go index 6d1b6c912..675344f77 100644 --- a/internal/terminal/password.go +++ b/internal/terminal/password.go @@ -3,7 +3,7 @@ package terminal import ( "context" "fmt" - "os" + "io" "golang.org/x/term" ) @@ -12,11 +12,10 @@ import ( // tty. Prompt is printed on the writer out before attempting to read the // password. If the context is canceled, the function leaks the password reading // goroutine. -func ReadPassword(ctx context.Context, in *os.File, out *os.File, prompt string) (password string, err error) { - fd := int(out.Fd()) - state, err := term.GetState(fd) +func ReadPassword(ctx context.Context, inFd int, out io.Writer, prompt string) (password string, err error) { + state, err := term.GetState(inFd) if err != nil { - _, _ = fmt.Fprintf(os.Stderr, "unable to get terminal state: %v\n", err) + _, _ = fmt.Fprintf(out, "unable to get terminal state: %v\n", err) return "", err } @@ -29,7 +28,7 @@ func ReadPassword(ctx context.Context, in *os.File, out *os.File, prompt string) if err != nil { return } - buf, err = term.ReadPassword(int(in.Fd())) + buf, err = term.ReadPassword(inFd) if err != nil { return } @@ -38,9 +37,9 @@ func ReadPassword(ctx context.Context, in *os.File, out *os.File, prompt string) select { case <-ctx.Done(): - err := term.Restore(fd, state) + err := term.Restore(inFd, state) if err != nil { - _, _ = fmt.Fprintf(os.Stderr, "unable to restore terminal state: %v\n", err) + _, _ = fmt.Fprintf(out, "unable to restore terminal state: %v\n", err) } return "", ctx.Err() case <-done: diff --git a/internal/ui/mock.go b/internal/ui/mock.go index 70a95fe1b..edc9050f9 100644 --- a/internal/ui/mock.go +++ b/internal/ui/mock.go @@ -1,6 +1,9 @@ package ui -import "io" +import ( + "context" + "io" +) var _ Terminal = &MockTerminal{} @@ -33,6 +36,10 @@ func (m *MockTerminal) InputIsTerminal() bool { return true } +func (m *MockTerminal) ReadPassword(_ context.Context, _ string) (string, error) { + return "password", nil +} + func (m *MockTerminal) OutputRaw() io.Writer { return nil } diff --git a/internal/ui/terminal.go b/internal/ui/terminal.go index 8ff5d6f27..c53de7bf2 100644 --- a/internal/ui/terminal.go +++ b/internal/ui/terminal.go @@ -1,6 +1,9 @@ package ui -import "io" +import ( + "context" + "io" +) // Terminal is used to write messages and display status lines which can be // updated. See termstatus.Terminal for a concrete implementation. @@ -15,6 +18,7 @@ type Terminal interface { CanUpdateStatus() bool InputRaw() io.ReadCloser InputIsTerminal() bool + ReadPassword(ctx context.Context, prompt string) (string, error) // OutputRaw returns the output writer. Should only be used if there is no // other option. Must not be used in combination with Print, Error, SetStatus // or any other method that writes to the terminal. diff --git a/internal/ui/termstatus/status.go b/internal/ui/termstatus/status.go index be3a3ce59..5ee21eb37 100644 --- a/internal/ui/termstatus/status.go +++ b/internal/ui/termstatus/status.go @@ -1,6 +1,7 @@ package termstatus import ( + "bufio" "context" "fmt" "io" @@ -18,6 +19,7 @@ var _ ui.Terminal = &Terminal{} // printed. type Terminal struct { rd io.ReadCloser + inFd uintptr wr io.Writer fd uintptr errWriter io.Writer @@ -100,6 +102,7 @@ func New(rd io.ReadCloser, wr io.Writer, errWriter io.Writer, disableStatus bool if d, ok := rd.(fder); ok { if terminal.InputIsTerminal(d.Fd()) { + t.inFd = d.Fd() t.inputIsTerminal = true } } @@ -130,6 +133,26 @@ func (t *Terminal) InputRaw() io.ReadCloser { return t.rd } +func (t *Terminal) ReadPassword(ctx context.Context, prompt string) (string, error) { + if t.InputIsTerminal() { + return terminal.ReadPassword(ctx, int(t.inFd), t.errWriter, prompt) + } + if t.OutputIsTerminal() { + t.Print("reading repository password from stdin") + } + return readPassword(t.rd) +} + +// readPassword reads the password from the given reader directly. +func readPassword(in io.Reader) (password string, err error) { + sc := bufio.NewScanner(in) + sc.Scan() + if sc.Err() != nil { + return "", fmt.Errorf("readPassword: %w", sc.Err()) + } + return sc.Text(), nil +} + // CanUpdateStatus return whether the status output is updated in place. func (t *Terminal) CanUpdateStatus() bool { return t.canUpdateStatus diff --git a/internal/ui/termstatus/status_test.go b/internal/ui/termstatus/status_test.go index 064b02989..bddb7c5d1 100644 --- a/internal/ui/termstatus/status_test.go +++ b/internal/ui/termstatus/status_test.go @@ -3,6 +3,7 @@ package termstatus import ( "bytes" "context" + "errors" "fmt" "io" "testing" @@ -76,3 +77,13 @@ func TestSanitizeLines(t *testing.T) { }) } } + +type errorReader struct{ err error } + +func (r *errorReader) Read([]byte) (int, error) { return 0, r.err } + +func TestReadPassword(t *testing.T) { + want := errors.New("foo") + _, err := readPassword(&errorReader{want}) + rtest.Assert(t, errors.Is(err, want), "wrong error %v", err) +}