diff --git a/src/langchain_google_alloydb_pg/model_manager.py b/src/langchain_google_alloydb_pg/model_manager.py index 899954b4..f62629cf 100644 --- a/src/langchain_google_alloydb_pg/model_manager.py +++ b/src/langchain_google_alloydb_pg/model_manager.py @@ -38,6 +38,7 @@ def __init__( generate_headers_fn: Optional[str] = None, input_batch_transform_fn: Optional[str] = None, output_batch_transform_fn: Optional[str] = None, + model_availability: Optional[str] = None, **kwargs: Any, ): self.model_id = model_id @@ -53,6 +54,7 @@ def __init__( self.generate_headers_fn = generate_headers_fn or kwargs.get("header_gen_fn") self.input_batch_transform_fn = input_batch_transform_fn self.output_batch_transform_fn = output_batch_transform_fn + self.model_availability = model_availability class AlloyDBModelManager: @@ -174,14 +176,14 @@ async def __avalidate(self) -> None: """Private async function to validate prerequisites. Raises: - Exception if google_ml_integration EXTENSION is not 1.5.2. + Exception if google_ml_integration EXTENSION is not 1.5.3. Exception if google_ml_integration.enable_model_support DB Flag not set. """ extension_version = await self.__fetch_google_ml_extension() db_flag = await self.__fetch_db_flag() - if extension_version < "1.5.2": + if extension_version < "1.5.3": raise Exception( - "Please upgrade google_ml_integration EXTENSION to version 1.5.2 or above." + "Please upgrade google_ml_integration EXTENSION to version 1.5.3 or above." ) if db_flag != "on": raise Exception( @@ -212,6 +214,7 @@ async def __aget_model(self, model_id: str) -> Optional[AlloyDBModel]: query = f"""SELECT * FROM google_ml.list_model('{model_id}') AS t(model_id VARCHAR, + model_availability text, model_request_url VARCHAR, model_provider google_ml.model_provider, model_type google_ml.model_type, @@ -291,13 +294,13 @@ async def __adrop_model(self, model_id: str) -> None: await conn.commit() async def __fetch_google_ml_extension(self) -> str: - """Creates the Google ML Extension if it does not exist and returns the version number (Default creates version 1.5.2).""" + """Creates the Google ML Extension if it does not exist and returns the version number (Default creates version 1.5.3).""" create_extension_query = """ DO $$ BEGIN IF NOT EXISTS ( SELECT 1 FROM pg_extension WHERE extname = 'google_ml_integration' ) - THEN CREATE EXTENSION google_ml_integration VERSION '1.5.2' CASCADE; + THEN CREATE EXTENSION google_ml_integration VERSION '1.5.3' CASCADE; END IF; END $$;