在苏神写的unlim代码,本身由于keras不友好的构件图逻辑判断,所以没办法只能按照原始tensorflow去重新理解一下,为torch的模型蒸馏提供基础。
首先我们假设Input-Segment的数值为:
a=tf.constant([[0,0,0,0,1,1,1,1,1,1]])
之所以第一个句子为0,原因是下边我们需要计算第二个句子预测每一时刻time_step需要几个单词信息第一行代码为:
idxs = K.cumsum(a, axis=1)
此行代码得到的结果为:
[[0 0 0 0 1 2 3 4 5 6]]
所以是对于两个句子的逐渐递增,数值的作用就是来说明每个时刻需要几个字符的信息来预测当前的字符。
第二行代码应该是unlim实现的核心代码,比较简洁
mask = idxs[:, None, :] <= idxs[:, :, None]
一行代码,但是涵盖的信息比较丰富,直接给出结果就了解了,最后得到的结果为:
[[[ True True True True False False False False False False]
[ True True True True False False False False False False]
[ True True True True False False False False False False]
[ True True True True False False False False False False]
[ True True True True True False False False False False]
[ True True True True True True False False False False]
[ True True True True True True True False False False]
[ True True True True True True True True False False]
[ True True True True True True True True True False]
[ True True True True True True True True True True]]]
很明显就是我们要实现unlim对两个句子mask的结果。 然后把bool类型转化为float类型
mask = K.cast(mask, K.floatx())
结果为:
[[[1. 1. 1. 1. 0. 0. 0. 0. 0. 0.]
[1. 1. 1. 1. 0. 0. 0. 0. 0. 0.]
[1. 1. 1. 1. 0. 0. 0. 0. 0. 0.]
[1. 1. 1. 1. 0. 0. 0. 0. 0. 0.]
[1. 1. 1. 1. 1. 0. 0. 0. 0. 0.]
[1. 1. 1. 1. 1. 1. 0. 0. 0. 0.]
[1. 1. 1. 1. 1. 1. 1. 0. 0. 0.]
[1. 1. 1. 1. 1. 1. 1. 1. 0. 0.]
[1. 1. 1. 1. 1. 1. 1. 1. 1. 0.]
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]]
然后把0的部份利用极小值替代,在整个attention以及self-attention过程中都不参与当前时刻time_step的预测:
mask=-(1 - mask[:, None]) * 1e12
得到结果为
[[[[ 0.e+00 0.e+00 0.e+00 0.e+00 -1.e+12 -1.e+12 -1.e+12 -1.e+12
-1.e+12 -1.e+12]
[ 0.e+00 0.e+00 0.e+00 0.e+00 -1.e+12 -1.e+12 -1.e+12 -1.e+12
-1.e+12 -1.e+12]
[ 0.e+00 0.e+00 0.e+00 0.e+00 -1.e+12 -1.e+12 -1.e+12 -1.e+12
-1.e+12 -1.e+12]
[ 0.e+00 0.e+00 0.e+00 0.e+00 -1.e+12 -1.e+12 -1.e+12 -1.e+12
-1.e+12 -1.e+12]
[ 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 -1.e+12 -1.e+12 -1.e+12
-1.e+12 -1.e+12]
[ 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 -1.e+12 -1.e+12
-1.e+12 -1.e+12]
[ 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 -1.e+12
-1.e+12 -1.e+12]
[ 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00
-1.e+12 -1.e+12]
[ 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00
0.e+00 -1.e+12]
[ 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 -0.e+00 -0.e+00
-0.e+00 -0.e+00]]]]