diff --git a/pkg/types/state_type.go b/pkg/types/state_type.go index e2539ba9..1f45686e 100644 --- a/pkg/types/state_type.go +++ b/pkg/types/state_type.go @@ -3,8 +3,8 @@ package types import ( "database/sql/driver" "encoding" + "encoding/json" "fmt" - "strconv" ) // StateType specifies a state's hardness. @@ -12,21 +12,19 @@ type StateType uint8 // UnmarshalText implements the encoding.TextUnmarshaler interface. func (st *StateType) UnmarshalText(bytes []byte) error { - text := string(bytes) + return st.UnmarshalJSON(bytes) +} - i, err := strconv.ParseUint(text, 10, 64) - if err != nil { +// UnmarshalJSON implements the json.Unmarshaler interface. +func (st *StateType) UnmarshalJSON(data []byte) error { + var i uint8 + if err := json.Unmarshal(data, &i); err != nil { return err } s := StateType(i) - if uint64(s) != i { - // Truncated due to above cast, obviously too high - return BadStateType{text} - } - if _, ok := stateTypes[s]; !ok { - return BadStateType{text} + return BadStateType{data} } *st = s @@ -62,5 +60,6 @@ var stateTypes = map[StateType]string{ var ( _ error = BadStateType{} _ encoding.TextUnmarshaler = (*StateType)(nil) + _ json.Unmarshaler = (*StateType)(nil) _ driver.Valuer = StateType(0) )