Skip to content

Commit 2462c6e

Browse files
Charlotte TumescheitCharlotte Tumescheit
authored andcommitted
black-lint fix
1 parent 9b29411 commit 2462c6e

File tree

7 files changed

+195
-106
lines changed

7 files changed

+195
-106
lines changed

chebai/cli.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser):
4343
def call_data_methods(data: Type[XYBaseDataModule]):
4444
if data._num_of_labels is None:
4545
data.prepare_data()
46-
data.setup()
46+
data.setup()
4747
return data.num_of_labels
4848

4949
parser.link_arguments(
@@ -60,18 +60,27 @@ def call_data_methods(data: Type[XYBaseDataModule]):
6060
)
6161

6262
for kind in ("train", "val", "test"):
63-
for average in ("micro-f1", "macro-f1", "balanced-accuracy", "roc-auc", "f1", "mse", "rmse", "r2"):
64-
# When using lightning > 2.5.1 then need to uncomment all metrics that are not used
65-
# for average in ("mse", "rmse","r2"): # for regression
66-
# for average in ("f1", "roc-auc"): # for binary classification
67-
# for average in ("micro-f1", "macro-f1", "roc-auc"): # for multilabel classification
68-
# for average in ("micro-f1", "macro-f1", "balanced-accuracy", "roc-auc"): # for multilabel classification using balanced-accuracy
63+
for average in (
64+
"micro-f1",
65+
"macro-f1",
66+
"balanced-accuracy",
67+
"roc-auc",
68+
"f1",
69+
"mse",
70+
"rmse",
71+
"r2",
72+
):
73+
# When using lightning > 2.5.1 then need to uncomment all metrics that are not used
74+
# for average in ("mse", "rmse","r2"): # for regression
75+
# for average in ("f1", "roc-auc"): # for binary classification
76+
# for average in ("micro-f1", "macro-f1", "roc-auc"): # for multilabel classification
77+
# for average in ("micro-f1", "macro-f1", "balanced-accuracy", "roc-auc"): # for multilabel classification using balanced-accuracy
6978
parser.link_arguments(
7079
"data.num_of_labels",
7180
f"model.init_args.{kind}_metrics.init_args.metrics.{average}.init_args.num_labels",
7281
apply_on="instantiate",
7382
)
74-
83+
7584
parser.link_arguments(
7685
"data.num_of_labels", "trainer.callbacks.init_args.num_labels"
7786
)
@@ -84,7 +93,7 @@ def call_data_methods(data: Type[XYBaseDataModule]):
8493
# parser.link_arguments(
8594
# "data.init_args.chebi_version",
8695
# "model.init_args.criterion.init_args.data_extractor.init_args.chebi_version",
87-
# )
96+
# )
8897

8998
@staticmethod
9099
def subcommands() -> Dict[str, Set[str]]:

chebai/models/electra.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class ElectraPre(ChebaiBaseNet):
4141

4242
def __init__(self, config: Dict[str, Any] = None, **kwargs: Any):
4343
super().__init__(config=config, **kwargs)
44-
44+
4545
self.generator_config = ElectraConfig(**config["generator"])
4646
self.generator = ElectraForMaskedLM(self.generator_config)
4747
self.discriminator_config = ElectraConfig(**config["discriminator"])

0 commit comments

Comments
 (0)