GAN学习笔记 (2):pytorch实现naive GAN

GAN学习笔记 (2):pytorch实现naive GAN

我们这里做个demo,就不直接生成图片了,而是事先准备好一些“点”,以这些“点”来代替图片。我们训练一个GAN,看看训练出的这个GAN的Generator能不能拟合我们实现准备好的“点”的分布。我们这里准备一个8-Gaussian Mixture Distribution,但我们假装并不知道这些“点”的分布(因为我们并不知道高维空间中的图片符合什么分布),让GAN来学习出他们的分布。

先定两个变量:

h_dim = 400
batchsz = 512

1.数据生成

生成数据的代码如下,这些“点”就相当于real image:

def data_generator():
    # 8-gaussian mixture model
    scale = 2.
    centers = [
        (1, 0),
        (-1, 0),
        (0, 1),
        (0, -1),
        (1. / np.sqrt(2), 1. / np.sqrt(2)),
        (1. / np.sqrt(2), -1. / np.sqrt(2)),
        (-1. / np.sqrt(2), 1. / np.sqrt(2)),
        (-1. / np.sqrt(2), -1. / np.sqrt(2))
    ]
    centers = [(scale * x, scale * y) for x, y in centers]
    while True:
        dataset = []
        for i in range(batchsz):
            point = np.random.randn(2) * 0.02
            center = random.choice(centers)
            point[0] += center[0]
            point[1] += center[1]
            dataset.append(point)
        dataset = np.array(dataset, dtype='float32')
        dataset /= 1.414  # stdev
        yield dataset

2.模型搭建

这里我们就随便搞几个层来搭一个Generator和Discriminator:

class Generator(nn.Module):

    def __init__(self):
        super(Generator, self).__init__()

        self.net = nn.Sequential(
            # 这个2也可以换成变的,只不过是你noise特征的维度
            nn.Linear(2, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            # “点”是二维的,所以输出必须是2维
            nn.Linear(h_dim, 2),
        )

    def forward(self, z):
        output = self.net(z)
        return output


class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(2, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        output = self.net(x)
        return output.view(-1)

3.训练模型

首先我们得到数据生成器:

data_iter = data_generator()

根据上篇博客(Discriminator多训练,Generator少训练),我们训练五次Discriminator,一次Generator。下面看代码:

 G = Generator().cuda()
 D = Discriminator().cuda()

 optim_G = optim.Adam(G.parameters(), lr=1e-3, betas=(0.5, 0.9))
 optim_D = optim.Adam(D.parameters(), lr=1e-3, betas=(0.5, 0.9))

for epoch in range(50000):
    # train Discriminator
    for _ in range(5):
    ##############for real data###############    

        # 得到real data
        x = next(data_iter)
        xr = torch.from_numpy(x).cuda()
        # 打分
        predr = D(xr)
        # 给真实数据高分
        lossr = - (predr.mean())
        
    ##############for fake data###############
    
        # noise
         z = torch.randn(batchsz, 2).cuda() 
        # 生成的数据, 我们这时训练的是Discriminator不需要更新Generator的梯度
         xf = G(z).detach()
        # 打分
         predf = (D(xf))
        # 给生成的数据低分	
         lossf = (predf.mean())

    ##############for Discriminator###############

         loss_D = lossr + lossf
        
    ################update parameter#################    
         optim_D.zero_grad()
         loss_D.backward()
         optim_D.step()
      
    
    # train Generator
    z = torch.randn(batchsz, 2).cuda()
    xf = G(z)
    predf = D(xf)
    
    # 让Discriminator给fake数据打高分
    loss_G = - (predf.mean())
    
    optim_G.zero_grad()
    loss_G.backward()
    optim_G.step()
    
    
    if epoch % 100 == 0:
        print(loss_D.item(), loss_G.item())
    

至此,最naive的GAN的代码demo就全部完成了,下一篇讲讲WGAN解决的问题和WGAN的代码。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值