@@ -44,16 +44,21 @@ class AsyncPartition:
44
44
This bucket of api_jobs is a bit useless for this iteration but should become interesting when we will be able to split jobs
45
45
"""
46
46
47
- _MAX_NUMBER_OF_ATTEMPTS = 3
47
+ _DEFAULT_MAX_JOB_RETRY = 3
48
48
49
- def __init__ (self , jobs : List [AsyncJob ], stream_slice : StreamSlice ) -> None :
49
+ def __init__ (
50
+ self , jobs : List [AsyncJob ], stream_slice : StreamSlice , job_max_retry : Optional [int ] = None
51
+ ) -> None :
50
52
self ._attempts_per_job = {job : 1 for job in jobs }
51
53
self ._stream_slice = stream_slice
54
+ self ._job_max_retry = (
55
+ job_max_retry if job_max_retry is not None else self ._DEFAULT_MAX_JOB_RETRY
56
+ )
52
57
53
58
def has_reached_max_attempt (self ) -> bool :
54
59
return any (
55
60
map (
56
- lambda attempt_count : attempt_count >= self ._MAX_NUMBER_OF_ATTEMPTS ,
61
+ lambda attempt_count : attempt_count >= self ._job_max_retry ,
57
62
self ._attempts_per_job .values (),
58
63
)
59
64
)
@@ -62,7 +67,7 @@ def replace_job(self, job_to_replace: AsyncJob, new_jobs: List[AsyncJob]) -> Non
62
67
current_attempt_count = self ._attempts_per_job .pop (job_to_replace , None )
63
68
if current_attempt_count is None :
64
69
raise ValueError ("Could not find job to replace" )
65
- elif current_attempt_count >= self ._MAX_NUMBER_OF_ATTEMPTS :
70
+ elif current_attempt_count >= self ._job_max_retry :
66
71
raise ValueError (f"Max attempt reached for job in partition { self ._stream_slice } " )
67
72
68
73
new_attempt_count = current_attempt_count + 1
@@ -155,6 +160,7 @@ def __init__(
155
160
message_repository : MessageRepository ,
156
161
exceptions_to_break_on : Iterable [Type [Exception ]] = tuple (),
157
162
has_bulk_parent : bool = False ,
163
+ job_max_retry : Optional [int ] = None ,
158
164
) -> None :
159
165
"""
160
166
If the stream slices provided as a parameters relies on a async job streams that relies on the same JobTracker, `has_bulk_parent`
@@ -175,6 +181,7 @@ def __init__(
175
181
self ._message_repository = message_repository
176
182
self ._exceptions_to_break_on : Tuple [Type [Exception ], ...] = tuple (exceptions_to_break_on )
177
183
self ._has_bulk_parent = has_bulk_parent
184
+ self ._job_max_retry = job_max_retry
178
185
179
186
self ._non_breaking_exceptions : List [Exception ] = []
180
187
@@ -214,7 +221,7 @@ def _start_jobs(self) -> None:
214
221
for _slice in self ._slice_iterator :
215
222
at_least_one_slice_consumed_from_slice_iterator_during_current_iteration = True
216
223
job = self ._start_job (_slice )
217
- self ._running_partitions .append (AsyncPartition ([job ], _slice ))
224
+ self ._running_partitions .append (AsyncPartition ([job ], _slice , self . _job_max_retry ))
218
225
if self ._has_bulk_parent and self ._slice_iterator .has_next ():
219
226
break
220
227
except ConcurrentJobLimitReached :
@@ -359,14 +366,11 @@ def _process_running_partitions_and_yield_completed_ones(
359
366
self ._process_partitions_with_errors (partition )
360
367
case _:
361
368
self ._stop_timed_out_jobs (partition )
369
+ # re-allocate FAILED jobs, but TIMEOUT jobs are not re-allocated
370
+ self ._reallocate_partition (current_running_partitions , partition )
362
371
363
- # job will be restarted in `_start_job`
364
- current_running_partitions .insert (0 , partition )
365
-
366
- for job in partition .jobs :
367
- # We only remove completed jobs as we want failed/timed out jobs to be re-allocated in priority
368
- if job .status () == AsyncJobStatus .COMPLETED :
369
- self ._job_tracker .remove_job (job .api_job_id ())
372
+ # We only remove completed / timeout jobs jobs as we want failed jobs to be re-allocated in priority
373
+ self ._remove_completed_jobs (partition )
370
374
371
375
# update the referenced list with running partitions
372
376
self ._running_partitions = current_running_partitions
@@ -381,7 +385,6 @@ def _stop_partition(self, partition: AsyncPartition) -> None:
381
385
def _stop_timed_out_jobs (self , partition : AsyncPartition ) -> None :
382
386
for job in partition .jobs :
383
387
if job .status () == AsyncJobStatus .TIMED_OUT :
384
- # we don't free allocation here because it is expected to retry the job
385
388
self ._abort_job (job , free_job_allocation = False )
386
389
387
390
def _abort_job (self , job : AsyncJob , free_job_allocation : bool = True ) -> None :
@@ -392,6 +395,31 @@ def _abort_job(self, job: AsyncJob, free_job_allocation: bool = True) -> None:
392
395
except Exception as exception :
393
396
LOGGER .warning (f"Could not free budget for job { job .api_job_id ()} : { exception } " )
394
397
398
+ def _remove_completed_jobs (self , partition : AsyncPartition ) -> None :
399
+ """
400
+ Remove completed or timed out jobs from the partition.
401
+
402
+ Args:
403
+ partition (AsyncPartition): The partition to process.
404
+ """
405
+ for job in partition .jobs :
406
+ if job .status () == AsyncJobStatus .COMPLETED :
407
+ self ._job_tracker .remove_job (job .api_job_id ())
408
+
409
+ def _reallocate_partition (
410
+ self ,
411
+ current_running_partitions : List [AsyncPartition ],
412
+ partition : AsyncPartition ,
413
+ ) -> None :
414
+ """
415
+ Reallocate the partition by starting a new job for each job in the
416
+ partition.
417
+ Args:
418
+ current_running_partitions (list): The list of currently running partitions.
419
+ partition (AsyncPartition): The partition to reallocate.
420
+ """
421
+ current_running_partitions .insert (0 , partition )
422
+
395
423
def _process_partitions_with_errors (self , partition : AsyncPartition ) -> None :
396
424
"""
397
425
Process a partition with status errors (FAILED and TIMEOUT).
0 commit comments