Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions garak/attempt.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ def from_dict(cls, value: dict):
raise ValueError("Expected `role` in Turn dict")
message = entity.pop("content", {})
if isinstance(message, str):
content = Message(text=message)
raise TypeError(
"Turn does not support str-type content, use Message / report this as a bug"
)
else:
content = Message(**message)
return cls(role=role, content=content)
Expand Down Expand Up @@ -156,7 +158,7 @@ class Attempt:
:param status: The status of this attempt; ``ATTEMPT_NEW``, ``ATTEMPT_STARTED``, or ``ATTEMPT_COMPLETE``
:type status: int
:param prompt: The processed prompt that will presented to the generator
:type prompt: Union[str|Turn|Conversation]
:type prompt: Message|Conversation
:param probe_classname: Name of the probe class that originated this ``Attempt``
:type probe_classname: str
:param probe_params: Non-default parameters logged by the probe
Expand Down Expand Up @@ -223,11 +225,16 @@ def __init__(
if isinstance(prompt, Conversation):
self.conversations = [prompt]
elif isinstance(prompt, str):
msg = Message(text=prompt, lang=lang)
raise ValueError(
"attempt Prompt must be Message or Conversation, not string"
)
# msg = Message(text=prompt, lang=lang)
elif isinstance(prompt, Message):
msg = prompt
else:
raise TypeError("prompts must be of type str | Message | Conversation")
raise TypeError(
"attempt prompts must be of type Message | Conversation"
)
if not hasattr(self, "conversations"):
self.conversations = [Conversation([Turn("user", msg)])]
self.prompt = self.conversations[0]
Expand Down Expand Up @@ -321,14 +328,17 @@ def all_outputs(self) -> List[Message]:
return all_outputs

@prompt.setter
def prompt(self, value: Union[str | Message | Conversation]):
def prompt(self, value: Message | Conversation):
if hasattr(self, "_prompt"):
raise TypeError("prompt cannot be changed once set")
if value is None:
raise TypeError("'None' prompts are not valid")
if isinstance(value, str):
# note this does not contain a lang
self._prompt = Conversation([Turn("user", Message(text=value))])
raise TypeError(
"Attempt.prompt must be Message or Conversation, not bare string"
)
# self._prompt = Conversation([Turn("user", Message(text=value))])
if isinstance(value, Message):
# make a copy to store an immutable object
self._prompt = Conversation([Turn("user", Message(**asdict(value)))])
Expand Down
14 changes: 8 additions & 6 deletions garak/probes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,11 @@ def _postprocess_hook(
return attempt

def _mint_attempt(
self, prompt=None, seq=None, notes=None, lang="*"
self,
prompt: str | garak.attempt.Message | garak.attempt.Conversation | None = None,
seq=None,
notes=None,
lang="*",
) -> garak.attempt.Attempt:
"""function for creating a new attempt given a prompt"""
turns = []
Expand All @@ -195,12 +199,10 @@ def _mint_attempt(
turns.append(
garak.attempt.Turn(
role="system",
content=garak.attempt.Message(
text=self.system_prompt, lang=lang
),
content=garak.attempt.Message(text=self.system_prompt, lang=lang),
)
)
if isinstance(prompt, str):
if isinstance(prompt, str): # we can mint with a string
turns.append(
garak.attempt.Turn(
role="user", content=garak.attempt.Message(text=prompt, lang=lang)
Expand Down Expand Up @@ -353,7 +355,7 @@ def probe(self, generator) -> Iterable[garak.attempt.Attempt]:
colour=f"#{garak.resources.theme.LANGPROVIDER_RGB}",
desc="Preparing prompts",
)
if isinstance(prompts[0], str):
if isinstance(prompts[0], str): # self.prompts can be strings
localized_prompts = self.langprovider.get_text(
prompts, notify_callback=preparation_bar.update
)
Expand Down
22 changes: 14 additions & 8 deletions tests/test_attempt.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def test_attempt_history_lengths():

def test_attempt_illegal_ops():
a = garak.attempt.Attempt()
a.prompt = "prompt"
a.prompt = garak.attempt.Message("prompt")
a.outputs = [garak.attempt.Message("output")]
with pytest.raises(TypeError):
a.prompt = "shouldn't be able to set initial prompt after output turned up"
Expand All @@ -238,15 +238,15 @@ def test_attempt_illegal_ops():

a = garak.attempt.Attempt()
with pytest.raises(TypeError):
a.prompt = "obsidian"
a.prompt = garak.attempt.Message("obsidian")
a.outputs = [garak.attempt.Message("order")]
a._expand_prompt_to_histories(
1
) # "shouldn't be able to expand histories twice"

a = garak.attempt.Attempt()
with pytest.raises(TypeError):
a.prompt = "obsidian"
a.prompt = garak.attempt.Message("obsidian")
a._expand_prompt_to_histories(3)
a._expand_prompt_to_histories(
3
Expand All @@ -268,15 +268,17 @@ def test_attempt_no_prompt_output_access():
def test_attempt_set_prompt_var():
test_text = "Plain Simple Garak"
direct_attempt = garak.attempt.Attempt()
direct_attempt.prompt = test_text
direct_attempt.prompt = garak.attempt.Message(test_text)
assert direct_attempt.prompt == garak.attempt.Conversation(
[garak.attempt.Turn("user", garak.attempt.Message(test_text))]
), "setting attempt.prompt should put the a Prompt with the given text in attempt.prompt"


def test_attempt_constructor_prompt():
test_text = "Plain Simple Garak"
constructor_attempt = garak.attempt.Attempt(prompt=test_text, lang="*")
constructor_attempt = garak.attempt.Attempt(
prompt=garak.attempt.Message(test_text, lang="*"), lang="*"
)
assert constructor_attempt.prompt == garak.attempt.Conversation(
[garak.attempt.Turn("user", garak.attempt.Message(test_text, lang="*"))]
), "instantiating an Attempt with prompt in the constructor should put a Prompt with the prompt text in attempt.prompt"
Expand Down Expand Up @@ -401,7 +403,7 @@ def test_attempt_outputs():
output_a = garak.attempt.Attempt()
assert output_a.outputs == []

output_a.prompt = test_prompt
output_a.prompt = garak.attempt.Message(test_prompt)
assert output_a.outputs == []

output_a.outputs = [garak.attempt.Message(test_sys1, lang=prompt_lang)]
Expand Down Expand Up @@ -457,14 +459,18 @@ def test_attempt_all_outputs():

def test_attempt_message_prompt_init():
test_prompt = "Enabran Tain"
att = garak.attempt.Attempt(prompt=test_prompt, lang="*")
att = garak.attempt.Attempt(
prompt=garak.attempt.Message(test_prompt, lang="*"), lang="*"
)
assert att.prompt == garak.attempt.Conversation(
[garak.attempt.Turn("user", garak.attempt.Message(text=test_prompt, lang="*"))]
)


def test_json_serialize():
att = garak.attempt.Attempt(prompt="well hello", lang="*")
att = garak.attempt.Attempt(
prompt=garak.attempt.Message("well hello", lang="*"), lang="*"
)
att.outputs = [garak.attempt.Message("output one"), None]

att_dict = att.as_dict()
Expand Down
Loading