diff --git a/helper/duration/duration.go b/helper/duration/duration.go index e3c5b683c8..f6bede0dd8 100644 --- a/helper/duration/duration.go +++ b/helper/duration/duration.go @@ -6,23 +6,23 @@ import ( "time" ) -func ParseDurationSecond(inp string) (int, error) { - var result int +func ParseDurationSecond(inp string) (time.Duration, error) { + var err error + var dur time.Duration // Look for a suffix otherwise its a plain second value if strings.HasSuffix(inp, "s") || strings.HasSuffix(inp, "m") || strings.HasSuffix(inp, "h") { - dur, err := time.ParseDuration(inp) + dur, err = time.ParseDuration(inp) if err != nil { - return result, err + return dur, err } - result = int(dur.Seconds()) } else { // Plain integer - val, err := strconv.ParseInt(inp, 10, 64) + secs, err := strconv.ParseInt(inp, 10, 64) if err != nil { - return result, err + return dur, err } - result = int(val) + dur = time.Duration(secs) * time.Second } - return result, nil + return dur, nil } diff --git a/helper/duration/duration_test.go b/helper/duration/duration_test.go index 9d3124ff45..02f0706eb0 100644 --- a/helper/duration/duration_test.go +++ b/helper/duration/duration_test.go @@ -1,20 +1,23 @@ package duration -import "testing" +import ( + "testing" + "time" +) func Test_ParseDurationSecond(t *testing.T) { outp, err := ParseDurationSecond("9876s") if err != nil { t.Fatal(err) } - if outp != 9876 { + if outp != time.Duration(9876)*time.Second { t.Fatal("not equivalent") } outp, err = ParseDurationSecond("9876") if err != nil { t.Fatal(err) } - if outp != 9876 { + if outp != time.Duration(9876)*time.Second { t.Fatal("not equivalent") } } diff --git a/http/handler.go b/http/handler.go index aef85e9756..3de72789a3 100644 --- a/http/handler.go +++ b/http/handler.go @@ -6,11 +6,10 @@ import ( "io" "net/http" "net/url" - "strconv" "strings" - "time" "github.com/hashicorp/errwrap" + "github.com/hashicorp/vault/helper/duration" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/vault" ) @@ -167,23 +166,14 @@ func requestWrapTTL(r *http.Request, req *logical.Request) (*logical.Request, er } // If it has an allowed suffix parse as a duration string - if strings.HasSuffix(wrapTTL, "s") || strings.HasSuffix(wrapTTL, "m") || strings.HasSuffix(wrapTTL, "h") { - dur, err := time.ParseDuration(wrapTTL) - if err != nil { - return req, err - } - req.WrapTTL = dur - } else { - // Parse as a straight number of seconds - seconds, err := strconv.ParseInt(wrapTTL, 10, 64) - if err != nil { - return req, err - } - req.WrapTTL = time.Duration(seconds) * time.Second + dur, err := duration.ParseDurationSecond(wrapTTL) + if err != nil { + return req, err } - if int64(req.WrapTTL) < 0 { + if int64(dur) < 0 { return req, fmt.Errorf("requested wrap ttl cannot be negative") } + req.WrapTTL = dur return req, nil } diff --git a/logical/framework/field_data.go b/logical/framework/field_data.go index 00e0f3122a..3bbd16fbc1 100644 --- a/logical/framework/field_data.go +++ b/logical/framework/field_data.go @@ -150,7 +150,6 @@ func (d *FieldData) getPrimitive( case TypeDurationSecond: var result int - var err error switch inp := raw.(type) { case nil: return nil, false, nil @@ -161,10 +160,11 @@ func (d *FieldData) getPrimitive( case float64: result = int(inp) case string: - result, err = duration.ParseDurationSecond(inp) + dur, err := duration.ParseDurationSecond(inp) if err != nil { return nil, true, err } + result = int(dur.Seconds()) default: return nil, false, fmt.Errorf("invalid input '%v'", raw) diff --git a/vault/logical_passthrough.go b/vault/logical_passthrough.go index b1fe8107dd..1d176e697b 100644 --- a/vault/logical_passthrough.go +++ b/vault/logical_passthrough.go @@ -3,10 +3,9 @@ package vault import ( "encoding/json" "fmt" - "strconv" "strings" - "time" + "github.com/hashicorp/vault/helper/duration" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -132,19 +131,9 @@ func (b *PassthroughBackend) handleRead( } ttlDuration := b.System().DefaultLeaseTTL() if len(ttl) != 0 { - - // Parse as a duration string if it has an appropriate suffix - if strings.HasSuffix(ttl, "s") || strings.HasSuffix(ttl, "m") || strings.HasSuffix(ttl, "h") { - dur, err := time.ParseDuration(ttl) - if err == nil { - ttlDuration = dur - } - } else { - // Parse as a straight number of seconds - seconds, err := strconv.ParseInt(ttl, 10, 64) - if err == nil { - ttlDuration = time.Duration(seconds) * time.Second - } + dur, err := duration.ParseDurationSecond(ttl) + if err == nil { + ttlDuration = dur } if b.generateLeases { diff --git a/vault/logical_system.go b/vault/logical_system.go index c3caaad8df..3b9a392ef9 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -709,7 +709,7 @@ func (b *SystemBackend) handleMount( "unable to parse default TTL of %s: %s", apiConfig.DefaultLeaseTTL, err)), logical.ErrInvalidRequest } - config.DefaultLeaseTTL = time.Duration(tmpDef) * time.Second + config.DefaultLeaseTTL = tmpDef } switch apiConfig.MaxLeaseTTL { @@ -722,7 +722,7 @@ func (b *SystemBackend) handleMount( "unable to parse max TTL of %s: %s", apiConfig.MaxLeaseTTL, err)), logical.ErrInvalidRequest } - config.MaxLeaseTTL = time.Duration(tmpMax) * time.Second + config.MaxLeaseTTL = tmpMax } if config.MaxLeaseTTL != 0 && config.DefaultLeaseTTL > config.MaxLeaseTTL { @@ -932,8 +932,7 @@ func (b *SystemBackend) handleTuneWriteCommon( if err != nil { return handleError(err) } - tmpDurDef := time.Duration(tmpDef) * time.Second - newDefault = &tmpDurDef + newDefault = &tmpDef } maxTTL := data.Get("max_lease_ttl").(string) @@ -947,8 +946,7 @@ func (b *SystemBackend) handleTuneWriteCommon( if err != nil { return handleError(err) } - tmpDurMax := time.Duration(tmpMax) * time.Second - newMax = &tmpDurMax + newMax = &tmpMax } if newDefault != nil || newMax != nil {