def embedding_lookup(
params,
ids,
partition_strategy="mod",
name=None,
validate_indices=True,
max_norm=None):
主要的两个参数的理解:
params:一个tensor形式,一般为embedding后的shape,在进行分类中是[vocab_size, embedding_size]形式的值
ids:这个表示的是int数字形式的tensor,我们会根据这个ids进行对params里的值进行查询操作
所以总结来说就是:tf.nn.embedding_lookup()就是根据input_ids中的id,寻找embeddings中的第id行。比如input_ids=[2,8,10],则找出embeddings中第2,8,10行,组成一个tensor返回。
下面给出一段代码示例:
import tensorflow as tf
import numpy as np
a = np.random.rand(5,3)
b = np.array([[1,2],[3,1]])
print(a)
print("*" * 10 + "上面是a的输出结果" + "*" * 10)
print(b)
print("*" * 10 + "上面是b的输出结果" + "*" * 10)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
c = tf.nn.embedding_lookup(a ,b)
print(sess.run(c))
print("*" * 10 + "上面是c的输出结果" + "*" * 10)
下面是输出的结果,从结果中我们可以很明显看到通过b的给出的ids值查找到了a中对应索引位置的值
[[0.84679692 0.06116903 0.12382439]
[0.49653969 0.74192834 0.07095517]
[0.6960441 0.10387642 0.23318349]
[0.43168901 0.5049978 0.83005329]
[0.7418392 0.54963284 0.03186683]]
**********上面是a的输出结果**********
[[1 2]
[3 1]]
**********上面是b的输出结果**********
[[[0.49653969 0.74192834 0.07095517]
[0.6960441 0.10387642 0.23318349]]
[[0.43168901 0.5049978 0.83005329]
[0.49653969 0.74192834 0.07095517]]]
**********上面是c的输出结果**********