官方
tf.nn.embedding_lookup:通过给定的 ids 查找 params 的词向量
tf.nn.embedding_lookup(
params, ids, max_norm=None, name=None
)
代码讲解
核心:输出结果维度为shape(ids) + shape(params)[1:]
示例1:params 和 ids 都只有一维,就是简单的索引
输入:params=(pdim1,),ids=(idim1,)
输出结果:shape=(ids)即shape=(idim1,)
import tensorflow as tf
params = tf.range(1, 11)
ids = [0, 3, 4]
# 输出tf.Tensor([ 1 2 3 4 5 6 7 8 9 10], shape=(10,), dtype=int32)
print(params)
y = tf.nn.embedding_lookup(params, ids)
# 输出tf.Tensor([1 4 5], shape=(3,), dtype=int32)
print(y)
示例2:params只有一维,ids多维
输入:params=(pdim1, ),ids=(idim1, idim2)
输出结果:shape=(idim1, idim2)
import tensorflow as tf
params = tf.range(1, 11)
ids = [[0], [1]]
# 输出tf.Tensor([ 1 2 3 4 5 6 7 8 9 10], shape=(10,), dtype=int32)
print(params)
y = tf.nn.embedding_lookup(params, ids)
# 输出tf.Tensor( [[1]
# [2]], shape=(2, 1), dtype=int32)
print(y)
示例3:params多维,ids一维
输入:params=(pdim1,pdim2),ids=(idim1)
输出结果:shape=(idim1,pdim2)
import tensorflow as tf
params = tf.range(1, 11)
params = tf.reshape(params, (5, 2))
ids = [0, 3, 4]
# 输出tf.Tensor(
# [[ 1 2]
# [ 3 4]
# [ 5 6]
# [ 7 8]
# [ 9 10]], shape=(5, 2), dtype=int32)
print(params)
y = tf.nn.embedding_lookup(params, ids)
# 输出tf.Tensor(
# [[ 1 2]
# [ 7 8]
# [ 9 10]], shape=(3, 2), dtype=int32)
print(y)
示例4:param多维,ids多维
输入:params=(pdim1,pdim2),ids=(idim1,idim2)
输出结果:shape=(idim1,idim2,pdim2)
import tensorflow as tf
params = tf.range(1, 11)
params = tf.reshape(params, (2, 5))
ids = [[1, 0], [2, 3]]
# 输出tf.Tensor(
# [[ 1 2 3 4 5]
# [ 6 7 8 9 10]], shape=(2, 5), dtype=int32)
print(params)
y = tf.nn.embedding_lookup(params, ids)
# 输出tf.Tensor(
# [[[ 6 7 8 9 10]
# [ 1 2 3 4 5]]
#
# [[ 0 0 0 0 0]
# [ 0 0 0 0 0]]], shape=(2, 2, 5), dtype=int32)
print(y)
注意点:当 ids 的索引超出了 params 的范围时,用0替代