diff --git a/airbyte_cdk/sources/declarative/declarative_component_schema.yaml b/airbyte_cdk/sources/declarative/declarative_component_schema.yaml index 642f7b5bb..55a6e02d9 100644 --- a/airbyte_cdk/sources/declarative/declarative_component_schema.yaml +++ b/airbyte_cdk/sources/declarative/declarative_component_schema.yaml @@ -40,6 +40,8 @@ properties: "$ref": "#/definitions/Spec" concurrency_level: "$ref": "#/definitions/ConcurrencyLevel" + api_budget: + "$ref": "#/definitions/HTTPAPIBudget" metadata: type: object description: For internal Airbyte use only - DO NOT modify manually. Used by consumers of declarative manifests for storing related metadata. @@ -794,7 +796,7 @@ definitions: description: This option is used to adjust the upper and lower boundaries of each datetime window to beginning and end of the provided target period (day, week, month) type: object required: - - target + - target properties: target: title: Target @@ -1365,6 +1367,170 @@ definitions: $parameters: type: object additional_properties: true + HTTPAPIBudget: + title: HTTP API Budget + description: > + Defines how many requests can be made to the API in a given time frame. `HTTPAPIBudget` extracts the remaining + call count and the reset time from HTTP response headers using the header names provided by + `ratelimit_remaining_header` and `ratelimit_reset_header`. Only requests using `HttpRequester` + are rate-limited; custom components that bypass `HttpRequester` are not covered by this budget. + type: object + required: + - type + - policies + properties: + type: + type: string + enum: [HTTPAPIBudget] + policies: + title: Policies + description: List of call rate policies that define how many calls are allowed. + type: array + items: + anyOf: + - "$ref": "#/definitions/FixedWindowCallRatePolicy" + - "$ref": "#/definitions/MovingWindowCallRatePolicy" + - "$ref": "#/definitions/UnlimitedCallRatePolicy" + ratelimit_reset_header: + title: Rate Limit Reset Header + description: The HTTP response header name that indicates when the rate limit resets. + type: string + default: "ratelimit-reset" + ratelimit_remaining_header: + title: Rate Limit Remaining Header + description: The HTTP response header name that indicates the number of remaining allowed calls. + type: string + default: "ratelimit-remaining" + status_codes_for_ratelimit_hit: + title: Status Codes for Rate Limit Hit + description: List of HTTP status codes that indicate a rate limit has been hit. + type: array + items: + type: integer + default: [429] + additionalProperties: true + FixedWindowCallRatePolicy: + title: Fixed Window Call Rate Policy + description: A policy that allows a fixed number of calls within a specific time window. + type: object + required: + - type + - period + - call_limit + - matchers + properties: + type: + type: string + enum: [FixedWindowCallRatePolicy] + period: + title: Period + description: The time interval for the rate limit window. + type: string + call_limit: + title: Call Limit + description: The maximum number of calls allowed within the period. + type: integer + matchers: + title: Matchers + description: List of matchers that define which requests this policy applies to. + type: array + items: + "$ref": "#/definitions/HttpRequestRegexMatcher" + additionalProperties: true + MovingWindowCallRatePolicy: + title: Moving Window Call Rate Policy + description: A policy that allows a fixed number of calls within a moving time window. + type: object + required: + - type + - rates + - matchers + properties: + type: + type: string + enum: [MovingWindowCallRatePolicy] + rates: + title: Rates + description: List of rates that define the call limits for different time intervals. + type: array + items: + "$ref": "#/definitions/Rate" + matchers: + title: Matchers + description: List of matchers that define which requests this policy applies to. + type: array + items: + "$ref": "#/definitions/HttpRequestRegexMatcher" + additionalProperties: true + UnlimitedCallRatePolicy: + title: Unlimited Call Rate Policy + description: A policy that allows unlimited calls for specific requests. + type: object + required: + - type + - matchers + properties: + type: + type: string + enum: [UnlimitedCallRatePolicy] + matchers: + title: Matchers + description: List of matchers that define which requests this policy applies to. + type: array + items: + "$ref": "#/definitions/HttpRequestRegexMatcher" + additionalProperties: true + Rate: + title: Rate + description: Defines a rate limit with a specific number of calls allowed within a time interval. + type: object + required: + - limit + - interval + properties: + limit: + title: Limit + description: The maximum number of calls allowed within the interval. + type: integer + interval: + title: Interval + description: The time interval for the rate limit. + type: string + examples: + - "PT1H" + - "P1D" + additionalProperties: true + HttpRequestRegexMatcher: + title: HTTP Request Matcher + description: > + Matches HTTP requests based on method, base URL, URL path pattern, query parameters, and headers. + Use `url_base` to specify the scheme and host (without trailing slash) and + `url_path_pattern` to apply a regex to the request path. + type: object + properties: + method: + title: Method + description: The HTTP method to match (e.g., GET, POST). + type: string + url_base: + title: URL Base + description: The base URL (scheme and host, e.g. "https://api.example.com") to match. + type: string + url_path_pattern: + title: URL Path Pattern + description: A regular expression pattern to match the URL path. + type: string + params: + title: Parameters + description: The query parameters to match. + type: object + additionalProperties: true + headers: + title: Headers + description: The headers to match. + type: object + additionalProperties: true + additionalProperties: true DefaultErrorHandler: title: Default Error Handler description: Component defining how to handle errors. Default behavior includes only retrying server errors (HTTP 5XX) and too many requests (HTTP 429) with an exponential backoff. diff --git a/airbyte_cdk/sources/declarative/manifest_declarative_source.py b/airbyte_cdk/sources/declarative/manifest_declarative_source.py index efc779464..d3afb1396 100644 --- a/airbyte_cdk/sources/declarative/manifest_declarative_source.py +++ b/airbyte_cdk/sources/declarative/manifest_declarative_source.py @@ -137,6 +137,10 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]: self._source_config, config ) + api_budget_model = self._source_config.get("api_budget") + if api_budget_model: + self._constructor.set_api_budget(api_budget_model, config) + source_streams = [ self._constructor.create_component( DeclarativeStreamModel, diff --git a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py index d2454ee78..8f5be6867 100644 --- a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py +++ b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py @@ -642,6 +642,48 @@ class OAuthAuthenticator(BaseModel): parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") +class Rate(BaseModel): + class Config: + extra = Extra.allow + + limit: int = Field( + ..., + description="The maximum number of calls allowed within the interval.", + title="Limit", + ) + interval: str = Field( + ..., + description="The time interval for the rate limit.", + examples=["PT1H", "P1D"], + title="Interval", + ) + + +class HttpRequestRegexMatcher(BaseModel): + class Config: + extra = Extra.allow + + method: Optional[str] = Field( + None, description="The HTTP method to match (e.g., GET, POST).", title="Method" + ) + url_base: Optional[str] = Field( + None, + description='The base URL (scheme and host, e.g. "https://api.example.com") to match.', + title="URL Base", + ) + url_path_pattern: Optional[str] = Field( + None, + description="A regular expression pattern to match the URL path.", + title="URL Path Pattern", + ) + params: Optional[Dict[str, Any]] = Field( + None, description="The query parameters to match.", title="Parameters" + ) + headers: Optional[Dict[str, Any]] = Field( + None, description="The headers to match.", title="Headers" + ) + + class DpathExtractor(BaseModel): type: Literal["DpathExtractor"] field_path: List[str] = Field( @@ -1565,6 +1607,55 @@ class DatetimeBasedCursor(BaseModel): parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") +class FixedWindowCallRatePolicy(BaseModel): + class Config: + extra = Extra.allow + + type: Literal["FixedWindowCallRatePolicy"] + period: str = Field( + ..., description="The time interval for the rate limit window.", title="Period" + ) + call_limit: int = Field( + ..., + description="The maximum number of calls allowed within the period.", + title="Call Limit", + ) + matchers: List[HttpRequestRegexMatcher] = Field( + ..., + description="List of matchers that define which requests this policy applies to.", + title="Matchers", + ) + + +class MovingWindowCallRatePolicy(BaseModel): + class Config: + extra = Extra.allow + + type: Literal["MovingWindowCallRatePolicy"] + rates: List[Rate] = Field( + ..., + description="List of rates that define the call limits for different time intervals.", + title="Rates", + ) + matchers: List[HttpRequestRegexMatcher] = Field( + ..., + description="List of matchers that define which requests this policy applies to.", + title="Matchers", + ) + + +class UnlimitedCallRatePolicy(BaseModel): + class Config: + extra = Extra.allow + + type: Literal["UnlimitedCallRatePolicy"] + matchers: List[HttpRequestRegexMatcher] = Field( + ..., + description="List of matchers that define which requests this policy applies to.", + title="Matchers", + ) + + class DefaultErrorHandler(BaseModel): type: Literal["DefaultErrorHandler"] backoff_strategies: Optional[ @@ -1696,6 +1787,39 @@ class CompositeErrorHandler(BaseModel): parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") +class HTTPAPIBudget(BaseModel): + class Config: + extra = Extra.allow + + type: Literal["HTTPAPIBudget"] + policies: List[ + Union[ + FixedWindowCallRatePolicy, + MovingWindowCallRatePolicy, + UnlimitedCallRatePolicy, + ] + ] = Field( + ..., + description="List of call rate policies that define how many calls are allowed.", + title="Policies", + ) + ratelimit_reset_header: Optional[str] = Field( + "ratelimit-reset", + description="The HTTP response header name that indicates when the rate limit resets.", + title="Rate Limit Reset Header", + ) + ratelimit_remaining_header: Optional[str] = Field( + "ratelimit-remaining", + description="The HTTP response header name that indicates the number of remaining allowed calls.", + title="Rate Limit Remaining Header", + ) + status_codes_for_ratelimit_hit: Optional[List[int]] = Field( + [429], + description="List of HTTP status codes that indicate a rate limit has been hit.", + title="Status Codes for Rate Limit Hit", + ) + + class ZipfileDecoder(BaseModel): class Config: extra = Extra.allow @@ -1724,6 +1848,7 @@ class Config: definitions: Optional[Dict[str, Any]] = None spec: Optional[Spec] = None concurrency_level: Optional[ConcurrencyLevel] = None + api_budget: Optional[HTTPAPIBudget] = None metadata: Optional[Dict[str, Any]] = Field( None, description="For internal Airbyte use only - DO NOT modify manually. Used by consumers of declarative manifests for storing related metadata.", @@ -1750,6 +1875,7 @@ class Config: definitions: Optional[Dict[str, Any]] = None spec: Optional[Spec] = None concurrency_level: Optional[ConcurrencyLevel] = None + api_budget: Optional[HTTPAPIBudget] = None metadata: Optional[Dict[str, Any]] = Field( None, description="For internal Airbyte use only - DO NOT modify manually. Used by consumers of declarative manifests for storing related metadata.", diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index b690ce705..12464b40a 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -221,18 +221,27 @@ from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( ExponentialBackoffStrategy as ExponentialBackoffStrategyModel, ) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + FixedWindowCallRatePolicy as FixedWindowCallRatePolicyModel, +) from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( FlattenFields as FlattenFieldsModel, ) from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( GzipDecoder as GzipDecoderModel, ) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + HTTPAPIBudget as HTTPAPIBudgetModel, +) from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( HttpComponentsResolver as HttpComponentsResolverModel, ) from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( HttpRequester as HttpRequesterModel, ) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + HttpRequestRegexMatcher as HttpRequestRegexMatcherModel, +) from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( HttpResponseFilter as HttpResponseFilterModel, ) @@ -281,6 +290,9 @@ from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( MinMaxDatetime as MinMaxDatetimeModel, ) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + MovingWindowCallRatePolicy as MovingWindowCallRatePolicyModel, +) from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( NoAuth as NoAuthModel, ) @@ -299,6 +311,9 @@ from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( ParentStreamConfig as ParentStreamConfigModel, ) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + Rate as RateModel, +) from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( RecordFilter as RecordFilterModel, ) @@ -342,6 +357,9 @@ from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( TypesMap as TypesMapModel, ) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + UnlimitedCallRatePolicy as UnlimitedCallRatePolicyModel, +) from airbyte_cdk.sources.declarative.models.declarative_component_schema import ValueType from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( WaitTimeFromHeader as WaitTimeFromHeaderModel, @@ -455,6 +473,15 @@ MessageRepository, NoopMessageRepository, ) +from airbyte_cdk.sources.streams.call_rate import ( + APIBudget, + FixedWindowCallRatePolicy, + HttpAPIBudget, + HttpRequestRegexMatcher, + MovingWindowCallRatePolicy, + Rate, + UnlimitedCallRatePolicy, +) from airbyte_cdk.sources.streams.concurrent.clamping import ( ClampingEndProvider, ClampingStrategy, @@ -506,6 +533,7 @@ def __init__( self._evaluate_log_level(emit_connector_builder_messages) ) self._connector_state_manager = connector_state_manager or ConnectorStateManager() + self._api_budget: Optional[Union[APIBudget, HttpAPIBudget]] = None def _init_mappings(self) -> None: self.PYDANTIC_MODEL_TO_CONSTRUCTOR: Mapping[Type[BaseModel], Callable[..., Any]] = { @@ -590,6 +618,12 @@ def _init_mappings(self) -> None: StreamConfigModel: self.create_stream_config, ComponentMappingDefinitionModel: self.create_components_mapping_definition, ZipfileDecoderModel: self.create_zipfile_decoder, + HTTPAPIBudgetModel: self.create_http_api_budget, + FixedWindowCallRatePolicyModel: self.create_fixed_window_call_rate_policy, + MovingWindowCallRatePolicyModel: self.create_moving_window_call_rate_policy, + UnlimitedCallRatePolicyModel: self.create_unlimited_call_rate_policy, + RateModel: self.create_rate, + HttpRequestRegexMatcherModel: self.create_http_request_matcher, } # Needed for the case where we need to perform a second parse on the fields of a custom component @@ -1902,6 +1936,8 @@ def create_http_requester( ) ) + api_budget = self._api_budget + request_options_provider = InterpolatedRequestOptionsProvider( request_body_data=model.request_body_data, request_body_json=model.request_body_json, @@ -1922,6 +1958,7 @@ def create_http_requester( path=model.path, authenticator=authenticator, error_handler=error_handler, + api_budget=api_budget, http_method=HttpMethod[model.http_method.value], request_options_provider=request_options_provider, config=config, @@ -2921,3 +2958,84 @@ def _is_supported_parser_for_pagination(self, parser: Parser) -> bool: return isinstance(parser.inner_parser, JsonParser) else: return False + + def create_http_api_budget( + self, model: HTTPAPIBudgetModel, config: Config, **kwargs: Any + ) -> HttpAPIBudget: + policies = [ + self._create_component_from_model(model=policy, config=config) + for policy in model.policies + ] + + return HttpAPIBudget( + policies=policies, + ratelimit_reset_header=model.ratelimit_reset_header or "ratelimit-reset", + ratelimit_remaining_header=model.ratelimit_remaining_header or "ratelimit-remaining", + status_codes_for_ratelimit_hit=model.status_codes_for_ratelimit_hit or [429], + ) + + def create_fixed_window_call_rate_policy( + self, model: FixedWindowCallRatePolicyModel, config: Config, **kwargs: Any + ) -> FixedWindowCallRatePolicy: + matchers = [ + self._create_component_from_model(model=matcher, config=config) + for matcher in model.matchers + ] + + # Set the initial reset timestamp to 10 days from now. + # This value will be updated by the first request. + return FixedWindowCallRatePolicy( + next_reset_ts=datetime.datetime.now() + datetime.timedelta(days=10), + period=parse_duration(model.period), + call_limit=model.call_limit, + matchers=matchers, + ) + + def create_moving_window_call_rate_policy( + self, model: MovingWindowCallRatePolicyModel, config: Config, **kwargs: Any + ) -> MovingWindowCallRatePolicy: + rates = [ + self._create_component_from_model(model=rate, config=config) for rate in model.rates + ] + matchers = [ + self._create_component_from_model(model=matcher, config=config) + for matcher in model.matchers + ] + return MovingWindowCallRatePolicy( + rates=rates, + matchers=matchers, + ) + + def create_unlimited_call_rate_policy( + self, model: UnlimitedCallRatePolicyModel, config: Config, **kwargs: Any + ) -> UnlimitedCallRatePolicy: + matchers = [ + self._create_component_from_model(model=matcher, config=config) + for matcher in model.matchers + ] + + return UnlimitedCallRatePolicy( + matchers=matchers, + ) + + def create_rate(self, model: RateModel, config: Config, **kwargs: Any) -> Rate: + return Rate( + limit=model.limit, + interval=parse_duration(model.interval), + ) + + def create_http_request_matcher( + self, model: HttpRequestRegexMatcherModel, config: Config, **kwargs: Any + ) -> HttpRequestRegexMatcher: + return HttpRequestRegexMatcher( + method=model.method, + url_base=model.url_base, + url_path_pattern=model.url_path_pattern, + params=model.params, + headers=model.headers, + ) + + def set_api_budget(self, component_definition: ComponentDefinition, config: Config) -> None: + self._api_budget = self.create_component( + model_type=HTTPAPIBudgetModel, component_definition=component_definition, config=config + ) diff --git a/airbyte_cdk/sources/declarative/requesters/http_requester.py b/airbyte_cdk/sources/declarative/requesters/http_requester.py index ad23f4d06..b206bd688 100644 --- a/airbyte_cdk/sources/declarative/requesters/http_requester.py +++ b/airbyte_cdk/sources/declarative/requesters/http_requester.py @@ -22,6 +22,7 @@ ) from airbyte_cdk.sources.declarative.requesters.requester import HttpMethod, Requester from airbyte_cdk.sources.message import MessageRepository, NoopMessageRepository +from airbyte_cdk.sources.streams.call_rate import APIBudget from airbyte_cdk.sources.streams.http import HttpClient from airbyte_cdk.sources.streams.http.error_handlers import ErrorHandler from airbyte_cdk.sources.types import Config, StreamSlice, StreamState @@ -55,6 +56,7 @@ class HttpRequester(Requester): http_method: Union[str, HttpMethod] = HttpMethod.GET request_options_provider: Optional[InterpolatedRequestOptionsProvider] = None error_handler: Optional[ErrorHandler] = None + api_budget: Optional[APIBudget] = None disable_retries: bool = False message_repository: MessageRepository = NoopMessageRepository() use_cache: bool = False @@ -91,6 +93,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: name=self.name, logger=self.logger, error_handler=self.error_handler, + api_budget=self.api_budget, authenticator=self._authenticator, use_cache=self.use_cache, backoff_strategy=backoff_strategies, diff --git a/airbyte_cdk/sources/streams/call_rate.py b/airbyte_cdk/sources/streams/call_rate.py index 81ebac78e..14f823e45 100644 --- a/airbyte_cdk/sources/streams/call_rate.py +++ b/airbyte_cdk/sources/streams/call_rate.py @@ -6,6 +6,7 @@ import dataclasses import datetime import logging +import re import time from datetime import timedelta from threading import RLock @@ -25,6 +26,7 @@ MIXIN_BASE = object logger = logging.getLogger("airbyte") +logging.getLogger("pyrate_limiter").setLevel(logging.WARNING) @dataclasses.dataclass @@ -98,7 +100,7 @@ def __call__(self, request: Any) -> bool: class HttpRequestMatcher(RequestMatcher): - """Simple implementation of RequestMatcher for http requests case""" + """Simple implementation of RequestMatcher for HTTP requests using HttpRequestRegexMatcher under the hood.""" def __init__( self, @@ -109,32 +111,94 @@ def __init__( ): """Constructor - :param method: - :param url: - :param params: - :param headers: + :param method: HTTP method (e.g., "GET", "POST"). + :param url: Full URL to match. + :param params: Dictionary of query parameters to match. + :param headers: Dictionary of headers to match. """ - self._method = method - self._url = url + # Parse the URL to extract the base and path + if url: + parsed_url = parse.urlsplit(url) + url_base = f"{parsed_url.scheme}://{parsed_url.netloc}" + url_path = parsed_url.path if parsed_url.path != "/" else None + else: + url_base = None + url_path = None + + # Use HttpRequestRegexMatcher under the hood + self._regex_matcher = HttpRequestRegexMatcher( + method=method, + url_base=url_base, + url_path_pattern=re.escape(url_path) if url_path else None, + params=params, + headers=headers, + ) + + def __call__(self, request: Any) -> bool: + """ + :param request: A requests.Request or requests.PreparedRequest instance. + :return: True if the request matches all provided criteria; False otherwise. + """ + return self._regex_matcher(request) + + def __str__(self) -> str: + return ( + f"HttpRequestMatcher(method={self._regex_matcher._method}, " + f"url={self._regex_matcher._url_base}{self._regex_matcher._url_path_pattern.pattern if self._regex_matcher._url_path_pattern else ''}, " + f"params={self._regex_matcher._params}, headers={self._regex_matcher._headers})" + ) + + +class HttpRequestRegexMatcher(RequestMatcher): + """ + Extended RequestMatcher for HTTP requests that supports matching on: + - HTTP method (case-insensitive) + - URL base (scheme + netloc) optionally + - URL path pattern (a regex applied to the path portion of the URL) + - Query parameters (must be present) + - Headers (header names compared case-insensitively) + """ + + def __init__( + self, + method: Optional[str] = None, + url_base: Optional[str] = None, + url_path_pattern: Optional[str] = None, + params: Optional[Mapping[str, Any]] = None, + headers: Optional[Mapping[str, Any]] = None, + ): + """ + :param method: HTTP method (e.g. "GET", "POST"); compared case-insensitively. + :param url_base: Base URL (scheme://host) that must match. + :param url_path_pattern: A regex pattern that will be applied to the path portion of the URL. + :param params: Dictionary of query parameters that must be present in the request. + :param headers: Dictionary of headers that must be present (header keys are compared case-insensitively). + """ + self._method = method.upper() if method else None + + # Normalize the url_base if provided: remove trailing slash. + self._url_base = url_base.rstrip("/") if url_base else None + + # Compile the URL path pattern if provided. + self._url_path_pattern = re.compile(url_path_pattern) if url_path_pattern else None + + # Normalize query parameters to strings. self._params = {str(k): str(v) for k, v in (params or {}).items()} - self._headers = {str(k): str(v) for k, v in (headers or {}).items()} + + # Normalize header keys to lowercase. + self._headers = {str(k).lower(): str(v) for k, v in (headers or {}).items()} @staticmethod def _match_dict(obj: Mapping[str, Any], pattern: Mapping[str, Any]) -> bool: - """Check that all elements from pattern dict present and have the same values in obj dict - - :param obj: - :param pattern: - :return: - """ + """Check that every key/value in the pattern exists in the object.""" return pattern.items() <= obj.items() def __call__(self, request: Any) -> bool: """ - - :param request: - :return: True if matches the provided request object, False - otherwise + :param request: A requests.Request or requests.PreparedRequest instance. + :return: True if the request matches all provided criteria; False otherwise. """ + # Prepare the request (if needed) and extract the URL details. if isinstance(request, requests.Request): prepared_request = request.prepare() elif isinstance(request, requests.PreparedRequest): @@ -142,23 +206,49 @@ def __call__(self, request: Any) -> bool: else: return False + # Check HTTP method. if self._method is not None: if prepared_request.method != self._method: return False - if self._url is not None and prepared_request.url is not None: - url_without_params = prepared_request.url.split("?")[0] - if url_without_params != self._url: + + # Parse the URL. + parsed_url = parse.urlsplit(prepared_request.url) + # Reconstruct the base: scheme://netloc + request_url_base = f"{str(parsed_url.scheme)}://{str(parsed_url.netloc)}" + # The path (without query parameters) + request_path = str(parsed_url.path).rstrip("/") + + # If a base URL is provided, check that it matches. + if self._url_base is not None: + if request_url_base != self._url_base: + return False + + # If a URL path pattern is provided, ensure the path matches the regex. + if self._url_path_pattern is not None: + if not self._url_path_pattern.search(request_path): return False - if self._params is not None: - parsed_url = parse.urlsplit(prepared_request.url) - params = dict(parse.parse_qsl(str(parsed_url.query))) - if not self._match_dict(params, self._params): + + # Check query parameters. + if self._params: + query_params = dict(parse.parse_qsl(str(parsed_url.query))) + if not self._match_dict(query_params, self._params): return False - if self._headers is not None: - if not self._match_dict(prepared_request.headers, self._headers): + + # Check headers (normalize keys to lower-case). + if self._headers: + req_headers = {k.lower(): v for k, v in prepared_request.headers.items()} + if not self._match_dict(req_headers, self._headers): return False + return True + def __str__(self) -> str: + regex = self._url_path_pattern.pattern if self._url_path_pattern else None + return ( + f"HttpRequestRegexMatcher(method={self._method}, url_base={self._url_base}, " + f"url_path_pattern={regex}, params={self._params}, headers={self._headers})" + ) + class BaseCallRatePolicy(AbstractCallRatePolicy, abc.ABC): def __init__(self, matchers: list[RequestMatcher]): @@ -257,6 +347,14 @@ def try_acquire(self, request: Any, weight: int) -> None: self._calls_num += weight + def __str__(self) -> str: + matcher_str = ", ".join(f"{matcher}" for matcher in self._matchers) + return ( + f"FixedWindowCallRatePolicy(call_limit={self._call_limit}, period={self._offset}, " + f"calls_used={self._calls_num}, next_reset={self._next_reset_ts}, " + f"matchers=[{matcher_str}])" + ) + def update( self, available_calls: Optional[int], call_reset_ts: Optional[datetime.datetime] ) -> None: @@ -363,6 +461,19 @@ def update( # if available_calls is not None and call_reset_ts is not None: # ts = call_reset_ts.timestamp() + def __str__(self) -> str: + """Return a human-friendly description of the moving window rate policy for logging purposes.""" + rates_info = ", ".join( + f"{rate.limit} per {timedelta(milliseconds=rate.interval)}" + for rate in self._bucket.rates + ) + current_bucket_count = self._bucket.count() + matcher_str = ", ".join(f"{matcher}" for matcher in self._matchers) + return ( + f"MovingWindowCallRatePolicy(rates=[{rates_info}], current_bucket_count={current_bucket_count}, " + f"matchers=[{matcher_str}])" + ) + class AbstractAPIBudget(abc.ABC): """Interface to some API where a client allowed to have N calls per T interval. @@ -415,6 +526,23 @@ def __init__( self._policies = policies self._maximum_attempts_to_acquire = maximum_attempts_to_acquire + def _extract_endpoint(self, request: Any) -> str: + """Extract the endpoint URL from the request if available.""" + endpoint = None + try: + # If the request is already a PreparedRequest, it should have a URL. + if isinstance(request, requests.PreparedRequest): + endpoint = request.url + # If it's a requests.Request, we call prepare() to extract the URL. + elif isinstance(request, requests.Request): + prepared = request.prepare() + endpoint = prepared.url + except Exception as e: + logger.debug(f"Error extracting endpoint: {e}") + if endpoint: + return endpoint + return "unknown endpoint" + def get_matching_policy(self, request: Any) -> Optional[AbstractCallRatePolicy]: for policy in self._policies: if policy.matches(request): @@ -428,20 +556,24 @@ def acquire_call( Matchers will be called sequentially in the same order they were added. The first matcher that returns True will - :param request: - :param block: when true (default) will block the current thread until call credit is available - :param timeout: if provided will limit maximum time in block, otherwise will wait until credit is available - :raises: CallRateLimitHit - when no calls left and if timeout was set the waiting time exceed the timeout + :param request: the API request + :param block: when True (default) will block until a call credit is available + :param timeout: if provided, limits maximum waiting time; otherwise, waits indefinitely + :raises: CallRateLimitHit if the call credit cannot be acquired within the timeout """ policy = self.get_matching_policy(request) + endpoint = self._extract_endpoint(request) if policy: + logger.debug(f"Acquiring call for endpoint {endpoint} using policy: {policy}") self._do_acquire(request=request, policy=policy, block=block, timeout=timeout) elif self._policies: - logger.info("no policies matched with requests, allow call by default") + logger.debug( + f"No policies matched for endpoint {endpoint} (request: {request}). Allowing call by default." + ) def update_from_response(self, request: Any, response: Any) -> None: - """Update budget information based on response from API + """Update budget information based on the API response. :param request: the initial request that triggered this response :param response: response from the API @@ -451,15 +583,17 @@ def update_from_response(self, request: Any, response: Any) -> None: def _do_acquire( self, request: Any, policy: AbstractCallRatePolicy, block: bool, timeout: Optional[float] ) -> None: - """Internal method to try to acquire a call credit + """Internal method to try to acquire a call credit. - :param request: - :param policy: - :param block: - :param timeout: + :param request: the API request + :param policy: the matching rate-limiting policy + :param block: indicates whether to block until a call credit is available + :param timeout: maximum time to wait if blocking + :raises: CallRateLimitHit if unable to acquire a call credit """ last_exception = None - # sometimes we spend all budget before a second attempt, so we have few more here + endpoint = self._extract_endpoint(request) + # sometimes we spend all budget before a second attempt, so we have a few more attempts for attempt in range(1, self._maximum_attempts_to_acquire): try: policy.try_acquire(request, weight=1) @@ -471,20 +605,24 @@ def _do_acquire( time_to_wait = min(timedelta(seconds=timeout), exc.time_to_wait) else: time_to_wait = exc.time_to_wait - - time_to_wait = max( - timedelta(0), time_to_wait - ) # sometimes we get negative duration - logger.info( - "reached call limit %s. going to sleep for %s", exc.rate, time_to_wait + # Ensure we never sleep for a negative duration. + time_to_wait = max(timedelta(0), time_to_wait) + logger.debug( + f"Policy {policy} reached call limit for endpoint {endpoint} ({exc.rate}). " + f"Sleeping for {time_to_wait} on attempt {attempt}." ) time.sleep(time_to_wait.total_seconds()) else: + logger.debug( + f"Policy {policy} reached call limit for endpoint {endpoint} ({exc.rate}) " + f"and blocking is disabled." + ) raise if last_exception: - logger.info( - "we used all %s attempts to acquire and failed", self._maximum_attempts_to_acquire + logger.debug( + f"Exhausted all {self._maximum_attempts_to_acquire} attempts to acquire a call for endpoint {endpoint} " + f"using policy: {policy}" ) raise last_exception @@ -496,7 +634,7 @@ def __init__( self, ratelimit_reset_header: str = "ratelimit-reset", ratelimit_remaining_header: str = "ratelimit-remaining", - status_codes_for_ratelimit_hit: tuple[int] = (429,), + status_codes_for_ratelimit_hit: list[int] = [429], **kwargs: Any, ): """Constructor diff --git a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py index 32a73f364..a062cdfc7 100644 --- a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py +++ b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py @@ -142,6 +142,7 @@ from airbyte_cdk.sources.declarative.transformations import AddFields, RemoveFields from airbyte_cdk.sources.declarative.transformations.add_fields import AddedFieldDefinition from airbyte_cdk.sources.declarative.yaml_declarative_source import YamlDeclarativeSource +from airbyte_cdk.sources.streams.call_rate import MovingWindowCallRatePolicy from airbyte_cdk.sources.streams.concurrent.clamping import ( ClampingEndProvider, DayClampingStrategy, @@ -3684,3 +3685,155 @@ def test_create_async_retriever(): assert isinstance(selector, RecordSelector) assert isinstance(extractor, DpathExtractor) assert extractor.field_path == ["data"] + + +def test_api_budget(): + manifest = { + "type": "DeclarativeSource", + "api_budget": { + "type": "HTTPAPIBudget", + "ratelimit_reset_header": "X-RateLimit-Reset", + "ratelimit_remaining_header": "X-RateLimit-Remaining", + "status_codes_for_ratelimit_hit": [429, 503], + "policies": [ + { + "type": "MovingWindowCallRatePolicy", + "rates": [ + { + "type": "Rate", + "limit": 3, + "interval": "PT0.1S", # 0.1 seconds + } + ], + "matchers": [ + { + "type": "HttpRequestRegexMatcher", + "method": "GET", + "url_base": "https://api.sendgrid.com", + "url_path_pattern": "/v3/marketing/lists", + } + ], + } + ], + }, + "my_requester": { + "type": "HttpRequester", + "path": "/v3/marketing/lists", + "url_base": "https://api.sendgrid.com", + "http_method": "GET", + "authenticator": { + "type": "BasicHttpAuthenticator", + "username": "admin", + "password": "{{ config['password'] }}", + }, + }, + } + + config = { + "password": "verysecrettoken", + } + + factory = ModelToComponentFactory() + if "api_budget" in manifest: + factory.set_api_budget(manifest["api_budget"], config) + + from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + HttpRequester as HttpRequesterModel, + ) + + requester_definition = manifest["my_requester"] + assert requester_definition["type"] == "HttpRequester" + + http_requester = factory.create_component( + model_type=HttpRequesterModel, + component_definition=requester_definition, + config=config, + name="lists_stream", + decoder=None, + ) + + assert http_requester.api_budget is not None + assert http_requester.api_budget._ratelimit_reset_header == "X-RateLimit-Reset" + assert http_requester.api_budget._status_codes_for_ratelimit_hit == [429, 503] + assert len(http_requester.api_budget._policies) == 1 + + # The single policy is a MovingWindowCallRatePolicy + policy = http_requester.api_budget._policies[0] + assert isinstance(policy, MovingWindowCallRatePolicy) + assert policy._bucket.rates[0].limit == 3 + # The 0.1s from 'PT0.1S' is stored in ms by PyRateLimiter internally + # but here just check that the limit and interval exist + assert policy._bucket.rates[0].interval == 100 # 100 ms + + +def test_api_budget_fixed_window_policy(): + manifest = { + "type": "DeclarativeSource", + # Root-level api_budget referencing a FixedWindowCallRatePolicy + "api_budget": { + "type": "HTTPAPIBudget", + "policies": [ + { + "type": "FixedWindowCallRatePolicy", + "period": "PT1M", # 1 minute + "call_limit": 10, + "matchers": [ + { + "type": "HttpRequestRegexMatcher", + "method": "GET", + "url_base": "https://example.org", + "url_path_pattern": "/v2/data", + } + ], + } + ], + }, + # We'll define a single HttpRequester that references that base + "my_requester": { + "type": "HttpRequester", + "path": "/v2/data", + "url_base": "https://example.org", + "http_method": "GET", + "authenticator": {"type": "NoAuth"}, + }, + } + + config = {} + + factory = ModelToComponentFactory() + if "api_budget" in manifest: + factory.set_api_budget(manifest["api_budget"], config) + + from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + HttpRequester as HttpRequesterModel, + ) + + requester_definition = manifest["my_requester"] + assert requester_definition["type"] == "HttpRequester" + http_requester = factory.create_component( + model_type=HttpRequesterModel, + component_definition=requester_definition, + config=config, + name="my_stream", + decoder=None, + ) + + assert http_requester.api_budget is not None + assert len(http_requester.api_budget._policies) == 1 + + from airbyte_cdk.sources.streams.call_rate import FixedWindowCallRatePolicy + + policy = http_requester.api_budget._policies[0] + assert isinstance(policy, FixedWindowCallRatePolicy) + assert policy._call_limit == 10 + # The period is "PT1M" => 60 seconds + assert policy._offset.total_seconds() == 60 + + assert len(policy._matchers) == 1 + matcher = policy._matchers[0] + from airbyte_cdk.sources.streams.call_rate import HttpRequestRegexMatcher + + assert isinstance(matcher, HttpRequestRegexMatcher) + assert matcher._method == "GET" + assert matcher._url_base == "https://example.org" + assert matcher._url_path_pattern.pattern == "/v2/data" diff --git a/unit_tests/sources/declarative/requesters/test_http_requester.py b/unit_tests/sources/declarative/requesters/test_http_requester.py index f02ec206b..c5d5c218d 100644 --- a/unit_tests/sources/declarative/requesters/test_http_requester.py +++ b/unit_tests/sources/declarative/requesters/test_http_requester.py @@ -2,6 +2,7 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from datetime import timedelta from typing import Any, Mapping, Optional from unittest import mock from unittest.mock import MagicMock @@ -9,6 +10,7 @@ import pytest as pytest import requests +import requests.sessions from requests import PreparedRequest from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator @@ -27,6 +29,12 @@ InterpolatedRequestOptionsProvider, ) from airbyte_cdk.sources.message import MessageRepository +from airbyte_cdk.sources.streams.call_rate import ( + AbstractAPIBudget, + HttpAPIBudget, + MovingWindowCallRatePolicy, + Rate, +) from airbyte_cdk.sources.streams.http.error_handlers.response_models import ResponseAction from airbyte_cdk.sources.streams.http.exceptions import ( RequestBodyException, @@ -45,6 +53,7 @@ def factory( request_options_provider: Optional[InterpolatedRequestOptionsProvider] = None, authenticator: Optional[DeclarativeAuthenticator] = None, error_handler: Optional[ErrorHandler] = None, + api_budget: Optional[HttpAPIBudget] = None, config: Optional[Config] = None, parameters: Mapping[str, Any] = None, disable_retries: bool = False, @@ -61,6 +70,7 @@ def factory( http_method=http_method, request_options_provider=request_options_provider, error_handler=error_handler, + api_budget=api_budget, disable_retries=disable_retries, message_repository=message_repository or MagicMock(), use_cache=use_cache, @@ -934,3 +944,25 @@ def test_backoff_strategy_from_manifest_is_respected(http_requester_factory: Any http_requester._http_client._request_attempt_count.get(request_mock) == http_requester._http_client._max_retries + 1 ) + + +def test_http_requester_with_mock_apibudget(http_requester_factory, monkeypatch): + mock_budget = MagicMock(spec=HttpAPIBudget) + + requester = http_requester_factory( + url_base="https://example.com", + path="test", + api_budget=mock_budget, + ) + + dummy_response = requests.Response() + dummy_response.status_code = 200 + send_mock = MagicMock(return_value=dummy_response) + monkeypatch.setattr(requests.Session, "send", send_mock) + + response = requester.send_request() + + assert send_mock.call_count == 1 + assert response.status_code == 200 + + assert mock_budget.acquire_call.call_count == 1 diff --git a/unit_tests/sources/streams/test_call_rate.py b/unit_tests/sources/streams/test_call_rate.py index 16bce68e3..853e2997e 100644 --- a/unit_tests/sources/streams/test_call_rate.py +++ b/unit_tests/sources/streams/test_call_rate.py @@ -17,6 +17,7 @@ CallRateLimitHit, FixedWindowCallRatePolicy, HttpRequestMatcher, + HttpRequestRegexMatcher, MovingWindowCallRatePolicy, Rate, UnlimitedCallRatePolicy, @@ -357,3 +358,90 @@ def test_with_cache(self, mocker, requests_mock): assert next(records) == {"data": "some_data"} assert MovingWindowCallRatePolicy.try_acquire.call_count == 1 + + +class TestHttpRequestRegexMatcher: + """ + Tests for the new regex-based logic: + - Case-insensitive HTTP method matching + - Optional url_base (scheme://netloc) + - Regex-based path matching + - Query params (must be present) + - Headers (case-insensitive keys) + """ + + def test_case_insensitive_method(self): + matcher = HttpRequestRegexMatcher(method="GET") + + req_ok = Request("get", "https://example.com/test/path") + req_wrong = Request("POST", "https://example.com/test/path") + + assert matcher(req_ok) + assert not matcher(req_wrong) + + def test_url_base(self): + matcher = HttpRequestRegexMatcher(url_base="https://example.com") + + req_ok = Request("GET", "https://example.com/test/path?foo=bar") + req_wrong = Request("GET", "https://another.com/test/path?foo=bar") + + assert matcher(req_ok) + assert not matcher(req_wrong) + + def test_url_path_pattern(self): + matcher = HttpRequestRegexMatcher(url_path_pattern=r"/test/") + + req_ok = Request("GET", "https://example.com/test/something") + req_wrong = Request("GET", "https://example.com/other/something") + + assert matcher(req_ok) + assert not matcher(req_wrong) + + def test_query_params(self): + matcher = HttpRequestRegexMatcher(params={"foo": "bar"}) + + req_ok = Request("GET", "https://example.com/api?foo=bar&extra=123") + req_missing = Request("GET", "https://example.com/api?not_foo=bar") + + assert matcher(req_ok) + assert not matcher(req_missing) + + def test_headers_case_insensitive(self): + matcher = HttpRequestRegexMatcher(headers={"X-Custom-Header": "abc"}) + + req_ok = Request( + "GET", + "https://example.com/api?foo=bar", + headers={"x-custom-header": "abc", "other": "123"}, + ) + req_wrong = Request("GET", "https://example.com/api", headers={"x-custom-header": "wrong"}) + + assert matcher(req_ok) + assert not matcher(req_wrong) + + def test_combined_criteria(self): + matcher = HttpRequestRegexMatcher( + method="GET", + url_base="https://example.com", + url_path_pattern=r"/test/", + params={"foo": "bar"}, + headers={"X-Test": "123"}, + ) + + req_ok = Request("GET", "https://example.com/test/me?foo=bar", headers={"x-test": "123"}) + req_bad_base = Request( + "GET", "https://other.com/test/me?foo=bar", headers={"x-test": "123"} + ) + req_bad_path = Request("GET", "https://example.com/nope?foo=bar", headers={"x-test": "123"}) + req_bad_param = Request( + "GET", "https://example.com/test/me?extra=xyz", headers={"x-test": "123"} + ) + req_bad_header = Request( + "GET", "https://example.com/test/me?foo=bar", headers={"some-other-header": "xyz"} + ) + + assert matcher(req_ok) + assert not matcher(req_bad_base) + assert not matcher(req_bad_path) + assert not matcher(req_bad_param) + assert not matcher(req_bad_header)