【个人学习笔记】真代码0基础从0开始一步步python实现一个基础GAN网络

使用mnist数据集简单实现GAN,代码源码来源于github,在这里仅供学习!
源代码只进行了训练没有进行验证
(只涉及到代码实现,不涉及到任何理论)

代码模块及细节

1 初始化

#导入库
import argparse
import os
import numpy as np
import torch
#-------------设置参数-------------------
#不明白参数含义的可以看看help的英文注释
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches") 
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space") # 随机噪声z的维度
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension") # 输入图像的尺寸
parser.add_argument("--channels", type=int, default=1, help="number of image channels") # 输入图像的channel数
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples") # 保存生成图像和模型的间隔
opt = parser.parse_known_args()[0]#parser.parse_known_args[0]查看上面设置的参数
print("opt =", opt)
#设置图片的大小
img_shape = (opt.channels, opt.img_size, opt.img_size)
print("img_shape =", img_shape)

2 数据加载

数据格式为 28 * 28 的灰度图,以及对应 0-9 的数字标签,在这里我们以常见的MNIST数据集来实验

#导入库
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets

# Configure data loader
#常用os方式,如果没有某个文件夹,直接创建这个文件夹
if not os.path.exists('./home/featurize/mnist'):
    os.makedirs("./home/featurize/mnist", exist_ok=True)
#调用pytorch自带的数据集,如果没有mnist数据集会自动下载  
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "./home/featurize/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),

    batch_size=opt.batch_size,
    shuffle=True,
)
from torch.autograd import Variable
import matplotlib.pyplot as plt

def show_img(img, trans=True):
    if trans:
        img = np.transpose(img.detach().cpu().numpy(), (1, 2, 0))  # 把channel维度放到最后
        plt.imshow(img[:, :, 0], cmap="gray")
    else:
        plt.imshow(img, cmap="gray")
    plt.show()
    
mnist = datasets.MNIST("../../data/mnist")

for i in range(3):
    sample = mnist[i][0]
    label = mnist[i][1]
    show_img(np.array(sample), trans=False)
    print("label =", label, '\n')

代码在实现的时候,用到了Variable对象。Variable对Tensor对象进行封装,只需要Variable::data即可取出Tensor,并且Variable还封装了该Tensor的梯度Variable::grad(是个Variable对象)。现在用Variable作为计算图的节点,则通过反向传播自动求得的导数就保存在Variable对象中了。简单来说就是Variable可以自动求梯度,嫌麻烦可以不使用,自己手动更新梯度

from torch.autograd import Variable
import matplotlib.pyplot as plt
#定义一个展示图片的函数,detach()切断反向传播,transpose()用来更换维度,在这里我把第一维channel通道数放到了最后
def show_img(img, trans=True):
    if trans:
        img = np.transpose(img.detach().cpu().numpy(), (1, 2, 0))  # 把channel维度放到最后,这是因为plt.imshow只接受RGB图像,即高*宽*channel数
        plt.imshow(img[:, :, 0], cmap="gray")
    else:
        plt.imshow(img, cmap="gray")
    plt.show()
    
mnist = datasets.MNIST("../../data/mnist")
#查看一下前三个图片,打印他们的图像和标签
for i in range(3):
    sample = mnist[i][0]
    label = mnist[i][1]
    show_img(np.array(sample), trans=False)
    print("label =", label, '\n')

为了更好地观察transforms的机制,在这里展示了transforms的每一步变化和输出,希望对初学者能够对transforms的使用更清晰一点。我作为一个纯小白却是清晰很多了,最主要的是print他们,看看他们的样子

trans_resize = transforms.Resize(opt.img_size)
trans_to_tensor = transforms.ToTensor()
trans_normalize = transforms.Normalize([0.5], [0.5]) # x_n = (x - 0.5) / 0.5

print("shape =", np.array(sample).shape, '\n')
print("data =", np.array(sample), '\n')
samlpe = trans_resize(sample)
print("(trans_resize) shape =", np.array(sample).shape, '\n')
sample = trans_to_tensor(sample)
print("(trans_to_tensor) data =", sample, '\n')
sample = trans_normalize(sample)
print("(trans_normalize) data =", sample, '\n')

3 模型

3.1生成器

包含5个全连接层,使用LeakyReLU和Tanh激活函数,使用了BatchNorm
并且,个人认为LeakyReLU激活函数是GAN网络中最常用到的激活函数,当然仅是个人看法

GAN的模型关键的两个部分分别是生成器和鉴别器(或者也可以叫判别器,都可以)
价值函数或者说是LOSS是最重要的

首先定义一个生成器,这里为了简单,只使用了简单的几个全连接层,也就是说使用了多层感知机,这也与原文做了对照,并进行了简单的修改。

import torch.nn as nn

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            #np.prod()函数用来计算所有元素的乘积,对于有多个维度的数组可以指定轴,如axis=1指定计算每一行的乘积。
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        #view()的作用相当于numpy中的reshape,重新定义矩阵的形状。view()主要用在pytoch的tensor
        img = img.view(img.size(0), *img_shape)
        return img
    
generator = Generator()
print(generator)

在这里插入图片描述

3.2鉴别器

包含3个全连接层,使用LeakyReLU和Sigmoid激活函数

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity
    
discriminator = Discriminator()
print(discriminator)

在这里插入图片描述

4 损失函数

使用 Binary Cross Entropy Loss (应该是最常见的)

# Loss function--BCELOSS
adversarial_loss = torch.nn.BCELoss()

5 使用Cuda,加速计算

#判断有没有GPU,如果有就用,没有就使用CPu;并且如果使用GPU,模型model,loss和数据都需要挂cuda
#比如 device = torch.device("cuda" if use_cuda else "cpu")
#x=x.to(device),x就是一个输入的tensor类型数据

#是否使用cuda
cuda = True if torch.cuda.is_available() else False
print("cuda_is_available =", cuda)
if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

6 优化器

使用Adam优化器(最常见也是最好用的优化器之一,一般不用改优化器,就是用adam)
在这里对生成器和鉴别器都采用adam优化器

# Optimizers,b1,b2代表adam的参数
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
print("learning_rate =", opt.lr)

7 创建输入

分别从数据集和随机向量中获取输入
在这里我们取出前几个图片

for i, (imgs, _) in list(enumerate(dataloader))[:1]:
    real_imgs = Variable(imgs.type(Tensor))
    # Sample noise as generator input生成符合正态分布的噪声,大小为(64,100)
    z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
    print("i =", i, '\n')
    print("shape of z =", z.shape, '\n')#shape of z = torch.Size([64, 100]) 
    print("shape of real_imgs =", real_imgs.shape, '\n')#shape of real_imgs = torch.Size([64, 1, 28, 28]) 
    print("z =", z, '\n')
    print("real_imgs =")#<built-in method type of Tensor object at 0x7efcf792b6b0>
    for img in real_imgs[:3]:
        show_img(img)

这里是全部的

for i, (imgs, _) in enumerate(dataloader):
    print(real_imgs.type)
    real_imgs = Variable(imgs.type(Tensor))
    # Sample noise as generator input
    z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
    print("i =", i, '\n')
    print("shape of z =", z.shape, '\n')
    print("shape of real_imgs =", real_imgs.shape, '\n')
    print("z =", z, '\n')
    print("real_imgs =")

8 计算loss,反向传播

分别对生成器和判别器计算loss,使用反向传播更新模型参数

    # Adversarial ground truths
    #定义两个tensor,一个为真,全为1,;一个为假,全为0.大小为imgs的imgs.size(0)
    valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False) # 为1时判定为真
    fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False) # 为0时判定为假
    
    # ---------------------
    #  Train Generator
    # ---------------------
    #梯度清0
    optimizer_G.zero_grad()
    #通过生成器生成图像
    gen_imgs = generator(z) 
    print("gen_imgs =")
    for img in gen_imgs:
        show_img(img)

    # Loss measures generator's ability to fool the discriminator
    g_loss = adversarial_loss(discriminator(gen_imgs), valid)
    print("g_loss =", g_loss, '\n')
    

    g_loss.backward()
    #更新参数
    optimizer_G.step()

    # ---------------------
    #  Train Discriminator
    # ---------------------

    optimizer_D.zero_grad()

    # Measure discriminator's ability to classify real from generated samples
    real_loss = adversarial_loss(discriminator(real_imgs), valid)
    fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
    #总的loss是real_loss+fake_loss
    d_loss = (real_loss + fake_loss) / 2
    print("real_loss =", real_loss, '\n')
    print("fake_loss =", fake_loss, '\n')
    print("d_loss =", d_loss, '\n')    
    
    d_loss.backward()
    optimizer_D.step()

9 保存生成图像和模型文件

    from torchvision.utils import save_image

    epoch = 0 # temporary
    batches_done = epoch * len(dataloader) + i
    if batches_done % opt.sample_interval == 0:
        save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True) # 保存生成图像
        
        os.makedirs("./home/featurize/model", exist_ok=True) # 保存模型
        torch.save(generator, './home/featurize/model/generator.pkl') 
        torch.save(discriminator, './home/featurize/model/discriminator.pkl')
        
        print("gen images saved!\n")
        print("model saved!")
  • 0
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值