diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9aef2786a7..14d19e071b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -200,6 +200,10 @@ jobs: with: apps: sbt + - uses: coursier/setup-action@v1 + with: + apps: sbt + - name: Release env: PGP_PASSPHRASE: ${{ secrets.PGP_PASSPHRASE }} diff --git a/.scalafmt.conf b/.scalafmt.conf index cac439b8f2..6cdcd8efbe 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -1,4 +1,4 @@ -version = 3.8.1 +version = 3.8.6 maxColumn = 120 align.preset = more diff --git a/build.sbt b/build.sbt index a19027bdb9..4e5199a9dd 100644 --- a/build.sbt +++ b/build.sbt @@ -59,6 +59,7 @@ ThisBuild / githubWorkflowPublishTargetBranches += RefPredicate.StartsWith(Ref.T ThisBuild / githubWorkflowPublishPreamble := Seq(coursierSetup) ThisBuild / githubWorkflowPublish := Seq( + WorkflowStep.Use(UseRef.Public("coursier", "setup-action", "v1"), Map("apps" -> "sbt")), WorkflowStep.Sbt( List("ci-release"), name = Some("Release"), diff --git a/project/Dependencies.scala b/project/Dependencies.scala index bcc954795d..853bb67fcd 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -16,8 +16,8 @@ object Dependencies { val `jwt-core` = "com.github.jwt-scala" %% "jwt-core" % JwtCoreVersion val `scala-compact-collection` = "org.scala-lang.modules" %% "scala-collection-compat" % ScalaCompactCollectionVersion - val scalafmt = "org.scalameta" %% "scalafmt-dynamic" % "3.8.1" - val scalametaParsers = "org.scalameta" %% "parsers" % "4.9.9" + val scalafmt = "org.scalameta" %% "scalafmt-dynamic" % "3.8.6" + val scalametaParsers = "org.scalameta" %% "parsers" % "4.12.7" val netty = Seq( diff --git a/project/MimaSettings.scala b/project/MimaSettings.scala index 1edcce1f98..fd62bacb55 100644 --- a/project/MimaSettings.scala +++ b/project/MimaSettings.scala @@ -12,6 +12,17 @@ object MimaSettings { mimaBinaryIssueFilters ++= Seq( exclude[Problem]("zio.http.internal.*"), exclude[Problem]("zio.http.codec.internal.*"), + exclude[Problem]("zio.http.codec.HttpCodec$Query$QueryType$Record$"), + exclude[Problem]("zio.http.codec.HttpCodec$Query$QueryType$Record"), + exclude[Problem]("zio.http.codec.HttpCodec$Query$QueryType$Primitive$"), + exclude[Problem]("zio.http.codec.HttpCodec$Query$QueryType$Primitive"), + exclude[Problem]("zio.http.codec.HttpCodec$Query$QueryType$Collection$"), + exclude[Problem]("zio.http.codec.HttpCodec$Query$QueryType$Collection"), + exclude[Problem]("zio.http.codec.HttpCodec$Query$QueryType$"), + exclude[Problem]("zio.http.codec.HttpCodec$Query$QueryType"), + exclude[Problem]("zio.http.endpoint.openapi.OpenAPIGen#AtomizedMetaCodecs.apply"), + exclude[Problem]("zio.http.endpoint.openapi.OpenAPIGen#AtomizedMetaCodecs.this"), + exclude[Problem]("zio.http.endpoint.openapi.OpenAPIGen#AtomizedMetaCodecs.copy"), ), mimaFailOnProblem := failOnProblem ) diff --git a/zio-http-benchmarks/src/main/scala/zhttp.benchmarks/ServerInboundHandlerBenchmark.scala b/zio-http-benchmarks/src/main/scala/zhttp.benchmarks/ServerInboundHandlerBenchmark.scala index d28c665e5e..8f594614c6 100644 --- a/zio-http-benchmarks/src/main/scala/zhttp.benchmarks/ServerInboundHandlerBenchmark.scala +++ b/zio-http-benchmarks/src/main/scala/zhttp.benchmarks/ServerInboundHandlerBenchmark.scala @@ -18,7 +18,7 @@ class ServerInboundHandlerBenchmark { private val largeString = random.alphanumeric.take(100000).mkString private val baseUrl = "http://localhost:8080" - private val headers = Headers(Header.ContentType(MediaType.text.`plain`).untyped) + private val headers = Headers(Header.ContentType(MediaType.text.`plain`)) private val arrayEndpoint = "array" private val arrayResponse = ZIO.succeed( diff --git a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/CliEndpoint.scala b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/CliEndpoint.scala index 566289ea37..2985811983 100644 --- a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/CliEndpoint.scala +++ b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/CliEndpoint.scala @@ -1,7 +1,6 @@ package zio.http.endpoint.cli import zio.http._ -import zio.http.codec.HttpCodec.Query.QueryType import zio.http.codec._ import zio.http.endpoint._ @@ -112,13 +111,9 @@ private[cli] object CliEndpoint { } CliEndpoint(body = HttpOptions.Body(name, codec.defaultMediaType, codec.defaultSchema) :: List()) - case HttpCodec.Header(name, textCodec, _) if textCodec.isInstanceOf[TextCodec.Constant] => - CliEndpoint(headers = - HttpOptions.HeaderConstant(name, textCodec.asInstanceOf[TextCodec.Constant].string) :: List(), - ) - case HttpCodec.Header(name, textCodec, _) => - CliEndpoint(headers = HttpOptions.Header(name, textCodec) :: List()) - case HttpCodec.Method(codec, _) => + case HttpCodec.Header(headerType, _) => + CliEndpoint(headers = HttpOptions.Header(headerType.names.head, TextCodec.string) :: List()) + case HttpCodec.Method(codec, _) => codec.asInstanceOf[SimpleCodec[_, _]] match { case SimpleCodec.Specified(method: Method) => CliEndpoint(methods = method) @@ -128,22 +123,11 @@ private[cli] object CliEndpoint { case HttpCodec.Path(pathCodec, _) => CliEndpoint(url = HttpOptions.Path(pathCodec) :: List()) - case HttpCodec.Query(queryType, _) => - queryType match { - case QueryType.Primitive(name, codec) => - CliEndpoint(url = HttpOptions.Query(name, codec) :: List()) - case record @ QueryType.Record(_) => - val queryOptions = record.fieldAndCodecs.map { case (field, codec) => - HttpOptions.Query(field.name, codec) - } - CliEndpoint(url = queryOptions.toList) - case QueryType.Collection(_, elements, _) => - val queryOptions = - HttpOptions.Query(elements.name, elements.codec) - CliEndpoint(url = queryOptions :: List()) - } - - case HttpCodec.Status(_, _) => CliEndpoint.empty + case HttpCodec.Query(codec, _) => + CliEndpoint(url = codec.recordFields.map { case (f, codec) => + HttpOptions.Query(codec, f.fieldName) + }.toList) + case HttpCodec.Status(_, _) => CliEndpoint.empty } } diff --git a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/HttpOptions.scala b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/HttpOptions.scala index 191194864d..7cbe6ca565 100644 --- a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/HttpOptions.scala +++ b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/HttpOptions.scala @@ -12,6 +12,8 @@ import zio.schema.annotation.description import zio.http._ import zio.http.codec._ +import zio.http.internal.StringSchemaCodec +import zio.http.internal.StringSchemaCodec.PrimitiveCodec /* * HttpOptions is a wrapper of a transformation Options[CliRequest] => Options[CliRequest]. @@ -264,12 +266,10 @@ private[cli] object HttpOptions { } - final case class Query(override val name: String, codec: BinaryCodecWithSchema[_], doc: Doc = Doc.empty) - extends URLOptions { + final case class Query(codec: PrimitiveCodec[_], name: String, doc: Doc = Doc.empty) extends URLOptions { self => - override val tag = "?" + name - def options: Options[_] = optionsFromSchema(codec)(name) + def options: Options[_] = optionsFromSchema(codec.schema)(name) override def ??(doc: Doc): Query = self.copy(doc = self.doc + doc) @@ -293,8 +293,8 @@ private[cli] object HttpOptions { } - private[cli] def optionsFromSchema[A](codec: BinaryCodecWithSchema[A]): String => Options[A] = - codec.schema match { + private[cli] def optionsFromSchema[A](schema: Schema[A]): String => Options[A] = + schema match { case Schema.Primitive(standardType, _) => standardType match { case StandardType.UnitType => diff --git a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/AuxGen.scala b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/AuxGen.scala index 8fdb42d863..06c365fc0e 100644 --- a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/AuxGen.scala +++ b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/AuxGen.scala @@ -16,11 +16,7 @@ import zio.http.codec._ */ object AuxGen { - lazy val anyTextCodec: Gen[Any, TextCodec[_]] = - Gen.oneOf( - Gen.fromIterable(List(TextCodec.boolean, TextCodec.int, TextCodec.string, TextCodec.uuid)), - Gen.alphaNumericStringBounded(1, 30).map(TextCodec.constant(_)), - ) + lazy val anyTextCodec: Gen[Any, TextCodec[_]] = Gen.const(TextCodec.string) lazy val anyMediaType: Gen[Any, MediaType] = Gen.fromIterable(MediaType.allMediaTypes) diff --git a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CliSpec.scala b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CliSpec.scala index 5153f1ab47..f26145431c 100644 --- a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CliSpec.scala +++ b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CliSpec.scala @@ -27,7 +27,7 @@ object CliSpec extends ZIOSpecDefault { val bodyStream = ContentCodec.contentStream[BigInt]("bodyStream") - val headerCodec = HttpCodec.Header("header", TextCodec.string) + val headerCodec = HttpCodec.headerAs[String]("header") val path1 = PathCodec.bool("path1") diff --git a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CommandGen.scala b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CommandGen.scala index ad99897218..c724eaa218 100644 --- a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CommandGen.scala +++ b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CommandGen.scala @@ -48,17 +48,17 @@ object CommandGen { case _ => true }.map { case HttpOptions.Path(pathCodec, _) => - pathCodec.segments.toList.flatMap { case segment => + pathCodec.segments.toList.flatMap { segment => getSegment(segment) match { case (_, "") => Nil case (name, "boolean") => s"[${getName(name, "")}]" :: Nil case (name, codec) => s"${getName(name, "")} $codec" :: Nil } } - case HttpOptions.Query(name, codec, _) => - getType(codec) match { - case "" => s"[${getName(name, "")}]" :: Nil - case codec => s"${getName(name, "")} $codec" :: Nil + case HttpOptions.Query(codec, name, _) => + getType(codec.schema) match { + case "" => s"[${getName(name, "")}]" :: Nil + case tpy => s"${getName(name, "")} $tpy" :: Nil } case _ => Nil }.foldRight(List[String]())(_ ++ _) @@ -121,8 +121,8 @@ object CommandGen { case _ => "" } - def getType[A](codec: BinaryCodecWithSchema[A]): String = - codec.schema match { + def getType[A](schema: Schema[A]): String = + schema match { case Schema.Primitive(standardType, _) => standardType match { case StandardType.UnitType => "" diff --git a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/EndpointGen.scala b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/EndpointGen.scala index d868a86cca..7f34ae9adc 100644 --- a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/EndpointGen.scala +++ b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/EndpointGen.scala @@ -78,10 +78,9 @@ object EndpointGen { lazy val anyHeader: Gen[Any, CliReprOf[Codec[_]]] = Gen.alphaNumericStringBounded(1, 30).zip(anyTextCodec).map { case (name, codec) => CliRepr( - HttpCodec.Header(name, codec), + HttpCodec.Header(Header.Custom(name, "").headerType), // todo use schema bases header codec match { - case TextCodec.Constant(value) => CliEndpoint(headers = HttpOptions.HeaderConstant(name, value) :: Nil) - case _ => CliEndpoint(headers = HttpOptions.Header(name, codec) :: Nil) + case _ => CliEndpoint(headers = HttpOptions.Header(name, codec) :: Nil) }, ) } @@ -102,10 +101,10 @@ object EndpointGen { lazy val anyQuery: Gen[Any, CliReprOf[Codec[_]]] = Gen.alphaNumericStringBounded(1, 30).zip(anyStandardType).map { case (name, schema0) => val schema = schema0.asInstanceOf[Schema[Any]] - val codec = BinaryCodecWithSchema(TextBinaryCodec.fromSchema(schema), schema) + val codec = QueryCodec.query(name)(schema).asInstanceOf[HttpCodec.Query[Any]] CliRepr( - HttpCodec.Query(HttpCodec.Query.QueryType.Primitive(name, codec)), - CliEndpoint(url = HttpOptions.Query(name, codec) :: Nil), + codec, + CliEndpoint(url = HttpOptions.Query(codec.codec.recordFields.head._2, name) :: Nil), ) } diff --git a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/OptionsGen.scala b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/OptionsGen.scala index 58fe22aa85..37350d29fe 100644 --- a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/OptionsGen.scala +++ b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/OptionsGen.scala @@ -10,6 +10,7 @@ import zio.http._ import zio.http.codec._ import zio.http.endpoint.cli.AuxGen._ import zio.http.endpoint.cli.CliRepr._ +import zio.http.internal.StringSchemaCodec.PrimitiveCodec /** * Constructs a Gen[Options[CliRequest], CliEndpoint] @@ -32,10 +33,10 @@ object OptionsGen { .optionsFromTextCodec(textCodec)(name) .map(value => textCodec.encode(value)) - def encodeOptions[A](name: String, codec: BinaryCodecWithSchema[A]): Options[String] = + def encodeOptions[A](name: String, codec: PrimitiveCodec[A], schema: Schema[A]): Options[String] = HttpOptions - .optionsFromSchema(codec)(name) - .map(value => codec.codec(CodecConfig.defaultConfig).encode(value).asString) + .optionsFromSchema(schema)(name) + .map(value => codec.encode(value)) lazy val anyBodyOption: Gen[Any, CliReprOf[Options[Retriever]]] = Gen @@ -50,18 +51,12 @@ object OptionsGen { } lazy val anyHeaderOption: Gen[Any, CliReprOf[Options[Headers]]] = - Gen.alphaNumericStringBounded(1, 30).zip(anyTextCodec).map { - case (name, TextCodec.Constant(value)) => - CliRepr( - Options.Empty.map(_ => Headers(name, value)), - CliEndpoint(headers = HttpOptions.HeaderConstant(name, value) :: Nil), - ) - case (name, codec) => - CliRepr( - encodeOptions(name, codec) - .map(value => Headers(name, value)), - CliEndpoint(headers = HttpOptions.Header(name, codec) :: Nil), - ) + Gen.alphaNumericStringBounded(1, 30).zip(anyTextCodec).map { case (name, codec) => + CliRepr( + encodeOptions(name, codec) + .map(value => Headers(name, value)), + CliEndpoint(headers = HttpOptions.Header(name, codec) :: Nil), + ) } lazy val anyURLOption: Gen[Any, CliReprOf[Options[String]]] = @@ -83,14 +78,12 @@ object OptionsGen { }, Gen .alphaNumericStringBounded(1, 30) - .zip(anyStandardType.map { s => - val schema = s.asInstanceOf[Schema[Any]] - BinaryCodecWithSchema(TextBinaryCodec.fromSchema(schema), schema) - }) - .map { case (name, codec) => + .zip(anyStandardType) + .map { case (name, schema) => + val codec = QueryCodec.query(name)(schema).asInstanceOf[HttpCodec.Query[Any]] CliRepr( - encodeOptions(name, codec), - CliEndpoint(url = HttpOptions.Query(name, codec) :: Nil), + encodeOptions(name, codec.codec.recordFields.head._2, schema.asInstanceOf[Schema[Any]]), + CliEndpoint(url = HttpOptions.Query(codec.codec.recordFields.head._2, name) :: Nil), ) }, ) diff --git a/zio-http-example/src/main/scala/example/websocket/WebSocketReconnectingClient.scala b/zio-http-example/src/main/scala/example/websocket/WebSocketReconnectingClient.scala index 599ac554ec..cecddbc6bb 100644 --- a/zio-http-example/src/main/scala/example/websocket/WebSocketReconnectingClient.scala +++ b/zio-http-example/src/main/scala/example/websocket/WebSocketReconnectingClient.scala @@ -22,13 +22,13 @@ object WebSocketReconnectingClient extends ZIOAppDefault { channel.send(ChannelEvent.Read(WebSocketFrame.text("foo"))) // On receiving "foo", we'll reply with another "foo" to keep echo loop going - case Read(WebSocketFrame.Text("foo")) => + case Read(WebSocketFrame.Text("foo")) => ZIO.logInfo("Received foo message.") *> ZIO.sleep(1.second) *> channel.send(ChannelEvent.Read(WebSocketFrame.text("foo"))) // Handle exception and convert it to failure to signal the shutdown of the socket connection via the promise - case ExceptionCaught(t) => + case ExceptionCaught(t) => ZIO.fail(t) case _ => diff --git a/zio-http-example/src/main/scala/example/websocket/WebSocketServerAdvanced.scala b/zio-http-example/src/main/scala/example/websocket/WebSocketServerAdvanced.scala index 19a98f3437..c7f08d288a 100644 --- a/zio-http-example/src/main/scala/example/websocket/WebSocketServerAdvanced.scala +++ b/zio-http-example/src/main/scala/example/websocket/WebSocketServerAdvanced.scala @@ -13,19 +13,19 @@ object WebSocketServerAdvanced extends ZIOAppDefault { val socketApp: WebSocketApp[Any] = Handler.webSocket { channel => channel.receiveAll { - case Read(WebSocketFrame.Text("end")) => + case Read(WebSocketFrame.Text("end")) => channel.shutdown // Send a "bar" if the client sends a "foo" - case Read(WebSocketFrame.Text("foo")) => + case Read(WebSocketFrame.Text("foo")) => channel.send(Read(WebSocketFrame.text("bar"))) // Send a "foo" if the client sends a "bar" - case Read(WebSocketFrame.Text("bar")) => + case Read(WebSocketFrame.Text("bar")) => channel.send(Read(WebSocketFrame.text("foo"))) // Echo the same message 10 times if it's not "foo" or "bar" - case Read(WebSocketFrame.Text(text)) => + case Read(WebSocketFrame.Text(text)) => channel .send(Read(WebSocketFrame.text(s"echo $text"))) .repeatN(10) @@ -38,11 +38,11 @@ object WebSocketServerAdvanced extends ZIOAppDefault { channel.send(Read(WebSocketFrame.text("Greetings!"))) // Log when the channel is getting closed - case Read(WebSocketFrame.Close(status, reason)) => + case Read(WebSocketFrame.Close(status, reason)) => Console.printLine("Closing channel with status: " + status + " and reason: " + reason) // Print the exception if it's not a normal close - case ExceptionCaught(cause) => + case ExceptionCaught(cause) => Console.printLine(s"Channel error!: ${cause.getMessage}") case _ => diff --git a/zio-http-example/src/main/scala/example/websocket/WebSocketSimpleClient.scala b/zio-http-example/src/main/scala/example/websocket/WebSocketSimpleClient.scala index 798935e6c1..0f6e085592 100644 --- a/zio-http-example/src/main/scala/example/websocket/WebSocketSimpleClient.scala +++ b/zio-http-example/src/main/scala/example/websocket/WebSocketSimpleClient.scala @@ -21,11 +21,11 @@ object WebSocketSimpleClient extends ZIOAppDefault { channel.send(Read(WebSocketFrame.text("foo"))) // Send a "bar" if the server sends a "foo" - case Read(WebSocketFrame.Text("foo")) => + case Read(WebSocketFrame.Text("foo")) => channel.send(Read(WebSocketFrame.text("bar"))) // Close the connection if the server sends a "bar" - case Read(WebSocketFrame.Text("bar")) => + case Read(WebSocketFrame.Text("bar")) => ZIO.succeed(println("Goodbye!")) *> channel.send(Read(WebSocketFrame.close(1000))) case _ => diff --git a/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala b/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala index d913d9e291..26bc00630d 100644 --- a/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala +++ b/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala @@ -823,9 +823,9 @@ final case class EndpointGen(config: Config) { * `transform` that simply `wrap` / `unwrap` the provided value. */ case JsonSchema.Boolean => aliasedSchemaToCode(openAPI, name, schema) - case JsonSchema.OneOfSchema(schemas) if schemas.exists(_.isPrimitive) => + case JsonSchema.OneOfSchema(schemas) if schemas.exists(_.isPrimitive) => throw new Exception("OneOf schemas with primitive types are not supported") - case JsonSchema.OneOfSchema(schemas) => + case JsonSchema.OneOfSchema(schemas) => val discriminatorInfo = annotations.collectFirst { case JsonSchema.MetaData.Discriminator(discriminator) => discriminator } val discriminator: Option[String] = discriminatorInfo.map(_.propertyName) @@ -885,7 +885,7 @@ final case class EndpointGen(config: Config) { ), ), ) - case JsonSchema.AllOfSchema(schemas) => + case JsonSchema.AllOfSchema(schemas) => val genericFieldIndex = Iterator.from(0) val unvalidatedFields = schemas.toList.map(_.withoutAnnotations).flatMap { case schema @ JsonSchema.Object(_, _, _) => @@ -928,9 +928,9 @@ final case class EndpointGen(config: Config) { enums = Nil, ), ) - case JsonSchema.AnyOfSchema(schemas) if schemas.exists(_.isPrimitive) => + case JsonSchema.AnyOfSchema(schemas) if schemas.exists(_.isPrimitive) => throw new Exception("AnyOf schemas with primitive types are not supported") - case JsonSchema.AnyOfSchema(schemas) => + case JsonSchema.AnyOfSchema(schemas) => val discriminatorInfo = annotations.collectFirst { case JsonSchema.MetaData.Discriminator(discriminator) => discriminator } val discriminator: Option[String] = discriminatorInfo.map(_.propertyName) @@ -987,12 +987,12 @@ final case class EndpointGen(config: Config) { ), ), ) - case JsonSchema.Number(_, _, _, _, _, _) => aliasedSchemaToCode(openAPI, name, schema) + case JsonSchema.Number(_, _, _, _, _, _) => aliasedSchemaToCode(openAPI, name, schema) // should we provide support for (Newtype) aliasing arrays of primitives? - case JsonSchema.ArrayType(None, _, _) => None - case JsonSchema.ArrayType(Some(schema), _, _) => + case JsonSchema.ArrayType(None, _, _) => None + case JsonSchema.ArrayType(Some(schema), _, _) => schemaToCode(schema, openAPI, name, annotations) - case obj: JsonSchema.Object if obj.isInvalid => + case obj: JsonSchema.Object if obj.isInvalid => throw new Exception("Object with properties and additionalProperties is not supported") case obj @ JsonSchema.Object(properties, _, _) if obj.isClosedDictionary => val unvalidatedFields = fieldsOfObject(openAPI, annotations)(obj) @@ -1052,8 +1052,8 @@ final case class EndpointGen(config: Config) { ), ), ) - case JsonSchema.Null => throw new Exception("Null query parameters are not supported") - case JsonSchema.AnyJson => None + case JsonSchema.Null => throw new Exception("Null query parameters are not supported") + case JsonSchema.AnyJson => None } } diff --git a/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala b/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala index fda70f5ff4..71e34f4c7a 100644 --- a/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala +++ b/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala @@ -32,12 +32,11 @@ import zio.http.gen.openapi.{Config, EndpointGen} object CodeGenSpec extends ZIOSpecDefault { case class ValidatedData( - @validate(Validation.maxLength(10)) - name: String, - @validate(Validation.greaterThan(0) && Validation.lessThan(100)) - age: Int, + @validate(Validation.maxLength(10)) name: String, + @validate(Validation.greaterThan(0) && Validation.lessThan(100)) age: Int, ) - implicit val validatedDataSchema: Schema[ValidatedData] = DeriveSchema.gen[ValidatedData] + implicit val validatedDataSchema: Schema[ValidatedData] = + DeriveSchema.gen[ValidatedData] private def fileShouldBe(dir: java.nio.file.Path, subPath: String, expectedFile: String): TestResult = { val filePath = dir.resolve(Paths.get(subPath)) @@ -155,8 +154,9 @@ object CodeGenSpec extends ZIOSpecDefault { Endpoint(Method.GET / "api" / "v1" / "users") .header(HeaderCodec.accept) .header(HeaderCodec.contentType) - .header(HeaderCodec.name[String]("Token")) - val openAPI = OpenAPIGen.fromEndpoints(endpoint) + .header(HeaderCodec.headerAs[String]("Token")) + + val openAPI = OpenAPIGen.fromEndpoints(endpoint) codeGenFromOpenAPI(openAPI) { testDir => fileShouldBe(testDir, "api/v1/Users.scala", "/EndpointWithHeaders.scala") @@ -605,6 +605,7 @@ object CodeGenSpec extends ZIOSpecDefault { } } } @@ TestAspect.exceptScala3, // for some reason, the temp dir is empty in Scala 3 + //format: off test("Endpoint with array field in input") { val endpoint = Endpoint(Method.POST / "api" / "v1" / "users").in[UserNameArray].out[User] val openAPI = OpenAPIGen.fromEndpoints("", "", endpoint) diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/model/Conversions.scala b/zio-http/jvm/src/main/scala/zio/http/netty/model/Conversions.scala index 374c9ba27f..2874de2c63 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/model/Conversions.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/model/Conversions.scala @@ -18,6 +18,8 @@ package zio.http.netty.model import scala.collection.AbstractIterator +import zio.Chunk + import zio.http.Server.Config.CompressionOptions import zio.http._ @@ -58,10 +60,10 @@ private[netty] object Conversions { def headersToNetty(headers: Headers): HttpHeaders = headers match { - case Headers.FromIterable(_) => encodeHeaderListToNetty(headers) - case Headers.Native(value, _, _, _) => value.asInstanceOf[HttpHeaders] - case Headers.Concat(_, _) => encodeHeaderListToNetty(headers) - case Headers.Empty => new DefaultHttpHeaders() + case Headers.FromIterable(_) => encodeHeaderListToNetty(headers) + case Headers.Native(value, _, _, _, _) => value.asInstanceOf[HttpHeaders] + case Headers.Concat(_, _) => encodeHeaderListToNetty(headers) + case Headers.Empty => new DefaultHttpHeaders() } def urlToNetty(url: URL): String = { @@ -89,6 +91,7 @@ private[netty] object Conversions { (headers: HttpHeaders) => nettyHeadersIterator(headers), // NOTE: Netty's headers.get is case-insensitive (headers: HttpHeaders, key: CharSequence) => headers.get(key), + (headers: HttpHeaders, key: CharSequence) => Chunk.fromJavaIterable(headers.getAll(key)), (headers: HttpHeaders, key: CharSequence) => headers.contains(key), ) diff --git a/zio-http/jvm/src/test/scala/zio/http/LogAnnotationMiddlewareSpec.scala b/zio-http/jvm/src/test/scala/zio/http/LogAnnotationMiddlewareSpec.scala index 51775c944a..6bf212e7dd 100644 --- a/zio-http/jvm/src/test/scala/zio/http/LogAnnotationMiddlewareSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/LogAnnotationMiddlewareSpec.scala @@ -30,9 +30,8 @@ object LogAnnotationMiddlewareSpec extends ZIOSpecDefault { handler(ZIO.logWarning("Oh!") *> ZIO.succeed(Response.text("Hey logging!"))), ) .@@( - Middleware.logAnnotate(req => - Set(LogAnnotation("method", req.method.name), LogAnnotation("path", req.path.encode)), - ), + Middleware + .logAnnotate(req => Set(LogAnnotation("method", req.method.name), LogAnnotation("path", req.path.encode))), ) .runZIO(Request.get("/")) diff --git a/zio-http/jvm/src/test/scala/zio/http/WebSocketSpec.scala b/zio-http/jvm/src/test/scala/zio/http/WebSocketSpec.scala index 87c7d335e4..cc3e222b6c 100644 --- a/zio-http/jvm/src/test/scala/zio/http/WebSocketSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/WebSocketSpec.scala @@ -79,7 +79,7 @@ object WebSocketSpec extends RoutesRunnableSpec { // Setup websocket server - serverHttp = Handler.webSocket { channel => + serverHttp = Handler.webSocket { channel => channel.receiveAll { case Unregistered => isStarted.succeed(()) <&> isSet.succeed(()).delay(5 seconds).withClock(clock) diff --git a/zio-http/jvm/src/test/scala/zio/http/endpoint/AuthSpec.scala b/zio-http/jvm/src/test/scala/zio/http/endpoint/AuthSpec.scala index e912e7f410..d140450719 100644 --- a/zio-http/jvm/src/test/scala/zio/http/endpoint/AuthSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/endpoint/AuthSpec.scala @@ -152,7 +152,7 @@ object AuthSpec extends ZIOSpecDefault { .catchAllCause(c => ZIO.logInfoCause(c)) <* ZIO.sleep(1.seconds) response <- response } yield assertTrue(response == "admin") - } @@ flaky, + }, test("Auth basic or bearer with context and endpoint client") { val endpoint = Endpoint(Method.GET / "multiAuth") @@ -212,7 +212,7 @@ object AuthSpec extends ZIOSpecDefault { .catchAllCause(c => ZIO.logInfoCause(c)) <* ZIO.sleep(1.seconds) response <- response } yield assertTrue(response == "admin") - } @@ TestAspect.flaky, + }, test("Auth with context and endpoint client with path parameter") { val endpoint = Endpoint(Method.GET / int("a")).out[String](MediaType.text.`plain`).auth(AuthType.Basic) @@ -237,7 +237,7 @@ object AuthSpec extends ZIOSpecDefault { response <- response } yield assertTrue(response == "admin") }, - ).provideShared(Client.default, Server.default) @@ TestAspect.withLiveClock, + ).provideShared(Client.default, Server.default) @@ TestAspect.withLiveClock @@ TestAspect.flaky, test("Require Basic Auth, but get Bearer Auth") { val endpoint = Endpoint(Method.GET / "test").out[String](MediaType.text.`plain`).auth(AuthType.Basic) val routes = diff --git a/zio-http/jvm/src/test/scala/zio/http/endpoint/EndpointSpec.scala b/zio-http/jvm/src/test/scala/zio/http/endpoint/EndpointSpec.scala index a2c6334e62..3a093669bb 100644 --- a/zio-http/jvm/src/test/scala/zio/http/endpoint/EndpointSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/endpoint/EndpointSpec.scala @@ -18,6 +18,8 @@ package zio.http.endpoint import java.time.Instant +import scala.math.BigDecimal.javaBigDecimal2bigDecimal + import zio._ import zio.test._ diff --git a/zio-http/jvm/src/test/scala/zio/http/endpoint/HeaderSpec.scala b/zio-http/jvm/src/test/scala/zio/http/endpoint/HeaderSpec.scala new file mode 100644 index 0000000000..ee5a7896ea --- /dev/null +++ b/zio-http/jvm/src/test/scala/zio/http/endpoint/HeaderSpec.scala @@ -0,0 +1,204 @@ +package zio.http.endpoint + +import zio.test._ +import zio.{NonEmptyChunk, Scope} + +import zio.schema.Schema +import zio.schema.annotation.fieldName + +import zio.http._ +import zio.http.codec.HttpCodec +import zio.http.endpoint.EndpointSpec.testEndpointWithHeaders + +object HeaderSpec extends ZIOHttpSpec { + case class MyHeaders(age: String, @fieldName("content-type") cType: String = "application", xApiKey: Option[String]) + + object MyHeaders { + implicit val schema: Schema[MyHeaders] = zio.schema.DeriveSchema.gen[MyHeaders] + } + + override def spec: Spec[TestEnvironment with Scope, Any] = + suite("HeaderCodec")( + test("Headers from case class") { + check( + Gen.alphaNumericStringBounded(1, 10), + Gen.alphaNumericStringBounded(1, 10), + Gen.alphaNumericStringBounded(1, 10), + ) { (age, cType, apiKey) => + val testRoutes = testEndpointWithHeaders( + Routes( + Endpoint(Method.GET / "users") + .header(HttpCodec.headers[MyHeaders]) + .out[String] + .implementPurely(_.toString), + ), + ) _ + + testRoutes( + s"/users", + List( + "age" -> age, + "content-type" -> cType, + "x-api-key" -> apiKey, + ), + MyHeaders(age, cType, Some(apiKey)).toString, + ) && + testRoutes( + s"/users", + List( + "age" -> age, + "content-type" -> cType, + "x-api-key" -> "", + ), + MyHeaders(age, cType, Some("")).toString, + ) && + testRoutes( + s"/users", + List( + "age" -> age, + ), + MyHeaders(age, "application", None).toString, + ) + } + }, + test("Optional Headers from case class") { + check( + Gen.alphaNumericStringBounded(1, 10), + Gen.alphaNumericStringBounded(1, 10), + Gen.alphaNumericStringBounded(1, 10), + ) { (age, cType, apiKey) => + val testRoutes = testEndpointWithHeaders( + Routes( + Endpoint(Method.GET / "users") + .header(HttpCodec.headers[MyHeaders].optional) + .out[String] + .implementPurely(_.toString), + ), + ) _ + + testRoutes( + s"/users", + List( + "content-type" -> cType, + ), + None.toString, + ) && + testRoutes( + s"/users", + List( + "age" -> age, + "content-type" -> cType, + "x-api-key" -> apiKey, + ), + Some(MyHeaders(age, cType, Some(apiKey))).toString, + ) && testRoutes( + s"/users", + List( + "age" -> age, + ), + Some(MyHeaders(age, "application", None)).toString, + ) + } + }, + test("Multiple Header values") { + check( + Gen.alphaNumericStringBounded(1, 10), + Gen.alphaNumericStringBounded(1, 10), + Gen.alphaNumericStringBounded(1, 10), + ) { (age, age2, age3) => + val testRoutes = testEndpointWithHeaders( + Routes( + Endpoint(Method.GET / "users") + .header(HttpCodec.headerAs[List[String]]("age")) + .out[String] + .implementPurely(_.toString), + ), + ) _ + + testRoutes( + s"/users", + List( + "age" -> age, + ), + List(age).toString, + ) && testRoutes( + s"/users", + List( + "age" -> age, + "age" -> age2, + ), + List(age, age2).toString, + ) && testRoutes( + s"/users", + List( + "age" -> age, + "age" -> age2, + "age" -> age3, + ), + List(age, age2, age3).toString, + ) + } + }, + test("Multiple Header values non empty") { + check( + Gen.alphaNumericStringBounded(1, 10), + Gen.alphaNumericStringBounded(1, 10), + Gen.alphaNumericStringBounded(1, 10), + ) { (age, age2, age3) => + val testRoutes = testEndpointWithHeaders( + Routes( + Endpoint(Method.GET / "users") + .header(HttpCodec.headerAs[NonEmptyChunk[String]]("age")) + .out[String] + .implementPurely(_.toString), + ), + ) _ + + testRoutes( + s"/users", + List( + "age" -> age, + ), + NonEmptyChunk(age).toString, + ) && testRoutes( + s"/users", + List( + "age" -> age, + "age" -> age2, + ), + NonEmptyChunk(age, age2).toString, + ) && testRoutes( + s"/users", + List( + "age" -> age, + "age" -> age2, + "age" -> age3, + ), + NonEmptyChunk(age, age2, age3).toString, + ) + } + }, + test("Header from transformed schema") { + case class Wrapper(age: Int) + implicit val schema: Schema[Wrapper] = zio.schema.Schema[Int].transform[Wrapper](Wrapper(_), _.age) + check(Gen.int) { age => + val testRoutes = testEndpointWithHeaders( + Routes( + Endpoint(Method.GET / "users") + .header(HttpCodec.headerAs[Wrapper]("age")) + .out[String] + .implementPurely(_.toString), + ), + ) _ + + testRoutes( + s"/users", + List( + "age" -> age.toString, + ), + Wrapper(age).toString, + ) + } + }, + ) +} diff --git a/zio-http/jvm/src/test/scala/zio/http/endpoint/QueryParameterSpec.scala b/zio-http/jvm/src/test/scala/zio/http/endpoint/QueryParameterSpec.scala index 0c5b732485..af0a2eca16 100644 --- a/zio-http/jvm/src/test/scala/zio/http/endpoint/QueryParameterSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/endpoint/QueryParameterSpec.scala @@ -55,16 +55,6 @@ object QueryParameterSpec extends ZIOHttpSpec { testRoutes( s"/users?int=$int&optInt=${optInt.mkString}&string=$string&strings=${strings.mkString(",")}", Params(int, optInt, string, strings).toString, - ) && - testRoutes( - s"/users?int=$int&string=$string&strings=${strings.mkString(",")}", - Params(int, None, string, strings).toString, - ) && testRoutes( - s"/users?int=$int&optInt=${optInt.mkString}&strings=${strings.mkString(",")}", - Params(int, optInt, "", strings).toString, - ) && testRoutes( - s"/users?int=$int&optInt=${optInt.mkString}&string=$string", - Params(int, optInt, string, Chunk("defaultString")).toString, ) } }, @@ -110,8 +100,8 @@ object QueryParameterSpec extends ZIOHttpSpec { }, ), ) _ - // testRoutes(s"/users/$userId", s"path(users, $userId, None)") && - // testRoutes(s"/users/$userId?details=", s"path(users, $userId, None)") && + testRoutes(s"/users/$userId", s"path(users, $userId, None)") && + testRoutes(s"/users/$userId?details=", s"path(users, $userId, Some())") && testRoutes(s"/users/$userId?details=$details", s"path(users, $userId, Some($details))") } }, @@ -168,6 +158,38 @@ object QueryParameterSpec extends ZIOHttpSpec { ) } }, + test("query parameters with multiple values non empty") { + check(Gen.int, Gen.listOfN(3)(Gen.alphaNumericString)) { (userId, keys) => + val routes = Routes( + Endpoint(GET / "users" / int("userId")) + .query(HttpCodec.query[NonEmptyChunk[String]]("key")) + .out[String] + .implementHandler { + Handler.fromFunction { case (userId, keys) => + s"""path(users, $userId, ${keys.mkString(", ")})""" + } + }, + ) + val testRoutes = testEndpoint( + routes, + ) _ + + testRoutes( + s"/users/$userId?key=${keys(0)}&key=${keys(1)}&key=${keys(2)}", + s"path(users, $userId, ${keys.mkString(", ")})", + ) && + testRoutes( + s"/users/$userId?key=${keys(0)}&key=${keys(1)}", + s"path(users, $userId, ${keys.take(2).mkString(", ")})", + ) && + testRoutes( + s"/users/$userId?key=${keys(0)}", + s"path(users, $userId, ${keys.take(1).mkString(", ")})", + ) && routes + .runZIO(Request.get(s"/users/$userId")) + .map(resp => assertTrue(resp.status == Status.BadRequest)) + } + }, test("optional query parameters with multiple values") { check(Gen.int, Gen.listOfN(3)(Gen.alphaNumericString)) { (userId, keys) => val testRoutes = testEndpoint( @@ -341,7 +363,7 @@ object QueryParameterSpec extends ZIOHttpSpec { test("query parameters keys without values for multi value query") { val routes = Routes( Endpoint(GET / "users") - .query(HttpCodec.query[Chunk[RuntimeFlags]]("ints")) + .query(HttpCodec.query[Chunk[Int]]("ints")) .out[String] .implementHandler { Handler.fromFunction { queryParams => s"path(users, $queryParams)" } @@ -438,6 +460,6 @@ object QueryParameterSpec extends ZIOHttpSpec { assertTrue(response.status == Status.Ok) } }, - ) + ).provide(ErrorResponseConfig.debugLayer) } diff --git a/zio-http/jvm/src/test/scala/zio/http/endpoint/RequestSpec.scala b/zio-http/jvm/src/test/scala/zio/http/endpoint/RequestSpec.scala index 7e6e1fb6c8..345c63dd29 100644 --- a/zio-http/jvm/src/test/scala/zio/http/endpoint/RequestSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/endpoint/RequestSpec.scala @@ -41,7 +41,7 @@ object RequestSpec extends ZIOHttpSpec { val testRoutes = testEndpointWithHeaders( Routes( Endpoint(GET / "users" / int("userId")) - .header(HeaderCodec.name[java.util.UUID]("X-Correlation-ID")) + .header(HeaderCodec.headerAs[java.util.UUID]("X-Correlation-ID")) .out[String] .implementHandler { Handler.fromFunction { case (userId, correlationId) => @@ -49,7 +49,7 @@ object RequestSpec extends ZIOHttpSpec { } }, Endpoint(GET / "users" / int("userId") / "posts" / int("postId")) - .header(HeaderCodec.name[java.util.UUID]("X-Correlation-ID")) + .header(HeaderCodec.headerAs[java.util.UUID]("X-Correlation-ID")) .out[String] .implementHandler { Handler.fromFunction { case (userId, postId, correlationId) => @@ -70,6 +70,48 @@ object RequestSpec extends ZIOHttpSpec { ) } }, + test("simple request with header with multiple values") { + check(Gen.int, Gen.listOfN(3)(Gen.uuid)) { (userId, correlationId) => + val testRoutes = testEndpointWithHeaders( + Routes( + Endpoint(GET / "users" / int("userId")) + .header(HeaderCodec.headerAs[Chunk[java.util.UUID]]("X-Correlation-ID")) + .out[String] + .implementHandler { + Handler.fromFunction { case (userId, correlationId) => + s"path(users, $userId) header(correlationId=${correlationId.mkString(",")})" + } + }, + ), + ) _ + testRoutes( + s"/users/$userId", + correlationId.map(uuid => "X-Correlation-ID" -> uuid.toString), + s"path(users, $userId) header(correlationId=${correlationId.mkString(",")})", + ) + } + }, + test("simple request with header with multiple values non empty") { + check(Gen.int, Gen.listOfN(3)(Gen.uuid)) { (userId, correlationId) => + val testRoutes = testEndpointWithHeaders( + Routes( + Endpoint(GET / "users" / int("userId")) + .header(HeaderCodec.headerAs[NonEmptyChunk[java.util.UUID]]("X-Correlation-ID")) + .out[String] + .implementHandler { + Handler.fromFunction { case (userId, correlationId) => + s"path(users, $userId) header(correlationId=${correlationId.mkString(",")})" + } + }, + ), + ) _ + testRoutes( + s"/users/$userId", + correlationId.map(uuid => "X-Correlation-ID" -> uuid.toString), + s"path(users, $userId) header(correlationId=${correlationId.mkString(",")})", + ) + } + }, test("custom content type") { check(Gen.int) { id => val endpoint = @@ -200,7 +242,7 @@ object RequestSpec extends ZIOHttpSpec { check(Gen.int, Gen.alphaNumericString) { (id, notACorrelationId) => val endpoint = Endpoint(GET / "posts") - .header(HeaderCodec.name[java.util.UUID]("X-Correlation-ID")) + .header(HeaderCodec.headerAs[java.util.UUID]("X-Correlation-ID")) .out[Int] val routes = endpoint.implementHandler { @@ -219,7 +261,7 @@ object RequestSpec extends ZIOHttpSpec { check(Gen.int) { id => val endpoint = Endpoint(GET / "posts") - .header(HeaderCodec.name[java.util.UUID]("X-Correlation-ID")) + .header(HeaderCodec.headerAs[java.util.UUID]("X-Correlation-ID")) .out[Int] val routes = endpoint.implementHandler { @@ -453,7 +495,7 @@ object RequestSpec extends ZIOHttpSpec { }, test("composite in codecs") { check(Gen.alphaNumericString, Gen.alphaNumericString) { (queryValue, headerValue) => - val headerOrQuery = HeaderCodec.name[String]("X-Header") | HttpCodec.query[String]("header") + val headerOrQuery = HeaderCodec.headerAs[String]("X-Header") | HttpCodec.query[String]("header") val endpoint = Endpoint(GET / "test").out[String].inCodec(headerOrQuery) val routes = endpoint.implementHandler(Handler.identity).toRoutes val request = Request.get( @@ -487,7 +529,7 @@ object RequestSpec extends ZIOHttpSpec { } }, test("composite out codecs") { - val headerOrQuery = HeaderCodec.name[String]("X-Header") | StatusCodec.status(Status.Created) + val headerOrQuery = HeaderCodec.headerAs[String]("X-Header") | StatusCodec.status(Status.Created) val endpoint = Endpoint(GET / "test").query(HttpCodec.query[Boolean]("Created")).outCodec(headerOrQuery) val routes = endpoint.implementHandler { diff --git a/zio-http/jvm/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala b/zio-http/jvm/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala index 6ff5b44fad..49023d128e 100644 --- a/zio-http/jvm/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala @@ -2,7 +2,7 @@ package zio.http.endpoint.openapi import zio.json.ast.Json import zio.test._ -import zio.{Chunk, Scope, ZIO} +import zio.{Chunk, NonEmptyChunk, Scope, ZIO} import zio.schema.annotation._ import zio.schema.validation.Validation @@ -194,6 +194,13 @@ object OpenAPIGenSpec extends ZIOSpecDefault { .out[SimpleOutputBody] .outError[NotFoundError](Status.NotFound) + private val queryParamNonEmptyCollectionEndpoint = + Endpoint(GET / "withQuery") + .in[SimpleInputBody] + .query(HttpCodec.query[NonEmptyChunk[String]]("query")) + .out[SimpleOutputBody] + .outError[NotFoundError](Status.NotFound) + private val queryParamValidationEndpoint = Endpoint(GET / "withQuery") .in[SimpleInputBody] @@ -520,6 +527,112 @@ object OpenAPIGenSpec extends ZIOSpecDefault { test("with optional query parameter") { val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", queryParamOptEndpoint) val json = toJsonAst(generated) + val expectedJson = """{ + | "openapi" : "3.1.0", + | "info" : { + | "title" : "Simple Endpoint", + | "version" : "1.0" + | }, + | "paths" : { + | "/withQuery" : { + | "get" : { + | "parameters" : [ + | { + | "name" : "query", + | "in" : "query", + | "schema" : { + | "type" : "string" + | }, + | "allowReserved" : false, + | "style" : "form" + | } + | ], + | "requestBody" : { + | "content" : { + | "application/json" : { + | "schema" : { + | "$ref" : "#/components/schemas/SimpleInputBody" + | } + | } + | }, + | "required" : true + | }, + | "responses" : { + | "200" : { + | "content" : { + | "application/json" : { + | "schema" : { + | "$ref" : "#/components/schemas/SimpleOutputBody" + | } + | } + | } + | }, + | "404" : { + | "content" : { + | "application/json" : { + | "schema" : { + | "$ref" : "#/components/schemas/NotFoundError" + | } + | } + | } + | } + | } + | } + | } + | }, + | "components" : { + | "schemas" : { + | "NotFoundError" : { + | "type" : "object", + | "properties" : { + | "message" : { + | "type" : "string" + | } + | }, + | "required" : [ + | "message" + | ] + | }, + | "SimpleInputBody" : { + | "type" : "object", + | "properties" : { + | "name" : { + | "type" : "string" + | }, + | "age" : { + | "type" : "integer", + | "format" : "int32" + | } + | }, + | "required" : [ + | "name", + | "age" + | ] + | }, + | "SimpleOutputBody" : { + | "type" : "object", + | "properties" : { + | "userName" : { + | "type" : "string" + | }, + | "score" : { + | "type" : "integer", + | "format" : "int32" + | } + | }, + | "required" : [ + | "userName", + | "score" + | ] + | } + | } + | } + |}""".stripMargin + assertTrue(json == toJsonAst(expectedJson)) + }, + test("with query parameter with multiple values") { + val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", queryParamCollectionEndpoint) + val json = toJsonAst(generated) val expectedJson = """{ | "openapi" : "3.1.0", | "info" : { @@ -536,10 +649,8 @@ object OpenAPIGenSpec extends ZIOSpecDefault { | "in" : "query", | "schema" : | { - | "type" :[ - | "string", - | "null" - | ] + | "type" : + | "string" | }, | "allowReserved" : false, | "style" : "form" @@ -645,8 +756,8 @@ object OpenAPIGenSpec extends ZIOSpecDefault { |}""".stripMargin assertTrue(json == toJsonAst(expectedJson)) }, - test("with query parameter with multiple values") { - val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", queryParamCollectionEndpoint) + test("with query parameter with multiple values - non empty") { + val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", queryParamNonEmptyCollectionEndpoint) val json = toJsonAst(generated) val expectedJson = """{ | "openapi" : "3.1.0", diff --git a/zio-http/jvm/src/test/scala/zio/http/security/TimingAttacksSpec.scala b/zio-http/jvm/src/test/scala/zio/http/security/TimingAttacksSpec.scala index f4466f69c6..9afc70d1b8 100644 --- a/zio-http/jvm/src/test/scala/zio/http/security/TimingAttacksSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/security/TimingAttacksSpec.scala @@ -49,10 +49,10 @@ object TimingAttacksSpec extends ZIOSpecDefault { val a = java.lang.System.nanoTime sampleUnsorted = (a - b) :: sampleUnsorted } - val sample = sampleUnsorted.sorted - val tail = sample.drop((nOfTries * i).round.toInt) - val low = tail.head - val high = tail.drop((nOfTries * (j - i)).round.toInt).head + val sample = sampleUnsorted.sorted + val tail = sample.drop((nOfTries * i).round.toInt) + val low = tail.head + val high = tail.drop((nOfTries * (j - i)).round.toInt).head (low, high) } diff --git a/zio-http/shared/src/main/scala-2/zio/http/UrlInterpolator.scala b/zio-http/shared/src/main/scala-2/zio/http/UrlInterpolator.scala index c58fa33f4c..415042e94d 100644 --- a/zio-http/shared/src/main/scala-2/zio/http/UrlInterpolator.scala +++ b/zio-http/shared/src/main/scala-2/zio/http/UrlInterpolator.scala @@ -66,7 +66,7 @@ private[http] object UrlInterpolatorMacro { } } val exampleParts = staticParts.zipAll(injectedPartExamples, "", "").flatMap { case (a, b) => List(a, b) } - val example = exampleParts.mkString + val example = exampleParts.mkString URL.decode(example) match { case Left(error) => c.abort(c.enclosingPosition, error.getMessage) case Right(_) => diff --git a/zio-http/shared/src/main/scala/zio/http/ClientSSLConfig.scala b/zio-http/shared/src/main/scala/zio/http/ClientSSLConfig.scala index a7a4ab8f7f..bbc1aa97a7 100644 --- a/zio-http/shared/src/main/scala/zio/http/ClientSSLConfig.scala +++ b/zio-http/shared/src/main/scala/zio/http/ClientSSLConfig.scala @@ -119,8 +119,8 @@ object ClientSSLConfig { def isInvalidBuild: Boolean = !isValidBuild def build(): FromJavaxNetSsl = this - def keyManagerKeyStoreType(tpe: String): FromJavaxNetSsl = self.copy(keyManagerKeyStoreType = tpe) - def keyManagerFile(file: String): FromJavaxNetSsl = + def keyManagerKeyStoreType(tpe: String): FromJavaxNetSsl = self.copy(keyManagerKeyStoreType = tpe) + def keyManagerFile(file: String): FromJavaxNetSsl = keyManagerSource match { case FromJavaxNetSsl.Resource(_) => this case _ => self.copy(keyManagerSource = FromJavaxNetSsl.File(file)) diff --git a/zio-http/shared/src/main/scala/zio/http/Header.scala b/zio-http/shared/src/main/scala/zio/http/Header.scala index bccf90fde9..cc7e24834e 100644 --- a/zio-http/shared/src/main/scala/zio/http/Header.scala +++ b/zio-http/shared/src/main/scala/zio/http/Header.scala @@ -31,8 +31,14 @@ import scala.util.{Either, Failure, Success, Try} import zio.Config.Secret import zio._ -import zio.http.codec.RichTextCodec -import zio.http.internal.DateEncoding +import zio.schema.Schema +import zio.schema.codec.DecodeError +import zio.schema.codec.DecodeError.ReadError +import zio.schema.validation.ValidationError + +import zio.http.Header.HeaderTypeBase.Typed +import zio.http.codec.{HttpCodecError, RichTextCodec} +import zio.http.internal.{DateEncoding, ErrorConstructor, StringSchemaCodec} sealed trait Header { type Self <: Header @@ -50,14 +56,135 @@ sealed trait Header { object Header { - sealed trait HeaderType { + sealed trait HeaderTypeBase { + type HeaderValue + + def names: Chunk[String] + + def fromHeaders(headers: Headers): Either[String, HeaderValue] + + private[http] def fromHeadersUnsafe(headers: Headers): HeaderValue + + def toHeaders(value: HeaderValue): Headers = + value match { + case h: Header => Headers.fromIterable(h :: Nil) + case _ => Headers.empty + } + } + + object HeaderTypeBase { + type Typed[HV] = HeaderTypeBase { type HeaderValue = HV } + } + + sealed trait SchemaHeaderType extends HeaderTypeBase { + def schema: Schema[HeaderValue] + + def optional: HeaderTypeBase.Typed[Option[HeaderValue]] + } + + object SchemaHeaderType { + type Typed[H] = SchemaHeaderType { type HeaderValue = H } + + private val errorConstructor = + new ErrorConstructor { + override def missing(fieldName: String): HttpCodecError = + HttpCodecError.MissingHeader(fieldName) + + override def missingAll(fieldNames: Chunk[String]): HttpCodecError = + HttpCodecError.MissingHeaders(fieldNames) + + override def invalid(errors: Chunk[ValidationError]): HttpCodecError = + HttpCodecError.InvalidEntity.wrap(errors) + + override def malformed(fieldName: String, error: DecodeError): HttpCodecError = + HttpCodecError.DecodingErrorHeader(fieldName, error) + + override def invalidCount(fieldName: String, expected: Int, actual: Int): HttpCodecError = + HttpCodecError.InvalidHeaderCount(fieldName, expected, actual) + } + + def apply[H](implicit schema0: Schema[H]): SchemaHeaderType.Typed[H] = { + new SchemaHeaderType { + type HeaderValue = H + val schema: Schema[H] = + schema0 + val codec: StringSchemaCodec[H, Headers] = + StringSchemaCodec.headerFromSchema(schema0, errorConstructor, null) + + override def names: Chunk[String] = + codec.recordFields.map(_._1.fieldName) + + override def optional: SchemaHeaderType.Typed[Option[H]] = + apply(schema.optional) + + override def fromHeaders(headers: Headers): Either[String, H] = + try Right(codec.decode(headers)) + catch { + case NonFatal(e) => Left(e.getMessage) + } + + private[http] override def fromHeadersUnsafe(headers: Headers): H = + codec.decode(headers) + + override def toHeaders(value: H): Headers = + codec.encode(value, Headers.empty) + } + } + + def apply[H](name: String)(implicit schema0: Schema[H]): SchemaHeaderType.Typed[H] = { + new SchemaHeaderType { + type HeaderValue = H + val schema: Schema[H] = schema0 + val codec: StringSchemaCodec[H, Headers] = + StringSchemaCodec.headerFromSchema(schema, errorConstructor, name) + + override def names: Chunk[String] = + codec.recordFields.map(_._1.fieldName) + + override def optional: SchemaHeaderType.Typed[Option[H]] = + apply(name)(schema.optional) + + override def fromHeaders(headers: Headers): Either[String, H] = + try Right(codec.decode(headers)) + catch { + case NonFatal(e) => Left(e.getMessage) + } + + private[http] override def fromHeadersUnsafe(headers: Headers): H = + codec.decode(headers) + + override def toHeaders(value: H): Headers = + codec.encode(value, Headers.empty) + } + } + } + + sealed trait HeaderType extends HeaderTypeBase { type HeaderValue <: Header + def names: Chunk[String] = Chunk.single(name) + def name: String def parse(value: String): Either[String, HeaderValue] def render(value: HeaderValue): String + + def fromHeaders(headers: Headers): Either[String, HeaderValue] = + headers.getUnsafe(name) match { + case null => Left(s"Header $name not found") + case value => parse(value) + } + + def fromHeadersUnsafe(headers: Headers): HeaderValue = + fromHeaders(headers).fold( + e => throw HttpCodecError.DecodingErrorHeader(name, ReadError(Cause.empty, e)), + identity, + ) + + override def toHeaders(value: HeaderValue): Headers = + Headers.FromIterable(Iterable(value)) + } object HeaderType { diff --git a/zio-http/shared/src/main/scala/zio/http/Headers.scala b/zio-http/shared/src/main/scala/zio/http/Headers.scala index 3df48e5739..08c24ea016 100644 --- a/zio-http/shared/src/main/scala/zio/http/Headers.scala +++ b/zio-http/shared/src/main/scala/zio/http/Headers.scala @@ -99,6 +99,7 @@ object Headers { value: T, iterate: T => Iterator[Header], unsafeGet: (T, CharSequence) => String, + getAll: (T, CharSequence) => Chunk[String], contains: (T, CharSequence) => Boolean, ) extends Headers { override def contains(key: CharSequence): Boolean = contains(value, key) @@ -106,6 +107,9 @@ object Headers { override def iterator: Iterator[Header] = iterate(value) override private[http] def getUnsafe(key: CharSequence): String = unsafeGet(value, key) + + override def rawHeaders(name: CharSequence): Chunk[String] = getAll(value, name) + } private[zio] final case class Concat(first: Headers, second: Headers) extends Headers { diff --git a/zio-http/shared/src/main/scala/zio/http/MediaTypes.scala b/zio-http/shared/src/main/scala/zio/http/MediaTypes.scala index f9fa13ef35..8660f275b7 100644 --- a/zio-http/shared/src/main/scala/zio/http/MediaTypes.scala +++ b/zio-http/shared/src/main/scala/zio/http/MediaTypes.scala @@ -9688,7 +9688,7 @@ private[zio] trait MediaTypes { lazy val `vnd.dolby.heaac.2`: MediaType = new MediaType("audio", "vnd.dolby.heaac.2", compressible = false, binary = true) - lazy val `amr-wb+` : MediaType = + lazy val `amr-wb+`: MediaType = new MediaType("audio", "amr-wb+", compressible = false, binary = true) lazy val `dsr-es202211`: MediaType = diff --git a/zio-http/shared/src/main/scala/zio/http/ZClient.scala b/zio-http/shared/src/main/scala/zio/http/ZClient.scala index ce6d47b332..e9efe54629 100644 --- a/zio-http/shared/src/main/scala/zio/http/ZClient.scala +++ b/zio-http/shared/src/main/scala/zio/http/ZClient.scala @@ -759,7 +759,7 @@ object ZClient extends ZClientPlatformSpecific { } } yield () }.forkDaemon // Needs to live as long as the channel is alive, as the response body may be streaming - _ <- ZIO.addFinalizer(onComplete.interrupt) + _ <- ZIO.addFinalizer(onComplete.interrupt) response <- restore(onResponse.await.onInterrupt { ZIO.unlessZIO(connectionAcquired.get)(channelFiber.interrupt) *> onComplete.interrupt *> diff --git a/zio-http/shared/src/main/scala/zio/http/codec/HeaderCodecs.scala b/zio-http/shared/src/main/scala/zio/http/codec/HeaderCodecs.scala index a6ced4dec4..2bafeeaa48 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/HeaderCodecs.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/HeaderCodecs.scala @@ -16,24 +16,48 @@ package zio.http.codec +import java.util.UUID + import scala.util.Try import zio.stacktracer.TracingImplicits.disableAutoTrace -import zio.http.Header.HeaderType +import zio.schema._ + +import zio.http.Header.{HeaderType, SchemaHeaderType} import zio.http._ private[codec] trait HeaderCodecs { - private[http] def headerCodec[A](name: String, value: TextCodec[A]): HeaderCodec[A] = - HttpCodec.Header(name, value) + private[http] def headerCodec[A](name: String, value: TextCodec[A]): HeaderCodec[A] = { + val schema = value match { + case TextCodec.Constant(string) => + Schema[String].transformOrFail[Unit]( + s => if (s == string) Right(()) else Left(s"Header $name was not $string"), + (_: Unit) => Right(string), + ) + case TextCodec.StringCodec => Schema[String] + case TextCodec.IntCodec => Schema[Int] + case TextCodec.LongCodec => Schema[Long] + case TextCodec.BooleanCodec => Schema[Boolean] + case TextCodec.UUIDCodec => Schema[UUID] + } + HttpCodec.Header(SchemaHeaderType(name)(schema.asInstanceOf[Schema[A]])) + } def header(headerType: HeaderType): HeaderCodec[headerType.HeaderValue] = - headerCodec(headerType.name, TextCodec.string) - .transformOrFailLeft(headerType.parse(_))(headerType.render(_)) + HttpCodec.Header(headerType) + + def headerAs[A](name: String)(implicit schema: Schema[A]): HeaderCodec[A] = + HttpCodec.Header(SchemaHeaderType(name)) + + def headers[A](implicit schema: Schema[A]): HeaderCodec[A] = + HttpCodec.Header(SchemaHeaderType("headers")) + @deprecated("Use Schema based headerAs instead", "3.1.0") def name[A](name: String)(implicit codec: TextCodec[A]): HeaderCodec[A] = headerCodec(name, codec) + @deprecated("Use Schema based API instead", "3.1.0") def nameTransform[A, B]( name: String, parse: B => A, @@ -43,11 +67,13 @@ private[codec] trait HeaderCodecs { Try(parse(s)).toEither.left.map(e => s"Failed to parse header $name: ${e.getMessage}"), )(render) + @deprecated("Use Schema based API instead", "3.1.0") def nameTransformOption[A, B](name: String, parse: B => Option[A], render: A => B)(implicit codec: TextCodec[B], ): HeaderCodec[A] = headerCodec(name, codec).transformOrFailLeft(parse(_).toRight(s"Failed to parse header $name"))(render) + @deprecated("Use Schema based API instead", "3.1.0") def nameTransformOrFail[A, B](name: String, parse: B => Either[String, A], render: A => B)(implicit codec: TextCodec[B], ): HeaderCodec[A] = diff --git a/zio-http/shared/src/main/scala/zio/http/codec/HttpCodec.scala b/zio-http/shared/src/main/scala/zio/http/codec/HttpCodec.scala index 278f185366..3ccabbebf8 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/HttpCodec.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/HttpCodec.scala @@ -18,19 +18,20 @@ package zio.http.codec import scala.annotation.tailrec import scala.reflect.ClassTag +import scala.util.Try -import zio._ +import zio.{http, _} -import zio.stream.ZStream +import zio.stream.{ZPipeline, ZStream} import zio.schema.Schema -import zio.schema.annotation._ import zio.http.Header.Accept.MediaTypeWithQFactor +import zio.http.Header.{HeaderTypeBase, SchemaHeaderType} import zio.http._ -import zio.http.codec.HttpCodec.Query.QueryType import zio.http.codec.HttpCodec.{Annotated, Metadata} import zio.http.codec.internal._ +import zio.http.internal.StringSchemaCodec /** * A [[zio.http.codec.HttpCodec]] represents a codec for a part of an HTTP @@ -2263,141 +2264,23 @@ object HttpCodec extends ContentCodecs with HeaderCodecs with MethodCodecs with def index(index: Int): ContentStream[A] = copy(index = index) } - private[http] final case class Query[A, Out]( - queryType: Query.QueryType[A], + private[http] final case class Query[A]( + codec: StringSchemaCodec[A, QueryParams], index: Int = 0, - ) extends Atom[HttpCodecType.Query, Out] { + ) extends Atom[HttpCodecType.Query, A] { self => - def erase: Query[Any, Any] = self.asInstanceOf[Query[Any, Any]] + def erase: Query[Any] = self.asInstanceOf[Query[Any]] - def tag: AtomTag = AtomTag.Query - - def index(index: Int): Query[A, Out] = copy(index = index) - - def isOptional: Boolean = - queryType match { - case QueryType.Primitive(_, BinaryCodecWithSchema(_, schema)) if schema.isInstanceOf[Schema.Optional[_]] => - true - case QueryType.Record(recordSchema) => - recordSchema match { - case s if s.isInstanceOf[Schema.Optional[_]] => true - case record: Schema.Record[_] if record.fields.forall(_.optional) => true - case _ => false - } - case _ => false - } + def index(index: Int): Query[A] = copy(index = index) /** * Returns a new codec, where the value produced by this one is optional. */ - override def optional: HttpCodec[HttpCodecType.Query, Option[Out]] = - queryType match { - case QueryType.Primitive(name, codec) if codec.schema.isInstanceOf[Schema.Optional[_]] => - throw new IllegalArgumentException( - s"Cannot make an optional query parameter optional. Name: $name schema: ${codec.schema}", - ) - case QueryType.Primitive(name, codec) => - val optionalSchema = codec.schema.optional - copy(queryType = - QueryType.Primitive(name, BinaryCodecWithSchema(TextBinaryCodec.fromSchema(optionalSchema), optionalSchema)), - ) - case QueryType.Record(recordSchema) if recordSchema.isInstanceOf[Schema.Optional[_]] => - throw new IllegalArgumentException(s"Cannot make an optional query parameter optional") - case QueryType.Record(recordSchema) => - val optionalSchema = recordSchema.optional - copy(queryType = QueryType.Record(optionalSchema)) - case queryType @ QueryType.Collection(_, _, false) => - copy(queryType = QueryType.Collection(queryType.colSchema, queryType.elements, optional = true)) - case queryType @ QueryType.Collection(_, _, true) => - throw new IllegalArgumentException(s"Cannot make an optional query parameter optional: $queryType") - - } - - } + override def optional: HttpCodec[HttpCodecType.Query, Option[A]] = + Annotated(Query(codec.optional, index), Metadata.Optional()) - private[http] object Query { - sealed trait QueryType[A] - object QueryType { - case class Primitive[A](name: String, codec: BinaryCodecWithSchema[A]) extends QueryType[A] - case class Collection[A](colSchema: Schema.Collection[_, _], elements: QueryType.Primitive[A], optional: Boolean) - extends QueryType[A] { - def toCollection(values: Chunk[Any]): A = - colSchema match { - case Schema.Sequence(_, fromChunk, _, _, _) => - fromChunk.asInstanceOf[Chunk[Any] => Any](values).asInstanceOf[A] - case Schema.Set(_, _) => - values.toSet.asInstanceOf[A] - case _ => - throw new IllegalArgumentException( - s"Unsupported collection schema for query object field of type: $colSchema", - ) - } - } - case class Record[A](recordSchema: Schema[A]) extends QueryType[A] { - private var namesAndCodecs: Chunk[(Schema.Field[_, _], BinaryCodecWithSchema[Any])] = _ - private[http] def fieldAndCodecs: Chunk[(Schema.Field[_, _], BinaryCodecWithSchema[Any])] = - if (namesAndCodecs == null) { - namesAndCodecs = recordSchema match { - case record: Schema.Record[A] => - record.fields.map { field => - validateSchema(field.name, field.schema) - val codec = binaryCodecForField(field.annotations.foldLeft(field.schema)(_ annotate _)) - (unlazy(field.asInstanceOf[Schema.Field[Any, Any]]), codec) - } - case s if s.isInstanceOf[Schema.Optional[_]] => - val record = s.asInstanceOf[Schema.Optional[A]].schema.asInstanceOf[Schema.Record[A]] - record.fields.map { field => - validateSchema(field.name, field.annotations.foldLeft(field.schema)(_ annotate _)) - val codec = binaryCodecForField(field.schema) - (field, codec) - } - case s => throw new IllegalArgumentException(s"Unsupported schema for query object field of type: $s") - } - namesAndCodecs - } else { - namesAndCodecs - } - } - - private def unlazy(field: Schema.Field[Any, Any]): Schema.Field[Any, Any] = field.schema match { - case Schema.Lazy(schema) => - Schema.Field( - field.name, - schema(), - field.annotations, - field.validation, - field.get, - field.set, - ) - case _ => field - } - - private def binaryCodecForField[A](schema: Schema[A]): BinaryCodecWithSchema[Any] = (schema match { - case schema @ Schema.Primitive(_, _) => BinaryCodecWithSchema(TextBinaryCodec.fromSchema(schema), schema) - case Schema.Transform(_, _, _, _, _) => BinaryCodecWithSchema(TextBinaryCodec.fromSchema(schema), schema) - case Schema.Optional(_, _) => BinaryCodecWithSchema(TextBinaryCodec.fromSchema(schema), schema) - case e: Schema.Enum[_] if isSimple(e) => BinaryCodecWithSchema(TextBinaryCodec.fromSchema(schema), schema) - case l @ Schema.Lazy(_) => binaryCodecForField(l.schema) - case Schema.Set(schema, _) => binaryCodecForField(schema) - case Schema.Sequence(schema, _, _, _, _) => binaryCodecForField(schema) - case schema => throw new IllegalArgumentException(s"Unsupported schema for query object field of type: $schema") - }).asInstanceOf[BinaryCodecWithSchema[Any]] - - def isSimple(schema: Schema.Enum[_]): Boolean = - schema.annotations.exists(_.isInstanceOf[simpleEnum]) - - @tailrec - private def validateSchema[A](name: String, schema: Schema[A]): Unit = schema match { - case _: Schema.Primitive[A] => () - case Schema.Transform(schema, _, _, _, _) => validateSchema(name, schema) - case Schema.Optional(schema, _) => validateSchema(name, schema) - case Schema.Lazy(schema) => validateSchema(name, schema()) - case Schema.Set(schema, _) => validateSchema(name, schema) - case Schema.Sequence(schema, _, _, _, _) => validateSchema(name, schema) - case s => throw new IllegalArgumentException(s"Unsupported schema for query object field of type: $s") - } + def tag: AtomTag = AtomTag.Query - } } private[http] final case class Method[A](codec: SimpleCodec[zio.http.Method, A], index: Int = 0) @@ -2409,7 +2292,7 @@ object HttpCodec extends ContentCodecs with HeaderCodecs with MethodCodecs with def index(index: Int): Method[A] = copy(index = index) } - private[http] final case class Header[A](name: String, textCodec: TextCodec[A], index: Int = 0) + private[http] final case class Header[A](headerType: HeaderTypeBase.Typed[A], index: Int = 0) extends Atom[HttpCodecType.Header, A] { self => def erase: Header[Any] = self.asInstanceOf[Header[Any]] @@ -2417,6 +2300,18 @@ object HttpCodec extends ContentCodecs with HeaderCodecs with MethodCodecs with def tag: AtomTag = AtomTag.Header def index(index: Int): Header[A] = copy(index = index) + + override def optional: HttpCodec[HttpCodecType.Header, Option[A]] = { + headerType match { + case headerType if headerType.isInstanceOf[SchemaHeaderType] => + Annotated( + Header(headerType.asInstanceOf[SchemaHeaderType.Typed[A]].optional, index), + Metadata.Optional(), + ) + case _ => + super.optional + } + } } private[http] final case class Annotated[AtomTypes, Value]( diff --git a/zio-http/shared/src/main/scala/zio/http/codec/HttpCodecError.scala b/zio-http/shared/src/main/scala/zio/http/codec/HttpCodecError.scala index bcd97223d8..bfbacea8d9 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/HttpCodecError.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/HttpCodecError.scala @@ -23,6 +23,7 @@ import zio.{Cause, Chunk} import zio.schema.codec.DecodeError import zio.schema.validation.ValidationError +import zio.http.Header.HeaderType import zio.http.{Path, Status} sealed trait HttpCodecError extends Exception with NoStackTrace with Product with Serializable { @@ -33,6 +34,9 @@ object HttpCodecError { final case class MissingHeader(headerName: String) extends HttpCodecError { def message = s"Missing header $headerName" } + final case class MissingHeaders(headerNames: Chunk[String]) extends HttpCodecError { + def message = s"Missing headers ${headerNames.mkString(", ")}" + } final case class MalformedMethod(expected: zio.http.Method, actual: zio.http.Method) extends HttpCodecError { def message = s"Expected $expected but found $actual" } @@ -48,6 +52,12 @@ object HttpCodecError { final case class MalformedHeader(headerName: String, textCodec: TextCodec[_]) extends HttpCodecError { def message = s"Malformed header $headerName failed to decode using $textCodec" } + final case class DecodingErrorHeader(headerName: String, cause: DecodeError) extends HttpCodecError { + def message = s"Malformed header $headerName could not be decoded: $cause" + } + final case class MalformedTypedHeader(headerName: String) extends HttpCodecError { + def message = s"Malformed header $headerName" + } final case class MissingQueryParam(queryParamName: String) extends HttpCodecError { def message = s"Missing query parameter $queryParamName" } @@ -73,6 +83,9 @@ object HttpCodecError { final case class InvalidQueryParamCount(name: String, expected: Int, actual: Int) extends HttpCodecError { def message = s"Invalid query parameter count for $name: expected $expected but found $actual." } + final case class InvalidHeaderCount(name: String, expected: Int, actual: Int) extends HttpCodecError { + def message = s"Invalid query parameter count for $name: expected $expected but found $actual." + } final case class CustomError(name: String, message: String) extends HttpCodecError final case class UnsupportedContentType(contentType: String) extends HttpCodecError { @@ -92,6 +105,9 @@ object HttpCodecError { def isMissingDataOnly(cause: Cause[Any]): Boolean = !cause.isFailure && cause.defects.forall(e => - e.isInstanceOf[HttpCodecError.MissingHeader] || e.isInstanceOf[HttpCodecError.MissingQueryParam], + e.isInstanceOf[HttpCodecError.MissingHeader] + || e.isInstanceOf[HttpCodecError.MissingQueryParam] + || e.isInstanceOf[HttpCodecError.MissingQueryParams] + || e.isInstanceOf[HttpCodecError.MissingHeaders], ) } diff --git a/zio-http/shared/src/main/scala/zio/http/codec/QueryCodecs.scala b/zio-http/shared/src/main/scala/zio/http/codec/QueryCodecs.scala index 4bc203f5e1..a941d4df85 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/QueryCodecs.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/QueryCodecs.scala @@ -15,127 +15,38 @@ */ package zio.http.codec -import scala.annotation.tailrec - +import zio.Chunk import zio.stacktracer.TracingImplicits.disableAutoTrace import zio.schema.Schema -import zio.schema.annotation.simpleEnum +import zio.schema.codec.DecodeError +import zio.schema.validation.ValidationError + +import zio.http.internal.{ErrorConstructor, StringSchemaCodec} private[codec] trait QueryCodecs { - def query[A](name: String)(implicit schema: Schema[A]): QueryCodec[A] = - schema match { - case s @ Schema.Primitive(_, _) => - HttpCodec.Query( - HttpCodec.Query.QueryType - .Primitive(name, BinaryCodecWithSchema.fromBinaryCodec(TextBinaryCodec.fromSchema(s))(s)), - ) - case c @ Schema.Sequence(elementSchema, _, _, _, _) => - if (supportedElementSchema(elementSchema.asInstanceOf[Schema[Any]])) { - HttpCodec.Query( - HttpCodec.Query.QueryType.Collection( - c, - HttpCodec.Query.QueryType.Primitive( - name, - BinaryCodecWithSchema(TextBinaryCodec.fromSchema(elementSchema), elementSchema), - ), - optional = false, - ), - ) - } else { - throw new IllegalArgumentException("Only primitive types can be elements of sequences") - } - case c @ Schema.Set(elementSchema, _) => - if (supportedElementSchema(elementSchema.asInstanceOf[Schema[Any]])) { - HttpCodec.Query( - HttpCodec.Query.QueryType.Collection( - c, - HttpCodec.Query.QueryType.Primitive( - name, - BinaryCodecWithSchema(TextBinaryCodec.fromSchema(elementSchema), elementSchema), - ), - optional = false, - ), - ) - } else { - throw new IllegalArgumentException("Only primitive types can be elements of sets") - } - case Schema.Optional(Schema.Primitive(_, _), _) => - HttpCodec.Query( - HttpCodec.Query.QueryType - .Primitive(name, BinaryCodecWithSchema.fromBinaryCodec(TextBinaryCodec.fromSchema(schema))(schema)), - ) - case Schema.Optional(c @ Schema.Sequence(elementSchema, _, _, _, _), _) => - if (supportedElementSchema(elementSchema.asInstanceOf[Schema[Any]])) { - HttpCodec.Query( - HttpCodec.Query.QueryType.Collection( - c, - HttpCodec.Query.QueryType.Primitive( - name, - BinaryCodecWithSchema(TextBinaryCodec.fromSchema(elementSchema), elementSchema), - ), - optional = true, - ), - ) - } else { - throw new IllegalArgumentException("Only primitive types can be elements of sequences") - } - case Schema.Optional(inner, _) if inner.isInstanceOf[Schema.Set[_]] => - val elementSchema = inner.asInstanceOf[Schema.Set[Any]].elementSchema - if (supportedElementSchema(elementSchema)) { - HttpCodec.Query( - HttpCodec.Query.QueryType.Collection( - inner.asInstanceOf[Schema.Set[_]], - HttpCodec.Query.QueryType.Primitive( - name, - BinaryCodecWithSchema(TextBinaryCodec.fromSchema(inner), inner), - ), - optional = true, - ), - ) - } else { - throw new IllegalArgumentException("Only primitive types can be elements of sets") - } - case enum0: Schema.Enum[_] if enum0.annotations.exists(_.isInstanceOf[simpleEnum]) => - HttpCodec.Query( - HttpCodec.Query.QueryType - .Primitive(name, BinaryCodecWithSchema.fromBinaryCodec(TextBinaryCodec.fromSchema(schema))(schema)), - ) - case record: Schema.Record[A] if record.fields.size == 1 => - val field = record.fields.head - if (supportedElementSchema(field.schema.asInstanceOf[Schema[Any]])) { - HttpCodec.Query( - HttpCodec.Query.QueryType.Primitive( - name, - BinaryCodecWithSchema(TextBinaryCodec.fromSchema(record), record), - ), - ) - } else { - throw new IllegalArgumentException("Only primitive types can be elements of records") - } - case other => - throw new IllegalArgumentException( - s"Only primitive types, sequences, sets, optional, enums and records with a single field can be used to infer query codecs, but got $other", - ) - } + private val errorConstructor = new ErrorConstructor { + override def missing(fieldName: String): HttpCodecError = + HttpCodecError.MissingQueryParam(fieldName) + + override def missingAll(fieldNames: Chunk[String]): HttpCodecError = + HttpCodecError.MissingQueryParams(fieldNames) - @tailrec - private def supportedElementSchema(elementSchema: Schema[Any]): Boolean = elementSchema match { - case Schema.Lazy(schema0) => supportedElementSchema(schema0()) - case _ => - elementSchema.isInstanceOf[Schema.Primitive[_]] || - elementSchema.isInstanceOf[Schema.Enum[_]] && elementSchema.annotations.exists(_.isInstanceOf[simpleEnum]) || - elementSchema.isInstanceOf[Schema.Record[_]] && elementSchema.asInstanceOf[Schema.Record[_]].fields.size == 1 + override def invalid(errors: Chunk[ValidationError]): HttpCodecError = + HttpCodecError.InvalidEntity.wrap(errors) + + override def malformed(fieldName: String, error: DecodeError): HttpCodecError = + HttpCodecError.MalformedQueryParam(fieldName, error) + + override def invalidCount(fieldName: String, expected: Int, actual: Int): HttpCodecError = + HttpCodecError.InvalidQueryParamCount(fieldName, expected, actual) } + def query[A](name: String)(implicit schema: Schema[A]): QueryCodec[A] = + HttpCodec.Query(StringSchemaCodec.queryFromSchema[A](schema, errorConstructor, name)) + def queryAll[A](implicit schema: Schema[A]): QueryCodec[A] = - schema match { - case _: Schema.Primitive[A] => - throw new IllegalArgumentException("Use query[A](name: String) for primitive types") - case record: Schema.Record[A] => HttpCodec.Query(HttpCodec.Query.QueryType.Record(record)) - case Schema.Optional(_, _) => HttpCodec.Query(HttpCodec.Query.QueryType.Record(schema)) - case _ => throw new IllegalArgumentException("Only case classes can be used to infer query codecs") - } + HttpCodec.Query(StringSchemaCodec.queryFromSchema[A](schema, errorConstructor, null)) } diff --git a/zio-http/shared/src/main/scala/zio/http/codec/TextBinaryCodec.scala b/zio-http/shared/src/main/scala/zio/http/codec/TextBinaryCodec.scala index cd552ad814..9059501f96 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/TextBinaryCodec.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/TextBinaryCodec.scala @@ -120,10 +120,10 @@ object TextBinaryCodec { ) override def streamEncoder: ZPipeline[Any, Nothing, A, Byte] = ZPipeline.mapChunks(_.flatMap(encode)) - override def streamDecoder: ZPipeline[Any, DecodeError, Byte, A] = codec.streamDecoder.map { x => + override def streamDecoder: ZPipeline[Any, DecodeError, Byte, A] = codec.streamDecoder.mapZIO { x => f(x) match { - case Left(value) => throw DecodeError.ReadError(Cause.fail(new Exception("Error in decoding")), value) - case Right(a) => a + case Left(value) => ZIO.fail(DecodeError.ReadError(Cause.fail(new Exception("Error in decoding")), value)) + case Right(a) => ZIO.succeed(a) } } } @@ -356,7 +356,7 @@ object TextBinaryCodec { override def streamDecoder: ZPipeline[Any, DecodeError, Byte, A] = (ZPipeline[Byte] >>> ZPipeline.utf8Decode) - .map(s => decode(Chunk.fromArray(s.getBytes)).fold(throw _, identity)) + .mapZIO(s => ZIO.fromEither(decode(Chunk.fromArray(s.getBytes)))) .mapErrorCause(e => Cause.fail(DecodeError.ReadError(e, e.squash.getMessage))) } case Schema.Lazy(schema0) => fromSchema(schema0()) diff --git a/zio-http/shared/src/main/scala/zio/http/codec/internal/Atomized.scala b/zio-http/shared/src/main/scala/zio/http/codec/internal/Atomized.scala index 9a53a41add..d7a51ebbd1 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/internal/Atomized.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/internal/Atomized.scala @@ -50,5 +50,6 @@ private[http] final case class Atomized[A]( } } private[http] object Atomized { - def apply[A](defValue: => A): Atomized[A] = Atomized(defValue, defValue, defValue, defValue, defValue, defValue) + def apply[A](defValue: => A): Atomized[A] = + Atomized(defValue, defValue, defValue, defValue, defValue, defValue) } diff --git a/zio-http/shared/src/main/scala/zio/http/codec/internal/AtomizedCodecs.scala b/zio-http/shared/src/main/scala/zio/http/codec/internal/AtomizedCodecs.scala index 52cce2c84a..4ae3b295b7 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/internal/AtomizedCodecs.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/internal/AtomizedCodecs.scala @@ -25,7 +25,7 @@ import zio.http.codec._ private[http] final case class AtomizedCodecs( method: Chunk[SimpleCodec[zio.http.Method, _]], path: Chunk[PathCodec[_]], - query: Chunk[Query[_, _]], + query: Chunk[Query[_]], header: Chunk[Header[_]], content: Chunk[BodyCodec[_]], status: Chunk[SimpleCodec[zio.http.Status, _]], @@ -33,11 +33,11 @@ private[http] final case class AtomizedCodecs( def append(atom: Atom[_, _]): AtomizedCodecs = atom match { case path0: Path[_] => self.copy(path = path :+ path0.pathCodec) case method0: Method[_] => self.copy(method = method :+ method0.codec) - case query0: Query[_, _] => self.copy(query = query :+ query0) + case query0: Query[_] => self.copy(query = query :+ query0) case header0: Header[_] => self.copy(header = header :+ header0) + case status0: Status[_] => self.copy(status = status :+ status0.codec) case content0: Content[_] => self.copy(content = content :+ BodyCodec.Single(content0.codec, content0.name)) - case status0: Status[_] => self.copy(status = status :+ status0.codec) case stream0: ContentStream[_] => self.copy(content = content :+ BodyCodec.Multiple(stream0.codec, stream0.name)) } diff --git a/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala b/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala index 44d99b72dd..63c315179d 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala @@ -16,16 +16,13 @@ package zio.http.codec.internal +import scala.annotation.tailrec import scala.util.Try import zio._ -import zio.schema.codec.{BinaryCodec, DecodeError} -import zio.schema.{Schema, StandardType} - import zio.http.Header.Accept.MediaTypeWithQFactor import zio.http._ -import zio.http.codec.HttpCodec.Query.QueryType import zio.http.codec._ private[codec] trait EncoderDecoder[-AtomTypes, Value] { self => @@ -46,7 +43,7 @@ private[codec] object EncoderDecoder { val flattened = httpCodec.alternatives flattened.length match { - case 0 => Undefined() + case 0 => Undefined.asInstanceOf[EncoderDecoder[AtomTypes, Value]] case 1 => Single(flattened.head._1) case _ => Multiple(flattened) } @@ -109,7 +106,7 @@ private[codec] object EncoderDecoder { } } - private final case class Undefined[-AtomTypes, Value]() extends EncoderDecoder[AtomTypes, Value] { + private object Undefined extends EncoderDecoder[Any, Any] { val encodeWithErrorMessage = """ @@ -125,7 +122,7 @@ private[codec] object EncoderDecoder { override def encodeWith[Z]( config: CodecConfig, - value: Value, + value: Any, outputTypes: Chunk[MediaTypeWithQFactor], )(f: (zio.http.URL, Option[zio.http.Status], Option[zio.http.Method], zio.http.Headers, zio.http.Body) => Z): Z = { throw new IllegalStateException(encodeWithErrorMessage) @@ -138,7 +135,7 @@ private[codec] object EncoderDecoder { method: zio.http.Method, headers: zio.http.Headers, body: zio.http.Body, - )(implicit trace: zio.Trace): zio.Task[Value] = { + )(implicit trace: zio.Trace): zio.Task[Any] = { ZIO.fail(new IllegalStateException(decodeErrorMessage)) } } @@ -214,178 +211,19 @@ private[codec] object EncoderDecoder { ) private def decodeQuery(config: CodecConfig, queryParams: QueryParams, inputs: Array[Any]): Unit = - genericDecode[QueryParams, HttpCodec.Query[_, _]]( + genericDecode[QueryParams, HttpCodec.Query[_]]( queryParams, flattened.query, inputs, - (codec, queryParams) => { - val query = codec.erase - val isOptional = query.isOptional - query.queryType match { - case QueryType.Primitive(name, bc @ BinaryCodecWithSchema(_, schema)) => - val count = queryParams.valueCount(name) - val hasParam = queryParams.hasQueryParam(name) - if (!hasParam && isOptional) None - else if (!hasParam) throw HttpCodecError.MissingQueryParam(name) - else if (count != 1) throw HttpCodecError.InvalidQueryParamCount(name, 1, count) - else { - val decoded = bc - .codec(config) - .decode( - Chunk.fromArray(queryParams.unsafeQueryParam(name).getBytes(Charsets.Utf8)), - ) match { - case Left(error) => throw HttpCodecError.MalformedQueryParam(name, error) - case Right(value) => value - } - val validationErrors = schema.validate(decoded)(schema) - if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) - if (isOptional && decoded == None && emptyStringIsValue(schema.asInstanceOf[Schema.Optional[_]].schema)) - Some("") - else decoded - } - case c @ QueryType.Collection(_, QueryType.Primitive(name, bc), optional) => - if (!queryParams.hasQueryParam(name)) { - if (!optional) c.toCollection(Chunk.empty) - else None - } else { - val values = queryParams.queryParams(name) - val decoded = c.toCollection { - values.map { value => - bc.codec(config).decode(Chunk.fromArray(value.getBytes(Charsets.Utf8))) match { - case Left(error) => throw HttpCodecError.MalformedQueryParam(name, error) - case Right(value) => value - } - } - } - val erasedSchema = c.colSchema.asInstanceOf[Schema[Any]] - val validationErrors = erasedSchema.validate(decoded)(erasedSchema) - if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) - if (optional) Some(decoded) - else decoded - } - case query @ QueryType.Record(recordSchema) => - val hasAllParams = query.fieldAndCodecs.forall { case (field, _) => - queryParams.hasQueryParam(field.name) || field.optional || field.defaultValue.isDefined - } - if (!hasAllParams && recordSchema.isInstanceOf[Schema.Optional[_]]) None - else if (!hasAllParams && isOptional) { - recordSchema.defaultValue match { - case Left(err) => - throw new IllegalStateException(s"Cannot compute default value for $recordSchema. Error was: $err") - case Right(value) => value - } - } else if (!hasAllParams) throw HttpCodecError.MissingQueryParams { - query.fieldAndCodecs.collect { - case (field, _) - if !(queryParams.hasQueryParam(field.name) || field.optional || field.defaultValue.isDefined) => - field.name - } - } - else { - val decoded = query.fieldAndCodecs.map { - case (field, codec) if field.schema.isInstanceOf[Schema.Collection[_, _]] => - if (!queryParams.hasQueryParam(field.name) && field.defaultValue.nonEmpty) field.defaultValue.get - else { - val values = queryParams.queryParams(field.name) - val decoded = values.map { value => - codec.codec(config).decode(Chunk.fromArray(value.getBytes(Charsets.Utf8))) match { - case Left(error) => throw HttpCodecError.MalformedQueryParam(field.name, error) - case Right(value) => value - } - } - val decodedCollection = - field.schema match { - case s @ Schema.Sequence(_, fromChunk, _, _, _) => - val collection = fromChunk.asInstanceOf[Chunk[Any] => Any](decoded) - val erasedSchema = s.asInstanceOf[Schema[Any]] - val validationErrors = erasedSchema.validate(collection)(erasedSchema) - if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) - collection - case s @ Schema.Set(_, _) => - val collection = decoded.toSet[Any] - val erasedSchema = s.asInstanceOf[Schema.Set[Any]] - val validationErrors = erasedSchema.validate(collection)(erasedSchema) - if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) - collection - case _ => throw new IllegalStateException("Only Sequence and Set are supported.") - } - decodedCollection - } - case (field, codec) => - val value = queryParams.queryParamOrElse(field.name, null) - val decoded = { - if (value == null) field.defaultValue.get - else { - codec.codec(config).decode(Chunk.fromArray(value.getBytes(Charsets.Utf8))) match { - case Left(error) => throw HttpCodecError.MalformedQueryParam(field.name, error) - case Right(value) => value - } - } - } - val validationErrors = codec.schema.validate(decoded)(codec.schema) - if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) - decoded - } - if (recordSchema.isInstanceOf[Schema.Optional[_]]) { - val schema = recordSchema.asInstanceOf[Schema.Optional[_]].schema.asInstanceOf[Schema.Record[Any]] - val constructed = schema.construct(decoded)(Unsafe.unsafe) - constructed match { - case Left(value) => - throw HttpCodecError.MalformedQueryParam( - s"${schema.id}", - DecodeError.ReadError(Cause.empty, value), - ) - case Right(value) => - schema.validate(value)(schema) match { - case errors if errors.nonEmpty => throw HttpCodecError.InvalidEntity.wrap(errors) - case _ => Some(value) - } - } - } else { - val schema = recordSchema.asInstanceOf[Schema.Record[Any]] - val constructed = schema.construct(decoded)(Unsafe.unsafe) - constructed match { - case Left(value) => - throw HttpCodecError.MalformedQueryParam( - s"${schema.id}", - DecodeError.ReadError(Cause.empty, value), - ) - case Right(value) => - schema.validate(value)(schema) match { - case errors if errors.nonEmpty => throw HttpCodecError.InvalidEntity.wrap(errors) - case _ => value - } - } - } - } - } - }, + (codec, queryParams) => codec.erase.codec.decode(queryParams), ) - private def emptyStringIsValue(schema: Schema[_]): Boolean = - schema.asInstanceOf[Schema.Primitive[_]].standardType match { - case StandardType.UnitType => true - case StandardType.StringType => true - case StandardType.BinaryType => true - case StandardType.CharType => true - case _ => false - } - private def decodeHeaders(headers: Headers, inputs: Array[Any]): Unit = genericDecode[Headers, HttpCodec.Header[_]]( headers, flattened.header, inputs, - (codec, headers) => - headers.get(codec.name) match { - case Some(value) => - codec.erase.textCodec - .decode(value) - .getOrElse(throw HttpCodecError.MalformedHeader(codec.name, codec.textCodec)) - - case None => - throw HttpCodecError.MissingHeader(codec.name) - }, + (codec, headers) => codec.headerType.fromHeadersUnsafe(headers), ) private def decodeStatus(status: Status, inputs: Array[Any]): Unit = @@ -508,108 +346,11 @@ private[codec] object EncoderDecoder { ) private def encodeQuery(config: CodecConfig, inputs: Array[Any]): QueryParams = - genericEncode[QueryParams, HttpCodec.Query[_, _]]( + genericEncode[QueryParams, HttpCodec.Query[_]]( flattened.query, inputs, QueryParams.empty, - (codec, input, queryParams) => { - val query = codec.erase - - query.queryType match { - case QueryType.Primitive(name, codec) => - val schema = codec.schema - if (schema.isInstanceOf[Schema.Primitive[_]]) { - if (schema.asInstanceOf[Schema.Primitive[_]].standardType.isInstanceOf[StandardType.UnitType.type]) { - queryParams.addQueryParams(name, Chunk.empty[String]) - } else { - val encoded = codec.codec(config).asInstanceOf[BinaryCodec[Any]].encode(input).asString - queryParams.addQueryParams(name, Chunk(encoded)) - } - } else if (schema.isInstanceOf[Schema.Optional[_]]) { - val encoded = codec.codec(config).asInstanceOf[BinaryCodec[Any]].encode(input).asString - if (encoded.nonEmpty) queryParams.addQueryParams(name, Chunk(encoded)) else queryParams - } else { - throw new IllegalStateException( - "Only primitive schema is supported for query parameters of type Primitive", - ) - } - case QueryType.Collection(_, QueryType.Primitive(name, codec), optional) => - var in: Any = input - if (optional) { - in = input.asInstanceOf[Option[Any]].getOrElse(Chunk.empty) - } - val values = input.asInstanceOf[Iterable[Any]] - if (values.nonEmpty) { - queryParams.addQueryParams( - name, - Chunk.fromIterable( - values.map { value => - codec.codec(config).asInstanceOf[BinaryCodec[Any]].encode(value).asString - }, - ), - ) - } else queryParams - case query @ QueryType.Record(recordSchema) if recordSchema.isInstanceOf[Schema.Optional[_]] => - input match { - case None => queryParams - case Some(value) => - val innerSchema = - recordSchema.asInstanceOf[Schema.Optional[_]].schema.asInstanceOf[Schema.Record[Any]] - val fieldValues = innerSchema.deconstruct(value)(Unsafe.unsafe) - var j = 0 - var qp = queryParams - while (j < fieldValues.size) { - val (field, codec) = query.fieldAndCodecs(j) - val name = field.name - val value = fieldValues(j) match { - case Some(value) => value - case None => field.defaultValue - } - value match { - case values: Iterable[_] => - qp = qp.addQueryParams( - name, - Chunk.fromIterable(values.map { v => - codec.codec(config).asInstanceOf[BinaryCodec[Any]].encode(v).asString - }), - ) - case _ => - val encoded = codec.codec(config).asInstanceOf[BinaryCodec[Any]].encode(value).asString - qp = qp.addQueryParam(name, encoded) - } - j = j + 1 - } - qp - } - case query @ QueryType.Record(recordSchema) => - val innerSchema = recordSchema.asInstanceOf[Schema.Record[Any]] - val fieldValues = innerSchema.deconstruct(input)(Unsafe.unsafe) - var j = 0 - var qp = queryParams - while (j < fieldValues.size) { - val (field, codec) = query.fieldAndCodecs(j) - val name = field.name - val value = fieldValues(j) match { - case Some(value) => value - case None => field.defaultValue - } - value match { - case values if values.isInstanceOf[Iterable[_]] => - qp = qp.addQueryParams( - name, - Chunk.fromIterable(values.asInstanceOf[Iterable[Any]].map { v => - codec.codec(config).asInstanceOf[BinaryCodec[Any]].encode(v).asString - }), - ) - case _ => - val encoded = codec.codec(config).asInstanceOf[BinaryCodec[Any]].encode(value).asString - qp = qp.addQueryParam(name, encoded) - } - j = j + 1 - } - qp - } - }, + (codec, input, queryParams) => codec.erase.codec.encode(input, queryParams), ) private def encodeHeaders(inputs: Array[Any]): Headers = @@ -617,7 +358,7 @@ private[codec] object EncoderDecoder { flattened.header, inputs, Headers.empty, - (codec, input, headers) => headers ++ Headers(codec.name, codec.erase.textCodec.encode(input)), + (codec, input, headers) => headers ++ codec.erase.headerType.toHeaders(input), ) private def encodeStatus(inputs: Array[Any]): Option[Status] = diff --git a/zio-http/shared/src/main/scala/zio/http/endpoint/http/HttpGen.scala b/zio-http/shared/src/main/scala/zio/http/endpoint/http/HttpGen.scala index eae8f5aba1..cdce331a56 100644 --- a/zio-http/shared/src/main/scala/zio/http/endpoint/http/HttpGen.scala +++ b/zio-http/shared/src/main/scala/zio/http/endpoint/http/HttpGen.scala @@ -112,8 +112,8 @@ object HttpGen { case JsonSchema.Object(properties, _, _) => properties.flatMap { case (key, value) => loop(value, Some(key)) }.toSeq case JsonSchema.Enum(values) => Seq(HttpVariable(getName(name), None, Some(s"enum: ${values.mkString(",")}"))) - case JsonSchema.Null => Seq.empty - case JsonSchema.AnyJson => Seq.empty + case JsonSchema.Null => Seq.empty + case JsonSchema.AnyJson => Seq.empty } bodySchema0 match { @@ -125,41 +125,35 @@ object HttpGen { private def getName(name: Option[String]) = { name.getOrElse(throw new IllegalArgumentException("name is required")) } def headersVariables(inAtoms: AtomizedMetaCodecs): Seq[HttpVariable] = - inAtoms.header.collect { case mc @ MetaCodec(HttpCodec.Header(name, codec, _), _) => + inAtoms.header.collect { case mc @ MetaCodec(HttpCodec.Header(headerType, _), _) => HttpVariable( - name.capitalize, - mc.examples.values.headOption.map(e => codec.asInstanceOf[TextCodec[Any]].encode(e)), + headerType.names.head.capitalize, + mc.examples.values.headOption.map(e => + headerType.toHeaders(e.asInstanceOf[headerType.HeaderValue]).head.renderedValue, + ), ) } def queryVariables(config: CodecConfig, inAtoms: AtomizedMetaCodecs): Seq[HttpVariable] = { - inAtoms.query.collect { - case mc @ MetaCodec(HttpCodec.Query(HttpCodec.Query.QueryType.Primitive(name, codec), _), _) => + inAtoms.query.collect { case mc @ MetaCodec(HttpCodec.Query(codec, _), _) => + val recordSchema = (codec.schema match { + case value if value.isInstanceOf[Schema.Optional[_]] => value.asInstanceOf[Schema.Optional[Any]].schema + case _ => codec.schema + }).asInstanceOf[Schema.Record[Any]] + val examples = mc.examples.values.headOption.map { ex => + recordSchema.deconstruct(ex)(Unsafe.unsafe) + } + codec.recordFields.zipWithIndex.map { case ((field, codec), index) => HttpVariable( - name, - mc.examples.values.headOption.map((e: Any) => - codec.codec(config).asInstanceOf[BinaryCodec[Any]].encode(e).asString, - ), - ) :: Nil - case mc @ MetaCodec(HttpCodec.Query(record @ HttpCodec.Query.QueryType.Record(schema), _), _) => - val recordSchema = (schema match { - case value if value.isInstanceOf[Schema.Optional[_]] => value.asInstanceOf[Schema.Optional[Any]].schema - case _ => schema - }).asInstanceOf[Schema.Record[Any]] - val examples = mc.examples.values.headOption.map { ex => - recordSchema.deconstruct(ex)(Unsafe.unsafe) - } - record.fieldAndCodecs.zipWithIndex.map { case ((field, codec), index) => - HttpVariable( - field.name, - examples.map(values => { - val fieldValue = values(index) - .orElse(field.defaultValue) - .getOrElse(throw new Exception(s"No value or default value for field ${field.name}")) - codec.codec(config).encode(fieldValue).asString - }), - ) - } + field.name, + examples.map(values => { + val fieldValue = values(index) + .orElse(field.defaultValue) + .getOrElse(throw new Exception(s"No value or default value for field ${field.name}")) + codec.encode(fieldValue) + }), + ) + } }.flatten } diff --git a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/JsonSchema.scala b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/JsonSchema.scala index 8dced7348f..69ca61f5a0 100644 --- a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/JsonSchema.scala +++ b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/JsonSchema.scala @@ -253,6 +253,16 @@ object JsonSchema { .toOption .get + def fromTextCodec(codec: TextCodec[_]): JsonSchema = + codec match { + case TextCodec.Constant(string) => JsonSchema.Enum(Chunk(EnumValue.Str(string))) + case TextCodec.StringCodec => JsonSchema.String() + case TextCodec.IntCodec => JsonSchema.Integer(JsonSchema.IntegerFormat.Int32) + case TextCodec.LongCodec => JsonSchema.Integer(JsonSchema.IntegerFormat.Int64) + case TextCodec.BooleanCodec => JsonSchema.Boolean + case TextCodec.UUIDCodec => JsonSchema.String(JsonSchema.StringFormat.UUID) + } + private[openapi] def fromSerializableSchema(schema: SerializableJsonSchema): JsonSchema = { val definedAttributesCount = schema.productIterator.count(_.asInstanceOf[Option[_]].isDefined) @@ -389,16 +399,6 @@ object JsonSchema { } } - def fromTextCodec(codec: TextCodec[_]): JsonSchema = - codec match { - case TextCodec.Constant(string) => JsonSchema.Enum(Chunk(EnumValue.Str(string))) - case TextCodec.StringCodec => JsonSchema.String() - case TextCodec.IntCodec => JsonSchema.Integer(JsonSchema.IntegerFormat.Int32) - case TextCodec.LongCodec => JsonSchema.Integer(JsonSchema.IntegerFormat.Int64) - case TextCodec.BooleanCodec => JsonSchema.Boolean - case TextCodec.UUIDCodec => JsonSchema.String(JsonSchema.StringFormat.UUID) - } - def fromSegmentCodec(codec: SegmentCodec[_]): JsonSchema = codec match { case SegmentCodec.BoolSeg(_) => JsonSchema.Boolean diff --git a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala index 59ab71fae7..a7cc99bc4f 100644 --- a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala +++ b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala @@ -16,16 +16,18 @@ import zio.schema.{Schema, TypeId} import zio.http._ import zio.http.codec.HttpCodec.Metadata +import zio.http.codec.HttpCodecType.Content import zio.http.codec._ import zio.http.endpoint._ import zio.http.endpoint.openapi.JsonSchema.SchemaStyle import zio.http.endpoint.openapi.OpenAPI.{Path, PathItem} +import zio.http.internal.StringSchemaCodec object OpenAPIGen { private val PathWildcard = "pathWildcard" private[openapi] def groupMap[A, K, B](chunk: Chunk[A])(key: A => K)(f: A => B): immutable.Map[K, Chunk[B]] = { - val m = mutable.Map.empty[K, mutable.Builder[B, Chunk[B]]] + val m = mutable.Map.empty[K, mutable.Builder[B, Chunk[B]]] for (elem <- chunk) { val k = key(elem) val bldr = m.getOrElseUpdate(k, Chunk.newBuilder[B]) @@ -100,20 +102,21 @@ object OpenAPIGen { final case class AtomizedMetaCodecs( method: Chunk[MetaCodec[SimpleCodec[Method, _]]], path: Chunk[MetaCodec[SegmentCodec[_]]], - query: Chunk[MetaCodec[HttpCodec.Query[_, _]]], + query: Chunk[MetaCodec[HttpCodec.Query[_]]], header: Chunk[MetaCodec[HttpCodec.Header[_]]], - content: Chunk[MetaCodec[HttpCodec.Atom[HttpCodecType.Content, _]]], + content: Chunk[MetaCodec[HttpCodec.Atom[Content, _]]], status: Chunk[MetaCodec[HttpCodec.Status[_]]], ) { def append(metaCodec: MetaCodec[_]): AtomizedMetaCodecs = metaCodec match { case MetaCodec(codec: HttpCodec.Method[_], annotations) => - copy(method = - (method :+ MetaCodec(codec.codec, annotations)).asInstanceOf[Chunk[MetaCodec[SimpleCodec[Method, _]]]], + copy( + method = + (method :+ MetaCodec(codec.codec, annotations)).asInstanceOf[Chunk[MetaCodec[SimpleCodec[Method, _]]]], ) case MetaCodec(_: SegmentCodec[_], _) => copy(path = path :+ metaCodec.asInstanceOf[MetaCodec[SegmentCodec[_]]]) - case MetaCodec(_: HttpCodec.Query[_, _], _) => - copy(query = query :+ metaCodec.asInstanceOf[MetaCodec[HttpCodec.Query[_, _]]]) + case MetaCodec(_: HttpCodec.Query[_], _) => + copy(query = query :+ metaCodec.asInstanceOf[MetaCodec[HttpCodec.Query[_]]]) case MetaCodec(_: HttpCodec.Header[_], _) => copy(header = header :+ metaCodec.asInstanceOf[MetaCodec[HttpCodec.Header[_]]]) case MetaCodec(_: HttpCodec.Status[_], _) => @@ -750,87 +753,41 @@ object OpenAPIGen { def parameters: Set[OpenAPI.ReferenceOr[OpenAPI.Parameter]] = queryParams ++ pathParams ++ headerParams - def queryParams: Set[OpenAPI.ReferenceOr[OpenAPI.Parameter]] = { - inAtoms.query.collect { - case mc @ MetaCodec(q @ HttpCodec.Query(HttpCodec.Query.QueryType.Primitive(name, codec), _), _) => + def queryParams: Set[OpenAPI.ReferenceOr[OpenAPI.Parameter]] = + inAtoms.query.collect { case mc @ MetaCodec(HttpCodec.Query(codec, _), _) => + val recordSchema = (codec.schema match { + case schema if schema.isInstanceOf[Schema.Optional[_]] => schema.asInstanceOf[Schema.Optional[_]].schema + case _ => codec.schema + }).asInstanceOf[Schema.Record[Any]] + val examples = mc.examples.map { case (exName, ex) => + exName -> recordSchema.deconstruct(ex)(Unsafe.unsafe) + } + codec.recordFields.zipWithIndex.map { case ((field, codec), index) => OpenAPI.ReferenceOr.Or( OpenAPI.Parameter.queryParameter( - name = name, + name = field.name, description = mc.docsOpt, schema = Some(OpenAPI.ReferenceOr.Or(JsonSchema.fromZSchema(codec.schema))), deprecated = mc.deprecated, style = OpenAPI.Parameter.Style.Form, explode = false, allowReserved = false, - examples = mc.examples.map { case (name, value) => - name -> OpenAPI.ReferenceOr.Or(OpenAPI.Example(value = Json.Str(value.toString))) - }, - required = mc.required && !q.isOptional, - ), - ) :: Nil - case mc @ MetaCodec(HttpCodec.Query(record @ HttpCodec.Query.QueryType.Record(schema), _), _) => - val recordSchema = (schema match { - case schema if schema.isInstanceOf[Schema.Optional[_]] => schema.asInstanceOf[Schema.Optional[_]].schema - case _ => schema - }).asInstanceOf[Schema.Record[Any]] - val examples = mc.examples.map { case (exName, ex) => - exName -> recordSchema.deconstruct(ex)(Unsafe.unsafe) - } - record.fieldAndCodecs.zipWithIndex.map { case ((field, codec), index) => - OpenAPI.ReferenceOr.Or( - OpenAPI.Parameter.queryParameter( - name = field.name, - description = mc.docsOpt, - schema = Some(OpenAPI.ReferenceOr.Or(JsonSchema.fromZSchema(codec.schema))), - deprecated = mc.deprecated, - style = OpenAPI.Parameter.Style.Form, - explode = false, - allowReserved = false, - examples = examples.map { case (exName, values) => - val fieldValue = values(index) - .orElse(field.defaultValue) - .getOrElse( - throw new Exception(s"No value or default value found for field ${exName}_${field.name}"), - ) - s"${exName}_${field.name}" -> OpenAPI.ReferenceOr.Or( - OpenAPI.Example(value = - Json.Str(codec.codec(CodecConfig.defaultConfig).encode(fieldValue).asString), - ), + examples = examples.map { case (exName, values) => + val fieldValue = values(index) + .orElse(field.defaultValue) + .getOrElse( + throw new Exception(s"No value or default value found for field ${exName}_${field.name}"), ) - }, - required = mc.required, - ), - ) - - } - case mc @ MetaCodec( - HttpCodec.Query( - HttpCodec.Query.QueryType.Collection( - _, - HttpCodec.Query.QueryType.Primitive(name, codec), - optional, - ), - _, - ), - _, - ) => - OpenAPI.ReferenceOr.Or( - OpenAPI.Parameter.queryParameter( - name = name, - description = mc.docsOpt, - schema = Some(OpenAPI.ReferenceOr.Or(JsonSchema.fromZSchema(codec.schema))), - deprecated = mc.deprecated, - style = OpenAPI.Parameter.Style.Form, - explode = false, - allowReserved = false, - examples = mc.examples.map { case (exName, value) => - exName -> OpenAPI.ReferenceOr.Or(OpenAPI.Example(value = Json.Str(value.toString))) + s"${exName}_${field.name}" -> OpenAPI.ReferenceOr.Or( + OpenAPI.Example(value = Json.Str(codec.encode(fieldValue))), + ) }, - required = mc.required && !optional, + required = mc.required && !StringSchemaCodec.isOptional(field.schema), ), - ) :: Nil - } - }.flatten.toSet + ) + + } + }.flatten.toSet def pathParams: Set[OpenAPI.ReferenceOr[OpenAPI.Parameter]] = inAtoms.path.collect { @@ -855,14 +812,14 @@ object OpenAPIGen { .map { case mc @ MetaCodec(codec, _) => OpenAPI.ReferenceOr.Or( OpenAPI.Parameter.headerParameter( - name = mc.name.getOrElse(codec.name), + name = mc.name.getOrElse(codec.headerType.names.head), description = mc.docsOpt, - definition = - Some(OpenAPI.ReferenceOr.Or(JsonSchema.fromTextCodec(codec.textCodec).nullable(!mc.required))), + definition = Some(OpenAPI.ReferenceOr.Or(JsonSchema.String().nullable(!mc.required))), deprecated = mc.deprecated, - examples = mc.examples.map { case (name, value) => - name -> OpenAPI.ReferenceOr.Or(OpenAPI.Example(codec.textCodec.encode(value).toJsonAST.toOption.get)) - }, + examples = Map.empty, +// mc.examples.map { case (name, value) => +// name -> OpenAPI.ReferenceOr.Or(OpenAPI.Example(codec.headerType.render(value).toJsonAST.toOption.get)) +// }, required = mc.required, ), ) @@ -1133,13 +1090,14 @@ object OpenAPIGen { private def headersFrom(codec: AtomizedMetaCodecs) = { codec.header.map { case mc @ MetaCodec(codec, _) => - codec.name -> OpenAPI.ReferenceOr.Or( + // todo use all headers + codec.headerType.names.head -> OpenAPI.ReferenceOr.Or( OpenAPI.Header( description = mc.docsOpt, required = true, deprecated = mc.deprecated, allowEmptyValue = false, - schema = Some(JsonSchema.fromTextCodec(codec.textCodec)), + schema = Some(JsonSchema.String().nullable(!mc.required)), ), ) }.toMap diff --git a/zio-http/shared/src/main/scala/zio/http/internal/HeaderGetters.scala b/zio-http/shared/src/main/scala/zio/http/internal/HeaderGetters.scala index 8b6ad2b167..2fb9b2c441 100644 --- a/zio-http/shared/src/main/scala/zio/http/internal/HeaderGetters.scala +++ b/zio-http/shared/src/main/scala/zio/http/internal/HeaderGetters.scala @@ -67,6 +67,13 @@ trait HeaderGetters { self => /** Gets the raw unparsed header value */ final def rawHeader(name: CharSequence): Option[String] = headers.get(name) + def rawHeaders(name: CharSequence): Chunk[String] = + Chunk.fromIterator( + headers.iterator + .filter(header => CharSequenceExtensions.equals(header.headerNameAsCharSequence, name, CaseMode.Insensitive)) + .map(_.renderedValue), + ) + /** Gets the raw unparsed header value */ final def rawHeader(headerType: HeaderType): Option[String] = rawHeader(headerType.name) diff --git a/zio-http/shared/src/main/scala/zio/http/internal/HeaderModifier.scala b/zio-http/shared/src/main/scala/zio/http/internal/HeaderModifier.scala index 255358d8c5..629f497f01 100644 --- a/zio-http/shared/src/main/scala/zio/http/internal/HeaderModifier.scala +++ b/zio-http/shared/src/main/scala/zio/http/internal/HeaderModifier.scala @@ -39,6 +39,9 @@ trait HeaderModifier[+A] { self => final def addHeaders(headers: Headers): A = updateHeaders(_ ++ headers) + final def addHeaders(headers: Iterable[(CharSequence, CharSequence)]): A = + addHeaders(Headers.fromIterable(headers.map { case (k, v) => Header.Custom(k, v) })) + final def removeHeader(headerType: HeaderType): A = removeHeader(headerType.name) final def removeHeader(name: String): A = removeHeaders(Set(name)) diff --git a/zio-http/shared/src/main/scala/zio/http/internal/QueryModifier.scala b/zio-http/shared/src/main/scala/zio/http/internal/QueryModifier.scala index 77996c462e..75ad863332 100644 --- a/zio-http/shared/src/main/scala/zio/http/internal/QueryModifier.scala +++ b/zio-http/shared/src/main/scala/zio/http/internal/QueryModifier.scala @@ -45,6 +45,13 @@ trait QueryModifier[+A] { self: QueryOps[A] with A => def addQueryParams(values: String): A = updateQueryParams(params => params ++ QueryParams.decode(values)) + def addQueryParams(queryParams: Iterable[(String, String)]): A = + updateQueryParams(params => + params ++ QueryParams(queryParams.groupBy(_._1).map { case (k, v) => + k -> Chunk.fromIterable(v).map(_._2) + }), + ) + /** * Removes the specified key from the query parameters. */ diff --git a/zio-http/shared/src/main/scala/zio/http/internal/StringSchemaCodec.scala b/zio-http/shared/src/main/scala/zio/http/internal/StringSchemaCodec.scala new file mode 100644 index 0000000000..d56047298a --- /dev/null +++ b/zio-http/shared/src/main/scala/zio/http/internal/StringSchemaCodec.scala @@ -0,0 +1,723 @@ +package zio.http.internal + +import java.time._ +import java.util.{Currency, UUID} + +import scala.annotation.tailrec +import scala.util.Try + +import zio.{Cause, Chunk, Unsafe} + +import zio.schema.codec.DecodeError +import zio.schema.validation.{Validation, ValidationError} +import zio.schema.{Schema, StandardType, TypeId} + +import zio.http.codec.HttpCodecError +import zio.http.internal.StringSchemaCodec.{PrimitiveCodec, decodeAndUnwrap, emptyStringIsValue, validateDecoded} +import zio.http.{Headers, QueryParams} + +private[http] trait ErrorConstructor { + def missing(fieldName: String): HttpCodecError + def missingAll(fieldNames: Chunk[String]): HttpCodecError + def invalid(errors: Chunk[ValidationError]): HttpCodecError + def malformed(fieldName: String, error: DecodeError): HttpCodecError + def invalidCount(fieldName: String, expected: Int, actual: Int): HttpCodecError +} + +private[http] trait StringSchemaCodec[A, Target] { + private[http] def schema: Schema[A] + private[http] def add(target: Target, key: String, value: String): Target + private[http] def addAll(target: Target, headers: Iterable[(String, String)]): Target + private[http] def contains(target: Target, key: String): Boolean + private[http] def unsafeGet(target: Target, key: String): String + private[http] def getAll(target: Target, key: String): Chunk[String] + private[http] def count(target: Target, key: String): Int + private[http] def error: ErrorConstructor + private[http] def kebabCase: Boolean + private[http] val defaultValue: A + private[http] val isOptional: Boolean + private[http] val isOptionalSchema: Boolean + private[http] val recordFields: Chunk[(Schema.Field[_, _], PrimitiveCodec[Any])] = { + val fields = schema match { + case record: Schema.Record[A] => + record.fields + case s: Schema.Optional[_] if s.schema.isInstanceOf[Schema.Record[_]] => + s.schema.asInstanceOf[Schema.Record[A]].fields + case s: Schema.Transform[_, _, _] if s.schema.isInstanceOf[Schema.Record[_]] => + s.schema.asInstanceOf[Schema.Record[A]].fields + case _ => Chunk.empty + } + fields.map(StringSchemaCodec.unlazyField).map { + case field if field.schema.isInstanceOf[Schema.Collection[_, _]] => + val elementSchema = field.schema.asInstanceOf[Schema.Collection[_, _]] match { + case s: Schema.NonEmptySequence[_, _, _] => s.elementSchema + case s: Schema.Sequence[_, _, _] => s.elementSchema + case s: Schema.Set[_] => s.elementSchema + case _: Schema.Map[_, _] => throw new IllegalArgumentException("Maps are not supported") + case _: Schema.NonEmptyMap[_, _] => throw new IllegalArgumentException("Maps are not supported") + } + val codec = PrimitiveCodec(elementSchema).asInstanceOf[PrimitiveCodec[Any]] + (StringSchemaCodec.mapFieldName(field, kebabCase), codec) + case field => + val codec = + PrimitiveCodec(field.annotations.foldLeft(field.schema)(_.annotate(_))).asInstanceOf[PrimitiveCodec[Any]] + (StringSchemaCodec.mapFieldName(field, kebabCase), codec) + } + } + + private[http] val recordSchema: Schema.Record[Any] = schema match { + case record: Schema.Record[_] => + record.asInstanceOf[Schema.Record[Any]] + case s: Schema.Optional[_] if s.schema.isInstanceOf[Schema.Record[_]] => + s.schema.asInstanceOf[Schema.Record[Any]] + case _ => null + } + + private def createAndValidateCollection(schema: Schema.Collection[_, _], decoded: Chunk[Any]) = { + val collection = schema.fromChunk.asInstanceOf[Chunk[Any] => Any](decoded) + val erasedSchema = schema.asInstanceOf[Schema[Any]] + val validationErrors = erasedSchema.validate(collection)(erasedSchema) + if (validationErrors.nonEmpty) throw error.invalid(validationErrors) + collection + } + + private[http] def decode(target: Target): A = { + val optional = isOptionalSchema + val hasDefault = defaultValue != null && isOptional + val default = defaultValue + val hasAllParams = recordFields.forall { case (field, codec) => + contains(target, field.fieldName) || field.optional || codec.isOptional + } + if (!hasAllParams && hasDefault) default + else if (!hasAllParams) { + throw error.missingAll { + recordFields.collect { + case (field, codec) if !(contains(target, field.fieldName) || field.optional || codec.isOptional) => + field.fieldName + } + } + } else { + val decoded = recordFields.map { + case (field, codec) if field.schema.isInstanceOf[Schema.Collection[_, _]] => + val schema = field.schema.asInstanceOf[Schema.Collection[_, _]] + if (!contains(target, field.fieldName)) { + if (field.defaultValue.isDefined) field.defaultValue.get + else throw error.missing(field.fieldName) + } else { + val values = getAll(target, field.fieldName) + val decoded = + values.map(decodeAndUnwrap(field, codec, _, error.malformed)) + createAndValidateCollection(schema, decoded) + + } + case (field, codec) => + val count0 = count(target, field.fieldName) + if (count0 > 1) throw error.invalidCount(field.fieldName, 1, count0) + val value = unsafeGet(target, field.fieldName) + val decoded = { + if (value == null || (value == "" && !emptyStringIsValue(codec.schema) && codec.isOptional)) + codec.defaultValue + else decodeAndUnwrap(field, codec, value, error.malformed) + } + validateDecoded(codec, decoded, error) + } + if (optional) { + val constructed = recordSchema.construct(decoded)(Unsafe.unsafe) + constructed match { + case Left(value) => + throw error.malformed( + s"${recordSchema.id}", + DecodeError.ReadError(Cause.empty, value), + ) + case Right(value) => + recordSchema.validate(value)(recordSchema) match { + case errors if errors.nonEmpty => throw error.invalid(errors) + case _ => Some(value).asInstanceOf[A] + } + } + } else { + val constructed = recordSchema.construct(decoded)(Unsafe.unsafe) + constructed match { + case Left(value) => + throw error.malformed( + s"${recordSchema.id}", + DecodeError.ReadError(Cause.empty, value), + ) + case Right(value) => + recordSchema.validate(value)(recordSchema) match { + case errors if errors.nonEmpty => throw error.invalid(errors) + case _ => value.asInstanceOf[A] + } + } + } + } + + } + + private[http] def encode(input: A, target: Target): Target = { + val fields = recordFields + val value = input.asInstanceOf[Any] match { + case None => null + case it: Iterable[_] if it.isEmpty => null + case Some(value) => value + case value => value + } + if (value == null) target + else { + val fieldValues = recordSchema.deconstruct(value)(Unsafe.unsafe) + var target0 = target + val fieldIt = fields.iterator + val fieldValuesIt = fieldValues.iterator + while (fieldIt.hasNext) { + val (field, codec) = fieldIt.next() + val name = field.fieldName + val value = fieldValuesIt.next() match { + case Some(value) => value + case None => field.defaultValue + } + value match { + case values: Iterable[_] => + target0 = addAll(target0, values.map { v => (name, codec.encode(v)) }) + case _ => + val encoded = codec.encode(value) + target0 = add(target0, name, encoded) + } + } + target0 + } + } + + private[http] def optional: StringSchemaCodec[Option[A], Target] + +} + +private[http] object StringSchemaCodec { + private[http] def unlazyField(field: Schema.Field[_, _]): Schema.Field[_, _] = field match { + case f if f.schema.isInstanceOf[Schema.Lazy[_]] => + Schema.Field( + f.name, + f.schema.asInstanceOf[Schema.Lazy[_]].schema.asInstanceOf[Schema[Any]], + f.annotations, + f.validation.asInstanceOf[Validation[Any]], + f.get.asInstanceOf[Any => Any], + f.set.asInstanceOf[(Any, Any) => Any], + ) + case f => f + } + private[http] def defaultValue[A](schema: Schema[A]): A = + if (schema.isInstanceOf[Schema.Collection[_, _]]) { + Try(schema.asInstanceOf[Schema.Collection[A, _]].empty).fold( + _ => null.asInstanceOf[A], + identity, + ) + } else { + schema.defaultValue match { + case Right(value) => value + case Left(_) => + schema match { + case _: Schema.Optional[_] => None.asInstanceOf[A] + case collection: Schema.Collection[A, _] => + Try(collection.empty).fold( + _ => null.asInstanceOf[A], + identity, + ) + case _ => null.asInstanceOf[A] + } + } + } + + private[http] def isOptional(schema: Schema[_]): Boolean = schema match { + case _: Schema.Optional[_] => + true + case record: Schema.Record[_] => + record.fields.forall(_.optional) || record.defaultValue.isRight + case d: Schema.Collection[_, _] => + val bool = Try(d.empty).isSuccess || d.defaultValue.isRight + bool + case _ => + false + } + + private[http] def isOptionalSchema(schema: Schema[_]): Boolean = + schema match { + case _: Schema.Optional[_] => true + case s: Schema.Transform[_, _, _] if s.schema.isInstanceOf[Schema.Optional[_]] => true + case _ => false + } + + private[http] final case class PrimitiveCodec[A]( + private[http] val schema: Schema[A], + ) { + + val defaultValue: A = + StringSchemaCodec.defaultValue(schema) + + private[http] val isOptional: Boolean = + StringSchemaCodec.isOptional(schema) + + private[http] val isOptionalSchema: Boolean = + StringSchemaCodec.isOptionalSchema(schema) + + private[http] val encode: A => String = + PrimitiveCodec.primitiveSchemaEncoder(schema) + + private[http] val decode: String => A = + PrimitiveCodec.primitiveSchemaDecoder(schema) + + } + + object PrimitiveCodec { + + private[http] def primitiveSchemaDecoder[A](schema: Schema[A]): String => A = schema match { + case Schema.Optional(schema, _) => + primitiveSchemaDecoder(schema).andThen(Some(_)).asInstanceOf[String => A] + case Schema.Transform(schema, f, _, _, _) => + primitiveSchemaDecoder(schema).andThen { + f(_) match { + case Left(value) => throw new IllegalArgumentException(value) + case Right(value) => value + } + }.asInstanceOf[String => A] + case Schema.Primitive(standardType, _) => + parsePrimitive(standardType.asInstanceOf[StandardType[Any]]).asInstanceOf[String => A] + case Schema.Lazy(schema0) => + primitiveSchemaDecoder(schema0()).asInstanceOf[String => A] + case _ => throw new IllegalArgumentException(s"Unsupported schema $schema") + } + + private[http] def primitiveSchemaEncoder[A](schema: Schema[A]): A => String = schema match { + case Schema.Optional(schema, _) => + val innerEncoder: Any => String = primitiveSchemaEncoder(schema.asInstanceOf[Schema[Any]]) + (a: A) => if (a.isInstanceOf[None.type]) null else innerEncoder(a.asInstanceOf[Some[Any]].get) + case Schema.Transform(schema, f, _, _, _) => + val innerEncoder: Any => String = primitiveSchemaEncoder(schema.asInstanceOf[Schema[Any]]) + (a: A) => + f.asInstanceOf[Any => Either[String, Any]](a.asInstanceOf[Any]) match { + case Left(value) => throw new IllegalArgumentException(value) + case Right(value) => innerEncoder(value) + } + case Schema.Lazy(schema0) => + primitiveSchemaEncoder(schema0()).asInstanceOf[A => String] + case Schema.Primitive(_, _) => + (a: A) => a.toString + case _ => + throw new IllegalArgumentException(s"Unsupported schema $schema") + } + } + + private def decodeAndUnwrap( + field: Schema.Field[_, _], + codec: PrimitiveCodec[Any], + value: String, + ex: (String, DecodeError) => HttpCodecError, + ) = + try codec.decode(value) + catch { + case err: DecodeError => throw ex(field.fieldName, err) + } + + private def validateDecoded(codec: PrimitiveCodec[Any], decoded: Any, error: ErrorConstructor) = { + val validationErrors = codec.schema.validate(decoded)(codec.schema) + if (validationErrors.nonEmpty) throw error.invalid(validationErrors) + decoded + } + + @tailrec + private def emptyStringIsValue(schema: Schema[_]): Boolean = { + schema match { + case value: Schema.Optional[_] => + val innerSchema = value.schema + emptyStringIsValue(innerSchema) + case _ => + schema.asInstanceOf[Schema.Primitive[_]].standardType match { + case StandardType.UnitType => true + case StandardType.StringType => true + case StandardType.BinaryType => true + case StandardType.CharType => true + case _ => false + } + } + } + private[http] def mapFieldName(field: Schema.Field[_, _], kebabCase: Boolean): Schema.Field[_, _] = { + Schema.Field( + if (!kebabCase) field.fieldName else camelToKebab(field.fieldName), + field.annotations.foldLeft(field.schema)(_ annotate _).asInstanceOf[Schema[Any]], + field.annotations, + field.validation.asInstanceOf[Validation[Any]], + field.get.asInstanceOf[Any => Any], + field.set.asInstanceOf[(Any, Any) => Any], + ) + } + + private[http] def headerFromSchema[A]( + schema0: Schema[A], + error0: ErrorConstructor, + name: String, + ): StringSchemaCodec[A, Headers] = { + + def stringSchemaCodec(schema1: Schema[Any]): StringSchemaCodec[A, Headers] = + new StringSchemaCodec[A, Headers] { + override def schema: Schema[A] = schema1.asInstanceOf[Schema[A]] + + override private[http] def add(headers: Headers, key: String, value: String): Headers = + headers.addHeader(key, value) + + override private[http] def addAll(headers: Headers, kvs: Iterable[(String, String)]): Headers = + headers.addHeaders(kvs) + + override private[http] def contains(headers: Headers, key: String): Boolean = + headers.contains(key) + + override private[http] def unsafeGet(headers: Headers, key: String): String = + headers.getUnsafe(key) + + override private[http] def getAll(headers: Headers, key: String): Chunk[String] = + headers.rawHeaders(key) + + override private[http] def count(headers: Headers, key: String): Int = + headers.rawHeaders(key).size + + override private[http] def error: ErrorConstructor = + error0 + + override private[http] def optional: StringSchemaCodec[Option[A], Headers] = + StringSchemaCodec.headerFromSchema(schema.optional, error0, name) + + override private[http] def kebabCase: Boolean = + true + override private[http] val defaultValue: A = + StringSchemaCodec.defaultValue(schema0) + override private[http] val isOptional: Boolean = + StringSchemaCodec.isOptional(schema0) + override private[http] val isOptionalSchema: Boolean = + StringSchemaCodec.isOptionalSchema(schema0) + } + schema0 match { + case s @ Schema.Primitive(_, _) => + stringSchemaCodec(recordSchema(s.asInstanceOf[Schema[Any]], name)) + case s @ Schema.Optional(schema, _) => + schema match { + case _: Schema.Collection[_, _] | _: Schema.Primitive[_] => + stringSchemaCodec(recordSchema(s.asInstanceOf[Schema[Any]], name)) + case s if s.isInstanceOf[Schema.Record[_]] => stringSchemaCodec(schema.asInstanceOf[Schema[Any]]) + case _ => throw new IllegalArgumentException(s"Unsupported schema $s") + } + case s @ Schema.Transform(schema, _, _, _, _) => + schema match { + case _: Schema.Collection[_, _] | _: Schema.Primitive[_] => + stringSchemaCodec(recordSchema(s.asInstanceOf[Schema[Any]], name)) + case _: Schema.Record[_] => stringSchemaCodec(s.asInstanceOf[Schema[Any]]) + case _ => throw new IllegalArgumentException(s"Unsupported schema $s") + } + case Schema.Lazy(schema0) => + headerFromSchema( + schema0().asInstanceOf[Schema[A]], + error0, + name, + ) + case _: Schema.Collection[_, _] => + stringSchemaCodec(recordSchema(schema0.asInstanceOf[Schema[Any]], name)) + case s: Schema.Record[_] => + stringSchemaCodec(s.asInstanceOf[Schema[Any]]) + case _ => + throw new IllegalArgumentException(s"Unsupported schema $schema0") + + } + + } + + private[http] def queryFromSchema[A]( + schema0: Schema[A], + error0: ErrorConstructor, + name: String, + ): StringSchemaCodec[A, QueryParams] = { + + def stringSchemaCodec(schema1: Schema[Any]): StringSchemaCodec[A, QueryParams] = + new StringSchemaCodec[A, QueryParams] { + override def schema: Schema[A] = schema1.asInstanceOf[Schema[A]] + + override private[http] def add(queryParams: QueryParams, key: String, value: String): QueryParams = + queryParams.addQueryParam(key, value) + + override private[http] def addAll(queryParams: QueryParams, kvs: Iterable[(String, String)]): QueryParams = + queryParams.addQueryParams(kvs) + + override private[http] def contains(queryParams: QueryParams, key: String): Boolean = + queryParams.hasQueryParam(key) + + override private[http] def unsafeGet(queryParams: QueryParams, key: String): String = + queryParams.unsafeQueryParam(key) + + override private[http] def getAll(queryParams: QueryParams, key: String): Chunk[String] = + queryParams.getAll(key) + + override private[http] def count(queryParams: QueryParams, key: String): Int = + queryParams.valueCount(key) + + override private[http] def error: ErrorConstructor = + error0 + + override private[http] def optional: StringSchemaCodec[Option[A], QueryParams] = + StringSchemaCodec.queryFromSchema(schema.optional, error0, name) + + override private[http] def kebabCase: Boolean = + false + override private[http] val defaultValue: A = + StringSchemaCodec.defaultValue(schema0) + override private[http] val isOptional: Boolean = + StringSchemaCodec.isOptional(schema0) + override private[http] val isOptionalSchema: Boolean = + StringSchemaCodec.isOptionalSchema(schema0) + } + schema0 match { + case s @ Schema.Primitive(_, _) => + stringSchemaCodec(recordSchema(s.asInstanceOf[Schema[Any]], name)) + case s @ Schema.Optional(schema, _) => + schema match { + case _: Schema.Collection[_, _] | _: Schema.Primitive[_] => + stringSchemaCodec(recordSchema(s.asInstanceOf[Schema[Any]], name)) + case s if s.isInstanceOf[Schema.Record[_]] => stringSchemaCodec(schema.asInstanceOf[Schema[Any]]) + case _ => throw new IllegalArgumentException(s"Unsupported schema $s") + } + case s @ Schema.Transform(schema, _, _, _, _) => + schema match { + case _: Schema.Collection[_, _] | _: Schema.Primitive[_] => + stringSchemaCodec(recordSchema(s.asInstanceOf[Schema[Any]], name)) + case _: Schema.Record[_] => stringSchemaCodec(s.asInstanceOf[Schema[Any]]) + case _ => throw new IllegalArgumentException(s"Unsupported schema $s") + } + case Schema.Lazy(schema0) => + queryFromSchema( + schema0().asInstanceOf[Schema[A]], + error0, + name, + ) + case _: Schema.Collection[_, _] => + stringSchemaCodec(recordSchema(schema0.asInstanceOf[Schema[Any]], name)) + case s: Schema.Record[_] => + stringSchemaCodec(s.asInstanceOf[Schema[Any]]) + case _ => + throw new IllegalArgumentException(s"Unsupported schema $schema0") + + } + } + + private def recordSchema[A](s: Schema[A], name: String): Schema[A] = Schema.CaseClass1[A, A]( + TypeId.Structural, + Schema.Field(name, s, Chunk.empty, Validation.succeed, identity, (_, v) => v), + identity, + ) + + private def parsePrimitive(standardType: StandardType[_]): String => Any = + standardType match { + case StandardType.UnitType => + val result = "" + (_: String) => result + case StandardType.StringType => + (s: String) => s + case StandardType.BoolType => + (s: String) => + s.toLowerCase match { + case "true" | "on" | "yes" | "1" => true + case "false" | "off" | "no" | "0" => false + case _ => throw DecodeError.ReadError(Cause.fail(new Exception("Invalid boolean value")), s) + } + case StandardType.ByteType => + (s: String) => + try { + s.toByte + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.ShortType => + (s: String) => + try { + s.toShort + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.IntType => + (s: String) => + try { + s.toInt + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.LongType => + (s: String) => + try { + s.toLong + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.FloatType => + (s: String) => + try { + s.toFloat + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.DoubleType => + (s: String) => + try { + s.toDouble + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.BinaryType => + val result = DecodeError.UnsupportedSchema(Schema.Primitive(standardType), "TextCodec") + (_: String) => throw result + case StandardType.CharType => + (s: String) => s.charAt(0) + case StandardType.UUIDType => + (s: String) => + try { + UUID.fromString(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.BigDecimalType => + (s: String) => + try { + BigDecimal(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.BigIntegerType => + (s: String) => + try { + BigInt(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.DayOfWeekType => + (s: String) => + try { + DayOfWeek.valueOf(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.MonthType => + (s: String) => + try { + Month.valueOf(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.MonthDayType => + (s: String) => + try { + MonthDay.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.PeriodType => + (s: String) => + try { + Period.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.YearType => + (s: String) => + try { + Year.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.YearMonthType => + (s: String) => + try { + YearMonth.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.ZoneIdType => + (s: String) => + try { + ZoneId.of(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.ZoneOffsetType => + (s: String) => + try { + ZoneOffset.of(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.DurationType => + (s: String) => + try { + java.time.Duration.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.InstantType => + (s: String) => + try { + Instant.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.LocalDateType => + (s: String) => + try { + LocalDate.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.LocalTimeType => + (s: String) => + try { + LocalTime.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.LocalDateTimeType => + (s: String) => + try { + LocalDateTime.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.OffsetTimeType => + (s: String) => + try { + OffsetTime.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.OffsetDateTimeType => + (s: String) => + try { + OffsetDateTime.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.ZonedDateTimeType => + (s: String) => + try { + ZonedDateTime.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.CurrencyType => + (s: String) => + try { + Currency.getInstance(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + } + + private def camelToKebab(s: String): String = + if (s.isEmpty) "" + else if (s.head.isUpper) s.head.toLower.toString + camelToKebab(s.tail) + else if (s.contains('-')) s + else + s.foldLeft("") { (acc, c) => + if (c.isUpper) acc + "-" + c.toLower + else acc + c + } +}