深度学习-第G1周:生成对抗网络(GAN)入门

本文详细介绍了生成对抗网络(GAN)的工作原理,包括生成器和判别器的设计,以及如何使用PyTorch实现GAN在MNIST数据集上的训练过程。通过逐层构建神经网络并优化损失函数,最终生成逼真的图像样本。
摘要由CSDN通过智能技术生成

🍨 本文为[🔗365天深度学习训练营]内部限免文章(版权归 *K同学啊* 所有)
🍖 作者:[K同学啊]

 一、前言

生成对抗网络(Generative Adversarial Network, GAN)是一种通过两个神经网络相互博弈的方式进行学习的生成模型。生成对抗网络能够在不使用标注数据的情况下来进行生成任务的学习。生成对抗网络由一个生成器和一个判别器组成。生成器从潜在空间随机取样作为输入,其输出结果需要尽量模仿训练集中的真实样本。判别器的输入则为真实样本或生成器的输出,其目的是将生成器的输出从真实样本中尽可能分别出来。生成器和判别器相互对抗、不断学习,最终目的使得判别器无法判断生成器的输出结果是否真实。

基本框架:生成器+鉴别器

二、前期准备

下载经典minist数据到datasets文件夹,使用经典dataloader = DataLoader(mnist, batch_size= batch_size, shuffle=True)调用批处理数据

# -*- coding:utf-8 -*-
import torch
from torchvision import datasets, transforms
import torch.nn as nn
import time
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F 
import torchsummary as summary
import copy
import os
import argparse
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch.autograd import Variable

## 创建文件夹
os.makedirs("./images/", exist_ok=True) # 记录训练过程的图片效果
os.makedirs("./save/", exist_ok=True) # 记录训练完成的图片效果
os.makedirs("./datasets/minst", exist_ok=True) # 下载数据集存放位置

# 超参数
n_epochs = 50
batch_size = 64
lr=0.0002
b1=0.5
b2=0.999
n_cpu=2
latent_dim = 100
img_size=28
channels=1
sample_interval=500

## 图像的尺寸(1,28,28),和图像的像素面积784
img_shape=(channels, img_size, img_size)
img_area=np.prod(img_shape)

transforms = transforms.Compose(
        [
        transforms.Resize(img_size),#中心裁剪到224*224
        transforms.ToTensor(),#转化成张量
        transforms.Normalize(0.5, 0.5)
])
## 下载mnist数据集
mnist = datasets.MNIST(root='./datasets/', train=True, download=True, transform = transforms)

# 
dataloader = DataLoader(mnist, batch_size= batch_size, shuffle=True)

三、生成器

生成网络G(𝒛) :生成网络 G 和自编码器的 Decoder 功能类似,从先验分布pzp_zpz​(∙)采样获得潜在空间点向量,经过网络生成图片样本xˉ\bar{x}xˉ~𝑝𝑔(x∣z)𝑝_𝑔(x|z)pg​(x∣z)。

生成器的网络(𝑝𝑔(x∣z)𝑝_𝑔(x|z)pg​(x∣z))可以由深度神经网络来参数化,如:卷积网络和转置卷积网络。下图中从均匀分布𝑝𝒛𝑝𝒛pz(∙)中采样出隐藏变量zzz,经过多层转置卷积层网络参数化的𝑝𝑔(x∣z)𝑝_𝑔(x|z)pg​(x∣z)分布中采样出样本xfx_fxf​,从输入输出层面来看,生成器 G 的功能是将隐向量𝒛通过神经网络转换为样本向量xfx_fxf​,下标𝑓代表假样本(Fake samples)

## 定义生成器 Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, img_area),
            nn.Tanh()
            )
    def forward(self, z):
        imgs = self.model(z)
        imgs = imgs.view(imgs.size(0), *img_shape)
        return imgs

生成网络 G 负责学习样本的真实分布,不断生成新的假数据

输出一个100维度的高斯分布数据,通过不断线性变换映射到784 维(1 * 28 * 28) 

具体流程如下:

1、线性变换映射100 —> 128 

2、线性变换映射128 —> 256

3、线性变换映射256—> 512

4、线性变换映射512—> 1024

5、线性变换映射1024—> 784

6、784—> Fake samples(1 * 28 * 28) 

最重要的是最后一步,生成假样本(Fake samples)

四、鉴别器

判别网络D(𝒙):判别网络和普通的二分类网络功能类似,网络的输入数据集由采样自真实数据分布p𝑟p_𝑟pr​(∙)的样本x𝑟x_𝑟xr​ ~ 𝑝𝑟𝑝_𝑟pr​(∙)和采样自生成网络的假样本x𝑓x_𝑓xf​ ~ 𝑝𝑔(x∣z)𝑝_𝑔(x|z)pg​(x∣z)组成。判别网络输出为xxx属于真实样本的概率𝑃(xxx为真|xxx),我们把所有真实样本xrx_rxr​的标签标注为真(1),所有生成网络产生的样本,所有生成网络产生的样本xfx_fxf​标注为假(0),通过最小化判别网络 D 的预测值与标签之间的误差来优化判别网络参数。

## 定义判别器 Discriminator 
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_area, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
            )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

判别器是另一个神经网络,它的目标是区分真实数据和假数据。具体来说,判别器将输入数据分为两个类别:真实数据和假数据

判别器跟其他神经网络模型接近,通过线性变换,最终用nn.Sigmoid()输出【0,1】的概率

五、训练实例

generator = Generator()
discriminator = Discriminator()

criterion = torch.nn.BCELoss()

optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

##
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)
criterion.to(device)

for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        
        imgs = imgs.view(imgs.size(0), -1)
        real_img = Variable(imgs).to(device)
        real_label = Variable(torch.ones(imgs.size(0), 1)).to(device)
        fake_label = Variable(torch.zeros(imgs.size(0), 1)).to(device)
    
        ## Train Discriminator
        ## 计算真实图片的损失
        real_out = discriminator(real_img)
        loss_real_D = criterion(real_out, real_label)
        real_scores = real_out
        ## 计算假图片的损失
        z = Variable(torch.randn(imgs.size(0), latent_dim)).to(device)
        fake_img = generator(z).detach()
        fake_out = discriminator(fake_img)
        loss_fake_D =  criterion(fake_out, fake_label)
        fake_scores = fake_out

        ## 损失函数和优化
        loss_D = loss_real_D + loss_fake_D
        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()
        
        ## Train Generator
        z = Variable(torch.randn(imgs.size(0), latent_dim)).to(device)
        fake_img = generator(z)
        output = discriminator(fake_img)
        ## 损失函数和优化
        loss_G = criterion(output, real_label)
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()        
    
        if (i+1) % 300 == 0:
            print(
                "[Epoch %d/%d][Batch %d/%d][D loss: %f][G loss: %f][D real: %f][D fake: %f]"
                % (epoch,n_epochs,i,len(dataloader),loss_D.item(),loss_G.item(),real_scores.data.mean(),fake_scores.data.mean(),))
            
        ## 保存训练的图像
        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            save_image(fake_img.data[:25], "./images/%d.png" % batches_done, nrow=5, normalize=True)
            
torch.save(generator.state_dict(), './save/generator.pth')   
torch.save(discriminator.state_dict(), './save/discriminator.pth')  

将训练的过程用图片保存在 "./images/%d.png" % batches_done

 训练过程如上图,训练结果逐渐接近真实图

训练46500次之后的训练结果

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值