-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinteract.py
170 lines (151 loc) · 9.57 KB
/
interact.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import transformers
import torch
import os
import json
import random
import numpy as np
import argparse
from datetime import datetime
from tqdm import tqdm
from torch.nn import DataParallel
import logging
from transformers.modeling_gpt2 import GPT2Config, GPT2LMHeadModel
from transformers import BertTokenizer
from os.path import join, exists
from itertools import zip_longest, chain
from dataset import MyDataset
from torch.utils.data import Dataset, DataLoader
from torch.nn import CrossEntropyLoss
from train import create_model
import torch.nn.functional as F
import sys
PAD = '[PAD]'
pad_id = 0
def set_interact_args():
"""
Sets up the training arguments.
"""
parser = argparse.ArgumentParser()
parser.add_argument('--device', default='0,1', type=str, required=False, help='生成设备')
parser.add_argument('--temperature', default=1, type=float, required=False, help='生成的temperature')
parser.add_argument('--topk', default=8, type=int, required=False, help='最高k选1')
parser.add_argument('--topp', default=0, type=float, required=False, help='最高积累概率')
parser.add_argument('--model_config', default='GPT2_NLPCC_Summary/config.json', type=str, required=False,
help='模型参数')
parser.add_argument('--log_path', default='logs/interacting_2.log', type=str, required=False, help='interact日志存放位置')
parser.add_argument('--voca_path', default='vocabulary/vocab_NLPCC.txt', type=str, required=False, help='选择词库')
parser.add_argument('--summary_model_path', default='/home/zhy2018/projects/abstract_gpt/GPT2_NLPCC_Summary', type=str, required=False, help='对话模型路径')
parser.add_argument('--save_samples_path', default="sample/", type=str, required=False, help="保存聊天记录的文件路径")
parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False,
help="重复惩罚参数,若生成的对话重复性较高,可适当提高该参数")
parser.add_argument('--seed', type=int, default=42, help='设置种子用于生成随机数,以使得训练的结果是确定的')
parser.add_argument('--max_len', type=int, default=512, help='每个utterance的最大长度,超过指定长度则进行截断')
parser.add_argument('--max_history_len', type=int, default=1, help="dialogue history的最大长度")
parser.add_argument('--no_cuda', default=False, help='不使用GPU进行预测')
return parser.parse_args()
def create_logger(args):
"""
将日志输出到日志文件和控制台
"""
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s - %(levelname)s - %(message)s')
# 创建一个handler,用于写入日志文件
file_handler = logging.FileHandler(
filename=args.log_path)
file_handler.setFormatter(formatter)
file_handler.setLevel(logging.INFO)
logger.addHandler(file_handler)
# 创建一个handler,用于将日志输出到控制台
console = logging.StreamHandler()
console.setLevel(logging.DEBUG)
console.setFormatter(formatter)
logger.addHandler(console)
return logger
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (vocabulary size)
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
top_k = min(top_k, logits.size(-1)) # Safety check
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
# torch.topk()返回最后一维最大的top_k个元素,返回值为二维(values,indices)
# ...表示其他维度由计算机自行推断
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value # 对于topk之外的其他元素的logits值设为负无穷
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True) # 对logits进行递减排序
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
return logits
# 可能得用这个来进行测试看rouge值是多少了
def get_summary(text, model, tokenizer, device, args):
# text = input("请输入一个文章:")
#for i in range(5): # 会尝试生成5次
if len(text) > 600:
text = text[:500]
input_ids = [tokenizer.cls_token_id] # 每个input以[CLS]为开头
input_ids.extend(tokenizer.encode(text))
input_ids.append(tokenizer.sep_token_id)
curr_input_tensor = torch.tensor(input_ids).long().to(device)
generated = []
# 最多生成max_len个token
for _ in range(args.max_len):
outputs = model(input_ids=curr_input_tensor)
# outputs[0]的维度是(563, 13317), 也就是(seq_len, vocab_size), 应该是因为batch_size=1所以就自己忽略了
# print(outputs[0].shape)
next_token_logits = outputs[0][-1, :]
# 对于已生成的结果generated中的每个token添加一个重复惩罚项,降低其生成概率
for id in set(generated):
next_token_logits[id] /= args.repetition_penalty
next_token_logits = next_token_logits / args.temperature
# 对于[UNK]的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token
next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf')
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=args.topk, top_p=args.topp)
# torch.multinomial表示从候选集合中无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
if next_token == tokenizer.sep_token_id: # 遇到[SEP]则表明response生成结束
break
generated.append(next_token.item())
curr_input_tensor = torch.cat((curr_input_tensor, next_token), dim=0)
text = tokenizer.convert_ids_to_tokens(generated)
return "".join(text)
# logger.info("summary:" + "".join(text))
# except Exception:
# continue
if __name__ == '__main__':
args = set_interact_args()
logger = create_logger(args)
# 当用户使用GPU,并且GPU可用时
args.cuda = torch.cuda.is_available() and not args.no_cuda
# args.cuda = False
device = 'cuda' if args.cuda else 'cpu'
logger.info('using device:{}'.format(device))
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
tokenizer = BertTokenizer(vocab_file=args.voca_path)
model = GPT2LMHeadModel.from_pretrained(args.summary_model_path)
model.to(device)
model.eval()
logger.info(args)
print('***********************Summary model start************************')
# text = text = """原告郑祥旭诉被告尹德乐、第三人敦化市黄泥河镇新开道村村民委员会(以下简称新开道村村委会)侵权责任纠纷一案,本院于2017年8月2日立案后,依法适用普通程序,公开开庭进行了审理。现要求:一、被告将原告耕地恢复原状并赔偿经济损失1万元;二、我建房的场地是一片空地,没有任何植被,而且我在村内居住40年,未见任何人耕种过此涝洼地;三、原告名下0.038公顷的道南地已经退耕还林,根本不存在我盖房占用的事实。新开道村村委会述称,原告诉请中的0.038公顷道南地据村里档案记载,已经退耕还林。本案中,原告向本院提供的《土地承包使用期合同》中涉案地即“道南地”并无登记四至,且据本院向敦化市黄泥河镇经营管理站调查了解,该地块亦无原始的四至记载。根据庭审时第三人陈述,该地块的四至中的西至位置尚不明确,故无法证实现有四至的存在,亦无法推断出被告房屋侵占该地块的事实,且第三人村委会予以证实该地块已退耕还林;其次,根据本院查明的事实,被告房屋建于2014年,建设时的空地上并无任何植被,且经第三人证实该地块当时的状态为“抛弃地”,并非耕地,原告主张之所以将该地块荒废是准备种植人参“养地”的说法亦无事实依据。依据《中华人民共和国侵权责任法》第六条;《中华人民共和国民事诉讼法》第六十四条第一款、第一百四十二条之规定,判决如下:驳回原告郑祥旭的诉讼请求。"""
while True:
doc = input('请输入文章:')
summary = get_summary(doc, model, tokenizer, device, args)
print()
logger.info("summary:" + summary)
print()
print()