Pytorch实现102类鲜花分类(102 Category Flower Dataset)

Pytorch实现102类鲜花分类(VGG19和ResNet152模型)

本文主要讲解该算法的实现过程,原理部分需读者自行研究,可以找一些论文之类的。


实验环境

python3.6+pytorch1.2+cuda10.1


数据集

102 Category Flower Dataset数据集由102类产自英国的花卉组成,每类由40-258张图片组成

至于数据集我放个链接给大家,并且是划分好的数据集https://download.csdn.net/download/ntntg/15535184?spm=1001.2014.3001.5503


接下来是代码的实现过程

导入需要的库

import pandas as pd
import torch
import torch.nn as nn
import torch.utils.data as D
import torchvision
from torchvision import transforms
import time
import os
import matplotlib.pyplot as plt

数据加载

# 读取文件
train_path = pd.read_csv('E:/花卉分类考核项目/train.txt', sep=' ', names=['name', 'classes'])
test_path = pd.read_csv('E:/花卉分类考核项目/test.txt', sep=' ', names=['name', 'classes'])
valid_path = pd.read_csv('E:/花卉分类考核项目/valid.txt', sep=' ', names=['name', 'classes'])
# 数据增强
data_transforms = {
   
    'train': transforms.Compose([
        transforms.RandomRotation(45),  # 随机旋转,-45到45度之间随机
        transforms.CenterCrop(224),  # 从中心开始裁剪
        transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转 选择一个概率
        transforms.RandomVerticalFlip(p=0.5),  # 随机垂直翻转
        transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),  # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])  # 均值,标准差
    ]),
    'valid': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize
  • 6
    点赞
  • 41
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
PyTorch 102花的分类数据集是一个经典的机器学习数据集,用于花朵分类任务。该数据集由102个不同种的花朵组成,每个别包含一些样本。每个样本都包含花朵的图像和相应的标签。在此数据集中,每个花朵别都具有不同的形状、颜色和纹理。 使用PyTorch进行花朵分类任务时,首先需要加载数据集。PyTorch提供了一个方便的数据加载器,可以方便地将数据集加载到模型中。加载数据集时,需要将每个图像进行预处理,以使其适合模型输入。预处理可能包括调整图像大小、标准化像素值和数据增强等。 加载完数据集后,可以构建一个深度学习模型来训练和测试数据集。可以使用PyTorch中的各种深度学习模型构建函数,如卷积神经网络(CNN)或预训练的模型(如ResNet、AlexNet等)。 在训练过程中,需要将数据集划分为训练集和验证集。训练集用于模型参数的更新,而验证集用于监控模型在未见过的数据上的性能。训练过程通常包括定义损失函数、选择优化器、迭代数据集、计算梯度和更新模型参数等步骤。 完成模型的训练后,可以使用测试集来评估模型的性能。可以计算准确率、精确率、召回率等指标来评估模型在不同别上的分类能力。 最后,根据模型的性能,可以使用训练好的模型对新的花朵图像进行分类预测。可以加载模型并将测试图像输入到模型中,然后将输出结果与数据集的标签进行比较,以获得花朵的分类预测。 总之,PyTorch 102花的分类数据集是一个广泛使用的用于花朵分类任务的数据集。通过使用PyTorch提供的强大功能,可以构建和训练深度学习模型来实现准确的花朵分类预测。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值