Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/langchain_google_alloydb_pg/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
generate_headers_fn: Optional[str] = None,
input_batch_transform_fn: Optional[str] = None,
output_batch_transform_fn: Optional[str] = None,
model_availability: Optional[str] = None,
**kwargs: Any,
):
self.model_id = model_id
Expand All @@ -53,6 +54,7 @@ def __init__(
self.generate_headers_fn = generate_headers_fn or kwargs.get("header_gen_fn")
self.input_batch_transform_fn = input_batch_transform_fn
self.output_batch_transform_fn = output_batch_transform_fn
self.model_availability = model_availability


class AlloyDBModelManager:
Expand Down Expand Up @@ -174,14 +176,14 @@ async def __avalidate(self) -> None:
"""Private async function to validate prerequisites.

Raises:
Exception if google_ml_integration EXTENSION is not 1.5.2.
Exception if google_ml_integration EXTENSION is not 1.5.3.
Exception if google_ml_integration.enable_model_support DB Flag not set.
"""
extension_version = await self.__fetch_google_ml_extension()
db_flag = await self.__fetch_db_flag()
if extension_version < "1.5.2":
if extension_version < "1.5.3":
raise Exception(
"Please upgrade google_ml_integration EXTENSION to version 1.5.2 or above."
"Please upgrade google_ml_integration EXTENSION to version 1.5.3 or above."
)
if db_flag != "on":
raise Exception(
Expand Down Expand Up @@ -212,6 +214,7 @@ async def __aget_model(self, model_id: str) -> Optional[AlloyDBModel]:
query = f"""SELECT * FROM
google_ml.list_model('{model_id}')
AS t(model_id VARCHAR,
model_availability text,
model_request_url VARCHAR,
model_provider google_ml.model_provider,
model_type google_ml.model_type,
Expand Down Expand Up @@ -291,13 +294,13 @@ async def __adrop_model(self, model_id: str) -> None:
await conn.commit()

async def __fetch_google_ml_extension(self) -> str:
"""Creates the Google ML Extension if it does not exist and returns the version number (Default creates version 1.5.2)."""
"""Creates the Google ML Extension if it does not exist and returns the version number (Default creates version 1.5.3)."""
create_extension_query = """
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM pg_extension WHERE extname = 'google_ml_integration' )
THEN CREATE EXTENSION google_ml_integration VERSION '1.5.2' CASCADE;
THEN CREATE EXTENSION google_ml_integration VERSION '1.5.3' CASCADE;
END IF;
END
$$;
Expand Down