Skip to content
Merged
42 changes: 20 additions & 22 deletions src/apify/_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,14 +380,6 @@ def event_manager(self) -> EventManager:
def _charging_manager_implementation(self) -> ChargingManagerImplementation:
return ChargingManagerImplementation(self.configuration, self.apify_client)

@cached_property
def _charge_lock(self) -> asyncio.Lock:
"""Lock to synchronize charge operations.

Prevents race conditions between Actor.charge and Actor.push_data calls.
"""
return asyncio.Lock()

@cached_property
def _storage_client(self) -> SmartApifyStorageClient:
"""Storage client used by the Actor.
Expand Down Expand Up @@ -639,21 +631,25 @@ async def push_data(self, data: dict | list[dict], charged_event_name: str | Non

data = data if isinstance(data, list) else [data]

# No charging, just push the data without locking.
if charged_event_name is None:
dataset = await self.open_dataset()
await dataset.push_data(data)
return None
if charged_event_name and charged_event_name.startswith('apify-'):
raise ValueError(f'Cannot charge for synthetic event "{charged_event_name}" manually')

# If charging is requested, acquire the charge lock to prevent race conditions between concurrent
charging_manager = self.get_charging_manager()

# Acquire the charge lock to prevent race conditions between concurrent
# push_data calls. We need to hold the lock for the entire push_data + charge sequence.
async with self._charge_lock:
max_charged_count = self.get_charging_manager().calculate_max_event_charge_count_within_limit(
charged_event_name
)
async with charging_manager.charge_lock:
# No explicit charging requested; synthetic events are handled within dataset.push_data.
if charged_event_name is None:
dataset = await self.open_dataset()
await dataset.push_data(data)
return None

# Push as many items as we can charge for.
pushed_items_count = min(max_charged_count, len(data)) if max_charged_count is not None else len(data)
pushed_items_count = self.get_charging_manager().compute_push_data_limit(
items_count=len(data),
event_name=charged_event_name,
is_default_dataset=True,
)

dataset = await self.open_dataset()

Expand All @@ -662,6 +658,7 @@ async def push_data(self, data: dict | list[dict], charged_event_name: str | Non
elif pushed_items_count > 0:
await dataset.push_data(data)

# Only charge explicit events; synthetic events will be processed within the client.
return await self.get_charging_manager().charge(
event_name=charged_event_name,
count=pushed_items_count,
Expand Down Expand Up @@ -727,8 +724,9 @@ async def charge(self, event_name: str, count: int = 1) -> ChargeResult:
count: Number of events to charge for.
"""
# Acquire lock to prevent race conditions with concurrent charge/push_data calls.
async with self._charge_lock:
return await self.get_charging_manager().charge(event_name, count)
charging_manager = self.get_charging_manager()
async with charging_manager.charge_lock:
return await charging_manager.charge(event_name, count)

@overload
def on(
Expand Down
87 changes: 78 additions & 9 deletions src/apify/_charging.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import asyncio
import math
from contextvars import ContextVar
from dataclasses import dataclass
from datetime import datetime, timezone
from decimal import Decimal
Expand Down Expand Up @@ -29,6 +31,12 @@

run_validator = TypeAdapter[ActorRun | None](ActorRun | None)

DEFAULT_DATASET_ITEM_EVENT = 'apify-default-dataset-item'

# Context variable to hold the current `ChargingManager` instance, if any. This allows PPE-aware dataset clients to
# access the charging manager without needing to pass it explicitly.
charging_manager_ctx: ContextVar[ChargingManager | None] = ContextVar('charging_manager_ctx', default=None)

_ensure_context = ensure_context('active')


Expand All @@ -45,6 +53,9 @@ class ChargingManager(Protocol):
- Apify platform documentation: https://docs.apify.com/platform/actors/publishing/monetize
"""

charge_lock: asyncio.Lock
"""Lock to synchronize charge operations. Prevents race conditions between `charge` and `push_data` calls."""

async def charge(self, event_name: str, count: int = 1) -> ChargeResult:
"""Charge for a specified number of events - sub-operations of the Actor.

Expand Down Expand Up @@ -81,6 +92,28 @@ def get_charged_event_count(self, event_name: str) -> int:
def get_max_total_charge_usd(self) -> Decimal:
"""Get the configured maximum total charge for this Actor run."""

def compute_push_data_limit(
self,
items_count: int,
event_name: str,
*,
is_default_dataset: bool,
) -> int:
"""Compute how many items can be pushed and charged within the current budget.

Accounts for both the explicit event and the synthetic `DEFAULT_DATASET_ITEM_EVENT` event,
so that the combined cost per item does not exceed the remaining budget.

Args:
items_count: The number of items to be pushed.
event_name: The explicit event name to charge for each item.
is_default_dataset: Whether the data is pushed to the default dataset.
If True, the synthetic event cost is included in the combined price.

Returns:
Max number of items that can be pushed within the budget.
"""


@docs_group('Charging')
@dataclass(frozen=True)
Expand Down Expand Up @@ -137,6 +170,8 @@ def __init__(self, configuration: Configuration, client: ApifyClientAsync) -> No
self._not_ppe_warning_printed = False
self.active = False

self.charge_lock = asyncio.Lock()

async def __aenter__(self) -> None:
"""Initialize the charging manager - this is called by the `Actor` class and shouldn't be invoked manually."""
# Validate config
Expand Down Expand Up @@ -190,6 +225,11 @@ async def __aenter__(self) -> None:

self._charging_log_dataset = await Dataset.open(name=self.LOCAL_CHARGING_LOG_DATASET_NAME)

# if the Actor runs with the pay-per-event pricing model, set the context variable so that PPE-aware dataset
# clients can access the charging manager and charge for synthetic events.
if self._pricing_model == 'PAY_PER_EVENT':
charging_manager_ctx.set(self)

async def __aexit__(
self,
exc_type: type[BaseException] | None,
Expand All @@ -199,8 +239,10 @@ async def __aexit__(
if not self.active:
raise RuntimeError('Exiting an uninitialized ChargingManager')

charging_manager_ctx.set(None)
self.active = False

@_ensure_context
@_ensure_context
async def charge(self, event_name: str, count: int = 1) -> ChargeResult:
def calculate_chargeable() -> dict[str, int | None]:
Expand Down Expand Up @@ -258,7 +300,11 @@ def calculate_chargeable() -> dict[str, int | None]:
if self._actor_run_id is None:
raise RuntimeError('Actor run ID not configured')

if event_name in self._pricing_info:
if event_name.startswith('apify-'):
# Synthetic events (e.g. apify-default-dataset-item) are tracked internally only,
# the platform handles them automatically based on dataset writes.
pass
elif event_name in self._pricing_info:
await self._client.run(self._actor_run_id).charge(event_name, charged_count)
else:
logger.warning(f"Attempting to charge for an unknown event '{event_name}'")
Expand Down Expand Up @@ -291,30 +337,26 @@ def calculate_chargeable() -> dict[str, int | None]:
chargeable_within_limit=calculate_chargeable(),
)

@_ensure_context
@_ensure_context
def calculate_total_charged_amount(self) -> Decimal:
return sum(
(item.total_charged_amount for item in self._charging_state.values()),
start=Decimal(),
)

@_ensure_context
@_ensure_context
def calculate_max_event_charge_count_within_limit(self, event_name: str) -> int | None:
pricing_info = self._pricing_info.get(event_name)

if pricing_info is not None:
price = pricing_info.price
elif not self._is_at_home:
price = Decimal(1) # Use a nonzero price for local development so that the maximum budget can be reached
else:
price = Decimal()
price = self._get_event_price(event_name)

if not price:
return None

result = (self._max_total_charge_usd - self.calculate_total_charged_amount()) / price
return max(0, math.floor(result)) if result.is_finite() else None

@_ensure_context
@_ensure_context
def get_pricing_info(self) -> ActorPricingInfo:
return ActorPricingInfo(
Expand All @@ -328,15 +370,36 @@ def get_pricing_info(self) -> ActorPricingInfo:
},
)

@_ensure_context
@_ensure_context
def get_charged_event_count(self, event_name: str) -> int:
item = self._charging_state.get(event_name)
return item.charge_count if item is not None else 0

@_ensure_context
@_ensure_context
def get_max_total_charge_usd(self) -> Decimal:
return self._max_total_charge_usd

@_ensure_context
def compute_push_data_limit(
self,
items_count: int,
event_name: str,
*,
is_default_dataset: bool,
) -> int:
explicit_price = self._get_event_price(event_name)
synthetic_price = self._get_event_price(DEFAULT_DATASET_ITEM_EVENT) if is_default_dataset else Decimal(0)
combined_price = explicit_price + synthetic_price

if not combined_price:
return items_count

result = (self._max_total_charge_usd - self.calculate_total_charged_amount()) / combined_price
max_count = max(0, math.floor(result)) if result.is_finite() else items_count
return min(items_count, max_count)

async def _fetch_pricing_info(self) -> _FetchedPricingInfoDict:
"""Fetch pricing information from environment variables or API."""
# Check if pricing info is available via environment variables
Expand Down Expand Up @@ -370,6 +433,12 @@ async def _fetch_pricing_info(self) -> _FetchedPricingInfoDict:
max_total_charge_usd=self._configuration.max_total_charge_usd or Decimal('inf'),
)

def _get_event_price(self, event_name: str) -> Decimal:
pricing_info = self._pricing_info.get(event_name)
if pricing_info is not None:
return pricing_info.price
return Decimal(0) if self._is_at_home else Decimal(1)


@dataclass
class ChargingStateItem:
Expand Down
36 changes: 22 additions & 14 deletions src/apify/storage_clients/_apify/_dataset_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata

from ._api_client_creation import create_storage_api_client
from apify.storage_clients._ppe_dataset_mixin import DatasetClientPpeMixin

if TYPE_CHECKING:
from collections.abc import AsyncIterator
Expand All @@ -25,7 +26,7 @@
logger = getLogger(__name__)


class ApifyDatasetClient(DatasetClient):
class ApifyDatasetClient(DatasetClient, DatasetClientPpeMixin):
"""An Apify platform implementation of the dataset client."""

_MAX_PAYLOAD_SIZE = ByteSize.from_mb(9)
Expand All @@ -48,6 +49,9 @@ def __init__(

Preferably use the `ApifyDatasetClient.open` class method to create a new instance.
"""
DatasetClient.__init__(self)
DatasetClientPpeMixin.__init__(self)

self._api_client = api_client
"""The Apify dataset client for API operations."""

Expand Down Expand Up @@ -108,12 +112,18 @@ async def open(
id=id,
)

return cls(
dataset_client = cls(
api_client=api_client,
api_public_base_url='', # Remove in version 4.0, https://github.com/apify/apify-sdk-python/issues/635
lock=asyncio.Lock(),
)

dataset_client.is_default_dataset = (
alias is None and name is None and (id is None or id == configuration.default_dataset_id)
)

return dataset_client

@override
async def purge(self) -> None:
raise NotImplementedError(
Expand All @@ -128,21 +138,19 @@ async def drop(self) -> None:

@override
async def push_data(self, data: list[Any] | dict[str, Any]) -> None:
async def payloads_generator() -> AsyncIterator[str]:
for index, item in enumerate(data):
async def payloads_generator(items: list[Any]) -> AsyncIterator[str]:
for index, item in enumerate(items):
yield await self._check_and_serialize(item, index)

async with self._lock:
# Handle lists
if isinstance(data, list):
# Invoke client in series to preserve the order of data
async for items in self._chunk_by_size(payloads_generator()):
await self._api_client.push_items(items=items)
async with self._lock, self._charge_lock():
items = data if isinstance(data, list) else [data]
limit = self._compute_limit_for_push(len(items))
items = items[:limit]

# Handle singular items
else:
items = await self._check_and_serialize(data)
await self._api_client.push_items(items=items)
async for chunk in self._chunk_by_size(payloads_generator(items)):
await self._api_client.push_items(items=chunk)

await self._charge_for_items(count_items=limit)

@override
async def get_data(
Expand Down
Loading
Loading