【生成式网络】入门篇(五):Pix2Pix 的 代码和结果记录

原理参考 https://zhuanlan.zhihu.com/p/464673225
代码参考自 https://github.com/LibreCV/blog/blob/master/_notebooks/2021-02-13-Pix2Pix%20explained%20with%20code.ipynb

import os
# os.chdir(os.path.dirname(__file__))
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision import datasets
from torchvision import models
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from PIL import Image
import argparse
from glob import glob
import random
import itertools

sample_dir = 'samples_pix2pix'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir, exist_ok=True)

writer = SummaryWriter(sample_dir)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
np.random.seed(0)
torch.manual_seed(0)

class DownSampleConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel=4, strides=2, padding=1, activation=True, batchnorm=True):
        """
        Paper details:
        - C64-C128-C256-C512-C512-C512-C512-C512
        - All convolutions are 4×4 spatial filters applied with stride 2
        - Convolutions in the encoder downsample by a factor of 2
        """
        super().__init__()
        self.activation = activation
        self.batchnorm = batchnorm

        self.conv = nn.Conv2d(in_channels, out_channels, kernel, strides, padding)

        if batchnorm:
            self.bn = nn.BatchNorm2d(out_channels)

        if activation:
            self.act = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.conv(x)
        if self.batchnorm:
            x = self.bn(x)
        if self.activation:
            x = self.act(x)
        return x

class UpSampleConv(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel=4,
        strides=2,
        padding=1,
        activation=True,
        batchnorm=True,
        dropout=False
    ):
        super().__init__()
        self.activation = activation
        self.batchnorm = batchnorm
        self.dropout = dropout

        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel, strides, padding)

        if batchnorm:
            self.bn = nn.BatchNorm2d(out_channels)

        if activation:
            self.act = nn.ReLU(True)

        if dropout:
            self.drop = nn.Dropout2d(0.5)

    def forward(self, x):
        x = self.deconv(x)
        if self.batchnorm:
            x = self.bn(x)

        if self.dropout:
            x = self.drop(x)
        return x

class Generator(nn.Module):
    def __init__(self, in_channels, out_channels):
        """
        Paper details:
        - Encoder: C64-C128-C256-C512-C512-C512-C512-C512
        - All convolutions are 4×4 spatial filters applied with stride 2
        - Convolutions in the encoder downsample by a factor of 2
        - Decoder: CD512-CD1024-CD1024-C1024-C1024-C512 -C256-C128
        """
        super().__init__()

        # encoder/donwsample convs
        self.encoders = [
            DownSampleConv(in_channels, 64, batchnorm=False),  # bs x 64 x 128 x 128
            DownSampleConv(64, 128),  # bs x 128 x 64 x 64
            DownSampleConv(128, 256),  # bs x 256 x 32 x 32
            DownSampleConv(256, 512),  # bs x 512 x 16 x 16
            DownSampleConv(512, 512),  # bs x 512 x 8 x 8
            DownSampleConv(512, 512),  # bs x 512 x 4 x 4
            DownSampleConv(512, 512),  # bs x 512 x 2 x 2
            DownSampleConv(512, 512, batchnorm=False),  # bs x 512 x 1 x 1
        ]

        # decoder/upsample convs
        self.decoders = [
            UpSampleConv(512, 512, dropout=True),  # bs x 512 x 2 x 2
            UpSampleConv(1024, 512, dropout=True),  # bs x 512 x 4 x 4
            UpSampleConv(1024, 512, dropout=True),  # bs x 512 x 8 x 8
            UpSampleConv(1024, 512),  # bs x 512 x 16 x 16
            UpSampleConv(1024, 256),  # bs x 256 x 32 x 32
            UpSampleConv(512, 128),  # bs x 128 x 64 x 64
            UpSampleConv(256, 64),  # bs x 64 x 128 x 128
        ]
        self.decoder_channels = [512, 512, 512, 512, 256, 128, 64]
        self.final_conv = nn.ConvTranspose2d(64, out_channels, kernel_size=4, stride=2, padding=1)
        self.tanh = nn.Tanh()

        self.encoders = nn.ModuleList(self.encoders)
        self.decoders = nn.ModuleList(self.decoders)

    def forward(self, x):
        skips_cons = []
        for encoder in self.encoders:
            x = encoder(x)

            skips_cons.append(x)

        skips_cons = list(reversed(skips_cons[:-1]))
        decoders = self.decoders[:-1]

        for decoder, skip in zip(decoders, skips_cons):
            x = decoder(x)
            # print(x.shape, skip.shape)
            x = torch.cat((x, skip), axis=1)

        x = self.decoders[-1](x)
        # print(x.shape)
        x = self.final_conv(x)
        return self.tanh(x)

class PatchGAN(nn.Module):
    def __init__(self, input_channels):
        super().__init__()
        self.d1 = DownSampleConv(input_channels, 64, batchnorm=False)
        self.d2 = DownSampleConv(64, 128)
        self.d3 = DownSampleConv(128, 256)
        self.d4 = DownSampleConv(256, 512)
        self.final = nn.Conv2d(512, 1, kernel_size=1)

    def forward(self, x, y):
        x = torch.cat([x, y], axis=1)
        x0 = self.d1(x)
        x1 = self.d2(x0)
        x2 = self.d3(x1)
        x3 = self.d4(x2)
        xn = self.final(x3)
        return xn

def _weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms=None, unaligned=False, mode='train'):
        self.transforms = transforms
        self.unaligned = unaligned
        self.files_A = sorted(glob(os.path.join(root, mode, 'A', '*.*')))
        self.files_B = sorted(glob(os.path.join(root, mode, 'B', '*.*')))

    def __getitem__(self, idx):
        img = Image.open(self.files_A[idx % len(self.files_A)]).convert('RGB')
        itemA = self.transforms(img)

        if self.unaligned:
            rand_idx = random.randint(0, len(self.files_B)-1)
            img = Image.open(self.files_B[rand_idx]).convert('RGB')
            itemB = self.transforms(img)
        else:
            img = Image.open(self.files_B[idx % len(self.files_B)]).convert('RGB')
            itemB = self.transforms(img)

        return {
            'A' : itemA,
            'B' : itemB
        }

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

def denorm(x):
    out = (x+1)/2
    return out.clamp(0, 1)

# Losses
adv_criterion = nn.BCEWithLogitsLoss() 
recon_criterion = nn.L1Loss() 

lambda_recon = 200
n_epochs = 200
display_step = 100
batch_size = 4
lr = 0.0002
target_size = 256
input_size = 256

dataroot = 'data/cycle_gan/datasets/facades'
input_nc = 3
output_nc = 3
G = Generator(input_nc, output_nc).to(device)
D = PatchGAN(input_nc + output_nc).to(device)


G.apply(_weights_init)
D.apply(_weights_init)

optimG = torch.optim.Adam(G.parameters(), lr=lr)
optimD = torch.optim.Adam(D.parameters(), lr=lr)


# Dataset loader
transforms_data = transforms.Compose([ 
                transforms.Resize(int(input_size*1.12), Image.BICUBIC), 
                transforms.RandomCrop(input_size), 
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) 
                ])

dataset = ImageDataset(dataroot, transforms=transforms_data, unaligned=False)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True)

###### Training ######
cnt = 0
log_step = 10
for epoch in range(0, n_epochs):
    for i, batch in enumerate(dataloader):
        # set model input
        real = batch['A'].to(device)
        condition = batch['B'].to(device)

        # discriminator
        fake_images = G(condition).detach()
        fake_logits = D(fake_images, condition)

        real_logits = D(real, condition)

        fake_loss = adv_criterion(fake_logits, torch.zeros_like(fake_logits))
        real_loss = adv_criterion(real_logits, torch.ones_like(real_logits))
        d_loss =  (real_loss + fake_loss) / 2

        optimD.zero_grad()
        d_loss.backward()
        optimD.step()

        # generator
        fake_images = G(condition)
        disc_logits = D(fake_images, condition)
        adversarial_loss = adv_criterion(disc_logits, torch.ones_like(disc_logits))

        # calculate reconstruction loss
        recon_loss = recon_criterion(fake_images, real)

        g_loss = adversarial_loss + lambda_recon * recon_loss

        optimG.zero_grad()
        g_loss.backward()
        optimG.step()

        cnt += 1
        if cnt % log_step == 0:
            print('Epoch [{}/{}], Step [{}], g_loss: {:.4f}, d_loss: {:.4f}'.\
                format(epoch, n_epochs, cnt, g_loss.item(), d_loss.item()))

            writer.add_scalar('g_loss', g_loss.item(), global_step=cnt)
            writer.add_scalar('d_loss', d_loss.item(), global_step=cnt)

        if cnt % 100 == 0:
            writer.add_images('real', denorm(real), global_step=cnt)
            writer.add_images('condition', denorm(condition), global_step=cnt)
            writer.add_images('fake_images', denorm(fake_images), global_step=cnt)

整体结构参考自Conditional GAN,把图像A作为condition出现在generator和discriminator里。
在这里插入图片描述

另外一个可以关注一下U-net结构的generator设计,和PatchGAN结构的 discriminator。具体解释可以看下图
在这里插入图片描述
实验效果如下

  • real image
    在这里插入图片描述

  • condition image在这里插入图片描述

  • generated image,效果很差,可能是没有训练到位,后续再调试吧

在这里插入图片描述

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值