Skip to content

Commit 98d3b14

Browse files
committed
Update from rebase
1 parent ce949e3 commit 98d3b14

File tree

3 files changed

+20
-91
lines changed

3 files changed

+20
-91
lines changed

bergson/attributor.py

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -121,67 +121,6 @@ def search(
121121

122122
return torch.topk(scores, k)
123123

124-
def search_module(
125-
self, queries: Tensor, k: int, module: str
126-
) -> tuple[Tensor, Tensor]:
127-
"""
128-
Search for the `k` nearest examples in the index based on the query or queries.
129-
If fewer than `k` examples are found FAISS will return items with the index -1
130-
and the maximum negative distance.
131-
132-
Args:
133-
queries: The query tensor of shape [..., d].
134-
k: The number of nearest examples to return for each query.
135-
nprobe: The number of FAISS vector clusters to search if using ANN.
136-
137-
Returns:
138-
A namedtuple containing the top `k` indices and inner products for each
139-
query. Both have shape [..., k].
140-
"""
141-
assert isinstance(
142-
self.grads, dict
143-
), "Gradients must be a dictionary of tensors."
144-
assert module in self.grads, f"Module {module} not found in gradients."
145-
146-
k = min(k, self.grads[module].shape[0])
147-
148-
q = queries
149-
150-
if self.unit_norm:
151-
q /= q.norm(dim=1, keepdim=True)
152-
153-
if not self.faiss_cfg:
154-
return torch.topk(q.to(self.device) @ self.grads[module].mT, k)
155-
156-
q = q.cpu().numpy()
157-
158-
shard_distances = []
159-
shard_indices = []
160-
offset = 0
161-
162-
for index in self.faiss_shards:
163-
index.nprobe = self.faiss_cfg.nprobe
164-
distances, indices = index.search(q, k)
165-
166-
indices += offset
167-
offset += index.ntotal
168-
169-
shard_distances.append(distances)
170-
shard_indices.append(indices)
171-
172-
distances = np.concatenate(shard_distances, axis=1)
173-
indices = np.concatenate(shard_indices, axis=1)
174-
175-
# Rerank results overfetched from multiple shards
176-
if len(self.faiss_shards) > 1:
177-
topk_indices = np.argsort(distances, axis=1)[:, :k]
178-
indices = indices[np.arange(indices.shape[0])[:, None], topk_indices]
179-
distances = distances[np.arange(distances.shape[0])[:, None], topk_indices]
180-
181-
return torch.from_numpy(distances.squeeze()), torch.from_numpy(
182-
indices.squeeze()
183-
)
184-
185124
@contextmanager
186125
def trace(
187126
self, module: nn.Module, k: int, *, precondition: bool = False

bergson/huggingface.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -31,23 +31,25 @@ def __init__(
3131
path: str,
3232
head_cfgs: dict[str, HeadConfig],
3333
projection_dim: int = 16,
34+
dtype: DTypeLike = np.float16,
3435
accumulate_grads: bool = False,
3536
use_optimizer_state: bool = True,
3637
track_order: bool = False,
3738
):
3839
"""
3940
Args:
4041
path: The path to save the gradients
42+
head_cfgs: Information used to split matrix-valued parameters into
43+
per-head matrices before down projection.
4144
projection_dim: The dimension to project the gradients onto
45+
dtype: The dtype of the on-disk gradient store
4246
accumulate_grads: Whether to take the sum of the gradients
4347
of the same example across epochs. If `False`, the
4448
gradients for each epoch are stored separately.
4549
use_optimizer_state: Whether to use the optimizer state to
4650
normalize the gradients. If `False`, no normalization is
4751
applied.
4852
track_order: Whether to record the shuffled order of training data.
49-
head_cfgs: Information used to split matrix-valued parameters into
50-
per-head matrices before down projection.
5153
"""
5254
super().__init__()
5355

@@ -73,9 +75,6 @@ def __init__(
7375
# TODO: Handle this more elegantly
7476
self.torch_dtype = torch.float32 if self.dtype == np.float32 else torch.float16
7577

76-
# TODO: Handle this more elegantly
77-
self.torch_dtype = torch.float32 if self.dtype == np.float32 else torch.float16
78-
7978
def write_grads(self, grad_buffer: np.memmap):
8079
# Ensure the nonblocking copies are all finished
8180
torch.cuda.synchronize()
@@ -84,12 +83,6 @@ def write_grads(self, grad_buffer: np.memmap):
8483

8584
self.mod_grads.clear()
8685

87-
def on_step_begin(self, args, state, control, **kwargs):
88-
"""Track the current step and epoch for training order recording."""
89-
if self.order:
90-
self._current_step = state.global_step
91-
self._current_epoch = int(state.epoch or 0)
92-
9386
def on_train_begin(
9487
self,
9588
args: TrainingArguments,
@@ -266,21 +259,6 @@ def on_step_end(
266259
if not self.use_optimizer_state:
267260
return
268261

269-
# Record training order if enabled
270-
if self.order:
271-
assert (
272-
self.batch_indices is not None
273-
), "Batch indices are not available for training order tracking"
274-
275-
self.order.extend(
276-
{
277-
"_idx": int(idx),
278-
"global_step": getattr(self, "_current_step", 0),
279-
"epoch": getattr(self, "_current_epoch", 0),
280-
}
281-
for idx in self.batch_indices.tolist()
282-
)
283-
284262
# The optimizer doesn't actually know the names of the parameters
285263
model = getattr(model, "base_model", model)
286264
param_to_name = {

examples/find_induction_heads.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,10 @@
3939
from transformers.modeling_outputs import CausalLMOutputWithPast
4040

4141
import wandb
42-
from bergson.attributor import Attributor
4342

4443
# from bergson.data import load_gradient_dataset
44+
from bergson import HeadConfig
45+
from bergson.attributor import Attributor
4546
from bergson.collection import collect_gradients
4647
from bergson.gradients import GradientProcessor
4748
from bergson.huggingface import (
@@ -349,9 +350,11 @@ def create_transformer():
349350
return model, tokenizer
350351

351352

352-
def load_tinystories_data(tokenizer, max_length=512, N=10000):
353+
def load_tinystories_data(tokenizer, max_length=512, N: int | None = 10_000):
353354
"""Load and preprocess TinyStories dataset."""
354355
dataset = load_dataset("EleutherAI/SmolLM2-135M-10B", split="train")
356+
if N is not None:
357+
dataset = dataset.select(range(min(N, len(dataset))))
355358
# dataset = load_dataset("roneneldan/TinyStories", split="train")
356359
# dataset = dataset.select(range(min(N, len(dataset))))
357360

@@ -552,9 +555,14 @@ def setup_training(
552555

553556
bergson_callback = GradientCollectorCallback(
554557
path=f"{output_dir}/gradients",
558+
head_cfgs={
559+
"h.0.attn.c_attn": HeadConfig(12, 192, 2),
560+
"h.0.attn.c_proj": HeadConfig(12, 64, 2),
561+
"h.1.attn.c_attn": HeadConfig(12, 192, 2),
562+
"h.1.attn.c_proj": HeadConfig(12, 64, 2),
563+
},
555564
projection_dim=projection_dim,
556565
dtype=np.float32,
557-
torch_dtype=torch.float32,
558566
accumulate_grads=False,
559567
track_order=True,
560568
)
@@ -683,7 +691,10 @@ def main(args):
683691
model, tokenizer = create_transformer()
684692

685693
# # Load TinyStories data
686-
train_dataset, eval_dataset = load_tinystories_data(tokenizer)
694+
if args.small:
695+
train_dataset, eval_dataset = load_tinystories_data(tokenizer, N=1000)
696+
else:
697+
train_dataset, eval_dataset = load_tinystories_data(tokenizer)
687698

688699
# # Create induction head dataset
689700
test_induction_head_labels()
@@ -899,6 +910,7 @@ def main(args):
899910
parser.add_argument("--seed", type=int, default=0)
900911
parser.add_argument("--train", action="store_true")
901912
parser.add_argument("--unit_norm", action="store_true")
913+
parser.add_argument("--small", action="store_true")
902914
parser.add_argument("--tag", type=str, default="")
903915
args = parser.parse_args()
904916
main(args)

0 commit comments

Comments
 (0)