使用GAN生成CIFAR10图像

参考肖智清老师的《神经网络与PyTorch实战》

import torch.nn as nn
import torch.nn.init as init
import torch
import torch.optim
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torchvision.utils import save_image


latent_size = 64
n_channel = 3
n_g_feature = 64
gnet = nn.Sequential(
    nn.ConvTranspose2d(latent_size, 4 * n_g_feature, kernel_size=4, bias=False),
    nn.BatchNorm2d(4 * n_g_feature),
    nn.ReLU(),

    nn.ConvTranspose2d(4 * n_g_feature, 2 * n_g_feature, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(2 * n_g_feature),
    nn.ReLU(),

    nn.ConvTranspose2d(2 * n_g_feature, n_g_feature, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(n_g_feature),
    nn.ReLU(),

    nn.ConvTranspose2d(n_g_feature, n_channel, kernel_size=4, stride=2, padding=1),
    nn.Sigmoid()
)


n_d_feature = 64
dnet = nn.Sequential(
    nn.Conv2d(n_channel, n_d_feature, kernel_size=4, stride=2, padding=1),
    nn.LeakyReLU(0.2),

    nn.Conv2d(n_d_feature, 2 * n_d_feature, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(2 * n_d_feature),
    nn.LeakyReLU(0.2),

    nn.Conv2d(2 * n_d_feature, 4 * n_d_feature, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(4 * n_d_feature),
    nn.LeakyReLU(0.2),

    nn.Conv2d(4 * n_d_feature, 1, kernel_size=4)
)


def weights_init(m):
    if type(m) in [nn.ConvTranspose2d, nn.Conv2d]:
        init.xavier_normal_(m.weight)
    elif type(m) == nn.BatchNorm2d:
        init.normal_(m.weight, 1.0, 0.02)
        init.constant_(m.bias, 0)

gnet.apply(weights_init)
dnet.apply(weights_init)


dataset = CIFAR10(root='./CIFARdata', download=True, transform=transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)


criterion = nn.BCEWithLogitsLoss()
goptimizer = torch.optim.Adam(gnet.parameters(), lr=0.0002, betas=(0.5, 0.999))
doptimizer = torch.optim.Adam(dnet.parameters(), lr=0.0002, betas=(0.5, 0.999))


batch_size = 64
fixed_noises = torch.randn(batch_size, latent_size, 1, 1)

epoch_num = 15
for epoch in range(epoch_num):
    for batch_idx, data in enumerate(dataloader):
        real_images, _ = data
        batch_size = real_images.size(0)

        labels = torch.ones(batch_size)
        preds = dnet(real_images)
        outputs = preds.reshape(-1)
        dloss_real = criterion(outputs, labels)
        dmean_real = outputs.sigmoid().mean()

        noises = torch.randn(batch_size, latent_size, 1, 1)
        fake_images = gnet(noises)
        labels = torch.zeros(batch_size)
        fake = fake_images.detach()

        preds = dnet(fake)
        outputs = preds.view(-1)
        dloss_fake = criterion(outputs, labels)
        dmean_fake = outputs.sigmoid().mean()

        dloss = dloss_real + dloss_fake
        dnet.zero_grad()
        dloss.backward()
        doptimizer.step()


        labels = torch.ones(batch_size)
        preds = dnet(fake_images)
        outputs = preds.view(-1)
        gloss = criterion(outputs, labels)
        gmean_fake = outputs.sigmoid().mean()
        gnet.zero_grad()
        gloss.backward()
        goptimizer.step()

        if batch_idx % 100 == 0:
            fake = gnet(fixed_noises)
            save_image(fake, f'./GAN_saved02/images_epoch{epoch:02d}_batch{batch_idx:03d}.png')

            print(f'Epoch index: {epoch}, {epoch_num} epoches in total.')
            print(f'Batch index: {batch_idx}, the batch size is {batch_size}.')
            print(f'Discriminator loss is: {dloss}, generator loss is: {gloss}', '\n',
                  f'Discriminator tells real images real ability: {dmean_real}', '\n',
                  f'Discriminator tells fake images real ability: {dmean_fake:g}/{gmean_fake:g}')


gnet_save_path = 'gnet.pt'
torch.save(gnet, gnet_save_path)
# gnet = torch.load(gnet_save_path)
# gnet.eval()

dnet_save_path = 'dnet.pt'
torch.save(dnet, dnet_save_path)
# dnet = torch.load(dnet_save_path)
# dnet.eval()

for i in range(100):
    noises = torch.randn(batch_size, latent_size, 1, 1)
    fake_images = gnet(noises)
    save_image(fake, f'./test_GAN/{i}.png')

# print(gnet, dnet)

生成图片依然很模糊
在这里插入图片描述

在这里插入图片描述

保存神经网络参数,可以再写一个小循环生成多张图片:

from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torchvision.utils import save_image
import os
import torch
import torch.optim


dataset = CIFAR10(root='./CIFARdata', download=True, transform=transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

for batch_idx, data in enumerate(dataloader):
    real_images, _ = data
    batch_size = real_images.size(0)
    print(f'#{batch_idx} has {batch_size} images.')
    if batch_idx % 100 == 0:
        path_ = os.path.join(r'I:\save_CIFAR', f'{batch_idx}.png')
        save_image(real_images, path_, normalize=True)
  • 5
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
gan cifar10是指使用生成对抗网络(GAN)来对CIFAR-10数据集进行图像生成图像分类的任务。 CIFAR-10是一个经典的计算机视觉数据集,包含10个不同类别的60000张32x32像素彩色图像。这个数据集常用于评估图像分类算法的性能。而GAN是一种由生成器和判别器组成的深度学习框架,通过对抗学习的方式训练生成生成以假乱真的图像,并让判别器尽可能准确地区分真实图像生成图像gan cifar10的任务可以分为两个方向: 第一,使用GAN生成CIFAR-10图像。通过训练一个生成器网络,它从随机噪声中生成逼真的CIFAR-10图像。这个生成器网络经过不断的迭代训练,使得生成图像能够在视觉上与真实的CIFAR-10图像越来越接近。这个任务的目标是让生成器网络生成图像能够与原始CIFAR-10数据集的图像在视觉上无法区分。 第二,使用GAN对CIFAR-10图像进行分类。在这个任务中,GAN被用作一个特征提取器和分类器,它能够自动学习到CIFAR-10图像的特征表示,并基于这些特征对图像进行分类。通过使用GAN,可以让分类器更好地理解CIFAR-10数据集的特点,从而提高分类的准确性。 综上所述,gan cifar10是指使用GAN对CIFAR-10数据集进行图像生成图像分类的任务。通过训练生成器网络来生成逼真的CIFAR-10图像,或使用GAN作为特征提取器和分类器来对CIFAR-10图像进行分类,可以提高图像生成图像分类的性能。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值