102类花卉分类数据集(已划分,有训练集、测试集、验证集标签)+完整运行代码
数据集已经经过处理划分好了,并且附带了训练集,测试集,验证集的txt文本标签。配合完整运行代码即可训练。
数据集链接在文章中间部分
代码如下:
VGG19
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as D
import torchvision
from torchvision import transforms
import time
import os
import matplotlib.pyplot as plt
# 读取文件
train_path = pd.read_csv('E:/花卉分类考核项目/train.txt', sep=' ', names=['name', 'classes'])
test_path = pd.read_csv('E:/花卉分类考核项目/test.txt', sep=' ', names=['name', 'classes'])
valid_path = pd.read_csv('E:/花卉分类考核项目/valid.txt', sep=' ', names=['name', 'classes'])
# 数据增强
data_transforms = {
'train': transforms.Compose([
transforms.RandomRotation(45), # 随机旋转,-45到45度之间随机
transforms.CenterCrop(224), # 从中心开始裁剪
transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转 选择一个概率
transforms.RandomVerticalFlip(p=0.5), # 随机垂直翻转
transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1), # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]) # 均值,标准差
]),
'valid': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
]),
'test': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
}
# 获取三个数据集
data_dir = 'E:/花卉分类考核项目/data'
image_datasets = {
x: torchvision.datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in
['train', 'valid', 'test']}
traindataset = image_datasets['train']
validdataset = image_datasets['valid']
testdataset = image_datasets['test']
batch_size = 1
dataloaders = {
x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size,
shuffle=True) for x in ['train', 'valid', 'test']}
print(dataloaders)
traindataloader = dataloaders['train']
validdataloader = dataloaders['valid']
testdataloader = dataloaders['test']
dataset_sizes = {
x: len(image_datasets[x]) for x in ['train', 'valid', 'test']}
# 是否用GPU训练
train_on_gpu = torch.cuda.is_available()
if not train_on_gpu:
print('CUDA is not available. Training on CPU ...')
else:
print('CUDA is available! Training on GPU ...')
device = torch.device("cuda:0" if torch.cuda.is_available(