pytorch实现Resnet系列的分类任务

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档


前言

最近需要一系列传统分类方法做对比,所以就顺便把自己复现resnet系列分类实验的过程记录一下,还是老传统:文末有源码。


一、数据集格式

在这里插入图片描述在这里插入图片描述
其中training文件夹下的basal、her2都是要分类的类别,basal里面就是一张张图片了。
在这里给大家一个公开的10种猴子的分类数据集,已经分好类了,大家可以直接下载使用。数据集地址:https://www.kaggle.com/slothkong/10-monkey-species

二、训练部分

train.py文件下可以自行修改权值文件保存的地址和batch size:

if __name__ == '__main__':

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    if not os.path.exists('./logs'):
        os.makedirs('./logs')

    BATCH_SIZE = 16

修改数据集的地址

    train_dataset = datasets.ImageFolder("./datasets/training", transform=data_transform["train"])  # 训练集数据
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                                               num_workers=2)  # 加载数据
    len_train = len(train_dataset)
    val_dataset = datasets.ImageFolder("./datasets/validation", transform=data_transform["val"])  # 测试集数据
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                                             num_workers=2)  # 加载数据
    len_val = len(val_dataset)

可以自行选择网络类型,resnet50或者resnet34或者其他都可以,损失函数就只有一个CEloss,优化器使用的是adam,epoch根据自己的数据多少和bacth size自行修改。

    net = resnet50()
    loss_function = nn.CrossEntropyLoss()  # 设置损失函数
    optimizer = optim.Adam(net.parameters(), lr=0.0001)  # 设置优化器和学习率
    epoch = 100

三、评价

改完上面的参数后就可以训练了,训练结束之后可以对于分类结果进行评价,需要续改evaluate.py的相关内容,首先修改训练生产的权值文件的路径。

if __name__ == '__main__':
    model = torch.load("./logs/best.pth")

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    class_correct = [0.] * 10
    class_total = [0.] * 10
    y_test, y_pred = [], []
    X_test = []

下面修改验证集或者测试集的指向路径。

    data_transform = transforms.Compose([transforms.Resize(256),
                                       transforms.CenterCrop(224),
                                       transforms.ToTensor(),
                                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    val_dataset = datasets.ImageFolder("./datasets/validation", transform=data_transform)  # 测试集数据
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                                                 num_workers=2)  # 加载数据

    classes = val_dataset.classes

最后运行evaluate.py就可以。

结果

在这里插入图片描述

总结

以上就是今天要讲的内容,本文仅仅简单介绍了resnet分类网络的使用。
源码分享在网盘里面:网盘
提取码:57hs

  • 2
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值