Skip to content

Commit

Permalink
Add count_bytes method to RecordSets (#3083)
Browse files Browse the repository at this point in the history
Co-authored-by: Heng Pan <[email protected]>
  • Loading branch information
jafermarq and panh99 authored Mar 13, 2024
1 parent 808ef75 commit 7b1433d
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 8 deletions.
38 changes: 37 additions & 1 deletion src/py/flwr/common/record/configsrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""ConfigsRecord."""


from typing import Dict, Optional, get_args
from typing import Dict, List, Optional, get_args

from flwr.common.typing import ConfigsRecordValues, ConfigsScalar

Expand Down Expand Up @@ -85,3 +85,39 @@ def __init__(
self[k] = configs_dict[k]
if not keep_input:
del configs_dict[k]

def count_bytes(self) -> int:
"""Return number of Bytes stored in this object.
This function counts booleans as occupying 1 Byte.
"""

def get_var_bytes(value: ConfigsScalar) -> int:
"""Return Bytes of value passed."""
if isinstance(value, bool):
var_bytes = 1
elif isinstance(value, (int, float)):
var_bytes = (
8 # the profobufing represents int/floats in ConfigRecords as 64bit
)
if isinstance(value, (str, bytes)):
var_bytes = len(value)
return var_bytes

num_bytes = 0

for k, v in self.items():
if isinstance(v, List):
if isinstance(v[0], (bytes, str)):
# not all str are of equal length necessarily
# for both the footprint of each element is 1 Byte
num_bytes += int(sum(len(s) for s in v)) # type: ignore
else:
num_bytes += get_var_bytes(v[0]) * len(v)
else:
num_bytes += get_var_bytes(v)

# We also count the bytes footprint of the keys
num_bytes += len(k)

return num_bytes
18 changes: 17 additions & 1 deletion src/py/flwr/common/record/metricsrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""MetricsRecord."""


from typing import Dict, Optional, get_args
from typing import Dict, List, Optional, get_args

from flwr.common.typing import MetricsRecordValues, MetricsScalar

Expand Down Expand Up @@ -84,3 +84,19 @@ def __init__(
self[k] = metrics_dict[k]
if not keep_input:
del metrics_dict[k]

def count_bytes(self) -> int:
"""Return number of Bytes stored in this object."""
num_bytes = 0

for k, v in self.items():
if isinstance(v, List):
# both int and float normally take 4 bytes
# But MetricRecords are mapped to 64bit int/float
# during protobuffing
num_bytes += 8 * len(v)
else:
num_bytes += 8
# We also count the bytes footprint of the keys
num_bytes += len(k)
return num_bytes
17 changes: 17 additions & 0 deletions src/py/flwr/common/record/parametersrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,20 @@ def __init__(
self[k] = array_dict[k]
if not keep_input:
del array_dict[k]

def count_bytes(self) -> int:
"""Return number of Bytes stored in this object.
Note that a small amount of Bytes might also be included in this counting that
correspond to metadata of the serialized object (e.g. of NumPy array) needed for
deseralization.
"""
num_bytes = 0

for k, v in self.items():
num_bytes += len(v.data)

# We also count the bytes footprint of the keys
num_bytes += len(k)

return num_bytes
51 changes: 45 additions & 6 deletions src/py/flwr/common/record/parametersrecord_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,26 @@
# ==============================================================================
"""Unit tests for ParametersRecord and Array."""


import unittest
from collections import OrderedDict
from io import BytesIO
from typing import List

import numpy as np
import pytest

from flwr.common import ndarray_to_bytes

from ..constant import SType
from .parametersrecord import Array
from ..typing import NDArray
from .parametersrecord import Array, ParametersRecord


def _get_buffer_from_ndarray(array: NDArray) -> bytes:
"""Return a bytes buffer froma given NumPy array."""
buffer = BytesIO()
np.save(buffer, array, allow_pickle=False)
return buffer.getvalue()


class TestArray(unittest.TestCase):
Expand All @@ -31,16 +43,15 @@ def test_numpy_conversion_valid(self) -> None:
"""Test the numpy method with valid Array instance."""
# Prepare
original_array = np.array([1, 2, 3], dtype=np.float32)
buffer = BytesIO()
np.save(buffer, original_array, allow_pickle=False)
buffer.seek(0)

buffer = _get_buffer_from_ndarray(original_array)

# Execute
array_instance = Array(
dtype=str(original_array.dtype),
shape=list(original_array.shape),
stype=SType.NUMPY,
data=buffer.read(),
data=buffer,
)
converted_array = array_instance.numpy()

Expand All @@ -60,3 +71,31 @@ def test_numpy_conversion_invalid(self) -> None:
# Execute and assert
with self.assertRaises(TypeError):
array_instance.numpy()


@pytest.mark.parametrize(
"shape, dtype",
[
([100], "float32"),
([31, 31], "int8"),
([31, 153], "bool_"), # bool_ is represented as a whole Byte in NumPy
],
)
def test_count_bytes(shape: List[int], dtype: str) -> None:
"""Test bytes in a ParametersRecord are computed correctly."""
original_array = np.random.randn(*shape).astype(np.dtype(dtype))

buff = ndarray_to_bytes(original_array)

buffer = _get_buffer_from_ndarray(original_array)

array_instance = Array(
dtype=str(original_array.dtype),
shape=list(original_array.shape),
stype=SType.NUMPY,
data=buffer,
)
key_name = "data"
p_record = ParametersRecord(OrderedDict({key_name: array_instance}))

assert len(buff) + len(key_name) == p_record.count_bytes()
39 changes: 39 additions & 0 deletions src/py/flwr/common/record/recordset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,42 @@ def test_set_configs_to_configsrecord_with_incorrect_types(

with pytest.raises(TypeError):
c_record.update(my_configs)


def test_count_bytes_metricsrecord() -> None:
"""Test counting bytes in MetricsRecord."""
data = {"a": 1, "b": 2.0, "c": [1, 2, 3], "d": [1.0, 2.0, 3.0, 4.0, 5.0]}
bytes_in_dict = 8 + 8 + 3 * 8 + 5 * 8
bytes_in_dict += 4 # represnting the keys

m_record = MetricsRecord()
m_record.update(OrderedDict(data))
record_bytest_count = m_record.count_bytes()
assert bytes_in_dict == record_bytest_count


def test_count_bytes_configsrecord() -> None:
"""Test counting bytes in ConfigsRecord."""
data = {"a": 1, "b": 2.0, "c": [1, 2, 3], "d": [1.0, 2.0, 3.0, 4.0, 5.0]}
bytes_in_dict = 8 + 8 + 3 * 8 + 5 * 8
bytes_in_dict += 4 # represnting the keys

to_add = {
"aa": True,
"bb": "False",
"cc": bytes(9),
"dd": [True, False, False],
"ee": ["True", "False"],
"ff": [bytes(1), bytes(13), bytes(51)],
}
data = {**data, **to_add}
bytes_in_dict += 1 + 5 + 9 + 3 + (4 + 5) + (1 + 13 + 51)
bytes_in_dict += 12 # represnting the keys

bytes_in_dict = int(bytes_in_dict)

c_record = ConfigsRecord()
c_record.update(OrderedDict(data))

record_bytest_count = c_record.count_bytes()
assert bytes_in_dict == record_bytest_count

0 comments on commit 7b1433d

Please sign in to comment.