starGAN的笔记(代码)

一、图片和标签融合输入CNN:
把标签转为one_hot(记为c), 维度是类别的个数, 假设是5个类别, 那么
x.size()==>[nb, cn, h, w] (cn是图片通道)
y.size()==>[nb, 1]
c.size()==>[nb, 5]
在generator的forward时, 把c扩展到四个维度(记为c_expand), 第3 4维度值和x一样
首先通过unsqueeze添加维度
c_expand = c.unsqueeze(2).unsqueeze(3)
新的维度值默认为1
c_expand.size()==>[nb, 5, 1, 1]
然后用expand扩展维度
c_expand = c_expand.expand(c.size(0), c.size(1), x.size(2), x.size(3))
c_expand.size()==>[nb, 5, h, w]
此时每个标签相当于是一个5通道的图, 每个通道图的值都一样, 即
c_expand[0, 0, :, :]上所有像素点的值都为c[nb, 0]
c_expand[0, 1, :, :]上所有像素点的值都为c[nb, 1]
c_expand[0, 2, :, :]上所有像素点的值都为c[nb, 2]
c_expand[0, 3, :, :]上所有像素点的值都为c[nb, 3]
c_expand[0, 4, :, :]上所有像素点的值都为c[nb, 4]
最后把c_expand和x进行cat
x = torch.cat([x, c_expand], dim=1)

 

此时x.size()==>[nb, cn+5, h, w]

 

 

 

 

二、改进的GAN的训练方式,Improved GAN Training
文章中提到了 为了使GAN的训练过程更稳定并且生成图片的质量更好, 作者把开始的GAN目标函数替换为以下:


这种目标函数参考自 《Improved Training of Wasserstein GANs》 (是《Wasserstein GAN》的变种),里面提到具体的算法过程:


中间的L(i)是用来优化D的,最后的θ更新就是优化G。
与原始的GAN损失相比: 
1. D去掉了sigmoid
2. D和G的loss不取log
3. D的损失多了一个gradient penalty项
3. 对D优化n_{critic}次后,对G优化1次


上面算法流程中,中间的L(i)对比starGAN的那个损失,可以发现两者的符号相反了,
这是因为在starGAN中
D要maximize这个目标函数,
进一步查看代码可知,
D: 
对于x_fake, 
out_src, out_cls = self.D(fake_x)
d_loss_fake = torch.mean(out_src) # 直接minimize,就是让
对于x_real,
out_src, out_cls = self.D(real_x)
d_loss_real = - torch.mean(out_src) # 取负号


对于G的损失就都一样,都要minimize,都是-D(G())
代码中对也是取负号:
out_src, out_cls = self.D(fake_x)
g_loss_fake = - torch.mean(out_src)



看看这个gradient penalty项在starGAN中的实现:
 

# Compute gradient penalty
alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x)
interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True)
out, out_cls = self.D(interpolated)

grad = torch.autograd.grad(outputs=out,
                           inputs=interpolated,
                           grad_outputs=torch.ones(out.size()).cuda(),
                           retain_graph=True,
                           create_graph=True,
                           only_inputs=True)[0]

grad = grad.view(grad.size(0), -1)
grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
d_loss_gp = torch.mean((grad_l2norm - 1)**2)

# Backward + Optimize
d_loss = self.lambda_gp * d_loss_gp
self.reset_grad()
d_loss.backward()
self.d_optimizer.step()

 


和上面算法流程图说的一样,
1.获取对x_real和x_fake随机插值后的图interpolated
2.计算D(interpolated)的损失out (out_cls是分类损失,忽略)
3.用torch.autograd.grad,让out对interpolated求梯度grad
4.最后的惩罚就是grad的二范数,减去1,再平方,最后乘以惩罚系数lambda_gp




所以对于D,GAN总损失是: d_loss_fake + d_loss_real + d_loss_gp * lambda_gp
starGAN中是先传播d_loss_fake + d_loss_real, 再传播d_loss_gp * lambda_gp

 

 

 

 

 

  • 3
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 7
    评论
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值