Skip to content

Commit

Permalink
feat: enable handling of nested fields when injecting request_option …
Browse files Browse the repository at this point in the history
…in request body_json (#201)
  • Loading branch information
ChristoGrab authored Feb 4, 2025
1 parent 979598c commit 126e233
Show file tree
Hide file tree
Showing 22 changed files with 707 additions and 182 deletions.
11 changes: 3 additions & 8 deletions airbyte_cdk/sources/declarative/auth/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import base64
import logging
from dataclasses import InitVar, dataclass
from typing import Any, Mapping, Union
from typing import Any, Mapping, MutableMapping, Union

import requests
from cachetools import TTLCache, cached
Expand Down Expand Up @@ -45,11 +45,6 @@ class ApiKeyAuthenticator(DeclarativeAuthenticator):
config: Config
parameters: InitVar[Mapping[str, Any]]

def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._field_name = InterpolatedString.create(
self.request_option.field_name, parameters=parameters
)

@property
def auth_header(self) -> str:
options = self._get_request_options(RequestOptionType.header)
Expand All @@ -60,9 +55,9 @@ def token(self) -> str:
return self.token_provider.get_token()

def _get_request_options(self, option_type: RequestOptionType) -> Mapping[str, Any]:
options = {}
options: MutableMapping[str, Any] = {}
if self.request_option.inject_into == option_type:
options[self._field_name.eval(self.config)] = self.token
self.request_option.inject_into_request(options, self.token, self.config)
return options

def get_request_params(self) -> Mapping[str, Any]:
Expand Down
18 changes: 14 additions & 4 deletions airbyte_cdk/sources/declarative/declarative_component_schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2847,25 +2847,35 @@ definitions:
enum: [RequestPath]
RequestOption:
title: Request Option
description: Specifies the key field and where in the request a component's value should be injected.
description: Specifies the key field or path and where in the request a component's value should be injected.
type: object
required:
- type
- field_name
- inject_into
properties:
type:
type: string
enum: [RequestOption]
field_name:
title: Request Option
description: Configures which key should be used in the location that the descriptor is being injected into
title: Field Name
description: Configures which key should be used in the location that the descriptor is being injected into. We hope to eventually deprecate this field in favor of `field_path` for all request_options, but must currently maintain it for backwards compatibility in the Builder.
type: string
examples:
- segment_id
interpolation_context:
- config
- parameters
field_path:
title: Field Path
description: Configures a path to be used for nested structures in JSON body requests (e.g. GraphQL queries)
type: array
items:
type: string
examples:
- ["data", "viewer", "id"]
interpolation_context:
- config
- parameters
inject_into:
title: Inject Into
description: Configures where the descriptor should be set on the HTTP requests. Note that request parameters that are already encoded in the URL path will not be duplicated.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,14 +365,15 @@ def _get_request_options(
options: MutableMapping[str, Any] = {}
if not stream_slice:
return options

if self.start_time_option and self.start_time_option.inject_into == option_type:
options[self.start_time_option.field_name.eval(config=self.config)] = stream_slice.get( # type: ignore # field_name is always casted to an interpolated string
self._partition_field_start.eval(self.config)
)
start_time_value = stream_slice.get(self._partition_field_start.eval(self.config))
self.start_time_option.inject_into_request(options, start_time_value, self.config)

if self.end_time_option and self.end_time_option.inject_into == option_type:
options[self.end_time_option.field_name.eval(config=self.config)] = stream_slice.get( # type: ignore [union-attr]
self._partition_field_end.eval(self.config)
)
end_time_value = stream_slice.get(self._partition_field_end.eval(self.config))
self.end_time_option.inject_into_request(options, end_time_value, self.config)

return options

def should_be_synced(self, record: Record) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1200,11 +1200,17 @@ class InjectInto(Enum):

class RequestOption(BaseModel):
type: Literal["RequestOption"]
field_name: str = Field(
...,
description="Configures which key should be used in the location that the descriptor is being injected into",
field_name: Optional[str] = Field(
None,
description="Configures which key should be used in the location that the descriptor is being injected into. We hope to eventually deprecate this field in favor of `field_path` for all request_options, but must currently maintain it for backwards compatibility in the Builder.",
examples=["segment_id"],
title="Request Option",
title="Field Name",
)
field_path: Optional[List[str]] = Field(
None,
description="Configures a path to be used for nested structures in JSON body requests (e.g. GraphQL queries)",
examples=[["data", "viewer", "id"]],
title="Field Path",
)
inject_into: InjectInto = Field(
...,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -733,8 +733,8 @@ def _json_schema_type_name_to_type(value_type: Optional[ValueType]) -> Optional[
}
return names_to_types[value_type]

@staticmethod
def create_api_key_authenticator(
self,
model: ApiKeyAuthenticatorModel,
config: Config,
token_provider: Optional[TokenProvider] = None,
Expand All @@ -756,10 +756,8 @@ def create_api_key_authenticator(
)

request_option = (
RequestOption(
inject_into=RequestOptionType(model.inject_into.inject_into.value),
field_name=model.inject_into.field_name,
parameters=model.parameters or {},
self._create_component_from_model(
model.inject_into, config, parameters=model.parameters or {}
)
if model.inject_into
else RequestOption(
Expand All @@ -768,6 +766,7 @@ def create_api_key_authenticator(
parameters=model.parameters or {},
)
)

return ApiKeyAuthenticator(
token_provider=(
token_provider
Expand Down Expand Up @@ -849,7 +848,7 @@ def create_session_token_authenticator(
token_provider=token_provider,
)
else:
return ModelToComponentFactory.create_api_key_authenticator(
return self.create_api_key_authenticator(
ApiKeyAuthenticatorModel(
type="ApiKeyAuthenticator",
api_token="",
Expand Down Expand Up @@ -1489,19 +1488,15 @@ def create_datetime_based_cursor(
)

end_time_option = (
RequestOption(
inject_into=RequestOptionType(model.end_time_option.inject_into.value),
field_name=model.end_time_option.field_name,
parameters=model.parameters or {},
self._create_component_from_model(
model.end_time_option, config, parameters=model.parameters or {}
)
if model.end_time_option
else None
)
start_time_option = (
RequestOption(
inject_into=RequestOptionType(model.start_time_option.inject_into.value),
field_name=model.start_time_option.field_name,
parameters=model.parameters or {},
self._create_component_from_model(
model.start_time_option, config, parameters=model.parameters or {}
)
if model.start_time_option
else None
Expand Down Expand Up @@ -1572,19 +1567,15 @@ def create_declarative_stream(
cursor_model = model.incremental_sync

end_time_option = (
RequestOption(
inject_into=RequestOptionType(cursor_model.end_time_option.inject_into.value),
field_name=cursor_model.end_time_option.field_name,
parameters=cursor_model.parameters or {},
self._create_component_from_model(
cursor_model.end_time_option, config, parameters=cursor_model.parameters or {}
)
if cursor_model.end_time_option
else None
)
start_time_option = (
RequestOption(
inject_into=RequestOptionType(cursor_model.start_time_option.inject_into.value),
field_name=cursor_model.start_time_option.field_name,
parameters=cursor_model.parameters or {},
self._create_component_from_model(
cursor_model.start_time_option, config, parameters=cursor_model.parameters or {}
)
if cursor_model.start_time_option
else None
Expand Down Expand Up @@ -2150,16 +2141,11 @@ def create_jwt_authenticator(
additional_jwt_payload=model.additional_jwt_payload,
)

@staticmethod
def create_list_partition_router(
model: ListPartitionRouterModel, config: Config, **kwargs: Any
self, model: ListPartitionRouterModel, config: Config, **kwargs: Any
) -> ListPartitionRouter:
request_option = (
RequestOption(
inject_into=RequestOptionType(model.request_option.inject_into.value),
field_name=model.request_option.field_name,
parameters=model.parameters or {},
)
self._create_component_from_model(model.request_option, config)
if model.request_option
else None
)
Expand Down Expand Up @@ -2355,7 +2341,25 @@ def create_request_option(
model: RequestOptionModel, config: Config, **kwargs: Any
) -> RequestOption:
inject_into = RequestOptionType(model.inject_into.value)
return RequestOption(field_name=model.field_name, inject_into=inject_into, parameters={})
field_path: Optional[List[Union[InterpolatedString, str]]] = (
[
InterpolatedString.create(segment, parameters=kwargs.get("parameters", {}))
for segment in model.field_path
]
if model.field_path
else None
)
field_name = (
InterpolatedString.create(model.field_name, parameters=kwargs.get("parameters", {}))
if model.field_name
else None
)
return RequestOption(
field_name=field_name,
field_path=field_path,
inject_into=inject_into,
parameters=kwargs.get("parameters", {}),
)

def create_record_selector(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#

from dataclasses import InitVar, dataclass
from typing import Any, Iterable, List, Mapping, Optional, Union
from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union

from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter
Expand Down Expand Up @@ -100,7 +100,9 @@ def _get_request_option(
):
slice_value = stream_slice.get(self._cursor_field.eval(self.config))
if slice_value:
return {self.request_option.field_name.eval(self.config): slice_value} # type: ignore # field_name is always casted to InterpolatedString
options: MutableMapping[str, Any] = {}
self.request_option.inject_into_request(options, slice_value, self.config)
return options
else:
return {}
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import copy
import logging
from dataclasses import InitVar, dataclass
from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, Optional, Union
from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, MutableMapping, Optional, Union

import dpath

Expand Down Expand Up @@ -118,7 +118,7 @@ def get_request_body_json(
def _get_request_option(
self, option_type: RequestOptionType, stream_slice: Optional[StreamSlice]
) -> Mapping[str, Any]:
params = {}
params: MutableMapping[str, Any] = {}
if stream_slice:
for parent_config in self.parent_stream_configs:
if (
Expand All @@ -128,13 +128,7 @@ def _get_request_option(
key = parent_config.partition_field.eval(self.config) # type: ignore # partition_field is always casted to an interpolated string
value = stream_slice.get(key)
if value:
params.update(
{
parent_config.request_option.field_name.eval( # type: ignore [union-attr]
config=self.config
): value
}
)
parent_config.request_option.inject_into_request(params, value, self.config)
return params

def stream_slices(self) -> Iterable[StreamSlice]:
Expand Down
6 changes: 5 additions & 1 deletion airbyte_cdk/sources/declarative/requesters/http_requester.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ def _get_request_options(
Raise a ValueError if there's a key collision
Returned merged mapping otherwise
"""

is_body_json = requester_method.__name__ == "get_request_body_json"

return combine_mappings(
[
requester_method(
Expand All @@ -208,7 +211,8 @@ def _get_request_options(
),
auth_options_method(),
extra_options,
]
],
allow_same_value_merge=is_body_json,
)

def _request_headers(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def get_request_body_json(
def _get_request_options(
self, option_type: RequestOptionType, next_page_token: Optional[Mapping[str, Any]]
) -> MutableMapping[str, Any]:
options = {}
options: MutableMapping[str, Any] = {}

token = next_page_token.get("next_page_token") if next_page_token else None
if (
Expand All @@ -196,15 +196,16 @@ def _get_request_options(
and isinstance(self.page_token_option, RequestOption)
and self.page_token_option.inject_into == option_type
):
options[self.page_token_option.field_name.eval(config=self.config)] = token # type: ignore # field_name is always cast to an interpolated string
self.page_token_option.inject_into_request(options, token, self.config)

if (
self.page_size_option
and self.pagination_strategy.get_page_size()
and self.page_size_option.inject_into == option_type
):
options[self.page_size_option.field_name.eval(config=self.config)] = ( # type: ignore [union-attr]
self.pagination_strategy.get_page_size()
) # type: ignore # field_name is always cast to an interpolated string
page_size = self.pagination_strategy.get_page_size()
self.page_size_option.inject_into_request(options, page_size, self.config)

return options


Expand Down
Loading

0 comments on commit 126e233

Please sign in to comment.