-
Notifications
You must be signed in to change notification settings - Fork 42
/
Copy pathtrain_bert.py
126 lines (101 loc) · 4.47 KB
/
train_bert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
## Author: Thomas Capelle, Soumik Rakshit
## Mail: [email protected], [email protected]
""""Benchmarking apple M1Pro with Tensorflow
@wandbcode{apple_m1_pro}"""
import torch, wandb, argparse
from types import SimpleNamespace
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer, default_data_collator
from datasets import load_dataset
from utils import MicroTrainer, get_gpu_name
PROJECT = "pytorch-M1Pro"
ENTITY = "capecape"
GROUP = "pytorch"
config_defaults = SimpleNamespace(
batch_size=16,
epochs=1,
num_experiments=1,
learning_rate=1e-3,
model_name="bert-base-cased",
dataset="yelp_review_full",
device="mps",
gpu_name=get_gpu_name(),
num_workers=8,
mixed_precision=False,
syncro=False,
inference_only=False,
compile=False,
)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=config_defaults.batch_size)
parser.add_argument('--epochs', type=int, default=config_defaults.epochs)
parser.add_argument('--num_experiments', type=int, default=config_defaults.num_experiments)
parser.add_argument('--learning_rate', type=float, default=config_defaults.learning_rate)
parser.add_argument('--model_name', type=str, default=config_defaults.model_name)
parser.add_argument('--dataset', type=str, default=config_defaults.dataset)
parser.add_argument('--device', type=str, default=config_defaults.device)
parser.add_argument('--gpu_name', type=str, default=config_defaults.gpu_name)
parser.add_argument('--num_workers', type=int, default=config_defaults.num_workers)
parser.add_argument('--inference_only', action="store_true")
parser.add_argument('--mixed_precision', action="store_true")
parser.add_argument('--compile', action="store_true")
parser.add_argument('--tags', type=str, default=None)
parser.add_argument('--syncro', action="store_true")
return parser.parse_args()
def get_dls(model_name="bert-base-cased", dataset_name="yelp_review_full", batch_size=8, num_workers=0, sample_size=2048):
# download and prepare cc_news dataset
dataset = load_dataset(dataset_name)
# get bert and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True)
# tokenize the dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True)
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(sample_size))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(sample_size))
train_loader = DataLoader(
small_train_dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
collate_fn=default_data_collator,
)
test_loader = DataLoader(
small_eval_dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
collate_fn=default_data_collator,
)
return train_loader, test_loader
def get_model(model_name="bert-base-cased", num_labels=5):
return AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
def check_cuda(config):
if torch.cuda.is_available():
config.device = "cuda"
config.mixed_precision = True
config.pt_version = torch.__version__
config.cuda_version = torch.version.cuda
return config
def train_bert(config):
train_dl, test_loader = get_dls(
model_name=config.model_name,
batch_size=config.batch_size,
num_workers=config.num_workers)
model = get_model(config.model_name).to(config.device)
if torch.__version__ >= "2.0" and config.compile:
print("Compiling model...")
model = torch.compile(model)
trainer = MicroTrainer(model, train_dl, device=config.device, mixed_precision=config.mixed_precision, syncro=config.syncro)
tags = [f"pt{torch.__version__}", f"cuda{torch.version.cuda}"] + (config.tags.split(",") if config.tags is not None else [])
with wandb.init(project=PROJECT, entity=ENTITY, group=GROUP, tags=tags, config=config):
config = wandb.config
if not config.inference_only:
trainer.fit(config.epochs)
trainer.inference(test_loader)
if __name__ == "__main__":
args = parse_args()
args = check_cuda(args)
for _ in range(args.num_experiments):
train_bert(config=args)