Skip to content

Adding readMany() support for findAllByIds() to improve performance. #43759

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/spring/azure-spring-data-cosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### 5.21.0-beta.1 (Unreleased)

#### Features Added
* Added `readMany()` API support to `findAllByIds()` - See [PR 43759](https://github.com/Azure/azure-sdk-for-java/pull/43759).

#### Breaking Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
import com.azure.cosmos.models.CosmosContainerProperties;
import com.azure.cosmos.models.CosmosContainerResponse;
import com.azure.cosmos.models.CosmosDatabaseResponse;
import com.azure.cosmos.models.CosmosItemIdentity;
import com.azure.cosmos.models.CosmosItemOperation;
import com.azure.cosmos.models.CosmosItemRequestOptions;
import com.azure.cosmos.models.CosmosItemResponse;
import com.azure.cosmos.models.CosmosPatchItemRequestOptions;
import com.azure.cosmos.models.CosmosPatchOperations;
import com.azure.cosmos.models.CosmosQueryRequestOptions;
import com.azure.cosmos.models.CosmosReadManyRequestOptions;
import com.azure.cosmos.models.FeedResponse;
import com.azure.cosmos.models.PartitionKey;
import com.azure.cosmos.models.PartitionKeyDefinition;
Expand Down Expand Up @@ -856,13 +858,40 @@ public <T, ID> Iterable<T> findByIds(Iterable<ID> ids, Class<T> domainType, Stri
Assert.notNull(ids, "Id list should not be null");
Assert.notNull(domainType, "domainType should not be null.");
Assert.hasText(containerName, "container should not be null, empty or only whitespaces");
final List<Object> idList = new ArrayList<>();
for (ID id : ids) {
idList.add(CosmosUtils.getStringIDValue(id));

CosmosEntityInformation<?, ?> cosmosEntityInformation = CosmosEntityInformation.getInstance(domainType);
String containerPartitionKey = cosmosEntityInformation.getPartitionKeyFieldName();
if ("id".equals(containerPartitionKey) && ids.iterator().next() != null) {
List<CosmosItemIdentity> idList = new ArrayList<>();
for (ID id : ids) {
idList.add(new CosmosItemIdentity(new PartitionKey(id), String.valueOf(id)));
}

final CosmosReadManyRequestOptions cosmosReadManyRequestOptions = new CosmosReadManyRequestOptions();
containerName = getContainerNameOverride(containerName);
cosmosReadManyRequestOptions.setQueryMetricsEnabled(this.queryMetricsEnabled);
cosmosReadManyRequestOptions.setIndexMetricsEnabled(this.indexMetricsEnabled);
cosmosReadManyRequestOptions.setResponseContinuationTokenLimitInKb(this.responseContinuationTokenLimitInKb);

return this.getCosmosAsyncClient()
.getDatabase(this.getDatabaseName())
.getContainer(containerName)
.readMany(idList, cosmosReadManyRequestOptions, domainType)
.publishOn(CosmosSchedulers.SPRING_DATA_COSMOS_PARALLEL)
.onErrorResume(throwable ->
CosmosExceptionUtils.exceptionHandler("Failed to find items", throwable,
this.responseDiagnosticsProcessor))
.block().getResults();
} else {
final List<Object> idList = new ArrayList<>();
for (ID id : ids) {
idList.add(CosmosUtils.getStringIDValue(id));
}
final CosmosQuery query = new CosmosQuery(Criteria.getInstance(CriteriaType.IN, "id",
Collections.singletonList(idList), Part.IgnoreCaseType.NEVER));
return find(query, domainType, containerName);
}
final CosmosQuery query = new CosmosQuery(Criteria.getInstance(CriteriaType.IN, "id",
Collections.singletonList(idList), Part.IgnoreCaseType.NEVER));
return find(query, domainType, containerName);

}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.spring.data.cosmos.repository.integration;

import com.azure.spring.data.cosmos.IntegrationTestCollectionManager;
import com.azure.spring.data.cosmos.common.ResponseDiagnosticsTestUtils;
import com.azure.spring.data.cosmos.common.TestUtils;
import com.azure.spring.data.cosmos.config.CosmosConfig;
import com.azure.spring.data.cosmos.core.CosmosTemplate;
import com.azure.spring.data.cosmos.domain.Address;
import com.azure.spring.data.cosmos.domain.BasicItem;
import com.azure.spring.data.cosmos.repository.TestRepositoryConfig;
import com.azure.spring.data.cosmos.repository.repository.BasicItemRepository;
import org.assertj.core.util.Lists;
import org.junit.Assert;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;

import static com.azure.spring.data.cosmos.common.TestConstants.ID_1;
import static com.azure.spring.data.cosmos.common.TestConstants.ID_2;

@RunWith(SpringJUnit4ClassRunner.class)
@ContextConfiguration(classes = TestRepositoryConfig.class)
public class BasicItemRepositoryIT {

@ClassRule
public static final IntegrationTestCollectionManager collectionManager = new IntegrationTestCollectionManager();

@Autowired
BasicItemRepository repository;

@Autowired
CosmosConfig cosmosConfig;

@Autowired
private CosmosTemplate template;

@Autowired
private ResponseDiagnosticsTestUtils responseDiagnosticsTestUtils;

private static final BasicItem BASIC_ITEM_1 = new BasicItem(ID_1);

private static final BasicItem BASIC_ITEM_2 = new BasicItem(ID_2);

@Before
public void setUp() {
collectionManager.ensureContainersCreatedAndEmpty(template, Address.class);
repository.saveAll(Lists.newArrayList(BASIC_ITEM_1, BASIC_ITEM_2));
}

@Test
public void testFindAllById() {
final Iterable<BasicItem> allById =
TestUtils.toList(this.repository.findAllById(Arrays.asList(BASIC_ITEM_1.getId(), BASIC_ITEM_2.getId())));
Assert.assertTrue(((ArrayList) allById).size() == 2);
Iterator<BasicItem> it = allById.iterator();
Assert.assertEquals(BASIC_ITEM_1, it.next());
Assert.assertEquals(BASIC_ITEM_2, it.next());
}
}