深度学习初探/02-Pytorch知识/07-基本数学运算
一、基本四则运算
有“直接使用运算符”和“调用函数”2种方法,效果完全一致
“ + ”、“ - ”、“ * ”、“ / ”分别对应"add"、“sub”、“mul”、“div”
1、四则运算的broadcast自动对齐
在基本四则运算中,会默认使用broadcast进行自动对齐,如下:
a = torch.rand([3, 4])
b = torch.rand([4])
c = a + b # 或 c = torch.add(a, b)
print(c)
Out: tensor([[0.8953, 0.9858, 0.5930, 0.9340],
[1.0404, 1.1417, 1.2256, 1.0447],
[0.7720, 1.5505, 0.8553, 0.9860]])
二、矩阵相乘
三种方式:
(1)torch.mm( ) ⇒ \Rightarrow ⇒ 只适用于2d的矩阵,不推荐使用
(2)torch.matmul( ) 推荐使用
(3)@ ⇒ \Rightarrow ⇒numpy中的@符号,相当于matmul的重载
a = torch.tensor([[2., 2.],
[2., 2.]])
b = torch.ones(2, 2)
print(torch.mm(a, b))
print(torch.matmul(a, b))
print(a @ b)
Out: tensor([[4., 4.],
[4., 4.]])
tensor([[4., 4.],
[4., 4.]])
tensor([[4., 4.],
[4., 4.]])
对于大于2d的tensor,mm无法使用,而matmul和@是只取最后2个dimension进行矩阵乘法运算:
a = torch.rand(4, 3, 28, 64)
b = torch.rand(4, 3, 64, 32)
c = torch.matmul(a, b)
d = a @ b
print(c.shape)
print(d.shape)
Out: torch.Size([4, 3, 28, 32])
torch.Size([4, 3, 28, 32])
三、指数运算
1、a ** b
a为底数,b为指数
a = torch.full([2, 2], 3.)
print(a)
print(a ** 2)
print(a ** 3)
print(a ** 0.5)
Out:
tensor([[3., 3.],
[3., 3.]])
tensor([[9., 9.],
[9., 9.]])
tensor([[27., 27.],
[27., 27.]])
tensor([[1.7321, 1.7321],
[1.7321, 1.7321]])
2、sqrt( ):开根运算
a = torch.full([2, 2], 3.)
c = a ** 2
print(c.sqrt())
Out:tensor([[3., 3.],
[3., 3.]])
3、rsqrt( ):平方根的倒数
a = torch.full([2, 2], 3.)
c = a ** 2
print(c.rsqrt())
Out: tensor([[0.3333, 0.3333],
[0.3333, 0.3333]])
4、pow( 指数 )
a = torch.full([2, 2], 3.)
print(a)
print(a.pow(2))
print(a.pow(3))
Out:
tensor([[3., 3.],
[3., 3.]])
tensor([[9., 9.],
[9., 9.]])
tensor([[27., 27.],
[27., 27.]])
四、exp & log
1、exp
以e为底,以原矩阵数据为指数
a = torch.full([3, 3], 2.)
e = torch.exp(a)
print(e)
Out: tensor([[7.3891, 7.3891, 7.3891],
[7.3891, 7.3891, 7.3891],
[7.3891, 7.3891, 7.3891]])
2、log
默认以e为底:
a = torch.full([3, 3], 2.)
e = torch.exp(a)
print(torch.log(e))
Out: tensor([[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.]])
如果想以x为底,就用logx:
a = torch.full([3, 3], 4.)
print(torch.log2(a))
Out: tensor([[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.]])
五、近似值
1、floor 向下取整,ceil 向上取整
a = torch.tensor(3.14)
print(a.floor(), a.ceil())
Out: tensor(3.) tensor(4.)
2、trunc 取整数部分,frac 取小数部分
a = torch.tensor(3.14)
print(a.trunc(), a.frac())
Out: tensor(3.) tensor(0.1400)
3、round 四舍五入
a = torch.tensor(3.14)
b = torch.tensor(3.5)
print(torch.round(a), torch.round(b))
Out: tensor(3.) tensor(4.)
六、裁剪clamp
通常用于梯度裁剪,以解决梯度爆炸问题
grad = torch.rand(2, 3) * 15 # 假设given一个比较爆炸的梯度
print(grad) # 打印当前梯度信息
print(grad.clamp(10)) # 设置梯度下限进行裁剪(小于10的都变为10)
print(grad.clamp(0, 10)) # 设置梯度上下限进行裁剪
Out:
tensor([[ 5.7563, 8.6068, 0.1151],
[ 7.5975, 13.7174, 13.1848]])
tensor([[10.0000, 10.0000, 10.0000],
[10.0000, 13.7174, 13.1848]])
tensor([[ 5.7563, 8.6068, 0.1151],
[ 7.5975, 10.0000, 10.0000]])