define_G
有两种实现方式 分别是 ResnetGenerator 和 UnetGenerator
ResnetGenerator
NLayerDiscriminator
原始 pix2pix 论文中描述的“PatchGAN”分类器。它可以区分 70×70 重叠的补丁是真的还是假的。这种补丁级鉴别器架构比全图像鉴别器具有更少的参数,并且可以以完全卷积的方式处理任意大小的图像。
有上述的卷积层进行感受野的逆推,可以得到输出特征图的每个像素对应到原图的感受野都是70*70,可以参见这个csdn,pathgan改进了原本GAN中的D
训练过程
首先根据opt.path确定风格迁移的方向,确定哪个domain是源头,哪个domain是目标
这里面的input是一个dict,他的key是['A','B','PATH_A','PATH_B‘],分别存储了图片和路径
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap domain A and domain B.
"""
AtoB = self.opt.direction == 'AtoB' # True
self.real_A = input['A' if AtoB else 'B'].to(self.device) # real_A是源头 real_B是目标
self.real_B = input['B' if AtoB else 'A'].to(self.device)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
然后首先使用两个generator产生4张图片,假B,重建A,假A和重建B,完成一次前向过程
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.fake_B = self.netG_A(self.real_A) # G_A(A) torch.Size([1, 3, 256, 256])
self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) torch.Size([1, 3, 256, 256])
self.fake_A = self.netG_B(self.real_B) # G_B(B)
self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B))
完成前向过程之后,把D的梯度从计算图中拉下来,然后计算并回传G的梯度
def optimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
# forward
self.forward() # compute fake images and reconstruction images.
# G_A and G_B
self.set_requires_grad([self.netD_A, self.netD_B], False) # Ds require no gradients when optimizing Gs
self.optimizer_G.zero_grad() # set G_A and G_B's gradients to zero
self.backward_G() # calculate gradients for G_A and G_B
self.optimizer_G.step() # update G_A and G_B's weights
# D_A and D_B
self.set_requires_grad([self.netD_A, self.netD_B], True)
self.optimizer_D.zero_grad() # set D_A and D_B's gradients to zero
self.backward_D_A() # calculate gradients for D_A
self.backward_D_B() # calculate graidents for D_B
self.optimizer_D.step() # update D_A and D_B's weights
backward_G()
首先计算一个Indentity Loss,有关indentity loss的讲解在https://blog.csdn.net/Teeyohuang/article/details/82729047https://blog.csdn.net/Teeyohuang/article/details/82729047
假设F是输入斑马生成马的generator,那么如果把马的图片输入F,生成的东西应该还是马才对,这个loss的功能主要是保证生成的图片纹理特征不变
def backward_G(self):
"""Calculate the loss for generators G_A and G_B"""
lambda_idt = self.opt.lambda_identity # 0.5
lambda_A = self.opt.lambda_A # 10
lambda_B = self.opt.lambda_B # 10
# Identity loss
if lambda_idt > 0:
# G_A should be identity if real_B is fed: ||G_A(B) - B||
self.idt_A = self.netG_A(self.real_B)
self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
# G_B should be identity if real_A is fed: ||G_B(A) - A||
self.idt_B = self.netG_B(self.real_A)
self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
else:
self.loss_idt_A = 0
self.loss_idt_B = 0
其中
self.criterionIdt = torch.nn.L1Loss()
接下来计算两个GANloss
# GAN loss D_A(G_A(A))
self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
# self.netD_A(self.fake_B).shape = torch.Size([1, 1, 30, 30])
# GAN loss D_B(G_B(B))
self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
其中criterionGAN的前向过程是这样的
def __call__(self, prediction, target_is_real):
"""Calculate loss given Discriminator's output and grount truth labels.
Parameters:
prediction (tensor) - - tpyically the prediction output from a discriminator
target_is_real (bool) - - if the ground truth label is for real images or fake images
Returns:
the calculated loss.
"""
if self.gan_mode in ['lsgan', 'vanilla']: # self.gan_mode = lsgan
target_tensor = self.get_target_tensor(prediction, target_is_real)
loss = self.loss(prediction, target_tensor)
elif self.gan_mode == 'wgangp':
if target_is_real:
loss = -prediction.mean()
else:
loss = prediction.mean()
return loss
self.get_target_tensor
D输入一个G的生成图,把这个生成图的标签置为一个全1的矩阵
因为此时在G的训练过程中,G的训练目标是自己生成的图片,在D看起来,尽可能与真实的接近,就是尽可能的与1接近
def get_target_tensor(self, prediction, target_is_real):
"""Create label tensors with the same size as the input.
Parameters:
prediction (tensor) - - tpyically the prediction from a discriminator
target_is_real (bool) - - if the ground truth label is for real images or fake images
Returns:
A label tensor filled with ground truth label, and with the size of the input
"""
if target_is_real:
target_tensor = self.real_label # 1 之前注册过一个寄存器
else:
target_tensor = self.fake_label
return target_tensor.expand_as(prediction)
# return .shape = torch.Size([1, 1, 30, 30])
接下来计算最后一个cycle loss
# Forward cycle loss || G_B(G_A(A)) - A||
self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
# Backward cycle loss || G_A(G_B(B)) - B||
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
其中
self.criterionCycle = torch.nn.L1Loss()
最后把这些loss都加在一起做一次梯度计算
# combined loss and calculate gradients
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
self.loss_G.backward()
有关这三种loss,转载两张很形象的图片
转载自https://blog.csdn.net/Teeyohuang/article/details/82729047
训练完了G就开始训练D
# D_A and D_B
self.set_requires_grad([self.netD_A, self.netD_B], True)
self.optimizer_D.zero_grad() # set D_A and D_B's gradients to zero
self.backward_D_A() # calculate gradients for D_A
self.backward_D_B() # calculate graidents for D_B
self.optimizer_D.step() # update D_A and D_B's weights
backward_D_A和backward_D_B是一样的,这里详细讲一下backward_D_A
backward_D_A()
def backward_D_A(self):
"""Calculate GAN loss for discriminator D_A"""
fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
self.fake_B_pool.query
这个函数的功能就是随机存储之前的50张图片返回,这个pool的buffer是随着训练的进行不断维护的,如果现在图片的总数超过了这个buffer的容量,那就有一定概率把原来的图片弹出并加入新的图片
这里面值得借鉴的代码就是torch.cat,可以把list拼接成一个tensor
def query(self, images):
"""Return an image from the pool.
Parameters:
images: the latest generated images from the generator
Returns images from the buffer.
By 50/100, the buffer will return input images.
By 50/100, the buffer will return images previously stored in the buffer,
and insert the current images to the buffer.
"""
if self.pool_size == 0: # if the buffer size is 0, do nothing
return images
return_images = []
for image in images:
image = torch.unsqueeze(image.data, 0)
if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
self.num_imgs = self.num_imgs + 1
self.images.append(image)
return_images.append(image)
else:
p = random.uniform(0, 1)
if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
tmp = self.images[random_id].clone()
self.images[random_id] = image
return_images.append(tmp)
else: # by another 50% chance, the buffer will return the current image
return_images.append(image)
return_images = torch.cat(return_images, 0) # collect all the images and return
return
backward_D_basic
这个函数的损失就是D需要完成的任务,训练目标是
在输入真实图片的时候,输出的结果要与1接近
在输入生成图片的是时候,结果要与0接近,把这两部分loss加在一起再除2
值得注意的地方是,在输入虚假图片fake的时候,要使用fake.detach(),因为fake也是前向生成的,要从计算图上拉下来防止梯度求的乱七八糟...
def backward_D_basic(self, netD, real, fake):
"""Calculate GAN loss for the discriminator
Parameters:
netD (network) -- the discriminator D
real (tensor array) -- real images
fake (tensor array) -- images generated by a generator
Return the discriminator loss.
We also call loss_D.backward() to calculate the gradients.
"""
# Real
pred_real = netD(real) # shape = [1,1,30,30]
loss_D_real = self.criterionGAN(pred_real, True)
# Fake
pred_fake = netD(fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined loss and calculate gradients
loss_D = (loss_D_real + loss_D_fake) * 0.5
loss_D.backward()
return loss_D
后面的代码就是一些存图和存数据的东西啦,感兴趣的同学可以自己去看下
完结撒花!