Pytorch 网络串联中loss.backward的相关问题

Pytorch 网络串联中loss.backward的相关问题

情况一

情况描述:
两个神经网络net1和net2,net1的输出是net2的输入,并且net1和net2的输出都有truth,我们需要对这两个网络进行训练。

应用背景:
End-to-End Pseudo-LiDAR for Image-Based 3D Object Detection中的Cor模块。

相关代码:

import  torch
from torch import nn

x = torch.ones(2, 3)*0.2
net1 = nn.Linear(3, 3)
net2 = nn.Linear(3, 3)

tgt1 = torch.ones(2, 3)*0.5
tgt2 = torch.ones(2, 3)
loss_fun = torch.nn.MSELoss()
opt1 = torch.optim.Adam(net1.parameters(), 0.002)
opt2 = torch.optim.Adam(net2.parameters(), 0.002)

for i in range(1000):
    tmp = net1(x)
    loss1=loss_fun(tmp, tgt1)
    
    output = net2(tmp)
    loss2 = loss_fun(output, tgt2)
    
    tol_loss= loss2+loss1

    opt1.zero_grad()
    opt2.zero_grad()

    tol_loss.backward()
    
    opt1.step()
    opt2.step()

    print(f'EPOCH:{i}, loss={loss1}, loss2={loss2}')
    
output1 = net1(x)
output2 = net2(output1)

print(f'output1:\n {output1}')
print(f'output2\n {output2}')

实验结果:
两个网络都能够完成正常训练,tol_loss.backward()实现了这个目标。

情况二

情况描述:
在情况一的基础上进行思考,如果将tol_loss.backward()改成使用loss1.backward() 和 loss2.backward()能否完成目标呢

相关代码:

import  torch
from torch import nn

x = torch.ones(2, 3)*0.2
net1 = nn.Linear(3, 3)
net2 = nn.Linear(3, 3)

tgt1 = torch.ones(2, 3)*0.5
tgt2 = torch.ones(2, 3)
loss_fun = torch.nn.MSELoss()
opt1 = torch.optim.Adam(net1.parameters(), 0.002)
opt2 = torch.optim.Adam(net2.parameters(), 0.002)

for i in range(1000):
    tmp = net1(x)
    loss1=loss_fun(tmp, tgt1)
    
    output = net2(tmp)
    loss2 = loss_fun(output, tgt2)

    opt1.zero_grad()
    opt2.zero_grad()

    loss1.backward()
    loss2.backward()
    
    opt1.step()
    opt2.step()

    print(f'EPOCH:{i}, loss={loss1}, loss2={loss2}')

output1 = net1(x)
output2 = net2(output1)

print(f'output1:\n {output1}')
print(f'output2\n {output2}')

实验结果: 不能,会出现一个很常见的报错

Trying to backward through the graph a second time, but the buffers have already been freed.
 Specify retain_graph=True when calling backward the first time.

结果分析:
这个就牵涉到pytorch的计算图,pytorch使用的是动态计算图,树状的结构,我们想要优化的参数就是叶子节点,在forward的过程中被建立起来,在loss.backward()后被释放掉。因此,当我们使用loss1.backward()后,net1的计算图就被free掉了,loss2.backward()便无法正常进行。

解决办法:
loss1.backward()改为loss1.backward(retain_graph=True),将计算图保留下来,不被清除掉就好了。但是,还是建议使用tol_loss.backward()进行反向梯度的计算。

注意】:
这从一个方面说明了net1和net2的计算图是建立在一起的,不是分开的!!!记住这句话,对下面的情况三出现的问题才能理解。

情况三

情况描述:
在情况一的基础上,我们对net1的输出需要做一些数据的计。

应用背景: 数据增强等等

相关代码:

import  torch
from torch import nn

x = torch.ones(2, 3)*0.2
net1 = nn.Linear(3, 3)
net2 = nn.Linear(3, 3)

tgt1 = torch.ones(2, 3)*0.5
tgt2 = torch.ones(2, 3)
loss_fun = torch.nn.MSELoss()
opt1 = torch.optim.Adam(net1.parameters(), 0.002)
opt2 = torch.optim.Adam(net2.parameters(), 0.002)

for i in range(3000):
    tmp = net1(x)
    loss1=loss_fun(tmp, tgt1)
    
    tmp=tmp*0.5
    
    output = net2(tmp)
    loss2 = loss_fun(output, tgt2)
    
    tol_loss= loss2+loss1

    opt1.zero_grad()
    opt2.zero_grad()

    tol_loss.backward()
    
    opt1.step()
    opt2.step()

    print(f'EPOCH:{i}, loss={loss1}, loss2={loss2}')

output1 = net1(x)
output1=output1*0.5
output2 = net2(output1)

print(f'output1:\n {output1}')
print(f'output2\n {output2}')

实验结果:
发现能够完成正常的训练,但是提出一个疑问,现在net1和net2的计算图还是联系在一起的吗?

测试:
我们将tol_loss.backward()换成los1.backward()和loss2.backward()

结果:
出现同样的报错,说明此时两者的计算图还是一个。

情况四

情况描述:
如果此时改变tmp的数据类型,从tensor改为np的array类型,进行计算后在改回tensor,又会是怎样的情况?

相关代码:

import  torch
from torch import nn

x = torch.ones(2, 3)*0.2
net1 = nn.Linear(3, 3)
net2 = nn.Linear(3, 3)

tgt1 = torch.ones(2, 3)*0.5
tgt2 = torch.ones(2, 3)
loss_fun = torch.nn.MSELoss()
opt1 = torch.optim.Adam(net1.parameters(), 0.002)
opt2 = torch.optim.Adam(net2.parameters(), 0.002)

for i in range(3000):
    tmp = net1(x)
    loss1=loss_fun(tmp, tgt1)
    
    tmp=tmp.detach().numpy()
    tmp=tmp*0.5
    tmp=torch.from_numpy(tmp)
    tmp.requires_grad_(True)
    
    output = net2(tmp)
    loss2 = loss_fun(output, tgt2)

    opt1.zero_grad()
    opt2.zero_grad()

    loss1.backward()
    loss2.backward()
    opt1.step()
    opt2.step()

   print(f'EPOCH:{i}, loss={loss1}, loss2={loss2}')

output1 = net1(x)
output1=output1*0.5
output2 = net2(output1)

print(f'output1:\n {output1}')
print(f'output2\n {output2}')

实验结果:
居然可以正常运行,这就说明了一个问题,现在这两个net的计算图已经是分开的了,就是说,虽然把tmp.requires_grad设置为了True,但是没有什么用,net2的梯度无法传播到net1了。(其实在这里,不设置tmp.requires_grad为True也不影响net2的更新。)我们可以做个简单的验证。

验证代码:

import  torch
from torch import nn

x = torch.ones(2, 3)*0.2
net1 = nn.Linear(3, 3)
raw_output1=net1(x)
raw_output1=raw_output1*0.5
net2 = nn.Linear(3, 3)

tgt1 = torch.ones(2, 3)*0.5
tgt2 = torch.ones(2, 3)
loss_fun = torch.nn.MSELoss()
opt1 = torch.optim.Adam(net1.parameters(), 0.002)
opt2 = torch.optim.Adam(net2.parameters(), 0.002)

for i in range(3000):
    tmp = net1(x)
    loss1=loss_fun(tmp, tgt1)
    
    tmp=tmp.detach().numpy()
    tmp=tmp*0.5
    tmp=torch.from_numpy(tmp)
    tmp.requires_grad_(True)
    
    output = net2(tmp)
    loss2 = loss_fun(output, tgt2)

    opt1.zero_grad()
    opt2.zero_grad()
    
    loss2.backward()
    opt1.step()
    opt2.step()

    print(f'EPOCH:{i}, loss={loss1}, loss2={loss2}')

output1 = net1(x)
output1=output1*0.5
output2 = net2(output1)

print(f'raw_output1:\n {row_output1}')
print(f'output1:\n {output1}')
print(f'output2\n {output2}')

验证结果:

raw_output1:
 tensor([[ 0.0029, -0.1812, -0.0505],
        [ 0.0029, -0.1812, -0.0505]], grad_fn=<MulBackward0>)
output1:
 tensor([[ 0.0029, -0.1812, -0.0505],
        [ 0.0029, -0.1812, -0.0505]], grad_fn=<MulBackward0>)
output2
 tensor([[1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000]], grad_fn=<AddmmBackward>)

验证结果分析:
发现在没有进行梯度更新前的输出raw_output1和进行loss2.backward()后毫无变化,说明net2没有办法传播到net1中。

总结

网络的串联,在进行forward的所创建的动态计算图是同一个。如果不改变数据的类型,在网络串联之间做一些数据的计算是没有什么问题的。但是如果把tensor张量改成了np.array类型,就不可行了,因为这个时候,前面的计算图会和后面的计算图断开。后面网络梯度无法传播到前面的网络之中。这也就失去了cor的意义。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

一米七八_FZH

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值