图像分类算法SAM,ICLR2021论文讲解,代码。并训练自己的数据集

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档


前言

提示:想了解图像分类最新的算法的可以学习一下:在这里插入图片描述

在当今严重过度参数化的模型中,训练损失的值几乎无法保证模型的泛化能力。事实上,像通常所做的那样,仅优化训练损失值很容易导致模型质量不理想。在先前将损失景观的几何形状与泛化联系起来的工作的推动下,我们引入了一种新颖、有效的程序,以同时最小化损失值和损失锐度。特别是,我们的程序“锐度感知最小化”(SAM)寻找位于具有一致低损耗的邻域中的参数;这个公式导致了一个最小-最大优化问题,在这个问题上可以有效地执行梯度下降。我们提出的实证结果表明,SAM 改进了各种基准数据集(例如 CIFAR-{10, 100}、ImageNet、微调任务)和模型的模型泛化能力,从而为一些。此外,我们发现 SAM 本身对标签噪声的鲁棒性与专门针对噪声标签学习的最先进程序所提供的鲁棒性相当。

文章链接:https://openreview.net/pdf?id=6Tm1mposlrM

官方代码链接:https://github.com/google-research/sam

一、作者做了哪些工作和贡献

文章提出了一种新的高效、可扩展且有效的方法来提高模型泛化能力,该方法直接利用损失景观的几何形状及其与泛化的联系,并且是对现有技术的有力补充。我们特别做出以下贡献:

  1. 我们引入了锐度感知最小化(SAM),这是一种通过同时最小化损失值和损失锐度来提高模型泛化能力的新颖程序。 SAM 通过寻找位于具有一致低损失值的邻域中的参数(而不是仅本身具有低损失值的参数,如图 1 的中图和右图所示)来发挥作用,并且可以高效且轻松地实现。
  2. 我们通过严格的实证研究表明,使用 SAM 可以提高一系列广泛研究的计算机视觉任务(例如 CIFAR-{10, 100}、ImageNet、微调任务)和模型的模型泛化能力,如左图所示图 1. 例如,应用 SAM 为许多已经深入研究的任务(例如 ImageNet、CIFAR-{10, 100}、SVHN、Fashion-MNIST 和标准)带来了新颖的最先进性能一组图像分类微调任务(例如,Flowers、Stanford Cars、Oxford Pets 等)。
  3. 我们表明,SAM 还提供了与专门针对噪声标签学习的最先进程序所提供的鲁棒性来处理标签噪声。 • 通过 SAM 提供的镜头,我们通过提出一个有前途的新锐度概念(我们称之为 m-锐度),进一步阐明了损失锐度和泛化之间的联系。
    在这里插入图片描述
    具体论文,大家自己下载学习,研究!!不过多赘述!

二、实验结果

在这里插入图片描述
在这里插入图片描述
!!!!重点来了

3.如何训练自己的数据集

提示:文章作者并未使用Pytorch框架,这里我已Pytorch框架为例:
我参考的代码仓库是:https://github.com/davda54/sam
由于他们都是针对cifar-100数据集进行训练的,如何训练自己的数据集?
根据上述的仓库代码改写,训练自己的数据集。

/dataset
  /train
	 /class1 
	    image1.jpg
	    image2.jpg
	    ......jpg
	 /class2
	    image1.jpg
	    image2.jpg
	   .....、
  /test	
     /class1 
	    image1.jpg
	    image2.jpg
	    .....
	 /class2
	    image1.jpg
	    image2.jpg
	   .....、
	   

3.1.读入数据

在下面图片文件下,创建一个读取数据的python文件,my_dataset.py

在这里插入图片描述
注意,这里,train_dataset 和 test_dataset,是你数据集的具体路径,我这里的演示是ubuntu环境

windows 下的路径:C: \\dataset\\path\\train 具体路径,参考网上其他教程

import os
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets


def load_data(batch_size, num_workers):
    train_transform = transforms.Compose([
        transforms.Resize((32, 32)),  # 调整大小以适应网络输入
        transforms.ToTensor(),
    ])
    test_transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
    ])

    train_dataset = datasets.ImageFolder(root='/path/train/', transform=train_transform)
    test_dataset = datasets.ImageFolder(root='/path/test/', transform=test_transform)

    print(f"Number of training images: {len(train_dataset)}")
    print(f"Number of testing images: {len(test_dataset)}")

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, test_loader

3.2 修改train.py训练脚本


import argparse
import torch
import os

from model.wide_res_net import WideResNet
from model.smooth_cross_entropy import smooth_crossentropy
from data.my_dataset import load_data
from utility.log import Log
from utility.initialize import initialize
from utility.step_lr import StepLR
from utility.bypass_bn import enable_running_stats, disable_running_stats

import sys; sys.path.append("..")
from sam import SAM


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--adaptive", default=True, type=bool, help="True if you want to use the Adaptive SAM.")
    parser.add_argument("--batch_size", default=8, type=int, help="Batch size used in the training and validation loop.")
    parser.add_argument("--depth", default=16, type=int, help="Number of layers.")
    parser.add_argument("--dropout", default=0.0, type=float, help="Dropout rate.")
    parser.add_argument("--epochs", default=100, type=int, help="Total number of epochs.")
    parser.add_argument("--label_smoothing", default=0.1, type=float, help="Use 0.0 for no label smoothing.")
    parser.add_argument("--learning_rate", default=0.1, type=float, help="Base learning rate at the start of the training.")
    parser.add_argument("--momentum", default=0.9, type=float, help="SGD Momentum.")
    parser.add_argument("--threads", default=2, type=int, help="Number of CPU threads for dataloaders.")
    parser.add_argument("--rho", default=2.0, type=int, help="Rho parameter for SAM.")
    parser.add_argument("--weight_decay", default=0.0005, type=float, help="L2 weight decay.")
    parser.add_argument("--width_factor", default=8, type=int, help="How many times wider compared to normal ResNet.")
    args = parser.parse_args()

    initialize(args, seed=42)
    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    device = torch.device("cpu")

    train_loader, test_loader = load_data(args.batch_size, args.threads)
    log = Log(log_each=10)
    model = WideResNet(args.depth, args.width_factor, args.dropout, in_channels=3, labels=10).to(device)

    base_optimizer = torch.optim.SGD
    optimizer = SAM(model.parameters(), base_optimizer, rho=args.rho, adaptive=args.adaptive, lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, args.learning_rate, args.epochs)

    best_accuracy = 0.0
    save_path = 'best_model.pth'

    for epoch in range(args.epochs):
        model.train()
        log.train(len_dataset=len(train_loader))

        for batch in train_loader:
            inputs, targets = (b.to(device) for b in batch)

            # first forward-backward step
            enable_running_stats(model)
            predictions = model(inputs)
            loss = smooth_crossentropy(predictions, targets, smoothing=args.label_smoothing)
            loss.mean().backward()
            optimizer.first_step(zero_grad=True)

            # second forward-backward step
            disable_running_stats(model)
            smooth_crossentropy(model(inputs), targets, smoothing=args.label_smoothing).mean().backward()
            optimizer.second_step(zero_grad=True)

            with torch.no_grad():
                correct = torch.argmax(predictions.data, 1) == targets
                log(model, loss.cpu(), correct.cpu(), scheduler.lr())
                scheduler(epoch)

        model.eval()
        log.eval(len_dataset=len(test_loader))

        correct_sum = 0
        total = 0

        with torch.no_grad():
            for batch in test_loader:
                inputs, targets = (b.to(device) for b in batch)

                predictions = model(inputs)
                loss = smooth_crossentropy(predictions, targets)
                correct = torch.argmax(predictions, 1) == targets
                correct_sum += correct.sum().item()
                total += targets.size(0)
                log(model, loss.cpu(), correct.cpu())

        accuracy = correct_sum / total

        # Save the model if the accuracy is the best
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save(model.state_dict(), save_path)
            print(f"Saved best model with accuracy: {accuracy:.4f}")

    log.flush()

代码中我们添加了保存模型的权重代码。

if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save(model.state_dict(), save_path)
            print(f"Saved best model with accuracy: {accuracy:.4f}")

3.3 创建推理验证的test.py文件


import torch
from model.wide_res_net import WideResNet
from data.my_dataset import MyDataset

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = WideResNet(args.depth, args.width_factor, args.dropout, in_channels=3, labels=10).to(device)
model.load_state_dict(torch.load('best_model.pth'))
model.eval()

dataset = MyDataset(args.batch_size, args.threads)
for batch in dataset.test:
    inputs, targets = (b.to(device) for b in batch)
    predictions = model(inputs)
    # Add code to evaluate the predictions

开始训练了。
在这里插入图片描述

总结

后期我会发布完整的源代码,并给出地址。
谢谢,如果对你有帮助,给个赞吧!!!

  • 24
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
《自适应通用广义PageRank图神经网络》是在ICLR2021中发布的一篇论文。这篇论文提出了一种新的图神经网络模型,称为自适应通用广义PageRank图神经网络。 传统的图神经网络通常使用节点和边的特征来进行节点分类和链接预测等任务,但在处理大规模图时存在计算复杂度高和难以处理隐含图结构的问题。为了解决这些问题,这篇论文引入了PageRank算法和广义反向传播法,在保留图结构信息的同时有效地处理大规模图数据。 这个模型的核心思想是将PageRank算法和图神经网络相结合,通过模拟随机游走过程对节点和边进行随机采样,并利用广义反向传播法将PageRank值传播给相邻的节点。通过这种方式,网络可以在保留图结构信息的同时,有效地进行节点嵌入和预测任务。 另外,这篇论文还提出了自适应的机制,允许网络根据不同的任务和数据集调整PageRank算法的参数。通过自适应机制,网络可以更好地适应不同的图结构和特征分布,提高模型的泛化能力。 实验证明,这个自适应通用广义PageRank图神经网络在节点分类、链路预测和社区检测等任务上都取得了比较好的效果。与传统的模型相比,该模型在保留图结构信息的同时,具有更高的计算效率和更好的预测能力。 总的来说,这篇论文提出了一种新颖的图神经网络模型,通过将PageRank算法与图神经网络相结合,可以有效地处理大规模图数据,并通过自适应机制适应不同的任务和数据集。这个模型在图神经网络领域具有一定的研究和应用价值。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值