CycleGAN(循环生成对抗网络)是一种用于图像到图像转换的深度学习技术,它能够在没有成对训练样本的情况下,将一种风格的图像转换成另一种风格。CycleGAN 通常用于图像风格迁移、季节转换、艺术风格模仿等任务。它是由朱俊彦等人提出的,并在论文《Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks》中进行了详细描述。
CycleGAN 的工作原理如下:
生成器(Generator):CycleGAN 有两个生成器,分别用于将图像从源域转换到目标域,以及将图像从目标域转换回源域。生成器的目标是生成足够真实的图像,以欺骗判别器。
判别器(Discriminator):CycleGAN 也有两个判别器,分别用于判断图像是否属于源域或目标域。判别器的目标是能够准确地区分真实图像和生成器生成的假图像。
循环一致性损失(Cycle Consistency Loss):这是 CycleGAN 的核心概念之一。循环一致性损失确保图像经过一次完整的转换循环(即从源域转换到目标域,再从目标域转换回源域)后,能够回到原始图像。这种损失函数有助于在没有成对数据的情况下保持图像内容的一致性。
身份损失(Identity Loss):这个损失函数确保当源域图像输入到目标域的生成器时,输出图像应该与输入图像尽可能相似。这样可以避免生成器在转换过程中改变不应该改变的内容。
对抗损失(Adversarial Loss):这是生成对抗网络(GAN)中常用的损失函数,用于让生成器生成的图像更接近于真实图像。
通过最小化这些损失函数,CycleGAN 能够学习到源域和目标域之间的映射关系,从而实现图像风格的转换。
CycleGAN 的应用非常广泛,例如:
将普通照片转换为印象派风格的画作。
将夏天的风景照片转换为冬天的风景。
将马的图像转换为斑马的图像。
由于CycleGAN 不需要成对的训练数据,它为图像风格转换提供了一种灵活且强大的工具。
from random import randint
import numpy as np
import torch
torch.set_default_tensor_type(torch.FloatTensor)
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import os
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image
import shutil
import cv2
import random
from PIL import Image
import itertools
导入代码中需要的各种库和模块,包括PyTorch及其相关库,以及Python的标准库和其他图像处理库。
def to_img(x):
# 将归一化后的图像转换为0-1范围
out = 0.5 * (x + 1)
out = out.clamp(0, 1)
# 改变图像的形状以适应显示和保存
out = out.view(-1, 3, 256, 256)
return out
将归一化后的图像数据转换回0到1的范围,并调整其形状以适应图像显示和保存。
数据加载
data_path = os.path.abspath('../data')
image_size = 256
batch_size = 1
transform = transforms.Compose([transforms.Resize(int(image_size * 1.12),
Image.BICUBIC),
transforms.RandomCrop(image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])
def _get_train_data(batch_size=1):
train_a_filepath = data_path + '\\trainA\\'
train_b_filepath = data_path + '\\trainB\\'
train_a_list = os.listdir(train_a_filepath)
train_b_list = os.listdir(train_b_filepath)
train_a_result = []
train_b_result = []
numlist = random.sample(range(0, len(train_a_list)), batch_size)
for i in numlist:
a_filename = train_a_list[i]
a_img = Image.open(train_a_filepath + a_filename).convert('RGB')
res_a_img = transform(a_img)
train_a_result.append(torch.unsqueeze(res_a_img, 0))
b_filename = train_b_list[i]
b_img = Image.open(train_b_filepath + b_filename).convert('RGB')
res_b_img = transform(b_img)
train_b_result.append(torch.unsqueeze(res_b_img, 0)