diff --git a/examples/streaming_put.py b/examples/streaming_put.py new file mode 100644 index 000000000..4e7697099 --- /dev/null +++ b/examples/streaming_put.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +""" +Simple example of streaming PUT operations. + +This demonstrates the basic usage of streaming PUT with the __input_stream__ token. +""" + +import io +import os +from databricks import sql + +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), +) as connection: + + with connection.cursor() as cursor: + # Create a simple data stream + data = b"Hello, streaming world!" + stream = io.BytesIO(data) + + # Get catalog, schema, and volume from environment variables + catalog = os.getenv("DATABRICKS_CATALOG") + schema = os.getenv("DATABRICKS_SCHEMA") + volume = os.getenv("DATABRICKS_VOLUME") + + # Upload to Unity Catalog volume + cursor.execute( + f"PUT '__input_stream__' INTO '/Volumes/{catalog}/{schema}/{volume}/hello.txt' OVERWRITE", + input_stream=stream + ) + + print("File uploaded successfully!") \ No newline at end of file diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index e4166f117..1d37d5f96 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1,5 +1,5 @@ import time -from typing import Dict, Tuple, List, Optional, Any, Union, Sequence +from typing import Dict, Tuple, List, Optional, Any, Union, Sequence, BinaryIO import pandas try: @@ -67,6 +67,7 @@ ) from databricks.sql.telemetry.latency_logger import log_latency from databricks.sql.telemetry.models.enums import StatementType +from databricks.sql.common.http import DatabricksHttpClient, HttpMethod, UploadType logger = logging.getLogger(__name__) @@ -615,8 +616,34 @@ def _check_not_closed(self): session_id_hex=self.connection.get_session_id_hex(), ) + def _validate_staging_http_response( + self, response: requests.Response, operation_name: str = "staging operation" + ) -> None: + + # Check response codes + OK = requests.codes.ok # 200 + CREATED = requests.codes.created # 201 + ACCEPTED = requests.codes.accepted # 202 + NO_CONTENT = requests.codes.no_content # 204 + + if response.status_code not in [OK, CREATED, NO_CONTENT, ACCEPTED]: + raise OperationalError( + f"{operation_name} over HTTP was unsuccessful: {response.status_code}-{response.text}", + session_id_hex=self.connection.get_session_id_hex(), + ) + + if response.status_code == ACCEPTED: + logger.debug( + "Response code %s from server indicates %s was accepted " + "but not yet applied on the server. It's possible this command may fail later.", + ACCEPTED, + operation_name, + ) + def _handle_staging_operation( - self, staging_allowed_local_path: Union[None, str, List[str]] + self, + staging_allowed_local_path: Union[None, str, List[str]], + input_stream: Optional[BinaryIO] = None, ): """Fetch the HTTP request instruction from a staging ingestion command and call the designated handler. @@ -625,6 +652,28 @@ def _handle_staging_operation( is not descended from staging_allowed_local_path. """ + assert self.active_result_set is not None + row = self.active_result_set.fetchone() + assert row is not None + + # Parse headers + headers = ( + json.loads(row.headers) if isinstance(row.headers, str) else row.headers + ) + headers = dict(headers) if headers else {} + + # Handle __input_stream__ token for PUT operations + if ( + row.operation == "PUT" + and getattr(row, "localFile", None) == "__input_stream__" + ): + return self._handle_staging_put_stream( + presigned_url=row.presignedUrl, + stream=input_stream, + headers=headers, + ) + + # For non-streaming operations, validate staging_allowed_local_path if isinstance(staging_allowed_local_path, type(str())): _staging_allowed_local_paths = [staging_allowed_local_path] elif isinstance(staging_allowed_local_path, type(list())): @@ -639,10 +688,6 @@ def _handle_staging_operation( os.path.abspath(i) for i in _staging_allowed_local_paths ] - assert self.active_result_set is not None - row = self.active_result_set.fetchone() - assert row is not None - # Must set to None in cases where server response does not include localFile abs_localFile = None @@ -665,19 +710,16 @@ def _handle_staging_operation( session_id_hex=self.connection.get_session_id_hex(), ) - # May be real headers, or could be json string - headers = ( - json.loads(row.headers) if isinstance(row.headers, str) else row.headers - ) - handler_args = { "presigned_url": row.presignedUrl, "local_file": abs_localFile, - "headers": dict(headers) or {}, + "headers": headers, } logger.debug( - f"Attempting staging operation indicated by server: {row.operation} - {getattr(row, 'localFile', '')}" + "Attempting staging operation indicated by server: %s - %s", + row.operation, + getattr(row, "localFile", ""), ) # TODO: Create a retry loop here to re-attempt if the request times out or fails @@ -696,6 +738,45 @@ def _handle_staging_operation( session_id_hex=self.connection.get_session_id_hex(), ) + @log_latency(StatementType.SQL) + def _handle_staging_put_stream( + self, + presigned_url: str, + stream: BinaryIO, + headers: dict = {}, + ) -> None: + """Handle PUT operation with streaming data. + + Args: + presigned_url: The presigned URL for upload + stream: Binary stream to upload + headers: HTTP headers + + Raises: + ProgrammingError: If no input stream is provided + OperationalError: If the upload fails + """ + + if not stream: + raise ProgrammingError( + "No input stream provided for streaming operation", + session_id_hex=self.connection.get_session_id_hex(), + ) + + http_client = DatabricksHttpClient.get_instance() + + # Stream directly to presigned URL + with http_client.execute( + method=HttpMethod.PUT, + url=presigned_url, + data=stream, + headers=headers, + timeout=300, # 5 minute timeout + ) as response: + self._validate_staging_http_response( + response, UploadType.STREAM_UPLOAD.value + ) + @log_latency(StatementType.SQL) def _handle_staging_put( self, presigned_url: str, local_file: str, headers: Optional[dict] = None @@ -714,27 +795,7 @@ def _handle_staging_put( with open(local_file, "rb") as fh: r = requests.put(url=presigned_url, data=fh, headers=headers) - # fmt: off - # Design borrowed from: https://stackoverflow.com/a/2342589/5093960 - - OK = requests.codes.ok # 200 - CREATED = requests.codes.created # 201 - ACCEPTED = requests.codes.accepted # 202 - NO_CONTENT = requests.codes.no_content # 204 - - # fmt: on - - if r.status_code not in [OK, CREATED, NO_CONTENT, ACCEPTED]: - raise OperationalError( - f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", - session_id_hex=self.connection.get_session_id_hex(), - ) - - if r.status_code == ACCEPTED: - logger.debug( - f"Response code {ACCEPTED} from server indicates ingestion command was accepted " - + "but not yet applied on the server. It's possible this command may fail later." - ) + self._validate_staging_http_response(r, UploadType.FILE_UPLOAD.value) @log_latency(StatementType.SQL) def _handle_staging_get( @@ -784,6 +845,7 @@ def execute( operation: str, parameters: Optional[TParameterCollection] = None, enforce_embedded_schema_correctness=False, + input_stream: Optional[BinaryIO] = None, ) -> "Cursor": """ Execute a query and wait for execution to complete. @@ -820,7 +882,6 @@ def execute( logger.debug( "Cursor.execute(operation=%s, parameters=%s)", operation, parameters ) - param_approach = self._determine_parameter_approach(parameters) if param_approach == ParameterApproach.NONE: prepared_params = NO_NATIVE_PARAMS @@ -857,7 +918,8 @@ def execute( if self.active_result_set and self.active_result_set.is_staging_operation: self._handle_staging_operation( - staging_allowed_local_path=self.connection.staging_allowed_local_path + staging_allowed_local_path=self.connection.staging_allowed_local_path, + input_stream=input_stream, ) return self diff --git a/src/databricks/sql/common/http.py b/src/databricks/sql/common/http.py index ec4e3341a..cd7562666 100644 --- a/src/databricks/sql/common/http.py +++ b/src/databricks/sql/common/http.py @@ -25,6 +25,11 @@ class HttpHeader(str, Enum): AUTHORIZATION = "Authorization" +class UploadType(str, Enum): + FILE_UPLOAD = "file_upload" + STREAM_UPLOAD = "stream_upload" + + # Dataclass for OAuthHTTP Response @dataclass class OAuthResponse: diff --git a/tests/e2e/common/streaming_put_tests.py b/tests/e2e/common/streaming_put_tests.py new file mode 100644 index 000000000..5d3e88943 --- /dev/null +++ b/tests/e2e/common/streaming_put_tests.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +""" +E2E tests for streaming PUT operations. +""" + +import io +import logging +import pytest +from datetime import datetime + +logger = logging.getLogger(__name__) + + +class PySQLStreamingPutTestSuiteMixin: + """Test suite for streaming PUT operations.""" + + def test_streaming_put_basic(self, catalog, schema): + """Test basic streaming PUT functionality.""" + + # Create test data + test_data = b"Hello, streaming world! This is test data." + filename = "streaming_put_test.txt" + file_path = f"/Volumes/{catalog}/{schema}/e2etests/{filename}" + + try: + with self.connection() as conn: + with conn.cursor() as cursor: + self._cleanup_test_file(file_path) + + with io.BytesIO(test_data) as stream: + cursor.execute( + f"PUT '__input_stream__' INTO '{file_path}'", + input_stream=stream + ) + + # Verify file exists + cursor.execute(f"LIST '/Volumes/{catalog}/{schema}/e2etests/'") + files = cursor.fetchall() + + # Check if our file is in the list + file_paths = [row[0] for row in files] + assert file_path in file_paths, f"File {file_path} not found in {file_paths}" + finally: + self._cleanup_test_file(file_path) + + def test_streaming_put_missing_stream(self, catalog, schema): + """Test that missing stream raises appropriate error.""" + + with self.connection() as conn: + with conn.cursor() as cursor: + # Test without providing stream + with pytest.raises(Exception): # Should fail + cursor.execute( + f"PUT '__input_stream__' INTO '/Volumes/{catalog}/{schema}/e2etests/test.txt'" + # Note: No input_stream parameter + ) + + def _cleanup_test_file(self, file_path): + """Clean up a test file if it exists.""" + try: + with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: + with conn.cursor() as cursor: + cursor.execute(f"REMOVE '{file_path}'") + logger.info("Successfully cleaned up test file: %s", file_path) + except Exception as e: + logger.error("Cleanup failed for %s: %s", file_path, e) \ No newline at end of file diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 042fcc10a..7a7041094 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -47,8 +47,8 @@ ) from tests.e2e.common.staging_ingestion_tests import PySQLStagingIngestionTestSuiteMixin from tests.e2e.common.retry_test_mixins import PySQLRetryTestsMixin - from tests.e2e.common.uc_volume_tests import PySQLUCVolumeTestSuiteMixin +from tests.e2e.common.streaming_put_tests import PySQLStreamingPutTestSuiteMixin from databricks.sql.exc import SessionAlreadyClosedError @@ -256,6 +256,7 @@ class TestPySQLCoreSuite( PySQLStagingIngestionTestSuiteMixin, PySQLRetryTestsMixin, PySQLUCVolumeTestSuiteMixin, + PySQLStreamingPutTestSuiteMixin, ): validate_row_value_type = True validate_result = True diff --git a/tests/unit/test_streaming_put.py b/tests/unit/test_streaming_put.py new file mode 100644 index 000000000..e1d3f27e3 --- /dev/null +++ b/tests/unit/test_streaming_put.py @@ -0,0 +1,171 @@ + +import io +import pytest +from unittest.mock import patch, Mock, MagicMock +import databricks.sql.client as client +from databricks.sql import ProgrammingError +import requests + + +class TestStreamingPut: + """Unit tests for streaming PUT functionality.""" + + @pytest.fixture + def mock_connection(self): + return Mock() + + @pytest.fixture + def mock_backend(self): + return Mock() + + @pytest.fixture + def cursor(self, mock_connection, mock_backend): + return client.Cursor( + connection=mock_connection, + backend=mock_backend + ) + + def _setup_mock_staging_put_stream_response(self, mock_backend): + """Helper method to set up mock staging PUT stream response.""" + mock_result_set = Mock() + mock_result_set.is_staging_operation = True + mock_backend.execute_command.return_value = mock_result_set + + mock_row = Mock() + mock_row.operation = "PUT" + mock_row.localFile = "__input_stream__" + mock_row.presignedUrl = "https://example.com/upload" + mock_row.headers = "{}" + mock_result_set.fetchone.return_value = mock_row + + return mock_result_set + + def test_execute_with_valid_stream(self, cursor, mock_backend): + """Test execute method with valid input stream.""" + + # Mock the backend response + self._setup_mock_staging_put_stream_response(mock_backend) + + # Test with valid stream + test_stream = io.BytesIO(b"test data") + + with patch.object(cursor, '_handle_staging_put_stream') as mock_handler: + cursor.execute( + "PUT '__input_stream__' INTO '/Volumes/test/cat/schema/vol/file.txt'", + input_stream=test_stream + ) + + # Verify staging handler was called + mock_handler.assert_called_once() + + def test_execute_with_invalid_stream_types(self, cursor, mock_backend): + + # Mock the backend response + self._setup_mock_staging_put_stream_response(mock_backend) + + # Test with None input stream + with pytest.raises(client.ProgrammingError) as excinfo: + cursor.execute( + "PUT '__input_stream__' INTO '/Volumes/test/cat/schema/vol/file.txt'", + input_stream=None + ) + assert "No input stream provided for streaming operation" in str(excinfo.value) + + def test_execute_with_none_stream_for_staging_put(self, cursor, mock_backend): + """Test execute method rejects None stream for streaming PUT operations.""" + + # Mock staging operation response for None case + self._setup_mock_staging_put_stream_response(mock_backend) + + # None with __input_stream__ raises ProgrammingError + with pytest.raises(client.ProgrammingError) as excinfo: + cursor.execute( + "PUT '__input_stream__' INTO '/Volumes/test/cat/schema/vol/file.txt'", + input_stream=None + ) + error_msg = str(excinfo.value) + assert "No input stream provided for streaming operation" in error_msg + + def test_handle_staging_put_stream_success(self, cursor): + """Test successful streaming PUT operation.""" + + test_stream = io.BytesIO(b"test data") + presigned_url = "https://example.com/upload" + headers = {"Content-Type": "text/plain"} + + with patch('databricks.sql.client.DatabricksHttpClient') as mock_client_class: + mock_client = Mock() + mock_client_class.get_instance.return_value = mock_client + + # Mock the context manager properly using MagicMock + mock_context = MagicMock() + mock_response = Mock() + mock_response.status_code = 200 + mock_context.__enter__.return_value = mock_response + mock_context.__exit__.return_value = None + mock_client.execute.return_value = mock_context + + cursor._handle_staging_put_stream( + presigned_url=presigned_url, + stream=test_stream, + headers=headers + ) + + # Verify the HTTP client was called correctly + mock_client.execute.assert_called_once() + call_args = mock_client.execute.call_args + assert call_args[1]['method'].value == 'PUT' + assert call_args[1]['url'] == presigned_url + assert call_args[1]['data'] == test_stream + assert call_args[1]['headers'] == headers + + def test_handle_staging_put_stream_http_error(self, cursor): + """Test streaming PUT operation with HTTP error.""" + + test_stream = io.BytesIO(b"test data") + presigned_url = "https://example.com/upload" + + with patch('databricks.sql.client.DatabricksHttpClient') as mock_client_class: + mock_client = Mock() + mock_client_class.get_instance.return_value = mock_client + + # Mock the context manager with error response + mock_context = MagicMock() + mock_response = Mock() + mock_response.status_code = 500 + mock_response.text = "Internal Server Error" + mock_context.__enter__.return_value = mock_response + mock_context.__exit__.return_value = None + mock_client.execute.return_value = mock_context + + with pytest.raises(client.OperationalError) as excinfo: + cursor._handle_staging_put_stream( + presigned_url=presigned_url, + stream=test_stream + ) + + # Check for the actual error message format + assert "500" in str(excinfo.value) + + def test_handle_staging_put_stream_network_error(self, cursor): + """Test streaming PUT operation with network error.""" + + test_stream = io.BytesIO(b"test data") + presigned_url = "https://example.com/upload" + + with patch('databricks.sql.client.DatabricksHttpClient') as mock_client_class: + mock_client = Mock() + mock_client_class.get_instance.return_value = mock_client + + # Mock the context manager to raise an exception + mock_context = MagicMock() + mock_context.__enter__.side_effect = requests.exceptions.RequestException("Network error") + mock_client.execute.return_value = mock_context + + with pytest.raises(requests.exceptions.RequestException) as excinfo: + cursor._handle_staging_put_stream( + presigned_url=presigned_url, + stream=test_stream + ) + + assert "Network error" in str(excinfo.value)