Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions ai_docs/COMPILE_AST_QUERY_RELATIONSHIP.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Relationship between `compile_ast` and `compile_query`

## Overview

`compile_ast` and `compile_query` work together to convert the AST (Abstract Syntax Tree) into SQLAlchemy query objects:

- **`compile_ast`**: Recursively traverses the AST and builds up intermediate state
- **`compile_query`**: Converts the intermediate state into a concrete SQLAlchemy `Select` statement

## `compile_ast` - State Builder

**Signature:**
```python
compile_ast(nd: AstNode, needed_cols: dict[UUID, int])
-> tuple[sqa.Table, Query, dict[UUID, sqa.Label]]
```

**What it does:**
- Recursively processes the AST from leaves to root
- Incrementally builds up three pieces of state:
1. **`table`**: A SQLAlchemy `Table` or `Subquery` object representing the data source
2. **`query`**: A `Query` object containing metadata (select list, where clauses, group_by, etc.)
3. **`sqa_expr`**: A dictionary mapping column UUIDs to SQLAlchemy `Label` expressions

**Key behaviors:**
- Processes verbs by modifying the `query` and `sqa_expr` incrementally
- Tracks which columns are "needed" via `needed_cols` for subquery optimization
- Only calls `compile_query` when materializing a subquery (at `SubqueryMarker`)

## `compile_query` - Query Materializer

**Signature:**
```python
compile_query(table: sqa.Table, query: Query, sqa_expr: dict[UUID, sqa.ColumnElement])
-> sqa.sql.Select
```

**What it does:**
- Takes the accumulated state from `compile_ast`
- Builds a concrete SQLAlchemy `Select` statement by:
1. Starting from `table.select().select_from(table)`
2. Adding WHERE clauses from `query.where`
3. Adding GROUP BY from `query.group_by`
4. Adding HAVING from `query.having`
5. Adding LIMIT/OFFSET from `query.limit`/`query.offset`
6. Adding ORDER BY from `query.order_by`
7. Finally selecting only the columns in `query.select` using `with_only_columns()`

## Important Assumptions

### 1. **State Consistency**
When calling `compile_query`, the state must be consistent:
- All UUIDs in `query.select` must exist as keys in `sqa_expr`
- All UUIDs referenced in `query.where`, `query.group_by`, `query.having`, `query.order_by` must exist in `sqa_expr`
- The `table` must be a valid SQLAlchemy Table/Subquery that contains or can reference the columns in `sqa_expr`

### 2. **When to Call `compile_query`**
- **Only call `compile_query` when you need to materialize a query** (e.g., for subqueries)
- During normal AST traversal, `compile_ast` just modifies the `query` and `sqa_expr` state
- Don't call `compile_query` prematurely - let `compile_ast` build up the full state first

### 3. **Column References**
- `sqa_expr` maps UUIDs to SQLAlchemy expressions that can be used in:
- SELECT clauses
- WHERE/HAVING predicates
- GROUP BY clauses
- ORDER BY clauses
- These expressions must be valid in the context of the `table`

### 4. **Query Object State**
The `Query` object accumulates state as verbs are processed:
- `query.select`: List of UUIDs to select (final output columns)
- `query.where`: List of predicates for WHERE clause
- `query.having`: List of predicates for HAVING clause
- `query.group_by`: List of UUIDs for GROUP BY
- `query.order_by`: List of Order objects for ORDER BY
- `query.limit`/`query.offset`: For LIMIT/OFFSET

### 5. **Subquery Handling**
When `compile_query` is called (typically at `SubqueryMarker`):
- It creates a `Select` statement
- That `Select` is converted to a subquery via `.subquery()`
- The `sqa_expr` is updated to reference columns from the subquery
- The `query` is reset to only select the needed columns

## Example: Union Implementation

In the Union implementation, we:
1. Call `compile_ast` on both left and right to get their state
2. Call `compile_query` on both to get their `Select` statements
3. Use SQLAlchemy's `union`/`union_all` to combine them
4. Convert the result to a subquery

**Important**: Before calling `sa.union`, we check if either side is already a `Subquery` and unwrap it using `.original` to get the underlying `CompoundSelect`, because you can't union two subqueries directly - you need the original `CompoundSelect` objects.

## Common Pitfalls

1. **Calling `compile_query` too early**: Don't call it during AST traversal unless you're materializing a subquery
2. **Inconsistent state**: Make sure all UUIDs in `query` exist in `sqa_expr`
3. **Missing column references**: Ensure columns referenced in WHERE/HAVING/etc. are in `sqa_expr`
4. **Subquery unwrapping**: When combining queries (like UNION), unwrap subqueries to get the original `CompoundSelect`
3 changes: 3 additions & 0 deletions docs/source/changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Changelog

## 0.6.3 (2025-12-17)
- implement `tbl1 >> union(tbl2)` and `union(tbl1,tbl2)`

## 0.6.2 (2025-12-15)
- drop support for python 3.10
- fix is_sql_backend(Table) and backend(Table)
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/verbs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ Verbs
slice_head
summarize
ungroup
union
52 changes: 11 additions & 41 deletions fuzz.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,7 @@
pdt.Float(): rng.standard_normal,
pdt.Int(): partial(rng.integers, -(1 << 13), 1 << 13),
pdt.Bool(): partial(rng.integers, 0, 1, dtype=bool),
pdt.String(): (
lambda rows: np.array(
["".join(random.choices(letters, k=rng.poisson(10))) for _ in range(rows)]
)
),
pdt.String(): (lambda rows: np.array(["".join(random.choices(letters, k=rng.poisson(10))) for _ in range(rows)])),
}


Expand All @@ -45,45 +41,29 @@ def gen_table(rows: int, types: dict[pdt.Dtype, int]) -> pl.DataFrame:
for ty, fn in RNG_FNS.items():
if ty in types:
d = d.with_columns(
**{
f"{ty.__class__.__name__.lower()} #{i + 1}": pl.lit(fn(rows))
for i in range(types[ty])
}
**{f"{ty.__class__.__name__.lower()} #{i + 1}": pl.lit(fn(rows)) for i in range(types[ty])}
)

return d


ops_with_return_type: dict[pdt.Dtype, list[tuple[Operator, Signature]]] = {
ty: [] for ty in ALL_TYPES
}
ops_with_return_type: dict[pdt.Dtype, list[tuple[Operator, Signature]]] = {ty: [] for ty in ALL_TYPES}

for op in ops.__dict__.values():
if (
not isinstance(op, Operator)
or op.ftype != Ftype.ELEMENT_WISE
or isinstance(op, Marker)
):
if not isinstance(op, Operator) or op.ftype != Ftype.ELEMENT_WISE or isinstance(op, Marker):
continue
for sig in op.signatures:
if not all(
t in (*ALL_TYPES, Tyvar("T")) for t in (*sig.types, sig.return_type)
):
if not all(t in (*ALL_TYPES, Tyvar("T")) for t in (*sig.types, sig.return_type)):
continue

if isinstance(sig.return_type, Tyvar) or any(
isinstance(param, Tyvar) for param in sig.types
):
if isinstance(sig.return_type, Tyvar) or any(isinstance(param, Tyvar) for param in sig.types):
for ty in ALL_TYPES:
rtype = ty if isinstance(sig.return_type, Tyvar) else sig.return_type
ops_with_return_type[rtype].append(
(
op,
Signature(
*(
ty if isinstance(param, Tyvar) else param
for param in sig.types
),
*(ty if isinstance(param, Tyvar) else param for param in sig.types),
return_type=rtype,
),
)
Expand All @@ -92,9 +72,7 @@ def gen_table(rows: int, types: dict[pdt.Dtype, int]) -> pl.DataFrame:
ops_with_return_type[sig.return_type].append((op, sig))


def gen_expr(
dtype: pdt.Dtype, cols: dict[pdt.Dtype, list[str]], q: float = 0.0
) -> pdt.ColExpr:
def gen_expr(dtype: pdt.Dtype, cols: dict[pdt.Dtype, list[str]], q: float = 0.0) -> pdt.ColExpr:
if dtype.const:
return RNG_FNS[dtype.without_const()](1).item()

Expand All @@ -114,9 +92,7 @@ def gen_expr(
if sig.is_vararg:
nargs = int(rng.normal(2.5, 1 / 1.5))
for _ in range(nargs):
args.append(
gen_expr(sig.types[-1], cols, q + rng.exponential(1 / MEAN_HEIGHT))
)
args.append(gen_expr(sig.types[-1], cols, q + rng.exponential(1 / MEAN_HEIGHT)))

return ColFn(op, *args)

Expand All @@ -132,16 +108,10 @@ def gen_expr(


tables = {backend: fn(df, "t") for backend, fn in BACKEND_TABLES.items()}
cols = {
dtype: [col.name for col in tables["polars"] if col.dtype() <= dtype]
for dtype in ALL_TYPES
}
cols = {dtype: [col.name for col in tables["polars"] if col.dtype() <= dtype] for dtype in ALL_TYPES}

for _ in range(it):
expr = gen_expr(rng.choice(ALL_TYPES), cols)
results = {
backend: table >> mutate(y=expr) >> select(C.y) >> export(Polars())
for backend, table in tables.items()
}
results = {backend: table >> mutate(y=expr) >> select(C.y) >> export(Polars()) for backend, table in tables.items()}
for _backend, res in results:
assert_frame_equal(results["polars"], res)
57 changes: 13 additions & 44 deletions generate_col_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,36 +38,26 @@ def add_vararg_star(formatted_args: str) -> str:


def type_annotation(dtype: Dtype, specialize_generic: bool) -> str:
if (not specialize_generic and not types.is_const(dtype)) or isinstance(
types.without_const(dtype), Tyvar
):
if (not specialize_generic and not types.is_const(dtype)) or isinstance(types.without_const(dtype), Tyvar):
return "ColExpr"
if types.is_const(dtype):
python_type = types.to_python(dtype)
return python_type.__name__ if python_type is not NoneType else "None"
return f"ColExpr[{dtype.__class__.__name__}]"


def generate_fn_decl(
op: Operator, sig: Signature, *, name=None, specialize_generic: bool = True
) -> str:
def generate_fn_decl(op: Operator, sig: Signature, *, name=None, specialize_generic: bool = True) -> str:
if name is None:
name = op.name

defaults: Iterable = (
op.default_values
if op.default_values is not None
else (... for _ in op.param_names)
)
defaults: Iterable = op.default_values if op.default_values is not None else (... for _ in op.param_names)

annotated_args = ", ".join(
name
+ ": "
+ type_annotation(dtype, specialize_generic)
+ (f" = {repr(default_val)}" if default_val is not ... else "")
for dtype, name, default_val in zip(
sig.types, op.param_names, defaults, strict=True
)
for dtype, name, default_val in zip(sig.types, op.param_names, defaults, strict=True)
)
if sig.is_vararg:
annotated_args = add_vararg_star(annotated_args)
Expand All @@ -80,8 +70,7 @@ def generate_fn_decl(
}

annotated_kwargs = "".join(
f", {kwarg.name}: {context_kwarg_annotation[kwarg.name]}"
+ f"{'' if kwarg.required else ' | None = None'}"
f", {kwarg.name}: {context_kwarg_annotation[kwarg.name]}" + f"{'' if kwarg.required else ' | None = None'}"
for kwarg in op.context_kwargs
)

Expand All @@ -93,8 +82,7 @@ def generate_fn_decl(
annotated_kwargs = ""

return (
f"def {name}({annotated_args}{annotated_kwargs}) "
f"-> {type_annotation(sig.return_type, specialize_generic)}:\n"
f"def {name}({annotated_args}{annotated_kwargs}) -> {type_annotation(sig.return_type, specialize_generic)}:\n"
)


Expand Down Expand Up @@ -126,9 +114,7 @@ def generate_fn_body(
return f" return ColFn(ops.{op_var_name}{args}{kwargs})\n\n"


def generate_overloads(
op: Operator, *, name: str | None = None, rversion: bool = False, op_var_name: str
):
def generate_overloads(op: Operator, *, name: str | None = None, rversion: bool = False, op_var_name: str):
res = ""
in_namespace = "." in op.name
if name is None:
Expand All @@ -140,9 +126,7 @@ def generate_overloads(
res += "@overload\n" + generate_fn_decl(op, sig, name=name) + " ...\n\n"

res += (
generate_fn_decl(
op, op.signatures[0], name=name, specialize_generic=not has_overloads
)
generate_fn_decl(op, op.signatures[0], name=name, specialize_generic=not has_overloads)
+ f' """\n{op.doc.strip()}\n"""\n\n'
+ generate_fn_body(
op,
Expand Down Expand Up @@ -177,13 +161,8 @@ def indent(s: str, by: int) -> str:
" # --- generated code starts here, do not delete this comment ---"
):
in_generated_section = True
new_file_contents += (
" # --- generated code starts here, do not delete this "
"comment ---\n\n"
)
elif in_generated_section and line.startswith(
"# --- generated code ends here, do not delete this comment ---"
):
new_file_contents += " # --- generated code starts here, do not delete this comment ---\n\n"
elif in_generated_section and line.startswith("# --- generated code ends here, do not delete this comment ---"):
for op_var_name in sorted(ops.__dict__):
op = ops.__dict__[op_var_name]
if not isinstance(op, Operator) or not op.generate_expr_method:
Expand All @@ -208,11 +187,7 @@ def indent(s: str, by: int) -> str:
for name in NAMESPACES:
new_file_contents += f" {name} : {name.title()}Namespace\n"

new_file_contents += (
"@dataclasses.dataclass(slots=True)\n"
"class FnNamespace:\n"
" arg: ColExpr\n"
)
new_file_contents += "@dataclasses.dataclass(slots=True)\nclass FnNamespace:\n arg: ColExpr\n"

for name in NAMESPACES:
new_file_contents += namespace_contents[name]
Expand All @@ -234,9 +209,7 @@ def indent(s: str, by: int) -> str:

for line in file:
new_file_contents += line
if line.startswith(
"# --- from here the code is generated, do not delete this comment ---"
):
if line.startswith("# --- from here the code is generated, do not delete this comment ---"):
new_file_contents += "\n"
for op_var_name in sorted(ops.__dict__):
op = ops.__dict__[op_var_name]
Expand Down Expand Up @@ -265,11 +238,7 @@ def indent(s: str, by: int) -> str:

new_file_contents += "\n ".join(
sorted(
[
op.name
for op in ops.__dict__.values()
if isinstance(op, Operator) and op.generate_expr_method
]
[op.name for op in ops.__dict__.values() if isinstance(op, Operator) and op.generate_expr_method]
+ ["rank", "dense_rank", "map", "cast"]
)
)
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "pydiverse-transform"
version = "0.6.2"
version = "0.6.3"
description = "Pipe based dataframe manipulation library that can also transform data on SQL databases"
authors = [
{ name = "QuantCo, Inc." },
Expand Down Expand Up @@ -42,6 +42,8 @@ packages = ["src/pydiverse"]
target-version = "py311"
extend-exclude = ["docs/*"]
fix = true
# 88 leads to worse comments
line-length = 120

[tool.ruff.lint]
select = ["F", "E", "UP", "W", "I001", "I002", "B", "A"]
Expand Down
5 changes: 1 addition & 4 deletions src/pydiverse/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,5 @@
from .version import __version__

__all__ = (
["__version__", "Table", "ColExpr", "Col", "verb", "backend", "is_sql_backed"]
+ __extended
+ __types
+ __errors
["__version__", "Table", "ColExpr", "Col", "verb", "backend", "is_sql_backed"] + __extended + __types + __errors
)
Loading
Loading