pytorch中常见的维度操作

1、view ;reshape;Flatten:维度合并和分解
2、squeeze;unsqueeze:压缩维度和增加维度(相对于维度为1的数据)
3、transpose;t;permute:维度顺顺序变换(转置)
4、expand;repeat:维度扩展

import torch

'''
维度变换
1、view ;reshape;Flatten:维度合并和分解
2、squeeze;unsqueeze:压缩维度和增加维度(相对于维度为1的数据)
3、transpose;t;permute:维度顺顺序变换(转置)
4、expand;repeat:维度扩展
'''
a = torch.rand(4, 1, 32, 32)

'''
view()的原理很简单,其实就是把原先tensor中的数据进行排列,排成一行,然后根据所给的view()中的参数从一行中按顺序选择组成最终的tensor。
view()可以有多个参数,这取决于你想要得到的是几维的tensor,一般设置两个参数,也是神经网络中常用的(一般在全连接之前),代表二维。
view(h,w),h代表行(想要变为几行),当不知道要变为几行,但知道要变为几列时可取-1;w代表的是列(想要变为几列),当不知道要变为几列,但知道要变为几行时可取-1。
'''


def zqb_view():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.view(4, 32 * 32)
    print(a1.shape)  # torch.Size([4, 1024])
    a2 = a1.view(4, 1, 32, 32)
    print(a2.shape)  # torch.Size([4, 1, 32, 32])

    # a3 = a1.view(4,28,28) #RuntimeError: shape '[4, 28, 28]' is invalid for input of size 4096
    #     要保持输出数据与输入数据总量,防止数据污染
    # a4 = a1.view(4,32,32,1) # 逻辑错误,改变了原来数据的存储方式,虽然不会报错,但是数据已经被污染,无法正常使用

    a5 = a.view(-1, 32 * 32)  # torch.Size([4, 1024])  -1表示该维度保持不变
    print(a5.shape)


'''
reshape()可以由torch.reshape(),也可由torch.Tensor.reshape()调用,
其作用是在不改变tensor元素数目的情况下改变tensor的shape。
torch.reshape() 需要两个参数,一个是待被改变的张量tensor,一个是想要改变的形状。
'''


def zqb_reshape():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.reshape(-1, 32 * 32)  # torch.Size([4, 1024])
    print(a1.shape)
    a2 = a1.reshape(4, 1, 32, 32)  # torch.Size([4, 1, 32, 32])
    print(a2.shape)


'''
torch.nn.Flatten(start_dim=1,end_dim=-1)
start_dim与end_dim分别表示开始的维度和终止的维度,默认值为1和-1,其中1表示第一维度,-1表示最后的维度。结合起来看意思就是从第一维度到最后一个维度全部给展平为张量。
(注意:数据的维度是从0开始的,也就是存在第0维度,第一维度并不是真正意义上的第一个)。
因为其被用在神经网络中,输入为一批数据,第 0 维为batch(输入数据的个数),通常要把一个数据拉成一维,而不是将一批数据拉为一维。所以torch.nn.Flatten()默认从第一维开始平坦化。

'''


def zqb_Flatten():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.flatten(1, -1)  # torch.Size([4, 1024]) ,
    print(a1.shape)
    a2 = a.flatten(2, -1)  # torch.Size([4, 1, 1024])
    print(a2.shape)


'''
unsqueeze(dim=idx)
表示插入的维度占据输出数据的维度,比如
[4, 1, 32, 32].unsqueeze(0),表示新插入维度占据输出数据的0维度torch.Size([1, 4, 1, 32, 32])
[4, 1, 32, 32].unsqueeze(-1),表示新插入维度占据输出数据的-1维度torch.Size([1, 4, 1, 32, 32,1])
'''


def zqb_unsqueeze():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.unsqueeze(0)
    print(a1.shape)  # torch.Size([1, 4, 1, 32, 32])
    a2 = a.unsqueeze(-1)
    print(a2.shape)  # torch.Size([4, 1, 32, 32, 1])


'''
squeeze()不给参数,表示删除所有1的维度
squeeze(index) 给参数,删除指定index维度,若刚该维度不为1则不做处理
'''


def zqb_squeeze():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.squeeze()
    print(a1.shape)  # torch.Size([4, 32, 32])
    a2 = a.squeeze(1)
    print(a2.shape)  # torch.Size([4, 32, 32])
    a3 = a.squeeze(0)
    print(a3.shape)  # torch.Size([4, 1, 32, 32])


'''
转置t操作仅仅针对2维数据操作
'''


def zqb_t():
    a = torch.rand(2, 3)
    print(a.shape)  # torch.Size([2, 3])
    a1 = a.t()
    print(a1.shape)  # torch.Size([3, 2])


'''
transpose(dim0,dim1)指定需要调换的两个维度,与顺序无关
对3维及以上的进行操作,输入需要调换的两个维度
'''


def zqb_transpose():
    print(a.shape)  # torch.Size([4, 1, 32, 32])对应的维度0,1,2,3
    a1 = a.transpose(0, 1)
    print(a1.shape)  # torch.Size([1, 4, 32, 32])
    a2 = a.transpose(1, 0)
    print(a2.shape)  # torch.Size([1, 4, 32, 32])


'''
permute(*dims),指定新维度的顺序
'''


def zqb_permute():
    print(a.shape)  # torch.Size([4, 1, 32, 32])对应的维度0,1,2,3
    a1 = a.permute(1, 0, 2, 3)
    print(a1.shape)  # torch.Size([1, 4, 32, 32])


'''
只能对维度值为1的维度进行扩展,无需扩展的维度,维度值不变,
对应位置可写上原始维度大小或直接写作-1;且扩展的Tensor不会分配新的内存,
只是原来的基础上创建新的视图并返回,返回的张量内存是不连续的。
类似于numpy中的broadcast_to函数的作用。如果希望张量内存连续,可以调用contiguous函数。

'''


def zqb_expand():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.expand(-1, 4, -1, -1)
    print(a1.shape)  # torch.Size([4, 4, 32, 32])


'''
repeat参数*sizes指定了原始张量在各维度上复制的次数。
整个原始张量作为一个整体进行复制,这与Numpy中的repeat函数截然不同,
而更接近于tile函数的效果。
'''


def zqb_repeat():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.repeat(1, 4, 1, 1)  # torch.Size([4, 4, 32, 32])
    print(a1.shape)
    a2 = a.repeat(4, 1, 1, 1)
    print(a2.shape)  # torch.Size([16, 1, 32, 32])


if __name__ == '__main__':
    zqb_view()
    zqb_reshape()
    zqb_Flatten()
    zqb_unsqueeze()
    zqb_squeeze()
    zqb_t()
    zqb_transpose()
    zqb_permute()
    zqb_expand()
    zqb_repeat()

  • 9
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch维度用`dim`表示,可以用来指定在哪个维度上进行操作。下面是一些常见PyTorch维度操作: 1. `torch.unsqueeze(input, dim)`:在指定维度增加一个维度,返回一个新的张量。例如: ```python import torch x = torch.tensor([1, 2, 3]) # 一维张量 x = torch.unsqueeze(x, 0) # 在第0维增加一个维度 print(x) # 输出:tensor([[1, 2, 3]]) ``` 2. `torch.squeeze(input, dim)`:在指定维度上去掉一个维度,返回一个新的张量。例如: ```python import torch x = torch.tensor([[1, 2, 3]]) # 二维张量 x = torch.squeeze(x, 0) # 去掉第0维 print(x) # 输出:tensor([1, 2, 3]) ``` 3. `torch.transpose(input, dim0, dim1)`:交换两个维度的位置,返回一个新的张量。例如: ```python import torch x = torch.tensor([[1, 2, 3], [4, 5, 6]]) # 二维张量 x = torch.transpose(x, 0, 1) # 交换第0维和第1维 print(x) # 输出:tensor([[1, 4], # [2, 5], # [3, 6]]) ``` 4. `torch.cat(inputs, dim)`:在指定维度上将多个张量拼接起来,返回一个新的张量。例如: ```python import torch x1 = torch.tensor([[1, 2, 3]]) x2 = torch.tensor([[4, 5, 6]]) x = torch.cat((x1, x2), dim=0) # 在第0维上拼接 print(x) # 输出:tensor([[1, 2, 3], # [4, 5, 6]]) ``` 5. `torch.stack(inputs, dim)`:在指定维度上将多个张量堆叠起来,返回一个新的张量。例如: ```python import torch x1 = torch.tensor([1, 2, 3]) x2 = torch.tensor([4, 5, 6]) x = torch.stack((x1, x2), dim=0) # 在第0维上堆叠 print(x) # 输出:tensor([[1, 2, 3], # [4, 5, 6]]) ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值