embedding_lookup是什么?
tf.nn.embedding_lookup函数的用法主要是选取一个张量里面索引对应的元素。embedding就是将输入文本表达成向量形式。向量化后需要用索引来查询对应向量,embedding_lookup就是帮助开发者来完成索引向量查询的。
为什么用embedding_lookup?
tf框架里,它会为输入的张量自动建立one-hot索引,但建立好的索引该如何与之后embedding向量对应起来并查询呢?这就需要通过索引-向量mapping表中去拿,此时,embedding_lookup就会帮助你完成这个操作。其实embedding_lookup本质是做了一次常规的线性变换,Z = WX + b。相当于通过one-hot的Weight矩阵,帮助使用者取出了矩阵中对应的那一行。相当于变相进行了一次矩阵相乘运算。看起来像查表一样。
什么时候使用embedding_lookup?
通常在训练之初,开始进行embedding的时候。看一下embedding_lookup的定义。
tf.nn.embedding_lookup(params, ids, partition_strategy='mod', name=None, validate_indices=True, max_norm=None)
TensorFlow官方文档定义见这里
其中params输入整型矩阵,用于给出索引embedding的idx。注意输出结果的shape其实相当于[params.shape(), embedding_shape()]
怎么使用embedding_lookup?
直接使用代码更容易说明:
import tensorflow as tf
import numpy as np
#a = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 2.2, 2.3], [3.1, 3.2, 3.3], [4.1, 4.2, 4.3]]
a = tf.Variable(np.identity(6, dtype=np.int32)) #np.asarray(a)
idx1 = tf.Variable([0, 2, 3, 1], tf.int32)
idx2 = tf.Variable([[0, 2, 3], [0, 2, 2]], tf.int32)
out1 = tf.nn.embedding_lookup(a, idx1)
out2 = tf.nn.embedding_lookup(a, idx2)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print(sess.run(out1))
print(out1)
print('=' * 30)
print(sess.run(out2))
print(out2)
输出如下:
[[1 0 0 0 0 0]
[0 0 1 0 0 0]
[0 0 0 1 0 0]
[0 1 0 0 0 0]]
Tensor("embedding_lookup/Identity:0", shape=(4, 6), dtype=int32)
==============================
[[[1 0 0 0 0 0]
[0 0 1 0 0 0]
[0 0 0 1 0 0]]
[[1 0 0 0 0 0]
[0 0 1 0 0 0]
[0 0 1 0 0 0]]]
Tensor("embedding_lookup_1/Identity:0", shape=(2, 3, 6), dtype=int32)
之后,我们的out1和out2其实就可以作为训练的输入向量进行训练了。