官网API
参数
tf.nn.embedding_lookup(
params, ids, partition_strategy='mod', name=None, validate_indices=True,
max_norm=None
)
- params:由一个tensor或者多个tensor组成的列表(多个tensor组成时,每个tensor除了第一个维度其他维度需相等);
- ids:一个类型为int32或int64的Tensor,包含要在params中查找的id;
- partition_strategy:逻辑index是由partition_strategy指定,partition_strategy用来设定ids的切分方式,目前有两种切分方式’div’和’mod’.
- name:操作名称(可选)
- validate_indices: 是否验证收集索引
- max_norm: 如果不是None,嵌入值将被l2归一化为max_norm的值
作用
选取一个张量里面索引对应的元素;
寻找的embedding data中的对应的行下的vector。
例子
#coding:utf-8
import tensorflow as tf
import numpy as np
c = np.random.random([5,1]) ##随机生成一个5*1的数组
b = tf.nn.embedding_lookup(c, [1, 3]) ##查找数组中的序号为1和3的
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
print(sess.run(b))
print(c)
输出的结果如下所示:
[[0.5687709 ]
[0.61091257]]
[[0.31777381]
[0.5687709 ]
[0.1779548 ]
[0.61091257]
[0.65478204]]
在c中第2个元素为0.5687709,第4个元素是0.61091257(索引从0开始),刚好是b的值