在caffe的reshape层参数设定中有两个特别的参数:axis和num_axes,看caffe.proto里的对reshape层的参考文档发现解释的不清楚,举的例子也存在错误,因此特地写了一个仿照reshape层实现的代码来探究这两个参数的作用,代码如下:
#include <vector>
#include <iostream>
using namespace std;
int count(vector<int> input, int begin, int end){
int num = 1;
for (int i = begin; i < end; ++i)
num *= input[i];
return num;
}
vector<int> reshape(vector<int> input, vector<int> shape, int axis=0, int num_axes=-1){
int inferred_axis_ = -1;
vector<int> copy_axes_;
const int top_num_axes = shape.size();
int constant_count_ = 1;
for (int i = 0; i < top_num_axes; ++i) {
const int top_dim = shape[i];
if (top_dim == 0) {
copy_axes_.push_back(i);
}
else if (top_dim == -1) {
inferred_axis_ = i;
}
else {
constant_count_ *= top_dim;
}
}
const int input_start_axis = axis;
const int start_axis = (input_start_axis >= 0) ? input_start_axis :
input.size() + input_start_axis + 1;
const int end_axis =
(num_axes == -1) ? input.size() : (start_axis + num_axes);
const int num_axes_replaced = end_axis - start_axis;
const int num_axes_retained = input.size() - num_axes_replaced;
const int num_new_axes = shape.size();
vector<int> top_shape(num_axes_retained + num_new_axes);
int top_shape_index = 0;
for (int i = 0; i < start_axis; ++i) {
top_shape[top_shape_index++] = input[i];
}
for (int i = 0; i < num_new_axes; ++i) {
top_shape[top_shape_index++] = shape[i];
}
for (int i = end_axis; i < input.size(); ++i) {
top_shape[top_shape_index++] = input[i];
}
for (int i = 0; i < copy_axes_.size(); ++i) {
const int copy_axis_index = copy_axes_[i];
top_shape[start_axis + copy_axis_index] =
input[start_axis + copy_axis_index];
}
if (inferred_axis_ >= 0) {
// A -1 dim was specified; infer the correct dimension by computing the
// product of the other dimensions.
int explicit_count = constant_count_;
explicit_count *= count(input, 0, start_axis);
explicit_count *= count(input, end_axis, input.size());
for (int i = 0; i < copy_axes_.size(); ++i) {
const int copy_axis_index = copy_axes_[i];
explicit_count *= top_shape[start_axis + copy_axis_index];
}
const int inferred_dim = count(input, 0, input.size()) / explicit_count;
top_shape[start_axis + inferred_axis_] = inferred_dim;
}
return top_shape;
}
int main(){
vector<int> input = { 1, 2, 3, 4 };
vector<int> shape = { -1 };
vector<int> res = reshape(input, shape,1, 2);//1,6,4
for (int i : res)
cout << i<<" ";
cout << endl;
}