B站账号@狼群里的小杨,记得点赞收藏加关注,一键三连哦!
EfficientNet
代码
这是一个用包含40个类别的垃圾数据集做的开放场景实验。训练过程中仅使用24个训练类,测试时使用40个垃圾类别。
garbage数据集下载
首先是训练的代码。
task_garbage.py
'''
@File :task_gabage.py
@Author:cjh
@Date :2022/1/16 14:45
@Desc :
'''
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import numpy as np
import torchvision.transforms as transforms
from torchvision.transforms import autoaugment
import os
import argparse
import sys
import warnings
warnings.filterwarnings("ignore")
# os.chdir(os.path.dirname('X:/PyCharm/211211-DL-OSR/DL_OSR/model/OpenMax'))
# sys.path.append("../..")
from torch.optim import lr_scheduler
import backbones.cifar10 as models
from datasets import GARBAGE40_Dataset
from utils import adjust_learning_rate, progress_bar, Logger, mkdir_p, Evaluation
from openmax import compute_train_score_and_mavs_and_dists,fit_weibull,openmax
from Modelbuilder import Network
from Plotter import plot_feature
from garbage_transform import Resize, Cutout, RandomErasing
from garbage_loss import LabelSmoothSoftmaxCE, LabelSmoothingLoss, FocalLoss
from checkpoints import efficientnet
# from pytorch_toolbelt import losses as L
parser=argparse.ArgumentParser()
parser.add_argument('--lr',default=0.01,type=float,help='learning rate')
# ./checkpoints/garbage/ResNet/ResNet18.pth
parser.add_argument('--resume',default=None,type=str,metavar='PATH',help='path to load lastest pth')
parser.add_argument('--arch',default='EfficientNet_B5',type=str,help='choosing network')
parser.add_argument('--bs',default=8,type=int,help='batch size')
parser.add_argument('--es',default=40,type=int,help='epoches')
parser.add_argument('--train_class_num',default=24,type=int,help='classes used in training')
parser.add_argument('--test_class_num',default=40,type=int,help='classes used in testing')
parser.add_argument('--includes_all_train_class',default=True,action='store_true',
help='testing uses all known classes')
parser.add_argument('--embed_dim', default=2, type=int, help='embedding feature dimension')
parser.add_argument('--evaluate',default=False,action='store_true',help='evaluating')
parser.add_argument('--weibull_tail', default=20, type=int, help='Classes used in testing')
parser.add_argument('--weibull_alpha', default=5, type=int, help='Classes used in testing')
parser.add_argument('--weibull_threshold', default=0.9, type=float, help='Classes used in testing')
# Parameters for stage plotting
# parser.add_argument('--plot', default=False, action='store_true', help='Plotting the training set.')
# parser.add_argument('--plot_max', default=0, type=int, help='max examples to plot in each class, 0 indicates all.')
# parser.add_argument('--plot_quality', default=200, type=int, help='DPI of plot figure')
args=parser.parse_args()
def main():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
best_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
# checkpoint
args.checkpoint = './checkpoints/garbage/' + args.arch
if not os.path.isdir(args.checkpoint):
mkdir_p(args.checkpoint)
# folder to save figures
args.plotfolder = './checkpoints/garbage/' + args.arch + '/plotter'
if not os.path.isdir(args.plotfolder):
mkdir_p(args.plotfolder)
# Data
print('==> Preparing data..')
picture_size = 256
train_transforms = transforms.Compose([
Resize((int(288 * (256 / 224)), int(288 * (256 / 224)))),
transforms.CenterCrop((picture_size, picture_size)),
transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
transforms.RandomVerticalFlip(),
autoaugment.AutoAugment(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
Cutout(probability=0.5, size=64, mean=[0.0, 0.0, 0.0]),
RandomErasing(probability=0.0, mean=[0.485, 0.456, 0.406]),
])
test_transforms = transforms.Compose([
Resize((int(288 * (256 / 224)), int(288 * (256 / 224)))),
transforms.CenterCrop((picture_size, picture_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
random.seed(42)
train_classes = random.sample(range(0, 40), args.train_class_num)
test_classes=train_classes+[999]
trainset = GARBAGE40_Dataset(root='../../data/garbage', train=True,
transform=train_transforms,
train_class_num=args.train_class_num, test_class_num=args.test_class_num,
includes_all_train_class=args.includes_all_train_class,
train_classes=train_classes)
testset = GARBAGE40_Dataset(root='../../data/garbage', train=False,
transform=test_transforms,
train_class_num=args.train_class_num, test_class_num=args.test_class_num,
includes_all_train_class=args.includes_all_train_class,
train_classes=train_classes)
# data loader
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=0)
testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=0)
#Model
# net=Network(backbone=args.arch,num_classes=args.train_class_num, embed_dim=args.embed_dim)
# fea_dim = net.classifier.in_features
# net = net.to(device)
if args.arch=='ResNet18':
net = torchvision.models.resnet18(pretrained=True).to(device)
model_wight_path = "checkpoints/garbage/ResNet18/best_model.pth"
assert os.path.exists(model_wight_path), "file {} dose not exist.".format(model_wight_path) # 若路径不存在,则打印信息
net.load_state_dict(torch.load(model_wight_path, map_location=device), strict=False)
net.fc = nn.Sequential(
nn.Linear(net.fc.in_features, 256),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(256, args.train_class_num)
)
if args.arch == 'ResNet50':
net = torchvision.models.resnet50(pretrained=True).to(device)
model_wight_path = "checkpoints/garbage/ResNet50/best_model.pth"
assert os.path.exists(model_wight_path), "file {} dose not exist.".format(model_wight_path) # 若路径不存在,则打印信息
net.load_state_dict(torch.load(model_wight_path, map_location=device), strict=False)
net.fc = nn.Sequential(
nn.Linear(net.fc.in_features, 256),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(256, args.train_class_num)
)
if args.arch == 'EfficientNet_B5':
# net = torchvision.models.efficientnet_b5(pretrained=True).to(device)
net = efficientnet.efficientnet_b5().to(device)
# model_wight_path = "checkpoints/garbage/EfficientNet_B5/efficientnetb5.pth"
model_wight_path = "checkpoints/garbage/EfficientNet_B5/best_model.pth"
assert os.path.exists(model_wight_path), "file {} dose not exist.".format(model_wight_path) # 若路径不存在,则打印信息
net.load_state_dict(torch.load(model_wight_path, map_location=device), strict=False)
net.classifier= nn.Sequential(
nn.Dropout(p=0.4, inplace=True),
nn.Linear(2048, args.train_class_num),
)
if args.arch == 'EfficientNet_B7':
# net = torchvision.models.efficientnet_b7(pretrained=True).to(device)
net = efficientnet.efficientnet_b7().to(device)
net.classifier= nn.Sequential(
nn.Dropout(p=0.4, inplace=True),
nn.Linear(2048, args.train_class_num),
)
if args.arch == 'ResNext101_32x16d_wsl':
net = torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x16d_wsl')
net.fc = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(2048, args.train_class_num)
)
if args.arch == 'Resnext101_32x8d_wsl':
net = torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x8d_wsl')
net.fc = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(2048, args.train_class_num)
)
if args.arch == 'Resnext50_32x4d':
net = torchvision.models.resnext50_32x4d(pretrained=True).to(device)
net.fc = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(2048, args.train_class_num)
)
# from efficientnet_pytorch import EfficientNet
# model = EfficientNet.from_pretrained('efficientnet-b0')
# model = EfficientNet.from_pretrained(,num_classes=args.train_class_num)
if args.arch == 'EfficientNet_B3':
net = torchvision.models.efficientnet_b3(pretrained=True).to(device)
net.classifier= nn.Sequential(
nn.Linear(1536, 256),
nn.ReLU(),
nn.Dropout(p=0.4),
nn.Linear(256, args.train_class_num),
# nn.Dropout(p=0.4, inplace=True),
# nn.Linear(1024, args.train_class_num),
)
if device == 'cuda':
net = torch.nn.DataParallel(net)
cudnn.benchmark = True
if args.resume!=None:
# Load checkpoint.
if os.path.isfile(args.resume):
print('==> Resuming from checkpoint..')
#for cpu load cuda model
checkpoint = torch.load(args.resume,map_location=torch.device('cpu'))
net.load_state_dict({
k.replace('module.', ''): v for k, v in checkpoint['net'].items()})
#for gpu load cuda model for cpu load cpu model
# checkpoint = torch.load(args.resume)
# net.load_state_dict(checkpoint['net'])
# best_acc = checkpoint['acc']
# print("BEST_ACCURACY: "+str(best_acc))
start_epoch = checkpoint['epoch']
logger = Logger(os.path.join(args.checkpoint, 'log.txt'), resume=True)
else:
print("=> no checkpoint found at '{}'".format(args.resume))
else:
logger = Logger(os.path.join(args.checkpoint, 'log.txt'))
logger.set_names(['Epoch', 'Learning Rate', 'Train Loss','Train Acc.', 'Test Loss', 'Test Acc.'])
criterion = nn.CrossEntropyLoss()
# criterion = LabelSmoothSoftmaxCE(lb_pos=0.9, lb_neg=5e-3)
# criterion = LabelSmoothingLoss(classes=args.train_class_num, smoothing=0.1)
# criterion = FocalLoss(alpha=0.25)
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
# optimizer = optim.RAdam(net.parameters(),lr=args.lr,betas=(0.9, 0.999), eps=1e-8,weight_decay=5e-4)
# scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.7, patience=3, verbose=True)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=2, verbose=False)
scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=3, T_mult=2)
scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=2, T_mult=2,eta_min = 1e-5)
scheduler = lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.5)
# epoch=0
best_ac=0
if not args.evaluate:
for epoch in range(start_epoch, args.es):
print('\nEpoch: %d Learning rate: %f' % (epoch+1, optimizer.param_groups[0]['lr']))
# adjust_learning_rate(optimizer, epoch, args.lr, step=20)
train_loss, train_acc = train(net, trainloader, optimizer, criterion, device, train_classes)
if epoch == args.es - 1:
save_model(net, None, epoch, os.path.join(args.checkpoint,'last_model.pth'))
test_loss, test_acc = 0, 0
try:
test_loss, test_acc = test(epoch, net, trainloader, testloader, criterion, device, test_classes)
except:
pass
# scheduler.step(test_loss)
scheduler.step(train_loss)
if best_ac<test_acc:
best_ac=test_acc
print("The best Acc: ",best_ac)
# save_model(net, None, epoch, os.path.join(args.checkpoint, 'best_model.pth'))
torch.save(net.state_dict(), os.path.join(args.checkpoint, 'best_model.pth'))
# save_model(net, best_ac, epoch, os.path.join(args.checkpoint, 'best_model.pth'))
#
logger.append([epoch+1, optimizer.param_groups[0]['lr'], train_loss, train_acc, test_loss, test_acc])
# plot_feature(net, trainloader, device, args.plotfolder,train_classes, epoch=epoch,
# plot_class_num=args.train_class_num, maximum=args.plot_max, plot_quality=args.plot_quality)
# if (epoch+1)%20==0:
# try:
# test(epoch, net, trainloader, testloader, criterion, device,test_classes)
# except:
# pass
test(99999, net, trainloader, testloader, criterion, device, test_classes)
# plot_feature(net, testloader, device, args.plotfolder,train_classes, epoch="test",
# plot_class_num=args.train_class_num+1, maximum=args.plot_max, plot_quality=args.plot_quality)
logger.close()
# Training
def train(net,trainloader,optimizer,criterion,device,train_classes):
net.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(trainloader):
onehot_targets_index=[train_classes.index(i) for i in targets]
targets=torch.LongTensor(onehot_targets_index)
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = net(inputs)
# onehot_targets=torch.zeros((outputs.shape[0],outputs.shape[1]))
# onehot_targets[range(outputs.shape[0]), onehot_targets_index]=1
loss = criterion(outputs, targets)
# loss = torch.nn.functional.cross_entropy(outputs, targets)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
return train_loss/(batch_idx+1), correct/total
def test(epoch, net, trainloader, testloader, criterion, device, test_classes):
net.eval()
test_loss = 0
correct = 0
total = 0
scores, labels = [], []
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(testloader):
onehot_targets_index = [test_classes.index(i) for i in targets]
targets = torch.LongTensor(onehot_targets_index)
# image_2 = transforms.RandomAffine(degrees=0, translate=(0.05, 0.05))(inputs).to(device)
# image_3 = transforms.RandomHorizontalFlip()(inputs).to(device)
# image_4 = Cutout(probability=0.5, size=64, mean=[0.0, 0.0, 0.0])(inputs).to(device)
# image_5 = transforms.RandomVerticalFlip()(inputs).to(device)
inputs