diff --git a/deployment/migrations/versions/0057_a8c3d9f1b2e4_credit_balances_lot_cache.py b/deployment/migrations/versions/0057_a8c3d9f1b2e4_credit_balances_lot_cache.py new file mode 100644 index 000000000..fee46533d --- /dev/null +++ b/deployment/migrations/versions/0057_a8c3d9f1b2e4_credit_balances_lot_cache.py @@ -0,0 +1,87 @@ +"""replace credit_balances with lot-cache schema + +The previous ``credit_balances`` table was a single integer sum per address, +recomputed lazily on the read path (and written back from inside the read +itself). This change replaces it with a per-lot cache: one row per granting +``credit_history`` entry, with ``amount_remaining`` decremented eagerly by +writers (distribution, expense, transfer). Reads become a simple ``SUM`` +over still-valid lots, no FIFO walk and no write-back. + +The table is a pure cache derived from ``credit_history``. The matching +``_repair_credit_balances`` startup hook rebuilds it from history, so the +upgrade does not need to backfill data in the migration itself. + +Revision ID: a8c3d9f1b2e4 +Revises: 7e5a630e4b36 +Create Date: 2026-05-12 00:00:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.sql import func + +revision = "a8c3d9f1b2e4" +down_revision = "7e5a630e4b36" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.drop_index("ix_credit_balances_address", table_name="credit_balances") + op.drop_table("credit_balances") + + op.create_table( + "credit_balances", + sa.Column("address", sa.String(), nullable=False), + sa.Column("credit_ref", sa.String(), nullable=False), + sa.Column("credit_index", sa.Integer(), nullable=False), + sa.Column("amount_remaining", sa.BigInteger(), nullable=False), + sa.Column("expiration_date", sa.TIMESTAMP(timezone=True), nullable=True), + sa.Column("message_timestamp", sa.TIMESTAMP(timezone=True), nullable=False), + sa.Column( + "last_update", + sa.TIMESTAMP(timezone=True), + nullable=False, + server_default=func.now(), + onupdate=func.now(), + ), + sa.PrimaryKeyConstraint( + "address", "credit_ref", "credit_index", name="credit_balances_pkey" + ), + ) + + op.create_index( + "ix_credit_balances_address_order", + "credit_balances", + ["address", "message_timestamp", "credit_ref", "credit_index"], + ) + op.create_index( + "ix_credit_balances_address_active", + "credit_balances", + ["address"], + postgresql_where=sa.text("amount_remaining > 0"), + ) + + +def downgrade() -> None: + op.drop_index("ix_credit_balances_address_active", table_name="credit_balances") + op.drop_index("ix_credit_balances_address_order", table_name="credit_balances") + op.drop_table("credit_balances") + + op.create_table( + "credit_balances", + sa.Column("address", sa.String(), nullable=False), + sa.Column("balance", sa.BigInteger(), nullable=False, server_default="0"), + sa.Column( + "last_update", + sa.TIMESTAMP(timezone=True), + nullable=False, + server_default=func.now(), + onupdate=func.now(), + ), + sa.PrimaryKeyConstraint("address", name="credit_balances_pkey"), + ) + op.create_index( + "ix_credit_balances_address", "credit_balances", ["address"], unique=False + ) diff --git a/src/aleph/db/accessors/balances.py b/src/aleph/db/accessors/balances.py index 85538a5c8..77f9945d2 100644 --- a/src/aleph/db/accessors/balances.py +++ b/src/aleph/db/accessors/balances.py @@ -5,11 +5,12 @@ from dataclasses import dataclass from decimal import Decimal from io import StringIO -from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union from aleph_message.models import Chain from sqlalchemy import func, select, text -from sqlalchemy.sql import Select +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.sql import ColumnElement, Select from aleph.db.models import AlephBalanceDb, AlephCreditBalanceDb, AlephCreditHistoryDb from aleph.toolkit.constants import ( @@ -227,270 +228,205 @@ def get_updated_balance_accounts(session: DbSession, last_update: dt.datetime): return (session.execute(select_stmt)).scalars().all() -@dataclass -class PositiveCredit: - amount: int - expiration_date: Optional[dt.datetime] - timestamp: dt.datetime - remaining: int - - -@dataclass -class NegativeAmount: - amount: int - timestamp: dt.datetime - - @dataclass class CreditBalanceDetail: expiration_date: Optional[dt.datetime] amount: int -def _apply_fifo_consumption( - session: DbSession, address: str, now: Optional[dt.datetime] = None -) -> list[PositiveCredit]: +def _insert_credit_lot( + session: DbSession, + address: str, + credit_ref: str, + credit_index: int, + amount: int, + expiration_date: Optional[dt.datetime], + message_timestamp: dt.datetime, +) -> None: + """Insert one cache row representing a granting ``credit_history`` entry. + + ON CONFLICT DO NOTHING because the credit_history (credit_ref, credit_index) + pair is the PK and message replay must be idempotent: re-applying the same + distribution / transfer-recipient row should be a no-op rather than blowing + up with a uniqueness violation. """ - Fetch all positive credits for the address and apply all existing debits via FIFO. + stmt = pg_insert(AlephCreditBalanceDb).values( + address=address, + credit_ref=credit_ref, + credit_index=credit_index, + amount_remaining=amount, + expiration_date=expiration_date, + message_timestamp=message_timestamp, + ) + stmt = stmt.on_conflict_do_nothing( + index_elements=[ + AlephCreditBalanceDb.address, + AlephCreditBalanceDb.credit_ref, + AlephCreditBalanceDb.credit_index, + ] + ) + session.execute(stmt) - Returns the list of PositiveCredit objects with `.remaining` reflecting how many - credits are still unconsumed after all past expenses/transfers. Expired credits are - included in the returned list — callers are responsible for filtering by expiration. + +def _consume_address_credits( + session: DbSession, + address: str, + amount: int, + message_timestamp: dt.datetime, +) -> List[Tuple[int, Optional[dt.datetime]]]: + """Drain ``amount`` from the address's still-valid lots in emission order. + + Returns ``(consumed_amount, source_expiration)`` per touched lot in + consumption order. Each touched lot has its ``amount_remaining`` decremented + in place. + + Emission order matches the historical FIFO: ``(message_timestamp, credit_ref, + credit_index) ASC``. ``message_timestamp`` is the cutoff for "still valid": + only lots with ``expiration_date IS NULL OR expiration_date > message_timestamp`` + are eligible. Using the message timestamp (not wall-clock now) keeps eager + writes consistent with the repair replay, which uses the historical timestamp + when reconstructing state from ``credit_history``. + + Lots are locked ``FOR UPDATE`` to serialise concurrent writers for the same + address. Over-draw silently drops the excess, matching the prior FIFO + behaviour gated by ``validate_credit_transfer_balance`` upstream. """ - now = now if now is not None else utc_now() + if amount <= 0: + return [] - records = ( + lots = ( session.execute( - select(AlephCreditHistoryDb) - .where(AlephCreditHistoryDb.address == address) - .order_by(AlephCreditHistoryDb.message_timestamp.asc()) + select(AlephCreditBalanceDb) + .where( + AlephCreditBalanceDb.address == address, + AlephCreditBalanceDb.amount_remaining > 0, + ( + AlephCreditBalanceDb.expiration_date.is_(None) + | (AlephCreditBalanceDb.expiration_date > message_timestamp) + ), + ) + .order_by( + AlephCreditBalanceDb.message_timestamp.asc(), + AlephCreditBalanceDb.credit_ref.asc(), + AlephCreditBalanceDb.credit_index.asc(), + ) + .with_for_update() ) .scalars() .all() ) - positive_credits: list[PositiveCredit] = [] - negative_amounts: list[NegativeAmount] = [] - - for record in records: - if record.amount > 0: - positive_credits.append( - PositiveCredit( - amount=record.amount, - expiration_date=record.expiration_date, - timestamp=record.message_timestamp, - remaining=record.amount, - ) - ) - else: - negative_amounts.append( - NegativeAmount( - amount=abs(record.amount), timestamp=record.message_timestamp - ) - ) - - for expense in negative_amounts: - remaining_expense = expense.amount - for credit in positive_credits: - if remaining_expense <= 0: - break - expense_valid = ( - credit.expiration_date is None - or expense.timestamp < credit.expiration_date - ) - if expense_valid and credit.remaining > 0: - consumed = min(credit.remaining, remaining_expense) - credit.remaining -= consumed - remaining_expense -= consumed - - return positive_credits + consumed_log: List[Tuple[int, Optional[dt.datetime]]] = [] + remaining = amount + for lot in lots: + if remaining <= 0: + break + take = min(lot.amount_remaining, remaining) + lot.amount_remaining -= take + remaining -= take + consumed_log.append((take, lot.expiration_date)) + session.flush() + return consumed_log def _compute_transfer_entries_by_expiration( - remaining_credits: list[PositiveCredit], - amount: int, + consumed_lots: List[Tuple[int, Optional[dt.datetime]]], requested_expiration: Optional[dt.datetime], - now: dt.datetime, -) -> list[tuple[int, Optional[dt.datetime]]]: - """ - Simulate consuming `amount` from `remaining_credits` (FIFO order) and return a list - of (portion_amount, effective_expiration) pairs. - - Credits are consumed in the same FIFO order used by the balance calculation, so the - expiration assignment for the recipient is consistent with the sender's accounting. - - The effective expiration for each portion is: - min(source_credit.expiration_date, requested_expiration) - where None means no expiration. +) -> List[Tuple[int, Optional[dt.datetime]]]: + """Cap each consumed portion's expiration at ``min(source, requested)``. - Adjacent portions with the same effective expiration are merged into one entry. - This prevents a re-transfer from extending or removing the original expiration - constraint placed on the source credits. + Adjacent portions with the same effective expiration are merged so the + recipient never sees more granularity than necessary. The cap rule prevents + a non-whitelisted re-transfer from extending or removing the original + expiration. Whitelisted senders skip this path entirely (see caller). """ - result: list[tuple[int, Optional[dt.datetime]]] = [] - remaining_to_consume = amount - - for credit in remaining_credits: - if remaining_to_consume <= 0: - break - if credit.expiration_date is not None and credit.expiration_date <= now: - continue - if credit.remaining <= 0: - continue - - consumed = min(credit.remaining, remaining_to_consume) - remaining_to_consume -= consumed - - # Effective expiration: most restrictive of source and requested - if credit.expiration_date is not None: - effective_exp: Optional[dt.datetime] = ( - credit.expiration_date - if requested_expiration is None - else min(credit.expiration_date, requested_expiration) - ) + result: List[Tuple[int, Optional[dt.datetime]]] = [] + for consumed, source_exp in consumed_lots: + if source_exp is None: + effective_exp: Optional[dt.datetime] = requested_expiration + elif requested_expiration is None: + effective_exp = source_exp else: - effective_exp = requested_expiration + effective_exp = min(source_exp, requested_expiration) - # Merge with previous entry if same effective expiration if result and result[-1][1] == effective_exp: result[-1] = (result[-1][0] + consumed, effective_exp) else: result.append((consumed, effective_exp)) - return result -def _calculate_credit_balance_fifo( +def _valid_lot_filter(cutoff: Union[dt.datetime, ColumnElement[dt.datetime]]): + return AlephCreditBalanceDb.expiration_date.is_(None) | ( + AlephCreditBalanceDb.expiration_date > cutoff + ) + + +def get_credit_balance( session: DbSession, address: str, now: Optional[dt.datetime] = None ) -> int: + """Sum of remaining amounts across still-valid lots for ``address``. + + Pure read: no FIFO walk, no write-back. Writers keep the cache up to date. """ - Calculate credit balance using FIFO consumption strategy. - - This function implements the core FIFO logic: - 1. Get all positive credits (ordered by message_timestamp) - 2. Get all negative amounts (expenses/transfers) - 3. Apply negative amounts to oldest credits first, but only if the expense - occurred before the credit's expiration date - 4. Return remaining balance considering current expiration status - """ - now = now if now is not None else utc_now() - positive_credits = _apply_fifo_consumption(session, address, now) - total_balance = sum( - c.remaining - for c in positive_credits - if c.expiration_date is None or c.expiration_date > now - ) - return max(0, total_balance) + cutoff = now if now is not None else func.now() + result = session.execute( + select(func.coalesce(func.sum(AlephCreditBalanceDb.amount_remaining), 0)).where( + AlephCreditBalanceDb.address == address, + _valid_lot_filter(cutoff), + ) + ).scalar() + return max(0, int(result or 0)) def get_credit_balance_with_details( session: DbSession, address: str, now: Optional[dt.datetime] = None ) -> Tuple[int, List[CreditBalanceDetail]]: + """Per-expiration breakdown of an address's still-valid remaining credit. + + Returns ``(total, details)`` with details sorted non-expiring first, then + by expiration ascending. Zero-amount lots are filtered out. """ - Calculate credit balance with a breakdown by expiration date. + cutoff = now if now is not None else func.now() + rows = session.execute( + select( + AlephCreditBalanceDb.expiration_date, + func.sum(AlephCreditBalanceDb.amount_remaining).label("amount"), + ) + .where( + AlephCreditBalanceDb.address == address, + AlephCreditBalanceDb.amount_remaining > 0, + _valid_lot_filter(cutoff), + ) + .group_by(AlephCreditBalanceDb.expiration_date) + ).all() - Returns (total_balance, details) where details is a list of - CreditBalanceDetail grouped by expiration_date, sorted with - non-expiring (None) first, then by expiration_date ascending. + pairs = [(row.expiration_date, int(row.amount)) for row in rows] + total = max(0, sum(amount for _, amount in pairs)) - Always recalculates (bypasses cache) since details are not cached. - """ - now = now if now is not None else utc_now() - positive_credits = _apply_fifo_consumption(session, address, now) - - details_map: Dict[Optional[dt.datetime], int] = {} - total_balance = 0 - for credit in positive_credits: - if credit.expiration_date is None or credit.expiration_date > now: - if credit.remaining > 0: - total_balance += credit.remaining - key = credit.expiration_date - details_map[key] = details_map.get(key, 0) + credit.remaining - - # Sort: non-expiring first (None), then by expiration_date ascending details = [ - CreditBalanceDetail(expiration_date=k, amount=v) - for k, v in sorted( - details_map.items(), + CreditBalanceDetail(expiration_date=exp, amount=amount) + for exp, amount in sorted( + pairs, key=lambda x: (x[0] is not None, x[0] or dt.datetime.min), ) ] - - return max(0, total_balance), details + return total, details -def get_credit_balance( - session: DbSession, address: str, now: Optional[dt.datetime] = None -) -> int: - """ - Get credit balance using lazy recalculation strategy. +def _credit_balance_amount_expr(): + """Reusable SQL expression: per-address sum of still-valid remaining credit. - 1. Check if cached balance exists in credit_balances table - 2. Check if credit_history has newer entries than cached balance - 3. Check if any credits have expiration dates that occurred after the cache's last update - 4. If recalculation is needed, recalculate using FIFO and update cache - 5. Return cached balance + Server-evaluated ``func.now()`` is used so the cutoff is fixed at statement + execution time, not Python expression-construction time. """ - - now = now if now is not None else utc_now() - - # Get the timestamp of the most recent credit history entry for this address - latest_history_timestamp = session.execute( - select(func.max(AlephCreditHistoryDb.last_update)).where( - AlephCreditHistoryDb.address == address - ) - ).scalar() - - # If no history exists, balance is 0 - if latest_history_timestamp is None: - return 0 - - # Get cached balance if it exists - cached_balance = session.execute( - select(AlephCreditBalanceDb).where(AlephCreditBalanceDb.address == address) - ).scalar_one_or_none() - - # Check if recalculation is needed - needs_recalculation = ( - cached_balance is None or cached_balance.last_update < latest_history_timestamp + return func.coalesce( + func.sum(AlephCreditBalanceDb.amount_remaining).filter( + _valid_lot_filter(func.now()) + ), + 0, ) - # Also check if any credits have expiration dates that occurred after the cache's last update - # This handles the case where credits expired since the last cache update - if not needs_recalculation and cached_balance is not None: - # Check for any credits with expiration dates between cache last_update and now - earliest_expiration_after_cache = session.execute( - select(func.min(AlephCreditHistoryDb.expiration_date)).where( - (AlephCreditHistoryDb.address == address) - & (AlephCreditHistoryDb.expiration_date.isnot(None)) - & (AlephCreditHistoryDb.expiration_date > cached_balance.last_update) - & (AlephCreditHistoryDb.expiration_date <= now) - ) - ).scalar() - - needs_recalculation = earliest_expiration_after_cache is not None - - if needs_recalculation: - # Recalculate balance using FIFO - new_balance = _calculate_credit_balance_fifo(session, address, now) - - if cached_balance is None: - # Create new cache entry - session.add( - AlephCreditBalanceDb( - address=address, balance=new_balance, last_update=now - ) - ) - else: - # Update existing cache entry - cached_balance.balance = new_balance - cached_balance.last_update = now - - session.flush() - return new_balance - - return cached_balance.balance if cached_balance else 0 - def get_credit_balances( session: DbSession, @@ -500,16 +436,17 @@ def get_credit_balances( after_address: Optional[str] = None, cursor_mode: bool = False, ) -> list[tuple[str, int]]: + """Paginated ``(address, balance)`` list across all addresses with a + positive still-valid sum. """ - Get paginated credit balances for all addresses. - Uses the cached balances from the credit_balances table. - """ - query = select(AlephCreditBalanceDb.address, AlephCreditBalanceDb.balance) - - if min_balance > 0: - query = query.filter(AlephCreditBalanceDb.balance >= min_balance) + balance_expr = _credit_balance_amount_expr().label("balance") - query = query.order_by(AlephCreditBalanceDb.address.asc()) + query = ( + select(AlephCreditBalanceDb.address, balance_expr) + .group_by(AlephCreditBalanceDb.address) + .having(balance_expr >= min_balance) + .order_by(AlephCreditBalanceDb.address.asc()) + ) if after_address is not None: query = query.where(AlephCreditBalanceDb.address > after_address) @@ -522,22 +459,19 @@ def get_credit_balances( if pagination: query = query.limit(pagination) - # Return results in the expected format (address, credits) - results = session.execute(query).all() - return [(row.address, row.balance) for row in results] + return [(row.address, int(row.balance)) for row in session.execute(query).all()] def count_credit_balances(session: DbSession, min_balance: int = 0) -> int: - """ - Count addresses with credit balances. - Uses the cached balances from the credit_balances table. - """ - query = select(func.count(AlephCreditBalanceDb.address)) - - if min_balance > 0: - query = query.filter(AlephCreditBalanceDb.balance >= min_balance) - - return session.execute(query).scalar_one() + """Count of addresses with a positive still-valid sum (or matching ``min_balance``).""" + balance_expr = _credit_balance_amount_expr().label("balance") + sub = ( + select(AlephCreditBalanceDb.address) + .group_by(AlephCreditBalanceDb.address) + .having(balance_expr >= min_balance) + .subquery() + ) + return session.execute(select(func.count()).select_from(sub)).scalar_one() def _format_csv_row(*fields) -> str: @@ -624,11 +558,8 @@ def update_credit_balances_distribution( message_hash: str, message_timestamp: dt.datetime, ) -> None: - """ - Updates credit balances for distribution messages (aleph_credit_distribution). - - Distribution messages include all fields like price, bonus_amount, tx_hash, provider, - payment_method, token, chain, and expiration_date. + """Apply a distribution message: insert one lot per recipient and append the + matching ``credit_history`` rows. """ last_update = utc_now() @@ -642,20 +573,28 @@ def update_credit_balances_distribution( tx_hash = credit_entry["tx_hash"] provider = credit_entry["provider"] - # Extract optional fields from each credit entry expiration_timestamp = credit_entry.get("expiration") or None origin = credit_entry.get("origin", "") origin_ref = credit_entry.get("ref", "") payment_method = credit_entry.get("payment_method", "") bonus_amount = credit_entry.get("bonus_amount", "") - # Convert expiration timestamp to datetime expiration_date = ( dt.datetime.fromtimestamp(expiration_timestamp / 1000, tz=dt.timezone.utc) if expiration_timestamp is not None else None ) + _insert_credit_lot( + session=session, + address=address, + credit_ref=message_hash, + credit_index=index, + amount=amount, + expiration_date=expiration_date, + message_timestamp=message_timestamp, + ) + csv_rows.append( _format_csv_row( address, @@ -686,15 +625,12 @@ def update_credit_balances_expense( message_hash: str, message_timestamp: dt.datetime, ) -> None: - """ - Updates credit balances for expense messages (aleph_credit_expense). - - Expense messages have negative amounts and can include: - - execution_id (mapped to origin) - - node_id (mapped to tx_hash) - - price (mapped to price) - - time (skipped for now) - - ref (mapped to origin_ref) + """Apply an expense message: drain the address's still-valid lots in emission + order, then append the matching ``credit_history`` rows. + + The history row's negative ``amount`` reflects the message intent. If the + address is under-funded, fewer credits are actually consumed (lots cannot go + negative), matching the prior FIFO behaviour. """ last_update = utc_now() @@ -703,19 +639,23 @@ def update_credit_balances_expense( for index, credit_entry in enumerate(credits_list): address = credit_entry["address"] raw_amount = int(credit_entry["amount"]) - amount = -_apply_credit_precision_multiplier(raw_amount, message_timestamp) + amount = _apply_credit_precision_multiplier(raw_amount, message_timestamp) origin_ref = credit_entry.get("ref", "") - - # Map new fields origin = credit_entry.get("execution_id", "") tx_hash = credit_entry.get("node_id", "") price = credit_entry.get("price", "") - # Skip time field for now + + _consume_address_credits( + session=session, + address=address, + amount=amount, + message_timestamp=message_timestamp, + ) csv_rows.append( _format_csv_row( address, - amount, + -amount, message_hash, index, message_timestamp, @@ -744,21 +684,13 @@ def update_credit_balances_transfer( message_hash: str, message_timestamp: dt.datetime, ) -> None: - """ - Updates credit balances for transfer messages (aleph_credit_transfer). - - Transfer messages involve two entries per transfer: - - One or more positive entries for the recipient (adding credits) - - One negative entry for the sender (subtracting credits) - - When a non-whitelisted sender re-transfers credits they received with an expiration - date, the recipient's credits are capped at the original expiration — preventing - bypass of expiration constraints through re-transfers. If the sender's credits have - mixed expirations, multiple positive entries are created for the recipient (one per - expiration group). + """Apply a transfer message: drain the sender's lots in emission order, grant + the resulting amounts to recipient(s) with each portion capped at + ``min(source_expiration, requested_expiration)``, and append the matching + ``credit_history`` rows. - Special case: If sender is in the whitelisted addresses, only add credits to recipient - using the requested expiration as-is (whitelisted senders create credits from nothing). + Whitelisted senders create credits from nothing: the sender is not debited + and the recipient is granted ``amount`` with the requested expiration as-is. """ last_update = utc_now() @@ -766,18 +698,12 @@ def update_credit_balances_transfer( index = 0 is_whitelisted = sender_address in whitelisted_addresses - # Compute sender's remaining credits once for all entries in this transfer - sender_remaining: list[PositiveCredit] = [] - if not is_whitelisted: - sender_remaining = _apply_fifo_consumption(session, sender_address, last_update) - for credit_entry in credits_list: recipient_address = credit_entry["address"] raw_amount = int(credit_entry["amount"]) amount = _apply_credit_precision_multiplier(raw_amount, message_timestamp) expiration_timestamp = credit_entry.get("expiration") or None - # Convert expiration timestamp to datetime requested_expiration = ( dt.datetime.fromtimestamp(expiration_timestamp / 1000, tz=dt.timezone.utc) if expiration_timestamp is not None @@ -785,21 +711,38 @@ def update_credit_balances_transfer( ) if is_whitelisted: - # Whitelisted senders are not constrained by source credits - entries: list[tuple[int, Optional[dt.datetime]]] = [ + entries: List[Tuple[int, Optional[dt.datetime]]] = [ (amount, requested_expiration) ] else: + consumed = _consume_address_credits( + session=session, + address=sender_address, + amount=amount, + message_timestamp=message_timestamp, + ) entries = _compute_transfer_entries_by_expiration( - sender_remaining, amount, requested_expiration, last_update + consumed, requested_expiration ) - # Fallback for edge cases where sender credits are not tracked - # (e.g. whitelisted distributions not recorded in history) + # Production transfers are gated by validate_credit_transfer_balance, + # so consumed should sum to ``amount``. Fall back to a single + # ``(amount, requested_expiration)`` entry whenever it doesn't + # (under-funded test scenarios, zero-amount transfers, or + # whitelisted distributions not tracked as lots) so the recipient + # still receives a history row matching the message intent. if not entries: entries = [(amount, requested_expiration)] - # Add positive entries for recipient (one per expiration group) for entry_amount, entry_expiration in entries: + _insert_credit_lot( + session=session, + address=recipient_address, + credit_ref=message_hash, + credit_index=index, + amount=entry_amount, + expiration_date=entry_expiration, + message_timestamp=message_timestamp, + ) csv_rows.append( _format_csv_row( recipient_address, @@ -822,8 +765,6 @@ def update_credit_balances_transfer( ) index += 1 - # Add negative entry for sender (unless sender is in whitelisted addresses) - # (origin = recipient, provider = ALEPH, payment_method = credit_transfer) if not is_whitelisted: csv_rows.append( _format_csv_row( diff --git a/src/aleph/db/models/balances.py b/src/aleph/db/models/balances.py index c4ef27d6b..c023010a2 100644 --- a/src/aleph/db/models/balances.py +++ b/src/aleph/db/models/balances.py @@ -70,10 +70,22 @@ class AlephCreditHistoryDb(Base): class AlephCreditBalanceDb(Base): + """Per-lot cache of remaining credit. One row per granting credit_history + entry; ``amount_remaining`` is decremented eagerly by expense and transfer + writers, so reads collapse to a SUM over still-valid lots.""" + __tablename__ = "credit_balances" - address: Mapped[str] = mapped_column(String, primary_key=True, index=True) - balance: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) + address: Mapped[str] = mapped_column(String, primary_key=True) + credit_ref: Mapped[str] = mapped_column(String, primary_key=True) + credit_index: Mapped[int] = mapped_column(Integer, primary_key=True) + amount_remaining: Mapped[int] = mapped_column(BigInteger, nullable=False) + expiration_date: Mapped[Optional[dt.datetime]] = mapped_column( + TIMESTAMP(timezone=True), nullable=True + ) + message_timestamp: Mapped[dt.datetime] = mapped_column( + TIMESTAMP(timezone=True), nullable=False + ) last_update: Mapped[dt.datetime] = mapped_column( TIMESTAMP(timezone=True), nullable=False, diff --git a/src/aleph/repair.py b/src/aleph/repair.py index 2f939140a..ac2b67f7d 100644 --- a/src/aleph/repair.py +++ b/src/aleph/repair.py @@ -1,10 +1,11 @@ import logging from aleph_message.models import ItemHash -from sqlalchemy import select +from sqlalchemy import delete, select +from sqlalchemy.dialects.postgresql import insert as pg_insert from aleph.db.accessors.files import upsert_file -from aleph.db.models import StoredFileDb +from aleph.db.models import AlephCreditBalanceDb, AlephCreditHistoryDb, StoredFileDb from aleph.storage import StorageService from aleph.types.db_session import DbSession, DbSessionFactory @@ -46,6 +47,96 @@ async def _fix_file_sizes( ) +def _rebuild_credit_lots_for_address(session: DbSession, address: str) -> None: + """Replay ``credit_history`` chronologically and replace this address's + lot rows with the resulting state. + + Idempotent: clears existing lots for the address first, then rebuilds. Safe + to interrupt at address granularity (callers commit per-address). + + Replays in emission order ``(message_timestamp, credit_ref, credit_index) + ASC``, the same ordering the eager writers see. Each positive history row + becomes a lot; each negative row drains lots in emission order, skipping + any whose expiration is at or before the negative row's + ``message_timestamp`` (so an expense from the past does not drain a lot + that had already expired at that moment). + """ + session.execute( + delete(AlephCreditBalanceDb).where(AlephCreditBalanceDb.address == address) + ) + + records = ( + session.execute( + select(AlephCreditHistoryDb) + .where(AlephCreditHistoryDb.address == address) + .order_by( + AlephCreditHistoryDb.message_timestamp.asc(), + AlephCreditHistoryDb.credit_ref.asc(), + AlephCreditHistoryDb.credit_index.asc(), + ) + ) + .scalars() + .all() + ) + + lots: list[dict] = [] + for record in records: + if record.amount > 0: + lots.append( + { + "credit_ref": record.credit_ref, + "credit_index": record.credit_index, + "amount_remaining": int(record.amount), + "expiration_date": record.expiration_date, + "message_timestamp": record.message_timestamp, + } + ) + else: + remaining = -int(record.amount) + for lot in lots: + if remaining <= 0: + break + if lot["amount_remaining"] <= 0: + continue + if ( + lot["expiration_date"] is not None + and lot["expiration_date"] <= record.message_timestamp + ): + continue + take = min(lot["amount_remaining"], remaining) + lot["amount_remaining"] -= take + remaining -= take + + rows = [{"address": address, **lot} for lot in lots if lot["amount_remaining"] > 0] + if rows: + session.execute(pg_insert(AlephCreditBalanceDb).values(rows)) + + +def _repair_credit_balances(session_factory: DbSessionFactory) -> None: + """Bootstrap or repair the credit_balances lot cache from credit_history. + + Rebuilds lots for every address that has any credit_history rows. Idempotent + and runs on every startup; after the initial bootstrap it is a bounded + full-table scan plus per-address rebuild, which is acceptable given typical + address counts. + """ + with session_factory() as session: + addresses = list( + session.execute(select(AlephCreditHistoryDb.address).distinct()).scalars() + ) + + LOGGER.info("Repairing credit_balances for %d address(es)", len(addresses)) + + for i, address in enumerate(addresses): + with session_factory() as session: + _rebuild_credit_lots_for_address(session, address) + session.commit() + if (i + 1) % 500 == 0: + LOGGER.info("Repaired %d / %d", i + 1, len(addresses)) + + LOGGER.info("Credit balances repair complete (%d address(es))", len(addresses)) + + async def repair_node( storage_service: StorageService, session_factory: DbSessionFactory ): @@ -53,3 +144,6 @@ async def repair_node( with session_factory() as session: await _fix_file_sizes(session, storage_service, store_files=True) session.commit() + + LOGGER.info("Repairing credit balances") + _repair_credit_balances(session_factory) diff --git a/tests/db/test_credit_balances.py b/tests/db/test_credit_balances.py index 8cbffe840..31e65e2b2 100644 --- a/tests/db/test_credit_balances.py +++ b/tests/db/test_credit_balances.py @@ -20,6 +20,7 @@ validate_credit_transfer_balance, ) from aleph.db.models import AlephCreditBalanceDb, AlephCreditHistoryDb +from aleph.repair import _rebuild_credit_lots_for_address from aleph.types.db_session import DbSessionFactory from aleph.types.sort_order import SortByCreditHistory, SortOrder @@ -893,42 +894,31 @@ def test_cache_invalidation_on_credit_expiration(session_factory: DbSessionFacto ) session.commit() - # Step 2: Simulate cache being calculated at T2 (before expiration) - # Mock utc_now to return cache_time during first balance calculation + # Read at T2 (before expiration): the lot is still valid. balance_before_expiration = get_credit_balance( session, "0xcache_bug_user", cache_time ) - session.commit() - - # Verify that at T2, the balance was 10000000 (1000 * 10000 multiplier, credit not yet expired) assert balance_before_expiration == 10000000 - # Verify that a cache entry was created and manually update its timestamp - # to simulate it being created at T2 (cache_time) - - cached_balance = session.execute( + # The eager write created the lot row at distribution time. With the + # lot cache, reads filter by expiration server-side; no write-back on + # the read path. + lot = session.execute( select(AlephCreditBalanceDb).where( AlephCreditBalanceDb.address == "0xcache_bug_user" ) ).scalar_one_or_none() + assert lot is not None + assert lot.amount_remaining == 10000000 - assert cached_balance is not None - assert cached_balance.balance == 10000000 - assert cached_balance.last_update == cache_time - - # Step 3: Now check balance at current time (T3, after expiration) - # The fix should detect that credit expired after cache update and recalculate + # Read at T3 (after expiration): same row, filtered out by the read's + # cutoff. The lot stays in the table; the balance returns 0. balance_after_expiration = get_credit_balance( session, "0xcache_bug_user", now_time ) - - # Expected: 0 (credit has expired) assert balance_after_expiration == 0 - - # Verify that cache was updated (should have a newer timestamp) - session.refresh(cached_balance) - assert cached_balance.balance == 0 - assert cached_balance.last_update == now_time + session.refresh(lot) + assert lot.amount_remaining == 10000000 def test_get_resource_consumed_credits_no_records(session_factory: DbSessionFactory): @@ -1550,10 +1540,21 @@ def test_chain_transfer_a_b_c_expiration_and_balances( def _insert_credit_history_entries(session, entries: List[Dict[str, Any]]): - """Helper to bulk-insert credit history rows for testing.""" + """Helper to seed credit_history rows and rebuild the lot cache so reads + served by the eager cache see them. + + Production code reaches the lot cache through ``update_credit_balances_*``; + tests that bypass those writers and insert into ``credit_history`` directly + have to rebuild the lot cache by hand, mirroring what ``repair_node`` does on + startup. + """ + addresses = set() for entry in entries: session.add(AlephCreditHistoryDb(**entry)) + addresses.add(entry["address"]) session.flush() + for address in addresses: + _rebuild_credit_lots_for_address(session, address) def test_credit_balance_details_non_expiring_only(session_factory: DbSessionFactory): diff --git a/tests/message_processing/test_process_stores.py b/tests/message_processing/test_process_stores.py index cbe692c67..18db07697 100644 --- a/tests/message_processing/test_process_stores.py +++ b/tests/message_processing/test_process_stores.py @@ -27,6 +27,7 @@ from aleph.handlers.content.store import StoreMessageHandler from aleph.handlers.message_handler import MessageHandler from aleph.jobs.process_pending_messages import PendingMessageProcessor +from aleph.repair import _rebuild_credit_lots_for_address from aleph.services.cost import get_total_and_detailed_costs_from_db from aleph.services.storage.engine import StorageEngine from aleph.storage import StorageService @@ -915,6 +916,8 @@ async def test_new_store_message_with_sufficient_credits( message_timestamp=timestamp_to_datetime(CREDIT_ONLY_CUTOFF_TIMESTAMP), ) ) + session.flush() + _rebuild_credit_lots_for_address(session, address) session.commit() # Should pass the balance check @@ -1310,6 +1313,8 @@ async def test_legacy_store_with_credit_payment_and_credits( ), ) ) + session.flush() + _rebuild_credit_lots_for_address(session, address) session.commit() # Should pass the balance check