Fastai/Pytorch 的 BCEWITHLOGITSLOSS/AdaptiveLoss

本文深入解析FastAI中的鉴别器结构,包括卷积层、dropout层、self_attention机制等关键组件,以及BCEWithLogitsLoss和AdaptiveLoss在GAN训练中的应用。通过实例演示,帮助理解损失函数计算过程。
摘要由CSDN通过智能技术生成

最近在学习一篇有关于fastai的鉴别器知识,整理相关的有意思的可以学习的点。

1、鉴别器结构
def custom_gan_critic(
    n_channels: int = 3, nf: int = 256, n_blocks: int = 3, p: int = 0.15):
    "Critic to train a `GAN`."
    layers = [_conv(n_channels, nf, ks=4, stride=2), nn.Dropout2d(p / 2)]
    for i in range(n_blocks):
        layers += [
            _conv(nf, nf, ks=3, stride=1),
            nn.Dropout2d(p),
            _conv(nf, nf * 2, ks=4, stride=2, self_attention=(i == 0)),
        ]
        nf *= 2
    layers += [
        _conv(nf, nf, ks=3, stride=1),
        _conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False),
        Flatten(),
    ]
    return nn.Sequential(*layers)

Fastai中鉴别器的结构我在其他文档中看到结构极其类似。
分析一下网络结构,首先是一层(3,256),stride为2的下采样卷积层,跟着一层dropout层防止过拟合;接着是三次相同的网络结构叠加,当然不同的是self_attention机制只在i==0的情况下有效。最后是一次卷积,加上一次kernel为4的卷积层。
对于鉴别器网络而言:输入为(3,n,n),则输出为(1,(n//16-3),(n//16-3))。

2、 BCEWITHLOGITSLOSS

BCE的含义是bilinear cross entropy,意思就是说线性交叉熵。LOGITLOSS就是逻辑损失。pytorch里面的BCEWITHLOGITSLOSS就是将以上两种LOSS进行融合的计算方式,本质上就是一种综合性的损失函数。

一个非常简洁易懂的计算说明链接:https://blog.csdn.net/qq_22210253/article/details/85222093
配合以上计算说明,会发现源码中举的例子通俗易懂,更为简明。

target = torch.ones([10, 64], dtype=torch.float32)  # 64 classes, batch size = 10
input = torch.full([10, 64], 1.5)  # A prediction (logit)
pos_weight = torch.ones([64])  # All weights are equal to 1
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion(input, target)  # -log(sigmoid(1.5))
#output: tensor(0.2014)

target 某种意义上就是本次计算的label,如果样本标签为1则为对应维度的[1,1,1,1,…1],如果样本标签为0则为对应维度的[0,0,0,0…0],在本例中向量的维度为[10, 64];input是本次输入的向量,在本例中为[10,64],填充值为1.5的向量;weight的含义是逻辑回归中wx+b的向量w,默认值为1/n。比如在本例中如果没有输入的数值,则默认数值为1/(10*64)。

3、AdaptiveLoss

在fastai.forum上有相关的讨论:https://forums.fast.ai/t/why-use-adaptive-loss-for-gan-critic/52956

class AdaptiveLoss(Module):
    "Expand the `target` to match the `output` size before applying `crit`."
    def __init__(self, crit):
        self.crit = crit

    def forward(self, output, target):
        return self.crit(output, target[:,None].expand_as(output).float())

gan_loss_from_func 是 fastai 中封装generator和discriminator的方法,AdaptiveLoss主要是用于将已有的计算得到的loss,比如在上文中的AdaptiveLoss(nn.BCEWithLogitsLoss())中,将output的向量值与target保持一致。(output与target之间的关系可以参考上文中的input与target之间的关系)

def gan_loss_from_func(loss_gen, loss_crit, weights_gen:Tuple[float,float]=None):
    "Define loss functions for a GAN from `loss_gen` and `loss_crit`."
    def _loss_G(fake_pred, output, target, weights_gen=weights_gen):
        ones = fake_pred.new_ones(fake_pred.shape[0])
        weights_gen = ifnone(weights_gen, (1.,1.))
        return weights_gen[0] * loss_crit(fake_pred, ones) + weights_gen[1] * loss_gen(output, target)

    def _loss_C(real_pred, fake_pred):
        ones  = real_pred.new_ones (real_pred.shape[0])
        zeros = fake_pred.new_zeros(fake_pred.shape[0])
        return (loss_crit(real_pred, ones) + loss_crit(fake_pred, zeros)) / 2

    return _loss_G, _loss_C
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值