import torch
from torch import nn,optim,autograd
import numpy as np
import visdom
import random
import matplotlib
matplotlib.use('TkAgg')
from matplotlib import pyplot as plt
h_dim =400
batchsz =512
viz =visdom.Visdom()
class Generator(nn.Module):
def __init__(self):
super(Generator,self).__init__()
self.net=nn.Sequential(
#x:[b,2] =>[b,2]
nn.Linear(2,h_dim),
nn.ReLU(True),
nn.Linear(h_dim,h_dim),
nn.ReLU(True),
nn.Linear(h_dim,h_dim),
nn.ReLU(True),
nn.Linear(h_dim,2),
)
def forward(self, x):
output = self.net(x)
return output
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator,self).__init__()
self.net = nn.Sequential(
nn.Linear(2,h_dim),
nn.ReLU(True),
nn.Linear(h_dim, h_dim),
nn.ReLU(True),
nn.Linear(h_dim, h_dim),
nn.ReLU(True),
nn.Linear(h_dim, 1),
nn.Sigmoid()
)
def forward(self,x):
output = self.net(x)
return output.view(-1)
#生成数据
def data_generator():
scale =2
centers = [
(1,0),
(-1,0),
(0,1),
(0,-1),
(1. /np.sqrt(2),1. /np.sqrt(2)),
(1. / np.sqrt(2), -1. / np.sqrt(2)),
(-1. / np.sqrt(2), 1. / np.sqrt(2)),
(-1. / np.sqrt(2), -1. / np.sqrt(2)),
]
centers = [(scale*x,scale*y) for x,y in centers]
while True:
dataset = []
for i in range(batchsz):
point = np.random.randn(2)*0.02
center = random.choice(centers)
#N(0,1)0-1分布
point[0] += center[0]
point[1] += center[1]
dataset.append(point)
dataset = np.array(dataset).astype(np.float32)
dataset /= 1.414
yield dataset
def generate_image(D,G,xr,epoch):
N_POINTS =128
RANGE =3
plt.clf()
points = np.zeros((N_POINTS,N_POINTS,2),dtype='float32')
points[:,:,0] = np.linspace(-RANGE,RANGE,N_POINTS)[:,None]
points[:,:,1] = np.linspace(-RANGE,RANGE,N_POINTS)[None,:]
points =points.reshape((-1,2))
with torch.no_grad():
points = torch.Tensor(points).cuda() #[16384,2]
disc_map = D(points).cpu().numpy() #[16384]
x = y =np.linspace(-RANGE,RANGE,N_POINTS)
cs = plt.contour(x,y,disc_map.reshape((len(x),len(y))).transpose())
plt.clabel(cs,inline=1,fontsize=10)
#plt.colorbar()
#draw samples
with torch.no_grad():
z = torch.randn(batchsz,2).cuda() #[b,2]
samples = G(z).cpu().numpy() #[b,2]
plt.scatter(xr[:,0],xr[:,1],c='orange',marker='.')
plt.scatter(samples[:,0],samples[:,1],c='green',marker='+')
viz.matplot(plt,win='contour',opts=dict(title='p(x):%d'%epoch))
def gradient_penalty(D,xr,xf):
#加一个惩罚项
#[b,1]
t = torch.rand(batchsz,1).cuda()
#[b,1] =>[b,2]
t = t.expand_as(xr)
#interpolation
mid = t * xr + (1-t)*xf
#set it requires gradient
mid.requires_grad_()
pred = D(mid)
grads = autograd.grad(outputs=pred,inputs=mid,grad_outputs=torch.ones_like(pred),create_graph=True,
retain_graph=True,only_inputs=True)[0] #retain_graph=True需要backword,需要设置为true
gp = torch.pow(grads.norm(2,dim=1)-1,2).mean()
return gp
def main():
torch.manual_seed(23)
np.random.seed(23)
data_iter = data_generator()
x = next(data_iter)
#[b,2]
print(x.shape)
G = Generator().cuda()
D = Discriminator().cuda()
#查看网络
print(G)
print(D)
optim_G = optim.Adam(G.parameters(),lr=5e-4,betas=(0.5,0.9))
optim_D = optim.Adam(G.parameters(),lr=5e-4,betas=(0.5,0.9))
#生成俩条曲线
viz.line([[0,0]],[0],win='loss',opts=dict(title='loss',legend=['D','G']))
for epoch in range(50000):
#1、 train Discrimator firstly
for _ in range(5):
#1、train on real data
xr = next(data_iter) #是numpy
xr = torch.from_numpy(xr).cuda() #需要转换成tensor
#[b,2] => [b,1]
predr = D(xr)
#max predr
lossr = -predr.mean()
#1.2 train on fake data
#[b,]
z = torch.randn(batchsz,2).cuda()
xf = G(z).detach() #tf.stop_gradient()
predf = D(xf)
lossf = predf.mean()
#1.3 gradient penalty (真是数据和假数据之间的差值)
gp = gradient_penalty(D,xr,xf.detach())
#aggregate all
loss_D=lossr+lossf+gp
#optimize
optim_D.zero_grad()
loss_D.backward()
optim_D.step()
#2、train Generator
z = torch.randn(batchsz,2).cuda()
xf =G(z)
predf = D(xf)
loss_G = -predf.mean() #需要最大化,所以加负号
#optimize
optim_G.zero_grad()
loss_G.backward()
optim_G.step()
if epoch % 100 == 0:
viz.line([[loss_D.item(),loss_G.item()]],[epoch],win='loss',update='append')
print(loss_D.item(),loss_G.item())
generate_image(D,G,xr,epoch)
if __name__ == '__main__':
main()
#启动方式 python -m visdom.server 或者直接 visdom
#访问地址
#http://localhost:8097
gan对抗网络的测试代码
最新推荐文章于 2024-06-17 13:08:59 发布