小白带你深度学习DAY3

本次代码是一个简单的GAN神经网络

代码说明:

  1. 数据准备:使用MNIST手写数字数据集,进行归一化处理

  2. 网络结构

    • 生成器:将随机噪声转换为图像

    • 判别器:判断输入图像的真假

  3. 训练过程

    • 交替训练判别器和生成器

    • 使用二元交叉熵损失(BCELoss)

    • 使用Adam优化器

  4. 可视化功能

    • 每5个epoch显示一次生成的图像

    • 使用matplotlib显示16张生成的手写数字

运行说明:

  1. 依赖安装

    pip install torch torchvision matplotlib

  2. 执行步骤

    • 在PyCharm中新建Python文件

    • 复制粘贴上述代码

    • 确保Python解释器正确配置(建议使用Python 3.7+)

    • 首次运行会自动下载MNIST数据集(约100MB)

  3. 输出效果

    • 每5个epoch显示一次生成的数字图像

    • 控制台打印损失值变化

    • 随着训练进行,生成的数字会逐渐变得清晰

# -*- coding: utf-8 -*-
# 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

# 设备配置(自动检测GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ====================
#   超参数设置
# ====================
latent_dim = 100  # 噪声向量维度
img_size = 28  # 图像尺寸(MNIST为28x28)
batch_size = 128  # 批大小
epochs = 100  # 训练轮数
lr = 0.0002  # 学习率
sample_interval = 5  # 采样间隔(每n个epoch保存一次生成样本)

# ====================
#   数据预处理
# ====================
# 定义数据转换(将图像转换为Tensor并归一化到[-1, 1])
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # 单通道的MNIST
])

# 加载MNIST数据集
train_dataset = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

# 创建数据加载器
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True
)


# ====================
#   生成器定义(改进版)
# ====================
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        # 使用全连接层构建生成器
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.BatchNorm1d(256),  # 添加批归一化
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(1024, img_size ** 2),  # 输出28x28图像
            nn.Tanh()  # 使用Tanh将输出限制在[-1, 1]
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, img_size, img_size)
        return img


# ====================
#   判别器定义(改进版)
# ====================
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        # 使用全连接层构建判别器
        self.model = nn.Sequential(
            nn.Linear(img_size ** 2, 512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(256, 1),
            nn.Sigmoid()  # 输出0-1的概率值
        )

    def forward(self, img):
        flattened = img.view(img.size(0), -1)
        validity = self.model(flattened)
        return validity


# ====================
#   初始化网络和优化器
# ====================
# 创建生成器和判别器实例
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# 配置优化器(使用Adam优化器)
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

# 定义损失函数
adversarial_loss = nn.BCELoss()


# ====================
#   可视化函数
# ====================
def sample_images(epoch):
    """保存生成的图像样本"""
    with torch.no_grad():
        # 生成随机噪声
        z = torch.randn(16, latent_dim).to(device)
        gen_imgs = generator(z).cpu().numpy()

    # 调整图像格式(反归一化)
    gen_imgs = 0.5 * gen_imgs + 0.5  # 从[-1,1]映射到[0,1]

    # 创建绘图窗口
    plt.figure(figsize=(4, 4))
    for i in range(16):
        plt.subplot(4, 4, i + 1)
        plt.imshow(gen_imgs[i].reshape(28, 28), cmap='gray')
        plt.axis('off')
    plt.tight_layout()
    plt.suptitle(f'Epoch {epoch}')
    plt.show()


# ====================
#   训练循环(改进版)
# ====================
for epoch in range(epochs):
    for i, (imgs, _) in enumerate(train_loader):

        # ---------------------
        #  准备真实数据和标签
        # ---------------------
        real_imgs = imgs.to(device)
        valid = torch.ones((imgs.size(0), 1), device=device) * 0.95  # 标签平滑
        fake = torch.zeros((imgs.size(0), 1), device=device) * 0.05  # 标签平滑

        # ====================
        #  训练判别器
        # ====================
        optimizer_D.zero_grad()

        # 真实图像的损失
        real_loss = adversarial_loss(discriminator(real_imgs), valid)

        # 生成假图像
        z = torch.randn(imgs.size(0), latent_dim).to(device)
        gen_imgs = generator(z)

        # 假图像的损失(注意使用detach()阻断梯度传播)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)

        # 总判别器损失并反向传播
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # ====================
        #  训练生成器
        # ====================
        optimizer_G.zero_grad()

        # 生成器试图欺骗判别器
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        # 反向传播和优化
        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  打印训练进度
        # ---------------------
        if i % 200 == 0:
            print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(train_loader)}] "
                  f"D_loss: {d_loss.item():.4f} G_loss: {g_loss.item():.4f}")

    # 定期生成样本图像
    if epoch % sample_interval == 0:
        sample_images(epoch)

print("训练完成!")
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值