GAN(生成对抗网络)学习——pytorch实现MINIST手写数据集生成

一、生成对抗网络原理

生成对抗网络,是一种基于博弈思想的网络训练思路,其主要网络模块由两部分组成,分别为generator(生成器)和discriminator(判别器)。
我们以GAM生成Minist手写数据集为例,在这个例子中,我们的目的是为了生成可以以假乱真的手写数字图片。而我们的训练思路,是使用生成器来产生一张照片,并且由判别器来判断这张照片是否是真实的照片。
在这个过程中,生成器会根据判别器返回的结果,一步一步的学习,生成器学习的目的是产生可以骗过判别器的照片,而判别器在这个过程中也会不断的进行学习,其学习的目的是,可以正确的分别出假图片和真图片

二、GAN网络结构详解

  • 生成器定义
    生成器的输入是一组噪声,网络中包含着全连接层和激活函数,最终会生成一张大小为28x28(784)的图片。
#生成网络
def generator(noise_dim=NOISE_DIM):
    net = nn.Sequential(
        nn.Linear(noise_dim, 1024),
        nn.ReLU(True),
        nn.Linear(1024, 1024),
        nn.ReLU(True),
        nn.Linear(1024, 784), #最终输出大小为784
        nn.Tanh()
    )
    return net
  • 判别器定义
    判别器的输入是一张28x28的图片,不过在输入到判别器网络之前,会先将其展开成28x28 = 784的一维向量。
    这个一维向量,可以是真实的图片(从MINIST数据集取出的图片),也可以是生成器生成的图片。
# 判别网络
def discriminator():
    net = nn.Sequential(
        nn.Linear(784, 256),
        nn.LeakyReLU(0.2),
        nn.Linear(256, 256),
        nn.LeakyReLU(0.2),
        nn.Linear(256, 1) #输出结果为置信度
    )
    return net

三、生成器与判别器训练思路

按照生成对抗网络的基本原理,我们对生成器和判别器的损失函数进行定义。

  • 判别器损失函数
    判别器的损失函数分为两部分
    第一步是要能准确的将正确的图片识别为正确
    第二步是要能准确的将错误的图片识别为错误
    因此,需要为真正的图片生成的图片分别生成标签
    正确的图片标签为1,生成的图片标签为0
    因此生成器的损失函数即为两部分的损失函数的和
    在这里损失计算函数采用交叉熵的计算方式
    即只需要关注是否分类正确来进行损失计算,而忽略分类错误的损失部分
	real_data = Variable(x).view(bs, -1)  # 真实数据
	logits_real = D_net(real_data)  # 判别网络得分

	sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5  # -1 ~ 1 的均匀分布
	g_fake_seed = Variable(sample_noise)
	fake_images = G_net(g_fake_seed)  # 生成的假的数据
	logits_fake = D_net(fake_images)  # 判别网络得分

	d_total_error = discriminator_loss(logits_real, logits_fake)  # 判别器的 loss
def discriminator_loss(logits_real, logits_fake):  # 判别器的 loss
    size = logits_real.shape[0]
    true_labels = Variable(torch.ones(size, 1)).float()
    false_labels = Variable(torch.zeros(size, 1)).float()
    loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
    return loss

  • 生成器损失函数
    与判别器不同,生成器的目的是要生成出可以骗过判别器的假图片
    因此生成器所生成的图片在经过判别器判别后的结果,需要最大限度的接近真实
    因此需要将判别后的结果,与正确的标签进行损失函数计算
    这里损失函数也是采用交叉熵的方式
 # 生成网络
  g_fake_seed = Variable(sample_noise)
  fake_images = G_net(g_fake_seed)  # 生成的假的数据

  gen_logits_fake = D_net(fake_images)
  g_error = generator_loss(gen_logits_fake)  # 生成网络的 loss
def generator_loss(logits_fake):  # 生成器的 loss
    size = logits_fake.shape[0]
    true_labels = Variable(torch.ones(size, 1)).float()
    loss = bce_loss(logits_fake, true_labels)
    return loss

四、实现代码

import torch
from torch import nn
from torch.autograd import Variable

import torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST

import numpy as np

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
plt.rcParams['figure.figsize'] = (10.0, 8.0)  # 设置画图的尺寸
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'


def show_images(images):  # 定义画图工具
    images = np.reshape(images, [images.shape[0], -1])
    sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
    sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))

    fig = plt.figure(figsize=(sqrtn, sqrtn))
    gs = gridspec.GridSpec(sqrtn, sqrtn)
    gs.update(wspace=0.05, hspace=0.05)

    for i, img in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(img.reshape([sqrtimg, sqrtimg]))
    return


def preprocess_img(x):
    x = tfs.ToTensor()(x)
    return (x - 0.5) / 0.5


def deprocess_img(x):
    return (x + 1.0) / 2.0


class ChunkSampler(sampler.Sampler):  # 定义一个取样的函数
    """Samples elements sequentially from some offset.
    Arguments:
      num_samples: # of desired datapoints
      start: offset where we should start selecting from
    """

    def __init__(self, num_samples, start=0):
        self.num_samples = num_samples
        self.start = start

    def __iter__(self):
        return iter(range(self.start, self.start + self.num_samples))

    def __len__(self):
        return self.num_samples


NUM_TRAIN = 50000
NUM_VAL = 5000

NOISE_DIM = 96
batch_size = 128

train_set = MNIST('./data', train=True, transform=preprocess_img,download=True)

train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0))

val_set = MNIST('./data', train=True, transform=preprocess_img,download=True)

val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))

imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size, 784)).numpy().squeeze()  # 可视化图片效果
show_images(imgs)


# 判别网络
def discriminator():
    net = nn.Sequential(
        nn.Linear(784, 256),
        nn.LeakyReLU(0.2),
        nn.Linear(256, 256),
        nn.LeakyReLU(0.2),
        nn.Linear(256, 1)
    )
    return net


# 生成网络
def generator(noise_dim=NOISE_DIM):
    net = nn.Sequential(
        nn.Linear(noise_dim, 1024),
        nn.ReLU(True),
        nn.Linear(1024, 1024),
        nn.ReLU(True),
        nn.Linear(1024, 784),
        nn.Tanh()
    )
    return net


# 判别器的 loss 就是将真实数据的得分判断为 1,假的数据的得分判断为 0,而生成器的 loss 就是将假的数据判断为 1

bce_loss = nn.BCEWithLogitsLoss()  # 交叉熵损失函数


def discriminator_loss(logits_real, logits_fake):  # 判别器的 loss
    size = logits_real.shape[0]
    true_labels = Variable(torch.ones(size, 1)).float()
    false_labels = Variable(torch.zeros(size, 1)).float()
    loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
    return loss


def generator_loss(logits_fake):  # 生成器的 loss
    size = logits_fake.shape[0]
    true_labels = Variable(torch.ones(size, 1)).float()
    loss = bce_loss(logits_fake, true_labels)
    return loss


# 使用 adam 来进行训练,学习率是 3e-4, beta1 是 0.5, beta2 是 0.999
def get_optimizer(net):
    optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
    return optimizer


def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250,
                noise_size=96, num_epochs=10):
    iter_count = 0
    for epoch in range(num_epochs):
        for x, _ in train_data:
            bs = x.shape[0]
            # 判别网络
            real_data = Variable(x).view(bs, -1)  # 真实数据
            logits_real = D_net(real_data)  # 判别网络得分

            sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5  # -1 ~ 1 的均匀分布
            g_fake_seed = Variable(sample_noise)
            fake_images = G_net(g_fake_seed)  # 生成的假的数据
            logits_fake = D_net(fake_images)  # 判别网络得分

            d_total_error = discriminator_loss(logits_real, logits_fake)  # 判别器的 loss
            D_optimizer.zero_grad()
            d_total_error.backward()
            D_optimizer.step()  # 优化判别网络

            # 生成网络
            g_fake_seed = Variable(sample_noise)
            fake_images = G_net(g_fake_seed)  # 生成的假的数据

            gen_logits_fake = D_net(fake_images)
            g_error = generator_loss(gen_logits_fake)  # 生成网络的 loss
            G_optimizer.zero_grad()
            g_error.backward()
            G_optimizer.step()  # 优化生成网络

            if (iter_count % show_every == 0):
                print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item()))
                imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
                show_images(imgs_numpy[0:16])
                plt.show()
                print()
            iter_count += 1


D = discriminator()
G = generator()

D_optim = get_optimizer(D)
G_optim = get_optimizer(G)

train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值