|
5 | 5 | * you may not use this file except in compliance with the License.
|
6 | 6 | * You may obtain a copy of the License at
|
7 | 7 | *
|
8 |
| - * http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | + * https://www.apache.org/licenses/LICENSE-2.0 |
9 | 9 | *
|
10 | 10 | * Unless required by applicable law or agreed to in writing, software
|
11 | 11 | * distributed under the License is distributed on an "AS IS" BASIS,
|
|
16 | 16 | package com.example.spi.mongodb.atlas;
|
17 | 17 |
|
18 | 18 | import java.util.List;
|
19 |
| -import java.util.Objects; |
20 | 19 |
|
21 | 20 | import org.bson.Document;
|
22 | 21 | import org.springframework.beans.factory.annotation.Autowired;
|
| 22 | +import org.springframework.core.ResolvableType; |
23 | 23 | import org.springframework.data.domain.Limit;
|
24 | 24 | import org.springframework.data.mongodb.core.MongoOperations;
|
25 | 25 | import org.springframework.data.mongodb.core.aggregation.Aggregation;
|
26 |
| -import org.springframework.data.mongodb.core.query.CriteriaDefinition; |
27 |
| -import org.springframework.data.repository.core.support.RepositoryMethodContext; |
28 |
| -import org.springframework.lang.Nullable; |
| 26 | +import org.springframework.data.repository.core.RepositoryMetadata; |
| 27 | +import org.springframework.data.repository.core.RepositoryMethodContext; |
| 28 | +import org.springframework.data.repository.core.support.RepositoryMetadataAccess; |
29 | 29 |
|
30 |
| -class AtlasRepositoryFragment<T> implements AtlasRepository<T> { |
| 30 | +class AtlasRepositoryFragment<T> implements AtlasRepository<T>, RepositoryMetadataAccess { |
31 | 31 |
|
32 |
| - private MongoOperations mongoOperations; |
| 32 | + private final MongoOperations mongoOperations; |
33 | 33 |
|
34 | 34 | public AtlasRepositoryFragment(@Autowired MongoOperations mongoOperations) {
|
35 | 35 | this.mongoOperations = mongoOperations;
|
36 | 36 | }
|
37 | 37 |
|
38 | 38 | @Override
|
39 | 39 | @SuppressWarnings("unchecked")
|
40 |
| - public List<T> vectorSearch(String index, String path, List<Double> vector, Limit limit) { |
| 40 | + public List<T> vectorSearch(String index, String path, List<Double> vector) { |
41 | 41 |
|
42 |
| - RepositoryMethodContext metadata = RepositoryMethodContext.currentMethod(); |
| 42 | + RepositoryMethodContext methodContext = RepositoryMethodContext.getContext(); |
43 | 43 |
|
44 |
| - Class<?> domainType = metadata.getRepository().getDomainType(); |
45 |
| - System.out.println("domainType: " + domainType); |
| 44 | + Class<?> domainType = resolveDomainType(methodContext.getMetadata()); |
46 | 45 |
|
47 |
| - Document $vectorSearch = createDocument(index, path, vector, limit, null, null, null); |
| 46 | + Document $vectorSearch = createDocument(index, path, vector, Limit.of(10)); |
48 | 47 | Aggregation aggregation = Aggregation.newAggregation(ctx -> $vectorSearch);
|
49 | 48 |
|
50 |
| - return (List<T>) mongoOperations.aggregate(aggregation, mongoOperations.getCollectionName(domainType), domainType); |
| 49 | + return (List<T>) mongoOperations.aggregate(aggregation, mongoOperations.getCollectionName(domainType), domainType).getMappedResults(); |
51 | 50 | }
|
52 | 51 |
|
53 |
| - private static Document createDocument(String indexName, String path, List<Double> vector, Limit limit, @Nullable Boolean exact, @Nullable CriteriaDefinition filter, @Nullable Integer numCandidates) { |
| 52 | + @SuppressWarnings("unchecked") |
| 53 | + private static <T> Class<T> resolveDomainType(RepositoryMetadata metadata) { |
| 54 | + |
| 55 | + // resolve the actual generic type argument of the AtlasRepository<T>. |
| 56 | + return (Class<T>) ResolvableType.forClass(metadata.getRepositoryInterface()) |
| 57 | + .as(AtlasRepository.class) |
| 58 | + .getGeneric(0) |
| 59 | + .resolve(); |
| 60 | + } |
| 61 | + |
| 62 | + private static Document createDocument(String indexName, String path, List<Double> vector, Limit limit) { |
54 | 63 |
|
55 | 64 | Document $vectorSearch = new Document();
|
56 | 65 |
|
57 | 66 | $vectorSearch.append("index", indexName);
|
58 | 67 | $vectorSearch.append("path", path);
|
59 | 68 | $vectorSearch.append("queryVector", vector);
|
60 | 69 | $vectorSearch.append("limit", limit.max());
|
61 |
| - |
62 |
| - if (exact != null) { |
63 |
| - $vectorSearch.append("exact", exact); |
64 |
| - } |
65 |
| - |
66 |
| - if (filter != null) { |
67 |
| - $vectorSearch.append("filter", filter.getCriteriaObject()); |
68 |
| - } |
69 |
| - |
70 |
| - if (numCandidates != null) { |
71 |
| - $vectorSearch.append("numCandidates", numCandidates); |
72 |
| - } |
| 70 | + $vectorSearch.append("numCandidates", 150); |
73 | 71 |
|
74 | 72 | return new Document("$vectorSearch", $vectorSearch);
|
75 | 73 | }
|
|
0 commit comments