-
Notifications
You must be signed in to change notification settings - Fork 50
fix bugs in _prepare_text and IntruderScorer prompt formatting #153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
a4b6013
7000df5
1e753cc
a9860ef
2c99b16
1d72a34
9ce2ce3
f7e1ff0
513e2e4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,8 @@ | ||
| import random | ||
| from collections import deque | ||
| from dataclasses import dataclass | ||
| from typing import NamedTuple | ||
| from itertools import groupby | ||
| from typing import Callable, NamedTuple | ||
|
|
||
| import torch | ||
|
|
||
|
|
@@ -88,73 +90,95 @@ def _prepare_text( | |
| threshold: float, | ||
| highlighted: bool, | ||
| ) -> tuple[str, list[str]]: | ||
| assert n_incorrect >= 0, ( | ||
| "n_incorrect must be 0 if highlighting correct example " | ||
| "or positive if creating false positives. " | ||
| f"Got {n_incorrect}" | ||
| ) | ||
|
|
||
| str_toks = example.str_tokens | ||
| assert str_toks is not None, "str_toks were not set" | ||
| clean = "".join(str_toks) | ||
|
|
||
| # Just return text if there's no highlighting | ||
| if not highlighted: | ||
| clean = "".join(str_toks) | ||
|
|
||
| return clean, str_toks | ||
|
|
||
| threshold = threshold * example.max_activation | ||
| abs_threshold = threshold * example.max_activation | ||
|
|
||
| # Highlight tokens with activations above threshold | ||
| # if correct example | ||
| if n_incorrect == 0: | ||
|
|
||
| def threshold_check(i): | ||
| return example.activations[i] >= threshold | ||
| def is_above_activation_threshold(i: int) -> bool: | ||
| return example.activations[i] >= abs_threshold | ||
|
|
||
| return _highlight(str_toks, threshold_check), str_toks | ||
| return _highlight(str_toks, is_above_activation_threshold), str_toks | ||
|
|
||
| # Highlight n_incorrect tokens with activations | ||
| # below threshold if incorrect example | ||
| below_threshold = torch.nonzero(example.activations <= threshold).squeeze() | ||
| tokens_below_threshold = torch.nonzero( | ||
| example.activations <= abs_threshold | ||
| ).squeeze() | ||
|
|
||
| # Rare case where there are no tokens below threshold | ||
| if below_threshold.dim() == 0: | ||
| logger.error("Failed to prepare example.") | ||
| if tokens_below_threshold.dim() == 0: | ||
| logger.error( | ||
| f"Tried to prepare false-positive example with {n_incorrect} tokens " | ||
| "incorrectly highlighted, but no tokens were below activation threshold." | ||
| ) | ||
| return DEFAULT_MESSAGE, str_toks | ||
|
|
||
| random.seed(22) | ||
|
|
||
| n_incorrect = min(n_incorrect, len(below_threshold)) | ||
| num_tokens_to_highlight = min(n_incorrect, tokens_below_threshold.shape[0]) | ||
|
|
||
| # The activating token is always ctx_len - ctx_len//4 | ||
| # so we always highlight this one, and if n_incorrect > 1 | ||
| # we highlight n_incorrect-1 random ones | ||
| # so we always highlight this one, and if num_tokens_to_highlight > 1 | ||
| # we highlight num_tokens_to_highlight - 1 random ones | ||
| token_pos = len(str_toks) - len(str_toks) // 4 | ||
| if token_pos in below_threshold: | ||
| if token_pos in tokens_below_threshold: | ||
| random_indices = [token_pos] | ||
| if n_incorrect > 1: | ||
|
|
||
| num_remaining_tokens_to_highlight = num_tokens_to_highlight - 1 | ||
| if num_remaining_tokens_to_highlight > 0: | ||
| remaining_tokens_below_threshold = tokens_below_threshold.tolist() | ||
| remaining_tokens_below_threshold.remove(token_pos) | ||
|
|
||
| random_indices.extend( | ||
| random.sample(below_threshold.tolist(), n_incorrect - 1) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in the old code, this could result in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch |
||
| random.sample( | ||
| remaining_tokens_below_threshold, | ||
| num_remaining_tokens_to_highlight, | ||
| ) | ||
| ) | ||
| else: | ||
| random_indices = random.sample(below_threshold.tolist(), n_incorrect) | ||
| random_indices = random.sample( | ||
| tokens_below_threshold.tolist(), num_tokens_to_highlight | ||
| ) | ||
|
|
||
| random_indices = set(random_indices) | ||
|
|
||
| def check(i): | ||
| def is_false_positive(i): | ||
| return i in random_indices | ||
|
|
||
| return _highlight(str_toks, check), str_toks | ||
| return _highlight(str_toks, is_false_positive), str_toks | ||
|
|
||
|
|
||
| def _highlight(tokens: list[str], check: Callable[[int], bool]) -> str: | ||
| result: deque[str] = deque() | ||
|
|
||
| def _highlight(tokens, check): | ||
| result = [] | ||
| tokens_grouped_by_check_fn = groupby( | ||
| enumerate(tokens), key=lambda item: check(item[0]) | ||
| ) | ||
|
|
||
| i = 0 | ||
| while i < len(tokens): | ||
| if check(i): | ||
| result.append(L) | ||
| for should_highlight, token_group in tokens_grouped_by_check_fn: | ||
|
||
| highlighted_tokens = deque(token for _token_index, token in token_group) | ||
|
|
||
| while i < len(tokens) and check(i): | ||
| result.append(tokens[i]) | ||
| i += 1 | ||
| if should_highlight: | ||
| highlighted_tokens.appendleft(L) | ||
| highlighted_tokens.append(R) | ||
|
|
||
| result.append(R) | ||
| else: | ||
| result.append(tokens[i]) | ||
| i += 1 | ||
| result.extend(highlighted_tokens) | ||
|
|
||
| return "".join(result) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this actually should say "when examples are centered the activating token is always blab".
When the examples are not centered (and its not that easy to check if the activating examples are centered when we are inside this function), this might bias the non activating examples to always have the same tokens selected, but It probably does not matter that much because the model that is being explained and the explainer model don't have the same tokenizer so that information shouldn't leak that much
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, perhaps there should be additional assertions to make sure the activating token is centered