diff --git a/library/agent/Context.ts b/library/agent/Context.ts index f324feb8d..0f7acf3d1 100644 --- a/library/agent/Context.ts +++ b/library/agent/Context.ts @@ -3,6 +3,7 @@ import { extractStringsFromUserInput } from "../helpers/extractStringsFromUserIn import { ContextStorage } from "./context/ContextStorage"; import { AsyncResource } from "async_hooks"; import { Source, SOURCES } from "./Source"; +import type { Endpoint } from "./Config"; export type User = { id: string; name?: string }; @@ -30,6 +31,7 @@ export type Context = { */ outgoingRequestRedirects?: { source: URL; destination: URL }[]; executedMiddleware?: boolean; + cachedMatchingEndpoints?: Endpoint[]; }; /** @@ -56,6 +58,11 @@ export function updateContext( if (context.cache && isSourceKey(key)) { context.cache.delete(key); } + + // Delete cacheMatchingEndpoints if method, url, or route changes + if (key === "method" || key === "url" || key === "route") { + delete context.cachedMatchingEndpoints; + } } /** @@ -69,8 +76,17 @@ export function runWithContext(context: Context, fn: () => T) { const current = ContextStorage.getStore(); // If there is already a context, we just update it - // In this way we don't lose the `attackDetected` flag + // In this way we don't lose the `attackDetected` flag or other properties like `markUnsafe` or `xml` if (current) { + if ( + current.method !== context.method || + current.url !== context.url || + current.route !== context.route + ) { + // If the method, url, or route changes, we need to delete the cacheMatchingEndpoints + delete current.cachedMatchingEndpoints; + } + current.url = context.url; current.method = context.method; current.query = context.query; diff --git a/library/agent/ServiceConfig.test.ts b/library/agent/ServiceConfig.test.ts index ef291e4a9..06ecb2296 100644 --- a/library/agent/ServiceConfig.test.ts +++ b/library/agent/ServiceConfig.test.ts @@ -1,5 +1,28 @@ import * as t from "tap"; import { ServiceConfig } from "./ServiceConfig"; +import { + getContext, + runWithContext, + updateContext, + type Context, +} from "./Context"; + +const getTestContext = ( + url: string | undefined, + method: string | undefined, + route: string | undefined +): Context => ({ + url, + method, + route, + query: {}, + headers: {}, + routeParams: {}, + remoteAddress: undefined, + body: undefined, + cookies: {}, + source: "http", +}); t.test("it returns false if empty rules", async () => { const config = new ServiceConfig([], 0, [], [], false, [], []); @@ -7,11 +30,7 @@ t.test("it returns false if empty rules", async () => { t.same(config.isUserBlocked("id"), false); t.same(config.isBypassedIP("1.2.3.4"), false); t.same( - config.getEndpoints({ - url: undefined, - method: undefined, - route: undefined, - }), + config.getEndpoints(getTestContext(undefined, undefined, undefined)), [] ); }); @@ -60,25 +79,18 @@ t.test("it works", async () => { t.same(config.isUserBlocked("123"), true); t.same(config.isUserBlocked("567"), false); - t.same( - config.getEndpoints({ - url: undefined, + t.same(config.getEndpoints(getTestContext("/foo", "GET", "/foo")), [ + { method: "GET", route: "/foo", - }), - [ - { - method: "GET", - route: "/foo", - forceProtectionOff: false, - rateLimiting: { - enabled: false, - maxRequests: 0, - windowSizeInMS: 0, - }, + forceProtectionOff: false, + rateLimiting: { + enabled: false, + maxRequests: 0, + windowSizeInMS: 0, }, - ] - ); + }, + ]); }); t.test("it checks if IP is bypassed", async () => { @@ -254,3 +266,93 @@ t.test("bypassed ips support cidr", async () => { t.same(config.isBypassedIP("123.123.123.1"), false); t.same(config.isBypassedIP("999.999.999.999"), false); }); + +t.test("matching endpoints are cached", async () => { + const config = new ServiceConfig( + [ + { + method: "GET", + route: "/foo", + forceProtectionOff: false, + rateLimiting: { + enabled: false, + maxRequests: 0, + windowSizeInMS: 0, + }, + }, + ], + 0, + [], + [], + false, + [], + [] + ); + + const testContext = getTestContext("/foo", "GET", "/foo"); + t.same(config.getEndpoints(testContext), [ + { + method: "GET", + route: "/foo", + forceProtectionOff: false, + rateLimiting: { + enabled: false, + maxRequests: 0, + windowSizeInMS: 0, + }, + }, + ]); + + t.same(testContext.cachedMatchingEndpoints, [ + { + method: "GET", + route: "/foo", + forceProtectionOff: false, + rateLimiting: { + enabled: false, + maxRequests: 0, + windowSizeInMS: 0, + }, + }, + ]); + // Cached + t.same(config.getEndpoints(testContext), [ + { + method: "GET", + route: "/foo", + forceProtectionOff: false, + rateLimiting: { + enabled: false, + maxRequests: 0, + windowSizeInMS: 0, + }, + }, + ]); + + // Clears cache + updateContext(testContext, "method", "POST"); + t.same(testContext.cachedMatchingEndpoints, undefined); + t.same(config.getEndpoints(testContext), []); + t.same(testContext.cachedMatchingEndpoints, []); + + runWithContext(testContext, () => { + t.same(getContext()!.cachedMatchingEndpoints, []); + }); + + runWithContext( + { + ...testContext, + route: "/bar", + }, + () => { + const context = getContext(); + if (!context) { + t.fail("context is undefined"); + return; + } + t.same(context.cachedMatchingEndpoints, []); + t.same(config.getEndpoints(context), []); + t.same(context.cachedMatchingEndpoints, []); + } + ); +}); diff --git a/library/agent/ServiceConfig.ts b/library/agent/ServiceConfig.ts index 40b86381f..b9699f21f 100644 --- a/library/agent/ServiceConfig.ts +++ b/library/agent/ServiceConfig.ts @@ -2,6 +2,7 @@ import { IPMatcher } from "../helpers/ip-matcher/IPMatcher"; import { LimitedContext, matchEndpoints } from "../helpers/matchEndpoints"; import { isPrivateIP } from "../vulnerabilities/ssrf/isPrivateIP"; import { Endpoint } from "./Config"; +import { Context, updateContext } from "./Context"; import { IPList } from "./api/fetchBlockedLists"; export class ServiceConfig { @@ -45,8 +46,14 @@ export class ServiceConfig { ); } - getEndpoints(context: LimitedContext) { - return matchEndpoints(context, this.nonGraphQLEndpoints); + getEndpoints(context: Context): Endpoint[] { + if (context.cachedMatchingEndpoints) { + return context.cachedMatchingEndpoints; + } + const endpoints = matchEndpoints(context, this.nonGraphQLEndpoints); + // Cache the endpoints to avoid re-running the matchEndpoints function + updateContext(context, "cachedMatchingEndpoints", endpoints); + return endpoints; } getGraphQLField( diff --git a/library/agent/applyHooks.test.ts b/library/agent/applyHooks.test.ts index 5c3ca692a..ebe4f4ede 100644 --- a/library/agent/applyHooks.test.ts +++ b/library/agent/applyHooks.test.ts @@ -7,17 +7,19 @@ import { Hooks } from "./hooks/Hooks"; import { wrapExport } from "./hooks/wrapExport"; import { createTestAgent } from "../helpers/createTestAgent"; -const context: Context = { - remoteAddress: "::1", - method: "POST", - url: "http://localhost:4000", - query: {}, - headers: {}, - body: undefined, - cookies: {}, - routeParams: {}, - source: "express", - route: "/posts/:id", +const getContext = (): Context => { + return { + remoteAddress: "::1", + method: "POST", + url: "http://localhost:4000", + query: {}, + headers: {}, + body: undefined, + cookies: {}, + routeParams: {}, + source: "express", + route: "/posts/:id", + }; }; const reportingAPI = new ReportingAPIForTesting(); @@ -74,7 +76,7 @@ t.test( applyHooks(hooks); - await runWithContext(context, async () => { + await runWithContext(getContext(), async () => { await fetch("https://app.aikido.dev"); t.same(modifyCalled, true); @@ -127,7 +129,7 @@ t.test("it ignores route if force protection off is on", async (t) => { await lookup("www.google.com"); t.same(inspectionCalls, [{ args: ["www.google.com"] }]); - await runWithContext(context, async () => { + await runWithContext(getContext(), async () => { await lookup("www.aikido.dev"); }); @@ -138,7 +140,7 @@ t.test("it ignores route if force protection off is on", async (t) => { await runWithContext( { - ...context, + ...getContext(), method: "GET", route: "/route", }, @@ -187,7 +189,7 @@ t.test("it does not report attack if IP is allowed", async (t) => { const { hostname } = require("os"); - await runWithContext(context, async () => { + await runWithContext(getContext(), async () => { const name = hostname(); t.ok(typeof name === "string"); }); diff --git a/library/agent/context/isProtectionOffForRoute.ts b/library/agent/context/isProtectionOffForRoute.ts new file mode 100644 index 000000000..28189bb69 --- /dev/null +++ b/library/agent/context/isProtectionOffForRoute.ts @@ -0,0 +1,18 @@ +import type { Agent } from "../Agent"; +import { type Context } from "../Context"; + +export function isProtectionOffForRoute( + agent: Agent, + context: Readonly | undefined +): boolean { + if (!context) { + return false; + } + + const matches = agent.getConfig().getEndpoints(context); + const protectionOff = matches.some( + (match) => match.forceProtectionOff === true + ); + + return protectionOff; +} diff --git a/library/agent/hooks/wrapExport.ts b/library/agent/hooks/wrapExport.ts index 49cc8f28e..ccdc5b85d 100644 --- a/library/agent/hooks/wrapExport.ts +++ b/library/agent/hooks/wrapExport.ts @@ -6,6 +6,7 @@ import type { InterceptorResult } from "./InterceptorResult"; import type { WrapPackageInfo } from "./WrapPackageInfo"; import { wrapDefaultOrNamed } from "./wrapDefaultOrNamed"; import { onInspectionInterceptorResult } from "./onInspectionInterceptorResult"; +import { isProtectionOffForRoute } from "../context/isProtectionOffForRoute"; type InspectArgsInterceptor = ( args: unknown[], @@ -128,12 +129,8 @@ function inspectArgs( pkgInfo: WrapPackageInfo, methodName: string ) { - if (context) { - const matches = agent.getConfig().getEndpoints(context); - - if (matches.find((match) => match.forceProtectionOff)) { - return; - } + if (isProtectionOffForRoute(agent, context)) { + return; } const start = performance.now(); diff --git a/library/ratelimiting/getRateLimitedEndpoint.test.ts b/library/ratelimiting/getRateLimitedEndpoint.test.ts index fd85762a4..b492d6d45 100644 --- a/library/ratelimiting/getRateLimitedEndpoint.test.ts +++ b/library/ratelimiting/getRateLimitedEndpoint.test.ts @@ -3,7 +3,7 @@ import type { Context } from "../agent/Context"; import { ServiceConfig } from "../agent/ServiceConfig"; import { getRateLimitedEndpoint } from "./getRateLimitedEndpoint"; -const context: Context = { +const getContext = (): Context => ({ remoteAddress: "1.2.3.4", method: "POST", url: "https://acme.com/api/login", @@ -14,12 +14,12 @@ const context: Context = { routeParams: {}, source: "express", route: "/api/login", -}; +}); t.test("it returns undefined if no endpoints", async () => { t.same( getRateLimitedEndpoint( - context, + getContext(), new ServiceConfig([], 0, [], [], true, [], []) ), undefined @@ -29,7 +29,7 @@ t.test("it returns undefined if no endpoints", async () => { t.test("it returns undefined if no matching endpoints", async () => { t.same( getRateLimitedEndpoint( - context, + getContext(), new ServiceConfig( [ { @@ -59,7 +59,7 @@ t.test("it returns undefined if no matching endpoints", async () => { t.test("it returns undefined if matching but not enabled", async () => { t.same( getRateLimitedEndpoint( - context, + getContext(), new ServiceConfig( [ { @@ -89,7 +89,7 @@ t.test("it returns undefined if matching but not enabled", async () => { t.test("it returns endpoint if matching and enabled", async () => { t.same( getRateLimitedEndpoint( - context, + getContext(), new ServiceConfig( [ { @@ -129,7 +129,7 @@ t.test("it returns endpoint if matching and enabled", async () => { t.test("it returns endpoint with lowest max requests", async () => { t.same( getRateLimitedEndpoint( - context, + getContext(), new ServiceConfig( [ { @@ -180,7 +180,7 @@ t.test("it returns endpoint with lowest max requests", async () => { t.test("it returns endpoint with smallest window size", async () => { t.same( getRateLimitedEndpoint( - context, + getContext(), new ServiceConfig( [ { @@ -231,7 +231,7 @@ t.test("it returns endpoint with smallest window size", async () => { t.test("it always returns exact matches first", async () => { t.same( getRateLimitedEndpoint( - context, + getContext(), new ServiceConfig( [ { diff --git a/library/vulnerabilities/ssrf/inspectDNSLookupCalls.test.ts b/library/vulnerabilities/ssrf/inspectDNSLookupCalls.test.ts index 0efa34664..d147a0357 100644 --- a/library/vulnerabilities/ssrf/inspectDNSLookupCalls.test.ts +++ b/library/vulnerabilities/ssrf/inspectDNSLookupCalls.test.ts @@ -15,19 +15,21 @@ wrap(console, "log", function log() { }; }); -const context: Context = { - remoteAddress: "::1", - method: "POST", - url: "http://localhost:4000", - query: {}, - headers: {}, - body: { - image: "http://localhost", - }, - cookies: {}, - routeParams: {}, - source: "express", - route: "/posts/:id", +const getContext = (): Context => { + return { + remoteAddress: "::1", + method: "POST", + url: "http://localhost:4000", + query: {}, + headers: {}, + body: { + image: "http://localhost", + }, + cookies: {}, + routeParams: {}, + source: "express", + route: "/posts/:id", + }; }; t.test("it resolves private IPv4 without context", (t) => { @@ -86,7 +88,7 @@ t.test("it blocks lookup in blocking mode", (t) => { "operation" ); - runWithContext(context, () => { + runWithContext(getContext(), () => { wrappedLookup("localhost", {}, (err, address) => { t.same(err instanceof Error, true); if (err instanceof Error) { @@ -130,7 +132,7 @@ t.test("it allows resolved public IP", (t) => { ); runWithContext( - { ...context, body: { image: "http://www.google.be" } }, + { ...getContext(), body: { image: "http://www.google.be" } }, () => { wrappedLookup("www.google.be", {}, (err, address) => { t.same(err, null); @@ -160,7 +162,7 @@ t.test( "operation" ); - runWithContext({ ...context, body: undefined }, () => { + runWithContext({ ...getContext(), body: undefined }, () => { wrappedLookup("localhost", {}, (err, address) => { t.same(err, null); t.same(address, getMajorNodeVersion() === 16 ? "127.0.0.1" : "::1"); @@ -211,7 +213,7 @@ t.test( ); await new Promise((resolve) => { - runWithContext(context, () => { + runWithContext(getContext(), () => { wrappedLookup("localhost", {}, (err, address) => { t.same(err, null); t.same(address, getMajorNodeVersion() === 16 ? "127.0.0.1" : "::1"); @@ -236,7 +238,7 @@ t.test("it blocks lookup in blocking mode with all option", (t) => { "operation" ); - runWithContext(context, () => { + runWithContext(getContext(), () => { wrappedLookup("localhost", { all: true }, (err, address) => { t.same(err instanceof Error, true); if (err instanceof Error) { @@ -268,7 +270,7 @@ t.test("it does not block in dry mode", (t) => { "operation" ); - runWithContext(context, () => { + runWithContext(getContext(), () => { wrappedLookup("localhost", {}, (err, address) => { t.same(err, null); t.same(address, getMajorNodeVersion() === 16 ? "127.0.0.1" : "::1"); @@ -378,7 +380,7 @@ t.test("Blocks IMDS SSRF with untrusted domain", async (t) => { }); }), new Promise((resolve) => { - runWithContext(context, () => { + runWithContext(getContext(), () => { wrappedLookup("imds.test.com", { family: 4 }, (err, address) => { t.same(err instanceof Error, true); if (err instanceof Error) { @@ -433,7 +435,7 @@ t.test( "operation" ); - runWithContext(context, () => { + runWithContext(getContext(), () => { wrappedLookup("imds.test.com", { family: 4 }, (err, address) => { t.same(err, null); t.same(address, "169.254.169.254"); @@ -469,7 +471,7 @@ t.test("Does not block IMDS SSRF with Google metadata domain", async (t) => { ); }), new Promise((resolve) => { - runWithContext(context, () => { + runWithContext(getContext(), () => { wrappedLookup( "metadata.google.internal", { family: 4 }, @@ -500,7 +502,7 @@ t.test("it ignores when the argument is an IP address", async (t) => { await Promise.all([ new Promise((resolve) => { runWithContext( - { ...context, routeParams: { id: "169.254.169.254" } }, + { ...getContext(), routeParams: { id: "169.254.169.254" } }, () => { wrappedLookup("169.254.169.254", {}, (err, address) => { t.same(err, null); @@ -512,7 +514,7 @@ t.test("it ignores when the argument is an IP address", async (t) => { }), new Promise((resolve) => { runWithContext( - { ...context, routeParams: { id: "fd00:ec2::254" } }, + { ...getContext(), routeParams: { id: "fd00:ec2::254" } }, () => { wrappedLookup("fd00:ec2::254", {}, (err, address) => { t.same(err, null); diff --git a/library/vulnerabilities/ssrf/inspectDNSLookupCalls.ts b/library/vulnerabilities/ssrf/inspectDNSLookupCalls.ts index 4731902a4..03a4a6dc1 100644 --- a/library/vulnerabilities/ssrf/inspectDNSLookupCalls.ts +++ b/library/vulnerabilities/ssrf/inspectDNSLookupCalls.ts @@ -15,6 +15,7 @@ import { getRedirectOrigin } from "./getRedirectOrigin"; import { getPortFromURL } from "../../helpers/getPortFromURL"; import { getLibraryRoot } from "../../helpers/getLibraryRoot"; import { cleanError } from "../../helpers/cleanError"; +import { isProtectionOffForRoute } from "../../agent/context/isProtectionOffForRoute"; export function inspectDNSLookupCalls( lookup: Function, @@ -91,14 +92,10 @@ function wrapDNSLookupCallback( const context = getContext(); - if (context) { - const matches = agent.getConfig().getEndpoints(context); - - if (matches.find((endpoint) => endpoint.forceProtectionOff)) { - // User disabled protection for this endpoint, we don't need to inspect the resolved IPs - // Just call the original callback to allow the DNS lookup - return callback(err, addresses, family); - } + if (isProtectionOffForRoute(agent, context)) { + // User disabled protection for this endpoint, we don't need to inspect the resolved IPs + // Just call the original callback to allow the DNS lookup + return callback(err, addresses, family); } const resolvedIPAddresses = getResolvedIPAddresses(addresses);