前面几章简单介绍了CRF层的作用以及CRF层的损失函数,详见:
下面使用tensorflow1.x版本实现BiLstm+CRF模型,并基于“万创杯”中医药天池大数据竞赛—中药说明书实体识别挑战的比赛数据实现中药NER任务。
1.bilstm+crf模型
该文件定义了embedding层,bilstm层,全链接层,crf层等模型。
# -*- coding: utf-8 -*-
# @Time : 2020-10-09 21:15
# @Author : xudong
# @email : dongxu222mk@163.com
# @Site :
# @File : bilstm_crf.py
# @Software: PyCharm
import tensorflow as tf
from tensorflow.contrib.rnn import LSTMCell
from tensorflow.contrib.rnn import MultiRNNCell
class Linear:
"""
全链接层
"""
def __init__(self, scope_name, input_size, output_size,
drop_out=0., trainable=True):
with tf.variable_scope(scope_name):
self.W = tf.get_variable('W', [input_size, output_size],
initializer=tf.random_uniform_initializer(-0.25, 0.25),
trainable=trainable)
self.b = tf.get_variable('b', [output_size],
initializer=tf.zeros_initializer(),
trainable=trainable)
self.drop_out = tf.layers.Dropout(drop_out)
self.output_size = output_size
def __call__(self, inputs, training):
size = tf.shape(inputs)
input_trans = tf.reshape(inputs, [-1, size[-1]])
input_trans = tf.nn.xw_plus_b(input_trans, self.W, self.b)
input_trans = self.drop_out(input_trans, training=training)
input_trans = tf.reshape(input_trans, [-1, size[1], self.output_size])
return input_trans
class LookupTable:
"""
embedding layer
"""
def __init__(self, scope_name, vocab_size, embed_size, reuse=False, trainable=True):
self.vocab_size = vocab_size
self.embed_size = embed_size
with tf.variable_scope(scope_name, reuse=bool(reuse)):
self.embedding = tf.get_variable('embedding', [vocab_size, embed_size],
initializer=tf.random_uniform_initializer(-0.25, 0.25),
trainable=trainable)
def __call__(self, input):
input = tf.where(tf.less(input, self.vocab_size), input, tf.ones_like(input))
return tf.nn.embedding_lookup(self.embedding, input)
class LstmBase:
"""
build rnn cell
"""
def build_rnn(self, hidden_size, num_layes):
cells = []
for i in range(num_layes):
cell = LSTMCell(num_units=hidden_size,
state_is_tuple=True,
initializer=tf.random_uniform_initializer(-0.25, 0.25))
cells.append(cell)
cells = MultiRNNCell(cells, state_is_tuple=True)
return cells
class BiLstm(LstmBase):
"""
define the lstm
"""
def __init__(self, scope_name, hidden_size, num_layers):
super(BiLstm, self).__init__()
assert hidden_size % 2 == 0
hidden_size /= 2