pytorch_gan_MNIST

该教程详细讲解了如何使用PyTorch实现GAN,包括激活函数的使用、GPU加速训练、数据预处理步骤,以及Generator和Discriminator的构建和训练过程。通过训练MNIST数据集,展示了生成器生成手写数字的示例。
摘要由CSDN通过智能技术生成

学习内容GAN代码实战和原理精讲 PyTorch代码进阶 最简明易懂的GAN生成对抗网络入门课程 使用PyTorch编写GAN实例 2021.12最新课程_哔哩哔哩_bilibili

1、相关代码基础

1.1 激活函数详解

详解激活函数(Sigmoid/Tanh/ReLU/Leaky ReLu等) - 知乎 (zhihu.com)

1.2 PyTorch深度学习框架在训练时,大多都是利用GPU来提高训练速度,怎么用GPU(方法:.cuda()):

.cuda()将数据和模型送入GPU中

(59条消息) PyTorch关于以下方法使用:detach() cpu() numpy() 以及item()_Karl_G的博客-CSDN博客

1.3 阻断反向传播

pred = np.squeeze(model(test_input).detach().cpu().numpy())
得到结果;截断梯度;放在cpu上,返回值为tensor;转换为numpy数据;删除指定维度,即把shape中为1的维度去掉;

2、代码及结果

 2.1 代码

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms  # 对数据进行原始处理
from torch.utils.data import DataLoader

# 数据准备
# 对数据归一化(-1,1)
transform = transforms.Compose([
	transforms.ToTensor(), #ToTensor()能够把灰度范围从0-255变换到0-1之间;[channel,high,width]
	transforms.Normalize(0.5, 0.5) #把0-1变换到(-1,1),因为使用tanh函数做激活
])

# 加载内置数据集,只需要图片,不需要标签,也不需要测试
#train_ds = torchvision.datasets.MNIST('data',  # 读谁?
#                                      train=True,
#                                      transform=transform,
#                                      download=True)
train_ds = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)# 放置位置
dataloader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)  # 怎么读? shuffle=True 不容易过拟合

# define G
# 输入: 长度为100的正态分布噪声
# 输出:图片(1,28,28)
# imgs, _ = next(iter(dataloader))#占位符是标签,不需要
# print(imgs.shape)
# torch.Size([61,1,28,28])

class Generator(nn.Module):
	def __init__(self):
		super(Generator, self).__init__()
		self.main = nn.Sequential(
			nn.Linear(100, 256),#100-256
			nn.ReLU(),
			nn.Linear(256, 512),#256-512
			nn.ReLU(),
			nn.Linear(512, 28 * 28),#512-28*28
			nn.Tanh()  # 需要注意
		)

	def forward(self, x):  # 长度100的噪声
		img = self.main(x) #img还没有被展平(28*28)
		img = img.view(-1, 28, 28)  # reshape:28*28-(1,28,28)
		return img


# define D
# 输入:图片
# 输出:二分类的概率值,使用sigmoid激活,BCE loss
class Discriminator(nn.Module):
	def __init__(self):
		super(Discriminator, self).__init__()
		self.main = nn.Sequential(
			nn.Linear(28 * 28, 512),
			nn.LeakyReLU(),  # 在负数的时候也会给一个非常小的斜率,需要注意
			nn.Linear(512, 256),
			nn.LeakyReLU(),
			nn.Linear(256, 1),
			nn.Sigmoid()
		)

	def forward(self, x):
		x = x.view(-1, 28 * 28)
		x = self.main(x)
		return x


# 初始化模型,优化器,损失函数计算
device = 'cuda' if torch.cuda.is_available() else 'cpu'

gen = Generator().to(device)# 初始化
dis = Discriminator().to(device)# 初始化

d_optim = torch.optim.Adam(dis.parameters(), lr=0.0001)
g_optim = torch.optim.Adam(gen.parameters(), lr=0.0001)

loss_fn = torch.nn.BCELoss()


# 绘图
def gen_img_plot(model, test_input):#test_input是同样的随机数
	pred = np.squeeze(model(test_input).detach().cpu().numpy())#得到结果;截断梯度;放在cpu上,返回值为tensor;转换为numpy数据;删除指定维度,即把shape中为1的维度去掉;28*28
	fig = plt.figure(figsize=(4, 4))#16张图片
	for i in range(16):
		plt.subplot(4, 4, i + 1)
		plt.imshow((pred[i] + 1) / 2)#将(-1,1)恢复成(0,1)之间进行绘图
		plt.axis('off')
	plt.show()


test_input = torch.randn(16, 100, device=device)#16个长度为100的随机输入,(16*100),产生16张图片

# Gan的训练
D_loss = []
G_loss = []

# 编写训练循环
for epoch in range(20):
	D_epoch_loss = 0  # 计算每一个epoch的平均loss
	G_epoch_loss = 0
	count = len(dataloader)  # 返回批次数 len(dataset)返回样本数
	for step, (img, _) in enumerate(dataloader):#对dataloader进行迭代
		img = img.to(device)
		size = img.size(0)  # 返回img第一维大小
		random_noise = torch.randn(size, 100, device=device)

		d_optim.zero_grad()  # 梯度归0

		real_output = dis(img)  # 对判别器输入真实图片,得到对真实图片的判断结果
		d_real_loss = loss_fn(real_output, torch.ones_like(real_output))  # 得到判别器在真实数据上的损失

		d_real_loss.backward()

		gen_img = gen(random_noise)
		fake_output = dis(gen_img.detach())  # 判别器输入生成图片,fake_output对生成图片的判断结果

		d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))  # 得到判别器在生成图片上的损失

		d_fake_loss.backward()
		d_loss = d_real_loss + d_fake_loss
		d_optim.step()

		# 生成器的损失和优化
		g_optim.zero_grad()
		fake_output = dis(gen_img)
		# g_loss = loss_fn(fake_output, torch.ones_like(fake_output), device=device)  # 得到生成器的损失
		g_loss = loss_fn(fake_output, torch.ones_like(fake_output))
		g_loss.backward()
		g_optim.step()

		with torch.no_grad():
			D_epoch_loss += d_loss
			G_epoch_loss += g_loss

	with torch.no_grad():
		D_epoch_loss /= count
		G_epoch_loss /= count
		D_loss.append(D_epoch_loss)
		G_loss.append(G_epoch_loss)
		print('epoch:', epoch)
		gen_img_plot(gen, test_input)











2.2 结果

3、代码思路

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值