Pytorch|GAN在手写数字集上的复现

在复现开始之前需要知道几个tricky但很重要的知识点:
1.在pytorch中,神经网络层中的权值weight和偏差bias的tensor均为叶子节点,自己定义的tensor例如a=torch.tensor([1.0])定义的节点是叶子节点,中间计算产生的变量都叫非叶子节点。默认情况下,只有叶子节点的梯度值能够被保留下来,非叶子节点的梯度值在反向传播过程中使用完后就会被清除,不会被保留,除非使用 retain_grad() 方法。backward函数是计算当前tensor对计算图的叶子节点的梯度。backward函数的计算方式中,梯度是累积计算而不是被替换,所以不清0的话梯度就会累加上去。

2.fake.detach()返回的是一个新的tensor,从当前计算图中分离下来的,但是仍指向原变量fake的存放位置,不同之处只是requires_grad为false,得到的这个tensor永远不需要计算其梯度,不具有grad。进一步理解就是fake.detach()这个tensor变成了当前计算图的叶子节点但是不会求叶子节点的梯度了,计算图其实就是代表程序中变量之间的关系,计算图会在backward()函数执行后被清理掉,由于叶子节点变成了fake.detach(),那么fake前的变量的计算关系是没有被清理掉的。

3.元组中只有一个数据要加逗号,

tup1 = (23)  # 不是元组
print(type(tup1)) #<class 'int'>
tup2 = (23,)  # 是元组
print(type(tup2))#<class 'tuple'>

4.对于Gan的判别器D的损失函数,判别器的目标是希望判别器能够更好的区分出生成样本和真实样本,具体体现在数学公式中为,因为期望就是平均数随样本趋于无穷的极限,所以这里把两个式子分别变换为BCEloss后,用均值来代替期望!提一嘴在cyclegan中是因为借鉴了最小二乘gan所以用的是MSE损失而不是BCE损失:
在这里插入图片描述

现在进入正题,开始复现

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter  # to print to tensorboard

class Discriminator(nn.Module):
    def __init__(self,img_dim):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim,128),
            nn.LeakyReLU(0.1),
            nn.Linear(128,1),
            nn.Sigmoid(),
        )
    def forward(self,x):
        return self.disc(x)

class Generator(nn.Module):
    def __init__(self,z_dim,img_dim):
        super().__init__()
        self.gen = nn.Sequential(nn.Linear(z_dim,256),
                                 nn.LeakyReLU(0.1),
                                 nn.Linear(256,img_dim),#28*28*1 -> 784
                                 nn.Tanh(),
                                 )
    def forward(self,x):
        return self.gen(x)

device ="cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
img_dim = 28*28*1
batch_size = 32
num_epochs =50
disc = Discriminator(img_dim).to(device)
gen = Generator(z_dim,img_dim).to(device)
fixed_noise = torch.randn((batch_size,z_dim)).to(device)
transforms = transforms.Compose(
    #transforms.Normalize是channel的对图像进行标准化,当数据维数为1时,数据后面要有逗号
    #因为传进去的数据类型是元组,所以要加逗号,元组中只有一个数据要加逗号,这传list也行
    [transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))]#image=(image-mean)/std,img.shape(28,28)
)
dataset = datasets.MNIST(root="dataset/",transform=transforms,download=True)
loader = DataLoader(dataset,batch_size=batch_size,shuffle=True)
opt_disc = optim.Adam(disc.parameters(),lr=lr)
opt_gen = optim.Adam(gen.parameters(),lr=lr)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")#for tensorboard
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")#for tensorboard
step = 0#for tensorboard

for epoch in range(num_epochs):
    for batch_idx, (real,_) in enumerate(loader):
        real = real.view(-1,784).to(device)
        batch_size =real.shape[0]

        ####Traning Discriminator: 最大化 log(D(real))+log(1-D(G(z)))
        noise = torch.randn(batch_size,z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1) #shape: torch.Size([32])
        lossD_real = criterion(disc_real,torch.ones_like(disc_real))#最大化log(D(real))这一项就等于最小化这个的bce损失

        #用.detach()或者77行的retain_graph=True都可以,二选一
       # disc_fake = disc(fake.detach()).view(-1) #截断fake节点前的梯度传播,所以fake.detach这个tensor此时就成了叶子节点
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake,torch.zeros_like(disc_fake))
        lossD = (lossD_fake+lossD_real)/2

        disc.zero_grad() #只对判别器模型的梯度清0,不清0的话梯度会叠加,这里生成器参数的梯度就没有清掉
      #lossD.backward()

        #backward函数是计算当前tensor对图叶子结点的梯度
        #计算图在backward一次之后各个节点的值会清除,但因为我们下面还要backward一次,所以需要retain_graph=True保存这个图。
        #因为进行了backward后,叶子节点的梯度值是保存了,但计算图被释放了
        lossD.backward(retain_graph=True)
        opt_disc.step()#只更新判别器的参数

        ####Traning Generator: 最小化 log(1-D(G(z))) ->最大化 log(D(G(Z))),因为这样梯度比较大
        output = disc(fake).view(-1)
        lossG = criterion(output,torch.ones_like(output))
        gen.zero_grad() #对生成器模型参数的梯度清0,如果不清0,下次backward计算就会形成累加
        lossG.backward()#由于保留了计算图,这样就可以求叶子结点生成器的参数:gen.grad
        opt_gen.step()#只更新生成器的参数
        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                              Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值