Skip to content

Commit

Permalink
[SPARK-42437][CONNECT][PYTHON][FOLLOW-UP] Storage level proto converters
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Converters between Proto and StorageLevel to avoid code duplication
It's follow up from apache#40015

### Why are the changes needed?
Code deduplication

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Existing tests

Closes apache#40859 from khalidmammadov/storage_level_converter.

Authored-by: khalidmammadov <[email protected]>
Signed-off-by: Takuya UESHIN <[email protected]>
  • Loading branch information
khalidmammadov authored and ueshin committed Apr 19, 2023
1 parent 77b72fc commit c291564
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 25 deletions.
19 changes: 3 additions & 16 deletions python/pyspark/sql/connect/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from google.rpc import error_details_pb2

from pyspark.resource.information import ResourceInformation
from pyspark.sql.connect.conversion import storage_level_to_proto, proto_to_storage_level
import pyspark.sql.connect.proto as pb2
import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib
import pyspark.sql.connect.types as types
Expand Down Expand Up @@ -469,13 +470,7 @@ def fromProto(cls, pb: Any) -> "AnalyzeResult":
elif pb.HasField("unpersist"):
pass
elif pb.HasField("get_storage_level"):
storage_level = StorageLevel(
useDisk=pb.get_storage_level.storage_level.use_disk,
useMemory=pb.get_storage_level.storage_level.use_memory,
useOffHeap=pb.get_storage_level.storage_level.use_off_heap,
deserialized=pb.get_storage_level.storage_level.deserialized,
replication=pb.get_storage_level.storage_level.replication,
)
storage_level = proto_to_storage_level(pb.get_storage_level.storage_level)
else:
raise SparkConnectException("No analyze result found!")

Expand Down Expand Up @@ -877,15 +872,7 @@ def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult:
req.persist.relation.CopyFrom(cast(pb2.Relation, kwargs.get("relation")))
if kwargs.get("storage_level", None) is not None:
storage_level = cast(StorageLevel, kwargs.get("storage_level"))
req.persist.storage_level.CopyFrom(
pb2.StorageLevel(
use_disk=storage_level.useDisk,
use_memory=storage_level.useMemory,
use_off_heap=storage_level.useOffHeap,
deserialized=storage_level.deserialized,
replication=storage_level.replication,
)
)
req.persist.storage_level.CopyFrom(storage_level_to_proto(storage_level))
elif method == "unpersist":
req.unpersist.relation.CopyFrom(cast(pb2.Relation, kwargs.get("relation")))
if kwargs.get("blocking", None) is not None:
Expand Down
24 changes: 24 additions & 0 deletions python/pyspark/sql/connect/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@
cast,
)

from pyspark.storagelevel import StorageLevel
from pyspark.sql.connect.types import to_arrow_schema
import pyspark.sql.connect.proto as pb2

from typing import (
Any,
Expand Down Expand Up @@ -486,3 +488,25 @@ def convert(table: "pa.Table", schema: StructType) -> List[Row]:
values = [field_converters[j](columnar_data[j][i]) for j in range(table.num_columns)]
rows.append(_create_row(fields=schema.fieldNames(), values=values))
return rows


def storage_level_to_proto(storage_level: StorageLevel) -> pb2.StorageLevel:
assert storage_level is not None and isinstance(storage_level, StorageLevel)
return pb2.StorageLevel(
use_disk=storage_level.useDisk,
use_memory=storage_level.useMemory,
use_off_heap=storage_level.useOffHeap,
deserialized=storage_level.deserialized,
replication=storage_level.replication,
)


def proto_to_storage_level(storage_level: pb2.StorageLevel) -> StorageLevel:
assert storage_level is not None and isinstance(storage_level, pb2.StorageLevel)
return StorageLevel(
useDisk=storage_level.use_disk,
useMemory=storage_level.use_memory,
useOffHeap=storage_level.use_off_heap,
deserialized=storage_level.deserialized,
replication=storage_level.replication,
)
11 changes: 2 additions & 9 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from pyspark.sql.types import DataType

import pyspark.sql.connect.proto as proto
from pyspark.sql.connect.conversion import storage_level_to_proto
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.expressions import (
SortOrder,
Expand Down Expand Up @@ -1896,15 +1897,7 @@ def __init__(self, table_name: str, storage_level: Optional[StorageLevel] = None
def plan(self, session: "SparkConnectClient") -> proto.Relation:
_cache_table = proto.CacheTable(table_name=self._table_name)
if self._storage_level:
_cache_table.storage_level.CopyFrom(
proto.StorageLevel(
use_disk=self._storage_level.useDisk,
use_memory=self._storage_level.useMemory,
use_off_heap=self._storage_level.useOffHeap,
deserialized=self._storage_level.deserialized,
replication=self._storage_level.replication,
)
)
_cache_table.storage_level.CopyFrom(storage_level_to_proto(self._storage_level))
plan = proto.Relation(catalog=proto.Catalog(cache_table=_cache_table))
return plan

Expand Down

0 comments on commit c291564

Please sign in to comment.