forked from cea-hpc/clustershell
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTreeGatewayTest.py
484 lines (405 loc) · 18.6 KB
/
TreeGatewayTest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
"""
Unit test for ClusterShell.Gateway
"""
import logging
import os
import re
import unittest
import xml.sax
from ClusterShell import __version__
from ClusterShell.Communication import ConfigurationMessage, ControlMessage, \
StdOutMessage, StdErrMessage, RetcodeMessage, ACKMessage, ErrorMessage, \
TimeoutMessage, StartMessage, EndMessage, XMLReader
from ClusterShell.Gateway import GatewayChannel
from ClusterShell.NodeSet import NodeSet
from ClusterShell.Task import Task, task_self
from ClusterShell.Topology import TopologyGraph
from ClusterShell.Worker.Tree import TreeWorker
from ClusterShell.Worker.Worker import StreamWorker
from TLib import HOSTNAME
# live logging with nosetests --nologcapture
logging.basicConfig(level=logging.DEBUG)
class Gateway(object):
"""Gateway special test class.
Initialize a GatewayChannel through a R/W StreamWorker like a real
remote ClusterShell Gateway but:
- using pipes to communicate,
- running on a dedicated task/thread.
"""
def __init__(self):
"""init Gateway bound objects"""
self.task = Task()
self.channel = GatewayChannel(self.task)
self.worker = StreamWorker(handler=self.channel)
# create communication pipes
self.pipe_stdin = os.pipe()
self.pipe_stdout = os.pipe()
# avoid nonblocking flag as we want recv/read() to block
self.worker.set_reader(self.channel.SNAME_READER,
self.pipe_stdin[0])
self.worker.set_writer(self.channel.SNAME_WRITER,
self.pipe_stdout[1], retain=False)
self.task.schedule(self.worker)
self.task.resume()
def send(self, msg):
"""send msg (bytes) to pseudo stdin"""
os.write(self.pipe_stdin[1], msg + b'\n')
def send_str(self, msgstr):
"""send msg (string) to pseudo stdin"""
self.send(msgstr.encode())
def recv(self):
"""recv buf from pseudo stdout (blocking call)"""
return os.read(self.pipe_stdout[0], 4096)
def wait(self):
"""wait for task/thread termination"""
# can be blocked indefinitely if StreamWorker doesn't complete
self.task.join()
def close(self):
"""close parent fds"""
os.close(self.pipe_stdout[0])
os.close(self.pipe_stdin[1])
def destroy(self):
"""abort task/thread"""
self.task.abort(kill=True)
class TreeGatewayBaseTest(unittest.TestCase):
"""base test class"""
def setUp(self):
"""setup gateway and topology for each test"""
# gateway
self.gateway = Gateway()
self.chan = self.gateway.channel
# topology
graph = TopologyGraph()
graph.add_route(NodeSet(HOSTNAME), NodeSet('n[1-2]'))
graph.add_route(NodeSet('n1'), NodeSet('n[10-49]'))
graph.add_route(NodeSet('n2'), NodeSet('n[50-89]'))
self.topology = graph.to_tree(HOSTNAME)
# xml parser with Communication.XMLReader as content handler
self.xml_reader = XMLReader()
self.parser = xml.sax.make_parser(["IncrementalParser"])
self.parser.setContentHandler(self.xml_reader)
def tearDown(self):
"""destroy gateway after each test"""
self.gateway.destroy()
self.gateway = None
#
# Send to GW
#
def channel_send_start(self):
"""send starting channel tag"""
self.gateway.send_str('<channel version="%s">' % __version__)
def channel_send_stop(self):
"""send channel ending tag"""
self.gateway.send_str("</channel>")
def channel_send_cfg(self, gateway):
"""send configuration part of channel"""
# code snippet from PropagationChannel.start()
cfg = ConfigurationMessage(gateway)
cfg.data_encode(self.topology)
self.gateway.send(cfg.xml())
#
# Receive from GW
#
def assert_isinstance(self, msg, msg_class):
"""helper to check a message instance"""
self.assertTrue(isinstance(msg, msg_class),
"%s is not a %s" % (type(msg), msg_class))
def _recvxml(self):
while not self.xml_reader.msg_available():
xml_msg = self.gateway.recv()
if len(xml_msg) == 0:
self.parser.close()
break
self.assertTrue(type(xml_msg) is bytes)
self.parser.feed(xml_msg)
if hasattr(self.parser, 'flush'): # >=3.13 and backports
self.parser.flush()
return self.xml_reader.pop_msg()
def recvxml(self, expected_msg_class=None):
msg = self._recvxml()
if expected_msg_class is None:
self.assertEqual(msg, None)
else:
self.assert_isinstance(msg, expected_msg_class)
return msg
class TreeGatewayTest(TreeGatewayBaseTest):
def test_basic_noop(self):
"""test gateway channel open/close"""
self.channel_send_start()
self.recvxml(StartMessage)
self.assertEqual(self.chan.opened, True)
self.assertEqual(self.chan.setup, False)
self.channel_send_stop()
self.recvxml(EndMessage)
# ending tag should abort gateway worker without delay
self.gateway.wait()
self.gateway.close()
def test_channel_err_dup(self):
"""test gateway channel duplicate tags"""
self.channel_send_start()
msg = self.recvxml(StartMessage)
self.assertEqual(self.chan.opened, True)
self.assertEqual(self.chan.setup, False)
# send an unexpected second channel tag
self.channel_send_start()
msg = self.recvxml(ErrorMessage)
self.assertEqual(msg.type, 'ERR')
reason = 'unexpected message: Message CHA '
self.assertEqual(msg.reason[:len(reason)], reason)
# gateway should terminate channel session
msg = self.recvxml(EndMessage)
self.gateway.wait()
self.gateway.close()
def _check_channel_err(self, sendmsg, errback, openchan=True,
setupchan=False):
"""helper to ease test of erroneous messages sent to gateway"""
if openchan:
self.channel_send_start()
msg = self.recvxml(StartMessage)
self.assertEqual(self.chan.opened, True)
self.assertEqual(self.chan.setup, False)
if setupchan:
# send channel configuration
self.channel_send_cfg('n1')
msg = self.recvxml(ACKMessage)
self.assertEqual(self.chan.setup, True)
# send the erroneous message and test gateway reply
self.gateway.send_str(sendmsg)
msg = self.recvxml(ErrorMessage)
self.assertEqual(msg.type, 'ERR')
try:
if not errback.search(msg.reason):
self.assertFalse(msg.reason)
except AttributeError:
# not a regex
self.assertEqual(msg.reason, errback)
# gateway should terminate channel session
if openchan:
msg = self.recvxml(EndMessage)
self.assertEqual(msg.type, 'END')
else:
self.recvxml()
# gateway task should exit properly
self.gateway.wait()
self.gateway.close()
def test_err_start_with_ending_tag(self):
"""test gateway missing opening channel tag"""
self._check_channel_err('</channel>',
'Parse error: not well-formed (invalid token)',
openchan=False)
def test_err_channel_end_msg(self):
"""test gateway channel missing opening message tag"""
self._check_channel_err('</message>',
'Parse error: mismatched tag')
def test_err_channel_end_msg_setup(self):
"""test gateway channel missing opening message tag (setup)"""
self._check_channel_err('</message>',
'Parse error: mismatched tag',
setupchan=True)
def test_err_unknown_tag(self):
"""test gateway unknown tag"""
self._check_channel_err('<foobar></footbar>',
'Invalid starting tag foobar',
openchan=False)
def test_channel_err_unknown_tag(self):
"""test gateway unknown tag in channel"""
self._check_channel_err('<foo></foo>', 'Invalid starting tag foo')
def test_channel_err_unknown_tag_setup(self):
"""test gateway unknown tag in channel (setup)"""
self._check_channel_err('<foo></foo>',
'Invalid starting tag foo',
setupchan=True)
def test_err_unknown_msg(self):
"""test gateway unknown message"""
self._check_channel_err('<message msgid="24" type="ABC"></message>',
'Unknown message type',
openchan=False)
def test_channel_err_unknown_msg(self):
"""test gateway channel unknown message"""
self._check_channel_err('<message msgid="24" type="ABC"></message>',
'Unknown message type')
def test_err_xml_malformed(self):
"""test gateway malformed xml message"""
self._check_channel_err('<message type="ABC"</message>',
'Parse error: not well-formed (invalid token)',
openchan=False)
def test_channel_err_xml_malformed(self):
"""test gateway channel malformed xml message"""
self._check_channel_err('<message type="ABC"</message>',
'Parse error: not well-formed (invalid token)')
def test_channel_err_xml_malformed_setup(self):
"""test gateway channel malformed xml message"""
self._check_channel_err('<message type="ABC"</message>',
'Parse error: not well-formed (invalid token)',
setupchan=True)
def test_channel_err_xml_bad_char(self):
"""test gateway channel malformed xml message (bad chars)"""
self._check_channel_err('\x11<message type="ABC"></message>',
'Parse error: not well-formed (invalid token)')
def test_channel_err_missingattr(self):
"""test gateway channel message bad attributes"""
self._check_channel_err(
'<message msgid="24" nodes="foo" retcode="4" type="RET"></message>',
'Invalid "message" attributes: missing key "srcid"')
def test_channel_err_unexpected(self):
"""test gateway channel unexpected message"""
self._check_channel_err(
'<message type="ACK" ack="2" msgid="2"></message>',
re.compile(r'unexpected message: Message ACK \(.*ack: 2.*\)'))
def test_channel_err_cfg_missing_gw(self):
"""test gateway channel message missing gateway nodename"""
self._check_channel_err(
'<message msgid="337" type="CFG">DUMMY</message>',
'Invalid "message" attributes: missing key "gateway"')
def test_channel_err_missing_pl(self):
"""test gateway channel message missing payload"""
self._check_channel_err(
'<message msgid="14" type="CFG" gateway="n1"></message>',
'Message CFG has an invalid payload')
def test_channel_err_unexpected_pl(self):
"""test gateway channel message unexpected payload"""
self._check_channel_err(
'<message msgid="14" type="ERR" reason="test">FOO</message>',
'Got unexpected payload for Message ERR', setupchan=True)
def test_channel_err_badenc_b2a_pl(self):
"""test gateway channel message badly encoded payload (base64)"""
# Generate TypeError (py2) or binascii.Error (py3)
self._check_channel_err(
'<message msgid="14" type="CFG" gateway="n1">bar</message>',
'Message CFG has an invalid payload')
def test_channel_err_badenc_pickle_pl(self):
"""test gateway channel message badly encoded payload (pickle)"""
# Generate pickle error
self._check_channel_err(
'<message msgid="14" type="CFG" gateway="n1">barm</message>',
'Message CFG has an invalid payload')
def test_channel_basic_abort(self):
"""test gateway channel aborted while opened"""
self.channel_send_start()
self.recvxml(StartMessage)
self.assertEqual(self.chan.opened, True)
self.assertEqual(self.chan.setup, False)
self.gateway.close()
self.gateway.wait()
def _check_channel_ctl_shell(self, command, target, stderr, remote,
reply_msg_class, reply_pattern,
write_buf=None, timeout=-1, replycnt=1,
reply_rc=0):
"""helper to check channel shell action"""
self.channel_send_start()
msg = self.recvxml(StartMessage)
self.channel_send_cfg('n1')
msg = self.recvxml(ACKMessage)
# prepare a remote shell command request...
workertree = TreeWorker(nodes=target, handler=None, timeout=timeout,
command=command)
# code snippet from PropagationChannel.shell()
ctl = ControlMessage(id(workertree))
ctl.action = 'shell'
ctl.target = NodeSet(target)
info = task_self()._info.copy()
info['debug'] = False
ctl_data = {
'cmd': command,
'invoke_gateway': workertree.invoke_gateway,
'taskinfo': info,
'stderr': stderr,
'timeout': timeout,
'remote': remote
}
ctl.data_encode(ctl_data)
self.gateway.send(ctl.xml())
self.recvxml(ACKMessage)
if write_buf:
ctl = ControlMessage(id(workertree))
ctl.action = 'write'
ctl.target = NodeSet(target)
ctl_data = {
'buf': write_buf,
}
# Send write message
ctl.data_encode(ctl_data)
self.gateway.send(ctl.xml())
self.recvxml(ACKMessage)
# Send EOF message
ctl = ControlMessage(id(workertree))
ctl.action = 'eof'
ctl.target = NodeSet(target)
self.gateway.send(ctl.xml())
self.recvxml(ACKMessage)
while replycnt > 0:
msg = self.recvxml(reply_msg_class)
replycnt -= len(NodeSet(msg.nodes))
self.assertTrue(msg.nodes in ctl.target)
if msg.has_payload or reply_pattern:
msg_data = msg.data_decode()
try:
if not reply_pattern.search(msg_data):
self.assertEqual(msg.data, reply_pattern,
'Pattern "%s" not found in data="%s"'
% (reply_pattern.pattern, msg_data))
except AttributeError:
# not a regexp
self.assertEqual(msg_data, reply_pattern)
if timeout <= 0:
msg = self.recvxml(RetcodeMessage)
self.assertEqual(msg.retcode, reply_rc)
self.channel_send_stop()
self.gateway.wait()
self.gateway.close()
def test_channel_ctl_shell_local1(self):
"""test gateway channel shell stdout (stderr=False remote=False)"""
self._check_channel_ctl_shell("echo ok", "n10", False, False,
StdOutMessage, b"ok")
def test_channel_ctl_shell_local2(self):
"""test gateway channel shell stdout (stderr=True remote=False)"""
self._check_channel_ctl_shell("echo ok", "n10", True, False,
StdOutMessage, b"ok")
def test_channel_ctl_shell_local3(self):
"""test gateway channel shell stderr (stderr=True remote=False)"""
self._check_channel_ctl_shell("echo ok >&2", "n10", True, False,
StdErrMessage, b"ok")
def test_channel_ctl_shell_mlocal1(self):
"""test gateway channel shell multi (remote=False)"""
self._check_channel_ctl_shell("echo ok", "n[10-49]", True, False,
StdOutMessage, b"ok", replycnt=40)
def test_channel_ctl_shell_mlocal2(self):
"""test gateway channel shell multi stderr (remote=False)"""
self._check_channel_ctl_shell("echo ok 1>&2", "n[10-49]", True, False,
StdErrMessage, b"ok", replycnt=40)
def test_channel_ctl_shell_mlocal3(self):
"""test gateway channel shell multi placeholder (remote=False)"""
self._check_channel_ctl_shell('echo node %h rank %n', "n[10-29]", True,
False, StdOutMessage,
re.compile(br"node n\d+ rank \d+"),
replycnt=20)
def test_channel_ctl_shell_remote1(self):
"""test gateway channel shell stdout (stderr=False remote=True)"""
self._check_channel_ctl_shell("echo ok", "n10", False, True,
StdOutMessage,
re.compile(b"(Could not resolve hostname|"
b"Name or service not known)"),
reply_rc=255)
def test_channel_ctl_shell_remote2(self):
"""test gateway channel shell stdout (stderr=True remote=True)"""
self._check_channel_ctl_shell("echo ok", "n10", True, True,
StdErrMessage,
re.compile(b"(Could not resolve hostname|"
b"Name or service not known)"),
reply_rc=255)
def test_channel_ctl_shell_timeo1(self):
"""test gateway channel shell timeout"""
self._check_channel_ctl_shell("sleep 10", "n10", False, False,
TimeoutMessage, None, timeout=0.5)
def test_channel_ctl_shell_wrloc1(self):
"""test gateway channel write (stderr=False remote=False)"""
self._check_channel_ctl_shell("cat", "n10", False, False,
StdOutMessage, b"ok", write_buf=b"ok\n")
def test_channel_ctl_shell_wrloc2(self):
"""test gateway channel write (stderr=True remote=False)"""
self._check_channel_ctl_shell("cat", "n10", True, False,
StdOutMessage, b"ok", write_buf=b"ok\n")
def test_channel_ctl_shell_mwrloc1(self):
"""test gateway channel write multi (remote=False)"""
self._check_channel_ctl_shell("cat", "n[10-49]", True, False,
StdOutMessage, b"ok", write_buf=b"ok\n")