diff --git a/pyproject.toml b/pyproject.toml index ce86aa38..8dcb37a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ test = [ "pytest==8.3.3", "pytest-cov==6.0.0" ] +langgraph = ["langgraph-checkpoint"] [build-system] requires = ["setuptools"] diff --git a/requirements.txt b/requirements.txt index 2433a992..af1ad6a1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,9 @@ langchain-core==0.3.8 numpy==1.26.4 pgvector==0.3.4 SQLAlchemy[asyncio]==2.0.35 +langgraph-checkpoint==2.0.10 +aiohttp~=3.11.11 +asyncpg~=0.30.0 +langgraph~=0.2.68 +pytest~=8.3.4 +pytest-asyncio~=0.25.2 \ No newline at end of file diff --git a/src/langchain_google_cloud_sql_pg/async_checkpoint.py b/src/langchain_google_cloud_sql_pg/async_checkpoint.py new file mode 100644 index 00000000..31633141 --- /dev/null +++ b/src/langchain_google_cloud_sql_pg/async_checkpoint.py @@ -0,0 +1,494 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from contextlib import asynccontextmanager +from typing import Any, Optional, Sequence, Tuple, cast, AsyncIterator + +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.base import ( + WRITES_IDX_MAP, + BaseCheckpointSaver, + ChannelVersions, + Checkpoint, + CheckpointMetadata, + CheckpointTuple, + get_checkpoint_id, +) +from langgraph.checkpoint.serde.base import SerializerProtocol +from langgraph.checkpoint.serde.types import TASKS +from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncEngine + +from .engine import CHECKPOINT_WRITES_TABLE, CHECKPOINTS_TABLE, PostgresEngine + +MetadataInput = Optional[dict[str, Any]] + +# Select SQL used in `alist` method +SELECT = f""" +select + thread_id, + checkpoint, + checkpoint_ns, + checkpoint_id, + parent_checkpoint_id, + metadata, + ( + select array_agg(array[bl.channel::bytea, bl.type::bytea, bl.blob]) + from jsonb_each_text(checkpoint -> 'channel_versions') + ) as channel_values, + ( + select + array_agg(array[cw.task_id::text::bytea, cw.channel::bytea, cw.type::bytea, cw.blob] order by cw.task_id, cw.idx) + from checkpoint_writes cw + where cw.thread_id = checkpoints.thread_id + and cw.checkpoint_ns = checkpoints.checkpoint_ns + and cw.checkpoint_id = checkpoints.checkpoint_id + ) as pending_writes, + ( + select array_agg(array[cw.type::bytea, cw.blob] order by cw.task_path, cw.task_id, cw.idx) + from checkpoint_writes cw + where cw.thread_id = checkpoints.thread_id + and cw.checkpoint_ns = checkpoints.checkpoint_ns + and cw.checkpoint_id = checkpoints.parent_checkpoint_id + and cw.channel = '{TASKS}' + ) as pending_sends +from checkpoints +""" + + +class AsyncPostgresSaver(BaseCheckpointSaver[str]): + """Checkpoint storage for a PostgreSQL database.""" + + __create_key = object() + + jsonplus_serde = JsonPlusSerializer() + + def __init__( + self, + key: object, + pool: AsyncEngine, + schema_name: str = "public", + serde: Optional[SerializerProtocol] = None, + ) -> None: + """ + Initializes an AsyncPostgresSaver instance. + + Args: + key (object): Internal key to restrict instantiation. + pool (AsyncEngine): The database connection pool. + schema_name (str, optional): The schema where the checkpoint tables reside. Defaults to "public". + serde (Optional[SerializerProtocol], optional): Serializer for encoding/decoding checkpoints. Defaults to None. + """ + super().__init__(serde=serde) + if key != AsyncPostgresSaver.__create_key: + raise Exception( + "Only create class through 'create' or 'create_sync' methods" + ) + self.pool = pool + self.schema_name = schema_name + + @classmethod + async def create( + cls, + engine: PostgresEngine, + schema_name: str = "public", + serde: Optional[SerializerProtocol] = None, + ) -> "AsyncPostgresSaver": + """ + Creates a new AsyncPostgresSaver instance. + + Args: + engine (PostgresEngine): The PostgreSQL engine to use. + schema_name (str, optional): The schema name where the table is located. Defaults to "public". + serde (Optional[SerializerProtocol], optional): Serializer for encoding/decoding checkpoints. Defaults to None. + + Raises: + IndexError: If the table does not contain the required schema. + + Returns: + AsyncPostgresSaver: A newly created instance. + """ + + checkpoints_table_schema = await engine._aload_table_schema( + CHECKPOINTS_TABLE, schema_name + ) + checkpoints_column_names = checkpoints_table_schema.columns.keys() + + checkpoints_required_columns = [ + "thread_id", + "checkpoint_ns", + "checkpoint_id", + "parent_checkpoint_id", + "type", + "checkpoint", + "metadata", + ] + + if not ( + all(x in checkpoints_column_names for x in checkpoints_required_columns) + ): + raise IndexError( + f"Table checkpoints.'{schema_name}' has incorrect schema. Got " + f"column names '{checkpoints_column_names}' but required column names " + f"'{checkpoints_required_columns}'.\nPlease create table with the following schema:" + f"\nCREATE TABLE {schema_name}.checkpoints (" + "\n thread_id TEXT NOT NULL," + "\n checkpoint_ns TEXT NOT NULL," + "\n checkpoint_id TEXT NOT NULL," + "\n parent_checkpoint_id TEXT," + "\n type TEXT," + "\n checkpoint JSONB NOT NULL," + "\n metadata JSONB NOT NULL" + "\n);" + ) + + checkpoint_writes_table_schema = await engine._aload_table_schema( + CHECKPOINT_WRITES_TABLE, schema_name + ) + checkpoint_writes_column_names = checkpoint_writes_table_schema.columns.keys() + + checkpoint_writes_columns = [ + "thread_id", + "checkpoint_ns", + "checkpoint_id", + "task_id", + "idx", + "channel", + "type", + "blob", + ] + + if not ( + all(x in checkpoint_writes_column_names for x in checkpoint_writes_columns) + ): + raise IndexError( + f"Table checkpoint_writes.'{schema_name}' has incorrect schema. Got " + f"column names '{checkpoint_writes_column_names}' but required column names " + f"'{checkpoint_writes_columns}'.\nPlease create table with following schema:" + f"\nCREATE TABLE {schema_name}.checkpoint_writes (" + "\n thread_id TEXT NOT NULL," + "\n checkpoint_ns TEXT NOT NULL," + "\n checkpoint_id TEXT NOT NULL," + "\n task_id TEXT NOT NULL," + "\n idx INT NOT NULL," + "\n channel TEXT NOT NULL," + "\n type TEXT," + "\n blob JSONB NOT NULL" + "\n);" + ) + + return cls(cls.__create_key, engine._pool, schema_name, serde) + + def _dump_checkpoint(self, checkpoint: Checkpoint) -> str: + """ + Serializes a checkpoint into a JSON string. + + Args: + checkpoint (Checkpoint): The checkpoint to serialize. + + Returns: + str: The serialized checkpoint as a JSON string. + """ + return {**checkpoint, "pending_sends": []} + + def _dump_metadata(self, metadata: CheckpointMetadata) -> str: + """ + Serializes checkpoint metadata into a JSON string. + + Args: + metadata (CheckpointMetadata): The metadata to serialize. + + Returns: + str: The serialized metadata as a JSON string. + """ + serialized_metadata = self.jsonplus_serde.dumps(metadata) + return serialized_metadata.decode().replace("\\u0000", "") + + def _dump_writes( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: str, + task_id: str, + task_path: str, + writes: Sequence[tuple[str, Any]], + ) -> list[dict[str, Any]]: + return [ + { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint_id, + "task_id": task_id, + "task_path": task_path, + "idx": WRITES_IDX_MAP.get(channel, idx), + "channel": channel, + "type": self.serde.dumps_typed(value)[0], + "blob": self.serde.dumps_typed(value)[1], + } + for idx, (channel, value) in enumerate(writes) + ] + + def _load_blobs( + self, blob_values: list[tuple[bytes, bytes, bytes]] + ) -> dict[str, Any]: + if not blob_values: + return {} + return { + k.decode(): self.serde.loads_typed((t.decode(), v)) + for k, t, v in blob_values + if t.decode() != "empty" + } + + def _load_checkpoint( + self, + checkpoint: dict[str, Any], + channel_values: list[tuple[bytes, bytes, bytes]], + pending_sends: list[tuple[bytes, bytes]], + ) -> Checkpoint: + return Checkpoint( + v=checkpoint["v"], + ts=checkpoint["ts"], + id=checkpoint["id"], + channel_values=self._load_blobs(channel_values), + channel_versions=checkpoint["channel_versions"].copy(), + versions_seen={k: v.copy() for k, v in checkpoint["versions_seen"].items()}, + pending_sends=[ + self.serde.loads_typed((c.decode(), b)) for c, b in pending_sends or [] + ], + ) + + def _load_metadata(self, metadata: str) -> CheckpointMetadata: + return self.jsonplus_serde.loads(self.jsonplus_serde.dumps(metadata)) + + def _load_writes( + self, writes: list[tuple[bytes, bytes, bytes, bytes]] + ) -> list[tuple[str, str, Any]]: + return ( + [ + ( + tid.decode(), + channel.decode(), + self.serde.loads_typed((t.decode(), v)), + ) + for tid, channel, t, v in writes + ] + if writes + else [] + ) + + def _search_where( + self, + config: Optional[RunnableConfig], + filter: MetadataInput, + before: Optional[RunnableConfig] = None, + ) -> tuple[str, list[Any]]: + """Return WHERE clause predicates for alist() given config, filter, before. + This method returns a tuple of a string and a tuple of values. The string + is the parametered WHERE clause predicate (including the WHERE keyword): + "WHERE column1 = $1 AND column2 IS $2". The list of values contains the + values for each of the corresponding parameters. + """ + wheres = [] + param_values = [] + + # construct predicate for config filter + if config: + wheres.append("thread_id = %s ") + param_values.append(config["configurable"]["thread_id"]) + checkpoint_ns = config["configurable"].get("checkpoint_ns") + if checkpoint_ns is not None: + wheres.append("checkpoint_ns = %s") + param_values.append(checkpoint_ns) + + if checkpoint_id := get_checkpoint_id(config): + wheres.append("checkpoint_id = %s ") + param_values.append(checkpoint_id) + + # construct predicate for metadata filter + if filter: + wheres.append("metadata @> %s ") + param_values.append(json.dumps(filter)) + + # construct predicate for `before` + if before is not None: + wheres.append("checkpoint_id < %s ") + param_values.append(get_checkpoint_id(before)) + + return ( + "WHERE " + " AND ".join(wheres) if wheres else "", + param_values, + ) + + async def aput( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + """ + Asynchronously stores a checkpoint with its configuration and metadata. + + Args: + config (RunnableConfig): Configuration for the checkpoint. + checkpoint (Checkpoint): The checkpoint to store. + metadata (CheckpointMetadata): Additional metadata for the checkpoint. + new_versions (ChannelVersions): New channel versions as of this write. + + Returns: + RunnableConfig: Updated configuration after storing the checkpoint. + """ + configurable = config["configurable"].copy() + thread_id = configurable.pop("thread_id") + checkpoint_ns = configurable.pop("checkpoint_ns") + checkpoint_id = configurable.pop( + "checkpoint_id", configurable.pop("thread_ts", None) + ) + + copy = checkpoint.copy() + next_config: RunnableConfig = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint["id"], + } + } + + query = f"""INSERT INTO "{self.schema_name}".{CHECKPOINTS_TABLE}(thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, checkpoint, metadata) + VALUES (:thread_id, :checkpoint_ns, :checkpoint_id, :parent_checkpoint_id, :checkpoint, :metadata) + ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id) + DO UPDATE SET + checkpoint = EXCLUDED.checkpoint, + metadata = EXCLUDED.metadata;""" + + async with self.pool.connect() as conn: + await conn.execute( + text(query), + { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint["id"], + "parent_checkpoint_id": checkpoint_id, + "checkpoint": self._dump_checkpoint(copy), + "metadata": self._dump_metadata(metadata), + }, + ) + await conn.commit() + + return next_config + + async def aput_writes( + self, + config: RunnableConfig, + writes: Sequence[Tuple[str, Any]], + task_id: str, + task_path: str = "", + ) -> None: + """Asynchronously store intermediate writes linked to a checkpoint. + Args: + config (RunnableConfig): Configuration of the related checkpoint. + writes (List[Tuple[str, Any]]): List of writes to store. + task_id (str): Identifier for the task creating the writes. + task_path (str): Path of the task creating the writes. + + Returns: + None + """ + upsert = f"""INSERT INTO "{self.schema_name}".{CHECKPOINT_WRITES_TABLE}(thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob) + VALUES (:thread_id, :checkpoint_ns, :checkpoint_id, :task_id, :idx, :channel, :type, :blob) + ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) DO UPDATE SET + channel = EXCLUDED.channel, + type = EXCLUDED.type, + blob = EXCLUDED.blob; + """ + insert = f"""INSERT INTO "{self.schema_name}".{CHECKPOINT_WRITES_TABLE}(thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob) + VALUES (:thread_id, :checkpoint_ns, :checkpoint_id, :task_id, :idx, :channel, :type, :blob) + ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) DO NOTHING + """ + query = upsert if all(w[0] in WRITES_IDX_MAP for w in writes) else insert + + params = self._dump_writes( + config["configurable"]["thread_id"], + config["configurable"]["checkpoint_ns"], + config["configurable"]["checkpoint_id"], + task_id, + task_path, + writes, + ) + + async with self.pool.connect() as conn: + await conn.execute( + text(query), + params, + ) + await conn.commit() + + async def alist( + self, + config: Optional[RunnableConfig], + *, + filter: Optional[dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None, + ) -> AsyncIterator[CheckpointTuple]: + """Asynchronously list checkpoints that match the given criteria. + Args: + config (Optional[RunnableConfig]): Base configuration for filtering checkpoints. + filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata. + before (Optional[RunnableConfig]): List checkpoints created before this configuration. + limit (Optional[int]): Maximum number of checkpoints to return. + Returns: + AsyncIterator[CheckpointTuple]: Async iterator of matching checkpoint tuples. + """ + + where, args = self._search_where(config, filter, before) + query = SELECT + where + " ORDER BY checkpoint_id DESC" + if limit: + query += f" LIMIT {limit}" + + async with self.pool.connect() as conn: + result = await conn.execute(text(query), args) + rows = result.fetchall() # Getting all the results + + for row in rows: + value = dict(row._mapping) + yield CheckpointTuple( + config={ + "configurable": { + "thread_id": value["thread_id"], + "checkpoint_ns": value["checkpoint_ns"], + "checkpoint_id": value["checkpoint_id"], + } + }, + checkpoint=self._load_checkpoint( + value["checkpoint"], + value["channel_values"], + value["pending_sends"], + ), + metadata=self._load_metadata(value["metadata"]), + parent_config=( + { + "configurable": { + "thread_id": value["thread_id"], + "checkpoint_ns": value["checkpoint_ns"], + "checkpoint_id": value["parent_checkpoint_id"], + } + } + if value["parent_checkpoint_id"] + else None + ), + pending_writes=self._load_writes(value["pending_writes"]), + ) diff --git a/src/langchain_google_cloud_sql_pg/engine.py b/src/langchain_google_cloud_sql_pg/engine.py index 1fc30815..fc8defe6 100644 --- a/src/langchain_google_cloud_sql_pg/engine.py +++ b/src/langchain_google_cloud_sql_pg/engine.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -39,6 +39,9 @@ USER_AGENT = "langchain-google-cloud-sql-pg-python/" + __version__ +CHECKPOINTS_TABLE = "checkpoints" +CHECKPOINT_WRITES_TABLE = "checkpoint_writes" + async def _get_iam_principal_email( credentials: google.auth.credentials.Credentials, @@ -747,6 +750,91 @@ def init_document_table( ) ) + async def _ainit_checkpoint_table( + self, + schema_name: str = "public", + checkpoints_table_name: str = CHECKPOINTS_TABLE, + checkpoint_writes_table_name: str = CHECKPOINT_WRITES_TABLE, + ) -> None: + """ + Create AlloyDB tables to save checkpoints. + Args: + schema_name (str): The schema name to store the checkpoint tables. Default: "public". + checkpoints_table_name (str): Custom table name for checkpoints. Default: CHECKPOINTS_TABLE. + checkpoint_writes_table_name (str): Custom table name for checkpoint writes. Default: CHECKPOINT_WRITES_TABLE. + Returns: + None + """ + create_checkpoints_table = f""" + CREATE TABLE IF NOT EXISTS "{schema_name}".{checkpoints_table_name}( + thread_id TEXT NOT NULL, + checkpoint_ns TEXT NOT NULL DEFAULT '', + checkpoint_id TEXT NOT NULL, + parent_checkpoint_id TEXT, + type TEXT, + checkpoint JSONB NOT NULL, + metadata JSONB NOT NULL DEFAULT '{{}}', + PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id) + );""" + + create_checkpoint_writes_table = f""" + CREATE TABLE IF NOT EXISTS "{schema_name}".{checkpoint_writes_table_name} ( + thread_id TEXT NOT NULL, + checkpoint_ns TEXT NOT NULL DEFAULT '', + checkpoint_id TEXT NOT NULL, + task_id TEXT NOT NULL, + idx INTEGER NOT NULL, + channel TEXT NOT NULL, + type TEXT, + blob BYTEA NOT NULL, + PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) + );""" + + async with self._pool.connect() as conn: + await conn.execute(text(create_checkpoints_table)) + await conn.execute(text(create_checkpoint_writes_table)) + await conn.commit() + + async def ainit_checkpoint_table( + self, + schema_name: str = "public", + checkpoints_table_name: str = CHECKPOINTS_TABLE, + checkpoint_writes_table_name: str = CHECKPOINT_WRITES_TABLE, + ) -> None: + """Create an AlloyDB table to save checkpoint messages. + Args: + schema_name (str): The schema name to store checkpoint tables. Default: "public". + checkpoints_table_name (str): Custom table name for checkpoints. Default: CHECKPOINTS_TABLE. + checkpoint_writes_table_name (str): Custom table name for checkpoint writes. Default: CHECKPOINT_WRITES_TABLE. + Returns: + None + """ + await self._run_as_async( + self._ainit_checkpoint_table( + schema_name, checkpoints_table_name, checkpoint_writes_table_name + ) + ) + + def init_checkpoint_table( + self, + schema_name: str = "public", + checkpoints_table_name: str = CHECKPOINTS_TABLE, + checkpoint_writes_table_name: str = CHECKPOINT_WRITES_TABLE, + ) -> None: + """Create Cloud SQL tables to store checkpoints. + Args: + schema_name (str): The schema name to store checkpoint tables. Default: "public". + checkpoints_table_name (str): Custom table name for checkpoints. Default: CHECKPOINTS_TABLE. + checkpoint_writes_table_name (str): Custom table name for checkpoint writes. Default: CHECKPOINT_WRITES_TABLE. + Returns: + None + """ + self._run_as_sync( + self._ainit_checkpoint_table( + schema_name, checkpoints_table_name, checkpoint_writes_table_name + ) + ) + async def _aload_table_schema( self, table_name: str, diff --git a/tests/test_async_checkpoint.py b/tests/test_async_checkpoint.py new file mode 100644 index 00000000..c43e5d34 --- /dev/null +++ b/tests/test_async_checkpoint.py @@ -0,0 +1,285 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Sequence, Any, Tuple + +import pytest +import pytest_asyncio +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.base import ( + BaseCheckpointSaver, + ChannelVersions, + Checkpoint, + CheckpointMetadata, + create_checkpoint, + empty_checkpoint, +) +from sqlalchemy import text +from sqlalchemy.engine.row import RowMapping + +from langchain_google_cloud_sql_pg.async_checkpoint import AsyncPostgresSaver +from langchain_google_cloud_sql_pg.engine import ( + CHECKPOINT_WRITES_TABLE, + CHECKPOINTS_TABLE, + PostgresEngine, +) + +# Configurations for writing and reading checkpoints +write_config: RunnableConfig = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} +read_config: RunnableConfig = {"configurable": {"thread_id": "1"}} + +# Environment variables for PostgreSQL instance +project_id = os.environ["PROJECT_ID"] +region = os.environ["REGION"] +instance_id = os.environ["INSTANCE_ID"] +db_name = os.environ["DATABASE_ID"] + +# Sample checkpoint data +checkpoint: Checkpoint = { + "v": 1, + "ts": "2024-07-31T20:14:19.804150+00:00", + "id": "1ef4f797-8335-6428-8001-8a1503f9b875", + "channel_values": {"my_key": "meow", "node": "node"}, + "channel_versions": {"__start__": 2, "my_key": 3, "start:node": 3, "node": 3}, + "versions_seen": { + "__input__": {}, + "__start__": {"__start__": 1}, + "node": {"start:node": 2}, + }, + "pending_sends": [], +} + + +async def aexecute(engine: PostgresEngine, query: str) -> None: + """Execute an asynchronous SQL command.""" + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + +async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: + """Fetch data from the database asynchronously.""" + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + +@pytest.fixture +def test_data(): + """Fixture providing test data for checkpoint tests.""" + config_0: RunnableConfig = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} + config_1: RunnableConfig = { + "configurable": { + "thread_id": "thread-1", + # for backwards compatibility testing + "thread_ts": "1", + "checkpoint_ns": "", + } + } + config_2: RunnableConfig = { + "configurable": { + "thread_id": "thread-2", + "checkpoint_id": "2", + "checkpoint_ns": "", + } + } + config_3: RunnableConfig = { + "configurable": { + "thread_id": "thread-2", + "checkpoint_id": "2-inner", + "checkpoint_ns": "inner", + } + } + chkpnt_0: Checkpoint = { + "v": 1, + "ts": "2024-07-31T20:14:19.804150+00:00", + "id": "1ef4f797-8335-6428-8001-8a1503f9b875", + "channel_values": {"my_key": "meow", "node": "node"}, + "channel_versions": {"__start__": 2, "my_key": 3, "start:node": 3, "node": 3}, + "versions_seen": { + "__input__": {}, + "__start__": {"__start__": 1}, + "node": {"start:node": 2}, + }, + "pending_sends": [], + } + chkpnt_1: Checkpoint = empty_checkpoint() + chkpnt_2: Checkpoint = create_checkpoint(chkpnt_1, {}, 1) + chkpnt_3: Checkpoint = empty_checkpoint() + + metadata_1: CheckpointMetadata = { + "source": "input", + "step": 2, + "writes": {}, + "parents": 1, + } + metadata_2: CheckpointMetadata = { + "source": "loop", + "step": 1, + "writes": {"foo": "bar"}, + "parents": None, + } + metadata_3: CheckpointMetadata = {} + + return { + "configs": [config_0, config_1, config_2, config_3], + "checkpoints": [chkpnt_0, chkpnt_1, chkpnt_2, chkpnt_3], + "metadata": [metadata_1, metadata_2, metadata_3], + } + + +@pytest_asyncio.fixture +async def async_engine(): + """Setup and teardown for PostgresEngine instance.""" + async_engine = await PostgresEngine.afrom_instance( + project_id=project_id, + region=region, + instance=instance_id, + database=db_name, + ) + await async_engine._ainit_checkpoint_table() + + yield async_engine # Provide the engine instance for testing + + # Cleanup: Drop checkpoint tables after tests + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{CHECKPOINTS_TABLE}"') + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{CHECKPOINT_WRITES_TABLE}"') + await async_engine.close() + + +@pytest.mark.asyncio +async def test_checkpoint_async( + async_engine: PostgresEngine, test_data: dict[str, Any] +) -> None: + """Test inserting and retrieving a checkpoint asynchronously.""" + + # Create an instance of AsyncPostgresSaver + checkpointer = await AsyncPostgresSaver.create(async_engine) + + configs = test_data["configs"] + checkpoints = test_data["checkpoints"] + metadata = test_data["metadata"] + + test_config = { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": "1ef4f797-8335-6428-8001-8a1503f9b875", + } + } + + # Verify if updated configuration after storing the checkpoint is correct + next_config = await checkpointer.aput(configs[0], checkpoints[0], {}, {}) + assert dict(next_config) == test_config + + # Verify if the checkpoint is stored correctly in the database + results = await afetch(async_engine, f"SELECT * FROM {CHECKPOINTS_TABLE}") + assert len(results) == 1 + for row in results: + assert isinstance(row["thread_id"], str) + await aexecute(async_engine, f"TRUNCATE TABLE {CHECKPOINTS_TABLE}") + + writes: Sequence[Tuple[str, Any]] = [("test_channel1", {}), ("test_channel2", {})] + + await checkpointer.aput_writes(configs[0], writes, task_id="1") + # Verify if the checkpoint writes are stored correctly in the database + results = await afetch(async_engine, f"SELECT * FROM {CHECKPOINT_WRITES_TABLE}") + assert len(results) == 2 + for row in results: + assert isinstance(row["task_id"], str) + await aexecute(async_engine, f"TRUNCATE TABLE {CHECKPOINT_WRITES_TABLE}") + + await checkpointer.aput(configs[1], checkpoints[1], metadata[0], {}) + await checkpointer.aput(configs[2], checkpoints[2], metadata[1], {}) + await checkpointer.aput(configs[3], checkpoints[3], metadata[2], {}) + + # call method / assertions + query_1 = {"source": "input"} # search by 1 key + query_2 = { + "step": 1, + "writes": {"foo": "bar"}, + } # search by multiple keys + query_3: dict[str, Any] = {} # search by no keys, return all checkpoints + query_4 = {"source": "update", "step": 1} # no match + + search_results_1 = [c async for c in checkpointer.alist(None, filter=query_1)] + assert len(search_results_1) == 1 + assert search_results_1[0].metadata == metadata[0] + + search_results_2 = [c async for c in checkpointer.alist(None, filter=query_2)] + assert len(search_results_2) == 1 + assert search_results_2[0].metadata == metadata[1] + + search_results_3 = [c async for c in checkpointer.alist(None, filter=query_3)] + assert len(search_results_3) == 3 + + search_results_4 = [c async for c in checkpointer.alist(None, filter=query_4)] + assert len(search_results_4) == 0 + + # search by config (defaults to checkpoints across all namespaces) + search_results_5 = [ + c async for c in checkpointer.alist({"configurable": {"thread_id": "thread-2"}}) + ] + assert len(search_results_5) == 2 + assert { + search_results_5[0].config["configurable"]["checkpoint_ns"], + search_results_5[1].config["configurable"]["checkpoint_ns"], + } == {"", "inner"} + + +async def test_checkpoint_aput_writes( + async_engine: PostgresEngine, +) -> None: + checkpointer = await AsyncPostgresSaver.create(async_engine) + + config: RunnableConfig = { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": "1ef4f797-8335-6428-8001-8a1503f9b875", + } + } + + # Verify if the checkpoint writes are stored correctly in the database + writes: Sequence[Tuple[str, Any]] = [("test_channel1", {}), ("test_channel2", {})] + await checkpointer.aput_writes(config, writes, task_id="1") + + results = await afetch(async_engine, f"SELECT * FROM {CHECKPOINT_WRITES_TABLE}") + assert len(results) == 2 + for row in results: + assert isinstance(row["task_id"], str) + await aexecute(async_engine, f"TRUNCATE TABLE {CHECKPOINT_WRITES_TABLE}") + + +@pytest.mark.asyncio +async def test_null_chars( + async_engine: PostgresEngine, + test_data: dict[str, Any], +) -> None: + checkpointer = await AsyncPostgresSaver.create(async_engine) + config = await checkpointer.aput( + test_data["configs"][0], + test_data["checkpoints"][0], + {"my_key": "\x00abc"}, # type: ignore + {}, + ) + # assert (await checkpointer.aget_tuple(config)).metadata["my_key"] == "abc" # type: ignore + assert [c async for c in checkpointer.alist(None, filter={"my_key": "abc"})][ + 0 + ].metadata[ + "my_key" + ] == "abc" # type: ignore diff --git a/tests/test_engine.py b/tests/test_engine.py index 1c2653bf..08f600ee 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -28,6 +28,10 @@ from sqlalchemy.pool import NullPool from langchain_google_cloud_sql_pg import Column, PostgresEngine +from langchain_google_cloud_sql_pg.engine import ( + CHECKPOINT_WRITES_TABLE, + CHECKPOINTS_TABLE, +) DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") @@ -121,6 +125,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): await aexecute(engine, f'DROP TABLE "{DEFAULT_TABLE}"') await aexecute(engine, f'DROP TABLE "{INT_ID_CUSTOM_TABLE}"') await engine.close() + await engine._connector.close() async def test_engine_args(self, engine): assert "Pool size: 3" in engine._pool.pool.status() @@ -299,6 +304,54 @@ async def test_iam_account_override( assert engine await aexecute(engine, "SELECT 1") await engine.close() + await engine._connector.close() + + async def test_ainit_checkpoints_table(self, engine): + custom_table_name = "test_checkpoints_table" + await aexecute(engine, f'DROP TABLE IF EXISTS "{custom_table_name}"') + await engine.ainit_checkpoint_table( + schema_name="public", checkpoints_table_name=custom_table_name + ) + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{custom_table_name}';" + results = await afetch(engine, stmt) + + expected = [ + {"column_name": "thread_id", "data_type": "text"}, + {"column_name": "checkpoint_ns", "data_type": "text"}, + {"column_name": "checkpoint_id", "data_type": "text"}, + {"column_name": "parent_checkpoint_id", "data_type": "text"}, + {"column_name": "type", "data_type": "text"}, + {"column_name": "checkpoint", "data_type": "jsonb"}, + {"column_name": "metadata", "data_type": "jsonb"}, + ] + for row in results: + assert row in expected + + async def test_init_checkpoint_writes_table(self, engine): + custom_table_name = "test_checkpoint_writes_table" + await aexecute(engine, f'DROP TABLE IF EXISTS "{custom_table_name}"') + + # Llamar al método correcto `ainit_checkpoint_table` + await engine.ainit_checkpoint_table( + schema_name="public", checkpoint_writes_table_name=custom_table_name + ) + + # Verificar que la consulta se haga sobre la tabla personalizada + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{custom_table_name}';" + results = await afetch(engine, stmt) + + expected = [ + {"column_name": "thread_id", "data_type": "text"}, + {"column_name": "checkpoint_ns", "data_type": "text"}, + {"column_name": "checkpoint_id", "data_type": "text"}, + {"column_name": "task_id", "data_type": "text"}, + {"column_name": "idx", "data_type": "integer"}, + {"column_name": "channel", "data_type": "text"}, + {"column_name": "type", "data_type": "text"}, + {"column_name": "blob", "data_type": "bytea"}, + ] + for row in results: + assert row in expected @pytest.mark.asyncio(scope="module") @@ -449,3 +502,54 @@ async def test_iam_account_override( assert engine await aexecute(engine, "SELECT 1") await engine.close() + + def test_init_checkpoints_table(self, engine): + custom_table_name = "test_checkpoints_table" + aexecute(engine, f'DROP TABLE IF EXISTS "{custom_table_name}"') + + # Llamar a la función correcta `init_checkpoint_table` + engine.init_checkpoint_table( + schema_name="public", checkpoints_table_name=custom_table_name + ) + + # Verificar que la consulta se haga sobre la tabla personalizada + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{custom_table_name}';" + results = afetch(engine, stmt) + + expected = [ + {"column_name": "thread_id", "data_type": "text"}, + {"column_name": "checkpoint_ns", "data_type": "text"}, + {"column_name": "checkpoint_id", "data_type": "text"}, + {"column_name": "parent_checkpoint_id", "data_type": "text"}, + {"column_name": "type", "data_type": "text"}, + {"column_name": "checkpoint", "data_type": "jsonb"}, + {"column_name": "metadata", "data_type": "jsonb"}, + ] + for row in results: + assert row in expected + + def test_init_checkpoint_writes_table(self, engine): + custom_table_name = "test_checkpoint_writes_table" + aexecute(engine, f'DROP TABLE IF EXISTS "{custom_table_name}"') + + # Llamar a la función correcta `init_checkpoint_table` + engine.init_checkpoint_table( + schema_name="public", checkpoint_writes_table_name=custom_table_name + ) + + # Verificar que la consulta se haga sobre la tabla personalizada + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{custom_table_name}';" + results = afetch(engine, stmt) + + expected = [ + {"column_name": "thread_id", "data_type": "text"}, + {"column_name": "checkpoint_ns", "data_type": "text"}, + {"column_name": "checkpoint_id", "data_type": "text"}, + {"column_name": "task_id", "data_type": "text"}, + {"column_name": "idx", "data_type": "integer"}, + {"column_name": "channel", "data_type": "text"}, + {"column_name": "type", "data_type": "text"}, + {"column_name": "blob", "data_type": "bytea"}, + ] + for row in results: + assert row in expected