Skip to content
Merged
2 changes: 1 addition & 1 deletion scripts/run_generate_all_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

experiments_tabicl = gen_tabicl.generate_all_bag_experiments(num_random_configs=0)
experiments_tabpfnv2 = gen_tabpfnv2.generate_all_bag_experiments(num_random_configs=n_random_configs)
experiments_tabdpt = gen_tabdpt.generate_all_bag_experiments(num_random_configs=0)
experiments_tabdpt = gen_tabdpt.generate_all_bag_experiments(num_random_configs=n_random_configs)
experiments_modernnca = gen_modernnca.generate_all_bag_experiments(num_random_configs=n_random_configs)

# Dummy (constant predictor)
Expand Down
6 changes: 1 addition & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,7 @@
"pytabkit>=1.5.0,<2.0",
],
"tabdpt": [
# TODO: pypi package is not available yet
# FIXME: newest version (1.1) has (unnecessary) strict version requirements
# that are not compatible with autogluon, so we stick to the old hash
"tabdpt @ git+https://github.com/layer6ai-labs/TabDPT.git@9699d9592b61c5f70fc88f5531cdb87b40cbedf5",
# used hash: 9699d9592b61c5f70fc88f5531cdb87b40cbedf5
"tabdpt>=1.1.5"
],
"tabm": [
"torch",
Expand Down
92 changes: 9 additions & 83 deletions tabrepo/benchmark/models/ag/tabdpt/tabdpt_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
from __future__ import annotations

import math
import os
import shutil
import sys
import warnings
from pathlib import Path
from typing import TYPE_CHECKING

from autogluon.common.utils.resource_utils import ResourceManager
Expand Down Expand Up @@ -47,16 +42,20 @@ def _fit(
from tabdpt import TabDPTClassifier, TabDPTRegressor

model_cls = TabDPTClassifier if self.problem_type in [BINARY, MULTICLASS] else TabDPTRegressor
supported_hps = ('context_size', 'permute_classes', 'temperature') \
if model_cls is TabDPTClassifier \
else ('context_size',)

hps = self._get_model_params()
self._predict_hps = dict(seed=42, context_size=1024)
self._predict_hps = {
k:v for k,v in hps.items() if k in supported_hps
}
self._predict_hps['seed'] = 42
X = self.preprocess(X)
y = y.to_numpy()
self.model = model_cls(
path=self._download_and_get_model_path(),
device=device,
use_flash=self._use_flash(),
**hps,
use_flash=self._use_flash()
)
self.model.fit(X=X, y=y)

Expand All @@ -76,23 +75,6 @@ def _use_flash() -> bool:

return True

@staticmethod
def _download_and_get_model_path() -> str:
# We follow TabPFN-logic for model caching as /tmp is not a persistent cache location.
from tabdpt.estimator import TabDPTEstimator
from tabdpt.utils import download_model

model_dir = _user_cache_dir(platform=sys.platform, appname="tabdpt")
model_dir.mkdir(exist_ok=True, parents=True)

final_model_path = model_dir / Path(TabDPTEstimator._DEFAULT_CHECKPOINT_PATH).name

if not final_model_path.exists():
model_path = Path(download_model()) # downloads to /tmp
shutil.copy(model_path, final_model_path) # copy to user cache dir

return str(final_model_path)

def _get_default_resources(self) -> tuple[int, int]:
# Use only physical cores for better performance based on benchmarks
num_cpus = ResourceManager.get_cpu_count(only_physical_cores=True)
Expand All @@ -110,17 +92,10 @@ def get_minimum_resources(self, is_gpu_available: bool = False) -> dict[str, int
def _predict_proba(self, X, **kwargs) -> np.ndarray:
X = self.preprocess(X, **kwargs)

# Fix bug in TabDPt where batches of length 1 crash the prediction.
# - We set the inference size such that there are no batches of length 1.
math.ceil(len(X) / self.model.inf_batch_size)
last_batch_size = len(X) % self.model.inf_batch_size
if last_batch_size == 1:
self.model.inf_batch_size += 1

if self.problem_type in [REGRESSION]:
return self.model.predict(X, **self._predict_hps)

y_pred_proba = self.model.predict_proba(X, **self._predict_hps)
y_pred_proba = self.model.ensemble_predict_proba(X, **self._predict_hps)
return self._convert_proba_to_unified_form(y_pred_proba)

def _preprocess(self, X: pd.DataFrame, **kwargs) -> pd.DataFrame:
Expand Down Expand Up @@ -150,52 +125,3 @@ def _get_default_ag_args_ensemble(cls, **kwargs) -> dict:
}
default_ag_args_ensemble.update(extra_ag_args_ensemble)
return default_ag_args_ensemble

# Vendored from TabPFNv2 Code
def _user_cache_dir(platform: str, appname: str = "tabpfn") -> Path:
use_instead_path = (Path.cwd() / ".tabpfn_models").resolve()

# https://docs.python.org/3/library/sys.html#sys.platform
if platform == "win32":
# Honestly, I don't want to do what `platformdirs` does:
# https://github.com/tox-dev/platformdirs/blob/b769439b2a3b70769a93905944a71b3e63ef4823/src/platformdirs/windows.py#L252-L265
APPDATA_PATH = os.environ.get("APPDATA", "")
if APPDATA_PATH.strip() != "":
return Path(APPDATA_PATH) / appname

warnings.warn(
"Could not find APPDATA environment variable to get user cache dir,"
" but detected platform 'win32'."
f" Defaulting to a path '{use_instead_path}'."
" If you would prefer, please specify a directory when creating"
" the model.",
UserWarning,
stacklevel=2,
)
return use_instead_path

if platform == "darwin":
return Path.home() / "Library" / "Caches" / appname

# TODO: Not entirely sure here, Python doesn't explicitly list
# all of these and defaults to the underlying operating system
# if not sure.
linux_likes = ("freebsd", "linux", "netbsd", "openbsd")
if any(platform.startswith(linux) for linux in linux_likes):
# The reason to use "" as default is that the env var could exist but be empty.
# We catch all this with the `.strip() != ""` below
XDG_CACHE_HOME = os.environ.get("XDG_CACHE_HOME", "")
if XDG_CACHE_HOME.strip() != "":
return Path(XDG_CACHE_HOME) / appname
return Path.home() / ".cache" / appname

warnings.warn(
f"Unknown platform '{platform}' to get user cache dir."
f" Defaulting to a path at the execution site '{use_instead_path}'."
" If you would prefer, please specify a directory when creating"
" the model.",
UserWarning,
stacklevel=2,
)
return use_instead_path

8 changes: 7 additions & 1 deletion tabrepo/models/tabdpt/generate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from autogluon.common.space import Real, Categorical
from tabrepo.benchmark.models.ag.tabdpt.tabdpt_model import TabDPTModel
from tabrepo.utils.config_utils import ConfigGenerator

Expand All @@ -8,9 +9,14 @@
# Default config with refit after cross-validation.
{"ag_args_ensemble": {"refit_folds": True}},
]
search_space = {
'temperature': Real(0.05, 1.5, default=0.8),
'context_size': Categorical(2048, 768, 256),
'permute_classes': Categorical(True, False)
}

gen_tabdpt = ConfigGenerator(
model_cls=TabDPTModel, manual_configs=manual_configs, search_space={}
model_cls=TabDPTModel, manual_configs=manual_configs, search_space=search_space
)

if __name__ == "__main__":
Expand Down