Return a duration instead and port a few other places to use it

This commit is contained in:
Jeff Mitchell 2016-07-11 18:19:35 +00:00
parent 984641af21
commit 58efdcba47
6 changed files with 31 additions and 51 deletions

View file

@ -6,23 +6,23 @@ import (
"time"
)
func ParseDurationSecond(inp string) (int, error) {
var result int
func ParseDurationSecond(inp string) (time.Duration, error) {
var err error
var dur time.Duration
// Look for a suffix otherwise its a plain second value
if strings.HasSuffix(inp, "s") || strings.HasSuffix(inp, "m") || strings.HasSuffix(inp, "h") {
dur, err := time.ParseDuration(inp)
dur, err = time.ParseDuration(inp)
if err != nil {
return result, err
return dur, err
}
result = int(dur.Seconds())
} else {
// Plain integer
val, err := strconv.ParseInt(inp, 10, 64)
secs, err := strconv.ParseInt(inp, 10, 64)
if err != nil {
return result, err
return dur, err
}
result = int(val)
dur = time.Duration(secs) * time.Second
}
return result, nil
return dur, nil
}

View file

@ -1,20 +1,23 @@
package duration
import "testing"
import (
"testing"
"time"
)
func Test_ParseDurationSecond(t *testing.T) {
outp, err := ParseDurationSecond("9876s")
if err != nil {
t.Fatal(err)
}
if outp != 9876 {
if outp != time.Duration(9876)*time.Second {
t.Fatal("not equivalent")
}
outp, err = ParseDurationSecond("9876")
if err != nil {
t.Fatal(err)
}
if outp != 9876 {
if outp != time.Duration(9876)*time.Second {
t.Fatal("not equivalent")
}
}

View file

@ -6,11 +6,10 @@ import (
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/duration"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault"
)
@ -167,23 +166,14 @@ func requestWrapTTL(r *http.Request, req *logical.Request) (*logical.Request, er
}
// If it has an allowed suffix parse as a duration string
if strings.HasSuffix(wrapTTL, "s") || strings.HasSuffix(wrapTTL, "m") || strings.HasSuffix(wrapTTL, "h") {
dur, err := time.ParseDuration(wrapTTL)
if err != nil {
return req, err
}
req.WrapTTL = dur
} else {
// Parse as a straight number of seconds
seconds, err := strconv.ParseInt(wrapTTL, 10, 64)
if err != nil {
return req, err
}
req.WrapTTL = time.Duration(seconds) * time.Second
dur, err := duration.ParseDurationSecond(wrapTTL)
if err != nil {
return req, err
}
if int64(req.WrapTTL) < 0 {
if int64(dur) < 0 {
return req, fmt.Errorf("requested wrap ttl cannot be negative")
}
req.WrapTTL = dur
return req, nil
}

View file

@ -150,7 +150,6 @@ func (d *FieldData) getPrimitive(
case TypeDurationSecond:
var result int
var err error
switch inp := raw.(type) {
case nil:
return nil, false, nil
@ -161,10 +160,11 @@ func (d *FieldData) getPrimitive(
case float64:
result = int(inp)
case string:
result, err = duration.ParseDurationSecond(inp)
dur, err := duration.ParseDurationSecond(inp)
if err != nil {
return nil, true, err
}
result = int(dur.Seconds())
default:
return nil, false, fmt.Errorf("invalid input '%v'", raw)

View file

@ -3,10 +3,9 @@ package vault
import (
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
"github.com/hashicorp/vault/helper/duration"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -132,19 +131,9 @@ func (b *PassthroughBackend) handleRead(
}
ttlDuration := b.System().DefaultLeaseTTL()
if len(ttl) != 0 {
// Parse as a duration string if it has an appropriate suffix
if strings.HasSuffix(ttl, "s") || strings.HasSuffix(ttl, "m") || strings.HasSuffix(ttl, "h") {
dur, err := time.ParseDuration(ttl)
if err == nil {
ttlDuration = dur
}
} else {
// Parse as a straight number of seconds
seconds, err := strconv.ParseInt(ttl, 10, 64)
if err == nil {
ttlDuration = time.Duration(seconds) * time.Second
}
dur, err := duration.ParseDurationSecond(ttl)
if err == nil {
ttlDuration = dur
}
if b.generateLeases {

View file

@ -709,7 +709,7 @@ func (b *SystemBackend) handleMount(
"unable to parse default TTL of %s: %s", apiConfig.DefaultLeaseTTL, err)),
logical.ErrInvalidRequest
}
config.DefaultLeaseTTL = time.Duration(tmpDef) * time.Second
config.DefaultLeaseTTL = tmpDef
}
switch apiConfig.MaxLeaseTTL {
@ -722,7 +722,7 @@ func (b *SystemBackend) handleMount(
"unable to parse max TTL of %s: %s", apiConfig.MaxLeaseTTL, err)),
logical.ErrInvalidRequest
}
config.MaxLeaseTTL = time.Duration(tmpMax) * time.Second
config.MaxLeaseTTL = tmpMax
}
if config.MaxLeaseTTL != 0 && config.DefaultLeaseTTL > config.MaxLeaseTTL {
@ -932,8 +932,7 @@ func (b *SystemBackend) handleTuneWriteCommon(
if err != nil {
return handleError(err)
}
tmpDurDef := time.Duration(tmpDef) * time.Second
newDefault = &tmpDurDef
newDefault = &tmpDef
}
maxTTL := data.Get("max_lease_ttl").(string)
@ -947,8 +946,7 @@ func (b *SystemBackend) handleTuneWriteCommon(
if err != nil {
return handleError(err)
}
tmpDurMax := time.Duration(tmpMax) * time.Second
newMax = &tmpDurMax
newMax = &tmpMax
}
if newDefault != nil || newMax != nil {