tf.contrib.crf.crf_log_likelihood说明

最近在 做一个 NER的项目,使用的是BILSTM+CRF 结构,github,求star。

现在 对 使用 tf.contrib.crf.crf_log_likelihood时,遇到的参数问题 说一下:

官方说明:https://www.tensorflow.org/code/stable/tensorflow/contrib/crf/python/ops/crf.py

tf.contrib.crf.crf_log_likelihood(
    inputs,
    tag_indices,
    sequence_lengths,
    transition_params=None
)

Args:

  • inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials to use as input to the CRF layer.
  • tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we compute the log-likelihood.
  • sequence_lengths: A [batch_size] vector of true sequence lengths.
  • transition_params: A [num_tags, num_tags] transition matrix, if available.

Returns:

  • log_likelihood: A [batch_size] Tensor containing the log-likelihood of each example, given the sequence of tag indices.
  • transition_params: A [num_tags, num_tags] transition matrix. This is either provided by the caller or created in this function.

下面只说入参:

inputs: 经过BILSTM层  处理后的 数据,格式为  [batch_size, max_seq_len, num_tags]

tag_indices: 就是 整个 项目的 入参

sequence_lengths: 该参数 是主要说明的,英文 直译过来就是:包括实际序列长度,形状为[batch_size] 的向量。下面会详细说的

transition_params:状态转移矩阵

详细说下:sequence_lengths

  首先,请记住 sequence_lengths是一个向量,

下面举个例子:

比如:batch_size=4, max_seq_len=5

那么,最终的 sequence_lengths 为[v1,v2,v3,v4] 且 v1<=5,v2<=5,v3<=5,v4<=5,好了,大概格式 和 数字的范围到现在已经知道了,那么 这些 v值,是 怎么确认的呢?

在NLP中 有很多句子大于max_seq_len,或者小于max_seq_len。对于大于max_seq_len的句子直接截取 为长度为max_seq_len的句子即可,在截取后的 句子中的每一个 词 都是有效的。但是 对于小于max_seq_len的句子,此时就需要 padding了,padding的词 都是无意义的,只是 为了 形成 进入NN的结构。所以 此处v的值就是 记录 该 句子 未padding前的 真实的长度。明白了吧。

到此,实际上 你已经可以正常使用这个 API了,但是,如果 你还要 问,为什么 是 这中格式呢?那咱们 继续 看源码:

   打开 上面 github的地址,

 crf_log_likelihood->crf_sequence_score->crf_unary_score 找到此方法的 293行

  masks = array_ops.sequence_mask(sequence_lengths,
                                  maxlen=array_ops.shape(tag_indices)[1],
                                  dtype=dtypes.float32)

然后找到sequence_mask源码:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/array_ops.py

3063行  sequence_mask:

  以上面的例子 来说该方法,该方法 返回一个 sequence_lengths * max_seq_len 的矩阵,也即 4*5 的mask矩阵,该矩阵 用来 计算 后续 损失时,将 无效 词和tag 去除。这里面的值 都是 如何 形成的呢,Aij,i=0,1,2,3,j=0,1,2,3,4 其中i为 sequence_length索引,j=range(max_seq_len)

Aij=true if j<sequence_length[i] else false

下面是例子,明白了吧

  tf.sequence_mask([1, 3, 2], 5)  # [[True, False, False, False, False],
                                  #  [True, True, True, False, False],
                                  #  [True, True, False, False, False]]

 

知乎: https://zhuanlan.zhihu.com/albertwang

微信公众号:AI-Research-Studio

https://img-blog.csdnimg.cn/20190110102516916.png ​​

下面是赞赏码

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值