Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,21 @@ public String updateApiCollectionNameForVxlan() {
return Action.SUCCESS.toUpperCase();
}

String transportType;
public String updateTransportType() {
try {
ApiCollection apiCollection = ApiCollectionsDao.instance.getMeta(apiCollectionId);
if (apiCollection != null) {
ApiCollectionsDao.instance.updateTransportType(apiCollection, transportType);
}
} catch (Exception e) {
loggerMaker.errorAndAddToDb(e, "error in updateTransportType " + e.toString());
return Action.ERROR.toUpperCase();
}
return Action.SUCCESS.toUpperCase();
}


public String updateCidrList() {
try {
DbLayer.updateCidrList(cidrList);
Expand Down Expand Up @@ -414,9 +429,9 @@ public String bulkWriteSti() {
}
}
}

System.out.println("filters: " + filters.toString());

if (isDeleteWrite) {
writes.add(
new DeleteOneModel<>(Filters.and(filters), new DeleteOptions())
Expand All @@ -427,7 +442,7 @@ public String bulkWriteSti() {
);
}
}

DbLayer.bulkWriteSingleTypeInfo(writes);
} catch (Exception e) {
String err = "Error: ";
Expand Down Expand Up @@ -466,15 +481,15 @@ public String bulkWriteSampleData() {

String responseCodeStr = mObj.get("responseCode").toString();
int responseCode = Integer.valueOf(responseCodeStr);

Bson filters = Filters.and(Filters.eq("_id.apiCollectionId", apiCollectionId),
Filters.eq("_id.bucketEndEpoch", bucketEndEpoch),
Filters.eq("_id.bucketStartEpoch", bucketStartEpoch),
Filters.eq("_id.method", mObj.get("method")),
Filters.eq("_id.responseCode", responseCode),
Filters.eq("_id.url", mObj.get("url")));
List<String> updatePayloadList = bulkUpdate.getUpdates();

List<Bson> updates = new ArrayList<>();
for (String payload: updatePayloadList) {
Map<String, Object> json = gson.fromJson(payload, Map.class);
Expand Down Expand Up @@ -836,23 +851,23 @@ public String bulkWriteOverageInfo() {
Filters.eq(UningestedApiOverage.METHOD, bulkUpdate.getFilters().get(UningestedApiOverage.METHOD)),
Filters.eq(UningestedApiOverage.URL, bulkUpdate.getFilters().get(UningestedApiOverage.URL))
);

List<String> updatePayloadList = bulkUpdate.getUpdates();
List<Bson> updates = new ArrayList<>();

for (String payload: updatePayloadList) {
Map<String, Object> json = gson.fromJson(payload, Map.class);
String field = (String) json.get("field");
Object val = json.get("val");
String op = (String) json.get("op");

if ("setOnInsert".equals(op)) {
updates.add(Updates.setOnInsert(field, val));
} else if ("set".equals(op)) {
updates.add(Updates.set(field, val));
}
}

if (!updates.isEmpty()) {
writes.add(
new UpdateOneModel<>(filters, Updates.combine(updates), new UpdateOptions().upsert(true))
Expand Down Expand Up @@ -1129,7 +1144,7 @@ public String findPendingTestingRun() {
testingRun = DbLayer.findPendingTestingRun(delta);
if (testingRun != null) {
/*
* There is a db call involved for collectionWiseTestingEndpoints, thus this hack.
* There is a db call involved for collectionWiseTestingEndpoints, thus this hack.
*/
if(testingRun.getTestingEndpoints() instanceof CollectionWiseTestingEndpoints){
CollectionWiseTestingEndpoints ts = (CollectionWiseTestingEndpoints) testingRun.getTestingEndpoints();
Expand Down Expand Up @@ -1790,7 +1805,7 @@ public String fetchTestScript() {
return Action.ERROR.toUpperCase();
}
}

public String countTestingRunResultSummaries() {
count = DbLayer.countTestingRunResultSummaries(filter);
return Action.SUCCESS.toUpperCase();
Expand Down Expand Up @@ -2773,4 +2788,11 @@ public void setTestingRunPlayground(TestingRunPlayground testingRunPlayground) {
this.testingRunPlayground = testingRunPlayground;
}

public String getTransportType() {
return transportType;
}

public void setTransportType(String transportType) {
this.transportType = transportType;
}
}
14 changes: 13 additions & 1 deletion apps/database-abstractor/src/main/resources/struts.xml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,18 @@
</result>
</action>

<action name="api/updateTransportType" class="com.akto.action.DbAction" method="updateTransportType">
<interceptor-ref name="json"/>
<interceptor-ref name="defaultStack" />
<result name="SUCCESS" type="json"/>
<result name="ERROR" type="json">
<param name="statusCode">422</param>
<param name="ignoreHierarchy">false</param>
<param name="includeProperties">^actionErrors.*</param>
</result>
</action>


<action name="api/updateCidrList" class="com.akto.action.DbAction" method="updateCidrList">
<interceptor-ref name="json"/>
<interceptor-ref name="defaultStack" />
Expand Down Expand Up @@ -1200,7 +1212,7 @@
<param name="includeProperties">^actionErrors.*</param>
</result>
</action>

<action name="api/countTestingRunResultSummaries" class="com.akto.action.DbAction" method="countTestingRunResultSummaries">
<interceptor-ref name="json"/>
<interceptor-ref name="defaultStack" />
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.akto.hybrid_runtime;

import com.akto.dao.context.Context;
import com.akto.data_actor.DataActor;
import com.akto.data_actor.DataActorFactory;
import com.akto.dto.APIConfig;
import com.akto.dto.AccountSettings;
Expand Down Expand Up @@ -32,6 +33,7 @@
import com.akto.testing.ApiExecutor;
import com.akto.util.Constants;
import com.akto.util.JSONUtils;
import com.akto.util.McpSseEndpointHelper;
import com.akto.util.Pair;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.module.SimpleModule;
Expand All @@ -57,11 +59,17 @@ public class McpToolsSyncJobExecutor {

private static final LoggerMaker logger = new LoggerMaker(McpToolsSyncJobExecutor.class, LogDb.RUNTIME);
private static final ObjectMapper mapper = new ObjectMapper();
public static final DataActor dataActor = DataActorFactory.fetchInstance();
public static final String MCP_TOOLS_LIST_REQUEST_JSON =
"{\"jsonrpc\": \"2.0\", \"id\": 1, \"method\": \"" + McpSchema.METHOD_TOOLS_LIST + "\", \"params\": {}}";
public static final String MCP_RESOURCE_LIST_REQUEST_JSON =
"{\"jsonrpc\": \"2.0\", \"id\": 1, \"method\": \"" + McpSchema.METHOD_RESOURCES_LIST + "\", \"params\": {}}";
public static final String LOCAL_IP = "127.0.0.1";

// MCP Transport types
private static final String TRANSPORT_SSE = "SSE";
private static final String TRANSPORT_HTTP = "HTTP";

private ServerCapabilities mcpServerCapabilities = null;

public static final McpToolsSyncJobExecutor INSTANCE = new McpToolsSyncJobExecutor();
Expand Down Expand Up @@ -310,7 +318,7 @@ public static Map<String, Object> generateExampleArguments(JsonSchema inputSchem
public Pair<JSONRPCResponse, HttpResponseParams> getMcpMethodResponse(String host, String mcpMethod,
String mcpMethodRequestJson, ApiCollection apiCollection) throws Exception {
OriginalHttpRequest mcpRequest = createRequest(host, mcpMethod, mcpMethodRequestJson);
String jsonrpcResponse = sendRequest(mcpRequest);
String jsonrpcResponse = sendRequest(mcpRequest, apiCollection);

JSONRPCResponse rpcResponse = (JSONRPCResponse) McpSchema.deserializeJsonRpcMessage(mapper, jsonrpcResponse);

Expand Down Expand Up @@ -347,13 +355,61 @@ private String buildHeaders(String host) {
return "{\"Content-Type\":\"application/json\",\"Accept\":\"*/*\",\"host\":\"" + host + "\"}";
}

private String sendRequest(OriginalHttpRequest request) throws Exception {

private String detectAndSetTransportType(OriginalHttpRequest request, ApiCollection apiCollection) throws Exception {
// Try SSE first if sseCallbackUrl is set
if (apiCollection.getSseCallbackUrl() != null && !apiCollection.getSseCallbackUrl().isEmpty()) {
try {
logger.info("Attempting to detect transport type for MCP server: {}", apiCollection.getHostName());

// Clone request for SSE detection to avoid modifying original
OriginalHttpRequest sseTestRequest = request.copy();
McpSseEndpointHelper.addSseEndpointHeader(sseTestRequest, apiCollection.getId());
ApiExecutor.sendRequestWithSse(sseTestRequest, true, null, false,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we should first check for Streamable support and then fallback to SSE as SSE is deprecated.

new ArrayList<>(), false, true);

// If SSE works, update the collection
dataActor.updateTransportType(apiCollection.getId(), TRANSPORT_SSE);
logger.info("Detected SSE transport for MCP server: {}", apiCollection.getHostName());
return TRANSPORT_SSE;
} catch (Exception sseException) {
logger.info("SSE transport failed, falling back to HTTP transport: {}", sseException.getMessage());
// Fall back to HTTP - no need to test, just store it
dataActor.updateTransportType(apiCollection.getId(), TRANSPORT_SSE);
return TRANSPORT_HTTP;
}
}

// Default to HTTP if no sseCallbackUrl
logger.info("No SSE callback URL found, using HTTP transport for: {}", apiCollection.getHostName());
dataActor.updateTransportType(apiCollection.getId(), TRANSPORT_SSE);
return TRANSPORT_HTTP;
}
private String sendRequest(OriginalHttpRequest request, ApiCollection apiCollection) throws Exception {
try {
OriginalHttpResponse response = ApiExecutor.sendRequestWithSse(request, true, null, false,
new ArrayList<>(), false, true);
String transportType = apiCollection.getMcpTransportType();

// If transport type is not set, try to detect it
if (transportType == null || transportType.isEmpty()) {
transportType = detectAndSetTransportType(request, apiCollection);
}

OriginalHttpResponse response;
if (TRANSPORT_HTTP.equals(transportType)) {
// Use standard HTTP POST for streamable responses
// Use sendRequestSkipSse to prevent ApiExecutor from trying SSE
logger.info("Using HTTP transport for MCP server: {}", apiCollection.getHostName());
response = ApiExecutor.sendRequestSkipSse(request, true, null, false, new ArrayList<>(), false);
} else {
// Use SSE transport
logger.info("Using SSE transport for MCP server: {}", apiCollection.getHostName());
McpSseEndpointHelper.addSseEndpointHeader(request, apiCollection.getId());
response = ApiExecutor.sendRequestWithSse(request, true, null, false,
new ArrayList<>(), false, true);
}
return response.getBody();
} catch (Exception e) {
logger.error("Error while making request to MCP server.", e);
logger.error("Error while making request to MCP server: {}", e.getMessage(), e);
throw e;
}
}
Expand Down
8 changes: 8 additions & 0 deletions libs/dao/src/main/java/com/akto/dao/ApiCollectionsDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ public Map<Integer, ApiCollection> getApiCollectionsMetaMap() {
return apiCollectionsMap;
}

public void updateTransportType(ApiCollection apiCollection, String transportType) {
Bson filter = Filters.eq(ApiCollection.ID, apiCollection.getId());
Bson update = Updates.set(ApiCollection.MCP_TRANSPORT_TYPE, transportType);
ApiCollectionsDao.instance.updateOne(filter, update);
apiCollection.setMcpTransportType(transportType);
}


public List<ApiCollection> getMetaAll() {
return ApiCollectionsDao.instance.findAll(new BasicDBObject(), Projections.exclude("urls"));
}
Expand Down
13 changes: 12 additions & 1 deletion libs/dao/src/main/java/com/akto/dto/ApiCollection.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,17 @@ public class ApiCollection {
String sseCallbackUrl;
public static final String SSE_CALLBACK_URL = "sseCallbackUrl";

private String mcpTransportType;
public static final String MCP_TRANSPORT_TYPE = "mcpTransportType";

public String getMcpTransportType() {
return mcpTransportType;
}

public void setMcpTransportType(String mcpTransportType) {
this.mcpTransportType = mcpTransportType;
}

public enum Type {
API_GROUP
}
Expand Down Expand Up @@ -91,7 +102,7 @@ public enum ENV_TYPE {
"localhost", "local", "intranet", "lan", "example", "invalid",
"home", "corp", "priv", "localdomain", "localnet", "network",
"int", "private");

private static final List<String> ENV_KEYWORDS_WITHOUT_DOT = Arrays.asList(
"kubernetes", "internal"
);
Expand Down
19 changes: 19 additions & 0 deletions libs/utils/src/main/java/com/akto/data_actor/ClientActor.java
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,25 @@ public void updateModuleInfo(ModuleInfo moduleInfo) {
}
}

@Override
public void updateTransportType(int apiCollectionId, String transportType) {
Map<String, List<String>> headers = buildHeaders();
BasicDBObject obj = new BasicDBObject();
obj.put("apiCollectionId", apiCollectionId);
obj.put("transportType", transportType);
OriginalHttpRequest request = new OriginalHttpRequest(url + "/updateTransportType", "", "POST", obj.toString(), headers, "");
try {
OriginalHttpResponse response = ApiExecutor.sendRequest(request, true, null, false, null);
if (response.getStatusCode() != 200) {
loggerMaker.errorAndAddToDb("non 2xx response in updateTransportType", LoggerMaker.LogDb.RUNTIME);
return;
}
} catch (Exception e) {
loggerMaker.errorAndAddToDb("error updating transport type" + e + " apiCollectionId " + apiCollectionId
+ " transportType " + transportType, LoggerMaker.LogDb.RUNTIME);
}
}

public APIConfig fetchApiConfig(String configName) {
Map<String, List<String>> headers = buildHeaders();
String queryParams = "?configName="+configName;
Expand Down
2 changes: 2 additions & 0 deletions libs/utils/src/main/java/com/akto/data_actor/DataActor.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ public abstract class DataActor {

public abstract void updateApiCollectionNameForVxlan(int vxlanId, String name);

public abstract void updateTransportType(int apiCollectionId, String transportType);

public abstract APIConfig fetchApiConfig(String configName);

public abstract void bulkWriteSingleTypeInfo(List<Object> writesForApiInfo);
Expand Down
6 changes: 6 additions & 0 deletions libs/utils/src/main/java/com/akto/data_actor/DbActor.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.akto.data_actor;

import com.akto.dao.ApiCollectionsDao;
import com.akto.dto.*;
import com.akto.dto.ApiInfo.ApiInfoKey;
import com.akto.dto.billing.Organization;
Expand Down Expand Up @@ -70,6 +71,11 @@ public void updateApiCollectionNameForVxlan(int vxlanId, String name) {
DbLayer.updateApiCollectionName(vxlanId, name);
}

@Override
public void updateTransportType(int apiCollectionId, String transportType) {
ApiCollectionsDao.instance.updateTransportType(ApiCollectionsDao.instance.getMeta(apiCollectionId), transportType);
}

public APIConfig fetchApiConfig(String configName) {
return DbLayer.fetchApiconfig(configName);
}
Expand Down
10 changes: 9 additions & 1 deletion libs/utils/src/main/java/com/akto/testing/ApiExecutor.java
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,8 @@ private static OriginalHttpResponse sendWithRequestBody(OriginalHttpRequest requ

if (payload == null) payload = "";
if (body == null) {// body not created by GRPC block yet
if (request.getHeaders().containsKey("charset")) {
// Create body with null MediaType for JSON-RPC requests to prevent OkHttp from adding charset parameter
if (request.getHeaders().containsKey("charset") || isJsonRpcRequest(request)) {
body = RequestBody.create(payload, null);
request.getHeaders().remove("charset");
} else {
Expand Down Expand Up @@ -664,6 +665,13 @@ private static String waitForMatchingSseMessage(SseSession session, String id, l
throw new Exception("Timeout waiting for SSE message with id=" + id);
}


public static OriginalHttpResponse sendRequestSkipSse(OriginalHttpRequest request, boolean followRedirects,
TestingRunConfig testingRunConfig, boolean debug, List<TestingRunResult.TestLog> testLogs,
boolean skipSSRFCheck) throws Exception {
return sendRequest(request, followRedirects, testingRunConfig, debug, testLogs, skipSSRFCheck, true);
}

public static OriginalHttpResponse sendRequestWithSse(OriginalHttpRequest request, boolean followRedirects,
TestingRunConfig testingRunConfig, boolean debug, List<TestingRunResult.TestLog> testLogs,
boolean skipSSRFCheck, boolean overrideMessageEndpoint) throws Exception {
Expand Down
Loading