MMLAB库学习

首先介绍一下MMLab,用于计算机视觉研究的基础Python库,支持OpenMMLab旗下其他开源库。

Github | https://github.com/open-mmlab/mmcv

主要功能是I/O、图像视频处理、标注可视化、各种CNN架构、各类CUDA操作算子

在这里插入图片描述

这里以mmcv,mmclassification为例子。

1,首先安装mmcv,这里可以参考mmlab的参考文档

依赖环境 — MMClassification 0.25.0 文档

2,安装完毕之后,我们去相关网站下载mmlabclassification-master 源码,然后可以看见里面有很多代码。

 上图是该文件的树结构,因为其本身是一个库,因此是具有setup安装文件的,我们这里将其视为一个项目来尝试。

3,MMLab代码的策略是将配置与源码分开,需要什么就直接调用,config就是各种配置,例如打开其中的查看一下:

4,里面有众多我们熟悉的网络配置,再展开其中一个,以resnet为例。

 renet下面有很多其网络架构,其命名规则为:网络名_层数_gpu数量_batch数量_数据集等

还有的会写一些策略,例如学习率,epoch数量等。

resnet18_18*b16_cifar10.py你能看出来吗?

5,打开文件看一看其结构:

_base_ = [
    '../_base_/models/resnet18_cifar.py', '../_base_/datasets/cifar10_bs16.py',
    '../_base_/schedules/cifar10_bs128.py', '../_base_/default_runtime.py'
]

其仅仅写出了一个框架,均为继承关系,因此不会有大片两的代码,共分为4个文件,分别是模型,数据,迭代策略,保存等打印配置。

那么真正代码结构在哪里呢? 在config/_base_下,让我们去看看吧,哈哈哈

6,结构

6.1model

# model settings
model = dict(
    type='ImageClassifier',
    backbone=dict(
        type='ResNet_CIFAR',
        depth=18,
        num_stages=4,
        out_indices=(3, ),
        style='pytorch'),
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        type='LinearClsHead',
        num_classes=10,
        in_channels=512,
        loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
    ))

上面展现了模型的结构,主干backbone使用的是ResNet网络,18层,一共4块,输出的是第4个(0,1,2,3),3对应的是4.style对应的是,我也不太清楚。只知道有pytorch 3*3,和cafe1*1,stride均为2.下面是neck,一般来说,我们对模型进行改进就是对neck进行改进,这里采用的是全局平均池化,后面的是头,由于我们这里是分类故使用全连接做分类。一共10类,输入通道为512,分为10类。后面就是损失。

同理其他同样。

7那么就来跑一个属于自己的数据吧。

数据如下是102个花的分类,以及classification

链接: https://pan.baidu.com/s/1617As76-0Us-JfaVoHZRcA 提取码: affr 

这是数据:

 

将其解压后放进这个文件夹内:

。首先进行模型生成找到tools/train

我们\mmclassification-master\configs\resnet\resnet18_8xb32_in1k.py作为参数,输入进train文件的参数中,图如下:

 执行完毕会在

/mmclassification-master/tools/work_dirs下生成一些文件,我们选取和resnet18_8xb32_in1k.py一样文件名的文件将其改名字复制到/home/wxq/mmlab/mmclassification-master/configs/resnet下并打开,我将其改为tod_resnet18_8xb32_in1k.py

如下图最后一个文件:

 下面是内容修改:

将12行改为102

# 改前:
num_classes=1000,
# 改后:
num_classes=102,

第44行里面data加载的数据路径均改为flower的路径

其中包含第49行data_prefix='..\mmcls\data\flower_data\train'

65行改为:data_prefix=''..\mmcls\data\flower_data\valid'

第66行注释掉:这里是加载标签的路径,如果没有或者注释掉,那么就是将其文件所在文件夹作为一类,例如花朵的1,2,3,,,102就是102类:

 同理由于test也是这样改,由于没有测试集,用valid作为测试集:

81行改为:data_prefix=''..\mmcls\data\flower_data\valid'

82行注释掉。

最后由于我自己显存小,将100行检查点改为50个记录一次,其余不变。

下面是改其他内容:

还有就是因为他源码用的是imagenet的数据加载,原来有1000个种类:

 

需要将这部分改为102个花的种类,才可以。

 

接下来就是运行:

成功啦

  • 2
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值