#tf.nn.embedding_lookup:可以把这个函数的数据切片原理想象成numpy的花式索引。 #index_1 = tf.Variable([1, 2, 3, 0], tf.int32),以这个为例,就是按顺序,逐行取出a矩阵中第2,3,4,1,行 #index_2 = tf.Variable([[1, 2, 3, 0], [4, 1, 3, 3], [4, 1, 3, 3]], tf.int32) ;把index_1中取出的看做"一堆",那这个就是取了"三堆" #下面的小栗子,run一下,你就知道了 import tensorflow as tf import numpy as np np.random.seed(24) a = 0.1 * np.random.randint(1,10,size=(5,3)) index_1 = tf.Variable([1, 2, 3, 0], tf.int32) index_2 = tf.Variable([[1, 2, 3, 0], [4, 1, 3, 3], [4, 1, 3, 3]], tf.int32) out1 = tf.nn.embedding_lookup(a, index_1) out2 = tf.nn.embedding_lookup(a, index_2) init = tf.global_variables_initializer() with tf.Session() as sess: print('a:---------------\n', a) sess.run(init) print('out1:---------------\n',sess.run(out1)) print(out1) print(' -- * -- '*10) print('out2:---------------\n',sess.run(out2)) print(out2)
tf.nn.embedding_lookup 举例说明,简单易懂
最新推荐文章于 2023-09-18 20:48:41 发布