Skip to content

Commit dca60a3

Browse files
Charlotte TumescheitCharlotte Tumescheit
authored andcommitted
adjust all regression tasks to new logic
1 parent ed1d4b4 commit dca60a3

16 files changed

+750
-72
lines changed

chebai/cli.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser):
4343
def call_data_methods(data: Type[XYBaseDataModule]):
4444
if data._num_of_labels is None:
4545
data.prepare_data()
46-
data.setup()
46+
data.setup()
4747
return data.num_of_labels
4848

4949
parser.link_arguments(
@@ -60,7 +60,9 @@ def call_data_methods(data: Type[XYBaseDataModule]):
6060
)
6161

6262
for kind in ("train", "val", "test"):
63-
for average in ("micro-f1", "macro-f1", "balanced-accuracy", "f1", "mse", "rmse","r2"):
63+
# todo: fix this
64+
# for average in ("mse", "rmse","r2"): # for regression
65+
for average in ("micro-f1", "macro-f1", "balanced-accuracy", "f1"): # for classification
6466
parser.link_arguments(
6567
"data.num_of_labels",
6668
f"model.init_args.{kind}_metrics.init_args.metrics.{average}.init_args.num_labels",
@@ -79,7 +81,7 @@ def call_data_methods(data: Type[XYBaseDataModule]):
7981
# parser.link_arguments(
8082
# "data.init_args.chebi_version",
8183
# "model.init_args.criterion.init_args.data_extractor.init_args.chebi_version",
82-
# )
84+
# )
8385

8486
@staticmethod
8587
def subcommands() -> Dict[str, Set[str]]:

0 commit comments

Comments
 (0)