diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..8696336 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,28 @@ +# Repository Guidelines + +## Project Structure & Module Organization +- Core libraries live in module directories such as `core`, `core-reactor`, and `core-kotlin-coroutine`; each follows the Gradle layout `src/main` and `src/test`. +- Spring adapters sit under `core-spring*` modules, while runnable samples are in `req-shield-*example` projects. +- Shared utilities and constants are centralized in `support`. Generated build outputs stay under each module's `build/` folder. + +## Build, Test, and Development Commands +- `./gradlew build` compiles all modules, runs unit tests, and assembles artifacts; pass `--parallel` for faster local feedback. +- `./gradlew test` executes Kotlin/JVM unit tests across every enabled module. +- `./gradlew ktlintCheck` enforces the project's formatting contract before you open a PR. +- Use `./gradlew :req-shield-spring-boot3-example:bootRun` (or another sample module) to manually exercise integration paths. + +## Coding Style & Naming Conventions +- Kotlin sources use 4-space indentation, `UpperCamelCase` for types, and `lowerCamelCase` for functions and properties. +- Keep package names lowercase and aligned with module boundaries (e.g., `com.linecorp.reqshield.core`). +- Always add the Apache 2.0 copyright header shown in `CONTRIBUTING.md` to new files. +- Prefer early-return patterns and meaningful exception messages; align with the `ErrorCode` enums already defined. + +## Testing Guidelines +- Write tests with JUnit 5 (`org.junit.jupiter`) and place them under `src/test/kotlin` mirroring the `src/main` package. +- Use descriptive method names such as `shouldCollapseConcurrentRequests()` and cover both success and failure paths. +- When adding integration behaviour, extend the corresponding example module and run `./gradlew test` before submission. + +## Commit & Pull Request Guidelines +- Follow the repository's history of concise, imperative commits (e.g., `Add cache invalidation helper`). +- Reference related issues in the body, summarise motivation, modifications, and results, and include screenshots/logs for behaviour changes. +- Verify CLS (Contributor License Agreement) status, ensure CI passes locally, and request review from a maintainer familiar with your module. diff --git a/README.md b/README.md index 05506bb..8b6774a 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,35 @@ A lib that regulates the cache-based requests an application receives in terms o `implementation("com.linecorp.cse.reqshield:core-spring-webflux:{version}")`
`implementation("com.linecorp.cse.reqshield:core-spring-webflux-kotlin-coroutine:{version}")`
+## Testing & Integration Tips + +### Integration tests with Redis (Testcontainers) + +- Some modules provide optional Redis-backed integration tests leveraging Testcontainers. +- Enable them via an environment variable to avoid Docker dependency by default: + - macOS/Linux: `RUN_REDIS_IT=true ./gradlew :core-spring-webflux:test :core-spring-webflux-kotlin-coroutine:test` + - Windows (PowerShell): `$env:RUN_REDIS_IT='true'; ./gradlew :core-spring-webflux:test :core-spring-webflux-kotlin-coroutine:test` +- If Docker is not available, these tests stay skipped. In-memory integration tests still validate core semantics (request collapsing and eviction). + +### WebFlux null handling + +- `@ReqShieldCacheable(nullHandling = ...)` controls how `null` values are emitted in WebFlux: + - `EMIT_EMPTY` (default): map `null` to `Mono.empty()`. + - `ERROR`: throw an `IllegalStateException` if a `null` value is produced. + +### Global lock guidance + +- When `isLocalLock = false`, you must provide real global lock/unlock implementations. +- Recommended approach with Redis: + - Lock: `SETNX lock:{key} 1` + `PEXPIRE lock:{key} {ttlMillis}` + - Unlock: `DEL lock:{key}` +- The provided defaults return `true` and are only suitable for local/dev usage. + +### Reactor Scheduler tuning + +- Reactor-based modules accept a `Scheduler` (e.g., `boundedElastic`) through configuration. +- Spring WebFlux adapter exposes a `reqShieldScheduler` bean you can override for tuning thread usage. + ## Contributing Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to diff --git a/core-kotlin-coroutine/src/main/kotlin/com/linecorp/cse/reqshield/kotlin/coroutine/KeyLocalLock.kt b/core-kotlin-coroutine/src/main/kotlin/com/linecorp/cse/reqshield/kotlin/coroutine/KeyLocalLock.kt index cfc8472..ef36e65 100644 --- a/core-kotlin-coroutine/src/main/kotlin/com/linecorp/cse/reqshield/kotlin/coroutine/KeyLocalLock.kt +++ b/core-kotlin-coroutine/src/main/kotlin/com/linecorp/cse/reqshield/kotlin/coroutine/KeyLocalLock.kt @@ -32,26 +32,40 @@ import kotlin.coroutines.CoroutineContext private val log = LoggerFactory.getLogger(KeyLocalLock::class.java) class KeyLocalLock(private val lockTimeoutMillis: Long) : KeyLock, CoroutineScope { - private data class LockInfo(val semaphore: Semaphore, val createdAt: Long) + private data class LockInfo(val semaphore: Semaphore, val expiresAt: Long) - private val lockMap = ConcurrentHashMap() + companion object { + private val lockMap = ConcurrentHashMap() + + @Volatile + private var monitorJob: Job? = null + + private fun ensureMonitorStarted() { + if (monitorJob?.isActive == true) return + synchronized(this) { + if (monitorJob?.isActive == true) return + monitorJob = + CoroutineScope(Dispatchers.IO).launch { + while (isActive) { + runCatching { + val now = System.currentTimeMillis() + lockMap.entries.removeIf { now > it.value.expiresAt } + delay(LOCK_MONITOR_INTERVAL_MILLIS) + }.onFailure { e -> + log.error("Error in lock lifecycle monitoring : {}", e.message) + } + } + } + } + } + } private val job = Job() override val coroutineContext: CoroutineContext get() = Dispatchers.IO + job init { - launch { - while (isActive) { - runCatching { - val now = System.currentTimeMillis() - lockMap.entries.removeIf { now - it.value.createdAt > lockTimeoutMillis } // 특정 시간이 지나면 lock 여부와 상관없이 map에서 삭제한다. - delay(LOCK_MONITOR_INTERVAL_MILLIS) - }.onFailure { e -> - log.error("Error in lock lifecycle monitoring : {}", e.message) - } - } - } + ensureMonitorStarted() } override suspend fun tryLock( @@ -59,8 +73,8 @@ class KeyLocalLock(private val lockTimeoutMillis: Long) : KeyLock, CoroutineScop lockType: LockType, ): Boolean { val completeKey = "${key}_${lockType.name}" - val lockInfo = lockMap.computeIfAbsent(completeKey) { LockInfo(Semaphore(1), nowToEpochTime()) } - + val now = nowToEpochTime() + val lockInfo = lockMap.computeIfAbsent(completeKey) { LockInfo(Semaphore(1), now + lockTimeoutMillis) } return lockInfo.semaphore.tryAcquire() } diff --git a/core-kotlin-coroutine/src/test/kotlin/com/linecorp/cse/reqshield/kotlin/coroutine/KeyGlobalLockTest.kt b/core-kotlin-coroutine/src/test/kotlin/com/linecorp/cse/reqshield/kotlin/coroutine/KeyGlobalLockTest.kt index e55048a..7754b4a 100644 --- a/core-kotlin-coroutine/src/test/kotlin/com/linecorp/cse/reqshield/kotlin/coroutine/KeyGlobalLockTest.kt +++ b/core-kotlin-coroutine/src/test/kotlin/com/linecorp/cse/reqshield/kotlin/coroutine/KeyGlobalLockTest.kt @@ -31,9 +31,11 @@ import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable import java.util.concurrent.atomic.AtomicInteger import kotlin.test.Ignore +@EnabledIfEnvironmentVariable(named = "RUN_REDIS_IT", matches = "true") class KeyGlobalLockTest : AbstractRedisTest(), BaseKeyLockTest { diff --git a/core-kotlin-coroutine/src/test/kotlin/com/linecorp/cse/reqshield/kotlin/coroutine/KeyLocalLockTest.kt b/core-kotlin-coroutine/src/test/kotlin/com/linecorp/cse/reqshield/kotlin/coroutine/KeyLocalLockTest.kt index afb0667..9043335 100644 --- a/core-kotlin-coroutine/src/test/kotlin/com/linecorp/cse/reqshield/kotlin/coroutine/KeyLocalLockTest.kt +++ b/core-kotlin-coroutine/src/test/kotlin/com/linecorp/cse/reqshield/kotlin/coroutine/KeyLocalLockTest.kt @@ -29,6 +29,36 @@ import org.junit.jupiter.api.Test import java.util.concurrent.atomic.AtomicInteger class KeyLocalLockTest : BaseKeyLockTest { + @Test + fun `should share global lockMap across multiple instances`() = + runBlocking { + val instance1 = KeyLocalLock(lockTimeoutMillis) + val instance2 = KeyLocalLock(lockTimeoutMillis) + val key = "shared-key" + val lockType = LockType.CREATE + + assertTrue(instance1.tryLock(key, lockType)) + assertTrue(!instance2.tryLock(key, lockType)) + + instance1.unLock(key, lockType) + } + + @Test + fun `should maintain request collapsing across multiple instances`() = + runBlocking { + val instance1 = KeyLocalLock(lockTimeoutMillis) + val instance2 = KeyLocalLock(lockTimeoutMillis) + val instance3 = KeyLocalLock(lockTimeoutMillis) + val key = "collapsing-key" + val lockType = LockType.CREATE + + val acquired = listOf(instance1, instance2, instance3).map { it.tryLock(key, lockType) }.count { it } + assertEquals(1, acquired) + + // cleanup whoever acquired + listOf(instance1, instance2, instance3).forEach { it.unLock(key, lockType) } + } + @Test override fun testConcurrencyWithOneKey() = runBlocking { diff --git a/core-reactor/src/main/kotlin/com/linecorp/cse/reqshield/reactor/KeyLocalLock.kt b/core-reactor/src/main/kotlin/com/linecorp/cse/reqshield/reactor/KeyLocalLock.kt index 427744e..744db3d 100644 --- a/core-reactor/src/main/kotlin/com/linecorp/cse/reqshield/reactor/KeyLocalLock.kt +++ b/core-reactor/src/main/kotlin/com/linecorp/cse/reqshield/reactor/KeyLocalLock.kt @@ -33,20 +33,34 @@ class KeyLocalLock( ) : KeyLock { private data class LockInfo( val semaphore: Semaphore, - val createdAt: Long, + val expiresAt: Long, ) - private val lockMap = ConcurrentHashMap() + companion object { + private val lockMap = ConcurrentHashMap() + + @Volatile + private var monitoringStarted: Boolean = false + + private fun startMonitoringOnce() { + if (monitoringStarted) return + synchronized(this) { + if (monitoringStarted) return + Flux + .interval(Duration.ofMillis(LOCK_MONITOR_INTERVAL_MILLIS), Schedulers.single()) + .doOnNext { + val now = System.currentTimeMillis() + lockMap.entries.removeIf { now > it.value.expiresAt } + }.doOnError { e -> + log.error("Error in lock lifecycle monitoring : {}", e.message) + }.subscribe() + monitoringStarted = true + } + } + } init { - Flux - .interval(Duration.ofMillis(LOCK_MONITOR_INTERVAL_MILLIS), Schedulers.single()) - .doOnNext { - val now = System.currentTimeMillis() - lockMap.entries.removeIf { now - it.value.createdAt > lockTimeoutMillis } - }.doOnError { e -> - log.error("Error in lock lifecycle monitoring : {}", e.message) - }.subscribe() + startMonitoringOnce() } override fun tryLock( @@ -55,7 +69,11 @@ class KeyLocalLock( ): Mono = Mono.fromCallable { val completeKey = "${key}_${lockType.name}" - val lockInfo = lockMap.computeIfAbsent(completeKey) { LockInfo(Semaphore(1), nowToEpochTime()) } + val now = nowToEpochTime() + val lockInfo = + lockMap.computeIfAbsent(completeKey) { + LockInfo(Semaphore(1), now + lockTimeoutMillis) + } lockInfo.semaphore.tryAcquire() } diff --git a/core-reactor/src/main/kotlin/com/linecorp/cse/reqshield/reactor/ReqShield.kt b/core-reactor/src/main/kotlin/com/linecorp/cse/reqshield/reactor/ReqShield.kt index 1d0b03f..50beeae 100644 --- a/core-reactor/src/main/kotlin/com/linecorp/cse/reqshield/reactor/ReqShield.kt +++ b/core-reactor/src/main/kotlin/com/linecorp/cse/reqshield/reactor/ReqShield.kt @@ -27,7 +27,6 @@ import com.linecorp.cse.reqshield.support.model.ReqShieldData import com.linecorp.cse.reqshield.support.utils.decideToUpdateCache import reactor.core.publisher.Flux import reactor.core.publisher.Mono -import reactor.core.scheduler.Schedulers import reactor.util.retry.Retry import java.time.Duration import java.util.concurrent.Callable @@ -73,13 +72,13 @@ class ReqShield( fun processMono(): Mono> = executeCallable({ callable.call() }, true, key, lockType) .map { data -> buildReqShieldData(data, timeToLiveMillis) } - .doOnNext { reqShieldData -> + .flatMap { reqShieldData -> setReqShieldData( reqShieldConfig.setCacheFunction, key, reqShieldData, lockType, - ) + ).thenReturn(reqShieldData) }.switchIfEmpty( Mono.defer { val reqShieldData = buildReqShieldData(null, timeToLiveMillis) @@ -88,21 +87,20 @@ class ReqShield( key, reqShieldData, lockType, - ) - Mono.just(reqShieldData) + ).thenReturn(reqShieldData) }, ) if (reqShieldConfig.reqShieldWorkMode == ReqShieldWorkMode.ONLY_CREATE_CACHE) { processMono() - .subscribeOn(Schedulers.boundedElastic()) + .subscribeOn(reqShieldConfig.scheduler) .subscribe() } else { reqShieldConfig.keyLock .tryLock(key, lockType) .filter { it } .flatMap { processMono() } - .subscribeOn(Schedulers.boundedElastic()) + .subscribeOn(reqShieldConfig.scheduler) .subscribe() } } @@ -138,15 +136,12 @@ class ReqShield( executeCallable({ callable.call() }, true, key, lockType) .map { data -> buildReqShieldData(data, timeToLiveMillis) } .flatMap { reqShieldData -> - setReqShieldData( reqShieldConfig.setCacheFunction, key, reqShieldData, lockType, - ) - - Mono.just(reqShieldData) + ).thenReturn(reqShieldData) }.switchIfEmpty( Mono.defer { val reqShieldData = buildReqShieldData(null, timeToLiveMillis) @@ -156,9 +151,7 @@ class ReqShield( key, reqShieldData, lockType, - ) - - Mono.just(reqShieldData) + ).thenReturn(reqShieldData) }, ) @@ -191,7 +184,7 @@ class ReqShield( Mono.just(reqShieldData) }, ), - ).subscribeOn(Schedulers.boundedElastic()) + ).subscribeOn(reqShieldConfig.scheduler) private fun buildReqShieldData( value: T?, @@ -207,18 +200,14 @@ class ReqShield( key: String, reqShieldData: ReqShieldData, lockType: LockType, - ) { - executeSetCacheFunction(cacheSetter, key, reqShieldData, lockType).subscribe() - } + ): Mono = executeSetCacheFunction(cacheSetter, key, reqShieldData, lockType) private fun executeGetCacheFunction( getFunction: (String) -> Mono?>, key: String, ): Mono?> = getFunction(key) - .doOnError { e -> - throw ClientException(ErrorCode.GET_CACHE_ERROR, originErrorMessage = e.message) - } + .onErrorMap { e -> ClientException(ErrorCode.GET_CACHE_ERROR, originErrorMessage = e.message) } private fun executeSetCacheFunction( setFunction: (String, ReqShieldData, Long) -> Mono, @@ -227,9 +216,8 @@ class ReqShield( lockType: LockType, ): Mono = setFunction(key, value, value.timeToLiveMillis) - .doOnError { e -> - throw ClientException(ErrorCode.SET_CACHE_ERROR, originErrorMessage = e.message) - }.doFinally { + .onErrorMap { e -> ClientException(ErrorCode.SET_CACHE_ERROR, originErrorMessage = e.message) } + .doFinally { if (shouldAttemptUnlock(lockType)) { reqShieldConfig.keyLock .unLock(key, lockType) @@ -240,7 +228,7 @@ class ReqShield( ), ).subscribe() } - }.subscribeOn(Schedulers.boundedElastic()) + }.subscribeOn(reqShieldConfig.scheduler) private fun executeCallable( callable: Callable>, @@ -254,7 +242,9 @@ class ReqShield( if (isUnlockWhenException && key != null && lockType != null) { reqShieldConfig.keyLock.unLock(key, lockType).subscribe() } - throw ClientException(ErrorCode.SUPPLIER_ERROR, originErrorMessage = e.message) + } + .onErrorMap { e -> + ClientException(ErrorCode.SUPPLIER_ERROR, originErrorMessage = e.message) } private fun shouldAttemptUnlock(lockType: LockType): Boolean = diff --git a/core-reactor/src/test/kotlin/com/linecorp/cse/reqshield/reactor/KeyGlobalLockTest.kt b/core-reactor/src/test/kotlin/com/linecorp/cse/reqshield/reactor/KeyGlobalLockTest.kt index 17867ae..e01996f 100644 --- a/core-reactor/src/test/kotlin/com/linecorp/cse/reqshield/reactor/KeyGlobalLockTest.kt +++ b/core-reactor/src/test/kotlin/com/linecorp/cse/reqshield/reactor/KeyGlobalLockTest.kt @@ -24,6 +24,7 @@ import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable import reactor.core.publisher.Mono import reactor.core.scheduler.Schedulers import reactor.test.StepVerifier @@ -31,6 +32,7 @@ import java.time.Duration import java.util.concurrent.atomic.AtomicInteger import kotlin.test.Ignore +@EnabledIfEnvironmentVariable(named = "RUN_REDIS_IT", matches = "true") class KeyGlobalLockTest : AbstractRedisTest(), BaseKeyLockTest { diff --git a/core-reactor/src/test/kotlin/com/linecorp/cse/reqshield/reactor/KeyLocalLockTest.kt b/core-reactor/src/test/kotlin/com/linecorp/cse/reqshield/reactor/KeyLocalLockTest.kt index 6460cef..6e38355 100644 --- a/core-reactor/src/test/kotlin/com/linecorp/cse/reqshield/reactor/KeyLocalLockTest.kt +++ b/core-reactor/src/test/kotlin/com/linecorp/cse/reqshield/reactor/KeyLocalLockTest.kt @@ -27,6 +27,41 @@ import java.time.Duration import java.util.concurrent.atomic.AtomicInteger class KeyLocalLockTest : BaseKeyLockTest { + @Test + fun `should share global lockMap across multiple instances`() { + val instance1 = KeyLocalLock(lockTimeoutMillis) + val instance2 = KeyLocalLock(lockTimeoutMillis) + val key = "shared-key" + val lockType = LockType.CREATE + + StepVerifier.create(instance1.tryLock(key, lockType)).expectNext(true).verifyComplete() + StepVerifier.create(instance2.tryLock(key, lockType)).expectNext(false).verifyComplete() + + StepVerifier.create(instance1.unLock(key, lockType)).expectNext(true).verifyComplete() + } + + @Test + fun `should maintain request collapsing across multiple instances`() { + val instance1 = KeyLocalLock(lockTimeoutMillis) + val instance2 = KeyLocalLock(lockTimeoutMillis) + val instance3 = KeyLocalLock(lockTimeoutMillis) + val key = "collapsing-key" + val lockType = LockType.CREATE + + val attempts = + listOf(instance1, instance2, instance3).map { inst -> + inst.tryLock(key, lockType).map { acquired -> if (acquired) 1 else 0 } + } + + StepVerifier + .create(Mono.zip(attempts) { arr -> arr.sumOf { it as Int } }) + .expectNextMatches { it == 1 } + .verifyComplete() + + // cleanup by unlocking whoever acquired + listOf(instance1, instance2, instance3).forEach { inst -> inst.unLock(key, lockType).subscribe() } + } + @Test override fun testConcurrencyWithOneKey() { val keyLock = KeyLocalLock(lockTimeoutMillis) diff --git a/core-spring-webflux-kotlin-coroutine/build.gradle.kts b/core-spring-webflux-kotlin-coroutine/build.gradle.kts index 52595db..9bba987 100644 --- a/core-spring-webflux-kotlin-coroutine/build.gradle.kts +++ b/core-spring-webflux-kotlin-coroutine/build.gradle.kts @@ -34,7 +34,9 @@ dependencies { testImplementation(rootProject.libs.kotlin.coroutine.test) testImplementation(rootProject.libs.kotlin.coroutine.jvm) testImplementation(rootProject.libs.spring.context) + testImplementation(rootProject.libs.spring.test) testImplementation(rootProject.libs.aspectj) + testImplementation(rootProject.libs.lettuce) } tasks.withType().configureEach { diff --git a/core-spring-webflux-kotlin-coroutine/src/main/kotlin/com/linecorp/cse/reqshield/spring/webflux/kotlin/coroutine/aspect/CoroutineExtension.kt b/core-spring-webflux-kotlin-coroutine/src/main/kotlin/com/linecorp/cse/reqshield/spring/webflux/kotlin/coroutine/aspect/CoroutineExtension.kt index 6513ef1..cf09268 100644 --- a/core-spring-webflux-kotlin-coroutine/src/main/kotlin/com/linecorp/cse/reqshield/spring/webflux/kotlin/coroutine/aspect/CoroutineExtension.kt +++ b/core-spring-webflux-kotlin-coroutine/src/main/kotlin/com/linecorp/cse/reqshield/spring/webflux/kotlin/coroutine/aspect/CoroutineExtension.kt @@ -18,6 +18,8 @@ package com.linecorp.cse.reqshield.spring.webflux.kotlin.coroutine.aspect +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext import org.aspectj.lang.ProceedingJoinPoint import kotlin.coroutines.Continuation import kotlin.coroutines.intrinsics.startCoroutineUninterceptedOrReturn @@ -39,3 +41,14 @@ suspend fun ProceedingJoinPoint.proceedCoroutine(args: Array = this.corout fun ProceedingJoinPoint.runCoroutine(block: suspend () -> Any?): Any? = block.startCoroutineUninterceptedOrReturn(this.coroutineContinuation) + +/** + * Proceed supporting both suspend and non-suspend join points. + * If the last argument is a Continuation, treat as suspend; otherwise proceed normally. + */ +suspend fun ProceedingJoinPoint.proceedSmart(): Any? = + if (this.args.isNotEmpty() && this.args.last() is Continuation<*>) { + this.proceedCoroutine() + } else { + withContext(Dispatchers.IO) { this@proceedSmart.proceed() } + } diff --git a/core-spring-webflux-kotlin-coroutine/src/main/kotlin/com/linecorp/cse/reqshield/spring/webflux/kotlin/coroutine/aspect/ReqShieldAspect.kt b/core-spring-webflux-kotlin-coroutine/src/main/kotlin/com/linecorp/cse/reqshield/spring/webflux/kotlin/coroutine/aspect/ReqShieldAspect.kt index 0fd9777..663677a 100644 --- a/core-spring-webflux-kotlin-coroutine/src/main/kotlin/com/linecorp/cse/reqshield/spring/webflux/kotlin/coroutine/aspect/ReqShieldAspect.kt +++ b/core-spring-webflux-kotlin-coroutine/src/main/kotlin/com/linecorp/cse/reqshield/spring/webflux/kotlin/coroutine/aspect/ReqShieldAspect.kt @@ -47,7 +47,7 @@ import kotlin.coroutines.Continuation @Aspect @Component -class ReqShieldAspect( +open class ReqShieldAspect( private val asyncCache: AsyncCache, ) : BeanFactoryAware { private lateinit var beanFactory: BeanFactory @@ -58,41 +58,37 @@ class ReqShieldAspect( private val keyGeneratorMap = ConcurrentHashMap() internal val reqShieldMap = ConcurrentHashMap>() - @Around("execution(@com.linecorp.cse.reqshield.spring.webflux.kotlin.coroutine.annotation.* * *(.., kotlin.coroutines.Continuation))") - fun aroundTargetCacheable(joinPoint: ProceedingJoinPoint): Any? { - return joinPoint.runCoroutine { - getTargetMethod(joinPoint).annotations.forEach { annotation -> - when (annotation) { - is ReqShieldCacheable -> { - val reqShield = getOrCreateReqShield(joinPoint) - val cacheKey = getCacheableCacheKey(joinPoint) - - return@runCoroutine reqShield - .getAndSetReqShieldData( - cacheKey, - { - joinPoint.proceedCoroutine().let { rtn -> - if (rtn is Mono<*>) { - rtn.awaitSingleOrNull()?.let { it as T } - } else { - rtn?.let { it as T } - } - } - }, - annotation.timeToLiveMillis, - ).value - } - - is ReqShieldCacheEvict -> { - val cacheKey = getCacheEvictCacheKey(joinPoint) - return@runCoroutine asyncCache.evict(cacheKey) - } - } - } + @Around("@annotation(com.linecorp.cse.reqshield.spring.webflux.kotlin.coroutine.annotation.ReqShieldCacheable)") + fun aroundReqShieldCacheable(joinPoint: ProceedingJoinPoint): Any? = + joinPoint.runCoroutine { + val annotation = getCacheableAnnotation(joinPoint) + val reqShield = getOrCreateReqShield(joinPoint) + val cacheKey = getCacheableCacheKey(joinPoint) + + reqShield + .getAndSetReqShieldData( + cacheKey, + { + joinPoint.proceedSmart().let { rtn -> + if (rtn is Mono<*>) { + rtn.awaitSingleOrNull()?.let { it as T } + } else { + rtn?.let { it as T } + } + } + }, + annotation.timeToLiveMillis, + ).value + } + + @Around("@annotation(com.linecorp.cse.reqshield.spring.webflux.kotlin.coroutine.annotation.ReqShieldCacheEvict)") + fun aroundReqShieldCacheEvict(joinPoint: ProceedingJoinPoint): Any? = + joinPoint.runCoroutine { + val cacheKey = getCacheEvictCacheKey(joinPoint) + asyncCache.evict(cacheKey) } - } - internal fun getTargetMethod(joinPoint: ProceedingJoinPoint): Method = (joinPoint.signature as MethodSignature).method + internal open fun getTargetMethod(joinPoint: ProceedingJoinPoint): Method = (joinPoint.signature as MethodSignature).method internal fun getCacheableAnnotation(joinPoint: ProceedingJoinPoint): ReqShieldCacheable = AnnotationUtils.getAnnotation(getTargetMethod(joinPoint), ReqShieldCacheable::class.java) @@ -141,7 +137,10 @@ class ReqShieldAspect( keyGenerator.generate(joinPoint.target, method, args).toString() } - require(!key.isNullOrBlank()) { "Null key returned for cache method : $method" } + require(!key.isNullOrBlank()) { + "Null/blank key for @ReqShieldCacheable method=${method.declaringClass.name}.${method.name} " + + "args=${args.joinToString(prefix = "[", postfix = "]") { it?.toString() ?: "null" }}" + } return key } @@ -183,7 +182,9 @@ class ReqShieldAspect( cacheKeyGenerator: String, ) { if (cacheKey.isNotBlank() && cacheKeyGenerator.isNotBlank()) { - throw IllegalArgumentException("The key and keyGenerator attributes are mutually exclusive.") + throw IllegalArgumentException( + "The key and keyGenerator attributes are mutually exclusive: key='$cacheKey', keyGenerator='$cacheKeyGenerator'", + ) } } @@ -206,8 +207,11 @@ class ReqShieldAspect( return major > 6 || (major == 6 && minor >= 1) } - private fun generateReqShieldKey(joinPoint: ProceedingJoinPoint): String = - "${getCacheableAnnotation(joinPoint).cacheName}-${getCacheableCacheKey(joinPoint)}" + private fun generateReqShieldKey(joinPoint: ProceedingJoinPoint): String { + val method = getTargetMethod(joinPoint) + return "${method.declaringClass.name}.${method.name}-" + + "${getCacheableAnnotation(joinPoint).cacheName}-${getCacheableCacheKey(joinPoint)}" + } override fun setBeanFactory(beanFactory: BeanFactory) { this.beanFactory = beanFactory diff --git a/core-spring-webflux-kotlin-coroutine/src/main/kotlin/com/linecorp/cse/reqshield/spring/webflux/kotlin/coroutine/config/LibAutoConfiguration.kt b/core-spring-webflux-kotlin-coroutine/src/main/kotlin/com/linecorp/cse/reqshield/spring/webflux/kotlin/coroutine/config/LibAutoConfiguration.kt index 34fec10..14555fa 100644 --- a/core-spring-webflux-kotlin-coroutine/src/main/kotlin/com/linecorp/cse/reqshield/spring/webflux/kotlin/coroutine/config/LibAutoConfiguration.kt +++ b/core-spring-webflux-kotlin-coroutine/src/main/kotlin/com/linecorp/cse/reqshield/spring/webflux/kotlin/coroutine/config/LibAutoConfiguration.kt @@ -16,11 +16,12 @@ package com.linecorp.cse.reqshield.spring.webflux.kotlin.coroutine.config -import org.springframework.context.annotation.ComponentScan +import com.linecorp.cse.reqshield.spring.webflux.kotlin.coroutine.aspect.ReqShieldAspect import org.springframework.context.annotation.Configuration import org.springframework.context.annotation.EnableAspectJAutoProxy +import org.springframework.context.annotation.Import @Configuration -@EnableAspectJAutoProxy -@ComponentScan(basePackages = ["com.linecorp.cse"]) +@EnableAspectJAutoProxy(proxyTargetClass = true) +@Import(ReqShieldAspect::class) open class LibAutoConfiguration diff --git a/core-spring-webflux-kotlin-coroutine/src/test/kotlin/com/linecorp/cse/reqshield/spring/webflux/kotlin/coroutine/aspect/InMemoryAsyncCache.kt b/core-spring-webflux-kotlin-coroutine/src/test/kotlin/com/linecorp/cse/reqshield/spring/webflux/kotlin/coroutine/aspect/InMemoryAsyncCache.kt new file mode 100644 index 0000000..926c3dc --- /dev/null +++ b/core-spring-webflux-kotlin-coroutine/src/test/kotlin/com/linecorp/cse/reqshield/spring/webflux/kotlin/coroutine/aspect/InMemoryAsyncCache.kt @@ -0,0 +1,56 @@ +/* + * Copyright 2024 LY Corporation + * + * LY Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.cse.reqshield.spring.webflux.kotlin.coroutine.aspect + +import com.linecorp.cse.reqshield.spring.webflux.kotlin.coroutine.cache.AsyncCache +import com.linecorp.cse.reqshield.support.model.ReqShieldData +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.Semaphore + +class InMemoryAsyncCache : AsyncCache { + private data class Entry(val data: ReqShieldData, val expiresAt: Long) + + private val store = ConcurrentHashMap>() + private val locks = ConcurrentHashMap() + + override suspend fun get(key: String): ReqShieldData? { + val now = System.currentTimeMillis() + return store[key]?.let { e -> if (now <= e.expiresAt) e.data else null } + } + + override suspend fun put( + key: String, + value: ReqShieldData, + timeToLiveMillis: Long, + ): Boolean { + val expiresAt = System.currentTimeMillis() + timeToLiveMillis + store[key] = Entry(value, expiresAt) + return true + } + + override suspend fun evict(key: String): Boolean = store.remove(key) != null + + override suspend fun globalLock( + key: String, + timeToLiveMillis: Long, + ): Boolean = locks.computeIfAbsent(key) { Semaphore(1) }.tryAcquire() + + override suspend fun globalUnLock(key: String): Boolean { + locks[key]?.release() + return true + } +} diff --git a/core-spring-webflux-kotlin-coroutine/src/test/kotlin/com/linecorp/cse/reqshield/spring/webflux/kotlin/coroutine/aspect/ReqShieldAspectIntegrationTest.kt b/core-spring-webflux-kotlin-coroutine/src/test/kotlin/com/linecorp/cse/reqshield/spring/webflux/kotlin/coroutine/aspect/ReqShieldAspectIntegrationTest.kt new file mode 100644 index 0000000..c83260f --- /dev/null +++ b/core-spring-webflux-kotlin-coroutine/src/test/kotlin/com/linecorp/cse/reqshield/spring/webflux/kotlin/coroutine/aspect/ReqShieldAspectIntegrationTest.kt @@ -0,0 +1,87 @@ +/* + * Copyright 2024 LY Corporation + * + * LY Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.cse.reqshield.spring.webflux.kotlin.coroutine.aspect + +import com.linecorp.cse.reqshield.spring.webflux.kotlin.coroutine.annotation.ReqShieldCacheEvict +import com.linecorp.cse.reqshield.spring.webflux.kotlin.coroutine.annotation.ReqShieldCacheable +import com.linecorp.cse.reqshield.spring.webflux.kotlin.coroutine.cache.AsyncCache +import com.linecorp.cse.reqshield.spring.webflux.kotlin.coroutine.config.LibAutoConfiguration +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.runBlocking +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.ExtendWith +import org.springframework.beans.factory.annotation.Autowired +import org.springframework.context.annotation.Bean +import org.springframework.context.annotation.Configuration +import org.springframework.test.context.ContextConfiguration +import org.springframework.test.context.junit.jupiter.SpringExtension +import java.util.concurrent.atomic.AtomicInteger + +@ExtendWith(SpringExtension::class) +@ContextConfiguration(classes = [LibAutoConfiguration::class, ReqShieldAspectIntegrationTest.TestConfig::class]) +class ReqShieldAspectIntegrationTest { + @Autowired + private lateinit var service: TestService + + @Test + fun shouldCollapseDuplicateRequests() = + runBlocking { + val key = "dup" + val attempts = 20 + val results = (1..attempts).map { async(Dispatchers.IO) { service.get(key) } }.awaitAll() + assertEquals(attempts, results.size) + val first = results.firstOrNull() + assertTrue(results.all { it == first }) + } + + @Test + fun shouldEvictAndRecompute() = + runBlocking { + val key = "evict" + val v1 = service.get(key) + val evicted = service.evict(key) + val v2 = service.get(key) + + assertTrue(evicted) + assertTrue(v1.isNotEmpty()) + assertTrue(v2.isNotEmpty()) + assertTrue(v1 != v2) + } + + @Configuration + open class TestConfig { + @Bean + open fun asyncCache(): AsyncCache = InMemoryAsyncCache() + + @Bean + open fun service(): TestService = TestService() + } + + open class TestService { + val counter = AtomicInteger(0) + + @ReqShieldCacheable(cacheName = "it", key = "#key", timeToLiveMillis = 10_000) + open suspend fun get(key: String): String = "value-" + counter.incrementAndGet() + + @ReqShieldCacheEvict(cacheName = "it", key = "#key") + open suspend fun evict(key: String): Boolean = true + } +} diff --git a/core-spring-webflux-kotlin-coroutine/src/test/kotlin/com/linecorp/cse/reqshield/spring/webflux/kotlin/coroutine/aspect/ReqShieldAspectRedisIntegrationTest.kt b/core-spring-webflux-kotlin-coroutine/src/test/kotlin/com/linecorp/cse/reqshield/spring/webflux/kotlin/coroutine/aspect/ReqShieldAspectRedisIntegrationTest.kt new file mode 100644 index 0000000..9c6810e --- /dev/null +++ b/core-spring-webflux-kotlin-coroutine/src/test/kotlin/com/linecorp/cse/reqshield/spring/webflux/kotlin/coroutine/aspect/ReqShieldAspectRedisIntegrationTest.kt @@ -0,0 +1,124 @@ +/* + * Copyright 2024 LY Corporation + * + * LY Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.cse.reqshield.spring.webflux.kotlin.coroutine.aspect + +import com.linecorp.cse.reqshield.spring.webflux.kotlin.coroutine.annotation.ReqShieldCacheEvict +import com.linecorp.cse.reqshield.spring.webflux.kotlin.coroutine.annotation.ReqShieldCacheable +import com.linecorp.cse.reqshield.spring.webflux.kotlin.coroutine.cache.AsyncCache +import com.linecorp.cse.reqshield.spring.webflux.kotlin.coroutine.config.LibAutoConfiguration +import com.linecorp.cse.reqshield.support.model.ReqShieldData +import com.linecorp.cse.reqshield.support.redis.AbstractRedisTest +import io.lettuce.core.RedisClient +import io.lettuce.core.api.StatefulRedisConnection +import io.lettuce.core.api.sync.RedisCommands +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.runBlocking +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable +import org.junit.jupiter.api.extension.ExtendWith +import org.springframework.beans.factory.annotation.Autowired +import org.springframework.beans.factory.annotation.Value +import org.springframework.context.annotation.Bean +import org.springframework.context.annotation.Configuration +import org.springframework.test.context.ContextConfiguration +import org.springframework.test.context.junit.jupiter.SpringExtension +import java.util.concurrent.atomic.AtomicInteger + +@ExtendWith(SpringExtension::class) +@ContextConfiguration(classes = [LibAutoConfiguration::class, ReqShieldAspectRedisIntegrationTest.TestConfig::class]) +@EnabledIfEnvironmentVariable(named = "RUN_REDIS_IT", matches = "true") +class ReqShieldAspectRedisIntegrationTest : AbstractRedisTest() { + @Autowired + private lateinit var service: TestService + + @Test + fun shouldCollapseDuplicateRequestsWithRedis() = + runBlocking { + val key = "dup-redis" + val attempts = 20 + val results = (1..attempts).map { async(Dispatchers.IO) { service.get(key) } }.awaitAll() + val first = results.firstOrNull() + assertTrue(results.size == attempts && results.all { it == first }) + assertTrue(service.counter.get() == 1) + } + + @Test + fun shouldEvictAndRecomputeWithRedis() = + runBlocking { + val key = "evict-redis" + val v1 = service.get(key) + val evicted = service.evict(key) + val v2 = service.get(key) + assertTrue(evicted && v1 != v2) + assertTrue(service.counter.get() == 2) + } + + @Configuration + open class TestConfig { + @Value("\${spring.redis.host}") + private lateinit var host: String + + @Value("\${spring.redis.port}") + private var port: Int = 0 + + @Bean + open fun asyncCache(): AsyncCache { + val client: RedisClient = RedisClient.create("redis://$host:$port") + val conn: StatefulRedisConnection = client.connect() + val sync: RedisCommands = conn.sync() + + return object : AsyncCache { + override suspend fun get(key: String): ReqShieldData? = + sync.get(key)?.let { ReqShieldData(value = it, timeToLiveMillis = 10_000) } + + override suspend fun put( + key: String, + value: ReqShieldData, + timeToLiveMillis: Long, + ): Boolean { + sync.psetex(key, timeToLiveMillis, value.value ?: "") + return true + } + + override suspend fun evict(key: String): Boolean = sync.del(key) > 0 + + override suspend fun globalLock( + key: String, + timeToLiveMillis: Long, + ): Boolean = sync.setnx("lock:$key", "1").also { if (it) sync.pexpire("lock:$key", timeToLiveMillis) } + + override suspend fun globalUnLock(key: String): Boolean = sync.del("lock:$key") >= 0 + } + } + + @Bean + open fun service(): TestService = TestService() + } + + open class TestService { + val counter = AtomicInteger(0) + + @ReqShieldCacheable(cacheName = "it", key = "#key", timeToLiveMillis = 10_000) + open suspend fun get(key: String): String = "value-" + counter.incrementAndGet() + + @ReqShieldCacheEvict(cacheName = "it", key = "#key") + open suspend fun evict(key: String): Boolean = true + } +} diff --git a/core-spring-webflux-kotlin-coroutine/src/test/kotlin/com/linecorp/cse/reqshield/spring/webflux/kotlin/coroutine/aspect/ReqShieldAspectTest.kt b/core-spring-webflux-kotlin-coroutine/src/test/kotlin/com/linecorp/cse/reqshield/spring/webflux/kotlin/coroutine/aspect/ReqShieldAspectTest.kt index dfc657d..b80022b 100644 --- a/core-spring-webflux-kotlin-coroutine/src/test/kotlin/com/linecorp/cse/reqshield/spring/webflux/kotlin/coroutine/aspect/ReqShieldAspectTest.kt +++ b/core-spring-webflux-kotlin-coroutine/src/test/kotlin/com/linecorp/cse/reqshield/spring/webflux/kotlin/coroutine/aspect/ReqShieldAspectTest.kt @@ -50,7 +50,7 @@ private val log = LoggerFactory.getLogger(ReqShieldAspectTest::class.java) @OptIn(ExperimentalCoroutinesApi::class) class ReqShieldAspectTest : BaseReqShieldModuleSupportTest { - private val asyncCache: AsyncCache = mockk() + private val asyncCache: AsyncCache = InMemoryAsyncCache() private val joinPoint: ProceedingJoinPoint = mockk() private val reqShieldAspect: ReqShieldAspect = spyk(ReqShieldAspect(asyncCache)) private val targetObject = spyk(TestBean()) @@ -80,9 +80,9 @@ class ReqShieldAspectTest : BaseReqShieldModuleSupportTest { runTest { // Mock the cache data using mockk val reqShieldData = ReqShieldData(methodReturn, 1000) - coEvery { asyncCache.get(any()) } returns reqShieldData + asyncCache.put(spelEvaluatedKey, reqShieldData, 1000) coEvery { joinPoint.proceed() } coAnswers { targetObject.cacheableWithCustomKey(argument) } - coEvery { reqShieldAspect.getTargetMethod(joinPoint) } returns + every { reqShieldAspect.getTargetMethod(joinPoint) } returns TestBean::class .functions .find { @@ -90,11 +90,13 @@ class ReqShieldAspectTest : BaseReqShieldModuleSupportTest { }?.javaMethod!! // Test the aroundTargetCacheable method - val result = reqShieldAspect.aroundTargetCacheable(joinPoint) + val result = reqShieldAspect.aroundReqShieldCacheable(joinPoint) assertEquals(result, reqShieldData.value) assertTrue(reqShieldAspect.reqShieldMap.size == 1) - assertNotNull(reqShieldAspect.reqShieldMap["$cacheName-$spelEvaluatedKey"]) + val method = reqShieldAspect.getTargetMethod(joinPoint) + val expectedKey = "${method.declaringClass.name}.${method.name}-$cacheName-$spelEvaluatedKey" + assertNotNull(reqShieldAspect.reqShieldMap[expectedKey]) } @Test @@ -102,9 +104,9 @@ class ReqShieldAspectTest : BaseReqShieldModuleSupportTest { runTest { // Mock the cache data using mockk val reqShieldData = ReqShieldData(methodReturn, 1000) - coEvery { asyncCache.get(any()) } returns reqShieldData + asyncCache.put(spelEvaluatedKey, reqShieldData, 1000) coEvery { joinPoint.proceed() } coAnswers { targetObject.cacheableWithCustomKey(argument) } - coEvery { reqShieldAspect.getTargetMethod(joinPoint) } returns + every { reqShieldAspect.getTargetMethod(joinPoint) } returns TestBean::class .functions .find { @@ -114,46 +116,49 @@ class ReqShieldAspectTest : BaseReqShieldModuleSupportTest { val jobs = List(20) { async { - reqShieldAspect.aroundTargetCacheable(joinPoint) + reqShieldAspect.aroundReqShieldCacheable(joinPoint) } } jobs.awaitAll() assertTrue(reqShieldAspect.reqShieldMap.size == 1) - assertNotNull(reqShieldAspect.reqShieldMap["$cacheName-$spelEvaluatedKey"]) + val method = reqShieldAspect.getTargetMethod(joinPoint) + val expectedKey = "${method.declaringClass.name}.${method.name}-$cacheName-$spelEvaluatedKey" + assertNotNull(reqShieldAspect.reqShieldMap[expectedKey]) } @Test override fun verifyReqShieldCacheEviction() = runTest { - // Mock the cache data using mockk val reqShieldData = ReqShieldData(methodReturn, 1000) - coEvery { asyncCache.get(any()) } returns reqShieldData - coEvery { joinPoint.proceed() } coAnswers { targetObject.cacheableWithCustomKey(argument) } - coEvery { reqShieldAspect.getTargetMethod(joinPoint) } returns + // Use SpEL-based key to align with eviction method's key + every { reqShieldAspect.getTargetMethod(joinPoint) } returns TestBean::class .functions .find { - it.name == TestBean::cacheableWithDefaultKeyGenerator.name && it.parameters.size == 2 + it.name == TestBean::cacheableWithCustomKey.name && it.parameters.size == 2 }?.javaMethod!! + val generatedKey = reqShieldAspect.getCacheableCacheKey(joinPoint) + asyncCache.put(generatedKey, reqShieldData, 1000) + coEvery { joinPoint.proceed() } coAnswers { targetObject.cacheableWithCustomKey(argument) } - // Test the aroundTargetCacheable method - val result = reqShieldAspect.aroundTargetCacheable(joinPoint) + // Test the aroundTargetCacheable method using the same SpEL key + val result = reqShieldAspect.aroundReqShieldCacheable(joinPoint) assertEquals(reqShieldData.value, result) - // Validate cache eviction - coEvery { asyncCache.evict(any()) } returns true - coEvery { reqShieldAspect.getTargetMethod(joinPoint) } returns + // Validate cache eviction using the eviction method (same SpEL key) + // real eviction call + every { reqShieldAspect.getTargetMethod(joinPoint) } returns TestBean::class .functions .find { it.name == TestBean::evict.name && it.parameters.size == 2 }?.javaMethod!! - coEvery { joinPoint.proceed() } coAnswers { targetObject.evict(argument) } + // Eviction branch in aspect does not require proceeding the original method - val removeProductMono = reqShieldAspect.aroundTargetCacheable(joinPoint) + val removeProductMono = reqShieldAspect.aroundReqShieldCacheEvict(joinPoint) assertTrue(removeProductMono as Boolean) } diff --git a/core-spring-webflux/build.gradle.kts b/core-spring-webflux/build.gradle.kts index 7642a4f..c84dead 100644 --- a/core-spring-webflux/build.gradle.kts +++ b/core-spring-webflux/build.gradle.kts @@ -31,7 +31,9 @@ dependencies { testImplementation(rootProject.libs.reactor) testImplementation(rootProject.libs.reactor.test) testImplementation(rootProject.libs.spring.context) + testImplementation(rootProject.libs.spring.test) testImplementation(rootProject.libs.aspectj) + testImplementation(rootProject.libs.lettuce) } tasks.withType().configureEach { diff --git a/core-spring-webflux/src/main/kotlin/com/linecorp/cse/reqshield/spring/webflux/annotation/ReqShieldCacheable.kt b/core-spring-webflux/src/main/kotlin/com/linecorp/cse/reqshield/spring/webflux/annotation/ReqShieldCacheable.kt index c093fc7..5d7ba67 100644 --- a/core-spring-webflux/src/main/kotlin/com/linecorp/cse/reqshield/spring/webflux/annotation/ReqShieldCacheable.kt +++ b/core-spring-webflux/src/main/kotlin/com/linecorp/cse/reqshield/spring/webflux/annotation/ReqShieldCacheable.kt @@ -33,4 +33,10 @@ annotation class ReqShieldCacheable( val maxAttemptGetCache: Int = MAX_ATTEMPT_GET_CACHE, val timeToLiveMillis: Long = 10 * 60 * 1000, val reqShieldWorkMode: ReqShieldWorkMode = ReqShieldWorkMode.CREATE_AND_UPDATE_CACHE, + val nullHandling: NullHandling = NullHandling.EMIT_EMPTY, ) + +enum class NullHandling { + EMIT_EMPTY, + ERROR, +} diff --git a/core-spring-webflux/src/main/kotlin/com/linecorp/cse/reqshield/spring/webflux/aspect/ReqShieldAspect.kt b/core-spring-webflux/src/main/kotlin/com/linecorp/cse/reqshield/spring/webflux/aspect/ReqShieldAspect.kt index fa65a83..d9820c2 100644 --- a/core-spring-webflux/src/main/kotlin/com/linecorp/cse/reqshield/spring/webflux/aspect/ReqShieldAspect.kt +++ b/core-spring-webflux/src/main/kotlin/com/linecorp/cse/reqshield/spring/webflux/aspect/ReqShieldAspect.kt @@ -44,7 +44,7 @@ import java.util.concurrent.ConcurrentHashMap @Aspect @Component -class ReqShieldAspect( +open class ReqShieldAspect( private val asyncCache: AsyncCache, ) : BeanFactoryAware { private lateinit var beanFactory: BeanFactory @@ -60,14 +60,34 @@ class ReqShieldAspect( val reqShield = getOrCreateReqShield(joinPoint) val cacheKey = getCacheableCacheKey(joinPoint) - return reqShield - .getAndSetReqShieldData( - cacheKey, - { - joinPoint.proceed() as Mono - }, - annotation.timeToLiveMillis, - ).mapNotNull { it.value } + val resultMono = + reqShield + .getAndSetReqShieldData( + cacheKey, + { + joinPoint.proceed() as Mono + }, + annotation.timeToLiveMillis, + ).map { it.value } + + return when (annotation.nullHandling) { + com.linecorp.cse.reqshield.spring.webflux.annotation.NullHandling.EMIT_EMPTY -> + resultMono.flatMap { value -> + if (value == null) { + Mono.empty() + } else { + Mono.just(value) + } + } + com.linecorp.cse.reqshield.spring.webflux.annotation.NullHandling.ERROR -> + resultMono.flatMap { value -> + if (value == null) { + Mono.error(IllegalStateException("ReqShieldCacheable returned null for key=$cacheKey")) + } else { + Mono.just(value) + } + } + } } @Around("@annotation(com.linecorp.cse.reqshield.spring.webflux.annotation.ReqShieldCacheEvict)") @@ -104,7 +124,7 @@ class ReqShieldAspect( return getCacheKeyOrDefault(annotation.key, annotation.keyGenerator, joinPoint) } - internal fun getTargetMethod(joinPoint: ProceedingJoinPoint): Method = (joinPoint.signature as MethodSignature).method + internal open fun getTargetMethod(joinPoint: ProceedingJoinPoint): Method = (joinPoint.signature as MethodSignature).method private fun getCacheKeyOrDefault( annotationCacheKey: String, @@ -124,7 +144,10 @@ class ReqShieldAspect( keyGenerator.generate(joinPoint.target, method, joinPoint.args).toString() } - require(!key.isNullOrBlank()) { "Null key returned for cache method : $method" } + require(!key.isNullOrBlank()) { + "Null/blank key for @ReqShieldCacheable method=${method.declaringClass.name}.${method.name} " + + "args=${joinPoint.args.joinToString(prefix = "[", postfix = "]") { it?.toString() ?: "null" }}" + } return key } @@ -156,6 +179,7 @@ class ReqShieldAspect( decisionForUpdate = annotation.decisionForUpdate, maxAttemptGetCache = annotation.maxAttemptGetCache, reqShieldWorkMode = annotation.reqShieldWorkMode, + scheduler = beanFactory.getBean("reqShieldScheduler", reactor.core.scheduler.Scheduler::class.java), ) return ReqShield(reqShieldConfiguration) @@ -166,7 +190,9 @@ class ReqShieldAspect( cacheKeyGenerator: String, ) { if (cacheKey.isNotBlank() && cacheKeyGenerator.isNotBlank()) { - throw IllegalArgumentException("The key and keyGenerator attributes are mutually exclusive.") + throw IllegalArgumentException( + "The key and keyGenerator attributes are mutually exclusive: key='$cacheKey', keyGenerator='$cacheKeyGenerator'", + ) } } @@ -180,8 +206,11 @@ class ReqShieldAspect( } } - private fun generateReqShieldKey(joinPoint: ProceedingJoinPoint): String = - "${getCacheableAnnotation(joinPoint).cacheName}-${getCacheableCacheKey(joinPoint)}" + private fun generateReqShieldKey(joinPoint: ProceedingJoinPoint): String { + val method = getTargetMethod(joinPoint) + return "${method.declaringClass.name}.${method.name}-" + + "${getCacheableAnnotation(joinPoint).cacheName}-${getCacheableCacheKey(joinPoint)}" + } override fun setBeanFactory(beanFactory: BeanFactory) { this.beanFactory = beanFactory diff --git a/core-spring-webflux/src/main/kotlin/com/linecorp/cse/reqshield/spring/webflux/config/LibAutoConfiguration.kt b/core-spring-webflux/src/main/kotlin/com/linecorp/cse/reqshield/spring/webflux/config/LibAutoConfiguration.kt index 863cb78..2a46c27 100644 --- a/core-spring-webflux/src/main/kotlin/com/linecorp/cse/reqshield/spring/webflux/config/LibAutoConfiguration.kt +++ b/core-spring-webflux/src/main/kotlin/com/linecorp/cse/reqshield/spring/webflux/config/LibAutoConfiguration.kt @@ -16,11 +16,18 @@ package com.linecorp.cse.reqshield.spring.webflux.config -import org.springframework.context.annotation.ComponentScan +import com.linecorp.cse.reqshield.spring.webflux.aspect.ReqShieldAspect +import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Configuration import org.springframework.context.annotation.EnableAspectJAutoProxy +import org.springframework.context.annotation.Import +import reactor.core.scheduler.Scheduler +import reactor.core.scheduler.Schedulers @Configuration -@EnableAspectJAutoProxy -@ComponentScan(basePackages = ["com.linecorp.cse"]) -open class LibAutoConfiguration +@EnableAspectJAutoProxy(proxyTargetClass = true) +@Import(ReqShieldAspect::class) +open class LibAutoConfiguration { + @Bean + open fun reqShieldScheduler(): Scheduler = Schedulers.boundedElastic() +} diff --git a/core-spring-webflux/src/test/kotlin/com/linecorp/cse/reqshield/spring/webflux/aspect/InMemoryAsyncCache.kt b/core-spring-webflux/src/test/kotlin/com/linecorp/cse/reqshield/spring/webflux/aspect/InMemoryAsyncCache.kt new file mode 100644 index 0000000..15a5631 --- /dev/null +++ b/core-spring-webflux/src/test/kotlin/com/linecorp/cse/reqshield/spring/webflux/aspect/InMemoryAsyncCache.kt @@ -0,0 +1,63 @@ +/* + * Copyright 2024 LY Corporation + * + * LY Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.cse.reqshield.spring.webflux.aspect + +import com.linecorp.cse.reqshield.spring.webflux.cache.AsyncCache +import com.linecorp.cse.reqshield.support.model.ReqShieldData +import reactor.core.publisher.Mono +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.Semaphore + +class InMemoryAsyncCache : AsyncCache { + private data class Entry(val data: ReqShieldData, val expiresAt: Long) + + private val store = ConcurrentHashMap>() + private val locks = ConcurrentHashMap() + + override fun get(key: String): Mono?> = + Mono.fromCallable { + val now = System.currentTimeMillis() + store[key]?.let { e -> if (now <= e.expiresAt) e.data else null } + } + + override fun put( + key: String, + value: ReqShieldData, + timeToLiveMillis: Long, + ): Mono = + Mono.fromCallable { + val expiresAt = System.currentTimeMillis() + timeToLiveMillis + store[key] = Entry(value, expiresAt) + true + } + + override fun evict(key: String): Mono = Mono.fromCallable { store.remove(key) != null } + + override fun globalLock( + key: String, + timeToLiveMillis: Long, + ): Mono = + Mono.fromCallable { + locks.computeIfAbsent(key) { Semaphore(1) }.tryAcquire() + } + + override fun globalUnLock(key: String): Mono = + Mono.fromCallable { + locks[key]?.release() + true + } +} diff --git a/core-spring-webflux/src/test/kotlin/com/linecorp/cse/reqshield/spring/webflux/aspect/ReqShieldAspectIntegrationTest.kt b/core-spring-webflux/src/test/kotlin/com/linecorp/cse/reqshield/spring/webflux/aspect/ReqShieldAspectIntegrationTest.kt new file mode 100644 index 0000000..577ac0c --- /dev/null +++ b/core-spring-webflux/src/test/kotlin/com/linecorp/cse/reqshield/spring/webflux/aspect/ReqShieldAspectIntegrationTest.kt @@ -0,0 +1,91 @@ +/* + * Copyright 2024 LY Corporation + * + * LY Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.cse.reqshield.spring.webflux.aspect + +import com.linecorp.cse.reqshield.spring.webflux.annotation.ReqShieldCacheEvict +import com.linecorp.cse.reqshield.spring.webflux.annotation.ReqShieldCacheable +import com.linecorp.cse.reqshield.spring.webflux.cache.AsyncCache +import com.linecorp.cse.reqshield.spring.webflux.config.LibAutoConfiguration +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.ExtendWith +import org.springframework.beans.factory.annotation.Autowired +import org.springframework.context.annotation.Bean +import org.springframework.context.annotation.Configuration +import org.springframework.test.context.ContextConfiguration +import org.springframework.test.context.junit.jupiter.SpringExtension +import reactor.core.publisher.Flux +import reactor.core.publisher.Mono +import reactor.core.scheduler.Schedulers +import java.util.concurrent.atomic.AtomicInteger + +@ExtendWith(SpringExtension::class) +@ContextConfiguration(classes = [LibAutoConfiguration::class, ReqShieldAspectIntegrationTest.TestConfig::class]) +class ReqShieldAspectIntegrationTest { + @Autowired + private lateinit var service: TestService + + @Test + fun shouldCollapseDuplicateRequests() { + val key = "dup" + val attempts = 20 + + val result = + Flux + .range(1, attempts) + .flatMap { service.get(key).subscribeOn(Schedulers.boundedElastic()) } + .collectList() + .block() + + assertEquals(attempts, result?.size) + val first = result?.firstOrNull() + assertTrue(result?.all { it == first } == true) + } + + @Test + fun shouldEvictAndRecompute() { + val key = "evict" + val v1 = service.get(key).block() + val evicted = service.evict(key).block() + val v2 = service.get(key).block() + + assertTrue(evicted == true) + assertTrue(v1 != null) + assertTrue(v2 != null) + assertTrue(v1 != v2) + } + + @Configuration + open class TestConfig { + @Bean + open fun asyncCache(): AsyncCache = InMemoryAsyncCache() + + @Bean + open fun service(): TestService = TestService() + } + + open class TestService { + val counter = AtomicInteger(0) + + @ReqShieldCacheable(cacheName = "it", key = "#key", timeToLiveMillis = 10_000) + open fun get(key: String): Mono = Mono.fromCallable { "value-" + counter.incrementAndGet() } + + @ReqShieldCacheEvict(cacheName = "it", key = "#key") + open fun evict(key: String): Mono = Mono.just(true) + } +} diff --git a/core-spring-webflux/src/test/kotlin/com/linecorp/cse/reqshield/spring/webflux/aspect/ReqShieldAspectRedisIntegrationTest.kt b/core-spring-webflux/src/test/kotlin/com/linecorp/cse/reqshield/spring/webflux/aspect/ReqShieldAspectRedisIntegrationTest.kt new file mode 100644 index 0000000..193d24c --- /dev/null +++ b/core-spring-webflux/src/test/kotlin/com/linecorp/cse/reqshield/spring/webflux/aspect/ReqShieldAspectRedisIntegrationTest.kt @@ -0,0 +1,135 @@ +/* + * Copyright 2024 LY Corporation + * + * LY Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.cse.reqshield.spring.webflux.aspect + +import com.linecorp.cse.reqshield.spring.webflux.annotation.ReqShieldCacheEvict +import com.linecorp.cse.reqshield.spring.webflux.annotation.ReqShieldCacheable +import com.linecorp.cse.reqshield.spring.webflux.cache.AsyncCache +import com.linecorp.cse.reqshield.spring.webflux.config.LibAutoConfiguration +import com.linecorp.cse.reqshield.support.model.ReqShieldData +import com.linecorp.cse.reqshield.support.redis.AbstractRedisTest +import io.lettuce.core.RedisClient +import io.lettuce.core.api.StatefulRedisConnection +import io.lettuce.core.api.sync.RedisCommands +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable +import org.junit.jupiter.api.extension.ExtendWith +import org.springframework.beans.factory.annotation.Autowired +import org.springframework.beans.factory.annotation.Value +import org.springframework.context.annotation.Bean +import org.springframework.context.annotation.Configuration +import org.springframework.test.context.ContextConfiguration +import org.springframework.test.context.junit.jupiter.SpringExtension +import reactor.core.publisher.Flux +import reactor.core.publisher.Mono +import reactor.core.scheduler.Schedulers +import java.util.concurrent.atomic.AtomicInteger + +@ExtendWith(SpringExtension::class) +@EnabledIfEnvironmentVariable(named = "RUN_REDIS_IT", matches = "true") +@ContextConfiguration(classes = [LibAutoConfiguration::class, ReqShieldAspectRedisIntegrationTest.TestConfig::class]) +class ReqShieldAspectRedisIntegrationTest : AbstractRedisTest() { + @Autowired + private lateinit var service: TestService + + @Test + fun shouldCollapseDuplicateRequestsWithRedis() { + val key = "dup-redis" + val attempts = 20 + + val result = + Flux + .range(1, attempts) + .flatMap { service.get(key).subscribeOn(Schedulers.boundedElastic()) } + .collectList() + .block()!! + + assertTrue(result.size == attempts) + val first = result.firstOrNull() + assertTrue(result.all { it == first }) + assertTrue(service.counter.get() == 1) + } + + @Test + fun shouldEvictAndRecomputeWithRedis() { + val key = "evict-redis" + val v1 = service.get(key).block() + val evicted = service.evict(key).block() + val v2 = service.get(key).block() + + assertTrue(evicted == true) + assertTrue(v1 != null && v2 != null && v1 != v2) + assertTrue(service.counter.get() == 2) + } + + @Configuration + open class TestConfig { + @Value("\${spring.redis.host}") + private lateinit var host: String + + @Value("\${spring.redis.port}") + private var port: Int = 0 + + @Bean + open fun asyncCache(): AsyncCache { + val client: RedisClient = RedisClient.create("redis://$host:$port") + val conn: StatefulRedisConnection = client.connect() + val sync: RedisCommands = conn.sync() + + return object : AsyncCache { + override fun get(key: String): Mono?> = + Mono.fromCallable { + sync.get(key)?.let { ReqShieldData(value = it, timeToLiveMillis = 10_000) } + } + + override fun put( + key: String, + value: ReqShieldData, + timeToLiveMillis: Long, + ): Mono = + Mono.fromCallable { + sync.psetex(key, timeToLiveMillis, value.value ?: "") + true + } + + override fun evict(key: String): Mono = Mono.fromCallable { sync.del(key) > 0 } + + override fun globalLock( + key: String, + timeToLiveMillis: Long, + ): Mono = + Mono.fromCallable { sync.setnx("lock:$key", "1").also { if (it) sync.pexpire("lock:$key", timeToLiveMillis) } } + + override fun globalUnLock(key: String): Mono = Mono.fromCallable { sync.del("lock:$key") >= 0 } + } + } + + @Bean + open fun service(): TestService = TestService() + } + + open class TestService { + val counter = AtomicInteger(0) + + @ReqShieldCacheable(cacheName = "it", key = "#key", timeToLiveMillis = 10_000) + open fun get(key: String): Mono = Mono.fromCallable { "value-" + counter.incrementAndGet() } + + @ReqShieldCacheEvict(cacheName = "it", key = "#key") + open fun evict(key: String): Mono = Mono.just(true) + } +} diff --git a/core-spring-webflux/src/test/kotlin/com/linecorp/cse/reqshield/spring/webflux/aspect/ReqShieldAspectTest.kt b/core-spring-webflux/src/test/kotlin/com/linecorp/cse/reqshield/spring/webflux/aspect/ReqShieldAspectTest.kt index afdfab0..8591fd8 100644 --- a/core-spring-webflux/src/test/kotlin/com/linecorp/cse/reqshield/spring/webflux/aspect/ReqShieldAspectTest.kt +++ b/core-spring-webflux/src/test/kotlin/com/linecorp/cse/reqshield/spring/webflux/aspect/ReqShieldAspectTest.kt @@ -42,7 +42,7 @@ import kotlin.test.assertEquals import kotlin.test.assertTrue class ReqShieldAspectTest : BaseReqShieldModuleSupportTest { - private val asyncCache: AsyncCache = mockk() + private val asyncCache: AsyncCache = InMemoryAsyncCache() private val joinPoint = mockk() private val reqShieldAspect = spyk(ReqShieldAspect(asyncCache)) private val targetObject = spyk(TestBean()) @@ -63,13 +63,17 @@ class ReqShieldAspectTest : BaseReqShieldModuleSupportTest { every { joinPoint.target } returns targetObject reqShieldAspect.setBeanFactory(beanFactory) + // Provide scheduler bean expected by aspect configuration + every { + beanFactory.getBean("reqShieldScheduler", reactor.core.scheduler.Scheduler::class.java) + } returns Schedulers.boundedElastic() } @Test override fun verifyReqShieldCacheCreation() { - // Mock the cache data using mockk val reqShieldData = ReqShieldData(methodReturn, 1000) - every { asyncCache.get(any()) } returns Mono.just(reqShieldData) + // pre-populate cache + asyncCache.put(spelEvaluatedKey, reqShieldData, 1000).block() every { joinPoint.proceed() } answers { targetObject.cacheableWithCustomKey(argument) } every { reqShieldAspect.getTargetMethod(joinPoint) } returns ReflectionUtils.findMethod( @@ -87,15 +91,16 @@ class ReqShieldAspectTest : BaseReqShieldModuleSupportTest { .assertNext { value -> assertEquals(reqShieldData.value, value) Assertions.assertTrue(reqShieldAspect.reqShieldMap.size == 1) - Assertions.assertNotNull(reqShieldAspect.reqShieldMap["$cacheName-$spelEvaluatedKey"]) + val method = reqShieldAspect.getTargetMethod(joinPoint) + val expectedKey = "${method.declaringClass.name}.${method.name}-$cacheName-$spelEvaluatedKey" + Assertions.assertNotNull(reqShieldAspect.reqShieldMap[expectedKey]) }.verifyComplete() } @Test override fun reqShieldObjectShouldBeCreatedOnce() { - // Mock the cache data using mockk val reqShieldData = ReqShieldData(methodReturn, 1000) - every { asyncCache.get(any()) } returns Mono.just(reqShieldData) + asyncCache.put(spelEvaluatedKey, reqShieldData, 1000).block() every { joinPoint.proceed() } answers { targetObject.cacheableWithCustomKey(argument) } every { reqShieldAspect.getTargetMethod(joinPoint) } returns ReflectionUtils.findMethod( @@ -117,16 +122,16 @@ class ReqShieldAspectTest : BaseReqShieldModuleSupportTest { .create(flux) .assertNext { productList -> Assertions.assertTrue(reqShieldAspect.reqShieldMap.size == 1) - println(reqShieldAspect.reqShieldMap.keys().toList()) - Assertions.assertNotNull(reqShieldAspect.reqShieldMap["$cacheName-$spelEvaluatedKey"]) + val method = reqShieldAspect.getTargetMethod(joinPoint) + val expectedKey = "${method.declaringClass.name}.${method.name}-$cacheName-$spelEvaluatedKey" + Assertions.assertNotNull(reqShieldAspect.reqShieldMap[expectedKey]) }.verifyComplete() } @Test override fun verifyReqShieldCacheEviction() { - // Mock the cache data using mockk val reqShieldData = ReqShieldData(methodReturn, 1000) - every { asyncCache.get(any()) } returns Mono.just(reqShieldData) + asyncCache.put("${SimpleKeyGenerator.generateKey(arrayOf(argument))}", reqShieldData, 1000).block() every { reqShieldAspect.getTargetMethod(joinPoint) } returns ReflectionUtils.findMethod( TestBean::class.java, @@ -145,7 +150,7 @@ class ReqShieldAspectTest : BaseReqShieldModuleSupportTest { }.verifyComplete() // Validate cache eviction - every { asyncCache.evict(any()) } returns Mono.just(true) + // real eviction call every { reqShieldAspect.getTargetMethod(joinPoint) } returns ReflectionUtils.findMethod( TestBean::class.java, diff --git a/core-spring/src/main/kotlin/com/linecorp/cse/reqshield/spring/aspect/ReqShieldAspect.kt b/core-spring/src/main/kotlin/com/linecorp/cse/reqshield/spring/aspect/ReqShieldAspect.kt index bf2cae7..df19ada 100644 --- a/core-spring/src/main/kotlin/com/linecorp/cse/reqshield/spring/aspect/ReqShieldAspect.kt +++ b/core-spring/src/main/kotlin/com/linecorp/cse/reqshield/spring/aspect/ReqShieldAspect.kt @@ -151,7 +151,10 @@ class ReqShieldAspect( keyGenerator.generate(joinPoint.target, method, joinPoint.args).toString() } - require(!key.isNullOrBlank()) { "Null key returned for cache method : $method" } + require(!key.isNullOrBlank()) { + "Null/blank key for @ReqShieldCacheable method=${method.declaringClass.name}.${method.name} " + + "args=${joinPoint.args.joinToString(prefix = "[", postfix = "]") { it?.toString() ?: "null" }}" + } return key } @@ -161,7 +164,9 @@ class ReqShieldAspect( cacheKeyGenerator: String, ) { if (cacheKey.isNotBlank() && cacheKeyGenerator.isNotBlank()) { - throw IllegalArgumentException("The key and keyGenerator attributes are mutually exclusive.") + throw IllegalArgumentException( + "The key and keyGenerator attributes are mutually exclusive: key='$cacheKey', keyGenerator='$cacheKeyGenerator'", + ) } } @@ -175,8 +180,11 @@ class ReqShieldAspect( } } - private fun generateReqShieldKey(joinPoint: ProceedingJoinPoint): String = - "${getCacheableAnnotation(joinPoint).cacheName}-${getCacheableCacheKey(joinPoint)}" + private fun generateReqShieldKey(joinPoint: ProceedingJoinPoint): String { + val method = getTargetMethod(joinPoint) + return "${method.declaringClass.name}.${method.name}-" + + "${getCacheableAnnotation(joinPoint).cacheName}-${getCacheableCacheKey(joinPoint)}" + } override fun setBeanFactory(beanFactory: BeanFactory) { this.beanFactory = beanFactory diff --git a/core-spring/src/test/kotlin/aspect/ReqShieldAspectTest.kt b/core-spring/src/test/kotlin/aspect/ReqShieldAspectTest.kt index 619bd08..38b246c 100644 --- a/core-spring/src/test/kotlin/aspect/ReqShieldAspectTest.kt +++ b/core-spring/src/test/kotlin/aspect/ReqShieldAspectTest.kt @@ -90,7 +90,9 @@ class ReqShieldAspectTest : BaseReqShieldModuleSupportTest { // then assertEquals(reqShieldData.value, result) assertTrue(reqShieldAspect.reqShieldMap.size == 1) - assertNotNull(reqShieldAspect.reqShieldMap["$cacheName-$spelEvaluatedKey"]) + val method = reqShieldAspect.getTargetMethod(joinPoint) + val expectedKey = "${method.declaringClass.name}.${method.name}-$cacheName-$spelEvaluatedKey" + assertNotNull(reqShieldAspect.reqShieldMap[expectedKey]) } } @@ -117,7 +119,9 @@ class ReqShieldAspectTest : BaseReqShieldModuleSupportTest { Awaitility.await().atMost(Duration.ofMillis(BaseReqShieldTest.AWAIT_TIMEOUT)).untilAsserted { // then assertTrue(reqShieldAspect.reqShieldMap.size == 1) - assertNotNull(reqShieldAspect.reqShieldMap["$cacheName-$spelEvaluatedKey"]) + val method = reqShieldAspect.getTargetMethod(joinPoint) + val expectedKey = "${method.declaringClass.name}.${method.name}-$cacheName-$spelEvaluatedKey" + assertNotNull(reqShieldAspect.reqShieldMap[expectedKey]) } } diff --git a/core/src/main/kotlin/com/linecorp/cse/reqshield/KeyLocalLock.kt b/core/src/main/kotlin/com/linecorp/cse/reqshield/KeyLocalLock.kt index 59153dd..8396605 100644 --- a/core/src/main/kotlin/com/linecorp/cse/reqshield/KeyLocalLock.kt +++ b/core/src/main/kotlin/com/linecorp/cse/reqshield/KeyLocalLock.kt @@ -28,7 +28,7 @@ import java.util.concurrent.TimeUnit private val log = LoggerFactory.getLogger(KeyLocalLock::class.java) class KeyLocalLock(private val lockTimeoutMillis: Long) : KeyLock { - private data class LockInfo(val semaphore: Semaphore, val createdAt: Long) + private data class LockInfo(val semaphore: Semaphore, val expiresAt: Long) companion object { // Global lockMap shared by all instances - CRITICAL FIX for request collapsing @@ -61,11 +61,15 @@ class KeyLocalLock(private val lockTimeoutMillis: Long) : KeyLock { } private fun startMonitoring(scheduler: ScheduledExecutorService) { - // Batch cleanup for all instances (10ms → 1000ms) + // Single cleanup task operating on the global lockMap scheduler.scheduleWithFixedDelay({ try { - instances.forEach { instance -> - instance.cleanupExpiredLocks() + val now = System.currentTimeMillis() + val before = lockMap.size + lockMap.entries.removeIf { now > it.value.expiresAt } + val after = lockMap.size + if (log.isTraceEnabled && before > after) { + log.trace("Cleaned up {} expired locks, {} remaining", before - after, after) } } catch (e: Exception) { log.error("Error in shared lock lifecycle monitoring: {}", e.message) @@ -99,28 +103,15 @@ class KeyLocalLock(private val lockTimeoutMillis: Long) : KeyLock { getOrCreateScheduler() } - // Internal cleanup method (called by shared scheduler) - internal fun cleanupExpiredLocks() { - val now = System.currentTimeMillis() - val expiredCount = lockMap.size - lockMap.entries.removeIf { now - it.value.createdAt > lockTimeoutMillis } - val remainingCount = lockMap.size - - if (log.isTraceEnabled && expiredCount > remainingCount) { - log.trace( - "Cleaned up {} expired locks, {} remaining", - expiredCount - remainingCount, - remainingCount, - ) - } - } + // Internal cleanup method no longer needed per-instance with single shared cleanup override fun tryLock( key: String, lockType: LockType, ): Boolean { val completeKey = "${key}_${lockType.name}" - val lockInfo = lockMap.computeIfAbsent(completeKey) { LockInfo(Semaphore(1), nowToEpochTime()) } + val now = nowToEpochTime() + val lockInfo = lockMap.computeIfAbsent(completeKey) { LockInfo(Semaphore(1), now + lockTimeoutMillis) } return lockInfo.semaphore.tryAcquire() } diff --git a/core/src/main/kotlin/com/linecorp/cse/reqshield/ReqShield.kt b/core/src/main/kotlin/com/linecorp/cse/reqshield/ReqShield.kt index 20a9f2a..7af99d8 100644 --- a/core/src/main/kotlin/com/linecorp/cse/reqshield/ReqShield.kt +++ b/core/src/main/kotlin/com/linecorp/cse/reqshield/ReqShield.kt @@ -163,26 +163,24 @@ class ReqShield( callable: Callable, key: String, ) { - fun schedule(): ScheduledFuture<*> = - executor.schedule({ - if (!future.isDone) { - val funcResult = executeGetCacheFunction(cacheGetter, key) - if (funcResult != null) { - future.complete(funcResult.value) - } else if (counter.incrementAndGet() >= reqShieldConfig.maxAttemptGetCache) { - future.complete( - executeCallable({ callable.call() }, false), - ) - } - if (!future.isDone) { - schedule() // Schedule the next execution - } + val scheduled: ScheduledFuture<*> = + executor.scheduleAtFixedRate({ + if (future.isDone) { + return@scheduleAtFixedRate } - }, GET_CACHE_INTERVAL_MILLIS, TimeUnit.MILLISECONDS) - val scheduleFuture = schedule() + val funcResult = executeGetCacheFunction(cacheGetter, key) + if (funcResult != null) { + future.complete(funcResult.value) + return@scheduleAtFixedRate + } + + if (counter.incrementAndGet() >= reqShieldConfig.maxAttemptGetCache) { + future.complete(executeCallable({ callable.call() }, false)) + } + }, GET_CACHE_INTERVAL_MILLIS, GET_CACHE_INTERVAL_MILLIS, TimeUnit.MILLISECONDS) - future.whenComplete { _, _ -> scheduleFuture.cancel(false) } + future.whenComplete { _, _ -> scheduled.cancel(false) } } private fun executeGetCacheFunction( diff --git a/core/src/test/kotlin/com/linecorp/cse/reqshield/KeyGlobalLockTest.kt b/core/src/test/kotlin/com/linecorp/cse/reqshield/KeyGlobalLockTest.kt index 9e9d443..2be0427 100644 --- a/core/src/test/kotlin/com/linecorp/cse/reqshield/KeyGlobalLockTest.kt +++ b/core/src/test/kotlin/com/linecorp/cse/reqshield/KeyGlobalLockTest.kt @@ -26,11 +26,13 @@ import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable import java.time.Duration import java.util.concurrent.Executors import java.util.concurrent.atomic.AtomicInteger import kotlin.test.Ignore +@EnabledIfEnvironmentVariable(named = "RUN_REDIS_IT", matches = "true") class KeyGlobalLockTest : AbstractRedisTest(), BaseKeyLockTest { diff --git a/core/src/test/kotlin/com/linecorp/cse/reqshield/KeyLocalLockTest.kt b/core/src/test/kotlin/com/linecorp/cse/reqshield/KeyLocalLockTest.kt index 237e405..c83dbbe 100644 --- a/core/src/test/kotlin/com/linecorp/cse/reqshield/KeyLocalLockTest.kt +++ b/core/src/test/kotlin/com/linecorp/cse/reqshield/KeyLocalLockTest.kt @@ -194,10 +194,10 @@ class KeyLocalLockTest : BaseKeyLockTest { // Then - Instance2 should not be able to acquire the same lock val lock2Result = instance2.tryLock(key, lockType) - + assertTrue(lock1Result) assertTrue(!lock2Result, "Instance2 should not acquire lock held by Instance1") - + // Cleanup instance1.unLock(key, lockType) instance1.shutdown() @@ -208,7 +208,7 @@ class KeyLocalLockTest : BaseKeyLockTest { fun `should maintain request collapsing across multiple instances`() { // Given val instance1 = KeyLocalLock(lockTimeoutMillis) - val instance2 = KeyLocalLock(lockTimeoutMillis) + val instance2 = KeyLocalLock(lockTimeoutMillis) val instance3 = KeyLocalLock(lockTimeoutMillis) val key = "collapsing-key" val lockType = LockType.CREATE @@ -220,11 +220,12 @@ class KeyLocalLockTest : BaseKeyLockTest { // When - Multiple instances try to acquire the same lock concurrently repeat(3) { index -> executor.submit { - val instance = when (index) { - 0 -> instance1 - 1 -> instance2 - else -> instance3 - } + val instance = + when (index) { + 0 -> instance1 + 1 -> instance2 + else -> instance3 + } attemptCount.incrementAndGet() if (instance.tryLock(key, lockType)) { successCount.incrementAndGet() @@ -241,7 +242,7 @@ class KeyLocalLockTest : BaseKeyLockTest { // Then - Only one should succeed in acquiring the lock assertEquals(3, attemptCount.get()) assertEquals(1, successCount.get(), "Only one instance should acquire the lock") - + // Cleanup instance1.shutdown() instance2.shutdown() @@ -259,11 +260,11 @@ class KeyLocalLockTest : BaseKeyLockTest { // When - Instance1 acquires lock, Instance2 unlocks assertTrue(instance1.tryLock(key, lockType)) instance2.unLock(key, lockType) // Should work even from different instance - + // Then - New lock acquisition should succeed val newLockResult = instance2.tryLock(key, lockType) assertTrue(newLockResult, "Should be able to acquire lock after global unlock") - + // Cleanup instance2.unLock(key, lockType) instance1.shutdown() @@ -305,16 +306,20 @@ class KeyLocalLockTest : BaseKeyLockTest { // Then - Verify thread safety and concurrent operations handling assertEquals(0, errors.get(), "No errors should occur during concurrent operations") - + // Due to sequential nature of ThreadPool(10) and brief work duration (10ms), // multiple operations can succeed on the same key at different times - assertTrue(operations.get() >= 10, - "At least one operation per key should succeed (minimum 10)") - assertTrue(operations.get() <= 50, - "No more operations than total attempts should succeed (maximum 50)") - + assertTrue( + operations.get() >= 10, + "At least one operation per key should succeed (minimum 10)", + ) + assertTrue( + operations.get() <= 50, + "No more operations than total attempts should succeed (maximum 50)", + ) + println("Successful operations: ${operations.get()}/50 total attempts") - + // Cleanup instances.forEach { it.shutdown() } } diff --git a/libs.versions.toml b/libs.versions.toml index 80d2491..48a0376 100644 --- a/libs.versions.toml +++ b/libs.versions.toml @@ -1,6 +1,7 @@ [versions] kotlin = "1.8.20" kotlinCoroutine = "1.7.3" +kotlinCoroutineSpring = "1.6.4" reactor = "3.4.23" spring = "5.3.30" springBoot3 = "3.3.1" @@ -22,6 +23,7 @@ kotlin-coroutine = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core", v kotlin-coroutine-jvm = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core-jvm", version.ref = "kotlinCoroutine" } kotlin-coroutine-jdk8 = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-jdk8", version.ref = "kotlinCoroutine" } kotlin-coroutine-reactor = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-reactor", version.ref = "kotlinCoroutine" } +kotlin-coroutine-spring = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-spring", version.ref = "kotlinCoroutineSpring" } kotlin-coroutine-reactive = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-reactive", version.ref = "kotlinCoroutine" } # log diff --git a/req-shield-spring-boot3-example/src/test/kotlin/com/linecorp/cse/reqshield/spring3/mvc/example/service/CacheAnnotationTest.kt b/req-shield-spring-boot3-example/src/test/kotlin/com/linecorp/cse/reqshield/spring3/mvc/example/service/CacheAnnotationTest.kt index 5f90aa7..db77498 100644 --- a/req-shield-spring-boot3-example/src/test/kotlin/com/linecorp/cse/reqshield/spring3/mvc/example/service/CacheAnnotationTest.kt +++ b/req-shield-spring-boot3-example/src/test/kotlin/com/linecorp/cse/reqshield/spring3/mvc/example/service/CacheAnnotationTest.kt @@ -10,6 +10,7 @@ import org.junit.jupiter.api.Assertions.assertNotNull import org.junit.jupiter.api.Assertions.assertNull import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable import org.junit.jupiter.api.extension.ExtendWith import org.springframework.beans.factory.annotation.Autowired import org.springframework.boot.test.context.SpringBootTest @@ -20,6 +21,7 @@ import java.util.concurrent.Executors import java.util.concurrent.TimeUnit @SpringBootTest +@EnabledIfEnvironmentVariable(named = "RUN_REDIS_IT", matches = "true") @ExtendWith(SpringExtension::class) class CacheAnnotationTest : AbstractRedisTest() { @Autowired diff --git a/req-shield-spring-boot3-webflux-example/src/test/kotlin/com/linecorp/cse/reqshield/spring3/webflux/example/IntegrationSmokeTest.kt b/req-shield-spring-boot3-webflux-example/src/test/kotlin/com/linecorp/cse/reqshield/spring3/webflux/example/IntegrationSmokeTest.kt new file mode 100644 index 0000000..bc25d07 --- /dev/null +++ b/req-shield-spring-boot3-webflux-example/src/test/kotlin/com/linecorp/cse/reqshield/spring3/webflux/example/IntegrationSmokeTest.kt @@ -0,0 +1,17 @@ +package com.linecorp.cse.reqshield.spring3.webflux.example + +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable +import org.junit.jupiter.api.extension.ExtendWith +import org.springframework.boot.test.context.SpringBootTest +import org.springframework.test.context.junit.jupiter.SpringExtension + +@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) +@ExtendWith(SpringExtension::class) +@EnabledIfEnvironmentVariable(named = "RUN_REDIS_IT", matches = "true") +class IntegrationSmokeTest { + @Test + fun contextLoads() { + // just ensure context starts with Testcontainers + } +} diff --git a/req-shield-spring-boot3-webflux-example/src/test/kotlin/com/linecorp/cse/reqshield/spring3/webflux/example/service/CacheAnnotationTest.kt b/req-shield-spring-boot3-webflux-example/src/test/kotlin/com/linecorp/cse/reqshield/spring3/webflux/example/service/CacheAnnotationTest.kt index 6421006..1984cb1 100644 --- a/req-shield-spring-boot3-webflux-example/src/test/kotlin/com/linecorp/cse/reqshield/spring3/webflux/example/service/CacheAnnotationTest.kt +++ b/req-shield-spring-boot3-webflux-example/src/test/kotlin/com/linecorp/cse/reqshield/spring3/webflux/example/service/CacheAnnotationTest.kt @@ -7,6 +7,7 @@ import org.awaitility.Awaitility.await import org.junit.jupiter.api.Assertions.* import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable import org.junit.jupiter.api.extension.ExtendWith import org.springframework.beans.factory.annotation.Autowired import org.springframework.boot.test.context.SpringBootTest @@ -19,6 +20,7 @@ import java.util.concurrent.TimeUnit import kotlin.test.assertNotNull @SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) +@EnabledIfEnvironmentVariable(named = "RUN_REDIS_IT", matches = "true") @ExtendWith(SpringExtension::class) class CacheAnnotationTest : AbstractRedisTest() { @Autowired diff --git a/req-shield-spring-boot3-webflux-kotlin-coroutine-example/src/test/kotlin/com/linecorp/cse/reqshield/spring3/webflux/kotlin/coroutine/example/CacheAnnotationTest.kt b/req-shield-spring-boot3-webflux-kotlin-coroutine-example/src/test/kotlin/com/linecorp/cse/reqshield/spring3/webflux/kotlin/coroutine/example/CacheAnnotationTest.kt index b0a29fe..3af5e73 100644 --- a/req-shield-spring-boot3-webflux-kotlin-coroutine-example/src/test/kotlin/com/linecorp/cse/reqshield/spring3/webflux/kotlin/coroutine/example/CacheAnnotationTest.kt +++ b/req-shield-spring-boot3-webflux-kotlin-coroutine-example/src/test/kotlin/com/linecorp/cse/reqshield/spring3/webflux/kotlin/coroutine/example/CacheAnnotationTest.kt @@ -13,6 +13,7 @@ import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Assertions.* import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable import org.junit.jupiter.api.extension.ExtendWith import org.springframework.beans.factory.annotation.Autowired import org.springframework.boot.test.context.SpringBootTest @@ -20,6 +21,7 @@ import org.springframework.test.context.junit.jupiter.SpringExtension import java.util.* @SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) +@EnabledIfEnvironmentVariable(named = "RUN_REDIS_IT", matches = "true") @ExtendWith(SpringExtension::class) class CacheAnnotationTest : AbstractRedisTest() { @Autowired diff --git a/support/build.gradle.kts b/support/build.gradle.kts index 936dab9..105eda1 100644 --- a/support/build.gradle.kts +++ b/support/build.gradle.kts @@ -23,6 +23,7 @@ plugins { } dependencies { + testFixturesImplementation(rootProject.libs.junit) testFixturesImplementation(rootProject.libs.testcontainers) testFixturesImplementation(rootProject.libs.junit.jupiter.testcontainers) testFixturesImplementation(rootProject.libs.spring.context) diff --git a/support/src/testFixtures/kotlin/com/linecorp/cse/reqshield/support/redis/AbstractRedisTest.kt b/support/src/testFixtures/kotlin/com/linecorp/cse/reqshield/support/redis/AbstractRedisTest.kt index 50ca6e8..da583f3 100644 --- a/support/src/testFixtures/kotlin/com/linecorp/cse/reqshield/support/redis/AbstractRedisTest.kt +++ b/support/src/testFixtures/kotlin/com/linecorp/cse/reqshield/support/redis/AbstractRedisTest.kt @@ -16,6 +16,7 @@ package com.linecorp.cse.reqshield.support.redis +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable import org.springframework.context.ApplicationContextInitializer import org.springframework.context.ConfigurableApplicationContext import org.springframework.core.env.MapPropertySource @@ -24,6 +25,7 @@ import org.testcontainers.junit.jupiter.Container import org.testcontainers.junit.jupiter.Testcontainers @Testcontainers +@EnabledIfEnvironmentVariable(named = "RUN_REDIS_IT", matches = "true") @ContextConfiguration(initializers = [AbstractRedisTest.Companion.Initializer::class]) abstract class AbstractRedisTest { companion object { diff --git a/support/src/testFixtures/kotlin/com/linecorp/cse/reqshield/support/redis/RedisContainer.kt b/support/src/testFixtures/kotlin/com/linecorp/cse/reqshield/support/redis/RedisContainer.kt index 0453baa..5fa6070 100644 --- a/support/src/testFixtures/kotlin/com/linecorp/cse/reqshield/support/redis/RedisContainer.kt +++ b/support/src/testFixtures/kotlin/com/linecorp/cse/reqshield/support/redis/RedisContainer.kt @@ -24,7 +24,8 @@ object RedisContainer { GenericContainer( DockerImageName.parse("redis:6.2.7-alpine"), ).apply { - portBindings = listOf("6379:6379") + // let Testcontainers map to a random local port to reduce conflicts + addExposedPort(6379) withReuse(true) } }