提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
前言
提示:想了解图像分类最新的算法的可以学习一下:
在当今严重过度参数化的模型中,训练损失的值几乎无法保证模型的泛化能力。事实上,像通常所做的那样,仅优化训练损失值很容易导致模型质量不理想。在先前将损失景观的几何形状与泛化联系起来的工作的推动下,我们引入了一种新颖、有效的程序,以同时最小化损失值和损失锐度。特别是,我们的程序“锐度感知最小化”(SAM)寻找位于具有一致低损耗的邻域中的参数;这个公式导致了一个最小-最大优化问题,在这个问题上可以有效地执行梯度下降。我们提出的实证结果表明,SAM 改进了各种基准数据集(例如 CIFAR-{10, 100}、ImageNet、微调任务)和模型的模型泛化能力,从而为一些。此外,我们发现 SAM 本身对标签噪声的鲁棒性与专门针对噪声标签学习的最先进程序所提供的鲁棒性相当。
文章链接:https://openreview.net/pdf?id=6Tm1mposlrM
官方代码链接:https://github.com/google-research/sam
一、作者做了哪些工作和贡献
文章提出了一种新的高效、可扩展且有效的方法来提高模型泛化能力,该方法直接利用损失景观的几何形状及其与泛化的联系,并且是对现有技术的有力补充。我们特别做出以下贡献:
- 我们引入了锐度感知最小化(SAM),这是一种通过同时最小化损失值和损失锐度来提高模型泛化能力的新颖程序。 SAM 通过寻找位于具有一致低损失值的邻域中的参数(而不是仅本身具有低损失值的参数,如图 1 的中图和右图所示)来发挥作用,并且可以高效且轻松地实现。
- 我们通过严格的实证研究表明,使用 SAM 可以提高一系列广泛研究的计算机视觉任务(例如 CIFAR-{10, 100}、ImageNet、微调任务)和模型的模型泛化能力,如左图所示图 1. 例如,应用 SAM 为许多已经深入研究的任务(例如 ImageNet、CIFAR-{10, 100}、SVHN、Fashion-MNIST 和标准)带来了新颖的最先进性能一组图像分类微调任务(例如,Flowers、Stanford Cars、Oxford Pets 等)。
- 我们表明,SAM 还提供了与专门针对噪声标签学习的最先进程序所提供的鲁棒性来处理标签噪声。 • 通过 SAM 提供的镜头,我们通过提出一个有前途的新锐度概念(我们称之为 m-锐度),进一步阐明了损失锐度和泛化之间的联系。
具体论文,大家自己下载学习,研究!!不过多赘述!
二、实验结果
!!!!重点来了
3.如何训练自己的数据集
提示:文章作者并未使用Pytorch框架,这里我已Pytorch框架为例:
我参考的代码仓库是:https://github.com/davda54/sam
由于他们都是针对cifar-100数据集进行训练的,如何训练自己的数据集?
根据上述的仓库代码改写,训练自己的数据集。
/dataset
/train
/class1
image1.jpg
image2.jpg
......jpg
/class2
image1.jpg
image2.jpg
.....、
/test
/class1
image1.jpg
image2.jpg
.....
/class2
image1.jpg
image2.jpg
.....、
3.1.读入数据
在下面图片文件下,创建一个读取数据的python文件,my_dataset.py
注意,这里,train_dataset 和 test_dataset,是你数据集的具体路径,我这里的演示是ubuntu环境
windows 下的路径:C: \\dataset\\path\\train 具体路径,参考网上其他教程
import os
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets
def load_data(batch_size, num_workers):
train_transform = transforms.Compose([
transforms.Resize((32, 32)), # 调整大小以适应网络输入
transforms.ToTensor(),
])
test_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
])
train_dataset = datasets.ImageFolder(root='/path/train/', transform=train_transform)
test_dataset = datasets.ImageFolder(root='/path/test/', transform=test_transform)
print(f"Number of training images: {len(train_dataset)}")
print(f"Number of testing images: {len(test_dataset)}")
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
return train_loader, test_loader
3.2 修改train.py训练脚本
import argparse
import torch
import os
from model.wide_res_net import WideResNet
from model.smooth_cross_entropy import smooth_crossentropy
from data.my_dataset import load_data
from utility.log import Log
from utility.initialize import initialize
from utility.step_lr import StepLR
from utility.bypass_bn import enable_running_stats, disable_running_stats
import sys; sys.path.append("..")
from sam import SAM
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--adaptive", default=True, type=bool, help="True if you want to use the Adaptive SAM.")
parser.add_argument("--batch_size", default=8, type=int, help="Batch size used in the training and validation loop.")
parser.add_argument("--depth", default=16, type=int, help="Number of layers.")
parser.add_argument("--dropout", default=0.0, type=float, help="Dropout rate.")
parser.add_argument("--epochs", default=100, type=int, help="Total number of epochs.")
parser.add_argument("--label_smoothing", default=0.1, type=float, help="Use 0.0 for no label smoothing.")
parser.add_argument("--learning_rate", default=0.1, type=float, help="Base learning rate at the start of the training.")
parser.add_argument("--momentum", default=0.9, type=float, help="SGD Momentum.")
parser.add_argument("--threads", default=2, type=int, help="Number of CPU threads for dataloaders.")
parser.add_argument("--rho", default=2.0, type=int, help="Rho parameter for SAM.")
parser.add_argument("--weight_decay", default=0.0005, type=float, help="L2 weight decay.")
parser.add_argument("--width_factor", default=8, type=int, help="How many times wider compared to normal ResNet.")
args = parser.parse_args()
initialize(args, seed=42)
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
train_loader, test_loader = load_data(args.batch_size, args.threads)
log = Log(log_each=10)
model = WideResNet(args.depth, args.width_factor, args.dropout, in_channels=3, labels=10).to(device)
base_optimizer = torch.optim.SGD
optimizer = SAM(model.parameters(), base_optimizer, rho=args.rho, adaptive=args.adaptive, lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
scheduler = StepLR(optimizer, args.learning_rate, args.epochs)
best_accuracy = 0.0
save_path = 'best_model.pth'
for epoch in range(args.epochs):
model.train()
log.train(len_dataset=len(train_loader))
for batch in train_loader:
inputs, targets = (b.to(device) for b in batch)
# first forward-backward step
enable_running_stats(model)
predictions = model(inputs)
loss = smooth_crossentropy(predictions, targets, smoothing=args.label_smoothing)
loss.mean().backward()
optimizer.first_step(zero_grad=True)
# second forward-backward step
disable_running_stats(model)
smooth_crossentropy(model(inputs), targets, smoothing=args.label_smoothing).mean().backward()
optimizer.second_step(zero_grad=True)
with torch.no_grad():
correct = torch.argmax(predictions.data, 1) == targets
log(model, loss.cpu(), correct.cpu(), scheduler.lr())
scheduler(epoch)
model.eval()
log.eval(len_dataset=len(test_loader))
correct_sum = 0
total = 0
with torch.no_grad():
for batch in test_loader:
inputs, targets = (b.to(device) for b in batch)
predictions = model(inputs)
loss = smooth_crossentropy(predictions, targets)
correct = torch.argmax(predictions, 1) == targets
correct_sum += correct.sum().item()
total += targets.size(0)
log(model, loss.cpu(), correct.cpu())
accuracy = correct_sum / total
# Save the model if the accuracy is the best
if accuracy > best_accuracy:
best_accuracy = accuracy
torch.save(model.state_dict(), save_path)
print(f"Saved best model with accuracy: {accuracy:.4f}")
log.flush()
代码中我们添加了保存模型的权重代码。
if accuracy > best_accuracy:
best_accuracy = accuracy
torch.save(model.state_dict(), save_path)
print(f"Saved best model with accuracy: {accuracy:.4f}")
3.3 创建推理验证的test.py文件
import torch
from model.wide_res_net import WideResNet
from data.my_dataset import MyDataset
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = WideResNet(args.depth, args.width_factor, args.dropout, in_channels=3, labels=10).to(device)
model.load_state_dict(torch.load('best_model.pth'))
model.eval()
dataset = MyDataset(args.batch_size, args.threads)
for batch in dataset.test:
inputs, targets = (b.to(device) for b in batch)
predictions = model(inputs)
# Add code to evaluate the predictions
开始训练了。
总结
后期我会发布完整的源代码,并给出地址。
谢谢,如果对你有帮助,给个赞吧!!!