对于deep learning初学者来说, Kaggle上的手写数字比赛是一个不错的练手任务。
caffe上有一个example就是解决基于Mnist库的手写数字识别问题的,但是example中的输入输出Kaggle的格式不一样,因此首先其实这个练手任务最重要的就是数据预处理、后处理。很多博客已经给出了具体的流程和代码,可以参考@小村长的博文。
基本思路:
1) 用matlab将Kaggle的train.csv分为训练集和验证集(取前1000个样本作验证集)。
2) 将给定csv数据转变成lmdb格式。
3) Training(用Caffe的example中的LeNet进行训练)。
4) Testing并提取prob层的特征。
5) 对结果进行后处理(转化为Kaggle要求的csv格式)。
后面我将所有matlab代码部分一并改写成c++,代码分预处理-训练-测试,下面给出一些思路分析,代码见github。
1、预处理
将train.csv前1000个样本写入验证集lmdb,剩余样本写入训练集lmdb,无需事先用matlab分割。参照caffe中数据转化成lmdb的原理(图侵删):
//convert_digitRecog_train_lmdb.cpp
创建db::DB类型的对象db(scoped_ptr是一种智能指针)
scoped_ptr<db::DB> db(db::GetDB(FLAGS_backend)); // new DB object
创建环境,打开环境
db->Open(argv[2], db::NEW); // open the lmdb file tostore the data
创建并打开transaction操作句柄
<span style="font-size:14px;">scoped_ptr<db::Transaction>txn(db->NewTransaction()); // newTransaction object to put and commit the data</span>
数据转换成Datum对象格式,并序列化
<span style="font-size:14px;">Datum datum; // this data structure store thedata and label
datum.set_channels(1); // the channels
datum.set_height(rows); // rows
datum.set_width(cols); // cols
datum.set_data(buffer,rows*cols);
datum.set_label(pixes);
CHECK(datum.SerializeToString(&out)); // serialize to string</span>
通过句柄写入数据库
<span style="font-size:14px;">txn->Put(string(key_cstr,length), out); // put it, both the key and value</span>
数据库写入lmdb文件中
<span style="font-size:14px;">txn->Commit();</span>
重置句柄,将训练集后面41000个样本写入训练集lmdb中
<span style="font-size:14px;">txn.reset(db->NewTransaction());</span>
训练集和验证集的比例可以调整,这里选择1000个可以得到比较好的结果。
测试集的数据转换基本类似,需要注意测试集没有label,转成datum时需要注意相应问题。详细见代码。
2.训练网络参数
Caffe的mnist例子就是手写数字识别的示例,直接用这个示例的LeNet网络结构就可以得到很好的效果,具体网络结构见lenet_train_test.prototxt,需要注意修改网络输入数据的路径等。
lenet_solver.prototxt里面是关于迭代次数、学习率等的设置,可以根据实际情况调整。
训练时可以将日志保存起来,linux下命令为:
./examples/digitRecog/digitRecog_train.sh2>&1 |tee digitRecog_train.log
Caffe自带日志解析器,可以解析日志中的迭代次数和准确率等。
./tools/extra/parse_log.shdigitRecog_train.log
3、测试&后处理
在训练得到的model上进行测试,可以得到测试结果。Caffe自带的feature_extract的代码的输出为lmdb格式,将其最后输出按Kaggle提交数据格式要求另写入csv文件提交即可。
Caffe的数据是以blob格式存储的,以起始地址+位移的方式获取某个数据。
feature_blob_data =feature_blob->cpu_data() +feature_blob->offset(n); //the features of which imag
其他详细见代码