diff --git a/.circleci/continue_config.yml b/.circleci/continue_config.yml index 8f8324a2a0..e21f3d869b 100644 --- a/.circleci/continue_config.yml +++ b/.circleci/continue_config.yml @@ -239,7 +239,7 @@ jobs: - checkout - run: name: Install OS-level dependencies - command: ./.circleci/install-prerequisites.sh "<< parameters.engine >>" + command: ./.circleci/install-prerequisites.sh "<< parameters.engine >>" - run: name: Generate database name command: | @@ -307,7 +307,7 @@ workflows: - redshift - bigquery - clickhouse-cloud - - athena + - athena - fabric - gcp-postgres filters: diff --git a/docs/integrations/engines/snowflake.md b/docs/integrations/engines/snowflake.md index 30de0bfd14..fc2ccbd6bb 100644 --- a/docs/integrations/engines/snowflake.md +++ b/docs/integrations/engines/snowflake.md @@ -250,6 +250,14 @@ And confirm that our schemas and objects exist in the Snowflake catalog: Congratulations - your SQLMesh project is up and running on Snowflake! +### Where are the row counts? + +SQLMesh reports the number of rows processed by each model in its `plan` and `run` terminal output. + +However, due to limitations in the Snowflake Python connector, row counts cannot be determined for `CREATE TABLE AS` statements. Therefore, SQLMesh does not report row counts for certain model kinds, such as `FULL` models. + +Learn more about the connector limitation [on Github](https://github.com/snowflakedb/snowflake-connector-python/issues/645). + ## Local/Built-in Scheduler **Engine Adapter Type**: `snowflake` diff --git a/pyproject.toml b/pyproject.toml index 9a4fd0632c..f371cdee0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "croniter", "duckdb>=0.10.0,!=0.10.3", "dateparser<=1.2.1", + "humanize", "hyperscript>=0.1.0", "importlib-metadata; python_version<'3.12'", "ipywidgets", diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index e046e17630..0907b39987 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -7,6 +7,7 @@ import uuid import logging import textwrap +from humanize import metric, naturalsize from itertools import zip_longest from pathlib import Path from hyperscript import h @@ -39,6 +40,7 @@ SnapshotInfoLike, ) from sqlmesh.core.snapshot.definition import Interval, Intervals, SnapshotTableInfo +from sqlmesh.core.snapshot.execution_tracker import QueryExecutionStats from sqlmesh.core.test import ModelTest from sqlmesh.utils import rich as srich from sqlmesh.utils import Verbosity @@ -439,6 +441,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, + execution_stats: t.Optional[QueryExecutionStats] = None, auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: """Updates the snapshot evaluation progress.""" @@ -587,6 +590,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, + execution_stats: t.Optional[QueryExecutionStats] = None, auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: pass @@ -1032,7 +1036,9 @@ def start_evaluation_progress( # determine column widths self.evaluation_column_widths["annotation"] = ( - _calculate_annotation_str_len(batched_intervals, self.AUDIT_PADDING) + _calculate_annotation_str_len( + batched_intervals, self.AUDIT_PADDING, len(" (123.4m rows, 123.4 KiB)") + ) + 3 # brackets and opening escape backslash ) self.evaluation_column_widths["name"] = max( @@ -1077,6 +1083,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, + execution_stats: t.Optional[QueryExecutionStats] = None, auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: """Update the snapshot evaluation progress.""" @@ -1097,7 +1104,7 @@ def update_snapshot_evaluation_progress( ).ljust(self.evaluation_column_widths["name"]) annotation = _create_evaluation_model_annotation( - snapshot, _format_evaluation_model_interval(snapshot, interval) + snapshot, _format_evaluation_model_interval(snapshot, interval), execution_stats ) audits_str = "" if num_audits_passed: @@ -3668,6 +3675,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, + execution_stats: t.Optional[QueryExecutionStats] = None, auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: view_name, loaded_batches = self.evaluation_batch_progress[snapshot.snapshot_id] @@ -3838,6 +3846,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, + execution_stats: t.Optional[QueryExecutionStats] = None, auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: message = f"Evaluated {snapshot.name} | batch={batch_idx} | duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}" @@ -4169,33 +4178,62 @@ def _format_evaluation_model_interval(snapshot: Snapshot, interval: Interval) -> return "" -def _create_evaluation_model_annotation(snapshot: Snapshot, interval_info: t.Optional[str]) -> str: +def _create_evaluation_model_annotation( + snapshot: Snapshot, + interval_info: t.Optional[str], + execution_stats: t.Optional[QueryExecutionStats], +) -> str: + annotation = None + execution_stats_str = "" + if execution_stats: + rows_processed = execution_stats.total_rows_processed + if rows_processed: + # 1.00 and 1.0 to 1 + rows_processed_str = metric(rows_processed).replace(".00", "").replace(".0", "") + execution_stats_str += f"{rows_processed_str} row{'s' if rows_processed > 1 else ''}" + + bytes_processed = execution_stats.total_bytes_processed + execution_stats_str += ( + f"{', ' if execution_stats_str else ''}{naturalsize(bytes_processed, binary=True)}" + if bytes_processed + else "" + ) + execution_stats_str = f" ({execution_stats_str})" if execution_stats_str else "" + if snapshot.is_audit: - return "run standalone audit" - if snapshot.is_model and snapshot.model.kind.is_external: - return "run external audits" - if snapshot.model.kind.is_seed: - return "insert seed file" - if snapshot.model.kind.is_full: - return "full refresh" - if snapshot.model.kind.is_view: - return "recreate view" - if snapshot.model.kind.is_incremental_by_unique_key: - return "insert/update rows" - if snapshot.model.kind.is_incremental_by_partition: - return "insert partitions" - - return interval_info if interval_info else "" - - -def _calculate_interval_str_len(snapshot: Snapshot, intervals: t.List[Interval]) -> int: + annotation = "run standalone audit" + if snapshot.is_model: + if snapshot.model.kind.is_external: + annotation = "run external audits" + if snapshot.model.kind.is_view: + annotation = "recreate view" + if snapshot.model.kind.is_seed: + annotation = f"insert seed file{execution_stats_str}" + if snapshot.model.kind.is_full: + annotation = f"full refresh{execution_stats_str}" + if snapshot.model.kind.is_incremental_by_unique_key: + annotation = f"insert/update rows{execution_stats_str}" + if snapshot.model.kind.is_incremental_by_partition: + annotation = f"insert partitions{execution_stats_str}" + + if annotation: + return annotation + + return f"{interval_info}{execution_stats_str}" if interval_info else "" + + +def _calculate_interval_str_len( + snapshot: Snapshot, + intervals: t.List[Interval], + execution_stats: t.Optional[QueryExecutionStats] = None, +) -> int: interval_str_len = 0 for interval in intervals: interval_str_len = max( interval_str_len, len( _create_evaluation_model_annotation( - snapshot, _format_evaluation_model_interval(snapshot, interval) + snapshot, _format_evaluation_model_interval(snapshot, interval), execution_stats ) ), ) @@ -4248,13 +4286,16 @@ def _calculate_audit_str_len(snapshot: Snapshot, audit_padding: int = 0) -> int: def _calculate_annotation_str_len( - batched_intervals: t.Dict[Snapshot, t.List[Interval]], audit_padding: int = 0 + batched_intervals: t.Dict[Snapshot, t.List[Interval]], + audit_padding: int = 0, + execution_stats_len: int = 0, ) -> int: annotation_str_len = 0 for snapshot, intervals in batched_intervals.items(): annotation_str_len = max( annotation_str_len, _calculate_interval_str_len(snapshot, intervals) - + _calculate_audit_str_len(snapshot, audit_padding), + + _calculate_audit_str_len(snapshot, audit_padding) + + execution_stats_len, ) return annotation_str_len diff --git a/sqlmesh/core/engine_adapter/athena.py b/sqlmesh/core/engine_adapter/athena.py index 48b9e4ad4e..aa8a5ce0c1 100644 --- a/sqlmesh/core/engine_adapter/athena.py +++ b/sqlmesh/core/engine_adapter/athena.py @@ -45,6 +45,7 @@ class AthenaEngineAdapter(PandasNativeFetchDFSupportMixin, RowDiffMixin): # >>> self._execute('/* test */ DESCRIBE foo') # pyathena.error.OperationalError: FAILED: ParseException line 1:0 cannot recognize input near '/' '*' 'test' ATTACH_CORRELATION_ID = False + SUPPORTS_QUERY_EXECUTION_TRACKING = True SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["DATABASE", "SCHEMA"] def __init__( diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index fe19f7df0f..2901831940 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -40,6 +40,7 @@ ) from sqlmesh.core.model.kind import TimeColumn from sqlmesh.core.schema_diff import SchemaDiffer, TableAlterOperation +from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker from sqlmesh.utils import ( CorrelationId, columns_to_types_all_known, @@ -117,6 +118,7 @@ class EngineAdapter: QUOTE_IDENTIFIERS_IN_VIEWS = True MAX_IDENTIFIER_LENGTH: t.Optional[int] = None ATTACH_CORRELATION_ID = True + SUPPORTS_QUERY_EXECUTION_TRACKING = False def __init__( self, @@ -133,6 +135,7 @@ def __init__( shared_connection: bool = False, correlation_id: t.Optional[CorrelationId] = None, schema_differ_overrides: t.Optional[t.Dict[str, t.Any]] = None, + query_execution_tracker: t.Optional[QueryExecutionTracker] = None, **kwargs: t.Any, ): self.dialect = dialect.lower() or self.DIALECT @@ -156,11 +159,16 @@ def __init__( self._multithreaded = multithreaded self.correlation_id = correlation_id self._schema_differ_overrides = schema_differ_overrides + self._query_execution_tracker = query_execution_tracker def with_settings(self, **kwargs: t.Any) -> EngineAdapter: extra_kwargs = { "null_connection": True, "execute_log_level": kwargs.pop("execute_log_level", self._execute_log_level), + "correlation_id": kwargs.pop("correlation_id", self.correlation_id), + "query_execution_tracker": kwargs.pop( + "query_execution_tracker", self._query_execution_tracker + ), **self._extra_config, **kwargs, } @@ -854,6 +862,7 @@ def _create_table_from_source_queries( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, + track_rows_processed: bool = True, **kwargs: t.Any, ) -> None: table = exp.to_table(table_name) @@ -899,11 +908,15 @@ def _create_table_from_source_queries( replace=replace, table_description=table_description, table_kind=table_kind, + track_rows_processed=track_rows_processed, **kwargs, ) else: self._insert_append_query( - table_name, query, target_columns_to_types or self.columns(table) + table_name, + query, + target_columns_to_types or self.columns(table), + track_rows_processed=track_rows_processed, ) # Register comments with commands if the engine supports comments and we weren't able to @@ -927,6 +940,7 @@ def _create_table( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, + track_rows_processed: bool = True, **kwargs: t.Any, ) -> None: self.execute( @@ -943,7 +957,8 @@ def _create_table( ), table_kind=table_kind, **kwargs, - ) + ), + track_rows_processed=track_rows_processed, ) def _build_create_table_exp( @@ -1431,6 +1446,7 @@ def insert_append( table_name: TableName, query_or_df: QueryOrDF, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + track_rows_processed: bool = True, source_columns: t.Optional[t.List[str]] = None, ) -> None: source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( @@ -1439,19 +1455,27 @@ def insert_append( target_table=table_name, source_columns=source_columns, ) - self._insert_append_source_queries(table_name, source_queries, target_columns_to_types) + self._insert_append_source_queries( + table_name, source_queries, target_columns_to_types, track_rows_processed + ) def _insert_append_source_queries( self, table_name: TableName, source_queries: t.List[SourceQuery], target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + track_rows_processed: bool = True, ) -> None: with self.transaction(condition=len(source_queries) > 0): target_columns_to_types = target_columns_to_types or self.columns(table_name) for source_query in source_queries: with source_query as query: - self._insert_append_query(table_name, query, target_columns_to_types) + self._insert_append_query( + table_name, + query, + target_columns_to_types, + track_rows_processed=track_rows_processed, + ) def _insert_append_query( self, @@ -1459,10 +1483,14 @@ def _insert_append_query( query: Query, target_columns_to_types: t.Dict[str, exp.DataType], order_projections: bool = True, + track_rows_processed: bool = True, ) -> None: if order_projections: query = self._order_projections_and_filter(query, target_columns_to_types) - self.execute(exp.insert(query, table_name, columns=list(target_columns_to_types))) + self.execute( + exp.insert(query, table_name, columns=list(target_columns_to_types)), + track_rows_processed=track_rows_processed, + ) def insert_overwrite_by_partition( self, @@ -1604,7 +1632,7 @@ def _insert_overwrite_by_condition( ) if insert_overwrite_strategy.is_replace_where: insert_exp.set("where", where or exp.true()) - self.execute(insert_exp) + self.execute(insert_exp, track_rows_processed=True) def update_table( self, @@ -1625,7 +1653,9 @@ def _merge( using = exp.alias_( exp.Subquery(this=query), alias=MERGE_SOURCE_ALIAS, copy=False, table=True ) - self.execute(exp.Merge(this=this, using=using, on=on, whens=whens)) + self.execute( + exp.Merge(this=this, using=using, on=on, whens=whens), track_rows_processed=True + ) def scd_type_2_by_time( self, @@ -2374,6 +2404,7 @@ def execute( expressions: t.Union[str, exp.Expression, t.Sequence[exp.Expression]], ignore_unsupported_errors: bool = False, quote_identifiers: bool = True, + track_rows_processed: bool = False, **kwargs: t.Any, ) -> None: """Execute a sql query.""" @@ -2395,7 +2426,7 @@ def execute( expression=e if isinstance(e, exp.Expression) else None, quote_identifiers=quote_identifiers, ) - self._execute(sql, **kwargs) + self._execute(sql, track_rows_processed, **kwargs) def _attach_correlation_id(self, sql: str) -> str: if self.ATTACH_CORRELATION_ID and self.correlation_id: @@ -2420,9 +2451,29 @@ def _log_sql( logger.log(self._execute_log_level, "Executing SQL: %s", sql_to_log) - def _execute(self, sql: str, **kwargs: t.Any) -> None: + def _record_execution_stats( + self, sql: str, rowcount: t.Optional[int] = None, bytes_processed: t.Optional[int] = None + ) -> None: + if self._query_execution_tracker: + self._query_execution_tracker.record_execution(sql, rowcount, bytes_processed) + + def _execute(self, sql: str, track_rows_processed: bool = False, **kwargs: t.Any) -> None: self.cursor.execute(sql, **kwargs) + if ( + self.SUPPORTS_QUERY_EXECUTION_TRACKING + and track_rows_processed + and self._query_execution_tracker + and self._query_execution_tracker.is_tracking() + ): + if ( + rowcount := getattr(self.cursor, "rowcount", None) + ) is not None and rowcount is not None: + try: + self._record_execution_stats(sql, int(rowcount)) + except (TypeError, ValueError): + return + @contextlib.contextmanager def temp_table( self, @@ -2467,6 +2518,7 @@ def temp_table( exists=True, table_description=None, column_descriptions=None, + track_rows_processed=False, **kwargs, ) @@ -2718,7 +2770,7 @@ def _replace_by_key( insert_statement.set("where", delete_filter) insert_statement.set("this", exp.to_table(target_table)) - self.execute(insert_statement) + self.execute(insert_statement, track_rows_processed=True) finally: self.drop_table(temp_table) diff --git a/sqlmesh/core/engine_adapter/base_postgres.py b/sqlmesh/core/engine_adapter/base_postgres.py index 26446aacfd..c6ba7d6d62 100644 --- a/sqlmesh/core/engine_adapter/base_postgres.py +++ b/sqlmesh/core/engine_adapter/base_postgres.py @@ -24,6 +24,7 @@ class BasePostgresEngineAdapter(EngineAdapter): DEFAULT_BATCH_SIZE = 400 COMMENT_CREATION_TABLE = CommentCreationTable.COMMENT_COMMAND_ONLY COMMENT_CREATION_VIEW = CommentCreationView.COMMENT_COMMAND_ONLY + SUPPORTS_QUERY_EXECUTION_TRACKING = True SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA", "TABLE", "VIEW"] def columns( diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index 4c8a125fa3..b3d02d8bbf 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -66,6 +66,7 @@ class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin, ClusteredByMixin, Row SUPPORTS_CLONING = True MAX_TABLE_COMMENT_LENGTH = 1024 MAX_COLUMN_COMMENT_LENGTH = 1024 + SUPPORTS_QUERY_EXECUTION_TRACKING = True SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA"] SCHEMA_DIFFER_KWARGS = { @@ -1049,6 +1050,7 @@ def _db_call(self, func: t.Callable[..., t.Any], *args: t.Any, **kwargs: t.Any) def _execute( self, sql: str, + track_rows_processed: bool = False, **kwargs: t.Any, ) -> None: """Execute a sql query.""" @@ -1094,6 +1096,23 @@ def _execute( self.cursor._set_rowcount(query_results) self.cursor._set_description(query_results.schema) + if ( + track_rows_processed + and self._query_execution_tracker + and self._query_execution_tracker.is_tracking() + ): + num_rows = None + if query_job.statement_type == "CREATE_TABLE_AS_SELECT": + # since table was just created, number rows in table == number rows processed + query_table = self.client.get_table(query_job.destination) + num_rows = query_table.num_rows + elif query_job.statement_type in ["INSERT", "DELETE", "MERGE", "UPDATE"]: + num_rows = query_job.num_dml_affected_rows + + self._query_execution_tracker.record_execution( + sql, num_rows, query_job.total_bytes_processed + ) + def _get_data_objects( self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None ) -> t.List[DataObject]: diff --git a/sqlmesh/core/engine_adapter/clickhouse.py b/sqlmesh/core/engine_adapter/clickhouse.py index 635e6f369b..ccffe64118 100644 --- a/sqlmesh/core/engine_adapter/clickhouse.py +++ b/sqlmesh/core/engine_adapter/clickhouse.py @@ -294,7 +294,7 @@ def _insert_overwrite_by_condition( ) try: - self.execute(existing_records_insert_exp) + self.execute(existing_records_insert_exp, track_rows_processed=True) finally: if table_partition_exp: self.drop_table(partitions_temp_table_name) @@ -489,6 +489,7 @@ def _create_table( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, + track_rows_processed: bool = True, **kwargs: t.Any, ) -> None: """Creates a table in the database. @@ -525,6 +526,7 @@ def _create_table( column_descriptions, table_kind, empty_ctas=(self.engine_run_mode.is_cloud and expression is not None), + track_rows_processed=track_rows_processed, **kwargs, ) diff --git a/sqlmesh/core/engine_adapter/duckdb.py b/sqlmesh/core/engine_adapter/duckdb.py index 4bce813610..3b057219e0 100644 --- a/sqlmesh/core/engine_adapter/duckdb.py +++ b/sqlmesh/core/engine_adapter/duckdb.py @@ -170,6 +170,7 @@ def _create_table( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, + track_rows_processed: bool = True, **kwargs: t.Any, ) -> None: catalog = self.get_current_catalog() @@ -193,6 +194,7 @@ def _create_table( table_description, column_descriptions, table_kind, + track_rows_processed=track_rows_processed, **kwargs, ) diff --git a/sqlmesh/core/engine_adapter/mssql.py b/sqlmesh/core/engine_adapter/mssql.py index 6aefd51fc0..50a67b4b37 100644 --- a/sqlmesh/core/engine_adapter/mssql.py +++ b/sqlmesh/core/engine_adapter/mssql.py @@ -53,6 +53,7 @@ class MSSQLEngineAdapter( COMMENT_CREATION_TABLE = CommentCreationTable.UNSUPPORTED COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED SUPPORTS_REPLACE_TABLE = False + SUPPORTS_QUERY_EXECUTION_TRACKING = True SCHEMA_DIFFER_KWARGS = { "parameterized_type_defaults": { exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(18, 0), (0,)], diff --git a/sqlmesh/core/engine_adapter/mysql.py b/sqlmesh/core/engine_adapter/mysql.py index e81b30e25e..26cc7c0197 100644 --- a/sqlmesh/core/engine_adapter/mysql.py +++ b/sqlmesh/core/engine_adapter/mysql.py @@ -39,6 +39,7 @@ class MySQLEngineAdapter( MAX_COLUMN_COMMENT_LENGTH = 1024 SUPPORTS_REPLACE_TABLE = False MAX_IDENTIFIER_LENGTH = 64 + SUPPORTS_QUERY_EXECUTION_TRACKING = True SCHEMA_DIFFER_KWARGS = { "parameterized_type_defaults": { exp.DataType.build("BIT", dialect=DIALECT).this: [(1,)], diff --git a/sqlmesh/core/engine_adapter/postgres.py b/sqlmesh/core/engine_adapter/postgres.py index faeb52b207..e9c212bd5f 100644 --- a/sqlmesh/core/engine_adapter/postgres.py +++ b/sqlmesh/core/engine_adapter/postgres.py @@ -35,6 +35,7 @@ class PostgresEngineAdapter( CURRENT_CATALOG_EXPRESSION = exp.column("current_catalog") SUPPORTS_REPLACE_TABLE = False MAX_IDENTIFIER_LENGTH = 63 + SUPPORTS_QUERY_EXECUTION_TRACKING = True SCHEMA_DIFFER_KWARGS = { "parameterized_type_defaults": { # DECIMAL without precision is "up to 131072 digits before the decimal point; up to 16383 digits after the decimal point" diff --git a/sqlmesh/core/engine_adapter/redshift.py b/sqlmesh/core/engine_adapter/redshift.py index 30ebc8e30d..7d14207b52 100644 --- a/sqlmesh/core/engine_adapter/redshift.py +++ b/sqlmesh/core/engine_adapter/redshift.py @@ -173,6 +173,7 @@ def _create_table_from_source_queries( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, + track_rows_processed: bool = True, **kwargs: t.Any, ) -> None: """ @@ -426,7 +427,8 @@ def resolve_target_table(expression: exp.Expression) -> exp.Expression: using=using, on=on.transform(resolve_target_table), whens=whens.transform(resolve_target_table), - ) + ), + track_rows_processed=True, ) def _normalize_decimal_value(self, expr: exp.Expression, precision: int) -> exp.Expression: diff --git a/sqlmesh/core/engine_adapter/snowflake.py b/sqlmesh/core/engine_adapter/snowflake.py index c5fa8540b0..8a6f5e2fcc 100644 --- a/sqlmesh/core/engine_adapter/snowflake.py +++ b/sqlmesh/core/engine_adapter/snowflake.py @@ -72,6 +72,7 @@ class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin, ClusteredByMixi } MANAGED_TABLE_KIND = "DYNAMIC TABLE" SNOWPARK = "snowpark" + SUPPORTS_QUERY_EXECUTION_TRACKING = True @contextlib.contextmanager def session(self, properties: SessionProperties) -> t.Iterator[None]: @@ -166,6 +167,7 @@ def _create_table( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, + track_rows_processed: bool = True, **kwargs: t.Any, ) -> None: table_format = kwargs.get("table_format") @@ -185,6 +187,7 @@ def _create_table( table_description=table_description, column_descriptions=column_descriptions, table_kind=table_kind, + track_rows_processed=False, # snowflake tracks CTAS row counts incorrectly **kwargs, ) diff --git a/sqlmesh/core/engine_adapter/spark.py b/sqlmesh/core/engine_adapter/spark.py index 8a529390c1..412e01f5bb 100644 --- a/sqlmesh/core/engine_adapter/spark.py +++ b/sqlmesh/core/engine_adapter/spark.py @@ -433,6 +433,7 @@ def _create_table( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, + track_rows_processed: bool = True, **kwargs: t.Any, ) -> None: table_name = ( @@ -461,6 +462,7 @@ def _create_table( target_columns_to_types=target_columns_to_types, table_description=table_description, column_descriptions=column_descriptions, + track_rows_processed=track_rows_processed, **kwargs, ) table_name = ( diff --git a/sqlmesh/core/engine_adapter/trino.py b/sqlmesh/core/engine_adapter/trino.py index fc08dd10af..4cef557d94 100644 --- a/sqlmesh/core/engine_adapter/trino.py +++ b/sqlmesh/core/engine_adapter/trino.py @@ -55,6 +55,7 @@ class TrinoEngineAdapter( SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA"] DEFAULT_CATALOG_TYPE = "hive" QUOTE_IDENTIFIERS_IN_VIEWS = False + SUPPORTS_QUERY_EXECUTION_TRACKING = True SCHEMA_DIFFER_KWARGS = { "parameterized_type_defaults": { # default decimal precision varies across backends @@ -357,6 +358,7 @@ def _create_table( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, + track_rows_processed: bool = True, **kwargs: t.Any, ) -> None: super()._create_table( @@ -368,6 +370,7 @@ def _create_table( table_description=table_description, column_descriptions=column_descriptions, table_kind=table_kind, + track_rows_processed=track_rows_processed, **kwargs, ) diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 7a653877ae..210aff230d 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -20,6 +20,7 @@ DeployabilityIndex, Snapshot, SnapshotId, + SnapshotIdBatch, SnapshotEvaluator, apply_auto_restatements, earliest_start_date, @@ -531,6 +532,11 @@ def run_node(node: SchedulingUnit) -> None: finally: num_audits = len(audit_results) num_audits_failed = sum(1 for result in audit_results if result.count) + + execution_stats = self.snapshot_evaluator.execution_tracker.get_execution_stats( + SnapshotIdBatch(snapshot_id=snapshot.snapshot_id, batch_id=node.batch_index) + ) + self.console.update_snapshot_evaluation_progress( snapshot, batched_intervals[snapshot][node.batch_index], @@ -538,6 +544,7 @@ def run_node(node: SchedulingUnit) -> None: evaluation_duration_ms, num_audits - num_audits_failed, num_audits_failed, + execution_stats=execution_stats, auto_restatement_triggers=auto_restatement_triggers.get( snapshot.snapshot_id ), diff --git a/sqlmesh/core/snapshot/__init__.py b/sqlmesh/core/snapshot/__init__.py index da44278aa8..8ad574f8ac 100644 --- a/sqlmesh/core/snapshot/__init__.py +++ b/sqlmesh/core/snapshot/__init__.py @@ -8,6 +8,7 @@ SnapshotDataVersion as SnapshotDataVersion, SnapshotFingerprint as SnapshotFingerprint, SnapshotId as SnapshotId, + SnapshotIdBatch as SnapshotIdBatch, SnapshotIdLike as SnapshotIdLike, SnapshotInfoLike as SnapshotInfoLike, SnapshotIntervals as SnapshotIntervals, diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index 1a286edcfc..afc8e06458 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -13,10 +13,13 @@ from sqlglot import exp from sqlglot.optimizer.normalize_identifiers import normalize_identifiers -from sqlmesh.core.config.common import TableNamingConvention, VirtualEnvironmentMode +from sqlmesh.core.config.common import ( + TableNamingConvention, + VirtualEnvironmentMode, + EnvironmentSuffixTarget, +) from sqlmesh.core import constants as c from sqlmesh.core.audit import StandaloneAudit -from sqlmesh.core.environment import EnvironmentSuffixTarget from sqlmesh.core.macros import call_macro from sqlmesh.core.model import Model, ModelKindMixin, ModelKindName, ViewKind, CustomKind from sqlmesh.core.model.definition import _Model @@ -159,6 +162,11 @@ def __str__(self) -> str: return f"SnapshotId<{self.name}: {self.identifier}>" +class SnapshotIdBatch(PydanticModel, frozen=True): + snapshot_id: SnapshotId + batch_id: int + + class SnapshotNameVersion(PydanticModel, frozen=True): name: str version: str diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 82924e4c3a..90186faba7 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -37,7 +37,6 @@ from sqlmesh.core import dialect as d from sqlmesh.core.audit import Audit, StandaloneAudit from sqlmesh.core.dialect import schema_ -from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy, DataObjectType from sqlmesh.core.macros import RuntimeStage from sqlmesh.core.model import ( @@ -62,9 +61,11 @@ Intervals, Snapshot, SnapshotId, + SnapshotIdBatch, SnapshotInfoLike, SnapshotTableCleanupTask, ) +from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker from sqlmesh.utils import random_id, CorrelationId from sqlmesh.utils.concurrency import ( concurrent_apply_to_snapshots, @@ -88,6 +89,7 @@ if t.TYPE_CHECKING: from sqlmesh.core.engine_adapter._typing import DF, QueryOrDF + from sqlmesh.core.engine_adapter.base import EngineAdapter from sqlmesh.core.environment import EnvironmentNamingInfo logger = logging.getLogger(__name__) @@ -128,6 +130,11 @@ def __init__( self.adapters = ( adapters if isinstance(adapters, t.Dict) else {selected_gateway or "": adapters} ) + self.execution_tracker = QueryExecutionTracker() + self.adapters = { + gateway: adapter.with_settings(query_execution_tracker=self.execution_tracker) + for gateway, adapter in self.adapters.items() + } self.adapter = ( next(iter(self.adapters.values())) if not selected_gateway @@ -169,19 +176,22 @@ def evaluate( Returns: The WAP ID of this evaluation if supported, None otherwise. """ - result = self._evaluate_snapshot( - start=start, - end=end, - execution_time=execution_time, - snapshot=snapshot, - snapshots=snapshots, - allow_destructive_snapshots=allow_destructive_snapshots or set(), - allow_additive_snapshots=allow_additive_snapshots or set(), - deployability_index=deployability_index, - batch_index=batch_index, - target_table_exists=target_table_exists, - **kwargs, - ) + with self.execution_tracker.track_execution( + SnapshotIdBatch(snapshot_id=snapshot.snapshot_id, batch_id=batch_index) + ): + result = self._evaluate_snapshot( + start=start, + end=end, + execution_time=execution_time, + snapshot=snapshot, + snapshots=snapshots, + allow_destructive_snapshots=allow_destructive_snapshots or set(), + allow_additive_snapshots=allow_additive_snapshots or set(), + deployability_index=deployability_index, + batch_index=batch_index, + target_table_exists=target_table_exists, + **kwargs, + ) if result is None or isinstance(result, str): return result raise SQLMeshError( diff --git a/sqlmesh/core/snapshot/execution_tracker.py b/sqlmesh/core/snapshot/execution_tracker.py new file mode 100644 index 0000000000..bcafec8d28 --- /dev/null +++ b/sqlmesh/core/snapshot/execution_tracker.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import typing as t +from contextlib import contextmanager +from threading import local +from dataclasses import dataclass, field +from sqlmesh.core.snapshot import SnapshotIdBatch + + +@dataclass +class QueryExecutionStats: + snapshot_id_batch: SnapshotIdBatch + total_rows_processed: t.Optional[int] = None + total_bytes_processed: t.Optional[int] = None + + +@dataclass +class QueryExecutionContext: + """ + Container for tracking rows processed or other execution information during snapshot evaluation. + + It accumulates statistics from multiple cursor.execute() calls during a single snapshot evaluation. + + Attributes: + snapshot_id_batch: Identifier linking this context to a specific snapshot evaluation + stats: Running sum of cursor.rowcount and possibly bytes processed from all executed queries during evaluation + """ + + snapshot_id_batch: SnapshotIdBatch + stats: QueryExecutionStats = field(init=False) + + def __post_init__(self) -> None: + self.stats = QueryExecutionStats(snapshot_id_batch=self.snapshot_id_batch) + + def add_execution( + self, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int] + ) -> None: + if row_count is not None and row_count >= 0: + if self.stats.total_rows_processed is None: + self.stats.total_rows_processed = row_count + else: + self.stats.total_rows_processed += row_count + + # conditional on row_count because we should only count bytes corresponding to + # DML actions whose rows were captured + if bytes_processed is not None: + if self.stats.total_bytes_processed is None: + self.stats.total_bytes_processed = bytes_processed + else: + self.stats.total_bytes_processed += bytes_processed + + def get_execution_stats(self) -> QueryExecutionStats: + return self.stats + + +class QueryExecutionTracker: + """Thread-local context manager for snapshot execution statistics, such as rows processed.""" + + def __init__(self) -> None: + self._thread_local = local() + self._contexts: t.Dict[SnapshotIdBatch, QueryExecutionContext] = {} + + def get_execution_context( + self, snapshot_id_batch: SnapshotIdBatch + ) -> t.Optional[QueryExecutionContext]: + return self._contexts.get(snapshot_id_batch) + + def is_tracking(self) -> bool: + return getattr(self._thread_local, "context", None) is not None + + @contextmanager + def track_execution( + self, snapshot_id_batch: SnapshotIdBatch + ) -> t.Iterator[t.Optional[QueryExecutionContext]]: + """Context manager for tracking snapshot execution statistics such as row counts and bytes processed.""" + context = QueryExecutionContext(snapshot_id_batch=snapshot_id_batch) + self._thread_local.context = context + self._contexts[snapshot_id_batch] = context + + try: + yield context + finally: + self._thread_local.context = None + + def record_execution( + self, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int] + ) -> None: + context = getattr(self._thread_local, "context", None) + if context is not None: + context.add_execution(sql, row_count, bytes_processed) + + def get_execution_stats( + self, snapshot_id_batch: SnapshotIdBatch + ) -> t.Optional[QueryExecutionStats]: + context = self._contexts.get(snapshot_id_batch) + self._contexts.pop(snapshot_id_batch, None) + return context.get_execution_stats() if context else None diff --git a/sqlmesh/core/state_sync/db/environment.py b/sqlmesh/core/state_sync/db/environment.py index 3196d18078..4a28d7d70a 100644 --- a/sqlmesh/core/state_sync/db/environment.py +++ b/sqlmesh/core/state_sync/db/environment.py @@ -78,6 +78,7 @@ def update_environment(self, environment: Environment) -> None: self.environments_table, _environment_to_df(environment), target_columns_to_types=self._environment_columns_to_types, + track_rows_processed=False, ) def update_environment_statements( @@ -108,6 +109,7 @@ def update_environment_statements( self.environment_statements_table, _environment_statements_to_df(environment_name, plan_id, environment_statements), target_columns_to_types=self._environment_statements_columns_to_types, + track_rows_processed=False, ) def invalidate_environment(self, name: str, protect_prod: bool = True) -> None: diff --git a/sqlmesh/core/state_sync/db/interval.py b/sqlmesh/core/state_sync/db/interval.py index bdfedace1e..75f475b75b 100644 --- a/sqlmesh/core/state_sync/db/interval.py +++ b/sqlmesh/core/state_sync/db/interval.py @@ -115,6 +115,7 @@ def remove_intervals( self.intervals_table, _intervals_to_df(intervals_to_remove, is_dev=False, is_removed=True), target_columns_to_types=self._interval_columns_to_types, + track_rows_processed=False, ) def get_snapshot_intervals( @@ -243,6 +244,7 @@ def _push_snapshot_intervals( self.intervals_table, pd.DataFrame(new_intervals), target_columns_to_types=self._interval_columns_to_types, + track_rows_processed=False, ) def _get_snapshot_intervals( diff --git a/sqlmesh/core/state_sync/db/migrator.py b/sqlmesh/core/state_sync/db/migrator.py index ca89668763..7edd7de3c4 100644 --- a/sqlmesh/core/state_sync/db/migrator.py +++ b/sqlmesh/core/state_sync/db/migrator.py @@ -413,7 +413,9 @@ def _backup_state(self) -> None: backup_name = _backup_table_name(table) self.engine_adapter.drop_table(backup_name) self.engine_adapter.create_table_like(backup_name, table) - self.engine_adapter.insert_append(backup_name, exp.select("*").from_(table)) + self.engine_adapter.insert_append( + backup_name, exp.select("*").from_(table), track_rows_processed=False + ) def _restore_table( self, diff --git a/sqlmesh/core/state_sync/db/snapshot.py b/sqlmesh/core/state_sync/db/snapshot.py index 9cf4f2fbf5..8d504993fc 100644 --- a/sqlmesh/core/state_sync/db/snapshot.py +++ b/sqlmesh/core/state_sync/db/snapshot.py @@ -103,6 +103,7 @@ def push_snapshots(self, snapshots: t.Iterable[Snapshot], overwrite: bool = Fals self.snapshots_table, _snapshots_to_df(snapshots_to_store), target_columns_to_types=self._snapshot_columns_to_types, + track_rows_processed=False, ) for snapshot in snapshots: @@ -406,6 +407,7 @@ def _push_snapshots(self, snapshots: t.Iterable[Snapshot]) -> None: self.snapshots_table, _snapshots_to_df(snapshots_to_store), target_columns_to_types=self._snapshot_columns_to_types, + track_rows_processed=False, ) def _get_snapshots( diff --git a/sqlmesh/core/state_sync/db/version.py b/sqlmesh/core/state_sync/db/version.py index 492d74cc09..c95592bc31 100644 --- a/sqlmesh/core/state_sync/db/version.py +++ b/sqlmesh/core/state_sync/db/version.py @@ -55,6 +55,7 @@ def update_versions( ] ), target_columns_to_types=self._version_columns_to_types, + track_rows_processed=False, ) def get_versions(self) -> Versions: diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index 1960848e24..19a45329d5 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -6,12 +6,13 @@ import sys import typing as t import shutil -from datetime import datetime, timedelta - +from datetime import datetime, timedelta, date +from unittest.mock import patch import numpy as np # noqa: TID253 import pandas as pd # noqa: TID253 import pytest import pytz +import time_machine from sqlglot import exp, parse_one from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlglot.optimizer.qualify_columns import quote_identifiers @@ -2382,8 +2383,30 @@ def _mutate_config(gateway: str, config: Config): ) context._models.update(replacement_models) + # capture row counts for each evaluated snapshot + actual_execution_stats = {} + + def capture_execution_stats( + snapshot, + interval, + batch_idx, + duration_ms, + num_audits_passed, + num_audits_failed, + audit_only=False, + execution_stats=None, + auto_restatement_triggers=None, + ): + if execution_stats is not None: + actual_execution_stats[snapshot.model.name.replace(f"{schema_name}.", "")] = ( + execution_stats + ) + # apply prod plan - context.plan(auto_apply=True, no_prompts=True) + with patch.object( + context.console, "update_snapshot_evaluation_progress", capture_execution_stats + ): + context.plan(auto_apply=True, no_prompts=True) prod_schema_results = ctx.get_metadata_results(object_names["view_schema"][0]) assert sorted(prod_schema_results.views) == object_names["views"] @@ -2395,6 +2418,34 @@ def _mutate_config(gateway: str, config: Config): assert len(physical_layer_results.materialized_views) == 0 assert len(physical_layer_results.tables) == len(physical_layer_results.non_temp_tables) == 3 + if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING: + assert actual_execution_stats["incremental_model"].total_rows_processed == 7 + # snowflake doesn't track rows for CTAS + assert actual_execution_stats["full_model"].total_rows_processed == ( + None if ctx.mark.startswith("snowflake") else 3 + ) + assert actual_execution_stats["seed_model"].total_rows_processed == ( + None if ctx.mark.startswith("snowflake") else 7 + ) + + if ctx.mark.startswith("bigquery"): + assert actual_execution_stats["incremental_model"].total_bytes_processed + assert actual_execution_stats["full_model"].total_bytes_processed + + # run that loads 0 rows in incremental model + # - some cloud DBs error because time travel messes up token expiration + if not ctx.is_remote: + actual_execution_stats = {} + with patch.object( + context.console, "update_snapshot_evaluation_progress", capture_execution_stats + ): + with time_machine.travel(date.today() + timedelta(days=1)): + context.run() + + if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING: + assert actual_execution_stats["incremental_model"].total_rows_processed == 0 + assert actual_execution_stats["full_model"].total_rows_processed == 3 + # make and validate unmodified dev environment no_change_plan: Plan = context.plan_builder( environment="test_dev", diff --git a/tests/core/engine_adapter/integration/test_integration_snowflake.py b/tests/core/engine_adapter/integration/test_integration_snowflake.py index 01cbe1c0aa..aed6bf83e4 100644 --- a/tests/core/engine_adapter/integration/test_integration_snowflake.py +++ b/tests/core/engine_adapter/integration/test_integration_snowflake.py @@ -12,6 +12,12 @@ from sqlmesh.core.plan import Plan from tests.core.engine_adapter.integration import TestContext from sqlmesh import model, ExecutionContext +from pytest_mock import MockerFixture +from sqlmesh.core.snapshot import SnapshotId, SnapshotIdBatch +from sqlmesh.core.snapshot.execution_tracker import ( + QueryExecutionContext, + QueryExecutionTracker, +) from sqlmesh.core.model import ModelKindName from datetime import datetime @@ -307,3 +313,27 @@ def fetch_database_names() -> t.Set[str]: engine_adapter.drop_catalog(sqlmesh_managed_catalog) # works, catalog is SQLMesh-managed assert fetch_database_names() == {non_sqlmesh_managed_catalog} + + +def test_rows_tracker( + ctx: TestContext, engine_adapter: SnowflakeEngineAdapter, mocker: MockerFixture +): + sqlmesh = ctx.create_context() + tracker = QueryExecutionTracker() + + add_execution_spy = mocker.spy(QueryExecutionContext, "add_execution") + + with tracker.track_execution( + SnapshotIdBatch(snapshot_id=SnapshotId(name="a", identifier="a"), batch_id=0) + ): + # Snowflake doesn't report row counts for CTAS, so this should not be tracked + engine_adapter._create_table("a", exp.select("1 as id")) + + assert add_execution_spy.call_count == 0 + + stats = tracker.get_execution_stats( + SnapshotIdBatch(snapshot_id=SnapshotId(name="a", identifier="a"), batch_id=0) + ) + assert stats is not None + assert stats.total_rows_processed is None + assert stats.total_bytes_processed is None diff --git a/tests/core/test_execution_tracker.py b/tests/core/test_execution_tracker.py new file mode 100644 index 0000000000..0e58395bee --- /dev/null +++ b/tests/core/test_execution_tracker.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor + +from sqlmesh.core.snapshot.execution_tracker import QueryExecutionStats, QueryExecutionTracker +from sqlmesh.core.snapshot import SnapshotIdBatch, SnapshotId + + +def test_execution_tracker_thread_isolation() -> None: + def worker(id: SnapshotId, row_counts: list[int]) -> QueryExecutionStats: + with execution_tracker.track_execution(SnapshotIdBatch(snapshot_id=id, batch_id=0)) as ctx: + assert execution_tracker.is_tracking() + + for count in row_counts: + execution_tracker.record_execution("SELECT 1", count, None) + + assert ctx is not None + return ctx.get_execution_stats() + + execution_tracker = QueryExecutionTracker() + + with ThreadPoolExecutor() as executor: + futures = [ + executor.submit(worker, SnapshotId(name="batch_A", identifier="batch_A"), [10, 5]), + executor.submit(worker, SnapshotId(name="batch_B", identifier="batch_B"), [3, 7]), + ] + results = [f.result() for f in futures] + + # Main thread has no active tracking context + assert not execution_tracker.is_tracking() + + # Order of results is not deterministic, so look up by id + by_batch = {s.snapshot_id_batch: s for s in results} + + assert ( + by_batch[ + SnapshotIdBatch( + snapshot_id=SnapshotId(name="batch_A", identifier="batch_A"), batch_id=0 + ) + ].total_rows_processed + == 15 + ) + assert ( + by_batch[ + SnapshotIdBatch( + snapshot_id=SnapshotId(name="batch_B", identifier="batch_B"), batch_id=0 + ) + ].total_rows_processed + == 10 + ) diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index f80c42f579..8cd50dc732 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -8022,7 +8022,7 @@ def test_incremental_by_time_model_ignore_additive_change(tmp_path: Path): cron '@daily' ); - SELECT + SELECT *, 1 as id, 'test_name' as name, @@ -8068,7 +8068,7 @@ def test_incremental_by_time_model_ignore_additive_change(tmp_path: Path): cron '@daily' ); - SELECT + SELECT *, 1 as id, 'other' as other_column, @@ -8124,7 +8124,7 @@ def test_incremental_by_time_model_ignore_additive_change(tmp_path: Path): cron '@daily' ); - SELECT + SELECT *, CAST(1 AS STRING) as id, 'other' as other_column, @@ -8170,7 +8170,7 @@ def test_incremental_by_time_model_ignore_additive_change(tmp_path: Path): cron '@daily' ); - SELECT + SELECT *, CAST(1 AS STRING) as id, 'other' as other_column, @@ -8344,7 +8344,7 @@ def test_incremental_by_unique_key_model_ignore_additive_change(tmp_path: Path): cron '@daily' ); - SELECT + SELECT *, 1 as id, 'test_name' as name, @@ -8389,7 +8389,7 @@ def test_incremental_by_unique_key_model_ignore_additive_change(tmp_path: Path): cron '@daily' ); - SELECT + SELECT *, 2 as id, 3 as new_column, @@ -8566,7 +8566,7 @@ def test_incremental_unmanaged_model_ignore_additive_change(tmp_path: Path): cron '@daily' ); - SELECT + SELECT *, 1 as id, 'test_name' as name, @@ -8610,7 +8610,7 @@ def test_incremental_unmanaged_model_ignore_additive_change(tmp_path: Path): ); SELECT - *, + *, 2 as id, 3 as new_column, @start_ds as ds @@ -9020,7 +9020,7 @@ def test_scd_type_2_by_column_ignore_additive_change(tmp_path: Path): cron '@daily' ); - SELECT + SELECT *, 1 as id, 'test_name' as name, @@ -9067,7 +9067,7 @@ def test_scd_type_2_by_column_ignore_additive_change(tmp_path: Path): ); SELECT - *, + *, 1 as id, 'stable2' as stable, 3 as new_column, @@ -9247,7 +9247,7 @@ def test_incremental_partition_ignore_additive_change(tmp_path: Path): cron '@daily' ); - SELECT + SELECT *, 1 as id, 'test_name' as name, @@ -9292,7 +9292,7 @@ def test_incremental_partition_ignore_additive_change(tmp_path: Path): ); SELECT - *, + *, 1 as id, 3 as new_column, @start_ds as ds @@ -9526,7 +9526,7 @@ def test_incremental_by_time_model_ignore_additive_change_unit_test(tmp_path: Pa cron '@daily' ); - SELECT + SELECT id, name, ds @@ -9595,7 +9595,7 @@ def test_incremental_by_time_model_ignore_additive_change_unit_test(tmp_path: Pa cron '@daily' ); - SELECT + SELECT id, new_column, ds diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index 53f9bd425a..3b72a14f5f 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -115,6 +115,7 @@ def mock_exit(self, exc_type, exc_value, traceback): adapter_mock.HAS_VIEW_BINDING = False adapter_mock.wap_supported.return_value = False adapter_mock.get_data_objects.return_value = [] + adapter_mock.with_settings.return_value = adapter_mock return adapter_mock @@ -137,6 +138,7 @@ def adapters(mocker: MockerFixture): adapter_mock.HAS_VIEW_BINDING = False adapter_mock.wap_supported.return_value = False adapter_mock.get_data_objects.return_value = [] + adapter_mock.with_settings.return_value = adapter_mock adapters.append(adapter_mock) return adapters @@ -652,6 +654,7 @@ def test_evaluate_materialized_view_with_partitioned_by_cluster_by( adapter.table_exists = lambda *args, **kwargs: False # type: ignore adapter.get_data_objects = lambda *args, **kwargs: [] # type: ignore adapter._execute = execute_mock # type: ignore + adapter.with_settings = lambda **kwargs: adapter # type: ignore evaluator = SnapshotEvaluator(adapter) model = SqlModel( @@ -676,7 +679,8 @@ def test_evaluate_materialized_view_with_partitioned_by_cluster_by( execute_mock.assert_has_calls( [ call( - f"CREATE MATERIALIZED VIEW `sqlmesh__test_schema`.`test_schema__test_model__{snapshot.version}` PARTITION BY `a` CLUSTER BY `b` AS SELECT `a` AS `a`, `b` AS `b` FROM `tbl` AS `tbl`" + f"CREATE MATERIALIZED VIEW `sqlmesh__test_schema`.`test_schema__test_model__{snapshot.version}` PARTITION BY `a` CLUSTER BY `b` AS SELECT `a` AS `a`, `b` AS `b` FROM `tbl` AS `tbl`", + False, ), ] ) @@ -991,6 +995,7 @@ def test_create_tables_exist( ): adapter_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter") adapter_mock.dialect = "duckdb" + adapter_mock.with_settings.return_value = adapter_mock evaluator = SnapshotEvaluator(adapter_mock) snapshot.categorize_as(category=snapshot_category, forward_only=forward_only) @@ -1193,6 +1198,7 @@ def test_create_view_with_properties(mocker: MockerFixture, adapter_mock, make_s def test_promote_model_info(mocker: MockerFixture, make_snapshot): adapter_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter") adapter_mock.dialect = "duckdb" + adapter_mock.with_settings.return_value = adapter_mock evaluator = SnapshotEvaluator(adapter_mock) @@ -1221,6 +1227,7 @@ def test_promote_model_info(mocker: MockerFixture, make_snapshot): def test_promote_deployable(mocker: MockerFixture, make_snapshot): adapter_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter") adapter_mock.dialect = "duckdb" + adapter_mock.with_settings.return_value = adapter_mock evaluator = SnapshotEvaluator(adapter_mock) @@ -1266,6 +1273,7 @@ def test_promote_deployable(mocker: MockerFixture, make_snapshot): def test_migrate(mocker: MockerFixture, make_snapshot, make_mocked_engine_adapter): adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.with_settings = lambda **kwargs: adapter # type: ignore session_spy = mocker.spy(adapter, "session") current_table = "sqlmesh__test_schema.test_schema__test_model__1" @@ -1321,6 +1329,7 @@ def columns(table_name): def test_migrate_missing_table(mocker: MockerFixture, make_snapshot, make_mocked_engine_adapter): adapter = make_mocked_engine_adapter(EngineAdapter) adapter.table_exists = lambda _: False # type: ignore + adapter.with_settings = lambda **kwargs: adapter # type: ignore mocker.patch.object(adapter, "get_data_object", return_value=None) evaluator = SnapshotEvaluator(adapter) @@ -1389,6 +1398,7 @@ def test_migrate_snapshot_data_object_type_mismatch( make_mocked_engine_adapter, ): adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.with_settings = lambda **kwargs: adapter # type: ignore mocker.patch.object( adapter, "get_data_object", @@ -1803,7 +1813,7 @@ def test_on_destructive_change_runtime_check( make_mocked_engine_adapter, ): adapter = make_mocked_engine_adapter(EngineAdapter) - + adapter.with_settings = lambda **kwargs: adapter # type: ignore current_table = "sqlmesh__test_schema.test_schema__test_model__1" def columns(table_name): @@ -1885,7 +1895,7 @@ def test_on_additive_change_runtime_check( make_mocked_engine_adapter, ): adapter = make_mocked_engine_adapter(EngineAdapter) - + adapter.with_settings = lambda **kwargs: adapter # type: ignore current_table = "sqlmesh__test_schema.test_schema__test_model__1" def columns(table_name): @@ -3777,6 +3787,7 @@ def test_create_managed_forward_only_with_previous_version_doesnt_clone_for_dev_ def test_migrate_snapshot(snapshot: Snapshot, mocker: MockerFixture, adapter_mock, make_snapshot): adapter_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter") adapter_mock.dialect = "duckdb" + adapter_mock.with_settings.return_value = adapter_mock evaluator = SnapshotEvaluator(adapter_mock) evaluator.create([snapshot], {}) @@ -3986,6 +3997,7 @@ def test_multiple_engine_promotion(mocker: MockerFixture, adapter_mock, make_sna cursor_mock = mocker.Mock() connection_mock.cursor.return_value = cursor_mock adapter = EngineAdapter(lambda: connection_mock, "") + adapter.with_settings = lambda **kwargs: adapter # type: ignore engine_adapters = {"default": adapter_mock, "secondary": adapter} def columns(table_name): @@ -4045,7 +4057,9 @@ def test_multiple_engine_migration( mocker: MockerFixture, adapter_mock, make_snapshot, make_mocked_engine_adapter ): adapter_one = make_mocked_engine_adapter(EngineAdapter) + adapter_one.with_settings = lambda **kwargs: adapter_one # type: ignore adapter_two = adapter_mock + adapter_two.with_settings.return_value = adapter_two engine_adapters = {"one": adapter_one, "two": adapter_two} current_table = "sqlmesh__test_schema.test_schema__test_model__1" diff --git a/tests/core/test_table_diff.py b/tests/core/test_table_diff.py index b2848676b4..73fd37a2f7 100644 --- a/tests/core/test_table_diff.py +++ b/tests/core/test_table_diff.py @@ -360,11 +360,11 @@ def test_generated_sql(sushi_context_fixed_date: Context, mocker: MockerFixture) temp_schema="sqlmesh_temp_test", ) - spy_execute.assert_any_call(query_sql) - spy_execute.assert_any_call(summary_query_sql) - spy_execute.assert_any_call(compare_sql) - spy_execute.assert_any_call(sample_query_sql) - spy_execute.assert_any_call(drop_sql) + spy_execute.assert_any_call(query_sql, False) + spy_execute.assert_any_call(summary_query_sql, False) + spy_execute.assert_any_call(compare_sql, False) + spy_execute.assert_any_call(sample_query_sql, False) + spy_execute.assert_any_call(drop_sql, False) spy_execute.reset_mock() @@ -378,7 +378,7 @@ def test_generated_sql(sushi_context_fixed_date: Context, mocker: MockerFixture) ) query_sql_where = 'CREATE TABLE IF NOT EXISTS "memory"."sqlmesh_temp"."__temp_diff_abcdefgh" AS WITH "__source" AS (SELECT "s"."key", "s"."value", "s"."key" AS "__sqlmesh_join_key" FROM "table_diff_source" AS "s" WHERE "s"."key" = 2), "__target" AS (SELECT "t"."key", "t"."value", "t"."key" AS "__sqlmesh_join_key" FROM "table_diff_target" AS "t" WHERE "t"."key" = 2), "__stats" AS (SELECT "s"."key" AS "s__key", "s"."value" AS "s__value", "s"."__sqlmesh_join_key" AS "s____sqlmesh_join_key", "t"."key" AS "t__key", "t"."value" AS "t__value", "t"."__sqlmesh_join_key" AS "t____sqlmesh_join_key", CASE WHEN NOT "s"."__sqlmesh_join_key" IS NULL THEN 1 ELSE 0 END AS "s_exists", CASE WHEN NOT "t"."__sqlmesh_join_key" IS NULL THEN 1 ELSE 0 END AS "t_exists", CASE WHEN "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key" THEN 1 ELSE 0 END AS "row_joined", CASE WHEN "s"."key" IS NULL AND "t"."key" IS NULL THEN 1 ELSE 0 END AS "null_grain", CASE WHEN "s"."key" = "t"."key" THEN 1 WHEN ("s"."key" IS NULL) AND ("t"."key" IS NULL) THEN 1 WHEN ("s"."key" IS NULL) OR ("t"."key" IS NULL) THEN 0 ELSE 0 END AS "key_matches", CASE WHEN ROUND("s"."value", 3) = ROUND("t"."value", 3) THEN 1 WHEN ("s"."value" IS NULL) AND ("t"."value" IS NULL) THEN 1 WHEN ("s"."value" IS NULL) OR ("t"."value" IS NULL) THEN 0 ELSE 0 END AS "value_matches" FROM "__source" AS "s" FULL JOIN "__target" AS "t" ON "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key") SELECT *, CASE WHEN "key_matches" = 1 AND "value_matches" = 1 THEN 1 ELSE 0 END AS "row_full_match" FROM "__stats"' - spy_execute.assert_any_call(query_sql_where) + spy_execute.assert_any_call(query_sql_where, False) def test_tables_and_grain_inferred_from_model(sushi_context_fixed_date: Context): diff --git a/tests/core/test_test.py b/tests/core/test_test.py index 9c3c3aba4b..1b5425068f 100644 --- a/tests/core/test_test.py +++ b/tests/core/test_test.py @@ -874,7 +874,8 @@ def test_partially_inferred_schemas(sushi_context: Context, mocker: MockerFixtur 'CAST("s" AS STRUCT("d" DATE)) AS "s", ' 'CAST("a" AS INT) AS "a", ' 'CAST("b" AS TEXT) AS "b" ' - """FROM (VALUES ({'d': CAST('2020-01-01' AS DATE)}, 1, 'bla')) AS "t"("s", "a", "b")""" + """FROM (VALUES ({'d': CAST('2020-01-01' AS DATE)}, 1, 'bla')) AS "t"("s", "a", "b")""", + False, ) @@ -1329,14 +1330,15 @@ def test_freeze_time(mocker: MockerFixture) -> None: spy_execute.assert_has_calls( [ - call('CREATE SCHEMA IF NOT EXISTS "memory"."sqlmesh_test_jzngz56a"'), + call('CREATE SCHEMA IF NOT EXISTS "memory"."sqlmesh_test_jzngz56a"', False), call( "SELECT " """CAST('2023-01-01 12:05:03+00:00' AS DATE) AS "cur_date", """ """CAST('2023-01-01 12:05:03+00:00' AS TIME) AS "cur_time", """ - '''CAST('2023-01-01 12:05:03+00:00' AS TIMESTAMP) AS "cur_timestamp"''' + '''CAST('2023-01-01 12:05:03+00:00' AS TIMESTAMP) AS "cur_timestamp"''', + False, ), - call('DROP SCHEMA IF EXISTS "memory"."sqlmesh_test_jzngz56a" CASCADE'), + call('DROP SCHEMA IF EXISTS "memory"."sqlmesh_test_jzngz56a" CASCADE', False), ] ) @@ -1361,7 +1363,12 @@ def test_freeze_time(mocker: MockerFixture) -> None: _check_successful_or_raise(test.run()) spy_execute.assert_has_calls( - [call('''SELECT CAST('2023-01-01 12:05:03+00:00' AS TIMESTAMPTZ) AS "cur_timestamp"''')] + [ + call( + '''SELECT CAST('2023-01-01 12:05:03+00:00' AS TIMESTAMPTZ) AS "cur_timestamp"''', + False, + ) + ] ) @model("py_model", columns={"ts1": "timestamptz", "ts2": "timestamptz"}) @@ -1496,7 +1503,7 @@ def test_gateway(copy_to_temp_path: t.Callable, mocker: MockerFixture) -> None: 'AS "t"("id", "customer_id", "waiter_id", "start_ts", "end_ts", "event_date")' ) test_adapter = t.cast(ModelTest, result.successes[0]).engine_adapter - assert call(test_adapter, expected_view_sql) in spy_execute.mock_calls + assert call(test_adapter, expected_view_sql, False) in spy_execute.mock_calls _check_successful_or_raise(context.test()) @@ -1621,7 +1628,8 @@ def test_generate_input_data_using_sql(mocker: MockerFixture, tmp_path: Path) -> spy_execute.assert_any_call( 'CREATE OR REPLACE VIEW "memory"."sqlmesh_test_jzngz56a"."foo" AS ' - '''SELECT {'x': 1, 'n': {'y': 2}} AS "struct_value"''' + '''SELECT {'x': 1, 'n': {'y': 2}} AS "struct_value"''', + False, ) with pytest.raises( @@ -1817,9 +1825,9 @@ def test_custom_testing_schema(mocker: MockerFixture) -> None: spy_execute.assert_has_calls( [ - call('CREATE SCHEMA IF NOT EXISTS "memory"."my_schema"'), - call('SELECT 1 AS "a"'), - call('DROP SCHEMA IF EXISTS "memory"."my_schema" CASCADE'), + call('CREATE SCHEMA IF NOT EXISTS "memory"."my_schema"', False), + call('SELECT 1 AS "a"', False), + call('DROP SCHEMA IF EXISTS "memory"."my_schema" CASCADE', False), ] ) @@ -1845,9 +1853,9 @@ def test_pretty_query(mocker: MockerFixture) -> None: _check_successful_or_raise(test.run()) spy_execute.assert_has_calls( [ - call('CREATE SCHEMA IF NOT EXISTS "memory"."my_schema"'), - call('SELECT\n 1 AS "a"'), - call('DROP SCHEMA IF EXISTS "memory"."my_schema" CASCADE'), + call('CREATE SCHEMA IF NOT EXISTS "memory"."my_schema"', False), + call('SELECT\n 1 AS "a"', False), + call('DROP SCHEMA IF EXISTS "memory"."my_schema" CASCADE', False), ] ) @@ -2950,7 +2958,7 @@ def test_parameterized_name_sql_model() -> None: outputs: query: - id: 1 - name: foo + name: foo """, variables=variables, ), @@ -2999,7 +3007,7 @@ def execute( outputs: query: - id: 1 - name: foo + name: foo """, variables=variables, ), @@ -3049,7 +3057,7 @@ def test_parameterized_name_self_referential_model(): v: int outputs: query: - - v: 1 + - v: 1 """, variables=variables, ), @@ -3171,7 +3179,7 @@ def execute( - id: 5 outputs: query: - - id: 8 + - id: 8 """, variables=variables, ), diff --git a/web/server/console.py b/web/server/console.py index 902a85418c..871aaefbb1 100644 --- a/web/server/console.py +++ b/web/server/console.py @@ -10,6 +10,7 @@ from sqlmesh.core.environment import EnvironmentNamingInfo from sqlmesh.core.plan.definition import EvaluatablePlan from sqlmesh.core.snapshot import Snapshot, SnapshotInfoLike, SnapshotTableInfo, SnapshotId +from sqlmesh.core.snapshot.execution_tracker import QueryExecutionStats from sqlmesh.core.test import ModelTest from sqlmesh.core.test.result import ModelTextTestResult from sqlmesh.utils.date import now_timestamp @@ -142,6 +143,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, + execution_stats: t.Optional[QueryExecutionStats] = None, auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: if audit_only: