生成对抗网络代码详解(一):GAN

首先导入必要的模块

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image

import os

然后设置一些超参数

z_dim = 100 #噪声维度
batch_size = 64
learning_rate = 0.0002
total_epochs = 100

如果你的计算机有GPU,可以指定使用哪块GPU

gpu_id ='0'
if gpu_id is not None:
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

创建路径,存放每次迭代生成的图片

if os.path.exists('gan_images') is False:
    os.makedirs('gan_images')

以上都是一些准备工作,接下来开始定义模型,首先是判别器Discriminator:

class Discriminator(nn.Module):
    '''定义全连接判别器'''
    def __init__(self):
         super(Discriminator,self).__init__()

         layers =[]
         #first floor
         layers.append(nn.Linear(in_features=28*28,out_features=512,bias=True))
         layers.append(nn.LeakyReLU(0.2,inplace=True))
         #second floor
         layers.append(nn.Linear(in_features=512, out_features=256, bias=True))
         layers.append(nn.LeakyReLU(0.2, inplace=True))
         # outpur floor
         layers.append(nn.Linear(in_features=256, out_features=1, bias=True))
         layers.append(nn.Sigmoid())

         self.model = nn.Sequential(*layers)

    def forward(self, x):
         x = x.view(x.size(0),-1)
         validity = self.model(x)
         return validity

首先创建一个空列表layers,存放判别器的全连接层。由于我使用的是MNIST数据集,因此第一层的输入为28*28,输出特征设置为512,并添加激活层LeakyReLU,第三层输出特征维度为1,并且添加sigmoid函数,然后定义前项传播函数foward,首先将输入x转换为一个列向量,传入定义好的model模型中。
然后定义生成器Generator,同意使用全连接层。

class Generator(nn.Module):
    '''全连接生成器'''
    def __init__(self,z_dim):
        super(Generator, self).__init__()

        layers = []
        # first floor
        layers.append(nn.Linear(in_features=z_dim, out_features=128))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        # second floor
        layers.append(nn.Linear(in_features=128, out_features=256))
        layers.append(nn.BatchNorm1d(256,0.8))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        # third floor
        layers.append(nn.Linear(in_features=256, out_features=512))
        layers.append(nn.BatchNorm1d(512, 0.8))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        # outpur floor
        layers.append(nn.Linear(in_features=512, out_features=28*28))
        layers.append(nn.Tanh())

        self.model = nn.Sequential(*layers)
    def forward(self, z):
        x = self.model(z)
        x = x.view(-1,1,28,28)
        return  x

生成器的输入为z_dim维的噪声,并且在生成器的全连接层后增加了BN层,有利于网络的收敛。噪声z经过model后输出,需要进行resize为12828的tensor。
定义完成判别器和生成器之后,就要对他们进行初始化,如果电脑上有GPU就可以将整个模型运算的过程放入GPU中运算。

discriminator = Discriminator().to(device)
generator = Generator(z_dim=z_dim).to(device)

初始化交叉熵损失和优化器:

bce = nn.BCELoss().to(device)
ones = torch.ones(batch_size).to(device)
zeros = torch.zeros(batch_size).to(device)

#初始化优化器
g_optimizer = optim.Adam(generator.parameters(),lr = learning_rate,betas=[0.5,0.999])
d_optimizer = optim.Adam(discriminator.parameters(),lr = learning_rate,betas=[0.5,0.999])

在做完以上的准备工作后,就可以加载数据集,并且生成一些随机的噪声作为生成器的输入:

#load dataset
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,),(0.5,))]) #归一化-0.5 /0.5
dataset = torchvision.datasets.MNIST(root='data/',train=True,transform=transform,download=True)
dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True,drop_last=True)

#随机产生100个向量,用于生成效果图
fixed_z = torch.randn([100,z_dim]).to(device)

transform.compose将数据转换为Torch可处理的tensor,并进行归一化处理。
接下来就是训练的过程,先将生成器设置为训练模式,对应于后面的测试阶段,再生成器设置为测试模式,分别计算判别器和生成器的损失,并且进行优化:

        #计算判别器损失,并优化判别器
        real_loss = bce(discriminator(real_images),ones)
        fake_loss = bce(discriminator(fake_images.detach()),zeros)
        d_loss = real_loss + fake_loss

        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        #计算生成器损失,并优化生成器
        g_loss = bce(discriminator(fake_images),ones)

        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

完整代码下载地址

  • 1
    点赞
  • 37
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值