Skip to content

Commit 61a0c13

Browse files
authored
Fix copy/pickling for Scalar and Vector (#217)
* Create an XYData class Signed-off-by: Michael Johansen <[email protected]> * Add more unit tests for corner cases. Signed-off-by: Michael Johansen <[email protected]> * Update intro.inc and the RTD link in README. Signed-off-by: Michael Johansen <[email protected]> * Refactors and fixes based on review feedback. Signed-off-by: Michael Johansen <[email protected]> * Re-implement using NumPy. Remove append functionality. Signed-off-by: Michael Johansen <[email protected]> * Remove/make constants private. Signed-off-by: Michael Johansen <[email protected]> * Add a few more unit tests. Signed-off-by: Michael Johansen <[email protected]> * Fix doctests. Signed-off-by: Michael Johansen <[email protected]> * Fix repr tests on oldest python/numpy. Signed-off-by: Michael Johansen <[email protected]> * Fix linting issue. Signed-off-by: Michael Johansen <[email protected]> * Include non-units properties in __repr__(). Signed-off-by: Michael Johansen <[email protected]> * Fix linting issues. Signed-off-by: Michael Johansen <[email protected]> * Update stale docs. Remove old spectrum references. Signed-off-by: Michael Johansen <[email protected]> * Make TypeVars public Signed-off-by: Michael Johansen <[email protected]> * Rename x_values to x_data and y_values to y_data. Signed-off-by: Michael Johansen <[email protected]> * Remove x_units and y_units from reduce() Signed-off-by: Michael Johansen <[email protected]> * Add failing tests for non-units extended properties. The next commit will fix these failures Signed-off-by: Michael Johansen <[email protected]> * Fix unit tests and doctests. Signed-off-by: Michael Johansen <[email protected]> * Fix formatting errors. Signed-off-by: Michael Johansen <[email protected]> * Fix error checking around mismatching units. Signed-off-by: Michael Johansen <[email protected]> * Improvement to repr based on PR feedback. Signed-off-by: Michael Johansen <[email protected]> * Address review feedback related to repr. Signed-off-by: Michael Johansen <[email protected]> * Fix linting issue. Signed-off-by: Michael Johansen <[email protected]> --------- Signed-off-by: Michael Johansen <[email protected]>
1 parent 98d2821 commit 61a0c13

File tree

6 files changed

+231
-52
lines changed

6 files changed

+231
-52
lines changed

src/nitypes/scalar.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99

1010
from __future__ import annotations
1111

12+
from collections.abc import Mapping
1213
from typing import TYPE_CHECKING, Any, Generic, Union
1314

14-
from typing_extensions import TypeVar, final
15+
from typing_extensions import Self, TypeVar, final
1516

1617
from nitypes._exceptions import invalid_arg_type
1718
from nitypes.waveform._extended_properties import UNIT_DESCRIPTION
19+
from nitypes.waveform.typing import ExtendedPropertyValue
1820

1921
if TYPE_CHECKING:
2022
# Import from the public package so the docs don't reference private submodules.
@@ -36,9 +38,9 @@ class Scalar(Generic[TScalar_co]):
3638
To construct a scalar data object, use the :class:`Scalar` class:
3739
3840
>>> Scalar(False)
39-
nitypes.scalar.Scalar(value=False, units='')
41+
nitypes.scalar.Scalar(value=False)
4042
>>> Scalar(0)
41-
nitypes.scalar.Scalar(value=0, units='')
43+
nitypes.scalar.Scalar(value=0)
4244
>>> Scalar(5.0, 'volts')
4345
nitypes.scalar.Scalar(value=5.0, units='volts')
4446
>>> Scalar("value", "volts")
@@ -60,12 +62,18 @@ def __init__(
6062
self,
6163
value: TScalar_co,
6264
units: str = "",
65+
*,
66+
extended_properties: Mapping[str, ExtendedPropertyValue] | None = None,
67+
copy_extended_properties: bool = True,
6368
) -> None:
6469
"""Initialize a new scalar.
6570
6671
Args:
6772
value: The scalar data to store in this object.
6873
units: The units string associated with this data.
74+
extended_properties: The extended properties of the Scalar.
75+
copy_extended_properties: Specifies whether to copy the extended properties or take
76+
ownership.
6977
7078
Returns:
7179
A scalar data object.
@@ -77,8 +85,20 @@ def __init__(
7785
raise invalid_arg_type("units", "str", units)
7886

7987
self._value = value
80-
self._extended_properties = ExtendedPropertyDictionary()
81-
self._extended_properties[UNIT_DESCRIPTION] = units
88+
if copy_extended_properties or not isinstance(
89+
extended_properties, ExtendedPropertyDictionary
90+
):
91+
extended_properties = ExtendedPropertyDictionary(extended_properties)
92+
self._extended_properties = extended_properties
93+
94+
# If units are not already in extended properties, set them.
95+
if UNIT_DESCRIPTION not in self._extended_properties:
96+
self._extended_properties[UNIT_DESCRIPTION] = units
97+
elif units and units != self._extended_properties.get(UNIT_DESCRIPTION):
98+
raise ValueError(
99+
"The specified units input does not match the units specified in "
100+
"extended_properties."
101+
)
82102

83103
@property
84104
def value(self) -> TScalar_co:
@@ -164,11 +184,28 @@ def __le__(self, value: Scalar[TScalar_co]) -> bool:
164184

165185
def __reduce__(self) -> tuple[Any, ...]:
166186
"""Return object state for pickling."""
167-
return (self.__class__, (self.value, self.units))
187+
ctor_args = (self.value,)
188+
ctor_kwargs: dict[str, Any] = {
189+
"extended_properties": self._extended_properties,
190+
"copy_extended_properties": False,
191+
}
192+
return (self.__class__._unpickle, (ctor_args, ctor_kwargs))
193+
194+
@classmethod
195+
def _unpickle(cls, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Self:
196+
return cls(*args, **kwargs)
168197

169198
def __repr__(self) -> str:
170199
"""Return repr(self)."""
171-
args = [f"value={self.value!r}", f"units={self.units!r}"]
200+
args = [f"value={self.value!r}"]
201+
202+
if self.units:
203+
args.append(f"units={self.units!r}")
204+
205+
# Only display the extended properties if non-units entries are specified.
206+
if any(key for key in self.extended_properties.keys() if key != UNIT_DESCRIPTION):
207+
args.append(f"extended_properties={self.extended_properties!r}")
208+
172209
return f"{self.__class__.__module__}.{self.__class__.__name__}({', '.join(args)})"
173210

174211
def __str__(self) -> str:

src/nitypes/vector.py

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88

99
from __future__ import annotations
1010

11-
from collections.abc import Iterable, MutableSequence
12-
from typing import TYPE_CHECKING, overload, Any, Union
11+
from collections.abc import Iterable, Mapping, MutableSequence
12+
from typing import TYPE_CHECKING, Any, Union, overload
1313

14-
from typing_extensions import TypeVar, final, override
14+
from typing_extensions import Self, TypeVar, final, override
1515

1616
from nitypes._exceptions import invalid_arg_type
1717
from nitypes.waveform._extended_properties import UNIT_DESCRIPTION
18+
from nitypes.waveform.typing import ExtendedPropertyValue
1819

1920
if TYPE_CHECKING:
2021
# Import from the public package so the docs don't reference private submodules.
@@ -35,9 +36,9 @@ class Vector(MutableSequence[TScalar]):
3536
To construct a vector data object, use the :class:`Vector` class:
3637
3738
>>> Vector([False, True])
38-
nitypes.vector.Vector(values=[False, True], units='')
39+
nitypes.vector.Vector(values=[False, True])
3940
>>> Vector([0, 1, 2])
40-
nitypes.vector.Vector(values=[0, 1, 2], units='')
41+
nitypes.vector.Vector(values=[0, 1, 2])
4142
>>> Vector([5.0, 6.0], 'volts')
4243
nitypes.vector.Vector(values=[5.0, 6.0], units='volts')
4344
>>> Vector(["one", "two"], "volts")
@@ -60,6 +61,8 @@ def __init__(
6061
units: str = "",
6162
*,
6263
value_type: type[TScalar] | None = None,
64+
extended_properties: Mapping[str, ExtendedPropertyValue] | None = None,
65+
copy_extended_properties: bool = True,
6366
) -> None:
6467
"""Initialize a new vector.
6568
@@ -69,6 +72,9 @@ def __init__(
6972
value_type: The type of values that will be added to this Vector.
7073
This parameter should only be used when creating a Vector with
7174
an empty Iterable.
75+
extended_properties: The extended properties of the Vector.
76+
copy_extended_properties: Specifies whether to copy the extended properties or take
77+
ownership.
7278
7379
Returns:
7480
A vector data object.
@@ -94,8 +100,20 @@ def __init__(
94100
raise invalid_arg_type("units", "str", units)
95101

96102
self._values = list(values)
97-
self._extended_properties = ExtendedPropertyDictionary()
98-
self._extended_properties[UNIT_DESCRIPTION] = units
103+
if copy_extended_properties or not isinstance(
104+
extended_properties, ExtendedPropertyDictionary
105+
):
106+
extended_properties = ExtendedPropertyDictionary(extended_properties)
107+
self._extended_properties = extended_properties
108+
109+
# If units are not already in extended properties, set them.
110+
if UNIT_DESCRIPTION not in self._extended_properties:
111+
self._extended_properties[UNIT_DESCRIPTION] = units
112+
elif units and units != self._extended_properties.get(UNIT_DESCRIPTION):
113+
raise ValueError(
114+
"The specified units input does not match the units specified in "
115+
"extended_properties."
116+
)
99117

100118
@property
101119
def units(self) -> str:
@@ -191,11 +209,29 @@ def __eq__(self, value: object, /) -> bool:
191209

192210
def __reduce__(self) -> tuple[Any, ...]:
193211
"""Return object state for pickling."""
194-
return (self.__class__, (self._values, self.units))
212+
ctor_args = (self._values,)
213+
ctor_kwargs: dict[str, Any] = {
214+
"value_type": self._value_type,
215+
"extended_properties": self._extended_properties,
216+
"copy_extended_properties": False,
217+
}
218+
return (self.__class__._unpickle, (ctor_args, ctor_kwargs))
219+
220+
@classmethod
221+
def _unpickle(cls, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Self:
222+
return cls(*args, **kwargs)
195223

196224
def __repr__(self) -> str:
197225
"""Return repr(self)."""
198-
args = [f"values={self._values!r}", f"units={self.units!r}"]
226+
args = [f"values={self._values!r}"]
227+
228+
if self.units:
229+
args.append(f"units={self.units!r}")
230+
231+
# Only display the extended properties if non-units entries are specified.
232+
if any(key for key in self.extended_properties.keys() if key != UNIT_DESCRIPTION):
233+
args.append(f"extended_properties={self.extended_properties!r}")
234+
199235
return f"{self.__class__.__module__}.{self.__class__.__name__}({', '.join(args)})"
200236

201237
def __str__(self) -> str:

src/nitypes/xy_data.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -71,20 +71,16 @@ class XYData(Generic[TData]):
7171
To construct an XYData object, use the :class:`XYData` class:
7272
7373
>>> XYData(np.array([1.1], np.float64), np.array([4.1], np.float64))
74-
nitypes.xy_data.XYData(x_data=array([1.1]), y_data=array([4.1]),
75-
extended_properties={'NI_UnitDescription_X': '', 'NI_UnitDescription_Y': ''})
74+
nitypes.xy_data.XYData(x_data=array([1.1]), y_data=array([4.1]))
7675
>>> XYData(np.array([1, 2]), np.array([4, 5]), x_units="A", y_units="V")
77-
nitypes.xy_data.XYData(x_data=array([1, 2]), y_data=array([4, 5]),
78-
extended_properties={'NI_UnitDescription_X': 'A', 'NI_UnitDescription_Y': 'V'})
76+
nitypes.xy_data.XYData(x_data=array([1, 2]), y_data=array([4, 5]), x_units='A', y_units='V')
7977
8078
To construct an XYData object using built-in lists, use from_arrays_1d():
8179
8280
>>> XYData.from_arrays_1d([1, 2], [5, 6], np.int32)
83-
nitypes.xy_data.XYData(x_data=array([1, 2], dtype=int32), y_data=array([5, 6], dtype=int32),
84-
extended_properties={'NI_UnitDescription_X': '', 'NI_UnitDescription_Y': ''})
81+
nitypes.xy_data.XYData(x_data=array([1, 2], dtype=int32), y_data=array([5, 6], dtype=int32))
8582
>>> XYData.from_arrays_1d([1.0, 1.1], [1.2, 1.3], np.float64)
86-
nitypes.xy_data.XYData(x_data=array([1. , 1.1]), y_data=array([1.2, 1.3]),
87-
extended_properties={'NI_UnitDescription_X': '', 'NI_UnitDescription_Y': ''})
83+
nitypes.xy_data.XYData(x_data=array([1. , 1.1]), y_data=array([1.2, 1.3]))
8884
"""
8985

9086
__slots__ = [
@@ -246,12 +242,23 @@ def __init__(
246242
extended_properties = ExtendedPropertyDictionary(extended_properties)
247243
self._extended_properties = extended_properties
248244

249-
# If x and y units are not already in extended properties, set them.
250-
# If the caller specifies a non-blank x or y units, overwrite the existing entry.
251-
if _UNIT_DESCRIPTION_X not in self._extended_properties or x_units:
245+
# If x_units are not already in extended properties, set them.
246+
if _UNIT_DESCRIPTION_X not in self._extended_properties:
252247
self._extended_properties[_UNIT_DESCRIPTION_X] = x_units
253-
if _UNIT_DESCRIPTION_Y not in self._extended_properties or y_units:
248+
elif x_units and x_units != self._extended_properties.get(_UNIT_DESCRIPTION_X):
249+
raise ValueError(
250+
"The specified x_units input does not match the units specified in "
251+
"extended_properties."
252+
)
253+
254+
# If y_units are not already in extended properties, set them.
255+
if _UNIT_DESCRIPTION_Y not in self._extended_properties:
254256
self._extended_properties[_UNIT_DESCRIPTION_Y] = y_units
257+
elif y_units and y_units != self._extended_properties.get(_UNIT_DESCRIPTION_Y):
258+
raise ValueError(
259+
"The specified y_units input does not match the units specified in "
260+
"extended_properties."
261+
)
255262

256263
def _init_with_provided_arrays(
257264
self,
@@ -365,11 +372,22 @@ def _unpickle(cls, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Self:
365372

366373
def __repr__(self) -> str:
367374
"""Return repr(self)."""
368-
args = [
369-
f"x_data={self.x_data!r}",
370-
f"y_data={self.y_data!r}",
371-
f"extended_properties={self._extended_properties._properties!r}",
372-
]
375+
args = [f"x_data={self.x_data!r}", f"y_data={self.y_data!r}"]
376+
377+
if self.x_units:
378+
args.append(f"x_units={self.x_units!r}")
379+
380+
if self.y_units:
381+
args.append(f"y_units={self.y_units!r}")
382+
383+
# Only display the extended properties if non-units entries are specified.
384+
if any(
385+
key
386+
for key in self.extended_properties.keys()
387+
if key not in [_UNIT_DESCRIPTION_X, _UNIT_DESCRIPTION_Y]
388+
):
389+
args.append(f"extended_properties={self.extended_properties!r}")
390+
373391
return f"{self.__class__.__module__}.{self.__class__.__name__}({', '.join(args)})"
374392

375393
def __str__(self) -> str:

tests/unit/scalar/test_scalar.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,20 @@ def test___invalid_data_value___create___raises_type_error(data_value: Any) -> N
6868
assert exc.value.args[0].startswith("The scalar input data must be a bool, int, float, or str.")
6969

7070

71+
def test___both_units_specified_unequal__create___raises_value_error() -> None:
72+
with pytest.raises(ValueError) as exc:
73+
_ = Scalar(10, "Volts", extended_properties={UNIT_DESCRIPTION: "Amps"})
74+
75+
assert exc.value.args[0].startswith(
76+
"The specified units input does not match the units specified in extended_properties."
77+
)
78+
79+
80+
def test___units_only_specified_in_extended_properties__create___creates_with_units() -> None:
81+
data = Scalar(10, extended_properties={UNIT_DESCRIPTION: "Volts"})
82+
assert data.units == "Volts"
83+
84+
7185
###############################################################################
7286
# compare
7387
###############################################################################
@@ -182,14 +196,20 @@ def test___different_units___comparison___throws_exception() -> None:
182196
@pytest.mark.parametrize(
183197
"value, expected_repr",
184198
[
185-
(Scalar(False), "nitypes.scalar.Scalar(value=False, units='')"),
186-
(Scalar(10), "nitypes.scalar.Scalar(value=10, units='')"),
187-
(Scalar(20.0), "nitypes.scalar.Scalar(value=20.0, units='')"),
188-
(Scalar("value"), "nitypes.scalar.Scalar(value='value', units='')"),
199+
(Scalar(False), "nitypes.scalar.Scalar(value=False)"),
200+
(Scalar(10), "nitypes.scalar.Scalar(value=10)"),
201+
(Scalar(20.0), "nitypes.scalar.Scalar(value=20.0)"),
202+
(Scalar("value"), "nitypes.scalar.Scalar(value='value')"),
189203
(Scalar(False, "amps"), "nitypes.scalar.Scalar(value=False, units='amps')"),
190204
(Scalar(10, "volts"), "nitypes.scalar.Scalar(value=10, units='volts')"),
191205
(Scalar(20.0, "watts"), "nitypes.scalar.Scalar(value=20.0, units='watts')"),
192-
(Scalar("value", ""), "nitypes.scalar.Scalar(value='value', units='')"),
206+
(Scalar("value", ""), "nitypes.scalar.Scalar(value='value')"),
207+
(
208+
Scalar(10, units="volts", extended_properties={"Prop1": "Value1"}),
209+
"nitypes.scalar.Scalar(value=10, units='volts', "
210+
"extended_properties=nitypes.waveform.ExtendedPropertyDictionary("
211+
"{'Prop1': 'Value1', 'NI_UnitDescription': 'volts'}))",
212+
),
193213
],
194214
)
195215
def test___various_values___repr___looks_ok(value: Scalar[Any], expected_repr: str) -> None:
@@ -244,12 +264,14 @@ def test___scalar_with_units___set_units___units_updated_correctly() -> None:
244264
Scalar(10, "volts"),
245265
Scalar(20.0, "watts"),
246266
Scalar("value", ""),
267+
Scalar(10, "Volts", extended_properties={"one": 1}),
247268
],
248269
)
249270
def test___various_values___copy___makes_copy(value: Scalar[TScalar_co]) -> None:
250271
new_value = copy.copy(value)
251272
assert new_value is not value
252273
assert new_value == value
274+
assert new_value.extended_properties == value.extended_properties
253275

254276

255277
@pytest.mark.parametrize(
@@ -263,12 +285,15 @@ def test___various_values___copy___makes_copy(value: Scalar[TScalar_co]) -> None
263285
Scalar(10, "volts"),
264286
Scalar(20.0, "watts"),
265287
Scalar("value", ""),
288+
Scalar(10, "Volts", extended_properties={"one": 1}),
266289
],
267290
)
268291
def test___various_values___pickle_unpickle___makes_copy(value: Scalar[TScalar_co]) -> None:
269292
new_value = pickle.loads(pickle.dumps(value))
293+
assert isinstance(new_value, Scalar)
270294
assert new_value is not value
271295
assert new_value == value
296+
assert new_value.extended_properties == value.extended_properties
272297

273298

274299
def test___scalar___pickle___references_public_modules() -> None:

0 commit comments

Comments
 (0)