GAN的基本总结和小型demo

GAN的基本总结和小型demo

关于GANS(Generative Adversarial Networks)

属于生成模型(generative models)

属于无监督学习(unsupervised learning)

在不给定目的值的情况下,学习所给数据的底层结构。

目前可生成最清晰的图像。

易于训练(不需要统计推断),只需要反向推断就能够获得梯度。

由于训练动态不稳定,难以优化。

基本不能做统计推断。

属于直接隐式密度模型,没有明确定义概率分布函数模型。

Generator和Discriminator

Discriminator

最大化被分类为属于真数据集的真数据输入

最小化被分类为属于真数据集的假数据输入

Generator

最大化被分类为属于真数据集的假数据输入

这意味着用于此网络的损耗/误差函数(loss/error函数)要最大化

经过许多步的训练,Generator和Discriminator都有足够的能力,均不能再进行改进,此时Generator就能生成真实的合成数据,而Discriminator已经无法区分。

训练GAN的基本步骤

1.采样噪声集和真实数据集,每个数据集具有大小m。

2.在这个数据上训练鉴别器。

3.采样具有大小m的不同噪声子集。

4.根据这个数据训练生成器。

5.从步骤1重复。

GAN的小型demo

1.导入相关库

import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.transforms as t
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torchvision.datasets as dataset
import numpy as np
# 绘制图像库
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

2.设置plt属性

plt.rcParams['figure.figsize'] = (10.0, 8.0)  # 设置大小
plt.rcParams['image.interpolation'] = 'nearest'  # 设置插值模式
plt.rcParams['image.cmap'] = 'gray'  # 设置颜色

3.图片显示

def show_images(images):
    images = np.reshape(images, [images.shape[0], -1])  # -1代表自动计算
    sqrtn = int(np.ceil(np.sqrt(images.shape[0])))      # np.ceil取整
    sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
​
    gs = gridspec.GridSpec(sqrtn, sqrtn)
    gs.update(wspace=0.05, hspace=0.05)
​
    for i, img in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis('off')  # 去掉坐标轴
        ax.set_xticklabels([])  # 设置x标记为空
        ax.set_yticklabels([])  # 设置y标记为空
        ax.set_aspect('equal')
        plt.imshow(img.reshape([sqrtimg, sqrtimg]))
    return


4.采样函数

# 采样函数为自己定义的序列采样(即按顺序采样)
class Sampler(sampler.Sampler):
    def __init__(self, num_samples, start=0):
        self.num_samples = num_samples
        self.start = start
​
    def __iter__(self):
        return iter(range(self.start, self.start + self.num_samples))
​
    def __len__(self):
        return self.num_samples

5.训练集和测试集的设置

NUM_TRAIN = 60000   # 训练集数量
NUM_VAL = 10000      # 测试集数量
​
NOISE_DIM = 96       # 噪声维度
batch_size = 128     # 批尺寸
​
mnist_train = dataset.MNIST('./cs231n/datasets/MNIST_data', train=True, download=True, transform=t.ToTensor())
loader_train = DataLoader(mnist_train, batch_size=batch_size, sampler=Sampler(NUM_TRAIN, 0))
# 从0位置开始采样NUM_TRAIN个数
​
mnist_val = dataset.MNIST('./cs231n/datasets/MNIST_data', train=True, download=True, transform=t.ToTensor())
loader_val = DataLoader(mnist_val, batch_size=batch_size, sampler=Sampler(NUM_VAL, NUM_TRAIN))
# 从NUM_TRAIN位置开始采样NUM_VAL个数
​
imgs = loader_train.__iter__().next()[0].view(batch_size, 784).numpy().squeeze()
show_images(imgs)  # 显示训练集图片

6.均匀噪声函数

def sample_noise(batch_size, dim):
    """
    - 产生一个从-1 ~ 1的均匀噪声函数,形状为 [batch_size, dim].
    参数:
    - batch_size: 整型 提供生成的batch_size
    - dim: 整型 提供生成维度
    """
    temp = torch.rand(batch_size, dim) + torch.rand(batch_size, dim)*(-1)
​
    return temp

7.平铺函数

# 平铺函数
​
​
class Flatten(nn.Module):
    def forward(self, x):
        n, c, h, w = x.size()  # 读取为n,c,h,w
        return x.view(n, -1)  # 每张图片把c*h*w的值传入单向量用于后期处理

8.判别器

# 判别器  判断generator产生的图像是否为假,同时判断正确的图像是否为真
​
​
def discriminator():
    model = nn.Sequential(
        Flatten(),
        nn.Linear(784, 256),
        nn.LeakyReLU(0.01, inplace=True),
        nn.Linear(256, 256),
        nn.LeakyReLU(0.01, inplace=True),
        nn.Linear(256, 1)
    )
    return model

9.生成器

# 生成器
​
​
def generator(noise_dim=NOISE_DIM):
    model = nn.Sequential(
        nn.Linear(noise_dim, 1024),
        nn.ReLU(inplace=True),
        nn.Linear(1024, 1024),
        nn.ReLU(inplace=True),
        nn.Linear(1024, 784),
        nn.Tanh(),
    )
    return model

10.损失函数

# GAN中指出的最大化最小化损失的算法
​
​
Bce_loss = nn.BCEWithLogitsLoss()
def discriminator_loss(logits_real, logits_fake):
    loss = None
​
    # Batch size.
    n = logits_real.size()
​
    # 目标label,全部设置为1意味着判别器需要做到的是将正确的全识别为正确,错误的全识别为错误
    true_labels = Variable(torch.ones(n))
​
    real_image_loss = Bce_loss(logits_real, true_labels)  # 识别正确的为正确
    fake_image_loss = Bce_loss(logits_fake, 1 - true_labels)  # 识别错误的为错误
​
    loss = real_image_loss + fake_image_loss
​
    return loss
​
​
def generator_loss(logits_fake):
    n = logits_fake.size()
​
    # 生成器的作用是将所有“假”的向真的(1)靠拢
    true_labels = Variable(torch.ones(n))
​
    # 计算生成器损失
    loss = Bce_loss(logits_fake, true_labels)
​
    return loss

11.Adam优化器

def get_optimizer(model):
    """
    为模型构建并返回一个Adam优化器
    learning rate 1e-3,
    beta1=0.5, and beta2=0.999.
    """
    # params(iterable):可用于迭代优化的参数或者定义参数组的dicts。
    # lr (float, optional) :学习率(默认: 1e-3)
    # betas (Tuple[float, float], optional):用于计算梯度的平均和平方的系数(默认: (0.9, 0.999))
​
    optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.5, 0.999))
    return optimizer

12.GAN函数

def run_a_gan(D, G, D_solver, G_solver, discriminator_loss, generator_loss, show_every=250,
              batch_size=128, noise_size=96, num_epochs=10):
    """
    训练GAN
    - D, G: 分别为判别器和生成器
    - D_solver, G_solver: D,G的优化器
    - discriminator_loss, generator_loss: 计算D,G的损失
    - show_every: 设置每show_every次显示样本
    - batch_size: 每次训练在训练集中取batch_size个样本训练
    - noise_size: 输入进生成器的噪声维度
    - num_epochs: 训练迭代次数
    """
    iter_count = 0
    for epoch in range(num_epochs):
        for x, _ in loader_train:
            if len(x) != batch_size:
                continue
​
            D_solver.zero_grad()
            real_data = Variable(x)
            logits_real = D(2 * (real_data - 0.5))
​
            g_fake_seed = Variable(sample_noise(batch_size, noise_size))
            fake_images = G(g_fake_seed).detach()
            logits_fake = D(fake_images.view(batch_size, 1, 28, 28))
​
            d_total_error = discriminator_loss(logits_real, logits_fake)
            d_total_error.backward()
            D_solver.step()
​
            G_solver.zero_grad()
            g_fake_seed = Variable(sample_noise(batch_size, noise_size))
            fake_images = G(g_fake_seed)
​
            gen_logits_fake = D(fake_images.view(batch_size, 1, 28, 28))
            g_error = generator_loss(gen_logits_fake)
            g_error.backward()
            G_solver.step()
​
            print(iter_count)
​
            if iter_count % show_every == 0:
                print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error, g_error))
                imgs_numpy = fake_images.data.cpu().numpy()
                show_images(imgs_numpy[0:16])
                plt.show()
                print()
            iter_count += 1
​
    print("Completed!")
    imgs_numpy = fake_images.data.cpu().numpy()
    show_images(imgs_numpy[0:16])
    plt.show()
    print()

13.

# 创建判别器
D = discriminator()
​
# 创建生成器
G = generator()
​
# 创建D,G的优化器
D_solver = get_optimizer(D)
G_solver = get_optimizer(G)
​
# 运行GAN
run_a_gan(D, G, D_solver, G_solver, discriminator_loss, generator_loss)

最终结果为:

 

效果很差,后续我还需要对这个初步的GAN进行完善。

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值