Skip to content

Commit 07ba1a2

Browse files
committed
server/order: store billing address
1 parent f27b363 commit 07ba1a2

File tree

5 files changed

+308
-1
lines changed

5 files changed

+308
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
"""Add Order.billing_address
2+
3+
Revision ID: 1769a6e618a4
4+
Revises: 6cbeabf73caf
5+
Create Date: 2024-11-26 14:44:03.569035
6+
7+
"""
8+
9+
import concurrent.futures
10+
import random
11+
import time
12+
from typing import Any, TypedDict, cast
13+
14+
import sqlalchemy as sa
15+
import stripe as stripe_lib
16+
from alembic import op
17+
from pydantic import ValidationError
18+
19+
from polar import payment_method
20+
from polar.config import settings
21+
22+
# Polar Custom Imports
23+
from polar.integrations.stripe.utils import get_expandable_id
24+
from polar.kit.address import Address, AddressType
25+
26+
# revision identifiers, used by Alembic.
27+
revision = "1769a6e618a4"
28+
down_revision = "6cbeabf73caf"
29+
branch_labels: tuple[str] | None = None
30+
depends_on: tuple[str] | None = None
31+
32+
33+
stripe_client = stripe_lib.StripeClient(
34+
settings.STRIPE_SECRET_KEY,
35+
http_client=stripe_lib.HTTPXClient(allow_sync_methods=True),
36+
)
37+
38+
39+
class MigratedOrder(TypedDict):
40+
order_id: str
41+
amount: int
42+
billing_address: dict[str, Any] | None
43+
44+
45+
def _is_empty_customer_address(customer_address: dict[str, Any] | None) -> bool:
46+
return customer_address is None or customer_address["country"] is None
47+
48+
49+
def migrate_order(
50+
order: tuple[str, int, str | None, str | None], retry: int = 1
51+
) -> MigratedOrder:
52+
order_id, amount, stripe_invoice_id, stripe_charge_id = order
53+
54+
if stripe_invoice_id is None and stripe_charge_id is None:
55+
raise ValueError(f"No invoice or charge: {order_id}")
56+
57+
customer_address: Any | None = None
58+
try:
59+
# Get from invoice
60+
if stripe_invoice_id is not None:
61+
invoice = stripe_client.invoices.retrieve(stripe_invoice_id)
62+
customer_address = invoice.customer_address
63+
# No address on invoice, try to get from charge
64+
if (
65+
_is_empty_customer_address(customer_address)
66+
and invoice.charge is not None
67+
):
68+
return migrate_order(
69+
(order_id, amount, None, get_expandable_id(invoice.charge))
70+
)
71+
# Get from charge
72+
elif stripe_charge_id is not None:
73+
charge = stripe_client.charges.retrieve(
74+
stripe_charge_id,
75+
params={
76+
"expand": ["payment_method_details", "payment_method_details.card"]
77+
},
78+
)
79+
customer_address = charge.billing_details.address
80+
# No address on charge, try to get from payment method
81+
if _is_empty_customer_address(customer_address):
82+
if payment_method_details := charge.payment_method_details:
83+
if card := getattr(payment_method_details, "card", None):
84+
customer_address = {"country": card.country}
85+
except stripe_lib.RateLimitError:
86+
time.sleep(retry + random.random())
87+
return migrate_order(order, retry=retry + 1)
88+
89+
billing_address: dict[str, Any] | None = None
90+
if not _is_empty_customer_address(customer_address):
91+
try:
92+
billing_address = (
93+
Address.model_validate(customer_address).model_dump()
94+
if customer_address
95+
else None
96+
)
97+
except ValidationError as e:
98+
raise ValueError(f"Invalid address for order {order_id}: {e}")
99+
100+
return {"order_id": order_id, "amount": amount, "billing_address": billing_address}
101+
102+
103+
def migrate_orders(
104+
results: sa.CursorResult[tuple[str, int, str | None, str | None]],
105+
) -> list[MigratedOrder]:
106+
migrated_orders: list[MigratedOrder] = []
107+
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
108+
futures = [executor.submit(migrate_order, order._tuple()) for order in results]
109+
for future in concurrent.futures.as_completed(futures):
110+
migrated_orders.append(future.result())
111+
return migrated_orders
112+
113+
114+
def upgrade() -> None:
115+
# ### commands auto generated by Alembic - please adjust! ###
116+
op.add_column(
117+
"orders",
118+
sa.Column(
119+
"billing_address",
120+
AddressType(astext_type=sa.Text()),
121+
nullable=True,
122+
),
123+
)
124+
125+
connection = op.get_bind()
126+
orders = connection.execute(
127+
sa.text("""
128+
SELECT orders.id, orders.amount, orders.stripe_invoice_id, orders.user_metadata->>'charge_id' AS stripe_charge_id
129+
FROM orders
130+
""")
131+
)
132+
migrated_orders = migrate_orders(orders)
133+
for migrated_order in migrated_orders:
134+
if migrated_order["billing_address"] is None:
135+
if migrated_order["amount"] != 0:
136+
print("No billing address for paid order", migrated_order["order_id"]) # noqa: T201
137+
continue
138+
op.execute(
139+
sa.text(
140+
"""
141+
UPDATE orders
142+
SET billing_address = :billing_address
143+
WHERE id = :order_id
144+
"""
145+
).bindparams(
146+
sa.bindparam(
147+
"billing_address", migrated_order["billing_address"], type_=sa.JSON
148+
),
149+
order_id=migrated_order["order_id"],
150+
)
151+
)
152+
153+
# ### end Alembic commands ###
154+
155+
156+
def downgrade() -> None:
157+
# ### commands auto generated by Alembic - please adjust! ###
158+
op.drop_column("orders", "billing_address")
159+
# ### end Alembic commands ###

server/polar/models/order.py

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship
88

99
from polar.custom_field.data import CustomFieldDataMixin
10+
from polar.kit.address import Address, AddressType
1011
from polar.kit.db.models import RecordModel
1112
from polar.kit.metadata import MetadataMixin
1213

@@ -38,6 +39,7 @@ class Order(CustomFieldDataMixin, MetadataMixin, RecordModel):
3839
billing_reason: Mapped[OrderBillingReason] = mapped_column(
3940
String, nullable=False, index=True
4041
)
42+
billing_address: Mapped[Address | None] = mapped_column(AddressType, nullable=True)
4143
stripe_invoice_id: Mapped[str | None] = mapped_column(
4244
String, nullable=True, unique=True
4345
)

server/polar/order/schemas.py

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from polar.discount.schemas import (
88
DiscountMinimal,
99
)
10+
from polar.kit.address import Address
1011
from polar.kit.metadata import MetadataOutputMixin
1112
from polar.kit.schemas import IDSchema, MergeJSONSchema, Schema, TimestampedSchema
1213
from polar.models.order import OrderBillingReason
@@ -21,6 +22,7 @@ class OrderBase(
2122
tax_amount: int
2223
currency: str
2324
billing_reason: OrderBillingReason
25+
billing_address: Address | None
2426

2527
user_id: UUID4
2628
product_id: UUID4

server/polar/order/service.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from polar.integrations.stripe.schemas import ProductType
2121
from polar.integrations.stripe.service import stripe as stripe_service
2222
from polar.integrations.stripe.utils import get_expandable_id
23+
from polar.kit.address import Address
2324
from polar.kit.db.postgres import AsyncSession
2425
from polar.kit.pagination import PaginationParams, paginate
2526
from polar.kit.services import ResourceServiceReader
@@ -149,6 +150,10 @@ def __init__(self, order: Order) -> None:
149150
super().__init__(message, 404)
150151

151152

153+
def _is_empty_customer_address(customer_address: dict[str, Any] | None) -> bool:
154+
return customer_address is None or customer_address["country"] is None
155+
156+
152157
class OrderService(ResourceServiceReader[Order]):
153158
async def list(
154159
self,
@@ -303,6 +308,16 @@ async def create_order_from_stripe(
303308

304309
product = product_price.product
305310

311+
billing_address: Address | None = None
312+
if not _is_empty_customer_address(invoice.customer_address):
313+
billing_address = Address.model_validate(invoice.customer_address)
314+
# Try to retrieve the country from the payment method
315+
elif invoice.charge is not None:
316+
charge = await stripe_service.get_charge(get_expandable_id(invoice.charge))
317+
if payment_method_details := charge.payment_method_details:
318+
if card := getattr(payment_method_details, "card", None):
319+
billing_address = Address.model_validate({"country": card.country})
320+
306321
# Get Discount if available
307322
discount: Discount | None = None
308323
if invoice.discount is not None:
@@ -329,7 +344,6 @@ async def create_order_from_stripe(
329344
user: User | None = None
330345

331346
billing_reason: OrderBillingReason = OrderBillingReason.purchase
332-
333347
tax = invoice.tax or 0
334348
amount = invoice.total - tax
335349

@@ -368,6 +382,7 @@ async def create_order_from_stripe(
368382
tax_amount=tax,
369383
currency=invoice.currency,
370384
billing_reason=billing_reason,
385+
billing_address=billing_address,
371386
stripe_invoice_id=invoice.id,
372387
product=product,
373388
product_price=product_price,

0 commit comments

Comments
 (0)