对这个代码的理解:https://blog.csdn.net/buppt/article/details/81180361
一.训练集数据格式
19980101-01-001-004/m 12月/t 31日/t ,/w 美国白宫/nt 发言人/n 、/w 国家/n 总统/n 奥/nr 巴马/nr 发表/v 1998年/t 新年/t 讲话/n 《/w 迈向/v 充满/v 希望/n 的/u 新/a 世纪/n 》/w 。/w (/w 白宫/nt 记者/n 工藤/nr 新一/nr 摄/Vg )/w
二 数据预处理1:把[]中间的合并成一个词,把相邻的/nr 前面的字合并成一个人名。
def originHandle():
with open('./renmin.txt','r',encoding='utf-8') as inp,open('./renmin2.txt','w',encoding='utf-8') as outp:
for line in inp.readlines():
line = line.split(' ') #每行样本形成一个列表。
i = 1 #把第一个元素 时期排除
while i<len(line)-1: # 遍历列表的每个元素
# 第一个if语句判断[]中包含的字为一个词语
if line[i][0]=='[': # line[i] : 第i个元素 line[i][0] :该元素的第一个字符是‘['。比如 '[中国/ns'
outp.write(line[i].split('/')[0][1:]) # line[i].split('/') = ['[中国’,‘ns']。所以[0][1:]的结果是'中国'
i+=1
while i<len(line)-1 and line[i].find(']')==-1: # returen -1 如果没找到。
if line[i]!='':
outp.write(line[i].split('/')[0]) # 取出']'左边的汉字。
i+=1
outp.write(line[i].split('/')[0].strip()+'/'+line[i].split('/')[1][-2:]+' ') # []中间为一个词,故将其合并。
# 第二个if语句判断相邻的两个nr组成一个人名。
elif line[i].split('/')[1]=='nr':
word = line[i].split('/')[0]
i+=1
if i<len(line)-1 and line[i].split('/')[1]=='nr':
outp.write(word+line[i].split('/')[0]+'/nr ')
else:
outp.write(word+'/nr ')
continue
# 最后 其他的不需要处理直接写入即可。
else:
outp.write(line[i]+' ')
i+=1
outp.write('\n')
数据预处理2
def originHandle2():
with open('./renmin2.txt', 'r', encoding='utf-8') as inp, open('./renmin3.txt', 'w',encoding='utf-8') as outp:
for line in inp.readlines():
line = line.split(' ')
i = 0
while i<len(line)-1:
if line[i]=='':
i+=1
continue
# 把词语和标注分离
word = line[i].split('/')[0]
tag = line[i].split('/')[1]
if tag=='nr' or tag=='ns' or tag=='nt':
outp.write(word[0]+"/B_"+tag+" ")
for j in word[1:len(word)-1]:
if j!=' ':
outp.write(j+"/M_"+tag+" ")
outp.write(word[-1]+"/E_"+tag+" ")
else:
for wor in word:
outp.write(wor+'/O ')
i+=1
outp.write('\n')
数据预处理3 把每段话分词多个短句or 短语 ,格式如下:
致/O 以/O 诚/O 挚/O 的/O 问/O 候/O 和/O 良/O 好/O 的/O 祝/O 愿/O
# 把长句分成短句。
def sentence2split():
with open('./renmin3.txt','r',encoding='utf-8') as inp,codecs.open('./renmin4.txt','w','utf-8') as outp:
texts = inp.read()
sentences = re.split('[,。!?、‘’“”:]/[O]', texts)
for sentence in sentences:
if sentence != " ":
outp.write(sentence.strip()+'\n')
数据预处理4 把生成的训练样本序列化保存为pkl文件,以提供为后续调用。
三 训练模型
通过 python train.py 开始训练
1 open 预处理过程中保存的pkl,因为pkl是按顺序保存了这10类数据。
with open('../data/renmindata.pkl', 'rb') as inp:
word2id = pickle.load(inp)
id2word = pickle.load(inp)
tag2id = pickle.load(inp)
id2tag = pickle.load(inp)
x_train = pickle.load(inp)
y_train = pickle.load(inp)
x_test = pickle.load(inp)
y_test = pickle.load(inp)
x_valid = pickle.load(inp)
y_valid = pickle.load(inp)
2 生成可迭代的数据 (BatchGenerator 的作用)
data_train = BatchGenerator(x_train, y_train, shuffle=True)
data_valid = BatchGenerator(x_valid, y_valid, shuffle=False)
data_test = BatchGenerator(x_test, y_test, shuffle=False)
3 进入训练代码段
else:
print ("begin to train...")
# 定义要训练的LSTM模型
model = Model(config,embedding_pre,dropout_keep=0.5)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# 保存和恢复都需要实例化一个 tf.train.Saver
saver = tf.train.Saver()
train(model,sess,saver,epochs,batch_size,data_train,data_test,id2word,id2tag)
4 Model 是bilstm_crf 的一个类, 输入参数为config:配置信息。embedding_pre 预训练词向量。dropout 防止过拟合。
有了输入的X 和Y,嵌入层的数据样式,可以开始搭建LSTM。
#输入数据为[32,60]
self.input_data = tf.placeholder(tf.int32, shape=[self.batch_size,self.sen_len], name="input_data")
self.labels = tf.placeholder(tf.int32,shape=[self.batch_size,self.sen_len], name="labels")
# [3918,100]
self.embedding_placeholder = tf.placeholder(tf.float32,shape=[self.embedding_size,self.embedding_dim], name="embedding_placeholder")
with tf.variable_scope("bilstm_crf") as scope:
self._build_net()
_build_net()方法