pytorch常用API(2)

学习内容:pytorch常用API(2)

1、张量的索引

a = torch.Tensor(2,3,32,32)
print(a[:,:,:,:].shape)#全要,一个冒号代表全要
print(a[0:1,:,:,:].shape)#取第一张图像,通道、宽度、高度全要
print(a[:,:,0:32:2,0:32:2].shape)#所有图像,通道全要,宽度高度全要但是每隔两个要一个进行一个下采样,变成了16*16

输出:
torch.Size([2, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([2, 3, 16, 16])

#pytorch索引
a = torch.linspace(1,12,steps=12)
a = a.view(3,4)
print(a)
b = torch.index_select(a,0,torch.tensor([0,2]))#0在行维度,取第0行和第2行
print(b)
c = torch.index_select(a,1,torch.tensor([1,3]))#1在列维度,取第1列和第3列
print(c)

输出:
在这里插入图片描述

2、torch.masked_select()

#torch.masked_select()
a = torch.randn(3,3)#随机三行三列
mask = torch.eye(3,3,dtype=torch.bool)
print(a)
print(mask)
c = torch.masked_select(a,mask)
print(c)#模板取

tensor([[ 0.6159, 0.6601, 2.2437],
[ 0.1718, -0.2925, -0.0763],
[-0.4906, 0.7846, -1.4077]])
tensor([[ True, False, False],
[False, True, False],
[False, False, True]])
tensor([ 0.6159, -0.2925, -1.4077])

3、torch.take()

#torch.take()
a = torch.randn(3,3)
b = torch.tensor([0,2,4,6])#0,2,4,6
c = torch.take(a,b)#先将a打平,再按0,2,4,6去索引
print(a)
print(b)
print(c)

tensor([[-0.2457, -0.8254, -0.2993],
[-0.1371, -0.7704, 0.5670],
[ 0.3647, 0.8135, -0.8036]])
tensor([0, 2, 4, 6])
tensor([-0.2457, -0.2993, -0.7704, 0.3647])

4、维度变化

permute(),可以同时换挪多个维度

#permute(),可以同时换挪多个维度
a = torch.rand(4,3,32,32)
b = a.permute(0,3,2,1)#可以同时换挪多个维度
print(a.shape)
print(b.shape)

torch.Size([4, 3, 32, 32])
torch.Size([4, 32, 32, 3])

view() reshape()这两个意思一样,可以变换维度,一般情况用reshape,鲁棒性更强一些

#view() reshape()这两个意思一样,可以变换维度,一般情况用reshape,鲁棒性更强一些
a = torch.rand(4,3,32,32)#四维
b = a.view(4,3,32*32)#变成三维,维度变换维数要相等(乘起来相等)
c = a.view(4,-1)#打成两维,剩余多少,用-1代替,自动会计算
print(a.shape)
print(b.shape)
print(c.shape)
d = a.reshape(4,3,8,4,8,4)#也可以升维
print(d.shape)

torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 1024])
torch.Size([4, 3072])
torch.Size([4, 3, 8, 4, 8, 4])

unsqueeze()扩张

#unsqueeze()扩张
a = torch.rand(4,3,32,32)
b = a.unsqueeze(0)#第0维进行扩张
c = a.unsqueeze(2)#第2维进行扩张
d = a.unsqueeze(4)
e = a.unsqueeze(-1)#负索引进行扩张
print(a.shape)
print(b.shape)
print(c.shape)
print(d.shape)
print(e.shape)

torch.Size([4, 3, 32, 32])
torch.Size([1, 4, 3, 32, 32])
torch.Size([4, 3, 1, 32, 32])
torch.Size([4, 3, 32, 32, 1])
torch.Size([4, 3, 32, 32, 1])

squeeze()压缩,只能压缩一维

#squeeze()压缩,只能压缩一维
a = torch.rand(1,1,32,32)
b = a.squeeze(0)#第0维进行压缩
c = a.squeeze(1)#第1维进行压缩
d = a.squeeze(3)#压不了,因为索引的维度不为1,但不会报错
e = a.squeeze(-1)#负索引
print(a.shape)
print(b.shape)
print(c.shape)
print(d.shape)
print(e.shape)

torch.Size([1, 1, 32, 32])
torch.Size([1, 32, 32])
torch.Size([1, 32, 32])
torch.Size([1, 1, 32, 32])
torch.Size([1, 1, 32, 32])

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值