焦点损失函数 Pytorch实现

看了一篇论文,然后这篇论文里面就只是将损失函数从常见的单分类多分类函数转化为了焦点损失函数,看论文里面的效果,效果还出奇的好,很奇怪,难道就只是改一个损失函数就这么牛逼?于是,我也想着把之前的网络的损失函数改一改,看看这个损失函数是否是真的这么牛逼。

pytorch自带的损失函数中,我没有找到对应的焦点损失函数的接口,没办法,只能去看原理然后参考其他人的代码改改或者写吧,找到几个实现焦点损失函数的大佬代码,但是总感觉他们的代码有些奇怪。比如:pytorch实现焦点函数比较靠谱的参考,这个作者给的代码,我如果将识别的类别改为2,那么它在计算α的时候就没有在0到1之间,所以开始怀疑这位大佬的代码,然后,我又去看了看上面我提到的论文里面怎么处理的,公式如下:
在这里插入图片描述这里面[x1, x2, x3],结合上面链接作者的源代码,我理解为分别代表每一种类别在对应的图中所占的像素个数,然后结合图片中的公式,我更改了求解α的代码。然后,我感觉作者给的代码其实不够简洁,当我看他代码理解了原理之后又更改了一下,最后版本如下:

def computer_weights(piexlist):
    sum_list = sum(piexlist) # 先求和
    piexlist_w = [sum_list/x for x in piexlist]
    sum_w = sum(piexlist_w) # 再求和
    w_final = [x/sum_w for x in piexlist_w]
    return w_final


'''在多分类中,我记得有篇文章里面有个动态图,当我们对神经网络的输出图output进行softmax之后,我们需要根据target,也就是打的标签
里面的每一个像素的值来选择output中每个像素某一个维度的值,然后我看网上公开的代码,其实就是类似这样操作的'''
def focus_loss(num_classes, input_data, target, cuda=True):
    n, c, h, w = target.shape
    input_data = torch.softmax(input_data, dim=1) # 先对数据进行softmax

    '''接下来就要根据打标签的图像来选择选取output中每个像素的最终输出值了'''
    classes_mask = torch.zeros_like(input_data) # 这个mask用来跟input相乘,相当于要去对softmax输出的维度进行筛选
    classes_mask.scatter_(1, target, 1)
    # classes_mask = classes_mask.permute(0, 2, 3, 1) # 这一行代码是用来验证上一行代码是否正确
    # print(classes_mask)
    # 这个函数,我感觉我这样用是没问题的,将最终值输出进行验证,感觉也没有问题,这个函数就是根据target的数值做索引。
    input_data = torch.sum(input_data * classes_mask, dim=1) # 二者相乘,并且求和,这样就将每个像素对应的预测值选取成功

    gamma = 2

    '''接下来就是对一张输入图中不同标签类型的权重进行求解'''
    num_class_list = []
    for i in range(num_classes):
        num_class_list.append(torch.sum(target == i).item()) # 将每一种类别所占的像素总和找到并且放到一个list里面
    weights_alpha = computer_weights(num_class_list) # 这是固定的求解公式,相当于所有类别的权重之和为1,我们通过一定的公式将各自的权重找到
    weights_alpha = torch.tensor(weights_alpha)
    weights_alpha = weights_alpha[target.view(-1)].reshape(n, c, h, w) # 这一行代码有意思,值得注意,这一行代码实现了将target中的每一个像素点进行了权值分配

    # print(weights_alpha.shape, weights_alpha)
    if cuda:
        weights_alpha = weights_alpha.cuda() # 这里是需要的,否则会因为数据在cpu和在gpu上的差距报错

    # loss = -(torch.pow((1-input_data), gamma))*torch.log(input_data) # 这个是没有α参数的损失函数
    loss = -(weights_alpha * torch.pow((1-input_data), gamma) * torch.log(input_data)) # 这个是加了α的损失函数
    loss = torch.mean(loss)
    # print(loss.item())
    return loss


if __name__ == '__main__':
    pred = torch.rand((1, 2, 512, 512)).cuda()
    y = torch.from_numpy(np.random.randint(0, 2, (1, 1, 512, 512))).long().cuda()
    focus_loss(2, pred, y)

代码中的注释,大家将就看,我更改之后,函数的接口的参数的维度就比较正常了,分别是[batchsize, n_classes, h, w]和[batchsize, 1, h, w],分别对应神经网络的输出和打好标签的图像。

2020.11.4

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值