import numpy as np
import tensorflow as tf
c = np.random.random([5,8])#[vocab_size,embedding_size]
embedding = tf.Variable(c)
b = tf.nn.embedding_lookup(embedding, [[1,3],[2,4]])#[batch_size,sentence_max_len](inputs)
print(b.shape)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print (sess.run(embedding))
print (sess.run(b))
(2, 2, 8)
[[0.95798472 0.74286751 0.7751764 0.86245725 0.53302332 0.69279576
0.00691748 0.55526062]
[0.31221902 0.89075832 0.83511314 0.62149053 0.76858983 0.96638756
0.3284697 0.54494426]
[0.1799067 0.91003457 0.04127924 0.26612211 0.82033345 0.2144692
0.54308715 0.05850638]
[0.20705396 0.59211751 0.81545149 0.6189804 0.33866672 0.70442117
0.22983621 0.64041192]
[0.14165099 0.62781607 0.72576032 0.39587963 0.60724242 0.47600276
0.98226275 0.49256804]]
[[[0.31221902 0.89075832 0.83511314 0.62149053 0.76858983 0.96638756
0.3284697 0.54494426]
[0.20705396 0.59211751 0.81545149 0.6189804 0.33866672 0.70442117
0.22983621 0.64041192]]
[[0.1799067 0.91003457 0.04127924 0.26612211 0.82033345 0.2144692
0.54308715 0.05850638]
[0.14165099 0.62781607 0.72576032 0.39587963 0.60724242 0.47600276
0.98226275 0.49256804]]]
tf.nn.embedding_lookup(embedding_dict,inputs)
就是查询inputs数值所在embedding_dict的索引值,inputs是二维时,一般是[batch_size,sentence_max_len]
则输出为[batch_size,max_len,embedding_size]
如果inputs是一维时更好理解了,inputs是[需要映射的index]
则输出为[len(inputs),embedding_size]
import numpy as np
import tensorflow as tf
c = np.random.random([5,8])#[vocab_size,embedding_size]
embedding = tf.Variable(c)
b = tf.nn.embedding_lookup(embedding, [1,4])#[batch_size,sentence_max_len](inputs)
print(b.shape)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print (sess.run(embedding))
print (sess.run(b))
(2, 8)
[[0.01986623 0.79118754 0.54255391 0.79699991 0.70534842 0.64375866
0.19962736 0.89929266]
[0.52227914 0.50747695 0.64987427 0.55481673 0.0735536 0.582565
0.65931548 0.41810564]
[0.22290179 0.8301516 0.26772186 0.23693508 0.78434498 0.64220977
0.4993044 0.54788336]
[0.80896985 0.1762523 0.72304219 0.12117844 0.90894192 0.55810879
0.52866756 0.98105626]
[0.71503674 0.7357023 0.58702779 0.67004222 0.37100633 0.65433607
0.20375213 0.38989389]]
[[0.52227914 0.50747695 0.64987427 0.55481673 0.0735536 0.582565
0.65931548 0.41810564]
[0.71503674 0.7357023 0.58702779 0.67004222 0.37100633 0.65433607
0.20375213 0.38989389]]