掌握PyTorch图片分类的简明教程 | 附完整代码

本文提供了一个使用PyTorch进行交通标志图片分类的简明教程,包括数据介绍和网络构建。数据集包含62类交通标志,训练集4572张,测试集2520张。教程中使用ResNet18进行迁移学习,并分享了训练过程与结果,测试集准确率可达0.86。
摘要由CSDN通过智能技术生成

深度学习的比赛中,图片分类是很常见的比赛,同时也是很难取得特别高名次的比赛,因为图片分类已经被大家研究的很透彻,一些开源的网络很容易取得高分。如果大家还掌握不了使用开源的网络进行训练,再慢慢去模型调优,很难取得较好的成绩。

1.数据介绍

数据下载地址:
https://download.csdn.net/download/xiaosongshine/11128410
这次的实战使用的数据是交通标志数据集,共有62类交通标志。其中训练集数据有4572张照片(每个类别大概七十个),测试数据集有2520张照片(每个类别大概40个)。数据包含两个子目录分别train与test:
为什么还需要测试数据集呢?这个测试数据集不会拿来训练,是用来进行模型的评估与调优。
在这里插入图片描述
train与test每个文件夹里又有62个子文件夹,每个类别在同一个文件夹内:
在这里插入图片描述
我从中打开一个文件间,把里面图片展示出来:
在这里插入图片描述

其中每张照片都类似下面的例子,1001003的大小。100是照片的照片的长和宽,3是什么呢?这其实是照片的色彩通道数目,RGB。彩色照片存储在计算机里就是以三维数组的形式。我们送入网络的也是这些数组。

2.网络构建

1.导入Python包,定义一些参数

 1import torch as t
 2import torchvision as tv
 3import os
 4import time
 5import numpy as np
 6from tqdm import tqdm
 7
 8
 9class DefaultConfigs(object):
10
11 data_dir = "./traffic-sign/"
12 data_list = ["train","test"]
13
14 lr = 0.001
15 epochs = 10
16 num_classes = 62
17 image_size = 224
18 batch_size = 40
19 channels = 3
20 gpu = "0"
21 train_len = 4572
22 test_len = 2520
23 use_gpu = t.cuda.is_available()
24
25config = DefaultConfigs()

2.数据准备,采用PyTorch提供的读取方式
注意一点Train数据需要进行随机裁剪,Test数据不要进行裁剪了

 1normalize = tv.transforms.Normalize(mean = [0.485, 0.456, 0.406],
 2 std = [0.229, 0.224, 0.225]
 3 )
 4
 5transform = {
 6 config.data_list[0]:tv.transforms.Compose(
 7 [tv.transforms.Resize([224,224]),tv.transforms.CenterCrop([224,224]),
 8 tv.transforms.ToTensor(),normalize]#tv.transforms.Resize 用于重设图片大小
 9 ) ,
10 config.data_list[1]:tv.transforms.Compose(
11 [tv.transforms.Resize([224,224]),tv.transforms.ToTensor(),normalize]
12 ) 
13}
14
15datasets = {
16 x:tv.datasets.ImageFolder(root = os.path.join(config.data_dir,x),transform=transform[x])
17 for x in config.data_list
18}
19
20dataloader = {
21 x:t.utils.data.DataLoader(dataset= datasets[x],
22 batch_size=config.batch_size,
23
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值