tf.gather:用一个一维的索引数组,将张量中对应索引的向量提取出来
import tensorflow as tf
import numpy as np
a = tf.constant([1,2,3,4])
b = tf.square(a)
with tf.Session() as sess:
print("b:%s" % sess.run(b))
# b:[ 1 4 9 16]
import numpy as np
import tensorflow as tf
def bbox_ohem(bbox_pred,bbox_target,label):
'''
:param bbox_pred:
:param bbox_target:
:param label: class label
:return: mean euclidean loss for all the pos and part examples
'''
zeros_index = tf.zeros_like(label, dtype=tf.float32)
ones_index = tf.ones_like(label, dtype=tf.float32)
#获取pos样本和part样本
valid_inds = tf.where(tf.equal(tf.abs(label),1),ones_index,zeros_index)
#(batch,)
#计算平方和(按行)tf.square(bbox_pred-bbox_target): 求每个数的平方值
square_error = tf.square(bbox_pred-bbox_target)
square_error = tf.reduce_sum(square_error,axis=1)
with tf.Session() as sess:
print("bbox_pred-bbox_target:%s"%(sess.run(bbox_pred-bbox_target)))
print("square_error:%s" % (sess.run(square_error)))
# 计算pos样本和part样本的数量
num_valid = tf.reduce_sum(valid_inds)
keep_num = tf.cast(num_valid, dtype=tf.int32)
# 去掉neg样本和landmark样本的平方和
square_error = square_error*valid_inds
# 获取前K个样本的索引,K为pos和part样本的数量
_, k_index = tf.nn.top_k(square_error, k=keep_num)
# 将所有pos样本和part样本的平方和提取出来
square_error = tf.gather(square_error, k_index)
# 返回均值
return tf.reduce_mean(square_error)
bbox_pred = tf.random_uniform([2,4],10,100,seed = 100)
bbox_target = tf.random_uniform([2,4],15,150,seed = 100)
with tf.Session() as sess:
print("cls_prob:%s"%(sess.run(bbox_pred)))
label = np.array([1,0])
bbox_ohem(bbox_pred,bbox_target,label)
landmark_ohem:作用就是返回landmark的损失,用的是landmark样本。
def landmark_ohem(landmark_pred,landmark_target,label):
'''
:param landmark_pred:
:param landmark_target:
:param label:
:return: mean euclidean loss
'''
#keep label =-2 then do landmark detection
ones = tf.ones_like(label,dtype=tf.float32)
zeros = tf.zeros_like(label,dtype=tf.float32)
valid_inds = tf.where(tf.equal(label,-2),ones,zeros)
square_error = tf.square(landmark_pred-landmark_target)
square_error = tf.reduce_sum(square_error,axis=1)
num_valid = tf.reduce_sum(valid_inds)
#keep_num = tf.cast(num_valid*num_keep_radio,dtype=tf.int32)
keep_num = tf.cast(num_valid, dtype=tf.int32)
square_error = square_error*valid_inds
_, k_index = tf.nn.top_k(square_error, k=keep_num)
square_error = tf.gather(square_error, k_index)
return tf.reduce_mean(square_error)