keras 的实现unilm的核心代码讲解

在苏神写的unlim代码,本身由于keras不友好的构件图逻辑判断,所以没办法只能按照原始tensorflow去重新理解一下,为torch的模型蒸馏提供基础。

   首先我们假设Input-Segment的数值为:

a=tf.constant([[0,0,0,0,1,1,1,1,1,1]])

之所以第一个句子为0,原因是下边我们需要计算第二个句子预测每一时刻time_step需要几个单词信息第一行代码为:

 idxs = K.cumsum(a, axis=1)

此行代码得到的结果为:

[[0 0 0 0 1 2 3 4 5 6]]

所以是对于两个句子的逐渐递增,数值的作用就是来说明每个时刻需要几个字符的信息来预测当前的字符。

第二行代码应该是unlim实现的核心代码,比较简洁

 mask = idxs[:, None, :] <= idxs[:, :, None]

一行代码,但是涵盖的信息比较丰富,直接给出结果就了解了,最后得到的结果为:

[[[ True  True  True  True False False False False False False]
  [ True  True  True  True False False False False False False]
  [ True  True  True  True False False False False False False]
  [ True  True  True  True False False False False False False]
  [ True  True  True  True  True False False False False False]
  [ True  True  True  True  True  True False False False False]
  [ True  True  True  True  True  True  True False False False]
  [ True  True  True  True  True  True  True  True False False]
  [ True  True  True  True  True  True  True  True  True False]
  [ True  True  True  True  True  True  True  True  True  True]]]

很明显就是我们要实现unlim对两个句子mask的结果。  然后把bool类型转化为float类型               

 mask = K.cast(mask, K.floatx())

结果为:

[[[1. 1. 1. 1. 0. 0. 0. 0. 0. 0.]
  [1. 1. 1. 1. 0. 0. 0. 0. 0. 0.]
  [1. 1. 1. 1. 0. 0. 0. 0. 0. 0.]
  [1. 1. 1. 1. 0. 0. 0. 0. 0. 0.]
  [1. 1. 1. 1. 1. 0. 0. 0. 0. 0.]
  [1. 1. 1. 1. 1. 1. 0. 0. 0. 0.]
  [1. 1. 1. 1. 1. 1. 1. 0. 0. 0.]
  [1. 1. 1. 1. 1. 1. 1. 1. 0. 0.]
  [1. 1. 1. 1. 1. 1. 1. 1. 1. 0.]
  [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]]

  然后把0的部份利用极小值替代,在整个attention以及self-attention过程中都不参与当前时刻time_step的预测:

mask=-(1 - mask[:, None]) * 1e12

 得到结果为

[[[[ 0.e+00  0.e+00  0.e+00  0.e+00 -1.e+12 -1.e+12 -1.e+12 -1.e+12
    -1.e+12 -1.e+12]
   [ 0.e+00  0.e+00  0.e+00  0.e+00 -1.e+12 -1.e+12 -1.e+12 -1.e+12
    -1.e+12 -1.e+12]
   [ 0.e+00  0.e+00  0.e+00  0.e+00 -1.e+12 -1.e+12 -1.e+12 -1.e+12
    -1.e+12 -1.e+12]
   [ 0.e+00  0.e+00  0.e+00  0.e+00 -1.e+12 -1.e+12 -1.e+12 -1.e+12
    -1.e+12 -1.e+12]
   [ 0.e+00  0.e+00  0.e+00  0.e+00  0.e+00 -1.e+12 -1.e+12 -1.e+12
    -1.e+12 -1.e+12]
   [ 0.e+00  0.e+00  0.e+00  0.e+00  0.e+00  0.e+00 -1.e+12 -1.e+12
    -1.e+12 -1.e+12]
   [ 0.e+00  0.e+00  0.e+00  0.e+00  0.e+00  0.e+00  0.e+00 -1.e+12
    -1.e+12 -1.e+12]
   [ 0.e+00  0.e+00  0.e+00  0.e+00  0.e+00  0.e+00  0.e+00  0.e+00
    -1.e+12 -1.e+12]
   [ 0.e+00  0.e+00  0.e+00  0.e+00  0.e+00  0.e+00  0.e+00  0.e+00
     0.e+00 -1.e+12]
   [ 0.e+00  0.e+00  0.e+00  0.e+00  0.e+00  0.e+00 -0.e+00 -0.e+00
    -0.e+00 -0.e+00]]]]


 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值