PyTorch 中 backward() 详解

转自:Pytorch中文网

接触了PyTorch这么长的时间,也玩了很多PyTorch的骚操作,都特别简单直观地实现了,但是有一个网络训练过程中的操作之前一直没有仔细去考虑过,那就是loss.backward(),看到这个大家一定都很熟悉,loss是网络的损失函数,是一个标量,你可能会说这不就是反向传播吗,有什么好讲的。

但是不知道大家思考过没有,如果loss不是一个标量,而是一个向量,那么loss.backward()是什么结果呢?

大家可以去试试,写一个简单的小程序

 

1

2

3

4

5

import torch as t

from torch.autograd import Variable as v

x = v(t.ones(2, 2), requires_grad=True)

y = x   1

y.backward()

运行一下程序,恭喜你报错了,错误显示如下

backwarderror.png

我们来读一读这个错误是什么意思。backward只能被应用在一个标量上,也就是一个一维tensor,或者传入跟变量相关的梯度。

嗯,前面一句话很简单,backward应用在一个标量,平时我们也是这么使用的,但是后面一句话,with gradient w.r.t variable是什么鬼,传入一个变量相关的梯度。不理解啊不理解,看不懂没关系我们还可以做实验来解决这个问题,俗话说自己动手丰衣足食(我也想做个伸手党去看看别人写的,然后不幸地是并没有什么人写过这方面的东西)。

首先我们开始做一个简单的实验,就是复习一下标量的形式

 

1

2

3

4

5

6

7

8

9

10

11

12

13

14

# simple gradient

a = v(t.FloatTensor([2, 3]), requires_grad=True)

b = a   3

c = b * b * 3

out = c.mean()

out.backward()

print(\'*\'*10)

print(\'=====simple gradient======\')

print(\'input\')

print(a.data)

print(\'compute result is\')

print(out.data[0])

print(\'input gradients are\')

print(a.grad.data)

很简单,我们把数学表达式写出来,传入的参数x1=2,x2=3

x1=2,x2=3

,特别注意Variable里面默认的参数requires_grad=False,所以这里我们要重新传入requires_grad=True让它成为一个叶子节点。

那么我们对其求偏导也很简单(分别为15,18) 这样依靠简单的微积分知识我们就能够算出他们的结果,运行一下程序,确保结果一致,ok。

Paste_Image.png

下面我们研究一下如何能够对非标量的情况下使用backward,下面开始做实验(瞎试)。

 

1

2

3

4

m = v(t.FloatTensor([[2, 3]]), requires_grad=True)

n = v(t.zeros(1, 2))

n[0, 0] = m[0, 0] ** 2

n[0, 1] = m[0, 1] ** 3

第一想法就是里面这个参数是要求梯度的对象,我们这样调用n.backward(m.data),有有报错诶,是不是成功了,我真的是个天才,这么难的东西都能想到,等等,我好想看到了一个很神奇的结果。

Paste_Image.png

这是什么鬼,这跟说好的结果不一样啊,我们想要的结果是4和27,现在给我们的结果是8和81,为什么会出现这样神奇的结果呢,想不通啊。我们看看我们传入的参数是m.data,这是一个(2, 3)的向量,我们希望得到的梯度是(4, 27),好像(4×2=8, 27×3=81),我的内心毫无波动,甚至有点想笑,似乎backward将我传入的参数m.data乘上了得到的梯度,既然要乘上我传入的参数,那么我就给你传入1,这样总能得到我想要的结果了吧,n.backward(t.FloatTensor([[1, 1]])),看看结果呢

backwardresult2.png

哇,跟我们想要的结果一样诶,撒花,我们解决了一个大问题,就是这么简单,扔进去一个1就可以了,这个问题也没有那么难嘛,哈哈哈。

似乎又有一点不对,如果这么简单那么写PyTorch的人为什么不把这一步直接集成进去,那我们不就不会遇到这个问题了嘛。

我们来试试另外一种情况

 

1

2

3

4

5

6

m = v(t.FloatTensor([[2, 3]]), requires_grad=True)

j = t.zeros(2 ,2)

k = v(t.zeros(1, 2))

m.grad.data.zero_()

k[0, 0] = m[0, 0] ** 2 + 3 * m[0 ,1]

k[0, 1] = m[0, 1] ** 2 + 2 * m[0, 0]

么我们直接对k反向传播k.backward(t.FloatTensor([[1, 1]]),结果是什么呢?

首先我们手动算一算结果是什么:4,3,2,6,我们是希望能够得到上面四个结果,这个时候你可能已经开始怀疑了,能够得到这4个结果吗?我们可以输出结果来看看

非常遗憾,我们只得到了两个结果,并且数值并不对,这个时候你就会疑惑了,到底是哪里出了问题呢,为什么会得到这样的结果呢?

经过不断地尝试,我终于发现了其中的奥秘,k.backward(parameters)接受的参数parameters必须要和k的大小一模一样,然后作为k的系数传回去,什么意思呢,我们通过上面的例子来解释这个问题你就知道了。

我们已经知道我们得到的k=(k1,k2)

k=(k1,k2)

,以及传入的参数是1和1,那么是如何得到这6和9这两个结果的呢?

我们知道了这个操作具体是怎么完成的,我们就可以求求我们需要的这个jacobian矩阵了,非常简单。

 

1

2

3

4

5

6

7

8

9

10

11

12

13

# jacobian

j = t.zeros(2 ,2)

k = v(t.zeros(1, 2))

m.grad.data.zero_()

k[0, 0] = m[0, 0] ** 2 + 3 * m[0 ,1]

k[0, 1] = m[0, 1] ** 2 + 2 * m[0, 0]

k.backward(t.FloatTensor([[1, 0]]), retain_variables=True)

j[:, 0] = m.grad.data

m.grad.data.zero_()

k.backward(t.FloatTensor([[0, 1]]))

j[:, 1] = m.grad.data

print('jacobian matrix is')

print(j)

我们可以得到如下结果

这里我们要注意backward()里面另外的一个参数retain_variables=True,这个参数默认是False,也就是反向传播之后这个计算图的内存会被释放,这样就没办法进行第二次反向传播了,所以我们需要设置为True,因为这里我们需要进行两次反向传播求得jacobian矩阵。

最后我们再举一个矩阵乘法的例子试验一下我们的结果

 

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

x = t.FloatTensor([2, 1]).view(1, 2)

x = v(x, requires_grad=True)

y = v(t.FloatTensor([[1, 2], [3, 4]]))

 

z = t.mm(x, y)

jacobian = t.zeros((2, 2))

z.backward(t.FloatTensor([[1, 0]]), retain_variables=True)  # dz1/dx1, dz2/dx1

jacobian[:, 0] = x.grad.data

x.grad.data.zero_()

z.backward(t.FloatTensor([[0, 1]]))  # dz1/dx2, dz2/dx2

jacobian[:, 1] = x.grad.data

print('=========jacobian========')

print('x')

print(x.data)

print('y')

print(y.data)

print('compute result')

print(z.data)

print('jacobian matrix is')

print(jacobian)

上面是代码,仔细阅读,作为一个小练习回顾一下本篇文章讲的内容,妈妈再也不用担心我不会用backward了。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值