提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
文章目录
前言
提示:这里可以添加本文要记录的大概内容:
例如:随着人工智能的不断发展,参与计算机视觉研究的人员也越来越多,今天简单地介绍一下基于pytorch的gan的实现。
提示:以下是本篇文章正文内容,下面案例可供参考
一、gan是什么?
示例:生成对抗网络(gan)是一种神经网络,主要由判别器(Discriminator)和生成器(Generator)组成。其中生成器的作用是根据输入的随机噪音产出一个相似与原始数据的图片。而判别器的作用是根据来自数据源的图片不断调整自己的参数和判断生成器产生的照片是否为真。并促使生成器调整期参数。
二、使用步骤
1.引入库
代码如下(示例):
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
2.下载数据并对数据进行规范
代码如下(示例):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(0.5 , 0.5)
])
train_ds = torchvision.datasets.MNIST('data', train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle= True)
该处使用torchvision的数据集。
3.生成器的代码和判别器的代码
生成器代码如下(示例):
class Generator(nn.Module):
def __init__(self):
super(Generator,self).__init__()
self.main = nn.Sequential(
nn.Linear(100,256),
nn.ReLU(),
nn.Linear(256,512),
nn.ReLU(),
nn.Linear(512, 28*28),
nn.Tanh()
)
def forward(self, x):
img = self.main(x)
img = img.reshape(-1, 28, 28)
return img
判别器代码如下(示例):
class Discraiminator(nn.Module):
def __init__(self):
super(Discraiminator, self).__init__()
self.main = nn.Sequential(
nn.Linear(28*28, 512),
nn.LeakyReLU(),
nn.Linear(512,256),
nn.LeakyReLU(),
nn.Linear(256,1),
nn.Sigmoid()
)
def forward(self, x):
x = x.view(-1, 28*28)
x = self.main(x)
return x
4.定义损失函数和优化函数
代码如下(示例):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discraiminator().to(device)###定义生成器和判别器
gen_opt = optim.Adam(gen.parameters(), lr=0.0001)
dis_opt = optim.Adam(dis.parameters(), lr=0.0001)
loss_fn = torch.nn.BCELoss()
5.定义绘图函数
代码如下(示例):
def gen_img_plot(model,test_input):
prediction = np.squeeze(model(test_input).detach().cpu().numpy())
fig = plt.figure(figsize=(4, 4))
for i in range(16):
plt.subplot(4, 4, i+1)
plt.imshow((prediction[i]+1)/2)
plt.axis('off')
plt.show()
6.开始训练,并显示出生成器所产生的图像
代码如下(示例):
test_input = torch.randn(16, 100, device=device)
D_loss =[]
G_loss =[]
for epoch in range(30):
d_epoch_loss = 0
g_epoch_loss = 0
count = len(dataloader)
for step, (img, _) in enumerate(dataloader):
img = img.to(device)
size = img.size(0)
random_noise = torch.randn(size, 100, device = device)
dis_opt.zero_grad()
real_output = dis(img)
d_real_loss = loss_fn(real_output, torch.ones_like(real_output))
d_real_loss.backward()
gen_img = gen(random_noise)
fake_output = dis(gen_img.detach())
d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))
d_fake_loss.backward()
d_loss = d_real_loss + d_fake_loss
dis_opt.step()
gen_opt.zero_grad()
fake_output = dis(gen_img)
g_loss = loss_fn(fake_output,torch.ones_like(fake_output))
g_loss.backward()
gen_opt.step()
with torch.no_grad():
d_epoch_loss += d_loss
g_epoch_loss += g_loss
with torch.no_grad():
d_epoch_loss /= count
g_epoch_loss /= count
D_loss.append(d_epoch_loss)
G_loss.append(g_epoch_loss)
print('epoch:', epoch)
gen_img_plot(gen,test_input)
最后产生的图片
如下图所示,分别为第一轮生成器产生图片结果和第三十轮的结果