解决方案
示例
def get_median(v):
# 把v拉伸成1维tensor
v = tf.reshape(v, [-1])
# 计算分位点(中位数)是第几(m)大元素
m = v.shape[0]//2
# 先获取前m个元素,再选取最小元素,(分位点为前m大元素中的最小值)。
return tf.reduce_min(tf.nn.top_k(v, m, sorted=False).values)
附:tf.nn.top_k函数详解
tf.nn.top_k(
input,
k=1,
sorted=True,
name=None
)
参数:
- input:输入的tensor,不能是array这些啊!要么输入1-D,要是更高维度必须保证最后的一个维度长度必须大于等于K
- k:0-D的int32的数字张量。
- sorted:如果sorted=True,那么选出来的k个数字就需要按照降序的顺序排序
- name:可选项,也就是这个操作的名字
返回:
- values:也就是每一行的最大的k个数字
- indices:这里的下标是在输入的张量的最后一个维度的下标