Skip to content

Commit 1384c53

Browse files
authored
add label based indexing to pass sequence (#154)
* add label based indexing to pass sequence * add ipython completion support
1 parent bf911dc commit 1384c53

File tree

3 files changed

+87
-4
lines changed

3 files changed

+87
-4
lines changed

pyroll/core/sequence/sequence.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,31 @@ def __iter__(self):
5858
return self._subunits.__iter__()
5959

6060
@overload
61-
def __getitem__(self, index: int) -> Unit:
61+
def __getitem__(self, key: int) -> Unit:
62+
"""Gets unit item by index."""
6263
...
6364

6465
@overload
65-
def __getitem__(self, index: slice) -> list[Unit]:
66+
def __getitem__(self, key: str) -> Unit:
67+
"""Gets unit item by label."""
6668
...
6769

68-
def __getitem__(self, index: int) -> Unit:
69-
return self._subunits.__getitem__(index)
70+
@overload
71+
def __getitem__(self, key: slice) -> list[Unit]:
72+
"""Gets a slice of units."""
73+
...
74+
75+
def __getitem__(self, key):
76+
if isinstance(key, str):
77+
try:
78+
return next(u for u in self._subunits if u.label == key)
79+
except StopIteration:
80+
raise KeyError(f"No unit with label '{key}' found.")
81+
82+
if isinstance(key, int) or isinstance(key, slice):
83+
return self._subunits.__getitem__(key)
84+
85+
raise TypeError("Key must be int, slice or str")
7086

7187
@property
7288
def units(self) -> List[Unit]:
@@ -88,3 +104,6 @@ def __attrs__(self):
88104
return super().__attrs__ | {
89105
"units": self.units
90106
}
107+
108+
def _ipython_key_completions_(self):
109+
return [u.label for u in self._subunits]

tests/test_pass_sequence.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from pyroll.core import PassSequence, RollPass, CircularOvalGroove, Transport, RoundGroove, Roll
2+
import pytest
3+
4+
5+
# noinspection DuplicatedCode
6+
def test_pass_sequence_indexing():
7+
sequence = PassSequence([
8+
RollPass(
9+
label="Oval I",
10+
roll=Roll(
11+
groove=CircularOvalGroove(
12+
depth=8e-3,
13+
r1=6e-3,
14+
r2=40e-3
15+
),
16+
nominal_radius=160e-3,
17+
rotational_frequency=1,
18+
neutral_point=-20e-3
19+
),
20+
gap=2e-3,
21+
22+
),
23+
Transport(
24+
label="I => II",
25+
duration=1,
26+
),
27+
RollPass(
28+
label="Round II",
29+
roll=Roll(
30+
groove=RoundGroove(
31+
r1=1e-3,
32+
r2=12.5e-3,
33+
depth=11.5e-3
34+
),
35+
nominal_radius=160e-3,
36+
rotational_frequency=1
37+
),
38+
gap=2e-3,
39+
),
40+
Transport(
41+
label="II => III",
42+
duration=1
43+
),
44+
RollPass(
45+
label="Oval III",
46+
roll=Roll(
47+
groove=CircularOvalGroove(
48+
depth=6e-3,
49+
r1=6e-3,
50+
r2=35e-3
51+
),
52+
nominal_radius=160e-3,
53+
rotational_frequency=1
54+
),
55+
gap=2e-3,
56+
),
57+
])
58+
59+
assert sequence["Oval III"] == sequence[4]
60+
assert sequence["I => II"] == sequence[1]
61+
62+
with pytest.raises(KeyError):
63+
_ = sequence["not present"]

tests/test_solve.py

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ def flow_stress(self: RollPass.Profile):
1111
return 50e6 * (1 + self.strain) ** 0.2 * self.roll_pass.strain_rate ** 0.1
1212

1313

14+
# noinspection DuplicatedCode
1415
def test_solve(tmp_path: Path, caplog):
1516
caplog.set_level(logging.DEBUG, logger="pyroll")
1617

0 commit comments

Comments
 (0)