-
Notifications
You must be signed in to change notification settings - Fork 0
Add ratelimiter framework for grpc server interceptor #73
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
plugins { | ||
`java-library` | ||
jacoco | ||
id("org.hypertrace.publish-plugin") | ||
id("org.hypertrace.jacoco-report-plugin") | ||
} | ||
|
||
dependencies { | ||
|
||
api(platform("io.grpc:grpc-bom:1.68.3")) | ||
api("io.grpc:grpc-api") | ||
api(project(":grpc-context-utils")) | ||
|
||
implementation("org.slf4j:slf4j-api:1.7.36") | ||
implementation("com.google.guava:guava:32.0.1-jre") | ||
implementation("com.bucket4j:bucket4j-core:8.7.0") | ||
|
||
annotationProcessor("org.projectlombok:lombok:1.18.24") | ||
compileOnly("org.projectlombok:lombok:1.18.24") | ||
|
||
testImplementation("org.junit.jupiter:junit-jupiter:5.8.2") | ||
testImplementation("org.mockito:mockito-core:5.8.0") | ||
testImplementation("org.mockito:mockito-junit-jupiter:5.8.0") | ||
} | ||
|
||
tasks.test { | ||
useJUnitPlatform() | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
package org.hypertrace.ratelimiter.grpcutils; | ||
|
||
public interface RateLimiter { | ||
default boolean tryAcquire(String key, RateLimiterConfiguration.RateLimit rateLimit) { | ||
return tryAcquire(key, 1, rateLimit); | ||
Check warning on line 5 in grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/RateLimiter.java
|
||
} // default single token | ||
|
||
boolean tryAcquire( | ||
String key, int permits, RateLimiterConfiguration.RateLimit rateLimit); // new: batch tokens | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
package org.hypertrace.ratelimiter.grpcutils; | ||
|
||
import java.util.Map; | ||
import java.util.function.BiFunction; | ||
import lombok.Builder; | ||
import lombok.Value; | ||
import org.hypertrace.core.grpcutils.context.RequestContext; | ||
|
||
@Value | ||
@Builder | ||
public class RateLimiterConfiguration { | ||
boolean enabled; | ||
String method; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's this? It's not clear from the name |
||
// Attributes to match like tenant_id -> traceable | ||
Map<String, String> matchAttributes; | ||
|
||
// Extract attributes from gRPC request | ||
BiFunction<RequestContext, Object, Map<String, String>> attributeExtractor; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Following on the earlier comment, we could decouple this from grpc by dropping the attribute parts (because in the general case the caller is only passing something to the rate limiter if it wants to rate limit it) and taking the token cost from the call too. Then, for the interceptor wrapper you'd have something like BiPredicate<RequestContext, T> rateLimitMatcher; // Should we rate limit? Alternatively could be combined with permit calculator as a 0 response
BiFunction<RequestContext, T, Object> rateLimitKeyCalculator; // What should they key be? We don't care about the type unless we have a need for human readable keys in which case swap to string
BiFunction<RequestContext, T, Integer> permitCalculator; // Same as before, just renamed/typed |
||
|
||
// Token cost evaluator (can be static 1 or dynamic based on message) | ||
@Builder.Default BiFunction<RequestContext, Object, Integer> tokenCostFunction = (ctx, req) -> 1; | ||
RateLimit rateLimit; | ||
|
||
@Value | ||
@Builder | ||
public static class RateLimit { | ||
int tokens; | ||
int refreshPeriodSeconds; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit - use duration unless the underlying lib only has second granularity. |
||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
package org.hypertrace.ratelimiter.grpcutils; | ||
|
||
public interface RateLimiterFactory { | ||
RateLimiter getRateLimiter(RateLimiterConfiguration rateLimiterConfiguration); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
package org.hypertrace.ratelimiter.grpcutils; | ||
|
||
import org.hypertrace.ratelimiter.grpcutils.bucket4j.Bucket4jRateLimiterFactory; | ||
|
||
public final class RateLimiterFactoryProvider { | ||
|
||
private RateLimiterFactoryProvider() { | ||
// Prevent instantiation | ||
} | ||
|
||
public static Bucket4jRateLimiterFactory bucket4j() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally not public. The point of the abstraction of a provider is so the caller is not concerned with the impl. |
||
return new Bucket4jRateLimiterFactory(); | ||
} | ||
Check warning on line 13 in grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/RateLimiterFactoryProvider.java
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
package org.hypertrace.ratelimiter.grpcutils; | ||
|
||
import io.grpc.ForwardingServerCallListener; | ||
import io.grpc.Metadata; | ||
import io.grpc.ServerCall; | ||
import io.grpc.ServerCallHandler; | ||
import io.grpc.ServerInterceptor; | ||
import io.grpc.Status; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
import java.util.stream.Collectors; | ||
import org.hypertrace.core.grpcutils.context.RequestContext; | ||
|
||
public class RateLimiterInterceptor implements ServerInterceptor { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. name should indicate server vs client as we likely want both eventually. |
||
|
||
private final List<RateLimiterConfiguration> | ||
rateLimitConfigs; // Provided via config or dynamic update | ||
private final RateLimiterFactory rateLimiterFactory; | ||
|
||
public RateLimiterInterceptor( | ||
List<RateLimiterConfiguration> rateLimitConfigs, RateLimiterFactory factory) { | ||
this.rateLimitConfigs = rateLimitConfigs; | ||
this.rateLimiterFactory = factory; | ||
} | ||
|
||
@Override | ||
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall( | ||
ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) { | ||
|
||
String method = call.getMethodDescriptor().getFullMethodName(); | ||
|
||
return new ForwardingServerCallListener.SimpleForwardingServerCallListener<>( | ||
next.startCall(call, headers)) { | ||
@Override | ||
public void onMessage(ReqT message) { | ||
RequestContext requestContext = RequestContext.fromMetadata(headers); | ||
for (RateLimiterConfiguration config : rateLimitConfigs) { | ||
if (!config.getMethod().equals(method)) continue; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing braces here + below. Even for single statement There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So this also answers my earlier question about what |
||
|
||
Map<String, String> attributes = | ||
config.getAttributeExtractor().apply(requestContext, message); | ||
|
||
if (!matches(config.getMatchAttributes(), attributes)) continue; | ||
int tokens = config.getTokenCostFunction().apply(requestContext, message); | ||
String key = buildRateLimitKey(method, config.getMatchAttributes(), attributes); | ||
boolean allowed = | ||
rateLimiterFactory | ||
.getRateLimiter(config) | ||
.tryAcquire(key, tokens, config.getRateLimit()); | ||
if (!allowed) { | ||
call.close(Status.RESOURCE_EXHAUSTED.withDescription("Rate limit exceeded"), headers); | ||
return; | ||
} | ||
} | ||
super.onMessage(message); | ||
} | ||
}; | ||
} | ||
|
||
private boolean matches(Map<String, String> match, Map<String, String> actual) { | ||
return match.entrySet().stream() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we adjust the config as described in an earlier comment this would go away in favor of an explicit predicate. |
||
.allMatch(e -> Objects.equals(actual.get(e.getKey()), e.getValue())); | ||
} | ||
|
||
private String buildRateLimitKey( | ||
String method, Map<String, String> keys, Map<String, String> attrs) { | ||
return method | ||
+ "::" | ||
+ keys.keySet().stream() | ||
.map(k -> attrs.getOrDefault(k, "null")) | ||
.collect(Collectors.joining(":")); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
package org.hypertrace.ratelimiter.grpcutils.bucket4j; | ||
|
||
import com.google.common.cache.Cache; | ||
import com.google.common.cache.CacheBuilder; | ||
import io.github.bucket4j.Bucket; | ||
import io.grpc.Status; | ||
import java.time.Duration; | ||
import java.util.concurrent.ExecutionException; | ||
import org.hypertrace.ratelimiter.grpcutils.RateLimiter; | ||
import org.hypertrace.ratelimiter.grpcutils.RateLimiterConfiguration; | ||
import org.hypertrace.ratelimiter.grpcutils.RateLimiterFactory; | ||
|
||
public class Bucket4jRateLimiterFactory implements RateLimiterFactory { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't be public, must be a singleton |
||
|
||
private final Cache<String, Bucket> limiterCache = | ||
CacheBuilder.newBuilder().maximumSize(10_000).build(); | ||
Check warning on line 16 in grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/bucket4j/Bucket4jRateLimiterFactory.java
|
||
|
||
@Override | ||
public RateLimiter getRateLimiter(RateLimiterConfiguration rule) { | ||
return (key, tokens, limit) -> { | ||
Check warning on line 20 in grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/bucket4j/Bucket4jRateLimiterFactory.java
|
||
try { | ||
Bucket bucket = limiterCache.get(key, () -> createBucket(limit)); | ||
return bucket.tryConsume(tokens); | ||
} catch (ExecutionException e) { | ||
throw Status.INTERNAL | ||
.withDescription("Failed to create rate limiter bucket for key: " + key) | ||
.withCause(e) | ||
.asRuntimeException(); | ||
Check warning on line 28 in grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/bucket4j/Bucket4jRateLimiterFactory.java
|
||
} | ||
}; | ||
} | ||
|
||
private Bucket createBucket(RateLimiterConfiguration.RateLimit limit) { | ||
return Bucket.builder() | ||
.addLimit( | ||
Check warning on line 35 in grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/bucket4j/Bucket4jRateLimiterFactory.java
|
||
bandwidth -> | ||
bandwidth | ||
.capacity(limit.getTokens()) | ||
.refillGreedy( | ||
limit.getTokens(), Duration.ofSeconds(limit.getRefreshPeriodSeconds()))) | ||
.build(); | ||
Check warning on line 41 in grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/bucket4j/Bucket4jRateLimiterFactory.java
|
||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
package org.hypertrace.ratelimiter.grpcutils; | ||
|
||
import static org.junit.jupiter.api.Assertions.*; | ||
|
||
import java.util.Map; | ||
import java.util.function.BiFunction; | ||
import org.hypertrace.core.grpcutils.context.RequestContext; | ||
import org.junit.jupiter.api.Test; | ||
|
||
class RateLimiterConfigurationTest { | ||
|
||
@Test | ||
void testDefaultValues() { | ||
// Build a minimal RateLimiterConfiguration using only required fields | ||
RateLimiterConfiguration configuration = | ||
RateLimiterConfiguration.builder() | ||
.method("testMethod") | ||
.rateLimit( | ||
RateLimiterConfiguration.RateLimit.builder() | ||
.tokens(10) | ||
.refreshPeriodSeconds(30) | ||
.build()) | ||
.build(); | ||
|
||
// Assertions for default values | ||
assertFalse(configuration.isEnabled()); // `enabled` defaults to true | ||
assertEquals("testMethod", configuration.getMethod()); | ||
assertNull(configuration.getMatchAttributes()); // `matchAttributes` defaults to null | ||
assertNotNull(configuration.getTokenCostFunction()); // Ensure tokenCostFunction is initialized | ||
assertEquals( | ||
1, configuration.getTokenCostFunction().apply(null, null)); // Default token cost value | ||
} | ||
|
||
@Test | ||
void testCustomMatchAttributes() { | ||
// Create matchAttributes map and build configuration | ||
Map<String, String> matchAttributes = Map.of("tenant_id", "traceable"); | ||
RateLimiterConfiguration configuration = | ||
RateLimiterConfiguration.builder() | ||
.method("testMethod") | ||
.matchAttributes(matchAttributes) | ||
.rateLimit( | ||
RateLimiterConfiguration.RateLimit.builder() | ||
.tokens(100) | ||
.refreshPeriodSeconds(60) | ||
.build()) | ||
.build(); | ||
|
||
// Verify matchAttributes are correctly set | ||
assertEquals(matchAttributes, configuration.getMatchAttributes()); | ||
} | ||
|
||
@Test | ||
void testCustomTokenCostFunction() { | ||
// Define a custom tokenCostFunction | ||
BiFunction<RequestContext, Object, Integer> customTokenCostFunction = | ||
(ctx, req) -> (req instanceof Integer) ? (Integer) req : 5; | ||
|
||
// Build a configuration using the custom token cost function | ||
RateLimiterConfiguration configuration = | ||
RateLimiterConfiguration.builder() | ||
.method("computeTokenCost") | ||
.tokenCostFunction(customTokenCostFunction) | ||
.rateLimit( | ||
RateLimiterConfiguration.RateLimit.builder() | ||
.tokens(50) | ||
.refreshPeriodSeconds(15) | ||
.build()) | ||
.build(); | ||
|
||
// Verify behavior of the custom token cost function | ||
assertEquals(10, configuration.getTokenCostFunction().apply(null, 10)); // Dynamic cost | ||
assertEquals( | ||
5, configuration.getTokenCostFunction().apply(null, "randomObject")); // Default cost | ||
} | ||
|
||
@Test | ||
void testAttributeExtractor() { | ||
// Define an attributeExtractor that extracts specific attributes from the request | ||
BiFunction<RequestContext, Object, Map<String, String>> customAttributeExtractor = | ||
(ctx, req) -> Map.of("attributeKey", "attributeValue"); | ||
|
||
// Build the configuration with the custom attribute extractor | ||
RateLimiterConfiguration configuration = | ||
RateLimiterConfiguration.builder() | ||
.method("extractAttributes") | ||
.attributeExtractor(customAttributeExtractor) | ||
.rateLimit( | ||
RateLimiterConfiguration.RateLimit.builder() | ||
.tokens(20) | ||
.refreshPeriodSeconds(45) | ||
.build()) | ||
.build(); | ||
|
||
// Verify the custom attribute extractor | ||
assertEquals( | ||
Map.of("attributeKey", "attributeValue"), | ||
configuration.getAttributeExtractor().apply(null, null)); | ||
} | ||
|
||
@Test | ||
void testRateLimitConfiguration() { | ||
// Build a simple RateLimit configuration | ||
RateLimiterConfiguration.RateLimit rateLimit = | ||
RateLimiterConfiguration.RateLimit.builder().tokens(500).refreshPeriodSeconds(300).build(); | ||
|
||
// Verify RateLimit configuration values | ||
assertEquals(500, rateLimit.getTokens()); | ||
assertEquals(300, rateLimit.getRefreshPeriodSeconds()); | ||
} | ||
|
||
@Test | ||
void testFullCustomConfiguration() { | ||
// Define custom token cost function and attribute extractor | ||
BiFunction<RequestContext, Object, Integer> tokenCostFunction = (ctx, req) -> 2; | ||
BiFunction<RequestContext, Object, Map<String, String>> attributeExtractor = | ||
(ctx, req) -> Map.of("tenant_id", "12345"); | ||
|
||
// Build a complete custom configuration | ||
RateLimiterConfiguration configuration = | ||
RateLimiterConfiguration.builder() | ||
.method("fullCustomMethod") | ||
.enabled(false) | ||
.matchAttributes(Map.of("region", "us-west")) | ||
.attributeExtractor(attributeExtractor) | ||
.tokenCostFunction(tokenCostFunction) | ||
.rateLimit( | ||
RateLimiterConfiguration.RateLimit.builder() | ||
.tokens(1000) | ||
.refreshPeriodSeconds(60) | ||
.build()) | ||
.build(); | ||
|
||
// Verify all custom configurations | ||
assertFalse(configuration.isEnabled()); // Verify `enabled` value | ||
assertEquals("fullCustomMethod", configuration.getMethod()); // Verify method | ||
assertEquals( | ||
Map.of("region", "us-west"), configuration.getMatchAttributes()); // Verify matchAttributes | ||
assertEquals( | ||
Map.of("tenant_id", "12345"), | ||
configuration.getAttributeExtractor().apply(null, null)); // Custom extractor | ||
assertEquals(2, configuration.getTokenCostFunction().apply(null, null)); // Custom token cost | ||
assertEquals(1000, configuration.getRateLimit().getTokens()); // RateLimit tokens | ||
assertEquals( | ||
60, configuration.getRateLimit().getRefreshPeriodSeconds()); // RateLimit refresh period | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's nothing in this interface tied to gRPC (which is great), but then we go ahead an implement such that it can only be used as a grpc interceptor. Suggest splitting this into a stand alone rate limit package (not in this repo) and then we can provide a utility here to build a gRPC interceptor given a rate limiter instance.
That way we can use it in non grpc use cases too.