template <typename Dtype>
void Solver<Dtype>::Init(const SolverParameter& param) {
CHECK(Caffe::root_solver() || root_solver_)
<< "root_solver_ needs to be set for all non-root solvers";
LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: "
<< std::endl << param.DebugString();
param_ = param;
CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
CheckSnapshotWritePermissions();//对snapshot的一个检查
if (Caffe::root_solver() && param_.random_seed() >= 0) {
Caffe::set_random_seed(param_.random_seed());
}
// Scaffolding code
InitTrainNet();
if (Caffe::root_solver()) {
InitTestNets();
LOG(INFO) << "Solver scaffolding done.";
}
iter_ = 0;
current_step_ = 0;
}
//这段代码主要是对网络进行初始赋值前的一些检查,基本上没啥需要特别关注的(也可能是我功力不够),主要看它掉用的函数InitTrainNet();
template <typename Dtype>
void Solver<Dtype>::InitTrainNet() {
const int num_train_nets = param_.has_net() + param_.has_net_param() +
param_.has_train_net() + param_.has_train_net_param();
//我针对的是lenet网络,这里,只有param_.has_net()为true,从参数的命名就可以看出,
//这是检查网络框架是否读入,以及网络的参数是否读入 你也可以打印下param_,就会发
//net_ = 0x69f850, net_param_ = 0x0,
//train_net_ = 0x7ffff692e3a8 <google::protobuf::internal::kEmptyString>,
//test_net_ = {<google::protobuf::internal::RepeatedPtrFieldBase> = {
// static kInitialSize = 0, elements_ = 0x0, current_size_ = 0,
// allocated_size_ = 0, total_size_ = 0}, <No data fields>},
// train_net_param_ = 0x0,
const string& field_names = "net, net_param, train_net, train_net_param";
CHECK_GE(num_train_nets, 1) << "SolverParameter must specify a train net "
<< "using one of these fields: " << field_names;
CHECK_LE(num_train_nets, 1) << "SolverParameter must not contain more than "
<< "one of these fields specifying a train_net: " << field_names;
NetParameter net_param; //个人建议,当出现这种参数声明的时候,对照caffe.proto看
if (param_.has_train_net_param()) {
LOG_IF(INFO, Caffe::root_solver())
<< "Creating training net specified in train_net_param.";
net_param.CopyFrom(param_.train_net_param());
} else if (param_.has_train_net()) {
LOG_IF(INFO, Caffe::root_solver())
<< "Creating training net from train_net file: " << param_.train_net();
ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param);
}
if (param_.has_net_param()) {
LOG_IF(INFO, Caffe::root_solver())
<< "Creating training net specified in net_param.";
net_param.CopyFrom(param_.net_param());
}
//这几个条件判断目的:确定对网络进行哪些操作
if (param_.has_net()) {
LOG_IF(INFO, Caffe::root_solver())
<< "Creating training net from net file: " << param_.net();
ReadNetParamsFromTextFileOrDie(param_.net(), &net_param);
//打印param_.net(): "/home/bing/tool/caffe-master11/caffe-master/examples/mnist/lenet_train_test.prototxt"
//ReadNetParamsFromTextFileOrDie这个函数的作用就是将param_.net()指向的
//网络赋值给刚刚创建的net_param参数(附1)
}
// Set the correct NetState. We start with the solver defaults (lowest
// precedence); then, merge in any NetState specified by the net_param itself;
// finally, merge in any NetState specified by the train_state (highest
// precedence).
NetState net_state;//看caffe.proto去~
net_state.set_phase(TRAIN);
net_state.MergeFrom(net_param.state());
net_state.MergeFrom(param_.train_state());
net_param.mutable_state()->CopyFrom(net_state);
if (Caffe::root_solver()) {
net_.reset(new Net<Dtype>(net_param));//net_param去对网络进行初始化到net.cpp的init
} else {
net_.reset(new Net<Dtype>(net_param, root_solver_->net_.get()));
}
}
以上,差不多就是solver.cpp对网络进行初始话主要的内容,其它的
也大同小异,总结一下,solver.cpp的init主要就是将train.prototxt文
件读入,然后进行一些个判断啊什么的,最后调用net.cpp的init,进行网络的初始化。
附1,net_param赋值前后对比
(gdb) p net_param
$27 = {<google::protobuf::Message> = {<No data fields>},
static kNameFieldNumber = 1, static kInputFieldNumber = 3,
static kInputShapeFieldNumber = 8, static kInputDimFieldNumber = 4,
static kForceBackwardFieldNumber = 5, static kStateFieldNumber = 6,
static kDebugInfoFieldNumber = 7, static kLayerFieldNumber = 100,
static kLayersFieldNumber = 2, _unknown_fields_ = {fields_ = 0x0},
name_ = 0x7ffff692e3a8 <google::protobuf::internal::kEmptyString>,
input_ = {<google::protobuf::internal::RepeatedPtrFieldBase> = {
static kInitialSize = 0, elements_ = 0x0, current_size_ = 0,
allocated_size_ = 0, total_size_ = 0}, <No data fields>},
input_shape_ = {<google::protobuf::internal::RepeatedPtrFieldBase> = {
static kInitialSize = 0, elements_ = 0x0, current_size_ = 0,
allocated_size_ = 0, total_size_ = 0}, <No data fields>}, input_dim_ = {
static kInitialSize = <optimized out>, elements_ = 0x0, current_size_ = 0,
total_size_ = 0}, state_ = 0x0,
layer_ = {<google::protobuf::internal::RepeatedPtrFieldBase> = {
static kInitialSize = 0, elements_ = 0x0, current_size_ = 0,
allocated_size_ = 0, total_size_ = 0}, <No data fields>},
layers_ = {<google::protobuf::internal::RepeatedPtrFieldBase> = {
static kInitialSize = 0, elements_ = 0x0, current_size_ = 0,
allocated_size_ = 0, total_size_ = 0}, <No data fields>},
force_backward_ = false, debug_info_ = false, _cached_size_ = 0,
_has_bits_ = {0}, static default_instance_ = 0x6a2d70}
(gdb) p net_param
$29 = {<google::protobuf::Message> = {<No data fields>},
static kNameFieldNumber = 1, static kInputFieldNumber = 3,
static kInputShapeFieldNumber = 8, static kInputDimFieldNumber = 4,
static kForceBackwardFieldNumber = 5, static kStateFieldNumber = 6,
static kDebugInfoFieldNumber = 7, static kLayerFieldNumber = 100,
static kLayersFieldNumber = 2, _unknown_fields_ = {fields_ = 0x0},
name_ = 0x69f260,
input_ = {<google::protobuf::internal::RepeatedPtrFieldBase> = {
static kInitialSize = 0, elements_ = 0x0, current_size_ = 0,
allocated_size_ = 0, total_size_ = 0}, <No data fields>},
input_shape_ = {<google::protobuf::internal::RepeatedPtrFieldBase> = {
static kInitialSize = 0, elements_ = 0x0, current_size_ = 0,
allocated_size_ = 0, total_size_ = 0}, <No data fields>}, input_dim_ = {
static kInitialSize = <optimized out>, elements_ = 0x0, current_size_ = 0,
total_size_ = 0}, state_ = 0x0,
layer_ = {<google::protobuf::internal::RepeatedPtrFieldBase> = {
static kInitialSize = 0, elements_ = 0x6be760, current_size_ = 11,
allocated_size_ = 11, total_size_ = 16}, <No data fields>},
layers_ = {<google::protobuf::internal::RepeatedPtrFieldBase> = {
static kInitialSize = 0, elements_ = 0x0, current_size_ = 0,
allocated_size_ = 0, total_size_ = 0}, <No data fields>},
force_backward_ = false, debug_info_ = false, _cached_size_ = 0,
_has_bits_ = {1}, static default_instance_ = 0x6a2d70}
caffe:solver.cpp——init()
最新推荐文章于 2022-07-16 13:35:07 发布