Pytorch搭建EfficientNet网络和Openmax

该代码实现了一个基于EfficientNet的深度学习模型,用于40类垃圾数据集的分类任务。模型在训练时使用24类,测试时使用40类,采用OpenMax技术调整softmax输出,以处理开放集识别问题。训练过程包括数据预处理、模型构建、训练和验证,以及模型保存。此外,还提供了计算训练集平均值和标准差、进度条显示、学习率调整等功能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

B站账号@狼群里的小杨,记得点赞收藏加关注,一键三连哦!

EfficientNet

在这里插入图片描述

代码

这是一个用包含40个类别的垃圾数据集做的开放场景实验。训练过程中仅使用24个训练类,测试时使用40个垃圾类别。
garbage数据集下载
首先是训练的代码。
task_garbage.py

'''
@File  :task_gabage.py
@Author:cjh
@Date  :2022/1/16 14:45
@Desc  :
'''
import random

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import numpy as np
import torchvision.transforms as transforms
from torchvision.transforms import autoaugment

import os
import argparse
import sys
import warnings
warnings.filterwarnings("ignore")


# os.chdir(os.path.dirname('X:/PyCharm/211211-DL-OSR/DL_OSR/model/OpenMax'))
# sys.path.append("../..")
from torch.optim import lr_scheduler

import backbones.cifar10 as models
from datasets import GARBAGE40_Dataset
from utils import adjust_learning_rate, progress_bar, Logger, mkdir_p, Evaluation

from openmax import compute_train_score_and_mavs_and_dists,fit_weibull,openmax
from Modelbuilder import Network
from Plotter import plot_feature
from garbage_transform import Resize, Cutout, RandomErasing
from garbage_loss import LabelSmoothSoftmaxCE, LabelSmoothingLoss, FocalLoss
from checkpoints import efficientnet
# from pytorch_toolbelt import losses as L

parser=argparse.ArgumentParser()
parser.add_argument('--lr',default=0.01,type=float,help='learning rate')
# ./checkpoints/garbage/ResNet/ResNet18.pth
parser.add_argument('--resume',default=None,type=str,metavar='PATH',help='path to load lastest pth')
parser.add_argument('--arch',default='EfficientNet_B5',type=str,help='choosing network')
parser.add_argument('--bs',default=8,type=int,help='batch size')
parser.add_argument('--es',default=40,type=int,help='epoches')
parser.add_argument('--train_class_num',default=24,type=int,help='classes used in training')
parser.add_argument('--test_class_num',default=40,type=int,help='classes used in testing')
parser.add_argument('--includes_all_train_class',default=True,action='store_true',
                    help='testing uses all known classes')
parser.add_argument('--embed_dim', default=2, type=int, help='embedding feature dimension')
parser.add_argument('--evaluate',default=False,action='store_true',help='evaluating')

parser.add_argument('--weibull_tail', default=20, type=int, help='Classes used in testing')
parser.add_argument('--weibull_alpha', default=5, type=int, help='Classes used in testing')
parser.add_argument('--weibull_threshold', default=0.9, type=float, help='Classes used in testing')

# Parameters for stage plotting
# parser.add_argument('--plot', default=False, action='store_true', help='Plotting the training set.')
# parser.add_argument('--plot_max', default=0, type=int, help='max examples to plot in each class, 0 indicates all.')
# parser.add_argument('--plot_quality', default=200, type=int, help='DPI of plot figure')

args=parser.parse_args()

def main():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(device)
    best_acc = 0  # best test accuracy
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    # checkpoint
    args.checkpoint = './checkpoints/garbage/' + args.arch
    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # folder to save figures
    args.plotfolder = './checkpoints/garbage/' + args.arch + '/plotter'
    if not os.path.isdir(args.plotfolder):
        mkdir_p(args.plotfolder)

    # Data
    print('==> Preparing data..')
    picture_size = 256
    train_transforms = transforms.Compose([

        Resize((int(288 * (256 / 224)), int(288 * (256 / 224)))),
        transforms.CenterCrop((picture_size, picture_size)),
        transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
        transforms.RandomVerticalFlip(),
        autoaugment.AutoAugment(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        Cutout(probability=0.5, size=64, mean=[0.0, 0.0, 0.0]),
        RandomErasing(probability=0.0, mean=[0.485, 0.456, 0.406]),
    ])
    test_transforms = transforms.Compose([
        Resize((int(288 * (256 / 224)), int(288 * (256 / 224)))),
        transforms.CenterCrop((picture_size, picture_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    random.seed(42)
    train_classes = random.sample(range(0, 40), args.train_class_num)
    test_classes=train_classes+[999]

    trainset = GARBAGE40_Dataset(root='../../data/garbage', train=True,
                                 transform=train_transforms,
                     train_class_num=args.train_class_num, test_class_num=args.test_class_num,
                     includes_all_train_class=args.includes_all_train_class,
                                 train_classes=train_classes)
    testset = GARBAGE40_Dataset(root='../../data/garbage', train=False,
                                transform=test_transforms,
                    train_class_num=args.train_class_num, test_class_num=args.test_class_num,
                    includes_all_train_class=args.includes_all_train_class,
                                train_classes=train_classes)
    # data loader
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=0)
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=0)

    #Model
    # net=Network(backbone=args.arch,num_classes=args.train_class_num, embed_dim=args.embed_dim)
    # fea_dim = net.classifier.in_features
    # net = net.to(device)
    if args.arch=='ResNet18':
        net = torchvision.models.resnet18(pretrained=True).to(device)
        model_wight_path = "checkpoints/garbage/ResNet18/best_model.pth"
        assert os.path.exists(model_wight_path), "file {} dose not exist.".format(model_wight_path)  # 若路径不存在,则打印信息
        net.load_state_dict(torch.load(model_wight_path, map_location=device), strict=False)
        net.fc = nn.Sequential(
            nn.Linear(net.fc.in_features, 256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, args.train_class_num)
        )

    if args.arch == 'ResNet50':
        net = torchvision.models.resnet50(pretrained=True).to(device)
        model_wight_path = "checkpoints/garbage/ResNet50/best_model.pth"
        assert os.path.exists(model_wight_path), "file {} dose not exist.".format(model_wight_path)  # 若路径不存在,则打印信息
        net.load_state_dict(torch.load(model_wight_path, map_location=device), strict=False)
        net.fc = nn.Sequential(
            nn.Linear(net.fc.in_features, 256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, args.train_class_num)
        )
    if args.arch == 'EfficientNet_B5':
        # net = torchvision.models.efficientnet_b5(pretrained=True).to(device)
        net = efficientnet.efficientnet_b5().to(device)
        # model_wight_path = "checkpoints/garbage/EfficientNet_B5/efficientnetb5.pth"
        model_wight_path = "checkpoints/garbage/EfficientNet_B5/best_model.pth"
        assert os.path.exists(model_wight_path), "file {} dose not exist.".format(model_wight_path)  # 若路径不存在,则打印信息
        net.load_state_dict(torch.load(model_wight_path, map_location=device), strict=False)
        net.classifier= nn.Sequential(
            nn.Dropout(p=0.4, inplace=True),
            nn.Linear(2048, args.train_class_num),
        )
    if args.arch == 'EfficientNet_B7':
        # net = torchvision.models.efficientnet_b7(pretrained=True).to(device)
        net = efficientnet.efficientnet_b7().to(device)
        net.classifier= nn.Sequential(
            nn.Dropout(p=0.4, inplace=True),
            nn.Linear(2048, args.train_class_num),
        )
    if args.arch == 'ResNext101_32x16d_wsl':
        net = torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x16d_wsl')
        net.fc = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(2048, args.train_class_num)
        )


    if args.arch == 'Resnext101_32x8d_wsl':
        net = torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x8d_wsl')
        net.fc = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(2048, args.train_class_num)
        )
    if args.arch == 'Resnext50_32x4d':
        net = torchvision.models.resnext50_32x4d(pretrained=True).to(device)
        net.fc = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(2048, args.train_class_num)
        )
    # from efficientnet_pytorch import EfficientNet
    # model = EfficientNet.from_pretrained('efficientnet-b0')
    # model = EfficientNet.from_pretrained(,num_classes=args.train_class_num)
    if args.arch == 'EfficientNet_B3':
        net = torchvision.models.efficientnet_b3(pretrained=True).to(device)
        net.classifier= nn.Sequential(
            nn.Linear(1536, 256),
            nn.ReLU(),
            nn.Dropout(p=0.4),
            nn.Linear(256, args.train_class_num),

            # nn.Dropout(p=0.4, inplace=True),
            # nn.Linear(1024, args.train_class_num),
        )

    if device == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True
    if args.resume!=None:
        # Load checkpoint.
        if os.path.isfile(args.resume):
            print('==> Resuming from checkpoint..')

            #for cpu load cuda model
            checkpoint = torch.load(args.resume,map_location=torch.device('cpu'))
            net.load_state_dict({
   k.replace('module.', ''): v for k, v in checkpoint['net'].items()})

            #for gpu load cuda model for cpu load cpu model
            # checkpoint = torch.load(args.resume)
            # net.load_state_dict(checkpoint['net'])


            # best_acc = checkpoint['acc']
            # print("BEST_ACCURACY: "+str(best_acc))
            start_epoch = checkpoint['epoch']
            logger = Logger(os.path.join(args.checkpoint, 'log.txt'), resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'))
        logger.set_names(['Epoch', 'Learning Rate', 'Train Loss','Train Acc.', 'Test Loss', 'Test Acc.'])

    criterion = nn.CrossEntropyLoss()
    # criterion = LabelSmoothSoftmaxCE(lb_pos=0.9, lb_neg=5e-3)
    # criterion = LabelSmoothingLoss(classes=args.train_class_num, smoothing=0.1)
    # criterion = FocalLoss(alpha=0.25)


    optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
    # optimizer = optim.RAdam(net.parameters(),lr=args.lr,betas=(0.9, 0.999), eps=1e-8,weight_decay=5e-4)
    # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.7, patience=3, verbose=True)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=2, verbose=False)
    scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=3, T_mult=2)
    scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=2, T_mult=2,eta_min = 1e-5)

    scheduler = lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.5)

    # epoch=0
    best_ac=0
    if not args.evaluate:
        for epoch in range(start_epoch, args.es):
            print('\nEpoch: %d   Learning rate: %f' % (epoch+1, optimizer.param_groups[0]['lr']))
            # adjust_learning_rate(optimizer, epoch, args.lr, step=20)
            train_loss, train_acc = train(net, trainloader, optimizer, criterion, device, train_classes)
            if epoch == args.es - 1:
                save_model(net, None, epoch, os.path.join(args.checkpoint,'last_model.pth'))
            test_loss, test_acc = 0, 0
            try:
                test_loss, test_acc = test(epoch, net, trainloader, testloader, criterion, device, test_classes)
            except:
                pass
            # scheduler.step(test_loss)
            scheduler.step(train_loss)

            if best_ac<test_acc:
                best_ac=test_acc
                print("The best Acc: ",best_ac)
                # save_model(net, None, epoch, os.path.join(args.checkpoint, 'best_model.pth'))
                torch.save(net.state_dict(), os.path.join(args.checkpoint, 'best_model.pth'))
                # save_model(net, best_ac, epoch, os.path.join(args.checkpoint, 'best_model.pth'))
            #
            logger.append([epoch+1, optimizer.param_groups[0]['lr'], train_loss, train_acc, test_loss, test_acc])
            # plot_feature(net, trainloader, device, args.plotfolder,train_classes, epoch=epoch,
            #              plot_class_num=args.train_class_num, maximum=args.plot_max, plot_quality=args.plot_quality)
            # if (epoch+1)%20==0:
            #     try:
            #         test(epoch, net, trainloader, testloader, criterion, device,test_classes)
            #     except:
            #         pass
    test(99999, net, trainloader, testloader, criterion, device, test_classes)
    # plot_feature(net, testloader, device, args.plotfolder,train_classes, epoch="test",
    #              plot_class_num=args.train_class_num+1, maximum=args.plot_max, plot_quality=args.plot_quality)
    logger.close()

# Training
def train(net,trainloader,optimizer,criterion,device,train_classes):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        onehot_targets_index=[train_classes.index(i) for i in targets]
        targets=torch.LongTensor(onehot_targets_index)

        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)

        # onehot_targets=torch.zeros((outputs.shape[0],outputs.shape[1]))
        # onehot_targets[range(outputs.shape[0]), onehot_targets_index]=1

        loss = criterion(outputs, targets)
        # loss = torch.nn.functional.cross_entropy(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
    return train_loss/(batch_idx+1), correct/total



def test(epoch, net, trainloader, testloader, criterion, device, test_classes):
    net.eval()

    test_loss = 0
    correct = 0
    total = 0

    scores, labels = [], []
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            onehot_targets_index = [test_classes.index(i) for i in targets]
            targets = torch.LongTensor(onehot_targets_index)

            # image_2 = transforms.RandomAffine(degrees=0, translate=(0.05, 0.05))(inputs).to(device)
            # image_3 = transforms.RandomHorizontalFlip()(inputs).to(device)
            # image_4 = Cutout(probability=0.5, size=64, mean=[0.0, 0.0, 0.0])(inputs).to(device)
            # image_5 = transforms.RandomVerticalFlip()(inputs).to(device)

            inputs
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值