Skip to content

Commit

Permalink
Show RoutingGroup info in query history
Browse files Browse the repository at this point in the history
  • Loading branch information
andythsu authored and asu80 committed Feb 3, 2025
1 parent 5fa46ac commit f57f2ad
Show file tree
Hide file tree
Showing 17 changed files with 109 additions and 62 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.gateway.ha.handler;

public record RoutingDestinationInfo(String group, String clusterUri) {}
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,16 @@ public RoutingTargetHandler(
cookiesEnabled = GatewayCookieConfigurationPropertiesProvider.getInstance().isEnabled();
}

public String getRoutingDestination(HttpServletRequest request)
public RoutingDestinationInfo getRoutingDestination(HttpServletRequest request)
{
Optional<String> previousBackend = getPreviousBackend(request);
String clusterHost = previousBackend.orElseGet(() -> getBackendFromRoutingGroup(request));
// This falls back on adhoc routing group if no routing group can be determined
String routingGroup = routingGroupSelector.findRoutingGroup(request).orElse("adhoc");
String user = request.getHeader(USER_HEADER);
String clusterHost = previousBackend.orElseGet(() -> routingManager.provideBackendForRoutingGroup(routingGroup, user));
logRewrite(clusterHost, request);

return buildUriWithNewBackend(clusterHost, request);
return new RoutingDestinationInfo(routingGroup, buildUriWithNewBackend(clusterHost, request));
}

public boolean isPathWhiteListed(String path)
Expand All @@ -87,17 +90,6 @@ public boolean isPathWhiteListed(String path)
|| extraWhitelistPaths.stream().anyMatch(pattern -> pattern.matcher(path).matches());
}

private String getBackendFromRoutingGroup(HttpServletRequest request)
{
String routingGroup = routingGroupSelector.findRoutingGroup(request);
String user = request.getHeader(USER_HEADER);
if (!isNullOrEmpty(routingGroup)) {
// This falls back on adhoc backend if there is no cluster found for the routing group.
return routingManager.provideBackendForRoutingGroup(routingGroup, user);
}
return routingManager.provideAdhocBackend(user);
}

private Optional<String> getPreviousBackend(HttpServletRequest request)
{
Optional<String> queryId = extractQueryIdIfPresent(request, statementPaths, requestAnalyserClientsUseV2Format, requestAnalyserMaxBodySize);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ public record QueryHistory(
@ColumnName("backend_url") String backendUrl,
@ColumnName("user_name") @Nullable String userName,
@ColumnName("source") @Nullable String source,
@ColumnName("created") long created)
@ColumnName("created") long created,
@ColumnName("routing_group") String routingGroup)
{
public QueryHistory
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ GROUP BY FLOOR(created / 1000 / 60), backend_url
List<Map<String, Object>> findDistribution(long created);

@SqlUpdate("""
INSERT INTO query_history (query_id, query_text, backend_url, user_name, source, created)
VALUES (:queryId, :queryText, :backendUrl, :userName, :source, :created)
INSERT INTO query_history (query_id, query_text, backend_url, user_name, source, created, routing_group)
VALUES (:queryId, :queryText, :backendUrl, :userName, :source, :created, :routingGroup)
""")
void insertHistory(String queryId, String queryText, String backendUrl, String userName, String source, long created);
void insertHistory(String queryId, String queryText, String backendUrl, String userName, String source, long created, String routingGroup);

@SqlUpdate("""
DELETE FROM query_history
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public class ExternalRoutingGroupSelector
}

@Override
public String findRoutingGroup(HttpServletRequest servletRequest)
public Optional<String> findRoutingGroup(HttpServletRequest servletRequest)
{
try {
RoutingGroupExternalBody requestBody = createRequestBody(servletRequest);
Expand All @@ -100,13 +100,13 @@ public String findRoutingGroup(HttpServletRequest servletRequest)
else if (response.errors() != null && !response.errors().isEmpty()) {
throw new RuntimeException("Response with error: " + String.join(", ", response.errors()));
}
return response.routingGroup();
return Optional.ofNullable(response.routingGroup());
}
catch (Exception e) {
log.error(e, "Error occurred while retrieving routing group "
+ "from external routing rules processing at " + uri);
}
return servletRequest.getHeader(ROUTING_GROUP_HEADER);
return Optional.ofNullable(servletRequest.getHeader(ROUTING_GROUP_HEADER));
}

private RoutingGroupExternalBody createRequestBody(HttpServletRequest request)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.google.common.base.Suppliers.memoizeWithExpiration;
import static java.nio.charset.StandardCharsets.UTF_8;
Expand Down Expand Up @@ -60,7 +61,7 @@ public FileBasedRoutingGroupSelector(String rulesPath, Duration rulesRefreshPeri
}

@Override
public String findRoutingGroup(HttpServletRequest request)
public Optional<String> findRoutingGroup(HttpServletRequest request)
{
Map<String, String> result = new HashMap<>();
Map<String, Object> state = new HashMap<>();
Expand All @@ -84,7 +85,7 @@ public String findRoutingGroup(HttpServletRequest request)
rule.evaluateAction(result, data, state);
}
});
return result.get(RESULTS_ROUTING_GROUP_KEY);
return Optional.ofNullable(result.get(RESULTS_ROUTING_GROUP_KEY));
}

public List<RoutingRule> readRulesFromPath(Path rulesPath)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ public void submitQueryDetail(QueryDetail queryDetail)
queryDetail.getBackendUrl(),
queryDetail.getUser(),
queryDetail.getSource(),
queryDetail.getCaptureTime());
queryDetail.getCaptureTime(),
queryDetail.getRoutingGroup());
}

@Override
Expand All @@ -87,6 +88,7 @@ private static List<QueryHistoryManager.QueryDetail> upcast(List<QueryHistory> q
queryDetail.setBackendUrl(dao.backendUrl());
queryDetail.setUser(dao.userName());
queryDetail.setSource(dao.source());
queryDetail.setRoutingGroup(dao.routingGroup());
queryDetails.add(queryDetail);
}
return queryDetails;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class QueryDetail
private String source;
private String backendUrl;
private long captureTime;
private String routingGroup;

public QueryDetail() {}

Expand Down Expand Up @@ -125,6 +126,17 @@ public void setCaptureTime(long captureTime)
this.captureTime = captureTime;
}

@JsonProperty
public String getRoutingGroup()
{
return this.routingGroup;
}

public void setRoutingGroup(String routingGroup)
{
this.routingGroup = routingGroup;
}

@Override
public boolean equals(Object o)
{
Expand All @@ -140,13 +152,14 @@ public boolean equals(Object o)
Objects.equals(queryText, that.queryText) &&
Objects.equals(user, that.user) &&
Objects.equals(source, that.source) &&
Objects.equals(backendUrl, that.backendUrl);
Objects.equals(backendUrl, that.backendUrl) &&
Objects.equals(routingGroup, that.routingGroup);
}

@Override
public int hashCode()
{
return Objects.hash(queryId, queryText, user, source, backendUrl, captureTime);
return Objects.hash(queryId, queryText, user, source, backendUrl, captureTime, routingGroup);
}

@Override
Expand All @@ -159,6 +172,7 @@ public String toString()
.add("source", source)
.add("backendUrl", backendUrl)
.add("captureTime", captureTime)
.add("routingGroup", routingGroup)
.toString();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import io.trino.gateway.ha.config.RulesExternalConfiguration;
import jakarta.servlet.http.HttpServletRequest;

import java.util.Optional;

/**
* RoutingGroupSelector provides a way to match an HTTP request to a Gateway routing group.
*/
Expand All @@ -32,7 +34,7 @@ public interface RoutingGroupSelector
*/
static RoutingGroupSelector byRoutingGroupHeader()
{
return request -> request.getHeader(ROUTING_GROUP_HEADER);
return request -> Optional.ofNullable(request.getHeader(ROUTING_GROUP_HEADER));
}

/**
Expand Down Expand Up @@ -60,5 +62,5 @@ static RoutingGroupSelector byRoutingExternal(
* Given an HTTP request find a routing group to direct the request to. If a routing group cannot
* be determined return null.
*/
String findRoutingGroup(HttpServletRequest request);
Optional<String> findRoutingGroup(HttpServletRequest request);
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.trino.gateway.ha.config.GatewayCookieConfigurationPropertiesProvider;
import io.trino.gateway.ha.config.HaGatewayConfiguration;
import io.trino.gateway.ha.config.ProxyResponseConfiguration;
import io.trino.gateway.ha.handler.RoutingDestinationInfo;
import io.trino.gateway.ha.router.GatewayCookie;
import io.trino.gateway.ha.router.OAuth2GatewayCookie;
import io.trino.gateway.ha.router.QueryHistoryManager;
Expand Down Expand Up @@ -118,49 +119,50 @@ public void shutdown()
public void deleteRequest(
HttpServletRequest servletRequest,
AsyncResponse asyncResponse,
URI remoteUri)
RoutingDestinationInfo routingDestinationInfo)
{
Request.Builder request = prepareDelete();
performRequest(remoteUri, servletRequest, asyncResponse, request);
performRequest(routingDestinationInfo, servletRequest, asyncResponse, request);
}

public void getRequest(
HttpServletRequest servletRequest,
AsyncResponse asyncResponse,
URI remoteUri)
RoutingDestinationInfo routingDestinationInfo)
{
Request.Builder request = prepareGet();
performRequest(remoteUri, servletRequest, asyncResponse, request);
performRequest(routingDestinationInfo, servletRequest, asyncResponse, request);
}

public void postRequest(
String statement,
HttpServletRequest servletRequest,
AsyncResponse asyncResponse,
URI remoteUri)
RoutingDestinationInfo routingDestinationInfo)
{
Request.Builder request = preparePost()
.setBodyGenerator(createStaticBodyGenerator(statement, UTF_8));
performRequest(remoteUri, servletRequest, asyncResponse, request);
performRequest(routingDestinationInfo, servletRequest, asyncResponse, request);
}

public void putRequest(
String statement,
HttpServletRequest servletRequest,
AsyncResponse asyncResponse,
URI remoteUri)
RoutingDestinationInfo routingDestinationInfo)
{
Request.Builder request = preparePut()
.setBodyGenerator(createStaticBodyGenerator(statement, UTF_8));
performRequest(remoteUri, servletRequest, asyncResponse, request);
performRequest(routingDestinationInfo, servletRequest, asyncResponse, request);
}

private void performRequest(
URI remoteUri,
RoutingDestinationInfo routingDestinationInfo,
HttpServletRequest servletRequest,
AsyncResponse asyncResponse,
Request.Builder requestBuilder)
{
URI remoteUri = URI.create(routingDestinationInfo.clusterUri());
requestBuilder.setUri(remoteUri);

for (String name : list(servletRequest.getHeaderNames())) {
Expand All @@ -181,6 +183,8 @@ private void performRequest(
ImmutableList.Builder<NewCookie> cookieBuilder = ImmutableList.builder();
cookieBuilder.addAll(getOAuth2GatewayCookie(remoteUri, servletRequest));

requestBuilder.addHeader("X-Routing-Group", routingDestinationInfo.group());

Request request = requestBuilder
.setFollowRedirects(false)
.build();
Expand Down Expand Up @@ -288,6 +292,7 @@ private ProxyResponse recordBackendForQueryId(Request request, ProxyResponse res
else {
log.error("Non OK HTTP Status code with response [%s] , Status code [%s]", response.body(), response.statusCode());
}
queryDetail.setRoutingGroup(request.getHeader("X-Routing-Group"));
queryHistoryManager.submitQueryDetail(queryDetail);
return response;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.inject.Inject;
import io.trino.gateway.ha.handler.ProxyHandlerStats;
import io.trino.gateway.ha.handler.RoutingDestinationInfo;
import io.trino.gateway.ha.handler.RoutingTargetHandler;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.ws.rs.DELETE;
Expand All @@ -26,8 +27,6 @@
import jakarta.ws.rs.container.Suspended;
import jakarta.ws.rs.core.Context;

import java.net.URI;

import static io.trino.gateway.ha.handler.HttpUtils.V1_STATEMENT_PATH;
import static io.trino.gateway.proxyserver.RouterPreMatchContainerRequestFilter.ROUTE_TO_BACKEND;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -65,26 +64,26 @@ public void postHandler(
if (multiReadHttpServletRequest.getRequestURI().startsWith(V1_STATEMENT_PATH)) {
proxyHandlerStats.recordRequest();
}
String remoteUri = routingTargetHandler.getRoutingDestination(multiReadHttpServletRequest);
proxyRequestHandler.postRequest(body, multiReadHttpServletRequest, asyncResponse, URI.create(remoteUri));
RoutingDestinationInfo routingDestinationInfo = routingTargetHandler.getRoutingDestination(multiReadHttpServletRequest);
proxyRequestHandler.postRequest(body, multiReadHttpServletRequest, asyncResponse, routingDestinationInfo);
}

@GET
public void getHandler(
@Context HttpServletRequest servletRequest,
@Suspended AsyncResponse asyncResponse)
{
String remoteUri = routingTargetHandler.getRoutingDestination(servletRequest);
proxyRequestHandler.getRequest(servletRequest, asyncResponse, URI.create(remoteUri));
RoutingDestinationInfo routingDestinationInfo = routingTargetHandler.getRoutingDestination(servletRequest);
proxyRequestHandler.getRequest(servletRequest, asyncResponse, routingDestinationInfo);
}

@DELETE
public void deleteHandler(
@Context HttpServletRequest servletRequest,
@Suspended AsyncResponse asyncResponse)
{
String remoteUri = routingTargetHandler.getRoutingDestination(servletRequest);
proxyRequestHandler.deleteRequest(servletRequest, asyncResponse, URI.create(remoteUri));
RoutingDestinationInfo routingDestinationInfo = routingTargetHandler.getRoutingDestination(servletRequest);
proxyRequestHandler.deleteRequest(servletRequest, asyncResponse, routingDestinationInfo);
}

@PUT
Expand All @@ -94,7 +93,7 @@ public void putHandler(
@Suspended AsyncResponse asyncResponse)
{
MultiReadHttpServletRequest multiReadHttpServletRequest = new MultiReadHttpServletRequest(servletRequest, body);
String remoteUri = routingTargetHandler.getRoutingDestination(multiReadHttpServletRequest);
proxyRequestHandler.putRequest(body, multiReadHttpServletRequest, asyncResponse, URI.create(remoteUri));
RoutingDestinationInfo routingDestinationInfo = routingTargetHandler.getRoutingDestination(multiReadHttpServletRequest);
proxyRequestHandler.putRequest(body, multiReadHttpServletRequest, asyncResponse, routingDestinationInfo);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE query_history
ADD routing_group VARCHAR(255);
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE query_history
ADD routing_group VARCHAR(255);
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE query_history
ADD routing_group VARCHAR(255);
Loading

0 comments on commit f57f2ad

Please sign in to comment.