5
5
from traitlets import Bytes
6
6
from traitlets import Integer
7
7
from traitlets import List
8
+ from traitlets import Unicode
8
9
9
10
from ipyparallel import util
10
11
from ipyparallel .controller .scheduler import get_common_scheduler_streams
15
16
class BroadcastScheduler (Scheduler ):
16
17
port_name = 'broadcast'
17
18
accumulated_replies = {}
19
+ accumulated_targets = {}
18
20
is_leaf = Bool (False )
19
21
connected_sub_scheduler_ids = List (Bytes ())
20
22
outgoing_streams = List ()
21
23
depth = Integer ()
22
24
max_depth = Integer ()
25
+ name = Unicode ()
23
26
24
27
def start (self ):
25
28
self .client_stream .on_recv (self .dispatch_submission , copy = False )
@@ -28,12 +31,14 @@ def start(self):
28
31
else :
29
32
for outgoing_stream in self .outgoing_streams :
30
33
outgoing_stream .on_recv (self .dispatch_result , copy = False )
34
+ self .log .info (f"BroadcastScheduler { self .name } started" )
31
35
32
36
def send_to_targets (self , msg , original_msg_id , targets , idents , is_coalescing ):
33
37
if is_coalescing :
34
38
self .accumulated_replies [original_msg_id ] = {
35
- bytes ( target , 'utf8' ): None for target in targets
39
+ target . encode ( 'utf8' ): None for target in targets
36
40
}
41
+ self .accumulated_targets [original_msg_id ] = targets
37
42
38
43
for target in targets :
39
44
new_msg = self .append_new_msg_id_to_msg (
@@ -44,11 +49,6 @@ def send_to_targets(self, msg, original_msg_id, targets, idents, is_coalescing):
44
49
def send_to_sub_schedulers (
45
50
self , msg , original_msg_id , targets , idents , is_coalescing
46
51
):
47
- if is_coalescing :
48
- self .accumulated_replies [original_msg_id ] = {
49
- scheduler_id : None for scheduler_id in self .connected_sub_scheduler_ids
50
- }
51
-
52
52
trunc = 2 ** self .max_depth
53
53
fmt = f"0{ self .max_depth + 1 } b"
54
54
@@ -62,10 +62,21 @@ def send_to_sub_schedulers(
62
62
next_idx = int (path [self .depth + 1 ]) # 0 or 1
63
63
targets_by_scheduler [next_idx ].append (target_tuple )
64
64
65
+ if is_coalescing :
66
+ self .accumulated_replies [original_msg_id ] = {
67
+ scheduler_id : None for scheduler_id in self .connected_sub_scheduler_ids
68
+ }
69
+ self .accumulated_targets [original_msg_id ] = {}
70
+
65
71
for i , scheduler_id in enumerate (self .connected_sub_scheduler_ids ):
66
72
targets_for_scheduler = targets_by_scheduler [i ]
67
- if not targets_for_scheduler and is_coalescing :
68
- del self .accumulated_replies [original_msg_id ][scheduler_id ]
73
+ if is_coalescing :
74
+ if targets_for_scheduler :
75
+ self .accumulated_targets [original_msg_id ][
76
+ scheduler_id
77
+ ] = targets_for_scheduler
78
+ else :
79
+ del self .accumulated_replies [original_msg_id ][scheduler_id ]
69
80
msg ['metadata' ]['targets' ] = targets_for_scheduler
70
81
71
82
new_msg = self .append_new_msg_id_to_msg (
@@ -76,28 +87,36 @@ def send_to_sub_schedulers(
76
87
)
77
88
self .outgoing_streams [i ].send_multipart (new_msg , copy = False )
78
89
79
- def coalescing_reply (self , raw_msg , msg , original_msg_id , outgoing_id ):
90
+ def coalescing_reply (self , raw_msg , msg , original_msg_id , outgoing_id , idents ):
91
+ # accumulate buffers
92
+ self .accumulated_replies [original_msg_id ][outgoing_id ] = msg ['buffers' ]
80
93
if all (
81
- msg is not None or stored_outgoing_id == outgoing_id
82
- for stored_outgoing_id , msg in self .accumulated_replies [
83
- original_msg_id
84
- ].items ()
94
+ msg_buffers is not None
95
+ for msg_buffers in self .accumulated_replies [original_msg_id ].values ()
85
96
):
86
- new_msg = raw_msg [1 :]
87
- new_msg .extend (
88
- [
89
- buffer
90
- for msg_buffers in self .accumulated_replies [
91
- original_msg_id
92
- ].values ()
93
- if msg_buffers
94
- for buffer in msg_buffers
95
- ]
97
+ replies = self .accumulated_replies .pop (original_msg_id )
98
+ self .log .debug (f"Coalescing { len (replies )} reply to { original_msg_id } " )
99
+ targets = self .accumulated_targets .pop (original_msg_id )
100
+
101
+ new_msg = msg .copy ()
102
+ # begin rebuilding message
103
+ # metadata['targets']
104
+ if self .is_leaf :
105
+ new_msg ['metadata' ]['broadcast_targets' ] = targets
106
+ else :
107
+ new_msg ['metadata' ]['broadcast_targets' ] = []
108
+
109
+ # avoid duplicated msg buffers
110
+ buffers = []
111
+ for sub_target , msg_buffers in replies .items ():
112
+ buffers .extend (msg_buffers )
113
+ if not self .is_leaf :
114
+ new_msg ['metadata' ]['broadcast_targets' ].extend (targets [sub_target ])
115
+
116
+ new_raw_msg = self .session .serialize (new_msg )
117
+ self .client_stream .send_multipart (
118
+ idents + new_raw_msg + buffers , copy = False
96
119
)
97
- self .client_stream .send_multipart (new_msg , copy = False )
98
- del self .accumulated_replies [original_msg_id ]
99
- else :
100
- self .accumulated_replies [original_msg_id ][outgoing_id ] = msg ['buffers' ]
101
120
102
121
@util .log_errors
103
122
def dispatch_submission (self , raw_msg ):
@@ -144,7 +163,9 @@ def dispatch_result(self, raw_msg):
144
163
original_msg_id = msg ['metadata' ]['original_msg_id' ]
145
164
is_coalescing = msg ['metadata' ]['is_coalescing' ]
146
165
if is_coalescing :
147
- self .coalescing_reply (raw_msg , msg , original_msg_id , outgoing_id )
166
+ self .coalescing_reply (
167
+ raw_msg , msg , original_msg_id , outgoing_id , idents [1 :]
168
+ )
148
169
else :
149
170
self .client_stream .send_multipart (raw_msg [1 :], copy = False )
150
171
@@ -223,6 +244,7 @@ def launch_broadcast_scheduler(
223
244
config = config ,
224
245
depth = depth ,
225
246
max_depth = max_depth ,
247
+ name = identity ,
226
248
)
227
249
if is_leaf :
228
250
scheduler_args .update (engine_stream = outgoing_streams [0 ], is_leaf = True )
0 commit comments