最近在学习一篇有关于fastai的鉴别器知识,整理相关的有意思的可以学习的点。
1、鉴别器结构
def custom_gan_critic(
n_channels: int = 3, nf: int = 256, n_blocks: int = 3, p: int = 0.15):
"Critic to train a `GAN`."
layers = [_conv(n_channels, nf, ks=4, stride=2), nn.Dropout2d(p / 2)]
for i in range(n_blocks):
layers += [
_conv(nf, nf, ks=3, stride=1),
nn.Dropout2d(p),
_conv(nf, nf * 2, ks=4, stride=2, self_attention=(i == 0)),
]
nf *= 2
layers += [
_conv(nf, nf, ks=3, stride=1),
_conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False),
Flatten(),
]
return nn.Sequential(*layers)
Fastai中鉴别器的结构我在其他文档中看到结构极其类似。
分析一下网络结构,首先是一层(3,256),stride为2的下采样卷积层,跟着一层dropout层防止过拟合;接着是三次相同的网络结构叠加,当然不同的是self_attention机制只在i==0的情况下有效。最后是一次卷积,加上一次kernel为4的卷积层。
对于鉴别器网络而言:输入为(3,n,n),则输出为(1,(n//16-3),(n//16-3))。
2、 BCEWITHLOGITSLOSS
BCE的含义是bilinear cross entropy,意思就是说线性交叉熵。LOGITLOSS就是逻辑损失。pytorch里面的BCEWITHLOGITSLOSS就是将以上两种LOSS进行融合的计算方式,本质上就是一种综合性的损失函数。
一个非常简洁易懂的计算说明链接:https://blog.csdn.net/qq_22210253/article/details/85222093
配合以上计算说明,会发现源码中举的例子通俗易懂,更为简明。
target = torch.ones([10, 64], dtype=torch.float32) # 64 classes, batch size = 10
input = torch.full([10, 64], 1.5) # A prediction (logit)
pos_weight = torch.ones([64]) # All weights are equal to 1
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion(input, target) # -log(sigmoid(1.5))
#output: tensor(0.2014)
target 某种意义上就是本次计算的label,如果样本标签为1则为对应维度的[1,1,1,1,…1],如果样本标签为0则为对应维度的[0,0,0,0…0],在本例中向量的维度为[10, 64];input是本次输入的向量,在本例中为[10,64],填充值为1.5的向量;weight的含义是逻辑回归中wx+b的向量w,默认值为1/n。比如在本例中如果没有输入的数值,则默认数值为1/(10*64)。
3、AdaptiveLoss
在fastai.forum上有相关的讨论:https://forums.fast.ai/t/why-use-adaptive-loss-for-gan-critic/52956
class AdaptiveLoss(Module):
"Expand the `target` to match the `output` size before applying `crit`."
def __init__(self, crit):
self.crit = crit
def forward(self, output, target):
return self.crit(output, target[:,None].expand_as(output).float())
gan_loss_from_func 是 fastai 中封装generator和discriminator的方法,AdaptiveLoss主要是用于将已有的计算得到的loss,比如在上文中的AdaptiveLoss(nn.BCEWithLogitsLoss())中,将output的向量值与target保持一致。(output与target之间的关系可以参考上文中的input与target之间的关系)
def gan_loss_from_func(loss_gen, loss_crit, weights_gen:Tuple[float,float]=None):
"Define loss functions for a GAN from `loss_gen` and `loss_crit`."
def _loss_G(fake_pred, output, target, weights_gen=weights_gen):
ones = fake_pred.new_ones(fake_pred.shape[0])
weights_gen = ifnone(weights_gen, (1.,1.))
return weights_gen[0] * loss_crit(fake_pred, ones) + weights_gen[1] * loss_gen(output, target)
def _loss_C(real_pred, fake_pred):
ones = real_pred.new_ones (real_pred.shape[0])
zeros = fake_pred.new_zeros(fake_pred.shape[0])
return (loss_crit(real_pred, ones) + loss_crit(fake_pred, zeros)) / 2
return _loss_G, _loss_C