|
12 | 12 | from collections.abc import Generator |
13 | 13 | from contextlib import contextmanager |
14 | 14 | from types import TracebackType |
15 | | -from typing import Any, Literal |
| 15 | +from typing import TYPE_CHECKING, Any, Literal |
16 | 16 |
|
17 | 17 | import sqlite_vec |
18 | 18 |
|
|
29 | 29 | validate_top_k, |
30 | 30 | ) |
31 | 31 |
|
| 32 | +if TYPE_CHECKING: |
| 33 | + from .pool import ConnectionPool |
| 34 | + |
32 | 35 | logger = get_logger() |
33 | 36 |
|
34 | 37 |
|
@@ -91,22 +94,34 @@ def rows_to_results(rows: list[sqlite3.Row]) -> list[Result]: |
91 | 94 | for row in rows |
92 | 95 | ] |
93 | 96 |
|
94 | | - def __init__(self, table: str, db_path: str) -> None: |
| 97 | + def __init__( |
| 98 | + self, table: str, db_path: str | None = None, pool: ConnectionPool | None = None |
| 99 | + ) -> None: |
95 | 100 | """Initialize the client for a given base table and database file. |
96 | 101 |
|
97 | 102 | Args: |
98 | 103 | table: Name of the base table |
99 | | - db_path: Path to SQLite database file |
| 104 | + db_path: Path to SQLite database file (required if pool is None) |
| 105 | + pool: Optional connection pool for connection reuse |
100 | 106 |
|
101 | 107 | Raises: |
102 | 108 | TableNameError: If table name is invalid |
103 | 109 | VecConnectionError: If connection fails |
| 110 | + ValueError: If both db_path and pool are None |
104 | 111 | """ |
105 | 112 | validate_table_name(table) |
106 | 113 | self.table = table |
107 | 114 | self._in_transaction = False |
| 115 | + self._pool = pool |
108 | 116 | logger.debug(f"Initializing SQLiteVecClient for table: {table}") |
109 | | - self.connection = self.create_connection(db_path) |
| 117 | + |
| 118 | + if pool: |
| 119 | + self.connection = pool.get_connection() |
| 120 | + logger.debug("Using connection from pool") |
| 121 | + elif db_path: |
| 122 | + self.connection = self.create_connection(db_path) |
| 123 | + else: |
| 124 | + raise ValueError("Either db_path or pool must be provided") |
110 | 125 |
|
111 | 126 | def __enter__(self) -> SQLiteVecClient: |
112 | 127 | """Support context manager protocol and return `self`.""" |
@@ -518,10 +533,14 @@ def transaction(self) -> Generator[None, None, None]: |
518 | 533 | self._in_transaction = False |
519 | 534 |
|
520 | 535 | def close(self) -> None: |
521 | | - """Close the underlying SQLite connection, suppressing close errors.""" |
| 536 | + """Close or return the connection to pool, suppressing close errors.""" |
522 | 537 | try: |
523 | 538 | logger.debug(f"Closing connection for table '{self.table}'") |
524 | | - self.connection.close() |
525 | | - logger.info(f"Connection closed for table '{self.table}'") |
| 539 | + if self._pool: |
| 540 | + self._pool.return_connection(self.connection) |
| 541 | + logger.info(f"Connection returned to pool for table '{self.table}'") |
| 542 | + else: |
| 543 | + self.connection.close() |
| 544 | + logger.info(f"Connection closed for table '{self.table}'") |
526 | 545 | except Exception as e: |
527 | 546 | logger.warning(f"Error closing connection: {e}") |
0 commit comments