Pix2Pix(从零实现图像风格转换任务)

本篇博文的内容是利用CGAN从零实现图像风格转换任务,本文用的数据集图片如下,目标是由右边的仅包含人物线条的图片生成左边的彩色图(通俗的来说就是给黑白图片上色)。
在这里插入图片描述

各位可以根据提供的地址从Kaggle网站上将图片下载。动漫图片数据集地址
数据集下载下来有文件夹重复的情况,最后仅保留三个文件夹即可,如下图
在这里插入图片描述

CGAN论文的简要介绍

论文的作者提出,普通的GAN都是期望生成器能够学习到将一个噪声分布(比如高斯分布或者是均匀分布)映射成我们想要的概率分布。
r a n d o m   z random \ z random z   G : z → y \ \bold G:z \rarr y  G:zy
但是在CGAN中,我们的式子变成了
r a n d o m   z ,    i n p u t _ i m a g e   x random \ z,\ \ input \_image\ x random z,  input_image x   G : x , z → y \ \bold G:{x, z} \rarr y  G:x,zy
这里的input_image指的是仅包含线条的图片
在这里插入图片描述
实作上,我参考的Github项目并没有往Generator中输入random noise z。我的理解是本来Generator的目的就是将一个分布映射成我们想要的分布,增加一个noise分布似乎并没有增添有用的信息。论文的作者也提到了增加noise以后,结果的随机性增加的有限,当然读者可以在实现时增加一个noise来对比效果。

在这里插入图片描述
上图是Generator的结构简图,它包含了两个部分,一部分是将输入图片进行下采样的encoder,另一部分是进行上采样的decoder。作者考虑到输入(仅包含线条的图片)与输出(上了色的图片)在结构上具有对称性,比如两种图片的线条轮廓是一样的。所以作者提出了将encoder中的特征图concatenate到了decoder的对应位置上,图中的虚线就表示了这种关系。

而CGAN中discriminator的特殊之处在于,它接收的输入是 x , y x,y x,y concatenate的结果,并且它的输出不再是一个标量,而是代表一个四维的张量,论文作者将其称为一个patch。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
以上就是关于模型设计的全部信息,不过在实现时,我们将BatchNorm替换成InstanceNorm

模型设计

#Generator
import torch
from torch import nn
class CNNBlock(nn.Module):
	def __init__(self, in_channels, out_channels, relu=False, down=True, use_dropout=False):
		#in_channels输入的通道数
		#out_channels输出的通道数
		#relu 是用relu还是leakyReLU
		#down 是下采样还是上采样
		#use_dropout是否是用dropout
		self.block = nn.Sequential(
			nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=True)
			if down else
			nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode='reflect'),
			nn.InstanceNorm2d(out_channels),
			nn.Dropout(0.5) if use_dropout,
			nn.ReLu() if relu else nn.LeakyReLu(0.2)
	)
class Generator(nn.Module):
	def __init__(self):
		super().__init__(img_channels=3, feature=64)
		#encoder一开始没有batchnorm
		self.down1 = nn.Sequential(
			nn.Conv2d(img_channels, feature, 4, 2, 1, padding_mode='reflect'),
			nn.LeakyReLU(0.2)
		)
		self.down2 = CNNBLock(feature, feature * 2)
		self.down3 = CNNBlock(feature * 2, feature * 4)
		self.down4 = CNNBlock(feature * 4, feature * 8)
		self.down5 = CNNBlock(feature * 8, feature * 8)	
		self.down6 = CNNBlock(feature * 8, feature * 8)	
		self.down7 = CNNBlock(feature * 8, feature * 8)	
		self.bottleneck = nn.Sequential(
			nn.Conv2d(feature * 8, feature * 8, 4, 2, 1, padding_mode='reflect'),
			nn.ReLU()
		)
		#decoder部分,仅有前三层是用dropout
		self.up1 = CNNBlock(feature * 8, feature * 8, relu=True, down=False, use_dropout=True)
		self.up2 = CNNBlock(feature * 16, feature * 8, relu=True, down=False, use_dropout=True)
		self.up3 = CNNBlock(feature * 16, feature * 8, relu=True, down=False, use_dropout=True)	
		self.up4 = CNNBlock(feature * 16, feature * 8, relu=True, down=False)	
		self.up5 = CNNBlock(feature * 16, feature * 4, relu=True, down=False)	
		self.up6 = CNNBlock(feature * 8, feature * 2, relu=True, down=False)	
		self.up7 = CNNBlock(feature * 4, feature, relu=True, down=False)
		#最后一层是用Tanh作为激活函数
		self.fianl_conv = nn.Sequential(
			nn.ConvTranspose2d(feature * 2, img_channels, 4, 2, 1, padding_mode='reflect'),
			nn.Tanh()
	)
	def forward(self, x):
	#x[N, img_channels, 256, 256]
		down1 = self.down1(x)#[N, 64, 128, 128]
		down2 = self.down2(down1)#[N, 128, 64, 64]
		down3 = self.down3(down2)#[N, 256, 32, 32]
		down4 = self.down4(down3)#[N, 512, 16, 16]
		down5 = self.down5(down4)#[N, 512, 8, 8]
		down6 = self.down6(down5)#[N, 512, 4, 4]
		down7 = self.down7(down6)#[N, 512, 2, 2]
		bottleneck = self.bottleneck(down7)#[N, 512, 1, 1]
		up1 = self.up1(bottleneck) #[N, 512, 2, 2]
		up2 = self.up2(torch.cat([up1, down7], 1))#[N, 512, 4, 4]
		up3 = self.up3(torch.cat([up2, down6], 1))#[N, 512, 8, 8]
		up4 = self.up4(torch.cat([up3, down5], 1))#[N, 512, 16, 16]
		up5 = self.up5(torch.cat([up4, down4], 1))#[N, 256, 32, 32]
		up6 = self.up6(torch.cat([up5, down3], 1))#[N, 128, 64, 64]
		up7 = self.up7(torch.cat([up6, down2], 1))#[N, 64, 128, 128]
		final_conv = self.final_conv(torch.cat([up7, down1]))#[N, 3, 128, 128]
		return final_conv	

Disciminator就是一个很简单的下采样

import torch
from torch import nn
class CNNBlock(nn.Module):
	def __init(self, in_channels, out_channels, stride):
		super().__init__()
		self.conv = nn.Sequential(
			nn.Conv2d(in_channels, out_channels, 4, 2, stride, bias=False, padding_mode='reflect')
			nn.InstanceNorm2d(out_channels),
			nn.LeakyReLU(0.2)
	)
class Discriminator(nn.Module):
	def __init__(self, img_channles, features=[64, 128, 256, 512]):
		super().__init__()
		self.init_conv = nn.Sequential(
			nn.Conv2d(img_channels * 2, features[0]),
			nn.LeakyReLU(0.2)
	)
	   layers = []
	   in_channels = features[0]
	   for feature in features:
	   		layers.append(CNNBlock(in_channels, feature, stride=1 if feature==features[-1] else 2))
	   		in_channels = feature
	   self.model = nn.Sequential(*layers)
	
	def forward(self, x, y):
		x = torch.cat([x, y], 1)
		x = self.init_conv(x)
		return self.model(x)

准备数据集

当我们的数据集下载好以后,需要处理成Dataset类,方便DataLoader加载

#config.py
#这里会进行一些数据增强的操作
import albumentations as A
from albumentations.pytorch import ToTensorV2
#两种图片都会进行的操作
both_transforms = A.Compose([
			#调整大小,50%进行水平旋转
			A.Resize(width=256, height=256), A.HorizontalFlip(p=0.5)], additional_targets = {"image0":"image"})
transform_only_input = A.Compose([
	A.ColorJitter(p=0.1),
	A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.),
	ToTensorV2()
])
transform_only_mask = A.Compose([
	A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.),
	ToTensorV2()
])
import torch
from torch.utisl.data import Dataset
import config
import numpy as np
import os
from PIL import Image
class AnimeDataset(Dataset):
	def __init__(self, root_dir):
		self.root_dir = root_dir
		self.paths = os.listdir(root_dir)
	def __len__(self):
		return len(self.paths)
	def __getitem__(self, index):
		img_path = os.path.join(self.root_dir, self.paths[index])
		img = np.array(Image.open(img_path))
		#图片总宽1024,右边512是input,左边512是目标 
		input_image = img[:, 512:, :]
		target_image = img[:, :512, :]
		augmentations = config.both_transforms(image=input_image, image0=target_image)
		input_image, target_image = augmentations['image'], augmentations['image0']
		input_image = config.transform_only_input(image=input_image)['image']
		target_image = config.transform_only_mask(image=target_image)['image']
		return input_image, target_image

一些训练的时候将会用到的工具

config就是一个配置文件,记录一些超参数还有数据增强的代码

#config.py
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
DEVICE = "cuda" if torch.cuda.is_available() else 'cpu'

LEARNING_RATE = 2e-4
BATCH_SIZE = 16
NUM_WORKERS = 2
IMAGE_SIZE = 256
CHANNELS_IMG = 3
L1_LAMBDA = 100
NUM_EPOCHS = 500
LOAD_MODEL = True
SAVE_MODEL = True
CHECKPOINT_DISC = "disc.pth.tar"
CHECKPOINT_GEN = "gen.pth.tar"
both_transforms = A.Compose(
    [A.Resize(width=IMAGE_SIZE, height=IMAGE_SIZE), A.HorizontalFlip(p=0.5)], additional_targets={"image0":"image"}
)
transform_only_input = A.Compose(
    [
        A.ColorJitter(p=0.1),
        A.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5], max_pixel_value=255.0),
        ToTensorV2()
    ]
)
transform_only_mask = A.Compose(
    [
        A.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5], max_pixel_value=255.0),
        ToTensorV2()
    ]
)

一些在训练中可能用到的工具

#utils.py
import torch
import config
from torchvision.image import save_image

def save_some_examples(gen, val_loader, epoch, folder):
	x, y = next(iter(val_loader))
	x, y = x.to(config.DEVICE), y.to(config.DEVICE)
	gen.val()
	with torch.no_grad():
		y_fake = gen(x)
		y_fake = y_fake * 0.5 + 0.5
		save_image(y_fake, folder + f"/y_gen_{epoch}.png")
        save_image(x * 0.5 + 0.5, folder + f"/input_{epoch}.png")
   	    if epoch == 1:
        	save_image(y * 0.5 + 0.5, folder + f"/label_{epoch}.png")
        	gen.train()

def save_checkpoint(model, optimizer, filename):
	checkpoint = {
	"model":model.state_dict(),
	"optimizer":optimizer.state_dict()
	}
	torch.save(checkpoint, filename)

def load_checkpoint(model, optimizer, filename, lr):
	state_dict = torch.load(filename, map_location=config.DEVICE)
	model.load_state_dict(state_dict['model'])
	optimizer.load_state_dict(state_dict['optimizer'])
	for param_group in optimizer.param_group():
		param_group['lr'] = lr
		

训练过程

import torch
from torch import optim, nn
from generator import Generator
from discriminator import Discriminator
from dataset import AnimeDataset
from torch.utils.data import DataLoader
from utils import *
import config
from tqdm import tqdm

disc = Discriminator(img_channels=conifg.CHANNELS_IMAGE).to(config.DEVICE)
gen = Generator(img_channels=config.CHANNELS_IMAGE).to(config.DEVICE)
opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.5,0.999))
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.5,0.999))
train_set = AnimeDataset(root_dir='../data/anime/train')
train_loader = DataLoader(train_set, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUN_WORKERS)
val_set = AnimeDataset(root_dir='../data/anime/val')
val_loader = DataLoader(val_set, batch_size=1, shuffle=False)

bce = nn.BCEWithLogitsLoss() #损失函数依旧是js散度
L1_loss = nn.L1Loss() #l1正则项
if config.LOAD_MODEL:
	load_checkpoint(gen, opt_gen, config.CHECKPOINT_GEN, config.LEARNING_RATE)
	load_checkpoint(disc, opt_disc, config.CHECKPOINT_GEN, config.LEARNING_RATE)

for epoch in range(1, config.NUM_EPOCHS + 1):
	loop = tqdm(train_loader ,leave=True)#显示训练进度条
	for x, y in loop:
		x ,y = x.to(config.DEVICE), y.to(config.DEVICE)
		y_fake = gen(x)
		disc_real = disc(x ,y).reshape(-1)
		disc_fake = disc(x, y_fake.detach()).reshape(-1)
		lossD_real = bce(disc_real, torch.ones_like(disc_real))
		lossD_fake = bce(disc_fake, torch.zeros_like(disc_fake))
		lossD = lossD_real + lossD_fake
		opt_disc.zero_grad()
		lossD.backward()
		opt_disc.step()
		
        D_fake = disc(x, y_fake).reshape(-1)
        G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
        L1 = l1(y_fake, y) * config.L1_LAMBDA
        G_loss = G_fake_loss + L1
        opt_gen.zero_grad()
        G_loss.backward()
		opt_gen.step()
    if config.SAVE_MODEL and epoch % 5 == 0:
        save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
        save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC)
    	save_some_examples(gen, val_loader, epoch, folder='evaluation')
  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值