今天进行了一项图片分类工作用到了resnet50来进行,现对代码进行记录。本文前半部分中快速搭建resnet网络进行训练这一部分代码主要参考博客使用 resnet50 网络训练多分类模型完整代码,我在这篇博客的基础上对代码进行了一定的修改和添加注释;此外,在本博客的后一部分内容主要是我自己写的一个对前面训练好的resnet50网络模型进行快速搭建并进行图片检测的一个脚本。
1.resnet50模型快速搭建并进行数据集训练
这里的快速部署主要是用到了pytorch中自带的resnet50预训练网络,省去了自己依次添加结构的过程。
关于在训练准备前的数据集的文件夹存放结构形式,如原博客中所述:
--train
|--dog
|--cat ...
--valid
|--dog
|--cat
在项目根目录下创建一个 data 文件夹(名字可以任意),用来存放数据集。
在 data 文件夹下依次创建 train、valid、test 文件夹(test 文件夹可以没有,依据自己需求确定)
在 train 文件夹下创建类别文件夹,如 cat、dog 等
在类别文件夹如 cat 下,存放 cat 类别的图片。
…
在 val 文件夹下创建类别文件夹,如 cat、dog 等
在类别文件夹如 cat 下,存放 cat 类别的图片。
…
…
————————————————
注:这一引用部分摘自CSDN博主「悄悄地努力」的原创文章
原文链接:https://blog.csdn.net/weixin_46034990/article/details/124859877
此外,还要注意的是还要在代码文件夹根目录下新建一个名为 models用于保存训练过程中的模型文件。同时要根据自己的训练集情况来自定义修改代码中的目标分类的种类数变量num_classes。
下面展示模型部署及训练代码:
import torch
from torchvision import datasets, models, transforms
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import time
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
# 一、建立数据集
# animals-6
# --train
# |--dog
# |--cat
# ...
# --valid
# |--dog
# |--cat
# ...
# --test
# |--dog
# |--cat
# ...
# 我的数据集中 train 中每个类别60张图片,valid 中每个类别 10 张图片,test 中每个类别几张到几十张不等,一共 6 个类别。
# 二、数据增强
# 建好的数据集在输入网络之前先进行数据增强,包括随机 resize 裁剪到 256 x 256,随机旋转,随机水平翻转,中心裁