Skip to content

Commit 19129e5

Browse files
committed
add rejection sampling to CFGLogitsProcessor
1 parent 21d61d1 commit 19129e5

File tree

5 files changed

+90
-17
lines changed

5 files changed

+90
-17
lines changed

benchmarks/bench_cfg_guide.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,13 @@ def setup(self, grammar_name):
4040
@staticmethod
4141
def _run_random_cfg(guide):
4242
state = guide.initial_state
43+
4344
for i in range(40):
44-
next_instruction = guide.get_next_instruction(state)
45-
next_token_id = random.choice(next_instruction.tokens)
45+
# simulate ordering of logits top prob to lowest prob
46+
token_ids = list(range(guide.tokenizer.vocabulary))
47+
random.shuffle(token_ids)
48+
# simulate sampling and state update
49+
next_token_id = next(guide.iter_valid_token_ids(state, token_ids))
4650
state = guide.get_next_state(state, next_token_id)
4751

4852
@cache_disabled()

outlines/fsm/guide.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
Any,
66
Callable,
77
Dict,
8+
Generator,
89
List,
910
Optional,
1011
Protocol,
@@ -351,26 +352,54 @@ def get_next_instruction(
351352
if parser_state is None:
352353
return Write(torch.tensor([self.eos_token_id]))
353354

354-
valid_tokens = []
355-
for test_token, token_id in self.tokenizer.vocabulary.items():
355+
valid_tokens = list(
356+
self.iter_valid_token_ids(parser_state, self.tokenizer.vocabulary.values())
357+
)
358+
if len(valid_tokens) == 1:
359+
return Write(torch.tensor(valid_tokens))
360+
return Generate(torch.tensor(valid_tokens))
361+
362+
def iter_valid_token_ids(
363+
self, parser_state: Optional[PartialParserState], token_ids: list
364+
) -> Generator[int, None, None]:
365+
"""
366+
Iterate over the given token_ids and yield those that are valid for the current parser state.
367+
368+
Parameters
369+
----------
370+
parser_state
371+
The current state of the parser, or None if complete.
372+
token_ids
373+
The list of token ids to check for validity.
374+
375+
Yields
376+
------
377+
int
378+
Valid token ids.
379+
"""
380+
if parser_state is None:
381+
yield self.eos_token_id
382+
return
383+
384+
for token_id in token_ids:
356385
if token_id == self.eos_token_id:
357386
if self.can_terminate_state(parser_state):
358-
valid_tokens.append(token_id)
359-
387+
yield token_id
360388
else:
361389
ps = copy.copy(parser_state)
362390
ls = ps.lexer.state
363-
ls.text += self.tokenizer.convert_token_to_string(test_token)
391+
token_str = self.tokenizer.convert_token_to_string(
392+
self.tokenizer.decode([token_id])[0]
393+
)
394+
if token_str == "":
395+
continue
396+
ls.text += token_str
364397
try:
365398
self.parser.parse_from_state(ps, is_end=False)
366-
valid_tokens.append(token_id)
399+
yield token_id
367400
except (EOFError, UnexpectedToken, UnexpectedCharacters, DedentError):
368401
pass
369402

370-
if len(valid_tokens) == 1:
371-
return Write(torch.tensor(valid_tokens))
372-
return Generate(torch.tensor(valid_tokens))
373-
374403
def get_next_state(
375404
self, parser_state: Optional[PartialParserState], token_id: int
376405
) -> Optional[PartialParserState]:

outlines/processors/structured.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
limitations under the License.
2525
"""
2626
import math
27-
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union
27+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
2828

2929
import torch
3030
from pydantic import BaseModel
@@ -50,6 +50,11 @@ class GuideLogitsProcessor(OutlinesLogitsProcessor):
5050
The `outlines.fsm.Guide` which is used to bias the logits.
5151
"""
5252

53+
tokenizer: "Tokenizer"
54+
guide: Guide
55+
_guide_states: Dict[int, Any]
56+
_seq_start_idx: Optional[int]
57+
5358
def __init__(self, tokenizer: "Tokenizer", guide: Guide):
5459
"""A Guide-based logits processor.
5560
@@ -61,9 +66,9 @@ def __init__(self, tokenizer: "Tokenizer", guide: Guide):
6166
The `outlines.fsm.Guide. which is used to bias the logits.
6267
"""
6368
self.tokenizer = tokenizer
64-
self.guide: Guide = guide
65-
self._guide_states: Dict[int, int] = {hash(tuple([])): self.guide.initial_state}
66-
self._seq_start_idx: Optional[int] = None
69+
self.guide = guide
70+
self._guide_states = {hash(tuple([])): self.guide.initial_state}
71+
self._seq_start_idx = None
6772

6873
def process_logits(
6974
self, input_ids: List[List[int]], logits: torch.Tensor
@@ -181,6 +186,8 @@ class CFGLogitsProcessor(GuideLogitsProcessor):
181186
The `outlines.fsm.CFGGuide. which is used to bias the logits.
182187
"""
183188

189+
guide: CFGGuide
190+
184191
def __init__(self, cfg_str: str, tokenizer: "Tokenizer"):
185192
"""Compile the CFGGuide that drives the CFG-guided generation.
186193
@@ -193,3 +200,34 @@ def __init__(self, cfg_str: str, tokenizer: "Tokenizer"):
193200
"""
194201
cfg_guide = CFGGuide(cfg_string=cfg_str, tokenizer=tokenizer)
195202
super().__init__(tokenizer=tokenizer, guide=cfg_guide)
203+
204+
def process_logits(
205+
self, input_ids: List[List[int]], logits: torch.Tensor
206+
) -> torch.Tensor:
207+
"""Same behavior as GuideLogitsProcessor, but uses rejection sampling"""
208+
if self._seq_start_idx is None:
209+
self._seq_start_idx = len(input_ids[0])
210+
211+
sequence_states: List = [] # vector of states corresponding to `input_ids`
212+
213+
for seq_ids in input_ids:
214+
gen_ids = seq_ids[self._seq_start_idx :]
215+
curr_state_key = hash(tuple(gen_ids))
216+
217+
if curr_state_key not in self._guide_states:
218+
prev_state = self._guide_states[hash(tuple(gen_ids[:-1]))]
219+
curr_state = self.guide.get_next_state(prev_state, gen_ids[-1])
220+
self._guide_states[curr_state_key] = curr_state
221+
222+
sequence_states.append(self._guide_states[curr_state_key])
223+
224+
mask = torch.full_like(logits, -math.inf)
225+
for i, guide_state in enumerate(sequence_states):
226+
first_legal_token = next(
227+
self.guide.iter_valid_token_ids(
228+
guide_state, torch.argsort(logits[i], descending=True)
229+
)
230+
)
231+
mask[i, [first_legal_token]] = logits[i, [first_legal_token]]
232+
233+
return mask

tests/fsm/test_cfg_guide.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,8 @@ def test_cfg_grammar_sample(request, sample_name, tokenizer_name, cleanup_lark_i
430430

431431
state = cfg_guide.initial_state
432432
for i, token_id in enumerate(sample_token_ids):
433+
if tokenizer.decode(token_id)[0] == "":
434+
continue
433435
next_instruction = cfg_guide.get_next_instruction(state)
434436
if token_id not in next_instruction.tokens:
435437
processed_str = tokenizer.decode([sample_token_ids[:i]])[0]

tests/generate/test_generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def sample_choices():
133133
def sample_lark_grammar():
134134
# from https://github.com/lark-parser/lark/blob/master/docs/grammar.md
135135
return """
136-
?start: (hello_world | number)
136+
?start: hello_world "!" number
137137
hello_world: ("hello" | "world") ~ 3
138138
number: ("0".."9") ~ 5
139139
thanks: "Thank"i " for testing!"

0 commit comments

Comments
 (0)