# Pytorch 网络串联中loss.backward的相关问题

5 篇文章 0 订阅

## 情况一

End-to-End Pseudo-LiDAR for Image-Based 3D Object Detection中的Cor模块。

import  torch
from torch import nn

x = torch.ones(2, 3)*0.2
net1 = nn.Linear(3, 3)
net2 = nn.Linear(3, 3)

tgt1 = torch.ones(2, 3)*0.5
tgt2 = torch.ones(2, 3)
loss_fun = torch.nn.MSELoss()

for i in range(1000):
tmp = net1(x)
loss1=loss_fun(tmp, tgt1)

output = net2(tmp)
loss2 = loss_fun(output, tgt2)

tol_loss= loss2+loss1

tol_loss.backward()

opt1.step()
opt2.step()

print(f'EPOCH:{i}, loss={loss1}, loss2={loss2}')

output1 = net1(x)
output2 = net2(output1)

print(f'output1:\n {output1}')
print(f'output2\n {output2}')


## 情况二

import  torch
from torch import nn

x = torch.ones(2, 3)*0.2
net1 = nn.Linear(3, 3)
net2 = nn.Linear(3, 3)

tgt1 = torch.ones(2, 3)*0.5
tgt2 = torch.ones(2, 3)
loss_fun = torch.nn.MSELoss()

for i in range(1000):
tmp = net1(x)
loss1=loss_fun(tmp, tgt1)

output = net2(tmp)
loss2 = loss_fun(output, tgt2)

loss1.backward()
loss2.backward()

opt1.step()
opt2.step()

print(f'EPOCH:{i}, loss={loss1}, loss2={loss2}')

output1 = net1(x)
output2 = net2(output1)

print(f'output1:\n {output1}')
print(f'output2\n {output2}')


Trying to backward through the graph a second time, but the buffers have already been freed.
Specify retain_graph=True when calling backward the first time.


loss1.backward()改为loss1.backward(retain_graph=True)，将计算图保留下来，不被清除掉就好了。但是，还是建议使用tol_loss.backward()进行反向梯度的计算。

## 情况三

import  torch
from torch import nn

x = torch.ones(2, 3)*0.2
net1 = nn.Linear(3, 3)
net2 = nn.Linear(3, 3)

tgt1 = torch.ones(2, 3)*0.5
tgt2 = torch.ones(2, 3)
loss_fun = torch.nn.MSELoss()

for i in range(3000):
tmp = net1(x)
loss1=loss_fun(tmp, tgt1)

tmp=tmp*0.5

output = net2(tmp)
loss2 = loss_fun(output, tgt2)

tol_loss= loss2+loss1

tol_loss.backward()

opt1.step()
opt2.step()

print(f'EPOCH:{i}, loss={loss1}, loss2={loss2}')

output1 = net1(x)
output1=output1*0.5
output2 = net2(output1)

print(f'output1:\n {output1}')
print(f'output2\n {output2}')


## 情况四

import  torch
from torch import nn

x = torch.ones(2, 3)*0.2
net1 = nn.Linear(3, 3)
net2 = nn.Linear(3, 3)

tgt1 = torch.ones(2, 3)*0.5
tgt2 = torch.ones(2, 3)
loss_fun = torch.nn.MSELoss()

for i in range(3000):
tmp = net1(x)
loss1=loss_fun(tmp, tgt1)

tmp=tmp.detach().numpy()
tmp=tmp*0.5
tmp=torch.from_numpy(tmp)

output = net2(tmp)
loss2 = loss_fun(output, tgt2)

loss1.backward()
loss2.backward()
opt1.step()
opt2.step()

print(f'EPOCH:{i}, loss={loss1}, loss2={loss2}')

output1 = net1(x)
output1=output1*0.5
output2 = net2(output1)

print(f'output1:\n {output1}')
print(f'output2\n {output2}')


import  torch
from torch import nn

x = torch.ones(2, 3)*0.2
net1 = nn.Linear(3, 3)
raw_output1=net1(x)
raw_output1=raw_output1*0.5
net2 = nn.Linear(3, 3)

tgt1 = torch.ones(2, 3)*0.5
tgt2 = torch.ones(2, 3)
loss_fun = torch.nn.MSELoss()

for i in range(3000):
tmp = net1(x)
loss1=loss_fun(tmp, tgt1)

tmp=tmp.detach().numpy()
tmp=tmp*0.5
tmp=torch.from_numpy(tmp)

output = net2(tmp)
loss2 = loss_fun(output, tgt2)

loss2.backward()
opt1.step()
opt2.step()

print(f'EPOCH:{i}, loss={loss1}, loss2={loss2}')

output1 = net1(x)
output1=output1*0.5
output2 = net2(output1)

print(f'raw_output1:\n {row_output1}')
print(f'output1:\n {output1}')
print(f'output2\n {output2}')


raw_output1:
tensor([[ 0.0029, -0.1812, -0.0505],
output1:
tensor([[ 0.0029, -0.1812, -0.0505],
output2
tensor([[1.0000, 1.0000, 1.0000],


## 总结

• 30
点赞
• 83
收藏
觉得还不错? 一键收藏
• 打赏
• 0
评论
10-10 1万+
12-02 211
03-30 1万+
03-14 623
11-08 5141
08-15 1036
05-28 2940
12-20 5223

¥1 ¥2 ¥4 ¥6 ¥10 ¥20

1.余额是钱包充值的虚拟货币，按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载，可以购买VIP、付费专栏及课程。