np.rollaxis函数说明
函数
rollaxis(a, axis, start=0)
'''
a: n维数组
axis: 需要滚动的轴
start: 滚动到位置的,默认为0
'''
用于将数组a的某个轴滚动到指定的维度。
例子:
import numpy as np
trainset_x = np.zeros(shape=[50000, 3072])
print('raw trainset_x.shape:', trainset_x.shape)
trainset_x = trainset_x.reshape(-1, 3, 32, 32)
print('after reshape trainset_x.shape:', trainset_x.shape)
trainset_x = np.rollaxis(trainset_x, 1, 4)
print('after rollaxis trainset_x.shape:', trainset_x.shape)
结果:
raw trainset_x.shape: (50000, 3072)
after reshape trainset_x.shape: (50000, 3, 32, 32)
after rollaxis trainset_x.shape: (50000, 32, 32, 3)
说明
我们指定的轴是axis=1(从0开始),滚动到的位置start=4,而实际上是滚动到了第3轴,也就是第4轴前面的那一个轴。
如果我们使用默认的轴
# 接着上面的来计算
trainset_x = np.rollaxis(trainset_x, 3)
print('use default start value, trainset_x.shape:', trainset_x.shape)
结果
use default start value, trainset_x.shape: (3, 50000, 32, 32)
它会滚动到最前面的轴。