知识回顾:Pytorch实现Least Squares GAN

相比于CNN,RNN等网络,GAN最难理解的点是其损失函数的含义以及定义,由于GAN由鉴别器和生成器两部分组成,在实现的时候我们需要定义两个损失:

  • 鉴别器损失
  • 生成器损失

对于鉴别器D,我们要遵循一个思路便是真的为真,假的为假,要有强大的鉴别能力。对于生成器G我们需要让假的为真,实现浑水摸鱼和投机倒把。
这里不在进行数学知识的展示,有兴趣的小伙伴可以参考以下两个链接。

在开始之前有如下基本规则:对于M*N的矩阵,M代表样本数,N代表每个样本的特征维数
以下为实现Least Squares GAN的代码:

#导入所需模块
import torch
import torch.nn as nn
from torch.autograd import Variable
#我们今天不进行复发分布的训练,仅通过训练好的生成器生成来拟合如下数据
#real data
real_data = torch.tensor([[1.],
                          [1.],
                          [1.]])

#这里是我想用的噪声
my_noise = Variable(torch.randn(3,3))#不用Variable定义的话,我们可能无法进行反向传播和优化算法
#定义生成器和鉴别器
#这里我就偷懒了,生成器和鉴别器网络结构一致,因为我想要生成的数据非常简单
def D_model():
    net = nn.Sequential(        
            nn.Linear(1, 3),
            nn.LeakyReLU(0.2),
            nn.Linear(3, 6),
            nn.LeakyReLU(0.2),
            nn.Linear(6, 1)
        )
    return net

def G_model(my_noise):
    net = nn.Sequential(        
            nn.Linear(my_noise, 3),
            nn.LeakyReLU(0.2),
            nn.Linear(3, 6),
            nn.LeakyReLU(0.2),
            nn.Linear(6, 1)
        )
    return net
#实例化鉴别器和生成器
D_net = D_model()
G_net = G_model(my_noise.shape[1])
#损失函数的定义
#此处是重中之重
#这里的mean很关键,其可以让向量变成标量,实现backward
#也可以使用.sum(),但是没有mean最终的效果好
def discriminator_loss(real_score, fake_score):
    loss = 0.5 * ((real_score - 1) ** 2).mean() + 0.5 * (fake_score ** 2).mean()
    return loss

def generator_loss(fake_score):
    loss = 0.5 * ((fake_score - 1) ** 2).mean()
    return loss
#定义优化算法
# 使用 adam 优化算法
def get_optimizer(net):
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3, betas=(0.5, 0.999))
    return optimizer
#分别创建优化器
#创建优化器
D_optimizer = get_optimizer(D_net)
G_optimizer = get_optimizer(G_net)
#train model
for step in range(16001):
    #真实分数
    real_score =D_net(real_data)
    
    #noise产生假数据
    fake_data = G_net(my_noise)
    #由假数据产生假分数
    fake_score = D_net(fake_data)
    
    #d的loss
    d_loss = discriminator_loss(real_score,fake_score)
    #首先优化D
    D_optimizer.zero_grad()
    d_loss.backward()
    D_optimizer.step()
    
    #在优化G
    #此时如果不在重新生成假数据,编译器会报错
    #生成假数据
    fake_data = G_net(my_noise)
    #由假数据产生假分数
    fake_score = D_net(fake_data)
    #g_loss
    g_loss = generator_loss(fake_score)
    G_optimizer.zero_grad()
    g_loss.backward()
    G_optimizer.step()
    
    if (step%1000 ==0):
        print("step:{},G_loss:{},D_loss:{}".format(step,g_loss.data,d_loss.data))
#测试结果
G_net.eval()
my_noise = torch.randn(3,3)
test = G_net(my_noise)
print(test)

tensor([[1.0201],
        [1.0191],
        [1.0153]], grad_fn=<AddmmBackward>)

到这,今天回顾的内容也就结束了。
PS:写代码有时候还是很有趣的,但是今天导师又叫我去跑别人的代码,尝试复现论文的内容,难道就我一个人觉得调通别人写好的代码很难嘛?摸不清文件之间的关系,看不懂的起名方式,怎么翻译都无法理解的英语注释,还有最重要的一点,就是大神们高深莫测的逻辑。

--------------------------------------------------2021/3/15 22:16

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值