Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions outlines/fsm/fast_lark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from copy import copy, deepcopy
from typing import Dict, Optional

from lark import Lark
from lark.lexer import Token
from lark.parsers.lalr_interactive_parser import InteractiveParser
from lark.parsers.lalr_parser_state import ParserState


class FastParserState(ParserState):
"""
Lark ParserState with optimized copying.
Works with Outlines because we don't perform
any operations which mutate Tokens
"""

copy_memo: Dict[str, Token] = {}

def __copy__(self):
new_value_stack = []
for value in self.value_stack:
key = f"{id(self)}_{id(value)}"
if key not in self.copy_memo:
self.copy_memo[key] = deepcopy(value, self.copy_memo)
new_value_stack.append(self.copy_memo[key])

new_instance = type(self)(
self.parse_conf,
self.lexer,
copy(self.state_stack),
new_value_stack,
)

self.copy_memo[id(self)] = new_instance
return new_instance


class FastInteractiveParser(InteractiveParser):
"""
InteractiveParser which uses FastParserState to manage its parse table
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.parser_state = FastParserState(
self.parser_state.parse_conf,
self.parser_state.lexer,
self.parser_state.state_stack,
self.parser_state.value_stack,
)
self.hash_val = None

def __hash__(self):
if self.hash_val is None:
self.hash_val = hash(tuple(self.parser_state.state_stack))
return self.hash_val

def __copy__(self):
return type(self)(
self.parser,
copy(self.parser_state),
copy(self.lexer_thread),
)


class FastLark(Lark):
"""
Lark which uses FastInteractiveParser for interactive mode
"""

def parse_interactive(
self, text: Optional[str] = None, start: Optional[str] = None
) -> "InteractiveParser":
base_interactive_parser = self.parser.parse_interactive(text, start=start)
return FastInteractiveParser(
base_interactive_parser.parser,
base_interactive_parser.parser_state,
base_interactive_parser.lexer_thread,
)
6 changes: 4 additions & 2 deletions outlines/fsm/fsm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import functools
from typing import TYPE_CHECKING, List, NewType, Protocol, Tuple

import interegular
from lark import Lark

# from outlines.fsm.parsing import PartialLark
from outlines import grammars
from outlines.caching import cache
from outlines.fsm.fast_lark import FastLark
from outlines.fsm.regex import create_fsm_index_tokenizer, make_deterministic_fsm

if TYPE_CHECKING:
Expand Down Expand Up @@ -88,6 +89,7 @@ def copy(self) -> "StopAtEosFSM":
return self


@functools.lru_cache(maxsize=1024)
class RegexFSM(FSM):
"""FSM to generate text that is in the language of a regular expression."""

Expand Down Expand Up @@ -194,7 +196,7 @@ def __init__(self, cfg_string: str, tokenizer: "Tokenizer"):
self.cfg_string = cfg_string
self.tokenizer = tokenizer

self.parser = Lark(
self.parser = FastLark(
cfg_string,
parser="lalr",
lexer="contextual",
Expand Down
8 changes: 6 additions & 2 deletions outlines/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ def __init__(self, model_name: str, **kwargs):
self.vocabulary = self.tokenizer.get_vocab()
self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types())

self._hash: Optional[int] = None

def encode(
self, prompt: Union[str, List[str]], **kwargs
) -> Tuple[torch.LongTensor, torch.LongTensor]:
Expand Down Expand Up @@ -175,9 +177,11 @@ def __eq__(self, other):
return NotImplemented

def __hash__(self):
from datasets.fingerprint import Hasher
if self._hash is None:
from datasets.fingerprint import Hasher

return hash(Hasher.hash(self.tokenizer))
self._hash = hash(Hasher.hash(self.tokenizer))
return self._hash


def transformers(
Expand Down