为什么把所有不同作用的损失加起来,可以对模型进行有效更新呢?而不会引起各个部分之间的干扰呢?
这是因为虽然不同的损失加起来作为一个标量对模型直接更新,但是在torch的计算图中记录了loss中不同loss的来源,根据不同其来源可以有效的对对应模型部件进行更新。下面举个例子进行介绍。
设
h
t
,
r
t
,
t
t
ht,rt,tt
ht,rt,tt分别是文本模态对应的知识三元组;
h
i
,
r
i
,
t
i
hi,ri,ti
hi,ri,ti分别是图像模态对应的知识三元组,目标是学习两组三元组的向量表示,基于平移规则(TransE),并使用MSELoss,先对其进行联合学习,所以损失如下:
L
o
s
s
=
F
.
m
s
e
l
o
s
s
(
h
t
+
r
t
,
t
t
)
+
F
.
m
s
e
l
o
s
s
(
h
i
+
r
i
,
t
i
)
Loss=F.mse_loss(ht+rt,tt)+F.mse_loss(hi+ri,ti)
Loss=F.mseloss(ht+rt,tt)+F.mseloss(hi+ri,ti)
torch.manual_seed(1)
ht,rt,tt = nn.Parameter(torch.randn(2,10)),nn.Parameter(torch.randn(2,10)),nn.Parameter(torch.randn(2,10)) # [batch,dim]
hi,ri,ti = nn.Parameter(torch.randn(2,10)),nn.Parameter(torch.randn(2,10)),nn.Parameter(torch.randn(2,10))
loss = F.mse_loss(ht+rt,tt)+F.mse_loss(hi+ri,ti)
loss.backward()
print(hi.grad)
print(ht.grad)
以上是将两个任务的损失加起来一起进行更新,并打印对应的梯度。
torch.manual_seed(1)
ht,rt,tt = nn.Parameter(torch.randn(2,10)),nn.Parameter(torch.randn(2,10)),nn.Parameter(torch.randn(2,10)) # [batch,dim]
hi,ri,ti = nn.Parameter(torch.randn(2,10)),nn.Parameter(torch.randn(2,10)),nn.Parameter(torch.randn(2,10))
loss = F.mse_loss(ht+rt,tt)
loss.backward()
print(hi.grad)
print(ht.grad)
以上是只计算文本模态的损失,打印发现 h t ht ht和之前损失加和时的梯度是一样的,但是 h i hi hi的梯度为空。
torch.manual_seed(1)
ht,rt,tt = nn.Parameter(torch.randn(2,10)),nn.Parameter(torch.randn(2,10)),nn.Parameter(torch.randn(2,10)) # [batch,dim]
hi,ri,ti = nn.Parameter(torch.randn(2,10)),nn.Parameter(torch.randn(2,10)),nn.Parameter(torch.randn(2,10))
loss = F.mse_loss(hi+ri,ti)
loss.backward()
print(hi.grad)
print(ht.grad)
以上是只计算图像模态的损失,打印发现
h
i
hi
hi和之前损失加和时的梯度是一样的,但是
h
t
ht
ht的梯度为空。
以上结果说损失加和后并不会导致各个损失对不同的部件造成混乱更新,仍只会更新相应的模型部件。这是通过torch中的计算图实现,加和后的损失可以根据计算图进行“溯源”。