上篇文章提到的INDArray中的repeat重写,我项目中使用的是40*1*300的矩阵操作,本人亲测性能得到了150倍的提升。
现在给出个直观的算法概念和具体实现代码:
Nd4j的repeat实现算法如图:
本人重写后的实现算法如图:
下面是针对我的项目情况的代码实现, 我的实现只针对被repeat的维度长度为1的情况,如果大家需要对任意长度维度的repeat操作,需要做少许修该。
/*
This funciton only support repeat the dimension of which the length of it is one.
* */
static INDArray repeat_ndarray(INDArray input, int dimension, int repeats){
if (1 != input.shape()[dimension]){
return input;
}
int[] permute_dimension = input.shape().clone();
for (int i = 0; i < permute_dimension.length; ++i)
permute_dimension[i] = i;
permute_dimension[0] = dimension;
permute_dimension[dimension] = 0;
INDArray input_dup_permute = input.permute(permute_dimension);
int[] shape_tmp = input.shape();
shape_tmp[dimension] = shape_tmp[0];
shape_tmp[0] = repeats;
INDArray array_tmp = Nd4j.create(shape_tmp);
for (int i = 0; i < repeats; ++ i){
array_tmp.putRow(i, input_dup_permute.getRow(0));
}
return array_tmp.permute(permute_dimension);
}
如果对这个算法实现有疑问或者感兴趣, 欢迎提问