Skip to content

Commit 6470258

Browse files
committed
client: allow overriding grpc-accept-encoding header
Signed-off-by: Israel Blancas <[email protected]>
1 parent 50c6321 commit 6470258

File tree

9 files changed

+230
-23
lines changed

9 files changed

+230
-23
lines changed

experimental/experimental.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,13 @@ func WithBufferPool(bufferPool mem.BufferPool) grpc.DialOption {
6262
func BufferPool(bufferPool mem.BufferPool) grpc.ServerOption {
6363
return internal.BufferPool.(func(mem.BufferPool) grpc.ServerOption)(bufferPool)
6464
}
65+
66+
// AcceptedCompressionNames returns a CallOption that limits the values
67+
// advertised in the grpc-accept-encoding header for the provided RPC. The
68+
// supplied names must correspond to compressors registered via
69+
// encoding.RegisterCompressor. Passing no names advertises identity only.
70+
//
71+
// Notice: This API is EXPERIMENTAL and may be changed or removed in a later release.
72+
func AcceptedCompressionNames(names ...string) grpc.CallOption {
73+
return internal.AcceptedCompressionNames.(func(...string) grpc.CallOption)(names...)
74+
}

internal/experimental.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,8 @@ var (
2525
// BufferPool is implemented by the grpc package and returns a server
2626
// option to configure a shared buffer pool for a grpc.Server.
2727
BufferPool any // func (grpc.SharedBufferPool) grpc.ServerOption
28+
29+
// AcceptedCompressionNames is implemented by the grpc package and returns
30+
// a call option that restricts the grpc-accept-encoding header for a call.
31+
AcceptedCompressionNames any // func(...string) grpc.CallOption
2832
)

internal/transport/http2_client.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,12 +321,14 @@ func NewHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
321321
maxHeaderListSize = *opts.MaxHeaderListSize
322322
}
323323

324+
registeredCompressors := grpcutil.RegisteredCompressors()
325+
324326
t := &http2Client{
325327
ctx: ctx,
326328
ctxDone: ctx.Done(), // Cache Done chan.
327329
cancel: cancel,
328330
userAgent: opts.UserAgent,
329-
registeredCompressors: grpcutil.RegisteredCompressors(),
331+
registeredCompressors: registeredCompressors,
330332
address: addr,
331333
conn: conn,
332334
remoteAddr: conn.RemoteAddr(),
@@ -551,6 +553,9 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr)
551553
hfLen := 7 // :method, :scheme, :path, :authority, content-type, user-agent, te
552554
hfLen += len(authData) + len(callAuthData)
553555
registeredCompressors := t.registeredCompressors
556+
if callHdr.AcceptedCompressors != nil {
557+
registeredCompressors = *callHdr.AcceptedCompressors
558+
}
554559
if callHdr.PreviousAttempts > 0 {
555560
hfLen++
556561
}

internal/transport/transport.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,12 @@ type CallHdr struct {
553553
// outbound message.
554554
SendCompress string
555555

556+
// AcceptedCompressors overrides the grpc-accept-encoding header for this
557+
// call. When nil, the transport advertises the default set of registered
558+
// compressors. A non-nil pointer overrides that value (including the empty
559+
// string to advertise none).
560+
AcceptedCompressors *string
561+
556562
// Creds specifies credentials.PerRPCCredentials for a call.
557563
Creds credentials.PerRPCCredentials
558564

rpc_util.go

Lines changed: 90 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ import (
3333
"google.golang.org/grpc/credentials"
3434
"google.golang.org/grpc/encoding"
3535
"google.golang.org/grpc/encoding/proto"
36+
"google.golang.org/grpc/internal"
37+
"google.golang.org/grpc/internal/grpcutil"
3638
"google.golang.org/grpc/internal/transport"
3739
"google.golang.org/grpc/mem"
3840
"google.golang.org/grpc/metadata"
@@ -41,6 +43,10 @@ import (
4143
"google.golang.org/grpc/status"
4244
)
4345

46+
func init() {
47+
internal.AcceptedCompressionNames = acceptedCompressionNames
48+
}
49+
4450
// Compressor defines the interface gRPC uses to compress a message.
4551
//
4652
// Deprecated: use package encoding.
@@ -151,16 +157,33 @@ func (d *gzipDecompressor) Type() string {
151157

152158
// callInfo contains all related configuration and information about an RPC.
153159
type callInfo struct {
154-
compressorName string
155-
failFast bool
156-
maxReceiveMessageSize *int
157-
maxSendMessageSize *int
158-
creds credentials.PerRPCCredentials
159-
contentSubtype string
160-
codec baseCodec
161-
maxRetryRPCBufferSize int
162-
onFinish []func(err error)
163-
authority string
160+
compressorName string
161+
failFast bool
162+
maxReceiveMessageSize *int
163+
maxSendMessageSize *int
164+
creds credentials.PerRPCCredentials
165+
contentSubtype string
166+
codec baseCodec
167+
maxRetryRPCBufferSize int
168+
onFinish []func(err error)
169+
authority string
170+
acceptedResponseCompressors *acceptedCompressionConfig
171+
}
172+
173+
type acceptedCompressionConfig struct {
174+
headerValue string
175+
allowed map[string]struct{}
176+
}
177+
178+
func (cfg *acceptedCompressionConfig) allows(name string) bool {
179+
if cfg == nil {
180+
return true
181+
}
182+
if name == "" || name == encoding.Identity {
183+
return true
184+
}
185+
_, ok := cfg.allowed[name]
186+
return ok
164187
}
165188

166189
func defaultCallInfo() *callInfo {
@@ -170,6 +193,35 @@ func defaultCallInfo() *callInfo {
170193
}
171194
}
172195

196+
func newAcceptedCompressionConfig(names []string) (*acceptedCompressionConfig, error) {
197+
cfg := &acceptedCompressionConfig{
198+
allowed: make(map[string]struct{}, len(names)),
199+
}
200+
if len(names) == 0 {
201+
return cfg, nil
202+
}
203+
var ordered []string
204+
for _, name := range names {
205+
name = strings.TrimSpace(name)
206+
if name == "" || name == encoding.Identity {
207+
continue
208+
}
209+
if !grpcutil.IsCompressorNameRegistered(name) {
210+
return nil, status.Errorf(codes.InvalidArgument, "grpc: compressor %q is not registered", name)
211+
}
212+
if _, dup := cfg.allowed[name]; dup {
213+
continue
214+
}
215+
cfg.allowed[name] = struct{}{}
216+
ordered = append(ordered, name)
217+
}
218+
if len(ordered) == 0 {
219+
return nil, status.Error(codes.InvalidArgument, "grpc: no valid compressor names provided")
220+
}
221+
cfg.headerValue = strings.Join(ordered, ",")
222+
return cfg, nil
223+
}
224+
173225
// CallOption configures a Call before it starts or extracts information from
174226
// a Call after it completes.
175227
type CallOption interface {
@@ -471,6 +523,26 @@ func (o CompressorCallOption) before(c *callInfo) error {
471523
}
472524
func (o CompressorCallOption) after(*callInfo, *csAttempt) {}
473525

526+
func acceptedCompressionNames(names ...string) CallOption {
527+
cp := append([]string(nil), names...)
528+
return acceptedCompressionNamesCallOption{names: cp}
529+
}
530+
531+
type acceptedCompressionNamesCallOption struct {
532+
names []string
533+
}
534+
535+
func (o acceptedCompressionNamesCallOption) before(c *callInfo) error {
536+
cfg, err := newAcceptedCompressionConfig(o.names)
537+
if err != nil {
538+
return err
539+
}
540+
c.acceptedResponseCompressors = cfg
541+
return nil
542+
}
543+
544+
func (acceptedCompressionNamesCallOption) after(*callInfo, *csAttempt) {}
545+
474546
// CallContentSubtype returns a CallOption that will set the content-subtype
475547
// for a call. For example, if content-subtype is "json", the Content-Type over
476548
// the wire will be "application/grpc+json". The content-subtype is converted
@@ -821,7 +893,7 @@ func outPayload(client bool, msg any, dataLength, payloadLength int, t time.Time
821893
}
822894
}
823895

824-
func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool, isServer bool) *status.Status {
896+
func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool, isServer bool, acceptedCfg *acceptedCompressionConfig) *status.Status {
825897
switch pf {
826898
case compressionNone:
827899
case compressionMade:
@@ -834,6 +906,9 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool
834906
}
835907
return status.Newf(codes.Internal, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
836908
}
909+
if !isServer && acceptedCfg != nil && !acceptedCfg.allows(recvCompress) {
910+
return status.Newf(codes.FailedPrecondition, "grpc: peer compressed the response with %q which is not allowed by AcceptedCompressionNames", recvCompress)
911+
}
837912
default:
838913
return status.Newf(codes.Internal, "grpc: received unexpected payload format %d", pf)
839914
}
@@ -857,7 +932,7 @@ func (p *payloadInfo) free() {
857932
// the buffer is no longer needed.
858933
// TODO: Refactor this function to reduce the number of arguments.
859934
// See: https://google.github.io/styleguide/go/best-practices.html#function-argument-lists
860-
func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool,
935+
func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool, acceptedCfg *acceptedCompressionConfig,
861936
) (out mem.BufferSlice, err error) {
862937
pf, compressed, err := p.recvMsg(maxReceiveMessageSize)
863938
if err != nil {
@@ -866,7 +941,7 @@ func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveM
866941

867942
compressedLength := compressed.Len()
868943

869-
if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil, isServer); st != nil {
944+
if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil, isServer, acceptedCfg); st != nil {
870945
compressed.Free()
871946
return nil, st.Err()
872947
}
@@ -941,8 +1016,8 @@ type recvCompressor interface {
9411016
// For the two compressor parameters, both should not be set, but if they are,
9421017
// dc takes precedence over compressor.
9431018
// TODO(dfawley): wrap the old compressor/decompressor using the new API?
944-
func recv(p *parser, c baseCodec, s recvCompressor, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool) error {
945-
data, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor, isServer)
1019+
func recv(p *parser, c baseCodec, s recvCompressor, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool, acceptedCfg *acceptedCompressionConfig) error {
1020+
data, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor, isServer, acceptedCfg)
9461021
if err != nil {
9471022
return err
9481023
}

rpc_util_test.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,85 @@ const (
4848
decompressionErrorMsg = "invalid compression format"
4949
)
5050

51+
func (s) TestNewAcceptedCompressionConfig(t *testing.T) {
52+
tests := []struct {
53+
name string
54+
input []string
55+
wantHeader string
56+
wantAllowed map[string]struct{}
57+
wantErr bool
58+
}{
59+
{
60+
name: "identity-only",
61+
input: nil,
62+
wantHeader: "",
63+
wantAllowed: map[string]struct{}{},
64+
},
65+
{
66+
name: "single valid",
67+
input: []string{"gzip"},
68+
wantHeader: "gzip",
69+
wantAllowed: map[string]struct{}{"gzip": {}},
70+
},
71+
{
72+
name: "dedupe and trim",
73+
input: []string{" gzip ", "gzip"},
74+
wantHeader: "gzip",
75+
wantAllowed: map[string]struct{}{"gzip": {}},
76+
},
77+
{
78+
name: "ignores identity",
79+
input: []string{"identity", "gzip"},
80+
wantHeader: "gzip",
81+
wantAllowed: map[string]struct{}{"gzip": {}},
82+
},
83+
{
84+
name: "invalid compressor",
85+
input: []string{"does-not-exist"},
86+
wantErr: true,
87+
},
88+
{
89+
name: "only whitespace",
90+
input: []string{" ", "\t"},
91+
wantErr: true,
92+
},
93+
}
94+
95+
for _, tt := range tests {
96+
t.Run(tt.name, func(t *testing.T) {
97+
cfg, err := newAcceptedCompressionConfig(tt.input)
98+
if (err != nil) != tt.wantErr {
99+
t.Fatalf("newAcceptedCompressionConfig(%v) error = %v, wantErr %v", tt.input, err, tt.wantErr)
100+
}
101+
if tt.wantErr {
102+
return
103+
}
104+
if cfg.headerValue != tt.wantHeader {
105+
t.Fatalf("headerValue = %q, want %q", cfg.headerValue, tt.wantHeader)
106+
}
107+
if diff := cmp.Diff(tt.wantAllowed, cfg.allowed); diff != "" {
108+
t.Fatalf("allowed diff (-want +got): %v", diff)
109+
}
110+
})
111+
}
112+
}
113+
114+
func (s) TestCheckRecvPayloadHonorsAcceptedCompressors(t *testing.T) {
115+
cfg, err := newAcceptedCompressionConfig([]string{"gzip"})
116+
if err != nil {
117+
t.Fatalf("newAcceptedCompressionConfig returned error: %v", err)
118+
}
119+
120+
if st := checkRecvPayload(compressionMade, "gzip", true, false, cfg); st != nil {
121+
t.Fatalf("checkRecvPayload returned error for allowed compressor: %v", st)
122+
}
123+
124+
st := checkRecvPayload(compressionMade, "snappy", true, false, cfg)
125+
if st == nil || st.Code() != codes.FailedPrecondition {
126+
t.Fatalf("checkRecvPayload = %v, want code %v", st, codes.FailedPrecondition)
127+
}
128+
}
129+
51130
type fullReader struct {
52131
data []byte
53132
}

server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1381,7 +1381,7 @@ func (s *Server) processUnaryRPC(ctx context.Context, stream *transport.ServerSt
13811381
defer payInfo.free()
13821382
}
13831383

1384-
d, err := recvAndDecompress(&parser{r: stream, bufferPool: s.opts.bufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp, true)
1384+
d, err := recvAndDecompress(&parser{r: stream, bufferPool: s.opts.bufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp, true, nil)
13851385
if err != nil {
13861386
if e := stream.WriteStatus(status.Convert(err)); e != nil {
13871387
channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e)

stream.go

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,9 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client
301301
DoneFunc: doneFunc,
302302
Authority: callInfo.authority,
303303
}
304+
if cfg := callInfo.acceptedResponseCompressors; cfg != nil {
305+
callHdr.AcceptedCompressors = &cfg.headerValue
306+
}
304307

305308
// Set our outgoing compression according to the UseCompressor CallOption, if
306309
// set. In that case, also find the compressor from the encoding package.
@@ -1141,7 +1144,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
11411144
// Only initialize this state once per stream.
11421145
a.decompressorSet = true
11431146
}
1144-
if err := recv(&a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decompressorV1, false); err != nil {
1147+
if err := recv(&a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decompressorV1, false, cs.callInfo.acceptedResponseCompressors); err != nil {
11451148
if err == io.EOF {
11461149
if statusErr := a.transportStream.Status().Err(); statusErr != nil {
11471150
return statusErr
@@ -1179,7 +1182,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
11791182
}
11801183
// Special handling for non-server-stream rpcs.
11811184
// This recv expects EOF or errors, so we don't collect inPayload.
1182-
if err := recv(&a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decompressorV1, false); err == io.EOF {
1185+
if err := recv(&a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decompressorV1, false, cs.callInfo.acceptedResponseCompressors); err == io.EOF {
11831186
return a.transportStream.Status().Err() // non-server streaming Recv returns nil on success
11841187
} else if err != nil {
11851188
return toRPCErr(err)
@@ -1486,7 +1489,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) {
14861489
// Only initialize this state once per stream.
14871490
as.decompressorSet = true
14881491
}
1489-
if err := recv(&as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false); err != nil {
1492+
if err := recv(&as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false, as.callInfo.acceptedResponseCompressors); err != nil {
14901493
if err == io.EOF {
14911494
if statusErr := as.transportStream.Status().Err(); statusErr != nil {
14921495
return statusErr
@@ -1508,7 +1511,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) {
15081511

15091512
// Special handling for non-server-stream rpcs.
15101513
// This recv expects EOF or errors, so we don't collect inPayload.
1511-
if err := recv(&as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false); err == io.EOF {
1514+
if err := recv(&as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false, as.callInfo.acceptedResponseCompressors); err == io.EOF {
15121515
return as.transportStream.Status().Err() // non-server streaming Recv returns nil on success
15131516
} else if err != nil {
15141517
return toRPCErr(err)
@@ -1785,7 +1788,7 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
17851788
payInfo = &payloadInfo{}
17861789
defer payInfo.free()
17871790
}
1788-
if err := recv(&ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, payInfo, ss.decompressorV1, true); err != nil {
1791+
if err := recv(&ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, payInfo, ss.decompressorV1, true, nil); err != nil {
17891792
if err == io.EOF {
17901793
if len(ss.binlogs) != 0 {
17911794
chc := &binarylog.ClientHalfClose{}
@@ -1829,7 +1832,7 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
18291832
}
18301833
// Special handling for non-client-stream rpcs.
18311834
// This recv expects EOF or errors, so we don't collect inPayload.
1832-
if err := recv(&ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, nil, ss.decompressorV1, true); err == io.EOF {
1835+
if err := recv(&ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, nil, ss.decompressorV1, true, nil); err == io.EOF {
18331836
return nil
18341837
} else if err != nil {
18351838
return err

0 commit comments

Comments
 (0)