Skip to content

Commit b6f0d23

Browse files
Charlotte TumescheitCharlotte Tumescheit
authored andcommitted
adjust classification tasks for new logic
1 parent dca60a3 commit b6f0d23

File tree

4 files changed

+1256
-85
lines changed

4 files changed

+1256
-85
lines changed

chebai/cli.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def call_data_methods(data: Type[XYBaseDataModule]):
6262
for kind in ("train", "val", "test"):
6363
# todo: fix this
6464
# for average in ("mse", "rmse","r2"): # for regression
65-
for average in ("micro-f1", "macro-f1", "balanced-accuracy", "f1"): # for classification
65+
for average in ("f1", "roc-auc"): # for binary classification
66+
# for average in ("micro-f1", "macro-f1", "roc-auc"): # for multilabel classification
67+
# for average in ("micro-f1", "macro-f1", "balanced-accuracy", "roc-auc"): # for multilabel classification using balanced-accuracy
6668
parser.link_arguments(
6769
"data.num_of_labels",
6870
f"model.init_args.{kind}_metrics.init_args.metrics.{average}.init_args.num_labels",

0 commit comments

Comments
 (0)