-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
145 lines (118 loc) · 4.35 KB
/
train.py
File metadata and controls
145 lines (118 loc) · 4.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
Main training script for TRNTR.
Usage:
python train.py # Train with default config
python train.py model=trntr_7m # Override model config
python train.py training.learning_rate=5e-4 # Override specific parameter
python train.py +experiment=ablation_no_recursion # Use ablation config
"""
import os
import hydra
from omegaconf import DictConfig, OmegaConf
import torch
from pathlib import Path
from src.trntr.models import build_model, EMA
from src.trntr.data import TRNTRTokenizer, get_dataloaders, train_tokenizer_from_dataset
from src.trntr.training import Trainer, setup_logger
def setup_reproducibility(seed: int, deterministic: bool = False):
"""Setup reproducibility.
Args:
seed: Random seed
deterministic: Use deterministic algorithms (slower but reproducible)
"""
import random
import numpy as np
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
@hydra.main(version_base=None, config_path="configs", config_name="config")
def main(config: DictConfig):
"""Main training function.
Args:
config: Hydra configuration
"""
# Print config
print("="*80)
print("TRNTR Training Configuration")
print("="*80)
print(OmegaConf.to_yaml(config))
print("="*80)
# Setup reproducibility
setup_reproducibility(
seed=config.experiment.seed,
deterministic=config.deterministic
)
# Setup paths
Path(config.paths.data_dir).mkdir(parents=True, exist_ok=True)
Path(config.paths.checkpoint_dir).mkdir(parents=True, exist_ok=True)
Path(config.paths.log_dir).mkdir(parents=True, exist_ok=True)
# Setup logger
logger = setup_logger(config)
logger.log_text("Starting TRNTR training", level="info")
# Load or train tokenizer
tokenizer_path = Path(config.paths.checkpoint_dir) / "tokenizer" / "trntr_tokenizer.json"
if tokenizer_path.exists():
logger.log_text(f"Loading tokenizer from {tokenizer_path}")
tokenizer = TRNTRTokenizer.load(tokenizer_path)
else:
logger.log_text("Training new tokenizer...")
tokenizer = train_tokenizer_from_dataset(
dataset_name="c4",
dataset_split="train",
num_samples=100000,
vocab_size=config.model.vocab_size,
save_path=str(tokenizer_path),
)
logger.log_text(f"Tokenizer vocab size: {len(tokenizer)}")
# Build model
logger.log_text("Building model...")
model = build_model(config.model)
# Log model info
num_params = model.count_parameters()
logger.log_text(f"Model has {num_params:,} parameters")
# Create dataloaders
logger.log_text("Creating dataloaders...")
train_loader, val_loader, test_loader = get_dataloaders(
config=config,
tokenizer=tokenizer,
)
# Setup Accelerate (for multi-GPU)
try:
from accelerate import Accelerator
accelerator = Accelerator(
mixed_precision=config.training.mixed_precision,
gradient_accumulation_steps=config.training.gradient_accumulation_steps,
)
logger.log_text(f"Using Accelerate with {accelerator.num_processes} processes")
logger.log_text(f"Mixed precision: {config.training.mixed_precision}")
except ImportError:
logger.log_text("Accelerate not available, using single GPU/CPU", level="warning")
accelerator = None
# Create trainer
logger.log_text("Creating trainer...")
trainer = Trainer(
model=model,
tokenizer=tokenizer,
train_loader=train_loader,
val_loader=val_loader,
config=config,
logger=logger,
accelerator=accelerator,
)
# Train
logger.log_text("Starting training loop...")
try:
trainer.train()
except KeyboardInterrupt:
logger.log_text("Training interrupted by user", level="warning")
trainer.save_checkpoint("interrupted")
except Exception as e:
logger.log_text(f"Training failed with error: {e}", level="error")
raise
logger.log_text("Training complete!")
if __name__ == "__main__":
main()