Skip to content

Commit 23ef2a7

Browse files
authoredFeb 20, 2023
[ERNIE-M] Fix layerwise decay ratio to model_zoo/ernie-m (PaddlePaddle#4892)
1 parent c5d9c9e commit 23ef2a7

File tree

1 file changed

+34
-38
lines changed

1 file changed

+34
-38
lines changed
 

‎model_zoo/ernie-m/run_classifier.py

+34-38
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,9 @@
2424
from datasets import load_dataset
2525
from paddle.io import Dataset
2626
from paddle.metric import Accuracy
27-
from paddle.optimizer import AdamW
2827

2928
import paddlenlp
3029
from paddlenlp.data import DataCollatorWithPadding
31-
from paddlenlp.ops.optimizer import layerwise_lr_decay
3230
from paddlenlp.trainer import (
3331
PdArgumentParser,
3432
Trainer,
@@ -39,7 +37,6 @@
3937
AutoModelForSequenceClassification,
4038
AutoTokenizer,
4139
ErnieMForSequenceClassification,
42-
LinearDecayWithWarmup,
4340
)
4441
from paddlenlp.utils.log import logger
4542

@@ -220,40 +217,6 @@ def compute_metrics(p):
220217
accu = metric.accumulate()
221218
return {"accuracy": accu}
222219

223-
n_layers = model.ernie_m.config["num_hidden_layers"]
224-
warmup = training_args.warmup_steps if training_args.warmup_steps > 0 else training_args.warmup_ratio
225-
if training_args.do_train:
226-
num_training_steps = (
227-
training_args.max_steps
228-
if training_args.max_steps > 0
229-
else len(train_ds) // training_args.train_batch_size * training_args.num_train_epochs
230-
)
231-
else:
232-
num_training_steps = 10
233-
234-
lr_scheduler = LinearDecayWithWarmup(training_args.learning_rate, num_training_steps, warmup)
235-
236-
# Generate parameter names needed to perform weight decay.
237-
# All bias and LayerNorm parameters are excluded.
238-
decay_params = [p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"])]
239-
# Construct dict
240-
name_dict = dict()
241-
for n, p in model.named_parameters():
242-
name_dict[p.name] = n
243-
244-
simple_lr_setting = partial(layerwise_lr_decay, model_args.layerwise_decay, name_dict, n_layers)
245-
246-
optimizer = AdamW(
247-
learning_rate=lr_scheduler,
248-
beta1=0.9,
249-
beta2=0.999,
250-
epsilon=training_args.adam_epsilon,
251-
parameters=model.parameters(),
252-
weight_decay=training_args.weight_decay,
253-
apply_decay_param_fun=lambda x: x in decay_params,
254-
lr_ratio=simple_lr_setting,
255-
)
256-
257220
trainer = Trainer(
258221
model=model,
259222
args=training_args,
@@ -262,9 +225,41 @@ def compute_metrics(p):
262225
eval_dataset=eval_ds if training_args.do_eval else None,
263226
tokenizer=tokenizer,
264227
compute_metrics=compute_metrics,
265-
optimizers=[optimizer, lr_scheduler],
228+
# optimizers=[optimizer, lr_scheduler],
266229
)
267230

231+
def using_layerwise_lr_decay(layerwise_decay, model, training_args):
232+
"""
233+
Generate parameter names needed to perform weight decay.
234+
All bias and LayerNorm parameters are excluded.
235+
"""
236+
# params_list = [{"params": param, "learning_rate": lr * decay_ratio}, ... ]
237+
params_list = []
238+
n_layers = model.config.num_hidden_layers
239+
for name, param in model.named_parameters():
240+
ratio = 1.0
241+
param_to_train = {"params": param, "dygraph_key_name": name}
242+
if any(nd in name for nd in ["bias", "norm"]):
243+
param_to_train["weight_decay"] = 0.0
244+
else:
245+
param_to_train["weight_decay"] = training_args.weight_decay
246+
247+
if "encoder.layers" in name:
248+
idx = name.find("encoder.layers.")
249+
layer = int(name[idx:].split(".")[2])
250+
ratio = layerwise_decay ** (n_layers - layer)
251+
elif "embedding" in name:
252+
ratio = layerwise_decay ** (n_layers + 1)
253+
254+
param_to_train["learning_rate"] = ratio
255+
256+
params_list.append(param_to_train)
257+
return params_list
258+
259+
params_to_train = using_layerwise_lr_decay(model_args.layerwise_decay, model, training_args)
260+
261+
trainer.set_optimizer_grouped_parameters(params_to_train)
262+
268263
checkpoint = None
269264
if training_args.resume_from_checkpoint is not None:
270265
checkpoint = training_args.resume_from_checkpoint
@@ -320,6 +315,7 @@ def compute_metrics(p):
320315
paddlenlp.transformers.export_model(
321316
model=model_to_save, input_spec=input_spec, path=model_args.export_model_dir
322317
)
318+
trainer.tokenizer.save_pretrained(model_args.export_model_dir)
323319

324320

325321
if __name__ == "__main__":

0 commit comments

Comments
 (0)
Please sign in to comment.