Skip to content

Commit 6d833cf

Browse files
Meet Patelquic-mamta
authored andcommitted
Made fixes to training script based on recent findings.
Signed-off-by: meetkuma <[email protected]>
1 parent 03d9871 commit 6d833cf

File tree

10 files changed

+150
-66
lines changed

10 files changed

+150
-66
lines changed

QEfficient/cloud/finetune.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#
66
# -----------------------------------------------------------------------------
77

8+
import os
89
import random
910
import warnings
1011
from typing import Any, Dict, Optional, Union
@@ -139,7 +140,7 @@ def load_model_and_tokenizer(
139140
train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name
140141
)
141142
if not tokenizer.pad_token_id:
142-
tokenizer.pad_token_id = tokenizer.eos_token_id
143+
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
143144

144145
# If there is a mismatch between tokenizer vocab size and embedding matrix,
145146
# throw a warning and then expand the embedding matrix
@@ -195,7 +196,9 @@ def apply_peft(
195196
else:
196197
peft_config = generate_peft_config(train_config, peft_config_file, **kwargs)
197198
model = get_peft_model(model, peft_config)
198-
model.print_trainable_parameters()
199+
200+
if os.getenv("LOCAL_RANK", 0) == 0:
201+
model.print_trainable_parameters()
199202

200203
return model
201204

QEfficient/finetune/data/sampler.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,18 @@ def __init__(
4949
) -> None:
5050
random.seed(seed)
5151
self.batch_sampler = LengthBasedBatchSampler(
52-
data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle
52+
data_source, batch_size=batch_size, drop_last=False, shuffle=shuffle
5353
)
5454
self.num_replicas = num_replicas
5555
self.rank = rank
56+
assert len(self.batch_sampler) % self.num_replicas == 0, (
57+
"Length of batch samples should be divisible by number to processes in DDP."
58+
)
59+
self.sampler_len = len(self.batch_sampler) // self.num_replicas
60+
self.max_length = len(self.batch_sampler)
5661

5762
def __iter__(self):
58-
max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas
59-
return islice(self.batch_sampler, self.rank, max_length, self.num_replicas)
63+
return islice(self.batch_sampler, self.rank, self.max_length, self.num_replicas)
6064

6165
def __len__(self):
62-
return len(self.batch_sampler) // self.num_replicas
66+
return self.sampler_len

QEfficient/finetune/dataset/alpaca_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import torch
1212
from torch.utils.data import Dataset
1313

14+
from QEfficient.finetune.dataset.helper import IGNORE_INDEX
15+
1416
PROMPT_DICT = {
1517
"prompt_input": (
1618
"Below is an instruction that describes a task, paired with an input that provides further context. "
@@ -42,8 +44,6 @@ def __len__(self):
4244
return len(self.ann)
4345

4446
def __getitem__(self, index):
45-
IGNORE_INDEX = -100 # The default setting
46-
4747
ann = self.ann[index]
4848
if ann.get("input", "") == "":
4949
prompt = PROMPT_DICT["prompt_no_input"].format_map(ann)

QEfficient/finetune/dataset/grammar_dataset.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from datasets import load_dataset
1111
from torch.utils.data import Dataset
1212

13+
from QEfficient.finetune.dataset.helper import IGNORE_INDEX
14+
1315

1416
class grammar(Dataset):
1517
def __init__(self, tokenizer, csv_name=None, context_length=None):
@@ -58,7 +60,7 @@ def convert_to_features(self, example_batch):
5860
sample = {
5961
"input_ids": prompt_ids + label_ids,
6062
"attention_mask": [1] * len(prompt_ids + label_ids),
61-
"labels": [-100] * len(prompt_ids) + label_ids,
63+
"labels": [IGNORE_INDEX] * len(prompt_ids) + label_ids,
6264
}
6365

6466
return sample

QEfficient/finetune/dataset/gsm8k_dataset.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
from datasets import Dataset, load_dataset
1111

12+
from QEfficient.finetune.dataset.helper import IGNORE_INDEX
13+
1214
default_instruction = """### Instruction: Solve the math question using a basic calculator.
1315
Calculator can be invoked using the format: <<expression=answer>>.
1416
"expression" can be one of the 4 arithmetic operations, and "answer" will be filled in for you.
@@ -26,9 +28,8 @@ def tokenize_and_mask(row: Dict[str, str], *, tokenizer, instruction) -> Dict[st
2628

2729
input_str = tokenizer.bos_token + instruction.format(**row)
2830
ques_ids = tokenizer(input_str, add_special_tokens=False, return_attention_mask=False)["input_ids"]
29-
ans_ids = tokenizer(row["answer"] + tokenizer.eos_token, add_special_tokens=False, return_attention_mask=False)[
30-
"input_ids"
31-
]
31+
ans_str = row["answer"] + tokenizer.eos_token
32+
ans_ids = tokenizer(ans_str, add_special_tokens=False, return_attention_mask=False)["input_ids"]
3233
input_ids = ques_ids + ans_ids
3334

3435
# State machine to recognize <<expression=answer>> and mask answer
@@ -39,11 +40,11 @@ def tokenize_and_mask(row: Dict[str, str], *, tokenizer, instruction) -> Dict[st
3940
elif mode == 1 and token in equal_tokens:
4041
mode = 2
4142
elif mode == 2:
42-
ans_ids[i] = -100
43+
ans_ids[i] = IGNORE_INDEX
4344
if token in end_tokens:
4445
mode = 0
4546

46-
labels = [-100] * len(ques_ids) + ans_ids
47+
labels = [IGNORE_INDEX] * len(ques_ids) + ans_ids
4748

4849
inputs = {"input_ids": input_ids, "labels": labels}
4950
return inputs
@@ -54,7 +55,7 @@ def pad_to_max_length(row: Dict[str, list], *, tokenizer, max_length: int) -> Di
5455
return {
5556
"input_ids": row["input_ids"] + [tokenizer.pad_token_id] * (max_length - length),
5657
"attention_mask": [1] * length + [0] * (max_length - length),
57-
"labels": row["labels"] + [-100] * (max_length - length),
58+
"labels": row["labels"] + [IGNORE_INDEX] * (max_length - length),
5859
}
5960

6061

QEfficient/finetune/dataset/helper.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
IGNORE_INDEX = -100

QEfficient/finetune/dataset/samsum_dataset.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77

88
import datasets
99

10+
from QEfficient.finetune.dataset.helper import IGNORE_INDEX
11+
1012

1113
def get_preprocessed_samsum(dataset_config, tokenizer, split, context_length=None):
12-
dataset = datasets.load_dataset("Samsung/samsum", split=split, trust_remote_code=True)
14+
dataset = datasets.load_dataset("knkarthick/samsum", split=split, trust_remote_code=True)
1315

1416
prompt = "Summarize this dialog:\n{dialog}\n---\nSummary:\n"
1517

@@ -35,10 +37,15 @@ def tokenize_add_label(sample):
3537
pad_to_max_length=True,
3638
)
3739

40+
labels = [IGNORE_INDEX] * len(prompt) + summary
41+
# labels = [l if l != tokenizer.pad_token_id else -100 for l in labels]
42+
# sentence: <bos> <prompt> <summary> <eos> <pad>
43+
# labels : -100 -100 <summary> <eos> -100
44+
3845
sample = {
3946
"input_ids": prompt + summary,
4047
"attention_mask": [1] * (len(prompt) + len(summary)),
41-
"labels": [-100] * len(prompt) + summary,
48+
"labels": labels,
4249
}
4350

4451
return sample

QEfficient/finetune/utils/config_utils.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,26 @@ def generate_dataset_config(dataset_name: str) -> Any:
115115
return dataset_config
116116

117117

118+
def pad_dataset(dataset, batch_size, num_replicas):
119+
reminder = len(dataset) % (batch_size * num_replicas)
120+
if reminder == 0:
121+
return dataset
122+
123+
sample_input = dataset[0]
124+
sample_input["labels"] = [-100] * len(sample_input["labels"])
125+
num_pads = (batch_size * num_replicas) - reminder
126+
for _ in range(num_pads):
127+
dataset = dataset.add_item(sample_input)
128+
return dataset
129+
130+
118131
def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode):
119132
kwargs = {}
120133
batch_size = train_config.batch_size_training if mode == "train" else train_config.val_batch_size
121134
if train_config.enable_ddp:
135+
print("Length of dataset before: ", len(dataset))
136+
dataset = pad_dataset(dataset, batch_size, 2)
137+
print("Length of dataset after: ", len(dataset))
122138
if train_config.enable_sorting_for_ddp:
123139
if train_config.context_length:
124140
raise ValueError(
@@ -134,13 +150,12 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode):
134150
)
135151
else:
136152
kwargs["sampler"] = data_utils.DistributedSampler(
137-
dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True
153+
dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True, drop_last=False
138154
)
139155
kwargs["batch_size"] = batch_size
140-
kwargs["drop_last"] = True
141156
else:
142157
kwargs["batch_size"] = batch_size
143-
kwargs["drop_last"] = True
158+
kwargs["drop_last"] = False
144159
kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
145160
return kwargs
146161

QEfficient/finetune/utils/helper.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
import os
9+
10+
11+
def print_rank_0(msg):
12+
if os.getenv("LOCAL_RANK", None) in [None, 0]:
13+
print(msg)

0 commit comments

Comments
 (0)