基于 GAN的MNIST 手写字体生成

该文介绍了一个基于PyTorch实现的GAN网络,用于生成28x28的手写数字图像。网络包括Generator和Discriminator两部分,Generator从正态分布噪声生成图像,Discriminator则区分真实图像和生成的图像。文章提供了完整的代码示例,展示了从随机噪声到生成图像的过程,并给出了不同训练阶段的生成结果。
摘要由CSDN通过智能技术生成

(参考 b站大神 日月光华 教程复现)

原理:

这里通过一个简单的手写字体生产网络了解GAN的基本原理,主要包含generator 和 discriminator两部分,其中generator 的输入 是正太分布噪声,输出是28x28的图像, discriminator 的输入是28x28的图像,分别是真是图像和generator生成的图像,输出是概率值。对抗的含义体现在优化目标上,generator 的目标是使输出的图像尽量被discriminator判别为真,而discriminator的目标是尽量将噪声生成的图像判别为假,真实图像判别为真。感兴趣的小伙伴可以在此网络上进行修改:

  • 添加可变的学习率

  • 添加卷积层

  • 增加网络深度

话不多说,上代码,可运行:


import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from   torchvision import transforms


#draw , pred
def draw_genImg(model, input):
    pred = np.squeeze(model(input).detach().cpu().numpy())
    size =  input.shape[0]
    for i in range(size):
        plt.subplot(4, int(size/4), i+1)
        plt.imshow((pred[i]+1)/2) #[0,1]
    plt.show()

#generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(100,256),
            nn.LeakyReLU(),
            nn.Linear(256,512),
            nn.LeakyReLU(),
            nn.Linear(512,28*28),
            nn.Tanh()
        )
    def forward(self,x):
        x= self.main(x)
        img = x.view(-1,28,28,1)
        return img

#discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(28*28,512),
            nn.LeakyReLU(),
            nn.Linear(512,256),
            nn.LeakyReLU(),
            nn.Linear(256,1),
            nn.Sigmoid()
        )
    def forward(self,x):
        x = x.view(-1,28*28)
        conf = self.main(x)
        return conf

if __name__=="__main__":

    batch_size = 64
    epoch_size = 200
    pred_size  = 16
    device  = 'cuda' if torch.cuda.is_available() else 'cpu'
    test_input = torch.randn(pred_size,100, device =device)
    # data
    transform = transforms.Compose([
        transforms.ToTensor(),   #0-1
        transforms.Normalize(0.5,0.5), #(mean-var:0.5,0.5)->-1,1
    ])
   
    train_ds   = torchvision.datasets.MNIST('data', train = True, transform = transform, download=True)  #data folder
    dataloader = torch.utils.data.DataLoader(train_ds, batch_size= batch_size, shuffle = True)

    
    gen  = Generator().to(device)
    dis  = Discriminator().to(device)
   
    g_optim = torch.optim.Adam(gen.parameters(), lr = 0.0001)
    d_optim = torch.optim.Adam(dis.parameters(), lr = 0.0001)
    loss_fn = nn.BCELoss()

    D_loss =  []
    G_loss =  []

    for epoch in range(epoch_size):
        d_epoch_loss = 0
        g_epoch_loss = 0
        count = len(dataloader)
        
        for step,(img,_) in enumerate(dataloader):
            img = img.to(device)
            size= img.size(0)
            random_noise  = torch.randn(size, 100, device = device)
            
            #optim the generator, gen->1
            g_optim.zero_grad()

            fake_out    =  dis(gen(random_noise))
            g_loss      = loss_fn(fake_out, torch.ones_like(fake_out))
            g_loss.backward()
            g_optim.step()


            #optim the discriminator, img->1, gen->0
            d_optim.zero_grad()

            real_out    = dis(img)
            d_real_loss = loss_fn(real_out, torch.ones_like(real_out))
            d_real_loss.backward()

            fake_out    =  dis(gen(random_noise).detach())
            d_fake_loss = loss_fn(fake_out, torch.zeros_like(fake_out))
            d_fake_loss.backward()

            d_loss = d_real_loss + d_fake_loss
            d_optim.step()  #over
            
            #统计
            with  torch.no_grad():
                g_epoch_loss +=  g_loss
                d_epoch_loss +=  d_loss
    
        #统计每次迭代后的loss 和生成结果        
    with torch.no_grad():
        g_epoch_loss /= count
        d_epoch_loss /= count
        G_loss.append(g_epoch_loss)
        D_loss.append(d_epoch_loss)
        print('epoch: ', epoch, 'g_epoch_loss:', g_epoch_loss, 'd_epoch_loss:', d_epoch_loss)

        if epoch > epoch_size - 5:
            draw_genImg(gen, test_input)
        


                


上面分别为第1代、第30代和第200代的结果。

  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值