Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
__pycache__/

*.log
*.bin

*.sqlite3
18 changes: 3 additions & 15 deletions BERT_CRF.py → BERT_CRF_Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@
VOB_NAME = "bert-base-chinese-vocab.txt"







class BertCrf(nn.Module):
def __init__(self,config_name:str,model_name:str = None,num_tags: int = 2, batch_first:bool = True) -> None:
# 记录batch_first
Expand Down Expand Up @@ -58,7 +53,6 @@ def __init__(self,config_name:str,model_name:str = None,num_tags: int = 2, batch
self.crf_model = CRF(num_tags=num_tags,batch_first=batch_first)



def forward(self,input_ids:torch.Tensor,
tags:torch.Tensor = None,
attention_mask:Optional[torch.ByteTensor] = None,
Expand All @@ -68,29 +62,23 @@ def forward(self,input_ids:torch.Tensor,

emissions = self.bertModel(input_ids = input_ids,attention_mask = attention_mask,token_type_ids=token_type_ids)[0]

# 这里在seq_len的维度上去头,是去掉了[CLS],去尾巴有两种情况
# 1、是 <pad> 2、[SEP]


# 这里在seq_len的维度上去掉开头,是去掉了[CLS],去尾巴有两种情况
# 第一种情况是 <pad>、第二种情况是 [SEP]
new_emissions = emissions[:,1:-1]
new_mask = attention_mask[:,2:].bool()

# 如果 tags 为 None,表示是一个预测的过程,不能求得loss,loss 直接为None
# 如果 tags 为 None,表示是一个预测的过程,不能求得 loss,则 loss 的值直接为 None
if tags is None:
loss = None
pass
else:
new_tags = tags[:, 1:-1]
loss = self.crf_model(emissions=new_emissions, tags=new_tags, mask=new_mask, reduction=reduction)



if decode:
tag_list = self.crf_model.decode(emissions = new_emissions,mask = new_mask)
return [loss, tag_list]

return [loss]




17 changes: 10 additions & 7 deletions CRF_Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import torch
import torch.nn as nn

"""
CRF 条件随机场模型;
"""

class CRF(nn.Module):
def __init__(self,num_tags : int = 2, batch_first:bool = True) -> None:
Expand All @@ -10,7 +13,8 @@ def __init__(self,num_tags : int = 2, batch_first:bool = True) -> None:
super().__init__()
self.num_tags = num_tags
self.batch_first = batch_first
# start 到其他tag(不包含end)的得分
# start 到其他 tag (不包含 end) 的得分
# (从开始节点到其他非 end 节点的 scores)
self.start_transitions = nn.Parameter(torch.empty(num_tags))
# 到其他tag(不包含start)到end的得分
self.end_transitions = nn.Parameter(torch.empty(num_tags))
Expand All @@ -21,6 +25,7 @@ def __init__(self,num_tags : int = 2, batch_first:bool = True) -> None:

self.reset_parameters()

# 对参数进行重新设置
def reset_parameters(self):
init_range = 0.1
nn.init.uniform_(self.start_transitions,-init_range,init_range)
Expand All @@ -30,6 +35,7 @@ def reset_parameters(self):
def __repr__(self):
return f'{self.__class__.__name__}(num_tags={self.num_tags})'

# 向前传播;
def forward(self, emissions:torch.Tensor,
tags:torch.Tensor = None,
mask:Optional[torch.ByteTensor] = None,
Expand All @@ -42,6 +48,7 @@ def forward(self, emissions:torch.Tensor,
raise ValueError(f'invalid reduction {reduction}')

if mask is None:
#生成值全为1的张量,用于掩码
mask = torch.ones_like(tags,dtype = torch.uint8)
# a.shape (seq_len,batch_size)
# a[0] shape ? batch_size
Expand Down Expand Up @@ -81,10 +88,6 @@ def decode(self,emissions:torch.Tensor,
return self._viterbi_decode(emissions,mask)






def _validate(self,
emissions:torch.Tensor,
tags:Optional[torch.LongTensor] = None ,
Expand Down Expand Up @@ -146,7 +149,7 @@ def _computer_score(self,
# 这里是为了获取每一个样本最后一个词的tag。
# shape: (batch_size,) 每一个batch 的真实长度
seq_ends = mask.long().sum(dim=0) - 1
# 每个样本最火一个词的tag
# 每个样本最后一个词的tag
last_tags = tags[seq_ends,torch.arange(batch_size)]
# shape: (batch_size,) 每一个样本到最后一个词的得分加上之前的score
score += self.end_transitions[last_tags]
Expand Down Expand Up @@ -250,4 +253,4 @@ def _viterbi_decode(self,emissions : torch.FloatTensor ,

best_tags.reverse()
best_tags_list.append(best_tags)
return best_tags_list
return best_tags_list
12 changes: 6 additions & 6 deletions test_NER.py → NERTest.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from BERT_CRF import BertCrf
from BERT_CRF_Model import BertCrf
from transformers import BertTokenizer
from NER_main import NerProcessor,statistical_real_sentences,flatten,CrfInputFeatures
from NERTrain import NerProcessor,statistical_real_sentences,flatten,CrfInputFeatures
from torch.utils.data import DataLoader, RandomSampler,TensorDataset
from sklearn.metrics import classification_report
import torch
import numpy as np
from tqdm import tqdm, trange




"""
对命名实体识别模型进行简单的测试
"""

processor = NerProcessor()
tokenizer_inputs = ()
Expand Down Expand Up @@ -81,4 +81,4 @@
#
# micro avg 0.996142 0.996142 0.996142 145137
# macro avg 0.994650 0.994380 0.994512 145137
# weighted avg 0.996149 0.996142 0.996143 145137
# weighted avg 0.996149 0.996142 0.996143 145137
Loading