diff --git a/lib/Dialect/ESI/runtime/python/esiaccel/types.py b/lib/Dialect/ESI/runtime/python/esiaccel/types.py index 540669459d60..2685d04eebf6 100644 --- a/lib/Dialect/ESI/runtime/python/esiaccel/types.py +++ b/lib/Dialect/ESI/runtime/python/esiaccel/types.py @@ -150,6 +150,10 @@ def serialize(self, obj: Union[bytearray, bytes, List[int]]) -> bytearray: raise ValueError(f"cannot convert {obj} to bytearray") def deserialize(self, data: bytearray) -> Tuple[bytearray, bytearray]: + if len(data) < self.max_size: + raise ValueError( + f"Cannot deserialize BitsType. Expected {self.max_size} bytes, got {len(data)}" + ) return (data[0:self.max_size], data[self.max_size:]) @@ -185,6 +189,10 @@ def serialize(self, obj: int) -> bytearray: return bytearray(int.to_bytes(obj, self.max_size, "little")) def deserialize(self, data: bytearray) -> Tuple[int, bytearray]: + if len(data) < self.max_size: + raise ValueError( + f"Cannot deserialize UIntType. Expected {self.max_size} bytes, got {len(data)}" + ) return (int.from_bytes(data[0:self.max_size], "little"), data[self.max_size:]) @@ -215,6 +223,10 @@ def serialize(self, obj: int) -> bytearray: return bytearray(int.to_bytes(obj, self.max_size, "little", signed=True)) def deserialize(self, data: bytearray) -> Tuple[int, bytearray]: + if len(data) < self.max_size: + raise ValueError( + f"Cannot deserialize SIntType. Expected {self.max_size} bytes, got {len(data)}" + ) return (int.from_bytes(data[0:self.max_size], "little", signed=True), data[self.max_size:]) @@ -266,10 +278,13 @@ def serialize(self, obj) -> bytearray: return ret def deserialize(self, data: bytearray) -> Tuple[Dict[str, Any], bytearray]: - ret = {} - for (fname, ftype) in reversed(self.fields): - (fval, data) = ftype.deserialize(data) - ret[fname] = fval + try: + ret = {} + for (fname, ftype) in reversed(self.fields): + (fval, data) = ftype.deserialize(data) + ret[fname] = fval + except Exception as e: + raise ValueError(f"Cannot deserialize StructType: {e}") return (ret, data) @@ -309,11 +324,14 @@ def serialize(self, lst: list) -> bytearray: return ret def deserialize(self, data: bytearray) -> Tuple[List[Any], bytearray]: - ret = [] - for _ in range(self.size): - (obj, data) = self.element_type.deserialize(data) - ret.append(obj) - ret.reverse() + try: + ret = [] + for _ in range(self.size): + (obj, data) = self.element_type.deserialize(data) + ret.append(obj) + ret.reverse() + except Exception as e: + raise ValueError(f"Cannot deserialize ArrayType: {e}") return (ret, data)