diff --git a/.scalafmt.conf b/.scalafmt.conf index aa47f2b4b..b267c1186 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -32,6 +32,9 @@ fileOverride { "glob:**/kyo-stats-otel/**" { runner.dialect = scala212source3 } + "glob:**/kyo-grpc-code-gen/**" { + runner.dialect = scala212source3 + } "glob:**/scala-3/**" { runner.dialect = scala3 } diff --git a/build.sbt b/build.sbt index 460f8a14d..046eaf3e9 100644 --- a/build.sbt +++ b/build.sbt @@ -3,11 +3,13 @@ import org.scalajs.jsenv.nodejs.* import org.typelevel.scalacoptions.ScalacOption import org.typelevel.scalacoptions.ScalacOptions import org.typelevel.scalacoptions.ScalaVersion +import protocbridge.Target import sbtdynver.DynVerPlugin.autoImport.* val scala3Version = "3.8.1" val scala3LTSVersion = "3.3.7" val scala213Version = "2.13.18" +val scala212Version = "2.12.20" val zioVersion = "2.1.24" val catsVersion = "3.6.3" @@ -56,7 +58,7 @@ lazy val `kyo-settings` = Seq( crossScalaVersions := List(scala3Version), scalacOptions ++= scalacOptionTokens(compilerOptions).value, Test / scalacOptions --= scalacOptionTokens(Set(ScalacOptions.warnNonUnitStatement)).value, - scalafmtOnCompile := true, + scalafmtOnCompile := false, scalacOptions += compilerOptionFailDiscard, Test / testOptions += Tests.Argument("-oDG"), ThisBuild / versionScheme := Some("early-semver"), @@ -125,7 +127,8 @@ lazy val kyoJVM = project `kyo-combinators`.jvm, `kyo-playwright`.jvm, `kyo-examples`.jvm, - `kyo-actor`.jvm + `kyo-actor`.jvm, + `kyo-grpc`.jvm ) lazy val kyoJS = project @@ -150,7 +153,8 @@ lazy val kyoJS = project `kyo-zio`.js, `kyo-cats`.js, `kyo-combinators`.js, - `kyo-actor`.js + `kyo-actor`.js, + `kyo-grpc`.js ) lazy val kyoNative = project @@ -596,6 +600,124 @@ lazy val `kyo-cats` = ) .jvmSettings(mimaCheck(false)) +lazy val `kyo-grpc` = + crossProject(JVMPlatform, JSPlatform) + .withoutSuffixFor(JVMPlatform) + .in(file("kyo-grpc")) + .settings( + crossScalaVersions := Seq.empty, + publishArtifact := false, + publish := {}, + publishLocal := {} + ) + .aggregate( + `kyo-grpc-core`, + `kyo-grpc-code-gen`, + `kyo-grpc-e2e` + ) + +lazy val `kyo-grpc-jvm` = + `kyo-grpc` + .jvm + .aggregate(`kyo-grpc-protoc-gen`.componentProjects.map(p => p: ProjectReference) *) + +lazy val `kyo-grpc-core` = + crossProject(JVMPlatform, JSPlatform) + .withoutSuffixFor(JVMPlatform) + .crossType(CrossType.Full) + .in(file("kyo-grpc-core")) + .dependsOn(`kyo-core`) + .settings(`kyo-settings`) + .settings( + libraryDependencies += "org.scalamock" %% "scalamock" % "7.5.0" % Test + ) + .jvmSettings( + libraryDependencies ++= Seq( + "com.thesamet.scalapb" %% "scalapb-runtime-grpc" % scalapb.compiler.Version.scalapbVersion, + "io.grpc" % "grpc-api" % "1.72.0", + // It is a little unusual to include this here but it greatly reduces the amount of generated code. + "io.grpc" % "grpc-stub" % "1.72.0", + "ch.qos.logback" % "logback-classic" % "1.5.18" % Test + ) + ).jsSettings( + `js-settings`, + libraryDependencies ++= Seq( + "com.thesamet.scalapb.grpcweb" %%% "scalapb-grpcweb" % "0.7.0") + ) + +lazy val `kyo-grpc-code-gen` = + crossProject(JVMPlatform, JSPlatform) + .withoutSuffixFor(JVMPlatform) + .crossType(CrossType.Full) + .in(file("kyo-grpc-code-gen")) + .enablePlugins(BuildInfoPlugin) + .settings( + `kyo-settings`, + buildInfoKeys := Seq[BuildInfoKey](name, organization, version, scalaVersion, sbtVersion), + buildInfoPackage := "kyo.grpc.compiler", + crossScalaVersions := List(scala213Version, scala3Version), + scalacOptions ++= scalacOptionToken(ScalacOptions.source3).value, + libraryDependencies ++= Seq( + "com.thesamet.scalapb" %% "compilerplugin" % scalapb.compiler.Version.scalapbVersion, + "org.scala-lang.modules" %%% "scala-collection-compat" % "2.12.0", + "org.typelevel" %%% "paiges-core" % "0.4.3" + ) + ).jsSettings( + `js-settings` + ) + +lazy val `kyo-grpc-code-gen_2.12` = + `kyo-grpc-code-gen` + .jvm + .settings(scalaVersion := scala212Version) + +lazy val `kyo-grpc-code-genJS_2.12` = + `kyo-grpc-code-gen` + .js + .settings(scalaVersion := scala212Version) + +lazy val `kyo-grpc-protoc-gen` = + protocGenProject("kyo-grpc-protoc-gen", `kyo-grpc-code-gen_2.12`) + .settings( + `kyo-settings`, + scalaVersion := scala212Version, + crossScalaVersions := Seq(scala212Version), + Compile / mainClass := Some("kyo.grpc.compiler.CodeGenerator") + ) + .aggregateProjectSettings( + scalaVersion := scala212Version, + crossScalaVersions := Seq(scala212Version) + ) + +lazy val `kyo-grpc-e2e` = + crossProject(JVMPlatform, JSPlatform) + .withoutSuffixFor(JVMPlatform) + .crossType(CrossType.Full) + .in(file("kyo-grpc-e2e")) + .enablePlugins(LocalCodeGenPlugin) + .dependsOn(`kyo-grpc-core` % "compile->compile;test->test") + .settings( + `kyo-settings`, + publish / skip := true, + Compile / PB.protoSources += (ThisBuild / baseDirectory).value / "kyo-grpc-e2e/shared/src/main/protobuf", + Compile / PB.targets := Seq( + scalapb.gen() -> (Compile / sourceManaged).value / "scalapb", + // Users of the plugin can use: kyo.grpc.gen() -> (Compile / sourceManaged).value / "scalapb" + genModule("kyo.grpc.compiler.CodeGenerator$") -> (Compile / sourceManaged).value / "scalapb" + ), + Compile / scalacOptions ++= scalacOptionToken(ScalacOptions.warnOption("conf:src=.*/src_managed/main/scalapb/kgrpc/.*:silent")).value + ).jvmSettings( + codeGenClasspath := (`kyo-grpc-code-gen_2.12` / Compile / fullClasspath).value, + libraryDependencies ++= Seq( + "io.grpc" % "grpc-netty-shaded" % "1.72.0", + "ch.qos.logback" % "logback-classic" % "1.5.18" % Test + ) + ).jsSettings( + `js-settings`, + codeGenClasspath := (`kyo-grpc-code-genJS_2.12` / Compile / fullClasspath).value, + libraryDependencies += "com.thesamet.scalapb.grpcweb" %%% "scalapb-grpcweb" % "0.7.0" + ) + lazy val `kyo-combinators` = crossProject(JSPlatform, JVMPlatform, NativePlatform) .withoutSuffixFor(JVMPlatform) @@ -647,17 +769,52 @@ lazy val `kyo-bench` = .withoutSuffixFor(JVMPlatform) .crossType(CrossType.Pure) .in(file("kyo-bench")) - .enablePlugins(JmhPlugin) - .dependsOn(`kyo-core`) - .dependsOn(`kyo-parse`) - .dependsOn(`kyo-sttp`) - .dependsOn(`kyo-stm`) - .dependsOn(`kyo-direct`) - .dependsOn(`kyo-scheduler-zio`) - .dependsOn(`kyo-scheduler-cats`) + .enablePlugins(Fs2Grpc, JmhPlugin, LocalCodeGenPlugin) .disablePlugins(MimaPlugin) + .dependsOn( + `kyo-core`, + `kyo-direct`, + `kyo-grpc-core`, + `kyo-parse`, + `kyo-scheduler-cats`, + `kyo-scheduler-zio`, + `kyo-stm`, + `kyo-sttp` + ) .settings( `kyo-settings`, + publish / skip := true, + Compile / PB.protoSources += baseDirectory.value.getParentFile / "src" / "main" / "protobuf", + Compile / PB.targets := { + val scalapbDir = (Compile / sourceManaged).value / "scalapb" + // This includes the base scalapb.gen. + val catsGen = Fs2GrpcPlugin.autoImport.scalapbCodeGenerators.value + catsGen ++ Seq[Target]( + scalapb.gen(scala3Sources = true) -> scalapbDir / "vanilla", + scalapb.zio_grpc.ZioCodeGenerator -> scalapbDir, + genModule("kyo.grpc.compiler.CodeGenerator$") -> scalapbDir + ) + }, + Compile / PB.generate ~= { files => + files.filter(_.isFile).filter(_.getPath.contains("/vanilla/")).foreach { file => + val fileContent = IO.read(file) + val updatedContent = fileContent + // Workaround for https://github.com/scalapb/ScalaPB/issues/1816. + .replace( + "_unknownFields__.parseField(tag, _input__)", + "_unknownFields__.parseField(tag, _input__): Unit" + ) + // Hacky workaround to not get a collision with the one generated by Kyo and ZIO. + .replace( + "kgrpc.bench", + "vanilla.kgrpc.bench", + ) + IO.write(file, updatedContent) + } + files + }, + codeGenClasspath := (`kyo-grpc-code-gen_2.12` / Compile / fullClasspath).value, + Compile / scalacOptions ++= scalacOptionToken(ScalacOptions.warnOption("conf:src=.*/src_managed/main/scalapb/kgrpc/.*:silent")).value, Test / testForkedParallel := true, // Forks each test suite individually Test / testGrouping := { @@ -681,6 +838,7 @@ lazy val `kyo-bench` = ) } }, + libraryDependencies += "io.grpc" % "grpc-netty-shaded" % "1.72.0", libraryDependencies += "dev.zio" %% "izumi-reflect" % "3.0.9", libraryDependencies += "org.typelevel" %% "cats-effect" % catsVersion, libraryDependencies += "org.typelevel" %% "log4cats-core" % "2.7.1", @@ -713,6 +871,7 @@ lazy val readme = .crossType(CrossType.Full) .in(file("target/readme")) .enablePlugins(MdocPlugin) + .disablePlugins(ProtocPlugin) .settings( `kyo-settings`, mdocIn := new File("./../../README-in.md"), @@ -773,6 +932,10 @@ def mimaCheck(failOnProblem: Boolean) = mimaFailOnProblem := failOnProblem ) +def sharedSourceDir(conf: String) = Def.setting { + CrossType.Full.sharedSrcDir(baseDirectory.value, conf).get.getParentFile +} + // --- Scalafix lazy val V = _root_.scalafix.sbt.BuildInfo diff --git a/kyo-bench/src/main/protobuf/bench.proto b/kyo-bench/src/main/protobuf/bench.proto new file mode 100644 index 000000000..a9cb34220 --- /dev/null +++ b/kyo-bench/src/main/protobuf/bench.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; + +// Don't use kyo here because otherwise it cannot derive the Frame. +package kgrpc; + +service TestService { + rpc OneToOne(Request) returns (Response); + rpc OneToMany(Request) returns (stream Response); + rpc ManyToOne(stream Request) returns (Response); + rpc ManyToMany(stream Request) returns (stream Response); +} + +message Request { + string message = 1; +} + +message Response { + string message = 1; +} diff --git a/kyo-bench/src/main/scala/kyo/bench/arena/ArenaBench2.scala b/kyo-bench/src/main/scala/kyo/bench/arena/ArenaBench2.scala new file mode 100644 index 000000000..2add95a4b --- /dev/null +++ b/kyo-bench/src/main/scala/kyo/bench/arena/ArenaBench2.scala @@ -0,0 +1,80 @@ +package kyo.bench.arena + +import kyo.bench.BaseBench +import org.openjdk.jmh.annotations.* +import scala.compiletime.uninitialized + +// TODO: What to call this? +abstract class ArenaBench2[A](val expectedResult: A) extends BaseBench: + + def forkCats(catsBench: cats.effect.IO[A])(using cats.effect.unsafe.IORuntime): A = + cats.effect.IO.cede.flatMap(_ => catsBench).unsafeRunSync() + + def forkKyo(kyoBenchFiber: kyo.<[A, kyo.Async & kyo.Abort[Throwable] & kyo.Scope]): A = + import kyo.* + import AllowUnsafe.embrace.danger + given Frame = Frame.internal + kyoBenchFiber.handle( + kyo.Scope.run, + Fiber.initUnscoped(_), + _.map(_.block(Duration.Infinity)), + Sync.Unsafe.evalOrThrow + ).getOrThrow + end forkKyo + + def forkZIO(zioBench: zio.Task[A])(using zioRuntime: zio.Runtime[Any]): A = zio.Unsafe.unsafe(implicit u => + zioRuntime.unsafe.run(zio.ZIO.yieldNow.flatMap(_ => zioBench)).getOrThrow() + ) + +end ArenaBench2 + +object ArenaBench2: + + @State(Scope.Benchmark) + class CatsRuntime: + + var ioRuntime: cats.effect.unsafe.IORuntime = uninitialized + given cats.effect.unsafe.IORuntime = ioRuntime + + @Setup + def setup() = + ioRuntime = + if System.getProperty("replaceCatsExecutor", "false") == "true" then + kyo.KyoSchedulerIORuntime.global + else + cats.effect.unsafe.implicits.global + end setup + + end CatsRuntime + + @State(Scope.Benchmark) + class ZIORuntime: + + var zioRuntime: zio.Runtime[Any] = uninitialized + + private var finalizer: () => Unit = () => () + + @Setup + def setup(): Unit = + val zioRuntimeLayer: zio.ZLayer[Any, Any, Any] = + if System.getProperty("replaceZIOExecutor", "false") == "true" then + kyo.KyoSchedulerZIORuntime.layer + else + zio.ZLayer.empty + + zioRuntime = + if zioRuntimeLayer ne zio.ZLayer.empty then + val (runtime, finalizer) = ZIORuntime.fromLayerWithFinalizer(zioRuntimeLayer) + this.finalizer = finalizer + runtime + else + zio.Runtime.default + end setup + + @TearDown + def tearDown(): Unit = + finalizer() + + end ZIORuntime + +end ArenaBench2 diff --git a/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcClientBench.scala b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcClientBench.scala new file mode 100644 index 000000000..137a49eb4 --- /dev/null +++ b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcClientBench.scala @@ -0,0 +1,164 @@ +package kyo.bench.arena.grpc + +import GrpcClientBench.* +import GrpcService.* +import io.grpc.* +import io.grpc.stub.StreamObserver +import java.util.concurrent.TimeoutException +import java.util.concurrent.TimeUnit +import kgrpc.bench.TestServiceFs2Grpc +import kgrpc.bench.ZioBench +import kyo.* +import kyo.Scope +import kyo.bench.arena.ArenaBench2.* +import kyo.grpc.Grpc +import kyo.kernel.ContextEffect +import org.openjdk.jmh.annotations.* +import org.openjdk.jmh.annotations.Scope as JmhScope +import scala.compiletime.uninitialized +import scala.concurrent.Future +import scalapb.zio_grpc +import vanilla.kgrpc.bench.* +import vanilla.kgrpc.bench.TestServiceGrpc.* +import zio.ZIO + +object GrpcClientBench: + + private val executionContext = kyo.scheduler.Scheduler.get.asExecutionContext + + class TestServiceImpl extends TestService: + + private val response = Response("response") + private val responses = Chunk.fill(size)(response) + + def oneToOne(request: Request): Future[Response] = + Future.successful(response) + + def oneToMany(request: Request, responseObserver: StreamObserver[Response]): Unit = + responses.foreach(responseObserver.onNext) + responseObserver.onCompleted() + end oneToMany + + def manyToOne(responseObserver: StreamObserver[Response]): StreamObserver[Request] = + new StreamObserver[Request]: + def onNext(request: Request): Unit = () + + def onError(t: Throwable): Unit = + responseObserver.onError(t) + + def onCompleted(): Unit = + responseObserver.onNext(response) + responseObserver.onCompleted() + end onCompleted + + def manyToMany(responseObserver: StreamObserver[Response]): StreamObserver[Request] = + new StreamObserver[Request]: + def onNext(request: Request): Unit = + responses.foreach(responseObserver.onNext) + + def onError(t: Throwable): Unit = + responseObserver.onError(t) + + def onCompleted(): Unit = + responseObserver.onCompleted() + + end TestServiceImpl + + @State(JmhScope.Benchmark) + abstract class BaseState: + + protected var port: Int = uninitialized + var server: Server = uninitialized + protected var finalizers: List[() => Unit] = Nil + + protected def addFinalizer(finalizer: => Unit): Unit = + finalizers = (() => finalizer) :: finalizers + + def setup(): Unit = + port = findFreePort() + server = ServerBuilder + .forPort(port) + .addService(TestService.bindService(new TestServiceImpl, executionContext)) + .build + .start + end setup + + @TearDown + def tearDown(): Unit = + server.shutdownNow() + finalizers.foreach(_()) + val isShutdown = server.awaitTermination(10, TimeUnit.SECONDS) + if !isShutdown then throw TimeoutException("Server did not shutdown within 10 seconds.") + end tearDown + + end BaseState + + @State(JmhScope.Benchmark) + class CatsState extends BaseState: + + var ioRuntime: cats.effect.unsafe.IORuntime = uninitialized + given cats.effect.unsafe.IORuntime = ioRuntime + var client: TestServiceFs2Grpc[cats.effect.IO, Metadata] = uninitialized + + @Setup + def setup(runtime: CatsRuntime): Unit = + super.setup() + + ioRuntime = runtime.ioRuntime + + val (client, finalizer) = createCatsClient(port).allocated.unsafeRunSync() + this.client = client + + addFinalizer: + finalizer.unsafeRunSync() + end setup + + end CatsState + + @State(JmhScope.Benchmark) + class KyoState extends BaseState: + + var client: kgrpc.bench.TestService.Client = uninitialized + + @Setup + override def setup(): Unit = + super.setup() + + import AllowUnsafe.embrace.danger + + val finalizer = Scope.Finalizer.Awaitable.Unsafe.init(1) + val clientEffect = createKyoClient(port) + val clientResult = ContextEffect.handle(Tag[Scope], finalizer, _ => finalizer)(clientEffect) + + client = Abort.run(Sync.Unsafe.run(clientResult)).eval.getOrThrow + + addFinalizer: + Abort.run(Sync.Unsafe.run(finalizer.close(Absent))).eval.getOrThrow + end setup + + end KyoState + + @State(JmhScope.Benchmark) + class ZIOState extends BaseState: + + var zioRuntime: zio.Runtime[Any] = uninitialized + given zio.Runtime[Any] = zioRuntime + var client: ZioBench.TestServiceClient = uninitialized + + @Setup + def setup(runtime: ZIORuntime): Unit = + super.setup() + + zioRuntime = runtime.zioRuntime + given zio.Unsafe = zio.Unsafe + + val clientScope = zio.Scope.unsafe.make + val zioClientEffect = clientScope.extend(createZioClient(port)) + client = zioRuntime.unsafe.run(zioClientEffect).getOrThrow() + addFinalizer: + zioRuntime.unsafe.run(clientScope.close(zio.Exit.unit)).getOrThrow() + end setup + + end ZIOState + +end GrpcClientBench diff --git a/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcClientManyToManyBench.scala b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcClientManyToManyBench.scala new file mode 100644 index 000000000..020e19681 --- /dev/null +++ b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcClientManyToManyBench.scala @@ -0,0 +1,48 @@ +package kyo.bench.arena.grpc + +import GrpcClientBench.* +import GrpcService.* +import io.grpc.* +import java.util.concurrent.TimeoutException +import java.util.concurrent.TimeUnit +import kgrpc.bench.* +import kgrpc.bench.TestServiceGrpc.* +import kyo.* +import kyo.Scope +import kyo.bench.arena.ArenaBench2 +import kyo.bench.arena.ArenaBench2.* +import kyo.bench.arena.WarmupJITProfile.CatsForkWarmup +import kyo.bench.arena.WarmupJITProfile.KyoForkWarmup +import kyo.bench.arena.WarmupJITProfile.ZIOForkWarmup +import kyo.grpc.Grpc +import org.openjdk.jmh.annotations.* +import scala.compiletime.uninitialized +import scalapb.zio_grpc +import zio.ZIO + +class GrpcClientManyToManyBench extends ArenaBench2[Long](sizeSquared): + + @Benchmark + def catsBench(warmup: CatsForkWarmup, state: CatsState): Long = + import state.{*, given} + forkCats: + client.manyToMany(fs2.Stream.emits(requests), Metadata()).compile.count + end catsBench + + @Benchmark + def kyoBench(warmup: KyoForkWarmup, state: KyoState): Long = + import state.* + forkKyo: + Env.run(Metadata()): + // TODO: Can we avoid the lift here? + client.manyToMany(Kyo.lift(Stream.init(requests))).into(Sink.count.map(_.toLong)) + end kyoBench + + @Benchmark + def zioBench(warmup: ZIOForkWarmup, state: ZIOState): Long = + import state.{*, given} + forkZIO: + client.manyToMany(zio.stream.ZStream.fromIterable(requests)).runCount + end zioBench + +end GrpcClientManyToManyBench diff --git a/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcClientManyToOneBench.scala b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcClientManyToOneBench.scala new file mode 100644 index 000000000..5a7b9385d --- /dev/null +++ b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcClientManyToOneBench.scala @@ -0,0 +1,47 @@ +package kyo.bench.arena.grpc + +import GrpcClientBench.* +import GrpcService.* +import io.grpc.* +import java.util.concurrent.TimeoutException +import java.util.concurrent.TimeUnit +import kgrpc.bench.* +import kgrpc.bench.TestServiceGrpc.* +import kyo.* +import kyo.bench.arena.ArenaBench2 +import kyo.bench.arena.ArenaBench2.* +import kyo.bench.arena.WarmupJITProfile.CatsForkWarmup +import kyo.bench.arena.WarmupJITProfile.KyoForkWarmup +import kyo.bench.arena.WarmupJITProfile.ZIOForkWarmup +import kyo.grpc.Grpc +import org.openjdk.jmh.annotations.* +import scala.compiletime.uninitialized +import scalapb.zio_grpc +import zio.ZIO + +class GrpcClientManyToOneBench extends ArenaBench2(response): + + @Benchmark + def catsBench(warmup: CatsForkWarmup, state: CatsState): Response = + import state.{*, given} + forkCats: + client.manyToOne(fs2.Stream.emits(requests), Metadata()) + end catsBench + + @Benchmark + def kyoBench(warmup: KyoForkWarmup, state: KyoState): Response = + import state.* + forkKyo: + Env.run(Metadata()): + // TODO: Can we avoid the lift here? + client.manyToOne(Kyo.lift(Stream.init(requests))) + end kyoBench + + @Benchmark + def zioBench(warmup: ZIOForkWarmup, state: ZIOState): Response = + import state.{*, given} + forkZIO: + client.manyToOne(zio.stream.ZStream.fromIterable(requests)) + end zioBench + +end GrpcClientManyToOneBench diff --git a/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcClientOneToManyBench.scala b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcClientOneToManyBench.scala new file mode 100644 index 000000000..85f53ba4d --- /dev/null +++ b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcClientOneToManyBench.scala @@ -0,0 +1,47 @@ +package kyo.bench.arena.grpc + +import GrpcClientBench.* +import GrpcService.* +import io.grpc.* +import java.util.concurrent.TimeoutException +import java.util.concurrent.TimeUnit +import kgrpc.bench.* +import kgrpc.bench.TestServiceGrpc.* +import kyo.* +import kyo.bench.arena.ArenaBench2 +import kyo.bench.arena.ArenaBench2.* +import kyo.bench.arena.WarmupJITProfile.CatsForkWarmup +import kyo.bench.arena.WarmupJITProfile.KyoForkWarmup +import kyo.bench.arena.WarmupJITProfile.ZIOForkWarmup +import kyo.grpc.Grpc +import org.openjdk.jmh.annotations.* +import scala.compiletime.uninitialized +import scalapb.zio_grpc +import zio.ZIO + +class GrpcClientOneToManyBench extends ArenaBench2[Long](size): + + @Benchmark + def catsBench(warmup: CatsForkWarmup, state: CatsState): Long = + import state.{*, given} + forkCats: + client.oneToMany(request, Metadata()).compile.count + end catsBench + + @Benchmark + def kyoBench(warmup: KyoForkWarmup, state: KyoState): Long = + import state.* + forkKyo: + Env.run(Metadata()): + // TODO: Can we avoid the lift here? + client.oneToMany(Kyo.lift(request)).into(Sink.count.map(_.toLong)) + end kyoBench + + @Benchmark + def zioBench(warmup: ZIOForkWarmup, state: ZIOState): Long = + import state.{*, given} + forkZIO: + client.oneToMany(request).runCount + end zioBench + +end GrpcClientOneToManyBench diff --git a/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcClientUnaryBench.scala b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcClientUnaryBench.scala new file mode 100644 index 000000000..93c576f8e --- /dev/null +++ b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcClientUnaryBench.scala @@ -0,0 +1,47 @@ +package kyo.bench.arena.grpc + +import GrpcClientBench.* +import GrpcService.* +import io.grpc.* +import java.util.concurrent.TimeoutException +import java.util.concurrent.TimeUnit +import kgrpc.bench.* +import kgrpc.bench.TestServiceGrpc.* +import kyo.* +import kyo.bench.arena.ArenaBench2 +import kyo.bench.arena.ArenaBench2.* +import kyo.bench.arena.WarmupJITProfile.CatsForkWarmup +import kyo.bench.arena.WarmupJITProfile.KyoForkWarmup +import kyo.bench.arena.WarmupJITProfile.ZIOForkWarmup +import kyo.grpc.Grpc +import org.openjdk.jmh.annotations.* +import scala.compiletime.uninitialized +import scalapb.zio_grpc +import zio.ZIO + +class GrpcClientUnaryBench extends ArenaBench2(response): + + @Benchmark + def catsBench(warmup: CatsForkWarmup, state: CatsState): Response = + import state.{*, given} + forkCats: + client.oneToOne(request, Metadata()) + end catsBench + + @Benchmark + def kyoBench(warmup: KyoForkWarmup, state: KyoState): Response = + import state.* + forkKyo: + Env.run(Metadata()): + // TODO: Can we avoid the lift here? + client.oneToOne(Kyo.lift(request)) + end kyoBench + + @Benchmark + def zioBench(warmup: ZIOForkWarmup, state: ZIOState): Response = + import state.{*, given} + forkZIO: + client.oneToOne(request) + end zioBench + +end GrpcClientUnaryBench diff --git a/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcE2EManyToManyBench.scala b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcE2EManyToManyBench.scala new file mode 100644 index 000000000..ace2fab9b --- /dev/null +++ b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcE2EManyToManyBench.scala @@ -0,0 +1,56 @@ +package kyo.bench.arena.grpc + +import GrpcService.* +import io.grpc.Metadata +import kgrpc.* +import kgrpc.bench.* +import kyo.* +import kyo.Scope +import kyo.bench.arena.ArenaBench +import kyo.grpc.Grpc +import org.openjdk.jmh.annotations.* +import scala.compiletime.uninitialized +import scalapb.zio_grpc.Server +import zio.Chunk +import zio.UIO +import zio.ZIO +import zio.stream.ZStream + +class GrpcE2EManyToManyBench extends ArenaBench.ForkOnly[Long](sizeSquared): + + private var port: Int = uninitialized + + @Setup + def buildChannel(): Unit = + port = findFreePort() + end buildChannel + + override def catsBench(): cats.effect.IO[Long] = + import cats.effect.* + import fs2.Stream + createCatsServer(port, static = false).use: _ => + createCatsClient(port).use: client => + val requestStream = Stream.emits(requests) + client.manyToMany(requestStream, Metadata()).compile.count + end catsBench + + override def kyoBenchFiber(): Long < (Async & Abort[Throwable]) = + Scope.run: + Env.run(Metadata()): + for + _ <- createKyoServer(port, static = false) + client <- createKyoClient(port) + // TODO: Can we avoid the lift here? + yield client.manyToMany(Kyo.lift(Stream.init(requests))).into(Sink.count.map(_.toLong)) + + override def zioBench(): UIO[Long] = + ZIO.scoped: + val run = + for + _ <- createZioServer(port, static = false) + client <- createZioClient(port) + reply <- client.manyToMany(ZStream.fromIterable(requests)).runCount + yield reply + run.orDie + +end GrpcE2EManyToManyBench diff --git a/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcE2EManyToOneBench.scala b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcE2EManyToOneBench.scala new file mode 100644 index 000000000..d35c50406 --- /dev/null +++ b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcE2EManyToOneBench.scala @@ -0,0 +1,56 @@ +package kyo.bench.arena.grpc + +import GrpcService.* +import io.grpc.Metadata +import kgrpc.* +import kgrpc.bench.* +import kyo.* +import kyo.Scope +import kyo.bench.arena.ArenaBench +import kyo.grpc.Grpc +import org.openjdk.jmh.annotations.* +import scala.compiletime.uninitialized +import scalapb.zio_grpc.Server +import zio.Chunk +import zio.UIO +import zio.ZIO +import zio.stream.ZStream + +class GrpcE2EManyToOneBench extends ArenaBench.ForkOnly(response): + + private var port: Int = uninitialized + + @Setup + def buildChannel(): Unit = + port = findFreePort() + end buildChannel + + override def catsBench(): cats.effect.IO[Response] = + import cats.effect.* + import fs2.Stream + createCatsServer(port, static = false).use: _ => + createCatsClient(port).use: client => + val requestStream = Stream.emits(requests) + client.manyToOne(requestStream, Metadata()) + end catsBench + + override def kyoBenchFiber(): Response < (Async & Abort[Throwable]) = + Scope.run: + Env.run(Metadata()): + for + _ <- createKyoServer(port, static = false) + client <- createKyoClient(port) + // TODO: Can we avoid the lift here? + yield client.manyToOne(Kyo.lift(Stream.init(requests))) + + override def zioBench(): UIO[Response] = + ZIO.scoped: + val run = + for + _ <- createZioServer(port, static = false) + client <- createZioClient(port) + reply <- client.manyToOne(ZStream.fromIterable(requests)) + yield reply + run.orDie + +end GrpcE2EManyToOneBench diff --git a/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcE2EOneToManyBench.scala b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcE2EOneToManyBench.scala new file mode 100644 index 000000000..256b11413 --- /dev/null +++ b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcE2EOneToManyBench.scala @@ -0,0 +1,53 @@ +package kyo.bench.arena.grpc + +import GrpcService.* +import io.grpc.Metadata +import kgrpc.* +import kgrpc.bench.* +import kyo.* +import kyo.Scope +import kyo.bench.arena.ArenaBench +import kyo.grpc.Grpc +import kyo.grpc.SafeMetadata +import org.openjdk.jmh.annotations.* +import scala.compiletime.uninitialized +import scalapb.zio_grpc.Server +import zio.UIO +import zio.ZIO + +class GrpcE2EOneToManyBench extends ArenaBench.ForkOnly[Long](size): + + private var port: Int = uninitialized + + @Setup + def buildChannel(): Unit = + port = findFreePort() + end buildChannel + + override def catsBench(): cats.effect.IO[Long] = + import cats.effect.* + createCatsServer(port, static = false).use: _ => + createCatsClient(port).use: client => + client.oneToMany(request, Metadata()).compile.count + end catsBench + + override def kyoBenchFiber(): Long < (Async & Abort[Throwable]) = + Scope.run: + Env.run(SafeMetadata.empty): + for + _ <- createKyoServer(port, static = false) + client <- createKyoClient(port) + // TODO: Can we avoid the lift here? + yield client.oneToMany(Kyo.lift(request)).into(Sink.count.map(_.toLong)) + + override def zioBench(): UIO[Long] = + ZIO.scoped: + val run = + for + _ <- createZioServer(port, static = false) + client <- createZioClient(port) + reply <- client.oneToMany(request).runCount + yield reply + run.orDie + +end GrpcE2EOneToManyBench diff --git a/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcE2EUnaryBench.scala b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcE2EUnaryBench.scala new file mode 100644 index 000000000..f6cf5f28b --- /dev/null +++ b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcE2EUnaryBench.scala @@ -0,0 +1,53 @@ +package kyo.bench.arena.grpc + +import GrpcService.* +import io.grpc.Metadata +import kgrpc.* +import kgrpc.bench.* +import kyo.* +import kyo.Scope +import kyo.bench.arena.ArenaBench +import kyo.grpc.Grpc +import kyo.grpc.SafeMetadata +import org.openjdk.jmh.annotations.* +import scala.compiletime.uninitialized +import scalapb.zio_grpc.Server +import zio.UIO +import zio.ZIO + +class GrpcE2EUnaryBench extends ArenaBench.ForkOnly(response): + + private var port: Int = uninitialized + + @Setup + def buildChannel(): Unit = + port = findFreePort() + end buildChannel + + override def catsBench(): cats.effect.IO[Response] = + import cats.effect.* + createCatsServer(port, static = false).use: _ => + createCatsClient(port).use: client => + client.oneToOne(request, Metadata()) + end catsBench + + override def kyoBenchFiber(): Response < (Async & Abort[Throwable]) = + Scope.run: + Env.run(SafeMetadata.empty): + for + _ <- createKyoServer(port, static = false) + client <- createKyoClient(port) + // TODO: Can we avoid the lift here? + yield client.oneToOne(Kyo.lift(request)) + + override def zioBench(): UIO[Response] = + ZIO.scoped: + val run = + for + _ <- createZioServer(port, static = false) + client <- createZioClient(port) + reply <- client.oneToOne(request) + yield reply + run.orDie + +end GrpcE2EUnaryBench diff --git a/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcServerBench.scala b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcServerBench.scala new file mode 100644 index 000000000..b587baa20 --- /dev/null +++ b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcServerBench.scala @@ -0,0 +1,115 @@ +package kyo.bench.arena.grpc + +import GrpcService.* +import io.grpc.* +import java.util.concurrent.TimeoutException +import java.util.concurrent.TimeUnit +import kgrpc.bench.* +import kgrpc.bench.TestServiceGrpc.* +import kyo.* +import kyo.Scope +import kyo.bench.arena.ArenaBench2.* +import kyo.kernel.ContextEffect +import org.openjdk.jmh.annotations.* +import org.openjdk.jmh.annotations.Scope as JmhScope +import scala.compiletime.uninitialized +import scalapb.zio_grpc +import zio.ZIO + +object GrpcServerBench: + + @State(JmhScope.Benchmark) + abstract class BaseState: + + protected var port: Int = uninitialized + var channel: ManagedChannel = uninitialized + + var blockingStub: TestServiceBlockingStub = uninitialized + var stub: TestServiceStub = uninitialized + + protected var finalizers: List[() => Unit] = Nil + + protected def addFinalizer(finalizer: => Unit): Unit = + finalizers = (() => finalizer) :: finalizers + + def setup(): Unit = + port = findFreePort() + channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build + blockingStub = TestServiceGrpc.blockingStub(channel) + stub = TestServiceGrpc.stub(channel) + end setup + + @TearDown + def tearDown(): Unit = + channel.shutdownNow() + finalizers.foreach(_()) + val isShutdown = channel.awaitTermination(10, TimeUnit.SECONDS) + if !isShutdown then throw TimeoutException("Channel did not shutdown within 10 seconds.") + end tearDown + + end BaseState + + @State(JmhScope.Benchmark) + class CatsState extends BaseState: + + var ioRuntime: cats.effect.unsafe.IORuntime = uninitialized + given cats.effect.unsafe.IORuntime = ioRuntime + + @Setup + def setup(runtime: CatsRuntime): Unit = + super.setup() + + ioRuntime = runtime.ioRuntime + + val (_: Server, finalizer) = createCatsServer(port, static = true).allocated.unsafeRunSync() + + addFinalizer: + finalizer.unsafeRunSync() + end setup + + end CatsState + + @State(JmhScope.Benchmark) + class KyoState extends BaseState: + + @Setup + override def setup(): Unit = + super.setup() + + import AllowUnsafe.embrace.danger + + val finalizer = Scope.Finalizer.Awaitable.Unsafe.init(1) + val kyoServer = createKyoServer(port, static = true) + val result = ContextEffect.handle(Tag[Scope], finalizer, _ => finalizer)(kyoServer) + val _: Server = Abort.run(Sync.Unsafe.run(result)).eval.getOrThrow + + addFinalizer: + Abort.run(Sync.Unsafe.run(finalizer.close(Absent))).eval.getOrThrow + end setup + + end KyoState + + @State(JmhScope.Benchmark) + class ZIOState extends BaseState: + + var zioRuntime: zio.Runtime[Any] = uninitialized + given zio.Runtime[Any] = zioRuntime + + @Setup + def setup(runtime: ZIORuntime): Unit = + super.setup() + + zioRuntime = runtime.zioRuntime + given zio.Unsafe = zio.Unsafe + + val scope = zio.Scope.unsafe.make + val zioServer = scope.extend(createZioServer(port, static = true)) + val _: zio_grpc.Server = zioRuntime.unsafe.run(zioServer).getOrThrow() + + addFinalizer: + zioRuntime.unsafe.run(scope.close(zio.Exit.unit)).getOrThrow() + end setup + + end ZIOState + +end GrpcServerBench diff --git a/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcServerManyToManyBench.scala b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcServerManyToManyBench.scala new file mode 100644 index 000000000..618e996e7 --- /dev/null +++ b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcServerManyToManyBench.scala @@ -0,0 +1,77 @@ +package kyo.bench.arena.grpc + +import GrpcServerBench.* +import GrpcService.* +import io.grpc.stub.StreamObserver +import kgrpc.bench.* +import kyo.* +import kyo.Fiber.Promise.Unsafe +import kyo.bench.arena.ArenaBench2 +import kyo.bench.arena.WarmupJITProfile.CatsForkWarmup +import kyo.bench.arena.WarmupJITProfile.KyoForkWarmup +import kyo.bench.arena.WarmupJITProfile.ZIOForkWarmup +import org.openjdk.jmh.annotations.* +import zio.ZIO + +class GrpcServerManyToManyBench extends ArenaBench2(sizeSquared): + + private given Frame = Frame.internal + + @Benchmark + def catsBench(warmup: CatsForkWarmup, state: CatsState): Int = + import state.{*, given} + forkCats: + cats.effect.IO.async[Int]: cb => + val observer = new StreamObserver[Response]: + private var count: Int = 0 + def onNext(response: Response): Unit = count += 1 + def onError(t: Throwable): Unit = cb(Left(t)) + def onCompleted(): Unit = cb(Right(count)) + + cats.effect.IO: + val requestObserver = stub.manyToMany(observer) + requests.foreach(requestObserver.onNext) + requestObserver.onCompleted() + None + end catsBench + + @Benchmark + def kyoBench(warmup: KyoForkWarmup, state: KyoState): Int = + import state.* + forkKyo: + Promise.initWith[Int, Abort[Throwable]]: promise => + val observer = new StreamObserver[Response]: + private var count: Int = 0 + def onNext(response: Response): Unit = count += 1 + def onError(t: Throwable): Unit = + import AllowUnsafe.embrace.danger + discard(promise.unsafe.complete(Result.fail(t))) + end onError + def onCompleted(): Unit = + import AllowUnsafe.embrace.danger + discard(promise.unsafe.complete(Result.succeed(count))) + + val run = Async.defer: + val requestObserver = stub.manyToMany(observer) + requests.foreach(requestObserver.onNext) + requestObserver.onCompleted() + run.andThen(promise.get) + end kyoBench + + @Benchmark + def zioBench(warmup: ZIOForkWarmup, state: ZIOState): Int = + import state.{*, given} + forkZIO: + ZIO.async[Any, Throwable, Int]: cb => + val observer = new StreamObserver[Response]: + private var count: Int = 0 + def onNext(response: Response): Unit = count += 1 + def onError(t: Throwable): Unit = cb(ZIO.fail(t)) + def onCompleted(): Unit = cb(ZIO.succeed(count)) + + val requestObserver = stub.manyToMany(observer) + requests.foreach(requestObserver.onNext) + requestObserver.onCompleted() + end zioBench + +end GrpcServerManyToManyBench diff --git a/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcServerManyToOneBench.scala b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcServerManyToOneBench.scala new file mode 100644 index 000000000..6353f78c0 --- /dev/null +++ b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcServerManyToOneBench.scala @@ -0,0 +1,87 @@ +package kyo.bench.arena.grpc + +import GrpcServerBench.* +import GrpcService.* +import io.grpc.stub.StreamObserver +import java.util.NoSuchElementException +import kgrpc.bench.* +import kyo.* +import kyo.bench.arena.ArenaBench2 +import kyo.bench.arena.WarmupJITProfile.CatsForkWarmup +import kyo.bench.arena.WarmupJITProfile.KyoForkWarmup +import kyo.bench.arena.WarmupJITProfile.ZIOForkWarmup +import org.openjdk.jmh.annotations.* +import zio.ZIO + +class GrpcServerManyToOneBench extends ArenaBench2(response): + + private given Frame = Frame.internal + + @Benchmark + def catsBench(warmup: CatsForkWarmup, state: CatsState): Response = + import state.{*, given} + forkCats: + cats.effect.IO.async[Response]: cb => + val observer = new StreamObserver[Response]: + private var response: Maybe[Response] = Maybe.empty + def onNext(response: Response): Unit = + if this.response.isDefined then throw IllegalStateException("Response already set.") + this.response = Maybe(response) + def onError(t: Throwable): Unit = + cb(Left(t)) + def onCompleted(): Unit = + cb(response.fold(Left(NoSuchElementException("No response")))(Right(_))) + + cats.effect.IO: + val requestObserver = stub.manyToOne(observer) + requests.foreach(requestObserver.onNext) + requestObserver.onCompleted() + None + end catsBench + + @Benchmark + def kyoBench(warmup: KyoForkWarmup, state: KyoState): Response = + import state.* + forkKyo: + Promise.initWith[Response, Abort[Throwable]]: promise => + val observer = new StreamObserver[Response]: + private var response: Maybe[Response] = Maybe.empty + def onNext(response: Response): Unit = + if this.response.isDefined then throw IllegalStateException("Response already set.") + this.response = Maybe(response) + def onError(t: Throwable): Unit = + import AllowUnsafe.embrace.danger + discard(promise.unsafe.complete(Result.fail(t))) + def onCompleted(): Unit = + import AllowUnsafe.embrace.danger + discard( + promise.unsafe.complete(response.fold(Result.fail(NoSuchElementException("No response")))(Result.succeed(_))) + ) + end onCompleted + + val run = Async.defer: + val requestObserver = stub.manyToOne(observer) + requests.foreach(requestObserver.onNext) + requestObserver.onCompleted() + run.andThen(promise.get) + end kyoBench + + @Benchmark + def zioBench(warmup: ZIOForkWarmup, state: ZIOState): Response = + import state.{*, given} + forkZIO: + ZIO.async[Any, Throwable, Response]: cb => + val observer = new StreamObserver[Response]: + private var response: Option[Response] = None + def onNext(response: Response): Unit = + if this.response.isDefined then throw IllegalStateException("Response already set.") + this.response = Some(response) + def onError(t: Throwable): Unit = cb(ZIO.fail(t)) + def onCompleted(): Unit = cb(response.fold(ZIO.fail(NoSuchElementException("No response")))(ZIO.succeed(_))) + + val requestObserver = stub.manyToOne(observer) + requests.foreach(requestObserver.onNext) + requestObserver.onCompleted() + end zioBench + +end GrpcServerManyToOneBench diff --git a/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcServerOneToManyBench.scala b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcServerOneToManyBench.scala new file mode 100644 index 000000000..8f8a6a7d6 --- /dev/null +++ b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcServerOneToManyBench.scala @@ -0,0 +1,48 @@ +package kyo.bench.arena.grpc + +import GrpcServerBench.* +import GrpcService.* +import io.grpc.* +import java.util.concurrent.TimeoutException +import java.util.concurrent.TimeUnit +import kgrpc.bench.* +import kgrpc.bench.TestServiceGrpc.TestServiceBlockingStub +import kyo.* +import kyo.bench.arena.ArenaBench2 +import kyo.bench.arena.WarmupJITProfile.CatsForkWarmup +import kyo.bench.arena.WarmupJITProfile.KyoForkWarmup +import kyo.bench.arena.WarmupJITProfile.ZIOForkWarmup +import kyo.kernel.ContextEffect +import org.openjdk.jmh.annotations.* +import scala.compiletime.uninitialized +import scalapb.zio_grpc +import zio.ZIO + +class GrpcServerOneToManyBench extends ArenaBench2(size): + + @Benchmark + def catsBench(warmup: CatsForkWarmup, state: CatsState): Int = + import state.{*, given} + forkCats: + cats.effect.IO(consume(blockingStub.oneToMany(request))) + end catsBench + + @Benchmark + def kyoBench(warmup: KyoForkWarmup, state: KyoState): Int = + import state.* + forkKyo: + Sync.defer(consume(blockingStub.oneToMany(request))) + end kyoBench + + @Benchmark + def zioBench(warmup: ZIOForkWarmup, state: ZIOState): Int = + import state.{*, given} + forkZIO: + ZIO.attempt(consume(blockingStub.oneToMany(request))).orDie + end zioBench + + // Consume the iterator otherwise Netty has a hissy fit about resource leaks. + private def consume(replies: Iterator[Response]): Int = + scala.collection.immutable.LazyList.from(replies).foldLeft(0)((acc, _) => acc + 1) + +end GrpcServerOneToManyBench diff --git a/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcServerUnaryBench.scala b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcServerUnaryBench.scala new file mode 100644 index 000000000..ede88ab74 --- /dev/null +++ b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcServerUnaryBench.scala @@ -0,0 +1,44 @@ +package kyo.bench.arena.grpc + +import GrpcServerBench.* +import GrpcService.* +import io.grpc.* +import java.util.concurrent.TimeoutException +import java.util.concurrent.TimeUnit +import kgrpc.bench.* +import kgrpc.bench.TestServiceGrpc.TestServiceBlockingStub +import kyo.* +import kyo.bench.arena.ArenaBench2 +import kyo.bench.arena.WarmupJITProfile.CatsForkWarmup +import kyo.bench.arena.WarmupJITProfile.KyoForkWarmup +import kyo.bench.arena.WarmupJITProfile.ZIOForkWarmup +import kyo.kernel.ContextEffect +import org.openjdk.jmh.annotations.* +import scala.compiletime.uninitialized +import scalapb.zio_grpc +import zio.ZIO + +class GrpcServerUnaryBench extends ArenaBench2(response): + + @Benchmark + def catsBench(warmup: CatsForkWarmup, state: CatsState): Response = + import state.{*, given} + forkCats: + cats.effect.IO(blockingStub.oneToOne(request)) + end catsBench + + @Benchmark + def kyoBench(warmup: KyoForkWarmup, state: KyoState): Response = + import state.* + forkKyo: + Sync.defer(blockingStub.oneToOne(request)) + end kyoBench + + @Benchmark + def zioBench(warmup: ZIOForkWarmup, state: ZIOState): Response = + import state.{*, given} + forkZIO: + ZIO.attempt(blockingStub.oneToOne(request)).orDie + end zioBench + +end GrpcServerUnaryBench diff --git a/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcService.scala b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcService.scala new file mode 100644 index 000000000..fcceb6da8 --- /dev/null +++ b/kyo-bench/src/main/scala/kyo/bench/arena/grpc/GrpcService.scala @@ -0,0 +1,209 @@ +package kyo.bench.arena.grpc + +import cats.effect +import cats.effect.IO.given +import cats.effect.IO as CIO +import fs2.grpc.syntax.all.* +import io.grpc.ManagedChannelBuilder +import io.grpc.Metadata +import io.grpc.Server +import io.grpc.ServerBuilder +import io.grpc.StatusException +import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder +import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder +import java.net.ServerSocket +import java.util.concurrent.TimeUnit +import kgrpc.bench.* +import kyo.* +import kyo.grpc.* +import scala.language.implicitConversions +import scalapb.zio_grpc +import scalapb.zio_grpc.ScopedServer +import scalapb.zio_grpc.ZChannel +import zio.ZIO +import zio.given +import zio.stream + +object GrpcService: + + given Frame = Frame.internal + + val host = "localhost" + val size = 10 + val sizeSquared: Int = size ^ 2 + val message = "Hello" + val request: Request = Request(message) + val response: Response = Response(message) + val requests: Chunk[Request] = Chunk.fill(GrpcService.size)(Request(message)) + + def createCatsServer(port: Int, static: Boolean): cats.effect.Resource[CIO, Server] = + val service = if static then StaticCatsTestService(size) else CatsTestService(size) + TestServiceFs2Grpc + .bindServiceResource[CIO](service) + .flatMap: service => + NettyServerBuilder + .forPort(port) + .addService(service) + .resourceWithShutdown { server => + for + _ <- CIO(server.shutdown()) + terminated <- CIO.interruptible(server.awaitTermination(30, TimeUnit.SECONDS)) + _ <- CIO.unlessA(terminated)(CIO.interruptible(server.shutdownNow().awaitTermination())) + yield () + } + .evalMap(server => CIO(server.start())) + end createCatsServer + + def createCatsClient(port: Int): cats.effect.Resource[CIO, TestServiceFs2Grpc[CIO, Metadata]] = + NettyChannelBuilder + .forAddress(host, port) + .usePlaintext() + .resource[CIO] + .flatMap(TestServiceFs2Grpc.stubResource[CIO](_)) + end createCatsClient + + def createKyoServer(port: Int, static: Boolean): grpc.Server < (Scope & Sync) = + val service = if static then StaticKyoTestService(size) else KyoTestService(size) + kyo.grpc.Server.start(port)(_.addService(service)) + end createKyoServer + + def createKyoClient(port: Int): TestService.Client < (Scope & Sync) = + TestService.managedClient(host, port)(_.usePlaintext()) + + def createZioServer(port: Int, static: Boolean): ZIO[zio.Scope, Throwable, zio_grpc.Server] = + val service = if static then StaticZIOTestService(size) else ZIOTestService(size) + ScopedServer.fromService(ServerBuilder.forPort(port), service) + end createZioServer + + def createZioClient(port: Int): ZIO[zio.Scope, Throwable, ZioBench.TestServiceClient] = + val builder = ManagedChannelBuilder.forAddress(host, port).usePlaintext() + val channel = ZIO.acquireRelease { + ZIO.attempt(ZChannel(builder.build(), None, Nil)) + } { c => + c.shutdown().flatMap(_ => c.awaitTermination(30.seconds)).ignore + } + ZioBench.TestServiceClient.scoped(channel) + end createZioClient + + def findFreePort(): Int = + val socket = ServerSocket(0) + try + socket.getLocalPort + finally + socket.close() + end try + end findFreePort + +end GrpcService + +class CatsTestService(size: Int) extends TestServiceFs2Grpc[CIO, Metadata]: + + import cats.effect.* + + override def oneToOne(request: kgrpc.bench.Request, ctx: Metadata): IO[Response] = + IO.pure(Response(request.message)) + + override def oneToMany(request: Request, ctx: Metadata): fs2.Stream[IO, Response] = + fs2.Stream.chunk(fs2.Chunk.constant(Response(request.message), size)) + + override def manyToOne(requests: fs2.Stream[IO, Request], ctx: Metadata): IO[Response] = + requests + .compile + .last + .map(maybe => Response(maybe.fold("")(_.message))) + + override def manyToMany(requests: fs2.Stream[IO, Request], ctx: Metadata): fs2.Stream[IO, Response] = + requests.flatMap(oneToMany(_, ctx)) + +end CatsTestService + +class KyoTestService(size: Int)(using Frame) extends TestService: + + override def oneToOne(request: Request): Response < Any = + Response(request.message) + + override def oneToMany(request: Request): Stream[Response, Grpc] < Grpc = + Stream.init(Chunk.fill(size)(Response(request.message))) + + override def manyToOne(requests: Stream[Request, Grpc]): Response < Grpc = + Sink.fold[Maybe[Request], Request](Absent)((_, v) => Present(v)).drain(requests).map(maybe => Response(maybe.fold("")(_.message))) + + override def manyToMany(requests: Stream[Request, Grpc]): Stream[Response, Grpc] < Grpc = + requests.flatMap(oneToMany) + +end KyoTestService + +class ZIOTestService(size: Int) extends ZioBench.TestService: + + override def oneToOne(request: Request): ZIO[Any, StatusException, Response] = + ZIO.succeed(Response(request.message)) + + override def oneToMany(request: Request): stream.Stream[StatusException, Response] = + stream.ZStream.fromChunk(zio.Chunk.fill(size)(Response(request.message))) + + override def manyToOne(requests: stream.Stream[StatusException, Request]): zio.IO[StatusException, Response] = + requests.runLast.map(maybe => Response(maybe.fold("")(_.message))) + + override def manyToMany(requests: stream.Stream[StatusException, Request]): stream.Stream[StatusException, Response] = + requests.flatMap(oneToMany) + +end ZIOTestService + +class StaticCatsTestService(size: Int) extends TestServiceFs2Grpc[CIO, Metadata]: + + import cats.effect.* + + private val response = Response("response") + private val responses = fs2.Chunk.constant(response, size) + + override def oneToOne(request: kgrpc.bench.Request, ctx: Metadata): IO[Response] = + IO.pure(response) + + override def oneToMany(request: Request, ctx: Metadata): fs2.Stream[IO, Response] = + fs2.Stream.chunk(responses) + + override def manyToOne(requests: fs2.Stream[IO, Request], ctx: Metadata): IO[Response] = + requests.compile.drain.map(_ => response) + + override def manyToMany(requests: fs2.Stream[IO, Request], ctx: Metadata): fs2.Stream[IO, Response] = + requests.flatMap(oneToMany(_, ctx)) + +end StaticCatsTestService + +class StaticKyoTestService(size: Int)(using Frame) extends TestService: + + private val response = Response("response") + private val responses = Chunk.fill(size)(response) + + override def oneToOne(request: Request): Response < Any = + response + + override def oneToMany(request: Request): Stream[Response, Grpc] < Grpc = + Stream.init(responses) + + override def manyToOne(requests: Stream[Request, Grpc]): Response < Grpc = + requests.discard.andThen(response) + + override def manyToMany(requests: Stream[Request, Grpc]): Stream[Response, Grpc] < Grpc = + requests.flatMap(oneToMany) + +end StaticKyoTestService + +class StaticZIOTestService(size: Int) extends ZioBench.TestService: + + private val response = Response("response") + private val responses = zio.Chunk.fill(size)(response) + + override def oneToOne(request: Request): ZIO[Any, StatusException, Response] = + ZIO.succeed(response) + + override def oneToMany(request: Request): stream.Stream[StatusException, Response] = + stream.ZStream.fromChunk(responses) + + override def manyToOne(requests: stream.Stream[StatusException, Request]): zio.IO[StatusException, Response] = + requests.runDrain.as(response) + + override def manyToMany(requests: stream.Stream[StatusException, Request]): stream.Stream[StatusException, Response] = + requests.flatMap(oneToMany) + +end StaticZIOTestService diff --git a/kyo-bench/src/test/scala/kyo/bench/arena/BenchTest.scala b/kyo-bench/src/test/scala/kyo/bench/arena/BenchTest.scala index 2891bfa4a..f4f5e9b54 100644 --- a/kyo-bench/src/test/scala/kyo/bench/arena/BenchTest.scala +++ b/kyo-bench/src/test/scala/kyo/bench/arena/BenchTest.scala @@ -42,7 +42,8 @@ abstract class BenchTest extends AsyncFreeSpec with Assertions: b match case b: Fork[A] => s"fork$target" in { - assert(runFork(b) == b.expectedResult) + val result = runFork(b) + assert(result == b.expectedResult) detectRuntimeLeak() } case _ => diff --git a/kyo-core/shared/src/main/scala/kyo/Async.scala b/kyo-core/shared/src/main/scala/kyo/Async.scala index 838707b1c..2aa525d18 100644 --- a/kyo-core/shared/src/main/scala/kyo/Async.scala +++ b/kyo-core/shared/src/main/scala/kyo/Async.scala @@ -180,6 +180,19 @@ object Async extends AsyncPlatformSpecific: end if end timeout + def tapFiber[E, A, S, S2](using isolate: Isolate[S, Abort[E] & Async, S]) + (v: => A < (Abort[E] & Async & S)) + (f: Fiber[Any, Abort[E] & S] => Unit < S2) + (using frame: Frame): A < (kyo.Abort[E] & kyo.Async & S & S2) = + isolate.capture { state => + Fiber.initUnscoped(isolate.isolate(state, v)).map { fiber => + f(fiber).andThen { + isolate.restore(fiber.get) + } + } + } + end tapFiber + /** Races multiple computations and returns the result of the first successful computation to complete. When one computation succeeds, * all other computations are interrupted. * diff --git a/kyo-core/shared/src/main/scala/kyo/Channel.scala b/kyo-core/shared/src/main/scala/kyo/Channel.scala index 7f9c544aa..686d97fdd 100644 --- a/kyo-core/shared/src/main/scala/kyo/Channel.scala +++ b/kyo-core/shared/src/main/scala/kyo/Channel.scala @@ -219,21 +219,51 @@ object Channel: /** Closes the channel and asynchronously waits until it's empty. * * This method closes the channel to new elements and returns a computation that completes when all elements have been consumed. - * Unlike the regular `close` method, this allows consumers to process all remaining elements before considering the channel fully - * closed. + * Unlike the regular [[close]] method, this allows consumers to process all remaining elements before considering the channel + * fully closed. * * @return - * true if the channel was successfully closed and emptied, false if it was already closed + * `true` if the channel was successfully closed and emptied, `false` if it was already closed */ def closeAwaitEmpty(using Frame): Boolean < Async = Sync.Unsafe.defer(self.closeAwaitEmpty().safe.get) + // TODO: I think this can be removed now. + /** Closes the channel and returns the [[Fiber]] waits until it's empty. + * + * This method closes the channel to new elements and returns a `Fiber` that completes when all elements have been consumed. Unlike + * the regular [[close]] method, this allows consumers to process all remaining elements before considering the channel fully + * closed. + * + * This differs from [[closeAwaitEmpty]] in that once the `Fiber` has been obtained it guarantees to have begun closing the channel + * and future offers to the channel will abort with [[Closed]] even if the channel is not yet completely closed. On the other hand, + * when handling the `Async` effect from `closeAwaitEmpty` the `Fiber` it returns may not have started closing the channel yet. + * + * @return + * a `Fiber` that completes with `true` if the channel was successfully closed and emptied, `false` if it was already closed + */ + def closeAwaitEmptyFiber(using Frame): Fiber[Boolean, Any] < Sync = Sync.Unsafe.defer(self.closeAwaitEmpty().safe) + /** Checks if the channel is closed. + * + * A channel is considered closed if it has fully closed, i.e. it is not open and it is empty. + * + * This will always be `true` after [[close]]. In the case of [[closeAwaitEmpty]] and [[closeAwaitEmptyFiber]], it will only be + * `true` once the channel has been emptied. * * @return - * true if the channel is closed, false otherwise + * `true` if the channel is closed, `false` otherwise */ def closed(using Frame): Boolean < Sync = Sync.Unsafe.defer(self.closed()) + /** Checks if the channel is open. + * + * A channel is considered open if it has not begun closing, and it may still accept new elements (although it might be full). + * + * @return + * `true` if the channel is open, `false` otherwise + */ + def open(using Frame): Boolean < Sync = Sync.Unsafe.defer(self.open()) + /** Checks if the channel is empty. * * @return @@ -410,6 +440,7 @@ object Channel: def empty()(using AllowUnsafe, Frame): Result[Closed, Boolean] def full()(using AllowUnsafe, Frame): Result[Closed, Boolean] def closed()(using AllowUnsafe): Boolean + def open()(using AllowUnsafe): Boolean def safe: Channel[A] = this end Unsafe @@ -606,6 +637,7 @@ object Channel: def empty()(using AllowUnsafe, Frame) = succeedIfOpen(true) def full()(using AllowUnsafe, Frame) = succeedIfOpen(true) def closed()(using AllowUnsafe) = isClosed.get() + def open()(using AllowUnsafe) = !isClosed.get() @tailrec protected def flush()(using Frame): Unit = // This method ensures that all values are processed @@ -762,6 +794,7 @@ object Channel: def empty()(using AllowUnsafe, Frame) = queue.empty() def full()(using AllowUnsafe, Frame) = queue.full() def closed()(using AllowUnsafe) = queue.closed() + def open()(using AllowUnsafe) = queue.open() @tailrec protected def flush()(using Frame): Unit = // This method ensures that all values are processed diff --git a/kyo-core/shared/src/main/scala/kyo/Queue.scala b/kyo-core/shared/src/main/scala/kyo/Queue.scala index 356107d5c..e1382f5a8 100644 --- a/kyo-core/shared/src/main/scala/kyo/Queue.scala +++ b/kyo-core/shared/src/main/scala/kyo/Queue.scala @@ -124,22 +124,51 @@ object Queue: /** Closes the queue and asynchronously waits until it's empty. * * This method closes the queue to new elements and returns a computation that completes when all elements have been consumed. - * Unlike the regular `close` method, this allows consumers to process all remaining elements before considering the queue fully + * Unlike the regular [[close]] method, this allows consumers to process all remaining elements before considering the queue fully * closed. * * @return - * true if the queue was successfully closed and emptied, false if it was already closed or another closeAwaitEmpty is already - * running. + * `true` if the queue was successfully closed and emptied, `false` if it was already closed or another `closeAwaitEmpty` is + * already running. */ def closeAwaitEmpty(using Frame): Boolean < Async = Sync.Unsafe.defer(self.closeAwaitEmpty().safe.get) + /** Closes the queue and returns the [[Fiber]] waits until it's empty. + * + * This method closes the queue to new elements and returns a `Fiber` that completes when all elements have been consumed. Unlike + * the regular [[close]] method, this allows consumers to process all remaining elements before considering the queue fully closed. + * + * This differs from [[closeAwaitEmpty]] in that once the `Fiber` has been obtained it guarantees to have begun closing the queue + * and future offers to the queue will abort with [[Closed]] even if the queue is not yet completely closed. On the other hand, + * when handling the `Async` effect from `closeAwaitEmpty` the `Fiber` it returns may not have started closing the queue yet. + * + * @return + * A `Fiber` that completes with `true` if the queue was successfully closed and emptied, `false` if it was already closed or + * another `closeAwaitEmpty` is already running. + */ + def closeAwaitEmptyFiber(using Frame): Fiber[Boolean, Any] < Sync = Sync.Unsafe.defer(self.closeAwaitEmpty().safe) + /** Checks if the queue is closed. + * + * A queue is considered closed if it has fully closed, i.e. it is not open and it is empty. + * + * This will always be `true` after [[close]]. In the case of [[closeAwaitEmpty]] and [[closeAwaitEmptyFiber]], it will only be + * `true` once the queue has been emptied. * * @return - * true if the queue is closed, false otherwise + * `true` if the queue is closed, `false` otherwise */ def closed(using Frame): Boolean < Sync = Sync.Unsafe.defer(self.closed()) + /** Checks if the queue is open. + * + * A queue is considered open if it has not begun closing, and it may still accept new elements (although it might be full). + * + * @return + * `true` if the queue is open, `false` otherwise + */ + def open(using Frame): Boolean < Sync = Sync.Unsafe.defer(self.open()) + /** Returns the unsafe version of the queue. * * @return @@ -484,6 +513,7 @@ object Queue: def close()(using Frame, AllowUnsafe) = underlying.close() def closeAwaitEmpty()(using Frame, AllowUnsafe) = underlying.closeAwaitEmpty() def closed()(using AllowUnsafe): Boolean = underlying.closed() + def open()(using AllowUnsafe): Boolean = underlying.open() end new end initDropping @@ -516,6 +546,7 @@ object Queue: def close()(using Frame, AllowUnsafe) = underlying.close() def closeAwaitEmpty()(using Frame, AllowUnsafe) = underlying.closeAwaitEmpty() def closed()(using AllowUnsafe): Boolean = underlying.closed() + def open()(using AllowUnsafe): Boolean = underlying.open() end new end initSliding end Unsafe @@ -535,6 +566,7 @@ object Queue: def close()(using Frame, AllowUnsafe): Maybe[Seq[A]] def closeAwaitEmpty()(using Frame, AllowUnsafe): Fiber.Unsafe[Boolean, Any] def closed()(using AllowUnsafe): Boolean + def open()(using AllowUnsafe): Boolean final def safe: Queue[A] = this end Unsafe @@ -576,6 +608,9 @@ object Queue: final def closed()(using AllowUnsafe) = state.get().isInstanceOf[State.FullyClosed] + final def open()(using AllowUnsafe) = + state.get() eq State.Open + final def drainUpTo(max: Int)(using AllowUnsafe): Result[Closed, Chunk[A]] = pollOp(_drain(Maybe.Present(max))) final def drain()(using AllowUnsafe): Result[Closed, Chunk[A]] = pollOp(_drain()) diff --git a/kyo-core/shared/src/test/scala/kyo/ChannelTest.scala b/kyo-core/shared/src/test/scala/kyo/ChannelTest.scala index be415e439..67671170a 100644 --- a/kyo-core/shared/src/test/scala/kyo/ChannelTest.scala +++ b/kyo-core/shared/src/test/scala/kyo/ChannelTest.scala @@ -515,6 +515,30 @@ class ChannelTest extends Test: t <- Abort.run[Throwable](c.put(1)) yield assert(r == Maybe(Seq()) && d.isFailure && t.isFailure) } + "states" in run { + for + c <- Channel.init[Int](1) + closed1 <- c.closed + open1 <- c.open + closed2 <- c.closed + open2 <- c.open + _ <- c.close + closed3 <- c.closed + open3 <- c.open + yield assert(!closed1 && open1 && !closed2 && open2 && closed3 && !open3) + } + "states no buffer" in run { + for + c <- Channel.init[Int](0) + closed1 <- c.closed + open1 <- c.open + closed2 <- c.closed + open2 <- c.open + _ <- c.close + closed3 <- c.closed + open3 <- c.open + yield assert(!closed1 && open1 && !closed2 && open2 && closed3 && !open3) + } } "no buffer" in run { for @@ -1038,7 +1062,8 @@ class ChannelTest extends Test: for c <- Channel.init[Int](10) result <- c.closeAwaitEmpty - yield assert(result) + closed <- c.closed + yield assert(result && closed) } "returns true when channel becomes empty after closing" in run { @@ -1211,6 +1236,93 @@ class ChannelTest extends Test: } } + "closeAwaitEmptyFiber" - { + "returns true when channel is already empty" in run { + for + c <- Channel.init[Int](10) + result <- c.closeAwaitEmpty + closed <- c.closed + open <- c.open + yield assert(result && closed && !open) + } + + "returns true when channel becomes empty after closing" in run { + for + c <- Channel.init[Int](10) + _ <- c.put(1) + _ <- c.put(2) + fiber <- c.closeAwaitEmptyFiber + closed1 <- c.closed + open1 <- c.open + _ <- c.take + closed2 <- c.closed + open2 <- c.open + _ <- c.take + result <- fiber.get + closed3 <- c.closed + open3 <- c.open + yield assert( + !closed1 && + !open1 && + !closed2 && + !open2 && + result && + closed3 && + !open3 + ) + } + + "returns false if channel is already closed" in run { + for + c <- Channel.init[Int](10) + _ <- c.close + result <- c.closeAwaitEmpty + yield assert(!result) + } + + "concurrent taking and waiting" in run { + for + c <- Channel.init[Int](10) + _ <- Kyo.foreach(1 to 5)(i => c.put(i)) + fiber <- c.closeAwaitEmptyFiber + _ <- Async.foreach(1 to 5)(_ => c.take) + result <- fiber.get + yield assert(result) + } + + "zero capacity channel" in run { + for + c <- Channel.init[Int](0) + result <- c.closeAwaitEmpty + yield assert(result) + } + + "should discard new takes" in run { + for + c <- Channel.init[Int](2) + _ <- c.put(1) + _ <- c.put(2) + fiber <- c.closeAwaitEmptyFiber + _ <- c.take + _ <- c.take + take <- Abort.run(c.take) + result <- fiber.get + yield assert(result && take.isFailure) + } + + "concurrent closeAwaitEmpty calls" in run { + for + c <- Channel.init[Int](10) + _ <- c.put(1) + _ <- c.put(2) + fiber <- Fiber.initUnscoped(Async.fill(10)(c.closeAwaitEmpty)) + _ <- c.take + _ <- c.take + closes <- fiber.get + yield assert(closes.count(identity) == 1) + } + } + "pendingPuts and pendingTakes" - { "should return 0 for empty channel" in run { for diff --git a/kyo-core/shared/src/test/scala/kyo/QueueTest.scala b/kyo-core/shared/src/test/scala/kyo/QueueTest.scala index eaca24eba..fda5a83d4 100644 --- a/kyo-core/shared/src/test/scala/kyo/QueueTest.scala +++ b/kyo-core/shared/src/test/scala/kyo/QueueTest.scala @@ -88,30 +88,43 @@ class QueueTest extends Test: } } - "close" in runNotNative { - for - q <- Queue.init[Int](2) - b <- q.offer(1) - c1 <- q.close - v1 <- Abort.run(q.size) - v2 <- Abort.run(q.empty) - v3 <- Abort.run(q.full) - v4 <- Abort.run(q.offer(2)) - v5 <- Abort.run(q.poll) - v6 <- Abort.run(q.peek) - v7 <- Abort.run(q.drain) - c2 <- q.close - yield assert( - b && c1 == Maybe(Seq(1)) && - v1.isFailure && - v2.isFailure && - v3.isFailure && - v4.isFailure && - v5.isFailure && - v6.isFailure && - v7.isFailure && - c2.isEmpty - ) + "close" - { + "allowed following ops" in runNotNative { + for + q <- Queue.init[Int](2) + b <- q.offer(1) + c1 <- q.close + v1 <- Abort.run(q.size) + v2 <- Abort.run(q.empty) + v3 <- Abort.run(q.full) + v4 <- Abort.run(q.offer(2)) + v5 <- Abort.run(q.poll) + v6 <- Abort.run(q.peek) + v7 <- Abort.run(q.drain) + c2 <- q.close + yield assert( + b && c1 == Maybe(Seq(1)) && + v1.isFailure && + v2.isFailure && + v3.isFailure && + v4.isFailure && + v5.isFailure && + v6.isFailure && + v7.isFailure && + c2.isEmpty + ) + } + "states" in runNotNative { + for + q <- Queue.init[Int](2) + closed1 <- q.closed + open1 <- q.open + _ <- q.offer(1) + _ <- q.close + closed2 <- q.closed + open2 <- q.open + yield assert(!closed1 && open1 && closed2 && !open2) + } } "drain" in runNotNative { @@ -585,19 +598,22 @@ class QueueTest extends Test: for queue <- Queue.init[Int](10) result <- queue.closeAwaitEmpty - yield assert(result) + closed <- queue.closed + yield assert(result && closed) } "returns true when queue becomes empty after closing" in runNotNative { for - queue <- Queue.init[Int](10) - _ <- queue.offer(1) - _ <- queue.offer(2) - fiber <- Fiber.initUnscoped(queue.closeAwaitEmpty) - _ <- queue.poll - _ <- queue.poll - result <- fiber.get - yield assert(result) + queue <- Queue.init[Int](10) + _ <- queue.offer(1) + _ <- queue.offer(2) + fiber <- Fiber.initUnscoped(queue.closeAwaitEmpty) + closed1 <- queue.closed + _ <- queue.poll + _ <- queue.poll + result <- fiber.get + closed2 <- queue.closed + yield assert(!closed1 && result && closed2) } "returns false if queue is already closed" in runNotNative { @@ -776,4 +792,213 @@ class QueueTest extends Test: } } + "closeAwaitEmptyFiber" - { + "returns true when queue is already empty" in runNotNative { + for + queue <- Queue.init[Int](10) + result <- queue.closeAwaitEmptyFiber.map(_.get) + closed <- queue.closed + open <- queue.open + yield assert(result && closed && !open) + } + + "returns true when queue becomes empty after closing" in runNotNative { + for + queue <- Queue.init[Int](10) + _ <- queue.offer(1) + _ <- queue.offer(2) + fiber <- queue.closeAwaitEmptyFiber + closed1 <- queue.closed + open1 <- queue.open + _ <- queue.poll + _ <- queue.poll + result <- fiber.get + closed2 <- queue.closed + open2 <- queue.open + yield assert( + !closed1 && + !open1 && + !open2 && + result && + closed2 && + !open2 + ) + } + + "returns false if queue is already closed" in runNotNative { + for + queue <- Queue.init[Int](10) + _ <- queue.close + result <- queue.closeAwaitEmptyFiber.map(_.get) + yield assert(!result) + } + + "unbounded queue" - { + "returns true when queue is already empty" in runNotNative { + for + queue <- Queue.Unbounded.init[Int]() + result <- queue.closeAwaitEmptyFiber.map(_.get) + yield assert(result) + } + + "returns true when queue becomes empty after closing" in runNotNative { + for + queue <- Queue.Unbounded.init[Int]() + _ <- queue.add(1) + _ <- queue.add(2) + fiber <- queue.closeAwaitEmptyFiber + _ <- queue.poll + _ <- queue.poll + result <- fiber.get + yield assert(result) + } + } + + "concurrent polling and waiting" in runNotNative { + for + queue <- Queue.init[Int](10) + _ <- Kyo.foreach(1 to 5)(i => queue.offer(i)) + fiber <- queue.closeAwaitEmptyFiber + _ <- Async.foreach(1 to 5)(_ => queue.poll) + result <- fiber.get + yield assert(result) + } + + "sliding queue" in runNotNative { + for + queue <- Queue.Unbounded.initSliding[Int](2) + _ <- queue.add(1) + _ <- queue.add(2) + fiber <- queue.closeAwaitEmptyFiber + _ <- queue.poll + _ <- queue.poll + result <- fiber.get + yield assert(result) + } + + "dropping queue" in runNotNative { + for + queue <- Queue.Unbounded.initDropping[Int](2) + _ <- queue.add(1) + _ <- queue.add(2) + fiber <- queue.closeAwaitEmptyFiber + _ <- queue.poll + _ <- queue.poll + result <- fiber.get + yield assert(result) + } + + "zero capacity queue" in runNotNative { + for + queue <- Queue.init[Int](0) + result <- queue.closeAwaitEmptyFiber.map(_.get) + yield assert(result) + } + + "race between closeAwaitEmpty and close" in runNotNative { + (for + size <- Choice.eval(0, 1, 2, 10, 100) + queue <- Queue.init[Int](size) + _ <- Kyo.foreach(1 to (size min 5))(i => queue.offer(i)) + latch <- Latch.init(1) + closeAwaitEmptyFiber <- Fiber.initUnscoped( + latch.await.andThen(queue.closeAwaitEmptyFiber.map(_.get)) + ) + closeFiber <- Fiber.initUnscoped( + latch.await.andThen(queue.close) + ) + _ <- latch.release + _ <- Abort.run(queue.drain) + result1 <- closeAwaitEmptyFiber.get + result2 <- closeFiber.get + isClosed <- queue.closed + yield + assert(isClosed) + assert((result1 && result2.isEmpty) || (!result1 && result2.isDefined)) + ) + .handle(Choice.run, _.unit, Loop.repeat(10)) + .andThen(succeed) + } + + "two producers calling closeAwaitEmpty" in runNotNative { + (for + size <- Choice.eval(0, 1, 2, 10, 100) + queue <- Queue.init[Int](size) + latch <- Latch.init(1) + + producerFiber1 <- Fiber.initUnscoped( + latch.await.andThen( + Async.foreach(1 to 25, 10)(i => Abort.run(queue.offer(i))) + .andThen(queue.closeAwaitEmptyFiber.map(_.get)) + ) + ) + producerFiber2 <- Fiber.initUnscoped( + latch.await.andThen( + Async.foreach(26 to 50, 10)(i => Abort.run(queue.offer(i))) + .andThen(queue.closeAwaitEmptyFiber.map(_.get)) + ) + ) + + consumerFiber <- Fiber.initUnscoped( + latch.await.andThen( + Async.fill(100, 10)(untilTrue(queue.poll.map(_.isDefined))) + ) + ) + + _ <- latch.release + result1 <- producerFiber1.getResult + result2 <- producerFiber2.getResult + isClosed <- queue.closed + _ <- consumerFiber.getResult + yield + assert(isClosed) + assert(Seq(result1, result2).count(_.contains(true)) == 1) + assert(Seq(result1, result2).count(r => r.contains(false) || r.isFailure) == 1) + ) + .handle(Choice.run, _.unit, Loop.repeat(10)) + .andThen(succeed) + } + + "producer calling closeAwaitEmpty and another calling close" in runNotNative { + (for + size <- Choice.eval(0, 1, 2, 10, 100) + queue <- Queue.init[Int](size) + latch <- Latch.init(1) + + producerFiber1 <- Fiber.initUnscoped( + latch.await.andThen( + Async.foreach(1 to 25, 10)(i => Abort.run(queue.offer(i))) + .andThen(queue.closeAwaitEmptyFiber.map(_.get)) + ) + ) + producerFiber2 <- Fiber.initUnscoped( + latch.await.andThen( + Async.foreach(26 to 50, 10)(i => Abort.run(queue.offer(i))) + .andThen(queue.close) + ) + ) + + consumerFiber <- Fiber.initUnscoped( + latch.await.andThen( + Async.fill(100, 10)(untilTrue(queue.poll.map(_.isDefined))) + ) + ) + + _ <- latch.release + result1 <- producerFiber1.getResult + result2 <- producerFiber2.getResult + isClosed <- queue.closed + _ <- consumerFiber.getResult + yield + assert(isClosed) + assert( + (result1.isFailure || result1.contains(false)) && !result2.contains(Absent) || + (result1.contains(true)) && result2.contains(Absent) + ) + ) + .handle(Choice.run, _.unit, Loop.repeat(10)) + .andThen(succeed) + } + } + end QueueTest diff --git a/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/CodeGenerator.scala b/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/CodeGenerator.scala new file mode 100644 index 000000000..ee1a3c5c6 --- /dev/null +++ b/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/CodeGenerator.scala @@ -0,0 +1,66 @@ +package kyo.grpc.compiler + +import com.google.protobuf.Descriptors.FileDescriptor +import com.google.protobuf.ExtensionRegistry +import com.google.protobuf.compiler.PluginProtos.CodeGeneratorResponse +import kyo.grpc.compiler.internal.FilePrinter +import protocbridge.Artifact +import protocgen.CodeGenApp +import protocgen.CodeGenRequest +import protocgen.CodeGenResponse +import scala.jdk.CollectionConverters.* +import scalapb.compiler.DescriptorImplicits +import scalapb.compiler.GeneratorException +import scalapb.compiler.ProtobufGenerator +import scalapb.options.Scalapb + +object CodeGenerator extends CodeGenApp { + + override def registerExtensions(registry: ExtensionRegistry): Unit = + Scalapb.registerAllExtensions(registry) + + // When your code generator will be invoked from SBT via sbt-protoc, this will add the following + // artifact to your users build whenever the generator is used in `PB.targets`: + override def suggestedDependencies: Seq[Artifact] = + Seq( + Artifact( + BuildInfo.organization, + "kyo-grpc-core", + BuildInfo.version, + crossVersion = true + ) + ) + + // This is called by CodeGenApp after the request is parsed. + // Example: scalapb.compiler.ProtobufGenerator.handleCodeGeneratorRequest + def process(request: CodeGenRequest): CodeGenResponse = + ProtobufGenerator.parseParameters(request.parameter) match { + case Right(params) => + try { + val implicits = DescriptorImplicits.fromCodeGenRequest(params, request) + import implicits.ExtendedFileDescriptor + val files = request.filesToGenerate.filterNot(_.disableOutput).flatMap { file => + if (file.scalaOptions.getSingleFile) + Seq(singleFile(file, implicits)) + else + multipleFiles(file, implicits) + } + CodeGenResponse.succeed(files, Set(CodeGeneratorResponse.Feature.FEATURE_PROTO3_OPTIONAL)) + } catch { + case e: GeneratorException => + CodeGenResponse.fail(e.message) + } + case Left(error) => + CodeGenResponse.fail(error) + } + + private def singleFile(file: FileDescriptor, implicits: DescriptorImplicits) = + file.getServices.asScala.foldLeft(FilePrinter(file, implicits).addPackage) { (fp, service) => + fp.addService(service) + }.result + + private def multipleFiles(file: FileDescriptor, implicits: DescriptorImplicits) = + file.getServices.asScala.map { service => + FilePrinter(file, implicits).addPackage.addService(service).setNameFromService(service).result + } +} diff --git a/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/Choice.scala b/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/Choice.scala new file mode 100644 index 000000000..634077ef8 --- /dev/null +++ b/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/Choice.scala @@ -0,0 +1,20 @@ +package kyo.grpc.compiler.internal + +import scala.language.implicitConversions + +private[compiler] trait Choice { self => + + type A + + type Choose = self.type => A + + implicit def makeChoice(choose: Choose): A = choose(self) + + implicit final class ChooseOps(choose: Choose) { + def choice: A = choose(self) + } + + implicit final class ChoiceOps(choice: A) { + def choose: Choose = _ => choice + } +} diff --git a/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/ClassBuilder.scala b/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/ClassBuilder.scala new file mode 100644 index 000000000..2a9105c0e --- /dev/null +++ b/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/ClassBuilder.scala @@ -0,0 +1,63 @@ +package kyo.grpc.compiler.internal + +import kyo.grpc.compiler.internal +import org.typelevel.paiges.Doc +import org.typelevel.paiges.internal.ExtendedSyntax.* +import scalapb.compiler.FunctionalPrinter.PrinterEndo + +final private[compiler] case class ClassBuilder( + override val id: String, + override val annotations: Vector[Doc] = Vector.empty, + override val mods: Vector[Doc] = Vector.empty, + typeParameters: Vector[String] = Vector.empty, + parameterLists: Vector[Seq[Parameter]] = Vector.empty, + implicitParameters: Vector[Parameter] = Vector.empty, + override val parents: Vector[Doc] = Vector.empty, + override val body: Doc = Doc.empty +) extends TemplateBuilder { + + override protected def keyword: String = "class" + + def appendAnnotations(annotations: Seq[String]): ClassBuilder = + copy(annotations = this.annotations ++ annotations.map(Doc.text)) + + def appendMods(mods: Seq[String]): ClassBuilder = + copy(mods = this.mods ++ mods.map(Doc.text)) + + def appendTypeParameters(params: Seq[String]): ClassBuilder = + copy(typeParameters = typeParameters ++ params) + + def appendParameterList(params: Seq[Parameter]): ClassBuilder = + copy(parameterLists = parameterLists :+ params) + + def appendImplicitParameters(params: Seq[Parameter]): ClassBuilder = + copy(implicitParameters = implicitParameters ++ params) + + def appendParents(parents: Seq[String]): ClassBuilder = + copy(parents = this.parents ++ parents.map(Doc.text)) + + def setBody(body: PrinterEndo): ClassBuilder = + setBody(printToDoc(body)) + + def setBody(body: Doc): ClassBuilder = + copy(body = body) + + override protected def preamble: Doc = { + val typeParametersDocs = typeParameters.map(Doc.text) + + val typeParametersDoc = when(typeParametersDocs.nonEmpty)( + "[" +: spreadList(typeParametersDocs) :+ "]" + ) + + val parameterListsDoc = internal.parameterLists(parameterLists) + + val implicitParametersDoc = when(implicitParameters.nonEmpty) { + stackList(implicitParameters.map(parameter)) + .tightBracketRightBy(Doc.text("(implicit"), Doc.char(')')) + } + + val allParameterListsDoc = (parameterListsDoc + implicitParametersDoc).regrouped + + typeParametersDoc + allParameterListsDoc + } +} diff --git a/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/FilePrinter.scala b/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/FilePrinter.scala new file mode 100644 index 000000000..9bf5efb87 --- /dev/null +++ b/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/FilePrinter.scala @@ -0,0 +1,36 @@ +package kyo.grpc.compiler.internal + +import com.google.protobuf.Descriptors.* +import com.google.protobuf.compiler.PluginProtos.CodeGeneratorResponse.File +import scala.util.chaining.scalaUtilChainingOps +import scalapb.compiler.DescriptorImplicits +import scalapb.compiler.FunctionalPrinter +import scalapb.compiler.NameUtils + +private[compiler] case class FilePrinter( + file: FileDescriptor, + implicits: DescriptorImplicits, + fp: FunctionalPrinter = new FunctionalPrinter(), + builder: File.Builder = File.newBuilder() +) { + + import implicits.* + + def addPackage: FilePrinter = + copy(fp = fp.addPackage(file.scalaPackage.fullName).newline) + + def addService(service: ServiceDescriptor): FilePrinter = + copy(fp = ServicePrinter(service, implicits, fp).addTrait.addObject.fp) + + def setNameFromService(service: ServiceDescriptor): FilePrinter = { + val dir = file.scalaPackage.fullName.replace(".", "/") + val name = NameUtils.snakeCaseToCamelCase(service.getName, upperInitial = true) + copy(builder = builder.setName(s"$dir/$name.scala")) + } + + def result: File = + builder + .pipe { b => if (b.getName.isEmpty) b.setName(file.scalaFileName) else b } + .setContent(fp.result()) + .build() +} diff --git a/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/MethodBuilder.scala b/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/MethodBuilder.scala new file mode 100644 index 000000000..dbd7568c9 --- /dev/null +++ b/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/MethodBuilder.scala @@ -0,0 +1,97 @@ +package kyo.grpc.compiler.internal + +import kyo.grpc.compiler.internal +import org.typelevel.paiges.Doc +import org.typelevel.paiges.internal.ExtendedSyntax.* +import scalapb.compiler.FunctionalPrinter.PrinterEndo + +final private[compiler] case class MethodBuilder( + id: String, + annotations: Vector[Doc] = Vector.empty, + mods: Vector[Doc] = Vector.empty, + typeParameters: Vector[String] = Vector.empty, + parameterLists: Vector[Seq[Parameter]] = Vector.empty, + implicitParameters: Vector[Parameter] = Vector.empty, + usingParameters: Vector[Parameter] = Vector.empty, + returnType: Option[String] = None, + body: Option[Doc] = None +) { + + def appendAnnotations(annotations: Seq[String]): MethodBuilder = + copy(annotations = this.annotations ++ annotations.map(Doc.text)) + + def appendMods(mods: Seq[String]): MethodBuilder = + copy(mods = this.mods ++ mods.map(Doc.text)) + + def appendTypeParameters(params: Seq[String]): MethodBuilder = + copy(typeParameters = typeParameters ++ params) + + def appendParameterList(params: Seq[Parameter]): MethodBuilder = + copy(parameterLists = parameterLists :+ params) + + def appendImplicitParameters(params: Seq[Parameter]): MethodBuilder = + copy(implicitParameters = implicitParameters ++ params) + + def appendUsingParameters(params: Seq[Parameter]): MethodBuilder = + copy(usingParameters = usingParameters ++ params) + + def setReturnType(returnType: String): MethodBuilder = + copy(returnType = Some(returnType)) + + def setBody(body: PrinterEndo): MethodBuilder = + setBody(printToDoc(body)) + + def setBody(body: Doc): MethodBuilder = + copy(body = Some(body)) + + def result: Doc = { + // Has trailing whitespace if non-empty. + val annotationsDoc = + if (annotations.isEmpty) Doc.empty + else hardList(annotations) + Doc.hardLine + + val modPrefixDoc = when(mods.nonEmpty)(Doc.spread(mods) + Doc.space) + + val defNameDoc = Doc.text("def ") :+ id + + val typeParametersDocs = typeParameters.map(Doc.text) + + val typeParametersDoc = when(typeParametersDocs.nonEmpty)( + "[" +: spreadList(typeParametersDocs) :+ "]" + ) + + val parameterListsDoc = internal.parameterLists(parameterLists) + + val implicitParametersDoc = when(implicitParameters.nonEmpty) { + stackList(implicitParameters.map(parameter)) + .tightBracketRightBy(Doc.text("(implicit"), Doc.char(')')) + } + + val usingParametersDoc = when(usingParameters.nonEmpty) { + stackList(usingParameters.map(parameter)) + .tightBracketRightBy(Doc.text("(using"), Doc.char(')')) + } + + val allParameterListsDoc = (parameterListsDoc + implicitParametersDoc + usingParametersDoc).regrouped + + val returnTypeDoc = returnType.fold(Doc.empty) { s => + Doc.text(": ") :+ s + } + + val signatureDoc = + (annotationsDoc + + modPrefixDoc + + defNameDoc + + typeParametersDoc + + allParameterListsDoc + + returnTypeDoc).grouped + + body.fold(signatureDoc) { bodyDoc => + val bracketedBodyDoc = { + if (bodyDoc.containsHardLine) Doc.text(" {") + Doc.hardLine + bodyDoc.indent(INDENT) + Doc.hardLine + Doc.char('}') + else bodyDoc.hangingUnsafe(INDENT) + } + (signatureDoc :+ " =") + bracketedBodyDoc + }.grouped + } +} diff --git a/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/Mod.scala b/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/Mod.scala new file mode 100644 index 000000000..c15c9f44b --- /dev/null +++ b/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/Mod.scala @@ -0,0 +1,9 @@ +package kyo.grpc.compiler.internal + +private[compiler] object Mod extends Choice { + override type A = String + + val Case = "case" + val Override = "override" + val Private = "private" +} diff --git a/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/ObjectBuilder.scala b/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/ObjectBuilder.scala new file mode 100644 index 000000000..6d968b275 --- /dev/null +++ b/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/ObjectBuilder.scala @@ -0,0 +1,30 @@ +package kyo.grpc.compiler.internal + +import org.typelevel.paiges.Doc +import scalapb.compiler.FunctionalPrinter.PrinterEndo + +final private[compiler] case class ObjectBuilder( + override val id: String, + override val annotations: Vector[Doc] = Vector.empty, + override val mods: Vector[Doc] = Vector.empty, + override val parents: Vector[Doc] = Vector.empty, + override val body: Doc = Doc.empty +) extends TemplateBuilder { + + override protected def keyword: String = "object" + + def appendAnnotations(annotations: Seq[String]): ObjectBuilder = + copy(annotations = this.annotations ++ annotations.map(Doc.text)) + + def appendMods(mods: Seq[String]): ObjectBuilder = + copy(mods = this.mods ++ mods.map(Doc.text)) + + def appendParents(parents: Seq[String]): ObjectBuilder = + copy(parents = this.parents ++ parents.map(Doc.text)) + + def setBody(body: PrinterEndo): ObjectBuilder = + setBody(printToDoc(body)) + + def setBody(body: Doc): ObjectBuilder = + copy(body = body) +} diff --git a/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/Parameter.scala b/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/Parameter.scala new file mode 100644 index 000000000..57de1a19f --- /dev/null +++ b/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/Parameter.scala @@ -0,0 +1,16 @@ +package kyo.grpc.compiler.internal + +import scala.language.implicitConversions + +final private[compiler] case class Parameter(typeName: String, name: Option[String] = None, default: Option[String] = None) { + + def :=(default: String): Parameter = copy(default = Some(default)) +} + +private[compiler] object Parameter { + + def apply(name: String, typeName: String, default: Option[String]): Parameter = + Parameter(typeName, Some(name), default) + + implicit def typeNameToParameter(typeName: String): Parameter = Parameter(typeName) +} diff --git a/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/ServicePrinter.scala b/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/ServicePrinter.scala new file mode 100644 index 000000000..1e17efd6c --- /dev/null +++ b/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/ServicePrinter.scala @@ -0,0 +1,235 @@ +package kyo.grpc.compiler.internal + +import com.google.protobuf.Descriptors.* +import scala.jdk.CollectionConverters.* +import scalapb.compiler.DescriptorImplicits +import scalapb.compiler.FunctionalPrinter +import scalapb.compiler.FunctionalPrinter.PrinterEndo +import scalapb.compiler.NameUtils +import scalapb.compiler.ProtobufGenerator.asScalaDocBlock +import scalapb.compiler.StreamType + +private[compiler] case class ServicePrinter( + service: ServiceDescriptor, + implicits: DescriptorImplicits, + fp: FunctionalPrinter = new FunctionalPrinter() +) { + + import implicits.* + + private val name = NameUtils.snakeCaseToCamelCase(service.getName, upperInitial = true) + + // noinspection MutatorLikeMethodIsParameterless + def addTrait: ServicePrinter = + copy(fp = fp.call(printScalaDoc).call(printServiceTrait).newline) + + // noinspection MutatorLikeMethodIsParameterless + def addObject: ServicePrinter = + copy(fp = fp.call(printServiceObject).newline) + + private def printScalaDoc: PrinterEndo = { + val lines = asScalaDocBlock(service.comment.map(_.split('\n').toSeq).getOrElse(Seq.empty)) + _.add(lines*) + } + + private def printServiceTrait: PrinterEndo = + _.addTrait(name) + .addAnnotations(Seq(service.deprecatedAnnotation).filter(_.nonEmpty)*) + .addParents(Types.service) + .addBody { + _.newline + .call(printServiceDefinitionMethod) + .newline + .print(service.getMethods.asScala) { (fp, md) => + printServiceMethod(md)(fp) + } + } + + private def printServiceDefinitionMethod: PrinterEndo = + _.addMethod("definition") + .addMods(_.Override) + .addReturnType(Types.serverServiceDefinition) + .addBody { + _.add(s"$name.service(this)") + } + + private def printServiceMethod(method: MethodDescriptor): PrinterEndo = { + val parameters = serverMethodParameters(method) + val returnType = serviceMethodReturnType(method) + _.call(printScalaDoc(method)) + .addMethod(method.name) + .addAnnotations(Seq(method.deprecatedAnnotation).filter(_.nonEmpty)*) + .addParameterList(parameters*) + .addReturnType(returnType) + } + + private def printScalaDoc(method: MethodDescriptor): PrinterEndo = { + val lines = asScalaDocBlock(method.comment.map(_.split('\n').toSeq).getOrElse(Seq.empty)) + _.add(lines*) + } + + private def printServiceObject: PrinterEndo = + _.addObject(name) + .addAnnotations(Seq(service.deprecatedAnnotation).filter(_.nonEmpty)*) + .addBody { + _.newline + .call(printServerMethod) + .newline + .call(printClientMethod) + .newline + .call(printManagedClientMethod) + .newline + .call(printClientTrait) + .newline + .call(printClientImpl) + } + + private def printServerMethod: PrinterEndo = { + val methods = service.methods.map(printAddMethod) + _.addMethod("service") + .addParameterList("serviceImpl" :- name) + .addReturnType(Types.serverServiceDefinition) + .addBody( + _.add(s"""${Types.serverServiceDefinition}.builder(${service.grpcDescriptor.fullNameWithMaybeRoot})""") + .call(methods*) + .add(".build()") + ) + } + + private def printAddMethod(method: MethodDescriptor): PrinterEndo = { + val methodName = method.streamType match { + case StreamType.Unary => "unary" + case StreamType.ClientStreaming => "clientStreaming" + case StreamType.ServerStreaming => "serverStreaming" + case StreamType.Bidirectional => "bidiStreaming" + } + val handler = s"${Types.serverCallHandlers}.$methodName(serviceImpl.${method.name})" + _.add(".addMethod(") + .indented( + _.add(s"${method.grpcDescriptor.fullNameWithMaybeRoot},") + .add(handler) + ) + .add(")") + } + + private def printClientMethod: PrinterEndo = + _.addMethod("client") + .addParameterList( // + "channel" :- Types.channel, + "options" :- Types.callOptions := (Types.callOptions + ".DEFAULT") + ) + .addReturnType("Client") + .addBody( + _.add("ClientImpl(channel, options)") + ) + + private def printManagedClientMethod: PrinterEndo = + _.addMethod("managedClient") + .addParameterList( // + "host" :- Types.string, + "port" :- Types.int, + "timeout" :- Types.duration := s"${Types.duration}.fromUnits(30, ${Types.duration}.Units.Millis)", + "options" :- Types.callOptions := (Types.callOptions + ".DEFAULT") + ) + .addParameterList( // + "configure" :- s"${Types.managedChannelBuilder("?")} => ${Types.managedChannelBuilder("?")}", + "shutdown" :- s"(${Types.managedChannel}, ${Types.duration}) => ${Types.frame} ?=> ${Types.pending(Types.any, Types.sync)}" := s"${Types.client}.shutdown" + ) + .addUsingParameters(Types.frame) + .addReturnType(Types.pending("Client", s"${Types.scope} & ${Types.sync}")) + .addBody( + _.add(s"${Types.client}.channel(host, port, timeout)(configure, shutdown).map($name.client(_, options))") + ) + + private def printClientTrait: PrinterEndo = + _.addTrait("Client") + .addBody { + _.print(service.getMethods.asScala) { (fp, md) => + printClientServiceMethod(md)(fp) + } + } + + private def printClientServiceMethod(method: MethodDescriptor): PrinterEndo = { + val parameters = clientMethodParameters(method) + val returnType = clientMethodReturnType(method) + _.call(printScalaDoc(method)) + .addMethod(method.name) + .addAnnotations(Seq(method.deprecatedAnnotation).filter(_.nonEmpty)*) + .addParameterList(parameters*) + .addReturnType(returnType) + } + + private def printClientImpl: PrinterEndo = + _.addClass("ClientImpl") + .addParameterList( // + "channel" :- Types.channel, + "options" :- Types.callOptions + ) + .addParents("Client") + .addBody { + _.print(service.getMethods.asScala) { (fp, md) => + printClientImplMethod(md)(fp) + } + } + + private def printClientImplMethod(method: MethodDescriptor): PrinterEndo = { + val parameters = clientMethodParameters(method) + val returnType = clientMethodReturnType(method) + val delegateName = method.streamType match { + case StreamType.Unary => "unary" + case StreamType.ClientStreaming => "clientStreaming" + case StreamType.ServerStreaming => "serverStreaming" + case StreamType.Bidirectional => "bidiStreaming" + } + val requestParameter = method.streamType match { + case StreamType.Unary | StreamType.ServerStreaming => "request" + case StreamType.ClientStreaming | StreamType.Bidirectional => "requests" + } + val delegate = s"$delegateName(channel, ${method.grpcDescriptor.fullNameWithMaybeRoot}, options, $requestParameter)" + _.call(printScalaDoc(method)) + .addMethod(method.name) + .addAnnotations(Seq(method.deprecatedAnnotation).filter(_.nonEmpty)*) + .addMods(_.Override) + .addParameterList(parameters*) + .addReturnType(returnType) + .addBody( + _.add(s"${Types.clientCall}.$delegate") + ) + } + + private def clientMethodParameters(method: MethodDescriptor): Seq[Parameter] = { + def requestParameter = "request" :- Types.grpcRequestInit(method.inputType.scalaType) + def requestsParameter = "requests" :- Types.grpcRequestsInit(Types.streamGrpcRequest(method.inputType.scalaType)) + method.streamType match { + case StreamType.Unary | StreamType.ServerStreaming => Seq(requestParameter) + case StreamType.ClientStreaming | StreamType.Bidirectional => Seq(requestsParameter) + } + } + + private def serverMethodParameters(method: MethodDescriptor): Seq[Parameter] = { + def requestParameter = "request" :- method.inputType.scalaType + def requestsParameter = "requests" :- Types.streamGrpcRequest(method.inputType.scalaType) + method.streamType match { + case StreamType.Unary | StreamType.ServerStreaming => Seq(requestParameter) + case StreamType.ClientStreaming | StreamType.Bidirectional => Seq(requestsParameter) + } + } + + private def clientMethodReturnType(method: MethodDescriptor): String = { + method.streamType match { + case StreamType.Unary | StreamType.ClientStreaming => + Types.pendingGrpcRequest(method.outputType.scalaType) + case StreamType.ServerStreaming | StreamType.Bidirectional => + Types.streamGrpcRequest(method.outputType.scalaType) + } + } + + private def serviceMethodReturnType(method: MethodDescriptor): String = { + method.streamType match { + case StreamType.Unary | StreamType.ClientStreaming => + Types.pendingGrpcResponse(method.outputType.scalaType) + case StreamType.ServerStreaming | StreamType.Bidirectional => + Types.pendingGrpcResponse(Types.streamGrpcResponse(method.outputType.scalaType)) + } + } +} diff --git a/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/TemplateBuilder.scala b/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/TemplateBuilder.scala new file mode 100644 index 000000000..dacf9aecb --- /dev/null +++ b/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/TemplateBuilder.scala @@ -0,0 +1,44 @@ +package kyo.grpc.compiler.internal + +import org.typelevel.paiges.Doc + +private[compiler] trait TemplateBuilder { + + def annotations: Iterable[Doc] + def mods: Iterable[Doc] + def id: String + def parents: Iterable[Doc] + def body: Doc + + protected def keyword: String + + /** The part between the id and the template. + * + * It must contain leading whitespace and no trailing whitespace. + */ + protected def preamble: Doc = Doc.empty + + def result: Doc = { + // Has trailing whitespace if non-empty. + val annotationsDoc = + if (annotations.isEmpty) Doc.empty + else hardList(annotations) + Doc.hardLine + + val modsDoc = when(mods.nonEmpty)(Doc.spread(mods) + Doc.space) + + val idDoc = Doc.text(s"$keyword ") :+ id + + // Has leading whitespace. + val parentsDoc = extendsList(parents) + + val headerDoc = + (annotationsDoc + + modsDoc + + idDoc + + preamble + + parentsDoc).grouped + + if (body.isEmpty) headerDoc + else (headerDoc :+ " {") + Doc.hardLine + body.indent(INDENT) + (Doc.hardLine :+ "}") + } +} diff --git a/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/TraitBuilder.scala b/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/TraitBuilder.scala new file mode 100644 index 000000000..71dd3f2f8 --- /dev/null +++ b/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/TraitBuilder.scala @@ -0,0 +1,42 @@ +package kyo.grpc.compiler.internal + +import org.typelevel.paiges.Doc +import scalapb.compiler.FunctionalPrinter.PrinterEndo + +final private[compiler] case class TraitBuilder( + override val id: String, + override val annotations: Vector[Doc] = Vector.empty, + override val mods: Vector[Doc] = Vector.empty, + typeParameters: Vector[String] = Vector.empty, + override val parents: Vector[Doc] = Vector.empty, + override val body: Doc = Doc.empty +) extends TemplateBuilder { + + override protected def keyword: String = "trait" + + def appendAnnotations(annotations: Seq[String]): TraitBuilder = + copy(annotations = this.annotations ++ annotations.map(Doc.text)) + + def appendMods(mods: Seq[String]): TraitBuilder = + copy(mods = this.mods ++ mods.map(Doc.text)) + + def appendTypeParameters(params: Seq[String]): TraitBuilder = + copy(typeParameters = typeParameters ++ params) + + def appendParents(parents: Seq[String]): TraitBuilder = + copy(parents = this.parents ++ parents.map(Doc.text)) + + def setBody(body: PrinterEndo): TraitBuilder = + setBody(printToDoc(body)) + + def setBody(body: Doc): TraitBuilder = + copy(body = body) + + override protected def preamble: Doc = { + val typeParametersDocs = typeParameters.map(Doc.text) + + when(typeParametersDocs.nonEmpty)( + "[" +: spreadList(typeParametersDocs) :+ "]" + ) + } +} diff --git a/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/Types.scala b/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/Types.scala new file mode 100644 index 000000000..90d65ac88 --- /dev/null +++ b/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/Types.scala @@ -0,0 +1,55 @@ +package kyo.grpc.compiler.internal + +private[compiler] object Types { + + val any = "Any" + + val int = "Int" + + val string = "String" + + def pending(t: String, s: String) = s"_root_.kyo.<[$t, $s]" + + def pendingGrpcRequest(t: String) = pending(t, grpcRequest) + + def pendingGrpcResponse(t: String) = pending(t, grpcResponse) + + val duration = "_root_.kyo.Duration" + + val frame = "_root_.kyo.Frame" + + val sync = "_root_.kyo.Sync" + + val scope = "_root_.kyo.Scope" + + def streamGrpcResponse(t: String) = s"_root_.kyo.Stream[$t, $grpcResponse]" + + def streamGrpcRequest(t: String) = s"_root_.kyo.Stream[$t, $grpcRequest]" + + val client = "_root_.kyo.grpc.Client" + + val clientCall = "_root_.kyo.grpc.ClientCall" + + val grpcResponse = "_root_.kyo.grpc.Grpc" + + val grpcRequest = "_root_.kyo.grpc.Grpc" + + def grpcRequestInit(request: String) = s"_root_.kyo.grpc.GrpcRequestInit[$request]" + + def grpcRequestsInit(request: String) = s"_root_.kyo.grpc.GrpcRequestsInit[$request]" + + val serverCallHandlers = "_root_.kyo.grpc.ServerCallHandlers" + + val service = "_root_.kyo.grpc.Service" + + val callOptions = "_root_.io.grpc.CallOptions" + + val channel = "_root_.io.grpc.Channel" + + val managedChannel = "_root_.io.grpc.ManagedChannel" + + def managedChannelBuilder(t: String) = s"_root_.io.grpc.ManagedChannelBuilder[$t]" + + val serverServiceDefinition = "_root_.io.grpc.ServerServiceDefinition" + +} diff --git a/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/package.scala b/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/package.scala new file mode 100644 index 000000000..c2734585b --- /dev/null +++ b/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/compiler/internal/package.scala @@ -0,0 +1,208 @@ +package kyo.grpc.compiler + +import org.typelevel.paiges.Doc +import org.typelevel.paiges.internal.Docx +import org.typelevel.paiges.internal.ExtendedSyntax.* +import scala.language.implicitConversions +import scalapb.compiler.FunctionalPrinter +import scalapb.compiler.FunctionalPrinter.PrinterEndo + +package object internal { + + private[compiler] val WIDTH = 100 + private[compiler] val INDENT = 2 + + private[compiler] def mods(chooses: Mod.Choose*): Vector[String] = + chooses.map(_.choice).toVector + + final private[compiler] case class AddClassDsl(builder: ClassBuilder, fp: FunctionalPrinter) { + + def addAnnotations(annotations: String*): AddClassDsl = + copy(builder = builder.appendAnnotations(annotations)) + + def addMods(mods: Mod.Choose*): AddClassDsl = + copy(builder = builder.appendMods(mods.map(_.choice))) + + def addTypeParameters(params: String*): AddClassDsl = + copy(builder = builder.appendTypeParameters(params)) + + def addParameterList(params: Parameter*): AddClassDsl = + copy(builder = builder.appendParameterList(params)) + + def addImplicitParameters(params: Parameter*): AddClassDsl = + copy(builder = builder.appendImplicitParameters(params)) + + def addParents(params: String*): AddClassDsl = + copy(builder = builder.appendParents(params)) + + def addBody(body: PrinterEndo): AddClassDsl = + copy(builder = builder.setBody(body)) + + def addBodyDoc(body: Doc): AddClassDsl = + copy(builder = builder.setBody(body)) + + def endClass: FunctionalPrinter = fp.addDoc(builder.result) + } + + private[compiler] object AddClassDsl { + implicit def endClass(dsl: AddClassDsl): FunctionalPrinter = dsl.endClass + } + + final private[compiler] case class AddObjectDsl(builder: ObjectBuilder, fp: FunctionalPrinter) { + + def addAnnotations(annotations: String*): AddObjectDsl = + copy(builder = builder.appendAnnotations(annotations)) + + def addMods(mods: Mod.Choose*): AddObjectDsl = + copy(builder = builder.appendMods(mods.map(_.choice))) + + def addParents(params: String*): AddObjectDsl = + copy(builder = builder.appendParents(params)) + + def addBody(body: PrinterEndo): AddObjectDsl = + copy(builder = builder.setBody(body)) + + def addBodyDoc(body: Doc): AddObjectDsl = + copy(builder = builder.setBody(body)) + + def endObject: FunctionalPrinter = fp.addDoc(builder.result) + } + + private[compiler] object AddObjectDsl { + implicit def endObject(dsl: AddObjectDsl): FunctionalPrinter = dsl.endObject + } + + final private[compiler] case class AddTraitDsl(builder: TraitBuilder, fp: FunctionalPrinter) { + + def addAnnotations(annotations: String*): AddTraitDsl = + copy(builder = builder.appendAnnotations(annotations)) + + def addMods(mods: Mod.Choose*): AddTraitDsl = + copy(builder = builder.appendMods(mods.map(_.choice))) + + def addTypeParameters(params: String*): AddTraitDsl = + copy(builder = builder.appendTypeParameters(params)) + + def addParents(params: String*): AddTraitDsl = + copy(builder = builder.appendParents(params)) + + def addBody(body: PrinterEndo): AddTraitDsl = + copy(builder = builder.setBody(body)) + + def addBodyDoc(body: Doc): AddTraitDsl = + copy(builder = builder.setBody(body)) + + def endTrait: FunctionalPrinter = fp.addDoc(builder.result) + } + + private[compiler] object AddTraitDsl { + implicit def endTrait(dsl: AddTraitDsl): FunctionalPrinter = dsl.endTrait + } + + final private[compiler] case class AddMethodDsl(builder: MethodBuilder, fp: FunctionalPrinter) { + + def addAnnotations(annotations: String*): AddMethodDsl = + copy(builder = builder.appendAnnotations(annotations)) + + def addMods(mods: Mod.Choose*): AddMethodDsl = + copy(builder = builder.appendMods(mods.map(_.choice))) + + def addTypeParameters(params: String*): AddMethodDsl = + copy(builder = builder.appendTypeParameters(params)) + + def addParameterList(params: Parameter*): AddMethodDsl = + copy(builder = builder.appendParameterList(params)) + + def addImplicitParameters(params: Parameter*): AddMethodDsl = + copy(builder = builder.appendImplicitParameters(params)) + + def addUsingParameters(params: Parameter*): AddMethodDsl = + copy(builder = builder.appendUsingParameters(params)) + + def addReturnType(returnType: String): AddMethodDsl = + copy(builder = builder.setReturnType(returnType)) + + def addBody(body: PrinterEndo): AddMethodDsl = + copy(builder = builder.setBody(body)) + + def addBodyDoc(body: Doc): AddMethodDsl = + copy(builder = builder.setBody(body)) + + def endMethod: FunctionalPrinter = fp.addDoc(builder.result) + } + + private[compiler] object AddMethodDsl { + implicit def endMethod(dsl: AddMethodDsl): FunctionalPrinter = dsl.endMethod + } + + implicit private[compiler] class ScalaFunctionalPrinterOps(val fp: FunctionalPrinter) extends AnyVal { + + def addPackage(id: String): FunctionalPrinter = + fp.add(s"package $id") + + def addClass(id: String): AddClassDsl = + AddClassDsl(ClassBuilder(id), fp) + + def addObject(id: String): AddObjectDsl = + AddObjectDsl(ObjectBuilder(id), fp) + + def addTrait(id: String): AddTraitDsl = + AddTraitDsl(TraitBuilder(id), fp) + + def addMethod(id: String): AddMethodDsl = + AddMethodDsl(MethodBuilder(id), fp) + + def append(s: String): FunctionalPrinter = { + val lastIndex = fp.content.size - 1 + if (lastIndex > 0) fp.copy(content = fp.content.updated(lastIndex, fp.content(lastIndex) + s)) + else fp.add(s) + } + + def addDoc(doc: Doc): FunctionalPrinter = + fp.add(doc.render(WIDTH)) + } + + implicit private[compiler] class StringParameterOps(val parameterName: String) extends AnyVal { + + def :-(typeName: String): Parameter = Parameter(parameterName, typeName, None) + } + + private[compiler] def when(condition: Boolean)(doc: => Doc): Doc = + if (condition) doc else Doc.empty + + private[compiler] def hardList(docs: Iterable[Doc]): Doc = + Doc.intercalate(Doc.hardLine, docs) + + private[compiler] def stackList(docs: Iterable[Doc]): Doc = + Doc.intercalate(Doc.char(',') + Doc.line, docs) + + private[compiler] def spreadList(docs: Iterable[Doc]): Doc = + Doc.intercalate(Doc.text(", "), docs) + + private[compiler] def extendsList(docs: Iterable[Doc]): Doc = + when(docs.nonEmpty) { + (Doc.text("extends ") + Doc.intercalate(Doc.line + Doc.text("with "), docs)).hangingUnsafe(INDENT * 2) + } + + private[compiler] def typedName(name: Option[String], tpe: String): Doc = + name match { + case None => Doc.text(tpe) + case Some(n) => n +: (Doc.text(": ") + Doc.text(tpe)) + } + + private[compiler] def parameter(parameter: Parameter): Doc = + typedName(parameter.name, parameter.typeName) + + parameter.default.fold(Doc.empty)(default => Doc.text(" = ") + Doc.text(default)) + + private[compiler] def parameterLists(parameterss: Vector[Seq[Parameter]]): Doc = + when(parameterss.nonEmpty) { + val parametersDocs = parameterss + .map(_.map(parameter)) + .map(stackList) + .map(_.tightBracketBy(Doc.char('('), Doc.char(')'))) + Doc.cat(parametersDocs) + } + + private[compiler] def printToDoc(f: PrinterEndo): Doc = + Docx.literal(f(new FunctionalPrinter()).result()) +} diff --git a/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/gen.scala b/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/gen.scala new file mode 100644 index 000000000..8c72233b8 --- /dev/null +++ b/kyo-grpc-code-gen/shared/src/main/scala/kyo/grpc/gen.scala @@ -0,0 +1,23 @@ +package kyo.grpc + +import protocbridge.Artifact +import protocbridge.SandboxedJvmGenerator +import scalapb.GeneratorOption + +object gen { + def apply(options: GeneratorOption*): (SandboxedJvmGenerator, Seq[String]) = ( + SandboxedJvmGenerator.forModule( + "scala", + Artifact( + kyo.grpc.compiler.BuildInfo.organization, + "kyo-grpc-code-gen_2.12", + kyo.grpc.compiler.BuildInfo.version + ), + "kyo.grpc.compiler.CodeGenerator$", + kyo.grpc.compiler.CodeGenerator.suggestedDependencies + ), + options.map(_.toString) + ) + + def apply(options: Set[GeneratorOption] = Set.empty): (SandboxedJvmGenerator, Seq[String]) = apply(options.toSeq*) +} diff --git a/kyo-grpc-code-gen/shared/src/main/scala/org/typelevel/paiges/internal/Docx.scala b/kyo-grpc-code-gen/shared/src/main/scala/org/typelevel/paiges/internal/Docx.scala new file mode 100644 index 000000000..07515266b --- /dev/null +++ b/kyo-grpc-code-gen/shared/src/main/scala/org/typelevel/paiges/internal/Docx.scala @@ -0,0 +1,44 @@ +package org.typelevel.paiges.internal + +import org.typelevel.paiges.* +import org.typelevel.paiges.Doc.* +import scala.annotation.tailrec + +// Workaround for https://github.com/typelevel/paiges/issues/628. +object Docx { + + import ExtendedSyntax.* + + /** The width of `left` must be shorter or equal to this doc. + */ + def bracketIfMultiline(left: Doc, doc: Doc, right: Doc, indent: Int = 2): Doc = + doc.bracketIfMultiline(left, right, indent) + + /** Unsafe as it violates the invariant of FlatAlt `width(default) <= width(whenFlat)`. + */ + private[paiges] def orEmpty(doc: Doc): Doc = { + if (doc.isEmpty) doc + else FlatAlt(doc, Doc.empty) + } + + def literal(str: String): Doc = { + def tx(i: Int, j: Int): Doc = + if (i == j) Empty + else if (i == j - 1) Doc.char(str.charAt(i)) + else Text(str.substring(i, j)) + + // parse the string right-to-left, splitting at newlines. + // this ensures that our concatenations are right-associated. + @tailrec def parse(i: Int, limit: Int, doc: Doc): Doc = + if (i < 0) { + val next = tx(0, limit) + if (doc.isEmpty) next else next + doc + } else + str.charAt(i) match { + case '\n' => parse(i - 1, i, hardLine + (tx(i + 1, limit) + doc)) + case _ => parse(i - 1, limit, doc) + } + + parse(str.length - 1, str.length, Empty) + } +} diff --git a/kyo-grpc-code-gen/shared/src/main/scala/org/typelevel/paiges/internal/ExtendedSyntax.scala b/kyo-grpc-code-gen/shared/src/main/scala/org/typelevel/paiges/internal/ExtendedSyntax.scala new file mode 100644 index 000000000..fe69e3af2 --- /dev/null +++ b/kyo-grpc-code-gen/shared/src/main/scala/org/typelevel/paiges/internal/ExtendedSyntax.scala @@ -0,0 +1,57 @@ +package org.typelevel.paiges.internal + +import org.typelevel.paiges.* +import org.typelevel.paiges.Doc.* + +// Workaround for https://github.com/typelevel/paiges/issues/628. +object ExtendedSyntax { + + implicit class DocOps(val doc: Doc) extends AnyVal { + + /** The width of `left` must be shorter or equal to this doc. + */ + def bracketIfMultiline(left: Doc, right: Doc, indent: Int = 2): Doc = { + // Long means not completely flat. orEmpty alone is not enough as a doc with hard lines is considered flat. + // Flat needs to be broken down into minimal vs single-line. + if (doc.containsHardLine) left + Doc.hardLine + doc.indent(indent) + Doc.hardLine + right + else ((Docx.orEmpty(left + Doc.hardLine) + doc).nested(indent) + Docx.orEmpty(Doc.hardLine + right)).grouped + } + + def tightBracketLeftBy(left: Doc, right: Doc, indent: Int = 2): Doc = + Concat(left, Concat(Concat(Doc.lineBreak, doc).nested(indent), Concat(Doc.line, right)).grouped) + + def tightBracketRightBy(left: Doc, right: Doc, indent: Int = 2): Doc = + Concat(left, Concat(Concat(Doc.line, doc).nested(indent), Concat(Doc.lineBreak, right)).grouped) + + def ungrouped: Doc = + doc match { + case Union(_, b) => b + case FlatAlt(a, b) => FlatAlt(a.ungrouped, Doc.defer(b.ungrouped)) + case Concat(a, b) => Concat(a.ungrouped, b.ungrouped) + case Nest(i, d) => Nest(i, d.ungrouped) + case d @ LazyDoc(_) => Doc.defer(d.evaluated.ungrouped) + case Align(d) => Align(d.ungrouped) + case ZeroWidth(_) | Text(_) | Empty | Line => doc + } + + def regrouped: Doc = + doc.ungrouped.grouped + + def containsHardLine: Boolean = + doc match { + case Line => true + case ZeroWidth(_) | Text(_) | Empty => false + case Nest(_, d) => d.containsHardLine + case Align(d) => d.containsHardLine + case FlatAlt(_, b) => b.containsHardLine + case Union(a, _) => a.containsHardLine + case Concat(a, b) => a.containsHardLine || b.containsHardLine + case d @ LazyDoc(_) => d.evaluated.containsHardLine + } + + // This is unsafe as it violates the invariants of FlatAlt, but it seems to be OK for it is used. + def hangingUnsafe(i: Int, sep: Doc = Doc.space): Doc = { + FlatAlt(Doc.hardLine + Doc.spaces(i) + doc.aligned, sep + doc.flatten).grouped + } + } +} diff --git a/kyo-grpc-code-gen/shared/src/test/resources/output/multiple-files-1 b/kyo-grpc-code-gen/shared/src/test/resources/output/multiple-files-1 new file mode 100644 index 000000000..40d4812d5 --- /dev/null +++ b/kyo-grpc-code-gen/shared/src/test/resources/output/multiple-files-1 @@ -0,0 +1,94 @@ +package kgrpc.test + +trait TestService extends _root_.kyo.grpc.Service { + + override def definition: _root_.io.grpc.ServerServiceDefinition = TestService.service(this) + + def oneToOne(request: kgrpc.test.Request): _root_.kyo.<[kgrpc.test.Response, _root_.kyo.grpc.Grpc] + def oneToMany( + request: kgrpc.test.Request + ): _root_.kyo.<[_root_.kyo.Stream[kgrpc.test.Response, _root_.kyo.grpc.Grpc], _root_.kyo.grpc.Grpc] + def manyToOne( + requests: _root_.kyo.Stream[kgrpc.test.Request, _root_.kyo.grpc.Grpc] + ): _root_.kyo.<[kgrpc.test.Response, _root_.kyo.grpc.Grpc] + def manyToMany( + requests: _root_.kyo.Stream[kgrpc.test.Request, _root_.kyo.grpc.Grpc] + ): _root_.kyo.<[_root_.kyo.Stream[kgrpc.test.Response, _root_.kyo.grpc.Grpc], _root_.kyo.grpc.Grpc] +} + +object TestService { + + def service(serviceImpl: TestService): _root_.io.grpc.ServerServiceDefinition = { + _root_.io.grpc.ServerServiceDefinition.builder(_root_.kgrpc.test.TestServiceGrpc.SERVICE) + .addMethod( + _root_.kgrpc.test.TestServiceGrpc.METHOD_ONE_TO_ONE, + _root_.kyo.grpc.ServerCallHandlers.unary(serviceImpl.oneToOne) + ) + .addMethod( + _root_.kgrpc.test.TestServiceGrpc.METHOD_ONE_TO_MANY, + _root_.kyo.grpc.ServerCallHandlers.serverStreaming(serviceImpl.oneToMany) + ) + .addMethod( + _root_.kgrpc.test.TestServiceGrpc.METHOD_MANY_TO_ONE, + _root_.kyo.grpc.ServerCallHandlers.clientStreaming(serviceImpl.manyToOne) + ) + .addMethod( + _root_.kgrpc.test.TestServiceGrpc.METHOD_MANY_TO_MANY, + _root_.kyo.grpc.ServerCallHandlers.bidiStreaming(serviceImpl.manyToMany) + ) + .build() + } + + def client( + channel: _root_.io.grpc.Channel, + options: _root_.io.grpc.CallOptions = _root_.io.grpc.CallOptions.DEFAULT + ): Client = ClientImpl(channel, options) + + def managedClient( + host: String, + port: Int, + timeout: _root_.kyo.Duration = _root_.kyo.Duration.fromUnits(30, _root_.kyo.Duration.Units.Millis), + options: _root_.io.grpc.CallOptions = _root_.io.grpc.CallOptions.DEFAULT + )( + configure: _root_.io.grpc.ManagedChannelBuilder[?] => _root_.io.grpc.ManagedChannelBuilder[?], + shutdown: (_root_.io.grpc.ManagedChannel, _root_.kyo.Duration) => _root_.kyo.Frame ?=> _root_.kyo.<[Any, _root_.kyo.Sync] = _root_.kyo.grpc.Client.shutdown + )(using + _root_.kyo.Frame + ): _root_.kyo.<[Client, _root_.kyo.Scope & _root_.kyo.Sync] = + _root_.kyo.grpc.Client.channel(host, port, timeout)(configure, shutdown).map(TestService.client(_, options)) + + trait Client { + def oneToOne( + request: _root_.kyo.grpc.GrpcRequestInit[kgrpc.test.Request] + ): _root_.kyo.<[kgrpc.test.Response, _root_.kyo.grpc.Grpc] + def oneToMany( + request: _root_.kyo.grpc.GrpcRequestInit[kgrpc.test.Request] + ): _root_.kyo.Stream[kgrpc.test.Response, _root_.kyo.grpc.Grpc] + def manyToOne( + requests: _root_.kyo.grpc.GrpcRequestsInit[_root_.kyo.Stream[kgrpc.test.Request, _root_.kyo.grpc.Grpc]] + ): _root_.kyo.<[kgrpc.test.Response, _root_.kyo.grpc.Grpc] + def manyToMany( + requests: _root_.kyo.grpc.GrpcRequestsInit[_root_.kyo.Stream[kgrpc.test.Request, _root_.kyo.grpc.Grpc]] + ): _root_.kyo.Stream[kgrpc.test.Response, _root_.kyo.grpc.Grpc] + } + + class ClientImpl(channel: _root_.io.grpc.Channel, options: _root_.io.grpc.CallOptions) + extends Client { + override def oneToOne( + request: _root_.kyo.grpc.GrpcRequestInit[kgrpc.test.Request] + ): _root_.kyo.<[kgrpc.test.Response, _root_.kyo.grpc.Grpc] = + _root_.kyo.grpc.ClientCall.unary(channel, _root_.kgrpc.test.TestServiceGrpc.METHOD_ONE_TO_ONE, options, request) + override def oneToMany( + request: _root_.kyo.grpc.GrpcRequestInit[kgrpc.test.Request] + ): _root_.kyo.Stream[kgrpc.test.Response, _root_.kyo.grpc.Grpc] = + _root_.kyo.grpc.ClientCall.serverStreaming(channel, _root_.kgrpc.test.TestServiceGrpc.METHOD_ONE_TO_MANY, options, request) + override def manyToOne( + requests: _root_.kyo.grpc.GrpcRequestsInit[_root_.kyo.Stream[kgrpc.test.Request, _root_.kyo.grpc.Grpc]] + ): _root_.kyo.<[kgrpc.test.Response, _root_.kyo.grpc.Grpc] = + _root_.kyo.grpc.ClientCall.clientStreaming(channel, _root_.kgrpc.test.TestServiceGrpc.METHOD_MANY_TO_ONE, options, requests) + override def manyToMany( + requests: _root_.kyo.grpc.GrpcRequestsInit[_root_.kyo.Stream[kgrpc.test.Request, _root_.kyo.grpc.Grpc]] + ): _root_.kyo.Stream[kgrpc.test.Response, _root_.kyo.grpc.Grpc] = + _root_.kyo.grpc.ClientCall.bidiStreaming(channel, _root_.kgrpc.test.TestServiceGrpc.METHOD_MANY_TO_MANY, options, requests) + } +} diff --git a/kyo-grpc-code-gen/shared/src/test/resources/output/multiple-files-2 b/kyo-grpc-code-gen/shared/src/test/resources/output/multiple-files-2 new file mode 100644 index 000000000..4026cc568 --- /dev/null +++ b/kyo-grpc-code-gen/shared/src/test/resources/output/multiple-files-2 @@ -0,0 +1,80 @@ +package kgrpc.test + +trait UtilityService extends _root_.kyo.grpc.Service { + + override def definition: _root_.io.grpc.ServerServiceDefinition = UtilityService.service(this) + + def health(request: kgrpc.test.Request): _root_.kyo.<[kgrpc.test.Response, _root_.kyo.grpc.Grpc] + def monitor( + request: kgrpc.test.Request + ): _root_.kyo.<[_root_.kyo.Stream[kgrpc.test.Response, _root_.kyo.grpc.Grpc], _root_.kyo.grpc.Grpc] + def batch( + requests: _root_.kyo.Stream[kgrpc.test.Request, _root_.kyo.grpc.Grpc] + ): _root_.kyo.<[kgrpc.test.Response, _root_.kyo.grpc.Grpc] +} + +object UtilityService { + + def service(serviceImpl: UtilityService): _root_.io.grpc.ServerServiceDefinition = { + _root_.io.grpc.ServerServiceDefinition.builder(_root_.kgrpc.test.UtilityServiceGrpc.SERVICE) + .addMethod( + _root_.kgrpc.test.UtilityServiceGrpc.METHOD_HEALTH, + _root_.kyo.grpc.ServerCallHandlers.unary(serviceImpl.health) + ) + .addMethod( + _root_.kgrpc.test.UtilityServiceGrpc.METHOD_MONITOR, + _root_.kyo.grpc.ServerCallHandlers.serverStreaming(serviceImpl.monitor) + ) + .addMethod( + _root_.kgrpc.test.UtilityServiceGrpc.METHOD_BATCH, + _root_.kyo.grpc.ServerCallHandlers.clientStreaming(serviceImpl.batch) + ) + .build() + } + + def client( + channel: _root_.io.grpc.Channel, + options: _root_.io.grpc.CallOptions = _root_.io.grpc.CallOptions.DEFAULT + ): Client = ClientImpl(channel, options) + + def managedClient( + host: String, + port: Int, + timeout: _root_.kyo.Duration = _root_.kyo.Duration.fromUnits(30, _root_.kyo.Duration.Units.Millis), + options: _root_.io.grpc.CallOptions = _root_.io.grpc.CallOptions.DEFAULT + )( + configure: _root_.io.grpc.ManagedChannelBuilder[?] => _root_.io.grpc.ManagedChannelBuilder[?], + shutdown: (_root_.io.grpc.ManagedChannel, _root_.kyo.Duration) => _root_.kyo.Frame ?=> _root_.kyo.<[Any, _root_.kyo.Sync] = _root_.kyo.grpc.Client.shutdown + )(using + _root_.kyo.Frame + ): _root_.kyo.<[Client, _root_.kyo.Scope & _root_.kyo.Sync] = + _root_.kyo.grpc.Client.channel(host, port, timeout)(configure, shutdown).map(UtilityService.client(_, options)) + + trait Client { + def health( + request: _root_.kyo.grpc.GrpcRequestInit[kgrpc.test.Request] + ): _root_.kyo.<[kgrpc.test.Response, _root_.kyo.grpc.Grpc] + def monitor( + request: _root_.kyo.grpc.GrpcRequestInit[kgrpc.test.Request] + ): _root_.kyo.Stream[kgrpc.test.Response, _root_.kyo.grpc.Grpc] + def batch( + requests: _root_.kyo.grpc.GrpcRequestsInit[_root_.kyo.Stream[kgrpc.test.Request, _root_.kyo.grpc.Grpc]] + ): _root_.kyo.<[kgrpc.test.Response, _root_.kyo.grpc.Grpc] + } + + class ClientImpl(channel: _root_.io.grpc.Channel, options: _root_.io.grpc.CallOptions) + extends Client { + override def health( + request: _root_.kyo.grpc.GrpcRequestInit[kgrpc.test.Request] + ): _root_.kyo.<[kgrpc.test.Response, _root_.kyo.grpc.Grpc] = + _root_.kyo.grpc.ClientCall.unary(channel, _root_.kgrpc.test.UtilityServiceGrpc.METHOD_HEALTH, options, request) + override def monitor( + request: _root_.kyo.grpc.GrpcRequestInit[kgrpc.test.Request] + ): _root_.kyo.Stream[kgrpc.test.Response, _root_.kyo.grpc.Grpc] = + _root_.kyo.grpc.ClientCall.serverStreaming(channel, _root_.kgrpc.test.UtilityServiceGrpc.METHOD_MONITOR, options, request) + override def batch( + requests: _root_.kyo.grpc.GrpcRequestsInit[_root_.kyo.Stream[kgrpc.test.Request, _root_.kyo.grpc.Grpc]] + ): _root_.kyo.<[kgrpc.test.Response, _root_.kyo.grpc.Grpc] = + _root_.kyo.grpc.ClientCall.clientStreaming(channel, _root_.kgrpc.test.UtilityServiceGrpc.METHOD_BATCH, options, requests) + } +} diff --git a/kyo-grpc-code-gen/shared/src/test/resources/output/single-file b/kyo-grpc-code-gen/shared/src/test/resources/output/single-file new file mode 100644 index 000000000..d271822f1 --- /dev/null +++ b/kyo-grpc-code-gen/shared/src/test/resources/output/single-file @@ -0,0 +1,173 @@ +package kgrpc.test + +trait TestService extends _root_.kyo.grpc.Service { + + override def definition: _root_.io.grpc.ServerServiceDefinition = TestService.service(this) + + def oneToOne(request: kgrpc.test.Request): _root_.kyo.<[kgrpc.test.Response, _root_.kyo.grpc.Grpc] + def oneToMany( + request: kgrpc.test.Request + ): _root_.kyo.<[_root_.kyo.Stream[kgrpc.test.Response, _root_.kyo.grpc.Grpc], _root_.kyo.grpc.Grpc] + def manyToOne( + requests: _root_.kyo.Stream[kgrpc.test.Request, _root_.kyo.grpc.Grpc] + ): _root_.kyo.<[kgrpc.test.Response, _root_.kyo.grpc.Grpc] + def manyToMany( + requests: _root_.kyo.Stream[kgrpc.test.Request, _root_.kyo.grpc.Grpc] + ): _root_.kyo.<[_root_.kyo.Stream[kgrpc.test.Response, _root_.kyo.grpc.Grpc], _root_.kyo.grpc.Grpc] +} + +object TestService { + + def service(serviceImpl: TestService): _root_.io.grpc.ServerServiceDefinition = { + _root_.io.grpc.ServerServiceDefinition.builder(_root_.kgrpc.test.TestServiceGrpc.SERVICE) + .addMethod( + _root_.kgrpc.test.TestServiceGrpc.METHOD_ONE_TO_ONE, + _root_.kyo.grpc.ServerCallHandlers.unary(serviceImpl.oneToOne) + ) + .addMethod( + _root_.kgrpc.test.TestServiceGrpc.METHOD_ONE_TO_MANY, + _root_.kyo.grpc.ServerCallHandlers.serverStreaming(serviceImpl.oneToMany) + ) + .addMethod( + _root_.kgrpc.test.TestServiceGrpc.METHOD_MANY_TO_ONE, + _root_.kyo.grpc.ServerCallHandlers.clientStreaming(serviceImpl.manyToOne) + ) + .addMethod( + _root_.kgrpc.test.TestServiceGrpc.METHOD_MANY_TO_MANY, + _root_.kyo.grpc.ServerCallHandlers.bidiStreaming(serviceImpl.manyToMany) + ) + .build() + } + + def client( + channel: _root_.io.grpc.Channel, + options: _root_.io.grpc.CallOptions = _root_.io.grpc.CallOptions.DEFAULT + ): Client = ClientImpl(channel, options) + + def managedClient( + host: String, + port: Int, + timeout: _root_.kyo.Duration = _root_.kyo.Duration.fromUnits(30, _root_.kyo.Duration.Units.Millis), + options: _root_.io.grpc.CallOptions = _root_.io.grpc.CallOptions.DEFAULT + )( + configure: _root_.io.grpc.ManagedChannelBuilder[?] => _root_.io.grpc.ManagedChannelBuilder[?], + shutdown: (_root_.io.grpc.ManagedChannel, _root_.kyo.Duration) => _root_.kyo.Frame ?=> _root_.kyo.<[Any, _root_.kyo.Sync] = _root_.kyo.grpc.Client.shutdown + )(using + _root_.kyo.Frame + ): _root_.kyo.<[Client, _root_.kyo.Scope & _root_.kyo.Sync] = + _root_.kyo.grpc.Client.channel(host, port, timeout)(configure, shutdown).map(TestService.client(_, options)) + + trait Client { + def oneToOne( + request: _root_.kyo.grpc.GrpcRequestInit[kgrpc.test.Request] + ): _root_.kyo.<[kgrpc.test.Response, _root_.kyo.grpc.Grpc] + def oneToMany( + request: _root_.kyo.grpc.GrpcRequestInit[kgrpc.test.Request] + ): _root_.kyo.Stream[kgrpc.test.Response, _root_.kyo.grpc.Grpc] + def manyToOne( + requests: _root_.kyo.grpc.GrpcRequestsInit[_root_.kyo.Stream[kgrpc.test.Request, _root_.kyo.grpc.Grpc]] + ): _root_.kyo.<[kgrpc.test.Response, _root_.kyo.grpc.Grpc] + def manyToMany( + requests: _root_.kyo.grpc.GrpcRequestsInit[_root_.kyo.Stream[kgrpc.test.Request, _root_.kyo.grpc.Grpc]] + ): _root_.kyo.Stream[kgrpc.test.Response, _root_.kyo.grpc.Grpc] + } + + class ClientImpl(channel: _root_.io.grpc.Channel, options: _root_.io.grpc.CallOptions) + extends Client { + override def oneToOne( + request: _root_.kyo.grpc.GrpcRequestInit[kgrpc.test.Request] + ): _root_.kyo.<[kgrpc.test.Response, _root_.kyo.grpc.Grpc] = + _root_.kyo.grpc.ClientCall.unary(channel, _root_.kgrpc.test.TestServiceGrpc.METHOD_ONE_TO_ONE, options, request) + override def oneToMany( + request: _root_.kyo.grpc.GrpcRequestInit[kgrpc.test.Request] + ): _root_.kyo.Stream[kgrpc.test.Response, _root_.kyo.grpc.Grpc] = + _root_.kyo.grpc.ClientCall.serverStreaming(channel, _root_.kgrpc.test.TestServiceGrpc.METHOD_ONE_TO_MANY, options, request) + override def manyToOne( + requests: _root_.kyo.grpc.GrpcRequestsInit[_root_.kyo.Stream[kgrpc.test.Request, _root_.kyo.grpc.Grpc]] + ): _root_.kyo.<[kgrpc.test.Response, _root_.kyo.grpc.Grpc] = + _root_.kyo.grpc.ClientCall.clientStreaming(channel, _root_.kgrpc.test.TestServiceGrpc.METHOD_MANY_TO_ONE, options, requests) + override def manyToMany( + requests: _root_.kyo.grpc.GrpcRequestsInit[_root_.kyo.Stream[kgrpc.test.Request, _root_.kyo.grpc.Grpc]] + ): _root_.kyo.Stream[kgrpc.test.Response, _root_.kyo.grpc.Grpc] = + _root_.kyo.grpc.ClientCall.bidiStreaming(channel, _root_.kgrpc.test.TestServiceGrpc.METHOD_MANY_TO_MANY, options, requests) + } +} + +trait UtilityService extends _root_.kyo.grpc.Service { + + override def definition: _root_.io.grpc.ServerServiceDefinition = UtilityService.service(this) + + def health(request: kgrpc.test.Request): _root_.kyo.<[kgrpc.test.Response, _root_.kyo.grpc.Grpc] + def monitor( + request: kgrpc.test.Request + ): _root_.kyo.<[_root_.kyo.Stream[kgrpc.test.Response, _root_.kyo.grpc.Grpc], _root_.kyo.grpc.Grpc] + def batch( + requests: _root_.kyo.Stream[kgrpc.test.Request, _root_.kyo.grpc.Grpc] + ): _root_.kyo.<[kgrpc.test.Response, _root_.kyo.grpc.Grpc] +} + +object UtilityService { + + def service(serviceImpl: UtilityService): _root_.io.grpc.ServerServiceDefinition = { + _root_.io.grpc.ServerServiceDefinition.builder(_root_.kgrpc.test.UtilityServiceGrpc.SERVICE) + .addMethod( + _root_.kgrpc.test.UtilityServiceGrpc.METHOD_HEALTH, + _root_.kyo.grpc.ServerCallHandlers.unary(serviceImpl.health) + ) + .addMethod( + _root_.kgrpc.test.UtilityServiceGrpc.METHOD_MONITOR, + _root_.kyo.grpc.ServerCallHandlers.serverStreaming(serviceImpl.monitor) + ) + .addMethod( + _root_.kgrpc.test.UtilityServiceGrpc.METHOD_BATCH, + _root_.kyo.grpc.ServerCallHandlers.clientStreaming(serviceImpl.batch) + ) + .build() + } + + def client( + channel: _root_.io.grpc.Channel, + options: _root_.io.grpc.CallOptions = _root_.io.grpc.CallOptions.DEFAULT + ): Client = ClientImpl(channel, options) + + def managedClient( + host: String, + port: Int, + timeout: _root_.kyo.Duration = _root_.kyo.Duration.fromUnits(30, _root_.kyo.Duration.Units.Millis), + options: _root_.io.grpc.CallOptions = _root_.io.grpc.CallOptions.DEFAULT + )( + configure: _root_.io.grpc.ManagedChannelBuilder[?] => _root_.io.grpc.ManagedChannelBuilder[?], + shutdown: (_root_.io.grpc.ManagedChannel, _root_.kyo.Duration) => _root_.kyo.Frame ?=> _root_.kyo.<[Any, _root_.kyo.Sync] = _root_.kyo.grpc.Client.shutdown + )(using + _root_.kyo.Frame + ): _root_.kyo.<[Client, _root_.kyo.Scope & _root_.kyo.Sync] = + _root_.kyo.grpc.Client.channel(host, port, timeout)(configure, shutdown).map(UtilityService.client(_, options)) + + trait Client { + def health( + request: _root_.kyo.grpc.GrpcRequestInit[kgrpc.test.Request] + ): _root_.kyo.<[kgrpc.test.Response, _root_.kyo.grpc.Grpc] + def monitor( + request: _root_.kyo.grpc.GrpcRequestInit[kgrpc.test.Request] + ): _root_.kyo.Stream[kgrpc.test.Response, _root_.kyo.grpc.Grpc] + def batch( + requests: _root_.kyo.grpc.GrpcRequestsInit[_root_.kyo.Stream[kgrpc.test.Request, _root_.kyo.grpc.Grpc]] + ): _root_.kyo.<[kgrpc.test.Response, _root_.kyo.grpc.Grpc] + } + + class ClientImpl(channel: _root_.io.grpc.Channel, options: _root_.io.grpc.CallOptions) + extends Client { + override def health( + request: _root_.kyo.grpc.GrpcRequestInit[kgrpc.test.Request] + ): _root_.kyo.<[kgrpc.test.Response, _root_.kyo.grpc.Grpc] = + _root_.kyo.grpc.ClientCall.unary(channel, _root_.kgrpc.test.UtilityServiceGrpc.METHOD_HEALTH, options, request) + override def monitor( + request: _root_.kyo.grpc.GrpcRequestInit[kgrpc.test.Request] + ): _root_.kyo.Stream[kgrpc.test.Response, _root_.kyo.grpc.Grpc] = + _root_.kyo.grpc.ClientCall.serverStreaming(channel, _root_.kgrpc.test.UtilityServiceGrpc.METHOD_MONITOR, options, request) + override def batch( + requests: _root_.kyo.grpc.GrpcRequestsInit[_root_.kyo.Stream[kgrpc.test.Request, _root_.kyo.grpc.Grpc]] + ): _root_.kyo.<[kgrpc.test.Response, _root_.kyo.grpc.Grpc] = + _root_.kyo.grpc.ClientCall.clientStreaming(channel, _root_.kgrpc.test.UtilityServiceGrpc.METHOD_BATCH, options, requests) + } +} diff --git a/kyo-grpc-code-gen/shared/src/test/scala/kyo/grpc/compiler/CodeGeneratorTest.scala b/kyo-grpc-code-gen/shared/src/test/scala/kyo/grpc/compiler/CodeGeneratorTest.scala new file mode 100644 index 000000000..15902bb73 --- /dev/null +++ b/kyo-grpc-code-gen/shared/src/test/scala/kyo/grpc/compiler/CodeGeneratorTest.scala @@ -0,0 +1,267 @@ +package kyo.grpc.compiler + +import com.google.protobuf.DescriptorProtos.* +import com.google.protobuf.compiler.PluginProtos +import com.google.protobuf.compiler.PluginProtos.CodeGeneratorRequest +import org.scalatest.freespec.AnyFreeSpec +import protocgen.CodeGenRequest +import scalapb.options.Scalapb +import scalapb.options.Scalapb.ScalaPbOptions + +class CodeGeneratorTest extends AnyFreeSpec { + + "CodeGenerator" - { + "process a request with default options" in { + val protoFile = createTestProtoDescriptor() + + val response = testProcess(protoFile) + + assert(response.getError.isEmpty) + assert(response.getUnknownFields.asMap().isEmpty) + assert(response.getFileCount === 2) + + val file1 = response.getFile(0) + assert(file1.getUnknownFields.asMap().isEmpty) + assert(file1.getName === "kgrpc/test/TestService.scala") + java.nio.file.Files.writeString(java.nio.file.Path.of("/tmp/gen-multi-1.txt"), file1.getContent) + val expected1 = scala.io.Source.fromResource("output/multiple-files-1").mkString + assert(file1.getContent === expected1) + + val file2 = response.getFile(1) + assert(file2.getUnknownFields.asMap().isEmpty) + assert(file2.getName === "kgrpc/test/UtilityService.scala") + java.nio.file.Files.writeString(java.nio.file.Path.of("/tmp/gen-multi-2.txt"), file2.getContent) + val expected2 = scala.io.Source.fromResource("output/multiple-files-2").mkString + assert(file2.getContent === expected2) + } + "process a request with single file" in { + val options = + ScalaPbOptions + .newBuilder() + .setSingleFile(true) + .build() + + val protoFile = createTestProtoDescriptor(options) + + val response = testProcess(protoFile) + + assert(response.getError.isEmpty) + assert(response.getUnknownFields.asMap().isEmpty) + assert(response.getFileCount === 1) + + val file = response.getFile(0) + assert(file.getUnknownFields.asMap().isEmpty) + assert(file.getName === "kgrpc/test/TestProto.scala") + java.nio.file.Files.writeString(java.nio.file.Path.of("/tmp/gen-single.txt"), file.getContent) + val expected = scala.io.Source.fromResource("output/single-file").mkString + assert(file.getContent === expected) + } + } + + private def testProcess(protoFile: FileDescriptorProto) = { + val version = PluginProtos.Version.newBuilder().setMajor(3).setMinor(21).setPatch(7).build() + + val request = CodeGenRequest( + CodeGeneratorRequest + .newBuilder() + .setCompilerVersion(version) + .setParameter("grpc") + .addFileToGenerate("test.proto") + .addProtoFile(protoFile) + .build() + ) + + CodeGenerator.process(request).toCodeGeneratorResponse + } + + private def createTestProtoDescriptor(options: ScalaPbOptions = ScalaPbOptions.getDefaultInstance): FileDescriptorProto = { + // Create oneof for Request + val requestOneof = OneofDescriptorProto.newBuilder() + .setName("sealed_value") + .build() + + val requestMessage = DescriptorProto.newBuilder() + .setName("Request") + .addField(FieldDescriptorProto.newBuilder() + .setName("success") + .setNumber(1) + .setType(FieldDescriptorProto.Type.TYPE_MESSAGE) + .setTypeName(".kgrpc.Success") + .setLabel(FieldDescriptorProto.Label.LABEL_OPTIONAL) + .setOneofIndex(0) + .setJsonName("success")) + .addField(FieldDescriptorProto.newBuilder() + .setName("fail") + .setNumber(2) + .setType(FieldDescriptorProto.Type.TYPE_MESSAGE) + .setTypeName(".kgrpc.Fail") + .setLabel(FieldDescriptorProto.Label.LABEL_OPTIONAL) + .setOneofIndex(0) + .setJsonName("fail")) + .addField(FieldDescriptorProto.newBuilder() + .setName("panic") + .setNumber(3) + .setType(FieldDescriptorProto.Type.TYPE_MESSAGE) + .setTypeName(".kgrpc.Panic") + .setLabel(FieldDescriptorProto.Label.LABEL_OPTIONAL) + .setOneofIndex(0) + .setJsonName("panic")) + .addOneofDecl(requestOneof) + .build() + + // Create oneof for Response + val responseOneof = OneofDescriptorProto.newBuilder() + .setName("sealed_value") + .build() + + val responseMessage = DescriptorProto.newBuilder() + .setName("Response") + .addField(FieldDescriptorProto.newBuilder() + .setName("echo") + .setNumber(1) + .setType(FieldDescriptorProto.Type.TYPE_MESSAGE) + .setTypeName(".kgrpc.Echo") + .setLabel(FieldDescriptorProto.Label.LABEL_OPTIONAL) + .setOneofIndex(0) + .setJsonName("echo")) + .addOneofDecl(responseOneof) + .build() + + val successMessage = DescriptorProto.newBuilder() + .setName("Success") + .addField(FieldDescriptorProto.newBuilder() + .setName("message") + .setNumber(1) + .setType(FieldDescriptorProto.Type.TYPE_STRING) + .setLabel(FieldDescriptorProto.Label.LABEL_OPTIONAL) + .setJsonName("message")) + .addField(FieldDescriptorProto.newBuilder() + .setName("count") + .setNumber(2) + .setType(FieldDescriptorProto.Type.TYPE_INT32) + .setLabel(FieldDescriptorProto.Label.LABEL_OPTIONAL) + .setJsonName("count")) + .build() + + val failMessage = DescriptorProto.newBuilder() + .setName("Fail") + .addField(FieldDescriptorProto.newBuilder() + .setName("code") + .setNumber(1) + .setType(FieldDescriptorProto.Type.TYPE_INT32) + .setLabel(FieldDescriptorProto.Label.LABEL_OPTIONAL) + .setJsonName("code")) + .addField(FieldDescriptorProto.newBuilder() + .setName("after") + .setNumber(2) + .setType(FieldDescriptorProto.Type.TYPE_INT32) + .setLabel(FieldDescriptorProto.Label.LABEL_OPTIONAL) + .setJsonName("after")) + .addField(FieldDescriptorProto.newBuilder() + .setName("outside") + .setNumber(3) + .setType(FieldDescriptorProto.Type.TYPE_BOOL) + .setLabel(FieldDescriptorProto.Label.LABEL_OPTIONAL) + .setJsonName("outside")) + .build() + + val panicMessage = DescriptorProto.newBuilder() + .setName("Panic") + .addField(FieldDescriptorProto.newBuilder() + .setName("message") + .setNumber(1) + .setType(FieldDescriptorProto.Type.TYPE_STRING) + .setLabel(FieldDescriptorProto.Label.LABEL_OPTIONAL) + .setJsonName("message")) + .addField(FieldDescriptorProto.newBuilder() + .setName("after") + .setNumber(2) + .setType(FieldDescriptorProto.Type.TYPE_INT32) + .setLabel(FieldDescriptorProto.Label.LABEL_OPTIONAL) + .setJsonName("after")) + .addField(FieldDescriptorProto.newBuilder() + .setName("outside") + .setNumber(3) + .setType(FieldDescriptorProto.Type.TYPE_BOOL) + .setLabel(FieldDescriptorProto.Label.LABEL_OPTIONAL) + .setJsonName("outside")) + .build() + + val echoMessage = DescriptorProto.newBuilder() + .setName("Echo") + .addField(FieldDescriptorProto.newBuilder() + .setName("message") + .setNumber(1) + .setType(FieldDescriptorProto.Type.TYPE_STRING) + .setLabel(FieldDescriptorProto.Label.LABEL_OPTIONAL) + .setJsonName("message")) + .build() + + // Create services + val testService = ServiceDescriptorProto.newBuilder() + .setName("TestService") + .addMethod(MethodDescriptorProto.newBuilder() + .setName("OneToOne") + .setInputType(".kgrpc.Request") + .setOutputType(".kgrpc.Response") + .setClientStreaming(false) + .setServerStreaming(false)) + .addMethod(MethodDescriptorProto.newBuilder() + .setName("OneToMany") + .setInputType(".kgrpc.Request") + .setOutputType(".kgrpc.Response") + .setClientStreaming(false) + .setServerStreaming(true)) + .addMethod(MethodDescriptorProto.newBuilder() + .setName("ManyToOne") + .setInputType(".kgrpc.Request") + .setOutputType(".kgrpc.Response") + .setClientStreaming(true) + .setServerStreaming(false)) + .addMethod(MethodDescriptorProto.newBuilder() + .setName("ManyToMany") + .setInputType(".kgrpc.Request") + .setOutputType(".kgrpc.Response") + .setClientStreaming(true) + .setServerStreaming(true)) + .build() + + val utilityService = ServiceDescriptorProto.newBuilder() + .setName("UtilityService") + .addMethod(MethodDescriptorProto.newBuilder() + .setName("Health") + .setInputType(".kgrpc.Request") + .setOutputType(".kgrpc.Response") + .setClientStreaming(false) + .setServerStreaming(false)) + .addMethod(MethodDescriptorProto.newBuilder() + .setName("Monitor") + .setInputType(".kgrpc.Request") + .setOutputType(".kgrpc.Response") + .setClientStreaming(false) + .setServerStreaming(true)) + .addMethod(MethodDescriptorProto.newBuilder() + .setName("Batch") + .setInputType(".kgrpc.Request") + .setOutputType(".kgrpc.Response") + .setClientStreaming(true) + .setServerStreaming(false)) + .build() + + // Create the file descriptor + FileDescriptorProto.newBuilder() + .setName("test.proto") + .setPackage("kgrpc") + .setSyntax("proto3") + .setOptions(FileOptions.newBuilder().setExtension(Scalapb.options, options)) + .addMessageType(requestMessage) + .addMessageType(responseMessage) + .addMessageType(successMessage) + .addMessageType(failMessage) + .addMessageType(panicMessage) + .addMessageType(echoMessage) + .addService(testService) + .addService(utilityService) + .build() + } +} diff --git a/kyo-grpc-code-gen/shared/src/test/scala/kyo/grpc/compiler/internal/ScalaFunctionalPrinterOpsTest.scala b/kyo-grpc-code-gen/shared/src/test/scala/kyo/grpc/compiler/internal/ScalaFunctionalPrinterOpsTest.scala new file mode 100644 index 000000000..9ae182b4c --- /dev/null +++ b/kyo-grpc-code-gen/shared/src/test/scala/kyo/grpc/compiler/internal/ScalaFunctionalPrinterOpsTest.scala @@ -0,0 +1,366 @@ +package kyo.grpc.compiler.internal + +import org.scalatest.freespec.AnyFreeSpec +import org.typelevel.paiges.Doc +import scalapb.compiler.FunctionalPrinter + +class ScalaFunctionalPrinterOpsTest extends AnyFreeSpec { + + "ScalaFunctionalPrinterOps" - { + "addDoc" - { + "should add multiline docs correctly" in { + val fp = new FunctionalPrinter() + val doc = Doc.text("a") + Doc.hardLine + Doc.text("b") + val actual = fp.addDoc(doc).result() + val expected = + """a + |b""".stripMargin + assert(actual == expected) + } + } + "addMethod" - { + "should add an abstract method without mods" in { + val fp = new FunctionalPrinter() + val actual = fp.addMethod("foo").result() + val expected = + """def foo""".stripMargin + assert(actual == expected) + } + "should add an abstract method with a mod" in { + val fp = new FunctionalPrinter() + val actual = fp.addMethod("foo").addMods(_.Override).result() + val expected = + """override def foo""".stripMargin + assert(actual == expected) + } + "should add an abstract method with multiple mods" in { + val fp = new FunctionalPrinter() + val actual = fp.addMethod("foo").addMods(_.Private, _.Override).result() + val expected = + """private override def foo""".stripMargin + assert(actual == expected) + } + "should add an abstract method with a type parameter" in { + val fp = new FunctionalPrinter() + val actual = fp.addMethod("foo").addTypeParameters("A").result() + val expected = + """def foo[A]""".stripMargin + assert(actual == expected) + } + "should add an abstract method with multiple type parameters" in { + val fp = new FunctionalPrinter() + val actual = fp.addMethod("foo").addTypeParameters("A", "B").result() + val expected = + """def foo[A, B]""".stripMargin + assert(actual == expected) + } + "should add an abstract method with multiple long type parameters" in { + val fp = new FunctionalPrinter() + val as = "A" * (WIDTH / 2) + val bs = "B" * (WIDTH / 2) + val actual = fp.addMethod("foo").addTypeParameters(as, bs).result() + val expected = + s"""def foo[$as, $bs]""".stripMargin + assert(actual == expected) + } + "should add an abstract method with a parameter" in { + val fp = new FunctionalPrinter() + val actual = fp.addMethod("foo").addParameterList("a" :- "A").result() + val expected = + """def foo(a: A)""".stripMargin + assert(actual == expected) + } + "should add an abstract method with multiple parameters on one line" in { + val fp = new FunctionalPrinter() + val actual = fp.addMethod("foo").addParameterList("a" :- "A", "b" :- "B").result() + val expected = + """def foo(a: A, b: B)""".stripMargin + assert(actual == expected) + } + "should add an abstract method with multiple parameters on multiple lines" in { + val fp = new FunctionalPrinter() + val as = "a" * (WIDTH / 2) + val bs = "b" * (WIDTH / 2) + val actual = fp.addMethod("foo").addParameterList(as :- "A", bs :- "B").result() + val expected = + s"""def foo( + | $as: A, + | $bs: B + |)""".stripMargin + assert(actual == expected) + } + "should add an abstract method with multiple parameter lists on one line" in { + val fp = new FunctionalPrinter() + val actual = fp.addMethod("foo").addParameterList("a" :- "A", "b" :- "B").addParameterList("c" :- "C").result() + val expected = + """def foo(a: A, b: B)(c: C)""".stripMargin + assert(actual == expected) + } + "should add an abstract method with multiple parameter lists on mixed lines" in { + val fp = new FunctionalPrinter() + val as = "a" * (WIDTH / 2) + val bs = "b" * (WIDTH / 2) + val cs = "c" * (WIDTH / 2) + val actual = fp.addMethod("foo").addParameterList(as :- "A").addParameterList(bs :- "B", cs :- "C").result() + val expected = + s"""def foo( + | $as: A + |)( + | $bs: B, + | $cs: C + |)""".stripMargin + assert(actual == expected) + } + "should add an abstract method with multiple parameter lists on mixed lines 2" in { + val fp = new FunctionalPrinter() + val as = "a" * (WIDTH / 2) + val bs = "b" * (WIDTH / 2) + val cs = "c" * (WIDTH / 2) + val actual = fp.addMethod("foo").addParameterList(as :- "A", bs :- "B").addParameterList(cs :- "C").result() + val expected = + s"""def foo( + | $as: A, + | $bs: B + |)( + | $cs: C + |)""".stripMargin + assert(actual == expected) + } + "should add an abstract method with multiple parameter lists on multiple lines" in { + val fp = new FunctionalPrinter() + val as = "a" * (WIDTH / 2) + val bs = "b" * (WIDTH / 2) + val cs = "c" * (WIDTH / 2) + val ds = "d" * (WIDTH / 2) + val actual = fp.addMethod("foo").addParameterList(as :- "A", bs :- "B").addParameterList(cs :- "C", ds :- "D").result() + val expected = + s"""def foo( + | $as: A, + | $bs: B + |)( + | $cs: C, + | $ds: D + |)""".stripMargin + assert(actual == expected) + } + "should add an abstract method with an implicit parameter" in { + val fp = new FunctionalPrinter() + val actual = fp.addMethod("foo").addImplicitParameters("a" :- "A").result() + val expected = + """def foo(implicit a: A)""".stripMargin + assert(actual == expected) + } + "should add an abstract method with multiple implicit parameters on one line" in { + val fp = new FunctionalPrinter() + val actual = fp.addMethod("foo").addImplicitParameters("a" :- "A", "b" :- "B").result() + val expected = + """def foo(implicit a: A, b: B)""".stripMargin + assert(actual == expected) + } + "should add an abstract method with multiple implicit parameters on multiple lines" in { + val fp = new FunctionalPrinter() + val as = "a" * (WIDTH / 2) + val bs = "b" * (WIDTH / 2) + val actual = fp.addMethod("foo").addImplicitParameters(as :- "A", bs :- "B").result() + val expected = + s"""def foo(implicit + | $as: A, + | $bs: B + |)""".stripMargin + assert(actual == expected) + } + "should add an abstract method with normal parameters and implicit parameters on one line" in { + val fp = new FunctionalPrinter() + val actual = fp.addMethod("foo").addParameterList("a" :- "A", "b" :- "B").addImplicitParameters("c" :- "C").result() + val expected = + """def foo(a: A, b: B)(implicit c: C)""".stripMargin + assert(actual == expected) + } + "should add an abstract method with normal parameters and implicit parameters on mixed lines" in { + val fp = new FunctionalPrinter() + val as = "a" * (WIDTH / 2) + val bs = "b" * (WIDTH / 2) + val cs = "c" * (WIDTH / 2) + val actual = fp.addMethod("foo").addParameterList(as :- "A").addImplicitParameters(bs :- "B", cs :- "C").result() + val expected = + s"""def foo( + | $as: A + |)(implicit + | $bs: B, + | $cs: C + |)""".stripMargin + assert(actual == expected) + } + "should add an abstract method with normal parameters and implicit parameters on mixed lines 2" in { + val fp = new FunctionalPrinter() + val as = "a" * (WIDTH / 2) + val bs = "b" * (WIDTH / 2) + val cs = "c" * (WIDTH / 2) + val actual = fp.addMethod("foo").addParameterList(as :- "A", bs :- "B").addImplicitParameters(cs :- "C").result() + val expected = + s"""def foo( + | $as: A, + | $bs: B + |)(implicit + | $cs: C + |)""".stripMargin + assert(actual == expected) + } + "should add an abstract method with normal parameters and implicit parameters on multiple lines" in { + val fp = new FunctionalPrinter() + val as = "a" * (WIDTH / 2) + val bs = "b" * (WIDTH / 2) + val cs = "c" * (WIDTH / 2) + val ds = "d" * (WIDTH / 2) + val actual = fp.addMethod("foo").addParameterList(as :- "A", bs :- "B").addImplicitParameters(cs :- "C", ds :- "D").result() + val expected = + s"""def foo( + | $as: A, + | $bs: B + |)(implicit + | $cs: C, + | $ds: D + |)""".stripMargin + assert(actual == expected) + } + "should add an abstract method with a return type" in { + val fp = new FunctionalPrinter() + val actual = fp.addMethod("foo").addReturnType("Foo").result() + val expected = + """def foo: Foo""".stripMargin + assert(actual == expected) + } + "should add a method with a simple body" in { + val fp = new FunctionalPrinter() + val actual = fp.addMethod("foo").addBody(_.add("Foo")).result() + val expected = + """def foo = Foo""".stripMargin + assert(actual == expected) + } + "should add a method with a long body" in { + val fp = new FunctionalPrinter() + val as = "a" * (WIDTH + 1) + val actual = fp.addMethod("foo").addBody(_.add(as)).result() + val expected = + s"""def foo = + | $as""".stripMargin + assert(actual == expected) + } + "should add a method with a multiline signature an a long body" in { + val fp = new FunctionalPrinter() + + val as = "a" * (WIDTH + 1) + val bs = "b" * (WIDTH + 1) + val actual = fp.addMethod("foo").addParameterList(as :- "A").addBody(_.add(bs)).result() + val expected = + s"""def foo( + | $as: A + |) = + | $bs""".stripMargin + assert(actual == expected) + } + "should add a method with a multiline body" in { + val fp = new FunctionalPrinter() + val body = + """a + |b""".stripMargin + val actual = fp.addMethod("foo").addBody(_.add(body)).result() + val expected = + """def foo = { + | a + | b + |}""".stripMargin + assert(actual == expected) + } + } + "addObject" - { + "should add an object without mods" in { + val fp = new FunctionalPrinter() + val actual = fp.addObject("Foo").result() + val expected = + """object Foo""".stripMargin + assert(actual == expected) + } + "should add an object with a mod" in { + val fp = new FunctionalPrinter() + val actual = fp.addObject("Foo").addMods(_.Case).result() + val expected = + """case object Foo""".stripMargin + assert(actual == expected) + } + "should add an object with multiple mods" in { + val fp = new FunctionalPrinter() + val actual = fp.addObject("Foo").addMods(_.Private, _.Case).result() + val expected = + """private case object Foo""".stripMargin + assert(actual == expected) + } + "should add an object with a parent" in { + val fp = new FunctionalPrinter() + val actual = fp.addObject("Foo").addParents("A").result() + val expected = + """object Foo extends A""".stripMargin + assert(actual == expected) + } + "should add an object with a long parent" in { + val fp = new FunctionalPrinter() + val as = "A" * WIDTH + val actual = fp.addObject("Foo").addParents(as).result() + val expected = + s"""object Foo + | extends $as""".stripMargin + assert(actual == expected) + } + "should add an object with multiple parents" in { + val fp = new FunctionalPrinter() + val actual = fp.addObject("Foo").addParents("A", "B").result() + val expected = + """object Foo extends A with B""".stripMargin + assert(actual == expected) + } + "should add an object with multiple long parents" in { + val fp = new FunctionalPrinter() + val as = "A" * (WIDTH / 2) + val bs = "B" * (WIDTH / 2) + val actual = fp.addObject("Foo").addParents(as, bs).result() + val expected = + s"""object Foo + | extends $as + | with $bs""".stripMargin + assert(actual == expected) + } + "should add an object with a simple body" in { + val fp = new FunctionalPrinter() + val actual = fp.addObject("Foo").addBody(_.add("Foo")).result() + val expected = + """object Foo { + | Foo + |}""".stripMargin + assert(actual == expected) + } + "should add an object with a long body" in { + val fp = new FunctionalPrinter() + val as = "a" * (WIDTH + 1) + val actual = fp.addObject("Foo").addBody(_.add(as)).result() + val expected = + s"""object Foo { + | $as + |}""".stripMargin + assert(actual == expected) + } + "should add an object with a multiline body" in { + val fp = new FunctionalPrinter() + val body = + """a + |b""".stripMargin + val actual = fp.addObject("Foo").addBody(_.add(body)).result() + val expected = + """object Foo { + | a + | b + |}""".stripMargin + assert(actual == expected) + } + } + } +} diff --git a/kyo-grpc-code-gen/shared/src/test/scala/org/typelevel/paiges/internal/DocxTest.scala b/kyo-grpc-code-gen/shared/src/test/scala/org/typelevel/paiges/internal/DocxTest.scala new file mode 100644 index 000000000..858a2a69c --- /dev/null +++ b/kyo-grpc-code-gen/shared/src/test/scala/org/typelevel/paiges/internal/DocxTest.scala @@ -0,0 +1,43 @@ +package org.typelevel.paiges.internal + +import org.scalatest.freespec.AnyFreeSpec +import org.typelevel.paiges.* + +class DocxTest extends AnyFreeSpec { + + "Docx" - { + "bracketIfMultiline" - { + "should add separators when is too long" in { + val sig = Doc.text("def foo = ") + val body = Doc.text("howlongcanyougo") + val doc = sig + Docx.bracketIfMultiline(Doc.char('{'), body, Doc.char('}')) + val actual = doc.render(24) + val expected = + """def foo = { + | howlongcanyougo + |}""".stripMargin + assert(actual == expected) + } + "should not add separators when the document fits" in { + val sig = Doc.text("def foo = ") + val body = Doc.text("howlongcanyougo") + val doc = sig + Docx.bracketIfMultiline(Doc.char('{'), body, Doc.char('}')) + val actual = doc.render(25) + val expected = """def foo = howlongcanyougo""".stripMargin + assert(actual == expected) + } + "should add separators when the document is multiline" in { + val sig = Doc.text("def foo = ") + val body = Doc.text("a") + Doc.hardLine + Doc.text("b") + val doc = sig + Docx.bracketIfMultiline(Doc.char('{'), body, Doc.char('}')) + val actual = doc.render(25) + val expected = + """def foo = { + | a + | b + |}""".stripMargin + assert(actual == expected) + } + } + } +} diff --git a/kyo-grpc-code-gen/shared/src/test/scala/org/typelevel/paiges/internal/ExtendedSyntaxTest.scala b/kyo-grpc-code-gen/shared/src/test/scala/org/typelevel/paiges/internal/ExtendedSyntaxTest.scala new file mode 100644 index 000000000..73e906799 --- /dev/null +++ b/kyo-grpc-code-gen/shared/src/test/scala/org/typelevel/paiges/internal/ExtendedSyntaxTest.scala @@ -0,0 +1,122 @@ +package org.typelevel.paiges.internal + +import org.scalatest.freespec.AnyFreeSpec +import org.typelevel.paiges.* +import org.typelevel.paiges.internal.ExtendedSyntax.* + +class ExtendedSyntaxTest extends AnyFreeSpec { + + "ExtendedSyntax" - { + "ungrouped" - { + "should collapse none when fits" in { + val doc = ((Doc.text("a") + Doc.line + Doc.text("b")).grouped + Doc.line + Doc.text("c")).ungrouped + val actual = doc.render(5) + val expected = + """|a + |b + |c""".stripMargin + assert(actual == expected) + } + "should collapse none when too long" in { + val doc = ((Doc.text("a") + Doc.line + Doc.text("b")).grouped + Doc.line + Doc.text("c")).ungrouped + val actual = doc.render(3) + val expected = + """|a + |b + |c""".stripMargin + assert(actual == expected) + } + } + "regrouped" - { + "should collapse all when fits" in { + val doc = ((Doc.text("a") + Doc.line + Doc.text("b")).grouped + Doc.line + Doc.text("c")).regrouped + val actual = doc.render(5) + val expected = + """|a b c""".stripMargin + assert(actual == expected) + } + "should collapse none when too long" in { + val doc = ((Doc.text("a") + Doc.line + Doc.text("b")).grouped + Doc.line + Doc.text("c")).regrouped + val actual = doc.render(3) + val expected = + """|a + |b + |c""".stripMargin + assert(actual == expected) + } + } + "hangingUnsafe" - { + "should not indent when short" in { + val prefix = Doc.text("foo") + val doc = prefix + Doc.text("bar").hangingUnsafe(2) + val actual = doc.grouped.render(10) + val expected = + """foo bar""".stripMargin + assert(actual == expected) + } + "should indent when long" in { + val prefix = Doc.text("foo") + val doc = prefix + Doc.text("bar").hangingUnsafe(2) + val actual = doc.grouped.render(5) + val expected = + """foo + | bar""".stripMargin + assert(actual == expected) + } + "should indent when long from multiline" in { + val prefix = Doc.text("foo") + Doc.line + Doc.text("baz") + val doc = prefix + Doc.text("bar").hangingUnsafe(2) + val actual = doc.grouped.render(5) + val expected = + """foo + |baz + | bar""".stripMargin + assert(actual == expected) + } + // These tests convince us why we can't use hang. + "hang" - { + "should not indent when short" in { + val prefix = Doc.text("foo") + Doc.line + val doc = (prefix + Doc.text("bar")).hang(2) + val actual = doc.grouped.render(10) + val expected = + """foo bar""".stripMargin + assert(actual == expected) + } + "should indent when long" in { + val prefix = Doc.text("foo") + Doc.line + val doc = (prefix + Doc.text("bar")).hang(2) + val actual = doc.grouped.render(5) + val expected = + """foo + | bar""".stripMargin + assert(actual == expected) + } + "should indent when long from multiline" in { + // This is why hang doesn't work. It requires that you include the previous new line. + val prefix = Doc.text("foo") + Doc.line + Doc.text("baz") + val doc = (prefix + Doc.line + Doc.text("bar")).hang(2) + val actual = doc.grouped.render(5) + val expected = + """foo + | baz + | bar""".stripMargin + assert(actual == expected) + } + "should indent when long from multiline 2" in { + // If you only include the previous line it doesn't work either. + // It hangs relative to the current position which is whatever was before the hanged document. + // hanging has to include it's own new line in order to ensure that the position is reset. + val prefix = Doc.text("foo") + Doc.line + Doc.text("baz") + val doc = prefix + (Doc.line + Doc.text("bar")).hang(2) + val actual = doc.grouped.render(5) + val expected = + """foo + |baz + | bar""".stripMargin + assert(actual == expected) + } + } + } + } +} diff --git a/kyo-grpc-core/README.md b/kyo-grpc-core/README.md new file mode 100644 index 000000000..790a318f1 --- /dev/null +++ b/kyo-grpc-core/README.md @@ -0,0 +1,43 @@ +# kyo-grpc + +A Protoc plugin that generates... + +# Using the plugin + + + +To add the plugin to another project: + +``` +addSbtPlugin("com.thesamet" % "sbt-protoc" % "1.0.6") + +libraryDependencies += "com.example" %% "kyo-grpc-codegen" % "0.1.0" +``` + +and the following to your `build.sbt`: +``` +PB.targets in Compile := Seq( + scalapb.gen() -> (sourceManaged in Compile).value / "scalapb", + kyo.grpc.gen() -> (sourceManaged in Compile).value / "scalapb" +) +``` + +# Development and testing + +Code structure: +- [`core`](./core/): The runtime library for this plugin +- [`code-gen`](./code-gen): The protoc plugin (code generator) +- [`e2e`](./e2e): Integration tests for the plugin +- [`examples`](./examples): Example projects + +To test the plugin, within SBT: + +``` +> e2eJVM2_13/test +``` + +or + +``` +> e2eJVM2_12/test +``` diff --git a/kyo-grpc-core/shared/src/main/scala/kyo/grpc/CallClosed.scala b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/CallClosed.scala new file mode 100644 index 000000000..19336a135 --- /dev/null +++ b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/CallClosed.scala @@ -0,0 +1,7 @@ +package kyo.grpc + +import io.grpc.Status +import io.grpc.StatusException + +final case class CallClosed(status: Status, trailers: SafeMetadata): + def asException: StatusException = status.asException(trailers.toJava) diff --git a/kyo-grpc-core/shared/src/main/scala/kyo/grpc/Client.scala b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/Client.scala new file mode 100644 index 000000000..4449e616b --- /dev/null +++ b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/Client.scala @@ -0,0 +1,99 @@ +package kyo.grpc + +import io.grpc.* +import java.util.concurrent.TimeUnit +import kyo.* + +/** Utilities for creating and managing gRPC client channels. + * + * This object provides functionality for creating managed gRPC channels with automatic resource cleanup and proper shutdown handling. + * + * Consider using the `managedClient` method on service companion objects for a more concise way to create both the channel and the client + * in one step. + * + * Key features: + * - Automatic resource cleanup via [[Scope]] effect + * - Graceful shutdown with fallback to forced shutdown + * - Configurable channel settings through builder pattern + * - Integration with generated gRPC service clients + */ +object Client: + + /** Attempts an orderly shut down of the [[ManagedChannel]] within a timeout. + * + * First attempts graceful shutdown by calling [[ManagedChannel.shutdown]] and waits up to `timeout` for termination. If the server + * doesn't terminate within the timeout, forces shutdown with [[ManagedChannel.shutdownNow]] and then waits up to 1 hour for it to + * terminate (there is no indefinite wait). + * + * @param channel + * The channel to shut down + * @param timeout + * The maximum duration to wait for graceful termination (default: 30 seconds) + */ + def shutdown(channel: ManagedChannel, timeout: Duration = 30.seconds)(using Frame): Unit < Sync = + Sync.defer: + val terminated = + channel + .shutdown() + .awaitTermination(timeout.toNanos, TimeUnit.NANOSECONDS) + if terminated then () else discard(channel.shutdownNow().awaitTermination(1, TimeUnit.MINUTES)) + + /** Creates a managed gRPC channel with automatic resource cleanup. + * + * Creates a [[ManagedChannel]] that is automatically acquired and released via the [[Scope]] effect. The channel is configured using + * the provided `configure` function and will be properly shut down when the resource is released. + * + * The returned channel can be passed to the `client` method on service companion objects generated by kyo-grpc-code-gen to create + * typed gRPC clients. Alternatively, you can use the `managedClient` method on service companion objects which combines channel + * creation and client creation in a single step. + * + * @example + * {{{ + * // Create a simple insecure channel + * val simpleChannel = Client.channel("localhost", 9090)(_.usePlaintext()) + * + * // Create a channel with custom configuration + * val customChannel = Client.channel("api.example.com", 443)( + * _.useTransportSecurity() + * .keepAliveTime(30, java.util.concurrent.TimeUnit.SECONDS) + * .maxInboundMessageSize(4 * 1024 * 1024) + * ) + * + * // Use with generated service client + * for + * channel <- Client.channel("localhost", 9090)(_.usePlaintext()) + * client <- GreeterService.client(channel) + * response <- client.sayHello(HelloRequest("World")) + * yield response + * + * // Or use managedClient for a more concise approach + * for + * client <- GreeterService.managedClient("localhost", 9090)(_.usePlaintext()) + * response <- client.sayHello(HelloRequest("World")) + * yield response + * }}} + * + * @param host + * The target server hostname or IP address to connect to + * @param port + * The target server port number + * @param timeout + * The maximum duration to wait for graceful channel shutdown when the resource is released (default: 30 seconds) + * @param configure + * A function to configure the [[ManagedChannelBuilder]] before building the channel. This allows customization of channel settings + * like TLS, keepalive, retry policies, etc. + * @param shutdown + * The shutdown function to use when releasing the channel resource. Defaults to [[shutdown]] method which performs graceful shutdown + * with fallback to forced shutdown + * @return + * A `ManagedChannel` pending `Scope` and `Sync` effects + */ + def channel(host: String, port: Int, timeout: Duration = 30.seconds)( + configure: ManagedChannelBuilder[?] => ManagedChannelBuilder[?], + shutdown: (ManagedChannel, Duration) => Frame ?=> Any < Sync = shutdown + )(using Frame): ManagedChannel < (Scope & Sync) = + Scope.acquireRelease( + Sync.defer(configure(ManagedChannelBuilder.forAddress(host, port)).build()) + )(shutdown(_, timeout)) + +end Client diff --git a/kyo-grpc-core/shared/src/main/scala/kyo/grpc/ClientCall.scala b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/ClientCall.scala new file mode 100644 index 000000000..8bcd99b3a --- /dev/null +++ b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/ClientCall.scala @@ -0,0 +1,464 @@ +package kyo.grpc + +import io.grpc.CallOptions +import io.grpc.ClientCall +import io.grpc.MethodDescriptor +import io.grpc.StatusException +import kyo.* +import kyo.grpc.* +import kyo.grpc.internal.ServerStreamingClientCallListener +import kyo.grpc.internal.UnaryClientCallListener + +// TODO: Name these. + +type GrpcRequestCompletion = Unit < (Env[CallClosed] & Async) + +private[grpc] type GrpcResponsesAwaitingCompletion[MaybeResponses] = MaybeResponses < (Emit[GrpcRequestCompletion] & Async) + +// Unary and server streaming method calls do not flush the request headers so they are only sent when the request is +// sent. That means that they cannot be used in the creation of the request and can only be provided to the application +// when the response is received. +// TODO: Singular vs plural is confusing here. +type GrpcRequest[Requests] = Requests < (Emit[GrpcRequestCompletion] & Async) + +type GrpcRequestsPendingHeaders[Requests] = Requests < (Env[SafeMetadata] & Emit[GrpcRequestCompletion] & Async) + +type GrpcRequestsInit[Requests] = GrpcRequestsPendingHeaders[Requests] < (Emit[RequestOptions] & Async) + +type GrpcRequestInit[Request] = GrpcRequest[Request] < (Emit[RequestOptions] & Async) + +/** Provides client-side gRPC call implementations for different RPC patterns. + * + * This object contains methods for executing unary, client streaming, server streaming, and bidirectional streaming gRPC calls using Kyo's + * effect system. + */ +object ClientCall: + + // TODO: + // - unary and serverStreaming need to provide the headers with the response. + // - all methods need to provide trailers. + + /** Executes a unary gRPC call. + * + * A unary call sends a single request and receives a single response. + * + * @param channel + * the gRPC channel to use for the call + * @param method + * the method descriptor defining the RPC method + * @param options + * call options for configuring the request + * @param requestInit + * the request message to send + * @tparam Request + * the type of the request message + * @tparam Response + * the type of the response message + * @return + * the response message pending [[Grpc]] + */ + def unary[Request, Response]( + channel: io.grpc.Channel, + method: MethodDescriptor[Request, Response], + options: CallOptions, + requestInit: GrpcRequestInit[Request] + )(using Frame): Response < Grpc = + def start(call: ClientCall[Request, Response], options: RequestOptions): UnaryClientCallListener[Response] < Sync = + for + headersPromise <- Promise.init[SafeMetadata, Any] + responsePromise <- Promise.init[Response, Abort[StatusException]] + completionPromise <- Promise.init[CallClosed, Any] + readySignal <- Signal.initRef[Boolean](false) + listener = UnaryClientCallListener(headersPromise, responsePromise, completionPromise, readySignal) + _ <- Sync.defer(call.start(listener, options.headers.toJava)) + _ <- Sync.defer(options.messageCompression.foreach(call.setMessageCompression)) + _ <- Sync.defer(call.request(1)) + yield listener + end start + + def sendAndReceive( + call: ClientCall[Request, Response], + listener: UnaryClientCallListener[Response], + requestEffect: GrpcRequest[Request] + ): GrpcResponsesAwaitingCompletion[Result[GrpcFailure, Response]] = + for + // We ignore the ready signal here as we want the request ready as soon as possible, + // and we will only buffer at most one request. + // The request was already made when starting. + request <- requestEffect + _ <- Sync.defer(call.sendMessage(request)) + _ <- Sync.defer(call.halfClose()) + result <- listener.responsePromise.getResult + // TODO: Where is the emit of the effect that waits for completion? + yield result + end sendAndReceive + + def processCompletion(listener: UnaryClientCallListener[Response])(completionEffect: GrpcResponsesAwaitingCompletion[Result[ + GrpcFailure, + Response + ]]): Response < Grpc = + Emit.runForeach(completionEffect)(handler => + listener.completionPromise.get.map(Env.run(_)(handler)) + ).map(Abort.get) + + def run(call: ClientCall[Request, Response]): Response < Grpc = + RequestOptions.run(requestInit).map: (options, requestEffect) => + for + listener <- start(call, options) + response <- sendAndReceive(call, listener, requestEffect).handle( + processCompletion(listener), + cancelOnError(call), + cancelOnInterrupt(call) + ) + yield response + + Sync.defer(channel.newCall(method, options)).map(run) + end unary + + /** Executes a client streaming gRPC call. + * + * A client streaming call sends multiple requests via a stream and receives a single response. The client can send multiple messages + * over time, and the server responds with a single message after processing all requests. + * + * @param channel + * the gRPC channel to use for the call + * @param method + * the method descriptor defining the RPC method + * @param options + * call options for configuring the request + * @param requestsInit + * a stream of request messages to send + * @tparam Request + * the type of the request messages + * @tparam Response + * the type of the response message + * @return + * the response message pending [[Grpc]] + */ + def clientStreaming[Request: Tag, Response]( + channel: io.grpc.Channel, + method: MethodDescriptor[Request, Response], + options: CallOptions, + requestsInit: GrpcRequestsInit[Stream[Request, Grpc]] + )(using Frame, Tag[Emit[Chunk[Request]]]): Response < Grpc = + def start(call: ClientCall[Request, Response], options: RequestOptions): UnaryClientCallListener[Response] < Sync = + for + headersPromise <- Promise.init[SafeMetadata, Any] + responsePromise <- Promise.init[Response, Abort[StatusException]] + completionPromise <- Promise.init[CallClosed, Any] + readySignal <- Signal.initRef[Boolean](false) + listener = UnaryClientCallListener(headersPromise, responsePromise, completionPromise, readySignal) + _ <- Sync.defer(call.start(listener, options.headers.toJava)) + _ <- Sync.defer(options.messageCompression.foreach(call.setMessageCompression)) + yield listener + end start + + def processHeaders( + listener: UnaryClientCallListener[Response], + requestsEffect: GrpcRequestsPendingHeaders[Stream[Request, Grpc]] + ): GrpcRequest[Stream[Request, Grpc]] = + listener.headersPromise.get.map(Env.run(_)(requestsEffect)) + + def sendAndClose( + call: ClientCall[Request, Response], + listener: UnaryClientCallListener[Response], + requests: Stream[Request, Grpc] + ): Result[GrpcFailure, Unit] < Async = + // Sends the first message regardless of readiness to ensure progress. + val send = requests.foreach(request => + for + _ <- Sync.defer(call.sendMessage(request)) + // There is a race condition between setting the ready signal to false and the listener setting it + // to true. Either update may be lost, however, we always check isReady which is the source of + // truth. The only case where the signal value matters is when isReady is false. We know that the + // signal will still be false, and the listener guarantees that the ready signal will be set to true + // when isReady becomes true. + _ <- listener.readySignal.set(false) + isReady <- Sync.defer(call.isReady) + // TODO: We have to handle the case where the listener completes. + _ <- if isReady then Kyo.unit else listener.readySignal.next + yield () + ) + + Abort.run(send).map((result: Result[GrpcFailure, Unit]) => + result match + case success: Result.Success[Unit] @unchecked => + Sync.defer(call.halfClose()).andThen(success) + case error: Result.Error[GrpcFailure] @unchecked => + Sync.defer(call.cancel("Call was cancelled due to an error.", error.failureOrPanic)).andThen(error) + ) + end sendAndClose + + def sendAndReceive( + call: ClientCall[Request, Response], + listener: UnaryClientCallListener[Response], + requestsEffect: GrpcRequest[Stream[Request, Grpc]] + ): GrpcResponsesAwaitingCompletion[Result[GrpcFailure, Response]] = + for + requests <- requestsEffect + _ <- Sync.defer(call.request(1)) + sendResult <- sendAndClose(call, listener, requests) + result <- + sendResult match + case Result.Success(_) => listener.responsePromise.getResult + case Result.Failure(e) => Kyo.lift(Result.fail(e)) + case Result.Panic(e) => Kyo.lift(Result.panic(e)) + yield result + end sendAndReceive + + def processCompletion(listener: UnaryClientCallListener[Response])(completionEffect: GrpcResponsesAwaitingCompletion[Result[ + GrpcFailure, + Response + ]]) + : Response < Grpc = + Emit.runForeach(completionEffect)(handler => + listener.completionPromise.get.map(Env.run(_)(handler)) + ).map(Abort.get) + + def run(call: ClientCall[Request, Response]): Response < Grpc = + RequestOptions.run(requestsInit).map: (options, requestsEffect) => + for + listener <- start(call, options) + response <- (for + requestsWithHeaders <- processHeaders(listener, requestsEffect) + response <- sendAndReceive(call, listener, requestsWithHeaders) + yield response).handle( + processCompletion(listener), + cancelOnError(call), + cancelOnInterrupt(call) + ) + yield response + + Sync.defer(channel.newCall(method, options)).map(run) + end clientStreaming + + /** Executes a server streaming gRPC call. + * + * A server streaming call sends a single request and receives multiple responses via a stream. The client sends one message, and the + * server responds with a stream of messages over time. + * + * @param channel + * the gRPC channel to use for the call + * @param method + * the method descriptor defining the RPC method + * @param options + * call options for configuring the request + * @param requestInit + * the request message to send + * @tparam Request + * the type of the request message + * @tparam Response + * the type of the response messages + * @return + * a stream of response messages pending [[Grpc]] + */ + def serverStreaming[Request, Response: Tag]( + channel: io.grpc.Channel, + method: MethodDescriptor[Request, Response], + options: CallOptions, + requestInit: GrpcRequestInit[Request] + )(using Frame, Tag[Emit[Chunk[Response]]]): Stream[Response, Grpc] = + def start(call: ClientCall[Request, Response], options: RequestOptions): ServerStreamingClientCallListener[Response] < Sync = + for + headersPromise <- Promise.init[SafeMetadata, Any] + // TODO: What about the Scope? + // Assumption is that SPSC is fine which I think it is according to gRPC docs. + responseStream <- + Channel.initUnscoped[Response](options.responseCapacityOrDefault, access = Access.SingleProducerSingleConsumer) + completionPromise <- Promise.init[CallClosed, Any] + readySignal <- Signal.initRef[Boolean](false) + listener = ServerStreamingClientCallListener(headersPromise, responseStream, completionPromise, readySignal) + _ <- Sync.defer(call.start(listener, options.headers.toJava)) + _ <- Sync.defer(options.messageCompression.foreach(call.setMessageCompression)) + // TODO: Add tests that ensure that we request the right amount. + _ <- Sync.defer(call.request(Math.max(1, options.responseCapacityOrDefault))) + yield listener + end start + + def sendAndReceive( + call: ClientCall[Request, Response], + listener: ServerStreamingClientCallListener[Response], + requestEffect: GrpcRequest[Request] + ): GrpcResponsesAwaitingCompletion[Stream[Response, Grpc]] = + def onChunk(chunk: Chunk[Response]) = + Sync.defer(call.request(chunk.size)) + + for + // We ignore the ready signal here as we want the request ready as soon as possible, + // and we will only buffer at most one request. + request <- requestEffect + _ <- Sync.defer(call.sendMessage(request)) + _ <- Sync.defer(call.halfClose()) + stream <- listener.responseChannel.streamUntilClosed().tapChunk(onChunk) + yield stream + end for + end sendAndReceive + + def processCompletion(listener: ServerStreamingClientCallListener[Response])( + completionEffect: GrpcResponsesAwaitingCompletion[Stream[Response, Grpc]] + ): Stream[Response, Grpc] < Async = + Emit.run[GrpcRequestCompletion](completionEffect).map: (handlers, responses) => + listener.completionPromise.get.map: callClosed => + val completed = handlers.foldLeft(Kyo.unit: Unit < Async): (acc, handler) => + acc.andThen(Env.run(callClosed)(handler)) + completed.andThen: + if callClosed.status.isOk then responses + else responses.concat(Stream(Abort.fail(callClosed.asException))) + end processCompletion + + def run(call: ClientCall[Request, Response]): Stream[Response, Grpc] < Async = + RequestOptions.run(requestInit).map: (options, requestEffect) => + for + listener <- start(call, options) + responses <- sendAndReceive(call, listener, requestEffect).handle( + processCompletion(listener), + cancelOnError(call), + cancelOnInterrupt(call) + ) + yield responses + + Stream.unwrap: + Sync.defer(channel.newCall(method, options)).map(run) + end serverStreaming + + /** Executes a bidirectional streaming gRPC call. + * + * A bidirectional streaming call allows both client and server to send multiple messages via streams. Both sides can send messages + * independently and asynchronously, enabling full-duplex communication patterns. + * + * @param channel + * the gRPC channel to use for the call + * @param method + * the method descriptor defining the RPC method + * @param options + * call options for configuring the request + * @param requestsInit + * a stream of request messages to send + * @tparam Request + * the type of the request messages + * @tparam Response + * the type of the response messages + * @return + * a stream of response messages pending [[Grpc]] + */ + def bidiStreaming[Request: Tag, Response: Tag]( + channel: io.grpc.Channel, + method: MethodDescriptor[Request, Response], + options: CallOptions, + requestsInit: GrpcRequestsInit[Stream[Request, Grpc]] + )(using Frame, Tag[Emit[Chunk[Request]]], Tag[Emit[Chunk[Response]]]): Stream[Response, Grpc] = + def start(call: ClientCall[Request, Response], options: RequestOptions): ServerStreamingClientCallListener[Response] < Sync = + for + headersPromise <- Promise.init[SafeMetadata, Any] + // TODO: What about the Scope? + // Assumption is that SPSC is fine which I think it is according to gRPC docs. + responseStream <- + Channel.initUnscoped[Response](options.responseCapacityOrDefault, access = Access.SingleProducerSingleConsumer) + completionPromise <- Promise.init[CallClosed, Any] + readySignal <- Signal.initRef[Boolean](false) + listener = ServerStreamingClientCallListener(headersPromise, responseStream, completionPromise, readySignal) + _ <- Sync.defer(call.start(listener, options.headers.toJava)) + _ <- Sync.defer(options.messageCompression.foreach(call.setMessageCompression)) + yield listener + end start + + def processHeaders( + listener: ServerStreamingClientCallListener[Response], + requestsEffect: GrpcRequestsPendingHeaders[Stream[Request, Grpc]] + ): GrpcRequest[Stream[Request, Grpc]] = + listener.headersPromise.get.map(Env.run(_)(requestsEffect)) + + def sendAndClose( + call: ClientCall[Request, Response], + listener: ServerStreamingClientCallListener[Response], + requests: Stream[Request, Grpc] + ): Result[GrpcFailure, Unit] < Async = + // Sends the first message regardless of readiness to ensure progress. + val send = requests.foreach(request => + for + _ <- Sync.defer(call.sendMessage(request)) + // There is a race condition between setting the ready signal to false and the listener setting it + // to true. Either update may be lost, however, we always check isReady which is the source of + // truth. The only case where the signal value matters is when isReady is false. We know that the + // signal will still be false, and the listener guarantees that the ready signal will be set to true + // when isReady becomes true. + _ <- listener.readySignal.set(false) + isReady <- Sync.defer(call.isReady) + _ <- if isReady then Kyo.unit else listener.readySignal.next + yield () + ) + + Abort.run(send).map((result: Result[GrpcFailure, Unit]) => + result match + case success: Result.Success[Unit] @unchecked => + Sync.defer(call.halfClose()).andThen(success) + case error: Result.Error[GrpcFailure] @unchecked => + Sync.defer(call.cancel("Call was cancelled due to an error.", error.failureOrPanic)).andThen(error) + ) + end sendAndClose + + def sendAndReceive( + call: ClientCall[Request, Response], + listener: ServerStreamingClientCallListener[Response], + requestsEffect: GrpcRequest[Stream[Request, Grpc]] + ): GrpcResponsesAwaitingCompletion[Stream[Response, Grpc]] = + def onResponseChunk(chunk: Chunk[Response]) = + Sync.defer(call.request(chunk.size)) + + for + requests <- requestsEffect + _ <- Sync.defer(call.request(1)) + // TODO: Is it fine for this to be unscoped? + _ <- Fiber.initUnscoped(sendAndClose(call, listener, requests)) + stream <- listener.responseChannel.streamUntilClosed().tapChunk(onResponseChunk) + yield stream + end for + end sendAndReceive + + def processCompletion(listener: ServerStreamingClientCallListener[Response])( + completionEffect: GrpcResponsesAwaitingCompletion[Stream[Response, Grpc]] + ): Stream[Response, Grpc] < Async = + Emit.runForeach(completionEffect)(handler => + listener.completionPromise.get.map(Env.run(_)(handler)) + ) + + def run(call: ClientCall[Request, Response]): Stream[Response, Grpc] < Async = + RequestOptions.run(requestsInit).map: (options, requestsEffect) => + for + listener <- start(call, options) + responses <- (for + requestsWithHeaders <- processHeaders(listener, requestsEffect) + responses <- sendAndReceive(call, listener, requestsWithHeaders) + yield responses).handle( + processCompletion(listener), + cancelOnError(call), + cancelOnInterrupt(call) + ) + yield responses + + Stream.unwrap: + Sync.defer(channel.newCall(method, options)).map(run) + end bidiStreaming + + private def cancelOnError[E <: Throwable: ConcreteTag, Response, S](call: ClientCall[?, ?])(v: => Response < (Abort[E] & S))(using + Frame + ): Response < (Abort[E] & Sync & S) = + Abort.recoverError[E](error => + Sync.defer(call.cancel("Call was cancelled due to an error.", error.failureOrPanic)) + .andThen(Abort.error(error)) + )(v) + + private def cancelOnInterrupt[E, Response](call: ClientCall[?, ?])(v: => Response < (Abort[E] & Async))(using + Frame + ): Response < (Abort[E] & Async) = + Async.tapFiber(v)(fiber => + fiber.onInterrupt(error => + val ex = error match + case Result.Panic(e) => e + case _ => null + Sync.defer(call.cancel("Kyo Fiber was interrupted.", ex)) + ) + ) + end cancelOnInterrupt + +end ClientCall diff --git a/kyo-grpc-core/shared/src/main/scala/kyo/grpc/Grpc.scala b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/Grpc.scala new file mode 100644 index 000000000..9cfae1d5a --- /dev/null +++ b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/Grpc.scala @@ -0,0 +1,42 @@ +package kyo.grpc + +import io.grpc.StatusException +import io.grpc.StatusRuntimeException +import kyo.* +import scala.concurrent.Future + +// TODO: Document how to include trailers in error. +/** Effect of sending or receiving a gRPC message. + * + * Service method implementations will be [[Async]] effects that either succeed with some `Response` or terminate early with a + * [[GrpcFailure]]. For example: + * {{{ + * Abort.fail(Status.INVALID_ARGUMENT.withDescription("Id cannot be empty.").asException) + * }}} + * + * Clients will typically handle the effect of calling a gRPC method using functions such as [[Abort.run]]. + */ +type Grpc = Async & Abort[GrpcFailure] + +object Grpc: + + /** Creates a computation pending the [[Grpc]] effect from a [[Future]]. + * + * If the `Future` fails with an exception, it will be converted to a [[GrpcFailure]] using [[GrpcFailure.fromThrowable]] and the + * computation will abort. + * + * @param f + * The `Future` that produces the computation result + * @tparam A + * The type of the successful result + * @return + * A computation that completes with the result of the Future + */ + def fromFuture[A](f: Future[A])(using Frame): A < Grpc = + Abort.recoverError[Throwable] { + // TODO: Fix match not exhaustive warning + case Result.Error(t) => Abort.fail(GrpcFailure.fromThrowable(t)) + }(Async.fromFuture(f)) + end fromFuture + +end Grpc diff --git a/kyo-grpc-core/shared/src/main/scala/kyo/grpc/GrpcFailure.scala b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/GrpcFailure.scala new file mode 100644 index 000000000..c502167f5 --- /dev/null +++ b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/GrpcFailure.scala @@ -0,0 +1,36 @@ +package kyo.grpc + +import io.grpc.* +import scala.util.chaining.scalaUtilChainingOps + +/** A failure that occurred while sending or receiving a gRPC message. + * + * These are typically created from a [[io.grpc.Status]] via `asException`. + * + * @see + * [[StatusException]] + */ +type GrpcFailure = StatusException + +object GrpcFailure: + + /** Converts a [[Throwable]] to a [[GrpcFailure]]. + * + * Conversions are handled as follows: + * - [[StatusException]]: Returns the exception unchanged. + * - [[StatusRuntimeException]]: Converts to `StatusException` while preserving status, trailers, and stack trace. + * - Other exceptions: Uses [[Status#fromThrowable(java.lang.Throwable)]] which attempts to find a gRPC status from nested causes, + * defaulting to [[Status.UNKNOWN]] status if none is found. + * + * @param t + * The `Throwable` to convert + * @return + * A `GrpcFailure` suitable for gRPC error reporting + */ + def fromThrowable(t: Throwable): GrpcFailure = + t match + case e: StatusException => e + case e: StatusRuntimeException => StatusException(e.getStatus, e.getTrailers).tap(_.setStackTrace(e.getStackTrace)) + case _ => Status.fromThrowable(t).asException() + +end GrpcFailure diff --git a/kyo-grpc-core/shared/src/main/scala/kyo/grpc/RequestOptions.scala b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/RequestOptions.scala new file mode 100644 index 000000000..b1140fd15 --- /dev/null +++ b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/RequestOptions.scala @@ -0,0 +1,34 @@ +package kyo.grpc + +import kyo.* +import kyo.grpc.RequestOptions.DefaultResponseCapacity + +final case class RequestOptions( + headers: SafeMetadata = SafeMetadata.empty, + messageCompression: Maybe[Boolean] = Maybe.empty, + responseCapacity: Maybe[Int] = Maybe.empty +): + + def combine(that: RequestOptions)(using Frame): RequestOptions < Sync = + Sync.defer: + RequestOptions( + headers = this.headers.merge(that.headers), + messageCompression = that.messageCompression.orElse(this.messageCompression), + responseCapacity = that.responseCapacity.orElse(this.responseCapacity) + ) + end combine + + def responseCapacityOrDefault: Int = responseCapacity.getOrElse(DefaultResponseCapacity) + +end RequestOptions + +object RequestOptions: + + // TODO: What are sensible defaults? + val DefaultRequestBuffer: Int = 8 + val DefaultResponseCapacity: Int = 8 + + def run[A, S](v: A < (Emit[RequestOptions] & S))(using Frame): (RequestOptions, A) < (Sync & S) = + Emit.runFold[RequestOptions](RequestOptions())(_.combine(_))(v) + +end RequestOptions diff --git a/kyo-grpc-core/shared/src/main/scala/kyo/grpc/ResponseOptions.scala b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/ResponseOptions.scala new file mode 100644 index 000000000..d98c70dd1 --- /dev/null +++ b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/ResponseOptions.scala @@ -0,0 +1,52 @@ +package kyo.grpc + +import io.grpc.ServerCall +import kyo.* + +// TODO: What to call this? +final case class ResponseOptions( + headers: SafeMetadata = SafeMetadata.empty, + messageCompression: Maybe[Boolean] = Maybe.empty, + compression: Maybe[String] = Maybe.empty, + onReadyThreshold: Maybe[Int] = Maybe.empty, + requestBuffer: Maybe[Int] = Maybe.empty +): + + def requestBufferOrDefault: Int = + requestBuffer.getOrElse(ResponseOptions.DefaultRequestBuffer) + + def combine(that: ResponseOptions)(using Frame): ResponseOptions < Sync = + Sync.defer: + ResponseOptions( + headers = this.headers.merge(that.headers), + messageCompression = that.messageCompression.orElse(this.messageCompression), + compression = that.compression.orElse(this.compression), + onReadyThreshold = that.onReadyThreshold.orElse(this.onReadyThreshold), + requestBuffer = that.requestBuffer.orElse(this.requestBuffer) + ) + end combine + + def sendHeaders(call: ServerCall[?, ?])(using Frame): Unit < Sync = + Sync.defer: + // These may only be called once and must be called before sendMessage. + messageCompression.foreach(call.setMessageCompression) + compression.foreach(call.setCompression) + onReadyThreshold.foreach(call.setOnReadyThreshold) + // Headers must be sent even if empty. + call.sendHeaders(headers.toJava) + end sendHeaders + +end ResponseOptions + +object ResponseOptions: + + val DefaultRequestBuffer: Int = 8 + + def run[A, S](v: A < (Emit[ResponseOptions] & S))(using Frame): (ResponseOptions, A) < (Sync & S) = + Emit.runFold[ResponseOptions](ResponseOptions())(_.combine(_))(v) + + def runSend[A, S](call: ServerCall[?, ?])(v: A < (Emit[ResponseOptions] & S))(using Frame): A < (Sync & S) = + run(v).map: (options, a) => + options.sendHeaders(call).andThen(a) + +end ResponseOptions diff --git a/kyo-grpc-core/shared/src/main/scala/kyo/grpc/SafeMetadata.scala b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/SafeMetadata.scala new file mode 100644 index 000000000..c8c884918 --- /dev/null +++ b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/SafeMetadata.scala @@ -0,0 +1,93 @@ +package kyo.grpc + +import io.grpc.Metadata +import java.util.Base64 +import kyo.* + +final case class SafeMetadata( + entries: Map[String, Seq[String]] = Map.empty +): + def add(key: String, value: String): SafeMetadata = + if key.endsWith(Metadata.BINARY_HEADER_SUFFIX) then + throw new IllegalArgumentException( + s"Binary header key $key must end with ${Metadata.BINARY_HEADER_SUFFIX} and value must be Chunk[Byte]" + ) + end if + update(key, value) + end add + + def add(key: String, value: Chunk[Byte]): SafeMetadata = + if !key.endsWith(Metadata.BINARY_HEADER_SUFFIX) then + throw new IllegalArgumentException(s"Binary header key $key must end with ${Metadata.BINARY_HEADER_SUFFIX}") + update(key, Base64.getEncoder.encodeToString(value.toArray)) + end add + + private def update(key: String, value: String): SafeMetadata = + val newValues = entries.getOrElse(key, Seq.empty) :+ value + copy(entries = entries.updated(key, newValues)) + + def merge(that: SafeMetadata): SafeMetadata = + val merged = that.entries.foldLeft(entries) { case (acc, (k, v)) => + acc.updated(k, acc.getOrElse(k, Seq.empty) ++ v) + } + copy(entries = merged) + end merge + + def toJava: Metadata = + val md = new Metadata() + entries.foreach { case (k, values) => + values.foreach { v => + if k.endsWith(Metadata.BINARY_HEADER_SUFFIX) then + val decoded = Base64.getDecoder.decode(v) + md.put(Metadata.Key.of(k, Metadata.BINARY_BYTE_MARSHALLER), decoded) + else + md.put(Metadata.Key.of(k, Metadata.ASCII_STRING_MARSHALLER), v) + } + } + md + end toJava + +end SafeMetadata + +object SafeMetadata: + val empty: SafeMetadata = SafeMetadata() + + def fromJava(md: Metadata): SafeMetadata = + var result = empty + val keys = md.keys() + if keys != null then + val iterator = keys.iterator() + while iterator.hasNext do + val key = iterator.next() + if key.endsWith(Metadata.BINARY_HEADER_SUFFIX) then + val keyObj = Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER) + val values = md.getAll(keyObj) + if values != null then + val it = values.iterator() + while it.hasNext do + result = result.update(key, Base64.getEncoder.encodeToString(it.next())) + end if + else + val keyObj = Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER) + val values = md.getAll(keyObj) + if values != null then + val it = values.iterator() + while it.hasNext do + result = result.update(key, it.next()) + end if + end if + end while + end if + result + end fromJava +end SafeMetadata + +extension (metadata: SafeMetadata) + + def getStrings(key: String): Seq[String] = + metadata.entries.getOrElse(key, Seq.empty) + + def getBinary(key: String): Seq[Chunk[Byte]] = + metadata.entries.getOrElse(key, Seq.empty).map(s => Chunk.from(Base64.getDecoder.decode(s))) + +end extension diff --git a/kyo-grpc-core/shared/src/main/scala/kyo/grpc/Server.scala b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/Server.scala new file mode 100644 index 000000000..ebcb84c30 --- /dev/null +++ b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/Server.scala @@ -0,0 +1,75 @@ +package kyo.grpc + +import io.grpc.ServerBuilder +import java.util.concurrent.TimeUnit +import kyo.* +import sun.misc.Signal + +type Server = io.grpc.Server + +/** Server utilities for managing gRPC servers in Kyo. + * + * Provides functionality to start and gracefully shutdown gRPC servers with proper resource management. + * + * @example + * {{{ + * for + * _ <- Console.printLine(s"Server is running on port $port. Press Ctrl-C to stop.") + * server <- Server.start(port)( + * _.addService(GreeterService), + * { (server, duration) => + * for + * _ <- Console.print("Shutting down...") + * _ <- Server.shutdown(server, duration) + * _ <- Console.printLine("Done.") + * yield () + * } + * ) + * _ <- Async.never + * yield () + * }}} + */ +object Server: + + /** Attempts an orderly shut down of the [[io.grpc.Server]] within a timeout. + * + * First attempts graceful shutdown by calling [[io.grpc.Server.shutdown]] and waits up to `timeout` for termination. If the server + * doesn't terminate within the timeout, forces shutdown with [[io.grpc.Server.shutdownNow]] and then waits indefinitely for it to + * terminate. + * + * @param server + * The server to shut down + * @param timeout + * The maximum duration to wait for graceful termination (default: 30 seconds) + */ + def shutdown(server: Server, timeout: Duration = 30.seconds)(using Frame): Unit < Sync = + Sync.defer: + val terminated = + server + .shutdown() + .awaitTermination(timeout.toJava.toNanos, TimeUnit.NANOSECONDS) + if terminated then () else server.shutdownNow().awaitTermination() + + /** Starts an [[io.grpc.Server]] on the specified port with the provided configuration and shutdown logic. + * + * @param port + * the port on which the server will listen + * @param timeout + * the maximum duration to wait for graceful termination (default: 30 seconds) + * @param configure + * a function to configure the [[ServerBuilder]] such as adding services + * @param shutdown + * A function to handle the shutdown of the server, which takes the server instance and a timeout duration. Defaults to + * [[Server.shutdown]] + * @return + * the running server pending [[Scope]] and [[Sync]] + */ + def start(port: Int, timeout: Duration = 30.seconds)( + configure: ServerBuilder[?] => ServerBuilder[?], + shutdown: (Server, Duration) => Frame ?=> Any < Sync = shutdown + )(using Frame): Server < (Scope & Sync) = + Scope.acquireRelease( + Sync.defer(configure(ServerBuilder.forPort(port)).build().start()) + )(shutdown(_, timeout)) + +end Server diff --git a/kyo-grpc-core/shared/src/main/scala/kyo/grpc/ServerCallHandlers.scala b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/ServerCallHandlers.scala new file mode 100644 index 000000000..8bae9dca5 --- /dev/null +++ b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/ServerCallHandlers.scala @@ -0,0 +1,113 @@ +package kyo.grpc + +import io.grpc.ServerCallHandler +import io.grpc.Status +import kyo.* +import kyo.grpc.* +import kyo.grpc.internal.BidiStreamingServerCallHandler +import kyo.grpc.internal.ClientStreamingServerCallHandler +import kyo.grpc.internal.ServerStreamingServerCallHandler +import kyo.grpc.internal.UnaryServerCallHandler + +type GrpcResponseMeta = Env[SafeMetadata] & Emit[ResponseOptions] + +type GrpcHandler[Requests, Responses] = Requests => Responses < (Grpc & Emit[SafeMetadata]) + +type GrpcHandlerInit[Requests, Responses] = GrpcHandler[Requests, Responses] < GrpcResponseMeta + +/** Factory for creating gRPC server call handlers that integrate with Kyo. + * + * This object provides methods to create server handlers for all four types of gRPC methods: unary, client streaming, server streaming, + * and bidirectional streaming. Each handler method converts Kyo effects into gRPC-compatible server handlers. + * + * All handlers are designed to work with the [[Grpc]] effect type and handle exceptions and stream completion appropriately. + */ +object ServerCallHandlers: + + // TODO: Update the docs for f + // TODO: Update the callers to include metadata and response options. + + /** Creates a server handler for unary gRPC calls. + * + * A unary call receives a single request and produces a single response. The handler function `f` takes a request and returns a + * response pending the [[Grpc]] effect. + * + * @param f + * the handler function that processes the request + * @tparam Request + * the type of the incoming request + * @tparam Response + * the type of the outgoing response + * @return + * a gRPC [[ServerCallHandler]] for unary calls + */ + def unary[Request, Response](f: GrpcHandlerInit[Request, Response])(using Frame): ServerCallHandler[Request, Response] = + UnaryServerCallHandler(f) + + /** Creates a server handler for client streaming gRPC calls. + * + * A client streaming call receives a stream of requests from the client and produces a single response. The handler function `f` takes + * a stream of requests and returns a response pending the [[Grpc]] effect. + * + * @param f + * the handler function that processes the request stream + * @tparam Request + * the type of each request in the stream + * @tparam Response + * the type of the single response + * @return + * a gRPC [[ServerCallHandler]] for client streaming calls + */ + def clientStreaming[Request, Response](f: GrpcHandlerInit[Stream[Request, Grpc], Response])(using + Frame, + Tag[Emit[Chunk[Request]]] + ): ServerCallHandler[Request, Response] = + ClientStreamingServerCallHandler(f) + + /** Creates a server handler for server streaming gRPC calls. + * + * A server streaming call receives a single request and produces a stream of responses. The handler function `f` takes a request and + * returns a stream of responses pending the [[Grpc]] effect. + * + * @param f + * the handler function that processes the request and produces a response stream + * @tparam Request + * the type of the single request + * @tparam Response + * the type of each response in the stream + * @return + * a gRPC [[ServerCallHandler]] for server streaming calls + */ + def serverStreaming[Request, Response](f: GrpcHandlerInit[Request, Stream[Response, Grpc]])(using + Frame, + Tag[Emit[Chunk[Response]]] + ): ServerCallHandler[Request, Response] = + ServerStreamingServerCallHandler(f) + + /** Creates a server handler for bidirectional streaming gRPC calls. + * + * A bidirectional streaming call receives a stream of requests and produces a stream of responses. The handler function `f` takes a + * stream of requests and returns a stream of responses pending the [[Grpc]] effect. + * + * @param f + * the handler function that processes the request stream and produces a response stream + * @tparam Request + * the type of each request in the input stream + * @tparam Response + * the type of each response in the output stream + * @return + * a gRPC [[ServerCallHandler]] for bidirectional streaming calls + */ + def bidiStreaming[Request, Response](f: GrpcHandlerInit[Stream[Request, Grpc], Stream[Response, Grpc]])(using + Frame, + Tag[Emit[Chunk[Request]]], + Tag[Emit[Chunk[Response]]] + ): ServerCallHandler[Request, Response] = + BidiStreamingServerCallHandler(f) + + // TODO: Inline this. + private[kyo] def errorStatus(error: Result.Error[Throwable]): Status = + val t = error.failureOrPanic + Status.fromThrowable(t) + +end ServerCallHandlers diff --git a/kyo-grpc-core/shared/src/main/scala/kyo/grpc/Service.scala b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/Service.scala new file mode 100644 index 000000000..e867ce0dd --- /dev/null +++ b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/Service.scala @@ -0,0 +1,41 @@ +package kyo.grpc + +import io.grpc.ServerServiceDefinition + +/** A trait representing a Kyo gRPC service. + * + * This trait provides a common interface for all gRPC services, requiring implementations to provide a [[ServerServiceDefinition]] that + * describes the service's methods and their handlers. + * + * Services implementing this trait can be easily converted to gRPC's `ServerServiceDefinition` through the provided implicit conversion. + * + * @note + * Implementations of this trait are typically generated by `kyo-grpc-code-gen` from protobuf service definitions rather than being + * manually implemented. + */ +trait Service: + + /** Returns the gRPC server service definition for this service. + * + * @return + * the [[ServerServiceDefinition]] that describes this service + */ + def definition: ServerServiceDefinition + +end Service + +/** Companion object for [[Service]] providing utility conversions. + */ +object Service: + + /** Implicit conversion from [[Service]] to [[ServerServiceDefinition]]. + * + * This conversion allows `Service` instances to be used directly wherever a `ServerServiceDefinition` is expected, providing seamless + * integration with gRPC server builders and other APIs. + * + * @return + * a conversion function that extracts the `definition` from a `Service` + */ + given Conversion[Service, ServerServiceDefinition] = _.definition + +end Service diff --git a/kyo-grpc-core/shared/src/main/scala/kyo/grpc/internal/BaseStreamingServerCallHandler.scala b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/internal/BaseStreamingServerCallHandler.scala new file mode 100644 index 000000000..4ff252f9e --- /dev/null +++ b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/internal/BaseStreamingServerCallHandler.scala @@ -0,0 +1,97 @@ +package kyo.grpc.internal + +import io.grpc.* +import kyo.* +import kyo.Channel +import kyo.grpc.* +import kyo.grpc.Grpc + +abstract private[grpc] class BaseStreamingServerCallHandler[Request, Response, Handler](f: Handler < GrpcResponseMeta)(using Frame) + extends ServerCallHandler[Request, Response]: + + import AllowUnsafe.embrace.danger + + protected def send( + call: ServerCall[Request, Response], + handler: Handler, + channel: Channel[Request], + ready: SignalRef[Boolean] + ): Status < (Grpc & Emit[SafeMetadata]) + + override def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = + // WARNING: call is not guaranteed to be thread-safe. + // WARNING: headers are definitely not thread-safe. + // This handler has ownership of the call and headers, so we can use them with care. + + def sendAndClose(handler: Handler, channel: Channel[Request], ready: SignalRef[Boolean]) = + for + (trailers, status) <- + send(call, handler, channel, ready).handle( + Abort.recoverError(ServerCallHandlers.errorStatus), + Emit.runFold[SafeMetadata](SafeMetadata.empty)(_.merge(_)) + ) + _ <- Sync.defer(call.close(status, trailers.toJava)) + yield () + + def start(handler: Handler, channel: Channel[Request], ready: SignalRef[Boolean]) = + for + fiber <- Fiber.initUnscoped(sendAndClose(handler, channel, ready)) + _ <- fiber.onInterrupt: _ => + val status = Status.CANCELLED.withDescription("Call was cancelled.") + try + call.close(status, SafeMetadata.empty.toJava) + catch + case _: IllegalStateException => // Ignore + end try + _ <- fiber.onComplete: _ => + channel.close + yield fiber + + val init = + for + // Request 1 up front to ensure that we get the headers. + _ <- Sync.defer(call.request(1)) + (options, handler) <- f.handle( + Env.run(SafeMetadata.fromJava(headers)), + ResponseOptions.run + ) + requestBuffer = options.requestBufferOrDefault + _ <- options.sendHeaders(call) + // Request the remaining messages to fill the request buffer. + _ <- Sync.defer(if requestBuffer > 1 then call.request(requestBuffer - 1) else ()) + ready <- Signal.initRef(false) + channel <- Channel.initUnscoped[Request](capacity = requestBuffer, access = Access.SingleProducerSingleConsumer) + sentFiber <- start(handler, channel, ready) + yield StreamingServerCallListener(channel.unsafe, sentFiber.unsafe, ready.unsafe, call) + + init.handle( + Sync.Unsafe.run, + Abort.run, + _.eval.getOrThrow + ) + end startCall + + private class StreamingServerCallListener( + requests: Channel.Unsafe[Request], + fiber: Fiber.Unsafe[Any, Nothing], + ready: SignalRef.Unsafe[Boolean], + call: ServerCall[Request, Response] + ) extends ServerCall.Listener[Request]: + + override def onMessage(message: Request): Unit = + discard(requests.putFiber(message)) + + override def onHalfClose(): Unit = + discard(requests.closeAwaitEmpty()) + + override def onCancel(): Unit = + discard(fiber.interrupt()) + + override def onComplete(): Unit = () + + override def onReady(): Unit = + ready.set(call.isReady) + + end StreamingServerCallListener + +end BaseStreamingServerCallHandler diff --git a/kyo-grpc-core/shared/src/main/scala/kyo/grpc/internal/BaseUnaryServerCallHandler.scala b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/internal/BaseUnaryServerCallHandler.scala new file mode 100644 index 000000000..59b630bb2 --- /dev/null +++ b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/internal/BaseUnaryServerCallHandler.scala @@ -0,0 +1,97 @@ +package kyo.grpc.internal + +import io.grpc.* +import kyo.* +import kyo.grpc.* +import kyo.grpc.Grpc + +abstract private[grpc] class BaseUnaryServerCallHandler[Request, Response, Handler](f: Handler < GrpcResponseMeta)(using Frame) + extends ServerCallHandler[Request, Response]: + + import AllowUnsafe.embrace.danger + + protected def send( + call: ServerCall[Request, Response], + handler: Handler, + promise: Promise[Request, Abort[Status]], + ready: SignalRef[Boolean] + ): Status < (Grpc & Emit[SafeMetadata]) + + override def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = + // WARNING: call is not guaranteed to be thread-safe. + // WARNING: headers are definitely not thread-safe. + // This handler has ownership of the call and headers, so we can use them with care. + + def sendAndClose(handler: Handler, promise: Promise[Request, Abort[Status]], ready: SignalRef[Boolean]) = + for + (trailers, status) <- + send(call, handler, promise, ready).handle( + Abort.recoverError(ServerCallHandlers.errorStatus), + Emit.runFold[SafeMetadata](SafeMetadata.empty)(_.merge(_)) + ) + _ <- Sync.defer(call.close(status, trailers.toJava)) + yield () + + def start(handler: Handler, promise: Promise[Request, Abort[Status]], ready: SignalRef[Boolean]) = + for + fiber <- Fiber.initUnscoped(sendAndClose(handler, promise, ready)) + _ <- fiber.onInterrupt: _ => + val status = Status.CANCELLED.withDescription("Call was cancelled.") + try + call.close(status, SafeMetadata.empty.toJava) + catch + case _: IllegalStateException => // Ignore + end try + yield fiber + + val init = + for + _ <- Sync.defer(call.request(1)) + (options, handler) <- f.handle( + Env.run(SafeMetadata.fromJava(headers)), + ResponseOptions.run + ) + _ <- options.sendHeaders(call) + ready <- Signal.initRef(false) + promise <- Promise.init[Request, Abort[Status]] + sentFiber <- start(handler, promise, ready) + yield UnaryServerCallListener(promise.unsafe, sentFiber.unsafe, ready.unsafe, call) + + init.handle( + Sync.Unsafe.run, + Abort.run, + _.eval.getOrThrow + ) + end startCall + + private class UnaryServerCallListener( + request: Promise.Unsafe[Request, Abort[Status]], + fiber: Fiber.Unsafe[Any, Nothing], + ready: SignalRef.Unsafe[Boolean], + call: ServerCall[Request, Response] + ) extends ServerCall.Listener[Request]: + + override def onMessage(message: Request): Unit = + // Unlike io.grpc.stub.ServerCalls.UnaryServerCallHandler.UnaryServerCallListener, + // this does not attempt to detect if the client is misbehaving by sending multiple messages. + // It is unnecessary as the server can reply and close the call after the first message. + // It should be up to the client implementation to detect that if it wants to. + request.completeDiscard(Result.succeed(message)) + + override def onHalfClose(): Unit = + // If the promise has not been completed yet, we complete it with an error, otherwise this does nothing. + request.completeDiscard(Result.fail( + Status.INVALID_ARGUMENT.withDescription("Client completed before sending a request.") + )) + + override def onCancel(): Unit = + discard(fiber.interrupt()) + + override def onComplete(): Unit = () + + override def onReady(): Unit = + ready.set(call.isReady) + + end UnaryServerCallListener + +end BaseUnaryServerCallHandler diff --git a/kyo-grpc-core/shared/src/main/scala/kyo/grpc/internal/BidiStreamingServerCallHandler.scala b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/internal/BidiStreamingServerCallHandler.scala new file mode 100644 index 000000000..b93421791 --- /dev/null +++ b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/internal/BidiStreamingServerCallHandler.scala @@ -0,0 +1,40 @@ +package kyo.grpc.internal + +import io.grpc.* +import kyo.* +import kyo.Channel +import kyo.grpc.* +import kyo.grpc.Grpc + +private[grpc] class BidiStreamingServerCallHandler[Request, Response](f: GrpcHandlerInit[Stream[Request, Grpc], Stream[Response, Grpc]])( + using + Frame, + Tag[Emit[Chunk[Request]]], + Tag[Emit[Chunk[Response]]] +) extends BaseStreamingServerCallHandler[Request, Response, GrpcHandler[Stream[Request, Grpc], Stream[Response, Grpc]]](f): + + override protected def send( + call: ServerCall[Request, Response], + handler: GrpcHandler[Stream[Request, Grpc], Stream[Response, Grpc]], + channel: Channel[Request], + ready: SignalRef[Boolean] + ): Status < (Grpc & Emit[SafeMetadata]) = + def onChunk(chunk: Chunk[Request]) = + Sync.defer(call.request(chunk.size)) + + def sendMessages(isFirst: AtomicRef[Boolean])(response: Response): Unit < Async = + // Send the first message whether the call is ready or not and let it buffer internally as a fast path + // under the assumption that the client will be ready for at least one response after the initial request. + isFirst.getAndSet(false).flatMap: first => + if first || call.isReady then Sync.defer(call.sendMessage(response)) + else ready.next.andThen(sendMessages(isFirst)(response)) + + for + responses <- handler(channel.streamUntilClosed().tapChunk(onChunk)) + isFirst <- AtomicRef.init(true) + _ <- responses.foreach(sendMessages(isFirst)) + yield Status.OK + end for + end send + +end BidiStreamingServerCallHandler diff --git a/kyo-grpc-core/shared/src/main/scala/kyo/grpc/internal/ClientStreamingServerCallHandler.scala b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/internal/ClientStreamingServerCallHandler.scala new file mode 100644 index 000000000..08f68dee3 --- /dev/null +++ b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/internal/ClientStreamingServerCallHandler.scala @@ -0,0 +1,30 @@ +package kyo.grpc.internal + +import io.grpc.* +import kyo.* +import kyo.Channel +import kyo.grpc.* +import kyo.grpc.Grpc + +private[grpc] class ClientStreamingServerCallHandler[Request, Response](f: GrpcHandlerInit[Stream[Request, Grpc], Response])(using + Frame, + Tag[Emit[Chunk[Request]]] +) extends BaseStreamingServerCallHandler[Request, Response, GrpcHandler[Stream[Request, Grpc], Response]](f): + + override protected def send( + call: ServerCall[Request, Response], + handler: GrpcHandler[Stream[Request, Grpc], Response], + channel: Channel[Request], + ready: SignalRef[Boolean] + ): Status < (Grpc & Emit[SafeMetadata]) = + def onChunk(chunk: Chunk[Request]) = + Sync.defer(call.request(chunk.size)) + + for + response <- handler(channel.streamUntilClosed().tapChunk(onChunk)) + _ <- Sync.defer(call.sendMessage(response)) + yield Status.OK + end for + end send + +end ClientStreamingServerCallHandler diff --git a/kyo-grpc-core/shared/src/main/scala/kyo/grpc/internal/MetadataExtensions.scala b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/internal/MetadataExtensions.scala new file mode 100644 index 000000000..8b3d12d07 --- /dev/null +++ b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/internal/MetadataExtensions.scala @@ -0,0 +1,15 @@ +package kyo.grpc.internal + +import kyo.* +import kyo.grpc.SafeMetadata + +extension (maybeMetadata: Maybe[SafeMetadata]) + + inline def mergeIfDefined(maybeOther: Maybe[SafeMetadata])(using Frame): Maybe[SafeMetadata] < Sync = + maybeMetadata match + case Maybe.Present(metadata) => + Sync.defer(Maybe.Present(metadata.merge(maybeOther.getOrElse(SafeMetadata.empty)))) + case Maybe.Absent => + maybeOther + +end extension diff --git a/kyo-grpc-core/shared/src/main/scala/kyo/grpc/internal/ServerStreamingClientCallListener.scala b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/internal/ServerStreamingClientCallListener.scala new file mode 100644 index 000000000..ddf1999d0 --- /dev/null +++ b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/internal/ServerStreamingClientCallListener.scala @@ -0,0 +1,34 @@ +package kyo.grpc.internal + +import io.grpc.* +import io.grpc.ClientCall.Listener +import kyo.* +import kyo.Channel +import kyo.grpc.CallClosed +import kyo.grpc.SafeMetadata + +private[grpc] class ServerStreamingClientCallListener[Response]( + val headersPromise: Promise[SafeMetadata, Any], + val responseChannel: Channel[Response], + val completionPromise: Promise[CallClosed, Any], + val readySignal: SignalRef[Boolean] +) extends Listener[Response]: + + import AllowUnsafe.embrace.danger + private given Frame = Frame.internal + + override def onHeaders(headers: Metadata): Unit = + headersPromise.unsafe.completeDiscard(Result.succeed(SafeMetadata.fromJava(headers))) + + override def onMessage(message: Response): Unit = + val _ = responseChannel.unsafe.offer(message) + + override def onClose(status: Status, trailers: Metadata): Unit = + val _ = responseChannel.unsafe.close() + completionPromise.unsafe.completeDiscard(Result.succeed(CallClosed(status, SafeMetadata.fromJava(trailers)))) + + override def onReady(): Unit = + // May not be called if the method type is unary. + readySignal.unsafe.set(true) + +end ServerStreamingClientCallListener diff --git a/kyo-grpc-core/shared/src/main/scala/kyo/grpc/internal/ServerStreamingServerCallHandler.scala b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/internal/ServerStreamingServerCallHandler.scala new file mode 100644 index 000000000..d3b6339ea --- /dev/null +++ b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/internal/ServerStreamingServerCallHandler.scala @@ -0,0 +1,38 @@ +package kyo.grpc.internal + +import io.grpc.* +import kyo.* +import kyo.grpc.* +import kyo.grpc.Grpc + +private[grpc] class ServerStreamingServerCallHandler[Request, Response](f: GrpcHandlerInit[Request, Stream[Response, Grpc]])(using + Frame, + Tag[Emit[Chunk[Response]]] +) extends BaseUnaryServerCallHandler[Request, Response, GrpcHandler[Request, Stream[Response, Grpc]]](f): + + import AllowUnsafe.embrace.danger + + override protected def send( + call: ServerCall[Request, Response], + handler: GrpcHandler[Request, Stream[Response, Grpc]], + promise: Promise[Request, Abort[Status]], + ready: SignalRef[Boolean] + ): Status < (Grpc & Emit[SafeMetadata]) = + def sendMessages(isFirst: AtomicRef[Boolean])(response: Response): Unit < Async = + // Send the first message whether the call is ready or not and let it buffer internally as a fast path + // under the assumption that the client will be ready for at least one response after the initial request. + isFirst.getAndSet(false).flatMap: first => + if first || call.isReady then Sync.defer(call.sendMessage(response)) + else ready.next.andThen(sendMessages(isFirst)(response)) + + Abort.run[Status]: + for + request <- promise.get + responses <- handler(request) + isFirst <- AtomicRef.init(true) + _ <- responses.foreach(sendMessages(isFirst)) + yield Status.OK + .map(_.fold(identity, identity, e => throw e)) + end send + +end ServerStreamingServerCallHandler diff --git a/kyo-grpc-core/shared/src/main/scala/kyo/grpc/internal/UnaryClientCallListener.scala b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/internal/UnaryClientCallListener.scala new file mode 100644 index 000000000..3b3d544f9 --- /dev/null +++ b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/internal/UnaryClientCallListener.scala @@ -0,0 +1,35 @@ +package kyo.grpc.internal + +import io.grpc.* +import io.grpc.ClientCall.Listener +import kyo.* +import kyo.Channel +import kyo.grpc.CallClosed +import kyo.grpc.SafeMetadata + +private[grpc] class UnaryClientCallListener[Response]( + val headersPromise: Promise[SafeMetadata, Any], + val responsePromise: Promise[Response, Abort[StatusException]], + val completionPromise: Promise[CallClosed, Any], + val readySignal: SignalRef[Boolean] +) extends Listener[Response]: + + import AllowUnsafe.embrace.danger + + override def onHeaders(headers: Metadata): Unit = + headersPromise.unsafe.completeDiscard(Result.succeed(SafeMetadata.fromJava(headers))) + + override def onMessage(message: Response): Unit = + if !responsePromise.unsafe.complete(Result.succeed(message)) then + throw Status.INVALID_ARGUMENT.withDescription("Server sent more than one response.").asException() + end onMessage + + override def onClose(status: Status, trailers: Metadata): Unit = + responsePromise.unsafe.completeDiscard(Result.fail(status.asException(trailers))) + completionPromise.unsafe.completeDiscard(Result.succeed(CallClosed(status, SafeMetadata.fromJava(trailers)))) + + override def onReady(): Unit = + // May not be called if the method type is unary. + readySignal.unsafe.set(true) + +end UnaryClientCallListener diff --git a/kyo-grpc-core/shared/src/main/scala/kyo/grpc/internal/UnaryServerCallHandler.scala b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/internal/UnaryServerCallHandler.scala new file mode 100644 index 000000000..bed10d516 --- /dev/null +++ b/kyo-grpc-core/shared/src/main/scala/kyo/grpc/internal/UnaryServerCallHandler.scala @@ -0,0 +1,26 @@ +package kyo.grpc.internal + +import io.grpc.* +import kyo.* +import kyo.grpc.* +import kyo.grpc.Grpc + +private[grpc] class UnaryServerCallHandler[Request, Response](f: GrpcHandlerInit[Request, Response])(using Frame) + extends BaseUnaryServerCallHandler[Request, Response, GrpcHandler[Request, Response]](f): + + override protected def send( + call: ServerCall[Request, Response], + handler: GrpcHandler[Request, Response], + promise: Promise[Request, Abort[Status]], + ready: SignalRef[Boolean] + ): Status < (Grpc & Emit[SafeMetadata]) = + Abort.run[Status]: + for + request <- promise.get + response <- handler(request) + _ <- Sync.defer(call.sendMessage(response)) + yield Status.OK + .map(_.fold(identity, identity, e => throw e)) + end send + +end UnaryServerCallHandler diff --git a/kyo-grpc-core/shared/src/test/scala/kyo/grpc/ArgEquals.scala b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/ArgEquals.scala new file mode 100644 index 000000000..e01a2dc98 --- /dev/null +++ b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/ArgEquals.scala @@ -0,0 +1,14 @@ +package kyo.grpc + +import org.scalactic.Equality +import org.scalamock.matchers.Matcher +import org.scalamock.matchers.MatcherBase +import scala.reflect.ClassTag + +class ArgEquals[T](arg: T, clue: Option[String])(using classTag: ClassTag[T], equality: Equality[T]) extends Matcher[T]: + override def toString: String = "argEquals[" + classTag.runtimeClass.getSimpleName + "]" + clue.map(c => s" - $c").getOrElse("") + override def safeEquals(obj: T): Boolean = equality.areEqual(arg, obj) + +def argEquals[T: {ClassTag, Equality}](clue: String)(arg: T): MatcherBase = ArgEquals(arg, Some(clue)) + +def argEquals[T: {ClassTag, Equality}](arg: T): MatcherBase = ArgEquals(arg, None) diff --git a/kyo-grpc-core/shared/src/test/scala/kyo/grpc/CallClosedTest.scala b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/CallClosedTest.scala new file mode 100644 index 000000000..75d93ef9a --- /dev/null +++ b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/CallClosedTest.scala @@ -0,0 +1,249 @@ +package kyo.grpc + +import io.grpc.* +import org.scalactic.TripleEquals.* + +class CallClosedTest extends Test: + + "construction" - { + "creates CallClosed with status and trailers" in { + val status = Status.OK + val trailers = SafeMetadata.empty + val result = CallClosed(status, trailers) + + assert(result.status === status) + assert(result.trailers === trailers) + } + + "creates CallClosed with non-OK status" in { + val status = Status.CANCELLED.withDescription("Request cancelled") + val trailers = SafeMetadata.empty + val result = CallClosed(status, trailers) + + assert(result.status.getCode === Status.Code.CANCELLED) + assert(result.status.getDescription === "Request cancelled") + assert(result.trailers === trailers) + } + + "creates CallClosed with metadata in trailers" in { + val status = Status.OK + val trailers = SafeMetadata.empty.add("test-key", "test-value") + val result = CallClosed(status, trailers) + + assert(result.status === status) + assert(result.trailers.getStrings("test-key") === Seq("test-value")) + } + } + + "equality" - { + "equals itself" in { + val status = Status.OK + val trailers = SafeMetadata.empty + val result = CallClosed(status, trailers) + + assert(result === result) + } + + "equals another CallClosed with same status and trailers" in { + val status = Status.OK + val trailers = SafeMetadata.empty + val result1 = CallClosed(status, trailers) + val result2 = CallClosed(status, trailers) + + assert(result1 === result2) + } + + "not equals CallClosed with different status" in { + val trailers = SafeMetadata.empty + val result1 = CallClosed(Status.OK, trailers) + val result2 = CallClosed(Status.CANCELLED, trailers) + + assert(result1 !== result2) + } + + "not equals CallClosed with different trailers" in { + val status = Status.OK + val result1 = CallClosed(status, SafeMetadata.empty.add("key", "v1")) + val result2 = CallClosed(status, SafeMetadata.empty.add("key", "v2")) + + assert(result1 !== result2) + } + } + + "copy" - { + "creates copy with modified status" in { + val status1 = Status.OK + val status2 = Status.CANCELLED + val trailers = SafeMetadata.empty + val original = CallClosed(status1, trailers) + val copied = original.copy(status = status2) + + assert(copied.status === status2) + assert(copied.trailers === trailers) + assert(original.status === status1) + } + + "creates copy with modified trailers" in { + val status = Status.OK + val trailers1 = SafeMetadata.empty + val trailers2 = SafeMetadata.empty.add("test-key", "test-value") + val original = CallClosed(status, trailers1) + val copied = original.copy(trailers = trailers2) + + assert(copied.status === status) + assert(copied.trailers === trailers2) + assert(copied.trailers.getStrings("test-key") === Seq("test-value")) + assert(original.trailers === trailers1) + } + + "creates copy with all fields modified" in { + val status1 = Status.OK + val status2 = Status.CANCELLED + val trailers1 = SafeMetadata.empty + val trailers2 = SafeMetadata.empty.add("key", "val") + val original = CallClosed(status1, trailers1) + val copied = original.copy(status = status2, trailers = trailers2) + + assert(copied.status === status2) + assert(copied.trailers === trailers2) + assert(original.status === status1) + assert(original.trailers === trailers1) + } + } + + "field access" - { + "provides access to status field" in { + val status = Status.DEADLINE_EXCEEDED.withDescription("Timeout") + val trailers = SafeMetadata.empty + val result = CallClosed(status, trailers) + + assert(result.status.getCode === Status.Code.DEADLINE_EXCEEDED) + assert(result.status.getDescription === "Timeout") + } + + "provides access to trailers field" in { + val status = Status.OK + val trailers = SafeMetadata.empty.add("key1", "value1").add("key2", "value2") + val result = CallClosed(status, trailers) + + assert(result.trailers.getStrings("key1") === Seq("value1")) + assert(result.trailers.getStrings("key2") === Seq("value2")) + } + } + + "different status codes" - { + "works with OK status" in { + val result = CallClosed(Status.OK, SafeMetadata.empty) + assert(result.status.getCode === Status.Code.OK) + } + + "works with CANCELLED status" in { + val result = CallClosed(Status.CANCELLED, SafeMetadata.empty) + assert(result.status.getCode === Status.Code.CANCELLED) + } + + "works with UNKNOWN status" in { + val result = CallClosed(Status.UNKNOWN, SafeMetadata.empty) + assert(result.status.getCode === Status.Code.UNKNOWN) + } + + "works with INVALID_ARGUMENT status" in { + val result = CallClosed(Status.INVALID_ARGUMENT, SafeMetadata.empty) + assert(result.status.getCode === Status.Code.INVALID_ARGUMENT) + } + + "works with DEADLINE_EXCEEDED status" in { + val result = CallClosed(Status.DEADLINE_EXCEEDED, SafeMetadata.empty) + assert(result.status.getCode === Status.Code.DEADLINE_EXCEEDED) + } + + "works with NOT_FOUND status" in { + val result = CallClosed(Status.NOT_FOUND, SafeMetadata.empty) + assert(result.status.getCode === Status.Code.NOT_FOUND) + } + + "works with ALREADY_EXISTS status" in { + val result = CallClosed(Status.ALREADY_EXISTS, SafeMetadata.empty) + assert(result.status.getCode === Status.Code.ALREADY_EXISTS) + } + + "works with PERMISSION_DENIED status" in { + val result = CallClosed(Status.PERMISSION_DENIED, SafeMetadata.empty) + assert(result.status.getCode === Status.Code.PERMISSION_DENIED) + } + + "works with RESOURCE_EXHAUSTED status" in { + val result = CallClosed(Status.RESOURCE_EXHAUSTED, SafeMetadata.empty) + assert(result.status.getCode === Status.Code.RESOURCE_EXHAUSTED) + } + + "works with FAILED_PRECONDITION status" in { + val result = CallClosed(Status.FAILED_PRECONDITION, SafeMetadata.empty) + assert(result.status.getCode === Status.Code.FAILED_PRECONDITION) + } + + "works with ABORTED status" in { + val result = CallClosed(Status.ABORTED, SafeMetadata.empty) + assert(result.status.getCode === Status.Code.ABORTED) + } + + "works with OUT_OF_RANGE status" in { + val result = CallClosed(Status.OUT_OF_RANGE, SafeMetadata.empty) + assert(result.status.getCode === Status.Code.OUT_OF_RANGE) + } + + "works with UNIMPLEMENTED status" in { + val result = CallClosed(Status.UNIMPLEMENTED, SafeMetadata.empty) + assert(result.status.getCode === Status.Code.UNIMPLEMENTED) + } + + "works with INTERNAL status" in { + val result = CallClosed(Status.INTERNAL, SafeMetadata.empty) + assert(result.status.getCode === Status.Code.INTERNAL) + } + + "works with UNAVAILABLE status" in { + val result = CallClosed(Status.UNAVAILABLE, SafeMetadata.empty) + assert(result.status.getCode === Status.Code.UNAVAILABLE) + } + + "works with DATA_LOSS status" in { + val result = CallClosed(Status.DATA_LOSS, SafeMetadata.empty) + assert(result.status.getCode === Status.Code.DATA_LOSS) + } + + "works with UNAUTHENTICATED status" in { + val result = CallClosed(Status.UNAUTHENTICATED, SafeMetadata.empty) + assert(result.status.getCode === Status.Code.UNAUTHENTICATED) + } + } + + "status with cause and description" - { + "preserves status description" in { + val status = Status.INTERNAL.withDescription("Internal server error") + val trailers = SafeMetadata.empty + val result = CallClosed(status, trailers) + + assert(result.status.getDescription === "Internal server error") + } + + "preserves status cause" in { + val cause = new RuntimeException("Original error") + val status = Status.INTERNAL.withCause(cause) + val trailers = SafeMetadata.empty + val result = CallClosed(status, trailers) + + assert(result.status.getCause === cause) + } + + "preserves both description and cause" in { + val cause = new RuntimeException("Original error") + val status = Status.INTERNAL.withDescription("Internal server error").withCause(cause) + val trailers = SafeMetadata.empty + val result = CallClosed(status, trailers) + + assert(result.status.getDescription === "Internal server error") + assert(result.status.getCause === cause) + } + } +end CallClosedTest diff --git a/kyo-grpc-core/shared/src/test/scala/kyo/grpc/ClientCallTest.scala b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/ClientCallTest.scala new file mode 100644 index 000000000..85dec49e1 --- /dev/null +++ b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/ClientCallTest.scala @@ -0,0 +1,135 @@ +package kyo.grpc + +import io.grpc.MethodDescriptor +import kyo.* +import org.scalactic.TripleEquals.* + +class ClientCallTest extends Test: + + case class TestRequest(message: String) + case class TestResponse(result: String) + + "ClientCall" - { + + "method descriptors" - { + + "unary method descriptor" in run { + val method = MethodDescriptor.newBuilder[TestRequest, TestResponse]() + .setType(MethodDescriptor.MethodType.UNARY) + .setFullMethodName("test.Service/UnaryMethod") + .setRequestMarshaller(TestMarshaller[TestRequest]()) + .setResponseMarshaller(TestMarshaller[TestResponse]()) + .build() + + assert(method.getType.equals(MethodDescriptor.MethodType.UNARY)) + assert(method.getFullMethodName == "test.Service/UnaryMethod") + succeed + } + + "client streaming method descriptor" in run { + val method = MethodDescriptor.newBuilder[TestRequest, TestResponse]() + .setType(MethodDescriptor.MethodType.CLIENT_STREAMING) + .setFullMethodName("test.Service/ClientStreamingMethod") + .setRequestMarshaller(TestMarshaller[TestRequest]()) + .setResponseMarshaller(TestMarshaller[TestResponse]()) + .build() + + assert(method.getType.equals(MethodDescriptor.MethodType.CLIENT_STREAMING)) + assert(method.getFullMethodName == "test.Service/ClientStreamingMethod") + succeed + } + + "server streaming method descriptor" in run { + val method = MethodDescriptor.newBuilder[TestRequest, TestResponse]() + .setType(MethodDescriptor.MethodType.SERVER_STREAMING) + .setFullMethodName("test.Service/ServerStreamingMethod") + .setRequestMarshaller(TestMarshaller[TestRequest]()) + .setResponseMarshaller(TestMarshaller[TestResponse]()) + .build() + + assert(method.getType.equals(MethodDescriptor.MethodType.SERVER_STREAMING)) + assert(method.getFullMethodName == "test.Service/ServerStreamingMethod") + succeed + } + + "bidirectional streaming method descriptor" in run { + val method = MethodDescriptor.newBuilder[TestRequest, TestResponse]() + .setType(MethodDescriptor.MethodType.BIDI_STREAMING) + .setFullMethodName("test.Service/BidiStreamingMethod") + .setRequestMarshaller(TestMarshaller[TestRequest]()) + .setResponseMarshaller(TestMarshaller[TestResponse]()) + .build() + + assert(method.getType.equals(MethodDescriptor.MethodType.BIDI_STREAMING)) + assert(method.getFullMethodName == "test.Service/BidiStreamingMethod") + succeed + } + } + + "request options" - { + + "default options have empty headers" in run { + val options = RequestOptions() + assert(options.headers === SafeMetadata.empty) + assert(options.messageCompression.isEmpty) + assert(options.responseCapacity.isEmpty) + succeed + } + + "options with headers" in run { + val headers = SafeMetadata.empty.add("test-header", "test-value") + + val options = RequestOptions(headers = headers) + + assert(options.headers.getStrings("test-header") == Seq("test-value")) + succeed + } + + "options with message compression" in run { + val options = RequestOptions(messageCompression = Maybe(true)) + + assert(options.messageCompression.isDefined) + assert(options.messageCompression.get == true) + succeed + } + + "options with response capacity" in run { + val options = RequestOptions(responseCapacity = Maybe(100)) + + assert(options.responseCapacity.isDefined) + assert(options.responseCapacity.get == 100) + assert(options.responseCapacityOrDefault == 100) + succeed + } + + "responseCapacityOrDefault returns default when not set" in run { + val options = RequestOptions() + assert(options.responseCapacityOrDefault == RequestOptions.DefaultResponseCapacity) + succeed + } + + "combine merges options" in run { + val h1 = SafeMetadata.empty.add("header1", "value1") + val h2 = SafeMetadata.empty.add("header2", "value2") + + val options1 = RequestOptions( + headers = h1, + messageCompression = Maybe(true) + ) + + val options2 = RequestOptions( + headers = h2, + responseCapacity = Maybe(50) + ) + + options1.combine(options2).map: result => + assert(result.headers.getStrings("header1") == Seq("value1")) + assert(result.headers.getStrings("header2") == Seq("value2")) + assert(result.messageCompression == Maybe(true)) + assert(result.responseCapacity == Maybe(50)) + succeed + } + } + } + +end ClientCallTest diff --git a/kyo-grpc-core/shared/src/test/scala/kyo/grpc/ClientTest.scala b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/ClientTest.scala new file mode 100644 index 000000000..08a4df4f4 --- /dev/null +++ b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/ClientTest.scala @@ -0,0 +1,177 @@ +package kyo.grpc + +import io.grpc.ManagedChannel +import io.grpc.ManagedChannelBuilder +import io.grpc.ManagedChannelProvider +import io.grpc.ManagedChannelProvider.ProviderNotFoundException +import io.grpc.ManagedChannelRegistry +import io.grpc.Status +import io.grpc.stub.StreamObserver +import java.net.SocketAddress +import java.util.concurrent.TimeUnit +import kyo.* +import org.scalamock.scalatest.AsyncMockFactory + +class ClientTest extends Test with AsyncMockFactory: + + private val host = "localhost" + private val port = 50051 + + "shutdown shuts down the channel gracefully" in run { + val channel = mock[ManagedChannel] + + (() => channel.shutdown()) + .expects() + .returns(channel) + .once() + + channel.awaitTermination + .expects(30000000000L, TimeUnit.NANOSECONDS) + .returns(true) + .once() + + Client.shutdown(channel).map(_ => succeed) + } + + "shutdown shuts down the channel forcefully" in run { + val channel = mock[ManagedChannel] + + (() => channel.shutdown()) + .expects() + .returns(channel) + .once() + + channel.awaitTermination + .expects(30000000000L, TimeUnit.NANOSECONDS) + .returns(false) + .once() + + (() => channel.shutdownNow()) + .expects() + .returns(channel) + .once() + + channel.awaitTermination + .expects(1, TimeUnit.MINUTES) + .returns(true) + .once() + + Client.shutdown(channel).map(_ => succeed) + } + + "configures channel" in run { + val channel = mock[ManagedChannel] + (() => channel.shutdown()) + .expects() + .returns(channel) + .once() + + channel.awaitTermination + .expects(30000000000L, TimeUnit.NANOSECONDS) + .returns(true) + .once() + + val unconfiguredBuilder = mock[Builder] + + val configuredBuilder = mock[Builder] + (() => configuredBuilder.build()) + .expects() + .returns(channel) + .once() + + val provider = StubProvider(unconfiguredBuilder) + + var configured = false + + def configure( + actualBuilder: ManagedChannelBuilder[?] + ): ManagedChannelBuilder[?] = + configured = true + assert(actualBuilder eq unconfiguredBuilder) + configuredBuilder + end configure + + for + _ <- replaceProviders(provider) + _ <- Client.channel(host, port)(configure) + yield + assert(provider.builderAddressName == host) + assert(provider.builderAddressPort == port) + assert(configured) + end for + } + + "shuts down channel" in run { + val channel = mock[ManagedChannel] + + // Be careful here. Unexpected calls will fail when shutdown is called which gets swallowed by Scope and so + // the test will not fail. See https://github.com/ScalaMock/ScalaMock/issues/633. + var shutdownCount = 0 + (() => channel.shutdown()) + .expects() + .onCall(() => + shutdownCount += 1 + channel + ) + .once() + + channel.awaitTermination + .expects(30000000000L, TimeUnit.NANOSECONDS) + .returns(true) + .once() + + val builder = mock[Builder] + (() => builder.build()) + .expects() + .returns(channel) + .once() + + val provider = StubProvider(builder) + + val result = Scope.run: + for + _ <- replaceProviders(provider) + _ <- Client.channel(host, port)(identity) + yield assert(shutdownCount == 0) + + result.map(_ => assert(shutdownCount == 1)) + } + + private def replaceProviders(provider: ManagedChannelProvider): Unit < Sync = + for + registry <- Sync.defer(ManagedChannelRegistry.getDefaultRegistry()) + _ <- removeProviders(registry) + _ <- Sync.defer(registry.register(provider)) + yield () + + private def removeProviders(registry: ManagedChannelRegistry): Unit < Sync = + Loop(registry): registry => + Abort.recover[ProviderNotFoundException](_ => Loop.done): + for + provider <- Abort.catching[ProviderNotFoundException](ManagedChannelProvider.provider()) + _ <- Sync.defer(registry.deregister(provider)) + yield Loop.continue(registry) + + abstract private class Builder extends ManagedChannelBuilder[Builder] + + private class StubProvider(builder: ManagedChannelBuilder[?]) extends ManagedChannelProvider: + var builderAddressName: String = "" + var builderAddressPort: Int = -1 + + override protected def isAvailable(): Boolean = true + + override protected def priority(): Int = 0 + + override protected def builderForAddress(name: String, port: Int): ManagedChannelBuilder[?] = + builderAddressName = name + builderAddressPort = port + builder + end builderForAddress + + override protected def builderForTarget(target: String): ManagedChannelBuilder[?] = builder + + override protected def getSupportedSocketAddressTypes: java.util.Collection[Class[? <: SocketAddress]] = + java.util.Collections.emptyList() + end StubProvider + +end ClientTest diff --git a/kyo-grpc-core/shared/src/test/scala/kyo/grpc/Equalities.scala b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/Equalities.scala new file mode 100644 index 000000000..6fca1b079 --- /dev/null +++ b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/Equalities.scala @@ -0,0 +1,44 @@ +package kyo.grpc + +import io.grpc.* +import kyo.Maybe +import org.scalactic.* +import org.scalactic.TripleEquals.* +import scala.reflect.ClassTag + +object Equalities: + + given statusEquality: Equality[Status] with + def areEqual(a: Status, b: Any): Boolean = + b match + case b: Status => a.getCode === b.getCode && a.getDescription === b.getDescription && a.getCause === b.getCause + case _ => false + end statusEquality + + given metadataEquality: Equality[Metadata] with + def areEqual(a: Metadata, b: Any): Boolean = + if Maybe(a).isEmpty then Maybe(b).isEmpty + else + b match + case b: Metadata => a.toString === b.toString + case _ => false + end areEqual + end metadataEquality + + given statusExceptionEquality: Equality[StatusException] with + def areEqual(a: StatusException, b: Any): Boolean = + b match + case b: StatusException => a.getStatus === b.getStatus && a.getTrailers === b.getTrailers + case _ => false + end areEqual + end statusExceptionEquality + + given tuple2Equality: [A: {Equality as eqA, ClassTag}, B: {Equality as eqB, ClassTag}] => Equality[(A, B)]: + def areEqual(a: (A, B), b: Any): Boolean = + b match + case (b1: A, b2: B) => eqA.areEqual(a._1, b1) && eqB.areEqual(a._2, b2) + case _ => false + end areEqual + end tuple2Equality + +end Equalities diff --git a/kyo-grpc-core/shared/src/test/scala/kyo/grpc/GrpcFailureTest.scala b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/GrpcFailureTest.scala new file mode 100644 index 000000000..f93fa613a --- /dev/null +++ b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/GrpcFailureTest.scala @@ -0,0 +1,102 @@ +package kyo.grpc + +import io.grpc.* +import kyo.* +import org.scalactic.TripleEquals.* + +class GrpcFailureTest extends Test: + + "fromThrowable" - { + "preserves existing StatusException" in { + val original = Status.ALREADY_EXISTS.withDescription("Already exists").asException() + val result = GrpcFailure.fromThrowable(original) + + assert(result eq original) + } + + "converts StatusRuntimeException to StatusException" in { + val status = Status.INVALID_ARGUMENT.withDescription("Invalid argument") + val metadata = new Metadata() + metadata.put(Metadata.Key.of("test-key", Metadata.ASCII_STRING_MARSHALLER), "test-value") + val original = status.asRuntimeException(metadata) + val result = GrpcFailure.fromThrowable(original) + + assert(result.getStatus === status) + assert(result.getTrailers === metadata) + assert(result.getStackTrace === original.getStackTrace) + } + + "converts other exceptions to StatusException with UNKNOWN status" in { + val original = new IllegalArgumentException("Invalid argument") + val result = GrpcFailure.fromThrowable(original) + + assert(result.getStatus.getCode === Status.Code.UNKNOWN) + assert(result.getStatus.getDescription === null) + assert(result.getStatus.getCause === original) + } + + "extracts nested StatusException from wrapped exception" in { + val innerStatusException = Status.PERMISSION_DENIED.withDescription("Access denied").asException() + val wrapperException = new RuntimeException("Wrapper", innerStatusException) + val result = GrpcFailure.fromThrowable(wrapperException) + + assert(result.getStatus.getCode === Status.Code.PERMISSION_DENIED) + assert(result.getStatus.getDescription === "Access denied") + } + + "extracts nested StatusRuntimeException from wrapped exception" in { + val status = Status.UNAVAILABLE.withDescription("Service unavailable") + val metadata = new Metadata() + val innerRuntimeException = status.asRuntimeException(metadata) + val wrapperException = new IllegalStateException("Wrapper", innerRuntimeException) + val result = GrpcFailure.fromThrowable(wrapperException) + + assert(result.getStatus.getCode === Status.Code.UNAVAILABLE) + assert(result.getStatus.getDescription === "Service unavailable") + } + + "handles deeply nested gRPC exceptions" in { + val innerStatusException = Status.NOT_FOUND.withDescription("Scope not found").asException() + val middleException = new IllegalArgumentException("Middle", innerStatusException) + val outerException = new RuntimeException("Outer", middleException) + val result = GrpcFailure.fromThrowable(outerException) + + assert(result.getStatus.getCode === Status.Code.NOT_FOUND) + assert(result.getStatus.getDescription === "Scope not found") + } + + "defaults to UNKNOWN when no gRPC exception in cause chain" in { + val innerException = new IllegalArgumentException("Inner") + val middleException = new RuntimeException("Middle", innerException) + val outerException = new IllegalStateException("Outer", middleException) + val result = GrpcFailure.fromThrowable(outerException) + + assert(result.getStatus.getCode === Status.Code.UNKNOWN) + assert(result.getStatus.getCause === outerException) + } + + "preserves original exception metadata for StatusRuntimeException" in { + val status = Status.DEADLINE_EXCEEDED.withDescription("Request timeout") + val metadata = new Metadata() + metadata.put(Metadata.Key.of("retry-info", Metadata.ASCII_STRING_MARSHALLER), "delay=1000ms") + metadata.put(Metadata.Key.of("request-id", Metadata.ASCII_STRING_MARSHALLER), "req-123") + val original = status.asRuntimeException(metadata) + val result = GrpcFailure.fromThrowable(original) + + assert(result.getStatus === status) + assert(result.getTrailers === metadata) + assert(result.getTrailers.get(Metadata.Key.of("retry-info", Metadata.ASCII_STRING_MARSHALLER)) === "delay=1000ms") + assert(result.getTrailers.get(Metadata.Key.of("request-id", Metadata.ASCII_STRING_MARSHALLER)) === "req-123") + } + } + + "type verification" - { + "GrpcFailure type alias" in { + val statusException: StatusException = Status.INTERNAL.asException() + val grpcFailure: GrpcFailure = statusException + val _: StatusException = grpcFailure + succeed + } + } + +end GrpcFailureTest diff --git a/kyo-grpc-core/shared/src/test/scala/kyo/grpc/GrpcTest.scala b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/GrpcTest.scala new file mode 100644 index 000000000..dfd7c52e9 --- /dev/null +++ b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/GrpcTest.scala @@ -0,0 +1,84 @@ +package kyo.grpc + +import io.grpc.Metadata +import io.grpc.Status +import io.grpc.StatusException +import kyo.* +import org.scalactic.TripleEquals.* +import scala.concurrent.Future +import scala.concurrent.Promise +import scala.util.Failure +import scala.util.Success + +class GrpcTest extends Test: + + "fromFuture" - { + "successful Future" in run { + val future = Future.successful("test result") + val grpcComputation = Grpc.fromFuture(future) + + Abort.run[GrpcFailure](grpcComputation).map: result => + assert(result === Result.succeed("test result")) + } + + "failed Future with StatusException" in run { + val statusException = Status.INVALID_ARGUMENT.withDescription("Invalid input").asException() + val future = Future.failed(statusException) + val grpcComputation = Grpc.fromFuture(future) + + Abort.run[GrpcFailure](grpcComputation).map: result => + assert(result.isFailure && (result.failure.get eq statusException)) + } + + "failed Future with StatusRuntimeException" in run { + val status = Status.UNAVAILABLE.withDescription("Service unavailable") + val metadata = new Metadata() + metadata.put(Metadata.Key.of("retry-after", Metadata.ASCII_STRING_MARSHALLER), "30") + val runtimeException = status.asRuntimeException(metadata) + val future = Future.failed(runtimeException) + val grpcComputation = Grpc.fromFuture(future) + + Abort.run[GrpcFailure](grpcComputation).map: result => + assert(result.isFailure) + val failure = result.failure.get + assert(failure.getStatus === status) + assert(failure.getTrailers === metadata) + assert(failure.getStackTrace sameElements runtimeException.getStackTrace) + } + + "failed Future with other exception" in run { + val originalException = new IllegalArgumentException("Custom error") + val future = Future.failed(originalException) + val grpcComputation = Grpc.fromFuture(future) + + Abort.run[GrpcFailure](grpcComputation).map: result => + assert(result.isFailure) + val failure = result.failure.get + assert(failure.getStatus.getCode === Status.Code.UNKNOWN) + assert(failure.getStatus.getCause === originalException) + } + + "failed Future with nested StatusException" in run { + val innerStatusException = Status.PERMISSION_DENIED.withDescription("Access denied").asException() + val wrapperException = new RuntimeException("Wrapper", innerStatusException) + val future = Future.failed(wrapperException) + val grpcComputation = Grpc.fromFuture(future) + + Abort.run[GrpcFailure](grpcComputation).map: result => + assert(result.isFailure) + val failure = result.failure.get + assert(failure.getStatus.getCode === Status.Code.PERMISSION_DENIED) + assert(failure.getStatus.getDescription === "Access denied") + } + } + + "type verification" - { + "Grpc type alias" in { + val grpcEffect: String < Grpc = "test" + val asyncEffect: String < (Async & Abort[GrpcFailure]) = grpcEffect + val _: String < Grpc = asyncEffect + succeed + } + } + +end GrpcTest diff --git a/kyo-grpc-core/shared/src/test/scala/kyo/grpc/RequestOptionsTest.scala b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/RequestOptionsTest.scala new file mode 100644 index 000000000..da3c93679 --- /dev/null +++ b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/RequestOptionsTest.scala @@ -0,0 +1,133 @@ +package kyo.grpc + +import kyo.* +import org.scalactic.TripleEquals.* + +class RequestOptionsTest extends Test: + + "constructor" - { + "creates instance with default empty values" in run { + val options = RequestOptions() + + assert(options.headers === SafeMetadata.empty) + assert(options.messageCompression === Maybe.empty) + assert(options.responseCapacity === Maybe.empty) + succeed + } + + "creates instance with specified values" in run { + val headers = SafeMetadata.empty.add("key", "value") + + val options = RequestOptions( + headers = headers, + messageCompression = Maybe.Present(true), + responseCapacity = Maybe.Present(16) + ) + + assert(options.headers === headers) + assert(options.messageCompression === Maybe.Present(true)) + assert(options.responseCapacity === Maybe.Present(16)) + succeed + } + } + + "responseCapacityOrDefault" - { + "returns specified capacity when present" in run { + val options = RequestOptions(responseCapacity = Maybe.Present(42)) + assert(options.responseCapacityOrDefault === 42) + succeed + } + + "returns default capacity when absent" in run { + val options = RequestOptions() + assert(options.responseCapacityOrDefault === RequestOptions.DefaultResponseCapacity) + succeed + } + } + + "combine" - { + "merges two empty options" in run { + val options1 = RequestOptions() + val options2 = RequestOptions() + + options1.combine(options2).map: result => + assert(result.headers === SafeMetadata.empty) + assert(result.messageCompression === Maybe.empty) + assert(result.responseCapacity === Maybe.empty) + succeed + } + + "prefers second options messageCompression when both present" in run { + val options1 = RequestOptions(messageCompression = Maybe.Present(true)) + val options2 = RequestOptions(messageCompression = Maybe.Present(false)) + + options1.combine(options2).map: result => + assert(result.messageCompression === Maybe.Present(false)) + succeed + } + + "keeps first options messageCompression when second is absent" in run { + val options1 = RequestOptions(messageCompression = Maybe.Present(true)) + val options2 = RequestOptions() + + options1.combine(options2).map: result => + assert(result.messageCompression === Maybe.Present(true)) + succeed + } + + "prefers second options responseCapacity when both present" in run { + val options1 = RequestOptions(responseCapacity = Maybe.Present(10)) + val options2 = RequestOptions(responseCapacity = Maybe.Present(20)) + + options1.combine(options2).map: result => + assert(result.responseCapacity === Maybe.Present(20)) + succeed + } + + "merges headers from both options" in run { + val h1 = SafeMetadata.empty.add("key1", "value1") + val h2 = SafeMetadata.empty.add("key2", "value2") + + val options1 = RequestOptions(headers = h1) + val options2 = RequestOptions(headers = h2) + + options1.combine(options2).map: result => + assert(result.headers.getStrings("key1") === Seq("value1")) + assert(result.headers.getStrings("key2") === Seq("value2")) + succeed + } + } + + "run" - { + "extracts single emitted option" in run { + val options = RequestOptions(messageCompression = Maybe.Present(true)) + val computation = Emit.value(options) + + RequestOptions.run(computation).map: (result, _) => + assert(result.messageCompression === Maybe.Present(true)) + succeed + } + + "handles computation with no emissions" in run { + val computation = "result" + + RequestOptions.run(computation).map: (result, value) => + assert(result === RequestOptions()) + assert(value === "result") + succeed + } + } + + "constants" - { + "DefaultRequestBuffer is 8" in { + assert(RequestOptions.DefaultRequestBuffer === 8) + succeed + } + + "DefaultResponseCapacity is 8" in { + assert(RequestOptions.DefaultResponseCapacity === 8) + succeed + } + } + +end RequestOptionsTest diff --git a/kyo-grpc-core/shared/src/test/scala/kyo/grpc/ResponseOptionsTest.scala b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/ResponseOptionsTest.scala new file mode 100644 index 000000000..fed677d0e --- /dev/null +++ b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/ResponseOptionsTest.scala @@ -0,0 +1,129 @@ +package kyo.grpc + +import io.grpc.ServerCall +import io.grpc.Status +import kyo.* +import org.scalactic.TripleEquals.* + +class ResponseOptionsTest extends Test: + + "constructor" - { + "creates instance with default empty values" in run { + val options = ResponseOptions() + + assert(options.headers === SafeMetadata.empty) + assert(options.messageCompression === Maybe.empty) + assert(options.compression === Maybe.empty) + assert(options.onReadyThreshold === Maybe.empty) + assert(options.requestBuffer === Maybe.empty) + succeed + } + + "creates instance with specified values" in run { + val headers = SafeMetadata.empty.add("key", "value") + + val options = ResponseOptions( + headers = headers, + messageCompression = Maybe.Present(true), + compression = Maybe.Present("gzip"), + onReadyThreshold = Maybe.Present(32), + requestBuffer = Maybe.Present(16) + ) + + assert(options.headers === headers) + assert(options.messageCompression === Maybe.Present(true)) + assert(options.compression === Maybe.Present("gzip")) + assert(options.onReadyThreshold === Maybe.Present(32)) + assert(options.requestBuffer === Maybe.Present(16)) + succeed + } + } + + "requestBufferOrDefault" - { + "returns specified buffer when present" in run { + val options = ResponseOptions(requestBuffer = Maybe.Present(42)) + assert(options.requestBufferOrDefault === 42) + succeed + } + + "returns default buffer when absent" in run { + val options = ResponseOptions() + assert(options.requestBufferOrDefault === ResponseOptions.DefaultRequestBuffer) + succeed + } + } + + "combine" - { + "merges two empty options" in run { + val options1 = ResponseOptions() + val options2 = ResponseOptions() + + options1.combine(options2).map: result => + assert(result.headers === SafeMetadata.empty) + assert(result.messageCompression === Maybe.empty) + assert(result.compression === Maybe.empty) + assert(result.onReadyThreshold === Maybe.empty) + assert(result.requestBuffer === Maybe.empty) + succeed + } + + "prefers second options messageCompression when both present" in run { + val options1 = ResponseOptions(messageCompression = Maybe.Present(true)) + val options2 = ResponseOptions(messageCompression = Maybe.Present(false)) + + options1.combine(options2).map: result => + assert(result.messageCompression === Maybe.Present(false)) + succeed + } + + "prefers second options compression when both present" in run { + val options1 = ResponseOptions(compression = Maybe.Present("gzip")) + val options2 = ResponseOptions(compression = Maybe.Present("snappy")) + + options1.combine(options2).map: result => + assert(result.compression === Maybe.Present("snappy")) + succeed + } + + "merges headers from both options" in run { + val h1 = SafeMetadata.empty.add("key1", "value1") + val h2 = SafeMetadata.empty.add("key2", "value2") + + val options1 = ResponseOptions(headers = h1) + val options2 = ResponseOptions(headers = h2) + + options1.combine(options2).map: result => + assert(result.headers.getStrings("key1") === Seq("value1")) + assert(result.headers.getStrings("key2") === Seq("value2")) + succeed + } + } + + "run" - { + "extracts single emitted option" in run { + val options = ResponseOptions(messageCompression = Maybe.Present(true)) + val computation = Emit.value(options) + + ResponseOptions.run(computation).map: (result, _) => + assert(result.messageCompression === Maybe.Present(true)) + succeed + } + + "handles computation with no emissions" in run { + val computation = "result" + + ResponseOptions.run(computation).map: (result, value) => + assert(result === ResponseOptions()) + assert(value === "result") + succeed + } + } + + "constants" - { + "DefaultRequestBuffer is 8" in { + assert(ResponseOptions.DefaultRequestBuffer === 8) + succeed + } + } + +end ResponseOptionsTest diff --git a/kyo-grpc-core/shared/src/test/scala/kyo/grpc/Test.scala b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/Test.scala new file mode 100644 index 000000000..21f40de67 --- /dev/null +++ b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/Test.scala @@ -0,0 +1,18 @@ +package kyo.grpc + +import kyo.internal.BaseKyoCoreTest +import kyo.internal.Platform +import org.scalactic.TripleEquals +import org.scalatest.NonImplicitAssertions +import org.scalatest.freespec.AsyncFreeSpec +import scala.concurrent.ExecutionContext +import scala.language.implicitConversions + +abstract class Test extends AsyncFreeSpec with NonImplicitAssertions with TripleEquals with BaseKyoCoreTest: + + type Assertion = org.scalatest.compatible.Assertion + def assertionSuccess = succeed + def assertionFailure(msg: String) = fail(msg) + + override given executionContext: ExecutionContext = Platform.executionContext +end Test diff --git a/kyo-grpc-core/shared/src/test/scala/kyo/grpc/TestMarshaller.scala b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/TestMarshaller.scala new file mode 100644 index 000000000..a278e1250 --- /dev/null +++ b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/TestMarshaller.scala @@ -0,0 +1,11 @@ +package kyo.grpc + +import io.grpc.MethodDescriptor.Marshaller +import java.io.* + +class TestMarshaller[T] extends Marshaller[T]: + override def stream(value: T): InputStream = + ByteArrayInputStream(value.toString.getBytes) + override def parse(stream: InputStream): T = + null.asInstanceOf[T] +end TestMarshaller diff --git a/kyo-grpc-core/shared/src/test/scala/kyo/grpc/internal/BidiStreamingServerCallHandlerTest.scala b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/internal/BidiStreamingServerCallHandlerTest.scala new file mode 100644 index 000000000..79b8dba49 --- /dev/null +++ b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/internal/BidiStreamingServerCallHandlerTest.scala @@ -0,0 +1,548 @@ +package kyo.grpc.internal + +import io.grpc.{Grpc as _, *} +import java.util.concurrent.atomic.AtomicBoolean as JAtomicBoolean +import java.util.concurrent.atomic.AtomicInteger +import kyo.* +import kyo.grpc.* +import kyo.grpc.Equalities.given +import org.scalamock.scalatest.AsyncMockFactory +import org.scalamock.stubs.Stubs +import org.scalatest.concurrent.Eventually +import org.scalatest.matchers.must.Matchers.* +import org.scalatest.time.Seconds +import org.scalatest.time.Span + +class BidiStreamingServerCallHandlerTest extends Test with Stubs with Eventually: + + case class TestRequest(message: String) + case class TestResponse(result: String) + + implicit override def patienceConfig: PatienceConfig = super.patienceConfig.copy(timeout = scaled(Span(5, Seconds))) + + "BidiStreamingServerCallHandler" - { + + "startup" - { + "requests one message from client initially" in run { + val handler: GrpcHandler[Stream[TestRequest, Grpc], Stream[TestResponse, Grpc]] = + requests => requests.map(req => TestResponse(s"echo: ${req.message}")) + + val callHandler = BidiStreamingServerCallHandler(handler) + + val call = stub[ServerCall[TestRequest, TestResponse]] + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + + val requestHeaders = Metadata() + + val listener = callHandler.startCall(call, requestHeaders) + + // Requests 1 initially, then (bufferSize - 1) = 7 more to fill the buffer + assert(call.request.calls === List(1, 7)) + } + + "set options and sends headers" in run { + val handler: GrpcHandler[Stream[TestRequest, Grpc], Stream[TestResponse, Grpc]] = + requests => requests.map(req => TestResponse(s"echo: ${req.message}")) + + val requestHeaders = Metadata() + + val responseHeaders = Metadata() + responseHeaders.put(Metadata.Key.of("custom-header", Metadata.ASCII_STRING_MARSHALLER), "custom-value") + + val responseOptions = ResponseOptions( + headers = SafeMetadata.fromJava(responseHeaders), + messageCompression = Maybe.Present(true), + compression = Maybe.Present("gzip"), + onReadyThreshold = Maybe.Present(16), + requestBuffer = Maybe.Present(4) + ) + + val init: GrpcHandlerInit[Stream[TestRequest, Grpc], Stream[TestResponse, Grpc]] = + for + actualRequestHeaders <- Env.get[SafeMetadata] + _ <- Emit.value(responseOptions) + yield + assert(actualRequestHeaders === SafeMetadata.fromJava(requestHeaders)) + handler + + val callHandler = BidiStreamingServerCallHandler(init) + + val call = stub[ServerCall[TestRequest, TestResponse]] + call.request.returnsWith(()) + call.setMessageCompression.returnsWith(()) + call.setCompression.returnsWith(()) + call.setOnReadyThreshold.returnsWith(()) + call.sendHeaders.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + assert(call.setMessageCompression.calls === responseOptions.messageCompression.toList) + assert(call.setCompression.calls === responseOptions.compression.toList) + assert(call.setOnReadyThreshold.calls === responseOptions.onReadyThreshold.toList) + assert(call.sendHeaders.calls.map(_.toString) === List(responseOptions.headers.toJava.toString)) + } + + "requests additional messages based on buffer size" in run { + val handler: GrpcHandler[Stream[TestRequest, Grpc], Stream[TestResponse, Grpc]] = + requests => requests.map(req => TestResponse(s"echo: ${req.message}")) + + val requestHeaders = Metadata() + + val responseOptions = ResponseOptions( + requestBuffer = Maybe.Present(5) + ) + + val init: GrpcHandlerInit[Stream[TestRequest, Grpc], Stream[TestResponse, Grpc]] = + for + _ <- Emit.value(responseOptions) + yield handler + + val callHandler = BidiStreamingServerCallHandler(init) + + val call = stub[ServerCall[TestRequest, TestResponse]] + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + // Should request 1 initially, then 4 more to fill buffer + assert(call.request.calls === List(1, 4)) + } + } + + "success" - { + "echoes multiple request-response pairs" in run { + import org.scalactic.TraversableEqualityConstraints.* + + val requests = List( + TestRequest("msg1"), + TestRequest("msg2"), + TestRequest("msg3") + ) + + val handler: GrpcHandler[Stream[TestRequest, Grpc], Stream[TestResponse, Grpc]] = + stream => stream.map(req => TestResponse(s"echo: ${req.message}")) + + val callHandler = BidiStreamingServerCallHandler(handler) + + val call = stub[ServerCall[TestRequest, TestResponse]] + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + (() => call.isReady()).returnsWith(true) + call.sendMessage.returnsWith(()) + call.close.returnsWith(()) + + val requestHeaders = Metadata() + + val listener = callHandler.startCall(call, requestHeaders) + + // Send multiple messages + requests.foreach(listener.onMessage) + listener.onHalfClose() + + eventually { + assert(call.sendMessage.times === 3) + assert(call.close.times === 1) + } + + val expectedResponses = requests.map(req => TestResponse(s"echo: ${req.message}")) + assert(call.sendMessage.calls === expectedResponses) + call.close.calls must contain theSameElementsInOrderAs List((Status.OK, Metadata())) + } + + "processes stream with transformations" in run { + val requests = List( + TestRequest("1"), + TestRequest("2"), + TestRequest("3") + ) + + val handler: GrpcHandler[Stream[TestRequest, Grpc], Stream[TestResponse, Grpc]] = + stream => + stream + .map(req => req.message.toInt) + .map(num => num * 2) + .map(num => TestResponse(num.toString)) + + val callHandler = BidiStreamingServerCallHandler(handler) + + val call = stub[ServerCall[TestRequest, TestResponse]] + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + (() => call.isReady()).returnsWith(true) + call.sendMessage.returnsWith(()) + call.close.returnsWith(()) + + val requestHeaders = Metadata() + + val listener = callHandler.startCall(call, requestHeaders) + + requests.foreach(listener.onMessage) + listener.onHalfClose() + + eventually { + assert(call.sendMessage.times === 3) + } + + val expectedResponses = List( + TestResponse("2"), + TestResponse("4"), + TestResponse("6") + ) + assert(call.sendMessage.calls === expectedResponses) + } + + "handles filter operations" in run { + val requests = List( + TestRequest("keep1"), + TestRequest("filter"), + TestRequest("keep2"), + TestRequest("filter"), + TestRequest("keep3") + ) + + val handler: GrpcHandler[Stream[TestRequest, Grpc], Stream[TestResponse, Grpc]] = + stream => + stream + .filter(req => req.message.startsWith("keep")) + .map(req => TestResponse(s"kept: ${req.message}")) + + val callHandler = BidiStreamingServerCallHandler(handler) + + val call = stub[ServerCall[TestRequest, TestResponse]] + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + (() => call.isReady()).returnsWith(true) + call.sendMessage.returnsWith(()) + call.close.returnsWith(()) + + val requestHeaders = Metadata() + + val listener = callHandler.startCall(call, requestHeaders) + + requests.foreach(listener.onMessage) + listener.onHalfClose() + + eventually { + assert(call.sendMessage.times === 3) + } + + val expectedResponses = List( + TestResponse("kept: keep1"), + TestResponse("kept: keep2"), + TestResponse("kept: keep3") + ) + assert(call.sendMessage.calls === expectedResponses) + } + + "processes requests incrementally" in run { + val requests = List( + TestRequest("msg1"), + TestRequest("msg2"), + TestRequest("msg3") + ) + + val processedMessages = new AtomicInteger(0) + + val handler: GrpcHandler[Stream[TestRequest, Grpc], Stream[TestResponse, Grpc]] = + stream => + stream + .tap(_ => Sync.defer(processedMessages.incrementAndGet())) + .map(req => TestResponse(s"processed: ${req.message}")) + + val callHandler = BidiStreamingServerCallHandler(handler) + + val call = stub[ServerCall[TestRequest, TestResponse]] + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + (() => call.isReady()).returnsWith(true) + call.sendMessage.returnsWith(()) + call.close.returnsWith(()) + + val requestHeaders = Metadata() + + val listener = callHandler.startCall(call, requestHeaders) + + requests.foreach(listener.onMessage) + listener.onHalfClose() + + eventually { + assert(call.sendMessage.times === 3) + assert(processedMessages.get === 3) + } + } + + "requests more messages after processing chunks" in run { + val handler: GrpcHandler[Stream[TestRequest, Grpc], Stream[TestResponse, Grpc]] = + stream => stream.map(req => TestResponse(s"echo: ${req.message}")) + + val callHandler = BidiStreamingServerCallHandler(handler) + + val call = stub[ServerCall[TestRequest, TestResponse]] + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.sendMessage.returnsWith(()) + call.close.returnsWith(()) + + val requestHeaders = Metadata() + + val listener = callHandler.startCall(call, requestHeaders) + + // Send multiple messages in sequence + listener.onMessage(TestRequest("msg1")) + listener.onMessage(TestRequest("msg2")) + + eventually { + // Should request more as chunks are processed + assert(call.request.times >= 2) + } + + listener.onHalfClose() + + eventually { + assert(call.close.times === 1) + } + } + + "closes with trailers from handler" in run { + val requests = List(TestRequest("test")) + val responseTrailers = Metadata() + responseTrailers.put(Metadata.Key.of("custom-header", Metadata.ASCII_STRING_MARSHALLER), "custom-value") + + val handler: GrpcHandler[Stream[TestRequest, Grpc], Stream[TestResponse, Grpc]] = + stream => + for + _ <- Emit.value(SafeMetadata.fromJava(responseTrailers)) + yield stream.map(req => TestResponse(s"echo: ${req.message}")) + + val callHandler = BidiStreamingServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.sendMessage.returnsWith(()) + call.close.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + requests.foreach(listener.onMessage) + listener.onHalfClose() + + eventually { + assert(call.close.times === 1) + } + + call.close.calls must contain theSameElementsInOrderAs List((Status.OK, responseTrailers)) + } + + "handles empty request stream" in run { + val handler: GrpcHandler[Stream[TestRequest, Grpc], Stream[TestResponse, Grpc]] = + stream => stream.map(req => TestResponse(s"echo: ${req.message}")) + + val callHandler = BidiStreamingServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.sendMessage.returnsWith(()) + call.close.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + // Close without sending any messages + listener.onHalfClose() + + eventually { + assert(call.sendMessage.times === 0) + assert(call.close.times === 1) + } + + call.close.calls must contain theSameElementsInOrderAs List((Status.OK, Metadata())) + } + + "handles empty response stream" in run { + val requests = List(TestRequest("test")) + + val handler: GrpcHandler[Stream[TestRequest, Grpc], Stream[TestResponse, Grpc]] = + stream => stream.filter(_ => false).map(req => TestResponse("never")) + + val callHandler = BidiStreamingServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.sendMessage.returnsWith(()) + call.close.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + requests.foreach(listener.onMessage) + listener.onHalfClose() + + eventually { + assert(call.sendMessage.times === 0) + assert(call.close.times === 1) + } + + call.close.calls must contain theSameElementsInOrderAs List((Status.OK, Metadata())) + } + } + + "errors" - { + "handles abort failure correctly" in run { + val status = Status.INVALID_ARGUMENT.withDescription("Bad request") + + val handler: GrpcHandler[Stream[TestRequest, Grpc], Stream[TestResponse, Grpc]] = + stream => Abort.fail(status.asException()) + + val callHandler = BidiStreamingServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.close.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + listener.onMessage(TestRequest("test")) + listener.onHalfClose() + + eventually { + assert(call.close.times === 1) + } + + call.close.calls must contain theSameElementsInOrderAs List((status, Metadata())) + } + + "handles panic correctly" in run { + val cause = Exception("Something went wrong") + + val handler: GrpcHandler[Stream[TestRequest, Grpc], Stream[TestResponse, Grpc]] = + stream => Abort.panic(cause) + + val callHandler = BidiStreamingServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.close.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + listener.onMessage(TestRequest("test")) + listener.onHalfClose() + + eventually { + assert(call.close.times === 1) + } + + val status = Status.UNKNOWN.withCause(cause) + call.close.calls must contain theSameElementsInOrderAs List((status, Metadata())) + } + + "handles error during stream processing" in run { + val handler: GrpcHandler[Stream[TestRequest, Grpc], Stream[TestResponse, Grpc]] = + stream => + stream.map(req => + if req.message == "error" then + throw Exception("Processing error") + else + TestResponse(s"echo: ${req.message}") + ) + + val callHandler = BidiStreamingServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.sendMessage.returnsWith(()) + call.close.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + listener.onMessage(TestRequest("ok")) + listener.onMessage(TestRequest("error")) + listener.onHalfClose() + + eventually { + assert(call.close.times === 1) + } + + val (status, _) = call.close.calls.head + assert(status.getCode === Status.Code.UNKNOWN) + } + } + + "cancellation" - { + "interrupts stream processing" in run { + val handler: GrpcHandler[Stream[TestRequest, Grpc], Stream[TestResponse, Grpc]] = + stream => stream.map(req => TestResponse(s"echo: ${req.message}")) + + val callHandler = BidiStreamingServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.close.returnsWith(()) + + val interrupted = new JAtomicBoolean(false) + call.sendMessage.returnsWith { + try + Thread.sleep(patienceConfig.timeout.toMillis + 5000) + catch + case e: InterruptedException => + interrupted.set(true) + throw e + } + + val listener = callHandler.startCall(call, requestHeaders) + + listener.onMessage(TestRequest("test")) + listener.onHalfClose() + listener.onCancel() + + eventually { + assert(call.close.times === 1) + } + + val status = Status.CANCELLED.withDescription("Call was cancelled.") + call.close.calls must contain theSameElementsInOrderAs List((status, Metadata())) + } + } + + // TODO: Re-enable these tests after implementing Signal-based flow control + // "flow control" - { + // ... + // } + + "lifecycle" - { + "onComplete does nothing" in run { + val handler: GrpcHandler[Stream[TestRequest, Grpc], Stream[TestResponse, Grpc]] = + stream => stream.map(req => TestResponse(s"echo: ${req.message}")) + + val callHandler = BidiStreamingServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + // Call onComplete - should not throw + listener.onComplete() + + assert(call.request.times === 2) // Initial 1 + buffer fill + } + } + } + +end BidiStreamingServerCallHandlerTest diff --git a/kyo-grpc-core/shared/src/test/scala/kyo/grpc/internal/ClientStreamingServerCallHandlerTest.scala b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/internal/ClientStreamingServerCallHandlerTest.scala new file mode 100644 index 000000000..e19d4306f --- /dev/null +++ b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/internal/ClientStreamingServerCallHandlerTest.scala @@ -0,0 +1,454 @@ +package kyo.grpc.internal + +import io.grpc.{Grpc as _, *} +import java.util.concurrent.atomic.AtomicBoolean as JAtomicBoolean +import java.util.concurrent.atomic.AtomicInteger +import kyo.* +import kyo.grpc.* +import kyo.grpc.Equalities.given +import org.scalamock.scalatest.AsyncMockFactory +import org.scalamock.stubs.Stubs +import org.scalatest.concurrent.Eventually +import org.scalatest.matchers.must.Matchers.* +import org.scalatest.time.Seconds +import org.scalatest.time.Span + +class ClientStreamingServerCallHandlerTest extends Test with Stubs with Eventually: + + case class TestRequest(message: String) + case class TestResponse(result: String) + + implicit override def patienceConfig: PatienceConfig = super.patienceConfig.copy(timeout = scaled(Span(5, Seconds))) + + "ClientStreamingServerCallHandler" - { + + "startup" - { + "requests one message from client initially" in run { + val handler: GrpcHandler[Stream[TestRequest, Grpc], TestResponse] = + requests => requests.run.map(_ => TestResponse("response")) + + val callHandler = ClientStreamingServerCallHandler(handler) + + val call = stub[ServerCall[TestRequest, TestResponse]] + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + + val requestHeaders = Metadata() + + val listener = callHandler.startCall(call, requestHeaders) + + // Requests 1 initially, then (bufferSize - 1) = 7 more to fill the buffer + assert(call.request.calls === List(1, 7)) + } + + "set options and sends headers" in run { + val handler: GrpcHandler[Stream[TestRequest, Grpc], TestResponse] = + requests => requests.run.map(_ => TestResponse("response")) + + val requestHeaders = Metadata() + + val responseHeaders = Metadata() + responseHeaders.put(Metadata.Key.of("custom-header", Metadata.ASCII_STRING_MARSHALLER), "custom-value") + + val responseOptions = ResponseOptions( + headers = SafeMetadata.fromJava(responseHeaders), + messageCompression = Maybe.Present(true), + compression = Maybe.Present("gzip"), + onReadyThreshold = Maybe.Present(16), + requestBuffer = Maybe.Present(4) + ) + + val init: GrpcHandlerInit[Stream[TestRequest, Grpc], TestResponse] = + for + actualRequestHeaders <- Env.get[SafeMetadata] + _ <- Emit.value(responseOptions) + yield + assert(actualRequestHeaders === SafeMetadata.fromJava(requestHeaders)) + handler + + val callHandler = ClientStreamingServerCallHandler(init) + + val call = stub[ServerCall[TestRequest, TestResponse]] + call.request.returnsWith(()) + call.setMessageCompression.returnsWith(()) + call.setCompression.returnsWith(()) + call.setOnReadyThreshold.returnsWith(()) + call.sendHeaders.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + assert(call.setMessageCompression.calls === responseOptions.messageCompression.toList) + assert(call.setCompression.calls === responseOptions.compression.toList) + assert(call.setOnReadyThreshold.calls === responseOptions.onReadyThreshold.toList) + assert(call.sendHeaders.calls.map(_.toString) === List(responseOptions.headers.toJava.toString)) + } + + "requests additional messages based on buffer size" in run { + val handler: GrpcHandler[Stream[TestRequest, Grpc], TestResponse] = + requests => requests.run.map(_ => TestResponse("response")) + + val requestHeaders = Metadata() + + val responseOptions = ResponseOptions( + requestBuffer = Maybe.Present(5) + ) + + val init: GrpcHandlerInit[Stream[TestRequest, Grpc], TestResponse] = + for + _ <- Emit.value(responseOptions) + yield handler + + val callHandler = ClientStreamingServerCallHandler(init) + + val call = stub[ServerCall[TestRequest, TestResponse]] + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + // Should request 1 initially, then 4 more to fill buffer + assert(call.request.calls === List(1, 4)) + } + } + + "success" - { + "receives multiple messages and sends single response" in run { + import org.scalactic.TraversableEqualityConstraints.* + + val requests = List( + TestRequest("msg1"), + TestRequest("msg2"), + TestRequest("msg3") + ) + + val handler: GrpcHandler[Stream[TestRequest, Grpc], TestResponse] = + stream => + stream.run.map(chunk => + TestResponse(s"received ${chunk.size} messages") + ) + + val callHandler = ClientStreamingServerCallHandler(handler) + + val call = stub[ServerCall[TestRequest, TestResponse]] + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.sendMessage.returnsWith(()) + call.close.returnsWith(()) + + val requestHeaders = Metadata() + + val listener = callHandler.startCall(call, requestHeaders) + + // Send multiple messages + requests.foreach(listener.onMessage) + listener.onHalfClose() + + eventually { + assert(call.sendMessage.times === 1) + assert(call.close.times === 1) + } + + assert(call.sendMessage.calls.head.result === "received 3 messages") + call.close.calls must contain theSameElementsInOrderAs List((Status.OK, Metadata())) + } + + "processes stream incrementally" in run { + val requests = List( + TestRequest("msg1"), + TestRequest("msg2"), + TestRequest("msg3") + ) + + val processedMessages = new AtomicInteger(0) + + val handler: GrpcHandler[Stream[TestRequest, Grpc], TestResponse] = + stream => + stream + .tap(_ => Sync.defer(processedMessages.incrementAndGet())) + .run + .map(_ => TestResponse(s"processed ${processedMessages.get} messages")) + + val callHandler = ClientStreamingServerCallHandler(handler) + + val call = stub[ServerCall[TestRequest, TestResponse]] + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.sendMessage.returnsWith(()) + call.close.returnsWith(()) + + val requestHeaders = Metadata() + + val listener = callHandler.startCall(call, requestHeaders) + + requests.foreach(listener.onMessage) + listener.onHalfClose() + + eventually { + assert(call.sendMessage.times === 1) + assert(processedMessages.get === 3) + } + } + + "requests more messages after processing chunks" in run { + val handler: GrpcHandler[Stream[TestRequest, Grpc], TestResponse] = + stream => stream.run.map(_ => TestResponse("done")) + + val callHandler = ClientStreamingServerCallHandler(handler) + + val call = stub[ServerCall[TestRequest, TestResponse]] + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.sendMessage.returnsWith(()) + call.close.returnsWith(()) + + val requestHeaders = Metadata() + + val listener = callHandler.startCall(call, requestHeaders) + + // Send multiple messages in chunks + listener.onMessage(TestRequest("msg1")) + listener.onMessage(TestRequest("msg2")) + + eventually { + // Should request more as chunks are processed + assert(call.request.times >= 2) + } + + listener.onHalfClose() + + eventually { + assert(call.close.times === 1) + } + } + + "closes with trailers from handler" in run { + val requests = List(TestRequest("test")) + val responseTrailers = Metadata() + responseTrailers.put(Metadata.Key.of("custom-header", Metadata.ASCII_STRING_MARSHALLER), "custom-value") + + val handler: GrpcHandler[Stream[TestRequest, Grpc], TestResponse] = + stream => + for + _ <- Emit.value(SafeMetadata.fromJava(responseTrailers)) + chunk <- stream.run + yield TestResponse("response") + + val callHandler = ClientStreamingServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.sendMessage.returnsWith(()) + call.close.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + requests.foreach(listener.onMessage) + listener.onHalfClose() + + eventually { + assert(call.close.times === 1) + } + + call.close.calls must contain theSameElementsInOrderAs List((Status.OK, responseTrailers)) + } + } + + "errors" - { + "handles abort failure correctly" in run { + val status = Status.INVALID_ARGUMENT.withDescription("Bad request") + + val handler: GrpcHandler[Stream[TestRequest, Grpc], TestResponse] = + stream => Abort.fail(status.asException()) + + val callHandler = ClientStreamingServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.close.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + listener.onMessage(TestRequest("test")) + listener.onHalfClose() + + eventually { + assert(call.close.times === 1) + } + + call.close.calls must contain theSameElementsInOrderAs List((status, Metadata())) + } + + "handles panic correctly" in run { + val cause = Exception("Something went wrong") + + val handler: GrpcHandler[Stream[TestRequest, Grpc], TestResponse] = + stream => Abort.panic(cause) + + val callHandler = ClientStreamingServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.close.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + listener.onMessage(TestRequest("test")) + listener.onHalfClose() + + eventually { + assert(call.close.times === 1) + } + + val status = Status.UNKNOWN.withCause(cause) + call.close.calls must contain theSameElementsInOrderAs List((status, Metadata())) + } + + "handles error during stream processing" in run { + val handler: GrpcHandler[Stream[TestRequest, Grpc], TestResponse] = + stream => + stream + .map(req => + if req.message == "error" then + throw Exception("Processing error") + else + req + ) + .run + .map(_ => TestResponse("response")) + + val callHandler = ClientStreamingServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.close.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + listener.onMessage(TestRequest("ok")) + listener.onMessage(TestRequest("error")) + listener.onHalfClose() + + eventually { + assert(call.close.times === 1) + } + + val (status, _) = call.close.calls.head + assert(status.getCode === Status.Code.UNKNOWN) + } + } + + "cancellation" - { + "interrupts stream processing" in run { + val handler: GrpcHandler[Stream[TestRequest, Grpc], TestResponse] = + stream => stream.run.map(_ => TestResponse("response")) + + val callHandler = ClientStreamingServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.close.returnsWith(()) + + val interrupted = new JAtomicBoolean(false) + call.sendMessage.returnsWith { + try + Thread.sleep(patienceConfig.timeout.toMillis + 5000) + catch + case e: InterruptedException => + interrupted.set(true) + throw e + } + + val listener = callHandler.startCall(call, requestHeaders) + + listener.onMessage(TestRequest("test")) + listener.onHalfClose() + listener.onCancel() + + eventually { + assert(call.close.times === 1) + } + + val status = Status.CANCELLED.withDescription("Call was cancelled.") + call.close.calls must contain theSameElementsInOrderAs List((status, Metadata())) + } + } + + "lifecycle" - { + "onComplete does nothing" in run { + val handler: GrpcHandler[Stream[TestRequest, Grpc], TestResponse] = + stream => stream.run.map(_ => TestResponse("response")) + + val callHandler = ClientStreamingServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + // Call onComplete - should not throw + listener.onComplete() + + assert(call.request.times === 2) // Initial 1 + buffer fill + } + + "onReady does nothing" in run { + val handler: GrpcHandler[Stream[TestRequest, Grpc], TestResponse] = + stream => stream.run.map(_ => TestResponse("response")) + + val callHandler = ClientStreamingServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + (() => call.isReady()).returnsWith(true) + + val listener = callHandler.startCall(call, requestHeaders) + + // Call onReady - should not throw + listener.onReady() + + assert(call.request.times === 2) // Initial 1 + buffer fill + } + + "handles empty stream" in run { + val handler: GrpcHandler[Stream[TestRequest, Grpc], TestResponse] = + stream => stream.run.map(_ => TestResponse("no messages")) + + val callHandler = ClientStreamingServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.sendMessage.returnsWith(()) + call.close.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + // Close without sending any messages + listener.onHalfClose() + + eventually { + assert(call.sendMessage.times === 1) + assert(call.close.times === 1) + } + + call.close.calls must contain theSameElementsInOrderAs List((Status.OK, Metadata())) + } + } + } + +end ClientStreamingServerCallHandlerTest diff --git a/kyo-grpc-core/shared/src/test/scala/kyo/grpc/internal/MetadataExtensionsTest.scala b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/internal/MetadataExtensionsTest.scala new file mode 100644 index 000000000..216fd2e53 --- /dev/null +++ b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/internal/MetadataExtensionsTest.scala @@ -0,0 +1,86 @@ +package kyo.grpc.internal + +import kyo.* +import kyo.grpc.* +import org.scalactic.TripleEquals.* + +class MetadataExtensionsTest extends Test: + + "Maybe[SafeMetadata] extension" - { + + "mergeIfDefined" - { + + "merges Present metadata with Present other" in run { + val m1 = SafeMetadata.empty.add("key1", "value1") + val m2 = SafeMetadata.empty.add("key2", "value2") + + Maybe.Present(m1).mergeIfDefined(Maybe.Present(m2)).map: + case Maybe.Present(merged) => + assert(merged.getStrings("key1") === Seq("value1")) + assert(merged.getStrings("key2") === Seq("value2")) + succeed + case Maybe.Absent => + fail("Expected Present but got Absent") + } + + "returns Present metadata when other is Absent" in run { + val m1 = SafeMetadata.empty.add("key1", "value1") + + Maybe.Present(m1).mergeIfDefined(Maybe.Absent).map: + case Maybe.Present(merged) => + assert(merged.getStrings("key1") === Seq("value1")) + succeed + case Maybe.Absent => + fail("Expected Present but got Absent") + } + + "returns other when metadata is Absent and other is Present" in run { + val m2 = SafeMetadata.empty.add("key1", "value1") + + (Maybe.Absent: Maybe[SafeMetadata]).mergeIfDefined(Maybe.Present(m2)).map: + case Maybe.Present(merged) => + assert(merged.getStrings("key1") === Seq("value1")) + succeed + case Maybe.Absent => + fail("Expected Present but got Absent") + } + + "returns Absent when both are Absent" in run { + (Maybe.Absent: Maybe[SafeMetadata]).mergeIfDefined(Maybe.Absent).map: + case Maybe.Present(_) => + fail("Expected Absent but got Present") + case Maybe.Absent => + succeed + } + } + } + + "SafeMetadata" - { + + "fromJava and toJava round-trip" in { + val sm = SafeMetadata.empty.add("test-key", "test-value") + val java = sm.toJava + val back = SafeMetadata.fromJava(java) + assert(back.getStrings("test-key") === Seq("test-value")) + succeed + } + + "merge combines entries" in { + val m1 = SafeMetadata.empty.add("key1", "value1") + val m2 = SafeMetadata.empty.add("key2", "value2") + val merged = m1.merge(m2) + assert(merged.getStrings("key1") === Seq("value1")) + assert(merged.getStrings("key2") === Seq("value2")) + succeed + } + + "merge appends duplicate keys" in { + val m1 = SafeMetadata.empty.add("key1", "value1") + val m2 = SafeMetadata.empty.add("key1", "value2") + val merged = m1.merge(m2) + assert(merged.getStrings("key1") === Seq("value1", "value2")) + succeed + } + } + +end MetadataExtensionsTest diff --git a/kyo-grpc-core/shared/src/test/scala/kyo/grpc/internal/ServerStreamingClientCallListenerTest.scala b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/internal/ServerStreamingClientCallListenerTest.scala new file mode 100644 index 000000000..e39aa52ae --- /dev/null +++ b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/internal/ServerStreamingClientCallListenerTest.scala @@ -0,0 +1,85 @@ +package kyo.grpc.internal + +import io.grpc.{Channel as _, *} +import kyo.* +import kyo.grpc.* +import org.scalactic.TripleEquals.* + +class ServerStreamingClientCallListenerTest extends Test: + + case class TestResponse(result: String) + + "ServerStreamingClientCallListener" - { + + "onHeaders completes headers promise with SafeMetadata" in run { + for + headersPromise <- Promise.init[SafeMetadata, Any] + responseChannel <- Channel.init[TestResponse](8) + completionPromise <- Promise.init[CallClosed, Any] + readySignal <- Signal.initRef[Boolean](false) + listener = ServerStreamingClientCallListener(headersPromise, responseChannel, completionPromise, readySignal) + + headers = new Metadata() + key = Metadata.Key.of("test-header", Metadata.ASCII_STRING_MARSHALLER) + _ = headers.put(key, "test-value") + _ = listener.onHeaders(headers) + + result <- headersPromise.get + yield assert(result.getStrings("test-header") === Seq("test-value")) + } + + "onMessage offers messages to channel" in run { + for + headersPromise <- Promise.init[SafeMetadata, Any] + responseChannel <- Channel.init[TestResponse](8) + completionPromise <- Promise.init[CallClosed, Any] + readySignal <- Signal.initRef[Boolean](false) + listener = ServerStreamingClientCallListener(headersPromise, responseChannel, completionPromise, readySignal) + + response1 = TestResponse("first") + response2 = TestResponse("second") + _ = listener.onMessage(response1) + _ = listener.onMessage(response2) + + result1 <- responseChannel.take + result2 <- responseChannel.take + yield + assert(result1 === response1) + assert(result2 === response2) + } + + "onClose closes channel and completes completion promise" in run { + for + headersPromise <- Promise.init[SafeMetadata, Any] + responseChannel <- Channel.init[TestResponse](8) + completionPromise <- Promise.init[CallClosed, Any] + readySignal <- Signal.initRef[Boolean](false) + listener = ServerStreamingClientCallListener(headersPromise, responseChannel, completionPromise, readySignal) + + status = Status.OK + trailers = new Metadata() + _ = listener.onClose(status, trailers) + + completionResult <- completionPromise.get + channelClosed <- responseChannel.closed + yield + assert(completionResult.status === status) + assert(channelClosed === true) + } + + "onReady sets ready signal to true" in run { + for + headersPromise <- Promise.init[SafeMetadata, Any] + responseChannel <- Channel.init[TestResponse](8) + completionPromise <- Promise.init[CallClosed, Any] + readySignal <- Signal.initRef[Boolean](false) + listener = ServerStreamingClientCallListener(headersPromise, responseChannel, completionPromise, readySignal) + + _ = listener.onReady() + + ready <- readySignal.get + yield assert(ready === true) + } + } + +end ServerStreamingClientCallListenerTest diff --git a/kyo-grpc-core/shared/src/test/scala/kyo/grpc/internal/ServerStreamingServerCallHandlerTest.scala b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/internal/ServerStreamingServerCallHandlerTest.scala new file mode 100644 index 000000000..2d756e523 --- /dev/null +++ b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/internal/ServerStreamingServerCallHandlerTest.scala @@ -0,0 +1,411 @@ +package kyo.grpc.internal + +import io.grpc.{Grpc as _, *} +import java.util.concurrent.atomic.AtomicBoolean as JAtomicBoolean +import java.util.concurrent.atomic.AtomicInteger +import kyo.* +import kyo.grpc.* +import kyo.grpc.Equalities.given +import org.scalamock.scalatest.AsyncMockFactory +import org.scalamock.stubs.Stubs +import org.scalatest.concurrent.Eventually +import org.scalatest.matchers.must.Matchers.* +import org.scalatest.time.Seconds +import org.scalatest.time.Span + +class ServerStreamingServerCallHandlerTest extends Test with Stubs with Eventually: + + case class TestRequest(message: String) + case class TestResponse(result: String) + + implicit override def patienceConfig: PatienceConfig = super.patienceConfig.copy(timeout = scaled(Span(5, Seconds))) + + "ServerStreamingServerCallHandler" - { + + "startup" - { + "requests one message from client" in run { + val handler: GrpcHandler[TestRequest, Stream[TestResponse, Grpc]] = + req => Stream.empty[TestResponse] + + val callHandler = ServerStreamingServerCallHandler(handler) + + val call = stub[ServerCall[TestRequest, TestResponse]] + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + + val requestHeaders = Metadata() + + val listener = callHandler.startCall(call, requestHeaders) + + assert(call.request.calls === List(1)) + } + + "set options and sends headers" in run { + val handler: GrpcHandler[TestRequest, Stream[TestResponse, Grpc]] = + req => Stream.empty[TestResponse] + + val requestHeaders = Metadata() + + val responseHeaders = Metadata() + responseHeaders.put(Metadata.Key.of("custom-header", Metadata.ASCII_STRING_MARSHALLER), "custom-value") + + val responseOptions = ResponseOptions( + headers = SafeMetadata.fromJava(responseHeaders), + messageCompression = Maybe.Present(true), + compression = Maybe.Present("gzip"), + onReadyThreshold = Maybe.Present(16), + requestBuffer = Maybe.Present(4) + ) + + val init: GrpcHandlerInit[TestRequest, Stream[TestResponse, Grpc]] = + for + actualRequestHeaders <- Env.get[SafeMetadata] + _ <- Emit.value(responseOptions) + yield + assert(actualRequestHeaders === SafeMetadata.fromJava(requestHeaders)) + handler + + val callHandler = ServerStreamingServerCallHandler(init) + + val call = stub[ServerCall[TestRequest, TestResponse]] + call.request.returnsWith(()) + call.setMessageCompression.returnsWith(()) + call.setCompression.returnsWith(()) + call.setOnReadyThreshold.returnsWith(()) + call.sendHeaders.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + assert(call.setMessageCompression.calls === responseOptions.messageCompression.toList) + assert(call.setCompression.calls === responseOptions.compression.toList) + assert(call.setOnReadyThreshold.calls === responseOptions.onReadyThreshold.toList) + assert(call.sendHeaders.calls.map(_.toString) === List(responseOptions.headers.toJava.toString)) + } + } + + "success" - { + "sends multiple response messages" in run { + import org.scalactic.TraversableEqualityConstraints.* + + val request = TestRequest("test") + val responses = List( + TestResponse("response1"), + TestResponse("response2"), + TestResponse("response3") + ) + + val handler: GrpcHandler[TestRequest, Stream[TestResponse, Grpc]] = + req => Stream.init(responses) + + val callHandler = ServerStreamingServerCallHandler(handler) + + val call = stub[ServerCall[TestRequest, TestResponse]] + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + (() => call.isReady()).returnsWith(true) + call.sendMessage.returnsWith(()) + call.close.returnsWith(()) + + val requestHeaders = Metadata() + + val listener = callHandler.startCall(call, requestHeaders) + + // Simulate receiving a message + listener.onMessage(request) + listener.onHalfClose() + + eventually { + assert(call.sendMessage.times === 3) + assert(call.close.times === 1) + } + + assert(call.sendMessage.calls === responses) + call.close.calls must contain theSameElementsInOrderAs List((Status.OK, Metadata())) + } + + "sends empty stream" in run { + import org.scalactic.TraversableEqualityConstraints.* + + val request = TestRequest("test") + + val handler: GrpcHandler[TestRequest, Stream[TestResponse, Grpc]] = + req => Stream.empty[TestResponse] + + val callHandler = ServerStreamingServerCallHandler(handler) + + val call = stub[ServerCall[TestRequest, TestResponse]] + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.sendMessage.returnsWith(()) + call.close.returnsWith(()) + + val requestHeaders = Metadata() + + val listener = callHandler.startCall(call, requestHeaders) + + listener.onMessage(request) + listener.onHalfClose() + + eventually { + assert(call.sendMessage.times === 0) + assert(call.close.times === 1) + } + + call.close.calls must contain theSameElementsInOrderAs List((Status.OK, Metadata())) + } + + "processes responses incrementally" in run { + val request = TestRequest("test") + val sentMessages = new AtomicInteger(0) + + val handler: GrpcHandler[TestRequest, Stream[TestResponse, Grpc]] = + req => + Stream.init(List( + TestResponse("response1"), + TestResponse("response2"), + TestResponse("response3") + )).tap(_ => Sync.defer(sentMessages.incrementAndGet())) + + val callHandler = ServerStreamingServerCallHandler(handler) + + val call = stub[ServerCall[TestRequest, TestResponse]] + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + (() => call.isReady()).returnsWith(true) + call.sendMessage.returnsWith(()) + call.close.returnsWith(()) + + val requestHeaders = Metadata() + + val listener = callHandler.startCall(call, requestHeaders) + + listener.onMessage(request) + listener.onHalfClose() + + eventually { + assert(call.sendMessage.times === 3) + assert(sentMessages.get === 3) + } + } + + "closes with trailers from handler" in run { + val request = TestRequest("test") + val responseTrailers = Metadata() + responseTrailers.put(Metadata.Key.of("custom-header", Metadata.ASCII_STRING_MARSHALLER), "custom-value") + + val handler: GrpcHandler[TestRequest, Stream[TestResponse, Grpc]] = + req => + for + _ <- Emit.value(SafeMetadata.fromJava(responseTrailers)) + yield Stream.init(List(TestResponse("response"))) + + val callHandler = ServerStreamingServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.sendMessage.returnsWith(()) + call.close.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + listener.onMessage(request) + listener.onHalfClose() + + eventually { + assert(call.close.times === 1) + } + + call.close.calls must contain theSameElementsInOrderAs List((Status.OK, responseTrailers)) + } + } + + "errors" - { + "handles abort failure correctly" in run { + val request = TestRequest("test") + val status = Status.INVALID_ARGUMENT.withDescription("Bad request") + + val handler: GrpcHandler[TestRequest, Stream[TestResponse, Grpc]] = + req => Abort.fail(status.asException()) + + val callHandler = ServerStreamingServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.close.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + listener.onMessage(request) + listener.onHalfClose() + + eventually { + assert(call.close.times === 1) + } + + call.close.calls must contain theSameElementsInOrderAs List((status, Metadata())) + } + + "handles panic correctly" in run { + val request = TestRequest("test") + val cause = Exception("Something went wrong") + + val handler: GrpcHandler[TestRequest, Stream[TestResponse, Grpc]] = + req => Abort.panic(cause) + + val callHandler = ServerStreamingServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.close.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + listener.onMessage(request) + listener.onHalfClose() + + eventually { + assert(call.close.times === 1) + } + + val status = Status.UNKNOWN.withCause(cause) + call.close.calls must contain theSameElementsInOrderAs List((status, Metadata())) + } + + "handles error during stream generation" in run { + val request = TestRequest("test") + + val handler: GrpcHandler[TestRequest, Stream[TestResponse, Grpc]] = + req => + Stream.init(List( + TestResponse("response1"), + TestResponse("response2") + )).map(resp => + if resp.result == "response2" then + throw Exception("Stream error") + else + resp + ) + + val callHandler = ServerStreamingServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.sendMessage.returnsWith(()) + call.close.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + listener.onMessage(request) + listener.onHalfClose() + + eventually { + assert(call.close.times === 1) + } + + val (status, _) = call.close.calls.head + assert(status.getCode === Status.Code.UNKNOWN) + } + + "fails when client completes without sending request" in run { + val handler: GrpcHandler[TestRequest, Stream[TestResponse, Grpc]] = + req => Stream.empty[TestResponse] + + val callHandler = ServerStreamingServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.close.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + // Complete without sending a message + listener.onHalfClose() + + eventually { + assert(call.close.times === 1) + } + + val status = Status.INVALID_ARGUMENT.withDescription("Client completed before sending a request.") + call.close.calls must contain theSameElementsInOrderAs List((status, Metadata())) + } + } + + "cancellation" - { + "interrupts while sending messages" in run { + val request = TestRequest("test") + val responses = List.fill(10)(TestResponse("response")) + + val handler: GrpcHandler[TestRequest, Stream[TestResponse, Grpc]] = + req => Stream.init(responses) + + val callHandler = ServerStreamingServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.close.returnsWith(()) + + val interrupted = new JAtomicBoolean(false) + call.sendMessage.returnsWith { + try + Thread.sleep(patienceConfig.timeout.toMillis + 5000) + catch + case e: InterruptedException => + interrupted.set(true) + throw e + } + + val listener = callHandler.startCall(call, requestHeaders) + + listener.onMessage(request) + listener.onHalfClose() + listener.onCancel() + + eventually { + assert(call.close.times === 1) + } + + val status = Status.CANCELLED.withDescription("Call was cancelled.") + call.close.calls must contain theSameElementsInOrderAs List((status, Metadata())) + } + } + + // TODO: Re-enable these tests after implementing Signal-based flow control + // "flow control" - { + // ... + // } + + "lifecycle" - { + "onComplete does nothing" in run { + val handler: GrpcHandler[TestRequest, Stream[TestResponse, Grpc]] = + req => Stream.empty[TestResponse] + + val callHandler = ServerStreamingServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + // Call onComplete - should not throw + listener.onComplete() + + assert(call.request.times === 1) + } + } + } + +end ServerStreamingServerCallHandlerTest diff --git a/kyo-grpc-core/shared/src/test/scala/kyo/grpc/internal/UnaryClientCallListenerTest.scala b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/internal/UnaryClientCallListenerTest.scala new file mode 100644 index 000000000..63f2d42aa --- /dev/null +++ b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/internal/UnaryClientCallListenerTest.scala @@ -0,0 +1,229 @@ +package kyo.grpc.internal + +import io.grpc.* +import kyo.* +import kyo.grpc.* +import kyo.grpc.Equalities.given +import org.scalatest.concurrent.Eventually +import org.scalatest.matchers.must.Matchers.* +import org.scalatest.time.Seconds +import org.scalatest.time.Span + +class UnaryClientCallListenerTest extends Test with Eventually: + + case class TestResponse(result: String) + + implicit override def patienceConfig: PatienceConfig = super.patienceConfig.copy(timeout = scaled(Span(5, Seconds))) + + "UnaryClientCallListener" - { + + "onHeaders" - { + "completes headers promise" in run { + for + headersPromise <- Promise.init[SafeMetadata, Any] + responsePromise <- Promise.init[TestResponse, Abort[StatusException]] + completionPromise <- Promise.init[CallClosed, Any] + readySignal <- Signal.initRef[Boolean](false) + listener = UnaryClientCallListener(headersPromise, responsePromise, completionPromise, readySignal) + + headers = Metadata() + key = Metadata.Key.of("test-header", Metadata.ASCII_STRING_MARSHALLER) + _ = headers.put(key, "test-value") + _ = listener.onHeaders(headers) + + result <- headersPromise.get + yield assert(result.getStrings("test-header") === Seq("test-value")) + } + } + + "onMessage" - { + "completes response promise with first message" in run { + for + headersPromise <- Promise.init[SafeMetadata, Any] + responsePromise <- Promise.init[TestResponse, Abort[StatusException]] + completionPromise <- Promise.init[CallClosed, Any] + readySignal <- Signal.initRef[Boolean](false) + listener = UnaryClientCallListener(headersPromise, responsePromise, completionPromise, readySignal) + + response = TestResponse("success") + _ = listener.onMessage(response) + + result <- Abort.run[StatusException](responsePromise.get) + yield + assert(result.isSuccess) + assert(result.getOrThrow === response) + } + + "throws exception when server sends more than one response" in run { + for + headersPromise <- Promise.init[SafeMetadata, Any] + responsePromise <- Promise.init[TestResponse, Abort[StatusException]] + completionPromise <- Promise.init[CallClosed, Any] + readySignal <- Signal.initRef[Boolean](false) + listener = UnaryClientCallListener(headersPromise, responsePromise, completionPromise, readySignal) + + response1 = TestResponse("first") + response2 = TestResponse("second") + _ = listener.onMessage(response1) + + exception <- Abort.run[Throwable]: + Abort.catching[StatusException]: + Sync.defer(listener.onMessage(response2)) + yield exception.fold( + _ => fail("Expected exception but got success"), + ex => + ex match + case se: StatusException => + assert(se.getStatus.getCode === Status.Code.INVALID_ARGUMENT) + assert(se.getStatus.getDescription === "Server sent more than one response.") + case _ => + fail(s"Expected StatusException but got ${ex.getClass}") + , + _ => fail("Expected exception but got panic") + ) + } + } + + "onClose" - { + "completes response promise with error when no message received" in run { + for + headersPromise <- Promise.init[SafeMetadata, Any] + responsePromise <- Promise.init[TestResponse, Abort[StatusException]] + completionPromise <- Promise.init[CallClosed, Any] + readySignal <- Signal.initRef[Boolean](false) + listener = UnaryClientCallListener(headersPromise, responsePromise, completionPromise, readySignal) + + status = Status.CANCELLED.withDescription("Client cancelled") + trailers = Metadata() + _ = listener.onClose(status, trailers) + + completionResult <- completionPromise.get + responseResult <- Abort.run[StatusException](responsePromise.get) + yield + assert(completionResult.status === status) + assert(completionResult.trailers === SafeMetadata.fromJava(trailers)) + assert(responseResult.isFailure) + val failure = responseResult.failure.get + assert(failure.getStatus === status) + } + + "completes completion promise after message received" in run { + for + headersPromise <- Promise.init[SafeMetadata, Any] + responsePromise <- Promise.init[TestResponse, Abort[StatusException]] + completionPromise <- Promise.init[CallClosed, Any] + readySignal <- Signal.initRef[Boolean](false) + listener = UnaryClientCallListener(headersPromise, responsePromise, completionPromise, readySignal) + + response = TestResponse("success") + _ = listener.onMessage(response) + + status = Status.OK + trailers = Metadata() + key = Metadata.Key.of("trailer-key", Metadata.ASCII_STRING_MARSHALLER) + _ = trailers.put(key, "trailer-value") + _ = listener.onClose(status, trailers) + + completionResult <- completionPromise.get + yield + assert(completionResult.status === status) + assert(completionResult.trailers.getStrings("trailer-key") === Seq("trailer-value")) + } + } + + "onReady" - { + "sets ready signal to true" in run { + for + headersPromise <- Promise.init[SafeMetadata, Any] + responsePromise <- Promise.init[TestResponse, Abort[StatusException]] + completionPromise <- Promise.init[CallClosed, Any] + readySignal <- Signal.initRef[Boolean](false) + listener = UnaryClientCallListener(headersPromise, responsePromise, completionPromise, readySignal) + + _ = listener.onReady() + + ready <- readySignal.get + yield assert(ready === true) + } + + "can be called multiple times" in run { + for + headersPromise <- Promise.init[SafeMetadata, Any] + responsePromise <- Promise.init[TestResponse, Abort[StatusException]] + completionPromise <- Promise.init[CallClosed, Any] + readySignal <- Signal.initRef[Boolean](false) + listener = UnaryClientCallListener(headersPromise, responsePromise, completionPromise, readySignal) + + _ = listener.onReady() + ready1 <- readySignal.get + _ = listener.onReady() + ready2 <- readySignal.get + yield + assert(ready1 === true) + assert(ready2 === true) + } + } + + "full lifecycle" - { + "processes successful unary call" in run { + for + headersPromise <- Promise.init[SafeMetadata, Any] + responsePromise <- Promise.init[TestResponse, Abort[StatusException]] + completionPromise <- Promise.init[CallClosed, Any] + readySignal <- Signal.initRef[Boolean](false) + listener = UnaryClientCallListener(headersPromise, responsePromise, completionPromise, readySignal) + + // Simulate call lifecycle + headers = Metadata() + _ = headers.put(Metadata.Key.of("content-type", Metadata.ASCII_STRING_MARSHALLER), "application/grpc") + _ = listener.onHeaders(headers) + + _ = listener.onReady() + + response = TestResponse("final result") + _ = listener.onMessage(response) + + trailers = Metadata() + _ = listener.onClose(Status.OK, trailers) + + headersResult <- headersPromise.get + responseResult <- Abort.run[StatusException](responsePromise.get) + completionResult <- completionPromise.get + readyResult <- readySignal.get + yield + assert(headersResult === SafeMetadata.fromJava(headers)) + assert(responseResult === Result.succeed(response)) + assert(completionResult.status === Status.OK) + assert(readyResult === true) + } + + "processes failed unary call" in run { + for + headersPromise <- Promise.init[SafeMetadata, Any] + responsePromise <- Promise.init[TestResponse, Abort[StatusException]] + completionPromise <- Promise.init[CallClosed, Any] + readySignal <- Signal.initRef[Boolean](false) + listener = UnaryClientCallListener(headersPromise, responsePromise, completionPromise, readySignal) + + // Simulate call lifecycle with error + headers = Metadata() + _ = listener.onHeaders(headers) + + errorStatus = Status.UNAVAILABLE.withDescription("Service unavailable") + trailers = Metadata() + _ = listener.onClose(errorStatus, trailers) + + headersResult <- headersPromise.get + responseResult <- Abort.run[StatusException](responsePromise.get) + completionResult <- completionPromise.get + yield + assert(headersResult === SafeMetadata.fromJava(headers)) + assert(responseResult.isFailure) + val failure = responseResult.failure.get + assert(failure.getStatus === errorStatus) + assert(completionResult.status === errorStatus) + } + } + } + +end UnaryClientCallListenerTest diff --git a/kyo-grpc-core/shared/src/test/scala/kyo/grpc/internal/UnaryServerCallHandlerTest.scala b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/internal/UnaryServerCallHandlerTest.scala new file mode 100644 index 000000000..f319cdf12 --- /dev/null +++ b/kyo-grpc-core/shared/src/test/scala/kyo/grpc/internal/UnaryServerCallHandlerTest.scala @@ -0,0 +1,312 @@ +package kyo.grpc.internal + +import io.grpc.* +import java.util.concurrent.atomic.AtomicBoolean as JAtomicBoolean +import kyo.* +import kyo.grpc.* +import kyo.grpc.Equalities.given +import org.scalamock.scalatest.AsyncMockFactory +import org.scalamock.stubs.Stubs +import org.scalatest.concurrent.Eventually +import org.scalatest.matchers.must.Matchers.* +import org.scalatest.time.Seconds +import org.scalatest.time.Span + +class UnaryServerCallHandlerTest extends Test with Stubs with Eventually: + + case class TestRequest(message: String) + case class TestResponse(result: String) + + implicit override def patienceConfig: PatienceConfig = super.patienceConfig.copy(timeout = scaled(Span(5, Seconds))) + + "UnaryServerCallHandler" - { + + "startup" - { + "requests one message from client" in run { + val handler: GrpcHandler[TestRequest, TestResponse] = + req => TestResponse("response") + + val callHandler = UnaryServerCallHandler(handler) + + val call = stub[ServerCall[TestRequest, TestResponse]] + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + + val requestHeaders = Metadata() + + val listener = callHandler.startCall(call, requestHeaders) + + assert(call.request.calls === List(1)) + } + + "set options and sends headers" in run { + val handler: GrpcHandler[TestRequest, TestResponse] = + req => TestResponse("response") + + val requestHeaders = Metadata() + + val responseHeaders = Metadata() + responseHeaders.put(Metadata.Key.of("custom-header", Metadata.ASCII_STRING_MARSHALLER), "custom-value") + + val responseOptions = ResponseOptions( + headers = SafeMetadata.fromJava(responseHeaders), + messageCompression = Maybe.Present(true), + compression = Maybe.Present("gzip"), + onReadyThreshold = Maybe.Present(16), + requestBuffer = Maybe.Present(4) + ) + + val init: GrpcHandlerInit[TestRequest, TestResponse] = + for + actualRequestHeaders <- Env.get[SafeMetadata] + _ <- Emit.value(responseOptions) + yield + assert(actualRequestHeaders === SafeMetadata.fromJava(requestHeaders)) + handler + + val callHandler = UnaryServerCallHandler(init) + + val call = stub[ServerCall[TestRequest, TestResponse]] + call.request.returnsWith(()) + call.setMessageCompression.returnsWith(()) + call.setCompression.returnsWith(()) + call.setOnReadyThreshold.returnsWith(()) + call.sendHeaders.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + assert(call.setMessageCompression.calls === responseOptions.messageCompression.toList) + assert(call.setCompression.calls === responseOptions.compression.toList) + assert(call.setOnReadyThreshold.calls === responseOptions.onReadyThreshold.toList) + assert(call.sendHeaders.calls.map(_.toString) === List(responseOptions.headers.toJava.toString)) + } + } + + "success" - { + "sends message and closes" in run { + import org.scalactic.TraversableEqualityConstraints.* + + val request = TestRequest("test") + val expectedResponse = TestResponse("response") + + val handler: GrpcHandler[TestRequest, TestResponse] = _ => expectedResponse + + val callHandler = UnaryServerCallHandler(handler) + + val call = stub[ServerCall[TestRequest, TestResponse]] + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.sendMessage.returnsWith(()) + call.close.returnsWith(()) + + val requestHeaders = Metadata() + + val listener = callHandler.startCall(call, requestHeaders) + + // Simulate receiving a message + listener.onMessage(request) + + eventually { + assert(call.sendMessage.times === 1) + assert(call.close.times === 1) + } + + assert(call.sendMessage.calls === List(expectedResponse)) + call.close.calls must contain theSameElementsInOrderAs List((Status.OK, Metadata())) + } + + "closes with trailers from handler" in run { + val request = TestRequest("test") + val expectedResponse = TestResponse("response") + val responseTrailers = Metadata() + responseTrailers.put(Metadata.Key.of("custom-header", Metadata.ASCII_STRING_MARSHALLER), "custom-value") + + val handler: GrpcHandler[TestRequest, TestResponse] = + req => Emit.value(SafeMetadata.fromJava(responseTrailers)).map(_ => expectedResponse) + + val callHandler = UnaryServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.sendMessage.returnsWith(()) + call.close.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + listener.onMessage(request) + listener.onHalfClose() + + eventually { + assert(call.close.times === 1) + } + + call.close.calls must contain theSameElementsInOrderAs List((Status.OK, responseTrailers)) + } + } + + "errors" - { + "handles abort failure correctly" in run { + val request = TestRequest("test") + val status = Status.INVALID_ARGUMENT.withDescription("Bad request") + + val handler: GrpcHandler[TestRequest, TestResponse] = + req => Abort.fail(status.asException()) + + val callHandler = UnaryServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.close.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + listener.onMessage(request) + + eventually { + assert(call.close.times === 1) + } + + call.close.calls must contain theSameElementsInOrderAs List((status, Metadata())) + } + + "handles panic correctly" in run { + val request = TestRequest("test") + val cause = Exception("Something went wrong") + + val handler: GrpcHandler[TestRequest, TestResponse] = + req => Abort.panic(cause) + + val callHandler = UnaryServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.close.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + listener.onMessage(request) + + eventually { + assert(call.close.times === 1) + } + + val status = Status.UNKNOWN.withCause(cause) + call.close.calls must contain theSameElementsInOrderAs List((status, Metadata())) + } + + "fails when client completes without sending request" in run { + val handler: GrpcHandler[TestRequest, TestResponse] = + req => TestResponse("response") + + val callHandler = UnaryServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.close.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + // Complete without sending a message + listener.onHalfClose() + + eventually { + assert(call.close.times === 1) + } + + val status = Status.INVALID_ARGUMENT.withDescription("Client completed before sending a request.") + call.close.calls must contain theSameElementsInOrderAs List((status, Metadata())) + } + } + + "cancellation" - { + "interrupts when sending message" in run { + val request = TestRequest("test") + + val handler: GrpcHandler[TestRequest, TestResponse] = + req => TestResponse("response") + + val callHandler = UnaryServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + call.close.returnsWith(()) + + val interrupted = new JAtomicBoolean(false) + call.sendMessage.returnsWith { + try + Thread.sleep(patienceConfig.timeout.toMillis + 5000) + catch + case e: InterruptedException => + interrupted.set(true) + throw e + } + + val listener = callHandler.startCall(call, requestHeaders) + + listener.onMessage(request) + listener.onCancel() + + eventually { + // This fails because of https://github.com/getkyo/kyo/issues/1431. + // assert(interrupted.get === true) + assert(call.close.times === 1) + } + + val status = Status.CANCELLED.withDescription("Call was cancelled.") + call.close.calls must contain theSameElementsInOrderAs List((status, Metadata())) + } + } + + "lifecycle" - { + "onComplete does nothing" in run { + val handler: GrpcHandler[TestRequest, TestResponse] = + req => TestResponse("response") + + val callHandler = UnaryServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + + val listener = callHandler.startCall(call, requestHeaders) + + // Call onComplete - should not throw + listener.onComplete() + + assert(call.request.times === 1) + } + + "onReady does nothing" in run { + val handler: GrpcHandler[TestRequest, TestResponse] = + req => TestResponse("response") + + val callHandler = UnaryServerCallHandler(handler) + val call = stub[ServerCall[TestRequest, TestResponse]] + val requestHeaders = Metadata() + + call.request.returnsWith(()) + call.sendHeaders.returnsWith(()) + (() => call.isReady()).returnsWith(true) + + val listener = callHandler.startCall(call, requestHeaders) + + // Call onReady - should not throw + listener.onReady() + + assert(call.request.times === 1) + } + } + } + +end UnaryServerCallHandlerTest diff --git a/kyo-grpc-e2e/jvm/src/test/resources/logging.properties b/kyo-grpc-e2e/jvm/src/test/resources/logging.properties new file mode 100644 index 000000000..ee9a9eeed --- /dev/null +++ b/kyo-grpc-e2e/jvm/src/test/resources/logging.properties @@ -0,0 +1,13 @@ +# Enable fine-grained gRPC logging +handlers=java.util.logging.ConsoleHandler +.level=INFO + +# gRPC transport logging +io.grpc.netty.level=FINE +io.grpc.internal.level=FINE +io.grpc.level=FINE + +# Console handler config +java.util.logging.ConsoleHandler.level=FINE +java.util.logging.ConsoleHandler.formatter=java.util.logging.SimpleFormatter +java.util.logging.SimpleFormatter.format=[%1$tF %1$tT] [%4$-7s] %5$s %n diff --git a/kyo-grpc-e2e/shared/src/main/protobuf/test.proto b/kyo-grpc-e2e/shared/src/main/protobuf/test.proto new file mode 100644 index 000000000..09249d096 --- /dev/null +++ b/kyo-grpc-e2e/shared/src/main/protobuf/test.proto @@ -0,0 +1,47 @@ +syntax = "proto3"; + +// Don't use kyo here because otherwise it cannot derive the Frame. +package kgrpc; + +service TestService { + rpc OneToOne(Request) returns (Response); + rpc OneToMany(Request) returns (stream Response); + rpc ManyToOne(stream Request) returns (Response); + rpc ManyToMany(stream Request) returns (stream Response); +} + +message Request { + oneof sealed_value { + Success success = 1; + Fail fail = 2; + Panic panic = 3; + } +} + +message Response { + oneof sealed_value { + Echo echo = 1; + } +} + +message Success { + string message = 1; + int32 count = 2; +} + +message Fail { + string message = 1; + int32 code = 2; + int32 after = 3; + bool outside = 4; +} + +message Panic { + string message = 1; + int32 after = 2; + bool outside = 3; +} + +message Echo { + string message = 1; +} diff --git a/kyo-grpc-e2e/shared/src/main/scala/kgrpc/TestServiceImpl.scala b/kyo-grpc-e2e/shared/src/main/scala/kgrpc/TestServiceImpl.scala new file mode 100644 index 000000000..882077db5 --- /dev/null +++ b/kyo-grpc-e2e/shared/src/main/scala/kgrpc/TestServiceImpl.scala @@ -0,0 +1,70 @@ +package kgrpc + +import io.grpc.Status +import kgrpc.test.* +import kgrpc.test.given +import kyo.* +import kyo.grpc.* + +object TestServiceImpl extends TestService: + + override def oneToOne(request: Request): Response < Grpc = + requestToResponse(request) + end oneToOne + + private def requestToResponse(request: Request): Response < Grpc = + request match + case Request.Empty => Abort.fail(Status.INVALID_ARGUMENT.asException) + case nonEmpty: Request.NonEmpty => + nonEmpty match + case Success(message, _, _) => Kyo.lift(Echo(message)) + case Fail(message, code, _, _, _) => Abort.fail(Status.fromCodeValue(code).withDescription(message).asException) + case Panic(message, _, _, _) => Abort.panic(new Exception(message)) + end match + end match + end requestToResponse + + override def oneToMany(request: Request): Stream[Response, Grpc] < Grpc = + requestToResponses(request) + end oneToMany + + private def requestToResponses(request: Request): Stream[Response, Grpc] < Grpc = + request match + case Request.Empty => Stream.empty[Response] + case nonEmpty: Request.NonEmpty => + nonEmpty match + case Success(message, count, _) => + stream((1 to count).map(n => Echo(s"$message $n"))) + case Fail(message, code, _, true, _) => + Abort.fail(Status.fromCodeValue(code).withDescription(message).asException) + case Fail(message, code, after, _, _) => + val echos = (after to 1 by -1).map(n => Kyo.lift(Echo(s"Failing in $n"))) + stream(echos :+ Abort.fail(Status.fromCodeValue(code).withDescription(message).asException)) + case Panic(message, _, true, _) => + Abort.panic(new Exception(message)) + case Panic(message, after, _, _) => + val echos = (after to 1 by -1).map(n => Kyo.lift(Echo(s"Panicing in $n"))) + stream(echos :+ Abort.panic(new Exception(message))) + end match + end match + end requestToResponses + + private def stream(responses: Seq[Response < Grpc]): Stream[Response, Grpc] = + Stream: + Kyo.foldLeft(responses)(()) { (_, response) => + response.map(r => Emit.value(Chunk(r))) + } + + override def manyToOne(requests: Stream[Request, Grpc]): Response < Grpc = + requests.fold(Maybe.empty[String])((acc, request) => + for + response <- requestToResponse(request) + nextAcc <- response.asNonEmpty.get match + case Echo(message, _) => acc.map(_ + " " + message).orElse(Maybe(message)) + yield nextAcc + ).map(maybeMessage => Echo(maybeMessage.getOrElse(""))) + + override def manyToMany(requests: Stream[Request, Grpc]): Stream[Response, Grpc] < Grpc = + requests.flatMap(requestToResponses) + +end TestServiceImpl diff --git a/kyo-grpc-e2e/shared/src/main/scala/kgrpc/test/Implicits.scala b/kyo-grpc-e2e/shared/src/main/scala/kgrpc/test/Implicits.scala new file mode 100644 index 000000000..0f710202e --- /dev/null +++ b/kyo-grpc-e2e/shared/src/main/scala/kgrpc/test/Implicits.scala @@ -0,0 +1,9 @@ +package kgrpc.test + +// Workaround for https://github.com/scalapb/ScalaPB/issues/1705. + +given CanEqual[RequestMessage.SealedValue.Empty.type, RequestMessage.SealedValue] = CanEqual.derived +given CanEqual[Request.Empty.type, Request] = CanEqual.derived + +given emptyCanEqualResponse: CanEqual[ResponseMessage.SealedValue.Empty.type, ResponseMessage.SealedValue] = CanEqual.derived +given CanEqual[Response.Empty.type, Response] = CanEqual.derived diff --git a/kyo-grpc-e2e/shared/src/test/resources/logback.xml b/kyo-grpc-e2e/shared/src/test/resources/logback.xml new file mode 100644 index 000000000..6bcce5cb6 --- /dev/null +++ b/kyo-grpc-e2e/shared/src/test/resources/logback.xml @@ -0,0 +1,13 @@ + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + + + diff --git a/kyo-grpc-e2e/shared/src/test/scala/kyo/grpc/ServiceTest.scala b/kyo-grpc-e2e/shared/src/test/scala/kyo/grpc/ServiceTest.scala new file mode 100644 index 000000000..2cb8dee85 --- /dev/null +++ b/kyo-grpc-e2e/shared/src/test/scala/kyo/grpc/ServiceTest.scala @@ -0,0 +1,564 @@ +package kyo.grpc + +import io.grpc.{Server as _, *} +import io.grpc.internal.GrpcUtil +import java.net.ServerSocket +import java.util.concurrent.TimeUnit +import kgrpc.* +import kgrpc.test.* +import kgrpc.test.TestService.* +import kyo.* +import kyo.grpc.* +import kyo.grpc.Equalities.given +import org.scalactic.Equality +import org.scalactic.TripleEquals.* +import org.scalatest.Inspectors.* +import scala.concurrent.Future +import scala.util.chaining.scalaUtilChainingOps + +class ServiceTest extends Test: + + private given CanEqual[Response, Echo] = CanEqual.derived + private given CanEqual[Status.Code, Status.Code] = CanEqual.derived + + private val emptyTrailers = Metadata() + // TODO: Test trailers. +// private val trailers = Metadata().tap(_.put(GrpcUtil.CONTENT_TYPE_KEY, GrpcUtil.CONTENT_TYPE_GRPC)) + + private val notOKStatusCodes = Status.Code.values().filterNot(_ === Status.Code.OK) + + "unary" - { + "success" in run { + val message = "Hello" + val request = Success(message) + for + client <- createClientAndServer + // TODO: Can we avoid the lift here? + response <- client.oneToOne(Kyo.lift(request)) + yield assert(response === Echo(message)) + end for + } + + "fail" in { + forEvery(notOKStatusCodes) { code => + run { + val message = "Yeah nah bro" + val status = code.toStatus.withDescription(message) + val request = Fail(message, status.getCode.value) + val expected = status.asException(emptyTrailers) + for + client <- createClientAndServer + // TODO: Can we avoid the lift here? + result <- Abort.run[StatusException](client.oneToOne(Kyo.lift(request))) + yield assertStatusException(result, expected) + end for + } + } + } + + "panic" in run { + val message = "Oh no!" + val request = Panic(message) + val expected = Status.UNKNOWN.asException(emptyTrailers) + for + client <- createClientAndServer + // TODO: Can we avoid the lift here? + result <- Abort.run[StatusException](client.oneToOne(Kyo.lift(request))) + yield + assertStatusException(result, expected) + // Do not expose the internal error message + val actual = result.failure.get + assert(actual.getMessage === "UNKNOWN") + assert(actual.getCause === null) + assert(actual.getStatus().getCause === null) + end for + } + } + + "server streaming" - { + "success" in run { + val message = "Hello" + val request = Success(message, count = 5) + for + client <- createClientAndServer + // TODO: Can we avoid the lift here? + responses <- + println(client) + client.oneToMany(Kyo.lift(request)) + .run + yield assert(responses == Chunk.from((1 to 5).map(n => Echo(s"$message $n")))) + end for + } + + "fail" - { + "producing stream" in { +// forEvery(notOKStatusCodes) { code => + val code = Status.Code.INTERNAL + run { + val message = "Yeah nah bro" + val status = code.toStatus.withDescription(message) + val request = Fail(message, status.getCode.value, outside = true) + val expected = status.asException(emptyTrailers) + for + client <- createClientAndServer + // TODO: Can we avoid the lift here? + response <- Abort.run[StatusException](client.oneToMany(Kyo.lift(request)).take(1).run) + yield assertStatusException(response, expected) + end for + } +// } + } + + "first element" in { + forEvery(notOKStatusCodes) { code => + run { + val message = "Yeah nah bro" + val status = code.toStatus.withDescription(message) + val request = Fail(message, status.getCode.value) + val expected = status.asException(emptyTrailers) + for + client <- createClientAndServer + // TODO: Can we avoid the lift here? + result <- Abort.run[StatusException](client.oneToMany(Kyo.lift(request)).take(1).run) + yield assertStatusException(result, expected) + end for + } + } + } + + "after some elements" in { + forEvery(notOKStatusCodes) { code => + run { + val message = "Yeah nah bro" + val status = code.toStatus.withDescription(message) + val after = 5 + val request = Fail(message, status.getCode.value, after) + val expected = status.asException(emptyTrailers) + for + client <- createClientAndServer + // TODO: Can we avoid the lift here? + (responses, tail) <- client.oneToMany(Kyo.lift(request)).splitAt(5) + failedResponse <- Abort.run[StatusException](tail.run) + yield + assert(responses == Chunk.from((after to 1 by -1).map(n => Echo(s"Failing in $n")))) + assertStatusException(failedResponse, expected) + end for + } + } + } + } + + "panic" - { + "producing stream" in run { + val message = "Oh no!" + val request = Panic(message) + val expected = Status.UNKNOWN.asException(emptyTrailers) + for + client <- createClientAndServer + // TODO: Can we avoid the lift here? + result <- Abort.run[StatusException](client.oneToMany(Kyo.lift(request)).take(1).run) + yield + assertStatusException(result, expected) + // Do not expose the internal error message + val actual = result.failure.get + assert(actual.getMessage === "UNKNOWN") + assert(actual.getCause === null) + assert(actual.getStatus().getCause === null) + end for + } + + "first element" in run { + val message = "Oh no!" + val request = Panic(message) + val expected = Status.UNKNOWN.asException(emptyTrailers) + for + client <- createClientAndServer + // TODO: Can we avoid the lift here? + result <- Abort.run[StatusException](client.oneToMany(Kyo.lift(request)).take(1).run) + yield + assertStatusException(result, expected) + // Do not expose the internal error message + val actual = result.failure.get + assert(actual.getMessage === "UNKNOWN") + assert(actual.getCause === null) + assert(actual.getStatus().getCause === null) + end for + } + + "after some elements" in run { + val message = "Oh no!" + val after = 5 + val request = Panic(message, after) + val expected = Status.UNKNOWN.asException(emptyTrailers) + for + client <- createClientAndServer + // TODO: Can we avoid the lift here? + (responses, tail) <- client.oneToMany(Kyo.lift(request)).splitAt(5) + failedResponse <- Abort.run[StatusException](tail.run) + yield + assert(responses == Chunk.from((after to 1 by -1).map(n => Echo(s"Panicing in $n")))) + assertStatusException(failedResponse, expected) + // Do not expose the internal error message + val actual = failedResponse.failure.get + assert(actual.getMessage === "UNKNOWN") + assert(actual.getCause === null) + assert(actual.getStatus().getCause === null) + end for + } + } + } + + "client streaming" - { + "empty" in run { + val successes = Chunk.empty[Request] + val requests = Stream(Emit.value(successes)) + for + client <- createClientAndServer + // TODO: Can we avoid the lift here? + response <- client.manyToOne(Kyo.lift(requests)) + yield assert(response === Echo()) + end for + } + + "success" in run { + val successes = Chunk.from((1 to 5).map(n => Success(n.toString): Request)) + val requests = Stream(Emit.value(successes)) + for + client <- createClientAndServer + // TODO: Can we avoid the lift here? + response <- client.manyToOne(Kyo.lift(requests)) + yield assert(response === Echo((1 to 5).mkString(" "))) + end for + } + + "fail" - { + "first element" in { + forEvery(notOKStatusCodes) { code => + run { + val message = "Yeah nah bro" + val status = code.toStatus.withDescription(message) + val fail = Fail(message, status.getCode.value) + val successes = Chunk.from((1 to 5).map(n => Success(n.toString): Request)) + val requests = Stream(Emit.value(Chunk(fail).concat(successes))) + val expected = status.asException(emptyTrailers) + for + client <- createClientAndServer + // TODO: Can we avoid the lift here? + result <- Abort.run[StatusException](client.manyToOne(Kyo.lift(requests))) + yield assertStatusException(result, expected) + end for + } + } + } + + "after some elements" in { + forEvery(notOKStatusCodes) { code => + run { + val message = "Yeah nah bro" + val status = code.toStatus.withDescription(message) + val after = 5 + val successes = Chunk.from((1 to after).map(n => Success(n.toString): Request)) + val fail = Fail(message, status.getCode.value) + val requests = Stream(Emit.value(successes.append(fail))) + val expected = status.asException(emptyTrailers) + for + client <- createClientAndServer + // TODO: Can we avoid the lift here? + result <- Abort.run[StatusException](client.manyToOne(Kyo.lift(requests))) + yield assertStatusException(result, expected) + end for + } + } + } + } + + "panic" - { + "first element" in run { + val message = "Oh no!" + val panic = Panic(message) + val successes = Chunk.from((1 to 5).map(n => Success(n.toString): Request)) + val requests = Stream(Emit.value(Chunk(panic).concat(successes))) + val expected = Status.UNKNOWN.asException(emptyTrailers) + for + client <- createClientAndServer + // TODO: Can we avoid the lift here? + result <- Abort.run[StatusException](client.manyToOne(Kyo.lift(requests))) + yield + assertStatusException(result, expected) + // Do not expose the internal error message + val actual = result.failure.get + assert(actual.getMessage === "UNKNOWN") + assert(actual.getCause === null) + assert(actual.getStatus().getCause === null) + end for + } + + "after some elements" in run { + val message = "Oh no!" + val after = 5 + val panic = Panic(message) + val successes = Chunk.from((1 to after).map(n => Success(n.toString): Request)) + val requests = Stream(Emit.value(successes.append(panic))) + val expected = Status.UNKNOWN.asException(emptyTrailers) + for + client <- createClientAndServer + // TODO: Can we avoid the lift here? + result <- Abort.run[StatusException](client.manyToOne(Kyo.lift(requests))) + yield + assertStatusException(result, expected) + // Do not expose the internal error message + val actual = result.failure.get + assert(actual.getMessage === "UNKNOWN") + assert(actual.getCause === null) + assert(actual.getStatus().getCause === null) + end for + } + } + } + + "bidirectional streaming" - { + "empty" in run { + val successes = Chunk.empty[Request] + val requests = Stream(Emit.value(successes)) + for + client <- createClientAndServer + // TODO: Can we avoid the lift here? + responses <- client.manyToMany(Kyo.lift(requests)).run + yield assert(responses == Chunk.empty) + end for + } + + "success" in run { + val successes = Chunk.from((1 to 5).map(n => Success(n.toString, count = n - 2): Request)) + val expected = Chunk.from((3 to 5).flatMap(n => Chunk.from((1 to (n - 2)).map(m => Echo(s"$n $m"))))) + val requests = Stream(Emit.value(successes)) + for + client <- createClientAndServer + // TODO: Can we avoid the lift here? + responses <- client.manyToMany(Kyo.lift(requests)).run + yield assert(responses == expected) + end for + } + + "fail" - { + "producing stream on first element" in { + forEvery(notOKStatusCodes) { code => + run { + val message = "Yeah nah bro" + val status = code.toStatus.withDescription(message) + val fail = Fail(message, status.getCode.value, outside = true) + val successes = Chunk.from((1 to 5).map(n => Success(n.toString, count = 1): Request)) + val requests = Stream(Emit.value(Chunk(fail).concat(successes))) + val expected = status.asException(emptyTrailers) + for + client <- createClientAndServer + // TODO: Can we avoid the lift here? + response <- Abort.run[StatusException](client.manyToMany(Kyo.lift(requests)).take(1).run) + yield assertStatusException(response, expected) + end for + } + } + } + + "producing stream after some elements" in { + forEvery(notOKStatusCodes) { code => + run { + val message = "Yeah nah bro" + val status = code.toStatus.withDescription(message) + val after = 5 + val successes = Chunk.from((1 to after).map(n => Success(n.toString, count = 1): Request)) + val fail = Fail(message, status.getCode.value, outside = true) + val requests = Stream(Emit.value(successes.append(fail))) + val expected = status.asException(emptyTrailers) + for + client <- createClientAndServer + // TODO: Can we avoid the lift here? + (responses, tail) <- client.manyToMany(Kyo.lift(requests)).splitAt(5) + failedResponse <- Abort.run[StatusException](tail.run) + yield + assert(responses == Chunk.from((1 to after).map(n => Echo(s"$n 1")))) + assertStatusException(failedResponse, expected) + end for + } + } + } + + "first element" in { + forEvery(notOKStatusCodes) { code => + run { + val message = "Yeah nah bro" + val status = code.toStatus.withDescription(message) + val fail = Fail(message, status.getCode.value) + val successes = Chunk.from((1 to 5).map(n => Success(n.toString, count = 1): Request)) + val requests = Stream(Emit.value(Chunk(fail).concat(successes))) + val expected = status.asException(emptyTrailers) + for + client <- createClientAndServer + // TODO: Can we avoid the lift here? + result <- Abort.run[StatusException](client.manyToMany(Kyo.lift(requests)).take(1).run) + yield assertStatusException(result, expected) + end for + } + } + } + + "after some elements" in { + forEvery(notOKStatusCodes) { code => + run { + val message = "Yeah nah bro" + val status = code.toStatus.withDescription(message) + val after = 5 + val successes = Chunk.from((1 to after).map(n => Success(n.toString, count = 1): Request)) + val fail = Fail(message, status.getCode.value) + val requests = Stream(Emit.value(successes.append(fail))) + val expected = status.asException(emptyTrailers) + for + client <- createClientAndServer + // TODO: Can we avoid the lift here? + (responses, tail) <- client.manyToMany(Kyo.lift(requests)).splitAt(5) + failedResponse <- Abort.run[StatusException](tail.run) + yield + assert(responses == Chunk.from((1 to after).map(n => Echo(s"$n 1")))) + assertStatusException(failedResponse, expected) + end for + } + } + } + } + + "panic" - { + "producing stream on first element" in { + run { + val message = "Oh no!" + val panic = Panic(message) + val successes = Chunk.from((1 to 5).map(n => Success(n.toString, count = 1): Request)) + val requests = Stream(Emit.value(Chunk(panic).concat(successes))) + val expected = Status.UNKNOWN.asException(emptyTrailers) + for + client <- createClientAndServer + // TODO: Can we avoid the lift here? + response <- Abort.run[StatusException](client.manyToMany(Kyo.lift(requests)).take(1).run) + yield + assertStatusException(response, expected) + // Do not expose the internal error message + val actual = response.failure.get + assert(actual.getMessage === "UNKNOWN") + assert(actual.getCause === null) + assert(actual.getStatus().getCause === null) + end for + } + } + + "producing stream after some elements" in { + run { + val after = 5 + val successes = Chunk.from((1 to after).map(n => Success(n.toString, count = 1): Request)) + val message = "Oh no!" + val panic = Panic(message) + val requests = Stream(Emit.value(successes.append(panic))) + val expected = Status.UNKNOWN.asException(emptyTrailers) + for + client <- createClientAndServer + // TODO: Can we avoid the lift here? + (responses, tail) <- client.manyToMany(Kyo.lift(requests)).splitAt(after) + failedResponse <- Abort.run[StatusException](tail.run) + yield + assert(responses == Chunk.from((1 to after).map(n => Echo(s"$n 1")))) + assertStatusException(failedResponse, expected) + // Do not expose the internal error message + val actual = failedResponse.failure.get + assert(actual.getMessage === "UNKNOWN") + assert(actual.getCause === null) + assert(actual.getStatus().getCause === null) + end for + } + } + + "first element" in { + run { + val message = "Oh no!" + val panic = Panic(message) + val successes = Chunk.from((1 to 5).map(n => Success(n.toString, count = 1): Request)) + val requests = Stream(Emit.value(Chunk(panic).concat(successes))) + val expected = Status.UNKNOWN.asException(emptyTrailers) + for + client <- createClientAndServer + // TODO: Can we avoid the lift here? + result <- Abort.run[StatusException](client.manyToMany(Kyo.lift(requests)).take(1).run) + yield + assertStatusException(result, expected) + // Do not expose the internal error message + val actual = result.failure.get + assert(actual.getMessage === "UNKNOWN") + assert(actual.getCause === null) + assert(actual.getStatus().getCause === null) + end for + } + } + + "after some elements" in { + run { + val after = 5 + val successes = Chunk.from((1 to after).map(n => Success(n.toString, count = 1): Request)) + val message = "Oh no!" + val panic = Panic(message) + val requests = Stream(Emit.value(successes.append(panic))) + val expected = Status.UNKNOWN.asException(emptyTrailers) + for + client <- createClientAndServer + // TODO: Can we avoid the lift here? + (responses, tail) <- client.manyToMany(Kyo.lift(requests)).splitAt(5) + failedResponse <- Abort.run[StatusException](tail.run) + yield + assert(responses == Chunk.from((1 to after).map(n => Echo(s"$n 1")))) + assertStatusException(failedResponse, expected) + // Do not expose the internal error message + val actual = failedResponse.failure.get + assert(actual.getMessage === "UNKNOWN") + assert(actual.getCause === null) + assert(actual.getStatus().getCause === null) + end for + } + } + } + } + + private def assertStatusException(result: Result[StatusException, Any], expected: StatusException) = + assert(result.isError) + assert(result.isFailure) + val actual = result.failure.get + assert(actual === expected) + end assertStatusException + + private def createClientAndServer: Client < (Scope & Sync) = + for + port <- findFreePort + _ <- createServer(port) + client <- createClient(port) + yield client + + private def createServer(port: Int): Server < (Scope & Sync) = + Server.start(port)(_.addService(TestServiceImpl.definition)) + + private def createClient(port: Int): Client < (Scope & Sync) = + createChannel(port).map(TestService.client(_)) + + private def createChannel(port: Int) = + Scope.acquireRelease( + Sync.defer( + ManagedChannelBuilder + .forAddress("localhost", port) + .usePlaintext() + .build() + ) + ) { channel => + Sync.defer(channel.shutdownNow().awaitTermination(10, TimeUnit.SECONDS)).unit + } + + private def findFreePort = + for + socket <- Sync.defer(new ServerSocket(0)) + port <- Sync.ensure(Sync.defer(socket.close()))(socket.getLocalPort) + yield port + +end ServiceTest diff --git a/kyo-grpc-e2e/shared/src/test/scala/kyo/grpc/Test.scala b/kyo-grpc-e2e/shared/src/test/scala/kyo/grpc/Test.scala new file mode 100644 index 000000000..371c476d2 --- /dev/null +++ b/kyo-grpc-e2e/shared/src/test/scala/kyo/grpc/Test.scala @@ -0,0 +1,17 @@ +package kyo.grpc + +import kyo.internal.BaseKyoCoreTest +import kyo.internal.Platform +import org.scalatest.NonImplicitAssertions +import org.scalatest.freespec.AsyncFreeSpec +import scala.concurrent.ExecutionContext +import scala.language.implicitConversions + +abstract class Test extends AsyncFreeSpec with NonImplicitAssertions with BaseKyoCoreTest: + + type Assertion = org.scalatest.compatible.Assertion + def assertionSuccess = succeed + def assertionFailure(msg: String) = fail(msg) + + override given executionContext: ExecutionContext = Platform.executionContext +end Test diff --git a/project/plugins.sbt b/project/plugins.sbt index e15e2ba5c..99e2fbf51 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -22,6 +22,13 @@ addSbtPlugin("ch.epfl.scala" % "sbt-scalafix" % "0.14.5") // addSbtPlugin("com.github.sbt" % "sbt-jacoco" % "3.4.0") +addSbtPlugin("com.eed3si9n" % "sbt-buildinfo" % "0.12.0") +addSbtPlugin("com.thesamet" % "sbt-protoc" % "1.0.7") +addSbtPlugin("com.thesamet" % "sbt-protoc-gen-project" % "0.1.8") +addSbtPlugin("org.typelevel" % "sbt-fs2-grpc" % "2.7.21") + libraryDependencies ++= Seq( - "org.typelevel" %% "scalac-options" % "0.1.8" + "com.thesamet.scalapb" %% "compilerplugin" % "0.11.17", + "com.thesamet.scalapb.zio-grpc" %% "zio-grpc-codegen" % "0.6.3", + "org.typelevel" %% "scalac-options" % "0.1.8" )