diff --git a/wherobots/db/connection.py b/wherobots/db/connection.py index a81ce77..4624334 100644 --- a/wherobots/db/connection.py +++ b/wherobots/db/connection.py @@ -22,7 +22,7 @@ GeometryRepresentation, ) from wherobots.db.cursor import Cursor -from wherobots.db.errors import NotSupportedError, OperationalError +from wherobots.db.errors import OperationalError @dataclass @@ -78,10 +78,10 @@ def close(self): self.__ws.close() def commit(self): - raise NotSupportedError + pass def rollback(self): - raise NotSupportedError + pass def cursor(self) -> Cursor: return Cursor(self.__execute_sql, self.__cancel_query) @@ -155,12 +155,20 @@ def __listen(self): query.state = ExecutionState.COMPLETED if result_format == ResultsFormat.JSON: - query.handler(json.loads(result_bytes.decode("utf-8"))) + data = json.loads(result_bytes.decode("utf-8")) + columns = data["columns"] + column_types = data.get("column_types") + rows = data["rows"] + query.handler((columns, column_types, rows)) elif result_format == ResultsFormat.ARROW: buffer = pyarrow.py_buffer(result_bytes) stream = pyarrow.input_stream(buffer, result_compression) with pyarrow.ipc.open_stream(stream) as reader: - query.handler(reader.read_pandas()) + schema = reader.schema + columns = schema.names + column_types = [field.type for field in schema] + rows = reader.read_pandas().values.tolist() + query.handler((columns, column_types, rows)) else: query.handler( OperationalError(f"Unsupported results format {result_format}") diff --git a/wherobots/db/cursor.py b/wherobots/db/cursor.py index 1316e5b..285bc6f 100644 --- a/wherobots/db/cursor.py +++ b/wherobots/db/cursor.py @@ -5,12 +5,19 @@ _TYPE_MAP = { "object": "STRING", + "string": "STRING", + "int32": "NUMBER", "int64": "NUMBER", + "float32": "NUMBER", "float64": "NUMBER", "datetime64[ns]": "DATETIME", "timedelta[ns]": "DATETIME", + "double": "NUMBER", "bool": "NUMBER", # Assuming boolean is stored as number "bytes": "BINARY", + "struct": "STRUCT", + "list": "LIST", + "geometry": "GEOMETRY", } @@ -54,20 +61,21 @@ def __get_results(self) -> Optional[List[Tuple[Any, ...]]]: if isinstance(result, DatabaseError): raise result - self.__rowcount = len(result) - self.__results = result - if not result.empty: + columns, column_types, rows = result + self.__rowcount = len(rows) + self.__results = rows + if rows: self.__description = [ ( col_name, # name - _TYPE_MAP.get(str(result[col_name].dtype), "STRING"), # type_code + _TYPE_MAP.get(str(column_types[i]), "STRING"), # type_code None, # display_size - result[col_name].memory_usage(), # internal_size + None, # internal_size None, # precision None, # scale True, # null_ok; Assuming all columns can accept NULL values ) - for col_name in result.columns + for i, col_name in enumerate(columns) ] return self.__results diff --git a/wherobots/db/driver.py b/wherobots/db/driver.py index eeec6f5..e5f441d 100644 --- a/wherobots/db/driver.py +++ b/wherobots/db/driver.py @@ -4,7 +4,6 @@ """ import logging -import os import urllib.parse import queue import requests @@ -51,6 +50,7 @@ def connect( results_format: Union[ResultsFormat, None] = None, data_compression: Union[DataCompression, None] = None, geometry_representation: Union[GeometryRepresentation, None] = None, + ws_url: str = None, ) -> Connection: if not token and not api_key: raise ValueError("At least one of `token` or `api_key` is required") @@ -67,6 +67,24 @@ def connect( runtime = runtime or DEFAULT_RUNTIME region = region or DEFAULT_REGION + if ws_url: + logging.info( + "Using existing %s/%s runtime in %s from %s ...", + runtime.name, + runtime.value, + region.value, + host, + ) + session_uri = ws_url + return connect_direct( + uri=http_to_ws(session_uri), + headers=headers, + read_timeout=read_timeout, + results_format=results_format, + data_compression=data_compression, + geometry_representation=geometry_representation, + ) + logging.info( "Requesting %s/%s runtime in %s from %s ...", runtime.name,