Pytorch实现102类鲜花分类(VGG19和ResNet152模型)
本文主要讲解该算法的实现过程,原理部分需读者自行研究,可以找一些论文之类的。
实验环境
python3.6+pytorch1.2+cuda10.1
数据集
102 Category Flower Dataset数据集由102类产自英国的花卉组成,每类由40-258张图片组成
至于数据集我放个链接给大家,并且是划分好的数据集:https://download.csdn.net/download/ntntg/15535184?spm=1001.2014.3001.5503
接下来是代码的实现过程
导入需要的库
import pandas as pd
import torch
import torch.nn as nn
import torch.utils.data as D
import torchvision
from torchvision import transforms
import time
import os
import matplotlib.pyplot as plt
数据加载
# 读取文件
train_path = pd.read_csv('E:/花卉分类考核项目/train.txt', sep=' ', names=['name', 'classes'])
test_path = pd.read_csv('E:/花卉分类考核项目/test.txt', sep=' ', names=['name', 'classes'])
valid_path = pd.read_csv('E:/花卉分类考核项目/valid.txt', sep=' ', names=['name', 'classes'])
# 数据增强
data_transforms = {
'train': transforms.Compose([
transforms.RandomRotation(45), # 随机旋转,-45到45度之间随机
transforms.CenterCrop(224), # 从中心开始裁剪
transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转 选择一个概率
transforms.RandomVerticalFlip(p=0.5), # 随机垂直翻转
transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1), # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]) # 均值,标准差
]),
'valid': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
]),
'test': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize