tf.nn.embedding_lookup(params, ids)
params: 一个张量或者数组
ids: 一个整型列表或一个二维矩阵,当输入为二维矩阵的时候,在CNN的时候会用,批量输入的时候ids为二维矩阵
该函数的目的是从params矩阵中返回行索引=ids中的元素的行向量组成矩阵
ids输入为一维列表的时候
import tensorflow as tf
table = tf.Variable(tf.random_normal([10, 5]))
b = tf.nn.embedding_lookup(table, [1, 4, 6, 7])
with tf.Session()as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(table))
print(sess.run(b))
输出
[[ 0.2844954 1.0876138 0.2640958 -1.3939503 1.9493129 ]
[-0.34022513 -0.22206968 0.19959041 -0.43038854 0.7214721 ]
[ 1.2583389 -0.41636813 0.5526711 -0.04547537 -2.220672 ]
[ 0.6416701 -0.04626859 -1.2670921 -1.0875092 -1.1969252 ]
[-0.9369289 0.01590852 -1.0708148 -1.0230598 0.6950529 ]
[-1.109506 0.43983954 1.1148814 0.48612115 -0.22546312]
[ 0.7978611 -0.32981223 0.9465104 0.11148026 -0.8291709 ]
[ 1.7482463 -0.84183437 -0.5938833 1.2219574 1.6940571 ]
[ 0.3316857 -0.0637491 1.3450751 1.5049508 -0.66448265]
[-0.56729424 -0.5770627 1.1358143 0.52266353 -2.49519 ]]
[[-0.34022513 -0.22206968 0.19959041 -0.43038854 0.7214721 ]
[-0.9369289 0.01590852 -1.0708148 -1.0230598 0.6950529 ]
[ 0.7978611 -0.32981223 0.9465104 0.11148026 -0.8291709 ]
[ 1.7482463 -0.84183437 -0.5938833 1.2219574 1.6940571 ]]
ids为二维矩阵
import tensorflow as tf
table = tf.Variable(tf.random_normal([10, 5]))
b = tf.nn.embedding_lookup(table, [[1, 4, 6, 7],[1,4,2,5]])
with tf.Session()as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(table))
print(sess.run(b))