使用mnist数据集简单实现GAN,代码源码来源于github,在这里仅供学习!
源代码只进行了训练没有进行验证
(只涉及到代码实现,不涉及到任何理论)
代码模块及细节
1 初始化
#导入库
import argparse
import os
import numpy as np
import torch
#-------------设置参数-------------------
#不明白参数含义的可以看看help的英文注释
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space") # 随机噪声z的维度
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension") # 输入图像的尺寸
parser.add_argument("--channels", type=int, default=1, help="number of image channels") # 输入图像的channel数
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples") # 保存生成图像和模型的间隔
opt = parser.parse_known_args()[0]#parser.parse_known_args[0]查看上面设置的参数
print("opt =", opt)
#设置图片的大小
img_shape = (opt.channels, opt.img_size, opt.img_size)
print("img_shape =", img_shape)
2 数据加载
数据格式为 28 * 28 的灰度图,以及对应 0-9 的数字标签,在这里我们以常见的MNIST数据集来实验
#导入库
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
# Configure data loader
#常用os方式,如果没有某个文件夹,直接创建这个文件夹
if not os.path.exists('./home/featurize/mnist'):
os.makedirs("./home/featurize/mnist", exist_ok=True)
#调用pytorch自带的数据集,如果没有mnist数据集会自动下载
dataloader = torch.utils.data.DataLoader(
datasets.MNIST(
"./home/featurize/mnist",
train=True,
download=True,
transform=transforms.Compose(
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
),
),
batch_size=opt.batch_size,
shuffle=True,
)
from torch.autograd import Variable
import matplotlib.pyplot as plt
def show_img(img, trans=True):
if trans:
img = np.transpose(img.detach().cpu().numpy(), (1, 2, 0)) # 把channel维度放到最后
plt.imshow(img[:, :, 0], cmap="gray")
else:
plt.imshow(img, cmap="gray")
plt.show()
mnist = datasets.MNIST("../../data/mnist")
for i in range(3):
sample = mnist[i][0]
label = mnist[i][1]
show_img(np.array(sample), trans=False)
print("label =", label, '\n')
代码在实现的时候,用到了Variable对象。Variable对Tensor对象进行封装,只需要Variable::data即可取出Tensor,并且Variable还封装了该Tensor的梯度Variable::grad(是个Variable对象)。现在用Variable作为计算图的节点,则通过反向传播自动求得的导数就保存在Variable对象中了。简单来说就是Variable可以自动求梯度,嫌麻烦可以不使用,自己手动更新梯度
from torch.autograd import Variable
import matplotlib.pyplot as plt
#定义一个展示图片的函数,detach()切断反向传播,transpose()用来更换维度,在这里我把第一维channel通道数放到了最后
def show_img(img, trans=True):
if trans:
img = np.transpose(img.detach().cpu().numpy(), (1, 2, 0)) # 把channel维度放到最后,这是因为plt.imshow只接受RGB图像,即高*宽*channel数
plt.imshow(img[:, :, 0], cmap="gray")
else:
plt.imshow(img, cmap="gray")
plt.show()
mnist = datasets.MNIST("../../data/mnist")
#查看一下前三个图片,打印他们的图像和标签
for i in range(3):
sample = mnist[i][0]
label = mnist[i][1]
show_img(np.array(sample), trans=False)
print("label =", label, '\n')
为了更好地观察transforms的机制,在这里展示了transforms的每一步变化和输出,希望对初学者能够对transforms的使用更清晰一点。我作为一个纯小白却是清晰很多了,最主要的是print他们,看看他们的样子
trans_resize = transforms.Resize(opt.img_size)
trans_to_tensor = transforms.ToTensor()
trans_normalize = transforms.Normalize([0.5], [0.5]) # x_n = (x - 0.5) / 0.5
print("shape =", np.array(sample).shape, '\n')
print("data =", np.array(sample), '\n')
samlpe = trans_resize(sample)
print("(trans_resize) shape =", np.array(sample).shape, '\n')
sample = trans_to_tensor(sample)
print("(trans_to_tensor) data =", sample, '\n')
sample = trans_normalize(sample)
print("(trans_normalize) data =", sample, '\n')
3 模型
3.1生成器
包含5个全连接层,使用LeakyReLU和Tanh激活函数,使用了BatchNorm
并且,个人认为LeakyReLU激活函数是GAN网络中最常用到的激活函数,当然仅是个人看法
GAN的模型关键的两个部分分别是生成器和鉴别器(或者也可以叫判别器,都可以)
价值函数或者说是LOSS是最重要的
首先定义一个生成器,这里为了简单,只使用了简单的几个全连接层,也就是说使用了多层感知机,这也与原文做了对照,并进行了简单的修改。
import torch.nn as nn
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(opt.latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
#np.prod()函数用来计算所有元素的乘积,对于有多个维度的数组可以指定轴,如axis=1指定计算每一行的乘积。
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
#view()的作用相当于numpy中的reshape,重新定义矩阵的形状。view()主要用在pytoch的tensor
img = img.view(img.size(0), *img_shape)
return img
generator = Generator()
print(generator)
3.2鉴别器
包含3个全连接层,使用LeakyReLU和Sigmoid激活函数
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
discriminator = Discriminator()
print(discriminator)
4 损失函数
使用 Binary Cross Entropy Loss (应该是最常见的)
# Loss function--BCELOSS
adversarial_loss = torch.nn.BCELoss()
5 使用Cuda,加速计算
#判断有没有GPU,如果有就用,没有就使用CPu;并且如果使用GPU,模型model,loss和数据都需要挂cuda
#比如 device = torch.device("cuda" if use_cuda else "cpu")
#x=x.to(device),x就是一个输入的tensor类型数据
#是否使用cuda
cuda = True if torch.cuda.is_available() else False
print("cuda_is_available =", cuda)
if cuda:
generator.cuda()
discriminator.cuda()
adversarial_loss.cuda()
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
6 优化器
使用Adam优化器(最常见也是最好用的优化器之一,一般不用改优化器,就是用adam)
在这里对生成器和鉴别器都采用adam优化器
# Optimizers,b1,b2代表adam的参数
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
print("learning_rate =", opt.lr)
7 创建输入
分别从数据集和随机向量中获取输入
在这里我们取出前几个图片
for i, (imgs, _) in list(enumerate(dataloader))[:1]:
real_imgs = Variable(imgs.type(Tensor))
# Sample noise as generator input生成符合正态分布的噪声,大小为(64,100)
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
print("i =", i, '\n')
print("shape of z =", z.shape, '\n')#shape of z = torch.Size([64, 100])
print("shape of real_imgs =", real_imgs.shape, '\n')#shape of real_imgs = torch.Size([64, 1, 28, 28])
print("z =", z, '\n')
print("real_imgs =")#<built-in method type of Tensor object at 0x7efcf792b6b0>
for img in real_imgs[:3]:
show_img(img)
这里是全部的
for i, (imgs, _) in enumerate(dataloader):
print(real_imgs.type)
real_imgs = Variable(imgs.type(Tensor))
# Sample noise as generator input
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
print("i =", i, '\n')
print("shape of z =", z.shape, '\n')
print("shape of real_imgs =", real_imgs.shape, '\n')
print("z =", z, '\n')
print("real_imgs =")
8 计算loss,反向传播
分别对生成器和判别器计算loss,使用反向传播更新模型参数
# Adversarial ground truths
#定义两个tensor,一个为真,全为1,;一个为假,全为0.大小为imgs的imgs.size(0)
valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False) # 为1时判定为真
fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False) # 为0时判定为假
# ---------------------
# Train Generator
# ---------------------
#梯度清0
optimizer_G.zero_grad()
#通过生成器生成图像
gen_imgs = generator(z)
print("gen_imgs =")
for img in gen_imgs:
show_img(img)
# Loss measures generator's ability to fool the discriminator
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
print("g_loss =", g_loss, '\n')
g_loss.backward()
#更新参数
optimizer_G.step()
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Measure discriminator's ability to classify real from generated samples
real_loss = adversarial_loss(discriminator(real_imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
#总的loss是real_loss+fake_loss
d_loss = (real_loss + fake_loss) / 2
print("real_loss =", real_loss, '\n')
print("fake_loss =", fake_loss, '\n')
print("d_loss =", d_loss, '\n')
d_loss.backward()
optimizer_D.step()
9 保存生成图像和模型文件
from torchvision.utils import save_image
epoch = 0 # temporary
batches_done = epoch * len(dataloader) + i
if batches_done % opt.sample_interval == 0:
save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True) # 保存生成图像
os.makedirs("./home/featurize/model", exist_ok=True) # 保存模型
torch.save(generator, './home/featurize/model/generator.pkl')
torch.save(discriminator, './home/featurize/model/discriminator.pkl')
print("gen images saved!\n")
print("model saved!")