diff --git a/documentation/examples/remote_storage/remote_storage_adapter/main.go b/documentation/examples/remote_storage/remote_storage_adapter/main.go index fc3da02c6f..dd131a1101 100644 --- a/documentation/examples/remote_storage/remote_storage_adapter/main.go +++ b/documentation/examples/remote_storage/remote_storage_adapter/main.go @@ -16,7 +16,6 @@ package main import ( "fmt" - "io" "log/slog" "net/http" _ "net/http/pprof" @@ -203,8 +202,6 @@ func buildClients(logger *slog.Logger, cfg *config) ([]writer, []reader) { } func serve(logger *slog.Logger, addr string, writers []writer, readers []reader) error { - bodyLimit := int64(32 * 1024 * 1024) - http.HandleFunc("/write", func(w http.ResponseWriter, r *http.Request) { req, err := remote.DecodeWriteRequest(r.Body) if err != nil { @@ -228,38 +225,13 @@ func serve(logger *slog.Logger, addr string, writers []writer, readers []reader) }) http.HandleFunc("/read", func(w http.ResponseWriter, r *http.Request) { - compressed, err := io.ReadAll(io.LimitReader(r.Body, bodyLimit)) - if err != nil { - logger.Error("Read error", "err", err.Error()) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - if decodedLen, err := snappy.DecodedLen(compressed); err != nil { - logger.Error("Decode error", "err", err.Error()) - http.Error(w, err.Error(), http.StatusBadRequest) - return - } else if int64(decodedLen) > bodyLimit { - err := fmt.Errorf("decoded read request too large (>%d bytes)", bodyLimit) - logger.Error("Decode error", "err", err.Error()) - http.Error(w, err.Error(), http.StatusRequestEntityTooLarge) - return - } - - reqBuf, err := snappy.Decode(nil, compressed) + req, err := remote.DecodeReadRequest(r) if err != nil { logger.Error("Decode error", "err", err.Error()) http.Error(w, err.Error(), http.StatusBadRequest) return } - var req prompb.ReadRequest - if err := proto.Unmarshal(reqBuf, &req); err != nil { - logger.Error("Unmarshal error", "err", err.Error()) - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - // TODO: Support reading from more than one reader and merging the results. if len(readers) != 1 { http.Error(w, fmt.Sprintf("expected exactly one reader, found %d readers", len(readers)), http.StatusInternalServerError) @@ -267,8 +239,7 @@ func serve(logger *slog.Logger, addr string, writers []writer, readers []reader) } reader := readers[0] - var resp *prompb.ReadResponse - resp, err = reader.Read(&req) + resp, err := reader.Read(req) if err != nil { logger.Warn("Error executing query", "query", req, "storage", reader.Name(), "err", err) http.Error(w, err.Error(), http.StatusInternalServerError) @@ -284,7 +255,7 @@ func serve(logger *slog.Logger, addr string, writers []writer, readers []reader) w.Header().Set("Content-Type", "application/x-protobuf") w.Header().Set("Content-Encoding", "snappy") - compressed = snappy.Encode(nil, data) + compressed := snappy.Encode(nil, data) if _, err := w.Write(compressed); err != nil { logger.Warn("Error writing response", "storage", reader.Name(), "err", err) } diff --git a/storage/remote/client.go b/storage/remote/client.go index e8ab7c730e..f56e8f36f0 100644 --- a/storage/remote/client.go +++ b/storage/remote/client.go @@ -452,22 +452,12 @@ func (c *Client) handleReadResponse(httpResp *http.Response, req *prompb.ReadReq } func (*Client) handleSampledResponse(req *prompb.ReadRequest, httpResp *http.Response, sortSeries bool) (storage.SeriesSet, error) { - compressed, err := io.ReadAll(httpResp.Body) - if err != nil { - return nil, fmt.Errorf("error reading response. HTTP status code: %s: %w", httpResp.Status, err) - } defer func() { _, _ = io.Copy(io.Discard, httpResp.Body) _ = httpResp.Body.Close() }() - if decodedLen, err := snappy.DecodedLen(compressed); err != nil { - return nil, fmt.Errorf("error reading response: %w", err) - } else if decodedLen > decodeReadLimit { - return nil, fmt.Errorf("decoded remote read response too large (>%d bytes)", decodeReadLimit) - } - - uncompressed, err := snappy.Decode(nil, compressed) + uncompressed, err := decodeSnappyWithLimit(httpResp.Body, decodeReadLimit, "remote read response") if err != nil { return nil, fmt.Errorf("error reading response: %w", err) } diff --git a/storage/remote/codec.go b/storage/remote/codec.go index 5b99c1e96f..dae003abe2 100644 --- a/storage/remote/codec.go +++ b/storage/remote/codec.go @@ -62,21 +62,31 @@ func (e HTTPError) Status() int { return e.status } -// DecodeReadRequest reads a remote.Request from a http.Request. -func DecodeReadRequest(r *http.Request) (*prompb.ReadRequest, error) { - compressed, err := io.ReadAll(io.LimitReader(r.Body, int64(snappy.MaxEncodedLen(decodeReadLimit)+1))) +// decodeSnappyWithLimit reads and decodes snappy-compressed data enforcing both +// compressed and decoded size limits. +func decodeSnappyWithLimit(r io.Reader, limit int, name string) ([]byte, error) { + compressed, err := io.ReadAll(io.LimitReader(r, int64(limit)+1)) if err != nil { return nil, err } - - // Ensure the decoded size is within a safe bound before allocating. - if decodedLen, err := snappy.DecodedLen(compressed); err != nil { - return nil, err - } else if decodedLen > decodeReadLimit { - return nil, fmt.Errorf("decoded read request too large (>%d bytes)", decodeReadLimit) + if len(compressed) > limit { + return nil, fmt.Errorf("compressed %s too large (%d bytes; limit %d bytes)", name, len(compressed), limit) } - reqBuf, err := snappy.Decode(nil, compressed) + decodedLen, err := snappy.DecodedLen(compressed) + if err != nil { + return nil, err + } + if decodedLen > limit { + return nil, fmt.Errorf("%s too large (%d bytes; limit %d bytes)", name, decodedLen, limit) + } + + return snappy.Decode(nil, compressed) +} + +// DecodeReadRequest reads a remote.Request from a http.Request. +func DecodeReadRequest(r *http.Request) (*prompb.ReadRequest, error) { + reqBuf, err := decodeSnappyWithLimit(r.Body, decodeReadLimit, "read request") if err != nil { return nil, err } @@ -921,18 +931,7 @@ func FromLabelMatchers(matchers []*prompb.LabelMatcher) ([]*labels.Matcher, erro // snappy decompression. // Used also by documentation/examples/remote_storage. func DecodeWriteRequest(r io.Reader) (*prompb.WriteRequest, error) { - compressed, err := io.ReadAll(io.LimitReader(r, int64(snappy.MaxEncodedLen(decodeWriteLimit)+1))) - if err != nil { - return nil, err - } - - if decodedLen, err := snappy.DecodedLen(compressed); err != nil { - return nil, err - } else if decodedLen > decodeWriteLimit { - return nil, fmt.Errorf("decoded write request too large (>%d bytes)", decodeWriteLimit) - } - - reqBuf, err := snappy.Decode(nil, compressed) + reqBuf, err := decodeSnappyWithLimit(r, decodeWriteLimit, "write request") if err != nil { return nil, err } @@ -949,18 +948,7 @@ func DecodeWriteRequest(r io.Reader) (*prompb.WriteRequest, error) { // snappy decompression. // Used also by documentation/examples/remote_storage. func DecodeWriteV2Request(r io.Reader) (*writev2.Request, error) { - compressed, err := io.ReadAll(io.LimitReader(r, int64(snappy.MaxEncodedLen(decodeWriteLimit)+1))) - if err != nil { - return nil, err - } - - if decodedLen, err := snappy.DecodedLen(compressed); err != nil { - return nil, err - } else if decodedLen > decodeWriteLimit { - return nil, fmt.Errorf("decoded write request too large (>%d bytes)", decodeWriteLimit) - } - - reqBuf, err := snappy.Decode(nil, compressed) + reqBuf, err := decodeSnappyWithLimit(r, decodeWriteLimit, "write v2 request") if err != nil { return nil, err } @@ -994,7 +982,6 @@ func DecodeOTLPWriteRequest(r *http.Request) (pmetricotlp.ExportRequest, error) } reader := r.Body - var gzipReader *gzip.Reader // Handle compression. switch r.Header.Get("Content-Encoding") { case "gzip": @@ -1003,7 +990,6 @@ func DecodeOTLPWriteRequest(r *http.Request) (pmetricotlp.ExportRequest, error) return pmetricotlp.NewExportRequest(), err } reader = gr - gzipReader = gr case "": // No compression. @@ -1012,29 +998,12 @@ func DecodeOTLPWriteRequest(r *http.Request) (pmetricotlp.ExportRequest, error) return pmetricotlp.NewExportRequest(), fmt.Errorf("unsupported compression: %s. Only \"gzip\" or no compression supported", r.Header.Get("Content-Encoding")) } - limitedReader := io.LimitReader(reader, int64(decodeWriteLimit)+1) - body, err := io.ReadAll(limitedReader) + body, err := io.ReadAll(reader) if err != nil { - if gzipReader != nil { - _ = gzipReader.Close() - } - _ = r.Body.Close() + r.Body.Close() return pmetricotlp.NewExportRequest(), err } - if len(body) > decodeWriteLimit { - if gzipReader != nil { - _ = gzipReader.Close() - } - _ = r.Body.Close() - return pmetricotlp.NewExportRequest(), fmt.Errorf("decoded write request too large (>%d bytes)", decodeWriteLimit) - } - if gzipReader != nil { - if err := gzipReader.Close(); err != nil { - _ = r.Body.Close() - return pmetricotlp.NewExportRequest(), err - } - } - if err := r.Body.Close(); err != nil { + if err = r.Body.Close(); err != nil { return pmetricotlp.NewExportRequest(), err } otlpReq, err := decoderFunc(body)