Skip to content

Commit c7622f2

Browse files
committed
PR feedback
1 parent 2020cc7 commit c7622f2

File tree

9 files changed

+61
-67
lines changed

9 files changed

+61
-67
lines changed

sqlmesh/core/console.py

Lines changed: 3 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import textwrap
1010
from itertools import zip_longest
1111
from pathlib import Path
12+
from humanize import naturalsize, metric
1213
from hyperscript import h
1314
from rich.console import Console as RichConsole
1415
from rich.live import Live
@@ -4187,14 +4188,14 @@ def _create_evaluation_model_annotation(
41874188
if execution_stats:
41884189
rows_processed = execution_stats.total_rows_processed
41894190
execution_stats_str += (
4190-
f"{_abbreviate_integer_count(rows_processed)} row{'s' if rows_processed != 1 else ''}"
4191+
f"{metric(rows_processed)} row{'s' if rows_processed != 1 else ''}"
41914192
if rows_processed is not None and rows_processed >= 0
41924193
else ""
41934194
)
41944195

41954196
bytes_processed = execution_stats.total_bytes_processed
41964197
execution_stats_str += (
4197-
f"{', ' if execution_stats_str else ''}{_format_bytes(bytes_processed)}"
4198+
f"{', ' if execution_stats_str else ''}{naturalsize(bytes_processed, binary=True)}"
41984199
if bytes_processed is not None and bytes_processed >= 0
41994200
else ""
42004201
)
@@ -4299,39 +4300,3 @@ def _calculate_annotation_str_len(
42994300
+ execution_stats_len,
43004301
)
43014302
return annotation_str_len
4302-
4303-
4304-
# Convert number of bytes to a human-readable string
4305-
# https://github.com/dbt-labs/dbt-adapters/blob/34fd178539dcb6f82e18e738adc03de7784c032f/dbt-bigquery/src/dbt/adapters/bigquery/connections.py#L165
4306-
def _format_bytes(num_bytes: t.Optional[int]) -> str:
4307-
if num_bytes is not None and num_bytes >= 0:
4308-
if num_bytes < 1024:
4309-
return f"{num_bytes} bytes"
4310-
4311-
num_bytes_float = float(num_bytes) / 1024.0
4312-
for unit in ["KiB", "MiB", "GiB", "TiB", "PiB"]:
4313-
if num_bytes_float < 1024.0:
4314-
return f"{num_bytes_float:3.1f} {unit}"
4315-
num_bytes_float /= 1024.0
4316-
4317-
num_bytes_float *= 1024.0 # undo last division in loop
4318-
return f"{num_bytes_float:3.1f} {unit}"
4319-
return ""
4320-
4321-
4322-
# Abbreviate integer count. Example: 1,000,000,000 -> 1b
4323-
# https://github.com/dbt-labs/dbt-adapters/blob/34fd178539dcb6f82e18e738adc03de7784c032f/dbt-bigquery/src/dbt/adapters/bigquery/connections.py#L178
4324-
def _abbreviate_integer_count(count: t.Optional[int]) -> str:
4325-
if count is not None and count >= 0:
4326-
if count < 1000:
4327-
return str(count)
4328-
4329-
count_float = float(count) / 1000.0
4330-
for unit in ["k", "m", "b", "t"]:
4331-
if count_float < 1000.0:
4332-
return f"{count_float:3.1f}{unit}".strip()
4333-
count_float /= 1000.0
4334-
4335-
count_float *= 1000.0 # undo last division in loop
4336-
return f"{count_float:3.1f}{unit}".strip()
4337-
return ""

sqlmesh/core/engine_adapter/base.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2458,16 +2458,12 @@ def _execute(self, sql: str, track_rows_processed: bool = False, **kwargs: t.Any
24582458
and track_rows_processed
24592459
and QueryExecutionTracker.is_tracking()
24602460
):
2461-
rowcount_raw = getattr(self.cursor, "rowcount", None)
2462-
rowcount = None
2463-
if rowcount_raw is not None:
2461+
if (rowcount := getattr(self.cursor, "rowcount", None)) and rowcount is not None:
24642462
try:
2465-
rowcount = int(rowcount_raw)
2463+
self._record_execution_stats(sql, int(rowcount))
24662464
except (TypeError, ValueError):
24672465
return
24682466

2469-
self._record_execution_stats(sql, rowcount)
2470-
24712467
@contextlib.contextmanager
24722468
def temp_table(
24732469
self,

sqlmesh/core/scheduler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
DeployabilityIndex,
2121
Snapshot,
2222
SnapshotId,
23+
SnapshotIdBatch,
2324
SnapshotEvaluator,
2425
apply_auto_restatements,
2526
earliest_start_date,
@@ -533,7 +534,7 @@ def run_node(node: SchedulingUnit) -> None:
533534
num_audits_failed = sum(1 for result in audit_results if result.count)
534535

535536
execution_stats = self.snapshot_evaluator.execution_tracker.get_execution_stats(
536-
f"{snapshot.snapshot_id}_{node.batch_index}"
537+
SnapshotIdBatch(snapshot_id=snapshot.snapshot_id, batch_id=node.batch_index)
537538
)
538539

539540
self.console.update_snapshot_evaluation_progress(

sqlmesh/core/snapshot/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
SnapshotDataVersion as SnapshotDataVersion,
99
SnapshotFingerprint as SnapshotFingerprint,
1010
SnapshotId as SnapshotId,
11+
SnapshotIdBatch as SnapshotIdBatch,
1112
SnapshotIdLike as SnapshotIdLike,
1213
SnapshotInfoLike as SnapshotInfoLike,
1314
SnapshotIntervals as SnapshotIntervals,

sqlmesh/core/snapshot/definition.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,11 @@ def __str__(self) -> str:
162162
return f"SnapshotId<{self.name}: {self.identifier}>"
163163

164164

165+
class SnapshotIdBatch(PydanticModel, frozen=True):
166+
snapshot_id: SnapshotId
167+
batch_id: int
168+
169+
165170
class SnapshotNameVersion(PydanticModel, frozen=True):
166171
name: str
167172
version: str

sqlmesh/core/snapshot/evaluator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
Intervals,
6262
Snapshot,
6363
SnapshotId,
64+
SnapshotIdBatch,
6465
SnapshotInfoLike,
6566
SnapshotTableCleanupTask,
6667
)
@@ -171,7 +172,9 @@ def evaluate(
171172
Returns:
172173
The WAP ID of this evaluation if supported, None otherwise.
173174
"""
174-
with self.execution_tracker.track_execution(f"{snapshot.snapshot_id}_{batch_index}"):
175+
with self.execution_tracker.track_execution(
176+
SnapshotIdBatch(snapshot_id=snapshot.snapshot_id, batch_id=batch_index)
177+
):
175178
result = self._evaluate_snapshot(
176179
start=start,
177180
end=end,

sqlmesh/core/snapshot/execution_tracker.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
from contextlib import contextmanager
55
from threading import local, Lock
66
from dataclasses import dataclass, field
7+
from sqlmesh.core.snapshot import SnapshotIdBatch
78

89

910
@dataclass
1011
class QueryExecutionStats:
11-
snapshot_batch_id: str
12+
snapshot_id_batch: SnapshotIdBatch
1213
total_rows_processed: t.Optional[int] = None
1314
total_bytes_processed: t.Optional[int] = None
1415

@@ -21,15 +22,15 @@ class QueryExecutionContext:
2122
It accumulates statistics from multiple cursor.execute() calls during a single snapshot evaluation.
2223
2324
Attributes:
24-
snapshot_batch_id: Identifier linking this context to a specific snapshot evaluation
25+
snapshot_id_batch: Identifier linking this context to a specific snapshot evaluation
2526
stats: Running sum of cursor.rowcount and possibly bytes processed from all executed queries during evaluation
2627
"""
2728

28-
snapshot_batch_id: str
29+
snapshot_id_batch: SnapshotIdBatch
2930
stats: QueryExecutionStats = field(init=False)
3031

3132
def __post_init__(self) -> None:
32-
self.stats = QueryExecutionStats(snapshot_batch_id=self.snapshot_batch_id)
33+
self.stats = QueryExecutionStats(snapshot_id_batch=self.snapshot_id_batch)
3334

3435
def add_execution(
3536
self, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int]
@@ -56,10 +57,12 @@ class QueryExecutionTracker:
5657
"""Thread-local context manager for snapshot execution statistics, such as rows processed."""
5758

5859
_thread_local = local()
59-
_contexts: t.Dict[str, QueryExecutionContext] = {}
60+
_contexts: t.Dict[SnapshotIdBatch, QueryExecutionContext] = {}
6061
_contexts_lock = Lock()
6162

62-
def get_execution_context(self, snapshot_id_batch: str) -> t.Optional[QueryExecutionContext]:
63+
def get_execution_context(
64+
self, snapshot_id_batch: SnapshotIdBatch
65+
) -> t.Optional[QueryExecutionContext]:
6366
with self._contexts_lock:
6467
return self._contexts.get(snapshot_id_batch)
6568

@@ -69,10 +72,10 @@ def is_tracking(cls) -> bool:
6972

7073
@contextmanager
7174
def track_execution(
72-
self, snapshot_id_batch: str
75+
self, snapshot_id_batch: SnapshotIdBatch
7376
) -> t.Iterator[t.Optional[QueryExecutionContext]]:
7477
"""Context manager for tracking snapshot execution statistics such as row counts and bytes processed."""
75-
context = QueryExecutionContext(snapshot_batch_id=snapshot_id_batch)
78+
context = QueryExecutionContext(snapshot_id_batch=snapshot_id_batch)
7679
self._thread_local.context = context
7780
with self._contexts_lock:
7881
self._contexts[snapshot_id_batch] = context
@@ -90,7 +93,9 @@ def record_execution(
9093
if context is not None:
9194
context.add_execution(sql, row_count, bytes_processed)
9295

93-
def get_execution_stats(self, snapshot_id_batch: str) -> t.Optional[QueryExecutionStats]:
96+
def get_execution_stats(
97+
self, snapshot_id_batch: SnapshotIdBatch
98+
) -> t.Optional[QueryExecutionStats]:
9499
with self._contexts_lock:
95100
context = self._contexts.get(snapshot_id_batch)
96101
self._contexts.pop(snapshot_id_batch, None)

tests/core/engine_adapter/integration/test_integration_snowflake.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from tests.core.engine_adapter.integration import TestContext
1414
from sqlmesh import model, ExecutionContext
1515
from pytest_mock import MockerFixture
16+
from sqlmesh.core.snapshot import SnapshotId, SnapshotIdBatch
1617
from sqlmesh.core.snapshot.execution_tracker import (
1718
QueryExecutionContext,
1819
QueryExecutionTracker,
@@ -322,7 +323,9 @@ def test_rows_tracker(
322323

323324
add_execution_spy = mocker.spy(QueryExecutionContext, "add_execution")
324325

325-
with tracker.track_execution("a"):
326+
with tracker.track_execution(
327+
SnapshotIdBatch(snapshot_id=SnapshotId(name="a", identifier="a"), batch_id=0)
328+
):
326329
# Snowflake doesn't report row counts for CTAS, so this should not be tracked
327330
engine_adapter.execute(
328331
"CREATE TABLE a (id int) AS SELECT 1 as id", track_rows_processed=True
@@ -332,6 +335,8 @@ def test_rows_tracker(
332335

333336
assert add_execution_spy.call_count == 2
334337

335-
stats = tracker.get_execution_stats("a")
338+
stats = tracker.get_execution_stats(
339+
SnapshotIdBatch(snapshot_id=SnapshotId(name="a", identifier="a"), batch_id=0)
340+
)
336341
assert stats is not None
337342
assert stats.total_rows_processed == 3

tests/core/test_execution_tracker.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
from concurrent.futures import ThreadPoolExecutor
44

55
from sqlmesh.core.snapshot.execution_tracker import QueryExecutionStats, QueryExecutionTracker
6+
from sqlmesh.core.snapshot import SnapshotIdBatch, SnapshotId
67

78

89
def test_execution_tracker_thread_isolation() -> None:
9-
def worker(id: str, row_counts: list[int]) -> QueryExecutionStats:
10-
with execution_tracker.track_execution(id) as ctx:
10+
def worker(id: SnapshotId, row_counts: list[int]) -> QueryExecutionStats:
11+
with execution_tracker.track_execution(SnapshotIdBatch(snapshot_id=id, batch_id=0)) as ctx:
1112
assert execution_tracker.is_tracking()
1213

1314
for count in row_counts:
@@ -20,18 +21,30 @@ def worker(id: str, row_counts: list[int]) -> QueryExecutionStats:
2021

2122
with ThreadPoolExecutor() as executor:
2223
futures = [
23-
executor.submit(worker, "batch_A", [10, 5]),
24-
executor.submit(worker, "batch_B", [3, 7]),
24+
executor.submit(worker, SnapshotId(name="batch_A", identifier="batch_A"), [10, 5]),
25+
executor.submit(worker, SnapshotId(name="batch_B", identifier="batch_B"), [3, 7]),
2526
]
2627
results = [f.result() for f in futures]
2728

2829
# Main thread has no active tracking context
2930
assert not execution_tracker.is_tracking()
30-
execution_tracker.record_execution("q", 10, None)
31-
assert execution_tracker.get_execution_stats("q") is None
3231

3332
# Order of results is not deterministic, so look up by id
34-
by_batch = {s.snapshot_batch_id: s for s in results}
35-
36-
assert by_batch["batch_A"].total_rows_processed == 15
37-
assert by_batch["batch_B"].total_rows_processed == 10
33+
by_batch = {s.snapshot_id_batch: s for s in results}
34+
35+
assert (
36+
by_batch[
37+
SnapshotIdBatch(
38+
snapshot_id=SnapshotId(name="batch_A", identifier="batch_A"), batch_id=0
39+
)
40+
].total_rows_processed
41+
== 15
42+
)
43+
assert (
44+
by_batch[
45+
SnapshotIdBatch(
46+
snapshot_id=SnapshotId(name="batch_B", identifier="batch_B"), batch_id=0
47+
)
48+
].total_rows_processed
49+
== 10
50+
)

0 commit comments

Comments
 (0)