提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
前言
在复现Faster R-CNN过程中,计算相应的模型损失时,遇到tf.where()
等函数不是很清楚,今天记录下来以备查看。
一、tf.where的介绍
tf.where()对于非布尔型张量,返回的是非零数值的位置。
import tensorflow as tf
test=tf.constant([[1,-1,3],[0,0,6]])
tf.where(test)
二、tf.gather_nd()的介绍
在tensor中进行不连续的取值,就必须用到tf.gather_nd()
。其基本语法如下:
tf.gather_nd(
params,
indices,
name=None
)
例子如下:
# coding=utf-8
# tf 2.0+
import tensorflow as tf
a = tf.constant([[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10],
[11, 12, 13, 14, 15]])
index_a1 = tf.constant([[0, 2], [0, 4], [2, 2]]) # 随便选几个
index_a2 = tf.constant([0, 1]) # 0行1列的元素——2
index_a3 = tf.constant([[0], [1]]) # [第0行,第1行]
index_a4 = tf.constant([0]) # 第0行
print(tf.gather_nd(a, index_a1))
print(tf.gather_nd(a, index_a2))
print(tf.gather_nd(a, index_a3))
print(tf.gather_nd(a, index_a4))