GAN-生成对抗网络(Pytorch)合集(2)--pixtopix-CycleGAN

pixtopix(像素到像素)

原文连接:https://arxiv.org/pdf/1611.07004.pdf
输入一个域的图片转换为另一个域的图片(白天照片转成黑夜)
如下图,输入标记图片,输出真实图片缺点就是训练集两个域的图片要一一对应,所以叫pixtopix,
在这里插入图片描述

网络结构有点复杂,用到了语义分割的UNET网络结构
在这里插入图片描述
数据集:
地址忘了,也是官方的,想起来补
代码:这里是建筑物labels to facade的例子

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import torchvision
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
from PIL import Image

# jpg是原始图片
images_path = glob.glob(r'base\*.jpg')
annos_path = glob.glob(r'base\*.png')
# png是分割的图片

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((256, 256)),
    transforms.Normalize(0.5, 0.5)
])


class CMP_dataset(data.Dataset):
    def __init__(self, imgs_path, annos_path):
        self.imgs_path = imgs_path
        self.annos_path = annos_path

    def __getitem__(self, item):
        img_path = self.imgs_path[item]
        anno_path = self.annos_path[item]
        pil_img = Image.open(img_path)
        pil_img = transform(pil_img)

        anno_img = Image.open(anno_path)
        anno_img = anno_img.convert('RGB')
        pil_anno = transform(anno_img)
        return pil_anno, pil_img

    def __len__(self):
        return len(self.imgs_path)


dataset = CMP_dataset(images_path, annos_path)
batchsize = 32
dataloader = data.DataLoader(dataset,
                             batch_size=batchsize,
                             shuffle=True)

annos_batch, images_batch = next(iter(dataloader))

for i, (anno, img) in enumerate(zip(annos_batch[:3], images_batch[:3])):
    anno = (anno.permute(1, 2, 0).numpy()+1)/2
    img = (img.permute(1, 2, 0).numpy()+1)/2
    plt.subplot(3, 2, i*2+1)
    plt.title('input_img')
    plt.imshow(anno)

    plt.subplot(3, 2, i*2+2)
    plt.title('output_img')
    plt.imshow(img)
plt.show()

# 定义下采样模块
class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Downsample, self).__init__()
        self.conv_relu = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 2, 1),
            nn.LeakyReLU(inplace=True)
        )
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x, is_bn=True):
        x = self.conv_relu(x)
        if is_bn:
            x = self.bn(x)
        return x


# 定义上采样模块
class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Upsample, self).__init__()
        self.upconv_relu = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, 3, 2, 1,
                               output_padding=1),
            nn.LeakyReLU(inplace=True)
        )
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x, is_drop=False):
        x = self.upconv_relu(x)
        x = self.bn(x)
        if is_drop:
            x = F.dropout2d(x)
        return x


# 定义生成器,包含6个下采样,5上采样,1输出
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.down1 = Downsample(3, 64)   # 64,128,128
        self.down2 = Downsample(64, 128)  # 128,64,64
        self.down3 = Downsample(128, 256)  # 256,32,32
        self.down4 = Downsample(256, 512)  # 512, 16,16
        self.down5 = Downsample(512, 512)  # 512,8,8
        self.down6 = Downsample(512, 512)  # 512, 4,4

        self.up1 = Upsample(512, 512)      # 512 ,8,8
        self.up2 = Upsample(1024, 512)    # 512, 16,16
        self.up3 = Upsample(1024, 256)   # 256, 32,32
        self.up4 = Upsample(512, 128)   # 128,64,64
        self.up5 = Upsample(256, 64)   # 64,128,128

        self.last = nn.ConvTranspose2d(128, 3,
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       output_padding=1)
    def forward(self,x):
        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.down5(x4)
        x6 = self.down6(x5)

        x6 = self.up1(x6, is_drop=True)
        x6 = torch.cat([x6, x5], dim=1)

        x6 = self.up2(x6, is_drop=True)
        x6 = torch.cat([x6, x4], dim=1)

        x6 = self.up3(x6, is_drop=True)
        x6 = torch.cat([x6, x3], dim=1)

        x6 = self.up4(x6, is_drop=True)
        x6 = torch.cat([x6, x2], dim=1)

        x6 = self.up5(x6)
        x6 = torch.cat([x6, x1], dim=1)

        x6 = torch.tanh(self.last(x6))

        return x6


# 定义判别器 输入anno + img
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.down1 = Downsample(6, 64)  # 64*128*128
        self.down2 = Downsample(64, 128)  # 128*64*64
        self.conv1 = nn.Conv2d(128, 256, 3)
        self.bn1 = nn.BatchNorm2d(256)
        self.conv2 = nn.Conv2d(256, 1, 3)

    def forward(self, anno, img):
        x = torch.cat([anno, img], axis=1)  # batch*6*h*w
        x = self.down1(x, is_bn=False)
        x = self.down2(x)
        x = F.dropout2d(self.bn1(F.leaky_relu(self.conv1(x))))
        x = torch.sigmoid(self.conv2(x))   # batch*1* 60*60
        return x


device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cuda':
    print('using cuda:', torch.cuda.get_device_name(0))
else:
    print(device)

Gen = Generator().to(device)
Dis = Discriminator().to(device)

d_optimizer = torch.optim.Adam(Dis.parameters(), lr=1e-3, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(Gen.parameters(), lr=1e-3, betas=(0.5, 0.999))
# loss
# cgan损失
loss_fn = torch.nn.BCELoss()
# L1-loss 后面计算,求差绝对值的求和
# 绘图
def generator_images(model, test_anno, test_real):
    prediction = model(test_anno).permute(0, 2, 3, 1).detach().cpu().numpy()
    test_anno = test_anno.permute(0, 2, 3, 1).detach().cpu().numpy()

    test_real = test_real.permute(0, 2, 3, 1).detach().cpu().numpy()
    plt.figure(figsize=(10, 10))
    display_list = [test_anno[0], test_real[0], prediction[0]]
    title = ['input', 'ground truth', 'output']
    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i])
        plt.axis('off')
    plt.show()

# 加载extend为测试
test_imgs_path = glob.glob('extended/*.jpg')
test_annos_path = glob.glob('extended/*.png')

test_dataset = CMP_dataset(test_imgs_path, test_annos_path)
test_daloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batchsize
)
# 返回一个批次

annos_batch, images_batch = next(iter(dataloader))

plt.figure(figsize=(6, 10))
for i, (anno, img) in enumerate(zip(annos_batch[:3], images_batch[:3])):
    anno = (anno.permute(1, 2, 0).numpy()+1)/2
    img = (img.permute(1, 2, 0).numpy()+1)/2
    plt.subplot(3, 2, i*2+1)
    plt.title('input_img')
    plt.imshow(anno)

    plt.subplot(3, 2, i*2+2)
    plt.title('output_img')
    plt.imshow(img)
plt.show()

annos_batch, images_batch = annos_batch.to(device), images_batch.to(device)
LAMBDA = 7  # L1损失权重

D_loss = []
G_loss = []
for epoch in range(300):
    D_epoch_loss = 0
    G_epoch_loss = 0
    count = len(dataloader)
    for step, (annos, imgs) in enumerate(dataloader):
        imgs = imgs.to(device)
        annos = annos.to(device)

        d_optimizer.zero_grad()
        disc_real_output = Dis(annos, imgs)  # 输入真实成对图片
        d_real_loss = loss_fn(disc_real_output, torch.ones_like(disc_real_output,
                                                                device=device)
                              )
        d_real_loss.backward()

        gen_output = Gen(annos)
        dis_gen_output = Dis(annos, gen_output.detach())
        d_fake_loss = loss_fn(dis_gen_output, torch.zeros_like(dis_gen_output,
                                                               device=device)
                              )
        d_fake_loss.backward()

        disc_loss = d_real_loss + d_fake_loss

        d_optimizer.step()

        disc_gen_out = Dis(annos, gen_output)
        gen_loss_crossentropyloss = loss_fn(disc_gen_out,
                                            torch.ones_like(disc_gen_out,
                                                            device=device)
                                            )
        gen_l1_loss = torch.mean(torch.abs(gen_output - imgs))
        gen_loss = LAMBDA * gen_l1_loss + gen_loss_crossentropyloss
        gen_loss.backward()
        g_optimizer.step()

        with torch.no_grad():
            D_epoch_loss += disc_loss.item()
            G_epoch_loss += gen_loss.item()
    with torch.no_grad():
        D_epoch_loss /= count
        G_epoch_loss /= count
        D_loss.append(D_epoch_loss)
        G_loss.append(G_epoch_loss)
        print('Epoch', epoch)
        generator_images(Gen, annos_batch, images_batch)

给动漫素描自动上色的(AI上色)移步我的kaggle
https://www.kaggle.com/code/jiyuanhai/pix2pix-test-pytorch

CycleGAN

这个厉害👍,我愿称之为最强,克服了pixtopix需要数据集一一对应的缺点
论文地址:https://arxiv.org/pdf/1703.10593.pdf
【推荐同济子豪兄】或者看论文详解:https://www.bilibili.com/video/BV1Ya411a78P?spm_id_from=333.999.0.0&vd_source=66d85dad339b02807124d27ef76332c9
B站也有很多讲的不错的视频。
创新型的提出了循环一致性损失,具体技术不多赘述了,有些复杂。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import torchvision
from torchvision import transforms

import numpy as np
import matplotlib.pyplot as plt
import os
import glob
from PIL import Image
import itertools

apples_path = glob.glob(r'trainA\*.jpg')

# 画图显示
# plt.figure(figsize=(8, 8))
# for i, imh_path in enumerate(apples_path[:4]):
#     img = Image.open(imh_path)
#     np_image = np.array(img)
#     plt.subplot(2, 2, i+1)
#     plt.imshow(np_image)
#     plt.title(str(np_image.shape))
# plt.show()

oranges_path = glob.glob(r'trainB\*.jpg')

# plt.figure(figsize=(8, 8))
# for i, imh_path in enumerate(oranges_path[:4]):
#     img = Image.open(imh_path)
#     np_image = np.array(img)
#     plt.subplot(2, 2, i+1)
#     plt.imshow(np_image)
#     plt.title(str(np_image.shape))
# plt.show()
apples_test_path = glob.glob(r'trainA\*.jpg')

#数据集已经处理成了256,不用裁减
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)
])

class AO_Dataset(data.Dataset):
    def __init__(self, img_path):  # 初始化方法
        self.img_path = img_path

    def __getitem__(self, index):
        imgpath = self.img_path[index]
        pil_img = Image.open(imgpath)
        pil_img = transform(pil_img)
        return pil_img

    def __len__(self):
        return len(self.img_path)


apple_dataset = AO_Dataset(apples_path)
orange_dataset = AO_Dataset(oranges_path)
apple_test_dataset = AO_Dataset(apples_test_path)

BATHSIZE = 2
NUMWORKERS = 10

apple_dataloader = data.DataLoader(apple_dataset,
                                   batch_size=BATHSIZE,
                                   shuffle=True,
                                   #num_workers=NUMWORKERS
                                   )

orange_dataloader = data.DataLoader(orange_dataset,
                                    batch_size=BATHSIZE,
                                    shuffle=True,
                                    #num_workers=NUMWORKERS
                                    )
apple_dl_test = data.DataLoader(
    apple_test_dataset,
    batch_size=BATHSIZE,
    shuffle=True
)
# 创建模型
# 定义下采样模块
class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Downsample, self).__init__()
        self.conv_relu = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 2, 1),
            nn.LeakyReLU(inplace=True)
        )
        self.bn = nn.InstanceNorm2d(out_channels)

    def forward(self, x, is_bn=True):
        x = self.conv_relu(x)
        if is_bn:
            x = self.bn(x)
        return x

# 定义上采样模块
class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Upsample, self).__init__()
        self.upconv_relu = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, 3, 2, 1,
                               output_padding=1),
            nn.LeakyReLU(inplace=True)
        )
        self.bn = nn.InstanceNorm2d(out_channels)

    def forward(self, x, is_drop=False):
        x = self.upconv_relu(x)
        x = self.bn(x)
        if is_drop:
            x = F.dropout2d(x)
        return x

# 定义生成器,包含6个下采样,5上采样,1输出
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.down1 = Downsample(3, 64)   # 64,128,128
        self.down2 = Downsample(64, 128)  # 128,64,64
        self.down3 = Downsample(128, 256)  # 256,32,32
        self.down4 = Downsample(256, 512)  # 512, 16,16
        self.down5 = Downsample(512, 512)  # 512,8,8
        self.down6 = Downsample(512, 512)  # 512, 4,4

        self.up1 = Upsample(512, 512)      # 512 ,8,8
        self.up2 = Upsample(1024, 512)    # 512, 16,16
        self.up3 = Upsample(1024, 256)   # 256, 32,32
        self.up4 = Upsample(512, 128)   # 128,64,64
        self.up5 = Upsample(256, 64)   # 64,128,128

        self.last = nn.ConvTranspose2d(128, 3,
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       output_padding=1)
    def forward(self,x):
        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.down5(x4)
        x6 = self.down6(x5)

        x6 = self.up1(x6, is_drop=True)
        x6 = torch.cat([x6, x5], dim=1)

        x6 = self.up2(x6, is_drop=True)
        x6 = torch.cat([x6, x4], dim=1)

        x6 = self.up3(x6, is_drop=True)
        x6 = torch.cat([x6, x3], dim=1)

        x6 = self.up4(x6, is_drop=True)
        x6 = torch.cat([x6, x2], dim=1)

        x6 = self.up5(x6)
        x6 = torch.cat([x6, x1], dim=1)

        x6 = torch.tanh(self.last(x6))

        return x6

# 定义判别器 输入
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.down1 = Downsample(3, 64)  # 128
        self.down2 = Downsample(64, 128)  # 64
        self.last = nn.Conv2d(128, 1, 3)

    def forward(self, img):
        x = self.down1(img)
        x = self.down2(x)
        x = torch.sigmoid(self.last(x))   # batch*1* 60*60
        return x


device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cuda':
    print('using cuda:', torch.cuda.get_device_name(0))
else:
    print(device)

# 初始化两个生成器(A->B  B->A),
gen_AB = Generator().to(device)
gen_BA = Generator().to(device)

# 两个判别器
dis_A = Discriminator().to(device)
dis_B = Discriminator().to(device)

# 损失函数
# 1,对抗loss  BCE
# 2,cycle consistance loss
# 3,identit loss

BECLoss = torch.nn.BCELoss()
L1_loss = torch.nn.L1Loss()
gen_optimizer = torch.optim.Adam(
    itertools.chain(
        gen_AB.parameters(),
        gen_BA.parameters()
    ),
    lr=2e-4,
    betas=(0.5, 0.999)
)

dis_A_optimizer = torch.optim.Adam(
    dis_A.parameters(),
    lr=2e-4,
    betas=(0.5, 0.999)
)

dis_B_optimizer = torch.optim.Adam(
    dis_B.parameters(),
    lr=2e-4,
    betas=(0.5, 0.999)
)


def generate_images(model, test_input):
    prediction = model(test_input).permute(0, 2, 3, 1).detach().cpu().numpy()
    test_input = test_input.permute(0, 2, 3, 1).cpu().numpy()
    plt.figure(figsize=(10, 6))
    display_list = [test_input[0], prediction[0]]
    title = ['Input Image', 'Genrated Image']
    for i in range(2):
        plt.subplot(1, 2, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()


test_batch = next(iter(apple_dl_test))
test_input = torch.unsqueeze(test_batch[0], 0).to(device)

D_loss = []
G_loss = []
epoches = 50
for epoch in range(epoches):
    d_epoch_loss = 0
    g_epoch_loss = 0
    for step, (real_A, real_B) in enumerate(zip(apple_dataloader, orange_dataloader)):
        real_A = real_A.to(device)
        real_B = real_B.to(device)

        # 训练生成器
        gen_optimizer.zero_grad()

        # identity loss
        same_B = gen_AB(real_B)
        identity_B_loss = L1_loss(same_B, real_B)

        same_A = gen_BA(real_A)
        identity_A_loss = L1_loss(same_A, real_A)

        # 对抗损失 gan loss
        fake_B = gen_AB(real_A)
        D_pre_fake_B = dis_B(fake_B)
        gen_loss_AB = BECLoss(D_pre_fake_B,
                torch.ones_like(D_pre_fake_B, device=device))

        fake_A = gen_BA(real_B)
        D_pre_fake_A = dis_A(fake_A)
        gen_loss_BA = BECLoss(D_pre_fake_A,
                torch.ones_like(D_pre_fake_A, device=device))

        # 循环一致性损失
        recovered_A = gen_BA(fake_B)
        cycle_loss_ABA = L1_loss(recovered_A, real_A)

        recovered_B = gen_AB(fake_A)
        cycle_loss_BAB = L1_loss(recovered_B, real_B)

        g_loss = identity_A_loss +identity_B_loss +gen_loss_AB +\
                 gen_loss_BA+cycle_loss_ABA+cycle_loss_BAB

        g_loss.backward()
        gen_optimizer.step()

        # dis_A train
        dis_A_optimizer.zero_grad()
        dis_A_real_output = dis_A(real_A)
        dis_A_real_loss = BECLoss(dis_A_real_output,
                                  torch.ones_like(dis_A_real_output, device=device))
        dis_A_fake_output = dis_A(fake_A.detach())
        dis_A_fake_loss = BECLoss(dis_A_fake_output,
                                    torch.zeros_like(dis_A_fake_output, device=device))

        dis_A_loss = dis_A_real_loss + dis_A_fake_loss
        dis_A_loss.backward()
        dis_A_optimizer.step()

        # dis_B train
        dis_B_optimizer.zero_grad()
        dis_B_real_output = dis_B(real_B)
        dis_B_real_loss = BECLoss(dis_B_real_output,
                                  torch.ones_like(dis_B_real_output, device=device))

        dis_B_fake_output = dis_B(fake_B.detach())
        dis_B_fake_loss = BECLoss(dis_B_fake_output,
                                  torch.zeros_like(dis_B_fake_output, device=device))
        dis_B_loss = dis_B_fake_loss + dis_B_real_loss
        dis_B_loss.backward()
        dis_B_optimizer.step()

        with torch.no_grad():
            g_epoch_loss += g_loss.item()
            d_epoch_loss += (dis_A_loss + dis_B_loss).item()
    with torch.no_grad():
        g_epoch_loss /= (step+1)
        d_epoch_loss /= (step+1)
        D_loss.append(d_epoch_loss)
        G_loss.append(g_epoch_loss)
        print('Epoch:', epoch+1)
        print('g_epoch_loss:', g_epoch_loss)
        print('d_epoch_loss:', d_epoch_loss)
        generate_images(gen_AB, test_input)  # test_input is apple

torch.save(gen_AB, 'Gen_AB.pth', _use_new_zipfile_serialization=False)
torch.save(gen_BA, 'Gen_BA.pth', _use_new_zipfile_serialization=False)
torch.save(dis_B, 'Dis_B.pth', _use_new_zipfile_serialization=False)
torch.save(dis_A, 'Dis_A.pth', _use_new_zipfile_serialization=False)


  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

JiYH

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值