动手学pytorch(3)用GAN来生成手写数字

一、前言

主要参考:

日月光华对GAN的讲解和演示
https://www.bilibili.com/video/BV1xm4y1X7KZ/?p=8&spm_id_from=pageDriver&vd_source=7d9525fad541a2d64d7edb7ee9f5fefa

实验环境:

pytorch 1.10.2pytorch 1.10.2+pycharm

二、代码

#自己写一遍gan网络

#用pytorch实现,数据集使用手写数字集
import torch
import torch.nn as nn#神经网络工具箱

#torchvision
from torchvision import transforms#数据预处理
from torchvision import datasets#mnist数据集

import numpy as np#
import matplotlib.pyplot as plt#绘图包

#确定超参数
batch_size=64
epochs=100
#1.首先加载数据集

#数据集预处理
#对数据集做归一化(-1,1)
img_transforms=transforms.Compose([
    transforms.ToTensor(),#转换成tensor格式
    transforms.Normalize(0.5,0.5)
])

#下载数据集
mnist=datasets.MNIST(root='./data/',train=True,transform=img_transforms,download=False)

#数据加载
dataLoader=torch.utils.data.DataLoader(mnist,batch_size=batch_size,shuffle=True)

#2.定义生成器
#搞明白输入和输出分布是什么,输入是100维的噪声(服从正态分布),输出是(1,28,28)的图片

class Generator(nn.Module):#继承
    def __init__(self):#构造器
        super(Generator,self).__init__()#super()调用父类

        #生成器的网络
        self.net=nn.Sequential(#展平
            nn.Linear(100,256),
            nn.ReLU(),
            nn.Linear(256,512),
            nn.ReLU(),
            nn.Linear(512,28*28),
            nn.Tanh()#输出范围[-1,1]
        )
    def forward(self,z):#z为长度一百的噪声输入
        img=self.net(z)
        img=img.view(-1,28,28,1)#(28*28)->(1,28,28)
        return img

#3.定义判别器
#输入为(1,28,28)的图片,输出是二分类的概率值,输出使用sigmoid激活
#判别器一般使用LeakyReLu激活函数,带一点斜率,初始为0.2
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()

        #判别器网络
        self.net=nn.Sequential(
            nn.Linear(28*28,512),
            nn.LeakyReLU(),
            nn.Linear(512,256),
            nn.LeakyReLU(),
            nn.Linear(256,1),
            nn.Sigmoid()#输出在[0,1]之间
        )

    def forward(self,img):#img为输入判别器的图片(1.28,28)
        x=img.view(-1,28*28)
        x=self.net(x)
        return x


#4.初始化模型

# 定义运行设备
device = 'cuda' if torch.cuda.is_available() else 'cpu'

#在设备上运行
gen=Generator().to(device)
dis=Discriminator().to(device)

#定义优化器
d_optim=torch.optim.Adam(dis.parameters(), lr=0.0001)
g_optim=torch.optim.Adam(gen.parameters(),lr=0.0001)

#用BCELoss计算交叉熵损失(二分类)
loss_fn=nn.BCELoss()

#5.绘图函数

# 查看生成的图片
def gen_img_plot(model,test_input):## 将噪声放进去生成器中
    prediction=np.squeeze(model(test_input).detach().cpu().numpy())#保留梯度
    fig = plt.figure(figsize=(4,4))#16张图片
    for i in  range(16):
        plt.subplot(4,4,i+1)
        plt.imshow((prediction[i]+1)/2)
        plt.axis('off')
    plt.show()
test_input= torch.randn(batch_size,100,device=device)

#6.gan的训练
D_loss=[]
G_loss=[]

#循环多少个epoch
for epoch in range(epochs):
    d_epoch_loss=0
    g_epoch_loss = 0
    count=len(dataLoader)#总样本个数(一个epoch的个数)

    for step,(img,_) in enumerate(dataLoader):#enumerate对数据编号,img,_只取图片,不取标签,img的格式为(64,1,28,28)
        img=img.to(device)
        size=img.size(0)

        #在真实图片上计算判别器损失
        d_optim.zero_grad()
        real_output=dis(img)
        d_real_loss=loss_fn(real_output,torch.ones_like(real_output))#在真实图片上的输出尽可能接近1
        d_real_loss.backward()

        #在生成图片上计算判别器损失
        random_noise=torch.randn(size,100,device=device)
        gen_img=gen(random_noise)
        fake_output=dis(gen_img.detach())#这里只更新判别器,所以要截断梯度
        d_fake_loss=loss_fn(fake_output,torch.zeros_like(fake_output))
        d_fake_loss.backward()

        #判别器的总和损失
        d_loss=d_fake_loss+d_real_loss
        d_optim.step()#优化

        #计算生成器损失
        g_optim.zero_grad()
        fake_output=dis(gen_img)
        g_loss=loss_fn(fake_output,torch.ones_like(fake_output))
        g_loss.backward()
        g_optim.step()

        #损失绘图
        with torch.no_grad():#不更新梯度
            d_epoch_loss+=d_loss#计算一个epoch的判别器总和损失
            g_epoch_loss+=g_loss

    with torch.no_grad():
        d_epoch_loss/=count#计算一个epoch的判别器平均损失
        g_epoch_loss/=count

        D_loss.append(d_epoch_loss.item())#D_loss列表里面存放所有的损失值
        G_loss.append(g_epoch_loss.item())

        print('epoch:',epoch)
        gen_img_plot(gen,test_input)

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值