分类网络-类别不均衡问题之FocalLoss

在这里插入图片描述
有训练和测代码如下:(完整代码来自CNN从搭建到部署实战)
train.py

import torch
import torchvision
import time
import argparse
import importlib
from loss import FocalLoss


def parse_args():
    parser = argparse.ArgumentParser('training')
    parser.add_argument('--batch_size', default=128, type=int, help='batch size in training')
    parser.add_argument('--num_epochs', default=5, type=int, help='number of epoch in training')
    parser.add_argument('--model',  default='lenet', help='model name [default: mlp]')
    return parser.parse_args()


if __name__ == '__main__':
    args = parse_args()
    batch_size = args.batch_size
    num_epochs = args.num_epochs
    model = importlib.import_module('models.'+args.model) 
        
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net = model.net.to(device)

    loss = torch.nn.CrossEntropyLoss()
    
    if args.model == 'mlp':
        optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
    else:
        optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
          
    train_path = r'./Datasets/mnist_png/training'
    test_path = r'./Datasets/mnist_png/testing'
    transform_list = [torchvision.transforms.Grayscale(num_output_channels=1), torchvision.transforms.ToTensor()]
    if args.model == 'alexnet' or args.model == 'vgg':
        transform_list.append(torchvision.transforms.Resize(size=224))
    if args.model == 'googlenet' or args.model == 'resnet':
        transform_list.append(torchvision.transforms.Resize(size=96))
    transform = torchvision.transforms.Compose(transform_list)

    train_dataset = torchvision.datasets.ImageFolder(train_path, transform=transform)
    test_dataset = torchvision.datasets.ImageFolder(test_path, transform=transform)

    train_iter = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    test_iter = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    for epoch in range(num_epochs):
        train_l, train_acc, test_acc, m, n, batch_count, start = 0.0, 0.0, 0.0, 0, 0, 0, time.time()
        for X, y in train_iter:
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            train_l += l.cpu().item()
            train_acc += (y_hat.argmax(dim=1) == y).sum().cpu().item()
            m += y.shape[0]
            batch_count += 1
        with torch.no_grad():
            for X, y in test_iter:
                net.eval() # 评估模式
                test_acc += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()
                net.train() # 改回训练模式
                n += y.shape[0]
        print('epoch %d, loss %.6f, train acc %.3f, test acc %.3f, time %.1fs'% (epoch, train_l / batch_count, train_acc / m, test_acc / n, time.time() - start))
        torch.save(net, args.model+".pth")

test.py

import cv2
import torch
import argparse
import importlib
from pathlib import Path
import torchvision.transforms.functional


def parse_args():
    parser = argparse.ArgumentParser('testing')
    parser.add_argument('--model',  default='lenet', help='model name [default: mlp]')
    return parser.parse_args()


if __name__ == '__main__':
    args = parse_args()
    model = importlib.import_module('models.' + args.model) 
        
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net = model.net.to(device)
    net = torch.load(args.model+'.pth')
    net.eval()

    with torch.no_grad():
        imgs_path = Path(r"./Datasets/mnist_png/testing/6/").glob("*")
        acc = 0
        count = 0
        for img_path in imgs_path:
            img = cv2.imread(str(img_path), 0)
            if args.model == 'alexnet' or args.model == 'vgg':  
                img = cv2.resize(img, (224,224))
            if args.model == 'googlenet' or args.model == 'resnet':
                img = cv2.resize(img, (96,96))
            img_tensor = torchvision.transforms.functional.to_tensor(img)
            img_tensor = torch.unsqueeze(img_tensor, 0)
            #print(net(img_tensor.to(device)).argmax(dim=1).item())
            if(net(img_tensor.to(device)).argmax(dim=1).item()==6):
                acc += 1
            count+=1
    print(acc/count)

数据集为mnist手写数字识别,其中训练集中数字0~9的数量分别为:0(5923张),1(6472张),2(5985张),3(6131张),4(5842张),5(5421张),6(5918张),7(6265张),8(5851张),9(5949张), 测试集中数字0~9的数量分别为:0(980张),1(1135张),2(1032张),3(1010张),4(982张),5(892张),6(958张),7(1028张),8(974张),9(1009张)。可见各个类别的数量基本上平衡。测试代码仅测试数字6的准确率,因为后面我们要改变训练集中数字6的数量来进行对比。为了节省时间,仅训练5个epoch。
训练结果:

epoch 0, loss 1.443379, train acc 0.529, test acc 0.877, time 23.4s
epoch 1, loss 0.314123, train acc 0.913, test acc 0.939, time 22.1s
epoch 2, loss 0.174050, train acc 0.949, test acc 0.960, time 21.9s
epoch 3, loss 0.122714, train acc 0.963, test acc 0.971, time 21.8s
epoch 4, loss 0.096798, train acc 0.971, test acc 0.975, time 21.8s

测试结果:

0.9780793319415448

现在将训练集中数字6的数量减少到59张(原来的1/100),来模拟某个类别的数据不平衡的情况。
训练结果:

epoch 0, loss 2.200247, train acc 0.131, test acc 0.373, time 20.8s
epoch 1, loss 0.579792, train acc 0.840, test acc 0.855, time 20.5s
epoch 2, loss 0.177890, train acc 0.950, test acc 0.872, time 20.3s
epoch 3, loss 0.128251, train acc 0.963, test acc 0.880, time 20.5s
epoch 4, loss 0.103937, train acc 0.969, test acc 0.888, time 20.7s

测试结果:

0.04801670146137787

可以看到,训练的准确率下降9%,而测试集直接下降了93%惨不忍睹。

引入FocalLoss模块:(参考https://github.com/QunBB/DeepLearning/blob/main/trick/unbalance/loss_pt.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Union


class FocalLoss(nn.Module):
    def __init__(self, alpha: Union[List[float], float], gamma: Optional[int] = 2, with_logits: Optional[bool] = True):
        """
        :param alpha: 每个类别的权重
        :param gamma:
        :param with_logits: 是否经过softmax或者sigmoid
        """
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = torch.FloatTensor([alpha]) if isinstance(alpha, float) else torch.FloatTensor(alpha)
        self.smooth = 1e-8
        self.with_logits = with_logits

    def _binary_class(self, input, target):
        prob = torch.sigmoid(input) if self.with_logits else input
        prob += self.smooth
        alpha = self.alpha.to(target.device)
        loss = -alpha * torch.pow(torch.sub(1.0, prob), self.gamma) * torch.log(prob)
        return loss

    def _multiple_class(self, input, target):
        prob = F.softmax(input, dim=1) if self.with_logits else input

        alpha = self.alpha.to(target.device)
        alpha = alpha.gather(0, target)

        target = target.view(-1, 1)

        prob = prob.gather(1, target).view(-1) + self.smooth  # avoid nan
        logpt = torch.log(prob)

        loss = -alpha * torch.pow(torch.sub(1.0, prob), self.gamma) * logpt
        return loss

    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        :param input: 维度为[bs, num_classes]
        :param target: 维度为[bs]
        :return:
        """
        if len(input.shape) > 1 and input.shape[-1] != 1:
            loss = self._multiple_class(input, target)
        else:
            loss = self._binary_class(input, target)

        return loss.mean()

并将train.py的第26行修改成

    loss = FocalLoss([1, 1, 1, 1, 1, 1, 100, 1, 1, 1])

其中列表的数字代表10个类别的权重值。
训练结果:

epoch 0, loss 2.045273, train acc 0.137, test acc 0.467, time 20.7s
epoch 1, loss 0.510476, train acc 0.810, test acc 0.907, time 21.3s
epoch 2, loss 0.148246, train acc 0.922, test acc 0.941, time 21.1s
epoch 3, loss 0.099026, train acc 0.944, test acc 0.953, time 21.2s
epoch 4, loss 0.075481, train acc 0.954, test acc 0.959, time 21.3s

测试结果:

0.9196242171189979

对比看出,FocalLoss可以有效缓解类别不均衡问题(当然并不能完全消除,有足够平衡的高质量数据集肯定更好啦~)。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

给算法爸爸上香

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值