引言
近年来,生成对抗网络(Generative Adversarial Networks, GANs)在图像生成领域取得了显著进展,尤其是在动漫头像生成方面。本文将详细介绍如何使用GAN来生成高质量的动漫人脸头像,并分享整个项目的实施步骤、关键技术和最终效果。
GAN基础
GAN由两部分组成:生成器(Generator)和判别器(Discriminator)。生成器的目标是生成尽可能接近真实数据的假数据,而判别器的目标是区分这些数据是真实的还是由生成器生成的。通过这两个部分的不断对抗训练,最终生成器能够生成难以区分的假数据。
数据集准备
首先,我们需要一个高质量的动漫头像数据集。这些插画风格一致、质量高、噪声小,非常适合用于训练GAN模型。
数据集链接:动漫头像数据集(提取码:crd8)
数据预处理
在将数据输入模型之前,我们需要对数据进行预处理。主要包括以下几个步骤:
- 图片缩放:将所有图片缩放到统一的尺寸,如64x64像素。
- 标准化:对图片进行标准化处理,使其均值为0,方差为1。
在Python中,我们可以使用torchvision.transforms
模块来完成这些操作。
import os
import torch.optim as optim
from matplotlib import pyplot as plt
from torchvision import transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm # 导入tqdm
from Model import *
# 图像数据预处理
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
GAN模型设计
生成器(Generator)
生成器的设计通常采用深度卷积神经网络(DCGAN)结构。输入是一个随机噪声向量,通过一系列的上采样(Transpose Convolution)和激活函数(如ReLU)生成最终的图像。
# 定义生成器网络
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, input):
return self.main(input)
判别器(Discriminator)
判别器通常是一个卷积神经网络,用于区分输入图像是真实的还是由生成器生成的。
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(3, 64, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, 4, 2, 1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input)
效果展示
经过迭代训练后产生相应的动漫人脸图片:
完整代码
import os
import torch.optim as optim
from matplotlib import pyplot as plt
from torchvision import transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm # 导入tqdm
from Model import *
# 自定义数据集类
class Mydataset(Dataset):
def __init__(self, image_dir, transform=None):
self.image_dir = image_dir
self.transform = transform
self.image_paths = [os.path.join(image_dir, img) for img in os.listdir(image_dir) if img.endswith(('png', 'jpg', 'jpeg'))]
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image_path = self.image_paths[idx]
image = Image.open(image_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image
# 定义生成器网络
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, input):
return self.main(input)
# 定义判别器网络
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(3, 64, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, 4, 2, 1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input)
# 参数设置
image_size = 64
batch_size = 128
nz = 100
num_epochs = 20
lr = 0.0002
beta1 = 0.5
image_dir = 'F:\\extra_data\\extra_data\\images'
# 图像数据预处理
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
# 加载数据集
dataset = Mydataset(image_dir=image_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 创建网络实例
# generator = Generator().cuda()
# discriminator = Discriminator().cuda()
generator=NetG(60,100).cuda()
discriminator=NetD(60).cuda()
# 损失函数和优化器
criterion = nn.BCELoss()
optimizerDis = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerGen = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
def train(num_epochs):
# 开始训练
for epoch in range(num_epochs):
progress_bar = tqdm(enumerate(dataloader, 0), total=len(dataloader), desc=f'Epoch {epoch + 1}/{num_epochs}',
unit='batch')
for i, data in progress_bar:
# 更新判别器网络
discriminator.zero_grad()
real_cpu = data.cuda()
batch_size = real_cpu.size(0)
label = torch.full((batch_size,), 1., dtype=torch.float, device='cuda')
output = discriminator(real_cpu).view(-1)
errD_real = criterion(output, label)
errD_real.backward()
D_x = output.mean().item()
noise = torch.randn(batch_size, nz, 1, 1, device='cuda')
fake = generator(noise)
label.fill_(0.)
output = discriminator(fake.detach()).view(-1)
errD_fake = criterion(output, label)
errD_fake.backward()
D_G_z1 = output.mean().item()
errD = errD_real + errD_fake
optimizerDis.step()
# 更新生成器网络
generator.zero_grad()
label.fill_(1.)
output = discriminator(fake).view(-1)
errG = criterion(output, label)
errG.backward()
D_G_z2 = output.mean().item()
optimizerGen.step()
# 更新进度条描述信息
progress_bar.set_postfix({
'Loss_Discriminator': f'{errD.item():.4f}',
'Loss_Generator': f'{errG.item():.4f}',
'D(x)': f'{D_x:.4f}',
'D_G_z1 / D_G_z2': f'{D_G_z1:.4f} / {D_G_z2:.4f}'
})
# 每五个epoch保存一次生成器的输出图片
if epoch % 5 == 0:
with torch.no_grad():
noise = torch.randn(36, nz, 1, 1, device='cuda')
fake = generator(noise).cpu() # 不需要 detach(),因为 with torch.no_grad(): 下不会计算梯度
grid = make_grid(fake, nrow=6, normalize=True)
# 使用 matplotlib 显示图像
img = grid.numpy().transpose((1, 2, 0)) # 转换 numpy 数组并调整通道顺序
plt.imshow(img)
plt.axis('off') # 不显示坐标轴
plt.title(f'Fake Samples - Epoch {epoch}')
plt.show()
# 如果你还想保存图像到文件,你可以调用 save_image
# save_image(grid, f'results/fake_samples_epoch_{epoch}.png')
# print(f'Saved fake_samples_epoch_{epoch}.png')
# 保存模型
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')
train(num_epochs)