收集下标对应值: tf.gather_nd()


记录一些很少会去用但是用起来就蛋疼的函数,方便自己回查


定义

1.x 版本都是在 tf里面 如: tf.gather_nd

2.x 版本都是在tf.compat.v1 里面 如: tf.compat.v1.gather_nd

官方的函数接口如下:

tf.gather_nd(
    params,
    indices,
    name=None,
    batch_dims=0
)

这个函数主要的作用就是根据你提供的张量坐标来收集这个张量在这个坐标下的值, 有点像下面这样

indexs = [[0,0], [1,1]]
ndarray = np.random.randint(1,10,(2,2))
result = []
for i in indexs:
	result.append(ndarray[i[0],i[1]])

所以

  1. params 就是张量
  2. indices 就是包含了上面张量中感兴趣的点的坐标的张量,一般是通过tf.where()获取
  3. name 就是给这个操作搞个名字
  4. batch_dims 就是控制indices的 batch 维度的

官网例子

indices 不带batch的(一条龙地拿)

    # 第一个例子 indices就是一个列表, 里面放的是张量坐标,很清楚了
    # 第一个点是 [0, 0], 那么对应张量里的 'a'
    # 第二个点是 [1, 1], 那么对应张量里的 'b'
    # 两个结果放一个列表里 ['a', 'd']
    # 这里 indices 的shape = (2,2) 
    # axis = -1 这个维度里 也就是shape(2,2<--这个2)
    # 其实要对应张量的最大维度数 也就是 len(params.shape) 
    # 同时也是 params.rank
    # axis[:-1] 表示的是indices的batch维度
    # 比如 indices shape = (1,2,3,4)
    # 那么说明indices 被分成了 len(indices.shape) 种 batch
    # 也就是 indices.rank 种 batch, 对应维度上的数字比如 上面的1
    # 就是 第一种(axis=0) batch 的第二个batch, 以此类推,
    # 所以下面的例子才会是各种套娃的括号
    # 但是不管怎样, 要找元素其实只要通过最后一个维度的列表[a,b,c,d]
    # 去对应params 里的位置上的元素就好了, 前面的这些都是骚操作
    # 结果 output shape = (1,2)
    indices = [[0, 0], [1, 1]]
    params = [['a', 'b'], ['c', 'd']]
    output = ['a', 'd']

	# 
    indices = [[1], [0]]
    params = [['a', 'b'], ['c', 'd']]
    output = [['c', 'd'], ['a', 'b']]

    indices = [[1]]
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]
    output = [[['a1', 'b1'], ['c1', 'd1']]]


    indices = [[0, 1], [1, 0]]
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]
    output = [['c0', 'd0'], ['a1', 'b1']]


    indices = [[0, 0, 1], [1, 0, 1]]
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]
    output = ['b0', 'b1']


# 这里是indices 其实是有各种batch的例子,看的人想一头撞死在不周山
# 套路是先从最里面的括号(也就是最后一个维度)拿值, 然后根据indices的shape
# 来补括号

    indices = [[[0, 0]], [[0, 1]]]
    params = [['a', 'b'], ['c', 'd']]
    output = [['a'], ['b']]

    indices = [[[1]], [[0]]]
    params = [['a', 'b'], ['c', 'd']]
    output = [[['c', 'd']], [['a', 'b']]]

    indices = [[[1]], [[0]]]
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]
    output = [[[['a1', 'b1'], ['c1', 'd1']]],
              [[['a0', 'b0'], ['c0', 'd0']]]]

    indices = [[[0, 1], [1, 0]], [[0, 0], [1, 1]]]
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]
    output = [[['c0', 'd0'], ['a1', 'b1']],
              [['a0', 'b0'], ['c1', 'd1']]]


    indices = [[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]]
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]
    output = [['b0', 'b1'], ['d0', 'c1']]


# 这里看的有点迷,下面是我觉得的理解
# 我感觉是这样的
# indices shape = (1,2,1)
# param shape = (2,2,2)
# output shape = (1,2,2)
# 以 indices axis = 1 也为一个batch 也就是要按照这个维度把 params axis = 1 开始作为一个batch, 一共有两个batch, 分别为 [['a0', 'b0'], ['c0', 'd0']], 和 [['a1', 'b1'], ['c1', 'd1']] 分别从中选出 indices = [1] 和 [0] 的值, 第一个就是['c0', 'd0'] 第二个就是 ['a1', 'b1']
    batch_dims = 1
    indices = [[1], [0]]
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]
    output = [['c0', 'd0'], ['a1', 'b1']]

    batch_dims = 1
    indices = [[[1]], [[0]]]
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]
    output = [[['c0', 'd0']], [['a1', 'b1']]]

    batch_dims = 1
    indices = [[[1, 0]], [[0, 1]]]
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]
    output = [['c0'], ['b1']]

一般会用在计算评价指标

这里借用Maybeshewill 大佬的代码

def calculate_model_fp(input_tensor, label_tensor):
    """
    calculate fp figure
    :param input_tensor:
    :param label_tensor:
    :return:
    """
    logits = tf.nn.softmax(logits=input_tensor)
    final_output = tf.expand_dims(tf.argmax(logits, axis=-1), axis=-1)

    idx = tf.where(tf.equal(final_output, 1))
    pix_cls_ret = tf.gather_nd(final_output, idx)
    false_pred = tf.cast(tf.shape(pix_cls_ret)[0], tf.int64) - tf.count_nonzero(
        tf.gather_nd(label_tensor, idx)
    )

    return tf.divide(false_pred, tf.cast(tf.shape(pix_cls_ret)[0], tf.int64))


def calculate_model_fn(input_tensor, label_tensor):
    """
    calculate fn figure
    :param input_tensor:
    :param label_tensor:
    :return:
    """
    logits = tf.nn.softmax(logits=input_tensor)
    final_output = tf.expand_dims(tf.argmax(logits, axis=-1), axis=-1)

    idx = tf.where(tf.equal(label_tensor, 1))
    pix_cls_ret = tf.gather_nd(final_output, idx)
    label_cls_ret = tf.gather_nd(label_tensor, tf.where(tf.equal(label_tensor, 1)))
    mis_pred = tf.cast(tf.shape(label_cls_ret)[0], tf.int64) - tf.count_nonzero(pix_cls_ret)

    return tf.divide(mis_pred, tf.cast(tf.shape(label_cls_ret)[0], tf.int64))

在其它框架下对应的函数

np.take()
torch.take()

参考文献

https://zhuanlan.zhihu.com/p/45673869
https://github.com/MaybeShewill-CV/lanenet-lane-detection

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值