- 对于GAN摸索了一段时间,有一点心的,就是要注意使用普通的网络作为生成器和判别器(例如:全连接网络)需要注意使用BatchNormalization,进行批量归一化,不然很难出现好的结果。还有生成器的最后一层需要使用tanh()函数,推荐吧,也可以使用sigmoid,二者在这里的区别,可以自己找找。
- 这是GAN的pytorch版本的实现。
-
导入相关库
import torch import torchvision import torch.nn as nn import torch.nn.functional as F import matplotlib.pylab as plt from matplotlib import animation from IPython.display import HTML
-
设置用到的一些常量
BATCH_SIZE = 100 IMG_CHANNELS = 1 NUM_Z = 100 NUM_GENERATOR_FEATURES = 64 NUM_DISCRIMINATOR_FEATURES = 64 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") INPUTS_G = torch.randn(BATCH_SIZE, NUM_Z, 1, 1, device=DEVICE) # INPUTS_G = torch.randn(BATCH_SIZE, NUM_Z, device=DEVICE)
-
加载数据集(MNIST10)数据集
transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor() ]) # ds = torchvision.datasets.cifar.CIFAR10(root="data", train=True, transform=transform, download=True) ds = torchvision.datasets.mnist.MNIST(root="data", train=True, transform=transform, download=True) ds_loader = torch.utils.data.DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True)
-
查看数据
img_batch, lab_batch = next(iter(ds_loader)) img_batch.shape, lab_batch.shape
-
绘制数据集图像
plt.figure(figsize=(8, 8), dpi=80) plt.imshow(torchvision.utils.make_grid(img_batch, nrow=10, padding=2, pad_value=1, normalize=True).permute(1, 2, 0)) plt.tight_layout() plt.axis("off")
-
定义生成器和判别器
class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() # (o - 1) * s - 2 * p + w self.main = nn.Sequential( # 100 x 1 x 1 --> 512 x 4 x 4 nn.ConvTranspose2d(NUM_Z, NUM_GENERATOR_FEATURES * 8, 4, 1, 0, bias=False), nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 8), nn.ReLU(True), # 512 x 4 x 4 --> 512 x 8 x 8 nn.ConvTranspose2d(NUM_GENERATOR_FEATURES * 8, NUM_GENERATOR_FEATURES * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 4), nn.ReLU(True), # 512 x 8 x 8 --> 512 x 16 x 16 nn.ConvTranspose2d(NUM_GENERATOR_FEATURES * 4, NUM_GENERATOR_FEATURES * 2, 4, 2, 1