跟着B站up主敲的代码
import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import matplotlib.pyplot as plt import torchvision from torchvision import transforms # -------------------------------------------------------------------------------# # --------------------------------数据准备-----------------------------------------# transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5),(0.5))] ) train_ds = torchvision.datasets.MNIST("data",train=True,transform=transform,download=True) dataloader = torch.utils.data.DataLoader(train_ds,batch_size=64,shuffle=True) imgs,_ = next(iter(dataloader)) print(imgs.shape) # ---------------------------------------------------------------------------------# # --------------------------------定义生成器-----------------------------------------# class Generator(nn.Module): def __init__(self): super(Generator,self).__init__() self.main = nn.Sequential( nn.Linear(100,256), nn.ReLU(), nn.Linear(256, 512), nn.ReLU(), nn.Linear(512, 784), nn.Tanh() ) def forward(self,x): img = self.main(x) img = img.view(-1,28,28,1) return img # ---------------------------------------------------------------------------------# # --------------------------------定义判别器-----------------------------------------# class Discriminator(nn.Module): def __init__(self): super(Discriminator,self).__init__() self.main = nn.Sequential( nn.Linear(784, 512), nn.LeakyReLU(), nn.Linear(512, 256), nn.LeakyReLU(), nn.Linear(256, 1), nn.Sigmoid() ) def forward(self,x): x = x.view(-1,784) self.main(x) return x # ---------------------------------------------------------------------------------# # ------------------------初始化模型,优化器及损失函数-----------------------------------# # device = "cpu" device = "cuda" if torch.cuda.is_available() else "cpu" gen = Generator().to(device) dis = Discriminator().to(device) d_optim = torch.optim.Adam(dis.parameters(),lr=0.001) g_optim = torch.optim.Adam(gen.parameters(),lr=0.001) loss_fn = torch.nn.BCELoss() # ---------------------------------------------------------------------------------# # -----------------------------------绘图函数----------------------------------------# # def gen_img_plot(model,test_input): # prediction = np.squeeze(model(test_input).detach().cpu().numpy()) # fig = plt.figure(figsize=(4,4)) # for i in range(16): # plt.subplot(4,4,i+1) # plt.imshow((prediction[i]+1)/2.0) # plt.axis("off") # plt.show() def gen_img_plot(model, test_input): prediction = np.squeeze(model(test_input).detach().cpu().numpy()) fig, axs = plt.subplots(4, 4, figsize=(8, 8)) for i in range(16): row = i // 4 col = i % 4 axs[row, col].imshow((prediction[i] + 1) / 2.0) axs[row, col].axis("off") plt.show() test_input = torch.randn(16,100,device=device) # ---------------------------------------------------------------------------------# # ----------------------------------GAN训练-----------------------------------------# D_loss = [] G_loss = [] for epoch in range(100): d_epoch_loss = 0 g_epoch_loss = 0 count = len(dataloader) for step,(img,_) in enumerate(dataloader): img = img.to(device) size = img.size(0) random_noise = torch.randn(size,100,device=device) d_optim.zero_grad() real_output = dis(img) real_output = torch.sigmoid(real_output) d_real_loss = loss_fn(real_output,torch.ones_like(real_output)) d_real_loss.requires_grad_(True) d_real_loss.backward() gen_img = gen(random_noise) fake_output = dis(gen_img.detach()) fake_output = torch.sigmoid(fake_output) d_fake_loss = loss_fn(fake_output,torch.zeros_like(fake_output)) d_fake_loss.requires_grad_(True) d_fake_loss.backward() d_loss = d_real_loss + d_fake_loss d_optim.step() g_optim.zero_grad() fake_output = dis(gen_img) fake_output = torch.sigmoid(fake_output) g_loss = loss_fn(fake_output,torch.ones_like(fake_output)) g_loss.requires_grad_(True) g_loss.backward() g_optim.step() with torch.no_grad(): d_epoch_loss = d_epoch_loss + d_loss g_epoch_loss = g_epoch_loss + g_loss with torch.no_grad(): d_epoch_loss = d_epoch_loss / count g_epoch_loss = g_epoch_loss / count D_loss.append(d_epoch_loss) G_loss.append(g_epoch_loss) print("Epoch:",epoch) gen_img_plot(gen,test_input)
运行100次结果:
遇到BUG就百度,一步一步解决问题后能成功运行,但结果惨不忍睹,希望有路过大佬帮忙看看,解决下。