Skip to content

Commit 8cf8fd1

Browse files
authoredJan 23, 2025
grpc: fix message length checks when compression is enabled and maxReceiveMessageSize is MaxInt (grpc#7918)
1 parent 67bee55 commit 8cf8fd1

File tree

2 files changed

+166
-30
lines changed

2 files changed

+166
-30
lines changed
 

‎rpc_util.go

+39-30
Original file line numberDiff line numberDiff line change
@@ -828,30 +828,13 @@ func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveM
828828
return nil, st.Err()
829829
}
830830

831-
var size int
832831
if pf.isCompressed() {
833832
defer compressed.Free()
834-
835833
// To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor,
836834
// use this decompressor as the default.
837-
if dc != nil {
838-
var uncompressedBuf []byte
839-
uncompressedBuf, err = dc.Do(compressed.Reader())
840-
if err == nil {
841-
out = mem.BufferSlice{mem.SliceBuffer(uncompressedBuf)}
842-
}
843-
size = len(uncompressedBuf)
844-
} else {
845-
out, size, err = decompress(compressor, compressed, maxReceiveMessageSize, p.bufferPool)
846-
}
835+
out, err = decompress(compressor, compressed, dc, maxReceiveMessageSize, p.bufferPool)
847836
if err != nil {
848-
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
849-
}
850-
if size > maxReceiveMessageSize {
851-
out.Free()
852-
// TODO: Revisit the error code. Currently keep it consistent with java
853-
// implementation.
854-
return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize)
837+
return nil, err
855838
}
856839
} else {
857840
out = compressed
@@ -866,20 +849,46 @@ func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveM
866849
return out, nil
867850
}
868851

869-
// Using compressor, decompress d, returning data and size.
870-
// Optionally, if data will be over maxReceiveMessageSize, just return the size.
871-
func decompress(compressor encoding.Compressor, d mem.BufferSlice, maxReceiveMessageSize int, pool mem.BufferPool) (mem.BufferSlice, int, error) {
872-
dcReader, err := compressor.Decompress(d.Reader())
873-
if err != nil {
874-
return nil, 0, err
852+
// decompress processes the given data by decompressing it using either a custom decompressor or a standard compressor.
853+
// If a custom decompressor is provided, it takes precedence. The function validates that the decompressed data
854+
// does not exceed the specified maximum size and returns an error if this limit is exceeded.
855+
// On success, it returns the decompressed data. Otherwise, it returns an error if decompression fails or the data exceeds the size limit.
856+
func decompress(compressor encoding.Compressor, d mem.BufferSlice, dc Decompressor, maxReceiveMessageSize int, pool mem.BufferPool) (mem.BufferSlice, error) {
857+
if dc != nil {
858+
uncompressed, err := dc.Do(d.Reader())
859+
if err != nil {
860+
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
861+
}
862+
if len(uncompressed) > maxReceiveMessageSize {
863+
return nil, status.Errorf(codes.ResourceExhausted, "grpc: message after decompression larger than max (%d vs. %d)", len(uncompressed), maxReceiveMessageSize)
864+
}
865+
return mem.BufferSlice{mem.SliceBuffer(uncompressed)}, nil
875866
}
867+
if compressor != nil {
868+
dcReader, err := compressor.Decompress(d.Reader())
869+
if err != nil {
870+
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the message: %v", err)
871+
}
876872

877-
out, err := mem.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1), pool)
878-
if err != nil {
879-
out.Free()
880-
return nil, 0, err
873+
out, err := mem.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)), pool)
874+
if err != nil {
875+
out.Free()
876+
return nil, status.Errorf(codes.Internal, "grpc: failed to read decompressed data: %v", err)
877+
}
878+
879+
if out.Len() == maxReceiveMessageSize && !atEOF(dcReader) {
880+
out.Free()
881+
return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max %d", maxReceiveMessageSize)
882+
}
883+
return out, nil
881884
}
882-
return out, out.Len(), nil
885+
return nil, status.Errorf(codes.Internal, "grpc: no decompressor available for compressed payload")
886+
}
887+
888+
// atEOF reads data from r and returns true if zero bytes could be read and r.Read returns EOF.
889+
func atEOF(dcReader io.Reader) bool {
890+
n, err := dcReader.Read(make([]byte, 1))
891+
return n == 0 && err == io.EOF
883892
}
884893

885894
type recvCompressor interface {

‎rpc_util_test.go

+127
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,17 @@ package grpc
2121
import (
2222
"bytes"
2323
"compress/gzip"
24+
"errors"
2425
"io"
2526
"math"
2627
"reflect"
2728
"testing"
2829

30+
"github.com/google/go-cmp/cmp"
31+
"github.com/google/go-cmp/cmp/cmpopts"
2932
"google.golang.org/grpc/codes"
33+
"google.golang.org/grpc/encoding"
34+
_ "google.golang.org/grpc/encoding/gzip"
3035
protoenc "google.golang.org/grpc/encoding/proto"
3136
"google.golang.org/grpc/internal/testutils"
3237
"google.golang.org/grpc/internal/transport"
@@ -36,6 +41,11 @@ import (
3641
"google.golang.org/protobuf/proto"
3742
)
3843

44+
const (
45+
defaultDecompressedData = "default decompressed data"
46+
decompressionErrorMsg = "invalid compression format"
47+
)
48+
3949
type fullReader struct {
4050
data []byte
4151
}
@@ -294,3 +304,120 @@ func BenchmarkGZIPCompressor512KiB(b *testing.B) {
294304
func BenchmarkGZIPCompressor1MiB(b *testing.B) {
295305
bmCompressor(b, 1024*1024, NewGZIPCompressor())
296306
}
307+
308+
// compressWithDeterministicError compresses the input data and returns a BufferSlice.
309+
func compressWithDeterministicError(t *testing.T, input []byte) mem.BufferSlice {
310+
t.Helper()
311+
var buf bytes.Buffer
312+
gz := gzip.NewWriter(&buf)
313+
if _, err := gz.Write(input); err != nil {
314+
t.Fatalf("compressInput() failed to write data: %v", err)
315+
}
316+
if err := gz.Close(); err != nil {
317+
t.Fatalf("compressInput() failed to close gzip writer: %v", err)
318+
}
319+
compressedData := buf.Bytes()
320+
return mem.BufferSlice{mem.NewBuffer(&compressedData, nil)}
321+
}
322+
323+
// MockDecompressor is a mock implementation of a decompressor used for testing purposes.
324+
// It simulates decompression behavior, returning either decompressed data or an error based on the ShouldError flag.
325+
type MockDecompressor struct {
326+
ShouldError bool // Flag to control whether the decompression should simulate an error.
327+
}
328+
329+
// Do simulates decompression. It returns a predefined error if ShouldError is true,
330+
// or a fixed set of decompressed data if ShouldError is false.
331+
func (m *MockDecompressor) Do(_ io.Reader) ([]byte, error) {
332+
if m.ShouldError {
333+
return nil, errors.New(decompressionErrorMsg)
334+
}
335+
return []byte(defaultDecompressedData), nil
336+
}
337+
338+
// Type returns the string identifier for the MockDecompressor.
339+
func (m *MockDecompressor) Type() string {
340+
return "MockDecompressor"
341+
}
342+
343+
// TestDecompress tests the decompress function behaves correctly for following scenarios
344+
// decompress successfully when message is <= maxReceiveMessageSize
345+
// errors when message > maxReceiveMessageSize
346+
// decompress successfully when maxReceiveMessageSize is MaxInt
347+
// errors when the decompressed message has an invalid format
348+
// errors when the decompressed message exceeds the maxReceiveMessageSize.
349+
func (s) TestDecompress(t *testing.T) {
350+
compressor := encoding.GetCompressor("gzip")
351+
validDecompressor := &MockDecompressor{ShouldError: false}
352+
invalidFormatDecompressor := &MockDecompressor{ShouldError: true}
353+
354+
testCases := []struct {
355+
name string
356+
input mem.BufferSlice
357+
dc Decompressor
358+
maxReceiveMessageSize int
359+
want []byte
360+
wantErr error
361+
}{
362+
{
363+
name: "Decompresses successfully with sufficient buffer size",
364+
input: compressWithDeterministicError(t, []byte("decompressed data")),
365+
dc: nil,
366+
maxReceiveMessageSize: 50,
367+
want: []byte("decompressed data"),
368+
wantErr: nil,
369+
},
370+
{
371+
name: "Fails due to exceeding maxReceiveMessageSize",
372+
input: compressWithDeterministicError(t, []byte("message that is too large")),
373+
dc: nil,
374+
maxReceiveMessageSize: len("message that is too large") - 1,
375+
want: nil,
376+
wantErr: status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max %d", len("message that is too large")-1),
377+
},
378+
{
379+
name: "Decompresses to exactly maxReceiveMessageSize",
380+
input: compressWithDeterministicError(t, []byte("exact size message")),
381+
dc: nil,
382+
maxReceiveMessageSize: len("exact size message"),
383+
want: []byte("exact size message"),
384+
wantErr: nil,
385+
},
386+
{
387+
name: "Decompresses successfully with maxReceiveMessageSize MaxInt",
388+
input: compressWithDeterministicError(t, []byte("large message")),
389+
dc: nil,
390+
maxReceiveMessageSize: math.MaxInt,
391+
want: []byte("large message"),
392+
wantErr: nil,
393+
},
394+
{
395+
name: "Fails with decompression error due to invalid format",
396+
input: compressWithDeterministicError(t, []byte("invalid compressed data")),
397+
dc: invalidFormatDecompressor,
398+
maxReceiveMessageSize: 50,
399+
want: nil,
400+
wantErr: status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", errors.New(decompressionErrorMsg)),
401+
},
402+
{
403+
name: "Fails with resourceExhausted error when decompressed message exceeds maxReceiveMessageSize",
404+
input: compressWithDeterministicError(t, []byte("large compressed data")),
405+
dc: validDecompressor,
406+
maxReceiveMessageSize: 20,
407+
want: nil,
408+
wantErr: status.Errorf(codes.ResourceExhausted, "grpc: message after decompression larger than max (%d vs. %d)", 25, 20),
409+
},
410+
}
411+
412+
for _, tc := range testCases {
413+
t.Run(tc.name, func(t *testing.T) {
414+
output, err := decompress(compressor, tc.input, tc.dc, tc.maxReceiveMessageSize, mem.DefaultBufferPool())
415+
if !cmp.Equal(err, tc.wantErr, cmpopts.EquateErrors()) {
416+
t.Fatalf("decompress() err = %v, wantErr = %v", err, tc.wantErr)
417+
}
418+
if !cmp.Equal(tc.want, output.Materialize()) {
419+
t.Fatalf("decompress() output mismatch: got = %v, want = %v", output.Materialize(), tc.want)
420+
}
421+
})
422+
}
423+
}

0 commit comments

Comments
 (0)
Please sign in to comment.