Pytorch简单实现GAN细节及原理

GAN(

https://blog.csdn.net/zcyzcyjava/article/details/127535536?utm_medium=distribute.pc_relevant.none-task-blog-2defaultbaidujs_baidulandingword~default-0-127535536-blog-132029778.235^v43^pc_blog_bottom_relevance_base5&spm=1001.2101.3001.4242.1&utm_relevant_index=3

网络架构

在这里插入图片描述

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms

# 对数据做归一化 (-1, 1)对gan的输入数据全部规范化到(-1,1)之间,normalized = (original - mean) / std
transform = transforms.Compose([   #transform做变形
    transforms.ToTensor(),         # ToTensor会将图像像素值转换为0-1; channel, high, witch,
    transforms.Normalize(0.5, 0.5) #标准化,将均值设置为0.5,标准差为0.5将数据规范化到(-1,1)
])
 
 
train_ds = torchvision.datasets.MNIST(root ='F:\code\Dataset',
                                      train=True,
                                      transform=transform,
                                      download=True)#定义MNIST数据集
 
#加载自定义数据集,打乱,batch_size设置为64
dataloader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)
#%%
imgs, _ = next(iter(dataloader))#加载一个批次的图片(64张)
#%%
imgs.shape

生成器

"""
输入是长度为 100 的 噪声(符合正态分布的随机数)
输出为(1, 28, 28)的图片
linear 1 :   100----256
linear 2:    256----512
linear 3:    512----28*28
output:     1*28*28----(1, 28, 28)
"""

class Generator(nn.Module):
    def __init__(self):
        # 继承父类
        super(Generator, self).__init__()
        self.main = nn.Sequential(
                                  nn.Linear(100, 256),
                                  nn.ReLU(),
                                  nn.Linear(256, 512),
                                  nn.ReLU(),
                                  nn.Linear(512, 28*28),
            # 对于-1, 1之间的数据分布,Tanh效果最好。输出的取值范围是-1,1之间
                                  nn.Tanh() 
        )
    def forward(self, x):  # 前向传播,x 表示长度为100 的noise输入
        img = self.main(x) #将x输入到main模型中 得到img
        img = img.view(-1, 28, 28, 1)#通过view函数reshape成(28,28,1),channel在后
        return img

判别器

# 输入为(1, 28, 28)的图片  输出为二分类的概率值,输出使用sigmoid激活 0-1
# BCEloss计算交叉熵损失
 
# nn.LeakyReLU   f(x) : x>0 输出 x, 如果x<0 ,输出 a*x  a表示一个很小的斜率,比如0.1
# 判别器中一般推荐使用 LeakyReLU,RELU激活函数在小于0没有任何梯度,会非常难以训练
 
 
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()#继承父类的属性
        self.main = nn.Sequential(
                 #输入一张图片(28,28),然后展平成28*28,再卷积到256
                                  nn.Linear(28*28, 512),
                                  nn.LeakyReLU(),
                                  nn.Linear(512, 256),
                                  nn.LeakyReLU(),
                                  nn.Linear(256, 1),
                                  nn.Sigmoid()
        )
    def forward(self, x):#x输入的是28,28的图片
        x = x.view(-1, 28*28)# 展平,会丢失一定的空间信息,但是在minst数据集上够用
        x = self.main(x)
        return x

初始化及绘图

device = 'cuda' if torch.cuda.is_available() else 'cpu'#默认使用cuda,否则cpu
#%%
gen = Generator().to(device)    #初始化Generator模型
dis = Discriminator().to(device)#初始化Discriminator模型
#%%
d_optim = torch.optim.Adam(dis.parameters(), lr=0.0001)#定义优化器,学习率
g_optim = torch.optim.Adam(gen.parameters(), lr=0.0001)
#%%
loss_fn = torch.nn.BCELoss()#二分类判别模型

def gen_img_plot(model, test_input):#每次都给一个同样的test_input正态分布随机数
    #detach用来截断梯度,放到cpu上,转换为numpy,squeeze用于去掉维度为一的值,鲁棒性更高==>28*28的数组
    prediction = np.squeeze(model(test_input).detach().cpu().numpy())
    fig = plt.figure(figsize=(4, 4))#绘制16张图片
    for i in range(16):#循环
        plt.subplot(4, 4, i+1)#四行四列的第一张
        plt.imshow((prediction[i] + 1)/2)#转换成0,1之间的数值(预测的结果恢复到0,1之间
        plt.axis('off')#关闭
    plt.show()
#%%
test_input = torch.randn(16, 100, device=device)#生成长度为100的一个批次16张的随机噪声输入

训练

D_loss = []
G_loss = []#定义空列表用来放两个模型生成的loss

"""
初始化损失函数为0
累加,最后除以总epoch得到每个epoch的平均损失
"""
for epoch in range(100):
    d_epoch_loss = 0
    g_epoch_loss = 0
    count = len(dataloader)#返回批次数,len(dataset)返回样本数
    """
    对dataloader进行迭代(总图片/批次大小=迭代次数)
    enumerate:返回值有两个:一个是序号,也就是在这里的batch地址,一个是数据train_ids
    enumerate用于给dataloader的每次迭代加上索引编号step
    img代表当前batch中的图片tensor
    _表示同时解包出来的label张量,但是没有用到,所以直接用_占位
    """
    for step, (img, _) in enumerate(dataloader):
        img = img.to(device)#将照片上传到设备上
        size = img.size(0)  #获批次大小(64),随机噪声的输入64个
        random_noise = torch.randn(size, 100, device=device)#生成噪声随机数,大小个数是size
        
        d_optim.zero_grad()#将梯度归0
        """
        1、输入真实的图片,real_output对真实图片的预测结果,真实图片为1,假图片为0
        得到判别器在真实图像上的损失  ones_like:全1数组,因为target(real)就是1
        
        2、输入生成的图片gen_img,fake_output是对生成图片的预测
        得到判别器在生成图像上的损失,zeros_like:全0数组,因为target(fake)就是0
        """
        # real图像loss
        real_output = dis(img)    
        d_real_loss = loss_fn(real_output, 
                              torch.ones_like(real_output)) 
        d_real_loss.backward()
        # fake图像loss
        gen_img = gen(random_noise)
        # 我们需要更新的参数是D的,所有截断G的梯度传递
        fake_output = dis(gen_img.detach()) 
        d_fake_loss = loss_fn(fake_output,
                              torch.zeros_like(fake_output)) 
        d_fake_loss.backward()
        d_loss = d_real_loss + d_fake_loss
        d_optim.step()#进行优化
        #以上是用来优化判别器

        
        g_optim.zero_grad()
        fake_output = dis(gen_img)    #将生成图片放到判别器当中--不要梯度截断
        g_loss = loss_fn(fake_output, #我们这里就希望fake_output被判定为1用来优化生成器
                         torch.ones_like(fake_output))      # 生成器的损失
        g_loss.backward()
        g_optim.step()
        
        with torch.no_grad():#两个模型的损失函数做累加(不需要计算梯度)---每个批次累加==一个epoch
            d_epoch_loss += d_loss
            g_epoch_loss += g_loss
            
    with torch.no_grad():#得到平均loss
        d_epoch_loss /= count
        g_epoch_loss /= count
        D_loss.append(d_epoch_loss.item())
        G_loss.append(g_epoch_loss.item())#这样列表当中会保存每个epoch的平均loss
        print('Epoch:', epoch)
        print('D_loss', D_loss)
        print('G_loss', G_loss)
        gen_img_plot(gen, test_input)#绘图

fake_output = dis(gen_img.detach())进行detach操作是因为在计算判别器对生成图片的损失fake_output时,gen_img被传入到判别器函数dis中执行前向传播计算。此时gen_img作为输入,其计算梯度需要回传到生成器更新参数。但是这里我们要计算的只是判别器d_loss,更新判别器的参数。如果不detach gen_img,那么计算完fake_output之后,其梯度会同时回传到生成器更新参数,因为这里我们只关心如何根据fake_output来更新判别器的参数,而不是更新生成器的参数。detach操作就将gen_img与后续计算图断开联系,使其不再参与梯度计算和回传。如下图所示

举个例子来说明一下detach有什么用。 如果A网络的输出被喂给B网络作为输入, 如果我们希望在梯度反传的时候只更新B中参数的值,而不更新A中的参数值,这时候就可以使用detach()

  • detach仅仅断开了gen_img本身与后续计算图的连接。
  • 但g_loss计算时,fake_output依然通过dis来计算,此时dis还保留参数。
  • 所以理论上,g_loss计算完后,dis的参数可能已经更新,影响后续d_loss计算。

为什么实际训练中g_loss计算不会影响d_loss,原因是:

  • GAN训练采用交替优化策略, g_loss和d_loss分别单独计算一次。
  • 也就是g_loss和d_loss不会在同一个backward()与step()中一起计算。
  • 所以实际上,g_loss计算完之后,下一步就计算d_loss,中间又reset参数,利用最新dis。

所以总结:

  • detach仅断开输入gen_img,但不确保g_loss计算不影响d_loss
  • 但GAN训练机制下,实际g_loss和d_loss得到了良好分离,各自使用更新后的网络。
  • 9
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值