cGAN/cDCGAN,MNIST数据集初体验(内含原理,代码)

​生成式对抗网络(Generative Adversarial Networks, GAN),简称GAN网络。有人说这是21世纪最让人激动的“发明”,虽然我忘了我是从哪看到的这句话,貌似是发明了卷积神经网络那位大佬说的。我试过以后,对于AI兴趣爱好者来说

确实挺激动的!

对于标题中的cGAN/cDCGAN,小c,全称是conditional,条件的。DC,全称是Deep Convolution,深度卷积。都是GAN网络的一个变种。对于DCGAN与GAN的关系,也很简单,因为最开始GAN网络是用神经网络设计的,而后来出现了计算能力更强的卷积(CNN),训练逻辑相同,只是计算操作不同,当然可以相互替换。

对于原理,网传:一个生成器(Generator),一个判别器(Discriminator),他两相互博弈,相爱相杀,最后产生一个好的结果。。。

What?还要动手吗?

对于此种高端解释,我等菜鸡无法领会,我只想知道网络是怎么训练的?两个部分的输入输出分别是什么?网络如何搭建?Loss如何设计?有了这些,你的程序就可以跑了

还是从代码中理解啥是相爱相杀吧。

先放一张整体原理图,来个大致印象
图片来源:github一个老哥的仓库
那个G,就是生成器,那个D,就是判别器。其余就是常规表示网络的结构了,是如何设计的。各位应该发现图中还有个小y,这就是cGAN网络中的c

较常规GAN网络,多了个条件标签

这里想啰嗦一句,这个版本的cGAN在条件标签的处理上,用的是concatenate操作,也就是在某个维度上,直接叠加相关数据,一会在代码中也有显现。其余的还可不可以用别的操作来改善效果,本人很菜,还没有试过。

如图所示,因为用的是MNIST(手写数字体)数据集,每张图片的shape是[1, 28, 28],也就是单通道,分辨率是28x28。又因为是使用神经网络提取特征,所以需要将图片打平操作,所以生成器(G)最后生成的本来应该是一张图片的shape,这里的话就是784,这个数字各位应该不陌生,不多废话。

可以看到,G的输入就是100维的一个随机数,shape是[100, 100],这里生成100张假的数字体图像,对应的label,小y,也就是[100, 10],做了One-Hot编码。然而输出就是[100, 784],经过一些类似imshow等显示图片的函数的时候,在reshape成[100, 1, 28, 28], 就可以显示啦

再看D,判别器,这个就相对简单一些,就是平常看到的分类网络的结构。输入是由G生成的假的图像数据,输出只有两个,真or假,real or fake,用1和0代替结果,shape为[batch, 1],只有一堆0或1作为label。在代码中一看便知,就理解了。

ok,少废话,上代码(下面是完整的,来源也是GitHub的那位老哥的仓库,稍做了些修改,要不在我的环境下直接跑不了)

import os, time
import matplotlib.pyplot as plt
import itertools
import pickle
import imageio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
​
# G(z)
class generator(nn.Module):
    # initializers
    def __init__(self):
        super(generator, self).__init__()
        self.fc1_1 = nn.Linear(100, 256)
        self.fc1_1_bn = nn.BatchNorm1d(256)
        self.fc1_2 = nn.Linear(10, 256)
        self.fc1_2_bn = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(512, 512)
        self.fc2_bn = nn.BatchNorm1d(512)
        self.fc3 = nn.Linear(512, 1024)
        self.fc3_bn = nn.BatchNorm1d(1024)
        self.fc4 = nn.Linear(1024, 784)# weight_init
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)# forward method
    def forward(self, input, label):
        x = F.relu(self.fc1_1_bn(self.fc1_1(input)))
        y = F.relu(self.fc1_2_bn(self.fc1_2(label)))
        x = torch.cat([x, y], 1)
        x = F.relu(self.fc2_bn(self.fc2(x)))
        x = F.relu(self.fc3_bn(self.fc3(x)))
        x = F.tanh(self.fc4(x))return x
​
class discriminator(nn.Module):
    # initializers
    def __init__(self):
        super(discriminator, self).__init__()
        self.fc1_1 = nn.Linear(784, 1024)
        self.fc1_2 = nn.Linear(10, 1024)
        self.fc2 = nn.Linear(2048, 512)
        self.fc2_bn = nn.BatchNorm1d(512)
        self.fc3 = nn.Linear(512, 256)
        self.fc3_bn = nn.BatchNorm1d(256)
        self.fc4 = nn.Linear(256, 1)# weight_init
    def weight_init(self, mean,
  • 8
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
好的,以下是使用Jittor训练CGAN模型生成手写数字的代码示例: ```python import jittor as jt from jittor import nn from jittor import Dataset from jittor import transform import jittor.random as random import numpy as np class MNIST(Dataset): def __init__(self, train=True, transform=None): self.train = train self.transform = transform if self.train: self.images = np.load("train_images.npy") self.labels = np.load("train_labels.npy") else: self.images = np.load("test_images.npy") self.labels = np.load("test_labels.npy") def __getitem__(self, index): img = self.images[index].astype('float32') / 255.0 label = self.labels[index] if self.transform is not None: img = self.transform(img) return img, label def __len__(self): return len(self.images) class Generator(nn.Module): def __init__(self, latent_dim=100, img_shape=(1, 28, 28)): super(Generator, self).__init__() self.img_shape = img_shape self.model = nn.Sequential( nn.Linear(latent_dim, 128), nn.ReLU(), nn.Linear(128, 256), nn.BatchNorm1d(256, 0.8), nn.ReLU(), nn.Linear(256, 512), nn.BatchNorm1d(512, 0.8), nn.ReLU(), nn.Linear(512, 1024), nn.BatchNorm1d(1024, 0.8), nn.ReLU(), nn.Linear(1024, int(np.prod(img_shape))), nn.Tanh() ) def execute(self, z): img = self.model(z) img = img.reshape((img.shape[0],) + self.img_shape) return img class Discriminator(nn.Module): def __init__(self, img_shape=(1, 28, 28)): super(Discriminator, self).__init__() self.img_shape = img_shape self.model = nn.Sequential( nn.Linear(int(np.prod(img_shape)), 512), nn.LeakyReLU(0.2), nn.Linear(512, 256), nn.LeakyReLU(0.2), nn.Linear(256, 1), nn.Sigmoid(), ) def execute(self, img): img_flat = img.flatten(1) validity = self.model(img_flat) return validity # 定义超参数 latent_dim = 100 img_shape = (1, 28, 28) lr = 0.0002 b1 = 0.5 b2 = 0.999 batch_size = 64 n_epochs = 200 # 定义数据集和数据转换 transform = transform.Compose([ transform.Resize(28), transform.ImageNormalize(mean=0.5, std=0.5), ]) train_dataset = MNIST(train=True, transform=transform) train_loader = jt.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) # 定义生成器和判别器 generator = Generator(latent_dim=latent_dim, img_shape=img_shape) discriminator = Discriminator(img_shape=img_shape) # 定义损失函数和优化器 adversarial_loss = nn.BCELoss() optimizer_G = nn.Adam(generator.parameters(), lr=lr, betas=(b1, b2)) optimizer_D = nn.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2)) # 训练CGAN模型 for epoch in range(n_epochs): for i, (imgs, _) in enumerate(train_loader): # 训练判别器 optimizer_D.zero_grad() real_imgs = jt.array(imgs) z = jt.array(random.normal((imgs.shape[0], latent_dim))) fake_imgs = generator(z) real_labels = jt.ones((batch_size, 1)) fake_labels = jt.zeros((batch_size, 1)) d_loss_real = adversarial_loss(discriminator(real_imgs), real_labels) d_loss_fake = adversarial_loss(discriminator(fake_imgs), fake_labels) d_loss = (d_loss_real + d_loss_fake) / 2.0 d_loss.backward() optimizer_D.step() # 训练生成器 optimizer_G.zero_grad() z = jt.array(random.normal((imgs.shape[0], latent_dim))) fake_imgs = generator(z) g_loss = adversarial_loss(discriminator(fake_imgs), real_labels) g_loss.backward() optimizer_G.step() # 输出训练信息 batches_done = epoch * len(train_loader) + i if batches_done % 400 == 0: print( f"[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(train_loader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]" ) ``` 在上述代码中,我们定义了MNIST数据集类和它的数据转换,生成器和判别器模型,以及损失函数和优化器。在训练循环中,我们依次训练判别器和生成器,并输出训练信息。运行以上代码,训练200个epoch后,我们可以得到生成器生成的手写数字图片。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值