GAN基础知识及代码

GAN也叫做生成对抗网络,分为两部分,一个是生成网络G,一个是对抗网络D。生成网络和对抗网络进行竞争,生成模型可以被认为是造假者,他们试图制造假币并在不被发现的情况下使用它,而鉴别模型类似于警察,视图发现假币。在这个游戏中,竞争促使两个团队改进他们的方法,直到冒充的产品和正品无法区分。

生成模型和判别模型都是多层感知器。

噪声就是随机生成的数,通过生成器随机生成一张图。(所以生成器只能随机生成图像,不能指定一些条件)

判别器的作用是尽可能的把真实数据集和生成数据集区分开,对于真实数据希望输出1,对于生成数据希望输出0。

相反,生成器希望判别器读入生成数据,输出1。

损失函数:

简化一点就是 D(x) +( 1-D(G(z)) ) , log是单调递增函数,此处的作用是放大损失。

对于生成器G,希望这个函数尽可能小,即D(x)接近0,1-D(G(z))接近0,即D(G(z))接近1.  事实上生成器不管D(x)是否是0,只要确保D(G(z))接近1,即生成的图像判别器判成了真实图像。

标准代码 - (pytorch)

在MNIST手写数字数据集上训练。

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms

1. 数据准备

对真实数据做归一化(-1,1),gan要求的,因为生成器生成的数据是(-1,1),保持两个数据分布一样

transform = transforms.Compose([
    transforms.ToTensor(),     # 归一化为0~1
    transforms.Normalize(0.5,0.5) # 归一化为-1~1
])
train_ds = torchvision.datasets.MNIST('datasets',  # 下载到那个目录下
                                      train=True,
                                      transform=transform,
                                      download=True)
dataloader = torch.utils.data.DataLoader(train_ds, batch_size=64,shuffle=True)
imgs,_ = next(iter(dataloader))
imgs.shape
# torch.Size([64, 1, 28, 28])

2. 定义生成器

输入是长度100的噪声z(正态分布随机数)
输出为(1,28,28)的图片,和MNIST数据集保持一致
Linear1: 100->256
Linear2: 256->512
Linear3: 512->2828
reshape: 28
28->(1,28,28)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__() # 继承父类
        self.main = nn.Sequential(
            nn.Linear(100,256), nn.ReLU(),
            nn.Linear(256,512), nn.ReLU(),
            nn.Linear(512,28*28), 
            nn.Tanh()  # 最后必须用tanh,把数据分布到(-1,1)之间
        )
    def forward(self, x):  # x表示长度为100的噪声输入
        img = self.main(x)
        img = img.view(-1,28,28,1) # 方便等会绘图
        return img

3. 定义判别器

输入为(1,28,28)的mnist图片
输出为二分类的概率,使用sigmoid激活,范围为0~1
BCEloss计算交叉熵损失
判别器推荐使用LeakReLU激活

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.main = nn.Sequential(
            nn.Linear(28*28,512),
            nn.LeakyReLU(), # x小于零是是一个很小的值不是0,x大于0是还是x
            nn.Linear(512,256),
            nn.LeakyReLU(),
            nn.Linear(256,1),
            nn.Sigmoid() # 保证输出范围为(0,1)的概率
        )
    def forward(self, x): # x表示28*28的mnist图片
        img = x.view(-1,28*28)
        img = self.main(img)
        return img

4. 初始化模型、优化器、损失函数

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('training on ',device)
# 模型
gen = Generator().to(device)
dis = Discriminator().to(device)
# 优化器
g_opt = torch.optim.Adam(gen.parameters(),lr=0.0001)
d_opt = torch.optim.Adam(dis.parameters(),lr=0.0001)
# 损失
loss = torch.nn.BCELoss()

5. 绘图函数

随时看到生成器生成的图像

def gen_img_plot(model, test_input):
    prediction = np.squeeze(model(test_input).detach().cpu().numpy())
    fig = plt.figure(figsize=(4,4))
    for i in range(16):
        plt.subplot(4,4,i+1) # 四行四列的第一个
        # imshow函数绘图的输入是(0,1)的float,或者(1,256)的int
        # 但prediction是tanh出来的范围是[-1,1]没法绘图,需要转成0~1(即加1除2)。
        plt.imshow( (prediction[i]+1)/2 )
        plt.axis('off')
    plt.show()
test_input = torch.randn(16, 100, device=device)

6. GAN训练

D_loss = []
G_loss = []
epochs = 40
for epoch in range(epochs):
    d_epoch_loss = 0
    g_epoch_loss = 0
    count = len(dataloader) # 一个epoch的大小
    for step, (img, _) in enumerate(dataloader):
        img = img.to(device) # 一个批次的图片
        size = img.size(0)   # 和和图片对应的原始噪音
        random_noise = torch.randn(size, 100, device=device)
        gen_img = gen(random_noise) # 生成的图像

        d_opt.zero_grad()
        real_output = dis(img)  # 判别器输入真实图片,对真实图片的预测结果,希望是1
        # 判别器在真实图像上的损失
        d_real_loss = loss(real_output, torch.ones_like(real_output)) # size一样全一的tensor
        d_real_loss.backward()

        g_opt.zero_grad()
        # 记得切断生成器的梯度
        fake_output = dis(gen_img.detach()) # 判别器输入生成图片,对生成图片的预测结果,希望是0
        # 判别器在生成图像上的损失
        d_fake_loss = loss(fake_output, torch.zeros_like(fake_output)) # size一样全一的tensor
        d_fake_loss.backward()
        
        d_loss = d_real_loss + d_fake_loss
        d_opt.step()
        
        
        # 生成器的损失
        g_opt.zero_grad()
        fake_output = dis(gen_img) # 希望被判定为1
        g_loss = loss(fake_output, torch.ones_like(fake_output))
        g_loss.backward()
        g_opt.step()
        
        # 每个epoch内的loss累加,循环外再除epoch大小,得到平均loss
        with torch.no_grad():
            d_epoch_loss += d_loss
            g_epoch_loss += g_loss
    # 一个epoch训练完成
    with torch.no_grad():
        d_epoch_loss /= count
        g_epoch_loss /= count
        D_loss.append(d_epoch_loss)
        G_loss.append(g_epoch_loss)
        print('Epoch: ',epoch)
        gen_img_plot(gen, test_input)

结果:

 

 可以看到效果还算可以,这是2014年提出的最基础的GAN,后续要有若干改进的工作,效果更好,有机会再学。

  • 7
    点赞
  • 81
    收藏
    觉得还不错? 一键收藏
  • 7
    评论
【资源说明】 课程设计-基于GAN的甲骨文自动摹写python实现源码+项目说明.zip 模型 是一个 41.823 M 参数的 U-Net。判别器是一个 2.765 M 参数的 CNN。 数据 数据处理部分不知道怎么跑起来了,但是已经处理的数据在我本地的电脑,是叫做 `rt7381` 的文件夹。 用 --dataroot 来指定,该文件夹的格式应该如下: - dataroot - dev - test - train - 白_xxxxx_r.png - 白_xxxxx_t.png - 百_xxxxx_r.png - 百_xxxxx_t.png - ... 摹写 > 这是基于某个 pix2pix 的项目。 # 训练 直接跑 ```bash ./scripts/train_oracle.sh ``` # 生成 用 `transcribe.py`,直接在 `get_opt()` 函数里面手动设置参数。 主要需要关注的参数是 --checkpoint_dir 和 --name,用来指定训练好的 checkpoint,然后在 `main` 函数里面设置 `src_dir` 和 `dst_dir` 来设置输入和输出的图片。 【备注】 1、该资源内项目代码都经过测试运行成功,功能ok的情况下才上传的,请放心下载使用! 2、本项目适合计算机相关专业(如计科、人工智能、通信工程、自动化、电子信息等)的在校学生、老师或者企业员工下载使用,也适合小白学习进阶,当然也可作为毕设项目、课程设计、作业、项目初期立项演示等。 3、如果基础还行,也可在此代码基础上进行修改,以实现其他功能,也可直接用于毕设、课设、作业等。 欢迎下载,沟通交流,互相学习,共同进步!
### 回答1: cvae-gan-zoos-pytorch-beginner这个词汇代表一个初学者使用PyTorch框架进行CVAE-GAN(生成式对抗网络变分自编码器)的编码器,这个网络可以在数据集中进行分析学习,并将数据转换为可以生成新数据的潜在向量空间。该网络不需要通过监督学习标签分类,而是直接使用数据的分布。这个编码器的目的是从潜在空间中生成新数据。此模型可以用于不同的任务,例如图像生成和语音生成。 为了实现这一目标,这一模型采用了CVAE-GAN网络结构,其中CVAE(条件变分自编码器)被用来建立机器学习模型的潜在空间,GAN(生成式对抗网络)作为一个反馈网络,以实现生成数据的目的。最后,这个模型需要使用PyTorch框架进行编程实现,并对数据集进行分析和处理,以便输入到模型中进行训练。这个编码器是一个比较复杂的模型,因此,初学者需要掌握深度学习知识和PyTorch框架的相关知识,并有一定的编程经验,才能实现这一任务。 总的来说,CVAE-GAN是一个在生成数据方面取得了重大成就的深度学习模型,可以应用于各种领域,例如图像、语音和自然语言处理等。然而,对于初学者来说,这是一个相对复杂的任务,需要掌握相关知识和技能,才能成功实现这一模型。 ### 回答2: cvae-gan-zoos-pytorch-beginner是一些机器学习领域的技术工具,使用深度学习方法来实现动物园场景的生成。这些技术包括:生成式对抗网络(GAN)、变分自编码器(CVAE)和pytorch。GAN是一种基于对抗机制的深度学习网络,它可以训练出生成逼真的场景图像;CVAE也是一种深度学习网络,它可以从潜在空间中提取出高质量的场景特征,并生成与原图像相似的图像;pytorch是一个深度学习框架,它可以支持这些技术的开发和实现。 在这个动物园场景生成的过程中,通过GAN和CVAE的组合使用可以从多个角度来创建逼真而多样化的动物园场景。此外,pytorch提供了很多工具和函数来简化代码编写和管理数据,使得训练过程更加容易和高效。对于初学者们来说,这些技术和框架提供了一个良好的起点,可以探索深度学习和图像处理领域的基础理论和实践方法,有助于了解如何使用技术来生成更好的图像结果。 ### 回答3: CVaE-GAN-ZOOS-PyTorch-Beginner是一种结合了条件变分自编码器(CVaE)、生成对抗网络GAN)和零样本学习(Zero-Shot Learning)的深度学习框架。它使用PyTorch深度学习库,适合初学者学习和使用。 CVaE-GAN-ZOOS-PyTorch-Beginner的主要目的是提供一个通用的模型结构,以实现Zero-Shot Learning任务。在这种任务中,模型要从未见过的类别中推断标签。CVaE-GAN-ZOOS-PyTorch-Beginner框架旨在使模型能够从已知类别中学习无监督的表示,并从中推断未知类别的标签。 CVaE-GAN-ZOOS-PyTorch-Beginner的结构由两个关键部分组成:生成器和判别器。生成器使用条件变分自编码器生成潜在特征,并进一步生成样本。判别器使用生成的样本和真实样本区分它们是否相似。这样,生成器被迫学习产生真实的样本,而判别器则被迫学习区分真实的样本和虚假的样本。 总的来说,CVaE-GAN-ZOOS-PyTorch-Beginner框架是一个强大的工具,可以用于解决Zero-Shot Learning问题。它是一个易于使用的框架,适合初学者学习和使用。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值