PyTorch对DCGANs网络的实现

​ 生成对抗网络(GANs)是现在深度学习的热点之一,下面我们通过PyTorch实现深度卷积生成对抗网络(DCGANs),数据集使用最为经典的MNIST手写数据集。

1.导入需要使用到的包

import os
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader,sampler
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
from torchvision.utils import save_image
import time
import numpy as np

2.定义判别网络(判别器)和生成网络(生成器)

定义深层卷积判别网络,它就是一般的卷积网络

class DC_Discriminator(nn.Module):
	def __init__(self):
		super(DC_Discriminator,self).__init__()

		self.conv = nn.Sequential(
			nn.Conv2d(1,32,5,1), # 输入深度1,输出深度32,滤波器5*5,步长1
			nn.LeakyReLU(0.01), # 斜率为0.01
			nn.MaxPool2d(2,2), # 滤波器2*2,步长2
			nn.Conv2d(32,64,5,1), # 输入深度32,输出深度64,滤波器5*5,步长1
			nn.LeakyReLU(0.01),
			nn.MaxPool2d(2,2)
			)

		self.fc = nn.Sequential(
			nn.Linear(1024,1024),
			nn.LeakyReLU(0.01),
			nn.Linear(1024,1)
			)

	def forward(self,x):
		x = self.conv(x)
		x = x.view(x.size(0),-1)
		x = self.fc(x)
		return x

定义深度卷积生成网络,其需要将一个低维的噪声向量变成一个图片数据

class DC_Generator(nn.Module):
	def __init__(self,noise_size=NOISE_DIM):
		super(DC_Generator,self).__init__()

		self.fc = nn.Sequential(
			nn.Linear(noise_size,1024), # noise_size表示输入的维度
			nn.ReLU(True),
			nn.BatchNorm1d(1024),
			nn.Linear(1024,7*7*128),
			nn.ReLU(True),
			nn.BatchNorm1d(7*7*128)
			)

		self.conv = nn.Sequential(
			nn.ConvTranspose2d(128,64,4,2,padding=1), # 转置卷积
			nn.ReLU(True),
			nn.BatchNorm2d(64),
			nn.ConvTranspose2d(64,1,4,2,padding=1),
			nn.Tanh()
			)

	def forward(self,x):
		x = self.fc(x)
		x = x.view(x.size(0),128,7,7) # reshape通道是128,大小是7*7
		x = self.conv(x)
		return x

3.定义图片预处理函数,及取样函数

训练图片时的预处理函数:

def preprocess_img(x): # 训练测试图片的预处理
	x = transforms.ToTensor()(x)
	return (x-0.5)/0.5 # 这一步是标准化处理

存储图片时的预处理函数:

def deprocess_img(x):
	return (x+1.0)/2.0

def to_img(x):
	x = 0.5*(x+1.)
	x = x.clamp(0,1) # 将输入张量每个元素夹紧到区间中,并返回结果到一个新张量
	x = x.view(x.size(0),1,28,28)
	return x

定义一个取样的函数,即取样器

class ChunkSampler(sampler.Sampler):
	def __init__(self,num_samples,start=0):
		# num_samples:表示我们需要的样本数量
		# start:表示我们从哪(下标)开始取样本
		self.num_samples = num_samples
		self.start = start

	def __iter__(self): # 迭代器
		return iter(range(self.start,self.start+self.num_samples))

	def __len__(self): # 返回取样的样本总数
		return self.num_samples

4.定义超参数,及加载数据集

# 定义超参数
NUM_TRAIN = 50000 # 训练样本数
learning_rate = 3e-4 # 学习率
NOISE_DIM = 96 # 噪音维度,就是生成器中的noise_size
batch_size = 128 # 一批次中的样本数

# 加载数据集,参数root表示数据集的路径,trainsform表示数据的预处理
train_set = MNIST(root='./data/',train=True,transform=preprocess_img)
train_loader = DataLoader(train_set,batch_size=batch_size,sampler=ChunkSampler(NUM_TRAIN,0))

5.定义损失函数和优化方法

# 定义判别网络的损失函数
def ls_discriminator_loss(scores_real,scores_fake):
    # scores_real:数据集里图片输入到判别网络得到的输出
    # scores_fake:生成器生成的图片输入到判别网络得到的输出
	loss = 0.5*((scores_real - 1)**2).mean() + 0.5*(scores_fake**2).mean()
	return loss

# 定义生成网络的损失函数
def ls_generator_loss(scores_fake):
	loss = 0.5*((scores_fake - 1)**2).mean()
	return loss

# 使用adam来进行优化,学习率是3e-4,beta1是0.5,beta2是0.999
def get_optimizer(net):
	optimizer = torch.optim.Adam(net.parameters(),lr=learning_rate,betas=(0.5,0.999))
    # betas (Tuple[float, float], 可选) – 用于计算梯度以及梯度平方的运行平均值的系数(默认:0.9,0.999)
	return optimizer

​ 注意,这里并没有像平常一样使用PyTorch已经集成的损失函数,而是自己定义的。而这个自己定义的损失函数其实是最基本的生成对抗网络的一个变式,这是在损失函数上的一个变式,叫做Least Squares GAN。

​ Least Squares GAN比最原始的GANs的loss更加稳定,因此这里我们使用了它。

6.定义训练函数

def train_a_gan(D_net,G_net,D_optimizer,G_optimzer,discriminator_loss,generator_loss,show_every=250,noise_size=NOISE_DIM,num_epoches=100):
    # 参数分别是判别网络、生成网络、判别网络优化方法、生成网络优化方法、判别网络损失函数、生成网络损失函数
    # show_every用于计数为了多少次保存一次图片、noise_size输入生成网络的噪音维度、训练轮数
	iter_count = 0 # 定义的计数器,为了和show_every一起用,目的是显示图片
	now = time.clock()
	for epoch in range(num_epoches):
		for img,_ in train_loader:
			bs = img.size(0) # batch_size的简写,为什么不直接使用batch_size的来代替,是因为有时到最后一个batch的size不一定够128

			# 判别网络
			# 判别网络的前向传播
			real_data = Variable(img).cuda() # 用于深度卷积生成对抗网络的真实数据
			logits_real = D_net(real_data) # 判别网络得分,判别真图片

			sample_noise = (torch.rand(bs,noise_size)-0.5) / 0.5 # -1~1的均匀分布
			g_fake_seed = Variable(sample_noise).cuda() # 生成噪声维度
			fake_data = G_net(g_fake_seed) # 生成假的数据,即生成假图片
			logits_fake = D_net(fake_data) # 判别网络得分,判别假图片
            
			d_total_error = discriminator_loss(logits_real,logits_fake) # 计算判别器的loss

			# 判别网络的反向传播
			D_optimizer.zero_grad()
			d_total_error.backward()
			D_optimizer.step() # 优化判别器,更新参数


			# 生成网络
			# 生成网络的前向传播
			g_fake_seed = Variable(sample_noise).cuda() # 这里为什么不用上面的数据,而重新生成一下???
			fake_data = G_net(g_fake_seed) # 生成假数据
			gen_logits_fake = D_net(fake_data) # 判别网络得分(就是网络的输出结果)
            
			g_error = generator_loss(gen_logits_fake) # 计算生成网络的loss

			# 生成网络的反向传播
			G_optimzer.zero_grad()
			g_error.backward()
			G_optimzer.step() # 优化生成器,更新参数

			if(iter_count % show_every == 0): # 每训练250批次就让生成器生成一次图片,看看生成器的效果
				print("Iter:{},D_Loss:{:.4f},G_Loss:{:.4f},Time:{:.4f}".format(iter_count,d_total_error.item(),g_error.item(),time.clock()-now))
				imgs_numpy = deprocess_img(fake_data.data.cpu().numpy())
				imgs_variable = Variable(torch.from_numpy(imgs_numpy))
				pic = to_img(imgs_variable.cpu().data) # 将数据转化为图片
                
				if not os.path.exists('./dc_gan_imgs'):
					os.mkdir('./dc_gan_imgs')
				save_image(pic,'./dc_gan_imgs/image_{}.png'.format(iter_count))
			iter_count += 1

7.定义主程序及结果显示

# 定义深度卷积生成对抗网络的训练过程
D_DC = DC_Discriminator().cuda()
G_DC = DC_Generator().cuda()

D_DC_optim = get_optimizer(D_DC)
G_DC_optim = get_optimizer(G_DC)

train_a_gan(D_DC,G_DC,D_DC_optim,G_DC_optim,ls_discriminator_loss,ls_generator_loss,num_epoches=10)

Iter-500Iter-1500Iter-2500Iter-3500

图片依次是Iter-500、Iter-1500、Iter-2500和Iter-3500的生成器生成的图片(注:训练10轮,图片保存到Iter-3750)。

可以发现图片生成的越来越清晰,已经跟数据集中的图片没多少区别了,已经基本达到以假乱真的境界了。

8.总结

  • 上面我们只用深度卷积生成对抗网络(DCGANs)训练10轮,就已经达到这么好的效果了(输入数据虽然只是普通的灰度图)。这差不多能超过普通的GANs训练30轮的效果。

  • 生成对抗网络的目的是,判别网络尽力将输入的真实图片判断成1,把生成网络生成的图片判断成0。而生成网络则是尽力生成假图片使得判别网络将其判断成1,以达到以假乱真的目的。所以二者间存在这种对抗的关系。

  • 虽然DCGANs的效果非常好,但是其实它并没有真正地学习到它要表示的物体,通过对抗它只是生成了一张尽可能真的图片。这意味着我们没办法决定哪种噪声能够生成想要的图片。不过,在GANs的变式中变相解决了这个问题,它是Conditional GAN,在训练的时候,将一句话(可以是描述)和噪声一起输入到生成网络,这样训练之后,我们就可以通过输入一句话,得到差不多我们想要的图片了。

9.参考资料

来自于:https://github.com/L1aoXingyu/code-of-learn-deep-learning-with-pytorch/blob/master/chapter6_GAN/gan.ipynb

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值