Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) Allow dict-like access for RecordSet while maintaining the original APIs #4963

Merged
merged 4 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 97 additions & 82 deletions src/py/flwr/common/record/recordset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,82 +17,79 @@

from __future__ import annotations

from dataclasses import dataclass
from typing import cast
from logging import WARN
from textwrap import indent
from typing import TypeVar, Union, cast

from ..logger import log
from .configsrecord import ConfigsRecord
from .metricsrecord import MetricsRecord
from .parametersrecord import ParametersRecord
from .typeddict import TypedDict

RecordType = Union[ParametersRecord, MetricsRecord, ConfigsRecord]

@dataclass
class RecordSetData:
"""Inner data container for the RecordSet class."""
T = TypeVar("T")

parameters_records: TypedDict[str, ParametersRecord]
metrics_records: TypedDict[str, MetricsRecord]
configs_records: TypedDict[str, ConfigsRecord]

def __init__(
self,
parameters_records: dict[str, ParametersRecord] | None = None,
metrics_records: dict[str, MetricsRecord] | None = None,
configs_records: dict[str, ConfigsRecord] | None = None,
) -> None:
self.parameters_records = TypedDict[str, ParametersRecord](
self._check_fn_str, self._check_fn_params
)
self.metrics_records = TypedDict[str, MetricsRecord](
self._check_fn_str, self._check_fn_metrics
def _check_key(key: str) -> None:
if not isinstance(key, str):
raise TypeError(
f"Expected `{str.__name__}`, but "
f"received `{type(key).__name__}` for the key."
)
self.configs_records = TypedDict[str, ConfigsRecord](
self._check_fn_str, self._check_fn_configs


def _check_value(value: RecordType) -> None:
if not isinstance(value, (ParametersRecord, MetricsRecord, ConfigsRecord)):
raise TypeError(
f"Expected `{ParametersRecord.__name__}`, `{MetricsRecord.__name__}`, "
f"or `{ConfigsRecord.__name__}` but received "
f"`{type(value).__name__}` for the value."
)
if parameters_records is not None:
self.parameters_records.update(parameters_records)
if metrics_records is not None:
self.metrics_records.update(metrics_records)
if configs_records is not None:
self.configs_records.update(configs_records)

def _check_fn_str(self, key: str) -> None:
if not isinstance(key, str):
raise TypeError(
f"Expected `{str.__name__}`, but "
f"received `{type(key).__name__}` for the key."
)

def _check_fn_params(self, record: ParametersRecord) -> None:
if not isinstance(record, ParametersRecord):
raise TypeError(
f"Expected `{ParametersRecord.__name__}`, but "
f"received `{type(record).__name__}` for the value."
)

def _check_fn_metrics(self, record: MetricsRecord) -> None:
if not isinstance(record, MetricsRecord):
raise TypeError(
f"Expected `{MetricsRecord.__name__}`, but "
f"received `{type(record).__name__}` for the value."
)
class _SyncedDict(TypedDict[str, T]):
"""A synchronized dictionary that mirrors changes to an underlying RecordSet.

def _check_fn_configs(self, record: ConfigsRecord) -> None:
if not isinstance(record, ConfigsRecord):
This dictionary ensures that any modifications (set or delete operations)
are automatically reflected in the associated `RecordSet`. Only values of
the specified `allowed_type` are permitted.
"""

def __init__(self, ref_recordset: RecordSet, allowed_type: type[T]) -> None:
if not issubclass(
allowed_type, (ParametersRecord, MetricsRecord, ConfigsRecord)
):
raise TypeError(f"{allowed_type} is not a valid type.")
super().__init__(_check_key, self.check_value)
self.recordset = ref_recordset
self.allowed_type = allowed_type

def __setitem__(self, key: str, value: T) -> None:
super().__setitem__(key, value)
self.recordset[key] = cast(RecordType, value)

def __delitem__(self, key: str) -> None:
super().__delitem__(key)
del self.recordset[key]

def check_value(self, value: T) -> None:
"""Check if value is of expected type."""
if not isinstance(value, self.allowed_type):
raise TypeError(
f"Expected `{ConfigsRecord.__name__}`, but "
f"received `{type(record).__name__}` for the value."
f"Expected `{self.allowed_type.__name__}`, but "
f"received `{type(value).__name__}` for the value."
)


class RecordSet:
class RecordSet(TypedDict[str, RecordType]):
"""RecordSet stores groups of parameters, metrics and configs.

A :code:`RecordSet` is the unified mechanism by which parameters,
metrics and configs can be either stored as part of a
`flwr.common.Context <flwr.common.Context.html>`_ in your apps
or communicated as part of a
`flwr.common.Message <flwr.common.Message.html>`_ between your apps.
A :class:`RecordSet` is the unified mechanism by which parameters,
metrics and configs can be either stored as part of a :class:`Context`
in your apps or communicated as part of a :class:`Message` between
your apps.

Parameters
----------
Expand Down Expand Up @@ -127,12 +124,12 @@ class RecordSet:
>>> # We can create a ConfigsRecord
>>> c_record = ConfigsRecord({"lr": 0.1, "batch-size": 128})
>>> # Adding it to the record_set would look like this
>>> my_recordset.configs_records["my_config"] = c_record
>>> my_recordset["my_config"] = c_record
>>>
>>> # We can create a MetricsRecord following a similar process
>>> m_record = MetricsRecord({"accuracy": 0.93, "losses": [0.23, 0.1]})
>>> # Adding it to the record_set would look like this
>>> my_recordset.metrics_records["my_metrics"] = m_record
>>> my_recordset["my_metrics"] = m_record

Adding a :code:`ParametersRecord` follows the same steps as above but first,
the array needs to be serialized and represented as a :code:`flwr.common.Array`.
Expand All @@ -151,7 +148,7 @@ class RecordSet:
>>> p_record = ParametersRecord({"my_array": arr})
>>>
>>> # Adding it to the record_set would look like this
>>> my_recordset.parameters_records["my_parameters"] = p_record
>>> my_recordset["my_parameters"] = p_record

For additional examples on how to construct each of the records types shown
above, please refer to the documentation for :code:`ConfigsRecord`,
Expand All @@ -164,39 +161,57 @@ def __init__(
metrics_records: dict[str, MetricsRecord] | None = None,
configs_records: dict[str, ConfigsRecord] | None = None,
) -> None:
data = RecordSetData(
parameters_records=parameters_records,
metrics_records=metrics_records,
configs_records=configs_records,
)
self.__dict__["_data"] = data
super().__init__(_check_key, _check_value)
for key, p_record in (parameters_records or {}).items():
self[key] = p_record
for key, m_record in (metrics_records or {}).items():
self[key] = m_record
for key, c_record in (configs_records or {}).items():
self[key] = c_record
Comment on lines +165 to +170
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we deprecate this constructor and encourage users to build the RecordSet by doing

rs = RecordSet()
rs["..."] = ...

?


@property
def parameters_records(self) -> TypedDict[str, ParametersRecord]:
"""Dictionary holding ParametersRecord instances."""
data = cast(RecordSetData, self.__dict__["_data"])
return data.parameters_records
"""Dictionary holding only ParametersRecord instances."""
synced_dict = _SyncedDict[ParametersRecord](self, ParametersRecord)
for key, record in self.items():
if isinstance(record, ParametersRecord):
synced_dict[key] = record
return synced_dict

@property
def metrics_records(self) -> TypedDict[str, MetricsRecord]:
"""Dictionary holding MetricsRecord instances."""
data = cast(RecordSetData, self.__dict__["_data"])
return data.metrics_records
"""Dictionary holding only MetricsRecord instances."""
synced_dict = _SyncedDict[MetricsRecord](self, MetricsRecord)
for key, record in self.items():
if isinstance(record, MetricsRecord):
synced_dict[key] = record
return synced_dict

@property
def configs_records(self) -> TypedDict[str, ConfigsRecord]:
"""Dictionary holding ConfigsRecord instances."""
data = cast(RecordSetData, self.__dict__["_data"])
return data.configs_records
"""Dictionary holding only ConfigsRecord instances."""
synced_dict = _SyncedDict[ConfigsRecord](self, ConfigsRecord)
for key, record in self.items():
if isinstance(record, ConfigsRecord):
synced_dict[key] = record
return synced_dict

def __repr__(self) -> str:
"""Return a string representation of this instance."""
flds = ("parameters_records", "metrics_records", "configs_records")
view = ", ".join([f"{fld}={getattr(self, fld)!r}" for fld in flds])
return f"{self.__class__.__qualname__}({view})"

def __eq__(self, other: object) -> bool:
"""Compare two instances of the class."""
if not isinstance(other, self.__class__):
raise NotImplementedError
return self.__dict__ == other.__dict__
fld_views = [f"{fld}={dict(getattr(self, fld))!r}" for fld in flds]
view = indent(",\n".join(fld_views), " ")
return f"{self.__class__.__qualname__}(\n{view}\n)"

def __setitem__(self, key: str, value: RecordType) -> None:
"""Set the given key to the given value after type checking."""
original_value = self.get(key, None)
super().__setitem__(key, value)
if original_value is not None and not isinstance(value, type(original_value)):
log(
WARN,
"Key '%s' was overwritten: record of type `%s` replaced with type `%s`",
key,
type(original_value).__name__,
type(value).__name__,
)
50 changes: 41 additions & 9 deletions src/py/flwr/common/record/recordset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


import pickle
from collections import OrderedDict, namedtuple
from collections import OrderedDict
from copy import deepcopy
from typing import Callable, Union

Expand Down Expand Up @@ -421,13 +421,45 @@ def test_record_is_picklable() -> None:
def test_recordset_repr() -> None:
"""Test the string representation of RecordSet."""
# Prepare
kwargs = {
"parameters_records": {"params": ParametersRecord()},
"metrics_records": {"metrics": MetricsRecord({"aa": 123})},
"configs_records": {"configs": ConfigsRecord({"cc": bytes(9)})},
}
rs = RecordSet(**kwargs) # type: ignore
expected = namedtuple("RecordSet", kwargs.keys())(**kwargs)
rs = RecordSet(
parameters_records={"params": ParametersRecord()},
metrics_records={"metrics": MetricsRecord({"aa": 123})},
configs_records={"configs": ConfigsRecord({"cc": bytes(5)})},
)
expected = """RecordSet(
parameters_records={'params': {}},
metrics_records={'metrics': {'aa': 123}},
configs_records={'configs': {'cc': b'\\x00\\x00\\x00\\x00\\x00'}}
)"""

# Assert
assert str(rs) == expected


def test_recordset_set_get_del_item() -> None:
"""Test setting, getting, and deleting items in RecordSet."""
# Prepare
rs = RecordSet()
p_record = ParametersRecord()
m_record = MetricsRecord({"aa": 123})
c_record = ConfigsRecord({"cc": bytes(5)})

# Execute
rs["params"] = p_record
rs["metrics"] = m_record
rs["configs"] = c_record

# Assert
assert rs["params"] == p_record
assert rs["metrics"] == m_record
assert rs["configs"] == c_record

# Execute
del rs["params"]
del rs["metrics"]
del rs["configs"]

# Assert
assert str(rs) == str(expected)
assert "params" not in rs
assert "metrics" not in rs
assert "configs" not in rs