pytorch 实现 GAN,基于MNIST生成手写字符

27 篇文章 3 订阅
12 篇文章 0 订阅
import torch
import torch.nn as nn
from torch.nn import init
from torch.autograd import Variable
import torchvision
import torchvision.transforms as T
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torchvision.datasets as dset

import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

batch_size = 64
noise_dim = 96
path = os.getcwd()
path = os.path.join(path, 'mnist')
# print(path)
mnist_train = dset.MNIST(path, train=True, download=True, transform=T.ToTensor())
loader_train = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
# imgs = loader_train.__iter__().next()[0].view(batch_size, 784).numpy().squeeze()
# show_images(imgs)
def show_images(images):
    images = np.reshape(images, [images.shape[0], -1])  # images reshape to (batch_size, D)
    sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
    sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
    fig = plt.figure(figsize=(sqrtn, sqrtn))
    gs = gridspec.GridSpec(sqrtn, sqrtn)
    gs.update(wspace=0.05, hspace=0.05)

    for i, img in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(img.reshape([sqrtimg,sqrtimg]))
        # plt.imsave('a'+str(i)+'.jpg', img.reshape([sqrtimg,sqrtimg]))
    return

#生成噪声
def sample_noise(batch_size, dim):
    """
    Generate a PyTorch Tensor of uniform random noise.

    Input:
    - batch_size: Integer giving the batch size of noise to generate.
    - dim: Integer giving the dimension of noise to generate.

    Output:
    - A PyTorch Tensor of shape (batch_size, dim) containing uniform
      random noise in the range (-1, 1).
    """
    temp = torch.rand(batch_size, dim) + torch.rand(batch_size, dim) * (-1)
    return temp

# 拉直
class Flatten(nn.Module):
    def forward(self, x):
        N, C, H, W = x.size()  # read in N, C, H, W
        return x.view(N, -1)  # "flatten" the C * H * W values into a single vector per image

# 处理成符合conv输入的
class Unflatten(nn.Module):
    """
    An Unflatten module receives an input of shape (N, C*H*W) and reshapes it
    to produce an output of shape (N, C, H, W).
    """

    def __init__(self, N=-1, C=128, H=7, W=7):
        super(Unflatten, self).__init__()
        self.N = N
        self.C = C
        self.H = H
        self.W = W

    def forward(self, x):
        return x.view(self.N, self.C, self.H, self.W)


class generator(nn.Module):
    def __init__(self, noise_dim=noise_dim):
        super(generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(noise_dim, 1024),
            nn.ReLU(True),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, 7 * 7 * 128),
            nn.ReLU(True),
            nn.BatchNorm1d(7 * 7 * 128)
        )

        self.conv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, padding=1),
            nn.ReLU(True),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, 1, 4, 2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.shape[0], 128, 7, 7)  # reshape 通道是 128,大小是 7x7
        x = self.conv(x)
        return x

# 定义损失
def generator_loss(scores_fake):
    loss = 0.5 * ((scores_fake - 1) ** 2).mean()
    return loss


class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 5, 1),
            nn.LeakyReLU(0.01),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 5, 1),
            nn.LeakyReLU(0.01),
            nn.MaxPool2d(2, 2)
        )
        self.fc = nn.Sequential(
            nn.Linear(1024, 1024),
            nn.LeakyReLU(0.01),
            nn.Linear(1024, 1)
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x

def discriminatro_loss(scores_real, scores_fake):
    loss = 0.5 * ((scores_real - 1) ** 2).mean() + 0.5 * (scores_fake ** 2).mean()
    return loss

# 定义优化器
def get_optim(model):
    # betas (Tuple[float, float], 可选) – 用于计算梯度以及梯度平方的运行平均值的系数(默认:0.9,0.999)
    # weight_decay (float, 可选) – 权重衰减(L2惩罚)(默认: 0)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.5, 0.999))
    return optimizer

def train(G,D,G_loss,D_loss,G_optim,D_optim,batch_size=64,noise_size=96):
    for epoch in range(10):
        print('epoch',epoch)
        i = 1
        for x,_ in loader_train:
            # 不满足一批就跳过
            if len(x) != batch_size:
                continue
            # 训练判别器
            #真数据
            D_optim.zero_grad()
            realdata = Variable(x).to(device)
            logits_real = D(2* (realdata - 0.5)) # 变为0-1
            #假数据
            g_fake_seed = Variable(sample_noise(batch_size, noise_size)).to(device)
            fake_images = G(g_fake_seed).detach()
            logits_fake = D(fake_images.view(batch_size, 1, 28, 28))
            #反向传播
            d_total_error = D_loss(logits_real, logits_fake)
            d_total_error.backward()
            D_optim.step()
            #更新生成器
            G_optim.zero_grad()
            g_fake_seed = Variable(sample_noise(batch_size, noise_size)).to(device)
            fake_images = G(g_fake_seed)

            gen_logits_fake = D(fake_images.view(batch_size, 1, 28, 28))
            g_error = G_loss(gen_logits_fake)
            g_error.backward()
            G_optim.step()
            if i % 100 == 0:
                print('Iter: {}, D: {:.4}, G:{:.4}'.format(epoch, d_total_error, g_error))
                print(i)
                break
            i += 1
        imgs_numpy = fake_images.data.cpu().numpy()
        show_images(imgs_numpy[0:16])
        plt.show()
        # print()
        torch.save(G.state_dict(),'g.pth')
        print('模型已保存')


if __name__ == '__main__':

    '''
        训练用代码
    '''
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # G = generator().to(device)
    # D = discriminator().to(device)
    #
    # D_optim = get_optim(D)
    # G_optim = get_optim(G)
    # train(G,D,generator_loss,discriminatro_loss,G_optim,D_optim)


    '''
        生成
    '''
    model = generator().to(device)
    model.load_state_dict(torch.load('g.pth'))
    print(model)
    noise = sample_noise(batch_size, 96)
    img = model(noise)

    plt.imsave('a.jpg',img.data.cpu().numpy()[0][0])

结果:

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值