Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add completion and instruct engines for GoogleAI #822

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 74 additions & 10 deletions guidance/models/_googleai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

_image_token_pattern = re.compile(r"<\|_image:(.*)\|>")


class GoogleAIEngine(GrammarlessEngine):
def __init__(
self,
Expand Down Expand Up @@ -65,14 +64,6 @@ def __init__(
# chat
found_subclass = GoogleAIChat # we assume all models are chat right now

# instruct
# elif "instruct" in model:
# found_subclass = GoogleAIInstruct

# # regular completion
# else:
# found_subclass = GoogleAICompletion

# convert to any found subclass
self.__class__ = found_subclass
found_subclass.__init__(
Expand All @@ -89,7 +80,11 @@ def __init__(
return # we return since we just ran init above and don't need to run again

# this allows us to use a single constructor for all our subclasses
engine_map = {GoogleAIChat: GoogleAIChatEngine}
engine_map = {
GoogleAIChat: GoogleAIChatEngine,
GoogleAIInstruct: GoogleAIInstructEngine,
GoogleAICompletion: GoogleAICompletionEngine,
}

super().__init__(
engine=engine_map[self.__class__](
Expand All @@ -104,6 +99,75 @@ def __init__(
echo=echo,
)

class GoogleAICompletion(GoogleAI):
pass

class GoogleAICompletionEngine(GoogleAIEngine):
def _generator(self, prompt, temperature):

self._not_running_stream.clear() # so we know we are running
self._data = prompt # we start with this data

try:
kwargs = {}
generation_config = {"temperature": temperature}
if self.max_streaming_tokens is not None:
generation_config["max_output_tokens"] = self.max_streaming_tokens
kwargs["generation_config"] = generation_config

generator = self.model_obj.generate_content(
contents=self._data.decode("utf8"),
stream=True,
**kwargs,
)
except Exception as e: # TODO: add retry logic
raise e

for chunk in generator:
yield chunk.candidates[0].content.parts[0].text.encode("utf8")

class GoogleAIInstruct(GoogleAI, Instruct):
def get_role_start(self, name):
return ""

def get_role_end(self, name):
if name == "instruction":
return "<|endofprompt|>"
else:
raise Exception(
f"The GoogleAIInstruct model does not know about the {name} role type!"
)

class GoogleAIInstructEngine(GoogleAIEngine):
def _generator(self, prompt, temperature):
# start the new stream
eop_count = prompt.count(b"<|endofprompt|>")
if eop_count > 1:
raise Exception(
"This model has been given multiple instruct blocks or <|endofprompt|> tokens, but this is not allowed!"
)
updated_prompt = prompt + b"<|endofprompt|>" if eop_count == 0 else prompt

self._not_running_stream.clear() # so we know we are running
self._data = updated_prompt # we start with this data

try:
kwargs = {}
generation_config = {"temperature": temperature}
if self.max_streaming_tokens is not None:
generation_config["max_output_tokens"] = self.max_streaming_tokens
kwargs["generation_config"] = generation_config

generator = self.model_obj.generate_content(
contents=self._data.decode("utf8"),
stream=True,
**kwargs,
)
except Exception as e: # TODO: add retry logic
raise e

for chunk in generator:
yield chunk.candidates[0].content.parts[0].text.encode("utf8")

class GoogleAIChatEngine(GoogleAIEngine):
def _generator(self, prompt, temperature):
Expand Down
23 changes: 23 additions & 0 deletions tests/models/test_googleai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,29 @@

from ..utils import get_model

def test_googleai_basic():
try:
lm = models.GoogleAICompletion("gemini-pro")
except:
pytest.skip("Skipping GoogleAI test because we can't load the model!")

lm += "Count to 20: 1,2,3,4,"
nl = "\n"
lm += f"""\
5,6,7"""
lm += f"""{gen(max_tokens=1, suffix=nl)}aaaaaa"""
assert str(lm)[-5:] == "aaaaa"

def test_googleai_instruct():
try:
lm = models.GoogleAIInstruct("gemini-pro")
except:
pytest.skip("Skipping GoogleAI test because we can't load the model!")

with instruction():
lm += "this is a test about"
lm += gen("test", max_tokens=100)
assert len(lm["test"]) > 0

def test_gemini_pro():
from guidance import assistant, gen, models, system, user
Expand Down