2021-04-04

生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。模型通过框架中(至少)两个模块:生成模型(Generative Model)和判别模型(Discriminative Model)的互相博弈学习产生相当好的输出。原始 GAN理论中,并不要求 G 和 D 都是神经网络,只需要是能拟合相应生成和判别的函数即可。但使用中一般均使用深度神经网络作为 G 和 D 。上文定义摘自百度百科,通过定义,我们知道GAN的组件包括,G,D两个网络,但从定义没有具体说明G,D两个网络是如何相互作用,本文通过实现一维GAN拟合二次函数,理解G,D网络相互作用原理。

from numpy import zeros
from numpy import ones
from numpy import hstack
from numpy.random import rand
from numpy.random import randn

import numpy as np
from torch import nn
from tqdm import tqdm
import torch
import matplotlib.pyplot as plt
def summarize_performance(generator, latent_dim, n=100):
    x_real, y_real = generate_real_samples(n)
    l_x = generate_latent_points(latent_dim, 100)
    l_x = torch.from_numpy(l_x).to(device).float()
    x_fake = generator(l_x)
    x_fake = x_fake.cpu().detach().numpy()
    plt.scatter(x_real[:, 0], x_real[:, 1], color='red')
    plt.scatter(x_fake[:, 0], x_fake[:, 1], color='blue')
    plt.show()
def generate_real_samples(n):
    X1 = rand(n) - 0.5
    X2 = X1 * X1
    X1 = X1.reshape(n, 1)
    X2 = X2.reshape(n, 1)
    X = hstack((X1, X2))
    y = ones((n, 1))
    return X, y
##生成隐语意向量
def generate_latent_points(latent_dim, n):
    x_input = randn(latent_dim * n)
    x_input = x_input.reshape(n, latent_dim)
    
    return x_input
    
## 使用G网络生成虚假的点
def generate_fake_samples(generator, latent_dim, n=1000):
    f_x = generate_latent_points(latent_dim, n)
    f_y = zeros((n, 1))
    f_x = torch.from_numpy(f_x).to(device).float()
    f_y = torch.from_numpy(f_y).to(device).long()
    f_x = generator(f_x)
    return f_x, f_y
class discriminatorModel(nn.Module):
    def __init__(self, input_features:int=2, output_features:int=2):
        super(discriminatorModel, self).__init__()
        self.linear1 = nn.Linear(input_features, 20)
        self.linear2 = nn.Linear(20, output_features)
        self.__initweight__()
    def __initweight__(self):
        nn.init.xavier_uniform_(self.linear1.weight, gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_uniform_(self.linear2.weight, gain=nn.init.calculate_gain('relu'))
    def forward(self, x):
        x = self.linear1(x)
        x = torch.relu(x)
        x = self.linear2(x)
        return x
class generatorModel(nn.Module):
    def __init__(self, input_features:int=5, output_features:int=2):
        super(generatorModel, self).__init__()
        self.linear1 = nn.Linear(input_features, 25)
        self.linear2 = nn.Linear(25, 2)
        self.__initweight__()
    def __initweight__(self):
        nn.init.xavier_uniform_(self.linear1.weight, gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_uniform_(self.linear2.weight, gain=nn.init.calculate_gain('relu'))
    def forward(self, x):
        x = self.linear1(x)
        x = torch.relu(x)
        x = self.linear2(x)
        return x
g = generatorModel()
d = discriminatorModel()
EPOCHS = 36000




criterion = nn.CrossEntropyLoss()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu'
d_optim = torch.optim.Adam(d.parameters(), lr=2e-3)
g_optim = torch.optim.Adam(g.parameters(), lr=2e-3)
d = d.to(device)
g = g.to(device)
pbar = tqdm(range(EPOCHS))
losses = []
for epoch in pbar:

    g_optim.zero_grad()
    d_optim.zero_grad()
    ## 生成真实点
    r_x, r_y = generate_real_samples(1000)
    r_x = torch.from_numpy(r_x).to(device).float()
    r_y = torch.from_numpy(r_y).to(device).long()
    r_preds = d(r_x)
    loss1 = criterion(r_preds, r_y.view(-1))
    ## 通过G网络生成虚假点
    f_x, f_y = generate_fake_samples(g, 5, 1000)
    f_preds = d(f_x)
    loss2 = criterion(f_preds, f_y.view(-1))
    loss = loss1 + loss2
    loss.backward()
    d_optim.step()
    losses.append(loss.cpu().detach().numpy())
    
    ## 通过隐语意向量
    l_x = generate_latent_points(5, 1000)
    l_x = torch.from_numpy(l_x).to(device).float()
    l_y = torch.ones(1000).to(device).long()
    l_x = g(l_x)
    l_preds = d(l_x)
    loss = criterion(l_preds, l_y)
    loss.backward()
    g_optim.step()
    pbar.set_description(f'losses {np.mean(losses):.3f}, loss: {loss.cpu().detach().numpy():.3f}')
    if epoch % 2000 == 0:
        summarize_performance(g, 5)

生成结果如下图,蓝色为拟合的二次曲线。有时候生成效果不好,可能也会生成一大堆点。
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值
>