DCGAN的原理本文不再介绍,可以参考:DCGAN论文解读-----DCGAN原理简介与基础GAN的区别
之前发过一篇利用GAN生成手写数字的实战演示,具体参考:入门GAN实战---生成MNIST手写数据集代码实现pytorch
由于利用GAN生成的图像噪声较多,因此利用DCGAN再次完成该实验。两种方法区别不大,只是在定义生成器和鉴别器的时候稍有改动。具体演示如下:
1.加载MNIST手写数据集
如果已经提前下载好MNIST手写数据集,记得把代码中的download改为False。具体MNIST数据集下载方法参考:深度学习入门--MNIST数据集及创建自己的手写数字数据集
# 加载数据
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=0.5, std=0.5)])
train_ds = torchvision.datasets.MNIST('data/',
train=True,
transform=transform,
download= True)
dataloader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)
2.定义生成器Generator
与基础GAN的生成器相比,利用了反卷积并添加了BN层
# 定义生成器
class Generator(nn.Module):
def __init__(self):
super(Generator,self).__init__()
self.linear1 = nn.Linear(100, 256*7*7) # 希望生成1*28*28的图片 7反卷积后14,再反卷积28 pytorch中channel在前
self.bn1 = nn.BatchNorm1d(256*7*7)
self.deconv1 = nn.ConvTranspose2d(256, 128,
kernel_size=(3,3),
str