diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..62fff29 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +__pycache__/ + +*.log +*.bin + +*.sqlite3 diff --git a/BERT_CRF.py b/BERT_CRF_Model.py similarity index 93% rename from BERT_CRF.py rename to BERT_CRF_Model.py index 8d59731..8e7fa07 100644 --- a/BERT_CRF.py +++ b/BERT_CRF_Model.py @@ -10,11 +10,6 @@ VOB_NAME = "bert-base-chinese-vocab.txt" - - - - - class BertCrf(nn.Module): def __init__(self,config_name:str,model_name:str = None,num_tags: int = 2, batch_first:bool = True) -> None: # 记录batch_first @@ -58,7 +53,6 @@ def __init__(self,config_name:str,model_name:str = None,num_tags: int = 2, batch self.crf_model = CRF(num_tags=num_tags,batch_first=batch_first) - def forward(self,input_ids:torch.Tensor, tags:torch.Tensor = None, attention_mask:Optional[torch.ByteTensor] = None, @@ -68,14 +62,12 @@ def forward(self,input_ids:torch.Tensor, emissions = self.bertModel(input_ids = input_ids,attention_mask = attention_mask,token_type_ids=token_type_ids)[0] - # 这里在seq_len的维度上去头,是去掉了[CLS],去尾巴有两种情况 - # 1、是 2、[SEP] - - + # 这里在seq_len的维度上去掉开头,是去掉了[CLS],去尾巴有两种情况 + # 第一种情况是 、第二种情况是 [SEP] new_emissions = emissions[:,1:-1] new_mask = attention_mask[:,2:].bool() - # 如果 tags 为 None,表示是一个预测的过程,不能求得loss,loss 直接为None + # 如果 tags 为 None,表示是一个预测的过程,不能求得 loss,则 loss 的值直接为 None if tags is None: loss = None pass @@ -83,8 +75,6 @@ def forward(self,input_ids:torch.Tensor, new_tags = tags[:, 1:-1] loss = self.crf_model(emissions=new_emissions, tags=new_tags, mask=new_mask, reduction=reduction) - - if decode: tag_list = self.crf_model.decode(emissions = new_emissions,mask = new_mask) return [loss, tag_list] @@ -92,5 +82,3 @@ def forward(self,input_ids:torch.Tensor, return [loss] - - diff --git a/CRF_Model.py b/CRF_Model.py index e9c3047..3e249e3 100644 --- a/CRF_Model.py +++ b/CRF_Model.py @@ -2,6 +2,9 @@ import torch import torch.nn as nn +""" +CRF 条件随机场模型; +""" class CRF(nn.Module): def __init__(self,num_tags : int = 2, batch_first:bool = True) -> None: @@ -10,7 +13,8 @@ def __init__(self,num_tags : int = 2, batch_first:bool = True) -> None: super().__init__() self.num_tags = num_tags self.batch_first = batch_first - # start 到其他tag(不包含end)的得分 + # start 到其他 tag (不包含 end) 的得分 + # (从开始节点到其他非 end 节点的 scores) self.start_transitions = nn.Parameter(torch.empty(num_tags)) # 到其他tag(不包含start)到end的得分 self.end_transitions = nn.Parameter(torch.empty(num_tags)) @@ -21,6 +25,7 @@ def __init__(self,num_tags : int = 2, batch_first:bool = True) -> None: self.reset_parameters() + # 对参数进行重新设置 def reset_parameters(self): init_range = 0.1 nn.init.uniform_(self.start_transitions,-init_range,init_range) @@ -30,6 +35,7 @@ def reset_parameters(self): def __repr__(self): return f'{self.__class__.__name__}(num_tags={self.num_tags})' + # 向前传播; def forward(self, emissions:torch.Tensor, tags:torch.Tensor = None, mask:Optional[torch.ByteTensor] = None, @@ -42,6 +48,7 @@ def forward(self, emissions:torch.Tensor, raise ValueError(f'invalid reduction {reduction}') if mask is None: + #生成值全为1的张量,用于掩码 mask = torch.ones_like(tags,dtype = torch.uint8) # a.shape (seq_len,batch_size) # a[0] shape ? batch_size @@ -81,10 +88,6 @@ def decode(self,emissions:torch.Tensor, return self._viterbi_decode(emissions,mask) - - - - def _validate(self, emissions:torch.Tensor, tags:Optional[torch.LongTensor] = None , @@ -146,7 +149,7 @@ def _computer_score(self, # 这里是为了获取每一个样本最后一个词的tag。 # shape: (batch_size,) 每一个batch 的真实长度 seq_ends = mask.long().sum(dim=0) - 1 - # 每个样本最火一个词的tag + # 每个样本最后一个词的tag last_tags = tags[seq_ends,torch.arange(batch_size)] # shape: (batch_size,) 每一个样本到最后一个词的得分加上之前的score score += self.end_transitions[last_tags] @@ -250,4 +253,4 @@ def _viterbi_decode(self,emissions : torch.FloatTensor , best_tags.reverse() best_tags_list.append(best_tags) - return best_tags_list \ No newline at end of file + return best_tags_list diff --git a/test_NER.py b/NERTest.py similarity index 93% rename from test_NER.py rename to NERTest.py index f6e1c09..33d5602 100644 --- a/test_NER.py +++ b/NERTest.py @@ -1,15 +1,15 @@ -from BERT_CRF import BertCrf +from BERT_CRF_Model import BertCrf from transformers import BertTokenizer -from NER_main import NerProcessor,statistical_real_sentences,flatten,CrfInputFeatures +from NERTrain import NerProcessor,statistical_real_sentences,flatten,CrfInputFeatures from torch.utils.data import DataLoader, RandomSampler,TensorDataset from sklearn.metrics import classification_report import torch import numpy as np from tqdm import tqdm, trange - - - +""" +对命名实体识别模型进行简单的测试 +""" processor = NerProcessor() tokenizer_inputs = () @@ -81,4 +81,4 @@ # # micro avg 0.996142 0.996142 0.996142 145137 # macro avg 0.994650 0.994380 0.994512 145137 -# weighted avg 0.996149 0.996142 0.996143 145137 \ No newline at end of file +# weighted avg 0.996149 0.996142 0.996143 145137 diff --git a/NER_main.py b/NERTrain.py similarity index 90% rename from NER_main.py rename to NERTrain.py index e256b00..be47d6d 100644 --- a/NER_main.py +++ b/NERTrain.py @@ -1,26 +1,3 @@ -# --data_dir -# ./input/data/ner_data -# --vob_file -# ./input/config/bert-base-chinese-vocab.txt -# --model_config -# ./input/config/bert-base-chinese-config.json -# --output -# ./output -# --pre_train_model -# ./input/config/bert-base-chinese-model.bin -# --max_seq_length -# 64 -# --do_train -# --train_batch_size -# 32 -# --eval_batch_size -# 256 -# --gradient_accumulation_steps -# 4 -# --num_train_epochs -# 15 - - import argparse import logging import codecs @@ -33,19 +10,16 @@ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset from transformers import BertForSequenceClassification,BertTokenizer,BertConfig from transformers.data.processors.utils import DataProcessor, InputExample -from BERT_CRF import BertCrf -from transformers import AdamW, WarmupLinearSchedule +from BERT_CRF_Model import BertCrf +from transformers import AdamW, get_linear_schedule_with_warmup from sklearn.metrics import classification_report logger = logging.getLogger(__name__) -# -# CRF_LABELS = ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "X", "[CLS]", "[SEP]"] -# 在这个项目中只需要识别三个类型的项目即可 -# 这里做以下测试,第一 LABELS = ["O", "B-LOC", "I-LOC"] ,因为需要预测的就只有这三个。 -# 第二 LABELS = ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "X", "[CLS]", "[SEP]"] -CRF_LABELS = ["O", "B-LOC", "I-LOC"] +# CRF_LABELS = ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "X", "[CLS]", "[SEP]"] +# 在项目中只需要识别三个类型的项目即可 +CRF_LABELS = ["O", "B-LOC", "I-LOC"] # 需要预测的就只有这三个 def statistical_real_sentences(input_ids:torch.Tensor,mask:torch.Tensor,predict:list)-> list: # shape (batch_size,max_len) @@ -53,7 +27,7 @@ def statistical_real_sentences(input_ids:torch.Tensor,mask:torch.Tensor,predict: # batch_size assert input_ids.shape[0] == len(predict) - # 第0位是[CLS] 最后一位是 或者 [SEP] + # 开头是 [CLS],结尾是 或者 [SEP] new_ids = input_ids[:,1:-1] new_mask = mask[:,2:] @@ -67,29 +41,21 @@ def statistical_real_sentences(input_ids:torch.Tensor,mask:torch.Tensor,predict: def flatten(inputs:list) -> list: result = [] + # 在列表末尾进行追加,以达到 flatten 的目的 [result.extend(line) for line in inputs] return result - - - - - - - - - - def set_seed(args): random.seed(args.seed) np.random.seed(args.seed) + # 为 CPU 设置种子用于生成随机数, 使得结果是确定的。 torch.manual_seed(args.seed) - class CrfInputExample(object): def __init__(self, guid, text, label=None): + # 初始化 self.guid = guid self.text = text self.label = label @@ -97,13 +63,13 @@ def __init__(self, guid, text, label=None): class CrfInputFeatures(object): def __init__(self, input_ids, attention_mask, token_type_ids, label): + # 初始化 self.input_ids = input_ids self.attention_mask = attention_mask self.token_type_ids = token_type_ids self.label = label - - +# 将语句序列化 def crf_convert_examples_to_features(examples,tokenizer, max_length=512, label_list=None, @@ -116,6 +82,7 @@ def crf_convert_examples_to_features(examples,tokenizer, features = [] for (ex_index, example) in enumerate(examples): + # 调用 BertTokenizer inputs = tokenizer.encode_plus( example.text, add_special_tokens=True, @@ -123,23 +90,20 @@ def crf_convert_examples_to_features(examples,tokenizer, truncate_first_sequence=True # We're truncating the first sequence in priority if True ) input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"] + # Masked 操作 attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) - padding_length = max_length - len(input_ids) input_ids = input_ids + ([pad_token] * padding_length) attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length) token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length) - # 第一个和第二个[0] 加的是[CLS]和[SEP]的位置, [0]*padding_length是[pad] ,把这些都暂时算作"O",后面用mask 来消除这些,不会影响 + # 第一个和第二个[0] 加的是[CLS]和[SEP]的位置, [0]*padding_length是[pad] ,把这些都暂时算作"O",对 mask 没有影响 labels_ids = [0] + [label_map[l] for l in example.label] + [0] + [0]*padding_length - - assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length) assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask),max_length) assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(len(token_type_ids),max_length) - assert len(labels_ids) == max_length, "Error with input length {} vs {}".format(len(labels_ids),max_length) @@ -158,14 +122,17 @@ def crf_convert_examples_to_features(examples,tokenizer, class NerProcessor(DataProcessor): + # 获取训练集 def get_train_examples(self,data_dir): return self._create_examples( os.path.join(data_dir,"train.txt")) + # 获取验证集 def get_dev_examples(self, data_dir): return self._create_examples( os.path.join(data_dir, "dev.txt")) + # 获取测试集 def get_test_examples(self, data_dir): return self._create_examples( os.path.join(data_dir, "test.txt")) @@ -227,18 +194,22 @@ def load_and_cache_example(args,tokenizer,processor,data_type): features = crf_convert_examples_to_features(examples=examples,tokenizer=tokenizer,max_length=args.max_seq_length,label_list=label_list) logger.info("Saving features into cached file %s", cached_features_file) torch.save(features, cached_features_file) - + # 获取 input 的 ID all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) + # 获取 掩码 all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) + # 获取 类型的 ID all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) + # 获取标签 all_label = torch.tensor([f.label for f in features], dtype=torch.long) dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_label) return dataset def trains(args,train_dataset,eval_dataset,model): - + # RandomSampler, 随机采样器 train_sampler = RandomSampler(train_dataset) + train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs @@ -250,8 +221,8 @@ def trains(args,train_dataset,eval_dataset,model): {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(optimizer_grouped_parameters,lr=args.learning_rate,eps=args.adam_epsilon) - - scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + # 学习率预热,在实验开始提高,后面逐渐下降; + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps = t_total) logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) @@ -265,6 +236,7 @@ def trains(args,train_dataset,eval_dataset,model): set_seed(args) best_f1 = 0. for _ in train_iterator: + # 进度条,用于显示整个实验的进度 epoch_iterator = tqdm(train_dataloader, desc="Iteration") for step,batch in enumerate(epoch_iterator): batch = tuple(t.to(args.device) for t in batch) @@ -279,7 +251,9 @@ def trains(args,train_dataset,eval_dataset,model): if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps + # 反向回溯 loss.backward() + # 进行梯度截断 torch.nn.utils.clip_grad_norm_(model.parameters(),args.max_grad_norm) logging_loss += loss.item() tr_loss += loss.item() @@ -292,9 +266,9 @@ def trains(args,train_dataset,eval_dataset,model): logging_loss) logging_loss = 0.0 - # if (global_step < 100 and global_step % 10 == 0) or (global_step % 50 == 0): - # 每 相隔 100步,评估一次 - if global_step % 100 == 0: + # 每相隔 100步,评估一次 + if global_step % .100 == 0: + #评估并保存模型 best_f1 = evaluate_and_save_model(args,model,eval_dataset,_,global_step,best_f1) # 最后循环结束 再评估一次 @@ -305,6 +279,7 @@ def trains(args,train_dataset,eval_dataset,model): def evaluate_and_save_model(args,model,eval_dataset,epoch,global_step,best_f1): ret = evaluate(args, model, eval_dataset) + print(ret) precision_b = ret['1']['precision'] recall_b = ret['1']['recall'] @@ -323,9 +298,9 @@ def evaluate_and_save_model(args,model,eval_dataset,epoch,global_step,best_f1): avg_recall = recall_b * weight_b + recall_i * weight_i avg_f1 = f1_b * weight_b + f1_i * weight_i - all_avg_precision = ret['micro avg']['precision'] - all_avg_recall = ret['micro avg']['recall'] - all_avg_f1 = ret['micro avg']['f1-score'] + all_avg_precision = ret['macro avg']['precision'] + all_avg_recall = ret['macro avg']['recall'] + all_avg_f1 = ret['macro avg']['f1-score'] logger.info("Evaluating EPOCH = [%d/%d] global_step = %d", epoch+1,args.num_train_epochs,global_step) logger.info("B-LOC precision = %f recall = %f f1 = %f support = %d", precision_b, recall_b, f1_b, @@ -339,6 +314,7 @@ def evaluate_and_save_model(args,model,eval_dataset,epoch,global_step,best_f1): all_avg_f1) if avg_f1 > best_f1: + #若当前的模型比历史最优模型要好 best_f1 = avg_f1 torch.save(model.state_dict(), os.path.join(args.output_dir, "best_ner.bin")) logging.info("save the best model %s,avg_f1= %f", os.path.join(args.output_dir, "best_bert.bin"), @@ -347,10 +323,6 @@ def evaluate_and_save_model(args,model,eval_dataset,epoch,global_step,best_f1): return best_f1 - - - - def evaluate(args, model, eval_dataset): eval_output_dirs = args.output_dir @@ -380,8 +352,7 @@ def evaluate(args, model, eval_dataset): 'reduction':'none' } outputs = model(**inputs) - # temp_eval_loss shape: (batch_size) - # temp_pred : list[list[int]] 长度不齐 + temp_eval_loss, temp_pred = outputs[0], outputs[1] loss.extend(temp_eval_loss.tolist()) @@ -398,21 +369,18 @@ def evaluate(args, model, eval_dataset): return ret - - - - def main(): + #argparse 命令行解析的标准模块, 在命令行中传入参数后让程序运行 parser = argparse.ArgumentParser() - parser.add_argument("--data_dir", default=None, type=str, required=True, + parser.add_argument("--data_dir", default=None, type=str, help="数据文件目录,因当有train.txt dev.txt") - parser.add_argument("--vob_file", default=None, type=str, required=True, + parser.add_argument("--vob_file", default=None, type=str, help="词表文件") - parser.add_argument("--model_config", default=None, type=str, required=True, + parser.add_argument("--model_config", default=None, type=str, help="模型配置文件json文件") - parser.add_argument("--output_dir", default=None, type=str, required=True, + parser.add_argument("--output_dir", default=None, type=str, help="输出结果的文件") # Other parameters @@ -451,10 +419,10 @@ def main(): logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) - # filename='./output/bert-crf-ner.log', + processor = NerProcessor() - # 得到tokenizer + # 得到 tokenizer tokenizer_inputs = () tokenizer_kwards = {'do_lower_case': False, 'max_len': args.max_seq_length, @@ -476,4 +444,5 @@ def main(): if __name__ == '__main__': + torch.cuda.empty_cache() main() diff --git a/test_pro.py b/ProjectTest.py similarity index 71% rename from test_pro.py rename to ProjectTest.py index ab91cb2..6803382 100644 --- a/test_pro.py +++ b/ProjectTest.py @@ -1,22 +1,25 @@ -from BERT_CRF import BertCrf -from NER_main import NerProcessor, CRF_LABELS -from SIM_main import SimProcessor,SimInputFeatures +from BERT_CRF_Model import BertCrf +from NERTrain import NerProcessor, CRF_LABELS +from SIMTrain import SimProcessor,SimInputFeatures from transformers import BertTokenizer, BertConfig, BertForSequenceClassification from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset import torch import pymysql from tqdm import tqdm, trange +import WikiQuery - +# 载入 GPU,没有则使用 CPU device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +# 获取之前训练的 NER 模型 def get_ner_model(config_file,pre_train_model,label_num = 2): model = BertCrf(config_name=config_file,num_tags=label_num, batch_first=True) model.load_state_dict(torch.load(pre_train_model)) return model.to(device) +# 获取之前训练的 SIM 模型 def get_sim_model(config_file,pre_train_model,label_num = 2): bert_config = BertConfig.from_pretrained(config_file) bert_config.num_labels = label_num @@ -25,6 +28,7 @@ def get_sim_model(config_file,pre_train_model,label_num = 2): return model +# 从一个句子中获取实体 def get_entity(model,tokenizer,sentence,max_len = 64): pad_token = 0 sentence_list = list(sentence.strip().replace(' ','')) @@ -86,6 +90,7 @@ def get_entity(model,tokenizer,sentence,max_len = 64): return "".join(entity_list) +# 语义匹配 def semantic_matching(model,tokenizer,question,attribute_list,answer_list,max_length): assert len(attribute_list) == len(answer_list) @@ -157,7 +162,7 @@ def semantic_matching(model,tokenizer,question,attribute_list,answer_list,max_le def select_database(sql): - # connect database + # 连接数据库 connect = pymysql.connect(user="root",password="123456",host="127.0.0.1",port=3306,db="kb_qa",charset="utf8") cursor = connect.cursor() # 创建操作游标 try: @@ -174,7 +179,7 @@ def select_database(sql): return results -# 文字直接匹配,看看属性的词语在不在句子之中 +# 文字直接匹配,看看属性的词是否在句子中 def text_match(attribute_list,answer_list,sentence): assert len(attribute_list) == len(answer_list) @@ -190,10 +195,9 @@ def text_match(attribute_list,answer_list,sentence): return "","" - -def main(): - - with torch.no_grad(): +class Model: + @torch.no_grad() + def __init__(self): tokenizer_inputs = () tokenizer_kwards = {'do_lower_case': False, 'max_len': 64, @@ -215,52 +219,93 @@ def main(): sim_model = sim_model.to(device) sim_model.eval() - while True: - print("====="*10) - raw_text = input("问题:\n") - raw_text = raw_text.strip() - if ( "quit" == raw_text ): - print("quit") - return - entity = get_entity(model=ner_model, tokenizer=tokenizer, sentence=raw_text, max_len=64) - print("实体:", entity) - if '' == entity: - print("未发现实体") - continue - sql_str = "select * from nlpccqa where entity = '{}'".format(entity) - triple_list = select_database(sql_str) - triple_list = list(triple_list) - if 0 == len(triple_list): - print("未找到 {} 相关信息".format(entity)) - continue - triple_list = list(zip(*triple_list)) - # print(triple_list) - attribute_list = triple_list[1] - answer_list = triple_list[2] - attribute, answer = text_match(attribute_list, answer_list, raw_text) - if attribute != '' and answer != '': - ret = "{}的{}是{}".format(entity, attribute, answer) - else: - sim_model = get_sim_model(config_file='./input/config/bert-base-chinese-config.json', - pre_train_model='./output/best_sim.bin', - label_num=len(sim_processor.get_labels())) - - sim_model = sim_model.to(device) - sim_model.eval() - attribute_idx = semantic_matching(sim_model, tokenizer, raw_text, attribute_list, answer_list, 64).item() - if -1 == attribute_idx: - ret = '' - else: - attribute = attribute_list[attribute_idx] - answer = answer_list[attribute_idx] - ret = "{}的{}是{}".format(entity, attribute, answer) - if '' == ret: - print("未找到{}相关信息".format(entity)) + @torch.no_grad() + def query(self, raw_text) -> str: + # 获取实体 + entity = get_entity(model=self.ner_model, tokenizer=self.tokenizer, sentence=raw_text, max_len=64) + + print("实体:", entity) + + if '' == entity: + return "未发现实体" + + sql_str = "select * from nlpccqa where entity = '{}'".format(entity) + + triple_list = select_database(sql_str) + triple_list = list(triple_list) + + print(triple_list) + + if 0 == len(triple_list): + # 未找到相关信息 + + print("未找到 {} 相关信息".format(entity)) + + print("正在通过网络查找中...") + + WikiQuery.getInfobox(entity) + # if len(elem_dic) != 0: + # for key in elem_dic: + # #print(key.text, elem_dic[key].text) + # if len(elem_dic[key].text) <= 10: + # insert_data(entity, key.text, elem_dic[key].text) + + print("查找完毕") + + return "未找到 {} 相关信息,尝试通过网络查找...".format(entity) + + triple_list = list(zip(*triple_list)) + print(triple_list) + + attribute_list = triple_list[1] + answer_list = triple_list[2] + # 直接进行匹配 + attribute, answer = text_match(attribute_list, answer_list, raw_text) + if attribute != '' and answer != '': + ret = "{}的{}是{}".format(entity, attribute, answer) + else: + self.sim_model = get_sim_model(config_file='./input/config/bert-base-chinese-config.json', + pre_train_model='./output/best_sim.bin', + label_num=len(self.sim_processor.get_labels())) + + self.sim_model = self.sim_model.to(device) + self.sim_model.eval() + # 进行语义匹配 + attribute_idx = semantic_matching(self.sim_model, self.tokenizer, raw_text, attribute_list, answer_list, 64).item() + if -1 == attribute_idx: + ret = '' else: - print("回答:",ret) + attribute = attribute_list[attribute_idx] + answer = answer_list[attribute_idx] + ret = "{}的{}是{}".format(entity, attribute, answer) + + if '' == ret: + print("未找到{}相关信息".format(entity)) + print("正在通过网络查找中...") + WikiQuery.getInfobox(entity) + return "未找到 {} 相关信息,尝试通过网络查找...".format(entity) + else: + return ret + if __name__ == '__main__': - main() + + model = Model() + + while True: + + print("====="*10) + + raw_text = input("问题:\n").strip() + + # 输入 quit 则退出。 + if ( "quit" == raw_text ): + print("quit") + break + + ans = model.query(raw_text) + + print('回答:', ans) diff --git a/README.md b/README.md index ebcc453..faa1794 100644 --- a/README.md +++ b/README.md @@ -1,51 +1,80 @@ # bert-kbqa -基于bert的kbqa系统 - -预训练模型太大了,无法上传。放在百度网盘了 -链接:https://pan.baidu.com/s/1EK-TGghfmj-0HbWl_xe3zg -提取码:jqeg -下载下来放在./bert-kbqa/input/config目录 - -构造数据集 -通过 1_split_data.py 切分数据 - -通过 2-construct_dataset_ner.py 构造命名实体识别的数据集 +基于bert的kbqa系统 -通过 3-construct_dataset_attribute.py 构造属性相似度的数据集 +预训练模型太大, 放在百度网盘了 -通过 4-print-seq-len.py 看看句子的长度 +- 链接: https://pan.baidu.com/s/1EK-TGghfmj-0HbWl_xe3zg +- 提取码: jqeg -通过 5-triple_clean.py 构造干净的三元组 +下载下来放在 `./bert-kbqa/input/config` 目录 -通过 6-load_dbdata.py 通过创建数据库 和 上传数据 +## 使用方法 +### 安装依赖 +- PyTorch +- Transformers -CRF_Model.py 条件随机场模型 +### 构造数据集 -BERT_CRF.py bert+条件随机场 +详见 [input/data/](./input/data/) 目录 -NER_main.py 训练命令实体识别的模型 +1. [SplitData.py](./input/data/SplitData.py) 切分数据 +2. [ConstructDatasetNer.py](./input/data/ConstructDatasetNer.py) 构造命名实体识别的数据集 +3. [ConstructDatasetAttribute.py](./input/data/ConstructDatasetAttribute.py) 构造属性相似度的数据集 +4. [ConstructTriple.py](./input/data/ConstructTriple.py) 构造干净的三元组 +5. [LoadMySQL.py](./input/data/LoadMySQL.py) 创建数据库和上传数据 -SIM_main.py 训练属性相似度的模型 +### 模型训练 +[CRF_Model.py](./CRF_Model.py) 条件随机场模型 -test_NER.py 测试命令实体识别 +[BERT_CRF_MODEL.py](./BERT_CRF_Model.py) bert+条件随机场 -test_SIM.py 测试属性相似度 +[NERTrain.py](./NERTrain.py) 训练命令实体识别的模型 -test_pro.py 测试整个项目 +```console +$ python3 NERTrain.py \ + --data_dir ./input/data/ner_data \ + --vob_file ./input/config/bert-base-chinese-vocab.txt \ + --model_config ./input/config/bert-base-chinese-config.json \ + --output ./output \ + --pre_train_model ./input/config/bert-base-chinese-model.bin \ + --max_seq_length 64 \ + --do_train \ + --train_batch_size 16 \ + --eval_batch_size 256 \ + --gradient_accumulation_steps 4 \ + --num_train_epochs 15 +``` +[SIMTrain.py](./SIMTrain.py) 训练属性相似度的模型 -主要依赖版本: +```console +$ python3 SIMTrain.py \ + --data_dir ./input/data/sim_data \ + --vob_file ./input/config/bert-base-chinese-vocab.txt \ + --model_config ./input/config/bert-base-chinese-config.json \ + --output ./output \ + --pre_train_model ./input/config/bert-base-chinese-model.bin \ + --max_seq_length 64 \ + --do_train \ + --train_batch_size 32 \ + --eval_batch_size 256 \ + --gradient_accumulation_steps 4 \ + --num_train_epochs 15 +``` -torch.__version__ 1.2.0 +### 模型测试 -transformers.__version__ 2.0.0 +[NERTest.py](./NERTest.py) 测试命令实体识别 +[SIMTest.py](./SIMTest.py) 测试属性相似度 -带有命令运行的py文件的命令都在 该py文件的最上方 +## 项目展示 +[ProjectTest.py](./ProjectTest.py) 测试整个项目 / 命令行界面的项目展示 +[chat.py](./chat.py) Qt/QML 编写的图形界面 diff --git a/test_SIM.py b/SIMTest.py similarity index 96% rename from test_SIM.py rename to SIMTest.py index 9aaa9c2..1d9929d 100644 --- a/test_SIM.py +++ b/SIMTest.py @@ -1,9 +1,14 @@ from transformers import BertConfig, BertForSequenceClassification, BertTokenizer from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset -from SIM_main import SimProcessor,SimInputFeatures,cal_acc +from SIMTrain import SimProcessor,SimInputFeatures,cal_acc import torch from tqdm import tqdm, trange +""" +对属性相似度模型进行测试 +""" + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") processor = SimProcessor() tokenizer_inputs = () diff --git a/SIM_main.py b/SIMTrain.py similarity index 96% rename from SIM_main.py rename to SIMTrain.py index 83272cb..3af654a 100644 --- a/SIM_main.py +++ b/SIMTrain.py @@ -1,25 +1,3 @@ -# --data_dir -# ./input/data/sim_data -# --vob_file -# ./input/config/bert-base-chinese-vocab.txt -# --model_config -# ./input/config/bert-base-chinese-config.json -# --output -# ./output -# --pre_train_model -# ./input/config/bert-base-chinese-model.bin -# --max_seq_length -# 64 -# --do_train -# --train_batch_size -# 32 -# --eval_batch_size -# 256 -# --gradient_accumulation_steps -# 4 -# --num_train_epochs -# 15 - import argparse from collections import Counter import code @@ -33,13 +11,11 @@ import torch.nn as nn from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset -from transformers import AdamW, WarmupLinearSchedule +from transformers import AdamW, get_linear_schedule_with_warmup from transformers import BertConfig, BertForSequenceClassification, BertTokenizer -from transformers import glue_convert_examples_to_features as convert_examples_to_features from transformers.data.processors.utils import DataProcessor, InputExample import numpy as np -import pandas as pd logger = logging.getLogger(__name__) @@ -76,9 +52,8 @@ def cal_acc(real_label,pred_label): - - class SimInputExample(object): + # 进行初始化 def __init__(self, guid, question,attribute, label=None): self.guid = guid self.question = question @@ -87,6 +62,7 @@ def __init__(self, guid, question,attribute, label=None): class SimInputFeatures(object): + # 进行初始化 def __init__(self, input_ids, attention_mask, token_type_ids, label = None): self.input_ids = input_ids self.attention_mask = attention_mask @@ -98,17 +74,19 @@ class SimProcessor(DataProcessor): """Processor for the FAQ problem modified from https://github.com/huggingface/transformers/blob/master/transformers/data/processors/glue.py#L154 """ - + # 获取训练集 def get_train_examples(self, data_dir): logger.info("******* train ********") return self._create_examples( os.path.join(data_dir, "train.txt")) + # 获取验证集 def get_dev_examples(self, data_dir): logger.info("******* dev ********") return self._create_examples( os.path.join(data_dir, "dev.txt")) + # 获取测试集 def get_test_examples(self,data_dir): logger.info("******* test ********") return self._create_examples( @@ -133,7 +111,7 @@ def _create_examples(cls, path): f.close() return examples - +# 将文本进行序列化 def sim_convert_examples_to_features(examples,tokenizer, max_length=512, label_list=None, @@ -230,7 +208,7 @@ def trains(args,train_dataset,eval_dataset,model): ] optimizer = AdamW(optimizer_grouped_parameters,lr=args.learning_rate,eps=args.adam_epsilon) - scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) @@ -239,11 +217,14 @@ def trains(args,train_dataset,eval_dataset,model): global_step = 0 tr_loss, logging_loss = 0.0, 0.0 + # 梯度清零 model.zero_grad() train_iterator = trange(int(args.num_train_epochs), desc="Epoch") + # 设置随机数种子 set_seed(args) best_acc = 0. for _ in train_iterator: + # 进度条 epoch_iterator = tqdm(train_dataloader, desc="Iteration") for step,batch in enumerate(epoch_iterator): batch = tuple(t.to(args.device) for t in batch) @@ -257,6 +238,7 @@ def trains(args,train_dataset,eval_dataset,model): if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() + # 梯度截断 torch.nn.utils.clip_grad_norm_(model.parameters(),args.max_grad_norm) logging_loss += loss.item() tr_loss += loss.item() @@ -311,6 +293,7 @@ def evaluate(args, model, eval_dataset): total_sample_num = 0 # 样本总数目 all_real_label = [] # 记录所有的真实标签列表 all_pred_label = [] # 记录所有的预测标签列表 + for batch in tqdm(eval_dataloader, desc="Evaluating"): model.eval() batch = tuple(t.to(args.device) for t in batch) diff --git a/WikiQuery.py b/WikiQuery.py new file mode 100644 index 0000000..36e6338 --- /dev/null +++ b/WikiQuery.py @@ -0,0 +1,51 @@ +import time +from selenium import webdriver +from selenium.webdriver.common.keys import Keys +from input.data.LoadMySQL import insert_data + +# getInfobox函数 +def getInfobox(name): + try: + # 访问百度百科并自动搜索 + driver = webdriver.Firefox() + driver.get("http://baike.baidu.com/") + elem_inp = driver.find_element_by_xpath("//form[@id='searchForm']/input") + elem_inp.send_keys(name) + elem_inp.send_keys(Keys.RETURN) + time.sleep(1) + print(driver.current_url) + print(driver.title) + + # 爬取消息盒InfoBox内容 + elem_name = driver.find_elements_by_xpath("//div[@class='basic-info J-basic-info cmn-clearfix']/dl/dt") + elem_value = driver.find_elements_by_xpath("//div[@class='basic-info J-basic-info cmn-clearfix']/dl/dd") + + # for e in elem_name: + # print(e.text) + # for e in elem_value: + # print(e.text) + + + # 构建字段成对输出 + elem_dic = dict(zip(elem_name, elem_value)) + for key in elem_dic: + print(key.text, elem_dic[key].text) + insert_data(name, key.text, elem_dic[key].text) + + return + + except Exception as e: + print("Error: ", e) + + finally: + print('\n') + driver.close() + + +# 主函数 + +def main(): + getInfobox("仙剑奇侠传") + +if __name__ == '__main__': + main() diff --git a/chat.py b/chat.py new file mode 100644 index 0000000..41a5d4a --- /dev/null +++ b/chat.py @@ -0,0 +1,50 @@ +import datetime +import logging + +import sys + +from PySide6.QtCore import QDir, QFile, QObject, QUrl +from PySide6.QtGui import QGuiApplication +from PySide6.QtQml import QQmlApplicationEngine +from PySide6.QtSql import QSqlDatabase + +import sqlDialog + +logging.basicConfig(filename='chat.log', level=logging.DEBUG) +logger = logging.getLogger('logger') + + +def connectToDatabase(): + database = QSqlDatabase.database() + if not database.isValid(): + database = QSqlDatabase.addDatabase("QSQLITE") + if not database.isValid(): + logger.error("Cannot add database") + + write_dir = QDir("") + if not write_dir.mkpath("."): + logger.error("Failed to create writable directory") + + # Ensure that we have a writable location on all devices. + abs_path = write_dir.absolutePath() + filename = f"{abs_path}/chat-database.sqlite3" + + # When using the SQLite driver, open() will create the SQLite + # database if it doesn't exist. + database.setDatabaseName(filename) + if not database.open(): + logger.error("Cannot open database") + QFile.remove(filename) + + +if __name__ == "__main__": + app = QGuiApplication() + connectToDatabase() + + engine = QQmlApplicationEngine() + engine.load(QUrl("chat.qml")) + + if not engine.rootObjects(): + sys.exit(-1) + + app.exec() diff --git a/chat.qml b/chat.qml new file mode 100644 index 0000000..dc9ad1e --- /dev/null +++ b/chat.qml @@ -0,0 +1,100 @@ +import QtQuick +import QtQuick.Layouts +import QtQuick.Controls + +import ChatModel 1.0 + +ApplicationWindow { + id: root + title: qsTr("Chat") + width: 540 + height: 720 + visible: true + + SqlConversationModel { + id: chat_model + } + + ColumnLayout { + anchors.fill: parent + + ListView { + id: listView + Layout.fillWidth: true + Layout.fillHeight: true + Layout.margins: pane.leftPadding + messageField.leftPadding + displayMarginBeginning: 40 + displayMarginEnd: 40 + verticalLayoutDirection: ListView.BottomToTop + spacing: 12 + model: chat_model + delegate: Column { + + readonly property bool sentByMe: model.recipient !== "Me" + + anchors.right: sentByMe ? listView.contentItem.right : undefined + spacing: 6 + + Row { + id: messageRow + spacing: 6 + anchors.right: sentByMe ? parent.right : undefined + + Rectangle { // Message Blob + width: Math.min(messageText.implicitWidth + 24, + listView.width - (!sentByMe ? messageRow.spacing : 0)) + height: messageText.implicitHeight + 24 + radius: 15 + color: sentByMe ? "steelblue" : "lightgrey" + + Label { + id: messageText + text: model.message + color: sentByMe ? "white" : "black" + anchors.fill: parent + anchors.margins: 12 + wrapMode: Label.Wrap + } + } + } + + /*Label { + id: timestampText + text: Qt.formatDataTime(model.timestamp, "d MM hh:mm") + color: "lightgrey" + anchors.right: sentByMe ? parent.right : undefined + }*/ + } + + ScrollBar.vertical: ScrollBar {} + } + + Pane { + id: pane + Layout.fillWidth: true + + RowLayout { + width: parent.width + + TextArea { + id: messageField + Layout.fillWidth: true + placeholderText: qsTr("Compose message") + wrapMode: TextArea.Wrap + } + + Button { + id: sendButton + text: qsTr("Send") + enabled: messageField.length > 0 + onClicked: { + listView.model.send_message("machine", messageField.text, "Me"); + messageField.text = ""; + } + } + } + } + } +} + + diff --git a/input/data/.gitignore b/input/data/.gitignore new file mode 100644 index 0000000..90ced5a --- /dev/null +++ b/input/data/.gitignore @@ -0,0 +1,11 @@ +# Files created by SplitData.py +NLPCC2016KBQA/*.txt + +# Files created by ConstructDatasetNer.py +ner_data/ + +# Files created by ConstructDatasetAttribute.py +sim_data/ + +# Files created by ConstructDB.py +DB_Data/ diff --git a/input/data/1_split_data.py b/input/data/1_split_data.py deleted file mode 100644 index 73f3066..0000000 --- a/input/data/1_split_data.py +++ /dev/null @@ -1,50 +0,0 @@ -# -# 切分数据集, -# 原始的 nlpcc-iccpol-2016.kbqa.testing-data 有 9870 个样本 -# 原始的 nlpcc-iccpol-2016.kbqa.training-data 有 14609 个样本 -# -# 将nlpcc-iccpol-2016.kbqa.testing-data 中的对半分,一半变成验证集(dev.text),一半变成测试集(test.txt) -# nlpcc-iccpol-2016.kbqa.training-data 保持不变,复制成为训练集 train.txt -# - - -import pandas as pd -import os - - -data_dir = 'NLPCC2016KBQA' -file_name_list = ['nlpcc-iccpol-2016.kbqa.testing-data','nlpcc-iccpol-2016.kbqa.training-data'] - - -for file_name in file_name_list: - file_path_name = os.path.join(data_dir,file_name) - file = [] - with open(file_path_name,'r',encoding='utf-8') as f: - for line in f: - line = line.strip() - if line == '': - continue - file.append(line) - f.close() - if 'training' in file_name: - with open(os.path.join(data_dir,'train.txt') , "w", encoding='utf-8') as f: - f.write('\n'.join(file)) - f.close() - elif 'testing' in file_name: - assert len(file) % 4 == 0 # 断言 - testing_num = len(file) / 4 # 一个样本是由 4 行构成的 - test_num = int(testing_num / 2) # 真正的测试集分一半 - - test_line_no = int(test_num * 4) - - - with open(os.path.join(data_dir, 'test.txt'), "w", encoding='utf-8') as f: - f.write('\n'.join(file[:test_line_no])) # 乘以四得到行号,前一半给 test 数据集 - f.close() - - with open(os.path.join(data_dir, 'dev.txt'), "w", encoding='utf-8') as f: - f.write('\n'.join(file[test_line_no:])) # 乘以四得到行号,后一半给 dev 数据集 - f.close() - -print("Done") - diff --git a/input/data/2-construct_dataset_ner.py b/input/data/2-construct_dataset_ner.py deleted file mode 100644 index cd6e971..0000000 --- a/input/data/2-construct_dataset_ner.py +++ /dev/null @@ -1,77 +0,0 @@ -# coding:utf-8 -import sys -import os -import pandas as pd - - -''' -通过 NLPCC2016KBQA 中的原始数据,构建用来训练NER的样本集合 -构造NER训练集,实体序列标注,训练BERT+CRF -''' - -data_dir = 'NLPCC2016KBQA' -file_name_list = ['train.txt','dev.txt','test.txt'] - -new_dir = 'ner_data' - -question_str = "")[1].strip() - q_str = q_str.split(">")[1].replace(" ", "").strip() - if entities in q_str: - q_list = list(q_str) - seq_q_list.extend(q_list) - seq_q_list.extend([" "]) - tag_list = ["O" for i in range(len(q_list))] - tag_start_index = q_str.find(entities) - for i in range(tag_start_index, tag_start_index + len(entities)): - if tag_start_index == i: - tag_list[i] = "B-LOC" - else: - tag_list[i] = "I-LOC" - seq_tag_list.extend(tag_list) - seq_tag_list.extend([" "]) - else: - pass - q_t_a_list.append([q_str, t_str, a_str]) - print(file_name) - print('\t'.join(seq_tag_list[0:50])) - print('\t'.join(seq_q_list[0:50])) - seq_result = [str(q) + " " + tag for q, tag in zip(seq_q_list, seq_tag_list)] - if not os.path.exists(new_dir): - os.mkdir(new_dir) - - with open(os.path.join(new_dir,file_name), "w", encoding='utf-8') as f: - f.write("\n".join(seq_result)) - f.close() - - df = pd.DataFrame(q_t_a_list, columns=["q_str", "t_str", "a_str"]) - file_type = file_name.split('.')[0] - csv_name = file_type+'.'+'csv' - - df.to_csv(os.path.join(new_dir,csv_name), encoding='utf-8', index=False) \ No newline at end of file diff --git a/input/data/3-construct_dataset_attribute.py b/input/data/3-construct_dataset_attribute.py deleted file mode 100644 index 0fac19c..0000000 --- a/input/data/3-construct_dataset_attribute.py +++ /dev/null @@ -1,68 +0,0 @@ -# coding:utf-8 -import sys -import os -import random -import pandas as pd -import re - -''' -通过 ner_data 中的数据 构建出 用来匹配句子相似度的 样本集合 -构造属性关联训练集,分类问题,训练BERT分类模型 -1 -''' - - -data_dir = 'ner_data' -file_name_list = ['train.csv','dev.csv','test.csv'] - -new_dir = 'sim_data' - -pattern = re.compile('^-+') # 以-开头 - - - - -for file_name in file_name_list: - file_path_name = os.path.join(data_dir,file_name) - assert os.path.exists(file_path_name) - - attribute_classify_sample = [] - df = pd.read_csv(file_path_name, encoding='utf-8') - df['attribute'] = df['t_str'].apply(lambda x: x.split('|||')[1].strip()) - attribute_list = df['attribute'].tolist() # 转化成列表 - attribute_list = list(set(attribute_list)) # 去重 - attribute_list = [att.strip().replace(' ','') for att in attribute_list] # 去尾部,去空格 - attribute_list = [re.sub(pattern,'',att) for att in attribute_list] # 去掉 以-开头 - - attribute_list = list(set(attribute_list)) # 再去重 - - for row in df.index: - question, pos_att = df.loc[row][['q_str', 'attribute']] - - question = question.strip().replace(' ','') # 去尾部,空格 - question = re.sub(pattern, '', question) # 去掉 以-开头 - - pos_att = pos_att.strip().replace(' ','') # 去尾部,空格 - pos_att = re.sub(pattern, '', pos_att) # 去掉 以-开头 - - neg_att_list = [] - while True: - neg_att_list = random.sample(attribute_list, 5) - if pos_att not in neg_att_list: - break - attribute_classify_sample.append([question, pos_att, '1']) - - neg_att_sample = [[question, neg_att, '0'] for neg_att in neg_att_list] - attribute_classify_sample.extend(neg_att_sample) - seq_result = [str(lineno) + '\t' + '\t'.join(line) for (lineno, line) in enumerate(attribute_classify_sample)] - - if not os.path.exists(new_dir): - os.makedirs(new_dir) - - file_type = file_name.split('.')[0] - print("***** {} ******".format(file_type)) - new_file_name = file_type + '.'+'txt' - with open(os.path.join(new_dir,new_file_name), "w", encoding='utf-8') as f: - f.write("\n".join(seq_result)) - f.close() - diff --git a/input/data/4-print-seq-len.py b/input/data/4-print-seq-len.py deleted file mode 100644 index 800e71d..0000000 --- a/input/data/4-print-seq-len.py +++ /dev/null @@ -1,33 +0,0 @@ -import os - -dir_name = 'sim_data' -file_list = ['train.txt','dev.txt','test.txt'] - -for file in file_list: - - file_path_name = os.path.join(dir_name,file) - - max_len = 0 - print("****** {} *******".format(file)) - with open(file_path_name,'r',encoding='utf-8') as f: - for line in f: - - line_list = line.split('\t') - question = list(line_list[1]) - attribute = list(line_list[2]) - add_len = len(question) + len(attribute) - if add_len > max_len: - max_len = add_len - print("max_len",max_len) - f.close() - -# ****** train.txt ******* -# max_len 62 -# ****** dev.txt ******* -# max_len 61 -# ****** test.txt ******* -# max_len 62 - - -# 因此,最大长度为 64 合理。 - diff --git a/input/data/5-triple_clean.py b/input/data/5-triple_clean.py deleted file mode 100644 index 988a5a7..0000000 --- a/input/data/5-triple_clean.py +++ /dev/null @@ -1,50 +0,0 @@ -# -*- coding: utf-8 -*- -# @Time : 2019/4/18 20:16 -# @Author : Alan -# @Email : xiezhengwen2013@163.com -# @File : triple_clean.py -# @Software: PyCharm - - -import pandas as pd - - -''' -构造NER训练集,实体序列标注,训练BERT+BiLSTM+CRF -''' - -question_str = "")[1].strip() - q_str = q_str.split(">")[1].replace(" ","").strip() - if ''.join(entities.split(' ')) in q_str: - clean_triple = t_str.split(">")[1].replace('\t','').replace(" ","").strip().split("|||") - triple_list.append(clean_triple) - else: - print(entities) - print(q_str) - print('------------------------') - -df = pd.DataFrame(triple_list, columns=["entity", "attribute", "answer"]) -print(df) -print(df.info()) -df.to_csv("./DB_Data/clean_triple.csv", encoding='utf-8', index=False) \ No newline at end of file diff --git a/input/data/Configurations.py b/input/data/Configurations.py new file mode 100644 index 0000000..4fef354 --- /dev/null +++ b/input/data/Configurations.py @@ -0,0 +1,4 @@ +DATA_DIR = 'NLPCC2016KBQA' +NER_DATA_DIR = 'NER_Data' +SIM_DATD_DIR = 'SIM_Data' +DB_DATA_DIR = 'DB_data' diff --git a/input/data/ConstructDatasetAttribute.py b/input/data/ConstructDatasetAttribute.py new file mode 100644 index 0000000..9bbca90 --- /dev/null +++ b/input/data/ConstructDatasetAttribute.py @@ -0,0 +1,73 @@ +#-*-coding:UTF-8 -*- + +from Configurations import NER_DATA_DIR, SIM_DATD_DIR +from MyUtils import my_strip, shuffle_from_ + +import os +import csv + +''' +通过 ner_data 中的数据 构建出 用来匹配句子相似度的 样本集合 +构造属性关联训练集,分类问题,训练BERT分类模型 +BERT: + 1. 预测 mask 概率; + 2. 预测 next sentence 概率. +''' + +# 处理后将文件写在 ./input/SIM_data 下 +if not os.path.exists(SIM_DATD_DIR): + os.makedirs(SIM_DATD_DIR) + +filenames = ['train.csv', 'dev.csv', 'test.csv'] + +sample_line_tmplt = '{}\t{}\t{}\t{}\n' + +for filename in filenames: + file_path = os.path.join(NER_DATA_DIR, filename) + assert os.path.exists(file_path) # 断言处理, 判断该文件是否存在 + rows = [] # 记录文件内容 + attribute_set = set() # 存储 attribute 的集合 + with open(file_path, mode='r', encoding='utf-8') as f: + reader = csv.DictReader(f) + # 读取文件 + for row in reader: + # 处理 question 字符串 + question = my_strip(row['question']) + # 提取并处理 attribute 字段 + attribute = my_strip(row['triple'].split('|||')[1]) + # 记录 attribute 字段 + row['attribute'] = attribute + # 并存储在集合中 + rows.append(row) + # 同时记录文件内容 + attribute_set.add(attribute) + + attribute_set = list(attribute_set) + + filename_tokens = filename.split('.') + filename_tokens[-1] = 'txt' + with open(os.path.join(SIM_DATD_DIR, '.'.join(filename_tokens)), mode='w', encoding='utf-8') as attribute_classify_sample: + NEGATIVE_SAMPLE_COUNT = 5 + cnt = 0 + max_len = 0 + for row in rows: + question, postive_attribute = row['question'], row['attribute'] + # 随机生成并记录负样本 + negative_attributes = [] + # 进行随机取样 + for i in range(NEGATIVE_SAMPLE_COUNT): + sample = shuffle_from_(attribute_set) + while sample == postive_attribute: + sample = shuffle_from_(attribute_set) + negative_attributes.append(sample) + # 写出生成的样本 + # 1 为正确答案, 0 为错误答案 + attribute_classify_sample.write(sample_line_tmplt.format(cnt, question, postive_attribute, 1)) + cnt += 1 + for var in negative_attributes: + attribute_classify_sample.write(sample_line_tmplt.format(cnt, question, var, 0)) + cnt += 1 + # 记录最大字串序列长度 + max_len = max(max_len, len(question) + len(var)) + max_len = max(max_len, len(question) + len(postive_attribute)) + print('{}:\tmax_len: {}'.format(filename, max_len)) diff --git a/input/data/ConstructDatasetNer.py b/input/data/ConstructDatasetNer.py new file mode 100644 index 0000000..bbee96f --- /dev/null +++ b/input/data/ConstructDatasetNer.py @@ -0,0 +1,71 @@ +# coding:utf-8 +import os +import csv +from Configurations import DATA_DIR, NER_DATA_DIR +from MyUtils import line_parse + +''' +通过 NLPCC2016KBQA 中的原始数据, 构建用来训练NER的样本集合 +构造 NER 训练集, 实体序列标注, 训练BERT+CRF +''' + +file_name_list = ['train.txt', 'dev.txt', 'test.txt'] + +# 将处理后的文件写在 ./input/NER_data 下 +if not os.path.exists(NER_DATA_DIR): + os.mkdir(NER_DATA_DIR) + +fields = ['question', 'triple', 'answer'] + +for file_name in file_name_list: + + # 写出序列标注 + tagged_q_str_file = open(os.path.join(NER_DATA_DIR, file_name), "w", encoding='utf-8') + + # 使用 csv.DictWriter 将 Q-T-A 对以 CSV 格式写出到对应的 CSV 文件里 + filename_tokens = file_name.split('.') + filename_tokens[-1] = 'csv' + qta_file = open(os.path.join(NER_DATA_DIR, '.'.join(filename_tokens)), mode='w', encoding='utf-8', newline='') + qta_writer = csv.DictWriter(qta_file, fieldnames=fields) + qta_writer.writeheader() + + file_path = os.path.join(DATA_DIR, file_name) + assert os.path.exists(file_path) + + with open(file_path, 'r', encoding='utf-8') as f: + while True: + try: + # 一次读取三行 + l = [line_parse(f.__next__()) for i in range(3)] + # 并映射到字典的形式 + s = {t[0]: t[1].strip() for t in l} + # 跳过分割线行 + f.__next__() + + q_str = s['question'].replace(' ', '') + s['question'] = q_str + + entity = s['triple'].split('|||')[0].strip() + p = q_str.find(entity) + # 若该实体名存在于问题中 + if p != -1: + tags = ['O'] * len(q_str) + # BIO 标注划分 + # B-IOC: 一个地名的开始 + # I-IOC:一个地名的中间部分 + # 其余为 O + tags[p] = 'B-LOC' + for i in range(p + 1, p + len(entity)): + tags[i] = 'I-LOC' + # 存储序列标注等待写出 + for q, t in zip(q_str, tags): + tagged_q_str_file.write(q + ' ' + t + '\n') + tagged_q_str_file.write('\n') + else: + pass + + qta_writer.writerow(s) + + except StopIteration: + qta_file.close() + break diff --git a/input/data/ConstructTriple.py b/input/data/ConstructTriple.py new file mode 100644 index 0000000..ff49565 --- /dev/null +++ b/input/data/ConstructTriple.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- + +import os +import csv +from Configurations import DATA_DIR, DB_DATA_DIR +from MyUtils import line_parse, my_strip + +''' +构造 NER 训练集 (即三元组), 实体序列标注, 用于训练 BERT + BiLSTM + CRF +''' + +# 将处理后的文件写在 ./input/DB_data 下 +if not os.path.exists(DB_DATA_DIR): + os.mkdir(DB_DATA_DIR) + +filename_prefix = 'nlpcc-iccpol-2016.kbqa.' +filename_suffix = '-data' +file_name_list = [ + filename_prefix + var + filename_suffix for var in ['training', 'testing'] +] + +with open(os.path.join(DB_DATA_DIR, 'clean_triple.csv'), mode='w', encoding='utf-8', newline='') as out_csv_file: + # 三元组: 实体, 属性, 答案 + writer = csv.DictWriter(out_csv_file, fieldnames=['entity', 'attribute', 'answer']) + writer.writeheader() + + for filename in file_name_list: + file_path = os.path.join(DATA_DIR, filename) + assert os.path.exists(file_path) + + with open(file_path, mode='r', encoding='utf-8') as f: + while True: + try: + # 一次读取三行 + l = [line_parse(f.__next__()) for i in range(3)] + # 并映射到字典的形式 + s = {t[0]: t[1] for t in l} + # 跳过分割线行 + f.__next__() + + triples = [t.strip() for t in s['triple'].split('|||')] + t = { + 'entity': triples[0], + 'attribute': triples[1], + 'answer': triples[2], + } + if t['entity'] in s['question']: + writer.writerow(t) + else: + # print(s['question']) + # print(t['entity']) + # print('-' * 30) + pass + except StopIteration: + break diff --git a/input/data/6-load_dbdata.py b/input/data/LoadMySQL.py similarity index 50% rename from input/data/6-load_dbdata.py rename to input/data/LoadMySQL.py index 2ae4eef..f096386 100644 --- a/input/data/6-load_dbdata.py +++ b/input/data/LoadMySQL.py @@ -1,18 +1,14 @@ # -*- coding: utf-8 -*- -# @Time : 2019/4/18 20:47 -# @Author : Alan -# @Email : xiezhengwen2013@163.com -# @File : load_dbdata.py -# @Software: PyCharm - import pymysql import pandas as pd from sqlalchemy import create_engine - +# 创建数据库, 默认使用 root 用户 def create_db(): - connect = pymysql.connect( # 连接数据库服务器-*-*- + + #创建连接,若账号密码不同请记得修改; + connect = pymysql.connect( user="root", password="123456", host="127.0.0.1", @@ -20,39 +16,45 @@ def create_db(): db="KB_QA", charset="utf8" ) - conn = connect.cursor() # 创建操作游标 - # 你需要一个游标 来实现对数据库的操作相当于一条线索 - - # 创建表 - conn.execute("drop database if exists KB_QA") # 如果new_database数据库存在则删除 - conn.execute("create database KB_QA") # 新创建一个数据库 - conn.execute("use KB_QA") # 选择new_database这个数据库 + #创建操作游标 + conn = connect.cursor() + + # 如果 KB_QA 数据库存在则删除 + conn.execute("drop database if exists KB_QA") + # 新创建一个数据库 + conn.execute("create database KB_QA") + # 选择使用 KB_QA 数据库 + conn.execute("use KB_QA") conn.execute("SET @@global.sql_mode=''") - # sql 中的内容为创建一个名为 new_table 的表 - sql = """create table nlpccQA(entity VARCHAR(50) character set utf8 collate utf8_unicode_ci, - attribute VARCHAR(50) character set utf8 collate utf8_unicode_ci, answer VARCHAR(255) character set utf8 - collate utf8_unicode_ci)""" # ()中的参数可以自行设置 - conn.execute("drop table if exists nlpccQA") # 如果表存在则删除 - conn.execute(sql) # 创建表 + # 如果表存在,则删除 + conn.execute("drop table if exists nlpccQA") + conn.execute(sql) - # 删除 - # conn.execute("drop table new_table") + # 创建一个名为 nlpccQA 的表 + sql = """ + create table nlpccQA(entity VARCHAR(50) character set utf8 collate utf8_unicode_ci, + attribute VARCHAR(50) character set utf8 collate utf8_unicode_ci, answer VARCHAR(255) character set utf8 + collate utf8_unicode_ci) + """ - conn.close() # 关闭游标连接 - connect.close() # 关闭数据库服务器连接 释放内存 + #关闭游标 + conn.close() + #关闭与数据库的连接 + connect.close() def loaddata(): - # 初始化数据库连接,使用pymysql模块 + #使用 pymysql,与数据库进行连接,同样,注意用户名和密码的设置。 db_info = {'user': 'root', - 'password': '123456', + 'password': 'root', 'host': '127.0.0.1', 'port': 3306, 'database': 'KB_QA' } - engine = create_engine( # 导入模块中的create_engine,需要利用它来进行连接数据库 + # 导入模块中的 create_engine, 需要利用它来进行连接数据库 + engine = create_engine( 'mysql+pymysql://%(user)s:%(password)s@%(host)s:%(port)d/%(database)s?charset=utf8' % db_info, encoding='utf-8') # ("mysql+pymysql://【此处填用户名】:【此处填密码】@【此处填host】:【此处填port】/【此处填数据库的名称】?charset=utf8") # 直接使用这种形式也可以engine = create_engine('mysql+pymysql://root:123456@localhost:3306/test') @@ -60,27 +62,30 @@ def loaddata(): # 读取本地CSV文件 df = pd.read_csv("./DB_Data/clean_triple.csv", sep=',', encoding='utf-8') - print(df) - # 将新建的DataFrame储存为MySQL中的数据表,不储存index列(index=False) - # if_exists: - # 1.fail:如果表存在,啥也不做 - # 2.replace:如果表存在,删了表,再建立一个新表,把数据插入 - # 3.append:如果表存在,把数据插入,如果表不存在创建一个表!! + + # 将新建的 DataFrame 储存到 MySQL 中的数据表, 不储存 index 列 (index=False) + # if_exists 参数: + # - fail: 如果表存在,啥也不做 + # - replace: 如果表存在,删了表,再建立一个新表,把数据插入 + # - append: 如果表存在,把数据插入,如果表不存在创建一个表!! pd.io.sql.to_sql(df, 'nlpccQA', con=engine, index=False, if_exists='append', chunksize=10000) # df.to_sql('example', con=engine, if_exists='replace')这种形式也可以 print("Write to MySQL successfully!") def upload_data(sql): - connect = pymysql.connect( # 连接数据库服务器 + #连接数据库服务器 + connect = pymysql.connect( user="root", - password="123456", + password="root", host="127.0.0.1", port=3306, db="kb_qa", charset="utf8" ) - cursor = connect.cursor() # 创建操作游标 + # 创建操作游标 + cursor = connect.cursor() + results = None try: # 执行SQL语句 cursor.execute(sql) @@ -95,6 +100,25 @@ def upload_data(sql): return results +def insert_data(entity, attributes, answer): + connect = pymysql.connect( + user="root", + password="123456", + host="127.0.0.1", + port=3306, + db="kb_qa", + charset="utf8" + ) + + # 创建操作游标 + cursor = connect.cursor() + sql = "INSERT INTO nlpccQA(entity, attribute, answer) VALUES (%s, %s, %s)" + cursor.execute(sql, (entity, attributes, answer)) + connect.commit() + cursor.close() + connect.close() + + if __name__ == '__main__': # create_db() # loaddata() @@ -102,4 +126,3 @@ def upload_data(sql): ret = upload_data(sql) print(list(ret)) - # \ No newline at end of file diff --git a/input/data/MyUtils.py b/input/data/MyUtils.py new file mode 100644 index 0000000..f5f10c8 --- /dev/null +++ b/input/data/MyUtils.py @@ -0,0 +1,23 @@ +from typing import Tuple +import re +import random + +def line_parse(line: str) -> Tuple[str, str]: + p1 = line.find('<') + p2 = line.find('>') + header = line[(p1+1):p2] + content = line[p2+1:].strip() + return header.split()[0], content + + +# 正则表达式 +pattern = re.compile('^-+') # 以-开头 + +def my_strip(s: str) -> str: + # 去首尾空白, 去除字符串之间空格 + s = s.strip().replace(' ', '') + # 去掉字符串开始处的 '-' + return re.sub(pattern, '', s) + + +shuffle_from_ = lambda l: random.choice(l) diff --git a/input/data/README.md b/input/data/README.md new file mode 100644 index 0000000..7d06fc4 --- /dev/null +++ b/input/data/README.md @@ -0,0 +1,29 @@ +# 数据处理 + +### 使用的数据集 + +NLPCC2016KBQA + +### 构建训练数据 + +1. [SplitData.py](./SplitData.py) +1. [ConstructDataSetNER.py](./ConstructDataSetNER.py) +1. [ConstructDatasetAttribute.py](./ConstructDatasetAttribute.py) + +记录下数据中最长的序列长度: + +```none +train.csv: max_len: 62 +dev.csv: max_len: 60 +test.csv: max_len: 62 +``` + +因此, 将最大长度设定为 64 较为合理. + +### 从数据构建知识三元组 + +[ConstructTriple.py](./ConstructTriple.py) + +### 将数据载入到数据库 + +[LoadDbData.py](./LoadDbData.py) (MySQL) diff --git a/input/data/SplitData.py b/input/data/SplitData.py new file mode 100644 index 0000000..ddad262 --- /dev/null +++ b/input/data/SplitData.py @@ -0,0 +1,44 @@ +import os +from Configurations import DATA_DIR + +""" +切分数据集; +原始的 nlpcc-iccpol-2016.kbqa.testing-data 有 9870 个样本 +原始的 nlpcc-iccpol-2016.kbqa.training-data 有 14609 个样本; + +nlpcc-iccpol-2016.kbqa.testing-data 中的数据对半分,一半变成验证集(dev.text),一半变成测试集(test.txt) +nlpcc-iccpol-2016.kbqa.training-data 保持不变,复制成为训练集 train.txt +""" + +file_name_list = ['nlpcc-iccpol-2016.kbqa.testing-data', 'nlpcc-iccpol-2016.kbqa.training-data'] + +# 文件处理 +for file_name in file_name_list: + file_path_name = os.path.join(DATA_DIR, file_name) + file = [] + with open(file_path_name,'r',encoding='utf-8') as f: + for line in f: + line = line.strip() + if line == '': + continue + file.append(line) + if 'training' in file_name: + with open(os.path.join(DATA_DIR,'train.txt') , "w", encoding='utf-8') as f: + f.write('\n'.join(file)) + elif 'testing' in file_name: + assert len(file) % 4 == 0 # 断言处理,错误时触发异常 + testing_num = len(file) // 4 # 一个样本是由 4 行构成的 + line_no = testing_num // 2 * 4 # 将测试集分出一半用作评估 + + # 测试数据 + with open(os.path.join(DATA_DIR, 'test.txt'), "w", encoding='utf-8') as f: + for line in file[:line_no]: + f.write(line + '\n') + + # 进行评估的数据 + with open(os.path.join(DATA_DIR, 'dev.txt'), "w", encoding='utf-8') as f: + for line in file[line_no:]: + f.write(line + '\n') + +print("Done") + diff --git a/sqlDialog.py b/sqlDialog.py new file mode 100644 index 0000000..a2bb5c3 --- /dev/null +++ b/sqlDialog.py @@ -0,0 +1,138 @@ +import datetime +import logging + +from PySide6.QtCore import Qt, Slot +from PySide6.QtSql import QSqlDatabase, QSqlQuery, QSqlRecord, QSqlTableModel +from PySide6.QtQml import QmlElement + +from ProjectTest import Model + +table_name = "Conversations" +QML_IMPORT_NAME = "ChatModel" +QML_IMPORT_MAJOR_VERSION = 1 +QML_IMPORT_MINOR_VERSION = 0 + +def createTable(): + if table_name in QSqlDatabase.database().tables(): + return + + query = QSqlQuery() + if not query.exec_( + """ + CREATE TABLE IF NOT EXISTS 'Conversations' ( + 'author' TEXT NOT NULL, + 'recipient' TEXT NOT NULL, + 'timestamp' TEXT NOT NULL, + 'message' TEXT NOT NULL, + FOREIGN KEY('author') REFERENCES Contacts ( name ), + FOREIGN KEY('recipient') REFERENCES Contacts ( name ) + ) + """ + ): + logging.error("Failed to query database") + + # This adds the first message from the Bot + # and further development is required to make it interactive. + query.exec_( + """ + INSERT INTO Conversations VALUES( + 'machine', 'Me', '2019-01-07T14:36:06', 'Hello!' + ) + """ + ) + + +@QmlElement +class SqlConversationModel(QSqlTableModel): + def __init__(self, parent=None): + super(SqlConversationModel, self).__init__(parent) + + createTable() + self.setTable(table_name) + self.setSort(2, Qt.DescendingOrder) + self.setEditStrategy(QSqlTableModel.OnManualSubmit) + self.recipient = "" + + self.setRecipient('machine') + + self.select() + logging.debug("Table was loaded successfully.") + + self.model = Model() + + + def setRecipient(self, recipient): + if recipient == self.recipient: + pass + + self.recipient = recipient + + filter_str = ( + "(recipient = '{}' AND author = 'Me') OR " + "(recipient = 'Me' AND author = '{}')".format(self.recipient, self.recipient) + ) + + self.setFilter(filter_str) + self.select() + + def data(self, index, role): + if role < Qt.UserRole: + return QSqlTableModel.data(self, index, role) + + sql_record = QSqlRecord() + sql_record = self.record(index.row()) + + return sql_record.value(role - Qt.UserRole) + + def roleNames(self): + names = dict() + author = "author".encode() + recipient = "recipient".encode() + timestamp = "timestamp".encode() + message = "message".encode() + + names[hash(Qt.UserRole)] = author + names[hash(Qt.UserRole + 1)] = recipient + names[hash(Qt.UserRole + 2)] = timestamp + names[hash(Qt.UserRole + 3)] = message + + return names + + + + @Slot(str, str, str) + def send_message(self, recipient, message, author): + timestamp = datetime.datetime.now() + + new_record = self.record() + new_record.setValue('author', author) + new_record.setValue('recipient', recipient) + new_record.setValue('timestamp', str(timestamp)) + new_record.setValue('message', message) + + logging.debug(f'Message: "{message}" Received by: "{recipient}"') + + if not self.insertRecord(self.rowCount(), new_record): + logging.error(f'Failed to send message: {self.lastError().text()}') + return + + + new_record = self.record() + ans = self.model.query(message.strip()) + new_record.setValue('message', ans) + new_record.setValue('author', recipient) + new_record.setValue('recipient', author) + + timestamp = datetime.datetime.now() + datetime.timedelta(microseconds=1) + new_record.setValue('timestamp', str(timestamp)) + + if not self.insertRecord(self.rowCount(), new_record): + logging.error(f'Failed to send message: {self.lastError().text()}') + return + + self.submitAll() + self.select() + + + +