Pytorch 使用DCGAN生成MNIST手写数字 入门级教程

DCGAN的原理本文不再介绍,可以参考:DCGAN论文解读-----DCGAN原理简介与基础GAN的区别

之前发过一篇利用GAN生成手写数字的实战演示,具体参考:入门GAN实战---生成MNIST手写数据集代码实现pytorch

由于利用GAN生成的图像噪声较多,因此利用DCGAN再次完成该实验。两种方法区别不大,只是在定义生成器和鉴别器的时候稍有改动。具体演示如下:

1.加载MNIST手写数据集

    如果已经提前下载好MNIST手写数据集,记得把代码中的download改为False。具体MNIST数据集下载方法参考:深度学习入门--MNIST数据集及创建自己的手写数字数据集

# 加载数据
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean=0.5, std=0.5)])

train_ds = torchvision.datasets.MNIST('data/',
                                      train=True,
                                      transform=transform,
                                      download= True)
dataloader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)

2.定义生成器Generator

与基础GAN的生成器相比,利用了反卷积并添加了BN层

# 定义生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.linear1 = nn.Linear(100, 256*7*7)  # 希望生成1*28*28的图片 7反卷积后14,再反卷积28 pytorch中channel在前
        self.bn1 = nn.BatchNorm1d(256*7*7)
        self.deconv1 = nn.ConvTranspose2d(256, 128,
                                          kernel_size=(3,3),
                                          str
  • 12
    点赞
  • 87
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论
以下是使用DCGAN训练MNIST数据集的步骤: 1.导入必要的库和模块 ```python import torch import torch.nn as nn import torchvision import torchvision.transforms as transforms import numpy as np import matplotlib.pyplot as plt ``` 2.加载数据集 ```python transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=0.5, std=0.5)]) train_ds = torchvision.datasets.MNIST('data/', train=True, transform=transform, download=True) dataloader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True) ``` 3.定义生成器Generator ```python class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.fc1 = nn.Linear(100, 256) self.fc2 = nn.Linear(256, 512) self.fc3 = nn.Linear(512, 1024) self.fc4 = nn.Linear(1024, 784) self.relu = nn.ReLU() self.tanh = nn.Tanh() def forward(self, x): x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.relu(self.fc3(x)) x = self.tanh(self.fc4(x)) return x ``` 4.定义判别器Discriminator ```python class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.fc1 = nn.Linear(784, 512) self.fc2 = nn.Linear(512, 256) self.fc3 = nn.Linear(256, 1) self.leaky_relu = nn.LeakyReLU(0.2) self.sigmoid = nn.Sigmoid() def forward(self, x): x = x.view(x.size(0), -1) x = self.leaky_relu(self.fc1(x)) x = self.leaky_relu(self.fc2(x)) x = self.sigmoid(self.fc3(x)) return x ``` 5.初始化生成器和判别器 ```python generator = Generator() discriminator = Discriminator() ``` 6.定义损失函数和优化器 ```python criterion = nn.BCELoss() lr = 0.0002 optimizer_g = torch.optim.Adam(generator.parameters(), lr=lr) optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=lr) ``` 7.训练模型 ```python num_epochs = 50 for epoch in range(num_epochs): for i, (images, _) in enumerate(dataloader): # 训练判别器 discriminator.zero_grad() real_images = images.view(-1, 784) real_labels = torch.ones(images.size(0), 1) fake_labels = torch.zeros(images.size(0), 1) z = torch.randn(images.size(0), 100) fake_images = generator(z) outputs_real = discriminator(real_images) outputs_fake = discriminator(fake_images) loss_d_real = criterion(outputs_real, real_labels) loss_d_fake = criterion(outputs_fake, fake_labels) loss_d = loss_d_real + loss_d_fake loss_d.backward() optimizer_d.step() # 训练生成器 generator.zero_grad() z = torch.randn(images.size(0), 100) fake_images = generator(z) outputs = discriminator(fake_images) loss_g = criterion(outputs, real_labels) loss_g.backward() optimizer_g.step() # 打印损失 if (i + 1) % 100 == 0: print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}' .format(epoch, num_epochs, i + 1, len(dataloader), loss_d.item(), loss_g.item())) ``` 8.生成图片 ```python # 生成随机噪声 z = torch.randn(64, 100) # 生成图片 fake_images = generator(z) # 将图片转换为numpy数组 fake_images = fake_images.detach().numpy() # 将图片可视化 fig, axs = plt.subplots(8, 8, figsize=(10, 10)) cnt = 0 for i in range(8): for j in range(8): axs[i, j].imshow(fake_images[cnt].reshape(28, 28), cmap='gray') axs[i, j].axis('off') cnt += 1 plt.show() ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

码农男孩

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值