有关条件GAN(cgan)的相关原理,可以参考:
其他类型的GAN原理介绍以及应用,可以查看我的GANs专栏
一、数据集介绍,加载数据
依旧使用到的是我们的老朋友-----MNIST手写数字数据集, 本文不再详细做介绍
相关数据集介绍可以参考:深度学习入门--MNIST数据集及创建自己的手写数字数据集
传统GAN生成手写数字参考:入门GAN实战---生成MNIST手写数据集代码实现pytorch
DCGAN生成手写数字参考:Pytorch 使用DCGAN生成MNIST手写数字 入门级教程
# 独热编码
def one_hot(x, class_count=10):
return torch.eye(class_count)[x, :]
transform =transforms.Compose([transforms.ToTensor(),
transforms.Normalize(0.5, 0.5)])
# 这个数据集其实包含两部分,第一部分是数据,第二部分是标签 print(dataset[0])
#这个第二部分就是我们所需要的condition,这个condition是数值类型,1就是1,2就是2。
#作为输入的condition并不是很合适,一种处理方法就是作为一种向量输入,就是独热编码化。
#比如说现在有10个类别,10个类别将被独热编码为长度为10的tensor,使用这个tensor作为我们的condition是比较合适的
dataset = torchvision.datasets.MNIST('data',
train=True,
transform=transform,
target_transform=one_hot,
download=False)
dataloader = data.DataLoader(dataset, batch_size=64, shuffle=True)
这里有个小技巧是作者用到独热编码化
One-Hot编码,又称为一位有效编码,主要是采用位状态寄存器来对个状态进行编码,每个状态都由他独立的寄存器位,并且在任意时候只有一位有效。独热编码 是利用0和1表示一些参数,使用N位状态寄存器来对N个状态进行编码。
例如:参考数字手写体识别中:如