|
| 1 | +"""Add Order.billing_address |
| 2 | +
|
| 3 | +Revision ID: 1769a6e618a4 |
| 4 | +Revises: 6cbeabf73caf |
| 5 | +Create Date: 2024-11-26 14:44:03.569035 |
| 6 | +
|
| 7 | +""" |
| 8 | + |
| 9 | +import concurrent.futures |
| 10 | +import random |
| 11 | +import time |
| 12 | +from typing import Any, TypedDict, cast |
| 13 | + |
| 14 | +import sqlalchemy as sa |
| 15 | +import stripe as stripe_lib |
| 16 | +from alembic import op |
| 17 | +from pydantic import ValidationError |
| 18 | + |
| 19 | +from polar import payment_method |
| 20 | +from polar.config import settings |
| 21 | + |
| 22 | +# Polar Custom Imports |
| 23 | +from polar.integrations.stripe.utils import get_expandable_id |
| 24 | +from polar.kit.address import Address, AddressType |
| 25 | + |
| 26 | +# revision identifiers, used by Alembic. |
| 27 | +revision = "1769a6e618a4" |
| 28 | +down_revision = "6cbeabf73caf" |
| 29 | +branch_labels: tuple[str] | None = None |
| 30 | +depends_on: tuple[str] | None = None |
| 31 | + |
| 32 | + |
| 33 | +stripe_client = stripe_lib.StripeClient( |
| 34 | + settings.STRIPE_SECRET_KEY, |
| 35 | + http_client=stripe_lib.HTTPXClient(allow_sync_methods=True), |
| 36 | +) |
| 37 | + |
| 38 | + |
| 39 | +class MigratedOrder(TypedDict): |
| 40 | + order_id: str |
| 41 | + amount: int |
| 42 | + billing_address: dict[str, Any] | None |
| 43 | + |
| 44 | + |
| 45 | +def _is_empty_customer_address(customer_address: dict[str, Any] | None) -> bool: |
| 46 | + return customer_address is None or customer_address["country"] is None |
| 47 | + |
| 48 | + |
| 49 | +def migrate_order( |
| 50 | + order: tuple[str, int, str | None, str | None], retry: int = 1 |
| 51 | +) -> MigratedOrder: |
| 52 | + order_id, amount, stripe_invoice_id, stripe_charge_id = order |
| 53 | + |
| 54 | + if stripe_invoice_id is None and stripe_charge_id is None: |
| 55 | + raise ValueError(f"No invoice or charge: {order_id}") |
| 56 | + |
| 57 | + customer_address: Any | None = None |
| 58 | + try: |
| 59 | + # Get from invoice |
| 60 | + if stripe_invoice_id is not None: |
| 61 | + invoice = stripe_client.invoices.retrieve(stripe_invoice_id) |
| 62 | + customer_address = invoice.customer_address |
| 63 | + # No address on invoice, try to get from charge |
| 64 | + if ( |
| 65 | + _is_empty_customer_address(customer_address) |
| 66 | + and invoice.charge is not None |
| 67 | + ): |
| 68 | + return migrate_order( |
| 69 | + (order_id, amount, None, get_expandable_id(invoice.charge)) |
| 70 | + ) |
| 71 | + # Get from charge |
| 72 | + elif stripe_charge_id is not None: |
| 73 | + charge = stripe_client.charges.retrieve( |
| 74 | + stripe_charge_id, |
| 75 | + params={ |
| 76 | + "expand": ["payment_method_details", "payment_method_details.card"] |
| 77 | + }, |
| 78 | + ) |
| 79 | + customer_address = charge.billing_details.address |
| 80 | + # No address on charge, try to get from payment method |
| 81 | + if _is_empty_customer_address(customer_address): |
| 82 | + if payment_method_details := charge.payment_method_details: |
| 83 | + if card := getattr(payment_method_details, "card", None): |
| 84 | + customer_address = {"country": card.country} |
| 85 | + except stripe_lib.RateLimitError: |
| 86 | + time.sleep(retry + random.random()) |
| 87 | + return migrate_order(order, retry=retry + 1) |
| 88 | + |
| 89 | + billing_address: dict[str, Any] | None = None |
| 90 | + if not _is_empty_customer_address(customer_address): |
| 91 | + try: |
| 92 | + billing_address = ( |
| 93 | + Address.model_validate(customer_address).model_dump() |
| 94 | + if customer_address |
| 95 | + else None |
| 96 | + ) |
| 97 | + except ValidationError as e: |
| 98 | + raise ValueError(f"Invalid address for order {order_id}: {e}") |
| 99 | + |
| 100 | + return {"order_id": order_id, "amount": amount, "billing_address": billing_address} |
| 101 | + |
| 102 | + |
| 103 | +def migrate_orders( |
| 104 | + results: sa.CursorResult[tuple[str, int, str | None, str | None]], |
| 105 | +) -> list[MigratedOrder]: |
| 106 | + migrated_orders: list[MigratedOrder] = [] |
| 107 | + with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: |
| 108 | + futures = [executor.submit(migrate_order, order._tuple()) for order in results] |
| 109 | + for future in concurrent.futures.as_completed(futures): |
| 110 | + migrated_orders.append(future.result()) |
| 111 | + return migrated_orders |
| 112 | + |
| 113 | + |
| 114 | +def upgrade() -> None: |
| 115 | + # ### commands auto generated by Alembic - please adjust! ### |
| 116 | + op.add_column( |
| 117 | + "orders", |
| 118 | + sa.Column( |
| 119 | + "billing_address", |
| 120 | + AddressType(astext_type=sa.Text()), |
| 121 | + nullable=True, |
| 122 | + ), |
| 123 | + ) |
| 124 | + |
| 125 | + connection = op.get_bind() |
| 126 | + orders = connection.execute( |
| 127 | + sa.text(""" |
| 128 | + SELECT orders.id, orders.amount, orders.stripe_invoice_id, orders.user_metadata->>'charge_id' AS stripe_charge_id |
| 129 | + FROM orders |
| 130 | + """) |
| 131 | + ) |
| 132 | + migrated_orders = migrate_orders(orders) |
| 133 | + for migrated_order in migrated_orders: |
| 134 | + if migrated_order["billing_address"] is None: |
| 135 | + if migrated_order["amount"] != 0: |
| 136 | + print("No billing address for paid order", migrated_order["order_id"]) # noqa: T201 |
| 137 | + continue |
| 138 | + op.execute( |
| 139 | + sa.text( |
| 140 | + """ |
| 141 | + UPDATE orders |
| 142 | + SET billing_address = :billing_address |
| 143 | + WHERE id = :order_id |
| 144 | + """ |
| 145 | + ).bindparams( |
| 146 | + sa.bindparam( |
| 147 | + "billing_address", migrated_order["billing_address"], type_=sa.JSON |
| 148 | + ), |
| 149 | + order_id=migrated_order["order_id"], |
| 150 | + ) |
| 151 | + ) |
| 152 | + |
| 153 | + # ### end Alembic commands ### |
| 154 | + |
| 155 | + |
| 156 | +def downgrade() -> None: |
| 157 | + # ### commands auto generated by Alembic - please adjust! ### |
| 158 | + op.drop_column("orders", "billing_address") |
| 159 | + # ### end Alembic commands ### |
0 commit comments