11import asyncio
22import re
3+ from collections import defaultdict
34from dataclasses import dataclass
5+ from itertools import cycle
46from typing import Literal
57
68from beartype .typing import Sequence
@@ -136,12 +138,11 @@ def _get_quantiled_examples(
136138 """
137139 Get the quantiled examples.
138140 """
139- quantiles = {}
141+ examples_grouped_by_quantiles = defaultdict ( list )
140142 for example in examples :
141- if example .quantile not in quantiles :
142- quantiles [example .quantile ] = []
143- quantiles [example .quantile ].append (example )
144- return quantiles
143+ examples_grouped_by_quantiles [example .quantile ].append (example )
144+
145+ return examples_grouped_by_quantiles
145146
146147 def _prepare_and_batch (self , record : LatentRecord ) -> list [IntruderSentence ]:
147148 """
@@ -153,38 +154,39 @@ def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]:
153154 quantiled_intruder_sentences = self ._get_quantiled_examples (record .test )
154155
155156 intruder_sentences = record .not_active
156- for i , intruder in enumerate (intruder_sentences ):
157- # select each quantile equally
158- quantile_index = i % len (quantiled_intruder_sentences .keys ())
159157
160- active_examples = quantiled_intruder_sentences [quantile_index ]
158+ # select each quantile equally by repeatedly cycling through them
159+ quantile_iterator = cycle (quantiled_intruder_sentences .items ())
160+ for (active_quantile , all_active_examples ), intruder in zip (
161+ quantile_iterator , intruder_sentences
162+ ):
161163 # if there are more examples than the number of examples to show,
162164 # sample which examples to show
163- examples_to_show = min (self .n_examples_shown - 1 , len (active_examples ))
164- example_indices = self .rng .sample (
165- range (len (active_examples )), examples_to_show
165+ num_active_examples = min (
166+ # - 1 because we are going to insert the intruder sentence
167+ self .n_examples_shown - 1 ,
168+ len (all_active_examples ),
166169 )
167- active_examples = [active_examples [i ] for i in example_indices ]
168-
169- # convert the examples to strings
170+ active_examples = self .rng .sample (all_active_examples , num_active_examples )
170171
171- # highlights the active tokens
172+ # highlights the active tokens with <<>> markers
172173 majority_examples = []
173- active_tokens = 0
174+ num_active_tokens = 0
174175 for example in active_examples :
175- text , _ = _prepare_text (
176+ text , _str_tokens = _prepare_text (
176177 example , n_incorrect = 0 , threshold = 0.3 , highlighted = True
177178 )
178179 majority_examples .append (text )
179- active_tokens += (example .activations > 0 ).sum ().item ()
180- active_tokens = int (active_tokens / len (active_examples ))
180+ num_active_tokens += (example .activations > 0 ).sum ().item ()
181+
182+ avg_active_tokens_per_example = num_active_tokens // len (active_examples )
181183 if self .type == "default" :
182184 # if example is contrastive, use the active tokens
183185 # otherwise use the non-activating tokens
184186 if intruder .activations .max () > 0 :
185187 n_incorrect = 0
186188 else :
187- n_incorrect = active_tokens
189+ n_incorrect = avg_active_tokens_per_example
188190 intruder_sentence , _ = _prepare_text (
189191 intruder ,
190192 n_incorrect = n_incorrect ,
@@ -194,22 +196,15 @@ def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]:
194196 elif self .type == "internal" :
195197 # randomly select a quantile to be the intruder, make sure it's not
196198 # the same as the source quantile
197- intruder_quantile_index = self .rng .randint (
198- 0 , len (quantiled_intruder_sentences .keys ()) - 1
199- )
200- while intruder_quantile_index == quantile_index :
201- intruder_quantile_index = self .rng .randint (
202- 0 , len (quantiled_intruder_sentences .keys ()) - 1
203- )
204- posible_intruder_sentences = quantiled_intruder_sentences [
205- intruder_quantile_index
206- ]
207- intruder_index_selected = self .rng .randint (
208- 0 , len (posible_intruder_sentences ) - 1
209- )
210- intruder = posible_intruder_sentences [intruder_index_selected ]
199+ all_quantile_examples = list (quantiled_intruder_sentences .values ())
200+ all_quantile_examples .remove (all_active_examples )
201+ possible_intruder_sentences = self .rng .choice (all_quantile_examples )
202+
203+ intruder = self .rng .choice (possible_intruder_sentences )
211204 # here the examples are activating, so we have to convert them
212205 # to non-activating examples
206+ assert intruder .str_tokens is not None , "intruder has no str_tokens"
207+
213208 non_activating_intruder = NonActivatingExample (
214209 tokens = intruder .tokens ,
215210 activations = intruder .activations ,
@@ -224,23 +219,27 @@ def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]:
224219 highlighted = True ,
225220 )
226221 intruder = non_activating_intruder
222+ else :
223+ raise ValueError ("Invalid intruder scorer type" )
227224
228225 # select a random index to insert the intruder sentence
229- intruder_index = self .rng .randint (0 , examples_to_show )
230- majority_examples .insert (intruder_index , intruder_sentence )
226+ intruder_index = self .rng .randint (0 , num_active_examples )
227+ examples = (
228+ majority_examples [:intruder_index ]
229+ + [intruder_sentence ]
230+ + majority_examples [intruder_index :]
231+ )
231232
232- activations = [example .activations .tolist () for example in active_examples ]
233- tokens = [example .str_tokens for example in active_examples ]
234- activations .insert (intruder_index , intruder .activations .tolist ())
235- tokens .insert (intruder_index , intruder .str_tokens )
233+ example_activations = [example .activations .tolist () for example in examples ]
234+ example_tokens = [example .str_tokens for example in examples ]
236235
237236 batches .append (
238237 IntruderSentence (
239- examples = majority_examples ,
238+ examples = examples ,
240239 intruder_index = intruder_index ,
241- chosen_quantile = quantile_index ,
242- activations = activations ,
243- tokens = tokens ,
240+ chosen_quantile = active_quantile ,
241+ activations = example_activations ,
242+ tokens = example_tokens ,
244243 intruder_distance = intruder .distance ,
245244 )
246245 )
@@ -275,7 +274,7 @@ def _build_prompt(
275274 """
276275
277276 examples = "\n " .join (
278- f"Example { i } : { example } " for i , example in enumerate (sample .examples )
277+ f"Example { i } :{ example } " for i , example in enumerate (sample .examples )
279278 )
280279
281280 return self .prompt (examples = examples )
@@ -311,21 +310,11 @@ async def _generate(self, sample: IntruderSentence) -> IntruderResult:
311310 prompt = self ._build_prompt (sample )
312311 try :
313312 response = await self .client .generate (prompt , ** self .generation_kwargs )
313+ interpretation , prediction = self ._parse (response .text )
314314 except Exception as e :
315- logger .error (f"Error generating text: { e } " )
316- response = None
317-
318- if response is None :
315+ logger .error (str (e ))
319316 # default result is a error
320317 return IntruderResult ()
321- else :
322-
323- try :
324- interpretation , prediction = self ._parse (response .text )
325- except Exception as e :
326- logger .error (f"Parsing selections failed: { e } " )
327- # default result is a error
328- return IntruderResult ()
329318
330319 # check that the only prediction is the intruder
331320 correct = prediction == sample .intruder_index
0 commit comments