观众老爷们大家好!最近实在太忙,回首一看上篇专栏文章已经是 4 个月前的事了,所以今天想着写出一篇来撑撑场子(喂
但感觉已经没有当初写专栏的感觉了,所以可能画风会变不少,观众老爷们还望不要介意(逃
这次想和大家分享的是 numpy 中的 axis 这个东西。当初学的时候也没太在意,向来都是感觉差不多就直接过去了,没有去深究背后的一些逻辑。前些天被问起的时候一时懵懂,查了下资料后发现还有点意思,于是就打算写这么一篇专栏来分享一下所得
要想学习 axis,首先要知道的就是 axis 的计数方式。我们在使用 numpy 的各种函数——比如说 np.sum——的时候,有一个参数就叫做 axis。那么这个参数的意思是什么呢?最直白地来说的话,就是“最外面的括号代表着 axis=0,依次往里的括号对应的 axis 的计数就依次加 1”
举个例子,现在我们有一个矩阵:
;在 Python,或说在 numpy 里面,这个矩阵是这样被表达出来的:x = [ [0, 1], [2, 3] ],然后 axis 的对应方式就是:不管画风怎么变,很丑这一点都无法改变啊……
所以相应的运算就是:
对应的代码实现和运行结果如下:
可以看到,貌似出来的结果比我们推导的结果的括号要少一些。这是因为诸如 np.sum 这种函数中有一个参数叫 keepdims,它的默认值是 False,此时它会把多余的括号给删掉。假如我们把它设为 True 的话,就可以得到和推导中一致的结果了:
下面来看一个更“高维”一点的例子:
对应的代码实现和运行结果如下:
以及
可以看到结果和我们推导的确实一样
现在我们知道哪个 axis 对应于数组中的哪些元素了,接下来还需要知道的就是 transpose 这个函数到底在背后干了什么。从纸面上来看,如果一个高维数组 x 的 shape 是 (2, 3, 4),那么 transpose 的作用就是把这个 shape 中各个数的顺序改一改。比如说:
但是 transpose 返回的结果究竟是如何得到的,可能就比较难理解了。幸运的是,这个回答非常好地阐明了这背后的原理。为了方便观众老爷们,我在这里就当一个搬运 and 润色工
首先是对这个 shape 的理解。直观地说,shape 中的各个数就是对应 axis 的元素个数。比如说上图中的 x,它画出来会是这个样子的:字比画还丑呢……
如果我们换一种思路的话,以 axis=0 为例,由于我们现在整个数组里面一共有 24 个数,而 axis=0 只有两个元素,所以可以理解为在 axis=0 这个 axis 上,每隔 24 / 2 = 12 个数就跳一下。比如说上面这个图中就可以看出,两个橙色矩阵对应的数之间差的都是 12
类似的,由于一个橙色矩阵中只有 24 / 2 = 12 个数,所以我们可以理解为在 axis=1 这个 axis 上,每隔 12 / 3 = 4 个数就跳一下。表现在图中,就是同一个橙色矩阵的两个相邻的蓝色向量对应的数之间差的都是 4
再次类似的,由于一个蓝色向量中只有 12 / 3 = 4 个数,我们可以理解为在 axis=2 这个 axis 上,每隔 4 / 4 = 1 个数就跳一下。表现在图中……观众老爷们想必也知道是怎样的了 ( σ'ω')σ
所以我们现在可以定义一个新的东西,比如说叫做 strides 吧,它记录着每个 axis 上跳过的数。比如说上图对应的三维数组,它的 strides 就是 (12, 4, 1)
那么接下来激动人心的时刻到了:transpose 的本质,其实就是对 strides 中各个数的顺序进行调换。举个例子:
在 transpose(1, 0, 2) 后,相应的 strides 会变成 (4, 12, 1)。而从上图可以看出,transpose 的结果确实满足:axis=0 的 axis 上,每隔 4 个数跳一下
axis=1 的 axis 上,每隔 12 个数跳一下
axis=2 的 axis 上,每隔 1 个数跳一下
至此,transpose 背后的逻辑就理顺啦!撒花!*★,°*:.☆\( ̄▽ ̄)/$:*.°★* 。
总之,如果这篇专栏能让大家对 numpy 的 axis 这玩意儿能有更好的认知,从而能够更加得心应手地驾驭 numpy 这个数据分析 & 机器学习领域中的神器的话,就再好不过了。然后就是由于我其实也不咋熟悉 axis,所以如果上文中有什么错漏的话,还请大家指出 ( σ'ω')σ
希望观众老爷们能够喜欢~