在pytorch上利用GAN网络实现0-9数字生成

在pytorch上利用GAN生成对抗网络实现0-9手写数字图片生成

1.背景
GAN生成对抗网络可以生成图片,音频,视频,主要分为生成器网络和判别器网络。

①大致流程
输入一段随机信号,让生成器随机生成一张图片,然后让判别器去识别真图片和假图片。

②何为对抗?
简而言之,让生成器生成的图片让判别器更加难以识别为假图,另一边让判别器不断提升识别真假图片的能力,这样就形成了一个对抗网络。

2.数据集
数据集来自torchvision的dataset的MNIST手写0-9数据集(28x28)
具体请自行了解

3.模型
生成器(Generator)和判别器(Discriminator)
model.py文件

from torch import nn

# 图像生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(100, 256),  
            nn.ReLU(),  
            nn.Linear(256, 512),  
            nn.ReLU(),  
            nn.Linear(512, 784),  #图片大小为28*28=784
            nn.Tanh()  # Tanh激活使得生成数据分布在[-1,1]之间,因为输入的真实数据的经过transforms之后也是这个分布
        )

    def forward(self, x):
        x = self.gen(x)
        return x


# 图像判别器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.f1 = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2)
        )
        self.f2 = nn.Sequential(
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2)
        )
        self.out = nn.Sequential(
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.f1(x)
        x = self.f2(x)
        x = self.out(x)
        return x

4.训练模型
train.py

import torch
from torch import nn
from torch.autograd import Variable
from torchvision import transforms, datasets
from torchvision.utils import save_image

from model import Discriminator, Generator


def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)  # Clamp函数可以将随机变化的数值限制在一个给定的区间[min, max]内:
    out = out.view(-1, 1, 28, 28)  # view()函数作用是将一个多行的Tensor,拼接成一行
    return out


def GAN_train_model(dataset, generator, discriminator, batch_size, epoch, lr, z_dim, device):
    device = device
    batch_size = batch_size
    epoch = epoch
    lr = lr
    z_dim = z_dim

    # 返回一个数据迭代器
    # shuffle:是否打乱顺序
    data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                              batch_size=batch_size,
                                              shuffle=True)

    if device == "CUDA":
        D = discriminator.cuda()
        G = generator.cuda()
    else:
        D = discriminator.cpu()
        G = generator.cpu()

    criterion = nn.BCELoss()  # 定义损失函数

    d_optimizer = torch.optim.Adam(D.parameters(), lr=lr)
    g_optimizer = torch.optim.Adam(G.parameters(), lr=lr)



    steps_per_epoch = len(data_loader)
    # 开始训练
    for cur_epoch in range(epoch):  # 进行多个epoch的训练
        total_d_loss = 0
        total_g_loss = 0

        for i, (img, _) in enumerate(data_loader):

            num_img = img.size(0)
            # 将图像变为1维数据
            img = img.view(num_img, -1)
            real_img = img

            # 定义真实的图片label为1
            real_label = torch.ones(num_img, 1)
            # 定义假的图片的label为0
            fake_label = torch.zeros(num_img, 1)

            if device == "CUDA":
                real_img = real_img.cuda()
                real_label = real_label.cuda()
                fake_label = fake_label.cuda()
            else:
                real_img = real_img.cpu()
                real_label = real_label.cpu()
                fake_label = fake_label.cpu()

            # 判别器训练

            # 将真实图片放入判别器中
            real_out = D(real_img)

            # 得到真实图片的loss
            d_loss_real = criterion(real_out, real_label)
            # 得到真实图片的判别值,real_out输出的值越接近1越好
            real_scores = real_out

            # 计算假的图片的损失
            z = torch.randn(num_img, z_dim)  # 随机生成一些噪声

            if device == "CUDA":
                z = z.cuda()
            else:
                z = z.cpu()

            # 随机噪声放入生成网络中,生成一张假的图片。
            # 避免梯度传到G,因为G不用更新, detach分离
            fake_img = G(z).detach()
            # 判别器判断假的图片
            fake_out = D(fake_img)
            # 得到假的图片的loss
            d_loss_fake = criterion(fake_out, fake_label)
            # 得到假图片的判别值,对于判别器来说,假图片的d_loss_fake损失越接近0越好
            fake_scores = fake_out
            # 损失函数和优化,总的来讲就是训练判别器能判断图片是真图还是假图(生成图)
            d_loss = d_loss_real + d_loss_fake  # 损失包括判真损失和判假损失

            total_d_loss += d_loss.data.item()

            d_optimizer.zero_grad()  # 在反向传播之前,先将梯度归0
            d_loss.backward()  # 将误差反向传播
            d_optimizer.step()  # 更新参数

            # 训练生成器
            # 原理:目的是希望生成的假的图片被判别器判断为真的图片,
            # 在此过程中,将判别器固定,将假的图片传入判别器的结果与real_label的对应,
            # 使得生成的图片让判别器以为是真的
            # 这样就达到了对抗的目的
            # 计算假的图片的损失
            z = torch.randn(num_img, z_dim)  # 得到随机噪声

            if device == "CUDA":
                z = z.cuda()
            else:
                z = z.cpu()

            fake_img = G(z)  # 随机噪声输入到生成器中,得到一副假的图片
            output = D(fake_img)  # 经过判别器得到的结果
            g_loss = criterion(output, real_label)  # 得到的假的图片与真实的图片的label的loss

            total_g_loss += g_loss.data.item()

            # bp and optimize
            g_optimizer.zero_grad()  # 梯度归0
            g_loss.backward()  # 进行反向传播
            g_optimizer.step()  # .step()一般用在反向传播后面,用于更新生成网络的参数


        # 打印每个epoch的损失
        print('Epoch[{}/{}],d_loss:{:.6f},g_loss:{:.6f} '.format(
            cur_epoch, epoch, total_d_loss / steps_per_epoch, total_g_loss / steps_per_epoch)  # 打印的是真实图片的损失均值
        )

        if cur_epoch == 0:
            real_images = to_img(real_img.cpu().data)
            save_image(real_images, './img/real_images.png')

        fake_images = to_img(fake_img.data)
        save_image(fake_images, './img/fake_images-{}.png'.format(cur_epoch + 1))

    # 保存生成器和判别器模型
    torch.save(generator, "model/generator.pkl")
    torch.save(discriminator, "model/discriminator.pkl")


if __name__ == "__main__":
    # 图像变化器,转为tensor并标准化数据
    transform = transforms.Compose([
        transforms.ToTensor(),  # 数据范围[0,1],归一化
        transforms.Normalize((0.5,), (0.5,))  # (x-mean) / std,数据范围[-1,1],经过Normalize后,可以加快模型的收敛速度(不确定)
    ])

    # 加载数据集
    dataset = datasets.MNIST(root='./data/',
                             train=True,
                             transform=transform,
                             download=True)

    # 初始生成器generator与判别器discriminator
    discriminator = Discriminator()
    generator = Generator()

    # batch_size
    batch_size = 128

    # epoch次数
    epoch = 100

    # lr学习率
    lr = 3e-4

    # 噪声维度
    z_dim = 100

    GAN_train_model(
        dataset=dataset,
        discriminator=discriminator,
        generator=generator,
        batch_size=batch_size,
        epoch=epoch,
        lr=lr,
        z_dim=z_dim,
        device="CUDA"
    )

5.训练效果(生成器)
第100个epoch训练效果
在这里插入图片描述
真实数据
在这里插入图片描述
这里可以看到生成器生成的图片跟真的一样。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值