|
| 1 | +package org.hypertrace.core.grpcutils.server; |
| 2 | + |
| 3 | +import static io.grpc.Metadata.ASCII_STRING_MARSHALLER; |
| 4 | +import static org.hypertrace.core.grpcutils.context.ContextualStatusExceptionBuilder.from; |
| 5 | + |
| 6 | +import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; |
| 7 | +import io.grpc.Metadata; |
| 8 | +import io.grpc.ServerCall; |
| 9 | +import io.grpc.ServerCall.Listener; |
| 10 | +import io.grpc.ServerCallHandler; |
| 11 | +import io.grpc.ServerInterceptor; |
| 12 | +import io.grpc.Status; |
| 13 | +import java.util.Optional; |
| 14 | +import java.util.UUID; |
| 15 | +import org.hypertrace.core.grpcutils.context.ContextualExceptionDetails; |
| 16 | +import org.hypertrace.core.grpcutils.context.RequestContext; |
| 17 | +import org.hypertrace.core.grpcutils.context.RequestContextConstants; |
| 18 | + |
| 19 | +/** |
| 20 | + * This interceptor can be used at the edge to scrub any sensitive information such as an exception |
| 21 | + * cause or metadata off the context before propagating it |
| 22 | + */ |
| 23 | +public class ExternalExceptionInterceptor implements ServerInterceptor { |
| 24 | + |
| 25 | + @Override |
| 26 | + public <ReqT, RespT> Listener<ReqT> interceptCall( |
| 27 | + ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) { |
| 28 | + ServerCall<ReqT, RespT> wrappedCall = new ExceptionWrappedServerCall<>(call, headers); |
| 29 | + return next.startCall(wrappedCall, headers); |
| 30 | + } |
| 31 | + |
| 32 | + private class ExceptionWrappedServerCall<ReqT, RespT> |
| 33 | + extends SimpleForwardingServerCall<ReqT, RespT> { |
| 34 | + private final Metadata headers; |
| 35 | + |
| 36 | + private ExceptionWrappedServerCall(ServerCall<ReqT, RespT> delegate, Metadata headers) { |
| 37 | + super(delegate); |
| 38 | + this.headers = headers; |
| 39 | + } |
| 40 | + |
| 41 | + @Override |
| 42 | + public void close(Status status, Metadata trailers) { |
| 43 | + Optional<ContextualExceptionDetails> details = |
| 44 | + resolveContextDetails(status, headers, trailers); |
| 45 | + String requestId = |
| 46 | + details |
| 47 | + .flatMap(ContextualExceptionDetails::getRequestContext) |
| 48 | + .flatMap(RequestContext::getRequestId) |
| 49 | + .orElseGet(ExternalExceptionInterceptor.this::generateDefaultRequestId); |
| 50 | + String message = |
| 51 | + details |
| 52 | + .flatMap(ContextualExceptionDetails::getExternalMessage) |
| 53 | + .orElseGet(ExternalExceptionInterceptor.this::getDefaultErrorMessage); |
| 54 | + Status externalStatus = buildExternalStatus(status, requestId, message); |
| 55 | + Metadata externalTrailers = buildExternalTrailers(trailers, requestId); |
| 56 | + super.close(externalStatus, externalTrailers); |
| 57 | + } |
| 58 | + |
| 59 | + /** Remove sensitive information from status before sending back to client. */ |
| 60 | + private Optional<ContextualExceptionDetails> resolveContextDetails( |
| 61 | + Status status, Metadata headers, Metadata trailers) { |
| 62 | + // Preference to the returned trailers then thread local value and finally calling headers |
| 63 | + return ContextualExceptionDetails.fromMetadata(trailers) |
| 64 | + .or(() -> Optional.of(from(status, RequestContext.CURRENT.get()).getDetails())) |
| 65 | + .filter( |
| 66 | + details -> |
| 67 | + details.getRequestContext().flatMap(RequestContext::getRequestId).isPresent()) |
| 68 | + .or(() -> ContextualExceptionDetails.fromMetadata(headers)); |
| 69 | + } |
| 70 | + } |
| 71 | + |
| 72 | + protected String generateDefaultRequestId() { |
| 73 | + return UUID.randomUUID().toString(); |
| 74 | + } |
| 75 | + |
| 76 | + protected String getDefaultErrorMessage() { |
| 77 | + return "Error"; |
| 78 | + } |
| 79 | + |
| 80 | + protected Status buildExternalStatus(Status status, String requestId, String message) { |
| 81 | + if (Status.OK.equals(status)) { |
| 82 | + return status; |
| 83 | + } |
| 84 | + |
| 85 | + return Status.fromCode(status.getCode()) |
| 86 | + .withDescription( |
| 87 | + String.format("Request with id: %s failed with message: %s", requestId, message)); |
| 88 | + } |
| 89 | + |
| 90 | + /** For now, only propagate request ID */ |
| 91 | + protected Metadata buildExternalTrailers(Metadata receivedTrailers, String requestId) { |
| 92 | + Metadata externalTrailers = new Metadata(); |
| 93 | + externalTrailers.put( |
| 94 | + Metadata.Key.of(RequestContextConstants.REQUEST_ID_HEADER_KEY, ASCII_STRING_MARSHALLER), |
| 95 | + requestId); |
| 96 | + return externalTrailers; |
| 97 | + } |
| 98 | +} |
0 commit comments