(五)cycleGAN论文笔记与实战

一、cycleGAN架构与目标函数

在这里插入图片描述
在cycleGAN中有两个生成器和两个判别器,核心思想就是循环一致性,原始输入 x x x 通过生成器 G 获得图像 Y ^ \hat Y Y^, 然后再通过生成器 F 获得 x ^ \hat x x^ ,就是要尽可能让 x x x ≈ \approx x ^ \hat x x^, 这也引出了循环一致性损失。其分为两种:

  1. 前向循环一致性损失:x→G(x)→F(G(x))≈x
  2. 后向循环一致性损失:y→F(y)→G(F(y))≈y

公式表示为: L c y c ( G , F ) L_{cyc}(G, F) Lcyc(G,F) = E x ∼ p d a t a ( x ) [ ∣ ∣ F ( G ( x ) − x ) ∣ ∣ 1 ] E_{x \sim p_{data}(x)}[|| F(G(x) - x) ||_1 ] Expdata(x)[F(G(x)x)1]+ E y ∼ p d a t a ( y ) [ ∣ ∣ G ( F ( y ) − y ) ∣ ∣ 1 ] E_{y \sim p_{data}(y)}[|| G(F(y) - y) ||_1 ] Eypdata(y)[G(F(y)y)1]

注:1-范数分为向量1-范数和矩阵1-范数,向量1-范数表示向量中元素绝对值之和:
                                               ∣ ∣ x ∣ ∣ 1 || x ||_1 x1 = ∑ i = 1 n ∣ x i ∣ \sum_{i=1}^n|x_i| i=1nxi
矩阵-1范数也称列和范数,表示所有矩阵的列向量中元素绝对值之和最大的那个值:
                                               ∣ ∣ X ∣ ∣ 1 = m a x j ∑ i = 1 n ∣ a i , j ∣ ||X||_1 = max_j\sum_{i=1}^n|a_{i,j}| X1=maxji=1nai,j

还有一个损失i像普通GAN一样存在的生成器与判别器之间的Adversarial Loss(对抗性损失):

i ) i) i) 对于 G G G D y D_y Dy 组成的GAN而言,它的损失函数为:

        L G A N ( G , D Y , X , Y ) L_{GAN}(G, D_Y, X, Y) LGAN(G,DY,X,Y) = E y ∼ p d a t a ( y ) [ l o g D Y ( y ) ] E_{y\sim p_{data}(y)}[logD_{Y}(y)] Eypdata(y)[logDY(y)] + E x ∼ p d a t a ( x ) [ l o g ( 1 − D Y ( G ( x ) ) ) ] E_{x \sim p_{data}(x)}[log(1-D_Y(G(x)))] Expdata(x)[log(1DY(G(x)))]

i i ) ii) ii)对于 F F F D X D_X DX 组成的GAN而言,它的损失函数为:

        L G A N ( F , D X , Y , X ) L_{GAN}(F, D_X, Y, X) LGAN(F,DX,Y,X) = E x ∼ p d a t a ( x ) [ l o g D X ( x ) ] E_{x\sim p_{data}(x)}[logD_{X}(x)] Expdata(x)[logDX(x)] + E y ∼ p d a t a ( y ) [ l o g ( 1 − D X ( F ( y ) ) ) ] E_{y \sim p_{data}(y)}[log(1-D_X(F(y)))] Eypdata(y)[log(1DX(F(y)))]

但是,由于传统的GAN生成图像质量不高并且在训练模型时不稳定,cycleGAN为了避免这两个问题,使用了最小二乘GAN (Least Square GAN,LSGAN)中的目标函数来代替传统GAN的目标函数,即使用平方差作为损失而不是 l o g log log似然。公式为:
        L L S G A N ( G , D Y , X , Y ) L_{LSGAN}(G, D_Y, X, Y) LLSGAN(G,DY,X,Y) = E y ∼ p d a t a ( y ) [ ( D Y ( y ) − 1 ) 2 ] E_{y\sim p_{data}(y)}[(D_Y(y) - 1)^2] Eypdata(y)[(DY(y)1)2] + E x ∼ p d a t a ( x ) [ D Y ( G ( x ) ) 2 ] E_{x \sim p_{data}(x)}[D_Y(G(x))^2] Expdata(x)[DY(G(x))2]

        L L S G A N ( F , D X , Y , X ) L_{LSGAN}(F, D_X, Y, X) LLSGAN(F,DX,Y,X) = E x ∼ p d a t a ( x ) [ ( D X ( x ) − 1 ) 2 ] E_{x\sim p_{data}(x)}[(D_X(x) - 1)^2] Expdata(x)[(DX(x)1)2] + E y ∼ p d a t a ( y ) [ D X ( F ( y ) ) 2 ] E_{y \sim p_{data}(y)}[D_X(F(y))^2] Eypdata(y)[DX(F(y))2]

最后总的目标函数就是:

        L ( G , F , D X , D Y ) L(G, F, D_X, D_Y) L(G,F,DX,DY) = L G A N ( G , D Y , X , Y ) L_{GAN}(G, D_Y, X, Y) LGAN(G,DY,X,Y) + L G A N ( F , D X , Y , X ) L_{GAN}(F, D_X, Y, X) LGAN(F,DX,Y,X) + λ L c y c ( G , F ) \lambda L_{cyc}(G, F) λLcyc(G,F)

现在我们的目的就是求:
                            在这里插入图片描述
注:我们的模型可以看作是训练两个‘自编码器’,we learn one autoencoder F ◦ G : X →X jointly with another G ◦ F : Y → Y . 然而这些自编码器每个都有特殊的内部结构:它们通过中间表示将图像映射到另一个域中。这种设置也可以看作是“对抗性自动编码器”的一个特例,它使用对抗性损失来训练自动编码器的瓶颈层,以匹配任意的目标分布。在我们的例子中,X→X自动编码器的目标分布是域Y的目标分布。

二、训练细节

( 1 ) (1) 1 为了减少模型震荡,遵循了Shrivastava等人的策略,使用生成图像 的历史记录,而不是生成器最新生成的图像,即保留一个图像缓冲区,它存储了之前创建的50个图像。

( 2 ) (2) 2目标函数使用LSGAN平方差损失代替传统的GAN损失

( 3 ) (3) 3生成器中使用残差网络,以更好的保存图像的语义,判别器使用patchGAN

( 4 ) (4) 4用Instanse normalization( IN )代替Batch normalization (BN)

三、完整代码

import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools
import datetime
import time
import glob
import os
import random
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import make_grid, save_image
from torch.autograd import Variable
from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from torch.utils.tensorboard import SummaryWriter

parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type = int, default=0, help = 'epoch to start training from')
parser.add_argument('--n_epochs', type = int, default=200, help = 'number of epochs of training')
parser.add_argument('--dataset_name', type = str, default='monet2photo', help = 'name of the dataset')
parser.add_argument('--batch_size', type = int, default=1, help = 'size of the batches')
parser.add_argument('--lr', type = float, default=0.0002, help='learning rate')
parser.add_argument('--b1', type = float, default=0.5, help='adam: decay of first order momentum of gradient')
parser.add_argument('--b2', type = float, default=0.999, help = 'adam: decay of second order momentum of gradient')
parser.add_argument('--decay_epoch', type = int, default=100, help = 'epoch from which to start lr decay')
parser.add_argument('--img_height', type = int, default=256, help='the height of image')
parser.add_argument('--img_width', type = int, default=256, help='the width of image')
parser.add_argument('--channels', type = int, default=3, help='number of image channels')
parser.add_argument('--sample_interval', type = int, default=100, help='interval between saving generator outputs')
parser.add_argument('--checkpoint_interval', type = int, default=1, help='interval between saving model checkpoints')
parser.add_argument('--n_residual_blocks', type = int, default=9, help='number of residual blocks in generator')
parser.add_argument('--lambda_cyc', type = float, default=10.0, help='cycle loss weight')
parser.add_argument('--lambda_id', type = float, default=5.0, help = 'identity loss weight')

opt = parser.parse_args(args = [])
print(opt)

random.seed(22)
torch.manual_seed(22)
os.makedirs(name = 'Picture/cycleGAN', exist_ok=True)
os.makedirs(name = 'Model/cycleGAN', exist_ok=True)
os.makedirs(name = 'runs/cycleGAN', exist_ok=True)
cuda = True if torch.cuda.is_available() else False
input_shape = (opt.channels, opt.img_height, opt.img_width)
'''数据集类'''
# 将图片转为RGB模式
def to_rgb(image):
    rgb_image = Image.new(mode='RGB', size=image.size)
    rgb_image.paste(image)
    return rgb_image
    
class ImageDataset(Dataset):
    def __init__(self, root, transforms_ = None, unaligned = False, mode = 'train'):
        self.transform = transforms.Compose(transforms_)
        self.unaligned = unaligned
        
        self.file_A = sorted(glob.glob(os.path.join(root, '{}A'.format(mode)) + '/*.*'))
        self.file_B = sorted(glob.glob(os.path.join(root, '{}B'.format(mode)) + '/*.*'))
        
    def __getitem__(self, index):
        # 这里对self.file_A取余是因为数据集的trainA和trainB的图片数量不同
        # 不取余会产生越界错误
        image_A = Image.open(self.file_A[index % len(self.file_A)])
        if self.unaligned:
            image_B = Image.open(self.file_B[np.random.randint(0, len(self.file_B) - 1)])
        else:
            image_B = Image.open(self.file_B[index % len(self.file_B)])
        # Convert grayscale images to rgb 这里是为了防止有的图片不是RGB图片
        if image_A.mode != 'RGB':
            to_rgb(image_A)
        if image_B.mode != 'RGB':
            to_rgb(image_B)
        image_A = self.transform(image_A)
        image_B = self.transform(image_B)
        
        return {'A':image_A, 'B':image_B}
    def __len__(self):
        return max(len(self.file_A), len(self.file_B))
    
'''图片缓冲区类'''
class ReplayBuffer:
    def __init__(self, max_size = 50):
        assert max_size > 0, "Empty buffer or trying to create a black hole. Be careful."
        self.max_size = max_size
        self.buffer = [] # 装图片的缓冲列表
        
    def push_pop(self, images): # images是generator最新产生的图片
        to_return = [] # 要返回给discriminator训练的图片
        for img in images:
            img = torch.unsqueeze(img, 0)
            if len(self.buffer)< self.max_size: # 如果缓冲区大小不足50,那么把刚产生的图片加进缓冲区,并加入to_return
                self.buffer.append(img)
                to_return.append(img)
            else:
                # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer(论文作者注释)
                if random.uniform(0, 1) > 0.5:
                    i = random.randint(0, self.max_size-1)
                    to_return.append(self.buffer[i].clone())
                    self.buffer[i] = img
                else: # by another 50% chance, the buffer will return the current image (论文作者注释)
                    to_return.append(img)
                    
        return Variable(torch.cat(to_return, 0)) # collect all the images and return
    
'''自定义学习率类'''
class LambdaLR: # 不是重写某个类,就是自己写的一个类
    def __init__(self, n_epochs, offset, decay_start_epoch):
        '''parameters:
                 n_epochs: numbers of the total training epoch
                 offset:   epoch from witch to start training
                 decay_start_epoch: epoch from which to start lr decay
        '''
        assert decay_start_epoch < n_epochs, 'Decay must start before the training session ends!'
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch
        
    def step(self, epoch): # epoch 是 当前运行的epoch, 后面调用会传进来 ,返回的是初始学习率的权重
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch)/ (self.n_epochs - self.decay_start_epoch) # 权重线性减小

'''加载数据集'''
transform = [
    transforms.Resize(int(opt.img_height * 1.12), Image.BICUBIC),
    transforms.RandomCrop(size=(opt.img_height, opt.img_width)),
    transforms.RandomHorizontalFlip(),#依据概率p对PIL图片进行水平翻转,默认P = 0.5
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]

# Training data loader
dataloader = DataLoader(ImageDataset('../dataset/{}'.format(opt.dataset_name), transforms_=transform, 
                                       unaligned=True, mode='train'),
                                       batch_size = opt.batch_size,
                                       shuffle=True,
                                       num_workers=0)
# Testing data loader
val_dataloader = DataLoader(ImageDataset('../dataset/{}'.format(opt.dataset_name), transforms_=transform, 
                                       unaligned=True, mode='test'),
                                       batch_size=5,
                                       shuffle=True,
                                       num_workers=0)        
'''网络结构'''
def weights_init_normal(m):
    classname = m.__class__.__name__ #得到 m 的类名
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        if hasattr(m, 'bias') and m.bias is not None: # hasattr():查看类是否有bias这个属性
            torch.nn.init.constant_(m.bias.data, 0.0)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)
    
##############################
#           RESNET
##############################
'''
Rk denotes a residual block that contains two 3 × 3 con-volutional layers with the same number of filters
on both layer.
'''
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1), # Reflection padding was used to reduce artifacts.
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            
            nn.ReLU(True),
            
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features)
        )
    def forward(self, x):
        return F.relu(x+self.block(x), True)

##############################
#           Generator
##############################
'''
The network with 9 residual blocks consists of:c7s1-64,d128,d256,R256,R256,R256,R256,R256,R256,R256,R256,
R256,u128, u64,c7s1-3
'''
class Generator(nn.Module):
    def __init__(self, input_shape, num_residual_blocks):
        super(Generator, self).__init__()
        channels = input_shape[0]
        # first layer : c7s1-64 Let c7s1-k denote a 7×7 Convolution-InstanceNorm-ReLU layer with k filters and 
        # stride 1.
        self.c7s1_64 = nn.Sequential( # input shape [b, 3, 256, 256]
            nn.ReflectionPad2d(channels), # --> [b, 3, 262, 262]
            nn.Conv2d(channels, 64, 7), # --> [b, 64, 256, 256]
            nn.InstanceNorm2d(64),
            nn.ReLU(True),
        )
        
        # Downsampling 
        # dk denotes a 3 × 3 Convolution-InstanceNorm-ReLU layer with k filters and stride 2.
        self.dk = nn.Sequential(
            nn.Conv2d(64, 128, 3, 2, 1), # --> [b, 128, 128, 128]
            nn.InstanceNorm2d(128),
            nn.ReLU(True),
            
            nn.Conv2d(128, 256, 3, 2, 1), # --> [b, 256, 64, 64]
            nn.InstanceNorm2d(256),
            nn.ReLU(True),
        )
        
        # Residual blocks
        # Rk denotes a residual block that contains two 3 × 3 convolutional layers with the same number of filters 
        # on both layer.
        model = []
        for _ in range(num_residual_blocks): # 输入:[b, 256, 64, 64]
            model += [ResidualBlock(256)]
        self.rk = nn.Sequential(*model) # 输出:[b, 256, 64, 64]
        
        # Upsampling
        # uk denotes a 3 × 3 fractional-strided-Convolution-InstanceNorm-ReLU layer with k filters and stride 1/2
        self.uk = nn.Sequential(
            nn.Upsample(scale_factor=2), # --> [b, 256, 128, 128]
            nn.ConvTranspose2d(256, 128, 3, 1, 1), # --> [b, 128, 128, 128]
            nn.InstanceNorm2d(128),
            nn.ReLU(True),
            
            nn.Upsample(scale_factor=2),
            nn.ConvTranspose2d(128, 64, 3, 1, 1), # --> [b, 64, 256, 256]
            nn.InstanceNorm2d(64),
            nn.ReLU(True),
        )
        # output layer c7s1-3
        self.c7s1_3 = nn.Sequential(
            nn.ReflectionPad2d(channels), # [b, 64, 262, 262]
            nn.Conv2d(64, channels, 7), # --> [b, 3, 256, 256]
            nn.Tanh(),
        )
        
    def forward(self, x):
        x = self.c7s1_64(x)
        x = self.dk(x)
        x = self.rk(x)
        x = self.uk(x)
        x = self.c7s1_3(x)
        return x
    
##############################
#        Discriminator
##############################
class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()
        channels, height, width = input_shape
        # Calculate output shape of image discriminator (PatchGAN)
        self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
        
        def discriminator_block(in_filters, out_filters, normalize = True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.model = nn.Sequential( # 输入:[b, 3, 256, 256]
            *discriminator_block(channels, 64, normalize=False), # --> [b, 64, 128, 128]
            *discriminator_block(64, 128), # --> [b, 128, 64, 64]
            *discriminator_block(128, 256), # --> [b, 256, 32, 32]
            *discriminator_block(256, 512), # --> [b, 512, 16, 16]
            nn.ZeroPad2d((1, 0, 1, 0)), # --> [b, 512, 17, 17]
            nn.Conv2d(512, 1, 4, padding=1), # --> [b, 1, 16, 16]  反推出感受野大小是 94 * 94, 并不是论文中提及的70 * 70
        )
    def forward(self, img):
        return self.model(img)
'''训练'''
writer = SummaryWriter('runs/cycleGAN')
# Losses
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

# Initialize generator and discriminator
G_AB = Generator(input_shape = input_shape, num_residual_blocks=opt.n_residual_blocks)
G_BA = Generator(input_shape = input_shape, num_residual_blocks=opt.n_residual_blocks)
D_A = Discriminator(input_shape = input_shape)
D_B = Discriminator(input_shape = input_shape)

# if cuda: # 电脑显存不够,所以我只能用cpu训练
#     G_AB = G_AB.cuda()
#     G_BA = G_BA.cuda()
#     D_A = D_A.cuda()
#     D_B = D_B.cuda()
#     criterion_cycle.cuda()
#     criterion_GAN.cuda()
#     criterion_identity.cuda()

if opt.epoch != 0:
    # Load pretrained models
    G_AB.load_state_dict(torch.load('Model/cycleGAN/G_AB_{}'.format(opt.epoch)))
    G_BA.load_state_dict(torch.load('Model/cycleGAN/G_BA_{}'.format(opt.epoch)))
    D_A.load_state_dict(torch.load('Model/cycleGAN/D_A_{}'.format(opt.epoch)))
    D_B.load_state_dict(torch.load('Model/cycleGAN/D_B_{}'.format(opt.epoch)))
else:
    # Initialize weights
    G_AB.apply(weights_init_normal)
    G_BA.apply(weights_init_normal)
    D_A.apply(weights_init_normal)
    D_B.apply(weights_init_normal)
    
# Optimizers
optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), 
                               lr = opt.lr, betas=(opt.b1, opt.b2))
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr = opt.lr, betas=(opt.b1, opt.b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr = opt.lr, betas=(opt.b1, opt.b2))

# Learning rate update schedulers
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, 
                                                                                  opt.decay_epoch).step)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch,
                                                                                      opt.decay_epoch).step)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch,
                                                                                      opt.decay_epoch).step)
Tensor = torch.FloatTensor 

# Buffers of previously generated samples
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

def save_images(batchs_done):
    """Saves a generated sample from the test set"""
    imgs = next(iter(val_dataloader))
    G_AB.eval() # 设置为测试模式
    G_BA.eval() # 设置为测试模式
    real_A = Variable(imgs['A'].type(Tensor))
    fake_B = G_AB(real_A).detach()
    real_B = Variable(imgs['B'].type(Tensor))
    fake_A = G_BA(real_B).detach()
    # Arange images along x-axis
    real_A = make_grid(real_A, nrow=5, normalize=True)
    real_B = make_grid(real_B, nrow=5, normalize=True)
    fake_A = make_grid(fake_A, nrow=5, normalize=True)
    fake_B = make_grid(fake_B, nrow=5, normalize=True)
    # Arange images along y-axis
    img_grid = torch.cat((real_A, real_B, fake_A,fake_B), dim = 1)
    save_image(img_grid, 'Picture/cycleGAN/{}.png'.format(batchs_done), nrow=5, normalize=False)

# ----------
#  Training
# ----------
prev_time = time.time()
for epoch in range(opt.epoch, opt.n_epochs):
    for i ,batch in enumerate(dataloader):
        # Set model input
        real_A = Variable(batch['A'].type(Tensor))
        real_B = Variable(batch['B'].type(Tensor))
        
        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad = False)
        fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad = False)
        
        # -----------------------
        #  Train Discriminator A
        # -----------------------
        optimizer_D_A.zero_grad()
        # real loss
        loss_real = criterion_GAN(D_A(real_A), valid)
        # fake loss
        fake_A = G_BA(real_B).detach()
        fake_A = fake_A_buffer.push_pop(fake_A)
        loss_fake = criterion_GAN(D_A(fake_A), fake)
        # total loss
        loss_D_A = (loss_real + loss_fake)/2
        loss_D_A.backward()
        optimizer_D_A.step()
        
        # -----------------------
        #  Train Discriminator B
        # -----------------------
        optimizer_D_B.zero_grad()
        # real loss
        loss_real = criterion_GAN(D_B(real_A), valid)
        # fake_loss
        fake_B = G_AB(real_A).detach()
        fake_B = fake_B_buffer.push_pop(fake_B)
        loss_fake = criterion_GAN(D_B(fake_B), fake)
        # total loss
        loss_D_B = (loss_real + loss_fake)/2
        loss_D_B.backward()
        optimizer_D_B.step()
        
        loss_D = (loss_D_A + loss_D_B)/2
        
        # ------------------
        #  Train Generators
        # ------------------
        G_AB.train()
        G_BA.train()
        optimizer_G.zero_grad()
        # Identity loss
        loss_id_A = criterion_identity(G_BA(real_A), real_A)
        loss_id_B = criterion_identity(G_AB(real_B), real_B)
        
        loss_identity = (loss_id_A + loss_id_B) / 2
        
        # GAN loss
        fake_B = G_AB(real_A)
        loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
        fake_A = G_BA(real_B)
        loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)
        
        loss_GAN = (loss_GAN_AB + loss_GAN_BA)/2
        
        # cycle loss
        recov_A = G_BA(fake_B)
        loss_cycle_A = criterion_cycle(recov_A, real_A)
        recov_B = G_AB(fake_A)
        loss_cycle_B = criterion_cycle(recov_B, real_B)
        
        loss_cycle = (loss_cycle_A + loss_cycle_B)/2
        
        # total loss
        loss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identity
        
        loss_G.backward()
        optimizer_G.step()
        
        # --------------
        #  Log Progress
        # --------------
        # Determine approximate time left
        batches_done = epoch * len(dataloader) + i
        batches_left = opt.n_epochs * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()

        # Print log
        print(
            "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, GAN: %f, cycle: %f, identity: %f] ETA: %s"
            % (
                epoch,
                opt.n_epochs,
                i,
                len(dataloader),
                loss_D.item(),
                loss_G.item(),
                loss_GAN.item(),
                loss_cycle.item(),
                loss_identity.item(),
                time_left,
            )
        )
        writer.add_scalar('loss_D', loss_D.item(), global_step = epoch)
        writer.add_scalar('loss_G', loss_G.item(), global_step = epoch)

        # If at sample interval save image
        if batches_done % opt.sample_interval == 0:
            save_images(batches_done)

    # Update learning rates
    lr_scheduler_G.step()
    lr_scheduler_D_A.step()
    lr_scheduler_D_B.step()

    if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
        # Save model checkpoints
        torch.save(G_AB.state_dict(), "Model/cycleGAN/G_AB_%d.pth" % (epoch))
        torch.save(G_BA.state_dict(), "Model/cycleGAN/G_BA_%d.pth" % (epoch))
        torch.save(D_A.state_dict(), "Model/cycleGAN/D_A_%d.pth" % (epoch))
        torch.save(D_B.state_dict(), "Model/cycleGAN/D_B_%d.pth" % (epoch))

四、效果截图

电脑跑不动,跑完200个epoch需要2年,我吐了,下面是跑了1/10个epoch的结果:
上面两层是原图, 下面两层是生成的图片
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

五、遇到的问题及解决

一、关于patchGAN的理解:链接一     链接二
二、感受野不等于kernal size
三、PIL库中图像的mode参数
四、itertools.chain()
五、torch.optim.lr_scheduler:调整学习率
六、transforms的二十二个方法
七、output_padding
八、identity loss 的理解

论文附录

在这里插入图片描述

  • 6
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值