Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Show RoutingGroup info in query history #607

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.gateway.ha.handler;

public record RoutingDestinationInfo(String group, String clusterUri) {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name RoutingDestination would be better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would expand the scope of this change, but it seems like a good time to add some audibility on which rules were triggered. For example, you could have RoutingDestinationInfo(String group, String clusterUri, List<String> rulesFired), where rulesFired would contain the names of all rules who's condition evaluated to true. We can chat about this offline

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can clusterUri be URIinstead of aString`?

Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,16 @@ public RoutingTargetHandler(
cookiesEnabled = GatewayCookieConfigurationPropertiesProvider.getInstance().isEnabled();
}

public String getRoutingDestination(HttpServletRequest request)
public RoutingDestinationInfo getRoutingDestination(HttpServletRequest request)
{
Optional<String> previousBackend = getPreviousBackend(request);
String clusterHost = previousBackend.orElseGet(() -> getBackendFromRoutingGroup(request));
// This falls back on adhoc routing group if no routing group can be determined
String routingGroup = routingGroupSelector.findRoutingGroup(request).orElse("adhoc");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moving this call out of the orElseGet will cause the routing logic to be triggered for every incoming request, including those containing query IDs. Please revert

String user = request.getHeader(USER_HEADER);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may not be always true, especially in case of bearer tokens (JWT). But I see that this logic is from before and it's not newly added as a part of the PR.

String clusterHost = previousBackend.orElseGet(() -> routingManager.provideBackendForRoutingGroup(routingGroup, user));
logRewrite(clusterHost, request);

return buildUriWithNewBackend(clusterHost, request);
return new RoutingDestinationInfo(routingGroup, buildUriWithNewBackend(clusterHost, request));
}

public boolean isPathWhiteListed(String path)
Expand All @@ -87,17 +90,6 @@ public boolean isPathWhiteListed(String path)
|| extraWhitelistPaths.stream().anyMatch(pattern -> pattern.matcher(path).matches());
}

private String getBackendFromRoutingGroup(HttpServletRequest request)
{
String routingGroup = routingGroupSelector.findRoutingGroup(request);
String user = request.getHeader(USER_HEADER);
if (!isNullOrEmpty(routingGroup)) {
// This falls back on adhoc backend if there is no cluster found for the routing group.
return routingManager.provideBackendForRoutingGroup(routingGroup, user);
}
return routingManager.provideAdhocBackend(user);
}

private Optional<String> getPreviousBackend(HttpServletRequest request)
{
Optional<String> queryId = extractQueryIdIfPresent(request, statementPaths, requestAnalyserClientsUseV2Format, requestAnalyserMaxBodySize);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ public record QueryHistory(
@ColumnName("backend_url") String backendUrl,
@ColumnName("user_name") @Nullable String userName,
@ColumnName("source") @Nullable String source,
@ColumnName("created") long created)
@ColumnName("created") long created,
@ColumnName("routing_group") String routingGroup)
{
public QueryHistory
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ GROUP BY FLOOR(created / 1000 / 60), backend_url
List<Map<String, Object>> findDistribution(long created);

@SqlUpdate("""
INSERT INTO query_history (query_id, query_text, backend_url, user_name, source, created)
VALUES (:queryId, :queryText, :backendUrl, :userName, :source, :created)
INSERT INTO query_history (query_id, query_text, backend_url, user_name, source, created, routing_group)
VALUES (:queryId, :queryText, :backendUrl, :userName, :source, :created, :routingGroup)
""")
void insertHistory(String queryId, String queryText, String backendUrl, String userName, String source, long created);
void insertHistory(String queryId, String queryText, String backendUrl, String userName, String source, long created, String routingGroup);

@SqlUpdate("""
DELETE FROM query_history
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public class ExternalRoutingGroupSelector
}

@Override
public String findRoutingGroup(HttpServletRequest servletRequest)
public Optional<String> findRoutingGroup(HttpServletRequest servletRequest)
{
try {
RoutingGroupExternalBody requestBody = createRequestBody(servletRequest);
Expand All @@ -100,13 +100,13 @@ public String findRoutingGroup(HttpServletRequest servletRequest)
else if (response.errors() != null && !response.errors().isEmpty()) {
throw new RuntimeException("Response with error: " + String.join(", ", response.errors()));
}
return response.routingGroup();
return Optional.ofNullable(response.routingGroup());
}
catch (Exception e) {
log.error(e, "Error occurred while retrieving routing group "
+ "from external routing rules processing at " + uri);
}
return servletRequest.getHeader(ROUTING_GROUP_HEADER);
return Optional.ofNullable(servletRequest.getHeader(ROUTING_GROUP_HEADER));
}

private RoutingGroupExternalBody createRequestBody(HttpServletRequest request)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.google.common.base.Suppliers.memoizeWithExpiration;
import static java.nio.charset.StandardCharsets.UTF_8;
Expand Down Expand Up @@ -60,7 +61,7 @@ public FileBasedRoutingGroupSelector(String rulesPath, Duration rulesRefreshPeri
}

@Override
public String findRoutingGroup(HttpServletRequest request)
public Optional<String> findRoutingGroup(HttpServletRequest request)
{
Map<String, String> result = new HashMap<>();
Map<String, Object> state = new HashMap<>();
Expand All @@ -84,7 +85,7 @@ public String findRoutingGroup(HttpServletRequest request)
rule.evaluateAction(result, data, state);
}
});
return result.get(RESULTS_ROUTING_GROUP_KEY);
return Optional.ofNullable(result.get(RESULTS_ROUTING_GROUP_KEY));
}

public List<RoutingRule> readRulesFromPath(Path rulesPath)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ public void submitQueryDetail(QueryDetail queryDetail)
queryDetail.getBackendUrl(),
queryDetail.getUser(),
queryDetail.getSource(),
queryDetail.getCaptureTime());
queryDetail.getCaptureTime(),
queryDetail.getRoutingGroup());
}

@Override
Expand All @@ -87,6 +88,7 @@ private static List<QueryHistoryManager.QueryDetail> upcast(List<QueryHistory> q
queryDetail.setBackendUrl(dao.backendUrl());
queryDetail.setUser(dao.userName());
queryDetail.setSource(dao.source());
queryDetail.setRoutingGroup(dao.routingGroup());
queryDetails.add(queryDetail);
}
return queryDetails;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class QueryDetail
private String source;
private String backendUrl;
private long captureTime;
private String routingGroup;

public QueryDetail() {}

Expand Down Expand Up @@ -125,6 +126,17 @@ public void setCaptureTime(long captureTime)
this.captureTime = captureTime;
}

@JsonProperty
public String getRoutingGroup()
{
return this.routingGroup;
}

public void setRoutingGroup(String routingGroup)
{
this.routingGroup = routingGroup;
}

@Override
public boolean equals(Object o)
{
Expand All @@ -140,13 +152,14 @@ public boolean equals(Object o)
Objects.equals(queryText, that.queryText) &&
Objects.equals(user, that.user) &&
Objects.equals(source, that.source) &&
Objects.equals(backendUrl, that.backendUrl);
Objects.equals(backendUrl, that.backendUrl) &&
Objects.equals(routingGroup, that.routingGroup);
}

@Override
public int hashCode()
{
return Objects.hash(queryId, queryText, user, source, backendUrl, captureTime);
return Objects.hash(queryId, queryText, user, source, backendUrl, captureTime, routingGroup);
}

@Override
Expand All @@ -159,6 +172,7 @@ public String toString()
.add("source", source)
.add("backendUrl", backendUrl)
.add("captureTime", captureTime)
.add("routingGroup", routingGroup)
.toString();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import io.trino.gateway.ha.config.RulesExternalConfiguration;
import jakarta.servlet.http.HttpServletRequest;

import java.util.Optional;

/**
* RoutingGroupSelector provides a way to match an HTTP request to a Gateway routing group.
*/
Expand All @@ -32,7 +34,7 @@ public interface RoutingGroupSelector
*/
static RoutingGroupSelector byRoutingGroupHeader()
{
return request -> request.getHeader(ROUTING_GROUP_HEADER);
return request -> Optional.ofNullable(request.getHeader(ROUTING_GROUP_HEADER));
}

/**
Expand Down Expand Up @@ -60,5 +62,5 @@ static RoutingGroupSelector byRoutingExternal(
* Given an HTTP request find a routing group to direct the request to. If a routing group cannot
* be determined return null.
*/
String findRoutingGroup(HttpServletRequest request);
Optional<String> findRoutingGroup(HttpServletRequest request);
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.trino.gateway.ha.config.GatewayCookieConfigurationPropertiesProvider;
import io.trino.gateway.ha.config.HaGatewayConfiguration;
import io.trino.gateway.ha.config.ProxyResponseConfiguration;
import io.trino.gateway.ha.handler.RoutingDestinationInfo;
import io.trino.gateway.ha.router.GatewayCookie;
import io.trino.gateway.ha.router.OAuth2GatewayCookie;
import io.trino.gateway.ha.router.QueryHistoryManager;
Expand Down Expand Up @@ -118,49 +119,50 @@ public void shutdown()
public void deleteRequest(
HttpServletRequest servletRequest,
AsyncResponse asyncResponse,
URI remoteUri)
RoutingDestinationInfo routingDestinationInfo)
{
Request.Builder request = prepareDelete();
performRequest(remoteUri, servletRequest, asyncResponse, request);
performRequest(routingDestinationInfo, servletRequest, asyncResponse, request);
}

public void getRequest(
HttpServletRequest servletRequest,
AsyncResponse asyncResponse,
URI remoteUri)
RoutingDestinationInfo routingDestinationInfo)
{
Request.Builder request = prepareGet();
performRequest(remoteUri, servletRequest, asyncResponse, request);
performRequest(routingDestinationInfo, servletRequest, asyncResponse, request);
}

public void postRequest(
String statement,
HttpServletRequest servletRequest,
AsyncResponse asyncResponse,
URI remoteUri)
RoutingDestinationInfo routingDestinationInfo)
{
Request.Builder request = preparePost()
.setBodyGenerator(createStaticBodyGenerator(statement, UTF_8));
performRequest(remoteUri, servletRequest, asyncResponse, request);
performRequest(routingDestinationInfo, servletRequest, asyncResponse, request);
}

public void putRequest(
String statement,
HttpServletRequest servletRequest,
AsyncResponse asyncResponse,
URI remoteUri)
RoutingDestinationInfo routingDestinationInfo)
{
Request.Builder request = preparePut()
.setBodyGenerator(createStaticBodyGenerator(statement, UTF_8));
performRequest(remoteUri, servletRequest, asyncResponse, request);
performRequest(routingDestinationInfo, servletRequest, asyncResponse, request);
}

private void performRequest(
URI remoteUri,
RoutingDestinationInfo routingDestinationInfo,
HttpServletRequest servletRequest,
AsyncResponse asyncResponse,
Request.Builder requestBuilder)
{
URI remoteUri = URI.create(routingDestinationInfo.clusterUri());
requestBuilder.setUri(remoteUri);

for (String name : list(servletRequest.getHeaderNames())) {
Expand All @@ -181,6 +183,8 @@ private void performRequest(
ImmutableList.Builder<NewCookie> cookieBuilder = ImmutableList.builder();
cookieBuilder.addAll(getOAuth2GatewayCookie(remoteUri, servletRequest));

requestBuilder.addHeader("X-Routing-Group", routingDestinationInfo.group());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you pass the routing info directly to recordBackendForQueryId in the future.transform call? If not, then use ROUTING_GROUP_HEADER from RoutingGroupSelector as the header key


Request request = requestBuilder
.setFollowRedirects(false)
.build();
Expand Down Expand Up @@ -288,6 +292,7 @@ private ProxyResponse recordBackendForQueryId(Request request, ProxyResponse res
else {
log.error("Non OK HTTP Status code with response [%s] , Status code [%s]", response.body(), response.statusCode());
}
queryDetail.setRoutingGroup(request.getHeader("X-Routing-Group"));
queryHistoryManager.submitQueryDetail(queryDetail);
return response;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.inject.Inject;
import io.trino.gateway.ha.handler.ProxyHandlerStats;
import io.trino.gateway.ha.handler.RoutingDestinationInfo;
import io.trino.gateway.ha.handler.RoutingTargetHandler;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.ws.rs.DELETE;
Expand All @@ -26,8 +27,6 @@
import jakarta.ws.rs.container.Suspended;
import jakarta.ws.rs.core.Context;

import java.net.URI;

import static io.trino.gateway.ha.handler.HttpUtils.V1_STATEMENT_PATH;
import static io.trino.gateway.proxyserver.RouterPreMatchContainerRequestFilter.ROUTE_TO_BACKEND;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -65,26 +64,26 @@ public void postHandler(
if (multiReadHttpServletRequest.getRequestURI().startsWith(V1_STATEMENT_PATH)) {
proxyHandlerStats.recordRequest();
}
String remoteUri = routingTargetHandler.getRoutingDestination(multiReadHttpServletRequest);
proxyRequestHandler.postRequest(body, multiReadHttpServletRequest, asyncResponse, URI.create(remoteUri));
RoutingDestinationInfo routingDestinationInfo = routingTargetHandler.getRoutingDestination(multiReadHttpServletRequest);
proxyRequestHandler.postRequest(body, multiReadHttpServletRequest, asyncResponse, routingDestinationInfo);
}

@GET
public void getHandler(
@Context HttpServletRequest servletRequest,
@Suspended AsyncResponse asyncResponse)
{
String remoteUri = routingTargetHandler.getRoutingDestination(servletRequest);
proxyRequestHandler.getRequest(servletRequest, asyncResponse, URI.create(remoteUri));
RoutingDestinationInfo routingDestinationInfo = routingTargetHandler.getRoutingDestination(servletRequest);
proxyRequestHandler.getRequest(servletRequest, asyncResponse, routingDestinationInfo);
}

@DELETE
public void deleteHandler(
@Context HttpServletRequest servletRequest,
@Suspended AsyncResponse asyncResponse)
{
String remoteUri = routingTargetHandler.getRoutingDestination(servletRequest);
proxyRequestHandler.deleteRequest(servletRequest, asyncResponse, URI.create(remoteUri));
RoutingDestinationInfo routingDestinationInfo = routingTargetHandler.getRoutingDestination(servletRequest);
proxyRequestHandler.deleteRequest(servletRequest, asyncResponse, routingDestinationInfo);
}

@PUT
Expand All @@ -94,7 +93,7 @@ public void putHandler(
@Suspended AsyncResponse asyncResponse)
{
MultiReadHttpServletRequest multiReadHttpServletRequest = new MultiReadHttpServletRequest(servletRequest, body);
String remoteUri = routingTargetHandler.getRoutingDestination(multiReadHttpServletRequest);
proxyRequestHandler.putRequest(body, multiReadHttpServletRequest, asyncResponse, URI.create(remoteUri));
RoutingDestinationInfo routingDestinationInfo = routingTargetHandler.getRoutingDestination(multiReadHttpServletRequest);
proxyRequestHandler.putRequest(body, multiReadHttpServletRequest, asyncResponse, routingDestinationInfo);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE query_history
ADD routing_group VARCHAR(255);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: end the file with a newline. This goes for the other .sql files as well

Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE query_history
ADD routing_group VARCHAR(255);
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE query_history
ADD routing_group VARCHAR(255);
Loading