Pytorch实现GAN之生成手写数字图片

2 篇文章 1 订阅

1.导入所需库

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms

2. 训练集

# mini_batch size
mb_size=64
#translate data to tensor format which is pytorch's expected format
transforms=transforms.Compose([transforms.ToTensor()])
#训练集
trainset= torchvision.datasets.MNIST(root='./NewData',download=False,train=True,transform=transforms)
trainloader=torch.utils.data.DataLoader(trainset,shuffle=True,batch_size=mb_size)

Notes

  1. torchvision.transforms是pytorch中的图像预处理包。
    一般用Compose()把多个步骤整合到一起,例如:
transforms.Compose([
    transforms.CenterCrop(10),
    transforms.ToTensor(),
])

此外,常用的transforms中的函数:

Resize:把给定的图片resize到given size

Normalize:Normalized an tensor image with mean and standard deviation

ToTensor:convert a PIL image to tensor (HWC) in range [0,255] to a
torch.Tensor(C * H * W) in the range [0.0,1.0]

ToPILImage: convert a tensor to PIL image

参考:https://blog.csdn.net/ftimes/article/details/105202795

3.可视化

参考:https://blog.csdn.net/xiongchengluo1129/article/details/79078478

#define an iterator
data_iter=iter(trainloader)
#getting the next batch of the image and labels
images,labels=data_iter.next()
test=images.view(images.size(0),-1)
print(test.size())

#dims and learning rate
z_dim=100
x_dim=test.size(1)
h_dim=128
lr=0.003

def imshow(img):
    #拼接图片
    im=torchvision.utils.make_grid(img)
    #转化成numpy
    npimg=im.numpy()
    plt.figure(figsize=(8,8))
    plt.imshow(np.transpose(npimg,(1,2,0)))
    plt.xticks([])
    plt.yticks([])
    plt.show()

imshow(images)

输出:
在这里插入图片描述
Notes:

  1. 将多维度的tensor展平成一维,x.view(x.size(0), -1)就实现的这个功能。
    所以我们一个batch里面的64张图,图片的大小是28 * 28,输出的size为64 * 784。

  2. make_grid的作用是将若干幅图像拼成一幅图像。其中padding的作用就是子图像与子图像之间的pad有多宽。
    在这里插入图片描述

  3. plt.figure()语法
    figure(num=None, figsize=None, dpi=None, facecolor=None)

edgecolor=None, frameon=True)
num:图像编号或名称,数字为编号 ,字符串为名称
figsize:指定figure的宽和高,单位为英寸;
dpi: 指定绘图对象的分辨率,即每英寸多少个像素,缺省值为80 1英寸等于2.5cm,A4纸是 21*30cm的纸张
facecolor:背景颜色
edgecolor:边框颜色
frameon:是否显示边框

  1. np.transpose(img,(1,2,0))将图片的格式由(channels,imagesize,imagesize)转化为(imagesize,imagesize,channels),这样plt.show()就可以显示图片了。
  2. plt.xticks()用法参考:https://blog.csdn.net/Tenderness___/article/details/82972845

4. 初始化weight和bias

def init_weights(m):
    if type(m)==nn.Linear:
        #初始化权重
        nn.init.xavier_uniform(m.weight)
        # bias都设为0
        m.bias.data.fill_(0)

参考:https://blog.csdn.net/dss_dssssd/article/details/83959474‘
pytorch官方教程中的例子:
在这里插入图片描述

5. Generator and Discriminator

class Generate(nn.Module):
    def __init__(self):
        super(Generate,self).__init__()
        self.predict=nn.Sequential(
            nn.Linear(z_dim,h_dim),
            nn.ReLU(),
            nn.Linear(h_dim,x_dim),
            nn.Sigmoid()
        )
        self.predict.apply(init_weights)

    def forward(self,input):
        return self.predict(input)

class Dis(nn.Module):
    def __init__(self):
        super(Dis,self).__init__()
        self.predict =nn.Sequential(
            nn.Linear(x_dim,h_dim),
            nn.ReLU(),
            nn.Linear(h_dim,1),
            nn.Sigmoid()
        )
        self.predict.apply(init_weights)
    def forward(self,input):
        return self.predict(input)

G=Generate()
D=Dis()

6.Optimizer

G_solver=optim.Adam(G.parameters(),lr=lr)
D_solver=optim.Adam(D.parameters(),lr=lr)

7.Training

for epoch in range(2):
    G_loss_run=0.0
    D_loss_run=0.0
    for i,data in enumerate(trainloader):
        # data里面包含图像数据(inputs)(tensor类型的)和标签(labels)(tensor类型)。
        X,label=data
        mb_size=X.size(0)
        X=X.view(X.size(0),-1)

        one_labels=torch.ones(mb_size,1)
        zero_labels=torch.zeros(mb_size,1)

        z=torch.randn(mb_size,z_dim)
        G_samples=G(z)
        D_fake=D(G_samples)
        D_real=D(X)
        D_fake_loss=F.binary_cross_entropy(D_fake,zero_labels)
        D_real_loss=F.binary_cross_entropy(D_real,one_labels)

        D_loss=D_fake_loss+D_real_loss
        D_solver.zero_grad()
        D_loss.backward(retain_graph=True)
        D_solver.step()

        z=torch.rand(mb_size,z_dim)
        G_sample=G(z)
        D_fake=D(G_samples)

        G_loss=F.binary_cross_entropy(D_fake,one_labels)
        G_solver.zero_grad()
        G_loss.backward()
        G_solver.step()

    print('Epoch: {},   G_loss: {}.   D_loss:{}'.format(epoch,G_loss_run/(i+1),D_loss_run/(i+1)))
    samples=G(z).detach()
    samples=samples.view(mb_size,1,28,28)
    imshow(samples)

完整代码:

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms

torch.manual_seed(0)
mb_size=64
#translate data to tensor format which is pytorch's expected format
transforms=transforms.Compose([transforms.ToTensor()])
#训练集
trainset= torchvision.datasets.MNIST(root='./NewData',download=False,train=True,transform=transforms)
trainloader=torch.utils.data.DataLoader(trainset,shuffle=True,batch_size=mb_size)

#可视化
#define an iterator
dataiter = iter(trainloader)
#getting the next batch of the image and labels
imgs, labels = dataiter.next()

test=imgs.view(imgs.size(0),-1)
print(test.size())

h_dim = 128    # number of hidden neurons in our hidden layer
Z_dim = 100    # dimension of the input noise for generator
lr = 1e-3      # learning rate
X_dim = imgs.view(imgs.size(0), -1).size(1)
print(X_dim)

def imshow(img):
    im=torchvision.utils.make_grid(img)
    npimg=im.numpy()
    plt.figure(figsize=(8,8))
    plt.imshow(np.transpose(npimg,(1,2,0)))
    plt.xticks([])
    plt.yticks([])

    plt.show()

imshow(imgs)

def xavier_init(m):
    """ Xavier initialization """
    if type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0)


class Gen(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(Z_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, X_dim),
            nn.Sigmoid()
        )
        self.model.apply(xavier_init)

    def forward(self, input):
        return self.model(input)


class Dis(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(X_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, 1),
            nn.Sigmoid()
        )
        self.model.apply(xavier_init)

    def forward(self, input):
        return self.model(input)


test = Dis()
print(test)


test = Dis()
print(test)

G=Gen()
D=Dis()

G_solver=optim.Adam(G.parameters(),lr=lr)
D_solver=optim.Adam(D.parameters(),lr=lr)

for epoch in range(20):
    G_loss_run = 0.0
    D_loss_run = 0.0
    for i, data in enumerate(trainloader):
        X, _ = data
        X = X.view(X.size(0), -1)
        mb_size = X.size(0)

        # Definig labels for real (1s) and fake (0s) images
        one_labels = torch.ones(mb_size, 1)
        zero_labels = torch.zeros(mb_size, 1)

        # Random normal distribution for each image
        z = torch.randn(mb_size, Z_dim)

        # Feed forward in discriminator both
        # fake and real images
        D_real = D(X)
        # fakes = G(z)
        D_fake = D(G(z))

        # Defining the loss for Discriminator
        D_real_loss = F.binary_cross_entropy(D_real, one_labels)
        D_fake_loss = F.binary_cross_entropy(D_fake, zero_labels)
        D_loss = D_fake_loss + D_real_loss

        # backward propagation for discriminator
        D_solver.zero_grad()
        D_loss.backward()
        D_solver.step()

        # Feed forward for generator
        z = torch.randn(mb_size, Z_dim)
        D_fake = D(G(z))

        # loss function of generator
        G_loss = F.binary_cross_entropy(D_fake, one_labels)

        # backward propagation for generator
        G_solver.zero_grad()
        G_loss.backward()
        G_solver.step()

        G_loss_run += G_loss.item()
        D_loss_run += D_loss.item()

    # printing loss after each epoch
    print('Epoch:{},   G_loss:{},   D_loss:{}'.format(epoch, G_loss_run / (i + 1), D_loss_run / (i + 1)))

    # Plotting fake images generated after each epoch by generator
    samples = G(z).detach()
    samples = samples.view(samples.size(0), 1, 28, 28)
    imshow(samples)
  • 5
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值