19
19
import os
20
20
import signal
21
21
import sys
22
+ from dataclasses import dataclass
23
+ from types import TracebackType
22
24
from typing import List
23
25
24
- from ....utils import get_next_port
26
+ from ....utils import get_next_port , dataslots
25
27
from ..config import ActorPoolConfig
26
28
from ..message import CreateActorMessage
27
29
from ..pool import MainActorPoolBase , SubActorPoolBase , _register_message_handler
@@ -58,6 +60,15 @@ def _mp_kill(self):
58
60
logger = logging .getLogger (__name__ )
59
61
60
62
63
+ @dataslots
64
+ @dataclass
65
+ class SubpoolStatus :
66
+ # for status, 0 is succeeded, 1 is failed
67
+ status : int = None
68
+ error : BaseException = None
69
+ traceback : TracebackType = None
70
+
71
+
61
72
@_register_message_handler
62
73
class MainActorPool (MainActorPoolBase ):
63
74
@@ -107,29 +118,41 @@ async def start_sub_pool(
107
118
108
119
def start_pool_in_process ():
109
120
ctx = multiprocessing .get_context (method = start_method )
110
- started = ctx .Event ()
121
+ status_queue = ctx .Queue ()
111
122
process = ctx .Process (
112
123
target = cls ._start_sub_pool ,
113
- args = (actor_pool_config , process_index , started ),
124
+ args = (actor_pool_config , process_index , status_queue ),
114
125
name = f'MarsActorPool{ process_index } ' ,
115
126
)
116
127
process .daemon = True
117
128
process .start ()
118
129
# wait for sub actor pool to finish starting
119
- started . wait ()
120
- return process
130
+ process_status = status_queue . get ()
131
+ return process , process_status
121
132
122
133
loop = asyncio .get_running_loop ()
123
134
executor = futures .ThreadPoolExecutor (1 )
124
135
create_pool_task = loop .run_in_executor (executor , start_pool_in_process )
125
136
return await create_pool_task
126
137
138
+ @classmethod
139
+ async def wait_sub_pools_ready (cls ,
140
+ create_pool_tasks : List [asyncio .Task ]):
141
+ processes = []
142
+ for task in create_pool_tasks :
143
+ process , status = await task
144
+ if status .status == 1 :
145
+ # start sub pool failed
146
+ raise status .error .with_traceback (status .traceback )
147
+ processes .append (process )
148
+ return processes
149
+
127
150
@classmethod
128
151
def _start_sub_pool (
129
152
cls ,
130
153
actor_config : ActorPoolConfig ,
131
154
process_index : int ,
132
- started : multiprocessing .Event ):
155
+ status_queue : multiprocessing .Queue ):
133
156
if not _is_windows :
134
157
try :
135
158
# register coverage hooks on SIGTERM
@@ -159,15 +182,16 @@ def _start_sub_pool(
159
182
else :
160
183
asyncio .set_event_loop (asyncio .new_event_loop ())
161
184
162
- coro = cls ._create_sub_pool (actor_config , process_index , started )
185
+ coro = cls ._create_sub_pool (actor_config , process_index , status_queue )
163
186
asyncio .run (coro )
164
187
165
188
@classmethod
166
189
async def _create_sub_pool (
167
190
cls ,
168
191
actor_config : ActorPoolConfig ,
169
192
process_index : int ,
170
- started : multiprocessing .Event ):
193
+ status_queue : multiprocessing .Queue ):
194
+ process_status = None
171
195
try :
172
196
env = actor_config .get_pool_config (process_index )['env' ]
173
197
if env :
@@ -176,9 +200,14 @@ async def _create_sub_pool(
176
200
'actor_pool_config' : actor_config ,
177
201
'process_index' : process_index
178
202
})
203
+ process_status = SubpoolStatus (status = 0 )
179
204
await pool .start ()
205
+ except : # noqa: E722 # nosec # pylint: disable=bare-except
206
+ _ , error , tb = sys .exc_info ()
207
+ process_status = SubpoolStatus (status = 1 , error = error , traceback = tb )
208
+ raise
180
209
finally :
181
- started . set ( )
210
+ status_queue . put ( process_status )
182
211
await pool .join ()
183
212
184
213
async def kill_sub_pool (self , process : multiprocessing .Process ,
@@ -203,8 +232,9 @@ async def recover_sub_pool(self, address: str):
203
232
process_index = self ._config .get_process_index (address )
204
233
# process dead, restart it
205
234
# remember always use spawn to recover sub pool
206
- self .sub_processes [address ] = await self .__class__ .start_sub_pool (
207
- self ._config , process_index , 'spawn' )
235
+ task = asyncio .create_task (self .start_sub_pool (
236
+ self ._config , process_index , 'spawn' ))
237
+ self .sub_processes [address ] = (await self .wait_sub_pools_ready ([task ]))[0 ]
208
238
209
239
if self ._auto_recover == 'actor' :
210
240
# need to recover all created actors
0 commit comments