-
Notifications
You must be signed in to change notification settings - Fork 50
fix bugs in _prepare_text and IntruderScorer prompt formatting #153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
a4b6013
7000df5
1e753cc
a9860ef
2c99b16
1d72a34
9ce2ce3
f7e1ff0
513e2e4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| import asyncio | ||
| import re | ||
| from dataclasses import dataclass | ||
| from itertools import cycle, groupby | ||
| from typing import Literal | ||
|
|
||
| from beartype.typing import Sequence | ||
|
|
@@ -136,12 +137,15 @@ def _get_quantiled_examples( | |
| """ | ||
| Get the quantiled examples. | ||
| """ | ||
| quantiles = {} | ||
| for example in examples: | ||
| if example.quantile not in quantiles: | ||
| quantiles[example.quantile] = [] | ||
| quantiles[example.quantile].append(example) | ||
| return quantiles | ||
| examples_sorted_by_quantile = sorted(examples, key=lambda x: x.quantile) | ||
| examples_grouped_by_quantile = groupby( | ||
| examples_sorted_by_quantile, key=lambda x: x.quantile | ||
| ) | ||
|
|
||
| return { | ||
| quantile: list(examples) | ||
| for quantile, examples in examples_grouped_by_quantile | ||
| } | ||
|
|
||
| def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]: | ||
| """ | ||
|
|
@@ -153,38 +157,37 @@ def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]: | |
| quantiled_intruder_sentences = self._get_quantiled_examples(record.test) | ||
|
|
||
| intruder_sentences = record.not_active | ||
| for i, intruder in enumerate(intruder_sentences): | ||
| # select each quantile equally | ||
| quantile_index = i % len(quantiled_intruder_sentences.keys()) | ||
|
|
||
| active_examples = quantiled_intruder_sentences[quantile_index] | ||
| # select each quantile equally by repeatedly cycling through them | ||
| quantile_iterator = cycle(quantiled_intruder_sentences.items()) | ||
| for (active_quantile, all_active_examples), intruder in zip( | ||
| quantile_iterator, intruder_sentences | ||
| ): | ||
| # if there are more examples than the number of examples to show, | ||
| # sample which examples to show | ||
| examples_to_show = min(self.n_examples_shown - 1, len(active_examples)) | ||
| example_indices = self.rng.sample( | ||
| range(len(active_examples)), examples_to_show | ||
| num_active_examples = min( | ||
| self.n_examples_shown - 1, len(all_active_examples) | ||
| ) | ||
| active_examples = [active_examples[i] for i in example_indices] | ||
| active_examples = self.rng.sample(all_active_examples, num_active_examples) | ||
|
|
||
| # convert the examples to strings | ||
|
|
||
| # highlights the active tokens | ||
| majority_examples = [] | ||
| active_tokens = 0 | ||
| # highlights the active tokens with <<>> markers | ||
| examples = [] | ||
|
||
| num_active_tokens = 0 | ||
| for example in active_examples: | ||
| text, _ = _prepare_text( | ||
| text, _str_tokens = _prepare_text( | ||
| example, n_incorrect=0, threshold=0.3, highlighted=True | ||
| ) | ||
| majority_examples.append(text) | ||
| active_tokens += (example.activations > 0).sum().item() | ||
| active_tokens = int(active_tokens / len(active_examples)) | ||
| examples.append(text) | ||
| num_active_tokens += (example.activations > 0).sum().item() | ||
|
|
||
| avg_active_tokens_per_example = num_active_tokens // len(active_examples) | ||
| if self.type == "default": | ||
| # if example is contrastive, use the active tokens | ||
| # otherwise use the non-activating tokens | ||
| if intruder.activations.max() > 0: | ||
| n_incorrect = 0 | ||
| else: | ||
| n_incorrect = active_tokens | ||
| n_incorrect = avg_active_tokens_per_example | ||
| intruder_sentence, _ = _prepare_text( | ||
| intruder, | ||
| n_incorrect=n_incorrect, | ||
|
|
@@ -194,22 +197,15 @@ def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]: | |
| elif self.type == "internal": | ||
| # randomly select a quantile to be the intruder, make sure it's not | ||
| # the same as the source quantile | ||
| intruder_quantile_index = self.rng.randint( | ||
| 0, len(quantiled_intruder_sentences.keys()) - 1 | ||
| ) | ||
| while intruder_quantile_index == quantile_index: | ||
| intruder_quantile_index = self.rng.randint( | ||
| 0, len(quantiled_intruder_sentences.keys()) - 1 | ||
| ) | ||
| posible_intruder_sentences = quantiled_intruder_sentences[ | ||
| intruder_quantile_index | ||
| ] | ||
| intruder_index_selected = self.rng.randint( | ||
| 0, len(posible_intruder_sentences) - 1 | ||
| ) | ||
| intruder = posible_intruder_sentences[intruder_index_selected] | ||
| all_quantile_examples = list(quantiled_intruder_sentences.values()) | ||
| all_quantile_examples.remove(all_active_examples) | ||
| possible_intruder_sentences = self.rng.choice(all_quantile_examples) | ||
|
|
||
| intruder = self.rng.choice(possible_intruder_sentences) | ||
| # here the examples are activating, so we have to convert them | ||
| # to non-activating examples | ||
| assert intruder.str_tokens is not None, "intruder has no str_tokens" | ||
|
|
||
| non_activating_intruder = NonActivatingExample( | ||
| tokens=intruder.tokens, | ||
| activations=intruder.activations, | ||
|
|
@@ -226,21 +222,19 @@ def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]: | |
| intruder = non_activating_intruder | ||
|
|
||
| # select a random index to insert the intruder sentence | ||
| intruder_index = self.rng.randint(0, examples_to_show) | ||
| majority_examples.insert(intruder_index, intruder_sentence) | ||
| intruder_index = self.rng.randint(0, num_active_examples) | ||
| examples.insert(intruder_index, intruder_sentence) | ||
|
|
||
| activations = [example.activations.tolist() for example in active_examples] | ||
| tokens = [example.str_tokens for example in active_examples] | ||
| activations.insert(intruder_index, intruder.activations.tolist()) | ||
| tokens.insert(intruder_index, intruder.str_tokens) | ||
| example_activations = [example.activations.tolist() for example in examples] | ||
| example_tokens = [example.str_tokens for example in examples] | ||
|
|
||
| batches.append( | ||
| IntruderSentence( | ||
| examples=majority_examples, | ||
| examples=examples, | ||
| intruder_index=intruder_index, | ||
| chosen_quantile=quantile_index, | ||
| activations=activations, | ||
| tokens=tokens, | ||
| chosen_quantile=active_quantile, | ||
| activations=example_activations, | ||
| tokens=example_tokens, | ||
| intruder_distance=intruder.distance, | ||
| ) | ||
| ) | ||
|
|
@@ -275,7 +269,7 @@ def _build_prompt( | |
| """ | ||
|
|
||
| examples = "\n".join( | ||
| f"Example {i}: {example}" for i, example in enumerate(sample.examples) | ||
SrGonao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| f"Example {i}:{example}" for i, example in enumerate(sample.examples) | ||
| ) | ||
|
|
||
| return self.prompt(examples=examples) | ||
|
|
@@ -319,7 +313,6 @@ async def _generate(self, sample: IntruderSentence) -> IntruderResult: | |
| # default result is a error | ||
| return IntruderResult() | ||
| else: | ||
|
|
||
| try: | ||
| interpretation, prediction = self._parse(response.text) | ||
| except Exception as e: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,8 @@ | ||
| import random | ||
| from collections import deque | ||
| from dataclasses import dataclass | ||
| from typing import NamedTuple | ||
| from itertools import groupby | ||
| from typing import Callable, NamedTuple | ||
|
|
||
| import torch | ||
|
|
||
|
|
@@ -88,73 +90,95 @@ def _prepare_text( | |
| threshold: float, | ||
| highlighted: bool, | ||
| ) -> tuple[str, list[str]]: | ||
| assert n_incorrect >= 0, ( | ||
| "n_incorrect must be 0 if highlighting correct example " | ||
| "or positive if creating false positives. " | ||
| f"Got {n_incorrect}" | ||
| ) | ||
|
|
||
| str_toks = example.str_tokens | ||
| assert str_toks is not None, "str_toks were not set" | ||
| clean = "".join(str_toks) | ||
|
|
||
| # Just return text if there's no highlighting | ||
| if not highlighted: | ||
| clean = "".join(str_toks) | ||
|
|
||
| return clean, str_toks | ||
|
|
||
| threshold = threshold * example.max_activation | ||
| abs_threshold = threshold * example.max_activation | ||
|
|
||
| # Highlight tokens with activations above threshold | ||
| # if correct example | ||
| if n_incorrect == 0: | ||
|
|
||
| def threshold_check(i): | ||
| return example.activations[i] >= threshold | ||
| def is_above_activation_threshold(i: int) -> bool: | ||
| return example.activations[i] >= abs_threshold | ||
|
|
||
| return _highlight(str_toks, threshold_check), str_toks | ||
| return _highlight(str_toks, is_above_activation_threshold), str_toks | ||
|
|
||
| # Highlight n_incorrect tokens with activations | ||
| # below threshold if incorrect example | ||
| below_threshold = torch.nonzero(example.activations <= threshold).squeeze() | ||
| tokens_below_threshold = torch.nonzero( | ||
| example.activations <= abs_threshold | ||
| ).squeeze() | ||
|
|
||
| # Rare case where there are no tokens below threshold | ||
| if below_threshold.dim() == 0: | ||
| logger.error("Failed to prepare example.") | ||
| if tokens_below_threshold.dim() == 0: | ||
| logger.error( | ||
| f"Tried to prepare false-positive example with {n_incorrect} tokens " | ||
| "incorrectly highlighted, but no tokens were below activation threshold." | ||
| ) | ||
| return DEFAULT_MESSAGE, str_toks | ||
|
|
||
| random.seed(22) | ||
|
|
||
| n_incorrect = min(n_incorrect, len(below_threshold)) | ||
| num_tokens_to_highlight = min(n_incorrect, tokens_below_threshold.shape[0]) | ||
|
|
||
| # The activating token is always ctx_len - ctx_len//4 | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this actually should say "when examples are centered the activating token is always blab". There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good point, perhaps there should be additional assertions to make sure the activating token is centered |
||
| # so we always highlight this one, and if n_incorrect > 1 | ||
| # we highlight n_incorrect-1 random ones | ||
| # so we always highlight this one, and if num_tokens_to_highlight > 1 | ||
| # we highlight num_tokens_to_highlight - 1 random ones | ||
| token_pos = len(str_toks) - len(str_toks) // 4 | ||
| if token_pos in below_threshold: | ||
| if token_pos in tokens_below_threshold: | ||
| random_indices = [token_pos] | ||
| if n_incorrect > 1: | ||
|
|
||
| num_remaining_tokens_to_highlight = num_tokens_to_highlight - 1 | ||
| if num_remaining_tokens_to_highlight > 0: | ||
| remaining_tokens_below_threshold = tokens_below_threshold.tolist() | ||
| remaining_tokens_below_threshold.remove(token_pos) | ||
|
|
||
| random_indices.extend( | ||
| random.sample(below_threshold.tolist(), n_incorrect - 1) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in the old code, this could result in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch |
||
| random.sample( | ||
| remaining_tokens_below_threshold, | ||
| num_remaining_tokens_to_highlight, | ||
| ) | ||
| ) | ||
| else: | ||
| random_indices = random.sample(below_threshold.tolist(), n_incorrect) | ||
| random_indices = random.sample( | ||
| tokens_below_threshold.tolist(), num_tokens_to_highlight | ||
| ) | ||
|
|
||
| random_indices = set(random_indices) | ||
|
|
||
| def check(i): | ||
| def is_false_positive(i): | ||
| return i in random_indices | ||
|
|
||
| return _highlight(str_toks, check), str_toks | ||
| return _highlight(str_toks, is_false_positive), str_toks | ||
|
|
||
|
|
||
| def _highlight(tokens: list[str], check: Callable[[int], bool]) -> str: | ||
| result: deque[str] = deque() | ||
|
|
||
| def _highlight(tokens, check): | ||
| result = [] | ||
| tokens_grouped_by_check_fn = groupby( | ||
| enumerate(tokens), key=lambda item: check(item[0]) | ||
| ) | ||
|
|
||
| i = 0 | ||
| while i < len(tokens): | ||
| if check(i): | ||
| result.append(L) | ||
| for should_highlight, token_group in tokens_grouped_by_check_fn: | ||
|
||
| highlighted_tokens = deque(token for _token_index, token in token_group) | ||
|
|
||
| while i < len(tokens) and check(i): | ||
| result.append(tokens[i]) | ||
| i += 1 | ||
| if should_highlight: | ||
| highlighted_tokens.appendleft(L) | ||
| highlighted_tokens.append(R) | ||
|
|
||
| result.append(R) | ||
| else: | ||
| result.append(tokens[i]) | ||
| i += 1 | ||
| result.extend(highlighted_tokens) | ||
|
|
||
| return "".join(result) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would we want to break
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
once a pipe returns
None, we iterate through the rest of the pipes and simply hit thepassuntil we return the result, which is None. this skips that pointless extra iteration; perhaps it may be clearer if we replacebreakwithreturn resultorreturn None?