From d57135eeec6fd2478521c717a24484bb5d7c615a Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Wed, 26 Feb 2025 14:19:50 +0000 Subject: [PATCH 1/3] init --- src/py/flwr/common/record/recordset.py | 179 +++++++++++--------- src/py/flwr/common/record/recordset_test.py | 23 ++- 2 files changed, 111 insertions(+), 91 deletions(-) diff --git a/src/py/flwr/common/record/recordset.py b/src/py/flwr/common/record/recordset.py index f8a85894bb80..d920db0d88ca 100644 --- a/src/py/flwr/common/record/recordset.py +++ b/src/py/flwr/common/record/recordset.py @@ -17,82 +17,79 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import cast +from logging import WARN +from textwrap import indent +from typing import TypeVar, cast +from ..logger import log from .configsrecord import ConfigsRecord from .metricsrecord import MetricsRecord from .parametersrecord import ParametersRecord from .typeddict import TypedDict +RecordType = ParametersRecord | MetricsRecord | ConfigsRecord -@dataclass -class RecordSetData: - """Inner data container for the RecordSet class.""" +T = TypeVar("T") - parameters_records: TypedDict[str, ParametersRecord] - metrics_records: TypedDict[str, MetricsRecord] - configs_records: TypedDict[str, ConfigsRecord] - def __init__( - self, - parameters_records: dict[str, ParametersRecord] | None = None, - metrics_records: dict[str, MetricsRecord] | None = None, - configs_records: dict[str, ConfigsRecord] | None = None, - ) -> None: - self.parameters_records = TypedDict[str, ParametersRecord]( - self._check_fn_str, self._check_fn_params - ) - self.metrics_records = TypedDict[str, MetricsRecord]( - self._check_fn_str, self._check_fn_metrics +def _check_key(key: str) -> None: + if not isinstance(key, str): + raise TypeError( + f"Expected `{str.__name__}`, but " + f"received `{type(key).__name__}` for the key." ) - self.configs_records = TypedDict[str, ConfigsRecord]( - self._check_fn_str, self._check_fn_configs + + +def _check_value(value: RecordType) -> None: + if not isinstance(value, (ParametersRecord, MetricsRecord, ConfigsRecord)): + raise TypeError( + f"Expected `{ParametersRecord.__name__}`, `{MetricsRecord.__name__}`, " + f"or `{ConfigsRecord.__name__}` but received " + f"`{type(value).__name__}` for the value." ) - if parameters_records is not None: - self.parameters_records.update(parameters_records) - if metrics_records is not None: - self.metrics_records.update(metrics_records) - if configs_records is not None: - self.configs_records.update(configs_records) - - def _check_fn_str(self, key: str) -> None: - if not isinstance(key, str): - raise TypeError( - f"Expected `{str.__name__}`, but " - f"received `{type(key).__name__}` for the key." - ) - def _check_fn_params(self, record: ParametersRecord) -> None: - if not isinstance(record, ParametersRecord): - raise TypeError( - f"Expected `{ParametersRecord.__name__}`, but " - f"received `{type(record).__name__}` for the value." - ) - def _check_fn_metrics(self, record: MetricsRecord) -> None: - if not isinstance(record, MetricsRecord): - raise TypeError( - f"Expected `{MetricsRecord.__name__}`, but " - f"received `{type(record).__name__}` for the value." - ) +class _SyncedDict(TypedDict[str, T]): + """A synchronized dictionary that mirrors changes to an underlying RecordSet. - def _check_fn_configs(self, record: ConfigsRecord) -> None: - if not isinstance(record, ConfigsRecord): + This dictionary ensures that any modifications (set or delete operations) + are automatically reflected in the associated `RecordSet`. Only values of + the specified `allowed_type` are permitted. + """ + + def __init__(self, ref_recordset: RecordSet, allowed_type: type[T]) -> None: + if not issubclass( + allowed_type, (ParametersRecord, MetricsRecord, ConfigsRecord) + ): + raise TypeError(f"{allowed_type} is not a valid type.") + super().__init__(_check_key, self.check_value) + self.recordset = ref_recordset + self.allowed_type = allowed_type + + def __setitem__(self, key: str, value: T) -> None: + super().__setitem__(key, value) + self.recordset[key] = cast(RecordType, value) + + def __delitem__(self, key: str) -> None: + super().__delitem__(key) + del self.recordset[key] + + def check_value(self, value: T) -> None: + """Check if value is of expected type.""" + if not isinstance(value, self.allowed_type): raise TypeError( - f"Expected `{ConfigsRecord.__name__}`, but " - f"received `{type(record).__name__}` for the value." + f"Expected `{self.allowed_type.__name__}`, but " + f"received `{type(value).__name__}` for the value." ) -class RecordSet: +class RecordSet(TypedDict[str, RecordType]): """RecordSet stores groups of parameters, metrics and configs. - A :code:`RecordSet` is the unified mechanism by which parameters, - metrics and configs can be either stored as part of a - `flwr.common.Context `_ in your apps - or communicated as part of a - `flwr.common.Message `_ between your apps. + A :class:`RecordSet` is the unified mechanism by which parameters, + metrics and configs can be either stored as part of a :class:`Context` + in your apps or communicated as part of a :class:`Message` between + your apps. Parameters ---------- @@ -127,12 +124,12 @@ class RecordSet: >>> # We can create a ConfigsRecord >>> c_record = ConfigsRecord({"lr": 0.1, "batch-size": 128}) >>> # Adding it to the record_set would look like this - >>> my_recordset.configs_records["my_config"] = c_record + >>> my_recordset["my_config"] = c_record >>> >>> # We can create a MetricsRecord following a similar process >>> m_record = MetricsRecord({"accuracy": 0.93, "losses": [0.23, 0.1]}) >>> # Adding it to the record_set would look like this - >>> my_recordset.metrics_records["my_metrics"] = m_record + >>> my_recordset["my_metrics"] = m_record Adding a :code:`ParametersRecord` follows the same steps as above but first, the array needs to be serialized and represented as a :code:`flwr.common.Array`. @@ -151,7 +148,7 @@ class RecordSet: >>> p_record = ParametersRecord({"my_array": arr}) >>> >>> # Adding it to the record_set would look like this - >>> my_recordset.parameters_records["my_parameters"] = p_record + >>> my_recordset["my_parameters"] = p_record For additional examples on how to construct each of the records types shown above, please refer to the documentation for :code:`ConfigsRecord`, @@ -164,39 +161,57 @@ def __init__( metrics_records: dict[str, MetricsRecord] | None = None, configs_records: dict[str, ConfigsRecord] | None = None, ) -> None: - data = RecordSetData( - parameters_records=parameters_records, - metrics_records=metrics_records, - configs_records=configs_records, - ) - self.__dict__["_data"] = data + super().__init__(_check_key, _check_value) + for key, p_record in (parameters_records or {}).items(): + self[key] = p_record + for key, m_record in (metrics_records or {}).items(): + self[key] = m_record + for key, c_record in (configs_records or {}).items(): + self[key] = c_record @property def parameters_records(self) -> TypedDict[str, ParametersRecord]: - """Dictionary holding ParametersRecord instances.""" - data = cast(RecordSetData, self.__dict__["_data"]) - return data.parameters_records + """Dictionary holding only ParametersRecord instances.""" + synced_dict = _SyncedDict[ParametersRecord](self, ParametersRecord) + for key, record in self.items(): + if isinstance(record, ParametersRecord): + synced_dict[key] = record + return synced_dict @property def metrics_records(self) -> TypedDict[str, MetricsRecord]: - """Dictionary holding MetricsRecord instances.""" - data = cast(RecordSetData, self.__dict__["_data"]) - return data.metrics_records + """Dictionary holding only MetricsRecord instances.""" + synced_dict = _SyncedDict[MetricsRecord](self, MetricsRecord) + for key, record in self.items(): + if isinstance(record, MetricsRecord): + synced_dict[key] = record + return synced_dict @property def configs_records(self) -> TypedDict[str, ConfigsRecord]: - """Dictionary holding ConfigsRecord instances.""" - data = cast(RecordSetData, self.__dict__["_data"]) - return data.configs_records + """Dictionary holding only ConfigsRecord instances.""" + synced_dict = _SyncedDict[ConfigsRecord](self, ConfigsRecord) + for key, record in self.items(): + if isinstance(record, ConfigsRecord): + synced_dict[key] = record + return synced_dict def __repr__(self) -> str: """Return a string representation of this instance.""" flds = ("parameters_records", "metrics_records", "configs_records") - view = ", ".join([f"{fld}={getattr(self, fld)!r}" for fld in flds]) - return f"{self.__class__.__qualname__}({view})" - - def __eq__(self, other: object) -> bool: - """Compare two instances of the class.""" - if not isinstance(other, self.__class__): - raise NotImplementedError - return self.__dict__ == other.__dict__ + fld_views = [f"{fld}={dict(getattr(self, fld))!r}" for fld in flds] + view = indent(",\n".join(fld_views), " ") + return f"{self.__class__.__qualname__}(\n{view}\n)" + + def __setitem__(self, key: str, value: RecordType) -> None: + """Set the given key to the given value after type checking.""" + original_value = self.get(key, None) + super().__setitem__(key, value) + if original_value is not None and not isinstance(value, type(original_value)): + log( + WARN, + "Key '%s' was overwritten: record of type `%s` replaced with type `%s`", + key, + type(original_value).__name__, + type(value).__name__, + ) diff --git a/src/py/flwr/common/record/recordset_test.py b/src/py/flwr/common/record/recordset_test.py index 09f227b92ba6..7e62b86a414a 100644 --- a/src/py/flwr/common/record/recordset_test.py +++ b/src/py/flwr/common/record/recordset_test.py @@ -16,7 +16,7 @@ import pickle -from collections import OrderedDict, namedtuple +from collections import OrderedDict from copy import deepcopy from typing import Callable, Union @@ -421,13 +421,18 @@ def test_record_is_picklable() -> None: def test_recordset_repr() -> None: """Test the string representation of RecordSet.""" # Prepare - kwargs = { - "parameters_records": {"params": ParametersRecord()}, - "metrics_records": {"metrics": MetricsRecord({"aa": 123})}, - "configs_records": {"configs": ConfigsRecord({"cc": bytes(9)})}, - } - rs = RecordSet(**kwargs) # type: ignore - expected = namedtuple("RecordSet", kwargs.keys())(**kwargs) + rs = RecordSet( + parameters_records={"params": ParametersRecord()}, + metrics_records={"metrics": MetricsRecord({"aa": 123})}, + configs_records={"configs": ConfigsRecord({"cc": bytes(5)})}, + ) + expected = """RecordSet( + parameters_records={'params': {}}, + metrics_records={'metrics': {'aa': 123}}, + configs_records={'configs': {'cc': b'\\x00\\x00\\x00\\x00\\x00'}} +)""" + print(str(rs)) + print(expected) # Assert - assert str(rs) == str(expected) + assert str(rs) == expected From d9b3cbe71b03ceab346fa5f5da1233ebf6798dea Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Wed, 26 Feb 2025 14:33:50 +0000 Subject: [PATCH 2/3] add unittest --- src/py/flwr/common/record/recordset_test.py | 31 +++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/common/record/recordset_test.py b/src/py/flwr/common/record/recordset_test.py index 7e62b86a414a..561dd2dd6eeb 100644 --- a/src/py/flwr/common/record/recordset_test.py +++ b/src/py/flwr/common/record/recordset_test.py @@ -431,8 +431,35 @@ def test_recordset_repr() -> None: metrics_records={'metrics': {'aa': 123}}, configs_records={'configs': {'cc': b'\\x00\\x00\\x00\\x00\\x00'}} )""" - print(str(rs)) - print(expected) # Assert assert str(rs) == expected + + +def test_recordset_set_get_del_item() -> None: + """Test setting, getting, and deleting items in RecordSet.""" + # Prepare + rs = RecordSet() + p_record = ParametersRecord() + m_record = MetricsRecord({"aa": 123}) + c_record = ConfigsRecord({"cc": bytes(5)}) + + # Execute + rs["params"] = p_record + rs["metrics"] = m_record + rs["configs"] = c_record + + # Assert + assert rs["params"] == p_record + assert rs["metrics"] == m_record + assert rs["configs"] == c_record + + # Execute + del rs["params"] + del rs["metrics"] + del rs["configs"] + + # Assert + assert "params" not in rs + assert "metrics" not in rs + assert "configs" not in rs From 1e1a0422d0c8e56ff618bc5f21242398d821ba59 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Wed, 26 Feb 2025 15:13:27 +0000 Subject: [PATCH 3/3] fix mypy --- src/py/flwr/common/record/recordset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/common/record/recordset.py b/src/py/flwr/common/record/recordset.py index d920db0d88ca..4dc3861fc41b 100644 --- a/src/py/flwr/common/record/recordset.py +++ b/src/py/flwr/common/record/recordset.py @@ -19,7 +19,7 @@ from logging import WARN from textwrap import indent -from typing import TypeVar, cast +from typing import TypeVar, Union, cast from ..logger import log from .configsrecord import ConfigsRecord @@ -27,7 +27,7 @@ from .parametersrecord import ParametersRecord from .typeddict import TypedDict -RecordType = ParametersRecord | MetricsRecord | ConfigsRecord +RecordType = Union[ParametersRecord, MetricsRecord, ConfigsRecord] T = TypeVar("T")