GAN学习笔记 (2):pytorch实现naive GAN
我们这里做个demo,就不直接生成图片了,而是事先准备好一些“点”,以这些“点”来代替图片。我们训练一个GAN,看看训练出的这个GAN的Generator能不能拟合我们实现准备好的“点”的分布。我们这里准备一个8-Gaussian Mixture Distribution,但我们假装并不知道这些“点”的分布(因为我们并不知道高维空间中的图片符合什么分布),让GAN来学习出他们的分布。
先定两个变量:
h_dim = 400
batchsz = 512
1.数据生成
生成数据的代码如下,这些“点”就相当于real image:
def data_generator():
# 8-gaussian mixture model
scale = 2.
centers = [
(1, 0),
(-1, 0),
(0, 1),
(0, -1),
(1. / np.sqrt(2), 1. / np.sqrt(2)),
(1. / np.sqrt(2), -1. / np.sqrt(2)),
(-1. / np.sqrt(2), 1. / np.sqrt(2)),
(-1. / np.sqrt(2), -1. / np.sqrt(2))
]
centers = [(scale * x, scale * y) for x, y in centers]
while True:
dataset = []
for i in range(batchsz):
point = np.random.randn(2) * 0.02
center = random.choice(centers)
point[0] += center[0]
point[1] += center[1]
dataset.append(point)
dataset = np.array(dataset, dtype='float32')
dataset /= 1.414 # stdev
yield dataset
2.模型搭建
这里我们就随便搞几个层来搭一个Generator和Discriminator:
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.net = nn.Sequential(
# 这个2也可以换成变的,只不过是你noise特征的维度
nn.Linear(2, h_dim),
nn.ReLU(True),
nn.Linear(h_dim, h_dim),
nn.ReLU(True),
nn.Linear(h_dim, h_dim),
nn.ReLU(True),
# “点”是二维的,所以输出必须是2维
nn.Linear(h_dim, 2),
)
def forward(self, z):
output = self.net(z)
return output
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.net = nn.Sequential(
nn.Linear(2, h_dim),
nn.ReLU(True),
nn.Linear(h_dim, h_dim),
nn.ReLU(True),
nn.Linear(h_dim, h_dim),
nn.ReLU(True),
nn.Linear(h_dim, 1),
nn.Sigmoid()
)
def forward(self, x):
output = self.net(x)
return output.view(-1)
3.训练模型
首先我们得到数据生成器:
data_iter = data_generator()
根据上篇博客(Discriminator多训练,Generator少训练),我们训练五次Discriminator,一次Generator。下面看代码:
G = Generator().cuda()
D = Discriminator().cuda()
optim_G = optim.Adam(G.parameters(), lr=1e-3, betas=(0.5, 0.9))
optim_D = optim.Adam(D.parameters(), lr=1e-3, betas=(0.5, 0.9))
for epoch in range(50000):
# train Discriminator
for _ in range(5):
##############for real data###############
# 得到real data
x = next(data_iter)
xr = torch.from_numpy(x).cuda()
# 打分
predr = D(xr)
# 给真实数据高分
lossr = - (predr.mean())
##############for fake data###############
# noise
z = torch.randn(batchsz, 2).cuda()
# 生成的数据, 我们这时训练的是Discriminator不需要更新Generator的梯度
xf = G(z).detach()
# 打分
predf = (D(xf))
# 给生成的数据低分
lossf = (predf.mean())
##############for Discriminator###############
loss_D = lossr + lossf
################update parameter#################
optim_D.zero_grad()
loss_D.backward()
optim_D.step()
# train Generator
z = torch.randn(batchsz, 2).cuda()
xf = G(z)
predf = D(xf)
# 让Discriminator给fake数据打高分
loss_G = - (predf.mean())
optim_G.zero_grad()
loss_G.backward()
optim_G.step()
if epoch % 100 == 0:
print(loss_D.item(), loss_G.item())
至此,最naive的GAN的代码demo就全部完成了,下一篇讲讲WGAN解决的问题和WGAN的代码。