Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>4.0.1</version>
<version>4.0.2</version>
<relativePath/> <!-- lookup parent from repository -->
</parent>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@
@Retention(RetentionPolicy.RUNTIME)
@NullMarked
public @interface ForceSwaggerSchema {

boolean includeSubTypes() default true;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package it.aboutbits.springboot.toolbox.swagger.annotation;

import org.jspecify.annotations.NullMarked;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
* Annotation to mark classes that should be ignored when forcing Swagger schemas.
*/
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@NullMarked
public @interface ForceSwaggerSchemaIgnore {
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@
import io.swagger.v3.oas.models.media.Schema;
import it.aboutbits.springboot.toolbox.reflection.util.ClassScannerUtil;
import it.aboutbits.springboot.toolbox.swagger.annotation.ForceSwaggerSchema;
import it.aboutbits.springboot.toolbox.swagger.annotation.ForceSwaggerSchemaIgnore;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.jspecify.annotations.NullMarked;
import org.springdoc.core.customizers.OpenApiCustomizer;

import java.lang.reflect.Modifier;
import java.util.Arrays;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Set;

@RequiredArgsConstructor
@Slf4j
Expand Down Expand Up @@ -42,7 +46,31 @@ private void addAnnotatedSchemas(OpenAPI openAPI) {
// Scan for classes with @ForceSwaggerSchema annotation
var annotatedClasses = classScanner.getClassesAnnotatedWith(ForceSwaggerSchema.class);

var classesToProcess = new HashSet<Class<?>>();
for (var clazz : annotatedClasses) {
if (clazz.isAnnotationPresent(ForceSwaggerSchemaIgnore.class)) {
continue;
}
classesToProcess.add(clazz);
var annotation = clazz.getAnnotation(ForceSwaggerSchema.class);
if (annotation != null && annotation.includeSubTypes()) {
var subTypes = classScanner.getSubTypesOf(clazz);
for (var subType : subTypes) {
if (!subType.isAnnotationPresent(ForceSwaggerSchemaIgnore.class)) {
classesToProcess.add(subType);
}
}

collectPublicNestedTypes(clazz, classesToProcess);
for (var subType : subTypes) {
if (classesToProcess.contains(subType)) {
collectPublicNestedTypes(subType, classesToProcess);
}
}
}
}

for (var clazz : classesToProcess) {
log.info("Forcing schema for class: {}", clazz.getName());

if (clazz.isEnum()) {
Expand Down Expand Up @@ -77,5 +105,15 @@ private void addAnnotatedSchemas(OpenAPI openAPI) {
log.debug("Scanned packages: {}", String.join(", ", classScanner.getScannedPackages()));
}
}

private void collectPublicNestedTypes(Class<?> clazz, Set<Class<?>> collected) {
for (var nested : clazz.getDeclaredClasses()) {
if (Modifier.isPublic(nested.getModifiers()) && !nested.isAnnotationPresent(ForceSwaggerSchemaIgnore.class)) {
if (collected.add(nested)) {
collectPublicNestedTypes(nested, collected);
}
}
}
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
package it.aboutbits.springboot.toolbox.swagger.customization.force_schema;

import com.fasterxml.jackson.databind.ObjectMapper;
import io.swagger.v3.core.jackson.ModelResolver;
import io.swagger.v3.oas.models.OpenAPI;
import it.aboutbits.springboot.toolbox.reflection.util.ClassScannerUtil;
import it.aboutbits.springboot.toolbox.swagger.annotation.ForceSwaggerSchema;
import it.aboutbits.springboot.toolbox.swagger.annotation.ForceSwaggerSchemaIgnore;
import org.jspecify.annotations.NullUnmarked;
import org.junit.jupiter.api.Test;

import java.util.Set;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

@NullUnmarked
class ForceSchemaCustomizerTest {

@ForceSwaggerSchema(includeSubTypes = true)
public static class AnnotatedClass {
public String field;

public static class InnerClass {
public String innerField;
}

public record InnerRecord(String recordField) {
}

@SuppressWarnings("checkstyle:FinalClass")
private static class PrivateInnerClass {
public String privateField;
}
}

@ForceSwaggerSchema(includeSubTypes = false)
public static class AnnotatedClassWithoutSubTypes {
public String field;

public static class InnerClass {
public String innerField;
}
}

@ForceSwaggerSchema(includeSubTypes = true)
public static class BaseClass {
}

public static class SubClass extends BaseClass {
public static class SubInnerClass {
}
}

@ForceSwaggerSchema(includeSubTypes = true)
public static class ClassWithIgnoredMembers {
public String field;

@ForceSwaggerSchemaIgnore
public static class IgnoredInnerClass {
public String innerField;
}

public static class NotIgnoredInnerClass {
public String innerField;
}
}

@ForceSwaggerSchema(includeSubTypes = true)
public static class BaseWithIgnoredSubClass {
}

@ForceSwaggerSchemaIgnore
public static class IgnoredSubClass extends BaseWithIgnoredSubClass {
}

public static class NotIgnoredSubClass extends BaseWithIgnoredSubClass {
}

@ForceSwaggerSchema
@ForceSwaggerSchemaIgnore
public static class AnnotatedAndIgnored {
}

@Test
void shouldIncludeSubTypesWhenEnabled() {
var classScanner = mock(ClassScannerUtil.ClassScanner.class);
when(classScanner.getClassesAnnotatedWith(ForceSwaggerSchema.class)).thenReturn(Set.of(
AnnotatedClass.class,
BaseClass.class
));
when(classScanner.getSubTypesOf(BaseClass.class)).thenReturn(Set.of(SubClass.class));

var modelResolver = new ModelResolver(new ObjectMapper());
var customizer = new ForceSchemaCustomizer(modelResolver, classScanner);
var openApi = new OpenAPI();

customizer.customise(openApi);

var schemas = openApi.getComponents().getSchemas();
assertThat(schemas).containsKey(AnnotatedClass.class.getSimpleName());
assertThat(schemas).containsKey(AnnotatedClass.InnerClass.class.getSimpleName());
assertThat(schemas).containsKey(AnnotatedClass.InnerRecord.class.getSimpleName());
assertThat(schemas).doesNotContainKey(AnnotatedClass.PrivateInnerClass.class.getSimpleName());

assertThat(schemas).containsKey(BaseClass.class.getSimpleName());
assertThat(schemas).containsKey(SubClass.class.getSimpleName());
assertThat(schemas).containsKey(SubClass.SubInnerClass.class.getSimpleName());
}

@Test
void shouldNotIncludeSubTypesWhenDisabled() {
var classScanner = mock(ClassScannerUtil.ClassScanner.class);
when(classScanner.getClassesAnnotatedWith(ForceSwaggerSchema.class)).thenReturn(Set.of(
AnnotatedClassWithoutSubTypes.class));

var modelResolver = new ModelResolver(new ObjectMapper());
var customizer = new ForceSchemaCustomizer(modelResolver, classScanner);
var openApi = new OpenAPI();

customizer.customise(openApi);

var schemas = openApi.getComponents().getSchemas();
assertThat(schemas).containsKey(AnnotatedClassWithoutSubTypes.class.getSimpleName());
assertThat(schemas).doesNotContainKey(AnnotatedClassWithoutSubTypes.InnerClass.class.getSimpleName());
}

@Test
void shouldExcludeIgnoredClasses() {
var classScanner = mock(ClassScannerUtil.ClassScanner.class);
when(classScanner.getClassesAnnotatedWith(ForceSwaggerSchema.class)).thenReturn(Set.of(
ClassWithIgnoredMembers.class,
BaseWithIgnoredSubClass.class
));
when(classScanner.getSubTypesOf(BaseWithIgnoredSubClass.class)).thenReturn(Set.of(
IgnoredSubClass.class,
NotIgnoredSubClass.class
));

var modelResolver = new ModelResolver(new ObjectMapper());
var customizer = new ForceSchemaCustomizer(modelResolver, classScanner);
var openApi = new OpenAPI();

customizer.customise(openApi);

var schemas = openApi.getComponents().getSchemas();
assertThat(schemas).containsKey(ClassWithIgnoredMembers.class.getSimpleName());
assertThat(schemas).containsKey(ClassWithIgnoredMembers.NotIgnoredInnerClass.class.getSimpleName());
assertThat(schemas).doesNotContainKey(ClassWithIgnoredMembers.IgnoredInnerClass.class.getSimpleName());

assertThat(schemas).containsKey(BaseWithIgnoredSubClass.class.getSimpleName());
assertThat(schemas).containsKey(NotIgnoredSubClass.class.getSimpleName());
assertThat(schemas).doesNotContainKey(IgnoredSubClass.class.getSimpleName());
}

@Test
void shouldExcludeClassWhenBothAnnotatedAndIgnored() {
var classScanner = mock(ClassScannerUtil.ClassScanner.class);
when(classScanner.getClassesAnnotatedWith(ForceSwaggerSchema.class)).thenReturn(Set.of(
AnnotatedAndIgnored.class
));

var modelResolver = new ModelResolver(new ObjectMapper());
var customizer = new ForceSchemaCustomizer(modelResolver, classScanner);
var openApi = new OpenAPI();

customizer.customise(openApi);

var schemas = openApi.getComponents().getSchemas();
assertThat(schemas).doesNotContainKey(AnnotatedAndIgnored.class.getSimpleName());
}
}