From 8218de3eec0308e501d61f73d38861bff80ab01e Mon Sep 17 00:00:00 2001 From: KohlerHECTOR Date: Thu, 26 Jun 2025 12:14:01 +0200 Subject: [PATCH 01/11] first commit to add dpdt --- setup.py | 5 + tabrepo/benchmark/models/ag/dpdt/__init__.py | 0 .../benchmark/models/ag/dpdt/dpdt_model.py | 124 ++++++++++++++++++ 3 files changed, 129 insertions(+) create mode 100644 tabrepo/benchmark/models/ag/dpdt/__init__.py create mode 100644 tabrepo/benchmark/models/ag/dpdt/dpdt_model.py diff --git a/setup.py b/setup.py index eab4ed0be..fd1c5b4f3 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,10 @@ "modernnca": [ "category_encoders", ], + "dpdt": [ + # TODO: pypi package is not available yet + "git+https://github.com/KohlerHECTOR/DPDTreeEstimator.git", + ], } benchmark_requires = [] @@ -51,6 +55,7 @@ "tabdpt", "tabm", "modernnca", + "dpdt", ]: benchmark_requires += extras_require[extra_package] benchmark_requires = list(set(benchmark_requires)) diff --git a/tabrepo/benchmark/models/ag/dpdt/__init__.py b/tabrepo/benchmark/models/ag/dpdt/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py b/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py new file mode 100644 index 000000000..46ae10262 --- /dev/null +++ b/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +from autogluon.core.models import AbstractModel + +if TYPE_CHECKING: + import pandas as pd + + +class CustomRandomForestModel(AbstractModel): + ag_key = "DPDT" + ag_name = "dpdt" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._feature_generator = None + + def _preprocess(self, X: pd.DataFrame, **kwargs) -> np.ndarray: + X = super()._preprocess(X, **kwargs) + return X.to_numpy() + + def _fit( + self, + X: pd.DataFrame, # training data + y: pd.Series, # training labels + # X_val=None, # val data (unused in RF model) + # y_val=None, # val labels (unused in RF model) + # time_limit=None, # time limit in seconds (ignored in tutorial) + num_cpus: int = 1, # number of CPUs to use for training + # num_gpus: int = 0, # number of GPUs to use for training + **kwargs, # kwargs includes many other potential inputs, refer to AbstractModel documentation for details + ): + # Select model class + if self.problem_type in ["regression"]: + from dpdt import DPDTreeRegressor + + model_cls = DPDTreeRegressor + else: + from dpdt import DPDTreeClassifier + + # case for 'binary' and 'multiclass', + model_cls = DPDTreeClassifier + + X = self.preprocess(X) + y = self.preprocess(y) + params = self._get_model_params() + self.model = model_cls(**params) + self.model.fit(X, y) + + def _set_default_params(self): + """Default parameters for the model.""" + default_params = { + "max_depth": 10, + "n_jobs": -1, + "random_state": 0, + "cart_nodes_list": (8,3,) + } + for param, val in default_params.items(): + self._set_default_param_value(param, val) + + def _get_default_auxiliary_params(self) -> dict: + """Specifics allowed input data and that all other dtypes should be handled + by the model-agnostic preprocessor. + """ + default_auxiliary_params = super()._get_default_auxiliary_params() + extra_auxiliary_params = { + "valid_raw_types": ["int", "float", "category"], + } + default_auxiliary_params.update(extra_auxiliary_params) + return default_auxiliary_params + + +# def get_configs_for_custom_rf( +# *, +# default_config: bool = True, +# num_random_configs: int = 1, +# sequential_fold_fitting: bool = False, +# ): +# """Generate the hyperparameter configurations to run for our custom random +# forest model. + +# sequential_fold_fitting: bool = False +# If True, the model will be configured to use sequential +# fold fitting (better for debugging, but usually slower). This is also a good +# idea to use on SLURM or other shared compute clusters where you want to run +# multiple jobs on the same node. +# See `tabflow_slurm.run_tabarena_experiment.setup_slurm_job` for ways to +# optimally use sequential_fold_fitting=False on SLURM. +# """ +# from autogluon.common.space import Int +# from tabrepo.utils.config_utils import ConfigGenerator + +# manual_configs = [ +# {}, +# ] +# search_space = { +# "n_estimators": Int(4, 50), +# } + +# gen_custom_rf = ConfigGenerator( +# model_cls=CustomRandomForestModel, +# manual_configs=manual_configs if default_config else None, +# search_space=search_space, +# ) +# experiments_lst = gen_custom_rf.generate_all_bag_experiments( +# num_random_configs=num_random_configs +# ) + +# if sequential_fold_fitting: +# for m_i in range(len(experiments_lst)): +# if ( +# "ag_args_ensemble" +# not in experiments_lst[m_i].method_kwargs["model_hyperparameters"] +# ): +# experiments_lst[m_i].method_kwargs["model_hyperparameters"][ +# "ag_args_ensemble" +# ] = {} +# experiments_lst[m_i].method_kwargs["model_hyperparameters"][ +# "ag_args_ensemble" +# ]["fold_fitting_strategy"] = "sequential_local" + +# return experiments_lst \ No newline at end of file From 38d660f02e3196f8272f198877d7698419d827aa Mon Sep 17 00:00:00 2001 From: KohlerHECTOR Date: Sun, 29 Jun 2025 21:55:24 +0200 Subject: [PATCH 02/11] Changed to AdaBoostDPDT to have compat with predict_proba --- tabrepo/benchmark/models/ag/dpdt/dpdt_model.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py b/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py index 46ae10262..ef5f2646b 100644 --- a/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py +++ b/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py @@ -10,8 +10,8 @@ class CustomRandomForestModel(AbstractModel): - ag_key = "DPDT" - ag_name = "dpdt" + ag_key = "BOOSTEDDPDT" + ag_name = "boosted_dpdt" def __init__(self, **kwargs): super().__init__(**kwargs) @@ -34,17 +34,14 @@ def _fit( ): # Select model class if self.problem_type in ["regression"]: - from dpdt import DPDTreeRegressor - - model_cls = DPDTreeRegressor + raise AssertionError, "Boosted DPDT does not support regression yet" else: - from dpdt import DPDTreeClassifier + from dpdt import AdaBoostDPDT # case for 'binary' and 'multiclass', - model_cls = DPDTreeClassifier + model_cls = AdaBoostDPDT X = self.preprocess(X) - y = self.preprocess(y) params = self._get_model_params() self.model = model_cls(**params) self.model.fit(X, y) From 3738c816ed7c6b5df0c99c3ade5926aefee73f93 Mon Sep 17 00:00:00 2001 From: KohlerHECTOR Date: Mon, 30 Jun 2025 01:29:40 +0200 Subject: [PATCH 03/11] Progress on BoostedDPDT; added memory estimate (estimators * dpdt tree memory estimate) --- .../benchmark/models/ag/dpdt/dpdt_model.py | 167 +++++++----------- 1 file changed, 67 insertions(+), 100 deletions(-) diff --git a/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py b/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py index ef5f2646b..e168d541a 100644 --- a/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py +++ b/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py @@ -2,120 +2,87 @@ from typing import TYPE_CHECKING -import numpy as np +from autogluon.common.utils.pandas_utils import get_approximate_df_mem_usage +from autogluon.common.utils.resource_utils import ResourceManager from autogluon.core.models import AbstractModel if TYPE_CHECKING: import pandas as pd -class CustomRandomForestModel(AbstractModel): - ag_key = "BOOSTEDDPDT" - ag_name = "boosted_dpdt" +class BoostedDPDT(AbstractModel): + ag_key = "ADABOOSTDPDT" + ag_name = "adaboost_dpdt" - def __init__(self, **kwargs): - super().__init__(**kwargs) - self._feature_generator = None + def get_model_cls(self): + from dpdt import AdaBoostDPDT - def _preprocess(self, X: pd.DataFrame, **kwargs) -> np.ndarray: - X = super()._preprocess(X, **kwargs) - return X.to_numpy() - - def _fit( - self, - X: pd.DataFrame, # training data - y: pd.Series, # training labels - # X_val=None, # val data (unused in RF model) - # y_val=None, # val labels (unused in RF model) - # time_limit=None, # time limit in seconds (ignored in tutorial) - num_cpus: int = 1, # number of CPUs to use for training - # num_gpus: int = 0, # number of GPUs to use for training - **kwargs, # kwargs includes many other potential inputs, refer to AbstractModel documentation for details - ): - # Select model class - if self.problem_type in ["regression"]: - raise AssertionError, "Boosted DPDT does not support regression yet" - else: - from dpdt import AdaBoostDPDT - - # case for 'binary' and 'multiclass', + if self.problem_type in ["binary", "multiclass"]: model_cls = AdaBoostDPDT - + else: + raise AssertionError(f"Unsupported problem_type: {self.problem_type}") + return model_cls + + def _fit(self, X: pd.DataFrame, y: pd.Series, num_cpus: int = 1, **kwargs): + model_cls = self.get_model_cls() + hyp = self._get_model_params() + if num_cpus < 1: + num_cpus = 'best' + self.model = model_cls( + **hyp, + n_jobs=num_cpus, + ) X = self.preprocess(X) - params = self._get_model_params() - self.model = model_cls(**params) - self.model.fit(X, y) + self.model = self.model.fit( + X=X, + y=y, + ) + def _set_default_params(self): - """Default parameters for the model.""" default_params = { - "max_depth": 10, - "n_jobs": -1, - "random_state": 0, - "cart_nodes_list": (8,3,) + "random_state": 42, } for param, val in default_params.items(): self._set_default_param_value(param, val) - def _get_default_auxiliary_params(self) -> dict: - """Specifics allowed input data and that all other dtypes should be handled - by the model-agnostic preprocessor. - """ - default_auxiliary_params = super()._get_default_auxiliary_params() - extra_auxiliary_params = { - "valid_raw_types": ["int", "float", "category"], - } - default_auxiliary_params.update(extra_auxiliary_params) - return default_auxiliary_params - - -# def get_configs_for_custom_rf( -# *, -# default_config: bool = True, -# num_random_configs: int = 1, -# sequential_fold_fitting: bool = False, -# ): -# """Generate the hyperparameter configurations to run for our custom random -# forest model. - -# sequential_fold_fitting: bool = False -# If True, the model will be configured to use sequential -# fold fitting (better for debugging, but usually slower). This is also a good -# idea to use on SLURM or other shared compute clusters where you want to run -# multiple jobs on the same node. -# See `tabflow_slurm.run_tabarena_experiment.setup_slurm_job` for ways to -# optimally use sequential_fold_fitting=False on SLURM. -# """ -# from autogluon.common.space import Int -# from tabrepo.utils.config_utils import ConfigGenerator - -# manual_configs = [ -# {}, -# ] -# search_space = { -# "n_estimators": Int(4, 50), -# } - -# gen_custom_rf = ConfigGenerator( -# model_cls=CustomRandomForestModel, -# manual_configs=manual_configs if default_config else None, -# search_space=search_space, -# ) -# experiments_lst = gen_custom_rf.generate_all_bag_experiments( -# num_random_configs=num_random_configs -# ) - -# if sequential_fold_fitting: -# for m_i in range(len(experiments_lst)): -# if ( -# "ag_args_ensemble" -# not in experiments_lst[m_i].method_kwargs["model_hyperparameters"] -# ): -# experiments_lst[m_i].method_kwargs["model_hyperparameters"][ -# "ag_args_ensemble" -# ] = {} -# experiments_lst[m_i].method_kwargs["model_hyperparameters"][ -# "ag_args_ensemble" -# ]["fold_fitting_strategy"] = "sequential_local" - -# return experiments_lst \ No newline at end of file + @classmethod + def supported_problem_types(cls) -> list[str] | None: + return ["binary", "multiclass"] + + def _get_default_resources(self) -> tuple[int, int]: + import torch + # logical=False is faster in training + num_cpus = ResourceManager.get_cpu_count_psutil(logical=False) + num_gpus = 0 + return num_cpus, num_gpus + + def _estimate_memory_usage(self, X: pd.DataFrame, **kwargs) -> int: + hyperparameters = self._get_model_params() + return self.estimate_memory_usage_static(X=X, problem_type=self.problem_type, num_classes=self.num_classes, hyperparameters=hyperparameters, **kwargs) + + @classmethod + def _estimate_memory_usage_static( + cls, + *, + X: pd.DataFrame, + hyperparameters: dict = None, + **kwargs, + ) -> int: + if hyperparameters is None: + hyperparameters = {} + + dataset_size_mem_est = 5 * hyperparameters.get('n_estimators') * hyperparameters.get('cart_nodes_list')[0] * get_approximate_df_mem_usage(X).sum() + baseline_overhead_mem_est = 3e8 # 300 MB generic overhead + + mem_estimate = dataset_size_mem_est + baseline_overhead_mem_est + + return mem_estimate + + @classmethod + def _class_tags(cls): + return {"can_estimate_memory_usage_static": True} + + def _more_tags(self) -> dict: + """DPDT does not yet support refit full.""" + return {"can_refit_full": False} \ No newline at end of file From 250df43b49cefbbed28be1ccac68be7266e9453f Mon Sep 17 00:00:00 2001 From: KohlerHECTOR Date: Mon, 30 Jun 2025 02:05:00 +0200 Subject: [PATCH 04/11] added configs and registered model and tests --- tabrepo/benchmark/models/ag/__init__.py | 2 + .../benchmark/models/ag/dpdt/dpdt_model.py | 6 +- tabrepo/benchmark/models/model_register.py | 2 + tabrepo/models/dpdt/__init__.py | 0 tabrepo/models/dpdt/generate.py | 58 +++++++++++++++++++ tabrepo/models/utils.py | 1 + tst/benchmark/models/test_dpdt.py | 17 ++++++ 7 files changed, 83 insertions(+), 3 deletions(-) create mode 100644 tabrepo/models/dpdt/__init__.py create mode 100644 tabrepo/models/dpdt/generate.py create mode 100644 tst/benchmark/models/test_dpdt.py diff --git a/tabrepo/benchmark/models/ag/__init__.py b/tabrepo/benchmark/models/ag/__init__.py index 4cfded7c6..03e3b3484 100644 --- a/tabrepo/benchmark/models/ag/__init__.py +++ b/tabrepo/benchmark/models/ag/__init__.py @@ -8,6 +8,7 @@ from tabrepo.benchmark.models.ag.tabm.tabm_model import TabMModel from tabrepo.benchmark.models.ag.tabpfnv2.tabpfnv2_client_model import TabPFNV2ClientModel from tabrepo.benchmark.models.ag.tabpfnv2.tabpfnv2_model import TabPFNV2Model +from tabrepo.benchmark.models.ag.dpdt.dpdt_model import BoostedDPDTModel __all__ = [ "ExplainableBoostingMachineModel", @@ -18,4 +19,5 @@ "TabMModel", "TabPFNV2ClientModel", "TabPFNV2Model", + "BoostedDPDTModel" ] diff --git a/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py b/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py index e168d541a..c181c9803 100644 --- a/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py +++ b/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py @@ -10,9 +10,9 @@ import pandas as pd -class BoostedDPDT(AbstractModel): - ag_key = "ADABOOSTDPDT" - ag_name = "adaboost_dpdt" +class BoostedDPDTModel(AbstractModel): + ag_key = "BOOSTEDDPDT" + ag_name = "boosted_dpdt" def get_model_cls(self): from dpdt import AdaBoostDPDT diff --git a/tabrepo/benchmark/models/model_register.py b/tabrepo/benchmark/models/model_register.py index 9066e788d..15e7a5940 100644 --- a/tabrepo/benchmark/models/model_register.py +++ b/tabrepo/benchmark/models/model_register.py @@ -13,6 +13,7 @@ TabMModel, TabPFNV2ClientModel, TabPFNV2Model, + BoostedDPDTModel, ) tabrepo_model_register: ModelRegistry = copy.deepcopy(ag_model_registry) @@ -26,6 +27,7 @@ TabDPTModel, TabMModel, ModernNCAModel, + BoostedDPDTModel, ] for _model_cls in _models_to_add: diff --git a/tabrepo/models/dpdt/__init__.py b/tabrepo/models/dpdt/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tabrepo/models/dpdt/generate.py b/tabrepo/models/dpdt/generate.py new file mode 100644 index 000000000..ff64c2162 --- /dev/null +++ b/tabrepo/models/dpdt/generate.py @@ -0,0 +1,58 @@ +from autogluon.common.space import Categorical, Real, Int +import numpy as np + +from tabrepo.benchmark.models.ag.dpdt.dpdt_model import BoostedDPDTModel +from tabrepo.utils.config_utils import ConfigGenerator + +name = 'BoostedDPDT' +manual_configs = [ + {}, +] + +# get config from paper + +# Generate 1000 samples from log-normal distribution +# Parameters: mu = log(0.01), sigma = log(10.0) +mu = float(np.log(0.01)) +sigma = float(np.log(10.0)) +samples = np.random.lognormal(mean=mu, sigma=sigma, size=1000) + +# Generate 1000 samples from q_log_uniform_values distribution +# Parameters: min=1.5, max=50.5, q=1 +min_val = 1.5 +max_val = 50.5 +q = 1 +# Generate log-uniform samples and quantize +log_min = np.log(min_val) +log_max = np.log(max_val) +log_uniform_samples = np.random.uniform(log_min, log_max, size=1000) +min_samples_leaf_samples = np.round(np.exp(log_uniform_samples) / q) * q +min_samples_leaf_samples = np.clip(min_samples_leaf_samples, min_val, max_val).astype(int) + +# Generate 1000 samples for min_weight_fraction_leaf +# Values: [0.0, 0.01], probabilities: [0.95, 0.05] +min_weight_fraction_leaf_samples = np.random.choice([0.0, 0.01], size=1000, p=[0.95, 0.05]) + +# Generate 1000 samples for max_features +# Values: ["sqrt", "log2", 10000], probabilities: [0.5, 0.25, 0.25] +max_features_samples = np.random.choice(["sqrt", "log2", 10000], size=1000, p=[0.5, 0.25, 0.25]) + +search_space = { + 'learning_rate': Categorical(*samples), # log_normal distribution equivalent + 'n_estimators': 1000, # Fixed value as per old config + 'max_depth': Categorical(2, 2, 2, 2, 3, 3, 3, 3, 3, 3), + 'min_samples_split': Categorical(*np.random.choice([2, 3], size=1000, p=[0.95, 0.05])), + 'min_impurity_decrease': Categorical(*np.random.choice([0, 0.01, 0.02, 0.05], size=1000, p=[0.85, 0.05, 0.05, 0.05])), + 'cart_nodes_list': Categorical((32,), (8, 4), (4, 8), (16, 2), (4, 4, 2)), + 'min_samples_leaf': Categorical(*min_samples_leaf_samples), # q_log_uniform equivalent + 'min_weight_fraction_leaf': Categorical(*min_weight_fraction_leaf_samples), + 'max_features': Categorical(*max_features_samples), + 'random_state': Categorical(0, 1, 2, 3, 4) +} + +gen_boosteddpdt = ConfigGenerator(model_cls=BoostedDPDTModel, manual_configs=manual_configs, search_space=search_space) + + +def generate_configs_boosted_dpdt(num_random_configs=200): + config_generator = ConfigGenerator(name=name, manual_configs=manual_configs, search_space=search_space) + return config_generator.generate_all_configs(num_random_configs=num_random_configs) diff --git a/tabrepo/models/utils.py b/tabrepo/models/utils.py index 5f021f4bc..5940d95c0 100644 --- a/tabrepo/models/utils.py +++ b/tabrepo/models/utils.py @@ -46,6 +46,7 @@ def get_configs_generator_from_name(model_name: str): # "TabPFN": lambda: importlib.import_module("tabrepo.models.tabpfn.generate").gen_tabpfn, # not supported in TabArena "TabPFNv2": lambda: importlib.import_module("tabrepo.models.tabpfnv2.generate").gen_tabpfnv2, "XGBoost": lambda: importlib.import_module("tabrepo.models.xgboost.generate").gen_xgboost, + "BoostedDPDT": lambda: importlib.import_module("tabrepo.models.dpdt.generate").gen_boosteddpdt, } if model_name not in name_to_import_map: diff --git a/tst/benchmark/models/test_dpdt.py b/tst/benchmark/models/test_dpdt.py new file mode 100644 index 000000000..5c6061e7d --- /dev/null +++ b/tst/benchmark/models/test_dpdt.py @@ -0,0 +1,17 @@ +import pytest + + +def test_dpdt(): + model_hyperparameters = {"n_estimators": 2, "cart_nodes_list":(4,3)} + + try: + from autogluon.tabular.testing import FitHelper + from tabrepo.benchmark.models.ag.tabicl.tabicl_model import BoostedDPDTModel + model_cls = BoostedDPDTModel + FitHelper.verify_model(model_cls=model_cls, model_hyperparameters=model_hyperparameters) + except ImportError as err: + pytest.skip( + f"Import Error, skipping test... " + f"Ensure you have the proper dependencies installed to run this test:\n" + f"{err}" + ) From 0534f8e6ba76f82b2000266cb40fddb0cd4b56fd Mon Sep 17 00:00:00 2001 From: KohlerHECTOR Date: Mon, 30 Jun 2025 02:17:48 +0200 Subject: [PATCH 05/11] Fixed some typos --- tabrepo/benchmark/models/ag/dpdt/dpdt_model.py | 2 +- tabrepo/models/dpdt/generate.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py b/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py index c181c9803..a8879b2cc 100644 --- a/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py +++ b/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py @@ -72,7 +72,7 @@ def _estimate_memory_usage_static( if hyperparameters is None: hyperparameters = {} - dataset_size_mem_est = 5 * hyperparameters.get('n_estimators') * hyperparameters.get('cart_nodes_list')[0] * get_approximate_df_mem_usage(X).sum() + dataset_size_mem_est = 10 * hyperparameters.get('cart_nodes_list')[0] * get_approximate_df_mem_usage(X).sum() baseline_overhead_mem_est = 3e8 # 300 MB generic overhead mem_estimate = dataset_size_mem_est + baseline_overhead_mem_est diff --git a/tabrepo/models/dpdt/generate.py b/tabrepo/models/dpdt/generate.py index ff64c2162..235aed37f 100644 --- a/tabrepo/models/dpdt/generate.py +++ b/tabrepo/models/dpdt/generate.py @@ -43,7 +43,7 @@ 'max_depth': Categorical(2, 2, 2, 2, 3, 3, 3, 3, 3, 3), 'min_samples_split': Categorical(*np.random.choice([2, 3], size=1000, p=[0.95, 0.05])), 'min_impurity_decrease': Categorical(*np.random.choice([0, 0.01, 0.02, 0.05], size=1000, p=[0.85, 0.05, 0.05, 0.05])), - 'cart_nodes_list': Categorical((32,), (8, 4), (4, 8), (16, 2), (4, 4, 2)), + 'cart_nodes_list': Categorical((8, 4), (4, 8), (16, 2), (4, 4, 2)), 'min_samples_leaf': Categorical(*min_samples_leaf_samples), # q_log_uniform equivalent 'min_weight_fraction_leaf': Categorical(*min_weight_fraction_leaf_samples), 'max_features': Categorical(*max_features_samples), From c8f8df779af9bba60bf0f0f598bc0670b4199432 Mon Sep 17 00:00:00 2001 From: KohlerHECTOR Date: Mon, 30 Jun 2025 19:03:48 +0200 Subject: [PATCH 06/11] Fixed some cpus stuff --- tabrepo/benchmark/models/ag/dpdt/dpdt_model.py | 7 +++---- tabrepo/models/dpdt/generate.py | 1 - 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py b/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py index a8879b2cc..c149366d6 100644 --- a/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py +++ b/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py @@ -26,11 +26,11 @@ def get_model_cls(self): def _fit(self, X: pd.DataFrame, y: pd.Series, num_cpus: int = 1, **kwargs): model_cls = self.get_model_cls() hyp = self._get_model_params() - if num_cpus < 1: - num_cpus = 'best' + self.model = model_cls( **hyp, - n_jobs=num_cpus, + n_jobs='best', + n_estimators=1000, ) X = self.preprocess(X) self.model = self.model.fit( @@ -51,7 +51,6 @@ def supported_problem_types(cls) -> list[str] | None: return ["binary", "multiclass"] def _get_default_resources(self) -> tuple[int, int]: - import torch # logical=False is faster in training num_cpus = ResourceManager.get_cpu_count_psutil(logical=False) num_gpus = 0 diff --git a/tabrepo/models/dpdt/generate.py b/tabrepo/models/dpdt/generate.py index 235aed37f..e354576dc 100644 --- a/tabrepo/models/dpdt/generate.py +++ b/tabrepo/models/dpdt/generate.py @@ -39,7 +39,6 @@ search_space = { 'learning_rate': Categorical(*samples), # log_normal distribution equivalent - 'n_estimators': 1000, # Fixed value as per old config 'max_depth': Categorical(2, 2, 2, 2, 3, 3, 3, 3, 3, 3), 'min_samples_split': Categorical(*np.random.choice([2, 3], size=1000, p=[0.95, 0.05])), 'min_impurity_decrease': Categorical(*np.random.choice([0, 0.01, 0.02, 0.05], size=1000, p=[0.85, 0.05, 0.05, 0.05])), From 11324dcd58408c9b81a9447e6e9fb4089dac6c71 Mon Sep 17 00:00:00 2001 From: KohlerHECTOR Date: Thu, 3 Jul 2025 09:22:30 -0600 Subject: [PATCH 07/11] updated with time limit --- tabrepo/benchmark/models/ag/dpdt/dpdt_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py b/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py index c149366d6..26b2be4fd 100644 --- a/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py +++ b/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py @@ -23,7 +23,7 @@ def get_model_cls(self): raise AssertionError(f"Unsupported problem_type: {self.problem_type}") return model_cls - def _fit(self, X: pd.DataFrame, y: pd.Series, num_cpus: int = 1, **kwargs): + def _fit(self, X: pd.DataFrame, y: pd.Series, num_cpus: int = 1, time_limit: float = None,**kwargs): model_cls = self.get_model_cls() hyp = self._get_model_params() @@ -31,6 +31,7 @@ def _fit(self, X: pd.DataFrame, y: pd.Series, num_cpus: int = 1, **kwargs): **hyp, n_jobs='best', n_estimators=1000, + time_limit=time_limit, ) X = self.preprocess(X) self.model = self.model.fit( From 2096a9bad635aa814b4b8d4a78c31ff95a4eab78 Mon Sep 17 00:00:00 2001 From: LennartPurucker Date: Fri, 18 Jul 2025 22:47:36 +0200 Subject: [PATCH 08/11] maint: minor refactor and make test run --- setup.py | 1 + .../benchmark/models/ag/dpdt/dpdt_model.py | 42 ++++++++++++------- tst/benchmark/models/test_dpdt.py | 2 +- 3 files changed, 30 insertions(+), 15 deletions(-) diff --git a/setup.py b/setup.py index 2d1e4a62e..5f20c8704 100644 --- a/setup.py +++ b/setup.py @@ -44,6 +44,7 @@ "dpdt": [ # TODO: pypi package is not available yet "git+https://github.com/KohlerHECTOR/DPDTreeEstimator.git", + # used hash: a74791d2190da27b43accd4da9e7d141380326ea ], } diff --git a/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py b/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py index 26b2be4fd..4b2bbfe92 100644 --- a/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py +++ b/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py @@ -22,15 +22,21 @@ def get_model_cls(self): else: raise AssertionError(f"Unsupported problem_type: {self.problem_type}") return model_cls - - def _fit(self, X: pd.DataFrame, y: pd.Series, num_cpus: int = 1, time_limit: float = None,**kwargs): + + def _fit( + self, + X: pd.DataFrame, + y: pd.Series, + num_cpus: int = 1, + time_limit: float | None = None, + **kwargs, + ): model_cls = self.get_model_cls() hyp = self._get_model_params() self.model = model_cls( **hyp, - n_jobs='best', - n_estimators=1000, + n_jobs="best" if num_cpus > 1 else num_cpus, time_limit=time_limit, ) X = self.preprocess(X) @@ -38,11 +44,11 @@ def _fit(self, X: pd.DataFrame, y: pd.Series, num_cpus: int = 1, time_limit: flo X=X, y=y, ) - def _set_default_params(self): default_params = { "random_state": 42, + "n_estimators": 1000, } for param, val in default_params.items(): self._set_default_param_value(param, val) @@ -50,7 +56,7 @@ def _set_default_params(self): @classmethod def supported_problem_types(cls) -> list[str] | None: return ["binary", "multiclass"] - + def _get_default_resources(self) -> tuple[int, int]: # logical=False is faster in training num_cpus = ResourceManager.get_cpu_count_psutil(logical=False) @@ -59,25 +65,33 @@ def _get_default_resources(self) -> tuple[int, int]: def _estimate_memory_usage(self, X: pd.DataFrame, **kwargs) -> int: hyperparameters = self._get_model_params() - return self.estimate_memory_usage_static(X=X, problem_type=self.problem_type, num_classes=self.num_classes, hyperparameters=hyperparameters, **kwargs) + return self.estimate_memory_usage_static( + X=X, + problem_type=self.problem_type, + num_classes=self.num_classes, + hyperparameters=hyperparameters, + **kwargs, + ) @classmethod def _estimate_memory_usage_static( cls, *, X: pd.DataFrame, - hyperparameters: dict = None, + hyperparameters: dict | None = None, **kwargs, ) -> int: if hyperparameters is None: hyperparameters = {} - - dataset_size_mem_est = 10 * hyperparameters.get('cart_nodes_list')[0] * get_approximate_df_mem_usage(X).sum() - baseline_overhead_mem_est = 3e8 # 300 MB generic overhead - mem_estimate = dataset_size_mem_est + baseline_overhead_mem_est + dataset_size_mem_est = ( + 10 + * hyperparameters.get("cart_nodes_list")[0] + * get_approximate_df_mem_usage(X).sum() + ) + baseline_overhead_mem_est = 3e8 # 300 MB generic overhead - return mem_estimate + return dataset_size_mem_est + baseline_overhead_mem_est @classmethod def _class_tags(cls): @@ -85,4 +99,4 @@ def _class_tags(cls): def _more_tags(self) -> dict: """DPDT does not yet support refit full.""" - return {"can_refit_full": False} \ No newline at end of file + return {"can_refit_full": False} diff --git a/tst/benchmark/models/test_dpdt.py b/tst/benchmark/models/test_dpdt.py index 5c6061e7d..d2347e200 100644 --- a/tst/benchmark/models/test_dpdt.py +++ b/tst/benchmark/models/test_dpdt.py @@ -6,7 +6,7 @@ def test_dpdt(): try: from autogluon.tabular.testing import FitHelper - from tabrepo.benchmark.models.ag.tabicl.tabicl_model import BoostedDPDTModel + from tabrepo.benchmark.models.ag import BoostedDPDTModel model_cls = BoostedDPDTModel FitHelper.verify_model(model_cls=model_cls, model_hyperparameters=model_hyperparameters) except ImportError as err: From 6dbe698537c8fc6af91c0bc4b93176c336e24b36 Mon Sep 17 00:00:00 2001 From: LennartPurucker Date: Sat, 19 Jul 2025 21:29:10 +0200 Subject: [PATCH 09/11] add: preprocessing for nan and cat handling --- .../benchmark/models/ag/dpdt/dpdt_model.py | 61 ++++++++++++++++++- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py b/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py index 4b2bbfe92..90ff0f594 100644 --- a/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py +++ b/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py @@ -10,10 +10,18 @@ import pandas as pd +def _to_cat(X): + return X + + class BoostedDPDTModel(AbstractModel): ag_key = "BOOSTEDDPDT" ag_name = "boosted_dpdt" + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._preprocessor = None + def get_model_cls(self): from dpdt import AdaBoostDPDT @@ -23,6 +31,55 @@ def get_model_cls(self): raise AssertionError(f"Unsupported problem_type: {self.problem_type}") return model_cls + def _preprocess(self, X, **kwargs): + X = super()._preprocess(X, **kwargs) + if self._preprocessor is None: + import numpy as np + from sklearn.compose import ColumnTransformer, make_column_selector + from sklearn.impute import SimpleImputer + from sklearn.pipeline import Pipeline + from sklearn.preprocessing import FunctionTransformer, OrdinalEncoder + + categorical_pipeline = Pipeline( + [ + ( + "encoder", + OrdinalEncoder( + handle_unknown="use_encoded_value", unknown_value=-1 + ), + ), + ( + "imputer", + SimpleImputer( + strategy="constant", add_indicator=True, fill_value=-1 + ), + ), + ( + "to_category", + FunctionTransformer(_to_cat), + ), + ] + ).set_output(transform="pandas") + + self._preprocessor = ColumnTransformer( + transformers=[ + ( + "num", + SimpleImputer(strategy="mean", add_indicator=True), + make_column_selector(dtype_include=np.number), + ), + ( + "cat", + categorical_pipeline, + make_column_selector(dtype_include=["object", "category"]), + ), + ], + remainder="passthrough", + ).set_output(transform="pandas") + self._preprocessor.fit(X) + + return self._preprocessor.transform(X) + def _fit( self, X: pd.DataFrame, @@ -86,12 +143,12 @@ def _estimate_memory_usage_static( dataset_size_mem_est = ( 10 - * hyperparameters.get("cart_nodes_list")[0] + * hyperparameters.get("cart_nodes_list", [2.5])[0] * get_approximate_df_mem_usage(X).sum() ) baseline_overhead_mem_est = 3e8 # 300 MB generic overhead - return dataset_size_mem_est + baseline_overhead_mem_est + return int(dataset_size_mem_est + baseline_overhead_mem_est) @classmethod def _class_tags(cls): From 996e72d0f1f4faf2d30a2398209c7cddb94d7990 Mon Sep 17 00:00:00 2001 From: LennartPurucker Date: Sat, 19 Jul 2025 22:15:23 +0200 Subject: [PATCH 10/11] add/fix: search space for HPO of dpdt --- tabrepo/models/dpdt/generate.py | 144 ++++++++++++++++++++------------ 1 file changed, 90 insertions(+), 54 deletions(-) diff --git a/tabrepo/models/dpdt/generate.py b/tabrepo/models/dpdt/generate.py index e354576dc..349bf1f4b 100644 --- a/tabrepo/models/dpdt/generate.py +++ b/tabrepo/models/dpdt/generate.py @@ -1,57 +1,93 @@ -from autogluon.common.space import Categorical, Real, Int +from __future__ import annotations + import numpy as np from tabrepo.benchmark.models.ag.dpdt.dpdt_model import BoostedDPDTModel -from tabrepo.utils.config_utils import ConfigGenerator - -name = 'BoostedDPDT' -manual_configs = [ - {}, -] - -# get config from paper - -# Generate 1000 samples from log-normal distribution -# Parameters: mu = log(0.01), sigma = log(10.0) -mu = float(np.log(0.01)) -sigma = float(np.log(10.0)) -samples = np.random.lognormal(mean=mu, sigma=sigma, size=1000) - -# Generate 1000 samples from q_log_uniform_values distribution -# Parameters: min=1.5, max=50.5, q=1 -min_val = 1.5 -max_val = 50.5 -q = 1 -# Generate log-uniform samples and quantize -log_min = np.log(min_val) -log_max = np.log(max_val) -log_uniform_samples = np.random.uniform(log_min, log_max, size=1000) -min_samples_leaf_samples = np.round(np.exp(log_uniform_samples) / q) * q -min_samples_leaf_samples = np.clip(min_samples_leaf_samples, min_val, max_val).astype(int) - -# Generate 1000 samples for min_weight_fraction_leaf -# Values: [0.0, 0.01], probabilities: [0.95, 0.05] -min_weight_fraction_leaf_samples = np.random.choice([0.0, 0.01], size=1000, p=[0.95, 0.05]) - -# Generate 1000 samples for max_features -# Values: ["sqrt", "log2", 10000], probabilities: [0.5, 0.25, 0.25] -max_features_samples = np.random.choice(["sqrt", "log2", 10000], size=1000, p=[0.5, 0.25, 0.25]) - -search_space = { - 'learning_rate': Categorical(*samples), # log_normal distribution equivalent - 'max_depth': Categorical(2, 2, 2, 2, 3, 3, 3, 3, 3, 3), - 'min_samples_split': Categorical(*np.random.choice([2, 3], size=1000, p=[0.95, 0.05])), - 'min_impurity_decrease': Categorical(*np.random.choice([0, 0.01, 0.02, 0.05], size=1000, p=[0.85, 0.05, 0.05, 0.05])), - 'cart_nodes_list': Categorical((8, 4), (4, 8), (16, 2), (4, 4, 2)), - 'min_samples_leaf': Categorical(*min_samples_leaf_samples), # q_log_uniform equivalent - 'min_weight_fraction_leaf': Categorical(*min_weight_fraction_leaf_samples), - 'max_features': Categorical(*max_features_samples), - 'random_state': Categorical(0, 1, 2, 3, 4) -} - -gen_boosteddpdt = ConfigGenerator(model_cls=BoostedDPDTModel, manual_configs=manual_configs, search_space=search_space) - - -def generate_configs_boosted_dpdt(num_random_configs=200): - config_generator = ConfigGenerator(name=name, manual_configs=manual_configs, search_space=search_space) - return config_generator.generate_all_configs(num_random_configs=num_random_configs) +from tabrepo.models.utils import convert_numpy_dtypes +from tabrepo.utils.config_utils import CustomAGConfigGenerator + + +def generate_configs_bdpdt(num_random_configs=200): + # TODO: transform this to a ConfigSpace configuration space or similar + # TODO: and/or switch to better random seed logic + + # Generate 1000 samples from log-normal distribution + # Parameters: mu = log(0.01), sigma = log(10.0) + np.random.seed(42) # For reproducibility + mu = float(np.log(0.01)) + sigma = float(np.log(10.0)) + samples = np.random.lognormal(mean=mu, sigma=sigma, size=num_random_configs) + + # Generate 1000 samples from q_log_uniform_values distribution + # Parameters: min=1.5, max=50.5, q=1 + np.random.seed(43) + min_val = 1.5 + max_val = 50.5 + q = 1 + # Generate log-uniform samples and quantize + log_min = np.log(min_val) + log_max = np.log(max_val) + log_uniform_samples = np.random.uniform(log_min, log_max, size=num_random_configs) + min_samples_leaf_samples = np.round(np.exp(log_uniform_samples) / q) * q + min_samples_leaf_samples = np.clip( + min_samples_leaf_samples, min_val, max_val + ).astype(int) + + # Generate 1000 samples for min_weight_fraction_leaf + # Values: [0.0, 0.01], probabilities: [0.95, 0.05] + np.random.seed(44) + min_weight_fraction_leaf_samples = np.random.choice( + [0.0, 0.01], size=num_random_configs, p=[0.95, 0.05] + ) + + # Generate 1000 samples for max_features + # Values: ["sqrt", "log2", 10000], probabilities: [0.5, 0.25, 0.25] + np.random.seed(45) + max_features_samples = np.random.choice( + ["sqrt", "log2", 10000], size=num_random_configs, p=[0.5, 0.25, 0.25] + ) + + np.random.seed(46) + max_depth_samples = np.random.choice([2, 3], size=num_random_configs, p=[0.4, 0.6]) + + np.random.seed(47) + min_samples_split = np.random.choice( + [2, 3], size=num_random_configs, p=[0.95, 0.05] + ) + + np.random.seed(48) + min_impurity_decrease_samples = np.random.choice( + [0, 0.01, 0.02, 0.05], size=num_random_configs, p=[0.85, 0.05, 0.05, 0.05] + ) + + np.random.seed(49) + choices = [[8, 4], [4, 8], [16, 2], [4, 4, 2]] + indices = np.random.choice(len(choices), size=num_random_configs) + cart_nodes_list = [choices[i] for i in indices] + + configs = [] + for i in range(num_random_configs): + config = { + "learning_rate": samples[i], + "max_depth": max_depth_samples[i], + "min_samples_split": min_samples_split[i], + "min_impurity_decrease": min_impurity_decrease_samples[i], + "cart_nodes_list": cart_nodes_list[i], + "min_samples_leaf": min_samples_leaf_samples[i], + "min_weight_fraction_leaf": min_weight_fraction_leaf_samples[i], + "max_features": max_features_samples[i], + } + configs.append(config) + + return [convert_numpy_dtypes(config) for config in configs] + + +gen_boosteddpdt = CustomAGConfigGenerator( + model_cls=BoostedDPDTModel, + search_space_func=generate_configs_bdpdt, + manual_configs=[{}], +) + + +if __name__ == "__main__": + print(generate_configs_bdpdt(3)) From ac10777b7090b5b0325f7872891e08aaeb8e4d08 Mon Sep 17 00:00:00 2001 From: LennartPurucker Date: Wed, 23 Jul 2025 16:04:12 +0200 Subject: [PATCH 11/11] add: state after EBM rerun --- tabrepo/benchmark/models/ag/dpdt/dpdt_model.py | 4 +++- tabrepo/models/dpdt/generate.py | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py b/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py index 90ff0f594..a9ee6e3d9 100644 --- a/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py +++ b/tabrepo/benchmark/models/ag/dpdt/dpdt_model.py @@ -138,12 +138,14 @@ def _estimate_memory_usage_static( hyperparameters: dict | None = None, **kwargs, ) -> int: + # TODO: add a callback that stops when running out of memory. if hyperparameters is None: hyperparameters = {} dataset_size_mem_est = ( - 10 + 40 * hyperparameters.get("cart_nodes_list", [2.5])[0] + * hyperparameters.get("cart_nodes_list", [0, 1])[1] * get_approximate_df_mem_usage(X).sum() ) baseline_overhead_mem_est = 3e8 # 300 MB generic overhead diff --git a/tabrepo/models/dpdt/generate.py b/tabrepo/models/dpdt/generate.py index 349bf1f4b..9d4188bea 100644 --- a/tabrepo/models/dpdt/generate.py +++ b/tabrepo/models/dpdt/generate.py @@ -67,6 +67,10 @@ def generate_configs_bdpdt(num_random_configs=200): configs = [] for i in range(num_random_configs): + try: + max_features = int(max_features_samples[i]) + except Exception: + max_features = max_features_samples[i] config = { "learning_rate": samples[i], "max_depth": max_depth_samples[i], @@ -75,7 +79,7 @@ def generate_configs_bdpdt(num_random_configs=200): "cart_nodes_list": cart_nodes_list[i], "min_samples_leaf": min_samples_leaf_samples[i], "min_weight_fraction_leaf": min_weight_fraction_leaf_samples[i], - "max_features": max_features_samples[i], + "max_features": max_features, } configs.append(config)