# 保姆级讲解生成对抗网络GAN，及原始GAN的torch复现

152 篇文章 23 订阅
45 篇文章 4 订阅
19 篇文章 2 订阅

# coding:utf-8
# @Email: wangguisen@infinities.com.cn
# @Time: 2022/11/11 10:44 下午
# @File: GAN_demo.py
'''

'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms

'''

其和GAN的训练技巧有关，对于生成器最后使用tanh激活，tanh的取值范围就是（-1,1），
为了方便生成的图片和输入噪声取值范围相同，所以将输入归一化到（-1,1）
'''
# 对数据归一化（-1，1）
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(0.5, 0.5)
])

# print(imgs.shape)

'''   定义生成器   '''
'''

'''
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()

self.linears = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 28*28),
nn.Tanh()
)

def forward(self, x):
# x 为长度为100的noise
out = self.linears(x)
out = out.view(-1, 28, 28, 1)
return out

'''   定义判别器   '''
'''

'''
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()

self.linears = nn.Sequential(
nn.Linear(28*28, 256),
nn.LeakyReLU(),
nn.Linear(256, 512),
nn.LeakyReLU(),
nn.Linear(512, 1),
nn.Sigmoid()
)

def forward(self, x):
x = x.view(-1, 28*28)
return self.linears(x)

'''   初始化模型、优化器、损失   '''
device = 'cuda' if torch.cuda.is_available() else 'cpu'

gen = Generator().to(device)
dis = Discriminator().to(device)

loss_fn = nn.BCELoss()

'''   绘图函数  '''
def gen_img_plot(net, test_input):
prediction = np.squeeze(net(test_input).detach().cpu().numpy())
fig = plt.figure(figsize=(4, 4))
for i in range(16):
plt.subplot(4, 4, i+1)
plt.imshow((prediction[i] + 1) / 2)
plt.axis('off')
plt.show()

test_input = torch.randn(16, 100, device=device)

'''   训练   '''
D_loss = []
G_loss = []
for epoch in range(20):
d_epoch_loss = 0
g_epoch_loss = 0
for step, (img, label) in enumerate(dataloader):
img = img.to(device)
size = img.size(0)
random_noise = torch.randn(size, 100, device=device)

'''判别器优化'''
# 判别器输入真实的图片，得到对真实图片的预测结果
real_output = dis(img)
# 判别器在真实图片上的损失
d_real_loss = loss_fn(real_output, torch.ones_like(real_output))    # 希望判别器将真实的数据判别为全1
d_real_loss.backward()

# 判别器输入生成的图片，得到判别器在生成图像上的损失
gen_img = gen(random_noise)
fake_output = dis(gen_img.detach())   # 对于生成图片产生的损失，我们的优化目标是判别器，希望fake_output被判定为0，来优化判别器，所以要截断梯度，detach会得到一个没有tensor的梯度
d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))   # 希望判别器将生成的数据判别为全0
d_fake_loss.backward()

# 判别器总损失
d_loss = d_real_loss + d_fake_loss
d_optim.step()

'''生成器优化'''
fake_output = dis(gen_img)   # 优化生成器，所以不用截断 detach
# 对于生成器，希望生成的图片判定为1
g_loss = loss_fn(fake_output, torch.ones_like(fake_output))   # 生成器的损失
g_loss.backward()
g_optim.step()

d_epoch_loss += d_loss
g_epoch_loss += g_loss

d_epoch_loss /= count
g_epoch_loss /= count
D_loss.append(d_epoch_loss)
G_loss.append(g_epoch_loss)
print('Epoch: ', epoch)
gen_img_plot(net=gen, test_input=test_input)


• 2
点赞
• 7
收藏 更改收藏夹
• 打赏
• 0
评论
07-19 4万+
01-16 6636
11-24 444
03-16
07-11 73
11-24 166
07-10 559
05-23 335
02-02 4177
05-15 993

### “相关推荐”对你有帮助么？

• 非常没帮助
• 没帮助
• 一般
• 有帮助
• 非常有帮助

WGS.

¥2 ¥4 ¥6 ¥10 ¥20

1.余额是钱包充值的虚拟货币，按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载，可以购买VIP、C币套餐、付费专栏及课程。