Skip to content
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ test = [
"pytest==8.3.3",
"pytest-cov==6.0.0"
]
langgraph = ["langgraph-checkpoint"]

[build-system]
requires = ["setuptools"]
Expand Down
6 changes: 6 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,9 @@ langchain-core==0.3.8
numpy==1.26.4
pgvector==0.3.4
SQLAlchemy[asyncio]==2.0.35
langgraph-checkpoint==2.0.10
aiohttp~=3.11.11
asyncpg~=0.30.0
langgraph~=0.2.68
pytest~=8.3.4
pytest-asyncio~=0.25.2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove these dependencies. The test dependencies are defined in the pyproejct.toml. We will only need to keep langgraph-checkpoint so we can test against it.

237 changes: 237 additions & 0 deletions src/langchain_google_cloud_sql_pg/async_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from contextlib import asynccontextmanager
from typing import Any, Optional

from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
BaseCheckpointSaver,
ChannelVersions,
Checkpoint,
CheckpointMetadata,
)
from langgraph.checkpoint.serde.base import SerializerProtocol
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncEngine

from .engine import CHECKPOINT_WRITES_TABLE, CHECKPOINTS_TABLE, PostgresEngine


class AsyncPostgresSaver(BaseCheckpointSaver[str]):
"""Checkpoint storage for a PostgreSQL database."""

__create_key = object()

jsonplus_serde = JsonPlusSerializer()

def __init__(
self,
key: object,
pool: AsyncEngine,
schema_name: str = "public",
serde: Optional[SerializerProtocol] = None,
) -> None:
"""
Initializes an AsyncPostgresSaver instance.
Args:
key (object): Internal key to restrict instantiation.
pool (AsyncEngine): The database connection pool.
schema_name (str, optional): The schema where the checkpoint tables reside. Defaults to "public".
serde (Optional[SerializerProtocol], optional): Serializer for encoding/decoding checkpoints. Defaults to None.
"""
super().__init__(serde=serde)
if key != AsyncPostgresSaver.__create_key:
raise Exception(
"Only create class through 'create' or 'create_sync' methods"
)
self.pool = pool
self.schema_name = schema_name

@classmethod
async def create(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Create will also have to support a custom table name for the tests to work

cls,
engine: PostgresEngine,
schema_name: str = "public",
serde: Optional[SerializerProtocol] = None,
) -> "AsyncPostgresSaver":
"""
Creates a new AsyncPostgresSaver instance.
Args:
engine (PostgresEngine): The PostgreSQL engine to use.
schema_name (str, optional): The schema name where the table is located. Defaults to "public".
serde (Optional[SerializerProtocol], optional): Serializer for encoding/decoding checkpoints. Defaults to None.
Raises:
IndexError: If the table does not contain the required schema.
Returns:
AsyncPostgresSaver: A newly created instance.
"""

checkpoints_table_schema = await engine._aload_table_schema(
CHECKPOINTS_TABLE, schema_name
)
checkpoints_column_names = checkpoints_table_schema.columns.keys()

checkpoints_required_columns = [
"thread_id",
"checkpoint_ns",
"checkpoint_id",
"parent_checkpoint_id",
"type",
"checkpoint",
"metadata",
]

if not (
all(x in checkpoints_column_names for x in checkpoints_required_columns)
):
raise IndexError(
f"Table checkpoints.'{schema_name}' has incorrect schema. Got "
f"column names '{checkpoints_column_names}' but required column names "
f"'{checkpoints_required_columns}'.\nPlease create table with the following schema:"
f"\nCREATE TABLE {schema_name}.checkpoints ("
"\n thread_id TEXT NOT NULL,"
"\n checkpoint_ns TEXT NOT NULL,"
"\n checkpoint_id TEXT NOT NULL,"
"\n parent_checkpoint_id TEXT,"
"\n type TEXT,"
"\n checkpoint JSONB NOT NULL,"
"\n metadata JSONB NOT NULL"
"\n);"
)

checkpoint_writes_table_schema = await engine._aload_table_schema(CHECKPOINT_WRITES_TABLE, schema_name)
checkpoint_writes_column_names = checkpoint_writes_table_schema.columns.keys()

checkpoint_writes_columns = [
"thread_id",
"checkpoint_ns",
"checkpoint_id",
"task_id",
"idx",
"channel",
"type",
"blob",
]

if not (
all(x in checkpoint_writes_column_names for x in checkpoint_writes_columns)
):
raise IndexError(
f"Table checkpoint_writes.'{schema_name}' has incorrect schema. Got "
f"column names '{checkpoint_writes_column_names}' but required column names "
f"'{checkpoint_writes_columns}'.\nPlease create table with following schema:"
f"\nCREATE TABLE {schema_name}.checkpoint_writes ("
"\n thread_id TEXT NOT NULL,"
"\n checkpoint_ns TEXT NOT NULL,"
"\n checkpoint_id TEXT NOT NULL,"
"\n task_id TEXT NOT NULL,"
"\n idx INT NOT NULL,"
"\n channel TEXT NOT NULL,"
"\n type TEXT,"
"\n blob JSONB NOT NULL"
"\n);"
)

return cls(cls.__create_key, engine._pool, schema_name, serde)

def _dump_checkpoint(self, checkpoint: Checkpoint) -> str:
"""
Serializes a checkpoint into a JSON string.
Args:
checkpoint (Checkpoint): The checkpoint to serialize.
Returns:
str: The serialized checkpoint as a JSON string.
"""
checkpoint["pending_sends"] = []
return json.dumps(checkpoint)

def _dump_metadata(self, metadata: CheckpointMetadata) -> str:
"""
Serializes checkpoint metadata into a JSON string.
Args:
metadata (CheckpointMetadata): The metadata to serialize.
Returns:
str: The serialized metadata as a JSON string.
"""
serialized_metadata = self.jsonplus_serde.dumps(metadata)
return serialized_metadata.decode().replace("\\u0000", "")

async def aput(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""
Asynchronously stores a checkpoint with its configuration and metadata.
Args:
config (RunnableConfig): Configuration for the checkpoint.
checkpoint (Checkpoint): The checkpoint to store.
metadata (CheckpointMetadata): Additional metadata for the checkpoint.
new_versions (ChannelVersions): New channel versions as of this write.
Returns:
RunnableConfig: Updated configuration after storing the checkpoint.
"""
configurable = config["configurable"].copy()
thread_id = configurable.pop("thread_id")
checkpoint_ns = configurable.pop("checkpoint_ns")
checkpoint_id = configurable.pop(
"checkpoint_id", configurable.pop("thread_ts", None)
)

copy = checkpoint.copy()
next_config: RunnableConfig = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint["id"],
}
}

query = f"""INSERT INTO "{self.schema_name}".{CHECKPOINTS_TABLE}(thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, checkpoint, metadata)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add quotes around table name

VALUES (:thread_id, :checkpoint_ns, :checkpoint_id, :parent_checkpoint_id, :checkpoint, :metadata)
ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id)
DO UPDATE SET
checkpoint = EXCLUDED.checkpoint,
metadata = EXCLUDED.metadata;"""

async with self.pool.connect() as conn:
await conn.execute(
text(query),
{
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint["id"],
"parent_checkpoint_id": checkpoint_id,
"checkpoint": self._dump_checkpoint(copy),
"metadata": self._dump_metadata(metadata),
},
)
await conn.commit()

return next_config
90 changes: 89 additions & 1 deletion src/langchain_google_cloud_sql_pg/engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Google LLC
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -39,6 +39,9 @@

USER_AGENT = "langchain-google-cloud-sql-pg-python/" + __version__

CHECKPOINTS_TABLE = "checkpoints"
CHECKPOINT_WRITES_TABLE = "checkpoint_writes"


async def _get_iam_principal_email(
credentials: google.auth.credentials.Credentials,
Expand Down Expand Up @@ -747,6 +750,91 @@ def init_document_table(
)
)

async def _ainit_checkpoint_table(
self,
schema_name: str = "public",
checkpoints_table_name: str = CHECKPOINTS_TABLE,
checkpoint_writes_table_name: str = CHECKPOINT_WRITES_TABLE,
) -> None:
"""
Create AlloyDB tables to save checkpoints.
Args:
schema_name (str): The schema name to store the checkpoint tables. Default: "public".
checkpoints_table_name (str): Custom table name for checkpoints. Default: CHECKPOINTS_TABLE.
checkpoint_writes_table_name (str): Custom table name for checkpoint writes. Default: CHECKPOINT_WRITES_TABLE.
Returns:
None
"""
create_checkpoints_table = f"""
CREATE TABLE IF NOT EXISTS "{schema_name}".{checkpoints_table_name}(
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
checkpoint_id TEXT NOT NULL,
parent_checkpoint_id TEXT,
type TEXT,
checkpoint JSONB NOT NULL,
metadata JSONB NOT NULL DEFAULT '{{}}',
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id)
);"""

create_checkpoint_writes_table = f"""
CREATE TABLE IF NOT EXISTS "{schema_name}".{checkpoint_writes_table_name} (
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
checkpoint_id TEXT NOT NULL,
task_id TEXT NOT NULL,
idx INTEGER NOT NULL,
channel TEXT NOT NULL,
type TEXT,
blob BYTEA NOT NULL,
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx)
);"""

async with self._pool.connect() as conn:
await conn.execute(text(create_checkpoints_table))
await conn.execute(text(create_checkpoint_writes_table))
await conn.commit()

async def ainit_checkpoint_table(
self,
schema_name: str = "public",
checkpoints_table_name: str = CHECKPOINTS_TABLE,
checkpoint_writes_table_name: str = CHECKPOINT_WRITES_TABLE,
) -> None:
"""Create an AlloyDB table to save checkpoint messages.
Args:
schema_name (str): The schema name to store checkpoint tables. Default: "public".
checkpoints_table_name (str): Custom table name for checkpoints. Default: CHECKPOINTS_TABLE.
checkpoint_writes_table_name (str): Custom table name for checkpoint writes. Default: CHECKPOINT_WRITES_TABLE.
Returns:
None
"""
await self._run_as_async(
self._ainit_checkpoint_table(
schema_name, checkpoints_table_name, checkpoint_writes_table_name
)
)

def init_checkpoint_table(
self,
schema_name: str = "public",
checkpoints_table_name: str = CHECKPOINTS_TABLE,
checkpoint_writes_table_name: str = CHECKPOINT_WRITES_TABLE,
) -> None:
"""Create Cloud SQL tables to store checkpoints.
Args:
schema_name (str): The schema name to store checkpoint tables. Default: "public".
checkpoints_table_name (str): Custom table name for checkpoints. Default: CHECKPOINTS_TABLE.
checkpoint_writes_table_name (str): Custom table name for checkpoint writes. Default: CHECKPOINT_WRITES_TABLE.
Returns:
None
"""
self._run_as_sync(
self._ainit_checkpoint_table(
schema_name, checkpoints_table_name, checkpoint_writes_table_name
)
)

async def _aload_table_schema(
self,
table_name: str,
Expand Down
Loading
Loading