CS231n课后作业 | Assignment 3 Q4 | StyleTransfer-PyTorch 风格迁移

本文介绍了使用PyTorch实现风格迁移的过程,详细讲解了内容损失、风格损失和整体方差正则化的计算,以及如何生成具有特定风格的新图片。内容包括设置、损失函数计算、风格转移应用实例,并展示了不同迭代阶段的图像效果。
摘要由CSDN通过智能技术生成

StyleTransfer-PyTorch

风格迁移

编写:BenVon



2018年的CS231n添加了StyleTransfer等新内容,同时添加了TensorFlow和PyTorch两种版本。目前网上主流的是TensorFlow的版本,在此更新一波PyTorch版本以供日后复习参考。

解答思路主要参考了 BigDataDigest 文章 ,汉化说明及TensorFlow代码可以看看这边。
我完成的PyTorch版本作业可以到 Github 上下载。

Style Transfer

这个作业里我们将实现Image Style Transfer Using Convolutional Neural Networks” (Gatys et al., CVPR 2015)提到的风格转换技巧。

主要目的是准备两张图片,生成一张反映一张图的内容和另一张图的风格的新图。我们将通过计算深度网络中某一些特征空间中对应内容和风格的损失函数,并将梯度下降应用于图片像素本身。

我们使用SqueezeNet作为抽取特征的深度网络,这是一个在ImageNet上训练的小模型。你可以用任何网络,我们在这里选择SqueezeNet是因为它小而高效。

以下是一个你在本次作业最后可以完成的例子:

StyleTransfer_Example


设置 Setup

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T
import PIL

import numpy as np

from scipy.misc import imread
from collections import namedtuple
import matplotlib.pyplot as plt

from cs231n.image_utils import SQUEEZENET_MEAN, SQUEEZENET_STD
%matplotlib inline

我们准备了一些用于处理图片的帮助函数, 从这里 开始我们将处理真正的JPEGs数据,而非CIFAR-10数据。

def preprocess(img, size=512):
    transform = T.Compose([
        T.Resize(size),
        T.ToTensor(),
        T.Normalize(mean=SQUEEZENET_MEAN.tolist(),
                    std=SQUEEZENET_STD.tolist()),
        T.Lambda(lambda x: x[None]),
    ])
    return transform(img)

def deprocess(img):
    transform = T.Compose([
        T.Lambda(lambda x: x[0]),
        T.Normalize(mean=[0, 0, 0], std=[1.0 / s for s in SQUEEZENET_STD.tolist()]),
        T.Normalize(mean=[-m for m in SQUEEZENET_MEAN.tolist()], std=[1, 1, 1]),
        T.Lambda(rescale),
        T.ToPILImage(),
    ])
    return transform(img)

def rescale(x):
    low, high = x.min(), x.max()
    x_rescaled = (x - low) / (high - low)
    return x_rescaled

def rel_error(x,y):
    return np.max(np.abs(x - y) / (np.maximum(1e-8, np.abs(x) + np.abs(y))))

def features_from_img(imgpath, imgsize):
    img = preprocess(PIL.Image.open(imgpath), size=imgsize)
    img_var = img.type(dtype)
    return extract_features(img_var, cnn), img_var

# Older versions of scipy.misc.imresize yield different results
# from newer versions, so we check to make sure scipy is up to date.
def check_scipy():
    import scipy
    vnum = int(scipy.__version__.split('.')[1])
    major_vnum = int(scipy.__version__.split('.')[0])

    assert vnum >= 16 or major_vnum >= 1, "You must install SciPy >= 0.16.0 to complete this notebook."

check_scipy()

answers = dict(np.load('style-transfer-checks.npz'))

就像上一个assignment,我们需要设定dtype用于选择CPU或GPU。

dtype = torch.FloatTensor
# Uncomment out the following line if you're on a machine with a GPU set up for PyTorch!
#dtype = torch.cuda.FloatTensor 
# Load the pre-trained SqueezeNet model.
cnn = torchvision.models.squeezenet1_1(pretrained=True).features
cnn.type(dtype)

# We don't want to train the model any further, so we don't want PyTorch to waste computation 
# computing gradients on parameters we're never going to update.
for param in cnn.parameters():
    param.requires_grad = False

# We provide this helper code which takes an image, a model (cnn), and returns a list of
# feature maps, one per layer.
def extract_features(x, cnn):
    """
    Use the CNN to extract features from the input image x.

    Inputs:
    - x: A PyTorch Tensor of shape (N, C, H, W) holding a minibatch of images that
      will be fed to the CNN.
    - cnn: A PyTorch model that we will use to extract features.

    Returns:
    - features: A list of feature for the input images x extracted using the cnn model.
      features[i] is a PyTorch Tensor of shape (N, C_i, H_i, W_i); recall that features
      from different layers of the network may have different numbers of channels (C_i) and
      spatial dimensions (H_i, W_i).
    """
    features = []
    prev_feat = x
    for i, module in enumerate(cnn._modules.values()):
        next_feat = module(prev_feat)
        features.append(next_feat)
        prev_feat = next_feat
    return features

#please disregard warnings about initialization

计算损失 Computing Loss

我们将进行三个组成部分的损失函数的计算。损失函数 分为三个部分的和:内容损失+风格损失+整体多样性损失。

内容损失 Content Loss

我们可以通过将损失函数组合生成一张反映一张图的内容和另一张图的风格的图片。我们希望惩罚图片内容的偏移和风格的偏移。然后使用这个混合损失函数进行梯度下降,不是对模型的参数,而是对源图的像素值。

首先我们写出内容损失函数。内容损失评估了生成图像的特征图与源图像的特征图有多大区别。我们只关心网络其中一层的内容表达(假设是层 ),有特征图 AR1×C×H×W A ℓ ∈ R 1 × C ℓ × H ℓ × W ℓ C C ℓ 层滤波器/通道的数量, H H ℓ W W ℓ 是高和宽。我们将对重构空间为一维的特征图进行运算。令 FRC×M F ℓ ∈ R C ℓ × M ℓ 表示当前图像的特征图,

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值