学习内容GAN代码实战和原理精讲 PyTorch代码进阶 最简明易懂的GAN生成对抗网络入门课程 使用PyTorch编写GAN实例 2021.12最新课程_哔哩哔哩_bilibili
1、相关代码基础
1.1 激活函数详解
详解激活函数(Sigmoid/Tanh/ReLU/Leaky ReLu等) - 知乎 (zhihu.com)
1.2 PyTorch深度学习框架在训练时,大多都是利用GPU来提高训练速度,怎么用GPU(方法:.cuda()):
.cuda()将数据和模型送入GPU中
(59条消息) PyTorch关于以下方法使用:detach() cpu() numpy() 以及item()_Karl_G的博客-CSDN博客
1.3 阻断反向传播
pred = np.squeeze(model(test_input).detach().cpu().numpy())
得到结果;截断梯度;放在cpu上,返回值为tensor;转换为numpy数据;删除指定维度,即把shape中为1的维度去掉;
2、代码及结果
2.1 代码
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms # 对数据进行原始处理
from torch.utils.data import DataLoader
# 数据准备
# 对数据归一化(-1,1)
transform = transforms.Compose([
transforms.ToTensor(), #ToTensor()能够把灰度范围从0-255变换到0-1之间;[channel,high,width]
transforms.Normalize(0.5, 0.5) #把0-1变换到(-1,1),因为使用tanh函数做激活
])
# 加载内置数据集,只需要图片,不需要标签,也不需要测试
#train_ds = torchvision.datasets.MNIST('data', # 读谁?
# train=True,
# transform=transform,
# download=True)
train_ds = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)# 放置位置
dataloader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True) # 怎么读? shuffle=True 不容易过拟合
# define G
# 输入: 长度为100的正态分布噪声
# 输出:图片(1,28,28)
# imgs, _ = next(iter(dataloader))#占位符是标签,不需要
# print(imgs.shape)
# torch.Size([61,1,28,28])
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.Linear(100, 256),#100-256
nn.ReLU(),
nn.Linear(256, 512),#256-512
nn.ReLU(),
nn.Linear(512, 28 * 28),#512-28*28
nn.Tanh() # 需要注意
)
def forward(self, x): # 长度100的噪声
img = self.main(x) #img还没有被展平(28*28)
img = img.view(-1, 28, 28) # reshape:28*28-(1,28,28)
return img
# define D
# 输入:图片
# 输出:二分类的概率值,使用sigmoid激活,BCE loss
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.LeakyReLU(), # 在负数的时候也会给一个非常小的斜率,需要注意
nn.Linear(512, 256),
nn.LeakyReLU(),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
x = x.view(-1, 28 * 28)
x = self.main(x)
return x
# 初始化模型,优化器,损失函数计算
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)# 初始化
dis = Discriminator().to(device)# 初始化
d_optim = torch.optim.Adam(dis.parameters(), lr=0.0001)
g_optim = torch.optim.Adam(gen.parameters(), lr=0.0001)
loss_fn = torch.nn.BCELoss()
# 绘图
def gen_img_plot(model, test_input):#test_input是同样的随机数
pred = np.squeeze(model(test_input).detach().cpu().numpy())#得到结果;截断梯度;放在cpu上,返回值为tensor;转换为numpy数据;删除指定维度,即把shape中为1的维度去掉;28*28
fig = plt.figure(figsize=(4, 4))#16张图片
for i in range(16):
plt.subplot(4, 4, i + 1)
plt.imshow((pred[i] + 1) / 2)#将(-1,1)恢复成(0,1)之间进行绘图
plt.axis('off')
plt.show()
test_input = torch.randn(16, 100, device=device)#16个长度为100的随机输入,(16*100),产生16张图片
# Gan的训练
D_loss = []
G_loss = []
# 编写训练循环
for epoch in range(20):
D_epoch_loss = 0 # 计算每一个epoch的平均loss
G_epoch_loss = 0
count = len(dataloader) # 返回批次数 len(dataset)返回样本数
for step, (img, _) in enumerate(dataloader):#对dataloader进行迭代
img = img.to(device)
size = img.size(0) # 返回img第一维大小
random_noise = torch.randn(size, 100, device=device)
d_optim.zero_grad() # 梯度归0
real_output = dis(img) # 对判别器输入真实图片,得到对真实图片的判断结果
d_real_loss = loss_fn(real_output, torch.ones_like(real_output)) # 得到判别器在真实数据上的损失
d_real_loss.backward()
gen_img = gen(random_noise)
fake_output = dis(gen_img.detach()) # 判别器输入生成图片,fake_output对生成图片的判断结果
d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output)) # 得到判别器在生成图片上的损失
d_fake_loss.backward()
d_loss = d_real_loss + d_fake_loss
d_optim.step()
# 生成器的损失和优化
g_optim.zero_grad()
fake_output = dis(gen_img)
# g_loss = loss_fn(fake_output, torch.ones_like(fake_output), device=device) # 得到生成器的损失
g_loss = loss_fn(fake_output, torch.ones_like(fake_output))
g_loss.backward()
g_optim.step()
with torch.no_grad():
D_epoch_loss += d_loss
G_epoch_loss += g_loss
with torch.no_grad():
D_epoch_loss /= count
G_epoch_loss /= count
D_loss.append(D_epoch_loss)
G_loss.append(G_epoch_loss)
print('epoch:', epoch)
gen_img_plot(gen, test_input)
2.2 结果
3、代码思路