今天看论文看到了联合损失函数的概念,查了一下,pytorch下可以通过重写损失函数来达到联合损失函数的效果。
自定义损失函数:
- 继承
nn.Module
类 - 重写
forward()
方法
class MyLoss(nn.Module):
def forward(self,output,target):
loss1 = ...
loss2 = ...
loss = (loss1 + loss2) / 2 # 计算平均值
return loss