原代码及数据集下载:【度学习pytorch实战六:ResNet50网络图像分类篇自建花数据集图像分类(5类)超详细代码 - CSDN App】/windows度学习pytorch实战六:ResNet50网络图像分类篇自建花数据集图像分类(5类)超详细代码_resnet50图像分类-CSDN博客
1.数据集
下载数据集至本地(点击原文链接进行下载)
2.代码
2.1数据预处理
**windows:**
import os
from shutil import copy
import random
def mkfile(file):
if not os.path.exists(file):
os.makedirs(file)
# 获取 photos 文件夹下除 .txt 文件以外所有文件夹名(即3种分类的类名)
file_path = 'yourdataroad/flower_photos'
flower_class = [cla for cla in os.listdir(file_path) if ".txt" not in cla]
# 创建 训练集train 文件夹,并由3种类名在其目录下创建3个子目录
mkfile('flower_data/train')
for cla in flower_class:
mkfile('flower_data/train/' + cla)
# 创建 验证集val 文件夹,并由3种类名在其目录下创建3个子目录
mkfile('flower_data/val')
for cla in flower_class:
mkfile('flower_data/val/' + cla)
# 划分比例,训练集 : 验证集 = 9 : 1
split_rate = 0.1
# 遍历3种花的全部图像并按比例分成训练集和验证集
for cla in flower_class:
cla_path = file_path + '/' + cla + '/' # 某一类别动作的子目录
images = os.listdir(cla_path) # iamges 列表存储了该目录下所有图像的名称
num = len(images)
eval_index = random.sample(images, k=int(num * split_rate)) # 从images列表中随机抽取 k 个图像名称
for index, image in enumerate(images):
# eval_index 中保存验证集val的图像名称
if image in eval_index:
image_path = cla_path + image
new_path = 'flower_data/val/' + cla
copy(image_path, new_path) # 将选中的图像复制到新路径
# 其余的图像保存在训练集train中
else:
image_path = cla_path + image
new_path = 'flower_data/train/' + cla
copy(image_path, new_path)
print("\\r[{}] processing [{}/{}]".format(cla, index + 1, num), end="") # processing bar
print()
print("processing done!")
mac设备报错:
File "/*/pytorch/resnet/data预处理.py", line 32, in <module>
images = os.listdir(cla_path) # iamges 列表存储了该目录下所有图像的名称
NotADirectoryError: [Errno 20] Not a directory: '/yourdataroad/flower_data/.DS_Store/'
错误信息表明在尝试列出目录内容时,cla_path
指向了一个不是目录的文件,具体来说是 /yourdataroad/flower_data/.DS_Store/
。.DS_Store
文件是 macOS 系统用来存储文件夹的自定义属性的隐藏文件,例如文件的图标位置和背景色等,它不是一个目录。
要修正这个错误,你需要在获取目录列表时排除 .DS_Store
文件
**macos**
import os
from shutil import copy
import random
def mkfile(file):
if not os.path.exists(file):
os.makedirs(file)
# 获取 photos 文件夹下除 .txt 文件以外所有文件夹名(即3种分类的类名)
file_path = '/yourdataroad/flower_data'
flower_class = [cla for cla in os.listdir(file_path) if os.path.isdir(os.path.join(file_path, cla)) and ".txt" not in cla]
# 创建 训练集train 文件夹,并由3种类名在其目录下创建3个子目录
mkfile('flower_data/train')
for cla in flower_class:
mkfile('flower_data/train/' + cla)
# 创建 验证集val 文件夹,并由3种类名在其目录下创建3个子目录
mkfile('flower_data/val')
for cla in flower_class:
mkfile('flower_data/val/' + cla)
# 划分比例,训练集 : 验证集 = 9 : 1
split_rate = 0.1
# 遍历3种花的全部图像并按比例分成训练集和验证集
for cla in flower_class:
cla_path = os.path.join(file_path, cla) # 某一类别动作的子目录
images = [img for img in os.listdir(cla_path) if os.path.isfile(os.path.join(cla_path, img))] # 确保是文件
num = len(images)
eval_index = random.sample(images, k=int(num * split_rate)) # 从images列表中随机抽取 k 个图像名称
for index, image in enumerate(images):
# eval_index 中保存验证集val的图像名称
if image in eval_index:
image_path = os.path.join(cla_path, image)
new_path = os.path.join('flower_data/val', cla)
copy(image_path, new_path) # 将选中的图像复制到新路径
# 其余的图像保存在训练集train中
else:
image_path = os.path.join(cla_path, image)
new_path = os.path.join('flower_data/train', cla)
copy(image_path, new_path)
print("\\r[{}] processing [{}/{}]".format(cla, index + 1, num), end="") # processing bar
print()
print("processing done!")
代码运行结果:
2.2训练模型
假设你已知restnet网络原理,不知道也没关系。以下是在macos上运行的完整代码,windows或其他版本可以去看本文开头链接。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import numpy as np
import matplotlib.pyplot as plt
import os
# 定义数据转换
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
# 加载数据集
data_dir = '/Users/xiaqizai/Downloads/flower_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=4,
shuffle=True, num_workers=4)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
# 加载预训练的ResNet-50模型
model = models.resnet50(weights=models.ResNe