Skip to content

Commit 72d9361

Browse files
christophstroblmp911de
authored andcommitted
Update AOT fragment generation to align with reflective behavior.
Closes: #5027 Original pull request: #5038
1 parent bbefb0b commit 72d9361

21 files changed

+519
-137
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,11 @@
111111
import org.springframework.data.mongodb.core.timeseries.Granularity;
112112
import org.springframework.data.mongodb.core.validation.Validator;
113113
import org.springframework.data.projection.EntityProjection;
114+
import org.springframework.data.projection.ProjectionFactory;
114115
import org.springframework.data.util.CloseableIterator;
115116
import org.springframework.data.util.Lazy;
116117
import org.springframework.data.util.Optionals;
118+
import org.springframework.data.util.TypeInformation;
117119
import org.springframework.lang.Contract;
118120
import org.springframework.util.Assert;
119121
import org.springframework.util.ClassUtils;
@@ -2272,8 +2274,17 @@ protected <O> AggregationResults<O> doAggregate(Aggregation aggregation, String
22722274
<T, O> AggregationResults<O> doAggregate(Aggregation aggregation, String collectionName, Class<T> outputType,
22732275
QueryResultConverter<? super T, ? extends O> resultConverter, AggregationOperationContext context) {
22742276

2275-
DocumentCallback<O> callback = new QueryResultConverterCallback<>(resultConverter,
2277+
final DocumentCallback<O> callback;
2278+
if(aggregation instanceof TypedAggregation<?> ta && outputType.isInterface()) {
2279+
EntityProjection<T, ?> projection = operations.introspectProjection(outputType, ta.getInputType());
2280+
ProjectingReadCallback cb = new ProjectingReadCallback(mongoConverter, projection, collectionName);
2281+
callback = new QueryResultConverterCallback<>(resultConverter,
2282+
cb);
2283+
} else {
2284+
2285+
callback = new QueryResultConverterCallback<>(resultConverter,
22762286
new ReadDocumentCallback<>(mongoConverter, outputType, collectionName));
2287+
}
22772288

22782289
AggregationOptions options = aggregation.getOptions();
22792290
AggregationUtil aggregationUtil = new AggregationUtil(queryMapper, mappingContext);

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/QueryOperations.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,9 @@ Document getMappedFields(@Nullable MongoPersistentEntity<?> entity, EntityProjec
377377
mappedFields = queryMapper.getMappedFields(fields, entity);
378378
} else {
379379
mappedFields = propertyOperations.computeMappedFieldsForProjection(projection, fields);
380+
if(projection.getMappedType().getType().isInterface()) {
381+
mappedFields = queryMapper.getMappedFields(mappedFields, entity);
382+
}
380383
mappedFields = queryMapper.addMetaAttributes(mappedFields, entity);
381384
}
382385

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222

2323
import org.bson.Document;
2424
import org.jspecify.annotations.NullUnmarked;
25-
2625
import org.springframework.core.ResolvableType;
2726
import org.springframework.core.annotation.MergedAnnotation;
2827
import org.springframework.data.domain.SliceImpl;
2928
import org.springframework.data.domain.Sort.Order;
3029
import org.springframework.data.mongodb.core.MongoOperations;
3130
import org.springframework.data.mongodb.core.aggregation.Aggregation;
31+
import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
3232
import org.springframework.data.mongodb.core.aggregation.AggregationOptions;
3333
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
3434
import org.springframework.data.mongodb.core.aggregation.AggregationResults;
@@ -80,12 +80,7 @@ CodeBlock build() {
8080

8181
builder.add("\n");
8282

83-
Class<?> outputType = queryMethod.getReturnedObjectType();
84-
if (MongoSimpleTypes.HOLDER.isSimpleType(outputType)) {
85-
outputType = Document.class;
86-
} else if (ClassUtils.isAssignable(AggregationResults.class, outputType)) {
87-
outputType = queryMethod.getReturnType().getComponentType().getType();
88-
}
83+
Class<?> outputType = getOutputType(queryMethod);
8984

9085
if (ReflectionUtils.isVoid(queryMethod.getReturnedObjectType())) {
9186
builder.addStatement("$L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType);
@@ -146,7 +141,6 @@ CodeBlock build() {
146141
builder.addStatement("return $L.aggregateStream($L, $T.class)", mongoOpsRef, aggregationVariableName,
147142
outputType);
148143
} else {
149-
150144
builder.addStatement("return $L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef,
151145
aggregationVariableName, outputType);
152146
}
@@ -155,6 +149,17 @@ CodeBlock build() {
155149

156150
return builder.build();
157151
}
152+
153+
}
154+
155+
private static Class<?> getOutputType(MongoQueryMethod queryMethod) {
156+
Class<?> outputType = queryMethod.getReturnedObjectType();
157+
if (MongoSimpleTypes.HOLDER.isSimpleType(outputType)) {
158+
outputType = Document.class;
159+
} else if (ClassUtils.isAssignable(AggregationResults.class, outputType) && queryMethod.getReturnType().getComponentType() != null) {
160+
outputType = queryMethod.getReturnType().getComponentType().getType();
161+
}
162+
return outputType;
158163
}
159164

160165
@NullUnmarked
@@ -173,13 +178,7 @@ static class AggregationCodeBlockBuilder {
173178

174179
this.context = context;
175180
this.queryMethod = queryMethod;
176-
String parameterNames = StringUtils.collectionToDelimitedString(context.getAllParameterNames(), ", ");
177-
178-
if (StringUtils.hasText(parameterNames)) {
179-
this.parameterNames = ", " + parameterNames;
180-
} else {
181-
this.parameterNames = "";
182-
}
181+
this.parameterNames = StringUtils.collectionToDelimitedString(context.getAllParameterNames(), ", ");
183182
}
184183

185184
AggregationCodeBlockBuilder stages(AggregationInteraction aggregation) {
@@ -231,7 +230,8 @@ private CodeBlock pipeline(String pipelineVariableName) {
231230
builder.add(aggregationStages(context.localVariable("stages"), source.stages()));
232231

233232
if (StringUtils.hasText(sortParameter)) {
234-
builder.add(sortingStage(sortParameter));
233+
Class<?> outputType = getOutputType(queryMethod);
234+
builder.add(sortingStage(sortParameter, outputType));
235235
}
236236

237237
if (StringUtils.hasText(limitParameter)) {
@@ -244,6 +244,7 @@ private CodeBlock pipeline(String pipelineVariableName) {
244244

245245
builder.addStatement("$T $L = createPipeline($L)", AggregationPipeline.class, pipelineVariableName,
246246
context.localVariable("stages"));
247+
247248
return builder.build();
248249
}
249250

@@ -312,7 +313,7 @@ private CodeBlock aggregationStages(String stageListVariableName, Collection<Str
312313
return builder.build();
313314
}
314315

315-
private CodeBlock sortingStage(String sortProvider) {
316+
private CodeBlock sortingStage(String sortProvider, Class<?> outputType) {
316317

317318
Builder builder = CodeBlock.builder();
318319

@@ -322,8 +323,17 @@ private CodeBlock sortingStage(String sortProvider) {
322323
builder.addStatement("$1L.append($2L.getProperty(), $2L.isAscending() ? 1 : -1);",
323324
context.localVariable("sortDocument"), context.localVariable("order"));
324325
builder.endControlFlow();
325-
builder.addStatement("stages.add(new $T($S, $L))", Document.class, "$sort",
326-
context.localVariable("sortDocument"));
326+
327+
if (outputType == Document.class || MongoSimpleTypes.HOLDER.isSimpleType(outputType)
328+
|| ClassUtils.isAssignable(context.getRepositoryInformation().getDomainType(), outputType)) {
329+
builder.addStatement("$L.add(new $T($S, $L))", context.localVariable("stages"), Document.class, "$sort",
330+
context.localVariable("sortDocument"));
331+
} else {
332+
builder.addStatement("$L.add(($T) _ctx -> new $T($S, _ctx.getMappedObject($L, $T.class)))",
333+
context.localVariable("stages"), AggregationOperation.class, Document.class, "$sort",
334+
context.localVariable("sortDocument"), outputType);
335+
}
336+
327337
builder.endControlFlow();
328338

329339
return builder.build();
@@ -333,7 +343,7 @@ private CodeBlock pagingStage(String pageableProvider, boolean slice) {
333343

334344
Builder builder = CodeBlock.builder();
335345

336-
builder.add(sortingStage(pageableProvider + ".getSort()"));
346+
builder.add(sortingStage(pageableProvider + ".getSort()", getOutputType(queryMethod)));
337347

338348
builder.beginControlFlow("if ($L.isPaged())", pageableProvider);
339349
builder.beginControlFlow("if ($L.getOffset() > 0)", pageableProvider);

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotPlaceholders.java

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import java.util.List;
1919

20+
import org.jspecify.annotations.Nullable;
2021
import org.springframework.data.geo.Box;
2122
import org.springframework.data.geo.Circle;
2223
import org.springframework.data.geo.Distance;
@@ -52,7 +53,7 @@ static Placeholder indexed(int position) {
5253
* @param type
5354
* @return
5455
*/
55-
public static Shape geoJson(int index, String type) {
56+
static Shape geoJson(int index, String type) {
5657
return new GeoJsonPlaceholder(index, type);
5758
}
5859

@@ -62,7 +63,7 @@ public static Shape geoJson(int index, String type) {
6263
* @param index zero-based index referring to the bindable method parameter.
6364
* @return
6465
*/
65-
public static Point point(int index) {
66+
static Point point(int index) {
6667
return new PointPlaceholder(index);
6768
}
6869

@@ -72,7 +73,7 @@ public static Point point(int index) {
7273
* @param index zero-based index referring to the bindable method parameter.
7374
* @return
7475
*/
75-
public static Shape circle(int index) {
76+
static Shape circle(int index) {
7677
return new CirclePlaceholder(index);
7778
}
7879

@@ -82,7 +83,7 @@ public static Shape circle(int index) {
8283
* @param index zero-based index referring to the bindable method parameter.
8384
* @return
8485
*/
85-
public static Shape box(int index) {
86+
static Shape box(int index) {
8687
return new BoxPlaceholder(index);
8788
}
8889

@@ -92,7 +93,7 @@ public static Shape box(int index) {
9293
* @param index zero-based index referring to the bindable method parameter.
9394
* @return
9495
*/
95-
public static Shape sphere(int index) {
96+
static Shape sphere(int index) {
9697
return new SpherePlaceholder(index);
9798
}
9899

@@ -102,20 +103,23 @@ public static Shape sphere(int index) {
102103
* @param index zero-based index referring to the bindable method parameter.
103104
* @return
104105
*/
105-
public static Shape polygon(int index) {
106+
static Shape polygon(int index) {
106107
return new PolygonPlaceholder(index);
107108
}
108109

110+
static RegexPlaceholder regex(int index, @Nullable String options) {
111+
return new RegexPlaceholder(index, options);
112+
}
113+
109114
/**
110115
* A placeholder expression used when rending queries to JSON.
111116
*
112117
* @since 5.0
113118
* @author Christoph Strobl
114119
*/
115-
public interface Placeholder {
120+
interface Placeholder {
116121

117122
String getValue();
118-
119123
}
120124

121125
/**
@@ -139,7 +143,7 @@ private static class PointPlaceholder extends Point implements Placeholder {
139143

140144
private final int index;
141145

142-
public PointPlaceholder(int index) {
146+
PointPlaceholder(int index) {
143147
super(Double.NaN, Double.NaN);
144148
this.index = index;
145149
}
@@ -184,7 +188,7 @@ private static class CirclePlaceholder extends Circle implements Placeholder {
184188

185189
private final int index;
186190

187-
public CirclePlaceholder(int index) {
191+
CirclePlaceholder(int index) {
188192
super(new PointPlaceholder(index), Distance.of(1, Metrics.NEUTRAL)); //
189193
this.index = index;
190194
}
@@ -205,7 +209,7 @@ private static class BoxPlaceholder extends Box implements Placeholder {
205209

206210
private final int index;
207211

208-
public BoxPlaceholder(int index) {
212+
BoxPlaceholder(int index) {
209213
super(new PointPlaceholder(index), new PointPlaceholder(index));
210214
this.index = index;
211215
}
@@ -226,7 +230,7 @@ private static class SpherePlaceholder extends Sphere implements Placeholder {
226230

227231
private final int index;
228232

229-
public SpherePlaceholder(int index) {
233+
SpherePlaceholder(int index) {
230234
super(new PointPlaceholder(index), Distance.of(1, Metrics.NEUTRAL)); //
231235
this.index = index;
232236
}
@@ -247,7 +251,7 @@ private static class PolygonPlaceholder extends Polygon implements Placeholder {
247251

248252
private final int index;
249253

250-
public PolygonPlaceholder(int index) {
254+
PolygonPlaceholder(int index) {
251255
super(new PointPlaceholder(index), new PointPlaceholder(index), new PointPlaceholder(index),
252256
new PointPlaceholder(index));
253257
this.index = index;
@@ -265,4 +269,29 @@ public String toString() {
265269

266270
}
267271

272+
static class RegexPlaceholder implements Placeholder {
273+
274+
private final int index;
275+
private final @Nullable String options;
276+
277+
RegexPlaceholder(int index, @Nullable String options) {
278+
this.index = index;
279+
this.options = options;
280+
}
281+
282+
@Nullable String regexOptions() {
283+
return options;
284+
}
285+
286+
@Override
287+
public String getValue() {
288+
return "?" + index;
289+
}
290+
291+
@Override
292+
public String toString() {
293+
return getValue();
294+
}
295+
}
296+
268297
}

0 commit comments

Comments
 (0)