From c9991a520d15af96e44cb7d1af56c6219c914277 Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Thu, 14 Aug 2025 14:05:40 -0500 Subject: [PATCH 01/31] Add rows tracking --- sqlmesh/core/console.py | 77 +++++--- sqlmesh/core/engine_adapter/base.py | 51 ++++-- sqlmesh/core/engine_adapter/bigquery.py | 5 + sqlmesh/core/engine_adapter/clickhouse.py | 4 +- sqlmesh/core/engine_adapter/duckdb.py | 2 + sqlmesh/core/engine_adapter/redshift.py | 4 +- sqlmesh/core/engine_adapter/snowflake.py | 2 + sqlmesh/core/engine_adapter/spark.py | 2 + sqlmesh/core/engine_adapter/trino.py | 2 + sqlmesh/core/execution_tracker.py | 164 ++++++++++++++++++ sqlmesh/core/scheduler.py | 130 ++++++++------ sqlmesh/core/snapshot/evaluator.py | 1 + sqlmesh/core/state_sync/db/environment.py | 2 + sqlmesh/core/state_sync/db/interval.py | 2 + sqlmesh/core/state_sync/db/migrator.py | 4 +- sqlmesh/core/state_sync/db/snapshot.py | 2 + sqlmesh/core/state_sync/db/version.py | 1 + .../integration/test_integration.py | 26 ++- tests/core/test_execution_tracker.py | 74 ++++++++ web/server/console.py | 1 + 20 files changed, 461 insertions(+), 95 deletions(-) create mode 100644 sqlmesh/core/execution_tracker.py create mode 100644 tests/core/test_execution_tracker.py diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index e046e17630..bfe2a31264 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -439,6 +439,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, + rows_processed: t.Optional[int] = None, auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: """Updates the snapshot evaluation progress.""" @@ -587,6 +588,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, + rows_processed: t.Optional[int] = None, auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: pass @@ -1032,7 +1034,9 @@ def start_evaluation_progress( # determine column widths self.evaluation_column_widths["annotation"] = ( - _calculate_annotation_str_len(batched_intervals, self.AUDIT_PADDING) + _calculate_annotation_str_len( + batched_intervals, self.AUDIT_PADDING, len(" (XXXXXX rows processed)") + ) + 3 # brackets and opening escape backslash ) self.evaluation_column_widths["name"] = max( @@ -1077,6 +1081,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, + rows_processed: t.Optional[int] = None, auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: """Update the snapshot evaluation progress.""" @@ -1097,7 +1102,7 @@ def update_snapshot_evaluation_progress( ).ljust(self.evaluation_column_widths["name"]) annotation = _create_evaluation_model_annotation( - snapshot, _format_evaluation_model_interval(snapshot, interval) + snapshot, _format_evaluation_model_interval(snapshot, interval), rows_processed ) audits_str = "" if num_audits_passed: @@ -3668,6 +3673,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, + rows_processed: t.Optional[int] = None, auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: view_name, loaded_batches = self.evaluation_batch_progress[snapshot.snapshot_id] @@ -3838,6 +3844,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, + rows_processed: t.Optional[int] = None, auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: message = f"Evaluated {snapshot.name} | batch={batch_idx} | duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}" @@ -4022,7 +4029,8 @@ def show_table_diff_summary(self, table_diff: TableDiff) -> None: self._write(f"Join On: {keys}") -_CONSOLE: Console = NoopConsole() +# _CONSOLE: Console = NoopConsole() +_CONSOLE: Console = TerminalConsole() def set_console(console: Console) -> None: @@ -4169,33 +4177,49 @@ def _format_evaluation_model_interval(snapshot: Snapshot, interval: Interval) -> return "" -def _create_evaluation_model_annotation(snapshot: Snapshot, interval_info: t.Optional[str]) -> str: +def _create_evaluation_model_annotation( + snapshot: Snapshot, interval_info: t.Optional[str], rows_processed: t.Optional[int] +) -> str: + annotation = None + num_rows_processed = str(rows_processed) if rows_processed else "" + rows_processed_str = f" ({num_rows_processed} rows processed)" if num_rows_processed else "" + if snapshot.is_audit: - return "run standalone audit" - if snapshot.is_model and snapshot.model.kind.is_external: - return "run external audits" - if snapshot.model.kind.is_seed: - return "insert seed file" - if snapshot.model.kind.is_full: - return "full refresh" - if snapshot.model.kind.is_view: - return "recreate view" - if snapshot.model.kind.is_incremental_by_unique_key: - return "insert/update rows" - if snapshot.model.kind.is_incremental_by_partition: - return "insert partitions" - - return interval_info if interval_info else "" - - -def _calculate_interval_str_len(snapshot: Snapshot, intervals: t.List[Interval]) -> int: + annotation = "run standalone audit" + if snapshot.is_model: + if snapshot.model.kind.is_external: + annotation = "run external audits" + if snapshot.model.kind.is_view: + annotation = "recreate view" + if snapshot.model.kind.is_seed: + # no "processed" for seeds + seed_num_rows_inserted = ( + f" ({num_rows_processed} rows inserted)" if num_rows_processed else "" + ) + annotation = f"insert seed file{seed_num_rows_inserted}" + if snapshot.model.kind.is_full: + annotation = f"full refresh{rows_processed_str}" + if snapshot.model.kind.is_incremental_by_unique_key: + annotation = f"insert/update rows{rows_processed_str}" + if snapshot.model.kind.is_incremental_by_partition: + annotation = f"insert partitions{rows_processed_str}" + + if annotation: + return annotation + + return f"{interval_info}{rows_processed_str}" if interval_info else "" + + +def _calculate_interval_str_len( + snapshot: Snapshot, intervals: t.List[Interval], rows_processed: t.Optional[int] = None +) -> int: interval_str_len = 0 for interval in intervals: interval_str_len = max( interval_str_len, len( _create_evaluation_model_annotation( - snapshot, _format_evaluation_model_interval(snapshot, interval) + snapshot, _format_evaluation_model_interval(snapshot, interval), rows_processed ) ), ) @@ -4248,13 +4272,16 @@ def _calculate_audit_str_len(snapshot: Snapshot, audit_padding: int = 0) -> int: def _calculate_annotation_str_len( - batched_intervals: t.Dict[Snapshot, t.List[Interval]], audit_padding: int = 0 + batched_intervals: t.Dict[Snapshot, t.List[Interval]], + audit_padding: int = 0, + rows_processed_len: int = 0, ) -> int: annotation_str_len = 0 for snapshot, intervals in batched_intervals.items(): annotation_str_len = max( annotation_str_len, _calculate_interval_str_len(snapshot, intervals) - + _calculate_audit_str_len(snapshot, audit_padding), + + _calculate_audit_str_len(snapshot, audit_padding) + + rows_processed_len, ) return annotation_str_len diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index fe19f7df0f..a3224b4a47 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -40,6 +40,7 @@ ) from sqlmesh.core.model.kind import TimeColumn from sqlmesh.core.schema_diff import SchemaDiffer, TableAlterOperation +from sqlmesh.core.execution_tracker import record_execution as track_execution_record from sqlmesh.utils import ( CorrelationId, columns_to_types_all_known, @@ -854,6 +855,7 @@ def _create_table_from_source_queries( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, + track_row_count: bool = True, **kwargs: t.Any, ) -> None: table = exp.to_table(table_name) @@ -899,11 +901,15 @@ def _create_table_from_source_queries( replace=replace, table_description=table_description, table_kind=table_kind, + track_row_count=track_row_count, **kwargs, ) else: self._insert_append_query( - table_name, query, target_columns_to_types or self.columns(table) + table_name, + query, + target_columns_to_types or self.columns(table), + track_row_count=track_row_count, ) # Register comments with commands if the engine supports comments and we weren't able to @@ -927,6 +933,7 @@ def _create_table( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, + track_row_count: bool = True, **kwargs: t.Any, ) -> None: self.execute( @@ -943,7 +950,8 @@ def _create_table( ), table_kind=table_kind, **kwargs, - ) + ), + track_row_count=track_row_count, ) def _build_create_table_exp( @@ -1431,6 +1439,7 @@ def insert_append( table_name: TableName, query_or_df: QueryOrDF, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + track_row_count: bool = True, source_columns: t.Optional[t.List[str]] = None, ) -> None: source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( @@ -1439,19 +1448,24 @@ def insert_append( target_table=table_name, source_columns=source_columns, ) - self._insert_append_source_queries(table_name, source_queries, target_columns_to_types) + self._insert_append_source_queries( + table_name, source_queries, target_columns_to_types, track_row_count + ) def _insert_append_source_queries( self, table_name: TableName, source_queries: t.List[SourceQuery], target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + track_row_count: bool = True, ) -> None: with self.transaction(condition=len(source_queries) > 0): target_columns_to_types = target_columns_to_types or self.columns(table_name) for source_query in source_queries: with source_query as query: - self._insert_append_query(table_name, query, target_columns_to_types) + self._insert_append_query( + table_name, query, target_columns_to_types, track_row_count=track_row_count + ) def _insert_append_query( self, @@ -1459,10 +1473,14 @@ def _insert_append_query( query: Query, target_columns_to_types: t.Dict[str, exp.DataType], order_projections: bool = True, + track_row_count: bool = True, ) -> None: if order_projections: query = self._order_projections_and_filter(query, target_columns_to_types) - self.execute(exp.insert(query, table_name, columns=list(target_columns_to_types))) + self.execute( + exp.insert(query, table_name, columns=list(target_columns_to_types)), + track_row_count=track_row_count, + ) def insert_overwrite_by_partition( self, @@ -1604,7 +1622,7 @@ def _insert_overwrite_by_condition( ) if insert_overwrite_strategy.is_replace_where: insert_exp.set("where", where or exp.true()) - self.execute(insert_exp) + self.execute(insert_exp, track_row_count=True) def update_table( self, @@ -1625,7 +1643,7 @@ def _merge( using = exp.alias_( exp.Subquery(this=query), alias=MERGE_SOURCE_ALIAS, copy=False, table=True ) - self.execute(exp.Merge(this=this, using=using, on=on, whens=whens)) + self.execute(exp.Merge(this=this, using=using, on=on, whens=whens), track_row_count=True) def scd_type_2_by_time( self, @@ -2374,6 +2392,7 @@ def execute( expressions: t.Union[str, exp.Expression, t.Sequence[exp.Expression]], ignore_unsupported_errors: bool = False, quote_identifiers: bool = True, + track_row_count: bool = False, **kwargs: t.Any, ) -> None: """Execute a sql query.""" @@ -2395,7 +2414,7 @@ def execute( expression=e if isinstance(e, exp.Expression) else None, quote_identifiers=quote_identifiers, ) - self._execute(sql, **kwargs) + self._execute(sql, track_row_count, **kwargs) def _attach_correlation_id(self, sql: str) -> str: if self.ATTACH_CORRELATION_ID and self.correlation_id: @@ -2420,9 +2439,20 @@ def _log_sql( logger.log(self._execute_log_level, "Executing SQL: %s", sql_to_log) - def _execute(self, sql: str, **kwargs: t.Any) -> None: + def _execute(self, sql: str, track_row_count: bool = False, **kwargs: t.Any) -> None: self.cursor.execute(sql, **kwargs) + if track_row_count: + rowcount_raw = getattr(self.cursor, "rowcount", None) + rowcount = None + if rowcount_raw is not None: + try: + rowcount = int(rowcount_raw) + except (TypeError, ValueError): + pass + + track_execution_record(sql, rowcount) + @contextlib.contextmanager def temp_table( self, @@ -2467,6 +2497,7 @@ def temp_table( exists=True, table_description=None, column_descriptions=None, + track_row_count=False, **kwargs, ) @@ -2718,7 +2749,7 @@ def _replace_by_key( insert_statement.set("where", delete_filter) insert_statement.set("this", exp.to_table(target_table)) - self.execute(insert_statement) + self.execute(insert_statement, track_row_count=True) finally: self.drop_table(temp_table) diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index 4c8a125fa3..3212acbe7c 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -21,6 +21,7 @@ SourceQuery, set_catalog, ) +from sqlmesh.core.execution_tracker import record_execution as track_execution_record from sqlmesh.core.node import IntervalUnit from sqlmesh.core.schema_diff import TableAlterOperation, NestedSupport from sqlmesh.utils import optional_import, get_source_columns_to_types @@ -1049,6 +1050,7 @@ def _db_call(self, func: t.Callable[..., t.Any], *args: t.Any, **kwargs: t.Any) def _execute( self, sql: str, + track_row_count: bool = False, **kwargs: t.Any, ) -> None: """Execute a sql query.""" @@ -1094,6 +1096,9 @@ def _execute( self.cursor._set_rowcount(query_results) self.cursor._set_description(query_results.schema) + if track_row_count: + track_execution_record(sql, query_results.total_rows) + def _get_data_objects( self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None ) -> t.List[DataObject]: diff --git a/sqlmesh/core/engine_adapter/clickhouse.py b/sqlmesh/core/engine_adapter/clickhouse.py index 635e6f369b..7458e4887c 100644 --- a/sqlmesh/core/engine_adapter/clickhouse.py +++ b/sqlmesh/core/engine_adapter/clickhouse.py @@ -294,7 +294,7 @@ def _insert_overwrite_by_condition( ) try: - self.execute(existing_records_insert_exp) + self.execute(existing_records_insert_exp, track_row_count=True) finally: if table_partition_exp: self.drop_table(partitions_temp_table_name) @@ -489,6 +489,7 @@ def _create_table( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, + track_row_count: bool = True, **kwargs: t.Any, ) -> None: """Creates a table in the database. @@ -525,6 +526,7 @@ def _create_table( column_descriptions, table_kind, empty_ctas=(self.engine_run_mode.is_cloud and expression is not None), + track_row_count=track_row_count, **kwargs, ) diff --git a/sqlmesh/core/engine_adapter/duckdb.py b/sqlmesh/core/engine_adapter/duckdb.py index 4bce813610..f08aebdf06 100644 --- a/sqlmesh/core/engine_adapter/duckdb.py +++ b/sqlmesh/core/engine_adapter/duckdb.py @@ -170,6 +170,7 @@ def _create_table( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, + track_row_count: bool = True, **kwargs: t.Any, ) -> None: catalog = self.get_current_catalog() @@ -193,6 +194,7 @@ def _create_table( table_description, column_descriptions, table_kind, + track_row_count=track_row_count, **kwargs, ) diff --git a/sqlmesh/core/engine_adapter/redshift.py b/sqlmesh/core/engine_adapter/redshift.py index 30ebc8e30d..fa9575820e 100644 --- a/sqlmesh/core/engine_adapter/redshift.py +++ b/sqlmesh/core/engine_adapter/redshift.py @@ -173,6 +173,7 @@ def _create_table_from_source_queries( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, + track_row_count: bool = True, **kwargs: t.Any, ) -> None: """ @@ -426,7 +427,8 @@ def resolve_target_table(expression: exp.Expression) -> exp.Expression: using=using, on=on.transform(resolve_target_table), whens=whens.transform(resolve_target_table), - ) + ), + track_row_count=True, ) def _normalize_decimal_value(self, expr: exp.Expression, precision: int) -> exp.Expression: diff --git a/sqlmesh/core/engine_adapter/snowflake.py b/sqlmesh/core/engine_adapter/snowflake.py index c5fa8540b0..ad04645185 100644 --- a/sqlmesh/core/engine_adapter/snowflake.py +++ b/sqlmesh/core/engine_adapter/snowflake.py @@ -166,6 +166,7 @@ def _create_table( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, + track_row_count: bool = True, **kwargs: t.Any, ) -> None: table_format = kwargs.get("table_format") @@ -185,6 +186,7 @@ def _create_table( table_description=table_description, column_descriptions=column_descriptions, table_kind=table_kind, + track_row_count=track_row_count, **kwargs, ) diff --git a/sqlmesh/core/engine_adapter/spark.py b/sqlmesh/core/engine_adapter/spark.py index 8a529390c1..99ff623301 100644 --- a/sqlmesh/core/engine_adapter/spark.py +++ b/sqlmesh/core/engine_adapter/spark.py @@ -433,6 +433,7 @@ def _create_table( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, + track_row_count: bool = True, **kwargs: t.Any, ) -> None: table_name = ( @@ -461,6 +462,7 @@ def _create_table( target_columns_to_types=target_columns_to_types, table_description=table_description, column_descriptions=column_descriptions, + track_row_count=track_row_count, **kwargs, ) table_name = ( diff --git a/sqlmesh/core/engine_adapter/trino.py b/sqlmesh/core/engine_adapter/trino.py index fc08dd10af..936680baab 100644 --- a/sqlmesh/core/engine_adapter/trino.py +++ b/sqlmesh/core/engine_adapter/trino.py @@ -357,6 +357,7 @@ def _create_table( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, + track_row_count: bool = True, **kwargs: t.Any, ) -> None: super()._create_table( @@ -368,6 +369,7 @@ def _create_table( table_description=table_description, column_descriptions=column_descriptions, table_kind=table_kind, + track_row_count=track_row_count, **kwargs, ) diff --git a/sqlmesh/core/execution_tracker.py b/sqlmesh/core/execution_tracker.py new file mode 100644 index 0000000000..24d1a8af7c --- /dev/null +++ b/sqlmesh/core/execution_tracker.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +import time +import typing as t +from contextlib import contextmanager +from threading import get_ident, Lock +from dataclasses import dataclass, field + + +@dataclass +class QueryExecutionContext: + """ + Container for tracking rows processed or other execution information during snapshot evaluation. + + It accumulates statistics from multiple cursor.execute() calls during a single snapshot evaluation. + + Attributes: + id: Identifier linking this context to a specific operation + total_rows_processed: Running sum of cursor.rowcount from all executed queries during evaluation + query_count: Total number of SQL statements executed + queries_executed: List of (sql_snippet, row_count, timestamp) tuples for debugging + """ + + id: str + total_rows_processed: int = 0 + query_count: int = 0 + queries_executed: t.List[t.Tuple[str, t.Optional[int], float]] = field(default_factory=list) + + def add_execution(self, sql: str, row_count: t.Optional[int]) -> None: + """Record a single query execution.""" + if row_count is not None and row_count >= 0: + self.total_rows_processed += row_count + self.query_count += 1 + # for debugging + self.queries_executed.append((sql[:300], row_count, time.time())) + + def get_execution_stats(self) -> t.Dict[str, t.Any]: + return { + "id": self.id, + "total_rows_processed": self.total_rows_processed, + "query_count": self.query_count, + "queries": self.queries_executed, + } + + +class QueryExecutionTracker: + """ + Thread-local context manager for snapshot evaluation execution statistics, such as + rows processed. + """ + + _thread_contexts: t.Dict[int, QueryExecutionContext] = {} + _contexts_lock = Lock() + + @classmethod + def get_execution_context(cls) -> t.Optional[QueryExecutionContext]: + thread_id = get_ident() + with cls._contexts_lock: + return cls._thread_contexts.get(thread_id) + + @classmethod + def is_tracking(cls) -> bool: + return cls.get_execution_context() is not None + + @classmethod + @contextmanager + def track_execution(cls, snapshot_name_batch: str) -> t.Iterator[QueryExecutionContext]: + """ + Context manager for tracking snapshot evaluation execution statistics. + """ + context = QueryExecutionContext(id=snapshot_name_batch) + thread_id = get_ident() + + with cls._contexts_lock: + cls._thread_contexts[thread_id] = context + try: + yield context + finally: + with cls._contexts_lock: + cls._thread_contexts.pop(thread_id, None) + + @classmethod + def record_execution(cls, sql: str, row_count: t.Optional[int]) -> None: + thread_id = get_ident() + with cls._contexts_lock: + context = cls._thread_contexts.get(thread_id) + if context is not None: + context.add_execution(sql, row_count) + + @classmethod + def get_execution_stats(cls) -> t.Optional[t.Dict[str, t.Any]]: + context = cls.get_execution_context() + return context.get_execution_stats() if context else None + + +class SeedExecutionTracker: + _seed_contexts: t.Dict[str, QueryExecutionContext] = {} + _active_threads: t.Set[int] = set() + _thread_to_seed_id: t.Dict[int, str] = {} + _seed_contexts_lock = Lock() + + @classmethod + @contextmanager + def track_execution(cls, model_name: str) -> t.Iterator[QueryExecutionContext]: + """ + Context manager for tracking seed creation execution statistics. + """ + + context = QueryExecutionContext(id=model_name) + thread_id = get_ident() + + with cls._seed_contexts_lock: + cls._seed_contexts[model_name] = context + cls._active_threads.add(thread_id) + cls._thread_to_seed_id[thread_id] = model_name + + try: + yield context + finally: + with cls._seed_contexts_lock: + cls._active_threads.discard(thread_id) + cls._thread_to_seed_id.pop(thread_id, None) + + @classmethod + def get_and_clear_seed_stats(cls, model_name: str) -> t.Optional[t.Dict[str, t.Any]]: + with cls._seed_contexts_lock: + context = cls._seed_contexts.pop(model_name, None) + return context.get_execution_stats() if context else None + + @classmethod + def clear_all_seed_stats(cls) -> None: + """Clear all remaining seed stats. Used for cleanup after evaluation completes.""" + with cls._seed_contexts_lock: + cls._seed_contexts.clear() + + @classmethod + def is_tracking(cls) -> bool: + thread_id = get_ident() + with cls._seed_contexts_lock: + return thread_id in cls._active_threads + + @classmethod + def record_execution(cls, sql: str, row_count: t.Optional[int]) -> None: + thread_id = get_ident() + with cls._seed_contexts_lock: + seed_id = cls._thread_to_seed_id.get(thread_id) + if not seed_id: + return + context = cls._seed_contexts.get(seed_id) + if context is not None: + context.add_execution(sql, row_count) + + +def record_execution(sql: str, row_count: t.Optional[int]) -> None: + """ + Record execution statistics for a single SQL statement. + + Automatically infers which tracker is active based on the current thread. + """ + if SeedExecutionTracker.is_tracking(): + SeedExecutionTracker.record_execution(sql, row_count) + return + if QueryExecutionTracker.is_tracking(): + QueryExecutionTracker.record_execution(sql, row_count) diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 7a653877ae..61fcad3eee 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -9,6 +9,7 @@ from sqlmesh.core import constants as c from sqlmesh.core.console import Console, get_console from sqlmesh.core.environment import EnvironmentNamingInfo, execute_environment_statements +from sqlmesh.core.execution_tracker import QueryExecutionTracker, SeedExecutionTracker from sqlmesh.core.macros import RuntimeStage from sqlmesh.core.model.definition import AuditResult from sqlmesh.core.node import IntervalUnit @@ -490,65 +491,84 @@ def run_node(node: SchedulingUnit) -> None: if isinstance(node, DummyNode): return - snapshot = self.snapshots_by_name[node.snapshot_name] - - if isinstance(node, EvaluateNode): - self.console.start_snapshot_evaluation_progress(snapshot) - execution_start_ts = now_timestamp() - evaluation_duration_ms: t.Optional[int] = None - start, end = node.interval - - audit_results: t.List[AuditResult] = [] - try: - assert execution_time # mypy - assert deployability_index # mypy - - if audit_only: - audit_results = self._audit_snapshot( - snapshot=snapshot, - environment_naming_info=environment_naming_info, - deployability_index=deployability_index, - snapshots=self.snapshots_by_name, - start=start, - end=end, - execution_time=execution_time, - ) - else: - audit_results = self.evaluate( - snapshot=snapshot, - environment_naming_info=environment_naming_info, - start=start, - end=end, - execution_time=execution_time, - deployability_index=deployability_index, - batch_index=node.batch_index, - allow_destructive_snapshots=allow_destructive_snapshots, - allow_additive_snapshots=allow_additive_snapshots, + with QueryExecutionTracker.track_execution( + f"{snapshot.name}_{batch_idx}" + ) as execution_context: + snapshot = self.snapshots_by_name[node.snapshot_name] + + if isinstance(node, EvaluateNode): + self.console.start_snapshot_evaluation_progress(snapshot) + execution_start_ts = now_timestamp() + evaluation_duration_ms: t.Optional[int] = None + start, end = node.interval + + audit_results: t.List[AuditResult] = [] + try: + assert execution_time # mypy + assert deployability_index # mypy + + if audit_only: + audit_results = self._audit_snapshot( + snapshot=snapshot, + environment_naming_info=environment_naming_info, + deployability_index=deployability_index, + snapshots=self.snapshots_by_name, + start=start, + end=end, + execution_time=execution_time, + ) + else: + audit_results = self.evaluate( + snapshot=snapshot, + environment_naming_info=environment_naming_info, + start=start, + end=end, + execution_time=execution_time, + deployability_index=deployability_index, + batch_index=node.batch_index, + allow_destructive_snapshots=allow_destructive_snapshots, + allow_additive_snapshots=allow_additive_snapshots, target_table_exists=snapshot.snapshot_id not in snapshots_to_create, - ) - - evaluation_duration_ms = now_timestamp() - execution_start_ts - finally: - num_audits = len(audit_results) - num_audits_failed = sum(1 for result in audit_results if result.count) - self.console.update_snapshot_evaluation_progress( - snapshot, - batched_intervals[snapshot][node.batch_index], - node.batch_index, - evaluation_duration_ms, - num_audits - num_audits_failed, - num_audits_failed, - auto_restatement_triggers=auto_restatement_triggers.get( + ) + + evaluation_duration_ms = now_timestamp() - execution_start_ts + finally: + num_audits = len(audit_results) + num_audits_failed = sum(1 for result in audit_results if result.count) + + rows_processed = None + if snapshot.is_seed: + # seed stats are tracked in SeedStrategy.create by model name, not snapshot name + seed_stats = SeedExecutionTracker.get_and_clear_seed_stats( + snapshot.model.name + ) + rows_processed = ( + seed_stats.get("total_rows_processed") if seed_stats else None + ) + else: + rows_processed = ( + execution_context.total_rows_processed if execution_context else None + ) + + self.console.update_snapshot_evaluation_progress( + snapshot, + batched_intervals[snapshot][node.batch_index], + node.batch_index, + evaluation_duration_ms, + num_audits - num_audits_failed, + num_audits_failed, + auto_restatement_triggers=auto_restatement_triggers.get( snapshot.snapshot_id ), ) - elif isinstance(node, CreateNode): - self.snapshot_evaluator.create_snapshot( - snapshot=snapshot, - snapshots=self.snapshots_by_name, - deployability_index=deployability_index, - allow_destructive_snapshots=allow_destructive_snapshots or set(), - allow_additive_snapshots=allow_additive_snapshots or set(), + elif isinstance(node, CreateNode): + self.snapshot_evaluator.create_snapshot( + snapshot=snapshot, + snapshots=self.snapshots_by_name, + deployability_index=deployability_index, + allow_destructive_snapshots=allow_destructive_snapshots or set(), + rows_processed=rows_processed, + allow_additive_snapshots=allow_additive_snapshots or set(), ) try: diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 82924e4c3a..9535197718 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -39,6 +39,7 @@ from sqlmesh.core.dialect import schema_ from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy, DataObjectType +from sqlmesh.core.execution_tracker import SeedExecutionTracker from sqlmesh.core.macros import RuntimeStage from sqlmesh.core.model import ( AuditResult, diff --git a/sqlmesh/core/state_sync/db/environment.py b/sqlmesh/core/state_sync/db/environment.py index 3196d18078..db3844f0ce 100644 --- a/sqlmesh/core/state_sync/db/environment.py +++ b/sqlmesh/core/state_sync/db/environment.py @@ -78,6 +78,7 @@ def update_environment(self, environment: Environment) -> None: self.environments_table, _environment_to_df(environment), target_columns_to_types=self._environment_columns_to_types, + track_row_count=False, ) def update_environment_statements( @@ -108,6 +109,7 @@ def update_environment_statements( self.environment_statements_table, _environment_statements_to_df(environment_name, plan_id, environment_statements), target_columns_to_types=self._environment_statements_columns_to_types, + track_row_count=False, ) def invalidate_environment(self, name: str, protect_prod: bool = True) -> None: diff --git a/sqlmesh/core/state_sync/db/interval.py b/sqlmesh/core/state_sync/db/interval.py index bdfedace1e..ae37fd9734 100644 --- a/sqlmesh/core/state_sync/db/interval.py +++ b/sqlmesh/core/state_sync/db/interval.py @@ -115,6 +115,7 @@ def remove_intervals( self.intervals_table, _intervals_to_df(intervals_to_remove, is_dev=False, is_removed=True), target_columns_to_types=self._interval_columns_to_types, + track_row_count=False, ) def get_snapshot_intervals( @@ -243,6 +244,7 @@ def _push_snapshot_intervals( self.intervals_table, pd.DataFrame(new_intervals), target_columns_to_types=self._interval_columns_to_types, + track_row_count=False, ) def _get_snapshot_intervals( diff --git a/sqlmesh/core/state_sync/db/migrator.py b/sqlmesh/core/state_sync/db/migrator.py index ca89668763..f796a26f62 100644 --- a/sqlmesh/core/state_sync/db/migrator.py +++ b/sqlmesh/core/state_sync/db/migrator.py @@ -413,7 +413,9 @@ def _backup_state(self) -> None: backup_name = _backup_table_name(table) self.engine_adapter.drop_table(backup_name) self.engine_adapter.create_table_like(backup_name, table) - self.engine_adapter.insert_append(backup_name, exp.select("*").from_(table)) + self.engine_adapter.insert_append( + backup_name, exp.select("*").from_(table), track_row_count=False + ) def _restore_table( self, diff --git a/sqlmesh/core/state_sync/db/snapshot.py b/sqlmesh/core/state_sync/db/snapshot.py index 9cf4f2fbf5..223b11153e 100644 --- a/sqlmesh/core/state_sync/db/snapshot.py +++ b/sqlmesh/core/state_sync/db/snapshot.py @@ -103,6 +103,7 @@ def push_snapshots(self, snapshots: t.Iterable[Snapshot], overwrite: bool = Fals self.snapshots_table, _snapshots_to_df(snapshots_to_store), target_columns_to_types=self._snapshot_columns_to_types, + track_row_count=False, ) for snapshot in snapshots: @@ -406,6 +407,7 @@ def _push_snapshots(self, snapshots: t.Iterable[Snapshot]) -> None: self.snapshots_table, _snapshots_to_df(snapshots_to_store), target_columns_to_types=self._snapshot_columns_to_types, + track_row_count=False, ) def _get_snapshots( diff --git a/sqlmesh/core/state_sync/db/version.py b/sqlmesh/core/state_sync/db/version.py index 492d74cc09..2732c0ca47 100644 --- a/sqlmesh/core/state_sync/db/version.py +++ b/sqlmesh/core/state_sync/db/version.py @@ -55,6 +55,7 @@ def update_versions( ] ), target_columns_to_types=self._version_columns_to_types, + track_row_count=False, ) def get_versions(self) -> Versions: diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index 1960848e24..afdadb560c 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -7,7 +7,7 @@ import typing as t import shutil from datetime import datetime, timedelta - +from unittest.mock import patch import numpy as np # noqa: TID253 import pandas as pd # noqa: TID253 import pytest @@ -2382,8 +2382,25 @@ def _mutate_config(gateway: str, config: Config): ) context._models.update(replacement_models) + # capture row counts for each evaluated snapshot + row_counts = {} + + def capture_row_counts( + snapshot, + interval, + batch_idx, + duration_ms, + num_audits_passed, + num_audits_failed, + audit_only=False, + rows_processed=None, + ): + if rows_processed is not None: + row_counts[snapshot.model.name.replace(f"{schema_name}.", "")] = rows_processed + # apply prod plan - context.plan(auto_apply=True, no_prompts=True) + with patch.object(context.console, "update_snapshot_evaluation_progress", capture_row_counts): + context.plan(auto_apply=True, no_prompts=True) prod_schema_results = ctx.get_metadata_results(object_names["view_schema"][0]) assert sorted(prod_schema_results.views) == object_names["views"] @@ -2395,6 +2412,11 @@ def _mutate_config(gateway: str, config: Config): assert len(physical_layer_results.materialized_views) == 0 assert len(physical_layer_results.tables) == len(physical_layer_results.non_temp_tables) == 3 + assert len(row_counts) == 3 + assert row_counts["seed_model"] == 7 + assert row_counts["incremental_model"] == 7 + assert row_counts["full_model"] == 3 + # make and validate unmodified dev environment no_change_plan: Plan = context.plan_builder( environment="test_dev", diff --git a/tests/core/test_execution_tracker.py b/tests/core/test_execution_tracker.py new file mode 100644 index 0000000000..7bdcbdb62a --- /dev/null +++ b/tests/core/test_execution_tracker.py @@ -0,0 +1,74 @@ +# Tests the sqlmesh.core.execution_tracker module +# - creates a scenario where executions will take place in multiple threads +# - generates the scenario with known numbers of rows to be processed +# - tests that the execution tracker correctly tracks the number of rows processed in both threads +# - may use mocks, an existing test project, manually created snapshots, or a duckdb database to create the scenario + +from __future__ import annotations + +import threading +from queue import Queue +from typing import List, Optional + +from sqlmesh.core.execution_tracker import QueryExecutionTracker + + +def test_execution_tracker_thread_isolation_and_aggregation() -> None: + """ + Two worker threads each track executions in their own context. Verify: + - isolation across threads + - correct aggregation of rows + - query metadata is captured + - main thread has no active tracking + """ + + assert not QueryExecutionTracker.is_tracking() + assert QueryExecutionTracker.get_execution_stats() is None + + counts_a: List[Optional[int]] = [10, 5, None] + counts_b: List[Optional[int]] = [3, 7] + + start_barrier = threading.Barrier(3) # 2 workers + main + results: "Queue[dict]" = Queue() + + def worker(batch_id: str, counts: List[Optional[int]]) -> None: + with QueryExecutionTracker.track_execution(batch_id) as ctx: + # tracking active in this thread + assert QueryExecutionTracker.is_tracking() + # synchronize start to overlap execution + start_barrier.wait() + for c in counts: + QueryExecutionTracker.record_execution("SELECT 1", c) + + stats = ctx.get_execution_stats() + + assert stats["snapshot_batch"] == batch_id + assert stats["query_count"] == len(counts) + results.put(stats) + + t1 = threading.Thread(target=worker, args=("batch_A", counts_a)) + t2 = threading.Thread(target=worker, args=("batch_B", counts_b)) + + t1.start() + t2.start() + # Release workers at the same time + start_barrier.wait() + t1.join() + t2.join() + + # Main thread has no active tracking context + assert not QueryExecutionTracker.is_tracking() + QueryExecutionTracker.record_execution("q", 10) + assert QueryExecutionTracker.get_execution_stats() is None + + collected = [results.get_nowait(), results.get_nowait()] + # by name since order is non-deterministic + by_batch = {s["snapshot_batch"]: s for s in collected} + + stats_a = by_batch["batch_A"] + assert stats_a["total_rows_processed"] == 15 # 10 + 5 + 0 (None) + assert stats_a["query_count"] == len(counts_a) + + stats_b = by_batch["batch_B"] + assert stats_b["total_rows_processed"] == 10 # 3 + 7 + assert stats_b["query_count"] == len(counts_b) diff --git a/web/server/console.py b/web/server/console.py index 902a85418c..b2d12cd624 100644 --- a/web/server/console.py +++ b/web/server/console.py @@ -142,6 +142,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, + rows_processed: t.Optional[int] = None, auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: if audit_only: From dba820656bcf4003d0646263b992c0956effc7af Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Thu, 14 Aug 2025 15:17:08 -0500 Subject: [PATCH 02/31] Add flag for supporting row tracking to engine adapters --- .circleci/continue_config.yml | 9 +++++---- sqlmesh/core/console.py | 1 + sqlmesh/core/engine_adapter/base.py | 4 +++- sqlmesh/core/engine_adapter/mssql.py | 1 + sqlmesh/core/engine_adapter/mysql.py | 1 + sqlmesh/core/engine_adapter/postgres.py | 1 + sqlmesh/core/engine_adapter/trino.py | 1 + sqlmesh/core/execution_tracker.py | 1 + .../core/engine_adapter/integration/test_integration.py | 9 +++++---- 9 files changed, 19 insertions(+), 9 deletions(-) diff --git a/.circleci/continue_config.yml b/.circleci/continue_config.yml index 8f8324a2a0..b38a587a6d 100644 --- a/.circleci/continue_config.yml +++ b/.circleci/continue_config.yml @@ -310,10 +310,11 @@ workflows: - athena - fabric - gcp-postgres - filters: - branches: - only: - - main + # TODO: uncomment this + # filters: + # branches: + # only: + # - main - ui_style - ui_test - vscode_test diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index bfe2a31264..5baea97d8b 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -4029,6 +4029,7 @@ def show_table_diff_summary(self, table_diff: TableDiff) -> None: self._write(f"Join On: {keys}") +# TODO: remove this # _CONSOLE: Console = NoopConsole() _CONSOLE: Console = TerminalConsole() diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index a3224b4a47..c872401799 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -118,6 +118,8 @@ class EngineAdapter: QUOTE_IDENTIFIERS_IN_VIEWS = True MAX_IDENTIFIER_LENGTH: t.Optional[int] = None ATTACH_CORRELATION_ID = True + # TODO: change to False + SUPPORTS_QUERY_EXECUTION_TRACKING = True def __init__( self, @@ -2442,7 +2444,7 @@ def _log_sql( def _execute(self, sql: str, track_row_count: bool = False, **kwargs: t.Any) -> None: self.cursor.execute(sql, **kwargs) - if track_row_count: + if track_row_count and self.SUPPORTS_QUERY_EXECUTION_TRACKING: rowcount_raw = getattr(self.cursor, "rowcount", None) rowcount = None if rowcount_raw is not None: diff --git a/sqlmesh/core/engine_adapter/mssql.py b/sqlmesh/core/engine_adapter/mssql.py index 6aefd51fc0..50a67b4b37 100644 --- a/sqlmesh/core/engine_adapter/mssql.py +++ b/sqlmesh/core/engine_adapter/mssql.py @@ -53,6 +53,7 @@ class MSSQLEngineAdapter( COMMENT_CREATION_TABLE = CommentCreationTable.UNSUPPORTED COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED SUPPORTS_REPLACE_TABLE = False + SUPPORTS_QUERY_EXECUTION_TRACKING = True SCHEMA_DIFFER_KWARGS = { "parameterized_type_defaults": { exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(18, 0), (0,)], diff --git a/sqlmesh/core/engine_adapter/mysql.py b/sqlmesh/core/engine_adapter/mysql.py index e81b30e25e..26cc7c0197 100644 --- a/sqlmesh/core/engine_adapter/mysql.py +++ b/sqlmesh/core/engine_adapter/mysql.py @@ -39,6 +39,7 @@ class MySQLEngineAdapter( MAX_COLUMN_COMMENT_LENGTH = 1024 SUPPORTS_REPLACE_TABLE = False MAX_IDENTIFIER_LENGTH = 64 + SUPPORTS_QUERY_EXECUTION_TRACKING = True SCHEMA_DIFFER_KWARGS = { "parameterized_type_defaults": { exp.DataType.build("BIT", dialect=DIALECT).this: [(1,)], diff --git a/sqlmesh/core/engine_adapter/postgres.py b/sqlmesh/core/engine_adapter/postgres.py index faeb52b207..e9c212bd5f 100644 --- a/sqlmesh/core/engine_adapter/postgres.py +++ b/sqlmesh/core/engine_adapter/postgres.py @@ -35,6 +35,7 @@ class PostgresEngineAdapter( CURRENT_CATALOG_EXPRESSION = exp.column("current_catalog") SUPPORTS_REPLACE_TABLE = False MAX_IDENTIFIER_LENGTH = 63 + SUPPORTS_QUERY_EXECUTION_TRACKING = True SCHEMA_DIFFER_KWARGS = { "parameterized_type_defaults": { # DECIMAL without precision is "up to 131072 digits before the decimal point; up to 16383 digits after the decimal point" diff --git a/sqlmesh/core/engine_adapter/trino.py b/sqlmesh/core/engine_adapter/trino.py index 936680baab..21dac81255 100644 --- a/sqlmesh/core/engine_adapter/trino.py +++ b/sqlmesh/core/engine_adapter/trino.py @@ -55,6 +55,7 @@ class TrinoEngineAdapter( SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA"] DEFAULT_CATALOG_TYPE = "hive" QUOTE_IDENTIFIERS_IN_VIEWS = False + SUPPORTS_QUERY_EXECUTION_TRACKING = True SCHEMA_DIFFER_KWARGS = { "parameterized_type_defaults": { # default decimal precision varies across backends diff --git a/sqlmesh/core/execution_tracker.py b/sqlmesh/core/execution_tracker.py index 24d1a8af7c..8e900a26b4 100644 --- a/sqlmesh/core/execution_tracker.py +++ b/sqlmesh/core/execution_tracker.py @@ -31,6 +31,7 @@ def add_execution(self, sql: str, row_count: t.Optional[int]) -> None: if row_count is not None and row_count >= 0: self.total_rows_processed += row_count self.query_count += 1 + # TODO: remove this # for debugging self.queries_executed.append((sql[:300], row_count, time.time())) diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index afdadb560c..4298cd58a7 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -2412,10 +2412,11 @@ def capture_row_counts( assert len(physical_layer_results.materialized_views) == 0 assert len(physical_layer_results.tables) == len(physical_layer_results.non_temp_tables) == 3 - assert len(row_counts) == 3 - assert row_counts["seed_model"] == 7 - assert row_counts["incremental_model"] == 7 - assert row_counts["full_model"] == 3 + if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING: + assert len(row_counts) == 3 + assert row_counts["seed_model"] == 7 + assert row_counts["incremental_model"] == 7 + assert row_counts["full_model"] == 3 # make and validate unmodified dev environment no_change_plan: Plan = context.plan_builder( From 296c25c84daabfdabc7faf6a95f3d8cafb093b8e Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Thu, 14 Aug 2025 15:34:24 -0500 Subject: [PATCH 03/31] Run cloud tests in CI --- .circleci/continue_config.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.circleci/continue_config.yml b/.circleci/continue_config.yml index b38a587a6d..35f15e2c2f 100644 --- a/.circleci/continue_config.yml +++ b/.circleci/continue_config.yml @@ -297,8 +297,9 @@ workflows: name: cloud_engine_<< matrix.engine >> context: - sqlmesh_cloud_database_integration - requires: - - engine_tests_docker + # TODO: uncomment this + # requires: + # - engine_tests_docker matrix: parameters: engine: From ed6a56447e856738db43ce32b53b61efc4150729 Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Thu, 14 Aug 2025 17:14:44 -0500 Subject: [PATCH 04/31] Add engine support flags --- sqlmesh/core/engine_adapter/athena.py | 1 + sqlmesh/core/engine_adapter/base.py | 3 +-- sqlmesh/core/engine_adapter/base_postgres.py | 1 + sqlmesh/core/engine_adapter/bigquery.py | 9 ++++++++- 4 files changed, 11 insertions(+), 3 deletions(-) diff --git a/sqlmesh/core/engine_adapter/athena.py b/sqlmesh/core/engine_adapter/athena.py index 48b9e4ad4e..aa8a5ce0c1 100644 --- a/sqlmesh/core/engine_adapter/athena.py +++ b/sqlmesh/core/engine_adapter/athena.py @@ -45,6 +45,7 @@ class AthenaEngineAdapter(PandasNativeFetchDFSupportMixin, RowDiffMixin): # >>> self._execute('/* test */ DESCRIBE foo') # pyathena.error.OperationalError: FAILED: ParseException line 1:0 cannot recognize input near '/' '*' 'test' ATTACH_CORRELATION_ID = False + SUPPORTS_QUERY_EXECUTION_TRACKING = True SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["DATABASE", "SCHEMA"] def __init__( diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index c872401799..9834fdac59 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -118,8 +118,7 @@ class EngineAdapter: QUOTE_IDENTIFIERS_IN_VIEWS = True MAX_IDENTIFIER_LENGTH: t.Optional[int] = None ATTACH_CORRELATION_ID = True - # TODO: change to False - SUPPORTS_QUERY_EXECUTION_TRACKING = True + SUPPORTS_QUERY_EXECUTION_TRACKING = False def __init__( self, diff --git a/sqlmesh/core/engine_adapter/base_postgres.py b/sqlmesh/core/engine_adapter/base_postgres.py index 26446aacfd..c6ba7d6d62 100644 --- a/sqlmesh/core/engine_adapter/base_postgres.py +++ b/sqlmesh/core/engine_adapter/base_postgres.py @@ -24,6 +24,7 @@ class BasePostgresEngineAdapter(EngineAdapter): DEFAULT_BATCH_SIZE = 400 COMMENT_CREATION_TABLE = CommentCreationTable.COMMENT_COMMAND_ONLY COMMENT_CREATION_VIEW = CommentCreationView.COMMENT_COMMAND_ONLY + SUPPORTS_QUERY_EXECUTION_TRACKING = True SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA", "TABLE", "VIEW"] def columns( diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index 3212acbe7c..703ce51a37 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -67,6 +67,7 @@ class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin, ClusteredByMixin, Row SUPPORTS_CLONING = True MAX_TABLE_COMMENT_LENGTH = 1024 MAX_COLUMN_COMMENT_LENGTH = 1024 + SUPPORTS_QUERY_EXECUTION_TRACKING = True SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA"] SCHEMA_DIFFER_KWARGS = { @@ -1097,7 +1098,13 @@ def _execute( self.cursor._set_description(query_results.schema) if track_row_count: - track_execution_record(sql, query_results.total_rows) + if query_job.statement_type == "CREATE_TABLE_AS_SELECT": + query_table = self.client.get_table(query_job.destination) + num_rows = query_table.num_rows + elif query_job.statement_type in ["INSERT", "DELETE", "MERGE", "UPDATE"]: + num_rows = query_job.num_dml_affected_rows + + track_execution_record(sql, num_rows) def _get_data_objects( self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None From 700f00600b2d6b407423fee6a5f23f2776b01dcd Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Fri, 15 Aug 2025 10:34:03 -0500 Subject: [PATCH 05/31] Use threading.local() instead of locks --- sqlmesh/core/execution_tracker.py | 63 ++++++++++--------------------- 1 file changed, 20 insertions(+), 43 deletions(-) diff --git a/sqlmesh/core/execution_tracker.py b/sqlmesh/core/execution_tracker.py index 8e900a26b4..3002c11c89 100644 --- a/sqlmesh/core/execution_tracker.py +++ b/sqlmesh/core/execution_tracker.py @@ -3,7 +3,7 @@ import time import typing as t from contextlib import contextmanager -from threading import get_ident, Lock +from threading import local from dataclasses import dataclass, field @@ -50,14 +50,11 @@ class QueryExecutionTracker: rows processed. """ - _thread_contexts: t.Dict[int, QueryExecutionContext] = {} - _contexts_lock = Lock() + _thread_local = local() @classmethod def get_execution_context(cls) -> t.Optional[QueryExecutionContext]: - thread_id = get_ident() - with cls._contexts_lock: - return cls._thread_contexts.get(thread_id) + return getattr(cls._thread_local, "context", None) @classmethod def is_tracking(cls) -> bool: @@ -70,23 +67,18 @@ def track_execution(cls, snapshot_name_batch: str) -> t.Iterator[QueryExecutionC Context manager for tracking snapshot evaluation execution statistics. """ context = QueryExecutionContext(id=snapshot_name_batch) - thread_id = get_ident() - - with cls._contexts_lock: - cls._thread_contexts[thread_id] = context + cls._thread_local.context = context try: yield context finally: - with cls._contexts_lock: - cls._thread_contexts.pop(thread_id, None) + if hasattr(cls._thread_local, "context"): + delattr(cls._thread_local, "context") @classmethod def record_execution(cls, sql: str, row_count: t.Optional[int]) -> None: - thread_id = get_ident() - with cls._contexts_lock: - context = cls._thread_contexts.get(thread_id) - if context is not None: - context.add_execution(sql, row_count) + context = cls.get_execution_context() + if context is not None: + context.add_execution(sql, row_count) @classmethod def get_execution_stats(cls) -> t.Optional[t.Dict[str, t.Any]]: @@ -96,9 +88,7 @@ def get_execution_stats(cls) -> t.Optional[t.Dict[str, t.Any]]: class SeedExecutionTracker: _seed_contexts: t.Dict[str, QueryExecutionContext] = {} - _active_threads: t.Set[int] = set() - _thread_to_seed_id: t.Dict[int, str] = {} - _seed_contexts_lock = Lock() + _thread_local = local() @classmethod @contextmanager @@ -106,47 +96,34 @@ def track_execution(cls, model_name: str) -> t.Iterator[QueryExecutionContext]: """ Context manager for tracking seed creation execution statistics. """ - context = QueryExecutionContext(id=model_name) - thread_id = get_ident() - - with cls._seed_contexts_lock: - cls._seed_contexts[model_name] = context - cls._active_threads.add(thread_id) - cls._thread_to_seed_id[thread_id] = model_name + cls._seed_contexts[model_name] = context + cls._thread_local.seed_id = model_name try: yield context finally: - with cls._seed_contexts_lock: - cls._active_threads.discard(thread_id) - cls._thread_to_seed_id.pop(thread_id, None) + if hasattr(cls._thread_local, "seed_id"): + delattr(cls._thread_local, "seed_id") @classmethod def get_and_clear_seed_stats(cls, model_name: str) -> t.Optional[t.Dict[str, t.Any]]: - with cls._seed_contexts_lock: - context = cls._seed_contexts.pop(model_name, None) - return context.get_execution_stats() if context else None + context = cls._seed_contexts.pop(model_name, None) + return context.get_execution_stats() if context else None @classmethod def clear_all_seed_stats(cls) -> None: """Clear all remaining seed stats. Used for cleanup after evaluation completes.""" - with cls._seed_contexts_lock: - cls._seed_contexts.clear() + cls._seed_contexts.clear() @classmethod def is_tracking(cls) -> bool: - thread_id = get_ident() - with cls._seed_contexts_lock: - return thread_id in cls._active_threads + return hasattr(cls._thread_local, "seed_id") @classmethod def record_execution(cls, sql: str, row_count: t.Optional[int]) -> None: - thread_id = get_ident() - with cls._seed_contexts_lock: - seed_id = cls._thread_to_seed_id.get(thread_id) - if not seed_id: - return + seed_id = getattr(cls._thread_local, "seed_id", None) + if seed_id: context = cls._seed_contexts.get(seed_id) if context is not None: context.add_execution(sql, row_count) From 8436adf24200b153d06b54a03e3efab0e4d03bdb Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Fri, 15 Aug 2025 18:23:38 -0500 Subject: [PATCH 06/31] Move all tracking into snapshot evaluator, remove seed tracker class --- sqlmesh/core/engine_adapter/base.py | 10 ++- sqlmesh/core/engine_adapter/bigquery.py | 4 +- sqlmesh/core/execution_tracker.py | 89 ++++++------------------- sqlmesh/core/scheduler.py | 2 +- sqlmesh/core/snapshot/evaluator.py | 31 +++++---- tests/core/test_execution_tracker.py | 80 ++++++---------------- 6 files changed, 69 insertions(+), 147 deletions(-) diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index 9834fdac59..459bfdea05 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -40,7 +40,7 @@ ) from sqlmesh.core.model.kind import TimeColumn from sqlmesh.core.schema_diff import SchemaDiffer, TableAlterOperation -from sqlmesh.core.execution_tracker import record_execution as track_execution_record +from sqlmesh.core.execution_tracker import QueryExecutionTracker from sqlmesh.utils import ( CorrelationId, columns_to_types_all_known, @@ -2443,7 +2443,11 @@ def _log_sql( def _execute(self, sql: str, track_row_count: bool = False, **kwargs: t.Any) -> None: self.cursor.execute(sql, **kwargs) - if track_row_count and self.SUPPORTS_QUERY_EXECUTION_TRACKING: + if ( + self.SUPPORTS_QUERY_EXECUTION_TRACKING + and track_row_count + and QueryExecutionTracker.is_tracking() + ): rowcount_raw = getattr(self.cursor, "rowcount", None) rowcount = None if rowcount_raw is not None: @@ -2452,7 +2456,7 @@ def _execute(self, sql: str, track_row_count: bool = False, **kwargs: t.Any) -> except (TypeError, ValueError): pass - track_execution_record(sql, rowcount) + QueryExecutionTracker.record_execution(sql, rowcount) @contextlib.contextmanager def temp_table( diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index 703ce51a37..4292a1e37d 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -21,7 +21,7 @@ SourceQuery, set_catalog, ) -from sqlmesh.core.execution_tracker import record_execution as track_execution_record +from sqlmesh.core.execution_tracker import QueryExecutionTracker from sqlmesh.core.node import IntervalUnit from sqlmesh.core.schema_diff import TableAlterOperation, NestedSupport from sqlmesh.utils import optional_import, get_source_columns_to_types @@ -1104,7 +1104,7 @@ def _execute( elif query_job.statement_type in ["INSERT", "DELETE", "MERGE", "UPDATE"]: num_rows = query_job.num_dml_affected_rows - track_execution_record(sql, num_rows) + QueryExecutionTracker.record_execution(sql, num_rows) def _get_data_objects( self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None diff --git a/sqlmesh/core/execution_tracker.py b/sqlmesh/core/execution_tracker.py index 3002c11c89..d31fde68bb 100644 --- a/sqlmesh/core/execution_tracker.py +++ b/sqlmesh/core/execution_tracker.py @@ -27,7 +27,6 @@ class QueryExecutionContext: queries_executed: t.List[t.Tuple[str, t.Optional[int], float]] = field(default_factory=list) def add_execution(self, sql: str, row_count: t.Optional[int]) -> None: - """Record a single query execution.""" if row_count is not None and row_count >= 0: self.total_rows_processed += row_count self.query_count += 1 @@ -46,97 +45,49 @@ def get_execution_stats(self) -> t.Dict[str, t.Any]: class QueryExecutionTracker: """ - Thread-local context manager for snapshot evaluation execution statistics, such as + Thread-local context manager for snapshot execution statistics, such as rows processed. """ _thread_local = local() + _contexts: t.Dict[str, QueryExecutionContext] = {} @classmethod - def get_execution_context(cls) -> t.Optional[QueryExecutionContext]: - return getattr(cls._thread_local, "context", None) + def get_execution_context(cls, snapshot_id_batch: str) -> t.Optional[QueryExecutionContext]: + return cls._contexts.get(snapshot_id_batch) @classmethod def is_tracking(cls) -> bool: - return cls.get_execution_context() is not None + return getattr(cls._thread_local, "context", None) is not None @classmethod @contextmanager - def track_execution(cls, snapshot_name_batch: str) -> t.Iterator[QueryExecutionContext]: + def track_execution( + cls, snapshot_id_batch: str, condition: bool = True + ) -> t.Iterator[t.Optional[QueryExecutionContext]]: """ - Context manager for tracking snapshot evaluation execution statistics. + Context manager for tracking snapshot execution statistics. """ - context = QueryExecutionContext(id=snapshot_name_batch) + if not condition: + yield None + return + + context = QueryExecutionContext(id=snapshot_id_batch) cls._thread_local.context = context + cls._contexts[snapshot_id_batch] = context try: yield context finally: - if hasattr(cls._thread_local, "context"): - delattr(cls._thread_local, "context") + cls._thread_local.context = None @classmethod def record_execution(cls, sql: str, row_count: t.Optional[int]) -> None: - context = cls.get_execution_context() + context = getattr(cls._thread_local, "context", None) if context is not None: context.add_execution(sql, row_count) @classmethod - def get_execution_stats(cls) -> t.Optional[t.Dict[str, t.Any]]: - context = cls.get_execution_context() + def get_execution_stats(cls, snapshot_id_batch: str) -> t.Optional[t.Dict[str, t.Any]]: + context = cls.get_execution_context(snapshot_id_batch) + cls._contexts.pop(snapshot_id_batch, None) return context.get_execution_stats() if context else None - - -class SeedExecutionTracker: - _seed_contexts: t.Dict[str, QueryExecutionContext] = {} - _thread_local = local() - - @classmethod - @contextmanager - def track_execution(cls, model_name: str) -> t.Iterator[QueryExecutionContext]: - """ - Context manager for tracking seed creation execution statistics. - """ - context = QueryExecutionContext(id=model_name) - cls._seed_contexts[model_name] = context - cls._thread_local.seed_id = model_name - - try: - yield context - finally: - if hasattr(cls._thread_local, "seed_id"): - delattr(cls._thread_local, "seed_id") - - @classmethod - def get_and_clear_seed_stats(cls, model_name: str) -> t.Optional[t.Dict[str, t.Any]]: - context = cls._seed_contexts.pop(model_name, None) - return context.get_execution_stats() if context else None - - @classmethod - def clear_all_seed_stats(cls) -> None: - """Clear all remaining seed stats. Used for cleanup after evaluation completes.""" - cls._seed_contexts.clear() - - @classmethod - def is_tracking(cls) -> bool: - return hasattr(cls._thread_local, "seed_id") - - @classmethod - def record_execution(cls, sql: str, row_count: t.Optional[int]) -> None: - seed_id = getattr(cls._thread_local, "seed_id", None) - if seed_id: - context = cls._seed_contexts.get(seed_id) - if context is not None: - context.add_execution(sql, row_count) - - -def record_execution(sql: str, row_count: t.Optional[int]) -> None: - """ - Record execution statistics for a single SQL statement. - - Automatically infers which tracker is active based on the current thread. - """ - if SeedExecutionTracker.is_tracking(): - SeedExecutionTracker.record_execution(sql, row_count) - return - if QueryExecutionTracker.is_tracking(): - QueryExecutionTracker.record_execution(sql, row_count) diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 61fcad3eee..60c45ab546 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -9,7 +9,7 @@ from sqlmesh.core import constants as c from sqlmesh.core.console import Console, get_console from sqlmesh.core.environment import EnvironmentNamingInfo, execute_environment_statements -from sqlmesh.core.execution_tracker import QueryExecutionTracker, SeedExecutionTracker +from sqlmesh.core.execution_tracker import QueryExecutionTracker from sqlmesh.core.macros import RuntimeStage from sqlmesh.core.model.definition import AuditResult from sqlmesh.core.node import IntervalUnit diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 9535197718..47a4709fd4 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -39,7 +39,7 @@ from sqlmesh.core.dialect import schema_ from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy, DataObjectType -from sqlmesh.core.execution_tracker import SeedExecutionTracker +from sqlmesh.core.execution_tracker import QueryExecutionTracker from sqlmesh.core.macros import RuntimeStage from sqlmesh.core.model import ( AuditResult, @@ -170,19 +170,22 @@ def evaluate( Returns: The WAP ID of this evaluation if supported, None otherwise. """ - result = self._evaluate_snapshot( - start=start, - end=end, - execution_time=execution_time, - snapshot=snapshot, - snapshots=snapshots, - allow_destructive_snapshots=allow_destructive_snapshots or set(), - allow_additive_snapshots=allow_additive_snapshots or set(), - deployability_index=deployability_index, - batch_index=batch_index, - target_table_exists=target_table_exists, - **kwargs, - ) + with QueryExecutionTracker.track_execution( + f"{snapshot.snapshot_id}_{batch_index}", condition=not snapshot.is_seed + ): + result = self._evaluate_snapshot( + start=start, + end=end, + execution_time=execution_time, + snapshot=snapshot, + snapshots=snapshots, + allow_destructive_snapshots=allow_destructive_snapshots or set(), + allow_additive_snapshots=allow_additive_snapshots or set(), + deployability_index=deployability_index, + batch_index=batch_index, + target_table_exists=target_table_exists, + **kwargs, + ) if result is None or isinstance(result, str): return result raise SQLMeshError( diff --git a/tests/core/test_execution_tracker.py b/tests/core/test_execution_tracker.py index 7bdcbdb62a..5791abeccb 100644 --- a/tests/core/test_execution_tracker.py +++ b/tests/core/test_execution_tracker.py @@ -1,74 +1,38 @@ -# Tests the sqlmesh.core.execution_tracker module -# - creates a scenario where executions will take place in multiple threads -# - generates the scenario with known numbers of rows to be processed -# - tests that the execution tracker correctly tracks the number of rows processed in both threads -# - may use mocks, an existing test project, manually created snapshots, or a duckdb database to create the scenario - from __future__ import annotations -import threading -from queue import Queue -from typing import List, Optional +import typing as t +from concurrent.futures import ThreadPoolExecutor from sqlmesh.core.execution_tracker import QueryExecutionTracker -def test_execution_tracker_thread_isolation_and_aggregation() -> None: - """ - Two worker threads each track executions in their own context. Verify: - - isolation across threads - - correct aggregation of rows - - query metadata is captured - - main thread has no active tracking - """ - - assert not QueryExecutionTracker.is_tracking() - assert QueryExecutionTracker.get_execution_stats() is None - - counts_a: List[Optional[int]] = [10, 5, None] - counts_b: List[Optional[int]] = [3, 7] - - start_barrier = threading.Barrier(3) # 2 workers + main - results: "Queue[dict]" = Queue() - - def worker(batch_id: str, counts: List[Optional[int]]) -> None: - with QueryExecutionTracker.track_execution(batch_id) as ctx: - # tracking active in this thread +def test_execution_tracker_thread_isolation() -> None: + def worker(id: str, row_counts: list[int]) -> t.Dict[str, t.Any]: + with QueryExecutionTracker.track_execution(id) as ctx: assert QueryExecutionTracker.is_tracking() - # synchronize start to overlap execution - start_barrier.wait() - for c in counts: - QueryExecutionTracker.record_execution("SELECT 1", c) - stats = ctx.get_execution_stats() + for count in row_counts: + QueryExecutionTracker.record_execution("SELECT 1", count) - assert stats["snapshot_batch"] == batch_id - assert stats["query_count"] == len(counts) - results.put(stats) + assert ctx is not None + return ctx.get_execution_stats() - t1 = threading.Thread(target=worker, args=("batch_A", counts_a)) - t2 = threading.Thread(target=worker, args=("batch_B", counts_b)) - - t1.start() - t2.start() - # Release workers at the same time - start_barrier.wait() - t1.join() - t2.join() + with ThreadPoolExecutor() as executor: + futures = [ + executor.submit(worker, "batch_A", [10, 5]), + executor.submit(worker, "batch_B", [3, 7]), + ] + results = [f.result() for f in futures] # Main thread has no active tracking context assert not QueryExecutionTracker.is_tracking() QueryExecutionTracker.record_execution("q", 10) - assert QueryExecutionTracker.get_execution_stats() is None - - collected = [results.get_nowait(), results.get_nowait()] - # by name since order is non-deterministic - by_batch = {s["snapshot_batch"]: s for s in collected} + assert QueryExecutionTracker.get_execution_stats("q") is None - stats_a = by_batch["batch_A"] - assert stats_a["total_rows_processed"] == 15 # 10 + 5 + 0 (None) - assert stats_a["query_count"] == len(counts_a) + # Order of results is not deterministic, so look up by id + by_batch = {s["id"]: s for s in results} - stats_b = by_batch["batch_B"] - assert stats_b["total_rows_processed"] == 10 # 3 + 7 - assert stats_b["query_count"] == len(counts_b) + assert by_batch["batch_A"]["total_rows_processed"] == 15 + assert by_batch["batch_A"]["query_count"] == 2 + assert by_batch["batch_B"]["total_rows_processed"] == 10 + assert by_batch["batch_B"]["query_count"] == 2 From a87df2a3724dbae9d819810c1b75c70fffeccfdf Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Mon, 18 Aug 2025 13:04:01 -0500 Subject: [PATCH 07/31] Remove 'processed' and 'inserted' from console output --- sqlmesh/core/console.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index 5baea97d8b..6ea7383dc8 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -4183,7 +4183,7 @@ def _create_evaluation_model_annotation( ) -> str: annotation = None num_rows_processed = str(rows_processed) if rows_processed else "" - rows_processed_str = f" ({num_rows_processed} rows processed)" if num_rows_processed else "" + rows_processed_str = f" ({num_rows_processed} rows)" if num_rows_processed else "" if snapshot.is_audit: annotation = "run standalone audit" @@ -4193,11 +4193,7 @@ def _create_evaluation_model_annotation( if snapshot.model.kind.is_view: annotation = "recreate view" if snapshot.model.kind.is_seed: - # no "processed" for seeds - seed_num_rows_inserted = ( - f" ({num_rows_processed} rows inserted)" if num_rows_processed else "" - ) - annotation = f"insert seed file{seed_num_rows_inserted}" + annotation = f"insert seed file{rows_processed_str}" if snapshot.model.kind.is_full: annotation = f"full refresh{rows_processed_str}" if snapshot.model.kind.is_incremental_by_unique_key: From f7c6866aab0fb5865361554f7956cde9de660161 Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Mon, 18 Aug 2025 16:11:26 -0500 Subject: [PATCH 08/31] Add BQ support and track bytes processed --- sqlmesh/core/console.py | 93 +++++++++++++++---- sqlmesh/core/engine_adapter/base.py | 45 +++++---- sqlmesh/core/engine_adapter/bigquery.py | 8 +- sqlmesh/core/engine_adapter/clickhouse.py | 6 +- sqlmesh/core/engine_adapter/duckdb.py | 4 +- sqlmesh/core/engine_adapter/redshift.py | 4 +- sqlmesh/core/engine_adapter/snowflake.py | 4 +- sqlmesh/core/engine_adapter/spark.py | 4 +- sqlmesh/core/engine_adapter/trino.py | 4 +- sqlmesh/core/execution_tracker.py | 55 +++++++---- sqlmesh/core/scheduler.py | 21 ++--- sqlmesh/core/state_sync/db/environment.py | 4 +- sqlmesh/core/state_sync/db/interval.py | 4 +- sqlmesh/core/state_sync/db/migrator.py | 2 +- sqlmesh/core/state_sync/db/snapshot.py | 4 +- sqlmesh/core/state_sync/db/version.py | 2 +- .../integration/test_integration.py | 23 +++-- tests/core/test_execution_tracker.py | 19 ++-- web/server/console.py | 3 +- 19 files changed, 193 insertions(+), 116 deletions(-) diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index 6ea7383dc8..9f512123b1 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -32,6 +32,7 @@ from sqlmesh.core.environment import EnvironmentNamingInfo, EnvironmentSummary from sqlmesh.core.linter.rule import RuleViolation from sqlmesh.core.model import Model +from sqlmesh.core.execution_tracker import QueryExecutionStats from sqlmesh.core.snapshot import ( Snapshot, SnapshotChangeCategory, @@ -439,7 +440,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, - rows_processed: t.Optional[int] = None, + execution_stats: t.Optional[QueryExecutionStats] = None, auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: """Updates the snapshot evaluation progress.""" @@ -588,7 +589,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, - rows_processed: t.Optional[int] = None, + execution_stats: t.Optional[QueryExecutionStats] = None, auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: pass @@ -1035,7 +1036,7 @@ def start_evaluation_progress( # determine column widths self.evaluation_column_widths["annotation"] = ( _calculate_annotation_str_len( - batched_intervals, self.AUDIT_PADDING, len(" (XXXXXX rows processed)") + batched_intervals, self.AUDIT_PADDING, len(" (123.4m rows, 123.4 KiB)") ) + 3 # brackets and opening escape backslash ) @@ -1081,7 +1082,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, - rows_processed: t.Optional[int] = None, + execution_stats: t.Optional[QueryExecutionStats] = None, auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: """Update the snapshot evaluation progress.""" @@ -1102,7 +1103,7 @@ def update_snapshot_evaluation_progress( ).ljust(self.evaluation_column_widths["name"]) annotation = _create_evaluation_model_annotation( - snapshot, _format_evaluation_model_interval(snapshot, interval), rows_processed + snapshot, _format_evaluation_model_interval(snapshot, interval), execution_stats ) audits_str = "" if num_audits_passed: @@ -3673,7 +3674,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, - rows_processed: t.Optional[int] = None, + execution_stats: t.Optional[QueryExecutionStats] = None, auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: view_name, loaded_batches = self.evaluation_batch_progress[snapshot.snapshot_id] @@ -3844,7 +3845,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, - rows_processed: t.Optional[int] = None, + execution_stats: t.Optional[QueryExecutionStats] = None, auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: message = f"Evaluated {snapshot.name} | batch={batch_idx} | duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}" @@ -4179,11 +4180,27 @@ def _format_evaluation_model_interval(snapshot: Snapshot, interval: Interval) -> def _create_evaluation_model_annotation( - snapshot: Snapshot, interval_info: t.Optional[str], rows_processed: t.Optional[int] + snapshot: Snapshot, + interval_info: t.Optional[str], + execution_stats: t.Optional[QueryExecutionStats], ) -> str: annotation = None - num_rows_processed = str(rows_processed) if rows_processed else "" - rows_processed_str = f" ({num_rows_processed} rows)" if num_rows_processed else "" + execution_stats_str = "" + if execution_stats: + rows_processed = execution_stats.total_rows_processed + execution_stats_str += ( + f"{_abbreviate_integer_count(rows_processed)} row{'s' if rows_processed > 1 else ''}" + if rows_processed + else "" + ) + + bytes_processed = execution_stats.total_bytes_processed + execution_stats_str += ( + f"{', ' if execution_stats_str else ''}{_format_bytes(bytes_processed)}" + if bytes_processed + else "" + ) + execution_stats_str = f" ({execution_stats_str})" if execution_stats_str else "" if snapshot.is_audit: annotation = "run standalone audit" @@ -4193,22 +4210,24 @@ def _create_evaluation_model_annotation( if snapshot.model.kind.is_view: annotation = "recreate view" if snapshot.model.kind.is_seed: - annotation = f"insert seed file{rows_processed_str}" + annotation = f"insert seed file{execution_stats_str}" if snapshot.model.kind.is_full: - annotation = f"full refresh{rows_processed_str}" + annotation = f"full refresh{execution_stats_str}" if snapshot.model.kind.is_incremental_by_unique_key: - annotation = f"insert/update rows{rows_processed_str}" + annotation = f"insert/update rows{execution_stats_str}" if snapshot.model.kind.is_incremental_by_partition: - annotation = f"insert partitions{rows_processed_str}" + annotation = f"insert partitions{execution_stats_str}" if annotation: return annotation - return f"{interval_info}{rows_processed_str}" if interval_info else "" + return f"{interval_info}{execution_stats_str}" if interval_info else "" def _calculate_interval_str_len( - snapshot: Snapshot, intervals: t.List[Interval], rows_processed: t.Optional[int] = None + snapshot: Snapshot, + intervals: t.List[Interval], + execution_stats: t.Optional[QueryExecutionStats] = None, ) -> int: interval_str_len = 0 for interval in intervals: @@ -4216,7 +4235,7 @@ def _calculate_interval_str_len( interval_str_len, len( _create_evaluation_model_annotation( - snapshot, _format_evaluation_model_interval(snapshot, interval), rows_processed + snapshot, _format_evaluation_model_interval(snapshot, interval), execution_stats ) ), ) @@ -4271,7 +4290,7 @@ def _calculate_audit_str_len(snapshot: Snapshot, audit_padding: int = 0) -> int: def _calculate_annotation_str_len( batched_intervals: t.Dict[Snapshot, t.List[Interval]], audit_padding: int = 0, - rows_processed_len: int = 0, + execution_stats_len: int = 0, ) -> int: annotation_str_len = 0 for snapshot, intervals in batched_intervals.items(): @@ -4279,6 +4298,42 @@ def _calculate_annotation_str_len( annotation_str_len, _calculate_interval_str_len(snapshot, intervals) + _calculate_audit_str_len(snapshot, audit_padding) - + rows_processed_len, + + execution_stats_len, ) return annotation_str_len + + +# Convert number of bytes to a human-readable string +# https://github.com/dbt-labs/dbt-adapters/blob/34fd178539dcb6f82e18e738adc03de7784c032f/dbt-bigquery/src/dbt/adapters/bigquery/connections.py#L165 +def _format_bytes(num_bytes: t.Optional[int]) -> str: + if num_bytes and num_bytes > 0: + if num_bytes < 1024: + return f"{num_bytes} Bytes" + + num_bytes_float = float(num_bytes) / 1024.0 + for unit in ["KiB", "MiB", "GiB", "TiB", "PiB"]: + if num_bytes_float < 1024.0: + return f"{num_bytes_float:3.1f} {unit}" + num_bytes_float /= 1024.0 + + num_bytes_float *= 1024.0 # undo last division in loop + return f"{num_bytes_float:3.1f} {unit}" + return "" + + +# Abbreviate integer count. Example: 1,000,000,000 -> 1b +# https://github.com/dbt-labs/dbt-adapters/blob/34fd178539dcb6f82e18e738adc03de7784c032f/dbt-bigquery/src/dbt/adapters/bigquery/connections.py#L178 +def _abbreviate_integer_count(count: t.Optional[int]) -> str: + if count and count > 0: + if count < 1000: + return str(count) + + count_float = float(count) / 1000.0 + for unit in ["k", "m", "b", "t"]: + if count_float < 1000.0: + return f"{count_float:3.1f}{unit}".strip() + count_float /= 1000.0 + + count_float *= 1000.0 # undo last division in loop + return f"{count_float:3.1f}{unit}".strip() + return "" diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index 459bfdea05..e7582c079e 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -856,7 +856,7 @@ def _create_table_from_source_queries( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, - track_row_count: bool = True, + track_execution_stats: bool = True, **kwargs: t.Any, ) -> None: table = exp.to_table(table_name) @@ -902,7 +902,7 @@ def _create_table_from_source_queries( replace=replace, table_description=table_description, table_kind=table_kind, - track_row_count=track_row_count, + track_execution_stats=track_execution_stats, **kwargs, ) else: @@ -910,7 +910,7 @@ def _create_table_from_source_queries( table_name, query, target_columns_to_types or self.columns(table), - track_row_count=track_row_count, + track_execution_stats=track_execution_stats, ) # Register comments with commands if the engine supports comments and we weren't able to @@ -934,7 +934,7 @@ def _create_table( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, - track_row_count: bool = True, + track_execution_stats: bool = True, **kwargs: t.Any, ) -> None: self.execute( @@ -952,7 +952,7 @@ def _create_table( table_kind=table_kind, **kwargs, ), - track_row_count=track_row_count, + track_execution_stats=track_execution_stats, ) def _build_create_table_exp( @@ -1440,7 +1440,7 @@ def insert_append( table_name: TableName, query_or_df: QueryOrDF, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, - track_row_count: bool = True, + track_execution_stats: bool = True, source_columns: t.Optional[t.List[str]] = None, ) -> None: source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( @@ -1450,7 +1450,7 @@ def insert_append( source_columns=source_columns, ) self._insert_append_source_queries( - table_name, source_queries, target_columns_to_types, track_row_count + table_name, source_queries, target_columns_to_types, track_execution_stats ) def _insert_append_source_queries( @@ -1458,14 +1458,17 @@ def _insert_append_source_queries( table_name: TableName, source_queries: t.List[SourceQuery], target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, - track_row_count: bool = True, + track_execution_stats: bool = True, ) -> None: with self.transaction(condition=len(source_queries) > 0): target_columns_to_types = target_columns_to_types or self.columns(table_name) for source_query in source_queries: with source_query as query: self._insert_append_query( - table_name, query, target_columns_to_types, track_row_count=track_row_count + table_name, + query, + target_columns_to_types, + track_execution_stats=track_execution_stats, ) def _insert_append_query( @@ -1474,13 +1477,13 @@ def _insert_append_query( query: Query, target_columns_to_types: t.Dict[str, exp.DataType], order_projections: bool = True, - track_row_count: bool = True, + track_execution_stats: bool = True, ) -> None: if order_projections: query = self._order_projections_and_filter(query, target_columns_to_types) self.execute( exp.insert(query, table_name, columns=list(target_columns_to_types)), - track_row_count=track_row_count, + track_execution_stats=track_execution_stats, ) def insert_overwrite_by_partition( @@ -1623,7 +1626,7 @@ def _insert_overwrite_by_condition( ) if insert_overwrite_strategy.is_replace_where: insert_exp.set("where", where or exp.true()) - self.execute(insert_exp, track_row_count=True) + self.execute(insert_exp, track_execution_stats=True) def update_table( self, @@ -1644,7 +1647,9 @@ def _merge( using = exp.alias_( exp.Subquery(this=query), alias=MERGE_SOURCE_ALIAS, copy=False, table=True ) - self.execute(exp.Merge(this=this, using=using, on=on, whens=whens), track_row_count=True) + self.execute( + exp.Merge(this=this, using=using, on=on, whens=whens), track_execution_stats=True + ) def scd_type_2_by_time( self, @@ -2393,7 +2398,7 @@ def execute( expressions: t.Union[str, exp.Expression, t.Sequence[exp.Expression]], ignore_unsupported_errors: bool = False, quote_identifiers: bool = True, - track_row_count: bool = False, + track_execution_stats: bool = False, **kwargs: t.Any, ) -> None: """Execute a sql query.""" @@ -2415,7 +2420,7 @@ def execute( expression=e if isinstance(e, exp.Expression) else None, quote_identifiers=quote_identifiers, ) - self._execute(sql, track_row_count, **kwargs) + self._execute(sql, track_execution_stats, **kwargs) def _attach_correlation_id(self, sql: str) -> str: if self.ATTACH_CORRELATION_ID and self.correlation_id: @@ -2440,12 +2445,12 @@ def _log_sql( logger.log(self._execute_log_level, "Executing SQL: %s", sql_to_log) - def _execute(self, sql: str, track_row_count: bool = False, **kwargs: t.Any) -> None: + def _execute(self, sql: str, track_execution_stats: bool = False, **kwargs: t.Any) -> None: self.cursor.execute(sql, **kwargs) if ( self.SUPPORTS_QUERY_EXECUTION_TRACKING - and track_row_count + and track_execution_stats and QueryExecutionTracker.is_tracking() ): rowcount_raw = getattr(self.cursor, "rowcount", None) @@ -2456,7 +2461,7 @@ def _execute(self, sql: str, track_row_count: bool = False, **kwargs: t.Any) -> except (TypeError, ValueError): pass - QueryExecutionTracker.record_execution(sql, rowcount) + QueryExecutionTracker.record_execution(sql, rowcount, None) @contextlib.contextmanager def temp_table( @@ -2502,7 +2507,7 @@ def temp_table( exists=True, table_description=None, column_descriptions=None, - track_row_count=False, + track_execution_stats=False, **kwargs, ) @@ -2754,7 +2759,7 @@ def _replace_by_key( insert_statement.set("where", delete_filter) insert_statement.set("this", exp.to_table(target_table)) - self.execute(insert_statement, track_row_count=True) + self.execute(insert_statement, track_execution_stats=True) finally: self.drop_table(temp_table) diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index 4292a1e37d..3473b05e03 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -1051,7 +1051,7 @@ def _db_call(self, func: t.Callable[..., t.Any], *args: t.Any, **kwargs: t.Any) def _execute( self, sql: str, - track_row_count: bool = False, + track_execution_stats: bool = False, **kwargs: t.Any, ) -> None: """Execute a sql query.""" @@ -1097,14 +1097,16 @@ def _execute( self.cursor._set_rowcount(query_results) self.cursor._set_description(query_results.schema) - if track_row_count: + if track_execution_stats and QueryExecutionTracker.is_tracking(): + num_rows = None if query_job.statement_type == "CREATE_TABLE_AS_SELECT": + # since table was just created, number rows in table == number rows processed query_table = self.client.get_table(query_job.destination) num_rows = query_table.num_rows elif query_job.statement_type in ["INSERT", "DELETE", "MERGE", "UPDATE"]: num_rows = query_job.num_dml_affected_rows - QueryExecutionTracker.record_execution(sql, num_rows) + QueryExecutionTracker.record_execution(sql, num_rows, query_job.total_bytes_processed) def _get_data_objects( self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None diff --git a/sqlmesh/core/engine_adapter/clickhouse.py b/sqlmesh/core/engine_adapter/clickhouse.py index 7458e4887c..34c6aad431 100644 --- a/sqlmesh/core/engine_adapter/clickhouse.py +++ b/sqlmesh/core/engine_adapter/clickhouse.py @@ -294,7 +294,7 @@ def _insert_overwrite_by_condition( ) try: - self.execute(existing_records_insert_exp, track_row_count=True) + self.execute(existing_records_insert_exp, track_execution_stats=True) finally: if table_partition_exp: self.drop_table(partitions_temp_table_name) @@ -489,7 +489,7 @@ def _create_table( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, - track_row_count: bool = True, + track_execution_stats: bool = True, **kwargs: t.Any, ) -> None: """Creates a table in the database. @@ -526,7 +526,7 @@ def _create_table( column_descriptions, table_kind, empty_ctas=(self.engine_run_mode.is_cloud and expression is not None), - track_row_count=track_row_count, + track_execution_stats=track_execution_stats, **kwargs, ) diff --git a/sqlmesh/core/engine_adapter/duckdb.py b/sqlmesh/core/engine_adapter/duckdb.py index f08aebdf06..b501d15945 100644 --- a/sqlmesh/core/engine_adapter/duckdb.py +++ b/sqlmesh/core/engine_adapter/duckdb.py @@ -170,7 +170,7 @@ def _create_table( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, - track_row_count: bool = True, + track_execution_stats: bool = True, **kwargs: t.Any, ) -> None: catalog = self.get_current_catalog() @@ -194,7 +194,7 @@ def _create_table( table_description, column_descriptions, table_kind, - track_row_count=track_row_count, + track_execution_stats=track_execution_stats, **kwargs, ) diff --git a/sqlmesh/core/engine_adapter/redshift.py b/sqlmesh/core/engine_adapter/redshift.py index fa9575820e..1b7faef05e 100644 --- a/sqlmesh/core/engine_adapter/redshift.py +++ b/sqlmesh/core/engine_adapter/redshift.py @@ -173,7 +173,7 @@ def _create_table_from_source_queries( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, - track_row_count: bool = True, + track_execution_stats: bool = True, **kwargs: t.Any, ) -> None: """ @@ -428,7 +428,7 @@ def resolve_target_table(expression: exp.Expression) -> exp.Expression: on=on.transform(resolve_target_table), whens=whens.transform(resolve_target_table), ), - track_row_count=True, + track_execution_stats=True, ) def _normalize_decimal_value(self, expr: exp.Expression, precision: int) -> exp.Expression: diff --git a/sqlmesh/core/engine_adapter/snowflake.py b/sqlmesh/core/engine_adapter/snowflake.py index ad04645185..8148714dcd 100644 --- a/sqlmesh/core/engine_adapter/snowflake.py +++ b/sqlmesh/core/engine_adapter/snowflake.py @@ -166,7 +166,7 @@ def _create_table( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, - track_row_count: bool = True, + track_execution_stats: bool = True, **kwargs: t.Any, ) -> None: table_format = kwargs.get("table_format") @@ -186,7 +186,7 @@ def _create_table( table_description=table_description, column_descriptions=column_descriptions, table_kind=table_kind, - track_row_count=track_row_count, + track_execution_stats=track_execution_stats, **kwargs, ) diff --git a/sqlmesh/core/engine_adapter/spark.py b/sqlmesh/core/engine_adapter/spark.py index 99ff623301..2d9a4abd99 100644 --- a/sqlmesh/core/engine_adapter/spark.py +++ b/sqlmesh/core/engine_adapter/spark.py @@ -433,7 +433,7 @@ def _create_table( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, - track_row_count: bool = True, + track_execution_stats: bool = True, **kwargs: t.Any, ) -> None: table_name = ( @@ -462,7 +462,7 @@ def _create_table( target_columns_to_types=target_columns_to_types, table_description=table_description, column_descriptions=column_descriptions, - track_row_count=track_row_count, + track_execution_stats=track_execution_stats, **kwargs, ) table_name = ( diff --git a/sqlmesh/core/engine_adapter/trino.py b/sqlmesh/core/engine_adapter/trino.py index 21dac81255..2e15cfb78e 100644 --- a/sqlmesh/core/engine_adapter/trino.py +++ b/sqlmesh/core/engine_adapter/trino.py @@ -358,7 +358,7 @@ def _create_table( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, - track_row_count: bool = True, + track_execution_stats: bool = True, **kwargs: t.Any, ) -> None: super()._create_table( @@ -370,7 +370,7 @@ def _create_table( table_description=table_description, column_descriptions=column_descriptions, table_kind=table_kind, - track_row_count=track_row_count, + track_execution_stats=track_execution_stats, **kwargs, ) diff --git a/sqlmesh/core/execution_tracker.py b/sqlmesh/core/execution_tracker.py index d31fde68bb..f98ada750c 100644 --- a/sqlmesh/core/execution_tracker.py +++ b/sqlmesh/core/execution_tracker.py @@ -7,6 +7,17 @@ from dataclasses import dataclass, field +@dataclass +class QueryExecutionStats: + snapshot_batch_id: str + total_rows_processed: int = 0 + total_bytes_processed: int = 0 + query_count: int = 0 + queries_executed: t.List[t.Tuple[str, t.Optional[int], t.Optional[int], float]] = field( + default_factory=list + ) + + @dataclass class QueryExecutionContext: """ @@ -21,26 +32,30 @@ class QueryExecutionContext: queries_executed: List of (sql_snippet, row_count, timestamp) tuples for debugging """ - id: str - total_rows_processed: int = 0 - query_count: int = 0 - queries_executed: t.List[t.Tuple[str, t.Optional[int], float]] = field(default_factory=list) + snapshot_batch_id: str + stats: QueryExecutionStats = field(init=False) + + def __post_init__(self) -> None: + self.stats = QueryExecutionStats(snapshot_batch_id=self.snapshot_batch_id) - def add_execution(self, sql: str, row_count: t.Optional[int]) -> None: + def add_execution( + self, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int] + ) -> None: if row_count is not None and row_count >= 0: - self.total_rows_processed += row_count - self.query_count += 1 + self.stats.total_rows_processed += row_count + + # conditional on row_count because we should only count bytes corresponding to + # DML actions whose rows were captured + if bytes_processed is not None and bytes_processed >= 0: + self.stats.total_bytes_processed += bytes_processed + + self.stats.query_count += 1 # TODO: remove this # for debugging - self.queries_executed.append((sql[:300], row_count, time.time())) + self.stats.queries_executed.append((sql[:300], row_count, bytes_processed, time.time())) - def get_execution_stats(self) -> t.Dict[str, t.Any]: - return { - "id": self.id, - "total_rows_processed": self.total_rows_processed, - "query_count": self.query_count, - "queries": self.queries_executed, - } + def get_execution_stats(self) -> QueryExecutionStats: + return self.stats class QueryExecutionTracker: @@ -72,7 +87,7 @@ def track_execution( yield None return - context = QueryExecutionContext(id=snapshot_id_batch) + context = QueryExecutionContext(snapshot_batch_id=snapshot_id_batch) cls._thread_local.context = context cls._contexts[snapshot_id_batch] = context try: @@ -81,13 +96,15 @@ def track_execution( cls._thread_local.context = None @classmethod - def record_execution(cls, sql: str, row_count: t.Optional[int]) -> None: + def record_execution( + cls, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int] + ) -> None: context = getattr(cls._thread_local, "context", None) if context is not None: - context.add_execution(sql, row_count) + context.add_execution(sql, row_count, bytes_processed) @classmethod - def get_execution_stats(cls, snapshot_id_batch: str) -> t.Optional[t.Dict[str, t.Any]]: + def get_execution_stats(cls, snapshot_id_batch: str) -> t.Optional[QueryExecutionStats]: context = cls.get_execution_context(snapshot_id_batch) cls._contexts.pop(snapshot_id_batch, None) return context.get_execution_stats() if context else None diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 60c45ab546..97613f13c5 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -536,19 +536,9 @@ def run_node(node: SchedulingUnit) -> None: num_audits = len(audit_results) num_audits_failed = sum(1 for result in audit_results if result.count) - rows_processed = None - if snapshot.is_seed: - # seed stats are tracked in SeedStrategy.create by model name, not snapshot name - seed_stats = SeedExecutionTracker.get_and_clear_seed_stats( - snapshot.model.name - ) - rows_processed = ( - seed_stats.get("total_rows_processed") if seed_stats else None - ) - else: - rows_processed = ( - execution_context.total_rows_processed if execution_context else None - ) + execution_stats = QueryExecutionTracker.get_execution_stats( + f"{snapshot.snapshot_id}_{batch_idx}" + ) self.console.update_snapshot_evaluation_progress( snapshot, @@ -557,6 +547,7 @@ def run_node(node: SchedulingUnit) -> None: evaluation_duration_ms, num_audits - num_audits_failed, num_audits_failed, + execution_stats=execution_stats, auto_restatement_triggers=auto_restatement_triggers.get( snapshot.snapshot_id ), @@ -567,9 +558,9 @@ def run_node(node: SchedulingUnit) -> None: snapshots=self.snapshots_by_name, deployability_index=deployability_index, allow_destructive_snapshots=allow_destructive_snapshots or set(), - rows_processed=rows_processed, allow_additive_snapshots=allow_additive_snapshots or set(), - ) + rows_processed=rows_processed, + ) try: with self.snapshot_evaluator.concurrent_context(): diff --git a/sqlmesh/core/state_sync/db/environment.py b/sqlmesh/core/state_sync/db/environment.py index db3844f0ce..444985274d 100644 --- a/sqlmesh/core/state_sync/db/environment.py +++ b/sqlmesh/core/state_sync/db/environment.py @@ -78,7 +78,7 @@ def update_environment(self, environment: Environment) -> None: self.environments_table, _environment_to_df(environment), target_columns_to_types=self._environment_columns_to_types, - track_row_count=False, + track_execution_stats=False, ) def update_environment_statements( @@ -109,7 +109,7 @@ def update_environment_statements( self.environment_statements_table, _environment_statements_to_df(environment_name, plan_id, environment_statements), target_columns_to_types=self._environment_statements_columns_to_types, - track_row_count=False, + track_execution_stats=False, ) def invalidate_environment(self, name: str, protect_prod: bool = True) -> None: diff --git a/sqlmesh/core/state_sync/db/interval.py b/sqlmesh/core/state_sync/db/interval.py index ae37fd9734..e06100f904 100644 --- a/sqlmesh/core/state_sync/db/interval.py +++ b/sqlmesh/core/state_sync/db/interval.py @@ -115,7 +115,7 @@ def remove_intervals( self.intervals_table, _intervals_to_df(intervals_to_remove, is_dev=False, is_removed=True), target_columns_to_types=self._interval_columns_to_types, - track_row_count=False, + track_execution_stats=False, ) def get_snapshot_intervals( @@ -244,7 +244,7 @@ def _push_snapshot_intervals( self.intervals_table, pd.DataFrame(new_intervals), target_columns_to_types=self._interval_columns_to_types, - track_row_count=False, + track_execution_stats=False, ) def _get_snapshot_intervals( diff --git a/sqlmesh/core/state_sync/db/migrator.py b/sqlmesh/core/state_sync/db/migrator.py index f796a26f62..ecf00baa3c 100644 --- a/sqlmesh/core/state_sync/db/migrator.py +++ b/sqlmesh/core/state_sync/db/migrator.py @@ -414,7 +414,7 @@ def _backup_state(self) -> None: self.engine_adapter.drop_table(backup_name) self.engine_adapter.create_table_like(backup_name, table) self.engine_adapter.insert_append( - backup_name, exp.select("*").from_(table), track_row_count=False + backup_name, exp.select("*").from_(table), track_execution_stats=False ) def _restore_table( diff --git a/sqlmesh/core/state_sync/db/snapshot.py b/sqlmesh/core/state_sync/db/snapshot.py index 223b11153e..2a9046986b 100644 --- a/sqlmesh/core/state_sync/db/snapshot.py +++ b/sqlmesh/core/state_sync/db/snapshot.py @@ -103,7 +103,7 @@ def push_snapshots(self, snapshots: t.Iterable[Snapshot], overwrite: bool = Fals self.snapshots_table, _snapshots_to_df(snapshots_to_store), target_columns_to_types=self._snapshot_columns_to_types, - track_row_count=False, + track_execution_stats=False, ) for snapshot in snapshots: @@ -407,7 +407,7 @@ def _push_snapshots(self, snapshots: t.Iterable[Snapshot]) -> None: self.snapshots_table, _snapshots_to_df(snapshots_to_store), target_columns_to_types=self._snapshot_columns_to_types, - track_row_count=False, + track_execution_stats=False, ) def _get_snapshots( diff --git a/sqlmesh/core/state_sync/db/version.py b/sqlmesh/core/state_sync/db/version.py index 2732c0ca47..487347f7d1 100644 --- a/sqlmesh/core/state_sync/db/version.py +++ b/sqlmesh/core/state_sync/db/version.py @@ -55,7 +55,7 @@ def update_versions( ] ), target_columns_to_types=self._version_columns_to_types, - track_row_count=False, + track_execution_stats=False, ) def get_versions(self) -> Versions: diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index 4298cd58a7..5abc860669 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -2383,7 +2383,7 @@ def _mutate_config(gateway: str, config: Config): context._models.update(replacement_models) # capture row counts for each evaluated snapshot - row_counts = {} + actual_execution_stats = {} def capture_row_counts( snapshot, @@ -2393,10 +2393,12 @@ def capture_row_counts( num_audits_passed, num_audits_failed, audit_only=False, - rows_processed=None, + execution_stats=None, ): - if rows_processed is not None: - row_counts[snapshot.model.name.replace(f"{schema_name}.", "")] = rows_processed + if execution_stats is not None: + actual_execution_stats[snapshot.model.name.replace(f"{schema_name}.", "")] = ( + execution_stats + ) # apply prod plan with patch.object(context.console, "update_snapshot_evaluation_progress", capture_row_counts): @@ -2413,10 +2415,15 @@ def capture_row_counts( assert len(physical_layer_results.tables) == len(physical_layer_results.non_temp_tables) == 3 if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING: - assert len(row_counts) == 3 - assert row_counts["seed_model"] == 7 - assert row_counts["incremental_model"] == 7 - assert row_counts["full_model"] == 3 + assert len(actual_execution_stats) == 3 + assert actual_execution_stats["seed_model"].total_rows_processed == 7 + assert actual_execution_stats["incremental_model"].total_rows_processed == 7 + assert actual_execution_stats["full_model"].total_rows_processed == 3 + + if ctx.mark.startswith("bigquery"): + assert actual_execution_stats["seed_model"].total_bytes_processed + assert actual_execution_stats["incremental_model"].total_bytes_processed + assert actual_execution_stats["full_model"].total_bytes_processed # make and validate unmodified dev environment no_change_plan: Plan = context.plan_builder( diff --git a/tests/core/test_execution_tracker.py b/tests/core/test_execution_tracker.py index 5791abeccb..3d51dbd4dd 100644 --- a/tests/core/test_execution_tracker.py +++ b/tests/core/test_execution_tracker.py @@ -1,18 +1,17 @@ from __future__ import annotations -import typing as t from concurrent.futures import ThreadPoolExecutor -from sqlmesh.core.execution_tracker import QueryExecutionTracker +from sqlmesh.core.execution_tracker import QueryExecutionStats, QueryExecutionTracker def test_execution_tracker_thread_isolation() -> None: - def worker(id: str, row_counts: list[int]) -> t.Dict[str, t.Any]: + def worker(id: str, row_counts: list[int]) -> QueryExecutionStats: with QueryExecutionTracker.track_execution(id) as ctx: assert QueryExecutionTracker.is_tracking() for count in row_counts: - QueryExecutionTracker.record_execution("SELECT 1", count) + QueryExecutionTracker.record_execution("SELECT 1", count, None) assert ctx is not None return ctx.get_execution_stats() @@ -26,13 +25,13 @@ def worker(id: str, row_counts: list[int]) -> t.Dict[str, t.Any]: # Main thread has no active tracking context assert not QueryExecutionTracker.is_tracking() - QueryExecutionTracker.record_execution("q", 10) + QueryExecutionTracker.record_execution("q", 10, None) assert QueryExecutionTracker.get_execution_stats("q") is None # Order of results is not deterministic, so look up by id - by_batch = {s["id"]: s for s in results} + by_batch = {s.snapshot_batch_id: s for s in results} - assert by_batch["batch_A"]["total_rows_processed"] == 15 - assert by_batch["batch_A"]["query_count"] == 2 - assert by_batch["batch_B"]["total_rows_processed"] == 10 - assert by_batch["batch_B"]["query_count"] == 2 + assert by_batch["batch_A"].total_rows_processed == 15 + assert by_batch["batch_A"].query_count == 2 + assert by_batch["batch_B"].total_rows_processed == 10 + assert by_batch["batch_B"].query_count == 2 diff --git a/web/server/console.py b/web/server/console.py index b2d12cd624..c7bdbbfc51 100644 --- a/web/server/console.py +++ b/web/server/console.py @@ -8,6 +8,7 @@ from sqlmesh.core.snapshot.definition import Interval, Intervals from sqlmesh.core.console import TerminalConsole from sqlmesh.core.environment import EnvironmentNamingInfo +from sqlmesh.core.execution_tracker import QueryExecutionStats from sqlmesh.core.plan.definition import EvaluatablePlan from sqlmesh.core.snapshot import Snapshot, SnapshotInfoLike, SnapshotTableInfo, SnapshotId from sqlmesh.core.test import ModelTest @@ -142,7 +143,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, - rows_processed: t.Optional[int] = None, + execution_stats: t.Optional[QueryExecutionStats] = None, auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: if audit_only: From 4df2e3225ff041dc844d67a014b5cd8bc479e11b Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Mon, 18 Aug 2025 17:30:26 -0500 Subject: [PATCH 09/31] Remove seed tracking, have snapshot evaluator own tracker instance --- sqlmesh/core/console.py | 2 +- sqlmesh/core/execution_tracker.py | 32 +++++++++---------- sqlmesh/core/scheduler.py | 3 +- sqlmesh/core/snapshot/evaluator.py | 5 ++- .../integration/test_integration.py | 3 -- tests/core/test_execution_tracker.py | 14 ++++---- 6 files changed, 27 insertions(+), 32 deletions(-) diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index 9f512123b1..f4accf551a 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -4308,7 +4308,7 @@ def _calculate_annotation_str_len( def _format_bytes(num_bytes: t.Optional[int]) -> str: if num_bytes and num_bytes > 0: if num_bytes < 1024: - return f"{num_bytes} Bytes" + return f"{num_bytes} bytes" num_bytes_float = float(num_bytes) / 1024.0 for unit in ["KiB", "MiB", "GiB", "TiB", "PiB"]: diff --git a/sqlmesh/core/execution_tracker.py b/sqlmesh/core/execution_tracker.py index f98ada750c..9d7d03887a 100644 --- a/sqlmesh/core/execution_tracker.py +++ b/sqlmesh/core/execution_tracker.py @@ -3,7 +3,7 @@ import time import typing as t from contextlib import contextmanager -from threading import local +from threading import local, Lock from dataclasses import dataclass, field @@ -66,34 +66,32 @@ class QueryExecutionTracker: _thread_local = local() _contexts: t.Dict[str, QueryExecutionContext] = {} + _contexts_lock = Lock() - @classmethod - def get_execution_context(cls, snapshot_id_batch: str) -> t.Optional[QueryExecutionContext]: - return cls._contexts.get(snapshot_id_batch) + def get_execution_context(self, snapshot_id_batch: str) -> t.Optional[QueryExecutionContext]: + with self._contexts_lock: + return self._contexts.get(snapshot_id_batch) @classmethod def is_tracking(cls) -> bool: return getattr(cls._thread_local, "context", None) is not None - @classmethod @contextmanager def track_execution( - cls, snapshot_id_batch: str, condition: bool = True + self, snapshot_id_batch: str ) -> t.Iterator[t.Optional[QueryExecutionContext]]: """ Context manager for tracking snapshot execution statistics. """ - if not condition: - yield None - return - context = QueryExecutionContext(snapshot_batch_id=snapshot_id_batch) - cls._thread_local.context = context - cls._contexts[snapshot_id_batch] = context + self._thread_local.context = context + with self._contexts_lock: + self._contexts[snapshot_id_batch] = context + try: yield context finally: - cls._thread_local.context = None + self._thread_local.context = None @classmethod def record_execution( @@ -103,8 +101,8 @@ def record_execution( if context is not None: context.add_execution(sql, row_count, bytes_processed) - @classmethod - def get_execution_stats(cls, snapshot_id_batch: str) -> t.Optional[QueryExecutionStats]: - context = cls.get_execution_context(snapshot_id_batch) - cls._contexts.pop(snapshot_id_batch, None) + def get_execution_stats(self, snapshot_id_batch: str) -> t.Optional[QueryExecutionStats]: + with self._contexts_lock: + context = self._contexts.get(snapshot_id_batch) + self._contexts.pop(snapshot_id_batch, None) return context.get_execution_stats() if context else None diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 97613f13c5..66b0b115d6 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -9,7 +9,6 @@ from sqlmesh.core import constants as c from sqlmesh.core.console import Console, get_console from sqlmesh.core.environment import EnvironmentNamingInfo, execute_environment_statements -from sqlmesh.core.execution_tracker import QueryExecutionTracker from sqlmesh.core.macros import RuntimeStage from sqlmesh.core.model.definition import AuditResult from sqlmesh.core.node import IntervalUnit @@ -536,7 +535,7 @@ def run_node(node: SchedulingUnit) -> None: num_audits = len(audit_results) num_audits_failed = sum(1 for result in audit_results if result.count) - execution_stats = QueryExecutionTracker.get_execution_stats( + execution_stats = self.snapshot_evaluator.execution_tracker.get_execution_stats( f"{snapshot.snapshot_id}_{batch_idx}" ) diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 47a4709fd4..45daa0feb6 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -136,6 +136,7 @@ def __init__( ) self.selected_gateway = selected_gateway self.ddl_concurrent_tasks = ddl_concurrent_tasks + self.execution_tracker = QueryExecutionTracker() def evaluate( self, @@ -170,9 +171,7 @@ def evaluate( Returns: The WAP ID of this evaluation if supported, None otherwise. """ - with QueryExecutionTracker.track_execution( - f"{snapshot.snapshot_id}_{batch_index}", condition=not snapshot.is_seed - ): + with self.execution_tracker.track_execution(f"{snapshot.snapshot_id}_{batch_index}"): result = self._evaluate_snapshot( start=start, end=end, diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index 5abc860669..a9674208b0 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -2415,13 +2415,10 @@ def capture_row_counts( assert len(physical_layer_results.tables) == len(physical_layer_results.non_temp_tables) == 3 if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING: - assert len(actual_execution_stats) == 3 - assert actual_execution_stats["seed_model"].total_rows_processed == 7 assert actual_execution_stats["incremental_model"].total_rows_processed == 7 assert actual_execution_stats["full_model"].total_rows_processed == 3 if ctx.mark.startswith("bigquery"): - assert actual_execution_stats["seed_model"].total_bytes_processed assert actual_execution_stats["incremental_model"].total_bytes_processed assert actual_execution_stats["full_model"].total_bytes_processed diff --git a/tests/core/test_execution_tracker.py b/tests/core/test_execution_tracker.py index 3d51dbd4dd..6172bade75 100644 --- a/tests/core/test_execution_tracker.py +++ b/tests/core/test_execution_tracker.py @@ -7,15 +7,17 @@ def test_execution_tracker_thread_isolation() -> None: def worker(id: str, row_counts: list[int]) -> QueryExecutionStats: - with QueryExecutionTracker.track_execution(id) as ctx: - assert QueryExecutionTracker.is_tracking() + with execution_tracker.track_execution(id) as ctx: + assert execution_tracker.is_tracking() for count in row_counts: - QueryExecutionTracker.record_execution("SELECT 1", count, None) + execution_tracker.record_execution("SELECT 1", count, None) assert ctx is not None return ctx.get_execution_stats() + execution_tracker = QueryExecutionTracker() + with ThreadPoolExecutor() as executor: futures = [ executor.submit(worker, "batch_A", [10, 5]), @@ -24,9 +26,9 @@ def worker(id: str, row_counts: list[int]) -> QueryExecutionStats: results = [f.result() for f in futures] # Main thread has no active tracking context - assert not QueryExecutionTracker.is_tracking() - QueryExecutionTracker.record_execution("q", 10, None) - assert QueryExecutionTracker.get_execution_stats("q") is None + assert not execution_tracker.is_tracking() + execution_tracker.record_execution("q", 10, None) + assert execution_tracker.get_execution_stats("q") is None # Order of results is not deterministic, so look up by id by_batch = {s.snapshot_batch_id: s for s in results} From 8c2e184ca9bfdd49d418b35191d66baccc14dc33 Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Mon, 18 Aug 2025 17:34:24 -0500 Subject: [PATCH 10/31] Move tracker class into snapshot module --- sqlmesh/core/console.py | 2 +- sqlmesh/core/engine_adapter/base.py | 2 +- sqlmesh/core/engine_adapter/bigquery.py | 2 +- sqlmesh/core/snapshot/evaluator.py | 3 ++- sqlmesh/core/{ => snapshot}/execution_tracker.py | 0 tests/core/test_execution_tracker.py | 2 +- web/server/console.py | 2 +- 7 files changed, 7 insertions(+), 6 deletions(-) rename sqlmesh/core/{ => snapshot}/execution_tracker.py (100%) diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index f4accf551a..ea7445be82 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -32,7 +32,6 @@ from sqlmesh.core.environment import EnvironmentNamingInfo, EnvironmentSummary from sqlmesh.core.linter.rule import RuleViolation from sqlmesh.core.model import Model -from sqlmesh.core.execution_tracker import QueryExecutionStats from sqlmesh.core.snapshot import ( Snapshot, SnapshotChangeCategory, @@ -40,6 +39,7 @@ SnapshotInfoLike, ) from sqlmesh.core.snapshot.definition import Interval, Intervals, SnapshotTableInfo +from sqlmesh.core.snapshot.execution_tracker import QueryExecutionStats from sqlmesh.core.test import ModelTest from sqlmesh.utils import rich as srich from sqlmesh.utils import Verbosity diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index e7582c079e..4a8ec4d34b 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -40,7 +40,7 @@ ) from sqlmesh.core.model.kind import TimeColumn from sqlmesh.core.schema_diff import SchemaDiffer, TableAlterOperation -from sqlmesh.core.execution_tracker import QueryExecutionTracker +from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker from sqlmesh.utils import ( CorrelationId, columns_to_types_all_known, diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index 3473b05e03..37a0ab2578 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -21,9 +21,9 @@ SourceQuery, set_catalog, ) -from sqlmesh.core.execution_tracker import QueryExecutionTracker from sqlmesh.core.node import IntervalUnit from sqlmesh.core.schema_diff import TableAlterOperation, NestedSupport +from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker from sqlmesh.utils import optional_import, get_source_columns_to_types from sqlmesh.utils.date import to_datetime from sqlmesh.utils.errors import SQLMeshError diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 45daa0feb6..9a615d6ad7 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -39,7 +39,6 @@ from sqlmesh.core.dialect import schema_ from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy, DataObjectType -from sqlmesh.core.execution_tracker import QueryExecutionTracker from sqlmesh.core.macros import RuntimeStage from sqlmesh.core.model import ( AuditResult, @@ -66,6 +65,8 @@ SnapshotInfoLike, SnapshotTableCleanupTask, ) +from sqlmesh.core.snapshot.definition import parent_snapshots_by_name +from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker from sqlmesh.utils import random_id, CorrelationId from sqlmesh.utils.concurrency import ( concurrent_apply_to_snapshots, diff --git a/sqlmesh/core/execution_tracker.py b/sqlmesh/core/snapshot/execution_tracker.py similarity index 100% rename from sqlmesh/core/execution_tracker.py rename to sqlmesh/core/snapshot/execution_tracker.py diff --git a/tests/core/test_execution_tracker.py b/tests/core/test_execution_tracker.py index 6172bade75..3afe56df16 100644 --- a/tests/core/test_execution_tracker.py +++ b/tests/core/test_execution_tracker.py @@ -2,7 +2,7 @@ from concurrent.futures import ThreadPoolExecutor -from sqlmesh.core.execution_tracker import QueryExecutionStats, QueryExecutionTracker +from sqlmesh.core.snapshot.execution_tracker import QueryExecutionStats, QueryExecutionTracker def test_execution_tracker_thread_isolation() -> None: diff --git a/web/server/console.py b/web/server/console.py index c7bdbbfc51..871aaefbb1 100644 --- a/web/server/console.py +++ b/web/server/console.py @@ -8,9 +8,9 @@ from sqlmesh.core.snapshot.definition import Interval, Intervals from sqlmesh.core.console import TerminalConsole from sqlmesh.core.environment import EnvironmentNamingInfo -from sqlmesh.core.execution_tracker import QueryExecutionStats from sqlmesh.core.plan.definition import EvaluatablePlan from sqlmesh.core.snapshot import Snapshot, SnapshotInfoLike, SnapshotTableInfo, SnapshotId +from sqlmesh.core.snapshot.execution_tracker import QueryExecutionStats from sqlmesh.core.test import ModelTest from sqlmesh.core.test.result import ModelTextTestResult from sqlmesh.utils.date import now_timestamp From ff7f0960cf539260bd86204229c2d4f147980f74 Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Mon, 18 Aug 2025 17:50:44 -0500 Subject: [PATCH 11/31] Fix circular import --- sqlmesh/core/snapshot/definition.py | 7 +++++-- sqlmesh/core/snapshot/evaluator.py | 2 +- tests/core/engine_adapter/integration/test_integration.py | 6 ++++-- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index 1a286edcfc..5a9ad60166 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -13,10 +13,13 @@ from sqlglot import exp from sqlglot.optimizer.normalize_identifiers import normalize_identifiers -from sqlmesh.core.config.common import TableNamingConvention, VirtualEnvironmentMode +from sqlmesh.core.config.common import ( + TableNamingConvention, + VirtualEnvironmentMode, + EnvironmentSuffixTarget, +) from sqlmesh.core import constants as c from sqlmesh.core.audit import StandaloneAudit -from sqlmesh.core.environment import EnvironmentSuffixTarget from sqlmesh.core.macros import call_macro from sqlmesh.core.model import Model, ModelKindMixin, ModelKindName, ViewKind, CustomKind from sqlmesh.core.model.definition import _Model diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 9a615d6ad7..9729b7a66b 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -37,7 +37,6 @@ from sqlmesh.core import dialect as d from sqlmesh.core.audit import Audit, StandaloneAudit from sqlmesh.core.dialect import schema_ -from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy, DataObjectType from sqlmesh.core.macros import RuntimeStage from sqlmesh.core.model import ( @@ -90,6 +89,7 @@ if t.TYPE_CHECKING: from sqlmesh.core.engine_adapter._typing import DF, QueryOrDF + from sqlmesh.core.engine_adapter.base import EngineAdapter from sqlmesh.core.environment import EnvironmentNamingInfo logger = logging.getLogger(__name__) diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index a9674208b0..7999252145 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -2385,7 +2385,7 @@ def _mutate_config(gateway: str, config: Config): # capture row counts for each evaluated snapshot actual_execution_stats = {} - def capture_row_counts( + def capture_execution_stats( snapshot, interval, batch_idx, @@ -2401,7 +2401,9 @@ def capture_row_counts( ) # apply prod plan - with patch.object(context.console, "update_snapshot_evaluation_progress", capture_row_counts): + with patch.object( + context.console, "update_snapshot_evaluation_progress", capture_execution_stats + ): context.plan(auto_apply=True, no_prompts=True) prod_schema_results = ctx.get_metadata_results(object_names["view_schema"][0]) From fe8adba74ae2ba447448a23fe9b56a4f2e0b001c Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Mon, 18 Aug 2025 19:04:08 -0500 Subject: [PATCH 12/31] Handle snowflake lack of CTAS tracking --- sqlmesh/core/engine_adapter/base.py | 7 +++- sqlmesh/core/engine_adapter/snowflake.py | 33 +++++++++++++++++++ sqlmesh/core/snapshot/execution_tracker.py | 14 +++++--- .../integration/test_integration.py | 7 +++- 4 files changed, 55 insertions(+), 6 deletions(-) diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index 4a8ec4d34b..5c25f3fcc8 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -2445,6 +2445,11 @@ def _log_sql( logger.log(self._execute_log_level, "Executing SQL: %s", sql_to_log) + def _record_execution_stats( + self, sql: str, rowcount: t.Optional[int] = None, bytes_processed: t.Optional[int] = None + ) -> None: + QueryExecutionTracker.record_execution(sql, rowcount, bytes_processed) + def _execute(self, sql: str, track_execution_stats: bool = False, **kwargs: t.Any) -> None: self.cursor.execute(sql, **kwargs) @@ -2461,7 +2466,7 @@ def _execute(self, sql: str, track_execution_stats: bool = False, **kwargs: t.An except (TypeError, ValueError): pass - QueryExecutionTracker.record_execution(sql, rowcount, None) + self._record_execution_stats(sql, rowcount) @contextlib.contextmanager def temp_table( diff --git a/sqlmesh/core/engine_adapter/snowflake.py b/sqlmesh/core/engine_adapter/snowflake.py index 8148714dcd..79640bd154 100644 --- a/sqlmesh/core/engine_adapter/snowflake.py +++ b/sqlmesh/core/engine_adapter/snowflake.py @@ -2,6 +2,7 @@ import contextlib import logging +import re import typing as t from sqlglot import exp @@ -23,6 +24,7 @@ SourceQuery, set_catalog, ) +from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker from sqlmesh.utils import optional_import, get_source_columns_to_types from sqlmesh.utils.errors import SQLMeshError from sqlmesh.utils.pandas import columns_to_types_from_dtypes @@ -72,6 +74,7 @@ class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin, ClusteredByMixi } MANAGED_TABLE_KIND = "DYNAMIC TABLE" SNOWPARK = "snowpark" + SUPPORTS_QUERY_EXECUTION_TRACKING = True @contextlib.contextmanager def session(self, properties: SessionProperties) -> t.Iterator[None]: @@ -664,3 +667,33 @@ def close(self) -> t.Any: self._connection_pool.set_attribute(self.SNOWPARK, None) return super().close() + + def _record_execution_stats( + self, sql: str, rowcount: t.Optional[int] = None, bytes_processed: t.Optional[int] = None + ) -> None: + """Snowflake does not report row counts for CTAS like other DML operations. + + They neither report the sentinel value -1 nor do they report 0 rows. Instead, they return a single data row + containing the string "Table successfully created." and a row count of 1. + + We do not want to record the row count of 1 for CTAS operations, so we check for that data pattern and return + early if it is detected. + + Regex explanation - Snowflake identifiers may be: + - An unquoted contiguous set of [a-zA-Z0-9_$] characters + - A double-quoted string that may contain spaces and nested double-quotes represented by `""` + - Example: " my ""table"" name " + - Pattern: "(?:[^"]|"")+" + - ?: is a non-capturing group + - [^"] matches any single character except a double-quote + - "" matches two sequential double-quotes + """ + if rowcount == 1: + results = self.cursor.fetchall() + if results and len(results) == 1: + is_ctas = re.match( + r'Table ([a-zA-Z0-9_$]+|"(?:[^"]|"")+") successfully created\.', results[0][0] + ) + if is_ctas: + return + QueryExecutionTracker.record_execution(sql, rowcount, bytes_processed) diff --git a/sqlmesh/core/snapshot/execution_tracker.py b/sqlmesh/core/snapshot/execution_tracker.py index 9d7d03887a..8e7ec245ef 100644 --- a/sqlmesh/core/snapshot/execution_tracker.py +++ b/sqlmesh/core/snapshot/execution_tracker.py @@ -10,8 +10,8 @@ @dataclass class QueryExecutionStats: snapshot_batch_id: str - total_rows_processed: int = 0 - total_bytes_processed: int = 0 + total_rows_processed: t.Optional[int] = None + total_bytes_processed: t.Optional[int] = None query_count: int = 0 queries_executed: t.List[t.Tuple[str, t.Optional[int], t.Optional[int], float]] = field( default_factory=list @@ -42,12 +42,18 @@ def add_execution( self, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int] ) -> None: if row_count is not None and row_count >= 0: - self.stats.total_rows_processed += row_count + if self.stats.total_rows_processed is None: + self.stats.total_rows_processed = row_count + else: + self.stats.total_rows_processed += row_count # conditional on row_count because we should only count bytes corresponding to # DML actions whose rows were captured if bytes_processed is not None and bytes_processed >= 0: - self.stats.total_bytes_processed += bytes_processed + if self.stats.total_bytes_processed is None: + self.stats.total_bytes_processed = bytes_processed + else: + self.stats.total_bytes_processed += bytes_processed self.stats.query_count += 1 # TODO: remove this diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index 7999252145..464ccfd996 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -2418,7 +2418,12 @@ def capture_execution_stats( if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING: assert actual_execution_stats["incremental_model"].total_rows_processed == 7 - assert actual_execution_stats["full_model"].total_rows_processed == 3 + # snowflake doesn't track rows for CTAS + assert actual_execution_stats["full_model"].total_rows_processed == ( + None if ctx.mark.startswith("snowflake") else 3 + ) + # seed rows aren't tracked + assert actual_execution_stats["seed_model"].total_rows_processed is None if ctx.mark.startswith("bigquery"): assert actual_execution_stats["incremental_model"].total_bytes_processed From bb4b0164eb255a8cafcdb5445c9a3eb464a27596 Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Tue, 19 Aug 2025 12:33:27 -0500 Subject: [PATCH 13/31] Fix tests and snowflake regex --- sqlmesh/core/engine_adapter/snowflake.py | 33 ++++++++++------- sqlmesh/core/snapshot/execution_tracker.py | 4 +-- tests/core/test_snapshot_evaluator.py | 4 ++- tests/core/test_table_diff.py | 12 +++---- tests/core/test_test.py | 42 +++++++++++++--------- 5 files changed, 57 insertions(+), 38 deletions(-) diff --git a/sqlmesh/core/engine_adapter/snowflake.py b/sqlmesh/core/engine_adapter/snowflake.py index 79640bd154..831368d841 100644 --- a/sqlmesh/core/engine_adapter/snowflake.py +++ b/sqlmesh/core/engine_adapter/snowflake.py @@ -676,23 +676,32 @@ def _record_execution_stats( They neither report the sentinel value -1 nor do they report 0 rows. Instead, they return a single data row containing the string "Table successfully created." and a row count of 1. - We do not want to record the row count of 1 for CTAS operations, so we check for that data pattern and return - early if it is detected. - - Regex explanation - Snowflake identifiers may be: - - An unquoted contiguous set of [a-zA-Z0-9_$] characters - - A double-quoted string that may contain spaces and nested double-quotes represented by `""` - - Example: " my ""table"" name " - - Pattern: "(?:[^"]|"")+" - - ?: is a non-capturing group - - [^"] matches any single character except a double-quote - - "" matches two sequential double-quotes + We do not want to record the incorrect row count of 1, so we check whether: + - There is exactly one row to fetch (in general, DML operations should return no rows to fetch from the cursor) + - That row contains the table successfully created string + + If so, we return early and do not record the row count. """ if rowcount == 1: results = self.cursor.fetchall() if results and len(results) == 1: + try: + results_str = str(results[0][0]) + except (ValueError, TypeError): + return + + # Snowflake identifiers may be: + # - An unquoted contiguous set of [a-zA-Z0-9_$] characters + # - A double-quoted string that may contain spaces and nested double-quotes represented by `""`. Example: " my ""table"" name " + # - Regex: + # - [a-zA-Z0-9_$]+ matches one or more character in the set + # - "(?:[^"]|"")+" matches a double-quoted string that may contain spaces and nested double-quotes + # - ?: non-capturing group + # - [^"] matches any single character except a double-quote + # - | or + # - "" matches two sequential double-quotes is_ctas = re.match( - r'Table ([a-zA-Z0-9_$]+|"(?:[^"]|"")+") successfully created\.', results[0][0] + r'Table ([a-zA-Z0-9_$]+|"(?:[^"]|"")+") successfully created\.', results_str ) if is_ctas: return diff --git a/sqlmesh/core/snapshot/execution_tracker.py b/sqlmesh/core/snapshot/execution_tracker.py index 8e7ec245ef..b80b746dcc 100644 --- a/sqlmesh/core/snapshot/execution_tracker.py +++ b/sqlmesh/core/snapshot/execution_tracker.py @@ -41,7 +41,7 @@ def __post_init__(self) -> None: def add_execution( self, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int] ) -> None: - if row_count is not None and row_count >= 0: + if row_count is not None: if self.stats.total_rows_processed is None: self.stats.total_rows_processed = row_count else: @@ -49,7 +49,7 @@ def add_execution( # conditional on row_count because we should only count bytes corresponding to # DML actions whose rows were captured - if bytes_processed is not None and bytes_processed >= 0: + if bytes_processed is not None: if self.stats.total_bytes_processed is None: self.stats.total_bytes_processed = bytes_processed else: diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index 53f9bd425a..c19d118c8c 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -675,8 +675,10 @@ def test_evaluate_materialized_view_with_partitioned_by_cluster_by( execute_mock.assert_has_calls( [ + call("CREATE SCHEMA IF NOT EXISTS `sqlmesh__test_schema`", False), call( - f"CREATE MATERIALIZED VIEW `sqlmesh__test_schema`.`test_schema__test_model__{snapshot.version}` PARTITION BY `a` CLUSTER BY `b` AS SELECT `a` AS `a`, `b` AS `b` FROM `tbl` AS `tbl`" + f"CREATE MATERIALIZED VIEW `sqlmesh__test_schema`.`test_schema__test_model__{snapshot.version}` PARTITION BY `a` CLUSTER BY `b` AS SELECT `a` AS `a`, `b` AS `b` FROM `tbl` AS `tbl`", + False, ), ] ) diff --git a/tests/core/test_table_diff.py b/tests/core/test_table_diff.py index b2848676b4..73fd37a2f7 100644 --- a/tests/core/test_table_diff.py +++ b/tests/core/test_table_diff.py @@ -360,11 +360,11 @@ def test_generated_sql(sushi_context_fixed_date: Context, mocker: MockerFixture) temp_schema="sqlmesh_temp_test", ) - spy_execute.assert_any_call(query_sql) - spy_execute.assert_any_call(summary_query_sql) - spy_execute.assert_any_call(compare_sql) - spy_execute.assert_any_call(sample_query_sql) - spy_execute.assert_any_call(drop_sql) + spy_execute.assert_any_call(query_sql, False) + spy_execute.assert_any_call(summary_query_sql, False) + spy_execute.assert_any_call(compare_sql, False) + spy_execute.assert_any_call(sample_query_sql, False) + spy_execute.assert_any_call(drop_sql, False) spy_execute.reset_mock() @@ -378,7 +378,7 @@ def test_generated_sql(sushi_context_fixed_date: Context, mocker: MockerFixture) ) query_sql_where = 'CREATE TABLE IF NOT EXISTS "memory"."sqlmesh_temp"."__temp_diff_abcdefgh" AS WITH "__source" AS (SELECT "s"."key", "s"."value", "s"."key" AS "__sqlmesh_join_key" FROM "table_diff_source" AS "s" WHERE "s"."key" = 2), "__target" AS (SELECT "t"."key", "t"."value", "t"."key" AS "__sqlmesh_join_key" FROM "table_diff_target" AS "t" WHERE "t"."key" = 2), "__stats" AS (SELECT "s"."key" AS "s__key", "s"."value" AS "s__value", "s"."__sqlmesh_join_key" AS "s____sqlmesh_join_key", "t"."key" AS "t__key", "t"."value" AS "t__value", "t"."__sqlmesh_join_key" AS "t____sqlmesh_join_key", CASE WHEN NOT "s"."__sqlmesh_join_key" IS NULL THEN 1 ELSE 0 END AS "s_exists", CASE WHEN NOT "t"."__sqlmesh_join_key" IS NULL THEN 1 ELSE 0 END AS "t_exists", CASE WHEN "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key" THEN 1 ELSE 0 END AS "row_joined", CASE WHEN "s"."key" IS NULL AND "t"."key" IS NULL THEN 1 ELSE 0 END AS "null_grain", CASE WHEN "s"."key" = "t"."key" THEN 1 WHEN ("s"."key" IS NULL) AND ("t"."key" IS NULL) THEN 1 WHEN ("s"."key" IS NULL) OR ("t"."key" IS NULL) THEN 0 ELSE 0 END AS "key_matches", CASE WHEN ROUND("s"."value", 3) = ROUND("t"."value", 3) THEN 1 WHEN ("s"."value" IS NULL) AND ("t"."value" IS NULL) THEN 1 WHEN ("s"."value" IS NULL) OR ("t"."value" IS NULL) THEN 0 ELSE 0 END AS "value_matches" FROM "__source" AS "s" FULL JOIN "__target" AS "t" ON "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key") SELECT *, CASE WHEN "key_matches" = 1 AND "value_matches" = 1 THEN 1 ELSE 0 END AS "row_full_match" FROM "__stats"' - spy_execute.assert_any_call(query_sql_where) + spy_execute.assert_any_call(query_sql_where, False) def test_tables_and_grain_inferred_from_model(sushi_context_fixed_date: Context): diff --git a/tests/core/test_test.py b/tests/core/test_test.py index 9c3c3aba4b..1b5425068f 100644 --- a/tests/core/test_test.py +++ b/tests/core/test_test.py @@ -874,7 +874,8 @@ def test_partially_inferred_schemas(sushi_context: Context, mocker: MockerFixtur 'CAST("s" AS STRUCT("d" DATE)) AS "s", ' 'CAST("a" AS INT) AS "a", ' 'CAST("b" AS TEXT) AS "b" ' - """FROM (VALUES ({'d': CAST('2020-01-01' AS DATE)}, 1, 'bla')) AS "t"("s", "a", "b")""" + """FROM (VALUES ({'d': CAST('2020-01-01' AS DATE)}, 1, 'bla')) AS "t"("s", "a", "b")""", + False, ) @@ -1329,14 +1330,15 @@ def test_freeze_time(mocker: MockerFixture) -> None: spy_execute.assert_has_calls( [ - call('CREATE SCHEMA IF NOT EXISTS "memory"."sqlmesh_test_jzngz56a"'), + call('CREATE SCHEMA IF NOT EXISTS "memory"."sqlmesh_test_jzngz56a"', False), call( "SELECT " """CAST('2023-01-01 12:05:03+00:00' AS DATE) AS "cur_date", """ """CAST('2023-01-01 12:05:03+00:00' AS TIME) AS "cur_time", """ - '''CAST('2023-01-01 12:05:03+00:00' AS TIMESTAMP) AS "cur_timestamp"''' + '''CAST('2023-01-01 12:05:03+00:00' AS TIMESTAMP) AS "cur_timestamp"''', + False, ), - call('DROP SCHEMA IF EXISTS "memory"."sqlmesh_test_jzngz56a" CASCADE'), + call('DROP SCHEMA IF EXISTS "memory"."sqlmesh_test_jzngz56a" CASCADE', False), ] ) @@ -1361,7 +1363,12 @@ def test_freeze_time(mocker: MockerFixture) -> None: _check_successful_or_raise(test.run()) spy_execute.assert_has_calls( - [call('''SELECT CAST('2023-01-01 12:05:03+00:00' AS TIMESTAMPTZ) AS "cur_timestamp"''')] + [ + call( + '''SELECT CAST('2023-01-01 12:05:03+00:00' AS TIMESTAMPTZ) AS "cur_timestamp"''', + False, + ) + ] ) @model("py_model", columns={"ts1": "timestamptz", "ts2": "timestamptz"}) @@ -1496,7 +1503,7 @@ def test_gateway(copy_to_temp_path: t.Callable, mocker: MockerFixture) -> None: 'AS "t"("id", "customer_id", "waiter_id", "start_ts", "end_ts", "event_date")' ) test_adapter = t.cast(ModelTest, result.successes[0]).engine_adapter - assert call(test_adapter, expected_view_sql) in spy_execute.mock_calls + assert call(test_adapter, expected_view_sql, False) in spy_execute.mock_calls _check_successful_or_raise(context.test()) @@ -1621,7 +1628,8 @@ def test_generate_input_data_using_sql(mocker: MockerFixture, tmp_path: Path) -> spy_execute.assert_any_call( 'CREATE OR REPLACE VIEW "memory"."sqlmesh_test_jzngz56a"."foo" AS ' - '''SELECT {'x': 1, 'n': {'y': 2}} AS "struct_value"''' + '''SELECT {'x': 1, 'n': {'y': 2}} AS "struct_value"''', + False, ) with pytest.raises( @@ -1817,9 +1825,9 @@ def test_custom_testing_schema(mocker: MockerFixture) -> None: spy_execute.assert_has_calls( [ - call('CREATE SCHEMA IF NOT EXISTS "memory"."my_schema"'), - call('SELECT 1 AS "a"'), - call('DROP SCHEMA IF EXISTS "memory"."my_schema" CASCADE'), + call('CREATE SCHEMA IF NOT EXISTS "memory"."my_schema"', False), + call('SELECT 1 AS "a"', False), + call('DROP SCHEMA IF EXISTS "memory"."my_schema" CASCADE', False), ] ) @@ -1845,9 +1853,9 @@ def test_pretty_query(mocker: MockerFixture) -> None: _check_successful_or_raise(test.run()) spy_execute.assert_has_calls( [ - call('CREATE SCHEMA IF NOT EXISTS "memory"."my_schema"'), - call('SELECT\n 1 AS "a"'), - call('DROP SCHEMA IF EXISTS "memory"."my_schema" CASCADE'), + call('CREATE SCHEMA IF NOT EXISTS "memory"."my_schema"', False), + call('SELECT\n 1 AS "a"', False), + call('DROP SCHEMA IF EXISTS "memory"."my_schema" CASCADE', False), ] ) @@ -2950,7 +2958,7 @@ def test_parameterized_name_sql_model() -> None: outputs: query: - id: 1 - name: foo + name: foo """, variables=variables, ), @@ -2999,7 +3007,7 @@ def execute( outputs: query: - id: 1 - name: foo + name: foo """, variables=variables, ), @@ -3049,7 +3057,7 @@ def test_parameterized_name_self_referential_model(): v: int outputs: query: - - v: 1 + - v: 1 """, variables=variables, ), @@ -3171,7 +3179,7 @@ def execute( - id: 5 outputs: query: - - id: 8 + - id: 8 """, variables=variables, ), From e4e30f0022188e375ced7565bb29334180cfa45b Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Tue, 19 Aug 2025 12:40:49 -0500 Subject: [PATCH 14/31] Change tracking arg name to track_rows_processed --- sqlmesh/core/console.py | 4 +-- sqlmesh/core/engine_adapter/base.py | 38 +++++++++++------------ sqlmesh/core/engine_adapter/bigquery.py | 4 +-- sqlmesh/core/engine_adapter/clickhouse.py | 6 ++-- sqlmesh/core/engine_adapter/duckdb.py | 4 +-- sqlmesh/core/engine_adapter/redshift.py | 4 +-- sqlmesh/core/engine_adapter/snowflake.py | 4 +-- sqlmesh/core/engine_adapter/spark.py | 4 +-- sqlmesh/core/engine_adapter/trino.py | 4 +-- sqlmesh/core/state_sync/db/environment.py | 4 +-- sqlmesh/core/state_sync/db/interval.py | 4 +-- sqlmesh/core/state_sync/db/migrator.py | 2 +- sqlmesh/core/state_sync/db/snapshot.py | 4 +-- sqlmesh/core/state_sync/db/version.py | 2 +- 14 files changed, 44 insertions(+), 44 deletions(-) diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index ea7445be82..4d3af6c2dc 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -4306,7 +4306,7 @@ def _calculate_annotation_str_len( # Convert number of bytes to a human-readable string # https://github.com/dbt-labs/dbt-adapters/blob/34fd178539dcb6f82e18e738adc03de7784c032f/dbt-bigquery/src/dbt/adapters/bigquery/connections.py#L165 def _format_bytes(num_bytes: t.Optional[int]) -> str: - if num_bytes and num_bytes > 0: + if num_bytes and num_bytes >= 0: if num_bytes < 1024: return f"{num_bytes} bytes" @@ -4324,7 +4324,7 @@ def _format_bytes(num_bytes: t.Optional[int]) -> str: # Abbreviate integer count. Example: 1,000,000,000 -> 1b # https://github.com/dbt-labs/dbt-adapters/blob/34fd178539dcb6f82e18e738adc03de7784c032f/dbt-bigquery/src/dbt/adapters/bigquery/connections.py#L178 def _abbreviate_integer_count(count: t.Optional[int]) -> str: - if count and count > 0: + if count and count >= 0: if count < 1000: return str(count) diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index 5c25f3fcc8..fe723d1109 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -856,7 +856,7 @@ def _create_table_from_source_queries( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, - track_execution_stats: bool = True, + track_rows_processed: bool = True, **kwargs: t.Any, ) -> None: table = exp.to_table(table_name) @@ -902,7 +902,7 @@ def _create_table_from_source_queries( replace=replace, table_description=table_description, table_kind=table_kind, - track_execution_stats=track_execution_stats, + track_rows_processed=track_rows_processed, **kwargs, ) else: @@ -910,7 +910,7 @@ def _create_table_from_source_queries( table_name, query, target_columns_to_types or self.columns(table), - track_execution_stats=track_execution_stats, + track_rows_processed=track_rows_processed, ) # Register comments with commands if the engine supports comments and we weren't able to @@ -934,7 +934,7 @@ def _create_table( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, - track_execution_stats: bool = True, + track_rows_processed: bool = True, **kwargs: t.Any, ) -> None: self.execute( @@ -952,7 +952,7 @@ def _create_table( table_kind=table_kind, **kwargs, ), - track_execution_stats=track_execution_stats, + track_rows_processed=track_rows_processed, ) def _build_create_table_exp( @@ -1440,7 +1440,7 @@ def insert_append( table_name: TableName, query_or_df: QueryOrDF, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, - track_execution_stats: bool = True, + track_rows_processed: bool = True, source_columns: t.Optional[t.List[str]] = None, ) -> None: source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( @@ -1450,7 +1450,7 @@ def insert_append( source_columns=source_columns, ) self._insert_append_source_queries( - table_name, source_queries, target_columns_to_types, track_execution_stats + table_name, source_queries, target_columns_to_types, track_rows_processed ) def _insert_append_source_queries( @@ -1458,7 +1458,7 @@ def _insert_append_source_queries( table_name: TableName, source_queries: t.List[SourceQuery], target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, - track_execution_stats: bool = True, + track_rows_processed: bool = True, ) -> None: with self.transaction(condition=len(source_queries) > 0): target_columns_to_types = target_columns_to_types or self.columns(table_name) @@ -1468,7 +1468,7 @@ def _insert_append_source_queries( table_name, query, target_columns_to_types, - track_execution_stats=track_execution_stats, + track_rows_processed=track_rows_processed, ) def _insert_append_query( @@ -1477,13 +1477,13 @@ def _insert_append_query( query: Query, target_columns_to_types: t.Dict[str, exp.DataType], order_projections: bool = True, - track_execution_stats: bool = True, + track_rows_processed: bool = True, ) -> None: if order_projections: query = self._order_projections_and_filter(query, target_columns_to_types) self.execute( exp.insert(query, table_name, columns=list(target_columns_to_types)), - track_execution_stats=track_execution_stats, + track_rows_processed=track_rows_processed, ) def insert_overwrite_by_partition( @@ -1626,7 +1626,7 @@ def _insert_overwrite_by_condition( ) if insert_overwrite_strategy.is_replace_where: insert_exp.set("where", where or exp.true()) - self.execute(insert_exp, track_execution_stats=True) + self.execute(insert_exp, track_rows_processed=True) def update_table( self, @@ -1648,7 +1648,7 @@ def _merge( exp.Subquery(this=query), alias=MERGE_SOURCE_ALIAS, copy=False, table=True ) self.execute( - exp.Merge(this=this, using=using, on=on, whens=whens), track_execution_stats=True + exp.Merge(this=this, using=using, on=on, whens=whens), track_rows_processed=True ) def scd_type_2_by_time( @@ -2398,7 +2398,7 @@ def execute( expressions: t.Union[str, exp.Expression, t.Sequence[exp.Expression]], ignore_unsupported_errors: bool = False, quote_identifiers: bool = True, - track_execution_stats: bool = False, + track_rows_processed: bool = False, **kwargs: t.Any, ) -> None: """Execute a sql query.""" @@ -2420,7 +2420,7 @@ def execute( expression=e if isinstance(e, exp.Expression) else None, quote_identifiers=quote_identifiers, ) - self._execute(sql, track_execution_stats, **kwargs) + self._execute(sql, track_rows_processed, **kwargs) def _attach_correlation_id(self, sql: str) -> str: if self.ATTACH_CORRELATION_ID and self.correlation_id: @@ -2450,12 +2450,12 @@ def _record_execution_stats( ) -> None: QueryExecutionTracker.record_execution(sql, rowcount, bytes_processed) - def _execute(self, sql: str, track_execution_stats: bool = False, **kwargs: t.Any) -> None: + def _execute(self, sql: str, track_rows_processed: bool = False, **kwargs: t.Any) -> None: self.cursor.execute(sql, **kwargs) if ( self.SUPPORTS_QUERY_EXECUTION_TRACKING - and track_execution_stats + and track_rows_processed and QueryExecutionTracker.is_tracking() ): rowcount_raw = getattr(self.cursor, "rowcount", None) @@ -2512,7 +2512,7 @@ def temp_table( exists=True, table_description=None, column_descriptions=None, - track_execution_stats=False, + track_rows_processed=False, **kwargs, ) @@ -2764,7 +2764,7 @@ def _replace_by_key( insert_statement.set("where", delete_filter) insert_statement.set("this", exp.to_table(target_table)) - self.execute(insert_statement, track_execution_stats=True) + self.execute(insert_statement, track_rows_processed=True) finally: self.drop_table(temp_table) diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index 37a0ab2578..679cff05ec 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -1051,7 +1051,7 @@ def _db_call(self, func: t.Callable[..., t.Any], *args: t.Any, **kwargs: t.Any) def _execute( self, sql: str, - track_execution_stats: bool = False, + track_rows_processed: bool = False, **kwargs: t.Any, ) -> None: """Execute a sql query.""" @@ -1097,7 +1097,7 @@ def _execute( self.cursor._set_rowcount(query_results) self.cursor._set_description(query_results.schema) - if track_execution_stats and QueryExecutionTracker.is_tracking(): + if track_rows_processed and QueryExecutionTracker.is_tracking(): num_rows = None if query_job.statement_type == "CREATE_TABLE_AS_SELECT": # since table was just created, number rows in table == number rows processed diff --git a/sqlmesh/core/engine_adapter/clickhouse.py b/sqlmesh/core/engine_adapter/clickhouse.py index 34c6aad431..ccffe64118 100644 --- a/sqlmesh/core/engine_adapter/clickhouse.py +++ b/sqlmesh/core/engine_adapter/clickhouse.py @@ -294,7 +294,7 @@ def _insert_overwrite_by_condition( ) try: - self.execute(existing_records_insert_exp, track_execution_stats=True) + self.execute(existing_records_insert_exp, track_rows_processed=True) finally: if table_partition_exp: self.drop_table(partitions_temp_table_name) @@ -489,7 +489,7 @@ def _create_table( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, - track_execution_stats: bool = True, + track_rows_processed: bool = True, **kwargs: t.Any, ) -> None: """Creates a table in the database. @@ -526,7 +526,7 @@ def _create_table( column_descriptions, table_kind, empty_ctas=(self.engine_run_mode.is_cloud and expression is not None), - track_execution_stats=track_execution_stats, + track_rows_processed=track_rows_processed, **kwargs, ) diff --git a/sqlmesh/core/engine_adapter/duckdb.py b/sqlmesh/core/engine_adapter/duckdb.py index b501d15945..3b057219e0 100644 --- a/sqlmesh/core/engine_adapter/duckdb.py +++ b/sqlmesh/core/engine_adapter/duckdb.py @@ -170,7 +170,7 @@ def _create_table( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, - track_execution_stats: bool = True, + track_rows_processed: bool = True, **kwargs: t.Any, ) -> None: catalog = self.get_current_catalog() @@ -194,7 +194,7 @@ def _create_table( table_description, column_descriptions, table_kind, - track_execution_stats=track_execution_stats, + track_rows_processed=track_rows_processed, **kwargs, ) diff --git a/sqlmesh/core/engine_adapter/redshift.py b/sqlmesh/core/engine_adapter/redshift.py index 1b7faef05e..7d14207b52 100644 --- a/sqlmesh/core/engine_adapter/redshift.py +++ b/sqlmesh/core/engine_adapter/redshift.py @@ -173,7 +173,7 @@ def _create_table_from_source_queries( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, - track_execution_stats: bool = True, + track_rows_processed: bool = True, **kwargs: t.Any, ) -> None: """ @@ -428,7 +428,7 @@ def resolve_target_table(expression: exp.Expression) -> exp.Expression: on=on.transform(resolve_target_table), whens=whens.transform(resolve_target_table), ), - track_execution_stats=True, + track_rows_processed=True, ) def _normalize_decimal_value(self, expr: exp.Expression, precision: int) -> exp.Expression: diff --git a/sqlmesh/core/engine_adapter/snowflake.py b/sqlmesh/core/engine_adapter/snowflake.py index 831368d841..d90d8c7afa 100644 --- a/sqlmesh/core/engine_adapter/snowflake.py +++ b/sqlmesh/core/engine_adapter/snowflake.py @@ -169,7 +169,7 @@ def _create_table( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, - track_execution_stats: bool = True, + track_rows_processed: bool = True, **kwargs: t.Any, ) -> None: table_format = kwargs.get("table_format") @@ -189,7 +189,7 @@ def _create_table( table_description=table_description, column_descriptions=column_descriptions, table_kind=table_kind, - track_execution_stats=track_execution_stats, + track_rows_processed=track_rows_processed, **kwargs, ) diff --git a/sqlmesh/core/engine_adapter/spark.py b/sqlmesh/core/engine_adapter/spark.py index 2d9a4abd99..412e01f5bb 100644 --- a/sqlmesh/core/engine_adapter/spark.py +++ b/sqlmesh/core/engine_adapter/spark.py @@ -433,7 +433,7 @@ def _create_table( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, - track_execution_stats: bool = True, + track_rows_processed: bool = True, **kwargs: t.Any, ) -> None: table_name = ( @@ -462,7 +462,7 @@ def _create_table( target_columns_to_types=target_columns_to_types, table_description=table_description, column_descriptions=column_descriptions, - track_execution_stats=track_execution_stats, + track_rows_processed=track_rows_processed, **kwargs, ) table_name = ( diff --git a/sqlmesh/core/engine_adapter/trino.py b/sqlmesh/core/engine_adapter/trino.py index 2e15cfb78e..4cef557d94 100644 --- a/sqlmesh/core/engine_adapter/trino.py +++ b/sqlmesh/core/engine_adapter/trino.py @@ -358,7 +358,7 @@ def _create_table( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, - track_execution_stats: bool = True, + track_rows_processed: bool = True, **kwargs: t.Any, ) -> None: super()._create_table( @@ -370,7 +370,7 @@ def _create_table( table_description=table_description, column_descriptions=column_descriptions, table_kind=table_kind, - track_execution_stats=track_execution_stats, + track_rows_processed=track_rows_processed, **kwargs, ) diff --git a/sqlmesh/core/state_sync/db/environment.py b/sqlmesh/core/state_sync/db/environment.py index 444985274d..4a28d7d70a 100644 --- a/sqlmesh/core/state_sync/db/environment.py +++ b/sqlmesh/core/state_sync/db/environment.py @@ -78,7 +78,7 @@ def update_environment(self, environment: Environment) -> None: self.environments_table, _environment_to_df(environment), target_columns_to_types=self._environment_columns_to_types, - track_execution_stats=False, + track_rows_processed=False, ) def update_environment_statements( @@ -109,7 +109,7 @@ def update_environment_statements( self.environment_statements_table, _environment_statements_to_df(environment_name, plan_id, environment_statements), target_columns_to_types=self._environment_statements_columns_to_types, - track_execution_stats=False, + track_rows_processed=False, ) def invalidate_environment(self, name: str, protect_prod: bool = True) -> None: diff --git a/sqlmesh/core/state_sync/db/interval.py b/sqlmesh/core/state_sync/db/interval.py index e06100f904..75f475b75b 100644 --- a/sqlmesh/core/state_sync/db/interval.py +++ b/sqlmesh/core/state_sync/db/interval.py @@ -115,7 +115,7 @@ def remove_intervals( self.intervals_table, _intervals_to_df(intervals_to_remove, is_dev=False, is_removed=True), target_columns_to_types=self._interval_columns_to_types, - track_execution_stats=False, + track_rows_processed=False, ) def get_snapshot_intervals( @@ -244,7 +244,7 @@ def _push_snapshot_intervals( self.intervals_table, pd.DataFrame(new_intervals), target_columns_to_types=self._interval_columns_to_types, - track_execution_stats=False, + track_rows_processed=False, ) def _get_snapshot_intervals( diff --git a/sqlmesh/core/state_sync/db/migrator.py b/sqlmesh/core/state_sync/db/migrator.py index ecf00baa3c..7edd7de3c4 100644 --- a/sqlmesh/core/state_sync/db/migrator.py +++ b/sqlmesh/core/state_sync/db/migrator.py @@ -414,7 +414,7 @@ def _backup_state(self) -> None: self.engine_adapter.drop_table(backup_name) self.engine_adapter.create_table_like(backup_name, table) self.engine_adapter.insert_append( - backup_name, exp.select("*").from_(table), track_execution_stats=False + backup_name, exp.select("*").from_(table), track_rows_processed=False ) def _restore_table( diff --git a/sqlmesh/core/state_sync/db/snapshot.py b/sqlmesh/core/state_sync/db/snapshot.py index 2a9046986b..8d504993fc 100644 --- a/sqlmesh/core/state_sync/db/snapshot.py +++ b/sqlmesh/core/state_sync/db/snapshot.py @@ -103,7 +103,7 @@ def push_snapshots(self, snapshots: t.Iterable[Snapshot], overwrite: bool = Fals self.snapshots_table, _snapshots_to_df(snapshots_to_store), target_columns_to_types=self._snapshot_columns_to_types, - track_execution_stats=False, + track_rows_processed=False, ) for snapshot in snapshots: @@ -407,7 +407,7 @@ def _push_snapshots(self, snapshots: t.Iterable[Snapshot]) -> None: self.snapshots_table, _snapshots_to_df(snapshots_to_store), target_columns_to_types=self._snapshot_columns_to_types, - track_execution_stats=False, + track_rows_processed=False, ) def _get_snapshots( diff --git a/sqlmesh/core/state_sync/db/version.py b/sqlmesh/core/state_sync/db/version.py index 487347f7d1..c95592bc31 100644 --- a/sqlmesh/core/state_sync/db/version.py +++ b/sqlmesh/core/state_sync/db/version.py @@ -55,7 +55,7 @@ def update_versions( ] ), target_columns_to_types=self._version_columns_to_types, - track_execution_stats=False, + track_rows_processed=False, ) def get_versions(self) -> Versions: From 6b0932a7046678c43a885aa1790c9560aafb88b6 Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Tue, 19 Aug 2025 16:31:25 -0500 Subject: [PATCH 15/31] Report 0 rows correctly --- sqlmesh/core/console.py | 10 ++++----- .../integration/test_integration.py | 21 ++++++++++++++++++- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index 4d3af6c2dc..e6010f7c0b 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -4189,15 +4189,15 @@ def _create_evaluation_model_annotation( if execution_stats: rows_processed = execution_stats.total_rows_processed execution_stats_str += ( - f"{_abbreviate_integer_count(rows_processed)} row{'s' if rows_processed > 1 else ''}" - if rows_processed + f"{_abbreviate_integer_count(rows_processed)} row{'s' if rows_processed != 1 else ''}" + if rows_processed is not None and rows_processed >= 0 else "" ) bytes_processed = execution_stats.total_bytes_processed execution_stats_str += ( f"{', ' if execution_stats_str else ''}{_format_bytes(bytes_processed)}" - if bytes_processed + if bytes_processed is not None and bytes_processed >= 0 else "" ) execution_stats_str = f" ({execution_stats_str})" if execution_stats_str else "" @@ -4306,7 +4306,7 @@ def _calculate_annotation_str_len( # Convert number of bytes to a human-readable string # https://github.com/dbt-labs/dbt-adapters/blob/34fd178539dcb6f82e18e738adc03de7784c032f/dbt-bigquery/src/dbt/adapters/bigquery/connections.py#L165 def _format_bytes(num_bytes: t.Optional[int]) -> str: - if num_bytes and num_bytes >= 0: + if num_bytes is not None and num_bytes >= 0: if num_bytes < 1024: return f"{num_bytes} bytes" @@ -4324,7 +4324,7 @@ def _format_bytes(num_bytes: t.Optional[int]) -> str: # Abbreviate integer count. Example: 1,000,000,000 -> 1b # https://github.com/dbt-labs/dbt-adapters/blob/34fd178539dcb6f82e18e738adc03de7784c032f/dbt-bigquery/src/dbt/adapters/bigquery/connections.py#L178 def _abbreviate_integer_count(count: t.Optional[int]) -> str: - if count and count >= 0: + if count is not None and count >= 0: if count < 1000: return str(count) diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index 464ccfd996..42f1c8c584 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -6,12 +6,13 @@ import sys import typing as t import shutil -from datetime import datetime, timedelta +from datetime import date, datetime, timedelta from unittest.mock import patch import numpy as np # noqa: TID253 import pandas as pd # noqa: TID253 import pytest import pytz +import time_machine from sqlglot import exp, parse_one from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlglot.optimizer.qualify_columns import quote_identifiers @@ -2429,6 +2430,24 @@ def capture_execution_stats( assert actual_execution_stats["incremental_model"].total_bytes_processed assert actual_execution_stats["full_model"].total_bytes_processed + # run that loads 0 rows in incremental model + with patch.object( + context.console, "update_snapshot_evaluation_progress", capture_execution_stats + ): + with time_machine.travel(date.today() + timedelta(days=1)): + context.run() + + if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING: + assert actual_execution_stats["incremental_model"].total_rows_processed == 0 + # snowflake doesn't track rows for CTAS + assert actual_execution_stats["full_model"].total_rows_processed == ( + None if ctx.mark.startswith("snowflake") else 3 + ) + + if ctx.mark.startswith("bigquery"): + assert actual_execution_stats["incremental_model"].total_bytes_processed + assert actual_execution_stats["full_model"].total_bytes_processed + # make and validate unmodified dev environment no_change_plan: Plan = context.plan_builder( environment="test_dev", From 55c5ffca01e8b39205aa7bc7e7e5c3a629895b08 Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Tue, 19 Aug 2025 18:35:35 -0500 Subject: [PATCH 16/31] Add databricks support --- sqlmesh/core/engine_adapter/databricks.py | 53 ++++++++++++++++++- sqlmesh/core/snapshot/execution_tracker.py | 2 +- .../integration/test_integration.py | 13 ++--- 3 files changed, 60 insertions(+), 8 deletions(-) diff --git a/sqlmesh/core/engine_adapter/databricks.py b/sqlmesh/core/engine_adapter/databricks.py index da70163db4..6abd84b700 100644 --- a/sqlmesh/core/engine_adapter/databricks.py +++ b/sqlmesh/core/engine_adapter/databricks.py @@ -4,7 +4,7 @@ import typing as t from functools import partial -from sqlglot import exp +from sqlglot import exp, parse_one from sqlmesh.core.dialect import to_schema from sqlmesh.core.engine_adapter.shared import ( CatalogSupport, @@ -16,6 +16,7 @@ from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter from sqlmesh.core.node import IntervalUnit from sqlmesh.core.schema_diff import NestedSupport +from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker from sqlmesh.engines.spark.db_api.spark_session import connection, SparkSessionConnection from sqlmesh.utils.errors import SQLMeshError, MissingDefaultCatalogError @@ -34,6 +35,7 @@ class DatabricksEngineAdapter(SparkEngineAdapter): SUPPORTS_CLONING = True SUPPORTS_MATERIALIZED_VIEWS = True SUPPORTS_MATERIALIZED_VIEW_SCHEMA = True + SUPPORTS_QUERY_EXECUTION_TRACKING = True SCHEMA_DIFFER_KWARGS = { "support_positional_add": True, "nested_support": NestedSupport.ALL, @@ -363,3 +365,52 @@ def _build_table_properties_exp( expressions.append(clustered_by_exp) properties = exp.Properties(expressions=expressions) return properties + + def _record_execution_stats( + self, sql: str, rowcount: t.Optional[int] = None, bytes_processed: t.Optional[int] = None + ) -> None: + parsed = parse_one(sql, dialect=self.dialect) + table = parsed.find(exp.Table) + table_name = table.sql(dialect=self.dialect) if table else None + + if table_name: + try: + self.cursor.execute(f"DESCRIBE HISTORY {table_name}") + except: + return + + history = self.cursor.fetchall_arrow() + if history.num_rows: + history_df = history.to_pandas() + write_df = history_df[history_df["operation"] == "WRITE"] + write_df = write_df[write_df["timestamp"] == write_df["timestamp"].max()] + if not write_df.empty: + metrics = write_df["operationMetrics"][0] + if metrics: + rowcount = None + rowcount_str = [ + metric[1] for metric in metrics if metric[0] == "numOutputRows" + ] + if rowcount_str: + try: + rowcount = int(rowcount_str[0]) + except (TypeError, ValueError): + pass + + bytes_processed = None + bytes_str = [ + metric[1] for metric in metrics if metric[0] == "numOutputBytes" + ] + if bytes_str: + try: + bytes_processed = int(bytes_str[0]) + except (TypeError, ValueError): + pass + + if rowcount is not None or bytes_processed is not None: + # if no rows were written, df contains 0 for bytes but no value for rows + rowcount = ( + 0 if rowcount is None and bytes_processed is not None else rowcount + ) + + QueryExecutionTracker.record_execution(sql, rowcount, bytes_processed) diff --git a/sqlmesh/core/snapshot/execution_tracker.py b/sqlmesh/core/snapshot/execution_tracker.py index b80b746dcc..bb29c09862 100644 --- a/sqlmesh/core/snapshot/execution_tracker.py +++ b/sqlmesh/core/snapshot/execution_tracker.py @@ -41,7 +41,7 @@ def __post_init__(self) -> None: def add_execution( self, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int] ) -> None: - if row_count is not None: + if row_count is not None and row_count >= 0: if self.stats.total_rows_processed is None: self.stats.total_rows_processed = row_count else: diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index 42f1c8c584..fca9a6c32b 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -2426,11 +2426,12 @@ def capture_execution_stats( # seed rows aren't tracked assert actual_execution_stats["seed_model"].total_rows_processed is None - if ctx.mark.startswith("bigquery"): - assert actual_execution_stats["incremental_model"].total_bytes_processed - assert actual_execution_stats["full_model"].total_bytes_processed + if ctx.mark.startswith("bigquery") or ctx.mark.startswith("databricks"): + assert actual_execution_stats["incremental_model"].total_bytes_processed is not None + assert actual_execution_stats["full_model"].total_bytes_processed is not None # run that loads 0 rows in incremental model + actual_execution_stats = {} with patch.object( context.console, "update_snapshot_evaluation_progress", capture_execution_stats ): @@ -2444,9 +2445,9 @@ def capture_execution_stats( None if ctx.mark.startswith("snowflake") else 3 ) - if ctx.mark.startswith("bigquery"): - assert actual_execution_stats["incremental_model"].total_bytes_processed - assert actual_execution_stats["full_model"].total_bytes_processed + if ctx.mark.startswith("bigquery") or ctx.mark.startswith("databricks"): + assert actual_execution_stats["incremental_model"].total_bytes_processed is not None + assert actual_execution_stats["full_model"].total_bytes_processed is not None # make and validate unmodified dev environment no_change_plan: Plan = context.plan_builder( From 9e3f2aa71eaf1024915408bec1374ecaa1d19252 Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Wed, 20 Aug 2025 12:58:40 -0500 Subject: [PATCH 17/31] Remove time travel test for cloud engines, handle pyspark DFs in dbx --- sqlmesh/core/engine_adapter/databricks.py | 92 ++++++++++++------- .../integration/test_integration.py | 32 +++---- 2 files changed, 72 insertions(+), 52 deletions(-) diff --git a/sqlmesh/core/engine_adapter/databricks.py b/sqlmesh/core/engine_adapter/databricks.py index 6abd84b700..eb177f5bf0 100644 --- a/sqlmesh/core/engine_adapter/databricks.py +++ b/sqlmesh/core/engine_adapter/databricks.py @@ -14,6 +14,7 @@ SourceQuery, ) from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter +from sqlmesh.engines.spark.db_api.spark_session import SparkSessionCursor from sqlmesh.core.node import IntervalUnit from sqlmesh.core.schema_diff import NestedSupport from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker @@ -379,38 +380,59 @@ def _record_execution_stats( except: return - history = self.cursor.fetchall_arrow() - if history.num_rows: - history_df = history.to_pandas() - write_df = history_df[history_df["operation"] == "WRITE"] - write_df = write_df[write_df["timestamp"] == write_df["timestamp"].max()] - if not write_df.empty: - metrics = write_df["operationMetrics"][0] - if metrics: - rowcount = None - rowcount_str = [ - metric[1] for metric in metrics if metric[0] == "numOutputRows" - ] - if rowcount_str: - try: - rowcount = int(rowcount_str[0]) - except (TypeError, ValueError): - pass - - bytes_processed = None - bytes_str = [ - metric[1] for metric in metrics if metric[0] == "numOutputBytes" - ] - if bytes_str: - try: - bytes_processed = int(bytes_str[0]) - except (TypeError, ValueError): - pass - - if rowcount is not None or bytes_processed is not None: - # if no rows were written, df contains 0 for bytes but no value for rows - rowcount = ( - 0 if rowcount is None and bytes_processed is not None else rowcount - ) - - QueryExecutionTracker.record_execution(sql, rowcount, bytes_processed) + history = ( + self.cursor.fetchdf() + if isinstance(self.cursor, SparkSessionCursor) + else self.cursor.fetchall_arrow() + ) + if history is not None: + from pandas import DataFrame as PandasDataFrame + from pyspark.sql import DataFrame as PySparkDataFrame + from pyspark.sql.connect.dataframe import DataFrame as PySparkConnectDataFrame + + history_df = None + if isinstance(history, PandasDataFrame): + history_df = history + elif isinstance(history, (PySparkDataFrame, PySparkConnectDataFrame)): + history_df = history.toPandas() + else: + # arrow table + history_df = history.to_pandas() + + if history_df is not None and not history_df.empty: + write_df = history_df[history_df["operation"] == "WRITE"] + write_df = write_df[write_df["timestamp"] == write_df["timestamp"].max()] + if not write_df.empty: + metrics = write_df["operationMetrics"][0] + if metrics: + rowcount = None + rowcount_str = [ + metric[1] for metric in metrics if metric[0] == "numOutputRows" + ] + if rowcount_str: + try: + rowcount = int(rowcount_str[0]) + except (TypeError, ValueError): + pass + + bytes_processed = None + bytes_str = [ + metric[1] for metric in metrics if metric[0] == "numOutputBytes" + ] + if bytes_str: + try: + bytes_processed = int(bytes_str[0]) + except (TypeError, ValueError): + pass + + if rowcount is not None or bytes_processed is not None: + # if no rows were written, df contains 0 for bytes but no value for rows + rowcount = ( + 0 + if rowcount is None and bytes_processed is not None + else rowcount + ) + + QueryExecutionTracker.record_execution( + sql, rowcount, bytes_processed + ) diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index fca9a6c32b..cc1d3791f3 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -2431,23 +2431,21 @@ def capture_execution_stats( assert actual_execution_stats["full_model"].total_bytes_processed is not None # run that loads 0 rows in incremental model - actual_execution_stats = {} - with patch.object( - context.console, "update_snapshot_evaluation_progress", capture_execution_stats - ): - with time_machine.travel(date.today() + timedelta(days=1)): - context.run() - - if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING: - assert actual_execution_stats["incremental_model"].total_rows_processed == 0 - # snowflake doesn't track rows for CTAS - assert actual_execution_stats["full_model"].total_rows_processed == ( - None if ctx.mark.startswith("snowflake") else 3 - ) - - if ctx.mark.startswith("bigquery") or ctx.mark.startswith("databricks"): - assert actual_execution_stats["incremental_model"].total_bytes_processed is not None - assert actual_execution_stats["full_model"].total_bytes_processed is not None + # - some cloud DBs error because time travel messes up token expiration + if not ctx.is_remote: + actual_execution_stats = {} + with patch.object( + context.console, "update_snapshot_evaluation_progress", capture_execution_stats + ): + with time_machine.travel(date.today() + timedelta(days=1)): + context.run() + + if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING: + assert actual_execution_stats["incremental_model"].total_rows_processed == 0 + # snowflake doesn't track rows for CTAS + assert actual_execution_stats["full_model"].total_rows_processed == ( + None if ctx.mark.startswith("snowflake") else 3 + ) # make and validate unmodified dev environment no_change_plan: Plan = context.plan_builder( From d45b197af99c4ffd739b149f947a10cef755a8b1 Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Wed, 20 Aug 2025 16:25:45 -0500 Subject: [PATCH 18/31] Fix rebase --- sqlmesh/core/scheduler.py | 101 ++++++++++++++--------------- sqlmesh/core/snapshot/evaluator.py | 1 - 2 files changed, 48 insertions(+), 54 deletions(-) diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 66b0b115d6..7110dcfa38 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -490,33 +490,33 @@ def run_node(node: SchedulingUnit) -> None: if isinstance(node, DummyNode): return - with QueryExecutionTracker.track_execution( - f"{snapshot.name}_{batch_idx}" - ) as execution_context: - snapshot = self.snapshots_by_name[node.snapshot_name] - - if isinstance(node, EvaluateNode): - self.console.start_snapshot_evaluation_progress(snapshot) - execution_start_ts = now_timestamp() - evaluation_duration_ms: t.Optional[int] = None - start, end = node.interval - - audit_results: t.List[AuditResult] = [] - try: - assert execution_time # mypy - assert deployability_index # mypy - - if audit_only: - audit_results = self._audit_snapshot( - snapshot=snapshot, - environment_naming_info=environment_naming_info, - deployability_index=deployability_index, - snapshots=self.snapshots_by_name, - start=start, - end=end, - execution_time=execution_time, - ) - else: + snapshot = self.snapshots_by_name[node.snapshot_name] + + if isinstance(node, EvaluateNode): + self.console.start_snapshot_evaluation_progress(snapshot) + execution_start_ts = now_timestamp() + evaluation_duration_ms: t.Optional[int] = None + start, end = node.interval + + audit_results: t.List[AuditResult] = [] + try: + assert execution_time # mypy + assert deployability_index # mypy + + if audit_only: + audit_results = self._audit_snapshot( + snapshot=snapshot, + environment_naming_info=environment_naming_info, + deployability_index=deployability_index, + snapshots=self.snapshots_by_name, + start=start, + end=end, + execution_time=execution_time, + ) + else: + with self.snapshot_evaluator.execution_tracker.track_execution( + f"{snapshot.name}_{node.batch_index}" + ) as execution_context: audit_results = self.evaluate( snapshot=snapshot, environment_naming_info=environment_naming_info, @@ -530,35 +530,30 @@ def run_node(node: SchedulingUnit) -> None: target_table_exists=snapshot.snapshot_id not in snapshots_to_create, ) - evaluation_duration_ms = now_timestamp() - execution_start_ts - finally: - num_audits = len(audit_results) - num_audits_failed = sum(1 for result in audit_results if result.count) + evaluation_duration_ms = now_timestamp() - execution_start_ts + finally: + num_audits = len(audit_results) + num_audits_failed = sum(1 for result in audit_results if result.count) - execution_stats = self.snapshot_evaluator.execution_tracker.get_execution_stats( - f"{snapshot.snapshot_id}_{batch_idx}" - ) + execution_stats = self.snapshot_evaluator.execution_tracker.get_execution_stats( + f"{snapshot.snapshot_id}_{node.batch_index}" + ) - self.console.update_snapshot_evaluation_progress( - snapshot, - batched_intervals[snapshot][node.batch_index], - node.batch_index, - evaluation_duration_ms, - num_audits - num_audits_failed, - num_audits_failed, - execution_stats=execution_stats, - auto_restatement_triggers=auto_restatement_triggers.get( - snapshot.snapshot_id - ), + self.console.update_snapshot_evaluation_progress( + snapshot, + batched_intervals[snapshot][node.batch_index], + node.batch_index, + evaluation_duration_ms, + num_audits - num_audits_failed, + num_audits_failed, + execution_stats=execution_stats, ) - elif isinstance(node, CreateNode): - self.snapshot_evaluator.create_snapshot( - snapshot=snapshot, - snapshots=self.snapshots_by_name, - deployability_index=deployability_index, - allow_destructive_snapshots=allow_destructive_snapshots or set(), - allow_additive_snapshots=allow_additive_snapshots or set(), - rows_processed=rows_processed, + elif isinstance(node, CreateNode): + self.snapshot_evaluator.create_snapshot( + snapshot=snapshot, + snapshots=self.snapshots_by_name, + deployability_index=deployability_index, + allow_destructive_snapshots=allow_destructive_snapshots or set(), ) try: diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 9729b7a66b..f1569d0163 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -64,7 +64,6 @@ SnapshotInfoLike, SnapshotTableCleanupTask, ) -from sqlmesh.core.snapshot.definition import parent_snapshots_by_name from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker from sqlmesh.utils import random_id, CorrelationId from sqlmesh.utils.concurrency import ( From 3e1747957723773d40d4276ad84f1b5a90d74f0c Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Wed, 20 Aug 2025 18:06:56 -0500 Subject: [PATCH 19/31] Seeds are now handled in evaluator --- sqlmesh/core/engine_adapter/databricks.py | 4 ++-- sqlmesh/core/scheduler.py | 16 ++++++++-------- .../integration/test_integration.py | 3 +-- tests/core/test_snapshot_evaluator.py | 1 - 4 files changed, 11 insertions(+), 13 deletions(-) diff --git a/sqlmesh/core/engine_adapter/databricks.py b/sqlmesh/core/engine_adapter/databricks.py index eb177f5bf0..4ec95b9f8c 100644 --- a/sqlmesh/core/engine_adapter/databricks.py +++ b/sqlmesh/core/engine_adapter/databricks.py @@ -402,8 +402,8 @@ def _record_execution_stats( if history_df is not None and not history_df.empty: write_df = history_df[history_df["operation"] == "WRITE"] write_df = write_df[write_df["timestamp"] == write_df["timestamp"].max()] - if not write_df.empty: - metrics = write_df["operationMetrics"][0] + if not write_df.empty and "operationMetrics" in write_df.columns: + metrics = write_df["operationMetrics"].iloc[0] if metrics: rowcount = None rowcount_str = [ diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 7110dcfa38..5d4d398138 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -540,13 +540,13 @@ def run_node(node: SchedulingUnit) -> None: ) self.console.update_snapshot_evaluation_progress( - snapshot, - batched_intervals[snapshot][node.batch_index], - node.batch_index, - evaluation_duration_ms, - num_audits - num_audits_failed, - num_audits_failed, - execution_stats=execution_stats, + snapshot, + batched_intervals[snapshot][node.batch_index], + node.batch_index, + evaluation_duration_ms, + num_audits - num_audits_failed, + num_audits_failed, + execution_stats=execution_stats, ) elif isinstance(node, CreateNode): self.snapshot_evaluator.create_snapshot( @@ -554,7 +554,7 @@ def run_node(node: SchedulingUnit) -> None: snapshots=self.snapshots_by_name, deployability_index=deployability_index, allow_destructive_snapshots=allow_destructive_snapshots or set(), - ) + ) try: with self.snapshot_evaluator.concurrent_context(): diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index cc1d3791f3..eaac95e7c5 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -2418,13 +2418,12 @@ def capture_execution_stats( assert len(physical_layer_results.tables) == len(physical_layer_results.non_temp_tables) == 3 if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING: + assert actual_execution_stats["seed_model"].total_rows_processed == 7 assert actual_execution_stats["incremental_model"].total_rows_processed == 7 # snowflake doesn't track rows for CTAS assert actual_execution_stats["full_model"].total_rows_processed == ( None if ctx.mark.startswith("snowflake") else 3 ) - # seed rows aren't tracked - assert actual_execution_stats["seed_model"].total_rows_processed is None if ctx.mark.startswith("bigquery") or ctx.mark.startswith("databricks"): assert actual_execution_stats["incremental_model"].total_bytes_processed is not None diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index c19d118c8c..c25acd279e 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -675,7 +675,6 @@ def test_evaluate_materialized_view_with_partitioned_by_cluster_by( execute_mock.assert_has_calls( [ - call("CREATE SCHEMA IF NOT EXISTS `sqlmesh__test_schema`", False), call( f"CREATE MATERIALIZED VIEW `sqlmesh__test_schema`.`test_schema__test_model__{snapshot.version}` PARTITION BY `a` CLUSTER BY `b` AS SELECT `a` AS `a`, `b` AS `b` FROM `tbl` AS `tbl`", False, From 46717a05f4e5804c2ad80807cf1846f7cb7969ba Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Wed, 20 Aug 2025 18:14:53 -0500 Subject: [PATCH 20/31] Fix rebase --- sqlmesh/core/scheduler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 5d4d398138..22921c6abb 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -527,7 +527,7 @@ def run_node(node: SchedulingUnit) -> None: batch_index=node.batch_index, allow_destructive_snapshots=allow_destructive_snapshots, allow_additive_snapshots=allow_additive_snapshots, - target_table_exists=snapshot.snapshot_id not in snapshots_to_create, + target_table_exists=snapshot.snapshot_id not in snapshots_to_create, ) evaluation_duration_ms = now_timestamp() - execution_start_ts @@ -554,6 +554,7 @@ def run_node(node: SchedulingUnit) -> None: snapshots=self.snapshots_by_name, deployability_index=deployability_index, allow_destructive_snapshots=allow_destructive_snapshots or set(), + allow_additive_snapshots=allow_additive_snapshots or set(), ) try: From f227e5afd36029caaf06d27ed369fb77a490e865 Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Wed, 20 Aug 2025 19:04:22 -0500 Subject: [PATCH 21/31] Handle snowflake table already exists --- sqlmesh/core/engine_adapter/snowflake.py | 8 ++++++-- tests/core/engine_adapter/integration/test_integration.py | 4 +++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/sqlmesh/core/engine_adapter/snowflake.py b/sqlmesh/core/engine_adapter/snowflake.py index d90d8c7afa..b5f89d77b0 100644 --- a/sqlmesh/core/engine_adapter/snowflake.py +++ b/sqlmesh/core/engine_adapter/snowflake.py @@ -700,9 +700,13 @@ def _record_execution_stats( # - [^"] matches any single character except a double-quote # - | or # - "" matches two sequential double-quotes - is_ctas = re.match( + is_created = re.match( r'Table ([a-zA-Z0-9_$]+|"(?:[^"]|"")+") successfully created\.', results_str ) - if is_ctas: + is_already_exists = re.match( + r'([a-zA-Z0-9_$]+|"(?:[^"]|"")+") already exists, statement succeeded\.', + results_str, + ) + if is_created or is_already_exists: return QueryExecutionTracker.record_execution(sql, rowcount, bytes_processed) diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index eaac95e7c5..bc000e599f 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -2418,7 +2418,9 @@ def capture_execution_stats( assert len(physical_layer_results.tables) == len(physical_layer_results.non_temp_tables) == 3 if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING: - assert actual_execution_stats["seed_model"].total_rows_processed == 7 + assert actual_execution_stats["seed_model"].total_rows_processed == ( + None if ctx.mark.startswith("snowflake") else 7 + ) assert actual_execution_stats["incremental_model"].total_rows_processed == 7 # snowflake doesn't track rows for CTAS assert actual_execution_stats["full_model"].total_rows_processed == ( From 775eadffec23c17946b24c01dde01128c5596a3f Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Thu, 21 Aug 2025 11:51:57 -0500 Subject: [PATCH 22/31] Query info schema for snowflake CTAS num rows --- sqlmesh/core/engine_adapter/snowflake.py | 39 +++++++------------ .../integration/test_integration.py | 13 ++----- 2 files changed, 16 insertions(+), 36 deletions(-) diff --git a/sqlmesh/core/engine_adapter/snowflake.py b/sqlmesh/core/engine_adapter/snowflake.py index b5f89d77b0..e704987456 100644 --- a/sqlmesh/core/engine_adapter/snowflake.py +++ b/sqlmesh/core/engine_adapter/snowflake.py @@ -2,10 +2,9 @@ import contextlib import logging -import re import typing as t -from sqlglot import exp +from sqlglot import exp, parse_one from sqlglot.helper import ensure_list from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlglot.optimizer.qualify_columns import quote_identifiers @@ -683,30 +682,18 @@ def _record_execution_stats( If so, we return early and do not record the row count. """ if rowcount == 1: - results = self.cursor.fetchall() - if results and len(results) == 1: - try: - results_str = str(results[0][0]) - except (ValueError, TypeError): + query_parsed = parse_one(sql, dialect=self.dialect) + if isinstance(query_parsed, exp.Create): + if query_parsed.expression and isinstance(query_parsed.expression, exp.Select): + table = query_parsed.find(exp.Table) + if table: + row_query = f"SELECT ROW_COUNT as row_count FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = '{table.db}' AND TABLE_NAME = '{table.name}'" + row_query_results = self.fetchone(row_query, quote_identifiers=True) + if row_query_results: + rowcount = row_query_results[0] + else: + return + else: return - # Snowflake identifiers may be: - # - An unquoted contiguous set of [a-zA-Z0-9_$] characters - # - A double-quoted string that may contain spaces and nested double-quotes represented by `""`. Example: " my ""table"" name " - # - Regex: - # - [a-zA-Z0-9_$]+ matches one or more character in the set - # - "(?:[^"]|"")+" matches a double-quoted string that may contain spaces and nested double-quotes - # - ?: non-capturing group - # - [^"] matches any single character except a double-quote - # - | or - # - "" matches two sequential double-quotes - is_created = re.match( - r'Table ([a-zA-Z0-9_$]+|"(?:[^"]|"")+") successfully created\.', results_str - ) - is_already_exists = re.match( - r'([a-zA-Z0-9_$]+|"(?:[^"]|"")+") already exists, statement succeeded\.', - results_str, - ) - if is_created or is_already_exists: - return QueryExecutionTracker.record_execution(sql, rowcount, bytes_processed) diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index bc000e599f..1ef7f27941 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -2418,14 +2418,10 @@ def capture_execution_stats( assert len(physical_layer_results.tables) == len(physical_layer_results.non_temp_tables) == 3 if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING: - assert actual_execution_stats["seed_model"].total_rows_processed == ( - None if ctx.mark.startswith("snowflake") else 7 - ) + assert actual_execution_stats["seed_model"].total_rows_processed == 7 assert actual_execution_stats["incremental_model"].total_rows_processed == 7 # snowflake doesn't track rows for CTAS - assert actual_execution_stats["full_model"].total_rows_processed == ( - None if ctx.mark.startswith("snowflake") else 3 - ) + assert actual_execution_stats["full_model"].total_rows_processed == 3 if ctx.mark.startswith("bigquery") or ctx.mark.startswith("databricks"): assert actual_execution_stats["incremental_model"].total_bytes_processed is not None @@ -2443,10 +2439,7 @@ def capture_execution_stats( if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING: assert actual_execution_stats["incremental_model"].total_rows_processed == 0 - # snowflake doesn't track rows for CTAS - assert actual_execution_stats["full_model"].total_rows_processed == ( - None if ctx.mark.startswith("snowflake") else 3 - ) + assert actual_execution_stats["full_model"].total_rows_processed == 3 # make and validate unmodified dev environment no_change_plan: Plan = context.plan_builder( From 81e21cdda7ef57b907c3288fcb57ce974335076c Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Fri, 22 Aug 2025 11:58:06 -0500 Subject: [PATCH 23/31] Remove databricks, snowflake metadata calls --- sqlmesh/core/engine_adapter/databricks.py | 75 +------------------ sqlmesh/core/engine_adapter/snowflake.py | 44 ++++++----- .../integration/test_integration.py | 17 +++-- 3 files changed, 35 insertions(+), 101 deletions(-) diff --git a/sqlmesh/core/engine_adapter/databricks.py b/sqlmesh/core/engine_adapter/databricks.py index 4ec95b9f8c..da70163db4 100644 --- a/sqlmesh/core/engine_adapter/databricks.py +++ b/sqlmesh/core/engine_adapter/databricks.py @@ -4,7 +4,7 @@ import typing as t from functools import partial -from sqlglot import exp, parse_one +from sqlglot import exp from sqlmesh.core.dialect import to_schema from sqlmesh.core.engine_adapter.shared import ( CatalogSupport, @@ -14,10 +14,8 @@ SourceQuery, ) from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter -from sqlmesh.engines.spark.db_api.spark_session import SparkSessionCursor from sqlmesh.core.node import IntervalUnit from sqlmesh.core.schema_diff import NestedSupport -from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker from sqlmesh.engines.spark.db_api.spark_session import connection, SparkSessionConnection from sqlmesh.utils.errors import SQLMeshError, MissingDefaultCatalogError @@ -36,7 +34,6 @@ class DatabricksEngineAdapter(SparkEngineAdapter): SUPPORTS_CLONING = True SUPPORTS_MATERIALIZED_VIEWS = True SUPPORTS_MATERIALIZED_VIEW_SCHEMA = True - SUPPORTS_QUERY_EXECUTION_TRACKING = True SCHEMA_DIFFER_KWARGS = { "support_positional_add": True, "nested_support": NestedSupport.ALL, @@ -366,73 +363,3 @@ def _build_table_properties_exp( expressions.append(clustered_by_exp) properties = exp.Properties(expressions=expressions) return properties - - def _record_execution_stats( - self, sql: str, rowcount: t.Optional[int] = None, bytes_processed: t.Optional[int] = None - ) -> None: - parsed = parse_one(sql, dialect=self.dialect) - table = parsed.find(exp.Table) - table_name = table.sql(dialect=self.dialect) if table else None - - if table_name: - try: - self.cursor.execute(f"DESCRIBE HISTORY {table_name}") - except: - return - - history = ( - self.cursor.fetchdf() - if isinstance(self.cursor, SparkSessionCursor) - else self.cursor.fetchall_arrow() - ) - if history is not None: - from pandas import DataFrame as PandasDataFrame - from pyspark.sql import DataFrame as PySparkDataFrame - from pyspark.sql.connect.dataframe import DataFrame as PySparkConnectDataFrame - - history_df = None - if isinstance(history, PandasDataFrame): - history_df = history - elif isinstance(history, (PySparkDataFrame, PySparkConnectDataFrame)): - history_df = history.toPandas() - else: - # arrow table - history_df = history.to_pandas() - - if history_df is not None and not history_df.empty: - write_df = history_df[history_df["operation"] == "WRITE"] - write_df = write_df[write_df["timestamp"] == write_df["timestamp"].max()] - if not write_df.empty and "operationMetrics" in write_df.columns: - metrics = write_df["operationMetrics"].iloc[0] - if metrics: - rowcount = None - rowcount_str = [ - metric[1] for metric in metrics if metric[0] == "numOutputRows" - ] - if rowcount_str: - try: - rowcount = int(rowcount_str[0]) - except (TypeError, ValueError): - pass - - bytes_processed = None - bytes_str = [ - metric[1] for metric in metrics if metric[0] == "numOutputBytes" - ] - if bytes_str: - try: - bytes_processed = int(bytes_str[0]) - except (TypeError, ValueError): - pass - - if rowcount is not None or bytes_processed is not None: - # if no rows were written, df contains 0 for bytes but no value for rows - rowcount = ( - 0 - if rowcount is None and bytes_processed is not None - else rowcount - ) - - QueryExecutionTracker.record_execution( - sql, rowcount, bytes_processed - ) diff --git a/sqlmesh/core/engine_adapter/snowflake.py b/sqlmesh/core/engine_adapter/snowflake.py index e704987456..58515e6675 100644 --- a/sqlmesh/core/engine_adapter/snowflake.py +++ b/sqlmesh/core/engine_adapter/snowflake.py @@ -2,9 +2,10 @@ import contextlib import logging +import re import typing as t -from sqlglot import exp, parse_one +from sqlglot import exp from sqlglot.helper import ensure_list from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlglot.optimizer.qualify_columns import quote_identifiers @@ -672,28 +673,31 @@ def _record_execution_stats( ) -> None: """Snowflake does not report row counts for CTAS like other DML operations. - They neither report the sentinel value -1 nor do they report 0 rows. Instead, they return a single data row - containing the string "Table successfully created." and a row count of 1. + They neither report the sentinel value -1 nor do they report 0 rows. Instead, they report a rowcount + of 1 and return a single data row containing one of the strings: + - "Table successfully created." + - " already exists, statement succeeded." - We do not want to record the incorrect row count of 1, so we check whether: - - There is exactly one row to fetch (in general, DML operations should return no rows to fetch from the cursor) - - That row contains the table successfully created string - - If so, we return early and do not record the row count. + We do not want to record the incorrect row count of 1, so we check whether that row contains the table + successfully created string. If so, we return early and do not record the row count. """ if rowcount == 1: - query_parsed = parse_one(sql, dialect=self.dialect) - if isinstance(query_parsed, exp.Create): - if query_parsed.expression and isinstance(query_parsed.expression, exp.Select): - table = query_parsed.find(exp.Table) - if table: - row_query = f"SELECT ROW_COUNT as row_count FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = '{table.db}' AND TABLE_NAME = '{table.name}'" - row_query_results = self.fetchone(row_query, quote_identifiers=True) - if row_query_results: - rowcount = row_query_results[0] - else: - return - else: + results = self.cursor.fetchone() + if results: + try: + results_str = str(results[0]) + except (ValueError, TypeError): + return + + # Snowflake identifiers may be: + # - An unquoted contiguous set of [a-zA-Z0-9_$] characters + # - A double-quoted string that may contain spaces and nested double-quotes represented by `""`. Example: " my ""table"" name " + is_created = re.match(r'Table [a-zA-Z0-9_$"]*? successfully created\.', results_str) + is_already_exists = re.match( + r'[a-zA-Z0-9_$"]*? already exists, statement succeeded\.', + results_str, + ) + if is_created or is_already_exists: return QueryExecutionTracker.record_execution(sql, rowcount, bytes_processed) diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index 1ef7f27941..a320e19687 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -6,13 +6,12 @@ import sys import typing as t import shutil -from datetime import date, datetime, timedelta +from datetime import datetime, timedelta from unittest.mock import patch import numpy as np # noqa: TID253 import pandas as pd # noqa: TID253 import pytest import pytz -import time_machine from sqlglot import exp, parse_one from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlglot.optimizer.qualify_columns import quote_identifiers @@ -2418,14 +2417,18 @@ def capture_execution_stats( assert len(physical_layer_results.tables) == len(physical_layer_results.non_temp_tables) == 3 if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING: - assert actual_execution_stats["seed_model"].total_rows_processed == 7 assert actual_execution_stats["incremental_model"].total_rows_processed == 7 # snowflake doesn't track rows for CTAS - assert actual_execution_stats["full_model"].total_rows_processed == 3 + assert actual_execution_stats["full_model"].total_rows_processed == ( + None if ctx.mark.startswith("snowflake") else 3 + ) + assert actual_execution_stats["seed_model"].total_rows_processed == ( + None if ctx.mark.startswith("snowflake") else 7 + ) - if ctx.mark.startswith("bigquery") or ctx.mark.startswith("databricks"): - assert actual_execution_stats["incremental_model"].total_bytes_processed is not None - assert actual_execution_stats["full_model"].total_bytes_processed is not None + if ctx.mark.startswith("bigquery"): + assert actual_execution_stats["incremental_model"].total_bytes_processed + assert actual_execution_stats["full_model"].total_bytes_processed # run that loads 0 rows in incremental model # - some cloud DBs error because time travel messes up token expiration From 7f5f301a50a11fcd45658f461d533e1135224d70 Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Fri, 22 Aug 2025 12:47:51 -0500 Subject: [PATCH 24/31] Add snowflake test --- .../integration/test_integration.py | 3 ++- .../integration/test_integration_snowflake.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index a320e19687..29287f4f46 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -6,12 +6,13 @@ import sys import typing as t import shutil -from datetime import datetime, timedelta +from datetime import datetime, timedelta, date from unittest.mock import patch import numpy as np # noqa: TID253 import pandas as pd # noqa: TID253 import pytest import pytz +import time_machine from sqlglot import exp, parse_one from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlglot.optimizer.qualify_columns import quote_identifiers diff --git a/tests/core/engine_adapter/integration/test_integration_snowflake.py b/tests/core/engine_adapter/integration/test_integration_snowflake.py index 01cbe1c0aa..7fa7366a46 100644 --- a/tests/core/engine_adapter/integration/test_integration_snowflake.py +++ b/tests/core/engine_adapter/integration/test_integration_snowflake.py @@ -13,6 +13,7 @@ from tests.core.engine_adapter.integration import TestContext from sqlmesh import model, ExecutionContext from sqlmesh.core.model import ModelKindName +from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker from datetime import datetime from tests.core.engine_adapter.integration import ( @@ -307,3 +308,21 @@ def fetch_database_names() -> t.Set[str]: engine_adapter.drop_catalog(sqlmesh_managed_catalog) # works, catalog is SQLMesh-managed assert fetch_database_names() == {non_sqlmesh_managed_catalog} + + +def test_rows_tracker(ctx: TestContext, engine_adapter: SnowflakeEngineAdapter): + sqlmesh = ctx.create_context() + tracker = QueryExecutionTracker() + + with tracker.track_execution("a"): + # Snowflake doesn't report row counts for CTAS, so this should not be tracked + engine_adapter.execute( + "CREATE TABLE a (id int) AS SELECT 1 as id", track_rows_processed=True + ) + engine_adapter.execute("INSERT INTO a VALUES (2), (3)", track_rows_processed=True) + engine_adapter.execute("INSERT INTO a VALUES (4)", track_rows_processed=True) + + stats = tracker.get_execution_stats("a") + assert stats is not None + assert stats.query_count == 2 + assert stats.total_rows_processed == 3 From 175012d56a6534a14b855e3a645a6a5f6da99a10 Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Fri, 22 Aug 2025 16:38:08 -0500 Subject: [PATCH 25/31] Tidy up --- .circleci/continue_config.yml | 18 +++++------ docs/integrations/engines/snowflake.md | 8 +++++ sqlmesh/core/console.py | 4 +-- sqlmesh/core/engine_adapter/base.py | 2 +- sqlmesh/core/engine_adapter/snowflake.py | 10 +++++-- sqlmesh/core/scheduler.py | 30 +++++++++---------- sqlmesh/core/snapshot/execution_tracker.py | 25 +++------------- .../integration/test_integration.py | 1 + .../integration/test_integration_snowflake.py | 15 ++++++++-- tests/core/test_execution_tracker.py | 2 -- 10 files changed, 57 insertions(+), 58 deletions(-) diff --git a/.circleci/continue_config.yml b/.circleci/continue_config.yml index 35f15e2c2f..e21f3d869b 100644 --- a/.circleci/continue_config.yml +++ b/.circleci/continue_config.yml @@ -239,7 +239,7 @@ jobs: - checkout - run: name: Install OS-level dependencies - command: ./.circleci/install-prerequisites.sh "<< parameters.engine >>" + command: ./.circleci/install-prerequisites.sh "<< parameters.engine >>" - run: name: Generate database name command: | @@ -297,9 +297,8 @@ workflows: name: cloud_engine_<< matrix.engine >> context: - sqlmesh_cloud_database_integration - # TODO: uncomment this - # requires: - # - engine_tests_docker + requires: + - engine_tests_docker matrix: parameters: engine: @@ -308,14 +307,13 @@ workflows: - redshift - bigquery - clickhouse-cloud - - athena + - athena - fabric - gcp-postgres - # TODO: uncomment this - # filters: - # branches: - # only: - # - main + filters: + branches: + only: + - main - ui_style - ui_test - vscode_test diff --git a/docs/integrations/engines/snowflake.md b/docs/integrations/engines/snowflake.md index 30de0bfd14..fc2ccbd6bb 100644 --- a/docs/integrations/engines/snowflake.md +++ b/docs/integrations/engines/snowflake.md @@ -250,6 +250,14 @@ And confirm that our schemas and objects exist in the Snowflake catalog: Congratulations - your SQLMesh project is up and running on Snowflake! +### Where are the row counts? + +SQLMesh reports the number of rows processed by each model in its `plan` and `run` terminal output. + +However, due to limitations in the Snowflake Python connector, row counts cannot be determined for `CREATE TABLE AS` statements. Therefore, SQLMesh does not report row counts for certain model kinds, such as `FULL` models. + +Learn more about the connector limitation [on Github](https://github.com/snowflakedb/snowflake-connector-python/issues/645). + ## Local/Built-in Scheduler **Engine Adapter Type**: `snowflake` diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index e6010f7c0b..9fdc643905 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -4030,9 +4030,7 @@ def show_table_diff_summary(self, table_diff: TableDiff) -> None: self._write(f"Join On: {keys}") -# TODO: remove this -# _CONSOLE: Console = NoopConsole() -_CONSOLE: Console = TerminalConsole() +_CONSOLE: Console = NoopConsole() def set_console(console: Console) -> None: diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index fe723d1109..721e7c61c1 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -2464,7 +2464,7 @@ def _execute(self, sql: str, track_rows_processed: bool = False, **kwargs: t.Any try: rowcount = int(rowcount_raw) except (TypeError, ValueError): - pass + return self._record_execution_stats(sql, rowcount) diff --git a/sqlmesh/core/engine_adapter/snowflake.py b/sqlmesh/core/engine_adapter/snowflake.py index 58515e6675..8e74eb67d0 100644 --- a/sqlmesh/core/engine_adapter/snowflake.py +++ b/sqlmesh/core/engine_adapter/snowflake.py @@ -680,21 +680,25 @@ def _record_execution_stats( We do not want to record the incorrect row count of 1, so we check whether that row contains the table successfully created string. If so, we return early and do not record the row count. + + Ref: https://github.com/snowflakedb/snowflake-connector-python/issues/645 """ if rowcount == 1: results = self.cursor.fetchone() if results: try: results_str = str(results[0]) - except (ValueError, TypeError): + except (TypeError, ValueError, IndexError): return # Snowflake identifiers may be: # - An unquoted contiguous set of [a-zA-Z0-9_$] characters # - A double-quoted string that may contain spaces and nested double-quotes represented by `""`. Example: " my ""table"" name " - is_created = re.match(r'Table [a-zA-Z0-9_$"]*? successfully created\.', results_str) + is_created = re.match( + r'Table [a-zA-Z0-9_$ "]*? successfully created\.', results_str + ) is_already_exists = re.match( - r'[a-zA-Z0-9_$"]*? already exists, statement succeeded\.', + r'[a-zA-Z0-9_$ "]*? already exists, statement succeeded\.', results_str, ) if is_created or is_already_exists: diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 22921c6abb..ce27ee1838 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -514,21 +514,18 @@ def run_node(node: SchedulingUnit) -> None: execution_time=execution_time, ) else: - with self.snapshot_evaluator.execution_tracker.track_execution( - f"{snapshot.name}_{node.batch_index}" - ) as execution_context: - audit_results = self.evaluate( - snapshot=snapshot, - environment_naming_info=environment_naming_info, - start=start, - end=end, - execution_time=execution_time, - deployability_index=deployability_index, - batch_index=node.batch_index, - allow_destructive_snapshots=allow_destructive_snapshots, - allow_additive_snapshots=allow_additive_snapshots, - target_table_exists=snapshot.snapshot_id not in snapshots_to_create, - ) + audit_results = self.evaluate( + snapshot=snapshot, + environment_naming_info=environment_naming_info, + start=start, + end=end, + execution_time=execution_time, + deployability_index=deployability_index, + batch_index=node.batch_index, + allow_destructive_snapshots=allow_destructive_snapshots, + allow_additive_snapshots=allow_additive_snapshots, + target_table_exists=snapshot.snapshot_id not in snapshots_to_create, + ) evaluation_duration_ms = now_timestamp() - execution_start_ts finally: @@ -547,6 +544,9 @@ def run_node(node: SchedulingUnit) -> None: num_audits - num_audits_failed, num_audits_failed, execution_stats=execution_stats, + auto_restatement_triggers=auto_restatement_triggers.get( + snapshot.snapshot_id + ), ) elif isinstance(node, CreateNode): self.snapshot_evaluator.create_snapshot( diff --git a/sqlmesh/core/snapshot/execution_tracker.py b/sqlmesh/core/snapshot/execution_tracker.py index bb29c09862..61c10696ce 100644 --- a/sqlmesh/core/snapshot/execution_tracker.py +++ b/sqlmesh/core/snapshot/execution_tracker.py @@ -1,6 +1,5 @@ from __future__ import annotations -import time import typing as t from contextlib import contextmanager from threading import local, Lock @@ -12,10 +11,6 @@ class QueryExecutionStats: snapshot_batch_id: str total_rows_processed: t.Optional[int] = None total_bytes_processed: t.Optional[int] = None - query_count: int = 0 - queries_executed: t.List[t.Tuple[str, t.Optional[int], t.Optional[int], float]] = field( - default_factory=list - ) @dataclass @@ -26,10 +21,8 @@ class QueryExecutionContext: It accumulates statistics from multiple cursor.execute() calls during a single snapshot evaluation. Attributes: - id: Identifier linking this context to a specific operation - total_rows_processed: Running sum of cursor.rowcount from all executed queries during evaluation - query_count: Total number of SQL statements executed - queries_executed: List of (sql_snippet, row_count, timestamp) tuples for debugging + snapshot_batch_id: Identifier linking this context to a specific snapshot evaluation + stats: Running sum of cursor.rowcount and possibly bytes processed from all executed queries during evaluation """ snapshot_batch_id: str @@ -55,20 +48,12 @@ def add_execution( else: self.stats.total_bytes_processed += bytes_processed - self.stats.query_count += 1 - # TODO: remove this - # for debugging - self.stats.queries_executed.append((sql[:300], row_count, bytes_processed, time.time())) - def get_execution_stats(self) -> QueryExecutionStats: return self.stats class QueryExecutionTracker: - """ - Thread-local context manager for snapshot execution statistics, such as - rows processed. - """ + """Thread-local context manager for snapshot execution statistics, such as rows processed.""" _thread_local = local() _contexts: t.Dict[str, QueryExecutionContext] = {} @@ -86,9 +71,7 @@ def is_tracking(cls) -> bool: def track_execution( self, snapshot_id_batch: str ) -> t.Iterator[t.Optional[QueryExecutionContext]]: - """ - Context manager for tracking snapshot execution statistics. - """ + """Context manager for tracking snapshot execution statistics such as row counts and bytes processed.""" context = QueryExecutionContext(snapshot_batch_id=snapshot_id_batch) self._thread_local.context = context with self._contexts_lock: diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index 29287f4f46..19a45329d5 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -2395,6 +2395,7 @@ def capture_execution_stats( num_audits_failed, audit_only=False, execution_stats=None, + auto_restatement_triggers=None, ): if execution_stats is not None: actual_execution_stats[snapshot.model.name.replace(f"{schema_name}.", "")] = ( diff --git a/tests/core/engine_adapter/integration/test_integration_snowflake.py b/tests/core/engine_adapter/integration/test_integration_snowflake.py index 7fa7366a46..d81ab38233 100644 --- a/tests/core/engine_adapter/integration/test_integration_snowflake.py +++ b/tests/core/engine_adapter/integration/test_integration_snowflake.py @@ -12,8 +12,12 @@ from sqlmesh.core.plan import Plan from tests.core.engine_adapter.integration import TestContext from sqlmesh import model, ExecutionContext +from pytest_mock import MockerFixture +from sqlmesh.core.snapshot.execution_tracker import ( + QueryExecutionContext, + QueryExecutionTracker, +) from sqlmesh.core.model import ModelKindName -from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker from datetime import datetime from tests.core.engine_adapter.integration import ( @@ -310,10 +314,14 @@ def fetch_database_names() -> t.Set[str]: assert fetch_database_names() == {non_sqlmesh_managed_catalog} -def test_rows_tracker(ctx: TestContext, engine_adapter: SnowflakeEngineAdapter): +def test_rows_tracker( + ctx: TestContext, engine_adapter: SnowflakeEngineAdapter, mocker: MockerFixture +): sqlmesh = ctx.create_context() tracker = QueryExecutionTracker() + add_execution_spy = mocker.spy(QueryExecutionContext, "add_execution") + with tracker.track_execution("a"): # Snowflake doesn't report row counts for CTAS, so this should not be tracked engine_adapter.execute( @@ -322,7 +330,8 @@ def test_rows_tracker(ctx: TestContext, engine_adapter: SnowflakeEngineAdapter): engine_adapter.execute("INSERT INTO a VALUES (2), (3)", track_rows_processed=True) engine_adapter.execute("INSERT INTO a VALUES (4)", track_rows_processed=True) + assert add_execution_spy.call_count == 2 + stats = tracker.get_execution_stats("a") assert stats is not None - assert stats.query_count == 2 assert stats.total_rows_processed == 3 diff --git a/tests/core/test_execution_tracker.py b/tests/core/test_execution_tracker.py index 3afe56df16..79479f808f 100644 --- a/tests/core/test_execution_tracker.py +++ b/tests/core/test_execution_tracker.py @@ -34,6 +34,4 @@ def worker(id: str, row_counts: list[int]) -> QueryExecutionStats: by_batch = {s.snapshot_batch_id: s for s in results} assert by_batch["batch_A"].total_rows_processed == 15 - assert by_batch["batch_A"].query_count == 2 assert by_batch["batch_B"].total_rows_processed == 10 - assert by_batch["batch_B"].query_count == 2 From d947fec932003a7edb1333b1b7085a8a733544b8 Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Mon, 25 Aug 2025 11:13:13 -0500 Subject: [PATCH 26/31] PR feedback --- sqlmesh/core/console.py | 41 ++----------------- sqlmesh/core/engine_adapter/base.py | 8 +--- sqlmesh/core/scheduler.py | 3 +- sqlmesh/core/snapshot/__init__.py | 1 + sqlmesh/core/snapshot/definition.py | 5 +++ sqlmesh/core/snapshot/evaluator.py | 5 ++- sqlmesh/core/snapshot/execution_tracker.py | 23 +++++++---- .../integration/test_integration_snowflake.py | 9 +++- tests/core/test_execution_tracker.py | 33 ++++++++++----- 9 files changed, 61 insertions(+), 67 deletions(-) diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index 9fdc643905..7ed4da1ede 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -9,6 +9,7 @@ import textwrap from itertools import zip_longest from pathlib import Path +from humanize import naturalsize, metric from hyperscript import h from rich.console import Console as RichConsole from rich.live import Live @@ -4187,14 +4188,14 @@ def _create_evaluation_model_annotation( if execution_stats: rows_processed = execution_stats.total_rows_processed execution_stats_str += ( - f"{_abbreviate_integer_count(rows_processed)} row{'s' if rows_processed != 1 else ''}" + f"{metric(rows_processed)} row{'s' if rows_processed != 1 else ''}" if rows_processed is not None and rows_processed >= 0 else "" ) bytes_processed = execution_stats.total_bytes_processed execution_stats_str += ( - f"{', ' if execution_stats_str else ''}{_format_bytes(bytes_processed)}" + f"{', ' if execution_stats_str else ''}{naturalsize(bytes_processed, binary=True)}" if bytes_processed is not None and bytes_processed >= 0 else "" ) @@ -4299,39 +4300,3 @@ def _calculate_annotation_str_len( + execution_stats_len, ) return annotation_str_len - - -# Convert number of bytes to a human-readable string -# https://github.com/dbt-labs/dbt-adapters/blob/34fd178539dcb6f82e18e738adc03de7784c032f/dbt-bigquery/src/dbt/adapters/bigquery/connections.py#L165 -def _format_bytes(num_bytes: t.Optional[int]) -> str: - if num_bytes is not None and num_bytes >= 0: - if num_bytes < 1024: - return f"{num_bytes} bytes" - - num_bytes_float = float(num_bytes) / 1024.0 - for unit in ["KiB", "MiB", "GiB", "TiB", "PiB"]: - if num_bytes_float < 1024.0: - return f"{num_bytes_float:3.1f} {unit}" - num_bytes_float /= 1024.0 - - num_bytes_float *= 1024.0 # undo last division in loop - return f"{num_bytes_float:3.1f} {unit}" - return "" - - -# Abbreviate integer count. Example: 1,000,000,000 -> 1b -# https://github.com/dbt-labs/dbt-adapters/blob/34fd178539dcb6f82e18e738adc03de7784c032f/dbt-bigquery/src/dbt/adapters/bigquery/connections.py#L178 -def _abbreviate_integer_count(count: t.Optional[int]) -> str: - if count is not None and count >= 0: - if count < 1000: - return str(count) - - count_float = float(count) / 1000.0 - for unit in ["k", "m", "b", "t"]: - if count_float < 1000.0: - return f"{count_float:3.1f}{unit}".strip() - count_float /= 1000.0 - - count_float *= 1000.0 # undo last division in loop - return f"{count_float:3.1f}{unit}".strip() - return "" diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index 721e7c61c1..530d9ce876 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -2458,16 +2458,12 @@ def _execute(self, sql: str, track_rows_processed: bool = False, **kwargs: t.Any and track_rows_processed and QueryExecutionTracker.is_tracking() ): - rowcount_raw = getattr(self.cursor, "rowcount", None) - rowcount = None - if rowcount_raw is not None: + if (rowcount := getattr(self.cursor, "rowcount", None)) and rowcount is not None: try: - rowcount = int(rowcount_raw) + self._record_execution_stats(sql, int(rowcount)) except (TypeError, ValueError): return - self._record_execution_stats(sql, rowcount) - @contextlib.contextmanager def temp_table( self, diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index ce27ee1838..210aff230d 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -20,6 +20,7 @@ DeployabilityIndex, Snapshot, SnapshotId, + SnapshotIdBatch, SnapshotEvaluator, apply_auto_restatements, earliest_start_date, @@ -533,7 +534,7 @@ def run_node(node: SchedulingUnit) -> None: num_audits_failed = sum(1 for result in audit_results if result.count) execution_stats = self.snapshot_evaluator.execution_tracker.get_execution_stats( - f"{snapshot.snapshot_id}_{node.batch_index}" + SnapshotIdBatch(snapshot_id=snapshot.snapshot_id, batch_id=node.batch_index) ) self.console.update_snapshot_evaluation_progress( diff --git a/sqlmesh/core/snapshot/__init__.py b/sqlmesh/core/snapshot/__init__.py index da44278aa8..8ad574f8ac 100644 --- a/sqlmesh/core/snapshot/__init__.py +++ b/sqlmesh/core/snapshot/__init__.py @@ -8,6 +8,7 @@ SnapshotDataVersion as SnapshotDataVersion, SnapshotFingerprint as SnapshotFingerprint, SnapshotId as SnapshotId, + SnapshotIdBatch as SnapshotIdBatch, SnapshotIdLike as SnapshotIdLike, SnapshotInfoLike as SnapshotInfoLike, SnapshotIntervals as SnapshotIntervals, diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index 5a9ad60166..afc8e06458 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -162,6 +162,11 @@ def __str__(self) -> str: return f"SnapshotId<{self.name}: {self.identifier}>" +class SnapshotIdBatch(PydanticModel, frozen=True): + snapshot_id: SnapshotId + batch_id: int + + class SnapshotNameVersion(PydanticModel, frozen=True): name: str version: str diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index f1569d0163..cef6b825e6 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -61,6 +61,7 @@ Intervals, Snapshot, SnapshotId, + SnapshotIdBatch, SnapshotInfoLike, SnapshotTableCleanupTask, ) @@ -171,7 +172,9 @@ def evaluate( Returns: The WAP ID of this evaluation if supported, None otherwise. """ - with self.execution_tracker.track_execution(f"{snapshot.snapshot_id}_{batch_index}"): + with self.execution_tracker.track_execution( + SnapshotIdBatch(snapshot_id=snapshot.snapshot_id, batch_id=batch_index) + ): result = self._evaluate_snapshot( start=start, end=end, diff --git a/sqlmesh/core/snapshot/execution_tracker.py b/sqlmesh/core/snapshot/execution_tracker.py index 61c10696ce..fb2af1d5dc 100644 --- a/sqlmesh/core/snapshot/execution_tracker.py +++ b/sqlmesh/core/snapshot/execution_tracker.py @@ -4,11 +4,12 @@ from contextlib import contextmanager from threading import local, Lock from dataclasses import dataclass, field +from sqlmesh.core.snapshot import SnapshotIdBatch @dataclass class QueryExecutionStats: - snapshot_batch_id: str + snapshot_id_batch: SnapshotIdBatch total_rows_processed: t.Optional[int] = None total_bytes_processed: t.Optional[int] = None @@ -21,15 +22,15 @@ class QueryExecutionContext: It accumulates statistics from multiple cursor.execute() calls during a single snapshot evaluation. Attributes: - snapshot_batch_id: Identifier linking this context to a specific snapshot evaluation + snapshot_id_batch: Identifier linking this context to a specific snapshot evaluation stats: Running sum of cursor.rowcount and possibly bytes processed from all executed queries during evaluation """ - snapshot_batch_id: str + snapshot_id_batch: SnapshotIdBatch stats: QueryExecutionStats = field(init=False) def __post_init__(self) -> None: - self.stats = QueryExecutionStats(snapshot_batch_id=self.snapshot_batch_id) + self.stats = QueryExecutionStats(snapshot_id_batch=self.snapshot_id_batch) def add_execution( self, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int] @@ -56,10 +57,12 @@ class QueryExecutionTracker: """Thread-local context manager for snapshot execution statistics, such as rows processed.""" _thread_local = local() - _contexts: t.Dict[str, QueryExecutionContext] = {} + _contexts: t.Dict[SnapshotIdBatch, QueryExecutionContext] = {} _contexts_lock = Lock() - def get_execution_context(self, snapshot_id_batch: str) -> t.Optional[QueryExecutionContext]: + def get_execution_context( + self, snapshot_id_batch: SnapshotIdBatch + ) -> t.Optional[QueryExecutionContext]: with self._contexts_lock: return self._contexts.get(snapshot_id_batch) @@ -69,10 +72,10 @@ def is_tracking(cls) -> bool: @contextmanager def track_execution( - self, snapshot_id_batch: str + self, snapshot_id_batch: SnapshotIdBatch ) -> t.Iterator[t.Optional[QueryExecutionContext]]: """Context manager for tracking snapshot execution statistics such as row counts and bytes processed.""" - context = QueryExecutionContext(snapshot_batch_id=snapshot_id_batch) + context = QueryExecutionContext(snapshot_id_batch=snapshot_id_batch) self._thread_local.context = context with self._contexts_lock: self._contexts[snapshot_id_batch] = context @@ -90,7 +93,9 @@ def record_execution( if context is not None: context.add_execution(sql, row_count, bytes_processed) - def get_execution_stats(self, snapshot_id_batch: str) -> t.Optional[QueryExecutionStats]: + def get_execution_stats( + self, snapshot_id_batch: SnapshotIdBatch + ) -> t.Optional[QueryExecutionStats]: with self._contexts_lock: context = self._contexts.get(snapshot_id_batch) self._contexts.pop(snapshot_id_batch, None) diff --git a/tests/core/engine_adapter/integration/test_integration_snowflake.py b/tests/core/engine_adapter/integration/test_integration_snowflake.py index d81ab38233..6c1498d889 100644 --- a/tests/core/engine_adapter/integration/test_integration_snowflake.py +++ b/tests/core/engine_adapter/integration/test_integration_snowflake.py @@ -13,6 +13,7 @@ from tests.core.engine_adapter.integration import TestContext from sqlmesh import model, ExecutionContext from pytest_mock import MockerFixture +from sqlmesh.core.snapshot import SnapshotId, SnapshotIdBatch from sqlmesh.core.snapshot.execution_tracker import ( QueryExecutionContext, QueryExecutionTracker, @@ -322,7 +323,9 @@ def test_rows_tracker( add_execution_spy = mocker.spy(QueryExecutionContext, "add_execution") - with tracker.track_execution("a"): + with tracker.track_execution( + SnapshotIdBatch(snapshot_id=SnapshotId(name="a", identifier="a"), batch_id=0) + ): # Snowflake doesn't report row counts for CTAS, so this should not be tracked engine_adapter.execute( "CREATE TABLE a (id int) AS SELECT 1 as id", track_rows_processed=True @@ -332,6 +335,8 @@ def test_rows_tracker( assert add_execution_spy.call_count == 2 - stats = tracker.get_execution_stats("a") + stats = tracker.get_execution_stats( + SnapshotIdBatch(snapshot_id=SnapshotId(name="a", identifier="a"), batch_id=0) + ) assert stats is not None assert stats.total_rows_processed == 3 diff --git a/tests/core/test_execution_tracker.py b/tests/core/test_execution_tracker.py index 79479f808f..0e58395bee 100644 --- a/tests/core/test_execution_tracker.py +++ b/tests/core/test_execution_tracker.py @@ -3,11 +3,12 @@ from concurrent.futures import ThreadPoolExecutor from sqlmesh.core.snapshot.execution_tracker import QueryExecutionStats, QueryExecutionTracker +from sqlmesh.core.snapshot import SnapshotIdBatch, SnapshotId def test_execution_tracker_thread_isolation() -> None: - def worker(id: str, row_counts: list[int]) -> QueryExecutionStats: - with execution_tracker.track_execution(id) as ctx: + def worker(id: SnapshotId, row_counts: list[int]) -> QueryExecutionStats: + with execution_tracker.track_execution(SnapshotIdBatch(snapshot_id=id, batch_id=0)) as ctx: assert execution_tracker.is_tracking() for count in row_counts: @@ -20,18 +21,30 @@ def worker(id: str, row_counts: list[int]) -> QueryExecutionStats: with ThreadPoolExecutor() as executor: futures = [ - executor.submit(worker, "batch_A", [10, 5]), - executor.submit(worker, "batch_B", [3, 7]), + executor.submit(worker, SnapshotId(name="batch_A", identifier="batch_A"), [10, 5]), + executor.submit(worker, SnapshotId(name="batch_B", identifier="batch_B"), [3, 7]), ] results = [f.result() for f in futures] # Main thread has no active tracking context assert not execution_tracker.is_tracking() - execution_tracker.record_execution("q", 10, None) - assert execution_tracker.get_execution_stats("q") is None # Order of results is not deterministic, so look up by id - by_batch = {s.snapshot_batch_id: s for s in results} - - assert by_batch["batch_A"].total_rows_processed == 15 - assert by_batch["batch_B"].total_rows_processed == 10 + by_batch = {s.snapshot_id_batch: s for s in results} + + assert ( + by_batch[ + SnapshotIdBatch( + snapshot_id=SnapshotId(name="batch_A", identifier="batch_A"), batch_id=0 + ) + ].total_rows_processed + == 15 + ) + assert ( + by_batch[ + SnapshotIdBatch( + snapshot_id=SnapshotId(name="batch_B", identifier="batch_B"), batch_id=0 + ) + ].total_rows_processed + == 10 + ) From 2f6f11a32822e518291c6194b4db1dc7f872ad81 Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Mon, 25 Aug 2025 13:44:40 -0500 Subject: [PATCH 27/31] Remove humanize functions --- sqlmesh/core/console.py | 49 +++++++++++++++++++++++++---- sqlmesh/core/engine_adapter/base.py | 4 ++- 2 files changed, 46 insertions(+), 7 deletions(-) diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index 7ed4da1ede..4d3af6c2dc 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -9,7 +9,6 @@ import textwrap from itertools import zip_longest from pathlib import Path -from humanize import naturalsize, metric from hyperscript import h from rich.console import Console as RichConsole from rich.live import Live @@ -4031,7 +4030,9 @@ def show_table_diff_summary(self, table_diff: TableDiff) -> None: self._write(f"Join On: {keys}") -_CONSOLE: Console = NoopConsole() +# TODO: remove this +# _CONSOLE: Console = NoopConsole() +_CONSOLE: Console = TerminalConsole() def set_console(console: Console) -> None: @@ -4188,15 +4189,15 @@ def _create_evaluation_model_annotation( if execution_stats: rows_processed = execution_stats.total_rows_processed execution_stats_str += ( - f"{metric(rows_processed)} row{'s' if rows_processed != 1 else ''}" - if rows_processed is not None and rows_processed >= 0 + f"{_abbreviate_integer_count(rows_processed)} row{'s' if rows_processed > 1 else ''}" + if rows_processed else "" ) bytes_processed = execution_stats.total_bytes_processed execution_stats_str += ( - f"{', ' if execution_stats_str else ''}{naturalsize(bytes_processed, binary=True)}" - if bytes_processed is not None and bytes_processed >= 0 + f"{', ' if execution_stats_str else ''}{_format_bytes(bytes_processed)}" + if bytes_processed else "" ) execution_stats_str = f" ({execution_stats_str})" if execution_stats_str else "" @@ -4300,3 +4301,39 @@ def _calculate_annotation_str_len( + execution_stats_len, ) return annotation_str_len + + +# Convert number of bytes to a human-readable string +# https://github.com/dbt-labs/dbt-adapters/blob/34fd178539dcb6f82e18e738adc03de7784c032f/dbt-bigquery/src/dbt/adapters/bigquery/connections.py#L165 +def _format_bytes(num_bytes: t.Optional[int]) -> str: + if num_bytes and num_bytes >= 0: + if num_bytes < 1024: + return f"{num_bytes} bytes" + + num_bytes_float = float(num_bytes) / 1024.0 + for unit in ["KiB", "MiB", "GiB", "TiB", "PiB"]: + if num_bytes_float < 1024.0: + return f"{num_bytes_float:3.1f} {unit}" + num_bytes_float /= 1024.0 + + num_bytes_float *= 1024.0 # undo last division in loop + return f"{num_bytes_float:3.1f} {unit}" + return "" + + +# Abbreviate integer count. Example: 1,000,000,000 -> 1b +# https://github.com/dbt-labs/dbt-adapters/blob/34fd178539dcb6f82e18e738adc03de7784c032f/dbt-bigquery/src/dbt/adapters/bigquery/connections.py#L178 +def _abbreviate_integer_count(count: t.Optional[int]) -> str: + if count and count >= 0: + if count < 1000: + return str(count) + + count_float = float(count) / 1000.0 + for unit in ["k", "m", "b", "t"]: + if count_float < 1000.0: + return f"{count_float:3.1f}{unit}".strip() + count_float /= 1000.0 + + count_float *= 1000.0 # undo last division in loop + return f"{count_float:3.1f}{unit}".strip() + return "" diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index 530d9ce876..3da3d73d5d 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -2458,7 +2458,9 @@ def _execute(self, sql: str, track_rows_processed: bool = False, **kwargs: t.Any and track_rows_processed and QueryExecutionTracker.is_tracking() ): - if (rowcount := getattr(self.cursor, "rowcount", None)) and rowcount is not None: + if ( + rowcount := getattr(self.cursor, "rowcount", None) + ) is not None and rowcount is not None: try: self._record_execution_stats(sql, int(rowcount)) except (TypeError, ValueError): From 0920f39cb7915e234dc17fe0083aa6d6a78bce46 Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Tue, 26 Aug 2025 10:28:45 -0500 Subject: [PATCH 28/31] Make tracking fully instance-based by passing to engine adapter --- .circleci/continue_config.yml | 12 +++--- sqlmesh/core/engine_adapter/base.py | 8 +++- sqlmesh/core/engine_adapter/bigquery.py | 11 +++-- sqlmesh/core/engine_adapter/snowflake.py | 42 +------------------ sqlmesh/core/snapshot/evaluator.py | 6 ++- sqlmesh/core/snapshot/execution_tracker.py | 29 ++++++------- .../integration/test_integration_snowflake.py | 11 ++--- 7 files changed, 42 insertions(+), 77 deletions(-) diff --git a/.circleci/continue_config.yml b/.circleci/continue_config.yml index e21f3d869b..5b0db2a5bb 100644 --- a/.circleci/continue_config.yml +++ b/.circleci/continue_config.yml @@ -297,8 +297,8 @@ workflows: name: cloud_engine_<< matrix.engine >> context: - sqlmesh_cloud_database_integration - requires: - - engine_tests_docker + # requires: + # - engine_tests_docker matrix: parameters: engine: @@ -310,10 +310,10 @@ workflows: - athena - fabric - gcp-postgres - filters: - branches: - only: - - main + # filters: + # branches: + # only: + # - main - ui_style - ui_test - vscode_test diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index 3da3d73d5d..878b7c6aca 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -135,6 +135,7 @@ def __init__( shared_connection: bool = False, correlation_id: t.Optional[CorrelationId] = None, schema_differ_overrides: t.Optional[t.Dict[str, t.Any]] = None, + query_execution_tracker: t.Optional[QueryExecutionTracker] = None, **kwargs: t.Any, ): self.dialect = dialect.lower() or self.DIALECT @@ -158,6 +159,7 @@ def __init__( self._multithreaded = multithreaded self.correlation_id = correlation_id self._schema_differ_overrides = schema_differ_overrides + self._query_execution_tracker = query_execution_tracker def with_settings(self, **kwargs: t.Any) -> EngineAdapter: extra_kwargs = { @@ -2448,7 +2450,8 @@ def _log_sql( def _record_execution_stats( self, sql: str, rowcount: t.Optional[int] = None, bytes_processed: t.Optional[int] = None ) -> None: - QueryExecutionTracker.record_execution(sql, rowcount, bytes_processed) + if self._query_execution_tracker: + self._query_execution_tracker.record_execution(sql, rowcount, bytes_processed) def _execute(self, sql: str, track_rows_processed: bool = False, **kwargs: t.Any) -> None: self.cursor.execute(sql, **kwargs) @@ -2456,7 +2459,8 @@ def _execute(self, sql: str, track_rows_processed: bool = False, **kwargs: t.Any if ( self.SUPPORTS_QUERY_EXECUTION_TRACKING and track_rows_processed - and QueryExecutionTracker.is_tracking() + and self._query_execution_tracker + and self._query_execution_tracker.is_tracking() ): if ( rowcount := getattr(self.cursor, "rowcount", None) diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index 679cff05ec..b3d02d8bbf 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -23,7 +23,6 @@ ) from sqlmesh.core.node import IntervalUnit from sqlmesh.core.schema_diff import TableAlterOperation, NestedSupport -from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker from sqlmesh.utils import optional_import, get_source_columns_to_types from sqlmesh.utils.date import to_datetime from sqlmesh.utils.errors import SQLMeshError @@ -1097,7 +1096,11 @@ def _execute( self.cursor._set_rowcount(query_results) self.cursor._set_description(query_results.schema) - if track_rows_processed and QueryExecutionTracker.is_tracking(): + if ( + track_rows_processed + and self._query_execution_tracker + and self._query_execution_tracker.is_tracking() + ): num_rows = None if query_job.statement_type == "CREATE_TABLE_AS_SELECT": # since table was just created, number rows in table == number rows processed @@ -1106,7 +1109,9 @@ def _execute( elif query_job.statement_type in ["INSERT", "DELETE", "MERGE", "UPDATE"]: num_rows = query_job.num_dml_affected_rows - QueryExecutionTracker.record_execution(sql, num_rows, query_job.total_bytes_processed) + self._query_execution_tracker.record_execution( + sql, num_rows, query_job.total_bytes_processed + ) def _get_data_objects( self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None diff --git a/sqlmesh/core/engine_adapter/snowflake.py b/sqlmesh/core/engine_adapter/snowflake.py index 8e74eb67d0..dca867e0ee 100644 --- a/sqlmesh/core/engine_adapter/snowflake.py +++ b/sqlmesh/core/engine_adapter/snowflake.py @@ -2,7 +2,6 @@ import contextlib import logging -import re import typing as t from sqlglot import exp @@ -24,7 +23,6 @@ SourceQuery, set_catalog, ) -from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker from sqlmesh.utils import optional_import, get_source_columns_to_types from sqlmesh.utils.errors import SQLMeshError from sqlmesh.utils.pandas import columns_to_types_from_dtypes @@ -189,7 +187,7 @@ def _create_table( table_description=table_description, column_descriptions=column_descriptions, table_kind=table_kind, - track_rows_processed=track_rows_processed, + track_rows_processed=False, **kwargs, ) @@ -667,41 +665,3 @@ def close(self) -> t.Any: self._connection_pool.set_attribute(self.SNOWPARK, None) return super().close() - - def _record_execution_stats( - self, sql: str, rowcount: t.Optional[int] = None, bytes_processed: t.Optional[int] = None - ) -> None: - """Snowflake does not report row counts for CTAS like other DML operations. - - They neither report the sentinel value -1 nor do they report 0 rows. Instead, they report a rowcount - of 1 and return a single data row containing one of the strings: - - "Table successfully created." - - " already exists, statement succeeded." - - We do not want to record the incorrect row count of 1, so we check whether that row contains the table - successfully created string. If so, we return early and do not record the row count. - - Ref: https://github.com/snowflakedb/snowflake-connector-python/issues/645 - """ - if rowcount == 1: - results = self.cursor.fetchone() - if results: - try: - results_str = str(results[0]) - except (TypeError, ValueError, IndexError): - return - - # Snowflake identifiers may be: - # - An unquoted contiguous set of [a-zA-Z0-9_$] characters - # - A double-quoted string that may contain spaces and nested double-quotes represented by `""`. Example: " my ""table"" name " - is_created = re.match( - r'Table [a-zA-Z0-9_$ "]*? successfully created\.', results_str - ) - is_already_exists = re.match( - r'[a-zA-Z0-9_$ "]*? already exists, statement succeeded\.', - results_str, - ) - if is_created or is_already_exists: - return - - QueryExecutionTracker.record_execution(sql, rowcount, bytes_processed) diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index cef6b825e6..90186faba7 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -130,6 +130,11 @@ def __init__( self.adapters = ( adapters if isinstance(adapters, t.Dict) else {selected_gateway or "": adapters} ) + self.execution_tracker = QueryExecutionTracker() + self.adapters = { + gateway: adapter.with_settings(query_execution_tracker=self.execution_tracker) + for gateway, adapter in self.adapters.items() + } self.adapter = ( next(iter(self.adapters.values())) if not selected_gateway @@ -137,7 +142,6 @@ def __init__( ) self.selected_gateway = selected_gateway self.ddl_concurrent_tasks = ddl_concurrent_tasks - self.execution_tracker = QueryExecutionTracker() def evaluate( self, diff --git a/sqlmesh/core/snapshot/execution_tracker.py b/sqlmesh/core/snapshot/execution_tracker.py index fb2af1d5dc..bcafec8d28 100644 --- a/sqlmesh/core/snapshot/execution_tracker.py +++ b/sqlmesh/core/snapshot/execution_tracker.py @@ -2,7 +2,7 @@ import typing as t from contextlib import contextmanager -from threading import local, Lock +from threading import local from dataclasses import dataclass, field from sqlmesh.core.snapshot import SnapshotIdBatch @@ -56,19 +56,17 @@ def get_execution_stats(self) -> QueryExecutionStats: class QueryExecutionTracker: """Thread-local context manager for snapshot execution statistics, such as rows processed.""" - _thread_local = local() - _contexts: t.Dict[SnapshotIdBatch, QueryExecutionContext] = {} - _contexts_lock = Lock() + def __init__(self) -> None: + self._thread_local = local() + self._contexts: t.Dict[SnapshotIdBatch, QueryExecutionContext] = {} def get_execution_context( self, snapshot_id_batch: SnapshotIdBatch ) -> t.Optional[QueryExecutionContext]: - with self._contexts_lock: - return self._contexts.get(snapshot_id_batch) + return self._contexts.get(snapshot_id_batch) - @classmethod - def is_tracking(cls) -> bool: - return getattr(cls._thread_local, "context", None) is not None + def is_tracking(self) -> bool: + return getattr(self._thread_local, "context", None) is not None @contextmanager def track_execution( @@ -77,26 +75,23 @@ def track_execution( """Context manager for tracking snapshot execution statistics such as row counts and bytes processed.""" context = QueryExecutionContext(snapshot_id_batch=snapshot_id_batch) self._thread_local.context = context - with self._contexts_lock: - self._contexts[snapshot_id_batch] = context + self._contexts[snapshot_id_batch] = context try: yield context finally: self._thread_local.context = None - @classmethod def record_execution( - cls, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int] + self, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int] ) -> None: - context = getattr(cls._thread_local, "context", None) + context = getattr(self._thread_local, "context", None) if context is not None: context.add_execution(sql, row_count, bytes_processed) def get_execution_stats( self, snapshot_id_batch: SnapshotIdBatch ) -> t.Optional[QueryExecutionStats]: - with self._contexts_lock: - context = self._contexts.get(snapshot_id_batch) - self._contexts.pop(snapshot_id_batch, None) + context = self._contexts.get(snapshot_id_batch) + self._contexts.pop(snapshot_id_batch, None) return context.get_execution_stats() if context else None diff --git a/tests/core/engine_adapter/integration/test_integration_snowflake.py b/tests/core/engine_adapter/integration/test_integration_snowflake.py index 6c1498d889..aed6bf83e4 100644 --- a/tests/core/engine_adapter/integration/test_integration_snowflake.py +++ b/tests/core/engine_adapter/integration/test_integration_snowflake.py @@ -327,16 +327,13 @@ def test_rows_tracker( SnapshotIdBatch(snapshot_id=SnapshotId(name="a", identifier="a"), batch_id=0) ): # Snowflake doesn't report row counts for CTAS, so this should not be tracked - engine_adapter.execute( - "CREATE TABLE a (id int) AS SELECT 1 as id", track_rows_processed=True - ) - engine_adapter.execute("INSERT INTO a VALUES (2), (3)", track_rows_processed=True) - engine_adapter.execute("INSERT INTO a VALUES (4)", track_rows_processed=True) + engine_adapter._create_table("a", exp.select("1 as id")) - assert add_execution_spy.call_count == 2 + assert add_execution_spy.call_count == 0 stats = tracker.get_execution_stats( SnapshotIdBatch(snapshot_id=SnapshotId(name="a", identifier="a"), batch_id=0) ) assert stats is not None - assert stats.total_rows_processed == 3 + assert stats.total_rows_processed is None + assert stats.total_bytes_processed is None From 8e8ddabab98a306c46f0b79ef2e4131b7571494a Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Tue, 26 Aug 2025 11:10:23 -0500 Subject: [PATCH 29/31] Fix snapshot evaluator tests --- sqlmesh/core/engine_adapter/snowflake.py | 2 +- tests/core/test_snapshot_evaluator.py | 17 +++++++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/sqlmesh/core/engine_adapter/snowflake.py b/sqlmesh/core/engine_adapter/snowflake.py index dca867e0ee..8a6f5e2fcc 100644 --- a/sqlmesh/core/engine_adapter/snowflake.py +++ b/sqlmesh/core/engine_adapter/snowflake.py @@ -187,7 +187,7 @@ def _create_table( table_description=table_description, column_descriptions=column_descriptions, table_kind=table_kind, - track_rows_processed=False, + track_rows_processed=False, # snowflake tracks CTAS row counts incorrectly **kwargs, ) diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index c25acd279e..3b72a14f5f 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -115,6 +115,7 @@ def mock_exit(self, exc_type, exc_value, traceback): adapter_mock.HAS_VIEW_BINDING = False adapter_mock.wap_supported.return_value = False adapter_mock.get_data_objects.return_value = [] + adapter_mock.with_settings.return_value = adapter_mock return adapter_mock @@ -137,6 +138,7 @@ def adapters(mocker: MockerFixture): adapter_mock.HAS_VIEW_BINDING = False adapter_mock.wap_supported.return_value = False adapter_mock.get_data_objects.return_value = [] + adapter_mock.with_settings.return_value = adapter_mock adapters.append(adapter_mock) return adapters @@ -652,6 +654,7 @@ def test_evaluate_materialized_view_with_partitioned_by_cluster_by( adapter.table_exists = lambda *args, **kwargs: False # type: ignore adapter.get_data_objects = lambda *args, **kwargs: [] # type: ignore adapter._execute = execute_mock # type: ignore + adapter.with_settings = lambda **kwargs: adapter # type: ignore evaluator = SnapshotEvaluator(adapter) model = SqlModel( @@ -992,6 +995,7 @@ def test_create_tables_exist( ): adapter_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter") adapter_mock.dialect = "duckdb" + adapter_mock.with_settings.return_value = adapter_mock evaluator = SnapshotEvaluator(adapter_mock) snapshot.categorize_as(category=snapshot_category, forward_only=forward_only) @@ -1194,6 +1198,7 @@ def test_create_view_with_properties(mocker: MockerFixture, adapter_mock, make_s def test_promote_model_info(mocker: MockerFixture, make_snapshot): adapter_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter") adapter_mock.dialect = "duckdb" + adapter_mock.with_settings.return_value = adapter_mock evaluator = SnapshotEvaluator(adapter_mock) @@ -1222,6 +1227,7 @@ def test_promote_model_info(mocker: MockerFixture, make_snapshot): def test_promote_deployable(mocker: MockerFixture, make_snapshot): adapter_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter") adapter_mock.dialect = "duckdb" + adapter_mock.with_settings.return_value = adapter_mock evaluator = SnapshotEvaluator(adapter_mock) @@ -1267,6 +1273,7 @@ def test_promote_deployable(mocker: MockerFixture, make_snapshot): def test_migrate(mocker: MockerFixture, make_snapshot, make_mocked_engine_adapter): adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.with_settings = lambda **kwargs: adapter # type: ignore session_spy = mocker.spy(adapter, "session") current_table = "sqlmesh__test_schema.test_schema__test_model__1" @@ -1322,6 +1329,7 @@ def columns(table_name): def test_migrate_missing_table(mocker: MockerFixture, make_snapshot, make_mocked_engine_adapter): adapter = make_mocked_engine_adapter(EngineAdapter) adapter.table_exists = lambda _: False # type: ignore + adapter.with_settings = lambda **kwargs: adapter # type: ignore mocker.patch.object(adapter, "get_data_object", return_value=None) evaluator = SnapshotEvaluator(adapter) @@ -1390,6 +1398,7 @@ def test_migrate_snapshot_data_object_type_mismatch( make_mocked_engine_adapter, ): adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.with_settings = lambda **kwargs: adapter # type: ignore mocker.patch.object( adapter, "get_data_object", @@ -1804,7 +1813,7 @@ def test_on_destructive_change_runtime_check( make_mocked_engine_adapter, ): adapter = make_mocked_engine_adapter(EngineAdapter) - + adapter.with_settings = lambda **kwargs: adapter # type: ignore current_table = "sqlmesh__test_schema.test_schema__test_model__1" def columns(table_name): @@ -1886,7 +1895,7 @@ def test_on_additive_change_runtime_check( make_mocked_engine_adapter, ): adapter = make_mocked_engine_adapter(EngineAdapter) - + adapter.with_settings = lambda **kwargs: adapter # type: ignore current_table = "sqlmesh__test_schema.test_schema__test_model__1" def columns(table_name): @@ -3778,6 +3787,7 @@ def test_create_managed_forward_only_with_previous_version_doesnt_clone_for_dev_ def test_migrate_snapshot(snapshot: Snapshot, mocker: MockerFixture, adapter_mock, make_snapshot): adapter_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter") adapter_mock.dialect = "duckdb" + adapter_mock.with_settings.return_value = adapter_mock evaluator = SnapshotEvaluator(adapter_mock) evaluator.create([snapshot], {}) @@ -3987,6 +3997,7 @@ def test_multiple_engine_promotion(mocker: MockerFixture, adapter_mock, make_sna cursor_mock = mocker.Mock() connection_mock.cursor.return_value = cursor_mock adapter = EngineAdapter(lambda: connection_mock, "") + adapter.with_settings = lambda **kwargs: adapter # type: ignore engine_adapters = {"default": adapter_mock, "secondary": adapter} def columns(table_name): @@ -4046,7 +4057,9 @@ def test_multiple_engine_migration( mocker: MockerFixture, adapter_mock, make_snapshot, make_mocked_engine_adapter ): adapter_one = make_mocked_engine_adapter(EngineAdapter) + adapter_one.with_settings = lambda **kwargs: adapter_one # type: ignore adapter_two = adapter_mock + adapter_two.with_settings.return_value = adapter_two engine_adapters = {"one": adapter_one, "two": adapter_two} current_table = "sqlmesh__test_schema.test_schema__test_model__1" From 61161ca70655129626cccdcdc02d0ecc0caa47f9 Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Tue, 26 Aug 2025 13:58:49 -0500 Subject: [PATCH 30/31] Retain previous with_settings settings --- .circleci/continue_config.yml | 12 ++++++------ sqlmesh/core/console.py | 4 +--- sqlmesh/core/engine_adapter/base.py | 4 ++++ tests/core/test_integration.py | 28 ++++++++++++++-------------- 4 files changed, 25 insertions(+), 23 deletions(-) diff --git a/.circleci/continue_config.yml b/.circleci/continue_config.yml index 5b0db2a5bb..e21f3d869b 100644 --- a/.circleci/continue_config.yml +++ b/.circleci/continue_config.yml @@ -297,8 +297,8 @@ workflows: name: cloud_engine_<< matrix.engine >> context: - sqlmesh_cloud_database_integration - # requires: - # - engine_tests_docker + requires: + - engine_tests_docker matrix: parameters: engine: @@ -310,10 +310,10 @@ workflows: - athena - fabric - gcp-postgres - # filters: - # branches: - # only: - # - main + filters: + branches: + only: + - main - ui_style - ui_test - vscode_test diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index 4d3af6c2dc..fa67b0549a 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -4030,9 +4030,7 @@ def show_table_diff_summary(self, table_diff: TableDiff) -> None: self._write(f"Join On: {keys}") -# TODO: remove this -# _CONSOLE: Console = NoopConsole() -_CONSOLE: Console = TerminalConsole() +_CONSOLE: Console = NoopConsole() def set_console(console: Console) -> None: diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index 878b7c6aca..2901831940 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -165,6 +165,10 @@ def with_settings(self, **kwargs: t.Any) -> EngineAdapter: extra_kwargs = { "null_connection": True, "execute_log_level": kwargs.pop("execute_log_level", self._execute_log_level), + "correlation_id": kwargs.pop("correlation_id", self.correlation_id), + "query_execution_tracker": kwargs.pop( + "query_execution_tracker", self._query_execution_tracker + ), **self._extra_config, **kwargs, } diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index f80c42f579..8cd50dc732 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -8022,7 +8022,7 @@ def test_incremental_by_time_model_ignore_additive_change(tmp_path: Path): cron '@daily' ); - SELECT + SELECT *, 1 as id, 'test_name' as name, @@ -8068,7 +8068,7 @@ def test_incremental_by_time_model_ignore_additive_change(tmp_path: Path): cron '@daily' ); - SELECT + SELECT *, 1 as id, 'other' as other_column, @@ -8124,7 +8124,7 @@ def test_incremental_by_time_model_ignore_additive_change(tmp_path: Path): cron '@daily' ); - SELECT + SELECT *, CAST(1 AS STRING) as id, 'other' as other_column, @@ -8170,7 +8170,7 @@ def test_incremental_by_time_model_ignore_additive_change(tmp_path: Path): cron '@daily' ); - SELECT + SELECT *, CAST(1 AS STRING) as id, 'other' as other_column, @@ -8344,7 +8344,7 @@ def test_incremental_by_unique_key_model_ignore_additive_change(tmp_path: Path): cron '@daily' ); - SELECT + SELECT *, 1 as id, 'test_name' as name, @@ -8389,7 +8389,7 @@ def test_incremental_by_unique_key_model_ignore_additive_change(tmp_path: Path): cron '@daily' ); - SELECT + SELECT *, 2 as id, 3 as new_column, @@ -8566,7 +8566,7 @@ def test_incremental_unmanaged_model_ignore_additive_change(tmp_path: Path): cron '@daily' ); - SELECT + SELECT *, 1 as id, 'test_name' as name, @@ -8610,7 +8610,7 @@ def test_incremental_unmanaged_model_ignore_additive_change(tmp_path: Path): ); SELECT - *, + *, 2 as id, 3 as new_column, @start_ds as ds @@ -9020,7 +9020,7 @@ def test_scd_type_2_by_column_ignore_additive_change(tmp_path: Path): cron '@daily' ); - SELECT + SELECT *, 1 as id, 'test_name' as name, @@ -9067,7 +9067,7 @@ def test_scd_type_2_by_column_ignore_additive_change(tmp_path: Path): ); SELECT - *, + *, 1 as id, 'stable2' as stable, 3 as new_column, @@ -9247,7 +9247,7 @@ def test_incremental_partition_ignore_additive_change(tmp_path: Path): cron '@daily' ); - SELECT + SELECT *, 1 as id, 'test_name' as name, @@ -9292,7 +9292,7 @@ def test_incremental_partition_ignore_additive_change(tmp_path: Path): ); SELECT - *, + *, 1 as id, 3 as new_column, @start_ds as ds @@ -9526,7 +9526,7 @@ def test_incremental_by_time_model_ignore_additive_change_unit_test(tmp_path: Pa cron '@daily' ); - SELECT + SELECT id, name, ds @@ -9595,7 +9595,7 @@ def test_incremental_by_time_model_ignore_additive_change_unit_test(tmp_path: Pa cron '@daily' ); - SELECT + SELECT id, new_column, ds From 61c7c41cc147bcbdf02165acf644509ebd2c542b Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Tue, 26 Aug 2025 15:35:36 -0500 Subject: [PATCH 31/31] Use humanize for integer/bytes abbreviation --- pyproject.toml | 1 + sqlmesh/core/console.py | 48 ++++++----------------------------------- 2 files changed, 7 insertions(+), 42 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9a4fd0632c..f371cdee0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "croniter", "duckdb>=0.10.0,!=0.10.3", "dateparser<=1.2.1", + "humanize", "hyperscript>=0.1.0", "importlib-metadata; python_version<'3.12'", "ipywidgets", diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index fa67b0549a..0907b39987 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -7,6 +7,7 @@ import uuid import logging import textwrap +from humanize import metric, naturalsize from itertools import zip_longest from pathlib import Path from hyperscript import h @@ -4186,15 +4187,14 @@ def _create_evaluation_model_annotation( execution_stats_str = "" if execution_stats: rows_processed = execution_stats.total_rows_processed - execution_stats_str += ( - f"{_abbreviate_integer_count(rows_processed)} row{'s' if rows_processed > 1 else ''}" - if rows_processed - else "" - ) + if rows_processed: + # 1.00 and 1.0 to 1 + rows_processed_str = metric(rows_processed).replace(".00", "").replace(".0", "") + execution_stats_str += f"{rows_processed_str} row{'s' if rows_processed > 1 else ''}" bytes_processed = execution_stats.total_bytes_processed execution_stats_str += ( - f"{', ' if execution_stats_str else ''}{_format_bytes(bytes_processed)}" + f"{', ' if execution_stats_str else ''}{naturalsize(bytes_processed, binary=True)}" if bytes_processed else "" ) @@ -4299,39 +4299,3 @@ def _calculate_annotation_str_len( + execution_stats_len, ) return annotation_str_len - - -# Convert number of bytes to a human-readable string -# https://github.com/dbt-labs/dbt-adapters/blob/34fd178539dcb6f82e18e738adc03de7784c032f/dbt-bigquery/src/dbt/adapters/bigquery/connections.py#L165 -def _format_bytes(num_bytes: t.Optional[int]) -> str: - if num_bytes and num_bytes >= 0: - if num_bytes < 1024: - return f"{num_bytes} bytes" - - num_bytes_float = float(num_bytes) / 1024.0 - for unit in ["KiB", "MiB", "GiB", "TiB", "PiB"]: - if num_bytes_float < 1024.0: - return f"{num_bytes_float:3.1f} {unit}" - num_bytes_float /= 1024.0 - - num_bytes_float *= 1024.0 # undo last division in loop - return f"{num_bytes_float:3.1f} {unit}" - return "" - - -# Abbreviate integer count. Example: 1,000,000,000 -> 1b -# https://github.com/dbt-labs/dbt-adapters/blob/34fd178539dcb6f82e18e738adc03de7784c032f/dbt-bigquery/src/dbt/adapters/bigquery/connections.py#L178 -def _abbreviate_integer_count(count: t.Optional[int]) -> str: - if count and count >= 0: - if count < 1000: - return str(count) - - count_float = float(count) / 1000.0 - for unit in ["k", "m", "b", "t"]: - if count_float < 1000.0: - return f"{count_float:3.1f}{unit}".strip() - count_float /= 1000.0 - - count_float *= 1000.0 # undo last division in loop - return f"{count_float:3.1f}{unit}".strip() - return ""