训练目标: 识别向日葵和玫瑰花两种类别。
数据集: 花卉数据集
原始数据集中的图片分辨率大小不一,这里在生成lmdb的时候,统一裁剪为160x160的分辨率大小。
create_lmdb.sh脚本中使用参数 --resize_height=160 --resize_width=160
网络模型: alex net
制作训练图片文件
在caffe-1.0\data下新建flowers目录,作为数据的根目录。
在flowers文件夹下分别放置train和val。
在train和val中放置不同类别的图rose和sunflower。
每一个类别的图中,训练图片是测试图片的5倍以上,测试图片100张,那么训练图片应该大于500张。
以下是rose训练图片集
生成数据脚本
在caffe-1.0\examples中新建 flower_class文件夹。作为此分类项目的根目录,然后创建以下3个脚本
create_filelist.sh:
DATA=data/flowers
MY=examples/flower_class
echo "Create train.txt..."
rm -rf $DATA/train.txt
find $DATA/train/rose -name *.jpg | cut -d '/' -f5 | sed "s:^:rose/:" | sed "s/$/ 0/">>$MY/train.txt
find $DATA/train/sunflower -name *.jpg | cut -d '/' -f5 | sed "s:^:sunflower/:" | sed "s/$/ 1/">>$MY/train.txt
echo "Create val.txt..."
rm -rf $DATA/val.txt
find $DATA/val/rose -name *.jpg | cut -d '/' -f5 | sed "s:^:rose/:" | sed "s/$/ 0/">>$MY/val.txt
find $DATA/val/sunflower -name *.jpg | cut -d '/' -f5 | sed "s:^:sunflower/:" | sed "s/$/ 1/">>$MY/val.txt
echo "All done"
create_lmdb.sh
MY=examples/flower_class
echo "Create train lmdb.."
rm -rf $MY/train_lmdb
convert_imageset --shuffle --resize_height=160 --resize_width=160 /home/lsc/work/tools/caffe-1.0/data/flowers/train/ $MY/train.txt $MY/train_lmdb
echo "Create val lmdb.."
rm -rf $MY/val_lmdb
convert_imageset --shuffle --resize_width=160 --resize_height=160 /home/lsc/work/tools/caffe-1.0/data/flowers/val/ $MY/val.txt $MY/val_lmdb
echo "All Done.."
create_meanfile.sh:
EXAMPLE=examples/flower_class
DATA=examples/flower_class/
TOOLS=build/tools
compute_image_mean $EXAMPLE/train_lmdb $EXAMPLE/mean.binaryproto
echo "Done."
依次执行以上三个脚本文件。
生成以下文件或文件夹:
开启训练
我们使用的是alex网络,所以我们先从models\bvlc_alexnet文件夹下拷贝配置文件到flower_class目录下。然后做出修改。
solver.prototxt
net: "./examples/flower_class/train_val.prototxt"
test_iter: 50
test_interval: 50
base_lr: 0.01
lr_policy: "step"
gamma: 0.1
stepsize: 100000
display: 20
max_iter: 1500
momentum: 0.9
weight_decay: 0.0005
snapshot: 500
snapshot_prefix: "flower_alexnet_train"
solver_mode: GPU
train_val.prototxt:
layer {
name: "data"
type: "Data"
top: "data"
top: "label"
include {
phase: TRAIN
}
transform_param {
mirror: false
crop_size: 160
mean_file: "examples/flower_class/mean.binaryproto"
}
data_param {
source: "examples/flower_class/train_lmdb"
batch_size: 64
backend: LMDB
}
}
layer {
name: "data"
type: "Data"
top: "data"
top: "label"
include {
phase: TEST
}
transform_param {
mirror: false
crop_size: 160
mean_file: "examples/flower_class/mean.binaryproto"
}
data_param {
source: "examples/flower_class/val_lmdb"
batch_size: 50
backend: LMDB
}
}
最后的输出层,更改类别为2,因为我们只有2个分类。
inner_product_param {
num_output: 2
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
value: 0
}
}
训练
ubuntu虚拟机上训练太慢,所以我把flower_class拷贝到windows系统上,进行训练。
建立脚本 train_flowers.bat:
caffe train --solver=examples/flower_class/solver.prototxt --gpu 0
最后的输出结果:
可以看到训练的精度挺高。
测试
test_flowers.bat:
.\build\examples\cpp_classification\Release\classification.exe .\examples\flower_class\deploy.prototxt .\examples\flower_class\flower_alexnet_train_iter_1500.caffemodel .\examples\flower_class\mean.binaryproto .\examples\flower_class\words.txt .\examples\images\sunflowers001.jpg
执行test_flowers.bat输出结果:
可以看出已经识别到正确的结果。
正确识别到玫瑰花。
到现在为止,我们已经训练好了一个模型。
把模型部署到海思3559A
新建一个目录alexnet_flower,这里先不做仿真,直接部署在板卡。所以从alexnet文件夹中拷贝相关的文件到alexnet_flower。
cp ../alexnet/alexnet_inst.cfg .
拷贝训练中用到的deploy.prototxt caffemodel,mean 均值文件到alexnet_flower下面,拷贝20个参考图片到images/ref中。此图片作为量化的参考。
修改后的alexnet_flowers_inst.cfg
[prototxt_file] ./../data/classification/alexnet_flower/model/deploy.prototxt
[caffemodel_file] ./../data/classification/alexnet_flower/model/flower_alexnet_train_iter_1500.caffemodel
[batch_num] 128
[net_type] 0
[sparse_rate] 0
[compile_mode] 0
[is_simulation] 0
[log_level] 2
[instruction_name] ./../data/classification/alexnet_flower/inst/alexnet_flowers_inst
[RGB_order] BGR
[data_scale] 0.0039062
[internal_stride] 16
[image_list] ./../data/classification/alexnet_flower/image_ref_list.txt
[image_type] 1
[mean_file] ./../data/classification/alexnet_flower/model/mean.binaryproto
[norm_type] 1
修改对应的路径,
image_type 1:表示网络数据输入为SVP_BLOB_TYPE_U8(普通的灰度图和RGB图)类型; 此时要求image_list配置是RGB图或者灰度图片的list文件。
注意:mage_ref_list.txt 这个文件中不能有空格,空行之类的字符,否则会出错,真实坑。
因为我们训练的是彩色图片,所以在代码中给定输入的时候,也要给BGR文件,
HI_CHAR *pcSrcFile = "./data/nnie_image/rgb_planar/rose_597_413.bgr";
HI_CHAR *pcModelName = "./data/nnie_model/classification/alexnet_flowers_inst.wk";
实测部署到板卡端,检测结果不对。
那么可以生成一个仿真的WK文件,在PC上试一下。
经过试验,发现是输入的数据的分辨率不对。
jpg转bgr文件:/home/lsc/work/tools/python_img/jpg_2_bgr/cvt
参考: https://blog.csdn.net/u011622434/article/details/97247736
因为我们给定的训练图片文件是160x160,所以给定的测试文件也必须是160x160的分辨率。如果给定错误的分辨率图像,那么很可能无法正确识别。