GAN实现
1.概述
在Encoder-Decoder结构中,Encoder提取数据特征,Decoder还原数据特征。但是,Decoder输入特征的分布必须得统一分布,否则生成的图片是破碎的。
这很好理解:
对于
MINST
分类任务来说,Encoder输出的是MNIST
数据集上的所有数据的特征,而这个特征分布在这个数据集上,如果我们输入的特征并不是在这个特征分布上,那么Decoder还原的数据肯定也不是在原本的数据里。
就像给一个没见过狗的画家叫他画狗,你可以给他描述狗的特征,但是画家画的狗也只是一个各种动物的组合体。
理论上,我们可以让网络学到一个数据变换,让符合正态分布的数据经过网络后能自动变为符合生成任务的特征分布(VAE,变分自编码,相对熵)。
对于GAN来说确实另一个思路,通过Generator
和Discriminator
在训练过程中的博弈,既要使得Generator
生成的图像能够使得分辨不出真伪,也要使得Discriminator
能够完美分辨出是真实的图片还Generator
生成的图片。当然,这是理想情况😄。
2.GAN
对于判别器(Discriminator
)来说,数据标签只有两类,真实数据为正例,生成数据为负例。而对于生成器(Generator
)来说数据没有标签,使用判别器输出的结果作为loss,这样判别器就能给生成器输出的结果一个反馈了。而生成器通过这一反馈来调整,使得最后输出能够’骗过’判别器。
生成器在训练和使用过程中必须从同一个分布进行采样,这是因为生成器在训练过程中能够学到采样的数据和需要生成数据之间的映射关系,倘如数据分布变了,这个映射关系就失效了。
使用MNIST
数据集,定义两个网络:
1️⃣判别网络
class DNet(nn.Module):
def __init__(self):
super(DNet, self).__init__()
self.dnet=nn.Sequential(
nn.Linear(784,512),
nn.LeakyReLU(),
nn.Linear(512,1024),
nn.LeakyReLU(),
nn.Linear(1024,256),
nn.LeakyReLU(),
nn.Linear(256,2),#二分类,真、假
nn.Sigmoid(),#0-1
# nn.Softmax()
)
def forward(self,x):
out=self.dnet(x)
return out
2️⃣生成模型
class GNet(nn.Module):
def __init__(self):
super(GNet, self).__init__()
self.gnet=nn.Sequential(
nn.Linear(128,512),
nn.LeakyReLU(),
nn.Linear(512,1024),
nn.LeakyReLU(),
nn.Linear(1024,512),
nn.LeakyReLU(),
nn.Linear(512, 784),
nn.ReLU(), #像素值是[0,255]
)
def forward(self,x):
return self.gnet(x)
3️⃣定义训练过程
if __name__ == '__main__':
dataset = datasets.MNIST('F:\Dataset', train=True, transform=torchvision.transforms.ToTensor(), download=False)
train_data = DataLoader(dataset, batch_size=100, shuffle=True)
d_net = DNet().cuda()
g_net = GNet().cuda()
loss = nn.BCELoss()
fake_img_save=0.
dnet_opt = torch.optim.Adam(d_net.parameters(), lr=0.0002, betas=(0.5, 0.999))
gnet_opt = torch.optim.Adam(g_net.parameters(), lr=0.0002, betas=(0.5, 0.999))
summary=SummaryWriter('logs')
for epoch in range(100):
for i, (img, _) in enumerate(train_data):
real_img = img.reshape(-1, 784).cuda()
real_label = torch.ones(real_img.shape[0],1).cuda()
fake_label = torch.zeros(real_img.shape[0],1).cuda()
# 训练dnet
# real loss
real_out = d_net(real_img)
real_loss = loss(real_out, real_label)
# fake loss
g_data1 = torch.randn(real_img.shape[0], 128).cuda()
# ++++++++++++++++++++++#
# 生成fake img
fake_img = g_net(g_data1)
# g网络判断
fake_out1 = d_net(fake_img)
# ++++++++++++++++++++++#
fake_loss = loss(fake_out1, fake_label)
loss_d = real_loss + fake_loss
dnet_opt.zero_grad()
loss_d.backward()
dnet_opt.step()
# 训练gnet
g_data2 = torch.randn(real_img.shape[0], 128).cuda()
# ++++++++++++++++++++++#
fake_img = g_net(g_data2)
# 拿fake img给dnet判断
fake_out2 = d_net(fake_img)
# ++++++++++++++++++++++#
# 损失:生成数据与真实数据做损失
loss_g = loss(fake_out2, real_label)
gnet_opt.zero_grad()
loss_g.backward()
gnet_opt.step()
summary.add_scalars('loss',{'g_loss':loss_g,'d_loss':loss_d})
if i %100==1:
print("dnet loss=>",loss_d.item())
print('gnet loss=>',loss_g.item())
print('='*25)
fake_img_save=fake_img.reshape(-1,1,28,28)#[1,784]->[1,28,28]
save_image(fake_img_save,'gan_mnist_img/fake_{}.jpg'.format(epoch),nrow=10,normalize=True,scale_each=True)
torch.save(g_net.state_dict(), 'GNet/gnet.pt')
最后生成在训练过程中的GNet的输出如下:
红色为判别器,蓝色为生成器的损失图像。
随机生成一个正态分布的数,得到如下图像:
从结果来看,GAN似乎只会输出数字0,初步分析是判别器能力太强了。
将网络重新定义:
class DNet(nn.Module):
def __init__(self):
super(DNet, self).__init__()
self.dnet = nn.Sequential(
nn.Linear(784, 512),
nn.LeakyReLU(),
nn.Linear(512, 256),
nn.LeakyReLU(),
# nn.Linear(256, 512),
# nn.LeakyReLU(),
# nn.Linear(512, 256),
# nn.LeakyReLU(),
nn.Linear(256, 1), # 二分类,真、假
nn.Sigmoid(), # 0-1
# nn.Softmax()
)
def forward(self, x):
out = self.dnet(x)
return out
class GNet(nn.Module):
def __init__(self):
super(GNet, self).__init__()
self.gnet = nn.Sequential(
nn.Linear(128, 256),
nn.LeakyReLU(),
nn.Linear(256, 512),
nn.LeakyReLU(),
# nn.Linear(512, 1024),
# nn.LeakyReLU(),
# nn.Linear(1024, 512),
# nn.LeakyReLU(),
nn.Linear(512, 784),
# nn.ReLU(), # 像素值是[0,255]
)
def forward(self, x):
return self.gnet(x)
重新训练一遍后,从标准正态分布中采样得到随机数,输入生成器中得到结果:
不过似乎效果也不咋样,只会输出3
,8
,0
。
对比整个训练可以看出,GNet和DNet的训练过程是一个此消彼长的过程。
3.DCGAN
DCGAN是对GAN的改进,主要是将全连接变为了卷积。除此之外,还:
- 将图像归一化到[-1,1]之间;
- 参数初始化为 N ( 0 , 0.02 ) N(0,0.02) N(0,0.02)分布;
- LeakyReLU的斜率为0.02;
- 用卷积代替池化;
- 去掉sigmoid激活
1️⃣定义判别网络
class Conv_bn_leak(nn.Module):
def __init__(self, in_channel, out_channel, ksize, stride, padding, bias=False):
super(Conv_bn_leak, self).__init__()
self.conv_bn_leak = nn.Sequential(
nn.Conv2d(in_channel, out_channel, ksize, stride, padding, bias=bias),
nn.BatchNorm2d(out_channel),
nn.LeakyReLU(0.2)
)
class DNet(nn.Module):
def __init__(self):
super(DNet, self).__init__()
self.dnet = nn.Sequential(
Conv_bn_leak(1, 64, 5, 3, padding=1),
Conv_bn_leak(64, 128, 4, 2, padding=1),
Conv_bn_leak(128, 256, 4, 2, padding=1),
Conv_bn_leak(256, 512, 4, 2, padding=1),
nn.Conv2d(512, 1, 4, 1, padding=0, bias=False) # 不激活
)
def forward(self, x):
out = self.dnet(x)
return out.reshape(-1) # [NCHW]->[N]
2️⃣定义生成网络
class ConvT_bn_relu(nn.Module):
def __init__(self, in_chanel, out_channel, ksize, stride, padding, bias=False):
super(ConvT_bn_relu, self).__init__()
self.convT_bn_leak = nn.Sequential(
nn.ConvTranspose2d(in_chanel, out_channel, kernel_size=ksize,stride=stride, padding=padding, bias=bias),
nn.BatchNorm2d(out_channel),
nn.ReLU()
)
def forward(self,x):
return self.convT_bn_leak(x)
# GNet和DNet对称
class GNet(nn.Module):
def __init__(self):
super(GNet, self).__init__()
self.gnet = nn.Sequential(
ConvT_bn_relu(128, 512, 4, 1, padding=0),
ConvT_bn_relu(512, 256, 4, 2, padding=1),
ConvT_bn_relu(256, 128, 4, 2, padding=1),
ConvT_bn_relu(128, 64, 4, 2, padding=1),
nn.ConvTranspose2d(64, 1, 5, 3, padding=1, bias=False),
nn.Tanh()
)
def forward(self, x):
return self.gnet(x)
3️⃣定义DCGAN
由于判别网络最后没有经过sigmoid激活,所以不能使用BCELoss。使用nn.BCEWithLogitsLoss()
,公式如下:
ℓ
c
(
x
,
y
)
=
L
c
=
{
l
1
,
c
,
…
,
l
N
,
c
}
⊤
,
l
n
,
c
=
−
w
n
,
c
[
p
c
y
n
,
c
⋅
log
σ
(
x
n
,
c
)
+
(
1
−
y
n
,
c
)
⋅
log
(
1
−
σ
(
x
n
,
c
)
)
]
\ell_c(x, y) = L_c = \{l_{1,c},\dots,l_{N,c}\}^\top, \quad l_{n,c} = - w_{n,c} \left[ p_c y_{n,c} \cdot \log \sigma(x_{n,c}) + (1 - y_{n,c}) \cdot \log (1 - \sigma(x_{n,c})) \right]
ℓc(x,y)=Lc={l1,c,…,lN,c}⊤,ln,c=−wn,c[pcyn,c⋅logσ(xn,c)+(1−yn,c)⋅log(1−σ(xn,c))]
而BCELoss的公式如下:
ℓ
(
x
,
y
)
=
L
=
{
l
1
,
…
,
l
N
}
⊤
,
l
n
=
−
w
n
[
y
n
⋅
log
x
n
+
(
1
−
y
n
)
⋅
log
(
1
−
x
n
)
]
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad l_n = - w_n \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right]
ℓ(x,y)=L={l1,…,lN}⊤,ln=−wn[yn⋅logxn+(1−yn)⋅log(1−xn)]
class DCGAN(nn.Module):
def __init__(self):
super(DCGAN, self).__init__()
self.gnet = GNet().cuda()
self.dnet = DNet().cuda()
self.loss_fn = nn.BCEWithLogitsLoss()
def forward(self, x):
return self.gnet(x)
def get_d_loss(self, noise, real_img):
real_y = self.dnet(real_img)
fake_img = self.gnet(noise)
fake_y = self.dnet(fake_img)
real_label = torch.ones(real_img.shape[0]).cuda()
fake_label = torch.zeros(real_img.shape[0]).cuda()
loss_real = self.loss_fn(real_y,real_label)
loss_fake = self.loss_fn(fake_y,real_label,)
loss=loss_fake+loss_real
return loss
def get_g_loss(self,noise):
fake_img=self.gnet(noise)
real_label=torch.ones(noise.shape[0]).cuda()
fake_y=self.dnet(fake_img)
return self.loss_fn(fake_y,real_label)
BCEWithLogitsLoss参数位置不要搞混,否则输出的是错误的值。
if __name__ == '__main__':
gan = DCGAN()
dataset = FaceDate('x:\Cartoon_faces')
train_data = DataLoader(dataset, shuffle=True, batch_size=100,drop_last=True)
opt_d=opt.Adam(gan.dnet.parameters(),lr=0.0002,betas=(0.5,0.999))
opt_g = opt.Adam(gan.gnet.parameters(), lr=0.0002, betas=(0.5, 0.999))
# opt_d = opt.Adam(gan.dnet.parameters(), lr=0.0002)#=>dloss偏小
# opt_g = opt.Adam(gan.gnet.parameters(), lr=0.0002)
summary = SummaryWriter('logs')
for epoch in range(100):
for i, img in enumerate(train_data):
real_img = img.cuda()
noise = torch.normal(0, 0.02, (100, 128, 1, 1)).cuda() # 正态分布
loss_d = gan.get_d_loss(noise, real_img)
opt_d.zero_grad()
loss_d.backward()
opt_d.step()
loss_g = gan.get_g_loss(noise)
opt_g.zero_grad()
loss_g.backward()
opt_g.step()
summary.add_scalars('loss', {'g_loss': loss_g, 'd_loss': loss_d})
if i % 100 == 0:
print("dnet loss=>", loss_d.item())
print('gnet loss=>', loss_g.item())
print('=' * 25)
noise = torch.normal(0, 0.02, (100, 128, 1, 1)).cuda()
fake_img = gan(noise)
save_image(fake_img, 'dcgan_gen_img/fake_{}.jpg'.format(epoch), nrow=10, normalize=True,scale_each=True)
torch.save(gan.gnet.state_dict(), 'save/dcgan_g.pt')
使用二次元图片作为数据集,经过35轮,结果如下。
似乎就是把数据集中各种数据的特征缝合起来😕,比如眼睛异色的概率很大,显然网络没有学到脸部特征的信息。