Skip to content

Commit 769cf10

Browse files
quaffmarkpollack
authored andcommitted
JdbcChatMemoryRepository should use the provided JdbcTemplate
Before this commit, the underlying `JdbcTemplate` is created like `new JdbcTemplate(providedJdbcTemplate.getDataSource())`, it means that settings on provided `JdbcTemplate` will lose. Signed-off-by: Yanming Zhou <[email protected]>
1 parent b83a162 commit 769cf10

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepository.java

+17-5
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
* @author Thomas Vitale
5454
* @author Linar Abzaltdinov
5555
* @author Mark Pollack
56+
* @author Yanming Zhou
5657
* @since 1.0.0
5758
*/
5859
public final class JdbcChatMemoryRepository implements ChatMemoryRepository {
@@ -65,14 +66,14 @@ public final class JdbcChatMemoryRepository implements ChatMemoryRepository {
6566

6667
private static final Logger logger = LoggerFactory.getLogger(JdbcChatMemoryRepository.class);
6768

68-
private JdbcChatMemoryRepository(DataSource dataSource, JdbcChatMemoryRepositoryDialect dialect,
69+
private JdbcChatMemoryRepository(JdbcTemplate jdbcTemplate, JdbcChatMemoryRepositoryDialect dialect,
6970
PlatformTransactionManager txManager) {
70-
Assert.notNull(dataSource, "dataSource cannot be null");
71+
Assert.notNull(jdbcTemplate, "jdbcTemplate cannot be null");
7172
Assert.notNull(dialect, "dialect cannot be null");
72-
this.jdbcTemplate = new JdbcTemplate(dataSource);
73+
this.jdbcTemplate = jdbcTemplate;
7374
this.dialect = dialect;
7475
this.transactionTemplate = new TransactionTemplate(
75-
txManager != null ? txManager : new DataSourceTransactionManager(dataSource));
76+
txManager != null ? txManager : new DataSourceTransactionManager(jdbcTemplate.getDataSource()));
7677
}
7778

7879
@Override
@@ -192,7 +193,18 @@ public Builder transactionManager(PlatformTransactionManager txManager) {
192193
public JdbcChatMemoryRepository build() {
193194
DataSource effectiveDataSource = resolveDataSource();
194195
JdbcChatMemoryRepositoryDialect effectiveDialect = resolveDialect(effectiveDataSource);
195-
return new JdbcChatMemoryRepository(effectiveDataSource, effectiveDialect, this.platformTransactionManager);
196+
return new JdbcChatMemoryRepository(resolveJdbcTemplate(), effectiveDialect,
197+
this.platformTransactionManager);
198+
}
199+
200+
private JdbcTemplate resolveJdbcTemplate() {
201+
if (this.jdbcTemplate != null) {
202+
return this.jdbcTemplate;
203+
}
204+
if (this.dataSource != null) {
205+
return new JdbcTemplate(this.dataSource);
206+
}
207+
throw new IllegalArgumentException("DataSource must be set (either via dataSource() or jdbcTemplate())");
196208
}
197209

198210
private DataSource resolveDataSource() {

memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/test/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepositoryBuilderTests.java

+12
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import org.junit.jupiter.api.Test;
2626

27+
import org.springframework.jdbc.core.JdbcTemplate;
2728
import org.springframework.transaction.PlatformTransactionManager;
2829

2930
import static org.assertj.core.api.Assertions.assertThat;
@@ -35,6 +36,7 @@
3536
* Tests for {@link JdbcChatMemoryRepository.Builder}.
3637
*
3738
* @author Mark Pollack
39+
* @author Yanming Zhou
3840
*/
3941
public class JdbcChatMemoryRepositoryBuilderTests {
4042

@@ -224,4 +226,14 @@ void testBuilderPreferenceForExplicitDialect() throws SQLException {
224226
// for this)
225227
}
226228

229+
@Test
230+
void repositoryShouldUseProvidedJdbcTemplate() throws SQLException {
231+
DataSource dataSource = mock(DataSource.class);
232+
JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSource);
233+
234+
JdbcChatMemoryRepository repository = JdbcChatMemoryRepository.builder().jdbcTemplate(jdbcTemplate).build();
235+
236+
assertThat(repository).extracting("jdbcTemplate").isSameAs(jdbcTemplate);
237+
}
238+
227239
}

0 commit comments

Comments
 (0)