WMSE
class Weighted_mse_mae(nn.Module): def __init__(self, mse_weight=1.0, mae_weight=1.0, NORMAL_LOSS_GLOBAL_SCALE=0.00005): super(Weighted_mse_mae, self).__init__() self.NORMAL_LOSS_GLOBAL_SCALE = NORMAL_LOSS_GLOBAL_SCALE self.mse_weight = mse_weight





