pytorch基本操作

1.索引与切片

a = torch.rand(4, 3, 2, 2)

print(a.shape, a[0].shape, a[0, 0].shape, a[0, 0, 0].shape, a[0, 2, 1, 0])
print(a[:2].shape, a[:2, :1].shape, a[:2, :1, :1].shape)
print(a.index_select(0, torch.tensor([0, 1, 2])).shape
      )  # 选择第一个维度的下标为0、1和2。第二个参数必须为tensor
print(a[0, ...].shape, a[:, 1, ...].shape)

要想获得tensor某个位置的值,直接输出a[0,2,1,0]
要想获得在某一维度下的,可以输出a[0],a[0,0],a[0,0,0]
获取第一个维度下标0到1的切片,输出a[:2],类似于python的切片
a[0,…]等价于a[0]等价于a[0, :, :, :]
在这里插入图片描述

2.维度变换

1.view和reshape

a = torch.rand(4,3,2,2)
print(a)
print(a.view(4,3*2*2))
print(a.view(4*3*2,2))

a原来是四维,每个维度的值分别为4,3,2,2
第一个view,将其输出为4行12列的tensor
第二个view,将其输出为24行2列的tensor
view和reshape可以交换使用

2. unsqueeze

a = torch.rand(4,3,2,2)
print(a.unsqueeze(0).shape)
print(a.unsqueeze(-1).shape)
print(a.unsqueeze(-5).shape)
print(a.unsqueeze(1).shape)
print(a.unsqueeze(4).shape)  # 对于此tensor,参数的取值为[-5,4]

a = torch.tensor([1.2,2.3])  # shape 是 torch.Size([2])
print(a.unsqueeze(-1))
print(a.unsqueeze(0))

unsqueeze是扩展维度,参数值即是在第几个位置扩展
在这里插入图片描述
对于a = torch.tensor([1.2,2.3]) 来说,其shape是torch.Size([2])
经过以下两个操作,输出的shape分别变为torch.Size([2,1]),torch.Size([1,2]),实现了从一维到二维的维度变换
在这里插入图片描述
3.squeeze

b = torch.rand(1,32,1,1)
print(b.squeeze(0).shape)  # 只能squeeze某一维度值为1的,否则不变化。此处参数可以是0,2,3,-1,-2,-4
print(b.squeeze(-1).shape)
print(b.squeeze(1).shape)
print(b.squeeze().shape)

注意squeeze压缩维度,只有对应维度的值为1时才能压缩,否则不变化
在这里插入图片描述
4.expand (在某一/某几个维度上复制到几)

b = torch.rand(1, 3, 1, 1)
print(b.shape)
print(b.expand(2, 3, 2, 2).shape)  # 3不能扩张,只能扩张某维度长度为1的
print(b.expand(2, -1, -1, -1).shape)  # -1代表不变化

在这里插入图片描述
5.repeat (在某一/某几个维度上复制几次)

b = torch.rand(1, 3, 2, 2)
print(b.shape)
print(b.repeat(4, 1, 1, 1).shape)  # 在第一维度上复制4次

在这里插入图片描述
6.转置(交换维度,适用二维)

# .t 转置只适用于二维矩阵
a = torch.rand(3, 4)
print(a.t().shape)

7.transpose(交换维度)

a = torch.rand(1, 3, 2, 2)
print(a)
print(a.transpose(1, 3))  # 将1号维度和3号维度交换
# print(a.transpose(1, 3).view(1,3*2*2).view(1,3,2,2))  # 因为不连续报错
print(a.transpose(1, 3).contiguous().view(1, 2, 2, 3).transpose(1,3))  # 经过此操作又变回a

8.permute(给出下标重排维度)

# permute
a = torch.rand(1, 3, 2, 2)
print(a.permute(0, 2, 3, 1).shape)

3.合并与分割

1. cat和stack (合并)

# cat 合并
a = torch.rand(4, 8, 8)
b = torch.rand(5, 8, 8)
c = torch.rand(3, 8, 8)
d = torch.cat([a, b, c], dim=0)  # 第一个参数指明哪个需要合并,第二个参数指明在哪个维度上合并,其他维度的长度必须相等
print(d.shape)
# stack 创建新的维度
a = torch.rand(4, 8, 8)
b = torch.rand(4, 8, 8)
c = torch.rand(4, 8, 8)
d = torch.stack([a, b, c], dim=1)  # 每个shape值必须完全一致
print(d.shape)

cat需要某一个维度的值不相同,以此实现在这一维度上合并
stack需要shape完全一致,在某一个位置上扩展维度
在这里插入图片描述
2.split和chunk

# split
c = torch.rand(5, 8, 8)
a, b = c.split([2, 3], dim=0)  # 在0维度上拆分成,一个0维度的值是1,一个是3
print(a.shape, b.shape)
a, b = c.split(3, dim=0)  # 在0维度上按3个3个拆分
print(a.shape, b.shape)
# chunk
c = torch.rand(4, 8, 8)
a, b, d, e = c.chunk(4, dim=0)  # 在0维度上拆分为4份
print(a.shape, b.shape)

4.数学运算

1.加减乘除

a = torch.randint(0, 10, [3, 4])
b = torch.randint(0, 10, [4])

# 加法,根据boardcasting,b会变成维度3,4的tensor
print(a + b)
print(torch.add(a, b))
# 减法
print(a - b)
print(torch.sub(a, b))
# 乘法 (对应位置的值相乘,不是矩阵乘)
print(a * b)
print(torch.mul(a, b))
# 除法 (对应位置的值相除)
print(a / b)
print(torch.div(a, b))

2.矩阵乘法与高维度乘法

# mm,matmul
a = torch.full([2, 2], 3.)
b = torch.ones(2, 2)
print(torch.mm(a, b))  # 只适用于维度为2
print(torch.matmul(a, b))  # 可以适用与各种维度
print(a @ b)  # 与matmul一致

3.平方,开方,e,ln,取整数取小数

# power
a = torch.randint(0, 10, [3, 3])
print(a)
print(a.pow(2))  # 对每个位置上的值作平方
print(a**2)  # 对每个位置上的值作平方
print(a.sqrt())  # 对每个位置上的值作开方
print(a.rsqrt())  # 对每个位置上的值作开方,取他们的倒数

# e的次方,ln,向下取整,向上取整,取整数部分,取小数部分,四舍五入
a = torch.exp(torch.ones(3, 3))  # a的每个值为e^1
b = a.log()  # b的每个值为 ln(a的每个值)
# 对a的每个值分别作:
# 向下取整,向上取整,取整数部分,取小数部分,四舍五入
print(a.floor(), a.ceil(), a.trunc(), a.frac(), a.round())

4.clamp (限制取值)

# clamp
grad = torch.rand(2, 3)*15
print(grad.clamp(10))  # 小于10的值全部变为10
print(grad.clamp(5, 10))  # 小于5的值变为5,大于10的值变为10

5.统计属性

1.norm

# norm
a = torch.rand(8)
b = a.reshape(2, 4)
c = a.reshape(2, 2, 2)
print(a, '\n', b, '\n', c)
print(a.norm(1))  # 绝对值之和
print(a.norm(2))  # 平方之和的开方
print(b.norm(1, dim=1))  # 对于1维度,即每行绝对值之和
print(b.norm(2, dim=1))  # 每行平方之和的开方
print(c.norm(1, dim=0))  # 对于0维度,求绝对值之和
print(c.norm(1, dim=1))
print(c.norm(1, dim=2))

2.mean,sum,min,max,prod,argmin,argmax

# mean,sum,min,max,prod
a = torch.randint(0, 10, [2, 4]).float()

print(a.mean(), a.sum(), a.min(), a.max(), a.prod(),)  # 均值,总和,最小,最大,累乘,
print(a.argmin(), a.argmax())  # 求最小,最大值的索引,不论维度多少,都看作顺序关系,只返回一个值

3.dim,keepdim

a = torch.randint(0, 10, [2, 4]).float()
print(a.max(dim=0))
print(a.max(dim=1))
print(a.max(dim=0, keepdim=True))
print(a.max(dim=1, keepdim=True))

可以看出,keepdim会使丢失的括号出现,
一般2行4列,求0维度上max的会显示shape是(4),但是keepdim会显示(1,4)
求1维度上max会显示(2),但是keepdim会显示(2,1)
在这里插入图片描述
4.比较

a = torch.randint(0, 10, [2, 4]).float()
print(a > 0)  # 判断a,值大于0,对应位置为True(此处比较运算符改为其他都可以,比如<,==,>=)
print(torch.gt(a,0))  # 判断a,值大于0,对应位置为True
b = torch.randint(0, 10, [2, 4]).float()
print(torch.eq(a,b))  # 在每个位置上显示a,b对应的值是否相等
print(torch.equal(a,b))  # 直接返回两个tensor是否相等
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值