Skip to content

Commit afaf6a9

Browse files
authored
Merge branch 'main' into lazebnyi/fix-type-and-parameters-resolve-for-dynamic-streams
2 parents 70e6665 + 837913f commit afaf6a9

33 files changed

+2063
-100
lines changed

airbyte_cdk/sources/declarative/async_job/job.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ def api_job_id(self) -> str:
3434

3535
def status(self) -> AsyncJobStatus:
3636
if self._timer.has_timed_out():
37+
# TODO: we should account the fact that,
38+
# certain APIs could send the `Timeout` status,
39+
# thus we should not return `Timeout` in that case,
40+
# but act based on the scenario.
41+
42+
# the default behavior is to return `Timeout` status and retry.
3743
return AsyncJobStatus.TIMED_OUT
3844
return self._status
3945

airbyte_cdk/sources/declarative/async_job/job_orchestrator.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,21 @@ class AsyncPartition:
4444
This bucket of api_jobs is a bit useless for this iteration but should become interesting when we will be able to split jobs
4545
"""
4646

47-
_MAX_NUMBER_OF_ATTEMPTS = 3
47+
_DEFAULT_MAX_JOB_RETRY = 3
4848

49-
def __init__(self, jobs: List[AsyncJob], stream_slice: StreamSlice) -> None:
49+
def __init__(
50+
self, jobs: List[AsyncJob], stream_slice: StreamSlice, job_max_retry: Optional[int] = None
51+
) -> None:
5052
self._attempts_per_job = {job: 1 for job in jobs}
5153
self._stream_slice = stream_slice
54+
self._job_max_retry = (
55+
job_max_retry if job_max_retry is not None else self._DEFAULT_MAX_JOB_RETRY
56+
)
5257

5358
def has_reached_max_attempt(self) -> bool:
5459
return any(
5560
map(
56-
lambda attempt_count: attempt_count >= self._MAX_NUMBER_OF_ATTEMPTS,
61+
lambda attempt_count: attempt_count >= self._job_max_retry,
5762
self._attempts_per_job.values(),
5863
)
5964
)
@@ -62,7 +67,7 @@ def replace_job(self, job_to_replace: AsyncJob, new_jobs: List[AsyncJob]) -> Non
6267
current_attempt_count = self._attempts_per_job.pop(job_to_replace, None)
6368
if current_attempt_count is None:
6469
raise ValueError("Could not find job to replace")
65-
elif current_attempt_count >= self._MAX_NUMBER_OF_ATTEMPTS:
70+
elif current_attempt_count >= self._job_max_retry:
6671
raise ValueError(f"Max attempt reached for job in partition {self._stream_slice}")
6772

6873
new_attempt_count = current_attempt_count + 1
@@ -155,6 +160,7 @@ def __init__(
155160
message_repository: MessageRepository,
156161
exceptions_to_break_on: Iterable[Type[Exception]] = tuple(),
157162
has_bulk_parent: bool = False,
163+
job_max_retry: Optional[int] = None,
158164
) -> None:
159165
"""
160166
If the stream slices provided as a parameters relies on a async job streams that relies on the same JobTracker, `has_bulk_parent`
@@ -175,6 +181,7 @@ def __init__(
175181
self._message_repository = message_repository
176182
self._exceptions_to_break_on: Tuple[Type[Exception], ...] = tuple(exceptions_to_break_on)
177183
self._has_bulk_parent = has_bulk_parent
184+
self._job_max_retry = job_max_retry
178185

179186
self._non_breaking_exceptions: List[Exception] = []
180187

@@ -214,7 +221,7 @@ def _start_jobs(self) -> None:
214221
for _slice in self._slice_iterator:
215222
at_least_one_slice_consumed_from_slice_iterator_during_current_iteration = True
216223
job = self._start_job(_slice)
217-
self._running_partitions.append(AsyncPartition([job], _slice))
224+
self._running_partitions.append(AsyncPartition([job], _slice, self._job_max_retry))
218225
if self._has_bulk_parent and self._slice_iterator.has_next():
219226
break
220227
except ConcurrentJobLimitReached:
@@ -359,14 +366,11 @@ def _process_running_partitions_and_yield_completed_ones(
359366
self._process_partitions_with_errors(partition)
360367
case _:
361368
self._stop_timed_out_jobs(partition)
369+
# re-allocate FAILED jobs, but TIMEOUT jobs are not re-allocated
370+
self._reallocate_partition(current_running_partitions, partition)
362371

363-
# job will be restarted in `_start_job`
364-
current_running_partitions.insert(0, partition)
365-
366-
for job in partition.jobs:
367-
# We only remove completed jobs as we want failed/timed out jobs to be re-allocated in priority
368-
if job.status() == AsyncJobStatus.COMPLETED:
369-
self._job_tracker.remove_job(job.api_job_id())
372+
# We only remove completed / timeout jobs jobs as we want failed jobs to be re-allocated in priority
373+
self._remove_completed_jobs(partition)
370374

371375
# update the referenced list with running partitions
372376
self._running_partitions = current_running_partitions
@@ -381,7 +385,6 @@ def _stop_partition(self, partition: AsyncPartition) -> None:
381385
def _stop_timed_out_jobs(self, partition: AsyncPartition) -> None:
382386
for job in partition.jobs:
383387
if job.status() == AsyncJobStatus.TIMED_OUT:
384-
# we don't free allocation here because it is expected to retry the job
385388
self._abort_job(job, free_job_allocation=False)
386389

387390
def _abort_job(self, job: AsyncJob, free_job_allocation: bool = True) -> None:
@@ -392,6 +395,31 @@ def _abort_job(self, job: AsyncJob, free_job_allocation: bool = True) -> None:
392395
except Exception as exception:
393396
LOGGER.warning(f"Could not free budget for job {job.api_job_id()}: {exception}")
394397

398+
def _remove_completed_jobs(self, partition: AsyncPartition) -> None:
399+
"""
400+
Remove completed or timed out jobs from the partition.
401+
402+
Args:
403+
partition (AsyncPartition): The partition to process.
404+
"""
405+
for job in partition.jobs:
406+
if job.status() == AsyncJobStatus.COMPLETED:
407+
self._job_tracker.remove_job(job.api_job_id())
408+
409+
def _reallocate_partition(
410+
self,
411+
current_running_partitions: List[AsyncPartition],
412+
partition: AsyncPartition,
413+
) -> None:
414+
"""
415+
Reallocate the partition by starting a new job for each job in the
416+
partition.
417+
Args:
418+
current_running_partitions (list): The list of currently running partitions.
419+
partition (AsyncPartition): The partition to reallocate.
420+
"""
421+
current_running_partitions.insert(0, partition)
422+
395423
def _process_partitions_with_errors(self, partition: AsyncPartition) -> None:
396424
"""
397425
Process a partition with status errors (FAILED and TIMEOUT).

airbyte_cdk/sources/declarative/async_job/job_tracker.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import logging
44
import threading
55
import uuid
6-
from typing import Set
6+
from dataclasses import dataclass, field
7+
from typing import Any, Mapping, Set, Union
78

89
from airbyte_cdk.logger import lazy_log
10+
from airbyte_cdk.sources.declarative.interpolation import InterpolatedString
911

1012
LOGGER = logging.getLogger("airbyte")
1113

@@ -14,15 +16,29 @@ class ConcurrentJobLimitReached(Exception):
1416
pass
1517

1618

19+
@dataclass
1720
class JobTracker:
18-
def __init__(self, limit: int):
21+
limit: Union[int, str]
22+
config: Mapping[str, Any] = field(default_factory=dict)
23+
24+
def __post_init__(self) -> None:
1925
self._jobs: Set[str] = set()
20-
if limit < 1:
26+
self._lock = threading.Lock()
27+
if isinstance(self.limit, str):
28+
try:
29+
self.limit = int(
30+
InterpolatedString(self.limit, parameters={}).eval(config=self.config)
31+
)
32+
except Exception as e:
33+
LOGGER.warning(
34+
f"Error interpolating max job count: {self.limit}. Setting to 1. {e}"
35+
)
36+
self.limit = 1
37+
if self.limit < 1:
2138
LOGGER.warning(
22-
f"The `max_concurrent_async_job_count` property is less than 1: {limit}. Setting to 1. Please update the source manifest to set a valid value."
39+
f"The `max_concurrent_async_job_count` property is less than 1: {self.limit}. Setting to 1. Please update the source manifest to set a valid value."
2340
)
24-
self._limit = 1 if limit < 1 else limit
25-
self._lock = threading.Lock()
41+
self._limit = self.limit if self.limit >= 1 else 1
2642

2743
def try_to_get_intent(self) -> str:
2844
lazy_log(

airbyte_cdk/sources/declarative/concurrent_declarative_source.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,11 @@ def _group_streams(
206206
# these legacy Python streams the way we do low-code streams to determine if they are concurrent compatible,
207207
# so we need to treat them as synchronous
208208

209-
if name_to_stream_mapping[declarative_stream.name]["type"] == "StateDelegatingStream":
209+
if (
210+
isinstance(declarative_stream, DeclarativeStream)
211+
and name_to_stream_mapping[declarative_stream.name]["type"]
212+
== "StateDelegatingStream"
213+
):
210214
stream_state = self._connector_state_manager.get_stream_state(
211215
stream_name=declarative_stream.name, namespace=declarative_stream.namespace
212216
)

0 commit comments

Comments
 (0)