深度学习初探/02-Pytorch知识/07-基本数学运算

深度学习初探/02-Pytorch知识/07-基本数学运算

一、基本四则运算

有“直接使用运算符”和“调用函数”2种方法,效果完全一致
“ + ”、“ - ”、“ * ”、“ / ”分别对应"add"、“sub”、“mul”、“div”

1、四则运算的broadcast自动对齐

在基本四则运算中,会默认使用broadcast进行自动对齐,如下:

a = torch.rand([3, 4])
b = torch.rand([4])
c = a + b   # 或 c = torch.add(a, b)
print(c)

Out: tensor([[0.8953, 0.9858, 0.5930, 0.9340],
      	     [1.0404, 1.1417, 1.2256, 1.0447],
             [0.7720, 1.5505, 0.8553, 0.9860]])
二、矩阵相乘

三种方式:
(1)torch.mm( ) ⇒ \Rightarrow 只适用于2d的矩阵,不推荐使用
(2)torch.matmul( ) 推荐使用
(3)@ ⇒ \Rightarrow numpy中的@符号,相当于matmul的重载

a = torch.tensor([[2., 2.],
                  [2., 2.]])
b = torch.ones(2, 2)

print(torch.mm(a, b))
print(torch.matmul(a, b))
print(a @ b)

Out: tensor([[4., 4.],
             [4., 4.]])
	 tensor([[4., 4.],
     	     [4., 4.]])
	 tensor([[4., 4.],
             [4., 4.]])

对于大于2d的tensor,mm无法使用,而matmul和@是只取最后2个dimension进行矩阵乘法运算:

a = torch.rand(4, 3, 28, 64)
b = torch.rand(4, 3, 64, 32)
c = torch.matmul(a, b)
d = a @ b
print(c.shape)
print(d.shape)

Out: torch.Size([4, 3, 28, 32])
	 torch.Size([4, 3, 28, 32])
三、指数运算
1、a ** b

a为底数,b为指数

a = torch.full([2, 2], 3.)
print(a)
print(a ** 2)
print(a ** 3)
print(a ** 0.5)

Out:
tensor([[3., 3.],
        [3., 3.]])
tensor([[9., 9.],
        [9., 9.]])
tensor([[27., 27.],
        [27., 27.]])
tensor([[1.7321, 1.7321],
        [1.7321, 1.7321]])
2、sqrt( ):开根运算
a = torch.full([2, 2], 3.)
c = a ** 2
print(c.sqrt())

Out:tensor([[3., 3.],
            [3., 3.]])
3、rsqrt( ):平方根的倒数
a = torch.full([2, 2], 3.)
c = a ** 2
print(c.rsqrt())

Out: tensor([[0.3333, 0.3333],
             [0.3333, 0.3333]])
4、pow( 指数 )
a = torch.full([2, 2], 3.)
print(a)
print(a.pow(2))
print(a.pow(3))

Out:
tensor([[3., 3.],
        [3., 3.]])
tensor([[9., 9.],
        [9., 9.]])
tensor([[27., 27.],
        [27., 27.]])
四、exp & log
1、exp

以e为底,以原矩阵数据为指数

a = torch.full([3, 3], 2.)
e = torch.exp(a)
print(e)

Out: tensor([[7.3891, 7.3891, 7.3891],
             [7.3891, 7.3891, 7.3891],
             [7.3891, 7.3891, 7.3891]])
2、log

默认以e为底:

a = torch.full([3, 3], 2.)
e = torch.exp(a)
print(torch.log(e))

Out: tensor([[2., 2., 2.],
             [2., 2., 2.],
             [2., 2., 2.]])

如果想以x为底,就用logx:

a = torch.full([3, 3], 4.)
print(torch.log2(a))

Out: tensor([[2., 2., 2.],
             [2., 2., 2.],
             [2., 2., 2.]])
五、近似值
1、floor 向下取整,ceil 向上取整
a = torch.tensor(3.14)
print(a.floor(), a.ceil())

Out: tensor(3.) tensor(4.)
2、trunc 取整数部分,frac 取小数部分
a = torch.tensor(3.14)
print(a.trunc(), a.frac())

Out: tensor(3.) tensor(0.1400)
3、round 四舍五入
a = torch.tensor(3.14)
b = torch.tensor(3.5)
print(torch.round(a), torch.round(b))

Out: tensor(3.) tensor(4.)
六、裁剪clamp

通常用于梯度裁剪,以解决梯度爆炸问题

grad = torch.rand(2, 3) * 15   # 假设given一个比较爆炸的梯度

print(grad) # 打印当前梯度信息
print(grad.clamp(10)) # 设置梯度下限进行裁剪(小于10的都变为10)
print(grad.clamp(0, 10)) # 设置梯度上下限进行裁剪

Out: 
tensor([[ 5.7563,  8.6068,  0.1151],
        [ 7.5975, 13.7174, 13.1848]])
tensor([[10.0000, 10.0000, 10.0000],
        [10.0000, 13.7174, 13.1848]])
tensor([[ 5.7563,  8.6068,  0.1151],
        [ 7.5975, 10.0000, 10.0000]])

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值