Skip to content

Commit 54aaf87

Browse files
committed
Implement AlignmentGuide
1 parent 289ef5d commit 54aaf87

File tree

2 files changed

+335
-0
lines changed

2 files changed

+335
-0
lines changed

outlines/fsm/guide.py

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,43 @@ def is_final_state(self, state: Any) -> bool:
9090
def copy(self) -> "Guide":
9191
...
9292

93+
def accepts(self, token_ids: List[int], state=None) -> bool:
94+
"""
95+
Determine whether the sequence, `token_ids`, is accepted by the Guide.
96+
`token_ids` doesn't need to complete the guide to be accepted.
97+
"""
98+
try:
99+
self.derive(token_ids, state)
100+
return True
101+
except ValueError:
102+
return False
103+
104+
def derive(self, token_ids: List[int], state=None) -> Union["Guide", None]:
105+
"""
106+
TODO: Docstring
107+
"""
108+
if state is None:
109+
state = self.initial_state
110+
for token_id in token_ids:
111+
instruction = self.get_next_instruction(state)
112+
113+
# determine if token_id allowed by instruction
114+
if isinstance(instruction, Write):
115+
raise NotImplementedError("TODO")
116+
elif isinstance(instruction, Generate):
117+
if (
118+
instruction.tokens is not None
119+
and token_id not in instruction.tokens
120+
):
121+
raise ValueError("Cannot advance state with provided token_ids")
122+
else:
123+
raise TypeError(f"Expected instruction, got {instruction}")
124+
125+
# advance state
126+
state = self.get_next_state(state, token_id)
127+
128+
return state
129+
93130

94131
class StopAtEOSGuide(Guide):
95132
"""Guide to generate tokens until the EOS token has been generated."""
@@ -487,3 +524,214 @@ def must_terminate_state(self, state: CFGState) -> bool:
487524
def copy(self) -> "CFGGuide":
488525
"""Create a copy of the Guide."""
489526
return CFGGuide(self.cfg_string, self.tokenizer)
527+
528+
529+
@cache()
530+
def build_vocab_prefix_map(tokenizer: "Tokenizer") -> Dict[str, Set[Tuple[str, Tuple]]]:
531+
"""Build a map from token prefix to Set[Tuple[suffix, aligment_token_id, suffix_token_ids]]"""
532+
533+
# precompute the token ids of all vocab suffixes
534+
suffixes = list(
535+
{tok[i:] for tok in tokenizer.vocabulary for i in range(1, len(tok))}
536+
)
537+
encoded_suffixes, _ = tokenizer.encode(suffixes)
538+
encoded_suffixes = [
539+
[tok for tok in seq_ids if tok != tokenizer.pad_token_id]
540+
for seq_ids in encoded_suffixes.tolist()
541+
]
542+
suffix_map = dict(zip(suffixes, map(tuple, encoded_suffixes)))
543+
suffix_map[""] = tuple()
544+
545+
# compute prefix-suffix map for all tokens, s.t. prefix + suffix = token
546+
prefix_map = collections.defaultdict(set)
547+
for token, token_id in tokenizer.vocabulary.items():
548+
for i in range(1, len(token) + 1):
549+
prefix_map[token[:i]].add((token[i:], suffix_map[token[i:]]))
550+
return prefix_map
551+
552+
553+
AlignmentGuideState = collections.namedtuple(
554+
"AlignmentGuideState", ["legal_path_map", "child_guide_state"]
555+
)
556+
557+
558+
class AlignmentGuide(Guide):
559+
def __init__(
560+
self, prompt: str, tokenizer: "Tokenizer", child_guide: Optional[Guide] = None
561+
):
562+
"""
563+
Initialize the AlignmentGuide with a prompt, tokenizer, and an optional child guide.
564+
565+
Parameters
566+
----------
567+
prompt : str
568+
The prompt text to be aligned with the generated tokens.
569+
tokenizer : Tokenizer
570+
Tokenizer used to align the prompt.
571+
child_guide : Guide, optional
572+
A guide to take control after alignment is complete. None -> Unconstrained after alignment
573+
"""
574+
self.prompt = prompt
575+
self.tokenizer = tokenizer
576+
self.child_guide = child_guide
577+
578+
alignment_seqs, child_guide_ids = self._get_alignment_sequences(
579+
prompt, tokenizer, child_guide
580+
)
581+
alignment_prompt_ids, common_prompt_len = self._get_longest_common_prompt_ids(
582+
alignment_seqs
583+
)
584+
585+
self.alignment_prompt = self.tokenizer.decode(
586+
[alignment_seqs[0, :common_prompt_len]]
587+
)[0]
588+
589+
# calculate map of alignment_prompt continuation tokens -> child_guide advancement tokens
590+
legal_paths = [
591+
tuple([t for t in seq if t != tokenizer.pad_token_id])
592+
for seq in alignment_seqs[:, common_prompt_len:].tolist()
593+
]
594+
legal_path_map = dict(zip(legal_paths, child_guide_ids))
595+
596+
self.initial_state = AlignmentGuideState(
597+
legal_path_map=legal_path_map, child_guide_state=None
598+
)
599+
600+
@staticmethod
601+
def _get_alignment_sequences(
602+
prompt: str, tokenizer: "Tokenizer", child_guide: Optional[Guide] = None
603+
):
604+
"""
605+
Calculate all possible sequences which are valid with a prompt + child_guide
606+
E.g. prompt="hello wo", child guide accepts "rld" -> tokenization ["hello", "world"] is valid
607+
608+
Returns tuple of (alignment_seqs, child_guide_ids) of same length
609+
- alignment_seqs:
610+
All token sequences which can represent `prompt` + start of generation. The last token
611+
must represent the end of the prompt can extend beyond the prompt to start generation.
612+
Sequences are only included if the start of generation portion is legal with child guide.
613+
- child_guide_ids:
614+
Token to send to the child guide to simulate the start of generation. In the example above
615+
"world" is the last alignment seq token, therefore we must advance the state of the child
616+
guide with the tokenization of "rld" in order to continue generation with the child guide.
617+
"""
618+
guide_accepts: Dict[
619+
Tuple[int], bool
620+
] = {} # cache of suffix acceptance for child_guide.accepts()
621+
622+
# prompts with alignment tokens at end
623+
aligned_prompt_completions: List[str] = []
624+
# tokens to feed child guide once alignment completes
625+
child_guide_ids: List[Tuple] = []
626+
627+
# compute alignment seqs which are valid with prompt and child guide
628+
for prefix, alignment_details in build_vocab_prefix_map(tokenizer).items():
629+
if prompt.endswith(prefix):
630+
for suffix, suffix_ids in alignment_details:
631+
if child_guide is None:
632+
aligned_prompt_completions.append(prompt + suffix)
633+
child_guide_ids.append(tuple())
634+
elif guide_accepts.setdefault(
635+
suffix_ids, child_guide.accepts(suffix_ids)
636+
):
637+
aligned_prompt_completions.append(prompt + suffix)
638+
child_guide_ids.append(suffix_ids)
639+
640+
alignment_seqs, _ = tokenizer.encode(aligned_prompt_completions)
641+
return alignment_seqs, child_guide_ids
642+
643+
@staticmethod
644+
def _get_longest_common_prompt_ids(alignment_seqs):
645+
"""
646+
Among all candidate prompt alignment seqs, get the longest shared prefix and their length
647+
"""
648+
# get longest common prefix among alignment sequences, which will form our alignment prompt
649+
common = (
650+
(alignment_seqs.unsqueeze(1) == alignment_seqs.unsqueeze(0))
651+
.all(0)
652+
.cumprod(1)
653+
)
654+
common_len = common.sum(1).max().item()
655+
return alignment_seqs[0, :common_len], common_len
656+
657+
def get_next_instruction(self, state: AlignmentGuideState) -> Instruction:
658+
"""
659+
Return the next set of valid tokens for generation based on the current state.
660+
661+
If alignment hasn't completed:
662+
tokens which continue one of the candidate alignment paths are legal
663+
If alignment has completed:
664+
get instruction from the child guide
665+
"""
666+
if state.legal_path_map is not None:
667+
return Generate(
668+
sorted({token_ids[0] for token_ids in state.legal_path_map.keys()})
669+
)
670+
elif self.child_guide is None:
671+
return Generate(None)
672+
else:
673+
return self.child_guide.get_next_instruction(state.child_guide_state)
674+
675+
def get_next_state(
676+
self, state: AlignmentGuideState, token_id: int
677+
) -> AlignmentGuideState:
678+
"""
679+
Get AlignmentGuideState advanced by token ID.
680+
681+
If alignment has completed:
682+
get instruction from the child guide
683+
If alignment hasn't completed:
684+
Filter out alignment paths which don't start with token_id
685+
Remove First token from remaining paths
686+
If advancing the state completes alignment:
687+
Advance the child_guide state
688+
"""
689+
if state.legal_path_map is None:
690+
if self.child_guide is not None:
691+
return AlignmentGuideState(
692+
legal_path_map=None,
693+
child_guide_state=self.child_guide.get_next_state(
694+
state.child_guide_state, token_id
695+
),
696+
)
697+
else:
698+
return AlignmentGuideState(None, None)
699+
else:
700+
next_state_legal_path_map = {
701+
key[1:]: value
702+
for key, value in state.legal_path_map.items()
703+
if key[0] == token_id
704+
}
705+
# if none remaining, advance the child guide
706+
if not any(next_state_legal_path_map):
707+
if self.child_guide is not None:
708+
child_guide_advancement_ids = next(
709+
iter(next_state_legal_path_map.values())
710+
)
711+
return AlignmentGuideState(
712+
legal_path_map=None,
713+
child_guide_state=self.child_guide.derive(
714+
child_guide_advancement_ids, state.child_guide_state
715+
),
716+
)
717+
else:
718+
return AlignmentGuideState(None, None)
719+
720+
# if paths remaining, return advanced legal_path_map
721+
else:
722+
return AlignmentGuideState(
723+
legal_path_map=next_state_legal_path_map,
724+
child_guide_state=state.child_guide_state,
725+
)
726+
727+
def is_final_state(self, state: AlignmentGuideState) -> bool:
728+
if state.legal_path_map is not None:
729+
return False
730+
elif self.child_guide is None:
731+
return True
732+
else:
733+
return self.child_guide.is_final_state(state.child_guide_state)
734+
735+
def copy(self):
736+
"""AlignmentGuide isn't mutated"""
737+
return self

tests/fsm/test_alignment_guide.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import torch
2+
from transformers import AutoTokenizer
3+
4+
from outlines.fsm.guide import AlignmentGuide, RegexGuide
5+
from outlines.models.transformers import TransformerTokenizer
6+
7+
8+
class MockTokenizer:
9+
def __init__(self, vocabulary):
10+
self.vocabulary = {tok: i for i, tok in enumerate(vocabulary)}
11+
self.vocabulary["<eos>"] = len(self.vocabulary)
12+
self.special_tokens = {"<eos>"}
13+
self.eos_token_id = self.vocabulary["<eos>"]
14+
self.pad_token_id = -1
15+
16+
self.inverse_vocabulary = {i: tok for tok, i in self.vocabulary.items()}
17+
18+
def convert_token_to_string(self, token):
19+
return token
20+
21+
def decode(self, token_ids):
22+
if token_ids == []:
23+
return ""
24+
if isinstance(list(token_ids)[0], list):
25+
return [
26+
"".join(map(self.inverse_vocabulary.get, token_ids_sublist))
27+
for token_ids_sublist in token_ids
28+
]
29+
return [self.inverse_vocabulary[int(token_id)] for token_id in token_ids]
30+
31+
def encode(self, texts):
32+
"""
33+
Encodes the input texts by finding the longest matching tokens in the vocabulary.
34+
"""
35+
seqs = []
36+
for text in texts:
37+
tokens = []
38+
while text:
39+
token = next(
40+
(
41+
tok
42+
for tok in sorted(self.vocabulary, key=len, reverse=True)
43+
if text.startswith(tok)
44+
),
45+
None,
46+
)
47+
if token is None:
48+
tokens = [self.pad_token_id]
49+
break
50+
tokens.append(self.vocabulary[token])
51+
text = text[len(token) :]
52+
seqs.append(tokens)
53+
54+
max_len = max(len(seq) for seq in seqs)
55+
padded_seqs = torch.tensor(
56+
[seq + [self.pad_token_id] * (max_len - len(seq)) for seq in seqs]
57+
)
58+
return padded_seqs, None
59+
60+
61+
def test_alignment_with_pseudo_token_and_regex_guide():
62+
# Mock tokenizer with the vocabulary for "hello", "world", "wo", "rld", and "!"
63+
tokenizer = MockTokenizer(["hello", " world", " wo", "rld", "!"])
64+
prompt = "hello wo"
65+
66+
# Create a RegexGuide that expects the sequence "rld!"
67+
child_guide = RegexGuide(regex_string="rld!", tokenizer=tokenizer)
68+
69+
# Create the AlignmentGuide with the child guide
70+
guide = AlignmentGuide(prompt, tokenizer, child_guide=child_guide)
71+
72+
assert guide.alignment_prompt == "hello"
73+
74+
# assert " world!" is legal and final
75+
seq = [tokenizer.vocabulary[" world"], tokenizer.vocabulary["!"]]
76+
assert guide.accepts(seq)
77+
assert guide.is_final_state(guide.derive(seq, guide.initial_state)) is True
78+
79+
80+
def test_alignment_guide_gpt2_url():
81+
# Based on notebook
82+
# https://github.com/guidance-ai/guidance/blob/af63e6/notebooks/tutorials/token_healing.ipynb#L4
83+
tokenizer = TransformerTokenizer(AutoTokenizer.from_pretrained("gpt2"))
84+
prompt = "The url of Google is http:"
85+
guide = AlignmentGuide(prompt, tokenizer)
86+
assert guide.alignment_prompt == "The url of Google is http"
87+
assert guide.accepts(list(tokenizer.encode("://google.com")[0][0]))

0 commit comments

Comments
 (0)