生成对抗网络mnist数据集基于pytorch

跟着B站up主敲的代码

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms

# -------------------------------------------------------------------------------#
# --------------------------------数据准备-----------------------------------------#



transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5),(0.5))]
)
train_ds = torchvision.datasets.MNIST("data",train=True,transform=transform,download=True)
dataloader = torch.utils.data.DataLoader(train_ds,batch_size=64,shuffle=True)
imgs,_ = next(iter(dataloader))
print(imgs.shape)


# ---------------------------------------------------------------------------------#
# --------------------------------定义生成器-----------------------------------------#

class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.main = nn.Sequential(
            nn.Linear(100,256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 784),
            nn.Tanh()
        )

    def forward(self,x):
        img = self.main(x)
        img = img.view(-1,28,28,1)
        return img


# ---------------------------------------------------------------------------------#
# --------------------------------定义判别器-----------------------------------------#

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.main = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self,x):
        x = x.view(-1,784)
        self.main(x)
        return x


# ---------------------------------------------------------------------------------#
# ------------------------初始化模型,优化器及损失函数-----------------------------------#

# device = "cpu"
device = "cuda" if torch.cuda.is_available() else "cpu"

gen = Generator().to(device)
dis = Discriminator().to(device)
d_optim = torch.optim.Adam(dis.parameters(),lr=0.001)
g_optim = torch.optim.Adam(gen.parameters(),lr=0.001)
loss_fn = torch.nn.BCELoss()



# ---------------------------------------------------------------------------------#
# -----------------------------------绘图函数----------------------------------------#

# def gen_img_plot(model,test_input):
#     prediction = np.squeeze(model(test_input).detach().cpu().numpy())
#     fig = plt.figure(figsize=(4,4))
#     for i in range(16):
#         plt.subplot(4,4,i+1)
#         plt.imshow((prediction[i]+1)/2.0)
#         plt.axis("off")
#         plt.show()

def gen_img_plot(model, test_input):
    prediction = np.squeeze(model(test_input).detach().cpu().numpy())
    fig, axs = plt.subplots(4, 4, figsize=(8, 8))
    for i in range(16):
        row = i // 4
        col = i % 4
        axs[row, col].imshow((prediction[i] + 1) / 2.0)
        axs[row, col].axis("off")
    plt.show()

test_input = torch.randn(16,100,device=device)



# ---------------------------------------------------------------------------------#
# ----------------------------------GAN训练-----------------------------------------#

D_loss = []
G_loss = []

for epoch in range(100):
    d_epoch_loss = 0
    g_epoch_loss = 0
    count = len(dataloader)

    for step,(img,_) in enumerate(dataloader):
        img = img.to(device)
        size = img.size(0)
        random_noise = torch.randn(size,100,device=device)

        d_optim.zero_grad()
        real_output = dis(img)
        real_output = torch.sigmoid(real_output)
        d_real_loss = loss_fn(real_output,torch.ones_like(real_output))
        d_real_loss.requires_grad_(True)
        d_real_loss.backward()

        gen_img = gen(random_noise)
        fake_output = dis(gen_img.detach())
        fake_output = torch.sigmoid(fake_output)
        d_fake_loss = loss_fn(fake_output,torch.zeros_like(fake_output))
        d_fake_loss.requires_grad_(True)
        d_fake_loss.backward()

        d_loss = d_real_loss + d_fake_loss
        d_optim.step()


        g_optim.zero_grad()
        fake_output = dis(gen_img)
        fake_output = torch.sigmoid(fake_output)
        g_loss = loss_fn(fake_output,torch.ones_like(fake_output))
        g_loss.requires_grad_(True)
        g_loss.backward()
        g_optim.step()

        with torch.no_grad():
            d_epoch_loss = d_epoch_loss + d_loss
            g_epoch_loss = g_epoch_loss + g_loss

    with torch.no_grad():
        d_epoch_loss = d_epoch_loss / count
        g_epoch_loss = g_epoch_loss / count
        D_loss.append(d_epoch_loss)
        G_loss.append(g_epoch_loss)
        print("Epoch:",epoch)
        gen_img_plot(gen,test_input)


运行100次结果:

 遇到BUG就百度,一步一步解决问题后能成功运行,但结果惨不忍睹,希望有路过大佬帮忙看看,解决下。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彭毓众

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值