@@ -38,23 +38,31 @@ def setup(self, grammar_name):
38
38
)
39
39
40
40
@staticmethod
41
- def _run_random_cfg (guide ):
41
+ def _run_random_cfg (guide , rejection_sampling = True ):
42
42
state = guide .initial_state
43
43
token_ids = list (guide .tokenizer .vocabulary .values ())
44
44
for i in range (40 ):
45
45
# simulate ordering of logits top prob to lowest prob
46
46
random .shuffle (token_ids )
47
47
# simulate sampling and state update
48
- next_token_id = next (guide .iter_valid_token_ids (state , token_ids ))
49
- state = guide .get_next_state (state , next_token_id )
48
+ if rejection_sampling :
49
+ next_token_id = next (guide .iter_valid_token_ids (state , token_ids ))
50
+ state = guide .get_next_state (state , next_token_id )
51
+ else :
52
+ next_token_id = random .choice (guide .get_next_instruction (state ).tokens )
53
+ state = guide .get_next_state (state , next_token_id )
50
54
51
55
@cache_disabled ()
52
56
def time_cfg_guide_setup (self , grammar_name ):
53
57
CFGGuide (benched_grammars [grammar_name ], self .tokenizer )
54
58
59
+ @cache_disabled ()
60
+ def time_cfg_guide_run_rejection_sampling (self , grammar ):
61
+ self ._run_random_cfg (self .prebuilt_cfg_guide , rejection_sampling = True )
62
+
55
63
@cache_disabled ()
56
64
def time_cfg_guide_run (self , grammar ):
57
- self ._run_random_cfg (self .prebuilt_cfg_guide )
65
+ self ._run_random_cfg (self .prebuilt_cfg_guide , rejection_sampling = False )
58
66
59
67
@cache_disabled ()
60
68
def peakmem_cfg_guide_run (self , grammar ):
0 commit comments