PyTorch常用API

🚀Torch 常用API之(reshape; view; permute; transpose; squeeze; unsqueeze; cat; split; chunk; stack)

  1. torch.reshape()input变换成指定形状的output。🌞参数有两个: inputshape

    """
    Parameters
    	input (Tensor) – the tensor to be reshaped
    	shape (tuple of int) – the new shape
    """
    >>> import torch
    >>> a = torch.arange(4)
    >>> a
    tensor([0, 1, 2, 3])
    >>> torch.reshape(a, (2, 2))
    tensor([[0, 1],
            [2, 3]])
    
  2. torch.Tensor.viewinput变换成指定形状的output。不能保证output在内存中的连续性,需要使用contiguous()

    >>> import torch
    >>> a = torch.arange(6)
    >>> a
    tensor([0, 1, 2, 3, 4, 5])
    >>> a.view(2, 3)
    tensor([[0, 1, 2],
            [3, 4, 5]])
    
  3. torch.permuteinput的维度进行置换。🌞参数有两个: inputdims

    """
    Parameters
    	input (Tensor) – the input tensor.
    	dims (tuple of int) – The desired ordering of dimensions
    """
    >>> a = torch.arange(8).reshape(4, 2)
    >>> a
    tensor([[0, 1],
            [2, 3],
            [4, 5],
            [6, 7]])
    >>> torch.permute(a, (1, 0))
    tensor([[0, 2, 4, 6],
            [1, 3, 5, 7]])
    
  4. torch.transposeinput的维度进行置换。🌞参数有三个: inputdim1dim2

    """
    Parameters
    	input (Tensor) – the input tensor.
        dim0 (int) – the first dimension to be transposed
        dim1 (int) – the second dimension to be transposed
    """
    >>> a = torch.arange(10).reshape(2, 5)
    >>> a
    tensor([[0, 1, 2, 3, 4],
            [5, 6, 7, 8, 9]])
    >>> torch.transpose(a, 0, 1)
    tensor([[0, 5],
            [1, 6],
            [2, 7],
            [3, 8],
            [4, 9]])
    
  5. torch.squeezeinput中所有为1的维度删掉。🌞参数有两个: inputdim(可选)

    """
        Parameters
        input (Tensor) – the input tensor.
        dim (int or tuple of ints, optional) –
    """
    >>> import torch
    >>> a = torch.arange(10).reshape(1,2, 5, 1)
    >>> a.size()
    torch.Size([1, 2, 5, 1])
    >>> torch.squeeze(a).size()
    torch.Size([2, 5])
    
  6. torch.unsqueezeinput在指定位置扩展1个维度。🌞参数有两个: inputdim

    """
    Parameters
    	input (Tensor) – the input tensor.
    	dim (int) – the index at which to insert the singleton dimension
    """
    >>> import torch
    >>> a = torch.arange(10).reshape(2, 5)
    >>> a.size()
    torch.Size([2, 5])
    >>> torch.unsqueeze(a, 0).size()
    torch.Size([1, 2, 5])
    
  7. torch.cat 将多个input按照指定维度拼接在一起。🌞参数有两个: 多个inputdim(可选),默认为0维

    """
    Parameters
        tensors (sequence of Tensors) – any python sequence of tensors of the same type. Non-empty tensors 			provided must have the same shape, except in the cat dimension.
    	dim (int, optional) – the dimension over which the tensors are concatenated
    """
    >>> a = torch.randn((2, 3))
    >>> b = torch.randn((2, 3))
    >>> c = torch.cat((a, b))
    >>> c.size()
    torch.Size([4, 3])
    >>> d = torch.cat((a, b), 1)
    >>> d.size()
    torch.Size([2, 6])
    
  8. torch.split 可看作torch.cat 的反操作,将input分为多块。🌞参数有三个: inputsplit_size_or_sections dim

    """
    Parameters
        tensor (Tensor) – tensor to split.
        split_size_or_sections (int) or (list(int)) – size of a single chunk or list of sizes for each chunk
        dim (int) – dimension along which to split the tensor.
    """
    >>> a = torch.arange(10).reshape(5, 2)
    >>> a
    tensor([[0, 1],
            [2, 3],
            [4, 5],
            [6, 7],
            [8, 9]])
    >>> torch.split(a, 2)
    (tensor([[0, 1],
             [2, 3]]),
     tensor([[4, 5],
             [6, 7]]),
     tensor([[8, 9]]))
    >>> torch.split(a, [1, 4])
    (tensor([[0, 1]]),
     tensor([[2, 3],
             [4, 5],
             [6, 7],
             [8, 9]]))
    
  9. torch.chunk 可看作torch.cat 的反操作,将input分为指定块数。🌞参数有三个: inputchunks dim

    """
    Parameters
    	input (Tensor) – the tensor to split
    	chunks (int) – number of chunks to return
    	dim (int) – dimension along which to split the tensor
    """
    >>> torch.arange(11).chunk(6)
    (tensor([0, 1]),
     tensor([2, 3]),
     tensor([4, 5]),
     tensor([6, 7]),
     tensor([8, 9]),
     tensor([10]))
    >>> torch.arange(12).chunk(6)
    (tensor([0, 1]),
     tensor([2, 3]),
     tensor([4, 5]),
     tensor([6, 7]),
     tensor([8, 9]),
     tensor([10, 11]))
    >>> torch.arange(13).chunk(6)
    (tensor([0, 1, 2]),
     tensor([3, 4, 5]),
     tensor([6, 7, 8]),
     tensor([ 9, 10, 11]),
     tensor([12]))
    
  10. torch.stack 沿着新的维度拼接多个input。 🌞参数有两个: 多个inputdim(可选),默认在0维新增一个维度。

    """
    Parameters
    	tensors (sequence of Tensors) – sequence of tensors to concatenate
        dim (int, optional) – dimension to insert. Has to be between 0 and the number of dimensions of 				concatenated tensors (inclusive). Default: 0
    """
    >>> a = torch.arange(9).reshape(3, 3)
    >>> a
    tensor([[0, 1, 2],
            [3, 4, 5],
            [6, 7, 8]])
    >>> torch.stack((a, a), 0)
    tensor([[[0, 1, 2],
             [3, 4, 5],
             [6, 7, 8]],
    
            [[0, 1, 2],
             [3, 4, 5],
             [6, 7, 8]]])
    >>> torch.stack((a, a), 0).size()
    torch.Size([2, 3, 3])
    
  • 20
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值