Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions experimental/experimental.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,13 @@ func WithBufferPool(bufferPool mem.BufferPool) grpc.DialOption {
func BufferPool(bufferPool mem.BufferPool) grpc.ServerOption {
return internal.BufferPool.(func(mem.BufferPool) grpc.ServerOption)(bufferPool)
}

// AcceptedCompressionNames returns a CallOption that limits the values
// advertised in the grpc-accept-encoding header for the provided RPC. The
// supplied names must correspond to compressors registered via
// encoding.RegisterCompressor. Passing no names advertises identity only.
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a later release.
func AcceptedCompressionNames(names ...string) grpc.CallOption {
return internal.AcceptedCompressionNames.(func(...string) grpc.CallOption)(names...)
}
4 changes: 4 additions & 0 deletions internal/experimental.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,8 @@ var (
// BufferPool is implemented by the grpc package and returns a server
// option to configure a shared buffer pool for a grpc.Server.
BufferPool any // func (grpc.SharedBufferPool) grpc.ServerOption

// AcceptedCompressionNames is implemented by the grpc package and returns
// a call option that restricts the grpc-accept-encoding header for a call.
AcceptedCompressionNames any // func(...string) grpc.CallOption
)
7 changes: 6 additions & 1 deletion internal/transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,12 +321,14 @@ func NewHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
maxHeaderListSize = *opts.MaxHeaderListSize
}

registeredCompressors := grpcutil.RegisteredCompressors()

t := &http2Client{
ctx: ctx,
ctxDone: ctx.Done(), // Cache Done chan.
cancel: cancel,
userAgent: opts.UserAgent,
registeredCompressors: grpcutil.RegisteredCompressors(),
registeredCompressors: registeredCompressors,
address: addr,
conn: conn,
remoteAddr: conn.RemoteAddr(),
Expand Down Expand Up @@ -551,6 +553,9 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr)
hfLen := 7 // :method, :scheme, :path, :authority, content-type, user-agent, te
hfLen += len(authData) + len(callAuthData)
registeredCompressors := t.registeredCompressors
if callHdr.AcceptedCompressors != nil {
registeredCompressors = *callHdr.AcceptedCompressors
}
if callHdr.PreviousAttempts > 0 {
hfLen++
}
Expand Down
6 changes: 6 additions & 0 deletions internal/transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,12 @@ type CallHdr struct {
// outbound message.
SendCompress string

// AcceptedCompressors overrides the grpc-accept-encoding header for this
// call. When nil, the transport advertises the default set of registered
// compressors. A non-nil pointer overrides that value (including the empty
// string to advertise none).
AcceptedCompressors *string

// Creds specifies credentials.PerRPCCredentials for a call.
Creds credentials.PerRPCCredentials

Expand Down
105 changes: 90 additions & 15 deletions rpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ import (
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/encoding/proto"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata"
Expand All @@ -41,6 +43,10 @@ import (
"google.golang.org/grpc/status"
)

func init() {
internal.AcceptedCompressionNames = acceptedCompressionNames
}

// Compressor defines the interface gRPC uses to compress a message.
//
// Deprecated: use package encoding.
Expand Down Expand Up @@ -151,16 +157,33 @@ func (d *gzipDecompressor) Type() string {

// callInfo contains all related configuration and information about an RPC.
type callInfo struct {
compressorName string
failFast bool
maxReceiveMessageSize *int
maxSendMessageSize *int
creds credentials.PerRPCCredentials
contentSubtype string
codec baseCodec
maxRetryRPCBufferSize int
onFinish []func(err error)
authority string
compressorName string
failFast bool
maxReceiveMessageSize *int
maxSendMessageSize *int
creds credentials.PerRPCCredentials
contentSubtype string
codec baseCodec
maxRetryRPCBufferSize int
onFinish []func(err error)
authority string
acceptedResponseCompressors *acceptedCompressionConfig
}

type acceptedCompressionConfig struct {
headerValue string
allowed map[string]struct{}
}

func (cfg *acceptedCompressionConfig) allows(name string) bool {
if cfg == nil {
return true
}
if name == "" || name == encoding.Identity {
return true
}
_, ok := cfg.allowed[name]
return ok
}

func defaultCallInfo() *callInfo {
Expand All @@ -170,6 +193,35 @@ func defaultCallInfo() *callInfo {
}
}

func newAcceptedCompressionConfig(names []string) (*acceptedCompressionConfig, error) {
cfg := &acceptedCompressionConfig{
allowed: make(map[string]struct{}, len(names)),
}
if len(names) == 0 {
return cfg, nil
}
var ordered []string
for _, name := range names {
name = strings.TrimSpace(name)
if name == "" || name == encoding.Identity {
continue
}
if !grpcutil.IsCompressorNameRegistered(name) {
return nil, status.Errorf(codes.InvalidArgument, "grpc: compressor %q is not registered", name)
}
if _, dup := cfg.allowed[name]; dup {
continue
}
cfg.allowed[name] = struct{}{}
ordered = append(ordered, name)
}
if len(ordered) == 0 {
return nil, status.Error(codes.InvalidArgument, "grpc: no valid compressor names provided")
}
cfg.headerValue = strings.Join(ordered, ",")
return cfg, nil
}

// CallOption configures a Call before it starts or extracts information from
// a Call after it completes.
type CallOption interface {
Expand Down Expand Up @@ -471,6 +523,26 @@ func (o CompressorCallOption) before(c *callInfo) error {
}
func (o CompressorCallOption) after(*callInfo, *csAttempt) {}

func acceptedCompressionNames(names ...string) CallOption {
cp := append([]string(nil), names...)
return acceptedCompressionNamesCallOption{names: cp}
}

type acceptedCompressionNamesCallOption struct {
names []string
}

func (o acceptedCompressionNamesCallOption) before(c *callInfo) error {
cfg, err := newAcceptedCompressionConfig(o.names)
if err != nil {
return err
}
c.acceptedResponseCompressors = cfg
return nil
}

func (acceptedCompressionNamesCallOption) after(*callInfo, *csAttempt) {}

// CallContentSubtype returns a CallOption that will set the content-subtype
// for a call. For example, if content-subtype is "json", the Content-Type over
// the wire will be "application/grpc+json". The content-subtype is converted
Expand Down Expand Up @@ -821,7 +893,7 @@ func outPayload(client bool, msg any, dataLength, payloadLength int, t time.Time
}
}

func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool, isServer bool) *status.Status {
func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool, isServer bool, acceptedCfg *acceptedCompressionConfig) *status.Status {
switch pf {
case compressionNone:
case compressionMade:
Expand All @@ -834,6 +906,9 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool
}
return status.Newf(codes.Internal, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
}
if !isServer && acceptedCfg != nil && !acceptedCfg.allows(recvCompress) {
return status.Newf(codes.FailedPrecondition, "grpc: peer compressed the response with %q which is not allowed by AcceptedCompressionNames", recvCompress)
}
default:
return status.Newf(codes.Internal, "grpc: received unexpected payload format %d", pf)
}
Expand All @@ -857,7 +932,7 @@ func (p *payloadInfo) free() {
// the buffer is no longer needed.
// TODO: Refactor this function to reduce the number of arguments.
// See: https://google.github.io/styleguide/go/best-practices.html#function-argument-lists
func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool,
func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool, acceptedCfg *acceptedCompressionConfig,
) (out mem.BufferSlice, err error) {
pf, compressed, err := p.recvMsg(maxReceiveMessageSize)
if err != nil {
Expand All @@ -866,7 +941,7 @@ func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveM

compressedLength := compressed.Len()

if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil, isServer); st != nil {
if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil, isServer, acceptedCfg); st != nil {
compressed.Free()
return nil, st.Err()
}
Expand Down Expand Up @@ -941,8 +1016,8 @@ type recvCompressor interface {
// For the two compressor parameters, both should not be set, but if they are,
// dc takes precedence over compressor.
// TODO(dfawley): wrap the old compressor/decompressor using the new API?
func recv(p *parser, c baseCodec, s recvCompressor, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool) error {
data, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor, isServer)
func recv(p *parser, c baseCodec, s recvCompressor, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool, acceptedCfg *acceptedCompressionConfig) error {
data, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor, isServer, acceptedCfg)
if err != nil {
return err
}
Expand Down
79 changes: 79 additions & 0 deletions rpc_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,85 @@ const (
decompressionErrorMsg = "invalid compression format"
)

func (s) TestNewAcceptedCompressionConfig(t *testing.T) {
tests := []struct {
name string
input []string
wantHeader string
wantAllowed map[string]struct{}
wantErr bool
}{
{
name: "identity-only",
input: nil,
wantHeader: "",
wantAllowed: map[string]struct{}{},
},
{
name: "single valid",
input: []string{"gzip"},
wantHeader: "gzip",
wantAllowed: map[string]struct{}{"gzip": {}},
},
{
name: "dedupe and trim",
input: []string{" gzip ", "gzip"},
wantHeader: "gzip",
wantAllowed: map[string]struct{}{"gzip": {}},
},
{
name: "ignores identity",
input: []string{"identity", "gzip"},
wantHeader: "gzip",
wantAllowed: map[string]struct{}{"gzip": {}},
},
{
name: "invalid compressor",
input: []string{"does-not-exist"},
wantErr: true,
},
{
name: "only whitespace",
input: []string{" ", "\t"},
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg, err := newAcceptedCompressionConfig(tt.input)
if (err != nil) != tt.wantErr {
t.Fatalf("newAcceptedCompressionConfig(%v) error = %v, wantErr %v", tt.input, err, tt.wantErr)
}
if tt.wantErr {
return
}
if cfg.headerValue != tt.wantHeader {
t.Fatalf("headerValue = %q, want %q", cfg.headerValue, tt.wantHeader)
}
if diff := cmp.Diff(tt.wantAllowed, cfg.allowed); diff != "" {
t.Fatalf("allowed diff (-want +got): %v", diff)
}
})
}
}

func (s) TestCheckRecvPayloadHonorsAcceptedCompressors(t *testing.T) {
cfg, err := newAcceptedCompressionConfig([]string{"gzip"})
if err != nil {
t.Fatalf("newAcceptedCompressionConfig returned error: %v", err)
}

if st := checkRecvPayload(compressionMade, "gzip", true, false, cfg); st != nil {
t.Fatalf("checkRecvPayload returned error for allowed compressor: %v", st)
}

st := checkRecvPayload(compressionMade, "snappy", true, false, cfg)
if st == nil || st.Code() != codes.FailedPrecondition {
t.Fatalf("checkRecvPayload = %v, want code %v", st, codes.FailedPrecondition)
}
}

type fullReader struct {
data []byte
}
Expand Down
2 changes: 1 addition & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1381,7 +1381,7 @@ func (s *Server) processUnaryRPC(ctx context.Context, stream *transport.ServerSt
defer payInfo.free()
}

d, err := recvAndDecompress(&parser{r: stream, bufferPool: s.opts.bufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp, true)
d, err := recvAndDecompress(&parser{r: stream, bufferPool: s.opts.bufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp, true, nil)
if err != nil {
if e := stream.WriteStatus(status.Convert(err)); e != nil {
channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e)
Expand Down
15 changes: 9 additions & 6 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,9 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client
DoneFunc: doneFunc,
Authority: callInfo.authority,
}
if cfg := callInfo.acceptedResponseCompressors; cfg != nil {
callHdr.AcceptedCompressors = &cfg.headerValue
}

// Set our outgoing compression according to the UseCompressor CallOption, if
// set. In that case, also find the compressor from the encoding package.
Expand Down Expand Up @@ -1141,7 +1144,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
// Only initialize this state once per stream.
a.decompressorSet = true
}
if err := recv(&a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decompressorV1, false); err != nil {
if err := recv(&a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decompressorV1, false, cs.callInfo.acceptedResponseCompressors); err != nil {
if err == io.EOF {
if statusErr := a.transportStream.Status().Err(); statusErr != nil {
return statusErr
Expand Down Expand Up @@ -1179,7 +1182,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
}
// Special handling for non-server-stream rpcs.
// This recv expects EOF or errors, so we don't collect inPayload.
if err := recv(&a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decompressorV1, false); err == io.EOF {
if err := recv(&a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decompressorV1, false, cs.callInfo.acceptedResponseCompressors); err == io.EOF {
return a.transportStream.Status().Err() // non-server streaming Recv returns nil on success
} else if err != nil {
return toRPCErr(err)
Expand Down Expand Up @@ -1486,7 +1489,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) {
// Only initialize this state once per stream.
as.decompressorSet = true
}
if err := recv(&as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false); err != nil {
if err := recv(&as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false, as.callInfo.acceptedResponseCompressors); err != nil {
if err == io.EOF {
if statusErr := as.transportStream.Status().Err(); statusErr != nil {
return statusErr
Expand All @@ -1508,7 +1511,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) {

// Special handling for non-server-stream rpcs.
// This recv expects EOF or errors, so we don't collect inPayload.
if err := recv(&as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false); err == io.EOF {
if err := recv(&as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false, as.callInfo.acceptedResponseCompressors); err == io.EOF {
return as.transportStream.Status().Err() // non-server streaming Recv returns nil on success
} else if err != nil {
return toRPCErr(err)
Expand Down Expand Up @@ -1785,7 +1788,7 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
payInfo = &payloadInfo{}
defer payInfo.free()
}
if err := recv(&ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, payInfo, ss.decompressorV1, true); err != nil {
if err := recv(&ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, payInfo, ss.decompressorV1, true, nil); err != nil {
if err == io.EOF {
if len(ss.binlogs) != 0 {
chc := &binarylog.ClientHalfClose{}
Expand Down Expand Up @@ -1829,7 +1832,7 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
}
// Special handling for non-client-stream rpcs.
// This recv expects EOF or errors, so we don't collect inPayload.
if err := recv(&ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, nil, ss.decompressorV1, true); err == io.EOF {
if err := recv(&ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, nil, ss.decompressorV1, true, nil); err == io.EOF {
return nil
} else if err != nil {
return err
Expand Down
Loading