Skip to content

Commit

Permalink
refactor(core): privatize step param getter/setter
Browse files Browse the repository at this point in the history
  • Loading branch information
deepyaman committed Sep 16, 2024
1 parent 8b771e6 commit 6a1ec0a
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 32 deletions.
46 changes: 20 additions & 26 deletions ibis_ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,13 +347,8 @@ def _get_param_names(cls) -> list[str]:
# Extract and sort argument names excluding 'self'
return sorted([p.name for p in parameters])

def get_params(self, deep=True) -> dict[str, Any]:
"""Get parameters for this estimator.
Parameters
----------
deep : bool, default=True
Has no effect, because steps cannot contain nested substeps.
def _get_params(self) -> dict[str, Any]:
"""Get parameters for this step.
Returns
-------
Expand All @@ -370,8 +365,8 @@ def get_params(self, deep=True) -> dict[str, Any]:
"""
return {key: getattr(self, key) for key in self._get_param_names()}

def set_params(self, **params):
"""Set the parameters of this estimator.
def _set_params(self, **params):
"""Set the parameters of this step.
Parameters
----------
Expand Down Expand Up @@ -400,7 +395,7 @@ def set_params(self, **params):
for key, value in params.items():
if key not in valid_params:
raise ValueError(
f"Invalid parameter {key!r} for estimator {self}. "
f"Invalid parameter {key!r} for step {self}. "
f"Valid parameters are: {valid_params!r}."
)

Expand Down Expand Up @@ -503,16 +498,16 @@ def output_format(self) -> Literal["default", "pandas", "pyarrow", "polars"]:
return self._output_format

def get_params(self, deep=True) -> dict[str, Any]:
"""Get parameters for this estimator.
"""Get parameters for this recipe.
Returns the parameters given in the constructor as well as the
estimators contained within the `steps` of the `Recipe`.
steps contained within the `steps` of the `Recipe`.
Parameters
----------
deep : bool, default=True
If True, will return the parameters for this estimator and
contained subobjects that are estimators.
If True, will return the parameters for this recipe and
contained steps.
Returns
-------
Expand All @@ -531,26 +526,25 @@ def get_params(self, deep=True) -> dict[str, Any]:
if not deep:
return out

estimators = _name_estimators(self.steps)
out.update(estimators)
steps = _name_estimators(self.steps)
out.update(steps)

for name, estimator in estimators:
if hasattr(estimator, "get_params"):
for key, value in estimator.get_params(deep=True).items():
out[f"{name}__{key}"] = value
for name, step in steps:
for key, value in step._get_params().items(): # noqa: SLF001
out[f"{name}__{key}"] = value
return out

def set_params(self, **params):
"""Set the parameters of this estimator.
"""Set the parameters of this recipe.
Valid parameter keys can be listed with ``get_params()``. Note that
you can directly set the parameters of the estimators contained in
you can directly set the parameters of the steps contained in
`steps`.
Parameters
----------
**params : dict
Parameters of this estimator or parameters of estimators contained
Parameters of this recipe or parameters of steps contained
in `steps`. Parameters of the steps may be set using its name and
the parameter name separated by a '__'.
Expand All @@ -577,7 +571,7 @@ def set_params(self, **params):
if "steps" in params:
self.steps = params.pop("steps")

# 2. Replace items with estimators in params
# 2. Replace steps with steps in params
estimator_name_indexes = {
x: i for i, x in enumerate(name for name, _ in _name_estimators(self.steps))
}
Expand All @@ -593,14 +587,14 @@ def set_params(self, **params):
key, sub_key = key.split("__", maxsplit=1)
if key not in valid_params:
raise ValueError(
f"Invalid parameter {key!r} for estimator {self}. "
f"Invalid parameter {key!r} for recipe {self}. "
f"Valid parameters are: ['steps']."
)

nested_params[key][sub_key] = value

for key, sub_params in nested_params.items():
valid_params[key].set_params(**sub_params)
valid_params[key]._set_params(**sub_params) # noqa: SLF001

return self

Expand Down
6 changes: 3 additions & 3 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_mutate_at_expr():
res = step.transform_table(t)
sol = t.mutate(x=_.x.abs(), y=_.y.abs())
assert res.equals(sol)
assert list(step.get_params()) == ["expr", "inputs", "named_exprs"]
assert list(step._get_params()) == ["expr", "inputs", "named_exprs"] # noqa: SLF001


def test_mutate_at_named_exprs():
Expand All @@ -45,7 +45,7 @@ def test_mutate_at_named_exprs():
res = step.transform_table(t)
sol = t.mutate(x=_.x.abs(), y=_.y.abs(), x_log=_.x.log(), y_log=_.y.log())
assert res.equals(sol)
assert list(step.get_params()) == ["expr", "inputs", "named_exprs"]
assert list(step._get_params()) == ["expr", "inputs", "named_exprs"] # noqa: SLF001


def test_mutate():
Expand All @@ -56,4 +56,4 @@ def test_mutate():
res = step.transform_table(t)
sol = t.mutate(_.x.abs().name("x_abs"), y_log=lambda t: t.y.log())
assert res.equals(sol)
assert list(step.get_params()) == ["exprs", "named_exprs"]
assert list(step._get_params()) == ["exprs", "named_exprs"] # noqa: SLF001
6 changes: 3 additions & 3 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,13 +380,13 @@ def test_set_params():
# Nonexistent parameter in step
with pytest.raises(
ValueError,
match="Invalid parameter 'nonexistent_param' for estimator ExpandTimestamp",
match="Invalid parameter 'nonexistent_param' for step ExpandTimestamp",
):
rec.set_params(expandtimestamp__nonexistent_param=True)

# Nonexistent parameter of pipeline
with pytest.raises(
ValueError, match="Invalid parameter 'expanddatetime' for estimator Recipe"
ValueError, match="Invalid parameter 'expanddatetime' for recipe Recipe"
):
rec.set_params(expanddatetime__nonexistent_param=True)

Expand All @@ -395,7 +395,7 @@ def test_set_params_passes_all_parameters():
# Make sure all parameters are passed together to set_params
# of nested estimator.
rec = ml.Recipe(ml.ExpandTimestamp(ml.timestamp()))
with patch.object(ml.ExpandTimestamp, "set_params") as mock_set_params:
with patch.object(ml.ExpandTimestamp, "_set_params") as mock_set_params:
rec.set_params(
expandtimestamp__inputs=["x", "y"],
expandtimestamp__components=["day", "year", "hour"],
Expand Down

0 comments on commit 6a1ec0a

Please sign in to comment.