|
| 1 | +import contextlib |
1 | 2 | import datetime |
2 | 3 | import io |
3 | 4 | import logging |
| 5 | +import uuid |
4 | 6 | from collections.abc import Iterator |
5 | 7 | from decimal import Decimal |
6 | 8 | from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union, cast |
7 | 9 |
|
8 | 10 | from google.cloud.bigquery import ( |
9 | 11 | ArrayQueryParameter, |
10 | 12 | Client, |
| 13 | + ExtractJobConfig, |
11 | 14 | LoadJobConfig, |
12 | 15 | QueryJob, |
13 | 16 | QueryJobConfig, |
14 | 17 | ScalarQueryParameter, |
| 18 | + SourceFormat, |
15 | 19 | WriteDisposition, |
16 | 20 | ) |
17 | 21 | from google.cloud.bigquery.table import Row as BigQueryRow |
|
32 | 36 | from sqlspec.utils.serializers import to_json |
33 | 37 |
|
34 | 38 | if TYPE_CHECKING: |
| 39 | + from pathlib import Path |
| 40 | + |
35 | 41 | from sqlglot.dialects.dialect import DialectType |
36 | 42 |
|
37 | 43 |
|
@@ -258,23 +264,17 @@ def _run_query_job( |
258 | 264 | param_value, |
259 | 265 | type(param_value), |
260 | 266 | ) |
261 | | - # Let BigQuery generate the job ID to avoid collisions |
262 | | - # This is the recommended approach for production code and works better with emulators |
263 | | - logger.warning("About to send to BigQuery - SQL: %r", sql_str) |
264 | | - logger.warning("Query parameters in job config: %r", final_job_config.query_parameters) |
265 | 267 | query_job = conn.query(sql_str, job_config=final_job_config) |
266 | 268 |
|
267 | 269 | # Get the auto-generated job ID for callbacks |
268 | 270 | if self.on_job_start and query_job.job_id: |
269 | | - try: |
| 271 | + with contextlib.suppress(Exception): |
| 272 | + # Callback errors should not interfere with job execution |
270 | 273 | self.on_job_start(query_job.job_id) |
271 | | - except Exception as e: |
272 | | - logger.warning("Job start callback failed: %s", str(e), extra={"adapter": "bigquery"}) |
273 | 274 | if self.on_job_complete and query_job.job_id: |
274 | | - try: |
| 275 | + with contextlib.suppress(Exception): |
| 276 | + # Callback errors should not interfere with job execution |
275 | 277 | self.on_job_complete(query_job.job_id, query_job) |
276 | | - except Exception as e: |
277 | | - logger.warning("Job complete callback failed: %s", str(e), extra={"adapter": "bigquery"}) |
278 | 278 |
|
279 | 279 | return query_job |
280 | 280 |
|
@@ -529,28 +529,120 @@ def _connection(self, connection: "Optional[Client]" = None) -> "Client": |
529 | 529 | # BigQuery Native Export Support |
530 | 530 | # ============================================================================ |
531 | 531 |
|
532 | | - def _export_native(self, query: str, destination_uri: str, format: str, **options: Any) -> int: |
533 | | - """BigQuery native export implementation. |
| 532 | + def _export_native(self, query: str, destination_uri: "Union[str, Path]", format: str, **options: Any) -> int: |
| 533 | + """BigQuery native export implementation with automatic GCS staging. |
534 | 534 |
|
535 | | - For local files, BigQuery doesn't support direct export, so we raise NotImplementedError |
536 | | - to trigger the fallback mechanism that uses fetch + write. |
| 535 | + For GCS URIs, uses direct export. For other locations, automatically stages |
| 536 | + through a temporary GCS location and transfers to the final destination. |
537 | 537 |
|
538 | 538 | Args: |
539 | 539 | query: SQL query to execute |
540 | | - destination_uri: Destination URI (local file path or gs:// URI) |
| 540 | + destination_uri: Destination URI (local file path, gs:// URI, or Path object) |
541 | 541 | format: Export format (parquet, csv, json, avro) |
542 | | - **options: Additional export options |
| 542 | + **options: Additional export options including 'gcs_staging_bucket' |
543 | 543 |
|
544 | 544 | Returns: |
545 | 545 | Number of rows exported |
546 | 546 |
|
547 | 547 | Raises: |
548 | | - NotImplementedError: Always, to trigger fallback to fetch + write |
| 548 | + NotImplementedError: If no staging bucket is configured for non-GCS destinations |
549 | 549 | """ |
550 | | - # BigQuery only supports native export to GCS, not local files |
551 | | - # By raising NotImplementedError, the mixin will fall back to fetch + write |
552 | | - msg = "BigQuery native export only supports GCS URIs, using fallback for local files" |
553 | | - raise NotImplementedError(msg) |
| 550 | + destination_str = str(destination_uri) |
| 551 | + |
| 552 | + # If it's already a GCS URI, use direct export |
| 553 | + if destination_str.startswith("gs://"): |
| 554 | + return self._export_to_gcs_native(query, destination_str, format, **options) |
| 555 | + |
| 556 | + # For non-GCS destinations, check if staging is configured |
| 557 | + staging_bucket = options.get("gcs_staging_bucket") or getattr(self.config, "gcs_staging_bucket", None) |
| 558 | + if not staging_bucket: |
| 559 | + # Fall back to fetch + write for non-GCS destinations without staging |
| 560 | + msg = "BigQuery native export requires GCS staging bucket for non-GCS destinations" |
| 561 | + raise NotImplementedError(msg) |
| 562 | + |
| 563 | + # Generate temporary GCS path |
| 564 | + from datetime import timezone |
| 565 | + |
| 566 | + timestamp = datetime.datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") |
| 567 | + temp_filename = f"bigquery_export_{timestamp}_{uuid.uuid4().hex[:8]}.{format}" |
| 568 | + temp_gcs_uri = f"gs://{staging_bucket}/temp_exports/{temp_filename}" |
| 569 | + |
| 570 | + try: |
| 571 | + # Export to temporary GCS location |
| 572 | + rows_exported = self._export_to_gcs_native(query, temp_gcs_uri, format, **options) |
| 573 | + |
| 574 | + # Transfer from GCS to final destination using storage backend |
| 575 | + backend, path = self._resolve_backend_and_path(destination_str) |
| 576 | + gcs_backend = self._get_storage_backend(temp_gcs_uri) |
| 577 | + |
| 578 | + # Download from GCS and upload to final destination |
| 579 | + data = gcs_backend.read_bytes(temp_gcs_uri) |
| 580 | + backend.write_bytes(path, data) |
| 581 | + |
| 582 | + return rows_exported |
| 583 | + finally: |
| 584 | + # Clean up temporary file |
| 585 | + try: |
| 586 | + gcs_backend = self._get_storage_backend(temp_gcs_uri) |
| 587 | + gcs_backend.delete(temp_gcs_uri) |
| 588 | + except Exception as e: |
| 589 | + logger.warning("Failed to clean up temporary GCS file %s: %s", temp_gcs_uri, e) |
| 590 | + |
| 591 | + def _export_to_gcs_native(self, query: str, gcs_uri: str, format: str, **options: Any) -> int: |
| 592 | + """Direct BigQuery export to GCS. |
| 593 | +
|
| 594 | + Args: |
| 595 | + query: SQL query to execute |
| 596 | + gcs_uri: GCS destination URI (must start with gs://) |
| 597 | + format: Export format (parquet, csv, json, avro) |
| 598 | + **options: Additional export options |
| 599 | +
|
| 600 | + Returns: |
| 601 | + Number of rows exported |
| 602 | + """ |
| 603 | + # First, run the query and store results in a temporary table |
| 604 | + |
| 605 | + temp_table_id = f"temp_export_{uuid.uuid4().hex[:8]}" |
| 606 | + dataset_id = getattr(self.connection, "default_dataset", None) or options.get("dataset", "temp") |
| 607 | + |
| 608 | + # Create a temporary table with query results |
| 609 | + query_with_table = f"CREATE OR REPLACE TABLE `{dataset_id}.{temp_table_id}` AS {query}" |
| 610 | + create_job = self._run_query_job(query_with_table, []) |
| 611 | + create_job.result() |
| 612 | + |
| 613 | + # Get row count |
| 614 | + count_query = f"SELECT COUNT(*) as cnt FROM `{dataset_id}.{temp_table_id}`" |
| 615 | + count_job = self._run_query_job(count_query, []) |
| 616 | + count_result = list(count_job.result()) |
| 617 | + row_count = count_result[0]["cnt"] if count_result else 0 |
| 618 | + |
| 619 | + try: |
| 620 | + # Configure extract job |
| 621 | + extract_config = ExtractJobConfig(**options) # type: ignore[no-untyped-call] |
| 622 | + |
| 623 | + # Set format |
| 624 | + format_mapping = { |
| 625 | + "parquet": SourceFormat.PARQUET, |
| 626 | + "csv": SourceFormat.CSV, |
| 627 | + "json": SourceFormat.NEWLINE_DELIMITED_JSON, |
| 628 | + "avro": SourceFormat.AVRO, |
| 629 | + } |
| 630 | + extract_config.destination_format = format_mapping.get(format, SourceFormat.PARQUET) |
| 631 | + |
| 632 | + # Extract table to GCS |
| 633 | + table_ref = self.connection.dataset(dataset_id).table(temp_table_id) |
| 634 | + extract_job = self.connection.extract_table(table_ref, gcs_uri, job_config=extract_config) |
| 635 | + extract_job.result() |
| 636 | + |
| 637 | + return row_count |
| 638 | + finally: |
| 639 | + # Clean up temporary table |
| 640 | + try: |
| 641 | + delete_query = f"DROP TABLE IF EXISTS `{dataset_id}.{temp_table_id}`" |
| 642 | + delete_job = self._run_query_job(delete_query, []) |
| 643 | + delete_job.result() |
| 644 | + except Exception as e: |
| 645 | + logger.warning("Failed to clean up temporary table %s: %s", temp_table_id, e) |
554 | 646 |
|
555 | 647 | # ============================================================================ |
556 | 648 | # BigQuery Native Arrow Support |
|
0 commit comments