Pytorch 张量变形、转置及维度调整

张量变形

reshape()

张量变形较为常用的为reshape()函数,其可以无视对应张量在内存上是否连续进行张量变形,参数为形状,注意变形后的张量不会自动覆盖原本张量,需要重新赋值

import torch
import numpy as np

data1 = torch.randint(0, 10, [2, 3])
print(data1)
data1 = data1.reshape(3, 2)
print(data1)
# tensor([[9, 1, 2],
#         [0, 4, 3]])
# tensor([[9, 1],
#         [2, 0],
#         [4, 3]])
view()

 view()同样可以用于张量变形,但是要求张量连续,具体演示请见下文 ,其余参数设置与reshape相同。

import torch
import numpy as np

data1 = torch.randint(0, 10, [2, 3])
print(data1)
data1 = data1.view(3, 2)
print(data1)
# tensor([[0, 0, 7],
#         [2, 8, 5]])
# tensor([[0, 0],
#         [7, 2],
#         [8, 5]])

张量转置

transpose()

transpose本身具有转置的意思,transpose()函数通常可以用于两个维度的转置,参数为需要进行转置的两个维度

import torch
import numpy as np

data1 = torch.randint(0, 10, [2, 3, 2])
print(data1)
data1 = data1.transpose(1, 2)
print(data1)
# tensor([[[6, 3],
#          [1, 6],
#          [1, 7]],
# 
#         [[1, 8],
#          [9, 6],
#          [3, 4]]])
# tensor([[[6, 1, 1],
#          [3, 6, 7]],
# 
#         [[1, 9, 3],
#          [8, 6, 4]]])
permute()

permute具有排列、转置的意思,但是permute可以一次进行多个维度的转置,

import torch
import numpy as np

data1 = torch.randint(0, 10, [2, 3, 4])
print(data1)
data1 = data1.permute(2, 1, 0)
print(data1)
# tensor([[[1, 0, 5, 2],
#          [8, 5, 4, 0],
#          [0, 4, 1, 4]],
# 
#         [[4, 4, 1, 9],
#          [3, 4, 6, 8],
#          [1, 2, 1, 8]]])
# tensor([[[1, 4],
#          [8, 3],
#          [0, 1]],
# 
#         [[0, 4],
#          [5, 4],
#          [4, 2]],
# 
#         [[5, 1],
#          [4, 6],
#          [1, 1]],
# 
#         [[2, 9],
#          [0, 8],
#          [4, 8]]])

注意,permute和transpose都使张量在内存上特定排列的快照,实际上并不会改变内存上的排列顺序,所以这里就不能直接通过view()直接变形转置后的张量

张量连续判断is_contiguous和contiguous()
import torch
import numpy as np

data1 = torch.randint(0, 10, [2, 3, 4])
data1 = data1.permute(2, 1, 0)
print(data1.view(4, 2, 3))
# Traceback (most recent call last):
#   File "D:\Pythonproject\teach_day_01\demo02.py", line 6, in <module>
#     print(data1.view(4,2,3))
#           ^^^^^^^^^^^^^^^^^
# RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

可以看到这里的张量在转置后直接使用view()进行变形的报错,这里说的not compatible with input tensor's size and stride,compatible是可兼容,stride是步长(这里指从一个元素移动到下一个元素所需要跳过的字节数)

 contiguous()/is_contiguous()

contiguous()函数可以将内存不连续排列的的张量进行连续化(注意这里进行连续化也需要手动赋值,函数本身不会直接覆盖原张量),而is_contiguous()函数则是用于判断张量在内存上是否连续的,若连续则返回True,否则返回False。

import torch
import numpy as np

data1 = torch.randint(0, 10, [2, 3, 4])
print("转置前是否连续:", data1.is_contiguous())
data1 = data1.permute(2, 1, 0)
print("转置后是否连续:", data1.is_contiguous())
data1 = data1.contiguous()
print(data1.is_contiguous())
print(data1.view(4, 2, 3))
# 转置前是否连续: True
# 转置后是否连续: False
# True
# tensor([[[9, 5, 4],
#          [7, 0, 0]],
# 
#         [[4, 6, 1],
#          [0, 3, 8]],
# 
#         [[8, 5, 0],
#          [1, 6, 9]],
# 
#         [[4, 0, 6],
#          [9, 8, 4]]])

可以看到这里在进行了contiguous()连续化后就可以成功进行view()变形了

张量升维与降维

张量升维与降维一般是通过squeeze()与unsqueeze()函数进行的,squeeze进行降维,unsqueeze用于升维

squeeze()

squeeze()函数在不设置参数的情况下会自动压缩掉值为1的维度,参数可以设置为想要进行压缩的维度(注意要压缩的维度值只能是1),

import torch
import numpy as np

data1 = torch.randint(0, 10, [1, 2, 3, 4, 1])
print(data1.shape)
print(data1.size())
data1 = data1.squeeze()
print(data1.shape)
# torch.Size([1, 2, 3, 4, 1])
# torch.Size([1, 2, 3, 4, 1])
# torch.Size([2, 3, 4])

上面为不设置参数的情况,进行了自动压缩

import torch
import numpy as np

data1 = torch.randint(0, 10, [1, 2, 3, 4, 1])
print(data1.shape)
print(data1.size())
data1 = data1.squeeze(0)
print(data1.shape)
# torch.Size([1, 2, 3, 4, 1])
# torch.Size([1, 2, 3, 4, 1])
# torch.Size([2, 3, 4, 1])

这里设置了参数为0,则压缩了第一个维度

unsqueeze()

unsqueeze()函数会在指定维度添加一个维度,值为1

import torch
import numpy as np

data1 = torch.randint(0, 10, [1, 2, 3, 4, 1])
print(data1.shape)
print(data1.size())
data1 = data1.unsqueeze(0)
print(data1.shape)
# torch.Size([1, 2, 3, 4, 1])
# torch.Size([1, 2, 3, 4, 1])
# torch.Size([1, 1, 2, 3, 4, 1])

注意,在进行张量形状输出的时候.shape和.size()的作用是相同的

  • 7
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值