简介:
Caffe的tools目录下提供了一个caffe.cpp,经过编译工具的编译会生成一个命令行工具caffe。该工具提供的功能有: train, test, device_query, time。
训练的方法:
这里先分析train的功能,也就是训练。训练既支持从零开始训练,也支持在已有的模型基础上进行finetune。
finetune的命令行格式:caffe train --solver=solver.prototxt --model=model.ptototxt --weights=weights.caffemodel
其中通过--solver设置solver的prototxt文件
通过--model设置网络的prototxt文件
通过--weights设置一个训练好的网络,训练时会用该网络的参数对待训练的网络的参数进行初始化
训练相关的代码及简要分析:
在caffe.cpp中,所有训练相关的代码都在 int train() 这个函数中。代码及分析如下:
int train()
{
// 检查输入参数solver,snapshot和weight。其中solver为solver的ptototxt文件
// snapshot为训练时产生的快照,以便在训练中断后,不至于从头开始训练
// weights为一个已有的训练好的网络,如果指定了weights,则训练的时候会用指定
// 的weights初始化网络参数,然后再训练,主要用于对网络进行finetune
// 注意:snapshot和weights不能同时使用
CHECK_GT(FLAGS_solver.size(), 0) << "Need a solver definition to train.";
CHECK(!FLAGS_snapshot.size() || !FLAGS_weights.size())
<< "Give a snapshot to resume training or weights to finetune "
"but not both.";
// 从指定的solver的prototxt文件中读取SolverParameter
caffe::SolverParameter solver_param;
caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);
// If the gpus flag is not provided, allow the mode and device to be set
// in the solver prototxt.
if (FLAGS_gpu.size() == 0
&& solver_param.solver_mode() == caffe::SolverParameter_SolverMode_GPU) {
if (solver_param.has_device_id()) {
FLAGS_gpu = "" +
boost::lexical_cast<string>(solver_param.device_id());
} else { // Set default GPU if unspecified
FLAGS_gpu = "" + boost::lexical_cast<string>(0);
}
}
vector<int> gpus;
get_gpus(&gpus);
if (gpus.size() == 0) {
LOG(INFO) << "Use CPU.";
Caffe::set_mode(Caffe::CPU);
} else {
ostringstream s;
for (int i = 0; i < gpus.size(); ++i) {
s << (i ? ", " : "") << gpus[i];
}
LOG(INFO) << "Using GPUs " << s.str();
#ifndef CPU_ONLY
cudaDeviceProp device_prop;
for (int i = 0; i < gpus.size(); ++i) {
cudaGetDeviceProperties(&device_prop, gpus[i]);
LOG(INFO) << "GPU " << gpus[i] << ": " << device_prop.name;
}
#endif
solver_param.set_device_id(gpus[0]);
Caffe::SetDevice(gpus[0]);
Caffe::set_mode(Caffe::GPU);
Caffe::set_solver_count(gpus.size());
}
caffe::SignalHandler signal_handler(
GetRequestedAction(FLAGS_sigint_effect),
GetRequestedAction(FLAGS_sighup_effect));
// 用读取的SolverParameter创建Solver
shared_ptr<caffe::Solver<float> >
solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));
solver->SetActionFunction(signal_handler.GetActionFunction());
// 利用snapshot restore网络或利用weights初始化网络的参数
if (FLAGS_snapshot.size()) {
LOG(INFO) << "Resuming from " << FLAGS_snapshot;
solver->Restore(FLAGS_snapshot.c_str());
} else if (FLAGS_weights.size()) {
CopyLayers(solver.get(), FLAGS_weights);
}
// 进行训练
if (gpus.size() > 1) {
caffe::P2PSync<float> sync(solver, NULL, solver->param());
sync.Run(gpus);
} else {
LOG(INFO) << "Starting Optimization";
solver->Solve();
}
LOG(INFO) << "Optimization Done.";
return 0;
}