WGAN模型——pytorch实现

论文传送门:https://arxiv.org/pdf/1701.07875.pdf

参考文章:令人拍案叫绝的Wasserstein GAN - 知乎​​​​​​

WGAN的目的:解决GAN的梯度不稳定、多样性不足的问题。

WGAN的思想:使用Wasserstein距离代替JS散度,来描述生成分布与真实分布的距离。

WGAN的实现:与GAN相比,有四处不同:

①判别器D去掉最后一层sigmoid激活函数,使得判别器D的作用变为计算近似的Wasserstein距离(代码13-31行);

②生成器和判别器的loss不再取log,近似Wasserstein距离(代码131,140行);

③判别器参数更新时,限制其值在[-c,c]区间内,使D(x)满足Lipschitz连续条件(代码134-135行);

④不采用基于动量的优化器(Adam等),使用RMSprop或SGD(代码113-114行)。

import os
import torch
from torch.utils.data import DataLoader

import torch.nn as nn

from torchvision import datasets, transforms
from torchvision.utils import save_image

from tqdm import tqdm


class Discriminator(nn.Module):  # 定义判别器(WS-divergence)
    def __init__(self, img_shape=(1, 28, 28)):  # 初始化方法
        super(Discriminator, self).__init__()  # 继承初始化方法
        self.img_shape = img_shape  # 图片形状

        self.linear1 = nn.Linear(self.img_shape[0] * self.img_shape[1] * self.img_shape[2], 512)  # linear映射
        self.linear2 = nn.Linear(512, 256)  # linear映射
        self.linear3 = nn.Linear(256, 1)  # linear映射
        self.leakyrelu = nn.LeakyReLU(0.2, inplace=True)  # leakyrelu激活函数

    def forward(self, x):  # 前传函数
        x = torch.flatten(x, 1)  # 输入图片从三维压缩至一维特征向量,(n,1,28,28)-->(n,784)
        x = self.linear1(x)  # linear映射,(n,784)-->(n,512)
        x = self.leakyrelu(x)  # leakyrelu激活函数
        x = self.linear2(x)  # linear映射,(n,512)-->(n,256)
        x = self.leakyrelu(x)  # leakyrelu激活函数
        x = self.linear3(x)  # linear映射,(n,256)-->(n,1)

        return x  # 返回近似拟合的Wasserstein距离


class Generator(nn.Module):  # 定义生成器
    def __init__(self, img_shape=(1, 28, 28), latent_dim=100):  # 初始化方法
        super(Generator, self).__init__()
        self.img_shape = img_shape  # 图片形状
        self.latent_dim = latent_dim  # 噪声z的长度

        self.linear1 = nn.Linear(self.latent_dim, 128)  # linear映射
        self.linear2 = nn.Linear(128, 256)  # linear映射
        self.bn2 = nn.BatchNorm1d(256, 0.8)  # bn操作
        self.linear3 = nn.Linear(256, 512)  # linear映射
        self.bn3 = nn.BatchNorm1d(512, 0.8)  # bn操作
        self.linear4 = nn.Linear(512, 1024)  # linear映射
        self.bn4 = nn.BatchNorm1d(1024, 0.8)  # bn操作
        self.linear5 = nn.Linear(1024, self.img_shape[0] * self.img_shape[1] * self.img_shape[2])  # linear映射
        self.leakyrelu = nn.LeakyReLU(0.2, inplace=True)  # leakyrelu激活函数
        self.tanh = nn.Tanh()  # tanh激活函数,将输出压缩至(-1.1)

    def forward(self, z):  # 前传函数
        z = self.linear1(z)  # linear映射,(n,100)-->(n,128)
        z = self.leakyrelu(z)  # leakyrelu激活函数
        z = self.linear2(z)  # linear映射,(n,128)-->(n,256)
        z = self.bn2(z)  # 一维bn操作
        z = self.leakyrelu(z)  # leakyrelu激活函数
        z = self.linear3(z)  # linear映射,(n,256)-->(n,512)
        z = self.bn3(z)  # 一维bn操作
        z = self.leakyrelu(z)  # leakyrelu激活函数
        z = self.linear4(z)  # linear映射,(n,512)-->(n,1024)
        z = self.bn4(z)  # 一维bn操作
        z = self.leakyrelu(z)  # leakyrelu激活函数
        z = self.linear5(z)  # linear映射,(n,1024)-->(n,784)
        z = self.tanh(z)  # tanh激活函数
        z = z.view(-1, self.img_shape[0], self.img_shape[1], self.img_shape[2])  # 从一维特征向量扩展至三维图片,(n,784)-->(n,1,28,28)

        return z  # 返回生成的图片


if __name__ == "__main__":
    # 训练参数
    total_epochs = 100  # 训练轮次
    batch_size = 64  # 批大小
    lr = 5e-5  # 学习率
    num_workers = 8  # 数据加载线程数
    latent_dim = 100  # 噪声z长度
    image_size = 28  # 图片尺寸
    channel = 1  # 图片通道
    clip_value = 0.01  # 判别器参数限定范围
    dataset_dir = "dataset/mnist"  # 训练数据集路径
    gen_images_dir = "gen_images"  # 生成样例图片路径
    cuda = True if torch.cuda.is_available() else False  # 设置是否使用cuda
    os.makedirs(dataset_dir, exist_ok=True)  # 创建训练数据集路径
    os.makedirs(gen_images_dir, exist_ok=True)  # 创建样例图片路径
    image_shape = (channel, image_size, image_size)  # 图片形状

    # 模型
    D = Discriminator(image_shape)  # 实例化判别器
    G = Generator(image_shape, latent_dim)  # 实例化生成器
    if cuda:  # 如果使用cuda
        D = D.cuda()  # 模型加载到GPU
        G = G.cuda()  # 模型加载到GPU

    # 数据集
    transform = transforms.Compose(  # 数据预处理方法
        [transforms.Resize(image_size),  # resize
         transforms.ToTensor(),  # 转为tensor
         transforms.Normalize([0.5], [0.5])]  # 标准化
    )
    dataloader = DataLoader(  # dataloader
        dataset=datasets.MNIST(  # 数据集选取MNIST手写体数据集
            root=dataset_dir,  # 数据集存放路径
            train=True,  # 使用训练集
            download=True,  # 自动下载
            transform=transform  # 应用数据预处理方法
        ),
        batch_size=batch_size,  # 设置batch size
        num_workers=num_workers,  # 设置读取数据线程数
        shuffle=True  # 设置打乱数据
    )

    # 优化器
    optimizer_D = torch.optim.RMSprop(D.parameters(), lr=lr)  # 定义判别网络RMSprop优化器,传入学习率
    optimizer_G = torch.optim.RMSprop(G.parameters(), lr=lr)  # 定义生成网络RMSprop优化器,传入学习率

    # 训练循环
    for epoch in range(total_epochs):  # 循环epoch
        pbar = tqdm(total=len(dataloader), desc=f'Epoch {epoch + 1}/{total_epochs}', postfix=dict,
                    mininterval=0.3)  # 设置当前epoch显示进度
        for i, (real_imgs, _) in enumerate(dataloader):  # 循环iter
            if cuda:  # 如果使用cuda
                real_imgs = real_imgs.cuda()  # 数据加载到GPU
            bs = real_imgs.shape[0]  # batchsize

            # 开始训练判别网络D
            optimizer_D.zero_grad()  # 判别网络D清零梯度
            z = torch.randn((bs, latent_dim))  # 生成输入噪声z,服从标准正态分布,长度为latent_dim
            if cuda:  # 如果使用cuda
                z = z.cuda()  # 噪声z加载到GPU
            fake_imgs = G(z).detach()  # 噪声z输入生成网络G,得到生成图片,并阻止其反向梯度传播
            loss_D = -torch.mean(D(real_imgs)) + torch.mean(D(fake_imgs))  # 判别网络D的损失函数
            loss_D.backward()  # 反向传播,计算当前梯度
            optimizer_D.step()  # 根据梯度,更新网络参数
            for p in D.parameters():  # 遍历判别网络D的模型参数
                p.data.clamp_(-clip_value, clip_value)  # 将参数限制在[-clip_value,clip_value]区间

            # 开始训练生成网络G
            optimizer_G.zero_grad()  # 生成网络G清零梯度
            gen_imgs = G(z)  # 噪声z输入生成网络G,得到生成图片
            loss_G = -torch.mean(D(gen_imgs))  # 生成网络G的损失函数
            loss_G.backward()  # 反向传播,计算当前梯度
            optimizer_G.step()  # 根据梯度,更新网络参数

            pbar.set_postfix(**{'D_loss': loss_D.item(), 'G_loss': loss_G.item()})  # 显示判别网络D和生成网络G的损失
            pbar.update(1)  # 步进长度

        pbar.close()  # 关闭当前epoch显示进度
        save_image(gen_imgs.data[:25], "%s/ep%d.png" % (gen_images_dir, (epoch + 1)), nrow=5,
                   normalize=True)  # 保存生成图片样例(5x5)

  • 6
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CV_Peach

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值