Skip to content

Commit 6a340b6

Browse files
authored
Merge pull request #51 from adf-python/adf/impl/module/complex_49
adf/impl/module/complex/のAmbulanceで使われている部分の実装
2 parents ca0e98c + 382a47e commit 6a340b6

File tree

13 files changed

+681
-22
lines changed

13 files changed

+681
-22
lines changed

adf_core_python/core/agent/info/scenario_info.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from enum import Enum
2+
from typing import Any
23

34
from adf_core_python.core.config.config import Config
45

@@ -57,18 +58,20 @@ def get_mode(self) -> Mode:
5758
"""
5859
return self._mode
5960

60-
def get_config_value(self, key: str, default: str) -> str:
61+
def get_value(self, key: str, default: Any) -> Any:
6162
"""
6263
Get the value of the configuration
6364
6465
Parameters
6566
----------
6667
key : str
6768
Key of the configuration
69+
default : Any
70+
Default value of the configuration
6871
6972
Returns
7073
-------
71-
str
74+
Any
7275
Value of the configuration
7376
"""
7477
return self._config.get_value(key, default)

adf_core_python/core/agent/info/world_info.py

Lines changed: 85 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,16 @@ def get_entity(self, entity_id: EntityID) -> Optional[Entity]:
5353
"""
5454
return self._world_model.get_entity(entity_id)
5555

56-
def get_entity_ids_of_type(self, entity_type: type[Entity]) -> list[EntityID]:
56+
def get_entity_ids_of_types(
57+
self, entity_types: list[type[Entity]]
58+
) -> list[EntityID]:
5759
"""
58-
Get the entity IDs of the specified type
60+
Get the entity IDs of the specified types
5961
6062
Parameters
6163
----------
62-
entity_type : type[Entity]
63-
Entity type
64+
entity_types : list[type[Entity]]
65+
List of entity types
6466
6567
Returns
6668
-------
@@ -69,7 +71,85 @@ def get_entity_ids_of_type(self, entity_type: type[Entity]) -> list[EntityID]:
6971
"""
7072
entity_ids: list[EntityID] = []
7173
for entity in self._world_model.get_entities():
72-
if isinstance(entity, entity_type):
74+
if any(isinstance(entity, entity_type) for entity_type in entity_types):
7375
entity_ids.append(entity.get_id())
7476

7577
return entity_ids
78+
79+
def get_entities_of_types(self, entity_types: list[type[Entity]]) -> list[Entity]:
80+
"""
81+
Get the entities of the specified types
82+
83+
Parameters
84+
----------
85+
entity_types : list[type[Entity]]
86+
List of entity types
87+
88+
Returns
89+
-------
90+
list[Entity]
91+
Entities
92+
"""
93+
entities: list[Entity] = []
94+
for entity in self._world_model.get_entities():
95+
if any(isinstance(entity, entity_type) for entity_type in entity_types):
96+
entities.append(entity)
97+
98+
return entities
99+
100+
def get_distance(self, entity_id1: EntityID, entity_id2: EntityID) -> float:
101+
"""
102+
Get the distance between two entities
103+
104+
Parameters
105+
----------
106+
entity_id1 : EntityID
107+
Entity ID 1
108+
entity_id2 : EntityID
109+
Entity ID 2
110+
111+
Returns
112+
-------
113+
float
114+
Distance
115+
116+
Raises
117+
------
118+
ValueError
119+
If one or both entities are invalid or the location is invalid
120+
"""
121+
entity1: Optional[Entity] = self.get_entity(entity_id1)
122+
entity2: Optional[Entity] = self.get_entity(entity_id2)
123+
if entity1 is None or entity2 is None:
124+
raise ValueError(
125+
f"One or both entities are invalid: entity_id1={entity_id1}, entity_id2={entity_id2}, entity1={entity1}, entity2={entity2}"
126+
)
127+
128+
location1_x, location1_y = entity1.get_location()
129+
location2_x, location2_y = entity2.get_location()
130+
if (
131+
location1_x is None
132+
or location1_y is None
133+
or location2_x is None
134+
or location2_y is None
135+
):
136+
raise ValueError(
137+
f"Invalid location: entity_id1={entity_id1}, entity_id2={entity_id2}, location1_x={location1_x}, location1_y={location1_y}, location2_x={location2_x}, location2_y={location2_y}"
138+
)
139+
140+
distance: float = (
141+
(location1_x - location2_x) ** 2 + (location1_y - location2_y) ** 2
142+
) ** 0.5
143+
144+
return distance
145+
146+
def get_change_set(self) -> ChangeSet:
147+
"""
148+
Get the change set
149+
150+
Returns
151+
-------
152+
ChangeSet
153+
Change set
154+
"""
155+
return self._change_set
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from __future__ import annotations
2+
3+
from abc import abstractmethod
4+
from typing import TYPE_CHECKING
5+
6+
from adf_core_python.core.component.module.abstract_module import AbstractModule
7+
8+
if TYPE_CHECKING:
9+
from rcrs_core.entities.entity import Entity
10+
from rcrs_core.worldmodel.entityID import EntityID
11+
12+
from adf_core_python.core.agent.communication.message_manager import MessageManager
13+
from adf_core_python.core.agent.develop.develop_data import DevelopData
14+
from adf_core_python.core.agent.info.agent_info import AgentInfo
15+
from adf_core_python.core.agent.info.scenario_info import ScenarioInfo
16+
from adf_core_python.core.agent.info.world_info import WorldInfo
17+
from adf_core_python.core.agent.module.module_manager import ModuleManager
18+
from adf_core_python.core.agent.precompute.precompute_data import PrecomputeData
19+
20+
21+
class Clustering(AbstractModule):
22+
def __init__(
23+
self,
24+
agent_info: AgentInfo,
25+
world_info: WorldInfo,
26+
scenario_info: ScenarioInfo,
27+
module_manager: ModuleManager,
28+
develop_data: DevelopData,
29+
) -> None:
30+
super().__init__(
31+
agent_info, world_info, scenario_info, module_manager, develop_data
32+
)
33+
34+
@abstractmethod
35+
def get_cluster_number(self) -> int:
36+
pass
37+
38+
@abstractmethod
39+
def get_cluster_index(self, entity_id: EntityID) -> int:
40+
pass
41+
42+
@abstractmethod
43+
def get_cluster_entities(self, cluster_index: int) -> list[Entity]:
44+
pass
45+
46+
@abstractmethod
47+
def get_cluster_entity_ids(self, cluster_index: int) -> list[EntityID]:
48+
pass
49+
50+
@abstractmethod
51+
def calculate(self) -> Clustering:
52+
pass
53+
54+
def precompute(self, precompute_data: PrecomputeData) -> Clustering:
55+
super().precompute(precompute_data)
56+
return self
57+
58+
def resume(self, precompute_data: PrecomputeData) -> Clustering:
59+
super().resume(precompute_data)
60+
return self
61+
62+
def prepare(self) -> Clustering:
63+
super().prepare()
64+
return self
65+
66+
def update_info(self, message_manager: MessageManager) -> Clustering:
67+
super().update_info(message_manager)
68+
return self

adf_core_python/core/component/module/complex/target_detector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from abc import abstractmethod
4-
from typing import TYPE_CHECKING, Generic, TypeVar
4+
from typing import TYPE_CHECKING, Generic, Optional, TypeVar
55

66
from rcrs_core.entities.entity import Entity
77

@@ -35,7 +35,7 @@ def __init__(
3535
)
3636

3737
@abstractmethod
38-
def get_target_entity_id(self) -> EntityID:
38+
def get_target_entity_id(self) -> Optional[EntityID]:
3939
pass
4040

4141
@abstractmethod

adf_core_python/implement/extend_action/default_extend_action_transport.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from adf_core_python.core.component.module.algorithm.path_planning import PathPlanning
2424

2525

26+
# TODO: refactor this class
2627
class DefaultExtendActionTransport(ExtAction):
2728
def __init__(
2829
self,
@@ -99,7 +100,7 @@ def calc(self) -> ExtAction:
99100
agent: AmbulanceTeamEntity = cast(
100101
AmbulanceTeamEntity, self.agent_info.get_myself()
101102
)
102-
transport_human: Human = self.agent_info.some_one_on_board()
103+
transport_human: Optional[Human] = self.agent_info.some_one_on_board()
103104
if transport_human is not None:
104105
self.result = self.calc_unload(
105106
agent, self._path_planning, transport_human, self._target_entity_id
@@ -134,9 +135,7 @@ def calc_rescue(
134135

135136
target_position = human.get_position()
136137
if agent_position == target_position:
137-
if isinstance(human, Civilian) and (
138-
human.get_buriedness() is not None and human.get_buriedness() > 0
139-
):
138+
if isinstance(human, Civilian) and ((human.get_buriedness() or 0) > 0):
140139
return ActionLoad(human.get_id())
141140
else:
142141
path = path_planning.get_path(agent_position, target_position)
@@ -176,9 +175,7 @@ def calc_unload(
176175
if isinstance(position, Refuge):
177176
return ActionUnload()
178177
else:
179-
path = path_planning.get_path(
180-
agent_position, self.world_info.get_entity_ids_of_type(Refuge)
181-
)
178+
path = self.get_nearest_refuge_path(agent, path_planning)
182179
if path is not None and len(path) > 0:
183180
return ActionMove(path)
184181

@@ -191,7 +188,7 @@ def calc_unload(
191188
human = cast(Human, target_entity)
192189
if human.get_position() is not None:
193190
return self.calc_refuge_action(
194-
agent, path_planning, [human.get_position()], True
191+
agent, path_planning, human.get_position(), True
195192
)
196193
path = self.get_nearest_refuge_path(agent, path_planning)
197194
if path is not None and len(path) > 0:
@@ -207,7 +204,7 @@ def calc_refuge_action(
207204
is_unload: bool,
208205
) -> Optional[ActionMove | ActionUnload | ActionRest]:
209206
position = human.get_position()
210-
refuges = self.world_info.get_entity_ids_of_type(Refuge)
207+
refuges = self.world_info.get_entity_ids_of_types([Refuge])
211208
size = len(refuges)
212209

213210
if position in refuges:
@@ -242,7 +239,7 @@ def get_nearest_refuge_path(
242239
self, human: Human, path_planning: PathPlanning
243240
) -> list[EntityID]:
244241
position = human.get_position()
245-
refuges = self.world_info.get_entity_ids_of_type(Refuge)
242+
refuges = self.world_info.get_entity_ids_of_types([Refuge])
246243
nearest_path = None
247244

248245
for refuge_id in refuges:

0 commit comments

Comments
 (0)