-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathscheduler.py
192 lines (141 loc) · 5.74 KB
/
scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
from typing import Tuple
from celery import group
from abc import ABC, abstractmethod
from tracetrack.tasks import celery
import uuid
class TaskNotFoundError(Exception):
pass
class Scheduler(ABC):
@abstractmethod
def get_result(self, task_id, index=None, timeout=15):
pass
@abstractmethod
def get_result_task_id(self, task_id, index):
pass
@abstractmethod
def get_results(self, task_id, timeout=15):
pass
@abstractmethod
def are_results_ready(self, task_id):
pass
@abstractmethod
def is_result_ready(self, task_id):
pass
@abstractmethod
def get_results_progress(self, task_id) -> Tuple[int, int, int]:
pass
@abstractmethod
def schedule_task(self, fun, *args, **kwargs):
pass
@abstractmethod
def schedule_tasks(self, fun, inputs):
pass
class CeleryScheduler(Scheduler):
"""Scheduler for scheduling tasks in the Celery task queue. Has functions to retrieve results of tasks."""
def get_celery_group_result(self, task_id):
group_result = celery.GroupResult.restore(task_id)
if group_result is None:
raise TaskNotFoundError(task_id)
return group_result
def get_result(self, task_id, index=None, timeout=15):
if index is not None:
assert int(index) > 0, f'Index is 1-indexed, got {index}'
return self.get_celery_group_result(task_id)[int(index) - 1].get(timeout=timeout)
async_result = celery.AsyncResult(task_id)
if async_result.state == 'PENDING':
# We update status of all created tasks to SENT, so no task should be PENDING
# PENDING tasks are actually missing tasks
raise TaskNotFoundError(task_id)
return async_result.get(timeout=timeout)
def get_result_task_id(self, task_id, index):
return self.get_celery_group_result(task_id)[int(index) - 1].id
def get_results(self, task_id, timeout=15):
return self.get_celery_group_result(task_id).join(timeout=timeout, propagate=False)
def are_results_ready(self, task_id):
return self.get_celery_group_result(task_id).ready()
def is_result_ready(self, task_id):
async_result = celery.AsyncResult(task_id)
return async_result.ready()
def get_results_progress(self, task_id) -> Tuple[int, int, int]:
group_result = self.get_celery_group_result(task_id)
num_running = sum(r.state == 'RUNNING' for r in group_result)
return num_running, group_result.completed_count(), len(group_result)
def schedule_task(self, fun, *args, **kwargs):
async_result = fun.delay(*args, **kwargs)
return async_result.id
def schedule_tasks(self, fun, inputs):
group_result = group([fun.s(**kwargs) for kwargs in inputs]).apply_async()
group_result.save()
return group_result.id
class SimpleInMemoryScheduler(Scheduler):
"""Simple scheduler for use when running TraceTrack locally."""
results = {}
def save_result(self, result):
task_id = None
while task_id is None or task_id in self.results:
task_id = uuid.uuid4().hex
self.results[task_id] = result
return task_id
def get_result(self, task_id, index=None, timeout=15):
if task_id not in self.results:
raise TaskNotFoundError(task_id)
if index is not None:
assert int(index) > 0, f'Index is 1-indexed, got {index}'
return self.get_result(self.results[task_id][int(index)-1])
return self.results[task_id]
def get_result_task_id(self, task_id, index):
task_ids = self.get_result(task_id)
return task_ids[int(index)-1]
def get_results(self, task_id, timeout=15):
task_ids = self.get_result(task_id)
return [self.get_result(i) for i in task_ids]
def are_results_ready(self, task_id):
return task_id in self.results
def is_result_ready(self, task_id):
return task_id in self.results
def get_results_progress(self, task_id) -> Tuple[int, int, int]:
# This should never be called because tasks are always ready
raise NotImplementedError()
def schedule_task(self, fun, *args, **kwargs):
return self.save_result(fun(*args, **kwargs))
def schedule_tasks(self, fun, inputs):
task_ids = []
for kwargs in inputs:
task_ids.append(self.save_result(fun(**kwargs)))
return self.save_result(task_ids)
class NotInitializedScheduler(Scheduler):
def throw(self):
raise NotImplementedError('Scheduler not initialized')
def get_result_task_id(self, task_id, index):
self.throw()
def get_results(self, task_id, timeout=15):
self.throw()
def are_results_ready(self, task_id):
self.throw()
def is_result_ready(self, task_id):
self.throw()
def get_results_progress(self, task_id):
self.throw()
def schedule_task(self, fun, *args, **kwargs):
self.throw()
def schedule_tasks(self, fun, inputs):
self.throw()
def get_result(self, task_id, index=None, timeout=15):
self.throw()
class SchedulerProxy:
wrapped = NotInitializedScheduler()
def __getattr__(self, attr):
return getattr(self.wrapped, attr)
scheduler: Scheduler = SchedulerProxy()
def use_scheduler(name):
"""
Initialize scheduler based on name value. name is passed depending on how TraceTrack is run: simple for running
locally, celery for running on a server with a task queue.
"""
global scheduler
if name == 'celery':
scheduler.wrapped = CeleryScheduler()
elif name == 'simple':
scheduler.wrapped = SimpleInMemoryScheduler()
else:
raise ValueError(f"Unsupported scheduler: {name}")