简述
之前认真学习了网上的一份,代码做了很详细的笔记。
【Gans入门】Pytorch实现Gans代码详解【70+代码】
但是上面的任务只是画一条在一定区间下的曲线。
这里对这个进行迁移,到可以进行图像的生成。
图像的很多数据都没有,但是突然想到在sklearn上的digits是一个非常简单的图片。
这里我想到之前的一份笔记
sklearn学习(一)
这里会使用sklearn自带的小数据来做训练
目标是让神经网络自己学会生成数字。
任务描述
为了让神经网络操作更简单。这里的输入数据只会选择特定数值的数字图片数据。然后丢给对抗生成神经网络学习。让其中的生成器学会如何生成手写数字。
下面是选择用数值1的生成过程
其实可以发现其实是有点这样的感觉了。
下面的这个是让它学习数字0的效果
可能是由于数字0的细节更粗糙一点,所以,可以发现,我们认为这个0生成的更好。(数字1和数字4其实是有点像的,所以会有点问题,还有这是因为图片像素有点低)
代码详解
导入包
torch,numpy
这些都是数据处理过程中需要的包matplotlib
为了画图sklearn
主要是为了它本身带的数据random
主要是为了选择标准数据更具有随机性os,shutil,imagei
o这三个库是为了画出gif动态图
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
import random
import os
import shutil
import imageio
创建临时文件夹
PNGFILE = './png/'
if not os.path.exists(PNGFILE):
os.mkdir(PNGFILE)
else:
shutil.rmtree(PNGFILE)
os.mkdir(PNGFILE)
这里会创建一个临时的文件夹png,会把中途生成的那些图片都存在这,然后我就可以用这些png来生成gif文件
模型参数
BATCH_SIZE
这个参数表示每次用多少的数据来进行考量。(数值多的话模型进化的会稍微快点)LR_G
跟LR_D
表示两个模型的学习率N_IDEAS
:启发式因子(生成函数的初始层的节点数)。因为我们要操作的节点数量会特别大(特别是图像问题,但是如果输入节点过于大的话,会需要大量的计算资源。所以用小一点的这个基本够用就行了)target_num
:表示的是想要生成的数字。由于数据集中只有(0到9)所以,这里也只能取0到9。image_max
表示图片像素点的最大值,这个一开始我用到了,但是后来我修改了代码之后,就用不到了。ART_COMPONENTS
:像素点数量(其实本质上跟前一个版本的参考节点数都是一样的)
# Hyper Parameters
BATCH_SIZE = 64
LR_G = 0.00001 # learning rate for generator
LR_D = 0.00001 # learning rate for discriminator
N_IDEAS = 6 # think of this as number of ideas for generating an art work (Generator)
target_num = 0 # target Number
digits = datasets.load_digits()
target = digits.target
data = digits.data[target == target_num]
image_max = max(data.reshape((-1,)))
ART_COMPONENTS = data.shape[-1] # it could be total point G can draw in the canvas
标准数据
这个函数本质上,这个区间上选BATCH_SIZE个标准数据。
但是,random.sample只能输入的是list所以需要先把data转成list,但是转出来的list又不能直接变成torch中的Tensor,这里需要再转成ndarray,之后再转成Tensor,但是要注意在后面加一个.float()
函数的操作。
def artist_works(): # painting from the famous artist (real target)
return torch.from_numpy(np.array(random.sample(list(data), BATCH_SIZE))).float()
构建模型
生成器模型,但是Linear转成的数据是有可能有负数的数据的,但是作为图片肯定是不可以有这样的数据的。因为数据一定是需要为大于等于0的数据。
所以搭建的这个模型最后一定要加一个ReLU()这样的类似的,来保证没有0的情况。
G = nn.Sequential( # Generator
nn.Linear(N_IDEAS, 128), # random ideas (could from normal distribution)
nn.ReLU(),
nn.Linear(128, ART_COMPONENTS), # making a painting from these random ideas
nn.ReLU(),
)
D = nn.Sequential( # Discriminator
nn.Linear(ART_COMPONENTS, 128), # receive art work either from the famous artist or a newbie like G
nn.ReLU(),
nn.Linear(128, 1),
nn.Sigmoid(), # tell the probability that the art work is made by artist
)
构建最优化的模型
opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)
迭代优化
这跟之前的是类似的。
for step in range(10000):
artist_paintings = artist_works() # real painting from artist
G_ideas = torch.randn(BATCH_SIZE, N_IDEAS) # random ideas
G_paintings = G(G_ideas) # fake painting from G (random ideas)
prob_artist0 = D(artist_paintings) # D try to increase this prob
prob_artist1 = D(G_paintings) # D try to reduce this prob
D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
G_loss = torch.mean(torch.log(1. - prob_artist1))
opt_D.zero_grad()
D_loss.backward(retain_graph=True) # reusing computational graph
opt_D.step()
opt_G.zero_grad()
G_loss.backward(retain_graph=True)
opt_G.step()
画图并保存
if step % 100 == 0: # plotting
plt.cla()
tempdata = G_paintings[0].detach().numpy()
tempdata = tempdata.reshape((8, 8))
plt.imshow(tempdata, cmap=plt.cm.gray_r)
# plt.draw()
plt.savefig(PNGFILE + '%d.png' % times)
filedatalist.append(PNGFILE + '%d.png' % times)
times += 1
plt.pause(0.01)
生成gif
generated_images = []
for png_path in filedatalist:
generated_images.append(imageio.imread(png_path))
shutil.rmtree(PNGFILE)
imageio.mimsave('gan.gif', generated_images, 'GIF', duration=0.1)
全部代码
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
import random
import os
import shutil
import imageio
PNGFILE = './png/'
if not os.path.exists(PNGFILE):
os.mkdir(PNGFILE)
else:
shutil.rmtree(PNGFILE)
os.mkdir(PNGFILE)
# Hyper Parameters
BATCH_SIZE = 64
LR_G = 0.00001 # learning rate for generator
LR_D = 0.00001 # learning rate for discriminator
N_IDEAS = 6 # think of this as number of ideas for generating an art work (Generator)
target_num = 0 # target Number
digits = datasets.load_digits()
target = digits.target
data = digits.data[target == target_num]
image_max = max(data.reshape((-1,)))
ART_COMPONENTS = data.shape[-1] # it could be total point G can draw in the canvas
def artist_works(): # painting from the famous artist (real target)
return torch.from_numpy(np.array(random.sample(list(data), BATCH_SIZE))).float()
G = nn.Sequential( # Generator
nn.Linear(N_IDEAS, 128), # random ideas (could from normal distribution)
nn.ReLU(),
nn.Linear(128, ART_COMPONENTS), # making a painting from these random ideas
nn.ReLU(),
)
D = nn.Sequential( # Discriminator
nn.Linear(ART_COMPONENTS, 128), # receive art work either from the famous artist or a newbie like G
nn.ReLU(),
nn.Linear(128, 1),
nn.Sigmoid(), # tell the probability that the art work is made by artist
)
opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)
times = 0
filedatalist = []
for step in range(10000):
artist_paintings = artist_works() # real painting from artist
G_ideas = torch.randn(BATCH_SIZE, N_IDEAS) # random ideas
G_paintings = G(G_ideas) # fake painting from G (random ideas)
prob_artist0 = D(artist_paintings) # D try to increase this prob
prob_artist1 = D(G_paintings) # D try to reduce this prob
D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
G_loss = torch.mean(torch.log(1. - prob_artist1))
opt_D.zero_grad()
D_loss.backward(retain_graph=True) # reusing computational graph
opt_D.step()
opt_G.zero_grad()
G_loss.backward(retain_graph=True)
opt_G.step()
if step % 100 == 0: # plotting
plt.cla()
tempdata = G_paintings[0].detach().numpy()
tempdata = tempdata.reshape((8, 8))
plt.imshow(tempdata, cmap=plt.cm.gray_r)
# plt.draw()
plt.savefig(PNGFILE + '%d.png' % times)
filedatalist.append(PNGFILE + '%d.png' % times)
times += 1
plt.pause(0.01)
generated_images = []
for png_path in filedatalist:
generated_images.append(imageio.imread(png_path))
shutil.rmtree(PNGFILE)
imageio.mimsave('gan.gif', generated_images, 'GIF', duration=0.1)