Skip to content

Commit

Permalink
refactor plugin system to inject hooks at session start and query run
Browse files Browse the repository at this point in the history
  • Loading branch information
buremba committed Feb 7, 2025
1 parent cff9bc6 commit d894890
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 39 deletions.
36 changes: 24 additions & 12 deletions universql/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,17 @@ def executor(self) -> "Executor":

T = typing.TypeVar('T', bound=ICatalog)


def _track_call(method):
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
# Mark that the method was called on this instance.
self._warm = True
return method(self, *args, **kwargs)

return wrapper


class Executor(typing.Protocol[T]):

def __init__(self, catalog: T):
Expand Down Expand Up @@ -85,26 +88,35 @@ def close(self):
pass


class UniversqlPlugin:
def __init__(self,
source_executor: Executor
):
self.source_executor = source_executor
class UQuery:
def __init__(self, ast: typing.Optional[List[sqlglot.exp.Expression]], raw_query: str):
self.ast = ast
self.raw_query = raw_query

def transform_sql(self, expression: Expression, target_executor: Executor) -> Expression:
return expression
def transform_ast(self, ast: sqlglot.exp.Expression, target_executor: Executor) -> Expression:
return ast

def post_execute(self, expression: Expression, locations : typing.Optional[Locations], target_executor: Executor):
def post_execute(self, locations: typing.Optional[Locations], target_executor: Executor):
pass

def pre_execute(self, expression: Expression, target_executor: Executor):
def end(self, table : pyarrow.Table):
pass


class UniversqlPlugin(ABC):
def __init__(self,
session: "universql.protocol.session.UniverSQLSession"
):
self.session = session

def start_query(self, ast: typing.Optional[List[sqlglot.exp.Expression]], raw_query : str) -> UQuery:
return UQuery(ast, raw_query)


# {"duckdb": DuckdbCatalog ..}
COMPUTES = {}
# [method]
TRANSFORMS = []
PLUGINS = []
# apps to be installed
APPS = []

Expand All @@ -123,7 +135,7 @@ def decorator(cls):
raise SystemError("name is required for catalogs")
COMPUTES[name] = cls
elif issubclass(cls, UniversqlPlugin) and cls is not UniversqlPlugin:
TRANSFORMS.append(cls)
PLUGINS.append(cls)
elif inspect.isfunction(cls):
signature = inspect.signature(cls)
if len(signature.parameters) == 1 and signature.parameters.values().__iter__().__next__().annotation is FastAPI:
Expand All @@ -132,4 +144,4 @@ def decorator(cls):
raise SystemError(f"Unknown type {cls}")
return cls

return decorator
return decorator
27 changes: 21 additions & 6 deletions universql/plugins/snow.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import typing
from typing import List

import pyarrow
import sqlglot
from sqlglot import Expression
from sqlglot.expressions import TableSample

from universql.plugin import UniversqlPlugin, register
from universql.plugin import UniversqlPlugin, register, UQuery, Locations
from universql.warehouse.duckdb import DuckDBExecutor
from universql.warehouse.snowflake import SnowflakeExecutor

Expand All @@ -12,12 +16,9 @@
# when FILES is specified:
# COPY INTO stg_device_metadata FROM 's3:/test/initial_objects/device_metadata.csv' (TYPE = CSV SKIP_HEADER = 1)
# COPY INTO stg_device_metadata FROM 's3:/test/initial_objects/file2.csv' (TYPE = CSV SKIP_HEADER = 1)
@register()
class SnowflakeStageUniversqlPlugin(UniversqlPlugin):
def __init__(self, source_executor: SnowflakeExecutor):
super().__init__(source_executor)

def transform_sql(self, expression: Expression, target_executor: DuckDBExecutor) -> Expression:
class StageTransformer(UQuery):
def transform_ast(self, expression: Expression, target_executor: DuckDBExecutor) -> Expression:
if isinstance(expression, sqlglot.exp.Var) and expression.name.startswith('@'):
expression.args['name'] = 'myname'
return expression
Expand All @@ -32,6 +33,20 @@ def _get_stage(self, table: sqlglot.exp.Table):
# self.source_executor.execute_raw("DESCRIBE STAGE {}", self.source_executor)
return

def post_execute(self, locations: typing.Optional[Locations], target_executor: DuckDBExecutor):
pass

def end(self, table : pyarrow.Table):
pass

@register()
class SnowflakeStageUniversqlPlugin(UniversqlPlugin):
def __init__(self, session: "universql.protocol.session.UniverSQLSession"):
super().__init__(session)

def start_query(self, ast: typing.Optional[List[sqlglot.exp.Expression]], raw_query: str) -> UQuery:
return StageTransformer(ast, raw_query)


# @register()
class TableSampleUniversqlPlugin(UniversqlPlugin):
Expand Down
52 changes: 31 additions & 21 deletions universql/protocol/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
from universql.lake.cloud import CACHE_DIRECTORY_KEY, MAX_CACHE_SIZE
from universql.util import get_friendly_time_since, \
prepend_to_lines, parse_compute, QueryError, full_qualifier
from universql.plugin import Executor, Tables, ICatalog, COMPUTES, TRANSFORMS, UniversqlPlugin

from universql.plugin import Executor, Tables, ICatalog, COMPUTES, PLUGINS, UniversqlPlugin, UQuery

logger = logging.getLogger("💡")

Expand All @@ -42,7 +41,7 @@ def __init__(self, context, session_id, credentials: dict, session_parameters: d
self.computes = {"snowflake": self.catalog_executor}
self.processing = False
self.metadata_db = None
self.transforms : List[UniversqlPlugin] = [transform(self.catalog_executor) for transform in TRANSFORMS]
self.plugins : List[UniversqlPlugin] = [plugin(self) for plugin in PLUGINS]


def _get_iceberg_catalog(self):
Expand Down Expand Up @@ -111,8 +110,18 @@ def _do_query(self, start_time: float, raw_query: str) -> pyarrow.Table:

last_executor = None

plugin_hooks = []
for plugin in self.plugins:
try:
plugin_hooks.append(plugin.start_query(queries, raw_query))
except Exception as e:
print_exc(10)
message = f"Unable to call start_query on plugin {plugin.__class__}"
logger.error(message, exc_info=e)
raise QueryError(f"{message}: {str(e)}")

if queries is None:
last_executor = self.perform_query(self.catalog_executor, raw_query)
last_executor = self.perform_query(self.catalog_executor, raw_query, plugin_hooks)
else:
last_error = None
for ast in queries:
Expand All @@ -129,7 +138,7 @@ def _do_query(self, start_time: float, raw_query: str) -> pyarrow.Table:
last_executor = target_compute(self, compute).executor()
self.computes[compute_name] = last_executor
try:
last_executor = self.perform_query(last_executor, raw_query, ast=ast)
last_executor = self.perform_query(last_executor, raw_query, plugin_hooks, ast=ast)
break
except QueryError as e:
logger.warning(f"Unable to run query: {e.message}")
Expand All @@ -145,7 +154,16 @@ def _do_query(self, start_time: float, raw_query: str) -> pyarrow.Table:
f"[{self.session_id}] {last_executor.get_query_log(query_duration)} 🚀 "
f"({get_friendly_time_since(start_time, performance_counter)})")

return last_executor.get_as_table()
table = last_executor.get_as_table()
for hook in plugin_hooks:
try:
hook.end(table)
except Exception as e:
print_exc(10)
message = f"Unable to end query execution on plugin {hook.__class__}"
logger.error(message, exc_info=e)
raise QueryError(f"{message}: {str(e)}")
return table

def _find_tables(self, ast: sqlglot.exp.Expression, cte_aliases=None):
if cte_aliases is None:
Expand All @@ -159,7 +177,7 @@ def _find_tables(self, ast: sqlglot.exp.Expression, cte_aliases=None):
if expression.catalog or expression.db or str(expression.this.this) not in cte_aliases:
yield full_qualifier(expression, self.credentials), cte_aliases

def perform_query(self, alternative_executor: Executor, raw_query, ast=None) -> Executor:
def perform_query(self, alternative_executor: Executor, raw_query, plugin_hooks : List[UQuery], ast=None) -> Executor:
if ast is not None and alternative_executor != self.catalog_executor:
must_run_on_catalog = False
if isinstance(ast, Create):
Expand All @@ -180,29 +198,21 @@ def perform_query(self, alternative_executor: Executor, raw_query, ast=None) ->
locations = self.get_table_paths_from_catalog(alternative_executor.catalog, tables_list)
with sentry_sdk.start_span(op=op_name, name="Execute query"):
current_ast = ast
for transform in self.transforms:
try:
current_ast = transform.transform_sql(current_ast, alternative_executor)
except Exception as e:
print_exc(10)
message = f"Unable to perform transformation {transform.__class__}"
logger.error(message, exc_info=e)
raise QueryError(f"{message}: {str(e)}")
for transform in self.transforms:
for plugin_hook in plugin_hooks:
try:
transform.pre_execute(current_ast, alternative_executor)
current_ast = plugin_hook.transform_ast(ast, alternative_executor)
except Exception as e:
print_exc(10)
message = f"Unable to perform transformation {transform.__class__}"
message = f"Unable to tranform_ast on plugin {plugin_hook.__class__}"
logger.error(message, exc_info=e)
raise QueryError(f"{message}: {str(e)}")
new_locations = alternative_executor.execute(current_ast, self.catalog_executor, locations)
for transform in self.transforms:
for plugin_hook in plugin_hooks:
try:
transform.post_execute(current_ast, new_locations, alternative_executor)
plugin_hook.post_execute(new_locations, alternative_executor)
except Exception as e:
print_exc(10)
message = f"Unable to perform transformation {transform.__class__}"
message = f"Unable to post_execute on plugin {plugin_hook.__class__}"
logger.error(message, exc_info=e)
raise QueryError(f"{message}: {str(e)}")
if new_locations is not None:
Expand Down

0 comments on commit d894890

Please sign in to comment.