Tensorflow 中 crf_decode 和 viterbi_decode 的使用
tf.contrib.crf
官方文档:https://www.tensorflow.org/api_docs/python/tf/contrib/crf?hl=en
Classes
class CrfDecodeBackwardRnnCell: Computes backward decoding in a linear-chain CRF.
class CrfDecodeForwardRnnCell: Computes the forward decoding in a linear-chain CRF.
class CrfForwardRnnCell: Computes the alpha values in a linear-chain CRF.
Functions
crf_binary_score(...): Computes the binary scores of tag sequences.
crf_decode(...): Decode the highest scoring sequence of tags in TensorFlow.
crf_log_likelihood(...): Computes the log-likelihood of tag sequences in a CRF.
crf_log_norm(...): Computes the normalization for a CRF.
crf_multitag_sequence_score(...): Computes the unnormalized score of all tag sequences matching tag_bitmap.
crf_sequence_score(...): Computes the unnormalized score for a tag sequence.
crf_unary_score(...): Computes the unary scores of tag sequences.
viterbi_decode(...): Decode the highest scoring sequence of tags outside of TensorFlow.
势函数(potential):
unary potential:一元势函数:衡量单个状态(某个字,或某个像素)的似然标签
pairwise potential:二元势函数:衡量相邻状态之间的势能
http://ai.stanford.edu/~pawan/teaching/optimization/lecture5.pdf
https://blog.csdn.net/lansatiankongxxc/article/details/45590545
训练过程(参数估计问题)
Tensorflow 中 tf.contrib.crf.crf_log_likelihood
用于计算crf_loss,
在 bi-lstm + crf 或 idcnn + crf 结构中作为crf的网络的损失函数。
对于损失函数的实现细节,可以参考这篇文章:CRF/Seq2Seq/CTC的Loss实现对比
tf.contrib.crf.crf_log_likelihood(
inputs,
tag_indices,
sequence_lengths,
transition_params=None
)
Defined in tensorflow/contrib/crf/python/ops/crf.py.
Computes the log-likelihood of tag sequences in a CRF.
参数:
- inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials to use as input to the CRF layer.
- tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we compute the log-likelihood.
- sequence_lengths: A [batch_size] vector of true sequence lengths.
- transition_params: A [num_tags, num_tags] transition matrix, if available.
返回值:
- log_likelihood: A [batch_size] Tensor containing the log-likelihood of each example, given the sequence of tag indices.
- transition_params: A [num_tags, num_tags] transition matrix. This is either provided by the caller or created in this function.
解码过程(序列问题)
看tensorflow的文档,说明 viterbi_decode 和 crf_decode 实现了相同功能,前者是numpy的实现,后者是 tensor 的实现,本文为了验证两者的解码结果是一致的。
测试环境:python3.6 + tensorflow1.8
import tensorflow as tf
from tensorflow.contrib.crf import viterbi_decode
from tensorflow.contrib.crf import crf_decode
score = [[
[1, 2, 3],
[2, 1, 3],
[1, 3, 2],
[3, 2, 1]
]] # (batch_size, time_step, num_tabs)
transition = [
[2, 1, 3],
[1, 3, 2],
[3, 2, 1]
] # (num_tabs, num_tabs)
lengths = [len(score[0])] # (batch_size, time_step)
# numpy
print("[numpy]")
np_op = viterbi_decode(
score=np.array(score[0]),
transition_params=np.array(transition))
print(np_op[0])
print(np_op[1])
print("=============")
# tensorflow
score_t = tf.constant(score, dtype=tf.int64)
transition_t = tf.constant(transition, dtype=tf.int64)
lengths_t = tf.constant(lengths, dtype=tf.int64)
tf_op = crf_decode(
potentials=score_t,
transition_params=transition_t,
sequence_length=lengths_t)
with tf.Session() as sess:
paths_tf, scores_tf = sess.run(tf_op)
print("[tensorflow]")
print(paths_tf)
print(scores_tf)
输出:
[numpy]
[2, 0, 2, 0]
19
=============
[tensorflow]
[[2 0 2 0]]
[19]
可见结果是一致的,说明 crf_decode 就是 viterbi_decode 的tensorflow版本。
以下给出结果的计算说明:
tensorflow中两个函数的文档:https://www.tensorflow.org/api_docs/python/tf/contrib/crf
tf.contrib.crf.crf_decode(
potentials,
transition_params,
sequence_length
)
Defined in tensorflow/contrib/crf/python/ops/crf.py.
Decode the highest scoring sequence of tags in TensorFlow.
This is a function for tensor.
Args:
potentials: A [batch_size, max_seq_len, num_tags] tensor of unary potentials.
transition_params: A [num_tags, num_tags] matrix of binary potentials.
sequence_length: A [batch_size] vector of true sequence lengths.
Returns:
decode_tags: A [batch_size, max_seq_len] matrix, with dtype tf.int32. Contains the highest scoring tag indices.
best_score: A [batch_size] vector, containing the score of decode_tags.
tf.contrib.crf.viterbi_decode(
score,
transition_params
)
Defined in tensorflow/contrib/crf/python/ops/crf.py.
Decode the highest scoring sequence of tags outside of TensorFlow.
This should only be used at test time.
Args:
score: A [seq_len, num_tags] matrix of unary potentials.
transition_params: A [num_tags, num_tags] matrix of binary potentials.
Returns:
viterbi: A [seq_len] list of integers containing the highest scoring tag indices.
viterbi_score: A float containing the score for the Viterbi sequence.