1. 思路
整个模型的要实现的模型是,我们有一个内容的图片和风格的图片,我们通过迁移学习得到一张新的图片。
2. 流程
- 读取内容图像和风格图像
- 预处理和后处理
- 抽取图像特征
- 定义损失函数
- 初始化合成图像
- 训练模型
3. 代码
# -*- coding: utf-8 -*-
# @Project: zc
# @Author: ZhangChu
# @File name: style_test
# @Create time: 2022/1/7 20:10
# 1. 导入相关数据库
import matplotlib.pyplot as plt
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
import os
# 2. 设置图片
d2l.set_figsize()
# path = os.path.join(os.getcwd(), 'img', 'banan.jpg')
path = 'D:/zc/img/rainier.jpg'
content_img = d2l.Image.open(path)
# d2l.plt.imshow(content_img)
# plt.show()
style_img = d2l.Image.open('D:/zc/img/autumn-oak.jpg')
# d2l.plt.imshow(style_img)
# plt.show()
# 3. 张量均一化处理,rgb_mean :均值,rgb_std:方差
rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])
# 4. 预处理,Compose组合
def preprocess(img, image_shape):
transforms = torchvision.transforms.Compose([
torchvision.transforms.Resize(image_shape),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])
return transforms(img).unsqueeze(0)
# 5. 后处理,将张量的值进行反归一化处理
def postprocess(img):
# 将图片放到 GPU 上
img = img[0].to(rgb_std.device)
# 将张量进行反归一化处理,torch.clamp限制张量的值在[0,1]之间
img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1)
# 将张量变成PILImage图片
return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))
# 6.将vgg19模型和对应的权重下载作为预处理模型
pretrained_net = torchvision.models.vgg19(pretrained=True)
# 7.特征提取
# 选择每个卷积块的第一个卷积层作为风格层style_layers
# 选择第四个卷积块的最后一个卷积层作为内容层 content_layers
style_layers, content_layers = [0, 5, 10, 19, 28], [25]
# 8.从vgg19里面筛选出对应需要的层,具体为我们选择的层
net = nn.Sequential(*[pretrained_net.features[i] for i in
range(max(content_layers + style_layers) + 1)])
# 9. 抽取特征,返回内容层和风格层
def extract_features(X, content_layers, style_layers):
contents = []
styles = []
for i in range(len(net)):
X = net[i](X)
if i in style_layers:
styles.append(X)
if i in content_layers:
contents.append(X)
return contents, styles
# 10.对内容图像抽取内容特征
def get_contents(image_shape, device):
"""
:param image_shape: 图像的大小
:param device: GPU
:return:
"""
# 将内容图片进行预处理
content_X = preprocess(content_img, image_shape).to(device)
# 抽取特征
content_Y, _ = extract_features(content_X, content_layers, style_layers)
return content_X, content_Y
# 11. 获取风格
def get_styles(image_shape, device):
# 对风格图片进行预处理
style_X = preprocess(style_img, image_shape).to(device)
# 对风格图片进行特征抽取
_, styles_Y = extract_features(style_X, content_layers, style_layers)
return style_X, styles_Y
# 12. 均方误差来计算得到内容损失
def content_loss(Y_hat, Y):
return torch.square(Y_hat - Y.detach()).mean()
# 13.向量克拉姆矩阵,
# 主要用来衡量合成图像与风格图像在风格上的差异
# 通过计算向量X_i,X_j的内积,来判断通道i,通道j上风格特征的相关性
def gram(X):
num_channels, n = X.shape[1], X.numel() // X.shape[1]
X = X.reshape((num_channels, n))
return torch.matmul(X, X.T) / (num_channels * n)
# 14.风格损失函数
# 将计算得到的Y_hat 风格[通过gram函数得到]与标签的风格gram_Y求均方误差得到损失
def style_loss(Y_hat, gram_Y):
return torch.square(gram(Y_hat) - gram_Y.detach()).mean()
# 15.全变分损失,一种常见的去噪方法
def tv_loss(Y_hat):
return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +
torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())
content_weight, style_weight, tv_weight = 1, 1e3, 10
# 16. 分别计算内容损失,风格损失和全变分损失;
# 通过调节这些权重超参数,我们可以权衡合成对象在保留内容,
# 迁移风格 以及去噪三个方面的相对重要性
def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):
"""
:param X: 输入的图片张量 X
:param contents_Y_hat: 通过net(x)得到的内容Y_hat
:param styles_Y_hat: 通过net(x)得到的风格 Y_hat
:param contents_Y: 标签 Y
:param styles_Y_gram: 风格标签 Y
:return:
contents_l : 内容损失;
styles_l : 风格损失
tv_l :全变分损失
l :总损失
"""
# 计算内容损失
contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(
contents_Y_hat, contents_Y)]
# 计算风格损失
styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(
styles_Y_hat, styles_Y_gram)]
# 计算全变分损失来去噪
tv_l = tv_loss(X) * tv_weight
# 根据一定的比例将三种损失求和得到一个总的损失
l = sum(10 * styles_l + contents_l + [tv_l])
return contents_l, styles_l, tv_l, l
# 17. 初始化合成图像,因为权重需要更新,故设置为nn.Parameter类型
class SynthesizedImage(nn.Module):
def __init__(self, img_shape, **kwargs):
super(SynthesizedImage, self).__init__(**kwargs)
self.weight = nn.Parameter(torch.rand(*img_shape))
def forward(self):
return self.weight
# 18.创建合成图像类的一个模型实例,并将其初始化为图像X
def get_inits(X, device, lr, styles_Y):
"""
:param X: 输入矩阵 X
:param device: GPU
:param lr: 学习率
:param styles_Y: 风格 Y
:return:
"""
# 实例化一个合成图像类的一个模型实例
gen_img = SynthesizedImage(X.shape).to(device)
# 通过 X.data 初始化实例 gen_img 的权重值
gen_img.weight.data.copy_(X.data)
# 定义优化器Adam ,需要更新的参数为实例的 gen_img
# 设置学习率 lr
trainer = torch.optim.Adam(gen_img.parameters(), lr=lr)
styles_Y_gram = [gram(Y) for Y in styles_Y]
return gen_img(), styles_Y_gram, trainer
# 19. 开始训练模型
def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):
"""
:param X: 输入图像X的张量
:param contents_Y: 内容_Y
:param styles_Y: 风格_Y
:param device: GPU
:param lr: 学习率
:param num_epochs: 训练迭代的次数
:param lr_decay_epoch: 学习率衰减次数
:return:
"""
# 初始化一个合成实例X,styles_Y_gram :风格Y矩阵,trainer:优化器
X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)
# 调整学习率
scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8)
animator = d2l.Animator(xlabel='epoch', ylabel='loss',
xlim=[10, num_epochs],
legend=['content', 'style', 'TV'],
ncols=2, figsize=(7, 2.5))
# 开始迭代
for epoch in range(num_epochs):
# 优化器梯度清零
trainer.zero_grad()
# 抽取特征,获得内容
contents_Y_hat, styles_Y_hat = extract_features(
X, content_layers, style_layers)
# 计算相关损失
contents_l, styles_l, tv_l, l = compute_loss(
X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)
# 总损失 l 的梯度回传
l.backward()
# 优化器梯度更新
trainer.step()
# 学习率衰减优化
scheduler.step()
if (epoch + 1) % 10 == 0:
animator.axes[1].imshow(postprocess(X))
animator.add(epoch + 1, [float(sum(contents_l)),
float(sum(styles_l)), float(tv_l)])
return X
# 20. 设置 device = GPU,设置图片的大小
device, image_shape = d2l.try_gpu(), (300, 450)
# 21. 将神经网络放到 GPU上
net = net.to(device)
# 22.根据图片得到内容特征_X,内容特征_Y
content_X, content_Y = get_contents(image_shape, device)
# 23.抽取特征风格_Y
_, styles_Y = get_styles(image_shape, device)
# 24.开始训练模型
output = train(content_X.contens_Y, styles_Y, device, 0.3, 500, 500)