pytorch 维度变换

view/reshape 功能完全一样。对一个张量进行维度变换,但是最终的总数量必须保持不变

import torch
import numpy as np
a = torch.rand(4,2,28,28)  # 四维张量
print(a.shape) #  原始的张量
print(a.view(8,28*28).shape) # 维度转换成一维,总大小不变
print(a.view(-1,8)) # -1 表示任意维度(pytorch根据,后面的维度自己推导,如总维度是 28*28*8,此时-1代表的维度就是28*28*8 / 8 =28*28)
torch.Size([4, 2, 28, 28])
torch.Size([8, 784])
torch.Size([784, 8])

suqeeze/unsqueeze 维度的压缩与扩展

unsuqeeze

import torch
import numpy as np
a = torch.rand(4,2,28,28)  # 四维张量
print("a 的初始维度是:{}".format(a.shape)) #  原始的张量
print("a 增加一个维度是:{}".format(a.unsqueeze(0).shape)) # 最前面增加一个维度
print("a 增加一个维度是:{}".format(a.unsqueeze(1).shape)) # 在原来第二个维度前增加一个维度
print("a 增加一个维度是:{}".format(a.unsqueeze(-1).shape)) # 最后面增加一个维度

a 的初始维度是:torch.Size([4, 2, 28, 28])
a 增加一个维度是:torch.Size([1, 4, 2, 28, 28])
a 增加一个维度是:torch.Size([4, 1, 2, 28, 28])
a 增加一个维度是:torch.Size([4, 2, 28, 28, 1])

suqeeze

import torch
import numpy as np
a = torch.rand(4,2,28,28,1,1)  # 四维张量
print("a 的初始维度是:{}".format(a.shape)) #  原始的张量
print("a 删除维度为1的维度是:{}".format(a.squeeze().shape)) # 删除所有维度为1的维度
print("a 删除指定位置为1的维度是:{}".format(a.squeeze(4).shape)) # 
a 的初始维度是:torch.Size([4, 2, 28, 28, 1, 1])
a 删除维度为1的维度是:torch.Size([4, 2, 28, 28])
a 删除指定位置为1的维度是:torch.Size([4, 2, 28, 28, 1])

转置

转置 .t() (二维)

import torch
import numpy as np
a = torch.rand(4,2)  # 四维张量
print("a 转置后是:{}".format(a.t().shape)) 
a 转置后是:torch.Size([2, 4])

transpose 两个维度相互交换

import torch
import numpy as np
a = torch.rand(4,3,32,32)  # 四维张量
#transpose 后需要接 contiguous 保证数据在内存的连续性
a1 = a.transpose(1,3).contiguous().view(4,3*32*32).view(4,3,32,32)
a2 = a.transpose(1,3).contiguous().view(4,3*32*32).view(4,32,32,3).transpose(1,3)
# torch.all 判断所有元素   torch.eq 判断两个张量是否相等
print("a1 shape : {}  a2 shape : {}".format(a1.shape,a2.shape))
# 判断变换后 a 与 a1 是否相同
print("a1 与 a 是否相同:{}".format(torch.all(torch.eq(a,a1))))
# 判断变换后 a 与 a2 是否相同
print("a2 与 a 是否相同:{}".format(torch.all(torch.eq(a,a2))))
a1 shape : torch.Size([4, 3, 32, 32])  a2 shape : torch.Size([4, 3, 32, 32])
a1 与 a 是否相同:False
a2 与 a 是否相同:True

permute 任意维度的交换

import torch
import numpy as np
a = torch.rand(4,3,32,32)  # 四维张量
#permute 后需要接 contiguous 保证数据在内存的连续性
a1 = a.permute(0,3,2,1).contiguous() # 按各维度的索引进行排列
print("a1 shape : {} ".format(a1.shape))
a1 shape : torch.Size([4, 32, 32, 3])
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值