From d8948904cae3e70defec5b0b7c37ea0ea3fb30c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Kabakc=C4=B1?= Date: Fri, 7 Feb 2025 14:13:04 +0000 Subject: [PATCH] refactor plugin system to inject hooks at session start and query run --- universql/plugin.py | 36 ++++++++++++++++-------- universql/plugins/snow.py | 27 ++++++++++++++---- universql/protocol/session.py | 52 +++++++++++++++++++++-------------- 3 files changed, 76 insertions(+), 39 deletions(-) diff --git a/universql/plugin.py b/universql/plugin.py index 17b28f5..b924111 100644 --- a/universql/plugin.py +++ b/universql/plugin.py @@ -39,14 +39,17 @@ def executor(self) -> "Executor": T = typing.TypeVar('T', bound=ICatalog) + def _track_call(method): @functools.wraps(method) def wrapper(self, *args, **kwargs): # Mark that the method was called on this instance. self._warm = True return method(self, *args, **kwargs) + return wrapper + class Executor(typing.Protocol[T]): def __init__(self, catalog: T): @@ -85,26 +88,35 @@ def close(self): pass -class UniversqlPlugin: - def __init__(self, - source_executor: Executor - ): - self.source_executor = source_executor +class UQuery: + def __init__(self, ast: typing.Optional[List[sqlglot.exp.Expression]], raw_query: str): + self.ast = ast + self.raw_query = raw_query - def transform_sql(self, expression: Expression, target_executor: Executor) -> Expression: - return expression + def transform_ast(self, ast: sqlglot.exp.Expression, target_executor: Executor) -> Expression: + return ast - def post_execute(self, expression: Expression, locations : typing.Optional[Locations], target_executor: Executor): + def post_execute(self, locations: typing.Optional[Locations], target_executor: Executor): pass - def pre_execute(self, expression: Expression, target_executor: Executor): + def end(self, table : pyarrow.Table): pass +class UniversqlPlugin(ABC): + def __init__(self, + session: "universql.protocol.session.UniverSQLSession" + ): + self.session = session + + def start_query(self, ast: typing.Optional[List[sqlglot.exp.Expression]], raw_query : str) -> UQuery: + return UQuery(ast, raw_query) + + # {"duckdb": DuckdbCatalog ..} COMPUTES = {} # [method] -TRANSFORMS = [] +PLUGINS = [] # apps to be installed APPS = [] @@ -123,7 +135,7 @@ def decorator(cls): raise SystemError("name is required for catalogs") COMPUTES[name] = cls elif issubclass(cls, UniversqlPlugin) and cls is not UniversqlPlugin: - TRANSFORMS.append(cls) + PLUGINS.append(cls) elif inspect.isfunction(cls): signature = inspect.signature(cls) if len(signature.parameters) == 1 and signature.parameters.values().__iter__().__next__().annotation is FastAPI: @@ -132,4 +144,4 @@ def decorator(cls): raise SystemError(f"Unknown type {cls}") return cls - return decorator \ No newline at end of file + return decorator diff --git a/universql/plugins/snow.py b/universql/plugins/snow.py index c37da40..ba86dd4 100644 --- a/universql/plugins/snow.py +++ b/universql/plugins/snow.py @@ -1,8 +1,12 @@ +import typing +from typing import List + +import pyarrow import sqlglot from sqlglot import Expression from sqlglot.expressions import TableSample -from universql.plugin import UniversqlPlugin, register +from universql.plugin import UniversqlPlugin, register, UQuery, Locations from universql.warehouse.duckdb import DuckDBExecutor from universql.warehouse.snowflake import SnowflakeExecutor @@ -12,12 +16,9 @@ # when FILES is specified: # COPY INTO stg_device_metadata FROM 's3:/test/initial_objects/device_metadata.csv' (TYPE = CSV SKIP_HEADER = 1) # COPY INTO stg_device_metadata FROM 's3:/test/initial_objects/file2.csv' (TYPE = CSV SKIP_HEADER = 1) -@register() -class SnowflakeStageUniversqlPlugin(UniversqlPlugin): - def __init__(self, source_executor: SnowflakeExecutor): - super().__init__(source_executor) - def transform_sql(self, expression: Expression, target_executor: DuckDBExecutor) -> Expression: +class StageTransformer(UQuery): + def transform_ast(self, expression: Expression, target_executor: DuckDBExecutor) -> Expression: if isinstance(expression, sqlglot.exp.Var) and expression.name.startswith('@'): expression.args['name'] = 'myname' return expression @@ -32,6 +33,20 @@ def _get_stage(self, table: sqlglot.exp.Table): # self.source_executor.execute_raw("DESCRIBE STAGE {}", self.source_executor) return + def post_execute(self, locations: typing.Optional[Locations], target_executor: DuckDBExecutor): + pass + + def end(self, table : pyarrow.Table): + pass + +@register() +class SnowflakeStageUniversqlPlugin(UniversqlPlugin): + def __init__(self, session: "universql.protocol.session.UniverSQLSession"): + super().__init__(session) + + def start_query(self, ast: typing.Optional[List[sqlglot.exp.Expression]], raw_query: str) -> UQuery: + return StageTransformer(ast, raw_query) + # @register() class TableSampleUniversqlPlugin(UniversqlPlugin): diff --git a/universql/protocol/session.py b/universql/protocol/session.py index 6a35467..f3447b7 100644 --- a/universql/protocol/session.py +++ b/universql/protocol/session.py @@ -20,8 +20,7 @@ from universql.lake.cloud import CACHE_DIRECTORY_KEY, MAX_CACHE_SIZE from universql.util import get_friendly_time_since, \ prepend_to_lines, parse_compute, QueryError, full_qualifier -from universql.plugin import Executor, Tables, ICatalog, COMPUTES, TRANSFORMS, UniversqlPlugin - +from universql.plugin import Executor, Tables, ICatalog, COMPUTES, PLUGINS, UniversqlPlugin, UQuery logger = logging.getLogger("💡") @@ -42,7 +41,7 @@ def __init__(self, context, session_id, credentials: dict, session_parameters: d self.computes = {"snowflake": self.catalog_executor} self.processing = False self.metadata_db = None - self.transforms : List[UniversqlPlugin] = [transform(self.catalog_executor) for transform in TRANSFORMS] + self.plugins : List[UniversqlPlugin] = [plugin(self) for plugin in PLUGINS] def _get_iceberg_catalog(self): @@ -111,8 +110,18 @@ def _do_query(self, start_time: float, raw_query: str) -> pyarrow.Table: last_executor = None + plugin_hooks = [] + for plugin in self.plugins: + try: + plugin_hooks.append(plugin.start_query(queries, raw_query)) + except Exception as e: + print_exc(10) + message = f"Unable to call start_query on plugin {plugin.__class__}" + logger.error(message, exc_info=e) + raise QueryError(f"{message}: {str(e)}") + if queries is None: - last_executor = self.perform_query(self.catalog_executor, raw_query) + last_executor = self.perform_query(self.catalog_executor, raw_query, plugin_hooks) else: last_error = None for ast in queries: @@ -129,7 +138,7 @@ def _do_query(self, start_time: float, raw_query: str) -> pyarrow.Table: last_executor = target_compute(self, compute).executor() self.computes[compute_name] = last_executor try: - last_executor = self.perform_query(last_executor, raw_query, ast=ast) + last_executor = self.perform_query(last_executor, raw_query, plugin_hooks, ast=ast) break except QueryError as e: logger.warning(f"Unable to run query: {e.message}") @@ -145,7 +154,16 @@ def _do_query(self, start_time: float, raw_query: str) -> pyarrow.Table: f"[{self.session_id}] {last_executor.get_query_log(query_duration)} 🚀 " f"({get_friendly_time_since(start_time, performance_counter)})") - return last_executor.get_as_table() + table = last_executor.get_as_table() + for hook in plugin_hooks: + try: + hook.end(table) + except Exception as e: + print_exc(10) + message = f"Unable to end query execution on plugin {hook.__class__}" + logger.error(message, exc_info=e) + raise QueryError(f"{message}: {str(e)}") + return table def _find_tables(self, ast: sqlglot.exp.Expression, cte_aliases=None): if cte_aliases is None: @@ -159,7 +177,7 @@ def _find_tables(self, ast: sqlglot.exp.Expression, cte_aliases=None): if expression.catalog or expression.db or str(expression.this.this) not in cte_aliases: yield full_qualifier(expression, self.credentials), cte_aliases - def perform_query(self, alternative_executor: Executor, raw_query, ast=None) -> Executor: + def perform_query(self, alternative_executor: Executor, raw_query, plugin_hooks : List[UQuery], ast=None) -> Executor: if ast is not None and alternative_executor != self.catalog_executor: must_run_on_catalog = False if isinstance(ast, Create): @@ -180,29 +198,21 @@ def perform_query(self, alternative_executor: Executor, raw_query, ast=None) -> locations = self.get_table_paths_from_catalog(alternative_executor.catalog, tables_list) with sentry_sdk.start_span(op=op_name, name="Execute query"): current_ast = ast - for transform in self.transforms: - try: - current_ast = transform.transform_sql(current_ast, alternative_executor) - except Exception as e: - print_exc(10) - message = f"Unable to perform transformation {transform.__class__}" - logger.error(message, exc_info=e) - raise QueryError(f"{message}: {str(e)}") - for transform in self.transforms: + for plugin_hook in plugin_hooks: try: - transform.pre_execute(current_ast, alternative_executor) + current_ast = plugin_hook.transform_ast(ast, alternative_executor) except Exception as e: print_exc(10) - message = f"Unable to perform transformation {transform.__class__}" + message = f"Unable to tranform_ast on plugin {plugin_hook.__class__}" logger.error(message, exc_info=e) raise QueryError(f"{message}: {str(e)}") new_locations = alternative_executor.execute(current_ast, self.catalog_executor, locations) - for transform in self.transforms: + for plugin_hook in plugin_hooks: try: - transform.post_execute(current_ast, new_locations, alternative_executor) + plugin_hook.post_execute(new_locations, alternative_executor) except Exception as e: print_exc(10) - message = f"Unable to perform transformation {transform.__class__}" + message = f"Unable to post_execute on plugin {plugin_hook.__class__}" logger.error(message, exc_info=e) raise QueryError(f"{message}: {str(e)}") if new_locations is not None: