Skip to content

Cache matching endpoints in context #565

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion library/agent/Context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { extractStringsFromUserInput } from "../helpers/extractStringsFromUserIn
import { ContextStorage } from "./context/ContextStorage";
import { AsyncResource } from "async_hooks";
import { Source, SOURCES } from "./Source";
import type { Endpoint } from "./Config";

export type User = { id: string; name?: string };

Expand Down Expand Up @@ -30,6 +31,7 @@ export type Context = {
*/
outgoingRequestRedirects?: { source: URL; destination: URL }[];
executedMiddleware?: boolean;
cachedMatchingEndpoints?: Endpoint[];
};

/**
Expand All @@ -56,6 +58,11 @@ export function updateContext<K extends keyof Context>(
if (context.cache && isSourceKey(key)) {
context.cache.delete(key);
}

// Delete cacheMatchingEndpoints if method, url, or route changes
if (key === "method" || key === "url" || key === "route") {
delete context.cachedMatchingEndpoints;
}
}

/**
Expand All @@ -69,8 +76,17 @@ export function runWithContext<T>(context: Context, fn: () => T) {
const current = ContextStorage.getStore();

// If there is already a context, we just update it
// In this way we don't lose the `attackDetected` flag
// In this way we don't lose the `attackDetected` flag or other properties like `markUnsafe` or `xml`
if (current) {
if (
current.method !== context.method ||
current.url !== context.url ||
current.route !== context.route
) {
// If the method, url, or route changes, we need to delete the cacheMatchingEndpoints
delete current.cachedMatchingEndpoints;
}

current.url = context.url;
current.method = context.method;
current.query = context.query;
Expand Down
144 changes: 123 additions & 21 deletions library/agent/ServiceConfig.test.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,36 @@
import * as t from "tap";
import { ServiceConfig } from "./ServiceConfig";
import {
getContext,
runWithContext,
updateContext,
type Context,
} from "./Context";

const getTestContext = (
url: string | undefined,
method: string | undefined,
route: string | undefined
): Context => ({
url,
method,
route,
query: {},
headers: {},
routeParams: {},
remoteAddress: undefined,
body: undefined,
cookies: {},
source: "http",
});

t.test("it returns false if empty rules", async () => {
const config = new ServiceConfig([], 0, [], [], false, [], []);
t.same(config.getLastUpdatedAt(), 0);
t.same(config.isUserBlocked("id"), false);
t.same(config.isBypassedIP("1.2.3.4"), false);
t.same(
config.getEndpoints({
url: undefined,
method: undefined,
route: undefined,
}),
config.getEndpoints(getTestContext(undefined, undefined, undefined)),
[]
);
});
Expand Down Expand Up @@ -60,25 +79,18 @@ t.test("it works", async () => {

t.same(config.isUserBlocked("123"), true);
t.same(config.isUserBlocked("567"), false);
t.same(
config.getEndpoints({
url: undefined,
t.same(config.getEndpoints(getTestContext("/foo", "GET", "/foo")), [
{
method: "GET",
route: "/foo",
}),
[
{
method: "GET",
route: "/foo",
forceProtectionOff: false,
rateLimiting: {
enabled: false,
maxRequests: 0,
windowSizeInMS: 0,
},
forceProtectionOff: false,
rateLimiting: {
enabled: false,
maxRequests: 0,
windowSizeInMS: 0,
},
]
);
},
]);
});

t.test("it checks if IP is bypassed", async () => {
Expand Down Expand Up @@ -254,3 +266,93 @@ t.test("bypassed ips support cidr", async () => {
t.same(config.isBypassedIP("123.123.123.1"), false);
t.same(config.isBypassedIP("999.999.999.999"), false);
});

t.test("matching endpoints are cached", async () => {
const config = new ServiceConfig(
[
{
method: "GET",
route: "/foo",
forceProtectionOff: false,
rateLimiting: {
enabled: false,
maxRequests: 0,
windowSizeInMS: 0,
},
},
],
0,
[],
[],
false,
[],
[]
);

const testContext = getTestContext("/foo", "GET", "/foo");
t.same(config.getEndpoints(testContext), [
{
method: "GET",
route: "/foo",
forceProtectionOff: false,
rateLimiting: {
enabled: false,
maxRequests: 0,
windowSizeInMS: 0,
},
},
]);

t.same(testContext.cachedMatchingEndpoints, [
{
method: "GET",
route: "/foo",
forceProtectionOff: false,
rateLimiting: {
enabled: false,
maxRequests: 0,
windowSizeInMS: 0,
},
},
]);
// Cached
t.same(config.getEndpoints(testContext), [
{
method: "GET",
route: "/foo",
forceProtectionOff: false,
rateLimiting: {
enabled: false,
maxRequests: 0,
windowSizeInMS: 0,
},
},
]);

// Clears cache
updateContext(testContext, "method", "POST");
t.same(testContext.cachedMatchingEndpoints, undefined);
t.same(config.getEndpoints(testContext), []);
t.same(testContext.cachedMatchingEndpoints, []);

runWithContext(testContext, () => {
t.same(getContext()!.cachedMatchingEndpoints, []);
});

runWithContext(
{
...testContext,
route: "/bar",
},
() => {
const context = getContext();
if (!context) {
t.fail("context is undefined");
return;
}
t.same(context.cachedMatchingEndpoints, []);
t.same(config.getEndpoints(context), []);
t.same(context.cachedMatchingEndpoints, []);
}
);
});
11 changes: 9 additions & 2 deletions library/agent/ServiceConfig.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { IPMatcher } from "../helpers/ip-matcher/IPMatcher";
import { LimitedContext, matchEndpoints } from "../helpers/matchEndpoints";
import { isPrivateIP } from "../vulnerabilities/ssrf/isPrivateIP";
import { Endpoint } from "./Config";
import { Context, updateContext } from "./Context";
import { IPList } from "./api/fetchBlockedLists";

export class ServiceConfig {
Expand Down Expand Up @@ -45,8 +46,14 @@ export class ServiceConfig {
);
}

getEndpoints(context: LimitedContext) {
return matchEndpoints(context, this.nonGraphQLEndpoints);
getEndpoints(context: Context): Endpoint[] {
if (context.cachedMatchingEndpoints) {
return context.cachedMatchingEndpoints;
}
const endpoints = matchEndpoints(context, this.nonGraphQLEndpoints);
// Cache the endpoints to avoid re-running the matchEndpoints function
updateContext(context, "cachedMatchingEndpoints", endpoints);
return endpoints;
}

getGraphQLField(
Expand Down
32 changes: 17 additions & 15 deletions library/agent/applyHooks.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,19 @@ import { Hooks } from "./hooks/Hooks";
import { wrapExport } from "./hooks/wrapExport";
import { createTestAgent } from "../helpers/createTestAgent";

const context: Context = {
remoteAddress: "::1",
method: "POST",
url: "http://localhost:4000",
query: {},
headers: {},
body: undefined,
cookies: {},
routeParams: {},
source: "express",
route: "/posts/:id",
const getContext = (): Context => {
return {
remoteAddress: "::1",
method: "POST",
url: "http://localhost:4000",
query: {},
headers: {},
body: undefined,
cookies: {},
routeParams: {},
source: "express",
route: "/posts/:id",
};
};

const reportingAPI = new ReportingAPIForTesting();
Expand Down Expand Up @@ -74,7 +76,7 @@ t.test(

applyHooks(hooks);

await runWithContext(context, async () => {
await runWithContext(getContext(), async () => {
await fetch("https://app.aikido.dev");
t.same(modifyCalled, true);

Expand Down Expand Up @@ -127,7 +129,7 @@ t.test("it ignores route if force protection off is on", async (t) => {
await lookup("www.google.com");
t.same(inspectionCalls, [{ args: ["www.google.com"] }]);

await runWithContext(context, async () => {
await runWithContext(getContext(), async () => {
await lookup("www.aikido.dev");
});

Expand All @@ -138,7 +140,7 @@ t.test("it ignores route if force protection off is on", async (t) => {

await runWithContext(
{
...context,
...getContext(),
method: "GET",
route: "/route",
},
Expand Down Expand Up @@ -187,7 +189,7 @@ t.test("it does not report attack if IP is allowed", async (t) => {

const { hostname } = require("os");

await runWithContext(context, async () => {
await runWithContext(getContext(), async () => {
const name = hostname();
t.ok(typeof name === "string");
});
Expand Down
18 changes: 18 additions & 0 deletions library/agent/context/isProtectionOffForRoute.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import type { Agent } from "../Agent";
import { type Context } from "../Context";

export function isProtectionOffForRoute(
agent: Agent,
context: Readonly<Context> | undefined
): boolean {
if (!context) {
return false;
}

const matches = agent.getConfig().getEndpoints(context);
const protectionOff = matches.some(
(match) => match.forceProtectionOff === true
);

return protectionOff;
}
9 changes: 3 additions & 6 deletions library/agent/hooks/wrapExport.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import type { InterceptorResult } from "./InterceptorResult";
import type { WrapPackageInfo } from "./WrapPackageInfo";
import { wrapDefaultOrNamed } from "./wrapDefaultOrNamed";
import { onInspectionInterceptorResult } from "./onInspectionInterceptorResult";
import { isProtectionOffForRoute } from "../context/isProtectionOffForRoute";

type InspectArgsInterceptor = (
args: unknown[],
Expand Down Expand Up @@ -128,12 +129,8 @@ function inspectArgs(
pkgInfo: WrapPackageInfo,
methodName: string
) {
if (context) {
const matches = agent.getConfig().getEndpoints(context);

if (matches.find((match) => match.forceProtectionOff)) {
return;
}
if (isProtectionOffForRoute(agent, context)) {
return;
}

const start = performance.now();
Expand Down
Loading