-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdrink_lava.py
47 lines (39 loc) · 1.21 KB
/
drink_lava.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# -*- coding: utf-8 -*-
"""Drink Lava
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/10DLgYVQYDVrTBjkcXWDrCAYmd3izelNJ
"""
# !pip install transformers
import numpy as np
import os
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import random
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
cuda0 = torch.device('cuda:0')
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.eval() # deactivate the DropOut modules
model.to('cuda');
def likelihoodToDie(drink):
given = f'I drank {drink}, so I'
indexed_tokens = tokenizer.encode(given)
tokens_tensor = torch.tensor(indexed_tokens).to('cuda')
word_index = tokenizer.encode(' died')[0]
with torch.no_grad():
outputs = model(tokens_tensor)
predictions = outputs[0][-1,:]
distribution = torch.nn.functional.softmax(predictions, dim=0).cpu().numpy()
return distribution[word_index]
THRESHOLD = 2.4 * 10**(-7)
print('http://www.scp-wiki.net/scp-294')
while True:
print('SCP 294! Drink = ')
drink = input()
print(f'You drank a cup of {drink}.')
a = likelihoodToDie(drink)
if a > THRESHOLD:
print('You died.')
else:
print('You lived.')
print(drink, a)