Skip to content

Commit 5922b3b

Browse files
committed
Add benchmark: CFG rejection sampling + CFG no rejection sampling
1 parent a1a017a commit 5922b3b

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

benchmarks/bench_cfg_guide.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -38,23 +38,31 @@ def setup(self, grammar_name):
3838
)
3939

4040
@staticmethod
41-
def _run_random_cfg(guide):
41+
def _run_random_cfg(guide, rejection_sampling=True):
4242
state = guide.initial_state
4343
token_ids = list(guide.tokenizer.vocabulary.values())
4444
for i in range(40):
4545
# simulate ordering of logits top prob to lowest prob
4646
random.shuffle(token_ids)
4747
# simulate sampling and state update
48-
next_token_id = next(guide.iter_valid_token_ids(state, token_ids))
49-
state = guide.get_next_state(state, next_token_id)
48+
if rejection_sampling:
49+
next_token_id = next(guide.iter_valid_token_ids(state, token_ids))
50+
state = guide.get_next_state(state, next_token_id)
51+
else:
52+
next_token_id = random.choice(guide.get_next_instruction(state).tokens)
53+
state = guide.get_next_state(state, next_token_id)
5054

5155
@cache_disabled()
5256
def time_cfg_guide_setup(self, grammar_name):
5357
CFGGuide(benched_grammars[grammar_name], self.tokenizer)
5458

59+
@cache_disabled()
60+
def time_cfg_guide_run_rejection_sampling(self, grammar):
61+
self._run_random_cfg(self.prebuilt_cfg_guide, rejection_sampling=True)
62+
5563
@cache_disabled()
5664
def time_cfg_guide_run(self, grammar):
57-
self._run_random_cfg(self.prebuilt_cfg_guide)
65+
self._run_random_cfg(self.prebuilt_cfg_guide, rejection_sampling=False)
5866

5967
@cache_disabled()
6068
def peakmem_cfg_guide_run(self, grammar):

0 commit comments

Comments
 (0)