摘要
实验是使用CycleGAN网络,实现的一种图片对图片的变化,输入原始图片,即可得到变换后的输出图片。网络主体由四部分组成:
G_A2B:生成器1,输入为A类图片(马的图片),输出为B类图片(斑马图片)。real_A —> fake_B
G_B2A:生成器2,输入为B类图片(斑马图片),输出为A类图片(马的图片)。real_B —> fake_A
D_A:判别器1,输入为real_A或fake_A,输出是0或1。
D_B:判别器2,输入为real_B或fake_B,输出是0或1。
使生成器和判别器在训练过程中不断对抗,生成网络G的目标就是尽量生成真实的图片去欺骗判别网络D,而D的目标就是尽量辨别出G生成的假图像和真实的图像。这样,G和D构成了一个动态的“博弈过程”,最终的平衡点即纳什均衡点。最终保存的生成器即可实现图对图的转换。
源代码在此Pytorch-CycleGAN
文章目录
前言
此网络可以使输入图片的部分区域产生变化,即原图中马的部分变成斑马;也可以使图片整体发生变化,如图片整体风格变化,使真实图片转化为油画风格、不同季节的转变等。
使用的显卡为8G显存的GTX2070S
一、数据 load
1.图片数据保存格式:
可以通过设置上述目录结构来构建自己的数据集(和ImageDataset函数相匹配)。
本次使用的是horse2zebra数据集,共有两类图片A、B,A为马,B为斑马,图像大小为256×256,RGB三通道。
2. 图像预处理
预处理封装在了transforms_列表中,包括调整大小、裁剪、翻转、归一化。
transforms_ = [transforms.Resize(int(256 * 1.12), Image.BICUBIC), # 调整输入图片的大小
transforms.RandomCrop(256), # 随机裁剪
transforms.RandomHorizontalFlip(), # 随机水平翻转图像
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
# 归一化,这两行不能颠倒顺序呢,归一化需要用到tensor型
]
产生迭代器,这里batch_size选取的是2,因为显卡显存限制。
mydata = ImageDataset(opt.dataroot, transforms_=transforms_, unaligned=True)
dataloader = DataLoader(mydata, batch_size=opt.batchSize, shuffle=True, num_workers=opt.n_cpu) # 两个数据集均打乱
二、网络结构
1.生成器网络结构
生成器的主体是基于残差网络,由残差块构成,因为最后需要输出一幅图片,所以需要反卷积上采样,增大特征图尺寸,具体组成如下图所示:
2.判别器网络结构
判别器就是由数个卷积块构成,最后输出0或1。
二、迭代内部过程
1.生成器和判别器
迭代过程中对于从loader中载入的A图和B图:
1.第一步如下图
2.第二部继续使用上一步生成的Fake_B和Fake_A,分别通入对应的生成器,如下图所示。
2.损失构成
1.生成器损失
生成器损失函数:损失函数=自身一致性损失(绿色)+对抗损失(红色)+循环一致性损失(蓝色)
下图的对抗损失因为我们想让他能尽量骗过判别器,所以期望的输出是“1”
上面三种损失都需要double,因为有A的就得有B的。所以总损失:
loss_G=(A的自身一致性损失+A的对抗损失+A的循环一致性损失)+(B的自身一致性损失+B的对抗损失+B的循环一致性损失)
# 对两个生成器所有的参数进行反向传播更新
loss_G.backward()
optimizer_G.step()
2.判别器损失
判别器的损失即是对真假图片的判别能力,如下图。
判别器A的损失同上类似。
三、附加内容
1.argparse模块
2.pytorch的可视化工具:visdom
可视化工具,可以将训练过程中的图片文本和loss和acc等绘制。