DCGAN代码练习

基于DCGAN生成CIFAR10代码练习
DCGAN:
将有监督学习的CNN和无监督学习的GAN整合,得到了DCGAN深度卷积生成式对抗网络。

  1. 限制CNN的网络拓扑,得到稳定训练。
  2. 利用无标记数据初始化DCGAN生成器和判别器参数。
  3. 定性分析GAN的filter,GAN的可视化工作流程。
  4. 生成的特征表示的向量计算特性。
    DCGAN的改造:
  5. 去掉了G,D网络中的pooling layer
  6. 在G,D网络中都使用Batch Bormalization
  7. 去掉全连接的隐藏层
  8. G的最后使用ReLU,D的最后使用Tanh
  9. D的每一层都使用leakyRELU。
    DCGAN生成器网络结构

代码链接: Hyfred/gan-torch​github.com
首先读取数据集。

#代码地址:https://github.com/Hyfred/gan-torch/blob/master/dcgan_torch.ipynb
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image

加载CIFA10图片数据集,读取图片批次,张量转换,保存图片。

dataset = CIFAR10(root = './data',
                 download = True, transform = transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size = 64, shuffle = True)

for batch_idx, data in enumerate(dataloader):
    real_images, _ = data
    batch_size = real_images.shape[0]
    print('#{} has {} images.'.format(batch_idx, batch_size))
    if batch_idx % 100 == 0:
        path = './data/CIFAR10_shuffled_batch{:03d}.png'.format(batch_idx)
        save_image(real_images, path, nrow = 8, normalize=False)

下载数据集,将数据集转换成Tensor格式。
然后通过pytorch内置函数DataLoader按批次读取数据集,每批次64张,一共有50000张图,所以有782批次,最后一批次只有16张图。
读取完后打乱数据。

由于数据集本身的结构,所以调用数据集的时候需要使用for循环,enumerate(dataloader)这个命令是将数据集(批次)的索引和数据集(批次)输出。
也就是一批一批进行循环,除了最后一批是16张图,其他每批64张图。

数据集命名为两部分,我们只需要第一部分real_images。它是一个四维向量(批次数,通道数,长宽) = (64 * 3 * 32 * 32), 最后一批为(16 * 3 * 32 * 32)

batch_idx是有几批,batch_size是该批次中有几张图。
最后一个for循环,通过if语句,每100批保存一次原图。所以782批保存(0, 100,… , 700)共8张每100批次保存一张图。
每一张图有batch_size = 64小张,设置显示行数为nrow = 8,即生成8 * 8 = 64 的小图来拼成大图。normalize归一化确认是否调整张量的范围到[0, 1]。

在扩大图像的时候有两种思路,一种是插值(上采样),另一种是反卷积。在搭建生成网络G时,使用了反卷积(内置卷积)“nn.ConvTranspose2d”

转置卷积参考:https://link.zhihu.com/?target=https%3A//blog.csdn.net/qq_37879432/article/details/80297263

n_d_feature = 64
n_channel = 3
dnet = nn.Sequential(
        nn.Conv2d(n_channel, n_d_feature, kernel_size = 4,
                  stride = 2, padding = 1),
        nn.LeakyReLU(0.2),
        nn.Conv2d(n_d_feature, 2 * n_d_feature, kernel_size = 4,
                 stride = 2, padding = 1, bias = False),
        nn.BatchNorm2d(2 * n_d_feature),
        nn.LeakyReLU(0.2),
        nn.Conv2d(2 * n_d_feature, 4 * n_d_feature, kernel_size = 4,
                 stride = 2, padding = 1, bias = False),
        nn.BatchNorm2d(4 * n_d_feature),
        nn.LeakyReLU(0.2),
        nn.Conv2d(4 * n_d_feature, 1, kernel_size=4)
)
print(dnet)

这一步先构建discriminator,潜在大小设置为64 = batch_size?,通道数设置为RGB3通道。
然后采用Sequential顺序模型,定义鉴别器的网络结构。
可以看到包含4个卷积层,3个线性ReLU,2个批次Batch_Norm。

接下来构建生成器Generator

latent_size = 64
n_g_feature = 64

gnet = nn.Sequential(
        nn.ConvTranspose2d(latent_size, 4 * n_g_feature, kernel_size = 4,
                           bias = False),
        nn.BatchNorm2d(4 * n_g_feature),
        nn.ReLU(),
        nn.ConvTranspose2d(4 * n_g_feature, 2 * n_g_feature, kernel_size = 4,
                          stride = 2, padding = 1, bias = False),
        nn.BatchNorm2d(2 * n_g_feature),
        nn.ReLU(),
        nn.ConvTranspose2d(2 * n_g_feature, n_g_feature, kernel_size = 4,
                          stride = 2, padding =1),
        nn.BatchNorm2d(n_g_feature),
        nn.ReLU(),
        nn.ConvTranspose2d(n_g_feature, n_channel, kernel_size=4,
                          stride=2,padding=1),
        nn.Sigmoid()
)
print(gnet)

可以看到生成器由4个卷积层,3个Batch_Norm,3个线性ReLU组成,最后通过Sigmoid将输出限定在(0,1)之间。

import torch.nn.init as init
def weights_init(m):
    if type(m) in [nn.ConvTranspose2d, nn.Conv2d]:
        init.xavier_normal_(m.weight)
    elif type(m) == nn.BatchNorm2d:
        init.normal_(m.weight, 1.0, 0.02)
        init.constant_(m.bias, 0)

gnet.apply(weights_init)
dnet.apply(weights_init)

初始化网络参数。接下来开始训练网络。

import torch
import torch.optim

#loss
criterion = nn.BCEWithLogitsLoss()
#opt
goptimizer = torch.optim.Adam(gnet.parameters(),
                             lr = 0.0002, betas = (0.5, 0.999))
doptimizer = torch.optim.Adam(dnet.parameters(),
                             lr = 0.0002, betas = (0.5, 0.999))
batch_size = 64
fixed_noise = torch.randn(batch_size, latent_size, 1, 1)

定义损失,使用BCELogits损失, 然后优化生成器和判别器。
接下来生成噪声信号,输入generator。

#start training
epoch_num = 5
for epoch in range(epoch_num):
    for batch_idx, data in enumerate(dataloader):
        real_images, _ = data
        batch_size = real_images.shape[0]
        
        #train discriminator
        labels = torch.ones(batch_size)#real data label:1
        preds = dnet(real_images)#let real data in discriminator
        outputs = preds.reshape(-1) #-1.means output row in auto
        dloss_real = criterion(outputs, labels)
        dmean_real = outputs.sigmoid().mean()#cacl the real possibility in (0,1)
        
        noise = torch.randn(batch_size, latent_size, 1, 1)
        fake_images = gnet(noise) #prodece fake data
        labels = torch.zeros(batch_size)  #fake_data label:0
        fake = fake_images.detach()#similiar with fixed parameters of generators
        preds = dnet(fake)#let fake data in discriminator
        outputs = preds.reshape(-1) #transform to auto row
        dloss_fake = criterion(outputs, labels)
        dmean_fake = outputs.sigmoid().mean()#cacl the possbility of fake data
        
        dloss = dloss_real + dloss_fake 
        dnet.zero_grad()
        dloss.backward()
        doptimizer.step()
        
        #train generator
        labels = torch.ones(batch_size) #label 1 of generator
        preds = dnet(fake_images)
        outputs = preds.reshape(-1)
        gloss = criterion(outputs, labels)
        gmean_fake = outputs.sigmoid().mean()
        
        gnet.zero_grad()
        gloss.backward()
        goptimizer.step()
        
        print('[{}/{}]'.format(epoch,epoch_num) + '[{}/{}]'.format(batch_idx,len(dataloader)) + 
             '鉴别器G损失:{:g} 生成器D损失:{:g}'.format(dloss,gloss) +
             '真数据判真比例:{:g}假数据判真比例:{:g}/{:g}'.format(dmean_real, dmean_fake, gmean_fake))
        if batch_idx % 100 == 0:       
            fake = gnet(fixed_noise)
            path = './data/images_epoch{:02d}_batch{:03d}.png'.format(epoch,batch_idx)
            save_image(fake, path, normalize = False)

训练代码。

torch.save(gnet.state_dict(),'./generator.pth')
torch.save(dnet.state_dict(),'./discriminator.pth')

保存路径。

GPU源码:https://github.com/Hyfred/gan-torch/blob/master/dcgan_torch_gpu.ipynb

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值