|
14 | 14 | SourceQuery, |
15 | 15 | ) |
16 | 16 | from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter |
| 17 | +from sqlmesh.engines.spark.db_api.spark_session import SparkSessionCursor |
17 | 18 | from sqlmesh.core.node import IntervalUnit |
18 | 19 | from sqlmesh.core.schema_diff import SchemaDiffer |
19 | 20 | from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker |
@@ -380,38 +381,59 @@ def _record_execution_stats( |
380 | 381 | except: |
381 | 382 | return |
382 | 383 |
|
383 | | - history = self.cursor.fetchall_arrow() |
384 | | - if history.num_rows: |
385 | | - history_df = history.to_pandas() |
386 | | - write_df = history_df[history_df["operation"] == "WRITE"] |
387 | | - write_df = write_df[write_df["timestamp"] == write_df["timestamp"].max()] |
388 | | - if not write_df.empty: |
389 | | - metrics = write_df["operationMetrics"][0] |
390 | | - if metrics: |
391 | | - rowcount = None |
392 | | - rowcount_str = [ |
393 | | - metric[1] for metric in metrics if metric[0] == "numOutputRows" |
394 | | - ] |
395 | | - if rowcount_str: |
396 | | - try: |
397 | | - rowcount = int(rowcount_str[0]) |
398 | | - except (TypeError, ValueError): |
399 | | - pass |
400 | | - |
401 | | - bytes_processed = None |
402 | | - bytes_str = [ |
403 | | - metric[1] for metric in metrics if metric[0] == "numOutputBytes" |
404 | | - ] |
405 | | - if bytes_str: |
406 | | - try: |
407 | | - bytes_processed = int(bytes_str[0]) |
408 | | - except (TypeError, ValueError): |
409 | | - pass |
410 | | - |
411 | | - if rowcount is not None or bytes_processed is not None: |
412 | | - # if no rows were written, df contains 0 for bytes but no value for rows |
413 | | - rowcount = ( |
414 | | - 0 if rowcount is None and bytes_processed is not None else rowcount |
415 | | - ) |
416 | | - |
417 | | - QueryExecutionTracker.record_execution(sql, rowcount, bytes_processed) |
| 384 | + history = ( |
| 385 | + self.cursor.fetchdf() |
| 386 | + if isinstance(self.cursor, SparkSessionCursor) |
| 387 | + else self.cursor.fetchall_arrow() |
| 388 | + ) |
| 389 | + if history is not None: |
| 390 | + from pandas import DataFrame as PandasDataFrame |
| 391 | + from pyspark.sql import DataFrame as PySparkDataFrame |
| 392 | + from pyspark.sql.connect.dataframe import DataFrame as PySparkConnectDataFrame |
| 393 | + |
| 394 | + history_df = None |
| 395 | + if isinstance(history, PandasDataFrame): |
| 396 | + history_df = history |
| 397 | + elif isinstance(history, (PySparkDataFrame, PySparkConnectDataFrame)): |
| 398 | + history_df = history.toPandas() |
| 399 | + else: |
| 400 | + # arrow table |
| 401 | + history_df = history.to_pandas() |
| 402 | + |
| 403 | + if history_df is not None and not history_df.empty: |
| 404 | + write_df = history_df[history_df["operation"] == "WRITE"] |
| 405 | + write_df = write_df[write_df["timestamp"] == write_df["timestamp"].max()] |
| 406 | + if not write_df.empty: |
| 407 | + metrics = write_df["operationMetrics"][0] |
| 408 | + if metrics: |
| 409 | + rowcount = None |
| 410 | + rowcount_str = [ |
| 411 | + metric[1] for metric in metrics if metric[0] == "numOutputRows" |
| 412 | + ] |
| 413 | + if rowcount_str: |
| 414 | + try: |
| 415 | + rowcount = int(rowcount_str[0]) |
| 416 | + except (TypeError, ValueError): |
| 417 | + pass |
| 418 | + |
| 419 | + bytes_processed = None |
| 420 | + bytes_str = [ |
| 421 | + metric[1] for metric in metrics if metric[0] == "numOutputBytes" |
| 422 | + ] |
| 423 | + if bytes_str: |
| 424 | + try: |
| 425 | + bytes_processed = int(bytes_str[0]) |
| 426 | + except (TypeError, ValueError): |
| 427 | + pass |
| 428 | + |
| 429 | + if rowcount is not None or bytes_processed is not None: |
| 430 | + # if no rows were written, df contains 0 for bytes but no value for rows |
| 431 | + rowcount = ( |
| 432 | + 0 |
| 433 | + if rowcount is None and bytes_processed is not None |
| 434 | + else rowcount |
| 435 | + ) |
| 436 | + |
| 437 | + QueryExecutionTracker.record_execution( |
| 438 | + sql, rowcount, bytes_processed |
| 439 | + ) |
0 commit comments