@@ -90,6 +90,43 @@ def is_final_state(self, state: Any) -> bool:
90
90
def copy (self ) -> "Guide" :
91
91
...
92
92
93
+ def accepts (self , token_ids : List [int ], state = None ) -> bool :
94
+ """
95
+ Determine whether the sequence, `token_ids`, is accepted by the Guide.
96
+ `token_ids` doesn't need to complete the guide to be accepted.
97
+ """
98
+ try :
99
+ self .derive (token_ids , state )
100
+ return True
101
+ except ValueError :
102
+ return False
103
+
104
+ def derive (self , token_ids : List [int ], state = None ) -> Union ["Guide" , None ]:
105
+ """
106
+ TODO: Docstring
107
+ """
108
+ if state is None :
109
+ state = self .initial_state
110
+ for token_id in token_ids :
111
+ instruction = self .get_next_instruction (state )
112
+
113
+ # determine if token_id allowed by instruction
114
+ if isinstance (instruction , Write ):
115
+ raise NotImplementedError ("TODO" )
116
+ elif isinstance (instruction , Generate ):
117
+ if (
118
+ instruction .tokens is not None
119
+ and token_id not in instruction .tokens
120
+ ):
121
+ raise ValueError ("Cannot advance state with provided token_ids" )
122
+ else :
123
+ raise TypeError (f"Expected instruction, got { instruction } " )
124
+
125
+ # advance state
126
+ state = self .get_next_state (state , token_id )
127
+
128
+ return state
129
+
93
130
94
131
class StopAtEOSGuide (Guide ):
95
132
"""Guide to generate tokens until the EOS token has been generated."""
@@ -487,3 +524,214 @@ def must_terminate_state(self, state: CFGState) -> bool:
487
524
def copy (self ) -> "CFGGuide" :
488
525
"""Create a copy of the Guide."""
489
526
return CFGGuide (self .cfg_string , self .tokenizer )
527
+
528
+
529
+ @cache ()
530
+ def build_vocab_prefix_map (tokenizer : "Tokenizer" ) -> Dict [str , Set [Tuple [str , Tuple ]]]:
531
+ """Build a map from token prefix to Set[Tuple[suffix, aligment_token_id, suffix_token_ids]]"""
532
+
533
+ # precompute the token ids of all vocab suffixes
534
+ suffixes = list (
535
+ {tok [i :] for tok in tokenizer .vocabulary for i in range (1 , len (tok ))}
536
+ )
537
+ encoded_suffixes , _ = tokenizer .encode (suffixes )
538
+ encoded_suffixes = [
539
+ [tok for tok in seq_ids if tok != tokenizer .pad_token_id ]
540
+ for seq_ids in encoded_suffixes .tolist ()
541
+ ]
542
+ suffix_map = dict (zip (suffixes , map (tuple , encoded_suffixes )))
543
+ suffix_map ["" ] = tuple ()
544
+
545
+ # compute prefix-suffix map for all tokens, s.t. prefix + suffix = token
546
+ prefix_map = collections .defaultdict (set )
547
+ for token , token_id in tokenizer .vocabulary .items ():
548
+ for i in range (1 , len (token ) + 1 ):
549
+ prefix_map [token [:i ]].add ((token [i :], suffix_map [token [i :]]))
550
+ return prefix_map
551
+
552
+
553
+ AlignmentGuideState = collections .namedtuple (
554
+ "AlignmentGuideState" , ["legal_path_map" , "child_guide_state" ]
555
+ )
556
+
557
+
558
+ class AlignmentGuide (Guide ):
559
+ def __init__ (
560
+ self , prompt : str , tokenizer : "Tokenizer" , child_guide : Optional [Guide ] = None
561
+ ):
562
+ """
563
+ Initialize the AlignmentGuide with a prompt, tokenizer, and an optional child guide.
564
+
565
+ Parameters
566
+ ----------
567
+ prompt : str
568
+ The prompt text to be aligned with the generated tokens.
569
+ tokenizer : Tokenizer
570
+ Tokenizer used to align the prompt.
571
+ child_guide : Guide, optional
572
+ A guide to take control after alignment is complete. None -> Unconstrained after alignment
573
+ """
574
+ self .prompt = prompt
575
+ self .tokenizer = tokenizer
576
+ self .child_guide = child_guide
577
+
578
+ alignment_seqs , child_guide_ids = self ._get_alignment_sequences (
579
+ prompt , tokenizer , child_guide
580
+ )
581
+ alignment_prompt_ids , common_prompt_len = self ._get_longest_common_prompt_ids (
582
+ alignment_seqs
583
+ )
584
+
585
+ self .alignment_prompt = self .tokenizer .decode (
586
+ [alignment_seqs [0 , :common_prompt_len ]]
587
+ )[0 ]
588
+
589
+ # calculate map of alignment_prompt continuation tokens -> child_guide advancement tokens
590
+ legal_paths = [
591
+ tuple ([t for t in seq if t != tokenizer .pad_token_id ])
592
+ for seq in alignment_seqs [:, common_prompt_len :].tolist ()
593
+ ]
594
+ legal_path_map = dict (zip (legal_paths , child_guide_ids ))
595
+
596
+ self .initial_state = AlignmentGuideState (
597
+ legal_path_map = legal_path_map , child_guide_state = None
598
+ )
599
+
600
+ @staticmethod
601
+ def _get_alignment_sequences (
602
+ prompt : str , tokenizer : "Tokenizer" , child_guide : Optional [Guide ] = None
603
+ ):
604
+ """
605
+ Calculate all possible sequences which are valid with a prompt + child_guide
606
+ E.g. prompt="hello wo", child guide accepts "rld" -> tokenization ["hello", "world"] is valid
607
+
608
+ Returns tuple of (alignment_seqs, child_guide_ids) of same length
609
+ - alignment_seqs:
610
+ All token sequences which can represent `prompt` + start of generation. The last token
611
+ must represent the end of the prompt can extend beyond the prompt to start generation.
612
+ Sequences are only included if the start of generation portion is legal with child guide.
613
+ - child_guide_ids:
614
+ Token to send to the child guide to simulate the start of generation. In the example above
615
+ "world" is the last alignment seq token, therefore we must advance the state of the child
616
+ guide with the tokenization of "rld" in order to continue generation with the child guide.
617
+ """
618
+ guide_accepts : Dict [
619
+ Tuple [int ], bool
620
+ ] = {} # cache of suffix acceptance for child_guide.accepts()
621
+
622
+ # prompts with alignment tokens at end
623
+ aligned_prompt_completions : List [str ] = []
624
+ # tokens to feed child guide once alignment completes
625
+ child_guide_ids : List [Tuple ] = []
626
+
627
+ # compute alignment seqs which are valid with prompt and child guide
628
+ for prefix , alignment_details in build_vocab_prefix_map (tokenizer ).items ():
629
+ if prompt .endswith (prefix ):
630
+ for suffix , suffix_ids in alignment_details :
631
+ if child_guide is None :
632
+ aligned_prompt_completions .append (prompt + suffix )
633
+ child_guide_ids .append (tuple ())
634
+ elif guide_accepts .setdefault (
635
+ suffix_ids , child_guide .accepts (suffix_ids )
636
+ ):
637
+ aligned_prompt_completions .append (prompt + suffix )
638
+ child_guide_ids .append (suffix_ids )
639
+
640
+ alignment_seqs , _ = tokenizer .encode (aligned_prompt_completions )
641
+ return alignment_seqs , child_guide_ids
642
+
643
+ @staticmethod
644
+ def _get_longest_common_prompt_ids (alignment_seqs ):
645
+ """
646
+ Among all candidate prompt alignment seqs, get the longest shared prefix and their length
647
+ """
648
+ # get longest common prefix among alignment sequences, which will form our alignment prompt
649
+ common = (
650
+ (alignment_seqs .unsqueeze (1 ) == alignment_seqs .unsqueeze (0 ))
651
+ .all (0 )
652
+ .cumprod (1 )
653
+ )
654
+ common_len = common .sum (1 ).max ().item ()
655
+ return alignment_seqs [0 , :common_len ], common_len
656
+
657
+ def get_next_instruction (self , state : AlignmentGuideState ) -> Instruction :
658
+ """
659
+ Return the next set of valid tokens for generation based on the current state.
660
+
661
+ If alignment hasn't completed:
662
+ tokens which continue one of the candidate alignment paths are legal
663
+ If alignment has completed:
664
+ get instruction from the child guide
665
+ """
666
+ if state .legal_path_map is not None :
667
+ return Generate (
668
+ sorted ({token_ids [0 ] for token_ids in state .legal_path_map .keys ()})
669
+ )
670
+ elif self .child_guide is None :
671
+ return Generate (None )
672
+ else :
673
+ return self .child_guide .get_next_instruction (state .child_guide_state )
674
+
675
+ def get_next_state (
676
+ self , state : AlignmentGuideState , token_id : int
677
+ ) -> AlignmentGuideState :
678
+ """
679
+ Get AlignmentGuideState advanced by token ID.
680
+
681
+ If alignment has completed:
682
+ get instruction from the child guide
683
+ If alignment hasn't completed:
684
+ Filter out alignment paths which don't start with token_id
685
+ Remove First token from remaining paths
686
+ If advancing the state completes alignment:
687
+ Advance the child_guide state
688
+ """
689
+ if state .legal_path_map is None :
690
+ if self .child_guide is not None :
691
+ return AlignmentGuideState (
692
+ legal_path_map = None ,
693
+ child_guide_state = self .child_guide .get_next_state (
694
+ state .child_guide_state , token_id
695
+ ),
696
+ )
697
+ else :
698
+ return AlignmentGuideState (None , None )
699
+ else :
700
+ next_state_legal_path_map = {
701
+ key [1 :]: value
702
+ for key , value in state .legal_path_map .items ()
703
+ if key [0 ] == token_id
704
+ }
705
+ # if none remaining, advance the child guide
706
+ if not any (next_state_legal_path_map ):
707
+ if self .child_guide is not None :
708
+ child_guide_advancement_ids = next (
709
+ iter (next_state_legal_path_map .values ())
710
+ )
711
+ return AlignmentGuideState (
712
+ legal_path_map = None ,
713
+ child_guide_state = self .child_guide .derive (
714
+ child_guide_advancement_ids , state .child_guide_state
715
+ ),
716
+ )
717
+ else :
718
+ return AlignmentGuideState (None , None )
719
+
720
+ # if paths remaining, return advanced legal_path_map
721
+ else :
722
+ return AlignmentGuideState (
723
+ legal_path_map = next_state_legal_path_map ,
724
+ child_guide_state = state .child_guide_state ,
725
+ )
726
+
727
+ def is_final_state (self , state : AlignmentGuideState ) -> bool :
728
+ if state .legal_path_map is not None :
729
+ return False
730
+ elif self .child_guide is None :
731
+ return True
732
+ else :
733
+ return self .child_guide .is_final_state (state .child_guide_state )
734
+
735
+ def copy (self ):
736
+ """AlignmentGuide isn't mutated"""
737
+ return self
0 commit comments