GAN实战——生成手写字体

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
import time

time_start = time.time()

# 生成器生成的数据在 [-1, 1]
transform = transforms.Compose([
    transforms.ToTensor(),  # 会做0-1归一化,也会channels, height, width
    transforms.Normalize((0.5,), (0.5,))
])

train_ds = torchvision.datasets.MNIST('data', train=True, transform=transform)
dataLoader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)


# 生成器网络定义
# 输入是长度为100的噪声(正态分布随机数)
# 输出为(1, 28, 28)的图片
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 28*28),
            nn.Tanh()
        )

    def forward(self, x):
        img = self.main(x)
        img = img.view(-1, 28, 28, 1)
        return img

# 判别器网络定义
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.main(x)
        return x


device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
gen = Generator().to(device)
dis = Discriminator().to(device)
d_optimizer = torch.optim.Adam(dis.parameters(), lr=0.0001)
g_optimizer = torch.optim.Adam(gen.parameters(), lr=0.0001)

# 损失函数
loss_fn = torch.nn.BCELoss()

# 绘图函数
def gen_img_plot(model, test_input):
    prediction = np.squeeze(model(test_input).detach().cpu().numpy())
    fig = plt.figure(figsize=(4, 4))
    for i in range(16):
        plt.subplot(4, 4, i+1)
        plt.imshow((prediction[i] + 1)/2)
        plt.axis('off')
    plt.show()


test_input = torch.randn(16, 100, device=device)


# GAN训练
D_loss = []
G_loss = []


# 训练循环
for epoch in range(20):
    d_epoch_loss = 0
    g_epoch_loss = 0
    count = len(dataLoader)  # 返回批次数
    for step, (img, _) in enumerate(dataLoader):
        img = img.to(device)
        size = img.size(0)
        random_noise = torch.randn(size, 100, device=device)

        # 判别器的损失与优化
        d_optimizer.zero_grad()
        real_output = dis(img)  # 对判别器输入真实图片, real_output是对真实图片的判断结果
        d_real_loss = loss_fn(real_output, torch.ones_like(real_output))  # 判别器在真实图像上的损失
        d_real_loss.backward()

        gen_img = gen(random_noise)
        fake_output = dis(gen_img.detach())  # 判别器输入生成的图片,fake_output对生成图片的预测
        d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))  # 判别器在生成图像上的损失
        d_fake_loss.backward()
        d_loss = d_real_loss + d_fake_loss
        d_optimizer.step()

        # 生成器的损失与优化
        g_optimizer.zero_grad()
        fake_output = dis(gen_img)
        g_loss = loss_fn(fake_output, torch.ones_like(fake_output))  # 生成器的损失
        g_loss.backward()
        g_optimizer.step()

        with torch.no_grad():
            d_epoch_loss += d_loss
            g_epoch_loss += g_loss

    with torch.no_grad():
        d_epoch_loss /= count
        g_epoch_loss /= count
        D_loss.append(d_epoch_loss)
        G_loss.append(g_epoch_loss)
        print("Epoch:", epoch)
        gen_img_plot(gen, test_input)

time_end = time.time()
print("花费总时间为:", time_end - time_start)

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
CycleGAN和StyleGANGAN生成式对抗网络)的两个重要应用。GAN是一种深度学习模型,它可以生成新的数据,比如图片、音频等。CycleGAN和StyleGAN的区别在于它们生成数据的方式以及应用领域。 CycleGAN是一种能够将一种图像转换成另一种图像的模型,例如将马变成斑马,将夏天的图片转换成冬天的图片等。它是由两个生成器和两个判别器组成的。其中一个生成器将一种图像转换成另一种图像,另一个生成器则将转换回来。两个判别器用于判断生成的图片是否真实。CycleGAN的优点是可以无需成对的图片进行训练,而且训练数据集不需要太大,只需要一些相关的图片即可。 StyleGAN则是一种用于生成逼真的图像的模型,它是在GAN的基础上进行了改进。StyleGAN可以生成逼真的人脸、汽车、动物等图像。StyleGAN的优点是可以生成高分辨率的图像,并且可以控制图像的风格和内容。StyleGAN可以使用一个具有连续变化的潜在空间来控制所生成图像的不同部分,从而可以在不同样本之间无缝地转换,这使得生成的图像更加逼真和自然。 总之,CycleGAN和StyleGAN都是GAN的应用,CycleGAN主要用于图像的风格转换,而StyleGAN则用于逼真图像的生成。它们的成功使得生成式对抗网络的应用得到了广泛的关注,并且将继续在图像、视频和音频数据的生成和处理中发挥重要作用。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值