Skip to content

Commit 1563a18

Browse files
committed
addressing review comments
Signed-off-by: Sreekanth Vadigi <[email protected]>
1 parent 8ae220c commit 1563a18

File tree

4 files changed

+213
-264
lines changed

4 files changed

+213
-264
lines changed

examples/streaming_put.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,26 @@
99
import os
1010
from databricks import sql
1111

12-
def main():
13-
"""Simple streaming PUT example."""
14-
15-
# Connect to Databricks
16-
connection = sql.connect(
17-
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
18-
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
19-
access_token=os.getenv("DATABRICKS_TOKEN"),
20-
)
21-
12+
with sql.connect(
13+
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
14+
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
15+
access_token=os.getenv("DATABRICKS_TOKEN"),
16+
) as connection:
17+
2218
with connection.cursor() as cursor:
2319
# Create a simple data stream
2420
data = b"Hello, streaming world!"
2521
stream = io.BytesIO(data)
2622

23+
# Get catalog, schema, and volume from environment variables
24+
catalog = os.getenv("DATABRICKS_CATALOG")
25+
schema = os.getenv("DATABRICKS_SCHEMA")
26+
volume = os.getenv("DATABRICKS_VOLUME")
27+
2728
# Upload to Unity Catalog volume
2829
cursor.execute(
29-
"PUT '__input_stream__' INTO '/Volumes/my_catalog/my_schema/my_volume/hello.txt'",
30+
f"PUT '__input_stream__' INTO '/Volumes/{catalog}/{schema}/{volume}/hello.txt' OVERWRITE",
3031
input_stream=stream
3132
)
3233

33-
print("File uploaded successfully!")
34-
35-
if __name__ == "__main__":
36-
main()
34+
print("File uploaded successfully!")

src/databricks/sql/client.py

Lines changed: 83 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
)
6868
from databricks.sql.telemetry.latency_logger import log_latency
6969
from databricks.sql.telemetry.models.enums import StatementType
70+
from databricks.sql.common.http import DatabricksHttpClient, HttpMethod
7071

7172
logger = logging.getLogger(__name__)
7273

@@ -455,7 +456,6 @@ def __init__(
455456
self.active_command_id = None
456457
self.escaper = ParamEscaper()
457458
self.lastrowid = None
458-
self._input_stream_data: Optional[BinaryIO] = None
459459

460460
self.ASYNC_DEFAULT_POLLING_INTERVAL = 2
461461

@@ -616,8 +616,29 @@ def _check_not_closed(self):
616616
session_id_hex=self.connection.get_session_id_hex(),
617617
)
618618

619+
def _validate_staging_http_response(self, response: requests.Response, operation_name: str = "staging operation") -> None:
620+
621+
# Check response codes
622+
OK = requests.codes.ok # 200
623+
CREATED = requests.codes.created # 201
624+
ACCEPTED = requests.codes.accepted # 202
625+
NO_CONTENT = requests.codes.no_content # 204
626+
627+
if response.status_code not in [OK, CREATED, NO_CONTENT, ACCEPTED]:
628+
raise OperationalError(
629+
f"{operation_name} over HTTP was unsuccessful: {response.status_code}-{response.text}",
630+
session_id_hex=self.connection.get_session_id_hex(),
631+
)
632+
633+
if response.status_code == ACCEPTED:
634+
logger.debug(
635+
"Response code %s from server indicates %s was accepted "
636+
"but not yet applied on the server. It's possible this command may fail later.",
637+
ACCEPTED, operation_name
638+
)
639+
619640
def _handle_staging_operation(
620-
self, staging_allowed_local_path: Union[None, str, List[str]]
641+
self, staging_allowed_local_path: Union[None, str, List[str]], input_stream: Optional[BinaryIO] = None
621642
):
622643
"""Fetch the HTTP request instruction from a staging ingestion command
623644
and call the designated handler.
@@ -641,14 +662,9 @@ def _handle_staging_operation(
641662
row.operation == "PUT"
642663
and getattr(row, "localFile", None) == "__input_stream__"
643664
):
644-
if not self._input_stream_data:
645-
raise ProgrammingError(
646-
"No input stream provided for streaming operation",
647-
session_id_hex=self.connection.get_session_id_hex(),
648-
)
649665
return self._handle_staging_put_stream(
650666
presigned_url=row.presignedUrl,
651-
stream=self._input_stream_data,
667+
stream=input_stream,
652668
headers=headers,
653669
)
654670

@@ -696,7 +712,8 @@ def _handle_staging_operation(
696712
}
697713

698714
logger.debug(
699-
f"Attempting staging operation indicated by server: {row.operation} - {getattr(row, 'localFile', '')}"
715+
"Attempting staging operation indicated by server: %s - %s",
716+
row.operation, getattr(row, 'localFile', '')
700717
)
701718

702719
# TODO: Create a retry loop here to re-attempt if the request times out or fails
@@ -720,54 +737,37 @@ def _handle_staging_put_stream(
720737
self,
721738
presigned_url: str,
722739
stream: BinaryIO,
723-
headers: Optional[dict] = None,
740+
headers: dict = {},
724741
) -> None:
725742
"""Handle PUT operation with streaming data.
726743
727744
Args:
728745
presigned_url: The presigned URL for upload
729746
stream: Binary stream to upload
730-
headers: Optional HTTP headers
747+
headers: HTTP headers
731748
732749
Raises:
750+
ProgrammingError: If no input stream is provided
733751
OperationalError: If the upload fails
734752
"""
735753

736-
# Prepare headers
737-
http_headers = dict(headers) if headers else {}
738-
739-
try:
740-
# Stream directly to presigned URL
741-
response = requests.put(
742-
url=presigned_url,
743-
data=stream,
744-
headers=http_headers,
745-
timeout=300, # 5 minute timeout
754+
if not stream:
755+
raise ProgrammingError(
756+
"No input stream provided for streaming operation",
757+
session_id_hex=self.connection.get_session_id_hex(),
746758
)
747759

748-
# Check response codes
749-
OK = requests.codes.ok # 200
750-
CREATED = requests.codes.created # 201
751-
ACCEPTED = requests.codes.accepted # 202
752-
NO_CONTENT = requests.codes.no_content # 204
760+
http_client = DatabricksHttpClient.get_instance()
753761

754-
if response.status_code not in [OK, CREATED, NO_CONTENT, ACCEPTED]:
755-
raise OperationalError(
756-
f"Staging operation over HTTP was unsuccessful: {response.status_code}-{response.text}",
757-
session_id_hex=self.connection.get_session_id_hex(),
758-
)
759-
760-
if response.status_code == ACCEPTED:
761-
logger.debug(
762-
f"Response code {ACCEPTED} from server indicates upload was accepted "
763-
"but not yet applied on the server. It's possible this command may fail later."
764-
)
765-
766-
except requests.exceptions.RequestException as e:
767-
raise OperationalError(
768-
f"HTTP request failed during stream upload: {str(e)}",
769-
session_id_hex=self.connection.get_session_id_hex(),
770-
) from e
762+
# Stream directly to presigned URL
763+
with http_client.execute(
764+
method=HttpMethod.PUT,
765+
url=presigned_url,
766+
data=stream,
767+
headers=headers,
768+
timeout=300, # 5 minute timeout
769+
) as response:
770+
self._validate_staging_http_response(response, "stream upload")
771771

772772
@log_latency(StatementType.SQL)
773773
def _handle_staging_put(
@@ -787,27 +787,7 @@ def _handle_staging_put(
787787
with open(local_file, "rb") as fh:
788788
r = requests.put(url=presigned_url, data=fh, headers=headers)
789789

790-
# fmt: off
791-
# Design borrowed from: https://stackoverflow.com/a/2342589/5093960
792-
793-
OK = requests.codes.ok # 200
794-
CREATED = requests.codes.created # 201
795-
ACCEPTED = requests.codes.accepted # 202
796-
NO_CONTENT = requests.codes.no_content # 204
797-
798-
# fmt: on
799-
800-
if r.status_code not in [OK, CREATED, NO_CONTENT, ACCEPTED]:
801-
raise OperationalError(
802-
f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}",
803-
session_id_hex=self.connection.get_session_id_hex(),
804-
)
805-
806-
if r.status_code == ACCEPTED:
807-
logger.debug(
808-
f"Response code {ACCEPTED} from server indicates ingestion command was accepted "
809-
+ "but not yet applied on the server. It's possible this command may fail later."
810-
)
790+
self._validate_staging_http_response(r, "file upload")
811791

812792
@log_latency(StatementType.SQL)
813793
def _handle_staging_get(
@@ -856,8 +836,8 @@ def execute(
856836
self,
857837
operation: str,
858838
parameters: Optional[TParameterCollection] = None,
859-
input_stream: Optional[BinaryIO] = None,
860839
enforce_embedded_schema_correctness=False,
840+
input_stream: Optional[BinaryIO] = None,
861841
) -> "Cursor":
862842
"""
863843
Execute a query and wait for execution to complete.
@@ -894,62 +874,49 @@ def execute(
894874
logger.debug(
895875
"Cursor.execute(operation=%s, parameters=%s)", operation, parameters
896876
)
897-
try:
898-
# Store stream data if provided
899-
self._input_stream_data = None
900-
if input_stream is not None:
901-
# Validate stream has required methods
902-
if not hasattr(input_stream, "read"):
903-
raise TypeError(
904-
"input_stream must be a binary stream with read() method"
905-
)
906-
self._input_stream_data = input_stream
877+
param_approach = self._determine_parameter_approach(parameters)
878+
if param_approach == ParameterApproach.NONE:
879+
prepared_params = NO_NATIVE_PARAMS
880+
prepared_operation = operation
907881

908-
param_approach = self._determine_parameter_approach(parameters)
909-
if param_approach == ParameterApproach.NONE:
910-
prepared_params = NO_NATIVE_PARAMS
911-
prepared_operation = operation
882+
elif param_approach == ParameterApproach.INLINE:
883+
prepared_operation, prepared_params = self._prepare_inline_parameters(
884+
operation, parameters
885+
)
886+
elif param_approach == ParameterApproach.NATIVE:
887+
normalized_parameters = self._normalize_tparametercollection(parameters)
888+
param_structure = self._determine_parameter_structure(
889+
normalized_parameters
890+
)
891+
transformed_operation = transform_paramstyle(
892+
operation, normalized_parameters, param_structure
893+
)
894+
prepared_operation, prepared_params = self._prepare_native_parameters(
895+
transformed_operation, normalized_parameters, param_structure
896+
)
912897

913-
elif param_approach == ParameterApproach.INLINE:
914-
prepared_operation, prepared_params = self._prepare_inline_parameters(
915-
operation, parameters
916-
)
917-
elif param_approach == ParameterApproach.NATIVE:
918-
normalized_parameters = self._normalize_tparametercollection(parameters)
919-
param_structure = self._determine_parameter_structure(
920-
normalized_parameters
921-
)
922-
transformed_operation = transform_paramstyle(
923-
operation, normalized_parameters, param_structure
924-
)
925-
prepared_operation, prepared_params = self._prepare_native_parameters(
926-
transformed_operation, normalized_parameters, param_structure
927-
)
898+
self._check_not_closed()
899+
self._close_and_clear_active_result_set()
900+
self.active_result_set = self.backend.execute_command(
901+
operation=prepared_operation,
902+
session_id=self.connection.session.session_id,
903+
max_rows=self.arraysize,
904+
max_bytes=self.buffer_size_bytes,
905+
lz4_compression=self.connection.lz4_compression,
906+
cursor=self,
907+
use_cloud_fetch=self.connection.use_cloud_fetch,
908+
parameters=prepared_params,
909+
async_op=False,
910+
enforce_embedded_schema_correctness=enforce_embedded_schema_correctness,
911+
)
928912

929-
self._check_not_closed()
930-
self._close_and_clear_active_result_set()
931-
self.active_result_set = self.backend.execute_command(
932-
operation=prepared_operation,
933-
session_id=self.connection.session.session_id,
934-
max_rows=self.arraysize,
935-
max_bytes=self.buffer_size_bytes,
936-
lz4_compression=self.connection.lz4_compression,
937-
cursor=self,
938-
use_cloud_fetch=self.connection.use_cloud_fetch,
939-
parameters=prepared_params,
940-
async_op=False,
941-
enforce_embedded_schema_correctness=enforce_embedded_schema_correctness,
913+
if self.active_result_set and self.active_result_set.is_staging_operation:
914+
self._handle_staging_operation(
915+
staging_allowed_local_path=self.connection.staging_allowed_local_path,
916+
input_stream=input_stream
942917
)
943918

944-
if self.active_result_set and self.active_result_set.is_staging_operation:
945-
self._handle_staging_operation(
946-
staging_allowed_local_path=self.connection.staging_allowed_local_path
947-
)
948-
949-
return self
950-
finally:
951-
# Clean up stream data
952-
self._input_stream_data = None
919+
return self
953920

954921
@log_latency(StatementType.QUERY)
955922
def execute_async(

0 commit comments

Comments
 (0)