Hung-Yi Lee homework[11]: GAN


一、GAN模型介绍

  GAN:Generative adversarial network生成对抗网络。
  GAN框架让一个深度学习模型学习训练数据分布,从而生成具有同分布的类似数据。
  GAN由两个不同的模型组成,一个是生成模型G(Generator),一个是鉴别模型D(Discriminator)。其中,G的作用是产生fake图像使其的分布与训练图像相似; D的作用是来判断这个fake图像与真正的图像是否相同。
  训练过程中,G通过产生越来越好的fake图像,来不断试图去打败D;同时D也是如此。这个训练在当生成器生成看起来像是直接来自训练数据的完美赝品时,判别器总是猜测生成器输出为真或假的概率为50%时达到平衡。
在这里插入图片描述
  此次实验采用DCGAN作为模型架构。DCGAN是将CNN与GAN的一种结合,将GAN的G和D换成了两个CNN。

二、训练过程

utils.py

from torch.utils.data import Dataset, DataLoader
import cv2
import os
import glob
import torchvision.transforms as transforms
import random
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

class FaceDataset(Dataset):
    def __init__(self, fnames, transform):
        self.transform = transform
        self.fnames = fnames
        self.num_samples = len(self.fnames)
    def __getitem__(self,idx):
        fname = self.fnames[idx]
        img = cv2.imread(fname)
        img = self.BGR2RGB(img) 
        img = self.transform(img)
        return img

    def __len__(self):
        return self.num_samples

    def BGR2RGB(self,img):
        return cv2.cvtColor(img,cv2.COLOR_BGR2RGB)

def get_dataset(root):
    fnames = glob.glob(os.path.join(root, '*'))

    transform = transforms.Compose(
        [transforms.ToPILImage(),
         transforms.Resize((64, 64)),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3) ] )
    dataset = FaceDataset(fnames, transform)
    return dataset

def same_seeds(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  
    np.random.seed(seed)  
    random.seed(seed)  
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

class Generator(nn.Module):
    """
    input (N, in_dim)
    output (N, 3, 64, 64)
    """
    def __init__(self, in_dim, dim=64):
        super(Generator, self).__init__()
        def dconv_bn_relu(in_dim, out_dim):
            return nn.Sequential(
                nn.ConvTranspose2d(in_dim, out_dim, 5, 2,
                                   padding=2, output_padding=1, bias=False),
                nn.BatchNorm2d(out_dim),
                nn.ReLU())
        self.l1 = nn.Sequential(
            nn.Linear(in_dim, dim * 8 * 4 * 4, bias=False),
            nn.BatchNorm1d(dim * 8 * 4 * 4),
            nn.ReLU())
        self.l2_5 = nn.Sequential(
            dconv_bn_relu(dim * 8, dim * 4),
            dconv_bn_relu(dim * 4, dim * 2),
            dconv_bn_relu(dim * 2, dim),
            nn.ConvTranspose2d(dim, 3, 5, 2, padding=2, output_padding=1),
            nn.Tanh())
        self.apply(weights_init)
    def forward(self, x):
        y = self.l1(x)
        y = y.view(y.size(0), -1, 4, 4)
        y = self.l2_5(y)
        return y

class Discriminator(nn.Module):
    """
    input (N, 3, 64, 64)
    output (N, )
    """
    def __init__(self, in_dim, dim=64):
        super(Discriminator, self).__init__()
        def conv_bn_lrelu(in_dim, out_dim):
            return nn.Sequential(
                nn.Conv2d(in_dim, out_dim, 5, 2, 2),
                nn.BatchNorm2d(out_dim),
                nn.LeakyReLU(0.2))
        self.ls = nn.Sequential(
            nn.Conv2d(in_dim, dim, 5, 2, 2), nn.LeakyReLU(0.2),
            conv_bn_lrelu(dim, dim * 2),
            conv_bn_lrelu(dim * 2, dim * 4),
            conv_bn_lrelu(dim * 4, dim * 8),
            nn.Conv2d(dim * 8, 1, 4),
            nn.Sigmoid())
        self.apply(weights_init)
    def forward(self, x):
        y = self.ls(x)
        y = y.view(-1)
        return y

hw11.py

import torch
from torch import optim
from torch.autograd import Variable
import torchvision
from utils import *
import matplotlib.pyplot as plt

if __name__ == '__main__':
    # 超参数
    batch_size = 64
    z_dim = 100
    lr = 1e-4
    n_epoch = 10
    save_dir = 'logs'
    os.makedirs(save_dir, exist_ok=True)

    # 建立模型
    G = Generator(in_dim=z_dim).cuda()
    D = Discriminator(3).cuda()
    G.train()
    D.train()

    # loss criterion
    criterion = nn.BCELoss()

    # optimizer
    opt_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
    opt_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))

    same_seeds(0)
    # 导入数据
    dataset = get_dataset('faces')
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    # 开始训练
    z_sample = Variable(torch.randn(100, z_dim)).cuda()

    for e, epoch in enumerate(range(n_epoch)):
        for i, data in enumerate(dataloader):
            imgs = data
            imgs = imgs.cuda()

            bs = imgs.size(0)

            """ 训练D网络 """
            z = Variable(torch.randn(bs, z_dim)).cuda()
            r_imgs = Variable(imgs).cuda()
            f_imgs = G(z)

            # label
            r_label = torch.ones((bs)).cuda()
            f_label = torch.zeros((bs)).cuda()

            r_logit = D(r_imgs.detach())
            f_logit = D(f_imgs.detach())

            # 计算 loss
            r_loss = criterion(r_logit, r_label)
            f_loss = criterion(f_logit, f_label)
            loss_D = (r_loss + f_loss) / 2

            # 更新模型参数
            D.zero_grad()
            loss_D.backward()
            opt_D.step()

            """ 训练G网络 """
            z = Variable(torch.randn(bs, z_dim)).cuda()
            f_imgs = G(z)
            f_logit = D(f_imgs)

            # 计算 loss
            loss_G = criterion(f_logit, r_label)

            # 更新模型参数
            G.zero_grad()
            loss_G.backward()
            opt_G.step()

            # 打印训练过程中的参数信息
            print(
                f'\rEpoch [{epoch + 1}/{n_epoch}] {i + 1}/{len(dataloader)} Loss_D: {loss_D.item():.4f} Loss_G: {loss_G.item():.4f}',
                end='')
        G.eval()
        f_imgs_sample = (G(z_sample).data + 1) / 2.0
        filename = os.path.join(save_dir, f'Epoch_{epoch + 1:03d}.jpg')
        torchvision.utils.save_image(f_imgs_sample, filename, nrow=10)
        print(f' | Save some samples to {filename}.')
        # 显示生成图像
        grid_img = torchvision.utils.make_grid(f_imgs_sample.cpu(), nrow=10)
        plt.figure(figsize=(10, 10))
        plt.imshow(grid_img.permute(1, 2, 0))
        plt.show()
        G.train()
        if (e + 1) % 5 == 0:
            torch.save(G.state_dict(), 'dcgan_g.pth')
            torch.save(D.state_dict(), 'dcgan_d.pth')


epoch训练中间过程中generator的结果
1在这里插入图片描述
2在这里插入图片描述
3在这里插入图片描述
4在这里插入图片描述
5在这里插入图片描述
6在这里插入图片描述
7在这里插入图片描述
8在这里插入图片描述
9在这里插入图片描述
10在这里插入图片描述

  如理论分析得到的结果所示,图片确实是在不断变得清晰以及更加靠近数据集内真实图片的。

三、实验结果

gerenete.py

import torch
from torch import optim
from torch.autograd import Variable
import torchvision
import matplotlib.pyplot as plt
from utils import *

if __name__ == '__main__':
    z_dim = 100
    # 加载训练好的模型
    G = Generator(z_dim)
    G.load_state_dict(torch.load('dcgan_g.pth'))
    G.eval()
    G.cuda()
    # 生成图片并进行保存
    n_output = 20
    z_sample = Variable(torch.randn(n_output, z_dim)).cuda()
    imgs_sample = (G(z_sample).data + 1) / 2.0
    save_dir = 'logs'
    filename = os.path.join(save_dir, f'result.jpg')
    torchvision.utils.save_image(imgs_sample, filename, nrow=10)
    # 显示图片
    grid_img = torchvision.utils.make_grid(imgs_sample.cpu(), nrow=10)
    plt.figure(figsize=(10, 10))
    plt.imshow(grid_img.permute(1, 2, 0))
    plt.show()

在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值