Skip to content

Commit b5c22d8

Browse files
authored
feat: row_count and unique columns (#14)
Signed-off-by: cutecutecat <[email protected]>
1 parent 5e175a9 commit b5c22d8

File tree

8 files changed

+238
-25
lines changed

8 files changed

+238
-25
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,13 +328,13 @@ from pgvecto_rs.sdk import PGVectoRs, Record
328328
# Create a client
329329
client = PGVectoRs(
330330
db_url="postgresql+psycopg://postgres:mysecretpassword@localhost:5432/postgres",
331-
table_name="example",
331+
collection_name="example",
332332
dimension=3,
333333
)
334334

335335
try:
336336
# Add some records
337-
client.add_records(
337+
client.insert(
338338
[
339339
Record.from_text("hello 1", [1, 2, 3]),
340340
Record.from_text("hello 2", [1, 2, 4]),

examples/sdk_example.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ def embed(text: str):
4040
client.insert(records1)
4141
client.insert(records2)
4242

43+
# Count rows
44+
client.row_count(estimate=True)
45+
4346
# Query (With a filter from the filters module)
4447
print("#################### First Query ####################")
4548
for record, dis in client.search(

src/pgvecto_rs/errors.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,10 @@ def __init__(self, vtype: str) -> None:
6565
class TextParseError(PGVectoRsError):
6666
def __init__(self, payload: str, dtype: type) -> None:
6767
super().__init__(f"failed to parse text of '{payload}' as a {dtype}")
68+
69+
70+
class CountRowsEstimateCondError(PGVectoRsError):
71+
def __init__(self) -> None:
72+
super().__init__(
73+
"cannot use estimate=True and a condition for row count requests"
74+
)

src/pgvecto_rs/sdk/client.py

Lines changed: 81 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,55 +2,95 @@
22
from uuid import UUID
33

44
from numpy import ndarray
5-
from sqlalchemy import ColumnElement, Float, create_engine, delete, insert, select, text
5+
from sqlalchemy import (
6+
BIGINT,
7+
Column,
8+
ColumnElement,
9+
Float,
10+
create_engine,
11+
delete,
12+
func,
13+
insert,
14+
select,
15+
text,
16+
)
617
from sqlalchemy.dialects import postgresql
18+
from sqlalchemy.dialects.postgresql.pg_catalog import pg_class
719
from sqlalchemy.engine import Engine
8-
from sqlalchemy.orm import Mapped, mapped_column
20+
from sqlalchemy.orm import mapped_column
921
from sqlalchemy.orm.session import Session
1022
from sqlalchemy.types import String
1123

24+
from pgvecto_rs.errors import CountRowsEstimateCondError
1225
from pgvecto_rs.sdk.filters import Filter
13-
from pgvecto_rs.sdk.record import Record, RecordORM, RecordORMType
26+
from pgvecto_rs.sdk.record import Record, RecordORM, RecordORMType, Unique
1427
from pgvecto_rs.sqlalchemy import VECTOR
1528

1629

30+
def table_factory(collection_name, dimension, table_args, base_class=RecordORM):
31+
def __init__(self, **kwargs): # noqa: N807
32+
base_class.__init__(self, **kwargs)
33+
34+
newclass = type(
35+
collection_name,
36+
(base_class,),
37+
{
38+
"__init__": __init__,
39+
"__tablename__": f"collection_{collection_name}",
40+
"__table_args__": table_args,
41+
"id": mapped_column(
42+
postgresql.UUID(as_uuid=True),
43+
primary_key=True,
44+
),
45+
"text": mapped_column(String),
46+
"meta": mapped_column(postgresql.JSONB),
47+
"embedding": mapped_column(VECTOR(dimension)),
48+
},
49+
)
50+
return newclass
51+
52+
1753
class PGVectoRs:
1854
_engine: Engine
1955
_table: Type[RecordORM]
2056
dimension: int
2157

22-
def __init__(
23-
self, db_url: str, collection_name: str, dimension: int, recreate: bool = False
58+
def __init__( # noqa: PLR0913
59+
self,
60+
db_url: str,
61+
collection_name: str,
62+
dimension: int,
63+
recreate: bool = False,
64+
constraints: Union[List[Unique], None] = None,
2465
) -> None:
2566
"""Connect to an existing table or create a new empty one.
2667
If the `recreate=True`, the table will be dropped if it exists.
2768
2869
Args:
2970
----
3071
db_url (str): url to the database.
31-
table_name (str): name of the table.
72+
collection_name (str): name of the collection. A prefix `collection_` is added to actual table name.
3273
dimension (int): dimension of the embeddings.
3374
recreate (bool): drop the table if it exists. Defaults to False.
75+
constraints (List[Unique]): add constraints to columns, e.g. UNIQUE constraint
3476
"""
35-
36-
class _Table(RecordORM):
37-
__tablename__ = f"collection_{collection_name}"
38-
__table_args__ = {"extend_existing": True} # noqa: RUF012
39-
id: Mapped[UUID] = mapped_column(
40-
postgresql.UUID(as_uuid=True),
41-
primary_key=True,
77+
if constraints is None or len(constraints) == 0:
78+
table_args = {"extend_existing": True}
79+
else:
80+
table_args = (
81+
*[col.make() for col in constraints],
82+
{"extend_existing": True},
4283
)
43-
text: Mapped[str] = mapped_column(String)
44-
meta: Mapped[dict] = mapped_column(postgresql.JSONB)
45-
embedding: Mapped[ndarray] = mapped_column(VECTOR(dimension))
4684

4785
self._engine = create_engine(db_url)
86+
self._table = table_factory(collection_name, dimension, table_args)
4887
with Session(self._engine) as session:
4988
session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
5089
if recreate:
51-
session.execute(text(f"DROP TABLE IF EXISTS {_Table.__tablename__}"))
90+
session.execute(
91+
text(f"DROP TABLE IF EXISTS {self._table.__tablename__}")
92+
)
5293
session.commit()
53-
self._table = _Table
5494
self._table.__table__.create(self._engine, checkfirst=True)
5595
self.dimension = dimension
5696

@@ -105,6 +145,29 @@ def search(
105145
res = session.execute(stmt)
106146
return [(Record.from_orm(row[0]), row[1]) for row in res]
107147

148+
# ================ Stat ==================
149+
def row_count(self, estimate: bool = True, filter: Optional[Filter] = None) -> int:
150+
if estimate and filter is not None:
151+
raise CountRowsEstimateCondError()
152+
if estimate:
153+
stmt = (
154+
select(func.cast(Column("reltuples", Float), BIGINT).label("rows"))
155+
.select_from(pg_class)
156+
.where(
157+
Column("oid", Float)
158+
== func.cast(self._table.__tablename__, postgresql.REGCLASS)
159+
)
160+
)
161+
with Session(self._engine) as session:
162+
result = session.execute(stmt).fetchone()
163+
else:
164+
stmt = select(func.count("*").label("rows")).select_from(self._table)
165+
if filter is not None:
166+
stmt = stmt.where(filter(self._table))
167+
with Session(self._engine) as session:
168+
result = session.execute(stmt).fetchone()
169+
return result[0]
170+
108171
# ================ Delete ================
109172
def delete(self, filter: Filter) -> None:
110173
with Session(self._engine) as session:

src/pgvecto_rs/sdk/record.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,34 @@
1+
from enum import IntEnum
2+
from functools import reduce
13
from typing import List, Optional, Type, Union
24
from uuid import UUID, uuid4
35

46
from numpy import array, float32, ndarray
7+
from sqlalchemy import UniqueConstraint
58
from sqlalchemy.orm import DeclarativeBase, Mapped
69

710

11+
class Column(IntEnum):
12+
TEXT = 1
13+
META = 2
14+
EMBEDDING = 4
15+
16+
17+
class Unique:
18+
def __init__(self, columns: List[Column]):
19+
self.value = reduce(lambda x, y: x | y, columns)
20+
21+
def make(self) -> UniqueConstraint:
22+
ans: List[UniqueConstraint] = []
23+
if self.value & Column.TEXT:
24+
ans.append("text")
25+
if self.value & Column.META:
26+
ans.append("meta")
27+
if self.value & Column.EMBEDDING:
28+
ans.append("embedding")
29+
return UniqueConstraint(*ans)
30+
31+
832
class RecordORM(DeclarativeBase):
933
__tablename__: str
1034
id: Mapped[UUID]

src/pgvecto_rs/types/index.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ def __init__(
1919
self.ratio = ratio
2020

2121
def dump(self) -> dict:
22-
if self.type == "trivial":
23-
return {"quantization": {"trivial": {}}}
22+
if self.type == "product":
23+
return {"quantization": {"product": {"ratio": self.ratio}}}
2424
elif self.type == "scalar":
2525
return {"quantization": {"scalar": {}}}
2626
else:
27-
return {"quantization": {"product": {"ratio": self.ratio}}}
27+
return {}
2828

2929

3030
class Flat:

tests/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
DB_NAME = os.getenv("DB_NAME", "postgres")
2424

2525
# Run tests with shell:
26-
# DB_HOST=localhost DB_USER=postgres DB_PASS=password DB_NAME=postgres python3 -m pytest bindings/python/tests/
26+
# DB_HOST=localhost DB_USER=postgres DB_PASS=password DB_NAME=postgres python3 -m pytest tests/
2727
URL = f"postgresql://{USER}:{PASS}@{HOST}:{PORT}/{DB_NAME}"
2828
DATABASES = {
2929
"default": {
@@ -106,7 +106,11 @@
106106
),
107107
(
108108
IndexOption(index=Ivf(quantization=Quantization(typ="trivial"))),
109-
"[indexing.ivf.quantization.trivial]\n",
109+
"[indexing.ivf]\n",
110+
),
111+
(
112+
IndexOption(index=Ivf()),
113+
"[indexing.ivf]\n",
110114
),
111115
(
112116
IndexOption(

tests/test_sdk.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
import time
12
from typing import Callable, List
23

34
import numpy as np
45
import pytest
6+
from sqlalchemy.exc import IntegrityError
57

68
from pgvecto_rs.sdk import Filter, PGVectoRs, Record, filters
9+
from pgvecto_rs.sdk.record import Column, Unique
710
from tests import (
811
COSINE_DIS_OP,
912
L2_DIS_OP,
@@ -82,3 +85,112 @@ def test_search_order_and_limit(
8285
for rec, dis in client.search(dis_oprand, dis_op, top_k=4):
8386
expect = assert_func(dis_oprand, rec.embedding.to_numpy())
8487
assert np.allclose(expect, dis, atol=1e-10)
88+
89+
90+
def test_unique_text_table(
91+
client: PGVectoRs,
92+
):
93+
unique_client = PGVectoRs(
94+
db_url=URL,
95+
collection_name="unique_text",
96+
dimension=3,
97+
recreate=True,
98+
constraints=[Unique(columns=[Column.TEXT])],
99+
)
100+
it = iter(MockTexts.items())
101+
text1, vector1 = next(it)
102+
_, vector2 = next(it)
103+
records_ok = [Record.from_text(t, v, {"src": "src1"}) for t, v in MockTexts.items()]
104+
records_fail = [
105+
Record.from_text(text1, vector1, {"src": "src1"}),
106+
Record.from_text(text1, vector2, {"src": "src2"}),
107+
]
108+
unique_client.insert(records_ok)
109+
unique_client.delete_all()
110+
with pytest.raises(IntegrityError):
111+
unique_client.insert(records_fail)
112+
113+
114+
def test_unique_meta_table(
115+
client: PGVectoRs,
116+
):
117+
unique_client = PGVectoRs(
118+
db_url=URL,
119+
collection_name="unique_meta",
120+
dimension=3,
121+
recreate=True,
122+
constraints=[Unique(columns=[Column.META])],
123+
)
124+
it = iter(MockTexts.items())
125+
text1, vector1 = next(it)
126+
text2, vector2 = next(it)
127+
records_ok = [
128+
Record.from_text(text1, vector1, {"src": "src1"}),
129+
Record.from_text(text2, vector2, {"src": "src2"}),
130+
]
131+
records_fail = [
132+
Record.from_text(text1, vector1, {"src": "src1"}),
133+
Record.from_text(text2, vector2, {"src": "src1"}),
134+
]
135+
unique_client.insert(records_ok)
136+
unique_client.delete_all()
137+
with pytest.raises(IntegrityError):
138+
unique_client.insert(records_fail)
139+
140+
141+
def test_unique_text_meta_table(
142+
client: PGVectoRs,
143+
):
144+
unique_client = PGVectoRs(
145+
db_url=URL,
146+
collection_name="unique_both",
147+
dimension=3,
148+
recreate=True,
149+
constraints=[Unique(columns=[Column.TEXT, Column.META])],
150+
)
151+
it = iter(MockTexts.items())
152+
text1, vector1 = next(it)
153+
text2, vector2 = next(it)
154+
records_ok = [
155+
Record.from_text(text1, vector1, {"src": "src1"}),
156+
Record.from_text(text2, vector2, {"src": "src1"}),
157+
]
158+
records_fail = [
159+
Record.from_text(text1, vector1, {"src": "src1"}),
160+
Record.from_text(text1, vector2, {"src": "src1"}),
161+
]
162+
unique_client.insert(records_ok)
163+
unique_client.delete_all()
164+
with pytest.raises(IntegrityError):
165+
unique_client.insert(records_fail)
166+
167+
168+
COUNT = 1000
169+
170+
171+
def test_count_table(
172+
client: PGVectoRs,
173+
):
174+
count_client = PGVectoRs(
175+
db_url=URL,
176+
collection_name="count",
177+
dimension=3,
178+
recreate=True,
179+
)
180+
it = iter(MockTexts.items())
181+
text1, vector1 = next(it)
182+
records = [Record.from_text(text1, vector1, {"src": "src1"}) for _ in range(COUNT)]
183+
count_client.insert(records)
184+
185+
rows = count_client.row_count(estimate=False)
186+
assert rows == COUNT
187+
188+
rows = count_client.row_count(estimate=False, filter=filter_src2)
189+
assert rows == 0
190+
191+
for _ in range(90):
192+
estimate_rows = count_client.row_count(estimate=True)
193+
if estimate_rows == COUNT:
194+
return
195+
time.sleep(1)
196+
raise AssertionError

0 commit comments

Comments
 (0)