PyTorch入门4——张量变形

在搭建神经网络模型时,数据都是基于张量形式的表示,网络层与层之间很多都是以不同的 shape 即形状(比如:3行4列、2片5行9列等等)的方式进行表现和运算,因此对张量形状的变换非常普遍,这一操作将会改变张量维度或者改变某维度上数据个数。

reshape 函数

reshape 函数可以在保证张量数据不变的前提下改变数据的维度,将其转换成指定的形状,在神经网络中经常使用该函数来调节数据的形状,以适配不同网络层之间的数据传递。

import torch

data = torch.randint(0, 10, [3, 4, 5])
print(data)
print(data.shape)      # torch.Size([3, 4, 5])
new_data = torch.reshape(data, [5, 12])  # new_data = data.reshape(5, 12) 亦可
print(new_data)
print(new_data.shape)  # torch.Size([5, 12])

transpose 和 permute 函数

transpose 函数可以实现交换张量形状的指定维度,例如:一个张量的形状为 (2, 3, 4) 可以通过 transpose 函数把 3 和 4 进行交换,将张量的形状变为 (2, 4, 3)。permute 函数则可以一次交换更多的维度。

import torch

data = torch.randint(0, 10, [3, 4, 5])
print(data)
print(data.size())      # torch.Size([3, 4, 5])

new_data = torch.transpose(data, 1, 2)  # new_data = data.transpose(1, 2) 亦可
print(new_data.size())  # torch.Size([3, 5, 4])

new_data = torch.transpose(data, 0, 1)
new_data = torch.transpose(new_data, 1, 2)
print(new_data.size())  # torch.Size([4, 5, 3])

new_data = torch.permute(data, [1, 2, 0])  # new_data = data.permute(1, 2, 0) 亦可
# 等价于 new_data = data.transpose(0,1).transpose(1,2)
print(new_data.size())  # torch.Size([4, 5, 3])

view 和 contigous 函数

view 函数也可以用于修改张量的形状,但是其用法比较局限,只能用于存储在整块内存中的张量。在 PyTorch 中,有些张量是由不同的数据块组成的,它们并没有存储在整块的内存中,view 函数无法对这样的张量进行变形处理,例如:一个张量经过了 transpose 或者 permute 函数的处理之后,就无法使用 view 函数进行形状操作,这时候可以先使用 contiguous 函数转换为整块内存的张量,再使用 view 函数。

import torch

data = torch.tensor([[10, 20, 30], [40, 50, 60]])
print(data.size())     # torch.Size([2, 3])

new_data = data.view(3, 2)
print(new_data.shape)  # torch.Size([3, 2])

print(data.is_contiguous())  # True  判断张量是否使用整块内存

new_data = torch.transpose(data, 0, 1)
print(new_data.is_contiguous())  # False
new_data = new_data.view(2, 3)   # RuntimeError!

# 使用 contiguous 函数转换为整块内存的张量,再用 view 函数变形
print(new_data.contiguous().is_contiguous())  # True
new_data = new_data.contiguous().view(2, 3)
print(new_data.shape)  # torch.Size([2, 3])

squeeze 和 unsqueeze 函数

squeeze 函数用于张量删除 shape 为 1 的维度,unsqueeze 在指定维度添加 1,以增加张量的维度。注意:unsqueeze 函数也可用张量特殊切片来增加维度!比如:data[:, None] 等价于 data.unsqueeze(1),data[:, :, None] 等价于 data.unsqueeze(2),data[None, :, None] 等价于 data.unsqueeze(0).unsqueeze(2)

import torch

data = torch.randint(0, 10, [1, 3, 1, 5])
print(data)
print(data.size())  # torch.Size([1, 3, 1, 5])

new_data = data.squeeze()  # 去掉值为1的维度
print(new_data.size())     # torch.Size([3, 5])

new_data = data.squeeze(2)  # 去掉指定位置为1的维度,注意: 如果指定位置不是1则不删除!
print(new_data.size())      # torch.Size([1, 3, 5])

new_data = data.unsqueeze(-1)  # 在最后维度增加一个维度
# 等价于 new_data = data1[:, :, :, :, None]
print(new_data.size())         # torch.Size([1, 3, 1, 5, 1])

flatten 函数

flatten 函数用于压缩张量的维度,与 squeeze 不同的是,它不要求被压缩的维度必须是1。

import torch

data = torch.randint(0, 5, [3,4,5])  # 三维张量
print(data)

fd = torch.flatten(data)  # fd = data.flatten() 亦可,张量压缩成一维
# 等价于 fd = data.view(-1)
print(fd.shape)  # torch.Size([60])

fd = torch.flatten(data, start_dim=1, end_dim=-1)  # 保持原张量第0维(即3)压缩其他维度
# 等价于 fd = data.view(3, -1)
print(fd.shape)  # torch.Size([3, 20])

fd = torch.flatten(t, start_dim=0, end_dim=1)  # 保持原张量第2维(即5)压缩其他维度
# 等价于 fd = data.view(-1, 5)
print(fd.shape)  # torch.Size([12, 5])

以上。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值