diff --git a/proxy/handler.go b/proxy/handler.go index 83920f1..9d71187 100644 --- a/proxy/handler.go +++ b/proxy/handler.go @@ -9,7 +9,6 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/codes" - "google.golang.org/grpc/transport" ) var ( @@ -60,11 +59,10 @@ type handler struct { // forwarding it to a ClientStream established against the relevant ClientConn. func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error { // little bit of gRPC internals never hurt anyone - lowLevelServerStream, ok := transport.StreamFromContext(serverStream.Context()) + fullMethodName, ok := grpc.MethodFromServerStream(serverStream) if !ok { return grpc.Errorf(codes.Internal, "lowLevelServerStream not exists in context") } - fullMethodName := lowLevelServerStream.Method() // We require that the director's returned context inherits from the serverStream.Context(). outgoingCtx, backendConn, err := s.director(serverStream.Context(), fullMethodName) clientCtx, clientCancel := context.WithCancel(outgoingCtx) diff --git a/proxy/handler_test.go b/proxy/handler_test.go index 7cb55e7..c811408 100644 --- a/proxy/handler_test.go +++ b/proxy/handler_test.go @@ -135,7 +135,7 @@ func (s *ProxyHappySuite) TestPingCarriesServerHeadersAndTrailers() { out, err := s.testClient.Ping(s.ctx(), &pb.PingRequest{Value: "foo"}, grpc.Header(&headerMd), grpc.Trailer(&trailerMd)) require.NoError(s.T(), err, "Ping should succeed without errors") require.Equal(s.T(), &pb.PingResponse{Value: "foo", Counter: 42}, out) - assert.Len(s.T(), headerMd, 1, "server response headers must contain server data") + assert.Contains(s.T(), headerMd, serverHeaderMdKey, "server response headers must contain server data") assert.Len(s.T(), trailerMd, 1, "server response trailers must contain server data") } @@ -170,7 +170,7 @@ func (s *ProxyHappySuite) TestPingStream_FullDuplexWorks() { // Check that the header arrives before all entries. headerMd, err := stream.Header() require.NoError(s.T(), err, "PingStream headers should not error.") - assert.Len(s.T(), headerMd, 1, "PingStream response headers user contain metadata") + assert.Contains(s.T(), headerMd, serverHeaderMdKey, "PingStream response headers user contain metadata") } assert.EqualValues(s.T(), i, resp.Counter, "ping roundtrip must succeed with the correct id") }