理论部分参考:李宏毅机器学习——对抗生成网络(GAN)_iwill323的博客-CSDN博客
目录
AFD (Anime face detection) rate
任务和数据集
1. Input: 随机数,输入的维度是(batch size, 特征数)
2. Output: 动漫人物脸
3. Implementation requirement: DCGAN & WGAN & WGAN-GP
4. Target:产生1000动漫人物脸
数据来自Crypko网站,有71,314个图像。可以从李宏毅2022机器学习HW6解析_机器学习手艺人的博客-CSDN博客获取数据
评价方法
FID
将真假图片送入另一个模型,产生对应的特征,计算真假特征的距离
AFD (Anime face detection) rate
1. To detect how many anime faces in your submission
2. The higher, the better
代码
导包
# import module
import os
import glob
import random
from datetime import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch import optim
from torch.utils.data import Dataset, DataLoader
from torch import autograd
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import logging
from tqdm import tqdm
# seed setting
def same_seeds(seed):
# Python built-in random module
random.seed(seed)
# Numpy
np.random.seed(seed)
# Torch
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
same_seeds(2022)
workspace_dir = '../input'
建立数据集
注意fnames是一个list类型的文件,和原代码不同,这里使用Image.open()来读取数据
# prepare for CrypkoDataset
class CrypkoDataset(Dataset):
def __init__(self, fnames, transform):
self.transform = transform
self.fnames = fnames
self.num_samples = len(self.fnames)
def __getitem__(self,idx):
fname = self.fnames[idx]
img = Image.open(fname)
img = self.transform(img)
return img
def __len__(self):
return self.num_samples
def get_dataset(root):
# glob.glob返回匹配给定通配符的文件列表
fnames = glob.glob(os.path.join(root, '*')) # list
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
])
dataset = CrypkoDataset(fnames, transform)
return dataset
显示一些图片
temp_dataset = get_dataset(os.path.join(workspace_dir, 'faces'))
images = [temp_dataset[i] for i in range(4)]
grid_img = torchvision.utils.make_grid(images, nrow=4)
plt.figure(figsize=(10,10))
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()
模型设置
生成器
生成器的目的是将输入向量z映射到真的数据空间。这儿我们的数据为图片,意味着我们需要将输入向量z转换为 3x64x64的RGB图像。实际操作时,通过一系列的二维转置卷积,每次转置卷积后跟一个二维的batch norm层和一个relu激活层。生成器的输出接入tanh函数以便满足输出范围为[−1,1]。值得一提的是,每个转置卷积后面跟一个 batch norm 层,是DCGAN论文的一个主要贡献。这些网络层有助于训练时的梯度计算。
反卷积参考这里:ConvTranspose2d原理,深度网络如何进行上采样?_月下花弄影的博客-CSDN博客
# Generator
class Generator(nn.Module):
"""
Input shape: (batch, in_dim)
Output shape: (batch, 3, 64, 64)
"""
def __init__(self, in_dim, feature_dim=64):
super().__init__()
#input: (batch, 100)
self.l1 = nn.Sequential(
nn.Linear(in_dim, feature_dim * 8 * 4 * 4, bias=False),
nn.BatchNorm1d(feature_dim * 8 * 4 * 4),
nn.ReLU()
)
self.l2 = nn.Sequential(
self.dconv_bn_relu(feature_dim * 8, feature_dim * 4), #(batch, feature_dim * 16, 8, 8)
self.dconv_bn_relu(feature_dim * 4, feature_dim * 2), #(batch, feature_dim * 16, 16, 16)
self.dconv_bn_relu(feature_dim * 2, feature_dim), #(batch, feature_dim * 16, 32, 32)
)
self.l3 = nn.Sequential(
nn.ConvTranspose2d(feature_dim, 3, kernel_size=5, stride=2,
padding=2, output_padding=1, bias=False),
nn.Tanh()
)
self.apply(weights_init)
def dconv_bn_relu(self, in_dim, out_dim):
return nn.Sequential(
nn.ConvTrans