PyTorch教程(六):算数运算

import torch

a = torch.rand(3,4)
# tensor([[0.0422, 0.9152, 0.5486, 0.8753],
#        [0.3918, 0.6056, 0.0634, 0.9498],
#        [0.4444, 0.2209, 0.8165, 0.1533]])
b = torch.rand(4)
# tensor([0.4145, 0.2913, 0.1655, 0.3705])


a + b
# tensor([[0.4567, 1.2065, 0.7141, 1.2458],
#        [0.8064, 0.8969, 0.2289, 1.3203],
#        [0.8589, 0.5122, 0.9819, 0.5238]])
torch.add(a,b)
# tensor([[0.4567, 1.2065, 0.7141, 1.2458],
#        [0.8064, 0.8969, 0.2289, 1.3203],
#        [0.8589, 0.5122, 0.9819, 0.5238]])
torch.all(torch.eq(a+b,torch.add(a,b))) # 判断是不是每一个位置对应的值都形同
tensor(True)

torch.all(torch.eq(a-b,torch.sub(a,b)))
tensor(True)

torch.all(torch.eq(a*b,torch.mul(a,b)))
tensor(True)

torch.all(torch.eq(a/b,torch.div(a,b)))
tensor(True)

矩阵相乘matmul

注意: * 表示element-wise,相同位置的相乘,
而matmul表示矩阵相乘

a = torch.full([2,2],3)
# tensor([[3., 3.],
#        [3., 3.]])
b = torch.ones(2,2)
# tensor([[1., 1.],
#        [1., 1.]])
torch.mm(a,b) # 矩阵相乘
#tensor([[6., 6.],
#        [6., 6.]])
torch.matmul(a,b) # 矩阵相乘
#tensor([[6., 6.],
#        [6., 6.]])
a@b # 矩阵相乘
#tensor([[6., 6.],
#        [6., 6.]])

上面三种方式表示矩阵相乘,torch.mm只能支持2D的矩阵相乘,一般使用torch.matmul或者a@b来表示。

a = torch.rand(4,3,26,64)
b = torch.rand(4,3,64,32)
torch.matmul(a,b).shape  # 4维矩阵相乘
torch.Size([4, 3, 26, 32])

b = torch.rand(4,1,64,32) # 自动进行Broadcasting

对于上面的4维矩阵相乘,前面的两维保持不变,后面的两维进行相乘。

指数与平方根

a = torch.full([2,2],3)
# tensor([[3, 3],
#        [3, 3]])
a.pow(2) # 平方
# tensor([[9, 9],
#        [9, 9]])
b = a ** 2 # 平方
# tensor([[9, 9],
#        [9, 9]])

b ** (0.5) # 开根号
# tensor([[3., 3.],
#        [3., 3.]])

torch.sqrt(b.to(torch.double)) # 开根号
# tensor([[3., 3.],
#        [3., 3.]], dtype=torch.float64)

torch.rsqrt(b.to(torch.double)) # 开根号之后求倒数
# tensor([[0.3333, 0.3333],
#        [0.3333, 0.3333]], dtype=torch.float64)

这里求平方根的时候为什么不能直接使用b.sqrt()?如果使用b.sqrt()会报如下错误:RuntimeError: sqrt_vml_cpu not implemented for 'Long'。原因是Long类型的数据不支持log对数运算, 为什么Tensor是Long类型? 因为创建List数组时默认使用的是int, 所以从List转成torch.Tensor后, 数据类型变成了Long。

以e为底的指数函数

a = torch.exp(torch.ones(2,2)) 
# tensor([[2.7183, 2.7183],
#        [2.7183, 2.7183]])
torch.log(a) # 默认以e为底
# tensor([[1., 1.],
#        [1., 1.]])

floor、ceil、trunc、frac

a.floor(),a.ceil(),a.trunc(),a.frac()
(tensor(3.), tensor(4.), tensor(3.), tensor(0.1400))
  • floor:向下取整
  • ceil:向上取整
  • trunc:裁剪整数部分
  • frac:裁剪小数部分

四舍五入

a = torch.tensor(3.499)
a.round()
# tensor(3.)
a = torch.tensor(3.50)
a.round()
# tensor(4.)

clamp

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值