利用GAN生成动漫人脸头像

引言

近年来,生成对抗网络(Generative Adversarial Networks, GANs)在图像生成领域取得了显著进展,尤其是在动漫头像生成方面。本文将详细介绍如何使用GAN来生成高质量的动漫人脸头像,并分享整个项目的实施步骤、关键技术和最终效果。

GAN基础

GAN由两部分组成:生成器(Generator)和判别器(Discriminator)。生成器的目标是生成尽可能接近真实数据的假数据,而判别器的目标是区分这些数据是真实的还是由生成器生成的。通过这两个部分的不断对抗训练,最终生成器能够生成难以区分的假数据。

数据集准备

首先,我们需要一个高质量的动漫头像数据集。这些插画风格一致、质量高、噪声小,非常适合用于训练GAN模型。

数据集链接:动漫头像数据集(提取码:crd8)

数据预处理

在将数据输入模型之前,我们需要对数据进行预处理。主要包括以下几个步骤:

  1. 图片缩放:将所有图片缩放到统一的尺寸,如64x64像素。
  2. 标准化:对图片进行标准化处理,使其均值为0,方差为1。

在Python中,我们可以使用torchvision.transforms模块来完成这些操作。

import os
import torch.optim as optim
from matplotlib import pyplot as plt
from torchvision import transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm  # 导入tqdm
from Model import *

# 图像数据预处理
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

GAN模型设计

生成器(Generator)

生成器的设计通常采用深度卷积神经网络(DCGAN)结构。输入是一个随机噪声向量,通过一系列的上采样(Transpose Convolution)和激活函数(如ReLU)生成最终的图像。

# 定义生成器网络
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

判别器(Discriminator)

判别器通常是一个卷积神经网络,用于区分输入图像是真实的还是由生成器生成的。

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

效果展示

经过迭代训练后产生相应的动漫人脸图片:

完整代码

import os
import torch.optim as optim
from matplotlib import pyplot as plt
from torchvision import transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm  # 导入tqdm
from Model import *

# 自定义数据集类
class Mydataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.image_paths = [os.path.join(image_dir, img) for img in os.listdir(image_dir) if img.endswith(('png', 'jpg', 'jpeg'))]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

# 定义生成器网络
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

# 定义判别器网络
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

# 参数设置
image_size = 64
batch_size = 128
nz = 100
num_epochs = 20
lr = 0.0002
beta1 = 0.5
image_dir = 'F:\\extra_data\\extra_data\\images'

# 图像数据预处理
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# 加载数据集
dataset = Mydataset(image_dir=image_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 创建网络实例
# generator = Generator().cuda()
# discriminator = Discriminator().cuda()
generator=NetG(60,100).cuda()
discriminator=NetD(60).cuda()
# 损失函数和优化器
criterion = nn.BCELoss()
optimizerDis = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerGen = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))

def train(num_epochs):
    # 开始训练
    for epoch in range(num_epochs):
        progress_bar = tqdm(enumerate(dataloader, 0), total=len(dataloader), desc=f'Epoch {epoch + 1}/{num_epochs}',
                            unit='batch')
        for i, data in progress_bar:
            # 更新判别器网络
            discriminator.zero_grad()
            real_cpu = data.cuda()
            batch_size = real_cpu.size(0)
            label = torch.full((batch_size,), 1., dtype=torch.float, device='cuda')
            output = discriminator(real_cpu).view(-1)
            errD_real = criterion(output, label)
            errD_real.backward()
            D_x = output.mean().item()

            noise = torch.randn(batch_size, nz, 1, 1, device='cuda')
            fake = generator(noise)
            label.fill_(0.)
            output = discriminator(fake.detach()).view(-1)
            errD_fake = criterion(output, label)
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            errD = errD_real + errD_fake
            optimizerDis.step()

            # 更新生成器网络
            generator.zero_grad()
            label.fill_(1.)
            output = discriminator(fake).view(-1)
            errG = criterion(output, label)
            errG.backward()
            D_G_z2 = output.mean().item()
            optimizerGen.step()

            # 更新进度条描述信息
            progress_bar.set_postfix({
                'Loss_Discriminator': f'{errD.item():.4f}',
                'Loss_Generator': f'{errG.item():.4f}',
                'D(x)': f'{D_x:.4f}',
                'D_G_z1 / D_G_z2': f'{D_G_z1:.4f} / {D_G_z2:.4f}'
            })

        # 每五个epoch保存一次生成器的输出图片
        if epoch % 5 == 0:
            with torch.no_grad():
                noise = torch.randn(36, nz, 1, 1, device='cuda')
                fake = generator(noise).cpu()  # 不需要 detach(),因为 with torch.no_grad(): 下不会计算梯度
                grid = make_grid(fake, nrow=6, normalize=True)

                # 使用 matplotlib 显示图像
            img = grid.numpy().transpose((1, 2, 0))  # 转换 numpy 数组并调整通道顺序
            plt.imshow(img)
            plt.axis('off')  # 不显示坐标轴
            plt.title(f'Fake Samples - Epoch {epoch}')
            plt.show()

            # 如果你还想保存图像到文件,你可以调用 save_image
            # save_image(grid, f'results/fake_samples_epoch_{epoch}.png')
            # print(f'Saved fake_samples_epoch_{epoch}.png')
        # 保存模型
        torch.save(generator.state_dict(), 'generator.pth')
        torch.save(discriminator.state_dict(), 'discriminator.pth')

train(num_epochs)

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值