记录一下个人学习和使用Pytorch中的一些问题。强烈推荐 《深度学习框架PyTorch:入门与实战》.写的非常好而且作者也十分用心,大家都可以看一看,本文为学习第七章GAN生成动漫头像的学习笔记。
主要分析实现代码里面main,model,visualize这3个代码文件完成整个项目模型结构定义,训练及生成,还有输出展示的整个过程。
model文件
整个模型结构是经典的生成器-判别器架构,model文件也只有这两个类,分别用于生成和判别图片。
生成器
生成器是从无到有,从一个噪声扩充数据到指定的大小,所以其中的网络层是反卷积层。除了第一层的输入通道数为参数设置之外,其他各层与判别器对称,可以从1x1x opt.nz的噪音,生成一个3x96x96的图片
class NetG(nn.Module):
"""
生成器定义
"""
def __init__(self, opt):
super(NetG, self).__init__()
ngf = opt.ngf # 生成器feature map数
self.main = nn.Sequential(
# 输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的feature map
nn.ConvTranspose2d(opt.nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# 上一步的输出形状:(ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# 上一步的输出形状: (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# 上一步的输出形状: (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf *