PyTorch实战,基于GAN实现表情包生成

这里写自定义目录标题

这是一个基于PyTorch框架实现的表情包生成工具,使用条件生成对抗网络(cGAN)进行训练。假设表情包是由描述和图像组成,我们将基于描述生成图像。

首先,我们需要准备我们的数据集。我们假设数据集是一个名为input的文件夹,其中包含表情包图像,每张图像的名称就是它所代表的描述。为了使用这些图像训练我们的cGAN,我们需要创建一个PyTorch数据集类,将图像加载到内存中。以下是一个简单的数据集类的示例代码:

import os
from PIL import Image
from torch.utils.data import Dataset

class EmoticonDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.filenames = os.listdir(self.data_dir)

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        filename = self.filenames[idx]
        img = Image.open(os.path.join(self.data_dir, filename)).convert('RGB')
        return img, filename.split('.')[0]

我们将数据集加载到内存中,将图像和它所代表的描述一起返回。


接下来,我们需要定义我们的cGAN模型。cGAN由两个部分组成:生成器和判别器。生成器将描述作为输入,输出一张图像。判别器将描述和图像作为输入,输出一个标量,表示图像是真实的还是生成的。我们的生成器和判别器将基于DCGAN结构。以下是生成器和判别器的示例代码:

import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()

        self.latent_dim = latent_dim

        self.model = nn.Sequential(
            nn.Linear(latent_dim+256, 256*8*8),
            nn.BatchNorm1d(256*8*8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Unflatten(-1, (256, 8, 8)),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, z, y):
        y_onehot = nn.functional.one_hot(y, num_classes=256).float()
        x = torch.cat([z, y_onehot], dim=1)
        x = self.model(x)
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Conv2d(3+256, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Linear(256*8*8, 1),
            nn.Sigmoid()
        )

	def forward(self, x, y):
	   	y_onehot = nn.functional.one_hot(y, num_classes=256).float()
	    y_onehot = y_onehot.view(-1, 256, 1, 1)
	    y_onehot = y_onehot.repeat(1, 1, 8, 8)
	    x = torch.cat([x, y_onehot], dim=1)
	    x = self.model(x)
	    return x


我们的生成器接受一个随机噪声向量和一个描述,输出一张图像。我们将描述转换为一个one-hot编码,然后与噪声向量连接。我们的判别器接受一个图像和一个描述,并输出一个标量,表示图像是真实的还是生成的。


接下来,我们需要定义训练过程。在每个训练迭代中,我们将随机选择一个描述和一个真实图像,然后使用生成器生成一张假图像。我们将使用真实图像和假图像来训练判别器。然后,我们将再次生成一张假图像,但这一次将其与所选描述一起用于训练生成器。我们将交替训练判别器和生成器,以达到最佳效果。以下是训练循环的示例代码:

import torch.optim as optim
import torchvision.utils as vutils

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

latent_dim = 100
batch_size = 64
lr = 0.0002
beta1 = 0.5
num_epochs = 50

dataset = EmoticonDataset('input')
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

netG = Generator(latent_dim).to(device)
netD = Discriminator().to(device)

criterion = nn.BCELoss()

optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

fixed_noise = torch.randn(64, latent_dim, device=device)
fixed_labels = torch.arange(0, 256).repeat(8).to(device)

for epoch in range(num_epochs):
    for i, (real_images, labels) in enumerate(dataloader):
        real_images = real_images.to(device)
        labels = labels.to(device)

        # Train discriminator
        netD.zero_grad()
        real_labels = torch.full((batch_size,), 1, device=device)
        fake_labels = torch.full((batch_size,), 0, device=device)

        real_scores = netD(real_images, labels).squeeze()
        real_loss = criterion(real_scores, real_labels)

        noise = torch.randn(batch_size, latent_dim)
        noise = noise.to(device)
	    fake_images = netG(noise, labels)
	    fake_scores = netD(fake_images.detach(), labels).squeeze()
	    fake_loss = criterion(fake_scores, fake_labels)
	
	    d_loss = real_loss + fake_loss
	    d_loss.backward()
	    optimizerD.step()
	
	    # Train generator
	    netG.zero_grad()
	    fake_scores = netD(fake_images, labels).squeeze()
	    g_loss = criterion(fake_scores, real_labels)
	
	    g_loss.backward()
	    optimizerG.step()

	    if i % 100 == 0:
	        print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f'
	              % (epoch, num_epochs, i, len(dataloader),
	                 d_loss.item(), g_loss.item()))

# Save generated images
with torch.no_grad():
    fake = netG(fixed_noise, fixed_labels).detach().cpu()
img_grid = vutils.make_grid(fake, nrow=8, normalize=True)
vutils.save_image(img_grid, 'output/fake_samples_epoch_%03d.png' % epoch)

# Save models
torch.save(netG.state_dict(), 'netG.pth')
torch.save(netD.state_dict(), 'netD.pth')

在训练循环中,我们使用Adam优化器来训练生成器和判别器。我们还定义了一个固定的噪声向量和标签,以便在每个时代中生成样本图像。我们使用PyTorch的vutils工具库将生成的图像保存到output文件夹中。


最后,我们需要编写推理脚本,以便用户可以使用我们的生成器生成表情包图像。以下是一个简单的推理脚本的示例代码:

import torch
from PIL import Image
import torchvision.transforms as transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
netG = Generator(latent_dim).to(device)
netG.load_state_dict(torch.load('netG.pth', map_location=device))

def generate_emoticon(description):
    netG.eval()

    with torch.no_grad():
        z = torch.randn(1, latent_dim).to(device)
        label = torch.tensor([ord(c) for c in description]).to(device)
        image = netG(z, label).squeeze().detach().cpu()

    image = transforms.ToPILImage()(image)
    return image

推理脚本将输入描述作为字符串,将其转换为标签,并将标签和随机噪声向量传递给我们的生成器。生成的图像将作为PIL图像返回。

以上是一个简单的基于PyTorch框架实现的表情包生成工具,使用cGAN进行训练,并提供了一个推理脚本,以便用户可以使用我们的生成器生成表情包图像。当然,这只是一个基础版本,可以通过增加模型复杂度、优化超参数等方式进一步改进。
[1]: http://chat.openai.com

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值