这是一个知乎上的经典问题,为什么 Bert 的三个 Embedding 可以进行相加?
其中,苏剑林老师的解释感觉很有意思:
Embedding的数学本质,就是以one hot为输入的单层全连接。请参考: https://kexue.fm/archives/4122
也就是说,世界上本没什么Embedding,有的只是one hot。
现在我们将token,position,segment三者都用one hot表示,然后concat起来,然后才去过一个单层全连接,等价的效果就是三个Embedding相加
简单理解起来就是:三个向量concat之后走一次全连接,等价于各自embedding之后相加。
数学解释
对于token,seg,pos的三个one-hot向量:v1T,v2T,v3T。先contact在全连接,等价于对应embedding之和
[ v 1 T , v 2 T , v 3 T ] W = [ v 1 T , v 2 T , v 3 T ] [ w 1 , w 2 , w 3 ] = v 1 T w 1 + v 2 T w 2 + v 3 T w 3 [v1^T,v2^T,v3^T]W=[v1^T,v2^T,v3^T][w1,w2,w3]=v1^Tw1+v2^Tw2+v3^Tw3 [v1T,v2T,v3T]W=[v1T,v2T,v3T][w1,w2,w3]=v1Tw1+v2Tw2+v3Tw3
代码解释
首先我们简单地假设我们有一个token,我们假设我们的字典大小(vocabulary_size) = 5, 对应的的token_id 是2,这个token所在的位置是第0个位置,我们最大的位置长度为max_position_size = 6,以及我们可以有两种segment,这个token是属于segment = 0的情况。首先我们分别对三种不同类型的分别进行 embedding lookup的操作,
下面的代码中我们,固定了三种类型的embedding matrix,分别是token_embedding,position_embedding,segment_embedding。
首先我们要清楚,正常的embedding lookup就是embedding id 进行onehot之后,然后在和embedding matrix 进行矩阵相乘,具体看例子中的 embd_embd_onehot_impl 和 embd_token,这两个的结果是一致的。
我们分别得到了三个类别数据的embedding之后(embd_token, embd_position, embd_sum),再将它们进行相加,得到embd_sum。
其结果跟,将三个类别进行onehot之后的结果concat起来,再进行embedding lookup的结果是一致的。
比如下面,我们将token_id_onehot, position_id_onehot, segment_id_onehot 这三个onehot后的结果concat起来得到concat_id_onehot, 与三者的embedding matrix的concat后的结果concat_embedding,进行矩阵相乘,得到的结果 embd_cat。
import tensorflow as tf
token_id = 2
vocabulary_size = 5
position = 0
max_position_size = 6
segment_id = 0
segment_size = 2
embedding_size = 4
token_embedding = tf.constant([[-3.,-2,-1, 0],[1,2,3,4], [5,6,7,8], [9,10, 11,12], [13,14,15,16]]) #size: (vocabulary_size, embedding_size)
position_embedding = tf.constant([[17., 18, 19, 20], [21,22,23,24], [25,26,27,28], [29,30,31,32], [33,34,35,36], [37,38,39,40]]) #size:(max_position_size, embedding_size)
segment_embedding = tf.constant([[41.,42,43,44], [45,46,47,48]]) #size:(segment_size, embedding_size)
token_id_onehot = tf.one_hot(token_id, vocabulary_size)
position_id_onehot = tf.one_hot(position, max_position_size)
segment_id_onehot = tf.one_hot(segment_id, segment_size)
embd_embd_onehot_impl = tf.matmul([token_id_onehot], token_embedding)
embd_token = tf.nn.embedding_lookup(token_embedding, token_id)
embd_position = tf.nn.embedding_lookup(position_embedding, position)
embd_segment = tf.nn.embedding_lookup(segment_embedding, segment_id)
embd_sum = tf.reduce_sum([embd_token, embd_position, embd_segment], axis=0)
concat_id_onehot = tf.concat([token_id_onehot, position_id_onehot, segment_id_onehot], axis=0)
concat_embedding = tf.concat([token_embedding, position_embedding, segment_embedding], axis=0)
embd_cat = tf.matmul([concat_id_onehot], concat_embedding)
with tf.Session() as sess:
print(sess.run(embd_embd_onehot_impl)) # [[5. 6. 7. 8.]]
print(sess.run(embd_token)) # [5. 6. 7. 8.]
print(sess.run(embd_position)) # [17. 18. 19. 20.]
print(sess.run(embd_segment)) # [41. 42. 43. 44.]
print(sess.run(embd_sum)) # [63. 66. 69. 72.]
print(sess.run(embd_cat)) # [[63. 66. 69. 72.]] # 结果一样