@@ -43,7 +43,7 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser):
4343 def call_data_methods (data : Type [XYBaseDataModule ]):
4444 if data ._num_of_labels is None :
4545 data .prepare_data ()
46- data .setup ()
46+ data .setup ()
4747 return data .num_of_labels
4848
4949 parser .link_arguments (
@@ -60,18 +60,27 @@ def call_data_methods(data: Type[XYBaseDataModule]):
6060 )
6161
6262 for kind in ("train" , "val" , "test" ):
63- for average in ("micro-f1" , "macro-f1" , "balanced-accuracy" , "roc-auc" , "f1" , "mse" , "rmse" , "r2" ):
64- # When using lightning > 2.5.1 then need to uncomment all metrics that are not used
65- # for average in ("mse", "rmse","r2"): # for regression
66- # for average in ("f1", "roc-auc"): # for binary classification
67- # for average in ("micro-f1", "macro-f1", "roc-auc"): # for multilabel classification
68- # for average in ("micro-f1", "macro-f1", "balanced-accuracy", "roc-auc"): # for multilabel classification using balanced-accuracy
63+ for average in (
64+ "micro-f1" ,
65+ "macro-f1" ,
66+ "balanced-accuracy" ,
67+ "roc-auc" ,
68+ "f1" ,
69+ "mse" ,
70+ "rmse" ,
71+ "r2" ,
72+ ):
73+ # When using lightning > 2.5.1 then need to uncomment all metrics that are not used
74+ # for average in ("mse", "rmse","r2"): # for regression
75+ # for average in ("f1", "roc-auc"): # for binary classification
76+ # for average in ("micro-f1", "macro-f1", "roc-auc"): # for multilabel classification
77+ # for average in ("micro-f1", "macro-f1", "balanced-accuracy", "roc-auc"): # for multilabel classification using balanced-accuracy
6978 parser .link_arguments (
7079 "data.num_of_labels" ,
7180 f"model.init_args.{ kind } _metrics.init_args.metrics.{ average } .init_args.num_labels" ,
7281 apply_on = "instantiate" ,
7382 )
74-
83+
7584 parser .link_arguments (
7685 "data.num_of_labels" , "trainer.callbacks.init_args.num_labels"
7786 )
@@ -84,7 +93,7 @@ def call_data_methods(data: Type[XYBaseDataModule]):
8493 # parser.link_arguments(
8594 # "data.init_args.chebi_version",
8695 # "model.init_args.criterion.init_args.data_extractor.init_args.chebi_version",
87- # )
96+ # )
8897
8998 @staticmethod
9099 def subcommands () -> Dict [str , Set [str ]]:
0 commit comments