有如下代码块
self.gan_mode = gan_mode
if gan_mode == 'lsgan':
self.loss = nn.MSELoss()
elif gan_mode == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif gan_mode in ['wgangp']:
self.loss = None
vanilla 即为 BCEWithLogitsLoss
参见:https://blog.csdn.net/qq_22210253/article/details/85222093
lsgan 在pytorch中为 nn.MSELoss(),均方误差
公式:
其中y为target,x为模型输出值
示例:
import torch
import torch.nn as nn
output = torch.rand(2,2)
print(output)
tensor([[0.1234, 0.8351],
[0.9274, 0.8286]])
target = torch.FloatTensor([[0,1],[1,0]])
print(target)
tensor([[0., 1.],
[1., 0.]])
利用nn.MSELoss()计算损失:
crit = nn.MSELoss()
cost = crit(input,target)
输出结果为:
利用公式手工验证计算:(注意除的数为2n,即为4)
MSELoss_handle = ((0-0.1234)*(0-0.1234) + (1-0.8351)*(1-0.8351))/4 + ((1-0.9274)*(1-0.9274) + (0-0.8286)*(0-0.8286))/4
输出结果为:
四舍五入之后结果一致!
wgangp 参见
https://blog.csdn.net/weixin_37993251/article/details/87120269