-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
232 lines (202 loc) · 11.5 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import logging
logging.basicConfig(level=logging.INFO)
if __name__ == "__main__":
from argparse import ArgumentParser
from pytorch_lightning import Trainer
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
import pytorch_lightning as pl
rank_zero_info("########## work in progress ##########")
parser = ArgumentParser()
parser.add_argument("--load_model", default="", type=str, help="path of rwkv model") # full path, with .pth
parser.add_argument("--model_path", type=str, default=None, help="path of visualrwkv model") #
parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb
parser.add_argument("--proj_dir", default="out", type=str)
parser.add_argument("--run_name", default='demo_run', type=str,
help="run name for wandb. force to consider what is the purpose of this run")
parser.add_argument("--random_seed", default="-1", type=int)
parser.add_argument("--data_file", default="", type=str)
parser.add_argument("--data_type", default="utf-8", type=str)
parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
parser.add_argument("--ctx_len", default=1024, type=int)
parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps
parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final
parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x
parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs"
parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU)
parser.add_argument("--n_layer", default=6, type=int)
parser.add_argument("--n_embd", default=512, type=int)
parser.add_argument("--dim_att", default=0, type=int)
parser.add_argument("--dim_ffn", default=0, type=int)
parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better)
parser.add_argument("--head_size_a", default=64, type=int)
parser.add_argument("--head_size_divisor", default=8, type=int)
parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
parser.add_argument("--lr_final", default=1e-5, type=float)
parser.add_argument("--warmup_steps", default=-1, type=int) # try 50 if you load a model
parser.add_argument("--beta1", default=0.9, type=float)
parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence
parser.add_argument("--adam_eps", default=1e-8, type=float)
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
parser.add_argument("--dropout", default=0, type=float) # try 0.01 / 0.02 / 0.05 / 0.1
parser.add_argument("--weight_decay", default=0, type=float) # try 0.1 / 0.01 / 0.001
parser.add_argument("--weight_decay_final", default=-1, type=float)
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
parser.add_argument("--vision_tower_name", default="openai/clip-vit-base-patch32", type=str) # openai/clip-vit-base-patch32
parser.add_argument("--image_folder", type=str, default="images")
parser.add_argument("--grid_size", type=int, default=8) # -1 for no grid, 0 for cls token, 1 for global avg, 8 for 64 tokens
parser.add_argument("--detail", type=str, default="low")
parser.add_argument("--freeze_rwkv", default=0, type=int) # layers to freeze
parser.add_argument("--freeze_proj", default=0, type=int) # freeze proj layer
parser.add_argument("--image_position", default='first', type=str) # 'first' or 'last' or ''middle
parser.add_argument("--print_param_shape", default=0, type=int) # print param shape
parser.add_argument("--max_spots", default=10, type=int)
parser.add_argument("--max_new_tokens", default=128, type=int)
parser.add_argument("--stage", default=1, type=int)
parser.add_argument("--pin_memory", default=True, type=bool)
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
args.inference_mode = False
########################################################################################################
import os, warnings, math, datetime, sys, time
import numpy as np
import torch
from torch.utils.data import DataLoader
if "deepspeed" in args.strategy:
import deepspeed
from pytorch_lightning import seed_everything
# os.environ['TORCH_USE_CUDA_DSA'] = '1'
# torch.backends.cudnn.benchmark=False
# torch.backends.cudnn.deterministic=True
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
# os.environ["TORCH_USE_CUDA_DSA"] = '1'
# print("VISIBLE_DEVICES: ", os.environ["CUDA_VISIBLE_DEVICES"])
# print("DEVICE_COUNT: ", torch.cuda.device_count())
# print("CURRENT_DEVICE: ", torch.cuda.current_device())
if args.random_seed >= 0:
print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3)
seed_everything(args.random_seed)
np.set_printoptions(precision=4, suppress=True, linewidth=200)
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
# os.environ["WDS_SHOW_SEED"] = "1"
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
args.enable_checkpointing = False
args.replace_sampler_ddp = False
args.logger = False
args.gradient_clip_val = 1.0
args.num_sanity_val_steps = 0
args.check_val_every_n_epoch = int(1e20)
args.log_every_n_steps = int(1e20)
args.max_epochs = args.epoch_count # continue forever
args.betas = (args.beta1, args.beta2)
args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
os.environ["RWKV_CTXLEN"] = str(args.ctx_len)
os.environ["RWKV_HEAD_SIZE_A"] = str(args.head_size_a)
if args.dim_att <= 0:
args.dim_att = args.n_embd
if args.dim_ffn <= 0:
args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32) # default = 3.5x emb size
#args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
if not os.path.exists(args.proj_dir):
os.makedirs(args.proj_dir)
samples_per_epoch = args.epoch_steps * args.real_bsz
tokens_per_epoch = samples_per_epoch * args.ctx_len
try:
deepspeed_version = deepspeed.__version__
except:
deepspeed_version = None
pass
rank_zero_info(
f"""
############################################################################
#
# RWKV-5 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''}
#
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
#
# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1} (will continue afterwards), save every {args.epoch_save} epoch
#
# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens
#
# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len
#
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
#
# Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer
# Found deepspeed {deepspeed_version}, recommend 0.7.0 (faster than newer versions)
# Found pytorch_lightning {pl.__version__}, recommend 1.9.5
#
############################################################################
"""
)
rank_zero_info(str(vars(args)) + "\n")
assert args.data_type in ["json"]
assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
os.environ["RWKV_FLOAT_MODE"] = args.precision
if args.precision == "fp32":
for i in range(10):
rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n")
if args.precision == "fp16":
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")
os.environ["RWKV_JIT_ON"] = "0"
if args.stage == 3:
os.environ["RWKV_JIT_ON"] = "0"
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
if args.precision == "fp32":
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
else:
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
# torch.backends.cudnn.allow_tf32 = False
# torch.backends.cuda.matmul.allow_tf32 = False
if "32" in args.precision:
args.precision = "32"
# args.dtype = torch.float32
elif args.precision == "fp16":
args.precision = "16"
# args.dtype = torch.float16
else:
args.precision = "bf16"
# args.dtype = torch.bfloat16
########################################################################################################
from src.trainer import train_callback
from src.dataset import MyDataset
from src.rwkv_tokenizer import TRIE_TOKENIZER
from transformers import AutoImageProcessor
args.tokenizer = TRIE_TOKENIZER("src/rwkv_vocab_v20230424.txt")
args.image_processor = AutoImageProcessor.from_pretrained(args.vision_tower_name)
train_data = MyDataset(args)
args.vocab_size = train_data.vocab_size
from src.model_state import RWKV_II
# 256gb cpu memory is not enough for 8 gpus
# to use 6 gpus on 256gb cpu memory, use .half() to save memory
model = RWKV_II(args).half()
if args.model_path:
msg = model.load_state_dict(torch.load(args.model_path, map_location='cpu'), strict=True)
model.half()
rank_zero_info(f"loading visual rwkv model from {args.model_path}: {msg}")
if args.freeze_rwkv > 0:
model.freeze_rwkv(args.freeze_rwkv)
if args.freeze_proj > 0:
model.freeze_proj()
model.freeze_emb() # freeze emb all the time
from pytorch_lightning.strategies import DeepSpeedStrategy
trainer = Trainer.from_argparse_args(args, callbacks=[train_callback(args)],strategy=DeepSpeedStrategy(config="config.json"),precision=16)
if trainer.global_rank == 0 and args.print_param_shape > 0:
for n in model.state_dict():
shape = model.state_dict()[n].shape
shape = [i for i in shape if i != 1]
if len(shape) > 1:
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}")
if len(shape) == 1:
print(f"{str(shape[0]).ljust(5)} {n}")
# must set shuffle=False, persistent_workers=False (because worker is in another thread)
data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=4,
persistent_workers=False, drop_last=True)
print(isinstance(trainer.strategy, DeepSpeedStrategy))
trainer.fit(model, data_loader)