|
| 1 | +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +import os |
| 15 | +import argparse |
| 16 | +import random |
| 17 | +import time |
| 18 | +import distutils.util |
| 19 | +from pprint import pprint |
| 20 | +from functools import partial |
| 21 | +from tqdm import tqdm |
| 22 | +import numpy as np |
| 23 | + |
| 24 | +import paddle |
| 25 | +import paddle.nn as nn |
| 26 | +from paddle.io import BatchSampler, DistributedBatchSampler, DataLoader |
| 27 | +from paddlenlp.transformers import T5ForConditionalGeneration, T5Tokenizer |
| 28 | +from paddlenlp.transformers import LinearDecayWithWarmup |
| 29 | +from paddlenlp.utils.log import logger |
| 30 | +from paddlenlp.datasets import load_dataset |
| 31 | +from paddlenlp.data import Tuple, Stack, Pad |
| 32 | +from utils import convert_example, compute_metrics |
| 33 | + |
| 34 | + |
| 35 | +def parse_args(): |
| 36 | + parser = argparse.ArgumentParser() |
| 37 | + # Required parameters |
| 38 | + parser.add_argument("--model_name_or_path", |
| 39 | + default="t5-base", |
| 40 | + type=str, |
| 41 | + required=True, |
| 42 | + help="Path to pre-trained model. ") |
| 43 | + parser.add_argument( |
| 44 | + "--dataset_name", |
| 45 | + default="squad", |
| 46 | + type=str, |
| 47 | + required=True, |
| 48 | + help="The name of the dataset to use. Selected in the list: " + "squad") |
| 49 | + parser.add_argument( |
| 50 | + "--output_dir", |
| 51 | + default="output", |
| 52 | + type=str, |
| 53 | + required=True, |
| 54 | + help= |
| 55 | + "The output directory where the model predictions and checkpoints will be written.", |
| 56 | + ) |
| 57 | + parser.add_argument( |
| 58 | + "--max_source_length", |
| 59 | + default=1024, |
| 60 | + type=int, |
| 61 | + help="The maximum total input sequence length after " |
| 62 | + "tokenization.Sequences longer than this will be truncated, sequences shorter will be padded.", |
| 63 | + ) |
| 64 | + parser.add_argument( |
| 65 | + "--min_target_length", |
| 66 | + default=0, |
| 67 | + type=int, |
| 68 | + help= |
| 69 | + "The minimum total sequence length for target text when generating. ") |
| 70 | + parser.add_argument( |
| 71 | + "--max_target_length", |
| 72 | + default=142, |
| 73 | + type=int, |
| 74 | + help="The maximum total sequence length for target text after " |
| 75 | + "tokenization. Sequences longer than this will be truncated, sequences shorter will be padded." |
| 76 | + "during ``evaluate`` and ``predict``.", |
| 77 | + ) |
| 78 | + parser.add_argument("--learning_rate", |
| 79 | + default=1e-4, |
| 80 | + type=float, |
| 81 | + help="The initial learning rate for Adam.") |
| 82 | + parser.add_argument( |
| 83 | + "--num_train_epochs", |
| 84 | + default=3, |
| 85 | + type=int, |
| 86 | + help="Total number of training epochs to perform.", |
| 87 | + ) |
| 88 | + parser.add_argument("--logging_steps", |
| 89 | + type=int, |
| 90 | + default=100, |
| 91 | + help="Log every X updates steps.") |
| 92 | + parser.add_argument("--save_steps", |
| 93 | + type=int, |
| 94 | + default=100, |
| 95 | + help="Save checkpoint every X updates steps.") |
| 96 | + parser.add_argument( |
| 97 | + "--train_batch_size", |
| 98 | + default=20, |
| 99 | + type=int, |
| 100 | + help="Batch size per GPU/CPU for training.", |
| 101 | + ) |
| 102 | + parser.add_argument( |
| 103 | + "--eval_batch_size", |
| 104 | + default=12, |
| 105 | + type=int, |
| 106 | + help="Batch size per GPU/CPU for evaluation.", |
| 107 | + ) |
| 108 | + parser.add_argument("--weight_decay", |
| 109 | + default=0.0, |
| 110 | + type=float, |
| 111 | + help="Weight decay if we apply some.") |
| 112 | + parser.add_argument( |
| 113 | + "--warmup_steps", |
| 114 | + default=0, |
| 115 | + type=int, |
| 116 | + help= |
| 117 | + "Linear warmup over warmup_steps. If > 0: Override warmup_proportion") |
| 118 | + parser.add_argument("--warmup_proportion", |
| 119 | + default=0.1, |
| 120 | + type=float, |
| 121 | + help="Linear warmup proportion over total steps.") |
| 122 | + parser.add_argument("--adam_epsilon", |
| 123 | + default=1e-6, |
| 124 | + type=float, |
| 125 | + help="Epsilon for Adam optimizer.") |
| 126 | + parser.add_argument( |
| 127 | + "--max_steps", |
| 128 | + default=-1, |
| 129 | + type=int, |
| 130 | + help= |
| 131 | + "If > 0: set total number of training steps to perform. Override num_train_epochs.", |
| 132 | + ) |
| 133 | + parser.add_argument("--seed", |
| 134 | + default=42, |
| 135 | + type=int, |
| 136 | + help="random seed for initialization") |
| 137 | + parser.add_argument( |
| 138 | + "--device", |
| 139 | + default="gpu", |
| 140 | + type=str, |
| 141 | + choices=["cpu", "gpu", "xpu"], |
| 142 | + help="The device to select to train the model, is must be cpu/gpu/xpu.") |
| 143 | + parser.add_argument("--use_amp", |
| 144 | + default=False, |
| 145 | + type=distutils.util.strtobool, |
| 146 | + help="Enable mixed precision training.") |
| 147 | + parser.add_argument("--scale_loss", |
| 148 | + default=2**15, |
| 149 | + type=float, |
| 150 | + help="The value of scale_loss for fp16.") |
| 151 | + args = parser.parse_args() |
| 152 | + return args |
| 153 | + |
| 154 | + |
| 155 | +def set_seed(args): |
| 156 | + # Use the same data seed(for data shuffle) for all procs to guarantee data |
| 157 | + # consistency after sharding. |
| 158 | + random.seed(args.seed) |
| 159 | + np.random.seed(args.seed) |
| 160 | + # Maybe different op seeds(for dropout) for different procs is better. By: |
| 161 | + # `paddle.seed(args.seed + paddle.distributed.get_rank())` |
| 162 | + paddle.seed(args.seed) |
| 163 | + |
| 164 | + |
| 165 | +@paddle.no_grad() |
| 166 | +def evaluate(model, data_loader, tokenizer, ignore_pad_token_for_loss, |
| 167 | + min_target_length, max_target_length): |
| 168 | + model.eval() |
| 169 | + all_preds = [] |
| 170 | + all_labels = [] |
| 171 | + model = model._layers if isinstance(model, paddle.DataParallel) else model |
| 172 | + for batch in tqdm(data_loader, total=len(data_loader), desc="Eval step"): |
| 173 | + input_ids, _, _, labels = batch |
| 174 | + preds = model.generate(input_ids=input_ids, |
| 175 | + min_length=min_target_length, |
| 176 | + max_length=max_target_length, |
| 177 | + use_cache=True)[0] |
| 178 | + all_preds.extend(preds.numpy()) |
| 179 | + all_labels.extend(labels.numpy()) |
| 180 | + bleu_result, decoded_preds, decoded_labels = compute_metrics( |
| 181 | + all_preds, all_labels, tokenizer, ignore_pad_token_for_loss) |
| 182 | + logger.info(bleu_result) |
| 183 | + model.train() |
| 184 | + |
| 185 | + |
| 186 | +def do_train(args): |
| 187 | + paddle.set_device(args.device) |
| 188 | + if paddle.distributed.get_world_size() > 1: |
| 189 | + paddle.distributed.init_parallel_env() |
| 190 | + |
| 191 | + set_seed(args) |
| 192 | + tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path) |
| 193 | + model = T5ForConditionalGeneration.from_pretrained(args.model_name_or_path) |
| 194 | + trans_func = partial( |
| 195 | + convert_example, |
| 196 | + tokenizer=tokenizer, |
| 197 | + decoder_start_token_id=model.t5.bos_token_id, |
| 198 | + max_source_length=args.max_source_length, |
| 199 | + max_target_length=args.max_target_length, |
| 200 | + ignore_pad_token_for_loss=args.ignore_pad_token_for_loss) |
| 201 | + logger.info("Loading train and dev dataset: %s" % args.dataset_name) |
| 202 | + train_set, dev_set = load_dataset(args.dataset_name, |
| 203 | + splits=["train_v1", "dev_v1"]) |
| 204 | + logger.info("Loaded train and dev dataset: %s" % args.dataset_name) |
| 205 | + train_set = train_set.map(trans_func, lazy=True) |
| 206 | + train_batch_sampler = DistributedBatchSampler( |
| 207 | + train_set, batch_size=args.train_batch_size, shuffle=True) |
| 208 | + |
| 209 | + batchify_fn = lambda samples, fn=Tuple( |
| 210 | + Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"), # input_ids |
| 211 | + Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64" |
| 212 | + ), # attention_mask |
| 213 | + Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64" |
| 214 | + ), # decoder_input_ids |
| 215 | + Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"), # labels |
| 216 | + ): fn(samples) |
| 217 | + train_data_loader = DataLoader(dataset=train_set, |
| 218 | + batch_sampler=train_batch_sampler, |
| 219 | + num_workers=0, |
| 220 | + collate_fn=batchify_fn, |
| 221 | + return_list=True) |
| 222 | + dev_set = dev_set.map(trans_func, lazy=True) |
| 223 | + dev_batch_sampler = BatchSampler(dev_set, |
| 224 | + batch_size=args.eval_batch_size, |
| 225 | + shuffle=False) |
| 226 | + dev_data_loader = DataLoader(dataset=dev_set, |
| 227 | + batch_sampler=dev_batch_sampler, |
| 228 | + num_workers=0, |
| 229 | + collate_fn=batchify_fn, |
| 230 | + return_list=True) |
| 231 | + |
| 232 | + if paddle.distributed.get_world_size() > 1: |
| 233 | + model = paddle.DataParallel(model) |
| 234 | + |
| 235 | + num_training_steps = args.max_steps if args.max_steps > 0 else ( |
| 236 | + len(train_data_loader) * args.num_train_epochs) |
| 237 | + warmup = args.warmup_steps if args.warmup_steps > 0 else args.warmup_proportion |
| 238 | + |
| 239 | + lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps, |
| 240 | + warmup) |
| 241 | + |
| 242 | + # Generate parameter names needed to perform weight decay. |
| 243 | + # All bias and LayerNorm parameters are excluded. |
| 244 | + decay_params = [ |
| 245 | + p.name for n, p in model.named_parameters() |
| 246 | + if not any(nd in n for nd in ["bias", "norm"]) |
| 247 | + ] |
| 248 | + optimizer = paddle.optimizer.AdamW( |
| 249 | + learning_rate=lr_scheduler, |
| 250 | + beta1=0.9, |
| 251 | + beta2=0.999, |
| 252 | + epsilon=args.adam_epsilon, |
| 253 | + parameters=model.parameters(), |
| 254 | + weight_decay=args.weight_decay, |
| 255 | + apply_decay_param_fun=lambda x: x in decay_params) |
| 256 | + |
| 257 | + if args.use_amp: |
| 258 | + scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss) |
| 259 | + global_step = 0 |
| 260 | + tic_train = time.time() |
| 261 | + for epoch in tqdm(range(args.num_train_epochs), desc="Epoch"): |
| 262 | + for step, batch in tqdm(enumerate(train_data_loader), |
| 263 | + desc="Train step", |
| 264 | + total=len(train_data_loader)): |
| 265 | + global_step += 1 |
| 266 | + input_ids, attention_mask, decoder_input_ids, labels = batch |
| 267 | + with paddle.amp.auto_cast( |
| 268 | + args.use_amp, |
| 269 | + custom_white_list=["layer_norm", "softmax", "gelu"]): |
| 270 | + output = model(input_ids, |
| 271 | + attention_mask, |
| 272 | + decoder_input_ids, |
| 273 | + labels=labels) |
| 274 | + loss = output[0] |
| 275 | + if args.use_amp: |
| 276 | + scaled_loss = scaler.scale(loss) |
| 277 | + scaled_loss.backward() |
| 278 | + scaler.minimize(optimizer, scaled_loss) |
| 279 | + else: |
| 280 | + loss.backward() |
| 281 | + optimizer.step() |
| 282 | + lr_scheduler.step() |
| 283 | + optimizer.clear_grad() |
| 284 | + if global_step % args.logging_steps == 0: |
| 285 | + logger.info( |
| 286 | + "global step %d/%d, epoch: %d, batch: %d, rank_id: %s, loss: %f, lr: %.10f, speed: %.4f step/s" |
| 287 | + % (global_step, num_training_steps, epoch, step, |
| 288 | + paddle.distributed.get_rank(), loss, optimizer.get_lr(), |
| 289 | + args.logging_steps / (time.time() - tic_train))) |
| 290 | + tic_train = time.time() |
| 291 | + if global_step % args.save_steps == 0 or global_step == num_training_steps: |
| 292 | + tic_eval = time.time() |
| 293 | + evaluate(model, dev_data_loader, tokenizer, |
| 294 | + args.ignore_pad_token_for_loss, args.min_target_length, |
| 295 | + args.max_target_length) |
| 296 | + logger.info("eval done total : %s s" % (time.time() - tic_eval)) |
| 297 | + if paddle.distributed.get_rank() == 0: |
| 298 | + output_dir = os.path.join( |
| 299 | + args.output_dir, "t5_model_%d.pdparams" % global_step) |
| 300 | + if not os.path.exists(output_dir): |
| 301 | + os.makedirs(output_dir) |
| 302 | + # Need better way to get inner model of DataParallel |
| 303 | + model_to_save = model._layers if isinstance( |
| 304 | + model, paddle.DataParallel) else model |
| 305 | + model_to_save.save_pretrained(output_dir) |
| 306 | + tokenizer.save_pretrained(output_dir) |
| 307 | + if global_step >= num_training_steps: |
| 308 | + return |
| 309 | + if paddle.distributed.get_rank() == 0: |
| 310 | + output_dir = os.path.join(args.output_dir, |
| 311 | + "t5_model_final_%d.pdparams" % global_step) |
| 312 | + if not os.path.exists(output_dir): |
| 313 | + os.makedirs(output_dir) |
| 314 | + # Need better way to get inner model of DataParallel |
| 315 | + model_to_save = model._layers if isinstance( |
| 316 | + model, paddle.DataParallel) else model |
| 317 | + model_to_save.save_pretrained(output_dir) |
| 318 | + tokenizer.save_pretrained(output_dir) |
| 319 | + |
| 320 | + |
| 321 | +if __name__ == "__main__": |
| 322 | + args = parse_args() |
| 323 | + pprint(args) |
| 324 | + do_train(args) |
0 commit comments