diff --git a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api index 9b548ac49f..b24e660df9 100644 --- a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api +++ b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api @@ -1105,7 +1105,7 @@ public final class kotlinx/coroutines/flow/FlowKt { public static final fun retryWhen (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function4;)Lkotlinx/coroutines/flow/Flow; public static final fun runningFold (Lkotlinx/coroutines/flow/Flow;Ljava/lang/Object;Lkotlin/jvm/functions/Function3;)Lkotlinx/coroutines/flow/Flow; public static final fun runningReduce (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function3;)Lkotlinx/coroutines/flow/Flow; - public static final fun sample (Lkotlinx/coroutines/flow/Flow;J)Lkotlinx/coroutines/flow/Flow; + public static final fun sample (Lkotlinx/coroutines/flow/Flow;Lkotlinx/coroutines/flow/Flow;)Lkotlinx/coroutines/flow/Flow; public static final fun sample-HG0u8IE (Lkotlinx/coroutines/flow/Flow;J)Lkotlinx/coroutines/flow/Flow; public static final fun scan (Lkotlinx/coroutines/flow/Flow;Ljava/lang/Object;Lkotlin/jvm/functions/Function3;)Lkotlinx/coroutines/flow/Flow; public static final fun scanFold (Lkotlinx/coroutines/flow/Flow;Ljava/lang/Object;Lkotlin/jvm/functions/Function3;)Lkotlinx/coroutines/flow/Flow; diff --git a/kotlinx-coroutines-core/common/src/flow/operators/Delay.kt b/kotlinx-coroutines-core/common/src/flow/operators/Delay.kt index 37505dc162..4a241c019e 100644 --- a/kotlinx-coroutines-core/common/src/flow/operators/Delay.kt +++ b/kotlinx-coroutines-core/common/src/flow/operators/Delay.kt @@ -203,7 +203,7 @@ public fun Flow.debounce(timeout: (T) -> Duration): Flow = timeout(emittedItem).toDelayMillis() } -private fun Flow.debounceInternal(timeoutMillisSelector: (T) -> Long) : Flow = +private fun Flow.debounceInternal(timeoutMillisSelector: (T) -> Long): Flow = scopedFlow { downstream -> // Produce the values using the default (rendezvous) channel val values = produce { @@ -273,13 +273,49 @@ private fun Flow.debounceInternal(timeoutMillisSelector: (T) -> Long) : F */ @FlowPreview public fun Flow.sample(periodMillis: Long): Flow { - require(periodMillis > 0) { "Sample period should be positive" } + return sample(flow { + delay(periodMillis) + while (true) { + emit(Unit) + delay(periodMillis) + } + }) +} + +/** + * Returns a flow that emits only the latest value emitted by the original flow only when the [sampler] emits. + * + * Example: + * ``` + * flow { + * repeat(10) { + * emit(it) + * delay(50) + * } + * }.sampleBy(flow { + * repeat(10) { + * delay(100) + * emit(it) + * } + * }) + * + * ``` + * produces `0, 2, 4, 6, 8`. + * + * Note that the latest element is not emitted if it does not fit into the sampling window. + */ +public fun Flow.sample(sampler: Flow): Flow { return scopedFlow { downstream -> val values = produce(capacity = Channel.CONFLATED) { collect { value -> send(value ?: NULL) } } + + val samplerProducer = produce(capacity = 0) { + sampler.collect { value -> + send(value) + } + } var lastValue: Any? = null - val ticker = fixedPeriodTicker(periodMillis) while (lastValue !== DONE) { select { values.onReceiveCatching { result -> @@ -287,16 +323,26 @@ public fun Flow.sample(periodMillis: Long): Flow { .onSuccess { lastValue = it } .onFailure { it?.let { throw it } - ticker.cancel(ChildCancelledException()) + samplerProducer.cancel(ChildCancelledException()) lastValue = DONE } } - // todo: shall be start sampling only when an element arrives or sample aways as here? - ticker.onReceive { - val value = lastValue ?: return@onReceive - lastValue = null // Consume the value - downstream.emit(NULL.unbox(value)) + samplerProducer.onReceiveCatching { samplerResult -> + samplerResult + .onSuccess { sampledValue -> + if (sampledValue != null) { + val value = lastValue ?: return@onSuccess + lastValue = null // Consume the value + downstream.emit(NULL.unbox(value)) + } else { + lastValue = DONE + } + } + .onFailure { + lastValue = DONE + } + } } } @@ -306,7 +352,10 @@ public fun Flow.sample(periodMillis: Long): Flow { /* * TODO this design (and design of the corresponding operator) depends on #540 */ -internal fun CoroutineScope.fixedPeriodTicker(delayMillis: Long, initialDelayMillis: Long = delayMillis): ReceiveChannel { +internal fun CoroutineScope.fixedPeriodTicker( + delayMillis: Long, + initialDelayMillis: Long = delayMillis +): ReceiveChannel { require(delayMillis >= 0) { "Expected non-negative delay, but has $delayMillis ms" } require(initialDelayMillis >= 0) { "Expected non-negative initial delay, but has $initialDelayMillis ms" } return produce(capacity = 0) { diff --git a/kotlinx-coroutines-core/common/test/flow/operators/SampleTest.kt b/kotlinx-coroutines-core/common/test/flow/operators/SampleTest.kt index 3c04abdd99..bcdba46ee8 100644 --- a/kotlinx-coroutines-core/common/test/flow/operators/SampleTest.kt +++ b/kotlinx-coroutines-core/common/test/flow/operators/SampleTest.kt @@ -5,34 +5,46 @@ package kotlinx.coroutines.flow.operators import kotlinx.coroutines.* -import kotlinx.coroutines.channels.* +import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.flow.* -import kotlin.test.* -import kotlin.time.* +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNull import kotlin.time.Duration.Companion.milliseconds class SampleTest : TestBase() { @Test - public fun testBasic() = withVirtualTime { + fun testBasic() = withVirtualTime { expect(1) val flow = flow { expect(3) + delay(200) emit("A") - delay(1500) + expect(4) + delay(600) emit("B") - delay(500) + delay(200) emit("C") - delay(250) + delay(200) + expect(6) emit("D") - delay(2000) - emit("E") - expect(4) + delay(10000) + expect(7) } + + val samplerFlow = flow { + delay(1000) + expect(5) + emit("A1") + delay(1000) + emit("B1") + } expect(2) - val result = flow.sample(1000).toList() - assertEquals(listOf("A", "B", "D"), result) - finish(5) + + val result = flow.sample(samplerFlow).toList() + assertEquals(listOf("B", "D"), result) + finish(8) } @Test @@ -40,15 +52,21 @@ class SampleTest : TestBase() { val flow = flow { delay(60) emit(1) - delay(60) expect(1) - }.sample(100) + delay(60) + expect(3) + }.sample(flow { + delay(100) + emit(4) + expect(2) + }) assertEquals(1, flow.singleOrNull()) - finish(2) + finish(4) } + @Test - fun testBasic2() = withVirtualTime { + fun testBasicFlow2() = withVirtualTime { expect(1) val flow = flow { expect(3) @@ -66,9 +84,15 @@ class SampleTest : TestBase() { delay(501) expect(4) } - + val samplerFlow = flow { + delay(500) + repeat(10) { + emit(1) + delay(500) + } + } expect(2) - val result = flow.sample(500).toList() + val result = flow.sample(samplerFlow).toList() assertEquals(listOf(2, 6, 7), result) finish(5) } @@ -77,12 +101,17 @@ class SampleTest : TestBase() { fun testFixedDelay() = withVirtualTime { val flow = flow { emit("A") + expect(1) delay(150) emit("B") - expect(1) - }.sample(100) + expect(3) + }.sample(flow { + delay(100) + emit("A") + expect(2) + }) assertEquals("A", flow.single()) - finish(2) + finish(4) } @Test @@ -91,7 +120,10 @@ class SampleTest : TestBase() { emit(null) delay(2) expect(1) - }.sample(1) + }.sample(flow { + delay(1) + emit(1) + }) assertNull(flow.single()) finish(2) } @@ -114,20 +146,33 @@ class SampleTest : TestBase() { } expect(2) - val result = flow.sample(1000).toList() + val sampler = flow { + delay(1000) + repeat(10) { + emit(1) + delay(1000) + } + } + val result = flow.sample(sampler).toList() assertEquals(listOf("A", null, null), result) finish(5) } @Test fun testEmpty() = runTest { - val flow = emptyFlow().sample(Long.MAX_VALUE) + val flow = emptyFlow().sample(flow { + delay(Long.MAX_VALUE) + emit(1) + }) assertNull(flow.singleOrNull()) } @Test fun testScalar() = runTest { - val flow = flowOf(1, 2, 3).sample(Long.MAX_VALUE) + val flow = flowOf(1, 2, 3).sample(flow { + delay(Long.MAX_VALUE) + emit(1) + }) assertNull(flow.singleOrNull()) } @@ -145,7 +190,13 @@ class SampleTest : TestBase() { delay(3000) // long wait again } - val result = flow.sample(1000).toList() + val result = flow.sample(flow { + repeat(10) { + delay(1000) + emit(1) + } + }).toList() + assertEquals(listOf("A", "B", "C"), result) finish(3) } @@ -164,32 +215,38 @@ class SampleTest : TestBase() { delay(100) } expect(2) - }.sample(100) + }.sample(flow { + repeat(10) { + delay(100) + emit(1) + } + }) assertEquals(listOf(-1, -3, 0, 1, 2, 3), flow.toList()) finish(3) - } - @Test - fun testUpstreamError() = testUpstreamError(TestException()) + } @Test - fun testUpstreamErrorCancellationException() = testUpstreamError(CancellationException("")) - - private inline fun testUpstreamError(cause: T) = runTest { + fun testUpstreamError() = runTest { val latch = Channel() val flow = flow { expect(1) emit(1) expect(2) latch.receive() - throw cause - }.sample(1).map { + throw TestException() + }.sample(flow { + repeat(10) { + delay(1) + emit(1) + } + }).map { latch.send(Unit) hang { expect(3) } } - assertFailsWith(flow) + assertFailsWith(flow) finish(4) } @@ -203,7 +260,12 @@ class SampleTest : TestBase() { expect(2) latch.receive() throw TestException() - }.flowOn(NamedDispatchers("upstream")).sample(1).map { + }.flowOn(NamedDispatchers("upstream")).sample(flow { + repeat(10) { + delay(1) + emit(1) + } + }).map { latch.send(Unit) hang { expect(3) } } @@ -219,7 +281,10 @@ class SampleTest : TestBase() { emit(1) expect(2) throw TestException() - }.sample(Long.MAX_VALUE).map { + }.sample(flow { + delay(Long.MAX_VALUE) + emit(1) + }).map { expectUnreached() } assertFailsWith(flow) @@ -233,9 +298,13 @@ class SampleTest : TestBase() { emit(1) expect(2) throw TestException() - }.flowOn(NamedDispatchers("unused")).sample(Long.MAX_VALUE).map { - expectUnreached() - } + }.flowOn(NamedDispatchers("unused")).sample(flow { + delay(Long.MAX_VALUE) + emit(1) + }).map { + expectUnreached() + } + assertFailsWith(flow) finish(3) @@ -247,7 +316,12 @@ class SampleTest : TestBase() { expect(1) emit(1) hang { expect(3) } - }.sample(100).map { + }.sample(flow { + repeat(100) { + delay(1) + emit(1) + } + }).map { expect(2) yield() throw TestException() @@ -264,7 +338,12 @@ class SampleTest : TestBase() { expect(1) emit(1) hang { expect(3) } - }.flowOn(NamedDispatchers("upstream")).sample(100).map { + }.flowOn(NamedDispatchers("upstream")).sample(flow { + repeat(100) { + delay(1) + emit(1) + } + }).map { expect(2) yield() throw TestException()