-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathtoxic_chat_t5.py
More file actions
77 lines (60 loc) · 2.18 KB
/
Copy pathtoxic_chat_t5.py
File metadata and controls
77 lines (60 loc) · 2.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import argparse
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from guardbench import benchmark
def moderate(
conversations: list[list[dict[str, str]]],
tokenizer: AutoTokenizer,
model: AutoModelForSeq2SeqLM,
safe_token_id: int,
unsafe_token_id: int,
) -> list[float]:
# Convert conversations to single texts by concatenation
texts = ["\n".join([y["content"] for y in x]) for x in conversations]
# Apply prompt template
texts = ["ToxicChat: " + x for x in texts]
# Tokenize texts
input = tokenizer(
texts,
max_length=4096,
padding=True,
truncation=True,
return_tensors="pt",
)
# Move input to model device
input = {k: v.to(model.device) for k, v in input.items()}
# Generate output
output = model.generate(
**input,
max_new_tokens=1,
output_scores=True,
return_dict_in_generate=True,
)
# Take logits for the first generated token of each prompt
logits = output.scores[0][:, [safe_token_id, unsafe_token_id]]
# Compute "unsafe" probabilities
return torch.softmax(logits, dim=-1)[:, 1].tolist()
def main(device: str, datasets: list[str], batch_size: int) -> None:
tokenizer = AutoTokenizer.from_pretrained("t5-large", model_max_length=4096)
model = AutoModelForSeq2SeqLM.from_pretrained("lmsys/toxicchat-t5-large-v1.0")
model = model.to(device)
model = model.eval()
benchmark(
moderate=moderate,
model_name="Toxic Chat T5",
batch_size=batch_size,
datasets=datasets,
# Moderate kwargs
tokenizer=tokenizer,
model=model,
safe_token_id=tokenizer.encode("negative")[0],
unsafe_token_id=tokenizer.encode("positive")[0],
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--device", default="cuda", type=str, help="Device")
parser.add_argument("--datasets", nargs="+", default="all", help="Datasets")
parser.add_argument("--batch_size", default=8, type=int, help="Batch size")
args = parser.parse_args()
with torch.no_grad():
main(args.device, args.datasets, args.batch_size)