diff --git a/ibis_ml/core.py b/ibis_ml/core.py index 2c86dab..6d21b9d 100644 --- a/ibis_ml/core.py +++ b/ibis_ml/core.py @@ -347,13 +347,8 @@ def _get_param_names(cls) -> list[str]: # Extract and sort argument names excluding 'self' return sorted([p.name for p in parameters]) - def get_params(self, deep=True) -> dict[str, Any]: - """Get parameters for this estimator. - - Parameters - ---------- - deep : bool, default=True - Has no effect, because steps cannot contain nested substeps. + def _get_params(self) -> dict[str, Any]: + """Get parameters for this step. Returns ------- @@ -370,8 +365,8 @@ def get_params(self, deep=True) -> dict[str, Any]: """ return {key: getattr(self, key) for key in self._get_param_names()} - def set_params(self, **params): - """Set the parameters of this estimator. + def _set_params(self, **params): + """Set the parameters of this step. Parameters ---------- @@ -400,7 +395,7 @@ def set_params(self, **params): for key, value in params.items(): if key not in valid_params: raise ValueError( - f"Invalid parameter {key!r} for estimator {self}. " + f"Invalid parameter {key!r} for step {self}. " f"Valid parameters are: {valid_params!r}." ) @@ -503,16 +498,16 @@ def output_format(self) -> Literal["default", "pandas", "pyarrow", "polars"]: return self._output_format def get_params(self, deep=True) -> dict[str, Any]: - """Get parameters for this estimator. + """Get parameters for this recipe. Returns the parameters given in the constructor as well as the - estimators contained within the `steps` of the `Recipe`. + steps contained within the `steps` of the `Recipe`. Parameters ---------- deep : bool, default=True - If True, will return the parameters for this estimator and - contained subobjects that are estimators. + If True, will return the parameters for this recipe and + contained steps. Returns ------- @@ -531,26 +526,25 @@ def get_params(self, deep=True) -> dict[str, Any]: if not deep: return out - estimators = _name_estimators(self.steps) - out.update(estimators) + steps = _name_estimators(self.steps) + out.update(steps) - for name, estimator in estimators: - if hasattr(estimator, "get_params"): - for key, value in estimator.get_params(deep=True).items(): - out[f"{name}__{key}"] = value + for name, step in steps: + for key, value in step._get_params().items(): # noqa: SLF001 + out[f"{name}__{key}"] = value return out def set_params(self, **params): - """Set the parameters of this estimator. + """Set the parameters of this recipe. Valid parameter keys can be listed with ``get_params()``. Note that - you can directly set the parameters of the estimators contained in + you can directly set the parameters of the steps contained in `steps`. Parameters ---------- **params : dict - Parameters of this estimator or parameters of estimators contained + Parameters of this recipe or parameters of steps contained in `steps`. Parameters of the steps may be set using its name and the parameter name separated by a '__'. @@ -577,7 +571,7 @@ def set_params(self, **params): if "steps" in params: self.steps = params.pop("steps") - # 2. Replace items with estimators in params + # 2. Replace steps with steps in params estimator_name_indexes = { x: i for i, x in enumerate(name for name, _ in _name_estimators(self.steps)) } @@ -593,14 +587,14 @@ def set_params(self, **params): key, sub_key = key.split("__", maxsplit=1) if key not in valid_params: raise ValueError( - f"Invalid parameter {key!r} for estimator {self}. " + f"Invalid parameter {key!r} for recipe {self}. " f"Valid parameters are: ['steps']." ) nested_params[key][sub_key] = value for key, sub_params in nested_params.items(): - valid_params[key].set_params(**sub_params) + valid_params[key]._set_params(**sub_params) # noqa: SLF001 return self diff --git a/tests/test_common.py b/tests/test_common.py index effea3a..0119d62 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -34,7 +34,7 @@ def test_mutate_at_expr(): res = step.transform_table(t) sol = t.mutate(x=_.x.abs(), y=_.y.abs()) assert res.equals(sol) - assert list(step.get_params()) == ["expr", "inputs", "named_exprs"] + assert list(step._get_params()) == ["expr", "inputs", "named_exprs"] # noqa: SLF001 def test_mutate_at_named_exprs(): @@ -45,7 +45,7 @@ def test_mutate_at_named_exprs(): res = step.transform_table(t) sol = t.mutate(x=_.x.abs(), y=_.y.abs(), x_log=_.x.log(), y_log=_.y.log()) assert res.equals(sol) - assert list(step.get_params()) == ["expr", "inputs", "named_exprs"] + assert list(step._get_params()) == ["expr", "inputs", "named_exprs"] # noqa: SLF001 def test_mutate(): @@ -56,4 +56,4 @@ def test_mutate(): res = step.transform_table(t) sol = t.mutate(_.x.abs().name("x_abs"), y_log=lambda t: t.y.log()) assert res.equals(sol) - assert list(step.get_params()) == ["exprs", "named_exprs"] + assert list(step._get_params()) == ["exprs", "named_exprs"] # noqa: SLF001 diff --git a/tests/test_core.py b/tests/test_core.py index 2f10c02..4592971 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -380,13 +380,13 @@ def test_set_params(): # Nonexistent parameter in step with pytest.raises( ValueError, - match="Invalid parameter 'nonexistent_param' for estimator ExpandTimestamp", + match="Invalid parameter 'nonexistent_param' for step ExpandTimestamp", ): rec.set_params(expandtimestamp__nonexistent_param=True) # Nonexistent parameter of pipeline with pytest.raises( - ValueError, match="Invalid parameter 'expanddatetime' for estimator Recipe" + ValueError, match="Invalid parameter 'expanddatetime' for recipe Recipe" ): rec.set_params(expanddatetime__nonexistent_param=True) @@ -395,7 +395,7 @@ def test_set_params_passes_all_parameters(): # Make sure all parameters are passed together to set_params # of nested estimator. rec = ml.Recipe(ml.ExpandTimestamp(ml.timestamp())) - with patch.object(ml.ExpandTimestamp, "set_params") as mock_set_params: + with patch.object(ml.ExpandTimestamp, "_set_params") as mock_set_params: rec.set_params( expandtimestamp__inputs=["x", "y"], expandtimestamp__components=["day", "year", "hour"],