问题描述
在研究srgnn这个图序列推荐模型的时候遇到了这个问题。
例如,你有一个4*6的列表list_,4表示batch_size, 8表示每个样本中的元素的个数,(用0补齐)你还有一个mask列表,用来表示每个样本中的元素是否是非零值:
list=[
[1,2,3,0,0,0],
[4,5,0,0,0,0],
[6,7,8,9,0,0],
[1,2,3,0,0,0]
]
mask=[
[1,1,1,0,0,0],
[1,1,0,0,0,0],
[1,1,1,1,0,0],
[1,1,1,0,0,0]
]
现在你想获得list中的每个样本的最后一个值,也就是[3,5,9,3],你可以这样做:
# 获取每个样本的有效长度,也就是[3,2,4,3]
self.last_position = tf.reduce_sum(self.mask, 1)
# 先使用stack获得一个4*2的位置序列---[[0,2],[1,1],[2,3],[3,2]],
# 然后再用这个位置序列去list_中获取对应的元素
self.last_id = tf.gather_nd(self.list_, tf.stack([tf.range(self.batch_size), tf.to_int32(self.rm)-1], axis=1))