# observation:[batch_size,num_step,output_dims] 神经网络输出
# transition:[output_dims,output_dims] 转移矩阵
# pi:[batch_size,output_dims] 初始概率矩阵
def lstm_crf_viterbi(observation,transition,pi):
batch_size = observation.shape[0].value
num_step = observation.shape[1].value
output_len = transition.shape[0].value
previous = [] # [B,O]
#记录最终路径
all_path_tag_sequence = []
batch_scores = []
#记录最佳路径
batch_argmax = [[] for b in xrange(batch_size)]
for b in xrange(batch_size):
previous.append(tf.transpose([observation[b][0]+pi[b]]))
for b in xrange(batch_size):
for x in range(1,num_step):
r_pre =tf.transpose(tf.convert_to_tensor([previous[b] for i in range(output_len)]))
r_obs = tf.convert_to_tensor([observation[b][x] for i in range(output_len)])
scores = r_pre + transition + r_obs
scores = tf.convert_to_tensor(scores)
batch_argmax[b].append(tf.squeeze(tf.argmax(scores,1)))
previous[b] = tf.reduce_max(scores,1)
previous[b] = tf.squeeze(previous[b])
print(batch_argmax)
#回溯 (仅最高分)
for b in xrange(batch_size):
best_path = [tf.argmax(previous[b])]
for x in xrange(num_step-2,-1,-1):
best_path.insert(0,batch_argmax[b][x][best_path[0]])
all_path_tag_sequence.append(best_path)
return previous,all_path_tag_sequence
#previous:最高分
#all_path_tag_sequence:最高分路径
tensorflow 使用HMM的 viterbi 计算误差
最新推荐文章于 2021-07-05 17:46:38 发布