tf.nn.embedding_lookup函数的工作原理
函数定义:
tf.nn.embedding_lookup(
params,
ids,
partition_strategy='mod',
name=None,
validate_indices=True,
max_norm=None
)
官方解释:
This function is used to perform parallel lookups on the list of tensors in params
,where params
is interpreted as a partitioning of a large embedding tensor.
操作方式:
该函数按照ids
顺序返回params
中的第ids
行。
实例解释:
如上例所示,输入数据为
[
0
,
0
,
0
,
1
,
0
]
[0, 0, 0, 1, 0]
[0,0,0,1,0],对params
执行乘法操作,由于输入是one hot 的原因,
x
⋅
W
x·W
x⋅W的矩阵乘法看起来就像是取了
W
W
W中对应的一行,看起来就像是在查表。这个实例中ids = 3,查表得到第四行数据
[
10
,
12
,
19
]
[10, 12, 19]
[10,12,19]。
实现过程:
针对输入是超高维向量,但是是one hot向量的一种特殊的全连接层的实现方法,其内部实际是包含一个网络结构的,如下图所示。
假设我们想要找到2的embedding值,这个值其实是输入层第二个神经元与embedding层连线的权重值。
使用该函数时,params
给定的初始值是随机的,即全连接层的权值,但Embedding矩阵(params)会跟随网络中的其它参数一起训练,最终得到合适的Embedding向量,该过程可以类比word2vec学习词向量的过程,Embedding向量就是网络的副产品,即神经元与embedding层连线的权重值。