diff --git a/documentation/examples/remote_storage/remote_storage_adapter/main.go b/documentation/examples/remote_storage/remote_storage_adapter/main.go index ffcbb5385a..fc3da02c6f 100644 --- a/documentation/examples/remote_storage/remote_storage_adapter/main.go +++ b/documentation/examples/remote_storage/remote_storage_adapter/main.go @@ -203,6 +203,8 @@ func buildClients(logger *slog.Logger, cfg *config) ([]writer, []reader) { } func serve(logger *slog.Logger, addr string, writers []writer, readers []reader) error { + bodyLimit := int64(32 * 1024 * 1024) + http.HandleFunc("/write", func(w http.ResponseWriter, r *http.Request) { req, err := remote.DecodeWriteRequest(r.Body) if err != nil { @@ -226,13 +228,24 @@ func serve(logger *slog.Logger, addr string, writers []writer, readers []reader) }) http.HandleFunc("/read", func(w http.ResponseWriter, r *http.Request) { - compressed, err := io.ReadAll(r.Body) + compressed, err := io.ReadAll(io.LimitReader(r.Body, bodyLimit)) if err != nil { logger.Error("Read error", "err", err.Error()) http.Error(w, err.Error(), http.StatusInternalServerError) return } + if decodedLen, err := snappy.DecodedLen(compressed); err != nil { + logger.Error("Decode error", "err", err.Error()) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } else if int64(decodedLen) > bodyLimit { + err := fmt.Errorf("decoded read request too large (>%d bytes)", bodyLimit) + logger.Error("Decode error", "err", err.Error()) + http.Error(w, err.Error(), http.StatusRequestEntityTooLarge) + return + } + reqBuf, err := snappy.Decode(nil, compressed) if err != nil { logger.Error("Decode error", "err", err.Error()) diff --git a/storage/remote/client.go b/storage/remote/client.go index c535ea3425..e8ab7c730e 100644 --- a/storage/remote/client.go +++ b/storage/remote/client.go @@ -461,6 +461,12 @@ func (*Client) handleSampledResponse(req *prompb.ReadRequest, httpResp *http.Res _ = httpResp.Body.Close() }() + if decodedLen, err := snappy.DecodedLen(compressed); err != nil { + return nil, fmt.Errorf("error reading response: %w", err) + } else if decodedLen > decodeReadLimit { + return nil, fmt.Errorf("decoded remote read response too large (>%d bytes)", decodeReadLimit) + } + uncompressed, err := snappy.Decode(nil, compressed) if err != nil { return nil, fmt.Errorf("error reading response: %w", err) diff --git a/storage/remote/codec.go b/storage/remote/codec.go index 059d5e66ce..7713f14a49 100644 --- a/storage/remote/codec.go +++ b/storage/remote/codec.go @@ -42,6 +42,8 @@ import ( const ( // decodeReadLimit is the maximum size of a read request body in bytes. decodeReadLimit = 32 * 1024 * 1024 + // decodeWriteLimit is the maximum size of a remote write request body in bytes. + decodeWriteLimit = 32 * 1024 * 1024 pbContentType = "application/x-protobuf" jsonContentType = "application/json" @@ -67,6 +69,13 @@ func DecodeReadRequest(r *http.Request) (*prompb.ReadRequest, error) { return nil, err } + // Ensure the decoded size is within a safe bound before allocating. + if decodedLen, err := snappy.DecodedLen(compressed); err != nil { + return nil, err + } else if decodedLen > decodeReadLimit { + return nil, fmt.Errorf("decoded read request too large (>%d bytes)", decodeReadLimit) + } + reqBuf, err := snappy.Decode(nil, compressed) if err != nil { return nil, err @@ -912,11 +921,17 @@ func FromLabelMatchers(matchers []*prompb.LabelMatcher) ([]*labels.Matcher, erro // snappy decompression. // Used also by documentation/examples/remote_storage. func DecodeWriteRequest(r io.Reader) (*prompb.WriteRequest, error) { - compressed, err := io.ReadAll(r) + compressed, err := io.ReadAll(io.LimitReader(r, decodeWriteLimit)) if err != nil { return nil, err } + if decodedLen, err := snappy.DecodedLen(compressed); err != nil { + return nil, err + } else if decodedLen > decodeWriteLimit { + return nil, fmt.Errorf("decoded write request too large (>%d bytes)", decodeWriteLimit) + } + reqBuf, err := snappy.Decode(nil, compressed) if err != nil { return nil, err @@ -934,11 +949,17 @@ func DecodeWriteRequest(r io.Reader) (*prompb.WriteRequest, error) { // snappy decompression. // Used also by documentation/examples/remote_storage. func DecodeWriteV2Request(r io.Reader) (*writev2.Request, error) { - compressed, err := io.ReadAll(r) + compressed, err := io.ReadAll(io.LimitReader(r, decodeWriteLimit)) if err != nil { return nil, err } + if decodedLen, err := snappy.DecodedLen(compressed); err != nil { + return nil, err + } else if decodedLen > decodeWriteLimit { + return nil, fmt.Errorf("decoded write request too large (>%d bytes)", decodeWriteLimit) + } + reqBuf, err := snappy.Decode(nil, compressed) if err != nil { return nil, err @@ -973,6 +994,7 @@ func DecodeOTLPWriteRequest(r *http.Request) (pmetricotlp.ExportRequest, error) } reader := r.Body + var gzipReader *gzip.Reader // Handle compression. switch r.Header.Get("Content-Encoding") { case "gzip": @@ -981,6 +1003,7 @@ func DecodeOTLPWriteRequest(r *http.Request) (pmetricotlp.ExportRequest, error) return pmetricotlp.NewExportRequest(), err } reader = gr + gzipReader = gr case "": // No compression. @@ -989,12 +1012,29 @@ func DecodeOTLPWriteRequest(r *http.Request) (pmetricotlp.ExportRequest, error) return pmetricotlp.NewExportRequest(), fmt.Errorf("unsupported compression: %s. Only \"gzip\" or no compression supported", r.Header.Get("Content-Encoding")) } - body, err := io.ReadAll(reader) + limitedReader := io.LimitReader(reader, int64(decodeWriteLimit)+1) + body, err := io.ReadAll(limitedReader) if err != nil { - r.Body.Close() + if gzipReader != nil { + _ = gzipReader.Close() + } + _ = r.Body.Close() return pmetricotlp.NewExportRequest(), err } - if err = r.Body.Close(); err != nil { + if len(body) > decodeWriteLimit { + if gzipReader != nil { + _ = gzipReader.Close() + } + _ = r.Body.Close() + return pmetricotlp.NewExportRequest(), fmt.Errorf("decoded write request too large (>%d bytes)", decodeWriteLimit) + } + if gzipReader != nil { + if err := gzipReader.Close(); err != nil { + _ = r.Body.Close() + return pmetricotlp.NewExportRequest(), err + } + } + if err := r.Body.Close(); err != nil { return pmetricotlp.NewExportRequest(), err } otlpReq, err := decoderFunc(body)