From be8105dbaccb11b739c59fc4f8b2cc1c86d7a331 Mon Sep 17 00:00:00 2001 From: LennartPurucker Date: Fri, 20 Jun 2025 19:13:04 +0200 Subject: [PATCH 1/2] add: perpetual boosting --- setup.py | 4 + tabrepo/benchmark/models/ag/__init__.py | 2 + .../benchmark/models/ag/perpetual/__init__.py | 0 .../models/ag/perpetual/perpetual_model.py | 123 ++++++++++++++++++ tabrepo/benchmark/models/model_register.py | 7 +- tabrepo/models/perpetual/__init__.py | 0 tabrepo/models/perpetual/generate.py | 13 ++ tabrepo/models/utils.py | 1 + tst/benchmark/models/test_perpetual.py | 25 ++++ 9 files changed, 174 insertions(+), 1 deletion(-) create mode 100644 tabrepo/benchmark/models/ag/perpetual/__init__.py create mode 100644 tabrepo/benchmark/models/ag/perpetual/perpetual_model.py create mode 100644 tabrepo/models/perpetual/__init__.py create mode 100644 tabrepo/models/perpetual/generate.py create mode 100644 tst/benchmark/models/test_perpetual.py diff --git a/setup.py b/setup.py index eab4ed0b..b6d3a28e 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,10 @@ "modernnca": [ "category_encoders", ], + "perpetualboosting": [ + # used version: 0.9.4 + "perpetual", + ], } benchmark_requires = [] diff --git a/tabrepo/benchmark/models/ag/__init__.py b/tabrepo/benchmark/models/ag/__init__.py index 4cfded7c..d7bcaa4e 100644 --- a/tabrepo/benchmark/models/ag/__init__.py +++ b/tabrepo/benchmark/models/ag/__init__.py @@ -8,6 +8,7 @@ from tabrepo.benchmark.models.ag.tabm.tabm_model import TabMModel from tabrepo.benchmark.models.ag.tabpfnv2.tabpfnv2_client_model import TabPFNV2ClientModel from tabrepo.benchmark.models.ag.tabpfnv2.tabpfnv2_model import TabPFNV2Model +from tabrepo.benchmark.models.ag.perpetual.perpetual_model import PerpetualBoostingModel __all__ = [ "ExplainableBoostingMachineModel", @@ -18,4 +19,5 @@ "TabMModel", "TabPFNV2ClientModel", "TabPFNV2Model", + "PerpetualBoostingModel", ] diff --git a/tabrepo/benchmark/models/ag/perpetual/__init__.py b/tabrepo/benchmark/models/ag/perpetual/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tabrepo/benchmark/models/ag/perpetual/perpetual_model.py b/tabrepo/benchmark/models/ag/perpetual/perpetual_model.py new file mode 100644 index 00000000..53c2001f --- /dev/null +++ b/tabrepo/benchmark/models/ag/perpetual/perpetual_model.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from autogluon.core.constants import BINARY, MULTICLASS, REGRESSION +from autogluon.core.models import AbstractModel + +if TYPE_CHECKING: + import numpy as np + import pandas as pd + from autogluon.core.metrics import Scorer + + +class PerpetualBoostingModel(AbstractModel): + ag_key = "PB" + ag_name = "PerpetualBoosting" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._category_features: list[str] = None + + def _preprocess_nonadaptive(self, X, **kwargs): + X = super()._preprocess_nonadaptive(X, **kwargs) + + if self._category_features is None: + self._category_features = list(X.select_dtypes(include="category").columns) + + return X + + # TODO: support sample weights + # TODO: API does not support a random seed, but rust code does (mhm?) + # TODO: no GPU support (?), add warning. + # TODO: no support for passing validation data... problematic + def _fit( + self, + X: pd.DataFrame, + y: pd.Series, + # X_val: pd.DataFrame | None = None, + # y_val: pd.Series | None = None, + time_limit: float | None = None, + num_cpus: int | str = "auto", + sample_weight: np.ndarray | None = None, + sample_weight_val: np.ndarray | None = None, + **kwargs, + ): + # Preprocess data. + X = self.preprocess(X, is_train=True) + paras = self._get_model_params() + + from perpetual import PerpetualBooster + + # FIXME: get correct number within ray job, + # here we know it is one of 8 folds now. + # -> This is also not correct if not bagged mode is used, or outer + # memory limit is not set on SLURM. + # memory_limit = ResourceManager().get_available_virtual_mem(format="GB") / 8 + # the above does not work as env var is not set in ray job... + # need to pass this to the model kwargs as well + # future work... hard coding for now. + memory_limit = 4 + + # safety as memory limit is not strictly enforced by PB + memory_limit = int(memory_limit * 0.95) + + self.model = PerpetualBooster( + objective=get_metric_from_ag_metric( + metric=self.eval_metric, problem_type=self.problem_type + ), + num_threads=num_cpus, + memory_limit=memory_limit, + categorical_features=self._category_features, + timeout=time_limit, + # stopping_rounds - no idea how to set this + **paras, + ) + + self.model.fit(X=X, y=y, sample_weight=sample_weight) + + def _set_default_params(self): + default_params = { + "iteration_limit": 10_000, + # As we use a timeout this should be the max I guess. + # 2.0 is used a lot in examples, so I guess it is the max (?). + "budget": 2.0, + } + for param, val in default_params.items(): + self._set_default_param_value(param, val) + + def _get_default_auxiliary_params(self) -> dict: + default_auxiliary_params = super()._get_default_auxiliary_params() + extra_auxiliary_params = { + "valid_raw_types": ["int", "float", "category"], + } + default_auxiliary_params.update(extra_auxiliary_params) + return default_auxiliary_params + + @classmethod + def supported_problem_types(cls) -> list[str] | None: + return ["binary", "multiclass", "regression"] + + # TODO: support refit full + def _more_tags(self) -> dict: + return {"can_refit_full": False} + + +def get_metric_from_ag_metric(*, metric: Scorer, problem_type: str): + """Map AutoGluon metric to EBM metric for early stopping.""" + if problem_type in [BINARY, MULTICLASS]: + # Only supports log_loss for classification. + metric_map = { + "log_loss": "LogLoss", + } + metric_class = metric_map.get(metric.name, "LogLoss") + elif problem_type == REGRESSION: + metric_map = { + "mean_squared_error": "HuberLoss", # seems to be a better match than RMSE. + "root_mean_squared_error": "SquaredLoss", + } + metric_class = metric_map.get(metric.name, "SquaredLoss") + else: + raise AssertionError(f"EBM does not support {problem_type} problem type.") + + return metric_class diff --git a/tabrepo/benchmark/models/model_register.py b/tabrepo/benchmark/models/model_register.py index 9066e788..4ee8f022 100644 --- a/tabrepo/benchmark/models/model_register.py +++ b/tabrepo/benchmark/models/model_register.py @@ -7,6 +7,7 @@ from tabrepo.benchmark.models.ag import ( ExplainableBoostingMachineModel, ModernNCAModel, + PerpetualBoostingModel, RealMLPModel, TabDPTModel, TabICLModel, @@ -26,6 +27,7 @@ TabDPTModel, TabMModel, ModernNCAModel, + PerpetualBoostingModel, ] for _model_cls in _models_to_add: @@ -43,7 +45,10 @@ def infer_model_cls(model_cls: str, model_register: ModelRegistry = None): if real_model_cls.ag_name == model_cls: model_cls = real_model_cls break - elif model_cls in [str(real_model_cls.__name__) for real_model_cls in model_register.model_cls_list]: + elif model_cls in [ + str(real_model_cls.__name__) + for real_model_cls in model_register.model_cls_list + ]: for real_model_cls in model_register.model_cls_list: if model_cls == str(real_model_cls.__name__): model_cls = real_model_cls diff --git a/tabrepo/models/perpetual/__init__.py b/tabrepo/models/perpetual/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tabrepo/models/perpetual/generate.py b/tabrepo/models/perpetual/generate.py new file mode 100644 index 00000000..ff6386c7 --- /dev/null +++ b/tabrepo/models/perpetual/generate.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from tabrepo.benchmark.models.ag.perpetual.perpetual_model import PerpetualBoostingModel +from tabrepo.utils.config_utils import ConfigGenerator + +# TODO: ask authors for search space / come up with something. +search_space = {} + +gen_perpetual = ConfigGenerator( + model_cls=PerpetualBoostingModel, + search_space=search_space, + manual_configs=[{}], +) diff --git a/tabrepo/models/utils.py b/tabrepo/models/utils.py index 5f021f4b..9daf1068 100644 --- a/tabrepo/models/utils.py +++ b/tabrepo/models/utils.py @@ -46,6 +46,7 @@ def get_configs_generator_from_name(model_name: str): # "TabPFN": lambda: importlib.import_module("tabrepo.models.tabpfn.generate").gen_tabpfn, # not supported in TabArena "TabPFNv2": lambda: importlib.import_module("tabrepo.models.tabpfnv2.generate").gen_tabpfnv2, "XGBoost": lambda: importlib.import_module("tabrepo.models.xgboost.generate").gen_xgboost, + "PerpetualBoosting": lambda: importlib.import_module("tabrepo.models.perpetual.generate").gen_perpetual, } if model_name not in name_to_import_map: diff --git a/tst/benchmark/models/test_perpetual.py b/tst/benchmark/models/test_perpetual.py new file mode 100644 index 00000000..09ffecd2 --- /dev/null +++ b/tst/benchmark/models/test_perpetual.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import pytest + + +def test_perpetual(): + model_hyperparameters = {"iteration_limit": 10, "budget": 0.1} + + try: + from autogluon.tabular.testing import FitHelper + from tabrepo.benchmark.models.ag.perpetual.perpetual_model import ( + PerpetualBoostingModel, + ) + + model_cls = PerpetualBoostingModel + FitHelper.verify_model( + model_cls=model_cls, + model_hyperparameters=model_hyperparameters, + ) + except ImportError as err: + pytest.skip( + f"Import Error, skipping test... " + f"Ensure you have the proper dependencies installed to run this test:\n" + f"{err}" + ) From e8c1e1927a934385fb8bad9926befd674e64e6e0 Mon Sep 17 00:00:00 2001 From: LennartPurucker Date: Wed, 25 Jun 2025 10:48:53 +0200 Subject: [PATCH 2/2] add: final version of PB --- .../models/ag/perpetual/perpetual_model.py | 59 +++++++++++++------ 1 file changed, 41 insertions(+), 18 deletions(-) diff --git a/tabrepo/benchmark/models/ag/perpetual/perpetual_model.py b/tabrepo/benchmark/models/ag/perpetual/perpetual_model.py index 53c2001f..08ee158e 100644 --- a/tabrepo/benchmark/models/ag/perpetual/perpetual_model.py +++ b/tabrepo/benchmark/models/ag/perpetual/perpetual_model.py @@ -11,6 +11,8 @@ from autogluon.core.metrics import Scorer +# TODO: memory limiting of the lib does not work correctly, can this be fixed? +# TODO: requires memory estimation logic to become useful to switch to sequential fitting class PerpetualBoostingModel(AbstractModel): ag_key = "PB" ag_name = "PerpetualBoosting" @@ -19,18 +21,19 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self._category_features: list[str] = None - def _preprocess_nonadaptive(self, X, **kwargs): - X = super()._preprocess_nonadaptive(X, **kwargs) + def _preprocess(self, X: pd.DataFrame, **kwargs) -> pd.DataFrame: + X = super()._preprocess(X, **kwargs) if self._category_features is None: - self._category_features = list(X.select_dtypes(include="category").columns) + self._category_features = X.select_dtypes( + include=["category"] + ).columns.tolist() return X - # TODO: support sample weights # TODO: API does not support a random seed, but rust code does (mhm?) # TODO: no GPU support (?), add warning. - # TODO: no support for passing validation data... problematic + # TODO: no support for passing validation data... needs to be added (as callback) def _fit( self, X: pd.DataFrame, @@ -49,31 +52,51 @@ def _fit( from perpetual import PerpetualBooster - # FIXME: get correct number within ray job, - # here we know it is one of 8 folds now. - # -> This is also not correct if not bagged mode is used, or outer - # memory limit is not set on SLURM. - # memory_limit = ResourceManager().get_available_virtual_mem(format="GB") / 8 - # the above does not work as env var is not set in ray job... - # need to pass this to the model kwargs as well - # future work... hard coding for now. - memory_limit = 4 - - # safety as memory limit is not strictly enforced by PB + # ----- The below is hacky workaround to set the memory limit + # TODO: set this in the outer scope or automatically here + sequential_fitting = False + if sequential_fitting: + memory_limit = 32 # all memory for the job, in GB + # 1 if sequential fitting is used + # otherwise memory limit is not enforced across threads. + num_cpus = 1 + + else: + # here we know it is one of 8 folds now. + # at this stage, num_cpus == 1 anyhow. + memory_limit = 4 + + # FIXME: does not work as env var is not set in ray job... + # need to pass this to the model kwargs as well + # future work... hard coding for now. + # memory_limit = ResourceManager().get_available_virtual_mem(format="GB") / 8 + + # ---- Additional bug + # FIXME: with a lot of categorical features, the memory limit is + # not enforced correctly. No way to control this as far as I can tell. + # - Example: to make this code run on 363711 (MIC) I set the limit to 32 GB + # but had to give it 64 GB to not OOM. If I set the limit to 64 GB, it uses + # 90ish GB and OOMs the job. + # memory_limit = 32 + # Even then, it still hangs on prediction and runs out of time. + + # safety as memory limit should be quicker than cgroups. memory_limit = int(memory_limit * 0.95) - self.model = PerpetualBooster( objective=get_metric_from_ag_metric( metric=self.eval_metric, problem_type=self.problem_type ), num_threads=num_cpus, - memory_limit=memory_limit, + memory_limit=memory_limit, # TODO: limit per thread? categorical_features=self._category_features, + # FIXME: time limit is also not strictly enforced, check when the + # loop is stopped and if preprocessing counts towards the limit. timeout=time_limit, # stopping_rounds - no idea how to set this **paras, ) + # FIXME: why does the out-of-budget message show up several times? self.model.fit(X=X, y=y, sample_weight=sample_weight) def _set_default_params(self):