@@ -33,6 +33,8 @@ import (
3333 "google.golang.org/grpc/credentials"
3434 "google.golang.org/grpc/encoding"
3535 "google.golang.org/grpc/encoding/proto"
36+ "google.golang.org/grpc/internal"
37+ "google.golang.org/grpc/internal/grpcutil"
3638 "google.golang.org/grpc/internal/transport"
3739 "google.golang.org/grpc/mem"
3840 "google.golang.org/grpc/metadata"
@@ -41,6 +43,10 @@ import (
4143 "google.golang.org/grpc/status"
4244)
4345
46+ func init () {
47+ internal .AcceptedCompressionNames = acceptedCompressionNames
48+ }
49+
4450// Compressor defines the interface gRPC uses to compress a message.
4551//
4652// Deprecated: use package encoding.
@@ -151,16 +157,33 @@ func (d *gzipDecompressor) Type() string {
151157
152158// callInfo contains all related configuration and information about an RPC.
153159type callInfo struct {
154- compressorName string
155- failFast bool
156- maxReceiveMessageSize * int
157- maxSendMessageSize * int
158- creds credentials.PerRPCCredentials
159- contentSubtype string
160- codec baseCodec
161- maxRetryRPCBufferSize int
162- onFinish []func (err error )
163- authority string
160+ compressorName string
161+ failFast bool
162+ maxReceiveMessageSize * int
163+ maxSendMessageSize * int
164+ creds credentials.PerRPCCredentials
165+ contentSubtype string
166+ codec baseCodec
167+ maxRetryRPCBufferSize int
168+ onFinish []func (err error )
169+ authority string
170+ acceptedResponseCompressors * acceptedCompressionConfig
171+ }
172+
173+ type acceptedCompressionConfig struct {
174+ headerValue string
175+ allowed map [string ]struct {}
176+ }
177+
178+ func (cfg * acceptedCompressionConfig ) allows (name string ) bool {
179+ if cfg == nil {
180+ return true
181+ }
182+ if name == "" || name == encoding .Identity {
183+ return true
184+ }
185+ _ , ok := cfg .allowed [name ]
186+ return ok
164187}
165188
166189func defaultCallInfo () * callInfo {
@@ -170,6 +193,35 @@ func defaultCallInfo() *callInfo {
170193 }
171194}
172195
196+ func newAcceptedCompressionConfig (names []string ) (* acceptedCompressionConfig , error ) {
197+ cfg := & acceptedCompressionConfig {
198+ allowed : make (map [string ]struct {}, len (names )),
199+ }
200+ if len (names ) == 0 {
201+ return cfg , nil
202+ }
203+ var ordered []string
204+ for _ , name := range names {
205+ name = strings .TrimSpace (name )
206+ if name == "" || name == encoding .Identity {
207+ continue
208+ }
209+ if ! grpcutil .IsCompressorNameRegistered (name ) {
210+ return nil , status .Errorf (codes .InvalidArgument , "grpc: compressor %q is not registered" , name )
211+ }
212+ if _ , dup := cfg .allowed [name ]; dup {
213+ continue
214+ }
215+ cfg .allowed [name ] = struct {}{}
216+ ordered = append (ordered , name )
217+ }
218+ if len (ordered ) == 0 {
219+ return nil , status .Error (codes .InvalidArgument , "grpc: no valid compressor names provided" )
220+ }
221+ cfg .headerValue = strings .Join (ordered , "," )
222+ return cfg , nil
223+ }
224+
173225// CallOption configures a Call before it starts or extracts information from
174226// a Call after it completes.
175227type CallOption interface {
@@ -471,6 +523,26 @@ func (o CompressorCallOption) before(c *callInfo) error {
471523}
472524func (o CompressorCallOption ) after (* callInfo , * csAttempt ) {}
473525
526+ func acceptedCompressionNames (names ... string ) CallOption {
527+ cp := append ([]string (nil ), names ... )
528+ return acceptedCompressionNamesCallOption {names : cp }
529+ }
530+
531+ type acceptedCompressionNamesCallOption struct {
532+ names []string
533+ }
534+
535+ func (o acceptedCompressionNamesCallOption ) before (c * callInfo ) error {
536+ cfg , err := newAcceptedCompressionConfig (o .names )
537+ if err != nil {
538+ return err
539+ }
540+ c .acceptedResponseCompressors = cfg
541+ return nil
542+ }
543+
544+ func (acceptedCompressionNamesCallOption ) after (* callInfo , * csAttempt ) {}
545+
474546// CallContentSubtype returns a CallOption that will set the content-subtype
475547// for a call. For example, if content-subtype is "json", the Content-Type over
476548// the wire will be "application/grpc+json". The content-subtype is converted
@@ -821,7 +893,7 @@ func outPayload(client bool, msg any, dataLength, payloadLength int, t time.Time
821893 }
822894}
823895
824- func checkRecvPayload (pf payloadFormat , recvCompress string , haveCompressor bool , isServer bool ) * status.Status {
896+ func checkRecvPayload (pf payloadFormat , recvCompress string , haveCompressor bool , isServer bool , acceptedCfg * acceptedCompressionConfig ) * status.Status {
825897 switch pf {
826898 case compressionNone :
827899 case compressionMade :
@@ -834,6 +906,9 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool
834906 }
835907 return status .Newf (codes .Internal , "grpc: Decompressor is not installed for grpc-encoding %q" , recvCompress )
836908 }
909+ if ! isServer && acceptedCfg != nil && ! acceptedCfg .allows (recvCompress ) {
910+ return status .Newf (codes .FailedPrecondition , "grpc: peer compressed the response with %q which is not allowed by AcceptedCompressionNames" , recvCompress )
911+ }
837912 default :
838913 return status .Newf (codes .Internal , "grpc: received unexpected payload format %d" , pf )
839914 }
@@ -857,7 +932,7 @@ func (p *payloadInfo) free() {
857932// the buffer is no longer needed.
858933// TODO: Refactor this function to reduce the number of arguments.
859934// See: https://google.github.io/styleguide/go/best-practices.html#function-argument-lists
860- func recvAndDecompress (p * parser , s recvCompressor , dc Decompressor , maxReceiveMessageSize int , payInfo * payloadInfo , compressor encoding.Compressor , isServer bool ,
935+ func recvAndDecompress (p * parser , s recvCompressor , dc Decompressor , maxReceiveMessageSize int , payInfo * payloadInfo , compressor encoding.Compressor , isServer bool , acceptedCfg * acceptedCompressionConfig ,
861936) (out mem.BufferSlice , err error ) {
862937 pf , compressed , err := p .recvMsg (maxReceiveMessageSize )
863938 if err != nil {
@@ -866,7 +941,7 @@ func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveM
866941
867942 compressedLength := compressed .Len ()
868943
869- if st := checkRecvPayload (pf , s .RecvCompress (), compressor != nil || dc != nil , isServer ); st != nil {
944+ if st := checkRecvPayload (pf , s .RecvCompress (), compressor != nil || dc != nil , isServer , acceptedCfg ); st != nil {
870945 compressed .Free ()
871946 return nil , st .Err ()
872947 }
@@ -941,8 +1016,8 @@ type recvCompressor interface {
9411016// For the two compressor parameters, both should not be set, but if they are,
9421017// dc takes precedence over compressor.
9431018// TODO(dfawley): wrap the old compressor/decompressor using the new API?
944- func recv (p * parser , c baseCodec , s recvCompressor , dc Decompressor , m any , maxReceiveMessageSize int , payInfo * payloadInfo , compressor encoding.Compressor , isServer bool ) error {
945- data , err := recvAndDecompress (p , s , dc , maxReceiveMessageSize , payInfo , compressor , isServer )
1019+ func recv (p * parser , c baseCodec , s recvCompressor , dc Decompressor , m any , maxReceiveMessageSize int , payInfo * payloadInfo , compressor encoding.Compressor , isServer bool , acceptedCfg * acceptedCompressionConfig ) error {
1020+ data , err := recvAndDecompress (p , s , dc , maxReceiveMessageSize , payInfo , compressor , isServer , acceptedCfg )
9461021 if err != nil {
9471022 return err
9481023 }
0 commit comments