Skip to content

Commit 0d68474

Browse files
committed
Implement AlignmentGuide
1 parent 289ef5d commit 0d68474

File tree

1 file changed

+242
-0
lines changed

1 file changed

+242
-0
lines changed

outlines/fsm/guide.py

+242
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,37 @@ def is_final_state(self, state: Any) -> bool:
9090
def copy(self) -> "Guide":
9191
...
9292

93+
def accepts(self, token_ids: List[int], state=None) -> bool:
94+
"""
95+
Determine whether the sequence, `token_ids`, is accepted by the Guide.
96+
`token_ids` doesn't need to complete the guide to be accepted.
97+
"""
98+
derived = self.derive(token_ids, state)
99+
return derived is not None
100+
101+
def derive(self, token_ids: List[int], state=None) -> Union["Guide", None]:
102+
if state is None:
103+
state = self.initial_state
104+
for token_id in token_ids:
105+
instruction = self.get_next_instruction(state)
106+
107+
# determine if token_id allowed by instruction
108+
if isinstance(instruction, Write):
109+
raise NotImplementedError("TODO")
110+
elif isinstance(instruction, Generate):
111+
if (
112+
instruction.tokens is not None
113+
and token_id not in instruction.tokens
114+
):
115+
return None
116+
else:
117+
raise TypeError(f"Expected instruction, got {instruction}")
118+
119+
# advance state
120+
state = self.get_next_state(state, token_id)
121+
122+
return state
123+
93124

94125
class StopAtEOSGuide(Guide):
95126
"""Guide to generate tokens until the EOS token has been generated."""
@@ -487,3 +518,214 @@ def must_terminate_state(self, state: CFGState) -> bool:
487518
def copy(self) -> "CFGGuide":
488519
"""Create a copy of the Guide."""
489520
return CFGGuide(self.cfg_string, self.tokenizer)
521+
522+
523+
@cache()
524+
def build_vocab_prefix_map(tokenizer: "Tokenizer") -> Dict[str, Set[Tuple[str, Tuple]]]:
525+
"""Build a map from token prefix to Set[Tuple[suffix, aligment_token_id, suffix_token_ids]]"""
526+
527+
# precompute the token ids of all vocab suffixes
528+
suffixes = list(
529+
{tok[i:] for tok in tokenizer.vocabulary for i in range(1, len(tok))}
530+
)
531+
encoded_suffixes, _ = tokenizer.encode(suffixes)
532+
encoded_suffixes = [
533+
[tok for tok in seq_ids if tok != tokenizer.pad_token_id]
534+
for seq_ids in encoded_suffixes.tolist()
535+
]
536+
suffix_map = dict(zip(suffixes, map(tuple, encoded_suffixes)))
537+
suffix_map[""] = tuple()
538+
539+
# compute prefix-suffix map for all tokens, s.t. prefix + suffix = token
540+
prefix_map = collections.defaultdict(set)
541+
for token, token_id in tokenizer.vocabulary.items():
542+
for i in range(1, len(token) + 1):
543+
prefix_map[token[:i]].add((token[i:], suffix_map[token[i:]]))
544+
return prefix_map
545+
546+
547+
AlignmentGuideState = collections.namedtuple(
548+
"AlignmentGuideState", ["legal_path_map", "child_guide_state"]
549+
)
550+
551+
552+
class AlignmentGuide(Guide):
553+
def __init__(
554+
self, prompt: str, tokenizer: "Tokenizer", child_guide: Optional[Guide] = None
555+
):
556+
"""
557+
Initialize the AlignmentGuide with a prompt, tokenizer, and an optional child guide.
558+
559+
Parameters
560+
----------
561+
prompt : str
562+
The prompt text to be aligned with the generated tokens.
563+
tokenizer : Tokenizer
564+
Tokenizer used to align the prompt.
565+
child_guide : Guide, optional
566+
A guide to take control after alignment is complete. None -> Unconstrained after alignment
567+
"""
568+
self.prompt = prompt
569+
self.tokenizer = tokenizer
570+
self.child_guide = child_guide
571+
572+
alignment_seqs, child_guide_ids = self._get_alignment_sequences(
573+
prompt, tokenizer, child_guide
574+
)
575+
alignment_prompt_ids, common_prompt_len = self._get_longest_common_prompt_ids(
576+
alignment_seqs
577+
)
578+
579+
self.alignment_prompt = self.tokenizer.decode(
580+
[alignment_seqs[0, :common_prompt_len]]
581+
)[0]
582+
583+
# calculate map of alignment_prompt continuation tokens -> child_guide advancement tokens
584+
legal_paths = [
585+
tuple([t for t in seq if t != tokenizer.pad_token_id])
586+
for seq in alignment_seqs[:, common_prompt_len:].tolist()
587+
]
588+
legal_path_map = dict(zip(legal_paths, child_guide_ids))
589+
590+
self.initial_state = AlignmentGuideState(
591+
legal_path_map=legal_path_map, child_guide_state=None
592+
)
593+
594+
@staticmethod
595+
def _get_alignment_sequences(
596+
prompt: str, tokenizer: "Tokenizer", child_guide: Optional[Guide] = None
597+
):
598+
"""
599+
Calculate all possible sequences which are valid with a prompt + child_guide
600+
E.g. prompt="hello wo", child guide accepts "rld" -> tokenization ["hello", "world"] is valid
601+
602+
Returns tuple of (alignment_seqs, child_guide_ids) of same length
603+
- alignment_seqs:
604+
All token sequences which can represent `prompt` + start of generation. The last token
605+
must represent the end of the prompt can extend beyond the prompt to start generation.
606+
Sequences are only included if the start of generation portion is legal with child guide.
607+
- child_guide_ids:
608+
Token to send to the child guide to simulate the start of generation. In the example above
609+
"world" is the last alignment seq token, therefore we must advance the state of the child
610+
guide with the tokenization of "rld" in order to continue generation with the child guide.
611+
"""
612+
guide_accepts: Dict[
613+
Tuple[int], bool
614+
] = {} # cache of suffix acceptance for child_guide.accepts()
615+
616+
# prompts with alignment tokens at end
617+
aligned_prompt_completions: List[str] = []
618+
# tokens to feed child guide once alignment completes
619+
child_guide_ids: List[Tuple] = []
620+
621+
# compute alignment seqs which are valid with prompt and child guide
622+
for prefix, alignment_details in build_vocab_prefix_map(tokenizer).items():
623+
if prompt.endswith(prefix):
624+
for suffix, suffix_ids in alignment_details:
625+
if child_guide is None:
626+
aligned_prompt_completions.append(prompt + suffix)
627+
child_guide_ids.append(tuple())
628+
elif guide_accepts.setdefault(
629+
suffix_ids, child_guide.accepts(suffix_ids)
630+
):
631+
aligned_prompt_completions.append(prompt + suffix)
632+
child_guide_ids.append(suffix_ids)
633+
634+
alignment_seqs, _ = tokenizer.encode(aligned_prompt_completions)
635+
return alignment_seqs, child_guide_ids
636+
637+
@staticmethod
638+
def _get_longest_common_prompt_ids(alignment_seqs):
639+
"""
640+
Among all candidate prompt alignment seqs, get the longest shared prefix and their length
641+
"""
642+
# get longest common prefix among alignment sequences, which will form our alignment prompt
643+
common = (
644+
(alignment_seqs.unsqueeze(1) == alignment_seqs.unsqueeze(0))
645+
.all(0)
646+
.cumprod(1)
647+
)
648+
common_len = common.sum(1).max().item()
649+
return alignment_seqs[0, :common_len], common_len
650+
651+
def get_next_instruction(self, state: AlignmentGuideState) -> Instruction:
652+
"""
653+
Return the next set of valid tokens for generation based on the current state.
654+
655+
If alignment hasn't completed:
656+
tokens which continue one of the candidate alignment paths are legal
657+
If alignment has completed:
658+
get instruction from the child guide
659+
"""
660+
if state.legal_path_map is not None:
661+
return Generate(
662+
sorted({token_ids[0] for token_ids in state.legal_path_map.keys()})
663+
)
664+
elif self.child_guide is None:
665+
return Generate(None)
666+
else:
667+
return self.child_guide.get_next_instruction(state.child_guide_state)
668+
669+
def get_next_state(
670+
self, state: AlignmentGuideState, token_id: int
671+
) -> AlignmentGuideState:
672+
"""
673+
Get AlignmentGuideState advanced by token ID.
674+
675+
If alignment has completed:
676+
get instruction from the child guide
677+
If alignment hasn't completed:
678+
Filter out alignment paths which don't start with token_id
679+
Remove First token from remaining paths
680+
If advancing the state completes alignment:
681+
Advance the child_guide state
682+
"""
683+
if state.legal_path_map is None:
684+
if self.child_guide is not None:
685+
return AlignmentGuideState(
686+
legal_path_map=None,
687+
child_guide_state=self.child_guide.get_next_state(
688+
state.child_guide_state, token_id
689+
),
690+
)
691+
else:
692+
return AlignmentGuideState(None, None)
693+
else:
694+
next_state_legal_path_map = {
695+
key[1:]: value
696+
for key, value in state.legal_path_map.items()
697+
if key[0] == token_id
698+
}
699+
# if none remaining, advance the child guide
700+
if not any(next_state_legal_path_map):
701+
if self.child_guide is not None:
702+
child_guide_advancement_ids = next(
703+
iter(next_state_legal_path_map.values())
704+
)
705+
return AlignmentGuideState(
706+
legal_path_map=None,
707+
child_guide_state=self.child_guide.derive(
708+
child_guide_advancement_ids, state.child_guide_state
709+
),
710+
)
711+
else:
712+
return AlignmentGuideState(None, None)
713+
714+
# if paths remaining, return advanced legal_path_map
715+
else:
716+
return AlignmentGuideState(
717+
legal_path_map=next_state_legal_path_map,
718+
child_guide_state=state.child_guide_state,
719+
)
720+
721+
def is_final_state(self, state: AlignmentGuideState) -> bool:
722+
if state.legal_path_map is not None:
723+
return False
724+
elif self.child_guide is None:
725+
return True
726+
else:
727+
return self.child_guide.is_final_state(state.child_guide_state)
728+
729+
def copy(self):
730+
"""AlignmentGuide isn't mutated"""
731+
return self

0 commit comments

Comments
 (0)