简单实现基于pytorch的gan网络

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档


前言

提示:这里可以添加本文要记录的大概内容:

例如:随着人工智能的不断发展,参与计算机视觉研究的人员也越来越多,今天简单地介绍一下基于pytorch的gan的实现。


提示:以下是本篇文章正文内容,下面案例可供参考

一、gan是什么?

示例:生成对抗网络(gan)是一种神经网络,主要由判别器(Discriminator)和生成器(Generator)组成。其中生成器的作用是根据输入的随机噪音产出一个相似与原始数据的图片。而判别器的作用是根据来自数据源的图片不断调整自己的参数和判断生成器产生的照片是否为真。并促使生成器调整期参数。

二、使用步骤

1.引入库

代码如下(示例):

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms

2.下载数据并对数据进行规范

代码如下(示例):

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.5 , 0.5)
])
train_ds = torchvision.datasets.MNIST('data', train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle= True)

该处使用torchvision的数据集。


3.生成器的代码和判别器的代码

生成器代码如下(示例):

class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.main = nn.Sequential(
               nn.Linear(100,256),
               nn.ReLU(),
               nn.Linear(256,512),
               nn.ReLU(),
               nn.Linear(512, 28*28),
               nn.Tanh()
           )
    def forward(self, x):
        img = self.main(x)
        img = img.reshape(-1, 28, 28)
        return img

判别器代码如下(示例):

class Discraiminator(nn.Module):
    def __init__(self):
        super(Discraiminator, self).__init__()
        self.main = nn.Sequential(
                    nn.Linear(28*28, 512),
                    nn.LeakyReLU(),
                    nn.Linear(512,256),
                    nn.LeakyReLU(),
                    nn.Linear(256,1),
                    nn.Sigmoid()
        )
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.main(x)
        return x

4.定义损失函数和优化函数

代码如下(示例):

device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discraiminator().to(device)###定义生成器和判别器
gen_opt = optim.Adam(gen.parameters(), lr=0.0001)
dis_opt = optim.Adam(dis.parameters(), lr=0.0001)
loss_fn = torch.nn.BCELoss()

5.定义绘图函数

代码如下(示例):

def gen_img_plot(model,test_input):
    prediction = np.squeeze(model(test_input).detach().cpu().numpy())
    fig = plt.figure(figsize=(4, 4))
    for i in range(16):
        plt.subplot(4, 4, i+1)
        plt.imshow((prediction[i]+1)/2)
        plt.axis('off')
    plt.show()

6.开始训练,并显示出生成器所产生的图像

代码如下(示例):

test_input = torch.randn(16, 100, device=device)
D_loss =[]
G_loss =[]
for epoch in range(30):
    d_epoch_loss = 0
    g_epoch_loss = 0
    count = len(dataloader)
    for step, (img, _) in enumerate(dataloader):
        img = img.to(device)
        size = img.size(0)
        random_noise = torch.randn(size, 100, device = device)
        
        dis_opt.zero_grad()
        real_output = dis(img)
        d_real_loss = loss_fn(real_output, torch.ones_like(real_output))
        d_real_loss.backward()
        
        gen_img = gen(random_noise)
        fake_output = dis(gen_img.detach())
        d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))
        d_fake_loss.backward()
         
        d_loss = d_real_loss + d_fake_loss
        dis_opt.step()
        
        gen_opt.zero_grad()
        fake_output = dis(gen_img)
        g_loss = loss_fn(fake_output,torch.ones_like(fake_output))
        g_loss.backward()
        gen_opt.step()
        
        with torch.no_grad():
            d_epoch_loss += d_loss
            g_epoch_loss += g_loss
    with torch.no_grad():
        d_epoch_loss /= count
        g_epoch_loss /= count
        D_loss.append(d_epoch_loss)
        G_loss.append(g_epoch_loss)
        print('epoch:', epoch)
        gen_img_plot(gen,test_input)

最后产生的图片

如下图所示,分别为第一轮生成器产生图片结果和第三十轮的结果
第一轮训练结果
第30轮训练结果

  • 4
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值