Skip to content

Commit 14535ff

Browse files
authored
Bugfix text-transformers.py (#170)
This proposal fixed two bugs: 1) The example did not run on the stsb task for GLUE, as you have the wrong if condition (it is always true). Changing `>= 1` to `> 1` fixes it. 2) The train data loader did not shuffle the dataset, which leads to quite a large performance drop for some datasets on glue. Adding shuffle=True to the train dataloader fixes it.
1 parent ee74e71 commit 14535ff

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

lightning_examples/text-transformers/text-transformers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def prepare_data(self):
9999
AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
100100

101101
def train_dataloader(self):
102-
return DataLoader(self.dataset["train"], batch_size=self.train_batch_size)
102+
return DataLoader(self.dataset["train"], batch_size=self.train_batch_size, shuffle=True)
103103

104104
def val_dataloader(self):
105105
if len(self.eval_splits) == 1:
@@ -183,7 +183,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0):
183183
outputs = self(**batch)
184184
val_loss, logits = outputs[:2]
185185

186-
if self.hparams.num_labels >= 1:
186+
if self.hparams.num_labels > 1:
187187
preds = torch.argmax(logits, axis=1)
188188
elif self.hparams.num_labels == 1:
189189
preds = logits.squeeze()

0 commit comments

Comments
 (0)