-
Notifications
You must be signed in to change notification settings - Fork 367
Expand file tree
/
Copy pathcheckpointing.py
More file actions
157 lines (125 loc) · 4.81 KB
/
checkpointing.py
File metadata and controls
157 lines (125 loc) · 4.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""
Checkpoint saving utilities for GRPO trainer.
Handles saving model checkpoints for different training modes:
- Full model checkpoints (legacy and shared_vllm modes)
- LoRA adapter checkpoints
IMPORTANT: For shared_vllm mode, the model parameters are VIEWS into vLLM's
fused tensors (qkv_proj, gate_up_proj). This module handles unfusing them
back to HuggingFace format for safe checkpoint saving.
"""
import os
import shutil
from typing import Dict
import torch
def _ensure_contiguous_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]:
"""
Create a state dict with contiguous tensors for safe saving.
This is critical for shared_vllm mode where parameters are views into
vLLM's fused tensors. Views may share storage and not be contiguous,
which can cause issues when saving.
Returns:
State dict with all tensors made contiguous (copied if necessary)
"""
state_dict = {}
for name, param in model.named_parameters():
# Check if tensor is a view (non-contiguous or shares storage)
if not param.is_contiguous() or param.storage_offset() != 0:
# Make a contiguous copy - this "unfuses" the view
state_dict[name] = param.detach().clone().contiguous()
else:
state_dict[name] = param.detach()
# Also include buffers
for name, buffer in model.named_buffers():
if not buffer.is_contiguous() or buffer.storage_offset() != 0:
state_dict[name] = buffer.detach().clone().contiguous()
else:
state_dict[name] = buffer.detach()
return state_dict
def save_checkpoint(
model: torch.nn.Module,
tokenizer,
save_path: str,
step: int,
is_final: bool = False,
safe_mode: bool = True,
) -> str:
"""
Save full model checkpoint.
Args:
model: Model to save
tokenizer: Tokenizer to save
save_path: Base directory for checkpoints
step: Current training step
is_final: Whether this is the final checkpoint
safe_mode: If True, ensure all tensors are contiguous before saving.
This is important for shared_vllm mode where params are
views into fused vLLM tensors.
Returns:
Path where checkpoint was saved
"""
if is_final:
checkpoint_path = os.path.join(save_path, "final_model")
else:
checkpoint_path = os.path.join(save_path, f"step_{step}")
print(f" Saving checkpoint to {checkpoint_path}...")
if os.path.exists(checkpoint_path):
shutil.rmtree(checkpoint_path)
os.makedirs(checkpoint_path, exist_ok=True)
if safe_mode:
# For shared_vllm mode: ensure views are properly unfused
print(" [Checkpoint] Using safe mode - ensuring contiguous tensors...")
state_dict = _ensure_contiguous_state_dict(model)
# Count how many were non-contiguous (views into fused tensors)
view_count = sum(
1
for name, param in model.named_parameters()
if not param.is_contiguous() or param.storage_offset() != 0
)
if view_count > 0:
print(
f" [Checkpoint] Unfused {view_count} view tensors (qkv/gate_up fusions)"
)
# Save state dict manually, then save config separately
torch.save(state_dict, os.path.join(checkpoint_path, "pytorch_model.bin"))
model.config.save_pretrained(checkpoint_path)
# CRITICAL: Clean up the copied state_dict to free significant GPU memory.
del state_dict
import gc
gc.collect()
torch.cuda.empty_cache()
else:
# Standard save (may have issues with view tensors)
model.save_pretrained(checkpoint_path)
tokenizer.save_pretrained(checkpoint_path)
print(" Checkpoint saved.")
return checkpoint_path
def save_lora_checkpoint(
model: torch.nn.Module,
save_path: str,
step: int,
is_final: bool = False,
) -> str:
"""
Save LoRA adapter checkpoint.
Only saves the LoRA adapter weights, not the full model.
This results in much smaller checkpoint files.
Args:
model: PEFT model with LoRA adapters
save_path: Base directory for checkpoints
step: Current training step
is_final: Whether this is the final checkpoint
Returns:
Path where adapter was saved
"""
if is_final:
adapter_path = os.path.join(save_path, "final_adapter")
else:
adapter_path = os.path.join(save_path, f"adapter_step_{step}")
print(f" Saving LoRA adapter to {adapter_path}...")
if os.path.exists(adapter_path):
shutil.rmtree(adapter_path)
os.makedirs(adapter_path, exist_ok=True)
# Save only the adapter weights (much smaller than full model)
model.save_pretrained(adapter_path)
print(" Adapter saved.")
return adapter_path