PyTorch常用张量切割和拼接方法(torch.chunk、torch.split、torch.cat和torch.stack用法详解)

目录

1、torch.chunk

2、torch.split

3、torch.cat

4、torch.stack


1、torch.chunk

首先看官方文档的解释:

功能:尝试将张量拆分为指定数量的数据块,每个数据块都是输入张量的一个视图。

例子:

>>> torch.arange(5).chunk(3)
(tensor([0, 1]),
 tensor([2, 3]),
 tensor([4]))

>>> torch.arange(6).chunk(3)
(tensor([0, 1]),
 tensor([2, 3]),
 tensor([4, 5]))

>>> torch.arange(7).chunk(3)
(tensor([0, 1, 2]),
 tensor([3, 4, 5]),
 tensor([6]))

         通过例子很好理解,torch.chunk是将tensor切割为k个tensor。

2、torch.split

首先看官方文档的解释:

功能:将张量分成数据块,每个数据块都是原始张量的视图。

官方解释的功能与torch.chunk大致相同,具体的区别通过例子很容易理解。

例子:

>>> a = torch.arange(10).reshape(5,2)
>>> a
tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7],
        [8, 9]])

>>> torch.split(a, 2)
(tensor([[0, 1],
         [2, 3]]),
 tensor([[4, 5],
         [6, 7]]),
 tensor([[8, 9]]))

>>> torch.split(a, [1,2])
(tensor([[0, 1]]),
 tensor([[2, 3],
         [4, 5],
         [6, 7],
         [8, 9]]))

        如果split_size_or_sections传入的是整数的话,就是沿着dim方向尽量分割出长度为split_size_or_sections的tensor。

        如果split_size_or_sections传入的是list的话,就按照list的参数分割不同长度的tensor。

注:chunk是指定分割数量,split是指定分割完的tensor的样式。

3、torch.cat

首先看官方文档的解释:

功能:在给定维度上对输入的张量序列seq 进行连接操作,所有的张量必须有相同的形状(连接维度除外)或为空。

           torch.cat()可以被看作是torch.split()和torch.chunk()的逆向操作。

例子:

>>> a = torch.arange(6).reshape(2, 3)
>>> a
tensor([[0, 1, 2],
        [3, 4, 5]])

>>> torch.cat((a, a, a), 0)
tensor([[0, 1, 2],
        [3, 4, 5],
        [0, 1, 2],
        [3, 4, 5],
        [0, 1, 2],
        [3, 4, 5]])

>>> torch.cat((a, a, a), 1)
tensor([[0, 1, 2, 0, 1, 2, 0, 1, 2],
        [3, 4, 5, 3, 4, 5, 3, 4, 5]])

        通过例子很好理解,torch.cat将tensor沿着dim进行拼接。

4、torch.stack

功能:沿着一个新的维度连接一系列张量。所有张量需要是相同的大小。

例子:

>>> a = torch.arange(9).reshape(2, 3)
>>> a
tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])

>>> torch.stack((a, a), 0) # torch.Size([2, 3, 3])
tensor([[[0, 1, 2],
         [0, 1, 2]],

        [[3, 4, 5],
         [3, 4, 5]],

        [[6, 7, 8],
         [6, 7, 8]]])

>>> torch.stack((a, a), 1) # torch.Size([3, 2, 3])
tensor([[[0, 1, 2],
         [0, 1, 2]],

        [[3, 4, 5],
         [3, 4, 5]]])

>>> torch.stack((a, a), 2) # torch.Size([3, 3, 2])
tensor([[[0, 0],
         [1, 1],
         [2, 2]],

        [[3, 3],
         [4, 4],
         [5, 5]],

        [[6, 6],
         [7, 7],
         [8, 8]]])

        与cat相比不同的是,stack是在添加一个新的维度去连接,连接方式通过例子中的output和Size很容易理解。 

  • 9
    点赞
  • 35
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
PyTorch提供了full stack trace的功能,它可以在程序发生异常时输出详细的调用堆栈信息,帮助开发者快速定位错误。 在PyTorch中,我们可以通过设置环境变量`PYTHONFAULTHANDLER=1`来启用full stack trace功能。例如,在Linux系统下,我们可以在终端中输入如下命令: ``` export PYTHONFAULTHANDLER=1 ``` 然后,在我们的程序中发生异常时,PyTorch会输出完整的调用堆栈信息,包括每个函数的参数和变量值等。例如: ``` Traceback (most recent call last): File "test.py", line 5, in <module> y = model(x) File "/path/to/model.py", line 10, in __call__ y = self.fc(x) File "/path/to/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/path/to/torch/nn/modules/linear.py", line 94, in forward return F.linear(input, self.weight, self.bias) File "/path/to/torch/nn/functional.py", line 1692, in linear if input.dim() == 2 and bias is not None: AttributeError: 'NoneType' object has no attribute 'dim' ``` 从上述输出中,我们可以看到程序异常发生在`test.py`的第5行,即调用了`model(x)`。接着,程序进入到了`model.py`中的`__call__`函数,在该函数中调用了`self.fc(x)`(假设`fc`是一个线性层),然后进入了PyTorch的内部函数调用堆栈,最终出现了一个`AttributeError`异常。 通过full stack trace功能,我们可以清晰地看到程序执行过程中每个函数的调用关系,方便我们快速定位错误。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值