gan对抗网络的测试代码

import torch
from torch import nn,optim,autograd
import numpy as np
import visdom
import random
import matplotlib
matplotlib.use('TkAgg')
from matplotlib import pyplot as plt

h_dim =400
batchsz =512
viz =visdom.Visdom()
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.net=nn.Sequential(
            #x:[b,2] =>[b,2]
            nn.Linear(2,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,2),
        )
    def forward(self, x):
        output = self.net(x)
        return output

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()

        self.net = nn.Sequential(
            nn.Linear(2,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, 1),
            nn.Sigmoid()
        )
    def forward(self,x):
        output = self.net(x)
        return output.view(-1)

#生成数据
def data_generator():
    scale =2
    centers = [
        (1,0),
        (-1,0),
        (0,1),
        (0,-1),
        (1. /np.sqrt(2),1. /np.sqrt(2)),
        (1. / np.sqrt(2), -1. / np.sqrt(2)),
        (-1. / np.sqrt(2), 1. / np.sqrt(2)),
        (-1. / np.sqrt(2), -1. / np.sqrt(2)),
    ]

    centers = [(scale*x,scale*y) for x,y in centers]

    while True:
        dataset = []
        for i in range(batchsz):
            point = np.random.randn(2)*0.02
            center = random.choice(centers)
            #N(0,1)0-1分布
            point[0] += center[0]
            point[1] += center[1]
            dataset.append(point)

        dataset = np.array(dataset).astype(np.float32)
        dataset /= 1.414
        yield dataset
def generate_image(D,G,xr,epoch):
    N_POINTS =128
    RANGE =3
    plt.clf()
    points = np.zeros((N_POINTS,N_POINTS,2),dtype='float32')
    points[:,:,0] = np.linspace(-RANGE,RANGE,N_POINTS)[:,None]
    points[:,:,1] = np.linspace(-RANGE,RANGE,N_POINTS)[None,:]
    points =points.reshape((-1,2))


    with torch.no_grad():
        points = torch.Tensor(points).cuda() #[16384,2]
        disc_map = D(points).cpu().numpy() #[16384]

    x = y =np.linspace(-RANGE,RANGE,N_POINTS)
    cs = plt.contour(x,y,disc_map.reshape((len(x),len(y))).transpose())
    plt.clabel(cs,inline=1,fontsize=10)
    #plt.colorbar()

    #draw samples
    with torch.no_grad():
        z = torch.randn(batchsz,2).cuda() #[b,2]
        samples = G(z).cpu().numpy() #[b,2]
    plt.scatter(xr[:,0],xr[:,1],c='orange',marker='.')
    plt.scatter(samples[:,0],samples[:,1],c='green',marker='+')

    viz.matplot(plt,win='contour',opts=dict(title='p(x):%d'%epoch))

def gradient_penalty(D,xr,xf):
    #加一个惩罚项
    #[b,1]
    t = torch.rand(batchsz,1).cuda()
    #[b,1] =>[b,2]
    t = t.expand_as(xr)
    #interpolation
    mid = t * xr + (1-t)*xf
    #set it requires gradient
    mid.requires_grad_()

    pred = D(mid)
    grads = autograd.grad(outputs=pred,inputs=mid,grad_outputs=torch.ones_like(pred),create_graph=True,
                          retain_graph=True,only_inputs=True)[0] #retain_graph=True需要backword,需要设置为true

    gp = torch.pow(grads.norm(2,dim=1)-1,2).mean()

    return gp

def main():
    torch.manual_seed(23)
    np.random.seed(23)

    data_iter = data_generator()
    x = next(data_iter)
    #[b,2]
    print(x.shape)

    G = Generator().cuda()
    D = Discriminator().cuda()
    #查看网络
    print(G)
    print(D)
    optim_G = optim.Adam(G.parameters(),lr=5e-4,betas=(0.5,0.9))
    optim_D = optim.Adam(G.parameters(),lr=5e-4,betas=(0.5,0.9))
    #生成俩条曲线
    viz.line([[0,0]],[0],win='loss',opts=dict(title='loss',legend=['D','G']))
    for epoch in range(50000):
        #1、 train Discrimator firstly
        for _ in range(5):
            #1、train on real data
            xr = next(data_iter) #是numpy
            xr = torch.from_numpy(xr).cuda() #需要转换成tensor
            #[b,2] => [b,1]
            predr = D(xr)
            #max predr
            lossr = -predr.mean()

            #1.2 train on fake data
            #[b,]
            z = torch.randn(batchsz,2).cuda()
            xf = G(z).detach() #tf.stop_gradient()
            predf = D(xf)
            lossf = predf.mean()

            #1.3 gradient penalty (真是数据和假数据之间的差值)
            gp = gradient_penalty(D,xr,xf.detach())

            #aggregate all
            loss_D=lossr+lossf+gp

            #optimize
            optim_D.zero_grad()
            loss_D.backward()
            optim_D.step()

        #2、train Generator
        z = torch.randn(batchsz,2).cuda()
        xf =G(z)
        predf = D(xf)
        loss_G = -predf.mean() #需要最大化,所以加负号

        #optimize
        optim_G.zero_grad()
        loss_G.backward()
        optim_G.step()

        if epoch % 100 == 0:
            viz.line([[loss_D.item(),loss_G.item()]],[epoch],win='loss',update='append')
            print(loss_D.item(),loss_G.item())
            generate_image(D,G,xr,epoch)

if __name__ == '__main__':
    main()

#启动方式 python -m visdom.server  或者直接  visdom
#访问地址
#http://localhost:8097

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值