【生成对抗网络系列】六、CycleGAN


参考资料

论文

  Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks

  论文主页

博客

  CycleGAN:图片风格,想换就换 | ICCV 2017论文解读

  CycleGAN详细解读

  CycleGAN 原理详解

视频

  生成对抗网络GAN开山之作论文精读

代码

  PyTorch-GAN


第1章 CycleGAN的作用

 CycleGAN的一个重要应用领域是 Domain Adaptation(域迁移:可以通俗的理解为画风迁移),比如可以把一张普通的风景照变成梵高化作,或者将游戏画面变化成真实世界画面等等。以下是原论文中给出的一些应用:

在这里插入图片描述


第2章 CycleGAN的优势

 其实在CycleGAN之前,就已经有了Domain Adaptation模型,比如Pix2Pix,不过 Pix2Pix 要求训练数据必须是成对的,而现实生活中,要找到两个域(画风)中成对出现的图片是相当困难的,而CycleGAN只需要两种域的数据,而不需要他们有严格对应关系,这使得CycleGAN的应用更为广泛。原论文中是这样解释的:

在这里插入图片描述


第3章 CycleGAN的网络结构

 CycleGAN 可以让两个 domain 的图片互相转化。传统的 GAN 是单向生成,而 CycleGAN 是互相生成,网络是个环形,所以命名为 Cycle。并且 CycleGAN 一个非常实用的地方就是输入的两张图片可以是任意的两张图片,也就是 unpaired

3.1 单向GAN

CycleGAN 本质上是两个镜像对称的 GAN,构成了一个环形网络。其实只要理解了一半的单向 GAN 就等于理解了整个CycleGAN。

在这里插入图片描述

 上图是一个单向 GAN 的示意图。我们希望能够把 domain A 的图片(命名为 A)转化为 domain B 的图片(命名为图片 B)。为了实现这个过程,我们需要两个生成器 G A B G_{AB} GAB G B A G_{BA} GBA,分别把 domain Adomain B 的图片进行互相转换。

 图片 A A A 经过生成器 G A B G_{AB} GAB 表示为 Fake Image in domain B,用 G A B ( A ) G_{AB}(A) GAB(A) 表示。而 经 G A B ( A ) G_{AB}(A) GAB(A)过生成器 G B A G_{BA} GBA表示为图片 A A A 的重建图片,用 G B A ( G A B ( A ) ) G_{BA}(G_{AB}(A)) GBA(GAB(A)) 表示。

 最后为了训练这个单向 GAN 需要两个 loss,分别是 生成器的重建 loss判别器的判别 loss

(1)判别 loss:判别器 D B D_B DB 是用来判断输入的图片是否是真实的 domain B 图片,于是生成的假图片 G A B ( A ) G_{AB}(A) GAB(A)和原始的真图片 B 都会输入到判别器里面,公式挺好理解的,就是一个 0,1 二分类的损失。最后的 loss 表示为:

在这里插入图片描述

(2)生成 loss:生成器用来重建图片 A,目的是希望生成的图片 G B A ( G A B ( A ) ) G_{BA}(G_{AB}(A)) GBA(GAB(A)) 和原图 A 尽可能的相似,那么可以很简单的采取 L 1   L o s s L_1\ Loss L1 Loss 或者 。最 L 2   L o s s L_2\ Loss L2 Loss 后生成 Loss 就表示为:

在这里插入图片描述

以上就是 A→B 单向 GAN 的原理。


3.2 CycleGAN

CycleGAN 其实就是一个 A→B 单向 GAN 加上一个 B→A 单向 GAN。两个 GAN 有两个生成器,然后各自带一个判别器,所以加起来总共有两个判别器和两个生成器。一个单向 GAN 有两个 loss,而 CycleGAN 加起来总共有四个 loss。CycleGAN 论文的原版原理图和公式如下,其实理解了单向 GAN 那么 CycleGAN 已经很好理解。

在这里插入图片描述

 (1) X → Y X→Y XY 的判别器损失为:

在这里插入图片描述

 (2) Y → X Y→X YX 的判别器损失为:

在这里插入图片描述

 (3)两个生成器的 loss 加起来表示为:

在这里插入图片描述

 (4)最终网络的所有损失加起来为:

在这里插入图片描述

 (5)论文里面提到判别器如果是对数损失(BCE Loss)训练不是很稳定,所以改成的均方误差损失(MSE Loss),如下:

在这里插入图片描述

在这里插入图片描述


3.3 改进Loss

 上面我们提到,我们希望生成器只进行风格的迁移而保证内容不变,具体而言:

  • 风格迁移,内容不变: G 吃一张房子的照片,吐一张梵高风格的房子的照片;
  • 风格迁移,内容改变: G 吃一张房子的照片,任意吐一张梵高风格的照片

 仅靠上面的 Loss 能否保证风格迁移,内容不变呢?我认为不能!以下图为例:

在这里插入图片描述

 正常情况下,我们希望 G ( x ) = a G(x)=a G(x)=a ,但是根据上面的 L o s s Loss Loss 会不会导致 G ( x ) = b , F ( b ) = x G(x)=b , F(b)=x G(x)=b,F(b)=x 的情况发生呢?答案是肯定的,对于 G ( x ) = b G(x)=b G(x)=b ,虽然产生的图片 b b b 并不是我们希望的,但是由于 b b b 的确是梵高画风,所以判别器会给它高分,这会鼓励生成器错误产生 b b b 的这个行为。

 其次,在更新 F F F 参数的时候,由于 L o s s c y c l e Loss_{cycle} Losscycle E x ∼ p d a t a ( x ) [ ‖ F ( G ( x ) ) − x ‖ 1 ] E_{x∼pdata} (x)[‖F(G(x))−x‖_1] Expdata(x)[F(G(x))x1] 一项的存在,即使 G ( x ) G(x) G(x) 错误产生了 b b b F F F 任然会努力把 G ( x ) G(x) G(x) 错误的结果“掰”回 x x x ,这就像 F F F 在“包庇” G G G 的错误。

 同样,在更新 G G G 参数的时候,由于 E y ∼ p d a t a ( y ) [ ‖ G ( F ( y ) ) − y ‖ 1 ] E_{y∼pdata} (y)[‖G(F(y))−y‖_1] Eypdata(y)[G(F(y))y1] 的存在, G G G 也会去“包庇” F F F 。这样一来,就会出现上图中风格迁移,内容改变的情况。

 这一点在原文中有提到,但原文说 Identity Loss 的作用主要是保证色调不变。Identity Loss 的形式为:

在这里插入图片描述

 就是将真实的B输入到A生成B的判别器中,查看判别器的识别损失,希望越小越好!说明生成器网络真正的理解了B的结构。

 加上 Identity Loss 后,整个损失函数的表达式为:

在这里插入图片描述

在这里插入图片描述


 总结一下 Loss 实现:

在这里插入图片描述


3.4 Instance Normalization

 图片使用了Instance Normalization而非经典DCGAN中所使用的Batch NormalizationInstance NormalizationBatch Normalization 一样,也是Normalization的一种方法,只是IN是作用于单张图片,但是BN作用于一个Batch

 参考:深度学习归一化方法总结(BN、LN、IN、GN)


在这里插入图片描述

 假如现在图像先进行了卷积运算得到如上图所示的激活状态 ( N , C , H , W ) (N,C,H,W) (N,C,H,W) ,其中 N N N 是样本数, C C C 为通道数即特征图数。

  • BN:取不同样本的同一个通道的特征做归一化,逐特征维度归一化。这个就是对batch维度进行计算。所以假设5个100通道的特征图的话,就会计算出100个均值方差。5个batch中每一个通道就会计算出来一个均值方差。
  • LN:取的是同一个样本的不同通道做归一化,逐个样本归一化。5个10通道的特征图,LN会给出5个均值方差。
  • IN:仅仅对每一个图片的每一个通道最归一化。也就是说,对【H,W】维度做归一化。假设一个特征图有10个通道,那么就会得到10个均值和10个方差;要是一个batch有5个样本,每个样本有10个通道,那么IN总共会计算出50个均值方差。
  • GN:这个是介于LN和IN之间的一种方法。假设Group分成2个,那么10个通道就会被分成5和5两组。然后5个10通道特征图会计算出10个均值方差。

3.5 PatchGAN

参考

  深度学习《patchGAN》

  PatchGAN笔记(个人理解)

  烂泥上墙之路:PatchGAN的理解


 CycleGAN网络中的判别器使用的是一种叫 PatchGAN 的设计,原始GAN的discriminator的设计是仅输出一个评价值(True or False),该值是对生成器生成的整幅图像的一个评价。

 在以往的GAN学习中,判别器D网络的输出是一个标量,介于0~1之间,代表是真实图片的概率。

 而PatchGAN的设计不同,PatchGAN设计成全卷积的形式,图像经过各种卷积层后,并不会输入到全连接层或者激活函数中,而是使用卷积将输入映射为N*N矩阵,该矩阵等同于原始GAN中的最后的评价值用以评价生成器的生成图像。

N × N N\times N N×N 矩阵中每个点(true or false)即代表原始图像中的一块小区域(这也就是patch含义)评价值,这也就是“感受野(下图)”的应用。

 原来用一个值衡量整幅图,现在使用 N × N N\times N N×N 的矩阵来评价整幅图(使用 PatchGAN 标签也需要设置成为 N × N N\times N N×N 的格式,这样就可以进行损失计算了),显然后者可以关注更多的区域,这也就是 PatchGAN 的优势。

在这里插入图片描述

 PatchGAN主要是用于判别器,普通的判别器我们所得到的是判断一张图像是否为目标图像(输入可以是期望的图像,也可以是生成器生成的图像)。PatchGAN则是基于映射的关系,通过卷积的感受野来判断某个小区域是否为我们想要的目标图片,并最终进行加权

在这里插入图片描述


3.6 训练细节

  • (1)图片使用了Instance Normalization而非经典DCGAN中所使用的Batch Normalization
  • (2)写代码时,并没有使用上面Loss中的 log likelihood 形式,而是使用的least-squares loss
  • (3)判别器采用的70×70 PatchGAN形式;
  • (4)生成器网络使用了residual blocks
  • (5)训练时的Batch Size为1;

【有以下几种说法】:(参考:为什么 CycleGAN 的batchSize等于1?)
 
 1、如果bs>1的话,一个batch里面存在不同内容的源域数据和不同风格的目标数据,会混淆图片的生成,在风格迁移里,就应该一张图一张图地训练。

 2、想在高分辨率图像上训练;

 3、为了将train和test的batchsize保持一致;

 4、instancenormal的使用,使用batchsize=1比较好;

  • (6)学习率在前100个epochs不变,在后面的epochs线性衰减
  • (7)使用了Reflection padding而非普通的Zero padding;
  • (8)生成器各层激活函数主要为ReLU,判别器各层激活函数主要为LeakyReLU;
  • (9)训练判别器时还会用到生成器产生的历史数据(Buffer)

第4章 Pytorch实现CycleGAN

参考资料

  Cycle-GAN代码解读

  CycleGAN 生成对抗网络图像处理工具

  GAN生成对抗网络—cycleGAN


4.1 Models.py

import torch.nn as nn
import torch.nn.functional as F
import torch


# 初始化权重函数
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        if hasattr(m, "bias") and m.bias is not None:
            torch.nn.init.constant_(m.bias.data, 0.0)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


##############################
#           RESNET
##############################


class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        # 经过两次3x3的卷积,(WxH) -> (W-4)x(H-4)
        # 经过两次pandding,(WxH) -> (W+4)x(H+4)
        # 所以经过整个操作后,(WxH) -> (WxH)
        self.block = nn.Sequential(
            # nn.ReflectionPad2d()函数用法参考:https://blog.csdn.net/LionZYT/article/details/120181586
            nn.ReflectionPad2d(1),  # 对四周都填充1行
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
        )

    def forward(self, x):
        return x + self.block(x)

##############################
#        Generator
##############################


class GeneratorResNet(nn.Module):
    def __init__(self, input_shape, num_residual_blocks):
        super(GeneratorResNet, self).__init__()

        # input_shape = (3, 256, 256)
        channels = input_shape[0]

        # Initial convolution block
        out_features = 64
        model = [
            # (3, 256, 256) -> (3, 262, 262)
            nn.ReflectionPad2d(channels),
            # (3, 262, 262) -> (64, 256, 256)
            nn.Conv2d(channels, out_features, 7),
            nn.InstanceNorm2d(out_features),
            nn.ReLU(inplace=True),
        ]
        # in_features = 64
        in_features = out_features

        # Downsampling下采样
        for _ in range(2):
            # 1:out_features = 128
            # 2:out_features = 256
            out_features *= 2
            # (64, 256, 256) -> (256, 64, 64)
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            # in_features = 256
            in_features = out_features

        # Residual blocks
        for _ in range(num_residual_blocks):
            model += [ResidualBlock(out_features)]

        # Upsampling上采样
        for _ in range(2):
            out_features //= 2
            # (256, 64, 64) -> (64, 256, 256)
            model += [
                nn.Upsample(scale_factor=2),    # (W,H)->(Wx2, Hx2)
                nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features

        # Output layer
        # (64, 256, 256) -> (3, 256, 256)
        model += [
            nn.ReflectionPad2d(channels),
            nn.Conv2d(out_features, channels, 7),
            nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


##############################
#        Discriminator
##############################


class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()

        # input_shape = (3, 256, 256)
        channels, height, width = input_shape

        # Calculate output shape of image discriminator (PatchGAN)
        # (1, 16, 16)
        self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)

        def discriminator_block(in_filters, out_filters, normalize=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        # (3, 256, 256) -> (1, 16, 16)
        self.model = nn.Sequential(
            *discriminator_block(channels, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),     # 左右上下的顺序
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def forward(self, img):
        return self.model(img)


4.2 Datasets.py

import glob
import random
import os

from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
"""
    主要是ImageDataset函数的操作,__init__操作将trainA和trainB的路径读入files_A 和files_B;
    __getitem__对两个文件夹的图片进行读取,若不是RGB图片则进行转换;__len__返回两个文件夹数据数量的大值。
"""


# 转为rgb图片
def to_rgb(image):
    rgb_image = Image.new("RGB", image.size)
    rgb_image.paste(image)
    return rgb_image


# 对数据进行读取
class ImageDataset(Dataset):
    def __init__(self, root, transforms_=None, unaligned=False, mode="train"):
        self.transform = transforms.Compose(transforms_)
        self.unaligned = unaligned

        self.files_A = sorted(glob.glob(os.path.join(root, "%sA" % mode) + "/*.*"))
        self.files_B = sorted(glob.glob(os.path.join(root, "%sB" % mode) + "/*.*"))

    def __getitem__(self, index):
        image_A = Image.open(self.files_A[index % len(self.files_A)])

        if self.unaligned:
            image_B = Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)])
        else:
            image_B = Image.open(self.files_B[index % len(self.files_B)])

        # Convert grayscale images to rgb
        if image_A.mode != "RGB":
            image_A = to_rgb(image_A)
        if image_B.mode != "RGB":
            image_B = to_rgb(image_B)

        item_A = self.transform(image_A)
        item_B = self.transform(image_B)
        return {"A": item_A, "B": item_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))


4.3 Utils.py

import random
import time
import datetime
import sys

from torch.autograd import Variable
import torch
import numpy as np

from torchvision.utils import save_image

"""
  主要关注学习率衰减(LambdaLR)。
"""


class ReplayBuffer:
    def __init__(self, max_size=50):
        assert max_size > 0, "Empty buffer or trying to create a black hole. Be careful."
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0, 1) > 0.5:
                    i = random.randint(0, self.max_size - 1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))

# 学习率在前100个epochs不变,在后面的epochs线性衰减
class LambdaLR:
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert (n_epochs - decay_start_epoch) > 0, "Decay must start before the training session ends!"
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)


4.4 Cyclegan.py

import argparse
import os
import numpy as np
import math
import itertools
from tqdm.autonotebook import tqdm
from torchvision.transforms import InterpolationMode

import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

from models import *
from datasets import *
from utils import *

import torch.nn as nn
import torch.nn.functional as F
import torch


'''
参数表格
    epoch:使用数据集的所有数据进行一次模型训练,一代训练,从第0代开始训练
    n_epochs:训练的次数,默认200次
    dataset_name:数据集文件夹的名字,默认"monet2photo"
    batch_size:使用数据中的一部分数据进行模型权重更新的这部分数据大小,默认1
    lr:adam学习率
    b1&b2:adam学习参数
    decay_epoch:lr学习率开始衰减
    n_cpu:训练过程中用到的CPU线程数目
    img_height:输入图片的高度,默认256
    img_width:输入图片的宽度,默认256
    channels:图片的通道数,默认为彩色图片,channels=3
    sample_interval:每隔一段时间对训练输出进行采样并展示,默认100
    n_residual_blocks:生成器中的residual模块的数量
    lambda_cyc:cycle loss权重参数
    lambda_id:identity loss权重参数
'''
parser = argparse.ArgumentParser()
parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from")
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--dataset_name", type=str, default="monet2photo", help="name of the dataset")
parser.add_argument("--batch_size", type=int, default=1, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--decay_epoch", type=int, default=100, help="epoch from which to start lr decay")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--img_height", type=int, default=256, help="size of image height")
parser.add_argument("--img_width", type=int, default=256, help="size of image width")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=100, help="interval between saving generator outputs")
parser.add_argument("--checkpoint_interval", type=int, default=-1, help="interval between saving model checkpoints")
parser.add_argument("--n_residual_blocks", type=int, default=9, help="number of residual blocks in generator")
parser.add_argument("--lambda_cyc", type=float, default=10.0, help="cycle loss weight")
parser.add_argument("--lambda_id", type=float, default=5.0, help="identity loss weight")
opt = parser.parse_args()
print(opt)

# 创建文件夹保存模型和采样输出图片
os.makedirs("images/%s" % opt.dataset_name, exist_ok=True)
os.makedirs("saved_models/%s" % opt.dataset_name, exist_ok=True)

# 损失函数定义和初始化
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

# 判断电脑是否可以使用GPU进行训练
cuda = torch.cuda.is_available()
# input_shape保存输入图片的通道数,高度,宽度
input_shape = (opt.channels, opt.img_height, opt.img_width)


# 初始化四个网络(G_AB,G_BA,D_A,D_B)
G_AB = GeneratorResNet(input_shape, opt.n_residual_blocks)
G_BA = GeneratorResNet(input_shape, opt.n_residual_blocks)
D_A = Discriminator(input_shape)
D_B = Discriminator(input_shape)

# 采用GPU进行训练
if cuda:
    G_AB = G_AB.cuda()
    G_BA = G_BA.cuda()
    D_A = D_A.cuda()
    D_B = D_B.cuda()
    criterion_GAN.cuda()
    criterion_cycle.cuda()
    criterion_identity.cuda()


# 如果不是从第0代开始训练,则从保存的模型中调用模型以及加载开始训练的代数,继续训练
if opt.epoch != 0:
    # Load pretrained models
    G_AB.load_state_dict(torch.load("saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, opt.epoch)))
    G_BA.load_state_dict(torch.load("saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, opt.epoch)))
    D_A.load_state_dict(torch.load("saved_models/%s/D_A_%d.pth" % (opt.dataset_name, opt.epoch)))
    D_B.load_state_dict(torch.load("saved_models/%s/D_B_%d.pth" % (opt.dataset_name, opt.epoch)))
else:
    # 如果从头开始训练,就初始化权重
    # Initialize weights
    G_AB.apply(weights_init_normal)
    G_BA.apply(weights_init_normal)
    D_A.apply(weights_init_normal)
    D_B.apply(weights_init_normal)

# 定义初始化模型的优化器
optimizer_G = torch.optim.Adam(
    itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))


# 按照epoch的次数自动调整学习率
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)

Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

# Buffers of previously generated samples
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

# 数据预处理包括resize、crop、flip、normalize等操作
transforms_ = [
    transforms.Resize(int(opt.img_height * 1.12), interpolation=InterpolationMode.BICUBIC), # 调整Image对象的尺寸
    transforms.RandomCrop((opt.img_height, opt.img_width)),  # 扩大后剪切成img_height*img_width大小的图片
    transforms.RandomHorizontalFlip(),  # 依据概率p对PIL图片进行水平翻转,p默认0.5
    transforms.ToTensor(),  # 转为tensor格式
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 归一化
]

# 加载训练数据
# Training data loader
dataloader = DataLoader(
    # ../表示当前目录的父目录
    ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, unaligned=True, mode="train"),
    batch_size=opt.batch_size,
    shuffle=True,  # 将数据打乱,数值越大,混乱程度越大
    # num_workers=0,
    num_workers=opt.n_cpu,  # 线程数
)

# 测试数据加载
# Test data loader
val_dataloader = DataLoader(
    ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, unaligned=True, mode="test"),
    batch_size=5,
    shuffle=True,
    num_workers=0,
)


# 定义测试数据喂进网络的输出展示函数
def sample_images(batches_done):
    """Saves a generated sample from the test set"""
    imgs = next(iter(val_dataloader))
    G_AB.eval()
    G_BA.eval()
    real_A = Variable(imgs["A"].type(Tensor))
    fake_B = G_AB(real_A)
    real_B = Variable(imgs["B"].type(Tensor))
    fake_A = G_BA(real_B)
    # Arange images along x-axis
    real_A = make_grid(real_A, nrow=5, normalize=True)
    real_B = make_grid(real_B, nrow=5, normalize=True)
    fake_A = make_grid(fake_A, nrow=5, normalize=True)
    fake_B = make_grid(fake_B, nrow=5, normalize=True)
    # Arange images along y-axis
    image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
    save_image(image_grid, "images/%s/%s.png" % (opt.dataset_name, batches_done), normalize=False)


# ----------
#  Training
#  开始训练
# ----------
if __name__ == '__main__':
    prev_time = time.time()
    for epoch in range(opt.epoch, opt.n_epochs):
        loop = tqdm(dataloader, colour='red', unit='img')
        for i, batch in enumerate(loop):

            # 设置模型输入
            # Set model input
            real_A = Variable(batch["A"].type(Tensor))
            real_B = Variable(batch["B"].type(Tensor))

            # 对抗生成网络中的真实图片和虚假图片
            # Adversarial ground truths
            valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False)
            fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False)

            # ------------------
            #  Train Generators
            #  训练生成器
            # ------------------

            G_AB.train()
            G_BA.train()

            # 梯度清零,方便下代训练
            optimizer_G.zero_grad()

            # Identity loss :
            # 用于保证生成图像的连续性,一个图像x,经过其中一个生成器生成图像 G(x),尽可能与原来图像接近。
            loss_id_A = criterion_identity(G_BA(real_A), real_A)
            loss_id_B = criterion_identity(G_AB(real_B), real_B)

            loss_identity = (loss_id_A + loss_id_B) / 2

            # GAN loss
            fake_B = G_AB(real_A)
            loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
            fake_A = G_BA(real_B)
            loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)

            loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

            # Cycle loss
            recov_A = G_BA(fake_B)
            loss_cycle_A = criterion_cycle(recov_A, real_A)
            recov_B = G_AB(fake_A)
            loss_cycle_B = criterion_cycle(recov_B, real_B)

            loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

            # 总损失函数
            # Total loss
            loss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identity

            # 反向传播
            loss_G.backward()

            # 权重更新
            optimizer_G.step()

            # -----------------------
            #  Train Discriminator A
            #  训练分类器A
            # -----------------------

            optimizer_D_A.zero_grad()

            # Real loss
            loss_real = criterion_GAN(D_A(real_A), valid)
            # Fake loss (on batch of previously generated samples)
            fake_A_ = fake_A_buffer.push_and_pop(fake_A)
            loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
            # Total loss
            loss_D_A = (loss_real + loss_fake) / 2

            loss_D_A.backward()
            optimizer_D_A.step()

            # -----------------------
            #  Train Discriminator B
            #  训练分类器B
            # -----------------------

            optimizer_D_B.zero_grad()

            # Real loss
            loss_real = criterion_GAN(D_B(real_B), valid)
            # Fake loss (on batch of previously generated samples)
            fake_B_ = fake_B_buffer.push_and_pop(fake_B)
            loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
            # Total loss
            loss_D_B = (loss_real + loss_fake) / 2

            loss_D_B.backward()
            optimizer_D_B.step()

            # 两个分类器损失之和
            loss_D = (loss_D_A + loss_D_B) / 2

            # --------------
            #  Log Progress
            #  训练进程显示
            # --------------

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + i
            # batches_left = opt.n_epochs * len(dataloader) - batches_done
            # time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
            # prev_time = time.time()

            # # Print log
            # sys.stdout.write(
            #     "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"
            #     % (
            #         epoch,
            #         opt.n_epochs,
            #         i,
            #         len(dataloader),
            #         loss_D.item(),
            #         loss_G.item(),
            #         loss_GAN.item(),
            #         loss_cycle.item(),
            #         loss_identity.item(),
            #         time_left,
            #     )
            # )

            # 进度条参数
            loop.set_description(f"Epoch [{epoch}/{opt.n_epochs}] Batch[{i}/{len(dataloader)}]")
            loop.set_postfix(D_loss=loss_D.item(), G_loss=loss_G.item(),
                             loss_GAN=loss_GAN.item(), loss_Cycle=loss_cycle.item(),
                             loss_Identity=loss_identity.item())
            # If at sample interval save image
            if batches_done % opt.sample_interval == 0:
                sample_images(batches_done)

        # Update learning rates
        # 更新学习率
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()

        if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
            # Save model checkpoints
            torch.save(G_AB.state_dict(), "saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, epoch))
            torch.save(G_BA.state_dict(), "saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, epoch))
            torch.save(D_A.state_dict(), "saved_models/%s/D_A_%d.pth" % (opt.dataset_name, epoch))
            torch.save(D_B.state_dict(), "saved_models/%s/D_B_%d.pth" % (opt.dataset_name, epoch))

  • 9
    点赞
  • 70
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

travellerss

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值