Skip to content

Commit 6676414

Browse files
committed
fixed tests
1 parent 74f8a82 commit 6676414

File tree

1 file changed

+20
-4
lines changed

1 file changed

+20
-4
lines changed

backend/app/tests/test_toxicity_hub_validators.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_build_with_default_policies(self):
3737
def test_build_with_explicit_policies(self):
3838
config = LlamaGuard7BSafetyValidatorConfig(
3939
type="llamaguard_7b",
40-
policies=["O1", "O2"],
40+
policies=["no_violence_hate", "no_sexual_content"],
4141
)
4242

4343
with patch(_LLAMAGUARD_PATCH) as mock_validator:
@@ -56,7 +56,14 @@ def test_build_with_empty_policies_list(self):
5656
assert kwargs["policies"] == []
5757

5858
def test_build_with_all_policy_codes(self):
59-
all_policies = ["O1", "O2", "O3", "O4", "O5", "O6"]
59+
all_policies = [
60+
"no_violence_hate",
61+
"no_sexual_content",
62+
"no_criminal_planning",
63+
"no_guns_and_illegal_weapons",
64+
"no_illegal_drugs",
65+
"no_encourage_self_harm",
66+
]
6067
config = LlamaGuard7BSafetyValidatorConfig(
6168
type="llamaguard_7b", policies=all_policies
6269
)
@@ -65,11 +72,11 @@ def test_build_with_all_policy_codes(self):
6572
config.build()
6673

6774
_, kwargs = mock_validator.call_args
68-
assert kwargs["policies"] == all_policies
75+
assert kwargs["policies"] == ["O1", "O2", "O3", "O4", "O5", "O6"]
6976

7077
def test_build_with_single_policy(self):
7178
config = LlamaGuard7BSafetyValidatorConfig(
72-
type="llamaguard_7b", policies=["O3"]
79+
type="llamaguard_7b", policies=["no_criminal_planning"]
7380
)
7481

7582
with patch(_LLAMAGUARD_PATCH) as mock_validator:
@@ -78,6 +85,15 @@ def test_build_with_single_policy(self):
7885
_, kwargs = mock_validator.call_args
7986
assert kwargs["policies"] == ["O3"]
8087

88+
def test_build_with_invalid_policy_raises(self):
89+
config = LlamaGuard7BSafetyValidatorConfig(
90+
type="llamaguard_7b", policies=["O1"]
91+
)
92+
93+
with patch(_LLAMAGUARD_PATCH):
94+
with pytest.raises(ValueError, match="Unknown policy"):
95+
config.build()
96+
8197
def test_build_returns_validator_instance(self):
8298
config = LlamaGuard7BSafetyValidatorConfig(type="llamaguard_7b")
8399

0 commit comments

Comments
 (0)