今天我们学习的模型是CycleGAN,到目前为止我们需要先熟悉部分的模型,对应内部的算法和运行逻辑我并没有进行学习,大致只知道这个模型的部分结构并且这个网络在这个章节中是干什么的。对于本书的编辑作者也是想要我们想了解部分模型,并对模型产生兴趣后面的章节在进行深度的学习并了解模型的搭建。
本篇文章我们看到的是一个把马变成斑马的网络。该网络学习了一匹或多匹马的图像,从而在你输入一张图片的时候,它可以将他们全部变成斑马,图像的其余部分尽可能的不被修改。接下来的部分我们来看一下这个到底是怎么具体实现的。
- 首先我们建立一个模型
netG = ResNetGenerator()
这个模型虽然已经被建立,但它包含的是随机权重。而这个模型已经有着被训练的权重,文件名为‘horse2zebra_0.4.0.pth’,也就是已经被训练过的模型。
- 加载文件并使用load_state_dict()加载模型
model_path = './dict/horse2zebra_0.4.0.pth'
model_data = torch.load(model_path) # 转化为张量
netG.load_state_dict(model_data)
同样,上节我们也提到到eval模式,就像对resnet101所作的那样。
# 切换模型到评估模式
netG.eval()
- 预处理操作
from torchvision import transforms
# 定义预处理操作,包括调整大小和转换为张量
preprocess = transforms.Compose([transforms.Resize(256),
transforms.ToTensor()])
- 加载和预处理图像
import torch
from PIL import Image
# 加载和预处理图像
img = Image.open('./pic/horse.jpg')
img_t = preprocess(img)
batch_t = torch.unsqueeze(img_t, dim=0)
- 模型推理
# 执行模型推理
batch_out = netG(batch_t)
- 处理输出,转化为图片并显示
from torchvision import transforms
out_t = (batch_out.data.squeeze() + 1.0) / 2.0
out_img = transforms.ToPILImage()(out_t)
# out_img.save('./zebra.jpg') # 对文件进行保存
out_img.show() # 展示图片
其中还有对ResNetBlock进行定义的类,具体的实现我目前也不清楚,需要等到后面的章节原书有着相应的讲解,现在我们只需要了解建立模型和使用模型大致需要几个步骤就行了。
- 相关的代码如下
# 定义ResNetBlock类,继承自nn.Module
class ResNetBlock(nn.Module):
def __init__(self, dim):
"""
初始化ResNetBlock
:param dim: 卷积层的输入和输出通道数
"""
super(ResNetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim) # 构建卷积块
def build_conv_block(self, dim):
"""
构建一个卷积块
:param dim: 卷积层的输入和输出通道数
:return: nn.Sequential包含多个卷积层的顺序容器
"""
conv_block = []
# 添加反射填充层
conv_block += [nn.ReflectionPad2d(1)]
# 添加卷积层、实例归一化层和ReLU激活函数
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
nn.InstanceNorm2d(dim),
nn.ReLU(True)]
# 再次添加反射填充层
conv_block += [nn.ReflectionPad2d(1)]
# 再次添加卷积层和实例归一化层
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
nn.InstanceNorm2d(dim)]
# 使用nn.Sequential将层组合成一个卷积块
return nn.Sequential(*conv_block)
def forward(self, x):
"""
前向传播函数
:param x: 输入张量
:return: 加上输入张量的卷积块输出
"""
out = x + self.conv_block(x) # 跳跃连接
return out
# 定义ResNetGenerator类,继承自nn.Module
class ResNetGenerator(nn.Module):
def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9): # <3>
"""
初始化ResNetGenerator
:param input_nc: 输入图像的通道数
:param output_nc: 输出图像的通道数
:param ngf: 第一层卷积层的过滤器数量
:param n_blocks: ResNetBlock的数量
"""
assert(n_blocks >= 0)
super(ResNetGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
# 构建模型结构
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
nn.InstanceNorm2d(ngf),
nn.ReLU(True)]
n_downsampling = 2
mult = 2**n_downsampling
# 添加下采样层
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=True),
nn.InstanceNorm2d(ngf * mult * 2),
nn.ReLU(True)]
mult = 2**n_downsampling
# 添加ResNet Block
for i in range(n_blocks):
model += [ResNetBlock(ngf * mult)]
# 添加上采样层
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=True),
nn.InstanceNorm2d(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
"""
前向传播函数
:param input: 输入张量
:return: 输出张量
"""
return self.model(input)
本文总体代码如下:
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
# 定义ResNetBlock类,继承自nn.Module
class ResNetBlock(nn.Module):
def __init__(self, dim):
"""
初始化ResNetBlock
:param dim: 卷积层的输入和输出通道数
"""
super(ResNetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim) # 构建卷积块
def build_conv_block(self, dim):
"""
构建一个卷积块
:param dim: 卷积层的输入和输出通道数
:return: nn.Sequential包含多个卷积层的顺序容器
"""
conv_block = []
# 添加反射填充层
conv_block += [nn.ReflectionPad2d(1)]
# 添加卷积层、实例归一化层和ReLU激活函数
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
nn.InstanceNorm2d(dim),
nn.ReLU(True)]
# 再次添加反射填充层
conv_block += [nn.ReflectionPad2d(1)]
# 再次添加卷积层和实例归一化层
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
nn.InstanceNorm2d(dim)]
# 使用nn.Sequential将层组合成一个卷积块
return nn.Sequential(*conv_block)
def forward(self, x):
"""
前向传播函数
:param x: 输入张量
:return: 加上输入张量的卷积块输出
"""
out = x + self.conv_block(x) # 跳跃连接
return out
# 定义ResNetGenerator类,继承自nn.Module
class ResNetGenerator(nn.Module):
def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9): # <3>
"""
初始化ResNetGenerator
:param input_nc: 输入图像的通道数
:param output_nc: 输出图像的通道数
:param ngf: 第一层卷积层的过滤器数量
:param n_blocks: ResNetBlock的数量
"""
assert(n_blocks >= 0)
super(ResNetGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
# 构建模型结构
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
nn.InstanceNorm2d(ngf),
nn.ReLU(True)]
n_downsampling = 2
mult = 2**n_downsampling
# 添加下采样层
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=True),
nn.InstanceNorm2d(ngf * mult * 2),
nn.ReLU(True)]
mult = 2**n_downsampling
# 添加ResNet Block
for i in range(n_blocks):
model += [ResNetBlock(ngf * mult)]
# 添加上采样层
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=True),
nn.InstanceNorm2d(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
"""
前向传播函数
:param input: 输入张量
:return: 输出张量
"""
return self.model(input)
# 加载模型权重
netG = ResNetGenerator()
model_path = './dict/horse2zebra_0.4.0.pth'
model_data = torch.load(model_path)
netG.load_state_dict(model_data)
# 切换模型到评估模式
netG.eval()
# 定义预处理操作,包括调整大小和转换为张量
preprocess = transforms.Compose([transforms.Resize(256),
transforms.ToTensor()])
# 加载和预处理图像
img = Image.open('./pic/horse.jpg')
img_t = preprocess(img)
batch_t = torch.unsqueeze(img_t, dim=0)
# 执行模型推理
batch_out = netG(batch_t)
# 处理输出,转换为图像并显示
out_t = (batch_out.data.squeeze() + 1.0) / 2.0
out_img = transforms.ToPILImage()(out_t)
# out_img.save('./data/pic2/zebra.jpg')
out_img.show()