From af41ffc29c7932d900a933b2f75cb0eec134ed4a Mon Sep 17 00:00:00 2001 From: henry_23 Date: Sun, 2 Jan 2022 00:06:10 +0800 Subject: [PATCH 01/10] fix: rename files & add comments Co-authored-by: KisAaki <72684022+KisAaki@users.noreply.github.com> --- .gitignore | 3 + BERT_CRF.py => BERT_CRF_Model.py | 18 +--- CRF_Model.py | 17 ++-- test_NER.py => NERTest.py | 12 +-- NER_main.py => NERTrain.py | 99 +++++++++---------- test_pro.py => ProjectTest.py | 27 +++-- test_SIM.py => SIMTest.py | 7 +- SIM_main.py => SIMTrain.py | 21 ++-- input/data/.gitignore | 11 +++ ...ribute.py => ConstructDatasetAttribute.py} | 37 +++---- ..._dataset_ner.py => ConstructDatasetNer.py} | 11 ++- .../{5-triple_clean.py => ConstructTriple.py} | 18 ++-- .../data/{6-load_dbdata.py => LoadDbData.py} | 25 +++-- .../{4-print-seq-len.py => PrintSeqLen.py} | 8 +- input/data/README.md | 8 ++ input/data/{1_split_data.py => SplitData.py} | 2 +- 16 files changed, 184 insertions(+), 140 deletions(-) create mode 100644 .gitignore rename BERT_CRF.py => BERT_CRF_Model.py (93%) rename test_NER.py => NERTest.py (93%) rename NER_main.py => NERTrain.py (90%) rename test_pro.py => ProjectTest.py (94%) rename test_SIM.py => SIMTest.py (96%) rename SIM_main.py => SIMTrain.py (97%) create mode 100644 input/data/.gitignore rename input/data/{3-construct_dataset_attribute.py => ConstructDatasetAttribute.py} (61%) rename input/data/{2-construct_dataset_ner.py => ConstructDatasetNer.py} (88%) rename input/data/{5-triple_clean.py => ConstructTriple.py} (80%) rename input/data/{6-load_dbdata.py => LoadDbData.py} (82%) rename input/data/{4-print-seq-len.py => PrintSeqLen.py} (82%) create mode 100644 input/data/README.md rename input/data/{1_split_data.py => SplitData.py} (99%) diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a3ce1a3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +__pycache__/ + +*.log diff --git a/BERT_CRF.py b/BERT_CRF_Model.py similarity index 93% rename from BERT_CRF.py rename to BERT_CRF_Model.py index 8d59731..8e7fa07 100644 --- a/BERT_CRF.py +++ b/BERT_CRF_Model.py @@ -10,11 +10,6 @@ VOB_NAME = "bert-base-chinese-vocab.txt" - - - - - class BertCrf(nn.Module): def __init__(self,config_name:str,model_name:str = None,num_tags: int = 2, batch_first:bool = True) -> None: # 记录batch_first @@ -58,7 +53,6 @@ def __init__(self,config_name:str,model_name:str = None,num_tags: int = 2, batch self.crf_model = CRF(num_tags=num_tags,batch_first=batch_first) - def forward(self,input_ids:torch.Tensor, tags:torch.Tensor = None, attention_mask:Optional[torch.ByteTensor] = None, @@ -68,14 +62,12 @@ def forward(self,input_ids:torch.Tensor, emissions = self.bertModel(input_ids = input_ids,attention_mask = attention_mask,token_type_ids=token_type_ids)[0] - # 这里在seq_len的维度上去头,是去掉了[CLS],去尾巴有两种情况 - # 1、是 2、[SEP] - - + # 这里在seq_len的维度上去掉开头,是去掉了[CLS],去尾巴有两种情况 + # 第一种情况是 、第二种情况是 [SEP] new_emissions = emissions[:,1:-1] new_mask = attention_mask[:,2:].bool() - # 如果 tags 为 None,表示是一个预测的过程,不能求得loss,loss 直接为None + # 如果 tags 为 None,表示是一个预测的过程,不能求得 loss,则 loss 的值直接为 None if tags is None: loss = None pass @@ -83,8 +75,6 @@ def forward(self,input_ids:torch.Tensor, new_tags = tags[:, 1:-1] loss = self.crf_model(emissions=new_emissions, tags=new_tags, mask=new_mask, reduction=reduction) - - if decode: tag_list = self.crf_model.decode(emissions = new_emissions,mask = new_mask) return [loss, tag_list] @@ -92,5 +82,3 @@ def forward(self,input_ids:torch.Tensor, return [loss] - - diff --git a/CRF_Model.py b/CRF_Model.py index e9c3047..3e249e3 100644 --- a/CRF_Model.py +++ b/CRF_Model.py @@ -2,6 +2,9 @@ import torch import torch.nn as nn +""" +CRF 条件随机场模型; +""" class CRF(nn.Module): def __init__(self,num_tags : int = 2, batch_first:bool = True) -> None: @@ -10,7 +13,8 @@ def __init__(self,num_tags : int = 2, batch_first:bool = True) -> None: super().__init__() self.num_tags = num_tags self.batch_first = batch_first - # start 到其他tag(不包含end)的得分 + # start 到其他 tag (不包含 end) 的得分 + # (从开始节点到其他非 end 节点的 scores) self.start_transitions = nn.Parameter(torch.empty(num_tags)) # 到其他tag(不包含start)到end的得分 self.end_transitions = nn.Parameter(torch.empty(num_tags)) @@ -21,6 +25,7 @@ def __init__(self,num_tags : int = 2, batch_first:bool = True) -> None: self.reset_parameters() + # 对参数进行重新设置 def reset_parameters(self): init_range = 0.1 nn.init.uniform_(self.start_transitions,-init_range,init_range) @@ -30,6 +35,7 @@ def reset_parameters(self): def __repr__(self): return f'{self.__class__.__name__}(num_tags={self.num_tags})' + # 向前传播; def forward(self, emissions:torch.Tensor, tags:torch.Tensor = None, mask:Optional[torch.ByteTensor] = None, @@ -42,6 +48,7 @@ def forward(self, emissions:torch.Tensor, raise ValueError(f'invalid reduction {reduction}') if mask is None: + #生成值全为1的张量,用于掩码 mask = torch.ones_like(tags,dtype = torch.uint8) # a.shape (seq_len,batch_size) # a[0] shape ? batch_size @@ -81,10 +88,6 @@ def decode(self,emissions:torch.Tensor, return self._viterbi_decode(emissions,mask) - - - - def _validate(self, emissions:torch.Tensor, tags:Optional[torch.LongTensor] = None , @@ -146,7 +149,7 @@ def _computer_score(self, # 这里是为了获取每一个样本最后一个词的tag。 # shape: (batch_size,) 每一个batch 的真实长度 seq_ends = mask.long().sum(dim=0) - 1 - # 每个样本最火一个词的tag + # 每个样本最后一个词的tag last_tags = tags[seq_ends,torch.arange(batch_size)] # shape: (batch_size,) 每一个样本到最后一个词的得分加上之前的score score += self.end_transitions[last_tags] @@ -250,4 +253,4 @@ def _viterbi_decode(self,emissions : torch.FloatTensor , best_tags.reverse() best_tags_list.append(best_tags) - return best_tags_list \ No newline at end of file + return best_tags_list diff --git a/test_NER.py b/NERTest.py similarity index 93% rename from test_NER.py rename to NERTest.py index f6e1c09..33d5602 100644 --- a/test_NER.py +++ b/NERTest.py @@ -1,15 +1,15 @@ -from BERT_CRF import BertCrf +from BERT_CRF_Model import BertCrf from transformers import BertTokenizer -from NER_main import NerProcessor,statistical_real_sentences,flatten,CrfInputFeatures +from NERTrain import NerProcessor,statistical_real_sentences,flatten,CrfInputFeatures from torch.utils.data import DataLoader, RandomSampler,TensorDataset from sklearn.metrics import classification_report import torch import numpy as np from tqdm import tqdm, trange - - - +""" +对命名实体识别模型进行简单的测试 +""" processor = NerProcessor() tokenizer_inputs = () @@ -81,4 +81,4 @@ # # micro avg 0.996142 0.996142 0.996142 145137 # macro avg 0.994650 0.994380 0.994512 145137 -# weighted avg 0.996149 0.996142 0.996143 145137 \ No newline at end of file +# weighted avg 0.996149 0.996142 0.996143 145137 diff --git a/NER_main.py b/NERTrain.py similarity index 90% rename from NER_main.py rename to NERTrain.py index e256b00..993a272 100644 --- a/NER_main.py +++ b/NERTrain.py @@ -12,7 +12,7 @@ # 64 # --do_train # --train_batch_size -# 32 +# 16 # --eval_batch_size # 256 # --gradient_accumulation_steps @@ -20,7 +20,6 @@ # --num_train_epochs # 15 - import argparse import logging import codecs @@ -33,27 +32,24 @@ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset from transformers import BertForSequenceClassification,BertTokenizer,BertConfig from transformers.data.processors.utils import DataProcessor, InputExample -from BERT_CRF import BertCrf -from transformers import AdamW, WarmupLinearSchedule +from BERT_CRF_Model import BertCrf +from transformers import AdamW, get_linear_schedule_with_warmup from sklearn.metrics import classification_report logger = logging.getLogger(__name__) -# # CRF_LABELS = ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "X", "[CLS]", "[SEP]"] -# 在这个项目中只需要识别三个类型的项目即可 -# 这里做以下测试,第一 LABELS = ["O", "B-LOC", "I-LOC"] ,因为需要预测的就只有这三个。 -# 第二 LABELS = ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "X", "[CLS]", "[SEP]"] +# 在项目中只需要识别三个类型的项目即可 +# LABELS = ["O", "B-LOC", "I-LOC"],需要预测的就只有这三个。 CRF_LABELS = ["O", "B-LOC", "I-LOC"] - def statistical_real_sentences(input_ids:torch.Tensor,mask:torch.Tensor,predict:list)-> list: # shape (batch_size,max_len) assert input_ids.shape == mask.shape # batch_size assert input_ids.shape[0] == len(predict) - # 第0位是[CLS] 最后一位是 或者 [SEP] + # 开头是 [CLS],结尾是 或者 [SEP] new_ids = input_ids[:,1:-1] new_mask = mask[:,2:] @@ -67,29 +63,21 @@ def statistical_real_sentences(input_ids:torch.Tensor,mask:torch.Tensor,predict: def flatten(inputs:list) -> list: result = [] + # 在列表末尾进行追加,以达到 flatten 的目的 [result.extend(line) for line in inputs] return result - - - - - - - - - - def set_seed(args): random.seed(args.seed) np.random.seed(args.seed) + # 为 CPU 设置种子用于生成随机数, 使得结果是确定的。 torch.manual_seed(args.seed) - class CrfInputExample(object): def __init__(self, guid, text, label=None): + # 初始化 self.guid = guid self.text = text self.label = label @@ -97,13 +85,13 @@ def __init__(self, guid, text, label=None): class CrfInputFeatures(object): def __init__(self, input_ids, attention_mask, token_type_ids, label): + # 初始化 self.input_ids = input_ids self.attention_mask = attention_mask self.token_type_ids = token_type_ids self.label = label - - +# 将语句序列化 def crf_convert_examples_to_features(examples,tokenizer, max_length=512, label_list=None, @@ -116,6 +104,7 @@ def crf_convert_examples_to_features(examples,tokenizer, features = [] for (ex_index, example) in enumerate(examples): + # 调用 BertTokenizer inputs = tokenizer.encode_plus( example.text, add_special_tokens=True, @@ -123,23 +112,20 @@ def crf_convert_examples_to_features(examples,tokenizer, truncate_first_sequence=True # We're truncating the first sequence in priority if True ) input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"] + # Masked 操作 attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) - padding_length = max_length - len(input_ids) input_ids = input_ids + ([pad_token] * padding_length) attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length) token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length) - # 第一个和第二个[0] 加的是[CLS]和[SEP]的位置, [0]*padding_length是[pad] ,把这些都暂时算作"O",后面用mask 来消除这些,不会影响 + # 第一个和第二个[0] 加的是[CLS]和[SEP]的位置, [0]*padding_length是[pad] ,把这些都暂时算作"O",对 mask 没有影响 labels_ids = [0] + [label_map[l] for l in example.label] + [0] + [0]*padding_length - - assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length) assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask),max_length) assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(len(token_type_ids),max_length) - assert len(labels_ids) == max_length, "Error with input length {} vs {}".format(len(labels_ids),max_length) @@ -158,14 +144,17 @@ def crf_convert_examples_to_features(examples,tokenizer, class NerProcessor(DataProcessor): + # 获取训练集 def get_train_examples(self,data_dir): return self._create_examples( os.path.join(data_dir,"train.txt")) + # 获取验证集 def get_dev_examples(self, data_dir): return self._create_examples( os.path.join(data_dir, "dev.txt")) + # 获取测试集 def get_test_examples(self, data_dir): return self._create_examples( os.path.join(data_dir, "test.txt")) @@ -227,18 +216,22 @@ def load_and_cache_example(args,tokenizer,processor,data_type): features = crf_convert_examples_to_features(examples=examples,tokenizer=tokenizer,max_length=args.max_seq_length,label_list=label_list) logger.info("Saving features into cached file %s", cached_features_file) torch.save(features, cached_features_file) - + # 获取 input 的 ID all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) + # 获取 掩码 all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) + # 获取 类型的 ID all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) + # 获取标签 all_label = torch.tensor([f.label for f in features], dtype=torch.long) dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_label) return dataset def trains(args,train_dataset,eval_dataset,model): - + # RandomSampler, 随机采样器 train_sampler = RandomSampler(train_dataset) + train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs @@ -250,8 +243,8 @@ def trains(args,train_dataset,eval_dataset,model): {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(optimizer_grouped_parameters,lr=args.learning_rate,eps=args.adam_epsilon) - - scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + # 学习率预热,在实验开始提高,后面逐渐下降; + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps = t_total) logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) @@ -265,6 +258,7 @@ def trains(args,train_dataset,eval_dataset,model): set_seed(args) best_f1 = 0. for _ in train_iterator: + # 进度条,用于显示整个实验的进度 epoch_iterator = tqdm(train_dataloader, desc="Iteration") for step,batch in enumerate(epoch_iterator): batch = tuple(t.to(args.device) for t in batch) @@ -279,7 +273,9 @@ def trains(args,train_dataset,eval_dataset,model): if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps + # 反向回溯 loss.backward() + # 进行梯度截断 torch.nn.utils.clip_grad_norm_(model.parameters(),args.max_grad_norm) logging_loss += loss.item() tr_loss += loss.item() @@ -292,9 +288,9 @@ def trains(args,train_dataset,eval_dataset,model): logging_loss) logging_loss = 0.0 - # if (global_step < 100 and global_step % 10 == 0) or (global_step % 50 == 0): - # 每 相隔 100步,评估一次 - if global_step % 100 == 0: + # 每相隔 100步,评估一次 + if global_step % .100 == 0: + #评估并保存模型 best_f1 = evaluate_and_save_model(args,model,eval_dataset,_,global_step,best_f1) # 最后循环结束 再评估一次 @@ -305,6 +301,7 @@ def trains(args,train_dataset,eval_dataset,model): def evaluate_and_save_model(args,model,eval_dataset,epoch,global_step,best_f1): ret = evaluate(args, model, eval_dataset) + print(ret) precision_b = ret['1']['precision'] recall_b = ret['1']['recall'] @@ -323,9 +320,9 @@ def evaluate_and_save_model(args,model,eval_dataset,epoch,global_step,best_f1): avg_recall = recall_b * weight_b + recall_i * weight_i avg_f1 = f1_b * weight_b + f1_i * weight_i - all_avg_precision = ret['micro avg']['precision'] - all_avg_recall = ret['micro avg']['recall'] - all_avg_f1 = ret['micro avg']['f1-score'] + all_avg_precision = ret['macro avg']['precision'] + all_avg_recall = ret['macro avg']['recall'] + all_avg_f1 = ret['macro avg']['f1-score'] logger.info("Evaluating EPOCH = [%d/%d] global_step = %d", epoch+1,args.num_train_epochs,global_step) logger.info("B-LOC precision = %f recall = %f f1 = %f support = %d", precision_b, recall_b, f1_b, @@ -339,6 +336,7 @@ def evaluate_and_save_model(args,model,eval_dataset,epoch,global_step,best_f1): all_avg_f1) if avg_f1 > best_f1: + #若当前的模型比历史最优模型要好 best_f1 = avg_f1 torch.save(model.state_dict(), os.path.join(args.output_dir, "best_ner.bin")) logging.info("save the best model %s,avg_f1= %f", os.path.join(args.output_dir, "best_bert.bin"), @@ -347,10 +345,6 @@ def evaluate_and_save_model(args,model,eval_dataset,epoch,global_step,best_f1): return best_f1 - - - - def evaluate(args, model, eval_dataset): eval_output_dirs = args.output_dir @@ -380,8 +374,7 @@ def evaluate(args, model, eval_dataset): 'reduction':'none' } outputs = model(**inputs) - # temp_eval_loss shape: (batch_size) - # temp_pred : list[list[int]] 长度不齐 + temp_eval_loss, temp_pred = outputs[0], outputs[1] loss.extend(temp_eval_loss.tolist()) @@ -398,21 +391,18 @@ def evaluate(args, model, eval_dataset): return ret - - - - def main(): + #argparse 命令行解析的标准模块, 在命令行中传入参数后让程序运行 parser = argparse.ArgumentParser() - parser.add_argument("--data_dir", default=None, type=str, required=True, + parser.add_argument("--data_dir", default=None, type=str, help="数据文件目录,因当有train.txt dev.txt") - parser.add_argument("--vob_file", default=None, type=str, required=True, + parser.add_argument("--vob_file", default=None, type=str, help="词表文件") - parser.add_argument("--model_config", default=None, type=str, required=True, + parser.add_argument("--model_config", default=None, type=str, help="模型配置文件json文件") - parser.add_argument("--output_dir", default=None, type=str, required=True, + parser.add_argument("--output_dir", default=None, type=str, help="输出结果的文件") # Other parameters @@ -451,10 +441,10 @@ def main(): logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) - # filename='./output/bert-crf-ner.log', + processor = NerProcessor() - # 得到tokenizer + # 得到 tokenizer tokenizer_inputs = () tokenizer_kwards = {'do_lower_case': False, 'max_len': args.max_seq_length, @@ -476,4 +466,5 @@ def main(): if __name__ == '__main__': + torch.cuda.empty_cache() main() diff --git a/test_pro.py b/ProjectTest.py similarity index 94% rename from test_pro.py rename to ProjectTest.py index ab91cb2..58572bb 100644 --- a/test_pro.py +++ b/ProjectTest.py @@ -1,22 +1,25 @@ -from BERT_CRF import BertCrf -from NER_main import NerProcessor, CRF_LABELS -from SIM_main import SimProcessor,SimInputFeatures +from BERT_CRF_Model import BertCrf +from NERTrain import NerProcessor, CRF_LABELS +from SIMTrain import SimProcessor,SimInputFeatures from transformers import BertTokenizer, BertConfig, BertForSequenceClassification from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset import torch import pymysql from tqdm import tqdm, trange +import WikiQuery - +# 载入 GPU,没有则使用 CPU device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +# 获取之前训练的 NER 模型 def get_ner_model(config_file,pre_train_model,label_num = 2): model = BertCrf(config_name=config_file,num_tags=label_num, batch_first=True) model.load_state_dict(torch.load(pre_train_model)) return model.to(device) +# 获取之前训练的 SIM 模型 def get_sim_model(config_file,pre_train_model,label_num = 2): bert_config = BertConfig.from_pretrained(config_file) bert_config.num_labels = label_num @@ -25,6 +28,7 @@ def get_sim_model(config_file,pre_train_model,label_num = 2): return model +# 从一个句子中获取实体 def get_entity(model,tokenizer,sentence,max_len = 64): pad_token = 0 sentence_list = list(sentence.strip().replace(' ','')) @@ -86,6 +90,7 @@ def get_entity(model,tokenizer,sentence,max_len = 64): return "".join(entity_list) +# 语义匹配 def semantic_matching(model,tokenizer,question,attribute_list,answer_list,max_length): assert len(attribute_list) == len(answer_list) @@ -157,7 +162,7 @@ def semantic_matching(model,tokenizer,question,attribute_list,answer_list,max_le def select_database(sql): - # connect database + # 连接数据库 connect = pymysql.connect(user="root",password="123456",host="127.0.0.1",port=3306,db="kb_qa",charset="utf8") cursor = connect.cursor() # 创建操作游标 try: @@ -174,7 +179,7 @@ def select_database(sql): return results -# 文字直接匹配,看看属性的词语在不在句子之中 +# 文字直接匹配,看看属性的词是否在句子中 def text_match(attribute_list,answer_list,sentence): assert len(attribute_list) == len(answer_list) @@ -190,7 +195,6 @@ def text_match(attribute_list,answer_list,sentence): return "","" - def main(): with torch.no_grad(): @@ -219,24 +223,30 @@ def main(): print("====="*10) raw_text = input("问题:\n") raw_text = raw_text.strip() + # 输入 quit 则退出。 if ( "quit" == raw_text ): print("quit") return + # 获取实体 entity = get_entity(model=ner_model, tokenizer=tokenizer, sentence=raw_text, max_len=64) print("实体:", entity) if '' == entity: print("未发现实体") continue sql_str = "select * from nlpccqa where entity = '{}'".format(entity) + triple_list = select_database(sql_str) triple_list = list(triple_list) + print(triple_list) if 0 == len(triple_list): print("未找到 {} 相关信息".format(entity)) continue triple_list = list(zip(*triple_list)) - # print(triple_list) + print(triple_list) + attribute_list = triple_list[1] answer_list = triple_list[2] + # 直接进行匹配 attribute, answer = text_match(attribute_list, answer_list, raw_text) if attribute != '' and answer != '': ret = "{}的{}是{}".format(entity, attribute, answer) @@ -247,6 +257,7 @@ def main(): sim_model = sim_model.to(device) sim_model.eval() + # 进行语义匹配 attribute_idx = semantic_matching(sim_model, tokenizer, raw_text, attribute_list, answer_list, 64).item() if -1 == attribute_idx: ret = '' diff --git a/test_SIM.py b/SIMTest.py similarity index 96% rename from test_SIM.py rename to SIMTest.py index 9aaa9c2..1d9929d 100644 --- a/test_SIM.py +++ b/SIMTest.py @@ -1,9 +1,14 @@ from transformers import BertConfig, BertForSequenceClassification, BertTokenizer from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset -from SIM_main import SimProcessor,SimInputFeatures,cal_acc +from SIMTrain import SimProcessor,SimInputFeatures,cal_acc import torch from tqdm import tqdm, trange +""" +对属性相似度模型进行测试 +""" + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") processor = SimProcessor() tokenizer_inputs = () diff --git a/SIM_main.py b/SIMTrain.py similarity index 97% rename from SIM_main.py rename to SIMTrain.py index 83272cb..376ebbc 100644 --- a/SIM_main.py +++ b/SIMTrain.py @@ -33,13 +33,11 @@ import torch.nn as nn from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset -from transformers import AdamW, WarmupLinearSchedule +from transformers import AdamW, get_linear_schedule_with_warmup from transformers import BertConfig, BertForSequenceClassification, BertTokenizer -from transformers import glue_convert_examples_to_features as convert_examples_to_features from transformers.data.processors.utils import DataProcessor, InputExample import numpy as np -import pandas as pd logger = logging.getLogger(__name__) @@ -76,9 +74,8 @@ def cal_acc(real_label,pred_label): - - class SimInputExample(object): + # 进行初始化 def __init__(self, guid, question,attribute, label=None): self.guid = guid self.question = question @@ -87,6 +84,7 @@ def __init__(self, guid, question,attribute, label=None): class SimInputFeatures(object): + # 进行初始化 def __init__(self, input_ids, attention_mask, token_type_ids, label = None): self.input_ids = input_ids self.attention_mask = attention_mask @@ -98,17 +96,19 @@ class SimProcessor(DataProcessor): """Processor for the FAQ problem modified from https://github.com/huggingface/transformers/blob/master/transformers/data/processors/glue.py#L154 """ - + # 获取训练集 def get_train_examples(self, data_dir): logger.info("******* train ********") return self._create_examples( os.path.join(data_dir, "train.txt")) + # 获取验证集 def get_dev_examples(self, data_dir): logger.info("******* dev ********") return self._create_examples( os.path.join(data_dir, "dev.txt")) + # 获取测试集 def get_test_examples(self,data_dir): logger.info("******* test ********") return self._create_examples( @@ -133,7 +133,7 @@ def _create_examples(cls, path): f.close() return examples - +# 将文本进行序列化 def sim_convert_examples_to_features(examples,tokenizer, max_length=512, label_list=None, @@ -230,7 +230,7 @@ def trains(args,train_dataset,eval_dataset,model): ] optimizer = AdamW(optimizer_grouped_parameters,lr=args.learning_rate,eps=args.adam_epsilon) - scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) @@ -239,11 +239,14 @@ def trains(args,train_dataset,eval_dataset,model): global_step = 0 tr_loss, logging_loss = 0.0, 0.0 + # 梯度清零 model.zero_grad() train_iterator = trange(int(args.num_train_epochs), desc="Epoch") + # 设置随机数种子 set_seed(args) best_acc = 0. for _ in train_iterator: + # 进度条 epoch_iterator = tqdm(train_dataloader, desc="Iteration") for step,batch in enumerate(epoch_iterator): batch = tuple(t.to(args.device) for t in batch) @@ -257,6 +260,7 @@ def trains(args,train_dataset,eval_dataset,model): if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() + # 梯度截断 torch.nn.utils.clip_grad_norm_(model.parameters(),args.max_grad_norm) logging_loss += loss.item() tr_loss += loss.item() @@ -311,6 +315,7 @@ def evaluate(args, model, eval_dataset): total_sample_num = 0 # 样本总数目 all_real_label = [] # 记录所有的真实标签列表 all_pred_label = [] # 记录所有的预测标签列表 + for batch in tqdm(eval_dataloader, desc="Evaluating"): model.eval() batch = tuple(t.to(args.device) for t in batch) diff --git a/input/data/.gitignore b/input/data/.gitignore new file mode 100644 index 0000000..90ced5a --- /dev/null +++ b/input/data/.gitignore @@ -0,0 +1,11 @@ +# Files created by SplitData.py +NLPCC2016KBQA/*.txt + +# Files created by ConstructDatasetNer.py +ner_data/ + +# Files created by ConstructDatasetAttribute.py +sim_data/ + +# Files created by ConstructDB.py +DB_Data/ diff --git a/input/data/3-construct_dataset_attribute.py b/input/data/ConstructDatasetAttribute.py similarity index 61% rename from input/data/3-construct_dataset_attribute.py rename to input/data/ConstructDatasetAttribute.py index 0fac19c..fefad4a 100644 --- a/input/data/3-construct_dataset_attribute.py +++ b/input/data/ConstructDatasetAttribute.py @@ -1,5 +1,6 @@ -# coding:utf-8 -import sys +#-*-coding:UTF-8 -*- + + import os import random import pandas as pd @@ -8,7 +9,6 @@ ''' 通过 ner_data 中的数据 构建出 用来匹配句子相似度的 样本集合 构造属性关联训练集,分类问题,训练BERT分类模型 -1 ''' @@ -17,11 +17,11 @@ new_dir = 'sim_data' +# 正则表达式 pattern = re.compile('^-+') # 以-开头 - for file_name in file_name_list: file_path_name = os.path.join(data_dir,file_name) assert os.path.exists(file_path_name) @@ -29,25 +29,29 @@ attribute_classify_sample = [] df = pd.read_csv(file_path_name, encoding='utf-8') df['attribute'] = df['t_str'].apply(lambda x: x.split('|||')[1].strip()) - attribute_list = df['attribute'].tolist() # 转化成列表 - attribute_list = list(set(attribute_list)) # 去重 - attribute_list = [att.strip().replace(' ','') for att in attribute_list] # 去尾部,去空格 - attribute_list = [re.sub(pattern,'',att) for att in attribute_list] # 去掉 以-开头 - - attribute_list = list(set(attribute_list)) # 再去重 + # 将 DataFrame 数据类型转化为 List + attributes_list = df['attribute'].tolist() + # 通过列表set() 对其中的数据进行去重 + attributes_list = list(set(attributes_list)) + # 去尾部,去空格 + attributes_list = [att.strip().replace(' ','') for att in attributes_list] + # 去掉 以-开头 + attributes_list = [re.sub(pattern,'',att) for att in attributes_list] + # 再次去重 + attributes_list = list(set(attributes_list)) for row in df.index: question, pos_att = df.loc[row][['q_str', 'attribute']] - question = question.strip().replace(' ','') # 去尾部,空格 - question = re.sub(pattern, '', question) # 去掉 以-开头 + question = question.strip().replace(' ','') # 去尾部,空格 + question = re.sub(pattern, '', question) # 去掉 以-开头 - pos_att = pos_att.strip().replace(' ','') # 去尾部,空格 - pos_att = re.sub(pattern, '', pos_att) # 去掉 以-开头 + pos_att = pos_att.strip().replace(' ','') # 去尾部,空格 + pos_att = re.sub(pattern, '', pos_att) # 去掉 以-开头 neg_att_list = [] while True: - neg_att_list = random.sample(attribute_list, 5) + neg_att_list = random.sample(attributes_list, 5) if pos_att not in neg_att_list: break attribute_classify_sample.append([question, pos_att, '1']) @@ -56,13 +60,12 @@ attribute_classify_sample.extend(neg_att_sample) seq_result = [str(lineno) + '\t' + '\t'.join(line) for (lineno, line) in enumerate(attribute_classify_sample)] + #处理后将文件写在./input/sim_data 下 if not os.path.exists(new_dir): os.makedirs(new_dir) file_type = file_name.split('.')[0] - print("***** {} ******".format(file_type)) new_file_name = file_type + '.'+'txt' with open(os.path.join(new_dir,new_file_name), "w", encoding='utf-8') as f: f.write("\n".join(seq_result)) - f.close() diff --git a/input/data/2-construct_dataset_ner.py b/input/data/ConstructDatasetNer.py similarity index 88% rename from input/data/2-construct_dataset_ner.py rename to input/data/ConstructDatasetNer.py index cd6e971..a641fb2 100644 --- a/input/data/2-construct_dataset_ner.py +++ b/input/data/ConstructDatasetNer.py @@ -40,15 +40,22 @@ if answer_str in line: a_str = line.strip() + # 新的一个问题 if start_str in line: # new question answer triple + # 根据 .txt 对信息进行提取 entities = t_str.split("|||")[0].split(">")[1].strip() q_str = q_str.split(">")[1].replace(" ", "").strip() + + # 若该实体已经存在 if entities in q_str: q_list = list(q_str) seq_q_list.extend(q_list) seq_q_list.extend([" "]) tag_list = ["O" for i in range(len(q_list))] tag_start_index = q_str.find(entities) + # B-IOC: 一个地名的开始 + # I-IOC:一个地名的中间部分 + # 其余为 O for i in range(tag_start_index, tag_start_index + len(entities)): if tag_start_index == i: tag_list[i] = "B-LOC" @@ -63,6 +70,8 @@ print('\t'.join(seq_tag_list[0:50])) print('\t'.join(seq_q_list[0:50])) seq_result = [str(q) + " " + tag for q, tag in zip(seq_q_list, seq_tag_list)] + + # 将处理后的文件写在 ./input/NER_data 下 if not os.path.exists(new_dir): os.mkdir(new_dir) @@ -74,4 +83,4 @@ file_type = file_name.split('.')[0] csv_name = file_type+'.'+'csv' - df.to_csv(os.path.join(new_dir,csv_name), encoding='utf-8', index=False) \ No newline at end of file + df.to_csv(os.path.join(new_dir,csv_name), encoding='utf-8', index=False) diff --git a/input/data/5-triple_clean.py b/input/data/ConstructTriple.py similarity index 80% rename from input/data/5-triple_clean.py rename to input/data/ConstructTriple.py index 988a5a7..65b07a2 100644 --- a/input/data/5-triple_clean.py +++ b/input/data/ConstructTriple.py @@ -10,7 +10,7 @@ ''' -构造NER训练集,实体序列标注,训练BERT+BiLSTM+CRF +构造NER训练集,实体序列标注,用于训练BERT+BiLSTM+CRF ''' question_str = "")[1].strip() q_str = q_str.split(">")[1].replace(" ","").strip() if ''.join(entities.split(' ')) in q_str: @@ -44,7 +47,10 @@ print(q_str) print('------------------------') +# 三元组:实体,属性,答案 df = pd.DataFrame(triple_list, columns=["entity", "attribute", "answer"]) print(df) print(df.info()) -df.to_csv("./DB_Data/clean_triple.csv", encoding='utf-8', index=False) \ No newline at end of file + +# 处理完后将文件写在 ./input/DB_Data 下 +df.to_csv("./DB_Data/clean_triple.csv", encoding='utf-8', index=False) diff --git a/input/data/6-load_dbdata.py b/input/data/LoadDbData.py similarity index 82% rename from input/data/6-load_dbdata.py rename to input/data/LoadDbData.py index 2ae4eef..2ec6131 100644 --- a/input/data/6-load_dbdata.py +++ b/input/data/LoadDbData.py @@ -11,8 +11,9 @@ from sqlalchemy import create_engine +# 创建数据库,默认使用 root 用户 def create_db(): - connect = pymysql.connect( # 连接数据库服务器-*-*- + connect = pymysql.connect( # 创建连接,若账号密码不同请记得修改 user="root", password="123456", host="127.0.0.1", @@ -21,30 +22,25 @@ def create_db(): charset="utf8" ) conn = connect.cursor() # 创建操作游标 - # 你需要一个游标 来实现对数据库的操作相当于一条线索 - # 创建表 - conn.execute("drop database if exists KB_QA") # 如果new_database数据库存在则删除 + conn.execute("drop database if exists KB_QA") # 如果 KB_QA 数据库存在则删除 conn.execute("create database KB_QA") # 新创建一个数据库 - conn.execute("use KB_QA") # 选择new_database这个数据库 + conn.execute("use KB_QA") # 选择使用 KB_QA 数据库 conn.execute("SET @@global.sql_mode=''") - # sql 中的内容为创建一个名为 new_table 的表 + # sql 中的内容为创建一个名为 nlpccQA 的表 sql = """create table nlpccQA(entity VARCHAR(50) character set utf8 collate utf8_unicode_ci, attribute VARCHAR(50) character set utf8 collate utf8_unicode_ci, answer VARCHAR(255) character set utf8 collate utf8_unicode_ci)""" # ()中的参数可以自行设置 conn.execute("drop table if exists nlpccQA") # 如果表存在则删除 conn.execute(sql) # 创建表 - # 删除 - # conn.execute("drop table new_table") - - conn.close() # 关闭游标连接 - connect.close() # 关闭数据库服务器连接 释放内存 + conn.close() # 关闭游标 + connect.close() # 关闭与数据库的连接 def loaddata(): - # 初始化数据库连接,使用pymysql模块 + # 使用 pymysql,与数据库进行连接,同样,注意用户名和密码的设置。 db_info = {'user': 'root', 'password': '123456', 'host': '127.0.0.1', @@ -60,7 +56,7 @@ def loaddata(): # 读取本地CSV文件 df = pd.read_csv("./DB_Data/clean_triple.csv", sep=',', encoding='utf-8') - print(df) + # 将新建的DataFrame储存为MySQL中的数据表,不储存index列(index=False) # if_exists: # 1.fail:如果表存在,啥也不做 @@ -81,6 +77,7 @@ def upload_data(sql): charset="utf8" ) cursor = connect.cursor() # 创建操作游标 + results = None try: # 执行SQL语句 cursor.execute(sql) @@ -102,4 +99,4 @@ def upload_data(sql): ret = upload_data(sql) print(list(ret)) - # \ No newline at end of file + # diff --git a/input/data/4-print-seq-len.py b/input/data/PrintSeqLen.py similarity index 82% rename from input/data/4-print-seq-len.py rename to input/data/PrintSeqLen.py index 800e71d..71a37f8 100644 --- a/input/data/4-print-seq-len.py +++ b/input/data/PrintSeqLen.py @@ -1,5 +1,9 @@ import os +""" +查看整个句子的长度 +""" + dir_name = 'sim_data' file_list = ['train.txt','dev.txt','test.txt'] @@ -14,8 +18,8 @@ line_list = line.split('\t') question = list(line_list[1]) - attribute = list(line_list[2]) - add_len = len(question) + len(attribute) + attributes = list(line_list[2]) + add_len = len(question) + len(attributes) if add_len > max_len: max_len = add_len print("max_len",max_len) diff --git a/input/data/README.md b/input/data/README.md new file mode 100644 index 0000000..34b05cb --- /dev/null +++ b/input/data/README.md @@ -0,0 +1,8 @@ +# 数据处理 + +1. [SplitData.py](./SplitData.py) +1. [ConstructDataSetNER](./ConstructDataSetNER.py) +1. [ConstructDatasetAttribute.py](./ConstructDatasetAttribute.py) +1. [PrintSeqLen.py](./PrintSeqLen.py) +1. [ConstructTriple.py](./ConstructTriple.py) +1. [LoadDbData.py](./LoadDbData.py) diff --git a/input/data/1_split_data.py b/input/data/SplitData.py similarity index 99% rename from input/data/1_split_data.py rename to input/data/SplitData.py index 73f3066..0d4f1c1 100644 --- a/input/data/1_split_data.py +++ b/input/data/SplitData.py @@ -15,7 +15,7 @@ data_dir = 'NLPCC2016KBQA' file_name_list = ['nlpcc-iccpol-2016.kbqa.testing-data','nlpcc-iccpol-2016.kbqa.training-data'] - +#文件处理 for file_name in file_name_list: file_path_name = os.path.join(data_dir,file_name) file = [] From a4d7b63a2accf24ddd92585234ee3223dc869e79 Mon Sep 17 00:00:00 2001 From: henry_23 Date: Sun, 2 Jan 2022 00:07:32 +0800 Subject: [PATCH 02/10] feat: do online query when entity not found Co-authored-by: KisAaki <72684022+KisAaki@users.noreply.github.com> --- ProjectTest.py | 13 ++++++++++ WikiQuery.py | 51 ++++++++++++++++++++++++++++++++++++++++ input/data/LoadDbData.py | 19 +++++++++++++++ 3 files changed, 83 insertions(+) create mode 100644 WikiQuery.py diff --git a/ProjectTest.py b/ProjectTest.py index 58572bb..22575c4 100644 --- a/ProjectTest.py +++ b/ProjectTest.py @@ -240,6 +240,17 @@ def main(): print(triple_list) if 0 == len(triple_list): print("未找到 {} 相关信息".format(entity)) + + print("正在通过网络查找中...") + WikiQuery.getInfobox(entity) + # if len(elem_dic) != 0: + # for key in elem_dic: + # #print(key.text, elem_dic[key].text) + # if len(elem_dic[key].text) <= 10: + # insert_data(entity, key.text, elem_dic[key].text) + + print("查找完毕") + continue triple_list = list(zip(*triple_list)) print(triple_list) @@ -267,6 +278,8 @@ def main(): ret = "{}的{}是{}".format(entity, attribute, answer) if '' == ret: print("未找到{}相关信息".format(entity)) + print("正在通过网络查找中………………………………") + WikiQuery.getInfobox(entity) else: print("回答:",ret) diff --git a/WikiQuery.py b/WikiQuery.py new file mode 100644 index 0000000..00d5d3b --- /dev/null +++ b/WikiQuery.py @@ -0,0 +1,51 @@ +import time +from selenium import webdriver +from selenium.webdriver.common.keys import Keys +from input.data.LoadDbData import insert_data + +# getInfobox函数 +def getInfobox(name): + try: + # 访问百度百科并自动搜索 + driver = webdriver.Firefox() + driver.get("http://baike.baidu.com/") + elem_inp = driver.find_element_by_xpath("//form[@id='searchForm']/input") + elem_inp.send_keys(name) + elem_inp.send_keys(Keys.RETURN) + time.sleep(1) + print(driver.current_url) + print(driver.title) + + # 爬取消息盒InfoBox内容 + elem_name = driver.find_elements_by_xpath("//div[@class='basic-info J-basic-info cmn-clearfix']/dl/dt") + elem_value = driver.find_elements_by_xpath("//div[@class='basic-info J-basic-info cmn-clearfix']/dl/dd") + + # for e in elem_name: + # print(e.text) + # for e in elem_value: + # print(e.text) + + + # 构建字段成对输出 + elem_dic = dict(zip(elem_name, elem_value)) + for key in elem_dic: + print(key.text, elem_dic[key].text) + insert_data(name, key.text, elem_dic[key].text) + + return + + except Exception as e: + print("Error: ", e) + + finally: + print('\n') + driver.close() + + +# 主函数 + +def main(): + getInfobox("仙剑奇侠传") + +if __name__ == '__main__': + main() diff --git a/input/data/LoadDbData.py b/input/data/LoadDbData.py index 2ec6131..47fb49b 100644 --- a/input/data/LoadDbData.py +++ b/input/data/LoadDbData.py @@ -92,6 +92,25 @@ def upload_data(sql): return results +def insert_data(entity, attributes, answer): + connect = pymysql.connect( + user="root", + password="123456", + host="127.0.0.1", + port=3306, + db="kb_qa", + charset="utf8" + ) + + # 创建操作游标 + cursor = connect.cursor() + sql = "INSERT INTO nlpccQA(entity, attribute, answer) VALUES (%s, %s, %s)" + cursor.execute(sql, (entity, attributes, answer)) + connect.commit() + cursor.close() + connect.close() + + if __name__ == '__main__': # create_db() # loaddata() From 7afdb085cee8913099f51f8df0ee3aa4f23debab Mon Sep 17 00:00:00 2001 From: henry_23 Date: Sun, 2 Jan 2022 00:12:41 +0800 Subject: [PATCH 03/10] fix: ignore binary files --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index a3ce1a3..663bdb5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ __pycache__/ *.log +*.bin From e98ca1bc746ad241edd9be236d01901547a8817e Mon Sep 17 00:00:00 2001 From: henry_23 Date: Sun, 2 Jan 2022 00:16:13 +0800 Subject: [PATCH 04/10] feat: simple GUI with pyside6 developed from code provided in Qt's official guide --- .gitignore | 2 + chat.qml | 100 +++++++++++++++++++++++++++++++++++++ main.py | 62 +++++++++++++++++++++++ sqlDialog.py | 137 +++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 301 insertions(+) create mode 100644 chat.qml create mode 100644 main.py create mode 100644 sqlDialog.py diff --git a/.gitignore b/.gitignore index 663bdb5..62fff29 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ __pycache__/ *.log *.bin + +*.sqlite3 diff --git a/chat.qml b/chat.qml new file mode 100644 index 0000000..dc9ad1e --- /dev/null +++ b/chat.qml @@ -0,0 +1,100 @@ +import QtQuick +import QtQuick.Layouts +import QtQuick.Controls + +import ChatModel 1.0 + +ApplicationWindow { + id: root + title: qsTr("Chat") + width: 540 + height: 720 + visible: true + + SqlConversationModel { + id: chat_model + } + + ColumnLayout { + anchors.fill: parent + + ListView { + id: listView + Layout.fillWidth: true + Layout.fillHeight: true + Layout.margins: pane.leftPadding + messageField.leftPadding + displayMarginBeginning: 40 + displayMarginEnd: 40 + verticalLayoutDirection: ListView.BottomToTop + spacing: 12 + model: chat_model + delegate: Column { + + readonly property bool sentByMe: model.recipient !== "Me" + + anchors.right: sentByMe ? listView.contentItem.right : undefined + spacing: 6 + + Row { + id: messageRow + spacing: 6 + anchors.right: sentByMe ? parent.right : undefined + + Rectangle { // Message Blob + width: Math.min(messageText.implicitWidth + 24, + listView.width - (!sentByMe ? messageRow.spacing : 0)) + height: messageText.implicitHeight + 24 + radius: 15 + color: sentByMe ? "steelblue" : "lightgrey" + + Label { + id: messageText + text: model.message + color: sentByMe ? "white" : "black" + anchors.fill: parent + anchors.margins: 12 + wrapMode: Label.Wrap + } + } + } + + /*Label { + id: timestampText + text: Qt.formatDataTime(model.timestamp, "d MM hh:mm") + color: "lightgrey" + anchors.right: sentByMe ? parent.right : undefined + }*/ + } + + ScrollBar.vertical: ScrollBar {} + } + + Pane { + id: pane + Layout.fillWidth: true + + RowLayout { + width: parent.width + + TextArea { + id: messageField + Layout.fillWidth: true + placeholderText: qsTr("Compose message") + wrapMode: TextArea.Wrap + } + + Button { + id: sendButton + text: qsTr("Send") + enabled: messageField.length > 0 + onClicked: { + listView.model.send_message("machine", messageField.text, "Me"); + messageField.text = ""; + } + } + } + } + } +} + + diff --git a/main.py b/main.py new file mode 100644 index 0000000..9239402 --- /dev/null +++ b/main.py @@ -0,0 +1,62 @@ +import datetime +import logging + +import sys + +from PySide6.QtCore import QDir, QFile, QObject, QUrl +from PySide6.QtGui import QGuiApplication +from PySide6.QtQml import QQmlApplicationEngine +from PySide6.QtSql import QSqlDatabase + +import sqlDialog + +logging.basicConfig(filename='chat.log', level=logging.DEBUG) +logger = logging.getLogger('logger') + +""" +class Message(QObject): + author = "Me" + text = "" + def __init__(self, author_, text_, parent=None): + super.__init__(parent) + self.author = author_ + self.text = text_ + + +class DialogModel: + dialog_histroy = [] +""" + +def connectToDatabase(): + database = QSqlDatabase.database() + if not database.isValid(): + database = QSqlDatabase.addDatabase("QSQLITE") + if not database.isValid(): + logger.error("Cannot add database") + + write_dir = QDir("") + if not write_dir.mkpath("."): + logger.error("Failed to create writable directory") + + # Ensure that we have a writable location on all devices. + abs_path = write_dir.absolutePath() + filename = f"{abs_path}/chat-database.sqlite3" + + # When using the SQLite driver, open() will create the SQLite + # database if it doesn't exist. + database.setDatabaseName(filename) + if not database.open(): + logger.error("Cannot open database") + QFile.remove(filename) + +if __name__ == "__main__": + app = QGuiApplication() + connectToDatabase() + + engine = QQmlApplicationEngine() + engine.load(QUrl("chat.qml")) + + if not engine.rootObjects(): + sys.exit(-1) + + app.exec() diff --git a/sqlDialog.py b/sqlDialog.py new file mode 100644 index 0000000..025c043 --- /dev/null +++ b/sqlDialog.py @@ -0,0 +1,137 @@ +import datetime +import logging + +from PySide6.QtCore import Qt, Slot +from PySide6.QtSql import QSqlDatabase, QSqlQuery, QSqlRecord, QSqlTableModel +from PySide6.QtQml import QmlElement + +table_name = "Conversations" +QML_IMPORT_NAME = "ChatModel" +QML_IMPORT_MAJOR_VERSION = 1 +QML_IMPORT_MINOR_VERSION = 0 + +def createTable(): + if table_name in QSqlDatabase.database().tables(): + return + + query = QSqlQuery() + if not query.exec_( + """ + CREATE TABLE IF NOT EXISTS 'Conversations' ( + 'author' TEXT NOT NULL, + 'recipient' TEXT NOT NULL, + 'timestamp' TEXT NOT NULL, + 'message' TEXT NOT NULL, + FOREIGN KEY('author') REFERENCES Contacts ( name ), + FOREIGN KEY('recipient') REFERENCES Contacts ( name ) + ) + """ + ): + logging.error("Failed to query database") + + # This adds the first message from the Bot + # and further development is required to make it interactive. + query.exec_( + """ + INSERT INTO Conversations VALUES( + 'machine', 'Me', '2019-01-07T14:36:06', 'Hello!' + ) + """ + ) + + +def get_answer(question): + + return "You said '" + question + "'" + +@QmlElement +class SqlConversationModel(QSqlTableModel): + def __init__(self, parent=None): + super(SqlConversationModel, self).__init__(parent) + + createTable() + self.setTable(table_name) + self.setSort(2, Qt.DescendingOrder) + self.setEditStrategy(QSqlTableModel.OnManualSubmit) + self.recipient = "" + + self.setRecipient('machine') + + self.select() + logging.debug("Table was loaded successfully.") + + + def setRecipient(self, recipient): + if recipient == self.recipient: + pass + + self.recipient = recipient + + filter_str = ( + "(recipient = '{}' AND author = 'Me') OR " + "(recipient = 'Me' AND author = '{}')".format(self.recipient, self.recipient) + ) + + self.setFilter(filter_str) + self.select() + + def data(self, index, role): + if role < Qt.UserRole: + return QSqlTableModel.data(self, index, role) + + sql_record = QSqlRecord() + sql_record = self.record(index.row()) + + return sql_record.value(role - Qt.UserRole) + + def roleNames(self): + names = dict() + author = "author".encode() + recipient = "recipient".encode() + timestamp = "timestamp".encode() + message = "message".encode() + + names[hash(Qt.UserRole)] = author + names[hash(Qt.UserRole + 1)] = recipient + names[hash(Qt.UserRole + 2)] = timestamp + names[hash(Qt.UserRole + 3)] = message + + return names + + + + @Slot(str, str, str) + def send_message(self, recipient, message, author): + timestamp = datetime.datetime.now() + + new_record = self.record() + new_record.setValue('author', author) + new_record.setValue('recipient', recipient) + new_record.setValue('timestamp', str(timestamp)) + new_record.setValue('message', message) + + logging.debug(f'Message: "{message}" Received by: "{recipient}"') + + if not self.insertRecord(self.rowCount(), new_record): + logging.error(f'Failed to send message: {self.lastError().text()}') + return + + + new_record = self.record() + new_record.setValue('message', get_answer(message)) + new_record.setValue('author', recipient) + new_record.setValue('recipient', author) + + timestamp = datetime.datetime.now() + datetime.timedelta(microseconds=1) + new_record.setValue('timestamp', str(timestamp)) + + if not self.insertRecord(self.rowCount(), new_record): + logging.error(f'Failed to send message: {self.lastError().text()}') + return + + self.submitAll() + self.select() + + + + From 7eb9746b9e104871b76bf5a87bd6596e4b0e37e7 Mon Sep 17 00:00:00 2001 From: henry_23 Date: Sun, 2 Jan 2022 00:23:00 +0800 Subject: [PATCH 05/10] fix: rename to chat.py --- main.py => chat.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) rename main.py => chat.py (81%) diff --git a/main.py b/chat.py similarity index 81% rename from main.py rename to chat.py index 9239402..41a5d4a 100644 --- a/main.py +++ b/chat.py @@ -13,19 +13,6 @@ logging.basicConfig(filename='chat.log', level=logging.DEBUG) logger = logging.getLogger('logger') -""" -class Message(QObject): - author = "Me" - text = "" - def __init__(self, author_, text_, parent=None): - super.__init__(parent) - self.author = author_ - self.text = text_ - - -class DialogModel: - dialog_histroy = [] -""" def connectToDatabase(): database = QSqlDatabase.database() @@ -49,6 +36,7 @@ def connectToDatabase(): logger.error("Cannot open database") QFile.remove(filename) + if __name__ == "__main__": app = QGuiApplication() connectToDatabase() From 52c8d82f80820d0952ad130bfb21a1b857370bef Mon Sep 17 00:00:00 2001 From: henry_23 Date: Sun, 2 Jan 2022 00:48:44 +0800 Subject: [PATCH 06/10] refactor: extract process logic as 'Model' --- ProjectTest.py | 153 ++++++++++++++++++++++++++++--------------------- 1 file changed, 87 insertions(+), 66 deletions(-) diff --git a/ProjectTest.py b/ProjectTest.py index 22575c4..6803382 100644 --- a/ProjectTest.py +++ b/ProjectTest.py @@ -195,9 +195,9 @@ def text_match(attribute_list,answer_list,sentence): return "","" -def main(): - - with torch.no_grad(): +class Model: + @torch.no_grad() + def __init__(self): tokenizer_inputs = () tokenizer_kwards = {'do_lower_case': False, 'max_len': 64, @@ -219,72 +219,93 @@ def main(): sim_model = sim_model.to(device) sim_model.eval() - while True: - print("====="*10) - raw_text = input("问题:\n") - raw_text = raw_text.strip() - # 输入 quit 则退出。 - if ( "quit" == raw_text ): - print("quit") - return - # 获取实体 - entity = get_entity(model=ner_model, tokenizer=tokenizer, sentence=raw_text, max_len=64) - print("实体:", entity) - if '' == entity: - print("未发现实体") - continue - sql_str = "select * from nlpccqa where entity = '{}'".format(entity) - - triple_list = select_database(sql_str) - triple_list = list(triple_list) - print(triple_list) - if 0 == len(triple_list): - print("未找到 {} 相关信息".format(entity)) - - print("正在通过网络查找中...") - WikiQuery.getInfobox(entity) - # if len(elem_dic) != 0: - # for key in elem_dic: - # #print(key.text, elem_dic[key].text) - # if len(elem_dic[key].text) <= 10: - # insert_data(entity, key.text, elem_dic[key].text) - - print("查找完毕") - - continue - triple_list = list(zip(*triple_list)) - print(triple_list) - - attribute_list = triple_list[1] - answer_list = triple_list[2] - # 直接进行匹配 - attribute, answer = text_match(attribute_list, answer_list, raw_text) - if attribute != '' and answer != '': - ret = "{}的{}是{}".format(entity, attribute, answer) - else: - sim_model = get_sim_model(config_file='./input/config/bert-base-chinese-config.json', - pre_train_model='./output/best_sim.bin', - label_num=len(sim_processor.get_labels())) - - sim_model = sim_model.to(device) - sim_model.eval() - # 进行语义匹配 - attribute_idx = semantic_matching(sim_model, tokenizer, raw_text, attribute_list, answer_list, 64).item() - if -1 == attribute_idx: - ret = '' - else: - attribute = attribute_list[attribute_idx] - answer = answer_list[attribute_idx] - ret = "{}的{}是{}".format(entity, attribute, answer) - if '' == ret: - print("未找到{}相关信息".format(entity)) - print("正在通过网络查找中………………………………") - WikiQuery.getInfobox(entity) + @torch.no_grad() + def query(self, raw_text) -> str: + # 获取实体 + entity = get_entity(model=self.ner_model, tokenizer=self.tokenizer, sentence=raw_text, max_len=64) + + print("实体:", entity) + + if '' == entity: + return "未发现实体" + + sql_str = "select * from nlpccqa where entity = '{}'".format(entity) + + triple_list = select_database(sql_str) + triple_list = list(triple_list) + + print(triple_list) + + if 0 == len(triple_list): + # 未找到相关信息 + + print("未找到 {} 相关信息".format(entity)) + + print("正在通过网络查找中...") + + WikiQuery.getInfobox(entity) + # if len(elem_dic) != 0: + # for key in elem_dic: + # #print(key.text, elem_dic[key].text) + # if len(elem_dic[key].text) <= 10: + # insert_data(entity, key.text, elem_dic[key].text) + + print("查找完毕") + + return "未找到 {} 相关信息,尝试通过网络查找...".format(entity) + + triple_list = list(zip(*triple_list)) + print(triple_list) + + attribute_list = triple_list[1] + answer_list = triple_list[2] + # 直接进行匹配 + attribute, answer = text_match(attribute_list, answer_list, raw_text) + if attribute != '' and answer != '': + ret = "{}的{}是{}".format(entity, attribute, answer) + else: + self.sim_model = get_sim_model(config_file='./input/config/bert-base-chinese-config.json', + pre_train_model='./output/best_sim.bin', + label_num=len(self.sim_processor.get_labels())) + + self.sim_model = self.sim_model.to(device) + self.sim_model.eval() + # 进行语义匹配 + attribute_idx = semantic_matching(self.sim_model, self.tokenizer, raw_text, attribute_list, answer_list, 64).item() + if -1 == attribute_idx: + ret = '' else: - print("回答:",ret) + attribute = attribute_list[attribute_idx] + answer = answer_list[attribute_idx] + ret = "{}的{}是{}".format(entity, attribute, answer) + + if '' == ret: + print("未找到{}相关信息".format(entity)) + print("正在通过网络查找中...") + WikiQuery.getInfobox(entity) + return "未找到 {} 相关信息,尝试通过网络查找...".format(entity) + else: + return ret + if __name__ == '__main__': - main() + + model = Model() + + while True: + + print("====="*10) + + raw_text = input("问题:\n").strip() + + # 输入 quit 则退出。 + if ( "quit" == raw_text ): + print("quit") + break + + ans = model.query(raw_text) + + print('回答:', ans) From 4ed268977e9584729013da7f05491564d79850ca Mon Sep 17 00:00:00 2001 From: henry_23 Date: Sun, 2 Jan 2022 00:49:47 +0800 Subject: [PATCH 07/10] feat: call Model to fetch answer for user's query --- sqlDialog.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sqlDialog.py b/sqlDialog.py index 025c043..a2bb5c3 100644 --- a/sqlDialog.py +++ b/sqlDialog.py @@ -5,6 +5,8 @@ from PySide6.QtSql import QSqlDatabase, QSqlQuery, QSqlRecord, QSqlTableModel from PySide6.QtQml import QmlElement +from ProjectTest import Model + table_name = "Conversations" QML_IMPORT_NAME = "ChatModel" QML_IMPORT_MAJOR_VERSION = 1 @@ -40,10 +42,6 @@ def createTable(): ) -def get_answer(question): - - return "You said '" + question + "'" - @QmlElement class SqlConversationModel(QSqlTableModel): def __init__(self, parent=None): @@ -60,6 +58,8 @@ def __init__(self, parent=None): self.select() logging.debug("Table was loaded successfully.") + self.model = Model() + def setRecipient(self, recipient): if recipient == self.recipient: @@ -118,7 +118,8 @@ def send_message(self, recipient, message, author): new_record = self.record() - new_record.setValue('message', get_answer(message)) + ans = self.model.query(message.strip()) + new_record.setValue('message', ans) new_record.setValue('author', recipient) new_record.setValue('recipient', author) From 48dece8c3fd61164a11cfa3e5c1a1413754c58ea Mon Sep 17 00:00:00 2001 From: henry_23 Date: Sun, 26 Jun 2022 11:29:46 +0800 Subject: [PATCH 08/10] fix: improve code for 'ConstructDataset' --- input/data/Configurations.py | 4 + input/data/ConstructDatasetAttribute.py | 124 ++++++++++++------------ input/data/ConstructDatasetNer.py | 117 ++++++++++------------ input/data/MyUtils.py | 23 +++++ input/data/PrintSeqLen.py | 37 ------- input/data/README.md | 14 ++- input/data/SplitData.py | 58 +++++------ 7 files changed, 179 insertions(+), 198 deletions(-) create mode 100644 input/data/Configurations.py create mode 100644 input/data/MyUtils.py delete mode 100644 input/data/PrintSeqLen.py diff --git a/input/data/Configurations.py b/input/data/Configurations.py new file mode 100644 index 0000000..4fef354 --- /dev/null +++ b/input/data/Configurations.py @@ -0,0 +1,4 @@ +DATA_DIR = 'NLPCC2016KBQA' +NER_DATA_DIR = 'NER_Data' +SIM_DATD_DIR = 'SIM_Data' +DB_DATA_DIR = 'DB_data' diff --git a/input/data/ConstructDatasetAttribute.py b/input/data/ConstructDatasetAttribute.py index fefad4a..9bbca90 100644 --- a/input/data/ConstructDatasetAttribute.py +++ b/input/data/ConstructDatasetAttribute.py @@ -1,71 +1,73 @@ #-*-coding:UTF-8 -*- +from Configurations import NER_DATA_DIR, SIM_DATD_DIR +from MyUtils import my_strip, shuffle_from_ import os -import random -import pandas as pd -import re +import csv ''' 通过 ner_data 中的数据 构建出 用来匹配句子相似度的 样本集合 构造属性关联训练集,分类问题,训练BERT分类模型 +BERT: + 1. 预测 mask 概率; + 2. 预测 next sentence 概率. ''' - -data_dir = 'ner_data' -file_name_list = ['train.csv','dev.csv','test.csv'] - -new_dir = 'sim_data' - -# 正则表达式 -pattern = re.compile('^-+') # 以-开头 - - - -for file_name in file_name_list: - file_path_name = os.path.join(data_dir,file_name) - assert os.path.exists(file_path_name) - - attribute_classify_sample = [] - df = pd.read_csv(file_path_name, encoding='utf-8') - df['attribute'] = df['t_str'].apply(lambda x: x.split('|||')[1].strip()) - # 将 DataFrame 数据类型转化为 List - attributes_list = df['attribute'].tolist() - # 通过列表set() 对其中的数据进行去重 - attributes_list = list(set(attributes_list)) - # 去尾部,去空格 - attributes_list = [att.strip().replace(' ','') for att in attributes_list] - # 去掉 以-开头 - attributes_list = [re.sub(pattern,'',att) for att in attributes_list] - # 再次去重 - attributes_list = list(set(attributes_list)) - - for row in df.index: - question, pos_att = df.loc[row][['q_str', 'attribute']] - - question = question.strip().replace(' ','') # 去尾部,空格 - question = re.sub(pattern, '', question) # 去掉 以-开头 - - pos_att = pos_att.strip().replace(' ','') # 去尾部,空格 - pos_att = re.sub(pattern, '', pos_att) # 去掉 以-开头 - - neg_att_list = [] - while True: - neg_att_list = random.sample(attributes_list, 5) - if pos_att not in neg_att_list: - break - attribute_classify_sample.append([question, pos_att, '1']) - - neg_att_sample = [[question, neg_att, '0'] for neg_att in neg_att_list] - attribute_classify_sample.extend(neg_att_sample) - seq_result = [str(lineno) + '\t' + '\t'.join(line) for (lineno, line) in enumerate(attribute_classify_sample)] - - #处理后将文件写在./input/sim_data 下 - if not os.path.exists(new_dir): - os.makedirs(new_dir) - - file_type = file_name.split('.')[0] - new_file_name = file_type + '.'+'txt' - with open(os.path.join(new_dir,new_file_name), "w", encoding='utf-8') as f: - f.write("\n".join(seq_result)) - +# 处理后将文件写在 ./input/SIM_data 下 +if not os.path.exists(SIM_DATD_DIR): + os.makedirs(SIM_DATD_DIR) + +filenames = ['train.csv', 'dev.csv', 'test.csv'] + +sample_line_tmplt = '{}\t{}\t{}\t{}\n' + +for filename in filenames: + file_path = os.path.join(NER_DATA_DIR, filename) + assert os.path.exists(file_path) # 断言处理, 判断该文件是否存在 + rows = [] # 记录文件内容 + attribute_set = set() # 存储 attribute 的集合 + with open(file_path, mode='r', encoding='utf-8') as f: + reader = csv.DictReader(f) + # 读取文件 + for row in reader: + # 处理 question 字符串 + question = my_strip(row['question']) + # 提取并处理 attribute 字段 + attribute = my_strip(row['triple'].split('|||')[1]) + # 记录 attribute 字段 + row['attribute'] = attribute + # 并存储在集合中 + rows.append(row) + # 同时记录文件内容 + attribute_set.add(attribute) + + attribute_set = list(attribute_set) + + filename_tokens = filename.split('.') + filename_tokens[-1] = 'txt' + with open(os.path.join(SIM_DATD_DIR, '.'.join(filename_tokens)), mode='w', encoding='utf-8') as attribute_classify_sample: + NEGATIVE_SAMPLE_COUNT = 5 + cnt = 0 + max_len = 0 + for row in rows: + question, postive_attribute = row['question'], row['attribute'] + # 随机生成并记录负样本 + negative_attributes = [] + # 进行随机取样 + for i in range(NEGATIVE_SAMPLE_COUNT): + sample = shuffle_from_(attribute_set) + while sample == postive_attribute: + sample = shuffle_from_(attribute_set) + negative_attributes.append(sample) + # 写出生成的样本 + # 1 为正确答案, 0 为错误答案 + attribute_classify_sample.write(sample_line_tmplt.format(cnt, question, postive_attribute, 1)) + cnt += 1 + for var in negative_attributes: + attribute_classify_sample.write(sample_line_tmplt.format(cnt, question, var, 0)) + cnt += 1 + # 记录最大字串序列长度 + max_len = max(max_len, len(question) + len(var)) + max_len = max(max_len, len(question) + len(postive_attribute)) + print('{}:\tmax_len: {}'.format(filename, max_len)) diff --git a/input/data/ConstructDatasetNer.py b/input/data/ConstructDatasetNer.py index a641fb2..bbee96f 100644 --- a/input/data/ConstructDatasetNer.py +++ b/input/data/ConstructDatasetNer.py @@ -1,86 +1,71 @@ # coding:utf-8 -import sys import os -import pandas as pd - +import csv +from Configurations import DATA_DIR, NER_DATA_DIR +from MyUtils import line_parse ''' -通过 NLPCC2016KBQA 中的原始数据,构建用来训练NER的样本集合 -构造NER训练集,实体序列标注,训练BERT+CRF +通过 NLPCC2016KBQA 中的原始数据, 构建用来训练NER的样本集合 +构造 NER 训练集, 实体序列标注, 训练BERT+CRF ''' -data_dir = 'NLPCC2016KBQA' -file_name_list = ['train.txt','dev.txt','test.txt'] +file_name_list = ['train.txt', 'dev.txt', 'test.txt'] -new_dir = 'ner_data' +# 将处理后的文件写在 ./input/NER_data 下 +if not os.path.exists(NER_DATA_DIR): + os.mkdir(NER_DATA_DIR) -question_str = "")[1].strip() - q_str = q_str.split(">")[1].replace(" ", "").strip() + with open(file_path, 'r', encoding='utf-8') as f: + while True: + try: + # 一次读取三行 + l = [line_parse(f.__next__()) for i in range(3)] + # 并映射到字典的形式 + s = {t[0]: t[1].strip() for t in l} + # 跳过分割线行 + f.__next__() + + q_str = s['question'].replace(' ', '') + s['question'] = q_str - # 若该实体已经存在 - if entities in q_str: - q_list = list(q_str) - seq_q_list.extend(q_list) - seq_q_list.extend([" "]) - tag_list = ["O" for i in range(len(q_list))] - tag_start_index = q_str.find(entities) + entity = s['triple'].split('|||')[0].strip() + p = q_str.find(entity) + # 若该实体名存在于问题中 + if p != -1: + tags = ['O'] * len(q_str) + # BIO 标注划分 # B-IOC: 一个地名的开始 # I-IOC:一个地名的中间部分 # 其余为 O - for i in range(tag_start_index, tag_start_index + len(entities)): - if tag_start_index == i: - tag_list[i] = "B-LOC" - else: - tag_list[i] = "I-LOC" - seq_tag_list.extend(tag_list) - seq_tag_list.extend([" "]) + tags[p] = 'B-LOC' + for i in range(p + 1, p + len(entity)): + tags[i] = 'I-LOC' + # 存储序列标注等待写出 + for q, t in zip(q_str, tags): + tagged_q_str_file.write(q + ' ' + t + '\n') + tagged_q_str_file.write('\n') else: pass - q_t_a_list.append([q_str, t_str, a_str]) - print(file_name) - print('\t'.join(seq_tag_list[0:50])) - print('\t'.join(seq_q_list[0:50])) - seq_result = [str(q) + " " + tag for q, tag in zip(seq_q_list, seq_tag_list)] - - # 将处理后的文件写在 ./input/NER_data 下 - if not os.path.exists(new_dir): - os.mkdir(new_dir) - - with open(os.path.join(new_dir,file_name), "w", encoding='utf-8') as f: - f.write("\n".join(seq_result)) - f.close() - - df = pd.DataFrame(q_t_a_list, columns=["q_str", "t_str", "a_str"]) - file_type = file_name.split('.')[0] - csv_name = file_type+'.'+'csv' - df.to_csv(os.path.join(new_dir,csv_name), encoding='utf-8', index=False) + qta_writer.writerow(s) + + except StopIteration: + qta_file.close() + break diff --git a/input/data/MyUtils.py b/input/data/MyUtils.py new file mode 100644 index 0000000..f5f10c8 --- /dev/null +++ b/input/data/MyUtils.py @@ -0,0 +1,23 @@ +from typing import Tuple +import re +import random + +def line_parse(line: str) -> Tuple[str, str]: + p1 = line.find('<') + p2 = line.find('>') + header = line[(p1+1):p2] + content = line[p2+1:].strip() + return header.split()[0], content + + +# 正则表达式 +pattern = re.compile('^-+') # 以-开头 + +def my_strip(s: str) -> str: + # 去首尾空白, 去除字符串之间空格 + s = s.strip().replace(' ', '') + # 去掉字符串开始处的 '-' + return re.sub(pattern, '', s) + + +shuffle_from_ = lambda l: random.choice(l) diff --git a/input/data/PrintSeqLen.py b/input/data/PrintSeqLen.py deleted file mode 100644 index 71a37f8..0000000 --- a/input/data/PrintSeqLen.py +++ /dev/null @@ -1,37 +0,0 @@ -import os - -""" -查看整个句子的长度 -""" - -dir_name = 'sim_data' -file_list = ['train.txt','dev.txt','test.txt'] - -for file in file_list: - - file_path_name = os.path.join(dir_name,file) - - max_len = 0 - print("****** {} *******".format(file)) - with open(file_path_name,'r',encoding='utf-8') as f: - for line in f: - - line_list = line.split('\t') - question = list(line_list[1]) - attributes = list(line_list[2]) - add_len = len(question) + len(attributes) - if add_len > max_len: - max_len = add_len - print("max_len",max_len) - f.close() - -# ****** train.txt ******* -# max_len 62 -# ****** dev.txt ******* -# max_len 61 -# ****** test.txt ******* -# max_len 62 - - -# 因此,最大长度为 64 合理。 - diff --git a/input/data/README.md b/input/data/README.md index 34b05cb..89ba485 100644 --- a/input/data/README.md +++ b/input/data/README.md @@ -1,8 +1,18 @@ # 数据处理 1. [SplitData.py](./SplitData.py) -1. [ConstructDataSetNER](./ConstructDataSetNER.py) +1. [ConstructDataSetNER.py](./ConstructDataSetNER.py) 1. [ConstructDatasetAttribute.py](./ConstructDatasetAttribute.py) -1. [PrintSeqLen.py](./PrintSeqLen.py) + +记录下数据中最长的序列长度: + +```none +train.csv: max_len: 62 +dev.csv: max_len: 60 +test.csv: max_len: 62 +``` + +因此, 将最大长度设定为 64 较为合理. + 1. [ConstructTriple.py](./ConstructTriple.py) 1. [LoadDbData.py](./LoadDbData.py) diff --git a/input/data/SplitData.py b/input/data/SplitData.py index 0d4f1c1..ddad262 100644 --- a/input/data/SplitData.py +++ b/input/data/SplitData.py @@ -1,23 +1,20 @@ -# -# 切分数据集, -# 原始的 nlpcc-iccpol-2016.kbqa.testing-data 有 9870 个样本 -# 原始的 nlpcc-iccpol-2016.kbqa.training-data 有 14609 个样本 -# -# 将nlpcc-iccpol-2016.kbqa.testing-data 中的对半分,一半变成验证集(dev.text),一半变成测试集(test.txt) -# nlpcc-iccpol-2016.kbqa.training-data 保持不变,复制成为训练集 train.txt -# - - -import pandas as pd import os +from Configurations import DATA_DIR +""" +切分数据集; +原始的 nlpcc-iccpol-2016.kbqa.testing-data 有 9870 个样本 +原始的 nlpcc-iccpol-2016.kbqa.training-data 有 14609 个样本; -data_dir = 'NLPCC2016KBQA' -file_name_list = ['nlpcc-iccpol-2016.kbqa.testing-data','nlpcc-iccpol-2016.kbqa.training-data'] +nlpcc-iccpol-2016.kbqa.testing-data 中的数据对半分,一半变成验证集(dev.text),一半变成测试集(test.txt) +nlpcc-iccpol-2016.kbqa.training-data 保持不变,复制成为训练集 train.txt +""" -#文件处理 +file_name_list = ['nlpcc-iccpol-2016.kbqa.testing-data', 'nlpcc-iccpol-2016.kbqa.training-data'] + +# 文件处理 for file_name in file_name_list: - file_path_name = os.path.join(data_dir,file_name) + file_path_name = os.path.join(DATA_DIR, file_name) file = [] with open(file_path_name,'r',encoding='utf-8') as f: for line in f: @@ -25,26 +22,23 @@ if line == '': continue file.append(line) - f.close() if 'training' in file_name: - with open(os.path.join(data_dir,'train.txt') , "w", encoding='utf-8') as f: + with open(os.path.join(DATA_DIR,'train.txt') , "w", encoding='utf-8') as f: f.write('\n'.join(file)) - f.close() elif 'testing' in file_name: - assert len(file) % 4 == 0 # 断言 - testing_num = len(file) / 4 # 一个样本是由 4 行构成的 - test_num = int(testing_num / 2) # 真正的测试集分一半 - - test_line_no = int(test_num * 4) - - - with open(os.path.join(data_dir, 'test.txt'), "w", encoding='utf-8') as f: - f.write('\n'.join(file[:test_line_no])) # 乘以四得到行号,前一半给 test 数据集 - f.close() - - with open(os.path.join(data_dir, 'dev.txt'), "w", encoding='utf-8') as f: - f.write('\n'.join(file[test_line_no:])) # 乘以四得到行号,后一半给 dev 数据集 - f.close() + assert len(file) % 4 == 0 # 断言处理,错误时触发异常 + testing_num = len(file) // 4 # 一个样本是由 4 行构成的 + line_no = testing_num // 2 * 4 # 将测试集分出一半用作评估 + + # 测试数据 + with open(os.path.join(DATA_DIR, 'test.txt'), "w", encoding='utf-8') as f: + for line in file[:line_no]: + f.write(line + '\n') + + # 进行评估的数据 + with open(os.path.join(DATA_DIR, 'dev.txt'), "w", encoding='utf-8') as f: + for line in file[line_no:]: + f.write(line + '\n') print("Done") From 490ad0a62e4bf9490aaec1fabdba81a631b55984 Mon Sep 17 00:00:00 2001 From: henry_23 Date: Sun, 26 Jun 2022 12:42:53 +0800 Subject: [PATCH 09/10] fix: improve code for 'ConstructTriple' --- input/data/ConstructTriple.py | 99 +++++++++++++++++------------------ 1 file changed, 49 insertions(+), 50 deletions(-) diff --git a/input/data/ConstructTriple.py b/input/data/ConstructTriple.py index 65b07a2..ff49565 100644 --- a/input/data/ConstructTriple.py +++ b/input/data/ConstructTriple.py @@ -1,56 +1,55 @@ # -*- coding: utf-8 -*- -# @Time : 2019/4/18 20:16 -# @Author : Alan -# @Email : xiezhengwen2013@163.com -# @File : triple_clean.py -# @Software: PyCharm - - -import pandas as pd +import os +import csv +from Configurations import DATA_DIR, DB_DATA_DIR +from MyUtils import line_parse, my_strip ''' -构造NER训练集,实体序列标注,用于训练BERT+BiLSTM+CRF +构造 NER 训练集 (即三元组), 实体序列标注, 用于训练 BERT + BiLSTM + CRF ''' -question_str = "")[1].strip() - q_str = q_str.split(">")[1].replace(" ","").strip() - if ''.join(entities.split(' ')) in q_str: - clean_triple = t_str.split(">")[1].replace('\t','').replace(" ","").strip().split("|||") - triple_list.append(clean_triple) - else: - print(entities) - print(q_str) - print('------------------------') - -# 三元组:实体,属性,答案 -df = pd.DataFrame(triple_list, columns=["entity", "attribute", "answer"]) -print(df) -print(df.info()) - -# 处理完后将文件写在 ./input/DB_Data 下 -df.to_csv("./DB_Data/clean_triple.csv", encoding='utf-8', index=False) +# 将处理后的文件写在 ./input/DB_data 下 +if not os.path.exists(DB_DATA_DIR): + os.mkdir(DB_DATA_DIR) + +filename_prefix = 'nlpcc-iccpol-2016.kbqa.' +filename_suffix = '-data' +file_name_list = [ + filename_prefix + var + filename_suffix for var in ['training', 'testing'] +] + +with open(os.path.join(DB_DATA_DIR, 'clean_triple.csv'), mode='w', encoding='utf-8', newline='') as out_csv_file: + # 三元组: 实体, 属性, 答案 + writer = csv.DictWriter(out_csv_file, fieldnames=['entity', 'attribute', 'answer']) + writer.writeheader() + + for filename in file_name_list: + file_path = os.path.join(DATA_DIR, filename) + assert os.path.exists(file_path) + + with open(file_path, mode='r', encoding='utf-8') as f: + while True: + try: + # 一次读取三行 + l = [line_parse(f.__next__()) for i in range(3)] + # 并映射到字典的形式 + s = {t[0]: t[1] for t in l} + # 跳过分割线行 + f.__next__() + + triples = [t.strip() for t in s['triple'].split('|||')] + t = { + 'entity': triples[0], + 'attribute': triples[1], + 'answer': triples[2], + } + if t['entity'] in s['question']: + writer.writerow(t) + else: + # print(s['question']) + # print(t['entity']) + # print('-' * 30) + pass + except StopIteration: + break From ce03cc00162bdd42969548dbdab572c1170d1f8d Mon Sep 17 00:00:00 2001 From: henry_23 Date: Sun, 26 Jun 2022 13:18:38 +0800 Subject: [PATCH 10/10] fix: code comments & README --- NERTrain.py | 28 +------- README.md | 81 +++++++++++++++------- SIMTrain.py | 22 ------ WikiQuery.py | 2 +- input/data/{LoadDbData.py => LoadMySQL.py} | 73 ++++++++++--------- input/data/README.md | 15 +++- 6 files changed, 112 insertions(+), 109 deletions(-) rename input/data/{LoadDbData.py => LoadMySQL.py} (59%) diff --git a/NERTrain.py b/NERTrain.py index 993a272..be47d6d 100644 --- a/NERTrain.py +++ b/NERTrain.py @@ -1,25 +1,3 @@ -# --data_dir -# ./input/data/ner_data -# --vob_file -# ./input/config/bert-base-chinese-vocab.txt -# --model_config -# ./input/config/bert-base-chinese-config.json -# --output -# ./output -# --pre_train_model -# ./input/config/bert-base-chinese-model.bin -# --max_seq_length -# 64 -# --do_train -# --train_batch_size -# 16 -# --eval_batch_size -# 256 -# --gradient_accumulation_steps -# 4 -# --num_train_epochs -# 15 - import argparse import logging import codecs @@ -37,11 +15,11 @@ from sklearn.metrics import classification_report logger = logging.getLogger(__name__) + # CRF_LABELS = ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "X", "[CLS]", "[SEP]"] -# 在项目中只需要识别三个类型的项目即可 -# LABELS = ["O", "B-LOC", "I-LOC"],需要预测的就只有这三个。 -CRF_LABELS = ["O", "B-LOC", "I-LOC"] +# 在项目中只需要识别三个类型的项目即可 +CRF_LABELS = ["O", "B-LOC", "I-LOC"] # 需要预测的就只有这三个 def statistical_real_sentences(input_ids:torch.Tensor,mask:torch.Tensor,predict:list)-> list: # shape (batch_size,max_len) diff --git a/README.md b/README.md index ebcc453..faa1794 100644 --- a/README.md +++ b/README.md @@ -1,51 +1,80 @@ # bert-kbqa -基于bert的kbqa系统 - -预训练模型太大了,无法上传。放在百度网盘了 -链接:https://pan.baidu.com/s/1EK-TGghfmj-0HbWl_xe3zg -提取码:jqeg -下载下来放在./bert-kbqa/input/config目录 - -构造数据集 -通过 1_split_data.py 切分数据 - -通过 2-construct_dataset_ner.py 构造命名实体识别的数据集 +基于bert的kbqa系统 -通过 3-construct_dataset_attribute.py 构造属性相似度的数据集 +预训练模型太大, 放在百度网盘了 -通过 4-print-seq-len.py 看看句子的长度 +- 链接: https://pan.baidu.com/s/1EK-TGghfmj-0HbWl_xe3zg +- 提取码: jqeg -通过 5-triple_clean.py 构造干净的三元组 +下载下来放在 `./bert-kbqa/input/config` 目录 -通过 6-load_dbdata.py 通过创建数据库 和 上传数据 +## 使用方法 +### 安装依赖 +- PyTorch +- Transformers -CRF_Model.py 条件随机场模型 +### 构造数据集 -BERT_CRF.py bert+条件随机场 +详见 [input/data/](./input/data/) 目录 -NER_main.py 训练命令实体识别的模型 +1. [SplitData.py](./input/data/SplitData.py) 切分数据 +2. [ConstructDatasetNer.py](./input/data/ConstructDatasetNer.py) 构造命名实体识别的数据集 +3. [ConstructDatasetAttribute.py](./input/data/ConstructDatasetAttribute.py) 构造属性相似度的数据集 +4. [ConstructTriple.py](./input/data/ConstructTriple.py) 构造干净的三元组 +5. [LoadMySQL.py](./input/data/LoadMySQL.py) 创建数据库和上传数据 -SIM_main.py 训练属性相似度的模型 +### 模型训练 +[CRF_Model.py](./CRF_Model.py) 条件随机场模型 -test_NER.py 测试命令实体识别 +[BERT_CRF_MODEL.py](./BERT_CRF_Model.py) bert+条件随机场 -test_SIM.py 测试属性相似度 +[NERTrain.py](./NERTrain.py) 训练命令实体识别的模型 -test_pro.py 测试整个项目 +```console +$ python3 NERTrain.py \ + --data_dir ./input/data/ner_data \ + --vob_file ./input/config/bert-base-chinese-vocab.txt \ + --model_config ./input/config/bert-base-chinese-config.json \ + --output ./output \ + --pre_train_model ./input/config/bert-base-chinese-model.bin \ + --max_seq_length 64 \ + --do_train \ + --train_batch_size 16 \ + --eval_batch_size 256 \ + --gradient_accumulation_steps 4 \ + --num_train_epochs 15 +``` +[SIMTrain.py](./SIMTrain.py) 训练属性相似度的模型 -主要依赖版本: +```console +$ python3 SIMTrain.py \ + --data_dir ./input/data/sim_data \ + --vob_file ./input/config/bert-base-chinese-vocab.txt \ + --model_config ./input/config/bert-base-chinese-config.json \ + --output ./output \ + --pre_train_model ./input/config/bert-base-chinese-model.bin \ + --max_seq_length 64 \ + --do_train \ + --train_batch_size 32 \ + --eval_batch_size 256 \ + --gradient_accumulation_steps 4 \ + --num_train_epochs 15 +``` -torch.__version__ 1.2.0 +### 模型测试 -transformers.__version__ 2.0.0 +[NERTest.py](./NERTest.py) 测试命令实体识别 +[SIMTest.py](./SIMTest.py) 测试属性相似度 -带有命令运行的py文件的命令都在 该py文件的最上方 +## 项目展示 +[ProjectTest.py](./ProjectTest.py) 测试整个项目 / 命令行界面的项目展示 +[chat.py](./chat.py) Qt/QML 编写的图形界面 diff --git a/SIMTrain.py b/SIMTrain.py index 376ebbc..3af654a 100644 --- a/SIMTrain.py +++ b/SIMTrain.py @@ -1,25 +1,3 @@ -# --data_dir -# ./input/data/sim_data -# --vob_file -# ./input/config/bert-base-chinese-vocab.txt -# --model_config -# ./input/config/bert-base-chinese-config.json -# --output -# ./output -# --pre_train_model -# ./input/config/bert-base-chinese-model.bin -# --max_seq_length -# 64 -# --do_train -# --train_batch_size -# 32 -# --eval_batch_size -# 256 -# --gradient_accumulation_steps -# 4 -# --num_train_epochs -# 15 - import argparse from collections import Counter import code diff --git a/WikiQuery.py b/WikiQuery.py index 00d5d3b..36e6338 100644 --- a/WikiQuery.py +++ b/WikiQuery.py @@ -1,7 +1,7 @@ import time from selenium import webdriver from selenium.webdriver.common.keys import Keys -from input.data.LoadDbData import insert_data +from input.data.LoadMySQL import insert_data # getInfobox函数 def getInfobox(name): diff --git a/input/data/LoadDbData.py b/input/data/LoadMySQL.py similarity index 59% rename from input/data/LoadDbData.py rename to input/data/LoadMySQL.py index 47fb49b..f096386 100644 --- a/input/data/LoadDbData.py +++ b/input/data/LoadMySQL.py @@ -1,19 +1,14 @@ # -*- coding: utf-8 -*- -# @Time : 2019/4/18 20:47 -# @Author : Alan -# @Email : xiezhengwen2013@163.com -# @File : load_dbdata.py -# @Software: PyCharm - import pymysql import pandas as pd from sqlalchemy import create_engine - -# 创建数据库,默认使用 root 用户 +# 创建数据库, 默认使用 root 用户 def create_db(): - connect = pymysql.connect( # 创建连接,若账号密码不同请记得修改 + + #创建连接,若账号密码不同请记得修改; + connect = pymysql.connect( user="root", password="123456", host="127.0.0.1", @@ -21,34 +16,45 @@ def create_db(): db="KB_QA", charset="utf8" ) - conn = connect.cursor() # 创建操作游标 - - conn.execute("drop database if exists KB_QA") # 如果 KB_QA 数据库存在则删除 - conn.execute("create database KB_QA") # 新创建一个数据库 - conn.execute("use KB_QA") # 选择使用 KB_QA 数据库 + #创建操作游标 + conn = connect.cursor() + + # 如果 KB_QA 数据库存在则删除 + conn.execute("drop database if exists KB_QA") + # 新创建一个数据库 + conn.execute("create database KB_QA") + # 选择使用 KB_QA 数据库 + conn.execute("use KB_QA") conn.execute("SET @@global.sql_mode=''") - # sql 中的内容为创建一个名为 nlpccQA 的表 - sql = """create table nlpccQA(entity VARCHAR(50) character set utf8 collate utf8_unicode_ci, + # 如果表存在,则删除 + conn.execute("drop table if exists nlpccQA") + conn.execute(sql) + + # 创建一个名为 nlpccQA 的表 + sql = """ + create table nlpccQA(entity VARCHAR(50) character set utf8 collate utf8_unicode_ci, attribute VARCHAR(50) character set utf8 collate utf8_unicode_ci, answer VARCHAR(255) character set utf8 - collate utf8_unicode_ci)""" # ()中的参数可以自行设置 - conn.execute("drop table if exists nlpccQA") # 如果表存在则删除 - conn.execute(sql) # 创建表 + collate utf8_unicode_ci) + """ - conn.close() # 关闭游标 - connect.close() # 关闭与数据库的连接 + #关闭游标 + conn.close() + #关闭与数据库的连接 + connect.close() def loaddata(): - # 使用 pymysql,与数据库进行连接,同样,注意用户名和密码的设置。 + #使用 pymysql,与数据库进行连接,同样,注意用户名和密码的设置。 db_info = {'user': 'root', - 'password': '123456', + 'password': 'root', 'host': '127.0.0.1', 'port': 3306, 'database': 'KB_QA' } - engine = create_engine( # 导入模块中的create_engine,需要利用它来进行连接数据库 + # 导入模块中的 create_engine, 需要利用它来进行连接数据库 + engine = create_engine( 'mysql+pymysql://%(user)s:%(password)s@%(host)s:%(port)d/%(database)s?charset=utf8' % db_info, encoding='utf-8') # ("mysql+pymysql://【此处填用户名】:【此处填密码】@【此处填host】:【此处填port】/【此处填数据库的名称】?charset=utf8") # 直接使用这种形式也可以engine = create_engine('mysql+pymysql://root:123456@localhost:3306/test') @@ -57,26 +63,28 @@ def loaddata(): # 读取本地CSV文件 df = pd.read_csv("./DB_Data/clean_triple.csv", sep=',', encoding='utf-8') - # 将新建的DataFrame储存为MySQL中的数据表,不储存index列(index=False) - # if_exists: - # 1.fail:如果表存在,啥也不做 - # 2.replace:如果表存在,删了表,再建立一个新表,把数据插入 - # 3.append:如果表存在,把数据插入,如果表不存在创建一个表!! + # 将新建的 DataFrame 储存到 MySQL 中的数据表, 不储存 index 列 (index=False) + # if_exists 参数: + # - fail: 如果表存在,啥也不做 + # - replace: 如果表存在,删了表,再建立一个新表,把数据插入 + # - append: 如果表存在,把数据插入,如果表不存在创建一个表!! pd.io.sql.to_sql(df, 'nlpccQA', con=engine, index=False, if_exists='append', chunksize=10000) # df.to_sql('example', con=engine, if_exists='replace')这种形式也可以 print("Write to MySQL successfully!") def upload_data(sql): - connect = pymysql.connect( # 连接数据库服务器 + #连接数据库服务器 + connect = pymysql.connect( user="root", - password="123456", + password="root", host="127.0.0.1", port=3306, db="kb_qa", charset="utf8" ) - cursor = connect.cursor() # 创建操作游标 + # 创建操作游标 + cursor = connect.cursor() results = None try: # 执行SQL语句 @@ -118,4 +126,3 @@ def insert_data(entity, attributes, answer): ret = upload_data(sql) print(list(ret)) - # diff --git a/input/data/README.md b/input/data/README.md index 89ba485..7d06fc4 100644 --- a/input/data/README.md +++ b/input/data/README.md @@ -1,5 +1,11 @@ # 数据处理 +### 使用的数据集 + +NLPCC2016KBQA + +### 构建训练数据 + 1. [SplitData.py](./SplitData.py) 1. [ConstructDataSetNER.py](./ConstructDataSetNER.py) 1. [ConstructDatasetAttribute.py](./ConstructDatasetAttribute.py) @@ -14,5 +20,10 @@ test.csv: max_len: 62 因此, 将最大长度设定为 64 较为合理. -1. [ConstructTriple.py](./ConstructTriple.py) -1. [LoadDbData.py](./LoadDbData.py) +### 从数据构建知识三元组 + +[ConstructTriple.py](./ConstructTriple.py) + +### 将数据载入到数据库 + +[LoadDbData.py](./LoadDbData.py) (MySQL)