GAN学习总结三-Pytorch实现利用GAN进行MNIST手写数字生成

16 篇文章 2 订阅
5 篇文章 1 订阅

GAN学习总结三-Pytorch实现利用GAN进行MNIST手写数字生成

​ 前面两篇博客分别介绍了GAN的基本概念理论推导,理论联系实际,本节从代码的角度理解GAN网络的实现及相关细节,加深自己的理解.

整个实现过程如下:

在这里插入图片描述

导入相关库

import torch
from torch import nn
from torch.autograd import Variable

import torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST

import numpy as np

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # 设置画图的尺寸
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

def show_images(images): # 定义画图工具
    images = np.reshape(images, [images.shape[0], -1])
    sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
    sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))

    fig = plt.figure(figsize=(sqrtn, sqrtn))
    gs = gridspec.GridSpec(sqrtn, sqrtn)
    gs.update(wspace=0.05, hspace=0.05)

    for i, img in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(img.reshape([sqrtimg,sqrtimg]))
    return 

def preprocess_img(x):
    x = tfs.ToTensor()(x)
    return (x - 0.5) / 0.5

def deprocess_img(x):
    return (x + 1.0) / 2.0
class ChunkSampler(sampler.Sampler): # 定义一个取样的函数
    """Samples elements sequentially from some offset. 
    Arguments:
        num_samples: # of desired datapoints
        start: offset where we should start selecting from
    """
    def __init__(self, num_samples, start=0):
        self.num_samples = num_samples
        self.start = start

    def __iter__(self):
        return iter(range(self.start, self.start + self.num_samples))

    def __len__(self):
        return self.num_samples

NUM_TRAIN = 50000
NUM_VAL = 5000

NOISE_DIM = 96
batch_size = 128

train_set = MNIST('mnist', train=True, download=True, transform=preprocess_img)

train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0))

val_set = MNIST('mnist', train=True, download=True, transform=preprocess_img)

val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))


imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size, 784)).numpy().squeeze() # 可视化图片效果
show_images(imgs)


在这里插入图片描述

定义卷积判别网络

class build_dc_classifier(nn.Module):
    def __init__(self):
        super(build_dc_classifier, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 5, 1),
            nn.LeakyReLU(0.01),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 5, 1),
            nn.LeakyReLU(0.01),
            nn.MaxPool2d(2, 2)
        )
        self.fc = nn.Sequential(
            nn.Linear(1024, 1024),
            nn.LeakyReLU(0.01),
            nn.Linear(1024, 1)
        )
        
    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x

定义卷积生成网络

class build_dc_generator(nn.Module): 
    def __init__(self, noise_dim=NOISE_DIM):
        super(build_dc_generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(noise_dim, 1024),
            nn.ReLU(True),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, 7 * 7 * 128),
            nn.ReLU(True),
            nn.BatchNorm1d(7 * 7 * 128)
        )
        
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, padding=1),
            nn.ReLU(True),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, 1, 4, 2, padding=1),
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.shape[0], 128, 7, 7) # reshape 通道是 128,大小是 7x7
        x = self.conv(x)
        return x

定义损失函数

判别网络的损失函数公式为:

ℓ D = E x ∼ p data [ log ⁡ D ( x ) ] + E z ∼ p ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \ell_D = \mathbb{E}_{x \sim p_\text{data}}\left[\log D(x)\right] + \mathbb{E}_{z \sim p(z)}\left[\log \left(1-D(G(z))\right)\right] D=Expdata[logD(x)]+Ezp(z)[log(1D(G(z)))]

生成网络的损失函数公式为:

ℓ G = E z ∼ p ( z ) [ log ⁡ D ( G ( z ) ) ] \ell_G = \mathbb{E}_{z \sim p(z)}\left[\log D(G(z))\right] G=Ezp(z)[logD(G(z))]

bce_loss = nn.BCEWithLogitsLoss()

def discriminator_loss(logits_real, logits_fake): # 判别器的 loss
    size = logits_real.shape[0]
    true_labels = Variable(torch.ones(size, 1)).float().cuda()
    false_labels = Variable(torch.zeros(size, 1)).float().cuda()
    loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
    return loss

def generator_loss(logits_fake): # 生成器的 loss  
    size = logits_fake.shape[0]
    true_labels = Variable(torch.ones(size, 1)).float().cuda()
    loss = bce_loss(logits_fake, true_labels)
    return loss

定义优化器

# 使用 adam 来进行训练,学习率是 3e-4, beta1 是 0.5, beta2 是 0.999
def get_optimizer(net):
    optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
    return optimizer

定义训练函数

def train_dc_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250, 
                noise_size=96, num_epochs=10):
    iter_count = 0
    for epoch in range(num_epochs):
        for x, _ in train_data:
            bs = x.shape[0]
            # 判别网络
            real_data = Variable(x).cuda() # 真实数据
            logits_real = D_net(real_data) # 判别网络得分
            
            sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5 # -1 ~ 1 的均匀分布
            g_fake_seed = Variable(sample_noise).cuda()
            fake_images = G_net(g_fake_seed) # 生成的假的数据
            logits_fake = D_net(fake_images) # 判别网络得分

            d_total_error = discriminator_loss(logits_real, logits_fake) # 判别器的 loss
            D_optimizer.zero_grad()
            d_total_error.backward()
            D_optimizer.step() # 优化判别网络
            
            # 生成网络
            g_fake_seed = Variable(sample_noise).cuda()
            fake_images = G_net(g_fake_seed) # 生成的假的数据

            gen_logits_fake = D_net(fake_images)
            g_error = generator_loss(gen_logits_fake) # 生成网络的 loss
            G_optimizer.zero_grad()
            g_error.backward()
            G_optimizer.step() # 优化生成网络

            if (iter_count % show_every == 0):
                print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.data[0], g_error.data[0]))
                imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
                show_images(imgs_numpy[0:16])
                plt.show()
                print()
            iter_count += 1

开始训练:

D_DC = build_dc_classifier().cuda()
G_DC = build_dc_generator().cuda()

D_DC_optim = get_optimizer(D_DC)
G_DC_optim = get_optimizer(G_DC)

train_dc_gan(D_DC, G_DC, D_DC_optim, G_DC_optim, discriminator_loss, generator_loss, num_epochs=20)

训练过程中生成结果如下,刚开始图像模糊,后面图像越来越清晰:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

参考:

https://github.com/L1aoXingyu/code-of-learn-deep-learning-with-pytorch

  • 14
    点赞
  • 77
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
以下是一个基于GANmnist手写数字生成PyTorch代码示例: ```python import torch import torch.nn as nn from torchvision.datasets import MNIST from torchvision.transforms import ToTensor from torch.utils.data import DataLoader # 定义生成器 class Generator(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(Generator, self).__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.fc3 = nn.Linear(hidden_size, output_size) self.relu = nn.ReLU() def forward(self, x): x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.fc3(x) return x # 定义判别器 class Discriminator(nn.Module): def __init__(self, input_size, hidden_size): super(Discriminator, self).__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.fc3 = nn.Linear(hidden_size, 1) self.sigmoid = nn.Sigmoid() def forward(self, x): x = self.sigmoid(self.fc1(x)) x = self.sigmoid(self.fc2(x)) x = self.sigmoid(self.fc3(x)) return x # 定义超参数 input_size = 100 hidden_size = 256 output_size = 784 batch_size = 128 num_epochs = 200 # 加载MNIST数据集 train_dataset = MNIST(root='./data', train=True, transform=ToTensor(), download=True) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) # 初始化网络 G = Generator(input_size, hidden_size, output_size) D = Discriminator(output_size, hidden_size) # 定义损失函数和优化器 criterion = nn.BCELoss() lr = 0.0002 G_optimizer = torch.optim.Adam(G.parameters(), lr=lr) D_optimizer = torch.optim.Adam(D.parameters(), lr=lr) # 定义真实和假的标签 real_label = torch.ones(batch_size, 1) fake_label = torch.zeros(batch_size, 1) # 训练网络 for epoch in range(num_epochs): for i, (images, _) in enumerate(train_loader): # 定义真实和假的图像 real_images = images.view(batch_size, -1) z = torch.randn(batch_size, input_size) fake_images = G(z) # 训练判别器 D_real_loss = criterion(D(real_images), real_label) D_fake_loss = criterion(D(fake_images.detach()), fake_label) D_loss = D_real_loss + D_fake_loss D_optimizer.zero_grad() D_loss.backward() D_optimizer.step() # 训练生成器 G_loss = criterion(D(fake_images), real_label) G_optimizer.zero_grad() G_loss.backward() G_optimizer.step() # 打印损失 if (i+1) % 100 == 0: print('Epoch [{}/{}], Step [{}/{}], D Loss: {:.4f}, G Loss: {:.4f}' .format(epoch+1, num_epochs, i+1, len(train_loader), D_loss.item(), G_loss.item())) # 保存模型 torch.save(G.state_dict(), 'generator.pth') ``` 在训练完成后,可以使用生成器来生成新的手写数字图像,例如: ```python import matplotlib.pyplot as plt import numpy as np # 加载生成器 G = Generator(input_size, hidden_size, output_size) G.load_state_dict(torch.load('generator.pth')) # 生成图像 z = torch.randn(1, input_size) fake_image = G(z).detach().numpy() fake_image = np.reshape(fake_image, (28, 28)) # 显示图像 plt.imshow(fake_image, cmap='gray') plt.show() ``` 这样就可以生成一个随机的手写数字图像了。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值