np.stack()函数详解

np.stack(array, axis)

背景

在python的numpy库中,数组的stack堆叠是个很常见的操作,如何堆叠涉及到axis这个参数,本文以np.stack()函数为例,去讲解axis这个参数的解释。

语法

stack(arrays, axis=0, out=None)
    Join a sequence of arrays along a new axis.
    
    The `axis` parameter specifies the index of the new axis in the dimensions
    of the result. For example, if ``axis=0`` it will be the first dimension
    and if ``axis=-1`` it will be the last dimension.

    Parameters
    ----------
    arrays : sequence of array_like
        Each array must have the same shape.
    axis : int, optional
        The axis in the result array along which the input arrays are stacked.

从官方文档的解释可以看出,stack()中的axis参数是在维度中加入了一个新轴,也即是stack堆叠的最终结果是返回一个array数组,堆叠后数组的维(轴)数比原始数组的维(轴)数要多一个维(轴),且多的那一个(维)轴上的数值为需要进行stack堆叠的array数组个数。例如要堆叠的数组是二维(轴),shape为(5,4),要堆叠的数组个数为3个,那么返回的结果一定是在二维(轴)上增加一维(轴)成三维(轴),且增加的维(轴)上的数值为3,那么新增的数值为3的维(轴)是增加在第0维(轴)上的(3,5,4)、第1维(轴)的(5,3,4)还是第2维(轴)的(5,4,3),而新增的维(轴)插在不同的位置也正是axis参数的真正含义。

测试

import numpy as np
a = np.array([1,2,3])
b = np.array([10,10,10])
>>> np.stack((a,b),axis=0)
array([[ 1,  2,  3],
       [10, 10, 10]])
>>> np.stack((a,b),axis=1)
array([[ 1, 10],
       [ 2, 10],
       [ 3, 10]])
>>> np.stack((a,b),axis=2)
numpy.AxisError: axis 2 is out of bounds for array of dimension 2

可以看出由于原始数组的形状为(3,)是一维(轴)数组,使用stack堆叠2个数组后必然返回的是新增一维(轴)的二维(轴)数组形状为(2,3)或(3,2),如果axis=2则超出了二维(轴)的设定,因此不可能实现,再来看当axis=0则得到的是形状为(2,3),当axis=1则得到的是形状为(3,2)。由此可见,axis的值正是表示的新维(轴)新增的位置。

规则

不同的axis值得到的是不同形状的数组,那么原始数组中的元素又是如何堆叠成新数组的呢,stack实际上是利用了python广播机制先扩展为设定形状的数组再执行简单堆叠方法(简单堆叠函数vstack,hstack一般不改变原数组的维(轴)数,只对元素进行纵向或横向拼接)。

np.stack((a,b),axis=0)为例,数组a是array([1, 2, 3])的形状是(3,),数组b是array([10, 10, 10])的形状是(3,),由于axis=0,所以新增的维(轴)出现在第0维(轴)的位置得到形状假设为(x,3)的数组,而数组a和数组b是2个数组进行堆叠,则第0维(轴)上的形状数值x应当为2,所以最终的返回数组形状是(2,3)。注意,新增的维(轴)位置上的数值2并不是替换原来数组形状第0维(轴)位置上的数值3而是将原来第0维(轴)向后挤形成多一层级的第1维(轴)。

再考虑如何由两个形状(3,)的数组堆叠为最终形状(2,3)的数组,由于已知最终形状是(2,3),则原来(3,)的数组通过广播机制将形状扩展为(1,3),则a=array([1,2,3])将扩展为a'=array([[1,2,3]]),同理b'=array([[10,10,10]]),广播扩展后的2个数组再沿着axis=0的row方向堆叠(即按纵向堆叠row行数的vstack方法,array[[1,2,3]]沿0轴堆叠array[[10,10,10]]),因此才得到了array([[1,2,3],[10,10,10]])

>>> np.vstack((np.array([[1,2,3]]),np.array([[10,10,10]])))
array([[ 1,  2,  3],
       [10, 10, 10]])

同理np.stack((a,b),axis=1),最终返回的数组形状是(3,2),因此远来(3,)的数组扩展为(3,1),即a'=array([[1],[2],[3]])和b’=array([[10],[10],[10]]),然后沿着axis=1的column方向堆叠(即按横向堆叠column列数的hstack方法,array([[1],[2],[3]]沿1轴堆叠array[[10],[10],[10]]),因此得到array([[1,10],[2,10],[3,10]])

>>> np.hstack((np.array([[1],[2],[3]]),np.array([[10],[10],[10]])))
array([[ 1, 10],
       [ 2, 10],
       [ 3, 10]])
  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值