Focal loss Tensorflow实现

数据不均衡--Focal loss Tensorflow实现

def focal_loss(prediction_tesnsor,target_tensor,weights = None,alpha = 0.25,gamma=2):
	"""
	FL = -alpha *(z-p)^gamma*log(p) - (1-alpha)*p^gamma *log(1-p)
	which alpha = 0.25,gamma = 2,p = sigmoid(x),z = target_tensor
	"""
	sigmoid_p = tf.nn.sigmoid(prediction_tesnsor)
	zeros = array_ops.zeros.like(sigmoid_p,dtype = sigmoid_p.type)

	#对比正标签
	pos_p_sub = array_ops.where(target_tensor > zeros,target_tensor - sigmoid_p,zeros)

	#对于负标签
	neg_p_sub = array_ops.where(target_tensor > zeros,zeros,sigmoid_p)

	fl_loss = -alpha * (pos_p_sub**gamma)*tf.log(tf.clip_by_value(sigmoid_p,1e-8,1.0))\
			- (1-alpha) * (neg_p_sub**gamma)*tf.log(tf.clip_by_value(1-sigmoid_p,1e-8,1.0))
	return tf.reduce_sum(fl_loss)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值