使用Mxnet训练图片分类类器
1、准备数据:
(1)建立一个根目录然后,再为每一个类别的图片建立一个子文件夹,把每一类图片放入对应的子文件夹即可。
--root:
----class1
----class2
......
----classn
首先生成训练集和测试集的list,命令如下:
Python ~/mxnet/tools/im2rec.py --list True --recursive True --train-ratio 0.9 myData /home/xxx/root/
–list:当要生成list文件时,这个参数一定要设为True,表示当前用来生成的list文件;默认是生成rec文件;
–recursive:递归的遍历你的所有数据集,要设为True;
–train-ratio:用来将你的全部数据集拆分成两部分:训练集(train)和交叉验证集(val),具体多少作为训练集,多少作为验证集,就由这个参数来确定;
–test-ratio:同上,分成训练集和测试集两部分;
–exts:这个是你数据的后缀(注,这里我们一般说的图片数据),目前的MXNet只支持两种图片格式:jpg和jpeg。
执行完这个命令,你会发现生成两个文件:myData_train.lst和myData_val.lst
(2)生成rec文件:
python ~/mxnet/tools/im2rec.py –num-thread 4 –pass-through 1 myData /home/xxx/root/
–pass-through: 设为1,即跳过矩阵变换,否则会报错:未知的array type;
myData就是第一步中生成.lst文件的前缀,这里用它来生成rec;
执行完这条命令,你就会看到两个文件:myData_train.rec和myData_val.rec
2、加载训练数据以及验证数据集:
def get_iterators(batch_size, data_shape=(3, 224, 224)):
train = mx.io.ImageRecordIter(
path_imgrec = './ld_train/my_images_train.rec',
data_name = 'data',
label_name = 'softmax_label',
batch_size = batch_size,
data_shape = data_shape,
shuffle = True,
rand_crop = True,
rand_mirror = True)
val = mx.io.ImageRecordIter(
path_imgrec = './ld_train/my_images_val.rec',
data_name = 'data',
label_name = 'softmax_label',
batch_size = batch_size,
data_shape = data_shape,
rand_crop = False,
rand_mirror = False)
return (train, val)
train,val=get_iterators(128,(3,128,128))#指定Batch_size以及图片大小。
3、定义网络结构:
这里以resnet为例:
'''
Adapted from https://github.com/tornadomeet/ResNet/blob/master/symbol_resnet.py
Original author Wei Wu
Implemented the following paper:
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Identity Mappings in Deep Residual Networks"
'''
import mxnet as mx
import numpy as np
def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True, bn_mom=0.9, workspace=256, memonger=False):
"""Return ResNet Unit symbol for building ResNet
Parameters
----------
data : str
Input data
num_filter : int
Number of output channels
bnf : int
Bottle neck channels factor with regard to num_filter
stride : tupe
Stride used in convolution
dim_match : Boolen
True means channel number between input and output is the same, otherwise means differ
name : str
Base name of the operators
workspace : int
Workspace used in convolution operator
"""
if bottle_neck:
# the same as https://github.com/facebook/fb.resnet.torch#notes, a bit difference with origin paper
bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn1')
act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1')
weight = mx.symbol.Variable(name=name + '_conv1_weight', dtype=np.float32)
weight = mx.symbol.Cast(data=weight, dtype=np.float16)
c