torch计算图直观体验

为什么把所有不同作用的损失加起来,可以对模型进行有效更新呢?而不会引起各个部分之间的干扰呢?
这是因为虽然不同的损失加起来作为一个标量对模型直接更新,但是在torch的计算图中记录了loss中不同loss的来源,根据不同其来源可以有效的对对应模型部件进行更新。下面举个例子进行介绍。

h t , r t , t t ht,rt,tt ht,rt,tt分别是文本模态对应的知识三元组; h i , r i , t i hi,ri,ti hi,ri,ti分别是图像模态对应的知识三元组,目标是学习两组三元组的向量表示,基于平移规则(TransE),并使用MSELoss,先对其进行联合学习,所以损失如下:
L o s s = F . m s e l o s s ( h t + r t , t t ) + F . m s e l o s s ( h i + r i , t i ) Loss=F.mse_loss(ht+rt,tt)+F.mse_loss(hi+ri,ti) Loss=F.mseloss(ht+rt,tt)+F.mseloss(hi+ri,ti)

torch.manual_seed(1)
ht,rt,tt = nn.Parameter(torch.randn(2,10)),nn.Parameter(torch.randn(2,10)),nn.Parameter(torch.randn(2,10))  #  [batch,dim]
hi,ri,ti = nn.Parameter(torch.randn(2,10)),nn.Parameter(torch.randn(2,10)),nn.Parameter(torch.randn(2,10))
loss = F.mse_loss(ht+rt,tt)+F.mse_loss(hi+ri,ti)
loss.backward()

print(hi.grad)
print(ht.grad)

以上是将两个任务的损失加起来一起进行更新,并打印对应的梯度。

torch.manual_seed(1)
ht,rt,tt = nn.Parameter(torch.randn(2,10)),nn.Parameter(torch.randn(2,10)),nn.Parameter(torch.randn(2,10))  #  [batch,dim]
hi,ri,ti = nn.Parameter(torch.randn(2,10)),nn.Parameter(torch.randn(2,10)),nn.Parameter(torch.randn(2,10))
loss = F.mse_loss(ht+rt,tt)
loss.backward()

print(hi.grad)
print(ht.grad)

以上是只计算文本模态的损失,打印发现 h t ht ht和之前损失加和时的梯度是一样的,但是 h i hi hi的梯度为空。

torch.manual_seed(1)
ht,rt,tt = nn.Parameter(torch.randn(2,10)),nn.Parameter(torch.randn(2,10)),nn.Parameter(torch.randn(2,10))  #  [batch,dim]
hi,ri,ti = nn.Parameter(torch.randn(2,10)),nn.Parameter(torch.randn(2,10)),nn.Parameter(torch.randn(2,10))
loss = F.mse_loss(hi+ri,ti)
loss.backward()

print(hi.grad)
print(ht.grad)

以上是只计算图像模态的损失,打印发现 h i hi hi和之前损失加和时的梯度是一样的,但是 h t ht ht的梯度为空。
以上结果说损失加和后并不会导致各个损失对不同的部件造成混乱更新,仍只会更新相应的模型部件。这是通过torch中的计算图实现,加和后的损失可以根据计算图进行“溯源”。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值