DCGAN网络简单应用——MNISTS数据集合

本文介绍了如何使用PyTorch构建一个生成对抗网络,用于生成MNIST手写数字的图像。它包括定义生成器和判别器模型,设置损失函数和优化器,以及训练和保存模型的过程。
摘要由CSDN通过智能技术生成
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from IPython import display
import time
import imageio
import glob
from PIL import Image

EPOCHS =30
BATCH_SIZE=128
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# 其他代码...

# 在训练循环之前生成一个恒定的种子     可以确保每次运行程序时生成的随机数都是相同的都是tensor ,torch.Size([16, 100])
num_examples_to_generate = 16
seed = torch.randn(num_examples_to_generate, 100, device=device)
# MNIST 数据处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 加载 MNIST 数据
train_dataset = datasets.MNIST(root="./data/mnist", train=True, download=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

#生成器模型   输入[BATCH_SIZE,100]>>>输出[BATCH_SIZE,1,28,28]
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc = nn.Linear(100, 7 * 7 * 256)

        self.batch_norm1 = nn.BatchNorm2d(256)
        self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=5, stride=1, padding=2, output_padding=0, bias=False)

        self.batch_norm2 = nn.BatchNorm2d(128)
        self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False)

        self.batch_norm3 = nn.BatchNorm2d(64)
        self.deconv3 = nn.ConvTranspose2d(64, 1, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False)

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), 256, 7, 7)
        x = self.batch_norm1(x)
        x = nn.functional.leaky_relu(x)

        x = self.deconv1(x)
        x = self.batch_norm2(x)
        x = nn.functional.leaky_relu(x)

        x = self.deconv2(x)
        x = self.batch_norm3(x)
        x = nn.functional.leaky_relu(x)

        x = self.deconv3(x)
        x = torch.tanh(x)
        return x

# 判别器模型  输入[BATCH_SIZE,1,28,28]>>>输出[BATCH_SIZE,]
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=5, stride=2, padding=2)
        self.leaky_relu1 = nn.LeakyReLU(0.2)
        self.dropout1 = nn.Dropout(0.3)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2)
        self.leaky_relu2 = nn.LeakyReLU(0.2)
        self.dropout2 = nn.Dropout(0.3)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(7 * 7 * 128, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.leaky_relu1(x)
        x = self.dropout1(x)
        x = self.conv2(x)
        x = self.leaky_relu2(x)
        x = self.dropout2(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

# 创建模型实例
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# 损失函数和优化器
criterion = nn.BCEWithLogitsLoss()
optimizer_generator = optim.Adam(generator.parameters(), lr=1e-4)
optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=1e-4)

# 训练循环
def train_step(images):
    batch_size = images.size(0)  # 获取实际的批次大小
    noise = torch.randn(batch_size, 100, device=device)

    optimizer_generator.zero_grad()
    optimizer_discriminator.zero_grad()

    generated_images = generator(noise)

    real_labels = torch.ones(batch_size, 1, device=device)
    fake_labels = torch.zeros(batch_size, 1, device=device)

    real_output = discriminator(images)
    # 计算判别器对生成样本的损失,将生成器的输出张量分离以避免梯度传播到生成器模型
    fake_output = discriminator(generated_images.detach())

    gen_loss = criterion(discriminator(generated_images), real_labels)
    gen_loss.backward()
    optimizer_generator.step()

    disc_loss_real = criterion(real_output, real_labels)
    disc_loss_fake = criterion(fake_output, fake_labels)
    disc_loss = disc_loss_real + disc_loss_fake
    disc_loss.backward()
    optimizer_discriminator.step()

def generate_and_save_images(model, epoch, test_input):
    model.eval()
    with torch.no_grad():
        predictions = model(test_input).cpu().detach()
    model.train()

    plt.figure(figsize=(4, 4))
    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow(predictions[i, 0, :, :] * 0.5 + 0.5, cmap='gray')
        plt.axis('off')         # 关闭坐标轴显示

    plt.savefig('C:/Users/Administrator/Desktop/GAN/images2/image_at_epoch_{:04d}.png'.format(epoch))

def train_model(generator, discriminator, train_loader, num_epochs):
    for epoch in range(num_epochs):
        start = time.time()

        for i, (images, _) in enumerate(train_loader):
            images = images.to(device)
            train_step(images)

        display.clear_output(wait=True)
        generate_and_save_images(generator, epoch + 1, seed)

        if (epoch + 1) % 10 == 0:
            checkpoint_path = 'C:/Users/Administrator/Desktop/GAN/checkpoints/checkpoint_epoch_{}.tar'.format(epoch + 1)
            torch.save({
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_generator_state_dict': optimizer_generator.state_dict(),
                'optimizer_discriminator_state_dict': optimizer_discriminator.state_dict(),
            }, checkpoint_path)

        print('Time for epoch {} is {} sec'.format(epoch + 1, time.time() - start))

    display.clear_output(wait=True)
    generate_and_save_images(generator, num_epochs, seed)

# 训练模型
train_model(generator, discriminator, train_loader, EPOCHS)

# 恢复最新的检查点
latest_checkpoint_path = 'C:/Users/Administrator/Desktop/GAN/checkpoints/checkpoint_epoch_{}.tar'.format(EPOCHS)
checkpoint = torch.load(latest_checkpoint_path)
generator.load_state_dict(checkpoint['generator_state_dict'])
discriminator.load_state_dict(checkpoint['discriminator_state_dict'])


# 评估模型
def display_image(epoch_no):
    image_path = 'C:/Users/Administrator/Desktop/GAN/images2/image_at_epoch_{:04d}.png'.format(epoch_no)
    img = Image.open(image_path)
    display.display(img)

# 使用 EPOCHS 变量调用
display_image(EPOCHS)


  • 8
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值