pytorch中的cat、stack、tranpose、permute、unsqeeze

pytorch中提供了对tensor常用的变换操作。

cat 连接

对数据沿着某一维度进行拼接。cat后数据的总维数不变。
比如下面代码对两个2维tensor(分别为2*3,1*3)进行拼接,拼接完后变为3*3还是2维的tensor。
代码如下:

import torch
torch.manual_seed(1)
x = torch.randn(2,3)
y = torch.randn(1,3)
print(x,y)
 
 
  • 1
  • 2
  • 3
  • 4
  • 5

结果:

0.6614  0.2669  0.0617
 0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x3]

-1.5228  0.3817 -1.0276
[torch.FloatTensor of size 1x3]
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

将两个tensor拼在一起:

torch.cat((x,y),0)
 
 
  • 1

结果:

 0.6614  0.2669  0.0617
 0.6213 -0.4519 -0.1661
-1.5228  0.3817 -1.0276
[torch.FloatTensor of size 3x3]
 
 
  • 1
  • 2
  • 3
  • 4

更灵活的拼法:

torch.manual_seed(1)
x = torch.randn(2,3)
print(x)
print(torch.cat((x,x),0))
print(torch.cat((x,x),1))
 
 
  • 1
  • 2
  • 3
  • 4
  • 5

结果

// x
0.6614  0.2669  0.0617
 0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x3]

// torch.cat((x,x),0)
 0.6614  0.2669  0.0617
 0.6213 -0.4519 -0.1661
 0.6614  0.2669  0.0617
 0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 4x3]

// torch.cat((x,x),1)
 0.6614  0.2669  0.0617  0.6614  0.2669  0.0617
 0.6213 -0.4519 -0.1661  0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x6]
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

stack,增加新的维度进行堆叠

而stack则会增加新的维度。
如对两个1*2维的tensor在第0个维度上stack,则会变为2*1*2的tensor;在第1个维度上stack,则会变为1*2*2的tensor。
见代码:

a = torch.ones([1,2])
b = torch.ones([1,2])
c= torch.stack([a,b],0) // 第0个维度stack
 
 
  • 1
  • 2
  • 3

输出:

(0 ,.,.) = 
  1  1

(1 ,.,.) = 
  1  1
[torch.FloatTensor of size 2x1x2]
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
c= torch.stack([a,b],1) // 第1个维度stack
 
 
  • 1

输出:

(0 ,.,.) = 
  1  1
  1  1
[torch.FloatTensor of size 1x2x2]
 
 
  • 1
  • 2
  • 3
  • 4

transpose ,交换维度

代码如下:

torch.manual_seed(1)
x = torch.randn(2,3)
print(x)
 
 
  • 1
  • 2
  • 3

原来x的结果:

 0.6614  0.2669  0.0617
 0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x3]
 
 
  • 1
  • 2
  • 3

将x的维度互换

x.transpose(0,1)
 
 
  • 1

结果

0.6614  0.6213
 0.2669 -0.4519
 0.0617 -0.1661
[torch.FloatTensor of size 3x2]
 
 
  • 1
  • 2
  • 3
  • 4

permute,适合多维数据,更灵活的transpose

permute是更灵活的transpose,可以灵活的对原数据的维度进行调换,而数据本身不变。
代码如下:

x = torch.randn(2,3,4)
print(x.size())
x_p = x.permute(1,0,2) # 将原来第1维变为0维,同理,0→1,2→2
print(x_p.size())
 
 
  • 1
  • 2
  • 3
  • 4

结果:

torch.Size([2, 3, 4])
torch.Size([3, 2, 4])
 
 
  • 1
  • 2

squeeze 和 unsqueeze

squeeze(dim_n)压缩,即去掉元素数量为1的dim_n维度。同理unsqueeze(dim_n),增加dim_n维度,元素数量为1。

上代码:

# 定义张量
import torch

b = torch.Tensor(2,1)
b.shape
Out[28]: torch.Size([2, 1])

# 不加参数,去掉所有为元素个数为1的维度
b_ = b.squeeze()
b_.shape
Out[30]: torch.Size([2])

# 加上参数,去掉第一维的元素为1,不起作用,因为第一维有2个元素
b_ = b.squeeze(0)
b_.shape 
Out[32]: torch.Size([2, 1])

# 这样就可以了
b_ = b.squeeze(1)
b_.shape
Out[34]: torch.Size([2])

# 增加一个维度
b_ = b.unsqueeze(2)
b_.shape
Out[36]: torch.Size([2, 1, 1])
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
					<link href="https://csdnimg.cn/release/phoenix/mdeditor/markdown_views-7b4cdcb592.css" rel="stylesheet">
            </div>
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch提供了full stack trace的功能,它可以在程序发生异常时输出详细的调用堆栈信息,帮助开发者快速定位错误。 在PyTorch,我们可以通过设置环境变量`PYTHONFAULTHANDLER=1`来启用full stack trace功能。例如,在Linux系统下,我们可以在终端输入如下命令: ``` export PYTHONFAULTHANDLER=1 ``` 然后,在我们的程序发生异常时,PyTorch会输出完整的调用堆栈信息,包括每个函数的参数和变量值等。例如: ``` Traceback (most recent call last): File "test.py", line 5, in <module> y = model(x) File "/path/to/model.py", line 10, in __call__ y = self.fc(x) File "/path/to/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/path/to/torch/nn/modules/linear.py", line 94, in forward return F.linear(input, self.weight, self.bias) File "/path/to/torch/nn/functional.py", line 1692, in linear if input.dim() == 2 and bias is not None: AttributeError: 'NoneType' object has no attribute 'dim' ``` 从上述输出,我们可以看到程序异常发生在`test.py`的第5行,即调用了`model(x)`。接着,程序进入到了`model.py`的`__call__`函数,在该函数调用了`self.fc(x)`(假设`fc`是一个线性层),然后进入了PyTorch的内部函数调用堆栈,最终出现了一个`AttributeError`异常。 通过full stack trace功能,我们可以清晰地看到程序执行过程每个函数的调用关系,方便我们快速定位错误。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值