训练正常&异常的GAN损失函数loss变化应该是怎么样的

这里以个人用到的一个网络为例,仅供参考,不代表所有情形。
用tensorboard记录loss曲线的走向,横轴为迭代次数(iter),纵轴为损失(loss)值。


正常的

判别器loss
在这里插入图片描述
loss在2.6到3.4之间来回上下波动。注意,在tensorboard中,最好将Smoothing值调整为0,如果使用了Smoothing将比较难观察到loss的波动趋势,这里以Smoothing为0.999为例,还是展示判别器损失:
在这里插入图片描述
这个时候容易误判为判别器loss在稳步下降。不过实际上可以发现这个"下降"从3.25降到了3.05,所以本质上还是小范围波动。
生成器loss
在这里插入图片描述
loss在3.0到3.6之间来回上下波动。类似的,观察Smoothing为0.999时的情况:
在这里插入图片描述
给人的感觉是生成器loss一直在往上跑,可能会怀疑生成器没在进一步学习;但实际上仔细看容易发现只是从3.25跑到了3.35,依旧是在一个小区间内波动,生成器的能力确实是在继续逐步增强的。


生成器崩溃

判别器loss
在这里插入图片描述
可以发现从3一路下降到了0.8,在训练一开始就有在快速下降。观察Smoothing为0.999时的情况:
在这里插入图片描述
生成器loss
在这里插入图片描述
可以发现从4一路上升到了6.5,在训练一开始就有在快速上升。观察Smoothing为0.999时的情况:
在这里插入图片描述
从结果的角度出发,在生成器loss快速上升的时候,其已经开始生成无意义的噪声图像:
在这里插入图片描述


小结

从原理上来说,生成器和判别器从一开始都是非常弱的,因此一般不会在训练一开始两者损失就非常剧烈的波动。在训练一段时间达到稳定期后,生成器和判别器的损失都应该在一个小区间内波动,而不会有明显的持续上升/下降趋势。
如果生成器损失持续明显上升,表明其无法学习怎么欺骗判别器,体现在结果上就是开始生成噪声。
如果判别器损失持续明显上升,表示其无法学习怎么识别生成器,体现在结果上就是生成器可能会生成一致的,无意义的但是能欺骗判别器的图像(比如直接输出训练集中的样本)。

好的,以下是使用Pytorch实现Focal Loss损失函数的代码: ```python import torch import torch.nn as nn import torch.nn.functional as F class FocalLoss(nn.Module): def __init__(self, gamma=2, alpha=None, size_average=True): super(FocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha if isinstance(alpha, (float, int)): self.alpha = torch.Tensor([alpha, 1 - alpha]) if isinstance(alpha, list): self.alpha = torch.Tensor(alpha) self.size_average = size_average def forward(self, input, target): if input.dim() > 2: input = input.view(input.size(0), input.size(1), -1) # N,C,H,W => N,C,H*W input = input.transpose(1, 2) # N,C,H*W => N,H*W,C input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C target = target.view(-1, 1) logpt = F.log_softmax(input) logpt = logpt.gather(1, target) logpt = logpt.view(-1) pt = logpt.data.exp() if self.alpha is not None: if self.alpha.type() != input.data.type(): self.alpha = self.alpha.type_as(input.data) at = self.alpha.gather(0, target.data.view(-1)) logpt = logpt * at loss = -1 * (1 - pt) ** self.gamma * logpt if self.size_average: return loss.mean() else: return loss.sum() ``` 其中,`gamma`是超参数,控制着难易样本的权重,一般默认为2;`alpha`是类别权重系数,可以为None、float、int或list类型;`size_average`控制是否对每个样本的损失求平均,默认为True。 使用时,只需在训练代码中调用该损失函数即可: ```python loss_fn = FocalLoss(gamma=2, alpha=[0.25, 0.75]) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) for epoch in range(num_epochs): for images, labels in train_loader: images = images.to(device) labels = labels.to(device) outputs = model(images) loss = loss_fn(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() ```
评论 16
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值