Skip to content

Commit e9d1ff0

Browse files
authored
feat: support CustomRetriever with partition router (#753)
1 parent fd3a330 commit e9d1ff0

File tree

2 files changed

+235
-7
lines changed

2 files changed

+235
-7
lines changed

airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1976,7 +1976,10 @@ def create_default_stream(
19761976
primary_key = model.primary_key.__root__ if model.primary_key else None
19771977

19781978
partition_router = self._build_stream_slicer_from_partition_router(
1979-
model.retriever, config, stream_name=model.name
1979+
model.retriever,
1980+
config,
1981+
stream_name=model.name,
1982+
**kwargs,
19801983
)
19811984
concurrent_cursor = self._build_concurrent_cursor(model, partition_router, config)
19821985
if model.incremental_sync and isinstance(model.incremental_sync, DatetimeBasedCursorModel):
@@ -2155,10 +2158,11 @@ def _build_stream_slicer_from_partition_router(
21552158
],
21562159
config: Config,
21572160
stream_name: Optional[str] = None,
2161+
**kwargs: Any,
21582162
) -> PartitionRouter:
21592163
if (
21602164
hasattr(model, "partition_router")
2161-
and isinstance(model, SimpleRetrieverModel | AsyncRetrieverModel)
2165+
and isinstance(model, (SimpleRetrieverModel, AsyncRetrieverModel, CustomRetrieverModel))
21622166
and model.partition_router
21632167
):
21642168
stream_slicer_model = model.partition_router
@@ -2172,6 +2176,23 @@ def _build_stream_slicer_from_partition_router(
21722176
],
21732177
parameters={},
21742178
)
2179+
elif isinstance(stream_slicer_model, dict):
2180+
# partition router comes from CustomRetrieverModel therefore has not been parsed as a model
2181+
params = stream_slicer_model.get("$parameters")
2182+
if not isinstance(params, dict):
2183+
params = {}
2184+
stream_slicer_model["$parameters"] = params
2185+
2186+
if stream_name is not None:
2187+
params["stream_name"] = stream_name
2188+
2189+
return self._create_nested_component( # type: ignore[no-any-return] # There is no guarantee that this will return a stream slicer. If not, we expect an AttributeError during the call to `stream_slices`
2190+
model,
2191+
"partition_router",
2192+
stream_slicer_model,
2193+
config,
2194+
**kwargs,
2195+
)
21752196
else:
21762197
return self._create_component_from_model( # type: ignore[no-any-return] # Will be created PartitionRouter as stream_slicer_model is model.partition_router
21772198
model=stream_slicer_model, config=config, stream_name=stream_name or ""
@@ -2886,7 +2907,7 @@ def create_page_increment(
28862907
)
28872908

28882909
def create_parent_stream_config(
2889-
self, model: ParentStreamConfigModel, config: Config, stream_name: str, **kwargs: Any
2910+
self, model: ParentStreamConfigModel, config: Config, *, stream_name: str, **kwargs: Any
28902911
) -> ParentStreamConfig:
28912912
declarative_stream = self._create_component_from_model(
28922913
model.stream,
@@ -3693,14 +3714,19 @@ def create_spec(self, model: SpecModel, config: Config, **kwargs: Any) -> Spec:
36933714
)
36943715

36953716
def create_substream_partition_router(
3696-
self, model: SubstreamPartitionRouterModel, config: Config, **kwargs: Any
3717+
self,
3718+
model: SubstreamPartitionRouterModel,
3719+
config: Config,
3720+
*,
3721+
stream_name: str,
3722+
**kwargs: Any,
36973723
) -> SubstreamPartitionRouter:
36983724
parent_stream_configs = []
36993725
if model.parent_stream_configs:
37003726
parent_stream_configs.extend(
37013727
[
37023728
self.create_parent_stream_config_with_substream_wrapper(
3703-
model=parent_stream_config, config=config, **kwargs
3729+
model=parent_stream_config, config=config, stream_name=stream_name, **kwargs
37043730
)
37053731
for parent_stream_config in model.parent_stream_configs
37063732
]
@@ -3720,7 +3746,7 @@ def create_parent_stream_config_with_substream_wrapper(
37203746

37213747
# This flag will be used exclusively for StateDelegatingStream when a parent stream is created
37223748
has_parent_state = bool(
3723-
self._connector_state_manager.get_stream_state(kwargs.get("stream_name", ""), None)
3749+
self._connector_state_manager.get_stream_state(stream_name, None)
37243750
if model.incremental_dependency
37253751
else False
37263752
)
@@ -4113,11 +4139,17 @@ def set_api_budget(self, component_definition: ComponentDefinition, config: Conf
41134139
)
41144140

41154141
def create_grouping_partition_router(
4116-
self, model: GroupingPartitionRouterModel, config: Config, **kwargs: Any
4142+
self,
4143+
model: GroupingPartitionRouterModel,
4144+
config: Config,
4145+
*,
4146+
stream_name: str,
4147+
**kwargs: Any,
41174148
) -> GroupingPartitionRouter:
41184149
underlying_router = self._create_component_from_model(
41194150
model=model.underlying_partition_router,
41204151
config=config,
4152+
stream_name=stream_name,
41214153
**kwargs,
41224154
)
41234155
if model.group_size < 1:

unit_tests/sources/declarative/parsers/test_model_to_component_factory.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -992,6 +992,202 @@ def test_stream_with_custom_retriever_and_transformations():
992992
assert get_retriever(stream).record_selector.transformations
993993

994994

995+
def test_stream_with_custom_retriever_and_partition_router():
996+
content = """
997+
a_stream:
998+
type: DeclarativeStream
999+
primary_key: id
1000+
schema_loader:
1001+
type: InlineSchemaLoader
1002+
schema:
1003+
$schema: "http://json-schema.org/draft-07/schema"
1004+
type: object
1005+
properties:
1006+
id:
1007+
type: string
1008+
retriever:
1009+
type: CustomRetriever
1010+
class_name: unit_tests.sources.declarative.parsers.testing_components.TestingCustomRetriever
1011+
record_selector:
1012+
type: RecordSelector
1013+
extractor:
1014+
field_path: []
1015+
requester:
1016+
type: HttpRequester
1017+
url_base: "https://api.sendgrid.com/v3/"
1018+
http_method: "GET"
1019+
partition_router:
1020+
type: SubstreamPartitionRouter
1021+
parent_stream_configs:
1022+
- parent_key: id
1023+
partition_field: id
1024+
stream:
1025+
type: DeclarativeStream
1026+
primary_key: id
1027+
schema_loader:
1028+
type: InlineSchemaLoader
1029+
schema:
1030+
$schema: "http://json-schema.org/draft-07/schema"
1031+
type: object
1032+
properties:
1033+
id:
1034+
type: string
1035+
retriever:
1036+
type: SimpleRetriever
1037+
requester:
1038+
type: HttpRequester
1039+
url_base: "https://api.sendgrid.com/v3/parent"
1040+
http_method: "GET"
1041+
record_selector:
1042+
type: RecordSelector
1043+
extractor:
1044+
field_path: []
1045+
$parameters:
1046+
name: a_stream
1047+
"""
1048+
1049+
parsed_manifest = YamlDeclarativeSource._parse(content)
1050+
resolved_manifest = resolver.preprocess_manifest(parsed_manifest)
1051+
stream_manifest = transformer.propagate_types_and_parameters(
1052+
"", resolved_manifest["a_stream"], {}
1053+
)
1054+
1055+
stream = factory.create_component(
1056+
model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config
1057+
)
1058+
1059+
assert isinstance(stream, DefaultStream)
1060+
assert isinstance(stream._stream_partition_generator._stream_slicer, SubstreamPartitionRouter)
1061+
1062+
1063+
def test_stream_with_custom_retriever_with_partition_router_field_that_is_not_a_partition_router():
1064+
"""
1065+
This test documents the behavior where if a custom retriever has a field named partition_router, it will assume
1066+
it can generate stream_slices with this parameter. In this test, the partition_router is a RecordSelector that can't
1067+
generate stream_slices so there will be an AttributeError.
1068+
"""
1069+
content = """
1070+
a_stream:
1071+
type: DeclarativeStream
1072+
primary_key: id
1073+
schema_loader:
1074+
type: InlineSchemaLoader
1075+
schema:
1076+
$schema: "http://json-schema.org/draft-07/schema"
1077+
type: object
1078+
properties:
1079+
id:
1080+
type: string
1081+
retriever:
1082+
type: CustomRetriever
1083+
class_name: unit_tests.sources.declarative.parsers.testing_components.TestingCustomRetriever
1084+
record_selector:
1085+
type: RecordSelector
1086+
extractor:
1087+
field_path: []
1088+
requester:
1089+
type: HttpRequester
1090+
url_base: "https://api.sendgrid.com/v3/"
1091+
http_method: "GET"
1092+
partition_router:
1093+
type: RecordSelector
1094+
extractor:
1095+
field_path: []
1096+
$parameters:
1097+
name: a_stream
1098+
"""
1099+
1100+
parsed_manifest = YamlDeclarativeSource._parse(content)
1101+
resolved_manifest = resolver.preprocess_manifest(parsed_manifest)
1102+
stream_manifest = transformer.propagate_types_and_parameters(
1103+
"", resolved_manifest["a_stream"], {}
1104+
)
1105+
1106+
stream = factory.create_component(
1107+
model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config
1108+
)
1109+
1110+
assert isinstance(stream, DefaultStream)
1111+
with pytest.raises(AttributeError) as e:
1112+
list(stream.generate_partitions())
1113+
assert e.value.args[0] == "'RecordSelector' object has no attribute 'stream_slices'"
1114+
1115+
1116+
def test_incremental_stream_with_custom_retriever_and_partition_router():
1117+
content = """
1118+
a_stream:
1119+
type: DeclarativeStream
1120+
primary_key: id
1121+
schema_loader:
1122+
type: InlineSchemaLoader
1123+
schema:
1124+
$schema: "http://json-schema.org/draft-07/schema"
1125+
type: object
1126+
properties:
1127+
id:
1128+
type: string
1129+
incremental_sync:
1130+
type: DatetimeBasedCursor
1131+
datetime_format: "%Y-%m-%dT%H:%M:%S.%f%z"
1132+
start_datetime: "{{ config['start_time'] }}"
1133+
cursor_field: "created"
1134+
retriever:
1135+
type: CustomRetriever
1136+
class_name: unit_tests.sources.declarative.parsers.testing_components.TestingCustomRetriever
1137+
record_selector:
1138+
type: RecordSelector
1139+
extractor:
1140+
field_path: []
1141+
requester:
1142+
type: HttpRequester
1143+
url_base: "https://api.sendgrid.com/v3/"
1144+
http_method: "GET"
1145+
partition_router:
1146+
type: SubstreamPartitionRouter
1147+
parent_stream_configs:
1148+
- parent_key: id
1149+
partition_field: id
1150+
stream:
1151+
type: DeclarativeStream
1152+
primary_key: id
1153+
schema_loader:
1154+
type: InlineSchemaLoader
1155+
schema:
1156+
$schema: "http://json-schema.org/draft-07/schema"
1157+
type: object
1158+
properties:
1159+
id:
1160+
type: string
1161+
retriever:
1162+
type: SimpleRetriever
1163+
requester:
1164+
type: HttpRequester
1165+
url_base: "https://api.sendgrid.com/v3/parent"
1166+
http_method: "GET"
1167+
record_selector:
1168+
type: RecordSelector
1169+
extractor:
1170+
field_path: []
1171+
$parameters:
1172+
name: a_stream
1173+
"""
1174+
1175+
parsed_manifest = YamlDeclarativeSource._parse(content)
1176+
resolved_manifest = resolver.preprocess_manifest(parsed_manifest)
1177+
stream_manifest = transformer.propagate_types_and_parameters(
1178+
"", resolved_manifest["a_stream"], {}
1179+
)
1180+
1181+
stream = factory.create_component(
1182+
model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config
1183+
)
1184+
1185+
assert isinstance(stream, DefaultStream)
1186+
assert isinstance(
1187+
stream._stream_partition_generator._stream_slicer, ConcurrentPerPartitionCursor
1188+
)
1189+
1190+
9951191
@pytest.mark.parametrize(
9961192
"use_legacy_state",
9971193
[

0 commit comments

Comments
 (0)