diff --git a/ai_docs/COMPILE_AST_QUERY_RELATIONSHIP.md b/ai_docs/COMPILE_AST_QUERY_RELATIONSHIP.md new file mode 100644 index 00000000..422c0771 --- /dev/null +++ b/ai_docs/COMPILE_AST_QUERY_RELATIONSHIP.md @@ -0,0 +1,101 @@ +# Relationship between `compile_ast` and `compile_query` + +## Overview + +`compile_ast` and `compile_query` work together to convert the AST (Abstract Syntax Tree) into SQLAlchemy query objects: + +- **`compile_ast`**: Recursively traverses the AST and builds up intermediate state +- **`compile_query`**: Converts the intermediate state into a concrete SQLAlchemy `Select` statement + +## `compile_ast` - State Builder + +**Signature:** +```python +compile_ast(nd: AstNode, needed_cols: dict[UUID, int]) + -> tuple[sqa.Table, Query, dict[UUID, sqa.Label]] +``` + +**What it does:** +- Recursively processes the AST from leaves to root +- Incrementally builds up three pieces of state: + 1. **`table`**: A SQLAlchemy `Table` or `Subquery` object representing the data source + 2. **`query`**: A `Query` object containing metadata (select list, where clauses, group_by, etc.) + 3. **`sqa_expr`**: A dictionary mapping column UUIDs to SQLAlchemy `Label` expressions + +**Key behaviors:** +- Processes verbs by modifying the `query` and `sqa_expr` incrementally +- Tracks which columns are "needed" via `needed_cols` for subquery optimization +- Only calls `compile_query` when materializing a subquery (at `SubqueryMarker`) + +## `compile_query` - Query Materializer + +**Signature:** +```python +compile_query(table: sqa.Table, query: Query, sqa_expr: dict[UUID, sqa.ColumnElement]) + -> sqa.sql.Select +``` + +**What it does:** +- Takes the accumulated state from `compile_ast` +- Builds a concrete SQLAlchemy `Select` statement by: + 1. Starting from `table.select().select_from(table)` + 2. Adding WHERE clauses from `query.where` + 3. Adding GROUP BY from `query.group_by` + 4. Adding HAVING from `query.having` + 5. Adding LIMIT/OFFSET from `query.limit`/`query.offset` + 6. Adding ORDER BY from `query.order_by` + 7. Finally selecting only the columns in `query.select` using `with_only_columns()` + +## Important Assumptions + +### 1. **State Consistency** +When calling `compile_query`, the state must be consistent: +- All UUIDs in `query.select` must exist as keys in `sqa_expr` +- All UUIDs referenced in `query.where`, `query.group_by`, `query.having`, `query.order_by` must exist in `sqa_expr` +- The `table` must be a valid SQLAlchemy Table/Subquery that contains or can reference the columns in `sqa_expr` + +### 2. **When to Call `compile_query`** +- **Only call `compile_query` when you need to materialize a query** (e.g., for subqueries) +- During normal AST traversal, `compile_ast` just modifies the `query` and `sqa_expr` state +- Don't call `compile_query` prematurely - let `compile_ast` build up the full state first + +### 3. **Column References** +- `sqa_expr` maps UUIDs to SQLAlchemy expressions that can be used in: + - SELECT clauses + - WHERE/HAVING predicates + - GROUP BY clauses + - ORDER BY clauses +- These expressions must be valid in the context of the `table` + +### 4. **Query Object State** +The `Query` object accumulates state as verbs are processed: +- `query.select`: List of UUIDs to select (final output columns) +- `query.where`: List of predicates for WHERE clause +- `query.having`: List of predicates for HAVING clause +- `query.group_by`: List of UUIDs for GROUP BY +- `query.order_by`: List of Order objects for ORDER BY +- `query.limit`/`query.offset`: For LIMIT/OFFSET + +### 5. **Subquery Handling** +When `compile_query` is called (typically at `SubqueryMarker`): +- It creates a `Select` statement +- That `Select` is converted to a subquery via `.subquery()` +- The `sqa_expr` is updated to reference columns from the subquery +- The `query` is reset to only select the needed columns + +## Example: Union Implementation + +In the Union implementation, we: +1. Call `compile_ast` on both left and right to get their state +2. Call `compile_query` on both to get their `Select` statements +3. Use SQLAlchemy's `union`/`union_all` to combine them +4. Convert the result to a subquery + +**Important**: Before calling `sa.union`, we check if either side is already a `Subquery` and unwrap it using `.original` to get the underlying `CompoundSelect`, because you can't union two subqueries directly - you need the original `CompoundSelect` objects. + +## Common Pitfalls + +1. **Calling `compile_query` too early**: Don't call it during AST traversal unless you're materializing a subquery +2. **Inconsistent state**: Make sure all UUIDs in `query` exist in `sqa_expr` +3. **Missing column references**: Ensure columns referenced in WHERE/HAVING/etc. are in `sqa_expr` +4. **Subquery unwrapping**: When combining queries (like UNION), unwrap subqueries to get the original `CompoundSelect` diff --git a/docs/source/changelog.md b/docs/source/changelog.md index 0353313a..59a398e9 100644 --- a/docs/source/changelog.md +++ b/docs/source/changelog.md @@ -1,5 +1,8 @@ # Changelog +## 0.6.3 (2025-12-17) +- implement `tbl1 >> union(tbl2)` and `union(tbl1,tbl2)` + ## 0.6.2 (2025-12-15) - drop support for python 3.10 - fix is_sql_backend(Table) and backend(Table) diff --git a/docs/source/reference/verbs.rst b/docs/source/reference/verbs.rst index a77f37c4..7f6c2aec 100644 --- a/docs/source/reference/verbs.rst +++ b/docs/source/reference/verbs.rst @@ -30,3 +30,4 @@ Verbs slice_head summarize ungroup + union diff --git a/fuzz.py b/fuzz.py index cb30c74a..c23da5ce 100644 --- a/fuzz.py +++ b/fuzz.py @@ -31,11 +31,7 @@ pdt.Float(): rng.standard_normal, pdt.Int(): partial(rng.integers, -(1 << 13), 1 << 13), pdt.Bool(): partial(rng.integers, 0, 1, dtype=bool), - pdt.String(): ( - lambda rows: np.array( - ["".join(random.choices(letters, k=rng.poisson(10))) for _ in range(rows)] - ) - ), + pdt.String(): (lambda rows: np.array(["".join(random.choices(letters, k=rng.poisson(10))) for _ in range(rows)])), } @@ -45,45 +41,29 @@ def gen_table(rows: int, types: dict[pdt.Dtype, int]) -> pl.DataFrame: for ty, fn in RNG_FNS.items(): if ty in types: d = d.with_columns( - **{ - f"{ty.__class__.__name__.lower()} #{i + 1}": pl.lit(fn(rows)) - for i in range(types[ty]) - } + **{f"{ty.__class__.__name__.lower()} #{i + 1}": pl.lit(fn(rows)) for i in range(types[ty])} ) return d -ops_with_return_type: dict[pdt.Dtype, list[tuple[Operator, Signature]]] = { - ty: [] for ty in ALL_TYPES -} +ops_with_return_type: dict[pdt.Dtype, list[tuple[Operator, Signature]]] = {ty: [] for ty in ALL_TYPES} for op in ops.__dict__.values(): - if ( - not isinstance(op, Operator) - or op.ftype != Ftype.ELEMENT_WISE - or isinstance(op, Marker) - ): + if not isinstance(op, Operator) or op.ftype != Ftype.ELEMENT_WISE or isinstance(op, Marker): continue for sig in op.signatures: - if not all( - t in (*ALL_TYPES, Tyvar("T")) for t in (*sig.types, sig.return_type) - ): + if not all(t in (*ALL_TYPES, Tyvar("T")) for t in (*sig.types, sig.return_type)): continue - if isinstance(sig.return_type, Tyvar) or any( - isinstance(param, Tyvar) for param in sig.types - ): + if isinstance(sig.return_type, Tyvar) or any(isinstance(param, Tyvar) for param in sig.types): for ty in ALL_TYPES: rtype = ty if isinstance(sig.return_type, Tyvar) else sig.return_type ops_with_return_type[rtype].append( ( op, Signature( - *( - ty if isinstance(param, Tyvar) else param - for param in sig.types - ), + *(ty if isinstance(param, Tyvar) else param for param in sig.types), return_type=rtype, ), ) @@ -92,9 +72,7 @@ def gen_table(rows: int, types: dict[pdt.Dtype, int]) -> pl.DataFrame: ops_with_return_type[sig.return_type].append((op, sig)) -def gen_expr( - dtype: pdt.Dtype, cols: dict[pdt.Dtype, list[str]], q: float = 0.0 -) -> pdt.ColExpr: +def gen_expr(dtype: pdt.Dtype, cols: dict[pdt.Dtype, list[str]], q: float = 0.0) -> pdt.ColExpr: if dtype.const: return RNG_FNS[dtype.without_const()](1).item() @@ -114,9 +92,7 @@ def gen_expr( if sig.is_vararg: nargs = int(rng.normal(2.5, 1 / 1.5)) for _ in range(nargs): - args.append( - gen_expr(sig.types[-1], cols, q + rng.exponential(1 / MEAN_HEIGHT)) - ) + args.append(gen_expr(sig.types[-1], cols, q + rng.exponential(1 / MEAN_HEIGHT))) return ColFn(op, *args) @@ -132,16 +108,10 @@ def gen_expr( tables = {backend: fn(df, "t") for backend, fn in BACKEND_TABLES.items()} -cols = { - dtype: [col.name for col in tables["polars"] if col.dtype() <= dtype] - for dtype in ALL_TYPES -} +cols = {dtype: [col.name for col in tables["polars"] if col.dtype() <= dtype] for dtype in ALL_TYPES} for _ in range(it): expr = gen_expr(rng.choice(ALL_TYPES), cols) - results = { - backend: table >> mutate(y=expr) >> select(C.y) >> export(Polars()) - for backend, table in tables.items() - } + results = {backend: table >> mutate(y=expr) >> select(C.y) >> export(Polars()) for backend, table in tables.items()} for _backend, res in results: assert_frame_equal(results["polars"], res) diff --git a/generate_col_ops.py b/generate_col_ops.py index 2d24785e..9ec60181 100644 --- a/generate_col_ops.py +++ b/generate_col_ops.py @@ -38,9 +38,7 @@ def add_vararg_star(formatted_args: str) -> str: def type_annotation(dtype: Dtype, specialize_generic: bool) -> str: - if (not specialize_generic and not types.is_const(dtype)) or isinstance( - types.without_const(dtype), Tyvar - ): + if (not specialize_generic and not types.is_const(dtype)) or isinstance(types.without_const(dtype), Tyvar): return "ColExpr" if types.is_const(dtype): python_type = types.to_python(dtype) @@ -48,26 +46,18 @@ def type_annotation(dtype: Dtype, specialize_generic: bool) -> str: return f"ColExpr[{dtype.__class__.__name__}]" -def generate_fn_decl( - op: Operator, sig: Signature, *, name=None, specialize_generic: bool = True -) -> str: +def generate_fn_decl(op: Operator, sig: Signature, *, name=None, specialize_generic: bool = True) -> str: if name is None: name = op.name - defaults: Iterable = ( - op.default_values - if op.default_values is not None - else (... for _ in op.param_names) - ) + defaults: Iterable = op.default_values if op.default_values is not None else (... for _ in op.param_names) annotated_args = ", ".join( name + ": " + type_annotation(dtype, specialize_generic) + (f" = {repr(default_val)}" if default_val is not ... else "") - for dtype, name, default_val in zip( - sig.types, op.param_names, defaults, strict=True - ) + for dtype, name, default_val in zip(sig.types, op.param_names, defaults, strict=True) ) if sig.is_vararg: annotated_args = add_vararg_star(annotated_args) @@ -80,8 +70,7 @@ def generate_fn_decl( } annotated_kwargs = "".join( - f", {kwarg.name}: {context_kwarg_annotation[kwarg.name]}" - + f"{'' if kwarg.required else ' | None = None'}" + f", {kwarg.name}: {context_kwarg_annotation[kwarg.name]}" + f"{'' if kwarg.required else ' | None = None'}" for kwarg in op.context_kwargs ) @@ -93,8 +82,7 @@ def generate_fn_decl( annotated_kwargs = "" return ( - f"def {name}({annotated_args}{annotated_kwargs}) " - f"-> {type_annotation(sig.return_type, specialize_generic)}:\n" + f"def {name}({annotated_args}{annotated_kwargs}) -> {type_annotation(sig.return_type, specialize_generic)}:\n" ) @@ -126,9 +114,7 @@ def generate_fn_body( return f" return ColFn(ops.{op_var_name}{args}{kwargs})\n\n" -def generate_overloads( - op: Operator, *, name: str | None = None, rversion: bool = False, op_var_name: str -): +def generate_overloads(op: Operator, *, name: str | None = None, rversion: bool = False, op_var_name: str): res = "" in_namespace = "." in op.name if name is None: @@ -140,9 +126,7 @@ def generate_overloads( res += "@overload\n" + generate_fn_decl(op, sig, name=name) + " ...\n\n" res += ( - generate_fn_decl( - op, op.signatures[0], name=name, specialize_generic=not has_overloads - ) + generate_fn_decl(op, op.signatures[0], name=name, specialize_generic=not has_overloads) + f' """\n{op.doc.strip()}\n"""\n\n' + generate_fn_body( op, @@ -177,13 +161,8 @@ def indent(s: str, by: int) -> str: " # --- generated code starts here, do not delete this comment ---" ): in_generated_section = True - new_file_contents += ( - " # --- generated code starts here, do not delete this " - "comment ---\n\n" - ) - elif in_generated_section and line.startswith( - "# --- generated code ends here, do not delete this comment ---" - ): + new_file_contents += " # --- generated code starts here, do not delete this comment ---\n\n" + elif in_generated_section and line.startswith("# --- generated code ends here, do not delete this comment ---"): for op_var_name in sorted(ops.__dict__): op = ops.__dict__[op_var_name] if not isinstance(op, Operator) or not op.generate_expr_method: @@ -208,11 +187,7 @@ def indent(s: str, by: int) -> str: for name in NAMESPACES: new_file_contents += f" {name} : {name.title()}Namespace\n" - new_file_contents += ( - "@dataclasses.dataclass(slots=True)\n" - "class FnNamespace:\n" - " arg: ColExpr\n" - ) + new_file_contents += "@dataclasses.dataclass(slots=True)\nclass FnNamespace:\n arg: ColExpr\n" for name in NAMESPACES: new_file_contents += namespace_contents[name] @@ -234,9 +209,7 @@ def indent(s: str, by: int) -> str: for line in file: new_file_contents += line - if line.startswith( - "# --- from here the code is generated, do not delete this comment ---" - ): + if line.startswith("# --- from here the code is generated, do not delete this comment ---"): new_file_contents += "\n" for op_var_name in sorted(ops.__dict__): op = ops.__dict__[op_var_name] @@ -265,11 +238,7 @@ def indent(s: str, by: int) -> str: new_file_contents += "\n ".join( sorted( - [ - op.name - for op in ops.__dict__.values() - if isinstance(op, Operator) and op.generate_expr_method - ] + [op.name for op in ops.__dict__.values() if isinstance(op, Operator) and op.generate_expr_method] + ["rank", "dense_rank", "map", "cast"] ) ) diff --git a/pyproject.toml b/pyproject.toml index 687102fa..0ec6e07a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "pydiverse-transform" -version = "0.6.2" +version = "0.6.3" description = "Pipe based dataframe manipulation library that can also transform data on SQL databases" authors = [ { name = "QuantCo, Inc." }, @@ -42,6 +42,8 @@ packages = ["src/pydiverse"] target-version = "py311" extend-exclude = ["docs/*"] fix = true +# 88 leads to worse comments +line-length = 120 [tool.ruff.lint] select = ["F", "E", "UP", "W", "I001", "I002", "B", "A"] diff --git a/src/pydiverse/transform/__init__.py b/src/pydiverse/transform/__init__.py index a05b3b63..1d73e977 100644 --- a/src/pydiverse/transform/__init__.py +++ b/src/pydiverse/transform/__init__.py @@ -13,8 +13,5 @@ from .version import __version__ __all__ = ( - ["__version__", "Table", "ColExpr", "Col", "verb", "backend", "is_sql_backed"] - + __extended - + __types - + __errors + ["__version__", "Table", "ColExpr", "Col", "verb", "backend", "is_sql_backed"] + __extended + __types + __errors ) diff --git a/src/pydiverse/transform/_internal/backend/duckdb.py b/src/pydiverse/transform/_internal/backend/duckdb.py index 33c78362..24862b3a 100644 --- a/src/pydiverse/transform/_internal/backend/duckdb.py +++ b/src/pydiverse/transform/_internal/backend/duckdb.py @@ -44,9 +44,7 @@ def export( @classmethod def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> Cast: if cast.val.dtype().is_float() and cast.target_type.is_int(): - return cls.cast_compiled( - cast, sqa.func.trunc(cls.compile_col_expr(cast.val, sqa_col)) - ) + return cls.cast_compiled(cast, sqa.func.trunc(cls.compile_col_expr(cast.val, sqa_col))) return super().compile_cast(cast, sqa_col) @@ -57,9 +55,7 @@ def compile_lit(cls, lit: LiteralCol) -> sqa.ColumnElement: return super().compile_lit(lit) @classmethod - def fix_fn_types( - cls, fn: sql.ColFn, val: sqa.ColumnElement, *args: sqa.ColumnElement - ) -> sqa.ColumnElement: + def fix_fn_types(cls, fn: sql.ColFn, val: sqa.ColumnElement, *args: sqa.ColumnElement) -> sqa.ColumnElement: if fn.op in (ops.sum, ops.cum_sum): return sqa.cast(val, type_=args[0].type) return val diff --git a/src/pydiverse/transform/_internal/backend/duckdb_polars.py b/src/pydiverse/transform/_internal/backend/duckdb_polars.py index 16f9288c..5b735077 100644 --- a/src/pydiverse/transform/_internal/backend/duckdb_polars.py +++ b/src/pydiverse/transform/_internal/backend/duckdb_polars.py @@ -28,19 +28,13 @@ def __init__(self, name: str | None, df: pl.DataFrame | pl.LazyFrame): super().__init__( name, - { - name: Dtype.from_polars(dtype) - for name, dtype in df.collect_schema().items() - }, + {name: Dtype.from_polars(dtype) for name, dtype in df.collect_schema().items()}, ) self.table = sqa.Table( name or "", sqa.MetaData(), - *( - sqa.Column(col.name, DuckDbImpl.sqa_type(col.dtype())) - for col in self.cols.values() - ), + *(sqa.Column(col.name, DuckDbImpl.sqa_type(col.dtype())) for col in self.cols.values()), ) def _table_def_repr(self) -> str: @@ -81,8 +75,5 @@ def _clone(self) -> tuple[AstNode, dict[AstNode, AstNode], dict[UUID, UUID]]: return ( cloned, {self: cloned}, - { - self.cols[name]._uuid: cloned.cols[name]._uuid - for name in self.cols.keys() - }, + {self.cols[name]._uuid: cloned.cols[name]._uuid for name in self.cols.keys()}, ) diff --git a/src/pydiverse/transform/_internal/backend/ibm_db2.py b/src/pydiverse/transform/_internal/backend/ibm_db2.py index c6e87bcf..50ed881e 100644 --- a/src/pydiverse/transform/_internal/backend/ibm_db2.py +++ b/src/pydiverse/transform/_internal/backend/ibm_db2.py @@ -42,30 +42,22 @@ def dialect_order_append_rand(cls): @impl(ops.horizontal_min) def _horizontal_min(*x): if len(x) == 1: - return sqa.func.LEAST( - x[0], x[0] - ) # DB2 does not support LEAST with a single argument + return sqa.func.LEAST(x[0], x[0]) # DB2 does not support LEAST with a single argument else: # the generated query will look extremely ugly but LEAST should be non-NULL # if any of the arguments is non-NULL any_non_null = sqa.func.COALESCE(*x) - return sqa.func.LEAST( - *[sqa.func.COALESCE(element, any_non_null) for element in x] - ) + return sqa.func.LEAST(*[sqa.func.COALESCE(element, any_non_null) for element in x]) @impl(ops.horizontal_max) def _horizontal_max(*x): if len(x) == 1: - return sqa.func.GREATEST( - x[0], x[0] - ) # DB2 does not support LEAST with a single argument + return sqa.func.GREATEST(x[0], x[0]) # DB2 does not support LEAST with a single argument else: # the generated query will look extremely ugly but LEAST should be non-NULL # if any of the arguments is non-NULL any_non_null = sqa.func.COALESCE(*x) - return sqa.func.GREATEST( - *[sqa.func.COALESCE(element, any_non_null) for element in x] - ) + return sqa.func.GREATEST(*[sqa.func.COALESCE(element, any_non_null) for element in x]) @impl(ops.dt_second) def _dt_second(x): @@ -92,9 +84,7 @@ def _day_of_week(x): @impl(ops.cbrt) def _cbrt(x): pow_impl = IbmDb2Impl.get_impl(ops.pow, (Float(), Float())) - return sqa.func.sign(x) * pow_impl( - sqa.func.abs(x), sqa.literal(1 / 3, type_=sqa.Double) - ) + return sqa.func.sign(x) * pow_impl(sqa.func.abs(x), sqa.literal(1 / 3, type_=sqa.Double)) @impl(ops.rand) def _rand(): diff --git a/src/pydiverse/transform/_internal/backend/mssql.py b/src/pydiverse/transform/_internal/backend/mssql.py index ffdcc1d2..53747024 100644 --- a/src/pydiverse/transform/_internal/backend/mssql.py +++ b/src/pydiverse/transform/_internal/backend/mssql.py @@ -74,9 +74,7 @@ def build_select(cls, nd: AstNode, *, final_select: list[Col] | None = None) -> desc.map_col_roots( functools.partial( convert_bool_bit, - desired_return_type="bool" - if isinstance(desc, verbs.Filter | verbs.Join) - else "bit", + desired_return_type="bool" if isinstance(desc, verbs.Filter | verbs.Join) else "bit", ) ) @@ -86,15 +84,11 @@ def build_select(cls, nd: AstNode, *, final_select: list[Col] | None = None) -> desc.order_by = convert_order_list(desc.order_by) if isinstance(desc, verbs.Verb): for node in desc.iter_col_nodes(): - if isinstance(node, ColFn) and ( - arrange := node.context_kwargs.get("arrange") - ): + if isinstance(node, ColFn) and (arrange := node.context_kwargs.get("arrange")): node.context_kwargs["arrange"] = convert_order_list(arrange) sql.create_aliases(nd, {}) - table, query, sqa_expr = cls.compile_ast( - nd, {col._uuid: 1 for col in final_select} - ) + table, query, sqa_expr = cls.compile_ast(nd, {col._uuid: 1 for col in final_select}) # mssql complains about OFFSET if there is no ORDER BY if query.offset and not query.order_by: @@ -103,19 +97,12 @@ def build_select(cls, nd: AstNode, *, final_select: list[Col] | None = None) -> return cls.compile_query(table, query, sqa_expr) @classmethod - def compile_ordered_aggregation( - cls, *args: sqa.ColumnElement, order_by: list[sqa.UnaryExpression], impl - ): + def compile_ordered_aggregation(cls, *args: sqa.ColumnElement, order_by: list[sqa.UnaryExpression], impl): return impl(*args).within_group(*order_by) @classmethod - def fix_fn_types( - cls, fn: ColFn, val: sqa.ColumnElement, *args: sqa.ColumnElement - ) -> sqa.ColumnElement: - if ( - fn.op in (ops.any, ops.all, ops.min, ops.max) - and types.without_const(fn.dtype()) == Bool() - ): + def fix_fn_types(cls, fn: ColFn, val: sqa.ColumnElement, *args: sqa.ColumnElement) -> sqa.ColumnElement: + if fn.op in (ops.any, ops.all, ops.min, ops.max) and types.without_const(fn.dtype()) == Bool(): return val.cast(BIT) return val @@ -189,8 +176,7 @@ def convert_bool_bit( result = copy.copy(expr) result.args = list(convert_bool_bit(arg, desired_arg_type) for arg in expr.args) result.context_kwargs = { - key: [convert_bool_bit(val, "bit") for val in arr] - for key, arr in expr.context_kwargs.items() + key: [convert_bool_bit(val, "bit") for val in arr] for key, arr in expr.context_kwargs.items() } return_type = ( @@ -222,23 +208,14 @@ def convert_bool_bit( elif isinstance(expr, CaseExpr): return_type = "bit" result = copy.copy(expr) - result.cases = [ - (convert_bool_bit(cond, "bool"), convert_bool_bit(val, "bit")) - for cond, val in expr.cases - ] - result.default_val = ( - None - if expr.default_val is None - else convert_bool_bit(expr.default_val, "bit") - ) + result.cases = [(convert_bool_bit(cond, "bool"), convert_bool_bit(val, "bit")) for cond, val in expr.cases] + result.default_val = None if expr.default_val is None else convert_bool_bit(expr.default_val, "bit") elif isinstance(expr, LiteralCol): return_type = "bit" elif isinstance(expr, Cast): - return Cast( - convert_bool_bit(expr.val, "bit"), expr.target_type, strict=expr.strict - ) + return Cast(convert_bool_bit(expr.val, "bit"), expr.target_type, strict=expr.strict) if types.without_const(expr.dtype()) == Bool(): if desired_return_type == "bool" and return_type == "bit": @@ -257,46 +234,38 @@ def convert_bool_bit( @impl(ops.equal, String(), String()) def _eq(x, y): - return (sqa.func.LENGTH(x + "a") == sqa.func.LENGTH(y + "a")) & ( - x.collate("Latin1_General_bin") == y - ) + return (sqa.func.LENGTH(x + "a") == sqa.func.LENGTH(y + "a")) & (x.collate("Latin1_General_bin") == y) @impl(ops.not_equal, String(), String()) def _ne(x, y): - return (sqa.func.LENGTH(x + "a") != sqa.func.LENGTH(y + "a")) | ( - x.collate("Latin1_General_bin") != y - ) + return (sqa.func.LENGTH(x + "a") != sqa.func.LENGTH(y + "a")) | (x.collate("Latin1_General_bin") != y) @impl(ops.less_than, String(), String()) def _lt(x, y): y_ = sqa.func.SUBSTRING(y, 1, sqa.func.LENGTH(x + "a") - 1) return (x.collate("Latin1_General_bin") < y_) | ( - (sqa.func.LENGTH(x + "a") < sqa.func.LENGTH(y + "a")) - & (x.collate("Latin1_General_bin") == y_) + (sqa.func.LENGTH(x + "a") < sqa.func.LENGTH(y + "a")) & (x.collate("Latin1_General_bin") == y_) ) @impl(ops.less_equal, String(), String()) def _le(x, y): y_ = sqa.func.SUBSTRING(y, 1, sqa.func.LENGTH(x + "a") - 1) return (x.collate("Latin1_General_bin") < y_) | ( - (sqa.func.LENGTH(x + "a") <= sqa.func.LENGTH(y + "a")) - & (x.collate("Latin1_General_bin") == y_) + (sqa.func.LENGTH(x + "a") <= sqa.func.LENGTH(y + "a")) & (x.collate("Latin1_General_bin") == y_) ) @impl(ops.greater_than, String(), String()) def _gt(x, y): y_ = sqa.func.SUBSTRING(y, 1, sqa.func.LENGTH(x + "a") - 1) return (x.collate("Latin1_General_bin") > y_) | ( - (sqa.func.LENGTH(x + "a") > sqa.func.LENGTH(y + "a")) - & (x.collate("Latin1_General_bin") == y_) + (sqa.func.LENGTH(x + "a") > sqa.func.LENGTH(y + "a")) & (x.collate("Latin1_General_bin") == y_) ) @impl(ops.greater_equal, String(), String()) def _ge(x, y): y_ = sqa.func.SUBSTRING(y, 1, sqa.func.LENGTH(x + "a") - 1) return (x.collate("Latin1_General_bin") > y_) | ( - (sqa.func.LENGTH(x + "a") >= sqa.func.LENGTH(y + "a")) - & (x.collate("Latin1_General_bin") == y_) + (sqa.func.LENGTH(x + "a") >= sqa.func.LENGTH(y + "a")) & (x.collate("Latin1_General_bin") == y_) ) @impl(ops.str_len) @@ -401,9 +370,7 @@ def _truediv(x, y): @impl(ops.cbrt) def _cbrt(x): - return sqa.func.sign(x) * _pow( - sqa.func.abs(x), sqa.literal(1 / 3, type_=sqa.Double) - ) + return sqa.func.sign(x) * _pow(sqa.func.abs(x), sqa.literal(1 / 3, type_=sqa.Double)) @impl(ops.rand) def _rand(): diff --git a/src/pydiverse/transform/_internal/backend/polars.py b/src/pydiverse/transform/_internal/backend/polars.py index 97038818..b82c9e2c 100644 --- a/src/pydiverse/transform/_internal/backend/polars.py +++ b/src/pydiverse/transform/_internal/backend/polars.py @@ -50,10 +50,7 @@ def __init__(self, name: str, df: pl.DataFrame | pl.LazyFrame): ) super().__init__( name, - { - name: Dtype.from_polars(pl_type) - for name, pl_type in df.collect_schema().items() - }, + {name: Dtype.from_polars(pl_type) for name, pl_type in df.collect_schema().items()}, ) def _table_def_repr(self) -> str: @@ -90,10 +87,7 @@ def _clone(self) -> tuple["PolarsImpl", dict[AstNode, AstNode], dict[UUID, UUID] return ( cloned, {self: cloned}, - { - self.cols[name]._uuid: cloned.cols[name]._uuid - for name in self.cols.keys() - }, + {self.cols[name]._uuid: cloned.cols[name]._uuid for name in self.cols.keys()}, ) @@ -119,9 +113,7 @@ def merge_desc_nulls_last( return merged -def compile_order( - order: Order, name_in_df: dict[UUID, str] -) -> tuple[pl.Expr, bool, bool | None]: +def compile_order(order: Order, name_in_df: dict[UUID, str]) -> tuple[pl.Expr, bool, bool | None]: return ( compile_col_expr(order.order_by, name_in_df), order.descending, @@ -141,15 +133,10 @@ def compile_col_expr( elif isinstance(expr, ColFn): impl = PolarsImpl.get_impl(expr.op, tuple(arg.dtype() for arg in expr.args)) - args: list[pl.Expr] = [ - compile_col_expr(arg, name_in_df, op_kwargs=op_kwargs) for arg in expr.args - ] + args: list[pl.Expr] = [compile_col_expr(arg, name_in_df, op_kwargs=op_kwargs) for arg in expr.args] if (partition_by := expr.context_kwargs.get("partition_by")) is not None: - partition_by = [ - compile_col_expr(pb, name_in_df, op_kwargs=op_kwargs) - for pb in partition_by - ] + partition_by = [compile_col_expr(pb, name_in_df, op_kwargs=op_kwargs) for pb in partition_by] arrange = expr.context_kwargs.get("arrange") if arrange: @@ -221,13 +208,11 @@ def compile_col_expr( assert len(expr.cases) >= 1 compiled = pl # to initialize the when/then-chain for cond, val in expr.cases: - compiled = compiled.when( - compile_col_expr(cond, name_in_df, op_kwargs=op_kwargs) - ).then(compile_col_expr(val, name_in_df, op_kwargs=op_kwargs)) - if expr.default_val is not None: - compiled = compiled.otherwise( - compile_col_expr(expr.default_val, name_in_df, op_kwargs=op_kwargs) + compiled = compiled.when(compile_col_expr(cond, name_in_df, op_kwargs=op_kwargs)).then( + compile_col_expr(val, name_in_df, op_kwargs=op_kwargs) ) + if expr.default_val is not None: + compiled = compiled.otherwise(compile_col_expr(expr.default_val, name_in_df, op_kwargs=op_kwargs)) return compiled elif isinstance(expr, LiteralCol): @@ -239,18 +224,15 @@ def compile_col_expr( ) elif isinstance(expr, Cast): - if ( - expr.target_type.is_int() or expr.target_type.is_float() - ) and types.without_const(expr.val.dtype()) == String(): + if (expr.target_type.is_int() or expr.target_type.is_float()) and types.without_const( + expr.val.dtype() + ) == String(): expr.val = expr.val.str.strip() compiled = compile_col_expr(expr.val, name_in_df, op_kwargs=op_kwargs) compiled = compiled.cast(expr.target_type.to_polars(), strict=expr.strict) - if ( - types.without_const(expr.val.dtype()).is_float() - and expr.target_type == String() - ): + if types.without_const(expr.val.dtype()).is_float() and expr.target_type == String(): compiled = compiled.replace("NaN", "nan") return compiled @@ -276,13 +258,8 @@ def rename_overwritten_cols( overwritten = names_to_consider.intersection(new_names) if overwritten: - name_map = { - name: f"{name}:{str(hex(uuid.uuid1().int))[2:]}" for name in overwritten - } - name_in_df = { - uid: (name_map[name] if name in name_map else name) - for uid, name in name_in_df.items() - } + name_map = {name: f"{name}:{str(hex(uuid.uuid1().int))[2:]}" for name in overwritten} + name_in_df = {uid: (name_map[name] if name in name_map else name) for uid, name in name_in_df.items()} df = df.rename(name_map) return df, name_in_df @@ -329,22 +306,14 @@ def compile_ast( elif isinstance(nd, verbs.Rename): df = df.rename(nd.name_map) - name_in_df = { - uid: (nd.name_map[name] if name in nd.name_map else name) - for uid, name in name_in_df.items() - } + name_in_df = {uid: (nd.name_map[name] if name in nd.name_map else name) for uid, name in name_in_df.items()} elif isinstance(nd, verbs.Mutate): df = df.with_columns( - **{ - name: compile_col_expr(value, name_in_df) - for name, value in zip(nd.names, nd.values, strict=True) - } + **{name: compile_col_expr(value, name_in_df) for name, value in zip(nd.names, nd.values, strict=True)} ) - name_in_df.update( - {uid: name for uid, name in zip(nd.uuids, nd.names, strict=True)} - ) + name_in_df.update({uid: name for uid, name in zip(nd.uuids, nd.names, strict=True)}) elif isinstance(nd, verbs.Filter): if nd.predicates: @@ -369,17 +338,13 @@ def has_path_to_leaf_without_agg(expr: ColExpr): return True if isinstance(expr, ColFn) and expr.op.ftype == Ftype.AGGREGATE: return False - return any( - has_path_to_leaf_without_agg(child) for child in expr.iter_children() - ) + return any(has_path_to_leaf_without_agg(child) for child in expr.iter_children()) aggregations = {} for name, val in zip(nd.names, nd.values, strict=True): # For some aggregations, a different polars function must be used if there # is not grouping. (In this case, `df.select` is called.) - compiled = compile_col_expr( - val, name_in_df, op_kwargs={"_empty_group_by": len(partition_by) == 0} - ) + compiled = compile_col_expr(val, name_in_df, op_kwargs={"_empty_group_by": len(partition_by) == 0}) if has_path_to_leaf_without_agg(val): compiled = compiled.first() aggregations[name] = compiled @@ -393,12 +358,8 @@ def has_path_to_leaf_without_agg(expr: ColExpr): # we have to remove the columns here for the join hidden column rename to work # correctly (otherwise it would try to rename hidden columns that do not exist) - name_in_df = { - name: uuid for name, uuid in name_in_df.items() if name in partition_by - } - name_in_df.update( - {uid: name for name, uid in zip(nd.names, nd.uuids, strict=True)} - ) + name_in_df = {name: uuid for name, uuid in name_in_df.items() if name in partition_by} + name_in_df.update({uid: name for name, uid in zip(nd.names, nd.uuids, strict=True)}) partition_by = [] elif isinstance(nd, verbs.SliceHead): @@ -427,9 +388,7 @@ def has_path_to_leaf_without_agg(expr: ColExpr): right_df, right_name_in_df = rename_overwritten_cols( set(name_in_df[uid] for uid in select), right_df, right_name_in_df ) - df, name_in_df = rename_overwritten_cols( - set(right_name_in_df[uid] for uid in right_select), df, name_in_df - ) + df, name_in_df = rename_overwritten_cols(set(right_name_in_df[uid] for uid in right_select), df, name_in_df) # hidden columns right_df, right_name_in_df = rename_overwritten_cols( @@ -442,9 +401,7 @@ def has_path_to_leaf_without_agg(expr: ColExpr): name_in_df.update(right_name_in_df) eq_predicates = [pred for pred in predicates if pred.op == ops.equal] - left_on, right_on = get_left_right_on( - eq_predicates, name_in_df, right_name_in_df - ) + left_on, right_on = get_left_right_on(eq_predicates, name_in_df, right_name_in_df) # If there are only equality predicates, use normal join. Else use join_where if len(eq_predicates) == len(predicates): @@ -465,9 +422,7 @@ def has_path_to_leaf_without_agg(expr: ColExpr): else: assert nd.how != "full" if nd.how == "left": - df = df.with_columns( - __INDEX__=pl.int_range(0, pl.len(), dtype=pl.Int64) - ) + df = df.with_columns(__INDEX__=pl.int_range(0, pl.len(), dtype=pl.Int64)) joined = df.join_where( right_df, @@ -486,6 +441,43 @@ def has_path_to_leaf_without_agg(expr: ColExpr): select += right_select + elif isinstance(nd, verbs.Union): + assert len(partition_by) == 0 + + right_df, right_name_in_df, right_select, _ = compile_ast(nd.right) + + # Drop all hidden columns from both dataframes before union + # Note: The union verb already validates that visible column names match between tables, + # so we can use left_col_names for both dataframes. This also ensures columns are in the + # same order, which Polars requires for union operations. + left_col_names = [name_in_df[uid] for uid in select] + df = df.select(*left_col_names) + right_df = right_df.select(*left_col_names) + + # Use pl.union if available (Polars >= 1.35), otherwise use pl.concat + # pl.union is faster than pl.concat for union operations + # distinct=True means UNION (remove duplicates), distinct=False means UNION ALL (keep duplicates) + try: + # Try to use pl.union (available in Polars >= 1.35) + # pl.union takes a list of DataFrames/LazyFrames and has a distinct parameter + if nd.distinct: + # For UNION (distinct), use union with distinct=True + df = pl.union([df, right_df], distinct=True) + else: + # For UNION ALL (not distinct), use union without distinct + df = pl.union([df, right_df]) + except (AttributeError, TypeError): + # Fall back to pl.concat for older Polars versions (< 1.35) + if nd.distinct: + # For UNION (distinct), we need to deduplicate + # Polars doesn't have a direct UNION without ALL, so we concat and then distinct + df = pl.concat([df, right_df]).unique() + else: + # For UNION ALL (not distinct), just concat + df = pl.concat([df, right_df]) + + # name_in_df and select remain the same (from left table) + elif isinstance(nd, PolarsImpl): df = nd.df name_in_df = {col._uuid: col.name for col in nd.cols.values()} @@ -791,6 +783,4 @@ def _clip(x, lower, upper): @impl(ops.rand) def _rand(): - return pl.int_range(pl.len()).map_elements( - lambda x: random.random(), pl.Float64() - ) + return pl.int_range(pl.len()).map_elements(lambda x: random.random(), pl.Float64()) diff --git a/src/pydiverse/transform/_internal/backend/postgres.py b/src/pydiverse/transform/_internal/backend/postgres.py index a46547ce..7e7187ae 100644 --- a/src/pydiverse/transform/_internal/backend/postgres.py +++ b/src/pydiverse/transform/_internal/backend/postgres.py @@ -18,9 +18,7 @@ class PostgresImpl(SqlImpl): def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> Cast: if types.without_const(cast.val.dtype()).is_float(): if cast.target_type.is_int(): - return cls.cast_compiled( - cast, sqa.func.trunc(cls.compile_col_expr(cast.val, sqa_col)) - ) + return cls.cast_compiled(cast, sqa.func.trunc(cls.compile_col_expr(cast.val, sqa_col))) if cast.target_type == String(): compiled = super().compile_cast(cast, sqa_col) @@ -31,10 +29,7 @@ def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> Cast: else_=compiled, ) - if ( - types.without_const(cast.val.dtype()) == Bool() - and cast.target_type == Int64() - ): + if types.without_const(cast.val.dtype()) == Bool() and cast.target_type == Int64(): # postgres does not like casts bool -> bigint, so we go via int return cls.compile_cast(Cast(cast.val, Int32()).cast(Int64()), sqa_col) @@ -50,9 +45,7 @@ def cast_compiled(cls, cast: Cast, compiled_expr: sqa.ColumnElement) -> sqa.Cast ( sqa.func.pg_input_is_valid( compiled_expr, - str(cls.sqa_type(cast.target_type)) - .lower() - .replace(" ", ""), + str(cls.sqa_type(cast.target_type)).lower().replace(" ", ""), ), compiled_expr, ), @@ -110,9 +103,7 @@ def sqa_type(cls, pdt_type: Dtype): return super().sqa_type(pdt_type) @classmethod - def fix_fn_types( - cls, fn: ColFn, val: sqa.ColumnElement, *args: sqa.ColumnElement - ) -> sqa.ColumnElement: + def fix_fn_types(cls, fn: ColFn, val: sqa.ColumnElement, *args: sqa.ColumnElement) -> sqa.ColumnElement: if isinstance(fn.op, ops.DatetimeExtract | ops.DateExtract): return sqa.cast(val, sqa.BigInteger) elif fn.op in (ops.sum, ops.cum_sum): @@ -155,9 +146,7 @@ def _round(x, decimals=0): if isinstance(x.type, sqa.Float): # Postgres doesn't support rounding of doubles to specific precision # -> Must first cast to numeric - return sqa.func.ROUND( - sqa.cast(x, sqa.Numeric), decimals, type_=sqa.Numeric - ).cast(x.type) + return sqa.func.ROUND(sqa.cast(x, sqa.Numeric), decimals, type_=sqa.Numeric).cast(x.type) return sqa.func.ROUND(x, decimals, type_=x.type) @@ -171,10 +160,7 @@ def _dt_millisecond(x): @impl(ops.dt_microsecond) def _dt_microsecond(x): - return ( - sqa.func.FLOOR(sqa.extract("microsecond", x), type_=sqa.Integer()) - % 1_000_000 - ) + return sqa.func.FLOOR(sqa.extract("microsecond", x), type_=sqa.Integer()) % 1_000_000 @impl(ops.horizontal_max, String(), String(), ...) def _horizontal_max(*x): diff --git a/src/pydiverse/transform/_internal/backend/sql.py b/src/pydiverse/transform/_internal/backend/sql.py index bf5dd9af..922daf58 100644 --- a/src/pydiverse/transform/_internal/backend/sql.py +++ b/src/pydiverse/transform/_internal/backend/sql.py @@ -59,18 +59,9 @@ class Query: class SqlImpl(TableImpl): def __new__(cls, *args, **kwargs) -> "SqlImpl": - engine: str | sqa.Engine = ( - inspect.signature(cls.__init__) - .bind(None, *args, **kwargs) - .arguments["conf"] - .engine - ) + engine: str | sqa.Engine = inspect.signature(cls.__init__).bind(None, *args, **kwargs).arguments["conf"].engine - dialect = ( - engine.dialect.name - if isinstance(engine, sqa.Engine) - else sqa.make_url(engine).get_dialect().name - ) + dialect = engine.dialect.name if isinstance(engine, sqa.Engine) else sqa.make_url(engine).get_dialect().name # We don't want to import any SQL impls we don't use, so the mapping # name -> impl class is defined here. @@ -100,10 +91,7 @@ def __new__(cls, *args, **kwargs) -> "SqlImpl": Impl = IbmDb2Impl else: - warn( - f"Pydiverse transform is not tested for dialect '{dialect}'. " - f"Assuming Postgres compatible SQL." - ) + warn(f"Pydiverse transform is not tested for dialect '{dialect}'. Assuming Postgres compatible SQL.") from .postgres import PostgresImpl Impl = PostgresImpl @@ -113,15 +101,9 @@ def __new__(cls, *args, **kwargs) -> "SqlImpl": def __init__(self, table: str | sqa.Table, conf: SqlAlchemy, name: str | None): assert type(self) is not SqlImpl - self.engine = ( - conf.engine - if isinstance(conf.engine, sqa.Engine) - else sqa.create_engine(conf.engine) - ) + self.engine = conf.engine if isinstance(conf.engine, sqa.Engine) else sqa.create_engine(conf.engine) if isinstance(table, str): - self.table = sqa.Table( - table, sqa.MetaData(), schema=conf.schema, autoload_with=self.engine - ) + self.table = sqa.Table(table, sqa.MetaData(), schema=conf.schema, autoload_with=self.engine) else: self.table = table @@ -147,10 +129,7 @@ def _clone(self) -> tuple["SqlImpl", dict[AstNode, AstNode], dict[UUID, UUID]]: return ( cloned, {self: cloned}, - { - self.cols[name]._uuid: cloned.cols[name]._uuid - for name in self.cols.keys() - }, + {self.cols[name]._uuid: cloned.cols[name]._uuid for name in self.cols.keys()}, ) @classmethod @@ -166,15 +145,11 @@ def default_collation(cls) -> str | None: return "POSIX" @classmethod - def build_select( - cls, nd: AstNode, *, final_select: list[Col] | None = None - ) -> sqa.Select: + def build_select(cls, nd: AstNode, *, final_select: list[Col] | None = None) -> sqa.Select: if final_select is None: final_select = Cache.from_ast(nd).selected_cols() create_aliases(nd, {}) - nd, query, sqa_expr = cls.compile_ast( - nd, {col._uuid: 1 for col in final_select} - ) + nd, query, sqa_expr = cls.compile_ast(nd, {col._uuid: 1 for col in final_select}) return cls.compile_query(nd, query, sqa_expr) @classmethod @@ -197,16 +172,12 @@ def export( connection=conn, schema_overrides={ sql_col.name: schema_overrides[col._uuid] - for sql_col, col in zip( - sel.columns.values(), final_select, strict=True - ) + for sql_col, col in zip(sel.columns.values(), final_select, strict=True) if col._uuid in schema_overrides } | { sql_col.name: NullType().to_polars() - for sql_col, col in zip( - sel.columns.values(), final_select, strict=True - ) + for sql_col, col in zip(sel.columns.values(), final_select, strict=True) if types.without_const(col.dtype()) == NullType() }, ) @@ -238,15 +209,11 @@ def compile_lit(cls, lit: LiteralCol): # Often, floats are incorrectly converted to decimals # TODO: ensure in tests that the dtype not only match after export to # polars, but also really in the backend - return sqa.cast( - sqa.literal(lit.val, literal_execute=True), cls.sqa_type(lit.dtype()) - ) + return sqa.cast(sqa.literal(lit.val, literal_execute=True), cls.sqa_type(lit.dtype())) return sqa.literal(lit.val, cls.sqa_type(lit.dtype()), literal_execute=True) @classmethod - def compile_order( - cls, order: Order, sqa_expr: dict[str, sqa.Label] - ) -> sqa.UnaryExpression: + def compile_order(cls, order: Order, sqa_expr: dict[str, sqa.Label]) -> sqa.UnaryExpression: order_expr = cls.compile_col_expr(order.order_by, sqa_expr) if types.without_const(order.order_by.dtype()) == String(): if cls.default_collation() is not None: @@ -255,32 +222,22 @@ def compile_order( order_expr = order_expr.collate(cls.default_collation()) order_expr = order_expr.desc() if order.descending else order_expr.asc() if order.nulls_last is not None: - order_expr = ( - order_expr.nulls_last() - if order.nulls_last - else order_expr.nulls_first() - ) + order_expr = order_expr.nulls_last() if order.nulls_last else order_expr.nulls_first() return order_expr @classmethod - def compile_cast( - cls, cast: Cast, sqa_expr: dict[str, sqa.Label] - ) -> sqa.Case | sqa.TryCast: + def compile_cast(cls, cast: Cast, sqa_expr: dict[str, sqa.Label]) -> sqa.Case | sqa.TryCast: return cls.cast_compiled(cast, cls.compile_col_expr(cast.val, sqa_expr)) @classmethod - def cast_compiled( - cls, cast: Cast, compiled_expr: sqa.ColumnElement - ) -> sqa.Cast | sqa.TryCast: + def cast_compiled(cls, cast: Cast, compiled_expr: sqa.ColumnElement) -> sqa.Cast | sqa.TryCast: if cast.strict: return sqa.cast(compiled_expr, cls.sqa_type(cast.target_type)) else: return sqa.try_cast(compiled_expr, cls.sqa_type(cast.target_type)) @classmethod - def fix_fn_types( - cls, fn: ColFn, val: sqa.ColumnElement, *args: sqa.ColumnElement - ) -> sqa.ColumnElement: + def fix_fn_types(cls, fn: ColFn, val: sqa.ColumnElement, *args: sqa.ColumnElement) -> sqa.ColumnElement: return val @classmethod @@ -298,9 +255,7 @@ def compile_col_expr( elif isinstance(expr, ColFn): args: list[sqa.ColumnElement] = [ - cls.compile_col_expr( - arg, sqa_expr, compile_literals=not types.is_const(param) - ) + cls.compile_col_expr(arg, sqa_expr, compile_literals=not types.is_const(param)) for arg, param in zip( expr.args, expr.op.trie.best_match(tuple(arg.dtype() for arg in expr.args))[0], @@ -319,9 +274,7 @@ def compile_col_expr( arrange = expr.context_kwargs.get("arrange") if arrange: - order_by = dedup_order_by( - cls.compile_order(order, sqa_expr) for order in arrange - ) + order_by = dedup_order_by(cls.compile_order(order, sqa_expr) for order in arrange) else: order_by = None @@ -332,9 +285,7 @@ def compile_col_expr( # some backends need to do preprocessing and some postprocessing here, # so we just give them full control by passing the responsibility of # calling the `impl`. - value = cls.compile_ordered_aggregation( - *args, order_by=order_by, impl=impl - ) + value = cls.compile_ordered_aggregation(*args, order_by=order_by, impl=impl) else: value: sqa.FunctionElement = impl(*args) @@ -342,17 +293,11 @@ def compile_col_expr( if expr.op == ops.cum_sum and cls.dialect_order_append_rand(): order_by += [cls.get_impl(ops.rand, [])()] - if ( - partition_by is not None - or order_by is not None - and expr.ftype() == Ftype.WINDOW - ): + if partition_by is not None or order_by is not None and expr.ftype() == Ftype.WINDOW: value = sqa.over( value, partition_by=partition_by, - order_by=sqa.sql.expression.ClauseList(*order_by) - if order_by - else None, + order_by=sqa.sql.expression.ClauseList(*order_by) if order_by else None, ) return cls.fix_fn_types(expr, value, *args) @@ -366,11 +311,7 @@ def compile_col_expr( ) for cond, val in expr.cases ), - else_=( - cls.compile_col_expr(expr.default_val, sqa_expr) - if expr.default_val is not None - else None - ), + else_=(cls.compile_col_expr(expr.default_val, sqa_expr) if expr.default_val is not None else None), ) if not cls.pdt_type(res.type).is_subtype(expr.dtype()): @@ -395,23 +336,17 @@ def compile_col_expr( raise AssertionError @classmethod - def compile_query( - cls, table: sqa.Table, query: Query, sqa_expr: dict[UUID, sqa.ColumnElement] - ) -> sqa.sql.Select: + def compile_query(cls, table: sqa.Table, query: Query, sqa_expr: dict[UUID, sqa.ColumnElement]) -> sqa.sql.Select: sel = table.select().select_from(table) if query.where: - sel = sel.where( - *(cls.compile_col_expr(pred, sqa_expr) for pred in query.where) - ) + sel = sel.where(*(cls.compile_col_expr(pred, sqa_expr) for pred in query.where)) if query.group_by: sel = sel.group_by(*(sqa_expr[uid] for uid in query.group_by)) if query.having: - sel = sel.having( - *(cls.compile_col_expr(pred, sqa_expr) for pred in query.having) - ) + sel = sel.having(*(cls.compile_col_expr(pred, sqa_expr) for pred in query.having)) if query.limit is not None: sel = sel.limit(query.limit) @@ -419,20 +354,14 @@ def compile_query( sel = sel.offset(query.offset) if query.order_by: - sel = sel.order_by( - *dedup_order_by( - cls.compile_order(ord, sqa_expr) for ord in query.order_by - ) - ) + sel = sel.order_by(*dedup_order_by(cls.compile_order(ord, sqa_expr) for ord in query.order_by)) sel = sel.with_only_columns(*(sqa_expr[uid] for uid in query.select)) return sel @classmethod - def compile_ast( - cls, nd: AstNode, needed_cols: dict[UUID, int] - ) -> tuple[sqa.Table, Query, dict[UUID, sqa.Label]]: + def compile_ast(cls, nd: AstNode, needed_cols: dict[UUID, int]) -> tuple[sqa.Table, Query, dict[UUID, sqa.Label]]: if isinstance(nd, verbs.Verb): # store a counter in `needed_cols how often each UUID is referenced by # ancestors. This allows to only select necessary columns in a subquery. @@ -447,9 +376,7 @@ def compile_ast( table, query, sqa_expr = cls.compile_ast(nd.child, needed_cols) if isinstance(nd, verbs.Mutate | verbs.Summarize): - query.select = [ - uid for uid in query.select if sqa_expr[uid].name not in set(nd.names) - ] + query.select = [uid for uid in query.select if sqa_expr[uid].name not in set(nd.names)] if isinstance(nd, verbs.SubqueryMarker): if needed_cols.keys().isdisjoint(sqa_expr.keys()): @@ -480,9 +407,7 @@ def compile_ast( table = cls.compile_query(table, query, sqa_expr).subquery() sqa_expr = { - uid: sqa.label( - name_in_subquery[uid], table.columns.get(name_in_subquery[uid]) - ) + uid: sqa.label(name_in_subquery[uid], table.columns.get(name_in_subquery[uid])) for uid in needed_cols.keys() if uid in sqa_expr } @@ -497,11 +422,7 @@ def compile_ast( elif isinstance(nd, verbs.Rename): sqa_expr = { - uid: ( - sqa.label(nd.name_map[lb.name], lb) - if lb.name in nd.name_map - else lb - ) + uid: (sqa.label(nd.name_map[lb.name], lb) if lb.name in nd.name_map else lb) for uid, lb in sqa_expr.items() } @@ -526,11 +447,7 @@ def compile_ast( uid: sqa.label(name, cls.compile_col_expr(val, sqa_expr)) for name, uid, val in zip(nd.names, nd.uuids, nd.values, strict=True) } - query.group_by.extend( - col._uuid - for col in query.partition_by - if not types.is_const(col.dtype()) - ) + query.group_by.extend(col._uuid for col in query.partition_by if not types.is_const(col.dtype())) query.select = [col._uuid for col in query.partition_by] + nd.uuids query.partition_by = [] query.order_by.clear() @@ -554,9 +471,7 @@ def compile_ast( query.partition_by.clear() elif isinstance(nd, verbs.Join): - right_table, right_query, right_sqa_expr = cls.compile_ast( - nd.right, needed_cols - ) + right_table, right_query, right_sqa_expr = cls.compile_ast(nd.right, needed_cols) sqa_expr.update(right_sqa_expr) compiled_on = cls.compile_col_expr(nd.on, sqa_expr) @@ -568,10 +483,7 @@ def compile_ast( operator.and_, ( compiled_on, - *( - cls.compile_col_expr(pred, right_sqa_expr) - for pred in right_query.where - ), + *(cls.compile_col_expr(pred, right_sqa_expr) for pred in right_query.where), ), ) elif nd.how == "full": @@ -589,6 +501,69 @@ def compile_ast( assert not right_query.partition_by assert not right_query.group_by + elif isinstance(nd, verbs.Union): + # For UNION, both queries must select the same columns in the same order + # First compile the right AST to get right_query.select (which contains only selected/visible columns) + right_table, right_query, right_sqa_expr = cls.compile_ast(nd.right, needed_cols) + + # Get column names from both sides + left_select = query.select + left_col_names = [sqa_expr[uid].name for uid in left_select] + right_select = right_query.select + right_col_names = [right_sqa_expr[uid].name for uid in right_select] + + # If column order doesn't match, wrap right AST with a Select to reorder + if left_col_names != right_col_names: + # Get right cache to access Col objects for reordering + from pydiverse.transform._internal.pipe.cache import Cache + + right_cache = Cache.from_ast(nd.right) + + # Get Col objects from right cache in the order of left columns + reordered_cols = [] + for name in left_col_names: + if name not in right_cache.name_to_uuid: + raise ValueError(f"union requires matching column names: '{name}' not found in right table") + uid = right_cache.name_to_uuid[name] + col = right_cache.cols[uid] + reordered_cols.append(col) + + # Wrap right AST with Select to reorder columns and recompile + right_ast = verbs.Select(nd.right, reordered_cols) + right_table, right_query, right_sqa_expr = cls.compile_ast(right_ast, needed_cols) + + # Build left and right select statements + left_sel = cls.compile_query(table, query, sqa_expr) + right_sel = cls.compile_query(right_table, right_query, right_sqa_expr) + + # If either side is a subquery, get the original CompoundSelect + # to allow calling sa.union/union_all again + if isinstance(left_sel, sqa.sql.selectable.Subquery): + left_sel = left_sel.original + if isinstance(right_sel, sqa.sql.selectable.Subquery): + right_sel = right_sel.original + + # Use UNION or UNION ALL + # distinct=True means UNION (remove duplicates), distinct=False means UNION ALL (keep duplicates) + if nd.distinct: + union_query = sqa.union(left_sel, right_sel) + else: + union_query = sqa.union_all(left_sel, right_sel) + + # Create a subquery from the union + table = union_query.subquery() + + # Update sqa_expr to point to the union result columns + # Use left column names + sqa_expr = {uid: sqa.label(sqa_expr[uid].name, table.columns[sqa_expr[uid].name]) for uid in left_select} + + # Create a new query with the union result + # Only keep the select columns, reset all other query state + query = Query(select=left_select) + + assert not right_query.partition_by + assert not right_query.group_by + elif isinstance(nd, TableImpl): table = nd.table query = Query(select=[col._uuid for col in nd.cols.values()]) @@ -856,16 +831,12 @@ def _len(): @impl(ops.shift) def _shift(x, by, empty_value=None): if by >= 0: - if empty_value is not None and not isinstance( - empty_value.type, sqa.types.NullType - ): + if empty_value is not None and not isinstance(empty_value.type, sqa.types.NullType): return sqa.func.LAG(x, by, empty_value, type_=x.type) else: return sqa.func.LAG(x, by, type_=x.type) if by < 0: - if empty_value is not None and not isinstance( - empty_value.type, sqa.types.NullType - ): + if empty_value is not None and not isinstance(empty_value.type, sqa.types.NullType): return sqa.func.LEAD(x, -by, empty_value, type_=x.type) else: return sqa.func.LEAD(x, -by, type_=x.type) @@ -966,9 +937,7 @@ def _log10(x): def _clip(x, lower, upper, *, _Impl: SqlImpl, _sig: Sequence[Dtype]): return sqa.case( (x.is_(sqa.null()), sqa.null()), - else_=_Impl.get_impl(ops.horizontal_max, _sig)( - _Impl.get_impl(ops.horizontal_min, _sig)(x, upper), lower - ), + else_=_Impl.get_impl(ops.horizontal_max, _sig)(_Impl.get_impl(ops.horizontal_min, _sig)(x, upper), lower), ) @impl(ops.cum_sum) diff --git a/src/pydiverse/transform/_internal/backend/sqlite.py b/src/pydiverse/transform/_internal/backend/sqlite.py index 5cb8ae6d..c9877165 100644 --- a/src/pydiverse/transform/_internal/backend/sqlite.py +++ b/src/pydiverse/transform/_internal/backend/sqlite.py @@ -64,21 +64,13 @@ def cast_compiled(cls, cast: Cast, compiled_expr: sqa.ColumnElement): return sqa.cast(compiled_expr, cls.sqa_type(cast.target_type)) @classmethod - def fix_fn_types( - cls, fn: ColFn, val: sqa.ColumnElement, *args: sqa.ColumnElement - ) -> sqa.ColumnElement: - if ( - fn.op - in (ops.horizontal_min, ops.horizontal_max, ops.mean, ops.min, ops.max) - and fn.dtype().is_float() - ): + def fix_fn_types(cls, fn: ColFn, val: sqa.ColumnElement, *args: sqa.ColumnElement) -> sqa.ColumnElement: + if fn.op in (ops.horizontal_min, ops.horizontal_max, ops.mean, ops.min, ops.max) and fn.dtype().is_float(): return sqa.cast(val, sqa.Double) return val @classmethod - def compile_ordered_aggregation( - cls, *args: sqa.ColumnElement, order_by: list[sqa.UnaryExpression], impl - ): + def compile_ordered_aggregation(cls, *args: sqa.ColumnElement, order_by: list[sqa.UnaryExpression], impl): from sqlalchemy.dialects import postgresql from sqlalchemy.dialects.postgresql import aggregate_order_by @@ -201,9 +193,7 @@ def _is_not_nan(x): @impl(ops.cbrt) def _cbrt(x): pow_impl = SqliteImpl.get_impl(ops.pow, (Float(), Float())) - return sqa.func.sign(x) * pow_impl( - sqa.func.abs(x), sqa.literal(1 / 3, type_=sqa.Double) - ) + return sqa.func.sign(x) * pow_impl(sqa.func.abs(x), sqa.literal(1 / 3, type_=sqa.Double)) @impl(ops.clip) def _clip(x, lower, upper): diff --git a/src/pydiverse/transform/_internal/backend/table_impl.py b/src/pydiverse/transform/_internal/backend/table_impl.py index 61cb7d6e..195e34b9 100644 --- a/src/pydiverse/transform/_internal/backend/table_impl.py +++ b/src/pydiverse/transform/_internal/backend/table_impl.py @@ -35,14 +35,9 @@ class TableImpl(AstNode): def __init__(self, name: str | None, schema: dict[str, Dtype]): self.name = name - self.cols = { - name: Col(name, self, uuid.uuid1(), dtype, Ftype.ELEMENT_WISE) - for name, dtype in schema.items() - } - - def _unformatted_ast_repr( - self, verb_depth: int, expr_depth: int, display_name_map - ) -> str: + self.cols = {name: Col(name, self, uuid.uuid1(), dtype, Ftype.ELEMENT_WISE) for name, dtype in schema.items()} + + def _unformatted_ast_repr(self, verb_depth: int, expr_depth: int, display_name_map) -> str: return self._ast_node_repr(expr_depth, display_name_map) def _ast_node_repr(self, expr_depth, display_name_map): @@ -52,9 +47,7 @@ def _table_def_repr(self) -> str: raise NotImplementedError() def short_name(self): - return ( - "?" if self.name is None else self.name - ) + f" (source table, backend: '{self.backend_name}')" + return ("?" if self.name is None else self.name) + f" (source table, backend: '{self.backend_name}')" def __init_subclass__(cls) -> None: cls.impl_store = ImplStore() @@ -77,16 +70,12 @@ def from_resource( res = resource elif isinstance(resource, dict): - return TableImpl.from_resource( - pl.DataFrame(resource), backend, name=name, uuids=uuids - ) + return TableImpl.from_resource(pl.DataFrame(resource), backend, name=name, uuids=uuids) elif pd is not None and isinstance(resource, pd.DataFrame): # copy pandas dataframe to polars # TODO: try zero-copy for arrow backed pandas - return TableImpl.from_resource( - pl.DataFrame(resource), backend, name=name, uuids=uuids - ) + return TableImpl.from_resource(pl.DataFrame(resource), backend, name=name, uuids=uuids) elif isinstance(resource, pl.DataFrame | pl.LazyFrame): if name is None: @@ -155,8 +144,7 @@ def get_impl(cls, op: "Operator", sig: Sequence[Dtype]) -> Any: return cls.__bases__[0].get_impl(op, sig) except NotSupportedError as err: raise NotSupportedError( - f"operation `{op.name}` is not supported by the backend " - f"`{cls.__name__.lower()[:-4]}`" + f"operation `{op.name}` is not supported by the backend `{cls.__name__.lower()[:-4]}`" ) from err diff --git a/src/pydiverse/transform/_internal/errors/__init__.py b/src/pydiverse/transform/_internal/errors/__init__.py index 829dc376..7f641dfd 100644 --- a/src/pydiverse/transform/_internal/errors/__init__.py +++ b/src/pydiverse/transform/_internal/errors/__init__.py @@ -64,11 +64,7 @@ def check_arg_type( ): if not isinstance(arg, expected_type): type_args = typing.get_args(expected_type) - expected_type_str = ( - expected_type.__name__ - if not type_args - else " | ".join(t.__name__ for t in type_args) - ) + expected_type_str = expected_type.__name__ if not type_args else " | ".join(t.__name__ for t in type_args) raise TypeError( f"argument for parameter `{param_name}` of `{fn}` must have type " f"`{expected_type_str}`, found `{type(arg).__name__}` instead" @@ -79,14 +75,9 @@ def check_vararg_type(expected_type: type, fn: str, *args: Any): for arg in args: if not isinstance(arg, expected_type): type_args = typing.get_args(expected_type) - expected_type_str = ( - expected_type.__name__ - if not type_args - else " | ".join(t.__name__ for t in type_args) - ) + expected_type_str = expected_type.__name__ if not type_args else " | ".join(t.__name__ for t in type_args) raise TypeError( - f"varargs to `{fn}` must have type `{expected_type_str}`, found " - f"`{type(arg).__name__}` instead" + f"varargs to `{fn}` must have type `{expected_type_str}`, found `{type(arg).__name__}` instead" ) diff --git a/src/pydiverse/transform/_internal/ops/ops/comparison.py b/src/pydiverse/transform/_internal/ops/ops/comparison.py index 71adbacb..e2c86d62 100644 --- a/src/pydiverse/transform/_internal/ops/ops/comparison.py +++ b/src/pydiverse/transform/_internal/ops/ops/comparison.py @@ -5,13 +5,9 @@ from pydiverse.transform._internal.ops.signature import Signature from pydiverse.transform._internal.tree.types import COMPARABLE, Bool, Const, S -equal = Operator( - "__eq__", Signature(S, S, return_type=Bool()), doc="Equality comparison ==" -) +equal = Operator("__eq__", Signature(S, S, return_type=Bool()), doc="Equality comparison ==") -not_equal = Operator( - "__ne__", Signature(S, S, return_type=Bool()), doc="Non-equality comparison !=" -) +not_equal = Operator("__ne__", Signature(S, S, return_type=Bool()), doc="Non-equality comparison !=") less_than = Operator( diff --git a/src/pydiverse/transform/_internal/ops/ops/datetime.py b/src/pydiverse/transform/_internal/ops/ops/datetime.py index 39f83856..28b2024b 100644 --- a/src/pydiverse/transform/_internal/ops/ops/datetime.py +++ b/src/pydiverse/transform/_internal/ops/ops/datetime.py @@ -44,9 +44,7 @@ def __init__(self, name: str, doc: str | None = None): """, ) -dt_microsecond = DatetimeExtract( - "dt.microsecond", doc="The microsecond component of the datetime." -) +dt_microsecond = DatetimeExtract("dt.microsecond", doc="The microsecond component of the datetime.") dt_day_of_week = DateExtract( "dt.day_of_week", diff --git a/src/pydiverse/transform/_internal/ops/ops/numeric.py b/src/pydiverse/transform/_internal/ops/ops/numeric.py index ea0fd419..0f3e8b59 100644 --- a/src/pydiverse/transform/_internal/ops/ops/numeric.py +++ b/src/pydiverse/transform/_internal/ops/ops/numeric.py @@ -89,13 +89,9 @@ sin = Operator("sin", Signature(Float(), return_type=Float()), doc="Computes the sine.") -cos = Operator( - "cos", Signature(Float(), return_type=Float()), doc="Computes the cosine." -) +cos = Operator("cos", Signature(Float(), return_type=Float()), doc="Computes the cosine.") -tan = Operator( - "tan", Signature(Float(), return_type=Float()), doc="Computes the tangent." -) +tan = Operator("tan", Signature(Float(), return_type=Float()), doc="Computes the tangent.") asin = Operator( "asin", @@ -116,13 +112,9 @@ ) -sqrt = Operator( - "sqrt", Signature(Float(), return_type=Float()), doc="Computes the square root." -) +sqrt = Operator("sqrt", Signature(Float(), return_type=Float()), doc="Computes the square root.") -cbrt = Operator( - "cbrt", Signature(Float(), return_type=Float()), doc="Computes the cube root." -) +cbrt = Operator("cbrt", Signature(Float(), return_type=Float()), doc="Computes the cube root.") is_inf = Operator( diff --git a/src/pydiverse/transform/_internal/ops/ops/string.py b/src/pydiverse/transform/_internal/ops/ops/string.py index b6d2c59d..7c31e01d 100644 --- a/src/pydiverse/transform/_internal/ops/ops/string.py +++ b/src/pydiverse/transform/_internal/ops/ops/string.py @@ -277,9 +277,7 @@ def __init__(self, name: str, doc: str = ""): str_contains = Operator( "str.contains", - Signature( - String(), Const(String()), Const(Bool()), Const(Bool()), return_type=Bool() - ), + Signature(String(), Const(String()), Const(Bool()), Const(Bool()), return_type=Bool()), param_names=["self", "pattern", "allow_regex", "true_if_regex_unsupported"], default_values=[..., ..., True, False], doc=""" @@ -382,8 +380,6 @@ def __init__(self, name: str, doc: str = ""): """, ) -str_to_datetime = Operator( - "str.to_datetime", Signature(String(), return_type=Datetime()) -) +str_to_datetime = Operator("str.to_datetime", Signature(String(), return_type=Datetime())) str_to_date = Operator("str.to_date", Signature(String(), return_type=Date())) diff --git a/src/pydiverse/transform/_internal/ops/signature.py b/src/pydiverse/transform/_internal/ops/signature.py index 5330c07a..e97ceb9d 100644 --- a/src/pydiverse/transform/_internal/ops/signature.py +++ b/src/pydiverse/transform/_internal/ops/signature.py @@ -31,9 +31,7 @@ def __init__(self, *types: Dtype | EllipsisType, return_type: Dtype): class SignatureTrie: @dataclasses.dataclass(slots=True) class Node: - children: dict[Dtype, "SignatureTrie.Node"] = dataclasses.field( - default_factory=dict - ) + children: dict[Dtype, "SignatureTrie.Node"] = dataclasses.field(default_factory=dict) data: Any = None def insert( @@ -56,20 +54,14 @@ def insert( if sig[0] not in self.children: self.children[sig[0]] = SignatureTrie.Node() - self.children[sig[0]].insert( - sig[1:], data, last_is_vararg, last_type=sig[0] - ) + self.children[sig[0]].insert(sig[1:], data, last_is_vararg, last_type=sig[0]) - def all_matches( - self, sig: Sequence[Dtype], tyvars: dict[str, Dtype] - ) -> list[tuple[list[Dtype], Any]]: + def all_matches(self, sig: Sequence[Dtype], tyvars: dict[str, Dtype]) -> list[tuple[list[Dtype], Any]]: if len(sig) == 0: return [ ( [], - self.data - if not isinstance(self.data, Tyvar) - else tyvars[self.data.name], + self.data if not isinstance(self.data, Tyvar) else tyvars[self.data.name], ) ] @@ -78,29 +70,22 @@ def all_matches( for dtype, child in self.children.items(): base_type = types.without_const(dtype) match_dtype = ( - tyvars[base_type.name] - if isinstance(base_type, Tyvar) and base_type.name in tyvars - else dtype + tyvars[base_type.name] if isinstance(base_type, Tyvar) and base_type.name in tyvars else dtype ) if isinstance(types.without_const(match_dtype), Tyvar): assert tyvar is None tyvar = dtype elif types.converts_to(sig[0], match_dtype): matches.extend( - ([match_dtype] + match_sig, data) - for match_sig, data in child.all_matches(sig[1:], tyvars) + ([match_dtype] + match_sig, data) for match_sig, data in child.all_matches(sig[1:], tyvars) ) # When the current node is a type var, try every type we can convert to. if tyvar is not None: already_matched = {types.without_const(m[0][0]) for m in matches} for dtype in types.implicit_conversions(types.without_const(sig[0])): - match_dtype = ( - types.with_const(dtype) if types.is_const(tyvar) else dtype - ) - if dtype not in already_matched and types.converts_to( - sig[0], match_dtype - ): + match_dtype = types.with_const(dtype) if types.is_const(tyvar) else dtype + if dtype not in already_matched and types.converts_to(sig[0], match_dtype): matches.extend( ([match_dtype] + match_sig, data) for match_sig, data in self.children[tyvar].all_matches( @@ -121,15 +106,11 @@ def best_match(self, sig: Sequence[Dtype]) -> tuple[list[Dtype], Any] | None: if len(all_matches) == 0: return None - return all_matches[ - best_signature_match(sig, [match[0] for match in all_matches]) - ] + return all_matches[best_signature_match(sig, [match[0] for match in all_matches])] # returns the index of the signature in `candidates` that matches best -def best_signature_match( - sig: Sequence[Dtype], candidates: Sequence[Sequence[Dtype]] -) -> int: +def best_signature_match(sig: Sequence[Dtype], candidates: Sequence[Sequence[Dtype]]) -> int: assert len(candidates) > 0 best_index = 0 @@ -140,9 +121,7 @@ def best_signature_match( best_index = i + 1 best_distance = this_distance - assert ( - sum(int(best_distance == sig_distance(sig, match)) for match in candidates) == 1 - ) + assert sum(int(best_distance == sig_distance(sig, match)) for match in candidates) == 1 return best_index diff --git a/src/pydiverse/transform/_internal/pipe/aligned.py b/src/pydiverse/transform/_internal/pipe/aligned.py index 89732917..6228a0b0 100644 --- a/src/pydiverse/transform/_internal/pipe/aligned.py +++ b/src/pydiverse/transform/_internal/pipe/aligned.py @@ -53,9 +53,7 @@ def aligned(fn=None, *, with_: str | None = None): def decorator(fn): signature = inspect.signature(fn) if with_ is not None and with_ not in signature.parameters: - raise ValueError( - f"function `{fn.__name__}` has no argument named `{with_}`" - ) + raise ValueError(f"function `{fn.__name__}` has no argument named `{with_}`") @wraps(fn) def wrapper(*args, **kwargs): @@ -76,9 +74,7 @@ def wrapper(*args, **kwargs): return decorator -def eval_aligned( - val: ColExpr | pl.Series | pd.Series, with_: Table | Col | None = None -) -> EvalAligned: +def eval_aligned(val: ColExpr | pl.Series | pd.Series, with_: Table | Col | None = None) -> EvalAligned: """ Allows to evaluate a column expression containing columns from different tables and to use polars / pandas Series in column expressions. diff --git a/src/pydiverse/transform/_internal/pipe/cache.py b/src/pydiverse/transform/_internal/pipe/cache.py index 308c48ac..55f26cb8 100644 --- a/src/pydiverse/transform/_internal/pipe/cache.py +++ b/src/pydiverse/transform/_internal/pipe/cache.py @@ -72,10 +72,8 @@ def __repr__(self) -> str: @staticmethod def from_ast(node: AstNode) -> "Cache": if isinstance(node, verbs.Verb): - if isinstance(node, verbs.Join): - return Cache.from_ast(node.child).update( - node, right_cache=Cache.from_ast(node.right) - ) + if isinstance(node, verbs.Join | verbs.Union): + return Cache.from_ast(node.child).update(node, right_cache=Cache.from_ast(node.right)) else: return Cache.from_ast(node.child).update(node) @@ -92,9 +90,7 @@ def from_ast(node: AstNode) -> "Cache": backend=type(node), ) - def update( - self, node: verbs.Verb, *, right_cache: Optional["Cache"] = None - ) -> "Cache": + def update(self, node: verbs.Verb, *, right_cache: Optional["Cache"] = None) -> "Cache": """ Returns a new cache for `node`, assuming `self` is the cache of `node.child`. Does not modify `self`. @@ -104,14 +100,10 @@ def update( if isinstance(node, verbs.Alias): if node.uuid_map is not None: - res.name_to_uuid = { - name: node.uuid_map[uid] for name, uid in self.name_to_uuid.items() - } + res.name_to_uuid = {name: node.uuid_map[uid] for name, uid in self.name_to_uuid.items()} res.uuid_to_name = {uid: name for name, uid in res.name_to_uuid.items()} res.cols = { - node.uuid_map[uid]: Col( - col.name, node, node.uuid_map[uid], col._dtype, col._ftype - ) + node.uuid_map[uid]: Col(col.name, node, node.uuid_map[uid], col._dtype, col._ftype) for uid, col in self.cols.items() } res.partition_by = [node.uuid_map[uid] for uid in self.partition_by] @@ -119,11 +111,7 @@ def update( elif isinstance(node, verbs.Select): selected_uuids = set(col._uuid for col in node.select) - res.uuid_to_name = { - uid: name - for uid, name in self.uuid_to_name.items() - if uid in selected_uuids - } + res.uuid_to_name = {uid: name for uid, name in self.uuid_to_name.items() if uid in selected_uuids} res.name_to_uuid = {name: uid for uid, name in res.uuid_to_name.items()} elif isinstance(node, verbs.Rename): @@ -136,13 +124,9 @@ def update( elif isinstance(node, verbs.Mutate): res.cols = self.cols | { uid: Col(name, node, uid, val.dtype(), val.ftype(agg_is_window=True)) - for name, val, uid in zip( - node.names, node.values, node.uuids, strict=True - ) - } - res.name_to_uuid = self.name_to_uuid | { - name: uid for name, uid in zip(node.names, node.uuids, strict=True) + for name, val, uid in zip(node.names, node.values, node.uuids, strict=True) } + res.name_to_uuid = self.name_to_uuid | {name: uid for name, uid in zip(node.names, node.uuids, strict=True)} res.uuid_to_name = {uid: name for name, uid in res.name_to_uuid.items()} elif isinstance(node, verbs.Filter): @@ -157,18 +141,14 @@ def update( res.partition_by = [] elif isinstance(node, verbs.Summarize): - overwritten = { - col_name for col_name in node.names if col_name in self.name_to_uuid - } + overwritten = {col_name for col_name in node.names if col_name in self.name_to_uuid} cols = { self.uuid_to_name[uid]: self.cols[uid] for uid in self.partition_by if self.uuid_to_name[uid] not in overwritten } | { name: Col(name, node, uid, val.dtype(), val.ftype()) - for name, val, uid in zip( - node.names, node.values, node.uuids, strict=True - ) + for name, val, uid in zip(node.names, node.values, node.uuids, strict=True) } res.cols = {col._uuid: col for _, col in cols.items()} @@ -191,6 +171,21 @@ def update( res.limit = 0 res.group_by = set() + elif isinstance(node, verbs.Union): + assert right_cache is not None + + # For union, visible columns must match (validated in verb function) + # Hidden columns: are removed (we don't keep names for them and it is unlike they match in uuid) + res.cols = {uid: col for uid, col in self.cols.items() if uid in self.uuid_to_name} + # Visible columns should match, so we keep left table's name_to_uuid + # (right table's visible columns are the same by validation) + res.name_to_uuid = self.name_to_uuid.copy() + res.uuid_to_name = self.uuid_to_name.copy() + + res.derived_from = self.derived_from | right_cache.derived_from + res.limit = 0 + res.group_by = set() + elif isinstance(node, verbs.SubqueryMarker): res.cols = { uid: Col( @@ -221,11 +216,7 @@ def requires_subquery(self, node: verbs.Verb) -> str | None: if ( isinstance( node, - verbs.Filter - | verbs.Summarize - | verbs.Arrange - | verbs.GroupBy - | verbs.Join, + verbs.Filter | verbs.Summarize | verbs.Arrange | verbs.GroupBy | verbs.Join | verbs.Union, ) and self.limit != 0 ): @@ -238,16 +229,12 @@ def requires_subquery(self, node: verbs.Verb) -> str | None: if isinstance(col, Col) ) for fn in node.iter_col_nodes() - if ( - isinstance(fn, ColFn) and fn.op.ftype in (Ftype.AGGREGATE, Ftype.WINDOW) - ) + if (isinstance(fn, ColFn) and fn.op.ftype in (Ftype.AGGREGATE, Ftype.WINDOW)) ): return "nested window / aggregation functions in `mutate`" if isinstance(node, verbs.Filter) and any( - col.ftype(agg_is_window=True) == Ftype.WINDOW - for col in node.iter_col_nodes() - if isinstance(col, Col) + col.ftype(agg_is_window=True) == Ftype.WINDOW for col in node.iter_col_nodes() if isinstance(col, Col) ): return "window function in `filter`" @@ -267,19 +254,12 @@ def requires_subquery(self, node: verbs.Verb) -> str | None: if self.group_by: return "join with a grouped table" - if ( - node.how == "full" - or (node.child not in self.derived_from and node.how == "left") - ) and any( - types.is_const(self.cols[uid].dtype()) - for uid in self.uuid_to_name.keys() + if (node.how == "full" or (node.child not in self.derived_from and node.how == "left")) and any( + types.is_const(self.cols[uid].dtype()) for uid in self.uuid_to_name.keys() ): return "left / full join with a table containing a constant column" - if any( - self.cols[uid].ftype() == Ftype.WINDOW - for uid in self.uuid_to_name.keys() - ): + if any(self.cols[uid].ftype() == Ftype.WINDOW for uid in self.uuid_to_name.keys()): return "join with a table containing window function expression" if any( @@ -292,6 +272,13 @@ def requires_subquery(self, node: verbs.Verb) -> str | None: if self.is_filtered and node.how == "full": return "full join with a filtered table" + if isinstance(node, verbs.Union): + if self.group_by: + return "union with a grouped table" + + if any(self.cols[uid].ftype() == Ftype.WINDOW for uid in self.uuid_to_name.keys()): + return "union with a table containing window function expression" + return None def selected_cols(self) -> list[Col]: @@ -349,9 +336,7 @@ def transfer_col_references(table, ref_source): errors.check_arg_type(Table, "transfer_col_references", "table", table) errors.check_arg_type(Table, "transfer_col_references", "ref_source", ref_source) - if ( - col := next((col for col in table if col.name not in ref_source), None) - ) is not None: + if (col := next((col for col in table if col.name not in ref_source), None)) is not None: raise ValueError( f"column {col.ast_repr()} of the table `{table._ast.short_name()}` does " "not exist in the reference source table " @@ -361,10 +346,7 @@ def transfer_col_references(table, ref_source): new = copy.copy(table) new._ast = Alias( new._ast, - uuid_map={ - uid: ref_source._cache.name_to_uuid[name] - for uid, name in table._cache.uuid_to_name.items() - }, + uuid_map={uid: ref_source._cache.name_to_uuid[name] for uid, name in table._cache.uuid_to_name.items()}, ) new._cache = table._cache.update(new._ast) diff --git a/src/pydiverse/transform/_internal/pipe/functions.py b/src/pydiverse/transform/_internal/pipe/functions.py index 5410891d..48d31f5f 100644 --- a/src/pydiverse/transform/_internal/pipe/functions.py +++ b/src/pydiverse/transform/_internal/pipe/functions.py @@ -34,14 +34,8 @@ def when(condition: ColExpr) -> WhenClause: condition = wrap_literals(condition) - if ( - condition.dtype() is not None - and not types.without_const(condition.dtype()) == Bool() - ): - raise errors.DataTypeError( - "argument for `when` must be of boolean type, but has type " - f"`{condition.dtype()}`" - ) + if condition.dtype() is not None and not types.without_const(condition.dtype()) == Bool(): + raise errors.DataTypeError(f"argument for `when` must be of boolean type, but has type `{condition.dtype()}`") return WhenClause([], wrap_literals(condition)) diff --git a/src/pydiverse/transform/_internal/pipe/table.py b/src/pydiverse/transform/_internal/pipe/table.py index 64654b79..d061262d 100644 --- a/src/pydiverse/transform/_internal/pipe/table.py +++ b/src/pydiverse/transform/_internal/pipe/table.py @@ -179,8 +179,7 @@ def __getitem__(self, key: str | Col | ColName) -> Col: # reference of a previous table. if key._uuid not in self._cache.uuid_to_name: raise ColumnNotFoundError( - f"column `{key.ast_repr()}` does not exist in table " - f"`{self._ast.short_name()}`" + f"column `{key.ast_repr()}` does not exist in table `{self._ast.short_name()}`" ) return Col( self._cache.uuid_to_name[key._uuid], @@ -196,9 +195,7 @@ def __getattr__(self, name: str) -> Col: # for hasattr to work correctly on dunder methods raise AttributeError if name not in self._cache.name_to_uuid: - raise ColumnNotFoundError( - f"column `{name}` does not exist in table `{self._ast.short_name()}`" - ) + raise ColumnNotFoundError(f"column `{name}` does not exist in table `{self._ast.short_name()}`") col = self._cache.cols[self._cache.name_to_uuid[name]] return Col(name, self._ast, col._uuid, col._dtype, col._ftype) @@ -291,10 +288,7 @@ def _repr_html_(self) -> str: else: tbl_name = f"Table {self >> name()} " - html = ( - tbl_name - + f"(backend: {self._cache.backend.backend_name})
" - ) + html = tbl_name + f"(backend: {self._cache.backend.backend_name})
" try: # We use polars' _repr_html_ on the first and last few rows of the table and # fix the `shape` afterwards. @@ -306,10 +300,7 @@ def _repr_html_(self) -> str: num_rows_end = df_html.find(",", num_rows_begin) except Exception as e: - return html + ( - "
export failed\n"
-                f"{escape(e.__class__.__name__)}: {escape(str(e))}
" - ) + return html + (f"
export failed\n{escape(e.__class__.__name__)}: {escape(str(e))}
") return f"{df_html[: num_rows_begin + 8]}{height}{df_html[num_rows_end:]}" @@ -328,18 +319,14 @@ def get_head_tail(tbl: Table) -> tuple[pl.DataFrame, int]: # Only export the first and last few rows. head: pl.DataFrame = tbl >> slice_head(head_tail_len) >> export(pdt.Polars) tail: pl.DataFrame = ( - tbl - >> slice_head(head_tail_len, offset=max(head_tail_len, height - head_tail_len)) - >> export(pdt.Polars) + tbl >> slice_head(head_tail_len, offset=max(head_tail_len, height - head_tail_len)) >> export(pdt.Polars) ) return pl.concat([head, tail]), height def backend( table: Table, -) -> Literal[ - "polars", "duckdb_polars", "sqlite", "postgres", "duckdb", "mssql", "ibm_db2" -]: +) -> Literal["polars", "duckdb_polars", "sqlite", "postgres", "duckdb", "mssql", "ibm_db2"]: """ Returns the backend of the table as a string. """ diff --git a/src/pydiverse/transform/_internal/pipe/verbs.py b/src/pydiverse/transform/_internal/pipe/verbs.py index c9e4f221..6b03c8b9 100644 --- a/src/pydiverse/transform/_internal/pipe/verbs.py +++ b/src/pydiverse/transform/_internal/pipe/verbs.py @@ -64,6 +64,7 @@ SliceHead, Summarize, Ungroup, + Union, ) __all__ = [ @@ -79,6 +80,7 @@ "left_join", "inner_join", "full_join", + "union", "filter", "arrange", "group_by", @@ -95,9 +97,7 @@ def alias(new_name: str | None = None, *, keep_col_refs: bool = False) -> Pipeab @verb @modify_ast -def alias( - table: Table, new_name: str | None = None, *, keep_col_refs: bool = False -) -> Pipeable: +def alias(table: Table, new_name: str | None = None, *, keep_col_refs: bool = False) -> Pipeable: """ Changes the name of the current table and allows subqueries in SQL. @@ -153,9 +153,7 @@ def alias( new = copy.copy(table) new._ast = Alias( new._ast, - uuid_map={uid: uuid.uuid1() for uid in table._cache.cols.keys()} - if not keep_col_refs - else None, + uuid_map={uid: uuid.uuid1() for uid in table._cache.cols.keys()} if not keep_col_refs else None, ) new._ast.name = new_name @@ -163,9 +161,7 @@ def alias( @overload -def collect( - target: Target | None = None, *, keep_col_refs: bool = True -) -> Pipeable: ... +def collect(target: Target | None = None, *, keep_col_refs: bool = True) -> Pipeable: ... @verb @@ -215,9 +211,7 @@ def collect( │ 4 ┆ -- r ┆ 10 │ └─────┴────────┴─────┘ """ - errors.check_arg_type( - Target | type(Target) | type(None), "collect", "target", target - ) + errors.check_arg_type(Target | type(Target) | type(None), "collect", "target", target) if target is None: target = Polars() @@ -245,17 +239,13 @@ def collect( ) ) new._cache.derived_from = table._cache.derived_from | {new._ast} - new._cache.partition_by = [ - preprocess_arg(col, new) for col in table._cache.partition_by - ] + new._cache.partition_by = [preprocess_arg(col, new) for col in table._cache.partition_by] return new @overload -def export( - target: Target, *, schema_overrides: dict[str, Any] | None = None -) -> Pipeable: ... +def export(target: Target, *, schema_overrides: dict[str, Any] | None = None) -> Pipeable: ... @verb @@ -326,24 +316,17 @@ def export( if isinstance(target, Scalar): if len(table) != 1: raise TypeError( - "to export a table to a scalar, it must have exactly one column, but " - f"found {len(table)} columns" + f"to export a table to a scalar, it must have exactly one column, but found {len(table)} columns" ) df: pl.DataFrame = table >> export(Polars()) if df.height != 1: - raise TypeError( - "to export a table to a scalar, it must have exactly one row, but " - f"found {df.height} rows" - ) + raise TypeError(f"to export a table to a scalar, it must have exactly one row, but found {df.height} rows") return df.item() elif isinstance(target, Dict): df: pl.DataFrame = table >> export(Polars()) if df.height != 1: - raise TypeError( - "cannot export a table with more than one row to `Dict`, " - f"found {df.height} rows" - ) + raise TypeError(f"cannot export a table with more than one row to `Dict`, found {df.height} rows") return df.to_dicts()[0] elif isinstance(target, DictOfLists): @@ -360,9 +343,7 @@ def export( return SourceBackend.export( table._ast.clone(), target, - schema_overrides={ - table[col_name]._uuid: dtype for col_name, dtype in schema_overrides.items() - }, + schema_overrides={table[col_name]._uuid: dtype for col_name, dtype in schema_overrides.items()}, ) @@ -404,10 +385,7 @@ def show_query(table: Table, pipe: bool = False) -> Pipeable | None: if query := table >> build_query(): print(query) else: - print( - f"No query to show for table {table._ast.short_name()}. " - f"(backend: {table._cache.backend.backend_name})" - ) + print(f"No query to show for table {table._ast.short_name()}. (backend: {table._cache.backend.backend_name})") return table if pipe else None @@ -446,10 +424,7 @@ def select(table: Table, *cols: Col | ColName | str) -> Pipeable: for col in cols: if isinstance(col, ColName | str) and col not in table: - raise ColumnNotFoundError( - f"column `{col.ast_repr()}` does not exist in table " - f"`{table._ast.short_name()}`" - ) + raise ColumnNotFoundError(f"column `{col.ast_repr()}` does not exist in table `{table._ast.short_name()}`") elif col not in table and col._uuid in table._cache.cols: raise ColumnNotFoundError( f"cannot select hidden column `{col.ast_repr()}` again\n" @@ -498,11 +473,7 @@ def drop(table: Table, *cols: Col | ColName | str) -> Pipeable: dropped_uuids = {preprocess_arg(col, table)._uuid for col in cols} return table >> select( - *( - name - for name, uid in table._cache.name_to_uuid.items() - if uid not in dropped_uuids - ), + *(name for name, uid in table._cache.name_to_uuid.items() if uid not in dropped_uuids), ) @@ -578,8 +549,7 @@ def rename(table: Table, name_map: dict[str | Col | ColName, str]) -> Pipeable: for v in name_map.values(): if not isinstance(v, str): raise TypeError( - "values in the `name_map` of `rename` must have type `str`, found " - f"`{v.__class__.__name__}` instead" + f"values in the `name_map` of `rename` must have type `str`, found `{v.__class__.__name__}` instead" ) for k in name_map.keys(): if not isinstance(k, str | Col | ColName): @@ -588,24 +558,13 @@ def rename(table: Table, name_map: dict[str | Col | ColName, str]) -> Pipeable: f"ColName`, found `{k.__class__.__name__}` instead" ) - name_map = { - (preprocess_arg(k, table) if isinstance(k, ColName | Col) else k): v - for k, v in name_map.items() - } - name_map = { - (table._cache.uuid_to_name[k._uuid] if isinstance(k, Col) else k): v - for k, v in name_map.items() - } + name_map = {(preprocess_arg(k, table) if isinstance(k, ColName | Col) else k): v for k, v in name_map.items()} + name_map = {(table._cache.uuid_to_name[k._uuid] if isinstance(k, Col) else k): v for k, v in name_map.items()} if d := set(name_map).difference(table._cache.name_to_uuid): - raise ValueError( - f"no column with name `{next(iter(d))}` in table " - f"`{table._ast.short_name()}`" - ) + raise ValueError(f"no column with name `{next(iter(d))}` in table `{table._ast.short_name()}`") - if d := (set(table._cache.name_to_uuid).difference(name_map)) & set( - name_map.values() - ): + if d := (set(table._cache.name_to_uuid).difference(name_map)) & set(name_map.values()): raise ValueError(f"rename would cause duplicate column name `{next(iter(d))}`") new = copy.copy(table) @@ -793,8 +752,7 @@ def arrange(table: Table, by: ColExpr, *more_by: ColExpr) -> Pipeable: preprocessed.append(preprocess_arg(Order.from_col_expr(ord), table)) except ErrorWithSource as e: raise type(e)( - e.args[0] - + f"\noccurred in {i}-{ordinal_suffix(i)} `arrange` argument\n" + e.args[0] + f"\noccurred in {i}-{ordinal_suffix(i)} `arrange` argument\n" f"AST path to error source: {get_ast_path_str(ord, e.source)}" ) from e @@ -965,31 +923,18 @@ def summarize(table: Table, **kwargs: ColExpr) -> Pipeable: partition_by = set(table._cache.partition_by) if len(kwargs) == 0 and len(partition_by) == 0: - raise ValueError( - "summarize without preceding group_by needs at least one column to " - "summarize" - ) + raise ValueError("summarize without preceding group_by needs at least one column to summarize") def check_summarize_col_expr(expr: ColExpr, agg_fn_above: bool): - if ( - isinstance(expr, Col) - and expr._uuid not in partition_by - and not agg_fn_above - ): + if isinstance(expr, Col) and expr._uuid not in partition_by and not agg_fn_above: raise FunctionTypeError( - f"column `{expr.ast_repr()}` is neither aggregated nor part of the " - "grouping columns" + f"column `{expr.ast_repr()}` is neither aggregated nor part of the grouping columns" ) elif isinstance(expr, ColFn): if expr.op.ftype == Ftype.WINDOW: - raise FunctionTypeError( - f"forbidden window function `{expr.op.name}` in `summarize`" - ) - elif ( - expr.op.ftype == Ftype.AGGREGATE - and "partition_by" not in expr.context_kwargs - ): + raise FunctionTypeError(f"forbidden window function `{expr.op.name}` in `summarize`") + elif expr.op.ftype == Ftype.AGGREGATE and "partition_by" not in expr.context_kwargs: agg_fn_above = True for child in expr.iter_children(): @@ -1150,9 +1095,7 @@ def join( errors.check_arg_type(ColExpr | list | str, "join", "on", on) errors.check_arg_type(str | None, "join", "suffix", suffix) errors.check_literal_type(["inner", "left", "full"], "join", "how", how) - errors.check_literal_type( - ["1:1", "1:m", "m:1", "m:m"], "join", "validate", validate - ) + errors.check_literal_type(["1:1", "1:m", "m:1", "m:m"], "join", "validate", validate) if left._cache.backend != right._cache.backend: raise TypeError("cannot join two tables with different backends") @@ -1190,17 +1133,10 @@ def _preprocess_on(expr: ColExpr): ) return left[expr.name] if expr not in right: - raise ValueError( - f"no column with name `{expr.name}` found in the left or right " - "table" - ) + raise ValueError(f"no column with name `{expr.name}` found in the left or right table") return right[expr.name] - if ( - isinstance(expr, Col) - and expr._uuid not in left._cache.cols - and expr._uuid not in right._cache.cols - ): + if isinstance(expr, Col) and expr._uuid not in left._cache.cols and expr._uuid not in right._cache.cols: raise ValueError( f"column `{expr.ast_repr()}` used in `on` neither exists in the table " f"`{left._ast.short_name()}` nor in the table " @@ -1259,18 +1195,14 @@ def _preprocess_on(expr: ColExpr): suffix += f"_{cnt}" on_uuids = set( - col._uuid - for col in itertools.chain(*(pred.iter_subtree_preorder() for pred in on)) - if isinstance(col, Col) + col._uuid for col in itertools.chain(*(pred.iter_subtree_preorder() for pred in on)) if isinstance(col, Col) ) right_on_names = set(col.name for col in right if col._uuid in on_uuids) if not (right_names - right_on_names) & left_names: # If nothing except join columns clashes, we only rename the clashing # columns on the right. - right >>= rename( - {col: col.name + suffix for col in right if col.name in left_names} - ) + right >>= rename({col: col.name + suffix for col in right if col.name in left_names}) else: right >>= rename({col: col.name + suffix for col in right}) @@ -1405,6 +1337,176 @@ def cross_join( return left >> join(right, how="inner", on=[], suffix=suffix) +@overload +def union( + right: Table, + *, + distinct: bool = False, +) -> Pipeable: ... + + +@overload +def union( + left: Table, + right: Table, + *, + distinct: bool = False, +) -> Table: ... + + +def _union_impl( + left: Table, + right: Table, + *, + distinct: bool = False, +) -> Table: + """ + Unions two tables by stacking rows vertically. + + The left table in the union comes through the pipe `>>` operator from the + left. + + :param right: + The right table to union with. + + :param distinct: + If ``True``, performs UNION (removes duplicates). If ``False``, + performs UNION ALL (keeps duplicates). + + Note + ---- + Both tables must have the same number of columns with compatible types. + Column names must match between the two tables. + + Examples + -------- + >>> t1 = pdt.Table({"a": [1, 2, 3], "b": [4, 5, 6]}, name="t1") + >>> t2 = pdt.Table({"a": [7, 8], "b": [9, 10]}, name="t2") + >>> t1 >> union(t2) >> show() + shape: (5, 2) + ┌─────┬─────┐ + │ a ┆ b │ + │ --- ┆ --- │ + │ i64 ┆ i64 │ + ╞═════╪═════╡ + │ 1 ┆ 4 │ + │ 2 ┆ 5 │ + │ 3 ┆ 6 │ + │ 7 ┆ 9 │ + │ 8 ┆ 10 │ + └─────┴─────┘ + """ + errors.check_arg_type(Table, "union", "right", right) + + if left._cache.backend != right._cache.backend: + raise TypeError("cannot union two tables with different backends") + + if left._cache.partition_by: + raise ValueError(f"cannot union grouped table `{left._ast.short_name()}`") + elif right._cache.partition_by: + raise ValueError(f"cannot union grouped table `{right._ast.short_name()}`") + + # Check that both tables have the same columns + left_cols = set(left._cache.name_to_uuid.keys()) + right_cols = set(right._cache.name_to_uuid.keys()) + + if left_cols != right_cols: + missing_left = right_cols - left_cols + missing_right = left_cols - right_cols + error_msg = "tables must have the same columns for union" + if missing_left: + error_msg += f"\n columns in right but not in left: {sorted(missing_left)}" + if missing_right: + error_msg += f"\n columns in left but not in right: {sorted(missing_right)}" + raise ValueError(error_msg) + + # Check column type compatibility using lca_type + from pydiverse.transform._internal.tree.types import lca_type + + for col_name in left_cols: + left_col = left._cache.cols[left._cache.name_to_uuid[col_name]] + right_col = right._cache.cols[right._cache.name_to_uuid[col_name]] + left_dtype = left_col.dtype() + right_dtype = right_col.dtype() + + # Check if types are compatible by trying to find a common ancestor + try: + lca_type([left_dtype, right_dtype]) + except DataTypeError as e: + raise TypeError( + f"column '{col_name}' has incompatible types: left has {left_dtype}, right has {right_dtype}" + ) from e + + new = copy.copy(left) + new._ast = Union(left._ast, right._ast, distinct) + + new, left = check_subquery(new, left) + new, right = check_subquery(new, right, is_right=True) + + new._cache = left._cache.update(new._ast, right_cache=right._cache) + + return new + + +def union( + left_or_right: Table, + right: Table | None = None, + *, + distinct: bool = False, +) -> Table | Pipeable: + """ + Unions two tables by stacking rows vertically. + + The left table in the union comes through the pipe `>>` operator from the + left. + + :param right: + The right table to union with. + + :param distinct: + If ``True``, performs UNION (removes duplicates). If ``False``, + performs UNION ALL (keeps duplicates). + + Note + ---- + Both tables must have the same number of columns with compatible types. + Column names must match between the two tables. + + Examples + -------- + >>> t1 = pdt.Table({"a": [1, 2, 3], "b": [4, 5, 6]}, name="t1") + >>> t2 = pdt.Table({"a": [7, 8], "b": [9, 10]}, name="t2") + >>> t1 >> union(t2) >> show() + shape: (5, 2) + ┌─────┬─────┐ + │ a ┆ b │ + │ --- ┆ --- │ + │ i64 ┆ i64 │ + ╞═════╪═════╡ + │ 1 ┆ 4 │ + │ 2 ┆ 5 │ + │ 3 ┆ 6 │ + │ 7 ┆ 9 │ + │ 8 ┆ 10 │ + └─────┴─────┘ + + You can also call it directly: + >>> union(t1, t2) >> show() + """ + # If called with two arguments directly (union(tbl1, tbl2)), return Table directly + if right is not None: + return _union_impl(left_or_right, right, distinct=distinct) # returns Table + + # If called with one argument (union(tbl2)), behave like a verb for pipe syntax + # Create a verb wrapper that will be called by the pipe operator + # Note: left_or_right is the right table in this branch (since right is None) + @verb + def _union_verb(left_table: Table, right_table: Table, distinct: bool) -> Table: + return _union_impl(left_table, right_table, distinct=distinct) + + return _union_verb(left_or_right, distinct=distinct) # returns Pipeable + + @overload def show(pipe: bool = False) -> Pipeable | None: ... @@ -1452,15 +1554,11 @@ def columns(table: Table) -> list[str]: @overload -def ast_repr( - verb_depth: int = 7, expr_depth: int = 2, pipe: bool = False -) -> Pipeable | None: ... +def ast_repr(verb_depth: int = 7, expr_depth: int = 2, pipe: bool = False) -> Pipeable | None: ... @verb -def ast_repr( - table: Table, verb_depth: int = 7, expr_depth: int = 2, *, pipe: bool = False -) -> Pipeable | None: +def ast_repr(table: Table, verb_depth: int = 7, expr_depth: int = 2, *, pipe: bool = False) -> Pipeable | None: r""" Prints the AST of the table to stdout. @@ -1526,15 +1624,8 @@ def preprocess_arg(arg: ColExpr, table: Table, *, agg_is_window: bool = True) -> assert isinstance(arg, ColExpr | Order) def _preprocess_expr(expr: ColExpr, eval_aligned: bool = False): - if ( - isinstance(expr, Col) - and expr._uuid not in table._cache.cols - and not eval_aligned - ): - raise ColumnNotFoundError( - f"column `{expr.ast_repr()}` does not exist in table " - f"`{table._ast.name}`" - ) + if isinstance(expr, Col) and expr._uuid not in table._cache.cols and not eval_aligned: + raise ColumnNotFoundError(f"column `{expr.ast_repr()}` does not exist in table `{table._ast.name}`") if not eval_aligned and isinstance(expr, Series): raise TypeError( @@ -1549,9 +1640,7 @@ def _preprocess_expr(expr: ColExpr, eval_aligned: bool = False): and "partition_by" not in expr.context_kwargs and (expr.op.ftype in (Ftype.WINDOW, Ftype.AGGREGATE)) ): - expr.context_kwargs["partition_by"] = [ - table._cache.cols[uid] for uid in table._cache.partition_by - ] + expr.context_kwargs["partition_by"] = [table._cache.cols[uid] for uid in table._cache.partition_by] if isinstance(expr, ColName): return table[expr.name] diff --git a/src/pydiverse/transform/_internal/tree/ast.py b/src/pydiverse/transform/_internal/tree/ast.py index 43273a96..36c34b7e 100644 --- a/src/pydiverse/transform/_internal/tree/ast.py +++ b/src/pydiverse/transform/_internal/tree/ast.py @@ -38,21 +38,13 @@ def next_nd(root: AstNode, cond: Callable[["AstNode"], bool]): if cond(nd): return nd - source_tbls = set( - nd - for nd in self.iter_subtree_preorder() - if isinstance(nd, TableImpl | Alias) - ) + source_tbls = set(nd for nd in self.iter_subtree_preorder() if isinstance(nd, TableImpl | Alias)) # Add required ASTs from aligned columns (they need not be in the subtree of # `self`) source_tbls |= set( next_nd(col._ast, lambda x: isinstance(x, Alias | TableImpl)) for col in itertools.chain( - *( - nd.iter_col_nodes() - for nd in self.iter_subtree_preorder() - if isinstance(nd, Verb) - ) + *(nd.iter_col_nodes() for nd in self.iter_subtree_preorder() if isinstance(nd, Verb)) ) if isinstance(col, Col) ) @@ -62,9 +54,7 @@ def next_nd(root: AstNode, cond: Callable[["AstNode"], bool]): for nd in source_tbls: display_name = nd.name or "tbl" # try to achieve valid python identifier names - display_name = ( - display_name.replace(" ", "_").replace(".", "_").replace("-", "_") - ) + display_name = display_name.replace(" ", "_").replace(".", "_").replace("-", "_") if display_name not in used: used.add(display_name) @@ -79,9 +69,7 @@ def next_nd(root: AstNode, cond: Callable[["AstNode"], bool]): # Find the last source table / alias for every node in the AST and use the # corresponding name. for nd in self.iter_subtree_postorder(): - table_display_name_map[nd] = table_display_name_map[ - next_nd(nd, lambda x: x in table_display_name_map) - ] + table_display_name_map[nd] = table_display_name_map[next_nd(nd, lambda x: x in table_display_name_map)] if isinstance(nd, Verb): for col in nd.iter_col_nodes(): @@ -94,11 +82,7 @@ def next_nd(root: AstNode, cond: Callable[["AstNode"], bool]): f"{display_name} = {tbl._table_def_repr()}" for tbl, display_name in table_display_name_map.items() if isinstance(tbl, TableImpl) - ) + ( - "\n\n(" - + self._unformatted_ast_repr(verb_depth, expr_depth, table_display_name_map) - + ")" - ) + ) + ("\n\n(" + self._unformatted_ast_repr(verb_depth, expr_depth, table_display_name_map) + ")") try: import black diff --git a/src/pydiverse/transform/_internal/tree/col_expr.py b/src/pydiverse/transform/_internal/tree/col_expr.py index 380a245c..3a0f123f 100644 --- a/src/pydiverse/transform/_internal/tree/col_expr.py +++ b/src/pydiverse/transform/_internal/tree/col_expr.py @@ -119,9 +119,7 @@ def _repr_html_(self) -> str: if isinstance(self, Col): return tbl_repr.replace("Table", "Column", 1) - return re.sub( - r"Table .* \(backend: .+\)
", "", tbl_repr, count=1 - ) + return re.sub(r"Table .* \(backend: .+\)
", "", tbl_repr, count=1) def _repr_pretty_(self, p, cycle): p.text(str(self) if not cycle else "...") @@ -132,9 +130,7 @@ def ast_repr(self, depth: int = -1) -> str: # Does the recursive calls and allows to alter the displayed names of the tables # of columns (this is necessary for the verb ast_repr). - def _ast_repr( - self, depth: int, needs_parens: bool, table_display_name_map: dict[AstNode, str] - ): + def _ast_repr(self, depth: int, needs_parens: bool, table_display_name_map: dict[AstNode, str]): raise NotImplementedError() def export(self, target: Target) -> pl.Series | pd.Series: @@ -237,9 +233,7 @@ def dtype(self) -> Dtype: def ftype(self, *, agg_is_window: bool | None = None) -> Ftype: return self._ftype - def map( - self, mapping: dict[tuple | ColExpr, ColExpr], *, default: ColExpr | None = None - ) -> CaseExpr: + def map(self, mapping: dict[tuple | ColExpr, ColExpr], *, default: ColExpr | None = None) -> CaseExpr: """ Replaces given values by other expressions. @@ -379,8 +373,7 @@ def cast(self, target_type: Dtype | type, *, strict: bool = True) -> Cast: errors.check_arg_type(bool, "ColExpr.cast", "strict", strict) if type(target_type) is type and not issubclass(target_type, Dtype): TypeError( - "argument for parameter `target_type` of `ColExpr.cast` must be an" - "instance or subclass of pdt.Dtype" + "argument for parameter `target_type` of `ColExpr.cast` must be aninstance or subclass of pdt.Dtype" ) return Cast(self, target_type, strict=strict) @@ -437,19 +430,13 @@ def __add__(self: ColExpr[String], rhs: ColExpr[String]) -> ColExpr[String]: ... def __add__(self: ColExpr[Bool], rhs: ColExpr[Bool]) -> ColExpr[Int]: ... @overload - def __add__( - self: ColExpr[Duration], rhs: ColExpr[Duration] - ) -> ColExpr[Duration]: ... + def __add__(self: ColExpr[Duration], rhs: ColExpr[Duration]) -> ColExpr[Duration]: ... @overload - def __add__( - self: ColExpr[Datetime], rhs: ColExpr[Duration] - ) -> ColExpr[Datetime]: ... + def __add__(self: ColExpr[Datetime], rhs: ColExpr[Duration]) -> ColExpr[Datetime]: ... @overload - def __add__( - self: ColExpr[Duration], rhs: ColExpr[Datetime] - ) -> ColExpr[Datetime]: ... + def __add__(self: ColExpr[Duration], rhs: ColExpr[Datetime]) -> ColExpr[Datetime]: ... def __add__(self: ColExpr, rhs: ColExpr) -> ColExpr: """ @@ -471,19 +458,13 @@ def __radd__(self: ColExpr[String], rhs: ColExpr[String]) -> ColExpr[String]: .. def __radd__(self: ColExpr[Bool], rhs: ColExpr[Bool]) -> ColExpr[Int]: ... @overload - def __radd__( - self: ColExpr[Duration], rhs: ColExpr[Duration] - ) -> ColExpr[Duration]: ... + def __radd__(self: ColExpr[Duration], rhs: ColExpr[Duration]) -> ColExpr[Duration]: ... @overload - def __radd__( - self: ColExpr[Datetime], rhs: ColExpr[Duration] - ) -> ColExpr[Datetime]: ... + def __radd__(self: ColExpr[Datetime], rhs: ColExpr[Duration]) -> ColExpr[Datetime]: ... @overload - def __radd__( - self: ColExpr[Duration], rhs: ColExpr[Datetime] - ) -> ColExpr[Datetime]: ... + def __radd__(self: ColExpr[Duration], rhs: ColExpr[Datetime]) -> ColExpr[Datetime]: ... def __radd__(self: ColExpr, rhs: ColExpr) -> ColExpr: """ @@ -779,44 +760,28 @@ def ceil(self: ColExpr[Float]) -> ColExpr[Float]: return ColFn(ops.ceil, self) @overload - def clip( - self: ColExpr[Int], lower_bound: int, upper_bound: int - ) -> ColExpr[Int]: ... + def clip(self: ColExpr[Int], lower_bound: int, upper_bound: int) -> ColExpr[Int]: ... @overload - def clip( - self: ColExpr[Float], lower_bound: float, upper_bound: float - ) -> ColExpr[Float]: ... + def clip(self: ColExpr[Float], lower_bound: float, upper_bound: float) -> ColExpr[Float]: ... @overload - def clip( - self: ColExpr[String], lower_bound: str, upper_bound: str - ) -> ColExpr[String]: ... + def clip(self: ColExpr[String], lower_bound: str, upper_bound: str) -> ColExpr[String]: ... @overload - def clip( - self: ColExpr[Datetime], lower_bound: datetime, upper_bound: datetime - ) -> ColExpr[Datetime]: ... + def clip(self: ColExpr[Datetime], lower_bound: datetime, upper_bound: datetime) -> ColExpr[Datetime]: ... @overload - def clip( - self: ColExpr[Time], lower_bound: time, upper_bound: time - ) -> ColExpr[Time]: ... + def clip(self: ColExpr[Time], lower_bound: time, upper_bound: time) -> ColExpr[Time]: ... @overload - def clip( - self: ColExpr[Duration], lower_bound: timedelta, upper_bound: timedelta - ) -> ColExpr[Duration]: ... + def clip(self: ColExpr[Duration], lower_bound: timedelta, upper_bound: timedelta) -> ColExpr[Duration]: ... @overload - def clip( - self: ColExpr[Date], lower_bound: date, upper_bound: date - ) -> ColExpr[Date]: ... + def clip(self: ColExpr[Date], lower_bound: date, upper_bound: date) -> ColExpr[Date]: ... @overload - def clip( - self: ColExpr[Bool], lower_bound: bool, upper_bound: bool - ) -> ColExpr[Bool]: ... + def clip(self: ColExpr[Bool], lower_bound: bool, upper_bound: bool) -> ColExpr[Bool]: ... def clip(self: ColExpr, lower_bound: int, upper_bound: int) -> ColExpr: """ @@ -1641,9 +1606,7 @@ def shift( └──────┴───────┴─────┴───────┘ """ - return ColFn( - ops.shift, self, n, fill_value, partition_by=partition_by, arrange=arrange - ) + return ColFn(ops.shift, self, n, fill_value, partition_by=partition_by, arrange=arrange) def sin(self: ColExpr[Float]) -> ColExpr[Float]: """ @@ -1666,9 +1629,7 @@ def __sub__(self: ColExpr[Int], rhs: ColExpr[Int]) -> ColExpr[Int]: ... def __sub__(self: ColExpr[Float], rhs: ColExpr[Float]) -> ColExpr[Float]: ... @overload - def __sub__( - self: ColExpr[Datetime], rhs: ColExpr[Datetime] - ) -> ColExpr[Duration]: ... + def __sub__(self: ColExpr[Datetime], rhs: ColExpr[Datetime]) -> ColExpr[Duration]: ... @overload def __sub__(self: ColExpr[Date], rhs: ColExpr[Date]) -> ColExpr[Duration]: ... @@ -1687,9 +1648,7 @@ def __rsub__(self: ColExpr[Int], rhs: ColExpr[Int]) -> ColExpr[Int]: ... def __rsub__(self: ColExpr[Float], rhs: ColExpr[Float]) -> ColExpr[Float]: ... @overload - def __rsub__( - self: ColExpr[Datetime], rhs: ColExpr[Datetime] - ) -> ColExpr[Duration]: ... + def __rsub__(self: ColExpr[Datetime], rhs: ColExpr[Datetime]) -> ColExpr[Duration]: ... @overload def __rsub__(self: ColExpr[Date], rhs: ColExpr[Date]) -> ColExpr[Duration]: ... @@ -1842,9 +1801,7 @@ def contains( └────────┴────────────┴───────┴───────┴──────┘ """ - return ColFn( - ops.str_contains, self.arg, pattern, allow_regex, true_if_regex_unsupported - ) + return ColFn(ops.str_contains, self.arg, pattern, allow_regex, true_if_regex_unsupported) def ends_with(self: ColExpr[String], suffix: str) -> ColExpr[Bool]: """ @@ -1974,9 +1931,7 @@ def lower(self: ColExpr[String]) -> ColExpr[String]: return ColFn(ops.str_lower, self.arg) - def replace_all( - self: ColExpr[String], substr: str, replacement: str - ) -> ColExpr[String]: + def replace_all(self: ColExpr[String], substr: str, replacement: str) -> ColExpr[String]: """ Replaces all occurrences of a given substring by a different string. @@ -2021,9 +1976,7 @@ def replace_all( return ColFn(ops.str_replace_all, self.arg, substr, replacement) - def slice( - self: ColExpr[String], offset: ColExpr[Int], n: ColExpr[Int] - ) -> ColExpr[String]: + def slice(self: ColExpr[String], offset: ColExpr[Int], n: ColExpr[Int]) -> ColExpr[String]: """ Returns a substring of the input string. @@ -2349,9 +2302,7 @@ def agg( class Col(ColExpr): - def __init__( - self, name: str, _ast: AstNode, _uuid: UUID, _dtype: Dtype, _ftype: Ftype - ): + def __init__(self, name: str, _ast: AstNode, _uuid: UUID, _dtype: Dtype, _ftype: Ftype): self.name = name self._ast = _ast self._uuid = _uuid @@ -2366,9 +2317,7 @@ def __hash__(self) -> int: class ColName(ColExpr): - def __init__( - self, name: str, dtype: Dtype | None = None, ftype: Ftype | None = None - ): + def __init__(self, name: str, dtype: Dtype | None = None, ftype: Ftype | None = None): self.name = name super().__init__(dtype, ftype) @@ -2411,9 +2360,7 @@ class ColFn(ColExpr): def __init__(self, op: Operator, *args: ColExpr, **kwargs: list[ColExpr | Order]): self.op = op # While building the expression tree, we have to allow markers. - self.args: list[ColExpr] = [ - wrap_literals(arg, allow_markers=True) for arg in args - ] + self.args: list[ColExpr] = [wrap_literals(arg, allow_markers=True) for arg in args] self.context_kwargs = clean_kwargs(**kwargs) # An id to recognize the expression also after copying / preprocessing. Useful @@ -2474,23 +2421,14 @@ def parenthesize(s: str) -> str: # are called like pdt.(...) return f"{self.op.name}(...)" - return f"(...).{self.op.name}" + ( - "(...)" if len(self.args) > 1 or len(self.context_kwargs) > 0 else "()" - ) + return f"(...).{self.op.name}" + ("(...)" if len(self.args) > 1 or len(self.context_kwargs) > 0 else "()") arg_parens_limit = 1 + int(self.op.name in DUNDER_BINARY) args = [ - e._ast_repr(depth - 1, i <= arg_parens_limit, table_display_name_map) - for i, e in enumerate(self.args) + e._ast_repr(depth - 1, i <= arg_parens_limit, table_display_name_map) for i, e in enumerate(self.args) ] + [ - ( - f"{ckwarg}=[" - + ", ".join( - v._ast_repr(depth - 1, False, table_display_name_map) for v in val - ) - + "]" - ) + (f"{ckwarg}=[" + ", ".join(v._ast_repr(depth - 1, False, table_display_name_map) for v in val) + "]") for ckwarg, val in self.context_kwargs.items() ] if op_symbol := DUNDER_BINARY.get(self.op.name): @@ -2509,18 +2447,14 @@ def iter_children(self) -> Iterable[ColExpr]: def map_children(self, g: Callable[[ColExpr], ColExpr]): self.args = [g(arg) for arg in self.args] - self.context_kwargs = { - key: [g(val) for val in arr] for key, arr in self.context_kwargs.items() - } + self.context_kwargs = {key: [g(val) for val in arr] for key, arr in self.context_kwargs.items()} def dtype(self) -> Dtype: if self._dtype is not None: return self._dtype arg_dtypes = [arg.dtype() for arg in self.args] - context_kwarg_dtypes = [ - elem.dtype() for elem in itertools.chain(*self.context_kwargs.values()) - ] + context_kwarg_dtypes = [elem.dtype() for elem in itertools.chain(*self.context_kwargs.values())] # we don't need the context_kwargs' types but we need to make sure type checks # are run on them @@ -2536,8 +2470,7 @@ def dtype(self) -> Dtype: ) if self.op.ftype == Ftype.ELEMENT_WISE and all( - types.is_const(argt) - for argt in itertools.chain(arg_dtypes, context_kwarg_dtypes) + types.is_const(argt) for argt in itertools.chain(arg_dtypes, context_kwarg_dtypes) ): self._dtype = Const(self._dtype) @@ -2564,11 +2497,7 @@ def ftype(self, *, agg_is_window: bool | None = None): if None in ftypes: return None - actual_ftype = ( - Ftype.WINDOW - if self.op.ftype == Ftype.AGGREGATE and agg_is_window - else self.op.ftype - ) + actual_ftype = Ftype.WINDOW if self.op.ftype == Ftype.AGGREGATE and agg_is_window else self.op.ftype if actual_ftype == Ftype.ELEMENT_WISE: if Ftype.WINDOW in ftypes: @@ -2586,9 +2515,7 @@ def ftype(self, *, agg_is_window: bool | None = None): if ( node is not self and isinstance(node, ColFn) - and ( - (desc_ftype := node.op.ftype) in (Ftype.AGGREGATE, Ftype.WINDOW) - ) + and ((desc_ftype := node.op.ftype) in (Ftype.AGGREGATE, Ftype.WINDOW)) ): assert isinstance(self, ColFn) ftype_string = { @@ -2646,9 +2573,7 @@ def _ast_repr(self, depth: int, needs_parens: bool, table_display_name_map) -> s ) + ( "" if self.default_val is None - else ".otherwise(" - + self.default_val._ast_repr(depth - 1, False, table_display_name_map) - + ")" + else ".otherwise(" + self.default_val._ast_repr(depth - 1, False, table_display_name_map) + ")" ) def iter_children(self) -> Iterable[ColExpr]: @@ -2665,13 +2590,9 @@ def dtype(self): return self._dtype for cond, _ in self.cases: - if ( - cond.dtype() is not None - and not types.without_const(cond.dtype()) == types.Bool() - ): + if cond.dtype() is not None and not types.without_const(cond.dtype()) == types.Bool(): raise DataTypeError( - f"argument `{cond.ast_repr()}` for `when` must be of boolean " - f"type, but has type `{cond.dtype()}`", + f"argument `{cond.ast_repr()}` for `when` must be of boolean type, but has type `{cond.dtype()}`", source=self._fn_id, ) @@ -2693,10 +2614,7 @@ def dtype(self): if all( ( - *( - types.is_const(cond.dtype()) and types.is_const(val.dtype()) - for cond, val in self.cases - ), + *(types.is_const(cond.dtype()) and types.is_const(val.dtype()) for cond, val in self.cases), self.default_val is None or types.is_const(self.default_val.dtype()), ) ): @@ -2711,9 +2629,7 @@ def ftype(self, *, agg_is_window: bool | None = None): val_ftypes = set() # TODO: does it actually matter if we add stuff that is const? it should be # elemwise anyway... - if self.default_val is not None and not types.is_const( - self.default_val.dtype() - ): + if self.default_val is not None and not types.is_const(self.default_val.dtype()): val_ftypes.add(self.default_val.ftype(agg_is_window=agg_is_window)) for cond, val in self.cases: @@ -2732,9 +2648,7 @@ def ftype(self, *, agg_is_window: bool | None = None): self._ftype = Ftype.WINDOW else: raise FunctionTypeError( - "incompatible function types found in case statement: , ".join( - val_ftypes - ), + "incompatible function types found in case statement: , ".join(val_ftypes), source=self._fn_id, ) @@ -2745,13 +2659,8 @@ def when(self, condition: ColExpr) -> WhenClause: raise TypeError("cannot call `when` on a closed case expression after") condition = wrap_literals(condition) - if condition.dtype() is not None and not isinstance( - condition.dtype(), types.Bool - ): - raise DataTypeError( - "argument for `when` must be of boolean type, but has type " - f"`{condition.dtype()}`" - ) + if condition.dtype() is not None and not isinstance(condition.dtype(), types.Bool): + raise DataTypeError(f"argument for `when` must be of boolean type, but has type `{condition.dtype()}`") return WhenClause(self.cases, wrap_literals(condition)) @@ -2796,12 +2705,8 @@ def is_valid_cast(source, target) -> bool: *( (t, u) for t, u in itertools.chain( - itertools.product( - (Int(), *INT_SUBTYPES), (*FLOAT_SUBTYPES, *INT_SUBTYPES) - ), - itertools.product( - (Float(), *FLOAT_SUBTYPES), (*FLOAT_SUBTYPES, *INT_SUBTYPES) - ), + itertools.product((Int(), *INT_SUBTYPES), (*FLOAT_SUBTYPES, *INT_SUBTYPES)), + itertools.product((Float(), *FLOAT_SUBTYPES), (*FLOAT_SUBTYPES, *INT_SUBTYPES)), ) ), *((Bool(), t) for t in itertools.chain(FLOAT_SUBTYPES, INT_SUBTYPES)), @@ -2828,9 +2733,7 @@ def dtype(self) -> Dtype | None: if not types.converts_to(self.val.dtype(), self.target_type): if not Cast.is_valid_cast(self.val.dtype(), self.target_type): hint = "" - if types.without_const( - self.val.dtype() - ) == String() and self.target_type in ( + if types.without_const(self.val.dtype()) == String() and self.target_type in ( Datetime(), Date(), ): @@ -2852,10 +2755,7 @@ def dtype(self) -> Dtype | None: def _ast_repr(self, depth: int, needs_parens: bool, table_display_name_map) -> str: if depth == 0: return f"(...).cast({self.target_type})" - return ( - f"{self.val._ast_repr(depth - 1, True, table_display_name_map)}" - + f".cast({self.target_type})" - ) + return f"{self.val._ast_repr(depth - 1, True, table_display_name_map)}" + f".cast({self.target_type})" def ftype(self, *, agg_is_window: bool | None = None) -> Ftype | None: if self._ftype is None: @@ -2904,9 +2804,7 @@ def iter_children(self): def map_children(self, g): self.val = g(self.val) - def _ast_repr( - self, depth: int, needs_parens: bool, table_display_name_map: dict[AstNode, str] - ) -> str: + def _ast_repr(self, depth: int, needs_parens: bool, table_display_name_map: dict[AstNode, str]) -> str: if depth == 0: return "eval_aligned(...)" return ( @@ -3002,13 +2900,7 @@ def wrap_literals(expr: Any, *, allow_markers=False) -> Any: or ( # markers can only be at the top of an expression tree not isinstance(expr.op, Marker) - and ( - marker_args := [ - arg - for arg in expr.args - if isinstance(arg, ColFn) and isinstance(arg.op, Marker) - ] - ) + and (marker_args := [arg for arg in expr.args if isinstance(arg, ColFn) and isinstance(arg.op, Marker)]) ) ): marker = expr.op if isinstance(expr.op, Marker) else marker_args[0].op @@ -3030,21 +2922,14 @@ def wrap_literals(expr: Any, *, allow_markers=False) -> Any: def clean_kwargs(**kwargs) -> dict[str, list[ColExpr]]: kwargs = { - key: [val] - if not isinstance(val, Iterable) or isinstance(val, str) - else list(val) + key: [val] if not isinstance(val, Iterable) or isinstance(val, str) else list(val) for key, val in kwargs.items() if val is not None } if (partition_by := kwargs.get("partition_by")) is not None: - kwargs["partition_by"] = [ - ColName(col) if isinstance(col, str) else col for col in partition_by - ] + kwargs["partition_by"] = [ColName(col) if isinstance(col, str) else col for col in partition_by] if (arrange := kwargs.get("arrange")) is not None: - kwargs["arrange"] = [ - Order.from_col_expr(ColName(ord) if isinstance(ord, str) else ord) - for ord in arrange - ] + kwargs["arrange"] = [Order.from_col_expr(ColName(ord) if isinstance(ord, str) else ord) for ord in arrange] return {key: [wrap_literals(val) for val in arr] for key, arr in kwargs.items()} @@ -3073,9 +2958,7 @@ def get_cols(nd: ColExpr): get_cols(expr) - aligned_nodes = [ - ea for ea in expr.iter_subtree_postorder() if isinstance(ea, EvalAligned) - ] + aligned_nodes = [ea for ea in expr.iter_subtree_postorder() if isinstance(ea, EvalAligned)] # We need one column whose AST is an ancestor of all other columns' ASTs. # The following could be done in linear time. @@ -3090,16 +2973,12 @@ def get_cols(nd: ColExpr): if ancestor_index is None: c, d = next( (c, d) - for (i, c), (j, d) in itertools.product( - enumerate(cols + aligned_nodes), enumerate(cols + aligned_nodes) - ) + for (i, c), (j, d) in itertools.product(enumerate(cols + aligned_nodes), enumerate(cols + aligned_nodes)) if roots[i] not in subtrees[j] and roots[j] not in subtrees[i] ) def text_repr(x: Col | EvalAligned): - return ( - "the column " if isinstance(x, Col) else "the `with_` argument of " - ) + f"`{x.ast_repr(depth=1)}`" + return ("the column " if isinstance(x, Col) else "the `with_` argument of ") + f"`{x.ast_repr(depth=1)}`" raise ValueError( "cannot export column expression since no common ancestor table " @@ -3117,8 +2996,7 @@ def text_repr(x: Col | EvalAligned): for col in expr.iter_subtree_postorder(): if isinstance(col, ColName) and col not in tbl: raise ValueError( - f"column expression cannot be exported since the C-column `{col}` is " - "not contained in the table" + f"column expression cannot be exported since the C-column `{col}` is not contained in the table" ) col_name = expr.ast_repr(depth=0) @@ -3139,9 +3017,7 @@ def get_ast_path_str(tree_root: ColExpr, error_source: UUID | ColExpr) -> str: preorder_traversal.append(node) # We have to do this marker magic for usage in verbs since the id of the node # may have changed since the exception was thrown. - if ( - is_uuid and hasattr(node, "_fn_id") and node._fn_id == error_source - ) or error_source is node: + if (is_uuid and hasattr(node, "_fn_id") and node._fn_id == error_source) or error_source is node: break path = [preorder_traversal[-1]] for node in reversed(preorder_traversal): diff --git a/src/pydiverse/transform/_internal/tree/types.py b/src/pydiverse/transform/_internal/tree/types.py index abaa2534..7eb4e7bf 100644 --- a/src/pydiverse/transform/_internal/tree/types.py +++ b/src/pydiverse/transform/_internal/tree/types.py @@ -97,9 +97,7 @@ def with_const(dtype: Dtype) -> Dtype: def converts_to(source: Dtype, target: Dtype) -> bool: if is_const(target): - return is_const(source) and converts_to( - without_const(source), without_const(target) - ) + return is_const(source) and converts_to(without_const(source), without_const(target)) source = without_const(source) if isinstance(source, List): return isinstance(target, List) and converts_to(source.inner, target.inner) @@ -107,11 +105,7 @@ def converts_to(source: Dtype, target: Dtype) -> bool: return ( target == source or target == String() - or ( - type(target) is String - and source.max_length is not None - and target.max_length > source.max_length - ) + or (type(target) is String and source.max_length is not None and target.max_length > source.max_length) ) if isinstance(source, Decimal): return ( @@ -198,12 +192,8 @@ def lca_type(dtypes: list[Dtype]) -> Dtype: # reduce to simple types if isinstance(dtypes[0], List): - if diff := next( - (dtype for dtype in dtypes if not isinstance(dtype, List)), None - ): - raise DataTypeError( - f"type `{diff.__name__}` is not compatible with `List` type" - ) + if diff := next((dtype for dtype in dtypes if not isinstance(dtype, List)), None): + raise DataTypeError(f"type `{diff.__name__}` is not compatible with `List` type") return List(lca_type([dtype.inner for dtype in dtypes])) @@ -286,22 +276,14 @@ def implicit_conversions(dtype: Dtype) -> list[Dtype]: if isinstance(dtype, Enum | String): return [String()] + ([dtype] if dtype.max_length is not None else []) if isinstance(dtype, Decimal): - return ( - list(FLOAT_SUBTYPES) + [Float()] + ([dtype] if dtype != Decimal() else []) - ) + return list(FLOAT_SUBTYPES) + [Float()] + ([dtype] if dtype != Decimal() else []) return list(IMPLICIT_CONVS[dtype].keys()) IMPLICIT_CONVS: dict[Dtype, dict[Dtype, tuple[int, int]]] = { Int(): {Float(): (1, 0), Decimal(): (2, 0), Int(): (0, 0)}, - **{ - int_subtype: {Int(): (0, 1), int_subtype: (0, 0)} - for int_subtype in INT_SUBTYPES - }, - **{ - float_subtype: {Float(): (0, 1), float_subtype: (0, 0)} - for float_subtype in FLOAT_SUBTYPES - }, + **{int_subtype: {Int(): (0, 1), int_subtype: (0, 0)} for int_subtype in INT_SUBTYPES}, + **{float_subtype: {Float(): (0, 1), float_subtype: (0, 0)} for float_subtype in FLOAT_SUBTYPES}, Float(): {Float(): (0, 0)}, String(): {String(): (0, 0)}, Decimal(): {Decimal(): (0, 0), Float(): (0, 1)}, @@ -322,9 +304,7 @@ def implicit_conversions(dtype: Dtype) -> list[Dtype]: for intermediate_type, cost1 in IMPLICIT_CONVS[start_type].items(): if intermediate_type in IMPLICIT_CONVS: for target_type, cost2 in IMPLICIT_CONVS[intermediate_type].items(): - added_edges[target_type] = tuple( - sum(z) for z in zip(cost1, cost2, strict=True) - ) + added_edges[target_type] = tuple(sum(z) for z in zip(cost1, cost2, strict=True)) if start_type not in IMPLICIT_CONVS: IMPLICIT_CONVS[start_type] = added_edges IMPLICIT_CONVS[start_type] |= added_edges @@ -338,13 +318,7 @@ def conversion_cost(dtype: Dtype, target: Dtype) -> tuple[int, int]: if isinstance(dtype, List): return conversion_cost(dtype.inner, target.inner) if isinstance(dtype, Enum | String | Decimal): - return ( - (0, 0) - if dtype == target - else (0, 1) - if type(dtype) is type(target) - else (0, 2) - ) + return (0, 0) if dtype == target else (0, 1) if type(dtype) is type(target) else (0, 2) return IMPLICIT_CONVS[dtype][target] diff --git a/src/pydiverse/transform/_internal/tree/verbs.py b/src/pydiverse/transform/_internal/tree/verbs.py index a08ad9da..bffb2061 100644 --- a/src/pydiverse/transform/_internal/tree/verbs.py +++ b/src/pydiverse/transform/_internal/tree/verbs.py @@ -23,11 +23,7 @@ def __post_init__(self): def _unformatted_ast_repr(self, verb_depth: int, expr_depth: int, display_name_map): nd_repr = self._ast_node_repr(expr_depth, display_name_map) return ( - self.child._unformatted_ast_repr( - verb_depth - 1, expr_depth, display_name_map - ) - if verb_depth != 0 - else "..." + self.child._unformatted_ast_repr(verb_depth - 1, expr_depth, display_name_map) if verb_depth != 0 else "..." ) + f" >> {nd_repr}" def _clone(self) -> tuple["Verb", dict[AstNode, AstNode], dict[UUID, UUID]]: @@ -87,18 +83,14 @@ class Alias(Verb): # would be nice to create a separate marker node for this to distinguish it from a # user-created alias. def _ast_node_repr(self, expr_depth: int, display_name_map) -> str: - return ( - "alias(" + (f"'{self.name}'" if self.name != self.child.name else "") + ")" - ) + return "alias(" + (f"'{self.name}'" if self.name != self.child.name else "") + ")" def _clone(self) -> tuple[Verb, dict[AstNode, AstNode], dict[UUID, UUID]]: cloned, nd_map, uuid_map = Verb._clone(self) if self.uuid_map is not None: # happens if and only if keep_col_refs=False assert set(self.uuid_map.keys()).issubset(uuid_map.keys()) uuid_map = { - self.uuid_map[old_uid]: new_uid - for old_uid, new_uid in uuid_map.items() - if old_uid in self.uuid_map + self.uuid_map[old_uid]: new_uid for old_uid, new_uid in uuid_map.items() if old_uid in self.uuid_map } cloned.uuid_map = None return cloned, nd_map, uuid_map @@ -115,14 +107,7 @@ def map_col_roots(self, g: Callable[[ColExpr], ColExpr]): self.select = [g(col) for col in self.select] def _ast_node_repr(self, expr_depth: int, display_name_map) -> str: - return ( - "select(" - + ", ".join( - col._ast_repr(expr_depth, False, display_name_map) - for col in self.select - ) - + ")" - ) + return "select(" + ", ".join(col._ast_repr(expr_depth, False, display_name_map) for col in self.select) + ")" @dataclasses.dataclass(eq=False, slots=True, repr=False) @@ -130,11 +115,7 @@ class Rename(Verb): name_map: dict[str, str] def _ast_node_repr(self, expr_depth: int, display_name_map) -> str: - return ( - "rename({" - + ", ".join(f"'{k}': '{v}'" for k, v in self.name_map.items()) - + "})" - ) + return "rename({" + ", ".join(f"'{k}': '{v}'" for k, v in self.name_map.items()) + "})" @dataclasses.dataclass(eq=False, slots=True, repr=False) @@ -162,12 +143,7 @@ def _clone(self) -> tuple[Verb, dict[AstNode, AstNode], dict[UUID, UUID]]: cloned, nd_map, uuid_map = Verb._clone(self) assert isinstance(cloned, Mutate) cloned.uuids = [uuid.uuid1() for _ in self.names] - uuid_map.update( - { - old_uid: new_uid - for old_uid, new_uid in zip(self.uuids, cloned.uuids, strict=True) - } - ) + uuid_map.update({old_uid: new_uid for old_uid, new_uid in zip(self.uuids, cloned.uuids, strict=True)}) return cloned, nd_map, uuid_map @@ -177,12 +153,7 @@ class Filter(Verb): def _ast_node_repr(self, expr_depth: int, display_name_map) -> str: return ( - "filter(" - + ", ".join( - pred._ast_repr(expr_depth, False, display_name_map) - for pred in self.predicates - ) - + ")" + "filter(" + ", ".join(pred._ast_repr(expr_depth, False, display_name_map) for pred in self.predicates) + ")" ) def iter_col_roots(self) -> Iterable[ColExpr]: @@ -217,12 +188,7 @@ def _clone(self) -> tuple[Verb, dict[AstNode, AstNode], dict[UUID, UUID]]: cloned, nd_map, uuid_map = Verb._clone(self) assert isinstance(cloned, Summarize) cloned.uuids = [uuid.uuid1() for _ in self.names] - uuid_map.update( - { - old_uid: new_uid - for old_uid, new_uid in zip(self.uuids, cloned.uuids, strict=True) - } - ) + uuid_map.update({old_uid: new_uid for old_uid, new_uid in zip(self.uuids, cloned.uuids, strict=True)}) return cloned, nd_map, uuid_map @@ -231,23 +197,13 @@ class Arrange(Verb): order_by: list[Order] def _ast_node_repr(self, expr_depth: int, display_name_map) -> str: - return ( - "arrange(" - + ", ".join( - ord._ast_repr(expr_depth, False, display_name_map) - for ord in self.order_by - ) - + ")" - ) + return "arrange(" + ", ".join(ord._ast_repr(expr_depth, False, display_name_map) for ord in self.order_by) + ")" def iter_col_roots(self) -> Iterable[ColExpr]: yield from (ord.order_by for ord in self.order_by) def map_col_roots(self, g: Callable[[ColExpr], ColExpr]): - self.order_by = [ - Order(g(ord.order_by), ord.descending, ord.nulls_last) - for ord in self.order_by - ] + self.order_by = [Order(g(ord.order_by), ord.descending, ord.nulls_last) for ord in self.order_by] @dataclasses.dataclass(eq=False, slots=True, repr=False) @@ -266,12 +222,7 @@ class GroupBy(Verb): def _ast_node_repr(self, expr_depth: int, display_name_map) -> str: return ( - "group_by(" - + ", ".join( - col._ast_repr(expr_depth, False, display_name_map) - for col in self.group_by - ) - + ")" + "group_by(" + ", ".join(col._ast_repr(expr_depth, False, display_name_map) for col in self.group_by) + ")" ) def iter_col_roots(self) -> Iterable[ColExpr]: @@ -294,21 +245,13 @@ class Join(Verb): how: Literal["inner", "left", "full"] validate: Literal["1:1", "1:m", "m:1", "m:m"] - def _unformatted_ast_repr( - self, verb_depth: int, expr_depth: int, display_name_map - ) -> str: + def _unformatted_ast_repr(self, verb_depth: int, expr_depth: int, display_name_map) -> str: return ( - self.child._unformatted_ast_repr( - verb_depth - 1, expr_depth, display_name_map - ) - if verb_depth != 0 - else "..." + self.child._unformatted_ast_repr(verb_depth - 1, expr_depth, display_name_map) if verb_depth != 0 else "..." ) + ( ">> join(" + ( - self.right._unformatted_ast_repr( - verb_depth - 1, expr_depth, display_name_map - ) + self.right._unformatted_ast_repr(verb_depth - 1, expr_depth, display_name_map) if verb_depth != 0 else self.right._unformatted_ast_repr(0, expr_depth, display_name_map) ) @@ -359,6 +302,54 @@ def map_col_roots(self, g: Callable[[ColExpr], ColExpr]): self.on = g(self.on) +@dataclasses.dataclass(eq=False, slots=True, repr=False) +class Union(Verb): + right: AstNode + distinct: bool + + def _unformatted_ast_repr(self, verb_depth: int, expr_depth: int, display_name_map) -> str: + return ( + self.child._unformatted_ast_repr(verb_depth - 1, expr_depth, display_name_map) if verb_depth != 0 else "..." + ) + ( + ">> union(" + + ( + self.right._unformatted_ast_repr(verb_depth - 1, expr_depth, display_name_map) + if verb_depth != 0 + else self.right._unformatted_ast_repr(0, expr_depth, display_name_map) + ) + + f", distinct={self.distinct})" + ) + + def _clone(self) -> tuple["Union", dict[AstNode, AstNode], dict[UUID, UUID]]: + child, nd_map, uuid_map = self.child._clone() + right_child, right_nd_map, right_uuid_map = self.right._clone() + nd_map.update(right_nd_map) + uuid_map.update(right_uuid_map) + + cloned = copy.copy(self) + cloned.child = child + cloned.right = right_child + + nd_map[self] = cloned + return cloned, nd_map, uuid_map + + def iter_subtree_postorder(self) -> Iterable[AstNode]: + yield from self.child.iter_subtree_postorder() + yield from self.right.iter_subtree_postorder() + yield self + + def iter_subtree_preorder(self): + yield self + yield from self.child.iter_subtree_preorder() + yield from self.right.iter_subtree_preorder() + + def iter_col_roots(self) -> Iterable[ColExpr]: + return iter(()) + + def map_col_roots(self, g: Callable[[ColExpr], ColExpr]): + pass + + class SubqueryMarker(Verb): def _ast_node_repr(self, expr_depth, display_name_map): return "subquery_marker" diff --git a/src/pydiverse/transform/common.py b/src/pydiverse/transform/common.py index 2ee8b882..62d80b60 100644 --- a/src/pydiverse/transform/common.py +++ b/src/pydiverse/transform/common.py @@ -32,6 +32,7 @@ slice_head, summarize, ungroup, + union, ) from .base import * # noqa: F403 from .base import __all__ as __base @@ -51,6 +52,7 @@ "left_join", "full_join", "cross_join", + "union", "mutate", "rename", "select", diff --git a/tests/conftest.py b/tests/conftest.py index 7ad344b7..d585e658 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,10 +34,7 @@ def pytest_collection_modifyitems(config: pytest.Config, items): skip = pytest.mark.skip(reason=f"{opt} not selected") for item in items: if opt in item.keywords or any( - kw.startswith(f"{opt}-") - or kw.endswith(f"-{opt}") - or f"-{opt}-" in kw - for kw in item.keywords + kw.startswith(f"{opt}-") or kw.endswith(f"-{opt}") or f"-{opt}-" in kw for kw in item.keywords ): item.add_marker(skip) diff --git a/tests/test_backend_equivalence/conftest.py b/tests/test_backend_equivalence/conftest.py index 13b05e4f..ecff0d2e 100644 --- a/tests/test_backend_equivalence/conftest.py +++ b/tests/test_backend_equivalence/conftest.py @@ -272,9 +272,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc): backends = {k: i for i, k in enumerate(backends)} backend_combinations = [ - (reference_backend, backend) - for reference_backend in reference_backends - for backend in backends + (reference_backend, backend) for reference_backend in reference_backends for backend in backends ] params = [] diff --git a/tests/test_backend_equivalence/test_alias.py b/tests/test_backend_equivalence/test_alias.py index 931ddd1f..f1c0c4d6 100644 --- a/tests/test_backend_equivalence/test_alias.py +++ b/tests/test_backend_equivalence/test_alias.py @@ -28,9 +28,7 @@ def run(post_op): def test_alias_mutate(df3): def run(post_op): - assert_result_equal( - df3, lambda t: t >> mutate(x=t.col1) >> alias("a") >> post_op - ) + assert_result_equal(df3, lambda t: t >> mutate(x=t.col1) >> alias("a") >> post_op) for post_op in [ mutate(), @@ -46,10 +44,7 @@ def test_alias_window(df3): def run(post_op): assert_result_equal( df3, - lambda t: t - >> mutate(x=t.col1.count(partition_by=t.col2)) - >> alias("a") - >> post_op, + lambda t: t >> mutate(x=t.col1.count(partition_by=t.col2)) >> alias("a") >> post_op, ) for post_op in [ @@ -70,11 +65,7 @@ def test_alias_summarize(df3): def run(post_op): assert_result_equal( df3, - lambda t: t - >> group_by(t.col1, t.col2) - >> summarize(x=t.col1.count()) - >> alias("a") - >> post_op, + lambda t: t >> group_by(t.col1, t.col2) >> summarize(x=t.col1.count()) >> alias("a") >> post_op, ) for post_op in [ diff --git a/tests/test_backend_equivalence/test_arrange.py b/tests/test_backend_equivalence/test_arrange.py index baaf8a13..b90df26b 100644 --- a/tests/test_backend_equivalence/test_arrange.py +++ b/tests/test_backend_equivalence/test_arrange.py @@ -10,9 +10,7 @@ def test_noop(df1): - assert_result_equal( - df1, lambda t: t >> arrange(), may_throw=True, exception=TypeError - ) + assert_result_equal(df1, lambda t: t >> arrange(), may_throw=True, exception=TypeError) def test_arrange(df2): @@ -23,9 +21,7 @@ def test_arrange(df2): def test_arrange_expression(df3): - assert_result_equal( - df3, lambda t: t >> arrange(t.col2, t.col4), check_row_order=True - ) + assert_result_equal(df3, lambda t: t >> arrange(t.col2, t.col4), check_row_order=True) assert_result_equal(df3, lambda t: t >> arrange(-t.col4 * 2), check_row_order=True) @@ -37,9 +33,7 @@ def test_arrange_null(df2): def test_multiple(df3): assert_result_equal(df3, lambda t: t >> arrange(t.col2, -t.col3, -t.col4)) - assert_result_equal( - df3, lambda t: t >> arrange(t.col2) >> arrange(-t.col3, -t.col4) - ) + assert_result_equal(df3, lambda t: t >> arrange(t.col2) >> arrange(-t.col3, -t.col4)) assert_result_equal( df3, @@ -94,8 +88,6 @@ def test_nulls_first_last_mixed(df4): def test_arrange_after_mutate(df4): assert_result_equal( df4, - lambda t: t - >> mutate(x=t.col1 <= t.col2) - >> arrange(C.x.nulls_last(), C.col4.nulls_first()), + lambda t: t >> mutate(x=t.col1 <= t.col2) >> arrange(C.x.nulls_last(), C.col4.nulls_first()), check_row_order=True, ) diff --git a/tests/test_backend_equivalence/test_dtypes.py b/tests/test_backend_equivalence/test_dtypes.py index 06ee369a..045582d6 100644 --- a/tests/test_backend_equivalence/test_dtypes.py +++ b/tests/test_backend_equivalence/test_dtypes.py @@ -8,7 +8,5 @@ def test_dtypes(df1): assert_result_equal( df1, - lambda t: t - >> filter(t.col1 % 2 == 1) - >> inner_join(s := t >> mutate(u=t.col1 % 2) >> alias(), t.col1 == s.u), + lambda t: t >> filter(t.col1 % 2 == 1) >> inner_join(s := t >> mutate(u=t.col1 % 2) >> alias(), t.col1 == s.u), ) diff --git a/tests/test_backend_equivalence/test_group_by.py b/tests/test_backend_equivalence/test_group_by.py index f9237843..aa478c42 100644 --- a/tests/test_backend_equivalence/test_group_by.py +++ b/tests/test_backend_equivalence/test_group_by.py @@ -44,9 +44,7 @@ def test_mutate(df3, df4): partition_by=[t.col1], ), p=t.col1 * u.col4, - y=pdt.rank( - arrange=[(t.col1 * u.col4).nulls_last().nulls_first().nulls_last()] - ), + y=pdt.rank(arrange=[(t.col1 * u.col4).nulls_last().nulls_first().nulls_last()]), ), ) @@ -71,10 +69,7 @@ def test_ungrouped_join(df1, df3, how): # After ungrouping joining should work again assert_result_equal( (df1, df3), - lambda t, u: t - >> group_by(t.col1) - >> ungroup() - >> join(u, t.col1 == u.col1, how=how), + lambda t, u: t >> group_by(t.col1) >> ungroup() >> join(u, t.col1 == u.col1, how=how), check_row_order=False, ) @@ -84,9 +79,7 @@ def test_filter(df3): def test_arrange(df3): - assert_result_equal( - df3, lambda t: t >> group_by(t.col1) >> arrange(t.col1, -t.col3) - ) + assert_result_equal(df3, lambda t: t >> group_by(t.col1) >> arrange(t.col1, -t.col3)) assert_result_equal(df3, lambda t: t >> group_by(t.col1) >> arrange(-t.col4)) @@ -94,17 +87,12 @@ def test_arrange(df3): def test_group_by_bool_col(df4): assert_result_equal( df4, - lambda t: t - >> mutate(x=t.col1 <= t.col2) - >> group_by(C.x) - >> mutate(y=C.col4.mean()), + lambda t: t >> mutate(x=t.col1 <= t.col2) >> group_by(C.x) >> mutate(y=C.col4.mean()), ) def test_group_by_scalar(df3): - assert_result_equal( - df3, lambda t: t >> mutate(x=0) >> group_by(C.x) >> summarize(y=t.col1.sum()) - ) + assert_result_equal(df3, lambda t: t >> mutate(x=0) >> group_by(C.x) >> summarize(y=t.col1.sum())) assert_result_equal( df3, diff --git a/tests/test_backend_equivalence/test_join.py b/tests/test_backend_equivalence/test_join.py index ad50a3d9..78c95c41 100644 --- a/tests/test_backend_equivalence/test_join.py +++ b/tests/test_backend_equivalence/test_join.py @@ -50,9 +50,7 @@ def test_join(df1, df2, how): df1, lambda t: t >> left_join( - r := t - >> left_join(s := t >> alias("s"), on=t.col1 == s.col1) - >> alias("r"), + r := t >> left_join(s := t >> alias("s"), on=t.col1 == s.col1) >> alias("r"), on=t.col1 == r.col1, ), ) @@ -99,8 +97,7 @@ def test_join_and_select(df1, df2, how): assert_result_equal( (df1, df2), - lambda t, u: t - >> join(u >> select(), (t.col1 == u.col1) & (u.col2 == t.col1), how=how), + lambda t, u: t >> join(u >> select(), (t.col1 == u.col1) & (u.col2 == t.col1), how=how), check_row_order=False, ) @@ -189,27 +186,18 @@ def test_ineq_join(df3, df4, df_strings): ), ) - assert_result_equal( - (df3, df4), lambda s, t: s >> inner_join(t, on=["col1", s.col2 <= t.col2]) - ) + assert_result_equal((df3, df4), lambda s, t: s >> inner_join(t, on=["col1", s.col2 <= t.col2])) def test_join_summarize(df3, df4): assert_result_equal( (df3, df4), - lambda t3, t4: t3 - >> group_by(t3.col2) - >> summarize(j=t3.col4.sum()) - >> alias() - >> inner_join(t4, on="col2"), + lambda t3, t4: t3 >> group_by(t3.col2) >> summarize(j=t3.col4.sum()) >> alias() >> inner_join(t4, on="col2"), ) assert_result_equal( (df3, df4), - lambda t3, t4: t4 - >> left_join( - t3 >> group_by(t3.col2) >> summarize(j=t3.col4.sum()) >> alias(), on="col2" - ), + lambda t3, t4: t4 >> left_join(t3 >> group_by(t3.col2) >> summarize(j=t3.col4.sum()) >> alias(), on="col2"), ) assert_result_equal( @@ -224,10 +212,7 @@ def test_join_summarize(df3, df4): def test_join_window(df3, df4): assert_result_equal( (df3, df4), - lambda t3, t4: t3 - >> mutate(y=t3.col1.dense_rank()) - >> alias() - >> inner_join(t4, on=C.y == t4.col1), + lambda t3, t4: t3 >> mutate(y=t3.col1.dense_rank()) >> alias() >> inner_join(t4, on=C.y == t4.col1), ) assert_result_equal( @@ -253,9 +238,7 @@ def test_join_where(df2, df3, df4): assert_result_equal( (df3, df4), - lambda t3, t4: t3 - >> filter(t3.col4 != -1729) - >> left_join(t4 >> filter(t4.col3 > 0), on=t3.col2 == t4.col2), + lambda t3, t4: t3 >> filter(t3.col4 != -1729) >> left_join(t4 >> filter(t4.col3 > 0), on=t3.col2 == t4.col2), ) assert_result_equal( @@ -278,10 +261,7 @@ def test_join_const_col(df3, df4): assert_result_equal( (df3, df4), - lambda s, t: s - >> mutate(z=2) - >> alias() - >> full_join(t >> mutate(j=True) >> alias(), on="col2"), + lambda s, t: s >> mutate(z=2) >> alias() >> full_join(t >> mutate(j=True) >> alias(), on="col2"), ) assert_result_equal( diff --git a/tests/test_backend_equivalence/test_mutate.py b/tests/test_backend_equivalence/test_mutate.py index 5dc7c3c9..a1fc5131 100644 --- a/tests/test_backend_equivalence/test_mutate.py +++ b/tests/test_backend_equivalence/test_mutate.py @@ -12,9 +12,7 @@ def test_noop(df2): - assert_result_equal( - df2, lambda t: t >> mutate(col1=t.col1, col2=t.col2, col3=t.col3) - ) + assert_result_equal(df2, lambda t: t >> mutate(col1=t.col1, col2=t.col2, col3=t.col3)) def test_multiply(df1): @@ -27,9 +25,7 @@ def test_reorder(df2): assert_result_equal( df2, - lambda t: t - >> mutate(col1=t.col2, col2=t.col1) - >> mutate(col1=t.col2, col2=C.col3, col3=C.col2), + lambda t: t >> mutate(col1=t.col2, col2=t.col1) >> mutate(col1=t.col2, col2=C.col3, col3=C.col2), ) @@ -65,9 +61,7 @@ def test_none(df4): def test_mutate_bool_expr(df4): assert_result_equal( df4, - lambda t: t - >> mutate(x=t.col1 <= t.col2, y=(t.col3 * 4) >= C.col4) - >> mutate(xAndY=C.x & C.y), + lambda t: t >> mutate(x=t.col1 <= t.col2, y=(t.col3 * 4) >= C.col4) >> mutate(xAndY=C.x & C.y), ) diff --git a/tests/test_backend_equivalence/test_ops/test_case_expression.py b/tests/test_backend_equivalence/test_ops/test_case_expression.py index fa2f4521..83292a4a 100644 --- a/tests/test_backend_equivalence/test_ops/test_case_expression.py +++ b/tests/test_backend_equivalence/test_ops/test_case_expression.py @@ -15,21 +15,14 @@ def test_mutate_case_ewise(df4): assert_result_equal( df4, - lambda t: t - >> mutate( - x=C.col1.map({0: 1, (1, 2): 2}), y=C.col1.map({0: 0, 1: None}, default=10.4) - ), + lambda t: t >> mutate(x=C.col1.map({0: 1, (1, 2): 2}), y=C.col1.map({0: 0, 1: None}, default=10.4)), ) assert_result_equal( df4, lambda t: t >> mutate( - x=pdt.when(C.col1 == C.col2) - .then(1) - .when(C.col2 == C.col3) - .then(2) - .otherwise(C.col1 + C.col2), + x=pdt.when(C.col1 == C.col2).then(1).when(C.col2 == C.col3).then(2).otherwise(C.col1 + C.col2), ), ) @@ -65,17 +58,11 @@ def test_mutate_case_window(df4): df4, lambda t: t >> mutate( - u=C.col1.shift( - 1, 1729, arrange=[t.col3.descending().nulls_last(), t.col4.nulls_last()] - ), + u=C.col1.shift(1, 1729, arrange=[t.col3.descending().nulls_last(), t.col4.nulls_last()]), x=C.col1.shift(1, 0, arrange=[C.col4.nulls_first()]).map( { - 1: C.col2.shift( - 1, -1, arrange=[C.col2.nulls_last(), C.col4.nulls_first()] - ), - 2: C.col3.shift( - 2, -2, arrange=[C.col3.nulls_last(), C.col4.nulls_last()] - ), + 1: C.col2.shift(1, -1, arrange=[C.col2.nulls_last(), C.col4.nulls_first()]), + 2: C.col3.shift(2, -2, arrange=[C.col3.nulls_last(), C.col4.nulls_last()]), } ), ), @@ -114,11 +101,7 @@ def test_summarize_case(df4): 2: 2, } ), - y=pdt.when(C.col2.max() > 2) - .then(1) - .when(C.col2.max() < 2) - .then(C.col2.min()) - .otherwise(C.col3.mean()), + y=pdt.when(C.col2.max() > 2).then(1).when(C.col2.max() < 2).then(C.col2.min()).otherwise(C.col3.mean()), ), ) @@ -156,9 +139,6 @@ def test_invalid_ftype(df1): assert_result_equal( df1, - lambda t: t - >> summarize( - x=pdt.when(pdt.rank(arrange=[C.col1]) == 1).then(1).otherwise(None) - ), + lambda t: t >> summarize(x=pdt.when(pdt.rank(arrange=[C.col1]) == 1).then(1).otherwise(None)), exception=FunctionTypeError, ) diff --git a/tests/test_backend_equivalence/test_ops/test_cast.py b/tests/test_backend_equivalence/test_ops/test_cast.py index c6561560..7e19b84f 100644 --- a/tests/test_backend_equivalence/test_ops/test_cast.py +++ b/tests/test_backend_equivalence/test_ops/test_cast.py @@ -55,9 +55,7 @@ def test_datetime_to_date(df_datetime): def test_int_to_string(df_int): - assert_result_equal( - df_int, lambda t: t >> mutate(**{c.name: c.cast(pdt.String()) for c in t}) - ) + assert_result_equal(df_int, lambda t: t >> mutate(**{c.name: c.cast(pdt.String()) for c in t})) def test_float_to_string(df_num): diff --git a/tests/test_backend_equivalence/test_ops/test_functions.py b/tests/test_backend_equivalence/test_ops/test_functions.py index 11c3d9f3..0beb069c 100644 --- a/tests/test_backend_equivalence/test_ops/test_functions.py +++ b/tests/test_backend_equivalence/test_ops/test_functions.py @@ -24,11 +24,7 @@ def test_row_number(df4): assert_result_equal( df4, lambda t: t - >> mutate( - row_number=pdt.row_number( - arrange=[C.col1.descending().nulls_first(), C.col5.nulls_last()] - ) - ), + >> mutate(row_number=pdt.row_number(arrange=[C.col1.descending().nulls_first(), C.col5.nulls_last()])), ) diff --git a/tests/test_backend_equivalence/test_ops/test_ops_datetime.py b/tests/test_backend_equivalence/test_ops/test_ops_datetime.py index 713e50de..6788078c 100644 --- a/tests/test_backend_equivalence/test_ops/test_ops_datetime.py +++ b/tests/test_backend_equivalence/test_ops/test_ops_datetime.py @@ -10,42 +10,26 @@ def test_eq(df_datetime): - assert_result_equal( - df_datetime, lambda t: t >> filter(C.col1 == datetime(1970, 1, 1)) - ) - assert_result_equal( - df_datetime, lambda t: t >> filter(C.col1 == datetime(2004, 12, 31)) - ) + assert_result_equal(df_datetime, lambda t: t >> filter(C.col1 == datetime(1970, 1, 1))) + assert_result_equal(df_datetime, lambda t: t >> filter(C.col1 == datetime(2004, 12, 31))) assert_result_equal(df_datetime, lambda t: t >> filter(C.col1 == C.col2)) def test_nq(df_datetime): - assert_result_equal( - df_datetime, lambda t: t >> filter(C.col1 != datetime(1970, 1, 1)) - ) - assert_result_equal( - df_datetime, lambda t: t >> filter(C.col1 != datetime(2004, 12, 31)) - ) + assert_result_equal(df_datetime, lambda t: t >> filter(C.col1 != datetime(1970, 1, 1))) + assert_result_equal(df_datetime, lambda t: t >> filter(C.col1 != datetime(2004, 12, 31))) assert_result_equal(df_datetime, lambda t: t >> filter(C.col1 != C.col2)) def test_lt(df_datetime): - assert_result_equal( - df_datetime, lambda t: t >> filter(C.col1 < datetime(1970, 1, 1)) - ) - assert_result_equal( - df_datetime, lambda t: t >> filter(C.col1 < datetime(2004, 12, 31)) - ) + assert_result_equal(df_datetime, lambda t: t >> filter(C.col1 < datetime(1970, 1, 1))) + assert_result_equal(df_datetime, lambda t: t >> filter(C.col1 < datetime(2004, 12, 31))) assert_result_equal(df_datetime, lambda t: t >> filter(C.col1 < C.col2)) def test_gt(df_datetime): - assert_result_equal( - df_datetime, lambda t: t >> filter(C.col1 > datetime(1970, 1, 1)) - ) - assert_result_equal( - df_datetime, lambda t: t >> filter(C.col1 > datetime(2004, 12, 31)) - ) + assert_result_equal(df_datetime, lambda t: t >> filter(C.col1 > datetime(1970, 1, 1))) + assert_result_equal(df_datetime, lambda t: t >> filter(C.col1 > datetime(2004, 12, 31))) assert_result_equal(df_datetime, lambda t: t >> filter(C.col1 > C.col2)) diff --git a/tests/test_backend_equivalence/test_ops/test_ops_int.py b/tests/test_backend_equivalence/test_ops/test_ops_int.py index 63721feb..8bb9aca8 100644 --- a/tests/test_backend_equivalence/test_ops/test_ops_int.py +++ b/tests/test_backend_equivalence/test_ops/test_ops_int.py @@ -11,10 +11,7 @@ def test_add(df_int): assert_result_equal( df_int, - lambda t: t - >> ( - lambda s: mutate(**{f"add_{c.name}_{d.name}": c + d for d in s for c in s}) - ), + lambda t: t >> (lambda s: mutate(**{f"add_{c.name}_{d.name}": c + d for d in s for c in s})), ) @@ -22,12 +19,7 @@ def test_sub(df_int): assert_result_equal( df_int, lambda t: t - >> ( - lambda s: ( - mutate(**{f"sub_{c.name}_{d.name}": c - d for d in s for c in s}) - >> (lambda u: mutate()) - ) - ), + >> (lambda s: (mutate(**{f"sub_{c.name}_{d.name}": c - d for d in s for c in s}) >> (lambda u: mutate()))), ) @@ -41,10 +33,7 @@ def test_neg(df_int): def test_mul(df_int): assert_result_equal( df_int, - lambda t: t - >> ( - lambda s: mutate(**{f"mul_{c.name}_{d.name}": c * d for d in s for c in s}) - ), + lambda t: t >> (lambda s: mutate(**{f"mul_{c.name}_{d.name}": c * d for d in s for c in s})), ) @@ -53,9 +42,7 @@ def test_truediv(df_int): df_int, lambda t: t >> mutate(**{c.name: c.map({0: 1}, default=c) for c in t}) - >> ( - lambda s: mutate(**{f"div_{c.name}_{d.name}": c / d for d in s for c in s}) - ), + >> (lambda s: mutate(**{f"div_{c.name}_{d.name}": c / d for d in s for c in s})), ) @@ -64,9 +51,7 @@ def test_floordiv(df_int): df_int, lambda t: t >> mutate(**{c.name: c.map({0: 1}, default=c) for c in t}) - >> ( - lambda s: mutate(**{f"div_{c.name}_{d.name}": c // d for d in s for c in s}) - ), + >> (lambda s: mutate(**{f"div_{c.name}_{d.name}": c // d for d in s for c in s})), ) @@ -75,9 +60,7 @@ def test_mod(df_int): df_int, lambda t: t >> mutate(**{c.name: c.map({0: 1}, default=c) for c in t}) - >> ( - lambda s: mutate(**{f"mod_{c.name}_{d.name}": c % d for d in s for c in s}) - ), + >> (lambda s: mutate(**{f"mod_{c.name}_{d.name}": c % d for d in s for c in s})), ) assert_result_equal( @@ -91,11 +74,7 @@ def test_mod(df_int): >> summarize( **{ f"div_plus_mod_{c.name}_{d.name}": ( - ( - C[f"div_{c.name}_{d.name}"] * C[d.name] - + C[f"mod_{c.name}_{d.name}"] - ) - == C[c.name] + (C[f"div_{c.name}_{d.name}"] * C[d.name] + C[f"mod_{c.name}_{d.name}"]) == C[c.name] ).all() for d in t for c in t diff --git a/tests/test_backend_equivalence/test_ops/test_ops_list.py b/tests/test_backend_equivalence/test_ops/test_ops_list.py index b871dead..0b495cfd 100644 --- a/tests/test_backend_equivalence/test_ops/test_ops_list.py +++ b/tests/test_backend_equivalence/test_ops/test_ops_list.py @@ -10,10 +10,7 @@ def test_list_agg(df3): assert_result_equal( df3, - lambda t: t - >> group_by(t.col3) - >> summarize(s=t.col2.list.agg()) - >> arrange(C.s), + lambda t: t >> group_by(t.col3) >> summarize(s=t.col2.list.agg()) >> arrange(C.s), check_row_order=True, ) @@ -34,7 +31,6 @@ def test_list_agg(df3): def test_list_agg_no_grouping(df3): assert_result_equal( df3, - lambda t: t - >> summarize(h=t.col5.list.agg(arrange=[t.col1, t.col4.descending()])), + lambda t: t >> summarize(h=t.col5.list.agg(arrange=[t.col1, t.col4.descending()])), check_row_order=True, ) diff --git a/tests/test_backend_equivalence/test_ops/test_ops_numerical.py b/tests/test_backend_equivalence/test_ops/test_ops_numerical.py index c55bb37e..31d6ff8a 100644 --- a/tests/test_backend_equivalence/test_ops/test_ops_numerical.py +++ b/tests/test_backend_equivalence/test_ops/test_ops_numerical.py @@ -17,9 +17,7 @@ def test_exp(df_num): def test_log(df_num): - assert_result_equal( - df_num, lambda t: t >> mutate(**{c.name: pdt.max(1e-16, c).log() for c in t}) - ) + assert_result_equal(df_num, lambda t: t >> mutate(**{c.name: pdt.max(1e-16, c).log() for c in t})) def test_abs(df_num): @@ -47,16 +45,14 @@ def test_round(df_num): def test_add(df_num): assert_result_equal( df_num, - lambda t: t - >> mutate(**{f"add_{c.name}_{d.name}": c + d for d in t for c in t}), + lambda t: t >> mutate(**{f"add_{c.name}_{d.name}": c + d for d in t for c in t}), ) def test_sub(df_num): assert_result_equal( df_num, - lambda t: t - >> mutate(**{f"sub_{c.name}_{d.name}": c - d for d in t for c in t}), + lambda t: t >> mutate(**{f"sub_{c.name}_{d.name}": c - d for d in t for c in t}), ) @@ -70,8 +66,7 @@ def test_neg(df_num): def test_mul(df_num): assert_result_equal( df_num, - lambda t: t - >> mutate(**{f"mul_{c.name}_{d.name}": c * d for d in t for c in t}), + lambda t: t >> mutate(**{f"mul_{c.name}_{d.name}": c * d for d in t for c in t}), ) @@ -79,21 +74,15 @@ def test_div(df_num): assert_result_equal( df_num, lambda t: t - >> mutate( - **{c.name: pdt.when(c.abs() < 1e-50).then(1e-50).otherwise(c) for c in t} - ) - >> ( - lambda s: mutate(**{f"div_{c.name}_{d.name}": c / d for d in s for c in s}) - ), + >> mutate(**{c.name: pdt.when(c.abs() < 1e-50).then(1e-50).otherwise(c) for c in t}) + >> (lambda s: mutate(**{f"div_{c.name}_{d.name}": c / d for d in s for c in s})), ) def test_decimal(df_num): assert_result_equal( df_num, - lambda t: t - >> mutate(f=t.f.cast(pdt.Decimal), g=t.g.cast(pdt.Decimal)) - >> mutate(u=C.f + C.g, z=C.f * C.g), + lambda t: t >> mutate(f=t.f.cast(pdt.Decimal), g=t.g.cast(pdt.Decimal)) >> mutate(u=C.f + C.g, z=C.f * C.g), ) @@ -186,9 +175,7 @@ def test_is_nan(df_num): def test_int_pow(df_int): - assert_result_equal( - df_int, lambda t: t >> mutate(u=pdt.min(t.a, 10) ** pdt.min(t.b.abs(), 5)) - ) + assert_result_equal(df_int, lambda t: t >> mutate(u=pdt.min(t.a, 10) ** pdt.min(t.b.abs(), 5))) def test_sin(df_num): @@ -200,21 +187,15 @@ def test_cos(df_num): def test_tan(df_num): - assert_result_equal( - df_num, lambda t: t >> mutate(**{c.name: c.atan().tan() for c in t}) - ) + assert_result_equal(df_num, lambda t: t >> mutate(**{c.name: c.atan().tan() for c in t})) def test_asin(df_num): - assert_result_equal( - df_num, lambda t: t >> mutate(**{c.name: c.sin().asin() for c in t}) - ) + assert_result_equal(df_num, lambda t: t >> mutate(**{c.name: c.sin().asin() for c in t})) def test_acos(df_num): - assert_result_equal( - df_num, lambda t: t >> mutate(**{c.name: c.cos().acos() for c in t}) - ) + assert_result_equal(df_num, lambda t: t >> mutate(**{c.name: c.cos().acos() for c in t})) def test_atan(df_num): @@ -222,9 +203,7 @@ def test_atan(df_num): def test_sqrt(df_num): - assert_result_equal( - df_num, lambda t: t >> mutate(**{c.name: c.abs().sqrt() for c in t}) - ) + assert_result_equal(df_num, lambda t: t >> mutate(**{c.name: c.abs().sqrt() for c in t})) def test_cbrt(df_num): diff --git a/tests/test_backend_equivalence/test_ops/test_ops_string.py b/tests/test_backend_equivalence/test_ops/test_ops_string.py index 04896ab4..0d774fd7 100644 --- a/tests/test_backend_equivalence/test_ops/test_ops_string.py +++ b/tests/test_backend_equivalence/test_ops/test_ops_string.py @@ -16,8 +16,7 @@ def test_eq(df_strings): assert_result_equal(df_strings, lambda t: t >> filter(C.col1 == "foo")) assert_result_equal( df_strings, - lambda t: t - >> filter(C.col1.str.replace_all(" ", "") == C.col2.str.replace_all(" ", "")), + lambda t: t >> filter(C.col1.str.replace_all(" ", "") == C.col2.str.replace_all(" ", "")), ) @@ -31,8 +30,7 @@ def test_nq(df_strings): assert_result_equal(df_strings, lambda t: t >> filter(C.col1 != "foo")) assert_result_equal( df_strings, - lambda t: t - >> filter(C.col1.str.replace_all(" ", "") != C.col2.str.replace_all(" ", "")), + lambda t: t >> filter(C.col1.str.replace_all(" ", "") != C.col2.str.replace_all(" ", "")), ) @@ -52,8 +50,7 @@ def test_gt(df_strings): assert_result_equal(df_strings, lambda t: t >> filter(C.col1 > "E")) assert_result_equal( df_strings, - lambda t: t - >> filter(C.col1.str.replace_all(" ", "") > C.col2.str.replace_all(" ", "")), + lambda t: t >> filter(C.col1.str.replace_all(" ", "") > C.col2.str.replace_all(" ", "")), ) @@ -77,22 +74,18 @@ def test_le(df_strings): lambda t: t >> mutate( col1_le_c=C.col1 <= C.c, - col1_le_col2=t.col1.str.replace_all(" ", "") - <= t.col2.str.replace_all(" ", ""), + col1_le_col2=t.col1.str.replace_all(" ", "") <= t.col2.str.replace_all(" ", ""), d_le_c=t.d <= t.c, ), ) assert_result_equal( df_strings, - lambda t: t - >> filter(C.col1.str.replace_all(" ", "") <= C.col2.str.replace_all(" ", "")), + lambda t: t >> filter(C.col1.str.replace_all(" ", "") <= C.col2.str.replace_all(" ", "")), ) def test_ge(df_strings): - assert_result_equal( - df_strings, lambda t: t >> mutate(col1_ge_col2=C.col1 >= C.col2) - ) + assert_result_equal(df_strings, lambda t: t >> mutate(col1_ge_col2=C.col1 >= C.col2)) assert_result_equal(df_strings, lambda t: t >> filter(C.col1 >= C.col2)) @@ -228,9 +221,7 @@ def test_slice(df_strings): v=t.col2.str.replace_all(" ", "") .str.slice(t.col1.str.len() % (t.col2.str.len() + 1), 42) .str.replace_all(" ", ""), - w=t.col1.str.replace_all(" ", "") - .str.slice(2, t.col1.str.len()) - .str.replace_all(" ", ""), + w=t.col1.str.replace_all(" ", "").str.slice(2, t.col1.str.len()).str.replace_all(" ", ""), ), ) @@ -241,9 +232,7 @@ def test_str_join(df_strings): df_strings, lambda t: t >> group_by(t.e) - >> summarize( - con=t.c.str.join(", ", arrange=[t.d.nulls_first(), t.c.nulls_last()]) - ), + >> summarize(con=t.c.str.join(", ", arrange=[t.d.nulls_first(), t.c.nulls_last()])), ) assert_result_equal( @@ -269,8 +258,7 @@ def test_str_arrange(df_strings): def bind(col): assert_result_equal( df_strings, - lambda t: t - >> arrange(t[col].str.replace_all(" ", "").nulls_last(), t.c.nulls_last()), + lambda t: t >> arrange(t[col].str.replace_all(" ", "").nulls_last(), t.c.nulls_last()), check_row_order=True, ) diff --git a/tests/test_backend_equivalence/test_rename.py b/tests/test_backend_equivalence/test_rename.py index 5c5d9e1a..2fb179e3 100644 --- a/tests/test_backend_equivalence/test_rename.py +++ b/tests/test_backend_equivalence/test_rename.py @@ -21,14 +21,10 @@ def test_simple(df3): def test_chained(df3): assert_result_equal(df3, lambda t: t >> rename({"col1": "X"}) >> rename({"X": "Y"})) - assert_result_equal( - df3, lambda t: t >> rename({"col1": "X"}) >> rename({"X": "col1"}) - ) + assert_result_equal(df3, lambda t: t >> rename({"col1": "X"}) >> rename({"X": "col1"})) assert_result_equal( df3, - lambda t: t - >> rename({"col1": "1", "col2": "2"}) - >> rename({"1": "col1", "2": "col2"}), + lambda t: t >> rename({"col1": "1", "col2": "2"}) >> rename({"1": "col1", "2": "col2"}), ) diff --git a/tests/test_backend_equivalence/test_slice_head.py b/tests/test_backend_equivalence/test_slice_head.py index 00710756..c2513422 100644 --- a/tests/test_backend_equivalence/test_slice_head.py +++ b/tests/test_backend_equivalence/test_slice_head.py @@ -12,25 +12,13 @@ def test_simple(df3): assert_result_equal(df3, lambda t: t >> arrange(*list(t)[0:-1]) >> slice_head(10)) assert_result_equal(df3, lambda t: t >> arrange(*list(t)[0:-1]) >> slice_head(100)) - assert_result_equal( - df3, lambda t: t >> arrange(*list(t)[0:-1]) >> slice_head(1, offset=8) - ) - assert_result_equal( - df3, lambda t: t >> arrange(*list(t)[0:-1]) >> slice_head(10, offset=8) - ) - assert_result_equal( - df3, lambda t: t >> arrange(*list(t)[0:-1]) >> slice_head(100, offset=8) - ) + assert_result_equal(df3, lambda t: t >> arrange(*list(t)[0:-1]) >> slice_head(1, offset=8)) + assert_result_equal(df3, lambda t: t >> arrange(*list(t)[0:-1]) >> slice_head(10, offset=8)) + assert_result_equal(df3, lambda t: t >> arrange(*list(t)[0:-1]) >> slice_head(100, offset=8)) - assert_result_equal( - df3, lambda t: t >> arrange(*list(t)[0:-1]) >> slice_head(1, offset=100) - ) - assert_result_equal( - df3, lambda t: t >> arrange(*list(t)[0:-1]) >> slice_head(10, offset=100) - ) - assert_result_equal( - df3, lambda t: t >> arrange(*list(t)[0:-1]) >> slice_head(100, offset=100) - ) + assert_result_equal(df3, lambda t: t >> arrange(*list(t)[0:-1]) >> slice_head(1, offset=100)) + assert_result_equal(df3, lambda t: t >> arrange(*list(t)[0:-1]) >> slice_head(10, offset=100)) + assert_result_equal(df3, lambda t: t >> arrange(*list(t)[0:-1]) >> slice_head(100, offset=100)) def test_chained(df3): @@ -76,18 +64,13 @@ def test_with_mutate(df3): def test_with_join(df1, df2): assert_result_equal( (df1, df2), - lambda t, u: t - >> arrange(*t) - >> slice_head(3) - >> alias(keep_col_refs=True) - >> left_join(u, t.col1 == u.col1), + lambda t, u: t >> arrange(*t) >> slice_head(3) >> alias(keep_col_refs=True) >> left_join(u, t.col1 == u.col1), check_row_order=False, ) assert_result_equal( (df1, df2), - lambda t, u: t - >> left_join(u >> arrange(*t) >> slice_head(2, offset=1), t.col1 == u.col1), + lambda t, u: t >> left_join(u >> arrange(*t) >> slice_head(2, offset=1), t.col1 == u.col1), check_row_order=False, exception=ColumnNotFoundError, may_throw=True, @@ -97,19 +80,12 @@ def test_with_join(df1, df2): def test_with_filter(df3): assert_result_equal( df3, - lambda t: t - >> filter(t.col4 % 2 == 0) - >> arrange(*list(t)[0:-1]) - >> slice_head(4, offset=2), + lambda t: t >> filter(t.col4 % 2 == 0) >> arrange(*list(t)[0:-1]) >> slice_head(4, offset=2), ) assert_result_equal( df3, - lambda t: t - >> arrange(*list(t)[0:-1]) - >> slice_head(4, offset=2) - >> alias() - >> filter(C.col1 == 1), + lambda t: t >> arrange(*list(t)[0:-1]) >> slice_head(4, offset=2) >> alias() >> filter(C.col1 == 1), ) assert_result_equal( @@ -126,10 +102,7 @@ def test_with_filter(df3): def test_with_arrange(df3): assert_result_equal( df3, - lambda t: t - >> mutate(x=t.col4 - (t.col1 * t.col2)) - >> arrange(C.x, *list(t)[0:-1]) - >> slice_head(4, offset=2), + lambda t: t >> mutate(x=t.col4 - (t.col1 * t.col2)) >> arrange(C.x, *list(t)[0:-1]) >> slice_head(4, offset=2), ) assert_result_equal( @@ -146,12 +119,7 @@ def test_with_arrange(df3): def test_with_group_by(df3): assert_result_equal( df3, - lambda t: t - >> arrange(*list(t)[0:-1]) - >> slice_head(1) - >> alias() - >> group_by(C.col1) - >> mutate(x=pdt.count()), + lambda t: t >> arrange(*list(t)[0:-1]) >> slice_head(1) >> alias() >> group_by(C.col1) >> mutate(x=pdt.count()), ) assert_result_equal( @@ -181,18 +149,10 @@ def test_with_group_by(df3): def test_with_summarize(df3): assert_result_equal( df3, - lambda t: t - >> arrange(*list(t)[0:-1]) - >> slice_head(4) - >> alias() - >> summarize(count=pdt.count()), + lambda t: t >> arrange(*list(t)[0:-1]) >> slice_head(4) >> alias() >> summarize(count=pdt.count()), ) assert_result_equal( df3, - lambda t: t - >> arrange(*list(t)[0:-1]) - >> slice_head(4) - >> alias() - >> summarize(c3_mean=C.col3.mean()), + lambda t: t >> arrange(*list(t)[0:-1]) >> slice_head(4) >> alias() >> summarize(c3_mean=C.col3.mean()), ) diff --git a/tests/test_backend_equivalence/test_summarize.py b/tests/test_backend_equivalence/test_summarize.py index 5e98f7c6..d7bdb36e 100644 --- a/tests/test_backend_equivalence/test_summarize.py +++ b/tests/test_backend_equivalence/test_summarize.py @@ -58,17 +58,13 @@ def test_chained_summarized(df3): def test_summarize_name_drop(df3): - assert_result_equal( - df3, lambda t: t >> summarize(x=t.col1.count()) >> mutate(col1=1, col2=2) - ) + assert_result_equal(df3, lambda t: t >> summarize(x=t.col1.count()) >> mutate(col1=1, col2=2)) def test_nested(df3): assert_result_equal( df3, - lambda t: t - >> group_by(t.col1, t.col2) - >> summarize(mean_of_mean3=t.col3.mean().mean()), + lambda t: t >> group_by(t.col1, t.col2) >> summarize(mean_of_mean3=t.col3.mean().mean()), exception=FunctionTypeError, ) @@ -76,60 +72,42 @@ def test_nested(df3): def test_select(df3): assert_result_equal( df3, - lambda t: t - >> group_by(t.col1, t.col2) - >> summarize(mean3=t.col3.mean()) - >> select(t.col1, C.mean3, t.col2), + lambda t: t >> group_by(t.col1, t.col2) >> summarize(mean3=t.col3.mean()) >> select(t.col1, C.mean3, t.col2), ) def test_mutate(df3): assert_result_equal( df3, - lambda t: t - >> group_by(t.col1, t.col2) - >> summarize(mean3=t.col3.mean()) - >> mutate(x10=C.mean3 * 10), + lambda t: t >> group_by(t.col1, t.col2) >> summarize(mean3=t.col3.mean()) >> mutate(x10=C.mean3 * 10), ) def test_filter(df3): assert_result_equal( df3, - lambda t: t - >> group_by(t.col1, t.col2) - >> summarize(mean3=t.col3.mean()) - >> filter(C.mean3 <= 2.0), + lambda t: t >> group_by(t.col1, t.col2) >> summarize(mean3=t.col3.mean()) >> filter(C.mean3 <= 2.0), ) def test_filter_argument(df3): assert_result_equal( df3, - lambda t: t - >> group_by(t.col2) - >> summarize(u=t.col4.sum(filter=(t.col1 != 0))), + lambda t: t >> group_by(t.col2) >> summarize(u=t.col4.sum(filter=(t.col1 != 0))), ) assert_result_equal( df3, lambda t: t >> group_by(t.col4, t.col1) - >> summarize( - u=(t.col3 * t.col4 - t.col2).sum( - filter=(t.col5.is_in("a", "e", "i", "o", "u")) - ) - ), + >> summarize(u=(t.col3 * t.col4 - t.col2).sum(filter=(t.col5.is_in("a", "e", "i", "o", "u")))), ) def test_arrange(df3): assert_result_equal( df3, - lambda t: t - >> group_by(t.col1, t.col2) - >> summarize(mean3=t.col3.mean()) - >> arrange(C.mean3), + lambda t: t >> group_by(t.col1, t.col2) >> summarize(mean3=t.col3.mean()) >> arrange(C.mean3), ) assert_result_equal( @@ -143,9 +121,7 @@ def test_arrange(df3): def test_not_summarising(df4): - assert_result_equal( - df4, lambda t: t >> summarize(x=C.col1), exception=FunctionTypeError - ) + assert_result_equal(df4, lambda t: t >> summarize(x=C.col1), exception=FunctionTypeError) def test_none(df4): @@ -155,18 +131,14 @@ def test_none(df4): def test_op_min(df4): assert_result_equal( df4, - lambda t: t - >> group_by(t.col1) - >> summarize(**{c.name + "_min": c.min() for c in t if c.name != "col7"}), + lambda t: t >> group_by(t.col1) >> summarize(**{c.name + "_min": c.min() for c in t if c.name != "col7"}), ) def test_op_max(df4): assert_result_equal( df4, - lambda t: t - >> group_by(t.col1) - >> summarize(**{c.name + "_max": c.max() for c in t if c.name != "col7"}), + lambda t: t >> group_by(t.col1) >> summarize(**{c.name + "_max": c.max() for c in t if c.name != "col7"}), ) diff --git a/tests/test_backend_equivalence/test_syntax.py b/tests/test_backend_equivalence/test_syntax.py index c95e5503..1b3563a0 100644 --- a/tests/test_backend_equivalence/test_syntax.py +++ b/tests/test_backend_equivalence/test_syntax.py @@ -9,9 +9,7 @@ def test_lambda_cols(df3): assert_result_equal(df3, lambda t: t >> select(C.col1, C.col2)) assert_result_equal(df3, lambda t: t >> mutate(col1=C.col1, col2=C.col1)) - assert_result_equal( - df3, lambda t: t >> select(C.col10), exception=pdt.ColumnNotFoundError - ) + assert_result_equal(df3, lambda t: t >> select(C.col10), exception=pdt.ColumnNotFoundError) def test_transfer_col_references(df3, df4): @@ -22,9 +20,7 @@ def collect_with_refs(tbl): return pdt.transfer_col_references(tbl >> collect(keep_col_refs=False), tbl) assert_result_equal(df3, lambda t: pdt.transfer_col_references(t, t)) - assert_result_equal( - df3, lambda t: t >> collect_with_refs() >> mutate(z=t.col1 + t.col2) - ) + assert_result_equal(df3, lambda t: t >> collect_with_refs() >> mutate(z=t.col1 + t.col2)) assert_result_equal( (df3, df4), lambda t, u: pdt.transfer_col_references(u, t) >> mutate(s=t.col1 * t.col4), diff --git a/tests/test_backend_equivalence/test_union.py b/tests/test_backend_equivalence/test_union.py new file mode 100644 index 00000000..3f431992 --- /dev/null +++ b/tests/test_backend_equivalence/test_union.py @@ -0,0 +1,282 @@ +# Copyright (c) QuantCo and pydiverse contributors 2025-2025 +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +import pydiverse.transform as pdt +from pydiverse.transform.extended import * +from tests.fixtures.backend import skip_backends +from tests.util import assert_result_equal + + +def test_union_basic(df3, df4): + """Test basic union with only visible columns.""" + # Test pipe syntax: tbl1 >> union(tbl2) + assert_result_equal( + (df3, df4), + lambda t, u: t >> union(u), + check_row_order=False, + ) + + # Test direct call syntax: union(tbl1, tbl2) + assert_result_equal( + (df3, df4), + lambda t, u: union(t, u), + check_row_order=False, + ) + + +def test_union_all(df3, df4): + """Test UNION ALL (keeping duplicates).""" + # Test pipe syntax with distinct=False (keeps duplicates) + assert_result_equal( + (df3, df4), + lambda t, u: t >> union(u, distinct=False), + check_row_order=False, + ) + + # Test direct call syntax with distinct=False (keeps duplicates) + assert_result_equal( + (df3, df4), + lambda t, u: union(t, u, distinct=False), + check_row_order=False, + ) + + +@skip_backends("ibm_db2") +def test_union_distinct(df3, df4): + """Test UNION (removing duplicates).""" + # Test pipe syntax with distinct=True (default, removes duplicates) + assert_result_equal( + (df3, df4), + lambda t, u: t >> union(u, distinct=True), + check_row_order=False, + ) + + # Test direct call syntax with distinct=True (default, removes duplicates) + assert_result_equal( + (df3, df4), + lambda t, u: union(t, u, distinct=True), + check_row_order=False, + ) + + +def test_union_distinct2(df3, df4): + """Test UNION (removing duplicates).""" + # Test pipe syntax with distinct=True (default, removes duplicates) + assert_result_equal( + (df3, df4), + lambda t, u: t >> drop(t.col7) >> union(u >> drop(u.col7), distinct=True), + check_row_order=False, + ) + + # Test direct call syntax with distinct=True (default, removes duplicates) + assert_result_equal( + (df3, df4), + lambda t, u: union(t >> drop(t.col7), u >> drop(u.col7), distinct=True), + check_row_order=False, + ) + + +def test_union_with_select(df3, df4): + """Test union with selected columns.""" + # Test with select on left table + assert_result_equal( + (df3, df4), + lambda t, u: t >> select(t.col1) >> union(u >> select(u.col1)), + check_row_order=False, + ) + + # Test with select on right table + assert_result_equal( + (df3, df4), + lambda t, u: t >> select(t.col2) >> union(u >> select(u.col2)), + check_row_order=False, + ) + + +def test_union_with_hidden_columns_left(df3, df4): + """Test union when left table has hidden columns.""" + # Left table has hidden column, right table has all columns visible + assert_result_equal( + (df3, df4), + lambda t, u: t + >> mutate(hidden_col=t.col1 * 10) # create a column + >> select(t.col1, t.col2) # hidden_col becomes hidden + >> union(u >> select(u.col1, u.col2)), # right has no hidden columns + check_row_order=False, + ) + + +def test_union_with_hidden_columns_right(df3, df4): + """Test union when right table has hidden columns that left table doesn't have.""" + # Right table has hidden column, left table has all columns visible + # The hidden column in right should not appear in the union result + assert_result_equal( + (df3, df4), + lambda t, u: t + >> select(t.col1, t.col2) # all visible + >> union(u >> mutate(hidden_col=u.col1 * 10) >> select(u.col1, u.col2)), # hidden_col becomes hidden in right + check_row_order=False, + ) + + +def test_union_with_hidden_columns_partial_match(df3, df4): + """Test union when both tables have hidden columns that partially match.""" + # Both tables have some hidden columns in common, some different + # Only hidden columns that exist in BOTH tables should be preserved + assert_result_equal( + (df3, df4), + lambda t, u: t + >> mutate(shared_hidden=t.col1 * 2, left_only=t.col1 + 100) # create columns + >> select(t.col1, t.col2) # shared_hidden and left_only become hidden + >> union( + u + >> mutate(shared_hidden=u.col1 * 2, right_only=u.col1 + 200) + >> select(u.col1, u.col2) # shared_hidden and right_only become hidden + ), + check_row_order=False, + ) + + +def test_union_with_mutate_hidden(df3, df4): + """Test union with mutated columns that become hidden.""" + # Create columns in mutate, then hide some + assert_result_equal( + (df3, df4), + lambda t, u: t + >> mutate(x=t.col1 * 2, y=t.col2 * 10) + >> select(t.col1, C.x) # y becomes hidden + >> union(u >> mutate(x=u.col1 * 2, y=u.col2 * 10) >> select(u.col1, C.x)), + check_row_order=False, + ) + + +def test_union_chained(df3, df4): + """Test chaining multiple unions.""" + # Chain multiple unions together + assert_result_equal( + (df3, df4), + lambda t, u: t >> union(u) >> union(t), + check_row_order=False, + ) + + +@skip_backends("sqlite") # sqlite only supports UNION for trivial queries +def test_union_after_operations(df3, df4): + """Test union after other operations like filter and arrange.""" + assert_result_equal( + (df3, df4), + lambda t, u: t >> filter(t.col1 > 0) >> arrange(t.col1) >> union(u >> filter(u.col1 > 0) >> arrange(u.col1)), + check_row_order=False, + ) + + +def test_union_error_different_columns(df3, df4): + """Test that union raises error when columns don't match.""" + # Should raise ValueError when column names don't match + assert_result_equal( + (df3, df4), + lambda t, u: t >> select(t.col1) >> union(u >> select(u.col2)), + exception=ValueError, + ) + + +def test_union_error_different_backends(): + """Test that union raises error when backends don't match.""" + import polars as pl + import sqlalchemy as sqa + + # Create tables with different backends + polars_tbl = pdt.Table(pl.DataFrame({"a": [1, 2]})) + engine = sqa.create_engine("sqlite:///:memory:") + pl.DataFrame({"a": [1, 2]}).write_database("test", engine, if_table_exists="replace") + sql_tbl = pdt.Table("test", pdt.SqlAlchemy(engine)) + + # Should raise TypeError + with pytest.raises(TypeError, match="cannot union two tables with different backends"): + polars_tbl >> union(sql_tbl) + + +def test_union_error_grouped_table(df3, df4): + """Test that union raises error when table is grouped.""" + # Should raise ValueError when trying to union a grouped table + assert_result_equal( + (df3, df4), + lambda t, u: t >> group_by(t.col1) >> union(u), + exception=ValueError, + ) + + assert_result_equal( + (df3, df4), + lambda t, u: t >> union(u >> group_by(u.col1)), + exception=ValueError, + ) + + +def test_union_with_rename(df3, df4): + """Test union after renaming columns.""" + assert_result_equal( + (df3, df4), + lambda t, u: t >> rename({"col1": "a", "col2": "b"}) >> union(u >> rename({"col1": "a", "col2": "b"})), + check_row_order=False, + ) + + +def test_union_empty_tables(): + """Test union with empty tables.""" + import polars as pl + + empty1 = pdt.Table(pl.DataFrame({"a": [], "b": []})) + empty2 = pdt.Table(pl.DataFrame({"a": [], "b": []})) + + result = empty1 >> union(empty2) >> export(pdt.Polars(lazy=False)) + assert len(result) == 0 + assert result.columns == ["a", "b"] + + +def test_union_different_column_order1(df3, df4): + """Test union when columns are in different order (should reorder automatically).""" + # Columns in different order should still work - union should handle reordering + # Note: This tests that the backend correctly reorders columns to match + assert_result_equal( + (df3, df4), + lambda t, u: t >> select(t.col1, t.col2) >> union(u >> select(u.col1, u.col2)), # same columns, same order + check_row_order=False, + ) + + +def test_union_different_column_order2(df3, df4): + # Test that union handles column reordering correctly + # The backend should reorder right table columns to match left order (by column name) + assert_result_equal( + (df3, df4), + lambda t, u: t + >> select(t.col1, t.col2) + >> union(u >> select(u.col2, u.col1)), # different order - should be reordered + check_row_order=False, + ) + + +def test_union_different_column_order3(df3, df4): + # Test that union handles column reordering correctly + # The backend should not reorder right table columns since order matches based on name + assert_result_equal( + (df3, df4), + lambda t, u: t + >> select(t.col1, t.col2) + >> union(u >> select(u.col2, u.col1) >> rename({u.col1: "col2", u.col2: "col1"})), + check_row_order=False, + ) + + +def test_union_different_column_order4(df3, df4): + # Test that union handles column reordering correctly + # The backend should reorder right table columns to match left order + assert_result_equal( + (df3, df4), + lambda t, u: t + >> select(t.col1, t.col2) + >> union(u >> select(u.col1, u.col2) >> rename({u.col1: "col2", u.col2: "col1"})), + check_row_order=False, + ) diff --git a/tests/test_backend_equivalence/test_window_function.py b/tests/test_backend_equivalence/test_window_function.py index 4c7e5703..cfd29315 100644 --- a/tests/test_backend_equivalence/test_window_function.py +++ b/tests/test_backend_equivalence/test_window_function.py @@ -17,16 +17,12 @@ def test_simple_ungrouped(df3): def test_simple_grouped(df3): assert_result_equal( df3, - lambda t: t - >> group_by(t.col1) - >> mutate(min=t.col4.min(), max=t.col4.max(), mean=t.col4.mean()), + lambda t: t >> group_by(t.col1) >> mutate(min=t.col4.min(), max=t.col4.max(), mean=t.col4.mean()), ) assert_result_equal( df3, - lambda t: t - >> group_by(t.col1, t.col2) - >> mutate(min=t.col4.min(), max=t.col4.max(), mean=t.col4.mean()), + lambda t: t >> group_by(t.col1, t.col2) >> mutate(min=t.col4.min(), max=t.col4.max(), mean=t.col4.mean()), ) @@ -41,9 +37,7 @@ def test_partition_by_argument(df3, df4): arrange=[t.col5.descending().nulls_last(), t.col4.nulls_first()], partition_by=[t.col2], ), - x=pdt.row_number( - arrange=[t.col4.nulls_last()], partition_by=[t.col1, t.col2] - ), + x=pdt.row_number(arrange=[t.col4.nulls_last()], partition_by=[t.col1, t.col2]), ), ) @@ -52,11 +46,7 @@ def test_partition_by_argument(df3, df4): lambda t, u: t >> join(u, t.col1 == u.col3, how="left") >> group_by(t.col2) - >> mutate( - y=(u.col3 + t.col1).max( - partition_by=(col for col in t if col.name != "col7") - ) - ), + >> mutate(y=(u.col3 + t.col1).max(partition_by=(col for col in t if col.name != "col7"))), ) assert_result_equal( @@ -73,18 +63,12 @@ def test_partition_by_argument(df3, df4): def test_chained(df3): assert_result_equal( df3, - lambda t: t - >> group_by(t.col1) - >> mutate(min=t.col4.min()) - >> mutate(max=t.col4.max(), mean=t.col4.mean()), + lambda t: t >> group_by(t.col1) >> mutate(min=t.col4.min()) >> mutate(max=t.col4.max(), mean=t.col4.mean()), ) assert_result_equal( df3, - lambda t: t - >> group_by(t.col1) - >> mutate(min=t.col4.min(), max=t.col4.max()) - >> mutate(span=C.max - C.min), + lambda t: t >> group_by(t.col1) >> mutate(min=t.col4.min(), max=t.col4.max()) >> mutate(span=C.max - C.min), ) @@ -101,11 +85,7 @@ def test_nested(df3): assert_result_equal( df3, - lambda t: t - >> mutate(x=C.col4.max()) - >> mutate(y=C.x.min() * 1) - >> mutate(z=C.y.mean()) - >> mutate(w=C.x / C.y), + lambda t: t >> mutate(x=C.col4.max()) >> mutate(y=C.x.min() * 1) >> mutate(z=C.y.mean()) >> mutate(w=C.x / C.y), may_throw=True, exception=SubqueryError, ) @@ -121,22 +101,14 @@ def test_nested(df3): def test_filter(df3): assert_result_equal( df3, - lambda t: t - >> group_by(t.col1, t.col2) - >> mutate(mean3=t.col3.mean()) - >> alias() - >> filter(C.mean3 <= 2.0), + lambda t: t >> group_by(t.col1, t.col2) >> mutate(mean3=t.col3.mean()) >> alias() >> filter(C.mean3 <= 2.0), ) def test_filter_argument(df3, df4): - assert_result_equal( - df4, lambda t: t >> mutate(u=t.col2.mean(filter=~t.col2.is_null())) - ) + assert_result_equal(df4, lambda t: t >> mutate(u=t.col2.mean(filter=~t.col2.is_null()))) - assert_result_equal( - df4, lambda t: t >> mutate(u=t.col2.mean(filter=~(t.col4 % 3 == 0))) - ) + assert_result_equal(df4, lambda t: t >> mutate(u=t.col2.mean(filter=~(t.col4 % 3 == 0)))) assert_result_equal(df3, lambda t: t >> mutate(u=t.col4.sum(partition_by=t.col2))) @@ -149,27 +121,18 @@ def test_filter_argument(df3, df4): ), ) - assert_result_equal( - df4, lambda t: t >> mutate(u=t.col3.min(filter=t.col3.is_null())) - ) + assert_result_equal(df4, lambda t: t >> mutate(u=t.col3.min(filter=t.col3.is_null()))) def test_arrange(df3): assert_result_equal( df3, - lambda t: t - >> group_by(t.col1, t.col2) - >> mutate(mean3=t.col3.mean()) - >> arrange(C.mean3), + lambda t: t >> group_by(t.col1, t.col2) >> mutate(mean3=t.col3.mean()) >> arrange(C.mean3), ) assert_result_equal( df3, - lambda t: t - >> arrange(-t.col4) - >> group_by(t.col1, t.col2) - >> mutate(mean3=t.col3.mean()) - >> arrange(C.mean3), + lambda t: t >> arrange(-t.col4) >> group_by(t.col1, t.col2) >> mutate(mean3=t.col3.mean()) >> arrange(C.mean3), ) @@ -213,18 +176,12 @@ def test_arrange_argument(df3): # Grouped assert_result_equal( df3, - lambda t: t - >> group_by(t.col1) - >> mutate(x=C.col4.shift(1, arrange=C.col3.nulls_last())) - >> select(C.x), + lambda t: t >> group_by(t.col1) >> mutate(x=C.col4.shift(1, arrange=C.col3.nulls_last())) >> select(C.x), ) assert_result_equal( df3, - lambda t: t - >> group_by(t.col2) - >> mutate(x=pdt.row_number(arrange=C.col4.descending())) - >> select(C.x), + lambda t: t >> group_by(t.col2) >> mutate(x=pdt.row_number(arrange=C.col4.descending())) >> select(C.x), ) # Ungrouped @@ -331,9 +288,7 @@ def test_op_row_number(df4): >> group_by(t.col1) >> mutate( row_number1=pdt.row_number(arrange=[C.col4.descending().nulls_last()]), - row_number2=pdt.row_number( - arrange=[C.col2.nulls_last(), C.col3.nulls_first(), t.col4.nulls_last()] - ), + row_number2=pdt.row_number(arrange=[C.col2.nulls_last(), C.col3.nulls_first(), t.col4.nulls_last()]), ), ) @@ -342,9 +297,7 @@ def test_op_row_number(df4): lambda t: t >> mutate( u=pdt.row_number(arrange=[C.col4.descending().nulls_last()]), - v=pdt.row_number( - arrange=[t.col3.descending().nulls_first(), t.col4.nulls_first()] - ), + v=pdt.row_number(arrange=[t.col3.descending().nulls_first(), t.col4.nulls_first()]), ), ) @@ -387,9 +340,7 @@ def test_op_dense_rank(df3): def test_partition_by_const_col(df3): - assert_result_equal( - df3, lambda t: t >> mutate(x=0) >> mutate(y=t.col3.sum(partition_by=C.x)) - ) + assert_result_equal(df3, lambda t: t >> mutate(x=0) >> mutate(y=t.col3.sum(partition_by=C.x))) def test_cum_sum(df4): diff --git a/tests/test_polars_table.py b/tests/test_polars_table.py index 211763b4..725d9a17 100644 --- a/tests/test_polars_table.py +++ b/tests/test_polars_table.py @@ -182,9 +182,7 @@ def test_mutate(self, tbl1): def test_join(self, tbl_left, tbl_right): assert_equal( - tbl_left - >> join(tbl_right, tbl_left.a == tbl_right.b, "left") - >> select(tbl_left.a, tbl_right.b), + tbl_left >> join(tbl_right, tbl_left.a == tbl_right.b, "left") >> select(tbl_left.a, tbl_right.b), pl.DataFrame({"a": [1, 2, 2, 3, 4], "b": [1, 2, 2, None, None]}), check_row_order=False, ) @@ -222,12 +220,8 @@ def test_join(self, tbl_left, tbl_right): assert_equal( tbl_right - >> inner_join( - tbl_right2 := tbl_right >> alias(), tbl_right.b == tbl_right2.b - ) - >> inner_join( - tbl_right3 := tbl_right >> alias(), tbl_right.b == tbl_right3.b - ), + >> inner_join(tbl_right2 := tbl_right >> alias(), tbl_right.b == tbl_right2.b) + >> inner_join(tbl_right3 := tbl_right >> alias(), tbl_right.b == tbl_right3.b), df_right.join(df_right, "b", suffix="_df_right", coalesce=False).join( df_right, "b", suffix="_df_right_1", coalesce=False ), @@ -336,18 +330,13 @@ def test_group_by(self, tbl3): >> group_by(tbl3.col1) >> group_by(tbl3.col2, add=True) >> summarize(mean3=tbl3.col3.mean(), mean4=tbl3.col4.mean()), - tbl3 - >> group_by(tbl3.col1, tbl3.col2) - >> summarize(mean3=tbl3.col3.mean(), mean4=tbl3.col4.mean()), + tbl3 >> group_by(tbl3.col1, tbl3.col2) >> summarize(mean3=tbl3.col3.mean(), mean4=tbl3.col4.mean()), check_row_order=False, ) # Ungroup doesn't change the result assert_equal( - tbl3 - >> group_by(tbl3.col1) - >> summarize(mean4=tbl3.col4.mean()) - >> ungroup(), + tbl3 >> group_by(tbl3.col1) >> summarize(mean4=tbl3.col4.mean()) >> ungroup(), tbl3 >> group_by(tbl3.col1) >> summarize(mean4=tbl3.col4.mean()), check_row_order=False, ) @@ -387,34 +376,18 @@ def test_window_functions(self, tbl3, tbl4): ) assert_equal( - ( - tbl3 - >> group_by(C.col2) - >> mutate(x=row_number(arrange=[-C.col4])) - >> select(C.x) - ), + (tbl3 >> group_by(C.col2) >> mutate(x=row_number(arrange=[-C.col4])) >> select(C.x)), pl.DataFrame({"x": [6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1]}), ) # group_by and partition_by should lead to the same result assert_equal( - ( - tbl3 - >> group_by(C.col2) - >> mutate(x=row_number(arrange=[-C.col4])) - >> select(C.x) - ), - ( - tbl3 - >> mutate(x=row_number(arrange=[-C.col4], partition_by=[C.col2])) - >> select(C.x) - ), + (tbl3 >> group_by(C.col2) >> mutate(x=row_number(arrange=[-C.col4])) >> select(C.x)), + (tbl3 >> mutate(x=row_number(arrange=[-C.col4], partition_by=[C.col2])) >> select(C.x)), ) assert_equal( - tbl3 - >> mutate(x=tbl3.col1.shift(1, arrange=tbl3.col4)) - >> inner_join(tbl4, on="col1"), + tbl3 >> mutate(x=tbl3.col1.shift(1, arrange=tbl3.col4)) >> inner_join(tbl4, on="col1"), df3.sort(pl.col("col4")) .with_columns(x=pl.col("col1").shift(1)) .join(df4, on="col1", suffix="_df4", coalesce=False), @@ -451,13 +424,7 @@ def test_case_expression(self, tbl3): ( tbl3 >> mutate( - col1=when(C.col1 == 0) - .then(1) - .when(C.col1 == 1) - .then(2) - .when(C.col1 == 2) - .then(3) - .otherwise(-1) + col1=when(C.col1 == 0).then(1).when(C.col1 == 1).then(2).when(C.col1 == 2).then(3).otherwise(-1) ) >> select(C.col1) ), @@ -467,13 +434,7 @@ def test_case_expression(self, tbl3): assert_equal( ( tbl3 - >> mutate( - x=when(C.col1 == C.col2) - .then(1) - .when(C.col1 == C.col3) - .then(2) - .otherwise(C.col4) - ) + >> mutate(x=when(C.col1 == C.col2).then(1).when(C.col1 == C.col3).then(2).otherwise(C.col4)) >> select(C.x) ), pl.DataFrame({"x": [1, 1, 2, 3, 4, 2, 1, 1, 8, 9, 2, 11]}), @@ -490,23 +451,14 @@ def test_lambda_column(self, tbl1, tbl2): ) assert_equal( - tbl1 - >> mutate(a=tbl1.col1 * 2) - >> mutate(b=C.a * 2, a=tbl1.col1) - >> select(C.b), + tbl1 >> mutate(a=tbl1.col1 * 2) >> mutate(b=C.a * 2, a=tbl1.col1) >> select(C.b), tbl1 >> select() >> mutate(b=tbl1.col1 * 4), ) # Join assert_equal( - tbl1 - >> mutate(a=tbl1.col1) - >> join(tbl2, C.a == tbl2.col1, "left") - >> select(C.a, *tbl2), - tbl1 - >> select() - >> mutate(a=tbl1.col1) - >> join(tbl2, tbl1.col1 == tbl2.col1, "left", suffix="_df2"), + tbl1 >> mutate(a=tbl1.col1) >> join(tbl2, C.a == tbl2.col1, "left") >> select(C.a, *tbl2), + tbl1 >> select() >> mutate(a=tbl1.col1) >> join(tbl2, tbl1.col1 == tbl2.col1, "left", suffix="_df2"), ) # Filter @@ -548,10 +500,7 @@ def test_null(self, tbl4): ) assert_equal( tbl4 >> mutate(u=tbl4.col3.fill_null(tbl4.col2)), - tbl4 - >> mutate( - u=pdt.when(tbl4.col3.is_null()).then(tbl4.col2).otherwise(tbl4.col3) - ), + tbl4 >> mutate(u=pdt.when(tbl4.col3.is_null()).then(tbl4.col2).otherwise(tbl4.col3)), ) def test_datetime(self, tbl_dt): @@ -560,36 +509,25 @@ def test_datetime(self, tbl_dt): >> mutate( u=(tbl_dt.dt1 - tbl_dt.dt2), v=tbl_dt.d1 - tbl_dt.d1, - w=(tbl_dt.d1.cast(pdt.Datetime) - tbl_dt.dt2) - + tbl_dt.dur1 - + dt.timedelta(days=1), + w=(tbl_dt.d1.cast(pdt.Datetime) - tbl_dt.dt2) + tbl_dt.dur1 + dt.timedelta(days=1), ), df_dt.with_columns( (pl.col("dt1") - pl.col("dt2")).alias("u"), pl.duration().alias("v"), - ( - (pl.col("d1") - pl.col("dt2")) - + pl.col("dur1") - + pl.lit(dt.timedelta(days=1)) - ).alias("w"), + ((pl.col("d1") - pl.col("dt2")) + pl.col("dur1") + pl.lit(dt.timedelta(days=1))).alias("w"), ), ) def test_duckdb_execution(self, tbl2, tbl3): assert_equal( - tbl3 - >> mutate(u=tbl3.col1 * 2) - >> collect(DuckDb()) - >> mutate(v=tbl3.col3 + C.u), + tbl3 >> mutate(u=tbl3.col1 * 2) >> collect(DuckDb()) >> mutate(v=tbl3.col3 + C.u), tbl3 >> mutate(u=tbl3.col1 * 2) >> mutate(v=C.col3 + C.u), ) assert_equal( tbl3 >> collect(DuckDb()) - >> left_join( - tbl2 >> collect(DuckDb()), tbl3.col1 == tbl2.col1, suffix="_right" - ) + >> left_join(tbl2 >> collect(DuckDb()), tbl3.col1 == tbl2.col1, suffix="_right") >> mutate(v=tbl3.col3 + tbl2.col2) >> group_by(C.col2) >> summarize(y=C.col3_right.sum()), @@ -614,9 +552,7 @@ def test_col_export(self, tbl1: pdt.Table, tbl2: pdt.Table): ((t.u + C.col2).exp() - t.v).export(Polars()), ) - e = t >> inner_join( - tbl1, tbl1.col1.cast(pdt.Float64()) <= tbl2.col1 + tbl2.col3 - ) + e = t >> inner_join(tbl1, tbl1.col1.cast(pdt.Float64()) <= tbl2.col1 + tbl2.col3) e_ex = e >> export(Polars(lazy=False)) assert_equal( @@ -642,13 +578,8 @@ def test_list(self, tbl1, tbl3): ) assert_equal( - tbl3 - >> group_by(tbl3.col1) - >> summarize(x=tbl3.col3.list.agg(arrange="col4")) - >> arrange(tbl3.col1), - df3.group_by(pl.col("col1")) - .agg(x=pl.col("col3").sort_by("col4")) - .sort("col1"), + tbl3 >> group_by(tbl3.col1) >> summarize(x=tbl3.col3.list.agg(arrange="col4")) >> arrange(tbl3.col1), + df3.group_by(pl.col("col1")).agg(x=pl.col("col3").sort_by("col4")).sort("col1"), ) def test_cum_sum(self, tbl1): @@ -681,15 +612,17 @@ def test_dict_export(self): def test_dict_of_lists_export(self): assert (pdt.Table({"a": 1}) >> export(pdt.DictOfLists)) == {"a": [1]} - assert ( - pdt.Table({"a": [1, 2], "b": [True, False]}) >> export(pdt.DictOfLists) - ) == {"a": [1, 2], "b": [True, False]} + assert (pdt.Table({"a": [1, 2], "b": [True, False]}) >> export(pdt.DictOfLists)) == { + "a": [1, 2], + "b": [True, False], + } def test_list_of_dicts_export(self): assert (pdt.Table({"a": 1}) >> export(pdt.ListOfDicts)) == [{"a": 1}] - assert ( - pdt.Table({"a": [1, 2], "b": [True, False]}) >> export(pdt.ListOfDicts) - ) == [{"a": 1, "b": True}, {"a": 2, "b": False}] + assert (pdt.Table({"a": [1, 2], "b": [True, False]}) >> export(pdt.ListOfDicts)) == [ + {"a": 1, "b": True}, + {"a": 2, "b": False}, + ] def test_uses_table(self, tbl2, tbl3): assert tbl2.col1.uses_table(tbl2) @@ -703,26 +636,18 @@ def test_name(self, tbl3): def test_columns(self, tbl3): assert tbl3 >> columns() == [col.name for col in tbl3] tbl3_different_col_order = tbl3 >> mutate(col3=tbl3.col3) - assert tbl3_different_col_order >> columns() == [ - col.name for col in tbl3_different_col_order - ] + assert tbl3_different_col_order >> columns() == [col.name for col in tbl3_different_col_order] def test_name_alias(self, tbl2): assert tbl2 >> alias("tbl") >> name() == "tbl" def test_enum(self, tbl1): tbl1_enum = tbl1 >> mutate(p=tbl1.col2.cast(pdt.Enum("a", "b", "c", "d"))) - df1_enum = df1.with_columns( - p=pl.col("col2").cast(pl.Enum(["a", "b", "c", "d"])) - ) + df1_enum = df1.with_columns(p=pl.col("col2").cast(pl.Enum(["a", "b", "c", "d"]))) assert_equal(tbl1_enum, df1_enum) with pytest.raises(pl.exceptions.InvalidOperationError): - ( - tbl1 - >> mutate(p=tbl1.col2.cast(pdt.Enum("a", "b", "d"))) - >> export(Polars) - ) + (tbl1 >> mutate(p=tbl1.col2.cast(pdt.Enum("a", "b", "d"))) >> export(Polars)) assert_equal( tbl1_enum >> mutate(q=C.p + "l"), @@ -730,9 +655,7 @@ def test_enum(self, tbl1): ) def test_enum_isin(self, tbl1): - tbl1 >> mutate( - p=tbl1.col2.cast(pdt.Enum("a", "b", "c", "d")).is_in("a", "b", "c", "d") - ) + tbl1 >> mutate(p=tbl1.col2.cast(pdt.Enum("a", "b", "c", "d")).is_in("a", "b", "c", "d")) tbl1 >> mutate(q="a") >> mutate(p=C.q.cast(pdt.Enum("a")).is_in("a")) def test_col_rename(self, tbl2, tbl4): @@ -751,19 +674,13 @@ def test_join_renaming(self, tbl2, tbl3): df2.rename({"col3": "c"}) .join(df3, how="left", on=["col1", "col2"]) .with_columns( - col1_df3=pl.when(pl.col("col4").is_null()) - .then(None) - .otherwise(pl.col("col1")), - col2_df3=pl.when(pl.col("col4").is_null()) - .then(None) - .otherwise(pl.col("col2")), + col1_df3=pl.when(pl.col("col4").is_null()).then(None).otherwise(pl.col("col1")), + col2_df3=pl.when(pl.col("col4").is_null()).then(None).otherwise(pl.col("col2")), ), ) def test_collision_renaming(self, tbl2, tbl3): - with pytest.raises( - ValueError, match="rename would cause duplicate column name `col3`" - ): + with pytest.raises(ValueError, match="rename would cause duplicate column name `col3`"): _ = ( tbl2 >> mutate(c=tbl2.col3) @@ -772,22 +689,16 @@ def test_collision_renaming(self, tbl2, tbl3): (tbl2.col1 == tbl3.col1) & (tbl2.col2 == tbl3.col2), ) >> rename({"c": "col3"}), - df2.join(df3, how="inner", on=["col1", "col2"]).select( - "col1", "col2", "col3" - ), + df2.join(df3, how="inner", on=["col1", "col2"]).select("col1", "col2", "col3"), ) def test_hidden_collision_renaming0(self, tbl2, tbl3): assert_equal( tbl2 >> rename({"col3": "c"}) - >> inner_join( - tbl3 >> select(), (tbl2.col1 == tbl3.col1) & (tbl2.col2 == tbl3.col2) - ) + >> inner_join(tbl3 >> select(), (tbl2.col1 == tbl3.col1) & (tbl2.col2 == tbl3.col2)) >> rename({"c": "col3"}), - df2.join(df3, how="inner", on=["col1", "col2"]).select( - "col1", "col2", "col3" - ), + df2.join(df3, how="inner", on=["col1", "col2"]).select("col1", "col2", "col3"), ) def test_hidden_collision_renaming1(self, tbl2, tbl3): @@ -795,22 +706,16 @@ def test_hidden_collision_renaming1(self, tbl2, tbl3): tbl2 >> rename({"col3": "c"}) >> collect() - >> inner_join( - tbl3 >> select(), (tbl2.col1 == tbl3.col1) & (tbl2.col2 == tbl3.col2) - ) + >> inner_join(tbl3 >> select(), (tbl2.col1 == tbl3.col1) & (tbl2.col2 == tbl3.col2)) >> rename({C.c: "col3"}), - df2.join(df3, how="inner", on=["col1", "col2"]).select( - "col1", "col2", "col3" - ), + df2.join(df3, how="inner", on=["col1", "col2"]).select("col1", "col2", "col3"), ) def test_hidden_collision_renaming2(self, tbl2, tbl3): assert_equal( tbl2 >> rename({"col3": "c"}) - >> inner_join( - tbl3 >> select(), (tbl2.col1 == tbl3.col1) & (tbl2.col2 == tbl3.col2) - ) + >> inner_join(tbl3 >> select(), (tbl2.col1 == tbl3.col1) & (tbl2.col2 == tbl3.col2)) >> mutate(col3=C.c), df2.join(df3, how="inner", on=["col1", "col2"]) .select("col1", "col2", "col3") @@ -822,9 +727,7 @@ def test_hidden_collision_renaming3(self, tbl2, tbl3): tbl2 >> rename({"col3": "c"}) >> collect() - >> inner_join( - tbl3 >> select(), (tbl2.col1 == tbl3.col1) & (tbl2.col2 == tbl3.col2) - ) + >> inner_join(tbl3 >> select(), (tbl2.col1 == tbl3.col1) & (tbl2.col2 == tbl3.col2)) >> mutate(col3=C.c), df2.join(df3, how="inner", on=["col1", "col2"]) .select("col1", "col2", "col3") @@ -834,9 +737,7 @@ def test_hidden_collision_renaming3(self, tbl2, tbl3): def test_eval_aligned(self, tbl4): assert_equal( tbl4, - tbl4 - >> drop(tbl4.col1) - >> mutate(col1=eval_aligned(df4.get_column("col1"))), + tbl4 >> drop(tbl4.col1) >> mutate(col1=eval_aligned(df4.get_column("col1"))), ) s = pl.Series([1.2**i for i in range(df4.height)]) @@ -851,13 +752,8 @@ def test_eval_aligned(self, tbl4): # pandas series s_pd = s.to_pandas() assert_equal( - tbl4 - >> mutate(z=eval_aligned(tbl4.col1 + s_pd)) - >> group_by(tbl4.col3) - >> summarize(u=C.z.sum()), - df4.with_columns(z=pl.col("col1") + s) - .group_by("col3") - .agg(u=pl.col("z").sum()), + tbl4 >> mutate(z=eval_aligned(tbl4.col1 + s_pd)) >> group_by(tbl4.col3) >> summarize(u=C.z.sum()), + df4.with_columns(z=pl.col("col1") + s).group_by("col3").agg(u=pl.col("z").sum()), check_row_order=False, ) @@ -934,13 +830,7 @@ def test_ast_repr(self, tbl4): tbl4.col1.ast_repr() (tbl4.col1 + tbl4.col2).ast_repr() (tbl4.col1 + tbl4.col2 + tbl4.col3).ast_repr() - ( - pdt.when(tbl4.col1 > 1) - .then(tbl4.col2) - .when(tbl4.col1 < -1) - .then(tbl4.col3) - .otherwise(7) - ).ast_repr() + (pdt.when(tbl4.col1 > 1).then(tbl4.col2).when(tbl4.col1 < -1).then(tbl4.col3).otherwise(7)).ast_repr() (tbl4.col1.cast(pdt.Float64) + tbl4.col2 / 2).ast_repr() @@ -950,16 +840,12 @@ def test_ast_repr(self, tbl4): tbl4.col1.max( partition_by=[tbl4.col2, tbl4.col3], - filter=pdt.when(tbl4.col1 > 0) - .then(tbl4.col2.is_not_null()) - .otherwise((tbl4.col3 % 2) == 0), + filter=pdt.when(tbl4.col1 > 0).then(tbl4.col2.is_not_null()).otherwise((tbl4.col3 % 2) == 0), ).ast_repr() def test_error_source_ptr(self, tbl1, tbl2): with pytest.raises(DataTypeError) as r: - tbl1 >> mutate( - z=(2 * (tbl1.col1 + C.col2) - tbl1.col1.exp()) * tbl1.col1.acos() - ) + tbl1 >> mutate(z=(2 * (tbl1.col1 + C.col2) - tbl1.col1.exp()) * tbl1.col1.acos()) assert "AST path" in r.value.args[0] with pytest.raises(DataTypeError) as r: @@ -1006,10 +892,7 @@ def test_verb_ast_repr(self, tbl3, tbl4): v=(tbl3.col2.exp() + tbl3.col4) * tbl3.col1, w=pdt.max(tbl3.col2, tbl3.col1, tbl3.col5.str.len()), x=pdt.count(), - y=eval_aligned( - (tbl3 >> mutate(j=42) >> alias("tbl. 1729") >> filter(C.col2 > 0)).j - + tbl3.col1 - ), + y=eval_aligned((tbl3 >> mutate(j=42) >> alias("tbl. 1729") >> filter(C.col2 > 0)).j + tbl3.col1), z=eval_aligned(series), ) >> select(tbl3.col1, tbl3.col4) diff --git a/tests/test_sql_table.py b/tests/test_sql_table.py index d9731314..84542422 100644 --- a/tests/test_sql_table.py +++ b/tests/test_sql_table.py @@ -263,17 +263,12 @@ def test_group_by(self, tbl3): >> group_by(tbl3.col1) >> group_by(tbl3.col2, add=True) >> summarize(mean3=tbl3.col3.mean(), mean4=tbl3.col4.mean()), - tbl3 - >> group_by(tbl3.col1, tbl3.col2) - >> summarize(mean3=tbl3.col3.mean(), mean4=tbl3.col4.mean()), + tbl3 >> group_by(tbl3.col1, tbl3.col2) >> summarize(mean3=tbl3.col3.mean(), mean4=tbl3.col4.mean()), ) # Ungroup doesn't change the result assert_equal( - tbl3 - >> group_by(tbl3.col1) - >> summarize(mean4=tbl3.col4.mean()) - >> ungroup(), + tbl3 >> group_by(tbl3.col1) >> summarize(mean4=tbl3.col4.mean()) >> ungroup(), tbl3 >> group_by(tbl3.col1) >> summarize(mean4=tbl3.col4.mean()), ) @@ -293,11 +288,7 @@ def test_alias(self, tbl1, tbl2): assert_equal(a, b) # Self Join - self_join = ( - tbl2 - >> join(x, tbl2.col1 == x.col1, "left", suffix="42") - >> alias("self_join") - ) + self_join = tbl2 >> join(x, tbl2.col1 == x.col1, "left", suffix="42") >> alias("self_join") self_join_expected = df2.join( df2, @@ -321,23 +312,14 @@ def test_lambda_column(self, tbl1, tbl2): ) assert_equal( - tbl1 - >> mutate(a=tbl1.col1 * 2) - >> mutate(b=C.a * 2, a=tbl1.col1) - >> select(C.b), + tbl1 >> mutate(a=tbl1.col1 * 2) >> mutate(b=C.a * 2, a=tbl1.col1) >> select(C.b), tbl1 >> select() >> mutate(b=tbl1.col1 * 4), ) # Join assert_equal( - tbl1 - >> mutate(a=tbl1.col1) - >> join(tbl2, C.a == tbl2.col1, "left") - >> select(C.a, *tbl2), - tbl1 - >> select() - >> mutate(a=tbl1.col1) - >> join(tbl2, tbl1.col1 == tbl2.col1, "left", suffix="_df2"), + tbl1 >> mutate(a=tbl1.col1) >> join(tbl2, C.a == tbl2.col1, "left") >> select(C.a, *tbl2), + tbl1 >> select() >> mutate(a=tbl1.col1) >> join(tbl2, tbl1.col1 == tbl2.col1, "left", suffix="_df2"), ) # Filter @@ -358,9 +340,7 @@ def test_select_without_tbl_ref(self, tbl2): tbl2 >> summarize(count=tbl2.col1.count()), ) - assert_equal( - tbl2 >> summarize(count=count()), pl.DataFrame({"count": [len(df2)]}) - ) + assert_equal(tbl2 >> summarize(count=count()), pl.DataFrame({"count": [len(df2)]})) def test_null_comparison(self, tbl4): assert_equal( @@ -378,13 +358,7 @@ def test_case_expression(self, tbl3): ( tbl3 >> mutate( - col1=when(C.col1 == 0) - .then(1) - .when(C.col1 == 1) - .then(2) - .when(C.col1 == 2) - .then(3) - .otherwise(-1) + col1=when(C.col1 == 0).then(1).when(C.col1 == 1).then(2).when(C.col1 == 2).then(3).otherwise(-1) ) >> select(C.col1) ), @@ -394,13 +368,7 @@ def test_case_expression(self, tbl3): assert_equal( ( tbl3 - >> mutate( - x=when(C.col1 == C.col2) - .then(1) - .when(C.col1 == C.col3) - .then(2) - .otherwise(C.col4) - ) + >> mutate(x=when(C.col1 == C.col2).then(1).when(C.col1 == C.col3).then(2).otherwise(C.col4)) >> select(C.x) ), pl.DataFrame({"x": [1, 1, 2, 3, 4, 2, 1, 1, 8, 9, 2, 11]}), diff --git a/tests/util/assertion.py b/tests/util/assertion.py index 01398e15..e65d686c 100644 --- a/tests/util/assertion.py +++ b/tests/util/assertion.py @@ -21,9 +21,7 @@ def assert_equal(left, right, check_dtypes=False, check_row_order=True): return left_df = left >> export(Polars(lazy=False)) if isinstance(left, Table) else left - right_df = ( - right >> export(Polars(lazy=False)) if isinstance(right, Table) else right - ) + right_df = right >> export(Polars(lazy=False)) if isinstance(right, Table) else right try: assert_frame_equal( @@ -98,13 +96,8 @@ def assert_result_equal( # TODO: after a join, cols containing only null values get type Null on # SQLite and Postgres. maybe we can fix this but for now we just ignore them assert dfx.columns == dfy.columns - null_cols = set(dfx.select(pl.col(pl.Null)).columns) | set( - dfy.select(pl.col(pl.Null)).columns - ) - assert all( - all(d.get_column(col).is_null().all() for col in null_cols) - for d in (dfx, dfy) - ) + null_cols = set(dfx.select(pl.col(pl.Null)).columns) | set(dfy.select(pl.col(pl.Null)).columns) + assert all(all(d.get_column(col).is_null().all() for col in null_cols) for d in (dfx, dfy)) dfy = dfy.select(pl.all().exclude(null_cols)) dfx = dfx.select(pl.all().exclude(null_cols)) @@ -119,10 +112,7 @@ def assert_result_equal( if isinstance(exception, type): exception = (exception,) if not isinstance(e, exception): - raise Exception( - f"Raised the wrong type of exception: {type(e)} instead of" - f" {exception}." - ) from e + raise Exception(f"Raised the wrong type of exception: {type(e)} instead of {exception}.") from e # TODO: Replace with logger print(f"An exception was thrown:\n{e}") return @@ -130,9 +120,7 @@ def assert_result_equal( raise e try: - assert_frame_equal( - dfx, dfy, check_row_order=check_row_order, check_exact=False, atol=1e-6 - ) + assert_frame_equal(dfx, dfy, check_row_order=check_row_order, check_exact=False, atol=1e-6) except Exception as e: if xfail_warnings and did_raise_warning: pytest.xfail(warnings_summary) diff --git a/tests/util/backend.py b/tests/util/backend.py index b385222a..6fdda9ec 100644 --- a/tests/util/backend.py +++ b/tests/util/backend.py @@ -64,20 +64,14 @@ def sql_table( if dtype in dtypes_map and col not in sql_dtypes: sql_dtypes[col] = dtypes_map[dtype] - df.write_database( - name, engine, if_table_exists="replace", engine_options={"dtype": sql_dtypes} - ) + df.write_database(name, engine, if_table_exists="replace", engine_options={"dtype": sql_dtypes}) if fix_sql_dtypes is not None and len(fix_sql_dtypes) > 0: # this is a hack to fix sql types after creation of the table # the main reason for this is that ibm_db_sa renders sqa.boolean as SMALLINT # (https://github.com/ibmdb/python-ibmdbsa/issues/161) with engine.connect() as conn: for col, dtype in fix_sql_dtypes.items(): - conn.execute( - sqa.text( - f"ALTER TABLE {name} ALTER COLUMN {col} {dialect_infix} {dtype}" - ) - ) + conn.execute(sqa.text(f"ALTER TABLE {name} ALTER COLUMN {col} {dialect_infix} {dtype}")) conn.execute(sqa.text(f"call sysproc.admin_cmd('REORG TABLE {name}')")) conn.commit() return Table(name, SqlAlchemy(engine)) @@ -113,12 +107,7 @@ def duckdb_parquet_table(df: pl.DataFrame, name: str): with engine.connect() as conn: conn.execute(sqa.text(f"DROP VIEW IF EXISTS {name}")) - conn.execute( - sqa.text( - f"CREATE VIEW {name} AS SELECT * FROM " - f"read_parquet('{path / f'{name}.parquet'}')" - ) - ) + conn.execute(sqa.text(f"CREATE VIEW {name} AS SELECT * FROM read_parquet('{path / f'{name}.parquet'}')")) conn.commit() return Table(name, SqlAlchemy(engine)) @@ -135,10 +124,7 @@ def postgres_table(df: pl.DataFrame, name: str): def mssql_table(df: pl.DataFrame, name: str): from sqlalchemy.dialects.mssql import DATETIME2 - url = ( - "mssql+pyodbc://sa:PydiQuant27@127.0.0.1:1433" - "/master?driver=ODBC+Driver+18+for+SQL+Server&encrypt=no" - ) + url = "mssql+pyodbc://sa:PydiQuant27@127.0.0.1:1433/master?driver=ODBC+Driver+18+for+SQL+Server&encrypt=no" return sql_table(df, name, url, dtypes_map={pl.Datetime(): DATETIME2()})