diff --git a/clientconn_test.go b/clientconn_test.go index f1bddde09848..b18bb28c6cb7 100644 --- a/clientconn_test.go +++ b/clientconn_test.go @@ -50,6 +50,7 @@ import ( const ( defaultTestTimeout = 10 * time.Second + defaultTestShortTimeout = 10 * time.Millisecond stateRecordingBalancerName = "state_recording_balancer" ) diff --git a/rpc_util.go b/rpc_util.go index a8ddb0af5285..ad20e9dff206 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -870,13 +870,19 @@ func decompress(compressor encoding.Compressor, d mem.BufferSlice, dc Decompress return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the message: %v", err) } - out, err := mem.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)), pool) + // Read at most one byte more than the limit from the decompressor. + // Unless the limit is MaxInt64, in which case, that's impossible, so + // apply no limit. + if limit := int64(maxReceiveMessageSize); limit < math.MaxInt64 { + dcReader = io.LimitReader(dcReader, limit+1) + } + out, err := mem.ReadAll(dcReader, pool) if err != nil { out.Free() return nil, status.Errorf(codes.Internal, "grpc: failed to read decompressed data: %v", err) } - if out.Len() == maxReceiveMessageSize && !atEOF(dcReader) { + if out.Len() > maxReceiveMessageSize { out.Free() return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max %d", maxReceiveMessageSize) } @@ -885,12 +891,6 @@ func decompress(compressor encoding.Compressor, d mem.BufferSlice, dc Decompress return nil, status.Errorf(codes.Internal, "grpc: no decompressor available for compressed payload") } -// atEOF reads data from r and returns true if zero bytes could be read and r.Read returns EOF. -func atEOF(dcReader io.Reader) bool { - n, err := dcReader.Read(make([]byte, 1)) - return n == 0 && err == io.EOF -} - type recvCompressor interface { RecvCompress() string } diff --git a/rpc_util_test.go b/rpc_util_test.go index 608cc1002471..a5c5cb8b17e2 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -21,10 +21,12 @@ package grpc import ( "bytes" "compress/gzip" + "context" "errors" "io" "math" "reflect" + "sync" "testing" "github.com/google/go-cmp/cmp" @@ -421,3 +423,57 @@ func (s) TestDecompress(t *testing.T) { }) } } + +type mockCompressor struct { + // Written to by the io.Reader on every call to Read. + ch chan<- struct{} +} + +func (m *mockCompressor) Compress(io.Writer) (io.WriteCloser, error) { + panic("unimplemented") +} + +func (m *mockCompressor) Decompress(io.Reader) (io.Reader, error) { + return m, nil +} + +func (m *mockCompressor) Read([]byte) (int, error) { + m.ch <- struct{}{} + return 1, io.EOF +} + +func (m *mockCompressor) Name() string { return "" } + +// Tests that the decompressor's Read method is not called after it returns EOF. +func (s) TestDecompress_NoReadAfterEOF(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + ch := make(chan struct{}, 10) + mc := &mockCompressor{ch: ch} + in := mem.BufferSlice{mem.NewBuffer(&[]byte{1, 2, 3}, nil)} + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + out, err := decompress(mc, in, nil, 1, mem.DefaultBufferPool()) + if err != nil { + t.Errorf("Unexpected error from decompress: %v", err) + return + } + out.Free() + }() + select { + case <-ch: + case <-ctx.Done(): + t.Fatalf("Timed out waiting for call to compressor") + } + ctx, cancel = context.WithTimeout(ctx, defaultTestShortTimeout) + defer cancel() + select { + case <-ch: + t.Fatalf("Unexpected second compressor.Read call detected") + case <-ctx.Done(): + } + wg.Wait() +}