tensorflow tf.nn.embedding_lookup partition strategy tf.gather tf.nn.top_k

tf.nn.embedding_lookup:

定义
tf.nn.embedding_lookup(
    params,
    ids,
    partition_strategy='mod',
    name=None,
    validate_indices=True,
    max_norm=None
)

当params为一个tensor时,很容易理解,ids里的每个数字代表选取在params中0轴所在的对应index。是tf.gather的泛化版。

params = tf.constant([10,20,30,40])
ids = tf.constant([0,1,3])
print tf.nn.embedding_lookup(params,ids).eval()
return [10 20 40]

如上,根据[0,1,3]中的数字,选取[10,20,30,40]的0轴中对应index的[10 20 40]。
但当params是个list的tensor时,ids中的数字就要根据partition_strategy来进行划分。

params1 = tf.constant([1,2])
params2 = tf.constant([10,20])
ids = tf.constant([2,0,2,1,2,3])
result = tf.nn.embedding_lookup([params1, params2], ids)
return [ 2  1  2 10  2 20]

默认划分规则是mod,mod划分规则是这样:
数字0表示params1中的第一个元素
数字1表示params2中的第一个元素
数字2表示params1中的第二个元素
数字3表示params2中的第二个元素
以此类推。。。
这样在上代码中,ids中的2代表的params1中的第二个元素,也就是2;ids中的1代表的params2中的第一个元素,也就是10。

另一种划分规则是div,其规则是:
数字0表示params1中的第一个元素
数字1表示params1中的第二个元素
数字2表示params2中的第一个元素
数字3表示params2中的第二个元素

# tf.nn.embedding_lookup通常用来根据id提取embedding
embeddings = tf.nn.embedding_lookup(self.weights['feature_embeddings'], self.feat_ids)
# self.feat_ids可形如:shape[batch_size,ids_size]
#[[0,1,2],
  [2,3,4]]

tf.gather:

定义:
tf.gather(
    params,
    indices,
    validate_indices=None,
    name=None,
    axis=0
)

沿着axis,根据indices提取params的一个切片。indices 可以是任何shape的整数张量,更高级点的api有tf.batch_gather、tf.gather_nd(二维id).
在这里插入图片描述

emb_trans = tf.transpose(embeddings, [1, 0, 2])
emb_left = tf.gather(emb_trans, self.left_index)

tf.gather_nd:

# 根据索引提取数据,可用于topk的索引生成
import tensorflow as tf
index = tf.constant([[1],[1]])
values = tf.constant([[0.2, 0.8],[0.4, 0.6]])

index = tf.stack([tf.range(index.shape[0])[:, None], index], axis=2)
result = tf.gather_nd(values, index)

index.eval(session=tf.Session())
array([[[0, 1]],
       [[1, 1]]], dtype=int32)
result.eval(session=tf.Session())
array([[0.8],
       [0.6]], dtype=float32)

top_k索引

import tensorflow as tf

# Input data
a = tf.placeholder(tf.float32, [None, None])
num_top = tf.placeholder(tf.int32, [])
# Find top elements
a_top, a_top_idx = tf.nn.top_k(a, num_top, sorted=False)
# Apply softmax
a_top_sm = tf.nn.softmax(a_top)
# Reconstruct into original shape
a_shape = tf.shape(a)
a_row_idx = tf.tile(tf.range(a_shape[0])[:, tf.newaxis], (1, num_top))
scatter_idx = tf.stack([a_row_idx, a_top_idx], axis=-1)
result = tf.scatter_nd(scatter_idx, a_top_sm, a_shape)
# Test
with tf.Session() as sess:
    result_val = sess.run(result, feed_dict={a: [[2, 5, 4, 7], [7, 5, 6, 8]], num_top: 2})
    print(result_val)

[[0.         0.11920291 0.         0.880797  ]
 [0.26894143 0.         0.         0.7310586 ]]
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值