tf.gather_nd
记录一些很少会去用但是用起来就蛋疼的函数,方便自己回查
定义
1.x 版本都是在 tf里面 如: tf.gather_nd
2.x 版本都是在tf.compat.v1 里面 如: tf.compat.v1.gather_nd
官方的函数接口如下:
tf.gather_nd(
params,
indices,
name=None,
batch_dims=0
)
这个函数主要的作用就是根据你提供的张量坐标来收集这个张量在这个坐标下的值, 有点像下面这样
indexs = [[0,0], [1,1]]
ndarray = np.random.randint(1,10,(2,2))
result = []
for i in indexs:
result.append(ndarray[i[0],i[1]])
所以
- params 就是张量
- indices 就是包含了上面张量中感兴趣的点的坐标的张量,一般是通过tf.where()获取
- name 就是给这个操作搞个名字
- batch_dims 就是控制indices的 batch 维度的
官网例子
indices 不带batch的(一条龙地拿)
# 第一个例子 indices就是一个列表, 里面放的是张量坐标,很清楚了
# 第一个点是 [0, 0], 那么对应张量里的 'a'
# 第二个点是 [1, 1], 那么对应张量里的 'b'
# 两个结果放一个列表里 ['a', 'd']
# 这里 indices 的shape = (2,2)
# axis = -1 这个维度里 也就是shape(2,2<--这个2)
# 其实要对应张量的最大维度数 也就是 len(params.shape)
# 同时也是 params.rank
# axis[:-1] 表示的是indices的batch维度
# 比如 indices shape = (1,2,3,4)
# 那么说明indices 被分成了 len(indices.shape) 种 batch
# 也就是 indices.rank 种 batch, 对应维度上的数字比如 上面的1
# 就是 第一种(axis=0) batch 的第二个batch, 以此类推,
# 所以下面的例子才会是各种套娃的括号
# 但是不管怎样, 要找元素其实只要通过最后一个维度的列表[a,b,c,d]
# 去对应params 里的位置上的元素就好了, 前面的这些都是骚操作
# 结果 output shape = (1,2)
indices = [[0, 0], [1, 1]]
params = [['a', 'b'], ['c', 'd']]
output = ['a', 'd']
#
indices = [[1], [0]]
params = [['a', 'b'], ['c', 'd']]
output = [['c', 'd'], ['a', 'b']]
indices = [[1]]
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]
output = [[['a1', 'b1'], ['c1', 'd1']]]
indices = [[0, 1], [1, 0]]
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]
output = [['c0', 'd0'], ['a1', 'b1']]
indices = [[0, 0, 1], [1, 0, 1]]
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]
output = ['b0', 'b1']
# 这里是indices 其实是有各种batch的例子,看的人想一头撞死在不周山
# 套路是先从最里面的括号(也就是最后一个维度)拿值, 然后根据indices的shape
# 来补括号
indices = [[[0, 0]], [[0, 1]]]
params = [['a', 'b'], ['c', 'd']]
output = [['a'], ['b']]
indices = [[[1]], [[0]]]
params = [['a', 'b'], ['c', 'd']]
output = [[['c', 'd']], [['a', 'b']]]
indices = [[[1]], [[0]]]
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]
output = [[[['a1', 'b1'], ['c1', 'd1']]],
[[['a0', 'b0'], ['c0', 'd0']]]]
indices = [[[0, 1], [1, 0]], [[0, 0], [1, 1]]]
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]
output = [[['c0', 'd0'], ['a1', 'b1']],
[['a0', 'b0'], ['c1', 'd1']]]
indices = [[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]]
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]
output = [['b0', 'b1'], ['d0', 'c1']]
# 这里看的有点迷,下面是我觉得的理解
# 我感觉是这样的
# indices shape = (1,2,1)
# param shape = (2,2,2)
# output shape = (1,2,2)
# 以 indices axis = 1 也为一个batch 也就是要按照这个维度把 params axis = 1 开始作为一个batch, 一共有两个batch, 分别为 [['a0', 'b0'], ['c0', 'd0']], 和 [['a1', 'b1'], ['c1', 'd1']] 分别从中选出 indices = [1] 和 [0] 的值, 第一个就是['c0', 'd0'] 第二个就是 ['a1', 'b1']
batch_dims = 1
indices = [[1], [0]]
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]
output = [['c0', 'd0'], ['a1', 'b1']]
batch_dims = 1
indices = [[[1]], [[0]]]
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]
output = [[['c0', 'd0']], [['a1', 'b1']]]
batch_dims = 1
indices = [[[1, 0]], [[0, 1]]]
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]
output = [['c0'], ['b1']]
一般会用在计算评价指标
def calculate_model_fp(input_tensor, label_tensor):
"""
calculate fp figure
:param input_tensor:
:param label_tensor:
:return:
"""
logits = tf.nn.softmax(logits=input_tensor)
final_output = tf.expand_dims(tf.argmax(logits, axis=-1), axis=-1)
idx = tf.where(tf.equal(final_output, 1))
pix_cls_ret = tf.gather_nd(final_output, idx)
false_pred = tf.cast(tf.shape(pix_cls_ret)[0], tf.int64) - tf.count_nonzero(
tf.gather_nd(label_tensor, idx)
)
return tf.divide(false_pred, tf.cast(tf.shape(pix_cls_ret)[0], tf.int64))
def calculate_model_fn(input_tensor, label_tensor):
"""
calculate fn figure
:param input_tensor:
:param label_tensor:
:return:
"""
logits = tf.nn.softmax(logits=input_tensor)
final_output = tf.expand_dims(tf.argmax(logits, axis=-1), axis=-1)
idx = tf.where(tf.equal(label_tensor, 1))
pix_cls_ret = tf.gather_nd(final_output, idx)
label_cls_ret = tf.gather_nd(label_tensor, tf.where(tf.equal(label_tensor, 1)))
mis_pred = tf.cast(tf.shape(label_cls_ret)[0], tf.int64) - tf.count_nonzero(pix_cls_ret)
return tf.divide(mis_pred, tf.cast(tf.shape(label_cls_ret)[0], tf.int64))
在其它框架下对应的函数
np.take()
torch.take()
参考文献
https://zhuanlan.zhihu.com/p/45673869
https://github.com/MaybeShewill-CV/lanenet-lane-detection