cifar-10数据集
cifar10数据集包含60000张32×32的彩色图像,分为10类,每一类6000张。
50000张训练集
10000张测试集
数据集准备
执行get_cifar10.sh下载数据
$ ./data/cifar10/get_cifar10.sh
执行create_cifar10.sh将数据转换为lmdb格式,并计算数据集均值
$ ./examples/cifar10/create_cifar10.sh
运行之后,将会在examples/cifar10/中出现数据库文件cifar10-leveldb和数据库图像均值二进制文件mean.binaryproto
训练网络
通过执行train_quick.sh脚本来训练网络,没有使用GPU的需要先将cifar10_quick_solver.prototxt 和 cifar10_quick_solver_lr1.prototxt 中的 GPU 改成 CPU 。
$ ./examples/cifar10/train_quick.sh
打开train_quick脚本,内容如下:
#!/usr/bin/env sh
set -e
TOOLS=./build/tools
$TOOLS/caffe train \
--solver=examples/cifar10/cifar10_quick_solver.prototxt $@
# reduce learning rate by factor of 10 after 8 epochs
$TOOLS/caffe train \
--solver=examples/cifar10/cifar10_quick_solver_lr1.prototxt \
--snapshot=examples/cifar10/cifar10_quick_iter_4000.solverstate.h5 $@
可以看到该脚本先用cifar10_quick_solver.prototxt迭代4000次,再用cifar10_quick_solver_lr1.prototxt,在上一次训练的基础上再迭代1000次。其中cifar10_quick_solver.prototxt的学习率为0.001,cifar10_quick_solver_lr1.prototxt的学习率为0.0001。
训练完成后,准确度约为75%。
用自己的图片测试训练好的网络
使用classification.bin程序测试:
$ ./build/examples/cpp_classification/classification.bin \
examples/cifar10/cifar10_quick.prototxt \
examples/cifar10/cifar10_quick_iter_5000.caffemodel.h5 \
examples/cifar10/mean.binaryproto \
data/cifar10/batches.meta.txt \
examples/images/cat.jpg
---------- Prediction for examples/images/cat.jpg ----------
0.7405 - "deer"
0.1773 - "dog"
0.0566 - "cat"
0.0160 - "bird"
0.0079 - "horse"
这个训练好的网络准确率并不高,要提高准确率,可以使用train_full.sh来训练。
train_full.sh使用的网络模型比train_quick.sh使用的网络模型少了一个全连接层。
train_full.sh使用0.001的学习率迭代60000次,再0.0001的学习率迭代65000次,再使用0.0000的学习率迭代70000次。