diff --git a/http/sys_health.go b/http/sys_health.go index 73b85688c5..6e7c0b92c1 100644 --- a/http/sys_health.go +++ b/http/sys_health.go @@ -20,6 +20,19 @@ func handleSysHealth(core *vault.Core) http.Handler { }) } +func fetchStatusCode(r *http.Request, field string) (int, bool, bool) { + var err error + statusCode := http.StatusOK + if statusCodeStr, statusCodeOk := r.URL.Query()[field]; statusCodeOk { + statusCode, err = strconv.Atoi(statusCodeStr[0]) + if err != nil || len(statusCodeStr) < 1 { + return http.StatusBadRequest, false, false + } + return statusCode, true, true + } + return statusCode, false, true +} + func handleSysHealthGet(core *vault.Core, w http.ResponseWriter, r *http.Request) { // Check if being a standby is allowed for the purpose of a 200 OK @@ -28,47 +41,27 @@ func handleSysHealthGet(core *vault.Core, w http.ResponseWriter, r *http.Request // FIXME: Change the sealed code to http.StatusServiceUnavailable at some // point sealedCode := http.StatusInternalServerError + if code, found, ok := fetchStatusCode(r, "sealedcode"); !ok { + respondError(w, http.StatusBadRequest, nil) + return + } else if found { + sealedCode = code + } + standbyCode := http.StatusTooManyRequests // Consul warning code + if code, found, ok := fetchStatusCode(r, "standbycode"); !ok { + respondError(w, http.StatusBadRequest, nil) + return + } else if found { + standbyCode = code + } + activeCode := http.StatusOK - - var err error - sealedCodeStr, sealedCodeOk := r.URL.Query()["sealedcode"] - if sealedCodeOk { - if len(sealedCodeStr) < 1 { - respondError(w, http.StatusBadRequest, nil) - return - } - sealedCode, err = strconv.Atoi(sealedCodeStr[0]) - if err != nil { - respondError(w, http.StatusBadRequest, nil) - return - } - } - standbyCodeStr, standbyCodeOk := r.URL.Query()["standbycode"] - if standbyCodeOk { - if len(standbyCodeStr) < 1 { - respondError(w, http.StatusBadRequest, nil) - return - } - standbyCode, err = strconv.Atoi(standbyCodeStr[0]) - if err != nil { - respondError(w, http.StatusBadRequest, nil) - return - } - } - - activeCodeStr, activeCodeOk := r.URL.Query()["activecode"] - if activeCodeOk { - if len(activeCodeStr) < 1 { - respondError(w, http.StatusBadRequest, nil) - return - } - - activeCode, err = strconv.Atoi(activeCodeStr[0]) - if err != nil { - respondError(w, http.StatusBadRequest, nil) - return - } + if code, found, ok := fetchStatusCode(r, "activecode"); !ok { + respondError(w, http.StatusBadRequest, nil) + return + } else if found { + activeCode = code } // Check system status