CycleGAN的简单实现(pytorch)

        CycleGAN是于2017年发表在ICCV上的由GAN发展而来的一种无监督机器学习算法,是一种实现图像风格转换功能的GAN网络,在此之前存在着pix2pix实现图像风格转换,但pix2pix具有很大的局限性,主要是要求针对两种风格图像要对应出现,而现实中很难找到一些风格不同相同图像,也能难去拍摄获得,CycleGan实现了这个功能,在两种类型图像之间进行转换,而不需要对应关系。比如把照片转换为油画风格,或者把照片的橘子转换为苹果、马与斑马之间的转换等。

实现效果:

 

马转斑马

代码实现:

网络定义和训练代码

'''
Descripttion: 
version: 
Author: MAPLE
Date: 2022-06-12 23:23:54
LastEditors: MAPLE
LastEditTime: 2022-06-28 23:24:09
'''
import os
import torch
import random
import torch.nn as nn
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torch.nn import init
from torch.optim import lr_scheduler
from tqdm import tqdm
from torchvision.utils import save_image
import torch.optim as optim
import torchvision.transforms as transforms

torch.cuda.is_available()

def seed_torch(seed=2018):

    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = "data/horse2zebra"
VAL_DIR = "data/horse2zebra"

BATCH_SIZE = 1
LEARNING_RATE = 2e-4#学习率
LAMBDA_IDENTITY = 5  # identityloss
LAMBDA_CYCLE = 10  # 循环一致性损失
NUM_WORKERS = 2
LOAD_MODEL = True#加载模型参数
SAVE_MODEL = True#保存模型参数

#模型参数保存位置
CHECKPOINT_GEN_H = "genh.pth.tar"
CHECKPOINT_GEN_Z = "genz.pth.tar"
CHECKPOINT_CRITIC_H = "critich.pth.tar"
CHECKPOINT_CRITIC_Z = "criticz.pth.tar"

#学习率调度超参数
EPOCH_COUNT = 1
N_EPOCHS = 100
N_EPOCHS_DECAY = 100

transforms = transforms.Compose(
    [
        transforms.Resize(286, Image.BICUBIC),#重构
        transforms.RandomCrop(256),#随机裁剪
        transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转
        transforms.ToTensor(),#转成tensor格式
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))#归一化
    ]
)

# 自定义参数初始化方式,用于多层网络初始化
def init_weights(net, init_type='normal', init_gain=0.02):
    """Initialize network weights.
        使用标准正态分布
    """
    def init_func(m):  # define the initialization function
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, init_gain)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, init_gain)
            init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.to(DEVICE)
    net.apply(init_func)  # apply the initialization function <init_func>

class ImagePool():
    def __init__(self, pool_size):
        self.pool_size = pool_size
        if self.pool_size > 0:  # create an empty pool
            self.num_imgs = 0
            self.images = []

    def query(self, images):
        """从缓存区返回图片
        """
        if self.pool_size == 0:  # if the buffer size is 0, do nothing
            return images
        return_images = []
        for image in images:
            image = torch.unsqueeze(image.data, 0)
            if self.num_imgs < self.pool_size:   # if the buffer is not full; keep inserting current images to the buffer
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:  # 50%的概率返回以前生成的图像
                    random_id = random.randint(
                        0, self.pool_size - 1)  # randint is inclusive
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image  # 将新得到的图片存入缓存区
                    return_images.append(tmp)
                else:       # by another 50% chance, the buffer will return the current image
                    return_images.append(image)
        # collect all the images and return
        return_images = torch.cat(return_images, 0)
        return return_images

GLOBAL_SEED = 1
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed(GLOBAL_SEED)

# 两个结构链接区域使用Residual block模块,默认是9个重复模块
class ResnetBlock(nn.Module):
    """Define a Resnet block"""
    def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        """Initialize the Resnet block
        """
        super(ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(
            dim, padding_type, norm_layer, use_dropout, use_bias)

    def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        """Construct a convolutional block."""
        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p,
                                 bias=use_bias), norm_layer(dim), nn.ReLU(True)]
        #根据经验得,dropout在卷积中一般没啥用
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3,
                                 padding=p, bias=use_bias), norm_layer(dim)]
        return nn.Sequential(*conv_block)

    def forward(self, x):
        """Forward function (with skip connections)"""
        out = x + self.conv_block(x)  # add skip connections
        return out

# 使用Residual block的生成器
class ResnetGenerator(nn.Module):

    def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9, padding_type='reflect'):
        """Construct a Resnet-based generator
        """
        super(ResnetGenerator, self).__init__()

        use_bias = norm_layer == nn.InstanceNorm2d

        model = [nn.ReflectionPad2d(3),
                nn.Conv2d(input_nc, ngf, kernel_size=7,padding=0, bias=use_bias),
                 norm_layer(ngf),
                 nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):  # add downsampling layers
            mult = 2 ** i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
                      norm_layer(ngf * mult * 2),
                        nn.ReLU(True)]

        mult = 2 ** n_downsampling
        for i in range(n_blocks):       # add ResNet blocks

            model += [ResnetBlock(ngf * mult, padding_type=padding_type,
                                  norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]

        for i in range(n_downsampling):  # add upsampling layers
            mult = 2 ** (n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),kernel_size=3, stride=2,padding=1, output_padding=1,bias=use_bias),
                                 norm_layer(int(ngf * mult / 2)),
                                nn.ReLU(True)]
        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)
        init_weights(self.model)

    def forward(self, input):
        """Standard forward"""
        return self.model(input)

#马尔可夫判别器(PatchGAN),由卷积层构成,最后输出一个n*n的预测矩阵
class NLayerDiscriminator(nn.Module):
    """Defines a PatchGAN discriminator"""

    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
        """Construct a PatchGAN discriminator"""
        super(NLayerDiscriminator, self).__init__()
        use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = 1
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw,stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]
        # output 1 channel prediction map
        sequence += [nn.Conv2d(ndf * nf_mult, 1,kernel_size=kw, stride=1, padding=padw)]
        self.model = nn.Sequential(*sequence)
        init_weights(self.model)

    def forward(self, input):
        """Standard forward."""
        return self.model(input)


# 学习率调度

def get_scheduler(optimizer):
    """Return a learning rate scheduler
        前100个epoch保持不变,后100个epoch线性衰减到0
    """
    def lambda_rule(epoch):
        lr_l = 1.0 - max(0, epoch + EPOCH_COUNT -N_EPOCHS) / float(N_EPOCHS_DECAY + 1)
        return lr_l
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
    return scheduler

def train_fn(disc_H, disc_Z, gen_H, gen_Z, loader, opt_disc, opt_gen, l1, mse):
    fake_H_pool = ImagePool(50)
    fake_Z_pool = ImagePool(50)
    H_reals = 0
    H_fakes = 0
    Z_reals = 0
    Z_fakes = 0
    loop = tqdm(loader, leave=True)

    for idx, data in enumerate(loop):
        zebra = data['B'].to(DEVICE)
        horse = data['A'].to(DEVICE)

        # Train Discriminators H and Z
        fake_horse = gen_H(zebra)
        fake_horse_train = fake_H_pool.query(fake_horse)
        D_H_real = disc_H(horse)
        D_H_fake = disc_H(fake_horse_train.detach())
        H_reals += D_H_real.mean().item()
        H_fakes += D_H_fake.mean().item()
        D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
        D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
        D_H_loss = D_H_real_loss + D_H_fake_loss

        fake_zebra = gen_Z(horse)
        fake_zebra_train = fake_Z_pool.query(fake_zebra)
        D_Z_real = disc_Z(zebra)
        D_Z_fake = disc_Z(fake_zebra_train.detach())
        Z_reals += D_Z_real.mean().item()
        Z_fakes += D_Z_fake.mean().item()
        D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real))
        D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
        D_Z_loss = D_Z_real_loss + D_Z_fake_loss

        # put it togethor
        D_loss = (D_H_loss + D_Z_loss)/2

        opt_disc.zero_grad()
        D_loss.backward()
        opt_disc.step()

        # Train Generators H and Z
        # adversarial loss for both generators
        D_H_fake = disc_H(fake_horse)
        D_Z_fake = disc_Z(fake_zebra)
        loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
        loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))

        # cycle loss
        cycle_zebra = gen_Z(fake_horse)
        cycle_horse = gen_H(fake_zebra)
        cycle_zebra_loss = l1(zebra, cycle_zebra)
        cycle_horse_loss = l1(horse, cycle_horse)

        # identity loss (remove these for efficiency if you set lambda_identity=0)
        identity_zebra = gen_Z(zebra)
        identity_horse = gen_H(horse)
        identity_zebra_loss = l1(zebra, identity_zebra)
        identity_horse_loss = l1(horse, identity_horse)

        # add all togethor
        G_loss = (
            loss_G_Z
            + loss_G_H
            + cycle_zebra_loss * LAMBDA_CYCLE
            + cycle_horse_loss * LAMBDA_CYCLE
            + identity_horse_loss * LAMBDA_IDENTITY
            + identity_zebra_loss * LAMBDA_IDENTITY
        )

        opt_gen.zero_grad()
        G_loss.backward()
        opt_gen.step()

        if idx % 200 == 0:
            save_image(fake_horse*0.5+0.5, f"train_images/horse_{idx}.png")
            save_image(fake_zebra*0.5+0.5, f"train_images/zebra_{idx}.png")

        loop.set_postfix(H_real=H_reals/(idx+1), H_fake=H_fakes /
                         (idx+1), D_real=Z_reals/(idx+1), D_fake=Z_fakes/(idx+1))

class CombineDataset(Dataset):
    def __init__(self, root_A, root_B, transform):
        self.root_A = root_A
        self.root_B = root_B
        self.transform = transform

        self.A_paths = os.listdir(root_A)
        self.B_paths = os.listdir(root_B)
        self.length_dataset = max(len(self.A_paths), len(self.B_paths))
        self.A_len = len(self.A_paths)
        self.B_len = len(self.B_paths)

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        A_path = self.A_paths[index % self.A_len]
        B_path = self.B_paths[index % self.B_len]

        A_img = Image.open(self.root_A+A_path).convert("RGB")
        B_img = Image.open(self.root_B+B_path).convert("RGB")

        A = self.transform(A_img)
        B = self.transform(B_img)

        return {'A': A, 'B': B}

def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # 修改学习率,使用当前的学习率
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

dataset = CombineDataset(root_A=TRAIN_DIR+"/trainA/",
                         root_B=TRAIN_DIR+"/trainB/", transform=transforms)
data_loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)
dataset_size = len(data_loader)
print('The number of training images = %d' % dataset_size)

disc_H = NLayerDiscriminator(input_nc=3).to(DEVICE)
disc_Z = NLayerDiscriminator(input_nc=3).to(DEVICE)
gen_Z = ResnetGenerator(input_nc=3, output_nc=3).to(DEVICE)
gen_H = ResnetGenerator(input_nc=3, output_nc=3).to(DEVICE)

opt_disc = optim.Adam(
    list(disc_H.parameters()) + list(disc_Z.parameters()),
    lr=LEARNING_RATE,
    betas=(0.5, 0.999),
)

opt_gen = optim.Adam(
    list(gen_Z.parameters()) + list(gen_H.parameters()),
    lr=LEARNING_RATE,
    betas=(0.5, 0.999),
)

scheduler_disc = get_scheduler(opt_disc)
scheduler_gen = get_scheduler(opt_gen)
L1 = nn.L1Loss()
mse = nn.MSELoss()

if LOAD_MODEL:
    load_checkpoint(
        CHECKPOINT_GEN_H, gen_H, opt_gen, LEARNING_RATE,
    )
    load_checkpoint(
        CHECKPOINT_GEN_Z, gen_Z, opt_gen, LEARNING_RATE,
    )
    load_checkpoint(
        CHECKPOINT_CRITIC_H, disc_H, opt_disc, LEARNING_RATE,
    )
    load_checkpoint(
        CHECKPOINT_CRITIC_Z, disc_Z, opt_disc, LEARNING_RATE,
    )

for epoch in range(EPOCH_COUNT, N_EPOCHS+N_EPOCHS_DECAY+1):

    train_fn(disc_H, disc_Z, gen_H, gen_Z,
             data_loader, opt_disc, opt_gen, L1, mse)
    scheduler_disc.step()
    scheduler_gen.step()
    if SAVE_MODEL:
        save_checkpoint(gen_H, opt_gen, filename=CHECKPOINT_GEN_H)
        save_checkpoint(gen_Z, opt_gen, filename=CHECKPOINT_GEN_Z)
        save_checkpoint(disc_H, opt_disc, filename=CHECKPOINT_CRITIC_H)
        save_checkpoint(disc_Z, opt_disc, filename=CHECKPOINT_CRITIC_Z)





 完整工程训练参数数据集若需要请留言。

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
cyclegan-pytorch-master是一个基于PyTorch深度学习框架实现CycleGAN模型。CycleGAN是一种无监督的图像转换方法,它可以在两个不同领域的图像之间进行双向转换。 具体而言,cyclegan-pytorch-master中的代码实现CycleGAN中的生成器和判别器网络,以及训练循环和测试函数。生成器网络负责将输入图像从一个领域转换到另一个领域,而判别器网络则用于区分生成的图像和真实的图像。生成器和判别器使用卷积神经网络结构进行建模,可以通过训练过程不断优化网络参数。 在训练循环中,通过最小化生成图像和真实图像之间的差异,使得生成器能够逐渐学习到领域之间的映射关系。同时,判别器也在不断优化中,使其能够更准确地区分生成的图像和真实的图像。通过交替训练生成器和判别器,CycleGAN可以实现两个领域之间的双向图像转换。 此外,在cyclegan-pytorch-master中还包含了测试函数,可将训练好的模型应用于新的图像转换任务。测试函数可以加载已训练好的生成器模型,并将输入图像转换到另一个领域中。通过这种方式,用户可以在训练过的模型上进行图像转换,实现各种有趣的应用,如狗到猫的转换、夏天到冬天的转换等。 总而言之,cyclegan-pytorch-master是一个实现CycleGAN模型的PyTorch代码库,可用于图像领域之间的双向转换任务,并提供了训练循环和测试函数来支持模型的训练和应用。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值