55from typing import Sized
66
77import numpy as np
8- import pyarrow as pa
9- import pyarrow .parquet as pq
108import torch
119import torch .distributed as dist
1210from datasets import Dataset
@@ -41,7 +39,6 @@ def __init__(
4139 Args:
4240 path: The path to save the gradients
4341 projection_dim: The dimension to project the gradients onto
44- dtype: The dtype of the on-disk gradient store
4542 accumulate_grads: Whether to take the sum of the gradients
4643 of the same example across epochs. If `False`, the
4744 gradients for each epoch are stored separately.
@@ -72,8 +69,9 @@ def __init__(
7269
7370 self .mod_grads = {}
7471 self .batch_indices : Tensor | None = None
75- self .training_order : list [dict ] = []
76- self .torch_dtype = torch_dtype
72+
73+ # TODO: Handle this more elegantly
74+ self .torch_dtype = torch .float32 if self .dtype == np .float32 else torch .float16
7775
7876 # TODO: Handle this more elegantly
7977 self .torch_dtype = torch .float32 if self .dtype == np .float32 else torch .float16
@@ -88,7 +86,7 @@ def write_grads(self, grad_buffer: np.memmap):
8886
8987 def on_step_begin (self , args , state , control , ** kwargs ):
9088 """Track the current step and epoch for training order recording."""
91- if self .track_training_order :
89+ if self .order :
9290 self ._current_step = state .global_step
9391 self ._current_epoch = int (state .epoch or 0 )
9492
@@ -222,12 +220,6 @@ def on_module_backward(self, name: str, g: Tensor):
222220 device = "cpu" , dtype = self .torch_dtype , non_blocking = True
223221 )
224222
225- if (self .mod_grads [name ].pow (2 ).sum (dim = 1 ) == 0 ).any ():
226- print (
227- f"{ self .mod_grads [name ].pow (2 ).sum (dim = 1 ).eq (0 ).sum ().item ()} "
228- f"sum of squares == 0 rows found in gradients after { self .torch_dtype } "
229- )
230-
231223 def on_substep_end (
232224 self ,
233225 args : TrainingArguments ,
@@ -275,17 +267,14 @@ def on_step_end(
275267 return
276268
277269 # Record training order if enabled
278- if self .training_order is not None :
279- if self .batch_indices is None :
280- raise ValueError (
281- "Batch indices are not available for training order tracking"
282- )
270+ if self .order :
271+ assert (
272+ self .batch_indices is not None
273+ ), "Batch indices are not available for training order tracking"
283274
284- rank = dist .get_rank () if dist .is_initialized () else 0
285- self .training_order .extend (
275+ self .order .extend (
286276 {
287277 "_idx" : int (idx ),
288- "rank" : rank ,
289278 "global_step" : getattr (self , "_current_step" , 0 ),
290279 "epoch" : getattr (self , "_current_epoch" , 0 ),
291280 }
@@ -356,6 +345,7 @@ def on_train_end(
356345
357346 def _save_order (self ):
358347 """Save the training order to disk, handling distributed training."""
348+ assert self .order is not None
359349 os .makedirs (self .path , exist_ok = True )
360350
361351 if dist .is_initialized ():
0 commit comments