Skip to content

Commit d74223c

Browse files
authored
Allow recursive schemas in multipart-based paths (#813)
### Motivation Fixes #769 Currently, as seen in the issue mentioned above, if a recursive schema gets referenced from a multipart-based path, a reference cycle error gets thrown. Since we support recursive type generation already, this shouldn't happen. ### Modifications This PR removes the last `dereferenced(in:)` method call, which would throw if a reference cycle was found, replacing it with a simple lookup in the OpenAPI components tree. It also replaces `DerereferencedJSONSchema` with `JSONSchema` accordingly. ### Result Multipart-based paths can now reference recursive types. ### Test Plan This also adds a test with such a case. We could probably also add a test with a multipart path referencing an _array_ of references instead of just a reference to be thorough.
1 parent 3a95e87 commit d74223c

File tree

2 files changed

+211
-7
lines changed

2 files changed

+211
-7
lines changed

Sources/_OpenAPIGeneratorCore/Translator/Multipart/MultipartContentInspector.swift

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,8 @@ extension FileTranslator {
257257
default: return .infer(.primitive)
258258
}
259259
}
260-
func inferAllOfAnyOfOneOf(_ schemas: [DereferencedJSONSchema]) throws -> MultipartPartInfo.ContentTypeSource? {
260+
261+
func inferAllOfAnyOfOneOf(_ schemas: [JSONSchema]) throws -> MultipartPartInfo.ContentTypeSource? {
261262
// If all schemas are primitive, the allOf/anyOf/oneOf is also primitive.
262263
// These cannot be binary, so only primitive vs complex.
263264
for schema in schemas {
@@ -266,12 +267,13 @@ extension FileTranslator {
266267
}
267268
return .infer(.primitive)
268269
}
269-
func inferSchema(_ schema: DereferencedJSONSchema) throws -> (
270+
271+
func inferSchema(_ schema: JSONSchema) throws -> (
270272
MultipartPartInfo.RepetitionKind, MultipartPartInfo.ContentTypeSource
271273
)? {
272274
let repetitionKind: MultipartPartInfo.RepetitionKind
273275
let candidateSource: MultipartPartInfo.ContentTypeSource
274-
switch schema {
276+
switch schema.value {
275277
case .null, .not: return nil
276278
case .boolean, .number, .integer:
277279
repetitionKind = .single
@@ -289,21 +291,30 @@ extension FileTranslator {
289291
case .array(_, let context):
290292
repetitionKind = .array
291293
if let items = context.items {
292-
switch items {
294+
switch items.value {
293295
case .null, .not: return nil
294296
case .boolean, .number, .integer: candidateSource = .infer(.primitive)
295297
case .string(_, let context): candidateSource = try inferStringContent(context)
296298
case .object, .all, .one, .any, .fragment, .array: candidateSource = .infer(.complex)
299+
case .reference(let ref, _):
300+
guard let source = try inferSchema(components.lookup(ref))?.1 else { return nil }
301+
candidateSource = source
297302
}
298303
} else {
299304
candidateSource = .infer(.complex)
300305
}
306+
case .reference(let ref, _):
307+
guard let (refRepetitionKind, refCandidateSource) = try inferSchema(components.lookup(ref)) else {
308+
return nil
309+
}
310+
repetitionKind = refRepetitionKind
311+
candidateSource = refCandidateSource
301312
}
313+
302314
return (repetitionKind, candidateSource)
303315
}
304-
guard let (repetitionKind, candidateSource) = try inferSchema(schema.dereferenced(in: components)) else {
305-
return nil
306-
}
316+
guard let (repetitionKind, candidateSource) = try inferSchema(schema) else { return nil }
317+
307318
let finalContentTypeSource: MultipartPartInfo.ContentTypeSource
308319
if let encoding, let contentType = encoding.contentTypes.first, encoding.contentTypes.count == 1 {
309320
finalContentTypeSource = try .explicit(contentType.asGeneratorContentType)

Tests/OpenAPIGeneratorReferenceTests/SnippetBasedReferenceTests.swift

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4225,6 +4225,199 @@ final class SnippetBasedReferenceTests: XCTestCase {
42254225
)
42264226
}
42274227

4228+
func testRequestMultipartBodyReferencedSchemaRecursive() throws {
4229+
try self.assertRequestInTypesClientServerTranslation(
4230+
"""
4231+
/foo:
4232+
post:
4233+
requestBody:
4234+
required: true
4235+
content:
4236+
multipart/form-data:
4237+
schema:
4238+
$ref: '#/components/schemas/NodeWrapper'
4239+
responses:
4240+
default:
4241+
description: Response
4242+
""",
4243+
"""
4244+
schemas:
4245+
NodeWrapper:
4246+
type: object
4247+
properties:
4248+
node:
4249+
$ref: '#/components/schemas/Node'
4250+
Node:
4251+
type: object
4252+
properties:
4253+
parent:
4254+
$ref: '#/components/schemas/Node'
4255+
""",
4256+
input: """
4257+
public struct Input: Sendable, Hashable {
4258+
@frozen public enum Body: Sendable, Hashable {
4259+
case multipartForm(OpenAPIRuntime.MultipartBody<Components.Schemas.NodeWrapper>)
4260+
}
4261+
public var body: Operations.post_sol_foo.Input.Body
4262+
public init(body: Operations.post_sol_foo.Input.Body) {
4263+
self.body = body
4264+
}
4265+
}
4266+
""",
4267+
schemas: """
4268+
public enum Schemas {
4269+
@frozen public enum NodeWrapper: Sendable, Hashable {
4270+
public struct nodePayload: Sendable, Hashable {
4271+
public var body: Components.Schemas.Node
4272+
public init(body: Components.Schemas.Node) {
4273+
self.body = body
4274+
}
4275+
}
4276+
case node(OpenAPIRuntime.MultipartPart<Components.Schemas.NodeWrapper.nodePayload>)
4277+
case undocumented(OpenAPIRuntime.MultipartRawPart)
4278+
}
4279+
public struct Node: Codable, Hashable, Sendable {
4280+
public var parent: Components.Schemas.Node? {
4281+
get {
4282+
self.storage.value.parent
4283+
}
4284+
_modify {
4285+
yield &self.storage.value.parent
4286+
}
4287+
}
4288+
public init(parent: Components.Schemas.Node? = nil) {
4289+
self.storage = .init(value: .init(parent: parent))
4290+
}
4291+
public enum CodingKeys: String, CodingKey {
4292+
case parent
4293+
}
4294+
public init(from decoder: any Decoder) throws {
4295+
self.storage = try .init(from: decoder)
4296+
}
4297+
public func encode(to encoder: any Encoder) throws {
4298+
try self.storage.encode(to: encoder)
4299+
}
4300+
private var storage: OpenAPIRuntime.CopyOnWriteBox<Storage>
4301+
private struct Storage: Codable, Hashable, Sendable {
4302+
var parent: Components.Schemas.Node?
4303+
init(parent: Components.Schemas.Node? = nil) {
4304+
self.parent = parent
4305+
}
4306+
typealias CodingKeys = Components.Schemas.Node.CodingKeys
4307+
}
4308+
}
4309+
}
4310+
""",
4311+
client: """
4312+
{ input in
4313+
let path = try converter.renderedPath(
4314+
template: "/foo",
4315+
parameters: []
4316+
)
4317+
var request: HTTPTypes.HTTPRequest = .init(
4318+
soar_path: path,
4319+
method: .post
4320+
)
4321+
suppressMutabilityWarning(&request)
4322+
let body: OpenAPIRuntime.HTTPBody?
4323+
switch input.body {
4324+
case let .multipartForm(value):
4325+
body = try converter.setRequiredRequestBodyAsMultipart(
4326+
value,
4327+
headerFields: &request.headerFields,
4328+
contentType: "multipart/form-data",
4329+
allowsUnknownParts: true,
4330+
requiredExactlyOncePartNames: [],
4331+
requiredAtLeastOncePartNames: [],
4332+
atMostOncePartNames: [
4333+
"node"
4334+
],
4335+
zeroOrMoreTimesPartNames: [],
4336+
encoding: { part in
4337+
switch part {
4338+
case let .node(wrapped):
4339+
var headerFields: HTTPTypes.HTTPFields = .init()
4340+
let value = wrapped.payload
4341+
let body = try converter.setRequiredRequestBodyAsJSON(
4342+
value.body,
4343+
headerFields: &headerFields,
4344+
contentType: "application/json; charset=utf-8"
4345+
)
4346+
return .init(
4347+
name: "node",
4348+
filename: wrapped.filename,
4349+
headerFields: headerFields,
4350+
body: body
4351+
)
4352+
case let .undocumented(value):
4353+
return value
4354+
}
4355+
}
4356+
)
4357+
}
4358+
return (request, body)
4359+
}
4360+
""",
4361+
server: """
4362+
{ request, requestBody, metadata in
4363+
let contentType = converter.extractContentTypeIfPresent(in: request.headerFields)
4364+
let body: Operations.post_sol_foo.Input.Body
4365+
let chosenContentType = try converter.bestContentType(
4366+
received: contentType,
4367+
options: [
4368+
"multipart/form-data"
4369+
]
4370+
)
4371+
switch chosenContentType {
4372+
case "multipart/form-data":
4373+
body = try converter.getRequiredRequestBodyAsMultipart(
4374+
OpenAPIRuntime.MultipartBody<Components.Schemas.NodeWrapper>.self,
4375+
from: requestBody,
4376+
transforming: { value in
4377+
.multipartForm(value)
4378+
},
4379+
boundary: contentType.requiredBoundary(),
4380+
allowsUnknownParts: true,
4381+
requiredExactlyOncePartNames: [],
4382+
requiredAtLeastOncePartNames: [],
4383+
atMostOncePartNames: [
4384+
"node"
4385+
],
4386+
zeroOrMoreTimesPartNames: [],
4387+
decoding: { part in
4388+
let headerFields = part.headerFields
4389+
let (name, filename) = try converter.extractContentDispositionNameAndFilename(in: headerFields)
4390+
switch name {
4391+
case "node":
4392+
try converter.verifyContentTypeIfPresent(
4393+
in: headerFields,
4394+
matches: "application/json"
4395+
)
4396+
let body = try await converter.getRequiredRequestBodyAsJSON(
4397+
Components.Schemas.Node.self,
4398+
from: part.body,
4399+
transforming: {
4400+
$0
4401+
}
4402+
)
4403+
return .node(.init(
4404+
payload: .init(body: body),
4405+
filename: filename
4406+
))
4407+
default:
4408+
return .undocumented(part)
4409+
}
4410+
}
4411+
)
4412+
default:
4413+
preconditionFailure("bestContentType chose an invalid content type.")
4414+
}
4415+
return Operations.post_sol_foo.Input(body: body)
4416+
}
4417+
"""
4418+
)
4419+
}
4420+
42284421
func testRequestMultipartBodyReferencedSchemaWithEncoding() throws {
42294422
try self.assertRequestInTypesClientServerTranslation(
42304423
"""

0 commit comments

Comments
 (0)