WMSE

class Weighted_mse_mae(nn.Module):
	def __init__(self, mse_weight=1.0, mae_weight=1.0, NORMAL_LOSS_GLOBAL_SCALE=0.00005):
		super(Weighted_mse_mae, self).__init__()
		self.NORMAL_LOSS_GLOBAL_SCALE = NORMAL_LOSS_GLOBAL_SCALE
		self.mse_weight = mse_weight
		self.mae_weight = mae_weight

	def forward(self, input, target): #, mask
		balancing_weights = (1, 1, 5, 10, 30, 32)
		weights = torch.ones_like(input) * balancing_weights[0]
		#weights = torch.nn.Parameter(weights, requires_grad=True)

		thresholds = [ dBZ_to_pixel(ele) for ele in np.array([10, 20, 30, 40, 50]) ]
		for i, threshold in enumerate(thresholds):
			weights = weights + (balancing_weights[i + 1] - balancing_weights[i]) * (target >= threshold).float()

		#input: S*B*H*W
		mse = torch.sum(weights * ((input-target)**2), (2, 3))
		mae = torch.sum(weights * (torch.abs((input-target))), (2, 3))
		'''
		mse = weights * torch.nn.MSELoss(reduce=True, size_average=True)(input, target)
		mse = torch.sum(mse, (2,3,4))
		mae = weights * torch.nn.L1Loss(reduce=True, size_average=True)(input, target)
		mae = torch.sum(mae, (2,3,4))
		'''
		loss_value = self.NORMAL_LOSS_GLOBAL_SCALE * (self.mse_weight*torch.mean(mse) + self.mae_weight*torch.mean(mae))
		return loss_value
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值