使用W-GAN实现手写数字识别

首先定义,generator和discriminator模型

import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, noise_size):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(noise_size, 1024),
            nn.ReLU(True),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, 7 * 7 * 128),
            nn.ReLU(True),
            nn.BatchNorm1d(7 * 7 * 128)
        )

        self.conv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, padding=1),
            nn.ReLU(True),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, 1, 4, 2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.shape[0], 128, 7, 7)  # reshape 通道是 128,大小是 7x7
        x = self.conv(x)
        return x




class Discriminator(nn.Module):
    def __init__(self, input_size=1, wgan=False):
        super(Discriminator, self).__init__()
        self.wgan = wgan
        self.conv = nn.Sequential(
            nn.Conv2d(input_size, 32, 5, 1),
            nn.LeakyReLU(0.01),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 5, 1),
            nn.LeakyReLU(0.01),
            nn.MaxPool2d(2, 2)
        )
        self.fc = nn.Sequential(
            nn.Linear(1024, 1024),
            nn.LeakyReLU(0.01),
            nn.Linear(1024, 1)
        )
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x

定义计算gradinet penalty函数,和label生成函数

def gen_label(batch_size):
    """
    用于生成真与假两个label
    """
    real = torch.ones((batch_size,)).view(-1, 1)
    fake = torch.zeros((batch_size,)).view(-1, 1)
    return real, fake

def cal_gradient_penalty(disc_net, device, real, fake):
    # compute wgan-gp
    batch_size = real.size(0)
    alpa = torch.rand(batch_size, 1, 1, 1)
    alpa = alpa.expand_as(real)
    alpa = alpa.to(device)
    #compute sample data
    interpolates = alpa*real + ((1-alpa)*fake)

    #compute y to obtain gradient
    interpolates = autograd.Variable(interpolates, requires_grad=True)
    disc_interpolates = disc_net(interpolates)
    # compute gradient
    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(disc_interpolates.size()).to(device),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]
    #compute gp
    gradient_penalty = torch.pow((gradients.norm(2, dim=1)-1), 2).mean()
    return gradient_penalty

定义数据集,和数据预处理, 模型选择gan,wgan,wgan-gp

torch.manual_seed(23)
np.random.seed(23)

def preprocess_img(x):
    x = transforms.ToTensor()(x)
    return (x - 0.5) / 0.5

def main():
    epochs = 50
    clamp_lower = -0.01
    clamp_upper = 0.01
    global_step = 0
    viz = visdom.Visdom()
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    batch_size = 64
    #data processes
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    transform = preprocess_img
    mnist_train = datasets.MNIST('./mnist', train=True, transform=transform, download=False)
    train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=nw)
    mnist_test = datasets.MNIST('./mnist', train=False, transform=transform, download=False)
    test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=nw)
    train_steps = len(train_loader)
    #noise
    noise_size = 96
    fix_noise = torch.empty((batch_size, noise_size), dtype=torch.float32).uniform_(-1, +1).to(device)
    #model
    model = 'wgan-gp'
    # 根据命令行参数选择构建哪种模型
    if model == 'wgan':
        gen_net = Generator(noise_size).to(device)
        disc_net = Discriminator(input_size=1, wgan=True).to(device)
    elif model == 'wgan-gp':
        gen_net = Generator(noise_size).to(device)
        disc_net = Discriminator(input_size=1, wgan=True).to(device)
    else:
        gen_net = Generator(noise_size).to(device)
        disc_net = Discriminator(input_size=1).to(device)

 定义损失函数,选择优化器, viz是visdom对象用来实时显示训练损失,

    criterion = nn.BCEWithLogitsLoss()
    #optimizer

    optimizer_D = optim.Adam(params=disc_net.parameters(), lr=3e-4, betas=(0.5,0.999))
    optimizer_G = optim.Adam(params=gen_net.parameters(), lr=3e-4, betas=(0.5,0.999))
    # optimizer_D = optim.RMSprop(disc_net.parameters(), lr=0.001)
    # optimizer_G = optim.RMSprop(gen_net.parameters(), lr=0.001)
    #learning rate decay

    scheduler_D = ExponentialLR(optimizer_D, gamma=0.9)
    scheduler_G = ExponentialLR(optimizer_G, gamma=0.9)
    viz.line([0.], [0], win='real loss', opts=dict(title='real image'))
    viz.line([0.], [0], win='fake loss', opts=dict(title='fake image'))
    viz.line([0.], [0], win='discriminator_loss', opts=dict(title='discriminator loss'))
    viz.line([0.], [0], win='generator_loss', opts=dict(title='generator loss'))

进行训练 

  for epoch in range(epochs):
        G_epochloss = 0.0
        D_epochloss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)  # 给训练过程加一个进度条
        #(1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        for index, (real_img, _) in enumerate(train_bar):
            gen_net.train()
            disc_net.train()
            real_img = real_img.to(device)
            batch_size = real_img.size(0)
            #real image label is one
            reallabel, fakelabel = gen_label(batch_size)
            reallabel, fakelabel = reallabel.to(device), fakelabel.to(device)
            # 生成随机噪声
            noise =  (torch.rand(batch_size, noise_size) - 0.5) / 0.5
            noise = noise.to(device)
            # noise = torch.empty((batch_size, noise_size), dtype=torch.float32).uniform_(-1, +1).to(device)

            # WGAN需要将判别器的参数绝对值截断到不超过一个固定常数c
            if model == 'wgan':
                for p in disc_net.parameters():
                    p.data.clamp_(clamp_lower, clamp_upper)

            disc_net.zero_grad()

            # 优化过程根据GAN、WGAN、WGAN-GP三种模型的不同而异。另外,为了能和之前求最小值的优化过程一致,这里我们选用损失值的相反数作为优化目标,即
            # maximize A <==> min -A
            if model == 'wgan':
                # WGAN相较于GAN,判别器最后一层去掉sigmoid函数,故直接求期望即可,不必使用损失函数
                D_Loss_real = disc_net(real_img).mean()
                fake = gen_net(noise)
                D_Loss_fake = disc_net(fake).mean()
                D_Loss = -(D_Loss_real - D_Loss_fake)
                # 反向传播
                D_Loss.backward()
            elif model == 'wgan-gp':
                # WGAN-GP此处与WGAN同
                D_Loss_real = disc_net(real_img).mean()
                fake = gen_net(noise)
                D_Loss_fake = disc_net(fake).mean()
                # WGAN-GP相较于WGAN引入了gradient penalty限制梯度
                gradient_penalty = cal_gradient_penalty(disc_net, device, real_img.data, fake.data)
                D_Loss = -(D_Loss_real - D_Loss_fake) + gradient_penalty * 0.1
                # 反向传播
                D_Loss.backward()
            else:
                # 与上面两个不同的是,GAN的公式是maximize log(D(x)) + log(1 - D(G(z)))
                D_Loss_real = criterion(disc_net(real_img), reallabel)
                fake = gen_net(noise).detach()
                D_Loss_fake = criterion(disc_net(fake), fakelabel)
                D_Loss = D_Loss_real + D_Loss_fake
                # 反向传播
                D_Loss.backward()
            D_epochloss += D_Loss.item()
            optimizer_D.step()

            """
            接着要进行maxmin算法的minimize生成器Loss的部分
            """
            # 将梯度缓存置0
            gen_net.zero_grad()
            # 生成放入generator中的噪声
            # noise = torch.randn(batch_size, noise_size).to(device)
            fake = gen_net(noise)
            # 分模型的细节与上述原理相同
            if model == 'wgan':
                G_Loss = -disc_net(fake).mean()
                G_Loss.backward()
            elif model == 'wgan-gp':
                G_Loss = -disc_net(fake).mean()
                G_Loss.backward()
            else:
                G_Loss = criterion(disc_net(fake), reallabel)
                G_Loss.backward()
            G_epochloss += G_Loss.item()
            optimizer_G.step()
            global_step += 1
            viz.line([D_Loss_real.item()], [global_step], win='real loss', opts=dict(title='real image'), update='append')
            viz.line([D_Loss_fake.item()], [global_step], win='fake loss', opts=dict(title='fake image'), update='append')
            viz.line([D_Loss.item()], [global_step], win='discriminator_loss', opts=dict(title='discriminator loss'), update='append')
            viz.line([G_Loss.item()], [global_step], win='generator_loss', opts=dict(title='generator loss'), update='append')
            if index % 100 == 0:
                viz.images(real_img, nrow=16, win='real_image', opts=dict(title='real image'))
                viz.images(fake.detach(), nrow=16, win='fake image', opts=dict(title='fake image'))
        scheduler_D.step(epoch)
        scheduler_G.step(epoch)
        print("%d / %d  discriminator loss is %.3f" % (epoch + 1, epochs, D_epochloss / train_steps))


if __name__ == '__main__':
    main()
  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值