Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions garak/generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,15 @@ def _call_model(
# reload client once when consuming the generator
self._load_client()

# TODO: refactor to always use local scoped variables for _call_model client objects to avoid serialization state issues
client = self.client
generator = self.generator
is_completion = generator == client.completions

create_args = {}
if "n" not in self.suppressed_params:
create_args["n"] = generations_this_call
for arg in inspect.signature(self.generator.create).parameters:
for arg in inspect.signature(generator.create).parameters:
if arg == "model":
create_args[arg] = self.name
continue
Expand All @@ -232,7 +237,7 @@ def _call_model(
for k, v in self.extra_params.items():
create_args[k] = v

if self.generator == self.client.completions:
if is_completion:
if not isinstance(prompt, Conversation) or len(prompt.turns) > 1:
msg = (
f"Expected a Conversation with one Turn for {self.generator_family_name} completions model {self.name}, but got {type(prompt)}. "
Expand All @@ -243,7 +248,7 @@ def _call_model(

create_args["prompt"] = prompt.last_message().text

elif self.generator == self.client.chat.completions:
else: # is chat
if isinstance(prompt, Conversation):
messages = self._conversation_to_list(prompt)
elif isinstance(prompt, list):
Expand All @@ -260,7 +265,7 @@ def _call_model(
create_args["messages"] = messages

try:
response = self.generator.create(**create_args)
response = generator.create(**create_args)
except openai.BadRequestError as e:
msg = "Bad request: " + str(repr(prompt))
logging.exception(e)
Expand All @@ -284,9 +289,9 @@ def _call_model(
else:
return [None]

if self.generator == self.client.completions:
if is_completion:
return [Message(c.text) for c in response.choices]
elif self.generator == self.client.chat.completions:
else:
return [Message(c.message.content) for c in response.choices]


Expand Down