Skip to content

Commit 9f070d4

Browse files
committed
Use overloads to make get_attr type safe
The overloads allow us to properly distinguish between: - Data type passed or not passed (returns str) - Optional set to true or false (returns Optional[T] or T) Also remove the now redundant casts.
1 parent 550f034 commit 9f070d4

File tree

5 files changed

+68
-39
lines changed

5 files changed

+68
-39
lines changed

launch/launch/action.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from typing import Text
2121
from typing import Tuple
2222
from typing import Dict
23-
from typing import cast
2423

2524
from .condition import Condition
2625
from .launch_context import LaunchContext
@@ -62,8 +61,8 @@ def parse(entity: 'Entity', parser: 'Parser'):
6261
# Import here for avoiding cyclic imports.
6362
from .conditions import IfCondition
6463
from .conditions import UnlessCondition
65-
if_cond = cast(str, entity.get_attr('if', optional=True))
66-
unless_cond = cast(str, entity.get_attr('unless', optional=True))
64+
if_cond = entity.get_attr('if', optional=True)
65+
unless_cond = entity.get_attr('unless', optional=True)
6766
kwargs: Dict[str, Condition] = {}
6867
if if_cond is not None and unless_cond is not None:
6968
raise RuntimeError("if and unless conditions can't be used simultaneously")

launch/launch/actions/append_environment_variable.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import os
1818
from typing import List
1919
from typing import Union
20-
from typing import cast
2120

2221
from ..action import Action
2322
from ..frontend import Entity
@@ -76,14 +75,14 @@ def parse(
7675
):
7776
"""Parse an 'append_env' entity."""
7877
_, kwargs = super().parse(entity, parser)
79-
kwargs['name'] = parser.parse_substitution(cast(str, entity.get_attr('name')))
80-
kwargs['value'] = parser.parse_substitution(cast(str, entity.get_attr('value')))
78+
kwargs['name'] = parser.parse_substitution(entity.get_attr('name'))
79+
kwargs['value'] = parser.parse_substitution(entity.get_attr('value'))
8180
prepend = entity.get_attr('prepend', optional=True, data_type=bool, can_be_str=True)
8281
if prepend is not None:
83-
kwargs['prepend'] = parser.parse_if_substitutions(cast(Union[str, bool], prepend))
82+
kwargs['prepend'] = parser.parse_if_substitutions(prepend)
8483
separator = entity.get_attr('separator', optional=True)
8584
if separator is not None:
86-
kwargs['separator'] = parser.parse_substitution(cast(str, separator))
85+
kwargs['separator'] = parser.parse_substitution(separator)
8786
return cls, kwargs
8887

8988
@property

launch/launch/actions/declare_launch_argument.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from typing import List
1818
from typing import Optional
1919
from typing import Text
20-
from typing import cast
2120

2221
import launch.logging
2322

@@ -166,19 +165,17 @@ def parse(
166165
):
167166
"""Parse `arg` tag."""
168167
_, kwargs = super().parse(entity, parser)
169-
kwargs['name'] = parser.escape_characters(cast(str, entity.get_attr('name')))
168+
kwargs['name'] = parser.escape_characters(entity.get_attr('name'))
170169
default_value = entity.get_attr('default', optional=True)
171170
if default_value is not None:
172-
kwargs['default_value'] = parser.parse_substitution(cast(str, default_value))
171+
kwargs['default_value'] = parser.parse_substitution(default_value)
173172
description = entity.get_attr('description', optional=True)
174173
if description is not None:
175-
kwargs['description'] = parser.escape_characters(cast(str, description))
176-
# TODO: What to do here? get_attr is supposed to parse scalar / list of scalars, but in
177-
# this case asks for a list of entities?
178-
choices = cast(List[Entity], entity.get_attr('choice', data_type=List[Entity], optional=True)) # type: ignore
174+
kwargs['description'] = parser.escape_characters(description)
175+
choices = entity.get_attr('choice', data_type=List[Entity], optional=True)
179176
if choices is not None:
180177
kwargs['choices'] = [
181-
parser.escape_characters(cast(str, choice.get_attr('value'))) for choice in choices
178+
parser.escape_characters(choice.get_attr('value')) for choice in choices
182179
]
183180
return cls, kwargs
184181

launch/launch/actions/execute_process.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -318,35 +318,35 @@ def parse(
318318
ignore = []
319319

320320
if 'cmd' not in ignore:
321-
kwargs['cmd'] = cls._parse_cmdline(cast(str, entity.get_attr('cmd')), parser)
321+
kwargs['cmd'] = cls._parse_cmdline(entity.get_attr('cmd'), parser)
322322

323323
if 'cwd' not in ignore:
324-
cwd = cast(Optional[str], entity.get_attr('cwd', optional=True))
324+
cwd = entity.get_attr('cwd', optional=True)
325325
if cwd is not None:
326326
kwargs['cwd'] = parser.parse_substitution(cwd)
327327

328328
if 'name' not in ignore:
329-
name = cast(Optional[str], entity.get_attr('name', optional=True))
329+
name = entity.get_attr('name', optional=True)
330330
if name is not None:
331331
kwargs['name'] = parser.parse_substitution(name)
332332

333333
if 'prefix' not in ignore:
334-
prefix = cast(Optional[str], entity.get_attr('launch-prefix', optional=True))
334+
prefix = entity.get_attr('launch-prefix', optional=True)
335335
if prefix is not None:
336336
kwargs['prefix'] = parser.parse_substitution(prefix)
337337

338338
if 'output' not in ignore:
339-
output = cast(Optional[str], entity.get_attr('output', optional=True))
339+
output = entity.get_attr('output', optional=True)
340340
if output is not None:
341341
kwargs['output'] = parser.parse_substitution(output)
342342

343343
if 'respawn' not in ignore:
344-
respawn = cast(Optional[str], entity.get_attr('respawn', optional=True))
344+
respawn = entity.get_attr('respawn', optional=True)
345345
if respawn is not None:
346346
kwargs['respawn'] = parser.parse_substitution(respawn)
347347

348348
if 'respawn_delay' not in ignore:
349-
respawn_delay = cast(Optional[float], entity.get_attr('respawn_delay', data_type=float, optional=True))
349+
respawn_delay = entity.get_attr('respawn_delay', data_type=float, optional=True)
350350
if respawn_delay is not None:
351351
if respawn_delay < 0.0:
352352
raise ValueError(
@@ -356,7 +356,7 @@ def parse(
356356
kwargs['respawn_delay'] = respawn_delay
357357

358358
if 'sigkill_timeout' not in ignore:
359-
sigkill_timeout = cast(Optional[float], entity.get_attr('sigkill_timeout', data_type=float, optional=True))
359+
sigkill_timeout = entity.get_attr('sigkill_timeout', data_type=float, optional=True)
360360
if sigkill_timeout is not None:
361361
if sigkill_timeout < 0.0:
362362
raise ValueError(
@@ -366,7 +366,7 @@ def parse(
366366
kwargs['sigkill_timeout'] = str(sigkill_timeout)
367367

368368
if 'sigterm_timeout' not in ignore:
369-
sigterm_timeout = cast(Optional[float], entity.get_attr('sigterm_timeout', data_type=float, optional=True))
369+
sigterm_timeout = entity.get_attr('sigterm_timeout', data_type=float, optional=True)
370370
if sigterm_timeout is not None:
371371
if sigterm_timeout < 0.0:
372372
raise ValueError(
@@ -376,25 +376,24 @@ def parse(
376376
kwargs['sigterm_timeout'] = str(sigterm_timeout)
377377

378378
if 'shell' not in ignore:
379-
shell = cast(Optional[bool], entity.get_attr('shell', data_type=bool, optional=True))
379+
shell = entity.get_attr('shell', data_type=bool, optional=True)
380380
if shell is not None:
381381
kwargs['shell'] = shell
382382

383383
if 'emulate_tty' not in ignore:
384-
emulate_tty = cast(Optional[bool], entity.get_attr('emulate_tty', data_type=bool, optional=True))
384+
emulate_tty = entity.get_attr('emulate_tty', data_type=bool, optional=True)
385385
if emulate_tty is not None:
386386
kwargs['emulate_tty'] = emulate_tty
387387

388388
if 'additional_env' not in ignore:
389389
# Conditions won't be allowed in the `env` tag.
390390
# If that feature is needed, `set_enviroment_variable` and
391391
# `unset_enviroment_variable` actions should be used.
392-
# TODO: Fixup the data_type annotation
393-
env = cast(Optional[List[Entity]], entity.get_attr('env', data_type=List[Entity], optional=True)) # type: ignore
392+
env = entity.get_attr('env', data_type=List[Entity], optional=True)
394393
if env is not None:
395394
kwargs['additional_env'] = {
396-
tuple(parser.parse_substitution(cast(str, e.get_attr('name')))):
397-
parser.parse_substitution(cast(str, e.get_attr('value'))) for e in env
395+
tuple(parser.parse_substitution(e.get_attr('name'))):
396+
parser.parse_substitution(e.get_attr('value')) for e in env
398397
}
399398
for e in env:
400399
e.assert_entity_completely_parsed()

launch/launch/frontend/entity.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,20 @@
1515
"""Module for Entity class."""
1616

1717
from typing import List
18+
from typing import Literal
1819
from typing import Optional
1920
from typing import Text
21+
from typing import Type
2022
from typing import Union
23+
from typing import TypeVar
24+
from typing import overload
2125

2226
from launch.utilities.type_utils import AllowedTypesType
2327
from launch.utilities.type_utils import AllowedValueType
2428

2529

30+
TargetType = TypeVar("TargetType")
31+
2632
class Entity:
2733
"""Single item in the intermediate front_end representation."""
2834

@@ -41,17 +47,46 @@ def children(self) -> List['Entity']:
4147
"""Get the Entity's children."""
4248
raise NotImplementedError()
4349

50+
# We need a few overloads for type checking:
51+
# - Depending on optional, the return value is either T or Optional[T].
52+
# Unfortunately, if the optional is not present, we need another overload to denote the
53+
# default, so this has three values: true, false and not present.
54+
# - If no data type is passed, the return type is str. Similar to the above, it has two
55+
# possibilities: present or not present.
56+
# => 6 overloads to cover every combination
57+
@overload
58+
def get_attr(self, name: Text, *, data_type: Type[TargetType], optional: Literal[False], can_be_str: bool = True) -> TargetType:
59+
...
60+
61+
@overload
62+
def get_attr(self, name: Text, *, data_type: Type[TargetType], optional: Literal[True], can_be_str: bool = True) -> Optional[TargetType]:
63+
...
64+
65+
@overload
66+
def get_attr(self, name: Text, *, data_type: Type[TargetType], can_be_str: bool = True) -> TargetType:
67+
...
68+
69+
@overload
70+
def get_attr(self, name: Text, *, optional: Literal[False], can_be_str: bool = True) -> str:
71+
...
72+
73+
@overload
74+
def get_attr(self, name: Text, *, optional: Literal[True], can_be_str: bool = True) -> Optional[str]:
75+
...
76+
77+
@overload
78+
def get_attr(self, name: Text, *, can_be_str: bool = True) -> str:
79+
...
80+
81+
4482
def get_attr(
4583
self,
46-
name: Text,
84+
name,
4785
*,
48-
data_type: AllowedTypesType = str,
49-
optional: bool = False,
50-
can_be_str: bool = True,
51-
) -> Optional[Union[
52-
AllowedValueType,
53-
List['Entity'],
54-
]]:
86+
data_type=str,
87+
optional=False,
88+
can_be_str=True,
89+
):
5590
"""
5691
Access an attribute of the entity.
5792

0 commit comments

Comments
 (0)