-
Notifications
You must be signed in to change notification settings - Fork 18
feat: add checkpoint postgresql #262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
carloszuag
wants to merge
13
commits into
googleapis:langgraph-base
from
carloszuag:checkpoint_postgresql
Closed
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
6b9aac0
adding new engine method for postgresql
a6585ed
adding changes
efe4552
adding feature
d10d96d
fixing correction and code refactor
9ea952d
adding async postgresql class
ab64647
adding changes to engine and tests
0659b29
adding write schema
d476337
fix: address comments form the PR
622619a
fix: address comments from the PR
bd63d22
Merge branch 'feat_engine_methods' into checkpoint_postgresql
df0ce4e
Merge remote-tracking branch 'upstream/langgraph-base' into checkpoin…
e05d73b
fix: addressing commnets and solving conflincts
8407082
code refactoring (black and isort)
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,3 +3,4 @@ langchain-core==0.3.8 | |
| numpy==1.26.4 | ||
| pgvector==0.3.4 | ||
| SQLAlchemy[asyncio]==2.0.35 | ||
| langgraph-checkpoint==2.0.10 | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,243 @@ | ||
| # Copyright 2025 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import json | ||
| from contextlib import asynccontextmanager | ||
| from typing import Any, Optional | ||
|
|
||
| from langchain_core.runnables import RunnableConfig | ||
| from langgraph.checkpoint.base import ( | ||
| BaseCheckpointSaver, | ||
| ChannelVersions, | ||
| Checkpoint, | ||
| CheckpointMetadata, | ||
| ) | ||
| from langgraph.checkpoint.serde.base import SerializerProtocol | ||
| from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer | ||
| from sqlalchemy import text | ||
| from sqlalchemy.ext.asyncio import AsyncEngine | ||
|
|
||
| from .engine import CHECKPOINT_WRITES_TABLE, CHECKPOINTS_TABLE, PostgresEngine | ||
|
|
||
|
|
||
| class AsyncPostgresSaver(BaseCheckpointSaver[str]): | ||
| """Checkpoint storage for a PostgreSQL database.""" | ||
|
|
||
| __create_key = object() | ||
|
|
||
| jsonplus_serde = JsonPlusSerializer() | ||
|
|
||
| def __init__( | ||
| self, | ||
| key: object, | ||
| pool: AsyncEngine, | ||
| schema_name: str = "public", | ||
| serde: Optional[SerializerProtocol] = None, | ||
| ) -> None: | ||
| """ | ||
| Initializes an AsyncPostgresSaver instance. | ||
|
|
||
| Args: | ||
| key (object): Internal key to restrict instantiation. | ||
| pool (AsyncEngine): The database connection pool. | ||
| schema_name (str, optional): The schema where the checkpoint tables reside. Defaults to "public". | ||
| serde (Optional[SerializerProtocol], optional): Serializer for encoding/decoding checkpoints. Defaults to None. | ||
| """ | ||
| super().__init__(serde=serde) | ||
| if key != AsyncPostgresSaver.__create_key: | ||
| raise Exception( | ||
| "Only create class through 'create' or 'create_sync' methods" | ||
| ) | ||
| self.pool = pool | ||
| self.schema_name = schema_name | ||
|
|
||
| @classmethod | ||
| async def create( | ||
| cls, | ||
| engine: PostgresEngine, | ||
| schema_name: str = "public", | ||
| table_name: str = CHECKPOINTS_TABLE, | ||
| writes_table_name: str = CHECKPOINT_WRITES_TABLE, | ||
| serde: Optional[SerializerProtocol] = None, | ||
| ) -> "AsyncPostgresSaver": | ||
| """ | ||
| Creates a new AsyncPostgresSaver instance. | ||
|
|
||
| Args: | ||
| engine (PostgresEngine): The PostgreSQL engine to use. | ||
| schema_name (str, optional): The schema name where the table is located. Defaults to "public". | ||
| table_name (str): Custom table name for checkpoints. Default: CHECKPOINTS_TABLE. | ||
| writes_table_name (str): Custom table name for checkpoint writes. Default: CHECKPOINT_WRITES_TABLE. | ||
| serde (Optional[SerializerProtocol], optional): Serializer for encoding/decoding checkpoints. Defaults to None. | ||
|
|
||
| Raises: | ||
| IndexError: If the table does not contain the required schema. | ||
|
|
||
| Returns: | ||
| AsyncPostgresSaver: A newly created instance. | ||
| """ | ||
|
|
||
| checkpoints_table_schema = await engine._aload_table_schema( | ||
| table_name, schema_name | ||
| ) | ||
| checkpoints_column_names = checkpoints_table_schema.columns.keys() | ||
|
|
||
| checkpoints_required_columns = [ | ||
| "thread_id", | ||
| "checkpoint_ns", | ||
| "checkpoint_id", | ||
| "parent_checkpoint_id", | ||
| "type", | ||
| "checkpoint", | ||
| "metadata", | ||
| ] | ||
|
|
||
| if not ( | ||
| all(x in checkpoints_column_names for x in checkpoints_required_columns) | ||
| ): | ||
| raise IndexError( | ||
| f"Table checkpoints.'{schema_name}' has incorrect schema. Got " | ||
| f"column names '{checkpoints_column_names}' but required column names " | ||
| f"'{checkpoints_required_columns}'.\nPlease create table with the following schema:" | ||
| f"\nCREATE TABLE {schema_name}.checkpoints (" | ||
| "\n thread_id TEXT NOT NULL," | ||
| "\n checkpoint_ns TEXT NOT NULL," | ||
| "\n checkpoint_id TEXT NOT NULL," | ||
| "\n parent_checkpoint_id TEXT," | ||
| "\n type TEXT," | ||
| "\n checkpoint JSONB NOT NULL," | ||
| "\n metadata JSONB NOT NULL" | ||
| "\n);" | ||
| ) | ||
|
|
||
| checkpoint_writes_table_schema = await engine._aload_table_schema( | ||
| writes_table_name, schema_name | ||
| ) | ||
| checkpoint_writes_column_names = checkpoint_writes_table_schema.columns.keys() | ||
|
|
||
| checkpoint_writes_columns = [ | ||
| "thread_id", | ||
| "checkpoint_ns", | ||
| "checkpoint_id", | ||
| "task_id", | ||
| "idx", | ||
| "channel", | ||
| "type", | ||
| "blob", | ||
| ] | ||
|
|
||
| if not ( | ||
| all(x in checkpoint_writes_column_names for x in checkpoint_writes_columns) | ||
| ): | ||
| raise IndexError( | ||
| f"Table checkpoint_writes.'{schema_name}' has incorrect schema. Got " | ||
| f"column names '{checkpoint_writes_column_names}' but required column names " | ||
| f"'{checkpoint_writes_columns}'.\nPlease create table with following schema:" | ||
| f"\nCREATE TABLE {schema_name}.checkpoint_writes (" | ||
| "\n thread_id TEXT NOT NULL," | ||
| "\n checkpoint_ns TEXT NOT NULL," | ||
| "\n checkpoint_id TEXT NOT NULL," | ||
| "\n task_id TEXT NOT NULL," | ||
| "\n idx INT NOT NULL," | ||
| "\n channel TEXT NOT NULL," | ||
| "\n type TEXT," | ||
| "\n blob JSONB NOT NULL" | ||
| "\n);" | ||
| ) | ||
|
|
||
| return cls(cls.__create_key, engine._pool, schema_name, serde) | ||
|
|
||
| def _dump_checkpoint(self, checkpoint: Checkpoint) -> str: | ||
| """ | ||
| Serializes a checkpoint into a JSON string. | ||
|
|
||
| Args: | ||
| checkpoint (Checkpoint): The checkpoint to serialize. | ||
|
|
||
| Returns: | ||
| str: The serialized checkpoint as a JSON string. | ||
| """ | ||
| checkpoint["pending_sends"] = [] | ||
| return json.dumps(checkpoint) | ||
|
|
||
| def _dump_metadata(self, metadata: CheckpointMetadata) -> str: | ||
| """ | ||
| Serializes checkpoint metadata into a JSON string. | ||
|
|
||
| Args: | ||
| metadata (CheckpointMetadata): The metadata to serialize. | ||
|
|
||
| Returns: | ||
| str: The serialized metadata as a JSON string. | ||
| """ | ||
| serialized_metadata = self.jsonplus_serde.dumps(metadata) | ||
| return serialized_metadata.decode().replace("\\u0000", "") | ||
|
|
||
| async def aput( | ||
| self, | ||
| config: RunnableConfig, | ||
| checkpoint: Checkpoint, | ||
| metadata: CheckpointMetadata, | ||
| new_versions: ChannelVersions, | ||
| ) -> RunnableConfig: | ||
| """ | ||
| Asynchronously stores a checkpoint with its configuration and metadata. | ||
|
|
||
| Args: | ||
| config (RunnableConfig): Configuration for the checkpoint. | ||
| checkpoint (Checkpoint): The checkpoint to store. | ||
| metadata (CheckpointMetadata): Additional metadata for the checkpoint. | ||
| new_versions (ChannelVersions): New channel versions as of this write. | ||
|
|
||
| Returns: | ||
| RunnableConfig: Updated configuration after storing the checkpoint. | ||
| """ | ||
| configurable = config["configurable"].copy() | ||
| thread_id = configurable.pop("thread_id") | ||
| checkpoint_ns = configurable.pop("checkpoint_ns") | ||
| checkpoint_id = configurable.pop( | ||
| "checkpoint_id", configurable.pop("thread_ts", None) | ||
| ) | ||
|
|
||
| copy = checkpoint.copy() | ||
| next_config: RunnableConfig = { | ||
| "configurable": { | ||
| "thread_id": thread_id, | ||
| "checkpoint_ns": checkpoint_ns, | ||
| "checkpoint_id": checkpoint["id"], | ||
| } | ||
| } | ||
|
|
||
| query = f"""INSERT INTO "{self.schema_name}"."{CHECKPOINTS_TABLE}"(thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, checkpoint, metadata) | ||
| VALUES (:thread_id, :checkpoint_ns, :checkpoint_id, :parent_checkpoint_id, :checkpoint, :metadata) | ||
| ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id) | ||
| DO UPDATE SET | ||
| checkpoint = EXCLUDED.checkpoint, | ||
| metadata = EXCLUDED.metadata;""" | ||
|
|
||
| async with self.pool.connect() as conn: | ||
| await conn.execute( | ||
| text(query), | ||
| { | ||
| "thread_id": thread_id, | ||
| "checkpoint_ns": checkpoint_ns, | ||
| "checkpoint_id": checkpoint["id"], | ||
| "parent_checkpoint_id": checkpoint_id, | ||
| "checkpoint": self._dump_checkpoint(copy), | ||
| "metadata": self._dump_metadata(metadata), | ||
| }, | ||
| ) | ||
| await conn.commit() | ||
|
|
||
| return next_config | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,122 @@ | ||
| # Copyright 2025 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import os | ||
| from typing import Sequence | ||
|
|
||
| import pytest | ||
| import pytest_asyncio | ||
| from langchain_core.runnables import RunnableConfig | ||
| from langgraph.checkpoint.base import ( | ||
| Checkpoint, | ||
| CheckpointMetadata, | ||
| ) | ||
| from sqlalchemy import text | ||
| from sqlalchemy.engine.row import RowMapping | ||
|
|
||
| from langchain_google_cloud_sql_pg.async_checkpoint import AsyncPostgresSaver | ||
| from langchain_google_cloud_sql_pg.engine import ( | ||
| CHECKPOINT_WRITES_TABLE, | ||
| CHECKPOINTS_TABLE, | ||
| PostgresEngine, | ||
| ) | ||
|
|
||
| # Configurations for writing and reading checkpoints | ||
| write_config: RunnableConfig = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} | ||
| read_config: RunnableConfig = {"configurable": {"thread_id": "1"}} | ||
|
|
||
| # Environment variables for PostgreSQL instance | ||
| project_id = os.environ["PROJECT_ID"] | ||
| region = os.environ["REGION"] | ||
| instance_id = os.environ["INSTANCE_ID"] | ||
| db_name = os.environ["DATABASE_ID"] | ||
|
|
||
| # Sample checkpoint data | ||
| checkpoint: Checkpoint = { | ||
| "v": 1, | ||
| "ts": "2024-07-31T20:14:19.804150+00:00", | ||
| "id": "1ef4f797-8335-6428-8001-8a1503f9b875", | ||
| "channel_values": {"my_key": "meow", "node": "node"}, | ||
| "channel_versions": {"__start__": 2, "my_key": 3, "start:node": 3, "node": 3}, | ||
| "versions_seen": { | ||
| "__input__": {}, | ||
| "__start__": {"__start__": 1}, | ||
| "node": {"start:node": 2}, | ||
| }, | ||
| "pending_sends": [], | ||
| } | ||
|
|
||
|
|
||
| async def aexecute(engine: PostgresEngine, query: str) -> None: | ||
| """Execute an asynchronous SQL command.""" | ||
| async with engine._pool.connect() as conn: | ||
| await conn.execute(text(query)) | ||
| await conn.commit() | ||
|
|
||
|
|
||
| async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: | ||
| """Fetch data from the database asynchronously.""" | ||
| async with engine._pool.connect() as conn: | ||
| result = await conn.execute(text(query)) | ||
| return result.mappings().fetchall() | ||
|
|
||
|
|
||
| @pytest_asyncio.fixture | ||
| async def async_engine(): | ||
| """Setup and teardown for PostgresEngine instance.""" | ||
| async_engine = await PostgresEngine.afrom_instance( | ||
| project_id=project_id, | ||
| region=region, | ||
| instance=instance_id, | ||
| database=db_name, | ||
| ) | ||
| await async_engine._ainit_checkpoint_table() | ||
|
|
||
| yield async_engine # Provide the engine instance for testing | ||
|
|
||
| # Cleanup: Drop checkpoint tables after tests | ||
| await aexecute(async_engine, f'DROP TABLE IF EXISTS "{CHECKPOINTS_TABLE}"') | ||
| await aexecute(async_engine, f'DROP TABLE IF EXISTS "{CHECKPOINT_WRITES_TABLE}"') | ||
| await async_engine.close() | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_checkpoint_async(async_engine: PostgresEngine) -> None: | ||
| """Test inserting and retrieving a checkpoint asynchronously.""" | ||
|
|
||
| # Create an instance of AsyncPostgresSaver | ||
| checkpointer = await AsyncPostgresSaver.create(async_engine) | ||
|
|
||
| # Expected configuration after storing the checkpoint | ||
| expected_config = { | ||
| "configurable": { | ||
| "thread_id": "1", | ||
| "checkpoint_ns": "", | ||
| "checkpoint_id": "1ef4f797-8335-6428-8001-8a1503f9b875", | ||
| } | ||
| } | ||
|
|
||
| # Insert checkpoint and verify returned configuration | ||
| next_config = await checkpointer.aput(write_config, checkpoint, {}, {}) | ||
| assert dict(next_config) == expected_config | ||
|
|
||
| # Verify if the checkpoint is stored correctly in the database | ||
| results = await afetch(async_engine, f'SELECT * FROM "{CHECKPOINTS_TABLE}"') | ||
| assert len(results) == 1 # Only one checkpoint should be stored | ||
|
|
||
| for row in results: | ||
| assert isinstance(row["thread_id"], str) | ||
|
|
||
| # Cleanup: Remove all checkpoints after the test | ||
| await aexecute(async_engine, f'TRUNCATE TABLE "{CHECKPOINTS_TABLE}"') |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Create will also have to support a custom table name for the tests to work