diff --git a/dataframely/columns/_base.py b/dataframely/columns/_base.py index 6b9e0ba..c8bde7c 100644 --- a/dataframely/columns/_base.py +++ b/dataframely/columns/_base.py @@ -18,6 +18,7 @@ warn_nullable_default_change, ) from dataframely._polars import PolarsDataType +from dataframely.columns._utils import first_non_null from dataframely.random import Generator if sys.version_info >= (3, 11): @@ -246,6 +247,24 @@ def col(self) -> pl.Expr: """Obtain a Polars column expression for the column.""" return pl.col(self.name) + def with_property( + self, + *, + nullable: bool | None = None, + primary_key: bool | None = None, + check: Check | None = None, + alias: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> Self: + """Create a copy of this column with updated properties.""" + return self.__class__( + nullable=first_non_null(nullable, self.nullable, allow_null_response=True), + primary_key=first_non_null(primary_key, default=self.primary_key), + check=self.check if check is None else check, + alias=first_non_null(alias, self.alias, allow_null_response=True), + metadata=first_non_null(metadata, self.metadata, allow_null_response=True), + ) + # ----------------------------------- SAMPLING ----------------------------------- # def sample(self, generator: Generator, n: int = 1) -> pl.Series: diff --git a/dataframely/columns/_mixins.py b/dataframely/columns/_mixins.py index 1f8b002..0a8fd0c 100644 --- a/dataframely/columns/_mixins.py +++ b/dataframely/columns/_mixins.py @@ -7,6 +7,8 @@ import polars as pl +from ._utils import first_non_null + if TYPE_CHECKING: # pragma: no cover from ._base import Column @@ -80,6 +82,26 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]: result["max_exclusive"] = expr < self.max_exclusive # type: ignore return result + def with_property( + self, + *, + min: T | None = None, + min_exclusive: T | None = None, + max: T | None = None, + max_exclusive: T | None = None, + **kwargs: Any, + ) -> Self: + new_column = super().with_property(**kwargs) + new_column.min = first_non_null(min, self.min, allow_null_response=True) + new_column.min_exclusive = first_non_null( + min_exclusive, self.min_exclusive, allow_null_response=True + ) + new_column.max = first_non_null(max, self.max, allow_null_response=True) + new_column.max_exclusive = first_non_null( + max_exclusive, self.max_exclusive, allow_null_response=True + ) + return new_column + # ------------------------------------ IS IN MIXIN ----------------------------------- # @@ -98,3 +120,8 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]: if self.is_in is not None: result["is_in"] = expr.is_in(self.is_in) return result + + def with_property(self, *, is_in: Sequence[U] | None = None, **kwargs: Any) -> Self: + new_column = super().with_property(**kwargs) + new_column.is_in = first_non_null(is_in, self.is_in, allow_null_response=True) + return new_column diff --git a/dataframely/columns/any.py b/dataframely/columns/any.py index cd41e4e..985e76a 100644 --- a/dataframely/columns/any.py +++ b/dataframely/columns/any.py @@ -3,6 +3,13 @@ from __future__ import annotations +import sys + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + import polars as pl from dataframely._compat import pa, sa, sa_mssql, sa_TypeEngine @@ -11,6 +18,7 @@ from ._base import Check, Column from ._registry import register +from ._utils import first_non_null @register @@ -79,3 +87,22 @@ def pyarrow_dtype(self) -> pa.DataType: def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: return pl.repeat(None, n, dtype=pl.Null, eager=True) + + def with_property( + self, + *, + nullable: bool | None = None, + primary_key: bool | None = None, + check: Check | None = None, + alias: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> Self: + if nullable is not None and not nullable: + raise ValueError("Column `Any` must be nullable.") + if primary_key is not None and primary_key: + raise ValueError("Column `Any` can't be a primary key.") + return self.__class__( + check=check if check is not None else self.check, + alias=first_non_null(alias, self.alias, allow_null_response=True), + metadata=first_non_null(metadata, self.metadata, allow_null_response=True), + ) diff --git a/dataframely/columns/datetime.py b/dataframely/columns/datetime.py index 82fd7d1..303a894 100644 --- a/dataframely/columns/datetime.py +++ b/dataframely/columns/datetime.py @@ -4,8 +4,14 @@ from __future__ import annotations import datetime as dt +import sys from typing import Any, cast +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + import polars as pl from polars._typing import TimeUnit @@ -149,6 +155,36 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: null_probability=self._null_probability, ) + def with_property( + self, + *, + nullable: bool | None = None, + primary_key: bool | None = None, + min: dt.date | None = None, + min_exclusive: dt.date | None = None, + max: dt.date | None = None, + max_exclusive: dt.date | None = None, + resolution: str | None = None, + check: Check | None = None, + alias: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> Self: + result = super().with_property( + nullable=nullable, + primary_key=primary_key, + min=min, + min_exclusive=min_exclusive, + max=max, + max_exclusive=max_exclusive, + check=check, + alias=alias, + metadata=metadata, + ) + result.resolution = first_non_null( + resolution, self.resolution, allow_null_response=True + ) + return result + @register class Time(OrdinalMixin[dt.time], Column): @@ -278,6 +314,36 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: null_probability=self._null_probability, ) + def with_property( + self, + *, + nullable: bool | None = None, + primary_key: bool | None = None, + min: dt.time | None = None, + min_exclusive: dt.time | None = None, + max: dt.time | None = None, + max_exclusive: dt.time | None = None, + resolution: str | None = None, + check: Check | None = None, + alias: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> Self: + result = super().with_property( + nullable=nullable, + primary_key=primary_key, + min=min, + min_exclusive=min_exclusive, + max=max, + max_exclusive=max_exclusive, + check=check, + alias=alias, + metadata=metadata, + ) + result.resolution = first_non_null( + resolution, self.resolution, allow_null_response=True + ) + return result + @register class Datetime(OrdinalMixin[dt.datetime], Column): @@ -425,6 +491,42 @@ def _attributes_match( return lhs.utcoffset(now) == rhs.utcoffset(now) return super()._attributes_match(lhs, rhs, name, column_expr) + def with_property( + self, + *, + nullable: bool | None = None, + primary_key: bool = False, + min: dt.datetime | None = None, + min_exclusive: dt.datetime | None = None, + max: dt.datetime | None = None, + max_exclusive: dt.datetime | None = None, + resolution: str | None = None, + time_zone: str | dt.tzinfo | None = None, + time_unit: TimeUnit | None = None, + check: Check | None = None, + alias: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> Self: + result = super().with_property( + nullable=nullable, + primary_key=primary_key, + min=min, + min_exclusive=min_exclusive, + max=max, + max_exclusive=max_exclusive, + check=check, + alias=alias, + metadata=metadata, + ) + result.resolution = first_non_null( + resolution, self.resolution, allow_null_response=True + ) + result.time_zone = first_non_null( + time_zone, self.time_zone, allow_null_response=True + ) + result.time_unit = first_non_null(time_unit, default=self.time_unit) + return result + @register class Duration(OrdinalMixin[dt.timedelta], Column): @@ -546,6 +648,36 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: null_probability=self._null_probability, ) + def with_property( + self, + *, + nullable: bool | None = None, + primary_key: bool = False, + min: dt.timedelta | None = None, + min_exclusive: dt.timedelta | None = None, + max: dt.timedelta | None = None, + max_exclusive: dt.timedelta | None = None, + resolution: str | None = None, + check: Check | None = None, + alias: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> Self: + result = super().with_property( + nullable=nullable, + primary_key=primary_key, + min=min, + min_exclusive=min_exclusive, + max=max, + max_exclusive=max_exclusive, + check=check, + alias=alias, + metadata=metadata, + ) + result.resolution = first_non_null( + resolution, self.resolution, allow_null_response=True + ) + return result + # --------------------------------------- UTILS -------------------------------------- # diff --git a/dataframely/columns/decimal.py b/dataframely/columns/decimal.py index 547ba8c..828ede0 100644 --- a/dataframely/columns/decimal.py +++ b/dataframely/columns/decimal.py @@ -5,10 +5,17 @@ import decimal import math +import sys from typing import Any import polars as pl +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + + from dataframely._compat import pa, sa, sa_TypeEngine from dataframely._polars import PolarsDataType from dataframely.random import Generator @@ -157,6 +164,39 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: ) return ((samples * 10**self.scale).floor() / 10**self.scale).cast(self.dtype) + def with_property( + self, + *, + precision: int | None = None, + scale: int | None = None, + nullable: bool | None = None, + primary_key: bool = False, + min: decimal.Decimal | None = None, + min_exclusive: decimal.Decimal | None = None, + max: decimal.Decimal | None = None, + max_exclusive: decimal.Decimal | None = None, + check: Check | None = None, + alias: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> Self: + # TODO validate + result = super().with_property( + nullable=nullable, + primary_key=primary_key, + min=min, + min_exclusive=min_exclusive, + max=max, + max_exclusive=max_exclusive, + check=check, + alias=alias, + metadata=metadata, + ) + result.precision = first_non_null( + precision, self.precision, allow_null_response=True + ) + result.scale = scale if scale is not None else self.scale + return result + # --------------------------------------- UTILS -------------------------------------- # diff --git a/dataframely/columns/enum.py b/dataframely/columns/enum.py index 9119aa4..4cbcb1c 100644 --- a/dataframely/columns/enum.py +++ b/dataframely/columns/enum.py @@ -3,6 +3,7 @@ from __future__ import annotations +import sys from collections.abc import Sequence from typing import Any @@ -15,6 +16,11 @@ from ._base import Check, Column from ._registry import register +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + @register class Enum(Column): @@ -88,3 +94,23 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: return generator.sample_choice( n, choices=self.categories, null_probability=self._null_probability ).cast(self.dtype) + + def with_property( + self, + *, + categories: Sequence[str] | None = None, + nullable: bool | None = None, + primary_key: bool = False, + check: Check | None = None, + alias: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> Self: + result = super().with_property( + nullable=nullable, + primary_key=primary_key, + check=check, + alias=alias, + metadata=metadata, + ) + result.categories = categories if categories is not None else self.categories + return result diff --git a/dataframely/columns/float.py b/dataframely/columns/float.py index ed357a0..659fe7f 100644 --- a/dataframely/columns/float.py +++ b/dataframely/columns/float.py @@ -8,6 +8,12 @@ from abc import abstractmethod from typing import Any +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + + import numpy as np import polars as pl from polars.datatypes.group import FLOAT_DTYPES @@ -137,6 +143,36 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: inf_probability=self._inf_probability, ).cast(self.dtype) + def with_property( + self, + *, + nullable: bool | None = None, + primary_key: bool | None = None, + allow_inf_nan: bool | None = None, + min: float | None = None, + min_exclusive: float | None = None, + max: float | None = None, + max_exclusive: float | None = None, + check: Check | None = None, + alias: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> Self: + result = super().with_property( + nullable=nullable, + primary_key=primary_key, + min=min, + min_exclusive=min_exclusive, + max=max, + max_exclusive=max_exclusive, + check=check, + alias=alias, + metadata=metadata, + ) + result.allow_inf_nan = ( + allow_inf_nan if allow_inf_nan is not None else self.allow_inf_nan + ) + return result + # ------------------------------------------------------------------------------------ # diff --git a/tests/columns/test_with_property.py b/tests/columns/test_with_property.py new file mode 100644 index 0000000..9382fa1 --- /dev/null +++ b/tests/columns/test_with_property.py @@ -0,0 +1,60 @@ +# Copyright (c) QuantCo 2025-2025 +# SPDX-License-Identifier: BSD-3-Clause + +import polars as pl +import pytest + +import dataframely as dy + + +class SchemaOne(dy.Schema): + column_one = dy.Integer(primary_key=True) + column_two = dy.Integer() + + +class SchemaTwo(dy.Schema): + column_one = SchemaOne.column_one + column_two = SchemaOne.column_two.with_property(primary_key=True) + + +def test_with_property() -> None: + # Check that the second schema has the updated column + SchemaTwo.validate(pl.LazyFrame({"column_one": [1, 1], "column_two": [1, 2]})) + # Check that the first schema is unchanged + with pytest.raises(dy.exc.ValidationError): + SchemaOne.validate(pl.LazyFrame({"column_one": [1, 1], "column_two": [1, 2]})) + + +class SchemaWithIsInConstraint(dy.Schema): + column_one = SchemaOne.column_one.with_property(is_in=[1, 2, 3]) + + +def test_with_is_in_property() -> None: + # Check that the updated schema has the constraint + with pytest.raises(dy.exc.ValidationError): + SchemaWithIsInConstraint.validate(pl.LazyFrame({"column_one": [1, 4]})) + + # Check that the original schema is unchanged: + SchemaOne.validate(pl.LazyFrame({"column_one": [1, 4], "column_two": [1, 2]})) + + +class SchemaWithMultipleProperties(dy.Schema): + column_one = SchemaOne.column_one.with_property(is_in=[1, 2, 3, 4, 5, 6], max=4) + + +def test_multiple() -> None: + # Is in + with pytest.raises(dy.exc.ValidationError): + SchemaWithMultipleProperties.validate(pl.LazyFrame({"column_one": [0]})) + # Max + with pytest.raises(dy.exc.ValidationError): + SchemaWithMultipleProperties.validate(pl.LazyFrame({"column_one": [6]})) + + +class SchemaAny(dy.Schema): + column_any = dy.Any().with_property(check=lambda x: x > 7) + + +def test_any() -> None: + with pytest.raises(dy.exc.ValidationError): + SchemaAny.validate(pl.LazyFrame({"column_any": [6, 7]}))