forked from mindspore-lab/mindformers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_mindformer.py
353 lines (325 loc) · 15.9 KB
/
run_mindformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Run MindFormer."""
import argparse
import os
import shutil
import numpy as np
import mindspore as ms
from mindspore.common import set_seed
from mindformers.tools.register import MindFormerConfig, ActionDict
from mindformers.core.parallel_config import build_parallel_config
from mindformers.tools.utils import str2bool, set_remote_save_url, check_in_modelarts, parse_value
from mindformers.core.context import build_context, build_profile_cb
from mindformers.trainer import build_trainer
from mindformers.tools.cloud_adapter import cloud_monitor
from mindformers.tools.logger import logger
from mindformers.tools import get_output_root_path, set_output_path
from mindformers.mindformer_book import MindFormerBook
if check_in_modelarts():
import moxing as mox
SUPPORT_MODEL_NAMES = MindFormerBook().get_model_name_support_list()
def update_checkpoint_config(config, is_train=True):
"""update checkpoint config depending on is_train"""
if (is_train and config.resume_training) or config.auto_trans_ckpt or os.path.isdir(config.load_checkpoint):
logger.info("Leave load_checkpoint may because: ")
logger.info("1. resume training need resume training info. ")
logger.info("2. need load distributed shard checkpoint. ")
if not config.load_checkpoint:
config.load_checkpoint = config.model.model_config.checkpoint_name_or_path
config.model.model_config.checkpoint_name_or_path = None
else:
if config.run_mode in ('train', 'finetune'):
config.model.model_config.checkpoint_name_or_path = config.load_checkpoint
elif config.run_mode in ['eval', 'predict'] and config.load_checkpoint:
config.model.model_config.checkpoint_name_or_path = config.load_checkpoint
config.load_checkpoint = None
def clear_auto_trans_output(config):
"""clear transformed_checkpoint and strategy"""
if check_in_modelarts():
obs_strategy_dir = os.path.join(config.remote_save_url, "strategy")
if mox.file.exists(obs_strategy_dir) and config.local_rank == 0:
mox.file.remove(obs_strategy_dir, recursive=True)
obs_transformed_ckpt_dir = os.path.join(config.remote_save_url, "transformed_checkpoint")
if mox.file.exists(obs_transformed_ckpt_dir) and config.local_rank == 0:
mox.file.remove(obs_transformed_ckpt_dir, recursive=True)
mox.file.make_dirs(obs_strategy_dir)
mox.file.make_dirs(obs_transformed_ckpt_dir)
else:
strategy_dir = os.path.join(get_output_root_path(), "strategy")
if os.path.exists(strategy_dir) and config.local_rank % 8 == 0:
shutil.rmtree(strategy_dir)
transformed_ckpt_dir = os.path.join(get_output_root_path(), "transformed_checkpoint")
if os.path.exists(transformed_ckpt_dir) and config.local_rank % 8 == 0:
shutil.rmtree(transformed_ckpt_dir)
os.makedirs(strategy_dir, exist_ok=True)
os.makedirs(transformed_ckpt_dir, exist_ok=True)
def create_task_trainer(config):
trainer = build_trainer(config.trainer)
if config.run_mode == 'train' or config.run_mode == 'finetune':
trainer.train(config, is_full_config=True)
elif config.run_mode == 'eval':
trainer.evaluate(config, is_full_config=True)
elif config.run_mode == 'predict':
trainer.predict(config, is_full_config=True, batch_size=config.predict_batch_size)
elif config.run_mode == 'export':
trainer.export(config, is_full_config=True)
@cloud_monitor()
def main(config):
"""main."""
# set output path
set_output_path(config.output_dir)
# init context
build_context(config)
if config.seed and \
ms.context.get_auto_parallel_context("parallel_mode") \
not in ["semi_auto_parallel", "auto_parallel"]:
set_seed(config.seed)
np.random.seed(config.seed)
# build context config
logger.info(".........Build context config..........")
if config.run_mode == 'predict':
if config.use_parallel and \
config.parallel.parallel_mode in ['semi_auto_parallel', 1] and \
config.parallel_config.data_parallel != 1:
raise ValueError("The value of data parallel can only be set to 1, since the batch size of input is 1. ")
build_parallel_config(config)
logger.info("context config is: %s", config.parallel_config)
logger.info("moe config is: %s", config.moe_config)
if config.run_mode == 'train':
update_checkpoint_config(config)
if config.run_mode == 'finetune':
if not config.load_checkpoint:
raise ValueError("if run status is finetune, "
"load_checkpoint must be input")
update_checkpoint_config(config)
if config.run_mode in ['eval', 'predict']:
update_checkpoint_config(config, is_train=False)
# remote save url
if check_in_modelarts() and config.remote_save_url:
logger.info("remote_save_url is %s, the output file will be uploaded to here.", config.remote_save_url)
set_remote_save_url(config.remote_save_url)
# define callback and add profile callback
if config.profile:
config.profile_cb = build_profile_cb(config)
if config.auto_trans_ckpt:
clear_auto_trans_output(config)
create_task_trainer(config)
if __name__ == "__main__":
work_path = os.path.dirname(os.path.abspath(__file__))
parser = argparse.ArgumentParser()
parser.add_argument(
'--config',
default="configs/mae/run_mae_vit_base_p16_224_800ep.yaml",
required=True,
help='YAML config files')
parser.add_argument(
'--mode', default=None, type=int,
help='Running in GRAPH_MODE(0) or PYNATIVE_MODE(1). Default: GRAPH_MODE(0).'
'GRAPH_MODE or PYNATIVE_MODE can be set by `mode` attribute and both modes support all backends,'
'Default: None')
parser.add_argument(
'--device_id', default=None, type=int,
help='ID of the target device, the value must be in [0, device_num_per_host-1], '
'while device_num_per_host should be no more than 4096. Default: None')
parser.add_argument(
'--device_target', default=None, type=str,
help='The target device to run, support "Ascend", "GPU", and "CPU".'
'If device target is not set, the version of MindSpore package is used.'
'Default: None')
parser.add_argument(
'--run_mode', default=None, type=str,
help='task running status, it support [train, finetune, eval, predict].'
'Default: None')
parser.add_argument(
'--do_eval', default=None, type=str2bool,
help='whether do evaluate in training process.'
'Default: None')
parser.add_argument(
'--train_dataset_dir', default=None, type=str,
help='dataset directory of data loader to train/finetune. '
'Default: None')
parser.add_argument(
'--eval_dataset_dir', default=None, type=str,
help='dataset directory of data loader to eval. '
'Default: None')
parser.add_argument(
'--predict_data', default=None, type=str, nargs='+',
help='input data for predict, it support real data path or data directory.'
'Default: None')
parser.add_argument(
'--predict_batch_size', default=None, type=int,
help='batch size for predict data, set to perform batch predict.'
'Default: None')
parser.add_argument(
'--load_checkpoint', default=None, type=str,
help="load model checkpoint to train/finetune/eval/predict, "
"it is also support input model name, such as 'mae_vit_base_p16', "
"please refer to https://gitee.com/mindspore/mindformers#%E4%BB%8B%E7%BB%8D."
"Default: None")
parser.add_argument(
'--src_strategy_path_or_dir', default=None, type=str,
help="The strategy of load_checkpoint, "
"if dir, it will be merged before transform checkpoint, "
"if file, it will be used in transform checkpoint directly, "
"Default: None, means load_checkpoint is a single whole ckpt, not distributed")
parser.add_argument(
'--auto_trans_ckpt', default=None, type=str2bool,
help="if true, auto transform load_checkpoint to load in distributed model. ")
parser.add_argument(
'--only_save_strategy', default=None, type=str2bool,
help="if true, when strategy files are saved, system exit. ")
parser.add_argument(
'--resume_training', default=None, type=str2bool,
help="whether to load training context info, such as optimizer and epoch num")
parser.add_argument(
'--strategy_load_checkpoint', default=None, type=str,
help='path to parallel strategy checkpoint to load, it support real data path or data directory.'
'Default: None')
parser.add_argument(
'--remote_save_url', default=None, type=str,
help='remote save url, where all the output files will tansferred and stroed in here. '
'Default: None')
parser.add_argument(
'--seed', default=None, type=int,
help='global random seed to train/finetune.'
'Default: None')
parser.add_argument(
'--use_parallel', default=None, type=str2bool,
help='whether use parallel mode. Default: None')
parser.add_argument(
'--profile', default=None, type=str2bool,
help='whether use profile analysis. Default: None')
parser.add_argument(
'--options',
nargs='+',
action=ActionDict,
help='override some settings in the used config, the key-value pair'
'in xxx=yyy format will be merged into config file')
parser.add_argument(
'--epochs', default=None, type=int,
help='train epochs.'
'Default: None')
parser.add_argument(
'--batch_size', default=None, type=int,
help='batch_size of datasets.'
'Default: None')
parser.add_argument(
'--sink_mode', default=None, type=str2bool,
help='whether use sink mode. '
'Default: None')
parser.add_argument(
'--num_samples', default=None, type=int,
help='number of datasets samples used.'
'Default: None')
parser.add_argument(
'--output_dir', default=None, type=str,
help='output directory.')
args_, rest_args_ = parser.parse_known_args()
rest_args_ = [i for item in rest_args_ for i in item.split("=")]
if len(rest_args_) % 2 != 0:
raise ValueError(f"input arg key-values are not in pair, please check input args. ")
if args_.config is not None:
args_.config = os.path.join(work_path, args_.config)
config_ = MindFormerConfig(args_.config)
if args_.device_id is not None:
config_.context.device_id = args_.device_id
if args_.device_target is not None:
config_.context.device_target = args_.device_target
if args_.mode is not None:
config_.context.mode = args_.mode
if args_.run_mode is not None:
config_.run_mode = args_.run_mode
if args_.do_eval is not None:
config_.do_eval = args_.do_eval
if args_.seed is not None:
config_.seed = args_.seed
if args_.use_parallel is not None:
config_.use_parallel = args_.use_parallel
if args_.load_checkpoint is not None:
config_.load_checkpoint = args_.load_checkpoint
if args_.src_strategy_path_or_dir is not None:
config_.src_strategy_path_or_dir = args_.src_strategy_path_or_dir
if args_.auto_trans_ckpt is not None:
config_.auto_trans_ckpt = args_.auto_trans_ckpt
if args_.only_save_strategy is not None:
config_.only_save_strategy = args_.only_save_strategy
if args_.resume_training is not None:
config_.resume_training = args_.resume_training
if args_.strategy_load_checkpoint is not None:
if os.path.isdir(args_.strategy_load_checkpoint):
ckpt_list = [os.path.join(args_.strategy_load_checkpoint, file)
for file in os.listdir(args_.strategy_load_checkpoint) if file.endwith(".ckpt")]
args_.strategy_load_checkpoint = ckpt_list[0]
config_.parallel.strategy_ckpt_load_file = args_.strategy_load_checkpoint
if args_.remote_save_url is not None:
config_.remote_save_url = args_.remote_save_url
if args_.profile is not None:
config_.profile = args_.profile
if args_.options is not None:
config_.merge_from_dict(args_.options)
assert config_.run_mode in ['train', 'eval', 'predict', 'finetune', 'export'], \
f"run status must be in {['train', 'eval', 'predict', 'finetune', 'export']}, but get {config_.run_mode}"
if args_.train_dataset_dir:
config_.train_dataset.data_loader.dataset_dir = args_.train_dataset_dir
if args_.eval_dataset_dir:
config_.eval_dataset.data_loader.dataset_dir = args_.eval_dataset_dir
if config_.run_mode == 'predict':
if args_.predict_data is None:
logger.info("dataset by config is used as input_data.")
if isinstance(args_.predict_data, list):
if len(args_.predict_data) > 1:
logger.info("predict data is a list, take it as input text list.")
else:
args_.predict_data = args_.predict_data[0]
if isinstance(args_.predict_data, str):
if os.path.isdir(args_.predict_data):
predict_data = [os.path.join(root, file)
for root, _, file_list in os.walk(os.path.join(args_.predict_data)) for file in
file_list
if file.endswith(".jpg") or file.endswith(".png") or file.endswith(".jpeg")
or file.endswith(".JPEG") or file.endswith("bmp")]
args_.predict_data = predict_data
else:
args_.predict_data = args_.predict_data.replace(r"\n", "\n")
config_.input_data = args_.predict_data
if args_.predict_batch_size is not None:
config_.predict_batch_size = args_.predict_batch_size
if args_.epochs is not None:
config_.runner_config.epochs = args_.epochs
if args_.batch_size is not None:
config_.runner_config.batch_size = args_.batch_size
if args_.sink_mode is not None:
config_.runner_config.sink_mode = args_.sink_mode
if args_.num_samples is not None:
if config_.train_dataset and config_.train_dataset.data_loader:
config_.train_dataset.data_loader.num_samples = args_.num_samples
if config_.eval_dataset and config_.eval_dataset.data_loader:
config_.eval_dataset.data_loader.num_samples = args_.num_samples
if args_.output_dir is not None:
config_.output_dir = args_.output_dir
while rest_args_:
key = rest_args_.pop(0)
value = rest_args_.pop(0)
if not key.startswith("--"):
raise ValueError("Custom config key need to start with --.")
dists = key[2:].split(".")
dist_config = config_
while len(dists) > 1:
if dists[0] not in dist_config:
raise ValueError(f"{dists[0]} is not a key of {dist_config}, please check input arg keys. ")
dist_config = dist_config[dists.pop(0)]
dist_config[dists.pop()] = parse_value(value)
main(config_)