Pytorch Tensor 维度变换学习记录

1. 改变shape

torch.reshape()、torch.view()可以调整Tensor的shape,返回一个新shape的Tensor,torch.view()是老版本的实现,torch.reshape()是最新的实现,两者在功能上是一样的。

示例代码:

import torch
 
a = torch.rand(4, 1, 28, 28)
print(a.shape)
print(a.view(4 * 1, 28, 28).shape)
print(a.reshape(4 * 1, 28, 28).shape)
print(a.reshape(4, 1 * 28 * 28).shape)

输出结果:

torch.Size([4, 1, 28, 28])
torch.Size([4, 28, 28])
torch.Size([4, 28, 28])
torch.Size([4, 784])

注意:维度变换的时候要注意实际意义。

2. 增加维度

torch.unsqueeze(index)可以为Tensor增加一个维度,增加的这一个维度的位置由我们自己定义,新增加的这一个维度不会改变数据本身,只是为数据新增加了一个组别,这个组别是什么由我们自己定义。

比如定义了一个Tensor:

a = torch.randn(4, 1, 28, 28)

        这个Tensor有4个维度,我们可以在现有维度的基础上插入一个新的维度,插入维度的index在[-a.dim()-1, a.dim()+1]范围内,并且当index>=0,则在index前面插入这个新增加的维度;当index < 0,则在index后面插入这个新增的维度。

示例代码:

print(a.shape)
print(a.unsqueeze(0).shape)
print(a.unsqueeze(-1).shape)
print(a.unsqueeze(3).shape)
print(a.unsqueeze(4).shape)
print(a.unsqueeze(-4).shape)
print(a.unsqueeze(-5).shape)
print(a.unsqueeze(5).shape)

输出结果:

torch.Size([4, 1, 28, 28])
torch.Size([1, 4, 1, 28, 28])
torch.Size([4, 1, 28, 28, 1])
torch.Size([4, 1, 28, 1, 28])
torch.Size([4, 1, 28, 28, 1])
torch.Size([4, 1, 1, 28, 28])
torch.Size([1, 4, 1, 28, 28])
Traceback (most recent call last):
  File "/home/lhy/workspace/mmdetection/my_code/pytorch_ws/tensotr_0.py", line 218, in <module>
    print(a.unsqueeze(5).shape)
IndexError: Dimension out of range (expected to be in range of [-5, 4], but got 5)

在执行a.unsqueeze(5)时报错,是因为超出了index的范围。

3. 删减维度

删减维度实际上是一个维度挤压的过程,直观地看是把那些多余的[]给去掉,也就是只是去删除那些size=1的维度。

import torch
 
a = torch.Tensor(1, 4, 1, 9)
print(a.shape)
print(a.squeeze().shape) # 删除所有的size=1的维度
print(a.squeeze(0).shape) # 删除0号维度,ok
print(a.squeeze(2).shape) # 删除2号维度,ok
print(a.squeeze(3).shape) # 删除3号维度,但是3号维度是9不是1,删除失败

 输出结果:

torch.Size([1, 4, 1, 9])
torch.Size([4, 9])
torch.Size([4, 1, 9])
torch.Size([1, 4, 9])
torch.Size([1, 4, 1, 9])

4. 维度扩展

expand就是在某个size=1的维度上改变size,改成更大的一个大小,实际就是在每个size=1的维度上的标量的广播操作。

实际用例:

我们有一个shape=[4, 32, 14,14]的Tensor data,相当于4张图片,每张图片32个通道,每个通道行为14,列为14的图像数据,需要将每个通道上的所有像素增加一个偏置bias。图像数据的channel=32,因此bias = torch.rand(32),但是还不能完成data + bias的操作,因为两者dim与shape不一致。为了使得dim一致,需要增加bias的维度到4维,这就用到了unsqueeze()函数;为了使得shape一致,需要bias的4个维度的shape=[4, 32, 14, 14],这就用到了维度扩展expand。

代码实现:

import torch
 
bias = torch.rand(32)
data = torch.rand(4, 32, 14, 14)
 
# 想要把bias加到data上面去
# 先进行维度增加
bias = bias.unsqueeze(1).unsqueeze(2).unsqueeze(0)
print(bias.shape)
 
# 再进行维度扩展
bias = bias.expand(4, -1, 14, 14)  # -1表示这个维度保持不变,这里写32也可以
print(bias.shape)
 
data + bias

输出结果:

torch.Size([1, 32, 1, 1])
torch.Size([4, 32, 14, 14])

5. 维度重复

repeat就是将每个位置的维度都重复至指定的次数,以形成新的Tensor,功能与维度扩展一样,但是repeat会重新申请内存空间,repeat()参数表示各个维度指定的重复次数。

代码示例:

import torch
b = torch.Tensor(1, 32, 1, 1)
print(b.shape)
# 维度重复,32这里不想进行重复,所以就相当于"重复至1次"
b = b.repeat(4, 1, 14, 14)
print(b.shape)

6. 转置操作

Pytorch的转置操作只适用于dim=2的Tensor,也就是矩阵。

示例代码:

c = torch.Tensor(2, 4)
print(c.t().shape)

输出结果:

torch.Size([4, 2])

7. 维度变换

(1) transpose(dim1, dim2)交换dim1与dim2

注意这种交换使得存储不再连续,再执行一些reshape的操作会报错,所以要调用一下contiguous()使其变成连续的维度。

示例代码:

d = torch.Tensor(6, 3, 1, 2)
# 1号维度和3号维度交换
print(d.transpose(1, 3).contiguous().shape)  

输出结果:

torch.Size([6, 2, 1, 3])

下面这个例子比较一下每个位置上的元素都是一致的,来验证一下这个交换->压缩shape->展开shape->交换回去是没有问题的。

e = torch.rand(4, 3, 6, 7)
e2 = e.transpose(1, 3).contiguous().reshape(4, 7 * 6 * 3).reshape(4, 7, 6, 3).transpose(1, 3)
print(e2.shape)
# 比较下两个Tensor所有位置上的元素是否都相等,返回1表示相等。
print(torch.all(torch.eq(e, e2)))

输出结果:

torch.Size([4, 3, 6, 7])
tensor(1, dtype=torch.uint8)

(2) permute

如果四个维度表示上节的[batch,channel,h,w] ,如果想把channel放到最后去,形成[batch,h,w,channel],那么如果使用前面的维度交换,至少要交换两次(先13交换再12交换)。而使用permute可以直接指定维度新的所处位置,更加方便。

示例代码:

a = torch.rand(4, 3, 6, 7)
print(a.permute(0, 2, 3, 1).shape)

输出结果:

torch.Size([4, 6, 7, 3])

参考博客:Pytorch Tensor维度变换_torch tensor.shape-CSDN博客

  • 26
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
PyTorch中,可以使用多种方法对张量的维度进行变换。其中常用的方法有view()和squeeze()。view()方法可以用来改变张量的形状,而squeeze()方法可以去除维度中值为1的尺寸。举例如下: 1. 使用view()方法进行维度变换。view()方法会返回一个改变了形状的新张量,但张量中的元素数量必须保持不变。可以使用该方法实现维度的展平、增加或减少维度等操作。 2. 使用squeeze()方法进行维度变换。squeeze()方法可以去除张量中维度中值为1的尺寸,并返回一个新张量。可以指定具体的维度进行去除,也可以不指定维度,即默认去除所有为1的维度。 例如,假设有一个张量a的形状为(2, 1, 2, 1, 3),使用squeeze()方法可以去除其中值为1的维度,得到一个形状为(2, 2, 3)的新张量。 在PyTorch中,还可以使用unsqueeze()方法对张量进行维度扩展。unsqueeze()方法会在指定的维度增加一个尺寸为1的维度,并返回一个新张量。 总结起来,PyTorch中的维度变换包括view()、squeeze()和unsqueeze()等方法,可以根据具体需求选择合适的方法进行张量的维度变换。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* [PyTorchTensor维度变换实现](https://download.csdn.net/download/weixin_38698174/13988496)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] - *2* *3* [Pytorch 基础之维度变化](https://blog.csdn.net/zxhandroid/article/details/129192950)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

随机惯性粒子群

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值