Skip to content

Commit ce949e3

Browse files
committed
Clean up huggingface callback; simplify induction heads model arch
1 parent a57c68e commit ce949e3

File tree

2 files changed

+535
-308
lines changed

2 files changed

+535
-308
lines changed

bergson/huggingface.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
from typing import Sized
66

77
import numpy as np
8-
import pyarrow as pa
9-
import pyarrow.parquet as pq
108
import torch
119
import torch.distributed as dist
1210
from datasets import Dataset
@@ -41,7 +39,6 @@ def __init__(
4139
Args:
4240
path: The path to save the gradients
4341
projection_dim: The dimension to project the gradients onto
44-
dtype: The dtype of the on-disk gradient store
4542
accumulate_grads: Whether to take the sum of the gradients
4643
of the same example across epochs. If `False`, the
4744
gradients for each epoch are stored separately.
@@ -72,8 +69,9 @@ def __init__(
7269

7370
self.mod_grads = {}
7471
self.batch_indices: Tensor | None = None
75-
self.training_order: list[dict] = []
76-
self.torch_dtype = torch_dtype
72+
73+
# TODO: Handle this more elegantly
74+
self.torch_dtype = torch.float32 if self.dtype == np.float32 else torch.float16
7775

7876
# TODO: Handle this more elegantly
7977
self.torch_dtype = torch.float32 if self.dtype == np.float32 else torch.float16
@@ -88,7 +86,7 @@ def write_grads(self, grad_buffer: np.memmap):
8886

8987
def on_step_begin(self, args, state, control, **kwargs):
9088
"""Track the current step and epoch for training order recording."""
91-
if self.track_training_order:
89+
if self.order:
9290
self._current_step = state.global_step
9391
self._current_epoch = int(state.epoch or 0)
9492

@@ -222,12 +220,6 @@ def on_module_backward(self, name: str, g: Tensor):
222220
device="cpu", dtype=self.torch_dtype, non_blocking=True
223221
)
224222

225-
if (self.mod_grads[name].pow(2).sum(dim=1) == 0).any():
226-
print(
227-
f"{self.mod_grads[name].pow(2).sum(dim=1).eq(0).sum().item()} "
228-
f"sum of squares == 0 rows found in gradients after {self.torch_dtype}"
229-
)
230-
231223
def on_substep_end(
232224
self,
233225
args: TrainingArguments,
@@ -275,17 +267,14 @@ def on_step_end(
275267
return
276268

277269
# Record training order if enabled
278-
if self.training_order is not None:
279-
if self.batch_indices is None:
280-
raise ValueError(
281-
"Batch indices are not available for training order tracking"
282-
)
270+
if self.order:
271+
assert (
272+
self.batch_indices is not None
273+
), "Batch indices are not available for training order tracking"
283274

284-
rank = dist.get_rank() if dist.is_initialized() else 0
285-
self.training_order.extend(
275+
self.order.extend(
286276
{
287277
"_idx": int(idx),
288-
"rank": rank,
289278
"global_step": getattr(self, "_current_step", 0),
290279
"epoch": getattr(self, "_current_epoch", 0),
291280
}
@@ -356,6 +345,7 @@ def on_train_end(
356345

357346
def _save_order(self):
358347
"""Save the training order to disk, handling distributed training."""
348+
assert self.order is not None
359349
os.makedirs(self.path, exist_ok=True)
360350

361351
if dist.is_initialized():

0 commit comments

Comments
 (0)