实验:打造自己的MNIST-GAN

实验:打造自己的MNIST-GAN

1 实验内容

借助Keras,Tensorfolow 或Pytorch 等框架,设计和搭建自己的MNIST-GAN 图像生成器,生成新的手写数字图片

要求:

  • 实现MNIST 数据加载和可视化

  • 搜索和阅读相关资料和论文,在Keras,Tensorfolow或Pytorch 任意框架下实现MNIST-GAN网络的构建和训练

  • 使用训练好的MNIST-GAN 网络产生新的0-9 手写数字图片,并在训练数据集中找出和新生成图片‘‘最接近’’(可自行定义接近程度,或者尝试多种方式后人工比较)的训练图片

  • 使用linearly interpolating 完成下图中效果(图片来源:Figure 3 in Generative Adversarial Nets, Ian J. Goodfellow, et al.)

    image-20211018184337448

  • (选做)GAN 的训练被认为相对困难(可参见‘‘参考资料’’),总结在实验中遇到的问题,搜索资料,尝试不同的解决方案并总结

2 实验原理

Basic Idea of GAN

Algorithm

3 具体实现

使用原生GAN实现

加载MNIST数据
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets

# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),

    batch_size=opt.batch_size,
    shuffle=True,
)

这里随机取几张图片观察。

def show_img(img, trans=True):
    if trans:
        img = np.transpose(img.detach().cpu().numpy(), (1, 2, 0))  # 把channel维度放到最后
        plt.imshow(img[:, :, 0], cmap="gray")
    else:
        plt.imshow(img, cmap="gray")
    plt.show()
    
mnist = datasets.MNIST("../../data/mnist")

image-20211018214050251

image-20211018214113981

构建生成器

仿照下图的原生GAN的结构来搭建。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-un1cmauM-1634715042902)(https://i.loli.net/2021/10/19/HYN87qkdefZhmyl.png)]

我们的生成器包含5个全连接层,使用LeakyReLU和Tanh激活函数,使用了BatchNorm。

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

结构如下:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-b7qAO5Zw-1634715042904)(https://i.loli.net/2021/10/18/ZsElTonhgqWweQv.png)]

构建判别器

仿照原生GAN,使用全连接网络,把Maxout激活函数换为ReLU与Sigmoid。

image-20211019143401129

包含3个全连接层,使用LeakyReLU和Sigmoid激活函数。

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity
    
discriminator = Discriminator()
print(discriminator)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-TpTQFtR8-1634715042908)(https://i.loli.net/2021/10/18/4ALVpd8lhnOPWzi.png)]

损失函数与优化

判别器使用 Binary Cross Entropy Loss。

优化都使用Adam,lr = 0.0002。

optimizer_G = torch**.**optim**.**Adam(generator**.**parameters(), lr=opt**.**lr, betas=(opt**.**b1, opt**.**b2))

optimizer_D = torch**.**optim**.**Adam(discriminator**.**parameters(), lr=opt**.**lr, betas=(opt**.**b1, opt**.**b2))
随机采样

从100维的正态分布中采样作为z。

一个batch有64组输入。

z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
交替训练
valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

real_imgs = Variable(imgs.type(Tensor))

#更新生成器

optimizer_G.zero_grad()

#采样z
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
gen_imgs = generator(z)

#生成器权值更新
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
g_loss.backward()
optimizer_G.step()

#更新判别器
optimizer_D.zero_grad()
real_loss = adversarial_loss(discriminator(real_imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
生成结果

每400次迭代观察一次当前生成图像。

最开始,生成全是杂讯。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-NHUs123v-1634715042910)(https://i.loli.net/2021/10/19/gdMUDkeC42OJ6sR.png)]

开始设置的epoch数很少,结果很差,下图是第6000次迭代的结果:

image-20211019151416280

20000次:

image-20211019151533525

100000次:

image-20211019151626088

200个epoch以后,也就是十八万多次迭代以后的最终结果:

image-20211019184330709

感觉没有很好的结果,还需要继续train下去,但没有继续尝试了。

使用CNN+GAN实现

更改生成网络结构
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )
        
    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

网络结构为:

更改判别网络结构
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(opt.channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        ds_size = opt.img_size // 2 ** 4
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())

网络结构为:

训练过程

生成结果

比用原生GAN的结果好很多。

比如:

第6000次迭代:

6000

第20000次迭代:

]

第100个epoch:

94000

第120个epoch:

image-20211019194510594

观察linearly interpolating结果

随机选两个点,在两点中取10个点观察变化过程:

image-20211020110239495

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
g = torch.load('model/generator.pkl')
z = Variable(Tensor(np.random.normal(0, 1, (2, 100))))
a = torch.FloatTensor(100, 20)
for i in range(100):
    a[i] = torch.linspace(z[0][i], z[1][i], 10)

b = Variable(a.t())
b = b.to('cuda')
gen_imgs = g(b)
save_image(gen_imgs.data[:], "images_trans.png", normalize=True)

再次尝试观察更细致的变化:

]

使用CGAN实现

为了可以控制输出我们可以使用CGAN

在原生GAN结构基础上,更改网络结构如下:

更改生成网络结构

更改判别网络结构
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)

        self.model = nn.Sequential(
            nn.Linear(opt.n_classes + int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1),
        )

    def forward(self, img, labels):
        d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1)
        validity = self.model(d_in)
        return validity

结构如下:

image-20211020112211492

交替训练

把标签引入训练。

    batch_size = imgs.shape[0]
    valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False) # 为1时判定为真
    fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False) # 为0时判定为假
    
    optimizer_G.zero_grad()
    gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))
    
	#更新生成器
    gen_imgs = generator(z, gen_labels)
    print("gen_imgs =")
    for img in gen_imgs[:3]:
        show_img(img)

    validity = discriminator(gen_imgs, gen_labels)
    g_loss = adversarial_loss(validity, valid)
    print("g_loss =", g_loss, '\n')

    g_loss.backward()
    optimizer_G.step()

   #更新判别器

    optimizer_D.zero_grad()

    validity_real = discriminator(real_imgs, labels)
    d_real_loss = adversarial_loss(validity_real, valid)
    validity_fake = discriminator(gen_imgs.detach(), gen_labels)
    
    d_fake_loss = adversarial_loss(validity_fake, fake)
    d_loss = (d_real_loss + d_fake_loss) / 2
    print("real_loss =", d_real_loss, '\n')
    print("fake_loss =", d_fake_loss, '\n')
    print("d_loss =", d_loss, '\n')    
    
    d_loss.backward()
    optimizer_D.step()
生成结果:

100个epoch后的结果

  • 4
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Yuetianw

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值