dnn分类鸢尾花 pytorch_从实例掌握 pytorch 进行图像分类

背景

从入门 Tensorflow 到沉迷 keras 再到跳出安逸选择pytorch,根本原因是在参加天池雪浪AI制造数据竞赛的时候,几乎同样的网络模型和参数,以及相似的数据预处理方式,结果得到的成绩差距之大让我无法接受,故转为 pytorch,keras 只用来做一些 NLP 的项目(毕竟积累了一些"祖传模型")~

更新 :2018年10月22日第二次更新,版本 0.1.1

更改:

数据增强方式由 pytorch 内置方式改为自定义,便于后期多 channels 模型更改,同时也可以借用 opencv 的强大库进行数据预处理(pytorch 的数据读取采用的是 PIL 库)。

输出打印方式采用 logger 的形式,动态更新。

保存最优模型的方式采用半个 epoch 计算一次

pytorch 0.4.0

0. 图像分类框架结构

在我们学习完机器学习、深度学习、卷积神经网络以及结构化机器学习项目等理论知识后,如何动手完成一个实际的项目往往是一个瓶颈期,只有将所学知识灵活运用,才敢说自己学了这些。前面的那些课程,我在研一上学期的时候都学习过,但直到研一下开始实习后,才逐渐能够独立完成项目,甚至参加一些数据竞赛。

在我使用 pytorch 的过程中,将其分为七大部分:数据加载,模型定义,评测标准定义,训练过程定义,验证过程定义,测试过程定义,参数定义。

文件组织如下:

==============================================================

checkpoints/

bestmodels/

dataset/

aug.py

dataloader.py

logs/

models/

pretrained_models/

model.py

submit/

config.py

main.py

utils.py

==============================================================

checkpoints/ : 存放训练保存的模型( bestmodels/ 保存在验证集上效果最好的模型);

models/ : 存放一些自定义的模型,如果不想使用 pytorch 自定义的网络模型,可以在这里添加(记得添加__init__.py文件);

submit/ : 输出的预测文件或者说比赛所需要你提交的结果文件,常见的是csv格式的;

logs/: 存放记录训练日志(.txt格式文件)

dataset/:包含 aug.py dataloader.py 两文件,主要实现数据增强和数据加载两个功能

config.py: 参数定义文件,以参数类的形式定义所需要提前设定或者修改的参数,例如:数据路径,学习率,训练 epoch 等;

model.py: 定义模型加载,可有可无,为了方便进行模型的 fine tune 我喜欢单独列出来;

utils.py: 定义了一些常用的评测标准,比如 mAP,Accuracy,loss 等。

main.py: 主文件,包含训练、测试、验证等过程;

1. 参数定义: config.py

参数定义的方式有很多种,有的人喜欢直接在主文件中进行设置;有的喜欢用 argparse 这个模块;也有人喜欢用 json 格式的文件,但是总的来说都不够简洁,我个人喜欢单独创建个 config.py 然后创建个 Python 类,以类属性的形式定义参数,详情见下:

class DefaultConfigs(object):

#1.string parameters

train_data = "../data/train/"

test_data = ""

val_data = "../data/val/"

model_name = "resnet50"

weights = "./checkpoints/"

best_models = weights + "best_model/"

submit = "./submit/"

logs = "./logs/"

gpus = "1"

#2.numeric parameters

epochs = 40

batch_size = 4

img_height = 224

img_weight = 224

num_classes = 62

seed = 888

lr = 1e-3

lr_decay = 1e-4

weight_decay = 1e-4

config = DefaultConfigs()

2. 数据加载: data_loader.py

pytorch 的数据读取方式有两种,一种是不同类别的图像按照文件夹进行划分,比如交通标志数据集:

train/

00000/

01153_00000.png

01153_00001.png

00001/

00025_00000.png

00025_00001.png

train_data = torchvision.datasets.ImageFolder(

"/data2/dockspace_zcj/traffic-sign/train/",#图片文件存放路径

transform = None #定义的数据增强方式

)

data_loader = torch.utils.data.DataLoader(train_data,

batch_size=20,

shuffle=True

)

"""

在模型训练过程时只需要加载data_loader就可以了,

具体方式在main文件中可见

"""

aug.py 由于代码较多,不在此展示,详情请移步 github 。常用的增强方式均在此文件中列举出来,如果需要添加,可根据样例,自行添加。

因此采用继承 torch.utils.data.Dataset 类,新建一个数据加载的 python 类,在__get_item__(self,index)函数中添加数据增强,代码如下:

from torch.utils.data import Dataset

from torchvision import transforms as T

from config import config

from PIL import Image

from dataset.aug import *

from itertools import chain

from glob import glob

from tqdm import tqdm

import random

import numpy as np

import pandas as pd

import os

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值