GAN入门:基本思想,损失函数,基于pytorch用GAN实现mnist手写数字生成

1.基本思想

GAN分为一个生成器(Discriminator,简称D)和一个生成器(Generator,简称G),简单的说,G和D就是两个多层感知器或卷积神经网络,他的基本思想,即为G和D的生成博弈过程。
训练D来让他能辨明真假数据,即给D输入真数据,将label赋值为1,输入假数据,将label赋值为0.
而G是要愚弄D,使他认为G生成的为真数据,即给G输入噪声z,让他生成一个假数据G(z),将G(z)输入D,赋值为1。此G的训练过程中,固定D的参数不变,只调整G的参数,否则D只需简单的迎合G就能达到G的目的。
基本结构:

在这里插入图片描述

2.损失函数

结合上述基本思想,我们可以得出以下损失函数:
在这里插入图片描述
如何理解这个式子呢?首先,固定G,只训练D,要使D(real)尽量的大,D(G(z))尽量的小,即1-(G(z))尽量的大,所以对于D,要max V(D,G)。其次,固定D,只训练G,此时与上式的第一项D(x)就没有关系了,只看后一项,要使D(G(z))尽量的大,即1-(G(z))尽量的小,所以对于G,要min V(D,G)。

3.基于pytorch用GAN实现mnist手写数字生成

3.1 定义一些要用的模块
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image

z_dimension = 100  # the dimension of noise tensor
3.2 读minst数据集

我们使用torchvision扩展库读取mnist数据,只需调用torchvision.datasets.MNIST(),该函数的参数:
root:表示数据将要存在哪里,我们这里设置的是’./data’,那么函数会将解压后的文件存在‘./data/raw’,将处理过的文件存在‘./data/processed’
train:为True表示要读取训练集,为False表示要读取测试集
download:表示是否要从网络上下载数据,一般设为True,如果指定的root位置没有数据,才会下载数据,否则不需要重新下载数据
transform:表示要将读取的原始数据转换为什么格式,为了方便pytorch使用,一般转换为tensor,而这里,我们先将原始数据转换为tensor,再将其做归一化操作,使用torchvision.transforms.Compose函数,把多个步骤合在一起

读取完数据,使用torch.utils.data,DataLoader分批读取类实例trainset和testset的内容

transform = transforms.Compose([
    #将PILImage或者numpy的ndarray转化成Tensor,这样才能进行下一步归一化
    transforms.ToTensor(),
    #transforms.Normalize(mean,std)参数:
    transforms.Normalize([0.5], [0.5]),
])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)

可以验证一下,如图,一个batch内,image.size=128,labels.size=128,与我们在torch.utils.data,DataLoader参数中设置的一样
在这里插入图片描述

3.3 构建生成器和判别器

为了运行速度快一点,我们使用简单的线性结构构建生成器和判别器

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Linear(784, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.dis(x)
        return x


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dimension, 256),
            nn.ReLU(True),
            nn.Linear(256, 256),
            nn.ReLU(True),
            nn.Linear(256, 784),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.gen(x)
        return x
3.4 数据处理,将x的范围由(-1,1)伸缩到(0,1)
def to_img(x):
    out = 0.5 * (x + 1)  # 将x的范围由(-1,1)伸缩到(0,1)
    out = out.view(-1, 1, 28, 28)
    return out
3.5 定义生成器、判别器、优化器
D = Discriminator().to('cpu')
G = Generator().to('cpu')

#因为我们只需要区分real和fake,所以使用二分类交叉熵损失函数即可
criterion = nn.BCELoss()
D_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
G_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)

os.makedirs("MNIST_FAKE", exist_ok=True)
3.6 训练
def train(epoch):
    print('\nEpoch: %d' % epoch)
    #将模型调整到训练状态
    D.train()
    G.train()
    all_D_loss = 0.
    all_G_loss = 0.
    for batch_idx, (inputs, targets) in enumerate(trainloader):
    #使网络在GPU上进行训练
        inputs, targets = inputs.to('cpu'), targets.to('cpu')
        #num_img即为图片的数量
        num_img = targets.size(0)
        #real的标签是1,fake的标签是0
        real_labels = torch.ones_like(targets, dtype=torch.float)
        fake_labels = torch.zeros_like(targets, dtype=torch.float)
        #把输入的28*28图片压平成784,便于输入D进行运算
        inputs_flatten = torch.flatten(inputs, start_dim=1)

        # Train Discriminator
        real_outputs = D(inputs_flatten)
        #criterion就是上一步定义的nn.BCELoss()
        D_real_loss = criterion(real_outputs, real_labels)

        z = torch.randn((num_img, z_dimension))  # Random noise from N(0,1)
        fake_img = G(z)  # Generate fake images
        fake_outputs = D(fake_img.detach())
        D_fake_loss = criterion(fake_outputs, fake_labels)

        D_loss = D_real_loss + D_fake_loss
        #清空上一步的残余更新参数值
        D_optimizer.zero_grad()
        # 误差反向传播, 计算参数更新值
        D_loss.backward()
        # 将参数更新值施加到 net 的 parameters 上
        D_optimizer.step()

        # Train Generator
        z = torch.randn((num_img, z_dimension))
        #fake_img是G从噪声生成的
        fake_img = G(z)
        #再把fake_img送入D,让D判别真假
        G_outputs = D(fake_img)
        G_loss = criterion(G_outputs, real_labels)
        G_optimizer.zero_grad()
        G_loss.backward()
        G_optimizer.step()

        all_D_loss += D_loss.item()
        all_G_loss += G_loss.item()
        print('Epoch {}, d_loss: {:.6f}, g_loss: {:.6f} '
              'D real: {:.6f}, D fake: {:.6f}'.format
              (epoch, all_D_loss/(batch_idx+1), all_G_loss/(batch_idx+1),
               torch.mean(real_outputs), torch.mean(fake_outputs)))

    # Save generated images for every epoch
    fake_images = to_img(fake_img)
    save_image(fake_images, 'MNIST_FAKE/fake_images-{}.png'.format(epoch + 1))

for epoch in range(40):
    train(epoch)

为了更清楚一点,在上述代码中加入这几行,打印batch_idx的内容,inputs, targets的大小:
在这里插入图片描述
输出的部分内容:
在这里插入图片描述
可以看出,batch_idx是循环的次数,inputs是图片,targets是他的标签

3.7 可视化训练过程

代码是上一部分代码的最后两行:

# Save generated images for every epoch
    fake_images = to_img(fake_img)
    save_image(fake_images, 'MNIST_FAKE/fake_images-{}.png'.format(epoch + 1))

由于最后一个batch的大小是96,所以我们输出的也是96张图的一个集合,如图,epoch 0几乎都是噪声:
在这里插入图片描述
在这里插入图片描述
可以看出,在epoch0,G和D的loss都比较大
epoch20:
在这里插入图片描述
epoch40:
在这里插入图片描述
这时的效果虽比epoch0好了很多,至少噪声大大减小,也能模糊的辨别数字,但还不是很理想,可能是因为我们使用线性分类器的原因,训练次数过小也是一个原因,也可以从损失函数等方面改进

  • 12
    点赞
  • 43
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是一个基于GANmnist手写数字生成PyTorch代码示例: ```python import torch import torch.nn as nn from torchvision.datasets import MNIST from torchvision.transforms import ToTensor from torch.utils.data import DataLoader # 定义生成器 class Generator(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(Generator, self).__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.fc3 = nn.Linear(hidden_size, output_size) self.relu = nn.ReLU() def forward(self, x): x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.fc3(x) return x # 定义判别器 class Discriminator(nn.Module): def __init__(self, input_size, hidden_size): super(Discriminator, self).__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.fc3 = nn.Linear(hidden_size, 1) self.sigmoid = nn.Sigmoid() def forward(self, x): x = self.sigmoid(self.fc1(x)) x = self.sigmoid(self.fc2(x)) x = self.sigmoid(self.fc3(x)) return x # 定义超参数 input_size = 100 hidden_size = 256 output_size = 784 batch_size = 128 num_epochs = 200 # 加载MNIST数据集 train_dataset = MNIST(root='./data', train=True, transform=ToTensor(), download=True) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) # 初始化网络 G = Generator(input_size, hidden_size, output_size) D = Discriminator(output_size, hidden_size) # 定义损失函数和优化器 criterion = nn.BCELoss() lr = 0.0002 G_optimizer = torch.optim.Adam(G.parameters(), lr=lr) D_optimizer = torch.optim.Adam(D.parameters(), lr=lr) # 定义真实和假的标签 real_label = torch.ones(batch_size, 1) fake_label = torch.zeros(batch_size, 1) # 训练网络 for epoch in range(num_epochs): for i, (images, _) in enumerate(train_loader): # 定义真实和假的图像 real_images = images.view(batch_size, -1) z = torch.randn(batch_size, input_size) fake_images = G(z) # 训练判别器 D_real_loss = criterion(D(real_images), real_label) D_fake_loss = criterion(D(fake_images.detach()), fake_label) D_loss = D_real_loss + D_fake_loss D_optimizer.zero_grad() D_loss.backward() D_optimizer.step() # 训练生成器 G_loss = criterion(D(fake_images), real_label) G_optimizer.zero_grad() G_loss.backward() G_optimizer.step() # 打印损失 if (i+1) % 100 == 0: print('Epoch [{}/{}], Step [{}/{}], D Loss: {:.4f}, G Loss: {:.4f}' .format(epoch+1, num_epochs, i+1, len(train_loader), D_loss.item(), G_loss.item())) # 保存模型 torch.save(G.state_dict(), 'generator.pth') ``` 在训练完成后,可以使用生成器来生成新的手写数字图像,例如: ```python import matplotlib.pyplot as plt import numpy as np # 加载生成器 G = Generator(input_size, hidden_size, output_size) G.load_state_dict(torch.load('generator.pth')) # 生成图像 z = torch.randn(1, input_size) fake_image = G(z).detach().numpy() fake_image = np.reshape(fake_image, (28, 28)) # 显示图像 plt.imshow(fake_image, cmap='gray') plt.show() ``` 这样就可以生成一个随机的手写数字图像了。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值