代码地址:Yuki0614/Series-of-GAN-for-Anime (github.com)
数据集地址:Anime Faces | Kaggle
FFT - Fast File Transfer Service (iqiyi.com) 密码: 5Kv2M1
Cartoon Face Recognition: A Benchmark Dataset (arxiv.org)
最近在学习GAN网络,感觉挺有意思的。GAN网络的系列有很多,包括GAN,CGAN,WGAN,CycleGAN等,差别在于训练方式,网络结构,不同任务(风格迁移,多分类图像生成等),在知乎和csdn上找了一些代码,然后找了个动漫图像的数据集复现了一下,想来和大伙交流学习一下。
第一个数据集是无类别的动漫头像,第二是有类别的,可以用于CGAN等网络。数据集比较大,我是选了一部分数据跑的,链接放上面了。
原始GAN网络结构由生成器和判别器组成。生成器输入噪声输出图片,其目的是生成尽可能接近真实图片的图片,来骗过判别器,故其损失函数为判别器对其生成图片真伪的判断和真实图片的标签(全1)构成的损失函数。判别器输入图片,输出真伪(0到1的概率),其目的是尽可能区分是真实图片还是生成器生成的图片,故其损失函数为真实图片标签(1)与输出概率,生成图片标签(0)与输出概率,两部分构成。在GAN中,生成器和判别器均由全连接网络构成。理想状态,经过不断训练,生成器可以生成接近真实的图片,判别器对所有输入都认为概率为0.5(无法分类)。
下面是原始GAN网络的代码,其他网络的放在github了,以后还会继续学习一些新的网络,弄明白以后就会更新。
网络结构:
import torch.nn as nn class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.main = nn.Sequential( nn.Linear(100, 512), nn.ReLU(), nn.Linear(512, 1024), nn.ReLU(), nn.Linear(1024, 64*64*3), # 数据集图片大小为64*64,3通道 nn.Tanh() # 最后必须用tanh,把数据分布到(-1,1)之间 ) def forward(self, x): # x表示长度为100的噪声输入 img = self.main(x) img = img.view(-1,3,64,64) return img class Discriminator(nn.Module): def __init__(self): super(Discriminator,self).__init__() self.main = nn.Sequential( nn.Linear(64*64*3, 1024), nn.LeakyReLU(), # x小于零是是一个很小的值不是0,x大于0是还是x nn.Linear(1024,512), nn.LeakyReLU(), nn.Linear(512,1), nn.Sigmoid() # 保证输出范围为(0,1)的概率 ) def forward(self, x): img = x.view(-1, 64*64*3) img = self.main(img) return img
训练代码:
import torch import torch.nn as nn from torchvision import transforms from create_dataset import My_dataset, save_img from torch.utils.data import DataLoader from model64 import Generator, Discriminator # 图像变换 transform = transforms.Compose([ transforms.Resize((64, 64)), # 网络设置图片大小为 64*64,保证图片大小符合网络结构要求 transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) dataset = My_dataset(r'E:\work\GAN\1', transform=transform) # 数据集位置 batch_size, epochs = 32, 500 my_dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=True) discriminator = Discriminator() generator = Generator() if torch.cuda.is_available(): discriminator = discriminator.cuda() generator = generator.cuda() d_optimizer = torch.optim.Adam(discriminator.parameters(), betas=(0.5, 0.99), lr=1e-4) # betas为adam算法两个动量参数 g_optimizer = torch.optim.Adam(generator.parameters(), betas=(0.5, 0.99), lr=1e-4) criterion = nn.BCELoss() for epoch in range(epochs): for i, img in enumerate(my_dataloader): noise = torch.randn(batch_size, 100).cuda() # 随机噪声作为输入 real_img = img.cuda() fake_img = generator(noise) real_out = discriminator(real_img) fake_out = discriminator(fake_img) real_label = torch.ones_like(real_out).cuda() # 真实图片标签为1,生成图片标签为0 fake_label = torch.zeros_like(fake_out).cuda() real_loss = criterion(real_out, real_label) fake_loss = criterion(fake_out, fake_label) d_loss = real_loss + fake_loss # 训练判别器 d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() noise = torch.randn(batch_size, 100).cuda() fake_img = generator(noise) output = discriminator(fake_img) g_loss = criterion(output, real_label) # 训练生成器 g_optimizer.zero_grad() g_loss.backward() g_optimizer.step() if (i + 1) % 5 == 0: print('Epoch[{}/{}],d_loss:{:.6f},g_loss:{:.6f} ' 'D_real: {:.6f},D_fake: {:.6f}'.format( epoch, epochs, d_loss.data.item(), g_loss.data.item(), real_out.data.mean(), fake_out.data.mean() # 打印的是真实图片的损失均值 )) if epoch == 0 and i == len(my_dataloader) - 1: # 保存真实图像 save_img(img[:64, :, :, :], './sample/real_images.png') if (epoch+1) % 50 == 0 and i == len(my_dataloader)-1: # 每50个epoch保存一次预测图像 save_img(fake_img[:64, :, :, :], './sample/fake_images_{}.png'.format(epoch + 1)) torch.save(generator.state_dict(), './generator.pth') # 保存权重文件 torch.save(discriminator.state_dict(), './discriminator.pth')
数据集工具:
import os import numpy as np import torch from PIL import Image from torchvision.utils import make_grid import torch.utils.data.dataset as Dataset class My_dataset(Dataset.Dataset): def __init__(self, path, transform): self.path = path loc_list = os.listdir(self.path) self.loc_list = loc_list self.tranform = transform def __getitem__(self, index): loc_data = os.path.join(self.path, self.loc_list[index]) data = Image.open(loc_data) data = np.array(data) data = Image.fromarray(data) # 若导入图像不为PIL格式则需要转换 data = self.tranform(data) return data def __len__(self): return len(self.loc_list) def save_img(tensor, fp): grid = make_grid(tensor) ndarr = (grid.mul(0.5).add_(0.5)).mul(255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() im = Image.fromarray(ndarr, mode='RGB') im.save(fp)
训练效果:GAN50epoch
GAN200epoch
对比了其他网络训练的效果:
WGAN
DCGAN
以上是由同一数据集不同GAN网络学习生成结果
感觉好像是有点差别的,不知道为什么GAN的图像好像有很多噪声?而且画风都有些不一样了,不过DCGAN和WGAN确实要比GAN效果好些,我这里没有用完整数据集,而且训练epoch比较少,增加训练,或者调大网络结构效果可能会更好些。