Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions samples/index_tuning_sample/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ class HNSWIndex(
index_type: str = "hnsw",
# Distance strategy does not affect recall and has minimal little on latency; refer to this guide to learn more https://cloud.google.com/spanner/docs/choose-vector-distance-function
distance_strategy: DistanceStrategy = lambda : DistanceStrategy.COSINE_DISTANCE,
partial_indexes: List[str] | None = None,
partial_indexes: list[str] | None = None,
m: int = 16,
ef_construction: int = 64
)
Expand Down Expand Up @@ -235,7 +235,7 @@ class IVFFlatIndex(
name: str = DEFAULT_INDEX_NAME,
index_type: str = "ivfflat",
distance_strategy: DistanceStrategy = lambda : DistanceStrategy.COSINE_DISTANCE,
partial_indexes: List[str] | None = None,
partial_indexes: list[str] | None = None,
lists: int = 1
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import List

import vertexai # type: ignore
from config import (
Expand All @@ -39,14 +38,14 @@
engine = None # Use global variable to share connection pooling


def similarity_search(query: str) -> List[Document]:
def similarity_search(query: str) -> list[Document]:
"""Searches and returns movies.

Args:
query: The user query to search for related items

Returns:
List[Document]: A list of Documents
list[Document]: A list of Documents
"""
global engine
if not engine: # Reuse connection pool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import json
from typing import List, Sequence
from typing import Sequence

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, messages_from_dict
Expand Down Expand Up @@ -128,7 +128,7 @@ async def aclear(self) -> None:
await conn.execute(text(query), {"session_id": self.session_id})
await conn.commit()

async def _aget_messages(self) -> List[BaseMessage]:
async def _aget_messages(self) -> list[BaseMessage]:
"""Retrieve the messages from AlloyDB."""
query = f"""SELECT data, type FROM "{self.schema_name}"."{self.table_name}" WHERE session_id = :session_id ORDER BY id;"""
async with self.pool.connect() as conn:
Expand Down
50 changes: 25 additions & 25 deletions src/langchain_google_alloydb_pg/async_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import json
from typing import Any, AsyncIterator, Callable, Dict, Iterable, List, Optional
from typing import Any, AsyncIterator, Callable, Iterable, Optional

from langchain_core.document_loaders.base import BaseLoader
from langchain_core.documents import Document
Expand All @@ -28,24 +28,24 @@
DEFAULT_METADATA_COL = "langchain_metadata"


def text_formatter(row: dict, content_columns: List[str]) -> str:
def text_formatter(row: dict, content_columns: list[str]) -> str:
"""txt document formatter."""
return " ".join(str(row[column]) for column in content_columns if column in row)


def csv_formatter(row: dict, content_columns: List[str]) -> str:
def csv_formatter(row: dict, content_columns: list[str]) -> str:
"""CSV document formatter."""
return ", ".join(str(row[column]) for column in content_columns if column in row)


def yaml_formatter(row: dict, content_columns: List[str]) -> str:
def yaml_formatter(row: dict, content_columns: list[str]) -> str:
"""YAML document formatter."""
return "\n".join(
f"{column}: {str(row[column])}" for column in content_columns if column in row
)


def json_formatter(row: dict, content_columns: List[str]) -> str:
def json_formatter(row: dict, content_columns: list[str]) -> str:
"""JSON document formatter."""
dictionary = {}
for column in content_columns:
Expand All @@ -63,7 +63,7 @@ def _parse_doc_from_row(
) -> Document:
"""Parse row into document."""
page_content = formatter(row, content_columns)
metadata: Dict[str, Any] = {}
metadata: dict[str, Any] = {}
# unnest metadata from langchain_metadata column
if metadata_json_column and row.get(metadata_json_column):
for k, v in row[metadata_json_column].items():
Expand All @@ -81,10 +81,10 @@ def _parse_row_from_doc(
column_names: Iterable[str],
content_column: str = DEFAULT_CONTENT_COL,
metadata_json_column: Optional[str] = DEFAULT_METADATA_COL,
) -> Dict:
) -> dict:
"""Parse document into a dictionary of rows."""
doc_metadata = doc.metadata.copy()
row: Dict[str, Any] = {content_column: doc.page_content}
row: dict[str, Any] = {content_column: doc.page_content}
for entry in doc.metadata:
if entry in column_names:
row[entry] = doc_metadata[entry]
Expand All @@ -111,8 +111,8 @@ def __init__(
key: object,
pool: AsyncEngine,
query: str,
content_columns: List[str],
metadata_columns: List[str],
content_columns: list[str],
metadata_columns: list[str],
formatter: Callable,
metadata_json_column: Optional[str] = None,
) -> None:
Expand All @@ -122,8 +122,8 @@ def __init__(
key (object): Prevent direct constructor usage.
engine (AlloyDBEngine): AsyncEngine with pool connection to the postgres database
query (Optional[str], optional): SQL query. Defaults to None.
content_columns (Optional[List[str]], optional): Column that represent a Document's page_content. Defaults to the first column.
metadata_columns (Optional[List[str]], optional): Column(s) that represent a Document's metadata. Defaults to None.
content_columns (Optional[list[str]], optional): Column that represent a Document's page_content. Defaults to the first column.
metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to None.
formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None.
metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "langchain_metadata".

Expand All @@ -149,8 +149,8 @@ async def create(
query: Optional[str] = None,
table_name: Optional[str] = None,
schema_name: str = "public",
content_columns: Optional[List[str]] = None,
metadata_columns: Optional[List[str]] = None,
content_columns: Optional[list[str]] = None,
metadata_columns: Optional[list[str]] = None,
metadata_json_column: Optional[str] = None,
format: Optional[str] = None,
formatter: Optional[Callable] = None,
Expand All @@ -162,8 +162,8 @@ async def create(
query (Optional[str], optional): SQL query. Defaults to None.
table_name (Optional[str], optional): Name of table to query. Defaults to None.
schema_name (str, optional): Name of the schema where table is located. Defaults to "public".
content_columns (Optional[List[str]], optional): Column that represent a Document's page_content. Defaults to the first column.
metadata_columns (Optional[List[str]], optional): Column(s) that represent a Document's metadata. Defaults to None.
content_columns (Optional[list[str]], optional): Column that represent a Document's page_content. Defaults to the first column.
metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to None.
metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "langchain_metadata".
format (Optional[str], optional): Format of page content (OneOf: text, csv, YAML, JSON). Defaults to 'text'.
formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None.
Expand Down Expand Up @@ -236,7 +236,7 @@ async def create(
metadata_json_column,
)

async def aload(self) -> List[Document]:
async def aload(self) -> list[Document]:
"""Load PostgreSQL data into Document objects."""
return [doc async for doc in self.alazy_load()]

Expand Down Expand Up @@ -282,7 +282,7 @@ def __init__(
table_name: str,
content_column: str,
schema_name: str = "public",
metadata_columns: List[str] = [],
metadata_columns: list[str] = [],
metadata_json_column: Optional[str] = None,
):
"""AsyncAlloyDBDocumentSaver constructor.
Expand All @@ -293,7 +293,7 @@ def __init__(
table_name (str): Name of table to query.
schema_name (str, optional): Name of schema where the table is located. Defaults to "public".
content_column (str, optional): Column that represent a Document's page_content. Defaults to "page_content".
metadata_columns (Optional[List[str]], optional): Column(s) that represent a Document's metadata. Defaults to an empty list.
metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to an empty list.
metadata_json_column (str, optional): Column to store metadata as JSON. Defaults to "langchain_metadata".

Raises:
Expand All @@ -317,7 +317,7 @@ async def create(
table_name: str,
schema_name: str = "public",
content_column: str = DEFAULT_CONTENT_COL,
metadata_columns: List[str] = [],
metadata_columns: list[str] = [],
metadata_json_column: Optional[str] = DEFAULT_METADATA_COL,
) -> AsyncAlloyDBDocumentSaver:
"""Create an AsyncAlloyDBDocumentSaver instance.
Expand All @@ -327,7 +327,7 @@ async def create(
table_name (str): Name of table to query.
schema_name (str, optional): Name of schema where the table is located. Defaults to "public".
content_column (str, optional): Column that represent a Document's page_content. Defaults to "page_content".
metadata_columns (List[str], optional): Column(s) that represent a Document's metadata. Defaults to an empty list.
metadata_columns (list[str], optional): Column(s) that represent a Document's metadata. Defaults to an empty list.
metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "langchain_metadata".

Returns:
Expand Down Expand Up @@ -370,13 +370,13 @@ async def create(
metadata_json_column,
)

async def aadd_documents(self, docs: List[Document]) -> None:
async def aadd_documents(self, docs: list[Document]) -> None:
"""
Save documents in the DocumentSaver table. Document’s metadata is added to columns if found or
stored in langchain_metadata JSON column.

Args:
docs (List[langchain_core.documents.Document]): a list of documents to be saved.
docs (list[langchain_core.documents.Document]): a list of documents to be saved.
"""

for doc in docs:
Expand Down Expand Up @@ -414,13 +414,13 @@ async def aadd_documents(self, docs: List[Document]) -> None:
await conn.execute(text(query), row)
await conn.commit()

async def adelete(self, docs: List[Document]) -> None:
async def adelete(self, docs: list[Document]) -> None:
"""
Delete all instances of a document from the DocumentSaver table by matching the entire Document
object.

Args:
docs (List[langchain_core.documents.Document]): a list of documents to be deleted.
docs (list[langchain_core.documents.Document]): a list of documents to be deleted.
"""
for doc in docs:
row = _parse_row_from_doc(
Expand Down
Loading
Loading