【对抗网络】CycleGAN模型讲解和代码实现

cycleGAN理论讲解:

论文地址:https://arxiv.org/pdf/1703.10593.pdf
        cycleGAN适用于非配对的图像到图像转换,cycleGAN解决了需要对数据进行训练的困难。

        可以看到上图中左边是配对图片,鞋子的素描和鞋子的真实图片为一对。右边为非配对图片,X是真实图片,Y是油画风格图片。

        CycleGAN的原理可以概述为:将一类图片转换成另一类图片。也就是说,现在有两个样本空间,X 和 Y, 我们希望把 X 空间中的样本转换成 Y 空间中的样本。可以理解为一种风格上的转换。

        这样来看:实际的目标就是学习从 X 到 Y 的映射。我们假设这个映射为F。他就对应着GAN中的生成器,F可以将X中的图片x转换为Y中的图片F(x)。对于生成的图片,我们还需要GAN中的判别器来判别它是否为真实图片,由此构成对抗生成网络。

CycleGAN的整体架构:

关于损失函数:

        这里有一个问题是在足够大的样本容量下,网络可以将相同的输入图像集合映射到目标域中图像的任何随机排列,其中任何学习的映射可以归纳出与目标分布匹配的输出分布。换句话说,映射F完全可以将所有 X 都映射为 Y 空间中的同一张图片,是的损失无效化。因此单独的对抗损失Loss不能保证学习函数可以将单个输入 Xi 映射到期望的输出 Yi。对此,论文作者提出了所谓的”循环一致性损失“(cycle consistency loss)

循环一致损失:

还有一个identity loss:

可以理解为,生成器是负责域 X 到 域 Y 的图像生成,如果输入域Y的图片还是应该生成域Y的图片y‘’,计算 y‘’ 和 输入y 的loss。

总损失:

训练结果:

           

epoch = 1

           

epoch = 15

           

epoch = 30

           

epoch = 45

导入的库:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils import data

import matplotlib.pyplot as plt
import glob
from PIL import Image
import itertools
import numpy as np

训练代码:

# 准备训练数据集
mans_path = glob.glob("data/man_woman/trianA/*.jpg")
print(len(mans_path))
plt.figure(figsize=(12, 8))
for i, man_path in enumerate(mans_path[:4]):
    man_img = Image.open(man_path)
    man_np_img = np.array(man_img)
    plt.subplot(2, 2, i + 1)
    plt.imshow(man_np_img)
    plt.title(str(man_np_img.shape))
plt.show()

womans_path = glob.glob("data/man_woman/trainB/*.jpg")
print(len(womans_path))
plt.figure(figsize=(12, 8))
for i, woman_path in enumerate(womans_path[:4]):
    woman_img = Image.open(woman_path)
    woman_np_img = np.array(woman_img)
    plt.subplot(2, 2, i + 1)
    plt.imshow(woman_np_img)
    plt.title(str(woman_np_img.shape))
plt.show()

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((256, 256)),
    transforms.Normalize(0.5, 0.5)
])


class MW_dataset(data.Dataset):
    def __init__(self, img_path):
        self.img_path = img_path

    def __getitem__(self, index):
        imgpath = self.img_path[index]
        pil_img = Image.open(imgpath)
        pil_img = transform(pil_img)
        return pil_img

    def __len__(self):
        return len(self.img_path)


man_dataset = MW_dataset(mans_path)
woman_dataset = MW_dataset(womans_path)

BATCHSIZE = 4
man_dl = data.DataLoader(man_dataset, batch_size=BATCHSIZE, shuffle=True)
woman_dl = data.DataLoader(woman_dataset, batch_size=BATCHSIZE, shuffle=True)

man_batch = next(iter(man_dl))
woman_batch = next(iter(woman_dl))
# 打印处理好的图片数据
fig = plt.figure(figsize=(8, 15))
for i, (m, w) in enumerate(zip(man_batch[:3], woman_batch[:3])):
    m = (m.permute(1, 2, 0).numpy() + 1) / 2
    w = (w.permute(1, 2, 0).numpy() + 1) / 2
    plt.subplot(3, 2, 2 * i + 1)
    plt.title("man")
    plt.imshow(m)
    plt.subplot(3, 2, 2 * i + 2)
    plt.title("woman")
    plt.imshow(w)
plt.show()

# 准备测试数据集
mans_path_test = glob.glob("data/man_woman/testA/*.jpg")
womans_path_test = glob.glob("data/man_woman/testB/*.jpg")

man_dataset_test = MW_dataset(mans_path_test)
woman_dataset_test = MW_dataset(womans_path_test)

man_dl_test = data.DataLoader(man_dataset_test, batch_size=BATCHSIZE, shuffle=True)
woman_dl_test = data.DataLoader(woman_dataset_test, batch_size=BATCHSIZE, shuffle=True)


# 定义下采样模块
class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Downsample, self).__init__()
        self.conv_relu = nn.Sequential(
            nn.Conv2d(in_channels=in_channels,
                      out_channels=out_channels,
                      kernel_size=3,
                      stride=2,
                      padding=1),
            nn.LeakyReLU(inplace=True)  # inplce=True:就地修改输入张量
        )
        self.bn = nn.InstanceNorm2d(out_channels)  # 因为这里我们需要优化每一张图片,所以要用instancenorm2d

    def forward(self, x, is_bn=True):
        x = self.conv_relu(x)
        if is_bn:
            x = self.bn(x)
        return x


# 定义上采样模块
class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Upsample, self).__init__()
        self.upconv1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=in_channels,
                               out_channels=out_channels,
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               output_padding=1),
            nn.LeakyReLU(inplace=True)  # inplce=True:就地修改输入张量
        )
        self.bn = nn.InstanceNorm2d(out_channels)  # 因为这里我们需要优化每一张图片,所以要用instancenorm2d

    def forward(self, x, is_drop=False):
        x = self.upconv1(x)
        x = self.bn(x)
        if is_drop:
            x = F.dropout2d(x)
        return x


# 初始化生成器:6个下采样,5个上采样+1个输出层
# PS:实战中建议画出模型图,方便了解输入层和输出层的关系(U-net)
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.down1 = Downsample(3, 64)  # (64, 128, 128)
        self.down2 = Downsample(64, 128)  # (128, 64, 64)
        self.down3 = Downsample(128, 256)  # (256, 32, 32)
        self.down4 = Downsample(256, 512)  # (512, 16, 16)
        self.down5 = Downsample(512, 512)  # (512, 8, 8)
        self.down6 = Downsample(512, 512)  # (512, 4, 4)

        self.up1 = Upsample(512, 512)  # (512, 8, 8)
        self.up2 = Upsample(1024, 512)  # (512, 16, 16)
        self.up3 = Upsample(1024, 256)  # (256, 32, 32)
        self.up4 = Upsample(512, 128)  # (128, 64, 64)
        self.up5 = Upsample(256, 64)  # (64, 128, 128)

        self.last = nn.ConvTranspose2d(128, 3,
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       output_padding=1)

    def forward(self, x):
        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.down5(x4)
        x6 = self.down6(x5)

        x6 = self.up1(x6, is_drop=True)
        x6 = torch.cat([x6, x5], dim=1)
        x6 = self.up2(x6, is_drop=True)
        x6 = torch.cat([x6, x4], dim=1)
        x6 = self.up3(x6, is_drop=True)
        x6 = torch.cat([x6, x3], dim=1)
        x6 = self.up4(x6)
        x6 = torch.cat([x6, x2], dim=1)
        x6 = self.up5(x6)
        x6 = torch.cat([x6, x1], dim=1)

        x6 = torch.tanh(self.last(x6))
        return x6


# 初始化判别器(patchGAN)  输入anno+img(生成的或者真实的)  concat
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.down1 = Downsample(3, 64)  # (64, 128, 128)  PS:这里输入的6:anno+img
        self.down2 = Downsample(64, 128)  # (128, 64, 64)
        self.last = nn.Conv2d(128, 1, 3)  # (1, 62, 62)

    def forward(self, img):
        x = self.down1(img)
        x = self.down2(x)
        x = torch.sigmoid(self.last(x))  # (batch, 1, 60, 60)
        return x


device = "cuda" if torch.cuda.is_available() else "cpu"

# 创建两个生成器,两个判别器
gen_AB = Generator().to(device)
gen_BA = Generator().to(device)
dis_A = Discriminator().to(device)
dis_B = Discriminator().to(device)

# 定义损失函数  1.gan loss 2.cycle consistance 3.identity loss
bceloss = torch.nn.BCELoss()
l1_loss = torch.nn.L1Loss()

# 初始化优化器
gen_optimizer = torch.optim.Adam(
    itertools.chain(gen_AB.parameters(), gen_BA.parameters()),
    lr=2e-4,
    betas=(0.5, 0.999)
)
dis_optimizer_A = torch.optim.Adam(
    itertools.chain(dis_A.parameters(), gen_BA.parameters()),
    lr=2e-4,
    betas=(0.5, 0.999)
)
dis_optimizer_B = torch.optim.Adam(
    itertools.chain(dis_B.parameters(), gen_BA.parameters()),
    lr=2e-4,
    betas=(0.5, 0.999)
)


# 画图函数
def generate_image(model, test_input):
    predictions = model(test_input).permute(0, 2, 3, 1).detach().cpu().numpy()
    test_input = test_input.permute(0, 2, 3, 1).cpu().numpy()
    title_list = ["input", "output"]
    display_list = [test_input[0], predictions[0]]
    fig = plt.figure(figsize=(10, 6))
    for i in range(2):
        plt.subplot(1, 2, i + 1)
        plt.title(title_list[i])
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis("off")
    plt.show()


test_batch = next(iter(man_dl_test))
# 因为我们只用一张图片作为test_input,所以没有bartchsize,因此用unsquezze设置batchsize为0
test_input = torch.unsqueeze(test_batch[0], 0).to(device)
# 训练模型
D_loss = []
G_loss = []
best_gen_loss = float("inf")
EPOCH = 5
for epoch in range(EPOCH):
    D_epoch_loss = 0
    G_epoch_loss = 0
    count = min(len(man_dl), len(woman_dl))

    for step, (real_A, real_B) in enumerate(zip(man_dl, woman_dl)):
        real_A = real_A.to(device)
        real_B = real_B.to(device)
        # 训练生成器(Generator)
        gen_optimizer.zero_grad()
        # identity loss
        same_A = gen_BA(real_A)
        same_A_loss = l1_loss(same_A, real_A)

        same_B = gen_AB(real_B)
        same_B_loss = l1_loss(same_B, real_B)

        # gan loss 对抗损失
        fake_A = gen_BA(real_B)
        fake_A_output = dis_A(fake_A)
        fake_A_output_loss = bceloss(fake_A_output, torch.ones_like(fake_A_output, device=device))

        fake_B = gen_AB(real_A)
        fake_B_output = dis_B(fake_B)
        fake_B_output_loss = bceloss(fake_B_output, torch.ones_like(fake_B_output, device=device))

        # cycle loss  循环一致损失
        recovered_A = gen_BA(fake_B)
        cycle_ABA_loss = l1_loss(recovered_A, real_A)

        recovered_B = gen_AB(fake_A)
        cycle_BAB_loss = l1_loss(recovered_B, real_B)

        g_loss = (same_A_loss + same_B_loss + fake_A_output_loss + fake_B_output_loss + cycle_ABA_loss + cycle_BAB_loss)
        g_loss.backward()
        gen_optimizer.step()

        # 训练判别器(Discriminator)
        # 训练dis_A
        dis_optimizer_A.zero_grad()
        real_A_output = dis_A(real_A)
        real_A_loss = bceloss(real_A_output, torch.ones_like(real_A_output))

        fake_A_output = dis_A(fake_A.detach())
        fake_A_loss = bceloss(fake_A_output, torch.zeros_like(fake_A_output))

        dis_A_loss = real_A_loss + fake_A_loss
        dis_A_loss.backward()
        dis_optimizer_A.step()
        # 训练dis_B
        dis_optimizer_B.zero_grad()
        real_B_output = dis_B(real_B)
        real_B_loss = bceloss(real_B_output, torch.ones_like(real_B_output))

        fake_B_output = dis_B(fake_B.detach())
        fake_B_loss = bceloss(fake_B_output, torch.zeros_like(fake_B_output))

        dis_B_loss = real_B_loss + fake_B_loss
        dis_B_loss.backward()
        dis_optimizer_B.step()

        with torch.no_grad():
            G_epoch_loss += g_loss.item()
            D_epoch_loss += (dis_A_loss + dis_B_loss).item()
        # 保存最好的模型
        if G_epoch_loss < best_gen_loss:
            best_gen_loss = G_epoch_loss

            # 保存生成器的状态字典
            torch.save(gen_AB.state_dict(), 'best_cycleGAN_model.pth')

    with torch.no_grad():
        D_epoch_loss /= count
        G_epoch_loss /= count
        D_loss.append(D_epoch_loss)
        G_loss.append(G_epoch_loss)
        print("Epoch:{}".format(epoch),
              "g_epoch_loss:{}".format(G_epoch_loss),
              "d_epoch_loss:{}".format(D_epoch_loss))
        # if epoch % 5 == 0:
        #     generate_image(gen_AB, test_input)

使用训练好的模型:

import os
import torch
import torchvision.utils
from torchvision import transforms
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F

# 定义下采样模块
class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Downsample, self).__init__()
        self.conv_relu = nn.Sequential(
            nn.Conv2d(in_channels=in_channels,
                      out_channels=out_channels,
                      kernel_size=3,
                      stride=2,
                      padding=1),
            nn.LeakyReLU(inplace=True)  # inplce=True:就地修改输入张量
        )
        self.bn = nn.InstanceNorm2d(out_channels)  # 因为这里我们需要优化每一张图片,所以要用instancenorm2d

    def forward(self, x, is_bn=True):
        x = self.conv_relu(x)
        if is_bn:
            x = self.bn(x)
        return x


# 定义上采样模块
class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Upsample, self).__init__()
        self.upconv1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=in_channels,
                               out_channels=out_channels,
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               output_padding=1),
            nn.LeakyReLU(inplace=True)  # inplce=True:就地修改输入张量
        )
        self.bn = nn.InstanceNorm2d(out_channels)  # 因为这里我们需要优化每一张图片,所以要用instancenorm2d

    def forward(self, x, is_drop=False):
        x = self.upconv1(x)
        x = self.bn(x)
        if is_drop:
            x = F.dropout2d(x)
        return x


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.down1 = Downsample(3, 64)  # (64, 128, 128)
        self.down2 = Downsample(64, 128)  # (128, 64, 64)
        self.down3 = Downsample(128, 256)  # (256, 32, 32)
        self.down4 = Downsample(256, 512)  # (512, 16, 16)
        self.down5 = Downsample(512, 512)  # (512, 8, 8)
        self.down6 = Downsample(512, 512)  # (512, 4, 4)

        self.up1 = Upsample(512, 512)  # (512, 8, 8)
        self.up2 = Upsample(1024, 512)  # (512, 16, 16)
        self.up3 = Upsample(1024, 256)  # (256, 32, 32)
        self.up4 = Upsample(512, 128)  # (128, 64, 64)
        self.up5 = Upsample(256, 64)  # (64, 128, 128)

        self.last = nn.ConvTranspose2d(128, 3,
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       output_padding=1)

    def forward(self, x):
        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.down5(x4)
        x6 = self.down6(x5)

        x6 = self.up1(x6, is_drop=True)
        x6 = torch.cat([x6, x5], dim=1)
        x6 = self.up2(x6, is_drop=True)
        x6 = torch.cat([x6, x4], dim=1)
        x6 = self.up3(x6, is_drop=True)
        x6 = torch.cat([x6, x3], dim=1)
        x6 = self.up4(x6)
        x6 = torch.cat([x6, x2], dim=1)
        x6 = self.up5(x6)
        x6 = torch.cat([x6, x1], dim=1)

        x6 = torch.tanh(self.last(x6))
        return x6


print("00000000")
# 确保文件夹存在
output_folder = "output"
os.makedirs(output_folder, exist_ok=True)

# 初始化生成器模型
gen_AB = Generator()
# 加载保存的模型状态字典
gen_AB.load_state_dict(torch.load("best_cycleGAN_model.pth"))
# 初始化数据集
img_path = "input.jpg"
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((256, 256)),
    transforms.Normalize(0.5, 0.5)
])
img = Image.open(img_path)
img = transform(img)
img = img.unsqueeze(0)
img = img
output = gen_AB(img).detach().cpu()
torchvision.utils.save_image((img + 1) / 2, os.path.join(output_folder, "output.jpg"))


 

  • 1
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值