caffe的reshape层探究

2 篇文章 0 订阅

在caffe的reshape层参数设定中有两个特别的参数:axis和num_axes,看caffe.proto里的对reshape层的参考文档发现解释的不清楚,举的例子也存在错误,因此特地写了一个仿照reshape层实现的代码来探究这两个参数的作用,代码如下:

#include <vector>
#include <iostream>
using namespace std;


int count(vector<int> input, int begin, int end){
	int num = 1;
	for (int i = begin; i < end; ++i)
		num *= input[i];
	return num;
}


vector<int> reshape(vector<int> input, vector<int> shape, int axis=0, int num_axes=-1){
	
	
	int inferred_axis_ = -1;
	vector<int> copy_axes_;
	const int top_num_axes = shape.size();
	int constant_count_ = 1;
	for (int i = 0; i < top_num_axes; ++i) {
		const int top_dim = shape[i];
		if (top_dim == 0) {
			copy_axes_.push_back(i);
		}
		else if (top_dim == -1) {
			inferred_axis_ = i;
		}
		else {
			constant_count_ *= top_dim;
		}
	}
	
	const int input_start_axis = axis;
	const int start_axis = (input_start_axis >= 0) ? input_start_axis :
		input.size() + input_start_axis + 1;
	const int end_axis =
		(num_axes == -1) ? input.size() : (start_axis + num_axes);
	const int num_axes_replaced = end_axis - start_axis;
	const int num_axes_retained = input.size() - num_axes_replaced;
	const int num_new_axes = shape.size();
	vector<int> top_shape(num_axes_retained + num_new_axes);
	int top_shape_index = 0;
	for (int i = 0; i < start_axis; ++i) {
		top_shape[top_shape_index++] = input[i];
	}
	for (int i = 0; i < num_new_axes; ++i) {
		top_shape[top_shape_index++] = shape[i];
	}
	for (int i = end_axis; i < input.size(); ++i) {
		top_shape[top_shape_index++] = input[i];
	}
	for (int i = 0; i < copy_axes_.size(); ++i) {
		const int copy_axis_index = copy_axes_[i];

		top_shape[start_axis + copy_axis_index] =
			input[start_axis + copy_axis_index];
	}
	if (inferred_axis_ >= 0) {
		// A -1 dim was specified; infer the correct dimension by computing the
		// product of the other dimensions.
		int explicit_count = constant_count_;
		explicit_count *= count(input, 0, start_axis);
		explicit_count *= count(input, end_axis, input.size());
		for (int i = 0; i < copy_axes_.size(); ++i) {
			const int copy_axis_index = copy_axes_[i];
			explicit_count *= top_shape[start_axis + copy_axis_index];
		}

		const int inferred_dim = count(input, 0, input.size()) / explicit_count;
		top_shape[start_axis + inferred_axis_] = inferred_dim;
	}
	return top_shape;
}


int main(){
	vector<int> input = { 1, 2, 3, 4 };
	vector<int> shape = { -1 };
	vector<int> res = reshape(input, shape,1, 2);//1,6,4
	for (int i : res)
		cout << i<<" ";
	cout << endl;
}


  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值