@@ -90,6 +90,37 @@ def is_final_state(self, state: Any) -> bool:
90
90
def copy (self ) -> "Guide" :
91
91
...
92
92
93
+ def accepts (self , token_ids : List [int ], state = None ) -> bool :
94
+ """
95
+ Determine whether the sequence, `token_ids`, is accepted by the Guide.
96
+ `token_ids` doesn't need to complete the guide to be accepted.
97
+ """
98
+ derived = self .derive (token_ids , state )
99
+ return derived is not None
100
+
101
+ def derive (self , token_ids : List [int ], state = None ) -> Union ["Guide" , None ]:
102
+ if state is None :
103
+ state = self .initial_state
104
+ for token_id in token_ids :
105
+ instruction = self .get_next_instruction (state )
106
+
107
+ # determine if token_id allowed by instruction
108
+ if isinstance (instruction , Write ):
109
+ raise NotImplementedError ("TODO" )
110
+ elif isinstance (instruction , Generate ):
111
+ if (
112
+ instruction .tokens is not None
113
+ and token_id not in instruction .tokens
114
+ ):
115
+ return None
116
+ else :
117
+ raise TypeError (f"Expected instruction, got { instruction } " )
118
+
119
+ # advance state
120
+ state = self .get_next_state (state , token_id )
121
+
122
+ return state
123
+
93
124
94
125
class StopAtEOSGuide (Guide ):
95
126
"""Guide to generate tokens until the EOS token has been generated."""
@@ -487,3 +518,214 @@ def must_terminate_state(self, state: CFGState) -> bool:
487
518
def copy (self ) -> "CFGGuide" :
488
519
"""Create a copy of the Guide."""
489
520
return CFGGuide (self .cfg_string , self .tokenizer )
521
+
522
+
523
+ @cache ()
524
+ def build_vocab_prefix_map (tokenizer : "Tokenizer" ) -> Dict [str , Set [Tuple [str , Tuple ]]]:
525
+ """Build a map from token prefix to Set[Tuple[suffix, aligment_token_id, suffix_token_ids]]"""
526
+
527
+ # precompute the token ids of all vocab suffixes
528
+ suffixes = list (
529
+ {tok [i :] for tok in tokenizer .vocabulary for i in range (1 , len (tok ))}
530
+ )
531
+ encoded_suffixes , _ = tokenizer .encode (suffixes )
532
+ encoded_suffixes = [
533
+ [tok for tok in seq_ids if tok != tokenizer .pad_token_id ]
534
+ for seq_ids in encoded_suffixes .tolist ()
535
+ ]
536
+ suffix_map = dict (zip (suffixes , map (tuple , encoded_suffixes )))
537
+ suffix_map ["" ] = tuple ()
538
+
539
+ # compute prefix-suffix map for all tokens, s.t. prefix + suffix = token
540
+ prefix_map = collections .defaultdict (set )
541
+ for token , token_id in tokenizer .vocabulary .items ():
542
+ for i in range (1 , len (token ) + 1 ):
543
+ prefix_map [token [:i ]].add ((token [i :], suffix_map [token [i :]]))
544
+ return prefix_map
545
+
546
+
547
+ AlignmentGuideState = collections .namedtuple (
548
+ "AlignmentGuideState" , ["legal_path_map" , "child_guide_state" ]
549
+ )
550
+
551
+
552
+ class AlignmentGuide (Guide ):
553
+ def __init__ (
554
+ self , prompt : str , tokenizer : "Tokenizer" , child_guide : Optional [Guide ] = None
555
+ ):
556
+ """
557
+ Initialize the AlignmentGuide with a prompt, tokenizer, and an optional child guide.
558
+
559
+ Parameters
560
+ ----------
561
+ prompt : str
562
+ The prompt text to be aligned with the generated tokens.
563
+ tokenizer : Tokenizer
564
+ Tokenizer used to align the prompt.
565
+ child_guide : Guide, optional
566
+ A guide to take control after alignment is complete. None -> Unconstrained after alignment
567
+ """
568
+ self .prompt = prompt
569
+ self .tokenizer = tokenizer
570
+ self .child_guide = child_guide
571
+
572
+ alignment_seqs , child_guide_ids = self ._get_alignment_sequences (
573
+ prompt , tokenizer , child_guide
574
+ )
575
+ alignment_prompt_ids , common_prompt_len = self ._get_longest_common_prompt_ids (
576
+ alignment_seqs
577
+ )
578
+
579
+ self .alignment_prompt = self .tokenizer .decode (
580
+ [alignment_seqs [0 , :common_prompt_len ]]
581
+ )[0 ]
582
+
583
+ # calculate map of alignment_prompt continuation tokens -> child_guide advancement tokens
584
+ legal_paths = [
585
+ tuple ([t for t in seq if t != tokenizer .pad_token_id ])
586
+ for seq in alignment_seqs [:, common_prompt_len :].tolist ()
587
+ ]
588
+ legal_path_map = dict (zip (legal_paths , child_guide_ids ))
589
+
590
+ self .initial_state = AlignmentGuideState (
591
+ legal_path_map = legal_path_map , child_guide_state = None
592
+ )
593
+
594
+ @staticmethod
595
+ def _get_alignment_sequences (
596
+ prompt : str , tokenizer : "Tokenizer" , child_guide : Optional [Guide ] = None
597
+ ):
598
+ """
599
+ Calculate all possible sequences which are valid with a prompt + child_guide
600
+ E.g. prompt="hello wo", child guide accepts "rld" -> tokenization ["hello", "world"] is valid
601
+
602
+ Returns tuple of (alignment_seqs, child_guide_ids) of same length
603
+ - alignment_seqs:
604
+ All token sequences which can represent `prompt` + start of generation. The last token
605
+ must represent the end of the prompt can extend beyond the prompt to start generation.
606
+ Sequences are only included if the start of generation portion is legal with child guide.
607
+ - child_guide_ids:
608
+ Token to send to the child guide to simulate the start of generation. In the example above
609
+ "world" is the last alignment seq token, therefore we must advance the state of the child
610
+ guide with the tokenization of "rld" in order to continue generation with the child guide.
611
+ """
612
+ guide_accepts : Dict [
613
+ Tuple [int ], bool
614
+ ] = {} # cache of suffix acceptance for child_guide.accepts()
615
+
616
+ # prompts with alignment tokens at end
617
+ aligned_prompt_completions : List [str ] = []
618
+ # tokens to feed child guide once alignment completes
619
+ child_guide_ids : List [Tuple ] = []
620
+
621
+ # compute alignment seqs which are valid with prompt and child guide
622
+ for prefix , alignment_details in build_vocab_prefix_map (tokenizer ).items ():
623
+ if prompt .endswith (prefix ):
624
+ for suffix , suffix_ids in alignment_details :
625
+ if child_guide is None :
626
+ aligned_prompt_completions .append (prompt + suffix )
627
+ child_guide_ids .append (tuple ())
628
+ elif guide_accepts .setdefault (
629
+ suffix_ids , child_guide .accepts (suffix_ids )
630
+ ):
631
+ aligned_prompt_completions .append (prompt + suffix )
632
+ child_guide_ids .append (suffix_ids )
633
+
634
+ alignment_seqs , _ = tokenizer .encode (aligned_prompt_completions )
635
+ return alignment_seqs , child_guide_ids
636
+
637
+ @staticmethod
638
+ def _get_longest_common_prompt_ids (alignment_seqs ):
639
+ """
640
+ Among all candidate prompt alignment seqs, get the longest shared prefix and their length
641
+ """
642
+ # get longest common prefix among alignment sequences, which will form our alignment prompt
643
+ common = (
644
+ (alignment_seqs .unsqueeze (1 ) == alignment_seqs .unsqueeze (0 ))
645
+ .all (0 )
646
+ .cumprod (1 )
647
+ )
648
+ common_len = common .sum (1 ).max ().item ()
649
+ return alignment_seqs [0 , :common_len ], common_len
650
+
651
+ def get_next_instruction (self , state : AlignmentGuideState ) -> Instruction :
652
+ """
653
+ Return the next set of valid tokens for generation based on the current state.
654
+
655
+ If alignment hasn't completed:
656
+ tokens which continue one of the candidate alignment paths are legal
657
+ If alignment has completed:
658
+ get instruction from the child guide
659
+ """
660
+ if state .legal_path_map is not None :
661
+ return Generate (
662
+ sorted ({token_ids [0 ] for token_ids in state .legal_path_map .keys ()})
663
+ )
664
+ elif self .child_guide is None :
665
+ return Generate (None )
666
+ else :
667
+ return self .child_guide .get_next_instruction (state .child_guide_state )
668
+
669
+ def get_next_state (
670
+ self , state : AlignmentGuideState , token_id : int
671
+ ) -> AlignmentGuideState :
672
+ """
673
+ Get AlignmentGuideState advanced by token ID.
674
+
675
+ If alignment has completed:
676
+ get instruction from the child guide
677
+ If alignment hasn't completed:
678
+ Filter out alignment paths which don't start with token_id
679
+ Remove First token from remaining paths
680
+ If advancing the state completes alignment:
681
+ Advance the child_guide state
682
+ """
683
+ if state .legal_path_map is None :
684
+ if self .child_guide is not None :
685
+ return AlignmentGuideState (
686
+ legal_path_map = None ,
687
+ child_guide_state = self .child_guide .get_next_state (
688
+ state .child_guide_state , token_id
689
+ ),
690
+ )
691
+ else :
692
+ return AlignmentGuideState (None , None )
693
+ else :
694
+ next_state_legal_path_map = {
695
+ key [1 :]: value
696
+ for key , value in state .legal_path_map .items ()
697
+ if key [0 ] == token_id
698
+ }
699
+ # if none remaining, advance the child guide
700
+ if not any (next_state_legal_path_map ):
701
+ if self .child_guide is not None :
702
+ child_guide_advancement_ids = next (
703
+ iter (next_state_legal_path_map .values ())
704
+ )
705
+ return AlignmentGuideState (
706
+ legal_path_map = None ,
707
+ child_guide_state = self .child_guide .derive (
708
+ child_guide_advancement_ids , state .child_guide_state
709
+ ),
710
+ )
711
+ else :
712
+ return AlignmentGuideState (None , None )
713
+
714
+ # if paths remaining, return advanced legal_path_map
715
+ else :
716
+ return AlignmentGuideState (
717
+ legal_path_map = next_state_legal_path_map ,
718
+ child_guide_state = state .child_guide_state ,
719
+ )
720
+
721
+ def is_final_state (self , state : AlignmentGuideState ) -> bool :
722
+ if state .legal_path_map is not None :
723
+ return False
724
+ elif self .child_guide is None :
725
+ return True
726
+ else :
727
+ return self .child_guide .is_final_state (state .child_guide_state )
728
+
729
+ def copy (self ):
730
+ """AlignmentGuide isn't mutated"""
731
+ return self
0 commit comments