diff --git a/build.gradle.kts b/build.gradle.kts index ffc733c..10adac6 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -7,7 +7,7 @@ plugins { id("org.hypertrace.publish-plugin") version "1.0.5" apply false id("org.hypertrace.jacoco-report-plugin") version "0.2.1" apply false id("org.hypertrace.code-style-plugin") version "2.0.0" apply false - id("org.owasp.dependencycheck") version "10.0.3" + id("org.owasp.dependencycheck") version "12.1.0" } subprojects { diff --git a/grpc-circuitbreaker-utils/build.gradle.kts b/grpc-circuitbreaker-utils/build.gradle.kts new file mode 100644 index 0000000..fadc1cf --- /dev/null +++ b/grpc-circuitbreaker-utils/build.gradle.kts @@ -0,0 +1,29 @@ +plugins { + `java-library` + jacoco + id("org.hypertrace.publish-plugin") + id("org.hypertrace.jacoco-report-plugin") +} + +dependencies { + + api(platform("io.grpc:grpc-bom:1.68.3")) + api("io.grpc:grpc-api") + api(project(":grpc-context-utils")) + + implementation("org.slf4j:slf4j-api:1.7.36") + implementation("io.github.resilience4j:resilience4j-circuitbreaker:1.7.1") + implementation("com.typesafe:config:1.4.2") + implementation("com.google.guava:guava:32.0.1-jre") + + annotationProcessor("org.projectlombok:lombok:1.18.24") + compileOnly("org.projectlombok:lombok:1.18.24") + + testImplementation("org.junit.jupiter:junit-jupiter:5.8.2") + testImplementation("org.mockito:mockito-core:5.8.0") + testImplementation("org.mockito:mockito-junit-jupiter:5.8.0") +} + +tasks.test { + useJUnitPlatform() +} diff --git a/grpc-circuitbreaker-utils/src/main/java/org/hypertrace/circuitbreaker/grpcutils/CircuitBreakerConfigParser.java b/grpc-circuitbreaker-utils/src/main/java/org/hypertrace/circuitbreaker/grpcutils/CircuitBreakerConfigParser.java new file mode 100644 index 0000000..dc4867f --- /dev/null +++ b/grpc-circuitbreaker-utils/src/main/java/org/hypertrace/circuitbreaker/grpcutils/CircuitBreakerConfigParser.java @@ -0,0 +1,111 @@ +package org.hypertrace.circuitbreaker.grpcutils; + +import com.typesafe.config.Config; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public class CircuitBreakerConfigParser { + + // Percentage of failures to trigger OPEN state + private static final String FAILURE_RATE_THRESHOLD = "failureRateThreshold"; + // Percentage of slow calls to trigger OPEN state + private static final String SLOW_CALL_RATE_THRESHOLD = "slowCallRateThreshold"; + // Define what a "slow" call is + private static final String SLOW_CALL_DURATION_THRESHOLD = "slowCallDurationThreshold"; + // Number of calls to consider in the sliding window + private static final String SLIDING_WINDOW_SIZE = "slidingWindowSize"; + // Time before retrying after OPEN state + private static final String WAIT_DURATION_IN_OPEN_STATE = "waitDurationInOpenState"; + // Minimum calls before evaluating failure rate + private static final String MINIMUM_NUMBER_OF_CALLS = "minimumNumberOfCalls"; + // Calls allowed in HALF_OPEN state before deciding to + // CLOSE or OPEN again + private static final String PERMITTED_NUMBER_OF_CALLS_IN_HALF_OPEN_STATE = + "permittedNumberOfCallsInHalfOpenState"; + private static final String SLIDING_WINDOW_TYPE = "slidingWindowType"; + public static final String ENABLED = "enabled"; + public static final String DEFAULT_THRESHOLDS = "defaultThresholds"; + private static final Set NON_THRESHOLD_KEYS = Set.of(ENABLED, DEFAULT_THRESHOLDS); + + public static CircuitBreakerConfiguration.CircuitBreakerConfigurationBuilder parseConfig( + Config config) { + CircuitBreakerConfiguration.CircuitBreakerConfigurationBuilder builder = + CircuitBreakerConfiguration.builder(); + if (config.hasPath(ENABLED)) { + builder.enabled(config.getBoolean(ENABLED)); + } + + Map circuitBreakerThresholdsMap = + config.root().keySet().stream() + .filter(key -> !NON_THRESHOLD_KEYS.contains(key)) // Filter out non-threshold keys + .collect( + Collectors.toMap( + key -> key, // Circuit breaker key + key -> buildCircuitBreakerThresholds(config.getConfig(key)))); + + builder.defaultThresholds( + config.hasPath(DEFAULT_THRESHOLDS) + ? buildCircuitBreakerThresholds(config.getConfig(DEFAULT_THRESHOLDS)) + : buildCircuitBreakerDefaultThresholds()); + + builder.circuitBreakerThresholdsMap(circuitBreakerThresholdsMap); + log.debug("Loaded circuit breaker configs: {}", builder); + return builder; + } + + private static CircuitBreakerThresholds buildCircuitBreakerThresholds(Config config) { + CircuitBreakerThresholds.CircuitBreakerThresholdsBuilder builder = + CircuitBreakerThresholds.builder(); + + if (config.hasPath(FAILURE_RATE_THRESHOLD)) { + builder.failureRateThreshold((float) config.getDouble(FAILURE_RATE_THRESHOLD)); + } + + if (config.hasPath(SLOW_CALL_RATE_THRESHOLD)) { + builder.slowCallRateThreshold((float) config.getDouble(SLOW_CALL_RATE_THRESHOLD)); + } + + if (config.hasPath(SLOW_CALL_DURATION_THRESHOLD)) { + builder.slowCallDurationThreshold(config.getDuration(SLOW_CALL_DURATION_THRESHOLD)); + } + + if (config.hasPath(SLIDING_WINDOW_TYPE)) { + builder.slidingWindowType(getSlidingWindowType(config.getString(SLIDING_WINDOW_TYPE))); + } + + if (config.hasPath(SLIDING_WINDOW_SIZE)) { + builder.slidingWindowSize(config.getInt(SLIDING_WINDOW_SIZE)); + } + + if (config.hasPath(WAIT_DURATION_IN_OPEN_STATE)) { + builder.waitDurationInOpenState(config.getDuration(WAIT_DURATION_IN_OPEN_STATE)); + } + + if (config.hasPath(PERMITTED_NUMBER_OF_CALLS_IN_HALF_OPEN_STATE)) { + builder.permittedNumberOfCallsInHalfOpenState( + config.getInt(PERMITTED_NUMBER_OF_CALLS_IN_HALF_OPEN_STATE)); + } + + if (config.hasPath(MINIMUM_NUMBER_OF_CALLS)) { + builder.minimumNumberOfCalls(config.getInt(MINIMUM_NUMBER_OF_CALLS)); + } + + if (config.hasPath(ENABLED)) { + builder.enabled(config.getBoolean(ENABLED)); + } + + return builder.build(); + } + + public static CircuitBreakerThresholds buildCircuitBreakerDefaultThresholds() { + return CircuitBreakerThresholds.builder().build(); + } + + private static CircuitBreakerThresholds.SlidingWindowType getSlidingWindowType( + String slidingWindowType) { + return CircuitBreakerThresholds.SlidingWindowType.valueOf(slidingWindowType); + } +} diff --git a/grpc-circuitbreaker-utils/src/main/java/org/hypertrace/circuitbreaker/grpcutils/CircuitBreakerConfiguration.java b/grpc-circuitbreaker-utils/src/main/java/org/hypertrace/circuitbreaker/grpcutils/CircuitBreakerConfiguration.java new file mode 100644 index 0000000..15866f9 --- /dev/null +++ b/grpc-circuitbreaker-utils/src/main/java/org/hypertrace/circuitbreaker/grpcutils/CircuitBreakerConfiguration.java @@ -0,0 +1,27 @@ +package org.hypertrace.circuitbreaker.grpcutils; + +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Function; +import lombok.Builder; +import lombok.Value; +import org.hypertrace.core.grpcutils.context.RequestContext; + +@Value +@Builder +public class CircuitBreakerConfiguration { + Class requestClass; + BiFunction keyFunction; + @Builder.Default boolean enabled = false; + // Standard/default thresholds + CircuitBreakerThresholds defaultThresholds; + // Custom overrides for specific cases (less common) + @Builder.Default Map circuitBreakerThresholdsMap = Map.of(); + + // New exception builder logic + @Builder.Default + Function exceptionBuilder = + reason -> Status.RESOURCE_EXHAUSTED.withDescription(reason).asRuntimeException(); +} diff --git a/grpc-circuitbreaker-utils/src/main/java/org/hypertrace/circuitbreaker/grpcutils/CircuitBreakerInterceptor.java b/grpc-circuitbreaker-utils/src/main/java/org/hypertrace/circuitbreaker/grpcutils/CircuitBreakerInterceptor.java new file mode 100644 index 0000000..b94cce8 --- /dev/null +++ b/grpc-circuitbreaker-utils/src/main/java/org/hypertrace/circuitbreaker/grpcutils/CircuitBreakerInterceptor.java @@ -0,0 +1,23 @@ +package org.hypertrace.circuitbreaker.grpcutils; + +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.MethodDescriptor; + +public abstract class CircuitBreakerInterceptor implements ClientInterceptor { + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + if (!isCircuitBreakerEnabled()) { + return next.newCall(method, callOptions); + } + return createInterceptedCall(method, callOptions, next); + } + + protected abstract boolean isCircuitBreakerEnabled(); + + protected abstract ClientCall createInterceptedCall( + MethodDescriptor method, CallOptions callOptions, Channel next); +} diff --git a/grpc-circuitbreaker-utils/src/main/java/org/hypertrace/circuitbreaker/grpcutils/CircuitBreakerThresholds.java b/grpc-circuitbreaker-utils/src/main/java/org/hypertrace/circuitbreaker/grpcutils/CircuitBreakerThresholds.java new file mode 100644 index 0000000..3b7f89e --- /dev/null +++ b/grpc-circuitbreaker-utils/src/main/java/org/hypertrace/circuitbreaker/grpcutils/CircuitBreakerThresholds.java @@ -0,0 +1,32 @@ +package org.hypertrace.circuitbreaker.grpcutils; + +import java.time.Duration; +import lombok.Builder; +import lombok.Value; + +@Value +@Builder +public class CircuitBreakerThresholds { + // Percentage of failures to trigger OPEN state + @Builder.Default float failureRateThreshold = 50f; + // Percentage of slow calls to trigger OPEN state + @Builder.Default float slowCallRateThreshold = 50f; + // Define what a "slow" call is + @Builder.Default Duration slowCallDurationThreshold = Duration.ofSeconds(2); + // Number of calls to consider in the sliding window + @Builder.Default SlidingWindowType slidingWindowType = SlidingWindowType.TIME_BASED; + @Builder.Default int slidingWindowSize = 60; + // Time before retrying after OPEN state + @Builder.Default Duration waitDurationInOpenState = Duration.ofSeconds(60); + // Minimum calls before evaluating failure rate + @Builder.Default int minimumNumberOfCalls = 10; + // Calls allowed in HALF_OPEN state before deciding to + // CLOSE or OPEN again + @Builder.Default int permittedNumberOfCallsInHalfOpenState = 5; + @Builder.Default boolean enabled = true; + + public enum SlidingWindowType { + COUNT_BASED, + TIME_BASED + } +} diff --git a/grpc-circuitbreaker-utils/src/main/java/org/hypertrace/circuitbreaker/grpcutils/resilience/ResilienceCircuitBreakerConfigConverter.java b/grpc-circuitbreaker-utils/src/main/java/org/hypertrace/circuitbreaker/grpcutils/resilience/ResilienceCircuitBreakerConfigConverter.java new file mode 100644 index 0000000..28cfd0c --- /dev/null +++ b/grpc-circuitbreaker-utils/src/main/java/org/hypertrace/circuitbreaker/grpcutils/resilience/ResilienceCircuitBreakerConfigConverter.java @@ -0,0 +1,50 @@ +package org.hypertrace.circuitbreaker.grpcutils.resilience; + +import io.github.resilience4j.circuitbreaker.CircuitBreakerConfig; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.hypertrace.circuitbreaker.grpcutils.CircuitBreakerThresholds; + +/** Utility class to parse CircuitBreakerConfiguration to Resilience4j CircuitBreakerConfig */ +class ResilienceCircuitBreakerConfigConverter { + + public static Map getCircuitBreakerConfigs( + Map configurationMap) { + return configurationMap.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, entry -> convertConfig(entry.getValue()))); + } + + public static List getDisabledKeys( + Map configurationMap) { + return configurationMap.entrySet().stream() + .filter(entry -> entry.getValue().isEnabled()) + .map(Map.Entry::getKey) + .collect(Collectors.toList()); + } + + static CircuitBreakerConfig convertConfig(CircuitBreakerThresholds configuration) { + return CircuitBreakerConfig.custom() + .failureRateThreshold(configuration.getFailureRateThreshold()) + .slowCallRateThreshold(configuration.getSlowCallRateThreshold()) + .slowCallDurationThreshold(configuration.getSlowCallDurationThreshold()) + .slidingWindowType(getSlidingWindowType(configuration.getSlidingWindowType())) + .slidingWindowSize(configuration.getSlidingWindowSize()) + .waitDurationInOpenState(configuration.getWaitDurationInOpenState()) + .permittedNumberOfCallsInHalfOpenState( + configuration.getPermittedNumberOfCallsInHalfOpenState()) + .minimumNumberOfCalls(configuration.getMinimumNumberOfCalls()) + .build(); + } + + private static CircuitBreakerConfig.SlidingWindowType getSlidingWindowType( + CircuitBreakerThresholds.SlidingWindowType slidingWindowType) { + switch (slidingWindowType) { + case COUNT_BASED: + return CircuitBreakerConfig.SlidingWindowType.COUNT_BASED; + case TIME_BASED: + default: + return CircuitBreakerConfig.SlidingWindowType.TIME_BASED; + } + } +} diff --git a/grpc-circuitbreaker-utils/src/main/java/org/hypertrace/circuitbreaker/grpcutils/resilience/ResilienceCircuitBreakerFactory.java b/grpc-circuitbreaker-utils/src/main/java/org/hypertrace/circuitbreaker/grpcutils/resilience/ResilienceCircuitBreakerFactory.java new file mode 100644 index 0000000..fb47045 --- /dev/null +++ b/grpc-circuitbreaker-utils/src/main/java/org/hypertrace/circuitbreaker/grpcutils/resilience/ResilienceCircuitBreakerFactory.java @@ -0,0 +1,29 @@ +package org.hypertrace.circuitbreaker.grpcutils.resilience; + +import io.github.resilience4j.circuitbreaker.CircuitBreakerConfig; +import io.github.resilience4j.circuitbreaker.CircuitBreakerRegistry; +import java.time.Clock; +import java.util.Map; +import org.hypertrace.circuitbreaker.grpcutils.CircuitBreakerConfiguration; + +public class ResilienceCircuitBreakerFactory { + public static ResilienceCircuitBreakerInterceptor getResilienceCircuitBreakerInterceptor( + CircuitBreakerConfiguration circuitBreakerConfiguration, Clock clock) { + Map resilienceCircuitBreakerConfigMap = + ResilienceCircuitBreakerConfigConverter.getCircuitBreakerConfigs( + circuitBreakerConfiguration.getCircuitBreakerThresholdsMap()); + CircuitBreakerRegistry resilienceCircuitBreakerRegistry = + new ResilienceCircuitBreakerRegistryProvider( + circuitBreakerConfiguration.getDefaultThresholds()) + .getCircuitBreakerRegistry(); + ResilienceCircuitBreakerProvider resilienceCircuitBreakerProvider = + new ResilienceCircuitBreakerProvider( + resilienceCircuitBreakerRegistry, + resilienceCircuitBreakerConfigMap, + ResilienceCircuitBreakerConfigConverter.getDisabledKeys( + circuitBreakerConfiguration.getCircuitBreakerThresholdsMap()), + circuitBreakerConfiguration.getDefaultThresholds().isEnabled()); + return new ResilienceCircuitBreakerInterceptor( + circuitBreakerConfiguration, clock, resilienceCircuitBreakerProvider); + } +} diff --git a/grpc-circuitbreaker-utils/src/main/java/org/hypertrace/circuitbreaker/grpcutils/resilience/ResilienceCircuitBreakerInterceptor.java b/grpc-circuitbreaker-utils/src/main/java/org/hypertrace/circuitbreaker/grpcutils/resilience/ResilienceCircuitBreakerInterceptor.java new file mode 100644 index 0000000..64f4082 --- /dev/null +++ b/grpc-circuitbreaker-utils/src/main/java/org/hypertrace/circuitbreaker/grpcutils/resilience/ResilienceCircuitBreakerInterceptor.java @@ -0,0 +1,134 @@ +package org.hypertrace.circuitbreaker.grpcutils.resilience; + +import io.github.resilience4j.circuitbreaker.CircuitBreaker; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ForwardingClientCall; +import io.grpc.ForwardingClientCallListener; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.Status; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import lombok.extern.slf4j.Slf4j; +import org.hypertrace.circuitbreaker.grpcutils.CircuitBreakerConfiguration; +import org.hypertrace.circuitbreaker.grpcutils.CircuitBreakerInterceptor; +import org.hypertrace.core.grpcutils.context.RequestContext; + +@Slf4j +public class ResilienceCircuitBreakerInterceptor extends CircuitBreakerInterceptor { + + private final ResilienceCircuitBreakerProvider resilienceCircuitBreakerProvider; + private final CircuitBreakerConfiguration circuitBreakerConfiguration; + private final Clock clock; + + ResilienceCircuitBreakerInterceptor( + CircuitBreakerConfiguration circuitBreakerConfiguration, + Clock clock, + ResilienceCircuitBreakerProvider resilienceCircuitBreakerProvider) { + this.circuitBreakerConfiguration = circuitBreakerConfiguration; + this.clock = clock; + this.resilienceCircuitBreakerProvider = resilienceCircuitBreakerProvider; + } + + @Override + protected boolean isCircuitBreakerEnabled() { + return circuitBreakerConfiguration.isEnabled(); + } + + @Override + protected ClientCall createInterceptedCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + return new ForwardingClientCall.SimpleForwardingClientCall<>( + next.newCall(method, callOptions)) { + Optional optionalCircuitBreaker; + + @Override + public void start(Listener responseListener, Metadata headers) { + Instant startTime = clock.instant(); + // Wrap response listener to track failures + Listener wrappedListener = + wrapListenerWithCircuitBreaker(responseListener, startTime); + super.start(wrappedListener, headers); + } + + @SuppressWarnings("unchecked") + @Override + public void sendMessage(ReqT message) { + CircuitBreakerConfiguration config = + (CircuitBreakerConfiguration) circuitBreakerConfiguration; + // Type check for message class compatibility + if (config.getRequestClass() != null && !config.getRequestClass().isInstance(message)) { + super.sendMessage(message); + return; + } + String circuitBreakerKey = null; + if (config.getKeyFunction() != null) { + circuitBreakerKey = config.getKeyFunction().apply(RequestContext.CURRENT.get(), message); + } + optionalCircuitBreaker = + circuitBreakerKey != null + ? resilienceCircuitBreakerProvider.getCircuitBreaker(circuitBreakerKey) + : resilienceCircuitBreakerProvider.getSharedCircuitBreaker(); + + CircuitBreaker circuitBreaker = optionalCircuitBreaker.orElse(null); + if (circuitBreaker == null) { + super.sendMessage(message); + return; + } + if (!circuitBreaker.tryAcquirePermission()) { + logCircuitBreakerRejection(circuitBreakerKey, circuitBreaker); + String rejectionReason = + circuitBreaker.getState() == CircuitBreaker.State.HALF_OPEN + ? "Circuit Breaker is HALF-OPEN and rejecting excess requests" + : "Circuit Breaker is OPEN and blocking requests"; + throw config.getExceptionBuilder().apply(rejectionReason); + } + super.sendMessage(message); + } + + private ForwardingClientCallListener.SimpleForwardingClientCallListener + wrapListenerWithCircuitBreaker(Listener responseListener, Instant startTime) { + return new ForwardingClientCallListener.SimpleForwardingClientCallListener<>( + responseListener) { + @Override + public void onClose(Status status, Metadata trailers) { + long duration = Duration.between(startTime, clock.instant()).toNanos(); + CircuitBreaker circuitBreaker = optionalCircuitBreaker.orElse(null); + if (circuitBreaker == null) { + super.onClose(status, trailers); + return; + } + if (status.isOk()) { + circuitBreaker.onSuccess(duration, TimeUnit.NANOSECONDS); + } else { + log.debug( + "Circuit Breaker '{}' detected failure. Status: {}, Description: {}", + circuitBreaker.getName(), + status.getCode(), + status.getDescription()); + circuitBreaker.onError(duration, TimeUnit.NANOSECONDS, status.asRuntimeException()); + } + super.onClose(status, trailers); + } + }; + } + }; + } + + private void logCircuitBreakerRejection(String circuitBreakerKey, CircuitBreaker circuitBreaker) { + Map stateMessages = + Map.of( + CircuitBreaker.State.HALF_OPEN, "is HALF-OPEN and rejecting excess requests.", + CircuitBreaker.State.OPEN, "is OPEN and blocking requests"); + log.debug( + "Circuit Breaker '{}' {}", + circuitBreakerKey, + stateMessages.getOrDefault(circuitBreaker.getState(), "is in an unexpected state")); + } +} diff --git a/grpc-circuitbreaker-utils/src/main/java/org/hypertrace/circuitbreaker/grpcutils/resilience/ResilienceCircuitBreakerProvider.java b/grpc-circuitbreaker-utils/src/main/java/org/hypertrace/circuitbreaker/grpcutils/resilience/ResilienceCircuitBreakerProvider.java new file mode 100644 index 0000000..897a3de --- /dev/null +++ b/grpc-circuitbreaker-utils/src/main/java/org/hypertrace/circuitbreaker/grpcutils/resilience/ResilienceCircuitBreakerProvider.java @@ -0,0 +1,104 @@ +package org.hypertrace.circuitbreaker.grpcutils.resilience; + +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; +import io.github.resilience4j.circuitbreaker.CircuitBreaker; +import io.github.resilience4j.circuitbreaker.CircuitBreakerConfig; +import io.github.resilience4j.circuitbreaker.CircuitBreakerRegistry; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import lombok.extern.slf4j.Slf4j; + +/** Utility class to provide Resilience4j CircuitBreaker */ +@Slf4j +class ResilienceCircuitBreakerProvider { + + private static final String SHARED_KEY = "SHARED_KEY"; + private final CircuitBreakerRegistry circuitBreakerRegistry; + private final Map circuitBreakerConfigMap; + private final List disabledKeys; + private final boolean defaultEnabled; + + // LoadingCache to manage CircuitBreaker instances with automatic loading and eviction + private final LoadingCache> circuitBreakerCache = + CacheBuilder.newBuilder() + .expireAfterAccess(60, TimeUnit.MINUTES) // Auto-evict after 60 minutes + .maximumSize(10000) // Limit max cache size + .build( + new CacheLoader<>() { + @Override + public Optional load(String key) { + return buildNewCircuitBreaker(key); + } + }); + + public ResilienceCircuitBreakerProvider( + CircuitBreakerRegistry circuitBreakerRegistry, + Map circuitBreakerConfigMap, + List disabledKeys, + boolean defaultEnabled) { + this.circuitBreakerRegistry = circuitBreakerRegistry; + this.circuitBreakerConfigMap = circuitBreakerConfigMap; + this.disabledKeys = disabledKeys; + this.defaultEnabled = defaultEnabled; + } + + public Optional getCircuitBreaker(String circuitBreakerKey) { + if (disabledKeys.contains(circuitBreakerKey)) { + return Optional.empty(); + } + return circuitBreakerCache.getUnchecked(circuitBreakerKey); + } + + public Optional getSharedCircuitBreaker() { + return defaultEnabled ? getCircuitBreaker(SHARED_KEY) : Optional.empty(); + } + + private static void attachListeners(CircuitBreaker circuitBreaker) { + circuitBreaker + .getEventPublisher() + .onStateTransition( + event -> + log.info( + "State transition: {} for circuit breaker {}", + event.getStateTransition(), + event.getCircuitBreakerName())) + .onCallNotPermitted( + event -> + log.debug( + "Call not permitted: Circuit is OPEN for circuit breaker {}", + event.getCircuitBreakerName())) + .onEvent( + event -> + log.debug( + "Circuit breaker event type {} for circuit breaker name {}", + event.getEventType(), + event.getCircuitBreakerName())); + } + + private Optional buildNewCircuitBreaker(String circuitBreakerKey) { + return Optional.ofNullable(circuitBreakerConfigMap.get(circuitBreakerKey)) + .map( + config -> { + CircuitBreaker circuitBreaker = + circuitBreakerRegistry.circuitBreaker(circuitBreakerKey, config); + attachListeners(circuitBreaker); // Attach listeners here + return circuitBreaker; + }) + .or( + () -> { + if (defaultEnabled) { + CircuitBreaker circuitBreaker = + circuitBreakerRegistry.circuitBreaker(circuitBreakerKey); + attachListeners( + circuitBreaker); // Attach listeners here for default circuit breaker + return Optional.of(circuitBreaker); + } else { + return Optional.empty(); + } + }); + } +} diff --git a/grpc-circuitbreaker-utils/src/main/java/org/hypertrace/circuitbreaker/grpcutils/resilience/ResilienceCircuitBreakerRegistryProvider.java b/grpc-circuitbreaker-utils/src/main/java/org/hypertrace/circuitbreaker/grpcutils/resilience/ResilienceCircuitBreakerRegistryProvider.java new file mode 100644 index 0000000..329c909 --- /dev/null +++ b/grpc-circuitbreaker-utils/src/main/java/org/hypertrace/circuitbreaker/grpcutils/resilience/ResilienceCircuitBreakerRegistryProvider.java @@ -0,0 +1,36 @@ +package org.hypertrace.circuitbreaker.grpcutils.resilience; + +import io.github.resilience4j.circuitbreaker.CircuitBreaker; +import io.github.resilience4j.circuitbreaker.CircuitBreakerRegistry; +import lombok.extern.slf4j.Slf4j; +import org.hypertrace.circuitbreaker.grpcutils.CircuitBreakerThresholds; + +/** Utility class to provide Resilience4j CircuitBreakerRegistry */ +@Slf4j +class ResilienceCircuitBreakerRegistryProvider { + private final CircuitBreakerThresholds circuitBreakerThresholds; + + public ResilienceCircuitBreakerRegistryProvider( + CircuitBreakerThresholds circuitBreakerThresholds) { + this.circuitBreakerThresholds = circuitBreakerThresholds; + } + + public CircuitBreakerRegistry getCircuitBreakerRegistry() { + CircuitBreakerRegistry circuitBreakerRegistry = + CircuitBreakerRegistry.of( + ResilienceCircuitBreakerConfigConverter.convertConfig(circuitBreakerThresholds)); + circuitBreakerRegistry + .getEventPublisher() + .onEntryAdded( + entryAddedEvent -> { + CircuitBreaker addedCircuitBreaker = entryAddedEvent.getAddedEntry(); + log.debug("CircuitBreaker {} added", addedCircuitBreaker.getName()); + }) + .onEntryRemoved( + entryRemovedEvent -> { + CircuitBreaker removedCircuitBreaker = entryRemovedEvent.getRemovedEntry(); + log.debug("CircuitBreaker {} removed", removedCircuitBreaker.getName()); + }); + return circuitBreakerRegistry; + } +} diff --git a/grpc-circuitbreaker-utils/src/test/java/org/hypertrace/circuitbreaker/grpcutils/resilience/ResilienceCircuitBreakerConfigConverterTest.java b/grpc-circuitbreaker-utils/src/test/java/org/hypertrace/circuitbreaker/grpcutils/resilience/ResilienceCircuitBreakerConfigConverterTest.java new file mode 100644 index 0000000..b49f3b3 --- /dev/null +++ b/grpc-circuitbreaker-utils/src/test/java/org/hypertrace/circuitbreaker/grpcutils/resilience/ResilienceCircuitBreakerConfigConverterTest.java @@ -0,0 +1,54 @@ +package org.hypertrace.circuitbreaker.grpcutils.resilience; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import io.github.resilience4j.circuitbreaker.CircuitBreakerConfig; +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; +import org.hypertrace.circuitbreaker.grpcutils.CircuitBreakerThresholds; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class ResilienceCircuitBreakerConfigConverterTest { + + @Test + void shouldParseValidConfiguration() { + CircuitBreakerThresholds thresholds = + CircuitBreakerThresholds.builder() + .failureRateThreshold(50.0f) + .slowCallRateThreshold(30.0f) + .slowCallDurationThreshold(Duration.ofSeconds(2)) + .slidingWindowType(CircuitBreakerThresholds.SlidingWindowType.TIME_BASED) + .slidingWindowSize(100) + .waitDurationInOpenState(Duration.ofSeconds(60)) + .permittedNumberOfCallsInHalfOpenState(5) + .minimumNumberOfCalls(20) + .build(); + + Map configMap = new HashMap<>(); + configMap.put("testService", thresholds); + + Map result = + ResilienceCircuitBreakerConfigConverter.getCircuitBreakerConfigs(configMap); + + Assertions.assertTrue(result.containsKey("testService")); + + CircuitBreakerConfig config = result.get("testService"); + assertEquals(50.0f, config.getFailureRateThreshold()); + assertEquals(30.0f, config.getSlowCallRateThreshold()); + assertEquals(Duration.ofSeconds(2), config.getSlowCallDurationThreshold()); + assertEquals(CircuitBreakerConfig.SlidingWindowType.TIME_BASED, config.getSlidingWindowType()); + assertEquals(100, config.getSlidingWindowSize()); + assertEquals(5, config.getPermittedNumberOfCallsInHalfOpenState()); + assertEquals(20, config.getMinimumNumberOfCalls()); + } + + @Test + void shouldThrowExceptionWhenConfigurationIsNull() { + assertThrows( + NullPointerException.class, + () -> ResilienceCircuitBreakerConfigConverter.convertConfig(null)); + } +} diff --git a/grpc-circuitbreaker-utils/src/test/java/org/hypertrace/circuitbreaker/grpcutils/resilience/ResilienceCircuitBreakerInterceptorTest.java b/grpc-circuitbreaker-utils/src/test/java/org/hypertrace/circuitbreaker/grpcutils/resilience/ResilienceCircuitBreakerInterceptorTest.java new file mode 100644 index 0000000..c136b95 --- /dev/null +++ b/grpc-circuitbreaker-utils/src/test/java/org/hypertrace/circuitbreaker/grpcutils/resilience/ResilienceCircuitBreakerInterceptorTest.java @@ -0,0 +1,172 @@ +package org.hypertrace.circuitbreaker.grpcutils.resilience; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import io.github.resilience4j.circuitbreaker.CircuitBreaker; +import io.github.resilience4j.circuitbreaker.CircuitBreakerRegistry; +import io.grpc.*; +import java.time.Clock; +import java.time.Instant; +import java.time.ZoneOffset; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import org.hypertrace.circuitbreaker.grpcutils.CircuitBreakerConfiguration; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.*; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +class ResilienceCircuitBreakerInterceptorTest { + + @Mock private Channel mockChannel; + @Mock private ClientCall mockClientCall; + @Mock private CircuitBreaker mockCircuitBreaker; + @Mock private Metadata mockMetadata; + @Mock private ClientCall.Listener mockListener; + @Mock private ResilienceCircuitBreakerRegistryProvider mockCircuitBreakerRegistryProvider; + @Mock private ResilienceCircuitBreakerProvider mockCircuitBreakerProvider; + @Mock private CircuitBreakerConfiguration mockCircuitBreakerConfig; + @Mock private CircuitBreakerRegistry mockCircuitBreakerRegistry; + + @Mock private Clock fixedClock; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + + fixedClock = Clock.fixed(Instant.now(), ZoneOffset.UTC); + when(mockChannel.newCall(any(), any())).thenReturn(mockClientCall); + } + + @Test + void testSendMessage_CallsSuperSendMessage_Success() { + doNothing().when(mockClientCall).sendMessage(any()); + + ResilienceCircuitBreakerInterceptor interceptor = + new ResilienceCircuitBreakerInterceptor( + mockCircuitBreakerConfig, fixedClock, mockCircuitBreakerProvider); + + ClientCall interceptedCall = + interceptor.createInterceptedCall( + mock(MethodDescriptor.class), CallOptions.DEFAULT, mockChannel); + + interceptedCall.start(mockListener, mockMetadata); + interceptedCall.sendMessage(new Object()); + + verify(mockClientCall).sendMessage(any()); + } + + @Test + void testSendMessage_CircuitBreakerRejectsRequest() { + when(mockCircuitBreaker.tryAcquirePermission()).thenReturn(false); + when(mockCircuitBreaker.getState()).thenReturn(CircuitBreaker.State.OPEN); + when(mockCircuitBreakerProvider.getSharedCircuitBreaker()) + .thenReturn(Optional.of(mockCircuitBreaker)); + when(mockCircuitBreakerConfig.getExceptionBuilder()) + .thenReturn( + reason -> + new StatusRuntimeException( + Status.RESOURCE_EXHAUSTED.withDescription(reason), mock(Metadata.class))); + ResilienceCircuitBreakerInterceptor interceptor = + new ResilienceCircuitBreakerInterceptor( + mockCircuitBreakerConfig, fixedClock, mockCircuitBreakerProvider); + + ClientCall interceptedCall = + interceptor.createInterceptedCall( + mock(MethodDescriptor.class), CallOptions.DEFAULT, mockChannel); + + interceptedCall.start(mockListener, mockMetadata); + + assertThrows( + StatusRuntimeException.class, + () -> interceptedCall.sendMessage(new Object()), + "Circuit Breaker should reject request"); + + verify(mockClientCall, never()).sendMessage(any()); + } + + @Test + void testSendMessage_CircuitBreakerInHalfOpenState() { + when(mockCircuitBreaker.tryAcquirePermission()).thenReturn(false); + when(mockCircuitBreaker.getState()).thenReturn(CircuitBreaker.State.HALF_OPEN); + when(mockCircuitBreakerProvider.getSharedCircuitBreaker()) + .thenReturn(Optional.of(mockCircuitBreaker)); + when(mockCircuitBreakerConfig.getExceptionBuilder()) + .thenReturn( + reason -> + new StatusRuntimeException( + Status.RESOURCE_EXHAUSTED.withDescription(reason), mock(Metadata.class))); + ResilienceCircuitBreakerInterceptor interceptor = + new ResilienceCircuitBreakerInterceptor( + mockCircuitBreakerConfig, fixedClock, mockCircuitBreakerProvider); + + ClientCall interceptedCall = + interceptor.createInterceptedCall( + mock(MethodDescriptor.class), CallOptions.DEFAULT, mockChannel); + + interceptedCall.start(mockListener, mockMetadata); + + assertThrows( + StatusRuntimeException.class, + () -> interceptedCall.sendMessage(new Object()), + "Circuit Breaker should reject requests when in HALF-OPEN state"); + + verify(mockClientCall, never()).sendMessage(any()); + } + + @Test + void testWrapListenerWithCircuitBreaker_Success() { + when(mockCircuitBreaker.tryAcquirePermission()).thenReturn(true); + when(mockCircuitBreakerProvider.getSharedCircuitBreaker()) + .thenReturn(Optional.of(mockCircuitBreaker)); + ResilienceCircuitBreakerInterceptor interceptor = + new ResilienceCircuitBreakerInterceptor( + mockCircuitBreakerConfig, fixedClock, mockCircuitBreakerProvider); + + ClientCall interceptedCall = + interceptor.createInterceptedCall( + mock(MethodDescriptor.class), CallOptions.DEFAULT, mockChannel); + + interceptedCall.start(mockListener, mockMetadata); + interceptedCall.sendMessage(new Object()); + + // Trigger `onClose` directly to mimic gRPC's flow + ArgumentCaptor> listenerCaptor = + ArgumentCaptor.forClass(ForwardingClientCallListener.class); + verify(mockClientCall).start(listenerCaptor.capture(), any()); + listenerCaptor.getValue().onClose(Status.OK, mockMetadata); + + verify(mockClientCall).sendMessage(any()); + verify(mockCircuitBreaker).onSuccess(anyLong(), eq(TimeUnit.NANOSECONDS)); + } + + @Test + void testWrapListenerWithCircuitBreaker_Failure() { + when(mockCircuitBreaker.tryAcquirePermission()).thenReturn(true); + when(mockCircuitBreakerProvider.getSharedCircuitBreaker()) + .thenReturn(Optional.of(mockCircuitBreaker)); + ResilienceCircuitBreakerInterceptor interceptor = + new ResilienceCircuitBreakerInterceptor( + mockCircuitBreakerConfig, fixedClock, mockCircuitBreakerProvider); + + ClientCall interceptedCall = + interceptor.createInterceptedCall( + mock(MethodDescriptor.class), CallOptions.DEFAULT, mockChannel); + + interceptedCall.start(mockListener, mockMetadata); + interceptedCall.sendMessage(new Object()); + + // Trigger `onClose` directly to mimic gRPC's flow + ArgumentCaptor> listenerCaptor = + ArgumentCaptor.forClass(ForwardingClientCallListener.class); + verify(mockClientCall).start(listenerCaptor.capture(), any()); + listenerCaptor.getValue().onClose(Status.UNKNOWN, mockMetadata); + + verify(mockClientCall).sendMessage(any()); + verify(mockCircuitBreaker).onError(anyLong(), eq(TimeUnit.NANOSECONDS), any()); + } +} diff --git a/settings.gradle.kts b/settings.gradle.kts index c13b603..94ab61c 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -18,3 +18,4 @@ include(":grpc-server-rx-utils") include(":grpc-context-utils") include(":grpc-server-utils") include(":grpc-validation-utils") +include(":grpc-circuitbreaker-utils")