1。安装MXNet
虽然网络不稳定,还是建议尝试用pip命令安装
pip install mxnet
尝试了几个国内的源,不幸都安装失败了。
#opencv
conda install opencv
#matplotlib
conda install matplotlib
2。github下载
https://github.com/apache/incubator-mxnet
clone整个项目,然后进入example找到ssd,单独复制到自己建立的项目目录下。
3。准备数据集
可以借助ssd中给出的数据格式转换命令,将voc格式进行转换。
1)建议先自己在项目中建立data文件,在其中放入VOCdevkit整个目录;当然也可以通过修改ssd/tools/prepare_dataset.py中root参数;还可以在输入命令行的时候键入--root。
2)尝试转换数据格式,毫无疑问会出错
python tools/prepare_dataset.py --dataset pascal --year 2007,2012 --set trainval --target ./data/train.lst
python tools/prepare_dataset.py --dataset pascal --year 2007 --set test --target ./data/val.lst --no-shuffle
解决方案:
1)修改ssd/dataset/names/pascal_voc.names中的类别为自己的类别;
python train.py 出错:error:no image in imdb
solve: 修改ssd/dataset/names/pascal_voc.names文件的名称,要和标注是的label名一样。
2)注意im2rec.py的路径
cmd_arguments = ["python",
os.path.join(curr_path, "im2rec.py"),
os.path.abspath(args.target), os.path.abspath(args.root_path),
"--pack-label", "--num-thread", str(args.num_thread)]
4。尝试利用官网中的预训练模型进行训练,毫无疑问还是会报错。
whut@whut-Z370-AORUS-Gaming-5:~/yyCode/ssd$ python train.py
[17:14:03] src/io/iter_image_det_recordio.cc:281: ImageDetRecordIOParser: /home/whut/yyCode/ssd/data/train.rec, use 11 threads for decoding..
[17:14:03] src/io/iter_image_det_recordio.cc:334: ImageDetRecordIOParser: /home/whut/yyCode/ssd/data/train.rec, label padding width: 350
[17:14:04] src/io/iter_image_det_recordio.cc:281: ImageDetRecordIOParser: /home/whut/yyCode/ssd/data/val.rec, use 11 threads for decoding..
[17:14:04] src/io/iter_image_det_recordio.cc:334: ImageDetRecordIOParser: /home/whut/yyCode/ssd/data/val.rec, label padding width: 350
INFO:root:Start finetuning with (gpu(0),gpu(1)) from epoch 1
[17:14:05] src/nnvm/legacy_json_util.cc:209: Loading symbol saved by previous version v0.9.5. Attempting to upgrade...
[17:14:05] src/nnvm/legacy_json_util.cc:217: Symbol successfully upgraded!
INFO:root:Freezed parameters: [conv1_1_weight,conv1_1_bias,conv1_2_weight,conv1_2_bias,conv2_1_weight,conv2_1_bias,conv2_2_weight,conv2_2_bias,conv3_1_weight,conv3_1_bias,conv3_2_weight,conv3_2_bias,conv3_3_weight,conv3_3_bias,conv4_1_weight,conv4_1_bias,conv4_2_weight,conv4_2_bias,conv4_3_weight,conv4_3_bias,conv5_1_weight,conv5_1_bias,conv5_2_weight,conv5_2_bias,conv5_3_weight,conv5_3_bias]
['conv2_1_bias', 'relu7_cls_pred_conv_bias', 'relu4_3_scale', 'conv10_1_weight', 'conv4_2_bias', 'conv5_2_bias', 'conv2_2_bias', 'relu10_2_cls_pred_conv_bias', 'relu8_2_cls_pred_conv_weight', 'conv4_3_bias', 'conv9_2_bias', 'conv5_2_weight', 'relu8_2_cls_pred_conv_bias', 'conv9_1_weight', 'relu7_loc_pred_conv_weight', 'conv5_3_bias', 'conv5_1_weight', 'conv11_1_bias', 'conv4_1_bias', 'relu11_2_loc_pred_conv_weight', 'relu10_2_loc_pred_conv_weight', 'relu4_3_loc_pred_conv_bias', 'conv6_weight', 'conv11_2_bias', 'relu8_2_loc_pred_conv_weight', 'conv4_3_weight', 'relu10_2_cls_pred_conv_weight', 'relu7_cls_pred_conv_weight', 'relu4_3_cls_pred_conv_bias', 'conv3_1_weight', 'conv3_2_weight', 'relu11_2_loc_pred_conv_bias', 'conv11_2_weight', 'conv5_1_bias', 'relu4_3_loc_pred_conv_weight', 'conv1_2_bias', 'relu8_2_loc_pred_conv_bias', 'conv8_1_bias', 'relu10_2_loc_pred_conv_bias', 'conv8_2_weight', 'conv6_bias', 'conv7_bias', 'conv3_2_bias', 'relu9_2_cls_pred_conv_bias', 'conv9_1_bias', 'conv3_1_bias', 'conv9_2_weight', 'conv10_2_weight', 'relu7_loc_pred_conv_bias', 'conv1_1_weight', 'conv8_1_weight', 'conv11_1_weight', 'relu4_3_cls_pred_conv_weight', 'conv3_3_bias', 'conv8_2_bias', 'conv5_3_weight', 'conv7_weight', 'conv2_2_weight', 'conv1_2_weight', 'conv10_2_bias', 'relu9_2_loc_pred_conv_bias', 'conv10_1_bias', 'conv2_1_weight', 'conv3_3_weight', 'relu9_2_cls_pred_conv_weight', 'relu11_2_cls_pred_conv_bias', 'relu9_2_loc_pred_conv_weight', 'conv1_1_bias', 'relu11_2_cls_pred_conv_weight', 'conv4_2_weight', 'conv4_1_weight']
Traceback (most recent call last):
File "train.py", line 151, in <module>
voc07_metric=args.use_voc07_metric)
File "/home/whut/yyCode/ssd/train/train_net.py", line 277, in train_net
monitor=monitor)
File "/home/whut/anaconda2/lib/python2.7/site-packages/mxnet/module/base_module.py", line 502, in fit
allow_missing=allow_missing, force_init=force_init)
File "/home/whut/anaconda2/lib/python2.7/site-packages/mxnet/module/module.py", line 309, in init_params
_impl(desc, arr, arg_params)
File "/home/whut/anaconda2/lib/python2.7/site-packages/mxnet/module/module.py", line 297, in _impl
cache_arr.copyto(arr)
File "/home/whut/anaconda2/lib/python2.7/site-packages/mxnet/ndarray/ndarray.py", line 2066, in copyto
return _internal._copyto(self, out=other)
File "<string>", line 25, in _copyto
File "/home/whut/anaconda2/lib/python2.7/site-packages/mxnet/_ctypes/ndarray.py", line 92, in _imperative_invoke
ctypes.byref(out_stypes)))
File "/home/whut/anaconda2/lib/python2.7/site-packages/mxnet/base.py", line 251, in check_call
raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [17:14:07] /home/travis/build/dmlc/mxnet-distro/mxnet-build/3rdparty/mshadow/../../src/operator/tensor/../elemwise_op_common.h:133: Check failed: assign(&dattr, (*vec)[i]) Incompatible attr in node at 0-th output: expected [84], got [76]
发现原本是84,但是只拿到了76,思考自己与之的不同,发现:voc原本有21类,而自己的数据集只有19类。说明官网给的不是预训练模型,而是已经在voc上训练之后的模型。
解决方案:
自己下载mxnet的预训练模型。
5。训练前的修改
1)修改参数:个人习惯问题:我习惯于直接修改默认值,如在train.py中进行修改:
#根据自己的电脑性能及相应之间的关系
--gpus default='0,1'
--batch-size default=128 #16的整数倍
--lr default=0.005 #与batch-size应该是一个根号的关系
--lr-steps #学习率迭代多少epoch开始衰减
#根据自己的数据集
--num-class
--num-example
--class-names
2)修改ssd/train/train_net.py
mod.fit(train_iter,
val_iter,
eval_metric=MultiBoxMetric(),
validation_metric=valid_metric,
batch_end_callback=batch_end_callback,
epoch_end_callback=epoch_end_callback,
optimizer='sgd',
optimizer_params=optimizer_params,
begin_epoch=begin_epoch,
num_epoch=end_epoch,
initializer=mx.init.Xavier(),
#arg_params=args, #注释的原因是不加载任何模型,随机初始化参数
#aux_params=auxs,
allow_missing=True,
monitor=monitor)
6。从头开始训练
尝试一:batch-size=32,lr=0.002,--lr-steps='80',--end-epoch=115
尝试二:batch-size=128,lr=0.002,--lr-steps='80,160',--end-epoch=226
发现在mAP还在上升的过程中,提早对lr进行了调整,所以再来。
尝试三:batch-size=128,lr=0.005,--lr-steps='120,240',--end-epoch=360
参考文献:
https://blog.csdn.net/weixin_39608351/article/details/82182383