kashgari
由于近期老师希望我完成一个知识的构建,因此想着用bert做一下实体识别,偶然发现了Kashgari这个框架感觉挺好用的,就记录一下自己的学习经历吧。
kashgari1.1.5 python3.6
代码
别的不多说直接上代码
#author Hu Nan
import kashgari
from kashgari.tasks.labeling import BiGRU_CRF_Model,BiGRU_Model,BiLSTM_CRF_Model,BiLSTM_Model
from untils_any2vec import bems2data
from kashgari.embeddings import BERTEmbedding
from keras_radam import RAdam
class Bert_NER(object):
def __init__(self,bert_folder,maxlen,modeltype,task=kashgari.LABELING,
trainable=True,layer_num=1,learning_rate=1e-5,
data=[[[],[]],[[],[]],[[],[]]]):
"""
:param bert_folder: #bert_folder is the bert model"s position ,please take note of
it has be tensorflow bert model .bert_folder是bert预训练模型的位置,必须满足他是tensorflow版本的。
:param maxlen: maxlen is the maximum length of sequence,it is suggested to set the longest
sentence length in the data 序列最大长度,建议是训练集最长的句子长度
:param modeltype: what model do you choose as the sequence annotation model;你选用
什么模型完成序列标注
:param task: default is LABELING 默认为LABELING
:param trainable: Do you want fine Bert Bool True or False ,default is True
是否微调Bert,默认是True
:param layer_num: Encoder output layers 编码器输出层数
:param learning_rate: default 1e-5
:param data Data for training and its type is list(list(list())),用于训练的数据,默认是列表类型
"""
self.traindata=data[0][0]
self.trainlabel=data[0][1]
self.devdata=data[1][0]
self.devlabel = data[1][1]
self.testdata= data[2][0]
self.testlabel = data[2][1]
self.bertfolder=bert_folder
self.maxlen=maxlen
self.modeltype=modeltype
self.task=task
self.trainable=trainable
self.layer_num=layer_num
self.learning_rate=learning_rate
def get_model(self):
"""
this method is uesed to get the model of bert ner
:return: bert model frame
生成bert ner模型的计算流图
"""
EMB = BERTEmbedding(model_folder=self.bertfolder,
task=self.task, trainable=self.trainable, sequence_length=self.maxlen,
layer_nums=self.layer_num)
if modeltype in ['BiGRU_CRF_Model']:model = BiGRU_CRF_Model(EMB)
elif modeltype in ['BiLSTM_CRF_Model']:model=BiLSTM_CRF_Model(EMB)
elif modeltype in ['BiGRU_Model']:model = BiGRU_Model(EMB)
else:model = BiLSTM_Model(EMB)
model.build_model(self.traindata , self.trainlabel )
model.compile_model(optimizer=RAdam(self.learning_rate))
return model
def train(self,epoch,bacth_size,evaluate_epoch=1,savemodel=False,savepath=''):
"""
:param epoch: how many epoch to use train 训练的轮次
:param bacth_size:bacth_size
:param evaluate_epoch: which epoch beagin to evaluate 从哪个轮次开始验证在测试集的效果
:param savemodel: can you save the model 是否保存模型
:param savepath: the model"s savepath 模型的存储路径
:return:model
"""
model=self.get_model()
for i in range(epoch):
model.fit(self.traindata, self.trainable,self.devdata,self.devlabel,
epochs=1, batch_size=bacth_size)
if i>=evaluate_epoch:
model.evaluate(self.testdata, self.testlabel)
if savemodel:model.save(savepath)
return model
def loadsweight(self,path):
"""
:param path: the savepath of bertner 提前训练好的模型路径
:return: bertnermodel
"""
model=kashgari.utils.load_model(path)
return model
def predicts(self,data,epoch=100,bacth_size=8,train=True,
evaluate_epoch=1,savemodel=False,savepath='',filepath=''):
"""
:param data: forecast data 要预测的数据
:param savepath: model savepath 模型保存路径
:param filepath:result savepath 结果保存路径
:return:result
"""
if train:
model=self.train(epoch,bacth_size,evaluate_epoch=evaluate_epoch,
savemodel=savemodel,savepath=savepath)
reslut=model.predict(data)
with open(filepath,'a',ecoding='utf-8') as f:
for sentence in range(len(data)):
for char in range(len(data[sentence])):
f.write(data[sentence][char]+' '+reslut[sentence][char]+'\n')
f.write('\n')
f.close()
else:
model=self.loadsweight(path=savepath)
reslut=model.predict(data)
with open(filepath,'a',ecoding='utf-8') as f:
for sentence in range(len(data)):
for char in range(len(data[sentence])):
f.write(data[sentence][char]+' '+reslut[sentence][char]+'\n')
f.write('\n')
f.close()
return reslut
if __name__ == '__main__':
print('hello FAST IE,this is the simplest bert NER')
#untils_any2vec是我自己写的一个文件,所以最后补一下bems2data这个方法。
def bems2data(path,encoding='utf=8'):
"""
path is the position of bmes file,and the sentence of bmes file
must be separated by'\n\n';char_labels is be separated by '\n',separator for
char and label is ' '.
the default encoding is 'utf-8',when your files is other encoding ,pelase replace it
path 是 文件的位置,这个文件句子与句子的间隔是\n\n,句子内则被\n划分,字和其标签之间被
' '分开,函数默认编码是utf-8
:param path:
:return: result_x,result_y
"""
result_x,result_y=[],[]
with open(path,'r',encoding=encoding) as f:
content_list=f.read().split('\n\n')[0:-1]
f.close()
for con in content_list:
x,y,temporary=[],[],con.split('\n')
for char_label in temporary:
try:
char,label=char_label.split(' ')[0],char_label.split(' ')[1]
if len(char)==1:
x.append(char)
y.append(label)
else:
pass
except:
pass
result_x.append(x)
result_y.append(y)
return result_x,result_y