当查找对象是二维张量的时候
代码如下:
import tensorflow as tf
import numpy as np
input_ids = tf.placeholder(tf.int32, shape=[None], name="input_ids")
embedding = tf.Variable(np.identity(5, dtype=np.int32))
input_embedding = tf.nn.embedding_lookup(embedding, input_ids)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
print("embedding=\n", embedding.eval())
print("input_embedding=\n", sess.run(input_embedding, feed_dict={input_ids: [1, 2, 3, 0, 3, 2, 1]}))
输出:
embedding=
[[1 0 0 0 0]
[0 1 0 0 0]
[0 0 1 0 0]
[0 0 0 1 0]
[0 0 0 0 1]]
input_embedding=
[[0 1 0 0 0]
[0 0 1 0 0]
[0 0 0 1 0]
[1 0 0 0 0]
[0 0 0 1 0]
[0 0 1 0 0]
[0 1 0 0 0]]
[Finished in 3.8s]
当查找索引时二维的时候
代码如下:
import tensorflow as tf
import numpy as np
input_ids = tf.placeholder(dtype=tf.int32, shape=[3, 2])
embedding = tf.Variable(np.identity(5, dtype=np.int32))
input_embedding = tf.nn.embedding_lookup(embedding, input_ids)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
print("embedding=\n", embedding.eval())
print("input_embedding=\n", sess.run(input_embedding, feed_dict={input_ids: [[1, 2], [2, 1], [3, 3]]}))
输出如下:
embedding=
[[1 0 0 0 0]
[0 1 0 0 0]
[0 0 1 0 0]
[0 0 0 1 0]
[0 0 0 0 1]]
input_embedding=
[[[0 1 0 0 0]
[0 0 1 0 0]]
[[0 0 1 0 0]
[0 1 0 0 0]]
[[0 0 0 1 0]
[0 0 0 1 0]]]
[Finished in 4.0s]
当查找对象是三维的时候
自己简单试验了一下,观看代码结果方便理解:
import tensorflow as tf
import numpy as np
input_ids = tf.placeholder(tf.int32, shape=[None], name="input_ids")
a=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]
aa=tf.reshape(a,[4,2,3])
embedding = tf.Variable(aa)
input_embedding = tf.nn.embedding_lookup(embedding, input_ids)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
print("embedding=\n", embedding.eval())
print("input_embedding=\n", sess.run(input_embedding, feed_dict={input_ids: [1, 2]}))
输出结果如下:
embedding=
[[[ 1 2 3]
[ 4 5 6]]
[[ 7 8 9]
[10 11 12]]
[[13 14 15]
[16 17 18]]
[[19 20 21]
[22 23 24]]]
input_embedding=
[[[ 7 8 9]
[10 11 12]]
[[13 14 15]
[16 17 18]]]
当查找索引是二维的时候
代码如下:
import tensorflow as tf
import numpy as np
input_ids = tf.placeholder(tf.int32, shape=[2,3], name="input_ids")
a=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]
aa=tf.reshape(a,[4,2,3])
embedding = tf.Variable(aa)
input_embedding = tf.nn.embedding_lookup(embedding, input_ids)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
print("input_embedding=\n", sess.run(input_embedding, feed_dict={input_ids: [[1,2,3], [3,2,1]]}))
输出结果如下,省略了原始数据:
在这里插入代码片
AMHEN算法运行example结果图
epoch 9: 23%|| 1606/7066 [00:10<00:35, 155.83it/s]
加粗的字体字体是在运行过程中不断改变的。tqdm的动态可视化功能。