Skip to content

Commit 28b53a8

Browse files
authored
Merge branch 'EleutherAI:main' into gemma_device_fix
2 parents 0e95694 + 0837a97 commit 28b53a8

File tree

6 files changed

+113
-98
lines changed

6 files changed

+113
-98
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,4 @@ results/
183183
statistics/
184184
.embedding_cache/
185185
wandb/
186+
uv.lock

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ repos:
77
- id: trailing-whitespace
88
- id: end-of-file-fixer
99
- id: check-added-large-files
10-
- repo: https://github.com/psf/black
10+
- repo: https://github.com/psf/black-pre-commit-mirror
1111
rev: 25.9.0
1212
hooks:
1313
- id: black
1414
- repo: https://github.com/astral-sh/ruff-pre-commit
15-
rev: 'v0.13.2'
15+
rev: 'v0.13.3'
1616
hooks:
1717
- id: ruff
1818
args: [--fix, --exit-non-zero-on-fix]

delphi/pipeline.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,9 @@ async def process_item(self, item: Any, semaphore: asyncio.Semaphore) -> Any:
161161
async with semaphore:
162162
result = item
163163
for pipe in self.pipes:
164-
if result is not None:
165-
result = await pipe(result)
166-
else:
167-
pass
164+
if result is None:
165+
return None
166+
167+
result = await pipe(result)
168+
168169
return result

delphi/scorers/classifier/intruder.py

Lines changed: 47 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import asyncio
22
import re
3+
from collections import defaultdict
34
from dataclasses import dataclass
5+
from itertools import cycle
46
from typing import Literal
57

68
from beartype.typing import Sequence
@@ -136,12 +138,11 @@ def _get_quantiled_examples(
136138
"""
137139
Get the quantiled examples.
138140
"""
139-
quantiles = {}
141+
examples_grouped_by_quantiles = defaultdict(list)
140142
for example in examples:
141-
if example.quantile not in quantiles:
142-
quantiles[example.quantile] = []
143-
quantiles[example.quantile].append(example)
144-
return quantiles
143+
examples_grouped_by_quantiles[example.quantile].append(example)
144+
145+
return examples_grouped_by_quantiles
145146

146147
def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]:
147148
"""
@@ -153,38 +154,39 @@ def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]:
153154
quantiled_intruder_sentences = self._get_quantiled_examples(record.test)
154155

155156
intruder_sentences = record.not_active
156-
for i, intruder in enumerate(intruder_sentences):
157-
# select each quantile equally
158-
quantile_index = i % len(quantiled_intruder_sentences.keys())
159157

160-
active_examples = quantiled_intruder_sentences[quantile_index]
158+
# select each quantile equally by repeatedly cycling through them
159+
quantile_iterator = cycle(quantiled_intruder_sentences.items())
160+
for (active_quantile, all_active_examples), intruder in zip(
161+
quantile_iterator, intruder_sentences
162+
):
161163
# if there are more examples than the number of examples to show,
162164
# sample which examples to show
163-
examples_to_show = min(self.n_examples_shown - 1, len(active_examples))
164-
example_indices = self.rng.sample(
165-
range(len(active_examples)), examples_to_show
165+
num_active_examples = min(
166+
# - 1 because we are going to insert the intruder sentence
167+
self.n_examples_shown - 1,
168+
len(all_active_examples),
166169
)
167-
active_examples = [active_examples[i] for i in example_indices]
168-
169-
# convert the examples to strings
170+
active_examples = self.rng.sample(all_active_examples, num_active_examples)
170171

171-
# highlights the active tokens
172+
# highlights the active tokens with <<>> markers
172173
majority_examples = []
173-
active_tokens = 0
174+
num_active_tokens = 0
174175
for example in active_examples:
175-
text, _ = _prepare_text(
176+
text, _str_tokens = _prepare_text(
176177
example, n_incorrect=0, threshold=0.3, highlighted=True
177178
)
178179
majority_examples.append(text)
179-
active_tokens += (example.activations > 0).sum().item()
180-
active_tokens = int(active_tokens / len(active_examples))
180+
num_active_tokens += (example.activations > 0).sum().item()
181+
182+
avg_active_tokens_per_example = num_active_tokens // len(active_examples)
181183
if self.type == "default":
182184
# if example is contrastive, use the active tokens
183185
# otherwise use the non-activating tokens
184186
if intruder.activations.max() > 0:
185187
n_incorrect = 0
186188
else:
187-
n_incorrect = active_tokens
189+
n_incorrect = avg_active_tokens_per_example
188190
intruder_sentence, _ = _prepare_text(
189191
intruder,
190192
n_incorrect=n_incorrect,
@@ -194,22 +196,15 @@ def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]:
194196
elif self.type == "internal":
195197
# randomly select a quantile to be the intruder, make sure it's not
196198
# the same as the source quantile
197-
intruder_quantile_index = self.rng.randint(
198-
0, len(quantiled_intruder_sentences.keys()) - 1
199-
)
200-
while intruder_quantile_index == quantile_index:
201-
intruder_quantile_index = self.rng.randint(
202-
0, len(quantiled_intruder_sentences.keys()) - 1
203-
)
204-
posible_intruder_sentences = quantiled_intruder_sentences[
205-
intruder_quantile_index
206-
]
207-
intruder_index_selected = self.rng.randint(
208-
0, len(posible_intruder_sentences) - 1
209-
)
210-
intruder = posible_intruder_sentences[intruder_index_selected]
199+
all_quantile_examples = list(quantiled_intruder_sentences.values())
200+
all_quantile_examples.remove(all_active_examples)
201+
possible_intruder_sentences = self.rng.choice(all_quantile_examples)
202+
203+
intruder = self.rng.choice(possible_intruder_sentences)
211204
# here the examples are activating, so we have to convert them
212205
# to non-activating examples
206+
assert intruder.str_tokens is not None, "intruder has no str_tokens"
207+
213208
non_activating_intruder = NonActivatingExample(
214209
tokens=intruder.tokens,
215210
activations=intruder.activations,
@@ -224,23 +219,27 @@ def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]:
224219
highlighted=True,
225220
)
226221
intruder = non_activating_intruder
222+
else:
223+
raise ValueError("Invalid intruder scorer type")
227224

228225
# select a random index to insert the intruder sentence
229-
intruder_index = self.rng.randint(0, examples_to_show)
230-
majority_examples.insert(intruder_index, intruder_sentence)
226+
intruder_index = self.rng.randint(0, num_active_examples)
227+
examples = (
228+
majority_examples[:intruder_index]
229+
+ [intruder_sentence]
230+
+ majority_examples[intruder_index:]
231+
)
231232

232-
activations = [example.activations.tolist() for example in active_examples]
233-
tokens = [example.str_tokens for example in active_examples]
234-
activations.insert(intruder_index, intruder.activations.tolist())
235-
tokens.insert(intruder_index, intruder.str_tokens)
233+
example_activations = [example.activations.tolist() for example in examples]
234+
example_tokens = [example.str_tokens for example in examples]
236235

237236
batches.append(
238237
IntruderSentence(
239-
examples=majority_examples,
238+
examples=examples,
240239
intruder_index=intruder_index,
241-
chosen_quantile=quantile_index,
242-
activations=activations,
243-
tokens=tokens,
240+
chosen_quantile=active_quantile,
241+
activations=example_activations,
242+
tokens=example_tokens,
244243
intruder_distance=intruder.distance,
245244
)
246245
)
@@ -275,7 +274,7 @@ def _build_prompt(
275274
"""
276275

277276
examples = "\n".join(
278-
f"Example {i}: {example}" for i, example in enumerate(sample.examples)
277+
f"Example {i}:{example}" for i, example in enumerate(sample.examples)
279278
)
280279

281280
return self.prompt(examples=examples)
@@ -311,21 +310,11 @@ async def _generate(self, sample: IntruderSentence) -> IntruderResult:
311310
prompt = self._build_prompt(sample)
312311
try:
313312
response = await self.client.generate(prompt, **self.generation_kwargs)
313+
interpretation, prediction = self._parse(response.text)
314314
except Exception as e:
315-
logger.error(f"Error generating text: {e}")
316-
response = None
317-
318-
if response is None:
315+
logger.error(str(e))
319316
# default result is a error
320317
return IntruderResult()
321-
else:
322-
323-
try:
324-
interpretation, prediction = self._parse(response.text)
325-
except Exception as e:
326-
logger.error(f"Parsing selections failed: {e}")
327-
# default result is a error
328-
return IntruderResult()
329318

330319
# check that the only prediction is the intruder
331320
correct = prediction == sample.intruder_index

delphi/scorers/classifier/sample.py

Lines changed: 57 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import random
2+
from collections import deque
23
from dataclasses import dataclass
3-
from typing import NamedTuple
4+
from itertools import groupby
5+
from typing import Callable, NamedTuple
46

57
import torch
68

@@ -88,73 +90,95 @@ def _prepare_text(
8890
threshold: float,
8991
highlighted: bool,
9092
) -> tuple[str, list[str]]:
93+
assert n_incorrect >= 0, (
94+
"n_incorrect must be 0 if highlighting correct example "
95+
"or positive if creating false positives. "
96+
f"Got {n_incorrect}"
97+
)
98+
9199
str_toks = example.str_tokens
92100
assert str_toks is not None, "str_toks were not set"
93-
clean = "".join(str_toks)
101+
94102
# Just return text if there's no highlighting
95103
if not highlighted:
104+
clean = "".join(str_toks)
105+
96106
return clean, str_toks
97107

98-
threshold = threshold * example.max_activation
108+
abs_threshold = threshold * example.max_activation
99109

100110
# Highlight tokens with activations above threshold
101-
# if correct example
111+
# if this is a correct example
102112
if n_incorrect == 0:
103113

104-
def threshold_check(i):
105-
return example.activations[i] >= threshold
114+
def is_above_activation_threshold(i: int) -> bool:
115+
return example.activations[i] >= abs_threshold
106116

107-
return _highlight(str_toks, threshold_check), str_toks
117+
return _highlight(str_toks, is_above_activation_threshold), str_toks
108118

109119
# Highlight n_incorrect tokens with activations
110-
# below threshold if incorrect example
111-
below_threshold = torch.nonzero(example.activations <= threshold).squeeze()
120+
# below threshold if this is an incorrect example
121+
tokens_below_threshold = torch.nonzero(
122+
example.activations <= abs_threshold
123+
).squeeze()
112124

113125
# Rare case where there are no tokens below threshold
114-
if below_threshold.dim() == 0:
115-
logger.error("Failed to prepare example.")
126+
if tokens_below_threshold.dim() == 0:
127+
logger.error(
128+
f"Tried to prepare false-positive example with {n_incorrect} tokens "
129+
"incorrectly highlighted, but no tokens were below activation threshold."
130+
)
116131
return DEFAULT_MESSAGE, str_toks
117132

118133
random.seed(22)
119134

120-
n_incorrect = min(n_incorrect, len(below_threshold))
135+
num_tokens_to_highlight = min(n_incorrect, tokens_below_threshold.shape[0])
121136

122137
# The activating token is always ctx_len - ctx_len//4
123-
# so we always highlight this one, and if n_incorrect > 1
124-
# we highlight n_incorrect-1 random ones
138+
# so we always highlight this one, and if num_tokens_to_highlight > 1
139+
# we highlight num_tokens_to_highlight - 1 random ones
125140
token_pos = len(str_toks) - len(str_toks) // 4
126-
if token_pos in below_threshold:
141+
if token_pos in tokens_below_threshold:
127142
random_indices = [token_pos]
128-
if n_incorrect > 1:
143+
144+
num_remaining_tokens_to_highlight = num_tokens_to_highlight - 1
145+
if num_remaining_tokens_to_highlight > 0:
146+
remaining_tokens_below_threshold = tokens_below_threshold.tolist()
147+
remaining_tokens_below_threshold.remove(token_pos)
148+
129149
random_indices.extend(
130-
random.sample(below_threshold.tolist(), n_incorrect - 1)
150+
random.sample(
151+
remaining_tokens_below_threshold,
152+
num_remaining_tokens_to_highlight,
153+
)
131154
)
132155
else:
133-
random_indices = random.sample(below_threshold.tolist(), n_incorrect)
156+
random_indices = random.sample(
157+
tokens_below_threshold.tolist(), num_tokens_to_highlight
158+
)
134159

135160
random_indices = set(random_indices)
136161

137-
def check(i):
162+
def is_false_positive(i):
138163
return i in random_indices
139164

140-
return _highlight(str_toks, check), str_toks
165+
return _highlight(str_toks, is_false_positive), str_toks
166+
141167

168+
def _highlight(tokens: list[str], check: Callable[[int], bool]) -> str:
169+
result: deque[str] = deque()
142170

143-
def _highlight(tokens, check):
144-
result = []
171+
tokens_grouped_by_check_fn = groupby(
172+
enumerate(tokens), key=lambda item: check(item[0])
173+
)
145174

146-
i = 0
147-
while i < len(tokens):
148-
if check(i):
149-
result.append(L)
175+
for should_highlight, token_group in tokens_grouped_by_check_fn:
176+
highlighted_tokens = deque(token for _token_index, token in token_group)
150177

151-
while i < len(tokens) and check(i):
152-
result.append(tokens[i])
153-
i += 1
178+
if should_highlight:
179+
highlighted_tokens.appendleft(L)
180+
highlighted_tokens.append(R)
154181

155-
result.append(R)
156-
else:
157-
result.append(tokens[i])
158-
i += 1
182+
result.extend(highlighted_tokens)
159183

160184
return "".join(result)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ dependencies = [
2222
"blobfile",
2323
"bitsandbytes",
2424
"flask",
25-
"vllm",
25+
"vllm>=0.10.2",
2626
"aiofiles",
2727
"sentence_transformers",
2828
"anyio>=4.8.0",

0 commit comments

Comments
 (0)