本文分享手动实现DCGAN生成动漫头像的Pytorch代码。
简单来说,DCGAN(Deep Convolutional GAN)就是用全卷积代替了原始GAN的全连接结构,提升了GAN的训练稳定性和生成结果质量。
我使用的数据集,5W张96×96的动漫头像。
import torch
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms, datasets
from torchvision.utils import save_image
import os
class D_Net(nn.Module):
def __init__(self):
super(D_Net,self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, 5, 3, 1, bias=False),
nn.LeakyReLU(0.2, True)
) # 64, 32, 32
self.conv2 = nn.Sequential(
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, True)
) # 128, 16, 16
self.conv3 = nn.Sequential(
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, True)
) # 256, 8, 8
self.conv4 = nn.Sequential(
nn.Conv2d(256, 512, 4, 2, 1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, True)
) # 512, 4, 4
self.conv5 = nn.Sequential(
nn.Conv2d(512, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
) # 1, 1, 1
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
return x
# 判别器参数初始化
def d_weight_init(self, m):
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, mean=0, std=0.02)
class G_Net(nn.Module):
def __init__(self):
super(G_Net,self).__init__()
self.conv1 = nn.Sequential(
nn.ConvTranspose2d(128, 512, 4, 1, 0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True)
) # 512, 4, 4
self.conv2 = nn.Sequential(
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True)
) # 256, 8, 8
self.conv3 = nn.Sequential(
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True)
) # 128, 16, 16
self.conv4 = nn.Sequential(
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True)
) # 128, 32, 32
self.conv5 = nn.Sequential(
nn.ConvTranspose2d(64, 3, 5, 3, 1, bias=False),
nn.Tanh()
) # 3, 96, 96
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
return x
# 生成器参数初始化
def g_weight_init(self, m):
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, mean=0, std=0.02)
if __name__ == '__main__':
batch_size = 225
if not os.path.exists("./dcgan_img"):
os.mkdir("./dcgan_img")
if not os.path.exists("./dcgan_params"):
os.mkdir("./dcgan_params")
img_transf = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5, ], [0.5, ])
])
img_dir = r"C:\Cartoon_faces0.1"
# ImageFolder 不用自己写Dataset
dataset = datasets.ImageFolder(img_dir, transform=img_transf)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
d_net = D_Net().to(device)
g_net = G_Net().to(device)
d_weight_file = r"dcgan_params/d_net.pth"
g_weight_file = r"dcgan_params/g_net.pth"
if os.path.exists(d_weight_file) and os.path.getsize(d_weight_file) != 0:
d_net.load_state_dict(torch.load(d_weight_file))
print("加载判别器保存参数成功")
else:
d_net.apply(d_net.d_weight_init)
print("加载判别器随机参数成功")
if os.path.exists(g_weight_file) and os.path.getsize(g_weight_file) != 0:
g_net.load_state_dict(torch.load(g_weight_file))
print("加载生成器保存参数成功")
else:
g_net.apply(g_net.g_weight_init)
print("加载生成器随机参数成功")
loss_fn = nn.BCELoss()
d_opt = torch.optim.Adam(d_net.parameters(), lr=0.0002, betas=(0.5, 0.999))
g_opt = torch.optim.Adam(g_net.parameters(), lr=0.0002, betas=(0.5, 0.999))
epoch = 1
while True:
print("epoch--{}".format(epoch))
for i, (x, y) in enumerate(loader):
# 判别器
real_img = x.to(device)
real_label = torch.ones(x.size(0), 1, 1, 1).to(device)
fake_label = torch.zeros(x.size(0), 1, 1, 1).to(device)
real_out = d_net(real_img)
d_real_loss = loss_fn(real_out, real_label)
z = torch.randn(x.size(0), 128, 1, 1).to(device)
fake_img = g_net(z).detach()
fake_out = d_net(fake_img)
d_fake_loss = loss_fn(fake_out, fake_label)
d_loss = d_real_loss + d_fake_loss
d_opt.zero_grad()
d_real_loss.backward()
d_fake_loss.backward()
d_opt.step()
# 生成器
fake_img = g_net(z)
fake_out = d_net(fake_img)
g_loss = loss_fn(fake_out, real_label)
g_opt.zero_grad()
g_loss.backward()
g_opt.step()
if i == 100:
print("d_loss:{:.3f}\tg_loss:{:.3f}\td_real:{:.3f}\td_fake:{:.3f}".
format(d_loss.item(), g_loss.item(), real_out.data.mean(), fake_out.data.mean()))
fake_image = fake_img.cpu().data
save_image(fake_image, "./dcgan_img/{}_{}-fake_img.jpg".
format(epoch, i), nrow=15, normalize=True, scale_each=True)
torch.save(d_net.state_dict(), "dcgan_params/d_net.pth")
torch.save(g_net.state_dict(), "dcgan_params/g_net.pth")
epoch += 1
- 生成网络G和判别网络D结构几乎完全对称,G网络用转置卷积实现上采样,参数设置在我的另一篇文章中已解释。偶数卷积核在正常网络模型中很少见,但在生成模型中效果比较好,避免生成图像不均匀的现象。
- 两个网络的激活函数和输出函数、网络参数初始化、优化器参数的选择大都是DCGAN论文的默认值,是实验结果。
- real_label和fake_label就是全1和全0的值,判别器训练时,真图标签为1,假图标签为0。产生的真图假图两个loss,其实可以合成一个,进行一次backward()就行,但是实验发现分开效果会比较好。
- D网络训练时只需要正常判别输入图片是真图还是G网络生成的假图,而G网络则需要混淆视听,尽量提高生成假图在D网络的输出评分,互相对抗学习。关键代码为:
g_loss = loss_fn(fake_out, real_label)
。 - 判别器和生成器交替训练,训练一个时,另一个的参数应该固定。这在Pytorch中不需要我们做什么处理,因为在优化器中已经给定了需要优化的参数,虽然backward()计算了两个网络的全部梯度,但step()只更新了对应参数。训练D网络时,生成的假图用detach()操作截断了计算图,即不计算G网络的梯度,没什么大用,只是略微节省了一点时间。
- 交替训练时可以指定两个网络的训练频率,比如D网络每个batch训练1次,G网络每个batch训练2次。但是这个比例怎么取比较好需要实验,我这老年机就放弃了……同学们有兴趣可以调一下试试。
一开始是酱的:
200 epochs later……
远观尚可,别细看!
本文实现的DCGAN是最基础的,需要自己小心调参。为了解决GAN训练不稳定以及生成器和判别器的训练平衡问题,同学们可以参考一下WGAN,改动很小,效果很好。