import random
import numpy as np
import torch
from matplotlib import pyplot as plt
from torch import nn, optim, autograd
from visdom import Visdom
# 生成real-data数据集
def data_generator():
"""
预设数据样本分布为8个高斯分布叠加的分布模型
"""
scale = 2.
centers = [
(1, 0), (-1, 0), (0, 1), (0, -1),
(1. / np.sqrt(2), 1. / np.sqrt(2)),
(1. / np.sqrt(2), -1. / np.sqrt(2)),
(-1. / np.sqrt(2), 1. / np.sqrt(2)),
(-1. / np.sqrt(2), -1. / np.sqrt(2))
]
centers = [(scale * x, scale * y) for x, y in centers]
while True:
dataset = []
for i in range(batch_size):
point = np.random.randn(2) * 0.02
center = random.choice(centers)
point[0] += center[0]
point[1] += center[1]
dataset.append(point)
dataset = np.array(dataset, dtype='float32')
dataset /= 1.414 # stdev
yield dataset
# hyper-parameters
hidden_dim = 200
batch_size = 256
epochs = 5000
# visdom object
vis = Visdom()
# device
device = torch.device("cuda")
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.net = nn.Sequential(
# 输入为[batch, 2],这里2指随机生成的二维点
nn.Linear(2, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(True),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(True),
# 这里2指Generator生成的尽量满足真实数据分布的fake-data
nn.Linear(hidden_dim, 2)
)
def forward(self, z):
"""
:param z: [batch, 2] 随机生成的二维点
:return: [batch, 2] Generator生成的尽量满足真实数据分布的fake-data
"""
output = self.net(z)
return output
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.net = nn.Sequential(
# [batch, 2], 2表示二维数据点(real-data 或 fake-data)
nn.Linear(2, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, 1),
# 概率高表示判定为real-data,概率低判定为fake-data
nn.Sigmoid()
)
def forward(self, x):
"""
输入二维数据点,判断是否满足预定义分布(real-data or fake-data)
:param x:
:return:
"""
output = self.net(x)
return output
def weights_init(m):
if isinstance(m, nn.Linear):
# m.weight.data.normal_(0.0, 0.02)
nn.init.kaiming_normal_(m.weight)
m.bias.data.fill_(0)
def gradient_penalty(D, xr, xf):
LAMBDA = 0.3
# only constrait for Discriminator
xf = xf.detach()
xr = xr.detach()
# [b, 1] => [b, 2]
alpha = torch.rand(batch_size, 1).to(device)
alpha = alpha.expand_as(xr)
interpolates = alpha * xr + (1 - alpha) * xf
interpolates.requires_grad_()
pred = D(interpolates)
gradients = autograd.grad(outputs=pred, inputs=interpolates,
grad_outputs=torch.ones_like(pred),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
return gp
def train_GAN():
# 使得每次训练生成的随机数比较稳定
torch.manual_seed(23)
np.random.seed(23)
G = Generator().to(device)
D = Discriminator().to(device)
G.apply(weights_init)
D.apply(weights_init)
optimizer_G = optim.Adam(G.parameters(), lr=1e-3, betas=(0.5, 0.9))
optimizer_D = optim.Adam(D.parameters(), lr=1e-3, betas=(0.5, 0.9))
data_iter = data_generator()
vis.line([[0., 0.]], [0.], win='Loss Info', opts=dict(label='Loss Info',
legend=['Loss_G', 'Loss_D']))
# Generator网络和Discriminator网络【交替分步训练】
# 每个epoch中Discriminator网络训练 k 次
for epoch in range(epochs):
# 1. train Discriminator for k step
for _ in range(5):
# real-data loss
x = next(data_iter)
xr = torch.from_numpy(x).to(device)
predr = D(xr) # 为real-data的概率预测值
lossr = -predr.mean()
# fake-data loss
x_random = torch.randn(batch_size, 2).to(device)
xf = G(x_random).detach() # 返回 tensor.data
predf = D(xf)
lossf = predf.mean()
# gradient penality
gp = gradient_penalty(D, xr, xf)
# 梯度更新
loss_D = lossr + lossf + gp
optimizer_D.zero_grad()
loss_D.backward()
optimizer_D.step()
# 2. train Generator
x_random = torch.randn(batch_size, 2).to(device)
xf = G(x_random)
predf = D(xf)
loss_G = -predf.mean()
optimizer_G.zero_grad()
loss_G.backward()
optimizer_G.step()
# 显示训练信息
if (epoch + 1) % 100 == 0:
vis.line([[loss_G.item(), loss_D.item()]], [epoch],
win='Loss Info', update='append')
print("epoch:%-5i" % (epoch + 1), "Loss_G=%-5.5f" % loss_G.item(),
"Loss_D=%-5.5f" % loss_D.item())
if __name__ == '__main__':
train_GAN()
print("Done!")