简述
其实是根据我之前写的两个代码改的。(之前已经有过非常详细的解释了,可以去看看)
同时,在结合了我之前写的DCGANs的时候,实现的一份代码
MNIST上选特定的数值,是根据下面的这篇文章得到的。
之前的代码上都有非常详细的解释。这里只是基于上面的一点点改进而已。就不给出特别详细的解释。但是代码中任然保留有注释部分。
图形演变过程
代码
import torch
import torch.nn as nn
import torchvision
import torch.utils.data as Data
import matplotlib.pyplot as plt
import os
import shutil
import imageio
PNGFILE = './png/'
if not os.path.exists(PNGFILE):
os.mkdir(PNGFILE)
else:
shutil.rmtree(PNGFILE)
os.mkdir(PNGFILE)
# Hyper Parameters
BATCH_SIZE = 64
LR_G = 0.0001 # learning rate for generator
LR_D = 0.0001 # learning rate for discriminator
N_IDEAS = 100 # think of this as number of ideas for generating an art work (Generator)
target_num = 0 # target Number
EPOCH = 10 # 训练整批数据多少次
DOWNLOAD_MNIST = False # 已经下载好的话,会自动跳过的
ART_COMPONENTS = 28 * 28
# Mnist 手写数字
class myMNIST(torchvision.datasets.MNIST):
def __init__(self, root, train=True, transform=None, target_transform=None, download=False, targetNum=None):
super(myMNIST, self).__init__(
root,
train=train,
transform=transform,
target_transform=target_transform,
download=download)
if targetNum != None:
self.train_data = self.train_data[self.train_labels == targetNum]
self.train_data = self.train_data[:int(self.__len__() / BATCH_SIZE) * BATCH_SIZE]
self.train_labels = self.train_labels[self.train_labels == targetNum][
:int(self.__len__() / BATCH_SIZE) * BATCH_SIZE]
def __len__(self):
if self.train:
return self.train_data.shape[0]
else:
return 10000
train_data = myMNIST(
root='./mnist/', # 保存或者提取位置
train=True, # this is training data
transform=torchvision.transforms.ToTensor(), # 转换 PIL.Image or numpy.ndarray 成
# torch.FloatTensor (C x H x W), 训练的时候 normalize 成 [0.0, 1.0] 区间
download=DOWNLOAD_MNIST, # 没下载就下载, 下载了就不用再下了
targetNum=target_num
)
print(len(train_data))
# print(train_data.shape)
# 训练集丢BATCH_SIZE个, 图片大小为28*28
train_loader = Data.DataLoader(
dataset=train_data,
batch_size=BATCH_SIZE,
shuffle=True # 是否打乱顺序
)
G = nn.Sequential( # Generator
nn.Linear(N_IDEAS, 128), # random ideas (could from normal distribution)
nn.ReLU(),
nn.Linear(128, ART_COMPONENTS), # making a painting from these random ideas
nn.ReLU(),
)
D = nn.Sequential( # Discriminator
nn.Linear(ART_COMPONENTS, 128), # receive art work either from the famous artist or a newbie like G
nn.ReLU(),
nn.Linear(128, 1),
nn.Sigmoid(), # tell the probability that the art work is made by artist
)
# loss & optimizer
optimD = torch.optim.Adam(D.parameters(), lr=LR_D)
optimG = torch.optim.Adam(G.parameters(), lr=LR_G)
label_Real = torch.FloatTensor(BATCH_SIZE).data.fill_(1)
label_Fake = torch.FloatTensor(BATCH_SIZE).data.fill_(0)
filePath = []
for epoch in range(EPOCH):
for step, (images, imagesLabel) in enumerate(train_loader):
G_ideas = torch.randn((BATCH_SIZE, N_IDEAS))
G_paintings = G(G_ideas)
images = images.reshape(BATCH_SIZE, -1)
prob_artist0 = D(images) # D try to increase this prob
prob_artist1 = D(G_paintings)
D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
G_loss = torch.mean(torch.log(1. - prob_artist1))
optimD.zero_grad()
D_loss.backward(retain_graph=True)
optimD.step()
optimG.zero_grad()
G_loss.backward(retain_graph=True)
optimG.step()
if step % 20 == 0:
plt.cla()
picture = torch.squeeze(G_paintings[0]).detach().numpy().reshape((28, 28))
plt.imshow(picture, cmap=plt.cm.gray_r)
plt.savefig(PNGFILE + '%d-%d.png' % (epoch, step))
filePath.append(PNGFILE + '%d-%d.png' % (epoch, step))
generated_images = []
for png_path in filePath:
generated_images.append(imageio.imread(png_path))
shutil.rmtree(PNGFILE)
imageio.mimsave('gan-mnist.gif', generated_images, 'GIF', duration=0.1)