|
8 | 8 | import json
|
9 | 9 | import logging
|
10 | 10 | import os
|
| 11 | +import re |
11 | 12 | import tempfile
|
12 | 13 | import time
|
13 | 14 | from dataclasses import dataclass, field
|
14 | 15 | from datetime import datetime
|
15 | 16 | from shutil import copy2, rmtree
|
16 |
| -from typing import Any, cast, Dict, Iterable, List, Mapping, Optional, Set, Type # noqa |
| 17 | +from typing import Any, cast, Dict, Iterable, List, Optional, Tuple # noqa |
17 | 18 |
|
18 | 19 | from torchx.schedulers.api import (
|
19 | 20 | AppDryRunInfo,
|
@@ -322,13 +323,25 @@ def wait_until_finish(self, app_id: str, timeout: int = 30) -> None:
|
322 | 323 | break
|
323 | 324 | time.sleep(1)
|
324 | 325 |
|
325 |
| - def _cancel_existing(self, app_id: str) -> None: # pragma: no cover |
| 326 | + def _parse_app_id(self, app_id: str) -> Tuple[str, str]: |
| 327 | + # find index of '-' in the first :\d+- |
| 328 | + m = re.search(r":\d+-", app_id) |
| 329 | + if m: |
| 330 | + sep = m.span()[1] |
| 331 | + addr = app_id[: sep - 1] |
| 332 | + app_id = app_id[sep:] |
| 333 | + return addr, app_id |
| 334 | + |
326 | 335 | addr, _, app_id = app_id.partition("-")
|
| 336 | + return addr, app_id |
| 337 | + |
| 338 | + def _cancel_existing(self, app_id: str) -> None: # pragma: no cover |
| 339 | + addr, app_id = self._parse_app_id(app_id) |
327 | 340 | client = JobSubmissionClient(f"http://{addr}")
|
328 | 341 | client.stop_job(app_id)
|
329 | 342 |
|
330 | 343 | def _get_job_status(self, app_id: str) -> JobStatus:
|
331 |
| - addr, _, app_id = app_id.partition("-") |
| 344 | + addr, app_id = self._parse_app_id(app_id) |
332 | 345 | client = JobSubmissionClient(f"http://{addr}")
|
333 | 346 | status = client.get_job_status(app_id)
|
334 | 347 | if isinstance(status, str):
|
@@ -375,7 +388,7 @@ def log_iter(
|
375 | 388 | streams: Optional[Stream] = None,
|
376 | 389 | ) -> Iterable[str]:
|
377 | 390 | # TODO: support tailing, streams etc..
|
378 |
| - addr, _, app_id = app_id.partition("-") |
| 391 | + addr, app_id = self._parse_app_id(app_id) |
379 | 392 | client: JobSubmissionClient = JobSubmissionClient(f"http://{addr}")
|
380 | 393 | logs: str = client.get_job_logs(app_id)
|
381 | 394 | iterator = split_lines(logs)
|
|
0 commit comments