tf.nn.embedding_lookup:
定义
tf.nn.embedding_lookup(
params,
ids,
partition_strategy='mod',
name=None,
validate_indices=True,
max_norm=None
)
当params为一个tensor时,很容易理解,ids里的每个数字代表选取在params中0轴所在的对应index。是tf.gather的泛化版。
params = tf.constant([10,20,30,40])
ids = tf.constant([0,1,3])
print tf.nn.embedding_lookup(params,ids).eval()
return [10 20 40]
如上,根据[0,1,3]中的数字,选取[10,20,30,40]的0轴中对应index的[10 20 40]。
但当params是个list的tensor时,ids中的数字就要根据partition_strategy来进行划分。
params1 = tf.constant([1,2])
params2 = tf.constant([10,20])
ids = tf.constant([2,0,2,1,2,3])
result = tf.nn.embedding_lookup([params1, params2], ids)
return [ 2 1 2 10 2 20]
默认划分规则是mod,mod划分规则是这样:
数字0表示params1中的第一个元素
数字1表示params2中的第一个元素
数字2表示params1中的第二个元素
数字3表示params2中的第二个元素
以此类推。。。
这样在上代码中,ids中的2代表的params1中的第二个元素,也就是2;ids中的1代表的params2中的第一个元素,也就是10。
另一种划分规则是div,其规则是:
数字0表示params1中的第一个元素
数字1表示params1中的第二个元素
数字2表示params2中的第一个元素
数字3表示params2中的第二个元素
# tf.nn.embedding_lookup通常用来根据id提取embedding
embeddings = tf.nn.embedding_lookup(self.weights['feature_embeddings'], self.feat_ids)
# self.feat_ids可形如:shape[batch_size,ids_size]
#[[0,1,2],
[2,3,4]]
tf.gather:
定义:
tf.gather(
params,
indices,
validate_indices=None,
name=None,
axis=0
)
沿着axis,根据indices提取params的一个切片。indices 可以是任何shape的整数张量,更高级点的api有tf.batch_gather、tf.gather_nd(二维id).
emb_trans = tf.transpose(embeddings, [1, 0, 2])
emb_left = tf.gather(emb_trans, self.left_index)
tf.gather_nd:
# 根据索引提取数据,可用于topk的索引生成
import tensorflow as tf
index = tf.constant([[1],[1]])
values = tf.constant([[0.2, 0.8],[0.4, 0.6]])
index = tf.stack([tf.range(index.shape[0])[:, None], index], axis=2)
result = tf.gather_nd(values, index)
index.eval(session=tf.Session())
array([[[0, 1]],
[[1, 1]]], dtype=int32)
result.eval(session=tf.Session())
array([[0.8],
[0.6]], dtype=float32)
top_k索引
import tensorflow as tf
# Input data
a = tf.placeholder(tf.float32, [None, None])
num_top = tf.placeholder(tf.int32, [])
# Find top elements
a_top, a_top_idx = tf.nn.top_k(a, num_top, sorted=False)
# Apply softmax
a_top_sm = tf.nn.softmax(a_top)
# Reconstruct into original shape
a_shape = tf.shape(a)
a_row_idx = tf.tile(tf.range(a_shape[0])[:, tf.newaxis], (1, num_top))
scatter_idx = tf.stack([a_row_idx, a_top_idx], axis=-1)
result = tf.scatter_nd(scatter_idx, a_top_sm, a_shape)
# Test
with tf.Session() as sess:
result_val = sess.run(result, feed_dict={a: [[2, 5, 4, 7], [7, 5, 6, 8]], num_top: 2})
print(result_val)
[[0. 0.11920291 0. 0.880797 ]
[0.26894143 0. 0. 0.7310586 ]]