Skip to content

Commit b79ca86

Browse files
committed
streaming ingestion support for PUT operation
Signed-off-by: Sreekanth Vadigi <[email protected]>
1 parent e0ca049 commit b79ca86

File tree

5 files changed

+424
-48
lines changed

5 files changed

+424
-48
lines changed

examples/streaming_put.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Simple example of streaming PUT operations.
4+
5+
This demonstrates the basic usage of streaming PUT with the __input_stream__ token.
6+
"""
7+
8+
import io
9+
import os
10+
from databricks import sql
11+
12+
def main():
13+
"""Simple streaming PUT example."""
14+
15+
# Connect to Databricks
16+
connection = sql.connect(
17+
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
18+
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
19+
access_token=os.getenv("DATABRICKS_TOKEN"),
20+
)
21+
22+
with connection.cursor() as cursor:
23+
# Create a simple data stream
24+
data = b"Hello, streaming world!"
25+
stream = io.BytesIO(data)
26+
27+
# Upload to Unity Catalog volume
28+
cursor.execute(
29+
"PUT '__input_stream__' INTO '/Volumes/my_catalog/my_schema/my_volume/hello.txt'",
30+
input_stream=stream
31+
)
32+
33+
print("File uploaded successfully!")
34+
35+
if __name__ == "__main__":
36+
main()

src/databricks/sql/client.py

Lines changed: 134 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import time
2-
from typing import Dict, Tuple, List, Optional, Any, Union, Sequence
2+
from typing import Dict, Tuple, List, Optional, Any, Union, Sequence, BinaryIO
33
import pandas
44

55
try:
@@ -455,6 +455,7 @@ def __init__(
455455
self.active_command_id = None
456456
self.escaper = ParamEscaper()
457457
self.lastrowid = None
458+
self._input_stream_data = None
458459

459460
self.ASYNC_DEFAULT_POLLING_INTERVAL = 2
460461

@@ -625,6 +626,33 @@ def _handle_staging_operation(
625626
is not descended from staging_allowed_local_path.
626627
"""
627628

629+
assert self.active_result_set is not None
630+
row = self.active_result_set.fetchone()
631+
assert row is not None
632+
633+
# Parse headers
634+
headers = (
635+
json.loads(row.headers) if isinstance(row.headers, str) else row.headers
636+
)
637+
headers = dict(headers) if headers else {}
638+
639+
# Handle __input_stream__ token for PUT operations
640+
if (
641+
row.operation == "PUT" and
642+
getattr(row, "localFile", None) == "__input_stream__"
643+
):
644+
if not self._input_stream_data:
645+
raise ProgrammingError(
646+
"No input stream provided for streaming operation",
647+
session_id_hex=self.connection.get_session_id_hex()
648+
)
649+
return self._handle_staging_put_stream(
650+
presigned_url=row.presignedUrl,
651+
stream=self._input_stream_data,
652+
headers=headers
653+
)
654+
655+
# For non-streaming operations, validate staging_allowed_local_path
628656
if isinstance(staging_allowed_local_path, type(str())):
629657
_staging_allowed_local_paths = [staging_allowed_local_path]
630658
elif isinstance(staging_allowed_local_path, type(list())):
@@ -639,10 +667,6 @@ def _handle_staging_operation(
639667
os.path.abspath(i) for i in _staging_allowed_local_paths
640668
]
641669

642-
assert self.active_result_set is not None
643-
row = self.active_result_set.fetchone()
644-
assert row is not None
645-
646670
# Must set to None in cases where server response does not include localFile
647671
abs_localFile = None
648672

@@ -665,15 +689,10 @@ def _handle_staging_operation(
665689
session_id_hex=self.connection.get_session_id_hex(),
666690
)
667691

668-
# May be real headers, or could be json string
669-
headers = (
670-
json.loads(row.headers) if isinstance(row.headers, str) else row.headers
671-
)
672-
673692
handler_args = {
674693
"presigned_url": row.presignedUrl,
675694
"local_file": abs_localFile,
676-
"headers": dict(headers) or {},
695+
"headers": headers,
677696
}
678697

679698
logger.debug(
@@ -696,6 +715,60 @@ def _handle_staging_operation(
696715
session_id_hex=self.connection.get_session_id_hex(),
697716
)
698717

718+
@log_latency(StatementType.SQL)
719+
def _handle_staging_put_stream(
720+
self,
721+
presigned_url: str,
722+
stream: BinaryIO,
723+
headers: Optional[dict] = None,
724+
) -> None:
725+
"""Handle PUT operation with streaming data.
726+
727+
Args:
728+
presigned_url: The presigned URL for upload
729+
stream: Binary stream to upload
730+
headers: Optional HTTP headers
731+
732+
Raises:
733+
OperationalError: If the upload fails
734+
"""
735+
736+
# Prepare headers
737+
http_headers = dict(headers) if headers else {}
738+
739+
try:
740+
# Stream directly to presigned URL
741+
response = requests.put(
742+
url=presigned_url,
743+
data=stream,
744+
headers=http_headers,
745+
timeout=300 # 5 minute timeout
746+
)
747+
748+
# Check response codes
749+
OK = requests.codes.ok # 200
750+
CREATED = requests.codes.created # 201
751+
ACCEPTED = requests.codes.accepted # 202
752+
NO_CONTENT = requests.codes.no_content # 204
753+
754+
if response.status_code not in [OK, CREATED, NO_CONTENT, ACCEPTED]:
755+
raise OperationalError(
756+
f"Staging operation over HTTP was unsuccessful: {response.status_code}-{response.text}",
757+
session_id_hex=self.connection.get_session_id_hex()
758+
)
759+
760+
if response.status_code == ACCEPTED:
761+
logger.debug(
762+
f"Response code {ACCEPTED} from server indicates upload was accepted "
763+
"but not yet applied on the server. It's possible this command may fail later."
764+
)
765+
766+
except requests.exceptions.RequestException as e:
767+
raise OperationalError(
768+
f"HTTP request failed during stream upload: {str(e)}",
769+
session_id_hex=self.connection.get_session_id_hex()
770+
) from e
771+
699772
@log_latency(StatementType.SQL)
700773
def _handle_staging_put(
701774
self, presigned_url: str, local_file: str, headers: Optional[dict] = None
@@ -783,6 +856,7 @@ def execute(
783856
self,
784857
operation: str,
785858
parameters: Optional[TParameterCollection] = None,
859+
input_stream: Optional[BinaryIO] = None,
786860
enforce_embedded_schema_correctness=False,
787861
) -> "Cursor":
788862
"""
@@ -820,47 +894,60 @@ def execute(
820894
logger.debug(
821895
"Cursor.execute(operation=%s, parameters=%s)", operation, parameters
822896
)
897+
try:
898+
# Store stream data if provided
899+
self._input_stream_data = None
900+
if input_stream is not None:
901+
# Validate stream has required methods
902+
if not hasattr(input_stream, 'read'):
903+
raise TypeError(
904+
"input_stream must be a binary stream with read() method"
905+
)
906+
self._input_stream_data = input_stream
823907

824-
param_approach = self._determine_parameter_approach(parameters)
825-
if param_approach == ParameterApproach.NONE:
826-
prepared_params = NO_NATIVE_PARAMS
827-
prepared_operation = operation
908+
param_approach = self._determine_parameter_approach(parameters)
909+
if param_approach == ParameterApproach.NONE:
910+
prepared_params = NO_NATIVE_PARAMS
911+
prepared_operation = operation
828912

829-
elif param_approach == ParameterApproach.INLINE:
830-
prepared_operation, prepared_params = self._prepare_inline_parameters(
831-
operation, parameters
832-
)
833-
elif param_approach == ParameterApproach.NATIVE:
834-
normalized_parameters = self._normalize_tparametercollection(parameters)
835-
param_structure = self._determine_parameter_structure(normalized_parameters)
836-
transformed_operation = transform_paramstyle(
837-
operation, normalized_parameters, param_structure
838-
)
839-
prepared_operation, prepared_params = self._prepare_native_parameters(
840-
transformed_operation, normalized_parameters, param_structure
841-
)
842-
843-
self._check_not_closed()
844-
self._close_and_clear_active_result_set()
845-
self.active_result_set = self.backend.execute_command(
846-
operation=prepared_operation,
847-
session_id=self.connection.session.session_id,
848-
max_rows=self.arraysize,
849-
max_bytes=self.buffer_size_bytes,
850-
lz4_compression=self.connection.lz4_compression,
851-
cursor=self,
852-
use_cloud_fetch=self.connection.use_cloud_fetch,
853-
parameters=prepared_params,
854-
async_op=False,
855-
enforce_embedded_schema_correctness=enforce_embedded_schema_correctness,
856-
)
913+
elif param_approach == ParameterApproach.INLINE:
914+
prepared_operation, prepared_params = self._prepare_inline_parameters(
915+
operation, parameters
916+
)
917+
elif param_approach == ParameterApproach.NATIVE:
918+
normalized_parameters = self._normalize_tparametercollection(parameters)
919+
param_structure = self._determine_parameter_structure(normalized_parameters)
920+
transformed_operation = transform_paramstyle(
921+
operation, normalized_parameters, param_structure
922+
)
923+
prepared_operation, prepared_params = self._prepare_native_parameters(
924+
transformed_operation, normalized_parameters, param_structure
925+
)
857926

858-
if self.active_result_set and self.active_result_set.is_staging_operation:
859-
self._handle_staging_operation(
860-
staging_allowed_local_path=self.connection.staging_allowed_local_path
927+
self._check_not_closed()
928+
self._close_and_clear_active_result_set()
929+
self.active_result_set = self.backend.execute_command(
930+
operation=prepared_operation,
931+
session_id=self.connection.session.session_id,
932+
max_rows=self.arraysize,
933+
max_bytes=self.buffer_size_bytes,
934+
lz4_compression=self.connection.lz4_compression,
935+
cursor=self,
936+
use_cloud_fetch=self.connection.use_cloud_fetch,
937+
parameters=prepared_params,
938+
async_op=False,
939+
enforce_embedded_schema_correctness=enforce_embedded_schema_correctness,
861940
)
862941

863-
return self
942+
if self.active_result_set and self.active_result_set.is_staging_operation:
943+
self._handle_staging_operation(
944+
staging_allowed_local_path=self.connection.staging_allowed_local_path
945+
)
946+
947+
return self
948+
finally:
949+
# Clean up stream data
950+
self._input_stream_data = None
864951

865952
@log_latency(StatementType.QUERY)
866953
def execute_async(
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#!/usr/bin/env python3
2+
"""
3+
E2E tests for streaming PUT operations.
4+
"""
5+
6+
import io
7+
import pytest
8+
from datetime import datetime
9+
10+
11+
class PySQLStreamingPutTestSuiteMixin:
12+
"""Test suite for streaming PUT operations."""
13+
14+
def test_streaming_put_basic(self, catalog, schema):
15+
"""Test basic streaming PUT functionality."""
16+
17+
# Create test data
18+
test_data = b"Hello, streaming world! This is test data."
19+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
20+
filename = f"stream_test_{timestamp}.txt"
21+
22+
with self.connection() as conn:
23+
with conn.cursor() as cursor:
24+
with io.BytesIO(test_data) as stream:
25+
cursor.execute(
26+
f"PUT '__input_stream__' INTO '/Volumes/{catalog}/{schema}/e2etests/{filename}'",
27+
input_stream=stream
28+
)
29+
30+
# Verify file exists
31+
cursor.execute(f"LIST '/Volumes/{catalog}/{schema}/e2etests/'")
32+
files = cursor.fetchall()
33+
34+
# Check if our file is in the list
35+
file_paths = [row[0] for row in files]
36+
expected_path = f"/Volumes/{catalog}/{schema}/e2etests/{filename}"
37+
38+
assert expected_path in file_paths, f"File {expected_path} not found in {file_paths}"
39+
40+
41+
def test_streaming_put_missing_stream(self, catalog, schema):
42+
"""Test that missing stream raises appropriate error."""
43+
44+
with self.connection() as conn:
45+
with conn.cursor() as cursor:
46+
# Test without providing stream
47+
with pytest.raises(Exception): # Should fail
48+
cursor.execute(
49+
f"PUT '__input_stream__' INTO '/Volumes/{catalog}/{schema}/e2etests/test.txt'"
50+
# Note: No input_stream parameter
51+
)

tests/e2e/test_driver.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@
4747
)
4848
from tests.e2e.common.staging_ingestion_tests import PySQLStagingIngestionTestSuiteMixin
4949
from tests.e2e.common.retry_test_mixins import PySQLRetryTestsMixin
50-
5150
from tests.e2e.common.uc_volume_tests import PySQLUCVolumeTestSuiteMixin
51+
from tests.e2e.common.streaming_put_tests import PySQLStreamingPutTestSuiteMixin
5252

5353
from databricks.sql.exc import SessionAlreadyClosedError
5454

@@ -256,6 +256,7 @@ class TestPySQLCoreSuite(
256256
PySQLStagingIngestionTestSuiteMixin,
257257
PySQLRetryTestsMixin,
258258
PySQLUCVolumeTestSuiteMixin,
259+
PySQLStreamingPutTestSuiteMixin,
259260
):
260261
validate_row_value_type = True
261262
validate_result = True

0 commit comments

Comments
 (0)