使用Dataset制作好数据集之后,可以用Dataloader进行读取,然后用resnet34进行训练。
关于data_read数据预处理代码编写,我放在了付费内容里!
具体代码及注释如下
1 模块导入
其中data_read是利用Dataset制作数据集时写的文件
# 从data_read文件中读取函数
# data_read是创建的数据集制作函数
from data_read import ImageFloder, train_transform, test_transform
import numpy as np
import torch
# torch.nn用于网络的自定义
import torch.nn as nn
# torch.optim用于训练过程中参数的更新
import torch.optim as optim
# DataLoader用于储存数据,方便使用
from torch.utils.data import DataLoader
import torchvision
# 如果不是自己建立网络,可以从torchvision.models读取到已有的网络
# 然后对网络做适当的修改
from torchvision.models import vgg16, resnet34
import os
from os.path import join
最后链接文章包含代码可以训练图像分类(基于tiny-imagenet200数据集,包含数据预处理和分类模型训练两部分代码)
亲测cpu环境下2天时间可以达到40%左右的图像分类精度
(我把作者的网络模型改为pytorch中的vgg16之后,作者的模型我没有尝试长时间训练,代码能跑我就改了,大家可以改成任意模型)
我训练这个网络是为了验证图像风格迁移效果与模型图像分类精度的关系:精度越高,图像风格迁移效果越好。这之后可能会出单独的风格迁移文章再归纳一下我的项目。
全文imagenet数据集训练链接如下: