Skip to content

Add ratelimiter framework for grpc server interceptor #73

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions grpc-ratelimiter-utils/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
plugins {
`java-library`
jacoco
id("org.hypertrace.publish-plugin")
id("org.hypertrace.jacoco-report-plugin")
}

dependencies {

api(platform("io.grpc:grpc-bom:1.68.3"))
api("io.grpc:grpc-api")
api(project(":grpc-context-utils"))

implementation("org.slf4j:slf4j-api:1.7.36")
implementation("com.google.guava:guava:32.0.1-jre")
implementation("com.bucket4j:bucket4j-core:8.7.0")

annotationProcessor("org.projectlombok:lombok:1.18.24")
compileOnly("org.projectlombok:lombok:1.18.24")

testImplementation("org.junit.jupiter:junit-jupiter:5.8.2")
testImplementation("org.mockito:mockito-core:5.8.0")
testImplementation("org.mockito:mockito-junit-jupiter:5.8.0")
}

tasks.test {
useJUnitPlatform()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package org.hypertrace.ratelimiter.grpcutils;

public interface RateLimiter {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's nothing in this interface tied to gRPC (which is great), but then we go ahead an implement such that it can only be used as a grpc interceptor. Suggest splitting this into a stand alone rate limit package (not in this repo) and then we can provide a utility here to build a gRPC interceptor given a rate limiter instance.

That way we can use it in non grpc use cases too.

default boolean tryAcquire(String key, RateLimiterConfiguration.RateLimit rateLimit) {
return tryAcquire(key, 1, rateLimit);

Check warning on line 5 in grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/RateLimiter.java

View check run for this annotation

Codecov / codecov/patch

grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/RateLimiter.java#L5

Added line #L5 was not covered by tests
} // default single token

boolean tryAcquire(
String key, int permits, RateLimiterConfiguration.RateLimit rateLimit); // new: batch tokens
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package org.hypertrace.ratelimiter.grpcutils;

import java.util.Map;
import java.util.function.BiFunction;
import lombok.Builder;
import lombok.Value;
import org.hypertrace.core.grpcutils.context.RequestContext;

@Value
@Builder
public class RateLimiterConfiguration {
boolean enabled;
String method;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this? It's not clear from the name

// Attributes to match like tenant_id -> traceable
Map<String, String> matchAttributes;

// Extract attributes from gRPC request
BiFunction<RequestContext, Object, Map<String, String>> attributeExtractor;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following on the earlier comment, we could decouple this from grpc by dropping the attribute parts (because in the general case the caller is only passing something to the rate limiter if it wants to rate limit it) and taking the token cost from the call too.

Then, for the interceptor wrapper you'd have something like

BiPredicate<RequestContext, T> rateLimitMatcher; // Should we rate limit? Alternatively could be combined with permit calculator as a 0 response
BiFunction<RequestContext, T, Object> rateLimitKeyCalculator; // What should they key be? We don't care about the type unless we have a need for human readable keys in which case swap to string
BiFunction<RequestContext, T, Integer> permitCalculator; // Same as before, just renamed/typed


// Token cost evaluator (can be static 1 or dynamic based on message)
@Builder.Default BiFunction<RequestContext, Object, Integer> tokenCostFunction = (ctx, req) -> 1;
RateLimit rateLimit;

@Value
@Builder
public static class RateLimit {
int tokens;
int refreshPeriodSeconds;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - use duration unless the underlying lib only has second granularity.

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package org.hypertrace.ratelimiter.grpcutils;

public interface RateLimiterFactory {
RateLimiter getRateLimiter(RateLimiterConfiguration rateLimiterConfiguration);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package org.hypertrace.ratelimiter.grpcutils;

import org.hypertrace.ratelimiter.grpcutils.bucket4j.Bucket4jRateLimiterFactory;

public final class RateLimiterFactoryProvider {

private RateLimiterFactoryProvider() {
// Prevent instantiation
}

public static Bucket4jRateLimiterFactory bucket4j() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally not public. The point of the abstraction of a provider is so the caller is not concerned with the impl.

return new Bucket4jRateLimiterFactory();
}

Check warning on line 13 in grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/RateLimiterFactoryProvider.java

View check run for this annotation

Codecov / codecov/patch

grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/RateLimiterFactoryProvider.java#L12-L13

Added lines #L12 - L13 were not covered by tests
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package org.hypertrace.ratelimiter.grpcutils;

import io.grpc.ForwardingServerCallListener;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.hypertrace.core.grpcutils.context.RequestContext;

public class RateLimiterInterceptor implements ServerInterceptor {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name should indicate server vs client as we likely want both eventually.


private final List<RateLimiterConfiguration>
rateLimitConfigs; // Provided via config or dynamic update
private final RateLimiterFactory rateLimiterFactory;

public RateLimiterInterceptor(
List<RateLimiterConfiguration> rateLimitConfigs, RateLimiterFactory factory) {
this.rateLimitConfigs = rateLimitConfigs;
this.rateLimiterFactory = factory;
}

@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {

String method = call.getMethodDescriptor().getFullMethodName();

return new ForwardingServerCallListener.SimpleForwardingServerCallListener<>(
next.startCall(call, headers)) {
@Override
public void onMessage(ReqT message) {
RequestContext requestContext = RequestContext.fromMetadata(headers);
for (RateLimiterConfiguration config : rateLimitConfigs) {
if (!config.getMethod().equals(method)) continue;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing braces here + below. Even for single statement if clauses our code style is to include braces.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this also answers my earlier question about what method is and it's the name being used for string comparison. Then, we use that to do an unchecked cast (well technically we make the rate limit config author do the unchecked cast by typing it as Object) of the message. Instead, if you accept the method descriptor itself in the config rather than just it's name you can use that directly and get type safety as it will specify your type parameters for the other functions.


Map<String, String> attributes =
config.getAttributeExtractor().apply(requestContext, message);

if (!matches(config.getMatchAttributes(), attributes)) continue;
int tokens = config.getTokenCostFunction().apply(requestContext, message);
String key = buildRateLimitKey(method, config.getMatchAttributes(), attributes);
boolean allowed =
rateLimiterFactory
.getRateLimiter(config)
.tryAcquire(key, tokens, config.getRateLimit());
if (!allowed) {
call.close(Status.RESOURCE_EXHAUSTED.withDescription("Rate limit exceeded"), headers);
return;
}
}
super.onMessage(message);
}
};
}

private boolean matches(Map<String, String> match, Map<String, String> actual) {
return match.entrySet().stream()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we adjust the config as described in an earlier comment this would go away in favor of an explicit predicate.

.allMatch(e -> Objects.equals(actual.get(e.getKey()), e.getValue()));
}

private String buildRateLimitKey(
String method, Map<String, String> keys, Map<String, String> attrs) {
return method
+ "::"
+ keys.keySet().stream()
.map(k -> attrs.getOrDefault(k, "null"))
.collect(Collectors.joining(":"));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package org.hypertrace.ratelimiter.grpcutils.bucket4j;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import io.github.bucket4j.Bucket;
import io.grpc.Status;
import java.time.Duration;
import java.util.concurrent.ExecutionException;
import org.hypertrace.ratelimiter.grpcutils.RateLimiter;
import org.hypertrace.ratelimiter.grpcutils.RateLimiterConfiguration;
import org.hypertrace.ratelimiter.grpcutils.RateLimiterFactory;

public class Bucket4jRateLimiterFactory implements RateLimiterFactory {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't be public, must be a singleton


private final Cache<String, Bucket> limiterCache =
CacheBuilder.newBuilder().maximumSize(10_000).build();

Check warning on line 16 in grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/bucket4j/Bucket4jRateLimiterFactory.java

View check run for this annotation

Codecov / codecov/patch

grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/bucket4j/Bucket4jRateLimiterFactory.java#L15-L16

Added lines #L15 - L16 were not covered by tests

@Override
public RateLimiter getRateLimiter(RateLimiterConfiguration rule) {
return (key, tokens, limit) -> {

Check warning on line 20 in grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/bucket4j/Bucket4jRateLimiterFactory.java

View check run for this annotation

Codecov / codecov/patch

grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/bucket4j/Bucket4jRateLimiterFactory.java#L20

Added line #L20 was not covered by tests
try {
Bucket bucket = limiterCache.get(key, () -> createBucket(limit));
return bucket.tryConsume(tokens);
} catch (ExecutionException e) {
throw Status.INTERNAL
.withDescription("Failed to create rate limiter bucket for key: " + key)
.withCause(e)
.asRuntimeException();

Check warning on line 28 in grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/bucket4j/Bucket4jRateLimiterFactory.java

View check run for this annotation

Codecov / codecov/patch

grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/bucket4j/Bucket4jRateLimiterFactory.java#L22-L28

Added lines #L22 - L28 were not covered by tests
}
};
}

private Bucket createBucket(RateLimiterConfiguration.RateLimit limit) {
return Bucket.builder()
.addLimit(

Check warning on line 35 in grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/bucket4j/Bucket4jRateLimiterFactory.java

View check run for this annotation

Codecov / codecov/patch

grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/bucket4j/Bucket4jRateLimiterFactory.java#L34-L35

Added lines #L34 - L35 were not covered by tests
bandwidth ->
bandwidth
.capacity(limit.getTokens())
.refillGreedy(
limit.getTokens(), Duration.ofSeconds(limit.getRefreshPeriodSeconds())))
.build();

Check warning on line 41 in grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/bucket4j/Bucket4jRateLimiterFactory.java

View check run for this annotation

Codecov / codecov/patch

grpc-ratelimiter-utils/src/main/java/org/hypertrace/ratelimiter/grpcutils/bucket4j/Bucket4jRateLimiterFactory.java#L37-L41

Added lines #L37 - L41 were not covered by tests
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package org.hypertrace.ratelimiter.grpcutils;

import static org.junit.jupiter.api.Assertions.*;

import java.util.Map;
import java.util.function.BiFunction;
import org.hypertrace.core.grpcutils.context.RequestContext;
import org.junit.jupiter.api.Test;

class RateLimiterConfigurationTest {

@Test
void testDefaultValues() {
// Build a minimal RateLimiterConfiguration using only required fields
RateLimiterConfiguration configuration =
RateLimiterConfiguration.builder()
.method("testMethod")
.rateLimit(
RateLimiterConfiguration.RateLimit.builder()
.tokens(10)
.refreshPeriodSeconds(30)
.build())
.build();

// Assertions for default values
assertFalse(configuration.isEnabled()); // `enabled` defaults to true
assertEquals("testMethod", configuration.getMethod());
assertNull(configuration.getMatchAttributes()); // `matchAttributes` defaults to null
assertNotNull(configuration.getTokenCostFunction()); // Ensure tokenCostFunction is initialized
assertEquals(
1, configuration.getTokenCostFunction().apply(null, null)); // Default token cost value
}

@Test
void testCustomMatchAttributes() {
// Create matchAttributes map and build configuration
Map<String, String> matchAttributes = Map.of("tenant_id", "traceable");
RateLimiterConfiguration configuration =
RateLimiterConfiguration.builder()
.method("testMethod")
.matchAttributes(matchAttributes)
.rateLimit(
RateLimiterConfiguration.RateLimit.builder()
.tokens(100)
.refreshPeriodSeconds(60)
.build())
.build();

// Verify matchAttributes are correctly set
assertEquals(matchAttributes, configuration.getMatchAttributes());
}

@Test
void testCustomTokenCostFunction() {
// Define a custom tokenCostFunction
BiFunction<RequestContext, Object, Integer> customTokenCostFunction =
(ctx, req) -> (req instanceof Integer) ? (Integer) req : 5;

// Build a configuration using the custom token cost function
RateLimiterConfiguration configuration =
RateLimiterConfiguration.builder()
.method("computeTokenCost")
.tokenCostFunction(customTokenCostFunction)
.rateLimit(
RateLimiterConfiguration.RateLimit.builder()
.tokens(50)
.refreshPeriodSeconds(15)
.build())
.build();

// Verify behavior of the custom token cost function
assertEquals(10, configuration.getTokenCostFunction().apply(null, 10)); // Dynamic cost
assertEquals(
5, configuration.getTokenCostFunction().apply(null, "randomObject")); // Default cost
}

@Test
void testAttributeExtractor() {
// Define an attributeExtractor that extracts specific attributes from the request
BiFunction<RequestContext, Object, Map<String, String>> customAttributeExtractor =
(ctx, req) -> Map.of("attributeKey", "attributeValue");

// Build the configuration with the custom attribute extractor
RateLimiterConfiguration configuration =
RateLimiterConfiguration.builder()
.method("extractAttributes")
.attributeExtractor(customAttributeExtractor)
.rateLimit(
RateLimiterConfiguration.RateLimit.builder()
.tokens(20)
.refreshPeriodSeconds(45)
.build())
.build();

// Verify the custom attribute extractor
assertEquals(
Map.of("attributeKey", "attributeValue"),
configuration.getAttributeExtractor().apply(null, null));
}

@Test
void testRateLimitConfiguration() {
// Build a simple RateLimit configuration
RateLimiterConfiguration.RateLimit rateLimit =
RateLimiterConfiguration.RateLimit.builder().tokens(500).refreshPeriodSeconds(300).build();

// Verify RateLimit configuration values
assertEquals(500, rateLimit.getTokens());
assertEquals(300, rateLimit.getRefreshPeriodSeconds());
}

@Test
void testFullCustomConfiguration() {
// Define custom token cost function and attribute extractor
BiFunction<RequestContext, Object, Integer> tokenCostFunction = (ctx, req) -> 2;
BiFunction<RequestContext, Object, Map<String, String>> attributeExtractor =
(ctx, req) -> Map.of("tenant_id", "12345");

// Build a complete custom configuration
RateLimiterConfiguration configuration =
RateLimiterConfiguration.builder()
.method("fullCustomMethod")
.enabled(false)
.matchAttributes(Map.of("region", "us-west"))
.attributeExtractor(attributeExtractor)
.tokenCostFunction(tokenCostFunction)
.rateLimit(
RateLimiterConfiguration.RateLimit.builder()
.tokens(1000)
.refreshPeriodSeconds(60)
.build())
.build();

// Verify all custom configurations
assertFalse(configuration.isEnabled()); // Verify `enabled` value
assertEquals("fullCustomMethod", configuration.getMethod()); // Verify method
assertEquals(
Map.of("region", "us-west"), configuration.getMatchAttributes()); // Verify matchAttributes
assertEquals(
Map.of("tenant_id", "12345"),
configuration.getAttributeExtractor().apply(null, null)); // Custom extractor
assertEquals(2, configuration.getTokenCostFunction().apply(null, null)); // Custom token cost
assertEquals(1000, configuration.getRateLimit().getTokens()); // RateLimit tokens
assertEquals(
60, configuration.getRateLimit().getRefreshPeriodSeconds()); // RateLimit refresh period
}
}
Loading