Pytorch-《Deep learning with pytorch》1.2.2 使用GAN将“马变成斑马”

最近在学习《Deep learning with pytorch》,跟着b站的一个up主敲代码,本篇内容对应视频
(实验在colab上完成,对此感兴趣的可以看这一篇,有使用介绍。)

实现内容:

使用GAN生成式对抗网络,将图中的马变成斑马。
在这里插入图片描述

实验准备:

实验所需要的文件可以通过百度网盘获得:

  • horse.jpg
  • horse2zebra_0.4.0.pth

链接:https://pan.baidu.com/s/1iOSDc00eZjzjwEEGS7ph7Q
提取码:oad1

实验步骤:
第一步:构建模型
import torch
import torch.nn as nn

本章节主要的学习目的是体会这些模型是做什么的,而不是怎么做的,所以这一部分的代码先不用深究,复制粘贴即可。如果你很感兴趣,也可以自己敲一遍。

#本章主要是体会模型可以做什么,暂时不用深究他是怎么实现的
class ResNetBlock(nn.Module): # <1>

    def __init__(self, dim):
        super(ResNetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim)

    def build_conv_block(self, dim):
        conv_block = []

        conv_block += [nn.ReflectionPad2d(1)]

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
                       nn.InstanceNorm2d(dim),
                       nn.ReLU(True)]

        conv_block += [nn.ReflectionPad2d(1)]

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
                       nn.InstanceNorm2d(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        out = x + self.conv_block(x) # <2>
        return out
        
 class ResNetGenerator(nn.Module):

    def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9): # <3> 

        assert(n_blocks >= 0)
        super(ResNetGenerator, self).__init__()

        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf

        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
                 nn.InstanceNorm2d(ngf),
                 nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2**i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                stride=2, padding=1, bias=True),
                      nn.InstanceNorm2d(ngf * mult * 2),
                      nn.ReLU(True)]

        mult = 2**n_downsampling
        for i in range(n_blocks):
            model += [ResNetBlock(ngf * mult)]

        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=True),
                      nn.InstanceNorm2d(int(ngf * mult / 2)),
                      nn.ReLU(True)]

        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, input): # <3>
        return self.model(input)
#生成一个模型实体
netG = ResNetGenerator()

到这一步为止,我们只定义了网络的结构,网络的参数没有被训练。目前,它包含随机权重。所以下一步我们可以加载已经训练好的参数。

第二步:加载预训练所得的参数

我们将运行一个已经在马-斑马数据集上预训练的生成器模型,其训练集分别包含马和斑马的两组1068和1335张图像。文件horse2zebra_0.4.0.pth中包含了预训练好的张量参数,我们将它加载到网络模型中。

#上传 horse2zebra_0.4.0.pth,这个文件里有根据马和斑马数据集已经训练好的参数权重
from google.colab import files
files.upload()

在这里插入图片描述

#加载一个预训练的网络的参数
model_path = '/content/horse2zebra_0.4.0.pth'
model_data = torch.load(model_path)
#在Pytorch中构建好一个模型后,一般需要进行预训练权重中加载。
#torch.load_state_dict()函数就是用于将预训练的参数权重加载到新的模型之中,
netG.load_state_dict(model_data)

在这里插入图片描述至此,netG已经掌握了其在训练中所获得的全部知识。

第三步:我们选择一张图片
from PIL import Image
from torchvision import transforms
#然后我们定义了一些输入变换来确保数据以正确的形状和大小进入网络
preprocess = transforms.Compose([transforms.Resize(256),
                  transforms.ToTensor()]) 
#上传 horse.jpg
from google.colab import files
files.upload()
img = Image.open("/content/horse.jpg")
img

可以看到此时分辨率还是比较高的。
在这里插入图片描述

第四步:将图片放入模型中

首先将我们的对象转化成张量才能放入模型:

#将图片转化成了张量
img_t = preprocess(img)
#torch.unsqueeze(input,dim),参数dim表示在哪个地方加一个维度,注意dim范围在:[-input.dim()-1,input.dim()+1]之间
#比如输入input是一维,则dim=0时数据为行方向扩,dim=1时为列方向扩
batch_t = torch.unsqueeze(img_t,0)

对比一下升维度的前后变化:

img_t.shape,batch_t.shape

在这里插入图片描述

#batch_out为模型输出的结果
batch_out = netG(batch_t)
batch_out.shape

在这里插入图片描述
可以看到,对比上一篇文章中给图片分类,输出维度是【1,1000】,这里的输出维度和输入维度是一样的,因为我们需要的也是一张图片,而不是一个分类结果。

#图片是三维的,所以如果想打印出来,要先降为batch_out.data.squeeze()
out_t = (batch_out.data.squeeze()+1.0)/2.0
#将张量转化为图片
out_img = transforms.ToPILImage()(out_t)
out_img

最后来看一下我们的输出结果:
原图中的马已经被改编成了斑马,但是显然这个效果并不是那么逼真,牛仔的衣服、马的鬓毛也有部分变成了斑马纹。
在这里插入图片描述
好啦,第二次尝试告一段落~

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值