11import datetime
22from dataclasses import dataclass
3- from typing import TYPE_CHECKING , Iterator , List , Optional , Union
3+ from typing import Iterator , List , Optional , Sequence , Union
4+
5+ import torch
46
57from outlines .generate .generator import sequence_generator
68from outlines .samplers import BeamSearchSampler , GreedySampler , MultinomialSampler
79
8- if TYPE_CHECKING :
9- import torch
10-
1110FormattedOutput = Union [
1211 str , int , float , bool , datetime .date , datetime .time , datetime .datetime
1312]
13+ TotalCompletionsType = Optional [Union [list [str ], str ]]
1414
1515
1616class SequenceGenerator :
@@ -461,6 +461,47 @@ def prepare_generation_parameters(
461461
462462 return generation_params
463463
464+ def strip_completions (
465+ self ,
466+ completions ,
467+ prompts : Union [str , List [str ]],
468+ aligned_prompts : Union [str , List [str ]],
469+ ):
470+ """Remove characters generated through token alignment from the completions.
471+
472+ As token alignment makes the model re-generate some of the characters at
473+ the end of the prompt, we want to remove those from the beginning of the
474+ completions to only return the characters after the end of the user prompts.
475+
476+ Parameters
477+ ----------
478+ completions
479+ Text generated by the model
480+ prompts
481+ The original prompts provided by the user
482+ aligned_prompts
483+ The prompts of the user after token alignment (what's given to the model)
484+
485+ Returns
486+ -------
487+ The stripped completions
488+ """
489+ if isinstance (prompts , str ):
490+ if isinstance (completions , str ):
491+ return completions [len (prompts ) - len (aligned_prompts ) :]
492+
493+ return [
494+ self .strip_completions (completion , prompts , aligned_prompts )
495+ for completion in completions
496+ ]
497+
498+ return [
499+ self .strip_completions (completion , prompt , aligned_prompt )
500+ for completion , prompt , aligned_prompt in zip (
501+ completions , prompts , aligned_prompts
502+ )
503+ ]
504+
464505 def format_sequence (self , sequence : str ) -> FormattedOutput :
465506 """Translate the generated sequence to another type.
466507
@@ -500,15 +541,24 @@ def format(sequences):
500541 max_tokens , stop_at , seed
501542 )
502543
544+ aligned_prompts = self .logits_processor .align_prompts (prompts )
545+
503546 completions = self .model .generate (
504- prompts ,
547+ aligned_prompts ,
505548 generation_params ,
506549 self .logits_processor ,
507550 self .sampling_params ,
508551 ** model_specific_params ,
509552 )
510553
511- return format (completions )
554+ print (completions , prompts , aligned_prompts )
555+ stripped_completions = self .strip_completions (
556+ completions , prompts , aligned_prompts
557+ )
558+
559+ print (stripped_completions )
560+
561+ return format (stripped_completions )
512562
513563 def stream (
514564 self ,
@@ -519,13 +569,72 @@ def stream(
519569 ** model_specific_params ,
520570 ):
521571 """Return a text generator from a prompt or a list of prompts."""
572+
573+ def add_chunks_to_completions (
574+ text_chunks : Union [str , List [str ], List [List [str ]], Sequence [str ]],
575+ total_completions : Optional [
576+ Union [str , List [str ], List [List [str ]], Sequence [str ]]
577+ ],
578+ ):
579+ """Append each of the text chunks at the end of the corresponding completions"""
580+ if isinstance (text_chunks , str ):
581+ if isinstance (total_completions , str ):
582+ return total_completions + text_chunks
583+ return text_chunks
584+
585+ if total_completions :
586+ return [
587+ add_chunks_to_completions (text_chunk , total_completion )
588+ for text_chunk , total_completion in zip (
589+ text_chunks , total_completions
590+ )
591+ ]
592+
593+ return [
594+ add_chunks_to_completions (text_chunk , None )
595+ for text_chunk in text_chunks
596+ ]
597+
598+ def strip_text_chunks (
599+ text_chunks : Union [str , List [str ], List [List [str ]], Sequence [str ]],
600+ stripped_completions : Union [str , List [str ], List [List [str ]], Sequence [str ]],
601+ ):
602+ """Get the stripped text_chunks from the stripped_completions."""
603+ if isinstance (text_chunks , str ):
604+ return (
605+ stripped_completions [- len (text_chunks ) :]
606+ if len (text_chunks ) > 0
607+ else ""
608+ )
609+
610+ return [
611+ strip_text_chunks (text_chunk , stripped_completion )
612+ for text_chunk , stripped_completion in zip (
613+ text_chunks , stripped_completions
614+ )
615+ ]
616+
522617 generation_params = self .prepare_generation_parameters (
523618 max_tokens , stop_at , seed
524619 )
525- return self .model .stream (
620+
621+ aligned_prompts = self .logits_processor .align_prompts (prompts )
622+
623+ total_completions : TotalCompletionsType = None
624+
625+ for text_chunks in self .model .stream (
526626 prompts ,
527627 generation_params ,
528628 self .logits_processor ,
529629 self .sampling_params ,
530630 ** model_specific_params ,
531- )
631+ ):
632+ total_completions = add_chunks_to_completions (
633+ text_chunks , total_completions
634+ )
635+
636+ stripped_completions = self .strip_completions (
637+ total_completions , prompts , aligned_prompts
638+ )
639+
640+ yield strip_text_chunks (text_chunks , stripped_completions )
0 commit comments