功能
通过InsertSplits后变成:
特殊变量说明
map<string, pair<int, int> > blob_name_to_last_top_idx;
map<pair<int, int>, pair<int, int> > bottom_idx_to_source_top_idx;
map<pair<int, int>, int> top_idx_to_bottom_count;
map<pair<int, int>, float> top_idx_to_loss_weight;
map<pair<int, int>, int> top_idx_to_bottom_split_idx;
map<int, string> layer_idx_to_layer_name;
/*{layer.top,(layer_idx,blob_in_layer_idx)}*/
blob_name_to_last_top_idx={
{'data',(0,0)},
{'ip1',(1,0)},
{'ip2',(2,0)},
};
/*{(layer_idx,layer.bottom_idx),(source_layer_idx,source_layer.top_idx)}*/
bottom_idx_to_source_top_idx={
{(1,0),(0,0)},
{(2,0),(0,0)},
};
/*{(layer_idx,layer.top_idx),to_bottom_top_cnt}*/
top_idx_to_bottom_cnt={
{(0,0),2},
};
top_idx_to_loss_weight;
top_idex_bottom_split_idx;
layer_idx_to_layer_name;
源码分析
void InsertSplits(const NetParameter& param, NetParameter* param_split) {
// Initialize by copying from the input NetParameter.
param_split->CopyFrom(param);
param_split->clear_layer();
map<string, pair<int, int> > blob_name_to_last_top_idx;//<blob_name,<layer_idx,top_idx_in_layer>>
map<pair<int, int>, pair<int, int> > bottom_idx_to_source_top_idx;//<<layer_idx,bottom_idx_layer>,<layer_idx_to_bottom,top_idx_to_bottom>>
map<pair<int, int>, int> top_idx_to_bottom_count;//<<layer_idx,top_idx_in_layer>,top_cnt>
map<pair<int, int>, float> top_idx_to_loss_weight;
map<pair<int, int>, int> top_idx_to_bottom_split_idx;
map<int, string> layer_idx_to_layer_name;
for (int i = 0; i < param.layer_size(); ++i) {
const LayerParameter& layer_param = param.layer(i);
layer_idx_to_layer_name[i] = layer_param.name();
for (int j = 0; j < layer_param.bottom_size(); ++j) {
const string& blob_name = layer_param.bottom(j);
if (blob_name_to_last_top_idx.find(blob_name) ==
blob_name_to_last_top_idx.end()) {
LOG(FATAL) << "Unknown bottom blob '" << blob_name << "' (layer '"
<< layer_param.name() << "', bottom index " << j << ")";
}
const pair<int, int>& bottom_idx = make_pair(i, j);
const pair<int, int>& top_idx = blob_name_to_last_top_idx[blob_name];
bottom_idx_to_source_top_idx[bottom_idx] = top_idx;
++top_idx_to_bottom_count[top_idx];
}
for (int j = 0; j < layer_param.top_size(); ++j) {
const string& blob_name = layer_param.top(j);
blob_name_to_last_top_idx[blob_name] = make_pair(i, j);
}
// A use of a top blob as a loss should be handled similarly to the use of
// a top blob as a bottom blob to another layer.
const int last_loss =
std::min(layer_param.loss_weight_size(), layer_param.top_size());
for (int j = 0; j < last_loss; ++j) {
const string& blob_name = layer_param.top(j);
const pair<int, int>& top_idx = blob_name_to_last_top_idx[blob_name];
top_idx_to_loss_weight[top_idx] = layer_param.loss_weight(j);
if (top_idx_to_loss_weight[top_idx]) {
++top_idx_to_bottom_count[top_idx];
}
}
}
for (int i = 0; i < param.layer_size(); ++i) {
LayerParameter* layer_param = param_split->add_layer();
layer_param->CopyFrom(param.layer(i));
// Replace any shared bottom blobs with split layer outputs.
for (int j = 0; j < layer_param->bottom_size(); ++j) {
const pair<int, int>& top_idx =
bottom_idx_to_source_top_idx[make_pair(i, j)];
const int split_count = top_idx_to_bottom_count[top_idx];
if (split_count > 1) {
const string& layer_name = layer_idx_to_layer_name[top_idx.first];
const string& blob_name = layer_param->bottom(j);
layer_param->set_bottom(j, SplitBlobName(layer_name,
blob_name, top_idx.second, top_idx_to_bottom_split_idx[top_idx]++));
}
}
// Create split layer for any top blobs used by other layer as bottom
// blobs more than once.
for (int j = 0; j < layer_param->top_size(); ++j) {
const pair<int, int>& top_idx = make_pair(i, j);
const int split_count = top_idx_to_bottom_count[top_idx];
if (split_count > 1) {
const string& layer_name = layer_idx_to_layer_name[i];
const string& blob_name = layer_param->top(j);
LayerParameter* split_layer_param = param_split->add_layer();
const float loss_weight = top_idx_to_loss_weight[top_idx];
ConfigureSplitLayer(layer_name, blob_name, j, split_count,
loss_weight, split_layer_param);
if (loss_weight) {
layer_param->clear_loss_weight();
top_idx_to_bottom_split_idx[top_idx]++;
}
}
}
}
}