diff --git a/grpc-ratelimiter-utils/build.gradle.kts b/grpc-ratelimiter-utils/build.gradle.kts new file mode 100644 index 0000000..2aefe8f --- /dev/null +++ b/grpc-ratelimiter-utils/build.gradle.kts @@ -0,0 +1,28 @@ +plugins { + `java-library` + jacoco + id("org.hypertrace.publish-plugin") + id("org.hypertrace.jacoco-report-plugin") +} + +dependencies { + + api(platform("io.grpc:grpc-bom:1.68.3")) + api("io.grpc:grpc-api") + api(project(":grpc-context-utils")) + + implementation("org.slf4j:slf4j-api:1.7.36") + implementation("com.google.guava:guava:32.0.1-jre") + implementation("com.bucket4j:bucket4j-core:8.7.0") + + annotationProcessor("org.projectlombok:lombok:1.18.24") + compileOnly("org.projectlombok:lombok:1.18.24") + + testImplementation("org.junit.jupiter:junit-jupiter:5.8.2") + testImplementation("org.mockito:mockito-core:5.8.0") + testImplementation("org.mockito:mockito-junit-jupiter:5.8.0") +} + +tasks.test { + useJUnitPlatform() +} diff --git a/grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/RateLimiter.java b/grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/RateLimiter.java new file mode 100644 index 0000000..e8d5a7b --- /dev/null +++ b/grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/RateLimiter.java @@ -0,0 +1,10 @@ +package org.hypertrace.ratelimiter.grpcutils; + +public interface RateLimiter { + default boolean tryAcquire(String key, RateLimiterConfiguration.RateLimit rateLimit) { + return tryAcquire(key, 1, rateLimit); + } // default single token + + boolean tryAcquire( + String key, int permits, RateLimiterConfiguration.RateLimit rateLimit); // new: batch tokens +} diff --git a/grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/RateLimiterConfiguration.java b/grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/RateLimiterConfiguration.java new file mode 100644 index 0000000..3f120b9 --- /dev/null +++ b/grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/RateLimiterConfiguration.java @@ -0,0 +1,30 @@ +package org.hypertrace.ratelimiter.grpcutils; + +import java.util.Map; +import java.util.function.BiFunction; +import lombok.Builder; +import lombok.Value; +import org.hypertrace.core.grpcutils.context.RequestContext; + +@Value +@Builder +public class RateLimiterConfiguration { + boolean enabled; + String method; + // Attributes to match like tenant_id -> traceable + Map matchAttributes; + + // Extract attributes from gRPC request + BiFunction> attributeExtractor; + + // Token cost evaluator (can be static 1 or dynamic based on message) + @Builder.Default BiFunction tokenCostFunction = (ctx, req) -> 1; + RateLimit rateLimit; + + @Value + @Builder + public static class RateLimit { + int tokens; + int refreshPeriodSeconds; + } +} diff --git a/grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/RateLimiterFactory.java b/grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/RateLimiterFactory.java new file mode 100644 index 0000000..983763d --- /dev/null +++ b/grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/RateLimiterFactory.java @@ -0,0 +1,5 @@ +package org.hypertrace.ratelimiter.grpcutils; + +public interface RateLimiterFactory { + RateLimiter getRateLimiter(RateLimiterConfiguration rateLimiterConfiguration); +} diff --git a/grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/RateLimiterFactoryProvider.java b/grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/RateLimiterFactoryProvider.java new file mode 100644 index 0000000..76e5853 --- /dev/null +++ b/grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/RateLimiterFactoryProvider.java @@ -0,0 +1,14 @@ +package org.hypertrace.ratelimiter.grpcutils; + +import org.hypertrace.ratelimiter.grpcutils.bucket4j.Bucket4jRateLimiterFactory; + +public final class RateLimiterFactoryProvider { + + private RateLimiterFactoryProvider() { + // Prevent instantiation + } + + public static Bucket4jRateLimiterFactory bucket4j() { + return new Bucket4jRateLimiterFactory(); + } +} diff --git a/grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/RateLimiterInterceptor.java b/grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/RateLimiterInterceptor.java new file mode 100644 index 0000000..1e864f0 --- /dev/null +++ b/grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/RateLimiterInterceptor.java @@ -0,0 +1,74 @@ +package org.hypertrace.ratelimiter.grpcutils; + +import io.grpc.ForwardingServerCallListener; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; +import org.hypertrace.core.grpcutils.context.RequestContext; + +public class RateLimiterInterceptor implements ServerInterceptor { + + private final List + rateLimitConfigs; // Provided via config or dynamic update + private final RateLimiterFactory rateLimiterFactory; + + public RateLimiterInterceptor( + List rateLimitConfigs, RateLimiterFactory factory) { + this.rateLimitConfigs = rateLimitConfigs; + this.rateLimiterFactory = factory; + } + + @Override + public ServerCall.Listener interceptCall( + ServerCall call, Metadata headers, ServerCallHandler next) { + + String method = call.getMethodDescriptor().getFullMethodName(); + + return new ForwardingServerCallListener.SimpleForwardingServerCallListener<>( + next.startCall(call, headers)) { + @Override + public void onMessage(ReqT message) { + RequestContext requestContext = RequestContext.fromMetadata(headers); + for (RateLimiterConfiguration config : rateLimitConfigs) { + if (!config.getMethod().equals(method)) continue; + + Map attributes = + config.getAttributeExtractor().apply(requestContext, message); + + if (!matches(config.getMatchAttributes(), attributes)) continue; + int tokens = config.getTokenCostFunction().apply(requestContext, message); + String key = buildRateLimitKey(method, config.getMatchAttributes(), attributes); + boolean allowed = + rateLimiterFactory + .getRateLimiter(config) + .tryAcquire(key, tokens, config.getRateLimit()); + if (!allowed) { + call.close(Status.RESOURCE_EXHAUSTED.withDescription("Rate limit exceeded"), headers); + return; + } + } + super.onMessage(message); + } + }; + } + + private boolean matches(Map match, Map actual) { + return match.entrySet().stream() + .allMatch(e -> Objects.equals(actual.get(e.getKey()), e.getValue())); + } + + private String buildRateLimitKey( + String method, Map keys, Map attrs) { + return method + + "::" + + keys.keySet().stream() + .map(k -> attrs.getOrDefault(k, "null")) + .collect(Collectors.joining(":")); + } +} diff --git a/grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/bucket4j/Bucket4jRateLimiterFactory.java b/grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/bucket4j/Bucket4jRateLimiterFactory.java new file mode 100644 index 0000000..6ed61dd --- /dev/null +++ b/grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/bucket4j/Bucket4jRateLimiterFactory.java @@ -0,0 +1,43 @@ +package org.hypertrace.ratelimiter.grpcutils.bucket4j; + +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import io.github.bucket4j.Bucket; +import io.grpc.Status; +import java.time.Duration; +import java.util.concurrent.ExecutionException; +import org.hypertrace.ratelimiter.grpcutils.RateLimiter; +import org.hypertrace.ratelimiter.grpcutils.RateLimiterConfiguration; +import org.hypertrace.ratelimiter.grpcutils.RateLimiterFactory; + +public class Bucket4jRateLimiterFactory implements RateLimiterFactory { + + private final Cache limiterCache = + CacheBuilder.newBuilder().maximumSize(10_000).build(); + + @Override + public RateLimiter getRateLimiter(RateLimiterConfiguration rule) { + return (key, tokens, limit) -> { + try { + Bucket bucket = limiterCache.get(key, () -> createBucket(limit)); + return bucket.tryConsume(tokens); + } catch (ExecutionException e) { + throw Status.INTERNAL + .withDescription("Failed to create rate limiter bucket for key: " + key) + .withCause(e) + .asRuntimeException(); + } + }; + } + + private Bucket createBucket(RateLimiterConfiguration.RateLimit limit) { + return Bucket.builder() + .addLimit( + bandwidth -> + bandwidth + .capacity(limit.getTokens()) + .refillGreedy( + limit.getTokens(), Duration.ofSeconds(limit.getRefreshPeriodSeconds()))) + .build(); + } +} diff --git a/grpc-ratelimiter-utils/src/test/java/org/hypertrace/ratelimiter/grpcutils/RateLimiterConfigurationTest.java b/grpc-ratelimiter-utils/src/test/java/org/hypertrace/ratelimiter/grpcutils/RateLimiterConfigurationTest.java new file mode 100644 index 0000000..96a8a2f --- /dev/null +++ b/grpc-ratelimiter-utils/src/test/java/org/hypertrace/ratelimiter/grpcutils/RateLimiterConfigurationTest.java @@ -0,0 +1,147 @@ +package org.hypertrace.ratelimiter.grpcutils; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.Map; +import java.util.function.BiFunction; +import org.hypertrace.core.grpcutils.context.RequestContext; +import org.junit.jupiter.api.Test; + +class RateLimiterConfigurationTest { + + @Test + void testDefaultValues() { + // Build a minimal RateLimiterConfiguration using only required fields + RateLimiterConfiguration configuration = + RateLimiterConfiguration.builder() + .method("testMethod") + .rateLimit( + RateLimiterConfiguration.RateLimit.builder() + .tokens(10) + .refreshPeriodSeconds(30) + .build()) + .build(); + + // Assertions for default values + assertFalse(configuration.isEnabled()); // `enabled` defaults to true + assertEquals("testMethod", configuration.getMethod()); + assertNull(configuration.getMatchAttributes()); // `matchAttributes` defaults to null + assertNotNull(configuration.getTokenCostFunction()); // Ensure tokenCostFunction is initialized + assertEquals( + 1, configuration.getTokenCostFunction().apply(null, null)); // Default token cost value + } + + @Test + void testCustomMatchAttributes() { + // Create matchAttributes map and build configuration + Map matchAttributes = Map.of("tenant_id", "traceable"); + RateLimiterConfiguration configuration = + RateLimiterConfiguration.builder() + .method("testMethod") + .matchAttributes(matchAttributes) + .rateLimit( + RateLimiterConfiguration.RateLimit.builder() + .tokens(100) + .refreshPeriodSeconds(60) + .build()) + .build(); + + // Verify matchAttributes are correctly set + assertEquals(matchAttributes, configuration.getMatchAttributes()); + } + + @Test + void testCustomTokenCostFunction() { + // Define a custom tokenCostFunction + BiFunction customTokenCostFunction = + (ctx, req) -> (req instanceof Integer) ? (Integer) req : 5; + + // Build a configuration using the custom token cost function + RateLimiterConfiguration configuration = + RateLimiterConfiguration.builder() + .method("computeTokenCost") + .tokenCostFunction(customTokenCostFunction) + .rateLimit( + RateLimiterConfiguration.RateLimit.builder() + .tokens(50) + .refreshPeriodSeconds(15) + .build()) + .build(); + + // Verify behavior of the custom token cost function + assertEquals(10, configuration.getTokenCostFunction().apply(null, 10)); // Dynamic cost + assertEquals( + 5, configuration.getTokenCostFunction().apply(null, "randomObject")); // Default cost + } + + @Test + void testAttributeExtractor() { + // Define an attributeExtractor that extracts specific attributes from the request + BiFunction> customAttributeExtractor = + (ctx, req) -> Map.of("attributeKey", "attributeValue"); + + // Build the configuration with the custom attribute extractor + RateLimiterConfiguration configuration = + RateLimiterConfiguration.builder() + .method("extractAttributes") + .attributeExtractor(customAttributeExtractor) + .rateLimit( + RateLimiterConfiguration.RateLimit.builder() + .tokens(20) + .refreshPeriodSeconds(45) + .build()) + .build(); + + // Verify the custom attribute extractor + assertEquals( + Map.of("attributeKey", "attributeValue"), + configuration.getAttributeExtractor().apply(null, null)); + } + + @Test + void testRateLimitConfiguration() { + // Build a simple RateLimit configuration + RateLimiterConfiguration.RateLimit rateLimit = + RateLimiterConfiguration.RateLimit.builder().tokens(500).refreshPeriodSeconds(300).build(); + + // Verify RateLimit configuration values + assertEquals(500, rateLimit.getTokens()); + assertEquals(300, rateLimit.getRefreshPeriodSeconds()); + } + + @Test + void testFullCustomConfiguration() { + // Define custom token cost function and attribute extractor + BiFunction tokenCostFunction = (ctx, req) -> 2; + BiFunction> attributeExtractor = + (ctx, req) -> Map.of("tenant_id", "12345"); + + // Build a complete custom configuration + RateLimiterConfiguration configuration = + RateLimiterConfiguration.builder() + .method("fullCustomMethod") + .enabled(false) + .matchAttributes(Map.of("region", "us-west")) + .attributeExtractor(attributeExtractor) + .tokenCostFunction(tokenCostFunction) + .rateLimit( + RateLimiterConfiguration.RateLimit.builder() + .tokens(1000) + .refreshPeriodSeconds(60) + .build()) + .build(); + + // Verify all custom configurations + assertFalse(configuration.isEnabled()); // Verify `enabled` value + assertEquals("fullCustomMethod", configuration.getMethod()); // Verify method + assertEquals( + Map.of("region", "us-west"), configuration.getMatchAttributes()); // Verify matchAttributes + assertEquals( + Map.of("tenant_id", "12345"), + configuration.getAttributeExtractor().apply(null, null)); // Custom extractor + assertEquals(2, configuration.getTokenCostFunction().apply(null, null)); // Custom token cost + assertEquals(1000, configuration.getRateLimit().getTokens()); // RateLimit tokens + assertEquals( + 60, configuration.getRateLimit().getRefreshPeriodSeconds()); // RateLimit refresh period + } +} diff --git a/grpc-ratelimiter-utils/src/test/java/org/hypertrace/ratelimiter/grpcutils/RateLimiterInterceptorTest.java b/grpc-ratelimiter-utils/src/test/java/org/hypertrace/ratelimiter/grpcutils/RateLimiterInterceptorTest.java new file mode 100644 index 0000000..1fbde64 --- /dev/null +++ b/grpc-ratelimiter-utils/src/test/java/org/hypertrace/ratelimiter/grpcutils/RateLimiterInterceptorTest.java @@ -0,0 +1,171 @@ +package org.hypertrace.ratelimiter.grpcutils; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.mockito.Mockito.*; + +import io.grpc.*; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +class RateLimiterInterceptorTest { + + @Mock private ServerCall mockServerCall; + @Mock private Metadata mockMetadata; + @Mock private ServerCallHandler mockHandler; + @Mock private ServerCall.Listener mockListener; + @Mock private RateLimiter mockRateLimiter; + @Mock private RateLimiterFactory mockRateLimiterFactory; + + private static final String TEST_METHOD = "org.example.TestService/TestMethod"; + + private RateLimiterInterceptor rateLimiterInterceptor; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + + // Mock the ServerCallHandler to return the mock listener + when(mockHandler.startCall(any(), any())).thenReturn(mockListener); + + // Mock the ServerCall's method descriptor + when(mockServerCall.getMethodDescriptor()) + .thenReturn( + MethodDescriptor.newBuilder() + .setFullMethodName(TEST_METHOD) + .setType(MethodDescriptor.MethodType.UNARY) + .setRequestMarshaller(mock(MethodDescriptor.Marshaller.class)) + .setResponseMarshaller(mock(MethodDescriptor.Marshaller.class)) + .build()); + + // Initialize RateLimiterInterceptor with a sample configuration + rateLimiterInterceptor = + new RateLimiterInterceptor( + List.of( + RateLimiterConfiguration.builder() + .method(TEST_METHOD) + .matchAttributes(Map.of("tenant_id", "123")) // Match attributes for this test + .attributeExtractor( + (ctx, req) -> Map.of("tenant_id", "123")) // Extracting mock attributes + .tokenCostFunction((ctx, req) -> 1) // Token cost per request + .rateLimit( + RateLimiterConfiguration.RateLimit.builder() + .tokens(10) // Bucket size + .refreshPeriodSeconds(60) // Bucket refresh period + .build()) + .build()), + mockRateLimiterFactory); + } + + @Test + void testRequestAllowedWhenRateLimitNotExceeded() { + when(mockRateLimiterFactory.getRateLimiter(any())).thenReturn(mockRateLimiter); + // Mock rate limiter to allow the request + when(mockRateLimiter.tryAcquire(anyString(), anyInt(), any())).thenReturn(true); + + // Intercept the call + ServerCall.Listener listener = + rateLimiterInterceptor.interceptCall(mockServerCall, mockMetadata, mockHandler); + + // Ensure listener isn't null + assertNotNull(listener); + + // Call onMessage to assert normal behavior + listener.onMessage("testMessage"); + + // Verify server call is not closed with error + verify(mockServerCall, never()).close(any(Status.class), any(Metadata.class)); + + // Verify message is passed downstream + verify(mockListener, times(1)).onMessage("testMessage"); + } + + @Test + void testRequestRejectedWhenRateLimitExceeded() { + when(mockRateLimiterFactory.getRateLimiter(any())).thenReturn(mockRateLimiter); + // Mock rate limiter to reject the request + when(mockRateLimiter.tryAcquire(anyString(), anyInt(), any())).thenReturn(false); + + // Intercept the call + ServerCall.Listener listener = + rateLimiterInterceptor.interceptCall(mockServerCall, mockMetadata, mockHandler); + + // Ensure listener isn't null + assertNotNull(listener); + + // Call onMessage to trigger the rate limiting logic + listener.onMessage("testMessage"); + + verify(mockServerCall).getMethodDescriptor(); + + // Verify server call is closed with RESOURCE_EXHAUSTED error + verify(mockServerCall, times(1)) + .close( + argThat( + status -> + status.getCode() == Status.RESOURCE_EXHAUSTED.getCode() + && "Rate limit exceeded".equals(status.getDescription())), + eq(mockMetadata)); + + // Verify the mock listener's onMessage is not called + verify(mockListener, never()).onMessage(any()); + } + + @Test + void testInterceptorSkipsIfNoMatchingRateLimitConfig() { + // Mock a different method name that doesn't match our rate limiter configuration + when(mockServerCall.getMethodDescriptor()) + .thenReturn( + MethodDescriptor.newBuilder() + .setFullMethodName("org.example.OtherService/OtherMethod") // Different method name + .setType(MethodDescriptor.MethodType.UNARY) + .setRequestMarshaller(mock(MethodDescriptor.Marshaller.class)) + .setResponseMarshaller(mock(MethodDescriptor.Marshaller.class)) + .build()); + + // Intercept the call + ServerCall.Listener listener = + rateLimiterInterceptor.interceptCall(mockServerCall, mockMetadata, mockHandler); + + // Ensure listener isn't null + assertNotNull(listener); + + // Call onMessage to test behavior for unmatched configurations + listener.onMessage("testMessage"); + + // Verify that the server call is not closed + verify(mockServerCall, never()).close(any(Status.class), any(Metadata.class)); + + // Verify the message is passed downstream + verify(mockListener, times(1)).onMessage("testMessage"); + } + + @Test + void testInterceptorHandlesEmptyRateLimiterConfiguration() { + // Initialize RateLimiterInterceptor with no configurations + RateLimiterInterceptor emptyConfigInterceptor = + new RateLimiterInterceptor(List.of(), mockRateLimiterFactory); + + // Intercept the call + ServerCall.Listener listener = + emptyConfigInterceptor.interceptCall(mockServerCall, mockMetadata, mockHandler); + + // Ensure listener isn't null + assertNotNull(listener); + + // Call onMessage to ensure normal behavior + listener.onMessage("testMessage"); + + // Verify server call is not closed with error + verify(mockServerCall, never()).close(any(Status.class), any(Metadata.class)); + + // Verify message is passed downstream + verify(mockListener, times(1)).onMessage("testMessage"); + } +} diff --git a/settings.gradle.kts b/settings.gradle.kts index 94ab61c..5ff340c 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -19,3 +19,4 @@ include(":grpc-context-utils") include(":grpc-server-utils") include(":grpc-validation-utils") include(":grpc-circuitbreaker-utils") +include(":grpc-ratelimiter-utils")