diff --git a/kyo-kernel/shared/src/main/scala/kyo/kernel/ArrowEffect.scala b/kyo-kernel/shared/src/main/scala/kyo/kernel/ArrowEffect.scala index f2894068c..d6ac17158 100644 --- a/kyo-kernel/shared/src/main/scala/kyo/kernel/ArrowEffect.scala +++ b/kyo-kernel/shared/src/main/scala/kyo/kernel/ArrowEffect.scala @@ -442,7 +442,7 @@ object ArrowEffect: handleLoopLoop(Loop.continue(kyo(v, context)), context) end new case kyo => - kyo.asInstanceOf[A] + kyo.unsafeGet case _ => v.asInstanceOf[A < (S & S2)] end handleLoopLoop @@ -536,7 +536,7 @@ object ArrowEffect: handleLoopLoop(Loop.continue(state, kyo(v, context)), context) end new case kyo => - done(state, kyo.asInstanceOf[A]) + done(state, kyo.unsafeGet) end match case _ => v.asInstanceOf[B < (S & S2)] // Loop.done diff --git a/kyo-kernel/shared/src/test/scala/kyo/kernel/ArrowEffectTest.scala b/kyo-kernel/shared/src/test/scala/kyo/kernel/ArrowEffectTest.scala index 9900b3c38..98b39867d 100644 --- a/kyo-kernel/shared/src/test/scala/kyo/kernel/ArrowEffectTest.scala +++ b/kyo-kernel/shared/src/test/scala/kyo/kernel/ArrowEffectTest.scala @@ -337,6 +337,171 @@ class ArrowEffectTest extends Test: } } + "nested effects handling" - { + + given [A, B]: CanEqual[A, B] = CanEqual.derived + + sealed trait NestedTestEffect extends ArrowEffect[Const[Int], Const[Int]] + + def suspendNested(i: Int): Int < NestedTestEffect = + ArrowEffect.suspend[Any](Tag[NestedTestEffect], i) + + val nestedTag: Tag[NestedTestEffect] = Tag[NestedTestEffect] + + def flatten[A, B, C](v: A < B < C): A < (B & C) = v.map(a => a) + + "not handle Nested" - { + + def handle[A, S](v: A < (S & NestedTestEffect)): A < S = + ArrowEffect.handle(nestedTag, v): + [C] => (input, cont) => cont(input * 10) + + "unwraps Nested and returns inner suspension" in { + val comp: Int < NestedTestEffect = suspendNested(5) + val nested: Int < NestedTestEffect < Any = Nested(comp) + val result: Int < NestedTestEffect < Any = handle(nested) + + assert(result == nested, "handleSimple should return the nested computation") + + val flattened = flatten(result) + val finalResult = handle(flattened) + + assert(finalResult.eval == 50) + } + } + + "handleFirst on Nested" - { + + def handle[A, S](v: A < (S & NestedTestEffect)): A < (S & NestedTestEffect) = + ArrowEffect.handleFirst(nestedTag, v)( + [C] => (input, cont) => cont(input * 10), + identity + ) + + "done callback receives unwrapped value" in { + val comp = suspendNested(5) + val nested: Int < NestedTestEffect < Any = Nested(comp) + + val result = handle(nested) + + assert(result == nested, "handleFirst should return the nested computation") + + val flattened = flatten(result) + val finalResult: Int < NestedTestEffect = handle(flattened) + assert(finalResult.evalNow == Maybe(50)) + } + } + + "handleLoop (stateless) on Nested" - { + + def handle[A, S](v: A < (S & NestedTestEffect)): A < S = + ArrowEffect.handleLoop(Tag[NestedTestEffect], v): + [C] => (input, cont) => Loop.continue(cont(input * 10)) + + "unwraps Nested and handles inner suspension" in { + val comp: Int < NestedTestEffect = suspendNested(5) + val nested: Int < NestedTestEffect < Any = Nested(comp) + + val result = handle(nested) + assert(result == nested, "handleLoop should return the nested computation") + + val flattened = flatten(result) + val finalResult: Int < Any = handle(flattened) + + assert(finalResult.eval == 50) + } + } + + "handleLoop (stateful) on Nested" - { + + def handle[A, S](v: A < (S & NestedTestEffect)): A < S = + ArrowEffect.handleLoop(nestedTag, 0, v)( + [C] => (input, state, cont) => Loop.continue(state + 1, cont((input + state) * 10)) + ) + + "unwraps Nested and handles inner suspension" in { + val comp: Int < NestedTestEffect = suspendNested(5) + val nested: Int < NestedTestEffect < Any = Nested(comp) + + val result = handle(nested) + assert(result == nested, "handleLoop should return the nested computation") + + val flattened = flatten(result) + val finalResult: Int < Any = handle(flattened) + + assert(finalResult.eval == 50) + } + + } + + "handleLoop (stateful + done) on Nested" - { + + def handle[A, S](v: A < (S & NestedTestEffect)): A < S = + ArrowEffect.handleLoop(nestedTag, 0, v)( + [C] => (input, state, cont) => Loop.continue(state + 1, cont(input * 10)), + (state, v) => v + ) + + "unwraps Nested and handles inner suspension" in { + val comp: Int < NestedTestEffect = suspendNested(5) + val nested: Int < NestedTestEffect < Any = Nested(comp) + + val result = handle(nested) + assert(result == nested, "handleLoop should return the nested computation") + + val flattened = flatten(result) + val finalResult: Int < Any = handle(flattened) + + assert(finalResult.eval == 50) + } + } + + "handleCatching on Nested" - { + + def handle[A, S](v: A < (S & NestedTestEffect)): A < S = + ArrowEffect.handleCatching(nestedTag, v)( + [C] => (input, cont) => cont(input * 10), + recover = e => throw e + ) + + "unwraps Nested and handles inner suspension" in { + val comp: Int < NestedTestEffect = suspendNested(5) + val nested: Int < NestedTestEffect < Any = Nested(comp) + + val result = handle(nested) + assert(result == nested, "handleLoop should return the nested computation") + + val flattened = flatten(result) + val finalResult: Int < Any = handle(flattened) + + assert(finalResult.eval == 50) + } + } + + "handlePartial on Nested" - { + + def handle[A, S](v: A < (S & NestedTestEffect)): A < (S & NestedTestEffect) = + ArrowEffect.handlePartial(nestedTag, nestedTag, v, Context.empty)( + stop = + false, + [C] => (input, cont) => cont(input * 10), + [C] => (input, cont) => cont(input * 10) + ) + + "unwraps Nested and handles inner suspension" in { + val comp: Int < NestedTestEffect = suspendNested(5) + val nested: Int < NestedTestEffect < Any = Nested(comp) + + val result = handle(nested) + assert(result == nested, "handlePartial should return the nested computation") + + val flattened = flatten(result) + val finalResult = handle(flattened) + assert(finalResult.evalNow == Maybe(50)) + } + } + } + "effects with variance" - { "delimited continuation" - {