Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions delphi/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,9 @@ async def process_item(self, item: Any, semaphore: asyncio.Semaphore) -> Any:
async with semaphore:
result = item
for pipe in self.pipes:
if result is not None:
result = await pipe(result)
else:
pass
if result is None:
break
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would we want to break

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

once a pipe returns None, we iterate through the rest of the pipes and simply hit the pass until we return the result, which is None. this skips that pointless extra iteration; perhaps it may be clearer if we replace break with return result or return None?


result = await pipe(result)

return result
93 changes: 43 additions & 50 deletions delphi/scorers/classifier/intruder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import re
from dataclasses import dataclass
from itertools import cycle, groupby
from typing import Literal

from beartype.typing import Sequence
Expand Down Expand Up @@ -136,12 +137,15 @@ def _get_quantiled_examples(
"""
Get the quantiled examples.
"""
quantiles = {}
for example in examples:
if example.quantile not in quantiles:
quantiles[example.quantile] = []
quantiles[example.quantile].append(example)
return quantiles
examples_sorted_by_quantile = sorted(examples, key=lambda x: x.quantile)
examples_grouped_by_quantile = groupby(
examples_sorted_by_quantile, key=lambda x: x.quantile
)

return {
quantile: list(examples)
for quantile, examples in examples_grouped_by_quantile
}

def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]:
"""
Expand All @@ -153,38 +157,37 @@ def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]:
quantiled_intruder_sentences = self._get_quantiled_examples(record.test)

intruder_sentences = record.not_active
for i, intruder in enumerate(intruder_sentences):
# select each quantile equally
quantile_index = i % len(quantiled_intruder_sentences.keys())

active_examples = quantiled_intruder_sentences[quantile_index]
# select each quantile equally by repeatedly cycling through them
quantile_iterator = cycle(quantiled_intruder_sentences.items())
for (active_quantile, all_active_examples), intruder in zip(
quantile_iterator, intruder_sentences
):
# if there are more examples than the number of examples to show,
# sample which examples to show
examples_to_show = min(self.n_examples_shown - 1, len(active_examples))
example_indices = self.rng.sample(
range(len(active_examples)), examples_to_show
num_active_examples = min(
self.n_examples_shown - 1, len(all_active_examples)
)
active_examples = [active_examples[i] for i in example_indices]
active_examples = self.rng.sample(all_active_examples, num_active_examples)

# convert the examples to strings

# highlights the active tokens
majority_examples = []
active_tokens = 0
# highlights the active tokens with <<>> markers
examples = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I like the name majority examples more

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, my issue is that it does not only contain the activating majority samples; later on, after majority_examples.insert(intruder_index, intruder_sentence), it also includes the intruder.

i think a better solution may be to keep the name majority_samples up until that point and reassigning it to a new examples variable that contains both the majority and intruder samples

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed

num_active_tokens = 0
for example in active_examples:
text, _ = _prepare_text(
text, _str_tokens = _prepare_text(
example, n_incorrect=0, threshold=0.3, highlighted=True
)
majority_examples.append(text)
active_tokens += (example.activations > 0).sum().item()
active_tokens = int(active_tokens / len(active_examples))
examples.append(text)
num_active_tokens += (example.activations > 0).sum().item()

avg_active_tokens_per_example = num_active_tokens // len(active_examples)
if self.type == "default":
# if example is contrastive, use the active tokens
# otherwise use the non-activating tokens
if intruder.activations.max() > 0:
n_incorrect = 0
else:
n_incorrect = active_tokens
n_incorrect = avg_active_tokens_per_example
intruder_sentence, _ = _prepare_text(
intruder,
n_incorrect=n_incorrect,
Expand All @@ -194,22 +197,15 @@ def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]:
elif self.type == "internal":
# randomly select a quantile to be the intruder, make sure it's not
# the same as the source quantile
intruder_quantile_index = self.rng.randint(
0, len(quantiled_intruder_sentences.keys()) - 1
)
while intruder_quantile_index == quantile_index:
intruder_quantile_index = self.rng.randint(
0, len(quantiled_intruder_sentences.keys()) - 1
)
posible_intruder_sentences = quantiled_intruder_sentences[
intruder_quantile_index
]
intruder_index_selected = self.rng.randint(
0, len(posible_intruder_sentences) - 1
)
intruder = posible_intruder_sentences[intruder_index_selected]
all_quantile_examples = list(quantiled_intruder_sentences.values())
all_quantile_examples.remove(all_active_examples)
possible_intruder_sentences = self.rng.choice(all_quantile_examples)

intruder = self.rng.choice(possible_intruder_sentences)
# here the examples are activating, so we have to convert them
# to non-activating examples
assert intruder.str_tokens is not None, "intruder has no str_tokens"

non_activating_intruder = NonActivatingExample(
tokens=intruder.tokens,
activations=intruder.activations,
Expand All @@ -226,21 +222,19 @@ def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]:
intruder = non_activating_intruder

# select a random index to insert the intruder sentence
intruder_index = self.rng.randint(0, examples_to_show)
majority_examples.insert(intruder_index, intruder_sentence)
intruder_index = self.rng.randint(0, num_active_examples)
examples.insert(intruder_index, intruder_sentence)

activations = [example.activations.tolist() for example in active_examples]
tokens = [example.str_tokens for example in active_examples]
activations.insert(intruder_index, intruder.activations.tolist())
tokens.insert(intruder_index, intruder.str_tokens)
example_activations = [example.activations.tolist() for example in examples]
example_tokens = [example.str_tokens for example in examples]

batches.append(
IntruderSentence(
examples=majority_examples,
examples=examples,
intruder_index=intruder_index,
chosen_quantile=quantile_index,
activations=activations,
tokens=tokens,
chosen_quantile=active_quantile,
activations=example_activations,
tokens=example_tokens,
intruder_distance=intruder.distance,
)
)
Expand Down Expand Up @@ -275,7 +269,7 @@ def _build_prompt(
"""

examples = "\n".join(
f"Example {i}: {example}" for i, example in enumerate(sample.examples)
f"Example {i}:{example}" for i, example in enumerate(sample.examples)
)

return self.prompt(examples=examples)
Expand Down Expand Up @@ -319,7 +313,6 @@ async def _generate(self, sample: IntruderSentence) -> IntruderResult:
# default result is a error
return IntruderResult()
else:

try:
interpretation, prediction = self._parse(response.text)
except Exception as e:
Expand Down
86 changes: 55 additions & 31 deletions delphi/scorers/classifier/sample.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import random
from collections import deque
from dataclasses import dataclass
from typing import NamedTuple
from itertools import groupby
from typing import Callable, NamedTuple

import torch

Expand Down Expand Up @@ -88,73 +90,95 @@ def _prepare_text(
threshold: float,
highlighted: bool,
) -> tuple[str, list[str]]:
assert n_incorrect >= 0, (
"n_incorrect must be 0 if highlighting correct example "
"or positive if creating false positives. "
f"Got {n_incorrect}"
)

str_toks = example.str_tokens
assert str_toks is not None, "str_toks were not set"
clean = "".join(str_toks)

# Just return text if there's no highlighting
if not highlighted:
clean = "".join(str_toks)

return clean, str_toks

threshold = threshold * example.max_activation
abs_threshold = threshold * example.max_activation

# Highlight tokens with activations above threshold
# if correct example
if n_incorrect == 0:

def threshold_check(i):
return example.activations[i] >= threshold
def is_above_activation_threshold(i: int) -> bool:
return example.activations[i] >= abs_threshold

return _highlight(str_toks, threshold_check), str_toks
return _highlight(str_toks, is_above_activation_threshold), str_toks

# Highlight n_incorrect tokens with activations
# below threshold if incorrect example
below_threshold = torch.nonzero(example.activations <= threshold).squeeze()
tokens_below_threshold = torch.nonzero(
example.activations <= abs_threshold
).squeeze()

# Rare case where there are no tokens below threshold
if below_threshold.dim() == 0:
logger.error("Failed to prepare example.")
if tokens_below_threshold.dim() == 0:
logger.error(
f"Tried to prepare false-positive example with {n_incorrect} tokens "
"incorrectly highlighted, but no tokens were below activation threshold."
)
return DEFAULT_MESSAGE, str_toks

random.seed(22)

n_incorrect = min(n_incorrect, len(below_threshold))
num_tokens_to_highlight = min(n_incorrect, tokens_below_threshold.shape[0])

# The activating token is always ctx_len - ctx_len//4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this actually should say "when examples are centered the activating token is always blab".
When the examples are not centered (and its not that easy to check if the activating examples are centered when we are inside this function), this might bias the non activating examples to always have the same tokens selected, but It probably does not matter that much because the model that is being explained and the explainer model don't have the same tokenizer so that information shouldn't leak that much

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, perhaps there should be additional assertions to make sure the activating token is centered

# so we always highlight this one, and if n_incorrect > 1
# we highlight n_incorrect-1 random ones
# so we always highlight this one, and if num_tokens_to_highlight > 1
# we highlight num_tokens_to_highlight - 1 random ones
token_pos = len(str_toks) - len(str_toks) // 4
if token_pos in below_threshold:
if token_pos in tokens_below_threshold:
random_indices = [token_pos]
if n_incorrect > 1:

num_remaining_tokens_to_highlight = num_tokens_to_highlight - 1
if num_remaining_tokens_to_highlight > 0:
remaining_tokens_below_threshold = tokens_below_threshold.tolist()
remaining_tokens_below_threshold.remove(token_pos)

random_indices.extend(
random.sample(below_threshold.tolist(), n_incorrect - 1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the old code, this could result in token_pos being selected again since it's still in below_threshold. then, after being turned into a set, random_indices would have one less element than expected, resulting in one fewer token being incorrectly highlighted than was specified by n_incorrect

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

random.sample(
remaining_tokens_below_threshold,
num_remaining_tokens_to_highlight,
)
)
else:
random_indices = random.sample(below_threshold.tolist(), n_incorrect)
random_indices = random.sample(
tokens_below_threshold.tolist(), num_tokens_to_highlight
)

random_indices = set(random_indices)

def check(i):
def is_false_positive(i):
return i in random_indices

return _highlight(str_toks, check), str_toks
return _highlight(str_toks, is_false_positive), str_toks


def _highlight(tokens: list[str], check: Callable[[int], bool]) -> str:
result: deque[str] = deque()

def _highlight(tokens, check):
result = []
tokens_grouped_by_check_fn = groupby(
enumerate(tokens), key=lambda item: check(item[0])
)

i = 0
while i < len(tokens):
if check(i):
result.append(L)
for should_highlight, token_group in tokens_grouped_by_check_fn:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be ok with using these fancy groupings, but for me this is harder to understand what is happening

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its not obvious to me that this is doing what it should do. If the tokens to be highlighted are concecutive does it correctly close the brackets? I'm not sure I understand the expected behaviour of groupby

Copy link
Contributor Author

@d0rbu d0rbu Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah i'm fine with not making this change, groupby is kind of misleading. the python docs put it best:

The operation of groupby() is similar to the uniq filter in Unix. It generates a break or new group every time the value of the key function changes (which is why it is usually necessary to have sorted the data using the same key function). That behavior differs from SQL’s GROUP BY which aggregates common elements regardless of their input order.

so, for example, using groupby() on a sequence like 001110 would result in 00 111 0 rather than 000 111. the behavior should be right but it's easy to mistake it for buggy code

highlighted_tokens = deque(token for _token_index, token in token_group)

while i < len(tokens) and check(i):
result.append(tokens[i])
i += 1
if should_highlight:
highlighted_tokens.appendleft(L)
highlighted_tokens.append(R)

result.append(R)
else:
result.append(tokens[i])
i += 1
result.extend(highlighted_tokens)

return "".join(result)