GAN简介与复习 基于pytorch

本文介绍了GAN的基本原理,并通过PyTorch实现了一个基于MNIST数据集的例子。在训练过程中,Generator和Discriminator相互对抗,分别试图生成逼真数据和识别真假数据。此外,还提及了使用名人照片数据集进行GAN训练的经验,虽然数据量较大,但最终效果令人满意。
摘要由CSDN通过智能技术生成

概要

GAN(Generative Adversarial Network)生成对抗网络。其实要理解GAN的构想逻辑并不难,像其他的一些模型比如说最最基础的nn.Linear() + nn.ReLU(),或者是RNN模型,我们不妨把这个模型看成一位武侠,他的目的是要跟江湖上尽可能多的人(data)过招(train),目的是在未来遇到邪恶的坏蛋(真实情景应用)时能够一招制敌(给出正确的结果)。
但是天不遂人愿,在茫茫的人海中,真正的武林高手有几个?又有几个能被我遇到?今天打过了丐帮的降龙十八掌,明天谁知道会不会被一记九阳神功拍的头昏眼花?(能接触到的数据总是有限的)武侠仰天沉思,他想起那一年去西域,自己仗着在中原打遍天下无敌手(过拟合)四处张扬得不行,结果被一旁的扫地大爷一套西洋拳术带走(模型不适应其他数据)。
可是家中有老母要照顾,忠孝难两全,武侠也因此一直呆在中原。由于放眼神州已无敌手,便打起了木人桩。机会总是留给有准备的人,有一天武侠捡到了阿拉丁神灯,神灯答应了他的愿望,点化了他的木人桩,让他能主动与武侠打斗并且不断增强自己的武力值,直击武侠痛点。武侠大喜,从此开始了与被点化的木人桩的切磋之路,技艺日增,终成一代地球大侠。

这个木人桩和武侠就是GAN中的Generator和Discriminator。对于Discriminator而言,它的目标是分辨出真假数据,对于Generator而言,它的目标是要制造出能以假乱真的数据。在学习的过程中,Generator的输入我们用torch.randn产生随机数据,以此希望通过Generator产生各种各样的输入。
简单地说,二者的目标总结为:

  • Discriminator: 给定数据x,我希望分辨出这是真实产生的数据,还是Generator模拟的假数据,输出0-1

  • Generator: 给定随机数random,我希望能蒙混过关,尽可能模拟真实数据 output.shape == x.shape

二者在训练的过程中我们应该可以看到两边的loss大致是一个此消彼长的关系,这也是GAN中A(Adversarial)的本意,两者对抗。

一个例子 (base on MNIST)

我用在暑假跟着学深度学习中一课时的代码复现给大家分享一下。MNIST数据集是一个图片集,都是手写的单个数字,有images和labels两个部分。用torch.utils.data.Dataset或者torchvision.Datasets.MNIST可以读入为dataset实例,进一步构造dataloader。废话不多说,上代码。

# 一些经常用的库和函数
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision 
from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
# 定义一些超参数
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 64
latent_size = 64	# 就是上文中生成的random的长度
image_size = 28*28  # 这是MNIST数据集中图片的大小
hidden_size = 256  # 定义Discriminator和Generator模型中的隐层的大小
output_size = 10  # 最终输出为10维的向量,代表1~
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值