首先介绍一下MMLab,用于计算机视觉研究的基础Python库,支持OpenMMLab旗下其他开源库。
Github | https://github.com/open-mmlab/mmcv
主要功能是I/O、图像视频处理、标注可视化、各种CNN架构、各类CUDA操作算子
这里以mmcv,mmclassification为例子。
1,首先安装mmcv,这里可以参考mmlab的参考文档
依赖环境 — MMClassification 0.25.0 文档
2,安装完毕之后,我们去相关网站下载mmlabclassification-master 源码,然后可以看见里面有很多代码。
上图是该文件的树结构,因为其本身是一个库,因此是具有setup安装文件的,我们这里将其视为一个项目来尝试。
3,MMLab代码的策略是将配置与源码分开,需要什么就直接调用,config就是各种配置,例如打开其中的查看一下:
4,里面有众多我们熟悉的网络配置,再展开其中一个,以resnet为例。
renet下面有很多其网络架构,其命名规则为:网络名_层数_gpu数量_batch数量_数据集等
还有的会写一些策略,例如学习率,epoch数量等。
resnet18_18*b16_cifar10.py你能看出来吗?
5,打开文件看一看其结构:
_base_ = [
'../_base_/models/resnet18_cifar.py', '../_base_/datasets/cifar10_bs16.py',
'../_base_/schedules/cifar10_bs128.py', '../_base_/default_runtime.py'
]
其仅仅写出了一个框架,均为继承关系,因此不会有大片两的代码,共分为4个文件,分别是模型,数据,迭代策略,保存等打印配置。
那么真正代码结构在哪里呢? 在config/_base_下,让我们去看看吧,哈哈哈
6,结构
6.1model
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='ResNet_CIFAR',
depth=18,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=10,
in_channels=512,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
))
上面展现了模型的结构,主干backbone使用的是ResNet网络,18层,一共4块,输出的是第4个(0,1,2,3),3对应的是4.style对应的是,我也不太清楚。只知道有pytorch 3*3,和cafe1*1,stride均为2.下面是neck,一般来说,我们对模型进行改进就是对neck进行改进,这里采用的是全局平均池化,后面的是头,由于我们这里是分类故使用全连接做分类。一共10类,输入通道为512,分为10类。后面就是损失。
同理其他同样。
7那么就来跑一个属于自己的数据吧。
数据如下是102个花的分类,以及classification
链接: https://pan.baidu.com/s/1617As76-0Us-JfaVoHZRcA 提取码: affr
这是数据:
将其解压后放进这个文件夹内:
。首先进行模型生成找到tools/train
我们\mmclassification-master\configs\resnet\resnet18_8xb32_in1k.py作为参数,输入进train文件的参数中,图如下:
执行完毕会在
/mmclassification-master/tools/work_dirs下生成一些文件,我们选取和resnet18_8xb32_in1k.py一样文件名的文件将其改名字复制到/home/wxq/mmlab/mmclassification-master/configs/resnet下并打开,我将其改为tod_resnet18_8xb32_in1k.py
如下图最后一个文件:
下面是内容修改:
将12行改为102
# 改前:
num_classes=1000,
# 改后:
num_classes=102,
第44行里面data加载的数据路径均改为flower的路径
其中包含第49行data_prefix='..\mmcls\data\flower_data\train'
65行改为:data_prefix=''..\mmcls\data\flower_data\valid'
第66行注释掉:这里是加载标签的路径,如果没有或者注释掉,那么就是将其文件所在文件夹作为一类,例如花朵的1,2,3,,,102就是102类:
同理由于test也是这样改,由于没有测试集,用valid作为测试集:
81行改为:data_prefix=''..\mmcls\data\flower_data\valid'
82行注释掉。
最后由于我自己显存小,将100行检查点改为50个记录一次,其余不变。
下面是改其他内容:
还有就是因为他源码用的是imagenet的数据加载,原来有1000个种类:
需要将这部分改为102个花的种类,才可以。
接下来就是运行:
成功啦