Skip to content

Commit bc60d8e

Browse files
Add grpc circuit breaker utility using interceptors
1 parent 20402cf commit bc60d8e

7 files changed

+420
-0
lines changed
+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
plugins {
2+
`java-library`
3+
jacoco
4+
id("org.hypertrace.publish-plugin")
5+
id("org.hypertrace.jacoco-report-plugin")
6+
}
7+
8+
dependencies {
9+
10+
api(platform("io.grpc:grpc-bom:1.68.3"))
11+
api("io.grpc:grpc-context")
12+
api("io.grpc:grpc-api")
13+
api("io.grpc:grpc-inprocess")
14+
api(platform("io.netty:netty-bom:4.1.118.Final"))
15+
constraints {
16+
api("com.google.protobuf:protobuf-java:3.25.5") {
17+
because("https://nvd.nist.gov/vuln/detail/CVE-2024-7254")
18+
}
19+
}
20+
21+
implementation(project(":grpc-context-utils"))
22+
implementation("org.slf4j:slf4j-api:1.7.36")
23+
implementation("io.grpc:grpc-core")
24+
implementation("io.github.resilience4j:resilience4j-circuitbreaker:1.7.1")
25+
implementation("com.typesafe:config:1.4.2")
26+
implementation("com.google.inject:guice:7.0.0")
27+
implementation("org.hypertrace.core.serviceframework:platform-metrics:0.1.87")
28+
29+
annotationProcessor("org.projectlombok:lombok:1.18.24")
30+
compileOnly("org.projectlombok:lombok:1.18.24")
31+
32+
testImplementation("org.junit.jupiter:junit-jupiter:5.8.2")
33+
testImplementation("org.mockito:mockito-core:5.8.0")
34+
testRuntimeOnly("io.grpc:grpc-netty")
35+
}
36+
37+
tasks.test {
38+
useJUnitPlatform()
39+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
package org.hypertrace.circuitbreaker.grpcutils;
2+
3+
import com.typesafe.config.Config;
4+
import io.github.resilience4j.circuitbreaker.CircuitBreakerConfig;
5+
import java.util.Map;
6+
import java.util.concurrent.ConcurrentHashMap;
7+
import java.util.stream.Collectors;
8+
import lombok.extern.slf4j.Slf4j;
9+
10+
@Slf4j
11+
public class CircuitBreakerConfigProvider {
12+
13+
public static final String CIRCUIT_BREAKER_CONFIG = "circuit.breaker.config";
14+
public static final String DEFAULT_CONFIG_KEY = "default";
15+
16+
// Whether to enable circuit breaker or not.
17+
private static final String ENABLED = "enabled";
18+
19+
// Percentage of failures to trigger OPEN state
20+
private static final String FAILURE_RATE_THRESHOLD = "failureRateThreshold";
21+
// Percentage of slow calls to trigger OPEN state
22+
private static final String SLOW_CALL_RATE_THRESHOLD = "slowCallRateThreshold";
23+
// Define what a "slow" call is
24+
private static final String SLOW_CALL_DURATION_THRESHOLD = "slowCallDurationThreshold";
25+
// Number of calls to consider in the sliding window
26+
private static final String SLIDING_WINDOW_SIZE = "slidingWindowSize";
27+
// Time before retrying after OPEN state
28+
private static final String WAIT_DURATION_IN_OPEN_STATE = "waitDurationInOpenState";
29+
// Minimum calls before evaluating failure rate
30+
private static final String MINIMUM_NUMBER_OF_CALLS = "minimumNumberOfCalls";
31+
// Calls allowed in HALF_OPEN state before deciding to
32+
// CLOSE or OPEN again
33+
private static final String PERMITTED_NUMBER_OF_CALLS_IN_HALF_OPEN_STATE =
34+
"permittedNumberOfCallsInHalfOpenState";
35+
private static final String SLIDING_WINDOW_TYPE = "slidingWindowType";
36+
37+
// Cache for storing CircuitBreakerConfig instances
38+
private static final ConcurrentHashMap<String, CircuitBreakerConfig> configCache =
39+
new ConcurrentHashMap<>();
40+
41+
// Global flag for circuit breaker enablement
42+
private boolean circuitBreakerEnabled = false;
43+
44+
public CircuitBreakerConfigProvider(Config config) {
45+
initialize(config);
46+
}
47+
48+
public CircuitBreakerConfigProvider() {}
49+
50+
/** Initializes and caches all CircuitBreaker configurations. */
51+
public void initialize(Config config) {
52+
if (!config.hasPath(CIRCUIT_BREAKER_CONFIG)) {
53+
log.warn("No circuit breaker configurations found in the config file.");
54+
return;
55+
}
56+
57+
Config circuitBreakerConfig = config.getConfig(CIRCUIT_BREAKER_CONFIG);
58+
59+
// Read global enabled flag (default to false if not provided)
60+
circuitBreakerEnabled =
61+
circuitBreakerConfig.hasPath(ENABLED) && circuitBreakerConfig.getBoolean(ENABLED);
62+
63+
// Load all circuit breaker configurations and cache them
64+
Map<String, CircuitBreakerConfig> allConfigs =
65+
circuitBreakerConfig.root().keySet().stream()
66+
.filter(key -> !key.equals(ENABLED)) // Ignore the global enabled flag
67+
.collect(
68+
Collectors.toMap(
69+
key -> key, // Circuit breaker key
70+
key -> createCircuitBreakerConfig(circuitBreakerConfig.getConfig(key))));
71+
72+
// Store in cache
73+
configCache.putAll(allConfigs);
74+
75+
log.info(
76+
"Loaded {} circuit breaker configurations, Global Enabled: {}. Configs: {}",
77+
allConfigs.size(),
78+
circuitBreakerEnabled,
79+
allConfigs);
80+
}
81+
82+
/**
83+
* Retrieves the CircuitBreakerConfig for a specific key. Falls back to default if key-specific
84+
* config is not found.
85+
*/
86+
public CircuitBreakerConfig getConfig(String circuitBreakerKey) {
87+
return configCache.getOrDefault(circuitBreakerKey, configCache.get(DEFAULT_CONFIG_KEY));
88+
}
89+
90+
/** Checks if Circuit Breaker is globally enabled. */
91+
public boolean isCircuitBreakerEnabled() {
92+
return circuitBreakerEnabled;
93+
}
94+
95+
private CircuitBreakerConfig createCircuitBreakerConfig(Config config) {
96+
return CircuitBreakerConfig.custom()
97+
.failureRateThreshold((float) config.getDouble(FAILURE_RATE_THRESHOLD))
98+
.slowCallRateThreshold((float) config.getDouble(SLOW_CALL_RATE_THRESHOLD))
99+
.slowCallDurationThreshold(config.getDuration(SLOW_CALL_DURATION_THRESHOLD))
100+
.slidingWindowType(getSlidingWindowType(config.getString(SLIDING_WINDOW_TYPE)))
101+
.slidingWindowSize(config.getInt(SLIDING_WINDOW_SIZE))
102+
.waitDurationInOpenState(config.getDuration(WAIT_DURATION_IN_OPEN_STATE))
103+
.permittedNumberOfCallsInHalfOpenState(
104+
config.getInt(PERMITTED_NUMBER_OF_CALLS_IN_HALF_OPEN_STATE))
105+
.minimumNumberOfCalls(config.getInt(MINIMUM_NUMBER_OF_CALLS))
106+
.build();
107+
}
108+
109+
private CircuitBreakerConfig.SlidingWindowType getSlidingWindowType(String slidingWindowType) {
110+
return CircuitBreakerConfig.SlidingWindowType.valueOf(slidingWindowType);
111+
}
112+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package org.hypertrace.circuitbreaker.grpcutils;
2+
3+
import io.github.resilience4j.circuitbreaker.CircuitBreaker;
4+
import java.util.Set;
5+
import java.util.concurrent.ConcurrentHashMap;
6+
import lombok.extern.slf4j.Slf4j;
7+
8+
@Slf4j
9+
public class CircuitBreakerEventListener {
10+
private static final Set<String> attachedCircuitBreakers = ConcurrentHashMap.newKeySet();
11+
12+
public static synchronized void attachListeners(CircuitBreaker circuitBreaker) {
13+
if (!attachedCircuitBreakers.add(
14+
circuitBreaker.getName())) { // Ensures only one listener is attached
15+
return;
16+
}
17+
circuitBreaker
18+
.getEventPublisher()
19+
.onStateTransition(
20+
event ->
21+
log.info(
22+
"State transition: {} for circuit breaker {} ",
23+
event.getStateTransition(),
24+
event.getCircuitBreakerName()))
25+
.onCallNotPermitted(
26+
event ->
27+
log.debug(
28+
"Call not permitted: Circuit is OPEN for circuit breaker {} ",
29+
event.getCircuitBreakerName()))
30+
.onEvent(
31+
event -> {
32+
log.debug(
33+
"Circuit breaker event type {} for circuit breaker name {} ",
34+
event.getEventType(),
35+
event.getCircuitBreakerName());
36+
});
37+
}
38+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
package org.hypertrace.circuitbreaker.grpcutils;
2+
3+
import io.github.resilience4j.circuitbreaker.CircuitBreaker;
4+
import io.github.resilience4j.circuitbreaker.CircuitBreakerRegistry;
5+
import io.grpc.CallOptions;
6+
import io.grpc.Channel;
7+
import io.grpc.ClientCall;
8+
import io.grpc.ClientInterceptor;
9+
import io.grpc.ForwardingClientCall;
10+
import io.grpc.ForwardingClientCallListener;
11+
import io.grpc.Metadata;
12+
import io.grpc.MethodDescriptor;
13+
import io.grpc.Status;
14+
import java.util.concurrent.TimeUnit;
15+
import lombok.extern.slf4j.Slf4j;
16+
17+
@Slf4j
18+
public class CircuitBreakerInterceptor implements ClientInterceptor {
19+
20+
public static final CallOptions.Key<String> CIRCUIT_BREAKER_KEY =
21+
CallOptions.Key.createWithDefault("circuitBreakerKey", "default");
22+
private final CircuitBreakerRegistry circuitBreakerRegistry;
23+
private final CircuitBreakerConfigProvider circuitBreakerConfigProvider;
24+
private final CircuitBreakerMetricsNotifier circuitBreakerMetricsNotifier;
25+
26+
public CircuitBreakerInterceptor(
27+
CircuitBreakerRegistry circuitBreakerRegistry,
28+
CircuitBreakerConfigProvider circuitBreakerConfigProvider,
29+
CircuitBreakerMetricsNotifier circuitBreakerMetricsNotifier) {
30+
this.circuitBreakerRegistry = circuitBreakerRegistry;
31+
this.circuitBreakerConfigProvider = circuitBreakerConfigProvider;
32+
this.circuitBreakerMetricsNotifier = circuitBreakerMetricsNotifier;
33+
}
34+
35+
// Intercepts the call and applies circuit breaker logic
36+
@Override
37+
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
38+
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
39+
if (!circuitBreakerConfigProvider.isCircuitBreakerEnabled()) {
40+
return next.newCall(method, callOptions);
41+
}
42+
43+
// Get circuit breaker key from CallOptions
44+
String circuitBreakerKey = callOptions.getOption(CIRCUIT_BREAKER_KEY);
45+
CircuitBreaker circuitBreaker = getCircuitBreaker(circuitBreakerKey);
46+
return new ForwardingClientCall.SimpleForwardingClientCall<>(
47+
next.newCall(method, callOptions)) {
48+
@Override
49+
public void start(Listener<RespT> responseListener, Metadata headers) {
50+
long startTime = System.nanoTime();
51+
52+
// Wrap response listener to track failures
53+
Listener<RespT> wrappedListener =
54+
new ForwardingClientCallListener.SimpleForwardingClientCallListener<>(
55+
responseListener) {
56+
@Override
57+
public void onClose(Status status, Metadata trailers) {
58+
long duration = System.nanoTime() - startTime;
59+
if (status.isOk()) {
60+
circuitBreaker.onSuccess(duration, TimeUnit.NANOSECONDS);
61+
} else {
62+
log.debug(
63+
"Circuit Breaker '{}' detected failure. Status: {}, Description: {}",
64+
circuitBreaker.getName(),
65+
status.getCode(),
66+
status.getDescription());
67+
circuitBreaker.onError(
68+
duration, TimeUnit.NANOSECONDS, status.asRuntimeException());
69+
}
70+
super.onClose(status, trailers);
71+
}
72+
};
73+
74+
super.start(wrappedListener, headers);
75+
}
76+
77+
@Override
78+
public void sendMessage(ReqT message) {
79+
if (!circuitBreaker.tryAcquirePermission()) {
80+
handleCircuitBreakerRejection(circuitBreakerKey, circuitBreaker);
81+
String rejectionReason =
82+
circuitBreaker.getState() == CircuitBreaker.State.HALF_OPEN
83+
? "Circuit Breaker is HALF-OPEN and rejecting excess requests"
84+
: "Circuit Breaker is OPEN and blocking requests";
85+
throw Status.UNAVAILABLE.withDescription(rejectionReason).asRuntimeException();
86+
}
87+
super.sendMessage(message);
88+
}
89+
};
90+
}
91+
92+
private void handleCircuitBreakerRejection(
93+
String circuitBreakerKey, CircuitBreaker circuitBreaker) {
94+
String tenantId = getTenantId(circuitBreakerKey);
95+
if (circuitBreaker.getState() == CircuitBreaker.State.HALF_OPEN) {
96+
circuitBreakerMetricsNotifier.incrementCount(tenantId, "circuitbreaker.halfopen.rejected");
97+
log.debug(
98+
"Circuit Breaker '{}' is HALF-OPEN and rejecting excess requests for tenant '{}'.",
99+
circuitBreakerKey,
100+
tenantId);
101+
} else if (circuitBreaker.getState() == CircuitBreaker.State.OPEN) {
102+
circuitBreakerMetricsNotifier.incrementCount(tenantId, "circuitbreaker.open.blocked");
103+
log.debug(
104+
"Circuit Breaker '{}' is OPEN. Blocking request for tenant '{}'.",
105+
circuitBreakerKey,
106+
tenantId);
107+
} else {
108+
log.debug( // Added unexpected state handling for safety
109+
"Unexpected Circuit Breaker state '{}' for '{}'. Blocking request.",
110+
circuitBreaker.getState(),
111+
circuitBreakerKey);
112+
}
113+
}
114+
115+
private static String getTenantId(String circuitBreakerKey) {
116+
if (!circuitBreakerKey.contains(".")) {
117+
return "Unknown";
118+
}
119+
return circuitBreakerKey.split("\\.", 2)[0]; // Ensures only the first split
120+
}
121+
122+
/** Retrieve the Circuit Breaker based on the key. */
123+
private CircuitBreaker getCircuitBreaker(String circuitBreakerKey) {
124+
CircuitBreaker circuitBreaker =
125+
circuitBreakerRegistry.circuitBreaker(
126+
circuitBreakerKey, circuitBreakerConfigProvider.getConfig(circuitBreakerKey));
127+
CircuitBreakerEventListener.attachListeners(circuitBreaker);
128+
return circuitBreaker;
129+
}
130+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package org.hypertrace.circuitbreaker.grpcutils;
2+
3+
import io.micrometer.core.instrument.Counter;
4+
import io.micrometer.core.instrument.Meter;
5+
import io.micrometer.core.instrument.Tags;
6+
import io.micrometer.core.instrument.noop.NoopCounter;
7+
import java.util.Map;
8+
import java.util.concurrent.ConcurrentHashMap;
9+
import org.hypertrace.core.serviceframework.metrics.PlatformMetricsRegistry;
10+
11+
public class CircuitBreakerMetricsNotifier {
12+
private static final ConcurrentHashMap<String, Counter> counterMap = new ConcurrentHashMap<>();
13+
public static final String UNKNOWN_TENANT = "unknown";
14+
15+
public void incrementCount(String tenantId, String counterName) {
16+
getCounter(tenantId, counterName).increment();
17+
}
18+
19+
public Counter getCounter(String tenantId, String counterName) {
20+
if (tenantId == null || tenantId.equals(UNKNOWN_TENANT)) {
21+
return getNoopCounter();
22+
}
23+
return counterMap.computeIfAbsent(
24+
tenantId + counterName,
25+
(unused) ->
26+
PlatformMetricsRegistry.registerCounter(counterName, Map.of("tenantId", tenantId)));
27+
}
28+
29+
private NoopCounter getNoopCounter() {
30+
Meter.Id dummyId = new Meter.Id("noopCounter", Tags.empty(), null, null, Meter.Type.COUNTER);
31+
return new NoopCounter(dummyId);
32+
}
33+
}

0 commit comments

Comments
 (0)