相比于CNN,RNN等网络,GAN最难理解的点是其损失函数的含义以及定义,由于GAN由鉴别器和生成器两部分组成,在实现的时候我们需要定义两个损失:
- 鉴别器损失
- 生成器损失
对于鉴别器D,我们要遵循一个思路便是真的为真,假的为假,要有强大的鉴别能力。对于生成器G我们需要让假的为真,实现浑水摸鱼和投机倒把。
这里不在进行数学知识的展示,有兴趣的小伙伴可以参考以下两个链接。
在开始之前有如下基本规则:对于M*N的矩阵,M代表样本数,N代表每个样本的特征维数
以下为实现Least Squares GAN的代码:
#导入所需模块
import torch
import torch.nn as nn
from torch.autograd import Variable
#我们今天不进行复发分布的训练,仅通过训练好的生成器生成来拟合如下数据
#real data
real_data = torch.tensor([[1.],
[1.],
[1.]])
#这里是我想用的噪声
my_noise = Variable(torch.randn(3,3))#不用Variable定义的话,我们可能无法进行反向传播和优化算法
#定义生成器和鉴别器
#这里我就偷懒了,生成器和鉴别器网络结构一致,因为我想要生成的数据非常简单
def D_model():
net = nn.Sequential(
nn.Linear(1, 3),
nn.LeakyReLU(0.2),
nn.Linear(3, 6),
nn.LeakyReLU(0.2),
nn.Linear(6, 1)
)
return net
def G_model(my_noise):
net = nn.Sequential(
nn.Linear(my_noise, 3),
nn.LeakyReLU(0.2),
nn.Linear(3, 6),
nn.LeakyReLU(0.2),
nn.Linear(6, 1)
)
return net
#实例化鉴别器和生成器
D_net = D_model()
G_net = G_model(my_noise.shape[1])
#损失函数的定义
#此处是重中之重
#这里的mean很关键,其可以让向量变成标量,实现backward
#也可以使用.sum(),但是没有mean最终的效果好
def discriminator_loss(real_score, fake_score):
loss = 0.5 * ((real_score - 1) ** 2).mean() + 0.5 * (fake_score ** 2).mean()
return loss
def generator_loss(fake_score):
loss = 0.5 * ((fake_score - 1) ** 2).mean()
return loss
#定义优化算法
# 使用 adam 优化算法
def get_optimizer(net):
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3, betas=(0.5, 0.999))
return optimizer
#分别创建优化器
#创建优化器
D_optimizer = get_optimizer(D_net)
G_optimizer = get_optimizer(G_net)
#train model
for step in range(16001):
#真实分数
real_score =D_net(real_data)
#noise产生假数据
fake_data = G_net(my_noise)
#由假数据产生假分数
fake_score = D_net(fake_data)
#d的loss
d_loss = discriminator_loss(real_score,fake_score)
#首先优化D
D_optimizer.zero_grad()
d_loss.backward()
D_optimizer.step()
#在优化G
#此时如果不在重新生成假数据,编译器会报错
#生成假数据
fake_data = G_net(my_noise)
#由假数据产生假分数
fake_score = D_net(fake_data)
#g_loss
g_loss = generator_loss(fake_score)
G_optimizer.zero_grad()
g_loss.backward()
G_optimizer.step()
if (step%1000 ==0):
print("step:{},G_loss:{},D_loss:{}".format(step,g_loss.data,d_loss.data))
#测试结果
G_net.eval()
my_noise = torch.randn(3,3)
test = G_net(my_noise)
print(test)
tensor([[1.0201],
[1.0191],
[1.0153]], grad_fn=<AddmmBackward>)
到这,今天回顾的内容也就结束了。
PS:写代码有时候还是很有趣的,但是今天导师又叫我去跑别人的代码,尝试复现论文的内容,难道就我一个人觉得调通别人写好的代码很难嘛?摸不清文件之间的关系,看不懂的起名方式,怎么翻译都无法理解的英语注释,还有最重要的一点,就是大神们高深莫测的逻辑。
--------------------------------------------------2021/3/15 22:16