-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
140 lines (120 loc) · 5.09 KB
/
Copy pathtrain.py
File metadata and controls
140 lines (120 loc) · 5.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# Adapted from Tevatron code
import logging
import os.path
import sys
from transformers.utils import logging as hf_logging
# Keep Transformers logging quiet.
hf_logging.set_verbosity_error()
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s",
handlers=[logging.StreamHandler(sys.stdout)], # Ensures logs appear in stdout
)
logger = logging.getLogger(__name__)
import sys
import torch
import wandb
import yaml
from src.arguments import DataArguments, ModelArguments, TrainingArguments
from src.data.collator.train_collator import MultimodalDataCollator
from src.data.loader.mixed_dataset import init_mixed_dataset
from src.model.model import MMEBModel
from src.model.processor import get_backbone_name, load_processor
from src.trainer import GradCacheLateProcessTrainer, MMEBTrainer
from src.utils import find_latest_checkpoint, print_master, print_rank
from transformers import HfArgumentParser
def main():
# a hack for torch.distributed.launch: https://github.com/huggingface/transformers/issues/22171
for arg in sys.argv:
if arg.startswith("--local-rank="):
rank = arg.split("=")[1]
sys.argv.remove(arg)
sys.argv.append("--local_rank")
sys.argv.append(rank)
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Force single-GPU debug mode
if not torch.distributed.is_initialized():
print(" Running in single-GPU debug mode (forcing local_rank = -1)")
training_args.local_rank = -1
os.environ["LOCAL_RANK"] = "-1"
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
model_args: ModelArguments
data_args: DataArguments
training_args: TrainingArguments
if torch.distributed.is_available():
print(f"torch.distributed.is_initialized: {torch.distributed.is_initialized()}")
if torch.distributed.is_initialized():
print(f"torch.distributed.get_rank(): {torch.distributed.get_rank()}")
print(
f"torch.distributed.get_world_size(): {torch.distributed.get_world_size()}"
)
# Check for existing checkpoints
if training_args.resume_from == "auto":
resume_checkpoint_dir = find_latest_checkpoint(training_args.output_dir)
if resume_checkpoint_dir:
logger.info(f"Resuming from checkpoint: {resume_checkpoint_dir}")
elif training_args.resume_from.isdigit():
resume_checkpoint_dir = os.path.join(
training_args.output_dir, f"checkpoint-{training_args.resume_from}"
)
if os.path.exists(resume_checkpoint_dir):
logger.info(f"Resuming from checkpoint: {resume_checkpoint_dir}")
else:
resume_checkpoint_dir = None
logger.info("No checkpoint found. Starting fresh training.")
# Initialize WandB if enabled
if "wandb" in training_args.report_to:
if (
torch.distributed.is_initialized() and torch.distributed.get_rank() == 0
) or (not torch.distributed.is_initialized()):
print_rank("init wandb")
wandb.init(
project=training_args.project_name,
name=training_args.run_name,
mode="online",
)
wandb.config.update(model_args)
wandb.config.update(data_args)
wandb.config.update(training_args)
processor = load_processor(model_args, data_args)
model = MMEBModel.build(
model_args, training_args=training_args, processor=processor
)
model_backbone = get_backbone_name(hf_config=model.config)
setattr(model_args, "model_backbone", model_backbone)
setattr(training_args, "model_backbone", model_backbone)
print_rank(f"model_backbone: {model_backbone}")
processor = load_processor(model_args, data_args) # Reload the processor with a specific model backbone to limit the maximum input resolution.
setattr(model, "processor", processor)
model.init_tokenizer(processor)
with open(data_args.dataset_config, "r") as yaml_file:
dataset_config = yaml.safe_load(yaml_file)
train_dataset = init_mixed_dataset(
dataset_config, model_args, data_args, training_args
)
train_collator = MultimodalDataCollator(
processor, model_args, data_args, training_args
)
if training_args.grad_cache:
trainer_cls = GradCacheLateProcessTrainer
else:
trainer_cls = MMEBTrainer
trainer = trainer_cls(
model=model,
processing_class=processor,
args=training_args,
model_args=model_args,
train_dataset=train_dataset,
data_collator=train_collator,
max_length=data_args.max_len,
num_hardneg=data_args.num_hardneg,
)
train_dataset.trainer = trainer
trainer.train(resume_from_checkpoint=resume_checkpoint_dir)
trainer.save_model(training_args.output_dir)
if trainer.is_world_process_zero():
processor.save_pretrained(training_args.output_dir)
if __name__ == "__main__":
main()