StyleTransfer-PyTorch
风格迁移
编写:BenVon
2018年的CS231n添加了StyleTransfer等新内容,同时添加了TensorFlow和PyTorch两种版本。目前网上主流的是TensorFlow的版本,在此更新一波PyTorch版本以供日后复习参考。
解答思路主要参考了 BigDataDigest 的 文章 ,汉化说明及TensorFlow代码可以看看这边。
我完成的PyTorch版本作业可以到 Github 上下载。
Style Transfer
这个作业里我们将实现Image Style Transfer Using Convolutional Neural Networks” (Gatys et al., CVPR 2015)提到的风格转换技巧。
主要目的是准备两张图片,生成一张反映一张图的内容和另一张图的风格的新图。我们将通过计算深度网络中某一些特征空间中对应内容和风格的损失函数,并将梯度下降应用于图片像素本身。
我们使用SqueezeNet作为抽取特征的深度网络,这是一个在ImageNet上训练的小模型。你可以用任何网络,我们在这里选择SqueezeNet是因为它小而高效。
以下是一个你在本次作业最后可以完成的例子:
设置 Setup
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T
import PIL
import numpy as np
from scipy.misc import imread
from collections import namedtuple
import matplotlib.pyplot as plt
from cs231n.image_utils import SQUEEZENET_MEAN, SQUEEZENET_STD
%matplotlib inline
我们准备了一些用于处理图片的帮助函数, 从这里 开始我们将处理真正的JPEGs数据,而非CIFAR-10数据。
def preprocess(img, size=512):
transform = T.Compose([
T.Resize(size),
T.ToTensor(),
T.Normalize(mean=SQUEEZENET_MEAN.tolist(),
std=SQUEEZENET_STD.tolist()),
T.Lambda(lambda x: x[None]),
])
return transform(img)
def deprocess(img):
transform = T.Compose([
T.Lambda(lambda x: x[0]),
T.Normalize(mean=[0, 0, 0], std=[1.0 / s for s in SQUEEZENET_STD.tolist()]),
T.Normalize(mean=[-m for m in SQUEEZENET_MEAN.tolist()], std=[1, 1, 1]),
T.Lambda(rescale),
T.ToPILImage(),
])
return transform(img)
def rescale(x):
low, high = x.min(), x.max()
x_rescaled = (x - low) / (high - low)
return x_rescaled
def rel_error(x,y):
return np.max(np.abs(x - y) / (np.maximum(1e-8, np.abs(x) + np.abs(y))))
def features_from_img(imgpath, imgsize):
img = preprocess(PIL.Image.open(imgpath), size=imgsize)
img_var = img.type(dtype)
return extract_features(img_var, cnn), img_var
# Older versions of scipy.misc.imresize yield different results
# from newer versions, so we check to make sure scipy is up to date.
def check_scipy():
import scipy
vnum = int(scipy.__version__.split('.')[1])
major_vnum = int(scipy.__version__.split('.')[0])
assert vnum >= 16 or major_vnum >= 1, "You must install SciPy >= 0.16.0 to complete this notebook."
check_scipy()
answers = dict(np.load('style-transfer-checks.npz'))
就像上一个assignment,我们需要设定dtype用于选择CPU或GPU。
dtype = torch.FloatTensor
# Uncomment out the following line if you're on a machine with a GPU set up for PyTorch!
#dtype = torch.cuda.FloatTensor
# Load the pre-trained SqueezeNet model.
cnn = torchvision.models.squeezenet1_1(pretrained=True).features
cnn.type(dtype)
# We don't want to train the model any further, so we don't want PyTorch to waste computation
# computing gradients on parameters we're never going to update.
for param in cnn.parameters():
param.requires_grad = False
# We provide this helper code which takes an image, a model (cnn), and returns a list of
# feature maps, one per layer.
def extract_features(x, cnn):
"""
Use the CNN to extract features from the input image x.
Inputs:
- x: A PyTorch Tensor of shape (N, C, H, W) holding a minibatch of images that
will be fed to the CNN.
- cnn: A PyTorch model that we will use to extract features.
Returns:
- features: A list of feature for the input images x extracted using the cnn model.
features[i] is a PyTorch Tensor of shape (N, C_i, H_i, W_i); recall that features
from different layers of the network may have different numbers of channels (C_i) and
spatial dimensions (H_i, W_i).
"""
features = []
prev_feat = x
for i, module in enumerate(cnn._modules.values()):
next_feat = module(prev_feat)
features.append(next_feat)
prev_feat = next_feat
return features
#please disregard warnings about initialization
计算损失 Computing Loss
我们将进行三个组成部分的损失函数的计算。损失函数 分为三个部分的和:内容损失+风格损失+整体多样性损失。
内容损失 Content Loss
我们可以通过将损失函数组合生成一张反映一张图的内容和另一张图的风格的图片。我们希望惩罚图片内容的偏移和风格的偏移。然后使用这个混合损失函数进行梯度下降,不是对模型的参数,而是对源图的像素值。
首先我们写出内容损失函数。内容损失评估了生成图像的特征图与源图像的特征图有多大区别。我们只关心网络其中一层的内容表达(假设是层 ℓ ℓ ),有特征图 Aℓ∈R1×Cℓ×Hℓ×Wℓ A ℓ ∈ R 1 × C ℓ × H ℓ × W ℓ 。 Cℓ C ℓ 是 ℓ ℓ 层滤波器/通道的数量, Hℓ H ℓ 和 Wℓ W ℓ 是高和宽。我们将对重构空间为一维的特征图进行运算。令 Fℓ∈RCℓ×Mℓ F ℓ ∈ R C ℓ × M ℓ 表示当前图像的特征图,