-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_local.py
60 lines (51 loc) · 2.04 KB
/
main_local.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from transformers import AutoTokenizer, AutoModelForCausalLM
# from flask import Flask, request, jsonify
# app = Flask(__name__)
import os
folder = 'func_only' # doc_only / func_only / doc_func
path = f"/Users/cindy/code-generation/{folder}/input"
os.chdir(path)
device = "cpu"
source = "microsoft" # EleutherAI / microsoft
modelname = "CodeGPT-small-py-adaptedGPT2" # gpt-neo-2.7B / gpt-neo-1.3B / gpt-neo-1.5B / CodeGPT-small-py-adaptedGPT2 / CodeGPT-small-py
tokenizer = AutoTokenizer.from_pretrained(f"{source}/{modelname}")
model = AutoModelForCausalLM.from_pretrained(f"{source}/{modelname}")
model.to(device)
# with open('text.txt', 'r') as file:
# prime = file.read()
def inference(prompt, temperature, max_length):
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to(device)
gen_tokens = model.generate(
input_ids,
do_sample=True,
temperature=temperature,
max_length=max_length,
)
gen_text = tokenizer.batch_decode(gen_tokens)[0]
return gen_text
def autocomplete(plaintext, to_prime=True, temperature=0.8, max_length=300):
# prompt = prime + plaintext if to_prime else plaintext
prompt = plaintext
generation = inference(prompt, temperature, max_length)
return generation[len(prompt) :].split("###")[0]
# @app.route("/")
def arguments(text, file_name):
# text = "Convert list of strings into ints and return its sum" # read from file
generation = autocomplete(text)
f = open(f"/Users/cindy/code-generation/{folder}/output/"+file_name.rsplit('.', 1)[0]+f"-{modelname}.txt", "w")
f.write(generation)
f.close()
try:
for file in os.listdir():
if file.endswith(".txt"):
file_path = f"{path}/{file}"
with open(file_path, 'r') as f:
arguments(f.read(), file)
print(file_path + " ran successfully")
else:
print(f"No text files in {path}")
except FileNotFoundError:
print('File not found')
# if __name__ == "__main__":
# app.run(host="0.0.0.0", port="9900")