np.expand_dims的使用

         关于np.expand_dims的使用,网上好多举了一些实例,自己在平时也常见,但总是有点迷糊,我知道它的作用是扩展一个张量的维度,但结果是如何变化得到的,想来想去不是太明了,所以去函数源码看了一下,算是明白了,np.expand_dims的源码如下:

def expand_dims(a, axis):
    """
    Expand the shape of an array.

    Insert a new axis that will appear at the `axis` position in the expanded
    array shape.

    .. note:: Previous to NumPy 1.13.0, neither ``axis < -a.ndim - 1`` nor
       ``axis > a.ndim`` raised errors or put the new axis where documented.
       Those axis values are now deprecated and will raise an AxisError in the
       future.

    Parameters
    ----------
    a : array_like
        Input array.
    axis : int
        Position in the expanded axes where the new axis is placed.

    Returns
    -------
    res : ndarray
        Output array. The number of dimensions is one greater than that of
        the input array.

    See Also
    --------
    squeeze : The inverse operation, removing singleton dimensions
    reshape : Insert, remove, and combine dimensions, and resize existing ones
    doc.indexing, atleast_1d, atleast_2d, atleast_3d

    Examples
    --------
    >>> x = np.array([1,2])
    >>> x.shape
    (2,)

    The following is equivalent to ``x[np.newaxis,:]`` or ``x[np.newaxis]``:

    >>> y = np.expand_dims(x, axis=0)
    >>> y
    array([[1, 2]])
    >>> y.shape
    (1, 2)

    >>> y = np.expand_dims(x, axis=1)  # Equivalent to x[:,np.newaxis]
    >>> y
    array([[1],
           [2]])
    >>> y.shape
    (2, 1)

    Note that some examples may use ``None`` instead of ``np.newaxis``.  These
    are the same objects:

    >>> np.newaxis is None
    True

    """
    if isinstance(a, matrix):
        a = asarray(a)
    else:
        a = asanyarray(a)

    shape = a.shape
    if axis > a.ndim or axis < -a.ndim - 1:
        # 2017-05-17, 1.13.0
        warnings.warn("Both axis > a.ndim and axis < -a.ndim - 1 are "
                      "deprecated and will raise an AxisError in the future.",
                      DeprecationWarning, stacklevel=2)
    # When the deprecation period expires, delete this if block,
    if axis < 0:
        axis = axis + a.ndim + 1
    # and uncomment the following line.
    # axis = normalize_axis_index(axis, a.ndim + 1)
    return a.reshape(shape[:axis] + (1,) + shape[axis:])

上述是官方原始文件,最重要的地方就最后三行代码,下面对其稍微注解一下:

    if axis < 0:
        axis = axis + a.ndim + 1#当采用倒数的方式指定维度位置时需要转化为正常顺序的位置
    # and uncomment the following line.
    # axis = normalize_axis_index(axis, a.ndim + 1)
    return a.reshape(shape[:axis] + (1,) + shape[axis:])#expand_dims最重要的地方就这里了,在axis位置把新维度插入原始维度中,然后reshape一下。上面的+(1,)可能不好理解(python中两个tuple类型相加,不是求和,而是拼接),举个例子:
例1:
shape=(2,3,4)
axis=1
newshape=shape[:axis] + (1,) + shape[axis:]
print(newshape)
#输出:(2, 1, 3, 4)
例2:
shape=(2,3,4)
axis=1
newshape=shape[:axis] + (1,11) + shape[axis:]
print(newshape)
#输出:(2, 1, 11, 3, 4)

 

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值