Skip to content

Commit

Permalink
feat(python): unify type system between python and xlang serializatio…
Browse files Browse the repository at this point in the history
…n in pyfury (#2034)

## What does this PR do?

This pr unifies type system between python and xlang serialization in
pyfury, so that we can remove duplicate code in pyfury, and lay the
foundation for following features:
- [ ] implement protocol such as chunk based map serialization for
pythobn and xlang serialization
- [ ] use python exsiting fastpath optimization in xlang serialization
- [ ] extend `DataClassSerializer` to support codegen based
serialization for xlang

## Related issues
#1690

## Does this PR introduce any user-facing change?

<!--
If any user-facing interface changes, please [open an
issue](https://github.com/apache/fury/issues/new/choose) describing the
need to do so and update the document if necessary.
-->

- [ ] Does this PR introduce any public API change?
- [ ] Does this PR introduce any binary protocol compatibility change?

## Benchmark

<!--
When the PR has an impact on performance (if you don't know whether the
PR will have an impact on performance, you can submit the PR first, and
if it will have impact on performance, the code reviewer will explain
it), be sure to attach a benchmark data here.
-->
  • Loading branch information
chaokunyang authored Jan 31, 2025
1 parent c5ef8ba commit b4f5a2a
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 313 deletions.
45 changes: 23 additions & 22 deletions python/pyfury/_fury.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
NOT_NULL_VALUE_FLAG,
)
from pyfury.util import is_little_endian, set_bit, get_bit, clear_bit
from pyfury.type import TypeId

try:
import numpy as np
Expand All @@ -44,24 +45,21 @@


MAGIC_NUMBER = 0x62D4
DEFAULT_DYNAMIC_WRITE_STRING_ID = -1
DEFAULT_DYNAMIC_WRITE_META_STR_ID = -1
DYNAMIC_TYPE_ID = -1
USE_CLASSNAME = 0
USE_CLASS_ID = 1
# preserve 0 as flag for class id not set in ClassInfo`
NO_CLASS_ID = 0
PYINT_CLASS_ID = 1
PYFLOAT_CLASS_ID = 2
PYBOOL_CLASS_ID = 3
STRING_CLASS_ID = 4
PICKLE_CLASS_ID = 5
PICKLE_STRONG_CACHE_CLASS_ID = 6
PICKLE_CACHE_CLASS_ID = 7
INT64_CLASS_ID = TypeId.INT64
FLOAT64_CLASS_ID = TypeId.FLOAT64
BOOL_CLASS_ID = TypeId.BOOL
STRING_CLASS_ID = TypeId.STRING
# `NOT_NULL_VALUE_FLAG` + `CLASS_ID << 1` in little-endian order
NOT_NULL_PYINT_FLAG = NOT_NULL_VALUE_FLAG & 0b11111111 | (PYINT_CLASS_ID << 9)
NOT_NULL_PYFLOAT_FLAG = NOT_NULL_VALUE_FLAG & 0b11111111 | (PYFLOAT_CLASS_ID << 9)
NOT_NULL_PYBOOL_FLAG = NOT_NULL_VALUE_FLAG & 0b11111111 | (PYBOOL_CLASS_ID << 9)
NOT_NULL_STRING_FLAG = NOT_NULL_VALUE_FLAG & 0b11111111 | (STRING_CLASS_ID << 9)
NOT_NULL_INT64_FLAG = NOT_NULL_VALUE_FLAG & 0b11111111 | (INT64_CLASS_ID << 8)
NOT_NULL_FLOAT64_FLAG = NOT_NULL_VALUE_FLAG & 0b11111111 | (FLOAT64_CLASS_ID << 8)
NOT_NULL_BOOL_FLAG = NOT_NULL_VALUE_FLAG & 0b11111111 | (BOOL_CLASS_ID << 8)
NOT_NULL_STRING_FLAG = NOT_NULL_VALUE_FLAG & 0b11111111 | (STRING_CLASS_ID << 8)
SMALL_STRING_THRESHOLD = 16


Expand Down Expand Up @@ -156,7 +154,7 @@ def __init__(
stacklevel=2,
)
self.pickler = Pickler(self.buffer)
self.unpickler = Unpickler(self.buffer)
self.unpickler = None
else:
self.pickler = _PicklerStub()
self.unpickler = _UnpicklerStub()
Expand Down Expand Up @@ -263,32 +261,32 @@ def serialize_ref(self, buffer, obj, classinfo=None):
buffer.write_string(obj)
return
elif cls is int:
buffer.write_int16(NOT_NULL_PYINT_FLAG)
buffer.write_int16(NOT_NULL_INT64_FLAG)
buffer.write_varint64(obj)
return
elif cls is bool:
buffer.write_int16(NOT_NULL_PYBOOL_FLAG)
buffer.write_int16(NOT_NULL_BOOL_FLAG)
buffer.write_bool(obj)
return
if self.ref_resolver.write_ref_or_null(buffer, obj):
return
if classinfo is None:
classinfo = self.class_resolver.get_classinfo(cls)
self.class_resolver.write_classinfo(buffer, classinfo)
self.class_resolver.write_typeinfo(buffer, classinfo)
classinfo.serializer.write(buffer, obj)

def serialize_nonref(self, buffer, obj):
cls = type(obj)
if cls is str:
buffer.write_varuint32(STRING_CLASS_ID << 1)
buffer.write_varuint32(STRING_CLASS_ID)
buffer.write_string(obj)
return
elif cls is int:
buffer.write_varuint32(PYINT_CLASS_ID << 1)
buffer.write_varuint32(INT64_CLASS_ID)
buffer.write_varint64(obj)
return
elif cls is bool:
buffer.write_varuint32(PYBOOL_CLASS_ID << 1)
buffer.write_varuint32(BOOL_CLASS_ID)
buffer.write_bool(obj)
return
else:
Expand Down Expand Up @@ -380,7 +378,7 @@ def deserialize_ref(self, buffer):
ref_id = ref_resolver.try_preserve_ref_id(buffer)
# indicates that the object is first read.
if ref_id >= NOT_NULL_VALUE_FLAG:
classinfo = self.class_resolver.read_classinfo(buffer)
classinfo = self.class_resolver.read_typeinfo(buffer)
o = classinfo.serializer.read(buffer)
ref_resolver.set_read_object(ref_id, o)
return o
Expand All @@ -389,7 +387,7 @@ def deserialize_ref(self, buffer):

def deserialize_nonref(self, buffer):
"""Deserialize not-null and non-reference object from buffer."""
classinfo = self.class_resolver.read_classinfo(buffer)
classinfo = self.class_resolver.read_typeinfo(buffer)
return classinfo.serializer.read(buffer)

def xdeserialize_ref(self, buffer, serializer=None):
Expand Down Expand Up @@ -448,7 +446,10 @@ def handle_unsupported_write(self, buffer, obj):
def handle_unsupported_read(self, buffer):
in_band = buffer.read_bool()
if in_band:
return self.unpickler.load()
unpickler = self.unpickler
if unpickler is None:
self.unpickler = unpickler = Unpickler(buffer)
return unpickler.load()
else:
assert self._unsupported_objects is not None
return next(self._unsupported_objects)
Expand Down
116 changes: 31 additions & 85 deletions python/pyfury/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,8 @@
Int16Serializer,
Int32Serializer,
Int64Serializer,
DynamicIntSerializer,
FloatSerializer,
DoubleSerializer,
DynamicFloatSerializer,
Float32Serializer,
Float64Serializer,
StringSerializer,
DateSerializer,
TimestampSerializer,
Expand All @@ -60,9 +58,9 @@
PickleCacheSerializer,
PickleStrongCacheSerializer,
PickleSerializer,
DataClassSerializer,
)
from pyfury._struct import ComplexObjectSerializer
from pyfury.buffer import Buffer
from pyfury.meta.metastring import MetaStringEncoder, MetaStringDecoder
from pyfury.type import (
TypeId,
Expand All @@ -78,13 +76,6 @@
DYNAMIC_TYPE_ID,
# preserve 0 as flag for class id not set in ClassInfo`
NO_CLASS_ID,
PYINT_CLASS_ID,
PYFLOAT_CLASS_ID,
PYBOOL_CLASS_ID,
STRING_CLASS_ID,
PICKLE_CLASS_ID,
PICKLE_STRONG_CACHE_CLASS_ID,
PICKLE_CACHE_CLASS_ID,
)

try:
Expand Down Expand Up @@ -168,7 +159,7 @@ def __init__(self, fury):
self._hash_to_classinfo = dict()
self._dynamic_written_metastr = []
self._type_id_to_classinfo = dict()
self._type_id_counter = PICKLE_CACHE_CLASS_ID + 1
self._type_id_counter = 64
self._dynamic_write_string_id = 0
# hold objects to avoid gc, since `flat_hash_map/vector` doesn't
# hold python reference.
Expand All @@ -181,60 +172,30 @@ def __init__(self, fury):
self.typename_decoder = MetaStringDecoder("$", "_")

def initialize(self):
self._initialize_xlang()
if self.fury.language == Language.PYTHON:
self._initialize_py()
else:
self._initialize_xlang()

def _initialize_py(self):
register = functools.partial(self._register_type, internal=True)
register(int, type_id=PYINT_CLASS_ID, serializer=Int64Serializer)
register(float, type_id=PYFLOAT_CLASS_ID, serializer=DoubleSerializer)
register(bool, type_id=PYBOOL_CLASS_ID, serializer=BooleanSerializer)
register(str, type_id=STRING_CLASS_ID, serializer=StringSerializer)
register(_PickleStub, type_id=PICKLE_CLASS_ID, serializer=PickleSerializer)
register(
_PickleStub,
type_id=PickleSerializer.PICKLE_CLASS_ID,
serializer=PickleSerializer,
)
register(
PickleStrongCacheStub,
type_id=PICKLE_STRONG_CACHE_CLASS_ID,
type_id=97,
serializer=PickleStrongCacheSerializer(self.fury),
)
register(
PickleCacheStub,
type_id=PICKLE_CACHE_CLASS_ID,
type_id=98,
serializer=PickleCacheSerializer(self.fury),
)
register(type(None), serializer=NoneSerializer)
register(Int8Type, serializer=ByteSerializer)
register(Int16Type, serializer=Int16Serializer)
register(Int32Type, serializer=Int32Serializer)
register(Int64Type, serializer=Int64Serializer)
register(Float32Type, serializer=FloatSerializer)
register(Float64Type, serializer=DoubleSerializer)
register(datetime.date, serializer=DateSerializer)
register(datetime.datetime, serializer=TimestampSerializer)
register(bytes, serializer=BytesSerializer)
register(list, serializer=ListSerializer)
register(tuple, serializer=TupleSerializer)
register(dict, serializer=MapSerializer)
register(set, serializer=SetSerializer)
register(enum.Enum, serializer=EnumSerializer)
register(slice, serializer=SliceSerializer)
try:
import pyarrow as pa
from pyfury.format.serializer import (
ArrowRecordBatchSerializer,
ArrowTableSerializer,
)

register(pa.RecordBatch, serializer=ArrowRecordBatchSerializer)
register(pa.Table, serializer=ArrowTableSerializer)
except Exception:
pass
for size, ftype, type_id in PyArraySerializer.typecode_dict.values():
register(ftype, serializer=PyArraySerializer(self.fury, ftype, type_id))
register(array.array, serializer=DynamicPyArraySerializer)
if np:
register(np.ndarray, serializer=NDArraySerializer)

def _initialize_xlang(self):
register = functools.partial(self._register_type, internal=True)
Expand All @@ -243,18 +204,18 @@ def _initialize_xlang(self):
register(Int16Type, type_id=TypeId.INT16, serializer=Int16Serializer)
register(Int32Type, type_id=TypeId.INT32, serializer=Int32Serializer)
register(Int64Type, type_id=TypeId.INT64, serializer=Int64Serializer)
register(int, type_id=DYNAMIC_TYPE_ID, serializer=DynamicIntSerializer)
register(int, type_id=TypeId.INT64, serializer=Int64Serializer)
register(
Float32Type,
type_id=TypeId.FLOAT32,
serializer=FloatSerializer,
serializer=Float32Serializer,
)
register(
Float64Type,
type_id=TypeId.FLOAT64,
serializer=DoubleSerializer,
serializer=Float64Serializer,
)
register(float, type_id=DYNAMIC_TYPE_ID, serializer=DynamicFloatSerializer)
register(float, type_id=TypeId.FLOAT64, serializer=Float64Serializer)
register(str, type_id=TypeId.STRING, serializer=StringSerializer)
# TODO(chaokunyang) DURATION DECIMAL
register(
Expand Down Expand Up @@ -512,9 +473,19 @@ def get_classinfo(self, cls, create=True):
raise TypeUnregisteredError(f"{cls} not registered")
logger.info("Class %s not registered", cls)
serializer = self._create_serializer(cls)
type_id = (
NO_CLASS_ID if type(serializer) is not PickleSerializer else PICKLE_CLASS_ID
)
type_id = None
if self.language == Language.PYTHON:
if isinstance(serializer, EnumSerializer):
type_id = TypeId.NAMED_ENUM
elif type(serializer) is PickleSerializer:
type_id = PickleSerializer.PICKLE_CLASS_ID
if not self.require_registration:
if isinstance(serializer, DataClassSerializer):
type_id = TypeId.NAMED_STRUCT
if type_id is None:
raise TypeUnregisteredError(
f"{cls} must be registered using `fury.register_type` API"
)
return self.__register_type(
cls,
type_id=type_id,
Expand Down Expand Up @@ -544,33 +515,6 @@ def _create_serializer(self, cls):
serializer = PickleSerializer(self.fury, cls)
return serializer

def write_classinfo(self, buffer: Buffer, classinfo):
if classinfo.dynamic_type:
return
type_id = classinfo.type_id
if type_id != NO_CLASS_ID:
buffer.write_varuint32(type_id << 1)
return
buffer.write_varuint32(1)
self.metastring_resolver.write_meta_string_bytes(
buffer, classinfo.namespace_bytes
)
self.metastring_resolver.write_meta_string_bytes(
buffer, classinfo.typename_bytes
)

def read_classinfo(self, buffer):
header = buffer.read_varuint32()
if header & 0b1 == 0:
type_id = header >> 1
classinfo = self._type_id_to_classinfo[type_id]
if classinfo.serializer is None:
classinfo.serializer = self._create_serializer(classinfo.cls)
return classinfo
ns_metabytes = self.metastring_resolver.read_meta_string_bytes(buffer)
type_metabytes = self.metastring_resolver.read_meta_string_bytes(buffer)
return self._load_metabytes_to_classinfo(ns_metabytes, type_metabytes)

def _load_metabytes_to_classinfo(self, ns_metabytes, type_metabytes):
typeinfo = self._ns_type_to_classinfo.get((ns_metabytes, type_metabytes))
if typeinfo is not None:
Expand All @@ -588,6 +532,8 @@ def _load_metabytes_to_classinfo(self, ns_metabytes, type_metabytes):
return classinfo

def write_typeinfo(self, buffer, classinfo):
if classinfo.dynamic_type:
return
type_id = classinfo.type_id
internal_type_id = type_id & 0xFF
buffer.write_varuint32(type_id)
Expand Down
Loading

0 comments on commit b4f5a2a

Please sign in to comment.