【GANs入门】pytorch-GANs任务迁移-单个目标(数字的生成)

简述

之前认真学习了网上的一份,代码做了很详细的笔记。
【Gans入门】Pytorch实现Gans代码详解【70+代码】

但是上面的任务只是画一条在一定区间下的曲线。
这里对这个进行迁移,到可以进行图像的生成。

图像的很多数据都没有,但是突然想到在sklearn上的digits是一个非常简单的图片。
这里我想到之前的一份笔记
sklearn学习(一)

这里会使用sklearn自带的小数据来做训练
目标是让神经网络自己学会生成数字。

任务描述

为了让神经网络操作更简单。这里的输入数据只会选择特定数值的数字图片数据。然后丢给对抗生成神经网络学习。让其中的生成器学会如何生成手写数字。

下面是选择用数值1的生成过程

其实可以发现其实是有点这样的感觉了。

在这里插入图片描述

下面的这个是让它学习数字0的效果

可能是由于数字0的细节更粗糙一点,所以,可以发现,我们认为这个0生成的更好。(数字1和数字4其实是有点像的,所以会有点问题,还有这是因为图片像素有点低

在这里插入图片描述

代码详解

导入包

  • torch,numpy这些都是数据处理过程中需要的包
  • matplotlib为了画图
  • sklearn主要是为了它本身带的数据
  • random主要是为了选择标准数据更具有随机性
  • os,shutil,imageio这三个库是为了画出gif动态图
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
import random
import os
import shutil
import imageio

创建临时文件夹

PNGFILE = './png/'
if not os.path.exists(PNGFILE):
    os.mkdir(PNGFILE)
else:
    shutil.rmtree(PNGFILE)
    os.mkdir(PNGFILE)

这里会创建一个临时的文件夹png,会把中途生成的那些图片都存在这,然后我就可以用这些png来生成gif文件

模型参数

  • BATCH_SIZE这个参数表示每次用多少的数据来进行考量。(数值多的话模型进化的会稍微快点)
  • LR_GLR_D表示两个模型的学习率
  • N_IDEAS:启发式因子(生成函数的初始层的节点数)。因为我们要操作的节点数量会特别大(特别是图像问题,但是如果输入节点过于大的话,会需要大量的计算资源。所以用小一点的这个基本够用就行了)
  • target_num :表示的是想要生成的数字。由于数据集中只有(0到9)所以,这里也只能取0到9。
  • image_max表示图片像素点的最大值,这个一开始我用到了,但是后来我修改了代码之后,就用不到了。
  • ART_COMPONENTS:像素点数量(其实本质上跟前一个版本的参考节点数都是一样的)
# Hyper Parameters
BATCH_SIZE = 64
LR_G = 0.00001  # learning rate for generator
LR_D = 0.00001  # learning rate for discriminator
N_IDEAS = 6  # think of this as number of ideas for generating an art work (Generator)
target_num = 0 # target Number

digits = datasets.load_digits()
target = digits.target
data = digits.data[target == target_num]
image_max = max(data.reshape((-1,)))
ART_COMPONENTS = data.shape[-1]  # it could be total point G can draw in the canvas

标准数据

这个函数本质上,这个区间上选BATCH_SIZE个标准数据。
但是,random.sample只能输入的是list所以需要先把data转成list,但是转出来的list又不能直接变成torch中的Tensor,这里需要再转成ndarray,之后再转成Tensor,但是要注意在后面加一个.float()函数的操作。

def artist_works():  # painting from the famous artist (real target)
    return torch.from_numpy(np.array(random.sample(list(data), BATCH_SIZE))).float()

构建模型

生成器模型,但是Linear转成的数据是有可能有负数的数据的,但是作为图片肯定是不可以有这样的数据的。因为数据一定是需要为大于等于0的数据。

所以搭建的这个模型最后一定要加一个ReLU()这样的类似的,来保证没有0的情况。

G = nn.Sequential(  # Generator
    nn.Linear(N_IDEAS, 128),  # random ideas (could from normal distribution)
    nn.ReLU(),
    nn.Linear(128, ART_COMPONENTS),  # making a painting from these random ideas
    nn.ReLU(),
)

D = nn.Sequential(  # Discriminator
    nn.Linear(ART_COMPONENTS, 128),  # receive art work either from the famous artist or a newbie like G
    nn.ReLU(),
    nn.Linear(128, 1),
    nn.Sigmoid(),  # tell the probability that the art work is made by artist
)

构建最优化的模型

opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)

迭代优化

这跟之前的是类似的。

for step in range(10000):
    artist_paintings = artist_works()  # real painting from artist
    G_ideas = torch.randn(BATCH_SIZE, N_IDEAS)  # random ideas

    G_paintings = G(G_ideas)  # fake painting from G (random ideas)

    prob_artist0 = D(artist_paintings)  # D try to increase this prob
    prob_artist1 = D(G_paintings)  # D try to reduce this prob

    D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
    G_loss = torch.mean(torch.log(1. - prob_artist1))

    opt_D.zero_grad()
    D_loss.backward(retain_graph=True)  # reusing computational graph
    opt_D.step()

    opt_G.zero_grad()
    G_loss.backward(retain_graph=True)
    opt_G.step()

画图并保存

    if step % 100 == 0:  # plotting
        plt.cla()
        tempdata = G_paintings[0].detach().numpy()

        tempdata = tempdata.reshape((8, 8))
        plt.imshow(tempdata, cmap=plt.cm.gray_r)
        # plt.draw()
        plt.savefig(PNGFILE + '%d.png' % times)
        filedatalist.append(PNGFILE + '%d.png' % times)
        times += 1
        plt.pause(0.01)

生成gif

generated_images = []
for png_path in filedatalist:
    generated_images.append(imageio.imread(png_path))
shutil.rmtree(PNGFILE)
imageio.mimsave('gan.gif', generated_images, 'GIF', duration=0.1)

全部代码

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
import random
import os
import shutil
import imageio

PNGFILE = './png/'
if not os.path.exists(PNGFILE):
    os.mkdir(PNGFILE)
else:
    shutil.rmtree(PNGFILE)
    os.mkdir(PNGFILE)

# Hyper Parameters
BATCH_SIZE = 64
LR_G = 0.00001  # learning rate for generator
LR_D = 0.00001  # learning rate for discriminator
N_IDEAS = 6  # think of this as number of ideas for generating an art work (Generator)
target_num = 0 # target Number

digits = datasets.load_digits()
target = digits.target
data = digits.data[target == target_num]
image_max = max(data.reshape((-1,)))
ART_COMPONENTS = data.shape[-1]  # it could be total point G can draw in the canvas


def artist_works():  # painting from the famous artist (real target)
    return torch.from_numpy(np.array(random.sample(list(data), BATCH_SIZE))).float()


G = nn.Sequential(  # Generator
    nn.Linear(N_IDEAS, 128),  # random ideas (could from normal distribution)
    nn.ReLU(),
    nn.Linear(128, ART_COMPONENTS),  # making a painting from these random ideas
    nn.ReLU(),
)

D = nn.Sequential(  # Discriminator
    nn.Linear(ART_COMPONENTS, 128),  # receive art work either from the famous artist or a newbie like G
    nn.ReLU(),
    nn.Linear(128, 1),
    nn.Sigmoid(),  # tell the probability that the art work is made by artist
)

opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)
times = 0

filedatalist = []

for step in range(10000):
    artist_paintings = artist_works()  # real painting from artist
    G_ideas = torch.randn(BATCH_SIZE, N_IDEAS)  # random ideas

    G_paintings = G(G_ideas)  # fake painting from G (random ideas)

    prob_artist0 = D(artist_paintings)  # D try to increase this prob
    prob_artist1 = D(G_paintings)  # D try to reduce this prob

    D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
    G_loss = torch.mean(torch.log(1. - prob_artist1))

    opt_D.zero_grad()
    D_loss.backward(retain_graph=True)  # reusing computational graph
    opt_D.step()

    opt_G.zero_grad()
    G_loss.backward(retain_graph=True)
    opt_G.step()

    if step % 100 == 0:  # plotting
        plt.cla()
        tempdata = G_paintings[0].detach().numpy()

        tempdata = tempdata.reshape((8, 8))
        plt.imshow(tempdata, cmap=plt.cm.gray_r)
        # plt.draw()
        plt.savefig(PNGFILE + '%d.png' % times)
        filedatalist.append(PNGFILE + '%d.png' % times)
        times += 1
        plt.pause(0.01)

generated_images = []
for png_path in filedatalist:
    generated_images.append(imageio.imread(png_path))
shutil.rmtree(PNGFILE)
imageio.mimsave('gan.gif', generated_images, 'GIF', duration=0.1)

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
条件GAN(Conditional GANs)是一种生成对抗网络(GAN)的变体,它通过给生成器和鉴别器提供额外的条件信息来改进生成过程。在PyTorch中实现条件GANs时,需要对生成器和判别器的结构进行修改。 引用中给出了生成器的具体实现。生成器接收两个输入:一个是噪声向量x,一个是条件向量c。首先,将x通过线性层进行处理,得到一个大小为128x7x7的张量。然后,将x和c通过cat操作在channels方向上进行合并,形成一个大小为256x7x7的张量。最后,通过三次转置卷积操作将张量的尺寸逐渐放大,最终生成一个大小为1x28x28的图像。 引用中给出了判别器的具体实现。判别器接收两个输入:一个是真实图像x,一个是条件向量c。首先,将c通过线性层进行处理,得到一个大小为1x28x28的张量。然后,将x和c通过cat操作在channels方向上进行合并,形成一个大小为2x28x28的张量。接下来,通过卷积层、LeakyReLU激活函数和Dropout层对张量进行处理。最后,将张量展平后通过全连接层得到一个概率值,表示输入图像为真实图像的概率。 通过以上改进,条件GANs可以在生成过程中根据给定的条件生成特定的图像。这种结构可以应用于各种任务,如图像生成、图像修复和图像转换等。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* [Conditional GAN代码实现(Pytorch)](https://blog.csdn.net/weixin_40330033/article/details/127212518)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] - *3* [pytorch-GANs:我对各种GAN生成对抗网络)架构的实现,例如香草GAN(Goodfellow等),cGAN(Mirza等),...](https://download.csdn.net/download/weixin_42116701/15910571)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

肥宅_Sean

公众号“肥宅Sean”欢迎关注

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值