From 6d96d1666cc177355835913124abccc67f775946 Mon Sep 17 00:00:00 2001 From: Nabil Abdel-Hafeez <7283535+987Nabil@users.noreply.github.com> Date: Thu, 9 Jan 2025 21:31:48 +0100 Subject: [PATCH 1/2] Schema based header codecs, unified with query codecs (#3232) Fix for publish CI job --- build.sbt | 1 + project/MimaSettings.scala | 11 + .../ServerInboundHandlerBenchmark.scala | 2 +- .../zio/http/endpoint/cli/CliEndpoint.scala | 39 +- .../zio/http/endpoint/cli/HttpOptions.scala | 8 +- .../scala/zio/http/endpoint/cli/AuxGen.scala | 6 +- .../scala/zio/http/endpoint/cli/CliSpec.scala | 2 +- .../zio/http/endpoint/cli/CommandGen.scala | 18 +- .../zio/http/endpoint/cli/EndpointGen.scala | 13 +- .../zio/http/endpoint/cli/OptionsGen.scala | 33 +- .../zio/http/gen/scala/CodeGenSpec.scala | 2 +- .../scala/zio/http/endpoint/AuthSpec.scala | 6 +- .../scala/zio/http/endpoint/HeaderSpec.scala | 204 ++++++ .../http/endpoint/QueryParameterSpec.scala | 50 +- .../scala/zio/http/endpoint/RequestSpec.scala | 54 +- .../endpoint/openapi/OpenAPIGenSpec.scala | 135 +++- .../src/main/scala/zio/http/Header.scala | 1 + .../scala/zio/http/codec/HeaderCodecs.scala | 34 +- .../main/scala/zio/http/codec/HttpCodec.scala | 354 ++++++---- .../scala/zio/http/codec/HttpCodecError.scala | 10 + .../scala/zio/http/codec/QueryCodecs.scala | 126 +--- .../scala/zio/http/codec/StringCodec.scala | 394 +++++++++++ .../zio/http/codec/TextBinaryCodec.scala | 8 +- .../zio/http/codec/internal/Atomized.scala | 30 +- .../http/codec/internal/AtomizedCodecs.scala | 7 +- .../http/codec/internal/EncoderDecoder.scala | 613 +++++++++++------- .../zio/http/endpoint/http/HttpGen.scala | 24 +- .../http/endpoint/openapi/JsonSchema.scala | 20 +- .../http/endpoint/openapi/OpenAPIGen.scala | 88 ++- .../zio/http/internal/HeaderGetters.scala | 7 + 30 files changed, 1688 insertions(+), 612 deletions(-) create mode 100644 zio-http/jvm/src/test/scala/zio/http/endpoint/HeaderSpec.scala create mode 100644 zio-http/shared/src/main/scala/zio/http/codec/StringCodec.scala diff --git a/build.sbt b/build.sbt index a19027bdb9..4e5199a9dd 100644 --- a/build.sbt +++ b/build.sbt @@ -59,6 +59,7 @@ ThisBuild / githubWorkflowPublishTargetBranches += RefPredicate.StartsWith(Ref.T ThisBuild / githubWorkflowPublishPreamble := Seq(coursierSetup) ThisBuild / githubWorkflowPublish := Seq( + WorkflowStep.Use(UseRef.Public("coursier", "setup-action", "v1"), Map("apps" -> "sbt")), WorkflowStep.Sbt( List("ci-release"), name = Some("Release"), diff --git a/project/MimaSettings.scala b/project/MimaSettings.scala index 1edcce1f98..fd62bacb55 100644 --- a/project/MimaSettings.scala +++ b/project/MimaSettings.scala @@ -12,6 +12,17 @@ object MimaSettings { mimaBinaryIssueFilters ++= Seq( exclude[Problem]("zio.http.internal.*"), exclude[Problem]("zio.http.codec.internal.*"), + exclude[Problem]("zio.http.codec.HttpCodec$Query$QueryType$Record$"), + exclude[Problem]("zio.http.codec.HttpCodec$Query$QueryType$Record"), + exclude[Problem]("zio.http.codec.HttpCodec$Query$QueryType$Primitive$"), + exclude[Problem]("zio.http.codec.HttpCodec$Query$QueryType$Primitive"), + exclude[Problem]("zio.http.codec.HttpCodec$Query$QueryType$Collection$"), + exclude[Problem]("zio.http.codec.HttpCodec$Query$QueryType$Collection"), + exclude[Problem]("zio.http.codec.HttpCodec$Query$QueryType$"), + exclude[Problem]("zio.http.codec.HttpCodec$Query$QueryType"), + exclude[Problem]("zio.http.endpoint.openapi.OpenAPIGen#AtomizedMetaCodecs.apply"), + exclude[Problem]("zio.http.endpoint.openapi.OpenAPIGen#AtomizedMetaCodecs.this"), + exclude[Problem]("zio.http.endpoint.openapi.OpenAPIGen#AtomizedMetaCodecs.copy"), ), mimaFailOnProblem := failOnProblem ) diff --git a/zio-http-benchmarks/src/main/scala/zhttp.benchmarks/ServerInboundHandlerBenchmark.scala b/zio-http-benchmarks/src/main/scala/zhttp.benchmarks/ServerInboundHandlerBenchmark.scala index d28c665e5e..8f594614c6 100644 --- a/zio-http-benchmarks/src/main/scala/zhttp.benchmarks/ServerInboundHandlerBenchmark.scala +++ b/zio-http-benchmarks/src/main/scala/zhttp.benchmarks/ServerInboundHandlerBenchmark.scala @@ -18,7 +18,7 @@ class ServerInboundHandlerBenchmark { private val largeString = random.alphanumeric.take(100000).mkString private val baseUrl = "http://localhost:8080" - private val headers = Headers(Header.ContentType(MediaType.text.`plain`).untyped) + private val headers = Headers(Header.ContentType(MediaType.text.`plain`)) private val arrayEndpoint = "array" private val arrayResponse = ZIO.succeed( diff --git a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/CliEndpoint.scala b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/CliEndpoint.scala index 566289ea37..3228b85c32 100644 --- a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/CliEndpoint.scala +++ b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/CliEndpoint.scala @@ -1,7 +1,6 @@ package zio.http.endpoint.cli import zio.http._ -import zio.http.codec.HttpCodec.Query.QueryType import zio.http.codec._ import zio.http.endpoint._ @@ -112,13 +111,11 @@ private[cli] object CliEndpoint { } CliEndpoint(body = HttpOptions.Body(name, codec.defaultMediaType, codec.defaultSchema) :: List()) - case HttpCodec.Header(name, textCodec, _) if textCodec.isInstanceOf[TextCodec.Constant] => - CliEndpoint(headers = - HttpOptions.HeaderConstant(name, textCodec.asInstanceOf[TextCodec.Constant].string) :: List(), - ) - case HttpCodec.Header(name, textCodec, _) => - CliEndpoint(headers = HttpOptions.Header(name, textCodec) :: List()) - case HttpCodec.Method(codec, _) => + case HttpCodec.Header(headerType, _) => + CliEndpoint(headers = HttpOptions.Header(headerType.name, TextCodec.string) :: List()) + case HttpCodec.HeaderCustom(codec, _) => + CliEndpoint(headers = HttpOptions.Header(codec.name.get, TextCodec.string) :: List()) + case HttpCodec.Method(codec, _) => codec.asInstanceOf[SimpleCodec[_, _]] match { case SimpleCodec.Specified(method: Method) => CliEndpoint(methods = method) @@ -128,22 +125,16 @@ private[cli] object CliEndpoint { case HttpCodec.Path(pathCodec, _) => CliEndpoint(url = HttpOptions.Path(pathCodec) :: List()) - case HttpCodec.Query(queryType, _) => - queryType match { - case QueryType.Primitive(name, codec) => - CliEndpoint(url = HttpOptions.Query(name, codec) :: List()) - case record @ QueryType.Record(_) => - val queryOptions = record.fieldAndCodecs.map { case (field, codec) => - HttpOptions.Query(field.name, codec) - } - CliEndpoint(url = queryOptions.toList) - case QueryType.Collection(_, elements, _) => - val queryOptions = - HttpOptions.Query(elements.name, elements.codec) - CliEndpoint(url = queryOptions :: List()) - } - - case HttpCodec.Status(_, _) => CliEndpoint.empty + case HttpCodec.Query(codec, _) => + if (codec.isPrimitive) + CliEndpoint(url = HttpOptions.Query(codec) :: List()) + else if (codec.isRecord) + CliEndpoint(url = codec.recordFields.map { case (_, codec) => + HttpOptions.Query(codec) + }.toList) + else + CliEndpoint(url = HttpOptions.Query(codec) :: List()) + case HttpCodec.Status(_, _) => CliEndpoint.empty } } diff --git a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/HttpOptions.scala b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/HttpOptions.scala index 191194864d..2abb5704b7 100644 --- a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/HttpOptions.scala +++ b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/HttpOptions.scala @@ -11,6 +11,7 @@ import zio.schema._ import zio.schema.annotation.description import zio.http._ +import zio.http.codec.HttpCodec.SchemaCodec import zio.http.codec._ /* @@ -264,10 +265,9 @@ private[cli] object HttpOptions { } - final case class Query(override val name: String, codec: BinaryCodecWithSchema[_], doc: Doc = Doc.empty) - extends URLOptions { + final case class Query(codec: SchemaCodec[_], doc: Doc = Doc.empty) extends URLOptions { self => - + override val name = codec.name.get override val tag = "?" + name def options: Options[_] = optionsFromSchema(codec)(name) @@ -293,7 +293,7 @@ private[cli] object HttpOptions { } - private[cli] def optionsFromSchema[A](codec: BinaryCodecWithSchema[A]): String => Options[A] = + private[cli] def optionsFromSchema[A](codec: SchemaCodec[A]): String => Options[A] = codec.schema match { case Schema.Primitive(standardType, _) => standardType match { diff --git a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/AuxGen.scala b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/AuxGen.scala index 8fdb42d863..06c365fc0e 100644 --- a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/AuxGen.scala +++ b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/AuxGen.scala @@ -16,11 +16,7 @@ import zio.http.codec._ */ object AuxGen { - lazy val anyTextCodec: Gen[Any, TextCodec[_]] = - Gen.oneOf( - Gen.fromIterable(List(TextCodec.boolean, TextCodec.int, TextCodec.string, TextCodec.uuid)), - Gen.alphaNumericStringBounded(1, 30).map(TextCodec.constant(_)), - ) + lazy val anyTextCodec: Gen[Any, TextCodec[_]] = Gen.const(TextCodec.string) lazy val anyMediaType: Gen[Any, MediaType] = Gen.fromIterable(MediaType.allMediaTypes) diff --git a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CliSpec.scala b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CliSpec.scala index 5153f1ab47..f26145431c 100644 --- a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CliSpec.scala +++ b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CliSpec.scala @@ -27,7 +27,7 @@ object CliSpec extends ZIOSpecDefault { val bodyStream = ContentCodec.contentStream[BigInt]("bodyStream") - val headerCodec = HttpCodec.Header("header", TextCodec.string) + val headerCodec = HttpCodec.headerAs[String]("header") val path1 = PathCodec.bool("path1") diff --git a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CommandGen.scala b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CommandGen.scala index ad99897218..9d84e04b9f 100644 --- a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CommandGen.scala +++ b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CommandGen.scala @@ -47,20 +47,20 @@ object CommandGen { case _: HttpOptions.Constant => false case _ => true }.map { - case HttpOptions.Path(pathCodec, _) => - pathCodec.segments.toList.flatMap { case segment => + case HttpOptions.Path(pathCodec, _) => + pathCodec.segments.toList.flatMap { segment => getSegment(segment) match { case (_, "") => Nil case (name, "boolean") => s"[${getName(name, "")}]" :: Nil case (name, codec) => s"${getName(name, "")} $codec" :: Nil } } - case HttpOptions.Query(name, codec, _) => - getType(codec) match { - case "" => s"[${getName(name, "")}]" :: Nil - case codec => s"${getName(name, "")} $codec" :: Nil + case HttpOptions.Query(codec, _) if codec.isPrimitive => + getType(codec.schema) match { + case "" => s"[${getName(codec.name.get, "")}]" :: Nil + case tpy => s"${getName(codec.name.get, "")} $tpy" :: Nil } - case _ => Nil + case _ => Nil }.foldRight(List[String]())(_ ++ _) val headersOptions = cliEndpoint.headers.filter { @@ -121,8 +121,8 @@ object CommandGen { case _ => "" } - def getType[A](codec: BinaryCodecWithSchema[A]): String = - codec.schema match { + def getType[A](schema: Schema[A]): String = + schema match { case Schema.Primitive(standardType, _) => standardType match { case StandardType.UnitType => "" diff --git a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/EndpointGen.scala b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/EndpointGen.scala index d868a86cca..792cbdb2f7 100644 --- a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/EndpointGen.scala +++ b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/EndpointGen.scala @@ -5,7 +5,9 @@ import zio.test._ import zio.schema.Schema +import zio.http.Header.HeaderType import zio.http._ +import zio.http.codec.HttpCodec.SchemaCodec import zio.http.codec._ import zio.http.endpoint._ import zio.http.endpoint.cli.AuxGen._ @@ -78,10 +80,9 @@ object EndpointGen { lazy val anyHeader: Gen[Any, CliReprOf[Codec[_]]] = Gen.alphaNumericStringBounded(1, 30).zip(anyTextCodec).map { case (name, codec) => CliRepr( - HttpCodec.Header(name, codec), + HttpCodec.Header(Header.Custom(name, "").headerType), // todo use schema bases header codec match { - case TextCodec.Constant(value) => CliEndpoint(headers = HttpOptions.HeaderConstant(name, value) :: Nil) - case _ => CliEndpoint(headers = HttpOptions.Header(name, codec) :: Nil) + case _ => CliEndpoint(headers = HttpOptions.Header(name, codec) :: Nil) }, ) } @@ -102,10 +103,10 @@ object EndpointGen { lazy val anyQuery: Gen[Any, CliReprOf[Codec[_]]] = Gen.alphaNumericStringBounded(1, 30).zip(anyStandardType).map { case (name, schema0) => val schema = schema0.asInstanceOf[Schema[Any]] - val codec = BinaryCodecWithSchema(TextBinaryCodec.fromSchema(schema), schema) + val codec = SchemaCodec(Some(name), schema) CliRepr( - HttpCodec.Query(HttpCodec.Query.QueryType.Primitive(name, codec)), - CliEndpoint(url = HttpOptions.Query(name, codec) :: Nil), + HttpCodec.Query(codec), + CliEndpoint(url = HttpOptions.Query(codec) :: Nil), ) } diff --git a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/OptionsGen.scala b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/OptionsGen.scala index 58fe22aa85..1cb6016f4b 100644 --- a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/OptionsGen.scala +++ b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/OptionsGen.scala @@ -7,6 +7,7 @@ import zio.test.Gen import zio.schema.Schema import zio.http._ +import zio.http.codec.HttpCodec.SchemaCodec import zio.http.codec._ import zio.http.endpoint.cli.AuxGen._ import zio.http.endpoint.cli.CliRepr._ @@ -32,10 +33,10 @@ object OptionsGen { .optionsFromTextCodec(textCodec)(name) .map(value => textCodec.encode(value)) - def encodeOptions[A](name: String, codec: BinaryCodecWithSchema[A]): Options[String] = + def encodeOptions[A](name: String, codec: SchemaCodec[A]): Options[String] = HttpOptions .optionsFromSchema(codec)(name) - .map(value => codec.codec(CodecConfig.defaultConfig).encode(value).asString) + .map(value => codec.stringCodec.encode(value)) lazy val anyBodyOption: Gen[Any, CliReprOf[Options[Retriever]]] = Gen @@ -50,18 +51,12 @@ object OptionsGen { } lazy val anyHeaderOption: Gen[Any, CliReprOf[Options[Headers]]] = - Gen.alphaNumericStringBounded(1, 30).zip(anyTextCodec).map { - case (name, TextCodec.Constant(value)) => - CliRepr( - Options.Empty.map(_ => Headers(name, value)), - CliEndpoint(headers = HttpOptions.HeaderConstant(name, value) :: Nil), - ) - case (name, codec) => - CliRepr( - encodeOptions(name, codec) - .map(value => Headers(name, value)), - CliEndpoint(headers = HttpOptions.Header(name, codec) :: Nil), - ) + Gen.alphaNumericStringBounded(1, 30).zip(anyTextCodec).map { case (name, codec) => + CliRepr( + encodeOptions(name, codec) + .map(value => Headers(name, value)), + CliEndpoint(headers = HttpOptions.Header(name, codec) :: Nil), + ) } lazy val anyURLOption: Gen[Any, CliReprOf[Options[String]]] = @@ -83,14 +78,12 @@ object OptionsGen { }, Gen .alphaNumericStringBounded(1, 30) - .zip(anyStandardType.map { s => - val schema = s.asInstanceOf[Schema[Any]] - BinaryCodecWithSchema(TextBinaryCodec.fromSchema(schema), schema) - }) - .map { case (name, codec) => + .zip(anyStandardType) + .map { case (name, schema) => + val codec = SchemaCodec(Some(name), schema) CliRepr( encodeOptions(name, codec), - CliEndpoint(url = HttpOptions.Query(name, codec) :: Nil), + CliEndpoint(url = HttpOptions.Query(codec) :: Nil), ) }, ) diff --git a/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala b/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala index fda70f5ff4..ee50ec1c81 100644 --- a/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala +++ b/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala @@ -155,7 +155,7 @@ object CodeGenSpec extends ZIOSpecDefault { Endpoint(Method.GET / "api" / "v1" / "users") .header(HeaderCodec.accept) .header(HeaderCodec.contentType) - .header(HeaderCodec.name[String]("Token")) + .header(HeaderCodec.headerAs[String]("Token")) val openAPI = OpenAPIGen.fromEndpoints(endpoint) codeGenFromOpenAPI(openAPI) { testDir => diff --git a/zio-http/jvm/src/test/scala/zio/http/endpoint/AuthSpec.scala b/zio-http/jvm/src/test/scala/zio/http/endpoint/AuthSpec.scala index e912e7f410..d140450719 100644 --- a/zio-http/jvm/src/test/scala/zio/http/endpoint/AuthSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/endpoint/AuthSpec.scala @@ -152,7 +152,7 @@ object AuthSpec extends ZIOSpecDefault { .catchAllCause(c => ZIO.logInfoCause(c)) <* ZIO.sleep(1.seconds) response <- response } yield assertTrue(response == "admin") - } @@ flaky, + }, test("Auth basic or bearer with context and endpoint client") { val endpoint = Endpoint(Method.GET / "multiAuth") @@ -212,7 +212,7 @@ object AuthSpec extends ZIOSpecDefault { .catchAllCause(c => ZIO.logInfoCause(c)) <* ZIO.sleep(1.seconds) response <- response } yield assertTrue(response == "admin") - } @@ TestAspect.flaky, + }, test("Auth with context and endpoint client with path parameter") { val endpoint = Endpoint(Method.GET / int("a")).out[String](MediaType.text.`plain`).auth(AuthType.Basic) @@ -237,7 +237,7 @@ object AuthSpec extends ZIOSpecDefault { response <- response } yield assertTrue(response == "admin") }, - ).provideShared(Client.default, Server.default) @@ TestAspect.withLiveClock, + ).provideShared(Client.default, Server.default) @@ TestAspect.withLiveClock @@ TestAspect.flaky, test("Require Basic Auth, but get Bearer Auth") { val endpoint = Endpoint(Method.GET / "test").out[String](MediaType.text.`plain`).auth(AuthType.Basic) val routes = diff --git a/zio-http/jvm/src/test/scala/zio/http/endpoint/HeaderSpec.scala b/zio-http/jvm/src/test/scala/zio/http/endpoint/HeaderSpec.scala new file mode 100644 index 0000000000..ee5a7896ea --- /dev/null +++ b/zio-http/jvm/src/test/scala/zio/http/endpoint/HeaderSpec.scala @@ -0,0 +1,204 @@ +package zio.http.endpoint + +import zio.test._ +import zio.{NonEmptyChunk, Scope} + +import zio.schema.Schema +import zio.schema.annotation.fieldName + +import zio.http._ +import zio.http.codec.HttpCodec +import zio.http.endpoint.EndpointSpec.testEndpointWithHeaders + +object HeaderSpec extends ZIOHttpSpec { + case class MyHeaders(age: String, @fieldName("content-type") cType: String = "application", xApiKey: Option[String]) + + object MyHeaders { + implicit val schema: Schema[MyHeaders] = zio.schema.DeriveSchema.gen[MyHeaders] + } + + override def spec: Spec[TestEnvironment with Scope, Any] = + suite("HeaderCodec")( + test("Headers from case class") { + check( + Gen.alphaNumericStringBounded(1, 10), + Gen.alphaNumericStringBounded(1, 10), + Gen.alphaNumericStringBounded(1, 10), + ) { (age, cType, apiKey) => + val testRoutes = testEndpointWithHeaders( + Routes( + Endpoint(Method.GET / "users") + .header(HttpCodec.headers[MyHeaders]) + .out[String] + .implementPurely(_.toString), + ), + ) _ + + testRoutes( + s"/users", + List( + "age" -> age, + "content-type" -> cType, + "x-api-key" -> apiKey, + ), + MyHeaders(age, cType, Some(apiKey)).toString, + ) && + testRoutes( + s"/users", + List( + "age" -> age, + "content-type" -> cType, + "x-api-key" -> "", + ), + MyHeaders(age, cType, Some("")).toString, + ) && + testRoutes( + s"/users", + List( + "age" -> age, + ), + MyHeaders(age, "application", None).toString, + ) + } + }, + test("Optional Headers from case class") { + check( + Gen.alphaNumericStringBounded(1, 10), + Gen.alphaNumericStringBounded(1, 10), + Gen.alphaNumericStringBounded(1, 10), + ) { (age, cType, apiKey) => + val testRoutes = testEndpointWithHeaders( + Routes( + Endpoint(Method.GET / "users") + .header(HttpCodec.headers[MyHeaders].optional) + .out[String] + .implementPurely(_.toString), + ), + ) _ + + testRoutes( + s"/users", + List( + "content-type" -> cType, + ), + None.toString, + ) && + testRoutes( + s"/users", + List( + "age" -> age, + "content-type" -> cType, + "x-api-key" -> apiKey, + ), + Some(MyHeaders(age, cType, Some(apiKey))).toString, + ) && testRoutes( + s"/users", + List( + "age" -> age, + ), + Some(MyHeaders(age, "application", None)).toString, + ) + } + }, + test("Multiple Header values") { + check( + Gen.alphaNumericStringBounded(1, 10), + Gen.alphaNumericStringBounded(1, 10), + Gen.alphaNumericStringBounded(1, 10), + ) { (age, age2, age3) => + val testRoutes = testEndpointWithHeaders( + Routes( + Endpoint(Method.GET / "users") + .header(HttpCodec.headerAs[List[String]]("age")) + .out[String] + .implementPurely(_.toString), + ), + ) _ + + testRoutes( + s"/users", + List( + "age" -> age, + ), + List(age).toString, + ) && testRoutes( + s"/users", + List( + "age" -> age, + "age" -> age2, + ), + List(age, age2).toString, + ) && testRoutes( + s"/users", + List( + "age" -> age, + "age" -> age2, + "age" -> age3, + ), + List(age, age2, age3).toString, + ) + } + }, + test("Multiple Header values non empty") { + check( + Gen.alphaNumericStringBounded(1, 10), + Gen.alphaNumericStringBounded(1, 10), + Gen.alphaNumericStringBounded(1, 10), + ) { (age, age2, age3) => + val testRoutes = testEndpointWithHeaders( + Routes( + Endpoint(Method.GET / "users") + .header(HttpCodec.headerAs[NonEmptyChunk[String]]("age")) + .out[String] + .implementPurely(_.toString), + ), + ) _ + + testRoutes( + s"/users", + List( + "age" -> age, + ), + NonEmptyChunk(age).toString, + ) && testRoutes( + s"/users", + List( + "age" -> age, + "age" -> age2, + ), + NonEmptyChunk(age, age2).toString, + ) && testRoutes( + s"/users", + List( + "age" -> age, + "age" -> age2, + "age" -> age3, + ), + NonEmptyChunk(age, age2, age3).toString, + ) + } + }, + test("Header from transformed schema") { + case class Wrapper(age: Int) + implicit val schema: Schema[Wrapper] = zio.schema.Schema[Int].transform[Wrapper](Wrapper(_), _.age) + check(Gen.int) { age => + val testRoutes = testEndpointWithHeaders( + Routes( + Endpoint(Method.GET / "users") + .header(HttpCodec.headerAs[Wrapper]("age")) + .out[String] + .implementPurely(_.toString), + ), + ) _ + + testRoutes( + s"/users", + List( + "age" -> age.toString, + ), + Wrapper(age).toString, + ) + } + }, + ) +} diff --git a/zio-http/jvm/src/test/scala/zio/http/endpoint/QueryParameterSpec.scala b/zio-http/jvm/src/test/scala/zio/http/endpoint/QueryParameterSpec.scala index 0c5b732485..af0a2eca16 100644 --- a/zio-http/jvm/src/test/scala/zio/http/endpoint/QueryParameterSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/endpoint/QueryParameterSpec.scala @@ -55,16 +55,6 @@ object QueryParameterSpec extends ZIOHttpSpec { testRoutes( s"/users?int=$int&optInt=${optInt.mkString}&string=$string&strings=${strings.mkString(",")}", Params(int, optInt, string, strings).toString, - ) && - testRoutes( - s"/users?int=$int&string=$string&strings=${strings.mkString(",")}", - Params(int, None, string, strings).toString, - ) && testRoutes( - s"/users?int=$int&optInt=${optInt.mkString}&strings=${strings.mkString(",")}", - Params(int, optInt, "", strings).toString, - ) && testRoutes( - s"/users?int=$int&optInt=${optInt.mkString}&string=$string", - Params(int, optInt, string, Chunk("defaultString")).toString, ) } }, @@ -110,8 +100,8 @@ object QueryParameterSpec extends ZIOHttpSpec { }, ), ) _ - // testRoutes(s"/users/$userId", s"path(users, $userId, None)") && - // testRoutes(s"/users/$userId?details=", s"path(users, $userId, None)") && + testRoutes(s"/users/$userId", s"path(users, $userId, None)") && + testRoutes(s"/users/$userId?details=", s"path(users, $userId, Some())") && testRoutes(s"/users/$userId?details=$details", s"path(users, $userId, Some($details))") } }, @@ -168,6 +158,38 @@ object QueryParameterSpec extends ZIOHttpSpec { ) } }, + test("query parameters with multiple values non empty") { + check(Gen.int, Gen.listOfN(3)(Gen.alphaNumericString)) { (userId, keys) => + val routes = Routes( + Endpoint(GET / "users" / int("userId")) + .query(HttpCodec.query[NonEmptyChunk[String]]("key")) + .out[String] + .implementHandler { + Handler.fromFunction { case (userId, keys) => + s"""path(users, $userId, ${keys.mkString(", ")})""" + } + }, + ) + val testRoutes = testEndpoint( + routes, + ) _ + + testRoutes( + s"/users/$userId?key=${keys(0)}&key=${keys(1)}&key=${keys(2)}", + s"path(users, $userId, ${keys.mkString(", ")})", + ) && + testRoutes( + s"/users/$userId?key=${keys(0)}&key=${keys(1)}", + s"path(users, $userId, ${keys.take(2).mkString(", ")})", + ) && + testRoutes( + s"/users/$userId?key=${keys(0)}", + s"path(users, $userId, ${keys.take(1).mkString(", ")})", + ) && routes + .runZIO(Request.get(s"/users/$userId")) + .map(resp => assertTrue(resp.status == Status.BadRequest)) + } + }, test("optional query parameters with multiple values") { check(Gen.int, Gen.listOfN(3)(Gen.alphaNumericString)) { (userId, keys) => val testRoutes = testEndpoint( @@ -341,7 +363,7 @@ object QueryParameterSpec extends ZIOHttpSpec { test("query parameters keys without values for multi value query") { val routes = Routes( Endpoint(GET / "users") - .query(HttpCodec.query[Chunk[RuntimeFlags]]("ints")) + .query(HttpCodec.query[Chunk[Int]]("ints")) .out[String] .implementHandler { Handler.fromFunction { queryParams => s"path(users, $queryParams)" } @@ -438,6 +460,6 @@ object QueryParameterSpec extends ZIOHttpSpec { assertTrue(response.status == Status.Ok) } }, - ) + ).provide(ErrorResponseConfig.debugLayer) } diff --git a/zio-http/jvm/src/test/scala/zio/http/endpoint/RequestSpec.scala b/zio-http/jvm/src/test/scala/zio/http/endpoint/RequestSpec.scala index 7e6e1fb6c8..345c63dd29 100644 --- a/zio-http/jvm/src/test/scala/zio/http/endpoint/RequestSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/endpoint/RequestSpec.scala @@ -41,7 +41,7 @@ object RequestSpec extends ZIOHttpSpec { val testRoutes = testEndpointWithHeaders( Routes( Endpoint(GET / "users" / int("userId")) - .header(HeaderCodec.name[java.util.UUID]("X-Correlation-ID")) + .header(HeaderCodec.headerAs[java.util.UUID]("X-Correlation-ID")) .out[String] .implementHandler { Handler.fromFunction { case (userId, correlationId) => @@ -49,7 +49,7 @@ object RequestSpec extends ZIOHttpSpec { } }, Endpoint(GET / "users" / int("userId") / "posts" / int("postId")) - .header(HeaderCodec.name[java.util.UUID]("X-Correlation-ID")) + .header(HeaderCodec.headerAs[java.util.UUID]("X-Correlation-ID")) .out[String] .implementHandler { Handler.fromFunction { case (userId, postId, correlationId) => @@ -70,6 +70,48 @@ object RequestSpec extends ZIOHttpSpec { ) } }, + test("simple request with header with multiple values") { + check(Gen.int, Gen.listOfN(3)(Gen.uuid)) { (userId, correlationId) => + val testRoutes = testEndpointWithHeaders( + Routes( + Endpoint(GET / "users" / int("userId")) + .header(HeaderCodec.headerAs[Chunk[java.util.UUID]]("X-Correlation-ID")) + .out[String] + .implementHandler { + Handler.fromFunction { case (userId, correlationId) => + s"path(users, $userId) header(correlationId=${correlationId.mkString(",")})" + } + }, + ), + ) _ + testRoutes( + s"/users/$userId", + correlationId.map(uuid => "X-Correlation-ID" -> uuid.toString), + s"path(users, $userId) header(correlationId=${correlationId.mkString(",")})", + ) + } + }, + test("simple request with header with multiple values non empty") { + check(Gen.int, Gen.listOfN(3)(Gen.uuid)) { (userId, correlationId) => + val testRoutes = testEndpointWithHeaders( + Routes( + Endpoint(GET / "users" / int("userId")) + .header(HeaderCodec.headerAs[NonEmptyChunk[java.util.UUID]]("X-Correlation-ID")) + .out[String] + .implementHandler { + Handler.fromFunction { case (userId, correlationId) => + s"path(users, $userId) header(correlationId=${correlationId.mkString(",")})" + } + }, + ), + ) _ + testRoutes( + s"/users/$userId", + correlationId.map(uuid => "X-Correlation-ID" -> uuid.toString), + s"path(users, $userId) header(correlationId=${correlationId.mkString(",")})", + ) + } + }, test("custom content type") { check(Gen.int) { id => val endpoint = @@ -200,7 +242,7 @@ object RequestSpec extends ZIOHttpSpec { check(Gen.int, Gen.alphaNumericString) { (id, notACorrelationId) => val endpoint = Endpoint(GET / "posts") - .header(HeaderCodec.name[java.util.UUID]("X-Correlation-ID")) + .header(HeaderCodec.headerAs[java.util.UUID]("X-Correlation-ID")) .out[Int] val routes = endpoint.implementHandler { @@ -219,7 +261,7 @@ object RequestSpec extends ZIOHttpSpec { check(Gen.int) { id => val endpoint = Endpoint(GET / "posts") - .header(HeaderCodec.name[java.util.UUID]("X-Correlation-ID")) + .header(HeaderCodec.headerAs[java.util.UUID]("X-Correlation-ID")) .out[Int] val routes = endpoint.implementHandler { @@ -453,7 +495,7 @@ object RequestSpec extends ZIOHttpSpec { }, test("composite in codecs") { check(Gen.alphaNumericString, Gen.alphaNumericString) { (queryValue, headerValue) => - val headerOrQuery = HeaderCodec.name[String]("X-Header") | HttpCodec.query[String]("header") + val headerOrQuery = HeaderCodec.headerAs[String]("X-Header") | HttpCodec.query[String]("header") val endpoint = Endpoint(GET / "test").out[String].inCodec(headerOrQuery) val routes = endpoint.implementHandler(Handler.identity).toRoutes val request = Request.get( @@ -487,7 +529,7 @@ object RequestSpec extends ZIOHttpSpec { } }, test("composite out codecs") { - val headerOrQuery = HeaderCodec.name[String]("X-Header") | StatusCodec.status(Status.Created) + val headerOrQuery = HeaderCodec.headerAs[String]("X-Header") | StatusCodec.status(Status.Created) val endpoint = Endpoint(GET / "test").query(HttpCodec.query[Boolean]("Created")).outCodec(headerOrQuery) val routes = endpoint.implementHandler { diff --git a/zio-http/jvm/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala b/zio-http/jvm/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala index 6ff5b44fad..f89203bd13 100644 --- a/zio-http/jvm/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala @@ -2,7 +2,7 @@ package zio.http.endpoint.openapi import zio.json.ast.Json import zio.test._ -import zio.{Chunk, Scope, ZIO} +import zio.{Chunk, NonEmptyChunk, Scope, ZIO} import zio.schema.annotation._ import zio.schema.validation.Validation @@ -194,6 +194,13 @@ object OpenAPIGenSpec extends ZIOSpecDefault { .out[SimpleOutputBody] .outError[NotFoundError](Status.NotFound) + private val queryParamNonEmptyCollectionEndpoint = + Endpoint(GET / "withQuery") + .in[SimpleInputBody] + .query(HttpCodec.query[NonEmptyChunk[String]]("query")) + .out[SimpleOutputBody] + .outError[NotFoundError](Status.NotFound) + private val queryParamValidationEndpoint = Endpoint(GET / "withQuery") .in[SimpleInputBody] @@ -648,6 +655,132 @@ object OpenAPIGenSpec extends ZIOSpecDefault { test("with query parameter with multiple values") { val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", queryParamCollectionEndpoint) val json = toJsonAst(generated) + val expectedJson = """{ + | "openapi" : "3.1.0", + | "info" : { + | "title" : "Simple Endpoint", + | "version" : "1.0" + | }, + | "paths" : { + | "/withQuery" : { + | "get" : { + | "parameters" : [ + | + | { + | "name" : "query", + | "in" : "query", + | "schema" : + | { + | "type" : + | "string" + | }, + | "allowReserved" : false, + | "style" : "form" + | } + | ], + | "requestBody" : + | { + | "content" : { + | "application/json" : { + | "schema" : + | { + | "$ref" : "#/components/schemas/SimpleInputBody" + | } + | } + | }, + | "required" : true + | }, + | "responses" : { + | "200" : + | { + | "content" : { + | "application/json" : { + | "schema" : + | { + | "$ref" : "#/components/schemas/SimpleOutputBody" + | } + | } + | } + | }, + | "404" : + | { + | "content" : { + | "application/json" : { + | "schema" : + | { + | "$ref" : "#/components/schemas/NotFoundError" + | } + | } + | } + | } + | } + | } + | } + | }, + | "components" : { + | "schemas" : { + | "NotFoundError" : + | { + | "type" : + | "object", + | "properties" : { + | "message" : { + | "type" : + | "string" + | } + | }, + | "required" : [ + | "message" + | ] + | }, + | "SimpleInputBody" : + | { + | "type" : + | "object", + | "properties" : { + | "name" : { + | "type" : + | "string" + | }, + | "age" : { + | "type" : + | "integer", + | "format" : "int32" + | } + | }, + | "required" : [ + | "name", + | "age" + | ] + | }, + | "SimpleOutputBody" : + | { + | "type" : + | "object", + | "properties" : { + | "userName" : { + | "type" : + | "string" + | }, + | "score" : { + | "type" : + | "integer", + | "format" : "int32" + | } + | }, + | "required" : [ + | "userName", + | "score" + | ] + | } + | } + | } + |}""".stripMargin + assertTrue(json == toJsonAst(expectedJson)) + }, + test("with query parameter with multiple values - non empty") { + val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", queryParamNonEmptyCollectionEndpoint) + val json = toJsonAst(generated) val expectedJson = """{ | "openapi" : "3.1.0", | "info" : { diff --git a/zio-http/shared/src/main/scala/zio/http/Header.scala b/zio-http/shared/src/main/scala/zio/http/Header.scala index bccf90fde9..f5515e1426 100644 --- a/zio-http/shared/src/main/scala/zio/http/Header.scala +++ b/zio-http/shared/src/main/scala/zio/http/Header.scala @@ -64,6 +64,7 @@ object Header { type Typed[HV] = HeaderType { type HeaderValue = HV } } + // @deprecated("Use Schema based header codecs instead", "3.1.0") final case class Custom(customName: CharSequence, value: CharSequence) extends Header { override type Self = Custom override def self: Self = this diff --git a/zio-http/shared/src/main/scala/zio/http/codec/HeaderCodecs.scala b/zio-http/shared/src/main/scala/zio/http/codec/HeaderCodecs.scala index a6ced4dec4..ce21ab1267 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/HeaderCodecs.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/HeaderCodecs.scala @@ -16,24 +16,48 @@ package zio.http.codec +import java.util.UUID + import scala.util.Try import zio.stacktracer.TracingImplicits.disableAutoTrace +import zio.schema._ + import zio.http.Header.HeaderType import zio.http._ private[codec] trait HeaderCodecs { - private[http] def headerCodec[A](name: String, value: TextCodec[A]): HeaderCodec[A] = - HttpCodec.Header(name, value) + private[http] def headerCodec[A](name: String, value: TextCodec[A]): HeaderCodec[A] = { + val schema = value match { + case TextCodec.Constant(string) => + Schema[String].transformOrFail[Unit]( + s => if (s == string) Right(()) else Left(s"Header $name was not $string"), + (_: Unit) => Right(string), + ) + case TextCodec.StringCodec => Schema[String] + case TextCodec.IntCodec => Schema[Int] + case TextCodec.LongCodec => Schema[Long] + case TextCodec.BooleanCodec => Schema[Boolean] + case TextCodec.UUIDCodec => Schema[UUID] + } + HttpCodec.HeaderCustom(name, schema.asInstanceOf[Schema[A]]) + } def header(headerType: HeaderType): HeaderCodec[headerType.HeaderValue] = - headerCodec(headerType.name, TextCodec.string) - .transformOrFailLeft(headerType.parse(_))(headerType.render(_)) + HttpCodec.Header(headerType) + + def headerAs[A](name: String)(implicit schema: Schema[A]): HeaderCodec[A] = + HttpCodec.HeaderCustom(name, schema) + + def headers[A](implicit schema: Schema[A]): HeaderCodec[A] = + HttpCodec.HeaderCustom(schema) + @deprecated("Use Schema based headerAs instead", "3.1.0") def name[A](name: String)(implicit codec: TextCodec[A]): HeaderCodec[A] = headerCodec(name, codec) + @deprecated("Use Schema based API instead", "3.1.0") def nameTransform[A, B]( name: String, parse: B => A, @@ -43,11 +67,13 @@ private[codec] trait HeaderCodecs { Try(parse(s)).toEither.left.map(e => s"Failed to parse header $name: ${e.getMessage}"), )(render) + @deprecated("Use Schema based API instead", "3.1.0") def nameTransformOption[A, B](name: String, parse: B => Option[A], render: A => B)(implicit codec: TextCodec[B], ): HeaderCodec[A] = headerCodec(name, codec).transformOrFailLeft(parse(_).toRight(s"Failed to parse header $name"))(render) + @deprecated("Use Schema based API instead", "3.1.0") def nameTransformOrFail[A, B](name: String, parse: B => Either[String, A], render: A => B)(implicit codec: TextCodec[B], ): HeaderCodec[A] = diff --git a/zio-http/shared/src/main/scala/zio/http/codec/HttpCodec.scala b/zio-http/shared/src/main/scala/zio/http/codec/HttpCodec.scala index 278f185366..673dcc242d 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/HttpCodec.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/HttpCodec.scala @@ -18,18 +18,22 @@ package zio.http.codec import scala.annotation.tailrec import scala.reflect.ClassTag +import scala.util.Try import zio._ -import zio.stream.ZStream +import zio.stream.{ZPipeline, ZStream} import zio.schema.Schema -import zio.schema.annotation._ +import zio.schema.codec.DecodeError +import zio.schema.validation.{Validation, ValidationError} import zio.http.Header.Accept.MediaTypeWithQFactor +import zio.http.Header.HeaderType import zio.http._ -import zio.http.codec.HttpCodec.Query.QueryType +import zio.http.codec.HttpCodec.SchemaCodec.camelToKebab import zio.http.codec.HttpCodec.{Annotated, Metadata} +import zio.http.codec.StringCodec.StringCodec import zio.http.codec.internal._ /** @@ -337,12 +341,13 @@ object HttpCodec extends ContentCodecs with HeaderCodecs with MethodCodecs with private[http] sealed trait AtomTag private[http] object AtomTag { - case object Status extends AtomTag - case object Path extends AtomTag - case object Content extends AtomTag - case object Query extends AtomTag - case object Header extends AtomTag - case object Method extends AtomTag + case object Status extends AtomTag + case object Path extends AtomTag + case object Content extends AtomTag + case object Query extends AtomTag + case object Header extends AtomTag + case object HeaderCustom extends AtomTag + case object Method extends AtomTag } def empty: HttpCodec[Any, Unit] = @@ -2264,140 +2269,220 @@ object HttpCodec extends ContentCodecs with HeaderCodecs with MethodCodecs with def index(index: Int): ContentStream[A] = copy(index = index) } private[http] final case class Query[A, Out]( - queryType: Query.QueryType[A], + codec: SchemaCodec[A], index: Int = 0, ) extends Atom[HttpCodecType.Query, Out] { self => def erase: Query[Any, Any] = self.asInstanceOf[Query[Any, Any]] - def tag: AtomTag = AtomTag.Query - def index(index: Int): Query[A, Out] = copy(index = index) - def isOptional: Boolean = - queryType match { - case QueryType.Primitive(_, BinaryCodecWithSchema(_, schema)) if schema.isInstanceOf[Schema.Optional[_]] => - true - case QueryType.Record(recordSchema) => - recordSchema match { - case s if s.isInstanceOf[Schema.Optional[_]] => true - case record: Schema.Record[_] if record.fields.forall(_.optional) => true - case _ => false - } - case _ => false - } + def isCollection: Boolean = codec.isCollection + + def isOptional: Boolean = codec.isOptional + + def isOptionalSchema: Boolean = codec.isOptionalSchema + + def isPrimitive: Boolean = codec.isPrimitive + + def isRecord: Boolean = codec.isRecord + + def nameUnsafe: String = codec.name.get /** * Returns a new codec, where the value produced by this one is optional. */ override def optional: HttpCodec[HttpCodecType.Query, Option[Out]] = - queryType match { - case QueryType.Primitive(name, codec) if codec.schema.isInstanceOf[Schema.Optional[_]] => - throw new IllegalArgumentException( - s"Cannot make an optional query parameter optional. Name: $name schema: ${codec.schema}", - ) - case QueryType.Primitive(name, codec) => - val optionalSchema = codec.schema.optional - copy(queryType = - QueryType.Primitive(name, BinaryCodecWithSchema(TextBinaryCodec.fromSchema(optionalSchema), optionalSchema)), - ) - case QueryType.Record(recordSchema) if recordSchema.isInstanceOf[Schema.Optional[_]] => - throw new IllegalArgumentException(s"Cannot make an optional query parameter optional") - case QueryType.Record(recordSchema) => - val optionalSchema = recordSchema.optional - copy(queryType = QueryType.Record(optionalSchema)) - case queryType @ QueryType.Collection(_, _, false) => - copy(queryType = QueryType.Collection(queryType.colSchema, queryType.elements, optional = true)) - case queryType @ QueryType.Collection(_, _, true) => - throw new IllegalArgumentException(s"Cannot make an optional query parameter optional: $queryType") - + if (isOptionalSchema) { + throw new IllegalArgumentException("Query is already optional") + } else { + Annotated(Query(codec.optional, index), Metadata.Optional()) } + def tag: AtomTag = AtomTag.Query + } - private[http] object Query { - sealed trait QueryType[A] - object QueryType { - case class Primitive[A](name: String, codec: BinaryCodecWithSchema[A]) extends QueryType[A] - case class Collection[A](colSchema: Schema.Collection[_, _], elements: QueryType.Primitive[A], optional: Boolean) - extends QueryType[A] { - def toCollection(values: Chunk[Any]): A = - colSchema match { - case Schema.Sequence(_, fromChunk, _, _, _) => - fromChunk.asInstanceOf[Chunk[Any] => Any](values).asInstanceOf[A] - case Schema.Set(_, _) => - values.toSet.asInstanceOf[A] - case _ => - throw new IllegalArgumentException( - s"Unsupported collection schema for query object field of type: $colSchema", - ) - } - } - case class Record[A](recordSchema: Schema[A]) extends QueryType[A] { - private var namesAndCodecs: Chunk[(Schema.Field[_, _], BinaryCodecWithSchema[Any])] = _ - private[http] def fieldAndCodecs: Chunk[(Schema.Field[_, _], BinaryCodecWithSchema[Any])] = - if (namesAndCodecs == null) { - namesAndCodecs = recordSchema match { - case record: Schema.Record[A] => - record.fields.map { field => - validateSchema(field.name, field.schema) - val codec = binaryCodecForField(field.annotations.foldLeft(field.schema)(_ annotate _)) - (unlazy(field.asInstanceOf[Schema.Field[Any, Any]]), codec) - } - case s if s.isInstanceOf[Schema.Optional[_]] => - val record = s.asInstanceOf[Schema.Optional[A]].schema.asInstanceOf[Schema.Record[A]] - record.fields.map { field => - validateSchema(field.name, field.annotations.foldLeft(field.schema)(_ annotate _)) - val codec = binaryCodecForField(field.schema) - (field, codec) - } - case s => throw new IllegalArgumentException(s"Unsupported schema for query object field of type: $s") - } - namesAndCodecs - } else { - namesAndCodecs - } + object Query { + def apply[A](name: String, schema: Schema[A]): Query[A, A] = Query(SchemaCodec(Some(name), schema)) + def apply[A](schema: Schema[A]): Query[A, A] = Query(SchemaCodec(None, schema)) + } + + final case class SchemaCodec[A](name: Option[String], schema: Schema[A], kebabCase: Boolean = false) { + + def erasedSchema: Schema[Any] = schema.asInstanceOf[Schema[Any]] + + val isCollection: Boolean = schema match { + case _: Schema.Collection[_, _] => true + case s: Schema.Optional[_] if s.schema.isInstanceOf[Schema.Collection[_, _]] => true + case _ => false + } + + val isOptional: Boolean = schema match { + case _: Schema.Optional[_] => + true + case record: Schema.Record[_] => + record.fields.forall(_.optional) || record.defaultValue.isRight + case d: Schema.Collection[_, _] => + Try(d.empty).isSuccess || d.defaultValue.isRight + case _ => + false + } + + val isOptionalSchema: Boolean = + schema match { + case _: Schema.Optional[_] => true + case s: Schema.Transform[_, _, _] if s.schema.isInstanceOf[Schema.Optional[_]] => true + case _ => false } - private def unlazy(field: Schema.Field[Any, Any]): Schema.Field[Any, Any] = field.schema match { - case Schema.Lazy(schema) => - Schema.Field( - field.name, - schema(), - field.annotations, - field.validation, - field.get, - field.set, + val isPrimitive: Boolean = schema match { + case _: Schema.Primitive[_] => true + case s: Schema.Optional[_] if s.schema.isInstanceOf[Schema.Primitive[_]] => true + case s: Schema.Transform[_, _, _] if s.schema.isInstanceOf[Schema.Primitive[_]] => true + case _ => false + } + + val isRecord: Boolean = schema match { + case _: Schema.Record[_] => true + case s: Schema.Optional[_] if s.schema.isInstanceOf[Schema.Record[_]] => true + case s: Schema.Transform[_, _, _] if s.schema.isInstanceOf[Schema.Record[_]] => true + case _ => false + } + + def optional: SchemaCodec[Option[A]] = copy(schema = schema.optional) + + val recordFields: Chunk[(Schema.Field[_, _], SchemaCodec[Any])] = { + val fields = schema match { + case record: Schema.Record[A] => + record.fields + case s: Schema.Optional[_] if s.schema.isInstanceOf[Schema.Record[_]] => + s.schema.asInstanceOf[Schema.Record[A]].fields + case s: Schema.Transform[_, _, _] if s.schema.isInstanceOf[Schema.Record[_]] => + s.schema.asInstanceOf[Schema.Record[A]].fields + case _ => Chunk.empty + } + fields.map(unlazyField).map { + case field if field.schema.isInstanceOf[Schema.Collection[_, _]] => + val elementSchema = field.schema.asInstanceOf[Schema.Collection[_, _]] match { + case s: Schema.NonEmptySequence[_, _, _] => s.elementSchema + case s: Schema.Sequence[_, _, _] => s.elementSchema + case s: Schema.Set[_] => s.elementSchema + case _: Schema.Map[_, _] => throw new IllegalArgumentException("Maps are not supported") + case _: Schema.NonEmptyMap[_, _] => throw new IllegalArgumentException("Maps are not supported") + } + val codec = SchemaCodec(Some(if (!kebabCase) field.name else camelToKebab(field.name)), elementSchema) + (field, codec.asInstanceOf[SchemaCodec[Any]]) + case field => + val codec = SchemaCodec( + Some(if (!kebabCase) field.name else camelToKebab(field.name)), + field.annotations.foldLeft(field.schema)(_ annotate _), ) - case _ => field + (field, codec.asInstanceOf[SchemaCodec[Any]]) } + } - private def binaryCodecForField[A](schema: Schema[A]): BinaryCodecWithSchema[Any] = (schema match { - case schema @ Schema.Primitive(_, _) => BinaryCodecWithSchema(TextBinaryCodec.fromSchema(schema), schema) - case Schema.Transform(_, _, _, _, _) => BinaryCodecWithSchema(TextBinaryCodec.fromSchema(schema), schema) - case Schema.Optional(_, _) => BinaryCodecWithSchema(TextBinaryCodec.fromSchema(schema), schema) - case e: Schema.Enum[_] if isSimple(e) => BinaryCodecWithSchema(TextBinaryCodec.fromSchema(schema), schema) - case l @ Schema.Lazy(_) => binaryCodecForField(l.schema) - case Schema.Set(schema, _) => binaryCodecForField(schema) - case Schema.Sequence(schema, _, _, _, _) => binaryCodecForField(schema) - case schema => throw new IllegalArgumentException(s"Unsupported schema for query object field of type: $schema") - }).asInstanceOf[BinaryCodecWithSchema[Any]] - - def isSimple(schema: Schema.Enum[_]): Boolean = - schema.annotations.exists(_.isInstanceOf[simpleEnum]) - - @tailrec - private def validateSchema[A](name: String, schema: Schema[A]): Unit = schema match { - case _: Schema.Primitive[A] => () - case Schema.Transform(schema, _, _, _, _) => validateSchema(name, schema) - case Schema.Optional(schema, _) => validateSchema(name, schema) - case Schema.Lazy(schema) => validateSchema(name, schema()) - case Schema.Set(schema, _) => validateSchema(name, schema) - case Schema.Sequence(schema, _, _, _, _) => validateSchema(name, schema) - case s => throw new IllegalArgumentException(s"Unsupported schema for query object field of type: $s") - } + val recordSchema: Schema.Record[Any] = schema match { + case record: Schema.Record[_] => + record.asInstanceOf[Schema.Record[Any]] + case s: Schema.Optional[_] if s.schema.isInstanceOf[Schema.Record[_]] => + s.schema.asInstanceOf[Schema.Record[Any]] + case _ => null + } + val stringCodec: StringCodec[Any] = + stringCodecForSchema(schema.asInstanceOf[Schema[Any]]) + + private def stringCodecForSchema(s: Schema[_]): StringCodec[Any] = { + (s match { + case s: Schema.Optional[_] if s.schema.isInstanceOf[Schema.Primitive[_]] => + StringCodec.fromSchema(schema) + case s: Schema.Optional[_] => + stringCodecForSchema(s.schema) + case s: Schema.Collection[_, _] => + s match { + case schema: Schema.NonEmptySequence[_, _, _] => StringCodec.fromSchema(schema.elementSchema) + case schema: Schema.Sequence[_, _, _] => StringCodec.fromSchema(schema.elementSchema) + case schema: Schema.Set[_] => StringCodec.fromSchema(schema.elementSchema) + case _: Schema.Map[_, _] => StringCodec.fromSchema(s) + case _: Schema.NonEmptyMap[_, _] => StringCodec.fromSchema(s) + } + case s: Schema.Lazy[_] => StringCodec.fromSchema(s.schema) + case s: Schema.Transform[Any, Any, _] @unchecked => + val stringCodec = StringCodec.fromSchema(s.schema) + new StringCodec[Any] { + override def decode(whole: String): Either[DecodeError, Any] = + stringCodec.decode(whole).flatMap(s.f(_).left.map(DecodeError.ReadError(Cause.empty, _))) + + override def streamDecoder: ZPipeline[Any, DecodeError, Char, Any] = + stringCodec.streamDecoder >>> ZPipeline.map(s.f(_).left.map(DecodeError.ReadError(Cause.empty, _))) + + override def encode(value: Any): String = + stringCodec.encode(s.g(value).fold(msg => throw new Exception(msg), identity)) + + override def streamEncoder: ZPipeline[Any, Nothing, Any, Char] = + ZPipeline.map[Any, Any]( + s.g(_).fold(msg => throw new Exception(msg), identity), + ) >>> stringCodec.streamEncoder + } + case schema: Schema[_] => StringCodec.fromSchema(schema) + }).asInstanceOf[StringCodec[Any]] } + + private def unlazyField(field: Schema.Field[_, _]): Schema.Field[_, _] = field match { + case f if f.schema.isInstanceOf[Schema.Lazy[_]] => + Schema.Field( + f.name, + f.schema.asInstanceOf[Schema.Lazy[_]].schema.asInstanceOf[Schema[Any]], + f.annotations, + f.validation.asInstanceOf[Validation[Any]], + f.get.asInstanceOf[Any => Any], + f.set.asInstanceOf[(Any, Any) => Any], + ) + case f => f + } + + def validate(value: Any): Chunk[ValidationError] = + schema.asInstanceOf[Schema[_]] match { + case Schema.Optional(schema: Schema[Any], _) => + schema.validate(value)(schema) + case schema: Schema[_] => + schema.asInstanceOf[Schema[Any]].validate(value)(schema.asInstanceOf[Schema[Any]]) + } + val defaultValue: A = + if (schema.isInstanceOf[Schema.Collection[_, _]]) { + Try(schema.asInstanceOf[Schema.Collection[A, _]].empty).fold( + _ => null.asInstanceOf[A], + identity, + ) + } else { + schema.defaultValue match { + case Right(value) => value + case Left(_) => + schema match { + case _: Schema.Optional[_] => None.asInstanceOf[A] + case collection: Schema.Collection[A, _] => + Try(collection.empty).fold( + _ => null.asInstanceOf[A], + identity, + ) + case _ => null.asInstanceOf[A] + } + } + } + + } + + object SchemaCodec { + private def camelToKebab(s: String): String = + if (s.isEmpty) "" + else if (s.head.isUpper) s.head.toLower.toString + camelToKebab(s.tail) + else if (s.contains('-')) s + else + s.foldLeft("") { (acc, c) => + if (c.isUpper) acc + "-" + c.toLower + else acc + c + } } private[http] final case class Method[A](codec: SimpleCodec[zio.http.Method, A], index: Int = 0) @@ -2409,7 +2494,34 @@ object HttpCodec extends ContentCodecs with HeaderCodecs with MethodCodecs with def index(index: Int): Method[A] = copy(index = index) } - private[http] final case class Header[A](name: String, textCodec: TextCodec[A], index: Int = 0) + private[http] final case class HeaderCustom[A](codec: SchemaCodec[A], index: Int = 0) + extends Atom[HttpCodecType.Header, A] { + self => + def erase: HeaderCustom[Any] = self.asInstanceOf[HeaderCustom[Any]] + + override def optional: HttpCodec[HttpCodecType.Header, Option[A]] = + if (codec.isOptionalSchema) { + throw new IllegalArgumentException("Header is already optional") + } else { + Annotated( + HeaderCustom(codec.optional, index), + Metadata.Optional(), + ) + } + + def tag: AtomTag = AtomTag.HeaderCustom + + def index(index: Int): HeaderCustom[A] = copy(index = index) + } + + object HeaderCustom { + def apply[A](name: String, schema: Schema[A]): HeaderCustom[A] = + HeaderCustom(SchemaCodec(Some(name), schema, kebabCase = true)) + def apply[A](schema: Schema[A]): HeaderCustom[A] = + HeaderCustom(SchemaCodec(None, schema, kebabCase = true)) + } + + private[http] final case class Header[A](headerType: HeaderType.Typed[A], index: Int = 0) extends Atom[HttpCodecType.Header, A] { self => def erase: Header[Any] = self.asInstanceOf[Header[Any]] diff --git a/zio-http/shared/src/main/scala/zio/http/codec/HttpCodecError.scala b/zio-http/shared/src/main/scala/zio/http/codec/HttpCodecError.scala index bcd97223d8..3df1973abb 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/HttpCodecError.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/HttpCodecError.scala @@ -23,6 +23,7 @@ import zio.{Cause, Chunk} import zio.schema.codec.DecodeError import zio.schema.validation.ValidationError +import zio.http.Header.HeaderType import zio.http.{Path, Status} sealed trait HttpCodecError extends Exception with NoStackTrace with Product with Serializable { @@ -33,6 +34,9 @@ object HttpCodecError { final case class MissingHeader(headerName: String) extends HttpCodecError { def message = s"Missing header $headerName" } + final case class MissingHeaders(headerNames: Chunk[String]) extends HttpCodecError { + def message = s"Missing headers ${headerNames.mkString(", ")}" + } final case class MalformedMethod(expected: zio.http.Method, actual: zio.http.Method) extends HttpCodecError { def message = s"Expected $expected but found $actual" } @@ -48,6 +52,12 @@ object HttpCodecError { final case class MalformedHeader(headerName: String, textCodec: TextCodec[_]) extends HttpCodecError { def message = s"Malformed header $headerName failed to decode using $textCodec" } + final case class MalformedCustomHeader(headerName: String, cause: DecodeError) extends HttpCodecError { + def message = s"Malformed custom header $headerName could not be decoded: $cause" + } + final case class MalformedTypedHeader(headerName: String) extends HttpCodecError { + def message = s"Malformed header $headerName" + } final case class MissingQueryParam(queryParamName: String) extends HttpCodecError { def message = s"Missing query parameter $queryParamName" } diff --git a/zio-http/shared/src/main/scala/zio/http/codec/QueryCodecs.scala b/zio-http/shared/src/main/scala/zio/http/codec/QueryCodecs.scala index 4bc203f5e1..4f98ec8e46 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/QueryCodecs.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/QueryCodecs.scala @@ -26,100 +26,33 @@ private[codec] trait QueryCodecs { def query[A](name: String)(implicit schema: Schema[A]): QueryCodec[A] = schema match { - case s @ Schema.Primitive(_, _) => - HttpCodec.Query( - HttpCodec.Query.QueryType - .Primitive(name, BinaryCodecWithSchema.fromBinaryCodec(TextBinaryCodec.fromSchema(s))(s)), - ) - case c @ Schema.Sequence(elementSchema, _, _, _, _) => - if (supportedElementSchema(elementSchema.asInstanceOf[Schema[Any]])) { - HttpCodec.Query( - HttpCodec.Query.QueryType.Collection( - c, - HttpCodec.Query.QueryType.Primitive( - name, - BinaryCodecWithSchema(TextBinaryCodec.fromSchema(elementSchema), elementSchema), - ), - optional = false, - ), - ) - } else { - throw new IllegalArgumentException("Only primitive types can be elements of sequences") - } - case c @ Schema.Set(elementSchema, _) => - if (supportedElementSchema(elementSchema.asInstanceOf[Schema[Any]])) { - HttpCodec.Query( - HttpCodec.Query.QueryType.Collection( - c, - HttpCodec.Query.QueryType.Primitive( - name, - BinaryCodecWithSchema(TextBinaryCodec.fromSchema(elementSchema), elementSchema), - ), - optional = false, - ), - ) - } else { - throw new IllegalArgumentException("Only primitive types can be elements of sets") - } - case Schema.Optional(Schema.Primitive(_, _), _) => - HttpCodec.Query( - HttpCodec.Query.QueryType - .Primitive(name, BinaryCodecWithSchema.fromBinaryCodec(TextBinaryCodec.fromSchema(schema))(schema)), - ) - case Schema.Optional(c @ Schema.Sequence(elementSchema, _, _, _, _), _) => - if (supportedElementSchema(elementSchema.asInstanceOf[Schema[Any]])) { - HttpCodec.Query( - HttpCodec.Query.QueryType.Collection( - c, - HttpCodec.Query.QueryType.Primitive( - name, - BinaryCodecWithSchema(TextBinaryCodec.fromSchema(elementSchema), elementSchema), - ), - optional = true, - ), - ) - } else { - throw new IllegalArgumentException("Only primitive types can be elements of sequences") - } - case Schema.Optional(inner, _) if inner.isInstanceOf[Schema.Set[_]] => - val elementSchema = inner.asInstanceOf[Schema.Set[Any]].elementSchema - if (supportedElementSchema(elementSchema)) { - HttpCodec.Query( - HttpCodec.Query.QueryType.Collection( - inner.asInstanceOf[Schema.Set[_]], - HttpCodec.Query.QueryType.Primitive( - name, - BinaryCodecWithSchema(TextBinaryCodec.fromSchema(inner), inner), - ), - optional = true, - ), - ) - } else { - throw new IllegalArgumentException("Only primitive types can be elements of sets") - } - case enum0: Schema.Enum[_] if enum0.annotations.exists(_.isInstanceOf[simpleEnum]) => - HttpCodec.Query( - HttpCodec.Query.QueryType - .Primitive(name, BinaryCodecWithSchema.fromBinaryCodec(TextBinaryCodec.fromSchema(schema))(schema)), - ) - case record: Schema.Record[A] if record.fields.size == 1 => - val field = record.fields.head - if (supportedElementSchema(field.schema.asInstanceOf[Schema[Any]])) { - HttpCodec.Query( - HttpCodec.Query.QueryType.Primitive( - name, - BinaryCodecWithSchema(TextBinaryCodec.fromSchema(record), record), - ), - ) - } else { - throw new IllegalArgumentException("Only primitive types can be elements of records") - } - case other => + case c: Schema.Collection[_, _] if !supportedCollection(c) => + throw new IllegalArgumentException(s"Collection schema $c is not supported for query codecs") + case enum0: Schema.Enum[_] if !enum0.annotations.exists(_.isInstanceOf[simpleEnum]) => + throw new IllegalArgumentException(s"Enum schema $enum0 is not supported. All cases must be objects.") + case record: Schema.Record[A] if record.fields.size != 1 => + throw new IllegalArgumentException("Use queryAll[A] for records with more than one field") + case record: Schema.Record[A] if !supportedElementSchema(record.fields.head.schema.asInstanceOf[Schema[Any]]) => throw new IllegalArgumentException( - s"Only primitive types, sequences, sets, optional, enums and records with a single field can be used to infer query codecs, but got $other", + s"Only primitive types and simple enums can be used in single field records, but got ${record.fields.head.schema}", ) + case other => + HttpCodec.Query(name, other) } + private def supportedCollection(schema: Schema.Collection[_, _]): Boolean = schema match { + case Schema.Map(_, _, _) => + false + case Schema.NonEmptyMap(_, _, _) => + false + case Schema.Sequence(elementSchema, _, _, _, _) => + supportedElementSchema(elementSchema.asInstanceOf[Schema[Any]]) + case Schema.NonEmptySequence(elementSchema, _, _, _, _) => + supportedElementSchema(elementSchema.asInstanceOf[Schema[Any]]) + case Schema.Set(elementSchema, _) => + supportedElementSchema(elementSchema.asInstanceOf[Schema[Any]]) + } + @tailrec private def supportedElementSchema(elementSchema: Schema[Any]): Boolean = elementSchema match { case Schema.Lazy(schema0) => supportedElementSchema(schema0()) @@ -131,11 +64,16 @@ private[codec] trait QueryCodecs { def queryAll[A](implicit schema: Schema[A]): QueryCodec[A] = schema match { - case _: Schema.Primitive[A] => + case _: Schema.Primitive[A] => throw new IllegalArgumentException("Use query[A](name: String) for primitive types") - case record: Schema.Record[A] => HttpCodec.Query(HttpCodec.Query.QueryType.Record(record)) - case Schema.Optional(_, _) => HttpCodec.Query(HttpCodec.Query.QueryType.Record(schema)) - case _ => throw new IllegalArgumentException("Only case classes can be used to infer query codecs") + case record: Schema.Record[A] => + HttpCodec.Query(record) + case Schema.Optional(s, _) if s.isInstanceOf[Schema.Record[_]] => + HttpCodec.Query(schema) + case _ => + throw new IllegalArgumentException( + "Only case classes can be used with queryAll. Maybe you wanted to use query[A](name: String)?", + ) } } diff --git a/zio-http/shared/src/main/scala/zio/http/codec/StringCodec.scala b/zio-http/shared/src/main/scala/zio/http/codec/StringCodec.scala new file mode 100644 index 0000000000..ae49864c51 --- /dev/null +++ b/zio-http/shared/src/main/scala/zio/http/codec/StringCodec.scala @@ -0,0 +1,394 @@ +package zio.http.codec + +import java.time._ +import java.util.{Currency, UUID} + +import scala.annotation.tailrec + +import zio._ + +import zio.stream._ + +import zio.schema._ +import zio.schema.annotation.simpleEnum +import zio.schema.codec._ + +import zio.http.Charsets + +object StringCodec { + type StringCodec[A] = Codec[String, Char, A] + private def errorCodec[A](schema: Schema[A]) = + new Codec[String, Char, A] { + override def decode(whole: String): Either[DecodeError, A] = throw new IllegalArgumentException( + s"Schema $schema is not supported by StringCodec.", + ) + + override def streamDecoder: ZPipeline[Any, DecodeError, Char, A] = throw new IllegalArgumentException( + s"Schema $schema is not supported by StringCodec.", + ) + + override def encode(value: A): String = throw new IllegalArgumentException( + s"Schema $schema is not supported by StringCodec.", + ) + + override def streamEncoder: ZPipeline[Any, Nothing, A, Char] = throw new IllegalArgumentException( + s"Schema $schema is not supported by StringCodec.", + ) + } + + @tailrec + private def emptyStringIsValue(schema: Schema[_]): Boolean = { + schema match { + case value: Schema.Optional[_] => + val innerSchema = value.schema + emptyStringIsValue(innerSchema) + case _ => + schema.asInstanceOf[Schema.Primitive[_]].standardType match { + case StandardType.UnitType => true + case StandardType.StringType => true + case StandardType.BinaryType => true + case StandardType.CharType => true + case _ => false + } + } + } + + implicit def fromSchema[A](implicit schema: Schema[A]): Codec[String, Char, A] = { + schema match { + case Schema.Optional(schema, _) => + val codec = fromSchema(schema).asInstanceOf[Codec[String, Char, Any]] + new Codec[String, Char, A] { + override def encode(a: A): String = { + a match { + case Some(value) => codec.encode(value) + case None => "" + } + } + + override def decode(c: String): Either[DecodeError, A] = { + if (c.isEmpty && !emptyStringIsValue(schema)) Right(None.asInstanceOf[A]) + else { + codec.decode(c).map(Some(_)).asInstanceOf[Either[DecodeError, A]] + } + } + + override def streamEncoder: ZPipeline[Any, Nothing, A, Char] = + ZPipeline.map((a: A) => encode(a).toSeq).flattenIterables + override def streamDecoder: ZPipeline[Any, DecodeError, Char, A] = + codec.streamDecoder.map(v => Some(v).asInstanceOf[A]) + } + case enum0: Schema.Enum[_] if enum0.annotations.exists(_.isInstanceOf[simpleEnum]) => + val stringCodec = fromSchema(Schema.Primitive(StandardType.StringType)) + val caseMap = enum0.nonTransientCases + .map(case_ => + case_.schema.asInstanceOf[Schema.CaseClass0[A]].defaultConstruct() -> + case_.caseName, + ) + .toMap + val reverseCaseMap = caseMap.map(_.swap) + new Codec[String, Char, A] { + override def encode(a: A): String = { + val caseName = caseMap(a.asInstanceOf[A]) + stringCodec.encode(caseName) + } + + override def decode(c: String): Either[DecodeError, A] = + stringCodec.decode(c).flatMap { caseName => + reverseCaseMap.get(caseName) match { + case Some(value) => Right(value.asInstanceOf[A]) + case None => Left(DecodeError.MissingCase(caseName, enum0)) + } + } + override def streamEncoder: ZPipeline[Any, Nothing, A, Char] = + ZPipeline.map((a: A) => encode(a).toSeq).flattenIterables + override def streamDecoder: ZPipeline[Any, DecodeError, Char, A] = + stringCodec.streamDecoder.mapZIO { caseName => + reverseCaseMap.get(caseName) match { + case Some(value) => ZIO.succeed(value.asInstanceOf[A]) + case None => ZIO.fail(DecodeError.MissingCase(caseName, enum0)) + } + } + } + + case enum0: Schema.Enum[_] => errorCodec(enum0) + case record: Schema.Record[_] if record.fields.size == 1 => + val fieldSchema = record.fields.head.schema + val codec = fromSchema(fieldSchema).asInstanceOf[Codec[String, Char, A]] + new Codec[String, Char, A] { + override def encode(a: A): String = + codec.encode(record.deconstruct(a)(Unsafe.unsafe).head.get.asInstanceOf[A]) + override def decode(c: String): Either[DecodeError, A] = + codec + .decode(c) + .flatMap(a => + record.construct(Chunk(a))(Unsafe.unsafe).left.map(s => DecodeError.ReadError(Cause.empty, s)), + ) + override def streamEncoder: ZPipeline[Any, Nothing, A, Char] = + ZPipeline.map((a: A) => encode(a).toSeq).flattenIterables + override def streamDecoder: ZPipeline[Any, DecodeError, Char, A] = + codec.streamDecoder.mapZIO(a => + ZIO.fromEither( + record.construct(Chunk(a))(Unsafe.unsafe).left.map(s => DecodeError.ReadError(Cause.empty, s)), + ), + ) + } + case record: Schema.Record[_] => errorCodec(record) + case collection: Schema.Collection[_, _] => errorCodec(collection) + case Schema.Transform(schema, f, g, _, _) => + val codec = fromSchema(schema) + new Codec[String, Char, A] { + override def encode(a: A): String = codec.encode(g(a).fold(e => throw new Exception(e), identity)) + override def decode(c: String): Either[DecodeError, A] = codec + .decode(c) + .flatMap(x => + f(x).left + .map(DecodeError.ReadError(Cause.fail(new Exception("Error during decoding")), _)), + ) + override def streamEncoder: ZPipeline[Any, Nothing, A, Char] = + ZPipeline.mapChunks(_.flatMap(encode)) + override def streamDecoder: ZPipeline[Any, DecodeError, Char, A] = codec.streamDecoder.map { x => + f(x) match { + case Left(value) => throw DecodeError.ReadError(Cause.fail(new Exception("Error in decoding")), value) + case Right(a) => a + } + } + } + case Schema.Primitive(_, _) => + new Codec[String, Char, A] { + val decode0: String => Either[DecodeError, Any] = + schema match { + case Schema.Primitive(standardType, _) => + standardType match { + case StandardType.UnitType => + val result = Right("") + (_: String) => result + case StandardType.StringType => + (s: String) => Right(s) + case StandardType.BoolType => + (s: String) => + s.toLowerCase match { + case "true" | "on" | "yes" | "1" => Right(true) + case "false" | "off" | "no" | "0" => Right(false) + case _ => Left(DecodeError.ReadError(Cause.fail(new Exception("Invalid boolean value")), s)) + } + case StandardType.ByteType => + (s: String) => + try { + Right(s.toByte) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.ShortType => + (s: String) => + try { + Right(s.toShort) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.IntType => + (s: String) => + try { + Right(s.toInt) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.LongType => + (s: String) => + try { + Right(s.toLong) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.FloatType => + (s: String) => + try { + Right(s.toFloat) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.DoubleType => + (s: String) => + try { + Right(s.toDouble) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.BinaryType => + val result = Left(DecodeError.UnsupportedSchema(schema, "TextCodec")) + (_: String) => result + case StandardType.CharType => + (s: String) => Right(s.charAt(0)) + case StandardType.UUIDType => + (s: String) => + try { + Right(UUID.fromString(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.BigDecimalType => + (s: String) => + try { + Right(BigDecimal(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.BigIntegerType => + (s: String) => + try { + Right(BigInt(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.DayOfWeekType => + (s: String) => + try { + Right(DayOfWeek.valueOf(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.MonthType => + (s: String) => + try { + Right(Month.valueOf(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.MonthDayType => + (s: String) => + try { + Right(MonthDay.parse(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.PeriodType => + (s: String) => + try { + Right(Period.parse(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.YearType => + (s: String) => + try { + Right(Year.parse(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.YearMonthType => + (s: String) => + try { + Right(YearMonth.parse(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.ZoneIdType => + (s: String) => + try { + Right(ZoneId.of(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.ZoneOffsetType => + (s: String) => + try { + Right(ZoneOffset.of(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.DurationType => + (s: String) => + try { + Right(java.time.Duration.parse(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.InstantType => + (s: String) => + try { + Right(Instant.parse(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.LocalDateType => + (s: String) => + try { + Right(LocalDate.parse(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.LocalTimeType => + (s: String) => + try { + Right(LocalTime.parse(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.LocalDateTimeType => + (s: String) => + try { + Right(LocalDateTime.parse(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.OffsetTimeType => + (s: String) => + try { + Right(OffsetTime.parse(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.OffsetDateTimeType => + (s: String) => + try { + Right(OffsetDateTime.parse(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.ZonedDateTimeType => + (s: String) => + try { + Right(ZonedDateTime.parse(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + case StandardType.CurrencyType => + (s: String) => + try { + Right(Currency.getInstance(s)) + } catch { + case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) + } + } + case schema => + val result = Left( + DecodeError.UnsupportedSchema(schema, "Only primitive types are supported for text decoding."), + ) + (_: String) => result + } + override def encode(a: A): String = + schema match { + case Schema.Primitive(_, _) => a.toString + case _ => + throw new IllegalArgumentException( + s"Cannot encode $a of type ${a.getClass} with schema $schema", + ) + } + + override def decode(c: String): Either[DecodeError, A] = + decode0(c).map(_.asInstanceOf[A]) + + override def streamEncoder: ZPipeline[Any, Nothing, A, Char] = + ZPipeline.map((a: A) => a.toString.toSeq).flattenIterables + + override def streamDecoder: ZPipeline[Any, DecodeError, Char, A] = + ZPipeline + .chunks[Char] + .map(_.asString) + .mapZIO(s => ZIO.fromEither(decode(s))) + .mapErrorCause(e => Cause.fail(DecodeError.ReadError(e, e.squash.getMessage))) + } + case Schema.Lazy(schema0) => fromSchema(schema0()) + case _ => errorCodec(schema) + } + } +} diff --git a/zio-http/shared/src/main/scala/zio/http/codec/TextBinaryCodec.scala b/zio-http/shared/src/main/scala/zio/http/codec/TextBinaryCodec.scala index cd552ad814..9059501f96 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/TextBinaryCodec.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/TextBinaryCodec.scala @@ -120,10 +120,10 @@ object TextBinaryCodec { ) override def streamEncoder: ZPipeline[Any, Nothing, A, Byte] = ZPipeline.mapChunks(_.flatMap(encode)) - override def streamDecoder: ZPipeline[Any, DecodeError, Byte, A] = codec.streamDecoder.map { x => + override def streamDecoder: ZPipeline[Any, DecodeError, Byte, A] = codec.streamDecoder.mapZIO { x => f(x) match { - case Left(value) => throw DecodeError.ReadError(Cause.fail(new Exception("Error in decoding")), value) - case Right(a) => a + case Left(value) => ZIO.fail(DecodeError.ReadError(Cause.fail(new Exception("Error in decoding")), value)) + case Right(a) => ZIO.succeed(a) } } } @@ -356,7 +356,7 @@ object TextBinaryCodec { override def streamDecoder: ZPipeline[Any, DecodeError, Byte, A] = (ZPipeline[Byte] >>> ZPipeline.utf8Decode) - .map(s => decode(Chunk.fromArray(s.getBytes)).fold(throw _, identity)) + .mapZIO(s => ZIO.fromEither(decode(Chunk.fromArray(s.getBytes)))) .mapErrorCause(e => Cause.fail(DecodeError.ReadError(e, e.squash.getMessage))) } case Schema.Lazy(schema0) => fromSchema(schema0()) diff --git a/zio-http/shared/src/main/scala/zio/http/codec/internal/Atomized.scala b/zio-http/shared/src/main/scala/zio/http/codec/internal/Atomized.scala index 9a53a41add..3557b776d9 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/internal/Atomized.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/internal/Atomized.scala @@ -25,30 +25,34 @@ private[http] final case class Atomized[A]( path: A, query: A, header: A, + headerCustom: A, content: A, ) { def get(tag: HttpCodec.AtomTag): A = { tag match { - case HttpCodec.AtomTag.Status => status - case HttpCodec.AtomTag.Path => path - case HttpCodec.AtomTag.Content => content - case HttpCodec.AtomTag.Query => query - case HttpCodec.AtomTag.Header => header - case HttpCodec.AtomTag.Method => method + case HttpCodec.AtomTag.Status => status + case HttpCodec.AtomTag.Path => path + case HttpCodec.AtomTag.Content => content + case HttpCodec.AtomTag.Query => query + case HttpCodec.AtomTag.Header => header + case HttpCodec.AtomTag.HeaderCustom => headerCustom + case HttpCodec.AtomTag.Method => method } } def update(tag: HttpCodec.AtomTag)(f: A => A): Atomized[A] = { tag match { - case HttpCodec.AtomTag.Status => copy(status = f(status)) - case HttpCodec.AtomTag.Path => copy(path = f(path)) - case HttpCodec.AtomTag.Content => copy(content = f(content)) - case HttpCodec.AtomTag.Query => copy(query = f(query)) - case HttpCodec.AtomTag.Header => copy(header = f(header)) - case HttpCodec.AtomTag.Method => copy(method = f(method)) + case HttpCodec.AtomTag.Status => copy(status = f(status)) + case HttpCodec.AtomTag.Path => copy(path = f(path)) + case HttpCodec.AtomTag.Content => copy(content = f(content)) + case HttpCodec.AtomTag.Query => copy(query = f(query)) + case HttpCodec.AtomTag.Header => copy(header = f(header)) + case HttpCodec.AtomTag.HeaderCustom => copy(headerCustom = f(header)) + case HttpCodec.AtomTag.Method => copy(method = f(method)) } } } private[http] object Atomized { - def apply[A](defValue: => A): Atomized[A] = Atomized(defValue, defValue, defValue, defValue, defValue, defValue) + def apply[A](defValue: => A): Atomized[A] = + Atomized(defValue, defValue, defValue, defValue, defValue, defValue, defValue) } diff --git a/zio-http/shared/src/main/scala/zio/http/codec/internal/AtomizedCodecs.scala b/zio-http/shared/src/main/scala/zio/http/codec/internal/AtomizedCodecs.scala index 52cce2c84a..af296b3cfe 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/internal/AtomizedCodecs.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/internal/AtomizedCodecs.scala @@ -27,6 +27,7 @@ private[http] final case class AtomizedCodecs( path: Chunk[PathCodec[_]], query: Chunk[Query[_, _]], header: Chunk[Header[_]], + headerCustom: Chunk[HeaderCustom[_]], content: Chunk[BodyCodec[_]], status: Chunk[SimpleCodec[zio.http.Status, _]], ) { self => @@ -35,9 +36,10 @@ private[http] final case class AtomizedCodecs( case method0: Method[_] => self.copy(method = method :+ method0.codec) case query0: Query[_, _] => self.copy(query = query :+ query0) case header0: Header[_] => self.copy(header = header :+ header0) + case header0: HeaderCustom[_] => self.copy(headerCustom = headerCustom :+ header0) + case status0: Status[_] => self.copy(status = status :+ status0.codec) case content0: Content[_] => self.copy(content = content :+ BodyCodec.Single(content0.codec, content0.name)) - case status0: Status[_] => self.copy(status = status :+ status0.codec) case stream0: ContentStream[_] => self.copy(content = content :+ BodyCodec.Multiple(stream0.codec, stream0.name)) } @@ -48,6 +50,7 @@ private[http] final case class AtomizedCodecs( path = Array.ofDim(path.length), query = Array.ofDim(query.length), header = Array.ofDim(header.length), + headerCustom = Array.ofDim(headerCustom.length), content = Array.ofDim(content.length), status = Array.ofDim(status.length), ) @@ -59,6 +62,7 @@ private[http] final case class AtomizedCodecs( path = path.materialize, query = query.materialize, header = header.materialize, + headerCustom = headerCustom.materialize, content = content.materialize, status = status.materialize, ) @@ -71,6 +75,7 @@ private[http] object AtomizedCodecs { path = Chunk.empty, query = Chunk.empty, header = Chunk.empty, + headerCustom = Chunk.empty, content = Chunk.empty, status = Chunk.empty, ) diff --git a/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala b/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala index 44d99b72dd..e00ba20664 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala @@ -16,16 +16,17 @@ package zio.http.codec.internal +import scala.annotation.tailrec import scala.util.Try import zio._ -import zio.schema.codec.{BinaryCodec, DecodeError} +import zio.schema.codec.DecodeError import zio.schema.{Schema, StandardType} import zio.http.Header.Accept.MediaTypeWithQFactor import zio.http._ -import zio.http.codec.HttpCodec.Query.QueryType +import zio.http.codec.StringCodec.StringCodec import zio.http.codec._ private[codec] trait EncoderDecoder[-AtomTypes, Value] { self => @@ -46,7 +47,7 @@ private[codec] object EncoderDecoder { val flattened = httpCodec.alternatives flattened.length match { - case 0 => Undefined() + case 0 => Undefined.asInstanceOf[EncoderDecoder[AtomTypes, Value]] case 1 => Single(flattened.head._1) case _ => Multiple(flattened) } @@ -109,7 +110,7 @@ private[codec] object EncoderDecoder { } } - private final case class Undefined[-AtomTypes, Value]() extends EncoderDecoder[AtomTypes, Value] { + private object Undefined extends EncoderDecoder[Any, Any] { val encodeWithErrorMessage = """ @@ -125,7 +126,7 @@ private[codec] object EncoderDecoder { override def encodeWith[Z]( config: CodecConfig, - value: Value, + value: Any, outputTypes: Chunk[MediaTypeWithQFactor], )(f: (zio.http.URL, Option[zio.http.Status], Option[zio.http.Method], zio.http.Headers, zio.http.Body) => Z): Z = { throw new IllegalStateException(encodeWithErrorMessage) @@ -138,7 +139,7 @@ private[codec] object EncoderDecoder { method: zio.http.Method, headers: zio.http.Headers, body: zio.http.Body, - )(implicit trace: zio.Trace): zio.Task[Value] = { + )(implicit trace: zio.Trace): zio.Task[Any] = { ZIO.fail(new IllegalStateException(decodeErrorMessage)) } } @@ -168,6 +169,7 @@ private[codec] object EncoderDecoder { decodeStatus(status, inputsBuilder.status) decodeMethod(method, inputsBuilder.method) decodeHeaders(headers, inputsBuilder.header) + decodeCustomHeaders(headers, inputsBuilder.headerCustom) decodeBody(config, body, inputsBuilder.content).as(constructor(inputsBuilder)) } @@ -180,7 +182,7 @@ private[codec] object EncoderDecoder { val query = encodeQuery(config, inputs.query) val status = encodeStatus(inputs.status) val method = encodeMethod(inputs.method) - val headers = encodeHeaders(inputs.header) + val headers = encodeHeaders(inputs.header) ++ encodeCustomHeaders(inputs.headerCustom) def contentTypeHeaders = encodeContentType(inputs.content, outputTypes) val body = encodeBody(config, inputs.content, outputTypes) @@ -220,156 +222,276 @@ private[codec] object EncoderDecoder { inputs, (codec, queryParams) => { val query = codec.erase - val isOptional = query.isOptional - query.queryType match { - case QueryType.Primitive(name, bc @ BinaryCodecWithSchema(_, schema)) => - val count = queryParams.valueCount(name) - val hasParam = queryParams.hasQueryParam(name) - if (!hasParam && isOptional) None - else if (!hasParam) throw HttpCodecError.MissingQueryParam(name) - else if (count != 1) throw HttpCodecError.InvalidQueryParamCount(name, 1, count) - else { - val decoded = bc - .codec(config) - .decode( - Chunk.fromArray(queryParams.unsafeQueryParam(name).getBytes(Charsets.Utf8)), - ) match { + val optional = query.isOptionalSchema + val hasDefault = query.codec.defaultValue != null && query.isOptional + val default = query.codec.defaultValue + if (codec.isPrimitive) { + val name = query.nameUnsafe + val hasParam = queryParams.hasQueryParam(name) + if ( + (!hasParam || (queryParams + .unsafeQueryParam(name) == "" && !emptyStringIsValue(codec.codec.schema))) && hasDefault + ) + default + else if (!hasParam) + throw HttpCodecError.MissingQueryParam(name) + else if (queryParams.valueCount(name) != 1) + throw HttpCodecError.InvalidQueryParamCount(name, 1, queryParams.valueCount(name)) + else { + val decoded = + codec.codec.stringCodec.decode(queryParams.unsafeQueryParam(name)) match { case Left(error) => throw HttpCodecError.MalformedQueryParam(name, error) case Right(value) => value } - val validationErrors = schema.validate(decoded)(schema) - if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) - if (isOptional && decoded == None && emptyStringIsValue(schema.asInstanceOf[Schema.Optional[_]].schema)) - Some("") - else decoded + val validationErrors = codec.codec.erasedSchema.validate(decoded)(codec.codec.erasedSchema) + if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) + else decoded + } + + } else if (codec.isCollection) { + val name = query.nameUnsafe + val hasParam = queryParams.hasQueryParam(name) + + if (!hasParam) { + if (query.codec.defaultValue != null) query.codec.defaultValue + else throw HttpCodecError.MissingQueryParam(name) + } else { + val decoded = queryParams.queryParams(name).map { value => + query.codec.stringCodec.decode(value) match { + case Left(error) => throw HttpCodecError.MalformedQueryParam(name, error) + case Right(value) => value + } + } + if (optional) + Some( + createAndValidateCollection( + query.codec.schema.asInstanceOf[Schema.Optional[_]].schema.asInstanceOf[Schema.Collection[_, _]], + decoded, + ), + ) + else createAndValidateCollection(query.codec.schema.asInstanceOf[Schema.Collection[_, _]], decoded) + } + } else { + val recordSchema = query.codec.recordSchema + val fields = query.codec.recordFields + val hasAllParams = fields.forall { case (field, codec) => + queryParams.hasQueryParam(field.fieldName) || field.optional || codec.isOptional + } + if (!hasAllParams && hasDefault) default + else if (!hasAllParams) throw HttpCodecError.MissingQueryParams { + fields.collect { + case (field, codec) + if !(queryParams.hasQueryParam(field.fieldName) || field.optional || codec.isOptional) => + field.fieldName } - case c @ QueryType.Collection(_, QueryType.Primitive(name, bc), optional) => - if (!queryParams.hasQueryParam(name)) { - if (!optional) c.toCollection(Chunk.empty) - else None + } + else { + val decoded = fields.map { + case (field, codec) if field.schema.isInstanceOf[Schema.Collection[_, _]] => + val schema = field.schema.asInstanceOf[Schema.Collection[_, _]] + if (!queryParams.hasQueryParam(field.fieldName)) { + if (field.defaultValue.isDefined) field.defaultValue.get + else throw HttpCodecError.MissingQueryParam(field.fieldName) + } else { + val values = queryParams.queryParams(field.fieldName) + val decoded = + values.map(decodeAndUnwrap(field, codec, _, HttpCodecError.MalformedQueryParam.apply)) + createAndValidateCollection(schema, decoded) + + } + case (field, codec) => + val value = queryParams.queryParamOrElse(field.fieldName, null) + val decoded = { + if (value == null || (value == "" && !emptyStringIsValue(codec.schema))) codec.defaultValue + else decodeAndUnwrap(field, codec, value, HttpCodecError.MalformedQueryParam.apply) + } + validateDecoded(codec, decoded) + } + if (optional) { + val constructed = recordSchema.construct(decoded)(Unsafe.unsafe) + constructed match { + case Left(value) => + throw HttpCodecError.MalformedQueryParam( + s"${recordSchema.id}", + DecodeError.ReadError(Cause.empty, value), + ) + case Right(value) => + recordSchema.validate(value)(recordSchema) match { + case errors if errors.nonEmpty => throw HttpCodecError.InvalidEntity.wrap(errors) + case _ => Some(value) + } + } } else { - val values = queryParams.queryParams(name) - val decoded = c.toCollection { - values.map { value => - bc.codec(config).decode(Chunk.fromArray(value.getBytes(Charsets.Utf8))) match { - case Left(error) => throw HttpCodecError.MalformedQueryParam(name, error) - case Right(value) => value + val constructed = recordSchema.construct(decoded)(Unsafe.unsafe) + constructed match { + case Left(value) => + throw HttpCodecError.MalformedQueryParam( + s"${recordSchema.id}", + DecodeError.ReadError(Cause.empty, value), + ) + case Right(value) => + recordSchema.validate(value)(recordSchema) match { + case errors if errors.nonEmpty => throw HttpCodecError.InvalidEntity.wrap(errors) + case _ => value } - } } - val erasedSchema = c.colSchema.asInstanceOf[Schema[Any]] - val validationErrors = erasedSchema.validate(decoded)(erasedSchema) - if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) - if (optional) Some(decoded) - else decoded } - case query @ QueryType.Record(recordSchema) => - val hasAllParams = query.fieldAndCodecs.forall { case (field, _) => - queryParams.hasQueryParam(field.name) || field.optional || field.defaultValue.isDefined + } + } + }, + ) + + private def createAndValidateCollection(schema: Schema.Collection[_, _], decoded: Chunk[Any]) = { + val collection = schema.fromChunk.asInstanceOf[Chunk[Any] => Any](decoded) + val erasedSchema = schema.asInstanceOf[Schema[Any]] + val validationErrors = erasedSchema.validate(collection)(erasedSchema) + if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) + collection + } + + @tailrec + private def emptyStringIsValue(schema: Schema[_]): Boolean = { + schema match { + case value: Schema.Optional[_] => + val innerSchema = value.schema + emptyStringIsValue(innerSchema) + case _ => + schema.asInstanceOf[Schema.Primitive[_]].standardType match { + case StandardType.UnitType => true + case StandardType.StringType => true + case StandardType.BinaryType => true + case StandardType.CharType => true + case _ => false + } + } + } + + private def decodeCustomHeaders(headers: Headers, inputs: Array[Any]): Unit = + genericDecode[Headers, HttpCodec.HeaderCustom[_]]( + headers, + flattened.headerCustom, + inputs, + (header, headers) => { + val optional = header.codec.isOptionalSchema + if (header.codec.isPrimitive) { + val schema = header.erase.codec.schema + val name = header.codec.name.get + val value = headers.getUnsafe(name) + if (value ne null) { + val decoded = header.codec.stringCodec.decode(value) match { + case Left(error) => throw HttpCodecError.MalformedCustomHeader(name, error) + case Right(value) => value } - if (!hasAllParams && recordSchema.isInstanceOf[Schema.Optional[_]]) None - else if (!hasAllParams && isOptional) { - recordSchema.defaultValue match { - case Left(err) => - throw new IllegalStateException(s"Cannot compute default value for $recordSchema. Error was: $err") - case Right(value) => value - } - } else if (!hasAllParams) throw HttpCodecError.MissingQueryParams { - query.fieldAndCodecs.collect { - case (field, _) - if !(queryParams.hasQueryParam(field.name) || field.optional || field.defaultValue.isDefined) => - field.name + val validationErrors = schema.validate(decoded)(schema) + if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) + else decoded + } else { + if (optional) None + else throw HttpCodecError.MissingHeader(name) + } + } else if (header.codec.isCollection) { + val name = header.codec.name.get + val values = headers.rawHeaders(name) + val decoded = values.map { value => + header.codec.stringCodec.decode(value) match { + case Left(error) => throw HttpCodecError.MalformedCustomHeader(name, error) + case Right(value) => value + } + } + if (optional) + Some( + createAndValidateCollection( + header.codec.schema.asInstanceOf[Schema.Optional[_]].schema.asInstanceOf[Schema.Collection[_, _]], + decoded, + ), + ) + else createAndValidateCollection(header.codec.schema.asInstanceOf[Schema.Collection[_, _]], decoded) + } else { + val recordSchema = header.codec.recordSchema + val fields = header.codec.recordFields + val hasAllParams = fields.forall { case (field, codec) => + headers.contains(field.fieldName) || field.optional || codec.isOptional + } + if (!hasAllParams) { + if (header.codec.defaultValue != null && header.codec.isOptional) header.codec.defaultValue + else + throw HttpCodecError.MissingHeaders { + fields.collect { + case (field, codec) if !(headers.contains(field.fieldName) || field.optional || codec.isOptional) => + field.fieldName + } } + } else { + val decoded = fields.map { + case (field, codec) if field.schema.isInstanceOf[Schema.Collection[_, _]] => + if (!headers.contains(codec.name.get)) { + if (codec.defaultValue != null) codec.defaultValue + else throw HttpCodecError.MissingHeader(codec.name.get) + } else { + val schema = field.schema.asInstanceOf[Schema.Collection[_, _]] + val values = headers.rawHeaders(codec.name.get) + val decoded = + values.map(decodeAndUnwrap(field, codec, _, HttpCodecError.MalformedCustomHeader.apply)) + createAndValidateCollection(schema, decoded) + } + case (field, codec) => + val value = headers.getUnsafe(codec.name.get) + val decoded = + if (value == null || (value == "" && !emptyStringIsValue(codec.schema))) codec.defaultValue + else decodeAndUnwrap(field, codec, value, HttpCodecError.MalformedCustomHeader.apply) + validateDecoded(codec, decoded) } - else { - val decoded = query.fieldAndCodecs.map { - case (field, codec) if field.schema.isInstanceOf[Schema.Collection[_, _]] => - if (!queryParams.hasQueryParam(field.name) && field.defaultValue.nonEmpty) field.defaultValue.get - else { - val values = queryParams.queryParams(field.name) - val decoded = values.map { value => - codec.codec(config).decode(Chunk.fromArray(value.getBytes(Charsets.Utf8))) match { - case Left(error) => throw HttpCodecError.MalformedQueryParam(field.name, error) - case Right(value) => value - } - } - val decodedCollection = - field.schema match { - case s @ Schema.Sequence(_, fromChunk, _, _, _) => - val collection = fromChunk.asInstanceOf[Chunk[Any] => Any](decoded) - val erasedSchema = s.asInstanceOf[Schema[Any]] - val validationErrors = erasedSchema.validate(collection)(erasedSchema) - if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) - collection - case s @ Schema.Set(_, _) => - val collection = decoded.toSet[Any] - val erasedSchema = s.asInstanceOf[Schema.Set[Any]] - val validationErrors = erasedSchema.validate(collection)(erasedSchema) - if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) - collection - case _ => throw new IllegalStateException("Only Sequence and Set are supported.") - } - decodedCollection - } - case (field, codec) => - val value = queryParams.queryParamOrElse(field.name, null) - val decoded = { - if (value == null) field.defaultValue.get - else { - codec.codec(config).decode(Chunk.fromArray(value.getBytes(Charsets.Utf8))) match { - case Left(error) => throw HttpCodecError.MalformedQueryParam(field.name, error) - case Right(value) => value - } - } + if (optional) { + val constructed = recordSchema.construct(decoded)(Unsafe.unsafe) + constructed match { + case Left(value) => + throw HttpCodecError.MalformedCustomHeader( + s"${recordSchema.id}", + DecodeError.ReadError(Cause.empty, value), + ) + case Right(value) => + recordSchema.validate(value)(recordSchema) match { + case errors if errors.nonEmpty => throw HttpCodecError.InvalidEntity.wrap(errors) + case _ => Some(value) } - val validationErrors = codec.schema.validate(decoded)(codec.schema) - if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) - decoded } - if (recordSchema.isInstanceOf[Schema.Optional[_]]) { - val schema = recordSchema.asInstanceOf[Schema.Optional[_]].schema.asInstanceOf[Schema.Record[Any]] - val constructed = schema.construct(decoded)(Unsafe.unsafe) - constructed match { - case Left(value) => - throw HttpCodecError.MalformedQueryParam( - s"${schema.id}", - DecodeError.ReadError(Cause.empty, value), - ) - case Right(value) => - schema.validate(value)(schema) match { - case errors if errors.nonEmpty => throw HttpCodecError.InvalidEntity.wrap(errors) - case _ => Some(value) - } - } - } else { - val schema = recordSchema.asInstanceOf[Schema.Record[Any]] - val constructed = schema.construct(decoded)(Unsafe.unsafe) - constructed match { - case Left(value) => - throw HttpCodecError.MalformedQueryParam( - s"${schema.id}", - DecodeError.ReadError(Cause.empty, value), - ) - case Right(value) => - schema.validate(value)(schema) match { - case errors if errors.nonEmpty => throw HttpCodecError.InvalidEntity.wrap(errors) - case _ => value - } - } + } else { + val constructed = recordSchema.construct(decoded)(Unsafe.unsafe) + constructed match { + case Left(value) => + throw HttpCodecError.MalformedCustomHeader( + s"${recordSchema.id}", + DecodeError.ReadError(Cause.empty, value), + ) + case Right(value) => + recordSchema.validate(value)(recordSchema) match { + case errors if errors.nonEmpty => throw HttpCodecError.InvalidEntity.wrap(errors) + case _ => value + } } } + } } }, ) - private def emptyStringIsValue(schema: Schema[_]): Boolean = - schema.asInstanceOf[Schema.Primitive[_]].standardType match { - case StandardType.UnitType => true - case StandardType.StringType => true - case StandardType.BinaryType => true - case StandardType.CharType => true - case _ => false + private def validateDecoded(codec: HttpCodec.SchemaCodec[Any], decoded: Any) = { + val validationErrors = codec.schema.validate(decoded)(codec.schema) + if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) + decoded + } + + private def decodeAndUnwrap( + field: Schema.Field[_, _], + codec: HttpCodec.SchemaCodec[Any], + value: String, + ex: (String, DecodeError) => HttpCodecError, + ) = { + codec.stringCodec.decode(value) match { + case Left(error) => throw ex(codec.name.get, error) + case Right(value) => value } + } private def decodeHeaders(headers: Headers, inputs: Array[Any]): Unit = genericDecode[Headers, HttpCodec.Header[_]]( @@ -377,14 +499,14 @@ private[codec] object EncoderDecoder { flattened.header, inputs, (codec, headers) => - headers.get(codec.name) match { + headers.get(codec.headerType.name) match { case Some(value) => - codec.erase.textCodec - .decode(value) - .getOrElse(throw HttpCodecError.MalformedHeader(codec.name, codec.textCodec)) + codec.erase.headerType + .parse(value) + .getOrElse(throw HttpCodecError.MalformedTypedHeader(codec.headerType.name)) case None => - throw HttpCodecError.MissingHeader(codec.name) + throw HttpCodecError.MissingHeader(codec.headerType.name) }, ) @@ -513,111 +635,154 @@ private[codec] object EncoderDecoder { inputs, QueryParams.empty, (codec, input, queryParams) => { - val query = codec.erase - - query.queryType match { - case QueryType.Primitive(name, codec) => - val schema = codec.schema - if (schema.isInstanceOf[Schema.Primitive[_]]) { - if (schema.asInstanceOf[Schema.Primitive[_]].standardType.isInstanceOf[StandardType.UnitType.type]) { - queryParams.addQueryParams(name, Chunk.empty[String]) - } else { - val encoded = codec.codec(config).asInstanceOf[BinaryCodec[Any]].encode(input).asString - queryParams.addQueryParams(name, Chunk(encoded)) - } - } else if (schema.isInstanceOf[Schema.Optional[_]]) { - val encoded = codec.codec(config).asInstanceOf[BinaryCodec[Any]].encode(input).asString - if (encoded.nonEmpty) queryParams.addQueryParams(name, Chunk(encoded)) else queryParams + val query = codec.erase + val optional = query.isOptionalSchema + val stringCodec = codec.codec.stringCodec.asInstanceOf[StringCodec[Any]] + + if (query.isPrimitive) { + val schema = codec.codec.schema + val name = query.nameUnsafe + if (schema.isInstanceOf[Schema.Primitive[_]]) { + if (schema.asInstanceOf[Schema.Primitive[_]].standardType.isInstanceOf[StandardType.UnitType.type]) { + queryParams.addQueryParams(name, Chunk.empty[String]) } else { - throw new IllegalStateException( - "Only primitive schema is supported for query parameters of type Primitive", - ) - } - case QueryType.Collection(_, QueryType.Primitive(name, codec), optional) => - var in: Any = input - if (optional) { - in = input.asInstanceOf[Option[Any]].getOrElse(Chunk.empty) - } - val values = input.asInstanceOf[Iterable[Any]] - if (values.nonEmpty) { - queryParams.addQueryParams( - name, - Chunk.fromIterable( - values.map { value => - codec.codec(config).asInstanceOf[BinaryCodec[Any]].encode(value).asString - }, - ), - ) - } else queryParams - case query @ QueryType.Record(recordSchema) if recordSchema.isInstanceOf[Schema.Optional[_]] => - input match { - case None => queryParams - case Some(value) => - val innerSchema = - recordSchema.asInstanceOf[Schema.Optional[_]].schema.asInstanceOf[Schema.Record[Any]] - val fieldValues = innerSchema.deconstruct(value)(Unsafe.unsafe) - var j = 0 - var qp = queryParams - while (j < fieldValues.size) { - val (field, codec) = query.fieldAndCodecs(j) - val name = field.name - val value = fieldValues(j) match { - case Some(value) => value - case None => field.defaultValue - } - value match { - case values: Iterable[_] => - qp = qp.addQueryParams( - name, - Chunk.fromIterable(values.map { v => - codec.codec(config).asInstanceOf[BinaryCodec[Any]].encode(v).asString - }), - ) - case _ => - val encoded = codec.codec(config).asInstanceOf[BinaryCodec[Any]].encode(value).asString - qp = qp.addQueryParam(name, encoded) - } - j = j + 1 - } - qp + val encoded = stringCodec.encode(input) + queryParams.addQueryParams(name, Chunk(encoded)) } - case query @ QueryType.Record(recordSchema) => - val innerSchema = recordSchema.asInstanceOf[Schema.Record[Any]] - val fieldValues = innerSchema.deconstruct(input)(Unsafe.unsafe) - var j = 0 - var qp = queryParams - while (j < fieldValues.size) { - val (field, codec) = query.fieldAndCodecs(j) - val name = field.name - val value = fieldValues(j) match { + } else if (schema.isInstanceOf[Schema.Optional[_]]) { + val encoded = stringCodec.encode(input) + if (encoded.nonEmpty) queryParams.addQueryParams(name, Chunk(encoded)) else queryParams + } else { + throw new IllegalStateException( + "Only primitive schema is supported for query parameters of type Primitive", + ) + } + } else if (query.isCollection) { + val name = query.nameUnsafe + var in: Any = input + if (optional) { + in = input.asInstanceOf[Option[Any]].getOrElse(Chunk.empty) + } + val values = input.asInstanceOf[Iterable[Any]] + if (values.nonEmpty) { + queryParams.addQueryParams( + name, + Chunk.fromIterable(values.map { value => stringCodec.encode(value) }), + ) + } else queryParams + } else if (query.isRecord) { + val value = input match { + case None => null + case Some(value) => value + case value => value + } + if (value == null) queryParams + else { + val innerSchema = query.codec.recordSchema + val fieldValues = innerSchema.deconstruct(value)(Unsafe.unsafe) + var qp = queryParams + val fieldIt = query.codec.recordFields.iterator + val fieldValuesIt = fieldValues.iterator + while (fieldIt.hasNext) { + val (field, codec) = fieldIt.next() + val name = field.fieldName + val value = fieldValuesIt.next() match { case Some(value) => value case None => field.defaultValue } value match { - case values if values.isInstanceOf[Iterable[_]] => + case values: Iterable[_] => qp = qp.addQueryParams( name, - Chunk.fromIterable(values.asInstanceOf[Iterable[Any]].map { v => - codec.codec(config).asInstanceOf[BinaryCodec[Any]].encode(v).asString + Chunk.fromIterable(values.map { v => + codec.stringCodec.encode(v) }), ) - case _ => - val encoded = codec.codec(config).asInstanceOf[BinaryCodec[Any]].encode(value).asString + case _ => + val encoded = codec.stringCodec.encode(value) qp = qp.addQueryParam(name, encoded) } - j = j + 1 } qp + } + } else { + queryParams + } + }, + ) + + private def encodeCustomHeaders(inputs: Array[Any]): Headers = { + genericEncode[Headers, HttpCodec.HeaderCustom[_]]( + flattened.headerCustom, + inputs, + Headers.empty, + (codec, input, headers) => { + val optional = codec.codec.isOptionalSchema + val stringCodec = codec.erase.codec.stringCodec + if (codec.codec.isPrimitive) { + val name = codec.codec.name.get + val value = input + if (optional && value == None) headers + else { + val encoded = stringCodec.encode(value) + headers ++ Headers(name, encoded) + } + } else if (codec.codec.isCollection) { + val name = codec.codec.name.get + val values = input.asInstanceOf[Iterable[Any]] + if (values.nonEmpty) { + headers ++ Headers.FromIterable( + values.map { value => + Header.Custom(name, stringCodec.encode(value)) + }, + ) + } else headers + } else { + val recordSchema = codec.codec.recordSchema + val fields = codec.codec.recordFields + val value = input match { + case None => null + case Some(value) => value + case value => value + } + if (value == null) headers + else { + val fieldValues = recordSchema.deconstruct(value)(Unsafe.unsafe) + var hs = headers + val fieldIt = fields.iterator + val fieldValuesIt = fieldValues.iterator + while (fieldIt.hasNext) { + val (field, codec) = fieldIt.next() + val name = field.fieldName + val value = fieldValuesIt.next() match { + case Some(value) => value + case None => field.defaultValue + } + value match { + case values: Iterable[_] => + hs = hs ++ Headers.FromIterable( + values.map { v => + Header.Custom(name, codec.stringCodec.encode(v)) + }, + ) + case _ => + val encoded = codec.stringCodec.encode(value) + hs = hs ++ Headers(name, encoded) + } + } + hs + } } }, ) + } private def encodeHeaders(inputs: Array[Any]): Headers = genericEncode[Headers, HttpCodec.Header[_]]( flattened.header, inputs, Headers.empty, - (codec, input, headers) => headers ++ Headers(codec.name, codec.erase.textCodec.encode(input)), + (codec, input, headers) => headers ++ Headers(codec.headerType.name, codec.erase.headerType.render(input)), ) private def encodeStatus(inputs: Array[Any]): Option[Status] = diff --git a/zio-http/shared/src/main/scala/zio/http/endpoint/http/HttpGen.scala b/zio-http/shared/src/main/scala/zio/http/endpoint/http/HttpGen.scala index eae8f5aba1..a1b84d1cd8 100644 --- a/zio-http/shared/src/main/scala/zio/http/endpoint/http/HttpGen.scala +++ b/zio-http/shared/src/main/scala/zio/http/endpoint/http/HttpGen.scala @@ -125,38 +125,36 @@ object HttpGen { private def getName(name: Option[String]) = { name.getOrElse(throw new IllegalArgumentException("name is required")) } def headersVariables(inAtoms: AtomizedMetaCodecs): Seq[HttpVariable] = - inAtoms.header.collect { case mc @ MetaCodec(HttpCodec.Header(name, codec, _), _) => + inAtoms.header.collect { case mc @ MetaCodec(HttpCodec.Header(headerType, _), _) => HttpVariable( - name.capitalize, - mc.examples.values.headOption.map(e => codec.asInstanceOf[TextCodec[Any]].encode(e)), + headerType.name.capitalize, + mc.examples.values.headOption.map(e => headerType.render(e.asInstanceOf[headerType.HeaderValue])), ) } def queryVariables(config: CodecConfig, inAtoms: AtomizedMetaCodecs): Seq[HttpVariable] = { inAtoms.query.collect { - case mc @ MetaCodec(HttpCodec.Query(HttpCodec.Query.QueryType.Primitive(name, codec), _), _) => + case mc @ MetaCodec(HttpCodec.Query(codec, _), _) if codec.isPrimitive => HttpVariable( - name, - mc.examples.values.headOption.map((e: Any) => - codec.codec(config).asInstanceOf[BinaryCodec[Any]].encode(e).asString, - ), + codec.name.get, + mc.examples.values.headOption.map((e: Any) => codec.stringCodec.encode(e)), ) :: Nil - case mc @ MetaCodec(HttpCodec.Query(record @ HttpCodec.Query.QueryType.Record(schema), _), _) => - val recordSchema = (schema match { + case mc @ MetaCodec(HttpCodec.Query(codec, _), _) if codec.isRecord => + val recordSchema = (codec.schema match { case value if value.isInstanceOf[Schema.Optional[_]] => value.asInstanceOf[Schema.Optional[Any]].schema - case _ => schema + case _ => codec.schema }).asInstanceOf[Schema.Record[Any]] val examples = mc.examples.values.headOption.map { ex => recordSchema.deconstruct(ex)(Unsafe.unsafe) } - record.fieldAndCodecs.zipWithIndex.map { case ((field, codec), index) => + codec.recordFields.zipWithIndex.map { case ((field, codec), index) => HttpVariable( field.name, examples.map(values => { val fieldValue = values(index) .orElse(field.defaultValue) .getOrElse(throw new Exception(s"No value or default value for field ${field.name}")) - codec.codec(config).encode(fieldValue).asString + codec.stringCodec.encode(fieldValue) }), ) } diff --git a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/JsonSchema.scala b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/JsonSchema.scala index 8dced7348f..69ca61f5a0 100644 --- a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/JsonSchema.scala +++ b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/JsonSchema.scala @@ -253,6 +253,16 @@ object JsonSchema { .toOption .get + def fromTextCodec(codec: TextCodec[_]): JsonSchema = + codec match { + case TextCodec.Constant(string) => JsonSchema.Enum(Chunk(EnumValue.Str(string))) + case TextCodec.StringCodec => JsonSchema.String() + case TextCodec.IntCodec => JsonSchema.Integer(JsonSchema.IntegerFormat.Int32) + case TextCodec.LongCodec => JsonSchema.Integer(JsonSchema.IntegerFormat.Int64) + case TextCodec.BooleanCodec => JsonSchema.Boolean + case TextCodec.UUIDCodec => JsonSchema.String(JsonSchema.StringFormat.UUID) + } + private[openapi] def fromSerializableSchema(schema: SerializableJsonSchema): JsonSchema = { val definedAttributesCount = schema.productIterator.count(_.asInstanceOf[Option[_]].isDefined) @@ -389,16 +399,6 @@ object JsonSchema { } } - def fromTextCodec(codec: TextCodec[_]): JsonSchema = - codec match { - case TextCodec.Constant(string) => JsonSchema.Enum(Chunk(EnumValue.Str(string))) - case TextCodec.StringCodec => JsonSchema.String() - case TextCodec.IntCodec => JsonSchema.Integer(JsonSchema.IntegerFormat.Int32) - case TextCodec.LongCodec => JsonSchema.Integer(JsonSchema.IntegerFormat.Int64) - case TextCodec.BooleanCodec => JsonSchema.Boolean - case TextCodec.UUIDCodec => JsonSchema.String(JsonSchema.StringFormat.UUID) - } - def fromSegmentCodec(codec: SegmentCodec[_]): JsonSchema = codec match { case SegmentCodec.BoolSeg(_) => JsonSchema.Boolean diff --git a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala index 59ab71fae7..9f3c47e385 100644 --- a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala +++ b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala @@ -16,6 +16,7 @@ import zio.schema.{Schema, TypeId} import zio.http._ import zio.http.codec.HttpCodec.Metadata +import zio.http.codec.HttpCodecType.Content import zio.http.codec._ import zio.http.endpoint._ import zio.http.endpoint.openapi.JsonSchema.SchemaStyle @@ -102,13 +103,15 @@ object OpenAPIGen { path: Chunk[MetaCodec[SegmentCodec[_]]], query: Chunk[MetaCodec[HttpCodec.Query[_, _]]], header: Chunk[MetaCodec[HttpCodec.Header[_]]], - content: Chunk[MetaCodec[HttpCodec.Atom[HttpCodecType.Content, _]]], + content: Chunk[MetaCodec[HttpCodec.Atom[Content, _]]], status: Chunk[MetaCodec[HttpCodec.Status[_]]], + headerCustom: Chunk[MetaCodec[HttpCodec.HeaderCustom[_]]] = Chunk.empty, ) { def append(metaCodec: MetaCodec[_]): AtomizedMetaCodecs = metaCodec match { case MetaCodec(codec: HttpCodec.Method[_], annotations) => - copy(method = - (method :+ MetaCodec(codec.codec, annotations)).asInstanceOf[Chunk[MetaCodec[SimpleCodec[Method, _]]]], + copy( + method = + (method :+ MetaCodec(codec.codec, annotations)).asInstanceOf[Chunk[MetaCodec[SimpleCodec[Method, _]]]], ) case MetaCodec(_: SegmentCodec[_], _) => copy(path = path :+ metaCodec.asInstanceOf[MetaCodec[SegmentCodec[_]]]) @@ -116,6 +119,8 @@ object OpenAPIGen { copy(query = query :+ metaCodec.asInstanceOf[MetaCodec[HttpCodec.Query[_, _]]]) case MetaCodec(_: HttpCodec.Header[_], _) => copy(header = header :+ metaCodec.asInstanceOf[MetaCodec[HttpCodec.Header[_]]]) + case MetaCodec(_: HttpCodec.HeaderCustom[_], _) => + copy(headerCustom = headerCustom :+ metaCodec.asInstanceOf[MetaCodec[HttpCodec.HeaderCustom[_]]]) case MetaCodec(_: HttpCodec.Status[_], _) => copy(status = status :+ metaCodec.asInstanceOf[MetaCodec[HttpCodec.Status[_]]]) case MetaCodec(_: HttpCodec.Content[_], _) => @@ -133,6 +138,7 @@ object OpenAPIGen { header ++ that.header, content ++ that.content, status ++ that.status, + headerCustom ++ that.headerCustom, ) def contentExamples: Map[String, OpenAPI.ReferenceOr.Or[OpenAPI.Example]] = @@ -170,6 +176,7 @@ object OpenAPIGen { header.materialize, content.materialize, status.materialize, + headerCustom.materialize, ) } @@ -181,6 +188,7 @@ object OpenAPIGen { header = Chunk.empty, content = Chunk.empty, status = Chunk.empty, + headerCustom = Chunk.empty, ) def flatten[R, A](codec: HttpCodec[R, A]): AtomizedMetaCodecs = { @@ -752,10 +760,10 @@ object OpenAPIGen { def queryParams: Set[OpenAPI.ReferenceOr[OpenAPI.Parameter]] = { inAtoms.query.collect { - case mc @ MetaCodec(q @ HttpCodec.Query(HttpCodec.Query.QueryType.Primitive(name, codec), _), _) => + case mc @ MetaCodec(q @ HttpCodec.Query(codec, _), _) if codec.isPrimitive => OpenAPI.ReferenceOr.Or( OpenAPI.Parameter.queryParameter( - name = name, + name = q.nameUnsafe, description = mc.docsOpt, schema = Some(OpenAPI.ReferenceOr.Or(JsonSchema.fromZSchema(codec.schema))), deprecated = mc.deprecated, @@ -768,15 +776,15 @@ object OpenAPIGen { required = mc.required && !q.isOptional, ), ) :: Nil - case mc @ MetaCodec(HttpCodec.Query(record @ HttpCodec.Query.QueryType.Record(schema), _), _) => - val recordSchema = (schema match { + case mc @ MetaCodec(HttpCodec.Query(codec, _), _) if codec.isRecord => + val recordSchema = (codec.schema match { case schema if schema.isInstanceOf[Schema.Optional[_]] => schema.asInstanceOf[Schema.Optional[_]].schema - case _ => schema + case _ => codec.schema }).asInstanceOf[Schema.Record[Any]] val examples = mc.examples.map { case (exName, ex) => exName -> recordSchema.deconstruct(ex)(Unsafe.unsafe) } - record.fieldAndCodecs.zipWithIndex.map { case ((field, codec), index) => + codec.recordFields.zipWithIndex.map { case ((field, codec), index) => OpenAPI.ReferenceOr.Or( OpenAPI.Parameter.queryParameter( name = field.name, @@ -793,9 +801,7 @@ object OpenAPIGen { throw new Exception(s"No value or default value found for field ${exName}_${field.name}"), ) s"${exName}_${field.name}" -> OpenAPI.ReferenceOr.Or( - OpenAPI.Example(value = - Json.Str(codec.codec(CodecConfig.defaultConfig).encode(fieldValue).asString), - ), + OpenAPI.Example(value = Json.Str(codec.stringCodec.encode(fieldValue))), ) }, required = mc.required, @@ -803,22 +809,22 @@ object OpenAPIGen { ) } - case mc @ MetaCodec( - HttpCodec.Query( - HttpCodec.Query.QueryType.Collection( - _, - HttpCodec.Query.QueryType.Primitive(name, codec), - optional, - ), - _, - ), - _, - ) => + case mc @ MetaCodec(q @ HttpCodec.Query(codec, _), _) if codec.isCollection => + var required = false + val schema = codec.schema.asInstanceOf[Schema.Collection[_, _]] match { + case s: Schema.Sequence[_, _, _] => s.elementSchema + case _: Schema.Map[_, _] => throw new Exception("Map query parameters not supported") + case _: Schema.NonEmptyMap[_, _] => throw new Exception("Map query parameters not supported") + case s: Schema.NonEmptySequence[_, _, _] => + required = true + s.elementSchema + case s: Schema.Set[_] => s.elementSchema + } OpenAPI.ReferenceOr.Or( OpenAPI.Parameter.queryParameter( - name = name, + name = q.nameUnsafe, description = mc.docsOpt, - schema = Some(OpenAPI.ReferenceOr.Or(JsonSchema.fromZSchema(codec.schema))), + schema = Some(OpenAPI.ReferenceOr.Or(JsonSchema.fromZSchema(schema))), deprecated = mc.deprecated, style = OpenAPI.Parameter.Style.Form, explode = false, @@ -826,7 +832,7 @@ object OpenAPIGen { examples = mc.examples.map { case (exName, value) => exName -> OpenAPI.ReferenceOr.Or(OpenAPI.Example(value = Json.Str(value.toString))) }, - required = mc.required && !optional, + required = required, ), ) :: Nil } @@ -855,19 +861,35 @@ object OpenAPIGen { .map { case mc @ MetaCodec(codec, _) => OpenAPI.ReferenceOr.Or( OpenAPI.Parameter.headerParameter( - name = mc.name.getOrElse(codec.name), + name = mc.name.getOrElse(codec.headerType.name), + description = mc.docsOpt, + definition = Some(OpenAPI.ReferenceOr.Or(JsonSchema.String().nullable(!mc.required))), + deprecated = mc.deprecated, + examples = mc.examples.map { case (name, value) => + name -> OpenAPI.ReferenceOr.Or(OpenAPI.Example(codec.headerType.render(value).toJsonAST.toOption.get)) + }, + required = mc.required, + ), + ) + } + .toSet ++ inAtoms.headerCustom + .asInstanceOf[Chunk[MetaCodec[HttpCodec.HeaderCustom[Any]]]] + // todo must handle collection and record + .map { case mc @ MetaCodec(codec, _) => + OpenAPI.ReferenceOr.Or( + OpenAPI.Parameter.headerParameter( + name = codec.codec.name.getOrElse(throw new Exception("Header parameter must have a name")), description = mc.docsOpt, - definition = - Some(OpenAPI.ReferenceOr.Or(JsonSchema.fromTextCodec(codec.textCodec).nullable(!mc.required))), + definition = Some(OpenAPI.ReferenceOr.Or(JsonSchema.String().nullable(!mc.required))), deprecated = mc.deprecated, examples = mc.examples.map { case (name, value) => - name -> OpenAPI.ReferenceOr.Or(OpenAPI.Example(codec.textCodec.encode(value).toJsonAST.toOption.get)) + name -> OpenAPI.ReferenceOr + .Or(OpenAPI.Example(codec.codec.stringCodec.encode(value).toJsonAST.toOption.get)) }, required = mc.required, ), ) } - .toSet def genDiscriminator(schema: Schema[_]): Option[OpenAPI.Discriminator] = { schema match { @@ -1133,13 +1155,13 @@ object OpenAPIGen { private def headersFrom(codec: AtomizedMetaCodecs) = { codec.header.map { case mc @ MetaCodec(codec, _) => - codec.name -> OpenAPI.ReferenceOr.Or( + codec.headerType.name -> OpenAPI.ReferenceOr.Or( OpenAPI.Header( description = mc.docsOpt, required = true, deprecated = mc.deprecated, allowEmptyValue = false, - schema = Some(JsonSchema.fromTextCodec(codec.textCodec)), + schema = Some(JsonSchema.String().nullable(!mc.required)), ), ) }.toMap diff --git a/zio-http/shared/src/main/scala/zio/http/internal/HeaderGetters.scala b/zio-http/shared/src/main/scala/zio/http/internal/HeaderGetters.scala index 8b6ad2b167..1a9841c329 100644 --- a/zio-http/shared/src/main/scala/zio/http/internal/HeaderGetters.scala +++ b/zio-http/shared/src/main/scala/zio/http/internal/HeaderGetters.scala @@ -67,6 +67,13 @@ trait HeaderGetters { self => /** Gets the raw unparsed header value */ final def rawHeader(name: CharSequence): Option[String] = headers.get(name) + final def rawHeaders(name: CharSequence): Chunk[String] = + Chunk.fromIterator( + headers.iterator + .filter(header => CharSequenceExtensions.equals(header.headerNameAsCharSequence, name, CaseMode.Insensitive)) + .map(_.renderedValue), + ) + /** Gets the raw unparsed header value */ final def rawHeader(headerType: HeaderType): Option[String] = rawHeader(headerType.name) From 14e54e0cbfcb1c06ff856979accec1d6f293fbb6 Mon Sep 17 00:00:00 2001 From: Nabil Abdel-Hafeez <7283535+987Nabil@users.noreply.github.com> Date: Fri, 31 Jan 2025 22:27:52 +0100 Subject: [PATCH 2/2] Simplify schema based header codecs (#3232) --- .github/workflows/ci.yml | 4 + .scalafmt.conf | 2 +- project/Dependencies.scala | 4 +- .../zio/http/endpoint/cli/CliEndpoint.scala | 19 +- .../zio/http/endpoint/cli/HttpOptions.scala | 12 +- .../zio/http/endpoint/cli/CommandGen.scala | 10 +- .../zio/http/endpoint/cli/EndpointGen.scala | 8 +- .../zio/http/endpoint/cli/OptionsGen.scala | 14 +- .../WebSocketReconnectingClient.scala | 4 +- .../websocket/WebSocketServerAdvanced.scala | 12 +- .../websocket/WebSocketSimpleClient.scala | 4 +- .../zio/http/gen/openapi/EndpointGen.scala | 22 +- .../zio/http/gen/scala/CodeGenSpec.scala | 13 +- .../zio/http/netty/model/Conversions.scala | 11 +- .../http/LogAnnotationMiddlewareSpec.scala | 5 +- .../test/scala/zio/http/WebSocketSpec.scala | 2 +- .../zio/http/endpoint/EndpointSpec.scala | 2 + .../endpoint/openapi/OpenAPIGenSpec.scala | 62 +- .../zio/http/security/TimingAttacksSpec.scala | 8 +- .../scala-2/zio/http/UrlInterpolator.scala | 2 +- .../main/scala/zio/http/ClientSSLConfig.scala | 4 +- .../src/main/scala/zio/http/Header.scala | 134 +++- .../src/main/scala/zio/http/Headers.scala | 4 + .../src/main/scala/zio/http/MediaTypes.scala | 2 +- .../src/main/scala/zio/http/ZClient.scala | 2 +- .../scala/zio/http/codec/HeaderCodecs.scala | 8 +- .../main/scala/zio/http/codec/HttpCodec.scala | 275 +------ .../scala/zio/http/codec/HttpCodecError.scala | 12 +- .../scala/zio/http/codec/QueryCodecs.scala | 71 +- .../scala/zio/http/codec/StringCodec.scala | 394 ---------- .../zio/http/codec/internal/Atomized.scala | 29 +- .../http/codec/internal/AtomizedCodecs.scala | 9 +- .../http/codec/internal/EncoderDecoder.scala | 438 +---------- .../zio/http/endpoint/http/HttpGen.scala | 52 +- .../http/endpoint/openapi/OpenAPIGen.scala | 138 +--- .../zio/http/internal/HeaderGetters.scala | 2 +- .../zio/http/internal/HeaderModifier.scala | 3 + .../zio/http/internal/QueryModifier.scala | 7 + .../zio/http/internal/StringSchemaCodec.scala | 723 ++++++++++++++++++ 39 files changed, 1118 insertions(+), 1409 deletions(-) delete mode 100644 zio-http/shared/src/main/scala/zio/http/codec/StringCodec.scala create mode 100644 zio-http/shared/src/main/scala/zio/http/internal/StringSchemaCodec.scala diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9aef2786a7..14d19e071b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -200,6 +200,10 @@ jobs: with: apps: sbt + - uses: coursier/setup-action@v1 + with: + apps: sbt + - name: Release env: PGP_PASSPHRASE: ${{ secrets.PGP_PASSPHRASE }} diff --git a/.scalafmt.conf b/.scalafmt.conf index cac439b8f2..6cdcd8efbe 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -1,4 +1,4 @@ -version = 3.8.1 +version = 3.8.6 maxColumn = 120 align.preset = more diff --git a/project/Dependencies.scala b/project/Dependencies.scala index bcc954795d..853bb67fcd 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -16,8 +16,8 @@ object Dependencies { val `jwt-core` = "com.github.jwt-scala" %% "jwt-core" % JwtCoreVersion val `scala-compact-collection` = "org.scala-lang.modules" %% "scala-collection-compat" % ScalaCompactCollectionVersion - val scalafmt = "org.scalameta" %% "scalafmt-dynamic" % "3.8.1" - val scalametaParsers = "org.scalameta" %% "parsers" % "4.9.9" + val scalafmt = "org.scalameta" %% "scalafmt-dynamic" % "3.8.6" + val scalametaParsers = "org.scalameta" %% "parsers" % "4.12.7" val netty = Seq( diff --git a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/CliEndpoint.scala b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/CliEndpoint.scala index 3228b85c32..2985811983 100644 --- a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/CliEndpoint.scala +++ b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/CliEndpoint.scala @@ -111,11 +111,9 @@ private[cli] object CliEndpoint { } CliEndpoint(body = HttpOptions.Body(name, codec.defaultMediaType, codec.defaultSchema) :: List()) - case HttpCodec.Header(headerType, _) => - CliEndpoint(headers = HttpOptions.Header(headerType.name, TextCodec.string) :: List()) - case HttpCodec.HeaderCustom(codec, _) => - CliEndpoint(headers = HttpOptions.Header(codec.name.get, TextCodec.string) :: List()) - case HttpCodec.Method(codec, _) => + case HttpCodec.Header(headerType, _) => + CliEndpoint(headers = HttpOptions.Header(headerType.names.head, TextCodec.string) :: List()) + case HttpCodec.Method(codec, _) => codec.asInstanceOf[SimpleCodec[_, _]] match { case SimpleCodec.Specified(method: Method) => CliEndpoint(methods = method) @@ -126,14 +124,9 @@ private[cli] object CliEndpoint { CliEndpoint(url = HttpOptions.Path(pathCodec) :: List()) case HttpCodec.Query(codec, _) => - if (codec.isPrimitive) - CliEndpoint(url = HttpOptions.Query(codec) :: List()) - else if (codec.isRecord) - CliEndpoint(url = codec.recordFields.map { case (_, codec) => - HttpOptions.Query(codec) - }.toList) - else - CliEndpoint(url = HttpOptions.Query(codec) :: List()) + CliEndpoint(url = codec.recordFields.map { case (f, codec) => + HttpOptions.Query(codec, f.fieldName) + }.toList) case HttpCodec.Status(_, _) => CliEndpoint.empty } diff --git a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/HttpOptions.scala b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/HttpOptions.scala index 2abb5704b7..7cbe6ca565 100644 --- a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/HttpOptions.scala +++ b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/HttpOptions.scala @@ -11,8 +11,9 @@ import zio.schema._ import zio.schema.annotation.description import zio.http._ -import zio.http.codec.HttpCodec.SchemaCodec import zio.http.codec._ +import zio.http.internal.StringSchemaCodec +import zio.http.internal.StringSchemaCodec.PrimitiveCodec /* * HttpOptions is a wrapper of a transformation Options[CliRequest] => Options[CliRequest]. @@ -265,11 +266,10 @@ private[cli] object HttpOptions { } - final case class Query(codec: SchemaCodec[_], doc: Doc = Doc.empty) extends URLOptions { + final case class Query(codec: PrimitiveCodec[_], name: String, doc: Doc = Doc.empty) extends URLOptions { self => - override val name = codec.name.get override val tag = "?" + name - def options: Options[_] = optionsFromSchema(codec)(name) + def options: Options[_] = optionsFromSchema(codec.schema)(name) override def ??(doc: Doc): Query = self.copy(doc = self.doc + doc) @@ -293,8 +293,8 @@ private[cli] object HttpOptions { } - private[cli] def optionsFromSchema[A](codec: SchemaCodec[A]): String => Options[A] = - codec.schema match { + private[cli] def optionsFromSchema[A](schema: Schema[A]): String => Options[A] = + schema match { case Schema.Primitive(standardType, _) => standardType match { case StandardType.UnitType => diff --git a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CommandGen.scala b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CommandGen.scala index 9d84e04b9f..c724eaa218 100644 --- a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CommandGen.scala +++ b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CommandGen.scala @@ -47,7 +47,7 @@ object CommandGen { case _: HttpOptions.Constant => false case _ => true }.map { - case HttpOptions.Path(pathCodec, _) => + case HttpOptions.Path(pathCodec, _) => pathCodec.segments.toList.flatMap { segment => getSegment(segment) match { case (_, "") => Nil @@ -55,12 +55,12 @@ object CommandGen { case (name, codec) => s"${getName(name, "")} $codec" :: Nil } } - case HttpOptions.Query(codec, _) if codec.isPrimitive => + case HttpOptions.Query(codec, name, _) => getType(codec.schema) match { - case "" => s"[${getName(codec.name.get, "")}]" :: Nil - case tpy => s"${getName(codec.name.get, "")} $tpy" :: Nil + case "" => s"[${getName(name, "")}]" :: Nil + case tpy => s"${getName(name, "")} $tpy" :: Nil } - case _ => Nil + case _ => Nil }.foldRight(List[String]())(_ ++ _) val headersOptions = cliEndpoint.headers.filter { diff --git a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/EndpointGen.scala b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/EndpointGen.scala index 792cbdb2f7..7f34ae9adc 100644 --- a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/EndpointGen.scala +++ b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/EndpointGen.scala @@ -5,9 +5,7 @@ import zio.test._ import zio.schema.Schema -import zio.http.Header.HeaderType import zio.http._ -import zio.http.codec.HttpCodec.SchemaCodec import zio.http.codec._ import zio.http.endpoint._ import zio.http.endpoint.cli.AuxGen._ @@ -103,10 +101,10 @@ object EndpointGen { lazy val anyQuery: Gen[Any, CliReprOf[Codec[_]]] = Gen.alphaNumericStringBounded(1, 30).zip(anyStandardType).map { case (name, schema0) => val schema = schema0.asInstanceOf[Schema[Any]] - val codec = SchemaCodec(Some(name), schema) + val codec = QueryCodec.query(name)(schema).asInstanceOf[HttpCodec.Query[Any]] CliRepr( - HttpCodec.Query(codec), - CliEndpoint(url = HttpOptions.Query(codec) :: Nil), + codec, + CliEndpoint(url = HttpOptions.Query(codec.codec.recordFields.head._2, name) :: Nil), ) } diff --git a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/OptionsGen.scala b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/OptionsGen.scala index 1cb6016f4b..37350d29fe 100644 --- a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/OptionsGen.scala +++ b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/OptionsGen.scala @@ -7,10 +7,10 @@ import zio.test.Gen import zio.schema.Schema import zio.http._ -import zio.http.codec.HttpCodec.SchemaCodec import zio.http.codec._ import zio.http.endpoint.cli.AuxGen._ import zio.http.endpoint.cli.CliRepr._ +import zio.http.internal.StringSchemaCodec.PrimitiveCodec /** * Constructs a Gen[Options[CliRequest], CliEndpoint] @@ -33,10 +33,10 @@ object OptionsGen { .optionsFromTextCodec(textCodec)(name) .map(value => textCodec.encode(value)) - def encodeOptions[A](name: String, codec: SchemaCodec[A]): Options[String] = + def encodeOptions[A](name: String, codec: PrimitiveCodec[A], schema: Schema[A]): Options[String] = HttpOptions - .optionsFromSchema(codec)(name) - .map(value => codec.stringCodec.encode(value)) + .optionsFromSchema(schema)(name) + .map(value => codec.encode(value)) lazy val anyBodyOption: Gen[Any, CliReprOf[Options[Retriever]]] = Gen @@ -80,10 +80,10 @@ object OptionsGen { .alphaNumericStringBounded(1, 30) .zip(anyStandardType) .map { case (name, schema) => - val codec = SchemaCodec(Some(name), schema) + val codec = QueryCodec.query(name)(schema).asInstanceOf[HttpCodec.Query[Any]] CliRepr( - encodeOptions(name, codec), - CliEndpoint(url = HttpOptions.Query(codec) :: Nil), + encodeOptions(name, codec.codec.recordFields.head._2, schema.asInstanceOf[Schema[Any]]), + CliEndpoint(url = HttpOptions.Query(codec.codec.recordFields.head._2, name) :: Nil), ) }, ) diff --git a/zio-http-example/src/main/scala/example/websocket/WebSocketReconnectingClient.scala b/zio-http-example/src/main/scala/example/websocket/WebSocketReconnectingClient.scala index 599ac554ec..cecddbc6bb 100644 --- a/zio-http-example/src/main/scala/example/websocket/WebSocketReconnectingClient.scala +++ b/zio-http-example/src/main/scala/example/websocket/WebSocketReconnectingClient.scala @@ -22,13 +22,13 @@ object WebSocketReconnectingClient extends ZIOAppDefault { channel.send(ChannelEvent.Read(WebSocketFrame.text("foo"))) // On receiving "foo", we'll reply with another "foo" to keep echo loop going - case Read(WebSocketFrame.Text("foo")) => + case Read(WebSocketFrame.Text("foo")) => ZIO.logInfo("Received foo message.") *> ZIO.sleep(1.second) *> channel.send(ChannelEvent.Read(WebSocketFrame.text("foo"))) // Handle exception and convert it to failure to signal the shutdown of the socket connection via the promise - case ExceptionCaught(t) => + case ExceptionCaught(t) => ZIO.fail(t) case _ => diff --git a/zio-http-example/src/main/scala/example/websocket/WebSocketServerAdvanced.scala b/zio-http-example/src/main/scala/example/websocket/WebSocketServerAdvanced.scala index 19a98f3437..c7f08d288a 100644 --- a/zio-http-example/src/main/scala/example/websocket/WebSocketServerAdvanced.scala +++ b/zio-http-example/src/main/scala/example/websocket/WebSocketServerAdvanced.scala @@ -13,19 +13,19 @@ object WebSocketServerAdvanced extends ZIOAppDefault { val socketApp: WebSocketApp[Any] = Handler.webSocket { channel => channel.receiveAll { - case Read(WebSocketFrame.Text("end")) => + case Read(WebSocketFrame.Text("end")) => channel.shutdown // Send a "bar" if the client sends a "foo" - case Read(WebSocketFrame.Text("foo")) => + case Read(WebSocketFrame.Text("foo")) => channel.send(Read(WebSocketFrame.text("bar"))) // Send a "foo" if the client sends a "bar" - case Read(WebSocketFrame.Text("bar")) => + case Read(WebSocketFrame.Text("bar")) => channel.send(Read(WebSocketFrame.text("foo"))) // Echo the same message 10 times if it's not "foo" or "bar" - case Read(WebSocketFrame.Text(text)) => + case Read(WebSocketFrame.Text(text)) => channel .send(Read(WebSocketFrame.text(s"echo $text"))) .repeatN(10) @@ -38,11 +38,11 @@ object WebSocketServerAdvanced extends ZIOAppDefault { channel.send(Read(WebSocketFrame.text("Greetings!"))) // Log when the channel is getting closed - case Read(WebSocketFrame.Close(status, reason)) => + case Read(WebSocketFrame.Close(status, reason)) => Console.printLine("Closing channel with status: " + status + " and reason: " + reason) // Print the exception if it's not a normal close - case ExceptionCaught(cause) => + case ExceptionCaught(cause) => Console.printLine(s"Channel error!: ${cause.getMessage}") case _ => diff --git a/zio-http-example/src/main/scala/example/websocket/WebSocketSimpleClient.scala b/zio-http-example/src/main/scala/example/websocket/WebSocketSimpleClient.scala index 798935e6c1..0f6e085592 100644 --- a/zio-http-example/src/main/scala/example/websocket/WebSocketSimpleClient.scala +++ b/zio-http-example/src/main/scala/example/websocket/WebSocketSimpleClient.scala @@ -21,11 +21,11 @@ object WebSocketSimpleClient extends ZIOAppDefault { channel.send(Read(WebSocketFrame.text("foo"))) // Send a "bar" if the server sends a "foo" - case Read(WebSocketFrame.Text("foo")) => + case Read(WebSocketFrame.Text("foo")) => channel.send(Read(WebSocketFrame.text("bar"))) // Close the connection if the server sends a "bar" - case Read(WebSocketFrame.Text("bar")) => + case Read(WebSocketFrame.Text("bar")) => ZIO.succeed(println("Goodbye!")) *> channel.send(Read(WebSocketFrame.close(1000))) case _ => diff --git a/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala b/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala index d913d9e291..26bc00630d 100644 --- a/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala +++ b/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala @@ -823,9 +823,9 @@ final case class EndpointGen(config: Config) { * `transform` that simply `wrap` / `unwrap` the provided value. */ case JsonSchema.Boolean => aliasedSchemaToCode(openAPI, name, schema) - case JsonSchema.OneOfSchema(schemas) if schemas.exists(_.isPrimitive) => + case JsonSchema.OneOfSchema(schemas) if schemas.exists(_.isPrimitive) => throw new Exception("OneOf schemas with primitive types are not supported") - case JsonSchema.OneOfSchema(schemas) => + case JsonSchema.OneOfSchema(schemas) => val discriminatorInfo = annotations.collectFirst { case JsonSchema.MetaData.Discriminator(discriminator) => discriminator } val discriminator: Option[String] = discriminatorInfo.map(_.propertyName) @@ -885,7 +885,7 @@ final case class EndpointGen(config: Config) { ), ), ) - case JsonSchema.AllOfSchema(schemas) => + case JsonSchema.AllOfSchema(schemas) => val genericFieldIndex = Iterator.from(0) val unvalidatedFields = schemas.toList.map(_.withoutAnnotations).flatMap { case schema @ JsonSchema.Object(_, _, _) => @@ -928,9 +928,9 @@ final case class EndpointGen(config: Config) { enums = Nil, ), ) - case JsonSchema.AnyOfSchema(schemas) if schemas.exists(_.isPrimitive) => + case JsonSchema.AnyOfSchema(schemas) if schemas.exists(_.isPrimitive) => throw new Exception("AnyOf schemas with primitive types are not supported") - case JsonSchema.AnyOfSchema(schemas) => + case JsonSchema.AnyOfSchema(schemas) => val discriminatorInfo = annotations.collectFirst { case JsonSchema.MetaData.Discriminator(discriminator) => discriminator } val discriminator: Option[String] = discriminatorInfo.map(_.propertyName) @@ -987,12 +987,12 @@ final case class EndpointGen(config: Config) { ), ), ) - case JsonSchema.Number(_, _, _, _, _, _) => aliasedSchemaToCode(openAPI, name, schema) + case JsonSchema.Number(_, _, _, _, _, _) => aliasedSchemaToCode(openAPI, name, schema) // should we provide support for (Newtype) aliasing arrays of primitives? - case JsonSchema.ArrayType(None, _, _) => None - case JsonSchema.ArrayType(Some(schema), _, _) => + case JsonSchema.ArrayType(None, _, _) => None + case JsonSchema.ArrayType(Some(schema), _, _) => schemaToCode(schema, openAPI, name, annotations) - case obj: JsonSchema.Object if obj.isInvalid => + case obj: JsonSchema.Object if obj.isInvalid => throw new Exception("Object with properties and additionalProperties is not supported") case obj @ JsonSchema.Object(properties, _, _) if obj.isClosedDictionary => val unvalidatedFields = fieldsOfObject(openAPI, annotations)(obj) @@ -1052,8 +1052,8 @@ final case class EndpointGen(config: Config) { ), ), ) - case JsonSchema.Null => throw new Exception("Null query parameters are not supported") - case JsonSchema.AnyJson => None + case JsonSchema.Null => throw new Exception("Null query parameters are not supported") + case JsonSchema.AnyJson => None } } diff --git a/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala b/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala index ee50ec1c81..71e34f4c7a 100644 --- a/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala +++ b/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala @@ -32,12 +32,11 @@ import zio.http.gen.openapi.{Config, EndpointGen} object CodeGenSpec extends ZIOSpecDefault { case class ValidatedData( - @validate(Validation.maxLength(10)) - name: String, - @validate(Validation.greaterThan(0) && Validation.lessThan(100)) - age: Int, + @validate(Validation.maxLength(10)) name: String, + @validate(Validation.greaterThan(0) && Validation.lessThan(100)) age: Int, ) - implicit val validatedDataSchema: Schema[ValidatedData] = DeriveSchema.gen[ValidatedData] + implicit val validatedDataSchema: Schema[ValidatedData] = + DeriveSchema.gen[ValidatedData] private def fileShouldBe(dir: java.nio.file.Path, subPath: String, expectedFile: String): TestResult = { val filePath = dir.resolve(Paths.get(subPath)) @@ -156,7 +155,8 @@ object CodeGenSpec extends ZIOSpecDefault { .header(HeaderCodec.accept) .header(HeaderCodec.contentType) .header(HeaderCodec.headerAs[String]("Token")) - val openAPI = OpenAPIGen.fromEndpoints(endpoint) + + val openAPI = OpenAPIGen.fromEndpoints(endpoint) codeGenFromOpenAPI(openAPI) { testDir => fileShouldBe(testDir, "api/v1/Users.scala", "/EndpointWithHeaders.scala") @@ -605,6 +605,7 @@ object CodeGenSpec extends ZIOSpecDefault { } } } @@ TestAspect.exceptScala3, // for some reason, the temp dir is empty in Scala 3 + //format: off test("Endpoint with array field in input") { val endpoint = Endpoint(Method.POST / "api" / "v1" / "users").in[UserNameArray].out[User] val openAPI = OpenAPIGen.fromEndpoints("", "", endpoint) diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/model/Conversions.scala b/zio-http/jvm/src/main/scala/zio/http/netty/model/Conversions.scala index 374c9ba27f..2874de2c63 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/model/Conversions.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/model/Conversions.scala @@ -18,6 +18,8 @@ package zio.http.netty.model import scala.collection.AbstractIterator +import zio.Chunk + import zio.http.Server.Config.CompressionOptions import zio.http._ @@ -58,10 +60,10 @@ private[netty] object Conversions { def headersToNetty(headers: Headers): HttpHeaders = headers match { - case Headers.FromIterable(_) => encodeHeaderListToNetty(headers) - case Headers.Native(value, _, _, _) => value.asInstanceOf[HttpHeaders] - case Headers.Concat(_, _) => encodeHeaderListToNetty(headers) - case Headers.Empty => new DefaultHttpHeaders() + case Headers.FromIterable(_) => encodeHeaderListToNetty(headers) + case Headers.Native(value, _, _, _, _) => value.asInstanceOf[HttpHeaders] + case Headers.Concat(_, _) => encodeHeaderListToNetty(headers) + case Headers.Empty => new DefaultHttpHeaders() } def urlToNetty(url: URL): String = { @@ -89,6 +91,7 @@ private[netty] object Conversions { (headers: HttpHeaders) => nettyHeadersIterator(headers), // NOTE: Netty's headers.get is case-insensitive (headers: HttpHeaders, key: CharSequence) => headers.get(key), + (headers: HttpHeaders, key: CharSequence) => Chunk.fromJavaIterable(headers.getAll(key)), (headers: HttpHeaders, key: CharSequence) => headers.contains(key), ) diff --git a/zio-http/jvm/src/test/scala/zio/http/LogAnnotationMiddlewareSpec.scala b/zio-http/jvm/src/test/scala/zio/http/LogAnnotationMiddlewareSpec.scala index 51775c944a..6bf212e7dd 100644 --- a/zio-http/jvm/src/test/scala/zio/http/LogAnnotationMiddlewareSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/LogAnnotationMiddlewareSpec.scala @@ -30,9 +30,8 @@ object LogAnnotationMiddlewareSpec extends ZIOSpecDefault { handler(ZIO.logWarning("Oh!") *> ZIO.succeed(Response.text("Hey logging!"))), ) .@@( - Middleware.logAnnotate(req => - Set(LogAnnotation("method", req.method.name), LogAnnotation("path", req.path.encode)), - ), + Middleware + .logAnnotate(req => Set(LogAnnotation("method", req.method.name), LogAnnotation("path", req.path.encode))), ) .runZIO(Request.get("/")) diff --git a/zio-http/jvm/src/test/scala/zio/http/WebSocketSpec.scala b/zio-http/jvm/src/test/scala/zio/http/WebSocketSpec.scala index 87c7d335e4..cc3e222b6c 100644 --- a/zio-http/jvm/src/test/scala/zio/http/WebSocketSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/WebSocketSpec.scala @@ -79,7 +79,7 @@ object WebSocketSpec extends RoutesRunnableSpec { // Setup websocket server - serverHttp = Handler.webSocket { channel => + serverHttp = Handler.webSocket { channel => channel.receiveAll { case Unregistered => isStarted.succeed(()) <&> isSet.succeed(()).delay(5 seconds).withClock(clock) diff --git a/zio-http/jvm/src/test/scala/zio/http/endpoint/EndpointSpec.scala b/zio-http/jvm/src/test/scala/zio/http/endpoint/EndpointSpec.scala index a2c6334e62..3a093669bb 100644 --- a/zio-http/jvm/src/test/scala/zio/http/endpoint/EndpointSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/endpoint/EndpointSpec.scala @@ -18,6 +18,8 @@ package zio.http.endpoint import java.time.Instant +import scala.math.BigDecimal.javaBigDecimal2bigDecimal + import zio._ import zio.test._ diff --git a/zio-http/jvm/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala b/zio-http/jvm/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala index f89203bd13..49023d128e 100644 --- a/zio-http/jvm/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala @@ -537,27 +537,20 @@ object OpenAPIGenSpec extends ZIOSpecDefault { | "/withQuery" : { | "get" : { | "parameters" : [ - | - | { + | { | "name" : "query", | "in" : "query", - | "schema" : - | { - | "type" :[ - | "string", - | "null" - | ] + | "schema" : { + | "type" : "string" | }, | "allowReserved" : false, | "style" : "form" | } | ], - | "requestBody" : - | { + | "requestBody" : { | "content" : { | "application/json" : { - | "schema" : - | { + | "schema" : { | "$ref" : "#/components/schemas/SimpleInputBody" | } | } @@ -565,23 +558,19 @@ object OpenAPIGenSpec extends ZIOSpecDefault { | "required" : true | }, | "responses" : { - | "200" : - | { + | "200" : { | "content" : { | "application/json" : { - | "schema" : - | { + | "schema" : { | "$ref" : "#/components/schemas/SimpleOutputBody" | } | } | } | }, - | "404" : - | { + | "404" : { | "content" : { | "application/json" : { - | "schema" : - | { + | "schema" : { | "$ref" : "#/components/schemas/NotFoundError" | } | } @@ -593,32 +582,25 @@ object OpenAPIGenSpec extends ZIOSpecDefault { | }, | "components" : { | "schemas" : { - | "NotFoundError" : - | { - | "type" : - | "object", + | "NotFoundError" : { + | "type" : "object", | "properties" : { | "message" : { - | "type" : - | "string" + | "type" : "string" | } | }, | "required" : [ | "message" | ] | }, - | "SimpleInputBody" : - | { - | "type" : - | "object", + | "SimpleInputBody" : { + | "type" : "object", | "properties" : { | "name" : { - | "type" : - | "string" + | "type" : "string" | }, | "age" : { - | "type" : - | "integer", + | "type" : "integer", | "format" : "int32" | } | }, @@ -627,18 +609,14 @@ object OpenAPIGenSpec extends ZIOSpecDefault { | "age" | ] | }, - | "SimpleOutputBody" : - | { - | "type" : - | "object", + | "SimpleOutputBody" : { + | "type" : "object", | "properties" : { | "userName" : { - | "type" : - | "string" + | "type" : "string" | }, | "score" : { - | "type" : - | "integer", + | "type" : "integer", | "format" : "int32" | } | }, diff --git a/zio-http/jvm/src/test/scala/zio/http/security/TimingAttacksSpec.scala b/zio-http/jvm/src/test/scala/zio/http/security/TimingAttacksSpec.scala index f4466f69c6..9afc70d1b8 100644 --- a/zio-http/jvm/src/test/scala/zio/http/security/TimingAttacksSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/security/TimingAttacksSpec.scala @@ -49,10 +49,10 @@ object TimingAttacksSpec extends ZIOSpecDefault { val a = java.lang.System.nanoTime sampleUnsorted = (a - b) :: sampleUnsorted } - val sample = sampleUnsorted.sorted - val tail = sample.drop((nOfTries * i).round.toInt) - val low = tail.head - val high = tail.drop((nOfTries * (j - i)).round.toInt).head + val sample = sampleUnsorted.sorted + val tail = sample.drop((nOfTries * i).round.toInt) + val low = tail.head + val high = tail.drop((nOfTries * (j - i)).round.toInt).head (low, high) } diff --git a/zio-http/shared/src/main/scala-2/zio/http/UrlInterpolator.scala b/zio-http/shared/src/main/scala-2/zio/http/UrlInterpolator.scala index c58fa33f4c..415042e94d 100644 --- a/zio-http/shared/src/main/scala-2/zio/http/UrlInterpolator.scala +++ b/zio-http/shared/src/main/scala-2/zio/http/UrlInterpolator.scala @@ -66,7 +66,7 @@ private[http] object UrlInterpolatorMacro { } } val exampleParts = staticParts.zipAll(injectedPartExamples, "", "").flatMap { case (a, b) => List(a, b) } - val example = exampleParts.mkString + val example = exampleParts.mkString URL.decode(example) match { case Left(error) => c.abort(c.enclosingPosition, error.getMessage) case Right(_) => diff --git a/zio-http/shared/src/main/scala/zio/http/ClientSSLConfig.scala b/zio-http/shared/src/main/scala/zio/http/ClientSSLConfig.scala index a7a4ab8f7f..bbc1aa97a7 100644 --- a/zio-http/shared/src/main/scala/zio/http/ClientSSLConfig.scala +++ b/zio-http/shared/src/main/scala/zio/http/ClientSSLConfig.scala @@ -119,8 +119,8 @@ object ClientSSLConfig { def isInvalidBuild: Boolean = !isValidBuild def build(): FromJavaxNetSsl = this - def keyManagerKeyStoreType(tpe: String): FromJavaxNetSsl = self.copy(keyManagerKeyStoreType = tpe) - def keyManagerFile(file: String): FromJavaxNetSsl = + def keyManagerKeyStoreType(tpe: String): FromJavaxNetSsl = self.copy(keyManagerKeyStoreType = tpe) + def keyManagerFile(file: String): FromJavaxNetSsl = keyManagerSource match { case FromJavaxNetSsl.Resource(_) => this case _ => self.copy(keyManagerSource = FromJavaxNetSsl.File(file)) diff --git a/zio-http/shared/src/main/scala/zio/http/Header.scala b/zio-http/shared/src/main/scala/zio/http/Header.scala index f5515e1426..cc7e24834e 100644 --- a/zio-http/shared/src/main/scala/zio/http/Header.scala +++ b/zio-http/shared/src/main/scala/zio/http/Header.scala @@ -31,8 +31,14 @@ import scala.util.{Either, Failure, Success, Try} import zio.Config.Secret import zio._ -import zio.http.codec.RichTextCodec -import zio.http.internal.DateEncoding +import zio.schema.Schema +import zio.schema.codec.DecodeError +import zio.schema.codec.DecodeError.ReadError +import zio.schema.validation.ValidationError + +import zio.http.Header.HeaderTypeBase.Typed +import zio.http.codec.{HttpCodecError, RichTextCodec} +import zio.http.internal.{DateEncoding, ErrorConstructor, StringSchemaCodec} sealed trait Header { type Self <: Header @@ -50,21 +56,141 @@ sealed trait Header { object Header { - sealed trait HeaderType { + sealed trait HeaderTypeBase { + type HeaderValue + + def names: Chunk[String] + + def fromHeaders(headers: Headers): Either[String, HeaderValue] + + private[http] def fromHeadersUnsafe(headers: Headers): HeaderValue + + def toHeaders(value: HeaderValue): Headers = + value match { + case h: Header => Headers.fromIterable(h :: Nil) + case _ => Headers.empty + } + } + + object HeaderTypeBase { + type Typed[HV] = HeaderTypeBase { type HeaderValue = HV } + } + + sealed trait SchemaHeaderType extends HeaderTypeBase { + def schema: Schema[HeaderValue] + + def optional: HeaderTypeBase.Typed[Option[HeaderValue]] + } + + object SchemaHeaderType { + type Typed[H] = SchemaHeaderType { type HeaderValue = H } + + private val errorConstructor = + new ErrorConstructor { + override def missing(fieldName: String): HttpCodecError = + HttpCodecError.MissingHeader(fieldName) + + override def missingAll(fieldNames: Chunk[String]): HttpCodecError = + HttpCodecError.MissingHeaders(fieldNames) + + override def invalid(errors: Chunk[ValidationError]): HttpCodecError = + HttpCodecError.InvalidEntity.wrap(errors) + + override def malformed(fieldName: String, error: DecodeError): HttpCodecError = + HttpCodecError.DecodingErrorHeader(fieldName, error) + + override def invalidCount(fieldName: String, expected: Int, actual: Int): HttpCodecError = + HttpCodecError.InvalidHeaderCount(fieldName, expected, actual) + } + + def apply[H](implicit schema0: Schema[H]): SchemaHeaderType.Typed[H] = { + new SchemaHeaderType { + type HeaderValue = H + val schema: Schema[H] = + schema0 + val codec: StringSchemaCodec[H, Headers] = + StringSchemaCodec.headerFromSchema(schema0, errorConstructor, null) + + override def names: Chunk[String] = + codec.recordFields.map(_._1.fieldName) + + override def optional: SchemaHeaderType.Typed[Option[H]] = + apply(schema.optional) + + override def fromHeaders(headers: Headers): Either[String, H] = + try Right(codec.decode(headers)) + catch { + case NonFatal(e) => Left(e.getMessage) + } + + private[http] override def fromHeadersUnsafe(headers: Headers): H = + codec.decode(headers) + + override def toHeaders(value: H): Headers = + codec.encode(value, Headers.empty) + } + } + + def apply[H](name: String)(implicit schema0: Schema[H]): SchemaHeaderType.Typed[H] = { + new SchemaHeaderType { + type HeaderValue = H + val schema: Schema[H] = schema0 + val codec: StringSchemaCodec[H, Headers] = + StringSchemaCodec.headerFromSchema(schema, errorConstructor, name) + + override def names: Chunk[String] = + codec.recordFields.map(_._1.fieldName) + + override def optional: SchemaHeaderType.Typed[Option[H]] = + apply(name)(schema.optional) + + override def fromHeaders(headers: Headers): Either[String, H] = + try Right(codec.decode(headers)) + catch { + case NonFatal(e) => Left(e.getMessage) + } + + private[http] override def fromHeadersUnsafe(headers: Headers): H = + codec.decode(headers) + + override def toHeaders(value: H): Headers = + codec.encode(value, Headers.empty) + } + } + } + + sealed trait HeaderType extends HeaderTypeBase { type HeaderValue <: Header + def names: Chunk[String] = Chunk.single(name) + def name: String def parse(value: String): Either[String, HeaderValue] def render(value: HeaderValue): String + + def fromHeaders(headers: Headers): Either[String, HeaderValue] = + headers.getUnsafe(name) match { + case null => Left(s"Header $name not found") + case value => parse(value) + } + + def fromHeadersUnsafe(headers: Headers): HeaderValue = + fromHeaders(headers).fold( + e => throw HttpCodecError.DecodingErrorHeader(name, ReadError(Cause.empty, e)), + identity, + ) + + override def toHeaders(value: HeaderValue): Headers = + Headers.FromIterable(Iterable(value)) + } object HeaderType { type Typed[HV] = HeaderType { type HeaderValue = HV } } - // @deprecated("Use Schema based header codecs instead", "3.1.0") final case class Custom(customName: CharSequence, value: CharSequence) extends Header { override type Self = Custom override def self: Self = this diff --git a/zio-http/shared/src/main/scala/zio/http/Headers.scala b/zio-http/shared/src/main/scala/zio/http/Headers.scala index 3df48e5739..08c24ea016 100644 --- a/zio-http/shared/src/main/scala/zio/http/Headers.scala +++ b/zio-http/shared/src/main/scala/zio/http/Headers.scala @@ -99,6 +99,7 @@ object Headers { value: T, iterate: T => Iterator[Header], unsafeGet: (T, CharSequence) => String, + getAll: (T, CharSequence) => Chunk[String], contains: (T, CharSequence) => Boolean, ) extends Headers { override def contains(key: CharSequence): Boolean = contains(value, key) @@ -106,6 +107,9 @@ object Headers { override def iterator: Iterator[Header] = iterate(value) override private[http] def getUnsafe(key: CharSequence): String = unsafeGet(value, key) + + override def rawHeaders(name: CharSequence): Chunk[String] = getAll(value, name) + } private[zio] final case class Concat(first: Headers, second: Headers) extends Headers { diff --git a/zio-http/shared/src/main/scala/zio/http/MediaTypes.scala b/zio-http/shared/src/main/scala/zio/http/MediaTypes.scala index f9fa13ef35..8660f275b7 100644 --- a/zio-http/shared/src/main/scala/zio/http/MediaTypes.scala +++ b/zio-http/shared/src/main/scala/zio/http/MediaTypes.scala @@ -9688,7 +9688,7 @@ private[zio] trait MediaTypes { lazy val `vnd.dolby.heaac.2`: MediaType = new MediaType("audio", "vnd.dolby.heaac.2", compressible = false, binary = true) - lazy val `amr-wb+` : MediaType = + lazy val `amr-wb+`: MediaType = new MediaType("audio", "amr-wb+", compressible = false, binary = true) lazy val `dsr-es202211`: MediaType = diff --git a/zio-http/shared/src/main/scala/zio/http/ZClient.scala b/zio-http/shared/src/main/scala/zio/http/ZClient.scala index ce6d47b332..e9efe54629 100644 --- a/zio-http/shared/src/main/scala/zio/http/ZClient.scala +++ b/zio-http/shared/src/main/scala/zio/http/ZClient.scala @@ -759,7 +759,7 @@ object ZClient extends ZClientPlatformSpecific { } } yield () }.forkDaemon // Needs to live as long as the channel is alive, as the response body may be streaming - _ <- ZIO.addFinalizer(onComplete.interrupt) + _ <- ZIO.addFinalizer(onComplete.interrupt) response <- restore(onResponse.await.onInterrupt { ZIO.unlessZIO(connectionAcquired.get)(channelFiber.interrupt) *> onComplete.interrupt *> diff --git a/zio-http/shared/src/main/scala/zio/http/codec/HeaderCodecs.scala b/zio-http/shared/src/main/scala/zio/http/codec/HeaderCodecs.scala index ce21ab1267..2bafeeaa48 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/HeaderCodecs.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/HeaderCodecs.scala @@ -24,7 +24,7 @@ import zio.stacktracer.TracingImplicits.disableAutoTrace import zio.schema._ -import zio.http.Header.HeaderType +import zio.http.Header.{HeaderType, SchemaHeaderType} import zio.http._ private[codec] trait HeaderCodecs { @@ -41,17 +41,17 @@ private[codec] trait HeaderCodecs { case TextCodec.BooleanCodec => Schema[Boolean] case TextCodec.UUIDCodec => Schema[UUID] } - HttpCodec.HeaderCustom(name, schema.asInstanceOf[Schema[A]]) + HttpCodec.Header(SchemaHeaderType(name)(schema.asInstanceOf[Schema[A]])) } def header(headerType: HeaderType): HeaderCodec[headerType.HeaderValue] = HttpCodec.Header(headerType) def headerAs[A](name: String)(implicit schema: Schema[A]): HeaderCodec[A] = - HttpCodec.HeaderCustom(name, schema) + HttpCodec.Header(SchemaHeaderType(name)) def headers[A](implicit schema: Schema[A]): HeaderCodec[A] = - HttpCodec.HeaderCustom(schema) + HttpCodec.Header(SchemaHeaderType("headers")) @deprecated("Use Schema based headerAs instead", "3.1.0") def name[A](name: String)(implicit codec: TextCodec[A]): HeaderCodec[A] = diff --git a/zio-http/shared/src/main/scala/zio/http/codec/HttpCodec.scala b/zio-http/shared/src/main/scala/zio/http/codec/HttpCodec.scala index 673dcc242d..3ccabbebf8 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/HttpCodec.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/HttpCodec.scala @@ -20,21 +20,18 @@ import scala.annotation.tailrec import scala.reflect.ClassTag import scala.util.Try -import zio._ +import zio.{http, _} import zio.stream.{ZPipeline, ZStream} import zio.schema.Schema -import zio.schema.codec.DecodeError -import zio.schema.validation.{Validation, ValidationError} import zio.http.Header.Accept.MediaTypeWithQFactor -import zio.http.Header.HeaderType +import zio.http.Header.{HeaderTypeBase, SchemaHeaderType} import zio.http._ -import zio.http.codec.HttpCodec.SchemaCodec.camelToKebab import zio.http.codec.HttpCodec.{Annotated, Metadata} -import zio.http.codec.StringCodec.StringCodec import zio.http.codec.internal._ +import zio.http.internal.StringSchemaCodec /** * A [[zio.http.codec.HttpCodec]] represents a codec for a part of an HTTP @@ -341,13 +338,12 @@ object HttpCodec extends ContentCodecs with HeaderCodecs with MethodCodecs with private[http] sealed trait AtomTag private[http] object AtomTag { - case object Status extends AtomTag - case object Path extends AtomTag - case object Content extends AtomTag - case object Query extends AtomTag - case object Header extends AtomTag - case object HeaderCustom extends AtomTag - case object Method extends AtomTag + case object Status extends AtomTag + case object Path extends AtomTag + case object Content extends AtomTag + case object Query extends AtomTag + case object Header extends AtomTag + case object Method extends AtomTag } def empty: HttpCodec[Any, Unit] = @@ -2268,223 +2264,25 @@ object HttpCodec extends ContentCodecs with HeaderCodecs with MethodCodecs with def index(index: Int): ContentStream[A] = copy(index = index) } - private[http] final case class Query[A, Out]( - codec: SchemaCodec[A], + private[http] final case class Query[A]( + codec: StringSchemaCodec[A, QueryParams], index: Int = 0, - ) extends Atom[HttpCodecType.Query, Out] { + ) extends Atom[HttpCodecType.Query, A] { self => - def erase: Query[Any, Any] = self.asInstanceOf[Query[Any, Any]] + def erase: Query[Any] = self.asInstanceOf[Query[Any]] - def index(index: Int): Query[A, Out] = copy(index = index) - - def isCollection: Boolean = codec.isCollection - - def isOptional: Boolean = codec.isOptional - - def isOptionalSchema: Boolean = codec.isOptionalSchema - - def isPrimitive: Boolean = codec.isPrimitive - - def isRecord: Boolean = codec.isRecord - - def nameUnsafe: String = codec.name.get + def index(index: Int): Query[A] = copy(index = index) /** * Returns a new codec, where the value produced by this one is optional. */ - override def optional: HttpCodec[HttpCodecType.Query, Option[Out]] = - if (isOptionalSchema) { - throw new IllegalArgumentException("Query is already optional") - } else { - Annotated(Query(codec.optional, index), Metadata.Optional()) - } + override def optional: HttpCodec[HttpCodecType.Query, Option[A]] = + Annotated(Query(codec.optional, index), Metadata.Optional()) def tag: AtomTag = AtomTag.Query } - object Query { - def apply[A](name: String, schema: Schema[A]): Query[A, A] = Query(SchemaCodec(Some(name), schema)) - def apply[A](schema: Schema[A]): Query[A, A] = Query(SchemaCodec(None, schema)) - } - - final case class SchemaCodec[A](name: Option[String], schema: Schema[A], kebabCase: Boolean = false) { - - def erasedSchema: Schema[Any] = schema.asInstanceOf[Schema[Any]] - - val isCollection: Boolean = schema match { - case _: Schema.Collection[_, _] => true - case s: Schema.Optional[_] if s.schema.isInstanceOf[Schema.Collection[_, _]] => true - case _ => false - } - - val isOptional: Boolean = schema match { - case _: Schema.Optional[_] => - true - case record: Schema.Record[_] => - record.fields.forall(_.optional) || record.defaultValue.isRight - case d: Schema.Collection[_, _] => - Try(d.empty).isSuccess || d.defaultValue.isRight - case _ => - false - } - - val isOptionalSchema: Boolean = - schema match { - case _: Schema.Optional[_] => true - case s: Schema.Transform[_, _, _] if s.schema.isInstanceOf[Schema.Optional[_]] => true - case _ => false - } - - val isPrimitive: Boolean = schema match { - case _: Schema.Primitive[_] => true - case s: Schema.Optional[_] if s.schema.isInstanceOf[Schema.Primitive[_]] => true - case s: Schema.Transform[_, _, _] if s.schema.isInstanceOf[Schema.Primitive[_]] => true - case _ => false - } - - val isRecord: Boolean = schema match { - case _: Schema.Record[_] => true - case s: Schema.Optional[_] if s.schema.isInstanceOf[Schema.Record[_]] => true - case s: Schema.Transform[_, _, _] if s.schema.isInstanceOf[Schema.Record[_]] => true - case _ => false - } - - def optional: SchemaCodec[Option[A]] = copy(schema = schema.optional) - - val recordFields: Chunk[(Schema.Field[_, _], SchemaCodec[Any])] = { - val fields = schema match { - case record: Schema.Record[A] => - record.fields - case s: Schema.Optional[_] if s.schema.isInstanceOf[Schema.Record[_]] => - s.schema.asInstanceOf[Schema.Record[A]].fields - case s: Schema.Transform[_, _, _] if s.schema.isInstanceOf[Schema.Record[_]] => - s.schema.asInstanceOf[Schema.Record[A]].fields - case _ => Chunk.empty - } - fields.map(unlazyField).map { - case field if field.schema.isInstanceOf[Schema.Collection[_, _]] => - val elementSchema = field.schema.asInstanceOf[Schema.Collection[_, _]] match { - case s: Schema.NonEmptySequence[_, _, _] => s.elementSchema - case s: Schema.Sequence[_, _, _] => s.elementSchema - case s: Schema.Set[_] => s.elementSchema - case _: Schema.Map[_, _] => throw new IllegalArgumentException("Maps are not supported") - case _: Schema.NonEmptyMap[_, _] => throw new IllegalArgumentException("Maps are not supported") - } - val codec = SchemaCodec(Some(if (!kebabCase) field.name else camelToKebab(field.name)), elementSchema) - (field, codec.asInstanceOf[SchemaCodec[Any]]) - case field => - val codec = SchemaCodec( - Some(if (!kebabCase) field.name else camelToKebab(field.name)), - field.annotations.foldLeft(field.schema)(_ annotate _), - ) - (field, codec.asInstanceOf[SchemaCodec[Any]]) - } - } - - val recordSchema: Schema.Record[Any] = schema match { - case record: Schema.Record[_] => - record.asInstanceOf[Schema.Record[Any]] - case s: Schema.Optional[_] if s.schema.isInstanceOf[Schema.Record[_]] => - s.schema.asInstanceOf[Schema.Record[Any]] - case _ => null - } - - val stringCodec: StringCodec[Any] = - stringCodecForSchema(schema.asInstanceOf[Schema[Any]]) - - private def stringCodecForSchema(s: Schema[_]): StringCodec[Any] = { - (s match { - case s: Schema.Optional[_] if s.schema.isInstanceOf[Schema.Primitive[_]] => - StringCodec.fromSchema(schema) - case s: Schema.Optional[_] => - stringCodecForSchema(s.schema) - case s: Schema.Collection[_, _] => - s match { - case schema: Schema.NonEmptySequence[_, _, _] => StringCodec.fromSchema(schema.elementSchema) - case schema: Schema.Sequence[_, _, _] => StringCodec.fromSchema(schema.elementSchema) - case schema: Schema.Set[_] => StringCodec.fromSchema(schema.elementSchema) - case _: Schema.Map[_, _] => StringCodec.fromSchema(s) - case _: Schema.NonEmptyMap[_, _] => StringCodec.fromSchema(s) - } - case s: Schema.Lazy[_] => StringCodec.fromSchema(s.schema) - case s: Schema.Transform[Any, Any, _] @unchecked => - val stringCodec = StringCodec.fromSchema(s.schema) - new StringCodec[Any] { - override def decode(whole: String): Either[DecodeError, Any] = - stringCodec.decode(whole).flatMap(s.f(_).left.map(DecodeError.ReadError(Cause.empty, _))) - - override def streamDecoder: ZPipeline[Any, DecodeError, Char, Any] = - stringCodec.streamDecoder >>> ZPipeline.map(s.f(_).left.map(DecodeError.ReadError(Cause.empty, _))) - - override def encode(value: Any): String = - stringCodec.encode(s.g(value).fold(msg => throw new Exception(msg), identity)) - - override def streamEncoder: ZPipeline[Any, Nothing, Any, Char] = - ZPipeline.map[Any, Any]( - s.g(_).fold(msg => throw new Exception(msg), identity), - ) >>> stringCodec.streamEncoder - } - case schema: Schema[_] => StringCodec.fromSchema(schema) - }).asInstanceOf[StringCodec[Any]] - } - - private def unlazyField(field: Schema.Field[_, _]): Schema.Field[_, _] = field match { - case f if f.schema.isInstanceOf[Schema.Lazy[_]] => - Schema.Field( - f.name, - f.schema.asInstanceOf[Schema.Lazy[_]].schema.asInstanceOf[Schema[Any]], - f.annotations, - f.validation.asInstanceOf[Validation[Any]], - f.get.asInstanceOf[Any => Any], - f.set.asInstanceOf[(Any, Any) => Any], - ) - case f => f - } - - def validate(value: Any): Chunk[ValidationError] = - schema.asInstanceOf[Schema[_]] match { - case Schema.Optional(schema: Schema[Any], _) => - schema.validate(value)(schema) - case schema: Schema[_] => - schema.asInstanceOf[Schema[Any]].validate(value)(schema.asInstanceOf[Schema[Any]]) - } - val defaultValue: A = - if (schema.isInstanceOf[Schema.Collection[_, _]]) { - Try(schema.asInstanceOf[Schema.Collection[A, _]].empty).fold( - _ => null.asInstanceOf[A], - identity, - ) - } else { - schema.defaultValue match { - case Right(value) => value - case Left(_) => - schema match { - case _: Schema.Optional[_] => None.asInstanceOf[A] - case collection: Schema.Collection[A, _] => - Try(collection.empty).fold( - _ => null.asInstanceOf[A], - identity, - ) - case _ => null.asInstanceOf[A] - } - } - } - - } - - object SchemaCodec { - private def camelToKebab(s: String): String = - if (s.isEmpty) "" - else if (s.head.isUpper) s.head.toLower.toString + camelToKebab(s.tail) - else if (s.contains('-')) s - else - s.foldLeft("") { (acc, c) => - if (c.isUpper) acc + "-" + c.toLower - else acc + c - } - } - private[http] final case class Method[A](codec: SimpleCodec[zio.http.Method, A], index: Int = 0) extends Atom[HttpCodecType.Method, A] { self => @@ -2494,34 +2292,7 @@ object HttpCodec extends ContentCodecs with HeaderCodecs with MethodCodecs with def index(index: Int): Method[A] = copy(index = index) } - private[http] final case class HeaderCustom[A](codec: SchemaCodec[A], index: Int = 0) - extends Atom[HttpCodecType.Header, A] { - self => - def erase: HeaderCustom[Any] = self.asInstanceOf[HeaderCustom[Any]] - - override def optional: HttpCodec[HttpCodecType.Header, Option[A]] = - if (codec.isOptionalSchema) { - throw new IllegalArgumentException("Header is already optional") - } else { - Annotated( - HeaderCustom(codec.optional, index), - Metadata.Optional(), - ) - } - - def tag: AtomTag = AtomTag.HeaderCustom - - def index(index: Int): HeaderCustom[A] = copy(index = index) - } - - object HeaderCustom { - def apply[A](name: String, schema: Schema[A]): HeaderCustom[A] = - HeaderCustom(SchemaCodec(Some(name), schema, kebabCase = true)) - def apply[A](schema: Schema[A]): HeaderCustom[A] = - HeaderCustom(SchemaCodec(None, schema, kebabCase = true)) - } - - private[http] final case class Header[A](headerType: HeaderType.Typed[A], index: Int = 0) + private[http] final case class Header[A](headerType: HeaderTypeBase.Typed[A], index: Int = 0) extends Atom[HttpCodecType.Header, A] { self => def erase: Header[Any] = self.asInstanceOf[Header[Any]] @@ -2529,6 +2300,18 @@ object HttpCodec extends ContentCodecs with HeaderCodecs with MethodCodecs with def tag: AtomTag = AtomTag.Header def index(index: Int): Header[A] = copy(index = index) + + override def optional: HttpCodec[HttpCodecType.Header, Option[A]] = { + headerType match { + case headerType if headerType.isInstanceOf[SchemaHeaderType] => + Annotated( + Header(headerType.asInstanceOf[SchemaHeaderType.Typed[A]].optional, index), + Metadata.Optional(), + ) + case _ => + super.optional + } + } } private[http] final case class Annotated[AtomTypes, Value]( diff --git a/zio-http/shared/src/main/scala/zio/http/codec/HttpCodecError.scala b/zio-http/shared/src/main/scala/zio/http/codec/HttpCodecError.scala index 3df1973abb..bfbacea8d9 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/HttpCodecError.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/HttpCodecError.scala @@ -52,8 +52,8 @@ object HttpCodecError { final case class MalformedHeader(headerName: String, textCodec: TextCodec[_]) extends HttpCodecError { def message = s"Malformed header $headerName failed to decode using $textCodec" } - final case class MalformedCustomHeader(headerName: String, cause: DecodeError) extends HttpCodecError { - def message = s"Malformed custom header $headerName could not be decoded: $cause" + final case class DecodingErrorHeader(headerName: String, cause: DecodeError) extends HttpCodecError { + def message = s"Malformed header $headerName could not be decoded: $cause" } final case class MalformedTypedHeader(headerName: String) extends HttpCodecError { def message = s"Malformed header $headerName" @@ -83,6 +83,9 @@ object HttpCodecError { final case class InvalidQueryParamCount(name: String, expected: Int, actual: Int) extends HttpCodecError { def message = s"Invalid query parameter count for $name: expected $expected but found $actual." } + final case class InvalidHeaderCount(name: String, expected: Int, actual: Int) extends HttpCodecError { + def message = s"Invalid query parameter count for $name: expected $expected but found $actual." + } final case class CustomError(name: String, message: String) extends HttpCodecError final case class UnsupportedContentType(contentType: String) extends HttpCodecError { @@ -102,6 +105,9 @@ object HttpCodecError { def isMissingDataOnly(cause: Cause[Any]): Boolean = !cause.isFailure && cause.defects.forall(e => - e.isInstanceOf[HttpCodecError.MissingHeader] || e.isInstanceOf[HttpCodecError.MissingQueryParam], + e.isInstanceOf[HttpCodecError.MissingHeader] + || e.isInstanceOf[HttpCodecError.MissingQueryParam] + || e.isInstanceOf[HttpCodecError.MissingQueryParams] + || e.isInstanceOf[HttpCodecError.MissingHeaders], ) } diff --git a/zio-http/shared/src/main/scala/zio/http/codec/QueryCodecs.scala b/zio-http/shared/src/main/scala/zio/http/codec/QueryCodecs.scala index 4f98ec8e46..a941d4df85 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/QueryCodecs.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/QueryCodecs.scala @@ -15,65 +15,38 @@ */ package zio.http.codec -import scala.annotation.tailrec - +import zio.Chunk import zio.stacktracer.TracingImplicits.disableAutoTrace import zio.schema.Schema -import zio.schema.annotation.simpleEnum +import zio.schema.codec.DecodeError +import zio.schema.validation.ValidationError + +import zio.http.internal.{ErrorConstructor, StringSchemaCodec} private[codec] trait QueryCodecs { - def query[A](name: String)(implicit schema: Schema[A]): QueryCodec[A] = - schema match { - case c: Schema.Collection[_, _] if !supportedCollection(c) => - throw new IllegalArgumentException(s"Collection schema $c is not supported for query codecs") - case enum0: Schema.Enum[_] if !enum0.annotations.exists(_.isInstanceOf[simpleEnum]) => - throw new IllegalArgumentException(s"Enum schema $enum0 is not supported. All cases must be objects.") - case record: Schema.Record[A] if record.fields.size != 1 => - throw new IllegalArgumentException("Use queryAll[A] for records with more than one field") - case record: Schema.Record[A] if !supportedElementSchema(record.fields.head.schema.asInstanceOf[Schema[Any]]) => - throw new IllegalArgumentException( - s"Only primitive types and simple enums can be used in single field records, but got ${record.fields.head.schema}", - ) - case other => - HttpCodec.Query(name, other) - } + private val errorConstructor = new ErrorConstructor { + override def missing(fieldName: String): HttpCodecError = + HttpCodecError.MissingQueryParam(fieldName) - private def supportedCollection(schema: Schema.Collection[_, _]): Boolean = schema match { - case Schema.Map(_, _, _) => - false - case Schema.NonEmptyMap(_, _, _) => - false - case Schema.Sequence(elementSchema, _, _, _, _) => - supportedElementSchema(elementSchema.asInstanceOf[Schema[Any]]) - case Schema.NonEmptySequence(elementSchema, _, _, _, _) => - supportedElementSchema(elementSchema.asInstanceOf[Schema[Any]]) - case Schema.Set(elementSchema, _) => - supportedElementSchema(elementSchema.asInstanceOf[Schema[Any]]) - } + override def missingAll(fieldNames: Chunk[String]): HttpCodecError = + HttpCodecError.MissingQueryParams(fieldNames) - @tailrec - private def supportedElementSchema(elementSchema: Schema[Any]): Boolean = elementSchema match { - case Schema.Lazy(schema0) => supportedElementSchema(schema0()) - case _ => - elementSchema.isInstanceOf[Schema.Primitive[_]] || - elementSchema.isInstanceOf[Schema.Enum[_]] && elementSchema.annotations.exists(_.isInstanceOf[simpleEnum]) || - elementSchema.isInstanceOf[Schema.Record[_]] && elementSchema.asInstanceOf[Schema.Record[_]].fields.size == 1 + override def invalid(errors: Chunk[ValidationError]): HttpCodecError = + HttpCodecError.InvalidEntity.wrap(errors) + + override def malformed(fieldName: String, error: DecodeError): HttpCodecError = + HttpCodecError.MalformedQueryParam(fieldName, error) + + override def invalidCount(fieldName: String, expected: Int, actual: Int): HttpCodecError = + HttpCodecError.InvalidQueryParamCount(fieldName, expected, actual) } + def query[A](name: String)(implicit schema: Schema[A]): QueryCodec[A] = + HttpCodec.Query(StringSchemaCodec.queryFromSchema[A](schema, errorConstructor, name)) + def queryAll[A](implicit schema: Schema[A]): QueryCodec[A] = - schema match { - case _: Schema.Primitive[A] => - throw new IllegalArgumentException("Use query[A](name: String) for primitive types") - case record: Schema.Record[A] => - HttpCodec.Query(record) - case Schema.Optional(s, _) if s.isInstanceOf[Schema.Record[_]] => - HttpCodec.Query(schema) - case _ => - throw new IllegalArgumentException( - "Only case classes can be used with queryAll. Maybe you wanted to use query[A](name: String)?", - ) - } + HttpCodec.Query(StringSchemaCodec.queryFromSchema[A](schema, errorConstructor, null)) } diff --git a/zio-http/shared/src/main/scala/zio/http/codec/StringCodec.scala b/zio-http/shared/src/main/scala/zio/http/codec/StringCodec.scala deleted file mode 100644 index ae49864c51..0000000000 --- a/zio-http/shared/src/main/scala/zio/http/codec/StringCodec.scala +++ /dev/null @@ -1,394 +0,0 @@ -package zio.http.codec - -import java.time._ -import java.util.{Currency, UUID} - -import scala.annotation.tailrec - -import zio._ - -import zio.stream._ - -import zio.schema._ -import zio.schema.annotation.simpleEnum -import zio.schema.codec._ - -import zio.http.Charsets - -object StringCodec { - type StringCodec[A] = Codec[String, Char, A] - private def errorCodec[A](schema: Schema[A]) = - new Codec[String, Char, A] { - override def decode(whole: String): Either[DecodeError, A] = throw new IllegalArgumentException( - s"Schema $schema is not supported by StringCodec.", - ) - - override def streamDecoder: ZPipeline[Any, DecodeError, Char, A] = throw new IllegalArgumentException( - s"Schema $schema is not supported by StringCodec.", - ) - - override def encode(value: A): String = throw new IllegalArgumentException( - s"Schema $schema is not supported by StringCodec.", - ) - - override def streamEncoder: ZPipeline[Any, Nothing, A, Char] = throw new IllegalArgumentException( - s"Schema $schema is not supported by StringCodec.", - ) - } - - @tailrec - private def emptyStringIsValue(schema: Schema[_]): Boolean = { - schema match { - case value: Schema.Optional[_] => - val innerSchema = value.schema - emptyStringIsValue(innerSchema) - case _ => - schema.asInstanceOf[Schema.Primitive[_]].standardType match { - case StandardType.UnitType => true - case StandardType.StringType => true - case StandardType.BinaryType => true - case StandardType.CharType => true - case _ => false - } - } - } - - implicit def fromSchema[A](implicit schema: Schema[A]): Codec[String, Char, A] = { - schema match { - case Schema.Optional(schema, _) => - val codec = fromSchema(schema).asInstanceOf[Codec[String, Char, Any]] - new Codec[String, Char, A] { - override def encode(a: A): String = { - a match { - case Some(value) => codec.encode(value) - case None => "" - } - } - - override def decode(c: String): Either[DecodeError, A] = { - if (c.isEmpty && !emptyStringIsValue(schema)) Right(None.asInstanceOf[A]) - else { - codec.decode(c).map(Some(_)).asInstanceOf[Either[DecodeError, A]] - } - } - - override def streamEncoder: ZPipeline[Any, Nothing, A, Char] = - ZPipeline.map((a: A) => encode(a).toSeq).flattenIterables - override def streamDecoder: ZPipeline[Any, DecodeError, Char, A] = - codec.streamDecoder.map(v => Some(v).asInstanceOf[A]) - } - case enum0: Schema.Enum[_] if enum0.annotations.exists(_.isInstanceOf[simpleEnum]) => - val stringCodec = fromSchema(Schema.Primitive(StandardType.StringType)) - val caseMap = enum0.nonTransientCases - .map(case_ => - case_.schema.asInstanceOf[Schema.CaseClass0[A]].defaultConstruct() -> - case_.caseName, - ) - .toMap - val reverseCaseMap = caseMap.map(_.swap) - new Codec[String, Char, A] { - override def encode(a: A): String = { - val caseName = caseMap(a.asInstanceOf[A]) - stringCodec.encode(caseName) - } - - override def decode(c: String): Either[DecodeError, A] = - stringCodec.decode(c).flatMap { caseName => - reverseCaseMap.get(caseName) match { - case Some(value) => Right(value.asInstanceOf[A]) - case None => Left(DecodeError.MissingCase(caseName, enum0)) - } - } - override def streamEncoder: ZPipeline[Any, Nothing, A, Char] = - ZPipeline.map((a: A) => encode(a).toSeq).flattenIterables - override def streamDecoder: ZPipeline[Any, DecodeError, Char, A] = - stringCodec.streamDecoder.mapZIO { caseName => - reverseCaseMap.get(caseName) match { - case Some(value) => ZIO.succeed(value.asInstanceOf[A]) - case None => ZIO.fail(DecodeError.MissingCase(caseName, enum0)) - } - } - } - - case enum0: Schema.Enum[_] => errorCodec(enum0) - case record: Schema.Record[_] if record.fields.size == 1 => - val fieldSchema = record.fields.head.schema - val codec = fromSchema(fieldSchema).asInstanceOf[Codec[String, Char, A]] - new Codec[String, Char, A] { - override def encode(a: A): String = - codec.encode(record.deconstruct(a)(Unsafe.unsafe).head.get.asInstanceOf[A]) - override def decode(c: String): Either[DecodeError, A] = - codec - .decode(c) - .flatMap(a => - record.construct(Chunk(a))(Unsafe.unsafe).left.map(s => DecodeError.ReadError(Cause.empty, s)), - ) - override def streamEncoder: ZPipeline[Any, Nothing, A, Char] = - ZPipeline.map((a: A) => encode(a).toSeq).flattenIterables - override def streamDecoder: ZPipeline[Any, DecodeError, Char, A] = - codec.streamDecoder.mapZIO(a => - ZIO.fromEither( - record.construct(Chunk(a))(Unsafe.unsafe).left.map(s => DecodeError.ReadError(Cause.empty, s)), - ), - ) - } - case record: Schema.Record[_] => errorCodec(record) - case collection: Schema.Collection[_, _] => errorCodec(collection) - case Schema.Transform(schema, f, g, _, _) => - val codec = fromSchema(schema) - new Codec[String, Char, A] { - override def encode(a: A): String = codec.encode(g(a).fold(e => throw new Exception(e), identity)) - override def decode(c: String): Either[DecodeError, A] = codec - .decode(c) - .flatMap(x => - f(x).left - .map(DecodeError.ReadError(Cause.fail(new Exception("Error during decoding")), _)), - ) - override def streamEncoder: ZPipeline[Any, Nothing, A, Char] = - ZPipeline.mapChunks(_.flatMap(encode)) - override def streamDecoder: ZPipeline[Any, DecodeError, Char, A] = codec.streamDecoder.map { x => - f(x) match { - case Left(value) => throw DecodeError.ReadError(Cause.fail(new Exception("Error in decoding")), value) - case Right(a) => a - } - } - } - case Schema.Primitive(_, _) => - new Codec[String, Char, A] { - val decode0: String => Either[DecodeError, Any] = - schema match { - case Schema.Primitive(standardType, _) => - standardType match { - case StandardType.UnitType => - val result = Right("") - (_: String) => result - case StandardType.StringType => - (s: String) => Right(s) - case StandardType.BoolType => - (s: String) => - s.toLowerCase match { - case "true" | "on" | "yes" | "1" => Right(true) - case "false" | "off" | "no" | "0" => Right(false) - case _ => Left(DecodeError.ReadError(Cause.fail(new Exception("Invalid boolean value")), s)) - } - case StandardType.ByteType => - (s: String) => - try { - Right(s.toByte) - } catch { - case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) - } - case StandardType.ShortType => - (s: String) => - try { - Right(s.toShort) - } catch { - case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) - } - case StandardType.IntType => - (s: String) => - try { - Right(s.toInt) - } catch { - case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) - } - case StandardType.LongType => - (s: String) => - try { - Right(s.toLong) - } catch { - case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) - } - case StandardType.FloatType => - (s: String) => - try { - Right(s.toFloat) - } catch { - case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) - } - case StandardType.DoubleType => - (s: String) => - try { - Right(s.toDouble) - } catch { - case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) - } - case StandardType.BinaryType => - val result = Left(DecodeError.UnsupportedSchema(schema, "TextCodec")) - (_: String) => result - case StandardType.CharType => - (s: String) => Right(s.charAt(0)) - case StandardType.UUIDType => - (s: String) => - try { - Right(UUID.fromString(s)) - } catch { - case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) - } - case StandardType.BigDecimalType => - (s: String) => - try { - Right(BigDecimal(s)) - } catch { - case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) - } - case StandardType.BigIntegerType => - (s: String) => - try { - Right(BigInt(s)) - } catch { - case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) - } - case StandardType.DayOfWeekType => - (s: String) => - try { - Right(DayOfWeek.valueOf(s)) - } catch { - case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) - } - case StandardType.MonthType => - (s: String) => - try { - Right(Month.valueOf(s)) - } catch { - case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) - } - case StandardType.MonthDayType => - (s: String) => - try { - Right(MonthDay.parse(s)) - } catch { - case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) - } - case StandardType.PeriodType => - (s: String) => - try { - Right(Period.parse(s)) - } catch { - case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) - } - case StandardType.YearType => - (s: String) => - try { - Right(Year.parse(s)) - } catch { - case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) - } - case StandardType.YearMonthType => - (s: String) => - try { - Right(YearMonth.parse(s)) - } catch { - case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) - } - case StandardType.ZoneIdType => - (s: String) => - try { - Right(ZoneId.of(s)) - } catch { - case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) - } - case StandardType.ZoneOffsetType => - (s: String) => - try { - Right(ZoneOffset.of(s)) - } catch { - case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) - } - case StandardType.DurationType => - (s: String) => - try { - Right(java.time.Duration.parse(s)) - } catch { - case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) - } - case StandardType.InstantType => - (s: String) => - try { - Right(Instant.parse(s)) - } catch { - case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) - } - case StandardType.LocalDateType => - (s: String) => - try { - Right(LocalDate.parse(s)) - } catch { - case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) - } - case StandardType.LocalTimeType => - (s: String) => - try { - Right(LocalTime.parse(s)) - } catch { - case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) - } - case StandardType.LocalDateTimeType => - (s: String) => - try { - Right(LocalDateTime.parse(s)) - } catch { - case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) - } - case StandardType.OffsetTimeType => - (s: String) => - try { - Right(OffsetTime.parse(s)) - } catch { - case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) - } - case StandardType.OffsetDateTimeType => - (s: String) => - try { - Right(OffsetDateTime.parse(s)) - } catch { - case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) - } - case StandardType.ZonedDateTimeType => - (s: String) => - try { - Right(ZonedDateTime.parse(s)) - } catch { - case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) - } - case StandardType.CurrencyType => - (s: String) => - try { - Right(Currency.getInstance(s)) - } catch { - case e: Exception => Left(DecodeError.ReadError(Cause.fail(e), e.getMessage)) - } - } - case schema => - val result = Left( - DecodeError.UnsupportedSchema(schema, "Only primitive types are supported for text decoding."), - ) - (_: String) => result - } - override def encode(a: A): String = - schema match { - case Schema.Primitive(_, _) => a.toString - case _ => - throw new IllegalArgumentException( - s"Cannot encode $a of type ${a.getClass} with schema $schema", - ) - } - - override def decode(c: String): Either[DecodeError, A] = - decode0(c).map(_.asInstanceOf[A]) - - override def streamEncoder: ZPipeline[Any, Nothing, A, Char] = - ZPipeline.map((a: A) => a.toString.toSeq).flattenIterables - - override def streamDecoder: ZPipeline[Any, DecodeError, Char, A] = - ZPipeline - .chunks[Char] - .map(_.asString) - .mapZIO(s => ZIO.fromEither(decode(s))) - .mapErrorCause(e => Cause.fail(DecodeError.ReadError(e, e.squash.getMessage))) - } - case Schema.Lazy(schema0) => fromSchema(schema0()) - case _ => errorCodec(schema) - } - } -} diff --git a/zio-http/shared/src/main/scala/zio/http/codec/internal/Atomized.scala b/zio-http/shared/src/main/scala/zio/http/codec/internal/Atomized.scala index 3557b776d9..d7a51ebbd1 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/internal/Atomized.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/internal/Atomized.scala @@ -25,34 +25,31 @@ private[http] final case class Atomized[A]( path: A, query: A, header: A, - headerCustom: A, content: A, ) { def get(tag: HttpCodec.AtomTag): A = { tag match { - case HttpCodec.AtomTag.Status => status - case HttpCodec.AtomTag.Path => path - case HttpCodec.AtomTag.Content => content - case HttpCodec.AtomTag.Query => query - case HttpCodec.AtomTag.Header => header - case HttpCodec.AtomTag.HeaderCustom => headerCustom - case HttpCodec.AtomTag.Method => method + case HttpCodec.AtomTag.Status => status + case HttpCodec.AtomTag.Path => path + case HttpCodec.AtomTag.Content => content + case HttpCodec.AtomTag.Query => query + case HttpCodec.AtomTag.Header => header + case HttpCodec.AtomTag.Method => method } } def update(tag: HttpCodec.AtomTag)(f: A => A): Atomized[A] = { tag match { - case HttpCodec.AtomTag.Status => copy(status = f(status)) - case HttpCodec.AtomTag.Path => copy(path = f(path)) - case HttpCodec.AtomTag.Content => copy(content = f(content)) - case HttpCodec.AtomTag.Query => copy(query = f(query)) - case HttpCodec.AtomTag.Header => copy(header = f(header)) - case HttpCodec.AtomTag.HeaderCustom => copy(headerCustom = f(header)) - case HttpCodec.AtomTag.Method => copy(method = f(method)) + case HttpCodec.AtomTag.Status => copy(status = f(status)) + case HttpCodec.AtomTag.Path => copy(path = f(path)) + case HttpCodec.AtomTag.Content => copy(content = f(content)) + case HttpCodec.AtomTag.Query => copy(query = f(query)) + case HttpCodec.AtomTag.Header => copy(header = f(header)) + case HttpCodec.AtomTag.Method => copy(method = f(method)) } } } private[http] object Atomized { def apply[A](defValue: => A): Atomized[A] = - Atomized(defValue, defValue, defValue, defValue, defValue, defValue, defValue) + Atomized(defValue, defValue, defValue, defValue, defValue, defValue) } diff --git a/zio-http/shared/src/main/scala/zio/http/codec/internal/AtomizedCodecs.scala b/zio-http/shared/src/main/scala/zio/http/codec/internal/AtomizedCodecs.scala index af296b3cfe..4ae3b295b7 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/internal/AtomizedCodecs.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/internal/AtomizedCodecs.scala @@ -25,18 +25,16 @@ import zio.http.codec._ private[http] final case class AtomizedCodecs( method: Chunk[SimpleCodec[zio.http.Method, _]], path: Chunk[PathCodec[_]], - query: Chunk[Query[_, _]], + query: Chunk[Query[_]], header: Chunk[Header[_]], - headerCustom: Chunk[HeaderCustom[_]], content: Chunk[BodyCodec[_]], status: Chunk[SimpleCodec[zio.http.Status, _]], ) { self => def append(atom: Atom[_, _]): AtomizedCodecs = atom match { case path0: Path[_] => self.copy(path = path :+ path0.pathCodec) case method0: Method[_] => self.copy(method = method :+ method0.codec) - case query0: Query[_, _] => self.copy(query = query :+ query0) + case query0: Query[_] => self.copy(query = query :+ query0) case header0: Header[_] => self.copy(header = header :+ header0) - case header0: HeaderCustom[_] => self.copy(headerCustom = headerCustom :+ header0) case status0: Status[_] => self.copy(status = status :+ status0.codec) case content0: Content[_] => self.copy(content = content :+ BodyCodec.Single(content0.codec, content0.name)) @@ -50,7 +48,6 @@ private[http] final case class AtomizedCodecs( path = Array.ofDim(path.length), query = Array.ofDim(query.length), header = Array.ofDim(header.length), - headerCustom = Array.ofDim(headerCustom.length), content = Array.ofDim(content.length), status = Array.ofDim(status.length), ) @@ -62,7 +59,6 @@ private[http] final case class AtomizedCodecs( path = path.materialize, query = query.materialize, header = header.materialize, - headerCustom = headerCustom.materialize, content = content.materialize, status = status.materialize, ) @@ -75,7 +71,6 @@ private[http] object AtomizedCodecs { path = Chunk.empty, query = Chunk.empty, header = Chunk.empty, - headerCustom = Chunk.empty, content = Chunk.empty, status = Chunk.empty, ) diff --git a/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala b/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala index e00ba20664..63c315179d 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala @@ -21,12 +21,8 @@ import scala.util.Try import zio._ -import zio.schema.codec.DecodeError -import zio.schema.{Schema, StandardType} - import zio.http.Header.Accept.MediaTypeWithQFactor import zio.http._ -import zio.http.codec.StringCodec.StringCodec import zio.http.codec._ private[codec] trait EncoderDecoder[-AtomTypes, Value] { self => @@ -169,7 +165,6 @@ private[codec] object EncoderDecoder { decodeStatus(status, inputsBuilder.status) decodeMethod(method, inputsBuilder.method) decodeHeaders(headers, inputsBuilder.header) - decodeCustomHeaders(headers, inputsBuilder.headerCustom) decodeBody(config, body, inputsBuilder.content).as(constructor(inputsBuilder)) } @@ -182,7 +177,7 @@ private[codec] object EncoderDecoder { val query = encodeQuery(config, inputs.query) val status = encodeStatus(inputs.status) val method = encodeMethod(inputs.method) - val headers = encodeHeaders(inputs.header) ++ encodeCustomHeaders(inputs.headerCustom) + val headers = encodeHeaders(inputs.header) def contentTypeHeaders = encodeContentType(inputs.content, outputTypes) val body = encodeBody(config, inputs.content, outputTypes) @@ -216,298 +211,19 @@ private[codec] object EncoderDecoder { ) private def decodeQuery(config: CodecConfig, queryParams: QueryParams, inputs: Array[Any]): Unit = - genericDecode[QueryParams, HttpCodec.Query[_, _]]( + genericDecode[QueryParams, HttpCodec.Query[_]]( queryParams, flattened.query, inputs, - (codec, queryParams) => { - val query = codec.erase - val optional = query.isOptionalSchema - val hasDefault = query.codec.defaultValue != null && query.isOptional - val default = query.codec.defaultValue - if (codec.isPrimitive) { - val name = query.nameUnsafe - val hasParam = queryParams.hasQueryParam(name) - if ( - (!hasParam || (queryParams - .unsafeQueryParam(name) == "" && !emptyStringIsValue(codec.codec.schema))) && hasDefault - ) - default - else if (!hasParam) - throw HttpCodecError.MissingQueryParam(name) - else if (queryParams.valueCount(name) != 1) - throw HttpCodecError.InvalidQueryParamCount(name, 1, queryParams.valueCount(name)) - else { - val decoded = - codec.codec.stringCodec.decode(queryParams.unsafeQueryParam(name)) match { - case Left(error) => throw HttpCodecError.MalformedQueryParam(name, error) - case Right(value) => value - } - val validationErrors = codec.codec.erasedSchema.validate(decoded)(codec.codec.erasedSchema) - if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) - else decoded - } - - } else if (codec.isCollection) { - val name = query.nameUnsafe - val hasParam = queryParams.hasQueryParam(name) - - if (!hasParam) { - if (query.codec.defaultValue != null) query.codec.defaultValue - else throw HttpCodecError.MissingQueryParam(name) - } else { - val decoded = queryParams.queryParams(name).map { value => - query.codec.stringCodec.decode(value) match { - case Left(error) => throw HttpCodecError.MalformedQueryParam(name, error) - case Right(value) => value - } - } - if (optional) - Some( - createAndValidateCollection( - query.codec.schema.asInstanceOf[Schema.Optional[_]].schema.asInstanceOf[Schema.Collection[_, _]], - decoded, - ), - ) - else createAndValidateCollection(query.codec.schema.asInstanceOf[Schema.Collection[_, _]], decoded) - } - } else { - val recordSchema = query.codec.recordSchema - val fields = query.codec.recordFields - val hasAllParams = fields.forall { case (field, codec) => - queryParams.hasQueryParam(field.fieldName) || field.optional || codec.isOptional - } - if (!hasAllParams && hasDefault) default - else if (!hasAllParams) throw HttpCodecError.MissingQueryParams { - fields.collect { - case (field, codec) - if !(queryParams.hasQueryParam(field.fieldName) || field.optional || codec.isOptional) => - field.fieldName - } - } - else { - val decoded = fields.map { - case (field, codec) if field.schema.isInstanceOf[Schema.Collection[_, _]] => - val schema = field.schema.asInstanceOf[Schema.Collection[_, _]] - if (!queryParams.hasQueryParam(field.fieldName)) { - if (field.defaultValue.isDefined) field.defaultValue.get - else throw HttpCodecError.MissingQueryParam(field.fieldName) - } else { - val values = queryParams.queryParams(field.fieldName) - val decoded = - values.map(decodeAndUnwrap(field, codec, _, HttpCodecError.MalformedQueryParam.apply)) - createAndValidateCollection(schema, decoded) - - } - case (field, codec) => - val value = queryParams.queryParamOrElse(field.fieldName, null) - val decoded = { - if (value == null || (value == "" && !emptyStringIsValue(codec.schema))) codec.defaultValue - else decodeAndUnwrap(field, codec, value, HttpCodecError.MalformedQueryParam.apply) - } - validateDecoded(codec, decoded) - } - if (optional) { - val constructed = recordSchema.construct(decoded)(Unsafe.unsafe) - constructed match { - case Left(value) => - throw HttpCodecError.MalformedQueryParam( - s"${recordSchema.id}", - DecodeError.ReadError(Cause.empty, value), - ) - case Right(value) => - recordSchema.validate(value)(recordSchema) match { - case errors if errors.nonEmpty => throw HttpCodecError.InvalidEntity.wrap(errors) - case _ => Some(value) - } - } - } else { - val constructed = recordSchema.construct(decoded)(Unsafe.unsafe) - constructed match { - case Left(value) => - throw HttpCodecError.MalformedQueryParam( - s"${recordSchema.id}", - DecodeError.ReadError(Cause.empty, value), - ) - case Right(value) => - recordSchema.validate(value)(recordSchema) match { - case errors if errors.nonEmpty => throw HttpCodecError.InvalidEntity.wrap(errors) - case _ => value - } - } - } - } - } - }, - ) - - private def createAndValidateCollection(schema: Schema.Collection[_, _], decoded: Chunk[Any]) = { - val collection = schema.fromChunk.asInstanceOf[Chunk[Any] => Any](decoded) - val erasedSchema = schema.asInstanceOf[Schema[Any]] - val validationErrors = erasedSchema.validate(collection)(erasedSchema) - if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) - collection - } - - @tailrec - private def emptyStringIsValue(schema: Schema[_]): Boolean = { - schema match { - case value: Schema.Optional[_] => - val innerSchema = value.schema - emptyStringIsValue(innerSchema) - case _ => - schema.asInstanceOf[Schema.Primitive[_]].standardType match { - case StandardType.UnitType => true - case StandardType.StringType => true - case StandardType.BinaryType => true - case StandardType.CharType => true - case _ => false - } - } - } - - private def decodeCustomHeaders(headers: Headers, inputs: Array[Any]): Unit = - genericDecode[Headers, HttpCodec.HeaderCustom[_]]( - headers, - flattened.headerCustom, - inputs, - (header, headers) => { - val optional = header.codec.isOptionalSchema - if (header.codec.isPrimitive) { - val schema = header.erase.codec.schema - val name = header.codec.name.get - val value = headers.getUnsafe(name) - if (value ne null) { - val decoded = header.codec.stringCodec.decode(value) match { - case Left(error) => throw HttpCodecError.MalformedCustomHeader(name, error) - case Right(value) => value - } - val validationErrors = schema.validate(decoded)(schema) - if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) - else decoded - } else { - if (optional) None - else throw HttpCodecError.MissingHeader(name) - } - } else if (header.codec.isCollection) { - val name = header.codec.name.get - val values = headers.rawHeaders(name) - val decoded = values.map { value => - header.codec.stringCodec.decode(value) match { - case Left(error) => throw HttpCodecError.MalformedCustomHeader(name, error) - case Right(value) => value - } - } - if (optional) - Some( - createAndValidateCollection( - header.codec.schema.asInstanceOf[Schema.Optional[_]].schema.asInstanceOf[Schema.Collection[_, _]], - decoded, - ), - ) - else createAndValidateCollection(header.codec.schema.asInstanceOf[Schema.Collection[_, _]], decoded) - } else { - val recordSchema = header.codec.recordSchema - val fields = header.codec.recordFields - val hasAllParams = fields.forall { case (field, codec) => - headers.contains(field.fieldName) || field.optional || codec.isOptional - } - if (!hasAllParams) { - if (header.codec.defaultValue != null && header.codec.isOptional) header.codec.defaultValue - else - throw HttpCodecError.MissingHeaders { - fields.collect { - case (field, codec) if !(headers.contains(field.fieldName) || field.optional || codec.isOptional) => - field.fieldName - } - } - } else { - val decoded = fields.map { - case (field, codec) if field.schema.isInstanceOf[Schema.Collection[_, _]] => - if (!headers.contains(codec.name.get)) { - if (codec.defaultValue != null) codec.defaultValue - else throw HttpCodecError.MissingHeader(codec.name.get) - } else { - val schema = field.schema.asInstanceOf[Schema.Collection[_, _]] - val values = headers.rawHeaders(codec.name.get) - val decoded = - values.map(decodeAndUnwrap(field, codec, _, HttpCodecError.MalformedCustomHeader.apply)) - createAndValidateCollection(schema, decoded) - } - case (field, codec) => - val value = headers.getUnsafe(codec.name.get) - val decoded = - if (value == null || (value == "" && !emptyStringIsValue(codec.schema))) codec.defaultValue - else decodeAndUnwrap(field, codec, value, HttpCodecError.MalformedCustomHeader.apply) - validateDecoded(codec, decoded) - } - if (optional) { - val constructed = recordSchema.construct(decoded)(Unsafe.unsafe) - constructed match { - case Left(value) => - throw HttpCodecError.MalformedCustomHeader( - s"${recordSchema.id}", - DecodeError.ReadError(Cause.empty, value), - ) - case Right(value) => - recordSchema.validate(value)(recordSchema) match { - case errors if errors.nonEmpty => throw HttpCodecError.InvalidEntity.wrap(errors) - case _ => Some(value) - } - } - } else { - val constructed = recordSchema.construct(decoded)(Unsafe.unsafe) - constructed match { - case Left(value) => - throw HttpCodecError.MalformedCustomHeader( - s"${recordSchema.id}", - DecodeError.ReadError(Cause.empty, value), - ) - case Right(value) => - recordSchema.validate(value)(recordSchema) match { - case errors if errors.nonEmpty => throw HttpCodecError.InvalidEntity.wrap(errors) - case _ => value - } - } - } - } - } - }, + (codec, queryParams) => codec.erase.codec.decode(queryParams), ) - private def validateDecoded(codec: HttpCodec.SchemaCodec[Any], decoded: Any) = { - val validationErrors = codec.schema.validate(decoded)(codec.schema) - if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) - decoded - } - - private def decodeAndUnwrap( - field: Schema.Field[_, _], - codec: HttpCodec.SchemaCodec[Any], - value: String, - ex: (String, DecodeError) => HttpCodecError, - ) = { - codec.stringCodec.decode(value) match { - case Left(error) => throw ex(codec.name.get, error) - case Right(value) => value - } - } - private def decodeHeaders(headers: Headers, inputs: Array[Any]): Unit = genericDecode[Headers, HttpCodec.Header[_]]( headers, flattened.header, inputs, - (codec, headers) => - headers.get(codec.headerType.name) match { - case Some(value) => - codec.erase.headerType - .parse(value) - .getOrElse(throw HttpCodecError.MalformedTypedHeader(codec.headerType.name)) - - case None => - throw HttpCodecError.MissingHeader(codec.headerType.name) - }, + (codec, headers) => codec.headerType.fromHeadersUnsafe(headers), ) private def decodeStatus(status: Status, inputs: Array[Any]): Unit = @@ -630,159 +346,19 @@ private[codec] object EncoderDecoder { ) private def encodeQuery(config: CodecConfig, inputs: Array[Any]): QueryParams = - genericEncode[QueryParams, HttpCodec.Query[_, _]]( + genericEncode[QueryParams, HttpCodec.Query[_]]( flattened.query, inputs, QueryParams.empty, - (codec, input, queryParams) => { - val query = codec.erase - val optional = query.isOptionalSchema - val stringCodec = codec.codec.stringCodec.asInstanceOf[StringCodec[Any]] - - if (query.isPrimitive) { - val schema = codec.codec.schema - val name = query.nameUnsafe - if (schema.isInstanceOf[Schema.Primitive[_]]) { - if (schema.asInstanceOf[Schema.Primitive[_]].standardType.isInstanceOf[StandardType.UnitType.type]) { - queryParams.addQueryParams(name, Chunk.empty[String]) - } else { - val encoded = stringCodec.encode(input) - queryParams.addQueryParams(name, Chunk(encoded)) - } - } else if (schema.isInstanceOf[Schema.Optional[_]]) { - val encoded = stringCodec.encode(input) - if (encoded.nonEmpty) queryParams.addQueryParams(name, Chunk(encoded)) else queryParams - } else { - throw new IllegalStateException( - "Only primitive schema is supported for query parameters of type Primitive", - ) - } - } else if (query.isCollection) { - val name = query.nameUnsafe - var in: Any = input - if (optional) { - in = input.asInstanceOf[Option[Any]].getOrElse(Chunk.empty) - } - val values = input.asInstanceOf[Iterable[Any]] - if (values.nonEmpty) { - queryParams.addQueryParams( - name, - Chunk.fromIterable(values.map { value => stringCodec.encode(value) }), - ) - } else queryParams - } else if (query.isRecord) { - val value = input match { - case None => null - case Some(value) => value - case value => value - } - if (value == null) queryParams - else { - val innerSchema = query.codec.recordSchema - val fieldValues = innerSchema.deconstruct(value)(Unsafe.unsafe) - var qp = queryParams - val fieldIt = query.codec.recordFields.iterator - val fieldValuesIt = fieldValues.iterator - while (fieldIt.hasNext) { - val (field, codec) = fieldIt.next() - val name = field.fieldName - val value = fieldValuesIt.next() match { - case Some(value) => value - case None => field.defaultValue - } - value match { - case values: Iterable[_] => - qp = qp.addQueryParams( - name, - Chunk.fromIterable(values.map { v => - codec.stringCodec.encode(v) - }), - ) - case _ => - val encoded = codec.stringCodec.encode(value) - qp = qp.addQueryParam(name, encoded) - } - } - qp - } - } else { - queryParams - } - }, + (codec, input, queryParams) => codec.erase.codec.encode(input, queryParams), ) - private def encodeCustomHeaders(inputs: Array[Any]): Headers = { - genericEncode[Headers, HttpCodec.HeaderCustom[_]]( - flattened.headerCustom, - inputs, - Headers.empty, - (codec, input, headers) => { - val optional = codec.codec.isOptionalSchema - val stringCodec = codec.erase.codec.stringCodec - if (codec.codec.isPrimitive) { - val name = codec.codec.name.get - val value = input - if (optional && value == None) headers - else { - val encoded = stringCodec.encode(value) - headers ++ Headers(name, encoded) - } - } else if (codec.codec.isCollection) { - val name = codec.codec.name.get - val values = input.asInstanceOf[Iterable[Any]] - if (values.nonEmpty) { - headers ++ Headers.FromIterable( - values.map { value => - Header.Custom(name, stringCodec.encode(value)) - }, - ) - } else headers - } else { - val recordSchema = codec.codec.recordSchema - val fields = codec.codec.recordFields - val value = input match { - case None => null - case Some(value) => value - case value => value - } - if (value == null) headers - else { - val fieldValues = recordSchema.deconstruct(value)(Unsafe.unsafe) - var hs = headers - val fieldIt = fields.iterator - val fieldValuesIt = fieldValues.iterator - while (fieldIt.hasNext) { - val (field, codec) = fieldIt.next() - val name = field.fieldName - val value = fieldValuesIt.next() match { - case Some(value) => value - case None => field.defaultValue - } - value match { - case values: Iterable[_] => - hs = hs ++ Headers.FromIterable( - values.map { v => - Header.Custom(name, codec.stringCodec.encode(v)) - }, - ) - case _ => - val encoded = codec.stringCodec.encode(value) - hs = hs ++ Headers(name, encoded) - } - } - hs - } - } - }, - ) - - } private def encodeHeaders(inputs: Array[Any]): Headers = genericEncode[Headers, HttpCodec.Header[_]]( flattened.header, inputs, Headers.empty, - (codec, input, headers) => headers ++ Headers(codec.headerType.name, codec.erase.headerType.render(input)), + (codec, input, headers) => headers ++ codec.erase.headerType.toHeaders(input), ) private def encodeStatus(inputs: Array[Any]): Option[Status] = diff --git a/zio-http/shared/src/main/scala/zio/http/endpoint/http/HttpGen.scala b/zio-http/shared/src/main/scala/zio/http/endpoint/http/HttpGen.scala index a1b84d1cd8..cdce331a56 100644 --- a/zio-http/shared/src/main/scala/zio/http/endpoint/http/HttpGen.scala +++ b/zio-http/shared/src/main/scala/zio/http/endpoint/http/HttpGen.scala @@ -112,8 +112,8 @@ object HttpGen { case JsonSchema.Object(properties, _, _) => properties.flatMap { case (key, value) => loop(value, Some(key)) }.toSeq case JsonSchema.Enum(values) => Seq(HttpVariable(getName(name), None, Some(s"enum: ${values.mkString(",")}"))) - case JsonSchema.Null => Seq.empty - case JsonSchema.AnyJson => Seq.empty + case JsonSchema.Null => Seq.empty + case JsonSchema.AnyJson => Seq.empty } bodySchema0 match { @@ -127,37 +127,33 @@ object HttpGen { def headersVariables(inAtoms: AtomizedMetaCodecs): Seq[HttpVariable] = inAtoms.header.collect { case mc @ MetaCodec(HttpCodec.Header(headerType, _), _) => HttpVariable( - headerType.name.capitalize, - mc.examples.values.headOption.map(e => headerType.render(e.asInstanceOf[headerType.HeaderValue])), + headerType.names.head.capitalize, + mc.examples.values.headOption.map(e => + headerType.toHeaders(e.asInstanceOf[headerType.HeaderValue]).head.renderedValue, + ), ) } def queryVariables(config: CodecConfig, inAtoms: AtomizedMetaCodecs): Seq[HttpVariable] = { - inAtoms.query.collect { - case mc @ MetaCodec(HttpCodec.Query(codec, _), _) if codec.isPrimitive => + inAtoms.query.collect { case mc @ MetaCodec(HttpCodec.Query(codec, _), _) => + val recordSchema = (codec.schema match { + case value if value.isInstanceOf[Schema.Optional[_]] => value.asInstanceOf[Schema.Optional[Any]].schema + case _ => codec.schema + }).asInstanceOf[Schema.Record[Any]] + val examples = mc.examples.values.headOption.map { ex => + recordSchema.deconstruct(ex)(Unsafe.unsafe) + } + codec.recordFields.zipWithIndex.map { case ((field, codec), index) => HttpVariable( - codec.name.get, - mc.examples.values.headOption.map((e: Any) => codec.stringCodec.encode(e)), - ) :: Nil - case mc @ MetaCodec(HttpCodec.Query(codec, _), _) if codec.isRecord => - val recordSchema = (codec.schema match { - case value if value.isInstanceOf[Schema.Optional[_]] => value.asInstanceOf[Schema.Optional[Any]].schema - case _ => codec.schema - }).asInstanceOf[Schema.Record[Any]] - val examples = mc.examples.values.headOption.map { ex => - recordSchema.deconstruct(ex)(Unsafe.unsafe) - } - codec.recordFields.zipWithIndex.map { case ((field, codec), index) => - HttpVariable( - field.name, - examples.map(values => { - val fieldValue = values(index) - .orElse(field.defaultValue) - .getOrElse(throw new Exception(s"No value or default value for field ${field.name}")) - codec.stringCodec.encode(fieldValue) - }), - ) - } + field.name, + examples.map(values => { + val fieldValue = values(index) + .orElse(field.defaultValue) + .getOrElse(throw new Exception(s"No value or default value for field ${field.name}")) + codec.encode(fieldValue) + }), + ) + } }.flatten } diff --git a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala index 9f3c47e385..a7cc99bc4f 100644 --- a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala +++ b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala @@ -21,12 +21,13 @@ import zio.http.codec._ import zio.http.endpoint._ import zio.http.endpoint.openapi.JsonSchema.SchemaStyle import zio.http.endpoint.openapi.OpenAPI.{Path, PathItem} +import zio.http.internal.StringSchemaCodec object OpenAPIGen { private val PathWildcard = "pathWildcard" private[openapi] def groupMap[A, K, B](chunk: Chunk[A])(key: A => K)(f: A => B): immutable.Map[K, Chunk[B]] = { - val m = mutable.Map.empty[K, mutable.Builder[B, Chunk[B]]] + val m = mutable.Map.empty[K, mutable.Builder[B, Chunk[B]]] for (elem <- chunk) { val k = key(elem) val bldr = m.getOrElseUpdate(k, Chunk.newBuilder[B]) @@ -101,11 +102,10 @@ object OpenAPIGen { final case class AtomizedMetaCodecs( method: Chunk[MetaCodec[SimpleCodec[Method, _]]], path: Chunk[MetaCodec[SegmentCodec[_]]], - query: Chunk[MetaCodec[HttpCodec.Query[_, _]]], + query: Chunk[MetaCodec[HttpCodec.Query[_]]], header: Chunk[MetaCodec[HttpCodec.Header[_]]], content: Chunk[MetaCodec[HttpCodec.Atom[Content, _]]], status: Chunk[MetaCodec[HttpCodec.Status[_]]], - headerCustom: Chunk[MetaCodec[HttpCodec.HeaderCustom[_]]] = Chunk.empty, ) { def append(metaCodec: MetaCodec[_]): AtomizedMetaCodecs = metaCodec match { case MetaCodec(codec: HttpCodec.Method[_], annotations) => @@ -115,12 +115,10 @@ object OpenAPIGen { ) case MetaCodec(_: SegmentCodec[_], _) => copy(path = path :+ metaCodec.asInstanceOf[MetaCodec[SegmentCodec[_]]]) - case MetaCodec(_: HttpCodec.Query[_, _], _) => - copy(query = query :+ metaCodec.asInstanceOf[MetaCodec[HttpCodec.Query[_, _]]]) + case MetaCodec(_: HttpCodec.Query[_], _) => + copy(query = query :+ metaCodec.asInstanceOf[MetaCodec[HttpCodec.Query[_]]]) case MetaCodec(_: HttpCodec.Header[_], _) => copy(header = header :+ metaCodec.asInstanceOf[MetaCodec[HttpCodec.Header[_]]]) - case MetaCodec(_: HttpCodec.HeaderCustom[_], _) => - copy(headerCustom = headerCustom :+ metaCodec.asInstanceOf[MetaCodec[HttpCodec.HeaderCustom[_]]]) case MetaCodec(_: HttpCodec.Status[_], _) => copy(status = status :+ metaCodec.asInstanceOf[MetaCodec[HttpCodec.Status[_]]]) case MetaCodec(_: HttpCodec.Content[_], _) => @@ -138,7 +136,6 @@ object OpenAPIGen { header ++ that.header, content ++ that.content, status ++ that.status, - headerCustom ++ that.headerCustom, ) def contentExamples: Map[String, OpenAPI.ReferenceOr.Or[OpenAPI.Example]] = @@ -176,7 +173,6 @@ object OpenAPIGen { header.materialize, content.materialize, status.materialize, - headerCustom.materialize, ) } @@ -188,7 +184,6 @@ object OpenAPIGen { header = Chunk.empty, content = Chunk.empty, status = Chunk.empty, - headerCustom = Chunk.empty, ) def flatten[R, A](codec: HttpCodec[R, A]): AtomizedMetaCodecs = { @@ -758,85 +753,41 @@ object OpenAPIGen { def parameters: Set[OpenAPI.ReferenceOr[OpenAPI.Parameter]] = queryParams ++ pathParams ++ headerParams - def queryParams: Set[OpenAPI.ReferenceOr[OpenAPI.Parameter]] = { - inAtoms.query.collect { - case mc @ MetaCodec(q @ HttpCodec.Query(codec, _), _) if codec.isPrimitive => + def queryParams: Set[OpenAPI.ReferenceOr[OpenAPI.Parameter]] = + inAtoms.query.collect { case mc @ MetaCodec(HttpCodec.Query(codec, _), _) => + val recordSchema = (codec.schema match { + case schema if schema.isInstanceOf[Schema.Optional[_]] => schema.asInstanceOf[Schema.Optional[_]].schema + case _ => codec.schema + }).asInstanceOf[Schema.Record[Any]] + val examples = mc.examples.map { case (exName, ex) => + exName -> recordSchema.deconstruct(ex)(Unsafe.unsafe) + } + codec.recordFields.zipWithIndex.map { case ((field, codec), index) => OpenAPI.ReferenceOr.Or( OpenAPI.Parameter.queryParameter( - name = q.nameUnsafe, + name = field.name, description = mc.docsOpt, schema = Some(OpenAPI.ReferenceOr.Or(JsonSchema.fromZSchema(codec.schema))), deprecated = mc.deprecated, style = OpenAPI.Parameter.Style.Form, explode = false, allowReserved = false, - examples = mc.examples.map { case (name, value) => - name -> OpenAPI.ReferenceOr.Or(OpenAPI.Example(value = Json.Str(value.toString))) - }, - required = mc.required && !q.isOptional, - ), - ) :: Nil - case mc @ MetaCodec(HttpCodec.Query(codec, _), _) if codec.isRecord => - val recordSchema = (codec.schema match { - case schema if schema.isInstanceOf[Schema.Optional[_]] => schema.asInstanceOf[Schema.Optional[_]].schema - case _ => codec.schema - }).asInstanceOf[Schema.Record[Any]] - val examples = mc.examples.map { case (exName, ex) => - exName -> recordSchema.deconstruct(ex)(Unsafe.unsafe) - } - codec.recordFields.zipWithIndex.map { case ((field, codec), index) => - OpenAPI.ReferenceOr.Or( - OpenAPI.Parameter.queryParameter( - name = field.name, - description = mc.docsOpt, - schema = Some(OpenAPI.ReferenceOr.Or(JsonSchema.fromZSchema(codec.schema))), - deprecated = mc.deprecated, - style = OpenAPI.Parameter.Style.Form, - explode = false, - allowReserved = false, - examples = examples.map { case (exName, values) => - val fieldValue = values(index) - .orElse(field.defaultValue) - .getOrElse( - throw new Exception(s"No value or default value found for field ${exName}_${field.name}"), - ) - s"${exName}_${field.name}" -> OpenAPI.ReferenceOr.Or( - OpenAPI.Example(value = Json.Str(codec.stringCodec.encode(fieldValue))), + examples = examples.map { case (exName, values) => + val fieldValue = values(index) + .orElse(field.defaultValue) + .getOrElse( + throw new Exception(s"No value or default value found for field ${exName}_${field.name}"), ) - }, - required = mc.required, - ), - ) - - } - case mc @ MetaCodec(q @ HttpCodec.Query(codec, _), _) if codec.isCollection => - var required = false - val schema = codec.schema.asInstanceOf[Schema.Collection[_, _]] match { - case s: Schema.Sequence[_, _, _] => s.elementSchema - case _: Schema.Map[_, _] => throw new Exception("Map query parameters not supported") - case _: Schema.NonEmptyMap[_, _] => throw new Exception("Map query parameters not supported") - case s: Schema.NonEmptySequence[_, _, _] => - required = true - s.elementSchema - case s: Schema.Set[_] => s.elementSchema - } - OpenAPI.ReferenceOr.Or( - OpenAPI.Parameter.queryParameter( - name = q.nameUnsafe, - description = mc.docsOpt, - schema = Some(OpenAPI.ReferenceOr.Or(JsonSchema.fromZSchema(schema))), - deprecated = mc.deprecated, - style = OpenAPI.Parameter.Style.Form, - explode = false, - allowReserved = false, - examples = mc.examples.map { case (exName, value) => - exName -> OpenAPI.ReferenceOr.Or(OpenAPI.Example(value = Json.Str(value.toString))) + s"${exName}_${field.name}" -> OpenAPI.ReferenceOr.Or( + OpenAPI.Example(value = Json.Str(codec.encode(fieldValue))), + ) }, - required = required, + required = mc.required && !StringSchemaCodec.isOptional(field.schema), ), - ) :: Nil - } - }.flatten.toSet + ) + + } + }.flatten.toSet def pathParams: Set[OpenAPI.ReferenceOr[OpenAPI.Parameter]] = inAtoms.path.collect { @@ -861,35 +812,19 @@ object OpenAPIGen { .map { case mc @ MetaCodec(codec, _) => OpenAPI.ReferenceOr.Or( OpenAPI.Parameter.headerParameter( - name = mc.name.getOrElse(codec.headerType.name), + name = mc.name.getOrElse(codec.headerType.names.head), description = mc.docsOpt, definition = Some(OpenAPI.ReferenceOr.Or(JsonSchema.String().nullable(!mc.required))), deprecated = mc.deprecated, - examples = mc.examples.map { case (name, value) => - name -> OpenAPI.ReferenceOr.Or(OpenAPI.Example(codec.headerType.render(value).toJsonAST.toOption.get)) - }, - required = mc.required, - ), - ) - } - .toSet ++ inAtoms.headerCustom - .asInstanceOf[Chunk[MetaCodec[HttpCodec.HeaderCustom[Any]]]] - // todo must handle collection and record - .map { case mc @ MetaCodec(codec, _) => - OpenAPI.ReferenceOr.Or( - OpenAPI.Parameter.headerParameter( - name = codec.codec.name.getOrElse(throw new Exception("Header parameter must have a name")), - description = mc.docsOpt, - definition = Some(OpenAPI.ReferenceOr.Or(JsonSchema.String().nullable(!mc.required))), - deprecated = mc.deprecated, - examples = mc.examples.map { case (name, value) => - name -> OpenAPI.ReferenceOr - .Or(OpenAPI.Example(codec.codec.stringCodec.encode(value).toJsonAST.toOption.get)) - }, + examples = Map.empty, +// mc.examples.map { case (name, value) => +// name -> OpenAPI.ReferenceOr.Or(OpenAPI.Example(codec.headerType.render(value).toJsonAST.toOption.get)) +// }, required = mc.required, ), ) } + .toSet def genDiscriminator(schema: Schema[_]): Option[OpenAPI.Discriminator] = { schema match { @@ -1155,7 +1090,8 @@ object OpenAPIGen { private def headersFrom(codec: AtomizedMetaCodecs) = { codec.header.map { case mc @ MetaCodec(codec, _) => - codec.headerType.name -> OpenAPI.ReferenceOr.Or( + // todo use all headers + codec.headerType.names.head -> OpenAPI.ReferenceOr.Or( OpenAPI.Header( description = mc.docsOpt, required = true, diff --git a/zio-http/shared/src/main/scala/zio/http/internal/HeaderGetters.scala b/zio-http/shared/src/main/scala/zio/http/internal/HeaderGetters.scala index 1a9841c329..2fb9b2c441 100644 --- a/zio-http/shared/src/main/scala/zio/http/internal/HeaderGetters.scala +++ b/zio-http/shared/src/main/scala/zio/http/internal/HeaderGetters.scala @@ -67,7 +67,7 @@ trait HeaderGetters { self => /** Gets the raw unparsed header value */ final def rawHeader(name: CharSequence): Option[String] = headers.get(name) - final def rawHeaders(name: CharSequence): Chunk[String] = + def rawHeaders(name: CharSequence): Chunk[String] = Chunk.fromIterator( headers.iterator .filter(header => CharSequenceExtensions.equals(header.headerNameAsCharSequence, name, CaseMode.Insensitive)) diff --git a/zio-http/shared/src/main/scala/zio/http/internal/HeaderModifier.scala b/zio-http/shared/src/main/scala/zio/http/internal/HeaderModifier.scala index 255358d8c5..629f497f01 100644 --- a/zio-http/shared/src/main/scala/zio/http/internal/HeaderModifier.scala +++ b/zio-http/shared/src/main/scala/zio/http/internal/HeaderModifier.scala @@ -39,6 +39,9 @@ trait HeaderModifier[+A] { self => final def addHeaders(headers: Headers): A = updateHeaders(_ ++ headers) + final def addHeaders(headers: Iterable[(CharSequence, CharSequence)]): A = + addHeaders(Headers.fromIterable(headers.map { case (k, v) => Header.Custom(k, v) })) + final def removeHeader(headerType: HeaderType): A = removeHeader(headerType.name) final def removeHeader(name: String): A = removeHeaders(Set(name)) diff --git a/zio-http/shared/src/main/scala/zio/http/internal/QueryModifier.scala b/zio-http/shared/src/main/scala/zio/http/internal/QueryModifier.scala index 77996c462e..75ad863332 100644 --- a/zio-http/shared/src/main/scala/zio/http/internal/QueryModifier.scala +++ b/zio-http/shared/src/main/scala/zio/http/internal/QueryModifier.scala @@ -45,6 +45,13 @@ trait QueryModifier[+A] { self: QueryOps[A] with A => def addQueryParams(values: String): A = updateQueryParams(params => params ++ QueryParams.decode(values)) + def addQueryParams(queryParams: Iterable[(String, String)]): A = + updateQueryParams(params => + params ++ QueryParams(queryParams.groupBy(_._1).map { case (k, v) => + k -> Chunk.fromIterable(v).map(_._2) + }), + ) + /** * Removes the specified key from the query parameters. */ diff --git a/zio-http/shared/src/main/scala/zio/http/internal/StringSchemaCodec.scala b/zio-http/shared/src/main/scala/zio/http/internal/StringSchemaCodec.scala new file mode 100644 index 0000000000..d56047298a --- /dev/null +++ b/zio-http/shared/src/main/scala/zio/http/internal/StringSchemaCodec.scala @@ -0,0 +1,723 @@ +package zio.http.internal + +import java.time._ +import java.util.{Currency, UUID} + +import scala.annotation.tailrec +import scala.util.Try + +import zio.{Cause, Chunk, Unsafe} + +import zio.schema.codec.DecodeError +import zio.schema.validation.{Validation, ValidationError} +import zio.schema.{Schema, StandardType, TypeId} + +import zio.http.codec.HttpCodecError +import zio.http.internal.StringSchemaCodec.{PrimitiveCodec, decodeAndUnwrap, emptyStringIsValue, validateDecoded} +import zio.http.{Headers, QueryParams} + +private[http] trait ErrorConstructor { + def missing(fieldName: String): HttpCodecError + def missingAll(fieldNames: Chunk[String]): HttpCodecError + def invalid(errors: Chunk[ValidationError]): HttpCodecError + def malformed(fieldName: String, error: DecodeError): HttpCodecError + def invalidCount(fieldName: String, expected: Int, actual: Int): HttpCodecError +} + +private[http] trait StringSchemaCodec[A, Target] { + private[http] def schema: Schema[A] + private[http] def add(target: Target, key: String, value: String): Target + private[http] def addAll(target: Target, headers: Iterable[(String, String)]): Target + private[http] def contains(target: Target, key: String): Boolean + private[http] def unsafeGet(target: Target, key: String): String + private[http] def getAll(target: Target, key: String): Chunk[String] + private[http] def count(target: Target, key: String): Int + private[http] def error: ErrorConstructor + private[http] def kebabCase: Boolean + private[http] val defaultValue: A + private[http] val isOptional: Boolean + private[http] val isOptionalSchema: Boolean + private[http] val recordFields: Chunk[(Schema.Field[_, _], PrimitiveCodec[Any])] = { + val fields = schema match { + case record: Schema.Record[A] => + record.fields + case s: Schema.Optional[_] if s.schema.isInstanceOf[Schema.Record[_]] => + s.schema.asInstanceOf[Schema.Record[A]].fields + case s: Schema.Transform[_, _, _] if s.schema.isInstanceOf[Schema.Record[_]] => + s.schema.asInstanceOf[Schema.Record[A]].fields + case _ => Chunk.empty + } + fields.map(StringSchemaCodec.unlazyField).map { + case field if field.schema.isInstanceOf[Schema.Collection[_, _]] => + val elementSchema = field.schema.asInstanceOf[Schema.Collection[_, _]] match { + case s: Schema.NonEmptySequence[_, _, _] => s.elementSchema + case s: Schema.Sequence[_, _, _] => s.elementSchema + case s: Schema.Set[_] => s.elementSchema + case _: Schema.Map[_, _] => throw new IllegalArgumentException("Maps are not supported") + case _: Schema.NonEmptyMap[_, _] => throw new IllegalArgumentException("Maps are not supported") + } + val codec = PrimitiveCodec(elementSchema).asInstanceOf[PrimitiveCodec[Any]] + (StringSchemaCodec.mapFieldName(field, kebabCase), codec) + case field => + val codec = + PrimitiveCodec(field.annotations.foldLeft(field.schema)(_.annotate(_))).asInstanceOf[PrimitiveCodec[Any]] + (StringSchemaCodec.mapFieldName(field, kebabCase), codec) + } + } + + private[http] val recordSchema: Schema.Record[Any] = schema match { + case record: Schema.Record[_] => + record.asInstanceOf[Schema.Record[Any]] + case s: Schema.Optional[_] if s.schema.isInstanceOf[Schema.Record[_]] => + s.schema.asInstanceOf[Schema.Record[Any]] + case _ => null + } + + private def createAndValidateCollection(schema: Schema.Collection[_, _], decoded: Chunk[Any]) = { + val collection = schema.fromChunk.asInstanceOf[Chunk[Any] => Any](decoded) + val erasedSchema = schema.asInstanceOf[Schema[Any]] + val validationErrors = erasedSchema.validate(collection)(erasedSchema) + if (validationErrors.nonEmpty) throw error.invalid(validationErrors) + collection + } + + private[http] def decode(target: Target): A = { + val optional = isOptionalSchema + val hasDefault = defaultValue != null && isOptional + val default = defaultValue + val hasAllParams = recordFields.forall { case (field, codec) => + contains(target, field.fieldName) || field.optional || codec.isOptional + } + if (!hasAllParams && hasDefault) default + else if (!hasAllParams) { + throw error.missingAll { + recordFields.collect { + case (field, codec) if !(contains(target, field.fieldName) || field.optional || codec.isOptional) => + field.fieldName + } + } + } else { + val decoded = recordFields.map { + case (field, codec) if field.schema.isInstanceOf[Schema.Collection[_, _]] => + val schema = field.schema.asInstanceOf[Schema.Collection[_, _]] + if (!contains(target, field.fieldName)) { + if (field.defaultValue.isDefined) field.defaultValue.get + else throw error.missing(field.fieldName) + } else { + val values = getAll(target, field.fieldName) + val decoded = + values.map(decodeAndUnwrap(field, codec, _, error.malformed)) + createAndValidateCollection(schema, decoded) + + } + case (field, codec) => + val count0 = count(target, field.fieldName) + if (count0 > 1) throw error.invalidCount(field.fieldName, 1, count0) + val value = unsafeGet(target, field.fieldName) + val decoded = { + if (value == null || (value == "" && !emptyStringIsValue(codec.schema) && codec.isOptional)) + codec.defaultValue + else decodeAndUnwrap(field, codec, value, error.malformed) + } + validateDecoded(codec, decoded, error) + } + if (optional) { + val constructed = recordSchema.construct(decoded)(Unsafe.unsafe) + constructed match { + case Left(value) => + throw error.malformed( + s"${recordSchema.id}", + DecodeError.ReadError(Cause.empty, value), + ) + case Right(value) => + recordSchema.validate(value)(recordSchema) match { + case errors if errors.nonEmpty => throw error.invalid(errors) + case _ => Some(value).asInstanceOf[A] + } + } + } else { + val constructed = recordSchema.construct(decoded)(Unsafe.unsafe) + constructed match { + case Left(value) => + throw error.malformed( + s"${recordSchema.id}", + DecodeError.ReadError(Cause.empty, value), + ) + case Right(value) => + recordSchema.validate(value)(recordSchema) match { + case errors if errors.nonEmpty => throw error.invalid(errors) + case _ => value.asInstanceOf[A] + } + } + } + } + + } + + private[http] def encode(input: A, target: Target): Target = { + val fields = recordFields + val value = input.asInstanceOf[Any] match { + case None => null + case it: Iterable[_] if it.isEmpty => null + case Some(value) => value + case value => value + } + if (value == null) target + else { + val fieldValues = recordSchema.deconstruct(value)(Unsafe.unsafe) + var target0 = target + val fieldIt = fields.iterator + val fieldValuesIt = fieldValues.iterator + while (fieldIt.hasNext) { + val (field, codec) = fieldIt.next() + val name = field.fieldName + val value = fieldValuesIt.next() match { + case Some(value) => value + case None => field.defaultValue + } + value match { + case values: Iterable[_] => + target0 = addAll(target0, values.map { v => (name, codec.encode(v)) }) + case _ => + val encoded = codec.encode(value) + target0 = add(target0, name, encoded) + } + } + target0 + } + } + + private[http] def optional: StringSchemaCodec[Option[A], Target] + +} + +private[http] object StringSchemaCodec { + private[http] def unlazyField(field: Schema.Field[_, _]): Schema.Field[_, _] = field match { + case f if f.schema.isInstanceOf[Schema.Lazy[_]] => + Schema.Field( + f.name, + f.schema.asInstanceOf[Schema.Lazy[_]].schema.asInstanceOf[Schema[Any]], + f.annotations, + f.validation.asInstanceOf[Validation[Any]], + f.get.asInstanceOf[Any => Any], + f.set.asInstanceOf[(Any, Any) => Any], + ) + case f => f + } + private[http] def defaultValue[A](schema: Schema[A]): A = + if (schema.isInstanceOf[Schema.Collection[_, _]]) { + Try(schema.asInstanceOf[Schema.Collection[A, _]].empty).fold( + _ => null.asInstanceOf[A], + identity, + ) + } else { + schema.defaultValue match { + case Right(value) => value + case Left(_) => + schema match { + case _: Schema.Optional[_] => None.asInstanceOf[A] + case collection: Schema.Collection[A, _] => + Try(collection.empty).fold( + _ => null.asInstanceOf[A], + identity, + ) + case _ => null.asInstanceOf[A] + } + } + } + + private[http] def isOptional(schema: Schema[_]): Boolean = schema match { + case _: Schema.Optional[_] => + true + case record: Schema.Record[_] => + record.fields.forall(_.optional) || record.defaultValue.isRight + case d: Schema.Collection[_, _] => + val bool = Try(d.empty).isSuccess || d.defaultValue.isRight + bool + case _ => + false + } + + private[http] def isOptionalSchema(schema: Schema[_]): Boolean = + schema match { + case _: Schema.Optional[_] => true + case s: Schema.Transform[_, _, _] if s.schema.isInstanceOf[Schema.Optional[_]] => true + case _ => false + } + + private[http] final case class PrimitiveCodec[A]( + private[http] val schema: Schema[A], + ) { + + val defaultValue: A = + StringSchemaCodec.defaultValue(schema) + + private[http] val isOptional: Boolean = + StringSchemaCodec.isOptional(schema) + + private[http] val isOptionalSchema: Boolean = + StringSchemaCodec.isOptionalSchema(schema) + + private[http] val encode: A => String = + PrimitiveCodec.primitiveSchemaEncoder(schema) + + private[http] val decode: String => A = + PrimitiveCodec.primitiveSchemaDecoder(schema) + + } + + object PrimitiveCodec { + + private[http] def primitiveSchemaDecoder[A](schema: Schema[A]): String => A = schema match { + case Schema.Optional(schema, _) => + primitiveSchemaDecoder(schema).andThen(Some(_)).asInstanceOf[String => A] + case Schema.Transform(schema, f, _, _, _) => + primitiveSchemaDecoder(schema).andThen { + f(_) match { + case Left(value) => throw new IllegalArgumentException(value) + case Right(value) => value + } + }.asInstanceOf[String => A] + case Schema.Primitive(standardType, _) => + parsePrimitive(standardType.asInstanceOf[StandardType[Any]]).asInstanceOf[String => A] + case Schema.Lazy(schema0) => + primitiveSchemaDecoder(schema0()).asInstanceOf[String => A] + case _ => throw new IllegalArgumentException(s"Unsupported schema $schema") + } + + private[http] def primitiveSchemaEncoder[A](schema: Schema[A]): A => String = schema match { + case Schema.Optional(schema, _) => + val innerEncoder: Any => String = primitiveSchemaEncoder(schema.asInstanceOf[Schema[Any]]) + (a: A) => if (a.isInstanceOf[None.type]) null else innerEncoder(a.asInstanceOf[Some[Any]].get) + case Schema.Transform(schema, f, _, _, _) => + val innerEncoder: Any => String = primitiveSchemaEncoder(schema.asInstanceOf[Schema[Any]]) + (a: A) => + f.asInstanceOf[Any => Either[String, Any]](a.asInstanceOf[Any]) match { + case Left(value) => throw new IllegalArgumentException(value) + case Right(value) => innerEncoder(value) + } + case Schema.Lazy(schema0) => + primitiveSchemaEncoder(schema0()).asInstanceOf[A => String] + case Schema.Primitive(_, _) => + (a: A) => a.toString + case _ => + throw new IllegalArgumentException(s"Unsupported schema $schema") + } + } + + private def decodeAndUnwrap( + field: Schema.Field[_, _], + codec: PrimitiveCodec[Any], + value: String, + ex: (String, DecodeError) => HttpCodecError, + ) = + try codec.decode(value) + catch { + case err: DecodeError => throw ex(field.fieldName, err) + } + + private def validateDecoded(codec: PrimitiveCodec[Any], decoded: Any, error: ErrorConstructor) = { + val validationErrors = codec.schema.validate(decoded)(codec.schema) + if (validationErrors.nonEmpty) throw error.invalid(validationErrors) + decoded + } + + @tailrec + private def emptyStringIsValue(schema: Schema[_]): Boolean = { + schema match { + case value: Schema.Optional[_] => + val innerSchema = value.schema + emptyStringIsValue(innerSchema) + case _ => + schema.asInstanceOf[Schema.Primitive[_]].standardType match { + case StandardType.UnitType => true + case StandardType.StringType => true + case StandardType.BinaryType => true + case StandardType.CharType => true + case _ => false + } + } + } + private[http] def mapFieldName(field: Schema.Field[_, _], kebabCase: Boolean): Schema.Field[_, _] = { + Schema.Field( + if (!kebabCase) field.fieldName else camelToKebab(field.fieldName), + field.annotations.foldLeft(field.schema)(_ annotate _).asInstanceOf[Schema[Any]], + field.annotations, + field.validation.asInstanceOf[Validation[Any]], + field.get.asInstanceOf[Any => Any], + field.set.asInstanceOf[(Any, Any) => Any], + ) + } + + private[http] def headerFromSchema[A]( + schema0: Schema[A], + error0: ErrorConstructor, + name: String, + ): StringSchemaCodec[A, Headers] = { + + def stringSchemaCodec(schema1: Schema[Any]): StringSchemaCodec[A, Headers] = + new StringSchemaCodec[A, Headers] { + override def schema: Schema[A] = schema1.asInstanceOf[Schema[A]] + + override private[http] def add(headers: Headers, key: String, value: String): Headers = + headers.addHeader(key, value) + + override private[http] def addAll(headers: Headers, kvs: Iterable[(String, String)]): Headers = + headers.addHeaders(kvs) + + override private[http] def contains(headers: Headers, key: String): Boolean = + headers.contains(key) + + override private[http] def unsafeGet(headers: Headers, key: String): String = + headers.getUnsafe(key) + + override private[http] def getAll(headers: Headers, key: String): Chunk[String] = + headers.rawHeaders(key) + + override private[http] def count(headers: Headers, key: String): Int = + headers.rawHeaders(key).size + + override private[http] def error: ErrorConstructor = + error0 + + override private[http] def optional: StringSchemaCodec[Option[A], Headers] = + StringSchemaCodec.headerFromSchema(schema.optional, error0, name) + + override private[http] def kebabCase: Boolean = + true + override private[http] val defaultValue: A = + StringSchemaCodec.defaultValue(schema0) + override private[http] val isOptional: Boolean = + StringSchemaCodec.isOptional(schema0) + override private[http] val isOptionalSchema: Boolean = + StringSchemaCodec.isOptionalSchema(schema0) + } + schema0 match { + case s @ Schema.Primitive(_, _) => + stringSchemaCodec(recordSchema(s.asInstanceOf[Schema[Any]], name)) + case s @ Schema.Optional(schema, _) => + schema match { + case _: Schema.Collection[_, _] | _: Schema.Primitive[_] => + stringSchemaCodec(recordSchema(s.asInstanceOf[Schema[Any]], name)) + case s if s.isInstanceOf[Schema.Record[_]] => stringSchemaCodec(schema.asInstanceOf[Schema[Any]]) + case _ => throw new IllegalArgumentException(s"Unsupported schema $s") + } + case s @ Schema.Transform(schema, _, _, _, _) => + schema match { + case _: Schema.Collection[_, _] | _: Schema.Primitive[_] => + stringSchemaCodec(recordSchema(s.asInstanceOf[Schema[Any]], name)) + case _: Schema.Record[_] => stringSchemaCodec(s.asInstanceOf[Schema[Any]]) + case _ => throw new IllegalArgumentException(s"Unsupported schema $s") + } + case Schema.Lazy(schema0) => + headerFromSchema( + schema0().asInstanceOf[Schema[A]], + error0, + name, + ) + case _: Schema.Collection[_, _] => + stringSchemaCodec(recordSchema(schema0.asInstanceOf[Schema[Any]], name)) + case s: Schema.Record[_] => + stringSchemaCodec(s.asInstanceOf[Schema[Any]]) + case _ => + throw new IllegalArgumentException(s"Unsupported schema $schema0") + + } + + } + + private[http] def queryFromSchema[A]( + schema0: Schema[A], + error0: ErrorConstructor, + name: String, + ): StringSchemaCodec[A, QueryParams] = { + + def stringSchemaCodec(schema1: Schema[Any]): StringSchemaCodec[A, QueryParams] = + new StringSchemaCodec[A, QueryParams] { + override def schema: Schema[A] = schema1.asInstanceOf[Schema[A]] + + override private[http] def add(queryParams: QueryParams, key: String, value: String): QueryParams = + queryParams.addQueryParam(key, value) + + override private[http] def addAll(queryParams: QueryParams, kvs: Iterable[(String, String)]): QueryParams = + queryParams.addQueryParams(kvs) + + override private[http] def contains(queryParams: QueryParams, key: String): Boolean = + queryParams.hasQueryParam(key) + + override private[http] def unsafeGet(queryParams: QueryParams, key: String): String = + queryParams.unsafeQueryParam(key) + + override private[http] def getAll(queryParams: QueryParams, key: String): Chunk[String] = + queryParams.getAll(key) + + override private[http] def count(queryParams: QueryParams, key: String): Int = + queryParams.valueCount(key) + + override private[http] def error: ErrorConstructor = + error0 + + override private[http] def optional: StringSchemaCodec[Option[A], QueryParams] = + StringSchemaCodec.queryFromSchema(schema.optional, error0, name) + + override private[http] def kebabCase: Boolean = + false + override private[http] val defaultValue: A = + StringSchemaCodec.defaultValue(schema0) + override private[http] val isOptional: Boolean = + StringSchemaCodec.isOptional(schema0) + override private[http] val isOptionalSchema: Boolean = + StringSchemaCodec.isOptionalSchema(schema0) + } + schema0 match { + case s @ Schema.Primitive(_, _) => + stringSchemaCodec(recordSchema(s.asInstanceOf[Schema[Any]], name)) + case s @ Schema.Optional(schema, _) => + schema match { + case _: Schema.Collection[_, _] | _: Schema.Primitive[_] => + stringSchemaCodec(recordSchema(s.asInstanceOf[Schema[Any]], name)) + case s if s.isInstanceOf[Schema.Record[_]] => stringSchemaCodec(schema.asInstanceOf[Schema[Any]]) + case _ => throw new IllegalArgumentException(s"Unsupported schema $s") + } + case s @ Schema.Transform(schema, _, _, _, _) => + schema match { + case _: Schema.Collection[_, _] | _: Schema.Primitive[_] => + stringSchemaCodec(recordSchema(s.asInstanceOf[Schema[Any]], name)) + case _: Schema.Record[_] => stringSchemaCodec(s.asInstanceOf[Schema[Any]]) + case _ => throw new IllegalArgumentException(s"Unsupported schema $s") + } + case Schema.Lazy(schema0) => + queryFromSchema( + schema0().asInstanceOf[Schema[A]], + error0, + name, + ) + case _: Schema.Collection[_, _] => + stringSchemaCodec(recordSchema(schema0.asInstanceOf[Schema[Any]], name)) + case s: Schema.Record[_] => + stringSchemaCodec(s.asInstanceOf[Schema[Any]]) + case _ => + throw new IllegalArgumentException(s"Unsupported schema $schema0") + + } + } + + private def recordSchema[A](s: Schema[A], name: String): Schema[A] = Schema.CaseClass1[A, A]( + TypeId.Structural, + Schema.Field(name, s, Chunk.empty, Validation.succeed, identity, (_, v) => v), + identity, + ) + + private def parsePrimitive(standardType: StandardType[_]): String => Any = + standardType match { + case StandardType.UnitType => + val result = "" + (_: String) => result + case StandardType.StringType => + (s: String) => s + case StandardType.BoolType => + (s: String) => + s.toLowerCase match { + case "true" | "on" | "yes" | "1" => true + case "false" | "off" | "no" | "0" => false + case _ => throw DecodeError.ReadError(Cause.fail(new Exception("Invalid boolean value")), s) + } + case StandardType.ByteType => + (s: String) => + try { + s.toByte + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.ShortType => + (s: String) => + try { + s.toShort + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.IntType => + (s: String) => + try { + s.toInt + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.LongType => + (s: String) => + try { + s.toLong + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.FloatType => + (s: String) => + try { + s.toFloat + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.DoubleType => + (s: String) => + try { + s.toDouble + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.BinaryType => + val result = DecodeError.UnsupportedSchema(Schema.Primitive(standardType), "TextCodec") + (_: String) => throw result + case StandardType.CharType => + (s: String) => s.charAt(0) + case StandardType.UUIDType => + (s: String) => + try { + UUID.fromString(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.BigDecimalType => + (s: String) => + try { + BigDecimal(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.BigIntegerType => + (s: String) => + try { + BigInt(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.DayOfWeekType => + (s: String) => + try { + DayOfWeek.valueOf(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.MonthType => + (s: String) => + try { + Month.valueOf(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.MonthDayType => + (s: String) => + try { + MonthDay.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.PeriodType => + (s: String) => + try { + Period.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.YearType => + (s: String) => + try { + Year.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.YearMonthType => + (s: String) => + try { + YearMonth.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.ZoneIdType => + (s: String) => + try { + ZoneId.of(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.ZoneOffsetType => + (s: String) => + try { + ZoneOffset.of(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.DurationType => + (s: String) => + try { + java.time.Duration.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.InstantType => + (s: String) => + try { + Instant.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.LocalDateType => + (s: String) => + try { + LocalDate.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.LocalTimeType => + (s: String) => + try { + LocalTime.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.LocalDateTimeType => + (s: String) => + try { + LocalDateTime.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.OffsetTimeType => + (s: String) => + try { + OffsetTime.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.OffsetDateTimeType => + (s: String) => + try { + OffsetDateTime.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.ZonedDateTimeType => + (s: String) => + try { + ZonedDateTime.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.CurrencyType => + (s: String) => + try { + Currency.getInstance(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + } + + private def camelToKebab(s: String): String = + if (s.isEmpty) "" + else if (s.head.isUpper) s.head.toLower.toString + camelToKebab(s.tail) + else if (s.contains('-')) s + else + s.foldLeft("") { (acc, c) => + if (c.isUpper) acc + "-" + c.toLower + else acc + c + } +}