Pytorch构建GAN 实现MNIST手写数字生成

# !/usr/bin/python
# -*- coding: UTF-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as func
import torchvision
import matplotlib.pylab as plt
import numpy as np



batch_size = 160

# 将读取的图片转换为tensor 并标准化
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.5], std=[0.5])
])
"""
ToTensor()能够把灰度范围从0-255变换到0-1之间,而后面的transform.Normalize()则把0-1变换到(-1,1).具体地说,对每个通道而言,Normalize执行以下操作:

image=(image-mean)/std

其中mean和std分别通过(0.5,0.5,0.5)和(0.5,0.5,0.5)进行指定。原来的0-1最小值0则变成(0-0.5)/0.5=-1,而最大值1则变成(1-0.5)/0.5=1.
因为用到的mnist为灰度图 单通道 所以mean和std 只用了一个值

数据如果分布在(0,1)之间,可能实际的bias,就是神经网络的输入b会比较大,而模型初始化时b=0的,这样会导致神经网络收敛比较慢,经过Normalize后,可以加快模型的收敛速度。
因为对RGB图片而言,数据范围是[0-255]的,需要先经过ToTensor除以255归一化到[0,1]之后,再通过Normalize计算过后,将数据归一化到[-1,1]。
"""

dataset = torchvision.datasets.MNIST("./mnist", train=True, transform=transform)
data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size = batch_size, shuffle=True) # shuffle乱序
device = torch.device("cuda")

def denomalize(x):
    """还原被标准化后的图像"""
    out = (x+1) / 2
    out = out.view(32, 28, 28).unsqueeze(1) # 添加channel
    return out.clamp(0,1)

def imshow(img, epoch):
    """打印生成器产生的图片"""
    # torchvision.utils.make_grid用来连接一组图, img为一个tensor(batch, channel, height, weight)
    # .detach()消除梯度
    im = torchvision.utils.make_grid(img, nrow=8).detach().numpy()
    # print(np.shape(im))
    plt.title("Epoch on %d" % epoch+1)
    plt.imshow(im.transpose(1, 2, 0)) # 调整图形标签, plt的图片格式为(height, weight, channel)
    plt.savefig('./save_pic/%d.jpg' % (epoch+1))
    plt.show()

# 判别器模型
class DNet(nn.Module):
    def __init__(self):
        super(DNet, self).__init__()

        self.l1 = nn.Linear(28*28, 256)
        self.a = nn.ReLU()
        self.l2 = nn.Linear(256, 128)
        self.l3 = nn.Linear(128, 1)
        self.s = nn.Sigmoid()


    def forward(self, x):
        x = self.l1(x)
        x = self.a(x)
        x = self.l2(x)
        x = self.a(x)
        x = self.l3(x)
        x = self.s(x)

        return x


# 生成器模型
class GNet(nn.Module):
    def __init__(self):
        super(GNet, self).__init__()

        self.l1 = nn.Linear(10, 128)
        self.a = nn.ReLU()
        self.l2 = nn.Linear(128, 256)
        self.l3 = nn.Linear(256, 28*28)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.l1(x)
        x = self.a(x)
        x = self.l2(x)
        x = self.a(x)
        x = self.l3(x)
        x = self.tanh(x)

        return x

# 构建模型并送入GPU
D = DNet().to(device)
G = GNet().to(device)

print(D)
print(G)

# 设置优化器
D_optimizer = torch.optim.Adam(D.parameters(), lr=0.001)
G_optimizer = torch.optim.Adam(G.parameters(), lr=0.001)




for epoch in range(250):
    cerrent = 0.0 # 正确识别
    for step, data in enumerate(data_loader):
        # 获取真实图集 并拉直
        real_images = data[0].reshape(batch_size, -1).to(device)

        # 构造真假标签
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # 训练辨别器 分别将真图片和真标签喂入判别器、生成图和假标签喂入判别器
        # 判别器的损失为真假训练的损失和
        # print(real_images.size())
        real_outputs = D(real_images)
        real_loss = func.binary_cross_entropy(real_outputs, real_labels)

        z = torch.randn(batch_size, 10).to(device) # 用生成器产生fake图喂入判别器网络
        fake_images = G(z)
        d_fake_outputs = D(fake_images)
        fake_loss = func.binary_cross_entropy(d_fake_outputs, fake_labels)

        d_loss = real_loss + fake_loss
        G_optimizer.zero_grad()
        D_optimizer.zero_grad()
        d_loss.backward()
        D_optimizer.step()

        # 训练生成器
        z = torch.randn(batch_size, 10).to(device)
        fake_images = G(z)
        fake_outputs = D(fake_images)
        g_loss = func.binary_cross_entropy(fake_outputs, real_labels) # 将fake图和真标签喂入判别器, 当g_loss越小生成越真实
        G_optimizer.zero_grad()
        D_optimizer.zero_grad()
        g_loss.backward()
        G_optimizer.step()


        if step % 20 == 19:
            print("epoch: " , epoch+1, "  step: ", step+1, "   d_loss: %.4f" % d_loss.mean().item(),
                  "   g_loss: %.4f" % g_loss.mean().item(), "   d_acc:  %.4f" %  real_outputs.mean().item(),
                  "   d(g)_acc: %.4f" % d_fake_outputs.mean().item())
    # 每10个epoch 进行一次生成
    if epoch % 10 == 9:
        z = torch.randn(32, 10).to(device)
        img = G(z)

        imshow(denomalize(img.to("cpu")), epoch)



torch.save(D.state_dict(), "./D.pth")
torch.save(G.state_dict(), "./G.pth")

 

训练250个epoch后的结果

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值